Source code for moderndive.plots.models

"""Regression-model plotting helpers ported from the R ``moderndive`` package.

- :func:`gg_parallel_slopes`  ~ ``moderndive::gg_parallel_slopes`` (scatter + parallel lines)
- :func:`geom_parallel_slopes` ~ ``moderndive::geom_parallel_slopes`` (plotnine layer)
- :func:`gg_categorical_model` ~ ``moderndive::geom_categorical_model`` (categorical predictor)

The non-plotly backend is plotnine. All functions accept polars or pandas frames.
``gg_*`` return a figure (plotly ``go.Figure`` or plotnine ``ggplot``);
``geom_parallel_slopes`` returns plotnine layers to add to a ``ggplot`` with ``+``.
"""

from __future__ import annotations

import polars as pl

# Qualitative palette (matches plotly's default) so plotnine/plotly look alike.
_PALETTE = [
    "#636EFA",
    "#EF553B",
    "#00CC96",
    "#AB63FA",
    "#FFA15A",
    "#19D3F3",
    "#FF6692",
    "#B6E880",
]
_FIT_COLOR = "#d62728"


def _to_pandas(data, columns):
    df = data if isinstance(data, pl.DataFrame) else pl.from_pandas(data)
    return df.select(columns).drop_nulls().to_pandas()


def _resolve_engine(engine: str) -> str:
    if engine not in ("plotly", "plotnine"):
        raise ValueError(f"engine must be 'plotly' or 'plotnine', got {engine!r}")
    return engine


def _parallel_slopes_fit(pdf, response: str, explanatory: str, by: str):
    """Fit ``response ~ explanatory + C(by)``; return (intercepts, slope, levels).

    ``intercepts`` maps each level of ``by`` to its line intercept (common slope,
    differing intercepts — the definition of a parallel-slopes model).
    """
    import statsmodels.formula.api as smf

    pdf = pdf.copy()
    pdf[by] = pdf[by].astype(str)
    model = smf.ols(f"{response} ~ {explanatory} + C({by})", data=pdf).fit()
    params = model.params
    slope = float(params[explanatory])
    base = float(params["Intercept"])
    levels = sorted(pdf[by].unique())
    intercepts = {}
    for lvl in levels:
        key = f"C({by})[T.{lvl}]"
        intercepts[lvl] = base + (float(params[key]) if key in params.index else 0.0)
    return intercepts, slope, levels


[docs] def gg_parallel_slopes(data, response: str, explanatory: str, by: str, *, engine: str = "plotly"): """Scatterplot with a parallel-slopes regression model overlaid. Fits ``response ~ explanatory + C(by)`` (one common slope, a separate intercept per level of ``by``) and draws one fitted line per group over the data. """ engine = _resolve_engine(engine) pdf = _to_pandas(data, [response, explanatory, by]) pdf[by] = pdf[by].astype(str) intercepts, slope, levels = _parallel_slopes_fit(pdf, response, explanatory, by) xmin, xmax = float(pdf[explanatory].min()), float(pdf[explanatory].max()) if engine == "plotly": import plotly.express as px fig = px.scatter(pdf, x=explanatory, y=response, color=by) for i, lvl in enumerate(levels): b = intercepts[lvl] fig.add_scatter( x=[xmin, xmax], y=[b + slope * xmin, b + slope * xmax], mode="lines", line={"color": _PALETTE[i % len(_PALETTE)]}, name=f"{lvl} fit", showlegend=False, ) fig.update_layout(template="plotly_white") return fig from plotnine import aes, geom_point, ggplot, labs, theme_light return ( ggplot(pdf, aes(x=explanatory, y=response, color=by)) + geom_point() + geom_parallel_slopes(pdf, response, explanatory, by) + labs(x=explanatory, y=response, title="Parallel slopes model") + theme_light() )
[docs] def geom_parallel_slopes(data, response: str, explanatory: str, by: str, color: str | None = None): """plotnine layer(s) drawing the parallel-slopes fitted lines. Add to a ``ggplot`` with ``+`` (plotnine-only; for a plotly version call :func:`gg_parallel_slopes` with ``engine="plotly"``). """ import pandas as pd from plotnine import aes, geom_line pdf = _to_pandas(data, [response, explanatory, by]) pdf[by] = pdf[by].astype(str) intercepts, slope, levels = _parallel_slopes_fit(pdf, response, explanatory, by) xmin, xmax = float(pdf[explanatory].min()), float(pdf[explanatory].max()) rows = [] for lvl in levels: b = intercepts[lvl] rows.append({explanatory: xmin, response: b + slope * xmin, by: lvl}) rows.append({explanatory: xmax, response: b + slope * xmax, by: lvl}) line_df = pd.DataFrame(rows) if color is None: return [geom_line(aes(x=explanatory, y=response, color=by), data=line_df)] return [geom_line(aes(x=explanatory, y=response, group=by), data=line_df, color=color)]
[docs] def gg_categorical_model(data, response: str, explanatory: str, *, engine: str = "plotly"): """Regression with one categorical predictor (~ ``geom_categorical_model``). Fits ``response ~ C(explanatory)``; each category's fitted value is its group mean, drawn as a horizontal marker over the (jittered) data points. """ engine = _resolve_engine(engine) pdf = _to_pandas(data, [response, explanatory]) pdf[explanatory] = pdf[explanatory].astype(str) means = pdf.groupby(explanatory)[response].mean() levels = sorted(pdf[explanatory].unique()) if engine == "plotly": import plotly.express as px fig = px.strip(pdf, x=explanatory, y=response, category_orders={explanatory: levels}) fig.add_scatter( x=levels, y=[float(means[lvl]) for lvl in levels], mode="markers", marker={"symbol": "line-ew", "size": 28, "line": {"color": _FIT_COLOR, "width": 3}}, name="fitted mean", showlegend=False, ) fig.update_layout(template="plotly_white") return fig import pandas as pd from plotnine import aes, geom_point, ggplot, labs, theme_light mean_df = pd.DataFrame({explanatory: levels, response: [float(means[lvl]) for lvl in levels]}) return ( ggplot(pdf, aes(x=explanatory, y=response)) + geom_point(position="jitter", alpha=0.5) + geom_point(data=mean_df, color=_FIT_COLOR, shape="_", size=8) + labs(x=explanatory, y=response, title="Categorical model") + theme_light() )
# R-parity alias: R's helper is named geom_categorical_model(); expose both names. geom_categorical_model = gg_categorical_model