napistu 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,146 +1,696 @@
1
- import inspect
2
- from typing import Optional, Union
1
+ from dataclasses import dataclass
2
+ import logging
3
+ from typing import Optional, Union, List, Dict, Any
3
4
 
4
5
  import pandas as pd
5
6
  import numpy as np
6
7
  import igraph as ig
8
+ import scipy.stats as stats
7
9
 
8
- from napistu.network.ng_core import NapistuGraph
10
+ from napistu.network.ig_utils import (
11
+ _parse_mask_input,
12
+ _get_attribute_masks,
13
+ _ensure_valid_attribute,
14
+ )
15
+ from napistu.statistics.quantiles import calculate_quantiles
16
+ from napistu.network.constants import (
17
+ MASK_KEYWORDS,
18
+ NAPISTU_GRAPH_VERTICES,
19
+ NET_PROPAGATION_DEFS,
20
+ NULL_STRATEGIES,
21
+ PARAMETRIC_NULL_DEFAULT_DISTRIBUTION,
22
+ VALID_NULL_STRATEGIES,
23
+ )
9
24
 
25
+ logger = logging.getLogger(__name__)
10
26
 
11
- def personalized_pagerank_by_attribute(
12
- napistu_graph: Union[NapistuGraph, ig.Graph],
13
- attribute: str,
14
- damping: float = 0.85,
15
- calculate_uniform_dist: bool = True,
27
+
28
+ @dataclass
29
+ class PropagationMethod:
30
+ method: callable
31
+ non_negative: bool
32
+
33
+
34
+ def network_propagation_with_null(
35
+ graph: ig.Graph,
36
+ attributes: List[str],
37
+ null_strategy: str = NULL_STRATEGIES.NODE_PERMUTATION,
38
+ propagation_method: Union[
39
+ str, PropagationMethod
40
+ ] = NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK,
16
41
  additional_propagation_args: Optional[dict] = None,
42
+ n_samples: int = 100,
43
+ **null_kwargs,
17
44
  ) -> pd.DataFrame:
18
45
  """
19
- Run personalized PageRank with reset probability proportional to a vertex attribute.
20
- Optionally computes uniform PPR over nonzero attribute nodes.
46
+ Apply network propagation to attributes and compare against null distributions.
47
+
48
+ This is the main orchestrator function that:
49
+ 1. Calculates observed propagated scores
50
+ 2. Generates null distribution using specified strategy
51
+ 3. Compares observed vs null using quantiles (for sampled nulls) or ratios (for uniform)
21
52
 
22
53
  Parameters
23
54
  ----------
24
- napistu_graph : NapistuGraph
25
- The input graph (subclass of igraph.Graph).
26
- attribute : str
27
- The vertex attribute to use for personalization.
28
- damping : float, optional
29
- Damping factor (default 0.85).
30
- calculate_uniform_dist : bool, optional
31
- If True, also compute uniform PPR over nonzero attribute nodes.
55
+ graph : ig.Graph
56
+ Input graph.
57
+ attributes : List[str]
58
+ Attribute names to propagate and test.
59
+ null_strategy : str
60
+ Null distribution strategy. One of: 'uniform', 'parametric', 'node_permutation', 'edge_permutation'.
61
+ propagation_method : str or PropagationMethod
62
+ Network propagation method to apply.
32
63
  additional_propagation_args : dict, optional
33
- Additional arguments to pass to igraph's personalized_pagerank. Keys must match the method's signature.
64
+ Additional arguments to pass to the network propagation method.
65
+ n_samples : int
66
+ Number of null samples to generate (ignored for uniform null).
67
+ **null_kwargs
68
+ Additional arguments to pass to the null generator (e.g., mask, burn_in_ratio, etc.).
34
69
 
35
70
  Returns
36
71
  -------
37
72
  pd.DataFrame
38
- DataFrame with columns ['name', 'pagerank_by_attribute', attribute] and optionally 'pagerank_uniform'.
73
+ DataFrame with same structure as observed scores containing:
74
+ - For uniform null: observed/uniform ratios
75
+ - For other nulls: quantiles (proportion of null values <= observed values)
76
+
77
+ Examples
78
+ --------
79
+ >>> # Node permutation test with custom mask
80
+ >>> result = network_propagation_with_null(
81
+ ... graph, ['gene_score'],
82
+ ... null_strategy='node_permutation',
83
+ ... n_samples=1000,
84
+ ... mask='measured_genes'
85
+ ... )
86
+
87
+ >>> # Edge permutation test
88
+ >>> result = network_propagation_with_null(
89
+ ... graph, ['pathway_score'],
90
+ ... null_strategy='edge_permutation',
91
+ ... n_samples=100,
92
+ ... burn_in_ratio=10,
93
+ ... sampling_ratio=0.1
94
+ ... )
95
+ """
96
+ # 1. Calculate observed propagated scores
97
+ observed_scores = net_propagate_attributes(
98
+ graph, attributes, propagation_method, additional_propagation_args
99
+ )
100
+
101
+ # 2. Get null generator function
102
+ null_generator = get_null_generator(null_strategy)
103
+
104
+ # 3. Generate null distribution
105
+ if null_strategy == NULL_STRATEGIES.UNIFORM:
106
+ # Uniform null doesn't take n_samples
107
+ null_distribution = null_generator(
108
+ graph=graph,
109
+ attributes=attributes,
110
+ propagation_method=propagation_method,
111
+ additional_propagation_args=additional_propagation_args,
112
+ **null_kwargs,
113
+ )
114
+
115
+ # 4a. For uniform null: calculate observed/uniform ratios
116
+ # Avoid division by zero by adding small epsilon
117
+ epsilon = 1e-10
118
+ ratios = observed_scores / (null_distribution + epsilon)
119
+ return ratios
120
+
121
+ else:
122
+ # Other nulls take n_samples
123
+ null_distribution = null_generator(
124
+ graph=graph,
125
+ attributes=attributes,
126
+ propagation_method=propagation_method,
127
+ additional_propagation_args=additional_propagation_args,
128
+ n_samples=n_samples,
129
+ **null_kwargs,
130
+ )
39
131
 
40
- Example
132
+ # 4b. For sampled nulls: calculate quantiles
133
+ return calculate_quantiles(observed_scores, null_distribution)
134
+
135
+
136
+ def net_propagate_attributes(
137
+ graph: ig.Graph,
138
+ attributes: List[str],
139
+ propagation_method: Union[
140
+ str, PropagationMethod
141
+ ] = NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK,
142
+ additional_propagation_args: Optional[dict] = None,
143
+ ) -> pd.DataFrame:
144
+ """
145
+ Propagate multiple attributes over a network using a network propagation method.
146
+
147
+ Parameters
148
+ ----------
149
+ graph : ig.Graph
150
+ The graph to propagate attributes over.
151
+ attributes : List[str]
152
+ List of attribute names to propagate.
153
+ propagation_method : str
154
+ The network propagation method to use (e.g., 'personalized_pagerank').
155
+ additional_propagation_args : dict, optional
156
+ Additional arguments to pass to the network propagation method.
157
+
158
+ Returns
41
159
  -------
42
- >>> import igraph as ig
43
- >>> from napistu.network.net_propagation import personalized_pagerank_by_attribute
44
- >>> g = ig.Graph.Full(3)
45
- >>> g.vs['name'] = ['A', 'B', 'C']
46
- >>> g.vs['score'] = [1, 0, 2]
47
- >>> df = personalized_pagerank_by_attribute(g, 'score')
48
- >>> print(df)
160
+ pd.DataFrame
161
+ DataFrame with node names as index and attributes as columns,
162
+ containing the propagated attribute values.
49
163
  """
50
- # Validate and extract attribute (missing/None as 0)
51
- attr = _ensure_nonnegative_vertex_attribute(napistu_graph, attribute)
52
164
 
53
- # Validate additional_propagation_args
165
+ propagation_method = _ensure_propagation_method(propagation_method)
166
+ _validate_vertex_attributes(graph, attributes, propagation_method)
167
+
54
168
  if additional_propagation_args is None:
55
169
  additional_propagation_args = {}
56
- else:
57
- valid_args = set(
58
- inspect.signature(napistu_graph.personalized_pagerank).parameters.keys()
170
+
171
+ results = []
172
+ for attr in attributes:
173
+ # Validate attributes
174
+ attr_data = _ensure_valid_attribute(
175
+ graph, attr, non_negative=propagation_method.non_negative
176
+ )
177
+ # apply the propagation method
178
+ pr_attr = propagation_method.method(
179
+ graph, attr_data, **additional_propagation_args
59
180
  )
60
- for k in additional_propagation_args:
61
- if k not in valid_args:
62
- raise ValueError(f"Invalid argument for personalized_pagerank: {k}")
63
181
 
64
- # Personalized PageRank (no normalization, igraph handles it)
65
- pr_attr = napistu_graph.personalized_pagerank(
66
- reset=attr.tolist(), damping=damping, **additional_propagation_args
67
- )
182
+ results.append(pr_attr)
68
183
 
69
- # Node names
184
+ # Get node names once
70
185
  names = (
71
- napistu_graph.vs["name"]
72
- if "name" in napistu_graph.vs.attributes()
73
- else list(range(napistu_graph.vcount()))
186
+ graph.vs[NAPISTU_GRAPH_VERTICES.NAME]
187
+ if NAPISTU_GRAPH_VERTICES.NAME in graph.vs.attributes()
188
+ else list(range(graph.vcount()))
74
189
  )
75
190
 
76
- data = {"name": names, "pagerank_by_attribute": pr_attr, attribute: attr}
77
-
78
- # Uniform PPR over nonzero attribute nodes
79
- if calculate_uniform_dist:
80
- used_in_uniform = attr > 0
81
- n_uniform = used_in_uniform.sum()
82
- if n_uniform == 0:
83
- raise ValueError("No nonzero attribute values for uniform PPR.")
84
- uniform_vec = np.zeros_like(attr, dtype=float)
85
- uniform_vec[used_in_uniform] = 1.0 / n_uniform
86
- pr_uniform = napistu_graph.personalized_pagerank(
87
- reset=uniform_vec.tolist(), damping=damping, **additional_propagation_args
88
- )
89
- data["pagerank_uniform"] = pr_uniform
191
+ return pd.DataFrame(np.column_stack(results), index=names, columns=attributes)
90
192
 
91
- return pd.DataFrame(data)
92
193
 
194
+ def uniform_null(
195
+ graph: ig.Graph,
196
+ attributes: List[str],
197
+ propagation_method: Union[
198
+ str, PropagationMethod
199
+ ] = NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK,
200
+ additional_propagation_args: Optional[dict] = None,
201
+ mask: Optional[Union[str, np.ndarray, List, Dict]] = MASK_KEYWORDS.ATTR,
202
+ ) -> pd.DataFrame:
203
+ """
204
+ Generate uniform null distribution over masked nodes and apply propagation method.
205
+
206
+ Parameters
207
+ ----------
208
+ graph : ig.Graph
209
+ Input graph.
210
+ attributes : List[str]
211
+ Attribute names to generate nulls for.
212
+ propagation_method : str
213
+ Network propagation method to apply.
214
+ additional_propagation_args : dict, optional
215
+ Additional arguments to pass to the network propagation method.
216
+ mask : str, np.ndarray, List, Dict, or None
217
+ Mask specification. Default is "attr" (use each attribute as its own mask).
93
218
 
94
- def _ensure_nonnegative_vertex_attribute(
95
- napistu_graph: Union[NapistuGraph, ig.Graph], attribute: str
96
- ):
219
+ Returns
220
+ -------
221
+ pd.DataFrame
222
+ Propagated null sample with uniform distribution over masked nodes.
223
+ Shape: (n_nodes, n_attributes)
97
224
  """
98
- Ensure a vertex attribute is present, numeric, and non-negative for all vertices.
99
225
 
100
- This utility checks that the specified vertex attribute exists, is numeric, and non-negative
101
- for all vertices in the graph. Missing or None values are treated as 0. Raises ValueError
102
- if the attribute is missing for all vertices, if all values are zero, or if any value is negative.
226
+ # Validate attributes
227
+ propagation_method = _ensure_propagation_method(propagation_method)
228
+ _validate_vertex_attributes(graph, attributes, propagation_method)
229
+
230
+ # Parse mask input
231
+ mask_specs = _parse_mask_input(mask, attributes)
232
+ masks = _get_attribute_masks(graph, mask_specs)
233
+
234
+ # Create null graph with uniform attributes
235
+ # we'll use these updated attributes when calling net_propagate_attributes() below
236
+ null_graph = graph.copy()
237
+
238
+ for _, attr in enumerate(attributes):
239
+ attr_mask = masks[attr]
240
+ n_masked = attr_mask.sum()
241
+
242
+ if n_masked == 0:
243
+ raise ValueError(f"No nodes in mask for attribute '{attr}'")
244
+
245
+ # Check for constant attribute values when mask is the same as attribute
246
+ if isinstance(mask_specs[attr], str) and mask_specs[attr] == attr:
247
+ attr_values = np.array(graph.vs[attr])
248
+ nonzero_values = attr_values[attr_values > 0]
249
+ if len(np.unique(nonzero_values)) == 1:
250
+ logger.warning(
251
+ f"Attribute '{attr}' has constant non-zero values, uniform null may not be meaningful."
252
+ )
253
+
254
+ # Set uniform values for masked nodes
255
+ null_attr_values = np.zeros(graph.vcount())
256
+ null_attr_values[attr_mask] = 1.0 / n_masked
257
+ null_graph.vs[attr] = null_attr_values.tolist()
258
+
259
+ # Apply propagation method to null graph
260
+ return net_propagate_attributes(
261
+ null_graph, attributes, propagation_method, additional_propagation_args
262
+ )
263
+
264
+
265
+ def parametric_null(
266
+ graph: ig.Graph,
267
+ attributes: List[str],
268
+ propagation_method: Union[
269
+ str, PropagationMethod
270
+ ] = NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK,
271
+ distribution: Union[str, Any] = PARAMETRIC_NULL_DEFAULT_DISTRIBUTION,
272
+ additional_propagation_args: Optional[dict] = None,
273
+ mask: Optional[Union[str, np.ndarray, List, Dict]] = MASK_KEYWORDS.ATTR,
274
+ n_samples: int = 100,
275
+ fit_kwargs: Optional[dict] = None,
276
+ ) -> pd.DataFrame:
277
+ """
278
+ Generate parametric null distribution by fitting scipy.stats distribution to observed values.
103
279
 
104
280
  Parameters
105
281
  ----------
106
- napistu_graph : NapistuGraph or ig.Graph
107
- The input graph (NapistuGraph or igraph.Graph).
108
- attribute : str
109
- The name of the vertex attribute to check.
282
+ graph : ig.Graph
283
+ Input graph.
284
+ attributes : List[str]
285
+ Attribute names to generate nulls for.
286
+ propagation_method : str or PropagationMethod
287
+ Network propagation method to apply.
288
+ distribution : str or scipy.stats distribution
289
+ Distribution to fit. Can be:
290
+ - String name (e.g., 'norm', 'gamma', 'beta', 'expon', 'lognorm')
291
+ - SciPy stats distribution object (e.g., stats.gamma, stats.beta)
292
+ additional_propagation_args : dict, optional
293
+ Additional arguments to pass to the network propagation method.
294
+ mask : str, np.ndarray, List, Dict, or None
295
+ Mask specification. Default is "attr" (use each attribute as its own mask).
296
+ n_samples : int
297
+ Number of null samples to generate.
298
+ fit_kwargs : dict, optional
299
+ Additional arguments passed to distribution.fit() method.
300
+ Common examples:
301
+ - For gamma: {'floc': 0} to fix location at 0
302
+ - For beta: {'floc': 0, 'fscale': 1} to fix support to [0,1]
110
303
 
111
304
  Returns
112
305
  -------
113
- np.ndarray
114
- Array of attribute values (with missing/None replaced by 0).
306
+ pd.DataFrame
307
+ Propagated null samples with specified parametric distribution over masked nodes.
308
+ Shape: (n_samples * n_nodes, n_attributes)
309
+
310
+ Examples
311
+ --------
312
+ >>> # Gaussian null (default)
313
+ >>> result = parametric_null(graph, ['gene_expression'])
115
314
 
116
- Raises
117
- ------
118
- ValueError
119
- If the attribute is missing for all vertices, all values are zero, or any value is negative.
315
+ >>> # Gamma null for positive-valued data
316
+ >>> result = parametric_null(graph, ['gene_expression'],
317
+ ... distribution='gamma',
318
+ ... fit_kwargs={'floc': 0})
319
+
320
+ >>> # Beta null for data in [0,1]
321
+ >>> result = parametric_null(graph, ['probabilities'],
322
+ ... distribution='beta')
323
+
324
+ >>> # Custom scipy distribution
325
+ >>> result = parametric_null(graph, ['counts'],
326
+ ... distribution=stats.poisson)
120
327
  """
121
- all_missing = all(
122
- (attribute not in v.attributes() or v[attribute] is None)
123
- for v in napistu_graph.vs
328
+ # Setup
329
+ dist = _get_distribution_object(distribution)
330
+ if fit_kwargs is None:
331
+ fit_kwargs = {}
332
+
333
+ # Validate attributes
334
+ propagation_method = _ensure_propagation_method(propagation_method)
335
+ _validate_vertex_attributes(graph, attributes, propagation_method)
336
+
337
+ # Parse mask input and get masks
338
+ mask_specs = _parse_mask_input(mask, attributes)
339
+ masks = _get_attribute_masks(graph, mask_specs)
340
+
341
+ # Fit distribution parameters for each attribute
342
+ params = _fit_distribution_parameters(graph, attributes, masks, dist, fit_kwargs)
343
+
344
+ # Get node names for output
345
+ node_names = (
346
+ graph.vs[NAPISTU_GRAPH_VERTICES.NAME]
347
+ if NAPISTU_GRAPH_VERTICES.NAME in graph.vs.attributes()
348
+ else list(range(graph.vcount()))
124
349
  )
125
- if all_missing:
126
- raise ValueError(f"Vertex attribute '{attribute}' is missing for all vertices.")
127
-
128
- values = [
129
- (
130
- v[attribute]
131
- if (attribute in v.attributes() and v[attribute] is not None)
132
- else 0.0
350
+
351
+ # Create null graph once (will overwrite attributes in each sample)
352
+ null_graph = graph.copy()
353
+ all_results = []
354
+
355
+ # Generate samples
356
+ for i in range(n_samples):
357
+ # Generate null sample (modifies null_graph in-place)
358
+ _generate_parametric_null_sample(
359
+ null_graph,
360
+ attributes,
361
+ params,
362
+ ensure_nonnegative=propagation_method.non_negative,
363
+ )
364
+
365
+ # Apply propagation method to null graph
366
+ result = net_propagate_attributes(
367
+ null_graph, attributes, propagation_method, additional_propagation_args
133
368
  )
134
- for v in napistu_graph.vs
135
- ]
369
+ all_results.append(result)
136
370
 
137
- arr = np.array(values, dtype=float)
371
+ # Combine all results
372
+ full_index = node_names * n_samples
373
+ all_data = np.vstack([result.values for result in all_results])
138
374
 
139
- if np.all(arr == 0):
375
+ return pd.DataFrame(all_data, index=full_index, columns=attributes)
376
+
377
+
378
+ def node_permutation_null(
379
+ graph: ig.Graph,
380
+ attributes: List[str],
381
+ propagation_method: Union[
382
+ str, PropagationMethod
383
+ ] = NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK,
384
+ additional_propagation_args: Optional[dict] = None,
385
+ mask: Optional[Union[str, np.ndarray, List, Dict]] = MASK_KEYWORDS.ATTR,
386
+ replace: bool = False,
387
+ n_samples: int = 100,
388
+ ) -> pd.DataFrame:
389
+ """
390
+ Generate null distribution by permuting node attribute values and apply propagation method.
391
+
392
+ Parameters
393
+ ----------
394
+ graph : ig.Graph
395
+ Input graph.
396
+ attributes : List[str]
397
+ Attribute names to permute.
398
+ propagation_method : str or PropagationMethod
399
+ Network propagation method to apply.
400
+ additional_propagation_args : dict, optional
401
+ Additional arguments to pass to the network propagation method.
402
+ mask : str, np.ndarray, List, Dict, or None
403
+ Mask specification. Default is "attr" (use each attribute as its own mask).
404
+ replace : bool
405
+ Whether to sample with replacement.
406
+ n_samples : int
407
+ Number of null samples to generate.
408
+
409
+ Returns
410
+ -------
411
+ pd.DataFrame
412
+ Propagated null samples with permuted attribute values.
413
+ Shape: (n_samples * n_nodes, n_attributes)
414
+ """
415
+ # Validate attributes
416
+ propagation_method = _ensure_propagation_method(propagation_method)
417
+ _validate_vertex_attributes(graph, attributes, propagation_method)
418
+
419
+ # Parse mask input
420
+ mask_specs = _parse_mask_input(mask, attributes)
421
+ masks = _get_attribute_masks(graph, mask_specs)
422
+
423
+ # Get original attribute values
424
+ original_values = {}
425
+ for attr in attributes:
426
+ original_values[attr] = np.array(graph.vs[attr])
427
+
428
+ # Get node names
429
+ node_names = (
430
+ graph.vs[NAPISTU_GRAPH_VERTICES.NAME]
431
+ if NAPISTU_GRAPH_VERTICES.NAME in graph.vs.attributes()
432
+ else list(range(graph.vcount()))
433
+ )
434
+
435
+ # Pre-allocate for results
436
+ all_results = []
437
+
438
+ # Generate samples
439
+ # we'll only do this once and overwrite the attributes in each sample
440
+ null_graph = graph.copy()
441
+
442
+ for _ in range(n_samples):
443
+
444
+ # Permute values among masked nodes for each attribute
445
+ for _, attr in enumerate(attributes):
446
+ attr_mask = masks[attr]
447
+ masked_indices = np.where(attr_mask)[0]
448
+ masked_values = original_values[attr][masked_indices]
449
+
450
+ # Start with original values
451
+ null_attr_values = original_values[attr].copy()
452
+
453
+ if replace:
454
+ # Sample with replacement
455
+ permuted_values = np.random.choice(
456
+ masked_values, size=len(masked_values), replace=True
457
+ )
458
+ else:
459
+ # Permute without replacement
460
+ permuted_values = np.random.permutation(masked_values)
461
+
462
+ null_attr_values[masked_indices] = permuted_values
463
+ null_graph.vs[attr] = null_attr_values.tolist()
464
+
465
+ # Apply propagation method to null graph
466
+ result = net_propagate_attributes(
467
+ null_graph, attributes, propagation_method, additional_propagation_args
468
+ )
469
+ all_results.append(result)
470
+
471
+ # Combine all results
472
+ full_index = node_names * n_samples
473
+ all_data = np.vstack([result.values for result in all_results])
474
+
475
+ return pd.DataFrame(all_data, index=full_index, columns=attributes)
476
+
477
+
478
+ def edge_permutation_null(
479
+ graph: ig.Graph,
480
+ attributes: List[str],
481
+ propagation_method: Union[
482
+ str, PropagationMethod
483
+ ] = NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK,
484
+ additional_propagation_args: Optional[dict] = None,
485
+ burn_in_ratio: float = 10,
486
+ sampling_ratio: float = 0.1,
487
+ n_samples: int = 100,
488
+ ) -> pd.DataFrame:
489
+ """
490
+ Generate null distribution by edge rewiring and apply propagation method.
491
+
492
+ Parameters
493
+ ----------
494
+ graph : ig.Graph
495
+ Input graph.
496
+ attributes : List[str]
497
+ Attribute names to use (values unchanged by rewiring).
498
+ propagation_method : str or PropagationMethod
499
+ Network propagation method to apply.
500
+ additional_propagation_args : dict, optional
501
+ Additional arguments to pass to the network propagation method.
502
+ burn_in_ratio : float
503
+ Multiplier for initial rewiring.
504
+ sampling_ratio : float
505
+ Proportion of edges to rewire between samples.
506
+ n_samples : int
507
+ Number of null samples to generate.
508
+
509
+ Returns
510
+ -------
511
+ pd.DataFrame
512
+ Propagated null samples from rewired network.
513
+ Shape: (n_samples * n_nodes, n_attributes)
514
+ """
515
+
516
+ # Validate attributes
517
+ propagation_method = _ensure_propagation_method(propagation_method)
518
+ _validate_vertex_attributes(graph, attributes, propagation_method)
519
+
520
+ # Setup rewired graph
521
+ null_graph = graph.copy()
522
+ n_edges = len(null_graph.es)
523
+
524
+ # Initial burn-in
525
+ null_graph.rewire(n=burn_in_ratio * n_edges)
526
+
527
+ # Get node names
528
+ node_names = (
529
+ graph.vs[NAPISTU_GRAPH_VERTICES.NAME]
530
+ if NAPISTU_GRAPH_VERTICES.NAME in graph.vs.attributes()
531
+ else list(range(graph.vcount()))
532
+ )
533
+
534
+ # Pre-allocate for results
535
+ all_results = []
536
+
537
+ # Generate samples
538
+ for _ in range(n_samples):
539
+ # Incremental rewiring
540
+ null_graph.rewire(n=int(sampling_ratio * n_edges))
541
+
542
+ # Apply propagation method to rewired graph (attributes unchanged)
543
+ result = net_propagate_attributes(
544
+ null_graph, attributes, propagation_method, additional_propagation_args
545
+ )
546
+ all_results.append(result)
547
+
548
+ # Combine all results
549
+ full_index = node_names * n_samples
550
+ all_data = np.vstack([result.values for result in all_results])
551
+
552
+ return pd.DataFrame(all_data, index=full_index, columns=attributes)
553
+
554
+
555
+ # Null generator registry
556
+ NULL_GENERATORS = {
557
+ NULL_STRATEGIES.UNIFORM: uniform_null,
558
+ NULL_STRATEGIES.PARAMETRIC: parametric_null,
559
+ NULL_STRATEGIES.NODE_PERMUTATION: node_permutation_null,
560
+ NULL_STRATEGIES.EDGE_PERMUTATION: edge_permutation_null,
561
+ }
562
+
563
+
564
+ def get_null_generator(strategy: str):
565
+ """Get null generator function by name."""
566
+ if strategy not in VALID_NULL_STRATEGIES:
140
567
  raise ValueError(
141
- f"Vertex attribute '{attribute}' is zero for all vertices; cannot use as reset vector."
568
+ f"Unknown null strategy: {strategy}. Available: {VALID_NULL_STRATEGIES}"
569
+ )
570
+ return NULL_GENERATORS[strategy]
571
+
572
+
573
+ def _get_distribution_object(distribution: Union[str, Any]) -> Any:
574
+ """Get scipy.stats distribution object from string name or object."""
575
+ if isinstance(distribution, str):
576
+ try:
577
+ return getattr(stats, distribution)
578
+ except AttributeError:
579
+ raise ValueError(
580
+ f"Unknown distribution: '{distribution}'. "
581
+ f"Must be a valid scipy.stats distribution name."
582
+ )
583
+ return distribution
584
+
585
+
586
+ def _fit_distribution_parameters(
587
+ graph: ig.Graph,
588
+ attributes: List[str],
589
+ masks: Dict[str, np.ndarray],
590
+ distribution: Any,
591
+ fit_kwargs: Dict[str, Any],
592
+ ) -> Dict[str, Dict[str, Any]]:
593
+ """Fit distribution parameters for each attribute using masked data."""
594
+ params = {}
595
+
596
+ for attr in attributes:
597
+ attr_mask = masks[attr]
598
+ attr_values = np.array(graph.vs[attr])
599
+ masked_values = attr_values[attr_mask]
600
+ masked_nonzero = masked_values[masked_values > 0]
601
+
602
+ if len(masked_nonzero) == 0:
603
+ raise ValueError(f"No nonzero values in mask for attribute '{attr}'")
604
+
605
+ try:
606
+ # Let SciPy handle parameter estimation and validation
607
+ fitted_params = distribution.fit(masked_nonzero, **fit_kwargs)
608
+
609
+ params[attr] = {
610
+ "fitted_params": fitted_params,
611
+ "mask": attr_mask,
612
+ "distribution": distribution,
613
+ }
614
+
615
+ except Exception as e:
616
+ dist_name = (
617
+ distribution.name
618
+ if hasattr(distribution, "name")
619
+ else str(distribution)
620
+ )
621
+ raise ValueError(
622
+ f"Failed to fit {dist_name} distribution to attribute '{attr}': {str(e)}"
623
+ )
624
+
625
+ return params
626
+
627
+
628
+ def _generate_parametric_null_sample(
629
+ null_graph: ig.Graph,
630
+ attributes: List[str],
631
+ params: Dict[str, Dict[str, Any]],
632
+ ensure_nonnegative: bool,
633
+ ) -> None:
634
+ """Generate one null sample by modifying graph attributes in-place."""
635
+ for attr in attributes:
636
+ attr_mask = params[attr]["mask"]
637
+ fitted_params = params[attr]["fitted_params"]
638
+ distribution = params[attr]["distribution"]
639
+
640
+ # Generate values for masked nodes using fitted distribution
641
+ null_attr_values = np.zeros(null_graph.vcount())
642
+ n_masked = attr_mask.sum()
643
+
644
+ # Sample from fitted distribution
645
+ sampled_values = distribution.rvs(*fitted_params, size=n_masked)
646
+
647
+ # Ensure non-negative if requested (common for PageRank)
648
+ if ensure_nonnegative:
649
+ # warning if there are negative samples since this suggests that the wrong
650
+ # distribution is being used
651
+ if np.any(sampled_values < 0):
652
+ logger.warning(
653
+ f"Negative samples for attribute '{attr}' suggest that the wrong distribution is being used"
654
+ )
655
+ sampled_values = np.maximum(sampled_values, 0)
656
+
657
+ null_attr_values[attr_mask] = sampled_values
658
+ null_graph.vs[attr] = null_attr_values.tolist()
659
+
660
+
661
+ def _validate_vertex_attributes(
662
+ graph: ig.Graph, attributes: List[str], propagation_method: str
663
+ ) -> None:
664
+ """Validate vertex attributes for propagation method."""
665
+
666
+ propagation_method = _ensure_propagation_method(propagation_method)
667
+
668
+ # check that the attributes are numeric and non-negative if required
669
+ for attr in attributes:
670
+ _ = _ensure_valid_attribute(
671
+ graph, attr, non_negative=propagation_method.non_negative
142
672
  )
143
- if np.any(arr < 0):
144
- raise ValueError(f"Attribute '{attribute}' contains negative values.")
145
673
 
146
- return arr
674
+ return None
675
+
676
+
677
+ def _pagerank_wrapper(graph: ig.Graph, attr_data: np.ndarray, **kwargs):
678
+ return graph.personalized_pagerank(reset=attr_data.tolist(), **kwargs)
679
+
680
+
681
+ _pagerank_method = PropagationMethod(method=_pagerank_wrapper, non_negative=True)
682
+
683
+ NET_PROPAGATION_METHODS: dict[str, PropagationMethod] = {
684
+ NET_PROPAGATION_DEFS.PERSONALIZED_PAGERANK: _pagerank_method
685
+ }
686
+ VALID_NET_PROPAGATION_METHODS = NET_PROPAGATION_METHODS.keys()
687
+
688
+
689
+ def _ensure_propagation_method(
690
+ propagation_method: Union[str, PropagationMethod],
691
+ ) -> PropagationMethod:
692
+ if isinstance(propagation_method, str):
693
+ if propagation_method not in VALID_NET_PROPAGATION_METHODS:
694
+ raise ValueError(f"Invalid propagation method: {propagation_method}")
695
+ return NET_PROPAGATION_METHODS[propagation_method]
696
+ return propagation_method