google-ngrams 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,809 @@
1
+ import copy
2
+ import numpy as np
3
+ import polars as pl
4
+ from textwrap import dedent
5
+
6
+ __all__ = [
7
+ "_linkage_matrix",
8
+ "_vnc_calculate_info",
9
+ "_plot_dendrogram",
10
+ "_cut_tree_simple"
11
+ ]
12
+
13
+
14
+ def _cut_tree_simple(Z, n_clusters=None, height=None):
15
+ if height is not None:
16
+ raise NotImplementedError(
17
+ """
18
+ height-based cuts not supported in simplified implementation
19
+ """
20
+ )
21
+ n = Z.shape[0] + 1
22
+ if n_clusters is None or n_clusters >= n:
23
+ return np.arange(n, dtype=int)
24
+ if n_clusters <= 1:
25
+ return np.zeros(n, dtype=int)
26
+ merges_to_apply = n - int(n_clusters)
27
+ labels = np.arange(n, dtype=int)
28
+ node_leaves = [None] * (2 * n - 1)
29
+ for i in range(n):
30
+ node_leaves[i] = [i]
31
+ for i in range(Z.shape[0]):
32
+ left = int(Z[i, 0])
33
+ right = int(Z[i, 1])
34
+ leaves = node_leaves[left] + node_leaves[right]
35
+ leaves.sort()
36
+ node_leaves[n + i] = leaves
37
+ if i < merges_to_apply:
38
+ block = leaves
39
+ block_labels = labels[block]
40
+ new_label = block_labels.min()
41
+ max_label = block_labels.max()
42
+ labels[block] = new_label
43
+ # compress labels above max_label downward
44
+ mask = labels > max_label
45
+ labels[mask] -= (max_label - new_label)
46
+ else:
47
+ break
48
+ return labels
49
+
50
+
51
+ def _remove_dups(L):
52
+ """
53
+ Remove duplicates AND preserve the original order of the elements.
54
+
55
+ The set class is not guaranteed to do this.
56
+ """
57
+ seen_before = set()
58
+ L2 = []
59
+ for i in L:
60
+ if i not in seen_before:
61
+ seen_before.add(i)
62
+ L2.append(i)
63
+ return L2
64
+
65
+
66
+ _dtextsizes = {20: 12, 30: 10, 50: 8, 85: 6, np.inf: 5}
67
+ _drotation = {20: 0, 40: 45, np.inf: 90}
68
+ _dtextsortedkeys = list(_dtextsizes.keys())
69
+ _dtextsortedkeys.sort()
70
+ _drotationsortedkeys = list(_drotation.keys())
71
+ _drotationsortedkeys.sort()
72
+
73
+
74
+ def _get_tick_text_size(p):
75
+ for k in _dtextsortedkeys:
76
+ if p <= k:
77
+ return _dtextsizes[k]
78
+
79
+
80
+ def _get_tick_rotation(p):
81
+ for k in _drotationsortedkeys:
82
+ if p <= k:
83
+ return _drotation[k]
84
+
85
+
86
+ def _contract_linkage_matrix(
87
+ Z: np.ndarray,
88
+ p=4
89
+ ) -> np.ndarray:
90
+ """
91
+ Contracts the linkage matrix by reducing the number of clusters
92
+ to a specified number.
93
+
94
+ Parameters
95
+ ----------
96
+ Z : np.ndarray
97
+ The linkage matrix.
98
+ p : int
99
+ The number of clusters to retain.
100
+
101
+ Returns
102
+ -------
103
+ np.ndarray
104
+ The contracted linkage matrix with updated cluster IDs
105
+ and member counts.
106
+ """
107
+ Z = Z.copy()
108
+ truncated_Z = Z[-(p - 1):]
109
+
110
+ n_points = Z.shape[0] + 1
111
+ clusters = [
112
+ dict(node_id=i, left=i, right=i, members=[i], distance=0, n_members=1)
113
+ for i in range(n_points)
114
+ ]
115
+ for z_i in range(Z.shape[0]):
116
+ row = Z[z_i]
117
+ left = int(row[0])
118
+ right = int(row[1])
119
+ cluster = dict(
120
+ node_id=z_i + n_points,
121
+ left=left,
122
+ right=right,
123
+ members=[],
124
+ distance=row[2],
125
+ n_members=int(row[3])
126
+ )
127
+ cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
128
+ cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
129
+ cluster["members"].sort()
130
+ clusters.append(cluster)
131
+
132
+ node_map = []
133
+ for i in range(truncated_Z.shape[0]):
134
+ node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
135
+ for cluster in clusters:
136
+ if cluster['node_id'] in node_ids:
137
+ node_map.append(cluster)
138
+
139
+ filtered_node_map = []
140
+ superset_node_map = []
141
+
142
+ for node in node_map:
143
+ is_superset = False
144
+ for other_node in node_map:
145
+ if (
146
+ node != other_node
147
+ and set(
148
+ node['members']
149
+ ).issuperset(set(other_node['members']))
150
+ ):
151
+ is_superset = True
152
+ break
153
+ if is_superset:
154
+ superset_node_map.append(node)
155
+ else:
156
+ filtered_node_map.append(node)
157
+
158
+ # Add 'truncated_id' to each dictionary in filtered_node_map
159
+ for idx, node in enumerate(
160
+ sorted(filtered_node_map, key=lambda x: x['members'][0])
161
+ ):
162
+ node['truncated_id'] = idx
163
+ node['n_members'] = 1
164
+
165
+ for idx, node in enumerate(
166
+ sorted(superset_node_map, key=lambda x: x['node_id'])
167
+ ):
168
+ node['truncated_id'] = idx + len(filtered_node_map)
169
+
170
+ # Adjust 'n_members' in superset_node_map to reflect
171
+ # the number of filtered_node_map['members'] sets they contain
172
+ for superset_node in superset_node_map:
173
+ count = 0
174
+ for filtered_node in filtered_node_map:
175
+ if set(
176
+ filtered_node['members']
177
+ ).issubset(set(superset_node['members'])):
178
+ count += 1
179
+ superset_node['n_members'] = count
180
+
181
+ # Create a mapping from node_id to truncated_id and n_members
182
+ node_id_to_truncated_id = {
183
+ node['node_id']: node['truncated_id']
184
+ for node in filtered_node_map + superset_node_map
185
+ }
186
+ node_id_to_n_members = {
187
+ node['node_id']: node['n_members']
188
+ for node in filtered_node_map + superset_node_map
189
+ }
190
+
191
+ # Replace values in truncated_Z
192
+ for i in range(truncated_Z.shape[0]):
193
+ truncated_Z[i, 3] = (
194
+ node_id_to_n_members[int(truncated_Z[i, 0])] +
195
+ node_id_to_n_members[int(truncated_Z[i, 1])]
196
+ )
197
+ truncated_Z[i, 0] = node_id_to_truncated_id[int(truncated_Z[i, 0])]
198
+ truncated_Z[i, 1] = node_id_to_truncated_id[int(truncated_Z[i, 1])]
199
+
200
+ return truncated_Z
201
+
202
+
203
+ def _contraction_mark_coordinates(
204
+ Z: np.ndarray,
205
+ p=4
206
+ ) -> list:
207
+ """
208
+ Generates contraction marks for a given linkage matrix.
209
+
210
+ Parameters
211
+ ----------
212
+ Z : np.ndarray
213
+ The linkage matrix.
214
+ p : int
215
+ The number of clusters to retain.
216
+
217
+ Returns
218
+ -------
219
+ list
220
+ A sorted list of tuples where each tuple contains
221
+ a calculated value based on truncated_id and a distance value.
222
+ """
223
+ Z = Z.copy()
224
+ truncated_Z = Z[-(p-1):]
225
+
226
+ n_points = Z.shape[0] + 1
227
+ clusters = [dict(node_id=i,
228
+ left=i,
229
+ right=i,
230
+ members=[i],
231
+ distance=0,
232
+ n_members=1) for i in range(n_points)]
233
+ for z_i in range(Z.shape[0]):
234
+ row = Z[z_i]
235
+ left = int(row[0])
236
+ right = int(row[1])
237
+ cluster = dict(
238
+ node_id=z_i + n_points,
239
+ left=left, right=right,
240
+ members=[],
241
+ distance=row[2],
242
+ n_members=int(row[3])
243
+ )
244
+ cluster["members"].extend(copy.deepcopy(clusters[left]["members"]))
245
+ cluster["members"].extend(copy.deepcopy(clusters[right]["members"]))
246
+ cluster["members"].sort()
247
+ clusters.append(cluster)
248
+
249
+ node_map = []
250
+ for i in range(truncated_Z.shape[0]):
251
+ node_ids = [int(truncated_Z[i, 0]), int(truncated_Z[i, 1])]
252
+ for cluster in clusters:
253
+ if cluster['node_id'] in node_ids:
254
+ node_map.append(cluster)
255
+
256
+ filtered_node_map = []
257
+ superset_node_map = []
258
+
259
+ for node in node_map:
260
+ is_superset = False
261
+ for other_node in node_map:
262
+ if (node != other_node
263
+ and set(node['members']
264
+ ).issuperset(set(other_node['members']))):
265
+ is_superset = True
266
+ break
267
+ if is_superset:
268
+ superset_node_map.append(node)
269
+ else:
270
+ filtered_node_map.append(node)
271
+
272
+ # Create a set of node_ids from filtered_node_map and superset_node_map
273
+ excluded_node_ids = set(
274
+ node['node_id'] for node in filtered_node_map
275
+ ).union(node['node_id'] for node in superset_node_map)
276
+
277
+ # Filter clusters that are not in excluded_node_ids
278
+ non_excluded_clusters = [
279
+ cluster for cluster in clusters
280
+ if cluster['node_id'] not in excluded_node_ids
281
+ ]
282
+
283
+ # Create a list to store the result
284
+ subset_clusters = []
285
+
286
+ # Iterate over filtered_node_map
287
+ for filtered_cluster in filtered_node_map:
288
+ distances = []
289
+ for cluster in non_excluded_clusters:
290
+ if (
291
+ cluster['n_members'] > 1
292
+ and set(cluster['members']
293
+ ).issubset(set(filtered_cluster['members']))):
294
+ distances.append(cluster['distance'])
295
+ if distances:
296
+ subset_clusters.append(
297
+ {'node_id': filtered_cluster['node_id'], 'distance': distances}
298
+ )
299
+
300
+ # Add 'truncated_id' to each dictionary in filtered_node_map
301
+ for idx, node in enumerate(
302
+ sorted(filtered_node_map, key=lambda x: x['members'][0])
303
+ ):
304
+ node['truncated_id'] = idx
305
+
306
+ # Create a mapping from node_id to truncated_id
307
+ node_id_to_truncated_id = {
308
+ node['node_id']: node['truncated_id'] for node in filtered_node_map
309
+ }
310
+
311
+ # Add 'truncated_id' to each dictionary in subset_clusters
312
+ for cluster in subset_clusters:
313
+ cluster['truncated_id'] = node_id_to_truncated_id[cluster['node_id']]
314
+
315
+ # Create a list of tuples
316
+ contraction_marks = []
317
+
318
+ # Iterate over subset_clusters
319
+ for cluster in subset_clusters:
320
+ truncated_id = cluster['truncated_id']
321
+ for distance in cluster['distance']:
322
+ contraction_marks.append((10.0 * truncated_id + 5.0, distance))
323
+
324
+ # Sort the list of tuples
325
+ contraction_marks = sorted(contraction_marks, key=lambda x: (x[0], x[1]))
326
+
327
+ return contraction_marks
328
+
329
+
330
+ def _convert_linkage_to_coordinates(
331
+ Z: np.ndarray
332
+ ) -> dict:
333
+ """
334
+ Converts a linkage matrix to coordinates for plotting a dendrogram.
335
+
336
+ Parameters
337
+ ----------
338
+ Z : np.ndarray
339
+ The linkage matrix.
340
+
341
+ Returns
342
+ -------
343
+ dict
344
+ A dictionary containing 'icoord', 'dcoord', and 'ivl'
345
+ for plotting the dendrogram.
346
+ """
347
+ ivl = [i for i in range(Z.shape[0] + 1)]
348
+ n = len(ivl)
349
+ icoord = []
350
+ dcoord = []
351
+ clusters = {i: [i] for i in range(n)}
352
+ current_index = n
353
+ positions = {i: (i + 1) * 10 - 5 for i in range(n)}
354
+ heights = {i: 0 for i in range(n)}
355
+
356
+ for i in range(len(Z)):
357
+ cluster1 = int(Z[i, 0])
358
+ cluster2 = int(Z[i, 1])
359
+ dist = Z[i, 2].item()
360
+ new_cluster = clusters[cluster1] + clusters[cluster2]
361
+ clusters[current_index] = new_cluster
362
+
363
+ x1 = positions[cluster1]
364
+ x2 = positions[cluster2]
365
+ x_new = (x1 + x2) / 2
366
+ positions[current_index] = x_new
367
+
368
+ h1 = heights[cluster1]
369
+ h2 = heights[cluster2]
370
+ heights[current_index] = dist
371
+
372
+ icoord.append([x1, x1, x2, x2])
373
+ dcoord.append([h1, dist, dist, h2])
374
+
375
+ current_index += 1
376
+
377
+ # Sort icoord and dcoord by the first element in each icoord list
378
+ sorted_indices = sorted(range(len(icoord)), key=lambda i: icoord[i][0])
379
+ icoord = [icoord[i] for i in sorted_indices]
380
+ dcoord = [dcoord[i] for i in sorted_indices]
381
+
382
+ return {"icoord": icoord, "dcoord": dcoord, "ivl": ivl}
383
+
384
+
385
+ def _vnc_calculate_info(
386
+ Z: np.ndarray,
387
+ p=None,
388
+ truncate=False,
389
+ contraction_marks=False,
390
+ labels=None
391
+ ) -> dict:
392
+ Z = Z.copy()
393
+ Zs = Z.shape
394
+ n = Zs[0] + 1
395
+
396
+ if labels is not None:
397
+ if Zs[0] + 1 != len(labels):
398
+ labels = None
399
+ print(dedent(
400
+ """
401
+ Dimensions of Z and labels are not consistent.
402
+ Using defalut labels.
403
+ """))
404
+ if labels is None:
405
+ labels = [str(i) for i in range(Zs[0] + 1)]
406
+ else:
407
+ labels = labels
408
+
409
+ if p is not None and p > n or p < 2:
410
+ p = None
411
+ truncate = False
412
+ contraction_marks = False
413
+
414
+ if p is not None:
415
+ cluster_assignment = [i.item() for i in _cut_tree_simple(Z, p)]
416
+
417
+ # Create a dictionary to hold the clusters
418
+ cluster_dict = {}
419
+
420
+ # Iterate over the labels and clusters to populate the dictionary
421
+ for label, cluster in zip(labels, cluster_assignment):
422
+ cluster_key = f'cluster_{cluster + 1}'
423
+ if cluster_key not in cluster_dict:
424
+ cluster_dict[cluster_key] = []
425
+ cluster_dict[cluster_key].append(label)
426
+
427
+ # Convert the dictionary to a list of dictionaries
428
+ cluster_list = [{key: value} for key, value in cluster_dict.items()]
429
+
430
+ # Create a new list to hold the cluster labels
431
+ cluster_labels = []
432
+
433
+ # Iterate over the cluster_list to create the labels
434
+ for cluster in cluster_list:
435
+ for key, value in cluster.items():
436
+ if len(value) == 1:
437
+ cluster_labels.append(str(value[0]))
438
+ else:
439
+ cluster_labels.append(f"{value[0]}-{value[-1]}")
440
+
441
+ # get distance for plotting cut line
442
+ dist = [x[2].item() for x in Z]
443
+ dist_threshold = np.mean(
444
+ [dist[len(dist)-p+1], dist[len(dist)-p]]
445
+ )
446
+ else:
447
+ dist_threshold = None
448
+ cluster_list = None
449
+ cluster_labels = None
450
+
451
+ if truncate is True:
452
+ truncated_Z = _contract_linkage_matrix(Z, p=p)
453
+
454
+ if contraction_marks is True:
455
+ contraction_marks = _contraction_mark_coordinates(Z, p=p)
456
+ else:
457
+ contraction_marks = None
458
+
459
+ Z = truncated_Z
460
+ else:
461
+ Z = Z
462
+ contraction_marks = None
463
+
464
+ R = _convert_linkage_to_coordinates(Z)
465
+
466
+ mh = np.max(Z[:, 2])
467
+ Zn = Z.shape[0] + 1
468
+ color_list = ['k'] * (Zn - 1)
469
+ leaves_color_list = ['k'] * Zn
470
+ R['n'] = Zn
471
+ R['mh'] = mh
472
+ R['p'] = p
473
+ R['labels'] = labels
474
+ R['color_list'] = color_list
475
+ R['leaves_color_list'] = leaves_color_list
476
+ R['clusters'] = cluster_list
477
+ R['cluster_labels'] = cluster_labels
478
+ R['dist_threshold'] = dist_threshold
479
+ R["contraction_marks"] = contraction_marks
480
+
481
+ return R
482
+
483
+
484
+ def _linkage_matrix(
485
+ time_series,
486
+ frequency,
487
+ distance_measure='sd'
488
+ ) -> np.ndarray:
489
+
490
+ input_values = frequency.copy()
491
+ years = time_series.copy()
492
+
493
+ data_collector = {}
494
+ data_collector["0"] = input_values
495
+ position_collector = {}
496
+ position_collector[1] = 0
497
+ overall_distance = 0
498
+ number_of_steps = len(input_values) - 1
499
+
500
+ for i in range(1, number_of_steps + 1):
501
+ difference_checker = []
502
+ unique_years = np.unique(years)
503
+
504
+ for j in range(len(unique_years) - 1):
505
+ first_name = unique_years[j]
506
+ second_name = unique_years[j + 1]
507
+ pooled_sample = input_values[np.isin(years,
508
+ [first_name,
509
+ second_name])]
510
+
511
+ if distance_measure == "sd":
512
+ difference_checker.append(0 if np.sum(pooled_sample) == 0
513
+ else np.std(pooled_sample, ddof=1))
514
+ elif distance_measure == "cv":
515
+ difference_checker.append(
516
+ 0 if np.sum(pooled_sample) == 0
517
+ else np.std(pooled_sample, ddof=1) / np.mean(pooled_sample)
518
+ )
519
+
520
+ pos_to_be_merged = np.argmin(difference_checker)
521
+ distance = np.min(difference_checker)
522
+ overall_distance += distance
523
+ lower_name = unique_years[pos_to_be_merged]
524
+ higher_name = unique_years[pos_to_be_merged + 1]
525
+
526
+ matches = np.isin(years, [lower_name, higher_name])
527
+ new_mean_age = round(np.mean(years[matches]), 4)
528
+ position_collector[i + 1] = np.where(matches)[0] + 1
529
+ years[matches] = new_mean_age
530
+ data_collector[f"{i}: {distance}"] = input_values
531
+
532
+ hc_build = pl.DataFrame({
533
+ 'start': [
534
+ min(pos)
535
+ if isinstance(pos, (list, np.ndarray))
536
+ else pos for pos in position_collector.values()
537
+ ],
538
+ 'end': [
539
+ max(pos)
540
+ if isinstance(pos, (list, np.ndarray))
541
+ else pos for pos in position_collector.values()
542
+ ]
543
+ })
544
+
545
+ idx = np.arange(len(hc_build))
546
+
547
+ y = [np.where(
548
+ hc_build['start'].to_numpy()[:i] == hc_build['start'].to_numpy()[i]
549
+ )[0] for i in idx]
550
+ z = [np.where(
551
+ hc_build['end'].to_numpy()[:i] == hc_build['end'].to_numpy()[i]
552
+ )[0] for i in idx]
553
+
554
+ merge1 = [
555
+ y[i].max().item() if len(y[i]) else np.nan for i in range(len(y))
556
+ ]
557
+ merge2 = [
558
+ z[i].max().item() if len(z[i]) else np.nan for i in range(len(z))
559
+ ]
560
+
561
+ hc_build = (
562
+ hc_build.with_columns([
563
+ pl.Series('merge1',
564
+ [
565
+ min(m1, m2) if not np.isnan(m1) and
566
+ not np.isnan(m2)
567
+ else np.nan for m1, m2 in zip(merge1, merge2)
568
+ ]),
569
+ pl.Series('merge2',
570
+ [
571
+ max(m1, m2) for m1, m2 in zip(merge1, merge2)
572
+ ])
573
+ ])
574
+ )
575
+
576
+ hc_build = (
577
+ hc_build.with_columns([
578
+ pl.Series('merge1', [
579
+ min(m1, m2) if not np.isnan(m1) and
580
+ not np.isnan(m2) else np.nan for m1, m2 in zip(merge1, merge2)
581
+ ]),
582
+ pl.Series('merge2', [
583
+ max(m1, m2) if not np.isnan(m1)
584
+ else m2 for m1, m2 in zip(merge1, merge2)
585
+ ])
586
+ ])
587
+ )
588
+
589
+ hc_build = (
590
+ hc_build.with_columns([
591
+ pl.when(
592
+ pl.col('merge1').is_nan() &
593
+ pl.col('merge2').is_nan()
594
+ ).then(-pl.col('start')
595
+ ).otherwise(pl.col('merge1')).alias('merge1'),
596
+ pl.when(
597
+ pl.col('merge2')
598
+ .is_nan()
599
+ ).then(-pl.col('end')
600
+ ).otherwise(pl.col('merge2')).alias('merge2')
601
+ ])
602
+ )
603
+
604
+ to_merge = [-np.setdiff1d(
605
+ hc_build.select(
606
+ pl.col('start', 'end')
607
+ ).row(i),
608
+ hc_build.select(
609
+ pl.col('start', 'end')
610
+ ).slice(1, i-1).to_numpy().flatten()
611
+ ) for i in idx]
612
+
613
+ to_merge = [x[0].item() if len(x) > 0 else np.nan for x in to_merge]
614
+
615
+ hc_build = (
616
+ hc_build
617
+ .with_columns([
618
+ pl.when(pl.col('merge1').is_nan()
619
+ ).then(pl.Series(to_merge, strict=False)
620
+ ).otherwise(pl.col('merge1')).alias('merge1')
621
+ ])
622
+ )
623
+
624
+ hc_build = hc_build.with_row_index()
625
+ n = hc_build.height
626
+
627
+ hc_build = (hc_build
628
+ .with_columns(
629
+ pl.when(pl.col("merge1").lt(0))
630
+ .then(pl.col("merge1").mul(-1).sub(1))
631
+ .otherwise(pl.col('merge1').add(n-1)).alias('merge1')
632
+ )
633
+ .with_columns(
634
+ pl.when(pl.col("merge2").lt(0))
635
+ .then(pl.col("merge2").mul(-1).sub(1))
636
+ .otherwise(pl.col('merge2').add(n-1)).alias('merge2')
637
+ )
638
+ )
639
+
640
+ hc_build = (
641
+ hc_build
642
+ .with_columns(distance=np.array(list(data_collector.keys())))
643
+ .with_columns(pl.col("distance").str.replace(r"(\d+: )", ""))
644
+ .with_columns(pl.col("distance").cast(pl.Float64))
645
+ .with_columns(pl.col("distance").cum_sum().alias("distance"))
646
+ )
647
+
648
+ size = np.array(
649
+ [
650
+ len(x) if isinstance(x, (list, np.ndarray))
651
+ else 1 for x in position_collector.values()
652
+ ])
653
+
654
+ hc_build = (
655
+ hc_build
656
+ .with_columns(size=size)
657
+ .with_columns(pl.col("size").cast(pl.Float64))
658
+ )
659
+
660
+ hc_build = hc_build.filter(pl.col("index") != 0)
661
+
662
+ hc = hc_build.select("merge1", "merge2", "distance", "size").to_numpy()
663
+ return hc
664
+
665
+
666
+ def _plot_dendrogram(
667
+ icoords,
668
+ dcoords,
669
+ ivl,
670
+ p,
671
+ n,
672
+ mh,
673
+ orientation,
674
+ no_labels,
675
+ color_list,
676
+ leaf_font_size=None,
677
+ leaf_rotation=None,
678
+ contraction_marks=None,
679
+ ax=None,
680
+ above_threshold_color='C0'
681
+ ):
682
+ # Import matplotlib here so that it's not imported unless dendrograms
683
+ # are plotted. Raise an informative error if importing fails.
684
+ try:
685
+ # if an axis is provided, don't use pylab at all
686
+ if ax is None:
687
+ import matplotlib.pylab
688
+ import matplotlib.patches
689
+ import matplotlib.collections
690
+ except ImportError as e:
691
+ raise ImportError("You must install the matplotlib library to plot "
692
+ "the dendrogram. Use no_plot=True to calculate the "
693
+ "dendrogram without plotting.") from e
694
+
695
+ if ax is None:
696
+ ax = matplotlib.pylab.gca()
697
+ # if we're using pylab, we want to trigger a draw at the end
698
+ trigger_redraw = True
699
+ else:
700
+ trigger_redraw = False
701
+
702
+ # Independent variable plot width
703
+ ivw = len(ivl) * 10
704
+ # Dependent variable plot height
705
+ dvw = mh + mh * 0.05
706
+
707
+ iv_ticks = np.arange(5, len(ivl) * 10 + 5, 10)
708
+ if orientation in ('top', 'bottom'):
709
+ if orientation == 'top':
710
+ ax.set_ylim([0, dvw])
711
+ ax.set_xlim([0, ivw])
712
+ else:
713
+ ax.set_ylim([dvw, 0])
714
+ ax.set_xlim([0, ivw])
715
+
716
+ xlines = icoords
717
+ ylines = dcoords
718
+ if no_labels:
719
+ ax.set_xticks([])
720
+ ax.set_xticklabels([])
721
+ else:
722
+ ax.set_xticks(iv_ticks)
723
+
724
+ if orientation == 'top':
725
+ ax.xaxis.set_ticks_position('bottom')
726
+ else:
727
+ ax.xaxis.set_ticks_position('top')
728
+
729
+ # Make the tick marks invisible because they cover up the links
730
+ for line in ax.get_xticklines():
731
+ line.set_visible(False)
732
+
733
+ leaf_rot = (float(_get_tick_rotation(len(ivl)))
734
+ if (leaf_rotation is None) else leaf_rotation)
735
+ leaf_font = (float(_get_tick_text_size(len(ivl)))
736
+ if (leaf_font_size is None) else leaf_font_size)
737
+ ax.set_xticklabels(ivl, rotation=leaf_rot, size=leaf_font)
738
+
739
+ elif orientation in ('left', 'right'):
740
+ if orientation == 'left':
741
+ ax.set_xlim([dvw, 0])
742
+ ax.set_ylim([0, ivw])
743
+ else:
744
+ ax.set_xlim([0, dvw])
745
+ ax.set_ylim([0, ivw])
746
+
747
+ xlines = dcoords
748
+ ylines = icoords
749
+ if no_labels:
750
+ ax.set_yticks([])
751
+ ax.set_yticklabels([])
752
+ else:
753
+ ax.set_yticks(iv_ticks)
754
+
755
+ if orientation == 'left':
756
+ ax.yaxis.set_ticks_position('right')
757
+ else:
758
+ ax.yaxis.set_ticks_position('left')
759
+
760
+ # Make the tick marks invisible because they cover up the links
761
+ for line in ax.get_yticklines():
762
+ line.set_visible(False)
763
+
764
+ leaf_font = (float(_get_tick_text_size(len(ivl)))
765
+ if (leaf_font_size is None) else leaf_font_size)
766
+
767
+ if leaf_rotation is not None:
768
+ ax.set_yticklabels(ivl, rotation=leaf_rotation, size=leaf_font)
769
+ else:
770
+ ax.set_yticklabels(ivl, size=leaf_font)
771
+
772
+ # Let's use collections instead. This way there is a separate legend item
773
+ # for each tree grouping, rather than stupidly one for each line segment.
774
+ colors_used = _remove_dups(color_list)
775
+ color_to_lines = {}
776
+ for color in colors_used:
777
+ color_to_lines[color] = []
778
+ for (xline, yline, color) in zip(xlines, ylines, color_list):
779
+ color_to_lines[color].append(list(zip(xline, yline)))
780
+
781
+ colors_to_collections = {}
782
+ # Construct the collections.
783
+ for color in colors_used:
784
+ coll = matplotlib.collections.LineCollection(color_to_lines[color],
785
+ colors=(color,))
786
+ colors_to_collections[color] = coll
787
+
788
+ # Add all the groupings below the color threshold.
789
+ for color in colors_used:
790
+ if color != above_threshold_color:
791
+ ax.add_collection(colors_to_collections[color])
792
+ # If there's a grouping of links above the color threshold, it goes last.
793
+ if above_threshold_color in colors_to_collections:
794
+ ax.add_collection(colors_to_collections[above_threshold_color])
795
+
796
+ if contraction_marks is not None:
797
+ Ellipse = matplotlib.patches.Ellipse
798
+ for (x, y) in contraction_marks:
799
+ if orientation in ('left', 'right'):
800
+ e = Ellipse((y, x), width=dvw / 100, height=1.0)
801
+ else:
802
+ e = Ellipse((x, y), width=1.0, height=dvw / 100)
803
+ ax.add_artist(e)
804
+ e.set_clip_box(ax.bbox)
805
+ e.set_alpha(0.5)
806
+ e.set_facecolor('k')
807
+
808
+ if trigger_redraw:
809
+ matplotlib.pylab.draw_if_interactive()