Source code for synrxn.split.repeated_kfold

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple, List, Iterator, Union, Iterable, Any
import warnings

import numpy as np
import pandas as pd
from sklearn.model_selection import (
    KFold,
    StratifiedKFold,
    GroupKFold,
    train_test_split,
)
from sklearn.utils import check_random_state, indexable


[docs] @dataclass class SplitIndices: """ Container for indices of a single (repeat, fold) split. :param repeat: Repeat index (0-based). :param fold: Fold index within the repeat (0-based). :param train_idx: Numpy array of training row indices. :param val_idx: Numpy array of validation row indices. :param test_idx: Numpy array of test row indices. """ repeat: int fold: int train_idx: np.ndarray val_idx: np.ndarray test_idx: np.ndarray def __repr__(self) -> str: return ( f"SplitIndices(repeat={self.repeat}, fold={self.fold}, " f"train={len(self.train_idx)}, val={len(self.val_idx)}, test={len(self.test_idx)})" ) def __str__(self) -> str: return self.__repr__()
[docs] class RepeatedKFoldsSplitter: """ Repeated K-Fold splitter producing (train, val, test) for each outer fold. - sklearn-compatible `split(X, y=None, groups=None, stratify=None)` -> yields (train_idx, holdout_idx) where holdout_idx == val + test (useful for sklearn `cross_validate`). - Use `split_with_val(X, y=None, groups=None, stratify=None)` to receive (train, val, test) triples as `SplitIndices` objects (repeat, fold, train_idx, val_idx, test_idx). The `ratio` argument is a (train, val, test) tuple that controls the *proportion* used when splitting the outer holdout into validation and test sets. For example, `ratio=(8,1,1)` means the holdout is split val:test = 1:1. Notes on stratification: - Pass `y` (array-like) to stratify by labels (sklearn-conventional). - Alternatively pass `stratify` to `split(...)` or `split_with_val(...)`. `stratify` may be: - a column name (str) when `X` is a pandas.DataFrame, or - an array-like of the same length as `X`. - If both `y` and `stratify` are provided, `stratify` takes precedence. - If stratification is requested but not possible (e.g., a class has fewer than `n_splits` members), the splitter falls back to non-stratified `KFold` and emits a warning. Example (Sphinx-style):: >>> from sklearn.datasets import make_classification >>> X, y = make_classification(n_samples=500, n_features=20, weights=[0.9, 0.1], random_state=0) >>> splitter = RepeatedKFoldsSplitter(n_splits=5, n_repeats=2, ratio=(8,1,1), shuffle=True, random_state=1) >>> # sklearn-style outer cross-validation (y is used for stratification) >>> for train_idx, hold_idx in splitter.split(X, y): ... print(len(train_idx), len(hold_idx)) >>> # explicit train/val/test >>> for s in splitter.split_with_val(X, stratify=y): ... X_train, X_val, X_test = X[s.train_idx], X[s.val_idx], X[s.test_idx] :param n_splits: Number of outer folds (k). :param n_repeats: Number of repeats (how many times to reshuffle-and-split). :param ratio: Tuple of three ints (train, val, test) like (8,1,1). :param shuffle: Whether to shuffle before splitting each repeat. :param random_state: Base random state for reproducible repeats. """ def __init__( self, n_splits: int = 5, n_repeats: int = 1, ratio: Tuple[int, int, int] = (8, 1, 1), shuffle: bool = True, random_state: Optional[int] = None, ): if n_splits < 2: raise ValueError("n_splits must be >= 2") if any(int(r) <= 0 for r in ratio): raise ValueError("All entries of ratio must be positive integers") self.n_splits = int(n_splits) self.n_repeats = int(n_repeats) self.ratio = (int(ratio[0]), int(ratio[1]), int(ratio[2])) self.shuffle = bool(shuffle) self.random_state = random_state self._val_frac_within_holdout = self.ratio[1] / (self.ratio[1] + self.ratio[2]) # computed state self._splits: List[SplitIndices] = [] # store original X if provided as DataFrame for as_frame slicing self._X_provided: Optional[Any] = None # can be DataFrame or array-like def __repr__(self) -> str: return ( f"RepeatedKFoldsSplitter(n_splits={self.n_splits}, n_repeats={self.n_repeats}, " f"ratio={self.ratio}, generated_splits={len(self._splits)}, random_state={self.random_state})" ) def __len__(self) -> int: return len(self._splits) # sklearn API
[docs] def get_n_splits(self, X=None, y=None, groups=None) -> int: """ Return how many (train, holdout) splits will be produced. :param X: Feature matrix or dataframe (ignored for counting). :param y: Labels (ignored for counting). :param groups: Groups (ignored for counting). :returns: Total number of outer splits (n_splits * n_repeats). """ return self.n_splits * self.n_repeats
[docs] def split( self, X: Any, y: Optional[Any] = None, groups: Optional[Any] = None, stratify: Optional[Union[str, Any]] = None, ) -> Iterable[Tuple[np.ndarray, np.ndarray]]: """ sklearn-compatible generator yielding (train_idx, holdout_idx) where holdout_idx == val + test. The `stratify` argument may be: - a column name (str) if X is a pandas.DataFrame, or - an array-like of length n_samples. If `stratify` is provided, it is used in preference to `y`. :param X: Feature matrix or pandas.DataFrame. :param y: Labels (array-like). Used for stratification if `stratify` is None. :param groups: Group labels for GroupKFold (optional). :param stratify: Column name or array-like used to stratify folds (optional). :yields: Tuples (train_idx, holdout_idx) for each repeat/fold. """ # If a DataFrame column name of stratify is passed, extract it stratify_arr = None if isinstance(stratify, str): if not isinstance(X, pd.DataFrame): raise ValueError( "When passing stratify as a column name (str), X must be a pandas.DataFrame" ) stratify_arr = X[stratify].values elif stratify is not None: stratify_arr = np.asarray(stratify) # prefer explicit stratify over y y_for_split = stratify_arr if stratify_arr is not None else y # Ensure X, y_for_split, groups are indexable and consistent lengths X_arr, y_arr, groups_arr = indexable(X, y_for_split, groups) self._X_provided = X # store original for get_split(as_frame=True) # Ensure splits computed (will use y_arr or groups_arr for stratification/grouping) if not self._splits: self._compute_splits(X_arr, y_arr, groups_arr) for s in self._splits: # holdout = val + test (concatenate) holdout = np.concatenate([s.val_idx, s.test_idx]).astype(int) yield s.train_idx.copy(), holdout.copy()
[docs] def split_with_val( self, X: Any, y: Optional[Any] = None, groups: Optional[Any] = None, stratify: Optional[Union[str, Any]] = None, ) -> Iterable[SplitIndices]: """ Yield SplitIndices objects containing (train, val, test) indices. :param X: Feature matrix or pandas.DataFrame. :param y: Labels (array-like). Used for stratification if `stratify` is None. :param groups: Group labels for GroupKFold (optional). :param stratify: Column name or array-like used to stratify folds (optional). :yields: SplitIndices objects for each repeat and fold. """ # handle stratify similar to split(...) stratify_arr = None if isinstance(stratify, str): if not isinstance(X, pd.DataFrame): raise ValueError( "When passing stratify as a column name (str), X must be a pandas.DataFrame" ) stratify_arr = X[stratify].values elif stratify is not None: stratify_arr = np.asarray(stratify) y_for_split = stratify_arr if stratify_arr is not None else y X_arr, y_arr, groups_arr = indexable(X, y_for_split, groups) self._X_provided = X if not self._splits: self._compute_splits(X_arr, y_arr, groups_arr) for s in self._splits: yield s
def _compute_splits( self, X_arr: Any, y_arr: Optional[Any], groups_arr: Optional[Any] ) -> None: """ Internal: compute and store all SplitIndices for repeats/folds. :param X_arr: indexable X (array-like). :param y_arr: indexable labels used for stratification (or None). :param groups_arr: groups (or None). """ n = len(X_arr) if n < self.n_splits: raise ValueError( f"n_splits={self.n_splits} is larger than dataset size {n}" ) # reset self._splits = [] base_rs = check_random_state(self.random_state) # if y_arr is provided use stratification; if groups_arr is provided prefer GroupKFold use_groups = groups_arr is not None use_stratify = (y_arr is not None) and not use_groups for r in range(self.n_repeats): rs_seed = base_rs.randint(np.iinfo(np.int32).max) rs = check_random_state(int(rs_seed)) # Choose outer splitter if use_groups: outer_cv = GroupKFold(n_splits=self.n_splits) split_gen = outer_cv.split(X_arr, y_arr, groups_arr) elif use_stratify: # StratifiedKFold: will raise if a class has fewer than n_splits members try: outer_cv = StratifiedKFold( n_splits=self.n_splits, shuffle=self.shuffle, random_state=rs_seed, ) split_gen = outer_cv.split(X_arr, y_arr) except ValueError: warnings.warn( "StratifiedKFold failed (likely too few members in some classes). Falling back to KFold.", UserWarning, ) outer_cv = KFold( n_splits=self.n_splits, shuffle=self.shuffle, random_state=rs_seed, ) split_gen = outer_cv.split(X_arr) else: outer_cv = KFold( n_splits=self.n_splits, shuffle=self.shuffle, random_state=rs_seed ) split_gen = outer_cv.split(X_arr) for fold_i, (train_outer_idx, hold_idx) in enumerate(split_gen): # train_outer_idx are indices NOT in holdout; however we will recompute 'train' below to ensure # train = everything except holdout (consistent with previous API) mask = np.ones(n, dtype=bool) mask[hold_idx] = False train_idx = np.nonzero(mask)[0].astype(int) # Now split holdout into val/test according to ratio hold_targets = None if use_stratify: # For stratification of holdout split, use y_arr[hold_idx] when possible. try: hold_targets = np.asarray(y_arr)[hold_idx] except Exception: hold_targets = None try: val_idx, test_idx = train_test_split( hold_idx, test_size=(1.0 - self._val_frac_within_holdout), random_state=int(rs.randint(np.iinfo(np.int32).max)), shuffle=True, stratify=hold_targets, ) except Exception: warnings.warn( "Inner holdout stratified split failed (likely too few members in some classes). " "Falling back to non-stratified split.", UserWarning, ) val_idx, test_idx = train_test_split( hold_idx, test_size=(1.0 - self._val_frac_within_holdout), random_state=int(rs.randint(np.iinfo(np.int32).max)), shuffle=True, stratify=None, ) self._splits.append( SplitIndices( repeat=r, fold=fold_i, train_idx=np.asarray(train_idx, dtype=int), val_idx=np.asarray(val_idx, dtype=int), test_idx=np.asarray(test_idx, dtype=int), ) )
[docs] def prepare_splits( self, X: Any, y: Optional[Any] = None, groups: Optional[Any] = None, stratify: Optional[Union[str, Any]] = None, ) -> None: """ Compute and store all splits immediately (equivalent to iterating split(...) fully). After calling this, self._splits is populated and get_split(...) may be used. """ # handle stratify same as split() stratify_arr = None if isinstance(stratify, str): if not isinstance(X, pd.DataFrame): raise ValueError( "When passing stratify as a column name (str), X must be a pandas.DataFrame" ) stratify_arr = X[stratify].values elif stratify is not None: stratify_arr = np.asarray(stratify) y_for_split = stratify_arr if stratify_arr is not None else y X_arr, y_arr, groups_arr = indexable(X, y_for_split, groups) self._X_provided = X # compute and save splits self._compute_splits(X_arr, y_arr, groups_arr)
[docs] def get_split(self, repeat: int = 0, fold: int = 0, as_frame: bool = False): """ Retrieve either index arrays (train_idx, val_idx, test_idx) or slices of the originally provided X (if it was a DataFrame or array-like) when as_frame=True. :param repeat: Repeat index (0-based). :param fold: Fold index within the repeat (0-based). :param as_frame: If True, return slices of the original X (DataFrame or ndarray) rather than indices. :returns: Tuple of (train, val, test) either as index arrays or as slices of X. :raises RuntimeError: If no splits have been computed yet. :raises IndexError: If the requested (repeat, fold) does not exist. """ if not self._splits: raise RuntimeError( "Call .split(X, ...) or .split_with_val(X, ...) before requesting a split" ) for s in self._splits: if s.repeat == repeat and s.fold == fold: found = s break else: raise IndexError(f"No split for repeat={repeat}, fold={fold}") if as_frame and self._X_provided is not None: X = self._X_provided # if X is a pandas DataFrame if isinstance(X, pd.DataFrame): return ( X.iloc[found.train_idx].reset_index(drop=True), X.iloc[found.val_idx].reset_index(drop=True), X.iloc[found.test_idx].reset_index(drop=True), ) # if numpy array or list-like -> return slices (np.take is safe) try: arr = np.asarray(X) return ( arr[found.train_idx], arr[found.val_idx], arr[found.test_idx], ) except Exception: # fallback to returning indices if slicing fails return ( found.train_idx.copy(), found.val_idx.copy(), found.test_idx.copy(), ) else: return found.train_idx.copy(), found.val_idx.copy(), found.test_idx.copy()
[docs] def iter_splits(self) -> Iterator[SplitIndices]: """ Iterate over computed splits in order (repeat major, fold minor). :returns: Iterator of SplitIndices objects. """ for s in self._splits: yield s
def __getitem__(self, key: Union[int, Tuple[int, int]]) -> SplitIndices: """ Allow indexing into computed splits. - splitter[0] -> first stored SplitIndices (by stored-order) - splitter[(repeat, fold)] -> SplitIndices for that repeat and fold :param key: int or (repeat, fold) :raises RuntimeError: if splits have not been computed yet. :raises IndexError: if the requested split is not found. :raises TypeError: if key type is unsupported. :returns: SplitIndices """ if not self._splits: raise RuntimeError( "No splits computed. Call split(...) or split_with_val(...) before indexing." ) if isinstance(key, int): return self._splits[key] if isinstance(key, tuple) and len(key) == 2: repeat, fold = int(key[0]), int(key[1]) for s in self._splits: if s.repeat == repeat and s.fold == fold: return s raise IndexError(f"No split for repeat={repeat}, fold={fold}") raise TypeError("Key must be int or tuple(repeat, fold)") @property def splits(self) -> List[SplitIndices]: """Return a copy of computed splits.""" return list(self._splits) @property def n_generated_splits(self) -> int: """Number of generated (repeat, fold) splits.""" return len(self._splits)