GLDF 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4764 @@
1
+ """Tigramite plotting package."""
2
+
3
+ # Author: Jakob Runge <jakob@jakob-runge.com>
4
+ #
5
+ # License: GNU General Public License v3.0
6
+
7
+ import numpy as np
8
+ import json, warnings, os, pathlib
9
+ import matplotlib
10
+ import networkx as nx
11
+ from matplotlib.colors import ListedColormap
12
+ import matplotlib.transforms as transforms
13
+ from matplotlib import pyplot, ticker
14
+ from matplotlib.ticker import FormatStrFormatter
15
+ import matplotlib.patches as mpatches
16
+ from matplotlib.collections import PatchCollection
17
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
18
+ import sys
19
+ from operator import sub
20
+ import tigramite.data_processing as pp
21
+ from copy import deepcopy
22
+ import matplotlib.path as mpath
23
+ import matplotlib.patheffects as PathEffects
24
+ from mpl_toolkits.axisartist.axislines import Axes
25
+ import csv
26
+
27
+
28
+ # TODO: Add proper docstrings to internal functions...
29
+
30
+
31
+ def _par_corr_trafo(cmi):
32
+ """Transformation of CMI to partial correlation scale."""
33
+
34
+ # Set negative values to small positive number
35
+ # (zero would be interpreted as non-significant in some functions)
36
+ if np.ndim(cmi) == 0:
37
+ if cmi < 0.0:
38
+ cmi = 1e-8
39
+ else:
40
+ cmi[cmi < 0.0] = 1e-8
41
+
42
+ return np.sqrt(1.0 - np.exp(-2.0 * cmi))
43
+
44
+
45
+ def _par_corr_to_cmi(par_corr):
46
+ """Transformation of partial correlation to CMI scale."""
47
+
48
+ return -0.5 * np.log(1.0 - par_corr ** 2)
49
+
50
+
51
+ def _myround(x, base=5, round_mode="updown"):
52
+ """Rounds x to a float with precision base."""
53
+
54
+ if round_mode == "updown":
55
+ return base * round(float(x) / base)
56
+ elif round_mode == "down":
57
+ return base * np.floor(float(x) / base)
58
+ elif round_mode == "up":
59
+ return base * np.ceil(float(x) / base)
60
+
61
+ return base * round(float(x) / base)
62
+
63
+
64
+ def _make_nice_axes(ax, where=None, skip=1, color=None):
65
+ """Makes nice axes."""
66
+
67
+ if where is None:
68
+ where = ["left", "bottom"]
69
+ if color is None:
70
+ color = {"left": "black", "right": "black", "bottom": "black", "top": "black"}
71
+
72
+ if type(skip) == int:
73
+ skip_x = skip_y = skip
74
+ else:
75
+ skip_x = skip[0]
76
+ skip_y = skip[1]
77
+
78
+ for loc, spine in ax.spines.items():
79
+ if loc in where:
80
+ spine.set_position(("outward", 5)) # outward by 10 points
81
+ spine.set_color(color[loc])
82
+ if loc == "left" or loc == "right":
83
+ pyplot.setp(ax.get_yticklines(), color=color[loc])
84
+ pyplot.setp(ax.get_yticklabels(), color=color[loc])
85
+ if loc == "top" or loc == "bottom":
86
+ pyplot.setp(ax.get_xticklines(), color=color[loc])
87
+ elif loc in [
88
+ item for item in ["left", "bottom", "right", "top"] if item not in where
89
+ ]:
90
+ spine.set_color("none") # don't draw spine
91
+ else:
92
+ raise ValueError("unknown spine location: %s" % loc)
93
+
94
+ # ax.xaxis.get_major_formatter().set_useOffset(False)
95
+
96
+ # turn off ticks where there is no spine
97
+ if "top" in where and "bottom" not in where:
98
+ ax.xaxis.set_ticks_position("top")
99
+ if skip_x > 1:
100
+ ax.set_xticks(ax.get_xticks()[::skip_x])
101
+ elif "bottom" in where:
102
+ ax.xaxis.set_ticks_position("bottom")
103
+ if skip_x > 1:
104
+ ax.set_xticks(ax.get_xticks()[::skip_x])
105
+ else:
106
+ ax.xaxis.set_ticks_position("none")
107
+ ax.xaxis.set_ticklabels([])
108
+ if "right" in where and "left" not in where:
109
+ ax.yaxis.set_ticks_position("right")
110
+ if skip_y > 1:
111
+ ax.set_yticks(ax.get_yticks()[::skip_y])
112
+ elif "left" in where:
113
+ ax.yaxis.set_ticks_position("left")
114
+ if skip_y > 1:
115
+ ax.set_yticks(ax.get_yticks()[::skip_y])
116
+ else:
117
+ ax.yaxis.set_ticks_position("none")
118
+ ax.yaxis.set_ticklabels([])
119
+
120
+ ax.patch.set_alpha(0.0)
121
+
122
+
123
+ def _get_absmax(val_matrix):
124
+ """Get value at absolute maximum in lag function array.
125
+ For an (N, N, tau)-array this comutes the lag of the absolute maximum
126
+ along the tau-axis and stores the (positive or negative) value in
127
+ the (N,N)-array absmax."""
128
+
129
+ absmax_indices = np.abs(val_matrix).argmax(axis=2)
130
+ i, j = np.indices(val_matrix.shape[:2])
131
+
132
+ return val_matrix[i, j, absmax_indices]
133
+
134
+
135
+ def _add_timeseries(
136
+ dataframe,
137
+ fig_axes,
138
+ grey_masked_samples=False,
139
+ show_meanline=False,
140
+ data_linewidth=1.0,
141
+ color="black",
142
+ alpha=1.,
143
+ grey_alpha=1.0,
144
+ selected_dataset=0,
145
+ selected_variables=None,
146
+ ):
147
+ """Adds a time series plot to an axis.
148
+ Plot of dataseries is added to axis. Allows for proper visualization of
149
+ masked data.
150
+
151
+ Parameters
152
+ ----------
153
+ fig : figure instance
154
+ Figure instance.
155
+ axes : axis instance
156
+ Either gridded axis object or single axis instance.
157
+ grey_masked_samples : bool, optional (default: False)
158
+ Whether to mark masked samples by grey fills ('fill') or grey data
159
+ ('data').
160
+ show_meanline : bool
161
+ Show mean of data as horizontal line.
162
+ data_linewidth : float, optional (default: 1.)
163
+ Linewidth.
164
+ color : str, optional (default: black)
165
+ Line color.
166
+ alpha : float
167
+ Alpha opacity.
168
+ grey_alpha : float, optional (default: 1.)
169
+ Opacity of fill_between.
170
+ selected_dataset : int, optional (default: 0)
171
+ In case of multiple datasets in dataframe, plot this one.
172
+ selected_variables : list, optional (default: None)
173
+ List of variables which to plot.
174
+ """
175
+ fig, axes = fig_axes
176
+
177
+ # Read in all attributes from dataframe
178
+ data = dataframe.values[selected_dataset]
179
+ if dataframe.mask is not None:
180
+ mask = dataframe.mask[selected_dataset]
181
+ else:
182
+ mask = None
183
+
184
+ missing_flag = dataframe.missing_flag
185
+ time = dataframe.datatime[selected_dataset]
186
+ T = len(time)
187
+
188
+ if selected_variables is None:
189
+ selected_variables = list(range(dataframe.N))
190
+
191
+ nb_components = sum([len(dataframe.vector_vars[var]) for var in selected_variables])
192
+
193
+ for j in range(nb_components):
194
+
195
+ ax = axes[j]
196
+ dataseries = data[:, j]
197
+
198
+ if missing_flag is not None:
199
+ dataseries_nomissing = np.ma.masked_where(
200
+ dataseries == missing_flag, dataseries
201
+ )
202
+ else:
203
+ dataseries_nomissing = np.ma.masked_where(
204
+ np.zeros(dataseries.shape), dataseries
205
+ )
206
+
207
+
208
+ if mask is not None:
209
+ maskseries = mask[:, j]
210
+
211
+ maskdata = np.ma.masked_where(maskseries, dataseries_nomissing)
212
+
213
+ if grey_masked_samples == "fill":
214
+ ax.fill_between(
215
+ time,
216
+ maskdata.min(),
217
+ maskdata.max(),
218
+ where=maskseries,
219
+ color="grey",
220
+ interpolate=True,
221
+ linewidth=0.0,
222
+ alpha=grey_alpha,
223
+ )
224
+ elif grey_masked_samples == "data":
225
+ ax.plot(
226
+ time,
227
+ dataseries_nomissing,
228
+ color="grey",
229
+ marker=".",
230
+ markersize=data_linewidth,
231
+ linewidth=data_linewidth,
232
+ clip_on=False,
233
+ alpha=grey_alpha,
234
+ )
235
+ if show_meanline:
236
+ ax.plot(time, maskdata.mean() * np.ones(T), lw=data_linewidth / 2., color=color)
237
+
238
+ ax.plot(
239
+ time,
240
+ maskdata,
241
+ color=color,
242
+ linewidth=data_linewidth,
243
+ marker=".",
244
+ markersize=data_linewidth,
245
+ clip_on=False,
246
+ alpha=alpha,
247
+ )
248
+ else:
249
+ if show_meanline:
250
+ ax.plot(time, dataseries_nomissing.mean() * np.ones(T), lw=data_linewidth / 2., color=color)
251
+
252
+ ax.plot(
253
+ time,
254
+ dataseries_nomissing,
255
+ color=color,
256
+ linewidth=data_linewidth,
257
+ clip_on=False,
258
+ alpha=alpha,
259
+ )
260
+
261
+
262
+ def plot_timeseries(
263
+ dataframe=None,
264
+ save_name=None,
265
+ fig_axes=None,
266
+ figsize=None,
267
+ var_units=None,
268
+ time_label="",
269
+ grey_masked_samples=False,
270
+ show_meanline=False,
271
+ data_linewidth=1.0,
272
+ skip_ticks_data_x=1,
273
+ skip_ticks_data_y=1,
274
+ label_fontsize=10,
275
+ color='black',
276
+ alpha=1.,
277
+ tick_label_size=6,
278
+ selected_dataset=0,
279
+ adjust_plot=True,
280
+ selected_variables=None,
281
+ ):
282
+ """Create and save figure of stacked panels with time series.
283
+
284
+ Parameters
285
+ ----------
286
+ dataframe : data object, optional
287
+ This is the Tigramite dataframe object. It has the attributes
288
+ dataframe.values yielding a np array of shape (observations T,
289
+ variables N) and optionally a mask of the same shape.
290
+ save_name : str, optional (default: None)
291
+ Name of figure file to save figure. If None, figure is shown in window.
292
+ fig_axes : subplots instance, optional (default: None)
293
+ Figure and axes instance. If None they are created as
294
+ fig, axes = pyplot.subplots(N,...)
295
+ figsize : tuple of floats, optional (default: None)
296
+ Figure size if new figure is created. If None, default pyplot figsize
297
+ is used.
298
+ var_units : list of str, optional (default: None)
299
+ Units of variables.
300
+ time_label : str, optional (default: '')
301
+ Label of time axis.
302
+ grey_masked_samples : bool, optional (default: False)
303
+ Whether to mark masked samples by grey fills ('fill') or grey data
304
+ ('data').
305
+ show_meanline : bool, optional (default: False)
306
+ Whether to plot a horizontal line at the mean.
307
+ data_linewidth : float, optional (default: 1.)
308
+ Linewidth.
309
+ skip_ticks_data_x : int, optional (default: 1)
310
+ Skip every other tickmark.
311
+ skip_ticks_data_y : int, optional (default: 2)
312
+ Skip every other tickmark.
313
+ label_fontsize : int, optional (default: 10)
314
+ Fontsize of variable labels.
315
+ tick_label_size : int, optional (default: 6)
316
+ Fontsize of tick labels.
317
+ color : str, optional (default: black)
318
+ Line color.
319
+ alpha : float
320
+ Alpha opacity.
321
+ selected_dataset : int, optional (default: 0)
322
+ In case of multiple datasets in dataframe, plot this one.
323
+ selected_variables : list, optional (default: None)
324
+ List of variables which to plot.
325
+ """
326
+
327
+ var_names = dataframe.var_names
328
+ time = dataframe.datatime[selected_dataset]
329
+
330
+ N = dataframe.N
331
+
332
+ if selected_variables is None:
333
+ selected_variables = list(range(N))
334
+
335
+ nb_components_per_var = [len(dataframe.vector_vars[var]) for var in selected_variables]
336
+ N_index = [sum(nb_components_per_var[:i]) for i, el in enumerate(nb_components_per_var)]
337
+ nb_components = sum(nb_components_per_var)
338
+
339
+ if var_units is None:
340
+ var_units = ["" for i in range(N)]
341
+
342
+ if fig_axes is None:
343
+ fig, axes = pyplot.subplots(nb_components, sharex=True, figsize=figsize)
344
+ else:
345
+ fig, axes = fig_axes
346
+
347
+ if adjust_plot:
348
+ for i in range(nb_components):
349
+
350
+ ax = axes[i]
351
+
352
+ if (i == nb_components - 1):
353
+ _make_nice_axes(
354
+ ax, where=["left", "bottom"], skip=(skip_ticks_data_x, skip_ticks_data_y)
355
+ )
356
+ ax.set_xlabel(r"%s" % time_label, fontsize=label_fontsize)
357
+ else:
358
+ _make_nice_axes(ax, where=["left"], skip=(skip_ticks_data_x, skip_ticks_data_y))
359
+ # ax.get_xaxis().get_major_formatter().set_useOffset(False)
360
+
361
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%.0f"))
362
+ ax.label_outer()
363
+
364
+ ax.set_xlim(time[0], time[-1])
365
+
366
+ # trans = transforms.blended_transform_factory(fig.transFigure, ax.transAxes)
367
+ if i in N_index:
368
+ if var_units[N_index.index(i)]:
369
+ ax.set_ylabel(r"%s [%s]" % (var_names[N_index.index(i)], var_units[N_index.index(i)]),
370
+ fontsize=label_fontsize)
371
+ else:
372
+ ax.set_ylabel(r"%s" % (var_names[N_index.index(i)]), fontsize=label_fontsize)
373
+
374
+ ax.tick_params(axis='both', which='major', labelsize=tick_label_size)
375
+ # ax.tick_params(axis='both', which='minor', labelsize=tick_label_size)
376
+
377
+ _add_timeseries(
378
+ dataframe=dataframe,
379
+ fig_axes=(fig, axes),
380
+ grey_masked_samples=grey_masked_samples,
381
+ show_meanline=show_meanline,
382
+ data_linewidth=data_linewidth,
383
+ color=color,
384
+ selected_dataset=selected_dataset,
385
+ alpha=alpha,
386
+ selected_variables=selected_variables
387
+ )
388
+
389
+ if adjust_plot:
390
+ fig.subplots_adjust(bottom=0.15, top=0.9, left=0.15, right=0.95, hspace=0.3)
391
+ pyplot.tight_layout()
392
+
393
+ if save_name is not None:
394
+ fig.savefig(save_name)
395
+
396
+ return fig, axes
397
+
398
+
399
+ def plot_lagfuncs(val_matrix,
400
+ name=None,
401
+ setup_args={},
402
+ add_lagfunc_args={}):
403
+ """Wrapper helper function to plot lag functions.
404
+ Sets up the matrix object and plots the lagfunction, see parameters in
405
+ setup_matrix and add_lagfuncs.
406
+
407
+ Parameters
408
+ ----------
409
+ val_matrix : array_like
410
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
411
+ name : str, optional (default: None)
412
+ File name. If None, figure is shown in window.
413
+ setup_args : dict
414
+ Arguments for setting up the lag function matrix, see doc of
415
+ setup_matrix.
416
+ add_lagfunc_args : dict
417
+ Arguments for adding a lag function matrix, see doc of add_lagfuncs.
418
+
419
+ Returns
420
+ -------
421
+ matrix : object
422
+ Further lag functions can be overlaid using the
423
+ matrix.add_lagfuncs(val_matrix) function.
424
+ """
425
+
426
+ N, N, tau_max_plusone = val_matrix.shape
427
+ tau_max = tau_max_plusone - 1
428
+
429
+ matrix = setup_matrix(N=N, tau_max=tau_max, **setup_args)
430
+ matrix.add_lagfuncs(val_matrix=val_matrix, **add_lagfunc_args)
431
+ matrix.savefig(name=name)
432
+
433
+ return matrix
434
+
435
+
436
+ class setup_matrix:
437
+ """Create matrix of lag function panels.
438
+ Class to setup figure object. The function add_lagfuncs(...) allows to plot
439
+ the val_matrix of shape (N, N, tau_max+1). Multiple lagfunctions can be
440
+ overlaid for comparison.
441
+
442
+ Parameters
443
+ ----------
444
+ N : int
445
+ Number of variables
446
+ tau_max : int
447
+ Maximum time lag.
448
+ var_names : list, optional (default: None)
449
+ List of variable names. If None, range(N) is used.
450
+ figsize : tuple of floats, optional (default: None)
451
+ Figure size if new figure is created. If None, default pyplot figsize
452
+ is used.
453
+ minimum : float, optional (default: -1.)
454
+ Lower y-axis limit.
455
+ maximum : float, optional (default: 1.)
456
+ Upper y-axis limit.
457
+ label_space_left : float, optional (default: 0.1)
458
+ Fraction of horizontal figure space to allocate left of plot for labels.
459
+ label_space_top : float, optional (default: 0.05)
460
+ Fraction of vertical figure space to allocate top of plot for labels.
461
+ legend_width : float, optional (default: 0.15)
462
+ Fraction of horizontal figure space to allocate right of plot for
463
+ legend.
464
+ tick_label_size : int, optional (default: 6)
465
+ Fontsize of tick labels.
466
+ x_base : float, optional (default: 1.)
467
+ x-tick intervals to show.
468
+ y_base : float, optional (default: .4)
469
+ y-tick intervals to show.
470
+ plot_gridlines : bool, optional (default: False)
471
+ Whether to show a grid.
472
+ lag_units : str, optional (default: '')
473
+ lag_array : array, optional (default: None)
474
+ Optional specification of lags overwriting np.arange(0, tau_max+1)
475
+ label_fontsize : int, optional (default: 10)
476
+ Fontsize of variable labels.
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ N,
482
+ tau_max,
483
+ var_names=None,
484
+ figsize=None,
485
+ minimum=-1,
486
+ maximum=1,
487
+ label_space_left=0.1,
488
+ label_space_top=0.05,
489
+ legend_width=0.15,
490
+ legend_fontsize=10,
491
+ x_base=1.0,
492
+ y_base=0.5,
493
+ tick_label_size=6,
494
+ plot_gridlines=False,
495
+ lag_units="",
496
+ lag_array=None,
497
+ label_fontsize=10,
498
+ ):
499
+
500
+ self.tau_max = tau_max
501
+
502
+ self.labels = []
503
+ self.lag_units = lag_units
504
+ # if lag_array is None:
505
+ # self.lag_array = np.arange(0, self.tau_max + 1)
506
+ # else:
507
+ self.lag_array = lag_array
508
+ if x_base is None:
509
+ self.x_base = 1
510
+ else:
511
+ self.x_base = x_base
512
+
513
+ self.legend_width = legend_width
514
+ self.legend_fontsize = legend_fontsize
515
+
516
+ self.label_space_left = label_space_left
517
+ self.label_space_top = label_space_top
518
+ self.label_fontsize = label_fontsize
519
+
520
+ self.fig = pyplot.figure(figsize=figsize)
521
+
522
+ self.axes_dict = {}
523
+
524
+ if var_names is None:
525
+ var_names = range(N)
526
+
527
+ plot_index = 1
528
+ for i in range(N):
529
+ for j in range(N):
530
+ self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index)
531
+ # Plot process labels
532
+ if j == 0:
533
+ trans = transforms.blended_transform_factory(
534
+ self.fig.transFigure, self.axes_dict[(i, j)].transAxes
535
+ )
536
+ self.axes_dict[(i, j)].text(
537
+ 0.01,
538
+ 0.5,
539
+ "%s" % str(var_names[i]),
540
+ fontsize=label_fontsize,
541
+ horizontalalignment="left",
542
+ verticalalignment="center",
543
+ transform=trans,
544
+ )
545
+ if i == 0:
546
+ trans = transforms.blended_transform_factory(
547
+ self.axes_dict[(i, j)].transAxes, self.fig.transFigure
548
+ )
549
+ self.axes_dict[(i, j)].text(
550
+ 0.5,
551
+ 0.99,
552
+ r"${\to}$ " + "%s" % str(var_names[j]),
553
+ fontsize=label_fontsize,
554
+ horizontalalignment="center",
555
+ verticalalignment="top",
556
+ transform=trans,
557
+ )
558
+
559
+ # Make nice axis
560
+ _make_nice_axes(
561
+ self.axes_dict[(i, j)], where=["left", "bottom"], skip=(1, 1)
562
+ )
563
+ if x_base is not None:
564
+ self.axes_dict[(i, j)].xaxis.set_major_locator(
565
+ ticker.FixedLocator(np.arange(0, self.tau_max + 1, x_base))
566
+ )
567
+ if x_base / 2.0 % 1 == 0:
568
+ self.axes_dict[(i, j)].xaxis.set_minor_locator(
569
+ ticker.FixedLocator(
570
+ np.arange(0, self.tau_max + 1, x_base / 2.0)
571
+ )
572
+ )
573
+ if y_base is not None:
574
+ self.axes_dict[(i, j)].yaxis.set_major_locator(
575
+ ticker.FixedLocator(
576
+ np.arange(
577
+ _myround(minimum, y_base, "down"),
578
+ _myround(maximum, y_base, "up") + y_base,
579
+ y_base,
580
+ )
581
+ )
582
+ )
583
+ self.axes_dict[(i, j)].yaxis.set_minor_locator(
584
+ ticker.FixedLocator(
585
+ np.arange(
586
+ _myround(minimum, y_base, "down"),
587
+ _myround(maximum, y_base, "up") + y_base,
588
+ y_base / 2.0,
589
+ )
590
+ )
591
+ )
592
+
593
+ self.axes_dict[(i, j)].set_ylim(
594
+ _myround(minimum, y_base, "down"),
595
+ _myround(maximum, y_base, "up"),
596
+ )
597
+ if j != 0:
598
+ self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
599
+ self.axes_dict[(i, j)].set_xlim(0, self.tau_max)
600
+ if plot_gridlines:
601
+ self.axes_dict[(i, j)].grid(
602
+ True,
603
+ which="major",
604
+ color="black",
605
+ linestyle="dotted",
606
+ dashes=(1, 1),
607
+ linewidth=0.05,
608
+ zorder=-5,
609
+ )
610
+ self.axes_dict[(i, j)].tick_params(axis='both', which='major', labelsize=tick_label_size)
611
+ self.axes_dict[(i, j)].tick_params(axis='both', which='minor', labelsize=tick_label_size)
612
+
613
+ plot_index += 1
614
+
615
+ def add_lagfuncs(
616
+ self,
617
+ val_matrix,
618
+ sig_thres=None,
619
+ conf_matrix=None,
620
+ color="black",
621
+ label=None,
622
+ two_sided_thres=True,
623
+ marker=".",
624
+ markersize=5,
625
+ alpha=1.0,
626
+ ):
627
+ """Add lag function plot from val_matrix array.
628
+
629
+ Parameters
630
+ ----------
631
+ val_matrix : array_like
632
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
633
+ sig_thres : array-like, optional (default: None)
634
+ Matrix of significance thresholds. Must be of same shape as
635
+ val_matrix.
636
+ conf_matrix : array-like, optional (default: None)
637
+ Matrix of shape (, N, tau_max+1, 2) containing confidence bounds.
638
+ color : str, optional (default: 'black')
639
+ Line color.
640
+ label : str
641
+ Test statistic label.
642
+ two_sided_thres : bool, optional (default: True)
643
+ Whether to draw sig_thres for pos. and neg. values.
644
+ marker : matplotlib marker symbol, optional (default: '.')
645
+ Marker.
646
+ markersize : int, optional (default: 5)
647
+ Marker size.
648
+ alpha : float, optional (default: 1.)
649
+ Opacity.
650
+ """
651
+
652
+ if label is not None:
653
+ self.labels.append((label, color, marker, markersize, alpha))
654
+
655
+ for ij in list(self.axes_dict):
656
+ i = ij[0]
657
+ j = ij[1]
658
+ maskedres = np.copy(val_matrix[i, j, int(i == j) :])
659
+ self.axes_dict[(i, j)].plot(
660
+ range(int(i == j), self.tau_max + 1),
661
+ maskedres,
662
+ linestyle="",
663
+ color=color,
664
+ marker=marker,
665
+ markersize=markersize,
666
+ alpha=alpha,
667
+ clip_on=False,
668
+ )
669
+ if conf_matrix is not None:
670
+ maskedconfres = np.copy(conf_matrix[i, j, int(i == j) :])
671
+ self.axes_dict[(i, j)].plot(
672
+ range(int(i == j), self.tau_max + 1),
673
+ maskedconfres[:, 0],
674
+ linestyle="",
675
+ color=color,
676
+ marker="_",
677
+ markersize=markersize - 2,
678
+ alpha=alpha,
679
+ clip_on=False,
680
+ )
681
+ self.axes_dict[(i, j)].plot(
682
+ range(int(i == j), self.tau_max + 1),
683
+ maskedconfres[:, 1],
684
+ linestyle="",
685
+ color=color,
686
+ marker="_",
687
+ markersize=markersize - 2,
688
+ alpha=alpha,
689
+ clip_on=False,
690
+ )
691
+
692
+ self.axes_dict[(i, j)].plot(
693
+ range(int(i == j), self.tau_max + 1),
694
+ np.zeros(self.tau_max + 1 - int(i == j)),
695
+ color="black",
696
+ linestyle="dotted",
697
+ linewidth=0.1,
698
+ )
699
+
700
+ if sig_thres is not None:
701
+ maskedsigres = sig_thres[i, j, int(i == j) :]
702
+
703
+ self.axes_dict[(i, j)].plot(
704
+ range(int(i == j), self.tau_max + 1),
705
+ maskedsigres,
706
+ color=color,
707
+ linestyle="solid",
708
+ linewidth=0.1,
709
+ alpha=alpha,
710
+ )
711
+ if two_sided_thres:
712
+ self.axes_dict[(i, j)].plot(
713
+ range(int(i == j), self.tau_max + 1),
714
+ -sig_thres[i, j, int(i == j) :],
715
+ color=color,
716
+ linestyle="solid",
717
+ linewidth=0.1,
718
+ alpha=alpha,
719
+ )
720
+ # pyplot.tight_layout()
721
+
722
+ def savefig(self, name=None):
723
+ """Save matrix figure.
724
+
725
+ Parameters
726
+ ----------
727
+ name : str, optional (default: None)
728
+ File name. If None, figure is shown in window.
729
+ """
730
+
731
+ # Trick to plot legend
732
+ if len(self.labels) > 0:
733
+ axlegend = self.fig.add_subplot(111, frameon=False)
734
+ axlegend.spines["left"].set_color("none")
735
+ axlegend.spines["right"].set_color("none")
736
+ axlegend.spines["bottom"].set_color("none")
737
+ axlegend.spines["top"].set_color("none")
738
+ axlegend.set_xticks([])
739
+ axlegend.set_yticks([])
740
+
741
+ # self.labels.append((label, color, marker, markersize, alpha))
742
+ for item in self.labels:
743
+ label = item[0]
744
+ color = item[1]
745
+ marker = item[2]
746
+ markersize = item[3]
747
+ alpha = item[4]
748
+
749
+ axlegend.plot(
750
+ [],
751
+ [],
752
+ linestyle="",
753
+ color=color,
754
+ marker=marker,
755
+ markersize=markersize,
756
+ label=label,
757
+ alpha=alpha,
758
+ )
759
+ axlegend.legend(
760
+ loc="upper left",
761
+ ncol=1,
762
+ bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
763
+ borderaxespad=0,
764
+ fontsize=self.legend_fontsize,
765
+ ).draw_frame(False)
766
+
767
+ self.fig.subplots_adjust(
768
+ left=self.label_space_left,
769
+ right=1.0 - self.legend_width,
770
+ top=1.0 - self.label_space_top,
771
+ hspace=0.35,
772
+ wspace=0.35,
773
+ )
774
+ pyplot.figtext(
775
+ 0.5,
776
+ 0.01,
777
+ r"lag $\tau$ [%s]" % self.lag_units,
778
+ horizontalalignment="center",
779
+ fontsize=self.label_fontsize,
780
+ )
781
+ else:
782
+ self.fig.subplots_adjust(
783
+ left=self.label_space_left,
784
+ right=0.95,
785
+ top=1.0 - self.label_space_top,
786
+ hspace=0.35,
787
+ wspace=0.35,
788
+ )
789
+ pyplot.figtext(
790
+ 0.55,
791
+ 0.01,
792
+ r"lag $\tau$ [%s]" % self.lag_units,
793
+ horizontalalignment="center",
794
+ fontsize=self.label_fontsize,
795
+ )
796
+
797
+ if self.lag_array is not None:
798
+ assert self.lag_array.shape == np.arange(self.tau_max + 1).shape
799
+ for ij in list(self.axes_dict):
800
+ i = ij[0]
801
+ j = ij[1]
802
+ self.axes_dict[(i, j)].set_xticklabels(self.lag_array[:: self.x_base])
803
+
804
+ if name is not None:
805
+ self.fig.savefig(name)
806
+ else:
807
+ pyplot.show()
808
+
809
+
810
+
811
+ def plot_scatterplots(dataframe,
812
+ name=None,
813
+ setup_args={},
814
+ add_scatterplot_args={},
815
+ selected_dataset=0):
816
+ """Wrapper helper function to plot scatter plots.
817
+ Sets up the matrix object and plots the scatter plots, see parameters in
818
+ setup_scatter_matrix and add_scatterplot.
819
+
820
+ Parameters
821
+ ----------
822
+ dataframe : data object
823
+ Tigramite dataframe object. It must have the attributes dataframe.values
824
+ yielding a numpy array of shape (observations T, variables N) and
825
+ optionally a mask of the same shape and a missing values flag.
826
+ name : str, optional (default: None)
827
+ File name. If None, figure is shown in window.
828
+ setup_args : dict
829
+ Arguments for setting up the scatter plot matrix, see doc of
830
+ setup_scatter_matrix.
831
+ add_scatterplot_args : dict
832
+ Arguments for adding a scatter plot matrix.
833
+ selected_dataset : int, optional (default: 0)
834
+ In case of multiple datasets in dataframe, plot this one.
835
+
836
+ Returns
837
+ -------
838
+ matrix : object
839
+ Further scatter plot can be overlaid using the
840
+ matrix.add_scatterplot function.
841
+ """
842
+
843
+ N = dataframe.N
844
+
845
+ matrix = setup_scatter_matrix(N=N, var_names=dataframe.var_names, **setup_args)
846
+ matrix.add_scatterplot(dataframe=dataframe, selected_dataset=selected_dataset, **add_scatterplot_args)
847
+ matrix.adjustfig(name=name)
848
+
849
+
850
+ return matrix
851
+
852
+
853
+ class setup_scatter_matrix:
854
+ """Create matrix of scatter plot panels.
855
+ Class to setup figure object. The function add_scatterplot allows to plot
856
+ scatterplots of variables in the dataframe. Multiple scatter plots can be
857
+ overlaid for comparison.
858
+
859
+ Parameters
860
+ ----------
861
+ N : int
862
+ Number of variables
863
+ var_names : list, optional (default: None)
864
+ List of variable names. If None, range(N) is used.
865
+ figsize : tuple of floats, optional (default: None)
866
+ Figure size if new figure is created. If None, default pyplot figsize
867
+ is used.
868
+ label_space_left : float, optional (default: 0.1)
869
+ Fraction of horizontal figure space to allocate left of plot for labels.
870
+ label_space_top : float, optional (default: 0.05)
871
+ Fraction of vertical figure space to allocate top of plot for labels.
872
+ legend_width : float, optional (default: 0.15)
873
+ Fraction of horizontal figure space to allocate right of plot for
874
+ legend.
875
+ tick_label_size : int, optional (default: 6)
876
+ Fontsize of tick labels.
877
+ plot_gridlines : bool, optional (default: False)
878
+ Whether to show a grid.
879
+ label_fontsize : int, optional (default: 10)
880
+ Fontsize of variable labels.
881
+ """
882
+
883
+ def __init__(
884
+ self,
885
+ N,
886
+ var_names=None,
887
+ figsize=None,
888
+ label_space_left=0.1,
889
+ label_space_top=0.05,
890
+ legend_width=0.15,
891
+ legend_fontsize=10,
892
+ plot_gridlines=False,
893
+ tick_label_size=6,
894
+ label_fontsize=10,
895
+ ):
896
+
897
+ self.labels = []
898
+
899
+ self.legend_width = legend_width
900
+ self.legend_fontsize = legend_fontsize
901
+
902
+ self.label_space_left = label_space_left
903
+ self.label_space_top = label_space_top
904
+ self.label_fontsize = label_fontsize
905
+
906
+ self.fig = pyplot.figure(figsize=figsize)
907
+
908
+ self.axes_dict = {}
909
+
910
+ if var_names is None:
911
+ var_names = range(N)
912
+
913
+ plot_index = 1
914
+ for i in range(N):
915
+ for j in range(N):
916
+ self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index, axes_class=Axes)
917
+ # Plot process labels
918
+ if j == 0:
919
+ trans = transforms.blended_transform_factory(
920
+ self.fig.transFigure, self.axes_dict[(i, j)].transAxes
921
+ )
922
+ self.axes_dict[(i, j)].text(
923
+ 0.01,
924
+ 0.5,
925
+ "%s" % str(var_names[i]),
926
+ fontsize=label_fontsize,
927
+ horizontalalignment="left",
928
+ verticalalignment="center",
929
+ transform=trans,
930
+ )
931
+ if i == 0:
932
+ trans = transforms.blended_transform_factory(
933
+ self.axes_dict[(i, j)].transAxes, self.fig.transFigure
934
+ )
935
+ self.axes_dict[(i, j)].text(
936
+ 0.5,
937
+ 0.99,
938
+ r"${\to}$ " + "%s" % str(var_names[j]),
939
+ fontsize=label_fontsize,
940
+ horizontalalignment="center",
941
+ verticalalignment="top",
942
+ transform=trans,
943
+ )
944
+
945
+ self.axes_dict[(i, j)].axis["right"].set_visible(False)
946
+ self.axes_dict[(i, j)].axis["top"].set_visible(False)
947
+
948
+ if j != 0:
949
+ self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
950
+ if i != N - 1:
951
+ self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])
952
+
953
+ if plot_gridlines:
954
+ self.axes_dict[(i, j)].grid(
955
+ True,
956
+ which="major",
957
+ color="black",
958
+ linestyle="dotted",
959
+ dashes=(1, 1),
960
+ linewidth=0.05,
961
+ zorder=-5,
962
+ )
963
+ self.axes_dict[(i, j)].tick_params(axis='both', which='major', labelsize=tick_label_size)
964
+
965
+ plot_index += 1
966
+
967
+ def add_scatterplot(
968
+ self,
969
+ dataframe,
970
+ matrix_lags=None,
971
+ color="black",
972
+ label=None,
973
+ marker=".",
974
+ markersize=5,
975
+ alpha=.2,
976
+ selected_dataset=0,
977
+ ):
978
+ """Add scatter plot.
979
+
980
+ Parameters
981
+ ----------
982
+ dataframe : data object
983
+ Tigramite dataframe object. It must have the attributes dataframe.values
984
+ yielding a numpy array of shape (observations T, variables N) and
985
+ optionally a mask of the same shape and a missing values flag.
986
+ matrix_lags : array
987
+ Lags to use in scatter plots. Either None or of shape (N, N). Then the
988
+ entry matrix_lags[i, j] = tau will depict the scatter plot of
989
+ time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
990
+ tau = 1.
991
+ color : str, optional (default: 'black')
992
+ Line color.
993
+ label : str
994
+ Test statistic label.
995
+ marker : matplotlib marker symbol, optional (default: '.')
996
+ Marker.
997
+ markersize : int, optional (default: 5)
998
+ Marker size.
999
+ alpha : float, optional (default: 1.)
1000
+ Opacity.
1001
+ selected_dataset : int, optional (default: 0)
1002
+ In case of multiple datasets in dataframe, plot this one.
1003
+ """
1004
+
1005
+ if matrix_lags is not None and np.any(matrix_lags < 0):
1006
+ raise ValueError("matrix_lags must be non-negative!")
1007
+
1008
+ data = dataframe.values[selected_dataset]
1009
+ if dataframe.mask is not None:
1010
+ mask = dataframe.mask[selected_dataset]
1011
+
1012
+ T, dim = data.shape
1013
+
1014
+ if label is not None:
1015
+ self.labels.append((label, color, marker, markersize, alpha))
1016
+
1017
+ for ij in list(self.axes_dict):
1018
+ i = ij[0]
1019
+ j = ij[1]
1020
+ if matrix_lags is None:
1021
+ if i == j:
1022
+ lag = 1
1023
+ else:
1024
+ lag = 0
1025
+ else:
1026
+ lag = matrix_lags[i,j]
1027
+ x = np.copy(data[:T-lag, i])
1028
+ y = np.copy(data[lag:, j])
1029
+ if dataframe.mask is not None:
1030
+ x[mask[:T-lag, i]==1] = np.nan
1031
+ y[mask[lag:, j]==1] = np.nan
1032
+
1033
+ # print(i, j, lag, x.shape, y.shape)
1034
+ self.axes_dict[(i, j)].scatter(
1035
+ y, x, # NEW: inverted to match rows and columns!
1036
+ color=color,
1037
+ marker=marker,
1038
+ s=markersize,
1039
+ alpha=alpha,
1040
+ clip_on=False,
1041
+ label=r"$\tau{=}%d$" %lag,
1042
+ )
1043
+ # self.axes_dict[(i, j)].text(0., 1., r"$\tau{=}%d$" %lag,
1044
+ # fontsize=self.legend_fontsize,
1045
+ # ha='left', va='top',
1046
+ # transform=self.axes_dict[(i, j)].transAxes)
1047
+
1048
+
1049
+ def adjustfig(self, name=None):
1050
+ """Adjust matrix figure.
1051
+
1052
+ Parameters
1053
+ ----------
1054
+ name : str, optional (default: None)
1055
+ File name. If None, figure is shown in window.
1056
+ """
1057
+
1058
+ # Trick to plot legends
1059
+ colors = []
1060
+ for item in self.labels:
1061
+ colors.append(item[1])
1062
+ for ij in list(self.axes_dict):
1063
+ i = ij[0]
1064
+ j = ij[1]
1065
+
1066
+ leg = self.axes_dict[(i, j)].legend(
1067
+ # loc="upper left",
1068
+ ncol=1,
1069
+ # bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1070
+ # borderaxespad=0,
1071
+ fontsize=self.legend_fontsize-2,
1072
+ labelcolor=colors,
1073
+ ).draw_frame(False)
1074
+
1075
+ if len(self.labels) > 0:
1076
+ axlegend = self.fig.add_subplot(111, frameon=False)
1077
+ axlegend.spines["left"].set_color("none")
1078
+ axlegend.spines["right"].set_color("none")
1079
+ axlegend.spines["bottom"].set_color("none")
1080
+ axlegend.spines["top"].set_color("none")
1081
+ axlegend.set_xticks([])
1082
+ axlegend.set_yticks([])
1083
+
1084
+ # self.labels.append((label, color, marker, markersize, alpha))
1085
+ for item in self.labels:
1086
+ label = item[0]
1087
+ color = item[1]
1088
+ marker = item[2]
1089
+ markersize = item[3]
1090
+ alpha = item[4]
1091
+
1092
+ axlegend.plot(
1093
+ [],
1094
+ [],
1095
+ linestyle="",
1096
+ color=color,
1097
+ marker=marker,
1098
+ markersize=markersize,
1099
+ label=label,
1100
+ alpha=alpha,
1101
+ )
1102
+ axlegend.legend(
1103
+ loc="upper left",
1104
+ ncol=1,
1105
+ bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1106
+ borderaxespad=0,
1107
+ fontsize=self.legend_fontsize,
1108
+ ).draw_frame(False)
1109
+
1110
+ self.fig.subplots_adjust(
1111
+ bottom=0.05,
1112
+ left=self.label_space_left,
1113
+ right=1.0 - self.legend_width,
1114
+ top=1.0 - self.label_space_top,
1115
+ hspace=0.5,
1116
+ wspace=0.35,
1117
+ )
1118
+
1119
+ else:
1120
+ self.fig.subplots_adjust(
1121
+ left=self.label_space_left,
1122
+ bottom=0.05,
1123
+ right=0.95,
1124
+ top=1.0 - self.label_space_top,
1125
+ hspace=0.35,
1126
+ wspace=0.35,
1127
+ )
1128
+
1129
+ if name is not None:
1130
+ self.fig.savefig(name)
1131
+ else:
1132
+ pyplot.show()
1133
+
1134
+
1135
+ def plot_densityplots(dataframe,
1136
+ name=None,
1137
+ setup_args={},
1138
+ add_densityplot_args={},
1139
+ selected_dataset=0,
1140
+ show_marginal_densities_on_diagonal=True):
1141
+ """Wrapper helper function to plot density plots.
1142
+ Sets up the matrix object and plots the density plots, see parameters in
1143
+ setup_density_matrix and add_densityplot.
1144
+
1145
+ The diagonal shows the marginal densities.
1146
+
1147
+ Requires seaborn.
1148
+
1149
+ Parameters
1150
+ ----------
1151
+ dataframe : data object
1152
+ Tigramite dataframe object. It must have the attributes dataframe.values
1153
+ yielding a numpy array of shape (observations T, variables N) and
1154
+ optionally a mask of the same shape and a missing values flag.
1155
+ name : str, optional (default: None)
1156
+ File name. If None, figure is shown in window.
1157
+ setup_args : dict
1158
+ Arguments for setting up the density plot matrix, see doc of
1159
+ setup_density_matrix.
1160
+ add_densityplot_args : dict
1161
+ Arguments for adding a density plot matrix.
1162
+ selected_dataset : int, optional (default: 0)
1163
+ In case of multiple datasets in dataframe, plot this one.
1164
+ show_marginal_densities_on_diagonal : bool, optional (default: True)
1165
+ Flag to show marginal densities on the diagonal of the density plots
1166
+
1167
+ Returns
1168
+ -------
1169
+ matrix : object
1170
+ Further density plots can be overlaid using the
1171
+ matrix.add_densityplot function.
1172
+ """
1173
+
1174
+ N = dataframe.N
1175
+
1176
+ matrix = setup_density_matrix(N=N, var_names=dataframe.var_names, **setup_args)
1177
+ matrix.add_densityplot(dataframe=dataframe, selected_dataset=selected_dataset,
1178
+ show_marginal_densities_on_diagonal=show_marginal_densities_on_diagonal, **add_densityplot_args)
1179
+ matrix.adjustfig(name=name)
1180
+
1181
+
1182
+ return matrix
1183
+
1184
+
1185
+ class setup_density_matrix:
1186
+ """Create matrix of density plot panels.
1187
+ Class to setup figure object. The function add_densityplot allows to plot
1188
+ density plots of variables in the dataframe.
1189
+
1190
+ Further density plots can be overlaid using the matrix.add_densityplot
1191
+ function.
1192
+
1193
+ Parameters
1194
+ ----------
1195
+ N : int
1196
+ Number of variables
1197
+ var_names : list, optional (default: None)
1198
+ List of variable names. If None, range(N) is used.
1199
+ figsize : tuple of floats, optional (default: None)
1200
+ Figure size if new figure is created. If None, default pyplot figsize
1201
+ is used.
1202
+ label_space_left : float, optional (default: 0.1)
1203
+ Fraction of horizontal figure space to allocate left of plot for labels.
1204
+ label_space_top : float, optional (default: 0.05)
1205
+ Fraction of vertical figure space to allocate top of plot for labels.
1206
+ legend_width : float, optional (default: 0.15)
1207
+ Fraction of horizontal figure space to allocate right of plot for
1208
+ legend.
1209
+ tick_label_size : int, optional (default: 6)
1210
+ Fontsize of tick labels.
1211
+ plot_gridlines : bool, optional (default: False)
1212
+ Whether to show a grid.
1213
+ label_fontsize : int, optional (default: 10)
1214
+ Fontsize of variable labels.
1215
+ """
1216
+
1217
+ def __init__(
1218
+ self,
1219
+ N,
1220
+ var_names=None,
1221
+ figsize=None,
1222
+ label_space_left=0.15,
1223
+ label_space_top=0.05,
1224
+ legend_width=0.15,
1225
+ legend_fontsize=10,
1226
+ tick_label_size=6,
1227
+ plot_gridlines=False,
1228
+ label_fontsize=10,
1229
+ ):
1230
+
1231
+ self.labels = []
1232
+
1233
+ self.legend_width = legend_width
1234
+ self.legend_fontsize = legend_fontsize
1235
+
1236
+ self.label_space_left = label_space_left
1237
+ self.label_space_top = label_space_top
1238
+ self.label_fontsize = label_fontsize
1239
+
1240
+ self.fig = pyplot.figure(figsize=figsize)
1241
+
1242
+ self.axes_dict = {}
1243
+
1244
+ if var_names is None:
1245
+ var_names = range(N)
1246
+
1247
+ plot_index = 1
1248
+ for i in range(N):
1249
+ for j in range(N):
1250
+ self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index)
1251
+ # Plot process labels
1252
+ if j == 0:
1253
+ trans = transforms.blended_transform_factory(
1254
+ self.fig.transFigure, self.axes_dict[(i, j)].transAxes
1255
+ )
1256
+ self.axes_dict[(i, j)].text(
1257
+ 0.01,
1258
+ 0.5,
1259
+ "%s" % str(var_names[i]),
1260
+ fontsize=label_fontsize,
1261
+ horizontalalignment="left",
1262
+ verticalalignment="center",
1263
+ transform=trans,
1264
+ )
1265
+ if i == 0:
1266
+ trans = transforms.blended_transform_factory(
1267
+ self.axes_dict[(i, j)].transAxes, self.fig.transFigure
1268
+ )
1269
+ self.axes_dict[(i, j)].text(
1270
+ 0.5,
1271
+ 0.99,
1272
+ r"${\to}$ " + "%s" % str(var_names[j]),
1273
+ fontsize=label_fontsize,
1274
+ horizontalalignment="center",
1275
+ verticalalignment="top",
1276
+ transform=trans,
1277
+ )
1278
+
1279
+ # _make_nice_axes(self.axes_dict[(i, j)], where=["bottom"], skip=(1, 1) )
1280
+ # self.axes_dict[(i, j)].axis["right"].set_visible(False)
1281
+ # self.axes_dict[(i, j)].axis["top"].set_visible(False)
1282
+ if i == j:
1283
+ # self.axes_dict[(i, j)].axis["left"].set_visible(False)
1284
+ _make_nice_axes(self.axes_dict[(i, j)], where=["bottom"], skip=(1, 1))
1285
+ else:
1286
+ _make_nice_axes(self.axes_dict[(i, j)], where=["left", "bottom"], skip=(1, 1))
1287
+ # if j != 0:
1288
+ # self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
1289
+ # if i != N - 1:
1290
+ # self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])
1291
+
1292
+ if plot_gridlines:
1293
+ self.axes_dict[(i, j)].grid(
1294
+ True,
1295
+ which="major",
1296
+ color="black",
1297
+ linestyle="dotted",
1298
+ dashes=(1, 1),
1299
+ linewidth=0.05,
1300
+ zorder=-5,
1301
+ )
1302
+ self.axes_dict[(i, j)].tick_params(axis='both', which='major', labelsize=tick_label_size)
1303
+ plot_index += 1
1304
+
1305
+ def add_densityplot(
1306
+ self,
1307
+ dataframe,
1308
+ matrix_lags=None,
1309
+ label=None,
1310
+ label_color='black',
1311
+ snskdeplot_args = {'cmap':'Greys'},
1312
+ snskdeplot_diagonal_args = {},
1313
+ selected_dataset=0,
1314
+ show_marginal_densities_on_diagonal=True
1315
+ ):
1316
+ """Add density function plot.
1317
+
1318
+ Parameters
1319
+ ----------
1320
+ dataframe : data object
1321
+ Tigramite dataframe object. It must have the attributes dataframe.values
1322
+ yielding a numpy array of shape (observations T, variables N) and
1323
+ optionally a mask of the same shape and a missing values flag.
1324
+ matrix_lags : array
1325
+ Lags to use in scatter plots. Either None or non-neg array of shape (N, N). Then the
1326
+ entry matrix_lags[i, j] = tau will depict the scatter plot of
1327
+ time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
1328
+ tau = 1.
1329
+ snskdeplot_args : dict
1330
+ Optional parameters to pass to sns.kdeplot() for i != j for off-diagonal plots.
1331
+ snskdeplot_diagonal_args : dict
1332
+ Optional parameters to pass to sns.kdeplot() for i == j on diagonal.
1333
+ label : string
1334
+ Label of this plot.
1335
+ label_color : string
1336
+ Color of line created just for legend.
1337
+ selected_dataset : int, optional (default: 0)
1338
+ In case of multiple datasets in dataframe, plot this one.
1339
+ show_marginal_densities_on_diagonal : bool, optional (default: True)
1340
+ Flag to show marginal densities on the diagonal of the density plots
1341
+ """
1342
+
1343
+ # Use seaborn for this one
1344
+ import seaborn as sns
1345
+
1346
+ # set seaborn style
1347
+ sns.set_style("white")
1348
+
1349
+ self.matrix_lags = matrix_lags
1350
+
1351
+ if matrix_lags is not None and np.any(matrix_lags < 0):
1352
+ raise ValueError("matrix_lags must be non-negative!")
1353
+
1354
+ data = dataframe.values[selected_dataset]
1355
+ if dataframe.mask is not None:
1356
+ mask = dataframe.mask[selected_dataset]
1357
+
1358
+ T, dim = data.shape
1359
+
1360
+ # if label is not None:
1361
+ self.labels.append((label, label_color))
1362
+
1363
+ for ij in list(self.axes_dict):
1364
+ i = ij[0]
1365
+ j = ij[1]
1366
+ ax = self.axes_dict[(i, j)]
1367
+ if (matrix_lags is None):
1368
+ if i == j:
1369
+ lag = 1
1370
+ else:
1371
+ lag = 0
1372
+ else:
1373
+ lag = matrix_lags[i,j]
1374
+ x = np.copy(data[:T-lag, i])
1375
+ y = np.copy(data[lag:, j])
1376
+ # Data is set to NaN in dataframe init already
1377
+ # if dataframe.missing_flag is not None:
1378
+ # x[x==dataframe.missing_flag] = np.nan
1379
+ # y[y==dataframe.missing_flag] = np.nan
1380
+ if dataframe.mask is not None:
1381
+ x[mask[:T-lag, i]==1] = np.nan
1382
+ y[mask[lag:, j]==1] = np.nan
1383
+
1384
+ if i == j and show_marginal_densities_on_diagonal:
1385
+ sns.kdeplot(x,
1386
+ color = label_color,
1387
+ # label=r"$\tau{=}%d$" %lag,
1388
+ **snskdeplot_diagonal_args,
1389
+ ax = ax)
1390
+ ax.set_ylabel("")
1391
+ # ax.yaxis.set_ticks_position("none")
1392
+ # ax.yaxis.set_ticklabels([])
1393
+ else:
1394
+ sns.kdeplot(x=y, y=x, # NEW: inverted to match rows/columns
1395
+ #label=r"$\tau{=}%d$" %lag,
1396
+ **snskdeplot_args,
1397
+ # fill=True,
1398
+ # alpha=0.3,
1399
+ ax = ax)
1400
+
1401
+ def adjustfig(self, name=None, show_labels=True):
1402
+ """Adjust matrix figure.
1403
+
1404
+ Parameters
1405
+ ----------
1406
+ name : str, optional (default: None)
1407
+ File name. If None, figure is shown in window.
1408
+ """
1409
+
1410
+ # Trick to plot legends
1411
+ # colors = []
1412
+ # for item in self.labels:
1413
+ # colors.append(item[1])
1414
+ for ij in list(self.axes_dict):
1415
+ i = ij[0]
1416
+ j = ij[1]
1417
+ if self.matrix_lags is None:
1418
+ lag = 0
1419
+ else:
1420
+ lag = self.matrix_lags[i,j]
1421
+ if i != j:
1422
+ colors = []
1423
+ for item in self.labels:
1424
+ color = item[1]
1425
+ colors.append(color)
1426
+ if show_labels:
1427
+ self.axes_dict[(i, j)].plot(
1428
+ [],
1429
+ [],
1430
+ linestyle="",
1431
+ color=color,
1432
+ label=r"$\tau{=}%d$" %lag,
1433
+ )
1434
+ # print('here')
1435
+ leg = self.axes_dict[(i, j)].legend(
1436
+ # loc="best",
1437
+ ncol=1,
1438
+ # bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1439
+ # borderaxespad=0,
1440
+ fontsize=self.legend_fontsize-2,
1441
+ labelcolor=colors,
1442
+ ).draw_frame(False)
1443
+
1444
+ # if i == j:
1445
+ # # self.axes_dict[(i, j)].axis["left"].set_visible(False)
1446
+ # _make_nice_axes(ax=self.axes_dict[(i, j)], where=["bottom"], skip=(1, 1))
1447
+ # else:
1448
+ # _make_nice_axes(ax=self.axes_dict[(i, j)], where=["left", "bottom"], skip=(1, 1))
1449
+
1450
+ if show_labels and len(self.labels) > 1:
1451
+ axlegend = self.fig.add_subplot(111, frameon=False)
1452
+ axlegend.spines["left"].set_color("none")
1453
+ axlegend.spines["right"].set_color("none")
1454
+ axlegend.spines["bottom"].set_color("none")
1455
+ axlegend.spines["top"].set_color("none")
1456
+ axlegend.set_xticks([])
1457
+ axlegend.set_yticks([])
1458
+
1459
+ # self.labels.append((label, color, marker, markersize, alpha))
1460
+ for item in self.labels:
1461
+ label = item[0]
1462
+ color = item[1]
1463
+
1464
+ axlegend.plot(
1465
+ [],
1466
+ [],
1467
+ linestyle="-",
1468
+ color=color,
1469
+ label=label,
1470
+ )
1471
+ axlegend.legend(
1472
+ loc="upper left",
1473
+ ncol=1,
1474
+ bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1475
+ borderaxespad=0,
1476
+ fontsize=self.legend_fontsize,
1477
+ ).draw_frame(False)
1478
+
1479
+ self.fig.subplots_adjust(
1480
+ bottom=0.08,
1481
+ left=self.label_space_left,
1482
+ right=1.0 - self.legend_width,
1483
+ top=1.0 - self.label_space_top,
1484
+ hspace=0.5,
1485
+ wspace=0.35,
1486
+ )
1487
+
1488
+ else:
1489
+ self.fig.subplots_adjust(
1490
+ left=self.label_space_left,
1491
+ bottom=0.08,
1492
+ right=0.95,
1493
+ top=1.0 - self.label_space_top,
1494
+ hspace=0.35,
1495
+ wspace=0.35,
1496
+ )
1497
+
1498
+ if name is not None:
1499
+ self.fig.savefig(name)
1500
+ else:
1501
+ pyplot.show()
1502
+
1503
+ def _draw_network_with_curved_edges(
1504
+ fig,
1505
+ ax,
1506
+ G,
1507
+ pos,
1508
+ node_rings,
1509
+ node_labels,
1510
+ node_label_size=10,
1511
+ node_alpha=1.0,
1512
+ standard_size=100,
1513
+ node_aspect=None,
1514
+ standard_cmap="OrRd",
1515
+ standard_color_links='black',
1516
+ standard_color_nodes='lightgrey',
1517
+ log_sizes=False,
1518
+ cmap_links="YlOrRd",
1519
+ # cmap_links_edges="YlOrRd",
1520
+ links_vmin=0.0,
1521
+ links_vmax=1.0,
1522
+ links_edges_vmin=0.0,
1523
+ links_edges_vmax=1.0,
1524
+ links_ticks=0.2,
1525
+ links_edges_ticks=0.2,
1526
+ link_label_fontsize=8,
1527
+ arrowstyle="->, head_width=0.4, head_length=1",
1528
+ arrowhead_size=3.0,
1529
+ curved_radius=0.2,
1530
+ label_fontsize=4,
1531
+ label_fraction=0.5,
1532
+ link_colorbar_label="link",
1533
+ tick_label_size=6,
1534
+ # link_edge_colorbar_label='link_edge',
1535
+ inner_edge_curved=False,
1536
+ inner_edge_style="solid",
1537
+ # network_lower_bound=0.2,
1538
+ network_left_bound=None,
1539
+ show_colorbar=True,
1540
+ special_nodes=None,
1541
+ autodep_sig_lags=None,
1542
+ show_autodependency_lags=False,
1543
+ transform='data',
1544
+ node_classification=None,
1545
+ max_lag=0,
1546
+ special_links=None,
1547
+ ):
1548
+ """Function to draw a network from networkx graph instance.
1549
+ Various attributes are used to specify the graph's properties.
1550
+ This function is just a beta-template for now that can be further
1551
+ customized.
1552
+ """
1553
+
1554
+ if transform == 'data':
1555
+ transform = ax.transData
1556
+
1557
+ from matplotlib.patches import FancyArrowPatch, Circle, Ellipse
1558
+
1559
+ ax.spines["left"].set_color("none")
1560
+ ax.spines["right"].set_color("none")
1561
+ ax.spines["bottom"].set_color("none")
1562
+ ax.spines["top"].set_color("none")
1563
+ ax.set_xticks([])
1564
+ ax.set_yticks([])
1565
+
1566
+ N = len(G)
1567
+
1568
+ # This fixes a positioning bug in matplotlib.
1569
+ ax.scatter(0, 0, zorder=-10, alpha=0)
1570
+
1571
+ def draw_edge(
1572
+ ax,
1573
+ u,
1574
+ v,
1575
+ d,
1576
+ seen,
1577
+ arrowstyle= "Simple, head_width=2, head_length=2, tail_width=1",
1578
+ outer_edge=True,
1579
+ cycle_edges_call_count=0
1580
+ ):
1581
+
1582
+ # avoiding attribute error raised by changes in networkx
1583
+ if hasattr(G, "node"):
1584
+ # works with networkx 1.10
1585
+ n1 = G.node[u]["patch"]
1586
+ n2 = G.node[v]["patch"]
1587
+ else:
1588
+ # works with networkx 2.4
1589
+ n1 = G.nodes[u]["patch"]
1590
+ n2 = G.nodes[v]["patch"]
1591
+
1592
+ # print("+++++++++++++++++++++++==cmap_links ", cmap_links)
1593
+ if outer_edge:
1594
+
1595
+ rad = -1.0 * curved_radius
1596
+ if cmap_links is not None:
1597
+ facecolor = data_to_rgb_links.to_rgba(d["outer_edge_color"])
1598
+ else:
1599
+ if d["outer_edge_color"] is not None:
1600
+ facecolor = d["outer_edge_color"]
1601
+ else:
1602
+ facecolor = standard_color_links
1603
+
1604
+ width = d["outer_edge_width"]
1605
+ alpha = d["outer_edge_alpha"]
1606
+ if (u, v) in seen:
1607
+ rad = seen.get((u, v))
1608
+ rad = (rad + np.sign(rad) * 0.1) * -1.0
1609
+ arrowstyle = arrowstyle
1610
+ # link_edge = d['outer_edge_edge']
1611
+ linestyle = 'solid' # d.get("outer_edge_style")
1612
+
1613
+ if cycle_edges_call_count==0:
1614
+
1615
+ if d.get("outer_edge_attribute", None) == "spurious":
1616
+ facecolor = "grey"
1617
+
1618
+ if d.get("outer_edge_type") in ["<-o", "<--", "<-x", "<-+"]:
1619
+ n1, n2 = n2, n1
1620
+
1621
+ if d.get("outer_edge_type") in [
1622
+ "o-o",
1623
+ "o--",
1624
+ "--o",
1625
+ "---",
1626
+ "x-x",
1627
+ "x--",
1628
+ "--x",
1629
+ "o-x",
1630
+ "x-o",
1631
+ # "+->",
1632
+ # "<-+",
1633
+ ]:
1634
+ arrowstyle = "-"
1635
+ # linewidth = width*factor
1636
+ elif d.get("outer_edge_type") == "<->":
1637
+ # arrowstyle = "<->, head_width=0.4, head_length=1"
1638
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1639
+ elif d.get("outer_edge_type") in ["o->", "-->", "<-o", "<--", "<-x", "x->", "+->", "<-+"]:
1640
+ # arrowstyle = "->, head_width=0.4, head_length=1"
1641
+ # arrowstyle = "->, head_width=0.4, head_length=1, width=10"
1642
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1643
+ else:
1644
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1645
+ # raise ValueError("edge type %s not valid." %d.get("outer_edge_type"))
1646
+
1647
+ if special_links is not None and special_links.vanishing_link(u, v):
1648
+ facecolor = special_links.get_color(u, v)
1649
+ # draw_edge(ax, u, v, d, seen, arrowstyle, outer_edge, # color edge plus forground edge (eg corr-color)
1650
+ # cycle_edges_call_count=cycle_edges_call_count + 1)
1651
+ arrowstyle = "Simple, head_width=1.9, head_length=1.8, tail_width=0.8" #%float(width/20.)
1652
+
1653
+ else: # cycle_edges_call_count > 1 # Added: Draw Vanishing edge
1654
+ facecolor = special_links.get_color(u, v)
1655
+ arrowstyle = "Simple, head_width=2.5, head_length=2.2, tail_width=1.5" #%float(width/20.)
1656
+
1657
+ else: # if inner edge:
1658
+ rad = -1.0 * inner_edge_curved * curved_radius
1659
+ if cmap_links is not None:
1660
+ facecolor = data_to_rgb_links.to_rgba(d["inner_edge_color"])
1661
+ else:
1662
+ if d["inner_edge_color"] is not None:
1663
+ facecolor = d["inner_edge_color"]
1664
+ else:
1665
+ # print("HERE")
1666
+ facecolor = standard_color_links
1667
+
1668
+ width = d["inner_edge_width"]
1669
+ alpha = d["inner_edge_alpha"]
1670
+
1671
+ if d.get("inner_edge_attribute", None) == "spurious":
1672
+ facecolor = "grey"
1673
+ # print(d.get("inner_edge_type"))
1674
+ if d.get("inner_edge_type") in ["<-o", "<--", "<-x", "<-+"]:
1675
+ n1, n2 = n2, n1
1676
+
1677
+ def get_arrow_style(edge_type: str, scale: float=1.0) -> str:
1678
+ if edge_type in [
1679
+ "o-o",
1680
+ "o--",
1681
+ "--o",
1682
+ "---",
1683
+ "x-x",
1684
+ "x--",
1685
+ "--x",
1686
+ "o-x",
1687
+ "x-o",
1688
+ ]:
1689
+ # return "-"
1690
+ return f"simple, tail_width={scale * 1.0}"
1691
+ elif edge_type == "<->":
1692
+ # arrowstyle = "<->, head_width=0.4, head_length=1"
1693
+ return f"Simple, head_width={scale * 2.0}, head_length={scale * 2.0}, tail_width={scale * 1.0}" #%float(width/20.)
1694
+ elif edge_type in ["o->", "-->", "<-o", "<--", "<-x", "x->", "+->", "<-+"]:
1695
+ # arrowstyle = "->, head_width=0.4, head_length=1"
1696
+ return f"Simple, head_width={scale * 2.0}, head_length={scale * 2.0}, tail_width={scale * 1.0}" #%float(width/20.)
1697
+ elif edge_type == "<O>": # Added: Draw Edge-Flip
1698
+ # arrowstyle = "<->, head_width=0.4, head_length=1"
1699
+ assert False
1700
+ if cycle_edges_call_count < 3:
1701
+ draw_edge(ax, v, u, d, seen, arrowstyle, outer_edge, cycle_edges_call_count=cycle_edges_call_count+1)
1702
+
1703
+ rad = -2.0 * curved_radius
1704
+ if cycle_edges_call_count < 2: # drawn in inverse order (draw executes synchronously ...)
1705
+ facecolor = data_to_rgb_links.to_rgba(d["inner_edge_color"])
1706
+ arrowstyle = "Simple, head_width=1.9, head_length=2.0, tail_width=0.7" #%float(width/20.)
1707
+ else:
1708
+ facecolor = special_links.get_color(u, v)
1709
+ arrowstyle = "Simple, head_width=2.5, head_length=2.2, tail_width=1.5" #%float(width/20.)
1710
+ else:
1711
+ return f"Simple, head_width={scale * 2.0}, head_length={scale * 2.0}, tail_width={scale * 1.0}" #%float(width/20.)
1712
+
1713
+
1714
+ if cycle_edges_call_count==0:
1715
+
1716
+ # added:
1717
+ arrowstyle = None
1718
+ if special_links is not None and special_links.vanishing_link(u, v):
1719
+ arrowstyle = get_arrow_style( d.get("inner_edge_type"), scale=0.8 )
1720
+
1721
+ facecolor = special_links.get_color(u, v)
1722
+ # draw_edge(ax, u, v, d, seen, arrowstyle, outer_edge, # color edge plus forground edge (eg corr-color)
1723
+ # cycle_edges_call_count=cycle_edges_call_count + 1)
1724
+
1725
+ # eg for '°-°' for arrowstyle = "-" ...
1726
+ else:
1727
+ arrowstyle = get_arrow_style( d.get("inner_edge_type") )
1728
+
1729
+ # raise ValueError("edge type %s not valid." %d.get("inner_edge_type"))
1730
+
1731
+ linestyle = 'solid' #d.get("inner_edge_style")
1732
+
1733
+ else: # cycle_edges_call_count > 1 # Added: Draw Vanishing edge
1734
+ facecolor = special_links.get_color(u, v)
1735
+ arrowstyle = get_arrow_style( d.get("inner_edge_type"), scale=1.3 )
1736
+
1737
+ linestyle = 'solid' #d.get("inner_edge_style")
1738
+
1739
+ coor1 = n1.center
1740
+ coor2 = n2.center
1741
+
1742
+ marker_size = width ** 2
1743
+ figuresize = fig.get_size_inches()
1744
+
1745
+ # print("COLOR ", facecolor)
1746
+ # print(u, v, outer_edge, "outer ", d.get("outer_edge_type"), "inner ", d.get("inner_edge_type"), width, arrowstyle, linestyle)
1747
+
1748
+ if ((outer_edge is True and d.get("outer_edge_type") == "<->")
1749
+ or (outer_edge is False and d.get("inner_edge_type") == "<->")):
1750
+ e_p = FancyArrowPatch(
1751
+ coor1,
1752
+ coor2,
1753
+ arrowstyle=arrowstyle,
1754
+ connectionstyle=f"arc3,rad={rad}",
1755
+ mutation_scale=1*width,
1756
+ lw=0., #width / 2.,
1757
+ aa=True,
1758
+ alpha=alpha,
1759
+ linestyle=linestyle,
1760
+ color=facecolor,
1761
+ clip_on=False,
1762
+ patchA=n1,
1763
+ patchB=n2,
1764
+ shrinkA=7,
1765
+ shrinkB=0,
1766
+ zorder=-1,
1767
+ capstyle="butt",
1768
+ transform=transform,
1769
+ )
1770
+ ax.add_artist(e_p)
1771
+
1772
+ e_p_back = FancyArrowPatch(
1773
+ coor2,
1774
+ coor1,
1775
+ arrowstyle=arrowstyle,
1776
+ connectionstyle=f"arc3,rad={-rad}",
1777
+ mutation_scale=1*width,
1778
+ lw=0., #width / 2.,
1779
+ aa=True,
1780
+ alpha=alpha,
1781
+ linestyle=linestyle,
1782
+ color=facecolor,
1783
+ clip_on=False,
1784
+ patchA=n2,
1785
+ patchB=n1,
1786
+ shrinkA=7,
1787
+ shrinkB=0,
1788
+ zorder=-1,
1789
+ capstyle="butt",
1790
+ transform=transform,
1791
+ )
1792
+ ax.add_artist(e_p_back)
1793
+
1794
+ else:
1795
+ if arrowstyle == '-':
1796
+ lw = 1*width
1797
+ else:
1798
+ lw = 0.
1799
+ # e_p = FancyArrowPatch(
1800
+ # coor1,
1801
+ # coor2,
1802
+ # arrowstyle=arrowstyle,
1803
+ # connectionstyle=f"arc3,rad={rad}",
1804
+ # mutation_scale=np.sqrt(width)*2*1.1,
1805
+ # lw=lw*1.1, #width / 2.,
1806
+ # aa=True,
1807
+ # alpha=alpha,
1808
+ # linestyle=linestyle,
1809
+ # color='white',
1810
+ # clip_on=False,
1811
+ # patchA=n1,
1812
+ # patchB=n2,
1813
+ # shrinkA=0,
1814
+ # shrinkB=0,
1815
+ # zorder=-1,
1816
+ # capstyle="butt",
1817
+ # )
1818
+ # ax.add_artist(e_p)
1819
+ e_p = FancyArrowPatch(
1820
+ coor1,
1821
+ coor2,
1822
+ arrowstyle=arrowstyle,
1823
+ connectionstyle=f"arc3,rad={rad}",
1824
+ mutation_scale=1*width,
1825
+ lw=lw, #width / 2.,
1826
+ aa=True,
1827
+ alpha=alpha,
1828
+ linestyle=linestyle,
1829
+ color=facecolor,
1830
+ clip_on=False,
1831
+ patchA=n1,
1832
+ patchB=n2,
1833
+ shrinkA=0,
1834
+ shrinkB=0,
1835
+ # zorder=-1,
1836
+ capstyle="butt",
1837
+ transform=transform,
1838
+ )
1839
+ ax.add_artist(e_p)
1840
+
1841
+ e_p_marker = FancyArrowPatch(
1842
+ coor1,
1843
+ coor2,
1844
+ arrowstyle='-',
1845
+ connectionstyle=f"arc3,rad={rad}",
1846
+ mutation_scale=1*width,
1847
+ lw=0., #width / 2.,
1848
+ aa=True,
1849
+ alpha=0.,
1850
+ linestyle=linestyle,
1851
+ color=facecolor,
1852
+ clip_on=False,
1853
+ patchA=n1,
1854
+ patchB=n2,
1855
+ shrinkA=0,
1856
+ shrinkB=0,
1857
+ zorder=-10,
1858
+ capstyle="butt",
1859
+ transform=transform,
1860
+ )
1861
+ ax.add_artist(e_p_marker)
1862
+
1863
+ # marker_path = e_p_marker.get_path()
1864
+ vertices = e_p_marker.get_path().vertices.copy()
1865
+ # vertices = e_p_marker.get_verts()
1866
+ # vertices = e_p_marker.get_path().to_polygons(transform=None)[0]
1867
+ # print(vertices.shape)
1868
+ m, n = vertices.shape
1869
+
1870
+ # print(vertices)
1871
+ start = vertices[0]
1872
+ end = vertices[-1]
1873
+
1874
+ # This must be added to avoid rescaling of the plot, when no 'o'
1875
+ # or 'x' is added to the graph.
1876
+ ax.scatter(*start, zorder=-10, alpha=0, transform=transform,)
1877
+
1878
+ if outer_edge:
1879
+ if d.get("outer_edge_type") in ["o->", "o--"]:
1880
+ circle_marker_start = ax.scatter(
1881
+ *start,
1882
+ marker="o",
1883
+ s=marker_size,
1884
+ facecolor="w",
1885
+ edgecolor=facecolor,
1886
+ zorder=1,
1887
+ transform=transform,
1888
+ )
1889
+ ax.add_collection(circle_marker_start)
1890
+ elif d.get("outer_edge_type") == "<-o":
1891
+ circle_marker_end = ax.scatter(
1892
+ *start,
1893
+ marker="o",
1894
+ s=marker_size,
1895
+ facecolor="w",
1896
+ edgecolor=facecolor,
1897
+ zorder=1,
1898
+ transform=transform,
1899
+ )
1900
+ ax.add_collection(circle_marker_end)
1901
+ elif d.get("outer_edge_type") == "--o":
1902
+ circle_marker_end = ax.scatter(
1903
+ *end,
1904
+ marker="o",
1905
+ s=marker_size,
1906
+ facecolor="w",
1907
+ edgecolor=facecolor,
1908
+ zorder=1,
1909
+ transform=transform,
1910
+ )
1911
+ ax.add_collection(circle_marker_end)
1912
+ elif d.get("outer_edge_type") in ["x--", "x->"]:
1913
+ circle_marker_start = ax.scatter(
1914
+ *start,
1915
+ marker="X",
1916
+ s=marker_size,
1917
+ facecolor="w",
1918
+ edgecolor=facecolor,
1919
+ zorder=1,
1920
+ transform=transform,
1921
+ )
1922
+ ax.add_collection(circle_marker_start)
1923
+ elif d.get("outer_edge_type") in ["+--", "+->"]:
1924
+ circle_marker_start = ax.scatter(
1925
+ *start,
1926
+ marker="P",
1927
+ s=marker_size,
1928
+ facecolor="w",
1929
+ edgecolor=facecolor,
1930
+ zorder=1,
1931
+ transform=transform,
1932
+ )
1933
+ ax.add_collection(circle_marker_start)
1934
+ elif d.get("outer_edge_type") == "<-x":
1935
+ circle_marker_end = ax.scatter(
1936
+ *start,
1937
+ marker="X",
1938
+ s=marker_size,
1939
+ facecolor="w",
1940
+ edgecolor=facecolor,
1941
+ zorder=1,
1942
+ transform=transform,
1943
+ )
1944
+ ax.add_collection(circle_marker_end)
1945
+ elif d.get("outer_edge_type") == "<-+":
1946
+ circle_marker_end = ax.scatter(
1947
+ *start,
1948
+ marker="P",
1949
+ s=marker_size,
1950
+ facecolor="w",
1951
+ edgecolor=facecolor,
1952
+ zorder=1,
1953
+ transform=transform,
1954
+ )
1955
+ ax.add_collection(circle_marker_end)
1956
+ elif d.get("outer_edge_type") == "--x":
1957
+ circle_marker_end = ax.scatter(
1958
+ *end,
1959
+ marker="X",
1960
+ s=marker_size,
1961
+ facecolor="w",
1962
+ edgecolor=facecolor,
1963
+ zorder=1,
1964
+ transform=transform,
1965
+ )
1966
+ ax.add_collection(circle_marker_end)
1967
+ elif d.get("outer_edge_type") == "o-o":
1968
+ circle_marker_start = ax.scatter(
1969
+ *start,
1970
+ marker="o",
1971
+ s=marker_size,
1972
+ facecolor="w",
1973
+ edgecolor=facecolor,
1974
+ zorder=1,
1975
+ transform=transform,
1976
+ )
1977
+ ax.add_collection(circle_marker_start)
1978
+ circle_marker_end = ax.scatter(
1979
+ *end,
1980
+ marker="o",
1981
+ s=marker_size,
1982
+ facecolor="w",
1983
+ edgecolor=facecolor,
1984
+ zorder=1,
1985
+ transform=transform,
1986
+ )
1987
+ ax.add_collection(circle_marker_end)
1988
+ elif d.get("outer_edge_type") == "x-x":
1989
+ circle_marker_start = ax.scatter(
1990
+ *start,
1991
+ marker="X",
1992
+ s=marker_size,
1993
+ facecolor="w",
1994
+ edgecolor=facecolor,
1995
+ zorder=1,
1996
+ transform=transform,
1997
+ )
1998
+ ax.add_collection(circle_marker_start)
1999
+ circle_marker_end = ax.scatter(
2000
+ *end,
2001
+ marker="X",
2002
+ s=marker_size,
2003
+ facecolor="w",
2004
+ edgecolor=facecolor,
2005
+ zorder=1,
2006
+ transform=transform,
2007
+ )
2008
+ ax.add_collection(circle_marker_end)
2009
+ elif d.get("outer_edge_type") == "o-x":
2010
+ circle_marker_start = ax.scatter(
2011
+ *start,
2012
+ marker="o",
2013
+ s=marker_size,
2014
+ facecolor="w",
2015
+ edgecolor=facecolor,
2016
+ zorder=1,
2017
+ transform=transform,
2018
+ )
2019
+ ax.add_collection(circle_marker_start)
2020
+ circle_marker_end = ax.scatter(
2021
+ *end,
2022
+ marker="X",
2023
+ s=marker_size,
2024
+ facecolor="w",
2025
+ edgecolor=facecolor,
2026
+ zorder=1,
2027
+ transform=transform,
2028
+ )
2029
+ ax.add_collection(circle_marker_end)
2030
+ elif d.get("outer_edge_type") == "x-o":
2031
+ circle_marker_start = ax.scatter(
2032
+ *start,
2033
+ marker="X",
2034
+ s=marker_size,
2035
+ facecolor="w",
2036
+ edgecolor=facecolor,
2037
+ zorder=1,
2038
+ transform=transform,
2039
+ )
2040
+ ax.add_collection(circle_marker_start)
2041
+ circle_marker_end = ax.scatter(
2042
+ *end,
2043
+ marker="o",
2044
+ s=marker_size,
2045
+ facecolor="w",
2046
+ edgecolor=facecolor,
2047
+ zorder=1,
2048
+ transform=transform,
2049
+ )
2050
+ ax.add_collection(circle_marker_end)
2051
+
2052
+ else:
2053
+ if d.get("inner_edge_type") in ["o->", "o--"]:
2054
+ circle_marker_start = ax.scatter(
2055
+ *start,
2056
+ marker="o",
2057
+ s=marker_size,
2058
+ facecolor="w",
2059
+ edgecolor=facecolor,
2060
+ zorder=1,
2061
+ transform=transform,
2062
+ )
2063
+ ax.add_collection(circle_marker_start)
2064
+ elif d.get("inner_edge_type") == "<-o":
2065
+ circle_marker_end = ax.scatter(
2066
+ *start,
2067
+ marker="o",
2068
+ s=marker_size,
2069
+ facecolor="w",
2070
+ edgecolor=facecolor,
2071
+ zorder=1,
2072
+ transform=transform,
2073
+ )
2074
+ ax.add_collection(circle_marker_end)
2075
+ elif d.get("inner_edge_type") == "--o":
2076
+ circle_marker_end = ax.scatter(
2077
+ *end,
2078
+ marker="o",
2079
+ s=marker_size,
2080
+ facecolor="w",
2081
+ edgecolor=facecolor,
2082
+ zorder=1,
2083
+ transform=transform,
2084
+ )
2085
+ ax.add_collection(circle_marker_end)
2086
+ elif d.get("inner_edge_type") in ["x--", "x->"]:
2087
+ circle_marker_start = ax.scatter(
2088
+ *start,
2089
+ marker="X",
2090
+ s=marker_size,
2091
+ facecolor="w",
2092
+ edgecolor=facecolor,
2093
+ zorder=1,
2094
+ transform=transform,
2095
+ )
2096
+ ax.add_collection(circle_marker_start)
2097
+ elif d.get("inner_edge_type") in ["+--", "+->"]:
2098
+ circle_marker_start = ax.scatter(
2099
+ *start,
2100
+ marker="P",
2101
+ s=marker_size,
2102
+ facecolor="w",
2103
+ edgecolor=facecolor,
2104
+ zorder=1,
2105
+ transform=transform,
2106
+ )
2107
+ ax.add_collection(circle_marker_start)
2108
+ elif d.get("inner_edge_type") == "<-x":
2109
+ circle_marker_end = ax.scatter(
2110
+ *start,
2111
+ marker="X",
2112
+ s=marker_size,
2113
+ facecolor="w",
2114
+ edgecolor=facecolor,
2115
+ zorder=1,
2116
+ transform=transform,
2117
+ )
2118
+ ax.add_collection(circle_marker_end)
2119
+ elif d.get("inner_edge_type") == "<-+":
2120
+ circle_marker_end = ax.scatter(
2121
+ *start,
2122
+ marker="P",
2123
+ s=marker_size,
2124
+ facecolor="w",
2125
+ edgecolor=facecolor,
2126
+ zorder=1,
2127
+ transform=transform,
2128
+ )
2129
+ ax.add_collection(circle_marker_end)
2130
+ elif d.get("inner_edge_type") == "--x":
2131
+ circle_marker_end = ax.scatter(
2132
+ *end,
2133
+ marker="X",
2134
+ s=marker_size,
2135
+ facecolor="w",
2136
+ edgecolor=facecolor,
2137
+ zorder=1,
2138
+ transform=transform,
2139
+ )
2140
+ ax.add_collection(circle_marker_end)
2141
+ elif d.get("inner_edge_type") == "o-o":
2142
+ circle_marker_start = ax.scatter(
2143
+ *start,
2144
+ marker="o",
2145
+ s=marker_size,
2146
+ facecolor="w",
2147
+ edgecolor=facecolor,
2148
+ zorder=1,
2149
+ transform=transform,
2150
+ )
2151
+ ax.add_collection(circle_marker_start)
2152
+ circle_marker_end = ax.scatter(
2153
+ *end,
2154
+ marker="o",
2155
+ s=marker_size,
2156
+ facecolor="w",
2157
+ edgecolor=facecolor,
2158
+ zorder=1,
2159
+ transform=transform,
2160
+ )
2161
+ ax.add_collection(circle_marker_end)
2162
+ elif d.get("inner_edge_type") == "x-x":
2163
+ circle_marker_start = ax.scatter(
2164
+ *start,
2165
+ marker="X",
2166
+ s=marker_size,
2167
+ facecolor="w",
2168
+ edgecolor=facecolor,
2169
+ zorder=1,
2170
+ transform=transform,
2171
+ )
2172
+ ax.add_collection(circle_marker_start)
2173
+ circle_marker_end = ax.scatter(
2174
+ *end,
2175
+ marker="X",
2176
+ s=marker_size,
2177
+ facecolor="w",
2178
+ edgecolor=facecolor,
2179
+ zorder=1,
2180
+ transform=transform,
2181
+ )
2182
+ ax.add_collection(circle_marker_end)
2183
+ elif d.get("inner_edge_type") == "o-x":
2184
+ circle_marker_start = ax.scatter(
2185
+ *start,
2186
+ marker="o",
2187
+ s=marker_size,
2188
+ facecolor="w",
2189
+ edgecolor=facecolor,
2190
+ zorder=1,
2191
+ transform=transform,
2192
+ )
2193
+ ax.add_collection(circle_marker_start)
2194
+ circle_marker_end = ax.scatter(
2195
+ *end,
2196
+ marker="X",
2197
+ s=marker_size,
2198
+ facecolor="w",
2199
+ edgecolor=facecolor,
2200
+ zorder=1,
2201
+ transform=transform,
2202
+ )
2203
+ ax.add_collection(circle_marker_end)
2204
+ elif d.get("inner_edge_type") == "x-o":
2205
+ circle_marker_start = ax.scatter(
2206
+ *start,
2207
+ marker="X",
2208
+ s=marker_size,
2209
+ facecolor="w",
2210
+ edgecolor=facecolor,
2211
+ zorder=1,
2212
+ transform=transform,
2213
+ )
2214
+ ax.add_collection(circle_marker_start)
2215
+ circle_marker_end = ax.scatter(
2216
+ *end,
2217
+ marker="o",
2218
+ s=marker_size,
2219
+ facecolor="w",
2220
+ edgecolor=facecolor,
2221
+ zorder=1,
2222
+ transform=transform,
2223
+ )
2224
+ ax.add_collection(circle_marker_end)
2225
+
2226
+
2227
+
2228
+ if d["label"] is not None and outer_edge:
2229
+ def closest_node(node, nodes):
2230
+ nodes = np.asarray(nodes)
2231
+ node = node.reshape(1, 2)
2232
+ dist_2 = np.sum((nodes - node)**2, axis=1)
2233
+ return np.argmin(dist_2)
2234
+
2235
+ # Attach labels of lags
2236
+ # trans = None # patch.get_transform()
2237
+ # path = e_p.get_path()
2238
+ vertices = e_p_marker.get_path().vertices.copy()
2239
+ verts = e_p.get_path().to_polygons(transform=None)[0]
2240
+ # print(verts)
2241
+ # print(verts.shape)
2242
+ # print(vertices.shape)
2243
+ # for num, vert in enumerate(verts):
2244
+ # ax.text(vert[0], vert[1], str(num),
2245
+ # transform=transform,)
2246
+ # ax.scatter(verts[:,0], verts[:,1])
2247
+ # mid_point = np.array([(start[0] + end[0])/2., (start[1] + end[1])/2.])
2248
+ # print(start, end, mid_point)
2249
+ # ax.scatter(mid_point[0], mid_point[1], marker='x',
2250
+ # s=100, zorder=10, transform=transform,)
2251
+ closest_node = closest_node(vertices[int(len(vertices)/2.),:], verts)
2252
+ # print(closest_node, verts[closest_node])
2253
+ # ax.scatter(verts[closest_node][0], verts[closest_node][1], marker='x')
2254
+
2255
+ if len(vertices) > 2:
2256
+ # label_vert = vertices[int(len(vertices)/2.),:] #verts[1, :]
2257
+ label_vert = verts[closest_node] #verts[1, :]
2258
+ l = d["label"]
2259
+ string = str(l)
2260
+ txt = ax.text(
2261
+ label_vert[0],
2262
+ label_vert[1],
2263
+ string,
2264
+ fontsize=link_label_fontsize,
2265
+ verticalalignment="center",
2266
+ horizontalalignment="center",
2267
+ color="w",
2268
+ zorder=1,
2269
+ transform=transform,
2270
+ )
2271
+ txt.set_path_effects(
2272
+ [PathEffects.withStroke(linewidth=2, foreground="k")]
2273
+ )
2274
+
2275
+ return rad
2276
+
2277
+ # Collect all edge weights to get color scale
2278
+ all_links_weights = []
2279
+ all_links_edge_weights = []
2280
+ for (u, v, d) in G.edges(data=True):
2281
+ if u != v:
2282
+ if d["outer_edge"] and d["outer_edge_color"] is not None:
2283
+ all_links_weights.append(d["outer_edge_color"])
2284
+ if d["inner_edge"] and d["inner_edge_color"] is not None:
2285
+ all_links_weights.append(d["inner_edge_color"])
2286
+
2287
+ if cmap_links is not None and len(all_links_weights) > 0:
2288
+ if links_vmin is None:
2289
+ links_vmin = np.array(all_links_weights).min()
2290
+ if links_vmax is None:
2291
+ links_vmax = np.array(all_links_weights).max()
2292
+ data_to_rgb_links = pyplot.cm.ScalarMappable(
2293
+ norm=None, cmap=pyplot.get_cmap(cmap_links)
2294
+ )
2295
+ data_to_rgb_links.set_array(np.array(all_links_weights))
2296
+ data_to_rgb_links.set_clim(vmin=links_vmin, vmax=links_vmax)
2297
+ # Create colorbars for links
2298
+
2299
+ # setup colorbar axes.
2300
+ if show_colorbar:
2301
+ # cax_e = pyplot.axes(
2302
+ # [
2303
+ # 0.55,
2304
+ # ax.get_subplotspec().get_position(ax.figure).bounds[1] + 0.02,
2305
+ # 0.4,
2306
+ # 0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2307
+ # ],
2308
+ # frameon=False,
2309
+ # )
2310
+ bbox_ax = ax.get_position()
2311
+ width = bbox_ax.xmax-bbox_ax.xmin
2312
+ height = bbox_ax.ymax-bbox_ax.ymin
2313
+ # print(bbox_ax.xmin, bbox_ax.xmax, bbox_ax.ymin, bbox_ax.ymax)
2314
+ # cax_e = fig.add_axes(
2315
+ # [
2316
+ # bbox_ax.xmax - width*0.45,
2317
+ # bbox_ax.ymin-0.075*height+network_lower_bound-0.15,
2318
+ # width*0.4,
2319
+ # 0.075*height, #0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2320
+ # ],
2321
+ # frameon=False,
2322
+ # )
2323
+ cax_e = ax.inset_axes(
2324
+ [
2325
+ 0.55, -0.07, 0.4, 0.07
2326
+ # bbox_ax.xmax - width*0.45,
2327
+ # bbox_ax.ymin-0.075*height+network_lower_bound-0.15,
2328
+ # width*0.4,
2329
+ # 0.075*height, #0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2330
+ ],
2331
+ frameon=False,)
2332
+ # divider = make_axes_locatable(ax)
2333
+
2334
+ # cax_e = divider.append_axes('bottom', size='5%', pad=0.05, frameon=False,)
2335
+
2336
+ cb_e = pyplot.colorbar(
2337
+ data_to_rgb_links, cax=cax_e, orientation="horizontal"
2338
+ )
2339
+ # try:
2340
+ ticks_here = np.arange(
2341
+ _myround(links_vmin, links_ticks, "down"),
2342
+ _myround(links_vmax, links_ticks, "up") + links_ticks,
2343
+ links_ticks,
2344
+ )
2345
+ cb_e.set_ticks(ticks_here[(links_vmin <= ticks_here) & (ticks_here <= links_vmax)])
2346
+ # except:
2347
+ # print('no ticks given')
2348
+
2349
+ cb_e.outline.clear()
2350
+ cax_e.set_xlabel(
2351
+ link_colorbar_label, labelpad=1, fontsize=label_fontsize, zorder=10
2352
+ )
2353
+ cax_e.tick_params(axis='both', which='major', labelsize=tick_label_size)
2354
+
2355
+ ##
2356
+ # Draw nodes
2357
+ ##
2358
+ node_sizes = np.zeros((len(node_rings), N))
2359
+ for ring in list(node_rings): # iterate through to get all node sizes
2360
+ if node_rings[ring]["sizes"] is not None:
2361
+ node_sizes[ring] = node_rings[ring]["sizes"]
2362
+
2363
+ else:
2364
+ node_sizes[ring] = standard_size
2365
+ max_sizes = node_sizes.max(axis=1)
2366
+ total_max_size = node_sizes.sum(axis=0).max()
2367
+ node_sizes /= total_max_size
2368
+ node_sizes *= standard_size
2369
+
2370
+ def get_aspect(ax):
2371
+ # Total figure size
2372
+ figW, figH = ax.get_figure().get_size_inches()
2373
+ # print(figW, figH)
2374
+ # Axis size on figure
2375
+ _, _, w, h = ax.get_position().bounds
2376
+ # Ratio of display units
2377
+ # print(w, h)
2378
+ disp_ratio = (figH * h) / (figW * w)
2379
+ # Ratio of data units
2380
+ # Negative over negative because of the order of subtraction
2381
+ data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim())
2382
+ # print(data_ratio, disp_ratio)
2383
+ return disp_ratio / data_ratio
2384
+
2385
+ if node_aspect is None:
2386
+ node_aspect = get_aspect(ax)
2387
+
2388
+ # start drawing the outer ring first...
2389
+ for ring in list(node_rings)[::-1]:
2390
+ # print ring
2391
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
2392
+ # or None, 'cmap':string, 'vmin':float or None, 'vmax':float or None}}
2393
+ if node_rings[ring]["color_array"] is not None:
2394
+ color_data = node_rings[ring]["color_array"]
2395
+ if node_rings[ring]["vmin"] is not None:
2396
+ vmin = node_rings[ring]["vmin"]
2397
+ else:
2398
+ vmin = node_rings[ring]["color_array"].min()
2399
+ if node_rings[ring]["vmax"] is not None:
2400
+ vmax = node_rings[ring]["vmax"]
2401
+ else:
2402
+ vmax = node_rings[ring]["color_array"].max()
2403
+ if node_rings[ring]["cmap"] is not None:
2404
+ cmap = node_rings[ring]["cmap"]
2405
+ else:
2406
+ cmap = standard_cmap
2407
+ data_to_rgb = pyplot.cm.ScalarMappable(
2408
+ norm=None, cmap=pyplot.get_cmap(cmap)
2409
+ )
2410
+ data_to_rgb.set_array(color_data)
2411
+ data_to_rgb.set_clim(vmin=vmin, vmax=vmax)
2412
+ colors = [data_to_rgb.to_rgba(color_data[n]) for n in G]
2413
+
2414
+ if node_rings[ring]["colorbar"]:
2415
+ # Create colorbars for nodes
2416
+ # cax_n = pyplot.axes([.8 + ring*0.11,
2417
+ # ax.get_subplotspec().get_position(ax.figure).bounds[1]+0.05, 0.025, 0.35], frameon=False) #
2418
+ # setup colorbar axes.
2419
+ # setup colorbar axes.
2420
+ bbox_ax = ax.get_position()
2421
+ # print(bbox_ax.xmin, bbox_ax.xmax, bbox_ax.ymin, bbox_ax.ymax)
2422
+ cax_n = ax.inset_axes(
2423
+ [
2424
+ 0.05, -0.07, 0.4, 0.07
2425
+ # bbox_ax.xmin + width*0.05,
2426
+ # bbox_ax.ymin-0.075*height+network_lower_bound-0.15,
2427
+ # width*0.4,
2428
+ # 0.075*height, #0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2429
+ ],
2430
+ frameon=False,
2431
+ )
2432
+ cb_n = pyplot.colorbar(data_to_rgb, cax=cax_n, orientation="horizontal")
2433
+ # try:
2434
+ ticks_here = np.arange(
2435
+ _myround(vmin, node_rings[ring]["ticks"], "down"),
2436
+ _myround(vmax, node_rings[ring]["ticks"], "up")
2437
+ + node_rings[ring]["ticks"],
2438
+ node_rings[ring]["ticks"],
2439
+ )
2440
+ cb_n.set_ticks(ticks_here[(vmin <= ticks_here) & (ticks_here <= vmax)])
2441
+ # except:
2442
+ # print ('no ticks given')
2443
+ cb_n.outline.clear()
2444
+ # cb_n.set_ticks()
2445
+ cax_n.set_xlabel(
2446
+ node_rings[ring]["label"], labelpad=1, fontsize=label_fontsize
2447
+ )
2448
+ cax_n.tick_params(axis='both', which='major', labelsize=tick_label_size)
2449
+ else:
2450
+ colors = None
2451
+ vmin = None
2452
+ vmax = None
2453
+
2454
+ for n in G:
2455
+ if type(node_alpha) == dict:
2456
+ alpha = node_alpha[n]
2457
+ else:
2458
+ alpha = 1.0
2459
+
2460
+ if special_nodes is not None:
2461
+ if n in special_nodes:
2462
+ color_here = special_nodes[n]
2463
+ else:
2464
+ color_here = 'grey'
2465
+ else:
2466
+ if colors is None:
2467
+ color_here = standard_color_nodes
2468
+ else:
2469
+ color_here = colors[n]
2470
+
2471
+ c = Ellipse(
2472
+ pos[n],
2473
+ width=node_sizes[: ring + 1].sum(axis=0)[n] * node_aspect,
2474
+ height=node_sizes[: ring + 1].sum(axis=0)[n],
2475
+ clip_on=False,
2476
+ facecolor=color_here,
2477
+ edgecolor=color_here,
2478
+ zorder=-ring - 1 + 2,
2479
+ transform=transform,
2480
+ )
2481
+
2482
+ # else:
2483
+ # if special_nodes is not None and n in special_nodes:
2484
+ # color_here = special_nodes[n]
2485
+ # else:
2486
+ # color_here = colors[n]
2487
+ # c = Ellipse(
2488
+ # pos[n],
2489
+ # width=node_sizes[: ring + 1].sum(axis=0)[n] * node_aspect,
2490
+ # height=node_sizes[: ring + 1].sum(axis=0)[n],
2491
+ # clip_on=False,
2492
+ # facecolor=colors[n],
2493
+ # edgecolor=colors[n],
2494
+ # zorder=-ring - 1,
2495
+ # )
2496
+
2497
+ ax.add_patch(c)
2498
+
2499
+ if node_classification is not None and node_classification[n] in ["space_context_last", "space_dummy_last", "time_dummy_last"]:
2500
+ node_height = node_sizes[: ring + 1].sum(axis=0)[n]
2501
+ node_width_difference_to_height = node_height * (1 - node_aspect)
2502
+
2503
+ c_wide = mpatches.FancyBboxPatch((pos[n-max_lag+1][0] + node_width_difference_to_height / 2, pos[n-max_lag+1][1]),
2504
+ (pos[n][0] - pos[n-max_lag+1][0] - node_width_difference_to_height),
2505
+ 0.,
2506
+ boxstyle=mpatches.BoxStyle.Round(pad=0.5 * node_height),
2507
+ facecolor=color_here,
2508
+ edgecolor=color_here,
2509
+ )
2510
+
2511
+ ax.add_patch(c_wide)
2512
+
2513
+
2514
+ # avoiding attribute error raised by changes in networkx
2515
+ if hasattr(G, "node"):
2516
+ # works with networkx 1.10
2517
+ G.node[n]["patch"] = c
2518
+ else:
2519
+ # works with networkx 2.4
2520
+ G.nodes[n]["patch"] = c
2521
+
2522
+ if ring == 0:
2523
+ ax.text(
2524
+ pos[n][0],
2525
+ pos[n][1],
2526
+ node_labels[n],
2527
+ fontsize=node_label_size,
2528
+ horizontalalignment="center",
2529
+ verticalalignment="center",
2530
+ alpha=1.0,
2531
+ zorder=5.,
2532
+ transform=transform,
2533
+ )
2534
+ if show_autodependency_lags:
2535
+ ax.text(
2536
+ pos[n][0],
2537
+ pos[n][1],
2538
+ autodep_sig_lags[n],
2539
+ fontsize=link_label_fontsize,
2540
+ horizontalalignment="center",
2541
+ verticalalignment="center",
2542
+ color="black",
2543
+ zorder=5.,
2544
+ transform=transform,
2545
+ )
2546
+
2547
+ # Draw edges
2548
+ seen = {}
2549
+ for (u, v, d) in G.edges(data=True):
2550
+ if d.get("no_links"):
2551
+ d["inner_edge_alpha"] = 1e-8
2552
+ d["outer_edge_alpha"] = 1e-8
2553
+ if u != v:
2554
+ if d["outer_edge"]:
2555
+ seen[(u, v)] = draw_edge(ax, u, v, d, seen, outer_edge=True)
2556
+ if d["inner_edge"]:
2557
+ seen[(u, v)] = draw_edge(ax, u, v, d, seen, outer_edge=False)
2558
+
2559
+ # if network_left_bound is not None:
2560
+ # network_right_bound = 0.98
2561
+ # else:
2562
+ # network_right_bound = None
2563
+ # fig.subplots_adjust(bottom=network_lower_bound, left=network_left_bound, right=network_right_bound) #, right=0.97)
2564
+
2565
+
2566
+ def plot_graph(
2567
+ graph,
2568
+ val_matrix=None,
2569
+ var_names=None,
2570
+ fig_ax=None,
2571
+ figsize=None,
2572
+ save_name=None,
2573
+ link_colorbar_label="MCI",
2574
+ node_colorbar_label="auto-MCI",
2575
+ link_width=None,
2576
+ link_attribute=None,
2577
+ node_pos=None,
2578
+ arrow_linewidth=8.0,
2579
+ vmin_edges=-1,
2580
+ vmax_edges=1.0,
2581
+ edge_ticks=0.4,
2582
+ cmap_edges="RdBu_r",
2583
+ vmin_nodes=-1,
2584
+ vmax_nodes=1.0,
2585
+ node_ticks=0.4,
2586
+ cmap_nodes="RdBu_r",
2587
+ node_size=0.3,
2588
+ node_aspect=None,
2589
+ arrowhead_size=20,
2590
+ curved_radius=0.2,
2591
+ label_fontsize=10,
2592
+ tick_label_size=6,
2593
+ alpha=1.0,
2594
+ node_label_size=10,
2595
+ link_label_fontsize=10,
2596
+ lag_array=None,
2597
+ # network_lower_bound=0.2,
2598
+ show_colorbar=True,
2599
+ inner_edge_style="dashed",
2600
+ link_matrix=None,
2601
+ special_nodes=None,
2602
+ show_autodependency_lags=False,
2603
+ special_links=None,
2604
+ ):
2605
+ """Creates a network plot.
2606
+
2607
+ This is still in beta. The network is defined from links in graph. Nodes
2608
+ denote variables, straight links contemporaneous dependencies and curved
2609
+ arrows lagged dependencies. The node color denotes the maximal absolute
2610
+ auto-dependency and the link color the value at the lag with maximal
2611
+ absolute cross-dependency. The link label lists the lags with significant
2612
+ dependency in order of absolute magnitude. The network can also be
2613
+ plotted over a map drawn before on the same axis. Then the node positions
2614
+ can be supplied in appropriate axis coordinates via node_pos.
2615
+
2616
+ Parameters
2617
+ ----------
2618
+ graph : string or bool array-like, optional (default: None)
2619
+ Either string matrix providing graph or bool array providing only adjacencies
2620
+ Must be of same shape as val_matrix.
2621
+ val_matrix : array_like
2622
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
2623
+ var_names : list, optional (default: None)
2624
+ List of variable names. If None, range(N) is used.
2625
+ fig_ax : tuple of figure and axis object, optional (default: None)
2626
+ Figure and axes instance. If None they are created.
2627
+ figsize : tuple
2628
+ Size of figure.
2629
+ save_name : str, optional (default: None)
2630
+ Name of figure file to save figure. If None, figure is shown in window.
2631
+ link_colorbar_label : str, optional (default: 'MCI')
2632
+ Test statistic label.
2633
+ node_colorbar_label : str, optional (default: 'auto-MCI')
2634
+ Test statistic label for auto-dependencies.
2635
+ link_width : array-like, optional (default: None)
2636
+ Array of val_matrix.shape specifying relative link width with maximum
2637
+ given by arrow_linewidth. If None, all links have same width.
2638
+ link_attribute : array-like, optional (default: None)
2639
+ String array of val_matrix.shape specifying link attributes.
2640
+ node_pos : dictionary, optional (default: None)
2641
+ Dictionary of node positions in axis coordinates of form
2642
+ node_pos = {'x':array of shape (N,), 'y':array of shape(N)}. These
2643
+ coordinates could have been transformed before for basemap plots. You can
2644
+ also add a key 'transform':ccrs.PlateCarree() in order to plot graphs on
2645
+ a map using cartopy.
2646
+ arrow_linewidth : float, optional (default: 30)
2647
+ Linewidth.
2648
+ vmin_edges : float, optional (default: -1)
2649
+ Link colorbar scale lower bound.
2650
+ vmax_edges : float, optional (default: 1)
2651
+ Link colorbar scale upper bound.
2652
+ edge_ticks : float, optional (default: 0.4)
2653
+ Link tick mark interval.
2654
+ cmap_edges : str, optional (default: 'RdBu_r')
2655
+ Colormap for links.
2656
+ vmin_nodes : float, optional (default: 0)
2657
+ Node colorbar scale lower bound.
2658
+ vmax_nodes : float, optional (default: 1)
2659
+ Node colorbar scale upper bound.
2660
+ node_ticks : float, optional (default: 0.4)
2661
+ Node tick mark interval.
2662
+ cmap_nodes : str, optional (default: 'OrRd')
2663
+ Colormap for links.
2664
+ node_size : int, optional (default: 0.3)
2665
+ Node size.
2666
+ node_aspect : float, optional (default: None)
2667
+ Ratio between the heigth and width of the varible nodes.
2668
+ arrowhead_size : int, optional (default: 20)
2669
+ Size of link arrow head. Passed on to FancyArrowPatch object.
2670
+ curved_radius, float, optional (default: 0.2)
2671
+ Curvature of links. Passed on to FancyArrowPatch object.
2672
+ label_fontsize : int, optional (default: 10)
2673
+ Fontsize of colorbar labels.
2674
+ alpha : float, optional (default: 1.)
2675
+ Opacity.
2676
+ node_label_size : int, optional (default: 10)
2677
+ Fontsize of node labels.
2678
+ link_label_fontsize : int, optional (default: 6)
2679
+ Fontsize of link labels.
2680
+ tick_label_size : int, optional (default: 6)
2681
+ Fontsize of tick labels.
2682
+ lag_array : array, optional (default: None)
2683
+ Optional specification of lags overwriting np.arange(0, tau_max+1)
2684
+ show_colorbar : bool
2685
+ Whether to show colorbars for links and nodes.
2686
+ show_autodependency_lags : bool (default: False)
2687
+ Shows significant autodependencies for a node.
2688
+ """
2689
+
2690
+ if link_matrix is not None:
2691
+ raise ValueError("link_matrix is deprecated and replaced by graph array"
2692
+ " which is now returned by all methods.")
2693
+
2694
+ if fig_ax is None:
2695
+ fig = pyplot.figure(figsize=figsize)
2696
+ ax = fig.add_subplot(111, frame_on=False)
2697
+ else:
2698
+ fig, ax = fig_ax
2699
+
2700
+ graph = np.copy(graph.squeeze())
2701
+
2702
+ if graph.ndim == 4:
2703
+ raise ValueError("Time series graph of shape (N,N,tau_max+1,tau_max+1) cannot be represented by plot_graph,"
2704
+ " use plot_time_series_graph instead.")
2705
+
2706
+ if graph.ndim == 2:
2707
+ # If a non-time series (N,N)-graph is given, insert a dummy dimension
2708
+ graph = np.expand_dims(graph, axis = 2)
2709
+
2710
+ if val_matrix is None:
2711
+ no_coloring = True
2712
+ cmap_edges = None
2713
+ cmap_nodes = None
2714
+ else:
2715
+ no_coloring = False
2716
+
2717
+ (graph, val_matrix, link_width, link_attribute) = _check_matrices(
2718
+ graph, val_matrix, link_width, link_attribute)
2719
+
2720
+
2721
+ N, N, dummy = graph.shape
2722
+ tau_max = dummy - 1
2723
+ max_lag = tau_max + 1
2724
+
2725
+ if np.count_nonzero(graph != "") == np.count_nonzero(
2726
+ np.diagonal(graph) != ""
2727
+ ):
2728
+ diagonal = True
2729
+ else:
2730
+ diagonal = False
2731
+
2732
+ if np.count_nonzero(graph == "") == graph.size or diagonal:
2733
+ graph[0, 1, 0] = "xxx" # Workaround, will not be plotted...
2734
+ no_links = True
2735
+ else:
2736
+ no_links = False
2737
+
2738
+ if var_names is None:
2739
+ var_names = range(N)
2740
+
2741
+ # Define graph links by absolute maximum (positive or negative like for
2742
+ # partial correlation)
2743
+ # val_matrix[np.abs(val_matrix) < sig_thres] = 0.
2744
+
2745
+ # Only draw link in one direction among contemp
2746
+ # Remove lower triangle
2747
+ link_matrix_upper = np.copy(graph)
2748
+ link_matrix_upper[:, :, 0] = np.triu(link_matrix_upper[:, :, 0])
2749
+
2750
+ # net = _get_absmax(link_matrix != "")
2751
+ net = np.any(link_matrix_upper != "", axis=2)
2752
+ G = nx.DiGraph(net)
2753
+
2754
+ # This handels Graphs with no links.
2755
+ # nx.draw(G, alpha=0, zorder=-10)
2756
+
2757
+ node_color = list(np.zeros(N))
2758
+
2759
+ if show_autodependency_lags:
2760
+ autodep_sig_lags = np.full(N, None, dtype='object')
2761
+ else:
2762
+ autodep_sig_lags = None
2763
+
2764
+ # list of all strengths for color map
2765
+ all_strengths = []
2766
+ # Add attributes, contemporaneous and lagged links are handled separately
2767
+ for (u, v, dic) in G.edges(data=True):
2768
+ dic["no_links"] = no_links
2769
+ # average lagfunc for link u --> v ANDOR u -- v
2770
+ if tau_max > 0:
2771
+ # argmax of absolute maximum where a link exists!
2772
+ links = np.where(link_matrix_upper[u, v, 1:] != "")[0]
2773
+ if len(links) > 0:
2774
+ argmax_links = np.abs(val_matrix[u, v][1:][links]).argmax()
2775
+ argmax = links[argmax_links] + 1
2776
+ else:
2777
+ argmax = 0
2778
+ else:
2779
+ argmax = 0
2780
+
2781
+ if u != v:
2782
+ # For contemp links masking or finite samples can lead to different
2783
+ # values for u--v and v--u
2784
+ # Here we use the maximum for the width and weight (=color)
2785
+ # of the link
2786
+ # Draw link if u--v OR v--u at lag 0 is nonzero
2787
+ # dic['inner_edge'] = ((np.abs(val_matrix[u, v][0]) >=
2788
+ # sig_thres[u, v][0]) or
2789
+ # (np.abs(val_matrix[v, u][0]) >=
2790
+ # sig_thres[v, u][0]))
2791
+ dic["inner_edge"] = link_matrix_upper[u, v, 0]
2792
+ dic["inner_edge_type"] = link_matrix_upper[u, v, 0]
2793
+ dic["inner_edge_alpha"] = alpha
2794
+ if no_coloring:
2795
+ dic["inner_edge_color"] = None
2796
+ else:
2797
+ dic["inner_edge_color"] = val_matrix[u, v, 0]
2798
+ # # value at argmax of average
2799
+ # if np.abs(val_matrix[u, v][0] - val_matrix[v, u][0]) > .0001:
2800
+ # print("Contemporaneous I(%d; %d)=%.3f != I(%d; %d)=%.3f" % (
2801
+ # u, v, val_matrix[u, v][0], v, u, val_matrix[v, u][0]) +
2802
+ # " due to conditions, finite sample effects or "
2803
+ # "masking, here edge color = "
2804
+ # "larger (absolute) value.")
2805
+ # dic['inner_edge_color'] = _get_absmax(
2806
+ # np.array([[[val_matrix[u, v][0],
2807
+ # val_matrix[v, u][0]]]])).squeeze()
2808
+
2809
+ if link_width is None:
2810
+ dic["inner_edge_width"] = arrow_linewidth
2811
+ else:
2812
+ dic["inner_edge_width"] = (
2813
+ link_width[u, v, 0] / link_width.max() * arrow_linewidth
2814
+ )
2815
+
2816
+ if link_attribute is None:
2817
+ dic["inner_edge_attribute"] = None
2818
+ else:
2819
+ dic["inner_edge_attribute"] = link_attribute[u, v, 0]
2820
+
2821
+ # # fraction of nonzero values
2822
+ dic["inner_edge_style"] = "solid"
2823
+ # else:
2824
+ # dic['inner_edge_style'] = link_style[
2825
+ # u, v, 0]
2826
+
2827
+ all_strengths.append(dic["inner_edge_color"])
2828
+
2829
+ if tau_max > 0:
2830
+ # True if ensemble mean at lags > 0 is nonzero
2831
+ # dic['outer_edge'] = np.any(
2832
+ # np.abs(val_matrix[u, v][1:]) >= sig_thres[u, v][1:])
2833
+ dic["outer_edge"] = np.any(link_matrix_upper[u, v, 1:] != "")
2834
+ else:
2835
+ dic["outer_edge"] = False
2836
+ # print(u, v, dic["outer_edge"], argmax, link_matrix_upper[u, v, :])
2837
+
2838
+ dic["outer_edge_type"] = link_matrix_upper[u, v, argmax]
2839
+
2840
+ dic["outer_edge_alpha"] = alpha
2841
+ if link_width is None:
2842
+ # fraction of nonzero values
2843
+ dic["outer_edge_width"] = arrow_linewidth
2844
+ else:
2845
+ dic["outer_edge_width"] = (
2846
+ link_width[u, v, argmax] / link_width.max() * arrow_linewidth
2847
+ )
2848
+
2849
+ if link_attribute is None:
2850
+ # fraction of nonzero values
2851
+ dic["outer_edge_attribute"] = None
2852
+ else:
2853
+ dic["outer_edge_attribute"] = link_attribute[u, v, argmax]
2854
+
2855
+ # value at argmax of average
2856
+ if no_coloring:
2857
+ dic["outer_edge_color"] = None
2858
+ else:
2859
+ dic["outer_edge_color"] = val_matrix[u, v][argmax]
2860
+ all_strengths.append(dic["outer_edge_color"])
2861
+
2862
+ # Sorted list of significant lags (only if robust wrt
2863
+ # d['min_ensemble_frac'])
2864
+ if tau_max > 0:
2865
+ lags = np.abs(val_matrix[u, v][1:]).argsort()[::-1] + 1
2866
+ sig_lags = (np.where(link_matrix_upper[u, v, 1:] != "")[0] + 1).tolist()
2867
+ else:
2868
+ lags, sig_lags = [], []
2869
+ if lag_array is not None:
2870
+ dic["label"] = ",".join([str(lag_array[l]) for l in lags if l in sig_lags]) #str([str(lag_array[l]) for l in lags if l in sig_lags])[1:-1].replace(" ", "")
2871
+ else:
2872
+ dic["label"] = ",".join([str(l) for l in lags if l in sig_lags]) # str([str(l) for l in lags if l in sig_lags])[1:-1].replace(" ", "")
2873
+ else:
2874
+ # Node color is max of average autodependency
2875
+ if no_coloring:
2876
+ node_color[u] = None
2877
+ else:
2878
+ node_color[u] = val_matrix[u, v][argmax]
2879
+
2880
+ if show_autodependency_lags:
2881
+ autodep_sig_lags[u] = "\n\n\n" + ",".join(str(i) for i in (np.where(link_matrix_upper[u, v, 1:] != "")[0] + 1).tolist())
2882
+ # Lags upto tau_max
2883
+ #autodep_lags = np.argsort(val_matrix[u, v][1:])[::-1]
2884
+ #autodep_lags += 1
2885
+ #autodeplags[u] = "\n\n\n" + ",".join(str(i) for i in autodep_lags.tolist())
2886
+
2887
+ dic["inner_edge_attribute"] = None
2888
+ dic["outer_edge_attribute"] = None
2889
+
2890
+ # dic['outer_edge_edge'] = False
2891
+ # dic['outer_edge_edgecolor'] = None
2892
+ # dic['inner_edge_edge'] = False
2893
+ # dic['inner_edge_edgecolor'] = None
2894
+
2895
+ if special_nodes is not None:
2896
+ special_nodes_draw = {}
2897
+ for node in special_nodes:
2898
+ i, tau = node
2899
+ if tau >= -tau_max:
2900
+ special_nodes_draw[i] = special_nodes[node]
2901
+ special_nodes = special_nodes_draw
2902
+
2903
+
2904
+ # If no links are present, set value to zero
2905
+ if len(all_strengths) == 0:
2906
+ all_strengths = [0.0]
2907
+
2908
+ if node_pos is None:
2909
+ pos = nx.circular_layout(deepcopy(G))
2910
+ else:
2911
+ pos = {}
2912
+ for i in range(N):
2913
+ pos[i] = (node_pos["x"][i], node_pos["y"][i])
2914
+
2915
+ if node_pos is not None and 'transform' in node_pos:
2916
+ transform = node_pos['transform']
2917
+ else: transform = ax.transData
2918
+
2919
+ if cmap_nodes is None:
2920
+ node_color = None
2921
+
2922
+ node_rings = {
2923
+ 0: {
2924
+ "sizes": None,
2925
+ "color_array": node_color,
2926
+ "cmap": cmap_nodes,
2927
+ "vmin": vmin_nodes,
2928
+ "vmax": vmax_nodes,
2929
+ "ticks": node_ticks,
2930
+ "label": node_colorbar_label,
2931
+ "colorbar": show_colorbar,
2932
+ }
2933
+ }
2934
+
2935
+ _draw_network_with_curved_edges(
2936
+ fig=fig,
2937
+ ax=ax,
2938
+ G=deepcopy(G),
2939
+ pos=pos,
2940
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
2941
+ # or None, 'cmap':string,
2942
+ node_rings=node_rings,
2943
+ # 'vmin':float or None, 'vmax':float or None, 'label':string or None}}
2944
+ node_labels=var_names,
2945
+ node_label_size=node_label_size,
2946
+ node_alpha=alpha,
2947
+ standard_size=node_size,
2948
+ node_aspect=node_aspect,
2949
+ standard_cmap="OrRd",
2950
+ standard_color_nodes="lightgrey",
2951
+ standard_color_links="black",
2952
+ log_sizes=False,
2953
+ cmap_links=cmap_edges,
2954
+ links_vmin=vmin_edges,
2955
+ links_vmax=vmax_edges,
2956
+ links_ticks=edge_ticks,
2957
+ tick_label_size=tick_label_size,
2958
+ # cmap_links_edges='YlOrRd', links_edges_vmin=-1., links_edges_vmax=1.,
2959
+ # links_edges_ticks=.2, link_edge_colorbar_label='link_edge',
2960
+ arrowstyle="simple",
2961
+ arrowhead_size=arrowhead_size,
2962
+ curved_radius=curved_radius,
2963
+ label_fontsize=label_fontsize,
2964
+ link_label_fontsize=link_label_fontsize,
2965
+ link_colorbar_label=link_colorbar_label,
2966
+ # network_lower_bound=network_lower_bound,
2967
+ show_colorbar=show_colorbar,
2968
+ # label_fraction=label_fraction,
2969
+ special_nodes=special_nodes,
2970
+ autodep_sig_lags=autodep_sig_lags,
2971
+ show_autodependency_lags=show_autodependency_lags,
2972
+ transform=transform,
2973
+ special_links=special_links,
2974
+ )
2975
+
2976
+ if save_name is not None:
2977
+ pyplot.savefig(save_name, dpi=300)
2978
+ else:
2979
+ return fig, ax
2980
+
2981
+
2982
+ def _reverse_patt(patt):
2983
+ """Inverts a link pattern"""
2984
+
2985
+ if patt == "":
2986
+ return ""
2987
+
2988
+ left_mark, middle_mark, right_mark = patt[0], patt[1], patt[2]
2989
+ if left_mark == "<":
2990
+ new_right_mark = ">"
2991
+ else:
2992
+ new_right_mark = left_mark
2993
+ if right_mark == ">":
2994
+ new_left_mark = "<"
2995
+ else:
2996
+ new_left_mark = right_mark
2997
+
2998
+ return new_left_mark + middle_mark + new_right_mark
2999
+
3000
+ # if patt in ['---', 'o--', '--o', 'o-o', '']:
3001
+ # return patt[::-1]
3002
+ # elif patt == '<->':
3003
+ # return '<->'
3004
+ # elif patt == 'o->':
3005
+ # return '<-o'
3006
+ # elif patt == '<-o':
3007
+ # return 'o->'
3008
+ # elif patt == '-->':
3009
+ # return '<--'
3010
+ # elif patt == '<--':
3011
+ # return '-->'
3012
+
3013
+
3014
+ def _check_matrices(graph, val_matrix, link_width, link_attribute):
3015
+
3016
+ if graph.dtype != "<U3":
3017
+ # Transform to new graph data type U3
3018
+ old_matrix = np.copy(graph)
3019
+ graph = np.zeros(old_matrix.shape, dtype="<U3")
3020
+ graph[:] = ""
3021
+ for i, j, tau in zip(*np.where(old_matrix)):
3022
+ if tau == 0:
3023
+ if old_matrix[j, i, 0] == 0:
3024
+ graph[i, j, 0] = "-->"
3025
+ graph[j, i, 0] = "<--"
3026
+ else:
3027
+ graph[i, j, 0] = "o-o"
3028
+ graph[j, i, 0] = "o-o"
3029
+ else:
3030
+ graph[i, j, tau] = "-->"
3031
+ if graph.ndim == 4:
3032
+ for i, j, taui, tauj in zip(*np.where(graph)):
3033
+ if graph[i, j, taui, tauj] not in [
3034
+ "---",
3035
+ "o--",
3036
+ "--o",
3037
+ "o-o",
3038
+ "o->",
3039
+ "<-o",
3040
+ "-->",
3041
+ "<--",
3042
+ "<->",
3043
+ "x-o",
3044
+ "o-x",
3045
+ "x--",
3046
+ "--x",
3047
+ "x->",
3048
+ "<-x",
3049
+ "x-x",
3050
+ "<-+",
3051
+ "+->",
3052
+ "<O>" # contemporaneous cycle in union-graph
3053
+ ]:
3054
+ raise ValueError("Invalid graph entry.")
3055
+ if graph[i, j, taui, tauj] != _reverse_patt(graph[j, i, tauj, taui]):
3056
+ raise ValueError(
3057
+ "graph needs to have consistent entries: "
3058
+ "graph[i, j, taui, tauj] == _reverse_patt(graph[j, i, tauj, taui])")
3059
+ if (
3060
+ val_matrix is not None
3061
+ and val_matrix[i, j, taui, tauj] != val_matrix[j, i, tauj, taui]
3062
+ ):
3063
+ raise ValueError(
3064
+ "val_matrix needs to have consistent entries: "
3065
+ "val_matrix[i, j, taui, tauj] == val_matrix[j, i, tauj, taui]")
3066
+ if (
3067
+ link_width is not None
3068
+ and link_width[i, j, taui, tauj] != link_width[j, i, tauj, taui]
3069
+ ):
3070
+ raise ValueError(
3071
+ "link_width needs to have consistent entries: "
3072
+ "link_width[i, j, taui, tauj] == link_width[j, i, tauj, taui]")
3073
+ if (
3074
+ link_attribute is not None
3075
+ and link_attribute[i, j, taui, tauj] != link_attribute[j, i, tauj, taui]
3076
+ ):
3077
+ raise ValueError(
3078
+ "link_attribute needs to have consistent entries: "
3079
+ "link_attribute[i, j, taui, tauj] == link_attribute[j, i, tauj, taui]")
3080
+ else:
3081
+ # print(graph[:,:,0])
3082
+ # Assert that graph has valid and consistent lag-zero entries
3083
+ for i, j, tau in zip(*np.where(graph)):
3084
+ if tau == 0:
3085
+ if graph[i, j, 0] != _reverse_patt(graph[j, i, 0]):
3086
+ raise ValueError(
3087
+ "graph needs to have consistent lag-zero links, but "
3088
+ " graph[%d,%d,0]=%s and graph[%d,%d,0]=%s)" %(i, j, graph[i, j, 0], j, i, graph[j, i, 0])
3089
+ )
3090
+ if (
3091
+ val_matrix is not None
3092
+ and val_matrix[i, j, 0] != val_matrix[j, i, 0]
3093
+ ):
3094
+ raise ValueError("val_matrix needs to be symmetric for lag-zero")
3095
+ if (
3096
+ link_width is not None
3097
+ and link_width[i, j, 0] != link_width[j, i, 0]
3098
+ ):
3099
+ raise ValueError("link_width needs to be symmetric for lag-zero")
3100
+ if (
3101
+ link_attribute is not None
3102
+ and link_attribute[i, j, 0] != link_attribute[j, i, 0]
3103
+ ):
3104
+ raise ValueError(
3105
+ "link_attribute needs to be symmetric for lag-zero"
3106
+ )
3107
+
3108
+ if graph[i, j, tau] not in [
3109
+ "---",
3110
+ "o--",
3111
+ "--o",
3112
+ "o-o",
3113
+ "o->",
3114
+ "<-o",
3115
+ "-->",
3116
+ "<--",
3117
+ "<->",
3118
+ "x-o",
3119
+ "o-x",
3120
+ "x--",
3121
+ "--x",
3122
+ "x->",
3123
+ "<-x",
3124
+ "x-x",
3125
+ "<-+",
3126
+ "+->",
3127
+ "<O>" # contemporaneous cycle in union-graph
3128
+ ]:
3129
+ raise ValueError("Invalid graph entry.")
3130
+
3131
+ if val_matrix is None:
3132
+ # if graph.ndim == 4:
3133
+ # val_matrix = (graph != "").astype("int")
3134
+ # else:
3135
+ val_matrix = (graph != "").astype("int")
3136
+
3137
+ if link_width is not None and not np.all(link_width >= 0.0):
3138
+ raise ValueError("link_width must be non-negative")
3139
+
3140
+ return graph, val_matrix, link_width, link_attribute
3141
+
3142
+
3143
+ def plot_time_series_graph(
3144
+ graph,
3145
+ val_matrix=None,
3146
+ var_names=None,
3147
+ fig_ax=None,
3148
+ figsize=None,
3149
+ link_colorbar_label="MCI",
3150
+ save_name=None,
3151
+ link_width=None,
3152
+ link_attribute=None,
3153
+ arrow_linewidth=4,
3154
+ vmin_edges=-1,
3155
+ vmax_edges=1.0,
3156
+ edge_ticks=0.4,
3157
+ cmap_edges="RdBu_r",
3158
+ order=None,
3159
+ node_size=0.1,
3160
+ node_aspect=None,
3161
+ arrowhead_size=20,
3162
+ curved_radius=0.2,
3163
+ label_fontsize=10,
3164
+ tick_label_size=6,
3165
+ alpha=1.0,
3166
+ inner_edge_style="dashed",
3167
+ link_matrix=None,
3168
+ special_nodes=None,
3169
+ node_classification=None,
3170
+ # aux_graph=None,
3171
+ standard_color_links='black',
3172
+ standard_color_nodes='lightgrey',
3173
+ ):
3174
+ """Creates a time series graph.
3175
+ This is still in beta. The time series graph's links are colored by
3176
+ val_matrix.
3177
+
3178
+ Parameters
3179
+ ----------
3180
+ graph : string or bool array-like, optional (default: None)
3181
+ Either string matrix providing graph or bool array providing only adjacencies
3182
+ Either of shape (N, N, tau_max + 1) or as auxiliary graph of dims
3183
+ (N, N, tau_max+1, tau_max+1) describing auxADMG.
3184
+ val_matrix : array_like
3185
+ Matrix of same shape as graph containing test statistic values.
3186
+ var_names : list, optional (default: None)
3187
+ List of variable names. If None, range(N) is used.
3188
+ fig_ax : tuple of figure and axis object, optional (default: None)
3189
+ Figure and axes instance. If None they are created.
3190
+ figsize : tuple
3191
+ Size of figure.
3192
+ save_name : str, optional (default: None)
3193
+ Name of figure file to save figure. If None, figure is shown in window.
3194
+ link_colorbar_label : str, optional (default: 'MCI')
3195
+ Test statistic label.
3196
+ link_width : array-like, optional (default: None)
3197
+ Array of val_matrix.shape specifying relative link width with maximum
3198
+ given by arrow_linewidth. If None, all links have same width.
3199
+ link_attribute : array-like, optional (default: None)
3200
+ Array of graph.shape specifying specific in drawing the graph (for internal use).
3201
+ order : list, optional (default: None)
3202
+ order of variables from top to bottom.
3203
+ arrow_linewidth : float, optional (default: 30)
3204
+ Linewidth.
3205
+ vmin_edges : float, optional (default: -1)
3206
+ Link colorbar scale lower bound.
3207
+ vmax_edges : float, optional (default: 1)
3208
+ Link colorbar scale upper bound.
3209
+ edge_ticks : float, optional (default: 0.4)
3210
+ Link tick mark interval.
3211
+ cmap_edges : str, optional (default: 'RdBu_r')
3212
+ Colormap for links.
3213
+ node_size : int, optional (default: 0.1)
3214
+ Node size.
3215
+ node_aspect : float, optional (default: None)
3216
+ Ratio between the heigth and width of the varible nodes.
3217
+ arrowhead_size : int, optional (default: 20)
3218
+ Size of link arrow head. Passed on to FancyArrowPatch object.
3219
+ curved_radius, float, optional (default: 0.2)
3220
+ Curvature of links. Passed on to FancyArrowPatch object.
3221
+ label_fontsize : int, optional (default: 10)
3222
+ Fontsize of colorbar labels.
3223
+ alpha : float, optional (default: 1.)
3224
+ Opacity.
3225
+ node_label_size : int, optional (default: 10)
3226
+ Fontsize of node labels.
3227
+ link_label_fontsize : int, optional (default: 6)
3228
+ Fontsize of link labels.
3229
+ tick_label_size : int, optional (default: 6)
3230
+ Fontsize of tick labels.
3231
+ inner_edge_style : string, optional (default: 'dashed')
3232
+ Style of inner_edge contemporaneous links.
3233
+ special_nodes : dict
3234
+ Dictionary of format {(i, -tau): 'blue', ...} to color special nodes.
3235
+ node_classification : dict or None (default: None)
3236
+ Dictionary of format {i: 'space_context', ...} to classify nodes into system, context, or dummy nodes.
3237
+ Keys of the dictionary are from {0, ..., N-1} where N is the number of nodes.
3238
+ Options for the values are "system", "time_context", "space_context", "time_dummy", or "space_dummy".
3239
+ Space_contexts and dummy nodes need to be represented as a single node in the time series graph.
3240
+ In case no value is supplied all nodes are treated as system nodes, i.e. are plotted in a time-resolved manner.
3241
+ """
3242
+
3243
+ if link_matrix is not None:
3244
+ raise ValueError("link_matrix is deprecated and replaced by graph array"
3245
+ " which is now returned by all methods.")
3246
+
3247
+ if fig_ax is None:
3248
+ fig = pyplot.figure(figsize=figsize)
3249
+ ax = fig.add_subplot(111, frame_on=False)
3250
+ else:
3251
+ fig, ax = fig_ax
3252
+
3253
+ if val_matrix is None:
3254
+ no_coloring = True
3255
+ cmap_edges = None
3256
+ else:
3257
+ no_coloring = False
3258
+
3259
+ (graph, val_matrix, link_width, link_attribute) = _check_matrices(
3260
+ graph, val_matrix, link_width, link_attribute
3261
+ )
3262
+
3263
+ if graph.ndim == 4:
3264
+ N, N, dummy, _ = graph.shape
3265
+ tau_max = dummy - 1
3266
+ max_lag = tau_max + 1
3267
+ else:
3268
+ N, N, dummy = graph.shape
3269
+ tau_max = dummy - 1
3270
+ max_lag = tau_max + 1
3271
+
3272
+ if np.count_nonzero(graph == "") == graph.size:
3273
+ if graph.ndim == 4:
3274
+ graph[0, 1, 0, 0] = "---"
3275
+ else:
3276
+ graph[0, 1, 0] = "---"
3277
+ no_links = True
3278
+ else:
3279
+ no_links = False
3280
+
3281
+ if var_names is None:
3282
+ var_names = range(N)
3283
+
3284
+ if order is None:
3285
+ order = range(N)
3286
+
3287
+ if set(order) != set(range(N)):
3288
+ raise ValueError("order must be a permutation of range(N)")
3289
+
3290
+ def translate(row, lag):
3291
+ return row * max_lag + lag
3292
+
3293
+ # Define graph links by absolute maximum (positive or negative like for
3294
+ # partial correlation)
3295
+ tsg = np.zeros((N * max_lag, N * max_lag))
3296
+ tsg_val = np.zeros((N * max_lag, N * max_lag))
3297
+ tsg_width = np.zeros((N * max_lag, N * max_lag))
3298
+ tsg_style = np.zeros((N * max_lag, N * max_lag), dtype=graph.dtype)
3299
+ if link_attribute is not None:
3300
+ tsg_attr = np.zeros((N * max_lag, N * max_lag), dtype=link_attribute.dtype)
3301
+
3302
+ if graph.ndim == 4:
3303
+ # 4-dimensional graphs represent the finite-time window projection of stationary 3-d graphs
3304
+ # They are internally created in some classes
3305
+ # Only draw link in one direction
3306
+ for i, j, taui, tauj in np.column_stack(np.where(graph)):
3307
+ tau = taui - tauj
3308
+ # if tau <= 0 and j <= i:
3309
+ if translate(i, max_lag - 1 - taui) >= translate(j, max_lag-1-tauj):
3310
+ continue
3311
+ # print(max_lag, (i, -taui), (j, -tauj), aux_graph[i, j, taui, tauj])
3312
+ # print(translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj))
3313
+ tsg[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = 1.0
3314
+ tsg_val[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = val_matrix[i, j, taui, tauj]
3315
+ tsg_style[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = graph[i, j, taui, tauj]
3316
+ if link_width is not None:
3317
+ tsg_width[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = link_width[i, j, taui, tauj] / link_width.max() * arrow_linewidth
3318
+ if link_attribute is not None:
3319
+ tsg_attr[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = link_attribute[i, j, taui, tauj] #'spurious'
3320
+ # print(tsg_style)
3321
+ # print(tsg_style[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = graph[i, j, taui, tauj])
3322
+ # print(max_lag, (i, -taui), (j, -tauj), graph[i, j, taui, tauj], tsg_style[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)])
3323
+
3324
+
3325
+ else:
3326
+ # Only draw link in one direction
3327
+ # Remove lower triangle
3328
+ link_matrix_tsg = np.copy(graph)
3329
+ link_matrix_tsg[:, :, 0] = np.triu(graph[:, :, 0])
3330
+
3331
+ for i, j, tau in np.column_stack(np.where(link_matrix_tsg)):
3332
+ for t in range(max_lag):
3333
+ if (
3334
+ 0 <= translate(i, t - tau)
3335
+ and translate(i, t - tau) % max_lag <= translate(j, t) % max_lag
3336
+ ):
3337
+
3338
+ tsg[
3339
+ translate(i, t - tau), translate(j, t)
3340
+ ] = 1.0 # val_matrix[i, j, tau]
3341
+ tsg_val[translate(i, t - tau), translate(j, t)] = val_matrix[i, j, tau]
3342
+ tsg_style[translate(i, t - tau), translate(j, t)] = graph[
3343
+ i, j, tau
3344
+ ]
3345
+ if link_width is not None:
3346
+ tsg_width[translate(i, t - tau), translate(j, t)] = (
3347
+ link_width[i, j, tau] / link_width.max() * arrow_linewidth
3348
+ )
3349
+ if link_attribute is not None:
3350
+ tsg_attr[translate(i, t - tau), translate(j, t)] = link_attribute[
3351
+ i, j, tau
3352
+ ]
3353
+
3354
+
3355
+ G = nx.DiGraph(tsg)
3356
+
3357
+ if special_nodes is not None:
3358
+ special_nodes_tsg = {}
3359
+ for node in special_nodes:
3360
+ i, tau = node
3361
+ if tau >= -tau_max:
3362
+ special_nodes_tsg[translate(i, max_lag-1 + tau)] = special_nodes[node]
3363
+
3364
+ special_nodes = special_nodes_tsg
3365
+
3366
+ if node_classification is None:
3367
+ node_classification = {i: "system" for i in range(N)}
3368
+ node_classification_tsg = {}
3369
+ for node in node_classification:
3370
+ for tau in range(max_lag):
3371
+ if tau == 0:
3372
+ suffix = "_first"
3373
+ elif tau == max_lag-1:
3374
+ suffix = "_last"
3375
+ else:
3376
+ suffix = "_middle"
3377
+ node_classification_tsg[translate(node, tau)] = node_classification[node] + suffix
3378
+
3379
+ # node_color = np.zeros(N)
3380
+ # list of all strengths for color map
3381
+ all_strengths = []
3382
+ # Add attributes, contemporaneous and lagged links are handled separately
3383
+ for (u, v, dic) in G.edges(data=True):
3384
+ dic["no_links"] = no_links
3385
+ if u != v:
3386
+ # tau = np.abs((u - v) % max_lag)
3387
+ # Determine neighbors in TSG
3388
+ i = u // max_lag
3389
+ taui = -(max_lag -1 - (u % max_lag))
3390
+ j = v // max_lag
3391
+ tauj = -(max_lag -1 - (v % max_lag))
3392
+
3393
+ if np.abs(i-j) <= 1 and np.abs(tauj-taui) <= 1:
3394
+ inout = 'inner'
3395
+ dic["inner_edge"] = True
3396
+ dic["outer_edge"] = False
3397
+ else:
3398
+ inout = 'outer'
3399
+ dic["inner_edge"] = False
3400
+ dic["outer_edge"] = True
3401
+
3402
+ dic["%s_edge_type" % inout] = tsg_style[u, v]
3403
+
3404
+ dic["%s_edge_alpha" % inout] = alpha
3405
+
3406
+ if link_width is None:
3407
+ # fraction of nonzero values
3408
+ dic["%s_edge_width" % inout] = dic["%s_edge_width" % inout] = arrow_linewidth
3409
+ else:
3410
+ dic["%s_edge_width" % inout] = dic["%s_edge_width" % inout] = tsg_width[u, v]
3411
+
3412
+ if link_attribute is None:
3413
+ dic["%s_edge_attribute" % inout] = None
3414
+ else:
3415
+ dic["%s_edge_attribute" % inout] = tsg_attr[u, v]
3416
+
3417
+ # value at argmax of average
3418
+ if no_coloring:
3419
+ dic["%s_edge_color" % inout] = None
3420
+ else:
3421
+ dic["%s_edge_color" % inout] = tsg_val[u, v]
3422
+
3423
+ all_strengths.append(dic["%s_edge_color" % inout])
3424
+ dic["label"] = None
3425
+ # print(u, v, dic)
3426
+
3427
+ # If no links are present, set value to zero
3428
+ if len(all_strengths) == 0:
3429
+ all_strengths = [0.0]
3430
+
3431
+ posarray = np.zeros((N * max_lag, 2))
3432
+ for i in range(N * max_lag):
3433
+ posarray[i] = np.array([(i % max_lag), (1.0 - i // max_lag)])
3434
+
3435
+ pos_tmp = {}
3436
+ for i in range(N * max_lag):
3437
+ # for n in range(N):
3438
+ # for tau in range(max_lag):
3439
+ # i = n*N + tau
3440
+ pos_tmp[i] = np.array(
3441
+ [
3442
+ ((i % max_lag) - posarray.min(axis=0)[0])
3443
+ / (posarray.max(axis=0)[0] - posarray.min(axis=0)[0]),
3444
+ ((1.0 - i // max_lag) - posarray.min(axis=0)[1])
3445
+ / (posarray.max(axis=0)[1] - posarray.min(axis=0)[1]),
3446
+ ]
3447
+ )
3448
+ pos_tmp[i][np.isnan(pos_tmp[i])] = 0.0
3449
+
3450
+ pos = {}
3451
+ for n in range(N):
3452
+ for tau in range(max_lag):
3453
+ pos[n * max_lag + tau] = pos_tmp[order[n] * max_lag + tau]
3454
+
3455
+ node_rings = {
3456
+ 0: {"sizes": None, "color_array": None, "label": "", "colorbar": False,}
3457
+ }
3458
+
3459
+ node_labels = ["" for i in range(N * max_lag)]
3460
+
3461
+ if graph.ndim == 4 and val_matrix is None:
3462
+ show_colorbar = False
3463
+ else:
3464
+ show_colorbar = True
3465
+
3466
+ _draw_network_with_curved_edges(
3467
+ fig=fig,
3468
+ ax=ax,
3469
+ G=deepcopy(G),
3470
+ pos=pos,
3471
+ node_rings=node_rings,
3472
+ node_labels=node_labels,
3473
+ # node_label_size=node_label_size,
3474
+ node_alpha=alpha,
3475
+ standard_size=node_size,
3476
+ node_aspect=node_aspect,
3477
+ standard_cmap="OrRd",
3478
+ standard_color_nodes=standard_color_nodes,
3479
+ standard_color_links=standard_color_links,
3480
+ log_sizes=False,
3481
+ cmap_links=cmap_edges,
3482
+ links_vmin=vmin_edges,
3483
+ links_vmax=vmax_edges,
3484
+ links_ticks=edge_ticks,
3485
+ # link_label_fontsize=link_label_fontsize,
3486
+ arrowstyle="simple",
3487
+ arrowhead_size=arrowhead_size,
3488
+ curved_radius=curved_radius,
3489
+ label_fontsize=label_fontsize,
3490
+ tick_label_size=tick_label_size,
3491
+ label_fraction=0.5,
3492
+ link_colorbar_label=link_colorbar_label,
3493
+ inner_edge_curved=False,
3494
+ # network_lower_bound=network_lower_bound,
3495
+ # network_left_bound=label_space_left,
3496
+ inner_edge_style=inner_edge_style,
3497
+ special_nodes=special_nodes,
3498
+ show_colorbar=show_colorbar,
3499
+ node_classification=node_classification_tsg,
3500
+ max_lag=max_lag,
3501
+ )
3502
+
3503
+ for i in range(N):
3504
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
3505
+ # trans = transforms.blended_transform_factory(fig.transFigure, ax.transData)
3506
+ ax.text(
3507
+ 0.,
3508
+ pos[order[i] * max_lag][1],
3509
+ f"{var_names[order[i]]}",
3510
+ fontsize=label_fontsize,
3511
+ horizontalalignment="right",
3512
+ verticalalignment="center",
3513
+ transform=trans,
3514
+ )
3515
+
3516
+ for tau in np.arange(max_lag - 1, -1, -1):
3517
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
3518
+ # trans = transforms.blended_transform_factory(ax.transData, fig.transFigure)
3519
+ if tau == max_lag - 1:
3520
+ ax.text(
3521
+ pos[tau][0],
3522
+ 1.0, # - label_space_top,
3523
+ r"$t$",
3524
+ fontsize=int(label_fontsize * 0.8),
3525
+ horizontalalignment="center",
3526
+ verticalalignment="bottom",
3527
+ transform=trans,
3528
+ )
3529
+ else:
3530
+ ax.text(
3531
+ pos[tau][0],
3532
+ 1.0, # - label_space_top,
3533
+ r"$t-%s$" % str(max_lag - tau - 1),
3534
+ fontsize=int(label_fontsize * 0.8),
3535
+ horizontalalignment="center",
3536
+ verticalalignment="bottom",
3537
+ transform=trans,
3538
+ )
3539
+
3540
+ # pyplot.tight_layout()
3541
+ if save_name is not None:
3542
+ pyplot.savefig(save_name, dpi=300)
3543
+ else:
3544
+ return fig, ax
3545
+
3546
+
3547
+ def plot_mediation_time_series_graph(
3548
+ path_node_array,
3549
+ tsg_path_val_matrix,
3550
+ var_names=None,
3551
+ fig_ax=None,
3552
+ figsize=None,
3553
+ link_colorbar_label="link coeff. (edge color)",
3554
+ node_colorbar_label="MCE (node color)",
3555
+ save_name=None,
3556
+ link_width=None,
3557
+ arrow_linewidth=8,
3558
+ vmin_edges=-1,
3559
+ vmax_edges=1.0,
3560
+ edge_ticks=0.4,
3561
+ cmap_edges="RdBu_r",
3562
+ order=None,
3563
+ vmin_nodes=-1.0,
3564
+ vmax_nodes=1.0,
3565
+ node_ticks=0.4,
3566
+ cmap_nodes="RdBu_r",
3567
+ node_size=0.1,
3568
+ node_aspect=None,
3569
+ arrowhead_size=20,
3570
+ curved_radius=0.2,
3571
+ label_fontsize=12,
3572
+ alpha=1.0,
3573
+ node_label_size=12,
3574
+ tick_label_size=6,
3575
+ standard_color_links='black',
3576
+ standard_color_nodes='lightgrey',
3577
+ ):
3578
+ """Creates a mediation time series graph plot.
3579
+ This is still in beta. The time series graph's links are colored by
3580
+ val_matrix.
3581
+
3582
+ Parameters
3583
+ ----------
3584
+ tsg_path_val_matrix : array_like
3585
+ Matrix of shape (N*tau_max, N*tau_max) containing link weight values.
3586
+ path_node_array: array_like
3587
+ Array of shape (N,) containing node values.
3588
+ var_names : list, optional (default: None)
3589
+ List of variable names. If None, range(N) is used.
3590
+ fig_ax : tuple of figure and axis object, optional (default: None)
3591
+ Figure and axes instance. If None they are created.
3592
+ figsize : tuple
3593
+ Size of figure.
3594
+ save_name : str, optional (default: None)
3595
+ Name of figure file to save figure. If None, figure is shown in window.
3596
+ link_colorbar_label : str, optional (default: 'link coeff. (edge color)')
3597
+ Link colorbar label.
3598
+ node_colorbar_label : str, optional (default: 'MCE (node color)')
3599
+ Node colorbar label.
3600
+ link_width : array-like, optional (default: None)
3601
+ Array of val_matrix.shape specifying relative link width with maximum
3602
+ given by arrow_linewidth. If None, all links have same width.
3603
+ order : list, optional (default: None)
3604
+ order of variables from top to bottom.
3605
+ arrow_linewidth : float, optional (default: 30)
3606
+ Linewidth.
3607
+ vmin_edges : float, optional (default: -1)
3608
+ Link colorbar scale lower bound.
3609
+ vmax_edges : float, optional (default: 1)
3610
+ Link colorbar scale upper bound.
3611
+ edge_ticks : float, optional (default: 0.4)
3612
+ Link tick mark interval.
3613
+ cmap_edges : str, optional (default: 'RdBu_r')
3614
+ Colormap for links.
3615
+ vmin_nodes : float, optional (default: 0)
3616
+ Node colorbar scale lower bound.
3617
+ vmax_nodes : float, optional (default: 1)
3618
+ Node colorbar scale upper bound.
3619
+ node_ticks : float, optional (default: 0.4)
3620
+ Node tick mark interval.
3621
+ cmap_nodes : str, optional (default: 'OrRd')
3622
+ Colormap for links.
3623
+ node_size : int, optional (default: 0.1)
3624
+ Node size.
3625
+ node_aspect : float, optional (default: None)
3626
+ Ratio between the heigth and width of the varible nodes.
3627
+ arrowhead_size : int, optional (default: 20)
3628
+ Size of link arrow head. Passed on to FancyArrowPatch object.
3629
+ curved_radius, float, optional (default: 0.2)
3630
+ Curvature of links. Passed on to FancyArrowPatch object.
3631
+ label_fontsize : int, optional (default: 10)
3632
+ Fontsize of colorbar labels.
3633
+ alpha : float, optional (default: 1.)
3634
+ Opacity.
3635
+ node_label_size : int, optional (default: 10)
3636
+ Fontsize of node labels.
3637
+ link_label_fontsize : int, optional (default: 6)
3638
+ Fontsize of link labels.
3639
+ """
3640
+ N = len(path_node_array)
3641
+ Nmaxlag = tsg_path_val_matrix.shape[0]
3642
+ max_lag = Nmaxlag // N
3643
+
3644
+ if var_names is None:
3645
+ var_names = range(N)
3646
+
3647
+ if fig_ax is None:
3648
+ fig = pyplot.figure(figsize=figsize)
3649
+ ax = fig.add_subplot(111, frame_on=False)
3650
+ else:
3651
+ fig, ax = fig_ax
3652
+
3653
+ if link_width is not None and not np.all(link_width >= 0.0):
3654
+ raise ValueError("link_width must be non-negative")
3655
+
3656
+ if order is None:
3657
+ order = range(N)
3658
+
3659
+ if set(order) != set(range(N)):
3660
+ raise ValueError("order must be a permutation of range(N)")
3661
+
3662
+ def translate(row, lag):
3663
+ return row * max_lag + lag
3664
+
3665
+ if np.count_nonzero(tsg_path_val_matrix) == np.count_nonzero(
3666
+ np.diagonal(tsg_path_val_matrix)
3667
+ ):
3668
+ diagonal = True
3669
+ else:
3670
+ diagonal = False
3671
+
3672
+ if np.count_nonzero(tsg_path_val_matrix) == tsg_path_val_matrix.size or diagonal:
3673
+ tsg_path_val_matrix[0, 1] = 1
3674
+ no_links = True
3675
+ else:
3676
+ no_links = False
3677
+
3678
+ # Define graph links by absolute maximum (positive or negative like for
3679
+ # partial correlation)
3680
+ tsg = tsg_path_val_matrix
3681
+ tsg_attr = np.zeros((N * max_lag, N * max_lag))
3682
+
3683
+ G = nx.DiGraph(tsg)
3684
+
3685
+ # node_color = np.zeros(N)
3686
+ # list of all strengths for color map
3687
+ all_strengths = []
3688
+ # Add attributes, contemporaneous and lagged links are handled separately
3689
+ for (u, v, dic) in G.edges(data=True):
3690
+ dic["no_links"] = no_links
3691
+ dic["outer_edge_attribute"] = None
3692
+
3693
+ if u != v:
3694
+
3695
+ if u % max_lag == v % max_lag:
3696
+ dic["inner_edge"] = True
3697
+ dic["outer_edge"] = False
3698
+ else:
3699
+ dic["inner_edge"] = False
3700
+ dic["outer_edge"] = True
3701
+
3702
+ dic["inner_edge_alpha"] = alpha
3703
+ dic["inner_edge_color"] = _get_absmax(
3704
+ np.array([[[tsg[u, v], tsg[v, u]]]])
3705
+ ).squeeze()
3706
+ dic["inner_edge_width"] = arrow_linewidth
3707
+ all_strengths.append(dic["inner_edge_color"])
3708
+
3709
+ dic["outer_edge_alpha"] = alpha
3710
+
3711
+ dic["outer_edge_width"] = arrow_linewidth
3712
+
3713
+ # value at argmax of average
3714
+ dic["outer_edge_color"] = tsg[u, v]
3715
+ all_strengths.append(dic["outer_edge_color"])
3716
+ dic["label"] = None
3717
+
3718
+ # dic['outer_edge_edge'] = False
3719
+ # dic['outer_edge_edgecolor'] = None
3720
+ # dic['inner_edge_edge'] = False
3721
+ # dic['inner_edge_edgecolor'] = None
3722
+
3723
+ # If no links are present, set value to zero
3724
+ if len(all_strengths) == 0:
3725
+ all_strengths = [0.0]
3726
+
3727
+ posarray = np.zeros((N * max_lag, 2))
3728
+ for i in range(N * max_lag):
3729
+ posarray[i] = np.array([(i % max_lag), (1.0 - i // max_lag)])
3730
+
3731
+ pos_tmp = {}
3732
+ for i in range(N * max_lag):
3733
+ # for n in range(N):
3734
+ # for tau in range(max_lag):
3735
+ # i = n*N + tau
3736
+ pos_tmp[i] = np.array(
3737
+ [
3738
+ ((i % max_lag) - posarray.min(axis=0)[0])
3739
+ / (posarray.max(axis=0)[0] - posarray.min(axis=0)[0]),
3740
+ ((1.0 - i // max_lag) - posarray.min(axis=0)[1])
3741
+ / (posarray.max(axis=0)[1] - posarray.min(axis=0)[1]),
3742
+ ]
3743
+ )
3744
+ pos_tmp[i][np.isnan(pos_tmp[i])] = 0.0
3745
+
3746
+ pos = {}
3747
+ for n in range(N):
3748
+ for tau in range(max_lag):
3749
+ pos[n * max_lag + tau] = pos_tmp[order[n] * max_lag + tau]
3750
+
3751
+ node_color = np.zeros(N * max_lag)
3752
+ for inet, n in enumerate(range(0, N * max_lag, max_lag)):
3753
+ node_color[n : n + max_lag] = path_node_array[inet]
3754
+
3755
+ # node_rings = {0: {'sizes': None, 'color_array': color_array,
3756
+ # 'label': '', 'colorbar': False,
3757
+ # }
3758
+ # }
3759
+
3760
+ node_rings = {
3761
+ 0: {
3762
+ "sizes": None,
3763
+ "color_array": node_color,
3764
+ "cmap": cmap_nodes,
3765
+ "vmin": vmin_nodes,
3766
+ "vmax": vmax_nodes,
3767
+ "ticks": node_ticks,
3768
+ "label": node_colorbar_label,
3769
+ "colorbar": True,
3770
+ }
3771
+ }
3772
+
3773
+ # ] for v in range(max_lag)]
3774
+ node_labels = ["" for i in range(N * max_lag)]
3775
+
3776
+ _draw_network_with_curved_edges(
3777
+ fig=fig,
3778
+ ax=ax,
3779
+ G=deepcopy(G),
3780
+ pos=pos,
3781
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
3782
+ # or None, 'cmap':string,
3783
+ node_rings=node_rings,
3784
+ # 'vmin':float or None, 'vmax':float or None, 'label':string or None}}
3785
+ node_labels=node_labels,
3786
+ node_label_size=node_label_size,
3787
+ node_alpha=alpha,
3788
+ standard_size=node_size,
3789
+ node_aspect=node_aspect,
3790
+ standard_cmap="OrRd",
3791
+ standard_color_nodes=standard_color_nodes,
3792
+ standard_color_links=standard_color_links,
3793
+ log_sizes=False,
3794
+ cmap_links=cmap_edges,
3795
+ links_vmin=vmin_edges,
3796
+ links_vmax=vmax_edges,
3797
+ links_ticks=edge_ticks,
3798
+ tick_label_size=tick_label_size,
3799
+ # cmap_links_edges='YlOrRd', links_edges_vmin=-1., links_edges_vmax=1.,
3800
+ # links_edges_ticks=.2, link_edge_colorbar_label='link_edge',
3801
+ arrowhead_size=arrowhead_size,
3802
+ curved_radius=curved_radius,
3803
+ label_fontsize=label_fontsize,
3804
+ label_fraction=0.5,
3805
+ link_colorbar_label=link_colorbar_label,
3806
+ inner_edge_curved=True,
3807
+ # network_lower_bound=network_lower_bound
3808
+ # inner_edge_style=inner_edge_style
3809
+ )
3810
+
3811
+ for i in range(N):
3812
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
3813
+ # trans = transforms.blended_transform_factory(fig.transFigure, ax.transData)
3814
+ ax.text(
3815
+ 0.,
3816
+ pos[order[i] * max_lag][1],
3817
+ "%s" % str(var_names[order[i]]),
3818
+ fontsize=label_fontsize,
3819
+ horizontalalignment="right",
3820
+ verticalalignment="center",
3821
+ transform=trans,
3822
+ )
3823
+
3824
+ for tau in np.arange(max_lag - 1, -1, -1):
3825
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
3826
+ # trans = transforms.blended_transform_factory(ax.transData, fig.transFigure)
3827
+ if tau == max_lag - 1:
3828
+ ax.text(
3829
+ pos[tau][0],
3830
+ 1.0, # - label_space_top,
3831
+ r"$t$",
3832
+ fontsize=label_fontsize,
3833
+ horizontalalignment="center",
3834
+ verticalalignment="bottom",
3835
+ transform=trans,
3836
+ )
3837
+ else:
3838
+ ax.text(
3839
+ pos[tau][0],
3840
+ 1.0, # - label_space_top,
3841
+ r"$t-%s$" % str(max_lag - tau - 1),
3842
+ fontsize=label_fontsize,
3843
+ horizontalalignment="center",
3844
+ verticalalignment="bottom",
3845
+ transform=trans,
3846
+ )
3847
+
3848
+ # fig.subplots_adjust(left=0.1, right=.98, bottom=.25, top=.9)
3849
+ # savestring = os.path.expanduser(save_name)
3850
+ if save_name is not None:
3851
+ pyplot.savefig(save_name)
3852
+ else:
3853
+ pyplot.show()
3854
+
3855
+
3856
+ def plot_mediation_graph(
3857
+ path_val_matrix,
3858
+ path_node_array=None,
3859
+ var_names=None,
3860
+ fig_ax=None,
3861
+ figsize=None,
3862
+ save_name=None,
3863
+ link_colorbar_label="link coeff. (edge color)",
3864
+ node_colorbar_label="MCE (node color)",
3865
+ link_width=None,
3866
+ node_pos=None,
3867
+ arrow_linewidth=10.0,
3868
+ vmin_edges=-1,
3869
+ vmax_edges=1.0,
3870
+ edge_ticks=0.4,
3871
+ cmap_edges="RdBu_r",
3872
+ vmin_nodes=-1.0,
3873
+ vmax_nodes=1.0,
3874
+ node_ticks=0.4,
3875
+ cmap_nodes="RdBu_r",
3876
+ node_size=0.3,
3877
+ node_aspect=None,
3878
+ arrowhead_size=20,
3879
+ curved_radius=0.2,
3880
+ label_fontsize=10,
3881
+ tick_label_size=6,
3882
+ lag_array=None,
3883
+ alpha=1.0,
3884
+ node_label_size=10,
3885
+ link_label_fontsize=10,
3886
+ # network_lower_bound=0.2,
3887
+ standard_color_links='black',
3888
+ standard_color_nodes='lightgrey',
3889
+ ):
3890
+ """Creates a network plot visualizing the pathways of a mediation analysis.
3891
+ This is still in beta. The network is defined from non-zero entries in
3892
+ ``path_val_matrix``. Nodes denote variables, straight links contemporaneous
3893
+ dependencies and curved arrows lagged dependencies. The node color denotes
3894
+ the mediated causal effect (MCE) and the link color the value at the lag
3895
+ with maximal link coefficient. The link label lists the lags with
3896
+ significant dependency in order of absolute magnitude. The network can also
3897
+ be plotted over a map drawn before on the same axis. Then the node positions
3898
+ can be supplied in appropriate axis coordinates via node_pos.
3899
+
3900
+ Parameters
3901
+ ----------
3902
+ path_val_matrix : array_like
3903
+ Matrix of shape (N, N, tau_max+1) containing link weight values.
3904
+ path_node_array: array_like
3905
+ Array of shape (N,) containing node values.
3906
+ var_names : list, optional (default: None)
3907
+ List of variable names. If None, range(N) is used.
3908
+ fig_ax : tuple of figure and axis object, optional (default: None)
3909
+ Figure and axes instance. If None they are created.
3910
+ figsize : tuple
3911
+ Size of figure.
3912
+ save_name : str, optional (default: None)
3913
+ Name of figure file to save figure. If None, figure is shown in window.
3914
+ link_colorbar_label : str, optional (default: 'link coeff. (edge color)')
3915
+ Link colorbar label.
3916
+ node_colorbar_label : str, optional (default: 'MCE (node color)')
3917
+ Node colorbar label.
3918
+ link_width : array-like, optional (default: None)
3919
+ Array of val_matrix.shape specifying relative link width with maximum
3920
+ given by arrow_linewidth. If None, all links have same width.
3921
+ node_pos : dictionary, optional (default: None)
3922
+ Dictionary of node positions in axis coordinates of form
3923
+ node_pos = {'x':array of shape (N,), 'y':array of shape(N)}. These
3924
+ coordinates could have been transformed before for basemap plots. You can
3925
+ also add a key 'transform':ccrs.PlateCarree() in order to plot graphs on
3926
+ a map using cartopy.
3927
+ arrow_linewidth : float, optional (default: 30)
3928
+ Linewidth.
3929
+ vmin_edges : float, optional (default: -1)
3930
+ Link colorbar scale lower bound.
3931
+ vmax_edges : float, optional (default: 1)
3932
+ Link colorbar scale upper bound.
3933
+ edge_ticks : float, optional (default: 0.4)
3934
+ Link tick mark interval.
3935
+ cmap_edges : str, optional (default: 'RdBu_r')
3936
+ Colormap for links.
3937
+ vmin_nodes : float, optional (default: 0)
3938
+ Node colorbar scale lower bound.
3939
+ vmax_nodes : float, optional (default: 1)
3940
+ Node colorbar scale upper bound.
3941
+ node_ticks : float, optional (default: 0.4)
3942
+ Node tick mark interval.
3943
+ cmap_nodes : str, optional (default: 'OrRd')
3944
+ Colormap for links.
3945
+ node_size : int, optional (default: 0.3)
3946
+ Node size.
3947
+ node_aspect : float, optional (default: None)
3948
+ Ratio between the heigth and width of the varible nodes.
3949
+ arrowhead_size : int, optional (default: 20)
3950
+ Size of link arrow head. Passed on to FancyArrowPatch object.
3951
+ curved_radius, float, optional (default: 0.2)
3952
+ Curvature of links. Passed on to FancyArrowPatch object.
3953
+ label_fontsize : int, optional (default: 10)
3954
+ Fontsize of colorbar labels.
3955
+ alpha : float, optional (default: 1.)
3956
+ Opacity.
3957
+ node_label_size : int, optional (default: 10)
3958
+ Fontsize of node labels.
3959
+ link_label_fontsize : int, optional (default: 6)
3960
+ Fontsize of link labels.
3961
+ lag_array : array, optional (default: None)
3962
+ Optional specification of lags overwriting np.arange(0, tau_max+1)
3963
+ """
3964
+ val_matrix = path_val_matrix
3965
+
3966
+ if fig_ax is None:
3967
+ fig = pyplot.figure(figsize=figsize)
3968
+ ax = fig.add_subplot(111, frame_on=False)
3969
+ else:
3970
+ fig, ax = fig_ax
3971
+
3972
+ if link_width is not None and not np.all(link_width >= 0.0):
3973
+ raise ValueError("link_width must be non-negative")
3974
+
3975
+ N, N, dummy = val_matrix.shape
3976
+ tau_max = dummy - 1
3977
+
3978
+ if np.count_nonzero(val_matrix) == np.count_nonzero(np.diagonal(val_matrix)):
3979
+ diagonal = True
3980
+ else:
3981
+ diagonal = False
3982
+
3983
+ if np.count_nonzero(val_matrix) == val_matrix.size or diagonal:
3984
+ val_matrix[0, 1, 0] = 1
3985
+ no_links = True
3986
+ else:
3987
+ no_links = False
3988
+
3989
+ if var_names is None:
3990
+ var_names = range(N)
3991
+
3992
+ # Define graph links by absolute maximum (positive or negative like for
3993
+ # partial correlation)
3994
+ # val_matrix[np.abs(val_matrix) < sig_thres] = 0.
3995
+ graph = val_matrix != 0.0
3996
+ net = _get_absmax(val_matrix)
3997
+ G = nx.DiGraph(net)
3998
+
3999
+ node_color = np.zeros(N)
4000
+ # list of all strengths for color map
4001
+ all_strengths = []
4002
+ # Add attributes, contemporaneous and lagged links are handled separately
4003
+ for (u, v, dic) in G.edges(data=True):
4004
+ dic["outer_edge_attribute"] = None
4005
+ dic["no_links"] = no_links
4006
+ # average lagfunc for link u --> v ANDOR u -- v
4007
+ if tau_max > 0:
4008
+ # argmax of absolute maximum
4009
+ argmax = np.abs(val_matrix[u, v][1:]).argmax() + 1
4010
+ else:
4011
+ argmax = 0
4012
+ if u != v:
4013
+ # For contemp links masking or finite samples can lead to different
4014
+ # values for u--v and v--u
4015
+ # Here we use the maximum for the width and weight (=color)
4016
+ # of the link
4017
+ # Draw link if u--v OR v--u at lag 0 is nonzero
4018
+ # dic['inner_edge'] = ((np.abs(val_matrix[u, v][0]) >=
4019
+ # sig_thres[u, v][0]) or
4020
+ # (np.abs(val_matrix[v, u][0]) >=
4021
+ # sig_thres[v, u][0]))
4022
+ dic["inner_edge"] = graph[u, v, 0] or graph[v, u, 0]
4023
+ dic["inner_edge_alpha"] = alpha
4024
+ # value at argmax of average
4025
+ if np.abs(val_matrix[u, v][0] - val_matrix[v, u][0]) > 0.0001:
4026
+ print(
4027
+ "Contemporaneous I(%d; %d)=%.3f != I(%d; %d)=%.3f"
4028
+ % (u, v, val_matrix[u, v][0], v, u, val_matrix[v, u][0])
4029
+ + " due to conditions, finite sample effects or "
4030
+ "masking, here edge color = "
4031
+ "larger (absolute) value."
4032
+ )
4033
+ dic["inner_edge_color"] = _get_absmax(
4034
+ np.array([[[val_matrix[u, v][0], val_matrix[v, u][0]]]])
4035
+ ).squeeze()
4036
+ if link_width is None:
4037
+ dic["inner_edge_width"] = arrow_linewidth
4038
+ else:
4039
+ dic["inner_edge_width"] = (
4040
+ link_width[u, v, 0] / link_width.max() * arrow_linewidth
4041
+ )
4042
+
4043
+ all_strengths.append(dic["inner_edge_color"])
4044
+
4045
+ if tau_max > 0:
4046
+ # True if ensemble mean at lags > 0 is nonzero
4047
+ # dic['outer_edge'] = np.any(
4048
+ # np.abs(val_matrix[u, v][1:]) >= sig_thres[u, v][1:])
4049
+ dic["outer_edge"] = np.any(graph[u, v, 1:])
4050
+ else:
4051
+ dic["outer_edge"] = False
4052
+ dic["outer_edge_alpha"] = alpha
4053
+ if link_width is None:
4054
+ # fraction of nonzero values
4055
+ dic["outer_edge_width"] = arrow_linewidth
4056
+ else:
4057
+ dic["outer_edge_width"] = (
4058
+ link_width[u, v, argmax] / link_width.max() * arrow_linewidth
4059
+ )
4060
+
4061
+ # value at argmax of average
4062
+ dic["outer_edge_color"] = val_matrix[u, v][argmax]
4063
+ all_strengths.append(dic["outer_edge_color"])
4064
+
4065
+ # Sorted list of significant lags (only if robust wrt
4066
+ # d['min_ensemble_frac'])
4067
+ if tau_max > 0:
4068
+ lags = np.abs(val_matrix[u, v][1:]).argsort()[::-1] + 1
4069
+ sig_lags = (np.where(graph[u, v, 1:])[0] + 1).tolist()
4070
+ else:
4071
+ lags, sig_lags = [], []
4072
+ if lag_array is not None:
4073
+ dic["label"] = ",".join([str(lag_array[l]) for l in lags if l in sig_lags]) #str([str(lag_array[l]) for l in lags if l in sig_lags])[1:-1].replace(" ", "")
4074
+ else:
4075
+ dic["label"] = ",".join([str(l) for l in lags if l in sig_lags]) # str([str(l) for l in lags if l in sig_lags])[1:-1].replace(" ", "")
4076
+ else:
4077
+ # Node color is max of average autodependency
4078
+ node_color[u] = val_matrix[u, v][argmax]
4079
+
4080
+ # dic['outer_edge_edge'] = False
4081
+ # dic['outer_edge_edgecolor'] = None
4082
+ # dic['inner_edge_edge'] = False
4083
+ # dic['inner_edge_edgecolor'] = None
4084
+
4085
+ node_color = path_node_array
4086
+ # print node_color
4087
+ # If no links are present, set value to zero
4088
+ if len(all_strengths) == 0:
4089
+ all_strengths = [0.0]
4090
+
4091
+ if node_pos is None:
4092
+ pos = nx.circular_layout(deepcopy(G))
4093
+ # pos = nx.spring_layout(deepcopy(G))
4094
+ else:
4095
+ pos = {}
4096
+ for i in range(N):
4097
+ pos[i] = (node_pos["x"][i], node_pos["y"][i])
4098
+
4099
+ if node_pos is not None and 'transform' in node_pos:
4100
+ transform = node_pos['transform']
4101
+ else: transform = ax.transData
4102
+
4103
+ node_rings = {
4104
+ 0: {
4105
+ "sizes": None,
4106
+ "color_array": node_color,
4107
+ "cmap": cmap_nodes,
4108
+ "vmin": vmin_nodes,
4109
+ "vmax": vmax_nodes,
4110
+ "ticks": node_ticks,
4111
+ "label": node_colorbar_label,
4112
+ "colorbar": True,
4113
+ }
4114
+ }
4115
+
4116
+ _draw_network_with_curved_edges(
4117
+ fig=fig,
4118
+ ax=ax,
4119
+ G=deepcopy(G),
4120
+ pos=pos,
4121
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
4122
+ # or None, 'cmap':string,
4123
+ node_rings=node_rings,
4124
+ # 'vmin':float or None, 'vmax':float or None, 'label':string or None}}
4125
+ node_labels=var_names,
4126
+ node_label_size=node_label_size,
4127
+ node_alpha=alpha,
4128
+ standard_size=node_size,
4129
+ node_aspect=node_aspect,
4130
+ standard_cmap="OrRd",
4131
+ standard_color_nodes=standard_color_nodes,
4132
+ standard_color_links=standard_color_links,
4133
+ log_sizes=False,
4134
+ cmap_links=cmap_edges,
4135
+ links_vmin=vmin_edges,
4136
+ links_vmax=vmax_edges,
4137
+ links_ticks=edge_ticks,
4138
+ tick_label_size=tick_label_size,
4139
+ # cmap_links_edges='YlOrRd', links_edges_vmin=-1., links_edges_vmax=1.,
4140
+ # links_edges_ticks=.2, link_edge_colorbar_label='link_edge',
4141
+ arrowhead_size=arrowhead_size,
4142
+ curved_radius=curved_radius,
4143
+ label_fontsize=label_fontsize,
4144
+ link_label_fontsize=link_label_fontsize,
4145
+ link_colorbar_label=link_colorbar_label,
4146
+ # network_lower_bound=network_lower_bound,
4147
+ # label_fraction=label_fraction,
4148
+ # inner_edge_style=inner_edge_style
4149
+ transform=transform
4150
+ )
4151
+
4152
+ # fig.subplots_adjust(left=0.1, right=.9, bottom=.25, top=.95)
4153
+ # savestring = os.path.expanduser(save_name)
4154
+ if save_name is not None:
4155
+ pyplot.savefig(save_name)
4156
+ else:
4157
+ pyplot.show()
4158
+
4159
+
4160
+ #
4161
+ # Functions to plot time series graphs from links including ancestors
4162
+ #
4163
+ def plot_tsg(links, X, Y, Z=None, anc_x=None, anc_y=None, anc_xy=None):
4164
+ """Plots TSG that is input in format (N*max_lag, N*max_lag).
4165
+ Compared to the tigramite plotting function here links
4166
+ X^i_{t-tau} --> X^j_t can be missing for different t'. Helpful to
4167
+ visualize the conditioned TSG.
4168
+ """
4169
+
4170
+ def varlag2node(var, lag):
4171
+ """Translate from (var, lag) notation to node in TSG.
4172
+ lag must be <= 0.
4173
+ """
4174
+ return var * max_lag + lag
4175
+
4176
+ def node2varlag(node):
4177
+ """Translate from node in TSG to (var, -tau) notation.
4178
+ Here tau is <= 0.
4179
+ """
4180
+ var = node // max_lag
4181
+ tau = node % (max_lag) - (max_lag - 1)
4182
+ return var, tau
4183
+
4184
+ def _get_minmax_lag(links):
4185
+ """Helper function to retrieve tau_min and tau_max from links
4186
+ """
4187
+
4188
+ N = len(links)
4189
+
4190
+ # Get maximum time lag
4191
+ min_lag = np.inf
4192
+ max_lag = 0
4193
+ for j in range(N):
4194
+ for link_props in links[j]:
4195
+ var, lag = link_props[0]
4196
+ coeff = link_props[1]
4197
+ # func = link_props[2]
4198
+ if coeff != 0.:
4199
+ min_lag = min(min_lag, abs(lag))
4200
+ max_lag = max(max_lag, abs(lag))
4201
+ return min_lag, max_lag
4202
+
4203
+ def _links_to_tsg(link_coeffs, max_lag=None):
4204
+ """Transform link_coeffs to time series graph.
4205
+ TSG is of shape (N*max_lag, N*max_lag).
4206
+ """
4207
+ N = len(link_coeffs)
4208
+
4209
+ # Get maximum lag
4210
+ min_lag_links, max_lag_links = _get_minmax_lag(link_coeffs)
4211
+
4212
+ # max_lag of TSG is max lag in links + 1 for the zero lag.
4213
+ if max_lag is None:
4214
+ max_lag = max_lag_links + 1
4215
+
4216
+ tsg = np.zeros((N * max_lag, N * max_lag))
4217
+
4218
+ for j in range(N):
4219
+ for link_props in link_coeffs[j]:
4220
+ i, lag = link_props[0]
4221
+ tau = abs(lag)
4222
+ coeff = link_props[1]
4223
+ # func = link_props[2]
4224
+ if coeff != 0.0:
4225
+ for t in range(max_lag):
4226
+ if (
4227
+ 0 <= varlag2node(i, t - tau)
4228
+ and varlag2node(i, t - tau) % max_lag
4229
+ <= varlag2node(j, t) % max_lag
4230
+ ):
4231
+ tsg[varlag2node(i, t - tau), varlag2node(j, t)] = 1.0
4232
+
4233
+ return tsg
4234
+
4235
+ color_list = ["lightgrey", "grey", "black", "red", "blue", "orange"]
4236
+ listcmap = ListedColormap(color_list)
4237
+
4238
+ N = len(links)
4239
+
4240
+ min_lag_links, max_lag_links = _get_minmax_lag(links)
4241
+ max_lag = max_lag_links
4242
+
4243
+ for anc in X + Y:
4244
+ max_lag = max(max_lag, abs(anc[1]))
4245
+ for anc in Y:
4246
+ max_lag = max(max_lag, abs(anc[1]))
4247
+ if Z is not None:
4248
+ for anc in Z:
4249
+ max_lag = max(max_lag, abs(anc[1]))
4250
+
4251
+ if anc_x is not None:
4252
+ for anc in anc_x:
4253
+ max_lag = max(max_lag, abs(anc[1]))
4254
+ if anc_y is not None:
4255
+ for anc in anc_y:
4256
+ max_lag = max(max_lag, abs(anc[1]))
4257
+ if anc_xy is not None:
4258
+ for anc in anc_xy:
4259
+ max_lag = max(max_lag, abs(anc[1]))
4260
+
4261
+ max_lag = max_lag + 1
4262
+
4263
+ tsg = _links_to_tsg(links, max_lag=max_lag)
4264
+
4265
+ G = nx.DiGraph(tsg)
4266
+
4267
+ figsize = (3, 3)
4268
+ link_colorbar_label = "MCI"
4269
+ arrow_linewidth = 8.0
4270
+ vmin_edges = -1
4271
+ vmax_edges = 1.0
4272
+ edge_ticks = 0.4
4273
+ cmap_edges = "RdBu_r"
4274
+ order = None
4275
+ node_size = .1
4276
+ arrowhead_size = 20
4277
+ curved_radius = 0.2
4278
+ label_fontsize = 10
4279
+ alpha = 1.0
4280
+ node_label_size = 10
4281
+ # label_space_left = 0.1
4282
+ # label_space_top = 0.0
4283
+ # network_lower_bound = 0.2
4284
+ inner_edge_style = "dashed"
4285
+
4286
+ node_color = np.ones(N * max_lag) # , dtype = 'object')
4287
+ node_color[:] = 0
4288
+
4289
+ if anc_x is not None:
4290
+ for n in [varlag2node(itau[0], max_lag - 1 + itau[1]) for itau in anc_x]:
4291
+ node_color[n] = 3
4292
+ if anc_y is not None:
4293
+ for n in [varlag2node(itau[0], max_lag - 1 + itau[1]) for itau in anc_y]:
4294
+ node_color[n] = 4
4295
+ if anc_xy is not None:
4296
+ for n in [varlag2node(itau[0], max_lag - 1 + itau[1]) for itau in anc_xy]:
4297
+ node_color[n] = 5
4298
+
4299
+ for x in X:
4300
+ node_color[varlag2node(x[0], max_lag - 1 + x[1])] = 2
4301
+ for y in Y:
4302
+ node_color[varlag2node(y[0], max_lag - 1 + y[1])] = 2
4303
+ if Z is not None:
4304
+ for z in Z:
4305
+ node_color[varlag2node(z[0], max_lag - 1 + z[1])] = 1
4306
+
4307
+ fig = pyplot.figure(figsize=figsize)
4308
+ ax = fig.add_subplot(111, frame_on=False)
4309
+ var_names = range(N)
4310
+ order = range(N)
4311
+
4312
+ # list of all strengths for color map
4313
+ all_strengths = []
4314
+ # Add attributes, contemporaneous and lagged links are handled separately
4315
+ for (u, v, dic) in G.edges(data=True):
4316
+ if u != v:
4317
+ if tsg[u, v] and tsg[v, u]:
4318
+ dic["inner_edge"] = True
4319
+ dic["outer_edge"] = False
4320
+ else:
4321
+ dic["inner_edge"] = False
4322
+ dic["outer_edge"] = True
4323
+
4324
+ dic["inner_edge_alpha"] = alpha
4325
+ dic["inner_edge_color"] = tsg[u, v]
4326
+
4327
+ dic["inner_edge_width"] = arrow_linewidth
4328
+ dic["inner_edge_attribute"] = dic["outer_edge_attribute"] = None
4329
+
4330
+ all_strengths.append(dic["inner_edge_color"])
4331
+ dic["outer_edge_alpha"] = alpha
4332
+ dic["outer_edge_width"] = dic["inner_edge_width"] = arrow_linewidth
4333
+
4334
+ # value at argmax of average
4335
+ dic["outer_edge_color"] = tsg[u, v]
4336
+
4337
+ all_strengths.append(dic["outer_edge_color"])
4338
+ dic["label"] = None
4339
+
4340
+ # If no links are present, set value to zero
4341
+ if len(all_strengths) == 0:
4342
+ all_strengths = [0.0]
4343
+
4344
+ posarray = np.zeros((N * max_lag, 2))
4345
+ for i in range(N * max_lag):
4346
+ posarray[i] = np.array([(i % max_lag), (1.0 - i // max_lag)])
4347
+
4348
+ pos_tmp = {}
4349
+ for i in range(N * max_lag):
4350
+ pos_tmp[i] = np.array(
4351
+ [
4352
+ ((i % max_lag) - posarray.min(axis=0)[0])
4353
+ / (posarray.max(axis=0)[0] - posarray.min(axis=0)[0]),
4354
+ ((1.0 - i // max_lag) - posarray.min(axis=0)[1])
4355
+ / (posarray.max(axis=0)[1] - posarray.min(axis=0)[1]),
4356
+ ]
4357
+ )
4358
+ pos_tmp[i][np.isnan(pos_tmp[i])] = 0.0
4359
+
4360
+ pos = {}
4361
+ for n in range(N):
4362
+ for tau in range(max_lag):
4363
+ pos[n * max_lag + tau] = pos_tmp[order[n] * max_lag + tau]
4364
+
4365
+ node_rings = {
4366
+ 0: {
4367
+ "sizes": None,
4368
+ "color_array": node_color,
4369
+ "label": "",
4370
+ "colorbar": False,
4371
+ "cmap": listcmap,
4372
+ "vmin": 0,
4373
+ "vmax": len(color_list),
4374
+ }
4375
+ }
4376
+
4377
+ node_labels = ["" for i in range(N * max_lag)]
4378
+
4379
+ _draw_network_with_curved_edges(
4380
+ fig=fig,
4381
+ ax=ax,
4382
+ G=deepcopy(G),
4383
+ pos=pos,
4384
+ node_rings=node_rings,
4385
+ node_labels=node_labels,
4386
+ node_label_size=node_label_size,
4387
+ node_alpha=alpha,
4388
+ standard_size=node_size,
4389
+ node_aspect=None,
4390
+ standard_cmap="OrRd",
4391
+ standard_color_links='black',
4392
+ standard_color_nodes='lightgrey',
4393
+ log_sizes=False,
4394
+ cmap_links=cmap_edges,
4395
+ links_vmin=vmin_edges,
4396
+ links_vmax=vmax_edges,
4397
+ links_ticks=edge_ticks,
4398
+ arrowstyle="simple",
4399
+ arrowhead_size=arrowhead_size,
4400
+ curved_radius=curved_radius,
4401
+ label_fontsize=label_fontsize,
4402
+ label_fraction=0.5,
4403
+ link_colorbar_label=link_colorbar_label,
4404
+ inner_edge_curved=True,
4405
+ # network_lower_bound=network_lower_bound,
4406
+ inner_edge_style=inner_edge_style,
4407
+ )
4408
+
4409
+ for i in range(N):
4410
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
4411
+ ax.text(
4412
+ 0.,
4413
+ pos[order[i] * max_lag][1],
4414
+ "%s" % str(var_names[order[i]]),
4415
+ fontsize=label_fontsize,
4416
+ horizontalalignment="right",
4417
+ verticalalignment="center",
4418
+ transform=trans,
4419
+ )
4420
+
4421
+ for tau in np.arange(max_lag - 1, -1, -1):
4422
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
4423
+ if tau == max_lag - 1:
4424
+ ax.text(
4425
+ pos[tau][0],
4426
+ 1.0, #- label_space_top,
4427
+ r"$t$",
4428
+ fontsize=int(label_fontsize * 0.7),
4429
+ horizontalalignment="center",
4430
+ verticalalignment="bottom",
4431
+ transform=trans,
4432
+ )
4433
+ else:
4434
+ ax.text(
4435
+ pos[tau][0],
4436
+ 1.0, # - label_space_top,
4437
+ r"$t-%s$" % str(max_lag - tau - 1),
4438
+ fontsize=int(label_fontsize * 0.7),
4439
+ horizontalalignment="center",
4440
+ verticalalignment="bottom",
4441
+ transform=trans,
4442
+ )
4443
+
4444
+ return fig, ax
4445
+
4446
+ def write_csv(
4447
+ graph,
4448
+ save_name,
4449
+ val_matrix=None,
4450
+ var_names=None,
4451
+ link_width=None,
4452
+ link_attribute=None,
4453
+ digits=5,
4454
+ ):
4455
+ """Writes all links in a graph to a csv file.
4456
+
4457
+ Format is each link in a row as 'Variable i', 'Variable j', 'Time lag of i', 'Link type i --- j',
4458
+ with optional further columns for entries in [val_matrix link_attribute, link_width].
4459
+
4460
+ Parameters
4461
+ ----------
4462
+ graph : string or bool array-like, optional (default: None)
4463
+ Either string matrix providing graph or bool array providing only adjacencies
4464
+ Must be of same shape as val_matrix.
4465
+ save_name : str
4466
+ Name of figure file to save figure. If None, figure is shown in window.
4467
+ val_matrix : array_like
4468
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
4469
+ var_names : list, optional (default: None)
4470
+ List of variable names. If None, range(N) is used.
4471
+ link_width : array-like, optional (default: None)
4472
+ Array of val_matrix.shape specifying relative link width with maximum
4473
+ given by arrow_linewidth. If None, all links have same width.
4474
+ link_attribute : array-like, optional (default: None)
4475
+ String array of val_matrix.shape specifying link attributes.
4476
+ digits : int
4477
+ Number of significant digits for writing link value and width.
4478
+ """
4479
+
4480
+ graph = np.copy(graph.squeeze())
4481
+
4482
+ N = len(graph)
4483
+
4484
+ if val_matrix is None:
4485
+ val_matrix_exists = false
4486
+ else:
4487
+ val_matrix_exists = True
4488
+
4489
+ if graph.ndim == 4:
4490
+ raise ValueError("Time series graph of shape (N,N,tau_max+1,tau_max+1) cannot be represented by plot_graph,"
4491
+ " use plot_time_series_graph instead.")
4492
+
4493
+ if graph.ndim == 2:
4494
+ # If a non-time series (N,N)-graph is given, insert a dummy dimension
4495
+ graph = np.expand_dims(graph, axis = 2)
4496
+
4497
+ (graph, val_matrix, link_width, link_attribute) = _check_matrices(
4498
+ graph, val_matrix, link_width, link_attribute)
4499
+
4500
+ if var_names is None:
4501
+ var_names = range(N)
4502
+
4503
+
4504
+ header = ['Variable i', 'Variable j', 'Time lag of i', 'Link type i --- j']
4505
+ if val_matrix_exists:
4506
+ header.append('Link value')
4507
+ if link_attribute is not None:
4508
+ header.append('Link attribute')
4509
+ if link_width is not None:
4510
+ header.append('Link width')
4511
+
4512
+
4513
+ with open(save_name, 'w', encoding='UTF8', newline='') as f:
4514
+ writer = csv.writer(f)
4515
+
4516
+ # write the header
4517
+ writer.writerow(header)
4518
+
4519
+ # write the link data
4520
+ for (i, j, tau) in zip(*np.where(graph!='')):
4521
+ # Only consider contemporaneous links once
4522
+ if tau > 0 or i <= j:
4523
+ row = [str(var_names[i]), str(var_names[j]), f"{tau}", graph[i,j,tau]]
4524
+ if val_matrix_exists:
4525
+ row.append(f"{val_matrix[i,j,tau]:.{digits}}")
4526
+ if link_attribute is not None:
4527
+ row.append(link_attribute[i,j,tau])
4528
+ if link_width is not None:
4529
+ row.append(f"{link_width[i,j,tau]:.{digits}}")
4530
+
4531
+ writer.writerow(row)
4532
+
4533
+
4534
+ if __name__ == "__main__":
4535
+
4536
+ import sys
4537
+ matplotlib.rc('xtick', labelsize=6)
4538
+ matplotlib.rc('ytick', labelsize=6)
4539
+
4540
+ # Consider some toy data
4541
+ import tigramite
4542
+ import tigramite.toymodels.structural_causal_processes as toys
4543
+ import tigramite.data_processing as pp
4544
+ from tigramite.causal_effects import CausalEffects
4545
+
4546
+
4547
+ # T = 1000
4548
+ def lin_f(x): return x
4549
+ # auto_coeff = 0.3
4550
+ # coeff = 1.
4551
+ # links = {
4552
+ # 0: [((0, -1), auto_coeff, lin_f)],
4553
+ # 1: [((1, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
4554
+ # 2: [((2, -1), auto_coeff, lin_f), ((1, 0), coeff, lin_f)],
4555
+ # }
4556
+ # data, nonstat = toys.structural_causal_process(links, T=T,
4557
+ # noises=None, seed=7)
4558
+ # dataframe = pp.DataFrame(data, var_names=range(len(links)))
4559
+
4560
+ # links = {
4561
+ # 0: [((0, -1), 1.5*auto_coeff, lin_f)],
4562
+ # 1: [((1, -1), 1.5*auto_coeff, lin_f), ((0, 0), 1.5*coeff, lin_f)],
4563
+ # 2: [((2, -1), 1.5*auto_coeff, lin_f), ((1, 0), 1.5*coeff, lin_f)],
4564
+ # }
4565
+ # data2, nonstat = toys.structural_causal_process(links, T=T,
4566
+ # noises=None, seed=7)
4567
+ # dataframe2 = pp.DataFrame(data2, var_names=range(len(links)))
4568
+ # plot_densityplots(dataframe, name='test.pdf')
4569
+
4570
+ # N = len(links)
4571
+
4572
+
4573
+ # parcorr = ParCorr(significance='analytic')
4574
+ # pcmci = PCMCI(
4575
+ # dataframe=dataframe,
4576
+ # cond_ind_test=parcorr,
4577
+ # verbosity=1)
4578
+
4579
+
4580
+ correlations = np.random.rand(3, 3, 5) - 0.5 #pcmci.get_lagged_dependencies(tau_max=20, val_only=True)['val_matrix']
4581
+ lag_func_matrix = plot_lagfuncs(val_matrix=correlations, setup_args={
4582
+ 'label_space_left':0.05,
4583
+ 'minimum': 0.0,
4584
+ 'maximum':.05,
4585
+ 'x_base':5,
4586
+ 'y_base':.5})
4587
+ plt.show()
4588
+
4589
+
4590
+ # N = len(links)
4591
+ # matrix = setup_density_matrix(N=N, var_names=dataframe.var_names)
4592
+ # matrix.add_densityplot(dataframe=dataframe,
4593
+ # # selected_dataset=0,
4594
+ # **{
4595
+ # 'label':'Weak',
4596
+ # 'label_color':'blue',
4597
+ # "snskdeplot_args" : {'cmap':'Reds'},
4598
+ # }), #{'cmap':'Blues', 'alpha':0.3}})
4599
+ # matrix.add_densityplot(dataframe=dataframe2, selected_dataset=0,
4600
+ # **{'label':'Strong',
4601
+ # 'label_color':'red',
4602
+ # "snskdeplot_args" : {'cmap':'Blues', 'alpha':0.3}})
4603
+ # matrix.adjustfig(name='test.pdf')
4604
+
4605
+ # matrix = setup_scatter_matrix(N=dataframe.N,
4606
+ # var_names=dataframe.var_names)
4607
+ # matrix_lags = np.ones((3, 3)).astype('int')
4608
+ # matrix.add_scatterplot(dataframe=dataframe, matrix_lags=matrix_lags,
4609
+ # label='ones', alpha=0.4)
4610
+ # matrix_lags = 2*np.ones((3, 3)).astype('int')
4611
+ # matrix.add_scatterplot(dataframe=dataframe, matrix_lags=matrix_lags,
4612
+ # label='twos', color='red', alpha=0.4)
4613
+
4614
+ # matrix.savefig(name='scattertest.pdf')
4615
+
4616
+
4617
+ # pyplot.show()
4618
+ # sys.exit(0)
4619
+
4620
+
4621
+ # val_matrix = np.zeros((4, 4, 3))
4622
+
4623
+ # # Complete test case
4624
+ # graph = np.zeros((3,3,2), dtype='<U3')
4625
+ # val_matrix = 0.*np.random.rand(*graph.shape)
4626
+ # val_matrix[:,:,0] = 0.2
4627
+ # graph[:] = ""
4628
+ # # graph[0, 1, 0] = "<-+"
4629
+ # # graph[1, 0, 0] = "+->"
4630
+ # graph[0, 0, 1] = "-->"
4631
+ # graph[1, 1, 1] = "-->"
4632
+
4633
+ # graph[0, 1, 1] = "+->"
4634
+ # # graph[1, 0, 1] = "o-o"
4635
+
4636
+ # graph[1, 2, 0] = "<->"
4637
+ # graph[2, 1, 0] = "<->"
4638
+
4639
+ # graph[0, 2, 0] = "x-x"
4640
+ # # graph[2, 0, 0] = "x-x"
4641
+ # nolinks = np.zeros(graph.shape)
4642
+ # # nolinks[range(4), range(4), 1] = 1
4643
+
4644
+ # # graph = graph[:2, :2, :]
4645
+
4646
+ # # fig, axes = pyplot.subplots(nrows=1, ncols=1, figsize=(6, 5))
4647
+
4648
+
4649
+ # # import cartopy.crs as ccrs
4650
+ # graph = np.ones((5, 5, 2), dtype='<U3')
4651
+ # graph[:] = ""
4652
+ # graph[3, :, 1] = '+->'
4653
+
4654
+ # fig = pyplot.figure(figsize=(8, 6))
4655
+ # fig = pyplot.figure(figsize=(10, 5))
4656
+ # ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
4657
+ # make the map global rather than have it zoom in to
4658
+ # the extents of any plotted data
4659
+ # ax.set_global()
4660
+ # ax.stock_img()
4661
+ # ax.coastlines()
4662
+ # # ymax = 1.
4663
+ # node_pos = {'x':np.linspace(0, ymax, graph.shape[0]), 'y':np.linspace(0, ymax, graph.shape[0]),}
4664
+ # node_pos = {'x':np.array([10,-20,80,-50,80]),
4665
+ # 'y':np.array([-10,70,60,-40,50]),
4666
+ # 'transform':ccrs.PlateCarree(), # t.PlateCarree()
4667
+ # }
4668
+
4669
+ # plot_time_series_graph(graph=graph,
4670
+ # # fig_ax = (fig, ax),
4671
+ # # val_matrix=val_matrix,
4672
+ # # figsize=(5, 5),
4673
+ # # var_names = ['Var %s' %i for i in range(len(graph))],
4674
+ # # arrow_linewidth=6,
4675
+ # # label_space_left = label_space_left,
4676
+ # # label_space_top = label_space_top,
4677
+ # # # network_lower_bound=network_lower_bound,
4678
+ # save_name="tsg_test.pdf"
4679
+ # )
4680
+ # pyplot.tight_layout()
4681
+
4682
+ # network_lower_bound = 0.
4683
+ # show_colorbar=True
4684
+ # plot_graph(graph=graph,
4685
+ # fig_ax = (fig, ax),
4686
+ # node_pos = node_pos,
4687
+ # node_size = 20,
4688
+ # # val_matrix=val_matrix,
4689
+ # # figsize=(5, 5),
4690
+ # # var_names = ['Var %s' %i for i in range(len(graph))],
4691
+ # # arrow_linewidth=6,
4692
+ # # label_space_left = label_space_left,
4693
+ # # label_space_top = label_space_top,
4694
+ # # # network_lower_bound=network_lower_bound,
4695
+ # save_name="tsg_test.pdf"
4696
+ # )
4697
+ # pyplot.tight_layout()
4698
+ # axes[0,0].scatter(np.random.rand(100), np.random.rand(100))
4699
+
4700
+ # plot_graph(graph=graph,
4701
+ # fig_ax = (fig, axes[0,0]),
4702
+ # val_matrix=val_matrix,
4703
+ # # figsize=(5, 5),
4704
+ # var_names = ['Variable %s' %i for i in range(len(graph))],
4705
+ # arrow_linewidth=6,
4706
+ # # label_space_left = label_space_left,
4707
+ # # label_space_top = label_space_top,
4708
+ # # save_name="tsg_test.pdf"
4709
+ # )
4710
+ # plot_graph(graph=graph,
4711
+ # fig_ax = (fig, axes[0,1]),
4712
+ # val_matrix=val_matrix,
4713
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4714
+ # arrow_linewidth=6,
4715
+ # # label_space_left = label_space_left,
4716
+ # # label_space_top = label_space_top,
4717
+ # )
4718
+ # plot_graph(graph=graph,
4719
+ # fig_ax = (fig, axes[1,0]),
4720
+ # val_matrix=val_matrix,
4721
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4722
+ # arrow_linewidth=6,
4723
+ # # label_space_left = label_space_left,
4724
+ # # label_space_top = label_space_top,
4725
+ # )
4726
+ # plot_graph(graph=graph,
4727
+ # fig_ax = (fig, axes[1,1]),
4728
+ # val_matrix=val_matrix,
4729
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4730
+ # arrow_linewidth=6,
4731
+ # n
4732
+ # # label_space_left = label_space_left,
4733
+ # # label_space_top = label_space_top,
4734
+ # )
4735
+ # # pyplot.subplots_adjust(wspace=0.3, hspace=0.2)
4736
+ # pyplot.tight_layout()
4737
+ # pyplot.savefig("test.pdf")
4738
+
4739
+ # def lin_f(x): return x
4740
+
4741
+ # links_coeffs = {0: [((0, -1), 0.3, lin_f)], #, ((1, -1), 0.5, lin_f)],
4742
+ # 1: [((1, -1), 0.3, lin_f), ((0, 0), 0.7, lin_f), ((2, -1), 0.5, lin_f)],
4743
+ # 2: [],
4744
+ # 3: [((3, -1), 0., lin_f), ((2, 0), 0.6, lin_f),]
4745
+ # }
4746
+ # graph = CausalEffects.get_graph_from_dict(links_coeffs, tau_max=None)
4747
+
4748
+ # val_matrix = np.random.randn(*graph.shape)
4749
+ # val_matrix[:,:,0] = 0.
4750
+ # write_csv(graph=graph,
4751
+ # val_matrix=val_matrix,
4752
+ # var_names=[r'$X^{%d}$' %i for i in range(graph.shape[0])],
4753
+ # link_width=np.ones(graph.shape),
4754
+ # link_attribute = np.ones(graph.shape, dtype='<U10'),
4755
+ # save_name='test.csv')
4756
+
4757
+ # # print(graph)
4758
+ # X = [(0,-1)]
4759
+ # Y = [(1,0)]
4760
+ # causal_effects = CausalEffects(graph, graph_type='stationary_dag', X=X, Y=Y, S=None,
4761
+ # hidden_variables=[(2, 0), (2, -1), (2, -2)],
4762
+ # verbosity=0)
4763
+
4764
+ # pyplot.show()