synthetic-graph-benchmarks 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1230 @@
1
+ ###############################################################################
2
+ #
3
+ # Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation
4
+ #
5
+ ###############################################################################
6
+ from dataclasses import dataclass
7
+
8
+ import networkx
9
+ # import graph_tool.all as gt
10
+
11
+ ##Navigate to the ./util/orca directory and compile orca.cpp
12
+ # g++ -O2 -std=c++11 -o orca orca.cpp
13
+ import os
14
+ import sys
15
+ import copy
16
+ import signal
17
+ import torch
18
+ import torch.nn as nn
19
+ import numpy as np
20
+ import networkx as nx
21
+ import subprocess as sp
22
+ import concurrent.futures
23
+
24
+ import pygsp as pg
25
+ import secrets
26
+ from string import ascii_uppercase, digits
27
+ from datetime import datetime
28
+ from scipy.linalg import eigvalsh
29
+ from scipy.stats import chi2
30
+ from synthetic_graph_benchmarks.dataset import Dataset
31
+ from synthetic_graph_benchmarks.dist_helper import (
32
+ compute_mmd,
33
+ gaussian_emd,
34
+ gaussian,
35
+ emd,
36
+ gaussian_tv,
37
+ disc,
38
+ )
39
+ from sklearn.cluster import SpectralClustering
40
+
41
+ import orca as orca_package
42
+
43
+ from synthetic_graph_benchmarks.utils import available_cpu_count
44
+ # from torch_geometric.utils import to_networkx
45
+ # import wandb
46
+
47
+ def compute_ratios(gen_metrics, ref_metrics, metrics_keys):
48
+ print("Computing ratios of metrics: ", metrics_keys)
49
+ if ref_metrics is not None and len(metrics_keys) > 0:
50
+ ratios = {}
51
+ for key in metrics_keys:
52
+ try:
53
+ ref_metric = round(ref_metrics[key], 4)
54
+ except:
55
+ print(key, "not found")
56
+ continue
57
+ if ref_metric != 0.0:
58
+ ratios[key + "_ratio"] = gen_metrics[key] / ref_metric
59
+ else:
60
+ print(f"WARNING: Reference {key} is 0. Skipping its ratio.")
61
+ if len(ratios) > 0:
62
+ ratios["average_ratio"] = sum(ratios.values()) / len(ratios)
63
+ else:
64
+ ratios["average_ratio"] = -1
65
+ print(f"WARNING: no ratio being saved.")
66
+ else:
67
+ print("WARNING: No reference metrics for ratio computation.")
68
+ ratios = {}
69
+
70
+ return ratios
71
+
72
+ PRINT_TIME = False
73
+ __all__ = [
74
+ "degree_stats",
75
+ "clustering_stats",
76
+ "orbit_stats_all",
77
+ "spectral_stats",
78
+ "eval_acc_lobster_graph",
79
+ ]
80
+
81
+
82
+ # Define a timeout handler
83
+ def handler(signum, frame):
84
+ raise TimeoutError
85
+
86
+
87
+ # Set the signal handler for the alarm
88
+ signal.signal(signal.SIGALRM, handler)
89
+
90
+
91
+ def degree_worker(G):
92
+ return np.array(nx.degree_histogram(G))
93
+
94
+
95
+ def degree_stats(graph_ref_list, graph_pred_list, is_parallel=True, compute_emd=False):
96
+ """Compute the distance between the degree distributions of two unordered sets of graphs.
97
+ Args:
98
+ graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated
99
+ """
100
+ sample_ref = []
101
+ sample_pred = []
102
+ # in case an empty graph is generated
103
+ graph_pred_list_remove_empty = [
104
+ G for G in graph_pred_list if not G.number_of_nodes() == 0
105
+ ]
106
+
107
+ prev = datetime.now()
108
+ if is_parallel:
109
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
110
+ for deg_hist in executor.map(degree_worker, graph_ref_list):
111
+ sample_ref.append(deg_hist)
112
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
113
+ for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty):
114
+ sample_pred.append(deg_hist)
115
+ else:
116
+ for i in range(len(graph_ref_list)):
117
+ degree_temp = np.array(nx.degree_histogram(graph_ref_list[i]))
118
+ sample_ref.append(degree_temp)
119
+ for i in range(len(graph_pred_list_remove_empty)):
120
+ degree_temp = np.array(nx.degree_histogram(graph_pred_list_remove_empty[i]))
121
+ sample_pred.append(degree_temp)
122
+
123
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
124
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
125
+ if compute_emd:
126
+ # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN
127
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
128
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
129
+ else:
130
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)
131
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian)
132
+
133
+ elapsed = datetime.now() - prev
134
+ if PRINT_TIME:
135
+ print("Time computing degree mmd: ", elapsed)
136
+ return mmd_dist
137
+
138
+
139
+ ###############################################################################
140
+
141
+
142
+ def spectral_worker(G, n_eigvals=-1):
143
+ # eigs = nx.laplacian_spectrum(G)
144
+ try:
145
+ eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense())
146
+ except:
147
+ eigs = np.zeros(G.number_of_nodes())
148
+ if n_eigvals > 0:
149
+ eigs = eigs[1 : n_eigvals + 1]
150
+ spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False)
151
+ spectral_pmf = spectral_pmf / spectral_pmf.sum()
152
+ return spectral_pmf
153
+
154
+
155
+ def get_spectral_pmf(eigs, max_eig):
156
+ spectral_pmf, _ = np.histogram(
157
+ np.clip(eigs, 0, max_eig), bins=200, range=(-1e-5, max_eig), density=False
158
+ )
159
+ spectral_pmf = spectral_pmf / spectral_pmf.sum()
160
+ return spectral_pmf
161
+
162
+
163
+ def eigval_stats(
164
+ eig_ref_list, eig_pred_list, max_eig=20, is_parallel=True, compute_emd=False
165
+ ):
166
+ """Compute the distance between the degree distributions of two unordered sets of graphs.
167
+ Args:
168
+ graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated
169
+ """
170
+ sample_ref = []
171
+ sample_pred = []
172
+
173
+ prev = datetime.now()
174
+ if is_parallel:
175
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
176
+ for spectral_density in executor.map(
177
+ get_spectral_pmf,
178
+ eig_ref_list,
179
+ [max_eig for i in range(len(eig_ref_list))],
180
+ ):
181
+ sample_ref.append(spectral_density)
182
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
183
+ for spectral_density in executor.map(
184
+ get_spectral_pmf,
185
+ eig_pred_list,
186
+ [max_eig for i in range(len(eig_ref_list))],
187
+ ):
188
+ sample_pred.append(spectral_density)
189
+ else:
190
+ for i in range(len(eig_ref_list)):
191
+ spectral_temp = get_spectral_pmf(eig_ref_list[i])
192
+ sample_ref.append(spectral_temp)
193
+ for i in range(len(eig_pred_list)):
194
+ spectral_temp = get_spectral_pmf(eig_pred_list[i])
195
+ sample_pred.append(spectral_temp)
196
+
197
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
198
+ if compute_emd:
199
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
200
+ else:
201
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)
202
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian)
203
+
204
+ elapsed = datetime.now() - prev
205
+ if PRINT_TIME:
206
+ print("Time computing eig mmd: ", elapsed)
207
+ return mmd_dist
208
+
209
+
210
+ def eigh_worker(G):
211
+ L = nx.normalized_laplacian_matrix(G).todense()
212
+ try:
213
+ eigvals, eigvecs = np.linalg.eigh(L)
214
+ except:
215
+ eigvals = np.zeros(L[0, :].shape)
216
+ eigvecs = np.zeros(L.shape)
217
+ return (eigvals, eigvecs)
218
+
219
+
220
+ def compute_list_eigh(graph_list, is_parallel=False):
221
+ eigval_list = []
222
+ eigvec_list = []
223
+ if is_parallel:
224
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
225
+ for e_U in executor.map(eigh_worker, graph_list):
226
+ eigval_list.append(e_U[0])
227
+ eigvec_list.append(e_U[1])
228
+ else:
229
+ for i in range(len(graph_list)):
230
+ e_U = eigh_worker(graph_list[i])
231
+ eigval_list.append(e_U[0])
232
+ eigvec_list.append(e_U[1])
233
+ return eigval_list, eigvec_list
234
+
235
+
236
+ def get_spectral_filter_worker(eigvec, eigval, filters, bound=1.4):
237
+ ges = filters.evaluate(eigval)
238
+ linop = []
239
+ for ge in ges:
240
+ linop.append(eigvec @ np.diag(ge) @ eigvec.T)
241
+ linop = np.array(linop)
242
+ norm_filt = np.sum(linop**2, axis=2)
243
+ hist_range = [0, bound]
244
+ hist = np.array(
245
+ [np.histogram(x, range=hist_range, bins=100)[0] for x in norm_filt]
246
+ ) # NOTE: change number of bins
247
+ return hist.flatten()
248
+
249
+
250
+ def spectral_filter_stats(
251
+ eigvec_ref_list,
252
+ eigval_ref_list,
253
+ eigvec_pred_list,
254
+ eigval_pred_list,
255
+ is_parallel=False,
256
+ compute_emd=False,
257
+ ):
258
+ """Compute the distance between the eigvector sets.
259
+ Args:
260
+ graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated
261
+ """
262
+ prev = datetime.now()
263
+
264
+ class DMG(object):
265
+ """Dummy Normalized Graph"""
266
+
267
+ lmax = 2
268
+
269
+ n_filters = 12
270
+ filters = pg.filters.Abspline(DMG, n_filters)
271
+ bound = np.max(filters.evaluate(np.arange(0, 2, 0.01)))
272
+ sample_ref = []
273
+ sample_pred = []
274
+ if is_parallel:
275
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
276
+ for spectral_density in executor.map(
277
+ get_spectral_filter_worker,
278
+ eigvec_ref_list,
279
+ eigval_ref_list,
280
+ [filters for i in range(len(eigval_ref_list))],
281
+ [bound for i in range(len(eigval_ref_list))],
282
+ ):
283
+ sample_ref.append(spectral_density)
284
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
285
+ for spectral_density in executor.map(
286
+ get_spectral_filter_worker,
287
+ eigvec_pred_list,
288
+ eigval_pred_list,
289
+ [filters for i in range(len(eigval_pred_list))],
290
+ [bound for i in range(len(eigval_pred_list))],
291
+ ):
292
+ sample_pred.append(spectral_density)
293
+ else:
294
+ for i in range(len(eigval_ref_list)):
295
+ try:
296
+ spectral_temp = get_spectral_filter_worker(
297
+ eigvec_ref_list[i], eigval_ref_list[i], filters, bound
298
+ )
299
+ sample_ref.append(spectral_temp)
300
+ except:
301
+ pass
302
+ for i in range(len(eigval_pred_list)):
303
+ try:
304
+ spectral_temp = get_spectral_filter_worker(
305
+ eigvec_pred_list[i], eigval_pred_list[i], filters, bound
306
+ )
307
+ sample_pred.append(spectral_temp)
308
+ except:
309
+ pass
310
+
311
+ if compute_emd:
312
+ # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN
313
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
314
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
315
+ else:
316
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)
317
+
318
+ elapsed = datetime.now() - prev
319
+ if PRINT_TIME:
320
+ print("Time computing spectral filter stats: ", elapsed)
321
+ return mmd_dist
322
+
323
+
324
+ def spectral_stats(
325
+ graph_ref_list, graph_pred_list, is_parallel=True, n_eigvals=-1, compute_emd=False
326
+ ):
327
+ """Compute the distance between the degree distributions of two unordered sets of graphs.
328
+ Args:
329
+ graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated
330
+ """
331
+ sample_ref = []
332
+ sample_pred = []
333
+ # in case an empty graph is generated
334
+ graph_pred_list_remove_empty = [
335
+ G for G in graph_pred_list if not G.number_of_nodes() == 0
336
+ ]
337
+
338
+ prev = datetime.now()
339
+ if is_parallel:
340
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
341
+ for spectral_density in executor.map(
342
+ spectral_worker, graph_ref_list, [n_eigvals for i in graph_ref_list]
343
+ ):
344
+ sample_ref.append(spectral_density)
345
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
346
+ for spectral_density in executor.map(
347
+ spectral_worker,
348
+ graph_pred_list_remove_empty,
349
+ [n_eigvals for i in graph_pred_list_remove_empty],
350
+ ):
351
+ sample_pred.append(spectral_density)
352
+ else:
353
+ for i in range(len(graph_ref_list)):
354
+ spectral_temp = spectral_worker(graph_ref_list[i], n_eigvals)
355
+ sample_ref.append(spectral_temp)
356
+ for i in range(len(graph_pred_list_remove_empty)):
357
+ spectral_temp = spectral_worker(graph_pred_list_remove_empty[i], n_eigvals)
358
+ sample_pred.append(spectral_temp)
359
+
360
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
361
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
362
+ if compute_emd:
363
+ # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN
364
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
365
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
366
+ else:
367
+ mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)
368
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian)
369
+
370
+ elapsed = datetime.now() - prev
371
+ if PRINT_TIME:
372
+ print("Time computing degree mmd: ", elapsed)
373
+ return mmd_dist
374
+
375
+
376
+ ###############################################################################
377
+
378
+
379
+ def clustering_worker(param):
380
+ G, bins = param
381
+ clustering_coeffs_list = list(nx.clustering(G).values())
382
+ hist, _ = np.histogram(
383
+ clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False
384
+ )
385
+ return hist
386
+
387
+
388
+ def clustering_stats(
389
+ graph_ref_list, graph_pred_list, bins=100, is_parallel=True, compute_emd=False
390
+ ):
391
+ sample_ref = []
392
+ sample_pred = []
393
+ graph_pred_list_remove_empty = [
394
+ G for G in graph_pred_list if not G.number_of_nodes() == 0
395
+ ]
396
+
397
+ prev = datetime.now()
398
+ if is_parallel:
399
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
400
+ for clustering_hist in executor.map(
401
+ clustering_worker, [(G, bins) for G in graph_ref_list]
402
+ ):
403
+ sample_ref.append(clustering_hist)
404
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
405
+ for clustering_hist in executor.map(
406
+ clustering_worker, [(G, bins) for G in graph_pred_list_remove_empty]
407
+ ):
408
+ sample_pred.append(clustering_hist)
409
+
410
+ # check non-zero elements in hist
411
+ # total = 0
412
+ # for i in range(len(sample_pred)):
413
+ # nz = np.nonzero(sample_pred[i])[0].shape[0]
414
+ # total += nz
415
+ # print(total)
416
+ else:
417
+ for i in range(len(graph_ref_list)):
418
+ clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values())
419
+ hist, _ = np.histogram(
420
+ clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False
421
+ )
422
+ sample_ref.append(hist)
423
+
424
+ for i in range(len(graph_pred_list_remove_empty)):
425
+ clustering_coeffs_list = list(
426
+ nx.clustering(graph_pred_list_remove_empty[i]).values()
427
+ )
428
+ hist, _ = np.histogram(
429
+ clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False
430
+ )
431
+ sample_pred.append(hist)
432
+
433
+ if compute_emd:
434
+ # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN
435
+ # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd, sigma=1.0 / 10)
436
+ mmd_dist = compute_mmd(
437
+ sample_ref,
438
+ sample_pred,
439
+ kernel=gaussian_emd,
440
+ sigma=1.0 / 10,
441
+ distance_scaling=bins,
442
+ )
443
+ else:
444
+ mmd_dist = compute_mmd(
445
+ sample_ref, sample_pred, kernel=gaussian_tv, sigma=1.0 / 10
446
+ )
447
+
448
+ elapsed = datetime.now() - prev
449
+ if PRINT_TIME:
450
+ print("Time computing clustering mmd: ", elapsed)
451
+ return mmd_dist
452
+
453
+
454
+ # maps motif/orbit name string to its corresponding list of indices from orca output
455
+ motif_to_indices = {
456
+ "3path": [1, 2],
457
+ "4cycle": [8],
458
+ }
459
+ COUNT_START_STR = "orbit counts:"
460
+
461
+
462
+ def edge_list_reindexed(G):
463
+ idx = 0
464
+ id2idx = dict()
465
+ for u in G.nodes():
466
+ id2idx[str(u)] = idx
467
+ idx += 1
468
+
469
+ edges = []
470
+ for u, v in G.edges():
471
+ edges.append((id2idx[str(u)], id2idx[str(v)]))
472
+ return edges
473
+
474
+
475
+
476
+ def orca(graph):
477
+ return orca_package.orca_nodes(np.array(edge_list_reindexed(graph)), graph.number_of_nodes(), graphlet_size=4)
478
+
479
+ def motif_stats(
480
+ graph_ref_list,
481
+ graph_pred_list,
482
+ motif_type="4cycle",
483
+ ground_truth_match=None,
484
+ bins=100,
485
+ compute_emd=False,
486
+ ):
487
+ # graph motif counts (int for each graph)
488
+ # normalized by graph size
489
+ total_counts_ref = []
490
+ total_counts_pred = []
491
+
492
+ num_matches_ref = []
493
+ num_matches_pred = []
494
+
495
+ graph_pred_list_remove_empty = [
496
+ G for G in graph_pred_list if not G.number_of_nodes() == 0
497
+ ]
498
+ indices = motif_to_indices[motif_type]
499
+
500
+ for G in graph_ref_list:
501
+ orbit_counts = orca(G)
502
+ motif_counts = np.sum(orbit_counts[:, indices], axis=1)
503
+
504
+ if ground_truth_match is not None:
505
+ match_cnt = 0
506
+ for elem in motif_counts:
507
+ if elem == ground_truth_match:
508
+ match_cnt += 1
509
+ num_matches_ref.append(match_cnt / G.number_of_nodes())
510
+
511
+ # hist, _ = np.histogram(
512
+ # motif_counts, bins=bins, density=False)
513
+ motif_temp = np.sum(motif_counts) / G.number_of_nodes()
514
+ total_counts_ref.append(motif_temp)
515
+
516
+ for G in graph_pred_list_remove_empty:
517
+ orbit_counts = orca(G)
518
+ motif_counts = np.sum(orbit_counts[:, indices], axis=1)
519
+
520
+ if ground_truth_match is not None:
521
+ match_cnt = 0
522
+ for elem in motif_counts:
523
+ if elem == ground_truth_match:
524
+ match_cnt += 1
525
+ num_matches_pred.append(match_cnt / G.number_of_nodes())
526
+
527
+ motif_temp = np.sum(motif_counts) / G.number_of_nodes()
528
+ total_counts_pred.append(motif_temp)
529
+
530
+ total_counts_ref = np.array(total_counts_ref)[:, None]
531
+ total_counts_pred = np.array(total_counts_pred)[:, None]
532
+
533
+ if compute_emd:
534
+ # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN
535
+ # mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=emd, is_hist=False)
536
+ mmd_dist = compute_mmd(
537
+ total_counts_ref, total_counts_pred, kernel=gaussian, is_hist=False
538
+ )
539
+ else:
540
+ mmd_dist = compute_mmd(
541
+ total_counts_ref, total_counts_pred, kernel=gaussian, is_hist=False
542
+ )
543
+ return mmd_dist
544
+
545
+
546
+ def orbit_stats_all(graph_ref_list, graph_pred_list, compute_emd=False):
547
+ total_counts_ref = []
548
+ total_counts_pred = []
549
+
550
+ graph_pred_list_remove_empty = [
551
+ G for G in graph_pred_list if not G.number_of_nodes() == 0
552
+ ]
553
+
554
+ for G in graph_ref_list:
555
+ orbit_counts = orca(G)
556
+ orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
557
+ total_counts_ref.append(orbit_counts_graph)
558
+
559
+ for G in graph_pred_list:
560
+ orbit_counts = orca(G)
561
+ orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
562
+ total_counts_pred.append(orbit_counts_graph)
563
+
564
+ total_counts_ref = np.array(total_counts_ref)
565
+ total_counts_pred = np.array(total_counts_pred)
566
+
567
+ # mmd_dist = compute_mmd(
568
+ # total_counts_ref,
569
+ # total_counts_pred,
570
+ # kernel=gaussian,
571
+ # is_hist=False,
572
+ # sigma=30.0)
573
+
574
+ # mmd_dist = compute_mmd(
575
+ # total_counts_ref,
576
+ # total_counts_pred,
577
+ # kernel=gaussian_tv,
578
+ # is_hist=False,
579
+ # sigma=30.0)
580
+
581
+ if compute_emd:
582
+ # mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=emd, sigma=30.0)
583
+ # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN
584
+ mmd_dist = compute_mmd(
585
+ total_counts_ref,
586
+ total_counts_pred,
587
+ kernel=gaussian,
588
+ is_hist=False,
589
+ sigma=30.0,
590
+ )
591
+ else:
592
+ mmd_dist = compute_mmd(
593
+ total_counts_ref,
594
+ total_counts_pred,
595
+ kernel=gaussian_tv,
596
+ is_hist=False,
597
+ sigma=30.0,
598
+ )
599
+ return mmd_dist
600
+
601
+
602
+ def eval_acc_lobster_graph(G_list):
603
+ G_list = [copy.deepcopy(gg) for gg in G_list]
604
+ count = 0
605
+ for gg in G_list:
606
+ if is_lobster_graph(gg):
607
+ count += 1
608
+ return count / float(len(G_list))
609
+
610
+
611
+ def eval_acc_tree_graph(G_list):
612
+ count = 0
613
+ for gg in G_list:
614
+ if nx.is_tree(gg):
615
+ count += 1
616
+ return count / float(len(G_list))
617
+
618
+
619
+ def eval_acc_grid_graph(G_list, grid_start=10, grid_end=20):
620
+ count = 0
621
+ for gg in G_list:
622
+ if is_grid_graph(gg):
623
+ count += 1
624
+ return count / float(len(G_list))
625
+
626
+
627
+ def eval_acc_sbm_graph(
628
+ G_list,
629
+ p_intra=0.3,
630
+ p_inter=0.005,
631
+ strict=True,
632
+ refinement_steps=100,
633
+ is_parallel=True,
634
+ ):
635
+ count = 0.0
636
+ if is_parallel:
637
+ with concurrent.futures.ThreadPoolExecutor(max_workers=available_cpu_count()) as executor:
638
+ for prob in executor.map(
639
+ is_sbm_graph,
640
+ [gg for gg in G_list],
641
+ [p_intra for i in range(len(G_list))],
642
+ [p_inter for i in range(len(G_list))],
643
+ [strict for i in range(len(G_list))],
644
+ [refinement_steps for i in range(len(G_list))],
645
+ ):
646
+ count += prob
647
+ else:
648
+ for gg in G_list:
649
+ count += is_sbm_graph(
650
+ gg,
651
+ p_intra=p_intra,
652
+ p_inter=p_inter,
653
+ strict=strict,
654
+ refinement_steps=refinement_steps,
655
+ )
656
+ return count / float(len(G_list))
657
+
658
+
659
+ def eval_acc_planar_graph(G_list):
660
+ count = 0
661
+ for gg in G_list:
662
+ if is_planar_graph(gg):
663
+ count += 1
664
+ return count / float(len(G_list))
665
+
666
+
667
+ def is_planar_graph(G):
668
+ return nx.is_connected(G) and nx.check_planarity(G)[0]
669
+
670
+
671
+ def is_lobster_graph(G):
672
+ """
673
+ Check a given graph is a lobster graph or not
674
+
675
+ Removing leaf nodes twice:
676
+
677
+ lobster -> caterpillar -> path
678
+
679
+ """
680
+ ### Check if G is a tree
681
+ if nx.is_tree(G):
682
+ G = G.copy()
683
+ ### Check if G is a path after removing leaves twice
684
+ leaves = [n for n, d in G.degree() if d == 1]
685
+ G.remove_nodes_from(leaves)
686
+
687
+ leaves = [n for n, d in G.degree() if d == 1]
688
+ G.remove_nodes_from(leaves)
689
+
690
+ num_nodes = len(G.nodes())
691
+ num_degree_one = [d for n, d in G.degree() if d == 1]
692
+ num_degree_two = [d for n, d in G.degree() if d == 2]
693
+
694
+ if sum(num_degree_one) == 2 and sum(num_degree_two) == 2 * (num_nodes - 2):
695
+ return True
696
+ elif sum(num_degree_one) == 0 and sum(num_degree_two) == 0:
697
+ return True
698
+ else:
699
+ return False
700
+ else:
701
+ return False
702
+
703
+
704
+ def is_grid_graph(G):
705
+ """
706
+ Check if the graph is grid, by comparing with all the real grids with the same node count
707
+ """
708
+ all_grid_file = f"data/all_grids.pt"
709
+ if os.path.isfile(all_grid_file):
710
+ all_grids = torch.load(all_grid_file)
711
+ else:
712
+ all_grids = {}
713
+ for i in range(2, 20):
714
+ for j in range(2, 20):
715
+ G_grid = nx.grid_2d_graph(i, j)
716
+ n_nodes = f"{len(G_grid.nodes())}"
717
+ all_grids[n_nodes] = all_grids.get(n_nodes, []) + [G_grid]
718
+ torch.save(all_grids, all_grid_file)
719
+
720
+ n_nodes = f"{len(G.nodes())}"
721
+ if n_nodes in all_grids:
722
+ for G_grid in all_grids[n_nodes]:
723
+ if nx.faster_could_be_isomorphic(G, G_grid):
724
+ if nx.is_isomorphic(G, G_grid):
725
+ return True
726
+ return False
727
+ else:
728
+ return False
729
+
730
+ def is_sbm_graph(G, p_intra=0.3, p_inter=0.005, strict=True, refinement_steps=100):
731
+ """
732
+ Check if how closely given graph matches a SBM with given probabilities by computing mean probability of Wald test statistic for each recovered parameter.
733
+ Uses spectral clustering instead of graph_tool for block detection.
734
+ """
735
+ try:
736
+ # Use spectral clustering to detect communities/blocks
737
+ adj = nx.adjacency_matrix(G).toarray()
738
+
739
+ if adj.shape[0] < 4: # Too small for meaningful block detection
740
+ if strict:
741
+ return False
742
+ else:
743
+ return 0.0
744
+
745
+ # Try different numbers of clusters (2 to 5 as per original strict conditions)
746
+ best_score = 0.0
747
+
748
+ for n_clusters in range(2, min(6, adj.shape[0] // 10 + 2)):
749
+ try:
750
+ clustering = SpectralClustering(
751
+ n_clusters=n_clusters,
752
+ affinity="precomputed",
753
+ random_state=42,
754
+ assign_labels="discretize",
755
+ )
756
+ labels = clustering.fit_predict(adj)
757
+
758
+ # Count nodes in each block
759
+ unique_labels, node_counts = np.unique(labels, return_counts=True)
760
+ n_blocks = len(unique_labels)
761
+
762
+ if strict:
763
+ if (node_counts > 40).sum() > 0 or (node_counts < 20).sum() > 0:
764
+ continue
765
+
766
+ # Compute edge counts between blocks
767
+ edge_counts = np.zeros((n_blocks, n_blocks))
768
+ for i in range(adj.shape[0]):
769
+ for j in range(i + 1, adj.shape[1]):
770
+ if adj[i, j] > 0:
771
+ block_i = labels[i]
772
+ block_j = labels[j]
773
+ edge_counts[block_i, block_j] += 1
774
+ if block_i != block_j:
775
+ edge_counts[block_j, block_i] += 1
776
+
777
+ # Compute probabilities
778
+ max_intra_edges = node_counts * (node_counts - 1)
779
+ est_p_intra = np.diagonal(edge_counts) / (max_intra_edges + 1e-6)
780
+
781
+ max_inter_edges = node_counts.reshape((-1, 1)) @ node_counts.reshape(
782
+ (1, -1)
783
+ )
784
+ edge_counts_inter = edge_counts.copy()
785
+ np.fill_diagonal(edge_counts_inter, 0)
786
+ est_p_inter = edge_counts_inter / (max_inter_edges + 1e-6)
787
+
788
+ # Compute Wald test statistics
789
+ W_p_intra = (est_p_intra - p_intra) ** 2 / (
790
+ est_p_intra * (1 - est_p_intra) + 1e-6
791
+ )
792
+ W_p_inter = (est_p_inter - p_inter) ** 2 / (
793
+ est_p_inter * (1 - est_p_inter) + 1e-6
794
+ )
795
+
796
+ W = W_p_inter.copy()
797
+ np.fill_diagonal(W, W_p_intra)
798
+ p = 1 - chi2.cdf(np.abs(W), 1)
799
+ p_mean = p.mean()
800
+
801
+ if p_mean > best_score:
802
+ best_score = p_mean
803
+
804
+ except Exception as e:
805
+ print(f"Error during spectral clustering with {n_clusters} clusters: {e}")
806
+ continue
807
+
808
+ if strict:
809
+ return best_score > 0.9 # p value < 10%
810
+ else:
811
+ return best_score
812
+
813
+ except Exception as e:
814
+ print(f"Error during SBM detection: {e}")
815
+ if strict:
816
+ return False
817
+ else:
818
+ return 0.0
819
+ def is_sbm_graph_dummy(G, p_intra=0.3, p_inter=0.005, strict=True, refinement_steps=100):
820
+ """
821
+ Check if how closely given graph matches a SBM with given probabilites by computing mean probability of Wald test statistic for each recovered parameter
822
+ """
823
+ return -1
824
+
825
+ adj = nx.adjacency_matrix(G).toarray()
826
+ idx = adj.nonzero()
827
+ g = gt.Graph()
828
+ g.add_edge_list(np.transpose(idx))
829
+ try:
830
+ state = gt.minimize_blockmodel_dl(g)
831
+ except ValueError:
832
+ if strict:
833
+ return False
834
+ else:
835
+ return 0.0
836
+
837
+ # Refine using merge-split MCMC
838
+ for i in range(refinement_steps):
839
+ state.multiflip_mcmc_sweep(beta=np.inf, niter=10)
840
+
841
+ b = state.get_blocks()
842
+ b = gt.contiguous_map(state.get_blocks())
843
+ state = state.copy(b=b)
844
+ e = state.get_matrix()
845
+ n_blocks = state.get_nonempty_B()
846
+ node_counts = state.get_nr().get_array()[:n_blocks]
847
+ edge_counts = e.todense()[:n_blocks, :n_blocks]
848
+ if strict:
849
+ if (
850
+ (node_counts > 40).sum() > 0
851
+ or (node_counts < 20).sum() > 0
852
+ or n_blocks > 5
853
+ or n_blocks < 2
854
+ ):
855
+ return False
856
+
857
+ max_intra_edges = node_counts * (node_counts - 1)
858
+ est_p_intra = np.diagonal(edge_counts) / (max_intra_edges + 1e-6)
859
+
860
+ max_inter_edges = node_counts.reshape((-1, 1)) @ node_counts.reshape((1, -1))
861
+ np.fill_diagonal(edge_counts, 0)
862
+ est_p_inter = edge_counts / (max_inter_edges + 1e-6)
863
+
864
+ W_p_intra = (est_p_intra - p_intra) ** 2 / (est_p_intra * (1 - est_p_intra) + 1e-6)
865
+ W_p_inter = (est_p_inter - p_inter) ** 2 / (est_p_inter * (1 - est_p_inter) + 1e-6)
866
+
867
+ W = W_p_inter.copy()
868
+ np.fill_diagonal(W, W_p_intra)
869
+ p = 1 - chi2.cdf(abs(W), 1)
870
+ p = p.mean()
871
+ if strict:
872
+ return p > 0.9 # p value < 10 %
873
+ else:
874
+ return p
875
+
876
+
877
+ def eval_fraction_isomorphic(fake_graphs, train_graphs):
878
+ count = 0
879
+ for fake_g in fake_graphs:
880
+ for train_g in train_graphs:
881
+ if nx.faster_could_be_isomorphic(fake_g, train_g):
882
+ if nx.is_isomorphic(fake_g, train_g):
883
+ count += 1
884
+ break
885
+ return count / float(len(fake_graphs))
886
+
887
+
888
+ def eval_fraction_unique(fake_graphs, precise=False):
889
+ count_non_unique = 0
890
+ fake_evaluated = []
891
+ for fake_g in fake_graphs:
892
+ unique = True
893
+ if not fake_g.number_of_nodes() == 0:
894
+ for fake_old in fake_evaluated:
895
+ if precise:
896
+ if nx.faster_could_be_isomorphic(fake_g, fake_old):
897
+ if nx.is_isomorphic(fake_g, fake_old):
898
+ count_non_unique += 1
899
+ unique = False
900
+ break
901
+ else:
902
+ if nx.faster_could_be_isomorphic(fake_g, fake_old):
903
+ if nx.could_be_isomorphic(fake_g, fake_old):
904
+ count_non_unique += 1
905
+ unique = False
906
+ break
907
+ if unique:
908
+ fake_evaluated.append(fake_g)
909
+
910
+ frac_unique = (float(len(fake_graphs)) - count_non_unique) / float(
911
+ len(fake_graphs)
912
+ ) # Fraction of distinct isomorphism classes in the fake graphs
913
+
914
+ return frac_unique
915
+
916
+
917
+ def eval_fraction_unique_non_isomorphic_valid(
918
+ fake_graphs, train_graphs, validity_func=(lambda x: True)
919
+ ):
920
+ count_valid = 0
921
+ count_isomorphic = 0
922
+ count_non_unique = 0
923
+ fake_evaluated = []
924
+ for fake_g in fake_graphs:
925
+ unique = True
926
+
927
+ for fake_old in fake_evaluated:
928
+ try:
929
+ # Set the alarm for 60 seconds
930
+ signal.alarm(60)
931
+ if nx.is_isomorphic(fake_g, fake_old):
932
+ count_non_unique += 1
933
+ unique = False
934
+ break
935
+ except TimeoutError:
936
+ print("Timeout: Skipping this iteration")
937
+ continue
938
+ finally:
939
+ # Disable the alarm
940
+ signal.alarm(0)
941
+ if unique:
942
+ fake_evaluated.append(fake_g)
943
+ non_isomorphic = True
944
+ for train_g in train_graphs:
945
+ if nx.faster_could_be_isomorphic(fake_g, train_g):
946
+ if nx.is_isomorphic(fake_g, train_g):
947
+ count_isomorphic += 1
948
+ non_isomorphic = False
949
+ break
950
+ if non_isomorphic:
951
+ if validity_func(fake_g):
952
+ count_valid += 1
953
+
954
+ frac_unique = (float(len(fake_graphs)) - count_non_unique) / float(
955
+ len(fake_graphs)
956
+ ) # Fraction of distinct isomorphism classes in the fake graphs
957
+ frac_unique_non_isomorphic = (
958
+ float(len(fake_graphs)) - count_non_unique - count_isomorphic
959
+ ) / float(
960
+ len(fake_graphs)
961
+ ) # Fraction of distinct isomorphism classes in the fake graphs that are not in the training set
962
+ frac_unique_non_isomorphic_valid = count_valid / float(
963
+ len(fake_graphs)
964
+ ) # Fraction of distinct isomorphism classes in the fake graphs that are not in the training set and are valid
965
+ return frac_unique, frac_unique_non_isomorphic, frac_unique_non_isomorphic_valid
966
+
967
+
968
+ class SpectreSamplingMetrics(nn.Module):
969
+ def __init__(self, dataset: Dataset, compute_emd, metrics_list):
970
+ super().__init__()
971
+
972
+ self.train_graphs = dataset.train_graphs
973
+ self.val_graphs = dataset.val_graphs
974
+ self.test_graphs = dataset.test_graphs if dataset.test_graphs is not None else dataset.val_graphs
975
+ self.num_graphs_test = len(self.test_graphs)
976
+ self.num_graphs_val = len(self.val_graphs)
977
+ self.compute_emd = compute_emd
978
+ self.metrics_list = metrics_list
979
+
980
+ # Store for wavelet computaiton
981
+ self.val_ref_eigvals, self.val_ref_eigvecs = compute_list_eigh(self.val_graphs)
982
+ self.test_ref_eigvals, self.test_ref_eigvecs = compute_list_eigh(
983
+ self.test_graphs
984
+ )
985
+
986
+ def forward(
987
+ self,
988
+ generated_graphs: list[networkx.Graph],
989
+ ref_metrics= { "val": None, "test": None},
990
+ test=False,
991
+ ):
992
+ reference_graphs = self.test_graphs if test else self.val_graphs
993
+ local_rank=0
994
+ if local_rank == 0:
995
+ print(
996
+ f"Computing sampling metrics between {len(generated_graphs)} generated graphs and {len(reference_graphs)}"
997
+ f" test graphs -- emd computation: {self.compute_emd}"
998
+ )
999
+ networkx_graphs = generated_graphs
1000
+ adjacency_matrices = []
1001
+ for graph in generated_graphs:
1002
+ A = networkx.adjacency_matrix(graph).todense()
1003
+ adjacency_matrices.append(A)
1004
+
1005
+ to_log = {}
1006
+ # np.savez("generated_adjs.npz", *adjacency_matrices)
1007
+
1008
+ if "degree" in self.metrics_list:
1009
+ if local_rank == 0:
1010
+ print("Computing degree stats..")
1011
+ degree = degree_stats(
1012
+ reference_graphs,
1013
+ networkx_graphs,
1014
+ is_parallel=True,
1015
+ compute_emd=self.compute_emd,
1016
+ )
1017
+ to_log["degree"] = degree
1018
+
1019
+ if "wavelet" in self.metrics_list:
1020
+ if local_rank == 0:
1021
+ print("Computing wavelet stats...")
1022
+
1023
+ ref_eigvecs = self.test_ref_eigvecs if test else self.val_ref_eigvecs
1024
+ ref_eigvals = self.test_ref_eigvals if test else self.val_ref_eigvals
1025
+
1026
+ pred_graph_eigvals, pred_graph_eigvecs = compute_list_eigh(networkx_graphs)
1027
+ wavelet = spectral_filter_stats(
1028
+ eigvec_ref_list=ref_eigvecs,
1029
+ eigval_ref_list=ref_eigvals,
1030
+ eigvec_pred_list=pred_graph_eigvecs,
1031
+ eigval_pred_list=pred_graph_eigvals,
1032
+ is_parallel=False,
1033
+ compute_emd=self.compute_emd,
1034
+ )
1035
+ to_log["wavelet"] = wavelet
1036
+
1037
+ if "spectre" in self.metrics_list:
1038
+ if local_rank == 0:
1039
+ print("Computing spectre stats...")
1040
+ spectre = spectral_stats(
1041
+ reference_graphs,
1042
+ networkx_graphs,
1043
+ is_parallel=True,
1044
+ n_eigvals=-1,
1045
+ compute_emd=self.compute_emd,
1046
+ )
1047
+
1048
+ to_log["spectre"] = spectre
1049
+
1050
+ if "clustering" in self.metrics_list:
1051
+ if local_rank == 0:
1052
+ print("Computing clustering stats...")
1053
+ clustering = clustering_stats(
1054
+ reference_graphs,
1055
+ networkx_graphs,
1056
+ bins=100,
1057
+ is_parallel=True,
1058
+ compute_emd=self.compute_emd,
1059
+ )
1060
+ to_log["clustering"] = clustering
1061
+
1062
+ if "motif" in self.metrics_list:
1063
+ if local_rank == 0:
1064
+ print("Computing motif stats")
1065
+ motif = motif_stats(
1066
+ reference_graphs,
1067
+ networkx_graphs,
1068
+ motif_type="4cycle",
1069
+ ground_truth_match=None,
1070
+ bins=100,
1071
+ compute_emd=self.compute_emd,
1072
+ )
1073
+ to_log["motif"] = motif
1074
+
1075
+ if "orbit" in self.metrics_list:
1076
+ if local_rank == 0:
1077
+ print("Computing orbit stats...")
1078
+ orbit = orbit_stats_all(
1079
+ reference_graphs, networkx_graphs, compute_emd=self.compute_emd
1080
+ )
1081
+ to_log["orbit"] = orbit
1082
+
1083
+ if "sbm" in self.metrics_list:
1084
+ if local_rank == 0:
1085
+ print("Computing accuracy...")
1086
+ sbm_acc = eval_acc_sbm_graph(
1087
+ networkx_graphs, refinement_steps=100, strict=True
1088
+ )
1089
+ to_log["sbm_acc"] = sbm_acc
1090
+
1091
+ if "planar" in self.metrics_list:
1092
+ if local_rank == 0:
1093
+ print("Computing planar accuracy...")
1094
+ planar_acc = eval_acc_planar_graph(networkx_graphs)
1095
+ to_log["planar_acc"] = planar_acc
1096
+
1097
+ if "tree" in self.metrics_list:
1098
+ if local_rank == 0:
1099
+ print("Computing tree accuracy...")
1100
+ tree_acc = eval_acc_tree_graph(networkx_graphs)
1101
+ to_log["tree_acc"] = tree_acc
1102
+
1103
+ if (
1104
+ "sbm" in self.metrics_list
1105
+ or "planar" in self.metrics_list
1106
+ or "tree" in self.metrics_list
1107
+ ):
1108
+ if local_rank == 0:
1109
+ print("Computing all fractions...")
1110
+ if "sbm" in self.metrics_list:
1111
+ validity_func = is_sbm_graph
1112
+ elif "planar" in self.metrics_list:
1113
+ validity_func = is_planar_graph
1114
+ elif "tree" in self.metrics_list:
1115
+ validity_func = nx.is_tree
1116
+ else:
1117
+ validity_func = None
1118
+ (
1119
+ frac_unique,
1120
+ frac_unique_non_isomorphic,
1121
+ fraction_unique_non_isomorphic_valid,
1122
+ ) = eval_fraction_unique_non_isomorphic_valid(
1123
+ networkx_graphs,
1124
+ self.train_graphs,
1125
+ validity_func,
1126
+ )
1127
+ frac_non_isomorphic = 1.0 - eval_fraction_isomorphic(
1128
+ networkx_graphs, self.train_graphs
1129
+ )
1130
+ to_log.update(
1131
+ {
1132
+ "sampling/frac_unique": frac_unique,
1133
+ "sampling/frac_unique_non_iso": frac_unique_non_isomorphic,
1134
+ "sampling/frac_unic_non_iso_valid": fraction_unique_non_isomorphic_valid,
1135
+ "sampling/frac_non_iso": frac_non_isomorphic,
1136
+ }
1137
+ )
1138
+
1139
+ ratios = compute_ratios(
1140
+ gen_metrics=to_log,
1141
+ ref_metrics=ref_metrics["test"] if test else ref_metrics["val"],
1142
+ metrics_keys=["degree", "clustering", "orbit", "spectre", "wavelet"],
1143
+ )
1144
+ to_log.update(ratios)
1145
+
1146
+ # if local_rank == 0:
1147
+ # print("Sampling statistics", to_log)
1148
+
1149
+ return to_log
1150
+
1151
+ def reset(self):
1152
+ pass
1153
+
1154
+
1155
+ class Comm20SamplingMetrics(SpectreSamplingMetrics):
1156
+
1157
+ def __init__(self, dataset: Dataset):
1158
+ super().__init__(
1159
+ dataset=dataset,
1160
+ compute_emd=True,
1161
+ metrics_list=["degree", "clustering", "orbit", "spectre", "wavelet"],
1162
+ )
1163
+
1164
+
1165
+ class PlanarSamplingMetrics(SpectreSamplingMetrics):
1166
+ def __init__(self, dataset: Dataset):
1167
+ super().__init__(
1168
+ dataset=dataset,
1169
+ compute_emd=False,
1170
+ metrics_list=[
1171
+ "degree",
1172
+ "clustering",
1173
+ "orbit",
1174
+ "spectre",
1175
+ "wavelet",
1176
+ "planar",
1177
+ ],
1178
+ )
1179
+
1180
+
1181
+ class SBMSamplingMetrics(SpectreSamplingMetrics):
1182
+ def __init__(self, dataset: Dataset):
1183
+ super().__init__(
1184
+ dataset=dataset,
1185
+ compute_emd=False,
1186
+ metrics_list=["degree", "clustering", "orbit", "spectre", "wavelet", "sbm"],
1187
+ )
1188
+
1189
+
1190
+ class TreeSamplingMetrics(SpectreSamplingMetrics):
1191
+ def __init__(self, dataset: Dataset):
1192
+ super().__init__(
1193
+ dataset=dataset,
1194
+ compute_emd=False,
1195
+ metrics_list=[
1196
+ "degree",
1197
+ "clustering",
1198
+ "orbit",
1199
+ "spectre",
1200
+ "wavelet",
1201
+ "tree",
1202
+ ],
1203
+ )
1204
+
1205
+
1206
+ class EgoSamplingMetrics(SpectreSamplingMetrics):
1207
+ def __init__(self, dataset: Dataset):
1208
+ super().__init__(
1209
+ dataset=dataset,
1210
+ compute_emd=False,
1211
+ metrics_list=["degree", "clustering", "orbit", "spectre", "wavelet"],
1212
+ )
1213
+
1214
+
1215
+ class ProteinSamplingMetrics(SpectreSamplingMetrics):
1216
+ def __init__(self, dataset: Dataset):
1217
+ super().__init__(
1218
+ dataset=dataset,
1219
+ compute_emd=False,
1220
+ metrics_list=["degree", "clustering", "orbit", "spectre", "wavelet"],
1221
+ )
1222
+
1223
+
1224
+ class IMDBSamplingMetrics(SpectreSamplingMetrics):
1225
+ def __init__(self, dataset: Dataset):
1226
+ super().__init__(
1227
+ dataset=dataset,
1228
+ compute_emd=False,
1229
+ metrics_list=["degree", "clustering", "orbit", "spectre", "wavelet"],
1230
+ )