cccpm 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,821 @@
1
+ import math
2
+ import warnings
3
+
4
+ import matplotlib
5
+ from matplotlib import patches
6
+ from nilearn._utils.param_validation import check_threshold
7
+ from scipy import interpolate
8
+
9
+ from collections import defaultdict
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from math import radians
13
+ from typing import Union, Tuple
14
+ from collections import OrderedDict
15
+
16
+ from scipy.stats import scoreatpercentile
17
+
18
+
19
+ class ROI_to_degree:
20
+ """
21
+ Used to calculate an ROI's position within the rim/label.
22
+ Within the rim/label, ROIs are positioned based on their y coordinate
23
+ (i.e., how anterior/posterior it is). If no coordinate is provided,
24
+ the ROIs are positioned randomly (see plot_chord_diagram(...)).
25
+ """
26
+
27
+ def __init__(self, coords, idx_to_label, network_low_high, network_counts):
28
+ self.idx_to_degree = {}
29
+ network_ROI_degree_running_sum = defaultdict(lambda: 0)
30
+ idx_coord_sorted_by_y = sorted(enumerate(coords), key=lambda x: x[1][1])
31
+ for idx, coord in idx_coord_sorted_by_y:
32
+ network = idx_to_label[idx]
33
+ network_low, network_high = network_low_high[network]
34
+ idx_degree = network_low + \
35
+ (network_ROI_degree_running_sum[network] + 0.5) / \
36
+ network_counts[network] * \
37
+ (network_high - network_low)
38
+ self.idx_to_degree[idx] = idx_degree
39
+ network_ROI_degree_running_sum[network] += 1
40
+
41
+ def get_degree(self, idx):
42
+ return self.idx_to_degree[idx]
43
+
44
+
45
+ def _network_order_colors_check(network_order, network_colors, idx_to_label):
46
+ if network_order is None:
47
+ network_order = sorted(list(set(idx_to_label.values())))
48
+ else:
49
+ idx_labels = set(idx_to_label.values())
50
+ network_order = [network for network in network_order if
51
+ network in idx_labels]
52
+ if idx_labels != set(network_order):
53
+ missing_labels = idx_labels - set(network_order)
54
+ raise ValueError('network_order, if specified, must contain all '
55
+ f'idx_to_label labels. '
56
+ f'Missing labels: {missing_labels}')
57
+
58
+ if network_colors is None:
59
+ default_colors = []
60
+ for i in range(1 + len(network_order) // 10):
61
+ default_colors += plt.rcParams['axes.prop_cycle'].by_key()['color']
62
+ network_colors = dict((network, color) for network, color in
63
+ zip(network_order, default_colors[:len(network_order)]))
64
+ else:
65
+ idx_labels = set(idx_to_label.values())
66
+ if not idx_labels.issubset(set(network_colors.keys())):
67
+ missing_labels = set(idx_labels) - set(network_colors.keys())
68
+ raise ValueError('network_colors, if specified, must describe all '
69
+ f'idx_to_label labels. '
70
+ f'Missing labels: {missing_labels}')
71
+ return network_order, network_colors
72
+
73
+
74
+ def plot_chord(idx_to_label: dict,
75
+ edges: Union[list, np.ndarray],
76
+ fp_chord: Union[str, None] = None,
77
+ edge_weights: Union[list, np.ndarray, None] = None,
78
+ network_order: Union[list, None] = None,
79
+ network_colors: Union[dict, None] = None,
80
+ colors: Union[None, str, tuple, list] = None,
81
+ linewidths: Union[None, float, int, list] = None,
82
+ alphas: Union[None, float, int, list] = None,
83
+ cmap: Union[None, str, matplotlib.colors.Colormap] = None,
84
+ coords: Union[list, np.ndarray] = None,
85
+ arc_setting: bool = True,
86
+ cbar: Union[None, plt.colorbar] = None,
87
+ do_ROI_circles: bool = False,
88
+ do_ROI_circles_specific: bool = True,
89
+ ROI_circle_radius: float = 0.005,
90
+ black_BG: bool = False,
91
+ label_fontsize: int = 60,
92
+ do_monkeypatch: bool = True,
93
+ vmin: Union[None, int, float] = None,
94
+ vmax: Union[None, int, float] = None,
95
+ plot_count: bool = False,
96
+ plot_abs_sum: bool = False,
97
+ norm_thickness: bool = False,
98
+ edge_threshold: Union[float, int, str] = 0.,
99
+ dpi: int = 400) -> None:
100
+ """
101
+ Plots the chord diagram and either saves a file if fp_chord is not None or
102
+ opens it in a matplotlib window. For most of the arguments in these
103
+ function, if None is passed, then a default setting is generated and
104
+ used.
105
+
106
+ :param idx_to_label: dict mapping each ROI index to its chord label
107
+ (e.g., {0: "FPCN"})
108
+ :param edges: list of tuples (a, b), where a = index of the 1st ROI &
109
+ b = index of the 2nd ROI
110
+ :param fp_chord: filepath to save the chord diagram. If None, a matplotlib
111
+ window will open with the diagram
112
+ :param edge_weights: list of edge weights. The nth weight, here, should
113
+ correspond to the nth edge in edges, above
114
+ :param network_order: list specifying the order of the labels (rims)
115
+ :param network_colors: dict mapping each label to a matplotlib color
116
+ :param linewidths: Value specifying the width of the edge arcs.
117
+ If int or float, then all arcs will be the same width. If a list,
118
+ then each arc will be set based on its list entry. If None, then
119
+ widths will be generated based on the edge's weight.
120
+ :param cmap: can either be None (default will be used), a string (will
121
+ retrieve cmap from matplotlib), or a matplotlib colormap
122
+ :param coords: list of coordinates, such that the nth entry corresponds
123
+ to idx/ROI n in edges, above. It's fine if len(coords) > len(edges).
124
+ :param arc_setting: changing between True/False will slightly change the
125
+ look of the arcs.
126
+ :param cbar: if not None, this argument will be called to make the colorbar.
127
+ Otherwise, if None and not all the edge weights are equal to 1,
128
+ then a colorbar will be created based on cmap and the edge_weights.
129
+ If None and all the edge weights are equal to 1, then no colorbar
130
+ will be created.
131
+ :param do_ROI_circles: If True, then a small circle will be added for
132
+ each ROI. If False, then this will not be done.
133
+ :param do_ROI_circles_specific: If True, then small ROI circles will only
134
+ be added for ROIs in edges used for arcs.
135
+ :param ROI_circle_radius: The radius used for the ROI circles
136
+ :param do_monkeypatch: There is seemingly a bug in matplotlib that prevents
137
+ cleanly rotating characters. The matplotlib team has been notified about
138
+ this, and a potential fix has been suggested. Until that is incorporated
139
+ into matplotlib, the fix is being "monkeypatched". See
140
+ chord.plot_rim_label(...) and patch_RenderAgg.py.
141
+ :param edge_threshold: This parameter acts the same as edge_threshold in
142
+ nilearn.plotting.plot_connectome. Edges whose abs(weight) is under
143
+ the threshold are omitted. edge_threshold can be a float or a string
144
+ representing a percentile (e.g., "25"). The latter causes edges below
145
+ the percentile to be omitted.
146
+ """
147
+
148
+ network_order, network_colors = \
149
+ _network_order_colors_check(network_order, network_colors, idx_to_label)
150
+
151
+ if edge_weights is None:
152
+ edge_weights = [1] * len(edges)
153
+ else:
154
+ edge_weights, edges = _threshold_proc(edge_threshold, edge_weights,
155
+ edges)
156
+
157
+ if vmin is None:
158
+ vmin = min(edge_weights)
159
+ if vmax is None:
160
+ vmax = max(edge_weights)
161
+
162
+ if vmin == vmax:
163
+ raise ValueError(f'vmin and vmax cannot be equal. '
164
+ f'Note: your inputs only provide {len(edges)} edges '
165
+ f'after thresholding')
166
+
167
+ if cmap is None:
168
+ if abs(vmin - vmax) < 1e-6:
169
+ cmap = plt.get_cmap('Greys')
170
+ else:
171
+ cmap = plt.get_cmap('turbo')
172
+ elif isinstance(cmap, str):
173
+ cmap = plt.get_cmap(cmap)
174
+
175
+ if coords is None:
176
+ coords = np.random.random(size=(len(idx_to_label), 3))
177
+
178
+ plt.figure(figsize=(15, 15))
179
+ radius = 0.6
180
+ network_low_high, network_counts, network_centers, network_starts_ends = \
181
+ plot_rim_and_labels(idx_to_label, network_order, network_colors, radius,
182
+ black_BG=black_BG, label_fontsize=label_fontsize,
183
+ do_monkeypatch=do_monkeypatch)
184
+
185
+
186
+ vmin, vmax = plot_arcs(edges, idx_to_label, network_low_high, network_counts,
187
+ edge_weights, network_centers, network_starts_ends,
188
+ radius, cmap, coords, linewidths=linewidths,
189
+ seven_point_arc=arc_setting, colors=colors,
190
+ alphas=alphas, vmin=vmin, vmax=vmax,
191
+ plot_count=plot_count,
192
+ plot_abs_sum=plot_abs_sum,
193
+ norm_thickness=norm_thickness)
194
+
195
+ plt.axis('off')
196
+ plt.tight_layout()
197
+ #plt.show()
198
+ if fp_chord is not None:
199
+ plt.savefig(fp_chord, dpi=dpi)
200
+ plt.clf()
201
+
202
+
203
+ def plot_rim_and_labels(idx_to_label: dict, network_order: list,
204
+ network_colors: dict, radius: Union[float, int],
205
+ rim_border: Union[float, int] = 1.0,
206
+ black_BG: bool = False,
207
+ label_fontsize: int = 60,
208
+ do_monkeypatch: bool = True) -> Tuple[dict, dict,
209
+ dict, dict]:
210
+ """
211
+ Plots the chord diagram rims and labels. Each rim is plotted separately
212
+
213
+ :param idx_to_label: dict mapping each ROI index to its chord label
214
+ :param network_order: list specifying the order of the labels (rims)
215
+ :param network_colors: dict mapping each label to a matplotlib color
216
+ :param radius: Number specifying the radius of the outer part
217
+ :param rim_border: Specifies the degrees/2 between each rim
218
+ (the amount of white spacing between rims)
219
+ :param do_monkeypatch: See chord.plot_rim_label(...) and patch_RenderAgg.py.
220
+ :return: variables useful for add_arcs
221
+ network_low_high: dict specifying (starting degree, ending degree)
222
+ for each rim, including whitespace
223
+ network_counts: dict specifying the number of ROIs for each rim
224
+ network_center: dict specifying the center degree for each rim
225
+ """
226
+ num_ROIs = len(idx_to_label)
227
+
228
+ network_counts = defaultdict(lambda: 0)
229
+ for network in idx_to_label.values():
230
+ network_counts[network] += 1
231
+
232
+ network_low_high = {}
233
+ network_center = {}
234
+ network_starts_ends = {}
235
+ circle_consumed = 0
236
+ for i, network in enumerate(network_order):
237
+ cnt = network_counts[network]
238
+ degree_st = circle_consumed / num_ROIs * 360
239
+ degree_end = (circle_consumed + cnt) / num_ROIs * 360
240
+ plot_rim(degree_st, degree_end, rim_border=rim_border, radius=radius,
241
+ color=network_colors[network])
242
+ plot_rim_label(degree_st, degree_end, network,
243
+ radius=radius,
244
+ label_fontsize=label_fontsize,
245
+ do_monkeypatch=do_monkeypatch)
246
+ network_low_high[network] = (
247
+ degree_st + rim_border, degree_end - rim_border)
248
+ network_center[network] = (circle_consumed + cnt * 0.5) / num_ROIs * 360
249
+ network_starts_ends[network] = (degree_st, degree_end)
250
+ circle_consumed += cnt
251
+ return network_low_high, network_counts, network_center, network_starts_ends
252
+
253
+
254
+ def get_character_degree_locations(text: str, degree_st: Union[float, int],
255
+ degree_end: Union[float, int],
256
+ rim_border: Union[float, int],
257
+ font_kwargs: dict) -> Tuple[list, str]:
258
+ """
259
+ Gets a list specifying the degree (location) of each character within
260
+ the text. Also returns text, which may be shortened, if the full
261
+ text does not fit within the edge.
262
+
263
+ :param text: the string written around the rim
264
+ :param degree_st: starting degree of the rim (whitespace included)
265
+ :param degree_end: ending degree of the rim (whitespace included)
266
+ :param rim_border: float specifying the degrees/2 between each rim
267
+ (the amount of white spacing between rims)
268
+ :param font_kwargs: dictionary used to specify text settings
269
+ :return:
270
+ char_degs = degree position of each character in text[::-1]
271
+ text, may have been shortened if does not fully fit on the rim
272
+ """
273
+ r = plt.gcf().canvas.get_renderer()
274
+ ax = plt.gca()
275
+ arc_len = 2 * np.pi * (degree_end - degree_st - rim_border * 2) / 360
276
+ char_degs = [0]
277
+ char_deg = 0
278
+ for char_i in range(0, len(text) - 1):
279
+ char0 = text[char_i]
280
+ char1 = text[char_i + 1]
281
+ char0_transparent = plt.text(0, 0, char0, rotation=0,
282
+ rotation_mode='anchor', ha='center',
283
+ alpha=0, **font_kwargs)
284
+ bb0 = char0_transparent.get_window_extent(renderer=r).transformed(
285
+ ax.transData.inverted())
286
+ char1_transparent = plt.text(0, 0, char1, rotation=0,
287
+ rotation_mode='anchor', ha='center',
288
+ alpha=0, **font_kwargs)
289
+ bb1 = char1_transparent.get_window_extent(renderer=r).transformed(
290
+ ax.transData.inverted())
291
+
292
+ width = bb0.width / 2 + bb1.width / 2
293
+ width *= 1.1
294
+ boost = 1
295
+ if 'fontname' in font_kwargs and \
296
+ font_kwargs['fontname'].lower() == 'monospace':
297
+ # Given the shapes of some letters, they look a bit better when
298
+ # pushed to be closer to the previous/following letter in the text.
299
+ if char1 in ['D', 'P', 'C']:
300
+ boost -= .4
301
+ elif char0 in ['A']:
302
+ boost -= .8
303
+ elif char1 in ['A']:
304
+ boost += .4
305
+ # Somebody who understands kerning please help me.
306
+ # I've invested substantial effort into trying to get this right,
307
+ # and this is the best that I've got. It works fine for the Yeo
308
+ # atlas labels and for a few other label sets I tested, but
309
+ # there could be label sets out there where it doesn't look good.
310
+
311
+ char_deg += width / arc_len * (degree_end - degree_st) + boost
312
+ if char_deg > degree_end - degree_st - rim_border * 2:
313
+ text = text[:len(char_degs) - 1] + '.'
314
+ break
315
+ char_degs.append(char_deg)
316
+
317
+ M = sum(char_degs) / len(char_degs)
318
+ boost_need = (degree_end + degree_st) / 2 - M
319
+ char_degs = [deg + boost_need for deg in char_degs]
320
+ if text[-1] == '.': # a period only takes up a small fraction of the its
321
+ # available horizontal space, so we shift all the
322
+ # characters to the left a bit to help make sure
323
+ # things keep looking centered
324
+ chardot_transparent = plt.text(0, 0, '.', rotation=0,
325
+ rotation_mode='anchor', ha='center',
326
+ alpha=0, **font_kwargs)
327
+ bb_dot = chardot_transparent.get_window_extent(renderer=r).transformed(
328
+ ax.transData.inverted())
329
+ subtract = bb_dot.width / arc_len * (degree_end - degree_st) * .75
330
+ char_degs = [deg - subtract for deg in char_degs]
331
+
332
+ return char_degs, text
333
+
334
+
335
+ def plot_rim_label(degree_st: Union[float, int], degree_end: Union[float, int],
336
+ text: str,
337
+ radius: float = 0.55,
338
+ label_fontsize: Union[float, int] = 60,
339
+ do_monkeypatch: bool = False):
340
+ """
341
+ Adds a straight label around the rim of the chord diagram, pointing towards the center.
342
+
343
+ :param degree_st: Starting degree of the rim (whitespace included)
344
+ :param degree_end: Ending degree of the rim (whitespace included)
345
+ :param text: The string written around the rim
346
+ :param rim_border: Float specifying the degrees/2 between each rim
347
+ (i.e., the amount of white spacing between rims)
348
+ :param radius: radius of the chord diagram
349
+ :param color: label color
350
+ :param label_fontsize: label fontsize
351
+ :param do_monkeypatch: Flag to apply a fix for character rotation, if needed.
352
+ """
353
+ if do_monkeypatch:
354
+ from nichord import patch_RendererAgg
355
+ patch_RendererAgg.do_monkey_patch()
356
+
357
+ # monospace is the easiest font to work with
358
+ font_kwargs = {'fontname': 'monospace', 'fontsize': label_fontsize}
359
+
360
+ # Calculate the midpoint of the arc where the label will be placed
361
+ mid_deg = (degree_st + degree_end) / 2
362
+ mid_rad = np.deg2rad(mid_deg)
363
+
364
+ # Position of the label on the rim
365
+ x = np.cos(mid_rad) * radius * 1.05
366
+ y = np.sin(mid_rad) * radius * 1.05
367
+
368
+ # Determine the angle for text rotation
369
+ if mid_deg <= 90 or mid_deg > 270:
370
+ rotation_angle = mid_deg
371
+ ha = 'left'
372
+ else:
373
+ rotation_angle = mid_deg + 180
374
+ ha = 'right'
375
+
376
+ plt.text(x, y, text, rotation=rotation_angle, rotation_mode='anchor',
377
+ va='center', ha=ha, color='k', **font_kwargs)
378
+
379
+
380
+ def plot_rim(degree_st: Union[float, int], degree_end: Union[float, int],
381
+ rim_border: Union[float, int] = 1,
382
+ radius: float = 0.55, color: Union[str, tuple] = 'black'):
383
+ """
384
+ The outside rims of the chord diagram are plotted as one colored Wedge
385
+ underneath a smaller white Wedge.
386
+
387
+ :param degree_st: starting degree of the rim (whitespace included)
388
+ :param degree_end: ending degree of the rim (whitespace included)
389
+ :param rim_border: Float specifying the degrees/2 between each rim
390
+ (i.e., the amount of white spacing between rims)
391
+ :param radius: radius of the chord diagram
392
+ :param color: rim color
393
+ """
394
+ ax = plt.gca()
395
+ ax.set_aspect('equal', adjustable='box')
396
+
397
+ plt.xlim(-0.65, 0.65)
398
+ plt.ylim(-0.65, 0.65)
399
+ center = (.0, .0)
400
+ ax.add_patch(patches.Wedge(center, radius, degree_st + rim_border,
401
+ degree_end - rim_border, color=color,
402
+ alpha=1)) # larger, colored wedge
403
+ ax.add_patch(patches.Wedge(center, radius - .02, degree_st, degree_end,
404
+ color='white')) # smaller white edge
405
+
406
+
407
+
408
+ def plot_arcs(edges: list, idx_to_label: dict, network_low_high: dict,
409
+ network_counts: dict,
410
+ edge_weights: list,
411
+ network_centers: dict,
412
+ network_starts_ends: dict,
413
+ radius: float,
414
+ cmap: Union[None, str, matplotlib.colors.Colormap],
415
+ coords: list,
416
+ colors: Union[None, str, tuple, list],
417
+ linewidths: Union[None, float, int, list],
418
+ alphas: Union[None, float, int, list],
419
+ seven_point_arc: bool = False,
420
+ vmin: Union[float, int] = -1,
421
+ vmax: Union[float, int] = 1,
422
+ plot_count: bool = True,
423
+ plot_abs_sum: bool = False,
424
+ norm_thickness: bool = False,
425
+ max_linewidth: Union[float, int] = 28,
426
+ sub_min_thickness: bool = False,
427
+ ) -> (float, float):
428
+ """
429
+ Plots the arcs between each ROI. Within the rim/label, ROIs are positioned
430
+ based on their y coordinate (i.e., how anterior/posterior it is). If
431
+ no coordinate is provided, the ROIs are positioned randomly
432
+ (see plot_chord_diagram(...)).
433
+
434
+ :param edges: list of tuples (a, b), where a = index of the 1st ROI &
435
+ b = index of the 2nd ROI
436
+ :param idx_to_label: dict mapping each ROI index to its chord label
437
+ :param network_low_high: See plot_rim(...)
438
+ :param network_counts: See plot_rim(...)
439
+ :param edge_weights: list of edge weights. The nth weight, here, should
440
+ correspond to the nth edge in edges, above
441
+ :param network_centers: See plot_rim(...)
442
+ :param radius: radius of the outer part of the rim/circle
443
+ :param cmap: can either be None (default will be used), a string (will
444
+ retrieve cmap from matplotlib), or a matplotlib colormap
445
+ :param coords: list of coordinates, such that the nth entry corresponds to
446
+ idx/ROI n in edges, above
447
+ :param colors: List of colors for the edges. If None, then
448
+ some will be generated based on the magnitudes of the weights.
449
+ :param linewidths: Same as colors.
450
+ :param alphas: Same as colors.
451
+ :param seven_point_arc: Changing between 0 or 1 will slightly change the
452
+ look of the arcs
453
+ """
454
+ roi2degree = ROI_to_degree(coords, idx_to_label, network_low_high,
455
+ network_counts)
456
+ if edge_weights is None: edge_weights = np.random.random(len(edges)) * 2 - 1
457
+
458
+ if plot_count or plot_abs_sum:
459
+ network_edges = OrderedDict()
460
+ for (i, j), w in zip(edges, edge_weights):
461
+ if i == j: continue # sometimes edges connected to themselves
462
+ # are added. These should be ignored
463
+ i_network = idx_to_label[i]
464
+ j_network = idx_to_label[j]
465
+ if i_network >= j_network:
466
+ if (i_network, j_network) not in network_edges:
467
+ network_edges[(i_network, j_network)] = []
468
+ network_edges[(i_network, j_network)].append(w)
469
+ elif j_network > i_network:
470
+ if (j_network, i_network) not in network_edges:
471
+ network_edges[(j_network, i_network)] = []
472
+ network_edges[(j_network, i_network)].append(w)
473
+ if plot_abs_sum:
474
+ network2thickness = {network: sum(np.abs(l_w))
475
+ for network, l_w
476
+ in network_edges.items()}
477
+ else:
478
+ network2thickness = {network: len(l_w)
479
+ for network, l_w
480
+ in network_edges.items()}
481
+ if norm_thickness:
482
+ network2thickness_std = {}
483
+ min_std_thickness = 1e10
484
+ max_std_thickness = 0
485
+ for network_pair, thickness in network2thickness.items():
486
+ network_span0 = network_starts_ends[network_pair[0]][1] - \
487
+ network_starts_ends[network_pair[0]][0]
488
+ network_span1 = network_starts_ends[network_pair[1]][1] - \
489
+ network_starts_ends[network_pair[1]][0]
490
+ if network_pair[0] == network_pair[1]:
491
+ expected_thickness = (network_span0*network_span1) / (720*360)
492
+ else:
493
+ expected_thickness = (network_span0*network_span1) / (360*360)
494
+ network2thickness_std[network_pair] = \
495
+ thickness / expected_thickness * 25
496
+ min_std_thickness = min(min_std_thickness,
497
+ network2thickness_std[network_pair])
498
+ max_std_thickness = max(max_std_thickness,
499
+ network2thickness_std[network_pair])
500
+ if True or sub_min_thickness:
501
+ min_std_thickness -= .1
502
+ network2thickness_std = {network_pair:
503
+ (thickness - min_std_thickness) /
504
+ (max_std_thickness - min_std_thickness)
505
+ for network_pair, thickness in
506
+ network2thickness_std.items()}
507
+ network2thickness = network2thickness_std
508
+
509
+ total_thickness = sum(network2thickness.values()) # diff. normalization
510
+ max_linewidth_seen = 0
511
+ for edge, thickness in network2thickness.items(): # ensures never too thick
512
+ linewidth = thickness / total_thickness * 200
513
+ max_linewidth_seen = max(max_linewidth_seen, linewidth)
514
+ if max_linewidth_seen > max_linewidth:
515
+ for edge, thickness in network2thickness.items():
516
+ network2thickness[edge] = thickness / max_linewidth_seen * \
517
+ max_linewidth
518
+ edge_weights = [np.mean(l_w) for l_w in network_edges.values()]
519
+ vmin = min(edge_weights)
520
+ vmax = max(edge_weights)
521
+ edges = list(network_edges.keys())
522
+ seven_point_arc = False
523
+
524
+ if plot_count or plot_abs_sum:
525
+ edges, edge_weights = zip(*sorted(zip(edges, edge_weights),
526
+ key=lambda x: network2thickness[x[0]]))
527
+ else:
528
+ edges, edge_weights = zip(*sorted(zip(edges, edge_weights),
529
+ key=lambda x: abs(x[1])))
530
+ value_range = (np.min(edge_weights), np.max(edge_weights))
531
+
532
+ import seaborn as sns
533
+ # Create a diverging palette centered around zero with a midpoint of 0.5
534
+ palette = sns.diverging_palette(220, 20, n=256, as_cmap=True)
535
+ from matplotlib.colors import Normalize
536
+ from matplotlib.colors import TwoSlopeNorm
537
+ from matplotlib.cm import ScalarMappable
538
+ #norm = Normalize(vmin=value_range[0], vmax=value_range[1])
539
+ if value_range[0] == 0 and value_range[1] > 0:
540
+ adjusted_min = -value_range[1]
541
+ norm = TwoSlopeNorm(vmin=adjusted_min, vcenter=0, vmax=value_range[1])
542
+ else:
543
+ norm = TwoSlopeNorm(vmin=value_range[0], vcenter=0, vmax=value_range[1])
544
+
545
+ #norm = TwoSlopeNorm(vmin=value_range[0], vcenter=0, vmax=value_range[1])
546
+
547
+ # Create a ScalarMappable to map normalized data to colormap
548
+ scalar_mappable = ScalarMappable(norm=norm, cmap=palette)
549
+
550
+
551
+ import matplotlib.colors as mcolors
552
+ from matplotlib.cm import ScalarMappable
553
+ #divnorm = mcolors.DivergingNorm(vmin=value_range[0], vcenter=0, vmax=value_range[1])
554
+ # Create a ScalarMappable to map normalized data to colormap
555
+ #scalar_mappable = ScalarMappable(norm=norm, cmap=palette)
556
+
557
+ for idx, (edge, edge_weight) in enumerate(zip(edges, edge_weights)):
558
+ if edge_weight == 0:
559
+ continue
560
+ if plot_count or plot_abs_sum:
561
+ deg0 = network_centers[edge[0]]
562
+ deg1 = network_centers[edge[1]]
563
+ network_span0 = network_starts_ends[edge[0]][1] - \
564
+ network_starts_ends[edge[0]][0]
565
+ network_span1 = network_starts_ends[edge[1]][1] - \
566
+ network_starts_ends[edge[1]][0]
567
+
568
+ if edge[0] == edge[1]:
569
+ deg0 -= 0.15 * network_span0
570
+ deg1 += 0.15 * network_span1
571
+ else:
572
+ deg0_ = deg0
573
+ deg1_ = deg1
574
+ if abs(deg0 + 360 - deg1) < abs(deg0 - deg1):
575
+ deg1_ -= 360
576
+ if abs(deg1 + 360 - deg0) < abs(deg0 - deg1):
577
+ deg0_ -= 360
578
+
579
+ assert abs(deg0_ - deg1_) <= 180, f'Bad: {deg0_} | {deg1_}'
580
+ rel_dist = (deg0_ - deg1_) / 180
581
+ if rel_dist < 0:
582
+ rel_dist_ = -1 - rel_dist
583
+ elif rel_dist > 0:
584
+ rel_dist_ = 1 - rel_dist
585
+ rel_dist_ = np.sign(rel_dist_) * (abs(rel_dist_) ** 2)
586
+ deg0 -= rel_dist_ * network_span0/2
587
+ deg1 += rel_dist_ * network_span1/2
588
+
589
+ else:
590
+ deg0 = roi2degree.get_degree(edge[0])
591
+ deg1 = roi2degree.get_degree(edge[1])
592
+ color = None
593
+ linewidth = None
594
+ alpha = None
595
+ if all(weight == 1 for weight in edge_weights):
596
+ if colors is None:
597
+ color = 'k'
598
+ if linewidths is None:
599
+ linewidth = 1.7
600
+ if alphas is None:
601
+ alpha = 0.5
602
+ else:
603
+ if colors is None:
604
+ weight_relative_min = (edge_weight - vmin) / \
605
+ (vmax - vmin)
606
+ #color = cmap(weight_relative_min)
607
+ #color = palette(norm(edge_weight))
608
+ color = scalar_mappable.to_rgba(edge_weight)
609
+
610
+ if linewidths is None:
611
+ if plot_count or plot_abs_sum:
612
+ linewidth = network2thickness[edge] / total_thickness * 250
613
+ else:
614
+ linewidth = 1.7 + abs(edge_weight)
615
+ if alphas is None:
616
+ alpha = 0.7
617
+
618
+ if color is None:
619
+ if isinstance(colors, list):
620
+ color = colors[idx]
621
+ else:
622
+ color = colors
623
+
624
+ if linewidth is None:
625
+ if isinstance(linewidths, list):
626
+ linewidth = linewidths[idx]
627
+ else:
628
+ linewidth = linewidths
629
+
630
+ if alpha is None:
631
+ if isinstance(alphas, list):
632
+ alpha = alphas[idx]
633
+ else:
634
+ alpha = alphas
635
+
636
+ if abs(deg0 - deg1) < .000000000001:
637
+ continue # ignores nodes' connections to itself
638
+
639
+ if plot_count or plot_abs_sum:
640
+ center0 = network_centers[edge[0]]
641
+ center1 = network_centers[edge[1]]
642
+ else:
643
+ center0 = network_centers[idx_to_label[edge[0]]]
644
+ center1 = network_centers[idx_to_label[edge[1]]]
645
+
646
+ #center0 = np.mean([center0, center1])
647
+ #center1 = np.mean([center0, center1])
648
+ #deg0 = np.mean([deg0, deg1])
649
+ #deg1 = np.mean([deg0, deg1])
650
+
651
+ plot_arc(deg0, deg1,
652
+ center0, center1,
653
+ radius, color=color,
654
+ linewidth=linewidth, alpha=alpha,
655
+ seven_point_arc=seven_point_arc)
656
+ return vmin, vmax
657
+
658
+
659
+ def plot_arc(deg0: Union[int, float], deg1: [int, float],
660
+ rim_center_deg0: [int, float], rim_center_deg1: [int, float],
661
+ radius: float, alpha: float = 0.5,
662
+ color: Union[str, tuple] = 'black', seven_point_arc: bool = False,
663
+ linewidth: Union[float, int] = 1):
664
+ """
665
+ Adds the arc between the two ROIs. The arcs are either made up of 5 or 7
666
+ connected and smoothed points.
667
+
668
+ :param deg0: degree position of ROI 0
669
+ :param deg1: degree position of ROI 1
670
+ :param rim_center_deg0: degree of the center of the rim corresponding to
671
+ ROI 0 (this helps with making nice arcs)
672
+ :param rim_center_deg1: degree of the center of the rim corresponding to
673
+ ROI 1 (this helps with making nice arcs)
674
+ :param radius: Radius of the outer part of the rim/circle
675
+ :param alpha: corresponds to the arcs
676
+ :param color: Color of the arcs
677
+ :param seven_point_arc: Tweaking True/False changes the look of the arcs
678
+ :param linewidth: Width of arcs
679
+ """
680
+
681
+
682
+ theta0 = radians(deg0)
683
+ theta1 = radians(deg1)
684
+ rim_center_theta0 = radians(rim_center_deg0)
685
+ rim_center_theta1 = radians(rim_center_deg1)
686
+
687
+ radius = radius - 0.02
688
+ start_cart = polar_to_cart(radius, theta0)
689
+ end_cart = polar_to_cart(radius, theta1)
690
+
691
+ if abs(deg0 - deg1) < .000001:
692
+ # Code goes here...
693
+ return
694
+
695
+ theta0_ = (theta0 + np.pi) % (2 * np.pi) - np.pi
696
+ theta1_ = (theta1 + np.pi) % (2 * np.pi) - np.pi
697
+ dif = abs(theta0_ - theta1_)
698
+ dif = 2 * np.pi - dif if dif > np.pi else dif
699
+
700
+ start_core = [radius * (0.5 + 0.5 * (1 - (dif / np.pi)) ** 0.5),
701
+ (theta0 + rim_center_theta0) / 2]
702
+ end_core = [radius * (0.5 + 0.5 * (1 - (dif / np.pi)) ** 0.5),
703
+ (theta1 + rim_center_theta1) / 2]
704
+
705
+ start_network_cart = polar_to_cart(*start_core)
706
+
707
+ end_network_cart = polar_to_cart(*end_core)
708
+
709
+
710
+ start_inner_cart = polar_to_cart(radius - .02, theta0)
711
+ end_inner_cart = polar_to_cart(radius - .02, theta1)
712
+
713
+ theta_mid = (theta0 + theta1) / 2
714
+ if abs(theta1 - theta0) > np.pi:
715
+ theta_mid += np.pi
716
+
717
+
718
+
719
+
720
+ less_than_half_away = abs(theta0 - theta1) < np.pi / 2 or \
721
+ abs(theta0 - theta1 - 2 * np.pi) < np.pi / 2 or \
722
+ abs(theta0 - theta1 + 2 * np.pi) < np.pi / 2
723
+
724
+ # same_network = abs(rim_center_deg0 - rim_center_deg1) < .0001
725
+ # less_than_half_away
726
+ if not seven_point_arc or less_than_half_away:
727
+ dif_radius = (radius - .02) * ((1 - (dif / np.pi)) ** 1.35)
728
+ mid_cart = polar_to_cart(dif_radius, theta_mid)
729
+ x, y = zip(start_cart, start_inner_cart, mid_cart,
730
+ end_inner_cart, end_cart)
731
+ warnings.filterwarnings("ignore", message="divide by zero encountered in divide")
732
+ tck, u = interpolate.splprep([list(x), list(y)], k=3,
733
+ s=.0002,
734
+ quiet=True)
735
+ u = np.linspace(0, 1, 1000)
736
+ xnew, ynew = interpolate.splev(u, tck)
737
+ xnew, ynew = zip(
738
+ *filter(lambda xy: cart_to_polar(*xy)[0] < radius,
739
+ zip(xnew, ynew)))
740
+ else:
741
+ dif_radius = radius * ((1 - (dif / np.pi)) ** 1.5)
742
+ mid_cart = polar_to_cart(dif_radius, theta_mid)
743
+ x, y = zip(start_cart, start_inner_cart, start_network_cart, mid_cart,
744
+ end_network_cart, end_inner_cart, end_cart)
745
+ tck, u = interpolate.splprep([list(x), list(y)], k=3, s=0.0002,
746
+ quiet=True)
747
+ u = np.linspace(0, 1, 1000)
748
+ xnew, ynew = interpolate.splev(u, tck)
749
+ xnew, ynew = zip(
750
+ *filter(lambda xy: cart_to_polar(*xy)[0] < radius,
751
+ zip(xnew, ynew)))
752
+
753
+ plt.gca().plot(xnew, ynew, alpha=alpha, linewidth=linewidth, color=color,
754
+ solid_joinstyle='bevel', solid_capstyle='round')
755
+
756
+
757
+ def plot_ROI_circles(coords: list, idx_to_label: dict, network_low_high: dict,
758
+ network_counts: dict,
759
+ network_colors: dict, radius: float,
760
+ exclusive_idx: Union[None, list, set]=None,
761
+ ROI_circle_radius=0.005) -> None:
762
+ """
763
+ Optional piece of the chord diagram. This will draw a little circle at each
764
+ ROI's position within each rim.
765
+
766
+ :param coords: list of coordinates, such that the nth entry corresponds to
767
+ idx/ROI n in edges (see other functions)
768
+ :param idx_to_label: dict mapping each ROI index to its chord label
769
+ (e.g., {0: "FPCN"})
770
+ :param network_low_high: see plot_rims(...)
771
+ :param network_counts: see plot_rims(...)
772
+ :param network_colors: dict mapping each label to a matplotlib color
773
+ :param radius: float specifying the radius of the outer part
774
+ """
775
+ ax = plt.gca()
776
+ roi2degree = ROI_to_degree(coords, idx_to_label, network_low_high,
777
+ network_counts)
778
+ for idx in idx_to_label:
779
+ if exclusive_idx is not None and idx not in exclusive_idx:
780
+ continue
781
+ deg = roi2degree.get_degree(idx)
782
+ x, y = polar_to_cart(radius - 0.02, radians(deg))
783
+ circle = plt.Circle((x, y), ROI_circle_radius,
784
+ color=network_colors[idx_to_label[idx]],
785
+ zorder=100000)
786
+ ax.add_patch(circle)
787
+
788
+
789
+ def polar_to_cart(r: Union[float, int], theta: Union[float, int]) -> Tuple[
790
+ float, float]:
791
+ """
792
+ Converts polar points (r, theta) to cartesian (x, y)
793
+ """
794
+ return r * math.cos(theta), r * math.sin(theta)
795
+
796
+
797
+ def cart_to_polar(x: Union[float, int], y: Union[float, int]) -> Tuple[
798
+ float, float]:
799
+ """
800
+ Converts cartesian points (x, y) to polar (r, theta)
801
+ """
802
+ z = x + y * 1j
803
+ r, theta = np.abs(z), np.angle(z)
804
+ return r, theta
805
+
806
+
807
+ def _threshold_proc(edge_threshold, edge_weights, edges):
808
+ if edge_threshold == 0: return edge_weights, edges
809
+ edge_threshold = check_threshold(
810
+ edge_threshold,
811
+ np.abs(edge_weights),
812
+ scoreatpercentile,
813
+ "edge_threshold",
814
+ )
815
+
816
+ if edge_threshold > 0:
817
+ edges = [edge for edge, weight in zip(edges, edge_weights)
818
+ if abs(weight) > edge_threshold]
819
+ edge_weights = [weight for weight in edge_weights
820
+ if abs(weight) > edge_threshold]
821
+ return edge_weights, edges