pycorpdiff 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. pycorpdiff/__init__.py +126 -0
  2. pycorpdiff/_backends/__init__.py +3 -0
  3. pycorpdiff/_backends/pandas.py +3 -0
  4. pycorpdiff/_backends/polars.py +3 -0
  5. pycorpdiff/collocation/__init__.py +19 -0
  6. pycorpdiff/collocation/cooccurrence.py +65 -0
  7. pycorpdiff/collocation/measures.py +102 -0
  8. pycorpdiff/collocation/network.py +233 -0
  9. pycorpdiff/collocation/shift.py +146 -0
  10. pycorpdiff/compare.py +345 -0
  11. pycorpdiff/corpus.py +411 -0
  12. pycorpdiff/datasets/__init__.py +27 -0
  13. pycorpdiff/datasets/_data/hansard_sample.parquet +0 -0
  14. pycorpdiff/datasets/_generate_hansard.py +221 -0
  15. pycorpdiff/datasets/hansard.py +235 -0
  16. pycorpdiff/datasets/histwords.py +221 -0
  17. pycorpdiff/explain.py +177 -0
  18. pycorpdiff/io/__init__.py +16 -0
  19. pycorpdiff/io/duckdb.py +92 -0
  20. pycorpdiff/io/huggingface.py +142 -0
  21. pycorpdiff/io/readers.py +138 -0
  22. pycorpdiff/keyness/__init__.py +26 -0
  23. pycorpdiff/keyness/bayes.py +50 -0
  24. pycorpdiff/keyness/chi_squared.py +94 -0
  25. pycorpdiff/keyness/correction.py +34 -0
  26. pycorpdiff/keyness/dispersion.py +89 -0
  27. pycorpdiff/keyness/effect_sizes.py +65 -0
  28. pycorpdiff/keyness/loglikelihood.py +92 -0
  29. pycorpdiff/keyness/multicorpus.py +143 -0
  30. pycorpdiff/keyness/permutation.py +154 -0
  31. pycorpdiff/py.typed +0 -0
  32. pycorpdiff/results.py +635 -0
  33. pycorpdiff/semantic/__init__.py +18 -0
  34. pycorpdiff/semantic/alignment.py +53 -0
  35. pycorpdiff/semantic/embed.py +84 -0
  36. pycorpdiff/semantic/shift.py +224 -0
  37. pycorpdiff/semantic/trajectory.py +166 -0
  38. pycorpdiff/stats.py +69 -0
  39. pycorpdiff/temporal/__init__.py +15 -0
  40. pycorpdiff/temporal/bocpd.py +233 -0
  41. pycorpdiff/temporal/causal_impact.py +293 -0
  42. pycorpdiff/temporal/changepoint.py +92 -0
  43. pycorpdiff/temporal/forecast.py +405 -0
  44. pycorpdiff/temporal/its.py +123 -0
  45. pycorpdiff/temporal/slicing.py +174 -0
  46. pycorpdiff/tokenize.py +110 -0
  47. pycorpdiff/viz/__init__.py +37 -0
  48. pycorpdiff/viz/bocpd.py +173 -0
  49. pycorpdiff/viz/causal_impact.py +142 -0
  50. pycorpdiff/viz/collocation.py +48 -0
  51. pycorpdiff/viz/dispersion.py +117 -0
  52. pycorpdiff/viz/forecast.py +129 -0
  53. pycorpdiff/viz/keyness.py +96 -0
  54. pycorpdiff/viz/network.py +186 -0
  55. pycorpdiff/viz/scattertext.py +160 -0
  56. pycorpdiff/viz/semantic_forecast.py +114 -0
  57. pycorpdiff/viz/trajectory.py +48 -0
  58. pycorpdiff-0.1.0a0.dist-info/METADATA +230 -0
  59. pycorpdiff-0.1.0a0.dist-info/RECORD +61 -0
  60. pycorpdiff-0.1.0a0.dist-info/WHEEL +4 -0
  61. pycorpdiff-0.1.0a0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,129 @@
1
+ """Forecast plot — solid history continues into dashed forecast.
2
+
3
+ Visual grammar: the same Wilson-CI band + line + points the trajectory
4
+ plot already uses for observed periods, then a *dashed* line + lighter
5
+ prediction-interval band for the forecast horizon. The visual handoff
6
+ between the two regions is what makes the chart read "this is the
7
+ observed history, this is what we project forward".
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import TYPE_CHECKING
13
+
14
+ import pandas as pd
15
+
16
+ if TYPE_CHECKING:
17
+ import altair as alt
18
+
19
+
20
+ def forecast_plot(
21
+ history: pd.DataFrame,
22
+ forecast: pd.DataFrame,
23
+ width: int = 600,
24
+ height: int = 320,
25
+ ) -> alt.Chart:
26
+ """Layered plot: observed Wilson-CI trajectory + dashed forecast band.
27
+
28
+ Parameters
29
+ ----------
30
+ history
31
+ The trajectory table — must carry ``period``, ``term``,
32
+ ``relfreq``, ``ci_lower``, ``ci_upper``.
33
+ forecast
34
+ The forecast table — must carry ``period``, ``term``,
35
+ ``point``, ``ci_lower``, ``ci_upper``.
36
+ width, height
37
+ Canvas dimensions.
38
+ """
39
+ import altair as alt
40
+
41
+ h = history.copy()
42
+ f = forecast.copy()
43
+ if isinstance(h["period"].iloc[0], pd.Period):
44
+ h["period"] = h["period"].apply(lambda p: p.to_timestamp())
45
+ if len(f) and isinstance(f["period"].iloc[0], pd.Period):
46
+ f["period"] = f["period"].apply(lambda p: p.to_timestamp())
47
+
48
+ base_h = alt.Chart(h).encode(
49
+ x=alt.X("period:T", title=None),
50
+ color=alt.Color("term:N", title=None),
51
+ )
52
+ history_band = base_h.mark_area(opacity=0.18).encode(
53
+ y=alt.Y("ci_lower:Q", title="Relative frequency"),
54
+ y2="ci_upper:Q",
55
+ )
56
+ history_line = base_h.mark_line(strokeWidth=2).encode(
57
+ y="relfreq:Q",
58
+ tooltip=[
59
+ "period",
60
+ "term",
61
+ "count",
62
+ "total",
63
+ alt.Tooltip("relfreq:Q", format=".5f"),
64
+ alt.Tooltip("ci_lower:Q", format=".5f"),
65
+ alt.Tooltip("ci_upper:Q", format=".5f"),
66
+ ],
67
+ )
68
+ history_points = base_h.mark_point(filled=True, size=50).encode(
69
+ y="relfreq:Q",
70
+ )
71
+
72
+ base_f = alt.Chart(f).encode(
73
+ x=alt.X("period:T", title=None),
74
+ color=alt.Color("term:N", title=None),
75
+ )
76
+ forecast_band = base_f.mark_area(opacity=0.12).encode(
77
+ y=alt.Y("ci_lower:Q"),
78
+ y2="ci_upper:Q",
79
+ )
80
+ forecast_line = base_f.mark_line(strokeDash=[6, 4], strokeWidth=2).encode(
81
+ y="point:Q",
82
+ tooltip=[
83
+ "period",
84
+ "term",
85
+ alt.Tooltip("point:Q", format=".5f"),
86
+ alt.Tooltip("ci_lower:Q", format=".5f"),
87
+ alt.Tooltip("ci_upper:Q", format=".5f"),
88
+ ],
89
+ )
90
+ forecast_points = base_f.mark_point(
91
+ filled=False, strokeWidth=2, size=50
92
+ ).encode(y="point:Q")
93
+
94
+ # Stitch history + forecast at the seam: a connector line from the
95
+ # last observed value to the first forecast point so the chart
96
+ # reads as a single trajectory rather than two disconnected lines.
97
+ if len(f) and len(h):
98
+ last_h = h.sort_values("period").groupby("term").tail(1)
99
+ first_f = f.sort_values("period").groupby("term").head(1)
100
+ seam = pd.concat(
101
+ [
102
+ last_h.assign(point=last_h["relfreq"])[
103
+ ["period", "term", "point"]
104
+ ],
105
+ first_f[["period", "term", "point"]],
106
+ ]
107
+ )
108
+ seam_line = (
109
+ alt.Chart(seam)
110
+ .mark_line(strokeDash=[6, 4], strokeWidth=2, opacity=0.55)
111
+ .encode(
112
+ x="period:T",
113
+ y="point:Q",
114
+ color=alt.Color("term:N", legend=None),
115
+ )
116
+ )
117
+ else:
118
+ seam_line = alt.Chart(pd.DataFrame({"period": [], "point": [], "term": []}))
119
+
120
+ chart = (
121
+ history_band
122
+ + history_line
123
+ + history_points
124
+ + seam_line
125
+ + forecast_band
126
+ + forecast_line
127
+ + forecast_points
128
+ ).properties(width=width, height=height)
129
+ return chart # type: ignore[no-any-return]
@@ -0,0 +1,96 @@
1
+ """Keyness visualisations — volcano plot and top-N bar chart."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ if TYPE_CHECKING:
11
+ import altair as alt
12
+
13
+
14
+ def keyness_volcano(
15
+ df: pd.DataFrame,
16
+ width: int = 600,
17
+ height: int = 400,
18
+ n_labels: int = 15,
19
+ ) -> alt.Chart:
20
+ """Volcano-style scatter: effect size (x) versus significance (y).
21
+
22
+ Expects the columns produced by :meth:`Comparison.keyness`:
23
+ ``term``, ``log_ratio`` (or fall back to ``g2``), and ``p_value``.
24
+ The top ``n_labels`` rows by ``|log_ratio|`` get a text label;
25
+ everything else is plotted as a circle only.
26
+ """
27
+ import altair as alt
28
+
29
+ if "log_ratio" in df.columns:
30
+ x_col = "log_ratio"
31
+ x_title = "LogRatio (positive = overused in A)"
32
+ else:
33
+ x_col = "g2"
34
+ x_title = "Signed G² (positive = overused in A)"
35
+
36
+ # Significance axis: -log10(p), with infinities clipped at a sensible cap
37
+ # so a single near-zero p doesn't crush the rest of the plot vertically.
38
+ with np.errstate(divide="ignore"):
39
+ neg_log_p = -np.log10(np.clip(df["p_value"].to_numpy(), 1e-300, 1.0))
40
+ plot_df = df.assign(neg_log_p=neg_log_p)
41
+
42
+ base = alt.Chart(plot_df).encode(
43
+ x=alt.X(f"{x_col}:Q", title=x_title),
44
+ y=alt.Y("neg_log_p:Q", title="−log₁₀(p)"),
45
+ tooltip=list(df.columns),
46
+ )
47
+ points = base.mark_circle(opacity=0.55, size=60).encode(
48
+ color=alt.condition(
49
+ alt.datum[x_col] >= 0,
50
+ alt.value("#1f77b4"), # A-leaning
51
+ alt.value("#d62728"), # B-leaning
52
+ )
53
+ )
54
+ label_subset = plot_df.assign(_abs=plot_df[x_col].abs()).nlargest(n_labels, "_abs")
55
+ labels = alt.Chart(label_subset).mark_text(align="left", dx=6, fontSize=10).encode(
56
+ x=f"{x_col}:Q",
57
+ y="neg_log_p:Q",
58
+ text="term:N",
59
+ )
60
+ return (points + labels).properties(width=width, height=height) # type: ignore[no-any-return]
61
+
62
+
63
+ def keyness_top_n_bar(
64
+ df: pd.DataFrame,
65
+ n: int = 20,
66
+ width: int = 500,
67
+ height: int | None = None,
68
+ ) -> alt.Chart:
69
+ """Top-N horizontal bar chart, sorted by ``|g2|`` (signed).
70
+
71
+ Positive bars are A-leaning, negative are B-leaning. Useful when
72
+ you want a clean publication-ready figure rather than the
73
+ information-dense volcano.
74
+ """
75
+ import altair as alt
76
+
77
+ subset = df.assign(_abs=df["g2"].abs()).nlargest(n, "_abs").drop(columns="_abs")
78
+ if height is None:
79
+ height = max(200, 18 * len(subset))
80
+
81
+ chart = (
82
+ alt.Chart(subset)
83
+ .mark_bar()
84
+ .encode(
85
+ x=alt.X("g2:Q", title="Signed G²"),
86
+ y=alt.Y("term:N", sort="-x", title=None),
87
+ color=alt.condition(
88
+ alt.datum.g2 >= 0,
89
+ alt.value("#1f77b4"),
90
+ alt.value("#d62728"),
91
+ ),
92
+ tooltip=list(subset.columns),
93
+ )
94
+ .properties(width=width, height=height)
95
+ )
96
+ return chart # type: ignore[no-any-return]
@@ -0,0 +1,186 @@
1
+ """Network plot for term co-occurrence graphs.
2
+
3
+ Renders a :class:`pycorpdiff.collocation.NetworkResult` as an altair
4
+ chart with circles for nodes, rules for edges, and text labels. Node
5
+ positions come from a spring-force layout if ``networkx`` is
6
+ installed; otherwise the nodes fall back to a circular layout, which
7
+ still gives a structurally faithful (if visually flatter) picture.
8
+
9
+ altair is intentionally pinned at the ``[viz]`` extra to avoid pulling
10
+ heavyweight rendering deps into the base install.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING
17
+
18
+ import pandas as pd
19
+
20
+ if TYPE_CHECKING:
21
+ import altair as alt
22
+
23
+ from ..collocation.network import NetworkResult
24
+
25
+
26
+ def network_plot(
27
+ result: NetworkResult,
28
+ *,
29
+ width: int = 700,
30
+ height: int = 700,
31
+ max_edges: int = 100,
32
+ label_top_n: int = 30,
33
+ seed: int = 0,
34
+ ) -> alt.Chart:
35
+ """Plot a :class:`NetworkResult` as a force-directed-style network.
36
+
37
+ Parameters
38
+ ----------
39
+ result
40
+ The network to render.
41
+ width, height
42
+ Canvas dimensions in pixels. Square by default.
43
+ max_edges
44
+ Only the top ``max_edges`` edges by ``|weight|`` are drawn,
45
+ preventing dense networks from rendering as a black blob.
46
+ label_top_n
47
+ Inline-label budget. Only the ``label_top_n`` highest-degree
48
+ nodes get text labels next to their dot; the others remain
49
+ bare circles with hover tooltips.
50
+ seed
51
+ Random seed for the spring layout's starting configuration.
52
+ """
53
+ import altair as alt
54
+
55
+ nodes = result.nodes.copy()
56
+ edges = result.edges.head(max_edges).copy()
57
+
58
+ positions = _layout(nodes, edges, seed=seed)
59
+ nodes_xy = nodes.join(positions, how="left").reset_index(names="term")
60
+
61
+ # Edge endpoints — attach source / target coordinates.
62
+ edges_xy = edges.merge(
63
+ positions.rename(columns={"x": "x_src", "y": "y_src"}),
64
+ left_on="source",
65
+ right_index=True,
66
+ ).merge(
67
+ positions.rename(columns={"x": "x_tgt", "y": "y_tgt"}),
68
+ left_on="target",
69
+ right_index=True,
70
+ )
71
+
72
+ edge_layer = (
73
+ alt.Chart(edges_xy)
74
+ .mark_rule(opacity=0.35, color="#777")
75
+ .encode(
76
+ x=alt.X("x_src:Q", axis=None, scale=alt.Scale(domain=[-1.1, 1.1])),
77
+ y=alt.Y("y_src:Q", axis=None, scale=alt.Scale(domain=[-1.1, 1.1])),
78
+ x2="x_tgt:Q",
79
+ y2="y_tgt:Q",
80
+ strokeWidth=alt.Size(
81
+ "weight:Q",
82
+ scale=alt.Scale(range=[0.5, 4]),
83
+ legend=alt.Legend(title=f"Edge weight ({result.measure})"),
84
+ ),
85
+ tooltip=["source:N", "target:N", "cooccur_count:Q", "weight:Q"],
86
+ )
87
+ )
88
+
89
+ node_layer = (
90
+ alt.Chart(nodes_xy)
91
+ .mark_circle(opacity=0.85, color="#1f77b4")
92
+ .encode(
93
+ x=alt.X("x:Q", axis=None, scale=alt.Scale(domain=[-1.1, 1.1])),
94
+ y=alt.Y("y:Q", axis=None, scale=alt.Scale(domain=[-1.1, 1.1])),
95
+ size=alt.Size(
96
+ "count:Q",
97
+ scale=alt.Scale(range=[80, 600]),
98
+ legend=alt.Legend(title="Term frequency"),
99
+ ),
100
+ tooltip=["term:N", "count:Q", "degree:Q"],
101
+ )
102
+ )
103
+
104
+ labelled = nodes_xy.sort_values("degree", ascending=False).head(label_top_n)
105
+ label_layer = (
106
+ alt.Chart(labelled)
107
+ .mark_text(dy=-10, fontSize=11, fontWeight="bold")
108
+ .encode(
109
+ x=alt.X("x:Q", axis=None, scale=alt.Scale(domain=[-1.1, 1.1])),
110
+ y=alt.Y("y:Q", axis=None, scale=alt.Scale(domain=[-1.1, 1.1])),
111
+ text="term:N",
112
+ )
113
+ )
114
+
115
+ chart = (
116
+ (edge_layer + node_layer + label_layer)
117
+ .properties(width=width, height=height)
118
+ .configure_view(strokeWidth=0)
119
+ .interactive()
120
+ )
121
+ return chart # type: ignore[no-any-return]
122
+
123
+
124
+ def _layout(
125
+ nodes: pd.DataFrame,
126
+ edges: pd.DataFrame,
127
+ seed: int = 0,
128
+ ) -> pd.DataFrame:
129
+ """Compute ``(x, y)`` for every node.
130
+
131
+ Uses ``networkx.spring_layout`` if available; otherwise falls back
132
+ to a circular layout (still a valid plot, just less informative).
133
+ Returns a DataFrame indexed by term with ``x`` and ``y`` columns
134
+ rescaled to ``[-1, 1]``.
135
+ """
136
+ try:
137
+ import networkx as nx
138
+ except ImportError:
139
+ return _circular_layout(nodes.index.tolist())
140
+
141
+ g = nx.Graph()
142
+ for term in nodes.index:
143
+ g.add_node(term)
144
+ for _, row in edges.iterrows():
145
+ g.add_edge(row["source"], row["target"], weight=abs(float(row["weight"])))
146
+
147
+ # Kamada-Kawai produces more uniformly-spaced layouts than the
148
+ # default spring algorithm on densely-connected graphs, which is
149
+ # the common case for corpus discourse networks (every top term
150
+ # tends to co-occur with many others). Falls back to a high-k
151
+ # spring layout for disconnected graphs (KK requires connectivity).
152
+ n = max(1, g.number_of_nodes())
153
+ try:
154
+ if nx.is_connected(g):
155
+ pos = nx.kamada_kawai_layout(g)
156
+ else:
157
+ pos = nx.spring_layout(
158
+ g, seed=seed, k=2.5 / math.sqrt(n), iterations=200
159
+ )
160
+ except (nx.NetworkXError, ValueError):
161
+ pos = nx.spring_layout(
162
+ g, seed=seed, k=2.5 / math.sqrt(n), iterations=200
163
+ )
164
+ coords = pd.DataFrame(
165
+ {"x": [pos[t][0] for t in nodes.index], "y": [pos[t][1] for t in nodes.index]},
166
+ index=nodes.index,
167
+ )
168
+ # Rescale to a square in [-1, 1].
169
+ for col in ("x", "y"):
170
+ lo, hi = coords[col].min(), coords[col].max()
171
+ span = hi - lo if hi != lo else 1.0
172
+ coords[col] = 2.0 * (coords[col] - lo) / span - 1.0
173
+ return coords
174
+
175
+
176
+ def _circular_layout(terms: list[str]) -> pd.DataFrame:
177
+ """Fallback layout when ``networkx`` isn't installed."""
178
+ n = len(terms)
179
+ coords = {
180
+ t: (math.cos(2 * math.pi * i / n), math.sin(2 * math.pi * i / n))
181
+ for i, t in enumerate(terms)
182
+ }
183
+ return pd.DataFrame(
184
+ {"x": [coords[t][0] for t in terms], "y": [coords[t][1] for t in terms]},
185
+ index=terms,
186
+ )
@@ -0,0 +1,160 @@
1
+ """Scattertext-style interactive scatter (Kessler 2017).
2
+
3
+ The signature Scattertext visualisation. Each term is a point whose
4
+ x-axis position is its rank-percentile in corpus A and whose y-axis
5
+ position is its rank-percentile in corpus B. Words common in both
6
+ land in the top-right corner; words common in only one side fall away
7
+ from the diagonal into one of the two off-diagonal "distinctiveness"
8
+ zones; words rare in both cluster near the origin.
9
+
10
+ The rank-based axes are the trick that makes Scattertext readable.
11
+ Plotting raw counts produces a plot dominated by stopwords in the
12
+ top-right corner with everything else crushed into the bottom-left;
13
+ rank-percentiles spread the whole vocabulary evenly across [0, 1].
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import TYPE_CHECKING
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+ if TYPE_CHECKING:
24
+ import altair as alt
25
+
26
+
27
+ def scattertext_plot(
28
+ df: pd.DataFrame,
29
+ *,
30
+ label_a: str = "a",
31
+ label_b: str = "b",
32
+ width: int = 600,
33
+ height: int = 600,
34
+ n_labels: int = 20,
35
+ ) -> alt.Chart:
36
+ """Scattertext-style interactive scatter of keyness terms.
37
+
38
+ Expects the columns produced by :meth:`Comparison.keyness`:
39
+ ``term``, ``count_a``, ``count_b``, and ``g2`` (the signed
40
+ log-likelihood used as the colour channel). ``log_ratio`` is used
41
+ in the tooltip when present.
42
+
43
+ The chart is pan/zoom-able and every point has a hover tooltip
44
+ listing the raw counts and effect sizes. The ``n_labels`` most
45
+ A-leaning and ``n_labels`` most B-leaning terms (by ``|g2|``) get
46
+ inline text labels — others remain dots.
47
+
48
+ Parameters
49
+ ----------
50
+ df
51
+ A KeynessResult-shaped DataFrame.
52
+ label_a, label_b
53
+ Axis titles — usually the corpus labels carried on the result.
54
+ width, height
55
+ Square by default (600×600); the rank-percentile axes deserve
56
+ equal visual weight.
57
+ n_labels
58
+ Per-side label budget. ``n_labels=20`` produces up to 40 text
59
+ labels total.
60
+ """
61
+ import altair as alt
62
+
63
+ if df.empty:
64
+ empty_chart = (
65
+ alt.Chart(df).mark_point().properties(width=width, height=height)
66
+ )
67
+ return empty_chart # type: ignore[no-any-return]
68
+
69
+ plot = df.copy()
70
+ # rank() with method="average" handles ties cleanly; pct=True scales to
71
+ # (0, 1]. Higher percentile == more common in that corpus.
72
+ plot["percentile_a"] = plot["count_a"].rank(pct=True, method="average")
73
+ plot["percentile_b"] = plot["count_b"].rank(pct=True, method="average")
74
+ plot["abs_g2"] = plot["g2"].abs()
75
+
76
+ # Symmetric colour scale around zero so the colour midpoint corresponds
77
+ # to "equally common in both", not to the median G² of the table.
78
+ g2_max = float(plot["g2"].abs().max() or 1.0)
79
+
80
+ tooltip_cols: list[str] = ["term", "count_a", "count_b", "g2"]
81
+ for opt in ("log_ratio", "percent_diff", "p_value", "p_adjusted"):
82
+ if opt in plot.columns:
83
+ tooltip_cols.append(opt)
84
+
85
+ base = alt.Chart(plot).encode(
86
+ x=alt.X(
87
+ "percentile_a:Q",
88
+ title=f"Frequency rank in {label_a} (→ more common)",
89
+ scale=alt.Scale(domain=[0, 1]),
90
+ ),
91
+ y=alt.Y(
92
+ "percentile_b:Q",
93
+ title=f"Frequency rank in {label_b} (→ more common)",
94
+ scale=alt.Scale(domain=[0, 1]),
95
+ ),
96
+ tooltip=tooltip_cols,
97
+ )
98
+
99
+ # The diagonal x = y is the "equally distinctive" line — terms on it
100
+ # have the same rank in both corpora. Drawing it as a reference rule
101
+ # helps readers calibrate the distinctiveness zones.
102
+ diag = (
103
+ alt.Chart(pd.DataFrame({"x": [0.0, 1.0], "y": [0.0, 1.0]}))
104
+ .mark_line(strokeDash=[4, 4], color="#999", opacity=0.5)
105
+ .encode(x="x:Q", y="y:Q")
106
+ )
107
+
108
+ points = base.mark_circle(opacity=0.55).encode(
109
+ size=alt.Size(
110
+ "abs_g2:Q",
111
+ scale=alt.Scale(range=[20, 200]),
112
+ legend=None,
113
+ ),
114
+ color=alt.Color(
115
+ "g2:Q",
116
+ scale=alt.Scale(
117
+ scheme="redblue",
118
+ domain=[-g2_max, 0.0, g2_max],
119
+ reverse=True, # blue = A-leaning, red = B-leaning
120
+ ),
121
+ title="Signed G²",
122
+ ),
123
+ )
124
+
125
+ # Pick the top-n_labels on each side by signed G².
126
+ a_leaning = plot.nlargest(n_labels, "g2")
127
+ b_leaning = plot.nsmallest(n_labels, "g2")
128
+ labelled = pd.concat([a_leaning, b_leaning], ignore_index=True).drop_duplicates(
129
+ subset="term"
130
+ )
131
+ labels = (
132
+ alt.Chart(labelled)
133
+ .mark_text(align="left", dx=5, dy=-2, fontSize=10)
134
+ .encode(
135
+ x="percentile_a:Q",
136
+ y="percentile_b:Q",
137
+ text="term:N",
138
+ color=alt.Color(
139
+ "g2:Q",
140
+ scale=alt.Scale(
141
+ scheme="redblue",
142
+ domain=[-g2_max, 0.0, g2_max],
143
+ reverse=True,
144
+ ),
145
+ legend=None,
146
+ ),
147
+ )
148
+ )
149
+
150
+ chart = (
151
+ (diag + points + labels)
152
+ .properties(width=width, height=height)
153
+ .interactive()
154
+ )
155
+ return chart # type: ignore[no-any-return]
156
+
157
+
158
+ def _percentile_rank(series: pd.Series) -> np.ndarray:
159
+ """Internal helper used by tests; kept here for stability."""
160
+ return series.rank(pct=True, method="average").to_numpy()
@@ -0,0 +1,114 @@
1
+ """Plot for :func:`pycorpdiff.forecast_semantic_drift` output.
2
+
3
+ Same dashed-extension grammar as :func:`pycorpdiff.viz.forecast_plot`
4
+ but operates on the cosine-distance scale (``distance_from_baseline``
5
+ column on the history side, ``point`` + PI on the forecast side).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ import pandas as pd
13
+
14
+ if TYPE_CHECKING:
15
+ import altair as alt
16
+
17
+
18
+ def semantic_forecast_plot(
19
+ history: pd.DataFrame,
20
+ forecast: pd.DataFrame,
21
+ *,
22
+ width: int = 600,
23
+ height: int = 320,
24
+ ) -> alt.Chart:
25
+ """Layered plot of semantic drift with a forecast continuation.
26
+
27
+ Parameters
28
+ ----------
29
+ history
30
+ The DataFrame returned by :func:`pycorpdiff.semantic_trajectory`
31
+ — must carry ``period``, ``target``, ``distance_from_baseline``.
32
+ forecast
33
+ The DataFrame returned by
34
+ :func:`pycorpdiff.forecast_semantic_drift` — must carry
35
+ ``period``, ``target``, ``point``, ``ci_lower``, ``ci_upper``.
36
+ """
37
+ import altair as alt
38
+
39
+ h = history.copy()
40
+ f = forecast.copy()
41
+ if len(h) and isinstance(h["period"].iloc[0], pd.Period):
42
+ h["period"] = h["period"].apply(lambda p: p.to_timestamp())
43
+ if len(f) and isinstance(f["period"].iloc[0], pd.Period):
44
+ f["period"] = f["period"].apply(lambda p: p.to_timestamp())
45
+
46
+ base_h = alt.Chart(h).encode(
47
+ x=alt.X("period:T", title=None),
48
+ color=alt.Color("target:N", title=None),
49
+ )
50
+ history_line = base_h.mark_line(strokeWidth=2.5).encode(
51
+ y=alt.Y("distance_from_baseline:Q", title="cosine distance from baseline"),
52
+ tooltip=[
53
+ "period",
54
+ "target",
55
+ alt.Tooltip("distance_from_baseline:Q", format=".4f"),
56
+ alt.Tooltip("n_contexts:Q") if "n_contexts" in h.columns else "target:N",
57
+ ],
58
+ )
59
+ history_points = base_h.mark_point(filled=True, size=55).encode(
60
+ y="distance_from_baseline:Q",
61
+ )
62
+
63
+ base_f = alt.Chart(f).encode(
64
+ x=alt.X("period:T", title=None),
65
+ color=alt.Color("target:N", title=None),
66
+ )
67
+ forecast_band = base_f.mark_area(opacity=0.18).encode(
68
+ y=alt.Y("ci_lower:Q"),
69
+ y2="ci_upper:Q",
70
+ )
71
+ forecast_line = base_f.mark_line(strokeDash=[6, 4], strokeWidth=2).encode(
72
+ y="point:Q",
73
+ tooltip=[
74
+ "period",
75
+ "target",
76
+ alt.Tooltip("point:Q", format=".4f"),
77
+ alt.Tooltip("ci_lower:Q", format=".4f"),
78
+ alt.Tooltip("ci_upper:Q", format=".4f"),
79
+ ],
80
+ )
81
+ forecast_points = base_f.mark_point(filled=False, strokeWidth=2, size=55).encode(
82
+ y="point:Q"
83
+ )
84
+
85
+ if len(h) and len(f):
86
+ last_h = h.sort_values("period").groupby("target").tail(1)
87
+ first_f = f.sort_values("period").groupby("target").head(1)
88
+ seam = pd.concat(
89
+ [
90
+ last_h.assign(point=last_h["distance_from_baseline"])[
91
+ ["period", "target", "point"]
92
+ ],
93
+ first_f[["period", "target", "point"]],
94
+ ]
95
+ )
96
+ seam_line = (
97
+ alt.Chart(seam)
98
+ .mark_line(strokeDash=[6, 4], strokeWidth=2, opacity=0.55)
99
+ .encode(
100
+ x="period:T",
101
+ y="point:Q",
102
+ color=alt.Color("target:N", legend=None),
103
+ )
104
+ )
105
+ else:
106
+ seam_line = alt.Chart(
107
+ pd.DataFrame({"period": [], "point": [], "target": []})
108
+ )
109
+
110
+ chart = (
111
+ history_line + history_points + seam_line + forecast_band
112
+ + forecast_line + forecast_points
113
+ ).properties(width=width, height=height)
114
+ return chart # type: ignore[no-any-return]