google-ngrams 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_ngrams/__init__.py +19 -0
- google_ngrams/data/__init__.py +14 -0
- google_ngrams/data/googlebooks_eng_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/data/googlebooks_eng_gb_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/data/googlebooks_eng_us_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/ngrams.py +341 -0
- google_ngrams/scatter_helpers.py +187 -0
- google_ngrams/vnc.py +518 -0
- google_ngrams/vnc_helpers.py +809 -0
- google_ngrams-0.2.0.dist-info/METADATA +144 -0
- google_ngrams-0.2.0.dist-info/RECORD +14 -0
- google_ngrams-0.2.0.dist-info/WHEEL +5 -0
- google_ngrams-0.2.0.dist-info/licenses/LICENSE +162 -0
- google_ngrams-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,809 @@
|
|
1
|
+
import copy
|
2
|
+
import numpy as np
|
3
|
+
import polars as pl
|
4
|
+
from textwrap import dedent
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"_linkage_matrix",
|
8
|
+
"_vnc_calculate_info",
|
9
|
+
"_plot_dendrogram",
|
10
|
+
"_cut_tree_simple"
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
def _cut_tree_simple(Z, n_clusters=None, height=None):
|
15
|
+
if height is not None:
|
16
|
+
raise NotImplementedError(
|
17
|
+
"""
|
18
|
+
height-based cuts not supported in simplified implementation
|
19
|
+
"""
|
20
|
+
)
|
21
|
+
n = Z.shape[0] + 1
|
22
|
+
if n_clusters is None or n_clusters >= n:
|
23
|
+
return np.arange(n, dtype=int)
|
24
|
+
if n_clusters <= 1:
|
25
|
+
return np.zeros(n, dtype=int)
|
26
|
+
merges_to_apply = n - int(n_clusters)
|
27
|
+
labels = np.arange(n, dtype=int)
|
28
|
+
node_leaves = [None] * (2 * n - 1)
|
29
|
+
for i in range(n):
|
30
|
+
node_leaves[i] = [i]
|
31
|
+
for i in range(Z.shape[0]):
|
32
|
+
left = int(Z[i, 0])
|
33
|
+
right = int(Z[i, 1])
|
34
|
+
leaves = node_leaves[left] + node_leaves[right]
|
35
|
+
leaves.sort()
|
36
|
+
node_leaves[n + i] = leaves
|
37
|
+
if i < merges_to_apply:
|
38
|
+
block = leaves
|
39
|
+
block_labels = labels[block]
|
40
|
+
new_label = block_labels.min()
|
41
|
+
max_label = block_labels.max()
|
42
|
+
labels[block] = new_label
|
43
|
+
# compress labels above max_label downward
|
44
|
+
mask = labels > max_label
|
45
|
+
labels[mask] -= (max_label - new_label)
|
46
|
+
else:
|
47
|
+
break
|
48
|
+
return labels
|
49
|
+
|
50
|
+
|
51
|
+
def _remove_dups(L):
|
52
|
+
"""
|
53
|
+
Remove duplicates AND preserve the original order of the elements.
|
54
|
+
|
55
|
+
The set class is not guaranteed to do this.
|
56
|
+
"""
|
57
|
+
seen_before = set()
|
58
|
+
L2 = []
|
59
|
+
for i in L:
|
60
|
+
if i not in seen_before:
|
61
|
+
seen_before.add(i)
|
62
|
+
L2.append(i)
|
63
|
+
return L2
|
64
|
+
|
65
|
+
|
66
|
+
_dtextsizes = {20: 12, 30: 10, 50: 8, 85: 6, np.inf: 5}
|
67
|
+
_drotation = {20: 0, 40: 45, np.inf: 90}
|
68
|
+
_dtextsortedkeys = list(_dtextsizes.keys())
|
69
|
+
_dtextsortedkeys.sort()
|
70
|
+
_drotationsortedkeys = list(_drotation.keys())
|
71
|
+
_drotationsortedkeys.sort()
|
72
|
+
|
73
|
+
|
74
|
+
def _get_tick_text_size(p):
|
75
|
+
for k in _dtextsortedkeys:
|
76
|
+
if p <= k:
|
77
|
+
return _dtextsizes[k]
|
78
|
+
|
79
|
+
|
80
|
+
def _get_tick_rotation(p):
|
81
|
+
for k in _drotationsortedkeys:
|
82
|
+
if p <= k:
|
83
|
+
return _drotation[k]
|
84
|
+
|
85
|
+
|
86
|
+
def _contract_linkage_matrix(
|
87
|
+
Z: np.ndarray,
|
88
|
+
p=4
|
89
|
+
) -> np.ndarray:
|
90
|
+
"""
|
91
|
+
Contracts the linkage matrix by reducing the number of clusters
|
92
|
+
to a specified number.
|
93
|
+
|
94
|
+
Parameters
|
95
|
+
----------
|
96
|
+
Z : np.ndarray
|
97
|
+
The linkage matrix.
|
98
|
+
p : int
|
99
|
+
The number of clusters to retain.
|
100
|
+
|
101
|
+
Returns
|
102
|
+
-------
|
103
|
+
np.ndarray
|
104
|
+
The contracted linkage matrix with updated cluster IDs
|
105
|
+
and member counts.
|
106
|
+
"""
|
107
|
+
Z = Z.copy()
|
108
|
+
truncated_Z = Z[-(p - 1):]
|
109
|
+
|
110
|
+
n_points = Z.shape[0] + 1
|
111
|
+
clusters = [
|
112
|
+
dict(node_id=i, left=i, right=i, members=[i], distance=0, n_members=1)
|
113
|
+
for i in range(n_points)
|
114
|
+
]
|
115
|
+
for z_i in range(Z.shape[0]):
|
116
|
+
row = Z[z_i]
|
117
|
+
left = int(row[0])
|
118
|
+
right = int(row[1])
|
119
|
+
cluster = dict(
|
120
|
+
node_id=z_i + n_points,
|
121
|
+
left=left,
|
122
|
+
right=right,
|
123
|
+
members=[],
|
124
|
+
distance=row[2],
|
125
|
+
n_members=int(row[3])
|
126
|
+
)
|
127
|
+
cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
|
128
|
+
cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
|
129
|
+
cluster["members"].sort()
|
130
|
+
clusters.append(cluster)
|
131
|
+
|
132
|
+
node_map = []
|
133
|
+
for i in range(truncated_Z.shape[0]):
|
134
|
+
node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
|
135
|
+
for cluster in clusters:
|
136
|
+
if cluster['node_id'] in node_ids:
|
137
|
+
node_map.append(cluster)
|
138
|
+
|
139
|
+
filtered_node_map = []
|
140
|
+
superset_node_map = []
|
141
|
+
|
142
|
+
for node in node_map:
|
143
|
+
is_superset = False
|
144
|
+
for other_node in node_map:
|
145
|
+
if (
|
146
|
+
node != other_node
|
147
|
+
and set(
|
148
|
+
node['members']
|
149
|
+
).issuperset(set(other_node['members']))
|
150
|
+
):
|
151
|
+
is_superset = True
|
152
|
+
break
|
153
|
+
if is_superset:
|
154
|
+
superset_node_map.append(node)
|
155
|
+
else:
|
156
|
+
filtered_node_map.append(node)
|
157
|
+
|
158
|
+
# Add 'truncated_id' to each dictionary in filtered_node_map
|
159
|
+
for idx, node in enumerate(
|
160
|
+
sorted(filtered_node_map, key=lambda x: x['members'][0])
|
161
|
+
):
|
162
|
+
node['truncated_id'] = idx
|
163
|
+
node['n_members'] = 1
|
164
|
+
|
165
|
+
for idx, node in enumerate(
|
166
|
+
sorted(superset_node_map, key=lambda x: x['node_id'])
|
167
|
+
):
|
168
|
+
node['truncated_id'] = idx + len(filtered_node_map)
|
169
|
+
|
170
|
+
# Adjust 'n_members' in superset_node_map to reflect
|
171
|
+
# the number of filtered_node_map['members'] sets they contain
|
172
|
+
for superset_node in superset_node_map:
|
173
|
+
count = 0
|
174
|
+
for filtered_node in filtered_node_map:
|
175
|
+
if set(
|
176
|
+
filtered_node['members']
|
177
|
+
).issubset(set(superset_node['members'])):
|
178
|
+
count += 1
|
179
|
+
superset_node['n_members'] = count
|
180
|
+
|
181
|
+
# Create a mapping from node_id to truncated_id and n_members
|
182
|
+
node_id_to_truncated_id = {
|
183
|
+
node['node_id']: node['truncated_id']
|
184
|
+
for node in filtered_node_map + superset_node_map
|
185
|
+
}
|
186
|
+
node_id_to_n_members = {
|
187
|
+
node['node_id']: node['n_members']
|
188
|
+
for node in filtered_node_map + superset_node_map
|
189
|
+
}
|
190
|
+
|
191
|
+
# Replace values in truncated_Z
|
192
|
+
for i in range(truncated_Z.shape[0]):
|
193
|
+
truncated_Z[i, 3] = (
|
194
|
+
node_id_to_n_members[int(truncated_Z[i, 0])] +
|
195
|
+
node_id_to_n_members[int(truncated_Z[i, 1])]
|
196
|
+
)
|
197
|
+
truncated_Z[i, 0] = node_id_to_truncated_id[int(truncated_Z[i, 0])]
|
198
|
+
truncated_Z[i, 1] = node_id_to_truncated_id[int(truncated_Z[i, 1])]
|
199
|
+
|
200
|
+
return truncated_Z
|
201
|
+
|
202
|
+
|
203
|
+
def _contraction_mark_coordinates(
|
204
|
+
Z: np.ndarray,
|
205
|
+
p=4
|
206
|
+
) -> list:
|
207
|
+
"""
|
208
|
+
Generates contraction marks for a given linkage matrix.
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
Z : np.ndarray
|
213
|
+
The linkage matrix.
|
214
|
+
p : int
|
215
|
+
The number of clusters to retain.
|
216
|
+
|
217
|
+
Returns
|
218
|
+
-------
|
219
|
+
list
|
220
|
+
A sorted list of tuples where each tuple contains
|
221
|
+
a calculated value based on truncated_id and a distance value.
|
222
|
+
"""
|
223
|
+
Z = Z.copy()
|
224
|
+
truncated_Z = Z[-(p-1):]
|
225
|
+
|
226
|
+
n_points = Z.shape[0] + 1
|
227
|
+
clusters = [dict(node_id=i,
|
228
|
+
left=i,
|
229
|
+
right=i,
|
230
|
+
members=[i],
|
231
|
+
distance=0,
|
232
|
+
n_members=1) for i in range(n_points)]
|
233
|
+
for z_i in range(Z.shape[0]):
|
234
|
+
row = Z[z_i]
|
235
|
+
left = int(row[0])
|
236
|
+
right = int(row[1])
|
237
|
+
cluster = dict(
|
238
|
+
node_id=z_i + n_points,
|
239
|
+
left=left, right=right,
|
240
|
+
members=[],
|
241
|
+
distance=row[2],
|
242
|
+
n_members=int(row[3])
|
243
|
+
)
|
244
|
+
cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
|
245
|
+
cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
|
246
|
+
cluster["members"].sort()
|
247
|
+
clusters.append(cluster)
|
248
|
+
|
249
|
+
node_map = []
|
250
|
+
for i in range(truncated_Z.shape[0]):
|
251
|
+
node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
|
252
|
+
for cluster in clusters:
|
253
|
+
if cluster['node_id'] in node_ids:
|
254
|
+
node_map.append(cluster)
|
255
|
+
|
256
|
+
filtered_node_map = []
|
257
|
+
superset_node_map = []
|
258
|
+
|
259
|
+
for node in node_map:
|
260
|
+
is_superset = False
|
261
|
+
for other_node in node_map:
|
262
|
+
if (node != other_node
|
263
|
+
and set(node['members']
|
264
|
+
).issuperset(set(other_node['members']))):
|
265
|
+
is_superset = True
|
266
|
+
break
|
267
|
+
if is_superset:
|
268
|
+
superset_node_map.append(node)
|
269
|
+
else:
|
270
|
+
filtered_node_map.append(node)
|
271
|
+
|
272
|
+
# Create a set of node_ids from filtered_node_map and superset_node_map
|
273
|
+
excluded_node_ids = set(
|
274
|
+
node['node_id'] for node in filtered_node_map
|
275
|
+
).union(node['node_id'] for node in superset_node_map)
|
276
|
+
|
277
|
+
# Filter clusters that are not in excluded_node_ids
|
278
|
+
non_excluded_clusters = [
|
279
|
+
cluster for cluster in clusters
|
280
|
+
if cluster['node_id'] not in excluded_node_ids
|
281
|
+
]
|
282
|
+
|
283
|
+
# Create a list to store the result
|
284
|
+
subset_clusters = []
|
285
|
+
|
286
|
+
# Iterate over filtered_node_map
|
287
|
+
for filtered_cluster in filtered_node_map:
|
288
|
+
distances = []
|
289
|
+
for cluster in non_excluded_clusters:
|
290
|
+
if (
|
291
|
+
cluster['n_members'] > 1
|
292
|
+
and set(cluster['members']
|
293
|
+
).issubset(set(filtered_cluster['members']))):
|
294
|
+
distances.append(cluster['distance'])
|
295
|
+
if distances:
|
296
|
+
subset_clusters.append(
|
297
|
+
{'node_id': filtered_cluster['node_id'], 'distance': distances}
|
298
|
+
)
|
299
|
+
|
300
|
+
# Add 'truncated_id' to each dictionary in filtered_node_map
|
301
|
+
for idx, node in enumerate(
|
302
|
+
sorted(filtered_node_map, key=lambda x: x['members'][0])
|
303
|
+
):
|
304
|
+
node['truncated_id'] = idx
|
305
|
+
|
306
|
+
# Create a mapping from node_id to truncated_id
|
307
|
+
node_id_to_truncated_id = {
|
308
|
+
node['node_id']: node['truncated_id'] for node in filtered_node_map
|
309
|
+
}
|
310
|
+
|
311
|
+
# Add 'truncated_id' to each dictionary in subset_clusters
|
312
|
+
for cluster in subset_clusters:
|
313
|
+
cluster['truncated_id'] = node_id_to_truncated_id[cluster['node_id']]
|
314
|
+
|
315
|
+
# Create a list of tuples
|
316
|
+
contraction_marks = []
|
317
|
+
|
318
|
+
# Iterate over subset_clusters
|
319
|
+
for cluster in subset_clusters:
|
320
|
+
truncated_id = cluster['truncated_id']
|
321
|
+
for distance in cluster['distance']:
|
322
|
+
contraction_marks.append((10.0 * truncated_id + 5.0, distance))
|
323
|
+
|
324
|
+
# Sort the list of tuples
|
325
|
+
contraction_marks = sorted(contraction_marks, key=lambda x: (x[0], x[1]))
|
326
|
+
|
327
|
+
return contraction_marks
|
328
|
+
|
329
|
+
|
330
|
+
def _convert_linkage_to_coordinates(
|
331
|
+
Z: np.ndarray
|
332
|
+
) -> dict:
|
333
|
+
"""
|
334
|
+
Converts a linkage matrix to coordinates for plotting a dendrogram.
|
335
|
+
|
336
|
+
Parameters
|
337
|
+
----------
|
338
|
+
Z : np.ndarray
|
339
|
+
The linkage matrix.
|
340
|
+
|
341
|
+
Returns
|
342
|
+
-------
|
343
|
+
dict
|
344
|
+
A dictionary containing 'icoord', 'dcoord', and 'ivl'
|
345
|
+
for plotting the dendrogram.
|
346
|
+
"""
|
347
|
+
ivl = [i for i in range(Z.shape[0] + 1)]
|
348
|
+
n = len(ivl)
|
349
|
+
icoord = []
|
350
|
+
dcoord = []
|
351
|
+
clusters = {i: [i] for i in range(n)}
|
352
|
+
current_index = n
|
353
|
+
positions = {i: (i + 1) * 10 - 5 for i in range(n)}
|
354
|
+
heights = {i: 0 for i in range(n)}
|
355
|
+
|
356
|
+
for i in range(len(Z)):
|
357
|
+
cluster1 = int(Z[i, 0])
|
358
|
+
cluster2 = int(Z[i, 1])
|
359
|
+
dist = Z[i, 2].item()
|
360
|
+
new_cluster = clusters[cluster1] + clusters[cluster2]
|
361
|
+
clusters[current_index] = new_cluster
|
362
|
+
|
363
|
+
x1 = positions[cluster1]
|
364
|
+
x2 = positions[cluster2]
|
365
|
+
x_new = (x1 + x2) / 2
|
366
|
+
positions[current_index] = x_new
|
367
|
+
|
368
|
+
h1 = heights[cluster1]
|
369
|
+
h2 = heights[cluster2]
|
370
|
+
heights[current_index] = dist
|
371
|
+
|
372
|
+
icoord.append([x1, x1, x2, x2])
|
373
|
+
dcoord.append([h1, dist, dist, h2])
|
374
|
+
|
375
|
+
current_index += 1
|
376
|
+
|
377
|
+
# Sort icoord and dcoord by the first element in each icoord list
|
378
|
+
sorted_indices = sorted(range(len(icoord)), key=lambda i: icoord[i][0])
|
379
|
+
icoord = [icoord[i] for i in sorted_indices]
|
380
|
+
dcoord = [dcoord[i] for i in sorted_indices]
|
381
|
+
|
382
|
+
return {"icoord": icoord, "dcoord": dcoord, "ivl": ivl}
|
383
|
+
|
384
|
+
|
385
|
+
def _vnc_calculate_info(
|
386
|
+
Z: np.ndarray,
|
387
|
+
p=None,
|
388
|
+
truncate=False,
|
389
|
+
contraction_marks=False,
|
390
|
+
labels=None
|
391
|
+
) -> dict:
|
392
|
+
Z = Z.copy()
|
393
|
+
Zs = Z.shape
|
394
|
+
n = Zs[0] + 1
|
395
|
+
|
396
|
+
if labels is not None:
|
397
|
+
if Zs[0] + 1 != len(labels):
|
398
|
+
labels = None
|
399
|
+
print(dedent(
|
400
|
+
"""
|
401
|
+
Dimensions of Z and labels are not consistent.
|
402
|
+
Using defalut labels.
|
403
|
+
"""))
|
404
|
+
if labels is None:
|
405
|
+
labels = [str(i) for i in range(Zs[0] + 1)]
|
406
|
+
else:
|
407
|
+
labels = labels
|
408
|
+
|
409
|
+
if p is not None and p > n or p < 2:
|
410
|
+
p = None
|
411
|
+
truncate = False
|
412
|
+
contraction_marks = False
|
413
|
+
|
414
|
+
if p is not None:
|
415
|
+
cluster_assignment = [i.item() for i in _cut_tree_simple(Z, p)]
|
416
|
+
|
417
|
+
# Create a dictionary to hold the clusters
|
418
|
+
cluster_dict = {}
|
419
|
+
|
420
|
+
# Iterate over the labels and clusters to populate the dictionary
|
421
|
+
for label, cluster in zip(labels, cluster_assignment):
|
422
|
+
cluster_key = f'cluster_{cluster + 1}'
|
423
|
+
if cluster_key not in cluster_dict:
|
424
|
+
cluster_dict[cluster_key] = []
|
425
|
+
cluster_dict[cluster_key].append(label)
|
426
|
+
|
427
|
+
# Convert the dictionary to a list of dictionaries
|
428
|
+
cluster_list = [{key: value} for key, value in cluster_dict.items()]
|
429
|
+
|
430
|
+
# Create a new list to hold the cluster labels
|
431
|
+
cluster_labels = []
|
432
|
+
|
433
|
+
# Iterate over the cluster_list to create the labels
|
434
|
+
for cluster in cluster_list:
|
435
|
+
for key, value in cluster.items():
|
436
|
+
if len(value) == 1:
|
437
|
+
cluster_labels.append(str(value[0]))
|
438
|
+
else:
|
439
|
+
cluster_labels.append(f"{value[0]}-{value[-1]}")
|
440
|
+
|
441
|
+
# get distance for plotting cut line
|
442
|
+
dist = [x[2].item() for x in Z]
|
443
|
+
dist_threshold = np.mean(
|
444
|
+
[dist[len(dist)-p+1], dist[len(dist)-p]]
|
445
|
+
)
|
446
|
+
else:
|
447
|
+
dist_threshold = None
|
448
|
+
cluster_list = None
|
449
|
+
cluster_labels = None
|
450
|
+
|
451
|
+
if truncate is True:
|
452
|
+
truncated_Z = _contract_linkage_matrix(Z, p=p)
|
453
|
+
|
454
|
+
if contraction_marks is True:
|
455
|
+
contraction_marks = _contraction_mark_coordinates(Z, p=p)
|
456
|
+
else:
|
457
|
+
contraction_marks = None
|
458
|
+
|
459
|
+
Z = truncated_Z
|
460
|
+
else:
|
461
|
+
Z = Z
|
462
|
+
contraction_marks = None
|
463
|
+
|
464
|
+
R = _convert_linkage_to_coordinates(Z)
|
465
|
+
|
466
|
+
mh = np.max(Z[:, 2])
|
467
|
+
Zn = Z.shape[0] + 1
|
468
|
+
color_list = ['k'] * (Zn - 1)
|
469
|
+
leaves_color_list = ['k'] * Zn
|
470
|
+
R['n'] = Zn
|
471
|
+
R['mh'] = mh
|
472
|
+
R['p'] = p
|
473
|
+
R['labels'] = labels
|
474
|
+
R['color_list'] = color_list
|
475
|
+
R['leaves_color_list'] = leaves_color_list
|
476
|
+
R['clusters'] = cluster_list
|
477
|
+
R['cluster_labels'] = cluster_labels
|
478
|
+
R['dist_threshold'] = dist_threshold
|
479
|
+
R["contraction_marks"] = contraction_marks
|
480
|
+
|
481
|
+
return R
|
482
|
+
|
483
|
+
|
484
|
+
def _linkage_matrix(
|
485
|
+
time_series,
|
486
|
+
frequency,
|
487
|
+
distance_measure='sd'
|
488
|
+
) -> np.ndarray:
|
489
|
+
|
490
|
+
input_values = frequency.copy()
|
491
|
+
years = time_series.copy()
|
492
|
+
|
493
|
+
data_collector = {}
|
494
|
+
data_collector["0"] = input_values
|
495
|
+
position_collector = {}
|
496
|
+
position_collector[1] = 0
|
497
|
+
overall_distance = 0
|
498
|
+
number_of_steps = len(input_values) - 1
|
499
|
+
|
500
|
+
for i in range(1, number_of_steps + 1):
|
501
|
+
difference_checker = []
|
502
|
+
unique_years = np.unique(years)
|
503
|
+
|
504
|
+
for j in range(len(unique_years) - 1):
|
505
|
+
first_name = unique_years[j]
|
506
|
+
second_name = unique_years[j + 1]
|
507
|
+
pooled_sample = input_values[np.isin(years,
|
508
|
+
[first_name,
|
509
|
+
second_name])]
|
510
|
+
|
511
|
+
if distance_measure == "sd":
|
512
|
+
difference_checker.append(0 if np.sum(pooled_sample) == 0
|
513
|
+
else np.std(pooled_sample, ddof=1))
|
514
|
+
elif distance_measure == "cv":
|
515
|
+
difference_checker.append(
|
516
|
+
0 if np.sum(pooled_sample) == 0
|
517
|
+
else np.std(pooled_sample, ddof=1) / np.mean(pooled_sample)
|
518
|
+
)
|
519
|
+
|
520
|
+
pos_to_be_merged = np.argmin(difference_checker)
|
521
|
+
distance = np.min(difference_checker)
|
522
|
+
overall_distance += distance
|
523
|
+
lower_name = unique_years[pos_to_be_merged]
|
524
|
+
higher_name = unique_years[pos_to_be_merged + 1]
|
525
|
+
|
526
|
+
matches = np.isin(years, [lower_name, higher_name])
|
527
|
+
new_mean_age = round(np.mean(years[matches]), 4)
|
528
|
+
position_collector[i + 1] = np.where(matches)[0] + 1
|
529
|
+
years[matches] = new_mean_age
|
530
|
+
data_collector[f"{i}: {distance}"] = input_values
|
531
|
+
|
532
|
+
hc_build = pl.DataFrame({
|
533
|
+
'start': [
|
534
|
+
min(pos)
|
535
|
+
if isinstance(pos, (list, np.ndarray))
|
536
|
+
else pos for pos in position_collector.values()
|
537
|
+
],
|
538
|
+
'end': [
|
539
|
+
max(pos)
|
540
|
+
if isinstance(pos, (list, np.ndarray))
|
541
|
+
else pos for pos in position_collector.values()
|
542
|
+
]
|
543
|
+
})
|
544
|
+
|
545
|
+
idx = np.arange(len(hc_build))
|
546
|
+
|
547
|
+
y = [np.where(
|
548
|
+
hc_build['start'].to_numpy()[:i] == hc_build['start'].to_numpy()[i]
|
549
|
+
)[0] for i in idx]
|
550
|
+
z = [np.where(
|
551
|
+
hc_build['end'].to_numpy()[:i] == hc_build['end'].to_numpy()[i]
|
552
|
+
)[0] for i in idx]
|
553
|
+
|
554
|
+
merge1 = [
|
555
|
+
y[i].max().item() if len(y[i]) else np.nan for i in range(len(y))
|
556
|
+
]
|
557
|
+
merge2 = [
|
558
|
+
z[i].max().item() if len(z[i]) else np.nan for i in range(len(z))
|
559
|
+
]
|
560
|
+
|
561
|
+
hc_build = (
|
562
|
+
hc_build.with_columns([
|
563
|
+
pl.Series('merge1',
|
564
|
+
[
|
565
|
+
min(m1, m2) if not np.isnan(m1) and
|
566
|
+
not np.isnan(m2)
|
567
|
+
else np.nan for m1, m2 in zip(merge1, merge2)
|
568
|
+
]),
|
569
|
+
pl.Series('merge2',
|
570
|
+
[
|
571
|
+
max(m1, m2) for m1, m2 in zip(merge1, merge2)
|
572
|
+
])
|
573
|
+
])
|
574
|
+
)
|
575
|
+
|
576
|
+
hc_build = (
|
577
|
+
hc_build.with_columns([
|
578
|
+
pl.Series('merge1', [
|
579
|
+
min(m1, m2) if not np.isnan(m1) and
|
580
|
+
not np.isnan(m2) else np.nan for m1, m2 in zip(merge1, merge2)
|
581
|
+
]),
|
582
|
+
pl.Series('merge2', [
|
583
|
+
max(m1, m2) if not np.isnan(m1)
|
584
|
+
else m2 for m1, m2 in zip(merge1, merge2)
|
585
|
+
])
|
586
|
+
])
|
587
|
+
)
|
588
|
+
|
589
|
+
hc_build = (
|
590
|
+
hc_build.with_columns([
|
591
|
+
pl.when(
|
592
|
+
pl.col('merge1').is_nan() &
|
593
|
+
pl.col('merge2').is_nan()
|
594
|
+
).then(-pl.col('start')
|
595
|
+
).otherwise(pl.col('merge1')).alias('merge1'),
|
596
|
+
pl.when(
|
597
|
+
pl.col('merge2')
|
598
|
+
.is_nan()
|
599
|
+
).then(-pl.col('end')
|
600
|
+
).otherwise(pl.col('merge2')).alias('merge2')
|
601
|
+
])
|
602
|
+
)
|
603
|
+
|
604
|
+
to_merge = [-np.setdiff1d(
|
605
|
+
hc_build.select(
|
606
|
+
pl.col('start', 'end')
|
607
|
+
).row(i),
|
608
|
+
hc_build.select(
|
609
|
+
pl.col('start', 'end')
|
610
|
+
).slice(1, i-1).to_numpy().flatten()
|
611
|
+
) for i in idx]
|
612
|
+
|
613
|
+
to_merge = [x[0].item() if len(x) > 0 else np.nan for x in to_merge]
|
614
|
+
|
615
|
+
hc_build = (
|
616
|
+
hc_build
|
617
|
+
.with_columns([
|
618
|
+
pl.when(pl.col('merge1').is_nan()
|
619
|
+
).then(pl.Series(to_merge, strict=False)
|
620
|
+
).otherwise(pl.col('merge1')).alias('merge1')
|
621
|
+
])
|
622
|
+
)
|
623
|
+
|
624
|
+
hc_build = hc_build.with_row_index()
|
625
|
+
n = hc_build.height
|
626
|
+
|
627
|
+
hc_build = (hc_build
|
628
|
+
.with_columns(
|
629
|
+
pl.when(pl.col("merge1").lt(0))
|
630
|
+
.then(pl.col("merge1").mul(-1).sub(1))
|
631
|
+
.otherwise(pl.col('merge1').add(n-1)).alias('merge1')
|
632
|
+
)
|
633
|
+
.with_columns(
|
634
|
+
pl.when(pl.col("merge2").lt(0))
|
635
|
+
.then(pl.col("merge2").mul(-1).sub(1))
|
636
|
+
.otherwise(pl.col('merge2').add(n-1)).alias('merge2')
|
637
|
+
)
|
638
|
+
)
|
639
|
+
|
640
|
+
hc_build = (
|
641
|
+
hc_build
|
642
|
+
.with_columns(distance=np.array(list(data_collector.keys())))
|
643
|
+
.with_columns(pl.col("distance").str.replace(r"(\d+: )", ""))
|
644
|
+
.with_columns(pl.col("distance").cast(pl.Float64))
|
645
|
+
.with_columns(pl.col("distance").cum_sum().alias("distance"))
|
646
|
+
)
|
647
|
+
|
648
|
+
size = np.array(
|
649
|
+
[
|
650
|
+
len(x) if isinstance(x, (list, np.ndarray))
|
651
|
+
else 1 for x in position_collector.values()
|
652
|
+
])
|
653
|
+
|
654
|
+
hc_build = (
|
655
|
+
hc_build
|
656
|
+
.with_columns(size=size)
|
657
|
+
.with_columns(pl.col("size").cast(pl.Float64))
|
658
|
+
)
|
659
|
+
|
660
|
+
hc_build = hc_build.filter(pl.col("index") != 0)
|
661
|
+
|
662
|
+
hc = hc_build.select("merge1", "merge2", "distance", "size").to_numpy()
|
663
|
+
return hc
|
664
|
+
|
665
|
+
|
666
|
+
def _plot_dendrogram(
|
667
|
+
icoords,
|
668
|
+
dcoords,
|
669
|
+
ivl,
|
670
|
+
p,
|
671
|
+
n,
|
672
|
+
mh,
|
673
|
+
orientation,
|
674
|
+
no_labels,
|
675
|
+
color_list,
|
676
|
+
leaf_font_size=None,
|
677
|
+
leaf_rotation=None,
|
678
|
+
contraction_marks=None,
|
679
|
+
ax=None,
|
680
|
+
above_threshold_color='C0'
|
681
|
+
):
|
682
|
+
# Import matplotlib here so that it's not imported unless dendrograms
|
683
|
+
# are plotted. Raise an informative error if importing fails.
|
684
|
+
try:
|
685
|
+
# if an axis is provided, don't use pylab at all
|
686
|
+
if ax is None:
|
687
|
+
import matplotlib.pylab
|
688
|
+
import matplotlib.patches
|
689
|
+
import matplotlib.collections
|
690
|
+
except ImportError as e:
|
691
|
+
raise ImportError("You must install the matplotlib library to plot "
|
692
|
+
"the dendrogram. Use no_plot=True to calculate the "
|
693
|
+
"dendrogram without plotting.") from e
|
694
|
+
|
695
|
+
if ax is None:
|
696
|
+
ax = matplotlib.pylab.gca()
|
697
|
+
# if we're using pylab, we want to trigger a draw at the end
|
698
|
+
trigger_redraw = True
|
699
|
+
else:
|
700
|
+
trigger_redraw = False
|
701
|
+
|
702
|
+
# Independent variable plot width
|
703
|
+
ivw = len(ivl) * 10
|
704
|
+
# Dependent variable plot height
|
705
|
+
dvw = mh + mh * 0.05
|
706
|
+
|
707
|
+
iv_ticks = np.arange(5, len(ivl) * 10 + 5, 10)
|
708
|
+
if orientation in ('top', 'bottom'):
|
709
|
+
if orientation == 'top':
|
710
|
+
ax.set_ylim([0, dvw])
|
711
|
+
ax.set_xlim([0, ivw])
|
712
|
+
else:
|
713
|
+
ax.set_ylim([dvw, 0])
|
714
|
+
ax.set_xlim([0, ivw])
|
715
|
+
|
716
|
+
xlines = icoords
|
717
|
+
ylines = dcoords
|
718
|
+
if no_labels:
|
719
|
+
ax.set_xticks([])
|
720
|
+
ax.set_xticklabels([])
|
721
|
+
else:
|
722
|
+
ax.set_xticks(iv_ticks)
|
723
|
+
|
724
|
+
if orientation == 'top':
|
725
|
+
ax.xaxis.set_ticks_position('bottom')
|
726
|
+
else:
|
727
|
+
ax.xaxis.set_ticks_position('top')
|
728
|
+
|
729
|
+
# Make the tick marks invisible because they cover up the links
|
730
|
+
for line in ax.get_xticklines():
|
731
|
+
line.set_visible(False)
|
732
|
+
|
733
|
+
leaf_rot = (float(_get_tick_rotation(len(ivl)))
|
734
|
+
if (leaf_rotation is None) else leaf_rotation)
|
735
|
+
leaf_font = (float(_get_tick_text_size(len(ivl)))
|
736
|
+
if (leaf_font_size is None) else leaf_font_size)
|
737
|
+
ax.set_xticklabels(ivl, rotation=leaf_rot, size=leaf_font)
|
738
|
+
|
739
|
+
elif orientation in ('left', 'right'):
|
740
|
+
if orientation == 'left':
|
741
|
+
ax.set_xlim([dvw, 0])
|
742
|
+
ax.set_ylim([0, ivw])
|
743
|
+
else:
|
744
|
+
ax.set_xlim([0, dvw])
|
745
|
+
ax.set_ylim([0, ivw])
|
746
|
+
|
747
|
+
xlines = dcoords
|
748
|
+
ylines = icoords
|
749
|
+
if no_labels:
|
750
|
+
ax.set_yticks([])
|
751
|
+
ax.set_yticklabels([])
|
752
|
+
else:
|
753
|
+
ax.set_yticks(iv_ticks)
|
754
|
+
|
755
|
+
if orientation == 'left':
|
756
|
+
ax.yaxis.set_ticks_position('right')
|
757
|
+
else:
|
758
|
+
ax.yaxis.set_ticks_position('left')
|
759
|
+
|
760
|
+
# Make the tick marks invisible because they cover up the links
|
761
|
+
for line in ax.get_yticklines():
|
762
|
+
line.set_visible(False)
|
763
|
+
|
764
|
+
leaf_font = (float(_get_tick_text_size(len(ivl)))
|
765
|
+
if (leaf_font_size is None) else leaf_font_size)
|
766
|
+
|
767
|
+
if leaf_rotation is not None:
|
768
|
+
ax.set_yticklabels(ivl, rotation=leaf_rotation, size=leaf_font)
|
769
|
+
else:
|
770
|
+
ax.set_yticklabels(ivl, size=leaf_font)
|
771
|
+
|
772
|
+
# Let's use collections instead. This way there is a separate legend item
|
773
|
+
# for each tree grouping, rather than stupidly one for each line segment.
|
774
|
+
colors_used = _remove_dups(color_list)
|
775
|
+
color_to_lines = {}
|
776
|
+
for color in colors_used:
|
777
|
+
color_to_lines[color] = []
|
778
|
+
for (xline, yline, color) in zip(xlines, ylines, color_list):
|
779
|
+
color_to_lines[color].append(list(zip(xline, yline)))
|
780
|
+
|
781
|
+
colors_to_collections = {}
|
782
|
+
# Construct the collections.
|
783
|
+
for color in colors_used:
|
784
|
+
coll = matplotlib.collections.LineCollection(color_to_lines[color],
|
785
|
+
colors=(color,))
|
786
|
+
colors_to_collections[color] = coll
|
787
|
+
|
788
|
+
# Add all the groupings below the color threshold.
|
789
|
+
for color in colors_used:
|
790
|
+
if color != above_threshold_color:
|
791
|
+
ax.add_collection(colors_to_collections[color])
|
792
|
+
# If there's a grouping of links above the color threshold, it goes last.
|
793
|
+
if above_threshold_color in colors_to_collections:
|
794
|
+
ax.add_collection(colors_to_collections[above_threshold_color])
|
795
|
+
|
796
|
+
if contraction_marks is not None:
|
797
|
+
Ellipse = matplotlib.patches.Ellipse
|
798
|
+
for (x, y) in contraction_marks:
|
799
|
+
if orientation in ('left', 'right'):
|
800
|
+
e = Ellipse((y, x), width=dvw / 100, height=1.0)
|
801
|
+
else:
|
802
|
+
e = Ellipse((x, y), width=1.0, height=dvw / 100)
|
803
|
+
ax.add_artist(e)
|
804
|
+
e.set_clip_box(ax.bbox)
|
805
|
+
e.set_alpha(0.5)
|
806
|
+
e.set_facecolor('k')
|
807
|
+
|
808
|
+
if trigger_redraw:
|
809
|
+
matplotlib.pylab.draw_if_interactive()
|