google-ngrams 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google_ngrams/vnc.py ADDED
@@ -0,0 +1,1123 @@
1
+ import copy
2
+ import numpy as np
3
+ import polars as pl
4
+ import matplotlib.pyplot as plt
5
+ from textwrap import dedent
6
+ from matplotlib.figure import Figure
7
+ from scipy.cluster import hierarchy as sch
8
+
9
+
10
+ def _linkage_matrix(time_series,
11
+ frequency,
12
+ distance_measure='sd'):
13
+
14
+ input_values = frequency.copy()
15
+ years = time_series.copy()
16
+
17
+ data_collector = {}
18
+ data_collector["0"] = input_values
19
+ position_collector = {}
20
+ position_collector[1] = 0
21
+ overall_distance = 0
22
+ number_of_steps = len(input_values) - 1
23
+
24
+ for i in range(1, number_of_steps + 1):
25
+ difference_checker = []
26
+ unique_years = np.unique(years)
27
+
28
+ for j in range(len(unique_years) - 1):
29
+ first_name = unique_years[j]
30
+ second_name = unique_years[j + 1]
31
+ pooled_sample = input_values[np.isin(years,
32
+ [first_name,
33
+ second_name])]
34
+
35
+ if distance_measure == "sd":
36
+ difference_checker.append(0 if np.sum(pooled_sample) == 0
37
+ else np.std(pooled_sample, ddof=1))
38
+ elif distance_measure == "cv":
39
+ difference_checker.append(
40
+ 0 if np.sum(pooled_sample) == 0
41
+ else np.std(pooled_sample, ddof=1) / np.mean(pooled_sample)
42
+ )
43
+
44
+ pos_to_be_merged = np.argmin(difference_checker)
45
+ distance = np.min(difference_checker)
46
+ overall_distance += distance
47
+ lower_name = unique_years[pos_to_be_merged]
48
+ higher_name = unique_years[pos_to_be_merged + 1]
49
+
50
+ matches = np.isin(years, [lower_name, higher_name])
51
+ new_mean_age = round(np.mean(years[matches]), 4)
52
+ position_collector[i + 1] = np.where(matches)[0] + 1
53
+ years[matches] = new_mean_age
54
+ data_collector[f"{i}: {distance}"] = input_values
55
+
56
+ hc_build = pl.DataFrame({
57
+ 'start': [
58
+ min(pos)
59
+ if isinstance(pos, (list, np.ndarray))
60
+ else pos for pos in position_collector.values()
61
+ ],
62
+ 'end': [
63
+ max(pos)
64
+ if isinstance(pos, (list, np.ndarray))
65
+ else pos for pos in position_collector.values()
66
+ ]
67
+ })
68
+
69
+ idx = np.arange(len(hc_build))
70
+
71
+ y = [np.where(
72
+ hc_build['start'].to_numpy()[:i] == hc_build['start'].to_numpy()[i]
73
+ )[0] for i in idx]
74
+ z = [np.where(
75
+ hc_build['end'].to_numpy()[:i] == hc_build['end'].to_numpy()[i]
76
+ )[0] for i in idx]
77
+
78
+ merge1 = [
79
+ y[i].max().item() if len(y[i]) else np.nan for i in range(len(y))
80
+ ]
81
+ merge2 = [
82
+ z[i].max().item() if len(z[i]) else np.nan for i in range(len(z))
83
+ ]
84
+
85
+ hc_build = (
86
+ hc_build.with_columns([
87
+ pl.Series('merge1',
88
+ [
89
+ min(m1, m2) if not np.isnan(m1) and
90
+ not np.isnan(m2)
91
+ else np.nan for m1, m2 in zip(merge1, merge2)
92
+ ]),
93
+ pl.Series('merge2',
94
+ [
95
+ max(m1, m2) for m1, m2 in zip(merge1, merge2)
96
+ ])
97
+ ])
98
+ )
99
+
100
+ hc_build = (
101
+ hc_build.with_columns([
102
+ pl.Series('merge1', [
103
+ min(m1, m2) if not np.isnan(m1) and
104
+ not np.isnan(m2) else np.nan for m1, m2 in zip(merge1, merge2)
105
+ ]),
106
+ pl.Series('merge2', [
107
+ max(m1, m2) if not np.isnan(m1)
108
+ else m2 for m1, m2 in zip(merge1, merge2)
109
+ ])
110
+ ])
111
+ )
112
+
113
+ hc_build = (
114
+ hc_build.with_columns([
115
+ pl.when(
116
+ pl.col('merge1').is_nan() &
117
+ pl.col('merge2').is_nan()
118
+ ).then(-pl.col('start')
119
+ ).otherwise(pl.col('merge1')).alias('merge1'),
120
+ pl.when(
121
+ pl.col('merge2')
122
+ .is_nan()
123
+ ).then(-pl.col('end')
124
+ ).otherwise(pl.col('merge2')).alias('merge2')
125
+ ])
126
+ )
127
+
128
+ to_merge = [-np.setdiff1d(
129
+ hc_build.select(
130
+ pl.col('start', 'end')
131
+ ).row(i),
132
+ hc_build.select(
133
+ pl.col('start', 'end')
134
+ ).slice(1, i-1).to_numpy().flatten()
135
+ ) for i in idx]
136
+
137
+ to_merge = [-np.setdiff1d(
138
+ hc_build.select(
139
+ pl.col('start', 'end')
140
+ ).row(i),
141
+ hc_build.select(
142
+ pl.col('start', 'end')
143
+ ).slice(1, i-1).to_numpy().flatten()
144
+ ) for i in idx]
145
+
146
+ to_merge = [x[0].item() if len(x) > 0 else np.nan for x in to_merge]
147
+
148
+ hc_build = (
149
+ hc_build
150
+ .with_columns([
151
+ pl.when(pl.col('merge1').is_nan()
152
+ ).then(pl.Series(to_merge, strict=False)
153
+ ).otherwise(pl.col('merge1')).alias('merge1')
154
+ ])
155
+ )
156
+
157
+ hc_build = hc_build.with_row_index()
158
+ n = hc_build.height
159
+
160
+ hc_build = (hc_build
161
+ .with_columns(
162
+ pl.when(pl.col("merge1").lt(0))
163
+ .then(pl.col("merge1").mul(-1).sub(1))
164
+ .otherwise(pl.col('merge1').add(n-1)).alias('merge1')
165
+ )
166
+ .with_columns(
167
+ pl.when(pl.col("merge2").lt(0))
168
+ .then(pl.col("merge2").mul(-1).sub(1))
169
+ .otherwise(pl.col('merge2').add(n-1)).alias('merge2')
170
+ )
171
+ )
172
+
173
+ hc_build = (
174
+ hc_build
175
+ .with_columns(distance=np.array(list(data_collector.keys())))
176
+ .with_columns(pl.col("distance").str.replace(r"(\d+: )", ""))
177
+ .with_columns(pl.col("distance").cast(pl.Float64))
178
+ .with_columns(pl.col("distance").cum_sum().alias("distance"))
179
+ )
180
+
181
+ size = np.array(
182
+ [
183
+ len(x) if isinstance(x, (list, np.ndarray))
184
+ else 1 for x in position_collector.values()
185
+ ])
186
+
187
+ hc_build = (
188
+ hc_build
189
+ .with_columns(size=size)
190
+ .with_columns(pl.col("size").cast(pl.Float64))
191
+ )
192
+
193
+ hc_build = hc_build.filter(pl.col("index") != 0)
194
+
195
+ hc = hc_build.select("merge1", "merge2", "distance", "size").to_numpy()
196
+ return hc
197
+
198
+
199
+ def _contract_linkage_matrix(Z: np.ndarray,
200
+ p=4):
201
+ """
202
+ Contracts the linkage matrix by reducing the number of clusters
203
+ to a specified number.
204
+
205
+ Parameters
206
+ ----------
207
+ Z : np.ndarray
208
+ The linkage matrix.
209
+ p : int
210
+ The number of clusters to retain.
211
+
212
+ Returns
213
+ -------
214
+ np.ndarray
215
+ The contracted linkage matrix with updated cluster IDs
216
+ and member counts.
217
+ """
218
+ Z = Z.copy()
219
+ truncated_Z = Z[-(p - 1):]
220
+
221
+ n_points = Z.shape[0] + 1
222
+ clusters = [
223
+ dict(node_id=i, left=i, right=i, members=[i], distance=0, n_members=1)
224
+ for i in range(n_points)
225
+ ]
226
+ for z_i in range(Z.shape[0]):
227
+ row = Z[z_i]
228
+ left = int(row[0])
229
+ right = int(row[1])
230
+ cluster = dict(
231
+ node_id=z_i + n_points,
232
+ left=left,
233
+ right=right,
234
+ members=[],
235
+ distance=row[2],
236
+ n_members=int(row[3])
237
+ )
238
+ cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
239
+ cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
240
+ cluster["members"].sort()
241
+ clusters.append(cluster)
242
+
243
+ node_map = []
244
+ for i in range(truncated_Z.shape[0]):
245
+ node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
246
+ for cluster in clusters:
247
+ if cluster['node_id'] in node_ids:
248
+ node_map.append(cluster)
249
+
250
+ filtered_node_map = []
251
+ superset_node_map = []
252
+
253
+ for node in node_map:
254
+ is_superset = False
255
+ for other_node in node_map:
256
+ if (
257
+ node != other_node
258
+ and set(
259
+ node['members']
260
+ ).issuperset(set(other_node['members']))
261
+ ):
262
+ is_superset = True
263
+ break
264
+ if is_superset:
265
+ superset_node_map.append(node)
266
+ else:
267
+ filtered_node_map.append(node)
268
+
269
+ # Add 'truncated_id' to each dictionary in filtered_node_map
270
+ for idx, node in enumerate(
271
+ sorted(filtered_node_map, key=lambda x: x['members'][0])
272
+ ):
273
+ node['truncated_id'] = idx
274
+ node['n_members'] = 1
275
+
276
+ for idx, node in enumerate(
277
+ sorted(superset_node_map, key=lambda x: x['node_id'])
278
+ ):
279
+ node['truncated_id'] = idx + len(filtered_node_map)
280
+
281
+ # Adjust 'n_members' in superset_node_map to reflect
282
+ # the number of filtered_node_map['members'] sets they contain
283
+ for superset_node in superset_node_map:
284
+ count = 0
285
+ for filtered_node in filtered_node_map:
286
+ if set(
287
+ filtered_node['members']
288
+ ).issubset(set(superset_node['members'])):
289
+ count += 1
290
+ superset_node['n_members'] = count
291
+
292
+ # Create a mapping from node_id to truncated_id and n_members
293
+ node_id_to_truncated_id = {
294
+ node['node_id']: node['truncated_id']
295
+ for node in filtered_node_map + superset_node_map
296
+ }
297
+ node_id_to_n_members = {
298
+ node['node_id']: node['n_members']
299
+ for node in filtered_node_map + superset_node_map
300
+ }
301
+
302
+ # Replace values in truncated_Z
303
+ for i in range(truncated_Z.shape[0]):
304
+ truncated_Z[i, 3] = (
305
+ node_id_to_n_members[int(truncated_Z[i, 0])] +
306
+ node_id_to_n_members[int(truncated_Z[i, 1])]
307
+ )
308
+ truncated_Z[i, 0] = node_id_to_truncated_id[int(truncated_Z[i, 0])]
309
+ truncated_Z[i, 1] = node_id_to_truncated_id[int(truncated_Z[i, 1])]
310
+
311
+ return truncated_Z
312
+
313
+
314
+ def _contraction_mark_coordinates(Z: np.ndarray,
315
+ p=4):
316
+ """
317
+ Generates contraction marks for a given linkage matrix.
318
+
319
+ Parameters
320
+ ----------
321
+ Z : np.ndarray
322
+ The linkage matrix.
323
+ p : int
324
+ The number of clusters to retain.
325
+
326
+ Returns
327
+ -------
328
+ list
329
+ A sorted list of tuples where each tuple contains
330
+ a calculated value based on truncated_id and a distance value.
331
+ """
332
+ Z = Z.copy()
333
+ truncated_Z = Z[-(p-1):]
334
+
335
+ n_points = Z.shape[0] + 1
336
+ clusters = [dict(node_id=i,
337
+ left=i,
338
+ right=i,
339
+ members=[i],
340
+ distance=0,
341
+ n_members=1) for i in range(n_points)]
342
+ for z_i in range(Z.shape[0]):
343
+ row = Z[z_i]
344
+ left = int(row[0])
345
+ right = int(row[1])
346
+ cluster = dict(
347
+ node_id=z_i + n_points,
348
+ left=left, right=right,
349
+ members=[],
350
+ distance=row[2],
351
+ n_members=int(row[3])
352
+ )
353
+ cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
354
+ cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
355
+ cluster["members"].sort()
356
+ clusters.append(cluster)
357
+
358
+ node_map = []
359
+ for i in range(truncated_Z.shape[0]):
360
+ node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
361
+ for cluster in clusters:
362
+ if cluster['node_id'] in node_ids:
363
+ node_map.append(cluster)
364
+
365
+ filtered_node_map = []
366
+ superset_node_map = []
367
+
368
+ for node in node_map:
369
+ is_superset = False
370
+ for other_node in node_map:
371
+ if (node != other_node
372
+ and set(node['members']
373
+ ).issuperset(set(other_node['members']))):
374
+ is_superset = True
375
+ break
376
+ if is_superset:
377
+ superset_node_map.append(node)
378
+ else:
379
+ filtered_node_map.append(node)
380
+
381
+ # Create a set of node_ids from filtered_node_map and superset_node_map
382
+ excluded_node_ids = set(
383
+ node['node_id'] for node in filtered_node_map
384
+ ).union(node['node_id'] for node in superset_node_map)
385
+
386
+ # Filter clusters that are not in excluded_node_ids
387
+ non_excluded_clusters = [
388
+ cluster for cluster in clusters
389
+ if cluster['node_id'] not in excluded_node_ids
390
+ ]
391
+
392
+ # Create a list to store the result
393
+ subset_clusters = []
394
+
395
+ # Iterate over filtered_node_map
396
+ for filtered_cluster in filtered_node_map:
397
+ distances = []
398
+ for cluster in non_excluded_clusters:
399
+ if (
400
+ cluster['n_members'] > 1
401
+ and set(cluster['members']
402
+ ).issubset(set(filtered_cluster['members']))):
403
+ distances.append(cluster['distance'])
404
+ if distances:
405
+ subset_clusters.append(
406
+ {'node_id': filtered_cluster['node_id'], 'distance': distances}
407
+ )
408
+
409
+ # Add 'truncated_id' to each dictionary in filtered_node_map
410
+ for idx, node in enumerate(
411
+ sorted(filtered_node_map, key=lambda x: x['members'][0])
412
+ ):
413
+ node['truncated_id'] = idx
414
+
415
+ # Create a mapping from node_id to truncated_id
416
+ node_id_to_truncated_id = {
417
+ node['node_id']: node['truncated_id'] for node in filtered_node_map
418
+ }
419
+
420
+ # Add 'truncated_id' to each dictionary in subset_clusters
421
+ for cluster in subset_clusters:
422
+ cluster['truncated_id'] = node_id_to_truncated_id[cluster['node_id']]
423
+
424
+ # Create a list of tuples
425
+ contraction_marks = []
426
+
427
+ # Iterate over subset_clusters
428
+ for cluster in subset_clusters:
429
+ truncated_id = cluster['truncated_id']
430
+ for distance in cluster['distance']:
431
+ contraction_marks.append((10.0 * truncated_id + 5.0, distance))
432
+
433
+ # Sort the list of tuples
434
+ contraction_marks = sorted(contraction_marks, key=lambda x: (x[0], x[1]))
435
+
436
+ return contraction_marks
437
+
438
+
439
+ def _convert_linkage_to_coordinates(Z: np.ndarray):
440
+ """
441
+ Converts a linkage matrix to coordinates for plotting a dendrogram.
442
+
443
+ Parameters
444
+ ----------
445
+ Z : np.ndarray
446
+ The linkage matrix.
447
+
448
+ Returns
449
+ -------
450
+ dict
451
+ A dictionary containing 'icoord', 'dcoord', and 'ivl'
452
+ for plotting the dendrogram.
453
+ """
454
+ ivl = [i for i in range(Z.shape[0] + 1)]
455
+ n = len(ivl)
456
+ icoord = []
457
+ dcoord = []
458
+ clusters = {i: [i] for i in range(n)}
459
+ current_index = n
460
+ positions = {i: (i + 1) * 10 - 5 for i in range(n)}
461
+ heights = {i: 0 for i in range(n)}
462
+
463
+ for i in range(len(Z)):
464
+ cluster1 = int(Z[i, 0])
465
+ cluster2 = int(Z[i, 1])
466
+ dist = Z[i, 2].item()
467
+ new_cluster = clusters[cluster1] + clusters[cluster2]
468
+ clusters[current_index] = new_cluster
469
+
470
+ x1 = positions[cluster1]
471
+ x2 = positions[cluster2]
472
+ x_new = (x1 + x2) / 2
473
+ positions[current_index] = x_new
474
+
475
+ h1 = heights[cluster1]
476
+ h2 = heights[cluster2]
477
+ heights[current_index] = dist
478
+
479
+ icoord.append([x1, x1, x2, x2])
480
+ dcoord.append([h1, dist, dist, h2])
481
+
482
+ current_index += 1
483
+
484
+ # Sort icoord and dcoord by the first element in each icoord list
485
+ sorted_indices = sorted(range(len(icoord)), key=lambda i: icoord[i][0])
486
+ icoord = [icoord[i] for i in sorted_indices]
487
+ dcoord = [dcoord[i] for i in sorted_indices]
488
+
489
+ return {"icoord": icoord, "dcoord": dcoord, "ivl": ivl}
490
+
491
+
492
+ def _vnc_calculate_info(Z: np.ndarray,
493
+ p=None,
494
+ truncate=False,
495
+ contraction_marks=False,
496
+ labels=None):
497
+ Z = Z.copy()
498
+ Zs = Z.shape
499
+ n = Zs[0] + 1
500
+
501
+ if labels is not None:
502
+ if Zs[0] + 1 != len(labels):
503
+ labels = None
504
+ print(dedent(
505
+ """
506
+ Dimensions of Z and labels are not consistent.
507
+ Using defalut labels.
508
+ """))
509
+ if labels is None:
510
+ labels = [str(i) for i in range(Zs[0] + 1)]
511
+ else:
512
+ labels = labels
513
+
514
+ if p is not None and p > n or p < 2:
515
+ p = None
516
+ truncate = False
517
+ contraction_marks = False
518
+
519
+ if p is not None:
520
+ cluster_assignment = [i.item() for i in sch.cut_tree(Z, p)]
521
+
522
+ # Create a dictionary to hold the clusters
523
+ cluster_dict = {}
524
+
525
+ # Iterate over the labels and clusters to populate the dictionary
526
+ for label, cluster in zip(labels, cluster_assignment):
527
+ cluster_key = f'cluster_{cluster + 1}'
528
+ if cluster_key not in cluster_dict:
529
+ cluster_dict[cluster_key] = []
530
+ cluster_dict[cluster_key].append(label)
531
+
532
+ # Convert the dictionary to a list of dictionaries
533
+ cluster_list = [{key: value} for key, value in cluster_dict.items()]
534
+
535
+ # Create a new list to hold the cluster labels
536
+ cluster_labels = []
537
+
538
+ # Iterate over the cluster_list to create the labels
539
+ for cluster in cluster_list:
540
+ for key, value in cluster.items():
541
+ if len(value) == 1:
542
+ cluster_labels.append(str(value[0]))
543
+ else:
544
+ cluster_labels.append(f"{value[0]}-{value[-1]}")
545
+
546
+ # get distance for plotting cut line
547
+ dist = [x[2].item() for x in Z]
548
+ dist_threshold = np.mean(
549
+ [dist[len(dist)-p+1], dist[len(dist)-p]]
550
+ )
551
+ else:
552
+ dist_threshold = None
553
+ cluster_list = None
554
+ cluster_labels = None
555
+
556
+ if truncate is True:
557
+ truncated_Z = _contract_linkage_matrix(Z, p=p)
558
+
559
+ if contraction_marks is True:
560
+ contraction_marks = _contraction_mark_coordinates(Z, p=p)
561
+ else:
562
+ contraction_marks = None
563
+
564
+ Z = truncated_Z
565
+ else:
566
+ Z = Z
567
+ contraction_marks = None
568
+
569
+ R = _convert_linkage_to_coordinates(Z)
570
+
571
+ mh = np.max(Z[:, 2])
572
+ Zn = Z.shape[0] + 1
573
+ color_list = ['k'] * (Zn - 1)
574
+ leaves_color_list = ['k'] * Zn
575
+ R['n'] = Zn
576
+ R['mh'] = mh
577
+ R['p'] = p
578
+ R['labels'] = labels
579
+ R['color_list'] = color_list
580
+ R['leaves_color_list'] = leaves_color_list
581
+ R['clusters'] = cluster_list
582
+ R['cluster_labels'] = cluster_labels
583
+ R['dist_threshold'] = dist_threshold
584
+ R["contraction_marks"] = contraction_marks
585
+
586
+ return R
587
+
588
+
589
+ def _lowess(x,
590
+ y,
591
+ f=1./3.):
592
+ """
593
+ Basic LOWESS smoother with uncertainty.
594
+ Note:
595
+ - Not robust (so no iteration) and
596
+ only normally distributed errors.
597
+ - No higher order polynomials d=1
598
+ so linear smoother.
599
+ """
600
+ # get some paras
601
+ # effective width after reduction factor
602
+ xwidth = f*(x.max()-x.min())
603
+ # number of obs
604
+ N = len(x)
605
+ # Don't assume the data is sorted
606
+ order = np.argsort(x)
607
+ # storage
608
+ y_sm = np.zeros_like(y)
609
+ y_stderr = np.zeros_like(y)
610
+ # define the weigthing function -- clipping too!
611
+ tricube = lambda d: np.clip((1 - np.abs(d)**3)**3, 0, 1) # noqa: E731
612
+ # run the regression for each observation i
613
+ for i in range(N):
614
+ dist = np.abs((x[order][i]-x[order]))/xwidth
615
+ w = tricube(dist)
616
+ # form linear system with the weights
617
+ A = np.stack([w, x[order]*w]).T
618
+ b = w * y[order]
619
+ ATA = A.T.dot(A)
620
+ ATb = A.T.dot(b)
621
+ # solve the syste
622
+ sol = np.linalg.solve(ATA, ATb)
623
+ # predict for the observation only
624
+ # equiv of A.dot(yest) just for k
625
+ yest = A[i].dot(sol)
626
+ place = order[i]
627
+ y_sm[place] = yest
628
+ sigma2 = (np.sum((A.dot(sol) - y[order])**2)/N)
629
+ # Calculate the standard error
630
+ y_stderr[place] = np.sqrt(sigma2 *
631
+ A[i].dot(np.linalg.inv(ATA)
632
+ ).dot(A[i]))
633
+ return y_sm, y_stderr
634
+
635
+
636
+ class TimeSeries:
637
+
638
+ def __init__(self,
639
+ time_series: pl.DataFrame,
640
+ time_col: str,
641
+ values_col: str):
642
+
643
+ time = time_series.get_column(time_col, default=None)
644
+ values = time_series.get_column(values_col, default=None)
645
+
646
+ if time is None:
647
+ raise ValueError("""
648
+ Invalid column.
649
+ Check name. Couldn't find column in DataFrame.
650
+ """)
651
+ if values is None:
652
+ raise ValueError("""
653
+ Invalid column.
654
+ Check name. Couldn't find column in DataFrame.
655
+ """)
656
+ if not isinstance(values.dtype, (pl.Float64, pl.Float32)):
657
+ raise ValueError("""
658
+ Invalid DataFrame.
659
+ Expected a column of normalized frequencies.
660
+ """)
661
+ if len(time) != len(values):
662
+ raise ValueError("""
663
+ Your time and values vectors must be the same length.
664
+ """)
665
+
666
+ time_series = time_series.sort(time)
667
+ self.time_intervals = time_series.get_column(time_col).to_numpy()
668
+ self.frequencies = time_series.get_column(values_col).to_numpy()
669
+ self.Z_sd = _linkage_matrix(time_series=self.time_intervals,
670
+ frequency=self.frequencies)
671
+ self.Z_cv = _linkage_matrix(time_series=self.time_intervals,
672
+ frequency=self.frequencies,
673
+ distance_measure='cv')
674
+ self.distances_sd = np.array([self.Z_sd[i][2].item()
675
+ for i in range(len(self.Z_sd))])
676
+ self.distances_cv = np.array([self.Z_cv[i][2].item()
677
+ for i in range(len(self.Z_cv))])
678
+
679
+ self.clusters = None
680
+ self.distance_threshold = None
681
+
682
+ def timeviz_barplot(self,
683
+ width=8,
684
+ height=4,
685
+ dpi=150,
686
+ barwidth=4,
687
+ fill_color='#440154',
688
+ tick_interval=None,
689
+ label_rotation=None):
690
+ """
691
+ Generate a bar plot of token frequenices over time.
692
+
693
+ Parameters
694
+ ----------
695
+ width:
696
+ The width of the plot.
697
+ height:
698
+ The height of the plot.
699
+ dpi:
700
+ The resolution of the plot.
701
+ barwidth:
702
+ The width of the bars.
703
+ fill_color:
704
+ The color of the bars.
705
+ tick_interval:
706
+ Interval spacing for the tick labels.
707
+ label_rotation:
708
+ Angle used to rotate tick labels.
709
+
710
+ Returns
711
+ -------
712
+ Figure
713
+ A matplotlib figure.
714
+
715
+ """
716
+ xx = self.time_intervals
717
+ yy = self.frequencies
718
+
719
+ if label_rotation is None:
720
+ rotation = 90
721
+ else:
722
+ rotation = label_rotation
723
+
724
+ if tick_interval is None:
725
+ interval = np.diff(xx)[0]
726
+ else:
727
+ interval = tick_interval
728
+
729
+ start_value = np.min(xx)
730
+
731
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
732
+
733
+ ax.bar(xx, yy, color=fill_color, edgecolor='black',
734
+ linewidth=.5, width=barwidth)
735
+
736
+ # Despine
737
+ ax.spines['right'].set_visible(False)
738
+ ax.spines['top'].set_visible(False)
739
+ ax.axhline(y=0, color='black', linestyle='-', linewidth=.5)
740
+
741
+ ax.tick_params(axis="x", which="both", labelrotation=rotation)
742
+ ax.grid(axis='y', color='w', linestyle='--', linewidth=.5)
743
+ ax.xaxis.set_major_locator(plt.MultipleLocator(base=interval,
744
+ offset=start_value))
745
+
746
+ return fig
747
+
748
+ def timeviz_scatterplot(self,
749
+ width=8,
750
+ height=4,
751
+ dpi=150,
752
+ point_color='black',
753
+ point_size=0.5,
754
+ ci='standard') -> Figure:
755
+ """
756
+ Generate a scatter plot of token frequenices over time
757
+ with a smoothed fit line and a confidence interval.
758
+
759
+ Parameters
760
+ ----------
761
+ width:
762
+ The width of the plot.
763
+ height:
764
+ The height of the plot.
765
+ dpi:
766
+ The resolution of the plot.
767
+ point_color:
768
+ The color of the points.
769
+ point_size:
770
+ The size of the points.
771
+ ci:
772
+ The confidence interval. One of "standard" (95%),
773
+ "strict" (97.5%) or "both".
774
+
775
+ Returns
776
+ -------
777
+ Figure
778
+ A matplotlib figure.
779
+
780
+ """
781
+ ci_types = ['standard', 'strict', 'both']
782
+ if ci not in ci_types:
783
+ ci = "standard"
784
+
785
+ xx = self.time_intervals
786
+ yy = self.frequencies
787
+
788
+ order = np.argsort(xx)
789
+
790
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
791
+
792
+ # run it
793
+ y_sm, y_std = _lowess(xx, yy, f=1./5.)
794
+ # plot it
795
+ ax.plot(xx[order], y_sm[order],
796
+ color='tomato', linewidth=.5, label='LOWESS')
797
+ if ci == 'standard':
798
+ ax.fill_between(
799
+ xx[order], y_sm[order] - 1.96*y_std[order],
800
+ y_sm[order] + 1.96*y_std[order], alpha=0.3,
801
+ label='95 uncertainty')
802
+ if ci == 'strict':
803
+ ax.fill_between(
804
+ xx[order], y_sm[order] - y_std[order],
805
+ y_sm[order] + y_std[order], alpha=0.3,
806
+ label='97.5 uncertainty')
807
+ if ci == 'both':
808
+ ax.fill_between(
809
+ xx[order], y_sm[order] - 1.96*y_std[order],
810
+ y_sm[order] + 1.96*y_std[order], alpha=0.3,
811
+ label='95 uncertainty')
812
+ ax.fill_between(
813
+ xx[order], y_sm[order] - y_std[order],
814
+ y_sm[order] + y_std[order], alpha=0.3,
815
+ label='97.5 uncertainty')
816
+
817
+ ax.scatter(xx, yy, s=point_size, color=point_color, alpha=0.75)
818
+
819
+ # Despine
820
+ ax.spines['right'].set_visible(False)
821
+ ax.spines['top'].set_visible(False)
822
+
823
+ ticks = [tick for tick in plt.gca().get_yticks() if tick >= 0]
824
+ plt.gca().set_yticks(ticks)
825
+
826
+ return fig
827
+
828
+ def timeviz_screeplot(self,
829
+ width=6,
830
+ height=3,
831
+ dpi=150,
832
+ point_size=0.75,
833
+ distance="sd") -> Figure:
834
+ """
835
+ Generate a scree plot for determining clusters.
836
+
837
+ Parameters
838
+ ----------
839
+ width:
840
+ The width of the plot.
841
+ height:
842
+ The height of the plot.
843
+ dpi:
844
+ The resolution of the plot.
845
+ point_size:
846
+ The size of the points.
847
+ distance:
848
+ One of 'sd' (standard deviation)
849
+ or 'cv' (coefficient of variation).
850
+
851
+ Returns
852
+ -------
853
+ Figure
854
+ A matplotlib figure.
855
+
856
+ """
857
+ dist_types = ['sd', 'cv']
858
+ if distance not in dist_types:
859
+ distance = "sd"
860
+
861
+ if distance == "cv":
862
+ dist = self.distances_cv
863
+ else:
864
+ dist = self.distances_sd
865
+
866
+ # SCREEPLOT
867
+ yy = dist[::-1]
868
+ xx = np.array([i for i in range(1, len(yy) + 1)])
869
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
870
+ ax.scatter(x=xx,
871
+ y=yy,
872
+ marker='o',
873
+ s=point_size,
874
+ facecolors='none',
875
+ edgecolors='black')
876
+ ax.set_xlabel('Clusters')
877
+ ax.set_ylabel(f'Distance (in summed {distance})')
878
+
879
+ # Despine
880
+ ax.spines['right'].set_visible(False)
881
+ ax.spines['top'].set_visible(False)
882
+ return fig
883
+
884
+ def timeviz_vnc(self,
885
+ width=6,
886
+ height=4,
887
+ dpi=150,
888
+ font_size=10,
889
+ n_periods=1,
890
+ distance="sd",
891
+ orientation="horizontal",
892
+ cut_line=False,
893
+ periodize=False,
894
+ hide_labels=False) -> Figure:
895
+ """
896
+ Generate a dendrogram using the clustering method,
897
+ "Variability-based Neighbor Clustering"(VNC),
898
+ to identify periods in the historical development
899
+ of P that accounts for the temporal ordering of the data.
900
+
901
+ Parameters
902
+ ----------
903
+ width:
904
+ The width of the plot.
905
+ height:
906
+ The height of the plot.
907
+ dpi:
908
+ The resolution of the plot.
909
+ font_size:
910
+ The font size for the labels.
911
+ n_periods:
912
+ The number of periods (or clusters).
913
+ distance:
914
+ One of 'sd' (standard deviation)
915
+ or 'cv' (coefficient of variation).
916
+ orientation:
917
+ The orientation of the plot,
918
+ either "horizontal" or "vertical".
919
+ cut_line:
920
+ Whether or not to include a cut line;
921
+ applies only to non-periodized plots.
922
+ cut_line:
923
+ Whether or not to include a cut line;
924
+ applies only to non-periodized plots.
925
+ periodize:
926
+ The dendrogram can be hard to read when the original
927
+ observation matrix from which the linkage is derived is
928
+ large. Periodization is used to condense the dendrogram.
929
+ hide_labels:
930
+ Whether or not to hide leaf labels.
931
+
932
+ Returns
933
+ -------
934
+ Figure
935
+ A matplotlib figure.
936
+
937
+ """
938
+ dist_types = ['sd', 'cv']
939
+ if distance not in dist_types:
940
+ distance = "sd"
941
+ orientation_types = ['horizontal', 'vertical']
942
+ if orientation not in orientation_types:
943
+ orientation = "horizontal"
944
+
945
+ if distance == "cv":
946
+ Z = self.Z_cv
947
+ else:
948
+ Z = self.Z_sd
949
+
950
+ if n_periods > len(Z):
951
+ n_periods = 1
952
+ periodize = False
953
+
954
+ if n_periods > 1 and n_periods <= len(Z) and periodize is not True:
955
+ cut_line = True
956
+
957
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
958
+
959
+ # Plot the corresponding dendrogram
960
+ if orientation == "horizontal" and periodize is not True:
961
+ X = _vnc_calculate_info(Z,
962
+ p=n_periods,
963
+ labels=self.time_intervals)
964
+
965
+ self.clusters = X['clusters']
966
+
967
+ sch._plot_dendrogram(icoords=X['icoord'],
968
+ dcoords=X['dcoord'],
969
+ ivl=X['ivl'],
970
+ color_list=X['color_list'],
971
+ mh=X['mh'],
972
+ orientation='top',
973
+ p=X['p'],
974
+ n=X['n'],
975
+ no_labels=False)
976
+
977
+ ax.spines['top'].set_visible(False)
978
+ ax.spines['right'].set_visible(False)
979
+ ax.spines['bottom'].set_visible(False)
980
+ ax.set_ylabel(f'Distance (in summed {distance})')
981
+
982
+ if hide_labels is not True:
983
+ ax.set_xticklabels(X['labels'],
984
+ fontsize=font_size,
985
+ rotation=90)
986
+ else:
987
+ ax.set_xticklabels([])
988
+
989
+ plt.setp(ax.collections, linewidth=.5)
990
+
991
+ if cut_line and X['dist_threshold'] is not None:
992
+ ax.axhline(y=X['dist_threshold'],
993
+ color='r',
994
+ alpha=0.7,
995
+ linestyle='--',
996
+ linewidth=.5)
997
+
998
+ if orientation == "horizontal" and periodize is True:
999
+ X = _vnc_calculate_info(Z,
1000
+ truncate=True,
1001
+ p=n_periods,
1002
+ contraction_marks=True,
1003
+ labels=self.time_intervals)
1004
+
1005
+ self.clusters = X['clusters']
1006
+
1007
+ sch._plot_dendrogram(icoords=X['icoord'],
1008
+ dcoords=X['dcoord'],
1009
+ ivl=X['ivl'],
1010
+ color_list=X['color_list'],
1011
+ mh=X['mh'], orientation='top',
1012
+ p=X['p'],
1013
+ n=X['n'],
1014
+ no_labels=False,
1015
+ contraction_marks=X['contraction_marks'])
1016
+
1017
+ ax.spines['top'].set_visible(False)
1018
+ ax.spines['right'].set_visible(False)
1019
+ ax.spines['bottom'].set_visible(False)
1020
+ ax.set_ylabel(f'Distance (in summed {distance})')
1021
+
1022
+ if hide_labels is not True:
1023
+ ax.set_xticklabels(X['cluster_labels'],
1024
+ fontsize=font_size,
1025
+ rotation=90)
1026
+ else:
1027
+ ax.set_xticklabels([])
1028
+
1029
+ plt.setp(ax.collections, linewidth=.5)
1030
+
1031
+ if orientation == "vertical" and periodize is not True:
1032
+ X = _vnc_calculate_info(Z,
1033
+ p=n_periods,
1034
+ labels=self.time_intervals)
1035
+
1036
+ self.clusters = X['clusters']
1037
+
1038
+ sch._plot_dendrogram(icoords=X['icoord'],
1039
+ dcoords=X['dcoord'],
1040
+ ivl=X['ivl'],
1041
+ color_list=X['color_list'],
1042
+ mh=X['mh'],
1043
+ orientation='right',
1044
+ p=X['p'],
1045
+ n=X['n'],
1046
+ no_labels=False)
1047
+
1048
+ ax.spines['top'].set_visible(False)
1049
+ ax.spines['right'].set_visible(False)
1050
+ ax.spines['left'].set_visible(False)
1051
+ ax.set_xlabel(f'Distance (in summed {distance})')
1052
+
1053
+ if hide_labels is not True:
1054
+ ax.set_yticklabels(X['labels'],
1055
+ fontsize=font_size,
1056
+ rotation=0)
1057
+ else:
1058
+ ax.set_yticklabels([])
1059
+
1060
+ ymin, ymax = ax.get_ylim()
1061
+ ax.set_ylim(ymax, ymin)
1062
+ plt.setp(ax.collections, linewidth=.5)
1063
+
1064
+ if cut_line and X['dist_threshold'] is not None:
1065
+ ax.axvline(x=X['dist_threshold'],
1066
+ color='r',
1067
+ alpha=0.7,
1068
+ linestyle='--',
1069
+ linewidth=.5)
1070
+
1071
+ if orientation == "vertical" and periodize is True:
1072
+ X = _vnc_calculate_info(Z,
1073
+ truncate=True,
1074
+ p=n_periods,
1075
+ contraction_marks=True,
1076
+ labels=self.time_intervals)
1077
+
1078
+ self.clusters = X['clusters']
1079
+
1080
+ sch._plot_dendrogram(icoords=X['icoord'],
1081
+ dcoords=X['dcoord'],
1082
+ ivl=X['ivl'],
1083
+ color_list=X['color_list'],
1084
+ mh=X['mh'], orientation='right',
1085
+ p=X['p'],
1086
+ n=X['n'],
1087
+ no_labels=False,
1088
+ contraction_marks=X['contraction_marks'])
1089
+
1090
+ ax.spines['top'].set_visible(False)
1091
+ ax.spines['right'].set_visible(False)
1092
+ ax.spines['left'].set_visible(False)
1093
+ ax.set_xlabel(f'Distance (in summed {distance})')
1094
+
1095
+ if hide_labels is not True:
1096
+ ax.set_yticklabels(X['cluster_labels'],
1097
+ fontsize=font_size,
1098
+ rotation=0)
1099
+ else:
1100
+ ax.set_yticklabels([])
1101
+
1102
+ ymin, ymax = ax.get_ylim()
1103
+ ax.set_ylim(ymax, ymin)
1104
+ plt.setp(ax.collections, linewidth=.5)
1105
+
1106
+ return fig
1107
+
1108
+ def cluster_summary(self):
1109
+ """
1110
+ Print a summary of cluster membership.
1111
+
1112
+ Returns
1113
+ -------
1114
+ Prints to the console.
1115
+
1116
+ """
1117
+ cluster_list = self.clusters
1118
+ if cluster_list is not None:
1119
+ for i, cluster in enumerate(cluster_list, start=1):
1120
+ for key, value in cluster.items():
1121
+ print(f"Cluster {i} (n={len(value)}): {[str(v) for v in value]}") # noqa: E501
1122
+ else:
1123
+ print("No clusters to summarize.")