sklearn.model_selection.PredefinedSplit

class sklearn.model_selection.PredefinedSplit(test_fold)[source]

Predefined split cross-validator

Provides train/test indices to split data into train/test sets using a predefined scheme specified by the user with the test_fold parameter.

Read more in the User Guide.

New in version 0.16.

Parameters
test_foldarray-like, shape (n_samples,)

The entry test_fold[i] represents the index of the test set that sample i belongs to. It is possible to exclude sample i from any test set (i.e. include sample i in every training set) by setting test_fold[i] equal to -1.

Examples

>>> import numpy as np
>>> from sklearn.model_selection import PredefinedSplit
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([0, 0, 1, 1])
>>> test_fold = [0, 1, -1, 1]
>>> ps = PredefinedSplit(test_fold)
>>> ps.get_n_splits()
2
>>> print(ps)
PredefinedSplit(test_fold=array([ 0,  1, -1,  1]))
>>> for train_index, test_index in ps.split():
...     print("TRAIN:", train_index, "TEST:", test_index)
...     X_train, X_test = X[train_index], X[test_index]
...     y_train, y_test = y[train_index], y[test_index]
TRAIN: [1 2 3] TEST: [0]
TRAIN: [0 2] TEST: [1 3]

Methods

get_n_splits(self[, X, y, groups])

Returns the number of splitting iterations in the cross-validator

split(self[, X, y, groups])

Generate indices to split data into training and test set.

__init__(self, test_fold)[source]

Initialize self. See help(type(self)) for accurate signature.

get_n_splits(self, X=None, y=None, groups=None)[source]

Returns the number of splitting iterations in the cross-validator

Parameters
Xobject

Always ignored, exists for compatibility.

yobject

Always ignored, exists for compatibility.

groupsobject

Always ignored, exists for compatibility.

Returns
n_splitsint

Returns the number of splitting iterations in the cross-validator.

split(self, X=None, y=None, groups=None)[source]

Generate indices to split data into training and test set.

Parameters
Xobject

Always ignored, exists for compatibility.

yobject

Always ignored, exists for compatibility.

groupsobject

Always ignored, exists for compatibility.

Yields
trainndarray

The training set indices for that split.

testndarray

The testing set indices for that split.