SHAP Feature Attribution#

Advanced Python Twiga Time


What you’ll build

SHAP feature attributions for a LightGBM forecaster trained on MLVS-PT net load data - a feature importance bar chart ranking which signals (net load history, solar irradiance, temperature) drive predictions most, a timestep attribution profile showing how far back the model looks, and a (L, F) heatmap of mean |SHAP| per feature per lookback position.

Prerequisites

  • 01 - Getting Started

  • 05 - ML Point Forecasting

Learning objectives

By the end of this notebook you will be able to:

  1. Explain SHAP values for time series - flat (B, L×F) inputs reshaped to (B, L, F), one attribution per lag per feature

  2. Use forecaster.explain() with TreeExplainer to compute SHAP values efficiently for LightGBM

  3. Rank features by mean |SHAP| averaged across samples and lookback timesteps

  4. Visualise which lookback timesteps carry the most predictive signal using the timestep attribution profile

  5. Build a feature × timestep heatmap to identify the most informative lag of each feature

Key concept - SHAP for time series

Standard feature importance gives one number per input column. For time series forecasting, the model’s input is a window of L past timesteps across F features - a flat vector of L × F values. SHAP decomposes the prediction into one attribution per input element.

Twiga reshapes the flat SHAP array (B, L*F) back to (B, L, F) - giving you one attribution per sample, per lookback step, per feature. This means you can ask:

  • Which features matter most overall? → average |SHAP| over B and L

  • Which timesteps matter most? → average |SHAP| over B and F

  • Which feature at which lag matters most? → the full (L, F) heatmap

historical_features (like Ghi and Temperature) are treated as features with unknown future values - only their past L timesteps are fed to the model, keeping the input shape a clean (B, L, F) with no future-covariate inflation.

Key concept - TreeExplainer vs. KernelExplainer

SHAP provides different explainers optimised for different model families:

  • TreeExplainer - exact, fast SHAP values for tree-based models (LightGBM, XGBoost, CatBoost, RandomForest). Exploits the tree structure directly. No background samples needed.

  • LinearExplainer - exact SHAP for linear models. Also fast.

  • KernelExplainer - model-agnostic, works on any black-box model. Much slower.

Twiga dispatches automatically: LightGBM/XGBoost/CatBoost/RandomForest → TreeExplainer, LinearRegression → LinearExplainer, anything else → KernelExplainer.

Key concept - mean |SHAP| importance

The standard way to rank features by overall importance is to take the mean absolute SHAP value across all samples and all lookback timesteps:

\[\text{importance}(f) = \frac{1}{B \cdot L} \sum_{b=1}^{B} \sum_{l=1}^{L} |\phi_{b,l,f}|\]

We use absolute values because SHAP values can be positive or negative - averaging without them would cause cancellation and understate importance.

1. Setup#

import warnings

warnings.filterwarnings("ignore")

from great_tables import GT, md
from IPython.display import clear_output
from lets_plot import LetsPlot, aes, geom_bar, geom_line, geom_tile, ggplot, ggsize, labs, scale_fill_gradient2
import numpy as np
import pandas as pd
from sklearn.preprocessing import RobustScaler, StandardScaler

LetsPlot.setup_html()

from twiga import TwigaForecaster
from twiga.core.config import DataPipelineConfig, ForecasterConfig
from twiga.core.explain import ShapExplainer
from twiga.core.plot import plot_metrics_bar, plot_timeseries
from twiga.core.plot.gt import twiga_gt
from twiga.core.plot.theme import TWIGA_PALETTE, twiga_theme
from twiga.core.utils import configure, get_logger
from twiga.models.ml import LIGHTGBMConfig

configure()
log = get_logger("tutorials")
log.info("Imports OK")
2026-05-03 16:46:06 | INFO     | twiga.tutorials | Imports OK

2. Load Data#

data = pd.read_parquet("../data/MLVS-PT.parquet")
data = data[["timestamp", "NetLoad(kW)", "Ghi", "Temperature"]]
data["timestamp"] = pd.to_datetime(data["timestamp"])
data = data.drop_duplicates(subset="timestamp").reset_index(drop=True)
# Restrict to 2019-2020 to keep tutorial execution fast
data = data[(data["timestamp"] >= "2019-01-01") & (data["timestamp"] <= "2020-12-31")].reset_index(drop=True)

log.info("Shape: %s", data.shape)

twiga_gt(
    GT(data.head())
    .tab_header(title=md("**Raw Data Sample**"), subtitle="First 5 rows of MLVS-PT")
    .cols_label(
        timestamp=md("**Timestamp**"),
        **{
            "NetLoad(kW)": md("**NetLoad (kW)**"),
            "Ghi": md("**Ghi (W/m^2)**"),
            "Temperature": md("**Temperature (C)**"),
        },
    )
    .tab_source_note("MLVS-PT dataset · Madeira, Portugal · 30-min resolution"),
    n_rows=5,
)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[2], line 1
----> 1 data = pd.read_parquet("../data/MLVS-PT.parquet")
      2 data = data[["timestamp", "NetLoad(kW)", "Ghi", "Temperature"]]
      3 data["timestamp"] = pd.to_datetime(data["timestamp"])
      4 data = data.drop_duplicates(subset="timestamp").reset_index(drop=True)

File ~/work/twiga-forecast/twiga-forecast/.venv/lib/python3.12/site-packages/pandas/io/parquet.py:669, in read_parquet(path, engine, columns, storage_options, use_nullable_dtypes, dtype_backend, filesystem, filters, **kwargs)
    666     use_nullable_dtypes = False
    667 check_dtype_backend(dtype_backend)
--> 669 return impl.read(
    670     path,
    671     columns=columns,
    672     filters=filters,
    673     storage_options=storage_options,
    674     use_nullable_dtypes=use_nullable_dtypes,
    675     dtype_backend=dtype_backend,
    676     filesystem=filesystem,
    677     **kwargs,
    678 )

File ~/work/twiga-forecast/twiga-forecast/.venv/lib/python3.12/site-packages/pandas/io/parquet.py:258, in PyArrowImpl.read(self, path, columns, filters, use_nullable_dtypes, dtype_backend, storage_options, filesystem, **kwargs)
    256 if manager == "array":
    257     to_pandas_kwargs["split_blocks"] = True
--> 258 path_or_handle, handles, filesystem = _get_path_or_handle(
    259     path,
    260     filesystem,
    261     storage_options=storage_options,
    262     mode="rb",
    263 )
    264 try:
    265     pa_table = self.api.parquet.read_table(
    266         path_or_handle,
    267         columns=columns,
   (...)    270         **kwargs,
    271     )

File ~/work/twiga-forecast/twiga-forecast/.venv/lib/python3.12/site-packages/pandas/io/parquet.py:141, in _get_path_or_handle(path, fs, storage_options, mode, is_dir)
    131 handles = None
    132 if (
    133     not fs
    134     and not is_dir
   (...)    139     # fsspec resources can also point to directories
    140     # this branch is used for example when reading from non-fsspec URLs
--> 141     handles = get_handle(
    142         path_or_handle, mode, is_text=False, storage_options=storage_options
    143     )
    144     fs = None
    145     path_or_handle = handles.handle

File ~/work/twiga-forecast/twiga-forecast/.venv/lib/python3.12/site-packages/pandas/io/common.py:882, in get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)
    873         handle = open(
    874             handle,
    875             ioargs.mode,
   (...)    878             newline="",
    879         )
    880     else:
    881         # Binary mode
--> 882         handle = open(handle, ioargs.mode)
    883     handles.append(handle)
    885 # Convert BytesIO or file objects passed with an encoding

FileNotFoundError: [Errno 2] No such file or directory: '../data/MLVS-PT.parquet'

3. Chronological Splits#

We use the same fixed temporal splits as all other tutorials - never shuffle time series data.

train_df = data[data["timestamp"] < "2020-01-01"].reset_index(drop=True)
val_df = data[(data["timestamp"] >= "2020-01-01") & (data["timestamp"] < "2020-07-01")].reset_index(drop=True)
test_df = data[data["timestamp"] >= "2020-07-01"].reset_index(drop=True)

log.info(
    "train : %d rows  (%s to %s)", len(train_df), train_df["timestamp"].min().date(), train_df["timestamp"].max().date()
)
log.info("val   : %d rows  (%s to %s)", len(val_df), val_df["timestamp"].min().date(), val_df["timestamp"].max().date())
log.info(
    "test  : %d rows  (%s to %s)", len(test_df), test_df["timestamp"].min().date(), test_df["timestamp"].max().date()
)

4. Configure the Forecaster#

We use historical_features for Ghi and Temperature rather than exogenous_features. The distinction matters for SHAP:

Config key

In feature_columns?

In covariate_columns?

Model receives future values?

historical_features

Yes

No

No - only past L steps

exogenous_features

Yes

Yes

Yes - past + future H steps

Using historical_features keeps the model input a clean (B, L, F) array that ShapExplainer can reshape correctly. Adding future covariate steps would inflate the time dimension to L+H and break the attribution reshape.

data_config = DataPipelineConfig(
    target_feature="NetLoad(kW)",
    period="30min",
    lookback_window_size=96,  # 48 h of history at 30-min resolution
    forecast_horizon=48,  # predict next 24 h
    historical_features=["Ghi", "Temperature"],  # past values only — no future leakage
    input_scaler=StandardScaler(),
    target_scaler=RobustScaler(),
)

train_config = ForecasterConfig(
    split_freq="months",
    train_size=3,
    test_size=1,
)

lgb_config = LIGHTGBMConfig()

log.info("target_feature      : %s", data_config.target_feature)
log.info("historical_features : %s", data_config.historical_features)
log.info("F (num features)    : 3  (NetLoad(kW), Ghi, Temperature)")
log.info("L (lookback steps)  : %d", data_config.lookback_window_size)
log.info("H (horizon steps)   : %d", data_config.forecast_horizon)

5. Fit LightGBM Forecaster#

forecaster = TwigaForecaster(
    data_params=data_config,
    model_params=[lgb_config],
    train_params=train_config,
)
forecaster.fit(train_df=train_df, val_df=val_df)
clear_output()
log.info("Model fitted.")

6. Prepare Test Features#

DataPipeline.transform_features() returns the 3-D array (B, L, F) the model sees at predict time:

  • B: number of sliding windows in the test set

  • L: lookback_window_size (96 steps = 48 h)

  • F: number of features: NetLoad(kW), Ghi, Temperature

Because historical_features are past-only (not in covariate_columns), no future covariate steps are appended - the shape stays a clean (B, L, F).

X_test = forecaster.data_pipeline.transform_features(test_df)
log.info("X_test shape : %s  (B=%d, L=%d, F=%d)", X_test.shape, X_test.shape[0], X_test.shape[1], X_test.shape[2])
log.info("feature_columns : %s", forecaster.data_pipeline.feature_columns)

7. Compute SHAP Values#

forecaster.explain() builds a ShapExplainer, dispatches to shap.TreeExplainer for LightGBM (fast, exact), and returns a ShapResult with SHAP values of shape (B, L, F) averaged across all 48 horizon steps.

We subsample to 300 windows for speed - TreeExplainer is still fast for LightGBM.

n_explain = min(300, len(X_test))
X_explain = X_test[:n_explain]

result = forecaster.explain(X_explain, model_idx=0)

log.info("SHAP values shape    : %s  (B, L, F)", result.values.shape)
log.info("Base value           : %.4f", result.expected_value)
log.info("Feature names        : %s", result.feature_names)

8. Feature Importance Bar Chart#

result.mean_importance() returns features ranked by mean |SHAP|, averaged across all lookback timesteps and all samples.

Reading the chart - a taller bar means that feature pushes the prediction further from the baseline on average. The target’s own history (NetLoad(kW)) typically dominates; solar irradiance (Ghi) shows a strong effect because net load = demand minus solar output.

importance = result.mean_importance()
for i, (feat, val) in enumerate(importance.items(), 1):
    log.info("%d. %-25s  %.5f", i, feat, val)
imp_df = pd.DataFrame({"Feature": list(importance.keys()), "Mean |SHAP|": list(importance.values())})

p_imp = plot_metrics_bar(
    imp_df,
    metric_col="Mean |SHAP|",
    model_col="Feature",
    lower_is_better=False,
    title="LightGBM — SHAP Feature Importance (MLVS-PT)",
    x_label="Mean |SHAP|",
    fig_size=(600, 280),
)
p_imp

9. Timestep Attribution Profile#

How far back does the model look? result.timestep_importance() gives mean |SHAP| per lookback position. We plot it as a line chart using plot_timeseries.

Reading the chart - a spike near t0 means the model relies on the most recent observations. Peaks at multiples of 48 reveal daily seasonal patterns (one day = 48 steps at 30-min resolution).

ts_importance = result.timestep_importance()
# Build a positional DataFrame so letsplot can render it as a line
ts_df = pd.DataFrame(
    {
        "step": range(len(result.timestep_labels)),
        "Mean |SHAP|": np.mean(np.abs(result.values), axis=(0, 2)),
    }
)

p_ts = (
    ggplot(ts_df, aes(x="step", y="Mean |SHAP|"))
    + geom_line(color=TWIGA_PALETTE[0], size=1.3)
    + labs(
        title="LightGBM — SHAP Importance by Lookback Position (MLVS-PT)",
        x="Lookback step (0 = oldest, 95 = most recent)",
        y="Mean |SHAP|",
        caption="Twiga Forecast",
    )
    + twiga_theme(grid=True)
    + ggsize(820, 300)
)
p_ts

10. Per-Feature × Per-Timestep Heatmap#

The raw result.values array has shape (B, L, F). Taking the mean over the batch dimension gives a (L, F) attribution matrix. We reshape it to long form so geom_tile can render it as a Twiga-styled heatmap.

Reading the chart - bright cells are high-attribution (feature, lag) pairs. Dark columns indicate features the model largely ignores at those timesteps.

mean_abs = np.mean(np.abs(result.values), axis=0)  # (L, F)
L, F = mean_abs.shape
feature_names = result.feature_names

rows = []
for l_idx in range(L):
    for f_idx in range(F):
        rows.append({"step": l_idx, "feature": feature_names[f_idx], "shap": mean_abs[l_idx, f_idx]})
heatmap_df = pd.DataFrame(rows)

p_heat = (
    ggplot(heatmap_df, aes(x="step", y="feature", fill="shap"))
    + geom_tile()
    + scale_fill_gradient2(low="#e8f5f8", mid="#069fac", high="#107591", midpoint=float(mean_abs.mean()))
    + labs(
        title="Mean |SHAP| per Feature x Lookback Position",
        x="Lookback step (0 = oldest, 95 = most recent)",
        y="Feature",
        fill="Mean |SHAP|",
        caption="Twiga Forecast",
    )
    + twiga_theme()
    + ggsize(820, 260)
)
p_heat

11. Using ShapExplainer Directly#

forecaster.explain() is a thin wrapper around ShapExplainer. You can also instantiate it directly for more control over background samples and model dispatch.

explainer = ShapExplainer(
    model=forecaster.models[0],
    data_pipeline=forecaster.data_pipeline,
)

result2 = explainer.explain(X_explain, n_background=100)
log.info("ShapResult values shape : %s  (B, L, F)", result2.values.shape)

importance2 = explainer.feature_importance(X_explain)
log.info("Top feature : %s", next(iter(importance2)))

12. API Summary#

api_df = pd.DataFrame(
    {
        "Object / Method": [
            "DataPipelineConfig(historical_features=[...])",
            "DataPipeline.transform_features(test_df)",
            "forecaster.explain(X, model_idx=0)",
            "ShapResult.mean_importance()",
            "ShapResult.timestep_importance()",
            "ShapResult.values",
            "ShapExplainer(model, data_pipeline)",
        ],
        "What it does": [
            "Declares features as past-only — keeps model input shape (B, L, F) with no future inflation",
            "Returns the (B, L, F) feature array the model sees at predict time",
            "Builds ShapExplainer, runs TreeExplainer for LightGBM, returns ShapResult averaged over horizon",
            "Returns features ranked by mean |SHAP| averaged over B and L",
            "Returns timestep labels ranked by mean |SHAP| averaged over B and F",
            "Raw SHAP array of shape (B, L, F) — use for custom heatmaps or aggregations",
            "Low-level explainer with full control over n_background and model dispatch",
        ],
    }
)

twiga_gt(
    GT(api_df)
    .tab_header(
        title=md("**Tutorial 16 — API Quick Reference**"),
        subtitle="SHAP feature attribution objects and methods",
    )
    .cols_label(**{c: md(f"**{c}**") for c in api_df.columns})
    .tab_source_note("Twiga Forecast · twiga.core.explain"),
    n_rows=len(api_df),
)

Wrapping up#

What you did

  • Trained a LightGBM forecaster on MLVS-PT with historical_features=["Ghi", "Temperature"]

  • Extracted the 3-D test features (B, L, F) from DataPipeline.transform_features()

  • Computed SHAP values with forecaster.explain() using TreeExplainer (fast, exact)

  • Ranked features by mean |SHAP| and plotted the importance bar chart with plot_metrics_bar

  • Visualised the timestep attribution profile as a Lets-Plot line chart

  • Built a (L, F) heatmap using geom_tile with the Twiga teal colour scale

  • Used ShapExplainer directly for fine-grained control

Key takeaways

  1. Use historical_features (not exogenous_features) for SHAP - this keeps the model input a clean (B, L, F) array without future covariate steps appended.

  2. SHAP values live in (B, L, F) space - one attribution per sample, lookback step, and feature.

  3. TreeExplainer is exact and fast for tree models - always prefer it over KernelExplainer when available.

  4. Mean |SHAP| uses absolute values to prevent positive and negative attributions from cancelling out.

  5. The timestep profile reveals how far back the model actually looks - many models ignore the oldest lags despite having them in the input.


What’s next?#

# ruff: noqa: E501, E701, E702
from IPython.display import HTML

_TEAL = "#107591"
_TEAL_MID = "#069fac"
_TEAL_LIGHT = "#e8f5f8"
_TEAL_BEST = "#d0ecf1"
_TEXT_DARK = "#2d3748"
_TEXT_MUTED = "#718096"
_WHITE = "#ffffff"

steps = [
    {
        "num": "05",
        "title": "ML Point Forecasting",
        "desc": "CatBoost · XGBoost · LightGBM — point forecasts",
        "tags": ["catboost", "xgboost"],
        "active": False,
    },
    {
        "num": "11",
        "title": "Hyperparameter Tuning",
        "desc": "Optuna TPE · search spaces · resumable SQLite",
        "tags": ["optuna", "HPO"],
        "active": False,
    },
    {
        "num": "12",
        "title": "Ensemble Strategies",
        "desc": "Mean · median · weighted-mean ensembles",
        "tags": ["ensemble", "weighted"],
        "active": False,
    },
    {
        "num": "14",
        "title": "Typed Forecast Results",
        "desc": "ForecastResult · ForecastCollection · typed dispatch",
        "tags": ["typed", "ForecastResult"],
        "active": False,
    },
    {
        "num": "16",
        "title": "SHAP Feature Attribution",
        "desc": "TreeExplainer · feature ranking · timestep attribution heatmap",
        "tags": ["SHAP", "explainability", "attribution"],
        "active": True,
    },
]
track_name = "Advanced Track"
footer = "You have reached the end of the Advanced Track. Consider contributing a custom model (13) or exploring the Twiga API reference."


def _b(t, bg, fg):
    return f'<span style="display:inline-block;background:{bg};color:{fg};font-size:10px;font-weight:600;padding:2px 7px;border-radius:10px;margin:2px 2px 0 0;">{t}</span>'


ch = ""
for i, s in enumerate(steps):
    a = s["active"]
    cb = _TEAL if a else _WHITE
    cbo = _TEAL if a else "#d1ecf1"
    nb = _TEAL_MID if a else _TEAL_LIGHT
    nf = _WHITE if a else _TEAL
    tf = _WHITE if a else _TEXT_DARK
    df = "#cce8ef" if a else _TEXT_MUTED
    bb = "#0d5f75" if a else _TEAL_BEST
    bf = "#b8e4ed" if a else _TEAL
    yh = (
        f'<span style="float:right;background:{_TEAL_MID};color:{_WHITE};font-size:10px;font-weight:700;padding:2px 10px;border-radius:12px;">\u2605 you are here</span>'
        if a
        else ""
    )
    bdg = "".join(_b(t, bb, bf) for t in s["tags"])
    shadow = "0 4px 14px rgba(16,117,145,.25)" if a else "0 1px 4px rgba(0,0,0,.06)"
    ch += f'<div style="background:{cb};border:2px solid {cbo};border-radius:12px;padding:16px 20px;display:flex;align-items:flex-start;gap:16px;box-shadow:{shadow};"><div style="min-width:44px;height:44px;background:{nb};color:{nf};border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:15px;font-weight:800;flex-shrink:0;">{s["num"]}</div><div style="flex:1;"><div style="font-size:15px;font-weight:700;color:{tf};margin-bottom:4px;">{s["title"]}{yh}</div><div style="font-size:12.5px;color:{df};margin-bottom:8px;line-height:1.5;">{s["desc"]}</div><div>{bdg}</div></div></div>'
    if i < len(steps) - 1:
        ch += f'<div style="display:flex;justify-content:center;height:32px;"><svg width="24" height="32" viewBox="0 0 24 32" fill="none"><line x1="12" y1="0" x2="12" y2="24" stroke="{_TEAL_MID}" stroke-width="2" stroke-dasharray="4 3"/><polygon points="6,20 18,20 12,30" fill="{_TEAL_MID}"/></svg></div>'

HTML(
    f'<div style="font-family:Inter,\'Segoe UI\',sans-serif;max-width:640px;margin:8px 0;"><div style="background:linear-gradient(135deg,{_TEAL} 0%,{_TEAL_MID} 100%);border-radius:12px 12px 0 0;padding:14px 20px;display:flex;align-items:center;gap:10px;"><svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="{_WHITE}" stroke-width="2"><path d="M12 2L2 7l10 5 10-5-10-5z"/><path d="M2 17l10 5 10-5"/><path d="M2 12l10 5 10-5"/></svg><span style="color:{_WHITE};font-size:14px;font-weight:700;">Twiga Learning Path — {track_name}</span></div><div style="border:2px solid {_TEAL_LIGHT};border-top:none;border-radius:0 0 12px 12px;padding:20px 20px 16px;background:#f9fdfe;display:flex;flex-direction:column;">{ch}<div style="margin-top:16px;font-size:11.5px;color:{_TEXT_MUTED};text-align:center;border-top:1px solid {_TEAL_LIGHT};padding-top:12px;">{footer}</div></div></div>'
)