Feature Attribution (SHAP)#
Source Files
twiga/core/explain/shap_explainer.pytwiga/core/explain/__init__.py
Twiga provides SHAP-based feature attribution through ShapExplainer, which
wraps any fitted ML model and returns attributions shaped to the interpretable
(B, L, F) form - one value per sample, per lookback timestep, per feature.
Why SHAP?#
Standard feature importance (split-gain, permutation) assigns a single scalar
per input column. For time series, the input is a window of L past
observations, so there are actually L × F distinct inputs. SHAP attribution
lets you answer:
Which features drive predictions most (averaged over time)?
How far back does the model actually look?
Which recent timesteps have the highest influence?
Quick Start#
from twiga.core.explain import ShapExplainer
# After fitting a TwigaForecaster:
result = forecaster.explain(X_test) # convenience wrapper
# or directly:
explainer = ShapExplainer(forecaster.models[0], forecaster.data_pipeline)
result = explainer.explain(X_test) # X_test: (B, L, F)
X_test is the 3-D feature array produced by the data pipeline (same shape as
the array passed to the model’s fit() / predict() methods).
ShapResult#
explain() returns a ShapResult dataclass:
Attribute |
Shape |
Description |
|---|---|---|
|
|
SHAP values averaged across horizon steps |
|
|
Original feature names from |
|
|
Labels |
|
|
Flat names used internally ( |
|
scalar |
SHAP base value (mean prediction) |
print(result.values.shape) # (B, L, F)
importance = result.mean_importance() # dict: feature → mean |SHAP|
result.plot_importance(top_n=20) # bar chart
result.plot_timestep_importance() # line chart over lookback window
mean_importance()#
Returns a dict mapping each feature name to its mean |SHAP|, averaged over
all batch samples and lookback timesteps. Use this to rank features.
timestep_importance()#
Returns a dict mapping each timestep label to its mean |SHAP|, averaged over
all batch samples and features. Use this to understand how far back the model
relies on for its predictions.
plot_importance(top_n=20)#
Horizontal bar chart of the top-N features by mean |SHAP|.
plot_timestep_importance()#
Line chart showing mean |SHAP| at each lookback position from t-(L-1) (oldest)
to t0 (most recent).
ShapExplainer#
ShapExplainer(model, data_pipeline) accepts a fitted BaseRegressor subclass
and the fitted DataPipeline. Internally it:
Detects the model storage pattern (per-output list, single, or wrapped).
Dispatches to the right SHAP backend.
Reshapes the flat
(B, L*F)SHAP output back to(B, L, F).
Explainer dispatch#
Estimator type |
SHAP backend |
Speed |
|---|---|---|
LightGBM / XGBoost / CatBoost / RandomForest |
|
Fast (exact) |
LinearRegression / Ridge / Lasso |
|
Fast (exact) |
Other |
|
Slow (model-agnostic) |
Model storage patterns#
Twiga model |
Pattern |
Internal detail |
|---|---|---|
|
|
One estimator per horizon step; SHAP averaged |
|
|
|
|
|
Iterates over |
TwigaForecaster.explain()#
The explain() method on TwigaForecaster is a thin convenience wrapper:
result = forecaster.explain(
X, # (B, L, F) feature array
model_idx=0, # which model in forecaster.models to explain
n_background=100,
)
Installation#
SHAP is an optional dependency. Install it separately:
pip install shap
Example#
import numpy as np
from sklearn.preprocessing import StandardScaler
from twiga.core.config import DataPipelineConfig, ForecasterConfig
from twiga.forecaster.core import TwigaForecaster
from twiga.models.ml.lightgbm_model import LIGHTGBMConfig
# --- Fit ---
data_config = DataPipelineConfig(
target_feature='load',
period='1h',
lookback_window_size=168,
forecast_horizon=24,
lags=[1, 24, 168],
input_scaler=StandardScaler(),
)
train_config = ForecasterConfig(
split_freq='months', train_size=12, test_size=1,
window='expanding', project_name='shap_demo',
)
forecaster = TwigaForecaster(data_config, LIGHTGBMConfig(), train_config)
forecaster.fit(train_df)
# --- SHAP ---
X_test, _ = forecaster.data_pipeline.transform(test_df) # (B, L, F)
result = forecaster.explain(X_test, model_idx=0)
print("Top features:")
for feat, val in list(result.mean_importance().items())[:5]:
print(f" {feat}: {val:.4f}")
result.plot_importance(top_n=20)
result.plot_timestep_importance()
API Reference#
- class twiga.core.explain.ShapExplainer(model, data_pipeline)#
Bases:
objectCompute SHAP attributions for any fitted Twiga ML model.
Dispatches to the most efficient SHAP backend based on the underlying estimator type, then reshapes values from the flat sklearn representation
(n_samples, L*F)back to the interpretable(n_samples, L, F)form.- Parameters:
- Raises:
ImportError – If
shapis not installed.RuntimeError – If the model structure cannot be detected (not fitted).
Example
>>> explainer = ShapExplainer(forecaster.models[0], forecaster.data_pipeline) >>> result = explainer.explain(X_test) # X_test: (n_samples, L, F) >>> importance = explainer.feature_importance(X_test) >>> result.plot_importance(top_n=15)
- explain(X, n_background=100)#
Compute SHAP values for input samples.
- Parameters:
X (
ndarray) – Input array of shape(n_samples, L, F)- same format as model input. Usually obtained fromDataPipeline.transform()or directly fromTwigaForecaster’s internal data preparation.n_background (
int) – Maximum number of background samples to use forLinearExplainerandKernelExplainer.TreeExplainerignores this parameter (it uses the tree structure directly).
- Return type:
ShapResult- Returns:
ShapResultwithvaluesof shape(n_samples, L, F).- Raises:
ValueError – If
Xis not 3-dimensional.RuntimeError – If the model structure cannot be determined.
- feature_importance(X, n_background=100)#
Return mean absolute SHAP importance per feature.
Averages
\|SHAP\|over all batch samples and lookback timesteps.
Limitations#
Neural networks (
domain="nn") are not supported byShapExplainer. Gradient-based attribution for NN models (GradientExplainer) is planned as a future extension.NGBoost models are not yet supported as they use a non-standard internal structure (
NGBRegressorwrappingDecisionTreeRegressors).SHAP computation can be slow for large
Bor largeL*F. Usen_backgroundto limit the background dataset size and pass a random sub-sample ofXif needed.