K-fold cross-validation (2024)

by Marco Taboga, PhD

Until now we have used the simplest of all cross-validation methods, which consists in testing our predictive models on a subset of the data (the test set) that has not been used for training or selecting the predictive models. This simple cross-validation method is sometimes called the holdout method.

There are more sophisticated cross-validation methods that allow us to obtain better predictive models, together with accurate estimates of an upper bound on the expected loss. In these methods, we perform multiple different partitions of the data into training, validation and test sets, we build a different predictive model for each partition, and finally we average the predictions made by the various models (so-called ensembling).

Here we introduce the most popular of these methods, called K-fold cross-validation.

K-fold cross-validation (1)

Table of contents

  1. Folds

  2. Ensembling

  3. Advantages and disadvantages

  4. Python example

    1. Import the data

    2. Run the K-fold cross-validation on LightGBM boosted trees

    3. Put the model in production


The data is divided into K-fold cross-validation (2) subsets, called folds. Each fold contains (approximately) the same number of observations.

Then, for K-fold cross-validation (3), we use the K-fold cross-validation (4)-th fold (denoted by K-fold cross-validation (5)) as the test set and use the remaining K-fold cross-validation (6) folds to train a predictive model K-fold cross-validation (7).

K-fold cross-validation (8)

For each fold, we compute the estimate of the expected loss K-fold cross-validation (9)where K-fold cross-validation (10) is a vector of inputs, K-fold cross-validation (11) is the loss incurred by using K-fold cross-validation (12) as a forecast of the observed output K-fold cross-validation (13), and K-fold cross-validation (14) is the number of observations in the K-fold cross-validation (15)-th fold.

Remark: the K-fold cross-validation (16) folds used to train a predictive model can be divided into training and validation sets if we need a validation set for model selection.


After training the K-fold cross-validation (17) models K-fold cross-validation (18), we use their ensemble average as our final prediction K-fold cross-validation (19)

We then compute the average loss K-fold cross-validation (20)

As previously discussed, this average is an estimate of an upper bound on the expected loss of the ensemble average (remember that the expected loss of the ensemble average equals the average expected loss of the models in the ensemble minus a correction term that measures the diversity of the models in the ensemble).

Advantages and disadvantages

K-fold cross validation is straightforward to implement: once we have a routine for training a predictive model, we just run it K-fold cross-validation (21) times on the different partitions of the data. The only real disadvantage is the computational cost.

As a reward for facing an increased computational cost, we have two main advantages:

  • our final model (the ensemble average) has been trained on all the data (no data have been "wasted") and it enjoys some benefits from ensembling;

  • although we no longer have an unbiased estimate of the expected loss of the final model, K-fold cross-validation (22) is an estimate of an upper bound to it, which uses all the data in the sample and should therefore be pretty precise.

Python example

Let us use K-fold cross-validation to improve on the simpler holdout cross-validation performed when we built a single boosted tree to predict the output variable in an artificially-generated data set.

Import the data

# Import the packages used to load and manipulate the dataimport numpy as np # Numpy is a Matlab-like package for array manipulation and linear algebraimport pandas as pd # Pandas is a data-analysis and table-manipulation toolimport urllib.request # Urlib will be used to download the dataset# Import the function that performs sample splits from scikit-learnfrom sklearn.model_selection import train_test_split# Load the output variable with pandas (download with urllib if not downloaded previously)remoteAddress = 'https://www.statlect.com/ml-assets/y_artificial.csv'localAddress = './y_artificial.csv'try: y = pd.read_csv(localAddress, header=None)except: urllib.request.urlretrieve(remoteAddress, localAddress) y = pd.read_csv(localAddress, header=None)y = y.values # Transform y into a numpy array# Print some information about the output variableprint('Class and dimension of output variable:')print(type(y))print(y.shape)# Load the input variables with pandas remoteAddress = 'https://www.statlect.com/ml-assets/x_artificial.csv'localAddress = './x_artificial.csv'try: x = pd.read_csv(localAddress, header=None)except: urllib.request.urlretrieve(remoteAddress, localAddress) x = pd.read_csv(localAddress, header=None)x = x.values# Print some information about the input variablesprint('Class and dimension of input variables:')print(type(x))print(x.shape)

The output is:

Class and dimension of output variable:class 'numpy.ndarray'(500, 1)Class and dimension of input variables:class 'numpy.ndarray'(500, 300)

Run the K-fold cross-validation on LightGBMboosted trees

We create 5 folds using the KFold class provided by the scikit-learn package.

Then, we use LightGBM to make predictions on each fold.

#Import the lightGBM packageimport lightgbm as lgb# Import the functions that performs sample splits from scikit-learnfrom sklearn.model_selection import train_test_split, KFold# Import model-evaluation metrics from scikit-learnfrom sklearn.metrics import mean_squared_error, r2_score# Set number of folds and ensemble variablesn_folds = 5ensemble = []mses_single_models = []mses_constant_predictions = []r_squareds_single_models = []# Initialize k_fold splitterK_fold = KFold(n_splits=n_folds, random_state=0, shuffle=True)# Iterate over foldsfor train_val_index, test_index in K_fold.split(x): # Get train_val (K-1 folds) and test (1 fold) x_train_val, x_test = x[train_val_index], x[test_index] y_train_val, y_test = y[train_val_index], y[test_index] # Partition the train_val set x_train, x_val, y_train, y_val = train_test_split(x_train_val, y_train_val, test_size=0.25, random_state=0) # Prepare dataset in LightGMB format y_train = np.squeeze(y_train) y_val = np.squeeze(y_val) train_set = lgb.Dataset(x_train, y_train, silent=True) valid_set = lgb.Dataset(x_val, y_val, silent=True) # Set algorithm parameters params = { 'objective': 'regression', 'learning_rate': 0.10, 'metric': 'mse', 'nthread': 8, 'min_data_in_leaf': 10, 'max_depth': 2, 'verbose': -1 } # Train the model boosted_tree = lgb.train( params = params, train_set = train_set, valid_sets = valid_set, num_boost_round = 10000, early_stopping_rounds = 20, verbose_eval = False, ) # Save the model in the ensemble list ensemble.append(boosted_tree) # Make predictions on test and compute performance metrics y_test_pred = boosted_tree.predict(x_test) mses_single_models.append(mean_squared_error(y_test, y_test_pred)) mses_constant_predictions.append(mean_squared_error(y_test, 0*y_test + np.mean(y_train))) r_squareds_single_models.append(r2_score(y_test, y_test_pred))# Print performance metrics on test sampleprint('Test MSEs of models in the ensemble:')print(mses_single_models)print('Test MSEs of constant predictions equal to sample mean on training set:')print(mses_constant_predictions)print('Average test MSE of models in the ensemble:')print(np.mean(mses_single_models))print('')print('Test R squareds of models in the ensemble:')print(r_squareds_single_models)print('Average test R squared of models in the ensemble:')print(np.mean(r_squareds_single_models))

The output is:

Test MSEs of models in the ensemble:[74.57708434382333, 37.439965145009516, 16.801894170417487, 54.55429632149937, 30.007456926513534]Test MSEs of constant predictions equal to sample mean on training set:[293.2026939915153, 167.6726253022108, 104.80130552664033, 213.67954000052373, 131.20844137862687]Average test MSE of models in the ensemble:42.67613938145264Test R squareds of models in the ensemble:[0.7414010556669601, 0.7762513029977958, 0.8356878267813256, 0.7416211645498416, 0.7712684886703587]Average test R squared of models in the ensemble:0.7732459677332564

There is significant variation in test mean squared errors (MSEs) across the folds, although it is mostly due to differences in variance (test MSEs of constant predictions). The R squareds are more hom*ogeneous, that is, the proportion of variance explained by the predictions is more stable across folds.

Anyway, the variability in test MSEs reveals that it was probably a good idea to run a K-fold cross-validation.

At this stage, it is not possible to say anything more precise about the benefits from using K-fold cross-validation instead of the simple holdout method (although there are good theoretical guarantees about them). The benefits are likely to become apparent when we put the model in production, which we simulate below.

Put the model in production

We now see how our predictive model performs in production, on new data that becomes available after we have completed the training.

# Load the input and output variables with pandas y_production = pd.read_csv('./assets/y_artificial_production.csv', header=None)y_production = y_production.valuesy_production = np.squeeze(y_production)x_production = pd.read_csv('./assets/x_artificial_production.csv', header=None)x_production = x_production.values# Initialize ensemble variables on production datasety_production_pred_ensemble = 0production_mses_single_models = []production_r_squareds_single_models = []# Make predictions with all models in the ensemble and compute ensemble averagefor model in ensemble: y_production_pred = model.predict(x_production) production_mses_single_models.append(mean_squared_error(y_production, y_production_pred)) production_r_squareds_single_models.append(r2_score(y_production, y_production_pred)) y_production_pred_ensemble += y_production_pred/n_folds# Print MSEs print('Production MSEs of models in the ensemble:')print(production_mses_single_models)print('Average production MSE of models in the ensemble:')print(np.mean(production_mses_single_models))print('Production MSE of ensemble average:')print(mean_squared_error(y_production, y_production_pred_ensemble))print('')# Print R squaredsprint('Production R squareds of models in the ensemble:')print(production_r_squareds_single_models)print('Average production R squared of models in the ensemble:')print(np.mean(production_r_squareds_single_models))print('Production R squared of ensemble average:')print(r2_score(y_production, y_production_pred_ensemble))

The output is:

Production MSEs of models in the ensemble:[39.64176656361246, 32.90628594145072, 37.6388748646386, 44.169307277678485, 43.86534683792454]Average production MSE of models in the ensemble:39.64431629706096Production MSE of ensemble average:33.64136466519085Production R squareds of models in the ensemble:[0.7639469310457051, 0.8040543987387033, 0.7758734614027009, 0.736986985185174, 0.73879696493291]Average production R squared of models in the ensemble:0.7639317482610387Production R squared of ensemble average:0.7996772580685613

This is an excellent result! The performance of the ensemble average is significantly better than the average performance of the models in the ensemble.

How to cite

Please cite as:

Taboga, Marco (2021). "K-fold cross-validation", Lectures on machine learning. https://www.statlect.com/machine-learning/k-fold-cross-validation.

K-fold cross-validation (2024)


What is k-fold cross-validation used for? ›

Given the training data set, the k-fold cross validation is done with the purpose of estimating beforehand how well the model would perform. Given the randomization, it is unlikely that there will be a dramatic change from one run into the next one in the loop of the cross-validation.

What is group k-fold cross-validation? ›

Note that k-fold cross-validation is to evaluate the model design, not a particular training. Because you re-trained the model of the same design with different training sets. The general procedure is as follows: Shuffle the dataset randomly. Split the dataset into k groups.

What is the best K for k-fold cross-validation? ›

The key configuration parameter for k-fold cross-validation is k that defines the number folds in which to split a given dataset. Common values are k=3, k=5, and k=10, and by far the most popular value used in applied machine learning to evaluate models is k=10.

What is the difference between K-fold and V fold cross-validation? ›

V-fold cross-validation (also known as k-fold cross-validation) randomly splits the data into V groups of roughly equal size (called "folds").

What is the downside of k-fold cross-validation? ›

K-fold cross-validation is valuable for assessing model performance, but it has limitations. The main drawbacks include increased computational cost and time due to multiple model trainings.

How does k-fold cross-validation prevent overfitting? ›

With k-fold cross-validation, we evaluate the model numerous times on distinct subsets of the data, resulting in a more trustworthy estimate of performance and aiding in the detection of overfitting or model instability. We only assess the model's performance on one split of the data without cross-validation.

What are the advantages of k-fold cross-validation? ›

Employing K-fold cross-validation enables a comprehensive evaluation of model performance by partitioning the entire dataset into K equal-sized subsets. This method allows us to mitigate the impact of imbalanced data and provides reliable cross-validation results for deep learning models.

Does k-fold cross-validation increase accuracy? ›

To do that, you need to evaluate its performance using a reliable method that avoids overfitting or underfitting. One such method is k-fold cross-validation, which can help you improve your model accuracy by reducing the variance of your estimates.

How to choose k-fold value? ›

Here's how to set the value of K In K-fold cross-validation…

Choose the value of 'k' such that the model doesn't suffer from high variance and high bias. In most cases, the choice of k is usually 5 or 10, but there is no formal rule. However, the value of k relies upon the size of the dataset.

Why use k-fold cross-validation instead of leave-one-out? ›

K-fold cross-validation strikes a balance between bias and variance by partitioning data into k subsets, whereas leave-one-out cross-validation provides low bias but can be computationally expensive for large datasets.

When should I use cross-validation? ›

Cross-validation is a technique for evaluating ML models by training several ML models on subsets of the available input data and evaluating them on the complementary subset of the data. Use cross-validation to detect overfitting, ie, failing to generalize a pattern.

Is k-fold cross-validation biased? ›

Role of the Test Set

Employing a test set in conjunction with k-fold cross-validation offers an unbiased evaluation, reinforcing the confidence in the model's ability to generalize well to new data.

Top Articles
Latest Posts
Article information

Author: Gov. Deandrea McKenzie

Last Updated:

Views: 5597

Rating: 4.6 / 5 (66 voted)

Reviews: 89% of readers found this page helpful

Author information

Name: Gov. Deandrea McKenzie

Birthday: 2001-01-17

Address: Suite 769 2454 Marsha Coves, Debbieton, MS 95002

Phone: +813077629322

Job: Real-Estate Executive

Hobby: Archery, Metal detecting, Kitesurfing, Genealogy, Kitesurfing, Calligraphy, Roller skating

Introduction: My name is Gov. Deandrea McKenzie, I am a spotless, clean, glamorous, sparkling, adventurous, nice, brainy person who loves writing and wants to share my knowledge and understanding with you.