Generalization Power of Machine Learning Algorithms

By Data Science Salon

If you have built machine learning pipelines, you must have faced questions like “How does this model generalize well on unseen data?”.

One of the most critical assumptions in building a successful machine learning model is that it can generalize on the previously unseen dataset. The new incoming data in the production pipelines will be well predicted only if it belongs to a similar distribution. This emphasizes the property of generalization of machine learning solutions. 

In this post, we will understand what generalization means and why it is important. We will also learn how to measure the generalization power of machine learning models and ensure it continues to perform well on the unseen dataset. 


Source: rawpixel.com on Freepik

What is Generalization and Why Is It Important?

Generalization is defined as the “model's ability to adapt properly to new, previously unseen data, drawn from the same distribution as the one used to create the model.” The model learns patterns from the historical data that is used to train a machine learning algorithm and considers it as a proxy estimate of the model’s future performance. This fundamental property highlights that all machine learning models are naturally probabilistic and can perform well as long as the unseen data distribution is similar to the training data distribution. The power of generalization is closely related to overfitting, as explained in the next section. 

Overfitting vs Underfitting

Generalization warrants a model which is neither overfitting nor underfitting the training set, meaning that it captures the general signal from the training set but ignores the noise. Such a well-learned model is complex enough to learn the characteristics from the data. It can be inspected using cross-validation methods that train and test the model on different subsets of the data.

There are specific techniques and algorithms like lasso and ridge regularization, along with ensemble models that reduce the model overfitting. However, there is a caveat to it – an attempt to reduce overfitting introduces bias into the model output. Hence, the developers need to balance the trade-off between overfitting and underfitting commonly called the bias-variance trade-off.

The post on the dogma of bias and variance explains this statistical property in detail and elaborates on various measures of how to select the best-performing model that accurately learns the regularities from the data and has good generalization power. 

Importance of Choosing the Right Sample Set

It is entirely possible that your model has learned the statistical associations from the data very well but the goodness of the model does not reflect in the results. A large part of the machine learning literature focuses on the models but not so much on the right way to generate a sample training set.

Curating a right and representative sample is also an art like modeling.

Let’s first understand why we need to work with a sample. A sample is a subset of all the related attributes needed to model a phenomenon that potentially exists in the world. A carefully created sample set is a prerequisite to representing the real-world phenomenon and building robust & reliable models.

Some of the common sampling methods are shared below:

Random Sampling

It is one of the most popular and simpler ways to create a sample. Each sample has a uniform probability of getting selected from the population. However, this technique has the downside of not being able to select the samples from the minority classes. The next sampling technique i.e. stratified sampling builds upon this limitation of random sampling.

Stratified Sampling

In order to ensure that the generated dataset has samples from all classes or categories of interest, the stratified sampling technique divides the population data into different strata aka groups based on different characteristics, for example, regions or tier-based categories, etc.

Weighted Sampling

It assigns weight to each sample, for example, if the model needs to be retrained on newly collected data, the data scientist can leverage the domain knowledge to assign 70% weight on large chunks of older data for learning generic attribute association and 30% on new data to adapt to the changing data patterns.

How to Check Whether Two Data-Distributions are the Same?

Now that we have discussed the assumptions and constraints of the ML framework along with the strategy to create the right sample, let's see how to identify whether the test data and train data belong to the same distribution.

The Kolmogorov-Smirnov test aka KS statistic is used to measure whether the two distributions are similar. It is defined as “a distance between the empirical distribution function of the sample and the cumulative distribution function of the reference distribution”. Further, it does not assume the underlying distribution and simply compares one set of data with a known distribution. 

You can find whether two sample sets belong to similar distributions using python code - first, you need to import the stats python package as shown below:

from scipy import stats
stats.ks_2samp(sample1, sample2)


The null hypothesis states that the sample distributions are the same, you need to make a discovery that can help in rejecting the null hypothesis and point out that the two distributions are not similar. If you choose a 95% confidence level aka 5% significance level, you can reject the null hypothesis if the p-value (from the python code shared above) comes out to be less than 5%.

Working With Non-Stationary Data 

So far, we have discussed multiple ways to ensure that the model is robust and generalizes on the real-time data in production. But, it assumes that the underlying data distribution is the same i.e. stationary. For example, what would you do if the data dynamics keep varying with time and requires a model to adjust and learn the weights from changing user behavior?

Online Learning

Online learning algorithms do not make any assumptions on underlying distribution and work best for non-stationary time-series model learning. 

It is important to distinguish the difference between online learning with that of a batch learning which is excellently explained below:

“Online machine learning is a method of machine learning in which data becomes available in a sequential order and is used to update the best predictor for future data at each step, as opposed to batch learning techniques which generate the best predictor by learning on the entire training data set at once.”

It dynamically adapts to the varying data distributions and patterns by updating the model weights. The model observes the mistakes and keeps calibrating the weights to produce correct predictions.

Domain Adaptation

What would you do if the model learned the importance of a certain feature or a group of features based on the training set, but the same features are not available or are missing in the production data? In such a scenario, domain adaptation ensures that the trained model is able to lower the weight of such feature(s), thereby reducing the model's reliance on the availability of such a feature set to generate accurate predictions. 

Summary

The post highlighted the importance of selecting the best performing model by assessing its generalization power on unseen data. It illustrated how to compare the two distributions using KS Statistic through python. You were presented with the challenges of modeling a non-stationary data distribution and what are the possible ways to handle such a scenario.

Reference

SIGN UP FOR THE DSS PLAY WEEKLY NEWSLETTER
Get the latest data science news and resources every Friday right to your inbox!