tigramite-fast 5.2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tigramite/__init__.py +0 -0
  2. tigramite/causal_effects.py +1525 -0
  3. tigramite/causal_mediation.py +1592 -0
  4. tigramite/data_processing.py +1574 -0
  5. tigramite/graphs.py +1509 -0
  6. tigramite/independence_tests/LBFGS.py +1114 -0
  7. tigramite/independence_tests/__init__.py +0 -0
  8. tigramite/independence_tests/cmiknn.py +661 -0
  9. tigramite/independence_tests/cmiknn_mixed.py +1397 -0
  10. tigramite/independence_tests/cmisymb.py +286 -0
  11. tigramite/independence_tests/gpdc.py +664 -0
  12. tigramite/independence_tests/gpdc_torch.py +820 -0
  13. tigramite/independence_tests/gsquared.py +190 -0
  14. tigramite/independence_tests/independence_tests_base.py +1310 -0
  15. tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
  16. tigramite/independence_tests/pairwise_CI.py +383 -0
  17. tigramite/independence_tests/parcorr.py +369 -0
  18. tigramite/independence_tests/parcorr_mult.py +485 -0
  19. tigramite/independence_tests/parcorr_wls.py +451 -0
  20. tigramite/independence_tests/regressionCI.py +403 -0
  21. tigramite/independence_tests/robust_parcorr.py +403 -0
  22. tigramite/jpcmciplus.py +966 -0
  23. tigramite/lpcmci.py +3649 -0
  24. tigramite/models.py +2257 -0
  25. tigramite/pcmci.py +3935 -0
  26. tigramite/pcmci_base.py +1218 -0
  27. tigramite/plotting.py +4735 -0
  28. tigramite/rpcmci.py +467 -0
  29. tigramite/toymodels/__init__.py +0 -0
  30. tigramite/toymodels/context_model.py +261 -0
  31. tigramite/toymodels/non_additive.py +1231 -0
  32. tigramite/toymodels/structural_causal_processes.py +1201 -0
  33. tigramite/toymodels/surrogate_generator.py +319 -0
  34. tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
  35. tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
  36. tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
  37. tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
  38. tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
tigramite/plotting.py ADDED
@@ -0,0 +1,4735 @@
1
+ """Tigramite plotting package."""
2
+
3
+ # Author: Jakob Runge <jakob@jakob-runge.com>
4
+ #
5
+ # License: GNU General Public License v3.0
6
+
7
+ import numpy as np
8
+ import json, warnings, os, pathlib
9
+ import matplotlib
10
+ import networkx as nx
11
+ from matplotlib.colors import ListedColormap
12
+ import matplotlib.transforms as transforms
13
+ from matplotlib import pyplot, ticker
14
+ from matplotlib.ticker import FormatStrFormatter
15
+ import matplotlib.patches as mpatches
16
+ from matplotlib.collections import PatchCollection
17
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
18
+ import sys
19
+ from operator import sub
20
+ import tigramite.data_processing as pp
21
+ from copy import deepcopy
22
+ import matplotlib.path as mpath
23
+ import matplotlib.patheffects as PathEffects
24
+ from mpl_toolkits.axisartist.axislines import Axes
25
+ import csv
26
+
27
+
28
+ # TODO: Add proper docstrings to internal functions...
29
+
30
+
31
+ def _par_corr_trafo(cmi):
32
+ """Transformation of CMI to partial correlation scale."""
33
+
34
+ # Set negative values to small positive number
35
+ # (zero would be interpreted as non-significant in some functions)
36
+ if np.ndim(cmi) == 0:
37
+ if cmi < 0.0:
38
+ cmi = 1e-8
39
+ else:
40
+ cmi[cmi < 0.0] = 1e-8
41
+
42
+ return np.sqrt(1.0 - np.exp(-2.0 * cmi))
43
+
44
+
45
+ def _par_corr_to_cmi(par_corr):
46
+ """Transformation of partial correlation to CMI scale."""
47
+
48
+ return -0.5 * np.log(1.0 - par_corr ** 2)
49
+
50
+
51
+ def _myround(x, base=5, round_mode="updown"):
52
+ """Rounds x to a float with precision base."""
53
+
54
+ if round_mode == "updown":
55
+ return base * round(float(x) / base)
56
+ elif round_mode == "down":
57
+ return base * np.floor(float(x) / base)
58
+ elif round_mode == "up":
59
+ return base * np.ceil(float(x) / base)
60
+
61
+ return base * round(float(x) / base)
62
+
63
+
64
+ def _make_nice_axes(ax, where=None, skip=1, color=None):
65
+ """Makes nice axes."""
66
+
67
+ if where is None:
68
+ where = ["left", "bottom"]
69
+ if color is None:
70
+ color = {"left": "black", "right": "black", "bottom": "black", "top": "black"}
71
+
72
+ if type(skip) == int:
73
+ skip_x = skip_y = skip
74
+ else:
75
+ skip_x = skip[0]
76
+ skip_y = skip[1]
77
+
78
+ for loc, spine in ax.spines.items():
79
+ if loc in where:
80
+ spine.set_position(("outward", 5)) # outward by 10 points
81
+ spine.set_color(color[loc])
82
+ if loc == "left" or loc == "right":
83
+ pyplot.setp(ax.get_yticklines(), color=color[loc])
84
+ pyplot.setp(ax.get_yticklabels(), color=color[loc])
85
+ if loc == "top" or loc == "bottom":
86
+ pyplot.setp(ax.get_xticklines(), color=color[loc])
87
+ elif loc in [
88
+ item for item in ["left", "bottom", "right", "top"] if item not in where
89
+ ]:
90
+ spine.set_color("none") # don't draw spine
91
+ else:
92
+ raise ValueError("unknown spine location: %s" % loc)
93
+
94
+ # ax.xaxis.get_major_formatter().set_useOffset(False)
95
+
96
+ # turn off ticks where there is no spine
97
+ if "top" in where and "bottom" not in where:
98
+ ax.xaxis.set_ticks_position("top")
99
+ if skip_x > 1:
100
+ ax.set_xticks(ax.get_xticks()[::skip_x])
101
+ elif "bottom" in where:
102
+ ax.xaxis.set_ticks_position("bottom")
103
+ if skip_x > 1:
104
+ ax.set_xticks(ax.get_xticks()[::skip_x])
105
+ else:
106
+ ax.xaxis.set_ticks_position("none")
107
+ ax.xaxis.set_ticklabels([])
108
+ if "right" in where and "left" not in where:
109
+ ax.yaxis.set_ticks_position("right")
110
+ if skip_y > 1:
111
+ ax.set_yticks(ax.get_yticks()[::skip_y])
112
+ elif "left" in where:
113
+ ax.yaxis.set_ticks_position("left")
114
+ if skip_y > 1:
115
+ ax.set_yticks(ax.get_yticks()[::skip_y])
116
+ else:
117
+ ax.yaxis.set_ticks_position("none")
118
+ ax.yaxis.set_ticklabels([])
119
+
120
+ ax.patch.set_alpha(0.0)
121
+
122
+
123
+ def _get_absmax(val_matrix):
124
+ """Get value at absolute maximum in lag function array.
125
+ For an (N, N, tau)-array this comutes the lag of the absolute maximum
126
+ along the tau-axis and stores the (positive or negative) value in
127
+ the (N,N)-array absmax."""
128
+
129
+ absmax_indices = np.abs(val_matrix).argmax(axis=2)
130
+ i, j = np.indices(val_matrix.shape[:2])
131
+
132
+ return val_matrix[i, j, absmax_indices]
133
+
134
+
135
+ def _add_timeseries(
136
+ dataframe,
137
+ fig_axes,
138
+ grey_masked_samples=False,
139
+ show_meanline=False,
140
+ data_linewidth=1.0,
141
+ color="black",
142
+ alpha=1.,
143
+ grey_alpha=1.0,
144
+ selected_dataset=0,
145
+ selected_variables=None,
146
+ ):
147
+ """Adds a time series plot to an axis.
148
+ Plot of dataseries is added to axis. Allows for proper visualization of
149
+ masked data.
150
+
151
+ Parameters
152
+ ----------
153
+ fig : figure instance
154
+ Figure instance.
155
+ axes : axis instance
156
+ Either gridded axis object or single axis instance.
157
+ grey_masked_samples : bool, optional (default: False)
158
+ Whether to mark masked samples by grey fills ('fill') or grey data
159
+ ('data').
160
+ show_meanline : bool
161
+ Show mean of data as horizontal line.
162
+ data_linewidth : float, optional (default: 1.)
163
+ Linewidth.
164
+ color : str, optional (default: black)
165
+ Line color.
166
+ alpha : float
167
+ Alpha opacity.
168
+ grey_alpha : float, optional (default: 1.)
169
+ Opacity of fill_between.
170
+ selected_dataset : int, optional (default: 0)
171
+ In case of multiple datasets in dataframe, plot this one.
172
+ selected_variables : list, optional (default: None)
173
+ List of variables which to plot.
174
+ """
175
+ fig, axes = fig_axes
176
+
177
+ # Read in all attributes from dataframe
178
+ data = dataframe.values[selected_dataset]
179
+ if dataframe.mask is not None:
180
+ mask = dataframe.mask[selected_dataset]
181
+ else:
182
+ mask = None
183
+
184
+ missing_flag = dataframe.missing_flag
185
+ time = dataframe.datatime[selected_dataset]
186
+ T = len(time)
187
+
188
+ if selected_variables is None:
189
+ selected_variables = list(range(dataframe.N))
190
+
191
+ nb_components = sum([len(dataframe.vector_vars[var]) for var in selected_variables])
192
+
193
+ for j in range(nb_components):
194
+
195
+ ax = axes[j]
196
+ dataseries = data[:, j]
197
+
198
+ if missing_flag is not None:
199
+ dataseries_nomissing = np.ma.masked_where(
200
+ dataseries == missing_flag, dataseries
201
+ )
202
+ else:
203
+ dataseries_nomissing = np.ma.masked_where(
204
+ np.zeros(dataseries.shape), dataseries
205
+ )
206
+
207
+
208
+ if mask is not None:
209
+ maskseries = mask[:, j]
210
+
211
+ maskdata = np.ma.masked_where(maskseries, dataseries_nomissing)
212
+
213
+ if grey_masked_samples == "fill":
214
+ ax.fill_between(
215
+ time,
216
+ maskdata.min(),
217
+ maskdata.max(),
218
+ where=maskseries,
219
+ color="grey",
220
+ interpolate=True,
221
+ linewidth=0.0,
222
+ alpha=grey_alpha,
223
+ )
224
+ elif grey_masked_samples == "data":
225
+ ax.plot(
226
+ time,
227
+ dataseries_nomissing,
228
+ color="grey",
229
+ marker=".",
230
+ markersize=data_linewidth,
231
+ linewidth=data_linewidth,
232
+ clip_on=False,
233
+ alpha=grey_alpha,
234
+ )
235
+ if show_meanline:
236
+ ax.plot(time, maskdata.mean() * np.ones(T), lw=data_linewidth / 2., color=color)
237
+
238
+ ax.plot(
239
+ time,
240
+ maskdata,
241
+ color=color,
242
+ linewidth=data_linewidth,
243
+ marker=".",
244
+ markersize=data_linewidth,
245
+ clip_on=False,
246
+ alpha=alpha,
247
+ )
248
+ else:
249
+ if show_meanline:
250
+ ax.plot(time, dataseries_nomissing.mean() * np.ones(T), lw=data_linewidth / 2., color=color)
251
+
252
+ ax.plot(
253
+ time,
254
+ dataseries_nomissing,
255
+ color=color,
256
+ linewidth=data_linewidth,
257
+ clip_on=False,
258
+ alpha=alpha,
259
+ )
260
+
261
+
262
+ def plot_timeseries(
263
+ dataframe=None,
264
+ save_name=None,
265
+ fig_axes=None,
266
+ figsize=None,
267
+ var_units=None,
268
+ time_label="",
269
+ grey_masked_samples=False,
270
+ show_meanline=False,
271
+ data_linewidth=1.0,
272
+ skip_ticks_data_x=1,
273
+ skip_ticks_data_y=1,
274
+ label_fontsize=10,
275
+ color='black',
276
+ alpha=1.,
277
+ tick_label_size=6,
278
+ selected_dataset=0,
279
+ adjust_plot=True,
280
+ selected_variables=None,
281
+ ):
282
+ """Create and save figure of stacked panels with time series.
283
+
284
+ Parameters
285
+ ----------
286
+ dataframe : data object, optional
287
+ This is the Tigramite dataframe object. It has the attributes
288
+ dataframe.values yielding a np array of shape (observations T,
289
+ variables N) and optionally a mask of the same shape.
290
+ save_name : str, optional (default: None)
291
+ Name of figure file to save figure. If None, figure is shown in window.
292
+ fig_axes : subplots instance, optional (default: None)
293
+ Figure and axes instance. If None they are created as
294
+ fig, axes = pyplot.subplots(N,...)
295
+ figsize : tuple of floats, optional (default: None)
296
+ Figure size if new figure is created. If None, default pyplot figsize
297
+ is used.
298
+ var_units : list of str, optional (default: None)
299
+ Units of variables.
300
+ time_label : str, optional (default: '')
301
+ Label of time axis.
302
+ grey_masked_samples : bool, optional (default: False)
303
+ Whether to mark masked samples by grey fills ('fill') or grey data
304
+ ('data').
305
+ show_meanline : bool, optional (default: False)
306
+ Whether to plot a horizontal line at the mean.
307
+ data_linewidth : float, optional (default: 1.)
308
+ Linewidth.
309
+ skip_ticks_data_x : int, optional (default: 1)
310
+ Skip every other tickmark.
311
+ skip_ticks_data_y : int, optional (default: 2)
312
+ Skip every other tickmark.
313
+ label_fontsize : int, optional (default: 10)
314
+ Fontsize of variable labels.
315
+ tick_label_size : int, optional (default: 6)
316
+ Fontsize of tick labels.
317
+ color : str, optional (default: black)
318
+ Line color.
319
+ alpha : float
320
+ Alpha opacity.
321
+ selected_dataset : int, optional (default: 0)
322
+ In case of multiple datasets in dataframe, plot this one.
323
+ selected_variables : list, optional (default: None)
324
+ List of variables which to plot.
325
+ """
326
+
327
+ var_names = dataframe.var_names
328
+ time = dataframe.datatime[selected_dataset]
329
+
330
+ N = dataframe.N
331
+
332
+ if selected_variables is None:
333
+ selected_variables = list(range(N))
334
+
335
+ nb_components_per_var = [len(dataframe.vector_vars[var]) for var in selected_variables]
336
+ N_index = [sum(nb_components_per_var[:i]) for i, el in enumerate(nb_components_per_var)]
337
+ nb_components = sum(nb_components_per_var)
338
+
339
+ if var_units is None:
340
+ var_units = ["" for i in range(N)]
341
+
342
+ if fig_axes is None:
343
+ fig, axes = pyplot.subplots(nb_components, sharex=True, figsize=figsize)
344
+ else:
345
+ fig, axes = fig_axes
346
+
347
+ if adjust_plot:
348
+ for i in range(nb_components):
349
+
350
+ ax = axes[i]
351
+
352
+ if (i == nb_components - 1):
353
+ _make_nice_axes(
354
+ ax, where=["left", "bottom"], skip=(skip_ticks_data_x, skip_ticks_data_y)
355
+ )
356
+ ax.set_xlabel(r"%s" % time_label, fontsize=label_fontsize)
357
+ else:
358
+ _make_nice_axes(ax, where=["left"], skip=(skip_ticks_data_x, skip_ticks_data_y))
359
+ # ax.get_xaxis().get_major_formatter().set_useOffset(False)
360
+
361
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%.0f"))
362
+ ax.label_outer()
363
+
364
+ ax.set_xlim(time[0], time[-1])
365
+
366
+ # trans = transforms.blended_transform_factory(fig.transFigure, ax.transAxes)
367
+ if i in N_index:
368
+ if var_units[N_index.index(i)]:
369
+ ax.set_ylabel(r"%s [%s]" % (var_names[N_index.index(i)], var_units[N_index.index(i)]),
370
+ fontsize=label_fontsize)
371
+ else:
372
+ ax.set_ylabel(r"%s" % (var_names[N_index.index(i)]), fontsize=label_fontsize)
373
+
374
+ ax.tick_params(axis='both', which='major', labelsize=tick_label_size)
375
+ # ax.tick_params(axis='both', which='minor', labelsize=tick_label_size)
376
+
377
+ _add_timeseries(
378
+ dataframe=dataframe,
379
+ fig_axes=(fig, axes),
380
+ grey_masked_samples=grey_masked_samples,
381
+ show_meanline=show_meanline,
382
+ data_linewidth=data_linewidth,
383
+ color=color,
384
+ selected_dataset=selected_dataset,
385
+ alpha=alpha,
386
+ selected_variables=selected_variables
387
+ )
388
+
389
+ if adjust_plot:
390
+ fig.subplots_adjust(bottom=0.15, top=0.9, left=0.15, right=0.95, hspace=0.3)
391
+ pyplot.tight_layout()
392
+
393
+ if save_name is not None:
394
+ fig.savefig(save_name)
395
+
396
+ return fig, axes
397
+
398
+
399
+ def plot_lagfuncs(val_matrix,
400
+ name=None,
401
+ setup_args={},
402
+ add_lagfunc_args={}):
403
+ """Wrapper helper function to plot lag functions.
404
+ Sets up the matrix object and plots the lagfunction, see parameters in
405
+ setup_matrix and add_lagfuncs.
406
+
407
+ Parameters
408
+ ----------
409
+ val_matrix : array_like
410
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
411
+ name : str, optional (default: None)
412
+ File name. If None, figure is shown in window.
413
+ setup_args : dict
414
+ Arguments for setting up the lag function matrix, see doc of
415
+ setup_matrix.
416
+ add_lagfunc_args : dict
417
+ Arguments for adding a lag function matrix, see doc of add_lagfuncs.
418
+
419
+ Returns
420
+ -------
421
+ matrix : object
422
+ Further lag functions can be overlaid using the
423
+ matrix.add_lagfuncs(val_matrix) function.
424
+ """
425
+
426
+ N, N, tau_max_plusone = val_matrix.shape
427
+ tau_max = tau_max_plusone - 1
428
+
429
+ matrix = setup_matrix(N=N, tau_max=tau_max, **setup_args)
430
+ matrix.add_lagfuncs(val_matrix=val_matrix, **add_lagfunc_args)
431
+ matrix.savefig(name=name)
432
+
433
+ return matrix
434
+
435
+
436
+ class setup_matrix:
437
+ """Create matrix of lag function panels.
438
+ Class to setup figure object. The function add_lagfuncs(...) allows to plot
439
+ the val_matrix of shape (N, N, tau_max+1). Multiple lagfunctions can be
440
+ overlaid for comparison.
441
+
442
+ Parameters
443
+ ----------
444
+ N : int
445
+ Number of variables
446
+ tau_max : int
447
+ Maximum time lag.
448
+ var_names : list, optional (default: None)
449
+ List of variable names. If None, range(N) is used.
450
+ figsize : tuple of floats, optional (default: None)
451
+ Figure size if new figure is created. If None, default pyplot figsize
452
+ is used.
453
+ minimum : float, optional (default: -1.)
454
+ Lower y-axis limit.
455
+ maximum : float, optional (default: 1.)
456
+ Upper y-axis limit.
457
+ label_space_left : float, optional (default: 0.1)
458
+ Fraction of horizontal figure space to allocate left of plot for labels.
459
+ label_space_top : float, optional (default: 0.05)
460
+ Fraction of vertical figure space to allocate top of plot for labels.
461
+ label_rotation_left : float, optional (default: 0)
462
+ Rotation of variable labels. Set to 90 for vertical labels on y-axis.
463
+ legend_width : float, optional (default: 0.15)
464
+ Fraction of horizontal figure space to allocate right of plot for
465
+ legend.
466
+ tick_label_size : int, optional (default: 6)
467
+ Fontsize of tick labels.
468
+ x_base : float, optional (default: 1.)
469
+ x-tick intervals to show.
470
+ y_base : float, optional (default: .4)
471
+ y-tick intervals to show.
472
+ plot_gridlines : bool, optional (default: False)
473
+ Whether to show a grid.
474
+ lag_units : str, optional (default: '')
475
+ lag_array : array, optional (default: None)
476
+ Optional specification of lags overwriting np.arange(0, tau_max+1)
477
+ label_fontsize : int, optional (default: 10)
478
+ Fontsize of variable labels.
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ N,
484
+ tau_max,
485
+ var_names=None,
486
+ figsize=None,
487
+ minimum=-1,
488
+ maximum=1,
489
+ label_space_left=0.1,
490
+ label_space_top=0.05,
491
+ label_rotation_left=0,
492
+ legend_width=0.15,
493
+ legend_fontsize=10,
494
+ x_base=1.0,
495
+ y_base=0.5,
496
+ tick_label_size=6,
497
+ plot_gridlines=False,
498
+ lag_units="",
499
+ lag_array=None,
500
+ label_fontsize=10,
501
+ ):
502
+
503
+ self.tau_max = tau_max
504
+
505
+ self.labels = []
506
+ self.lag_units = lag_units
507
+ # if lag_array is None:
508
+ # self.lag_array = np.arange(0, self.tau_max + 1)
509
+ # else:
510
+ self.lag_array = lag_array
511
+ if x_base is None:
512
+ self.x_base = 1
513
+ else:
514
+ self.x_base = x_base
515
+
516
+ self.legend_width = legend_width
517
+ self.legend_fontsize = legend_fontsize
518
+
519
+ self.label_space_left = label_space_left
520
+ self.label_space_top = label_space_top
521
+ self.label_fontsize = label_fontsize
522
+
523
+ self.fig = pyplot.figure(figsize=figsize)
524
+
525
+ self.axes_dict = {}
526
+
527
+ if var_names is None:
528
+ var_names = range(N)
529
+
530
+ plot_index = 1
531
+ for i in range(N):
532
+ for j in range(N):
533
+ self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index)
534
+ # Plot process labels
535
+ if j == 0:
536
+ trans = transforms.blended_transform_factory(
537
+ self.fig.transFigure, self.axes_dict[(i, j)].transAxes
538
+ )
539
+ self.axes_dict[(i, j)].text(
540
+ 0.01,
541
+ 0.5,
542
+ "%s" % str(var_names[i]),
543
+ fontsize=label_fontsize,
544
+ horizontalalignment="left",
545
+ verticalalignment="center",
546
+ rotation=label_rotation_left,
547
+ transform=trans,
548
+ )
549
+ if i == 0:
550
+ trans = transforms.blended_transform_factory(
551
+ self.axes_dict[(i, j)].transAxes, self.fig.transFigure
552
+ )
553
+ self.axes_dict[(i, j)].text(
554
+ 0.5,
555
+ 0.99,
556
+ r"${\to}$ " + "%s" % str(var_names[j]),
557
+ fontsize=label_fontsize,
558
+ horizontalalignment="center",
559
+ verticalalignment="top",
560
+ transform=trans,
561
+ )
562
+
563
+ # Make nice axis
564
+ _make_nice_axes(
565
+ self.axes_dict[(i, j)], where=["left", "bottom"], skip=(1, 1)
566
+ )
567
+ if x_base is not None:
568
+ self.axes_dict[(i, j)].xaxis.set_major_locator(
569
+ ticker.FixedLocator(np.arange(0, self.tau_max + 1, x_base))
570
+ )
571
+ if x_base / 2.0 % 1 == 0:
572
+ self.axes_dict[(i, j)].xaxis.set_minor_locator(
573
+ ticker.FixedLocator(
574
+ np.arange(0, self.tau_max + 1, x_base / 2.0)
575
+ )
576
+ )
577
+ if y_base is not None:
578
+ self.axes_dict[(i, j)].yaxis.set_major_locator(
579
+ ticker.FixedLocator(
580
+ np.arange(
581
+ _myround(minimum, y_base, "down"),
582
+ _myround(maximum, y_base, "up") + y_base,
583
+ y_base,
584
+ )
585
+ )
586
+ )
587
+ self.axes_dict[(i, j)].yaxis.set_minor_locator(
588
+ ticker.FixedLocator(
589
+ np.arange(
590
+ _myround(minimum, y_base, "down"),
591
+ _myround(maximum, y_base, "up") + y_base,
592
+ y_base / 2.0,
593
+ )
594
+ )
595
+ )
596
+
597
+ self.axes_dict[(i, j)].set_ylim(
598
+ _myround(minimum, y_base, "down"),
599
+ _myround(maximum, y_base, "up"),
600
+ )
601
+ if j != 0:
602
+ self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
603
+ self.axes_dict[(i, j)].set_xlim(0, self.tau_max)
604
+ if plot_gridlines:
605
+ self.axes_dict[(i, j)].grid(
606
+ True,
607
+ which="major",
608
+ color="black",
609
+ linestyle="dotted",
610
+ dashes=(1, 1),
611
+ linewidth=0.05,
612
+ zorder=-5,
613
+ )
614
+ self.axes_dict[(i, j)].tick_params(axis='both', which='major', labelsize=tick_label_size)
615
+ self.axes_dict[(i, j)].tick_params(axis='both', which='minor', labelsize=tick_label_size)
616
+
617
+ plot_index += 1
618
+
619
+ def add_lagfuncs(
620
+ self,
621
+ val_matrix,
622
+ sig_thres=None,
623
+ conf_matrix=None,
624
+ color="black",
625
+ label=None,
626
+ two_sided_thres=True,
627
+ marker=".",
628
+ markersize=5,
629
+ alpha=1.0,
630
+ ):
631
+ """Add lag function plot from val_matrix array.
632
+
633
+ Parameters
634
+ ----------
635
+ val_matrix : array_like
636
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
637
+ sig_thres : array-like, optional (default: None)
638
+ Matrix of significance thresholds. Must be of same shape as
639
+ val_matrix.
640
+ conf_matrix : array-like, optional (default: None)
641
+ Matrix of shape (, N, tau_max+1, 2) containing confidence bounds.
642
+ color : str, optional (default: 'black')
643
+ Line color.
644
+ label : str
645
+ Test statistic label.
646
+ two_sided_thres : bool, optional (default: True)
647
+ Whether to draw sig_thres for pos. and neg. values.
648
+ marker : matplotlib marker symbol, optional (default: '.')
649
+ Marker.
650
+ markersize : int, optional (default: 5)
651
+ Marker size.
652
+ alpha : float, optional (default: 1.)
653
+ Opacity.
654
+ """
655
+
656
+ if label is not None:
657
+ self.labels.append((label, color, marker, markersize, alpha))
658
+
659
+ for ij in list(self.axes_dict):
660
+ i = ij[0]
661
+ j = ij[1]
662
+ maskedres = np.copy(val_matrix[i, j, int(i == j) :])
663
+ self.axes_dict[(i, j)].plot(
664
+ range(int(i == j), self.tau_max + 1),
665
+ maskedres,
666
+ linestyle="",
667
+ color=color,
668
+ marker=marker,
669
+ markersize=markersize,
670
+ alpha=alpha,
671
+ clip_on=False,
672
+ )
673
+ if conf_matrix is not None:
674
+ maskedconfres = np.copy(conf_matrix[i, j, int(i == j) :])
675
+ self.axes_dict[(i, j)].plot(
676
+ range(int(i == j), self.tau_max + 1),
677
+ maskedconfres[:, 0],
678
+ linestyle="",
679
+ color=color,
680
+ marker="_",
681
+ markersize=markersize - 2,
682
+ alpha=alpha,
683
+ clip_on=False,
684
+ )
685
+ self.axes_dict[(i, j)].plot(
686
+ range(int(i == j), self.tau_max + 1),
687
+ maskedconfres[:, 1],
688
+ linestyle="",
689
+ color=color,
690
+ marker="_",
691
+ markersize=markersize - 2,
692
+ alpha=alpha,
693
+ clip_on=False,
694
+ )
695
+
696
+ self.axes_dict[(i, j)].plot(
697
+ range(int(i == j), self.tau_max + 1),
698
+ np.zeros(self.tau_max + 1 - int(i == j)),
699
+ color="black",
700
+ linestyle="dotted",
701
+ linewidth=0.1,
702
+ )
703
+
704
+ if sig_thres is not None:
705
+ maskedsigres = sig_thres[i, j, int(i == j) :]
706
+
707
+ self.axes_dict[(i, j)].plot(
708
+ range(int(i == j), self.tau_max + 1),
709
+ maskedsigres,
710
+ color=color,
711
+ linestyle="solid",
712
+ linewidth=0.1,
713
+ alpha=alpha,
714
+ )
715
+ if two_sided_thres:
716
+ self.axes_dict[(i, j)].plot(
717
+ range(int(i == j), self.tau_max + 1),
718
+ -sig_thres[i, j, int(i == j) :],
719
+ color=color,
720
+ linestyle="solid",
721
+ linewidth=0.1,
722
+ alpha=alpha,
723
+ )
724
+ # pyplot.tight_layout()
725
+
726
+ def savefig(self, name=None):
727
+ """Save matrix figure.
728
+
729
+ Parameters
730
+ ----------
731
+ name : str, optional (default: None)
732
+ File name. If None, figure is shown in window.
733
+ """
734
+
735
+ # Trick to plot legend
736
+ if len(self.labels) > 0:
737
+ axlegend = self.fig.add_subplot(111, frameon=False)
738
+ axlegend.spines["left"].set_color("none")
739
+ axlegend.spines["right"].set_color("none")
740
+ axlegend.spines["bottom"].set_color("none")
741
+ axlegend.spines["top"].set_color("none")
742
+ axlegend.set_xticks([])
743
+ axlegend.set_yticks([])
744
+
745
+ # self.labels.append((label, color, marker, markersize, alpha))
746
+ for item in self.labels:
747
+ label = item[0]
748
+ color = item[1]
749
+ marker = item[2]
750
+ markersize = item[3]
751
+ alpha = item[4]
752
+
753
+ axlegend.plot(
754
+ [],
755
+ [],
756
+ linestyle="",
757
+ color=color,
758
+ marker=marker,
759
+ markersize=markersize,
760
+ label=label,
761
+ alpha=alpha,
762
+ )
763
+ axlegend.legend(
764
+ loc="upper left",
765
+ ncol=1,
766
+ bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
767
+ borderaxespad=0,
768
+ fontsize=self.legend_fontsize,
769
+ ).draw_frame(False)
770
+
771
+ self.fig.subplots_adjust(
772
+ left=self.label_space_left,
773
+ right=1.0 - self.legend_width,
774
+ top=1.0 - self.label_space_top,
775
+ hspace=0.35,
776
+ wspace=0.35,
777
+ )
778
+ pyplot.figtext(
779
+ 0.5,
780
+ 0.01,
781
+ r"lag $\tau$ [%s]" % self.lag_units,
782
+ horizontalalignment="center",
783
+ fontsize=self.label_fontsize,
784
+ )
785
+ else:
786
+ self.fig.subplots_adjust(
787
+ left=self.label_space_left,
788
+ right=0.95,
789
+ top=1.0 - self.label_space_top,
790
+ hspace=0.35,
791
+ wspace=0.35,
792
+ )
793
+ pyplot.figtext(
794
+ 0.55,
795
+ 0.01,
796
+ r"lag $\tau$ [%s]" % self.lag_units,
797
+ horizontalalignment="center",
798
+ fontsize=self.label_fontsize,
799
+ )
800
+
801
+ if self.lag_array is not None:
802
+ assert self.lag_array.shape == np.arange(self.tau_max + 1).shape
803
+ for ij in list(self.axes_dict):
804
+ i = ij[0]
805
+ j = ij[1]
806
+ self.axes_dict[(i, j)].set_xticklabels(self.lag_array[:: self.x_base])
807
+
808
+ if name is not None:
809
+ self.fig.savefig(name)
810
+ else:
811
+ pyplot.show()
812
+
813
+
814
+
815
+ def plot_scatterplots(dataframe,
816
+ name=None,
817
+ setup_args={},
818
+ add_scatterplot_args={},
819
+ selected_dataset=0):
820
+ """Wrapper helper function to plot scatter plots.
821
+ Sets up the matrix object and plots the scatter plots, see parameters in
822
+ setup_scatter_matrix and add_scatterplot.
823
+
824
+ Parameters
825
+ ----------
826
+ dataframe : data object
827
+ Tigramite dataframe object. It must have the attributes dataframe.values
828
+ yielding a numpy array of shape (observations T, variables N) and
829
+ optionally a mask of the same shape and a missing values flag.
830
+ name : str, optional (default: None)
831
+ File name. If None, figure is shown in window.
832
+ setup_args : dict
833
+ Arguments for setting up the scatter plot matrix, see doc of
834
+ setup_scatter_matrix.
835
+ add_scatterplot_args : dict
836
+ Arguments for adding a scatter plot matrix.
837
+ selected_dataset : int, optional (default: 0)
838
+ In case of multiple datasets in dataframe, plot this one.
839
+
840
+ Returns
841
+ -------
842
+ matrix : object
843
+ Further scatter plot can be overlaid using the
844
+ matrix.add_scatterplot function.
845
+ """
846
+
847
+ N = dataframe.N
848
+
849
+ matrix = setup_scatter_matrix(N=N, var_names=dataframe.var_names, **setup_args)
850
+ matrix.add_scatterplot(dataframe=dataframe, selected_dataset=selected_dataset, **add_scatterplot_args)
851
+ matrix.adjustfig(name=name)
852
+
853
+
854
+ return matrix
855
+
856
+
857
+ class setup_scatter_matrix:
858
+ """Create matrix of scatter plot panels.
859
+ Class to setup figure object. The function add_scatterplot allows to plot
860
+ scatterplots of variables in the dataframe. Multiple scatter plots can be
861
+ overlaid for comparison.
862
+
863
+ Parameters
864
+ ----------
865
+ N : int
866
+ Number of variables
867
+ var_names : list, optional (default: None)
868
+ List of variable names. If None, range(N) is used.
869
+ figsize : tuple of floats, optional (default: None)
870
+ Figure size if new figure is created. If None, default pyplot figsize
871
+ is used.
872
+ label_space_left : float, optional (default: 0.1)
873
+ Fraction of horizontal figure space to allocate left of plot for labels.
874
+ label_space_top : float, optional (default: 0.05)
875
+ Fraction of vertical figure space to allocate top of plot for labels.
876
+ label_rotation_left : float, optional (default: 0)
877
+ Rotation of variable labels. Set to 90 for vertical labels on y-axis.
878
+ legend_width : float, optional (default: 0.15)
879
+ Fraction of horizontal figure space to allocate right of plot for
880
+ legend.
881
+ tick_label_size : int, optional (default: 6)
882
+ Fontsize of tick labels.
883
+ plot_gridlines : bool, optional (default: False)
884
+ Whether to show a grid.
885
+ label_fontsize : int, optional (default: 10)
886
+ Fontsize of variable labels.
887
+ """
888
+
889
+ def __init__(
890
+ self,
891
+ N,
892
+ var_names=None,
893
+ figsize=None,
894
+ label_space_left=0.1,
895
+ label_space_top=0.05,
896
+ label_rotation_left=0,
897
+ legend_width=0.15,
898
+ legend_fontsize=10,
899
+ plot_gridlines=False,
900
+ tick_label_size=6,
901
+ label_fontsize=10,
902
+ ):
903
+
904
+ self.labels = []
905
+
906
+ self.legend_width = legend_width
907
+ self.legend_fontsize = legend_fontsize
908
+
909
+ self.label_space_left = label_space_left
910
+ self.label_space_top = label_space_top
911
+ self.label_fontsize = label_fontsize
912
+
913
+ self.fig = pyplot.figure(figsize=figsize)
914
+
915
+ self.axes_dict = {}
916
+
917
+ if var_names is None:
918
+ var_names = range(N)
919
+
920
+ plot_index = 1
921
+ for i in range(N):
922
+ for j in range(N):
923
+ self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index, axes_class=Axes)
924
+ # Plot process labels
925
+ if j == 0:
926
+ trans = transforms.blended_transform_factory(
927
+ self.fig.transFigure, self.axes_dict[(i, j)].transAxes
928
+ )
929
+ self.axes_dict[(i, j)].text(
930
+ 0.01,
931
+ 0.5,
932
+ "%s" % str(var_names[i]),
933
+ fontsize=label_fontsize,
934
+ horizontalalignment="left",
935
+ verticalalignment="center",
936
+ rotation=label_rotation_left,
937
+ transform=trans,
938
+ )
939
+ if i == 0:
940
+ trans = transforms.blended_transform_factory(
941
+ self.axes_dict[(i, j)].transAxes, self.fig.transFigure
942
+ )
943
+ self.axes_dict[(i, j)].text(
944
+ 0.5,
945
+ 0.99,
946
+ r"${\to}$ " + "%s" % str(var_names[j]),
947
+ fontsize=label_fontsize,
948
+ horizontalalignment="center",
949
+ verticalalignment="top",
950
+ transform=trans,
951
+ )
952
+
953
+ self.axes_dict[(i, j)].axis["right"].set_visible(False)
954
+ self.axes_dict[(i, j)].axis["top"].set_visible(False)
955
+
956
+ if j != 0:
957
+ self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
958
+ if i != N - 1:
959
+ self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])
960
+
961
+ if plot_gridlines:
962
+ self.axes_dict[(i, j)].grid(
963
+ True,
964
+ which="major",
965
+ color="black",
966
+ linestyle="dotted",
967
+ dashes=(1, 1),
968
+ linewidth=0.05,
969
+ zorder=-5,
970
+ )
971
+ self.axes_dict[(i, j)].tick_params(axis='both', which='major', labelsize=tick_label_size)
972
+
973
+ plot_index += 1
974
+
975
+ def add_scatterplot(
976
+ self,
977
+ dataframe,
978
+ matrix_lags=None,
979
+ color="black",
980
+ label=None,
981
+ marker=".",
982
+ markersize=5,
983
+ alpha=.2,
984
+ selected_dataset=0,
985
+ ):
986
+ """Add scatter plot.
987
+
988
+ Parameters
989
+ ----------
990
+ dataframe : data object
991
+ Tigramite dataframe object. It must have the attributes dataframe.values
992
+ yielding a numpy array of shape (observations T, variables N) and
993
+ optionally a mask of the same shape and a missing values flag.
994
+ matrix_lags : array
995
+ Lags to use in scatter plots. Either None or of shape (N, N). Then the
996
+ entry matrix_lags[i, j] = tau will depict the scatter plot of
997
+ time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
998
+ tau = 1.
999
+ color : str, optional (default: 'black')
1000
+ Line color.
1001
+ label : str
1002
+ Test statistic label.
1003
+ marker : matplotlib marker symbol, optional (default: '.')
1004
+ Marker.
1005
+ markersize : int, optional (default: 5)
1006
+ Marker size.
1007
+ alpha : float, optional (default: 1.)
1008
+ Opacity.
1009
+ selected_dataset : int, optional (default: 0)
1010
+ In case of multiple datasets in dataframe, plot this one.
1011
+ """
1012
+
1013
+ if matrix_lags is not None and np.any(matrix_lags < 0):
1014
+ raise ValueError("matrix_lags must be non-negative!")
1015
+
1016
+ data = dataframe.values[selected_dataset]
1017
+ if dataframe.mask is not None:
1018
+ mask = dataframe.mask[selected_dataset]
1019
+
1020
+ T, dim = data.shape
1021
+
1022
+ if label is not None:
1023
+ self.labels.append((label, color, marker, markersize, alpha))
1024
+
1025
+ for ij in list(self.axes_dict):
1026
+ i = ij[0]
1027
+ j = ij[1]
1028
+ if matrix_lags is None:
1029
+ if i == j:
1030
+ lag = 1
1031
+ else:
1032
+ lag = 0
1033
+ else:
1034
+ lag = matrix_lags[i,j]
1035
+ x = np.copy(data[:T-lag, i])
1036
+ y = np.copy(data[lag:, j])
1037
+ if dataframe.mask is not None:
1038
+ x[mask[:T-lag, i]==1] = np.nan
1039
+ y[mask[lag:, j]==1] = np.nan
1040
+
1041
+ # print(i, j, lag, x.shape, y.shape)
1042
+ self.axes_dict[(i, j)].scatter(
1043
+ y, x, # NEW: inverted to match rows and columns!
1044
+ color=color,
1045
+ marker=marker,
1046
+ s=markersize,
1047
+ alpha=alpha,
1048
+ clip_on=False,
1049
+ label=r"$\tau{=}%d$" %lag,
1050
+ )
1051
+ # self.axes_dict[(i, j)].text(0., 1., r"$\tau{=}%d$" %lag,
1052
+ # fontsize=self.legend_fontsize,
1053
+ # ha='left', va='top',
1054
+ # transform=self.axes_dict[(i, j)].transAxes)
1055
+
1056
+
1057
+ def adjustfig(self, name=None):
1058
+ """Adjust matrix figure.
1059
+
1060
+ Parameters
1061
+ ----------
1062
+ name : str, optional (default: None)
1063
+ File name. If None, figure is shown in window.
1064
+ """
1065
+
1066
+ # Trick to plot legends
1067
+ colors = []
1068
+ for item in self.labels:
1069
+ colors.append(item[1])
1070
+ for ij in list(self.axes_dict):
1071
+ i = ij[0]
1072
+ j = ij[1]
1073
+
1074
+ leg = self.axes_dict[(i, j)].legend(
1075
+ # loc="upper left",
1076
+ ncol=1,
1077
+ # bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1078
+ # borderaxespad=0,
1079
+ fontsize=self.legend_fontsize-2,
1080
+ labelcolor=colors,
1081
+ ).draw_frame(False)
1082
+
1083
+ if len(self.labels) > 0:
1084
+ axlegend = self.fig.add_subplot(111, frameon=False)
1085
+ axlegend.spines["left"].set_color("none")
1086
+ axlegend.spines["right"].set_color("none")
1087
+ axlegend.spines["bottom"].set_color("none")
1088
+ axlegend.spines["top"].set_color("none")
1089
+ axlegend.set_xticks([])
1090
+ axlegend.set_yticks([])
1091
+
1092
+ # self.labels.append((label, color, marker, markersize, alpha))
1093
+ for item in self.labels:
1094
+ label = item[0]
1095
+ color = item[1]
1096
+ marker = item[2]
1097
+ markersize = item[3]
1098
+ alpha = item[4]
1099
+
1100
+ axlegend.plot(
1101
+ [],
1102
+ [],
1103
+ linestyle="",
1104
+ color=color,
1105
+ marker=marker,
1106
+ markersize=markersize,
1107
+ label=label,
1108
+ alpha=alpha,
1109
+ )
1110
+ axlegend.legend(
1111
+ loc="upper left",
1112
+ ncol=1,
1113
+ bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1114
+ borderaxespad=0,
1115
+ fontsize=self.legend_fontsize,
1116
+ ).draw_frame(False)
1117
+
1118
+ self.fig.subplots_adjust(
1119
+ bottom=0.05,
1120
+ left=self.label_space_left,
1121
+ right=1.0 - self.legend_width,
1122
+ top=1.0 - self.label_space_top,
1123
+ hspace=0.5,
1124
+ wspace=0.35,
1125
+ )
1126
+
1127
+ else:
1128
+ self.fig.subplots_adjust(
1129
+ left=self.label_space_left,
1130
+ bottom=0.05,
1131
+ right=0.95,
1132
+ top=1.0 - self.label_space_top,
1133
+ hspace=0.35,
1134
+ wspace=0.35,
1135
+ )
1136
+
1137
+ if name is not None:
1138
+ self.fig.savefig(name)
1139
+ else:
1140
+ pyplot.show()
1141
+
1142
+
1143
+ def plot_densityplots(dataframe,
1144
+ name=None,
1145
+ setup_args={},
1146
+ add_densityplot_args={},
1147
+ selected_dataset=0,
1148
+ show_marginal_densities_on_diagonal=True):
1149
+ """Wrapper helper function to plot density plots.
1150
+ Sets up the matrix object and plots the density plots, see parameters in
1151
+ setup_density_matrix and add_densityplot.
1152
+
1153
+ The diagonal shows the marginal densities.
1154
+
1155
+ Requires seaborn.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ dataframe : data object
1160
+ Tigramite dataframe object. It must have the attributes dataframe.values
1161
+ yielding a numpy array of shape (observations T, variables N) and
1162
+ optionally a mask of the same shape and a missing values flag.
1163
+ name : str, optional (default: None)
1164
+ File name. If None, figure is shown in window.
1165
+ setup_args : dict
1166
+ Arguments for setting up the density plot matrix, see doc of
1167
+ setup_density_matrix.
1168
+ add_densityplot_args : dict
1169
+ Arguments for adding a density plot matrix.
1170
+ selected_dataset : int, optional (default: 0)
1171
+ In case of multiple datasets in dataframe, plot this one.
1172
+ show_marginal_densities_on_diagonal : bool, optional (default: True)
1173
+ Flag to show marginal densities on the diagonal of the density plots
1174
+
1175
+ Returns
1176
+ -------
1177
+ matrix : object
1178
+ Further density plots can be overlaid using the
1179
+ matrix.add_densityplot function.
1180
+ """
1181
+
1182
+ N = dataframe.N
1183
+
1184
+ matrix = setup_density_matrix(N=N, var_names=dataframe.var_names, **setup_args)
1185
+ matrix.add_densityplot(dataframe=dataframe, selected_dataset=selected_dataset,
1186
+ show_marginal_densities_on_diagonal=show_marginal_densities_on_diagonal, **add_densityplot_args)
1187
+ matrix.adjustfig(name=name)
1188
+
1189
+
1190
+ return matrix
1191
+
1192
+
1193
+ class setup_density_matrix:
1194
+ """Create matrix of density plot panels.
1195
+ Class to setup figure object. The function add_densityplot allows to plot
1196
+ density plots of variables in the dataframe.
1197
+
1198
+ Further density plots can be overlaid using the matrix.add_densityplot
1199
+ function.
1200
+
1201
+ Parameters
1202
+ ----------
1203
+ N : int
1204
+ Number of variables
1205
+ var_names : list, optional (default: None)
1206
+ List of variable names. If None, range(N) is used.
1207
+ figsize : tuple of floats, optional (default: None)
1208
+ Figure size if new figure is created. If None, default pyplot figsize
1209
+ is used.
1210
+ label_space_left : float, optional (default: 0.1)
1211
+ Fraction of horizontal figure space to allocate left of plot for labels.
1212
+ label_space_top : float, optional (default: 0.05)
1213
+ Fraction of vertical figure space to allocate top of plot for labels.
1214
+ label_rotation_left : float, optional (default: 0)
1215
+ Rotation of variable labels. Set to 90 for vertical labels on y-axis.
1216
+ legend_width : float, optional (default: 0.15)
1217
+ Fraction of horizontal figure space to allocate right of plot for
1218
+ legend.
1219
+ tick_label_size : int, optional (default: 6)
1220
+ Fontsize of tick labels.
1221
+ plot_gridlines : bool, optional (default: False)
1222
+ Whether to show a grid.
1223
+ label_fontsize : int, optional (default: 10)
1224
+ Fontsize of variable labels.
1225
+ """
1226
+
1227
+ def __init__(
1228
+ self,
1229
+ N,
1230
+ var_names=None,
1231
+ figsize=None,
1232
+ label_space_left=0.15,
1233
+ label_space_top=0.05,
1234
+ label_rotation_left=0,
1235
+ legend_width=0.15,
1236
+ legend_fontsize=10,
1237
+ tick_label_size=6,
1238
+ plot_gridlines=False,
1239
+ label_fontsize=10,
1240
+ ):
1241
+
1242
+ self.labels = []
1243
+
1244
+ self.legend_width = legend_width
1245
+ self.legend_fontsize = legend_fontsize
1246
+
1247
+ self.label_space_left = label_space_left
1248
+ self.label_space_top = label_space_top
1249
+ self.label_fontsize = label_fontsize
1250
+
1251
+ self.fig = pyplot.figure(figsize=figsize)
1252
+
1253
+ self.axes_dict = {}
1254
+
1255
+ if var_names is None:
1256
+ var_names = range(N)
1257
+
1258
+ plot_index = 1
1259
+ for i in range(N):
1260
+ for j in range(N):
1261
+ self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index)
1262
+ # Plot process labels
1263
+ if j == 0:
1264
+ trans = transforms.blended_transform_factory(
1265
+ self.fig.transFigure, self.axes_dict[(i, j)].transAxes
1266
+ )
1267
+ self.axes_dict[(i, j)].text(
1268
+ 0.01,
1269
+ 0.5,
1270
+ "%s" % str(var_names[i]),
1271
+ fontsize=label_fontsize,
1272
+ horizontalalignment="left",
1273
+ verticalalignment="center",
1274
+ rotation=label_rotation_left,
1275
+ transform=trans,
1276
+ )
1277
+ if i == 0:
1278
+ trans = transforms.blended_transform_factory(
1279
+ self.axes_dict[(i, j)].transAxes, self.fig.transFigure
1280
+ )
1281
+ self.axes_dict[(i, j)].text(
1282
+ 0.5,
1283
+ 0.99,
1284
+ r"${\to}$ " + "%s" % str(var_names[j]),
1285
+ fontsize=label_fontsize,
1286
+ horizontalalignment="center",
1287
+ verticalalignment="top",
1288
+ transform=trans,
1289
+ )
1290
+
1291
+ # _make_nice_axes(self.axes_dict[(i, j)], where=["bottom"], skip=(1, 1) )
1292
+ # self.axes_dict[(i, j)].axis["right"].set_visible(False)
1293
+ # self.axes_dict[(i, j)].axis["top"].set_visible(False)
1294
+ if i == j:
1295
+ # self.axes_dict[(i, j)].axis["left"].set_visible(False)
1296
+ _make_nice_axes(self.axes_dict[(i, j)], where=["bottom"], skip=(1, 1))
1297
+ else:
1298
+ _make_nice_axes(self.axes_dict[(i, j)], where=["left", "bottom"], skip=(1, 1))
1299
+ # if j != 0:
1300
+ # self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
1301
+ # if i != N - 1:
1302
+ # self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])
1303
+
1304
+ if plot_gridlines:
1305
+ self.axes_dict[(i, j)].grid(
1306
+ True,
1307
+ which="major",
1308
+ color="black",
1309
+ linestyle="dotted",
1310
+ dashes=(1, 1),
1311
+ linewidth=0.05,
1312
+ zorder=-5,
1313
+ )
1314
+ self.axes_dict[(i, j)].tick_params(axis='both', which='major', labelsize=tick_label_size)
1315
+ plot_index += 1
1316
+
1317
+ def add_densityplot(
1318
+ self,
1319
+ dataframe,
1320
+ matrix_lags=None,
1321
+ label=None,
1322
+ label_color='black',
1323
+ snskdeplot_args = {'cmap':'Greys'},
1324
+ snskdeplot_diagonal_args = {},
1325
+ selected_dataset=0,
1326
+ show_marginal_densities_on_diagonal=True
1327
+ ):
1328
+ """Add density function plot.
1329
+
1330
+ Parameters
1331
+ ----------
1332
+ dataframe : data object
1333
+ Tigramite dataframe object. It must have the attributes dataframe.values
1334
+ yielding a numpy array of shape (observations T, variables N) and
1335
+ optionally a mask of the same shape and a missing values flag.
1336
+ matrix_lags : array
1337
+ Lags to use in scatter plots. Either None or non-neg array of shape (N, N). Then the
1338
+ entry matrix_lags[i, j] = tau will depict the scatter plot of
1339
+ time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
1340
+ tau = 1.
1341
+ snskdeplot_args : dict
1342
+ Optional parameters to pass to sns.kdeplot() for i != j for off-diagonal plots.
1343
+ snskdeplot_diagonal_args : dict
1344
+ Optional parameters to pass to sns.kdeplot() for i == j on diagonal.
1345
+ label : string
1346
+ Label of this plot.
1347
+ label_color : string
1348
+ Color of line created just for legend.
1349
+ selected_dataset : int, optional (default: 0)
1350
+ In case of multiple datasets in dataframe, plot this one.
1351
+ show_marginal_densities_on_diagonal : bool, optional (default: True)
1352
+ Flag to show marginal densities on the diagonal of the density plots
1353
+ """
1354
+
1355
+ # Use seaborn for this one
1356
+ import seaborn as sns
1357
+
1358
+ # set seaborn style
1359
+ sns.set_style("white")
1360
+
1361
+ self.matrix_lags = matrix_lags
1362
+
1363
+ if matrix_lags is not None and np.any(matrix_lags < 0):
1364
+ raise ValueError("matrix_lags must be non-negative!")
1365
+
1366
+ data = dataframe.values[selected_dataset]
1367
+ if dataframe.mask is not None:
1368
+ mask = dataframe.mask[selected_dataset]
1369
+
1370
+ T, dim = data.shape
1371
+
1372
+ # if label is not None:
1373
+ self.labels.append((label, label_color))
1374
+
1375
+ for ij in list(self.axes_dict):
1376
+ i = ij[0]
1377
+ j = ij[1]
1378
+ ax = self.axes_dict[(i, j)]
1379
+ if (matrix_lags is None):
1380
+ if i == j:
1381
+ lag = 1
1382
+ else:
1383
+ lag = 0
1384
+ else:
1385
+ lag = matrix_lags[i,j]
1386
+ x = np.copy(data[:T-lag, i])
1387
+ y = np.copy(data[lag:, j])
1388
+ # Data is set to NaN in dataframe init already
1389
+ # if dataframe.missing_flag is not None:
1390
+ # x[x==dataframe.missing_flag] = np.nan
1391
+ # y[y==dataframe.missing_flag] = np.nan
1392
+ if dataframe.mask is not None:
1393
+ x[mask[:T-lag, i]==1] = np.nan
1394
+ y[mask[lag:, j]==1] = np.nan
1395
+
1396
+ if i == j and show_marginal_densities_on_diagonal:
1397
+ sns.kdeplot(x,
1398
+ color = label_color,
1399
+ # label=r"$\tau{=}%d$" %lag,
1400
+ **snskdeplot_diagonal_args,
1401
+ ax = ax)
1402
+ ax.set_ylabel("")
1403
+ # ax.yaxis.set_ticks_position("none")
1404
+ # ax.yaxis.set_ticklabels([])
1405
+ else:
1406
+ sns.kdeplot(x=y, y=x, # NEW: inverted to match rows/columns
1407
+ #label=r"$\tau{=}%d$" %lag,
1408
+ **snskdeplot_args,
1409
+ # fill=True,
1410
+ # alpha=0.3,
1411
+ ax = ax)
1412
+
1413
+ def adjustfig(self, name=None, show_labels=True):
1414
+ """Adjust matrix figure.
1415
+
1416
+ Parameters
1417
+ ----------
1418
+ name : str, optional (default: None)
1419
+ File name. If None, figure is shown in window.
1420
+ """
1421
+
1422
+ # Trick to plot legends
1423
+ # colors = []
1424
+ # for item in self.labels:
1425
+ # colors.append(item[1])
1426
+ for ij in list(self.axes_dict):
1427
+ i = ij[0]
1428
+ j = ij[1]
1429
+ if self.matrix_lags is None:
1430
+ lag = 0
1431
+ else:
1432
+ lag = self.matrix_lags[i,j]
1433
+ if i != j:
1434
+ colors = []
1435
+ for item in self.labels:
1436
+ color = item[1]
1437
+ colors.append(color)
1438
+ if show_labels:
1439
+ self.axes_dict[(i, j)].plot(
1440
+ [],
1441
+ [],
1442
+ linestyle="",
1443
+ color=color,
1444
+ label=r"$\tau{=}%d$" %lag,
1445
+ )
1446
+ # print('here')
1447
+ leg = self.axes_dict[(i, j)].legend(
1448
+ # loc="best",
1449
+ ncol=1,
1450
+ # bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1451
+ # borderaxespad=0,
1452
+ fontsize=self.legend_fontsize-2,
1453
+ labelcolor=colors,
1454
+ ).draw_frame(False)
1455
+
1456
+ # if i == j:
1457
+ # # self.axes_dict[(i, j)].axis["left"].set_visible(False)
1458
+ # _make_nice_axes(ax=self.axes_dict[(i, j)], where=["bottom"], skip=(1, 1))
1459
+ # else:
1460
+ # _make_nice_axes(ax=self.axes_dict[(i, j)], where=["left", "bottom"], skip=(1, 1))
1461
+
1462
+ if show_labels and len(self.labels) > 1:
1463
+ axlegend = self.fig.add_subplot(111, frameon=False)
1464
+ axlegend.spines["left"].set_color("none")
1465
+ axlegend.spines["right"].set_color("none")
1466
+ axlegend.spines["bottom"].set_color("none")
1467
+ axlegend.spines["top"].set_color("none")
1468
+ axlegend.set_xticks([])
1469
+ axlegend.set_yticks([])
1470
+
1471
+ # self.labels.append((label, color, marker, markersize, alpha))
1472
+ for item in self.labels:
1473
+ label = item[0]
1474
+ color = item[1]
1475
+
1476
+ axlegend.plot(
1477
+ [],
1478
+ [],
1479
+ linestyle="-",
1480
+ color=color,
1481
+ label=label,
1482
+ )
1483
+ axlegend.legend(
1484
+ loc="upper left",
1485
+ ncol=1,
1486
+ bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
1487
+ borderaxespad=0,
1488
+ fontsize=self.legend_fontsize,
1489
+ ).draw_frame(False)
1490
+
1491
+ self.fig.subplots_adjust(
1492
+ bottom=0.08,
1493
+ left=self.label_space_left,
1494
+ right=1.0 - self.legend_width,
1495
+ top=1.0 - self.label_space_top,
1496
+ hspace=0.5,
1497
+ wspace=0.35,
1498
+ )
1499
+
1500
+ else:
1501
+ self.fig.subplots_adjust(
1502
+ left=self.label_space_left,
1503
+ bottom=0.08,
1504
+ right=0.95,
1505
+ top=1.0 - self.label_space_top,
1506
+ hspace=0.35,
1507
+ wspace=0.35,
1508
+ )
1509
+
1510
+ if name is not None:
1511
+ self.fig.savefig(name)
1512
+ else:
1513
+ pyplot.show()
1514
+
1515
+ def _draw_network_with_curved_edges(
1516
+ fig,
1517
+ ax,
1518
+ G,
1519
+ pos,
1520
+ node_rings,
1521
+ node_labels,
1522
+ node_label_size=10,
1523
+ node_alpha=1.0,
1524
+ standard_size=100,
1525
+ node_aspect=None,
1526
+ standard_cmap="OrRd",
1527
+ standard_color_links='black',
1528
+ standard_color_nodes='lightgrey',
1529
+ log_sizes=False,
1530
+ cmap_links="YlOrRd",
1531
+ # cmap_links_edges="YlOrRd",
1532
+ links_vmin=0.0,
1533
+ links_vmax=1.0,
1534
+ links_edges_vmin=0.0,
1535
+ links_edges_vmax=1.0,
1536
+ links_ticks=0.2,
1537
+ links_edges_ticks=0.2,
1538
+ link_label_fontsize=8,
1539
+ arrowstyle="->, head_width=0.4, head_length=1",
1540
+ arrowhead_size=3.0,
1541
+ curved_radius=0.2,
1542
+ label_fontsize=4,
1543
+ label_fraction=0.5,
1544
+ link_colorbar_label="link",
1545
+ tick_label_size=6,
1546
+ # link_edge_colorbar_label='link_edge',
1547
+ inner_edge_curved=False,
1548
+ inner_edge_style="solid",
1549
+ # network_lower_bound=0.2,
1550
+ network_left_bound=None,
1551
+ show_colorbar=True,
1552
+ special_nodes=None,
1553
+ autodep_sig_lags=None,
1554
+ show_autodependency_lags=False,
1555
+ transform='data',
1556
+ node_classification=None,
1557
+ max_lag=0,
1558
+ ):
1559
+ """Function to draw a network from networkx graph instance.
1560
+ Various attributes are used to specify the graph's properties.
1561
+ This function is just a beta-template for now that can be further
1562
+ customized.
1563
+ """
1564
+
1565
+ if transform == 'data':
1566
+ transform = ax.transData
1567
+
1568
+ from matplotlib.patches import FancyArrowPatch, Circle, Ellipse
1569
+
1570
+ ax.spines["left"].set_color("none")
1571
+ ax.spines["right"].set_color("none")
1572
+ ax.spines["bottom"].set_color("none")
1573
+ ax.spines["top"].set_color("none")
1574
+ ax.set_xticks([])
1575
+ ax.set_yticks([])
1576
+
1577
+ N = len(G)
1578
+
1579
+ # This fixes a positioning bug in matplotlib.
1580
+ ax.scatter(0, 0, zorder=-10, alpha=0)
1581
+
1582
+ def draw_edge(
1583
+ ax,
1584
+ u,
1585
+ v,
1586
+ d,
1587
+ seen,
1588
+ arrowstyle= "Simple, head_width=2, head_length=2, tail_width=1",
1589
+ outer_edge=True,
1590
+ ):
1591
+
1592
+ # avoiding attribute error raised by changes in networkx
1593
+ if hasattr(G, "node"):
1594
+ # works with networkx 1.10
1595
+ n1 = G.node[u]["patch"]
1596
+ n2 = G.node[v]["patch"]
1597
+ else:
1598
+ # works with networkx 2.4
1599
+ n1 = G.nodes[u]["patch"]
1600
+ n2 = G.nodes[v]["patch"]
1601
+
1602
+ # print("+++++++++++++++++++++++==cmap_links ", cmap_links)
1603
+ if outer_edge:
1604
+ rad = -1.0 * curved_radius
1605
+ if cmap_links is not None:
1606
+ facecolor = data_to_rgb_links.to_rgba(d["outer_edge_color"])
1607
+ else:
1608
+ if d["outer_edge_color"] is not None:
1609
+ facecolor = d["outer_edge_color"]
1610
+ else:
1611
+ facecolor = standard_color_links
1612
+
1613
+ width = d["outer_edge_width"]
1614
+ alpha = d["outer_edge_alpha"]
1615
+ if (u, v) in seen:
1616
+ rad = seen.get((u, v))
1617
+ rad = (rad + np.sign(rad) * 0.1) * -1.0
1618
+ arrowstyle = arrowstyle
1619
+ # link_edge = d['outer_edge_edge']
1620
+ linestyle = 'solid' # d.get("outer_edge_style")
1621
+
1622
+ if d.get("outer_edge_attribute", None) == "spurious":
1623
+ facecolor = "grey"
1624
+
1625
+ if d.get("outer_edge_type") in ["<-o", "<--", "<-x", "<-+"]:
1626
+ n1, n2 = n2, n1
1627
+
1628
+ if d.get("outer_edge_type") in [
1629
+ "o-o",
1630
+ "o--",
1631
+ "--o",
1632
+ "---",
1633
+ "x-x",
1634
+ "x--",
1635
+ "--x",
1636
+ "o-x",
1637
+ "x-o",
1638
+ # "+->",
1639
+ # "<-+",
1640
+ ]:
1641
+ arrowstyle = "-"
1642
+ # linewidth = width*factor
1643
+ elif d.get("outer_edge_type") == "<->":
1644
+ # arrowstyle = "<->, head_width=0.4, head_length=1"
1645
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1646
+ elif d.get("outer_edge_type") in ["o->", "-->", "<-o", "<--", "<-x", "x->", "+->", "<-+"]:
1647
+ # arrowstyle = "->, head_width=0.4, head_length=1"
1648
+ # arrowstyle = "->, head_width=0.4, head_length=1, width=10"
1649
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1650
+ else:
1651
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1652
+ # raise ValueError("edge type %s not valid." %d.get("outer_edge_type"))
1653
+ else:
1654
+ rad = -1.0 * inner_edge_curved * curved_radius
1655
+ if cmap_links is not None:
1656
+ facecolor = data_to_rgb_links.to_rgba(d["inner_edge_color"])
1657
+ else:
1658
+ if d["inner_edge_color"] is not None:
1659
+ facecolor = d["inner_edge_color"]
1660
+ else:
1661
+ # print("HERE")
1662
+ facecolor = standard_color_links
1663
+
1664
+ width = d["inner_edge_width"]
1665
+ alpha = d["inner_edge_alpha"]
1666
+
1667
+ if d.get("inner_edge_attribute", None) == "spurious":
1668
+ facecolor = "grey"
1669
+ # print(d.get("inner_edge_type"))
1670
+ if d.get("inner_edge_type") in ["<-o", "<--", "<-x", "<-+"]:
1671
+ n1, n2 = n2, n1
1672
+
1673
+ if d.get("inner_edge_type") in [
1674
+ "o-o",
1675
+ "o--",
1676
+ "--o",
1677
+ "---",
1678
+ "x-x",
1679
+ "x--",
1680
+ "--x",
1681
+ "o-x",
1682
+ "x-o",
1683
+ ]:
1684
+ arrowstyle = "-"
1685
+ elif d.get("inner_edge_type") == "<->":
1686
+ # arrowstyle = "<->, head_width=0.4, head_length=1"
1687
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1688
+ elif d.get("inner_edge_type") in ["o->", "-->", "<-o", "<--", "<-x", "x->", "+->", "<-+"]:
1689
+ # arrowstyle = "->, head_width=0.4, head_length=1"
1690
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1691
+ else:
1692
+ arrowstyle = "Simple, head_width=2, head_length=2, tail_width=1" #%float(width/20.)
1693
+
1694
+ # raise ValueError("edge type %s not valid." %d.get("inner_edge_type"))
1695
+
1696
+ linestyle = 'solid' #d.get("inner_edge_style")
1697
+
1698
+ coor1 = n1.center
1699
+ coor2 = n2.center
1700
+
1701
+ marker_size = width ** 2
1702
+ figuresize = fig.get_size_inches()
1703
+
1704
+ # print("COLOR ", facecolor)
1705
+ # print(u, v, outer_edge, "outer ", d.get("outer_edge_type"), "inner ", d.get("inner_edge_type"), width, arrowstyle, linestyle)
1706
+
1707
+ if ((outer_edge is True and d.get("outer_edge_type") == "<->")
1708
+ or (outer_edge is False and d.get("inner_edge_type") == "<->")):
1709
+ e_p = FancyArrowPatch(
1710
+ coor1,
1711
+ coor2,
1712
+ arrowstyle=arrowstyle,
1713
+ connectionstyle=f"arc3,rad={rad}",
1714
+ mutation_scale=1*width,
1715
+ lw=0., #width / 2.,
1716
+ aa=True,
1717
+ alpha=alpha,
1718
+ linestyle=linestyle,
1719
+ color=facecolor,
1720
+ clip_on=False,
1721
+ patchA=n1,
1722
+ patchB=n2,
1723
+ shrinkA=7,
1724
+ shrinkB=0,
1725
+ zorder=-1,
1726
+ capstyle="butt",
1727
+ transform=transform,
1728
+ )
1729
+ ax.add_artist(e_p)
1730
+
1731
+ e_p_back = FancyArrowPatch(
1732
+ coor2,
1733
+ coor1,
1734
+ arrowstyle=arrowstyle,
1735
+ connectionstyle=f"arc3,rad={-rad}",
1736
+ mutation_scale=1*width,
1737
+ lw=0., #width / 2.,
1738
+ aa=True,
1739
+ alpha=alpha,
1740
+ linestyle=linestyle,
1741
+ color=facecolor,
1742
+ clip_on=False,
1743
+ patchA=n2,
1744
+ patchB=n1,
1745
+ shrinkA=7,
1746
+ shrinkB=0,
1747
+ zorder=-1,
1748
+ capstyle="butt",
1749
+ transform=transform,
1750
+ )
1751
+ ax.add_artist(e_p_back)
1752
+
1753
+ else:
1754
+ if arrowstyle == '-':
1755
+ lw = 1*width
1756
+ else:
1757
+ lw = 0.
1758
+ # e_p = FancyArrowPatch(
1759
+ # coor1,
1760
+ # coor2,
1761
+ # arrowstyle=arrowstyle,
1762
+ # connectionstyle=f"arc3,rad={rad}",
1763
+ # mutation_scale=np.sqrt(width)*2*1.1,
1764
+ # lw=lw*1.1, #width / 2.,
1765
+ # aa=True,
1766
+ # alpha=alpha,
1767
+ # linestyle=linestyle,
1768
+ # color='white',
1769
+ # clip_on=False,
1770
+ # patchA=n1,
1771
+ # patchB=n2,
1772
+ # shrinkA=0,
1773
+ # shrinkB=0,
1774
+ # zorder=-1,
1775
+ # capstyle="butt",
1776
+ # )
1777
+ # ax.add_artist(e_p)
1778
+ e_p = FancyArrowPatch(
1779
+ coor1,
1780
+ coor2,
1781
+ arrowstyle=arrowstyle,
1782
+ connectionstyle=f"arc3,rad={rad}",
1783
+ mutation_scale=1*width,
1784
+ lw=lw, #width / 2.,
1785
+ aa=True,
1786
+ alpha=alpha,
1787
+ linestyle=linestyle,
1788
+ color=facecolor,
1789
+ clip_on=False,
1790
+ patchA=n1,
1791
+ patchB=n2,
1792
+ shrinkA=0,
1793
+ shrinkB=0,
1794
+ # zorder=-1,
1795
+ capstyle="butt",
1796
+ transform=transform,
1797
+ )
1798
+ ax.add_artist(e_p)
1799
+
1800
+ e_p_marker = FancyArrowPatch(
1801
+ coor1,
1802
+ coor2,
1803
+ arrowstyle='-',
1804
+ connectionstyle=f"arc3,rad={rad}",
1805
+ mutation_scale=1*width,
1806
+ lw=0., #width / 2.,
1807
+ aa=True,
1808
+ alpha=0.,
1809
+ linestyle=linestyle,
1810
+ color=facecolor,
1811
+ clip_on=False,
1812
+ patchA=n1,
1813
+ patchB=n2,
1814
+ shrinkA=0,
1815
+ shrinkB=0,
1816
+ zorder=-10,
1817
+ capstyle="butt",
1818
+ transform=transform,
1819
+ )
1820
+ ax.add_artist(e_p_marker)
1821
+
1822
+ # marker_path = e_p_marker.get_path()
1823
+ vertices = e_p_marker.get_path().vertices.copy()
1824
+ # vertices = e_p_marker.get_verts()
1825
+ # vertices = e_p_marker.get_path().to_polygons(transform=None)[0]
1826
+ # print(vertices.shape)
1827
+ m, n = vertices.shape
1828
+
1829
+ # print(vertices)
1830
+ start = vertices[0]
1831
+ end = vertices[-1]
1832
+
1833
+ # This must be added to avoid rescaling of the plot, when no 'o'
1834
+ # or 'x' is added to the graph.
1835
+ ax.scatter(*start, zorder=-10, alpha=0, transform=transform,)
1836
+
1837
+ if outer_edge:
1838
+ if d.get("outer_edge_type") in ["o->", "o--"]:
1839
+ circle_marker_start = ax.scatter(
1840
+ *start,
1841
+ marker="o",
1842
+ s=marker_size,
1843
+ facecolor="w",
1844
+ edgecolor=facecolor,
1845
+ zorder=1,
1846
+ transform=transform,
1847
+ )
1848
+ ax.add_collection(circle_marker_start)
1849
+ elif d.get("outer_edge_type") == "<-o":
1850
+ circle_marker_end = ax.scatter(
1851
+ *start,
1852
+ marker="o",
1853
+ s=marker_size,
1854
+ facecolor="w",
1855
+ edgecolor=facecolor,
1856
+ zorder=1,
1857
+ transform=transform,
1858
+ )
1859
+ ax.add_collection(circle_marker_end)
1860
+ elif d.get("outer_edge_type") == "--o":
1861
+ circle_marker_end = ax.scatter(
1862
+ *end,
1863
+ marker="o",
1864
+ s=marker_size,
1865
+ facecolor="w",
1866
+ edgecolor=facecolor,
1867
+ zorder=1,
1868
+ transform=transform,
1869
+ )
1870
+ ax.add_collection(circle_marker_end)
1871
+ elif d.get("outer_edge_type") in ["x--", "x->"]:
1872
+ circle_marker_start = ax.scatter(
1873
+ *start,
1874
+ marker="X",
1875
+ s=marker_size,
1876
+ facecolor="w",
1877
+ edgecolor=facecolor,
1878
+ zorder=1,
1879
+ transform=transform,
1880
+ )
1881
+ ax.add_collection(circle_marker_start)
1882
+ elif d.get("outer_edge_type") in ["+--", "+->"]:
1883
+ circle_marker_start = ax.scatter(
1884
+ *start,
1885
+ marker="P",
1886
+ s=marker_size,
1887
+ facecolor="w",
1888
+ edgecolor=facecolor,
1889
+ zorder=1,
1890
+ transform=transform,
1891
+ )
1892
+ ax.add_collection(circle_marker_start)
1893
+ elif d.get("outer_edge_type") == "<-x":
1894
+ circle_marker_end = ax.scatter(
1895
+ *start,
1896
+ marker="X",
1897
+ s=marker_size,
1898
+ facecolor="w",
1899
+ edgecolor=facecolor,
1900
+ zorder=1,
1901
+ transform=transform,
1902
+ )
1903
+ ax.add_collection(circle_marker_end)
1904
+ elif d.get("outer_edge_type") == "<-+":
1905
+ circle_marker_end = ax.scatter(
1906
+ *start,
1907
+ marker="P",
1908
+ s=marker_size,
1909
+ facecolor="w",
1910
+ edgecolor=facecolor,
1911
+ zorder=1,
1912
+ transform=transform,
1913
+ )
1914
+ ax.add_collection(circle_marker_end)
1915
+ elif d.get("outer_edge_type") == "--x":
1916
+ circle_marker_end = ax.scatter(
1917
+ *end,
1918
+ marker="X",
1919
+ s=marker_size,
1920
+ facecolor="w",
1921
+ edgecolor=facecolor,
1922
+ zorder=1,
1923
+ transform=transform,
1924
+ )
1925
+ ax.add_collection(circle_marker_end)
1926
+ elif d.get("outer_edge_type") == "o-o":
1927
+ circle_marker_start = ax.scatter(
1928
+ *start,
1929
+ marker="o",
1930
+ s=marker_size,
1931
+ facecolor="w",
1932
+ edgecolor=facecolor,
1933
+ zorder=1,
1934
+ transform=transform,
1935
+ )
1936
+ ax.add_collection(circle_marker_start)
1937
+ circle_marker_end = ax.scatter(
1938
+ *end,
1939
+ marker="o",
1940
+ s=marker_size,
1941
+ facecolor="w",
1942
+ edgecolor=facecolor,
1943
+ zorder=1,
1944
+ transform=transform,
1945
+ )
1946
+ ax.add_collection(circle_marker_end)
1947
+ elif d.get("outer_edge_type") == "x-x":
1948
+ circle_marker_start = ax.scatter(
1949
+ *start,
1950
+ marker="X",
1951
+ s=marker_size,
1952
+ facecolor="w",
1953
+ edgecolor=facecolor,
1954
+ zorder=1,
1955
+ transform=transform,
1956
+ )
1957
+ ax.add_collection(circle_marker_start)
1958
+ circle_marker_end = ax.scatter(
1959
+ *end,
1960
+ marker="X",
1961
+ s=marker_size,
1962
+ facecolor="w",
1963
+ edgecolor=facecolor,
1964
+ zorder=1,
1965
+ transform=transform,
1966
+ )
1967
+ ax.add_collection(circle_marker_end)
1968
+ elif d.get("outer_edge_type") == "o-x":
1969
+ circle_marker_start = ax.scatter(
1970
+ *start,
1971
+ marker="o",
1972
+ s=marker_size,
1973
+ facecolor="w",
1974
+ edgecolor=facecolor,
1975
+ zorder=1,
1976
+ transform=transform,
1977
+ )
1978
+ ax.add_collection(circle_marker_start)
1979
+ circle_marker_end = ax.scatter(
1980
+ *end,
1981
+ marker="X",
1982
+ s=marker_size,
1983
+ facecolor="w",
1984
+ edgecolor=facecolor,
1985
+ zorder=1,
1986
+ transform=transform,
1987
+ )
1988
+ ax.add_collection(circle_marker_end)
1989
+ elif d.get("outer_edge_type") == "x-o":
1990
+ circle_marker_start = ax.scatter(
1991
+ *start,
1992
+ marker="X",
1993
+ s=marker_size,
1994
+ facecolor="w",
1995
+ edgecolor=facecolor,
1996
+ zorder=1,
1997
+ transform=transform,
1998
+ )
1999
+ ax.add_collection(circle_marker_start)
2000
+ circle_marker_end = ax.scatter(
2001
+ *end,
2002
+ marker="o",
2003
+ s=marker_size,
2004
+ facecolor="w",
2005
+ edgecolor=facecolor,
2006
+ zorder=1,
2007
+ transform=transform,
2008
+ )
2009
+ ax.add_collection(circle_marker_end)
2010
+
2011
+ else:
2012
+ if d.get("inner_edge_type") in ["o->", "o--"]:
2013
+ circle_marker_start = ax.scatter(
2014
+ *start,
2015
+ marker="o",
2016
+ s=marker_size,
2017
+ facecolor="w",
2018
+ edgecolor=facecolor,
2019
+ zorder=1,
2020
+ transform=transform,
2021
+ )
2022
+ ax.add_collection(circle_marker_start)
2023
+ elif d.get("inner_edge_type") == "<-o":
2024
+ circle_marker_end = ax.scatter(
2025
+ *start,
2026
+ marker="o",
2027
+ s=marker_size,
2028
+ facecolor="w",
2029
+ edgecolor=facecolor,
2030
+ zorder=1,
2031
+ transform=transform,
2032
+ )
2033
+ ax.add_collection(circle_marker_end)
2034
+ elif d.get("inner_edge_type") == "--o":
2035
+ circle_marker_end = ax.scatter(
2036
+ *end,
2037
+ marker="o",
2038
+ s=marker_size,
2039
+ facecolor="w",
2040
+ edgecolor=facecolor,
2041
+ zorder=1,
2042
+ transform=transform,
2043
+ )
2044
+ ax.add_collection(circle_marker_end)
2045
+ elif d.get("inner_edge_type") in ["x--", "x->"]:
2046
+ circle_marker_start = ax.scatter(
2047
+ *start,
2048
+ marker="X",
2049
+ s=marker_size,
2050
+ facecolor="w",
2051
+ edgecolor=facecolor,
2052
+ zorder=1,
2053
+ transform=transform,
2054
+ )
2055
+ ax.add_collection(circle_marker_start)
2056
+ elif d.get("inner_edge_type") in ["+--", "+->"]:
2057
+ circle_marker_start = ax.scatter(
2058
+ *start,
2059
+ marker="P",
2060
+ s=marker_size,
2061
+ facecolor="w",
2062
+ edgecolor=facecolor,
2063
+ zorder=1,
2064
+ transform=transform,
2065
+ )
2066
+ ax.add_collection(circle_marker_start)
2067
+ elif d.get("inner_edge_type") == "<-x":
2068
+ circle_marker_end = ax.scatter(
2069
+ *start,
2070
+ marker="X",
2071
+ s=marker_size,
2072
+ facecolor="w",
2073
+ edgecolor=facecolor,
2074
+ zorder=1,
2075
+ transform=transform,
2076
+ )
2077
+ ax.add_collection(circle_marker_end)
2078
+ elif d.get("inner_edge_type") == "<-+":
2079
+ circle_marker_end = ax.scatter(
2080
+ *start,
2081
+ marker="P",
2082
+ s=marker_size,
2083
+ facecolor="w",
2084
+ edgecolor=facecolor,
2085
+ zorder=1,
2086
+ transform=transform,
2087
+ )
2088
+ ax.add_collection(circle_marker_end)
2089
+ elif d.get("inner_edge_type") == "--x":
2090
+ circle_marker_end = ax.scatter(
2091
+ *end,
2092
+ marker="X",
2093
+ s=marker_size,
2094
+ facecolor="w",
2095
+ edgecolor=facecolor,
2096
+ zorder=1,
2097
+ transform=transform,
2098
+ )
2099
+ ax.add_collection(circle_marker_end)
2100
+ elif d.get("inner_edge_type") == "o-o":
2101
+ circle_marker_start = ax.scatter(
2102
+ *start,
2103
+ marker="o",
2104
+ s=marker_size,
2105
+ facecolor="w",
2106
+ edgecolor=facecolor,
2107
+ zorder=1,
2108
+ transform=transform,
2109
+ )
2110
+ ax.add_collection(circle_marker_start)
2111
+ circle_marker_end = ax.scatter(
2112
+ *end,
2113
+ marker="o",
2114
+ s=marker_size,
2115
+ facecolor="w",
2116
+ edgecolor=facecolor,
2117
+ zorder=1,
2118
+ transform=transform,
2119
+ )
2120
+ ax.add_collection(circle_marker_end)
2121
+ elif d.get("inner_edge_type") == "x-x":
2122
+ circle_marker_start = ax.scatter(
2123
+ *start,
2124
+ marker="X",
2125
+ s=marker_size,
2126
+ facecolor="w",
2127
+ edgecolor=facecolor,
2128
+ zorder=1,
2129
+ transform=transform,
2130
+ )
2131
+ ax.add_collection(circle_marker_start)
2132
+ circle_marker_end = ax.scatter(
2133
+ *end,
2134
+ marker="X",
2135
+ s=marker_size,
2136
+ facecolor="w",
2137
+ edgecolor=facecolor,
2138
+ zorder=1,
2139
+ transform=transform,
2140
+ )
2141
+ ax.add_collection(circle_marker_end)
2142
+ elif d.get("inner_edge_type") == "o-x":
2143
+ circle_marker_start = ax.scatter(
2144
+ *start,
2145
+ marker="o",
2146
+ s=marker_size,
2147
+ facecolor="w",
2148
+ edgecolor=facecolor,
2149
+ zorder=1,
2150
+ transform=transform,
2151
+ )
2152
+ ax.add_collection(circle_marker_start)
2153
+ circle_marker_end = ax.scatter(
2154
+ *end,
2155
+ marker="X",
2156
+ s=marker_size,
2157
+ facecolor="w",
2158
+ edgecolor=facecolor,
2159
+ zorder=1,
2160
+ transform=transform,
2161
+ )
2162
+ ax.add_collection(circle_marker_end)
2163
+ elif d.get("inner_edge_type") == "x-o":
2164
+ circle_marker_start = ax.scatter(
2165
+ *start,
2166
+ marker="X",
2167
+ s=marker_size,
2168
+ facecolor="w",
2169
+ edgecolor=facecolor,
2170
+ zorder=1,
2171
+ transform=transform,
2172
+ )
2173
+ ax.add_collection(circle_marker_start)
2174
+ circle_marker_end = ax.scatter(
2175
+ *end,
2176
+ marker="o",
2177
+ s=marker_size,
2178
+ facecolor="w",
2179
+ edgecolor=facecolor,
2180
+ zorder=1,
2181
+ transform=transform,
2182
+ )
2183
+ ax.add_collection(circle_marker_end)
2184
+
2185
+
2186
+
2187
+ if d["label"] is not None and outer_edge:
2188
+ def closest_node(node, nodes):
2189
+ nodes = np.asarray(nodes)
2190
+ node = node.reshape(1, 2)
2191
+ dist_2 = np.sum((nodes - node)**2, axis=1)
2192
+ return np.argmin(dist_2)
2193
+
2194
+ # Attach labels of lags
2195
+ # trans = None # patch.get_transform()
2196
+ # path = e_p.get_path()
2197
+ vertices = e_p_marker.get_path().vertices.copy()
2198
+ verts = e_p.get_path().to_polygons(transform=None)[0]
2199
+ # print(verts)
2200
+ # print(verts.shape)
2201
+ # print(vertices.shape)
2202
+ # for num, vert in enumerate(verts):
2203
+ # ax.text(vert[0], vert[1], str(num),
2204
+ # transform=transform,)
2205
+ # ax.scatter(verts[:,0], verts[:,1])
2206
+ # mid_point = np.array([(start[0] + end[0])/2., (start[1] + end[1])/2.])
2207
+ # print(start, end, mid_point)
2208
+ # ax.scatter(mid_point[0], mid_point[1], marker='x',
2209
+ # s=100, zorder=10, transform=transform,)
2210
+ closest_node = closest_node(vertices[int(len(vertices)/2.),:], verts)
2211
+ # print(closest_node, verts[closest_node])
2212
+ # ax.scatter(verts[closest_node][0], verts[closest_node][1], marker='x')
2213
+
2214
+ if len(vertices) > 2:
2215
+ # label_vert = vertices[int(len(vertices)/2.),:] #verts[1, :]
2216
+ label_vert = verts[closest_node] #verts[1, :]
2217
+ l = d["label"]
2218
+ string = str(l)
2219
+ txt = ax.text(
2220
+ label_vert[0],
2221
+ label_vert[1],
2222
+ string,
2223
+ fontsize=link_label_fontsize,
2224
+ verticalalignment="center",
2225
+ horizontalalignment="center",
2226
+ color="w",
2227
+ zorder=1,
2228
+ transform=transform,
2229
+ )
2230
+ txt.set_path_effects(
2231
+ [PathEffects.withStroke(linewidth=2, foreground="k")]
2232
+ )
2233
+
2234
+ return rad
2235
+
2236
+ # Collect all edge weights to get color scale
2237
+ all_links_weights = []
2238
+ all_links_edge_weights = []
2239
+ for (u, v, d) in G.edges(data=True):
2240
+ if u != v:
2241
+ if d["outer_edge"] and d["outer_edge_color"] is not None:
2242
+ all_links_weights.append(d["outer_edge_color"])
2243
+ if d["inner_edge"] and d["inner_edge_color"] is not None:
2244
+ all_links_weights.append(d["inner_edge_color"])
2245
+
2246
+ if cmap_links is not None and len(all_links_weights) > 0:
2247
+ if links_vmin is None:
2248
+ links_vmin = np.array(all_links_weights).min()
2249
+ if links_vmax is None:
2250
+ links_vmax = np.array(all_links_weights).max()
2251
+ data_to_rgb_links = pyplot.cm.ScalarMappable(
2252
+ norm=None, cmap=pyplot.get_cmap(cmap_links)
2253
+ )
2254
+ data_to_rgb_links.set_array(np.array(all_links_weights))
2255
+ data_to_rgb_links.set_clim(vmin=links_vmin, vmax=links_vmax)
2256
+ # Create colorbars for links
2257
+
2258
+ # setup colorbar axes.
2259
+ if show_colorbar:
2260
+ # cax_e = pyplot.axes(
2261
+ # [
2262
+ # 0.55,
2263
+ # ax.get_subplotspec().get_position(ax.figure).bounds[1] + 0.02,
2264
+ # 0.4,
2265
+ # 0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2266
+ # ],
2267
+ # frameon=False,
2268
+ # )
2269
+ bbox_ax = ax.get_position()
2270
+ width = bbox_ax.xmax-bbox_ax.xmin
2271
+ height = bbox_ax.ymax-bbox_ax.ymin
2272
+ # print(bbox_ax.xmin, bbox_ax.xmax, bbox_ax.ymin, bbox_ax.ymax)
2273
+ # cax_e = fig.add_axes(
2274
+ # [
2275
+ # bbox_ax.xmax - width*0.45,
2276
+ # bbox_ax.ymin-0.075*height+network_lower_bound-0.15,
2277
+ # width*0.4,
2278
+ # 0.075*height, #0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2279
+ # ],
2280
+ # frameon=False,
2281
+ # )
2282
+ cax_e = ax.inset_axes(
2283
+ [
2284
+ 0.55, -0.07, 0.4, 0.07
2285
+ # bbox_ax.xmax - width*0.45,
2286
+ # bbox_ax.ymin-0.075*height+network_lower_bound-0.15,
2287
+ # width*0.4,
2288
+ # 0.075*height, #0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2289
+ ],
2290
+ frameon=False,)
2291
+ # divider = make_axes_locatable(ax)
2292
+
2293
+ # cax_e = divider.append_axes('bottom', size='5%', pad=0.05, frameon=False,)
2294
+
2295
+ cb_e = pyplot.colorbar(
2296
+ data_to_rgb_links, cax=cax_e, orientation="horizontal"
2297
+ )
2298
+ # try:
2299
+ ticks_here = np.arange(
2300
+ _myround(links_vmin, links_ticks, "down"),
2301
+ _myround(links_vmax, links_ticks, "up") + links_ticks,
2302
+ links_ticks,
2303
+ )
2304
+ cb_e.set_ticks(ticks_here[(links_vmin <= ticks_here) & (ticks_here <= links_vmax)])
2305
+ # except:
2306
+ # print('no ticks given')
2307
+
2308
+ cb_e.outline.clear()
2309
+ cax_e.set_xlabel(
2310
+ link_colorbar_label, labelpad=1, fontsize=label_fontsize, zorder=10
2311
+ )
2312
+ cax_e.tick_params(axis='both', which='major', labelsize=tick_label_size)
2313
+
2314
+ ##
2315
+ # Draw nodes
2316
+ ##
2317
+ node_sizes = np.zeros((len(node_rings), N))
2318
+ for ring in list(node_rings): # iterate through to get all node sizes
2319
+ if node_rings[ring]["sizes"] is not None:
2320
+ node_sizes[ring] = node_rings[ring]["sizes"]
2321
+
2322
+ else:
2323
+ node_sizes[ring] = standard_size
2324
+ max_sizes = node_sizes.max(axis=1)
2325
+ total_max_size = node_sizes.sum(axis=0).max()
2326
+ node_sizes /= total_max_size
2327
+ node_sizes *= standard_size
2328
+
2329
+ def get_aspect(ax):
2330
+ # Total figure size
2331
+ figW, figH = ax.get_figure().get_size_inches()
2332
+ # print(figW, figH)
2333
+ # Axis size on figure
2334
+ _, _, w, h = ax.get_position().bounds
2335
+ # Ratio of display units
2336
+ # print(w, h)
2337
+ disp_ratio = (figH * h) / (figW * w)
2338
+ # Ratio of data units
2339
+ # Negative over negative because of the order of subtraction
2340
+ data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim())
2341
+ # print(data_ratio, disp_ratio)
2342
+ return disp_ratio / data_ratio
2343
+
2344
+ if node_aspect is None:
2345
+ node_aspect = get_aspect(ax)
2346
+
2347
+ # start drawing the outer ring first...
2348
+ for ring in list(node_rings)[::-1]:
2349
+ # print ring
2350
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
2351
+ # or None, 'cmap':string, 'vmin':float or None, 'vmax':float or None}}
2352
+ if node_rings[ring]["color_array"] is not None:
2353
+ color_data = node_rings[ring]["color_array"]
2354
+ if node_rings[ring]["vmin"] is not None:
2355
+ vmin = node_rings[ring]["vmin"]
2356
+ else:
2357
+ vmin = node_rings[ring]["color_array"].min()
2358
+ if node_rings[ring]["vmax"] is not None:
2359
+ vmax = node_rings[ring]["vmax"]
2360
+ else:
2361
+ vmax = node_rings[ring]["color_array"].max()
2362
+ if node_rings[ring]["cmap"] is not None:
2363
+ cmap = node_rings[ring]["cmap"]
2364
+ else:
2365
+ cmap = standard_cmap
2366
+ data_to_rgb = pyplot.cm.ScalarMappable(
2367
+ norm=None, cmap=pyplot.get_cmap(cmap)
2368
+ )
2369
+ data_to_rgb.set_array(color_data)
2370
+ data_to_rgb.set_clim(vmin=vmin, vmax=vmax)
2371
+ colors = [data_to_rgb.to_rgba(color_data[n]) for n in G]
2372
+
2373
+ if node_rings[ring]["colorbar"]:
2374
+ # Create colorbars for nodes
2375
+ # cax_n = pyplot.axes([.8 + ring*0.11,
2376
+ # ax.get_subplotspec().get_position(ax.figure).bounds[1]+0.05, 0.025, 0.35], frameon=False) #
2377
+ # setup colorbar axes.
2378
+ # setup colorbar axes.
2379
+ bbox_ax = ax.get_position()
2380
+ # print(bbox_ax.xmin, bbox_ax.xmax, bbox_ax.ymin, bbox_ax.ymax)
2381
+ cax_n = ax.inset_axes(
2382
+ [
2383
+ 0.05, -0.07, 0.4, 0.07
2384
+ # bbox_ax.xmin + width*0.05,
2385
+ # bbox_ax.ymin-0.075*height+network_lower_bound-0.15,
2386
+ # width*0.4,
2387
+ # 0.075*height, #0.025 + (len(all_links_edge_weights) == 0) * 0.035,
2388
+ ],
2389
+ frameon=False,
2390
+ )
2391
+ cb_n = pyplot.colorbar(data_to_rgb, cax=cax_n, orientation="horizontal")
2392
+ # try:
2393
+ ticks_here = np.arange(
2394
+ _myround(vmin, node_rings[ring]["ticks"], "down"),
2395
+ _myround(vmax, node_rings[ring]["ticks"], "up")
2396
+ + node_rings[ring]["ticks"],
2397
+ node_rings[ring]["ticks"],
2398
+ )
2399
+ cb_n.set_ticks(ticks_here[(vmin <= ticks_here) & (ticks_here <= vmax)])
2400
+ # except:
2401
+ # print ('no ticks given')
2402
+ cb_n.outline.clear()
2403
+ # cb_n.set_ticks()
2404
+ cax_n.set_xlabel(
2405
+ node_rings[ring]["label"], labelpad=1, fontsize=label_fontsize
2406
+ )
2407
+ cax_n.tick_params(axis='both', which='major', labelsize=tick_label_size)
2408
+ else:
2409
+ colors = None
2410
+ vmin = None
2411
+ vmax = None
2412
+
2413
+ for n in G:
2414
+ if type(node_alpha) == dict:
2415
+ alpha = node_alpha[n]
2416
+ else:
2417
+ alpha = 1.0
2418
+
2419
+ if special_nodes is not None:
2420
+ if n in special_nodes:
2421
+ color_here = special_nodes[n]
2422
+ else:
2423
+ color_here = 'grey'
2424
+ else:
2425
+ if colors is None:
2426
+ color_here = standard_color_nodes
2427
+ else:
2428
+ color_here = colors[n]
2429
+
2430
+ c = Ellipse(
2431
+ pos[n],
2432
+ width=node_sizes[: ring + 1].sum(axis=0)[n] * node_aspect,
2433
+ height=node_sizes[: ring + 1].sum(axis=0)[n],
2434
+ clip_on=False,
2435
+ facecolor=color_here,
2436
+ edgecolor=color_here,
2437
+ zorder=-ring - 1 + 2,
2438
+ transform=transform,
2439
+ )
2440
+
2441
+ # else:
2442
+ # if special_nodes is not None and n in special_nodes:
2443
+ # color_here = special_nodes[n]
2444
+ # else:
2445
+ # color_here = colors[n]
2446
+ # c = Ellipse(
2447
+ # pos[n],
2448
+ # width=node_sizes[: ring + 1].sum(axis=0)[n] * node_aspect,
2449
+ # height=node_sizes[: ring + 1].sum(axis=0)[n],
2450
+ # clip_on=False,
2451
+ # facecolor=colors[n],
2452
+ # edgecolor=colors[n],
2453
+ # zorder=-ring - 1,
2454
+ # )
2455
+
2456
+ ax.add_patch(c)
2457
+
2458
+ if node_classification is not None and node_classification[n] in ["space_context_last", "space_dummy_last", "time_dummy_last"]:
2459
+ node_height = node_sizes[: ring + 1].sum(axis=0)[n]
2460
+ node_width_difference_to_height = node_height * (1 - node_aspect)
2461
+
2462
+ c_wide = mpatches.FancyBboxPatch((pos[n-max_lag+1][0] + node_width_difference_to_height / 2, pos[n-max_lag+1][1]),
2463
+ (pos[n][0] - pos[n-max_lag+1][0] - node_width_difference_to_height),
2464
+ 0.,
2465
+ boxstyle=mpatches.BoxStyle.Round(pad=0.5 * node_height),
2466
+ facecolor=color_here,
2467
+ edgecolor=color_here,
2468
+ )
2469
+
2470
+ ax.add_patch(c_wide)
2471
+
2472
+
2473
+ # avoiding attribute error raised by changes in networkx
2474
+ if hasattr(G, "node"):
2475
+ # works with networkx 1.10
2476
+ G.node[n]["patch"] = c
2477
+ else:
2478
+ # works with networkx 2.4
2479
+ G.nodes[n]["patch"] = c
2480
+
2481
+ if ring == 0:
2482
+ ax.text(
2483
+ pos[n][0],
2484
+ pos[n][1],
2485
+ node_labels[n],
2486
+ fontsize=node_label_size,
2487
+ horizontalalignment="center",
2488
+ verticalalignment="center",
2489
+ alpha=1.0,
2490
+ zorder=5.,
2491
+ transform=transform,
2492
+ )
2493
+ if show_autodependency_lags:
2494
+ ax.text(
2495
+ pos[n][0],
2496
+ pos[n][1],
2497
+ autodep_sig_lags[n],
2498
+ fontsize=link_label_fontsize,
2499
+ horizontalalignment="center",
2500
+ verticalalignment="center",
2501
+ color="black",
2502
+ zorder=5.,
2503
+ transform=transform,
2504
+ )
2505
+
2506
+ # Draw edges
2507
+ seen = {}
2508
+ for (u, v, d) in G.edges(data=True):
2509
+ if d.get("no_links"):
2510
+ d["inner_edge_alpha"] = 1e-8
2511
+ d["outer_edge_alpha"] = 1e-8
2512
+ if u != v:
2513
+ if d["outer_edge"]:
2514
+ seen[(u, v)] = draw_edge(ax, u, v, d, seen, outer_edge=True)
2515
+ if d["inner_edge"]:
2516
+ seen[(u, v)] = draw_edge(ax, u, v, d, seen, outer_edge=False)
2517
+
2518
+ # if network_left_bound is not None:
2519
+ # network_right_bound = 0.98
2520
+ # else:
2521
+ # network_right_bound = None
2522
+ # fig.subplots_adjust(bottom=network_lower_bound, left=network_left_bound, right=network_right_bound) #, right=0.97)
2523
+
2524
+
2525
+ def plot_graph(
2526
+ graph,
2527
+ val_matrix=None,
2528
+ var_names=None,
2529
+ fig_ax=None,
2530
+ figsize=None,
2531
+ save_name=None,
2532
+ link_colorbar_label="MCI",
2533
+ node_colorbar_label="auto-MCI",
2534
+ link_width=None,
2535
+ link_attribute=None,
2536
+ node_pos=None,
2537
+ arrow_linewidth=8.0,
2538
+ vmin_edges=-1,
2539
+ vmax_edges=1.0,
2540
+ edge_ticks=0.4,
2541
+ cmap_edges="RdBu_r",
2542
+ vmin_nodes=-1,
2543
+ vmax_nodes=1.0,
2544
+ node_ticks=0.4,
2545
+ cmap_nodes="RdBu_r",
2546
+ node_size=0.3,
2547
+ node_aspect=None,
2548
+ arrowhead_size=20,
2549
+ curved_radius=0.2,
2550
+ label_fontsize=10,
2551
+ tick_label_size=6,
2552
+ alpha=1.0,
2553
+ node_label_size=10,
2554
+ link_label_fontsize=10,
2555
+ lag_array=None,
2556
+ # network_lower_bound=0.2,
2557
+ show_colorbar=True,
2558
+ inner_edge_style="dashed",
2559
+ link_matrix=None,
2560
+ special_nodes=None,
2561
+ show_autodependency_lags=False
2562
+ ):
2563
+ """Creates a network plot.
2564
+
2565
+ This is still in beta. The network is defined from links in graph. Nodes
2566
+ denote variables, straight links contemporaneous dependencies and curved
2567
+ arrows lagged dependencies. The node color denotes the maximal absolute
2568
+ auto-dependency and the link color the value at the lag with maximal
2569
+ absolute cross-dependency. The link label lists the lags with significant
2570
+ dependency in order of absolute magnitude. The network can also be
2571
+ plotted over a map drawn before on the same axis. Then the node positions
2572
+ can be supplied in appropriate axis coordinates via node_pos.
2573
+
2574
+ Parameters
2575
+ ----------
2576
+ graph : string or bool array-like, optional (default: None)
2577
+ Either string matrix providing graph or bool array providing only adjacencies
2578
+ Must be of same shape as val_matrix.
2579
+ val_matrix : array_like
2580
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
2581
+ var_names : list, optional (default: None)
2582
+ List of variable names. If None, range(N) is used.
2583
+ fig_ax : tuple of figure and axis object, optional (default: None)
2584
+ Figure and axes instance. If None they are created.
2585
+ figsize : tuple
2586
+ Size of figure.
2587
+ save_name : str, optional (default: None)
2588
+ Name of figure file to save figure. If None, figure is shown in window.
2589
+ link_colorbar_label : str, optional (default: 'MCI')
2590
+ Test statistic label.
2591
+ node_colorbar_label : str, optional (default: 'auto-MCI')
2592
+ Test statistic label for auto-dependencies.
2593
+ link_width : array-like, optional (default: None)
2594
+ Array of val_matrix.shape specifying relative link width with maximum
2595
+ given by arrow_linewidth. If None, all links have same width.
2596
+ link_attribute : array-like, optional (default: None)
2597
+ String array of val_matrix.shape specifying link attributes.
2598
+ node_pos : dictionary, optional (default: None)
2599
+ Dictionary of node positions in axis coordinates of form
2600
+ node_pos = {'x':array of shape (N,), 'y':array of shape(N)}. These
2601
+ coordinates could have been transformed before for basemap plots. You can
2602
+ also add a key 'transform':ccrs.PlateCarree() in order to plot graphs on
2603
+ a map using cartopy.
2604
+ arrow_linewidth : float, optional (default: 30)
2605
+ Linewidth.
2606
+ vmin_edges : float, optional (default: -1)
2607
+ Link colorbar scale lower bound.
2608
+ vmax_edges : float, optional (default: 1)
2609
+ Link colorbar scale upper bound.
2610
+ edge_ticks : float, optional (default: 0.4)
2611
+ Link tick mark interval.
2612
+ cmap_edges : str, optional (default: 'RdBu_r')
2613
+ Colormap for links.
2614
+ vmin_nodes : float, optional (default: 0)
2615
+ Node colorbar scale lower bound.
2616
+ vmax_nodes : float, optional (default: 1)
2617
+ Node colorbar scale upper bound.
2618
+ node_ticks : float, optional (default: 0.4)
2619
+ Node tick mark interval.
2620
+ cmap_nodes : str, optional (default: 'OrRd')
2621
+ Colormap for links.
2622
+ node_size : int, optional (default: 0.3)
2623
+ Node size.
2624
+ node_aspect : float, optional (default: None)
2625
+ Ratio between the heigth and width of the varible nodes.
2626
+ arrowhead_size : int, optional (default: 20)
2627
+ Size of link arrow head. Passed on to FancyArrowPatch object.
2628
+ curved_radius, float, optional (default: 0.2)
2629
+ Curvature of links. Passed on to FancyArrowPatch object.
2630
+ label_fontsize : int, optional (default: 10)
2631
+ Fontsize of colorbar labels.
2632
+ alpha : float, optional (default: 1.)
2633
+ Opacity.
2634
+ node_label_size : int, optional (default: 10)
2635
+ Fontsize of node labels.
2636
+ link_label_fontsize : int, optional (default: 6)
2637
+ Fontsize of link labels.
2638
+ tick_label_size : int, optional (default: 6)
2639
+ Fontsize of tick labels.
2640
+ lag_array : array, optional (default: None)
2641
+ Optional specification of lags overwriting np.arange(0, tau_max+1)
2642
+ show_colorbar : bool
2643
+ Whether to show colorbars for links and nodes.
2644
+ show_autodependency_lags : bool (default: False)
2645
+ Shows significant autodependencies for a node.
2646
+ """
2647
+
2648
+ if link_matrix is not None:
2649
+ raise ValueError("link_matrix is deprecated and replaced by graph array"
2650
+ " which is now returned by all methods.")
2651
+
2652
+ if fig_ax is None:
2653
+ fig = pyplot.figure(figsize=figsize)
2654
+ ax = fig.add_subplot(111, frame_on=False)
2655
+ else:
2656
+ fig, ax = fig_ax
2657
+
2658
+ graph = np.copy(graph.squeeze())
2659
+
2660
+ if graph.ndim == 4:
2661
+ raise ValueError("Time series graph of shape (N,N,tau_max+1,tau_max+1) cannot be represented by plot_graph,"
2662
+ " use plot_time_series_graph instead.")
2663
+
2664
+ if graph.ndim == 2:
2665
+ # If a non-time series (N,N)-graph is given, insert a dummy dimension
2666
+ graph = np.expand_dims(graph, axis = 2)
2667
+
2668
+ if val_matrix is None:
2669
+ no_coloring = True
2670
+ cmap_edges = None
2671
+ cmap_nodes = None
2672
+ else:
2673
+ no_coloring = False
2674
+
2675
+ (graph, val_matrix, link_width, link_attribute) = _check_matrices(
2676
+ graph, val_matrix, link_width, link_attribute)
2677
+
2678
+
2679
+ N, N, dummy = graph.shape
2680
+ tau_max = dummy - 1
2681
+ max_lag = tau_max + 1
2682
+
2683
+ if np.count_nonzero(graph != "") == np.count_nonzero(
2684
+ np.diagonal(graph) != ""
2685
+ ):
2686
+ diagonal = True
2687
+ else:
2688
+ diagonal = False
2689
+
2690
+ if np.count_nonzero(graph == "") == graph.size or diagonal:
2691
+ graph[0, 1, 0] = "xxx" # Workaround, will not be plotted...
2692
+ no_links = True
2693
+ else:
2694
+ no_links = False
2695
+
2696
+ if var_names is None:
2697
+ var_names = range(N)
2698
+
2699
+ # Define graph links by absolute maximum (positive or negative like for
2700
+ # partial correlation)
2701
+ # val_matrix[np.abs(val_matrix) < sig_thres] = 0.
2702
+
2703
+ # Only draw link in one direction among contemp
2704
+ # Remove lower triangle
2705
+ link_matrix_upper = np.copy(graph)
2706
+ link_matrix_upper[:, :, 0] = np.triu(link_matrix_upper[:, :, 0])
2707
+
2708
+ # net = _get_absmax(link_matrix != "")
2709
+ net = np.any(link_matrix_upper != "", axis=2)
2710
+ G = nx.DiGraph(net)
2711
+
2712
+ # This handels Graphs with no links.
2713
+ # nx.draw(G, alpha=0, zorder=-10)
2714
+
2715
+ node_color = list(np.zeros(N))
2716
+
2717
+ if show_autodependency_lags:
2718
+ autodep_sig_lags = np.full(N, None, dtype='object')
2719
+ else:
2720
+ autodep_sig_lags = None
2721
+
2722
+ # list of all strengths for color map
2723
+ all_strengths = []
2724
+ # Add attributes, contemporaneous and lagged links are handled separately
2725
+ for (u, v, dic) in G.edges(data=True):
2726
+ dic["no_links"] = no_links
2727
+ # average lagfunc for link u --> v ANDOR u -- v
2728
+ if tau_max > 0:
2729
+ # argmax of absolute maximum where a link exists!
2730
+ links = np.where(link_matrix_upper[u, v, 1:] != "")[0]
2731
+ if len(links) > 0:
2732
+ argmax_links = np.abs(val_matrix[u, v][1:][links]).argmax()
2733
+ argmax = links[argmax_links] + 1
2734
+ else:
2735
+ argmax = 0
2736
+ else:
2737
+ argmax = 0
2738
+
2739
+ if u != v:
2740
+ # For contemp links masking or finite samples can lead to different
2741
+ # values for u--v and v--u
2742
+ # Here we use the maximum for the width and weight (=color)
2743
+ # of the link
2744
+ # Draw link if u--v OR v--u at lag 0 is nonzero
2745
+ # dic['inner_edge'] = ((np.abs(val_matrix[u, v][0]) >=
2746
+ # sig_thres[u, v][0]) or
2747
+ # (np.abs(val_matrix[v, u][0]) >=
2748
+ # sig_thres[v, u][0]))
2749
+ dic["inner_edge"] = link_matrix_upper[u, v, 0]
2750
+ dic["inner_edge_type"] = link_matrix_upper[u, v, 0]
2751
+ dic["inner_edge_alpha"] = alpha
2752
+ if no_coloring:
2753
+ dic["inner_edge_color"] = None
2754
+ else:
2755
+ dic["inner_edge_color"] = val_matrix[u, v, 0]
2756
+ # # value at argmax of average
2757
+ # if np.abs(val_matrix[u, v][0] - val_matrix[v, u][0]) > .0001:
2758
+ # print("Contemporaneous I(%d; %d)=%.3f != I(%d; %d)=%.3f" % (
2759
+ # u, v, val_matrix[u, v][0], v, u, val_matrix[v, u][0]) +
2760
+ # " due to conditions, finite sample effects or "
2761
+ # "masking, here edge color = "
2762
+ # "larger (absolute) value.")
2763
+ # dic['inner_edge_color'] = _get_absmax(
2764
+ # np.array([[[val_matrix[u, v][0],
2765
+ # val_matrix[v, u][0]]]])).squeeze()
2766
+
2767
+ if link_width is None:
2768
+ dic["inner_edge_width"] = arrow_linewidth
2769
+ else:
2770
+ dic["inner_edge_width"] = (
2771
+ link_width[u, v, 0] / link_width.max() * arrow_linewidth
2772
+ )
2773
+
2774
+ if link_attribute is None:
2775
+ dic["inner_edge_attribute"] = None
2776
+ else:
2777
+ dic["inner_edge_attribute"] = link_attribute[u, v, 0]
2778
+
2779
+ # # fraction of nonzero values
2780
+ dic["inner_edge_style"] = "solid"
2781
+ # else:
2782
+ # dic['inner_edge_style'] = link_style[
2783
+ # u, v, 0]
2784
+
2785
+ all_strengths.append(dic["inner_edge_color"])
2786
+
2787
+ if tau_max > 0:
2788
+ # True if ensemble mean at lags > 0 is nonzero
2789
+ # dic['outer_edge'] = np.any(
2790
+ # np.abs(val_matrix[u, v][1:]) >= sig_thres[u, v][1:])
2791
+ dic["outer_edge"] = np.any(link_matrix_upper[u, v, 1:] != "")
2792
+ else:
2793
+ dic["outer_edge"] = False
2794
+ # print(u, v, dic["outer_edge"], argmax, link_matrix_upper[u, v, :])
2795
+
2796
+ dic["outer_edge_type"] = link_matrix_upper[u, v, argmax]
2797
+
2798
+ dic["outer_edge_alpha"] = alpha
2799
+ if link_width is None:
2800
+ # fraction of nonzero values
2801
+ dic["outer_edge_width"] = arrow_linewidth
2802
+ else:
2803
+ dic["outer_edge_width"] = (
2804
+ link_width[u, v, argmax] / link_width.max() * arrow_linewidth
2805
+ )
2806
+
2807
+ if link_attribute is None:
2808
+ # fraction of nonzero values
2809
+ dic["outer_edge_attribute"] = None
2810
+ else:
2811
+ dic["outer_edge_attribute"] = link_attribute[u, v, argmax]
2812
+
2813
+ # value at argmax of average
2814
+ if no_coloring:
2815
+ dic["outer_edge_color"] = None
2816
+ else:
2817
+ dic["outer_edge_color"] = val_matrix[u, v][argmax]
2818
+ all_strengths.append(dic["outer_edge_color"])
2819
+
2820
+ # Sorted list of significant lags (only if robust wrt
2821
+ # d['min_ensemble_frac'])
2822
+ if tau_max > 0:
2823
+ lags = np.abs(val_matrix[u, v][1:]).argsort()[::-1] + 1
2824
+ sig_lags = (np.where(link_matrix_upper[u, v, 1:] != "")[0] + 1).tolist()
2825
+ else:
2826
+ lags, sig_lags = [], []
2827
+ if lag_array is not None:
2828
+ dic["label"] = ",".join([str(lag_array[l]) for l in lags if l in sig_lags]) #str([str(lag_array[l]) for l in lags if l in sig_lags])[1:-1].replace(" ", "")
2829
+ else:
2830
+ dic["label"] = ",".join([str(l) for l in lags if l in sig_lags]) # str([str(l) for l in lags if l in sig_lags])[1:-1].replace(" ", "")
2831
+ else:
2832
+ # Node color is max of average autodependency
2833
+ if no_coloring:
2834
+ node_color[u] = None
2835
+ else:
2836
+ node_color[u] = val_matrix[u, v][argmax]
2837
+
2838
+ if show_autodependency_lags:
2839
+ autodep_sig_lags[u] = "\n\n\n" + ",".join(str(i) for i in (np.where(link_matrix_upper[u, v, 1:] != "")[0] + 1).tolist())
2840
+ # Lags upto tau_max
2841
+ #autodep_lags = np.argsort(val_matrix[u, v][1:])[::-1]
2842
+ #autodep_lags += 1
2843
+ #autodeplags[u] = "\n\n\n" + ",".join(str(i) for i in autodep_lags.tolist())
2844
+
2845
+ dic["inner_edge_attribute"] = None
2846
+ dic["outer_edge_attribute"] = None
2847
+
2848
+ # dic['outer_edge_edge'] = False
2849
+ # dic['outer_edge_edgecolor'] = None
2850
+ # dic['inner_edge_edge'] = False
2851
+ # dic['inner_edge_edgecolor'] = None
2852
+
2853
+ if special_nodes is not None:
2854
+ special_nodes_draw = {}
2855
+ for node in special_nodes:
2856
+ i, tau = node
2857
+ if tau >= -tau_max:
2858
+ special_nodes_draw[i] = special_nodes[node]
2859
+ special_nodes = special_nodes_draw
2860
+
2861
+
2862
+ # If no links are present, set value to zero
2863
+ if len(all_strengths) == 0:
2864
+ all_strengths = [0.0]
2865
+
2866
+ if node_pos is None:
2867
+ pos = nx.circular_layout(deepcopy(G))
2868
+ else:
2869
+ pos = {}
2870
+ for i in range(N):
2871
+ pos[i] = (node_pos["x"][i], node_pos["y"][i])
2872
+
2873
+ if node_pos is not None and 'transform' in node_pos:
2874
+ transform = node_pos['transform']
2875
+ else: transform = ax.transData
2876
+
2877
+ if cmap_nodes is None:
2878
+ node_color = None
2879
+
2880
+ node_rings = {
2881
+ 0: {
2882
+ "sizes": None,
2883
+ "color_array": node_color,
2884
+ "cmap": cmap_nodes,
2885
+ "vmin": vmin_nodes,
2886
+ "vmax": vmax_nodes,
2887
+ "ticks": node_ticks,
2888
+ "label": node_colorbar_label,
2889
+ "colorbar": show_colorbar,
2890
+ }
2891
+ }
2892
+
2893
+ _draw_network_with_curved_edges(
2894
+ fig=fig,
2895
+ ax=ax,
2896
+ G=deepcopy(G),
2897
+ pos=pos,
2898
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
2899
+ # or None, 'cmap':string,
2900
+ node_rings=node_rings,
2901
+ # 'vmin':float or None, 'vmax':float or None, 'label':string or None}}
2902
+ node_labels=var_names,
2903
+ node_label_size=node_label_size,
2904
+ node_alpha=alpha,
2905
+ standard_size=node_size,
2906
+ node_aspect=node_aspect,
2907
+ standard_cmap="OrRd",
2908
+ standard_color_nodes="lightgrey",
2909
+ standard_color_links="black",
2910
+ log_sizes=False,
2911
+ cmap_links=cmap_edges,
2912
+ links_vmin=vmin_edges,
2913
+ links_vmax=vmax_edges,
2914
+ links_ticks=edge_ticks,
2915
+ tick_label_size=tick_label_size,
2916
+ # cmap_links_edges='YlOrRd', links_edges_vmin=-1., links_edges_vmax=1.,
2917
+ # links_edges_ticks=.2, link_edge_colorbar_label='link_edge',
2918
+ arrowstyle="simple",
2919
+ arrowhead_size=arrowhead_size,
2920
+ curved_radius=curved_radius,
2921
+ label_fontsize=label_fontsize,
2922
+ link_label_fontsize=link_label_fontsize,
2923
+ link_colorbar_label=link_colorbar_label,
2924
+ # network_lower_bound=network_lower_bound,
2925
+ show_colorbar=show_colorbar,
2926
+ # label_fraction=label_fraction,
2927
+ special_nodes=special_nodes,
2928
+ autodep_sig_lags=autodep_sig_lags,
2929
+ show_autodependency_lags=show_autodependency_lags,
2930
+ transform=transform
2931
+ )
2932
+
2933
+ if save_name is not None:
2934
+ pyplot.savefig(save_name, dpi=300)
2935
+ else:
2936
+ return fig, ax
2937
+
2938
+
2939
+ def _reverse_patt(patt):
2940
+ """Inverts a link pattern"""
2941
+
2942
+ if patt == "":
2943
+ return ""
2944
+
2945
+ left_mark, middle_mark, right_mark = patt[0], patt[1], patt[2]
2946
+ if left_mark == "<":
2947
+ new_right_mark = ">"
2948
+ else:
2949
+ new_right_mark = left_mark
2950
+ if right_mark == ">":
2951
+ new_left_mark = "<"
2952
+ else:
2953
+ new_left_mark = right_mark
2954
+
2955
+ return new_left_mark + middle_mark + new_right_mark
2956
+
2957
+ # if patt in ['---', 'o--', '--o', 'o-o', '']:
2958
+ # return patt[::-1]
2959
+ # elif patt == '<->':
2960
+ # return '<->'
2961
+ # elif patt == 'o->':
2962
+ # return '<-o'
2963
+ # elif patt == '<-o':
2964
+ # return 'o->'
2965
+ # elif patt == '-->':
2966
+ # return '<--'
2967
+ # elif patt == '<--':
2968
+ # return '-->'
2969
+
2970
+
2971
+ def _check_matrices(graph, val_matrix, link_width, link_attribute):
2972
+
2973
+ if graph.dtype != "<U3":
2974
+ # Transform to new graph data type U3
2975
+ old_matrix = np.copy(graph)
2976
+ graph = np.zeros(old_matrix.shape, dtype="<U3")
2977
+ graph[:] = ""
2978
+ for i, j, tau in zip(*np.where(old_matrix)):
2979
+ if tau == 0:
2980
+ if old_matrix[j, i, 0] == 0:
2981
+ graph[i, j, 0] = "-->"
2982
+ graph[j, i, 0] = "<--"
2983
+ else:
2984
+ graph[i, j, 0] = "o-o"
2985
+ graph[j, i, 0] = "o-o"
2986
+ else:
2987
+ graph[i, j, tau] = "-->"
2988
+ if graph.ndim == 4:
2989
+ for i, j, taui, tauj in zip(*np.where(graph)):
2990
+ if graph[i, j, taui, tauj] not in [
2991
+ "---",
2992
+ "o--",
2993
+ "--o",
2994
+ "o-o",
2995
+ "o->",
2996
+ "<-o",
2997
+ "-->",
2998
+ "<--",
2999
+ "<->",
3000
+ "x-o",
3001
+ "o-x",
3002
+ "x--",
3003
+ "--x",
3004
+ "x->",
3005
+ "<-x",
3006
+ "x-x",
3007
+ "<-+",
3008
+ "+->",
3009
+ ]:
3010
+ raise ValueError("Invalid graph entry.")
3011
+ if graph[i, j, taui, tauj] != _reverse_patt(graph[j, i, tauj, taui]):
3012
+ raise ValueError(
3013
+ "graph needs to have consistent entries: "
3014
+ "graph[i, j, taui, tauj] == _reverse_patt(graph[j, i, tauj, taui])")
3015
+ if (
3016
+ val_matrix is not None
3017
+ and val_matrix[i, j, taui, tauj] != val_matrix[j, i, tauj, taui]
3018
+ ):
3019
+ raise ValueError(
3020
+ "val_matrix needs to have consistent entries: "
3021
+ "val_matrix[i, j, taui, tauj] == val_matrix[j, i, tauj, taui]")
3022
+ if (
3023
+ link_width is not None
3024
+ and link_width[i, j, taui, tauj] != link_width[j, i, tauj, taui]
3025
+ ):
3026
+ raise ValueError(
3027
+ "link_width needs to have consistent entries: "
3028
+ "link_width[i, j, taui, tauj] == link_width[j, i, tauj, taui]")
3029
+ if (
3030
+ link_attribute is not None
3031
+ and link_attribute[i, j, taui, tauj] != link_attribute[j, i, tauj, taui]
3032
+ ):
3033
+ raise ValueError(
3034
+ "link_attribute needs to have consistent entries: "
3035
+ "link_attribute[i, j, taui, tauj] == link_attribute[j, i, tauj, taui]")
3036
+ else:
3037
+ # print(graph[:,:,0])
3038
+ # Assert that graph has valid and consistent lag-zero entries
3039
+ for i, j, tau in zip(*np.where(graph)):
3040
+ if tau == 0:
3041
+ if graph[i, j, 0] != _reverse_patt(graph[j, i, 0]):
3042
+ raise ValueError(
3043
+ "graph needs to have consistent lag-zero links, but "
3044
+ " graph[%d,%d,0]=%s and graph[%d,%d,0]=%s)" %(i, j, graph[i, j, 0], j, i, graph[j, i, 0])
3045
+ )
3046
+ if (
3047
+ val_matrix is not None
3048
+ and val_matrix[i, j, 0] != val_matrix[j, i, 0]
3049
+ ):
3050
+ raise ValueError("val_matrix needs to be symmetric for lag-zero")
3051
+ if (
3052
+ link_width is not None
3053
+ and link_width[i, j, 0] != link_width[j, i, 0]
3054
+ ):
3055
+ raise ValueError("link_width needs to be symmetric for lag-zero")
3056
+ if (
3057
+ link_attribute is not None
3058
+ and link_attribute[i, j, 0] != link_attribute[j, i, 0]
3059
+ ):
3060
+ raise ValueError(
3061
+ "link_attribute needs to be symmetric for lag-zero"
3062
+ )
3063
+
3064
+ if graph[i, j, tau] not in [
3065
+ "---",
3066
+ "o--",
3067
+ "--o",
3068
+ "o-o",
3069
+ "o->",
3070
+ "<-o",
3071
+ "-->",
3072
+ "<--",
3073
+ "<->",
3074
+ "x-o",
3075
+ "o-x",
3076
+ "x--",
3077
+ "--x",
3078
+ "x->",
3079
+ "<-x",
3080
+ "x-x",
3081
+ "<-+",
3082
+ "+->",
3083
+ ]:
3084
+ raise ValueError("Invalid graph entry.")
3085
+
3086
+ if val_matrix is None:
3087
+ # if graph.ndim == 4:
3088
+ # val_matrix = (graph != "").astype("int")
3089
+ # else:
3090
+ val_matrix = (graph != "").astype("int")
3091
+
3092
+ if link_width is not None and not np.all(link_width >= 0.0):
3093
+ raise ValueError("link_width must be non-negative")
3094
+
3095
+ return graph, val_matrix, link_width, link_attribute
3096
+
3097
+
3098
+ def plot_time_series_graph(
3099
+ graph,
3100
+ val_matrix=None,
3101
+ var_names=None,
3102
+ fig_ax=None,
3103
+ figsize=None,
3104
+ link_colorbar_label="MCI",
3105
+ save_name=None,
3106
+ link_width=None,
3107
+ link_attribute=None,
3108
+ arrow_linewidth=4,
3109
+ vmin_edges=-1,
3110
+ vmax_edges=1.0,
3111
+ edge_ticks=0.4,
3112
+ cmap_edges="RdBu_r",
3113
+ order=None,
3114
+ node_size=0.1,
3115
+ node_aspect=None,
3116
+ arrowhead_size=20,
3117
+ curved_radius=0.2,
3118
+ label_fontsize=10,
3119
+ tick_label_size=6,
3120
+ alpha=1.0,
3121
+ inner_edge_style="dashed",
3122
+ link_matrix=None,
3123
+ special_nodes=None,
3124
+ node_classification=None,
3125
+ # aux_graph=None,
3126
+ standard_color_links='black',
3127
+ standard_color_nodes='lightgrey',
3128
+ ):
3129
+ """Creates a time series graph.
3130
+ This is still in beta. The time series graph's links are colored by
3131
+ val_matrix.
3132
+
3133
+ Parameters
3134
+ ----------
3135
+ graph : string or bool array-like, optional (default: None)
3136
+ Either string matrix providing graph or bool array providing only adjacencies
3137
+ Either of shape (N, N, tau_max + 1) or as auxiliary graph of dims
3138
+ (N, N, tau_max+1, tau_max+1) describing auxADMG.
3139
+ val_matrix : array_like
3140
+ Matrix of same shape as graph containing test statistic values.
3141
+ var_names : list, optional (default: None)
3142
+ List of variable names. If None, range(N) is used.
3143
+ fig_ax : tuple of figure and axis object, optional (default: None)
3144
+ Figure and axes instance. If None they are created.
3145
+ figsize : tuple
3146
+ Size of figure.
3147
+ save_name : str, optional (default: None)
3148
+ Name of figure file to save figure. If None, figure is shown in window.
3149
+ link_colorbar_label : str, optional (default: 'MCI')
3150
+ Test statistic label.
3151
+ link_width : array-like, optional (default: None)
3152
+ Array of val_matrix.shape specifying relative link width with maximum
3153
+ given by arrow_linewidth. If None, all links have same width.
3154
+ link_attribute : array-like, optional (default: None)
3155
+ Array of graph.shape specifying specific in drawing the graph (for internal use).
3156
+ order : list, optional (default: None)
3157
+ order of variables from top to bottom.
3158
+ arrow_linewidth : float, optional (default: 30)
3159
+ Linewidth.
3160
+ vmin_edges : float, optional (default: -1)
3161
+ Link colorbar scale lower bound.
3162
+ vmax_edges : float, optional (default: 1)
3163
+ Link colorbar scale upper bound.
3164
+ edge_ticks : float, optional (default: 0.4)
3165
+ Link tick mark interval.
3166
+ cmap_edges : str, optional (default: 'RdBu_r')
3167
+ Colormap for links.
3168
+ node_size : int, optional (default: 0.1)
3169
+ Node size.
3170
+ node_aspect : float, optional (default: None)
3171
+ Ratio between the heigth and width of the varible nodes.
3172
+ arrowhead_size : int, optional (default: 20)
3173
+ Size of link arrow head. Passed on to FancyArrowPatch object.
3174
+ curved_radius, float, optional (default: 0.2)
3175
+ Curvature of links. Passed on to FancyArrowPatch object.
3176
+ label_fontsize : int, optional (default: 10)
3177
+ Fontsize of colorbar labels.
3178
+ alpha : float, optional (default: 1.)
3179
+ Opacity.
3180
+ node_label_size : int, optional (default: 10)
3181
+ Fontsize of node labels.
3182
+ link_label_fontsize : int, optional (default: 6)
3183
+ Fontsize of link labels.
3184
+ tick_label_size : int, optional (default: 6)
3185
+ Fontsize of tick labels.
3186
+ inner_edge_style : string, optional (default: 'dashed')
3187
+ Style of inner_edge contemporaneous links.
3188
+ special_nodes : dict
3189
+ Dictionary of format {(i, -tau): 'blue', ...} to color special nodes.
3190
+ node_classification : dict or None (default: None)
3191
+ Dictionary of format {i: 'space_context', ...} to classify nodes into system, context, or dummy nodes.
3192
+ Keys of the dictionary are from {0, ..., N-1} where N is the number of nodes.
3193
+ Options for the values are "system", "time_context", "space_context", "time_dummy", or "space_dummy".
3194
+ Space_contexts and dummy nodes need to be represented as a single node in the time series graph.
3195
+ In case no value is supplied all nodes are treated as system nodes, i.e. are plotted in a time-resolved manner.
3196
+ """
3197
+
3198
+ if link_matrix is not None:
3199
+ raise ValueError("link_matrix is deprecated and replaced by graph array"
3200
+ " which is now returned by all methods.")
3201
+
3202
+ if fig_ax is None:
3203
+ fig = pyplot.figure(figsize=figsize)
3204
+ ax = fig.add_subplot(111, frame_on=False)
3205
+ else:
3206
+ fig, ax = fig_ax
3207
+
3208
+ if val_matrix is None:
3209
+ no_coloring = True
3210
+ cmap_edges = None
3211
+ else:
3212
+ no_coloring = False
3213
+
3214
+ (graph, val_matrix, link_width, link_attribute) = _check_matrices(
3215
+ graph, val_matrix, link_width, link_attribute
3216
+ )
3217
+
3218
+ if graph.ndim == 4:
3219
+ N, N, dummy, _ = graph.shape
3220
+ tau_max = dummy - 1
3221
+ max_lag = tau_max + 1
3222
+ else:
3223
+ N, N, dummy = graph.shape
3224
+ tau_max = dummy - 1
3225
+ max_lag = tau_max + 1
3226
+
3227
+ if np.count_nonzero(graph == "") == graph.size:
3228
+ if graph.ndim == 4:
3229
+ graph[0, 1, 0, 0] = "---"
3230
+ else:
3231
+ graph[0, 1, 0] = "---"
3232
+ no_links = True
3233
+ else:
3234
+ no_links = False
3235
+
3236
+ if var_names is None:
3237
+ var_names = range(N)
3238
+
3239
+ if order is None:
3240
+ order = range(N)
3241
+
3242
+ if set(order) != set(range(N)):
3243
+ raise ValueError("order must be a permutation of range(N)")
3244
+
3245
+ def translate(row, lag):
3246
+ return row * max_lag + lag
3247
+
3248
+ # Define graph links by absolute maximum (positive or negative like for
3249
+ # partial correlation)
3250
+ tsg = np.zeros((N * max_lag, N * max_lag))
3251
+ tsg_val = np.zeros((N * max_lag, N * max_lag))
3252
+ tsg_width = np.zeros((N * max_lag, N * max_lag))
3253
+ tsg_style = np.zeros((N * max_lag, N * max_lag), dtype=graph.dtype)
3254
+ if link_attribute is not None:
3255
+ tsg_attr = np.zeros((N * max_lag, N * max_lag), dtype=link_attribute.dtype)
3256
+
3257
+ if graph.ndim == 4:
3258
+ # 4-dimensional graphs represent the finite-time window projection of stationary 3-d graphs
3259
+ # They are internally created in some classes
3260
+ # Only draw link in one direction
3261
+ for i, j, taui, tauj in np.column_stack(np.where(graph)):
3262
+ tau = taui - tauj
3263
+ # if tau <= 0 and j <= i:
3264
+ if translate(i, max_lag - 1 - taui) >= translate(j, max_lag-1-tauj):
3265
+ continue
3266
+ # print(max_lag, (i, -taui), (j, -tauj), aux_graph[i, j, taui, tauj])
3267
+ # print(translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj))
3268
+ tsg[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = 1.0
3269
+ tsg_val[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = val_matrix[i, j, taui, tauj]
3270
+ tsg_style[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = graph[i, j, taui, tauj]
3271
+ if link_width is not None:
3272
+ tsg_width[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = link_width[i, j, taui, tauj] / link_width.max() * arrow_linewidth
3273
+ if link_attribute is not None:
3274
+ tsg_attr[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = link_attribute[i, j, taui, tauj] #'spurious'
3275
+ # print(tsg_style)
3276
+ # print(tsg_style[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)] = graph[i, j, taui, tauj])
3277
+ # print(max_lag, (i, -taui), (j, -tauj), graph[i, j, taui, tauj], tsg_style[translate(i, max_lag - 1 - taui), translate(j, max_lag-1-tauj)])
3278
+
3279
+
3280
+ else:
3281
+ # Only draw link in one direction
3282
+ # Remove lower triangle
3283
+ link_matrix_tsg = np.copy(graph)
3284
+ link_matrix_tsg[:, :, 0] = np.triu(graph[:, :, 0])
3285
+
3286
+ for i, j, tau in np.column_stack(np.where(link_matrix_tsg)):
3287
+ for t in range(max_lag):
3288
+ if (
3289
+ 0 <= translate(i, t - tau)
3290
+ and translate(i, t - tau) % max_lag <= translate(j, t) % max_lag
3291
+ ):
3292
+
3293
+ tsg[
3294
+ translate(i, t - tau), translate(j, t)
3295
+ ] = 1.0 # val_matrix[i, j, tau]
3296
+ tsg_val[translate(i, t - tau), translate(j, t)] = val_matrix[i, j, tau]
3297
+ tsg_style[translate(i, t - tau), translate(j, t)] = graph[
3298
+ i, j, tau
3299
+ ]
3300
+ if link_width is not None:
3301
+ tsg_width[translate(i, t - tau), translate(j, t)] = (
3302
+ link_width[i, j, tau] / link_width.max() * arrow_linewidth
3303
+ )
3304
+ if link_attribute is not None:
3305
+ tsg_attr[translate(i, t - tau), translate(j, t)] = link_attribute[
3306
+ i, j, tau
3307
+ ]
3308
+
3309
+
3310
+ G = nx.DiGraph(tsg)
3311
+
3312
+ if special_nodes is not None:
3313
+ special_nodes_tsg = {}
3314
+ for node in special_nodes:
3315
+ i, tau = node
3316
+ if tau >= -tau_max:
3317
+ special_nodes_tsg[translate(i, max_lag-1 + tau)] = special_nodes[node]
3318
+
3319
+ special_nodes = special_nodes_tsg
3320
+
3321
+ if node_classification is None:
3322
+ node_classification = {i: "system" for i in range(N)}
3323
+ node_classification_tsg = {}
3324
+ for node in node_classification:
3325
+ for tau in range(max_lag):
3326
+ if tau == 0:
3327
+ suffix = "_first"
3328
+ elif tau == max_lag-1:
3329
+ suffix = "_last"
3330
+ else:
3331
+ suffix = "_middle"
3332
+ node_classification_tsg[translate(node, tau)] = node_classification[node] + suffix
3333
+
3334
+ # node_color = np.zeros(N)
3335
+ # list of all strengths for color map
3336
+ all_strengths = []
3337
+ # Add attributes, contemporaneous and lagged links are handled separately
3338
+ for (u, v, dic) in G.edges(data=True):
3339
+ dic["no_links"] = no_links
3340
+ if u != v:
3341
+ # tau = np.abs((u - v) % max_lag)
3342
+ # Determine neighbors in TSG
3343
+ i = u // max_lag
3344
+ taui = -(max_lag -1 - (u % max_lag))
3345
+ j = v // max_lag
3346
+ tauj = -(max_lag -1 - (v % max_lag))
3347
+
3348
+ if np.abs(i-j) <= 1 and np.abs(tauj-taui) <= 1:
3349
+ inout = 'inner'
3350
+ dic["inner_edge"] = True
3351
+ dic["outer_edge"] = False
3352
+ else:
3353
+ inout = 'outer'
3354
+ dic["inner_edge"] = False
3355
+ dic["outer_edge"] = True
3356
+
3357
+ dic["%s_edge_type" % inout] = tsg_style[u, v]
3358
+
3359
+ dic["%s_edge_alpha" % inout] = alpha
3360
+
3361
+ if link_width is None:
3362
+ # fraction of nonzero values
3363
+ dic["%s_edge_width" % inout] = dic["%s_edge_width" % inout] = arrow_linewidth
3364
+ else:
3365
+ dic["%s_edge_width" % inout] = dic["%s_edge_width" % inout] = tsg_width[u, v]
3366
+
3367
+ if link_attribute is None:
3368
+ dic["%s_edge_attribute" % inout] = None
3369
+ else:
3370
+ dic["%s_edge_attribute" % inout] = tsg_attr[u, v]
3371
+
3372
+ # value at argmax of average
3373
+ if no_coloring:
3374
+ dic["%s_edge_color" % inout] = None
3375
+ else:
3376
+ dic["%s_edge_color" % inout] = tsg_val[u, v]
3377
+
3378
+ all_strengths.append(dic["%s_edge_color" % inout])
3379
+ dic["label"] = None
3380
+ # print(u, v, dic)
3381
+
3382
+ # If no links are present, set value to zero
3383
+ if len(all_strengths) == 0:
3384
+ all_strengths = [0.0]
3385
+
3386
+ posarray = np.zeros((N * max_lag, 2))
3387
+ for i in range(N * max_lag):
3388
+ posarray[i] = np.array([(i % max_lag), (1.0 - i // max_lag)])
3389
+
3390
+ pos_tmp = {}
3391
+ for i in range(N * max_lag):
3392
+ # for n in range(N):
3393
+ # for tau in range(max_lag):
3394
+ # i = n*N + tau
3395
+ pos_tmp[i] = np.array(
3396
+ [
3397
+ ((i % max_lag) - posarray.min(axis=0)[0])
3398
+ / (posarray.max(axis=0)[0] - posarray.min(axis=0)[0]),
3399
+ ((1.0 - i // max_lag) - posarray.min(axis=0)[1])
3400
+ / (posarray.max(axis=0)[1] - posarray.min(axis=0)[1]),
3401
+ ]
3402
+ )
3403
+ pos_tmp[i][np.isnan(pos_tmp[i])] = 0.0
3404
+
3405
+ pos = {}
3406
+ for n in range(N):
3407
+ for tau in range(max_lag):
3408
+ pos[n * max_lag + tau] = pos_tmp[order[n] * max_lag + tau]
3409
+
3410
+ node_rings = {
3411
+ 0: {"sizes": None, "color_array": None, "label": "", "colorbar": False,}
3412
+ }
3413
+
3414
+ node_labels = ["" for i in range(N * max_lag)]
3415
+
3416
+ if graph.ndim == 4 and val_matrix is None:
3417
+ show_colorbar = False
3418
+ else:
3419
+ show_colorbar = True
3420
+
3421
+ _draw_network_with_curved_edges(
3422
+ fig=fig,
3423
+ ax=ax,
3424
+ G=deepcopy(G),
3425
+ pos=pos,
3426
+ node_rings=node_rings,
3427
+ node_labels=node_labels,
3428
+ # node_label_size=node_label_size,
3429
+ node_alpha=alpha,
3430
+ standard_size=node_size,
3431
+ node_aspect=node_aspect,
3432
+ standard_cmap="OrRd",
3433
+ standard_color_nodes=standard_color_nodes,
3434
+ standard_color_links=standard_color_links,
3435
+ log_sizes=False,
3436
+ cmap_links=cmap_edges,
3437
+ links_vmin=vmin_edges,
3438
+ links_vmax=vmax_edges,
3439
+ links_ticks=edge_ticks,
3440
+ # link_label_fontsize=link_label_fontsize,
3441
+ arrowstyle="simple",
3442
+ arrowhead_size=arrowhead_size,
3443
+ curved_radius=curved_radius,
3444
+ label_fontsize=label_fontsize,
3445
+ tick_label_size=tick_label_size,
3446
+ label_fraction=0.5,
3447
+ link_colorbar_label=link_colorbar_label,
3448
+ inner_edge_curved=False,
3449
+ # network_lower_bound=network_lower_bound,
3450
+ # network_left_bound=label_space_left,
3451
+ inner_edge_style=inner_edge_style,
3452
+ special_nodes=special_nodes,
3453
+ show_colorbar=show_colorbar,
3454
+ node_classification=node_classification_tsg,
3455
+ max_lag=max_lag,
3456
+ )
3457
+
3458
+ for i in range(N):
3459
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
3460
+ # trans = transforms.blended_transform_factory(fig.transFigure, ax.transData)
3461
+ ax.text(
3462
+ 0.,
3463
+ pos[order[i] * max_lag][1],
3464
+ f"{var_names[order[i]]}",
3465
+ fontsize=label_fontsize,
3466
+ horizontalalignment="right",
3467
+ verticalalignment="center",
3468
+ transform=trans,
3469
+ )
3470
+
3471
+ for tau in np.arange(max_lag - 1, -1, -1):
3472
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
3473
+ # trans = transforms.blended_transform_factory(ax.transData, fig.transFigure)
3474
+ if tau == max_lag - 1:
3475
+ ax.text(
3476
+ pos[tau][0],
3477
+ 1.0, # - label_space_top,
3478
+ r"$t$",
3479
+ fontsize=int(label_fontsize * 0.8),
3480
+ horizontalalignment="center",
3481
+ verticalalignment="bottom",
3482
+ transform=trans,
3483
+ )
3484
+ else:
3485
+ ax.text(
3486
+ pos[tau][0],
3487
+ 1.0, # - label_space_top,
3488
+ r"$t-%s$" % str(max_lag - tau - 1),
3489
+ fontsize=int(label_fontsize * 0.8),
3490
+ horizontalalignment="center",
3491
+ verticalalignment="bottom",
3492
+ transform=trans,
3493
+ )
3494
+
3495
+ # pyplot.tight_layout()
3496
+ if save_name is not None:
3497
+ pyplot.savefig(save_name, dpi=300)
3498
+ else:
3499
+ return fig, ax
3500
+
3501
+
3502
+ def plot_mediation_time_series_graph(
3503
+ path_node_array,
3504
+ tsg_path_val_matrix,
3505
+ var_names=None,
3506
+ fig_ax=None,
3507
+ figsize=None,
3508
+ link_colorbar_label="link coeff. (edge color)",
3509
+ node_colorbar_label="MCE (node color)",
3510
+ save_name=None,
3511
+ link_width=None,
3512
+ arrow_linewidth=8,
3513
+ vmin_edges=-1,
3514
+ vmax_edges=1.0,
3515
+ edge_ticks=0.4,
3516
+ cmap_edges="RdBu_r",
3517
+ order=None,
3518
+ vmin_nodes=-1.0,
3519
+ vmax_nodes=1.0,
3520
+ node_ticks=0.4,
3521
+ cmap_nodes="RdBu_r",
3522
+ node_size=0.1,
3523
+ node_aspect=None,
3524
+ arrowhead_size=20,
3525
+ curved_radius=0.2,
3526
+ label_fontsize=12,
3527
+ alpha=1.0,
3528
+ node_label_size=12,
3529
+ tick_label_size=6,
3530
+ standard_color_links='black',
3531
+ standard_color_nodes='lightgrey',
3532
+ ):
3533
+ """Creates a mediation time series graph plot.
3534
+ This is still in beta. The time series graph's links are colored by
3535
+ val_matrix.
3536
+
3537
+ Parameters
3538
+ ----------
3539
+ tsg_path_val_matrix : array_like
3540
+ Matrix of shape (N*tau_max, N*tau_max) containing link weight values.
3541
+ path_node_array: array_like
3542
+ Array of shape (N,) containing node values.
3543
+ var_names : list, optional (default: None)
3544
+ List of variable names. If None, range(N) is used.
3545
+ fig_ax : tuple of figure and axis object, optional (default: None)
3546
+ Figure and axes instance. If None they are created.
3547
+ figsize : tuple
3548
+ Size of figure.
3549
+ save_name : str, optional (default: None)
3550
+ Name of figure file to save figure. If None, figure is shown in window.
3551
+ link_colorbar_label : str, optional (default: 'link coeff. (edge color)')
3552
+ Link colorbar label.
3553
+ node_colorbar_label : str, optional (default: 'MCE (node color)')
3554
+ Node colorbar label.
3555
+ link_width : array-like, optional (default: None)
3556
+ Array of val_matrix.shape specifying relative link width with maximum
3557
+ given by arrow_linewidth. If None, all links have same width.
3558
+ order : list, optional (default: None)
3559
+ order of variables from top to bottom.
3560
+ arrow_linewidth : float, optional (default: 30)
3561
+ Linewidth.
3562
+ vmin_edges : float, optional (default: -1)
3563
+ Link colorbar scale lower bound.
3564
+ vmax_edges : float, optional (default: 1)
3565
+ Link colorbar scale upper bound.
3566
+ edge_ticks : float, optional (default: 0.4)
3567
+ Link tick mark interval.
3568
+ cmap_edges : str, optional (default: 'RdBu_r')
3569
+ Colormap for links.
3570
+ vmin_nodes : float, optional (default: 0)
3571
+ Node colorbar scale lower bound.
3572
+ vmax_nodes : float, optional (default: 1)
3573
+ Node colorbar scale upper bound.
3574
+ node_ticks : float, optional (default: 0.4)
3575
+ Node tick mark interval.
3576
+ cmap_nodes : str, optional (default: 'OrRd')
3577
+ Colormap for links.
3578
+ node_size : int, optional (default: 0.1)
3579
+ Node size.
3580
+ node_aspect : float, optional (default: None)
3581
+ Ratio between the heigth and width of the varible nodes.
3582
+ arrowhead_size : int, optional (default: 20)
3583
+ Size of link arrow head. Passed on to FancyArrowPatch object.
3584
+ curved_radius, float, optional (default: 0.2)
3585
+ Curvature of links. Passed on to FancyArrowPatch object.
3586
+ label_fontsize : int, optional (default: 10)
3587
+ Fontsize of colorbar labels.
3588
+ alpha : float, optional (default: 1.)
3589
+ Opacity.
3590
+ node_label_size : int, optional (default: 10)
3591
+ Fontsize of node labels.
3592
+ link_label_fontsize : int, optional (default: 6)
3593
+ Fontsize of link labels.
3594
+ """
3595
+ N = len(path_node_array)
3596
+ Nmaxlag = tsg_path_val_matrix.shape[0]
3597
+ max_lag = Nmaxlag // N
3598
+
3599
+ if var_names is None:
3600
+ var_names = range(N)
3601
+
3602
+ if fig_ax is None:
3603
+ fig = pyplot.figure(figsize=figsize)
3604
+ ax = fig.add_subplot(111, frame_on=False)
3605
+ else:
3606
+ fig, ax = fig_ax
3607
+
3608
+ if link_width is not None and not np.all(link_width >= 0.0):
3609
+ raise ValueError("link_width must be non-negative")
3610
+
3611
+ if order is None:
3612
+ order = range(N)
3613
+
3614
+ if set(order) != set(range(N)):
3615
+ raise ValueError("order must be a permutation of range(N)")
3616
+
3617
+ def translate(row, lag):
3618
+ return row * max_lag + lag
3619
+
3620
+ if np.count_nonzero(tsg_path_val_matrix) == np.count_nonzero(
3621
+ np.diagonal(tsg_path_val_matrix)
3622
+ ):
3623
+ diagonal = True
3624
+ else:
3625
+ diagonal = False
3626
+
3627
+ if np.count_nonzero(tsg_path_val_matrix) == tsg_path_val_matrix.size or diagonal:
3628
+ tsg_path_val_matrix[0, 1] = 1
3629
+ no_links = True
3630
+ else:
3631
+ no_links = False
3632
+
3633
+ # Define graph links by absolute maximum (positive or negative like for
3634
+ # partial correlation)
3635
+ tsg = tsg_path_val_matrix
3636
+ tsg_attr = np.zeros((N * max_lag, N * max_lag))
3637
+
3638
+ G = nx.DiGraph(tsg)
3639
+
3640
+ # node_color = np.zeros(N)
3641
+ # list of all strengths for color map
3642
+ all_strengths = []
3643
+ # Add attributes, contemporaneous and lagged links are handled separately
3644
+ for (u, v, dic) in G.edges(data=True):
3645
+ dic["no_links"] = no_links
3646
+ dic["outer_edge_attribute"] = None
3647
+
3648
+ if u != v:
3649
+
3650
+ if u % max_lag == v % max_lag:
3651
+ dic["inner_edge"] = True
3652
+ dic["outer_edge"] = False
3653
+ else:
3654
+ dic["inner_edge"] = False
3655
+ dic["outer_edge"] = True
3656
+
3657
+ dic["inner_edge_alpha"] = alpha
3658
+ dic["inner_edge_color"] = _get_absmax(
3659
+ np.array([[[tsg[u, v], tsg[v, u]]]])
3660
+ ).squeeze()
3661
+ dic["inner_edge_width"] = arrow_linewidth
3662
+ all_strengths.append(dic["inner_edge_color"])
3663
+
3664
+ dic["outer_edge_alpha"] = alpha
3665
+
3666
+ dic["outer_edge_width"] = arrow_linewidth
3667
+
3668
+ # value at argmax of average
3669
+ dic["outer_edge_color"] = tsg[u, v]
3670
+ all_strengths.append(dic["outer_edge_color"])
3671
+ dic["label"] = None
3672
+
3673
+ # dic['outer_edge_edge'] = False
3674
+ # dic['outer_edge_edgecolor'] = None
3675
+ # dic['inner_edge_edge'] = False
3676
+ # dic['inner_edge_edgecolor'] = None
3677
+
3678
+ # If no links are present, set value to zero
3679
+ if len(all_strengths) == 0:
3680
+ all_strengths = [0.0]
3681
+
3682
+ posarray = np.zeros((N * max_lag, 2))
3683
+ for i in range(N * max_lag):
3684
+ posarray[i] = np.array([(i % max_lag), (1.0 - i // max_lag)])
3685
+
3686
+ pos_tmp = {}
3687
+ for i in range(N * max_lag):
3688
+ # for n in range(N):
3689
+ # for tau in range(max_lag):
3690
+ # i = n*N + tau
3691
+ pos_tmp[i] = np.array(
3692
+ [
3693
+ ((i % max_lag) - posarray.min(axis=0)[0])
3694
+ / (posarray.max(axis=0)[0] - posarray.min(axis=0)[0]),
3695
+ ((1.0 - i // max_lag) - posarray.min(axis=0)[1])
3696
+ / (posarray.max(axis=0)[1] - posarray.min(axis=0)[1]),
3697
+ ]
3698
+ )
3699
+ pos_tmp[i][np.isnan(pos_tmp[i])] = 0.0
3700
+
3701
+ pos = {}
3702
+ for n in range(N):
3703
+ for tau in range(max_lag):
3704
+ pos[n * max_lag + tau] = pos_tmp[order[n] * max_lag + tau]
3705
+
3706
+ node_color = np.zeros(N * max_lag)
3707
+ for inet, n in enumerate(range(0, N * max_lag, max_lag)):
3708
+ node_color[n : n + max_lag] = path_node_array[inet]
3709
+
3710
+ # node_rings = {0: {'sizes': None, 'color_array': color_array,
3711
+ # 'label': '', 'colorbar': False,
3712
+ # }
3713
+ # }
3714
+
3715
+ node_rings = {
3716
+ 0: {
3717
+ "sizes": None,
3718
+ "color_array": node_color,
3719
+ "cmap": cmap_nodes,
3720
+ "vmin": vmin_nodes,
3721
+ "vmax": vmax_nodes,
3722
+ "ticks": node_ticks,
3723
+ "label": node_colorbar_label,
3724
+ "colorbar": True,
3725
+ }
3726
+ }
3727
+
3728
+ # ] for v in range(max_lag)]
3729
+ node_labels = ["" for i in range(N * max_lag)]
3730
+
3731
+ _draw_network_with_curved_edges(
3732
+ fig=fig,
3733
+ ax=ax,
3734
+ G=deepcopy(G),
3735
+ pos=pos,
3736
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
3737
+ # or None, 'cmap':string,
3738
+ node_rings=node_rings,
3739
+ # 'vmin':float or None, 'vmax':float or None, 'label':string or None}}
3740
+ node_labels=node_labels,
3741
+ node_label_size=node_label_size,
3742
+ node_alpha=alpha,
3743
+ standard_size=node_size,
3744
+ node_aspect=node_aspect,
3745
+ standard_cmap="OrRd",
3746
+ standard_color_nodes=standard_color_nodes,
3747
+ standard_color_links=standard_color_links,
3748
+ log_sizes=False,
3749
+ cmap_links=cmap_edges,
3750
+ links_vmin=vmin_edges,
3751
+ links_vmax=vmax_edges,
3752
+ links_ticks=edge_ticks,
3753
+ tick_label_size=tick_label_size,
3754
+ # cmap_links_edges='YlOrRd', links_edges_vmin=-1., links_edges_vmax=1.,
3755
+ # links_edges_ticks=.2, link_edge_colorbar_label='link_edge',
3756
+ arrowhead_size=arrowhead_size,
3757
+ curved_radius=curved_radius,
3758
+ label_fontsize=label_fontsize,
3759
+ label_fraction=0.5,
3760
+ link_colorbar_label=link_colorbar_label,
3761
+ inner_edge_curved=True,
3762
+ # network_lower_bound=network_lower_bound
3763
+ # inner_edge_style=inner_edge_style
3764
+ )
3765
+
3766
+ for i in range(N):
3767
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
3768
+ # trans = transforms.blended_transform_factory(fig.transFigure, ax.transData)
3769
+ ax.text(
3770
+ 0.,
3771
+ pos[order[i] * max_lag][1],
3772
+ "%s" % str(var_names[order[i]]),
3773
+ fontsize=label_fontsize,
3774
+ horizontalalignment="right",
3775
+ verticalalignment="center",
3776
+ transform=trans,
3777
+ )
3778
+
3779
+ for tau in np.arange(max_lag - 1, -1, -1):
3780
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
3781
+ # trans = transforms.blended_transform_factory(ax.transData, fig.transFigure)
3782
+ if tau == max_lag - 1:
3783
+ ax.text(
3784
+ pos[tau][0],
3785
+ 1.0, # - label_space_top,
3786
+ r"$t$",
3787
+ fontsize=label_fontsize,
3788
+ horizontalalignment="center",
3789
+ verticalalignment="bottom",
3790
+ transform=trans,
3791
+ )
3792
+ else:
3793
+ ax.text(
3794
+ pos[tau][0],
3795
+ 1.0, # - label_space_top,
3796
+ r"$t-%s$" % str(max_lag - tau - 1),
3797
+ fontsize=label_fontsize,
3798
+ horizontalalignment="center",
3799
+ verticalalignment="bottom",
3800
+ transform=trans,
3801
+ )
3802
+
3803
+ # fig.subplots_adjust(left=0.1, right=.98, bottom=.25, top=.9)
3804
+ # savestring = os.path.expanduser(save_name)
3805
+ if save_name is not None:
3806
+ pyplot.savefig(save_name)
3807
+ else:
3808
+ pyplot.show()
3809
+
3810
+
3811
+ def plot_mediation_graph(
3812
+ path_val_matrix,
3813
+ path_node_array=None,
3814
+ var_names=None,
3815
+ fig_ax=None,
3816
+ figsize=None,
3817
+ save_name=None,
3818
+ link_colorbar_label="link coeff. (edge color)",
3819
+ node_colorbar_label="MCE (node color)",
3820
+ link_width=None,
3821
+ node_pos=None,
3822
+ arrow_linewidth=10.0,
3823
+ vmin_edges=-1,
3824
+ vmax_edges=1.0,
3825
+ edge_ticks=0.4,
3826
+ cmap_edges="RdBu_r",
3827
+ vmin_nodes=-1.0,
3828
+ vmax_nodes=1.0,
3829
+ node_ticks=0.4,
3830
+ cmap_nodes="RdBu_r",
3831
+ node_size=0.3,
3832
+ node_aspect=None,
3833
+ arrowhead_size=20,
3834
+ curved_radius=0.2,
3835
+ label_fontsize=10,
3836
+ tick_label_size=6,
3837
+ lag_array=None,
3838
+ alpha=1.0,
3839
+ node_label_size=10,
3840
+ link_label_fontsize=10,
3841
+ # network_lower_bound=0.2,
3842
+ standard_color_links='black',
3843
+ standard_color_nodes='lightgrey',
3844
+ ):
3845
+ """Creates a network plot visualizing the pathways of a mediation analysis.
3846
+ This is still in beta. The network is defined from non-zero entries in
3847
+ ``path_val_matrix``. Nodes denote variables, straight links contemporaneous
3848
+ dependencies and curved arrows lagged dependencies. The node color denotes
3849
+ the mediated causal effect (MCE) and the link color the value at the lag
3850
+ with maximal link coefficient. The link label lists the lags with
3851
+ significant dependency in order of absolute magnitude. The network can also
3852
+ be plotted over a map drawn before on the same axis. Then the node positions
3853
+ can be supplied in appropriate axis coordinates via node_pos.
3854
+
3855
+ Parameters
3856
+ ----------
3857
+ path_val_matrix : array_like
3858
+ Matrix of shape (N, N, tau_max+1) containing link weight values.
3859
+ path_node_array: array_like
3860
+ Array of shape (N,) containing node values.
3861
+ var_names : list, optional (default: None)
3862
+ List of variable names. If None, range(N) is used.
3863
+ fig_ax : tuple of figure and axis object, optional (default: None)
3864
+ Figure and axes instance. If None they are created.
3865
+ figsize : tuple
3866
+ Size of figure.
3867
+ save_name : str, optional (default: None)
3868
+ Name of figure file to save figure. If None, figure is shown in window.
3869
+ link_colorbar_label : str, optional (default: 'link coeff. (edge color)')
3870
+ Link colorbar label.
3871
+ node_colorbar_label : str, optional (default: 'MCE (node color)')
3872
+ Node colorbar label.
3873
+ link_width : array-like, optional (default: None)
3874
+ Array of val_matrix.shape specifying relative link width with maximum
3875
+ given by arrow_linewidth. If None, all links have same width.
3876
+ node_pos : dictionary, optional (default: None)
3877
+ Dictionary of node positions in axis coordinates of form
3878
+ node_pos = {'x':array of shape (N,), 'y':array of shape(N)}. These
3879
+ coordinates could have been transformed before for basemap plots. You can
3880
+ also add a key 'transform':ccrs.PlateCarree() in order to plot graphs on
3881
+ a map using cartopy.
3882
+ arrow_linewidth : float, optional (default: 30)
3883
+ Linewidth.
3884
+ vmin_edges : float, optional (default: -1)
3885
+ Link colorbar scale lower bound.
3886
+ vmax_edges : float, optional (default: 1)
3887
+ Link colorbar scale upper bound.
3888
+ edge_ticks : float, optional (default: 0.4)
3889
+ Link tick mark interval.
3890
+ cmap_edges : str, optional (default: 'RdBu_r')
3891
+ Colormap for links.
3892
+ vmin_nodes : float, optional (default: 0)
3893
+ Node colorbar scale lower bound.
3894
+ vmax_nodes : float, optional (default: 1)
3895
+ Node colorbar scale upper bound.
3896
+ node_ticks : float, optional (default: 0.4)
3897
+ Node tick mark interval.
3898
+ cmap_nodes : str, optional (default: 'OrRd')
3899
+ Colormap for links.
3900
+ node_size : int, optional (default: 0.3)
3901
+ Node size.
3902
+ node_aspect : float, optional (default: None)
3903
+ Ratio between the heigth and width of the varible nodes.
3904
+ arrowhead_size : int, optional (default: 20)
3905
+ Size of link arrow head. Passed on to FancyArrowPatch object.
3906
+ curved_radius, float, optional (default: 0.2)
3907
+ Curvature of links. Passed on to FancyArrowPatch object.
3908
+ label_fontsize : int, optional (default: 10)
3909
+ Fontsize of colorbar labels.
3910
+ alpha : float, optional (default: 1.)
3911
+ Opacity.
3912
+ node_label_size : int, optional (default: 10)
3913
+ Fontsize of node labels.
3914
+ link_label_fontsize : int, optional (default: 6)
3915
+ Fontsize of link labels.
3916
+ lag_array : array, optional (default: None)
3917
+ Optional specification of lags overwriting np.arange(0, tau_max+1)
3918
+ """
3919
+ val_matrix = path_val_matrix
3920
+
3921
+ if fig_ax is None:
3922
+ fig = pyplot.figure(figsize=figsize)
3923
+ ax = fig.add_subplot(111, frame_on=False)
3924
+ else:
3925
+ fig, ax = fig_ax
3926
+
3927
+ if link_width is not None and not np.all(link_width >= 0.0):
3928
+ raise ValueError("link_width must be non-negative")
3929
+
3930
+ N, N, dummy = val_matrix.shape
3931
+ tau_max = dummy - 1
3932
+
3933
+ if np.count_nonzero(val_matrix) == np.count_nonzero(np.diagonal(val_matrix)):
3934
+ diagonal = True
3935
+ else:
3936
+ diagonal = False
3937
+
3938
+ if np.count_nonzero(val_matrix) == val_matrix.size or diagonal:
3939
+ val_matrix[0, 1, 0] = 1
3940
+ no_links = True
3941
+ else:
3942
+ no_links = False
3943
+
3944
+ if var_names is None:
3945
+ var_names = range(N)
3946
+
3947
+ # Define graph links by absolute maximum (positive or negative like for
3948
+ # partial correlation)
3949
+ # val_matrix[np.abs(val_matrix) < sig_thres] = 0.
3950
+ graph = val_matrix != 0.0
3951
+ net = _get_absmax(val_matrix)
3952
+ G = nx.DiGraph(net)
3953
+
3954
+ node_color = np.zeros(N)
3955
+ # list of all strengths for color map
3956
+ all_strengths = []
3957
+ # Add attributes, contemporaneous and lagged links are handled separately
3958
+ for (u, v, dic) in G.edges(data=True):
3959
+ dic["outer_edge_attribute"] = None
3960
+ dic["no_links"] = no_links
3961
+ # average lagfunc for link u --> v ANDOR u -- v
3962
+ if tau_max > 0:
3963
+ # argmax of absolute maximum
3964
+ argmax = np.abs(val_matrix[u, v][1:]).argmax() + 1
3965
+ else:
3966
+ argmax = 0
3967
+ if u != v:
3968
+ # For contemp links masking or finite samples can lead to different
3969
+ # values for u--v and v--u
3970
+ # Here we use the maximum for the width and weight (=color)
3971
+ # of the link
3972
+ # Draw link if u--v OR v--u at lag 0 is nonzero
3973
+ # dic['inner_edge'] = ((np.abs(val_matrix[u, v][0]) >=
3974
+ # sig_thres[u, v][0]) or
3975
+ # (np.abs(val_matrix[v, u][0]) >=
3976
+ # sig_thres[v, u][0]))
3977
+ dic["inner_edge"] = graph[u, v, 0] or graph[v, u, 0]
3978
+ dic["inner_edge_alpha"] = alpha
3979
+ # value at argmax of average
3980
+ if np.abs(val_matrix[u, v][0] - val_matrix[v, u][0]) > 0.0001:
3981
+ print(
3982
+ "Contemporaneous I(%d; %d)=%.3f != I(%d; %d)=%.3f"
3983
+ % (u, v, val_matrix[u, v][0], v, u, val_matrix[v, u][0])
3984
+ + " due to conditions, finite sample effects or "
3985
+ "masking, here edge color = "
3986
+ "larger (absolute) value."
3987
+ )
3988
+ dic["inner_edge_color"] = _get_absmax(
3989
+ np.array([[[val_matrix[u, v][0], val_matrix[v, u][0]]]])
3990
+ ).squeeze()
3991
+ if link_width is None:
3992
+ dic["inner_edge_width"] = arrow_linewidth
3993
+ else:
3994
+ dic["inner_edge_width"] = (
3995
+ link_width[u, v, 0] / link_width.max() * arrow_linewidth
3996
+ )
3997
+
3998
+ all_strengths.append(dic["inner_edge_color"])
3999
+
4000
+ if tau_max > 0:
4001
+ # True if ensemble mean at lags > 0 is nonzero
4002
+ # dic['outer_edge'] = np.any(
4003
+ # np.abs(val_matrix[u, v][1:]) >= sig_thres[u, v][1:])
4004
+ dic["outer_edge"] = np.any(graph[u, v, 1:])
4005
+ else:
4006
+ dic["outer_edge"] = False
4007
+ dic["outer_edge_alpha"] = alpha
4008
+ if link_width is None:
4009
+ # fraction of nonzero values
4010
+ dic["outer_edge_width"] = arrow_linewidth
4011
+ else:
4012
+ dic["outer_edge_width"] = (
4013
+ link_width[u, v, argmax] / link_width.max() * arrow_linewidth
4014
+ )
4015
+
4016
+ # value at argmax of average
4017
+ dic["outer_edge_color"] = val_matrix[u, v][argmax]
4018
+ all_strengths.append(dic["outer_edge_color"])
4019
+
4020
+ # Sorted list of significant lags (only if robust wrt
4021
+ # d['min_ensemble_frac'])
4022
+ if tau_max > 0:
4023
+ lags = np.abs(val_matrix[u, v][1:]).argsort()[::-1] + 1
4024
+ sig_lags = (np.where(graph[u, v, 1:])[0] + 1).tolist()
4025
+ else:
4026
+ lags, sig_lags = [], []
4027
+ if lag_array is not None:
4028
+ dic["label"] = str([lag_array[l] for l in lags if l in sig_lags])[1:-1].replace(" ", "")
4029
+ else:
4030
+ dic["label"] = str([l for l in lags if l in sig_lags])[1:-1].replace(" ", "")
4031
+ else:
4032
+ # Node color is max of average autodependency
4033
+ node_color[u] = val_matrix[u, v][argmax]
4034
+
4035
+ # dic['outer_edge_edge'] = False
4036
+ # dic['outer_edge_edgecolor'] = None
4037
+ # dic['inner_edge_edge'] = False
4038
+ # dic['inner_edge_edgecolor'] = None
4039
+
4040
+ node_color = path_node_array
4041
+ # print node_color
4042
+ # If no links are present, set value to zero
4043
+ if len(all_strengths) == 0:
4044
+ all_strengths = [0.0]
4045
+
4046
+ if node_pos is None:
4047
+ pos = nx.circular_layout(deepcopy(G))
4048
+ # pos = nx.spring_layout(deepcopy(G))
4049
+ else:
4050
+ pos = {}
4051
+ for i in range(N):
4052
+ pos[i] = (node_pos["x"][i], node_pos["y"][i])
4053
+
4054
+ if node_pos is not None and 'transform' in node_pos:
4055
+ transform = node_pos['transform']
4056
+ else: transform = ax.transData
4057
+
4058
+ node_rings = {
4059
+ 0: {
4060
+ "sizes": None,
4061
+ "color_array": node_color,
4062
+ "cmap": cmap_nodes,
4063
+ "vmin": vmin_nodes,
4064
+ "vmax": vmax_nodes,
4065
+ "ticks": node_ticks,
4066
+ "label": node_colorbar_label,
4067
+ "colorbar": True,
4068
+ }
4069
+ }
4070
+
4071
+ _draw_network_with_curved_edges(
4072
+ fig=fig,
4073
+ ax=ax,
4074
+ G=deepcopy(G),
4075
+ pos=pos,
4076
+ # dictionary of rings: {0:{'sizes':(N,)-array, 'color_array':(N,)-array
4077
+ # or None, 'cmap':string,
4078
+ node_rings=node_rings,
4079
+ # 'vmin':float or None, 'vmax':float or None, 'label':string or None}}
4080
+ node_labels=var_names,
4081
+ node_label_size=node_label_size,
4082
+ node_alpha=alpha,
4083
+ standard_size=node_size,
4084
+ node_aspect=node_aspect,
4085
+ standard_cmap="OrRd",
4086
+ standard_color_nodes=standard_color_nodes,
4087
+ standard_color_links=standard_color_links,
4088
+ log_sizes=False,
4089
+ cmap_links=cmap_edges,
4090
+ links_vmin=vmin_edges,
4091
+ links_vmax=vmax_edges,
4092
+ links_ticks=edge_ticks,
4093
+ tick_label_size=tick_label_size,
4094
+ # cmap_links_edges='YlOrRd', links_edges_vmin=-1., links_edges_vmax=1.,
4095
+ # links_edges_ticks=.2, link_edge_colorbar_label='link_edge',
4096
+ arrowhead_size=arrowhead_size,
4097
+ curved_radius=curved_radius,
4098
+ label_fontsize=label_fontsize,
4099
+ link_label_fontsize=link_label_fontsize,
4100
+ link_colorbar_label=link_colorbar_label,
4101
+ # network_lower_bound=network_lower_bound,
4102
+ # label_fraction=label_fraction,
4103
+ # inner_edge_style=inner_edge_style
4104
+ transform=transform
4105
+ )
4106
+
4107
+ # fig.subplots_adjust(left=0.1, right=.9, bottom=.25, top=.95)
4108
+ # savestring = os.path.expanduser(save_name)
4109
+ if save_name is not None:
4110
+ pyplot.savefig(save_name)
4111
+ else:
4112
+ pyplot.show()
4113
+
4114
+
4115
+ #
4116
+ # Functions to plot time series graphs from links including ancestors
4117
+ #
4118
+ def plot_tsg(links, X, Y, Z=None, anc_x=None, anc_y=None, anc_xy=None):
4119
+ """Plots TSG that is input in format (N*max_lag, N*max_lag).
4120
+ Compared to the tigramite plotting function here links
4121
+ X^i_{t-tau} --> X^j_t can be missing for different t'. Helpful to
4122
+ visualize the conditioned TSG.
4123
+ """
4124
+
4125
+ def varlag2node(var, lag):
4126
+ """Translate from (var, lag) notation to node in TSG.
4127
+ lag must be <= 0.
4128
+ """
4129
+ return var * max_lag + lag
4130
+
4131
+ def node2varlag(node):
4132
+ """Translate from node in TSG to (var, -tau) notation.
4133
+ Here tau is <= 0.
4134
+ """
4135
+ var = node // max_lag
4136
+ tau = node % (max_lag) - (max_lag - 1)
4137
+ return var, tau
4138
+
4139
+ def _get_minmax_lag(links):
4140
+ """Helper function to retrieve tau_min and tau_max from links
4141
+ """
4142
+
4143
+ N = len(links)
4144
+
4145
+ # Get maximum time lag
4146
+ min_lag = np.inf
4147
+ max_lag = 0
4148
+ for j in range(N):
4149
+ for link_props in links[j]:
4150
+ var, lag = link_props[0]
4151
+ coeff = link_props[1]
4152
+ # func = link_props[2]
4153
+ if coeff != 0.:
4154
+ min_lag = min(min_lag, abs(lag))
4155
+ max_lag = max(max_lag, abs(lag))
4156
+ return min_lag, max_lag
4157
+
4158
+ def _links_to_tsg(link_coeffs, max_lag=None):
4159
+ """Transform link_coeffs to time series graph.
4160
+ TSG is of shape (N*max_lag, N*max_lag).
4161
+ """
4162
+ N = len(link_coeffs)
4163
+
4164
+ # Get maximum lag
4165
+ min_lag_links, max_lag_links = _get_minmax_lag(link_coeffs)
4166
+
4167
+ # max_lag of TSG is max lag in links + 1 for the zero lag.
4168
+ if max_lag is None:
4169
+ max_lag = max_lag_links + 1
4170
+
4171
+ tsg = np.zeros((N * max_lag, N * max_lag))
4172
+
4173
+ for j in range(N):
4174
+ for link_props in link_coeffs[j]:
4175
+ i, lag = link_props[0]
4176
+ tau = abs(lag)
4177
+ coeff = link_props[1]
4178
+ # func = link_props[2]
4179
+ if coeff != 0.0:
4180
+ for t in range(max_lag):
4181
+ if (
4182
+ 0 <= varlag2node(i, t - tau)
4183
+ and varlag2node(i, t - tau) % max_lag
4184
+ <= varlag2node(j, t) % max_lag
4185
+ ):
4186
+ tsg[varlag2node(i, t - tau), varlag2node(j, t)] = 1.0
4187
+
4188
+ return tsg
4189
+
4190
+ color_list = ["lightgrey", "grey", "black", "red", "blue", "orange"]
4191
+ listcmap = ListedColormap(color_list)
4192
+
4193
+ N = len(links)
4194
+
4195
+ min_lag_links, max_lag_links = _get_minmax_lag(links)
4196
+ max_lag = max_lag_links
4197
+
4198
+ for anc in X + Y:
4199
+ max_lag = max(max_lag, abs(anc[1]))
4200
+ for anc in Y:
4201
+ max_lag = max(max_lag, abs(anc[1]))
4202
+ if Z is not None:
4203
+ for anc in Z:
4204
+ max_lag = max(max_lag, abs(anc[1]))
4205
+
4206
+ if anc_x is not None:
4207
+ for anc in anc_x:
4208
+ max_lag = max(max_lag, abs(anc[1]))
4209
+ if anc_y is not None:
4210
+ for anc in anc_y:
4211
+ max_lag = max(max_lag, abs(anc[1]))
4212
+ if anc_xy is not None:
4213
+ for anc in anc_xy:
4214
+ max_lag = max(max_lag, abs(anc[1]))
4215
+
4216
+ max_lag = max_lag + 1
4217
+
4218
+ tsg = _links_to_tsg(links, max_lag=max_lag)
4219
+
4220
+ G = nx.DiGraph(tsg)
4221
+
4222
+ figsize = (3, 3)
4223
+ link_colorbar_label = "MCI"
4224
+ arrow_linewidth = 8.0
4225
+ vmin_edges = -1
4226
+ vmax_edges = 1.0
4227
+ edge_ticks = 0.4
4228
+ cmap_edges = "RdBu_r"
4229
+ order = None
4230
+ node_size = .1
4231
+ arrowhead_size = 20
4232
+ curved_radius = 0.2
4233
+ label_fontsize = 10
4234
+ alpha = 1.0
4235
+ node_label_size = 10
4236
+ # label_space_left = 0.1
4237
+ # label_space_top = 0.0
4238
+ # network_lower_bound = 0.2
4239
+ inner_edge_style = "dashed"
4240
+
4241
+ node_color = np.ones(N * max_lag) # , dtype = 'object')
4242
+ node_color[:] = 0
4243
+
4244
+ if anc_x is not None:
4245
+ for n in [varlag2node(itau[0], max_lag - 1 + itau[1]) for itau in anc_x]:
4246
+ node_color[n] = 3
4247
+ if anc_y is not None:
4248
+ for n in [varlag2node(itau[0], max_lag - 1 + itau[1]) for itau in anc_y]:
4249
+ node_color[n] = 4
4250
+ if anc_xy is not None:
4251
+ for n in [varlag2node(itau[0], max_lag - 1 + itau[1]) for itau in anc_xy]:
4252
+ node_color[n] = 5
4253
+
4254
+ for x in X:
4255
+ node_color[varlag2node(x[0], max_lag - 1 + x[1])] = 2
4256
+ for y in Y:
4257
+ node_color[varlag2node(y[0], max_lag - 1 + y[1])] = 2
4258
+ if Z is not None:
4259
+ for z in Z:
4260
+ node_color[varlag2node(z[0], max_lag - 1 + z[1])] = 1
4261
+
4262
+ fig = pyplot.figure(figsize=figsize)
4263
+ ax = fig.add_subplot(111, frame_on=False)
4264
+ var_names = range(N)
4265
+ order = range(N)
4266
+
4267
+ # list of all strengths for color map
4268
+ all_strengths = []
4269
+ # Add attributes, contemporaneous and lagged links are handled separately
4270
+ for (u, v, dic) in G.edges(data=True):
4271
+ if u != v:
4272
+ if tsg[u, v] and tsg[v, u]:
4273
+ dic["inner_edge"] = True
4274
+ dic["outer_edge"] = False
4275
+ else:
4276
+ dic["inner_edge"] = False
4277
+ dic["outer_edge"] = True
4278
+
4279
+ dic["inner_edge_alpha"] = alpha
4280
+ dic["inner_edge_color"] = tsg[u, v]
4281
+
4282
+ dic["inner_edge_width"] = arrow_linewidth
4283
+ dic["inner_edge_attribute"] = dic["outer_edge_attribute"] = None
4284
+
4285
+ all_strengths.append(dic["inner_edge_color"])
4286
+ dic["outer_edge_alpha"] = alpha
4287
+ dic["outer_edge_width"] = dic["inner_edge_width"] = arrow_linewidth
4288
+
4289
+ # value at argmax of average
4290
+ dic["outer_edge_color"] = tsg[u, v]
4291
+
4292
+ all_strengths.append(dic["outer_edge_color"])
4293
+ dic["label"] = None
4294
+
4295
+ # If no links are present, set value to zero
4296
+ if len(all_strengths) == 0:
4297
+ all_strengths = [0.0]
4298
+
4299
+ posarray = np.zeros((N * max_lag, 2))
4300
+ for i in range(N * max_lag):
4301
+ posarray[i] = np.array([(i % max_lag), (1.0 - i // max_lag)])
4302
+
4303
+ pos_tmp = {}
4304
+ for i in range(N * max_lag):
4305
+ pos_tmp[i] = np.array(
4306
+ [
4307
+ ((i % max_lag) - posarray.min(axis=0)[0])
4308
+ / (posarray.max(axis=0)[0] - posarray.min(axis=0)[0]),
4309
+ ((1.0 - i // max_lag) - posarray.min(axis=0)[1])
4310
+ / (posarray.max(axis=0)[1] - posarray.min(axis=0)[1]),
4311
+ ]
4312
+ )
4313
+ pos_tmp[i][np.isnan(pos_tmp[i])] = 0.0
4314
+
4315
+ pos = {}
4316
+ for n in range(N):
4317
+ for tau in range(max_lag):
4318
+ pos[n * max_lag + tau] = pos_tmp[order[n] * max_lag + tau]
4319
+
4320
+ node_rings = {
4321
+ 0: {
4322
+ "sizes": None,
4323
+ "color_array": node_color,
4324
+ "label": "",
4325
+ "colorbar": False,
4326
+ "cmap": listcmap,
4327
+ "vmin": 0,
4328
+ "vmax": len(color_list),
4329
+ }
4330
+ }
4331
+
4332
+ node_labels = ["" for i in range(N * max_lag)]
4333
+
4334
+ _draw_network_with_curved_edges(
4335
+ fig=fig,
4336
+ ax=ax,
4337
+ G=deepcopy(G),
4338
+ pos=pos,
4339
+ node_rings=node_rings,
4340
+ node_labels=node_labels,
4341
+ node_label_size=node_label_size,
4342
+ node_alpha=alpha,
4343
+ standard_size=node_size,
4344
+ node_aspect=None,
4345
+ standard_cmap="OrRd",
4346
+ standard_color_links='black',
4347
+ standard_color_nodes='lightgrey',
4348
+ log_sizes=False,
4349
+ cmap_links=cmap_edges,
4350
+ links_vmin=vmin_edges,
4351
+ links_vmax=vmax_edges,
4352
+ links_ticks=edge_ticks,
4353
+ arrowstyle="simple",
4354
+ arrowhead_size=arrowhead_size,
4355
+ curved_radius=curved_radius,
4356
+ label_fontsize=label_fontsize,
4357
+ label_fraction=0.5,
4358
+ link_colorbar_label=link_colorbar_label,
4359
+ inner_edge_curved=True,
4360
+ # network_lower_bound=network_lower_bound,
4361
+ inner_edge_style=inner_edge_style,
4362
+ )
4363
+
4364
+ for i in range(N):
4365
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
4366
+ ax.text(
4367
+ 0.,
4368
+ pos[order[i] * max_lag][1],
4369
+ "%s" % str(var_names[order[i]]),
4370
+ fontsize=label_fontsize,
4371
+ horizontalalignment="right",
4372
+ verticalalignment="center",
4373
+ transform=trans,
4374
+ )
4375
+
4376
+ for tau in np.arange(max_lag - 1, -1, -1):
4377
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
4378
+ if tau == max_lag - 1:
4379
+ ax.text(
4380
+ pos[tau][0],
4381
+ 1.0, #- label_space_top,
4382
+ r"$t$",
4383
+ fontsize=int(label_fontsize * 0.7),
4384
+ horizontalalignment="center",
4385
+ verticalalignment="bottom",
4386
+ transform=trans,
4387
+ )
4388
+ else:
4389
+ ax.text(
4390
+ pos[tau][0],
4391
+ 1.0, # - label_space_top,
4392
+ r"$t-%s$" % str(max_lag - tau - 1),
4393
+ fontsize=int(label_fontsize * 0.7),
4394
+ horizontalalignment="center",
4395
+ verticalalignment="bottom",
4396
+ transform=trans,
4397
+ )
4398
+
4399
+ return fig, ax
4400
+
4401
+ def write_csv(
4402
+ graph,
4403
+ save_name,
4404
+ val_matrix=None,
4405
+ var_names=None,
4406
+ link_width=None,
4407
+ link_attribute=None,
4408
+ digits=5,
4409
+ ):
4410
+ """Writes all links in a graph to a csv file.
4411
+
4412
+ Format is each link in a row as 'Variable i', 'Variable j', 'Time lag of i', 'Link type i --- j',
4413
+ with optional further columns for entries in [val_matrix link_attribute, link_width].
4414
+
4415
+ Parameters
4416
+ ----------
4417
+ graph : string or bool array-like, optional (default: None)
4418
+ Either string matrix providing graph or bool array providing only adjacencies
4419
+ Must be of same shape as val_matrix.
4420
+ save_name : str
4421
+ Name of figure file to save figure. If None, figure is shown in window.
4422
+ val_matrix : array_like
4423
+ Matrix of shape (N, N, tau_max+1) containing test statistic values.
4424
+ var_names : list, optional (default: None)
4425
+ List of variable names. If None, range(N) is used.
4426
+ link_width : array-like, optional (default: None)
4427
+ Array of val_matrix.shape specifying relative link width with maximum
4428
+ given by arrow_linewidth. If None, all links have same width.
4429
+ link_attribute : array-like, optional (default: None)
4430
+ String array of val_matrix.shape specifying link attributes.
4431
+ digits : int
4432
+ Number of significant digits for writing link value and width.
4433
+ """
4434
+
4435
+ graph = np.copy(graph.squeeze())
4436
+
4437
+ N = len(graph)
4438
+
4439
+ if val_matrix is None:
4440
+ val_matrix_exists = false
4441
+ else:
4442
+ val_matrix_exists = True
4443
+
4444
+ if graph.ndim == 4:
4445
+ raise ValueError("Time series graph of shape (N,N,tau_max+1,tau_max+1) cannot be represented by plot_graph,"
4446
+ " use plot_time_series_graph instead.")
4447
+
4448
+ if graph.ndim == 2:
4449
+ # If a non-time series (N,N)-graph is given, insert a dummy dimension
4450
+ graph = np.expand_dims(graph, axis = 2)
4451
+
4452
+ (graph, val_matrix, link_width, link_attribute) = _check_matrices(
4453
+ graph, val_matrix, link_width, link_attribute)
4454
+
4455
+ if var_names is None:
4456
+ var_names = range(N)
4457
+
4458
+
4459
+ header = ['Variable i', 'Variable j', 'Time lag of i', 'Link type i --- j']
4460
+ if val_matrix_exists:
4461
+ header.append('Link value')
4462
+ if link_attribute is not None:
4463
+ header.append('Link attribute')
4464
+ if link_width is not None:
4465
+ header.append('Link width')
4466
+
4467
+
4468
+ with open(save_name, 'w', encoding='UTF8', newline='') as f:
4469
+ writer = csv.writer(f)
4470
+
4471
+ # write the header
4472
+ writer.writerow(header)
4473
+
4474
+ # write the link data
4475
+ for (i, j, tau) in zip(*np.where(graph!='')):
4476
+ # Only consider contemporaneous links once
4477
+ if tau > 0 or i <= j:
4478
+ row = [str(var_names[i]), str(var_names[j]), f"{tau}", graph[i,j,tau]]
4479
+ if val_matrix_exists:
4480
+ row.append(f"{val_matrix[i,j,tau]:.{digits}}")
4481
+ if link_attribute is not None:
4482
+ row.append(link_attribute[i,j,tau])
4483
+ if link_width is not None:
4484
+ row.append(f"{link_width[i,j,tau]:.{digits}}")
4485
+
4486
+ writer.writerow(row)
4487
+
4488
+
4489
+ if __name__ == "__main__":
4490
+
4491
+ import sys
4492
+ matplotlib.rc('xtick', labelsize=6)
4493
+ matplotlib.rc('ytick', labelsize=6)
4494
+
4495
+ # Consider some toy data
4496
+ import tigramite
4497
+ import tigramite.toymodels.structural_causal_processes as toys
4498
+ import tigramite.data_processing as pp
4499
+ from tigramite.causal_effects import CausalEffects
4500
+
4501
+ graph = np.zeros((3,3,3), dtype='<U3')
4502
+ graph[0,1,2] = "-->"
4503
+ graph[0,1,1] = "-->"
4504
+ graph[0,1,0] = "-->"; graph[1,0,0] = "<--"
4505
+
4506
+ plot_graph(graph=graph,
4507
+ # val_matrix=val_matrix,
4508
+ # figsize=(5, 5),
4509
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4510
+ # arrow_linewidth=6,
4511
+ # label_space_left = label_space_left,
4512
+ # label_space_top = label_space_top,
4513
+ # # network_lower_bound=network_lower_bound,
4514
+ save_name="test.pdf"
4515
+ )
4516
+
4517
+
4518
+ # T = 1000
4519
+ # def lin_f(x): return x
4520
+ # auto_coeff = 0.3
4521
+ # coeff = 1.
4522
+ # links = {
4523
+ # 0: [((0, -1), auto_coeff, lin_f)],
4524
+ # 1: [((1, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
4525
+ # 2: [((2, -1), auto_coeff, lin_f), ((1, 0), coeff, lin_f)],
4526
+ # }
4527
+ # data, nonstat = toys.structural_causal_process(links, T=T,
4528
+ # noises=None, seed=7)
4529
+ # dataframe = pp.DataFrame(data, var_names=range(len(links)))
4530
+
4531
+ # links = {
4532
+ # 0: [((0, -1), 1.5*auto_coeff, lin_f)],
4533
+ # 1: [((1, -1), 1.5*auto_coeff, lin_f), ((0, 0), 1.5*coeff, lin_f)],
4534
+ # 2: [((2, -1), 1.5*auto_coeff, lin_f), ((1, 0), 1.5*coeff, lin_f)],
4535
+ # }
4536
+ # data2, nonstat = toys.structural_causal_process(links, T=T,
4537
+ # noises=None, seed=7)
4538
+ # dataframe2 = pp.DataFrame(data2, var_names=range(len(links)))
4539
+ # plot_densityplots(dataframe, name='test.pdf')
4540
+
4541
+ # N = len(links)
4542
+
4543
+
4544
+ # parcorr = ParCorr(significance='analytic')
4545
+ # pcmci = PCMCI(
4546
+ # dataframe=dataframe,
4547
+ # cond_ind_test=parcorr,
4548
+ # verbosity=1)
4549
+
4550
+
4551
+ # correlations = np.random.rand(3, 3, 5) - 0.5 #pcmci.get_lagged_dependencies(tau_max=20, val_only=True)['val_matrix']
4552
+ # lag_func_matrix = plot_lagfuncs(val_matrix=correlations, setup_args={
4553
+ # 'label_space_left':0.05,
4554
+ # 'minimum': 0.0,
4555
+ # 'maximum':.05,
4556
+ # 'x_base':5,
4557
+ # 'y_base':.5})
4558
+ # plt.show()
4559
+
4560
+
4561
+ # N = len(links)
4562
+ # matrix = setup_density_matrix(N=N, var_names=dataframe.var_names)
4563
+ # matrix.add_densityplot(dataframe=dataframe,
4564
+ # # selected_dataset=0,
4565
+ # **{
4566
+ # 'label':'Weak',
4567
+ # 'label_color':'blue',
4568
+ # "snskdeplot_args" : {'cmap':'Reds'},
4569
+ # }), #{'cmap':'Blues', 'alpha':0.3}})
4570
+ # matrix.add_densityplot(dataframe=dataframe2, selected_dataset=0,
4571
+ # **{'label':'Strong',
4572
+ # 'label_color':'red',
4573
+ # "snskdeplot_args" : {'cmap':'Blues', 'alpha':0.3}})
4574
+ # matrix.adjustfig(name='test.pdf')
4575
+
4576
+ # matrix = setup_scatter_matrix(N=dataframe.N,
4577
+ # var_names=dataframe.var_names)
4578
+ # matrix_lags = np.ones((3, 3)).astype('int')
4579
+ # matrix.add_scatterplot(dataframe=dataframe, matrix_lags=matrix_lags,
4580
+ # label='ones', alpha=0.4)
4581
+ # matrix_lags = 2*np.ones((3, 3)).astype('int')
4582
+ # matrix.add_scatterplot(dataframe=dataframe, matrix_lags=matrix_lags,
4583
+ # label='twos', color='red', alpha=0.4)
4584
+
4585
+ # matrix.savefig(name='scattertest.pdf')
4586
+
4587
+
4588
+ # pyplot.show()
4589
+ # sys.exit(0)
4590
+
4591
+
4592
+ # val_matrix = np.zeros((4, 4, 3))
4593
+
4594
+ # # Complete test case
4595
+ # graph = np.zeros((3,3,2), dtype='<U3')
4596
+ # val_matrix = 0.*np.random.rand(*graph.shape)
4597
+ # val_matrix[:,:,0] = 0.2
4598
+ # graph[:] = ""
4599
+ # # graph[0, 1, 0] = "<-+"
4600
+ # # graph[1, 0, 0] = "+->"
4601
+ # graph[0, 0, 1] = "-->"
4602
+ # graph[1, 1, 1] = "-->"
4603
+
4604
+ # graph[0, 1, 1] = "+->"
4605
+ # # graph[1, 0, 1] = "o-o"
4606
+
4607
+ # graph[1, 2, 0] = "<->"
4608
+ # graph[2, 1, 0] = "<->"
4609
+
4610
+ # graph[0, 2, 0] = "x-x"
4611
+ # # graph[2, 0, 0] = "x-x"
4612
+ # nolinks = np.zeros(graph.shape)
4613
+ # # nolinks[range(4), range(4), 1] = 1
4614
+
4615
+ # # graph = graph[:2, :2, :]
4616
+
4617
+ # # fig, axes = pyplot.subplots(nrows=1, ncols=1, figsize=(6, 5))
4618
+
4619
+
4620
+ # # import cartopy.crs as ccrs
4621
+ # graph = np.ones((5, 5, 2), dtype='<U3')
4622
+ # graph[:] = ""
4623
+ # graph[3, :, 1] = '+->'
4624
+
4625
+ # fig = pyplot.figure(figsize=(8, 6))
4626
+ # fig = pyplot.figure(figsize=(10, 5))
4627
+ # ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide())
4628
+ # make the map global rather than have it zoom in to
4629
+ # the extents of any plotted data
4630
+ # ax.set_global()
4631
+ # ax.stock_img()
4632
+ # ax.coastlines()
4633
+ # # ymax = 1.
4634
+ # node_pos = {'x':np.linspace(0, ymax, graph.shape[0]), 'y':np.linspace(0, ymax, graph.shape[0]),}
4635
+ # node_pos = {'x':np.array([10,-20,80,-50,80]),
4636
+ # 'y':np.array([-10,70,60,-40,50]),
4637
+ # 'transform':ccrs.PlateCarree(), # t.PlateCarree()
4638
+ # }
4639
+
4640
+ # plot_time_series_graph(graph=graph,
4641
+ # # fig_ax = (fig, ax),
4642
+ # # val_matrix=val_matrix,
4643
+ # # figsize=(5, 5),
4644
+ # # var_names = ['Var %s' %i for i in range(len(graph))],
4645
+ # # arrow_linewidth=6,
4646
+ # # label_space_left = label_space_left,
4647
+ # # label_space_top = label_space_top,
4648
+ # # # network_lower_bound=network_lower_bound,
4649
+ # save_name="tsg_test.pdf"
4650
+ # )
4651
+ # pyplot.tight_layout()
4652
+
4653
+ # network_lower_bound = 0.
4654
+ # show_colorbar=True
4655
+ # plot_graph(graph=graph,
4656
+ # fig_ax = (fig, ax),
4657
+ # node_pos = node_pos,
4658
+ # node_size = 20,
4659
+ # # val_matrix=val_matrix,
4660
+ # # figsize=(5, 5),
4661
+ # # var_names = ['Var %s' %i for i in range(len(graph))],
4662
+ # # arrow_linewidth=6,
4663
+ # # label_space_left = label_space_left,
4664
+ # # label_space_top = label_space_top,
4665
+ # # # network_lower_bound=network_lower_bound,
4666
+ # save_name="tsg_test.pdf"
4667
+ # )
4668
+ # pyplot.tight_layout()
4669
+ # axes[0,0].scatter(np.random.rand(100), np.random.rand(100))
4670
+
4671
+ # plot_graph(graph=graph,
4672
+ # fig_ax = (fig, axes[0,0]),
4673
+ # val_matrix=val_matrix,
4674
+ # # figsize=(5, 5),
4675
+ # var_names = ['Variable %s' %i for i in range(len(graph))],
4676
+ # arrow_linewidth=6,
4677
+ # # label_space_left = label_space_left,
4678
+ # # label_space_top = label_space_top,
4679
+ # # save_name="tsg_test.pdf"
4680
+ # )
4681
+ # plot_graph(graph=graph,
4682
+ # fig_ax = (fig, axes[0,1]),
4683
+ # val_matrix=val_matrix,
4684
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4685
+ # arrow_linewidth=6,
4686
+ # # label_space_left = label_space_left,
4687
+ # # label_space_top = label_space_top,
4688
+ # )
4689
+ # plot_graph(graph=graph,
4690
+ # fig_ax = (fig, axes[1,0]),
4691
+ # val_matrix=val_matrix,
4692
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4693
+ # arrow_linewidth=6,
4694
+ # # label_space_left = label_space_left,
4695
+ # # label_space_top = label_space_top,
4696
+ # )
4697
+ # plot_graph(graph=graph,
4698
+ # fig_ax = (fig, axes[1,1]),
4699
+ # val_matrix=val_matrix,
4700
+ # var_names = ['Var %s' %i for i in range(len(graph))],
4701
+ # arrow_linewidth=6,
4702
+ # n
4703
+ # # label_space_left = label_space_left,
4704
+ # # label_space_top = label_space_top,
4705
+ # )
4706
+ # # pyplot.subplots_adjust(wspace=0.3, hspace=0.2)
4707
+ # pyplot.tight_layout()
4708
+ # pyplot.savefig("test.pdf")
4709
+
4710
+ # def lin_f(x): return x
4711
+
4712
+ # links_coeffs = {0: [((0, -1), 0.3, lin_f)], #, ((1, -1), 0.5, lin_f)],
4713
+ # 1: [((1, -1), 0.3, lin_f), ((0, 0), 0.7, lin_f), ((2, -1), 0.5, lin_f)],
4714
+ # 2: [],
4715
+ # 3: [((3, -1), 0., lin_f), ((2, 0), 0.6, lin_f),]
4716
+ # }
4717
+ # graph = CausalEffects.get_graph_from_dict(links_coeffs, tau_max=None)
4718
+
4719
+ # val_matrix = np.random.randn(*graph.shape)
4720
+ # val_matrix[:,:,0] = 0.
4721
+ # write_csv(graph=graph,
4722
+ # val_matrix=val_matrix,
4723
+ # var_names=[r'$X^{%d}$' %i for i in range(graph.shape[0])],
4724
+ # link_width=np.ones(graph.shape),
4725
+ # link_attribute = np.ones(graph.shape, dtype='<U10'),
4726
+ # save_name='test.csv')
4727
+
4728
+ # # print(graph)
4729
+ # X = [(0,-1)]
4730
+ # Y = [(1,0)]
4731
+ # causal_effects = CausalEffects(graph, graph_type='stationary_dag', X=X, Y=Y, S=None,
4732
+ # hidden_variables=[(2, 0), (2, -1), (2, -2)],
4733
+ # verbosity=0)
4734
+
4735
+ # pyplot.show()