"""Dual-engine visualization for the infer grammar (plotly default, plotnine optional).
Every plotting function takes ``engine="plotly"|"plotnine"`` and defaults to
``"plotly"`` (the book's default engine); pass ``engine="plotnine"`` to get the
grammar-of-graphics objects instead.
``visualize()`` returns an :class:`InferPlot` wrapper that composes with shading
via ``+`` for *both* engines::
visualize(null_dist, bins=25) + shade_p_value(obs_diff, direction="right")
visualize(boot_means) + shade_confidence_interval(endpoints=percentile_ci)
``shade_p_value()`` / ``shade_confidence_interval()`` return an engine-neutral
:class:`ShadeSpec`; ``InferPlot.__add__`` turns it into plotnine layers or applies
plotly shapes as appropriate. A keyword form is also supported and is the most
ergonomic path for plotly::
visualize(null_dist, shade_pvalue={"obs_stat": obs, "direction": "right"})
visualize(boot_means, shade_ci=percentile_ci)
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from . import _common as C
__all__ = [
"InferPlot",
"ShadeSpec",
"visualize",
"visualize_fit",
"visualize_theoretical",
"shade_p_value",
"shade_confidence_interval",
]
@dataclass(frozen=True)
class ShadeSpec:
"""Engine-neutral description of a shading layer (p-value tail or CI).
For a single-panel plot the scalar fields are used. For a faceted
regression-fit plot (:func:`visualize_fit`), ``per_term`` holds one value per
term so each facet is shaded independently — a tuple of ``(term, payload)``
pairs where ``payload`` is the observed statistic (p-value) or a
``(lower, upper)`` pair (confidence interval).
"""
kind: str # "p_value" | "confidence_interval"
obs_stat: float | None = None
direction: str | None = None
lower: float | None = None
upper: float | None = None
color: str | None = None
per_term: tuple | None = None
class InferPlot:
"""A plot plus its engine, composable with :class:`ShadeSpec` (and, for the
plotnine engine, any plotnine layer) via ``+``.
Access the wrapped figure with :attr:`figure`; for the plotnine engine, the
raw ``ggplot`` is also available via :attr:`gg`.
"""
def __init__(self, figure, engine: str, terms: list[str] | None = None):
self.figure = figure
self.engine = engine
# Facet terms for a regression-fit plot (enables per-facet shading).
self.terms = terms
def __add__(self, other):
if isinstance(other, ShadeSpec) and other.per_term is not None:
return self._add_per_facet(other)
if self.engine == "plotnine":
from . import _plotnine as P
if isinstance(other, ShadeSpec):
layers = (
P.shade_pvalue_layers(other)
if other.kind == "p_value"
else P.shade_ci_layers(other)
)
return InferPlot(self.figure + layers, "plotnine", self.terms)
# Any other plotnine object/layer/list.
return InferPlot(self.figure + other, "plotnine", self.terms)
from . import _plotly as PX
if isinstance(other, ShadeSpec):
return InferPlot(PX.apply_shade_px(self.figure, other), "plotly", self.terms)
raise TypeError(
"Only a ShadeSpec (from shade_p_value/shade_confidence_interval) can be "
"added to a plotly InferPlot."
)
def _add_per_facet(self, spec: ShadeSpec):
"""Apply per-term shading to a faceted regression-fit plot (both engines)."""
if self.terms is None:
raise TypeError(
"Per-term shading requires a faceted fit plot from visualize_fit(); "
"pass a per-term obs_stat/endpoints (a FitResult or a term-keyed table)."
)
if self.engine == "plotnine":
from . import _plotnine as P
return InferPlot(
P.apply_fit_shade_gg(self.figure, spec, self.terms), "plotnine", self.terms
)
from . import _plotly as PX
return InferPlot(PX.apply_fit_shade_px(self.figure, spec, self.terms), "plotly", self.terms)
@property
def gg(self):
"""The underlying plotnine ``ggplot`` (only when ``engine='plotnine'``)."""
if self.engine != "plotnine":
raise AttributeError("`.gg` is only available when engine='plotnine'.")
return self.figure
def show(self, **kwargs):
"""Display the figure (delegates to the underlying figure)."""
return self.figure.show(**kwargs)
def save(self, path, **kwargs):
"""Save the figure. plotnine forwards to ``ggplot.save``; plotly writes an
HTML file for ``.html`` paths and a static image otherwise (the latter
needs the optional ``kaleido`` dependency — install ``moderndive[image]``).
"""
if self.engine == "plotnine":
return self.figure.save(path, **kwargs)
path_str = str(path)
if path_str.endswith(".html"):
return self.figure.write_html(path_str)
return self.figure.write_image(path_str)
def _repr_mimebundle_(self, include=None, exclude=None):
"""Rich display for Jupyter/Quarto across both engines.
Both plotnine ``ggplot`` and plotly ``Figure`` expose
``_repr_mimebundle_``, so we delegate to it and the wrapped figure renders
exactly as a bare figure would. We fall back to ``text/html`` and finally
to no output. This is required because ``_repr_html_`` alone returns
``None`` for a ``ggplot`` — without this method, plotnine-engine
``InferPlot``s render blank in notebooks and Quarto.
"""
fig = self.figure
if hasattr(fig, "_repr_mimebundle_"):
return fig._repr_mimebundle_(include=include, exclude=exclude)
if hasattr(fig, "_repr_html_"):
return {"text/html": fig._repr_html_()}
return None
def _repr_html_(self):
fig = self.figure
return fig._repr_html_() if hasattr(fig, "_repr_html_") else None
def __repr__(self) -> str:
return repr(self.figure)
def _coerce_pvalue_spec(value) -> ShadeSpec:
if isinstance(value, ShadeSpec):
return value
if isinstance(value, dict):
return shade_p_value(**value)
raise TypeError("shade_pvalue must be a ShadeSpec or a dict of shade_p_value() kwargs.")
def _coerce_ci_spec(value) -> ShadeSpec:
if isinstance(value, ShadeSpec):
return value
return shade_confidence_interval(value)
[docs]
def visualize(
distribution,
bins: int = 20,
*,
engine: str = "plotly",
method: str = "simulation",
shade_pvalue=None,
shade_ci=None,
**kwargs,
) -> InferPlot:
"""Histogram of the simulated statistics, as an :class:`InferPlot`.
``method`` is ``"simulation"`` (histogram, default), ``"theoretical"`` (a
normal-approximation density curve), or ``"both"`` (histogram in density units
overlaid with the normal curve), mirroring R ``infer``'s ``visualize(method=)``.
Pass ``shade_pvalue=``/``shade_ci=`` to shade in one call, or compose with ``+``.
"""
engine = C.resolve_engine(engine)
method = C.resolve_method(method)
if engine == "plotnine":
from . import _plotnine as P
fig = P.visualize_gg(distribution, bins, method)
else:
from . import _plotly as PX
fig = PX.visualize_px(distribution, bins, method)
plot = InferPlot(fig, engine)
if shade_pvalue is not None:
plot = plot + _coerce_pvalue_spec(shade_pvalue)
if shade_ci is not None:
plot = plot + _coerce_ci_spec(shade_ci)
return plot
def visualize_fit(
fit, bins: int = 20, *, engine: str = "plotly", shade_pvalue=None, shade_ci=None
) -> InferPlot:
"""Faceted histogram of a regression fit distribution, one panel per term.
Pass ``shade_pvalue=``/``shade_ci=`` to shade each facet from per-term values
(a ``FitResult`` of observed estimates, or a term-keyed CI/p-value table), or
compose the same per-term :class:`ShadeSpec` with ``+``.
"""
engine = C.resolve_engine(engine)
terms = fit.data["term"].unique(maintain_order=True).to_list()
if engine == "plotnine":
from . import _plotnine as P
fig = P.visualize_fit_gg(fit, bins)
else:
from . import _plotly as PX
fig = PX.visualize_fit_px(fit, bins)
plot = InferPlot(fig, engine, terms)
if shade_pvalue is not None:
plot = plot + _coerce_pvalue_spec(shade_pvalue)
if shade_ci is not None:
plot = plot + _coerce_ci_spec(shade_ci)
return plot
def visualize_theoretical(theoretical, bins: int = 100, *, engine: str = "plotly") -> InferPlot:
"""Density curve for a :class:`~moderndive.infer.TheoreticalDistribution`."""
engine = C.resolve_engine(engine)
dist = theoretical._dist()
lo, hi = dist.ppf(0.001), dist.ppf(0.999)
x = np.linspace(lo, hi, 400)
density = dist.pdf(x)
title = f"Theoretical {theoretical.distribution} distribution"
if engine == "plotnine":
from . import _plotnine as P
return InferPlot(P.density_curve_gg(x, density, title), engine)
from . import _plotly as PX
return InferPlot(PX.density_curve_px(x, density, title), engine)
def _per_term_obs(obs_stat) -> dict | None:
"""Extract a ``{term: observed}`` mapping for per-facet p-value shading.
Accepts an observed ``FitResult`` (term/estimate), a term-keyed polars frame
(``term`` + ``estimate`` or ``stat``), or a dict. Returns ``None`` for a scalar.
"""
import polars as pl
if isinstance(obs_stat, pl.DataFrame) and "term" in obs_stat.columns:
valcol = "estimate" if "estimate" in obs_stat.columns else "stat"
return {r["term"]: float(r[valcol]) for r in obs_stat.iter_rows(named=True)}
data = getattr(obs_stat, "data", None)
if isinstance(data, pl.DataFrame) and {"term", "estimate"} <= set(data.columns):
return {r["term"]: float(r["estimate"]) for r in data.iter_rows(named=True)}
if isinstance(obs_stat, dict):
return {k: float(v) for k, v in obs_stat.items()}
return None
def _per_term_ci(endpoints) -> dict | None:
"""Extract a ``{term: (lower, upper)}`` mapping for per-facet CI shading."""
import polars as pl
if isinstance(endpoints, pl.DataFrame) and "term" in endpoints.columns:
return {
r["term"]: (float(r["lower_ci"]), float(r["upper_ci"]))
for r in endpoints.iter_rows(named=True)
}
return None
[docs]
def shade_p_value(obs_stat, direction: str, *, color: str | None = None) -> ShadeSpec:
"""A p-value shading spec; add it to a ``visualize()`` plot with ``+``.
``direction`` ∈ {right/greater, left/less, two-sided}. For a faceted
:func:`visualize_fit` plot, pass a per-term ``obs_stat`` — an observed
``FitResult``, a ``term``-keyed frame, or a dict — to shade each facet.
"""
per = _per_term_obs(obs_stat)
if per is not None:
return ShadeSpec(
kind="p_value", direction=direction, color=color, per_term=tuple(sorted(per.items()))
)
return ShadeSpec(kind="p_value", obs_stat=float(obs_stat), direction=direction, color=color)
[docs]
def shade_confidence_interval(endpoints, color: str | None = None) -> ShadeSpec:
"""A confidence-interval shading spec; add it to a ``visualize()`` plot with ``+``.
``endpoints`` is a CI DataFrame (``lower_ci``/``upper_ci``) or a ``(lower, upper)``
tuple. For a faceted :func:`visualize_fit` plot, pass a per-term CI table (with a
``term`` column) to shade each facet from its own interval.
"""
per = _per_term_ci(endpoints)
if per is not None:
return ShadeSpec(
kind="confidence_interval", color=color, per_term=tuple(sorted(per.items()))
)
lower, upper = C.ci_endpoints(endpoints)
return ShadeSpec(kind="confidence_interval", lower=lower, upper=upper, color=color)