Source code for moderndive.plots.models

"""Regression-model plotting helpers ported from the R ``moderndive`` package.

- :func:`gg_parallel_slopes`  ~ ``moderndive::gg_parallel_slopes`` (scatter + parallel lines)
- :func:`geom_parallel_slopes` ~ ``moderndive::geom_parallel_slopes`` (plotnine layer)
- :func:`gg_categorical_model` ~ ``moderndive::geom_categorical_model`` (categorical predictor)
- :func:`plot_3d_regression` ~ ``moderndive::plot_3d_regression`` (3D scatter + plane)

The non-plotly backend is plotnine. All functions accept polars or pandas frames.
``gg_*`` return a figure (plotly ``go.Figure`` or plotnine ``ggplot``);
``geom_parallel_slopes`` returns plotnine layers to add to a ``ggplot`` with ``+``.
"""

from __future__ import annotations

import re

import polars as pl

from .._messaging import helpful_error

# Qualitative palette (matches plotly's default) so plotnine/plotly look alike.
_PALETTE = [
    "#636EFA",
    "#EF553B",
    "#00CC96",
    "#AB63FA",
    "#FFA15A",
    "#19D3F3",
    "#FF6692",
    "#B6E880",
]
_FIT_COLOR = "#d62728"


def _to_pandas(data, columns):
    df = data if isinstance(data, pl.DataFrame) else pl.from_pandas(data)
    return df.select(columns).drop_nulls().to_pandas()


def _resolve_engine(engine: str) -> str:
    if engine not in ("plotly", "plotnine"):
        raise ValueError(f"engine must be 'plotly' or 'plotnine', got {engine!r}")
    return engine


def _parallel_slopes_fit(pdf, response: str, explanatory: str, by: str):
    """Fit ``response ~ explanatory + C(by)``; return (intercepts, slope, levels).

    ``intercepts`` maps each level of ``by`` to its line intercept (common slope,
    differing intercepts — the definition of a parallel-slopes model).
    """
    import statsmodels.formula.api as smf

    pdf = pdf.copy()
    pdf[by] = pdf[by].astype(str)
    model = smf.ols(f"{response} ~ {explanatory} + C({by})", data=pdf).fit()
    params = model.params
    slope = float(params[explanatory])
    base = float(params["Intercept"])
    levels = sorted(pdf[by].unique())
    intercepts = {}
    for lvl in levels:
        key = f"C({by})[T.{lvl}]"
        intercepts[lvl] = base + (float(params[key]) if key in params.index else 0.0)
    return intercepts, slope, levels


[docs] def gg_parallel_slopes( data, response: str, explanatory: str, by: str, *, alpha: float = 1.0, engine: str = "plotly" ): """Scatterplot with a parallel-slopes regression model overlaid. Fits ``response ~ explanatory + C(by)`` (one common slope, a separate intercept per level of ``by``) and draws one fitted line per group over the data. ``alpha`` sets the point transparency (0–1), useful when points overlap. """ engine = _resolve_engine(engine) pdf = _to_pandas(data, [response, explanatory, by]) pdf[by] = pdf[by].astype(str) intercepts, slope, levels = _parallel_slopes_fit(pdf, response, explanatory, by) xmin, xmax = float(pdf[explanatory].min()), float(pdf[explanatory].max()) if engine == "plotly": import plotly.express as px fig = px.scatter(pdf, x=explanatory, y=response, color=by, opacity=alpha) for i, lvl in enumerate(levels): b = intercepts[lvl] fig.add_scatter( x=[xmin, xmax], y=[b + slope * xmin, b + slope * xmax], mode="lines", line={"color": _PALETTE[i % len(_PALETTE)]}, name=f"{lvl} fit", showlegend=False, ) fig.update_layout(template="plotly_white") return fig from plotnine import aes, geom_point, ggplot, labs, theme_light return ( ggplot(pdf, aes(x=explanatory, y=response, color=by)) + geom_point(alpha=alpha) + geom_parallel_slopes(pdf, response, explanatory, by) + labs(x=explanatory, y=response, title="Parallel slopes model") + theme_light() )
[docs] def geom_parallel_slopes(data, response: str, explanatory: str, by: str, color: str | None = None): """plotnine layer(s) drawing the parallel-slopes fitted lines. Add to a ``ggplot`` with ``+`` (plotnine-only; for a plotly version call :func:`gg_parallel_slopes` with ``engine="plotly"``). """ import pandas as pd from plotnine import aes, geom_line pdf = _to_pandas(data, [response, explanatory, by]) pdf[by] = pdf[by].astype(str) intercepts, slope, levels = _parallel_slopes_fit(pdf, response, explanatory, by) xmin, xmax = float(pdf[explanatory].min()), float(pdf[explanatory].max()) rows = [] for lvl in levels: b = intercepts[lvl] rows.append({explanatory: xmin, response: b + slope * xmin, by: lvl}) rows.append({explanatory: xmax, response: b + slope * xmax, by: lvl}) line_df = pd.DataFrame(rows) if color is None: return [geom_line(aes(x=explanatory, y=response, color=by), data=line_df)] return [geom_line(aes(x=explanatory, y=response, group=by), data=line_df, color=color)]
[docs] def gg_categorical_model(data, response: str, explanatory: str, *, engine: str = "plotly"): """Regression with one categorical predictor (~ ``geom_categorical_model``). Fits ``response ~ C(explanatory)``; each category's fitted value is its group mean, drawn as a horizontal marker over the (jittered) data points. """ engine = _resolve_engine(engine) pdf = _to_pandas(data, [response, explanatory]) pdf[explanatory] = pdf[explanatory].astype(str) means = pdf.groupby(explanatory)[response].mean() levels = sorted(pdf[explanatory].unique()) if engine == "plotly": import plotly.express as px fig = px.strip(pdf, x=explanatory, y=response, category_orders={explanatory: levels}) fig.add_scatter( x=levels, y=[float(means[lvl]) for lvl in levels], mode="markers", marker={"symbol": "line-ew", "size": 28, "line": {"color": _FIT_COLOR, "width": 3}}, name="fitted mean", showlegend=False, ) fig.update_layout(template="plotly_white") return fig import pandas as pd from plotnine import aes, geom_point, ggplot, labs, theme_light mean_df = pd.DataFrame({explanatory: levels, response: [float(means[lvl]) for lvl in levels]}) return ( ggplot(pdf, aes(x=explanatory, y=response)) + geom_point(position="jitter", alpha=0.5) + geom_point(data=mean_df, color=_FIT_COLOR, shape="_", size=8) + labs(x=explanatory, y=response, title="Categorical model") + theme_light() )
# R-parity alias: R's helper is named geom_categorical_model(); expose both names. geom_categorical_model = gg_categorical_model _IDENTIFIER = re.compile(r"^[A-Za-z_]\w*$")
[docs] def plot_3d_regression(data, formula: str, n: int = 25): """Interactive 3D scatterplot with a fitted regression plane. Mirrors ``moderndive::plot_3d_regression``. Pass a formula ``z ~ x + y`` — one numeric outcome and exactly two numeric predictors — and get a plotly ``go.Figure`` with the data points and the fitted ``lm`` plane. In-formula transformations (e.g. ``log(z) ~ x + y``) are **not** supported, since the plane and the raw points would be on different scales; transform the columns of ``data`` first and pass plain names. ``n`` sets the plane's grid resolution per axis. """ import numpy as np import pandas as pd import plotly.graph_objects as go import statsmodels.formula.api as smf df = data if isinstance(data, pl.DataFrame) else pl.from_pandas(data) if not isinstance(n, int) or n < 2: raise ValueError( helpful_error(f"`n` must be an integer ≥ 2, got {n!r}.", "Try n=25 (the default).") ) if "~" not in formula: raise ValueError( helpful_error( f"formula must look like 'z ~ x + y', got {formula!r}.", "Put the outcome on the left of ~ and exactly two predictors on the right.", ) ) lhs, rhs = (part.strip() for part in formula.split("~", 1)) predictors = [v.strip() for v in rhs.split("+") if v.strip()] if not _IDENTIFIER.match(lhs) or not all(_IDENTIFIER.match(p) for p in predictors): raise ValueError( helpful_error( "plot_3d_regression() does not support in-formula transformations " "(e.g. 'log(z) ~ x + y').", "Transform the columns of `data` first, then pass plain names like 'z ~ x + y'.", ) ) if len(predictors) != 2: raise ValueError( helpful_error( f"The right-hand side must name exactly two predictors (got {len(predictors)}).", "Use a formula like 'z ~ x + y'.", ) ) missing = [c for c in [lhs, *predictors] if c not in df.columns] if missing: raise ValueError( helpful_error( f"Column(s) not found in the data: {', '.join(missing)}.", f"Available columns: {', '.join(df.columns)}.", ) ) for var in [lhs, *predictors]: if not df.schema[var].is_numeric(): raise ValueError( helpful_error( f"`{var}` must be numeric — a regression plane needs three continuous " "variables.", "For a categorical predictor, see gg_categorical_model().", ) ) x_var, y_var = predictors pdf = df.select(lhs, x_var, y_var).drop_nulls().to_pandas() model = smf.ols(f"{lhs} ~ {x_var} + {y_var}", data=pdf).fit() x_seq = np.linspace(pdf[x_var].min(), pdf[x_var].max(), n) y_seq = np.linspace(pdf[y_var].min(), pdf[y_var].max(), n) grid_x, grid_y = np.meshgrid(x_seq, y_seq) grid = pd.DataFrame({x_var: grid_x.ravel(), y_var: grid_y.ravel()}) z_pred = np.asarray(model.predict(grid)).reshape(grid_x.shape) fig = go.Figure( go.Scatter3d( x=pdf[x_var], y=pdf[y_var], z=pdf[lhs], mode="markers", marker={"size": 4}, name="data", ) ) fig.add_surface( x=x_seq, y=y_seq, z=z_pred, opacity=0.6, showscale=False, name="regression plane" ) fig.update_layout( template="plotly_white", scene={"xaxis_title": x_var, "yaxis_title": y_var, "zaxis_title": lhs}, ) return fig