google-ngrams 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google_ngrams/vnc.py ADDED
@@ -0,0 +1,518 @@
1
+ import numpy as np
2
+ import polars as pl
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib.figure import Figure
5
+ from .scatter_helpers import gam_smoother
6
+ from .vnc_helpers import (
7
+ _linkage_matrix,
8
+ _vnc_calculate_info,
9
+ _plot_dendrogram
10
+ )
11
+
12
+
13
+ class TimeSeries:
14
+
15
+ def __init__(self,
16
+ time_series: pl.DataFrame,
17
+ time_col: str,
18
+ values_col: str):
19
+
20
+ time = time_series.get_column(time_col, default=None)
21
+ values = time_series.get_column(values_col, default=None)
22
+
23
+ if time is None:
24
+ raise ValueError("""
25
+ Invalid column.
26
+ Check name. Couldn't find column in DataFrame.
27
+ """)
28
+ if values is None:
29
+ raise ValueError("""
30
+ Invalid column.
31
+ Check name. Couldn't find column in DataFrame.
32
+ """)
33
+ if not isinstance(values.dtype, (pl.Float64, pl.Float32)):
34
+ raise ValueError("""
35
+ Invalid DataFrame.
36
+ Expected a column of normalized frequencies.
37
+ """)
38
+ if len(time) != len(values):
39
+ raise ValueError("""
40
+ Your time and values vectors must be the same length.
41
+ """)
42
+
43
+ time_series = time_series.sort(time)
44
+ self.time_intervals = time_series.get_column(time_col).to_numpy()
45
+ self.frequencies = time_series.get_column(values_col).to_numpy()
46
+ self.Z_sd = _linkage_matrix(time_series=self.time_intervals,
47
+ frequency=self.frequencies)
48
+ self.Z_cv = _linkage_matrix(time_series=self.time_intervals,
49
+ frequency=self.frequencies,
50
+ distance_measure='cv')
51
+ self.distances_sd = np.array([self.Z_sd[i][2].item()
52
+ for i in range(len(self.Z_sd))])
53
+ self.distances_cv = np.array([self.Z_cv[i][2].item()
54
+ for i in range(len(self.Z_cv))])
55
+
56
+ self.clusters = None
57
+ self.distance_threshold = None
58
+
59
+ def timeviz_barplot(
60
+ self,
61
+ width=8,
62
+ height=4,
63
+ dpi=150,
64
+ barwidth=4,
65
+ fill_color='#440154',
66
+ tick_interval=None,
67
+ label_rotation=None
68
+ ) -> Figure:
69
+ """
70
+ Generate a bar plot of token frequenices over time.
71
+
72
+ Parameters
73
+ ----------
74
+ width:
75
+ The width of the plot.
76
+ height:
77
+ The height of the plot.
78
+ dpi:
79
+ The resolution of the plot.
80
+ barwidth:
81
+ The width of the bars.
82
+ fill_color:
83
+ The color of the bars.
84
+ tick_interval:
85
+ Interval spacing for the tick labels.
86
+ label_rotation:
87
+ Angle used to rotate tick labels.
88
+
89
+ Returns
90
+ -------
91
+ Figure
92
+ A matplotlib figure.
93
+
94
+ """
95
+ xx = self.time_intervals
96
+ yy = self.frequencies
97
+
98
+ if label_rotation is None:
99
+ rotation = 90
100
+ else:
101
+ rotation = label_rotation
102
+
103
+ if tick_interval is None:
104
+ interval = np.diff(xx)[0]
105
+ else:
106
+ interval = tick_interval
107
+
108
+ start_value = np.min(xx)
109
+
110
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
111
+
112
+ ax.bar(xx, yy, color=fill_color, edgecolor='black',
113
+ linewidth=.5, width=barwidth)
114
+
115
+ ax.set_ylabel('Frequency (per mil. words)')
116
+
117
+ # Despine
118
+ ax.spines['right'].set_visible(False)
119
+ ax.spines['top'].set_visible(False)
120
+ ax.axhline(y=0, color='black', linestyle='-', linewidth=.5)
121
+
122
+ ax.tick_params(axis="x", which="both", labelrotation=rotation)
123
+ ax.grid(axis='y', color='w', linestyle='--', linewidth=.5)
124
+ # Attempt to use newer matplotlib MultipleLocator with offset.
125
+ # Older versions (<3.8) don't support the 'offset' kwarg, so we
126
+ # gracefully fall back and manually align ticks.
127
+ try:
128
+ ax.xaxis.set_major_locator(
129
+ plt.MultipleLocator(base=interval, offset=start_value)
130
+ )
131
+ except TypeError:
132
+ # Fallback: no offset support. Use basic locator and, if needed,
133
+ # force tick positions to start at the first time value.
134
+ ax.xaxis.set_major_locator(plt.MultipleLocator(interval))
135
+ # If the first tick that would be drawn by the locator would not
136
+ # coincide with the first bar (start_value), set explicit ticks.
137
+ if start_value % interval != 0:
138
+ ticks = np.arange(start_value, xx.max() + interval, interval)
139
+ ax.set_xticks(ticks)
140
+ except Exception:
141
+ # Last-resort: just set ticks to the observed time points.
142
+ ax.set_xticks(xx)
143
+
144
+ return fig
145
+
146
+ def timeviz_scatterplot(
147
+ self,
148
+ width=8,
149
+ height=4,
150
+ dpi=150,
151
+ point_color='black',
152
+ point_size=0.5,
153
+ smoothing=7,
154
+ confidence_interval=True
155
+ ) -> Figure:
156
+ """
157
+ Generate a scatter plot of token frequenices over time
158
+ with a smoothed fit line and a confidence interval.
159
+
160
+ Parameters
161
+ ----------
162
+ width:
163
+ The width of the plot.
164
+ height:
165
+ The height of the plot.
166
+ dpi:
167
+ The resolution of the plot.
168
+ point_color:
169
+ The color of the points.
170
+ point_size:
171
+ The size of the points.
172
+ smoothing:
173
+ A value between 1 and 9 specifying magnitude of smoothing.
174
+ confidence_interval:
175
+ Whether to plot a confidence interval.
176
+
177
+ Returns
178
+ -------
179
+ Figure
180
+ A matplotlib figure.
181
+
182
+ """
183
+ if 0 < smoothing and smoothing < 10:
184
+ smoothing = smoothing
185
+ else:
186
+ smoothing = 7
187
+
188
+ xx = self.time_intervals
189
+ yy = self.frequencies
190
+
191
+ # Lightweight spline-based smoother with optional bootstrap CI
192
+ sm = gam_smoother(xx, yy, smoothing=smoothing, ci=confidence_interval)
193
+ fit_line = sm.y_fit
194
+ upper = sm.y_upper if confidence_interval else None
195
+ lower = sm.y_lower if confidence_interval else None
196
+
197
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
198
+
199
+ # plot fit line
200
+ ax.plot(xx, fit_line, color='tomato', linewidth=.5)
201
+
202
+ # add confidence interval
203
+ if (
204
+ confidence_interval is True and
205
+ lower is not None and upper is not None
206
+ ):
207
+ ax.fill_between(xx, lower, upper, color='grey', alpha=0.2)
208
+
209
+ ax.scatter(xx, yy, s=point_size, color=point_color, alpha=0.75)
210
+ ax.set_ylabel('Frequency (per mil. words)')
211
+
212
+ # Despine
213
+ ax.spines['right'].set_visible(False)
214
+ ax.spines['top'].set_visible(False)
215
+
216
+ ticks = [tick for tick in plt.gca().get_yticks() if tick >= 0]
217
+ plt.gca().set_yticks(ticks)
218
+
219
+ return fig
220
+
221
+ def timeviz_screeplot(self,
222
+ width=6,
223
+ height=3,
224
+ dpi=150,
225
+ point_size=0.75,
226
+ distance="sd") -> Figure:
227
+ """
228
+ Generate a scree plot for determining clusters.
229
+
230
+ Parameters
231
+ ----------
232
+ width:
233
+ The width of the plot.
234
+ height:
235
+ The height of the plot.
236
+ dpi:
237
+ The resolution of the plot.
238
+ point_size:
239
+ The size of the points.
240
+ distance:
241
+ One of 'sd' (standard deviation)
242
+ or 'cv' (coefficient of variation).
243
+
244
+ Returns
245
+ -------
246
+ Figure
247
+ A matplotlib figure.
248
+
249
+ """
250
+ dist_types = ['sd', 'cv']
251
+ if distance not in dist_types:
252
+ distance = "sd"
253
+
254
+ if distance == "cv":
255
+ dist = self.distances_cv
256
+ else:
257
+ dist = self.distances_sd
258
+
259
+ # SCREEPLOT
260
+ yy = dist[::-1]
261
+ xx = np.array([i for i in range(1, len(yy) + 1)])
262
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
263
+ ax.scatter(x=xx,
264
+ y=yy,
265
+ marker='o',
266
+ s=point_size,
267
+ facecolors='none',
268
+ edgecolors='black')
269
+ ax.set_xlabel('Clusters')
270
+ ax.set_ylabel(f'Distance (in summed {distance})')
271
+
272
+ # Despine
273
+ ax.spines['right'].set_visible(False)
274
+ ax.spines['top'].set_visible(False)
275
+ return fig
276
+
277
+ def timeviz_vnc(self,
278
+ width=6,
279
+ height=4,
280
+ dpi=150,
281
+ font_size=10,
282
+ n_periods=1,
283
+ distance="sd",
284
+ orientation="horizontal",
285
+ cut_line=False,
286
+ periodize=False,
287
+ hide_labels=False) -> Figure:
288
+ """
289
+ Generate a dendrogram using the clustering method,
290
+ "Variability-based Neighbor Clustering"(VNC),
291
+ to identify periods in the historical development
292
+ of P that accounts for the temporal ordering of the data.
293
+
294
+ Parameters
295
+ ----------
296
+ width:
297
+ The width of the plot.
298
+ height:
299
+ The height of the plot.
300
+ dpi:
301
+ The resolution of the plot.
302
+ font_size:
303
+ The font size for the labels.
304
+ n_periods:
305
+ The number of periods (or clusters).
306
+ distance:
307
+ One of 'sd' (standard deviation)
308
+ or 'cv' (coefficient of variation).
309
+ orientation:
310
+ The orientation of the plot,
311
+ either 'horizontal' or 'vertical'.
312
+ cut_line:
313
+ Whether or not to include a cut line;
314
+ applies only to non-periodized plots.
315
+ periodize:
316
+ The dendrogram can be hard to read when the original
317
+ observation matrix from which the linkage is derived is
318
+ large. Periodization is used to condense the dendrogram.
319
+ hide_labels:
320
+ Whether or not to hide leaf labels.
321
+
322
+ Returns
323
+ -------
324
+ Figure
325
+ A matplotlib figure.
326
+
327
+ """
328
+ dist_types = ['sd', 'cv']
329
+ if distance not in dist_types:
330
+ distance = "sd"
331
+ orientation_types = ['horizontal', 'vertical']
332
+ if orientation not in orientation_types:
333
+ orientation = "horizontal"
334
+
335
+ if distance == "cv":
336
+ Z = self.Z_cv
337
+ else:
338
+ Z = self.Z_sd
339
+
340
+ if n_periods > len(Z):
341
+ n_periods = 1
342
+ periodize = False
343
+
344
+ if n_periods > 1 and n_periods <= len(Z) and periodize is not True:
345
+ cut_line = True
346
+
347
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
348
+
349
+ # Plot the corresponding dendrogram
350
+ if orientation == "horizontal" and periodize is not True:
351
+ X = _vnc_calculate_info(Z,
352
+ p=n_periods,
353
+ labels=self.time_intervals)
354
+
355
+ self.clusters = X['clusters']
356
+
357
+ _plot_dendrogram(icoords=X['icoord'],
358
+ dcoords=X['dcoord'],
359
+ ivl=X['ivl'],
360
+ color_list=X['color_list'],
361
+ mh=X['mh'],
362
+ orientation='top',
363
+ p=X['p'],
364
+ n=X['n'],
365
+ no_labels=False)
366
+
367
+ ax.spines['top'].set_visible(False)
368
+ ax.spines['right'].set_visible(False)
369
+ ax.spines['bottom'].set_visible(False)
370
+ ax.set_ylabel(f'Distance (in summed {distance})')
371
+
372
+ if hide_labels is not True:
373
+ ax.set_xticklabels(X['labels'],
374
+ fontsize=font_size,
375
+ rotation=90)
376
+ else:
377
+ ax.set_xticklabels([])
378
+
379
+ plt.setp(ax.collections, linewidth=.5)
380
+
381
+ if cut_line and X['dist_threshold'] is not None:
382
+ ax.axhline(y=X['dist_threshold'],
383
+ color='r',
384
+ alpha=0.7,
385
+ linestyle='--',
386
+ linewidth=.5)
387
+
388
+ if orientation == "horizontal" and periodize is True:
389
+ X = _vnc_calculate_info(Z,
390
+ truncate=True,
391
+ p=n_periods,
392
+ contraction_marks=True,
393
+ labels=self.time_intervals)
394
+
395
+ self.clusters = X['clusters']
396
+
397
+ _plot_dendrogram(
398
+ icoords=X['icoord'],
399
+ dcoords=X['dcoord'],
400
+ ivl=X['ivl'],
401
+ color_list=X['color_list'],
402
+ mh=X['mh'],
403
+ orientation='top',
404
+ p=X['p'],
405
+ n=X['n'],
406
+ no_labels=False,
407
+ contraction_marks=X['contraction_marks'])
408
+
409
+ ax.spines['top'].set_visible(False)
410
+ ax.spines['right'].set_visible(False)
411
+ ax.spines['bottom'].set_visible(False)
412
+ ax.set_ylabel(f'Distance (in summed {distance})')
413
+
414
+ if hide_labels is not True:
415
+ ax.set_xticklabels(X['cluster_labels'],
416
+ fontsize=font_size,
417
+ rotation=90)
418
+ else:
419
+ ax.set_xticklabels([])
420
+
421
+ plt.setp(ax.collections, linewidth=.5)
422
+
423
+ if orientation == "vertical" and periodize is not True:
424
+ X = _vnc_calculate_info(Z,
425
+ p=n_periods,
426
+ labels=self.time_intervals)
427
+
428
+ self.clusters = X['clusters']
429
+
430
+ _plot_dendrogram(
431
+ icoords=X['icoord'],
432
+ dcoords=X['dcoord'],
433
+ ivl=X['ivl'],
434
+ color_list=X['color_list'],
435
+ mh=X['mh'],
436
+ orientation='right',
437
+ p=X['p'],
438
+ n=X['n'],
439
+ no_labels=False)
440
+
441
+ ax.spines['top'].set_visible(False)
442
+ ax.spines['right'].set_visible(False)
443
+ ax.spines['left'].set_visible(False)
444
+ ax.set_xlabel(f'Distance (in summed {distance})')
445
+
446
+ if hide_labels is not True:
447
+ ax.set_yticklabels(X['labels'],
448
+ fontsize=font_size,
449
+ rotation=0)
450
+ else:
451
+ ax.set_yticklabels([])
452
+
453
+ ymin, ymax = ax.get_ylim()
454
+ ax.set_ylim(ymax, ymin)
455
+ plt.setp(ax.collections, linewidth=.5)
456
+
457
+ if cut_line and X['dist_threshold'] is not None:
458
+ ax.axvline(x=X['dist_threshold'],
459
+ color='r',
460
+ alpha=0.7,
461
+ linestyle='--',
462
+ linewidth=.5)
463
+
464
+ if orientation == "vertical" and periodize is True:
465
+ X = _vnc_calculate_info(Z,
466
+ truncate=True,
467
+ p=n_periods,
468
+ contraction_marks=True,
469
+ labels=self.time_intervals)
470
+
471
+ self.clusters = X['clusters']
472
+
473
+ _plot_dendrogram(
474
+ icoords=X['icoord'],
475
+ dcoords=X['dcoord'],
476
+ ivl=X['ivl'],
477
+ color_list=X['color_list'],
478
+ mh=X['mh'],
479
+ orientation='right',
480
+ p=X['p'],
481
+ n=X['n'],
482
+ no_labels=False,
483
+ contraction_marks=X['contraction_marks'])
484
+
485
+ ax.spines['top'].set_visible(False)
486
+ ax.spines['right'].set_visible(False)
487
+ ax.spines['left'].set_visible(False)
488
+ ax.set_xlabel(f'Distance (in summed {distance})')
489
+
490
+ if hide_labels is not True:
491
+ ax.set_yticklabels(X['cluster_labels'],
492
+ fontsize=font_size,
493
+ rotation=0)
494
+ else:
495
+ ax.set_yticklabels([])
496
+
497
+ ymin, ymax = ax.get_ylim()
498
+ ax.set_ylim(ymax, ymin)
499
+ plt.setp(ax.collections, linewidth=.5)
500
+
501
+ return fig
502
+
503
+ def cluster_summary(self):
504
+ """
505
+ Print a summary of cluster membership.
506
+
507
+ Returns
508
+ -------
509
+ Prints to the console.
510
+
511
+ """
512
+ cluster_list = self.clusters
513
+ if cluster_list is not None:
514
+ for i, cluster in enumerate(cluster_list, start=1):
515
+ for key, value in cluster.items():
516
+ print(f"Cluster {i} (n={len(value)}): {[str(v) for v in value]}") # noqa: E501
517
+ else:
518
+ print("No clusters to summarize.")