Skip to content

Cross Validation

TimeGapSplit

We allow for a timeseries split that contains a gap.

You won't always need it, but sometimes you consider these two situations;

  • If you have multiple samples per timestamp: you want to make sure that a timestamp doesn’t appear at the same time in training and validation folds.
  • If your target is looking \(x\) days ahead in the future. In this case you cannot construct the target of the last \(x\) days of your available data. It means that when you put your model in production, the first day that you are going to score is always \(x\) days after your last training sample, therefore you should select the best model according to that setup.

    In other words, if you keep that gap in the validation, your metric might be overestimated because those first \(x\) days might be easier to predict since they are closer to the training set. If you want to be strict in terms of robustness you might want to replicate in the CV exactly this real-world behaviour, and thus you want to introduce a gap of x days between your training and validation folds.

TimeGapSplit provides 4 parameters to really reproduce your production implementation in your cross-validation schema. We will demonstrate this in a code example below.

Examples

Let's make some random data to start with, and next define a plotting function.

from datetime import timedelta

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

from sklego.model_selection import TimeGapSplit

# Plotting helper function
def plot_cv(cv, X):
    """Plot all the folds on time axis"""
    X_index_df = cv._join_date_and_x(X)

    plt.figure(figsize=(16, 4))
    for i, split in enumerate(cv.split(X)):
        x_idx, y_idx = split
        x_dates = X_index_df.iloc[x_idx]["__date__"].unique()
        y_dates = X_index_df.iloc[y_idx]["__date__"].unique()
        plt.plot(x_dates, i*np.ones(x_dates.shape), c="steelblue")
        plt.plot(y_dates, i*np.ones(y_dates.shape), c="orange")

    plt.legend(("training", "validation"), loc="upper left")
    plt.ylabel("Fold id")
    plt.axvline(x=X_index_df["__date__"].min(), color="gray", label="x")
    plt.axvline(x=X_index_df["__date__"].max(), color="gray", label="d")

# Random data creation
df = (pd.DataFrame(np.random.randint(0, 30, size=(30, 4)), columns=list("ABCy"))
      .assign(date=pd.date_range(start="1/1/2018", end="1/30/2018")[::-1]))

print(df.shape)
# (30, 5)

print(df.head())
A B C y date
28 9 24 5 2018-01-30 00:00:00
5 0 19 1 2018-01-29 00:00:00
8 1 29 2 2018-01-28 00:00:00
11 4 21 19 2018-01-27 00:00:00
19 26 6 2 2018-01-26 00:00:00
Example 1
cv = TimeGapSplit(
    date_serie=df["date"],
    train_duration=timedelta(days=10),
    valid_duration=timedelta(days=2),
    gap_duration=timedelta(days=1)
)

plot_cv(cv, df)

example-1

Example 2
cv = TimeGapSplit(
    date_serie=df["date"],
    train_duration=timedelta(days=10),
    valid_duration=timedelta(days=5),
    gap_duration=timedelta(days=1)
)

plot_cv(cv, df)

example-2

window="expanding" is the closest to scikit-learn implementation:

Example 3
cv = TimeGapSplit(
    date_serie=df["date"],
    train_duration=timedelta(days=10),
    valid_duration=timedelta(days=2),
    gap_duration=timedelta(days=1),
    window="expanding"
)

plot_cv(cv, df)

example-3

If train_duration is not passed the training duration is the maximum without overlapping validation folds:

Example 4
cv = TimeGapSplit(
    date_serie=df["date"],
    train_duration=None,
    valid_duration=timedelta(days=3),
    gap_duration=timedelta(days=2),
    n_splits=3
)

plot_cv(cv, df)

example-4

If train and valid duration would lead to unwanted amounts of splits n_splits can set a maximal amount of splits

Example 5
cv = TimeGapSplit(
    date_serie=df["date"],
    train_duration=timedelta(days=10),
    valid_duration=timedelta(days=2),
    gap_duration=timedelta(days=1),
    n_splits=4
)

plot_cv(cv, df)

example-5

Summary
cv.summary(df)
Start date End date Period Unique days nbr samples
2018-01-01 00:00:00 2018-01-10 00:00:00 9 days 00:00:00 10 10
2018-01-12 00:00:00 2018-01-13 00:00:00 1 days 00:00:00 2 2
2018-01-06 00:00:00 2018-01-15 00:00:00 9 days 00:00:00 10 10
2018-01-17 00:00:00 2018-01-18 00:00:00 1 days 00:00:00 2 2
2018-01-10 00:00:00 2018-01-19 00:00:00 9 days 00:00:00 10 10
2018-01-21 00:00:00 2018-01-22 00:00:00 1 days 00:00:00 2 2
2018-01-15 00:00:00 2018-01-24 00:00:00 9 days 00:00:00 10 10
2018-01-26 00:00:00 2018-01-27 00:00:00 1 days 00:00:00 2 2

GroupTimeSeriesSplit

In a time series problem it is possible that not every time unit (e.g. years) has the same amount of rows/observations. This makes a normal kfold split impractical as you cannot specify a certain timeframe per fold (e.g. 5 years), because this can cause the folds' sizes to be very different.

With GroupTimeSeriesSplit you can specify the amount of folds you want (e.g. n_splits=3) and GroupTimeSeriesSplit will calculate itself folds in such a way that the amount of observations per fold are as similar as possible.

The folds are created with a smartly modified brute forced method. This still means that for higher n_splits values in combination with many different unique time periods (e.g. 100 different years, thus 100 groups) the generation of the optimal split points can take minutes to hours.

Info

UserWarnings are raised when GroupTimeSeriesSplit expects to be running over a minute. Of course, this actual runtime depends on your machine's specifications.

Examples

First let's create an example data set:

import numpy as np
import pandas as pd

X = np.random.randint(low=1, high=1000, size=17)
y = np.random.randint(low=1, high=1000, size=17)
groups = np.array([2000,2000,2000,2001,2002,2002,2003,2004,2004,2004,2004,2004,2005,2005,2006,2006,2007])

df = pd.DataFrame(np.vstack((X,y)).T, index=groups, columns=['X','y'])
df.head(10)
X y
583 481
414 617
669 627
812 604
800 248
966 503
719 650
476 939
743 170
142 893

Create a GroupTimeSeriesSplit cross-validator with kfold/n_splits = 3:

from sklego.model_selection import GroupTimeSeriesSplit
cv = GroupTimeSeriesSplit(n_splits=3)

def print_folds(cv, X, y, groups):
    for kfold, (train, test) in enumerate(cv.split(X, y, groups)):
        print(f"Fold {kfold+1}:")
        print(f"Train = {df.iloc[train].index.tolist()}")
        print(f"Test = {df.iloc[test].index.tolist()}\n\n")

print_folds(cv, X, y, groups)
Fold 1:
Train = [2000, 2000, 2000, 2001]
Test = [2002, 2002, 2003]


Fold 2:
Train = [2002, 2002, 2003]
Test = [2004, 2004, 2004, 2004, 2004]


Fold 3:
Train = [2004, 2004, 2004, 2004, 2004]
Test = [2005, 2005, 2006, 2006, 2007]

grp-ts-split

As you can see above GroupTimeSeriesSplit keeps the order of the time chronological and makes sure that the same time value won't appear in both the train and test set of the same fold.

GroupTimeSeriesSplit also has the .summary() method, in which is shown which time values are grouped together. Because of the chronological order the train and test folds need to be, the amount of groups is always n_splits + 1. (see the four folds in the image above with Kfold=3)

Summary
cv.summary()
index observations group obs_per_group ideal_group_size diff_from_ideal_group_size
2000 3 0 4 4 0
2001 1 0 4 4 0
2002 2 1 3 4 -1
2003 1 1 3 4 -1
2004 5 2 5 4 1
2005 2 3 5 4 1
2006 2 3 5 4 1
2007 1 3 5 4 1

To use GroupTimeSeriesSplit with sklearn's GridSearchCV:

from sklearn.linear_model import Lasso
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline

# reshape X for the Lasso model
X = X.reshape(-1,1)

# initiate the cross validator
cv = GroupTimeSeriesSplit(n_splits=3)

# generate the train-test splits
cv_splits = cv.split(X=X, y=y, groups=groups)

# initiate the Lasso model
Lasso(random_state=0, tol=0.1, alpha=0.8).fit(X, y, groups)
pipe = Pipeline([("reg", Lasso(random_state=0, tol=0.1))])


# initiate GridSearchCv with cv_splits as parameter
alphas = [0.1, 0.5, 0.8]
grid = GridSearchCV(pipe, {"reg__alpha": alphas}, cv=cv_splits)
grid.fit(X, y)
grid.best_estimator_.get_params()["reg__alpha"]
# 0.8

Cluster-Kfold

The ClusterFoldValidation object is a cross-validator that splits the data into n_splits folds, where each fold is determined by a clustering algorithm. This is not a common pattern, probably more like an anti-pattern really, but it might be useful when you want to make sure that the train and test sets are very distinct. This can be seen as a way to make it harder for the algorithm perform well, because the training sets are sampled differently than the test sets.

Example

Here's how you could set up a cross validator that uses KMeans.

Using Kmeans to generate folds
from sklego.model_selection import ClusterFoldValidation
from sklearn.cluster import KMeans

clusterer = KMeans(n_clusters=5, random_state=42)
folder = ClusterFoldValidation(clusterer)

You can also use other cross validation methods, but the nice thing about Kmeans is that it demos well. Here's how it would generate folds on a uniform dataset.

Using Kmeans to generate folds
import matplotlib.pylab as plt
import numpy as np

X_orig = np.random.uniform(0, 1, (1000, 2))
for i, split in enumerate(folder.split(X_orig)):
    x_train, x_valid = split
    plt.scatter(X_orig[x_valid, 0], X_orig[x_valid, 1], label=f"split {i}")
plt.legend();

example-1

As you can see, each split will focus on a cluster of the data. Hopefully this also makes it clear that this method will ensure that each validation set will be rather distinct from the train set. These sets are not only exclusive, but they are also from a different region of the data by design.

Note that this image is mostly for illustrative purposes because you typically won't directly generate these folds yourself. Instead you'd use a helper function like cross_val_score or GridSearchCV to do this for you.

More realistic example
from sklearn.model_selection import cross_val_score

# Given an existing pipeline and X,y dataset, you probably would do something like this:
fold_method = KlusterFoldValidation(
    KMeans(n_cluster=5, random_state=42)
)
cross_val_score(pipeline, X, y, cv=fold_method)