pycorpdiff 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. pycorpdiff/__init__.py +126 -0
  2. pycorpdiff/_backends/__init__.py +3 -0
  3. pycorpdiff/_backends/pandas.py +3 -0
  4. pycorpdiff/_backends/polars.py +3 -0
  5. pycorpdiff/collocation/__init__.py +19 -0
  6. pycorpdiff/collocation/cooccurrence.py +65 -0
  7. pycorpdiff/collocation/measures.py +102 -0
  8. pycorpdiff/collocation/network.py +233 -0
  9. pycorpdiff/collocation/shift.py +146 -0
  10. pycorpdiff/compare.py +345 -0
  11. pycorpdiff/corpus.py +411 -0
  12. pycorpdiff/datasets/__init__.py +27 -0
  13. pycorpdiff/datasets/_data/hansard_sample.parquet +0 -0
  14. pycorpdiff/datasets/_generate_hansard.py +221 -0
  15. pycorpdiff/datasets/hansard.py +235 -0
  16. pycorpdiff/datasets/histwords.py +221 -0
  17. pycorpdiff/explain.py +177 -0
  18. pycorpdiff/io/__init__.py +16 -0
  19. pycorpdiff/io/duckdb.py +92 -0
  20. pycorpdiff/io/huggingface.py +142 -0
  21. pycorpdiff/io/readers.py +138 -0
  22. pycorpdiff/keyness/__init__.py +26 -0
  23. pycorpdiff/keyness/bayes.py +50 -0
  24. pycorpdiff/keyness/chi_squared.py +94 -0
  25. pycorpdiff/keyness/correction.py +34 -0
  26. pycorpdiff/keyness/dispersion.py +89 -0
  27. pycorpdiff/keyness/effect_sizes.py +65 -0
  28. pycorpdiff/keyness/loglikelihood.py +92 -0
  29. pycorpdiff/keyness/multicorpus.py +143 -0
  30. pycorpdiff/keyness/permutation.py +154 -0
  31. pycorpdiff/py.typed +0 -0
  32. pycorpdiff/results.py +635 -0
  33. pycorpdiff/semantic/__init__.py +18 -0
  34. pycorpdiff/semantic/alignment.py +53 -0
  35. pycorpdiff/semantic/embed.py +84 -0
  36. pycorpdiff/semantic/shift.py +224 -0
  37. pycorpdiff/semantic/trajectory.py +166 -0
  38. pycorpdiff/stats.py +69 -0
  39. pycorpdiff/temporal/__init__.py +15 -0
  40. pycorpdiff/temporal/bocpd.py +233 -0
  41. pycorpdiff/temporal/causal_impact.py +293 -0
  42. pycorpdiff/temporal/changepoint.py +92 -0
  43. pycorpdiff/temporal/forecast.py +405 -0
  44. pycorpdiff/temporal/its.py +123 -0
  45. pycorpdiff/temporal/slicing.py +174 -0
  46. pycorpdiff/tokenize.py +110 -0
  47. pycorpdiff/viz/__init__.py +37 -0
  48. pycorpdiff/viz/bocpd.py +173 -0
  49. pycorpdiff/viz/causal_impact.py +142 -0
  50. pycorpdiff/viz/collocation.py +48 -0
  51. pycorpdiff/viz/dispersion.py +117 -0
  52. pycorpdiff/viz/forecast.py +129 -0
  53. pycorpdiff/viz/keyness.py +96 -0
  54. pycorpdiff/viz/network.py +186 -0
  55. pycorpdiff/viz/scattertext.py +160 -0
  56. pycorpdiff/viz/semantic_forecast.py +114 -0
  57. pycorpdiff/viz/trajectory.py +48 -0
  58. pycorpdiff-0.1.0a0.dist-info/METADATA +230 -0
  59. pycorpdiff-0.1.0a0.dist-info/RECORD +61 -0
  60. pycorpdiff-0.1.0a0.dist-info/WHEEL +4 -0
  61. pycorpdiff-0.1.0a0.dist-info/licenses/LICENSE +21 -0
pycorpdiff/tokenize.py ADDED
@@ -0,0 +1,110 @@
1
+ """Tokenizer protocol and the default regex tokenizer.
2
+
3
+ The :class:`Tokenizer` protocol is the package's only extension point for
4
+ language-specific preprocessing. Adapters around spaCy, Stanza, jieba,
5
+ fugashi, etc. need to satisfy a single ``__call__(text: str) -> list[str]``
6
+ contract — no inheritance, no registration.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import re
12
+ from dataclasses import dataclass, field
13
+ from typing import Protocol, runtime_checkable
14
+
15
+
16
+ @runtime_checkable
17
+ class Tokenizer(Protocol):
18
+ """Anything callable that maps a string to a list of token strings."""
19
+
20
+ def __call__(self, text: str) -> list[str]: ...
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class RegexTokenizer:
25
+ """A minimal Unicode-aware regex tokenizer used as the default.
26
+
27
+ The default pattern matches sequences of word characters (``\\w+``),
28
+ which under Python's default regex engine is Unicode-aware and
29
+ therefore safe for non-Latin scripts at the granularity of "word-like
30
+ runs of letters / digits / underscores". Researchers needing
31
+ language-specific behaviour (lemmatisation, segmentation of CJK
32
+ scripts, MWE handling, etc.) should plug in a spaCy/Stanza/jieba
33
+ adapter instead.
34
+ """
35
+
36
+ pattern: str = r"\w+"
37
+ lowercase: bool = True
38
+ _compiled: re.Pattern[str] = field(init=False, repr=False, compare=False)
39
+
40
+ def __post_init__(self) -> None:
41
+ object.__setattr__(self, "_compiled", re.compile(self.pattern, re.UNICODE))
42
+
43
+ def __call__(self, text: str) -> list[str]:
44
+ if self.lowercase:
45
+ text = text.lower()
46
+ return self._compiled.findall(text)
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class NgramTokenizer:
51
+ """A tokenizer wrapper that emits joined n-grams from a base tokenizer.
52
+
53
+ Wraps any :class:`Tokenizer` and yields ``n``-token sequences joined
54
+ by ``sep``. The output is a flat ``list[str]`` of joined n-grams,
55
+ which means every downstream analytical surface — keyness,
56
+ dispersion, collocation, semantic shift — treats them as ordinary
57
+ terms with no special-casing needed.
58
+
59
+ Parameters
60
+ ----------
61
+ base
62
+ The underlying tokenizer producing unigrams. Defaults to
63
+ :class:`RegexTokenizer`.
64
+ n
65
+ N-gram order. ``2`` = bigrams, ``3`` = trigrams. Must be ``>=1``.
66
+ sep
67
+ Joiner string. Default ``"_"`` matches gensim's convention and
68
+ sidesteps the usual ambiguity of whitespace-joined n-grams when
69
+ the base tokenizer itself strips whitespace.
70
+ include_lower
71
+ If ``True``, the output also includes every ``k``-gram for
72
+ ``k < n`` (so ``n=3`` emits unigrams + bigrams + trigrams).
73
+ Useful when you want a single Comparison to keyness-rank
74
+ single words *and* their multi-word collocations side-by-side.
75
+
76
+ Examples
77
+ --------
78
+ >>> from pycorpdiff.tokenize import NgramTokenizer
79
+ >>> tok = NgramTokenizer(n=2)
80
+ >>> tok("the cat sat on the mat")
81
+ ['the_cat', 'cat_sat', 'sat_on', 'on_the', 'the_mat']
82
+ >>> NgramTokenizer(n=2, include_lower=True)("a b c")
83
+ ['a', 'b', 'c', 'a_b', 'b_c']
84
+ """
85
+
86
+ base: Tokenizer = field(default_factory=RegexTokenizer)
87
+ n: int = 2
88
+ sep: str = "_"
89
+ include_lower: bool = False
90
+
91
+ def __post_init__(self) -> None:
92
+ if self.n < 1:
93
+ raise ValueError(f"n must be >= 1; got {self.n}")
94
+
95
+ def __call__(self, text: str) -> list[str]:
96
+ unigrams = self.base(text)
97
+ if self.n == 1:
98
+ return unigrams
99
+ out: list[str] = []
100
+ start = 1 if self.include_lower else self.n
101
+ for k in range(start, self.n + 1):
102
+ if k == 1:
103
+ out.extend(unigrams)
104
+ continue
105
+ sep = self.sep
106
+ out.extend(
107
+ sep.join(unigrams[i : i + k])
108
+ for i in range(len(unigrams) - k + 1)
109
+ )
110
+ return out
@@ -0,0 +1,37 @@
1
+ """Visualisation helpers — altair-first, matplotlib for paper-grade figures.
2
+
3
+ Every Result type's ``.plot()`` method delegates here. Plot functions
4
+ also accept a bare DataFrame so users can call
5
+ ``pcd.viz.keyness_volcano(df)`` directly without going through a Result.
6
+
7
+ altair is an optional dependency declared in the ``viz`` extra. Each
8
+ plot function lazily imports altair on first call; the friendly
9
+ ImportError lives at that boundary.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from .bocpd import bocpd_plot
15
+ from .causal_impact import causal_impact_plot
16
+ from .collocation import collocation_diverging_bar
17
+ from .dispersion import dispersion_plot
18
+ from .forecast import forecast_plot
19
+ from .keyness import keyness_top_n_bar, keyness_volcano
20
+ from .network import network_plot
21
+ from .scattertext import scattertext_plot
22
+ from .semantic_forecast import semantic_forecast_plot
23
+ from .trajectory import trajectory_with_ci
24
+
25
+ __all__ = [
26
+ "bocpd_plot",
27
+ "causal_impact_plot",
28
+ "collocation_diverging_bar",
29
+ "dispersion_plot",
30
+ "forecast_plot",
31
+ "keyness_top_n_bar",
32
+ "keyness_volcano",
33
+ "network_plot",
34
+ "scattertext_plot",
35
+ "semantic_forecast_plot",
36
+ "trajectory_with_ci",
37
+ ]
@@ -0,0 +1,173 @@
1
+ """BOCPD diagnostic plot — series + run-length posterior heatmap + MAP line.
2
+
3
+ Three stacked panels sharing the time axis:
4
+
5
+ 1. The input series with detected changepoints flagged.
6
+ 2. Heatmap of the run-length posterior P(r_t | data so far) on a
7
+ log colour scale — the canonical BOCPD diagnostic figure.
8
+ 3. MAP run length over time. Visible drops mark changepoints.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import TYPE_CHECKING
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ if TYPE_CHECKING:
19
+ import altair as alt
20
+
21
+ from ..temporal.bocpd import BocpdResult
22
+
23
+
24
+ def bocpd_plot(
25
+ result: BocpdResult,
26
+ *,
27
+ width: int = 660,
28
+ height_per_panel: int = 180,
29
+ max_run_length_shown: int = 40,
30
+ threshold: int = 3,
31
+ ) -> alt.Chart:
32
+ """Three-panel BOCPD diagnostic chart.
33
+
34
+ Parameters
35
+ ----------
36
+ result
37
+ A :class:`BocpdResult`.
38
+ width
39
+ Width of each panel (panels are vertically concat'd).
40
+ height_per_panel
41
+ Height of each panel.
42
+ max_run_length_shown
43
+ Truncate the heatmap at this run length — most posterior mass
44
+ lives in the first few dozen run lengths.
45
+ threshold
46
+ MAP-run-length threshold below which a step is flagged as a
47
+ detected changepoint in panel 1.
48
+ """
49
+ import altair as alt
50
+
51
+ # The heatmap panel has T × (max_run_length_shown + 1) cells which
52
+ # can exceed altair's default 5000-row inline-data limit for longer
53
+ # series. Disable the limit on the heatmap dataframe specifically —
54
+ # the data is inline (no external host), so the rendering cost is
55
+ # bounded by the local browser.
56
+ alt.data_transformers.disable_max_rows()
57
+
58
+ periods = result.series.index
59
+ if isinstance(periods, pd.PeriodIndex):
60
+ period_axis: pd.Index = pd.Index(periods.to_timestamp())
61
+ else:
62
+ period_axis = pd.Index(periods)
63
+
64
+ series_df = pd.DataFrame(
65
+ {"period": period_axis, "value": result.series.to_numpy(dtype=float)}
66
+ )
67
+ flagged_df = result.detected_changepoints(threshold=threshold).copy()
68
+ if len(flagged_df) and isinstance(flagged_df["period"].iloc[0], pd.Period):
69
+ flagged_df["period"] = flagged_df["period"].apply(lambda p: p.to_timestamp())
70
+
71
+ # ---------- Panel 1: series + flagged changepoints ----------
72
+ series_line = (
73
+ alt.Chart(series_df)
74
+ .mark_line(strokeWidth=2, color="#0b6e7c")
75
+ .encode(
76
+ x=alt.X("period:T", title=None),
77
+ y=alt.Y("value:Q", title=None),
78
+ tooltip=["period", alt.Tooltip("value:Q", format=".5f")],
79
+ )
80
+ )
81
+ series_points = (
82
+ alt.Chart(series_df)
83
+ .mark_point(filled=True, size=40, color="#0b6e7c")
84
+ .encode(x="period:T", y="value:Q")
85
+ )
86
+ cp_rules = (
87
+ alt.Chart(flagged_df)
88
+ .mark_rule(color="#e63946", strokeDash=[4, 3], opacity=0.7)
89
+ .encode(x="period:T")
90
+ )
91
+ panel1 = (series_line + series_points + cp_rules).properties(
92
+ width=width,
93
+ height=height_per_panel,
94
+ title=alt.TitleParams(
95
+ text=f"Observed series — {len(flagged_df)} flagged changepoint(s)",
96
+ subtitle=f"red lines: MAP run length ≤ {threshold} (hazard = {result.hazard})",
97
+ ),
98
+ )
99
+
100
+ # ---------- Panel 2: run-length posterior heatmap ----------
101
+ R = np.asarray(result.run_length_posterior)
102
+ R_shown = R[:, : max_run_length_shown + 1]
103
+ # log10 + clip for colour mapping — most posterior values are
104
+ # tiny floats, log compresses the dynamic range.
105
+ with np.errstate(divide="ignore"):
106
+ log_R = np.log10(np.clip(R_shown, 1e-12, 1.0))
107
+ n_t, n_r = log_R.shape
108
+ period_repeat = np.tile(period_axis.to_numpy(), n_r)
109
+ runs = np.repeat(np.arange(n_r), n_t)
110
+ heat_df = pd.DataFrame(
111
+ {
112
+ "period": period_repeat,
113
+ "run_length": runs,
114
+ "log_posterior": log_R.T.ravel(),
115
+ }
116
+ )
117
+ heatmap = (
118
+ alt.Chart(heat_df)
119
+ .mark_rect()
120
+ .encode(
121
+ x=alt.X("period:T", title=None),
122
+ y=alt.Y("run_length:O", title="run length r", sort="descending"),
123
+ color=alt.Color(
124
+ "log_posterior:Q",
125
+ scale=alt.Scale(scheme="viridis", domain=[-6, 0]),
126
+ title="log₁₀ P(r | data)",
127
+ ),
128
+ tooltip=[
129
+ "period",
130
+ "run_length",
131
+ alt.Tooltip("log_posterior:Q", format=".2f"),
132
+ ],
133
+ )
134
+ .properties(
135
+ width=width,
136
+ height=height_per_panel,
137
+ title="Run-length posterior P(r_t | data through t) — log₁₀ scale",
138
+ )
139
+ )
140
+
141
+ # ---------- Panel 3: MAP run length over time ----------
142
+ map_df = pd.DataFrame(
143
+ {
144
+ "period": period_axis,
145
+ "map_run_length": result.map_run_length.to_numpy(dtype=int),
146
+ }
147
+ )
148
+ map_line = (
149
+ alt.Chart(map_df)
150
+ .mark_line(strokeWidth=2, color="#1f7a3e")
151
+ .encode(
152
+ x=alt.X("period:T", title=None),
153
+ y=alt.Y("map_run_length:Q", title="MAP run length"),
154
+ tooltip=["period", "map_run_length"],
155
+ )
156
+ )
157
+ map_points = (
158
+ alt.Chart(map_df)
159
+ .mark_point(filled=True, size=40, color="#1f7a3e")
160
+ .encode(x="period:T", y="map_run_length:Q")
161
+ )
162
+ threshold_rule = (
163
+ alt.Chart(pd.DataFrame({"y": [threshold]}))
164
+ .mark_rule(color="#888", strokeDash=[3, 3], opacity=0.6)
165
+ .encode(y="y:Q")
166
+ )
167
+ panel3 = (map_line + map_points + threshold_rule).properties(
168
+ width=width,
169
+ height=height_per_panel,
170
+ title="MAP run length — visible drops mark detected changepoints",
171
+ )
172
+
173
+ return alt.vconcat(panel1, heatmap, panel3).resolve_scale(x="shared") # type: ignore[no-any-return]
@@ -0,0 +1,142 @@
1
+ """Three-panel causal-impact plot (Brodersen et al. 2015 style).
2
+
3
+ - Panel 1 — observed series vs counterfactual (dashed) with CrI band
4
+ - Panel 2 — pointwise effect (observed − counterfactual) with CrI band,
5
+ zero reference line
6
+ - Panel 3 — cumulative effect with CrI band, zero reference line
7
+
8
+ Stacked vertically, sharing the time axis. The visual grammar matches
9
+ the figures from Brodersen, Gallusser, Koehler, Remy & Scott (2015).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING
15
+
16
+ import pandas as pd
17
+
18
+ if TYPE_CHECKING:
19
+ import altair as alt
20
+
21
+ from ..temporal.causal_impact import CausalImpactResult
22
+
23
+
24
+ def causal_impact_plot(
25
+ result: CausalImpactResult,
26
+ *,
27
+ width: int = 640,
28
+ height_per_panel: int = 200,
29
+ ) -> alt.Chart:
30
+ """Three-panel observed/counterfactual + effect plot.
31
+
32
+ Parameters
33
+ ----------
34
+ result
35
+ A :class:`CausalImpactResult` from
36
+ :meth:`TemporalTrajectory.causal_impact`.
37
+ width
38
+ Width of each panel.
39
+ height_per_panel
40
+ Height of each panel. Total chart height ≈ 3× this.
41
+ """
42
+ import altair as alt
43
+
44
+ df = result.table.copy()
45
+ if len(df) and isinstance(df["period"].iloc[0], pd.Period):
46
+ df["period"] = df["period"].apply(lambda p: p.to_timestamp())
47
+
48
+ # The event marker: a vertical rule at result.event_date.
49
+ event_df = pd.DataFrame({"event": [pd.Timestamp(result.event_date)]})
50
+ event_rule = (
51
+ alt.Chart(event_df)
52
+ .mark_rule(color="#999", strokeDash=[3, 3], strokeWidth=1.5)
53
+ .encode(x="event:T")
54
+ )
55
+ zero_rule = (
56
+ alt.Chart(pd.DataFrame({"y": [0.0]}))
57
+ .mark_rule(color="#666", strokeWidth=1)
58
+ .encode(y="y:Q")
59
+ )
60
+
61
+ x_axis = alt.X("period:T", title=None)
62
+
63
+ # ---------- Panel 1: observed vs counterfactual ----------
64
+ base1 = alt.Chart(df).encode(x=x_axis)
65
+ cf_band = base1.mark_area(opacity=0.18, color="#888").encode(
66
+ y=alt.Y("counterfactual_lower:Q", title=f"{result.target} rate"),
67
+ y2="counterfactual_upper:Q",
68
+ )
69
+ cf_line = base1.mark_line(
70
+ color="#888", strokeDash=[5, 4], strokeWidth=2
71
+ ).encode(y="counterfactual:Q")
72
+ observed_line = base1.mark_line(color="#0b6e7c", strokeWidth=2.5).encode(
73
+ y="observed:Q",
74
+ tooltip=[
75
+ "period",
76
+ alt.Tooltip("observed:Q", format=".5f"),
77
+ alt.Tooltip("counterfactual:Q", format=".5f"),
78
+ alt.Tooltip("counterfactual_lower:Q", format=".5f"),
79
+ alt.Tooltip("counterfactual_upper:Q", format=".5f"),
80
+ ],
81
+ )
82
+ panel1 = (
83
+ (cf_band + cf_line + observed_line + event_rule)
84
+ .properties(
85
+ width=width,
86
+ height=height_per_panel,
87
+ title=alt.TitleParams(
88
+ text=f"Observed (teal) vs counterfactual (gray) — {result.target!r}",
89
+ subtitle=(
90
+ f"event = {pd.Timestamp(result.event_date).date()} · "
91
+ f"BSTS local linear trend on {result.n_pre} pre-event periods"
92
+ ),
93
+ ),
94
+ )
95
+ )
96
+
97
+ # ---------- Panel 2: pointwise effect ----------
98
+ base2 = alt.Chart(df).encode(x=x_axis)
99
+ pw_band = base2.mark_area(opacity=0.22, color="#e63946").encode(
100
+ y=alt.Y("pointwise_lower:Q", title="pointwise effect"),
101
+ y2="pointwise_upper:Q",
102
+ )
103
+ pw_line = base2.mark_line(color="#e63946", strokeWidth=2).encode(
104
+ y="pointwise_effect:Q",
105
+ tooltip=[
106
+ "period",
107
+ alt.Tooltip("pointwise_effect:Q", format=".5f"),
108
+ alt.Tooltip("pointwise_lower:Q", format=".5f"),
109
+ alt.Tooltip("pointwise_upper:Q", format=".5f"),
110
+ ],
111
+ )
112
+ panel2 = (pw_band + pw_line + zero_rule + event_rule).properties(
113
+ width=width,
114
+ height=height_per_panel,
115
+ title=alt.TitleParams(
116
+ text="Pointwise effect (observed − counterfactual)",
117
+ subtitle=f"{int(result.level * 100)}% credible interval shaded",
118
+ ),
119
+ )
120
+
121
+ # ---------- Panel 3: cumulative effect ----------
122
+ base3 = alt.Chart(df).encode(x=x_axis)
123
+ cum_band = base3.mark_area(opacity=0.22, color="#1f7a3e").encode(
124
+ y=alt.Y("cumulative_lower:Q", title="cumulative effect"),
125
+ y2="cumulative_upper:Q",
126
+ )
127
+ cum_line = base3.mark_line(color="#1f7a3e", strokeWidth=2).encode(
128
+ y="cumulative_effect:Q",
129
+ tooltip=[
130
+ "period",
131
+ alt.Tooltip("cumulative_effect:Q", format=".5f"),
132
+ alt.Tooltip("cumulative_lower:Q", format=".5f"),
133
+ alt.Tooltip("cumulative_upper:Q", format=".5f"),
134
+ ],
135
+ )
136
+ panel3 = (cum_band + cum_line + zero_rule + event_rule).properties(
137
+ width=width,
138
+ height=height_per_panel,
139
+ title="Cumulative effect",
140
+ )
141
+
142
+ return alt.vconcat(panel1, panel2, panel3).resolve_scale(x="shared") # type: ignore[no-any-return]
@@ -0,0 +1,48 @@
1
+ """Collocation-shift visualisation — diverging horizontal bar."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import pandas as pd
8
+
9
+ if TYPE_CHECKING:
10
+ import altair as alt
11
+
12
+
13
+ def collocation_diverging_bar(
14
+ df: pd.DataFrame,
15
+ n: int = 20,
16
+ width: int = 500,
17
+ height: int | None = None,
18
+ ) -> alt.Chart:
19
+ """Diverging bar chart of the top ``n`` collocate shifts.
20
+
21
+ Positive bars are A-leaning collocates (gained around the target in
22
+ A); negative bars are B-leaning (lost from A, gained in B). Sorted
23
+ by signed shift so the eye reads the divergence directly.
24
+ """
25
+ import altair as alt
26
+
27
+ subset = (
28
+ df.assign(_abs=df["shift"].abs()).nlargest(n, "_abs").drop(columns="_abs")
29
+ ).sort_values("shift", ascending=False)
30
+ if height is None:
31
+ height = max(200, 18 * len(subset))
32
+
33
+ chart = (
34
+ alt.Chart(subset)
35
+ .mark_bar()
36
+ .encode(
37
+ x=alt.X("shift:Q", title="Shift (A − B)"),
38
+ y=alt.Y("collocate:N", sort="-x", title=None),
39
+ color=alt.condition(
40
+ alt.datum.shift >= 0,
41
+ alt.value("#1f77b4"),
42
+ alt.value("#d62728"),
43
+ ),
44
+ tooltip=list(subset.columns),
45
+ )
46
+ .properties(width=width, height=height)
47
+ )
48
+ return chart # type: ignore[no-any-return]
@@ -0,0 +1,117 @@
1
+ """Dispersion plots: *where* in a corpus does a term appear?
2
+
3
+ A term can have the same overall frequency in two corpora and be
4
+ distributed completely differently — even across one corpus, a high-
5
+ frequency word can be clustered in a few documents or evenly spread.
6
+ The dispersion plot answers "where" by marking each occurrence at the
7
+ relevant document index along a horizontal axis.
8
+
9
+ This is the classic Mosteller / Stubbs / Brezina visualisation for
10
+ "how representative is a frequency count" — companion to the
11
+ :func:`pycorpdiff.keyness.juilland_d` / ``dispersion_dp`` numerical
12
+ measures pycorpdiff already exposes.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import TYPE_CHECKING, Any
18
+
19
+ import pandas as pd
20
+
21
+ if TYPE_CHECKING:
22
+ import altair as alt
23
+
24
+ from ..corpus import Corpus, CorpusSlice
25
+
26
+
27
+ def dispersion_plot(
28
+ corpus: Corpus | CorpusSlice,
29
+ targets: str | list[str],
30
+ width: int = 600,
31
+ height: int | None = None,
32
+ ) -> alt.Chart:
33
+ """Visualise *where* each ``target`` appears in the corpus.
34
+
35
+ Each occurrence of every target becomes a small vertical tick at
36
+ its document index. Stacked rows show one target at a time, so
37
+ "this term is concentrated in the first third of the corpus" or
38
+ "this term is evenly spread" reads off the plot at a glance.
39
+
40
+ Parameters
41
+ ----------
42
+ corpus
43
+ A :class:`Corpus` or :class:`CorpusSlice`.
44
+ targets
45
+ Single target or a list. Each gets its own horizontal row.
46
+ width
47
+ Chart width in pixels.
48
+ height
49
+ Chart height. If ``None``, scales with the number of targets.
50
+
51
+ Returns
52
+ -------
53
+ altair.Chart
54
+ Interactive chart with per-occurrence ticks coloured by target.
55
+ Requires the ``[viz]`` extra.
56
+
57
+ Example
58
+ -------
59
+ >>> import pycorpdiff as pcd
60
+ >>> corpus = pcd.load_hansard_sample()
61
+ >>> chart = pcd.viz.dispersion_plot(corpus, ['criminal', 'family']) # doctest: +SKIP
62
+ >>> chart.save('dispersion.svg') # doctest: +SKIP
63
+ """
64
+ import altair as alt
65
+
66
+ target_list = [targets] if isinstance(targets, str) else list(targets)
67
+ docs_tokens = corpus.tokens()
68
+ n_docs = len(docs_tokens)
69
+
70
+ # Collect (doc_index, target) for every occurrence.
71
+ rows: list[dict[str, Any]] = []
72
+ for doc_idx, tokens in enumerate(docs_tokens):
73
+ for token in tokens:
74
+ if token in target_list:
75
+ rows.append({"doc_index": doc_idx, "target": token})
76
+
77
+ if rows:
78
+ points = pd.DataFrame(rows, columns=["doc_index", "target"])
79
+ else:
80
+ # Empty corpus or no matches — emit a typed empty frame so
81
+ # altair can infer column types (it can't infer from empties).
82
+ points = pd.DataFrame(
83
+ {
84
+ "doc_index": pd.Series([], dtype="int64"),
85
+ "target": pd.Series([], dtype="object"),
86
+ }
87
+ )
88
+
89
+ if height is None:
90
+ height = max(60, 40 * len(target_list))
91
+
92
+ chart = (
93
+ alt.Chart(points)
94
+ .mark_tick(thickness=1.5)
95
+ .encode(
96
+ x=alt.X(
97
+ "doc_index:Q",
98
+ title=None,
99
+ scale=alt.Scale(domain=[0, max(1, n_docs - 1)]),
100
+ ),
101
+ y=alt.Y("target:N", title=None, sort=target_list),
102
+ color=alt.Color("target:N", legend=None, sort=target_list),
103
+ tooltip=[
104
+ alt.Tooltip("doc_index:Q"),
105
+ alt.Tooltip("target:N"),
106
+ ],
107
+ )
108
+ .properties(
109
+ width=width,
110
+ height=height,
111
+ title=alt.TitleParams(
112
+ text="Dispersion plot — where each term occurs in the corpus",
113
+ subtitle=f"{n_docs} documents on the x-axis, one tick per occurrence",
114
+ ),
115
+ )
116
+ )
117
+ return chart # type: ignore[no-any-return]