Causal Explanations with LEWIS
This notebook implements LEWIS, a causal explanation framework for black-box machine learning models. Using probabilistic contrastive counterfactuals, we compute Necessity (Nec), Sufficiency (Suf), and Necessity-and-Sufficiency (NeSuf) scores to quantify causal feature importance without relying on model internals. This notebook follows the methodology proposed by Galhotra et al. (SIGMOD 2021).
Reference:
Sainyam Galhotra, Romila Pradhan, Babak Salimi (2021). Explaining Black-Box Algorithms Using Probabilistic Contrastive Counterfactuals. SIGMOD ’21: International Conference on Management of Data, Virtual Event, China, June 20-25, 2021.
import numpy as np
import pandas as pd
import operator
import graphviz
import matplotlib.pyplot as plt
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import lingam
from lingam.utils import make_dot
print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])
np.set_printoptions(precision=3, suppress=True)
np.random.seed(0)
['1.26.4', '2.3.0', '0.20.3', '1.12.1']
Test data
We create test data consisting of 7 variables.
m = np.array([
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[ 0.4, 0.0,-0.6, 0.0, 0.0, 0.0, 0.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0],
[ 0.0, 0.2, 0.0, 0.0, 0.0,-0.7, 0.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[ 0.0, 0.5, 0.0,-0.1,-0.5, 0.0, 0.0],
])
generate_error = lambda p: np.random.uniform(-p, p, size=1000)
error_vars = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
params = [0.5 * np.sqrt(12 * v) for v in error_vars]
e = np.array([generate_error(p) for p in params])
cols = [f"x{i}" for i in range(len(m) - 1)] + ["y"]
X = np.linalg.pinv(np.eye(len(m)) - m) @ e
df = pd.DataFrame(X.T, columns=cols)
display(make_dot(m, labels=cols))
df.head()
| x0 | x1 | x2 | x3 | x4 | x5 | y | |
|---|---|---|---|---|---|---|---|
| 0 | 0.119568 | -0.182500 | 0.763061 | -0.373908 | -0.315997 | -0.326320 | -0.202553 |
| 1 | 0.527104 | -0.954103 | -0.058582 | 0.069640 | 0.319117 | -0.495717 | -0.818257 |
| 2 | 0.251718 | 0.007441 | 0.056720 | 0.154034 | -0.146964 | -1.056711 | 0.510195 |
| 3 | 0.109941 | 0.922016 | -0.611097 | 0.680521 | 0.069137 | -0.361232 | 0.605658 |
| 4 | -0.187007 | -1.346211 | 0.257302 | 0.447058 | -0.904446 | -0.655983 | -1.317686 |
Discretization
LEWIS assumes that all features have finite, discrete domains. Therefore, continuous-valued features must be discretized before computing LEWIS scores. Each continuous feature is binned into a small number of ordered categories (e.g., Low < Medium < High), enabling contrastive value pairs \((x, x′)\) and well-defined computation of Necessity, Sufficiency, and Necessity-and-Sufficiency scores.
kbd = KBinsDiscretizer(n_bins=2, encode="ordinal", strategy="uniform", subsample=None)
df["y"] = kbd.fit_transform(df["y"].values.reshape(-1, 1))
feature_names = [f"x{i}" for i in range(len(m) - 1)]
for name in feature_names:
kbd = KBinsDiscretizer(n_bins=5, encode="ordinal", strategy="uniform", subsample=None)
df[name] = kbd.fit_transform(df[name].values.reshape(-1, 1))
df.head()
| x0 | x1 | x2 | x3 | x4 | x5 | y | |
|---|---|---|---|---|---|---|---|
| 0 | 2.0 | 2.0 | 4.0 | 1.0 | 2.0 | 1.0 | 0.0 |
| 1 | 3.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 0.0 |
| 2 | 3.0 | 2.0 | 2.0 | 2.0 | 2.0 | 0.0 | 1.0 |
| 3 | 2.0 | 3.0 | 1.0 | 3.0 | 2.0 | 1.0 | 1.0 |
| 4 | 2.0 | 1.0 | 3.0 | 3.0 | 1.0 | 1.0 | 0.0 |
Model Training and Prediction with RandomForestClassifier
We train a black-box prediction model using a RandomForestClassifier. The model learns a non-linear decision function by aggregating predictions from multiple decision trees trained on bootstrapped samples of the data. After training, the model is used to generate predictions (and class probabilities) for each instance. These predicted outcomes serve as the decision variable used in subsequent LEWIS analysis, while the internal structure of the model remains opaque.
X = df[feature_names]
y = df["y"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
clf = RandomForestClassifier(max_depth=5, random_state=0)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
X_test["y"] = y_pred
X_test.head()
| x0 | x1 | x2 | x3 | x4 | x5 | y | |
|---|---|---|---|---|---|---|---|
| 993 | 2.0 | 2.0 | 4.0 | 2.0 | 1.0 | 2.0 | 0.0 |
| 859 | 1.0 | 0.0 | 4.0 | 3.0 | 2.0 | 1.0 | 0.0 |
| 298 | 4.0 | 3.0 | 4.0 | 1.0 | 4.0 | 1.0 | 0.0 |
| 553 | 2.0 | 1.0 | 4.0 | 2.0 | 0.0 | 4.0 | 0.0 |
| 672 | 3.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 |
Computing LEWIS Scores
We compute LEWIS explanation scores—Necessity (Nec), Sufficiency (Suf), and Necessity-and-Sufficiency (NeSuf)—for each selected feature using a causal, model-agnostic approach. Based on the given causal graph, we first identify an adjustment set that satisfies the backdoor criterion for the relationship between the target feature and the outcome. These variables are used only for probability estimation and are not part of the user-facing explanation.
For an ordinal feature with ordered values \((x^{(1)} \prec \dots \prec x^{(L)})\), we compute Nec, Suf, and NeSuf for all ordered value pairs \((x, x')\) such that \((x \succ x')\). After evaluating all pairs, we select the pair that maximizes NeSuf as the most causally responsive contrast. The corresponding Nec and Suf values for this same pair are then reported, ensuring that all three scores are aligned to a single, interpretable intervention.
backdoor = {
'x0': [],
'x1': [],
'x2': [],
'x3': ['x5'],
'x4': ['x1', 'x5'],
'x5': [],
}
def max_nesuf(score_dic, score, name):
pre_score = [0, 0, 0]
if name in score_dic.keys():
pre_score = score_dic[name]
if score[2] > pre_score[2]: # nesuf
score_dic[name] = score
return score_dic
score_dic = {}
for name in feature_names:
val = {}
for v in df[name].unique():
df_temp = df[df[name] == v]
vc = df_temp["y"].value_counts().reindex([0, 1], fill_value=0)
val[v] = vc[1] * 1.0 / (vc[0] + vc[1])
sorted_val = sorted(val.items(), key=operator.itemgetter(1))
for i, (before, _) in enumerate(sorted_val):
for after, _ in sorted_val[i + 1 :]:
lewis = lingam.LEWIS()
score = lewis.get_scores(
X_test,
x_names=[name],
x_values=[after],
x_prime_values=[before],
o_name="y",
c_names=backdoor[name])
score_dic = max_nesuf(score_dic, score, name)
print(f"{name}: {before}->{after}: score={score}")
x0: 0.0->1.0: score=(0.28876449588651437, 0.15671962820931234, 0.11307290998254543)
x0: 0.0->2.0: score=(0.37054073719681124, 0.22722774483181898, 0.16394438035923292)
x0: 0.0->3.0: score=(0.308951540247843, 0.1725738099335212, 0.12451167156864434)
x0: 0.0->4.0: score=(0.3041121895598342, 0.16868933442058193, 0.12170902996578553)
x0: 1.0->2.0: score=(0.11497772655799332, 0.08361171323457248, 0.05087147037668749)
x0: 1.0->3.0: score=(0.028383066149784867, 0.018800605652118815, 0.011438761586098911)
x0: 1.0->4.0: score=(0.02157891947822522, 0.014194218923715932, 0.008636119983240098)
x0: 2.0->3.0: score=(0.0, 0.0, 0.0)
x0: 2.0->4.0: score=(0.0, 0.0, 0.0)
x0: 3.0->4.0: score=(0.0, 0.0, 0.0)
x1: 0.0->1.0: score=(1.0, 0.058517831098751374, 0.058517831098751374)
x1: 0.0->2.0: score=(1.0, 0.3863071221311797, 0.3863071221311797)
x1: 0.0->3.0: score=(1.0, 0.6916588068976024, 0.6916588068976024)
x1: 0.0->4.0: score=(1.0, 0.610516293905489, 0.610516293905489)
x1: 1.0->2.0: score=(0.8485199269018905, 0.348163036815634, 0.3277892910324283)
x1: 1.0->3.0: score=(0.9153949454338188, 0.6724938577835782, 0.633140975798851)
x1: 1.0->4.0: score=(0.9041502549843325, 0.5863079313025591, 0.5519984628067376)
x1: 2.0->3.0: score=(0.44147733206212025, 0.4975643286375128, 0.3053516847664227)
x1: 2.0->4.0: score=(0.36724518905144576, 0.3653442623497988, 0.22420917177430932)
x1: 3.0->4.0: score=(0.0, 0.0, 0.0)
x2: 4.0->3.0: score=(0.37781207464998834, 0.12785476560870568, 0.10561677393784988)
x2: 4.0->2.0: score=(0.5960755908256189, 0.31071620571720965, 0.2566728201473194)
x2: 4.0->0.0: score=(0.6241556609044814, 0.3496612983303619, 0.2888441281382971)
x2: 4.0->1.0: score=(0.6962790808358484, 0.48269334060441804, 0.39873768641480434)
x2: 3.0->2.0: score=(0.350799987082434, 0.2096685653922427, 0.1510560462094695)
x2: 3.0->0.0: score=(0.3959311587667224, 0.2543229315200742, 0.18322735420044717)
x2: 3.0->1.0: score=(0.5118501873958845, 0.406857207955013, 0.2931209124769545)
x2: 2.0->0.0: score=(0.06951813121731897, 0.056500809878571416, 0.03217130799097767)
x2: 2.0->1.0: score=(0.2480748569145521, 0.2495012015568321, 0.142064866267485)
x2: 0.0->1.0: score=(0.19189704999930096, 0.2045580893963687, 0.10989355827650732)
x3: 0.0->3.0: score=(0.0, 0.0, 0.07576064369254501)
x3: 0.0->2.0: score=(0.24019885850208111, 0.0, 0.06756468929060586)
x3: 0.0->4.0: score=(0.4375471879801787, 0.33627411063087126, 0.2394474103512867)
x3: 0.0->1.0: score=(0.5215173421020096, 0.2077705378411908, 0.2560619882491983)
x3: 3.0->2.0: score=(5.8380840876851114e-05, 0.0, 0.0016793883085030473)
x3: 3.0->4.0: score=(0.06087756395956861, 0.18097315391122015, 0.1735621093691839)
x3: 3.0->1.0: score=(0.4050632226945426, 0.2552560187517776, 0.1901766872670955)
x3: 2.0->4.0: score=(0.08389050626406856, 0.26072855163571107, 0.17188272106068087)
x3: 2.0->1.0: score=(0.39325969742974465, 0.27799227073061517, 0.18849729895859244)
x3: 4.0->1.0: score=(0.009866759685432035, 0.3007171788408852, 0.016614577897911605)
x4: 4.0->3.0: score=(0.0, 0.0002437752006358201, 0.01417600649350649)
x4: 4.0->2.0: score=(0.0, 0.43864694304558277, 0.3479847269469221)
x4: 4.0->1.0: score=(0.36470213549225855, 0.6477582075236524, 0.5560876613241783)
x4: 4.0->0.0: score=(0.8540240160406432, 0.6193767844978033, 0.624207928623127)
x4: 3.0->2.0: score=(0.0, 0.3662085477202382, 0.33380872045341564)
x4: 3.0->1.0: score=(0.34372689656444605, 0.6086744374259367, 0.5419116548306718)
x4: 3.0->0.0: score=(0.8449920368334224, 0.6353532663565686, 0.6100319221296207)
x4: 2.0->1.0: score=(0.0, 0.408105331786863, 0.21286296955041636)
x4: 2.0->0.0: score=(0.24400727130132333, 0.5372555633357295, 0.27622320167620484)
x4: 1.0->0.0: score=(0.17773403145390043, 0.252226191821934, 0.07080119345612482)
x5: 1.0->0.0: score=(0.0, 0.0, 0.0)
x5: 1.0->2.0: score=(0.2233835942176238, 0.08602778414260957, 0.06622184994605693)
x5: 1.0->3.0: score=(0.5867549758794374, 0.42466177557054957, 0.3268930922716893)
x5: 1.0->4.0: score=(0.6595932839891901, 0.5795249365522417, 0.44610254432152086)
x5: 0.0->2.0: score=(0.4502876323702336, 0.1594757752886144, 0.13348733208373778)
x5: 0.0->3.0: score=(0.7074927867488406, 0.4708967005285375, 0.3941585744093701)
x5: 0.0->4.0: score=(0.7590499242327206, 0.613314857297541, 0.5133680264592017)
x5: 2.0->3.0: score=(0.46789042692929883, 0.37050797119720974, 0.2606712423256323)
x5: 2.0->4.0: score=(0.5616797256969114, 0.5399476525078897, 0.3798806943754639)
x5: 3.0->4.0: score=(0.17625937121629423, 0.2691689069247335, 0.11920945204983158)
Visualization of LEWIS Scores Sorted by NeSuf
Finally, we visualize the LEWIS scores using a horizontal bar chart, where features are sorted in descending order of NeSuf. NeSuf represents the strongest indicator of causal importance, capturing how often a feature change would flip the decision outcome in both directions. By sorting features by NeSuf, the plot highlights which features are most causally influential overall. For each feature, the corresponding Nec and Suf values for the same maximizing contrast are displayed alongside NeSuf, enabling a consistent and interpretable comparison.
def plot_scores_barh_featurecol(
df,
feature_col="feature",
value_cols=("Nec", "Suf", "NeSuf"),
title="LEWIS Scores",
reverse_for_image=True,
):
needed = [feature_col, *value_cols]
missing = [c for c in needed if c not in df.columns]
if missing:
raise ValueError(f"df is missing columns {missing}. Required columns: {needed}")
df_plot = (
df.loc[:, needed]
.dropna(subset=[feature_col])
.drop_duplicates(subset=[feature_col])
.set_index(feature_col)
.loc[:, list(value_cols)]
)
if reverse_for_image:
df_plot = df_plot.iloc[::-1]
y = np.arange(len(df_plot.index))
h = 0.22
fig, ax = plt.subplots(figsize=(3.0, 3.6), dpi=120)
bars1 = ax.barh(
y + h,
df_plot[value_cols[0]].values,
height=h,
color="#4C72B0",
edgecolor="black",
linewidth=0.6,
label=value_cols[0],
)
bars2 = ax.barh(
y,
df_plot[value_cols[1]].values,
height=h,
color="#F0E442",
edgecolor="black",
linewidth=0.6,
label=value_cols[1],
)
bars3 = ax.barh(
y - h,
df_plot[value_cols[2]].values,
height=h,
color="#2CA02C",
edgecolor="black",
linewidth=0.6,
label=value_cols[2],
)
ax.set_yticks(y)
ax.set_yticklabels(df_plot.index)
ax.set_xlim(0, 1.0)
ax.set_xticks([0.0, 0.5, 1.0])
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
ax.set_xlabel(title)
ax.legend(loc="lower right", fontsize=7, frameon=True)
def add_labels(bars):
for b in bars:
w = b.get_width()
yy = b.get_y() + b.get_height() / 2
ax.text(w + 0.01, yy, f"{w:.2f}", va="center", ha="left", fontsize=7)
add_labels(bars1)
add_labels(bars2)
add_labels(bars3)
ax.grid(False)
plt.tight_layout()
plt.show()
nec = []
suf = []
nesuf = []
for name in feature_names:
score = score_dic[name]
nec.append(score[0])
suf.append(score[1])
nesuf.append(score[2])
score_df = pd.DataFrame()
score_df['feature'] = feature_names
score_df['Nec'] = nec
score_df['Suf'] = suf
score_df['NeSuf'] = nesuf
plot_scores_barh_featurecol(score_df.sort_values(['NeSuf'], ascending=False))