google-ngrams 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_ngrams/__init__.py +16 -0
- google_ngrams/data/__init__.py +14 -0
- google_ngrams/data/googlebooks_eng_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/data/googlebooks_eng_gb_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/data/googlebooks_eng_us_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/ngrams.py +209 -0
- google_ngrams/vnc.py +1123 -0
- google_ngrams-0.1.0.dist-info/LICENSE +201 -0
- google_ngrams-0.1.0.dist-info/METADATA +126 -0
- google_ngrams-0.1.0.dist-info/RECORD +12 -0
- google_ngrams-0.1.0.dist-info/WHEEL +6 -0
- google_ngrams-0.1.0.dist-info/top_level.txt +1 -0
google_ngrams/vnc.py
ADDED
@@ -0,0 +1,1123 @@
|
|
1
|
+
import copy
|
2
|
+
import numpy as np
|
3
|
+
import polars as pl
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
from textwrap import dedent
|
6
|
+
from matplotlib.figure import Figure
|
7
|
+
from scipy.cluster import hierarchy as sch
|
8
|
+
|
9
|
+
|
10
|
+
def _linkage_matrix(time_series,
|
11
|
+
frequency,
|
12
|
+
distance_measure='sd'):
|
13
|
+
|
14
|
+
input_values = frequency.copy()
|
15
|
+
years = time_series.copy()
|
16
|
+
|
17
|
+
data_collector = {}
|
18
|
+
data_collector["0"] = input_values
|
19
|
+
position_collector = {}
|
20
|
+
position_collector[1] = 0
|
21
|
+
overall_distance = 0
|
22
|
+
number_of_steps = len(input_values) - 1
|
23
|
+
|
24
|
+
for i in range(1, number_of_steps + 1):
|
25
|
+
difference_checker = []
|
26
|
+
unique_years = np.unique(years)
|
27
|
+
|
28
|
+
for j in range(len(unique_years) - 1):
|
29
|
+
first_name = unique_years[j]
|
30
|
+
second_name = unique_years[j + 1]
|
31
|
+
pooled_sample = input_values[np.isin(years,
|
32
|
+
[first_name,
|
33
|
+
second_name])]
|
34
|
+
|
35
|
+
if distance_measure == "sd":
|
36
|
+
difference_checker.append(0 if np.sum(pooled_sample) == 0
|
37
|
+
else np.std(pooled_sample, ddof=1))
|
38
|
+
elif distance_measure == "cv":
|
39
|
+
difference_checker.append(
|
40
|
+
0 if np.sum(pooled_sample) == 0
|
41
|
+
else np.std(pooled_sample, ddof=1) / np.mean(pooled_sample)
|
42
|
+
)
|
43
|
+
|
44
|
+
pos_to_be_merged = np.argmin(difference_checker)
|
45
|
+
distance = np.min(difference_checker)
|
46
|
+
overall_distance += distance
|
47
|
+
lower_name = unique_years[pos_to_be_merged]
|
48
|
+
higher_name = unique_years[pos_to_be_merged + 1]
|
49
|
+
|
50
|
+
matches = np.isin(years, [lower_name, higher_name])
|
51
|
+
new_mean_age = round(np.mean(years[matches]), 4)
|
52
|
+
position_collector[i + 1] = np.where(matches)[0] + 1
|
53
|
+
years[matches] = new_mean_age
|
54
|
+
data_collector[f"{i}: {distance}"] = input_values
|
55
|
+
|
56
|
+
hc_build = pl.DataFrame({
|
57
|
+
'start': [
|
58
|
+
min(pos)
|
59
|
+
if isinstance(pos, (list, np.ndarray))
|
60
|
+
else pos for pos in position_collector.values()
|
61
|
+
],
|
62
|
+
'end': [
|
63
|
+
max(pos)
|
64
|
+
if isinstance(pos, (list, np.ndarray))
|
65
|
+
else pos for pos in position_collector.values()
|
66
|
+
]
|
67
|
+
})
|
68
|
+
|
69
|
+
idx = np.arange(len(hc_build))
|
70
|
+
|
71
|
+
y = [np.where(
|
72
|
+
hc_build['start'].to_numpy()[:i] == hc_build['start'].to_numpy()[i]
|
73
|
+
)[0] for i in idx]
|
74
|
+
z = [np.where(
|
75
|
+
hc_build['end'].to_numpy()[:i] == hc_build['end'].to_numpy()[i]
|
76
|
+
)[0] for i in idx]
|
77
|
+
|
78
|
+
merge1 = [
|
79
|
+
y[i].max().item() if len(y[i]) else np.nan for i in range(len(y))
|
80
|
+
]
|
81
|
+
merge2 = [
|
82
|
+
z[i].max().item() if len(z[i]) else np.nan for i in range(len(z))
|
83
|
+
]
|
84
|
+
|
85
|
+
hc_build = (
|
86
|
+
hc_build.with_columns([
|
87
|
+
pl.Series('merge1',
|
88
|
+
[
|
89
|
+
min(m1, m2) if not np.isnan(m1) and
|
90
|
+
not np.isnan(m2)
|
91
|
+
else np.nan for m1, m2 in zip(merge1, merge2)
|
92
|
+
]),
|
93
|
+
pl.Series('merge2',
|
94
|
+
[
|
95
|
+
max(m1, m2) for m1, m2 in zip(merge1, merge2)
|
96
|
+
])
|
97
|
+
])
|
98
|
+
)
|
99
|
+
|
100
|
+
hc_build = (
|
101
|
+
hc_build.with_columns([
|
102
|
+
pl.Series('merge1', [
|
103
|
+
min(m1, m2) if not np.isnan(m1) and
|
104
|
+
not np.isnan(m2) else np.nan for m1, m2 in zip(merge1, merge2)
|
105
|
+
]),
|
106
|
+
pl.Series('merge2', [
|
107
|
+
max(m1, m2) if not np.isnan(m1)
|
108
|
+
else m2 for m1, m2 in zip(merge1, merge2)
|
109
|
+
])
|
110
|
+
])
|
111
|
+
)
|
112
|
+
|
113
|
+
hc_build = (
|
114
|
+
hc_build.with_columns([
|
115
|
+
pl.when(
|
116
|
+
pl.col('merge1').is_nan() &
|
117
|
+
pl.col('merge2').is_nan()
|
118
|
+
).then(-pl.col('start')
|
119
|
+
).otherwise(pl.col('merge1')).alias('merge1'),
|
120
|
+
pl.when(
|
121
|
+
pl.col('merge2')
|
122
|
+
.is_nan()
|
123
|
+
).then(-pl.col('end')
|
124
|
+
).otherwise(pl.col('merge2')).alias('merge2')
|
125
|
+
])
|
126
|
+
)
|
127
|
+
|
128
|
+
to_merge = [-np.setdiff1d(
|
129
|
+
hc_build.select(
|
130
|
+
pl.col('start', 'end')
|
131
|
+
).row(i),
|
132
|
+
hc_build.select(
|
133
|
+
pl.col('start', 'end')
|
134
|
+
).slice(1, i-1).to_numpy().flatten()
|
135
|
+
) for i in idx]
|
136
|
+
|
137
|
+
to_merge = [-np.setdiff1d(
|
138
|
+
hc_build.select(
|
139
|
+
pl.col('start', 'end')
|
140
|
+
).row(i),
|
141
|
+
hc_build.select(
|
142
|
+
pl.col('start', 'end')
|
143
|
+
).slice(1, i-1).to_numpy().flatten()
|
144
|
+
) for i in idx]
|
145
|
+
|
146
|
+
to_merge = [x[0].item() if len(x) > 0 else np.nan for x in to_merge]
|
147
|
+
|
148
|
+
hc_build = (
|
149
|
+
hc_build
|
150
|
+
.with_columns([
|
151
|
+
pl.when(pl.col('merge1').is_nan()
|
152
|
+
).then(pl.Series(to_merge, strict=False)
|
153
|
+
).otherwise(pl.col('merge1')).alias('merge1')
|
154
|
+
])
|
155
|
+
)
|
156
|
+
|
157
|
+
hc_build = hc_build.with_row_index()
|
158
|
+
n = hc_build.height
|
159
|
+
|
160
|
+
hc_build = (hc_build
|
161
|
+
.with_columns(
|
162
|
+
pl.when(pl.col("merge1").lt(0))
|
163
|
+
.then(pl.col("merge1").mul(-1).sub(1))
|
164
|
+
.otherwise(pl.col('merge1').add(n-1)).alias('merge1')
|
165
|
+
)
|
166
|
+
.with_columns(
|
167
|
+
pl.when(pl.col("merge2").lt(0))
|
168
|
+
.then(pl.col("merge2").mul(-1).sub(1))
|
169
|
+
.otherwise(pl.col('merge2').add(n-1)).alias('merge2')
|
170
|
+
)
|
171
|
+
)
|
172
|
+
|
173
|
+
hc_build = (
|
174
|
+
hc_build
|
175
|
+
.with_columns(distance=np.array(list(data_collector.keys())))
|
176
|
+
.with_columns(pl.col("distance").str.replace(r"(\d+: )", ""))
|
177
|
+
.with_columns(pl.col("distance").cast(pl.Float64))
|
178
|
+
.with_columns(pl.col("distance").cum_sum().alias("distance"))
|
179
|
+
)
|
180
|
+
|
181
|
+
size = np.array(
|
182
|
+
[
|
183
|
+
len(x) if isinstance(x, (list, np.ndarray))
|
184
|
+
else 1 for x in position_collector.values()
|
185
|
+
])
|
186
|
+
|
187
|
+
hc_build = (
|
188
|
+
hc_build
|
189
|
+
.with_columns(size=size)
|
190
|
+
.with_columns(pl.col("size").cast(pl.Float64))
|
191
|
+
)
|
192
|
+
|
193
|
+
hc_build = hc_build.filter(pl.col("index") != 0)
|
194
|
+
|
195
|
+
hc = hc_build.select("merge1", "merge2", "distance", "size").to_numpy()
|
196
|
+
return hc
|
197
|
+
|
198
|
+
|
199
|
+
def _contract_linkage_matrix(Z: np.ndarray,
|
200
|
+
p=4):
|
201
|
+
"""
|
202
|
+
Contracts the linkage matrix by reducing the number of clusters
|
203
|
+
to a specified number.
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
Z : np.ndarray
|
208
|
+
The linkage matrix.
|
209
|
+
p : int
|
210
|
+
The number of clusters to retain.
|
211
|
+
|
212
|
+
Returns
|
213
|
+
-------
|
214
|
+
np.ndarray
|
215
|
+
The contracted linkage matrix with updated cluster IDs
|
216
|
+
and member counts.
|
217
|
+
"""
|
218
|
+
Z = Z.copy()
|
219
|
+
truncated_Z = Z[-(p - 1):]
|
220
|
+
|
221
|
+
n_points = Z.shape[0] + 1
|
222
|
+
clusters = [
|
223
|
+
dict(node_id=i, left=i, right=i, members=[i], distance=0, n_members=1)
|
224
|
+
for i in range(n_points)
|
225
|
+
]
|
226
|
+
for z_i in range(Z.shape[0]):
|
227
|
+
row = Z[z_i]
|
228
|
+
left = int(row[0])
|
229
|
+
right = int(row[1])
|
230
|
+
cluster = dict(
|
231
|
+
node_id=z_i + n_points,
|
232
|
+
left=left,
|
233
|
+
right=right,
|
234
|
+
members=[],
|
235
|
+
distance=row[2],
|
236
|
+
n_members=int(row[3])
|
237
|
+
)
|
238
|
+
cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
|
239
|
+
cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
|
240
|
+
cluster["members"].sort()
|
241
|
+
clusters.append(cluster)
|
242
|
+
|
243
|
+
node_map = []
|
244
|
+
for i in range(truncated_Z.shape[0]):
|
245
|
+
node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
|
246
|
+
for cluster in clusters:
|
247
|
+
if cluster['node_id'] in node_ids:
|
248
|
+
node_map.append(cluster)
|
249
|
+
|
250
|
+
filtered_node_map = []
|
251
|
+
superset_node_map = []
|
252
|
+
|
253
|
+
for node in node_map:
|
254
|
+
is_superset = False
|
255
|
+
for other_node in node_map:
|
256
|
+
if (
|
257
|
+
node != other_node
|
258
|
+
and set(
|
259
|
+
node['members']
|
260
|
+
).issuperset(set(other_node['members']))
|
261
|
+
):
|
262
|
+
is_superset = True
|
263
|
+
break
|
264
|
+
if is_superset:
|
265
|
+
superset_node_map.append(node)
|
266
|
+
else:
|
267
|
+
filtered_node_map.append(node)
|
268
|
+
|
269
|
+
# Add 'truncated_id' to each dictionary in filtered_node_map
|
270
|
+
for idx, node in enumerate(
|
271
|
+
sorted(filtered_node_map, key=lambda x: x['members'][0])
|
272
|
+
):
|
273
|
+
node['truncated_id'] = idx
|
274
|
+
node['n_members'] = 1
|
275
|
+
|
276
|
+
for idx, node in enumerate(
|
277
|
+
sorted(superset_node_map, key=lambda x: x['node_id'])
|
278
|
+
):
|
279
|
+
node['truncated_id'] = idx + len(filtered_node_map)
|
280
|
+
|
281
|
+
# Adjust 'n_members' in superset_node_map to reflect
|
282
|
+
# the number of filtered_node_map['members'] sets they contain
|
283
|
+
for superset_node in superset_node_map:
|
284
|
+
count = 0
|
285
|
+
for filtered_node in filtered_node_map:
|
286
|
+
if set(
|
287
|
+
filtered_node['members']
|
288
|
+
).issubset(set(superset_node['members'])):
|
289
|
+
count += 1
|
290
|
+
superset_node['n_members'] = count
|
291
|
+
|
292
|
+
# Create a mapping from node_id to truncated_id and n_members
|
293
|
+
node_id_to_truncated_id = {
|
294
|
+
node['node_id']: node['truncated_id']
|
295
|
+
for node in filtered_node_map + superset_node_map
|
296
|
+
}
|
297
|
+
node_id_to_n_members = {
|
298
|
+
node['node_id']: node['n_members']
|
299
|
+
for node in filtered_node_map + superset_node_map
|
300
|
+
}
|
301
|
+
|
302
|
+
# Replace values in truncated_Z
|
303
|
+
for i in range(truncated_Z.shape[0]):
|
304
|
+
truncated_Z[i, 3] = (
|
305
|
+
node_id_to_n_members[int(truncated_Z[i, 0])] +
|
306
|
+
node_id_to_n_members[int(truncated_Z[i, 1])]
|
307
|
+
)
|
308
|
+
truncated_Z[i, 0] = node_id_to_truncated_id[int(truncated_Z[i, 0])]
|
309
|
+
truncated_Z[i, 1] = node_id_to_truncated_id[int(truncated_Z[i, 1])]
|
310
|
+
|
311
|
+
return truncated_Z
|
312
|
+
|
313
|
+
|
314
|
+
def _contraction_mark_coordinates(Z: np.ndarray,
|
315
|
+
p=4):
|
316
|
+
"""
|
317
|
+
Generates contraction marks for a given linkage matrix.
|
318
|
+
|
319
|
+
Parameters
|
320
|
+
----------
|
321
|
+
Z : np.ndarray
|
322
|
+
The linkage matrix.
|
323
|
+
p : int
|
324
|
+
The number of clusters to retain.
|
325
|
+
|
326
|
+
Returns
|
327
|
+
-------
|
328
|
+
list
|
329
|
+
A sorted list of tuples where each tuple contains
|
330
|
+
a calculated value based on truncated_id and a distance value.
|
331
|
+
"""
|
332
|
+
Z = Z.copy()
|
333
|
+
truncated_Z = Z[-(p-1):]
|
334
|
+
|
335
|
+
n_points = Z.shape[0] + 1
|
336
|
+
clusters = [dict(node_id=i,
|
337
|
+
left=i,
|
338
|
+
right=i,
|
339
|
+
members=[i],
|
340
|
+
distance=0,
|
341
|
+
n_members=1) for i in range(n_points)]
|
342
|
+
for z_i in range(Z.shape[0]):
|
343
|
+
row = Z[z_i]
|
344
|
+
left = int(row[0])
|
345
|
+
right = int(row[1])
|
346
|
+
cluster = dict(
|
347
|
+
node_id=z_i + n_points,
|
348
|
+
left=left, right=right,
|
349
|
+
members=[],
|
350
|
+
distance=row[2],
|
351
|
+
n_members=int(row[3])
|
352
|
+
)
|
353
|
+
cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
|
354
|
+
cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
|
355
|
+
cluster["members"].sort()
|
356
|
+
clusters.append(cluster)
|
357
|
+
|
358
|
+
node_map = []
|
359
|
+
for i in range(truncated_Z.shape[0]):
|
360
|
+
node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
|
361
|
+
for cluster in clusters:
|
362
|
+
if cluster['node_id'] in node_ids:
|
363
|
+
node_map.append(cluster)
|
364
|
+
|
365
|
+
filtered_node_map = []
|
366
|
+
superset_node_map = []
|
367
|
+
|
368
|
+
for node in node_map:
|
369
|
+
is_superset = False
|
370
|
+
for other_node in node_map:
|
371
|
+
if (node != other_node
|
372
|
+
and set(node['members']
|
373
|
+
).issuperset(set(other_node['members']))):
|
374
|
+
is_superset = True
|
375
|
+
break
|
376
|
+
if is_superset:
|
377
|
+
superset_node_map.append(node)
|
378
|
+
else:
|
379
|
+
filtered_node_map.append(node)
|
380
|
+
|
381
|
+
# Create a set of node_ids from filtered_node_map and superset_node_map
|
382
|
+
excluded_node_ids = set(
|
383
|
+
node['node_id'] for node in filtered_node_map
|
384
|
+
).union(node['node_id'] for node in superset_node_map)
|
385
|
+
|
386
|
+
# Filter clusters that are not in excluded_node_ids
|
387
|
+
non_excluded_clusters = [
|
388
|
+
cluster for cluster in clusters
|
389
|
+
if cluster['node_id'] not in excluded_node_ids
|
390
|
+
]
|
391
|
+
|
392
|
+
# Create a list to store the result
|
393
|
+
subset_clusters = []
|
394
|
+
|
395
|
+
# Iterate over filtered_node_map
|
396
|
+
for filtered_cluster in filtered_node_map:
|
397
|
+
distances = []
|
398
|
+
for cluster in non_excluded_clusters:
|
399
|
+
if (
|
400
|
+
cluster['n_members'] > 1
|
401
|
+
and set(cluster['members']
|
402
|
+
).issubset(set(filtered_cluster['members']))):
|
403
|
+
distances.append(cluster['distance'])
|
404
|
+
if distances:
|
405
|
+
subset_clusters.append(
|
406
|
+
{'node_id': filtered_cluster['node_id'], 'distance': distances}
|
407
|
+
)
|
408
|
+
|
409
|
+
# Add 'truncated_id' to each dictionary in filtered_node_map
|
410
|
+
for idx, node in enumerate(
|
411
|
+
sorted(filtered_node_map, key=lambda x: x['members'][0])
|
412
|
+
):
|
413
|
+
node['truncated_id'] = idx
|
414
|
+
|
415
|
+
# Create a mapping from node_id to truncated_id
|
416
|
+
node_id_to_truncated_id = {
|
417
|
+
node['node_id']: node['truncated_id'] for node in filtered_node_map
|
418
|
+
}
|
419
|
+
|
420
|
+
# Add 'truncated_id' to each dictionary in subset_clusters
|
421
|
+
for cluster in subset_clusters:
|
422
|
+
cluster['truncated_id'] = node_id_to_truncated_id[cluster['node_id']]
|
423
|
+
|
424
|
+
# Create a list of tuples
|
425
|
+
contraction_marks = []
|
426
|
+
|
427
|
+
# Iterate over subset_clusters
|
428
|
+
for cluster in subset_clusters:
|
429
|
+
truncated_id = cluster['truncated_id']
|
430
|
+
for distance in cluster['distance']:
|
431
|
+
contraction_marks.append((10.0 * truncated_id + 5.0, distance))
|
432
|
+
|
433
|
+
# Sort the list of tuples
|
434
|
+
contraction_marks = sorted(contraction_marks, key=lambda x: (x[0], x[1]))
|
435
|
+
|
436
|
+
return contraction_marks
|
437
|
+
|
438
|
+
|
439
|
+
def _convert_linkage_to_coordinates(Z: np.ndarray):
|
440
|
+
"""
|
441
|
+
Converts a linkage matrix to coordinates for plotting a dendrogram.
|
442
|
+
|
443
|
+
Parameters
|
444
|
+
----------
|
445
|
+
Z : np.ndarray
|
446
|
+
The linkage matrix.
|
447
|
+
|
448
|
+
Returns
|
449
|
+
-------
|
450
|
+
dict
|
451
|
+
A dictionary containing 'icoord', 'dcoord', and 'ivl'
|
452
|
+
for plotting the dendrogram.
|
453
|
+
"""
|
454
|
+
ivl = [i for i in range(Z.shape[0] + 1)]
|
455
|
+
n = len(ivl)
|
456
|
+
icoord = []
|
457
|
+
dcoord = []
|
458
|
+
clusters = {i: [i] for i in range(n)}
|
459
|
+
current_index = n
|
460
|
+
positions = {i: (i + 1) * 10 - 5 for i in range(n)}
|
461
|
+
heights = {i: 0 for i in range(n)}
|
462
|
+
|
463
|
+
for i in range(len(Z)):
|
464
|
+
cluster1 = int(Z[i, 0])
|
465
|
+
cluster2 = int(Z[i, 1])
|
466
|
+
dist = Z[i, 2].item()
|
467
|
+
new_cluster = clusters[cluster1] + clusters[cluster2]
|
468
|
+
clusters[current_index] = new_cluster
|
469
|
+
|
470
|
+
x1 = positions[cluster1]
|
471
|
+
x2 = positions[cluster2]
|
472
|
+
x_new = (x1 + x2) / 2
|
473
|
+
positions[current_index] = x_new
|
474
|
+
|
475
|
+
h1 = heights[cluster1]
|
476
|
+
h2 = heights[cluster2]
|
477
|
+
heights[current_index] = dist
|
478
|
+
|
479
|
+
icoord.append([x1, x1, x2, x2])
|
480
|
+
dcoord.append([h1, dist, dist, h2])
|
481
|
+
|
482
|
+
current_index += 1
|
483
|
+
|
484
|
+
# Sort icoord and dcoord by the first element in each icoord list
|
485
|
+
sorted_indices = sorted(range(len(icoord)), key=lambda i: icoord[i][0])
|
486
|
+
icoord = [icoord[i] for i in sorted_indices]
|
487
|
+
dcoord = [dcoord[i] for i in sorted_indices]
|
488
|
+
|
489
|
+
return {"icoord": icoord, "dcoord": dcoord, "ivl": ivl}
|
490
|
+
|
491
|
+
|
492
|
+
def _vnc_calculate_info(Z: np.ndarray,
|
493
|
+
p=None,
|
494
|
+
truncate=False,
|
495
|
+
contraction_marks=False,
|
496
|
+
labels=None):
|
497
|
+
Z = Z.copy()
|
498
|
+
Zs = Z.shape
|
499
|
+
n = Zs[0] + 1
|
500
|
+
|
501
|
+
if labels is not None:
|
502
|
+
if Zs[0] + 1 != len(labels):
|
503
|
+
labels = None
|
504
|
+
print(dedent(
|
505
|
+
"""
|
506
|
+
Dimensions of Z and labels are not consistent.
|
507
|
+
Using defalut labels.
|
508
|
+
"""))
|
509
|
+
if labels is None:
|
510
|
+
labels = [str(i) for i in range(Zs[0] + 1)]
|
511
|
+
else:
|
512
|
+
labels = labels
|
513
|
+
|
514
|
+
if p is not None and p > n or p < 2:
|
515
|
+
p = None
|
516
|
+
truncate = False
|
517
|
+
contraction_marks = False
|
518
|
+
|
519
|
+
if p is not None:
|
520
|
+
cluster_assignment = [i.item() for i in sch.cut_tree(Z, p)]
|
521
|
+
|
522
|
+
# Create a dictionary to hold the clusters
|
523
|
+
cluster_dict = {}
|
524
|
+
|
525
|
+
# Iterate over the labels and clusters to populate the dictionary
|
526
|
+
for label, cluster in zip(labels, cluster_assignment):
|
527
|
+
cluster_key = f'cluster_{cluster + 1}'
|
528
|
+
if cluster_key not in cluster_dict:
|
529
|
+
cluster_dict[cluster_key] = []
|
530
|
+
cluster_dict[cluster_key].append(label)
|
531
|
+
|
532
|
+
# Convert the dictionary to a list of dictionaries
|
533
|
+
cluster_list = [{key: value} for key, value in cluster_dict.items()]
|
534
|
+
|
535
|
+
# Create a new list to hold the cluster labels
|
536
|
+
cluster_labels = []
|
537
|
+
|
538
|
+
# Iterate over the cluster_list to create the labels
|
539
|
+
for cluster in cluster_list:
|
540
|
+
for key, value in cluster.items():
|
541
|
+
if len(value) == 1:
|
542
|
+
cluster_labels.append(str(value[0]))
|
543
|
+
else:
|
544
|
+
cluster_labels.append(f"{value[0]}-{value[-1]}")
|
545
|
+
|
546
|
+
# get distance for plotting cut line
|
547
|
+
dist = [x[2].item() for x in Z]
|
548
|
+
dist_threshold = np.mean(
|
549
|
+
[dist[len(dist)-p+1], dist[len(dist)-p]]
|
550
|
+
)
|
551
|
+
else:
|
552
|
+
dist_threshold = None
|
553
|
+
cluster_list = None
|
554
|
+
cluster_labels = None
|
555
|
+
|
556
|
+
if truncate is True:
|
557
|
+
truncated_Z = _contract_linkage_matrix(Z, p=p)
|
558
|
+
|
559
|
+
if contraction_marks is True:
|
560
|
+
contraction_marks = _contraction_mark_coordinates(Z, p=p)
|
561
|
+
else:
|
562
|
+
contraction_marks = None
|
563
|
+
|
564
|
+
Z = truncated_Z
|
565
|
+
else:
|
566
|
+
Z = Z
|
567
|
+
contraction_marks = None
|
568
|
+
|
569
|
+
R = _convert_linkage_to_coordinates(Z)
|
570
|
+
|
571
|
+
mh = np.max(Z[:, 2])
|
572
|
+
Zn = Z.shape[0] + 1
|
573
|
+
color_list = ['k'] * (Zn - 1)
|
574
|
+
leaves_color_list = ['k'] * Zn
|
575
|
+
R['n'] = Zn
|
576
|
+
R['mh'] = mh
|
577
|
+
R['p'] = p
|
578
|
+
R['labels'] = labels
|
579
|
+
R['color_list'] = color_list
|
580
|
+
R['leaves_color_list'] = leaves_color_list
|
581
|
+
R['clusters'] = cluster_list
|
582
|
+
R['cluster_labels'] = cluster_labels
|
583
|
+
R['dist_threshold'] = dist_threshold
|
584
|
+
R["contraction_marks"] = contraction_marks
|
585
|
+
|
586
|
+
return R
|
587
|
+
|
588
|
+
|
589
|
+
def _lowess(x,
|
590
|
+
y,
|
591
|
+
f=1./3.):
|
592
|
+
"""
|
593
|
+
Basic LOWESS smoother with uncertainty.
|
594
|
+
Note:
|
595
|
+
- Not robust (so no iteration) and
|
596
|
+
only normally distributed errors.
|
597
|
+
- No higher order polynomials d=1
|
598
|
+
so linear smoother.
|
599
|
+
"""
|
600
|
+
# get some paras
|
601
|
+
# effective width after reduction factor
|
602
|
+
xwidth = f*(x.max()-x.min())
|
603
|
+
# number of obs
|
604
|
+
N = len(x)
|
605
|
+
# Don't assume the data is sorted
|
606
|
+
order = np.argsort(x)
|
607
|
+
# storage
|
608
|
+
y_sm = np.zeros_like(y)
|
609
|
+
y_stderr = np.zeros_like(y)
|
610
|
+
# define the weigthing function -- clipping too!
|
611
|
+
tricube = lambda d: np.clip((1 - np.abs(d)**3)**3, 0, 1) # noqa: E731
|
612
|
+
# run the regression for each observation i
|
613
|
+
for i in range(N):
|
614
|
+
dist = np.abs((x[order][i]-x[order]))/xwidth
|
615
|
+
w = tricube(dist)
|
616
|
+
# form linear system with the weights
|
617
|
+
A = np.stack([w, x[order]*w]).T
|
618
|
+
b = w * y[order]
|
619
|
+
ATA = A.T.dot(A)
|
620
|
+
ATb = A.T.dot(b)
|
621
|
+
# solve the syste
|
622
|
+
sol = np.linalg.solve(ATA, ATb)
|
623
|
+
# predict for the observation only
|
624
|
+
# equiv of A.dot(yest) just for k
|
625
|
+
yest = A[i].dot(sol)
|
626
|
+
place = order[i]
|
627
|
+
y_sm[place] = yest
|
628
|
+
sigma2 = (np.sum((A.dot(sol) - y[order])**2)/N)
|
629
|
+
# Calculate the standard error
|
630
|
+
y_stderr[place] = np.sqrt(sigma2 *
|
631
|
+
A[i].dot(np.linalg.inv(ATA)
|
632
|
+
).dot(A[i]))
|
633
|
+
return y_sm, y_stderr
|
634
|
+
|
635
|
+
|
636
|
+
class TimeSeries:
|
637
|
+
|
638
|
+
def __init__(self,
|
639
|
+
time_series: pl.DataFrame,
|
640
|
+
time_col: str,
|
641
|
+
values_col: str):
|
642
|
+
|
643
|
+
time = time_series.get_column(time_col, default=None)
|
644
|
+
values = time_series.get_column(values_col, default=None)
|
645
|
+
|
646
|
+
if time is None:
|
647
|
+
raise ValueError("""
|
648
|
+
Invalid column.
|
649
|
+
Check name. Couldn't find column in DataFrame.
|
650
|
+
""")
|
651
|
+
if values is None:
|
652
|
+
raise ValueError("""
|
653
|
+
Invalid column.
|
654
|
+
Check name. Couldn't find column in DataFrame.
|
655
|
+
""")
|
656
|
+
if not isinstance(values.dtype, (pl.Float64, pl.Float32)):
|
657
|
+
raise ValueError("""
|
658
|
+
Invalid DataFrame.
|
659
|
+
Expected a column of normalized frequencies.
|
660
|
+
""")
|
661
|
+
if len(time) != len(values):
|
662
|
+
raise ValueError("""
|
663
|
+
Your time and values vectors must be the same length.
|
664
|
+
""")
|
665
|
+
|
666
|
+
time_series = time_series.sort(time)
|
667
|
+
self.time_intervals = time_series.get_column(time_col).to_numpy()
|
668
|
+
self.frequencies = time_series.get_column(values_col).to_numpy()
|
669
|
+
self.Z_sd = _linkage_matrix(time_series=self.time_intervals,
|
670
|
+
frequency=self.frequencies)
|
671
|
+
self.Z_cv = _linkage_matrix(time_series=self.time_intervals,
|
672
|
+
frequency=self.frequencies,
|
673
|
+
distance_measure='cv')
|
674
|
+
self.distances_sd = np.array([self.Z_sd[i][2].item()
|
675
|
+
for i in range(len(self.Z_sd))])
|
676
|
+
self.distances_cv = np.array([self.Z_cv[i][2].item()
|
677
|
+
for i in range(len(self.Z_cv))])
|
678
|
+
|
679
|
+
self.clusters = None
|
680
|
+
self.distance_threshold = None
|
681
|
+
|
682
|
+
def timeviz_barplot(self,
|
683
|
+
width=8,
|
684
|
+
height=4,
|
685
|
+
dpi=150,
|
686
|
+
barwidth=4,
|
687
|
+
fill_color='#440154',
|
688
|
+
tick_interval=None,
|
689
|
+
label_rotation=None):
|
690
|
+
"""
|
691
|
+
Generate a bar plot of token frequenices over time.
|
692
|
+
|
693
|
+
Parameters
|
694
|
+
----------
|
695
|
+
width:
|
696
|
+
The width of the plot.
|
697
|
+
height:
|
698
|
+
The height of the plot.
|
699
|
+
dpi:
|
700
|
+
The resolution of the plot.
|
701
|
+
barwidth:
|
702
|
+
The width of the bars.
|
703
|
+
fill_color:
|
704
|
+
The color of the bars.
|
705
|
+
tick_interval:
|
706
|
+
Interval spacing for the tick labels.
|
707
|
+
label_rotation:
|
708
|
+
Angle used to rotate tick labels.
|
709
|
+
|
710
|
+
Returns
|
711
|
+
-------
|
712
|
+
Figure
|
713
|
+
A matplotlib figure.
|
714
|
+
|
715
|
+
"""
|
716
|
+
xx = self.time_intervals
|
717
|
+
yy = self.frequencies
|
718
|
+
|
719
|
+
if label_rotation is None:
|
720
|
+
rotation = 90
|
721
|
+
else:
|
722
|
+
rotation = label_rotation
|
723
|
+
|
724
|
+
if tick_interval is None:
|
725
|
+
interval = np.diff(xx)[0]
|
726
|
+
else:
|
727
|
+
interval = tick_interval
|
728
|
+
|
729
|
+
start_value = np.min(xx)
|
730
|
+
|
731
|
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
732
|
+
|
733
|
+
ax.bar(xx, yy, color=fill_color, edgecolor='black',
|
734
|
+
linewidth=.5, width=barwidth)
|
735
|
+
|
736
|
+
# Despine
|
737
|
+
ax.spines['right'].set_visible(False)
|
738
|
+
ax.spines['top'].set_visible(False)
|
739
|
+
ax.axhline(y=0, color='black', linestyle='-', linewidth=.5)
|
740
|
+
|
741
|
+
ax.tick_params(axis="x", which="both", labelrotation=rotation)
|
742
|
+
ax.grid(axis='y', color='w', linestyle='--', linewidth=.5)
|
743
|
+
ax.xaxis.set_major_locator(plt.MultipleLocator(base=interval,
|
744
|
+
offset=start_value))
|
745
|
+
|
746
|
+
return fig
|
747
|
+
|
748
|
+
def timeviz_scatterplot(self,
|
749
|
+
width=8,
|
750
|
+
height=4,
|
751
|
+
dpi=150,
|
752
|
+
point_color='black',
|
753
|
+
point_size=0.5,
|
754
|
+
ci='standard') -> Figure:
|
755
|
+
"""
|
756
|
+
Generate a scatter plot of token frequenices over time
|
757
|
+
with a smoothed fit line and a confidence interval.
|
758
|
+
|
759
|
+
Parameters
|
760
|
+
----------
|
761
|
+
width:
|
762
|
+
The width of the plot.
|
763
|
+
height:
|
764
|
+
The height of the plot.
|
765
|
+
dpi:
|
766
|
+
The resolution of the plot.
|
767
|
+
point_color:
|
768
|
+
The color of the points.
|
769
|
+
point_size:
|
770
|
+
The size of the points.
|
771
|
+
ci:
|
772
|
+
The confidence interval. One of "standard" (95%),
|
773
|
+
"strict" (97.5%) or "both".
|
774
|
+
|
775
|
+
Returns
|
776
|
+
-------
|
777
|
+
Figure
|
778
|
+
A matplotlib figure.
|
779
|
+
|
780
|
+
"""
|
781
|
+
ci_types = ['standard', 'strict', 'both']
|
782
|
+
if ci not in ci_types:
|
783
|
+
ci = "standard"
|
784
|
+
|
785
|
+
xx = self.time_intervals
|
786
|
+
yy = self.frequencies
|
787
|
+
|
788
|
+
order = np.argsort(xx)
|
789
|
+
|
790
|
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
791
|
+
|
792
|
+
# run it
|
793
|
+
y_sm, y_std = _lowess(xx, yy, f=1./5.)
|
794
|
+
# plot it
|
795
|
+
ax.plot(xx[order], y_sm[order],
|
796
|
+
color='tomato', linewidth=.5, label='LOWESS')
|
797
|
+
if ci == 'standard':
|
798
|
+
ax.fill_between(
|
799
|
+
xx[order], y_sm[order] - 1.96*y_std[order],
|
800
|
+
y_sm[order] + 1.96*y_std[order], alpha=0.3,
|
801
|
+
label='95 uncertainty')
|
802
|
+
if ci == 'strict':
|
803
|
+
ax.fill_between(
|
804
|
+
xx[order], y_sm[order] - y_std[order],
|
805
|
+
y_sm[order] + y_std[order], alpha=0.3,
|
806
|
+
label='97.5 uncertainty')
|
807
|
+
if ci == 'both':
|
808
|
+
ax.fill_between(
|
809
|
+
xx[order], y_sm[order] - 1.96*y_std[order],
|
810
|
+
y_sm[order] + 1.96*y_std[order], alpha=0.3,
|
811
|
+
label='95 uncertainty')
|
812
|
+
ax.fill_between(
|
813
|
+
xx[order], y_sm[order] - y_std[order],
|
814
|
+
y_sm[order] + y_std[order], alpha=0.3,
|
815
|
+
label='97.5 uncertainty')
|
816
|
+
|
817
|
+
ax.scatter(xx, yy, s=point_size, color=point_color, alpha=0.75)
|
818
|
+
|
819
|
+
# Despine
|
820
|
+
ax.spines['right'].set_visible(False)
|
821
|
+
ax.spines['top'].set_visible(False)
|
822
|
+
|
823
|
+
ticks = [tick for tick in plt.gca().get_yticks() if tick >= 0]
|
824
|
+
plt.gca().set_yticks(ticks)
|
825
|
+
|
826
|
+
return fig
|
827
|
+
|
828
|
+
def timeviz_screeplot(self,
|
829
|
+
width=6,
|
830
|
+
height=3,
|
831
|
+
dpi=150,
|
832
|
+
point_size=0.75,
|
833
|
+
distance="sd") -> Figure:
|
834
|
+
"""
|
835
|
+
Generate a scree plot for determining clusters.
|
836
|
+
|
837
|
+
Parameters
|
838
|
+
----------
|
839
|
+
width:
|
840
|
+
The width of the plot.
|
841
|
+
height:
|
842
|
+
The height of the plot.
|
843
|
+
dpi:
|
844
|
+
The resolution of the plot.
|
845
|
+
point_size:
|
846
|
+
The size of the points.
|
847
|
+
distance:
|
848
|
+
One of 'sd' (standard deviation)
|
849
|
+
or 'cv' (coefficient of variation).
|
850
|
+
|
851
|
+
Returns
|
852
|
+
-------
|
853
|
+
Figure
|
854
|
+
A matplotlib figure.
|
855
|
+
|
856
|
+
"""
|
857
|
+
dist_types = ['sd', 'cv']
|
858
|
+
if distance not in dist_types:
|
859
|
+
distance = "sd"
|
860
|
+
|
861
|
+
if distance == "cv":
|
862
|
+
dist = self.distances_cv
|
863
|
+
else:
|
864
|
+
dist = self.distances_sd
|
865
|
+
|
866
|
+
# SCREEPLOT
|
867
|
+
yy = dist[::-1]
|
868
|
+
xx = np.array([i for i in range(1, len(yy) + 1)])
|
869
|
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
870
|
+
ax.scatter(x=xx,
|
871
|
+
y=yy,
|
872
|
+
marker='o',
|
873
|
+
s=point_size,
|
874
|
+
facecolors='none',
|
875
|
+
edgecolors='black')
|
876
|
+
ax.set_xlabel('Clusters')
|
877
|
+
ax.set_ylabel(f'Distance (in summed {distance})')
|
878
|
+
|
879
|
+
# Despine
|
880
|
+
ax.spines['right'].set_visible(False)
|
881
|
+
ax.spines['top'].set_visible(False)
|
882
|
+
return fig
|
883
|
+
|
884
|
+
def timeviz_vnc(self,
|
885
|
+
width=6,
|
886
|
+
height=4,
|
887
|
+
dpi=150,
|
888
|
+
font_size=10,
|
889
|
+
n_periods=1,
|
890
|
+
distance="sd",
|
891
|
+
orientation="horizontal",
|
892
|
+
cut_line=False,
|
893
|
+
periodize=False,
|
894
|
+
hide_labels=False) -> Figure:
|
895
|
+
"""
|
896
|
+
Generate a dendrogram using the clustering method,
|
897
|
+
"Variability-based Neighbor Clustering"(VNC),
|
898
|
+
to identify periods in the historical development
|
899
|
+
of P that accounts for the temporal ordering of the data.
|
900
|
+
|
901
|
+
Parameters
|
902
|
+
----------
|
903
|
+
width:
|
904
|
+
The width of the plot.
|
905
|
+
height:
|
906
|
+
The height of the plot.
|
907
|
+
dpi:
|
908
|
+
The resolution of the plot.
|
909
|
+
font_size:
|
910
|
+
The font size for the labels.
|
911
|
+
n_periods:
|
912
|
+
The number of periods (or clusters).
|
913
|
+
distance:
|
914
|
+
One of 'sd' (standard deviation)
|
915
|
+
or 'cv' (coefficient of variation).
|
916
|
+
orientation:
|
917
|
+
The orientation of the plot,
|
918
|
+
either "horizontal" or "vertical".
|
919
|
+
cut_line:
|
920
|
+
Whether or not to include a cut line;
|
921
|
+
applies only to non-periodized plots.
|
922
|
+
cut_line:
|
923
|
+
Whether or not to include a cut line;
|
924
|
+
applies only to non-periodized plots.
|
925
|
+
periodize:
|
926
|
+
The dendrogram can be hard to read when the original
|
927
|
+
observation matrix from which the linkage is derived is
|
928
|
+
large. Periodization is used to condense the dendrogram.
|
929
|
+
hide_labels:
|
930
|
+
Whether or not to hide leaf labels.
|
931
|
+
|
932
|
+
Returns
|
933
|
+
-------
|
934
|
+
Figure
|
935
|
+
A matplotlib figure.
|
936
|
+
|
937
|
+
"""
|
938
|
+
dist_types = ['sd', 'cv']
|
939
|
+
if distance not in dist_types:
|
940
|
+
distance = "sd"
|
941
|
+
orientation_types = ['horizontal', 'vertical']
|
942
|
+
if orientation not in orientation_types:
|
943
|
+
orientation = "horizontal"
|
944
|
+
|
945
|
+
if distance == "cv":
|
946
|
+
Z = self.Z_cv
|
947
|
+
else:
|
948
|
+
Z = self.Z_sd
|
949
|
+
|
950
|
+
if n_periods > len(Z):
|
951
|
+
n_periods = 1
|
952
|
+
periodize = False
|
953
|
+
|
954
|
+
if n_periods > 1 and n_periods <= len(Z) and periodize is not True:
|
955
|
+
cut_line = True
|
956
|
+
|
957
|
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
958
|
+
|
959
|
+
# Plot the corresponding dendrogram
|
960
|
+
if orientation == "horizontal" and periodize is not True:
|
961
|
+
X = _vnc_calculate_info(Z,
|
962
|
+
p=n_periods,
|
963
|
+
labels=self.time_intervals)
|
964
|
+
|
965
|
+
self.clusters = X['clusters']
|
966
|
+
|
967
|
+
sch._plot_dendrogram(icoords=X['icoord'],
|
968
|
+
dcoords=X['dcoord'],
|
969
|
+
ivl=X['ivl'],
|
970
|
+
color_list=X['color_list'],
|
971
|
+
mh=X['mh'],
|
972
|
+
orientation='top',
|
973
|
+
p=X['p'],
|
974
|
+
n=X['n'],
|
975
|
+
no_labels=False)
|
976
|
+
|
977
|
+
ax.spines['top'].set_visible(False)
|
978
|
+
ax.spines['right'].set_visible(False)
|
979
|
+
ax.spines['bottom'].set_visible(False)
|
980
|
+
ax.set_ylabel(f'Distance (in summed {distance})')
|
981
|
+
|
982
|
+
if hide_labels is not True:
|
983
|
+
ax.set_xticklabels(X['labels'],
|
984
|
+
fontsize=font_size,
|
985
|
+
rotation=90)
|
986
|
+
else:
|
987
|
+
ax.set_xticklabels([])
|
988
|
+
|
989
|
+
plt.setp(ax.collections, linewidth=.5)
|
990
|
+
|
991
|
+
if cut_line and X['dist_threshold'] is not None:
|
992
|
+
ax.axhline(y=X['dist_threshold'],
|
993
|
+
color='r',
|
994
|
+
alpha=0.7,
|
995
|
+
linestyle='--',
|
996
|
+
linewidth=.5)
|
997
|
+
|
998
|
+
if orientation == "horizontal" and periodize is True:
|
999
|
+
X = _vnc_calculate_info(Z,
|
1000
|
+
truncate=True,
|
1001
|
+
p=n_periods,
|
1002
|
+
contraction_marks=True,
|
1003
|
+
labels=self.time_intervals)
|
1004
|
+
|
1005
|
+
self.clusters = X['clusters']
|
1006
|
+
|
1007
|
+
sch._plot_dendrogram(icoords=X['icoord'],
|
1008
|
+
dcoords=X['dcoord'],
|
1009
|
+
ivl=X['ivl'],
|
1010
|
+
color_list=X['color_list'],
|
1011
|
+
mh=X['mh'], orientation='top',
|
1012
|
+
p=X['p'],
|
1013
|
+
n=X['n'],
|
1014
|
+
no_labels=False,
|
1015
|
+
contraction_marks=X['contraction_marks'])
|
1016
|
+
|
1017
|
+
ax.spines['top'].set_visible(False)
|
1018
|
+
ax.spines['right'].set_visible(False)
|
1019
|
+
ax.spines['bottom'].set_visible(False)
|
1020
|
+
ax.set_ylabel(f'Distance (in summed {distance})')
|
1021
|
+
|
1022
|
+
if hide_labels is not True:
|
1023
|
+
ax.set_xticklabels(X['cluster_labels'],
|
1024
|
+
fontsize=font_size,
|
1025
|
+
rotation=90)
|
1026
|
+
else:
|
1027
|
+
ax.set_xticklabels([])
|
1028
|
+
|
1029
|
+
plt.setp(ax.collections, linewidth=.5)
|
1030
|
+
|
1031
|
+
if orientation == "vertical" and periodize is not True:
|
1032
|
+
X = _vnc_calculate_info(Z,
|
1033
|
+
p=n_periods,
|
1034
|
+
labels=self.time_intervals)
|
1035
|
+
|
1036
|
+
self.clusters = X['clusters']
|
1037
|
+
|
1038
|
+
sch._plot_dendrogram(icoords=X['icoord'],
|
1039
|
+
dcoords=X['dcoord'],
|
1040
|
+
ivl=X['ivl'],
|
1041
|
+
color_list=X['color_list'],
|
1042
|
+
mh=X['mh'],
|
1043
|
+
orientation='right',
|
1044
|
+
p=X['p'],
|
1045
|
+
n=X['n'],
|
1046
|
+
no_labels=False)
|
1047
|
+
|
1048
|
+
ax.spines['top'].set_visible(False)
|
1049
|
+
ax.spines['right'].set_visible(False)
|
1050
|
+
ax.spines['left'].set_visible(False)
|
1051
|
+
ax.set_xlabel(f'Distance (in summed {distance})')
|
1052
|
+
|
1053
|
+
if hide_labels is not True:
|
1054
|
+
ax.set_yticklabels(X['labels'],
|
1055
|
+
fontsize=font_size,
|
1056
|
+
rotation=0)
|
1057
|
+
else:
|
1058
|
+
ax.set_yticklabels([])
|
1059
|
+
|
1060
|
+
ymin, ymax = ax.get_ylim()
|
1061
|
+
ax.set_ylim(ymax, ymin)
|
1062
|
+
plt.setp(ax.collections, linewidth=.5)
|
1063
|
+
|
1064
|
+
if cut_line and X['dist_threshold'] is not None:
|
1065
|
+
ax.axvline(x=X['dist_threshold'],
|
1066
|
+
color='r',
|
1067
|
+
alpha=0.7,
|
1068
|
+
linestyle='--',
|
1069
|
+
linewidth=.5)
|
1070
|
+
|
1071
|
+
if orientation == "vertical" and periodize is True:
|
1072
|
+
X = _vnc_calculate_info(Z,
|
1073
|
+
truncate=True,
|
1074
|
+
p=n_periods,
|
1075
|
+
contraction_marks=True,
|
1076
|
+
labels=self.time_intervals)
|
1077
|
+
|
1078
|
+
self.clusters = X['clusters']
|
1079
|
+
|
1080
|
+
sch._plot_dendrogram(icoords=X['icoord'],
|
1081
|
+
dcoords=X['dcoord'],
|
1082
|
+
ivl=X['ivl'],
|
1083
|
+
color_list=X['color_list'],
|
1084
|
+
mh=X['mh'], orientation='right',
|
1085
|
+
p=X['p'],
|
1086
|
+
n=X['n'],
|
1087
|
+
no_labels=False,
|
1088
|
+
contraction_marks=X['contraction_marks'])
|
1089
|
+
|
1090
|
+
ax.spines['top'].set_visible(False)
|
1091
|
+
ax.spines['right'].set_visible(False)
|
1092
|
+
ax.spines['left'].set_visible(False)
|
1093
|
+
ax.set_xlabel(f'Distance (in summed {distance})')
|
1094
|
+
|
1095
|
+
if hide_labels is not True:
|
1096
|
+
ax.set_yticklabels(X['cluster_labels'],
|
1097
|
+
fontsize=font_size,
|
1098
|
+
rotation=0)
|
1099
|
+
else:
|
1100
|
+
ax.set_yticklabels([])
|
1101
|
+
|
1102
|
+
ymin, ymax = ax.get_ylim()
|
1103
|
+
ax.set_ylim(ymax, ymin)
|
1104
|
+
plt.setp(ax.collections, linewidth=.5)
|
1105
|
+
|
1106
|
+
return fig
|
1107
|
+
|
1108
|
+
def cluster_summary(self):
|
1109
|
+
"""
|
1110
|
+
Print a summary of cluster membership.
|
1111
|
+
|
1112
|
+
Returns
|
1113
|
+
-------
|
1114
|
+
Prints to the console.
|
1115
|
+
|
1116
|
+
"""
|
1117
|
+
cluster_list = self.clusters
|
1118
|
+
if cluster_list is not None:
|
1119
|
+
for i, cluster in enumerate(cluster_list, start=1):
|
1120
|
+
for key, value in cluster.items():
|
1121
|
+
print(f"Cluster {i} (n={len(value)}): {[str(v) for v in value]}") # noqa: E501
|
1122
|
+
else:
|
1123
|
+
print("No clusters to summarize.")
|