How Can I Trust My ML Predictions?

Machine learning models can be extremely useful to solve critical business problems but the nature of many modeling techniques can lead to more questions than answers if great care is not taken when designing the model and in understanding the outputs. 

In this blog, we will look at some data science best practices that we use to create reliable models that perform well in a real-world context, and then we’ll take a look at how to use model explanations to both understand what is driving individual predictions and how each feature drives overall model performance.

Isn’t ML Modeling a Black Box?

Many people treat machine learning models as a black box: data goes in and answers come out without a need to understand what is happening inside the box. Some modeling techniques have very intuitive results like a linear regression model, where the inputs are multiplied by a set of numbers and added together. However, more complex models like neural networks or gradient-boosted decision trees are much more powerful but harder to guess how individual predictions are being made. 

Luckily, even the most complex model is just using statistics to find trends in the data, and we can leverage the same understanding of statistics to probe into the most complex black box and find out what is going on.

Best Practices for Creating Models

Before we get to understanding the behavior of a trained model, there are a few things we should understand about best practices when it comes to creating a model.

Create a Dataset That Looks Like the Real World

Different metrics like accuracy, precision, and recall, help data scientists describe how well a model performs on a set of data. If you want to know how a model will perform, the metrics must be calculated on a test dataset that accurately represents the data it will receive in the real world.

For example, if you create a model that predicts housing prices that achieves an average error of $200 when looking at houses in Iowa, you would not expect to get the same accuracy on predictions in New York City. So, if you expect your model to be used in every state, metrics should be calculated on a mix of data from each state.

In statistics, selection bias is the term used when a dataset doesn’t reflect reality.

Split the Data

Now that we have a representative sample of data, we need to make sure that the model is not overconfident in its abilities. Most complex models tend to tailor their predictions to the dataset they are trained on, so a set of data, our test dataset, must be held back from training for calculating our metrics.

You can imagine that if the training and testing datasets are identical, a significantly complex model could memorize all of the answers and appear to get very good results. As soon as the model sees a question it has not memorized, it will start to fail.

This is called overfitting, where a model’s predictions do not generalize well beyond the data it has seen. By focusing on metrics calculated on the test dataset, we can be sure that the results will correspond to the actual performance of the model.

Data Leakage

There are some considerations to remember when choosing how to split your dataset into training and testing sets. Let’s start with an example.

If you want to predict when is a good time to buy or sell stocks based on the past week of behavior, you might have a record for each day in the last few years. 

If January 1st, 2021 is in the training data, and January 2nd, 2021 is in the testing data, you have data leakage since six of the seven days used to train the model are “leaking” into the test dataset. To prevent this for predictions that are tied to a specific time, it is common to train and test on completely distinct time periods.

If the task is to predict a user’s likelihood to make a purchase on a website, you might want to make sure all of the records for each specific user end up on the same side of the split, so that the model isn’t rewarded for memorizing a specific user’s behavior instead of learning generalizable trends. 

Cross Validation

Another way to prevent overfitting your model when trying many different model parameters (hyperparameter optimization) or even different algorithms (logistic regression, decision trees, neural networks, etc.) is to use a type of cross validation. 

Cross validation is a strategy to estimate the accuracy of the model by repeatedly training a given model set up on a portion of the data and calculating metrics on the remainder. 

For each iteration, a different portion is used for calculating the metrics and averaged across all iterations. The set of parameters with the best accuracy can now be used on the whole training set to give the best model. 

Example of 4-fold cross validation

Figure: Example of 4-fold cross validation.
In each iteration, ¼ of the data is left out of training for calculating metrics. Note that the sample index should be randomized in most cases to avoid bias.

When done correctly, we can have very high confidence that the chosen model and its parameters are not overly tuned to the training data. Then when our metrics are calculated on the testing dataset, we have a fair picture of how the model will perform on further unseen data. 

Data Drift

Make sure to keep an eye on how the features that are being fed into the model are distributed compared to the original data. If future data drifts too far away from the original spread, model performance will likely degrade. Updating the training set and redeploying new models when you notice significant drift can keep your models at peak performance. 

What is Model Explainability, and Why Does it Matter?

After you create your model, you might still have questions about what it is actually doing.

  • What features are contributing to each prediction? 
  • What features are most important overall?
  • Do the predictions make any sense at all?
  • Does the algorithm have biases against different people groups?

These are important questions for businesses, both for internal understanding and compliance with government regulations. Many governments have protected the Right to Explanation for individuals whose lives are significantly impacted by the decision based on the output of an algorithm. You might be required to explain any prediction your model makes. Sometimes, you can learn a lot about your business by finding answers to all of these questions.

Some traditional methods of inspecting a model are very dependent on the algorithms at play. A linear model will give coefficients, but large coefficients don’t always signal an important feature, especially if it is very sparsely populated. Tree-based models can provide gain or coverage, which only gives a sense of how often the feature is used in the collection of trees. When you want to compare models from drastically different families we need a model agnostic approach. 

My favorite framework for explaining model predictions is SHAP (SHapley Additive exPlanations). There is a full-featured python library that provides methods for calculating and visualizing explanations for individual predictions, feature importance, and correlating features with outcomes. 

SHAP values help understand individual predictions by providing a value for each feature that explains how much impact that feature has on the prediction. The methods for calculating SHAP values are based on game theory, and the package comes with many ways to speed up calculations by using approximations or parallelization.

For this example, let’s look at an explanation for a model that is predicting median home prices for different areas of California using this data. This plot shows the breakdown of how much each feature contributes to a single prediction. 

The values for the data are shown on the left, and the contributions as the colored arrows. This plot lets us easily explain how the model arrived at the prediction. These explanations could be built into a model’s response in production to help comply with Right to Explanation laws. 

Next, we can take the individual contributions and see how they are distributed for the entire test dataset. To sort the features by most important, we can calculate the average size of contributions for each feature. Then, we can show the density of records, where the color shows the value of the features. This plot can help a data scientist find major issues with the model’s understanding of the data and determine the relative importance of all your features at a glance.

The last plot I want to highlight shows the shape of the relationship between model impact and the feature values. You can see on the left that the model picked out a mostly linear relationship between median income and the model target which was median home value. 

The model’s dependence on average occupancy is more complex. Not only is there a gentle curve, but you can also see from the coloring that the size of the impact from occupancy depends on median income as well. From this, we can tell that the model picked up on nonlinear interactions between these two features. 


Do not forget that any model, however complex, is not truly a black box. With the right tools, machine learning models can not only provide accurate predictions for the future, they can provide valuable insights into your data.

If you are interested in doing data science in a robust, future-proof way, the data scientists at phData would love to help! 

Accelerate and automate your data projects with the phData Toolkit

Data Coach is our premium analytics training program with one-on-one coaching from renowned experts.