pycorpdiff 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycorpdiff/__init__.py +126 -0
- pycorpdiff/_backends/__init__.py +3 -0
- pycorpdiff/_backends/pandas.py +3 -0
- pycorpdiff/_backends/polars.py +3 -0
- pycorpdiff/collocation/__init__.py +19 -0
- pycorpdiff/collocation/cooccurrence.py +65 -0
- pycorpdiff/collocation/measures.py +102 -0
- pycorpdiff/collocation/network.py +233 -0
- pycorpdiff/collocation/shift.py +146 -0
- pycorpdiff/compare.py +345 -0
- pycorpdiff/corpus.py +411 -0
- pycorpdiff/datasets/__init__.py +27 -0
- pycorpdiff/datasets/_data/hansard_sample.parquet +0 -0
- pycorpdiff/datasets/_generate_hansard.py +221 -0
- pycorpdiff/datasets/hansard.py +235 -0
- pycorpdiff/datasets/histwords.py +221 -0
- pycorpdiff/explain.py +177 -0
- pycorpdiff/io/__init__.py +16 -0
- pycorpdiff/io/duckdb.py +92 -0
- pycorpdiff/io/huggingface.py +142 -0
- pycorpdiff/io/readers.py +138 -0
- pycorpdiff/keyness/__init__.py +26 -0
- pycorpdiff/keyness/bayes.py +50 -0
- pycorpdiff/keyness/chi_squared.py +94 -0
- pycorpdiff/keyness/correction.py +34 -0
- pycorpdiff/keyness/dispersion.py +89 -0
- pycorpdiff/keyness/effect_sizes.py +65 -0
- pycorpdiff/keyness/loglikelihood.py +92 -0
- pycorpdiff/keyness/multicorpus.py +143 -0
- pycorpdiff/keyness/permutation.py +154 -0
- pycorpdiff/py.typed +0 -0
- pycorpdiff/results.py +635 -0
- pycorpdiff/semantic/__init__.py +18 -0
- pycorpdiff/semantic/alignment.py +53 -0
- pycorpdiff/semantic/embed.py +84 -0
- pycorpdiff/semantic/shift.py +224 -0
- pycorpdiff/semantic/trajectory.py +166 -0
- pycorpdiff/stats.py +69 -0
- pycorpdiff/temporal/__init__.py +15 -0
- pycorpdiff/temporal/bocpd.py +233 -0
- pycorpdiff/temporal/causal_impact.py +293 -0
- pycorpdiff/temporal/changepoint.py +92 -0
- pycorpdiff/temporal/forecast.py +405 -0
- pycorpdiff/temporal/its.py +123 -0
- pycorpdiff/temporal/slicing.py +174 -0
- pycorpdiff/tokenize.py +110 -0
- pycorpdiff/viz/__init__.py +37 -0
- pycorpdiff/viz/bocpd.py +173 -0
- pycorpdiff/viz/causal_impact.py +142 -0
- pycorpdiff/viz/collocation.py +48 -0
- pycorpdiff/viz/dispersion.py +117 -0
- pycorpdiff/viz/forecast.py +129 -0
- pycorpdiff/viz/keyness.py +96 -0
- pycorpdiff/viz/network.py +186 -0
- pycorpdiff/viz/scattertext.py +160 -0
- pycorpdiff/viz/semantic_forecast.py +114 -0
- pycorpdiff/viz/trajectory.py +48 -0
- pycorpdiff-0.1.0a0.dist-info/METADATA +230 -0
- pycorpdiff-0.1.0a0.dist-info/RECORD +61 -0
- pycorpdiff-0.1.0a0.dist-info/WHEEL +4 -0
- pycorpdiff-0.1.0a0.dist-info/licenses/LICENSE +21 -0
pycorpdiff/tokenize.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Tokenizer protocol and the default regex tokenizer.
|
|
2
|
+
|
|
3
|
+
The :class:`Tokenizer` protocol is the package's only extension point for
|
|
4
|
+
language-specific preprocessing. Adapters around spaCy, Stanza, jieba,
|
|
5
|
+
fugashi, etc. need to satisfy a single ``__call__(text: str) -> list[str]``
|
|
6
|
+
contract — no inheritance, no registration.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import re
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Protocol, runtime_checkable
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@runtime_checkable
|
|
17
|
+
class Tokenizer(Protocol):
|
|
18
|
+
"""Anything callable that maps a string to a list of token strings."""
|
|
19
|
+
|
|
20
|
+
def __call__(self, text: str) -> list[str]: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class RegexTokenizer:
|
|
25
|
+
"""A minimal Unicode-aware regex tokenizer used as the default.
|
|
26
|
+
|
|
27
|
+
The default pattern matches sequences of word characters (``\\w+``),
|
|
28
|
+
which under Python's default regex engine is Unicode-aware and
|
|
29
|
+
therefore safe for non-Latin scripts at the granularity of "word-like
|
|
30
|
+
runs of letters / digits / underscores". Researchers needing
|
|
31
|
+
language-specific behaviour (lemmatisation, segmentation of CJK
|
|
32
|
+
scripts, MWE handling, etc.) should plug in a spaCy/Stanza/jieba
|
|
33
|
+
adapter instead.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
pattern: str = r"\w+"
|
|
37
|
+
lowercase: bool = True
|
|
38
|
+
_compiled: re.Pattern[str] = field(init=False, repr=False, compare=False)
|
|
39
|
+
|
|
40
|
+
def __post_init__(self) -> None:
|
|
41
|
+
object.__setattr__(self, "_compiled", re.compile(self.pattern, re.UNICODE))
|
|
42
|
+
|
|
43
|
+
def __call__(self, text: str) -> list[str]:
|
|
44
|
+
if self.lowercase:
|
|
45
|
+
text = text.lower()
|
|
46
|
+
return self._compiled.findall(text)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass(frozen=True)
|
|
50
|
+
class NgramTokenizer:
|
|
51
|
+
"""A tokenizer wrapper that emits joined n-grams from a base tokenizer.
|
|
52
|
+
|
|
53
|
+
Wraps any :class:`Tokenizer` and yields ``n``-token sequences joined
|
|
54
|
+
by ``sep``. The output is a flat ``list[str]`` of joined n-grams,
|
|
55
|
+
which means every downstream analytical surface — keyness,
|
|
56
|
+
dispersion, collocation, semantic shift — treats them as ordinary
|
|
57
|
+
terms with no special-casing needed.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
base
|
|
62
|
+
The underlying tokenizer producing unigrams. Defaults to
|
|
63
|
+
:class:`RegexTokenizer`.
|
|
64
|
+
n
|
|
65
|
+
N-gram order. ``2`` = bigrams, ``3`` = trigrams. Must be ``>=1``.
|
|
66
|
+
sep
|
|
67
|
+
Joiner string. Default ``"_"`` matches gensim's convention and
|
|
68
|
+
sidesteps the usual ambiguity of whitespace-joined n-grams when
|
|
69
|
+
the base tokenizer itself strips whitespace.
|
|
70
|
+
include_lower
|
|
71
|
+
If ``True``, the output also includes every ``k``-gram for
|
|
72
|
+
``k < n`` (so ``n=3`` emits unigrams + bigrams + trigrams).
|
|
73
|
+
Useful when you want a single Comparison to keyness-rank
|
|
74
|
+
single words *and* their multi-word collocations side-by-side.
|
|
75
|
+
|
|
76
|
+
Examples
|
|
77
|
+
--------
|
|
78
|
+
>>> from pycorpdiff.tokenize import NgramTokenizer
|
|
79
|
+
>>> tok = NgramTokenizer(n=2)
|
|
80
|
+
>>> tok("the cat sat on the mat")
|
|
81
|
+
['the_cat', 'cat_sat', 'sat_on', 'on_the', 'the_mat']
|
|
82
|
+
>>> NgramTokenizer(n=2, include_lower=True)("a b c")
|
|
83
|
+
['a', 'b', 'c', 'a_b', 'b_c']
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
base: Tokenizer = field(default_factory=RegexTokenizer)
|
|
87
|
+
n: int = 2
|
|
88
|
+
sep: str = "_"
|
|
89
|
+
include_lower: bool = False
|
|
90
|
+
|
|
91
|
+
def __post_init__(self) -> None:
|
|
92
|
+
if self.n < 1:
|
|
93
|
+
raise ValueError(f"n must be >= 1; got {self.n}")
|
|
94
|
+
|
|
95
|
+
def __call__(self, text: str) -> list[str]:
|
|
96
|
+
unigrams = self.base(text)
|
|
97
|
+
if self.n == 1:
|
|
98
|
+
return unigrams
|
|
99
|
+
out: list[str] = []
|
|
100
|
+
start = 1 if self.include_lower else self.n
|
|
101
|
+
for k in range(start, self.n + 1):
|
|
102
|
+
if k == 1:
|
|
103
|
+
out.extend(unigrams)
|
|
104
|
+
continue
|
|
105
|
+
sep = self.sep
|
|
106
|
+
out.extend(
|
|
107
|
+
sep.join(unigrams[i : i + k])
|
|
108
|
+
for i in range(len(unigrams) - k + 1)
|
|
109
|
+
)
|
|
110
|
+
return out
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Visualisation helpers — altair-first, matplotlib for paper-grade figures.
|
|
2
|
+
|
|
3
|
+
Every Result type's ``.plot()`` method delegates here. Plot functions
|
|
4
|
+
also accept a bare DataFrame so users can call
|
|
5
|
+
``pcd.viz.keyness_volcano(df)`` directly without going through a Result.
|
|
6
|
+
|
|
7
|
+
altair is an optional dependency declared in the ``viz`` extra. Each
|
|
8
|
+
plot function lazily imports altair on first call; the friendly
|
|
9
|
+
ImportError lives at that boundary.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from .bocpd import bocpd_plot
|
|
15
|
+
from .causal_impact import causal_impact_plot
|
|
16
|
+
from .collocation import collocation_diverging_bar
|
|
17
|
+
from .dispersion import dispersion_plot
|
|
18
|
+
from .forecast import forecast_plot
|
|
19
|
+
from .keyness import keyness_top_n_bar, keyness_volcano
|
|
20
|
+
from .network import network_plot
|
|
21
|
+
from .scattertext import scattertext_plot
|
|
22
|
+
from .semantic_forecast import semantic_forecast_plot
|
|
23
|
+
from .trajectory import trajectory_with_ci
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"bocpd_plot",
|
|
27
|
+
"causal_impact_plot",
|
|
28
|
+
"collocation_diverging_bar",
|
|
29
|
+
"dispersion_plot",
|
|
30
|
+
"forecast_plot",
|
|
31
|
+
"keyness_top_n_bar",
|
|
32
|
+
"keyness_volcano",
|
|
33
|
+
"network_plot",
|
|
34
|
+
"scattertext_plot",
|
|
35
|
+
"semantic_forecast_plot",
|
|
36
|
+
"trajectory_with_ci",
|
|
37
|
+
]
|
pycorpdiff/viz/bocpd.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""BOCPD diagnostic plot — series + run-length posterior heatmap + MAP line.
|
|
2
|
+
|
|
3
|
+
Three stacked panels sharing the time axis:
|
|
4
|
+
|
|
5
|
+
1. The input series with detected changepoints flagged.
|
|
6
|
+
2. Heatmap of the run-length posterior P(r_t | data so far) on a
|
|
7
|
+
log colour scale — the canonical BOCPD diagnostic figure.
|
|
8
|
+
3. MAP run length over time. Visible drops mark changepoints.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import altair as alt
|
|
20
|
+
|
|
21
|
+
from ..temporal.bocpd import BocpdResult
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def bocpd_plot(
|
|
25
|
+
result: BocpdResult,
|
|
26
|
+
*,
|
|
27
|
+
width: int = 660,
|
|
28
|
+
height_per_panel: int = 180,
|
|
29
|
+
max_run_length_shown: int = 40,
|
|
30
|
+
threshold: int = 3,
|
|
31
|
+
) -> alt.Chart:
|
|
32
|
+
"""Three-panel BOCPD diagnostic chart.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
result
|
|
37
|
+
A :class:`BocpdResult`.
|
|
38
|
+
width
|
|
39
|
+
Width of each panel (panels are vertically concat'd).
|
|
40
|
+
height_per_panel
|
|
41
|
+
Height of each panel.
|
|
42
|
+
max_run_length_shown
|
|
43
|
+
Truncate the heatmap at this run length — most posterior mass
|
|
44
|
+
lives in the first few dozen run lengths.
|
|
45
|
+
threshold
|
|
46
|
+
MAP-run-length threshold below which a step is flagged as a
|
|
47
|
+
detected changepoint in panel 1.
|
|
48
|
+
"""
|
|
49
|
+
import altair as alt
|
|
50
|
+
|
|
51
|
+
# The heatmap panel has T × (max_run_length_shown + 1) cells which
|
|
52
|
+
# can exceed altair's default 5000-row inline-data limit for longer
|
|
53
|
+
# series. Disable the limit on the heatmap dataframe specifically —
|
|
54
|
+
# the data is inline (no external host), so the rendering cost is
|
|
55
|
+
# bounded by the local browser.
|
|
56
|
+
alt.data_transformers.disable_max_rows()
|
|
57
|
+
|
|
58
|
+
periods = result.series.index
|
|
59
|
+
if isinstance(periods, pd.PeriodIndex):
|
|
60
|
+
period_axis: pd.Index = pd.Index(periods.to_timestamp())
|
|
61
|
+
else:
|
|
62
|
+
period_axis = pd.Index(periods)
|
|
63
|
+
|
|
64
|
+
series_df = pd.DataFrame(
|
|
65
|
+
{"period": period_axis, "value": result.series.to_numpy(dtype=float)}
|
|
66
|
+
)
|
|
67
|
+
flagged_df = result.detected_changepoints(threshold=threshold).copy()
|
|
68
|
+
if len(flagged_df) and isinstance(flagged_df["period"].iloc[0], pd.Period):
|
|
69
|
+
flagged_df["period"] = flagged_df["period"].apply(lambda p: p.to_timestamp())
|
|
70
|
+
|
|
71
|
+
# ---------- Panel 1: series + flagged changepoints ----------
|
|
72
|
+
series_line = (
|
|
73
|
+
alt.Chart(series_df)
|
|
74
|
+
.mark_line(strokeWidth=2, color="#0b6e7c")
|
|
75
|
+
.encode(
|
|
76
|
+
x=alt.X("period:T", title=None),
|
|
77
|
+
y=alt.Y("value:Q", title=None),
|
|
78
|
+
tooltip=["period", alt.Tooltip("value:Q", format=".5f")],
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
series_points = (
|
|
82
|
+
alt.Chart(series_df)
|
|
83
|
+
.mark_point(filled=True, size=40, color="#0b6e7c")
|
|
84
|
+
.encode(x="period:T", y="value:Q")
|
|
85
|
+
)
|
|
86
|
+
cp_rules = (
|
|
87
|
+
alt.Chart(flagged_df)
|
|
88
|
+
.mark_rule(color="#e63946", strokeDash=[4, 3], opacity=0.7)
|
|
89
|
+
.encode(x="period:T")
|
|
90
|
+
)
|
|
91
|
+
panel1 = (series_line + series_points + cp_rules).properties(
|
|
92
|
+
width=width,
|
|
93
|
+
height=height_per_panel,
|
|
94
|
+
title=alt.TitleParams(
|
|
95
|
+
text=f"Observed series — {len(flagged_df)} flagged changepoint(s)",
|
|
96
|
+
subtitle=f"red lines: MAP run length ≤ {threshold} (hazard = {result.hazard})",
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# ---------- Panel 2: run-length posterior heatmap ----------
|
|
101
|
+
R = np.asarray(result.run_length_posterior)
|
|
102
|
+
R_shown = R[:, : max_run_length_shown + 1]
|
|
103
|
+
# log10 + clip for colour mapping — most posterior values are
|
|
104
|
+
# tiny floats, log compresses the dynamic range.
|
|
105
|
+
with np.errstate(divide="ignore"):
|
|
106
|
+
log_R = np.log10(np.clip(R_shown, 1e-12, 1.0))
|
|
107
|
+
n_t, n_r = log_R.shape
|
|
108
|
+
period_repeat = np.tile(period_axis.to_numpy(), n_r)
|
|
109
|
+
runs = np.repeat(np.arange(n_r), n_t)
|
|
110
|
+
heat_df = pd.DataFrame(
|
|
111
|
+
{
|
|
112
|
+
"period": period_repeat,
|
|
113
|
+
"run_length": runs,
|
|
114
|
+
"log_posterior": log_R.T.ravel(),
|
|
115
|
+
}
|
|
116
|
+
)
|
|
117
|
+
heatmap = (
|
|
118
|
+
alt.Chart(heat_df)
|
|
119
|
+
.mark_rect()
|
|
120
|
+
.encode(
|
|
121
|
+
x=alt.X("period:T", title=None),
|
|
122
|
+
y=alt.Y("run_length:O", title="run length r", sort="descending"),
|
|
123
|
+
color=alt.Color(
|
|
124
|
+
"log_posterior:Q",
|
|
125
|
+
scale=alt.Scale(scheme="viridis", domain=[-6, 0]),
|
|
126
|
+
title="log₁₀ P(r | data)",
|
|
127
|
+
),
|
|
128
|
+
tooltip=[
|
|
129
|
+
"period",
|
|
130
|
+
"run_length",
|
|
131
|
+
alt.Tooltip("log_posterior:Q", format=".2f"),
|
|
132
|
+
],
|
|
133
|
+
)
|
|
134
|
+
.properties(
|
|
135
|
+
width=width,
|
|
136
|
+
height=height_per_panel,
|
|
137
|
+
title="Run-length posterior P(r_t | data through t) — log₁₀ scale",
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# ---------- Panel 3: MAP run length over time ----------
|
|
142
|
+
map_df = pd.DataFrame(
|
|
143
|
+
{
|
|
144
|
+
"period": period_axis,
|
|
145
|
+
"map_run_length": result.map_run_length.to_numpy(dtype=int),
|
|
146
|
+
}
|
|
147
|
+
)
|
|
148
|
+
map_line = (
|
|
149
|
+
alt.Chart(map_df)
|
|
150
|
+
.mark_line(strokeWidth=2, color="#1f7a3e")
|
|
151
|
+
.encode(
|
|
152
|
+
x=alt.X("period:T", title=None),
|
|
153
|
+
y=alt.Y("map_run_length:Q", title="MAP run length"),
|
|
154
|
+
tooltip=["period", "map_run_length"],
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
map_points = (
|
|
158
|
+
alt.Chart(map_df)
|
|
159
|
+
.mark_point(filled=True, size=40, color="#1f7a3e")
|
|
160
|
+
.encode(x="period:T", y="map_run_length:Q")
|
|
161
|
+
)
|
|
162
|
+
threshold_rule = (
|
|
163
|
+
alt.Chart(pd.DataFrame({"y": [threshold]}))
|
|
164
|
+
.mark_rule(color="#888", strokeDash=[3, 3], opacity=0.6)
|
|
165
|
+
.encode(y="y:Q")
|
|
166
|
+
)
|
|
167
|
+
panel3 = (map_line + map_points + threshold_rule).properties(
|
|
168
|
+
width=width,
|
|
169
|
+
height=height_per_panel,
|
|
170
|
+
title="MAP run length — visible drops mark detected changepoints",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return alt.vconcat(panel1, heatmap, panel3).resolve_scale(x="shared") # type: ignore[no-any-return]
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Three-panel causal-impact plot (Brodersen et al. 2015 style).
|
|
2
|
+
|
|
3
|
+
- Panel 1 — observed series vs counterfactual (dashed) with CrI band
|
|
4
|
+
- Panel 2 — pointwise effect (observed − counterfactual) with CrI band,
|
|
5
|
+
zero reference line
|
|
6
|
+
- Panel 3 — cumulative effect with CrI band, zero reference line
|
|
7
|
+
|
|
8
|
+
Stacked vertically, sharing the time axis. The visual grammar matches
|
|
9
|
+
the figures from Brodersen, Gallusser, Koehler, Remy & Scott (2015).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import altair as alt
|
|
20
|
+
|
|
21
|
+
from ..temporal.causal_impact import CausalImpactResult
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def causal_impact_plot(
|
|
25
|
+
result: CausalImpactResult,
|
|
26
|
+
*,
|
|
27
|
+
width: int = 640,
|
|
28
|
+
height_per_panel: int = 200,
|
|
29
|
+
) -> alt.Chart:
|
|
30
|
+
"""Three-panel observed/counterfactual + effect plot.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
result
|
|
35
|
+
A :class:`CausalImpactResult` from
|
|
36
|
+
:meth:`TemporalTrajectory.causal_impact`.
|
|
37
|
+
width
|
|
38
|
+
Width of each panel.
|
|
39
|
+
height_per_panel
|
|
40
|
+
Height of each panel. Total chart height ≈ 3× this.
|
|
41
|
+
"""
|
|
42
|
+
import altair as alt
|
|
43
|
+
|
|
44
|
+
df = result.table.copy()
|
|
45
|
+
if len(df) and isinstance(df["period"].iloc[0], pd.Period):
|
|
46
|
+
df["period"] = df["period"].apply(lambda p: p.to_timestamp())
|
|
47
|
+
|
|
48
|
+
# The event marker: a vertical rule at result.event_date.
|
|
49
|
+
event_df = pd.DataFrame({"event": [pd.Timestamp(result.event_date)]})
|
|
50
|
+
event_rule = (
|
|
51
|
+
alt.Chart(event_df)
|
|
52
|
+
.mark_rule(color="#999", strokeDash=[3, 3], strokeWidth=1.5)
|
|
53
|
+
.encode(x="event:T")
|
|
54
|
+
)
|
|
55
|
+
zero_rule = (
|
|
56
|
+
alt.Chart(pd.DataFrame({"y": [0.0]}))
|
|
57
|
+
.mark_rule(color="#666", strokeWidth=1)
|
|
58
|
+
.encode(y="y:Q")
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
x_axis = alt.X("period:T", title=None)
|
|
62
|
+
|
|
63
|
+
# ---------- Panel 1: observed vs counterfactual ----------
|
|
64
|
+
base1 = alt.Chart(df).encode(x=x_axis)
|
|
65
|
+
cf_band = base1.mark_area(opacity=0.18, color="#888").encode(
|
|
66
|
+
y=alt.Y("counterfactual_lower:Q", title=f"{result.target} rate"),
|
|
67
|
+
y2="counterfactual_upper:Q",
|
|
68
|
+
)
|
|
69
|
+
cf_line = base1.mark_line(
|
|
70
|
+
color="#888", strokeDash=[5, 4], strokeWidth=2
|
|
71
|
+
).encode(y="counterfactual:Q")
|
|
72
|
+
observed_line = base1.mark_line(color="#0b6e7c", strokeWidth=2.5).encode(
|
|
73
|
+
y="observed:Q",
|
|
74
|
+
tooltip=[
|
|
75
|
+
"period",
|
|
76
|
+
alt.Tooltip("observed:Q", format=".5f"),
|
|
77
|
+
alt.Tooltip("counterfactual:Q", format=".5f"),
|
|
78
|
+
alt.Tooltip("counterfactual_lower:Q", format=".5f"),
|
|
79
|
+
alt.Tooltip("counterfactual_upper:Q", format=".5f"),
|
|
80
|
+
],
|
|
81
|
+
)
|
|
82
|
+
panel1 = (
|
|
83
|
+
(cf_band + cf_line + observed_line + event_rule)
|
|
84
|
+
.properties(
|
|
85
|
+
width=width,
|
|
86
|
+
height=height_per_panel,
|
|
87
|
+
title=alt.TitleParams(
|
|
88
|
+
text=f"Observed (teal) vs counterfactual (gray) — {result.target!r}",
|
|
89
|
+
subtitle=(
|
|
90
|
+
f"event = {pd.Timestamp(result.event_date).date()} · "
|
|
91
|
+
f"BSTS local linear trend on {result.n_pre} pre-event periods"
|
|
92
|
+
),
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# ---------- Panel 2: pointwise effect ----------
|
|
98
|
+
base2 = alt.Chart(df).encode(x=x_axis)
|
|
99
|
+
pw_band = base2.mark_area(opacity=0.22, color="#e63946").encode(
|
|
100
|
+
y=alt.Y("pointwise_lower:Q", title="pointwise effect"),
|
|
101
|
+
y2="pointwise_upper:Q",
|
|
102
|
+
)
|
|
103
|
+
pw_line = base2.mark_line(color="#e63946", strokeWidth=2).encode(
|
|
104
|
+
y="pointwise_effect:Q",
|
|
105
|
+
tooltip=[
|
|
106
|
+
"period",
|
|
107
|
+
alt.Tooltip("pointwise_effect:Q", format=".5f"),
|
|
108
|
+
alt.Tooltip("pointwise_lower:Q", format=".5f"),
|
|
109
|
+
alt.Tooltip("pointwise_upper:Q", format=".5f"),
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
panel2 = (pw_band + pw_line + zero_rule + event_rule).properties(
|
|
113
|
+
width=width,
|
|
114
|
+
height=height_per_panel,
|
|
115
|
+
title=alt.TitleParams(
|
|
116
|
+
text="Pointwise effect (observed − counterfactual)",
|
|
117
|
+
subtitle=f"{int(result.level * 100)}% credible interval shaded",
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# ---------- Panel 3: cumulative effect ----------
|
|
122
|
+
base3 = alt.Chart(df).encode(x=x_axis)
|
|
123
|
+
cum_band = base3.mark_area(opacity=0.22, color="#1f7a3e").encode(
|
|
124
|
+
y=alt.Y("cumulative_lower:Q", title="cumulative effect"),
|
|
125
|
+
y2="cumulative_upper:Q",
|
|
126
|
+
)
|
|
127
|
+
cum_line = base3.mark_line(color="#1f7a3e", strokeWidth=2).encode(
|
|
128
|
+
y="cumulative_effect:Q",
|
|
129
|
+
tooltip=[
|
|
130
|
+
"period",
|
|
131
|
+
alt.Tooltip("cumulative_effect:Q", format=".5f"),
|
|
132
|
+
alt.Tooltip("cumulative_lower:Q", format=".5f"),
|
|
133
|
+
alt.Tooltip("cumulative_upper:Q", format=".5f"),
|
|
134
|
+
],
|
|
135
|
+
)
|
|
136
|
+
panel3 = (cum_band + cum_line + zero_rule + event_rule).properties(
|
|
137
|
+
width=width,
|
|
138
|
+
height=height_per_panel,
|
|
139
|
+
title="Cumulative effect",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return alt.vconcat(panel1, panel2, panel3).resolve_scale(x="shared") # type: ignore[no-any-return]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Collocation-shift visualisation — diverging horizontal bar."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import altair as alt
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def collocation_diverging_bar(
|
|
14
|
+
df: pd.DataFrame,
|
|
15
|
+
n: int = 20,
|
|
16
|
+
width: int = 500,
|
|
17
|
+
height: int | None = None,
|
|
18
|
+
) -> alt.Chart:
|
|
19
|
+
"""Diverging bar chart of the top ``n`` collocate shifts.
|
|
20
|
+
|
|
21
|
+
Positive bars are A-leaning collocates (gained around the target in
|
|
22
|
+
A); negative bars are B-leaning (lost from A, gained in B). Sorted
|
|
23
|
+
by signed shift so the eye reads the divergence directly.
|
|
24
|
+
"""
|
|
25
|
+
import altair as alt
|
|
26
|
+
|
|
27
|
+
subset = (
|
|
28
|
+
df.assign(_abs=df["shift"].abs()).nlargest(n, "_abs").drop(columns="_abs")
|
|
29
|
+
).sort_values("shift", ascending=False)
|
|
30
|
+
if height is None:
|
|
31
|
+
height = max(200, 18 * len(subset))
|
|
32
|
+
|
|
33
|
+
chart = (
|
|
34
|
+
alt.Chart(subset)
|
|
35
|
+
.mark_bar()
|
|
36
|
+
.encode(
|
|
37
|
+
x=alt.X("shift:Q", title="Shift (A − B)"),
|
|
38
|
+
y=alt.Y("collocate:N", sort="-x", title=None),
|
|
39
|
+
color=alt.condition(
|
|
40
|
+
alt.datum.shift >= 0,
|
|
41
|
+
alt.value("#1f77b4"),
|
|
42
|
+
alt.value("#d62728"),
|
|
43
|
+
),
|
|
44
|
+
tooltip=list(subset.columns),
|
|
45
|
+
)
|
|
46
|
+
.properties(width=width, height=height)
|
|
47
|
+
)
|
|
48
|
+
return chart # type: ignore[no-any-return]
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Dispersion plots: *where* in a corpus does a term appear?
|
|
2
|
+
|
|
3
|
+
A term can have the same overall frequency in two corpora and be
|
|
4
|
+
distributed completely differently — even across one corpus, a high-
|
|
5
|
+
frequency word can be clustered in a few documents or evenly spread.
|
|
6
|
+
The dispersion plot answers "where" by marking each occurrence at the
|
|
7
|
+
relevant document index along a horizontal axis.
|
|
8
|
+
|
|
9
|
+
This is the classic Mosteller / Stubbs / Brezina visualisation for
|
|
10
|
+
"how representative is a frequency count" — companion to the
|
|
11
|
+
:func:`pycorpdiff.keyness.juilland_d` / ``dispersion_dp`` numerical
|
|
12
|
+
measures pycorpdiff already exposes.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import TYPE_CHECKING, Any
|
|
18
|
+
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
import altair as alt
|
|
23
|
+
|
|
24
|
+
from ..corpus import Corpus, CorpusSlice
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def dispersion_plot(
|
|
28
|
+
corpus: Corpus | CorpusSlice,
|
|
29
|
+
targets: str | list[str],
|
|
30
|
+
width: int = 600,
|
|
31
|
+
height: int | None = None,
|
|
32
|
+
) -> alt.Chart:
|
|
33
|
+
"""Visualise *where* each ``target`` appears in the corpus.
|
|
34
|
+
|
|
35
|
+
Each occurrence of every target becomes a small vertical tick at
|
|
36
|
+
its document index. Stacked rows show one target at a time, so
|
|
37
|
+
"this term is concentrated in the first third of the corpus" or
|
|
38
|
+
"this term is evenly spread" reads off the plot at a glance.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
corpus
|
|
43
|
+
A :class:`Corpus` or :class:`CorpusSlice`.
|
|
44
|
+
targets
|
|
45
|
+
Single target or a list. Each gets its own horizontal row.
|
|
46
|
+
width
|
|
47
|
+
Chart width in pixels.
|
|
48
|
+
height
|
|
49
|
+
Chart height. If ``None``, scales with the number of targets.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
altair.Chart
|
|
54
|
+
Interactive chart with per-occurrence ticks coloured by target.
|
|
55
|
+
Requires the ``[viz]`` extra.
|
|
56
|
+
|
|
57
|
+
Example
|
|
58
|
+
-------
|
|
59
|
+
>>> import pycorpdiff as pcd
|
|
60
|
+
>>> corpus = pcd.load_hansard_sample()
|
|
61
|
+
>>> chart = pcd.viz.dispersion_plot(corpus, ['criminal', 'family']) # doctest: +SKIP
|
|
62
|
+
>>> chart.save('dispersion.svg') # doctest: +SKIP
|
|
63
|
+
"""
|
|
64
|
+
import altair as alt
|
|
65
|
+
|
|
66
|
+
target_list = [targets] if isinstance(targets, str) else list(targets)
|
|
67
|
+
docs_tokens = corpus.tokens()
|
|
68
|
+
n_docs = len(docs_tokens)
|
|
69
|
+
|
|
70
|
+
# Collect (doc_index, target) for every occurrence.
|
|
71
|
+
rows: list[dict[str, Any]] = []
|
|
72
|
+
for doc_idx, tokens in enumerate(docs_tokens):
|
|
73
|
+
for token in tokens:
|
|
74
|
+
if token in target_list:
|
|
75
|
+
rows.append({"doc_index": doc_idx, "target": token})
|
|
76
|
+
|
|
77
|
+
if rows:
|
|
78
|
+
points = pd.DataFrame(rows, columns=["doc_index", "target"])
|
|
79
|
+
else:
|
|
80
|
+
# Empty corpus or no matches — emit a typed empty frame so
|
|
81
|
+
# altair can infer column types (it can't infer from empties).
|
|
82
|
+
points = pd.DataFrame(
|
|
83
|
+
{
|
|
84
|
+
"doc_index": pd.Series([], dtype="int64"),
|
|
85
|
+
"target": pd.Series([], dtype="object"),
|
|
86
|
+
}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if height is None:
|
|
90
|
+
height = max(60, 40 * len(target_list))
|
|
91
|
+
|
|
92
|
+
chart = (
|
|
93
|
+
alt.Chart(points)
|
|
94
|
+
.mark_tick(thickness=1.5)
|
|
95
|
+
.encode(
|
|
96
|
+
x=alt.X(
|
|
97
|
+
"doc_index:Q",
|
|
98
|
+
title=None,
|
|
99
|
+
scale=alt.Scale(domain=[0, max(1, n_docs - 1)]),
|
|
100
|
+
),
|
|
101
|
+
y=alt.Y("target:N", title=None, sort=target_list),
|
|
102
|
+
color=alt.Color("target:N", legend=None, sort=target_list),
|
|
103
|
+
tooltip=[
|
|
104
|
+
alt.Tooltip("doc_index:Q"),
|
|
105
|
+
alt.Tooltip("target:N"),
|
|
106
|
+
],
|
|
107
|
+
)
|
|
108
|
+
.properties(
|
|
109
|
+
width=width,
|
|
110
|
+
height=height,
|
|
111
|
+
title=alt.TitleParams(
|
|
112
|
+
text="Dispersion plot — where each term occurs in the corpus",
|
|
113
|
+
subtitle=f"{n_docs} documents on the x-axis, one tick per occurrence",
|
|
114
|
+
),
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
return chart # type: ignore[no-any-return]
|