Demos

Two end-to-end walkthroughs ship in the repository root. Each one is a single self-contained script you can run with python demo_<name>.py after installing smmargins.

Williams-style logit walkthrough

demo_margins.py reproduces, on a simulated voting dataset, every core statistic in Richard Williams’ Margins01 notes and then exercises the 0.3 inference surface end-to-end:

  1. Adjusted predictions at specific values (APR / margins, at(...))

  2. APM vs AAP (margins, atmeans vs margins)

  3. MER vs MEM vs AME for a continuous covariate

  4. Discrete contrasts for a multi-level categorical variable

  5. Discrete change for a 0/1 dummy

  6. Williams’ classic interaction example: AME of age by sex

  7. Predicted probability over age, by sex (table for plotting)

  8. Analytic vs FD parity check

  9. Robust covariance via cov_type="HC3"

  10. Krinsky–Robb simulation VCE

  11. Pairs bootstrap VCE

  12. Simultaneous CIs via sup-t

  13. Cluster-robust SEs (cov_type="cluster" with cov_kwds=)

  14. Multiple-comparison adjustments side-by-side (Bonferroni / Šidák / sup-t)

  15. User-supplied parameter covariance (vcov=)

Highlights from the script

Fit a logit with an interaction:

fit = smf.logit(
    "voted ~ age + income + C(educ) + female + age:female",
    data=df,
).fit()
M = Margins(fit)

APR — predictions at policy-relevant ages, averaging everything else over the sample:

M.predict(atexog={"age": [25, 45, 65]})

MER, MEM, and AME for age — these can differ meaningfully in nonlinear models with interactions:

M.dydx("age", atexog={"age": [25, 45, 65]})   # MER
M.dydx("age", at="mean")                      # MEM
M.dydx("age")                                 # AME

Discrete AME for a multi-level factor with an explicit reference level:

M.dydx("educ", reference="college")

Williams’ interaction lesson — same model, AME of age for each sex:

M.dydx("age", atexog={"female": [0, 1]})

Robust SEs and alternative VCEs (sections 9–11):

Margins(fit, cov_type="HC3").dydx("age")          # heteroskedastic-robust
M.dydx("age", vce="simulation",
       n_sims=2000, sim_seed=42)                   # Krinsky–Robb
M.dydx("age", vce="bootstrap",
       n_boot=500, boot_seed=42)                   # pairs bootstrap

Cluster-robust SEs through cov_type="cluster" with cluster IDs passed in cov_kwds (section 13):

Margins(fit, cov_type="cluster",
        cov_kwds={"groups": df["household"]}).dydx("age")

Family-wise CI methods side-by-side at five ages (section 14) — for a correlated family of predictions, sup-t is typically narrower than Bonferroni / Šidák:

common = dict(atexog={"age": [25, 35, 45, 55, 65]},
              vce="simulation", n_sims=4000, sim_seed=123)
M.predict(**common, ci_method="pointwise")
M.predict(**common, ci_method="bonferroni")
M.predict(**common, ci_method="sidak")
M.predict(**common, ci_method="sup-t")

User-supplied parameter covariance (section 15) — drop in any \((k, k)\) matrix and smmargins sandwiches it through the Jacobian:

Margins(fit, vcov=my_vcov_matrix).dydx("age")

Full source

  1"""
  2demo_margins.py
  3===============
  4
  5Walkthrough of the core analyses in Richard Williams' *Margins01* notes
  6(https://academicweb.nd.edu/~rwilliam/stats/Margins01.pdf), implemented
  7on top of StatsModels + patsy + the ``smmargins`` package.
  8
  9Sections
 10--------
 11  1.  Adjusted predictions at specific values (APR / ``margins, at(...)``)
 12  2.  APM vs AAP (``margins, atmeans`` vs ``margins``)
 13  3.  MER vs MEM vs AME for a continuous covariate
 14  4.  Discrete contrast for a categorical variable
 15  5.  Discrete change for a 0/1 dummy
 16  6.  AME by interaction subgroup (Williams' motivating example)
 17  7.  Predicted probability over age, by sex (table for plotting)
 18  8.  Analytic vs FD parity check
 19  9.  Robust covariance (``cov_type="HC3"``)
 20  10. Krinsky–Robb simulation VCE
 21  11. Pairs bootstrap VCE
 22  12. Simultaneous CIs via sup-t
 23  13. Cluster-robust SEs (``cov_type="cluster"``)
 24  14. Multiple-comparison adjustments side-by-side (Bonferroni / Sidak / sup-t)
 25  15. User-supplied parameter covariance (``vcov=``)
 26  16. Prediction scales (``scale=``) and a custom ``Transform``
 27  17. Subgroup AMEs via ``over=`` and observation weights
 28  18. Joint Wald test and pairwise comparisons
 29"""
 30
 31import numpy as np
 32import pandas as pd
 33import statsmodels.formula.api as smf
 34
 35from smmargins import Margins, Transform
 36
 37pd.options.display.width = 120
 38pd.options.display.float_format = "{: .4f}".format
 39
 40# ---------------------------------------------------------------------------
 41# Simulate a binary-outcome dataset with structure similar to Williams' notes
 42# ---------------------------------------------------------------------------
 43rng = np.random.default_rng(7)
 44N = 5_000
 45df = pd.DataFrame(
 46    {
 47        "age":    rng.normal(45, 12, N).clip(18, 90),
 48        "income": rng.lognormal(10.5, 0.4, N),          # ~36k median
 49        "educ":   rng.choice(["hs", "college", "grad"], N, p=[0.4, 0.4, 0.2]),
 50        "female": rng.integers(0, 2, N),
 51    }
 52)
 53eta = (
 54    -4.0
 55    + 0.05 * df["age"]
 56    + 0.00001 * df["income"]
 57    + 0.8 * (df["educ"] == "college")
 58    + 1.4 * (df["educ"] == "grad")
 59    + 0.3 * df["female"]
 60    - 0.0004 * df["age"] * (df["female"])        # interaction
 61)
 62df["voted"] = (rng.uniform(0, 1, N) < 1 / (1 + np.exp(-eta))).astype(int)
 63
 64print("Sample:")
 65print(df.head(3), "\n")
 66
 67# ---------------------------------------------------------------------------
 68# Fit a logit with an interaction, like the Williams example
 69# ---------------------------------------------------------------------------
 70fit = smf.logit(
 71    "voted ~ age + income + C(educ) + female + age:female",
 72    data=df,
 73).fit(disp=False)
 74print("=" * 80)
 75print("Fitted logit")
 76print("=" * 80)
 77print(fit.summary().tables[1])
 78print()
 79
 80# `analytic=True` is the default: the outer ∂g/∂β goes through
 81# `family.link.inverse_deriv` for any GLM (Logit/Probit/Poisson/...) and
 82# the identity link for OLS/WLS/GLS, falling back to central finite
 83# differences only when the link derivative isn't available. Set
 84# `analytic=False` to force FD; you'll get the same answers (see the
 85# parity check at the bottom of this file) but pay p extra forward
 86# predict() calls per statistic.
 87M = Margins(fit)
 88
 89# ---------------------------------------------------------------------------
 90# 1. Adjusted predictions at representative values (APR)
 91#    Stata: margins, at(age=(25 45 65))
 92# ---------------------------------------------------------------------------
 93print("=" * 80)
 94print("1. APR  (predict at age=25,45,65; everything else at sample values)")
 95print("=" * 80)
 96print(M.predict(atexog={"age": [25, 45, 65]}))
 97print()
 98
 99# ---------------------------------------------------------------------------
100# 2. Adjusted prediction at means (APM)  vs  average adjusted prediction (AAP)
101# ---------------------------------------------------------------------------
102print("=" * 80)
103print("2. APM  (margins, atmeans)   vs   AAP  (margins)")
104print("=" * 80)
105print("APM:"); print(M.predict(at="mean"))
106print("\nAAP:"); print(M.predict())
107print()
108
109# ---------------------------------------------------------------------------
110# 3. Marginal effect: MER vs MEM vs AME for `age`
111#    (Williams points out these three can differ meaningfully in nonlinear
112#    models with interactions)
113# ---------------------------------------------------------------------------
114print("=" * 80)
115print("3. d Pr(voted)/d age : MER (at age=25,45,65),  MEM, and AME")
116print("=" * 80)
117print("MER (at age=25,45,65):")
118print(M.dydx("age", atexog={"age": [25, 45, 65]}))
119print("\nMEM (at means of everything):")
120print(M.dydx("age", at="mean"))
121print("\nAME (averaged over the sample):")
122print(M.dydx("age"))
123print()
124
125# ---------------------------------------------------------------------------
126# 4. Discrete contrast for the categorical variable `educ`
127# ---------------------------------------------------------------------------
128print("=" * 80)
129print("4. Discrete AME for educ  (each level vs 'college' as reference)")
130print("=" * 80)
131print(M.dydx("educ", reference="college"))
132print()
133
134# ---------------------------------------------------------------------------
135# 5. Discrete change for the dummy `female`  (auto-detected as discrete)
136# ---------------------------------------------------------------------------
137print("=" * 80)
138print("5. AME for female (0/1 dummy):  Pr(voted|female=1) - Pr(voted|female=0)")
139print("=" * 80)
140print(M.dydx("female"))
141print()
142
143# ---------------------------------------------------------------------------
144# 6. Interaction-sensitivity: marginal effect of age, separately for men/women
145#    This is Williams' classic motivating example: the interaction coefficient
146#    alone tells you little about what the marginal effect actually is for any
147#    given subpopulation.
148# ---------------------------------------------------------------------------
149print("=" * 80)
150print("6. AME of age, separately by sex  (Williams' interaction illustration)")
151print("=" * 80)
152print(M.dydx("age", atexog={"female": [0, 1]}))
153print()
154
155# ---------------------------------------------------------------------------
156# 7. Adjusted predictions, age by sex — table suitable for plotting
157# ---------------------------------------------------------------------------
158print("=" * 80)
159print("7. Predicted Pr(voted) over age, for each sex")
160print("=" * 80)
161tbl = M.predict(atexog={"age": list(range(20, 91, 10)), "female": [0, 1]})
162print(tbl)
163
164# ---------------------------------------------------------------------------
165# 8. Analytic vs FD: same answers, faster path
166#    Logit exposes `family.link.inverse_deriv`, so the analytic outer
167#    Jacobian is used by default. Toggling `analytic=False` reroutes
168#    every statistic through central finite differences — useful as a
169#    sanity check or when working with a custom Link subclass that
170#    doesn't implement inverse_deriv.
171# ---------------------------------------------------------------------------
172print()
173print("=" * 80)
174print("8. Analytic vs FD — same numbers, taken via different paths")
175print("=" * 80)
176M_fd = Margins(fit, analytic=False)
177ame_an = M.dydx("age")
178ame_fd = M_fd.dydx("age")
179print(f"AME(age) analytic : est={ame_an.estimate[0]: .8f}  se={ame_an.se[0]: .8f}")
180print(f"AME(age) FD       : est={ame_fd.estimate[0]: .8f}  se={ame_fd.se[0]: .8f}")
181print(f"max abs diff      : "
182      f"est {abs(ame_an.estimate[0] - ame_fd.estimate[0]): .2e}, "
183      f"se {abs(ame_an.se[0] - ame_fd.se[0]): .2e}")
184
185# ---------------------------------------------------------------------------
186# 9. Robust covariance (Feature 1)
187#    Recompute SEs with HC3 heteroskedasticity-consistent covariance.
188# ---------------------------------------------------------------------------
189print()
190print("=" * 80)
191print("9. Robust covariance — HC3")
192print("=" * 80)
193M_hc3 = Margins(fit, cov_type="HC3")
194print(M_hc3.dydx("age"))
195
196# ---------------------------------------------------------------------------
197# 10. Krinsky–Robb simulation VCE (Feature 2)
198#     Draw parameters from their sampling distribution and evaluate margins.
199# ---------------------------------------------------------------------------
200print()
201print("=" * 80)
202print("10. Krinsky–Robb simulation VCE")
203print("=" * 80)
204print(M.dydx("age", vce="simulation", n_sims=2000, sim_seed=42))
205
206# ---------------------------------------------------------------------------
207# 11. Bootstrap VCE (Feature 3)
208#     Pairs bootstrap with 500 replications.
209# ---------------------------------------------------------------------------
210print()
211print("=" * 80)
212print("11. Bootstrap VCE")
213print("=" * 80)
214print(M.dydx("age", vce="bootstrap", n_boot=500, boot_seed=42))
215
216# ---------------------------------------------------------------------------
217# 12. Simultaneous CIs — sup-t (Feature 4)
218#     Use simulation draws to compute simultaneous CIs for a family of margins.
219# ---------------------------------------------------------------------------
220print()
221print("=" * 80)
222print("12. Simultaneous CIs (sup-t)")
223print("=" * 80)
224print(M.predict(atexog={"age": [25, 45, 65]},
225                vce="simulation", n_sims=2000, sim_seed=42,
226                ci_method="sup-t"))
227
228# ---------------------------------------------------------------------------
229# 13. Cluster-robust SEs
230#     Synthesize a clustering structure (e.g., households of ~10 voters who
231#     share unobserved local effects). Cluster-robust SEs propagate that
232#     correlation through the Jacobian to the AME.
233# ---------------------------------------------------------------------------
234print()
235print("=" * 80)
236print("13. Cluster-robust SEs vs nonrobust  (synthetic household clusters)")
237print("=" * 80)
238df_c = df.copy()
239df_c["household"] = rng.integers(0, N // 10, N)  # ~10 obs per cluster
240fit_c = smf.logit(
241    "voted ~ age + income + C(educ) + female + age:female",
242    data=df_c,
243).fit(disp=False)
244M_nonrobust = Margins(fit_c)
245M_cluster = Margins(fit_c, cov_type="cluster",
246                    cov_kwds={"groups": df_c["household"]})
247ame_nr = M_nonrobust.dydx("age").se[0]
248ame_cl = M_cluster.dydx("age").se[0]
249print(f"AME(age) SE — nonrobust : {ame_nr: .6f}")
250print(f"AME(age) SE — cluster    : {ame_cl: .6f}   (ratio {ame_cl / ame_nr:.2f}x)")
251
252# ---------------------------------------------------------------------------
253# 14. Multiple-comparison adjustments
254#     A family of 5 marginal effects at different ages. Pointwise CIs
255#     under-cover the joint event "all 5 contain the truth"; Bonferroni
256#     and Sidak inflate the critical value uniformly; sup-t uses the
257#     simulation draws to exploit correlation across the family.
258# ---------------------------------------------------------------------------
259print()
260print("=" * 80)
261print("14. Family-wise CI methods at age=25,35,45,55,65")
262print("=" * 80)
263ages = [25, 35, 45, 55, 65]
264common = dict(atexog={"age": ages}, vce="simulation",
265              n_sims=4000, sim_seed=123)
266pw   = M.predict(**common, ci_method="pointwise")
267bonf = M.predict(**common, ci_method="bonferroni")
268sidk = M.predict(**common, ci_method="sidak")
269supt = M.predict(**common, ci_method="sup-t")
270
271widths = pd.DataFrame({
272    "age":        ages,
273    "pointwise":  pw.ci_upper   - pw.ci_lower,
274    "bonferroni": bonf.ci_upper - bonf.ci_lower,
275    "sidak":      sidk.ci_upper - sidk.ci_lower,
276    "sup-t":      supt.ci_upper - supt.ci_lower,
277}).set_index("age")
278print("CI widths:")
279print(widths)
280print("\nBonferroni >= Sidak (always); for correlated margins sup-t is "
281      "typically narrower than both.")
282
283# ---------------------------------------------------------------------------
284# 15. User-supplied vcov
285#     Drop in any (k, k) covariance matrix you trust — e.g. a sandwich
286#     computed offline, a Bayesian posterior covariance, or the output
287#     of a custom resampling scheme — and smmargins will sandwich it
288#     through the Jacobian without recomputing anything else.
289# ---------------------------------------------------------------------------
290print()
291print("=" * 80)
292print("15. User-supplied parameter covariance (vcov=)")
293print("=" * 80)
294V_default = fit.cov_params().to_numpy()
295V_inflated = V_default * 1.5     # toy example: assume 50% wider sampling cov
296M_v = Margins(fit, vcov=V_inflated)
297ame_default = M.dydx("age").se[0]
298ame_user    = M_v.dydx("age").se[0]
299print(f"AME(age) SE — default cov_params() : {ame_default: .6f}")
300print(f"AME(age) SE — vcov = 1.5 x default : {ame_user: .6f}   "
301      f"(ratio {ame_user / ame_default:.3f}, expect ≈ sqrt(1.5)={np.sqrt(1.5):.3f})")
302
303
304# ---------------------------------------------------------------------------
305# 16. Prediction scales (`scale=`) and a custom Transform
306#     The default response scale gives Λ(η) for logit. Switch to "linear"
307#     to get the linear predictor (Stata 'xb'), to "or" for odds-ratio
308#     scale, or pass a Transform for any user-defined λ.
309# ---------------------------------------------------------------------------
310print()
311print("=" * 80)
312print("16. Prediction scales (scale=) and custom Transform")
313print("=" * 80)
314ame_resp   = M.dydx("age").estimate[0]
315ame_linear = M.dydx("age", scale="linear").estimate[0]
316ame_or     = M.dydx("age", scale="or").estimate[0]
317print(f"AME(age) on response scale      : {ame_resp: .6f}   (probability units)")
318print(f"AME(age) on linear scale        : {ame_linear: .6f}   (= beta_age, log-odds)")
319print(f"AME(age) on odds-ratio scale    : {ame_or: .6f}")
320
321# Custom Transform: square the linear predictor (silly but illustrative)
322square = Transform(
323    value=lambda e: e ** 2,
324    grad=lambda e: 2 * e,
325    hess=lambda e: np.full_like(e, 2.0),
326    name="square",
327)
328ame_sq = M.dydx("age", scale=square).estimate[0]
329print(f"AME(age) under square transform : {ame_sq: .6f}")
330
331# ---------------------------------------------------------------------------
332# 17. Subgroup AMEs (over=) and observation weights
333#     over= partitions the sample and averages within each subgroup,
334#     keeping the FULL joint covariance (so cross-subgroup tests are valid).
335#     weights= threads through every average; pass at construction time.
336# ---------------------------------------------------------------------------
337print()
338print("=" * 80)
339print("17. Subgroup AMEs (over=) and observation weights")
340print("=" * 80)
341print("AME(age) by educ:")
342print(M.dydx("age", over="educ").summary())
343
344w = rng.uniform(0.5, 2.0, N)
345M_w = Margins(fit, weights=w)
346ame_unw = M.dydx("age").estimate[0]
347ame_w   = M_w.dydx("age").estimate[0]
348print()
349print(f"AME(age) — unweighted         : {ame_unw: .6f}")
350print(f"AME(age) — sampling weighted  : {ame_w: .6f}")
351
352# ---------------------------------------------------------------------------
353# 18. Joint Wald test and pairwise comparisons
354#     wald() tests linear restrictions on a result; pairwise() builds the
355#     all-level-vs-level contrast matrix for a factor and applies whatever
356#     ci_method= you ask for (sup-t needs simulation/bootstrap draws).
357# ---------------------------------------------------------------------------
358print()
359print("=" * 80)
360print("18. Joint Wald and pairwise comparisons")
361print("=" * 80)
362res = M.dydx(["age", "income", "female"])
363joint = res.wald()
364print(f"Joint H0: AME(age) = AME(income) = AME(female) = 0")
365print(f"  chi^2 = {joint.stat:.4f},  df = {joint.df},  p = {joint.pvalue:.3g}")
366
367res_educ = M.dydx("educ")
368print()
369print("Pairwise comparisons across educ levels (Bonferroni-adjusted CIs):")
370print(res_educ.pairwise(by="educ", ci_method="bonferroni").summary())
371
372# ---------------------------------------------------------------------------
373# 19. Per-variable DSL (values=)
374#     Fix specific variables without restating every column.
375# ---------------------------------------------------------------------------
376print()
377print("=" * 80)
378print("19. Per-variable DSL — values= and default_values=")
379print("=" * 80)
380print("Predictions at income=p25/p50/p75, everything else at mean:")
381print(M.predict(values={"income": ["p25", "p50", "p75"]}, default_values="mean"))
382
383print()
384print("Counterfactual: income + 10% via Expr:")
385from smmargins import Expr
386print(M.predict(values={"income": Expr("income * 1.10")}))
387
388# ---------------------------------------------------------------------------
389# 20. Plotting
390# ---------------------------------------------------------------------------
391print()
392print("=" * 80)
393print("20. Plotting — prediction curve vs age")
394print("=" * 80)
395from smmargins import plot_predictions
396fig, ax = plot_predictions(M, "age")
397fig.savefig("demo_plot_predictions.png", dpi=150)
398print("Saved demo_plot_predictions.png")
399
400print()
401print("Plotting — AME of age vs income, by sex:")
402from smmargins import plot_slopes
403fig2, ax2 = plot_slopes(M, "age", condition="income", by="female")
404fig2.savefig("demo_plot_slopes.png", dpi=150)
405print("Saved demo_plot_slopes.png")

Healthcare-style 2x2 difference-in-differences

demo_did.py answers a clinical question:

Is there a rate difference of condition \(X\) between groups A and B, with or without a preexisting condition \(Y\)?

The script fits a logit on simulated patient data and reports, on the probability scale:

  • 4 cell predictions \(P(X \mid \text{group}, Y)\)

  • 2 simple effects \(P(X \mid B, Y) - P(X \mid A, Y)\) at each \(Y\)

  • 1 difference-in-differences (whether the A-vs-B gap depends on \(Y\))

All with delta-method standard errors and confidence intervals. The DiD here is not the coefficient on the group:Y interaction — that coefficient is on the log-odds scale, while the clinical question is about probabilities. This is Ai & Norton (2003) in practice; see Mathematical motivation for the derivation.

Highlights from the script

Fit and call did():

fit = smf.logit(
    "condition_X ~ C(group) + preexist_Y + C(group):preexist_Y "
    "+ age + female",
    data=df,
).fit()
M = Margins(fit)

did = M.did("group", "preexist_Y",
            group_levels=["A", "B"],
            condition_levels=[0, 1])
print(did)              # cells + simple effects + DiD

Same DiD at one specific patient profile (60-year-old male):

M.did("group", "preexist_Y",
      group_levels=["A", "B"], condition_levels=[0, 1],
      atexog={"age": 60, "female": 0})

Plot-ready cell table:

tbl = did.cells.summary()        # estimate / SE / CI per cell

Full source

  1"""
  2demo_did.py
  3===========
  4
  5Difference-in-differences example directly matching the question:
  6
  7    "Is there a rate difference of condition X between group A and B,
  8     with or without preexisting condition Y?"
  9
 10We fit a logit for P(X=1) on group (A/B), preexisting Y (0/1), their
 11interaction, and control covariates. Then we use Margins.did() to get,
 12on the *probability* scale:
 13
 14  * 4 cell predictions      P(X | group, Y)
 15  * 2 simple effects        P(X|B,Y) - P(X|A,Y)   at each Y
 16  * 1 DiD                   (simple effect at Y=1) - (simple effect at Y=0)
 17
 18All with delta-method standard errors and CIs.
 19
 20The DiD here is the "does the A-vs-B gap depend on Y?" question.  It is
 21NOT the coefficient on group×Y (that's on the log-odds scale); on the
 22probability scale you have to go through the inverse link — which is
 23exactly what Margins.did() does.
 24"""
 25import numpy as np
 26import pandas as pd
 27import statsmodels.formula.api as smf
 28
 29from smmargins import Margins
 30
 31pd.options.display.width = 140
 32pd.options.display.float_format = "{: .4f}".format
 33
 34# ---------------------------------------------------------------------------
 35# Simulate patient-level data
 36# ---------------------------------------------------------------------------
 37rng = np.random.default_rng(42)
 38N = 6_000
 39df = pd.DataFrame({
 40    "group":        rng.choice(["A", "B"], N, p=[0.55, 0.45]),
 41    "preexist_Y":   rng.integers(0, 2, N),             # 0 = no Y, 1 = has Y
 42    "age":          rng.normal(55, 15, N).clip(18, 95),
 43    "female":       rng.integers(0, 2, N),
 44})
 45
 46# True data-generating process:
 47#   * baseline rate of X depends on age, sex, and Y
 48#   * group B has a modest additive bump in the log-odds
 49#   * the group effect is AMPLIFIED among patients with preexisting Y
 50#     (this is the thing we want to detect)
 51eta = (
 52    -3.5
 53    + 0.04 * df["age"]
 54    - 0.3 * df["female"]
 55    + 0.5 * (df["group"] == "B")
 56    + 1.1 * df["preexist_Y"]
 57    + 0.8 * (df["group"] == "B") * df["preexist_Y"]   # interaction
 58)
 59df["condition_X"] = (rng.uniform(0, 1, N) < 1 / (1 + np.exp(-eta))).astype(int)
 60
 61print("Raw sample rates of condition X by cell:")
 62print(df.groupby(["group", "preexist_Y"])["condition_X"].mean().round(4))
 63print()
 64
 65# ---------------------------------------------------------------------------
 66# Fit the logit with the group × preexist_Y interaction + controls
 67# ---------------------------------------------------------------------------
 68fit = smf.logit(
 69    "condition_X ~ C(group) + preexist_Y + C(group):preexist_Y + age + female",
 70    data=df,
 71).fit(disp=False)
 72
 73print("=" * 84)
 74print("Logit model (coefficients are on the log-odds scale)")
 75print("=" * 84)
 76print(fit.summary().tables[1])
 77print()
 78
 79# ---------------------------------------------------------------------------
 80# DiD on the *probability* (response) scale — what the clinical question asks
 81# ---------------------------------------------------------------------------
 82# Margins(fit) uses analytic outer Jacobians via family.link.inverse_deriv
 83# when available (Logit qualifies), falling back to central finite
 84# differences otherwise. did() reuses predict()'s machinery, so it
 85# inherits the analytic path automatically. Set Margins(fit,
 86# analytic=False) to force FD if you ever want to cross-check.
 87M = Margins(fit)
 88did = M.did("group", "preexist_Y",
 89            group_levels=["A", "B"],
 90            condition_levels=[0, 1])
 91
 92print("=" * 84)
 93print("DiD on the probability scale, averaged over age and sex distribution")
 94print("=" * 84)
 95print(did)
 96
 97# ---------------------------------------------------------------------------
 98# Interpretation
 99# ---------------------------------------------------------------------------
100pA0 = did.cells.estimate[0]   # group=A, Y=0
101pA1 = did.cells.estimate[1]   # group=A, Y=1
102pB0 = did.cells.estimate[2]   # group=B, Y=0
103pB1 = did.cells.estimate[3]   # group=B, Y=1
104se_simple_Y0 = did.simple_effects.se[0]
105se_simple_Y1 = did.simple_effects.se[1]
106did_est, did_se = did.did.estimate[0], did.did.se[0]
107
108print()
109print("=" * 84)
110print("Plain-language summary")
111print("=" * 84)
112print(f"Condition X rate, group A, no preexisting Y : {pA0:.3%}")
113print(f"Condition X rate, group A, with Y           : {pA1:.3%}")
114print(f"Condition X rate, group B, no preexisting Y : {pB0:.3%}")
115print(f"Condition X rate, group B, with Y           : {pB1:.3%}")
116print()
117print(f"Rate difference (B - A) among NO-Y patients  : "
118      f"{(pB0 - pA0):+.3%}  (SE {se_simple_Y0:.3%})")
119print(f"Rate difference (B - A) among WITH-Y patients: "
120      f"{(pB1 - pA1):+.3%}  (SE {se_simple_Y1:.3%})")
121print()
122print(f"Difference-in-differences                   : "
123      f"{did_est:+.3%}  (SE {did_se:.3%})")
124print(f"  -> the B-vs-A gap is {abs(did_est):.3%} larger among patients "
125      f"with preexisting Y.")
126print(f"  -> 95% CI: ({did.did.ci_lower[0]:+.3%}, {did.did.ci_upper[0]:+.3%})")
127print(f"  -> p-value: {did.did.pvalues[0]:.4g}")
128
129# ---------------------------------------------------------------------------
130# Sensitivity: DiD at a specific patient profile (e.g. 60-year-old male)
131# ---------------------------------------------------------------------------
132print()
133print("=" * 84)
134print("DiD at a specific profile: 60-year-old male")
135print("=" * 84)
136did_profile = M.did(
137    "group", "preexist_Y",
138    group_levels=["A", "B"], condition_levels=[0, 1],
139    atexog={"age": 60, "female": 0},
140)
141print(did_profile.did)
142
143# ---------------------------------------------------------------------------
144# Bonus: plottable table of cell predictions with CIs
145# ---------------------------------------------------------------------------
146print()
147print("=" * 84)
148print("Cells with 95% CIs (suitable for a plot)")
149print("=" * 84)
150tbl = did.cells.summary().copy()
151print(tbl)
152
153# If you wanted to plot:
154#   import matplotlib.pyplot as plt
155#   fig, ax = plt.subplots()
156#   for g in ["A", "B"]:
157#       sub = tbl[tbl.index.str.contains(f"group={g}")]
158#       ax.errorbar([0, 1],
159#                   sub["prediction"].values,
160#                   yerr=(sub["prediction"] - sub["[95% CI lo]"]).values,
161#                   marker="o", label=f"group {g}", capsize=4)
162#   ax.set_xticks([0, 1]); ax.set_xticklabels(["no Y", "with Y"])
163#   ax.set_ylabel("P(condition X)"); ax.legend(); plt.show()