"""Utilities for working with common plots."""
from collections.abc import Callable, Mapping
from functools import reduce, wraps
from pathlib import Path
from textwrap import dedent, indent
from typing import Any, Concatenate, Literal, ParamSpec, TextIO, TypeAlias, cast
from uuid import uuid4
import pandas as pd
import plotly.graph_objects as go
from typing_extensions import deprecated
ImageFormat: TypeAlias = Literal["svg", "pdf", "png", "jpg", "webp"]
[docs]
@deprecated("Use `plotly_to_html` instead.")
def export_html(
fig: go.Figure,
out: Path | str | TextIO,
width: int = 800,
height: int = 450,
scale: float = 3,
name: str = "plot",
download_format: ImageFormat = "svg",
):
"""Save plotly figure to interactive html.
Args:
fig: Plotly figure to save.
out: Path to save to or file-like object to write to.
width: Width of the figure in pixels.
height: Height of the figure in pixels.
scale: Scale factor for the figure.
name: Name of the figure.
download_format: Format to use when downloading the figure.
"""
fig = fig.update_layout(
title=name,
)
with open(out, "w") if isinstance(out, Path | str) else out as f:
f.write(
plotly_to_html(
fig=fig,
download_scale=scale,
download_format=download_format,
)
)
ResponsiveMode: TypeAlias = Literal[None, "keep-aspect", "stretch-width"]
default_responsiveness: dict[ResponsiveMode, int | None] = {
None: 400,
"keep-aspect": 720,
"stretch-width": 1080,
}
with open(Path(__file__).parent / "responsive_plotly.js") as f:
plotly_responsive_js: str = f.read()
def _ind(text: str, count: int) -> str:
"""Indent all lines in text except for first by count spaces."""
return indent(text, " " * count).lstrip(" " * count)
[docs]
def plotly_to_html(
fig: go.Figure,
write_to: Path | str | TextIO | None = None,
responsive: bool = True,
min_width: int = 500,
max_width: int = 1500,
max_height: int = 600,
download_format: ImageFormat = "svg",
download_scale: float = 3,
full_html: bool = True,
plotly_js_url: (
str | None
) = "https://cdn.jsdelivr.net/npm/plotly.js@2/dist/plotly.min.js",
) -> str:
"""Convert plotly figure to interactive (and responsive) html.
Args:
fig: Plotly figure to save.
write_to: Path to save to or file-like object to write to.
responsive: Whether to make the figure responsive.
min_width: Minimum width of the figure in pixels.
max_width: Maximum width of the figure in pixels.
max_height: Maximum height of the figure in pixels.
download_format: Format to use when downloading the figure.
download_scale: Scale factor for the figure resolution when downloading.
full_html: Whether to wrap the figure in a full html document.
plotly_js_url:
URL to load plotly.js library from.
Defaults to the latest version on jsdelivr.
Set to None to leave out the plotly.js script tag.
Returns:
HTML string.
"""
# Get the name of the figure,
# either from the figure itself or from the supplied argument.
# Use fallback name "plot" if neither is available.
name = (
cast(dict, cast(go.Layout, fig.layout).to_plotly_json().get("title") or {}).get(
"text"
)
or "plot"
)
fig_id = str(uuid4())[:10]
script_id = str(uuid4())[:10]
layout = cast(go.Layout, fig.layout).to_plotly_json()
width = layout.get("width") or 720
height = layout.get("height") or 480
# Render the figure to html.
fig_html: str = fig.to_html(
include_plotlyjs=False,
full_html=False,
div_id=fig_id,
config={
"toImageButtonOptions": {
"format": download_format,
"filename": name,
"height": height,
"width": width,
"scale": download_scale,
},
},
)
res_html = fig_html
if responsive:
# Define attributes to pass to the script tag.
script_attrs = {
"plotly-js-url": plotly_js_url,
"fig-id": fig_id,
"fig-width": width,
"fig-height": height,
"min-width": min_width,
"max-width": max_width,
"max-height": max_height,
}
script_attr_str = " ".join(
f'{attr}="{value}"' for attr, value in script_attrs.items()
)
plotly_js_script = (
dedent(
f"""
<script src="{plotly_js_url}"></script>
"""
)
if plotly_js_url is not None
else ""
)
# Combine the figure html with a script tag that makes it responsive,
# and optionally with the plotly.js script tag.
res_html = dedent(
f"""
<div class="plotly-responsive-container" style="width: 100%;">
{_ind(plotly_js_script, 16)}
{_ind(fig_html, 16)}
<script
id="{script_id}"
{_ind(script_attr_str, 20)}
>
var resplotScriptID = "{script_id}";
{_ind(plotly_responsive_js, 20)}
</script>
</div>
"""
)
# Wrap the figure in a full html document if requested.
html = dedent(
f"""
<!doctype html>
<html>
<head>
<title>{name}</title>
</head>
<body>
{_ind(res_html, 16)}
</body>
</html>
"""
if full_html
else res_html
)
# Write the html to a file if requested.
if write_to is not None:
with open(write_to, "w") if isinstance(write_to, Path | str) else write_to as f:
f.write(html)
return html
P = ParamSpec("P")
PlottingFunction = Callable[Concatenate[pd.DataFrame, P], go.Figure]
def _merge_data(figs: Mapping[str, go.Figure | go.Frame]) -> list:
return [d for fig in figs.values() for d in (fig.data or [])]
def _merge_layout(figs: Mapping[str, go.Figure | go.Frame]) -> dict[str, Any]:
return reduce(
lambda a, b: {**a, **b},
[
cast(dict, fig.layout.to_plotly_json()) # type: ignore
for fig in figs.values()
],
)
def _get_frames(figs: Mapping[str, go.Figure], label: str) -> dict[str, go.Frame]:
return {
group: frame
for group, fig in figs.items()
for f in fig.frames
if (frame := cast(go.Frame, f)).name == label
}
[docs]
def with_dropdown(
group_by: str,
dropdown_kwars: dict[str, Any] | None = None,
) -> Callable[[PlottingFunction[P]], PlottingFunction[P]]:
"""Add a dropdown to a plotting function.
Args:
group_by: Column to group by for the dropdown.
dropdown_kwars: Keyword arguments to pass to the dropdown layout.
Returns:
Decorator that adds a dropdown to the returned plot.
"""
def decorator(func: PlottingFunction[P]) -> PlottingFunction[P]:
"""Add a dropdown to a plotting function."""
@wraps(func)
def wrapper(data: pd.DataFrame, *args: P.args, **kwargs: P.kwargs) -> go.Figure:
figs = {
str(group): func(sub_data, *args, **kwargs)
for group, sub_data in data.groupby(group_by)
}
fig_data = _merge_data(figs)
layout = _merge_layout(figs)
frames = None
if "sliders" in layout:
labels = [step["label"] for step in layout["sliders"][0]["steps"]]
frames = []
for label in labels:
sub_frames = _get_frames(figs, label)
frames.append(
go.Frame(
data=_merge_data(sub_frames),
layout=_merge_layout(sub_frames),
name=label,
)
)
split_fig = go.Figure(
data=fig_data,
layout=layout,
frames=frames,
)
trace_sources = pd.Series(
[group for group, fig in figs.items() for _ in fig.data]
)
split_fig.update_layout(
updatemenus=[
*(
[split_fig.layout["updatemenus"][0]] # type: ignore
if "updatemenus" in layout
else []
),
{
"active": 0,
"buttons": [
{
"label": group,
"method": "update",
"args": [
{
"visible": (
trace_sources == str(group)
).to_list(),
},
],
}
for group in figs.keys()
],
"direction": "up",
"pad": {"r": 10, "t": 70},
"showactive": True,
"x": 1.1,
"xanchor": "left",
"y": 0,
"yanchor": "top",
**(dropdown_kwars or {}),
},
]
)
# Show only the first group.
first_group = list(figs.keys())[0]
first_group_trace_count = trace_sources[trace_sources == first_group].size
for i, trace in enumerate(split_fig.select_traces()):
if i < first_group_trace_count:
trace.visible = True
else:
trace.visible = False
return split_fig
return wrapper
return decorator