API#
Estimators#
- class tabicl.TabICLClassifier(n_estimators=8, norm_methods=None, feat_shuffle_method='latin', class_shuffle_method='shift', outlier_threshold=4.0, softmax_temperature=0.9, average_logits=True, support_many_classes=True, batch_size=8, kv_cache=False, model_path=None, allow_auto_download=True, checkpoint_version='tabicl-classifier-v2-20260212.ckpt', device=None, use_amp='auto', use_fa3='auto', offload_mode='auto', disk_offload_dir=None, random_state=42, n_jobs=None, verbose=False, inference_config=None)#
Tabular In-Context Learning (TabICL) Classifier with scikit-learn interface.
This classifier applies TabICL to tabular data classification, using an ensemble of transformed dataset views to improve predictions. The ensemble members are created by applying different normalization methods, feature permutations, and class label shifts.
- Parameters:
n_estimators (int, default=8) – Number of estimators for ensemble predictions.
norm_methods (str or list[str] or None, default=None) –
Normalization methods to apply:
’none’: No normalization
’power’: Yeo-Johnson power transform
’quantile’: Transform features to an approximately normal distribution.
’quantile_rtdl’: Quantile transform that adds noise to training data before fitting.
’robust’: Scale using median and quantiles
Can be a single string or a list of methods to use across ensemble members. When set to None, it will use [“none”, “power”].
feat_shuffle_method (str, default='latin') –
Feature permutation strategy:
’none’: No shuffling and preserve original feature order
’shift’: Circular shifting of feature columns
’random’: Random permutation of features
’latin’: Latin square patterns for systematic feature permutations
class_shuffle_method (str, default='shift') –
Class label permutation strategy:
’none’: No shuffling and preserve original class labels
’shift’: Circular shifting of class labels
’random’: Random permutation of class labels
’latin’: Latin square patterns for systematic class permutations
outlier_threshold (float, default=4.0) – Z-score threshold for outlier detection and clipping. Values with \(|z| > ext{threshold}\) are considered outliers.
softmax_temperature (float, default=0.9) – Temperature parameter :math:` au` for the softmax function, applied as :math:` ext{softmax}(x / au)`. Lower values make predictions more confident, higher values make them more conservative.
average_logits (bool, default=True) – Whether to average the logits (True) or probabilities (False) of ensemble members. Averaging logits often produces better calibrated probabilities.
support_many_classes (bool, default=True) – Whether to enable many-class support which performs mixed-radix ensembling during column-wise embedding and hierarchical classification during in-context learning. Required when the number of classes exceeds the model’s max_classes limit.
batch_size (Optional[int] = 8) – Batch size for inference. If None, all ensemble members are processed in a single batch. Adjust this parameter based on available memory. Lower values use less memory but may be slower.
kv_cache (bool or str, default=False) –
Controls caching of training data computations to speed up subsequent
predict_proba/predictcalls. The cache is built duringfit().False: No caching.
True or “kv”: Cache key-value projections from both column embedding and ICL transformer layers. Fast inference but memory-heavy for large training sets.
”repr”: Cache column embedding KV projections and row interaction outputs (representations). Uses ~24x less memory than “kv” for the ICL part, at the cost of re-running the ICL transformer at predict time.
The cache retains whatever dtype the model produced during
fit()(float16 when AMP is active, float32 otherwise). If the cache is later loaded on CPU or on CUDA without AMP, the tensors are automatically upcast to float32 to avoid dtype-mismatch errors.model_path (Optional[str | Path] = None) –
Path to the pre-trained model checkpoint file.
If provided and the file exists, it’s loaded directly.
If provided but the file doesn’t exist and allow_auto_download is true, the version specified by checkpoint_version is downloaded from Hugging Face Hub (repo: ‘jingang/TabICL’) to this path.
If None (default), the version specified by checkpoint_version is downloaded from Hugging Face Hub (repo: ‘jingang/TabICL’) and cached locally in the default Hugging Face cache directory (typically ~/.cache/huggingface/hub).
allow_auto_download (bool, default=True) – Whether to allow automatic download if the pretrained checkpoint cannot be found at the specified model_path.
checkpoint_version (str, default='tabicl-classifier-v2-20260212.ckpt') – Specifies which version of the pre-trained model checkpoint to use when model_path is None or points to a non-existent file (and allow_auto_download is true). Checkpoints are downloaded from https://huggingface.co/jingang/TabICL. Available versions: - ‘tabicl-classifier-v2-20260212.ckpt’ (Default): The latest best-performing version, used in our TabICLv2 paper. - ‘tabicl-classifier-v1.1-20250506.ckpt’: An enhanced version of TabICLv1 using a precursor of the v2 prior. - ‘tabicl-classifier-v1-20250208.ckpt’: The version used in our TabICLv1 paper.
device (Optional[str or torch.device], default=None) – Device to use for inference. If None, automatically selects CUDA if available, otherwise CPU. Can be specified as a string (
'cuda','cpu','mps') or atorch.deviceobject. MPS (Apple Silicon GPU) is supported but must be explicitly requested.use_amp (bool or "auto", default="auto") –
Controls automatic mixed precision (AMP) for inference. - True / False: force on / off. - “auto”: Automatically enable AMP based on input data size using the following heuristic:
Regime
AMP
FA3
Small (n < 1024 & feat < 60)
off
off
Medium (above small, n < 10240)
on
off
Large (n >= 10240)
on
on
The above heuristic is based on the observation that AMP can introduce overhead that outweighs its benefits for small inputs. In addition, it assumes that the training set is large relative to the test set and does not account for KV-cache scenarios. If it is suboptimal for your workload, set it explicitly.
use_fa3 (bool or "auto", default="auto") – Whether to use Flash Attention 3 that can speed up inference for large datasets on NVIDIA Hopper GPUs like H100. Only effective when FA3 is installed. - True / False: force on / off. - “auto”: Automatically enable FA3 based on input data size using a simple heuristic (see above).
offload_mode (str or bool, default='auto') –
Controls where column-wise embedding outputs are stored during inference. Column-wise embedding produces a large tensor of shape (batch_size, n_rows, n_columns, embed_dim) which is the main memory bottleneck. Available options: -
'auto': Automatically choose based on available memory (default). -'gpu'orFalse: Keep on GPU. Fastest but limited by VRAM. -'cpu'orTrue: Offload to CPU memory. -'disk': Offload to memory-mapped files (requiresdisk_offload_dir).It only affects column-wise embedding (COL_CONFIG). For finer-grained control over all components, use
inference_config.disk_offload_dir (Optional[str], default=None) – Directory for memory-mapped files used when
offload_mode='disk'or whenoffload_mode='auto'falls back to disk offloading. It only affects column-wise embedding (COL_CONFIG). For finer-grained control over all components, useinference_config.random_state (int or None, default=42) – Random seed for reproducibility of ensemble generation, affecting feature shuffling and other randomized operations.
n_jobs (int or None, default=None) – Number of threads to use for PyTorch in case the model is run on CPU. None means using the PyTorch default, which is the number of physical CPU cores. Negative numbers mean that \(\max(1, n_{\text{logical\_cores}} + 1 + \text{n\_jobs})\) threads will be used. In particular,
n_jobs=-1means that all logical cores will be used.verbose (bool, default=False) – Whether to print detailed information during inference.
inference_config (Optional[InferenceConfig | Dict[str, Dict[str, Any]]], default=None) –
Configuration for inference settings. This parameter provides fine-grained control over the three transformers in TabICL (column-wise, row-wise, and in-context learning).
WARNING: This parameter should only be used by advanced users who understand the internal architecture of TabICL and need precise control over inference.
- When None (default):
A new InferenceConfig object is created with default settings
The
device,use_amp,use_fa3,offload_mode,disk_offload_dir, andverboseparameters from the class initialization are applied to the relevant components
- When Dict with allowed top-level keys “COL_CONFIG”, “ROW_CONFIG”, “ICL_CONFIG”:
A new InferenceConfig object is created with default settings
Any values explicitly specified in the dictionary will override default defaults
device,use_amp,use_fa3,offload_mode,disk_offload_dir, andverbosefrom the class initialization are used if they are not specified in the dictionary
- When InferenceConfig:
The provided InferenceConfig object is used directly without modification
device,use_amp,use_fa3,offload_mode,disk_offload_dir, andverbosefrom the class initialization are ignoredAll settings must be explicitly defined in the provided InferenceConfig object
- classes_#
Class labels known to the classifier.
- Type:
ndarray of shape (n_classes,)
- feature_names_in_#
Feature names seen during
fit. Only set when the inputXhas feature names (e.g., a pandas DataFrame with string column names).- Type:
ndarray of shape
(n_features_in_,)or None
- X_encoder_#
Encoder for transforming input features to numerical values.
- Type:
TransformToNumerical
- y_encoder_#
Encoder for transforming class labels to integers and back.
- Type:
LabelEncoder
- ensemble_generator_#
Fitted ensemble generator that creates multiple dataset views.
- Type:
EnsembleGenerator
- model_#
The loaded TabICL model used for predictions.
- Type:
TabICL
- device_#
The device where the model is loaded and computations are performed.
- Type:
- inference_config_#
The inference configuration.
- Type:
- cache_mode_#
The resolved caching mode, set during
fit()based on thekv_cacheinit parameter. One of"kv","repr", orNone(no caching).- Type:
str or None
- model_kv_cache_#
Pre-computed KV caches for training data, keyed by normalization method. Created during
fit()whenkv_cacheis enabled. When set,predict_proba()reuses the cached key-value projections instead of re-processing training data, enabling faster inference on multiple test sets.- Type:
OrderedDict[str, TabICLCache] or None
- fit(X, y)#
Fit the classifier to training data.
Prepares the model for prediction by:
Encoding class labels using LabelEncoder
Converting input features to numerical values
Fitting the ensemble generator to create transformed dataset views
Loading the pre-trained TabICL model
Optionally pre-computing KV caches for training data to speed up inference (controlled by the
kv_cacheinit parameter)
The model itself is not trained on the data; it uses in-context learning at inference time.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training input data.
y (array-like of shape (n_samples,)) – Training target labels.
- Returns:
self – Fitted classifier instance.
- Return type:
- Raises:
ValueError – If the number of classes exceeds the model’s maximum supported classes and many-class support is disabled.
- predict(X)#
Predict class labels for test samples.
Uses predict_proba to get class probabilities and returns the class with the highest probability for each sample.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples for prediction. Columns that are entirely NaN are treated as masked features and excluded from inference. This is useful for computing SHAP values, where masked features are represented as all-NaN columns.
- Returns:
Predicted class labels for each test sample.
- Return type:
array-like of shape (n_samples,)
- predict_proba(X)#
Predict class probabilities for test samples.
Applies the ensemble of TabICL models to make predictions, with each ensemble member providing predictions that are then averaged. The method:
Transforms input data using the fitted encoders
Applies the ensemble generator to create multiple views
Forwards each view through the model
Corrects for class shuffles
Averages predictions across ensemble members
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples for prediction. Columns that are entirely NaN are treated as masked features and excluded from inference. This is useful for computing SHAP values, where masked features are represented as all-NaN columns.
- Returns:
Class probabilities for each test sample.
- Return type:
np.ndarray of shape (n_samples, n_classes)
- set_score_request(*, sample_weight='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- class tabicl.TabICLRegressor(n_estimators=8, norm_methods=None, feat_shuffle_method='latin', outlier_threshold=4.0, batch_size=8, kv_cache=False, model_path=None, allow_auto_download=True, checkpoint_version='tabicl-regressor-v2-20260212.ckpt', device=None, use_amp='auto', use_fa3='auto', offload_mode='auto', disk_offload_dir=None, random_state=42, n_jobs=None, verbose=False, inference_config=None)#
Tabular In-Context Learning (TabICL) Regressor with scikit-learn interface.
This regressor applies TabICL to tabular data regression, using an ensemble of transformed dataset views to improve predictions. The ensemble members are created by applying different normalization methods and feature permutations.
- Parameters:
n_estimators (int, default=8) – Number of estimators for ensemble predictions.
norm_methods (str or list[str] or None, default=None) – Normalization methods to apply: - ‘none’: No normalization - ‘power’: Yeo-Johnson power transform - ‘quantile’: Transform features to an approximately normal distribution. - ‘quantile_rtdl’: Quantile transform that adds noise to training data before fitting. - ‘robust’: Scale using median and quantiles Can be a single string or a list of methods to use across ensemble members. When set to None, it will use [“none”, “power”].
feat_shuffle_method (str, default='latin') – Feature permutation strategy: - ‘none’: No shuffling and preserve original feature order - ‘shift’: Circular shifting of feature columns - ‘random’: Random permutation of features - ‘latin’: Latin square patterns for systematic feature permutations
outlier_threshold (float, default=4.0) – Z-score threshold for outlier detection and clipping. Values with \(|z| > ext{threshold}\) are considered outliers.
batch_size (Optional[int], default=8) – Batch size for inference. If None, all ensemble members are processed in a single batch. Adjust this parameter based on available memory. Lower values use less memory but may be slower.
kv_cache (bool or str, default=False) –
Controls caching of training data computations to speed up subsequent
predictcalls. The cache is built duringfit().False: No caching.
True or “kv”: Cache key-value projections from both column embedding and ICL transformer layers. Fast inference but memory-heavy for large training sets.
”repr”: Cache column embedding KV projections and row interaction outputs (representations). Uses ~24x less memory than “kv” for the ICL part, at the cost of re-running the ICL transformer at predict time.
The cache retains whatever dtype the model produced during
fit()(float16 when AMP is active, float32 otherwise). If the cache is later loaded on CPU or on CUDA without AMP, the tensors are automatically upcast to float32 to avoid dtype-mismatch errors.model_path (Optional[str or Path], default=None) –
Path to the pre-trained model checkpoint file.
If provided and the file exists, it’s loaded directly.
If provided but the file doesn’t exist and allow_auto_download is true, the version specified by checkpoint_version is downloaded from Hugging Face Hub (repo: ‘jingang/TabICL’) to this path.
If None (default), the version specified by checkpoint_version is downloaded from Hugging Face Hub (repo: ‘jingang/TabICL’) and cached locally in the default Hugging Face cache directory (typically ~/.cache/huggingface/hub).
allow_auto_download (bool, default=True) – Whether to allow automatic download if the pretrained checkpoint cannot be found at the specified model_path.
checkpoint_version (str, default='tabicl-regressor-v2-20260212.ckpt') – Specifies which version of the pre-trained model checkpoint to use when model_path is None or points to a non-existent file (and allow_auto_download is true). Checkpoints are downloaded from https://huggingface.co/jingang/TabICL.
device (Optional[str or torch.device], default=None) – Device to use for inference. If None, automatically selects CUDA if available, otherwise CPU. Can be specified as a string (
'cuda','cpu','mps') or atorch.deviceobject. MPS (Apple Silicon GPU) is supported but must be explicitly requested.use_amp (bool or "auto", default="auto") –
Controls automatic mixed precision (AMP) for inference. - True / False: force on / off. - “auto”: Automatically enable AMP based on input data size using the following heuristic:
Regime
AMP
FA3
Small (n < 1024 & feat < 60)
off
off
Medium (above small, n < 10240)
on
off
Large (n >= 10240)
on
on
The above heuristic is based on the observation that AMP can introduce overhead that outweighs its benefits for small inputs. In addition, it assumes that the training set is large relative to the test set and does not account for KV-cache scenarios. If it is suboptimal for your workload, set it explicitly.
use_fa3 (bool or "auto", default="auto") – Whether to use Flash Attention 3 that can speed up inference for large datasets on NVIDIA Hopper GPUs like H100. Only effective when FA3 is installed. - True / False: force on / off. - “auto”: Automatically enable FA3 based on input data size using a simple heuristic (see above).
offload_mode (str or bool, default='auto') –
Controls where column-wise embedding outputs are stored during inference. Column-wise embedding produces a large tensor of shape (batch_size, n_rows, n_columns, embed_dim) which is the main memory bottleneck. Available options: -
'auto': Automatically choose based on available memory (default). -'gpu'orFalse: Keep on GPU. Fastest but limited by VRAM. -'cpu'orTrue: Offload to CPU memory. -'disk': Offload to memory-mapped files (requiresdisk_offload_dir).It only affects column-wise embedding (COL_CONFIG). For finer-grained control over all components, use
inference_config.disk_offload_dir (Optional[str], default=None) – Directory for memory-mapped files used when
offload_mode='disk'or whenoffload_mode='auto'falls back to disk offloading. It only affects column-wise embedding (COL_CONFIG). For finer-grained control over all components, useinference_config.random_state (int or None, default=42) – Random seed for reproducibility of ensemble generation, affecting feature shuffling and other randomized operations.
n_jobs (int or None, default=None) – Number of threads to use for PyTorch in case the model is run on CPU. None means using the PyTorch default, which is the number of physical CPU cores. Negative numbers mean that \(\max(1, n_{\text{logical\_cores}} + 1 + \text{n\_jobs})\) threads will be used. In particular,
n_jobs=-1means that all logical cores will be used.verbose (bool, default=False) – Whether to print detailed information during inference.
inference_config (Optional[InferenceConfig | Dict[str, Dict[str, Any]]], default=None) –
Configuration for inference settings. This parameter provides fine-grained control over the three transformers in TabICL (column-wise, row-wise, and in-context learning).
WARNING: This parameter should only be used by advanced users who understand the internal architecture of TabICL and need precise control over inference.
- When None (default):
A new InferenceConfig object is created with default settings
The
device,use_amp,use_fa3,offload_mode,disk_offload_dir, andverboseparameters from the class initialization are applied to the relevant components
- When Dict with allowed top-level keys “COL_CONFIG”, “ROW_CONFIG”, “ICL_CONFIG”:
A new InferenceConfig object is created with default settings
Any values explicitly specified in the dictionary will override default defaults
device,use_amp,use_fa3,offload_mode,disk_offload_dir, andverbosefrom the class initialization are used if they are not specified in the dictionary
- When InferenceConfig:
The provided InferenceConfig object is used directly without modification
device,use_amp,use_fa3,offload_mode,disk_offload_dir, andverbosefrom the class initialization are ignoredAll settings must be explicitly defined in the provided InferenceConfig object
- feature_names_in_#
Feature names seen during
fit. Only set when the inputXhas feature names (e.g., a pandas DataFrame with string column names).- Type:
ndarray of shape
(n_features_in_,)or None
- X_encoder_#
Encoder for transforming input features to numerical values.
- Type:
TransformToNumerical
- y_scaler_#
Scaler for transforming target values.
- Type:
StandardScaler
- ensemble_generator_#
Fitted ensemble generator that creates multiple dataset views.
- Type:
EnsembleGenerator
- model_#
The loaded TabICL model used for predictions.
- Type:
TabICL
- device_#
The device where the model is loaded and computations are performed.
- Type:
- inference_config_#
The inference configuration.
- Type:
- cache_mode_#
The resolved caching mode, set during
fit()based on thekv_cacheinit parameter. One of"kv","repr", orNone(no caching).- Type:
str or None
- model_kv_cache_#
Pre-computed KV caches for training data, keyed by normalization method. Created during
fit()whenkv_cacheis enabled. When set,predict()reuses the cached key-value projections instead of re-processing training data, enabling faster inference on multiple test sets.- Type:
OrderedDict[str, TabICLCache] or None
- fit(X, y)#
Fit the regressor to training data.
Prepares the model for prediction by:
Scaling target values using StandardScaler
Converting input features to numerical values
Fitting the ensemble generator to create transformed dataset views
Loading the pre-trained TabICL model
Optionally pre-computing KV caches for training data to speed up inference (controlled by the
kv_cacheinit parameter)
The model itself is not trained on the data; it uses in-context learning at inference time. This method only prepares the data transformations.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training input data.
y (array-like of shape (n_samples,)) – Training target values.
- Returns:
self – Fitted regressor instance.
- Return type:
- predict(X, output_type='mean', alphas=None)#
Predict target values for test samples.
Applies the ensemble of TabICL models to make predictions, with each ensemble member providing predictions that are then averaged. The method:
Transforms input data using the fitted encoders
Applies the ensemble generator to create multiple views
Forwards each view through the model
Averages predictions across ensemble members
Inverse transforms predictions to original scale
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples for prediction. Columns that are entirely NaN are treated as masked features and excluded from inference. This is useful for computing SHAP values, where masked features are represented as all-NaN columns.
output_type (str or list of str, default="mean") –
Determines the type of output to return.
If
"mean", returns the mean over the predicted distribution.If
"median", returns the median over the predicted distribution.If
"quantiles", returns the quantiles of the predicted distribution. The parameteralphasdetermines which quantiles are returned.If
"raw_quantiles", returns the raw quantiles (direct outputs of TabICL).If a list of str, returns multiple types of outputs as specified in the list.
alphas (list of float or None, default=None) –
The probability levels to return if
output_type="quantiles".By default, the
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]quantiles are returned. The predictions per quantile match the input order.
- Returns:
An array of shape
(n_samples,)ifoutput_typeis"mean"or"median", or an array of shape(n_samples, n_quantiles)ifoutput_typeis"quantiles"or"raw_quantiles".If
output_typeis a list of str, returns a dictionary with keys as specified in the list and values as the corresponding predictions.- Return type:
- set_predict_request(*, alphas='$UNCHANGED$', output_type='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
predictmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed topredictif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it topredict.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- Returns:
self – The updated object.
- Return type:
- set_score_request(*, sample_weight='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- class tabicl.FinetunedTabICLClassifier(*, epochs=30, learning_rate=1e-5, weight_decay=0.01, grad_clip=1.0, amp=True, use_lr_scheduler=True, warmup_proportion=0.1, n_estimators_finetune=2, n_estimators_validation=2, n_estimators_inference=8, max_data_size=10_000, finetune_ctx_query_ratio=0.2, validation_split_ratio=0.1, early_stopping=True, patience=8, min_delta=1e-4, time_limit=None, save_interval=1, norm_methods=None, feat_shuffle_method='latin', outlier_threshold=4.0, model_path=None, allow_auto_download=True, checkpoint_version='tabicl-classifier-v2-20260212.ckpt', freeze_col=False, freeze_row=False, freeze_icl=False, device=None, random_state=42, verbose=False, wandb_kwargs=None, class_shuffle_method='shift', softmax_temperature=0.9, average_logits=True, support_many_classes=True, eval_metric='roc_auc', extra_classifier_kwargs=None)#
Fine-tune a pretrained TabICL for single-dataset classification.
Subclass of
FinetunedTabICLBasethat implements cross-entropy loss on the raw TabICL logits and ROC-AUC / log-loss / accuracy evaluation metrics.Minimal usage:
from tabicl import FinetunedTabICLClassifier clf = FinetunedTabICLClassifier(epochs=30, device="cuda", verbose=True) clf.fit(X_train, y_train, X_val=X_val, y_val=y_val) y_proba = clf.predict_proba(X_test)
- Parameters:
**Optimization**
epochs (int, default=30) – Number of passes through the fine-tuning meta-batches.
learning_rate (float, default=1e-5) – AdamW learning rate.
weight_decay (float, default=0.01) – AdamW weight decay.
grad_clip (float, default=1.0) – Max global gradient norm (
0disables).amp (bool, default=True) – Use FP16 automatic mixed precision on CUDA.
use_lr_scheduler (bool, default=True) – Cosine-with-warmup LR schedule.
warmup_proportion (float, default=0.1) – Warmup fraction of total steps.
pipeline** (**Data)
n_estimators_finetune (int, default=2) – Ensemble size during training meta-batches.
n_estimators_validation (int, default=2) – Ensemble size during end-of-epoch validation.
n_estimators_inference (int, default=8) – Ensemble size of the final inner estimator used by
predict()/predict_proba().max_data_size (int, default=10_000) – Max samples per meta-dataset chunk.
finetune_ctx_query_ratio (float, default=0.2) – Query fraction inside each chunk.
validation_split_ratio (float, default=0.1) – Size of auto-split validation set when
X_val/y_valare not passed tofit().budget** (**Early stopping & time)
early_stopping (bool, default=True) – Stop after
patiencenon-improving epochs.patience (int, default=8) – Number of non-improving epochs tolerated.
min_delta (float, default=1e-4) – Minimum metric improvement that counts as an improvement.
time_limit (float or None, default=None) – Wall-clock budget in seconds;
Nonedisables.save_interval (int, default=1) – Write an interval checkpoint every N epochs; best is always saved.
**Preprocessing**
norm_methods (str, list[str] or None, default=None) – Normalization methods forwarded to
tabicl._sklearn.preprocessing.EnsembleGenerator.feat_shuffle_method (str, default="latin") – Feature-permutation strategy for ensemble diversity.
outlier_threshold (float, default=4.0) – Z-score threshold for outlier clipping during preprocessing.
loading** (**Model)
model_path (str, Path or None, default=None) – Checkpoint file to fine-tune from.
None→ download the default TabICLv2 classifier checkpoint from Hugging Face Hub.allow_auto_download (bool, default=True) – Permit downloading the pretrained checkpoint when it isn’t cached.
checkpoint_version (str, default="tabicl-classifier-v2-20260212.ckpt") – Pretrained checkpoint version identifier.
**Freezing**
freeze_col (bool, default=False) – Freeze the column-embedding sub-module (weights and dropout/BN).
freeze_row (bool, default=False) – Freeze the row-interaction sub-module.
freeze_icl (bool, default=False) – Freeze the in-context-learning predictor.
logging** (**Device &)
device (str, torch.device or None, default=None) – Compute device;
Noneauto-selectscudawhen available.random_state (int, default=42) – Seed for data splits and ensemble shuffle patterns.
verbose (bool, default=False) – Print a tqdm progress bar and one-line per-epoch summary.
wandb_kwargs (dict or None, default=None) – When provided, enables Weights & Biases tracking by instantiating
WandbLogger(**wandb_kwargs)on rank 0. Supported keys are those ofwandb.init()— most commonlyproject,name(the W&B run name),entity,tags,notes,group,mode("online" / "offline" / "disabled"), anddir. All keys are forwarded verbatim towandb.init.**Classifier-specific**
class_shuffle_method (str, default="shift") – Class-label shuffle strategy for ensemble diversity.
softmax_temperature (float, default=0.9) – Softmax temperature used by the inner
TabICLClassifierat inference time.average_logits (bool, default=True) – If True, ensemble averaging is done on logits; else on probabilities.
support_many_classes (bool, default=True) – Enable TabICL’s mixed-radix ensembling when the dataset has more classes than the pretrained head’s native
max_classes.eval_metric ({"roc_auc", "log_loss", "accuracy"}, default="roc_auc") – Primary validation metric driving early stopping and best-weight selection.
log_lossis internally negated so “higher is better” holds uniformly.extra_classifier_kwargs (dict or None, default=None) – Additional kwargs forwarded to the inner
TabICLClassifier(e.g.{"kv_cache": "kv"}).
- property classes_#
Class labels in the order used by
predict_proba().
- predict_proba(X)#
Predict class probabilities for
X.- Returns:
Probability that each sample belongs to each class, in the order given by
classes_.- Return type:
ndarray of shape (n_samples, n_classes)
- set_fit_request(*, X_val='$UNCHANGED$', output_dir='$UNCHANGED$', y_val='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
X_val (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
X_valparameter infit.output_dir (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
output_dirparameter infit.y_val (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
y_valparameter infit.
- Returns:
self – The updated object.
- Return type:
- set_score_request(*, sample_weight='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- class tabicl.FinetunedTabICLRegressor(*, epochs=30, learning_rate=1e-5, weight_decay=0.01, grad_clip=1.0, amp=True, use_lr_scheduler=True, warmup_proportion=0.1, n_estimators_finetune=2, n_estimators_validation=2, n_estimators_inference=8, max_data_size=10_000, finetune_ctx_query_ratio=0.2, validation_split_ratio=0.1, early_stopping=True, patience=8, min_delta=1e-4, time_limit=None, save_interval=1, norm_methods=None, feat_shuffle_method='latin', outlier_threshold=4.0, model_path=None, allow_auto_download=True, checkpoint_version='tabicl-regressor-v2-20260212.ckpt', freeze_col=False, freeze_row=False, freeze_icl=False, device=None, random_state=42, verbose=False, wandb_kwargs=None, eval_metric='mse', extra_regressor_kwargs=None)#
Fine-tune a pretrained TabICL for single-dataset regression.
Subclass of
FinetunedTabICLBasethat trains against pinball (quantile) loss applied directly to the raw quantile outputs of TabICL and scores validation with MSE / MAE / R² computed in raw y space via the wrappedTabICLRegressor.Minimal usage:
from tabicl import FinetunedTabICLRegressor reg = FinetunedTabICLRegressor(epochs=30, device="cuda", verbose=True) reg.fit(X_train, y_train, X_val=X_val, y_val=y_val) y_pred = reg.predict(X_test) quantiles = reg.predict(X_test, output_type="quantiles", alphas=[0.1, 0.5, 0.9])
- Parameters:
**Optimization**
epochs (int, default=30) – Number of passes through the fine-tuning meta-batches.
learning_rate (float, default=1e-5) – AdamW learning rate.
weight_decay (float, default=0.01) – AdamW weight decay.
grad_clip (float, default=1.0) – Max global gradient norm (
0disables).amp (bool, default=True) – Use FP16 automatic mixed precision on CUDA.
use_lr_scheduler (bool, default=True) – Cosine-with-warmup LR schedule.
warmup_proportion (float, default=0.1) – Warmup fraction of total steps.
pipeline** (**Data)
n_estimators_finetune (int, default=2) – Ensemble size during training meta-batches.
n_estimators_validation (int, default=2) – Ensemble size during end-of-epoch validation.
n_estimators_inference (int, default=8) – Ensemble size of the final inner estimator used by
predict().max_data_size (int, default=10_000) – Max samples per meta-dataset chunk.
finetune_ctx_query_ratio (float, default=0.2) – Query fraction inside each chunk.
validation_split_ratio (float, default=0.1) – Size of auto-split validation set when
X_val/y_valare not passed tofit().budget** (**Early stopping & time)
early_stopping (bool, default=True) – Stop after
patiencenon-improving epochs.patience (int, default=8) – Number of non-improving epochs tolerated.
min_delta (float, default=1e-4) – Minimum metric improvement that counts as an improvement.
time_limit (float or None, default=None) – Wall-clock budget in seconds;
Nonedisables.save_interval (int, default=1) – Write an interval checkpoint every N epochs; best is always saved.
**Preprocessing**
norm_methods (str, list[str] or None, default=None) – Normalization methods forwarded to
tabicl._sklearn.preprocessing.EnsembleGenerator.feat_shuffle_method (str, default="latin") – Feature-permutation strategy for ensemble diversity.
outlier_threshold (float, default=4.0) – Z-score threshold for outlier clipping during preprocessing.
loading** (**Model)
model_path (str, Path or None, default=None) – Checkpoint file to fine-tune from.
None→ download the default TabICLv2 regressor checkpoint from Hugging Face Hub.allow_auto_download (bool, default=True) – Permit downloading the pretrained checkpoint when it isn’t cached.
checkpoint_version (str, default="tabicl-regressor-v2-20260212.ckpt") – Pretrained checkpoint version identifier.
**Freezing**
freeze_col (bool, default=False) – Freeze the column-embedding sub-module (weights and dropout/BN).
freeze_row (bool, default=False) – Freeze the row-interaction sub-module.
freeze_icl (bool, default=False) – Freeze the in-context-learning predictor.
logging** (**Device &)
device (str, torch.device or None, default=None) – Compute device;
Noneauto-selectscudawhen available.random_state (int, default=42) – Seed for data splits and ensemble shuffle patterns.
verbose (bool, default=False) – Print a tqdm progress bar and one-line per-epoch summary.
wandb_kwargs (dict or None, default=None) – When provided, enables Weights & Biases tracking by instantiating
WandbLogger(**wandb_kwargs)on rank 0. Supported keys are those ofwandb.init()— most commonlyproject,name(the W&B run name),entity,tags,notes,group,mode("online" / "offline" / "disabled"), anddir. All keys are forwarded verbatim towandb.init.**Regressor-specific**
eval_metric ({"mse", "mae", "r2"}, default="mse") – Primary validation metric driving early stopping and best-weight selection. Computed in raw y space via
TabICLRegressor.predict().mseandmaeare internally negated so “higher is better” holds uniformly.extra_regressor_kwargs (dict or None, default=None) – Additional kwargs forwarded to the inner
TabICLRegressor.
- predict(X, output_type='mean', alphas=None)#
Predict target values for
X.Thin wrapper around
TabICLRegressor.predict()that forwardsoutput_type/alphasunchanged. See that method for the full set of supported output types ("mean","median","quantiles","raw_quantiles").
- set_fit_request(*, X_val='$UNCHANGED$', output_dir='$UNCHANGED$', y_val='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
X_val (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
X_valparameter infit.output_dir (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
output_dirparameter infit.y_val (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
y_valparameter infit.
- Returns:
self – The updated object.
- Return type:
- set_predict_request(*, alphas='$UNCHANGED$', output_type='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
predictmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed topredictif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it topredict.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- Returns:
self – The updated object.
- Return type:
- set_score_request(*, sample_weight='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- class tabicl.TabICLForecaster(max_context_length=4096, temporal_features=None, point_estimate='mean', tabicl_config=None)#
TabICL-based time series forecasting pipeline.
This pipeline uses TabICL for zero-shot time series forecasting. It handles the entire prediction workflow:
Data preprocessing (missing value handling, context slicing)
Feature engineering (calendar features, seasonal detection)
Prediction with probabilistic quantile estimates
- Parameters:
max_context_length (int, default=4096) – Maximum number of historical timesteps to use as context. The pipeline automatically slices to the last
max_context_lengthtimesteps if the historical data is longer.temporal_features (list[str | TimeTransform] | None, default=None) –
Feature transforms to apply to the time series. Each element can be a string name or a
TimeTransforminstance. IfNone, defaults to["index", "datetime", "periodic"].Available string names:
"index"—IndexEncoder: sequential position features."datetime"—DatetimeEncoder: calendar features (day-of-week, month, etc.)."fourier"—FourierEncoder: Fourier basis features for periodicity."periodic"—AutoPeriodicEncoder: automatically detected seasonal features via FFT.
Strings and
TimeTransforminstances can be mixed freely, which is useful when a built-in transform needs non-default parameters:temporal_features=["index", FourierEncoder(period=7)]
point_estimate ({"mean", "median"}, default="mean") – Method to select the point prediction from TabICL output.
tabicl_config (dict | None, default=None) – Configuration for
TabICLRegressorinitialization. If None, defaults to empty dict (uses default settings).
Notes
For time series with irregular timestamps, consider opting out of
"periodic"(AutoPeriodicEncoder).
- predict(context_tsdf, future_tsdf, quantiles=None)#
Generate forecasts using TimeSeriesDataFrame objects.
This is the core prediction method. For a simpler pandas DataFrame interface, see
predict_df.- Parameters:
context_tsdf (TimeSeriesDataFrame) – Historical time series data used as context for prediction. Must contain a
targetcolumn with historical values. May contain additional covariate columns.future_tsdf (TimeSeriesDataFrame) – Future timestamps for which to generate predictions. Should contain the same covariate columns as
context_tsdf. Thetargetcolumn should be NaN (will be filled with predictions).quantiles (list[float] | None, default=None) – Quantiles to predict for probabilistic forecasting. Defaults to
[0.1, 0.2, ..., 0.9].
- Returns:
Predictions containing:
target: Point predictions (mean by default)One column per quantile (e.g.,
0.1,0.9) for prediction intervals
- Return type:
Notes
Only covariates present in both
context_tsdfandfuture_tsdfwill be used.Context is automatically sliced to
max_context_lengthif longer.Missing values in context are handled automatically.
- predict_df(context_df, future_df=None, prediction_length=None, quantiles=None)#
Generate forecasts from pandas DataFrames.
This is the recommended user-facing API. It accepts standard pandas DataFrames and returns predictions as a DataFrame.
- Parameters:
context_df (pd.DataFrame) –
Historical time series data. Required columns:
timestamp: Timestamps for each observation (datetime).target: Historical values to forecast from (numeric).item_id(optional): Identifier for multiple time series. If omitted, assumes a single time series.Additional columns are treated as known covariates.
future_df (pd.DataFrame or None, default=None) –
Future timestamps for prediction. Required columns:
timestamp: Future timestamps to forecast (datetime).item_id(optional): Must match item_ids incontext_df.Covariate columns matching those in
context_df.
Mutually exclusive with
prediction_length. Use this when you have known future covariate values or irregular timestamps.prediction_length (int or None, default=None) – Number of time steps to forecast into the future. Mutually exclusive with
future_df. Use this for simple forecasting when you don’t have future covariates.quantiles (list[float] | None, default=None) – Quantiles to predict for uncertainty estimation. Defaults to
[0.1, 0.2, ..., 0.9].
- Returns:
Forecasts indexed by (item_id, timestamp) containing:
target: Point predictions (mean by default)One column per quantile for prediction intervals
- Return type:
pd.DataFrame
- Raises:
ValueError – If both or neither of
future_dfandprediction_lengthare provided.
- class tabicl.TabICLUnsupervised(n_estimators=8, categorical_features=None, max_categories=10, batch_size=8, random_state=42, device=None, estimator_params=None)#
Unsupervised learning with TabICL.
Supports three unsupervised tasks: - Imputation: Fill NaN values by conditioning on observed features. - Outlier detection: Score samples by their estimated joint density. - Synthetic data generation: Autoregressive sampling from the learned density.
Estimates the joint density by decomposing it via the chain rule of probability:
P(X) = P(X_1) * P(X_2 | X_1) * ... * P(X_d | X_1, ..., X_{d-1})Each conditional
P(X_k | X_{<k})is predicted by a TabICL classifier (for categorical features) or regressor (for numerical features). Multiple random feature orderings (permutations) are averaged to reduce the dependence on any single ordering.- Parameters:
n_estimators (int, default=8) – Number of ensemble estimators per conditional prediction.
categorical_features (list[int] or None, default=None) – Indices of categorical features. If None, auto-detected from data based on
max_categories.max_categories (int, default=10) – Maximum unique values for auto-detection of categorical features.
batch_size (int or None, default=8) – Batch size for inner estimator inference.
random_state (int or None, default=42) – Random seed for reproducibility.
device (str or None, default=None) – Device for inference. None auto-selects CUDA or CPU.
estimator_params (dict or None, default=None) – Additional keyword arguments forwarded to the inner
TabICLClassifierandTabICLRegressor(e.g.norm_methods,outlier_threshold).
- X_#
Copy of the training data, used as conditioning context for all predictions.
- Type:
np.ndarray of shape (n_samples, n_features)
- categorical_features_#
Indices of categorical features (user-supplied or auto-detected).
- numerical_features_#
Indices of numerical features (complement of
categorical_features_).
- categories_#
Mapping from categorical feature index to its sorted unique values.
- _clf_model#
Shared classifier model weights (loaded once in
fit()).- Type:
torch.nn.Module or None
- _reg_model#
Shared regressor model weights (loaded once in
fit()).- Type:
torch.nn.Module or None
Examples
>>> import numpy as np >>> from tabicl import TabICLUnsupervised >>> X = np.random.standard_normal((50, 3)) >>> model = TabICLUnsupervised(n_estimators=4, device="cpu") >>> model.fit(X) >>> scores = model.score_samples(X, n_permutations=4) >>> X_synth = model.generate(n_samples=10)
- fit(X, y=None)#
Store training data, detect categorical features, and load shared models.
The raw training data is stored in
self.X_and used as conditioning context for all downstream predictions. Shared model weights are loaded once here and injected into per-column estimators to avoid redundanttorch.load()calls.- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data. May contain NaN values.
y (ignored) – Not used, present for API consistency.
- Returns:
self
- Return type:
- generate(n_samples=100, temperature=1.0)#
Generate synthetic data by autoregressive sampling.
Features are sampled in the original feature order: each feature
x_kis sampled fromP(x_k | x_{<k}).
- impute(X, temperature=1e-8, n_iterations=2)#
Fill NaN values by conditioning on all other features.
For numerical features, predictions are drawn from the quantile-based ICDF: \(x \sim F^{-1}(u)\) where \(u\) is temperature-scaled around 0.5. For categorical features, classes are sampled from the temperature-scaled predictive distribution.
Multiple iterations (
n_iterations > 1) refine imputed values iteratively: each pass conditions on the current best estimates of all other columns.- Parameters:
X (array-like of shape (n_samples, n_features)) – Data with NaN values to impute.
temperature (float, default=1e-8) – Temperature for sampling. Near 0 gives deterministic (median/mode), 1.0 gives full distribution sampling.
n_iterations (int, default=2) – Number of iterative refinement passes. With
n_iterations=1the method performs a single left-to-right sweep; higher values cycle through the missing columns repeatedly, each time conditioning on the most recently imputed values of other columns.
- Returns:
Data with NaN values filled.
- Return type:
np.ndarray of shape (n_samples, n_features)
- score_samples(X, n_permutations=4)#
Compute outlier scores via chain-rule log-probability.
Estimates the joint density by factoring it as a product of conditionals:
score(x) = exp((1/K) Sigma_k log P(x_{pi(k)} | x_{pi(<k)}))where
piis a random permutation and averaging is overKpermutations. Higher scores indicate more normal data points; lower scores indicate outliers.For numerical features,
P(x_k | ...)is the density from the quantile-based distribution (log_prob on the learned ICDF). For categorical features, it is the predicted class probability.- Parameters:
X (array-like of shape (n_samples, n_features)) – Data to score.
n_permutations (int, default=4) – Number of random feature orderings to average over.
- Returns:
Outlier scores. Higher = more normal, lower = more outlier.
- Return type:
np.ndarray of shape (n_samples,)
Inference configuration#
- class tabicl.InferenceConfig(COL_CONFIG=None, ROW_CONFIG=None, ICL_CONFIG=None)#
Configuration class for inference.
Forecasting utilities#
- class tabicl.forecast.TimeSeriesDataFrame(data, static_features=None, id_column=None, timestamp_column=None, num_cpus=-1, *args, **kwargs)#
A collection of univariate time series, where each row is identified by an (
item_id,timestamp) pair.For example, a time series dataframe could represent the daily sales of a collection of products, where each
item_idcorresponds to a product andtimestampcorresponds to the day of the record.- Parameters:
data (pd.DataFrame, str, pathlib.Path or Iterable) –
Time series data to construct a
TimeSeriesDataFrame. The class currently supports four input formats.Time series data in a pandas DataFrame format without multi-index. For example:
item_id timestamp target 0 0 2019-01-01 0 1 0 2019-01-02 1 2 0 2019-01-03 2 3 1 2019-01-01 3 4 1 2019-01-02 4 5 1 2019-01-03 5 6 2 2019-01-01 6 7 2 2019-01-02 7 8 2 2019-01-03 8
You can also use
from_data_frame()for loading data in such format.Path to a data file in CSV or Parquet format. The file must contain columns
item_idandtimestamp, as well as columns with time series values. This is similar to Option 1 above (pandas DataFrame format without multi-index). Both remote (e.g., S3) and local paths are accepted. You can also usefrom_path()for loading data in such format.Time series data in pandas DataFrame format with multi-index on
item_idandtimestamp. For example:target item_id timestamp 0 2019-01-01 0 2019-01-02 1 2019-01-03 2 1 2019-01-01 3 2019-01-02 4 2019-01-03 5 2 2019-01-01 6 2019-01-02 7 2019-01-03 8
Time series data in Iterable format. For example:
iterable_dataset = [ {"target": [0, 1, 2], "start": pd.Period("01-01-2019", freq='D')}, {"target": [3, 4, 5], "start": pd.Period("01-01-2019", freq='D')}, {"target": [6, 7, 8], "start": pd.Period("01-01-2019", freq='D')} ]
You can also use
from_iterable_dataset()for loading data in such format.static_features (pd.DataFrame, str or pathlib.Path, optional) –
An optional dataframe describing the metadata of each individual time series that does not change with time. Can take real-valued or categorical values. For example, if
TimeSeriesDataFramecontains sales of various products, static features may refer to time-independent features like color or brand.The index of the
static_featuresindex must contain a single entry for each item present in the respectiveTimeSeriesDataFrame. For example, the followingTimeSeriesDataFrame:target item_id timestamp A 2019-01-01 0 2019-01-02 1 2019-01-03 2 B 2019-01-01 3 2019-01-02 4 2019-01-03 5
is compatible with the following
static_features:feat_1 feat_2 item_id A 2.0 bar B 5.0 foo
TimeSeriesDataFramewill ensure consistency of static features during serialization/deserialization, copy and slice operations.If
static_featuresare provided duringfit, theTimeSeriesPredictorexpects the same metadata to be available during prediction time.id_column (str, optional) – Name of the
item_idcolumn, if it’s different from the default. This argument is only used when constructing a TimeSeriesDataFrame using format 1 (DataFrame without multi-index) or 2 (path to a file).timestamp_column (str, optional) – Name of the
timestampcolumn, if it’s different from the default. This argument is only used when constructing a TimeSeriesDataFrame using format 1 (DataFrame without multi-index) or 2 (path to a file).num_cpus (int, default = -1) – Number of CPU cores used to process the iterable dataset in parallel. Set to -1 to use all cores. This argument is only used when constructing a TimeSeriesDataFrame using format 4 (iterable dataset).
- assign(**kwargs)#
Assign new columns to the time series dataframe. See
pandas.DataFrame.assign()for details.
- convert_frequency(freq, agg_numeric='mean', agg_categorical='first', num_cpus=-1, chunk_size=100, **kwargs)#
Convert each time series in the dataframe to the given frequency.
This method is useful for two purposes:
Converting an irregularly-sampled time series to a regular time index.
Aggregating time series data by downsampling (e.g., convert daily sales into weekly sales)
Standard
df.groupby(...).resample(...)can be extremely slow for large datasets, so we parallelize this operation across multiple CPU cores.- Parameters:
freq (Union[str, pd.DateOffset]) – Frequency to which the data should be converted. See pandas frequency aliases for supported values.
agg_numeric ({"max", "min", "sum", "mean", "median", "first", "last"}, default = "mean") – Aggregation method applied to numeric columns.
agg_categorical ({"first", "last"}, default = "first") – Aggregation method applied to categorical columns.
num_cpus (int, default = -1) – Number of CPU cores used when resampling in parallel. Set to -1 to use all cores.
chunk_size (int, default = 100) – Number of time series in a chunk assigned to each parallel worker.
**kwargs – Additional keywords arguments that will be passed to
pandas.DataFrameGroupBy.resample.
- Returns:
ts_df – A new time series dataframe with time series resampled at the new frequency. Output may contain missing values represented by
NaNif original data does not have information for the given period.- Return type:
Examples
Convert irregularly-sampled time series data to a regular index >>> ts_df
target
item_id timestamp
- 0 2019-01-01 NaN
2019-01-03 1.0 2019-01-06 2.0 2019-01-07 NaN
- 1 2019-02-04 3.0
2019-02-07 4.0
>>> ts_df.convert_frequency(freq="D") target item_id timestamp
- 0 2019-01-01 NaN
2019-01-02 NaN 2019-01-03 1.0 2019-01-04 NaN 2019-01-05 NaN 2019-01-06 2.0 2019-01-07 NaN
- 1 2019-02-04 3.0
2019-02-05 NaN 2019-02-06 NaN 2019-02-07 4.0
Downsample quarterly data to yearly frequency
>>> ts_df target item_id timestamp
- 0 2020-03-31 1.0
2020-06-30 2.0 2020-09-30 3.0 2020-12-31 4.0 2021-03-31 5.0 2021-06-30 6.0 2021-09-30 7.0 2021-12-31 8.0
>>> ts_df.convert_frequency("YE") target item_id timestamp
- 0 2020-12-31 2.5
2021-12-31 6.5
>>> ts_df.convert_frequency("YE", agg_numeric="sum") target item_id timestamp
- 0 2020-12-31 10.0
2021-12-31 26.0
- copy(deep=True)#
Make a copy of the TimeSeriesDataFrame.
When
deep=True(default), a new object will be created with a copy of the calling object’s data and indices. Modifications to the data or indices of the copy will not be reflected in the original object.When
deep=False, a new object will be created without copying the calling object’s data or index (only references to the data and index are copied). Any changes to the data of the original will be reflected in the shallow copy (and vice versa).For more details, see pandas documentation.
- dropna(how='any')#
Drop rows containing NaNs.
- Parameters:
how ({"any", "all"}, default = "any") –
Determine if row or column is removed from TimeSeriesDataFrame, when we have at least one NaN or all NaN.
”any” : If any NaN values are present, drop that row or column.
”all” : If all values are NaN, drop that row or column.
- fill_missing_values(method='auto', value=0.0)#
Fill missing values represented by NaN.
Note
This method assumes that the index of the TimeSeriesDataFrame is sorted by [item_id, timestamp].
If the index is not sorted, this method will log a warning and may produce an incorrect result.
- Parameters:
method (str, default = "auto") –
Method used to impute missing values.
”auto” - first forward fill (to fill the in-between and trailing NaNs), then backward fill (to fill the leading NaNs)
”ffill” or “pad” - propagate last valid observation forward. Note: missing values at the start of the time series are not filled.
”bfill” or “backfill” - use next valid observation to fill gap. Note: this may result in information leakage; missing values at the end of the time series are not filled.
”constant” - replace NaNs with the given constant
value.”interpolate” - fill NaN values using linear interpolation. Note: this may result in information leakage.
value (float, default = 0.0) – Value used by the “constant” imputation method.
Examples
>>> ts_df target item_id timestamp 0 2019-01-01 NaN 2019-01-02 NaN 2019-01-03 1.0 2019-01-04 NaN 2019-01-05 NaN 2019-01-06 2.0 2019-01-07 NaN
- 1 2019-02-04 NaN
2019-02-05 3.0 2019-02-06 NaN 2019-02-07 4.0
>>> ts_df.fill_missing_values(method="auto") target item_id timestamp
- 0 2019-01-01 1.0
2019-01-02 1.0 2019-01-03 1.0 2019-01-04 1.0 2019-01-05 1.0 2019-01-06 2.0 2019-01-07 2.0
- 1 2019-02-04 3.0
2019-02-05 3.0 2019-02-06 3.0 2019-02-07 4.0
- property freq#
Inferred pandas-compatible frequency of the timestamps in the dataframe.
Computed using a random subset of the time series for speed. This may sometimes result in incorrectly inferred values. For reliable results, use
infer_frequency().
- classmethod from_data_frame(df, id_column=None, timestamp_column=None, static_features_df=None)#
Construct a
TimeSeriesDataFramefrom a pandas DataFrame.- Parameters:
df (pd.DataFrame) –
A pd.DataFrame with ‘item_id’ and ‘timestamp’ as columns. For example:
item_id timestamp target 0 0 2019-01-01 0 1 0 2019-01-02 1 2 0 2019-01-03 2 3 1 2019-01-01 3 4 1 2019-01-02 4 5 1 2019-01-03 5 6 2 2019-01-01 6 7 2 2019-01-02 7 8 2 2019-01-03 8
id_column (str, optional) – Name of the ‘item_id’ column if column name is different
timestamp_column (str, optional) – Name of the ‘timestamp’ column if column name is different
static_features_df (pd.DataFrame, optional) –
A pd.DataFrame with ‘item_id’ column that contains the static features for each time series. For example:
item_id feat_1 feat_2 0 0 foo 0.5 1 1 foo 2.2 2 2 bar 0.1
- Returns:
ts_df – A dataframe in TimeSeriesDataFrame format.
- Return type:
- classmethod from_iterable_dataset(iterable_dataset, num_cpus=-1)#
Construct a
TimeSeriesDataFramefrom an Iterable of dictionaries each of which represent a single time series.This function also offers compatibility with GluonTS ListDataset format.
- Parameters:
iterable_dataset (Iterable) –
An iterator over dictionaries, each with a
targetfield specifying the value of the (univariate) time series, and astartfield with the starting time as a pandas Period . Example:iterable_dataset = [ {"target": [0, 1, 2], "start": pd.Period("01-01-2019", freq='D')}, {"target": [3, 4, 5], "start": pd.Period("01-01-2019", freq='D')}, {"target": [6, 7, 8], "start": pd.Period("01-01-2019", freq='D')} ]
num_cpus (int, default = -1) – Number of CPU cores used to process the iterable dataset in parallel. Set to -1 to use all cores.
- Returns:
ts_df – A dataframe in TimeSeriesDataFrame format.
- Return type:
- classmethod from_path(path, id_column=None, timestamp_column=None, static_features_path=None)#
Construct a
TimeSeriesDataFramefrom a CSV or Parquet file.- Parameters:
path (str or pathlib.Path) –
Path to a local or remote (e.g., S3) file containing the time series data in CSV or Parquet format. Example file contents:
item_id,timestamp,target 0,2019-01-01,0 0,2019-01-02,1 0,2019-01-03,2 1,2019-01-01,3 1,2019-01-02,4 1,2019-01-03,5 2,2019-01-01,6 2,2019-01-02,7 2,2019-01-03,8
id_column (str, optional) – Name of the ‘item_id’ column if column name is different
timestamp_column (str, optional) – Name of the ‘timestamp’ column if column name is different
static_features_path (str or pathlib.Path, optional) –
Path to a local or remote (e.g., S3) file containing static features in CSV or Parquet format. Example file contents:
item_id,feat_1,feat_2 0,foo,0.5 1,foo,2.2 2,bar,0.1
- Returns:
ts_df – A dataframe in TimeSeriesDataFrame format.
- Return type:
- classmethod from_pickle(filepath_or_buffer)#
Convenience method to read pickled time series dataframes. If the read pickle file refers to a plain pandas DataFrame, it will be cast to a TimeSeriesDataFrame.
- Parameters:
filepath_or_buffer (Any) – Filename provided as a string or an
IOBuffercontaining the pickled object.- Returns:
ts_df – The pickled time series dataframe.
- Return type:
- get_indptr()#
[Advanced] Get a numpy array of shape [num_items + 1] that points to the start and end of each time series.
This method assumes that the TimeSeriesDataFrame is sorted by [item_id, timestamp].
- get_model_inputs_for_scoring(prediction_length, known_covariates_names=None)#
Prepare model inputs necessary to predict the last
prediction_lengthtime steps of each time series in the dataset.- Parameters:
- Returns:
past_data (TimeSeriesDataFrame) – Data, where the last
prediction_lengthtime steps have been removed from the end of each time series.known_covariates (TimeSeriesDataFrame or None) – If
known_covariates_nameswas provided, dataframe with the values of the known covariates during the forecast horizon. Otherwise,None.
- infer_frequency(num_items=None, raise_if_irregular=False)#
Infer the time series frequency based on the timestamps of the observations.
- Parameters:
num_items (int or None, default = None) –
Number of items (individual time series) randomly selected to infer the frequency. Lower values speed up the method, but increase the chance that some items with invalid frequency are missed by subsampling.
If set to None, all items will be used for inferring the frequency.
raise_if_irregular (bool, default = False) – If True, an exception will be raised if some items have an irregular frequency, or if different items have different frequencies.
- Returns:
freq – If all time series have a regular frequency, returns a pandas-compatible frequency alias.
If some items have an irregular frequency or if different items have different frequencies, returns string IRREG.
- Return type:
- property item_ids#
List of unique time series IDs contained in the data set.
- property num_items#
Number of items (time series) in the data set.
- num_timesteps_per_item()#
Number of observations in each time series in the dataframe.
Returns a pandas.Series with item_id as index and number of observations per item as values.
- slice_by_time(start_time, end_time)#
Select a subsequence from each time series between start (inclusive) and end (exclusive) timestamps.
- Parameters:
start_time (pd.Timestamp) – Start time (inclusive) of the slice for each time series.
end_time (pd.Timestamp) – End time (exclusive) of the slice for each time series.
- Returns:
ts_df – A new time series dataframe containing entries of the original time series between start and end timestamps.
- Return type:
- slice_by_timestep(start_index=None, end_index=None)#
Select a subsequence from each time series between start (inclusive) and end (exclusive) indices.
This operation is equivalent to selecting a slice
[start_index : end_index]from each time series, and then combining these slices into a newTimeSeriesDataFrame. See examples below.It is recommended to sort the index with ts_df.sort_index() before calling this method to take advantage of a fast optimized algorithm.
- Parameters:
start_index (int or None) – Start index (inclusive) of the slice for each time series. Negative values are counted from the end of each time series. When set to None, the slice starts from the beginning of each time series.
end_index (int or None) – End index (exclusive) of the slice for each time series. Negative values are counted from the end of each time series. When set to None, the slice includes the end of each time series.
- Returns:
ts_df (TimeSeriesDataFrame) – A new time series dataframe containing entries of the original time series between start and end indices.
Examples
——–
>>> ts_df – target
item_id timestamp
0 2019-01-01 0 – 2019-01-02 1 2019-01-03 2
1 2019-01-02 3 – 2019-01-03 4 2019-01-04 5
2 2019-01-03 6 – 2019-01-04 7 2019-01-05 8
Select the first entry of each time series
>>> df.slice_by_timestep(0, 1) – target
item_id timestamp
0 2019-01-01 0
1 2019-01-02 3
2 2019-01-03 6
Select the last 2 entries of each time series
>>> df.slice_by_timestep(-2, None) – target
item_id timestamp
0 2019-01-02 1 – 2019-01-03 2
1 2019-01-03 4 – 2019-01-04 5
2 2019-01-04 7 – 2019-01-05 8
Select all except the last entry of each time series
>>> df.slice_by_timestep(None, -1) – target
item_id timestamp
0 2019-01-01 0 – 2019-01-02 1
1 2019-01-02 3 – 2019-01-03 4
2 2019-01-03 6 – 2019-01-04 7
Copy the entire dataframe
>>> df.slice_by_timestep(None, None) – target
item_id timestamp
0 2019-01-01 0 – 2019-01-02 1 2019-01-03 2
1 2019-01-02 3 – 2019-01-03 4 2019-01-04 5
2 2019-01-03 6 – 2019-01-04 7 2019-01-05 8
- sort_index(*args, **kwargs)#
Sort object by labels (along an axis).
Returns a new DataFrame sorted by label if inplace argument is
False, otherwise updates the original DataFrame and returns None.- Parameters:
axis ({0 or 'index', 1 or 'columns'}, default 0) – The axis along which to sort. The value 0 identifies the rows, and 1 identifies the columns.
level (int or level name or list of ints or list of level names) – If not None, sort on values in specified index level(s).
ascending (bool or list-like of bools, default True) – Sort ascending vs. descending. When the index is a MultiIndex the sort direction can be controlled for each level individually.
inplace (bool, default False) – Whether to modify the DataFrame rather than creating a new one.
kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort') – Choice of sorting algorithm. See also
numpy.sort()for more information. mergesort and stable are the only stable algorithms. For DataFrames, this option is only applied when sorting on a single column or label.na_position ({'first', 'last'}, default 'last') – Puts NaNs at the beginning if first; last puts NaNs at the end. Not implemented for MultiIndex.
sort_remaining (bool, default True) – If True and sorting by level and index is multilevel, sort by other levels too (in order) after sorting by specified level.
ignore_index (bool, default False) – If True, the resulting axis will be labeled 0, 1, …, n - 1.
key (callable, optional) – If not None, apply the key function to the index values before sorting. This is similar to the key argument in the builtin
sorted()function, with the notable difference that this key function should be vectorized. It should expect anIndexand return anIndexof the same shape. For MultiIndex inputs, the key is applied per level.
- Returns:
The original DataFrame sorted by the labels or None if
inplace=True.- Return type:
DataFrame or None
See also
Series.sort_indexSort Series by the index.
DataFrame.sort_valuesSort DataFrame by the value.
Series.sort_valuesSort Series by the value.
Examples
>>> df = pd.DataFrame([1, 2, 3, 4, 5], index=[100, 29, 234, 1, 150], ... columns=['A']) >>> df.sort_index() A 1 4 29 2 100 1 150 5 234 3
By default, it sorts in ascending order, to sort in descending order, use
ascending=False>>> df.sort_index(ascending=False) A 234 3 150 5 100 1 29 2 1 4
A key function can be specified which is applied to the index before sorting. For a
MultiIndexthis is applied to each level separately.>>> df = pd.DataFrame({"a": [1, 2, 3, 4]}, index=['A', 'b', 'C', 'd']) >>> df.sort_index(key=lambda x: x.str.lower()) a A 1 b 2 C 3 d 4
- split_by_time(cutoff_time)#
Split dataframe to two different
TimeSeriesDataFrames before and after a certaincutoff_time.- Parameters:
cutoff_time (pd.Timestamp) – The time to split the current dataframe into two dataframes.
- Returns:
data_before (TimeSeriesDataFrame) – Data frame containing time series before the
cutoff_time(excludecutoff_time).data_after (TimeSeriesDataFrame) – Data frame containing time series after the
cutoff_time(includecutoff_time).
- to_data_frame()#
Convert TimeSeriesDataFrame to a pandas.DataFrame
- train_test_split(prediction_length, end_index=None, suffix=None)#
Generate a train/test split from the given dataset.
This method can be used to generate splits for multi-window backtesting.
Note
This method automatically sorts the TimeSeriesDataFrame by [item_id, timestamp].
- Parameters:
prediction_length (int) – Number of time steps in a single evaluation window.
end_index (int, optional) – If given, all time series will be shortened up to
end_idxbefore the train/test splitting. In other words, test data will include the slice[:end_index]of each time series, and train data will include the slice[:end_index - prediction_length].suffix (str, optional) – Suffix appended to all entries in the
item_idindex level.
- Returns:
train_data (TimeSeriesDataFrame) – Train portion of the data. Contains the slice
[:-prediction_length]of each time series intest_data.test_data (TimeSeriesDataFrame) – Test portion of the data. Contains the slice
[:end_idx]of each time series in the original dataset.
- class tabicl.forecast.TimeTransformChain(transforms)#
Orchestrates feature generation for time series data.
Applies a sequence of
TimeTransforminstances to both training and test data, ensuring consistent feature columns across splits.- Parameters:
transforms (list[TimeTransform]) – Transforms to apply sequentially.
- transform(train_tsdf, test_tsdf, target_column='target')#
Transform both training and test data with the configured transforms.
- Parameters:
train_tsdf (TimeSeriesDataFrame) – Training time series data.
test_tsdf (TimeSeriesDataFrame) – Test time series data.
target_column (str, default="target") – Name of the target column.
- Returns:
Transformed
(train_tsdf, test_tsdf)with generated features.- Return type:
- Raises:
ValueError – If
target_columnis not found in training data or if test data contains non-NaN target values.
- tabicl.forecast.plot_forecast(context_df, pred_df, test_df=None, item_ids=None, context_length=100, show_quantiles=True, show_points=False, linewidth=1.8)#
Plot forecast with historical context and optional ground truth.
Converts pandas DataFrames from the
predict_dfAPI toTimeSeriesDataFrameand delegates toplot_predictions.- Parameters:
context_df (pd.DataFrame) – Historical data with columns
timestamp,target, and optionallyitem_id.pred_df (pd.DataFrame) – Predictions from
predict_df, with multi-index(item_id, timestamp).test_df (pd.DataFrame or None, default=None) – Optional ground truth for the forecast horizon.
item_ids (list or None, default=None) – Item IDs to plot. If
None, plots all unique items.context_length (int, default=100) – Number of historical points to show before the forecast.
show_quantiles (bool, default=True) – Whether to show the quantile prediction range.
show_points (bool, default=False) – Whether to show individual data points.
linewidth (float, default=1.8) – Line thickness for all plot lines.
Time feature transforms#
- class tabicl.forecast.transforms.TimeTransform#
Abstract base class for time series feature transforms.
Subclasses must implement
generateto add feature columns to a DataFrame. Instances are callable via__call__, which delegates togenerate.- abstractmethod generate(df)#
Generate features for the given DataFrame.
- Parameters:
df (pd.DataFrame) – Input DataFrame to augment with features.
- Returns:
DataFrame with added feature columns.
- Return type:
pd.DataFrame
- class tabicl.forecast.transforms.IndexEncoder#
Transform that adds a
running_indexcolumn.Assigns a sequential integer index (0, 1, 2, …) to each row.
- class tabicl.forecast.transforms.DatetimeEncoder(components=None, seasonal_features=None)#
Transform that creates calendar-based temporal features.
Extracts calendar components (e.g., year) and encodes seasonal patterns (e.g., hour of day, day of week) as sin/cosine pairs using
gluonts.time_feature.- Parameters:
components (list[str] | None, default=None) – Calendar components to extract (e.g.,
["year"]). IfNone, defaults to["year"].seasonal_features (dict[str, list[float]] | None, default=None) – Mapping of seasonal feature names to their natural periods. Each feature is encoded as sin/cosine pairs. If
None, uses default temporal features (second, minute, hour, day, week, month).
- class tabicl.forecast.transforms.ExtendedDatetimeEncoder(components=None, additional_seasonal_features=None)#
Extended calendar feature transform with additional seasonal features.
Inherits from
DatetimeEncoderand merges additional seasonal features with the defaults.
- class tabicl.forecast.transforms.FourierEncoder(periods, name_suffix=None)#
Transform that creates sin/cosine features for given periods.
For each period, adds
sin_{period}andcos_{period}columns based on the row index position.
- class tabicl.forecast.transforms.AutoPeriodicEncoder(config=None)#
Transform that automatically detects and encodes seasonal periods.
Uses FFT-based spectral analysis to identify dominant seasonal periods in the target time series, then generates sin/cosine features for each detected period.
- Parameters:
config (PeriodicDetectionConfig | dict | None, default=None) – Configuration for periodicity detection. Accepts a dataclass instance, a dict of overrides, or
Nonefor defaults.
- class tabicl.forecast.transforms.PeriodicDetectionConfig(max_top_k=5, do_detrend=True, detrend_type='linear', use_peaks_only=True, apply_hann_window=True, zero_padding_factor=2, round_to_closest_integer=True, validate_with_acf=False, sampling_interval=1.0, magnitude_threshold=0.05, relative_threshold=True, exclude_zero=True)#
Configuration for automatic periodicity detection via FFT.
- Parameters:
max_top_k (int) – Maximum number of dominant periods to detect.
do_detrend (bool) – Whether to remove trend before FFT.
detrend_type ({"first_diff", "loess", "linear", "constant"}) – Detrending method.
use_peaks_only (bool) – Whether to consider only local peaks in the FFT spectrum.
apply_hann_window (bool) – Whether to apply a Hann window to reduce spectral leakage.
zero_padding_factor (int) – Factor by which to zero-pad the signal for finer frequency resolution.
round_to_closest_integer (bool) – Whether to round detected periods to the nearest integer.
validate_with_acf (bool) – Whether to validate detected periods against autocorrelation.
sampling_interval (float) – Time interval between consecutive samples.
magnitude_threshold (float | None) – Threshold to filter out less significant frequency components.
relative_threshold (bool) – Whether
magnitude_thresholdis relative to the maximum FFT magnitude.exclude_zero (bool) – Whether to exclude periods of 0 from the results.
Pre-training data#
- class tabicl.prior.PriorDataset(batch_size=256, batch_size_per_gp=4, batch_size_per_subgp=None, min_features=2, max_features=100, max_classes=10, min_seq_len=None, max_seq_len=1024, log_seq_len=False, seq_len_per_gp=False, min_train_size=0.1, max_train_size=0.9, replay_small=False, prior_type='mlp_scm', scm_fixed_hp=DEFAULT_FIXED_HP, scm_sampled_hp=DEFAULT_SAMPLED_HP, n_jobs=-1, num_threads_per_generate=1, device='cpu')#
Main dataset class that provides an infinite iterator over synthetic tabular datasets.
- Parameters:
batch_size (int, default=256) – Total number of datasets to generate per batch.
batch_size_per_gp (int, default=4) – Number of datasets per group, sharing similar characteristics.
batch_size_per_subgp (int, optional) – Number of datasets per subgroup, with more similar causal structures. If None, defaults to batch_size_per_gp.
min_features (int, default=2) – Minimum number of features per dataset.
max_features (int, default=100) – Maximum number of features per dataset.
max_classes (int, default=10) – Maximum number of target classes.
min_seq_len (int, optional) – Minimum samples per dataset. If None, uses max_seq_len directly.
max_seq_len (int, default=1024) – Maximum samples per dataset.
log_seq_len (bool, default=False) – If True, sample sequence length from a log-uniform distribution.
seq_len_per_gp (bool, default=False) – If True, sample sequence length per group, allowing variable-sized datasets.
min_train_size (int or float, default=0.1) – Position or ratio for train/test split start. If int, absolute position. If float between 0 and 1, specifies a fraction of sequence length.
max_train_size (int or float, default=0.9) – Position or ratio for train/test split end. If int, absolute position. If float between 0 and 1, specifies a fraction of sequence length.
replay_small (bool, default=False) – If True, occasionally sample smaller sequence lengths with specific distributions to ensure model robustness on smaller datasets.
prior_type (str, default="mlp_scm") –
Type of prior: ‘mlp_scm’ (default), ‘tree_scm’, ‘mix_scm’, or ‘dummy’.
SCM-based: Structural causal models with complex feature relationships
’mlp_scm’: MLP-based causal models
’tree_scm’: Tree-based causal models
’mix_scm’: Probabilistic mix of the above models
Dummy: Randomly generated datasets for debugging
scm_fixed_hp (dict, default=DEFAULT_FIXED_HP) – Fixed parameters for SCM-based priors.
scm_sampled_hp (dict, default=DEFAULT_SAMPLED_HP) – Parameters sampled during generation.
n_jobs (int, default=-1) – Number of parallel jobs to run (-1 means using all processors).
num_threads_per_generate (int, default=1) – Number of threads per job for dataset generation.
device (str, default="cpu") – Computation device (‘cpu’ or ‘cuda’).
- get_batch(batch_size=None)#
Generate a new batch of datasets.
- Parameters:
batch_size (int, optional) – If provided, overrides the default batch size for this call.
- Returns:
X (Tensor or NestedTensor) –
For SCM-based priors:
If seq_len_per_gp=False, shape is
(batch_size, seq_len, max_features).If seq_len_per_gp=True, returns a NestedTensor.
2. For DummyPrior, random Gaussian values of
(batch_size, seq_len, max_features).y (Tensor or NestedTensor) –
For SCM-based priors:
If seq_len_per_gp=False, shape is
(batch_size, seq_len).If seq_len_per_gp=True, returns a NestedTensor.
For DummyPrior, random class labels of
(batch_size, seq_len).
d (Tensor) – Number of active features per dataset of shape
(batch_size,).seq_lens (Tensor) – Sequence length for each dataset of shape
(batch_size,).train_sizes (Tensor) – Position for train/test split for each dataset of shape
(batch_size,).
SHAP interpretability#
- tabicl.shap.get_shap_explainer(estimator, X, predict_fn='predict_proba', **kwargs)#
Build a
shap.Explainerwith an all-NaN background.- Parameters:
estimator (estimator object) – A fitted estimator.
X (array-like) – Used only to infer
n_features.predict_fn (str or callable, default="predict_proba") – Prediction method; resolved via
getattrwhen a string.**kwargs – Forwarded to
shap.Explainer.
- Return type:
shap.Explainer
- tabicl.shap.get_shap_values(estimator, X_test, attribute_names=None, **kwargs)#
Compute SHAP values for a fitted estimator.
- Parameters:
estimator (estimator object) – A fitted TabICL estimator (classifier or regressor).
X_test (array-like or DataFrame) – Samples to explain.
attribute_names (list of str, optional) – Feature names (inferred from DataFrame columns when possible).
**kwargs – Forwarded to
get_shap_explainer().
- Return type:
shap.Explanation
- tabicl.shap.get_shapiq_explainer(estimator, data, *, imputer='nan', index='k-SII', max_order=2, class_index=None, **kwargs)#
Create a shapiq explainer tuned for TabICL.
- Parameters:
estimator (estimator object) – A fitted TabICL estimator (or any sklearn-compatible estimator when using
imputer="marginal").data (array-like) – Background / reference data.
imputer (str or shapiq.Imputer instance, default="nan") –
How absent features are handled when evaluating coalitions:
"nan"(default) — uses_NaNImputerso that absent features become NaN, exploiting TabICL’s native missing-feature handling. Deterministic, no sampling noise."marginal"— standard marginal-sampling imputation."baseline"— replace absent features with a fixed baseline value (typically the mean of data)."conditional"— conditional-sampling imputation.Any
shapiq.Imputerinstance — forwarded directly toshapiq.TabularExplainer.
See the shapiq imputer documentation for details.
index (str, default="k-SII") –
Interaction index to compute. Common choices:
"SV"— Shapley values (setmax_order=1)."k-SII"— k-Shapley Interaction Index (default)."SII"— Shapley Interaction Index."STII"— Shapley Taylor Interaction Index."FSII"— Faithful Shapley Interaction Index."FBII"— Faithful Banzhaf Interaction Index.
See the shapiq index documentation for the full list.
max_order (int, default=2) – Maximum interaction order.
class_index (int or None, default=None) – For classifiers, which class probability to explain.
**kwargs – Forwarded to
shapiq.TabularExplainer.
- Return type:
shapiq.TabularExplainer
- tabicl.shap.plot_shap(shap_values, kind='bar')#
Plot SHAP explanations.
- Parameters:
shap_values (shap.Explanation) – Typically returned by
get_shap_values().kind (str or tuple of str, default="bar") – Which plots to show. Any combination of
"bar","beeswarm", and"scatter".
- tabicl.shap.plot_shap_feature(shap_values, feature, n_plots=1)#
Scatter plot of a single feature coloured by its top interactions.