risk-network 0.0.6b3__tar.gz → 0.0.6b5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/PKG-INFO +1 -1
  2. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/__init__.py +1 -1
  3. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/annotations/io.py +29 -0
  4. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/neighborhoods/domains.py +24 -18
  5. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/network/graph.py +13 -9
  6. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/network/plot.py +80 -42
  7. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk_network.egg-info/PKG-INFO +1 -1
  8. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/LICENSE +0 -0
  9. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/MANIFEST.in +0 -0
  10. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/README.md +0 -0
  11. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/pyproject.toml +0 -0
  12. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/annotations/__init__.py +0 -0
  13. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/annotations/annotations.py +0 -0
  14. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/constants.py +0 -0
  15. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/log/__init__.py +0 -0
  16. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/log/console.py +0 -0
  17. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/log/params.py +0 -0
  18. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/neighborhoods/__init__.py +0 -0
  19. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/neighborhoods/community.py +0 -0
  20. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/neighborhoods/neighborhoods.py +0 -0
  21. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/network/__init__.py +0 -0
  22. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/network/geometry.py +0 -0
  23. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/network/io.py +0 -0
  24. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/risk.py +0 -0
  25. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/__init__.py +0 -0
  26. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/fisher_exact.py +0 -0
  27. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/hypergeom.py +0 -0
  28. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/permutation/__init__.py +0 -0
  29. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/permutation/permutation.py +0 -0
  30. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/permutation/test_functions.py +0 -0
  31. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk/stats/stats.py +0 -0
  32. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk_network.egg-info/SOURCES.txt +0 -0
  33. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk_network.egg-info/dependency_links.txt +0 -0
  34. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk_network.egg-info/requires.txt +0 -0
  35. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/risk_network.egg-info/top_level.txt +0 -0
  36. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/setup.cfg +0 -0
  37. {risk_network-0.0.6b3 → risk_network-0.0.6b5}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.6b3
3
+ Version: 0.0.6b5
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.6-beta.3"
10
+ __version__ = "0.0.6-beta.5"
@@ -153,6 +153,35 @@ class AnnotationsIO:
153
153
  # Load the annotations into the provided network
154
154
  return load_annotations(network, annotations_input)
155
155
 
156
+ def load_dict_annotation(self, content: Dict[str, Any], network: nx.Graph) -> Dict[str, Any]:
157
+ """Load annotations from a provided dictionary and convert them to a dictionary annotation.
158
+
159
+ Args:
160
+ content (dict): The annotations dictionary to load.
161
+ network (NetworkX graph): The network to which the annotations are related.
162
+
163
+ Returns:
164
+ dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
165
+ """
166
+ # Ensure the input content is a dictionary
167
+ if not isinstance(content, dict):
168
+ raise TypeError(
169
+ f"Expected 'content' to be a dictionary, but got {type(content).__name__} instead."
170
+ )
171
+
172
+ filetype = "Dictionary"
173
+ # Log the loading of the annotations from the dictionary
174
+ params.log_annotations(filepath="In-memory dictionary", filetype=filetype)
175
+ _log_loading(filetype, "In-memory dictionary")
176
+
177
+ # Load the annotations into the provided network
178
+ annotations_dict = load_annotations(network, content)
179
+ # Ensure the output is a dictionary
180
+ if not isinstance(annotations_dict, dict):
181
+ raise ValueError("Expected output to be a dictionary")
182
+
183
+ return annotations_dict
184
+
156
185
 
157
186
  def _load_matrix_file(
158
187
  filepath: str,
@@ -35,26 +35,31 @@ def define_domains(
35
35
  Returns:
36
36
  pd.DataFrame: DataFrame with the primary domain for each node.
37
37
  """
38
- # Perform hierarchical clustering on the binary enrichment matrix
39
- m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
40
- best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
41
- m, linkage_criterion, linkage_method, linkage_metric
42
- )
43
- try:
44
- Z = linkage(m, method=best_linkage, metric=best_metric)
45
- except ValueError as e:
46
- raise ValueError("No significant annotations found.") from e
38
+ # Check if there's more than one column in significant_neighborhoods_enrichment
39
+ if significant_neighborhoods_enrichment.shape[1] == 1:
40
+ print("Single annotation detected. Skipping clustering.")
41
+ top_annotations["domain"] = 1 # Assign a default domain or handle appropriately
42
+ else:
43
+ # Perform hierarchical clustering on the binary enrichment matrix
44
+ m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
45
+ best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
46
+ m, linkage_criterion, linkage_method, linkage_metric
47
+ )
48
+ try:
49
+ Z = linkage(m, method=best_linkage, metric=best_metric)
50
+ except ValueError as e:
51
+ raise ValueError("No significant annotations found.") from e
47
52
 
48
- print(
49
- f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
50
- )
51
- print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
53
+ print(
54
+ f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
55
+ )
56
+ print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
52
57
 
53
- max_d_optimal = np.max(Z[:, 2]) * best_threshold
54
- domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
55
- # Assign domains to the annotations matrix
56
- top_annotations["domain"] = 0
57
- top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
58
+ max_d_optimal = np.max(Z[:, 2]) * best_threshold
59
+ domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
60
+ # Assign domains to the annotations matrix
61
+ top_annotations["domain"] = 0
62
+ top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
58
63
 
59
64
  # Create DataFrames to store domain information
60
65
  node_to_enrichment = pd.DataFrame(
@@ -63,6 +68,7 @@ def define_domains(
63
68
  )
64
69
  node_to_domain = node_to_enrichment.groupby(level="domain", axis=1).sum()
65
70
 
71
+ # Find the maximum enrichment score for each node
66
72
  t_max = node_to_domain.loc[:, 1:].max(axis=1)
67
73
  t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
68
74
  t_idxmax[t_max == 0] = 0
@@ -42,12 +42,16 @@ class NetworkGraph:
42
42
  node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
43
43
  """
44
44
  self.top_annotations = top_annotations
45
- self.domain_to_nodes = self._create_domain_to_nodes_map(domains)
45
+ self.domain_to_nodes_map = self._create_domain_to_nodes_map(domains)
46
46
  self.domains = domains
47
47
  self.trimmed_domain_to_term = self._create_domain_to_term_map(trimmed_domains)
48
48
  self.trimmed_domains = trimmed_domains
49
- self.node_label_to_id_map = node_label_to_id_map
50
49
  self.node_enrichment_sums = node_enrichment_sums
50
+ self.node_id_to_label_map = {v: k for k, v in node_label_to_id_map.items()}
51
+ self.node_label_to_enrichment_map = dict(
52
+ zip(node_label_to_id_map.keys(), node_enrichment_sums)
53
+ )
54
+ self.node_label_to_id_map = node_label_to_id_map
51
55
  # NOTE: self.network and self.node_coordinates are declared in _initialize_network
52
56
  self.network = None
53
57
  self.node_coordinates = None
@@ -63,12 +67,12 @@ class NetworkGraph:
63
67
  dict: A dictionary where keys are domain IDs and values are lists of nodes belonging to each domain.
64
68
  """
65
69
  cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
66
- node_to_domains = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
67
- domain_to_nodes = defaultdict(list)
68
- for k, v in node_to_domains.items():
69
- domain_to_nodes[v].append(k)
70
+ node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
71
+ domain_to_nodes_map = defaultdict(list)
72
+ for k, v in node_to_domains_map.items():
73
+ domain_to_nodes_map[v].append(k)
70
74
 
71
- return domain_to_nodes
75
+ return domain_to_nodes_map
72
76
 
73
77
  def _create_domain_to_term_map(self, trimmed_domains: pd.DataFrame) -> Dict[str, Any]:
74
78
  """Create a mapping from domain IDs to their corresponding terms.
@@ -154,7 +158,7 @@ class NetworkGraph:
154
158
  # Initialize composite colors array with shape (number of nodes, 4) for RGBA
155
159
  composite_colors = np.zeros((num_nodes, 4))
156
160
  # Assign colors to nodes based on domain_colors
157
- for domain_idx, nodes in self.domain_to_nodes.items():
161
+ for domain_idx, nodes in self.domain_to_nodes_map.items():
158
162
  color = domain_colors[domain_idx]
159
163
  for node in nodes:
160
164
  composite_colors[node] = color
@@ -185,7 +189,7 @@ class NetworkGraph:
185
189
  domain_colors = _get_colors(
186
190
  num_colors_to_generate=len(domains), cmap=cmap, color=color, random_seed=random_seed
187
191
  )
188
- return dict(zip(self.domain_to_nodes.keys(), domain_colors))
192
+ return dict(zip(self.domain_to_nodes_map.keys(), domain_colors))
189
193
 
190
194
 
191
195
  def _transform_colors(
@@ -137,9 +137,7 @@ class NetworkPlotter:
137
137
  )
138
138
  # Set the transparency of the fill if applicable
139
139
  if fill_alpha > 0:
140
- circle.set_facecolor(
141
- _to_rgba(color, fill_alpha)
142
- ) # Use _to_rgba to set the fill color with transparency
140
+ circle.set_facecolor(_to_rgba(color, fill_alpha))
143
141
 
144
142
  self.ax.add_artist(circle)
145
143
 
@@ -200,7 +198,7 @@ class NetworkPlotter:
200
198
  color=color,
201
199
  linestyle=linestyle,
202
200
  linewidth=linewidth,
203
- alpha=fill_alpha, # Use fill_alpha for the fill
201
+ alpha=fill_alpha,
204
202
  )
205
203
 
206
204
  def plot_network(
@@ -279,6 +277,7 @@ class NetworkPlotter:
279
277
  nodes: List,
280
278
  node_size: Union[int, np.ndarray] = 50,
281
279
  node_shape: str = "o",
280
+ node_edgewidth: float = 1.0,
282
281
  edge_width: float = 1.0,
283
282
  node_color: Union[str, List, Tuple, np.ndarray] = "white",
284
283
  node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
@@ -292,6 +291,7 @@ class NetworkPlotter:
292
291
  nodes (list): List of node labels to include in the subnetwork.
293
292
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
294
293
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
294
+ node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
295
295
  edge_width (float, optional): Width of the edges. Defaults to 1.0.
296
296
  node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
297
297
  node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
@@ -302,7 +302,6 @@ class NetworkPlotter:
302
302
  Raises:
303
303
  ValueError: If no valid nodes are found in the network graph.
304
304
  """
305
- # Don't log subnetwork parameters as they are specific to individual annotations
306
305
  # Filter to get node IDs and their coordinates
307
306
  node_ids = [
308
307
  self.graph.node_label_to_id_map.get(node)
@@ -312,10 +311,19 @@ class NetworkPlotter:
312
311
  if not node_ids:
313
312
  raise ValueError("No nodes found in the network graph.")
314
313
 
314
+ # Check if node_color is a single color or a list of colors
315
+ if not isinstance(node_color, (str, tuple, np.ndarray)):
316
+ node_color = [
317
+ node_color[nodes.index(node)]
318
+ for node in nodes
319
+ if node in self.graph.node_label_to_id_map
320
+ ]
321
+
315
322
  # Convert colors to RGBA using the _to_rgba helper function
316
- node_color = _to_rgba(node_color, node_alpha)
317
- node_edgecolor = _to_rgba(node_edgecolor, 1.0) # Node edges usually fully opaque
318
- edge_color = _to_rgba(edge_color, edge_alpha)
323
+ node_color = _to_rgba(node_color, node_alpha, num_repeats=len(node_ids))
324
+ node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(node_ids))
325
+ edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
326
+
319
327
  # Get the coordinates of the filtered nodes
320
328
  node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
321
329
 
@@ -328,6 +336,7 @@ class NetworkPlotter:
328
336
  node_shape=node_shape,
329
337
  node_color=node_color,
330
338
  edgecolors=node_edgecolor,
339
+ linewidths=node_edgewidth,
331
340
  ax=self.ax,
332
341
  )
333
342
  # Draw the edges between the specified nodes in the subnetwork
@@ -348,7 +357,8 @@ class NetworkPlotter:
348
357
  color: Union[str, List, Tuple, np.ndarray] = "white",
349
358
  linestyle: str = "solid",
350
359
  linewidth: float = 1.5,
351
- alpha: float = 0.2,
360
+ alpha: float = 1.0,
361
+ fill_alpha: float = 0.2,
352
362
  ) -> None:
353
363
  """Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
354
364
 
@@ -359,7 +369,8 @@ class NetworkPlotter:
359
369
  color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
360
370
  linestyle (str, optional): Line style for the contours. Defaults to "solid".
361
371
  linewidth (float, optional): Line width for the contours. Defaults to 1.5.
362
- alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
372
+ alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
373
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
363
374
  """
364
375
  # Log the contour plotting parameters
365
376
  params.log_plotter(
@@ -370,14 +381,15 @@ class NetworkPlotter:
370
381
  "custom" if isinstance(color, np.ndarray) else color
371
382
  ), # np.ndarray usually indicates custom colors
372
383
  contour_alpha=alpha,
384
+ contour_fill_alpha=fill_alpha,
373
385
  )
374
386
 
375
387
  # Ensure color is converted to RGBA with repetition matching the number of domains
376
- color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_to_nodes))
388
+ color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_to_nodes_map))
377
389
  # Extract node coordinates from the network graph
378
390
  node_coordinates = self.graph.node_coordinates
379
391
  # Draw contours for each domain in the network
380
- for idx, (_, nodes) in enumerate(self.graph.domain_to_nodes.items()):
392
+ for idx, (_, nodes) in enumerate(self.graph.domain_to_nodes_map.items()):
381
393
  if len(nodes) > 1:
382
394
  self._draw_kde_contour(
383
395
  self.ax,
@@ -390,6 +402,7 @@ class NetworkPlotter:
390
402
  linestyle=linestyle,
391
403
  linewidth=linewidth,
392
404
  alpha=alpha,
405
+ fill_alpha=fill_alpha,
393
406
  )
394
407
 
395
408
  def plot_subcontour(
@@ -401,7 +414,8 @@ class NetworkPlotter:
401
414
  color: Union[str, List, Tuple, np.ndarray] = "white",
402
415
  linestyle: str = "solid",
403
416
  linewidth: float = 1.5,
404
- alpha: float = 0.2,
417
+ alpha: float = 1.0,
418
+ fill_alpha: float = 0.2,
405
419
  ) -> None:
406
420
  """Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
407
421
 
@@ -413,7 +427,8 @@ class NetworkPlotter:
413
427
  color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
414
428
  linestyle (str, optional): Line style for the contour. Defaults to "solid".
415
429
  linewidth (float, optional): Line width for the contour. Defaults to 1.5.
416
- alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
430
+ alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
431
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
417
432
 
418
433
  Raises:
419
434
  ValueError: If no valid nodes are found in the network graph.
@@ -443,6 +458,7 @@ class NetworkPlotter:
443
458
  linestyle=linestyle,
444
459
  linewidth=linewidth,
445
460
  alpha=alpha,
461
+ fill_alpha=fill_alpha,
446
462
  )
447
463
 
448
464
  def _draw_kde_contour(
@@ -456,7 +472,8 @@ class NetworkPlotter:
456
472
  color: Union[str, np.ndarray] = "white",
457
473
  linestyle: str = "solid",
458
474
  linewidth: float = 1.5,
459
- alpha: float = 0.5,
475
+ alpha: float = 1.0,
476
+ fill_alpha: float = 0.2,
460
477
  ) -> None:
461
478
  """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
462
479
 
@@ -470,7 +487,8 @@ class NetworkPlotter:
470
487
  color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
471
488
  linestyle (str, optional): Line style for the contour. Defaults to "solid".
472
489
  linewidth (float, optional): Line width for the contour. Defaults to 1.5.
473
- alpha (float, optional): Transparency level for the contour fill. Defaults to 0.5.
490
+ alpha (float, optional): Transparency level for the contour lines. Defaults to 1.0.
491
+ fill_alpha (float, optional): Transparency level for the contour fill. Defaults to 0.2.
474
492
  """
475
493
  # Extract the positions of the specified nodes
476
494
  points = np.array([pos[n] for n in nodes])
@@ -496,8 +514,8 @@ class NetworkPlotter:
496
514
  min_density, max_density = z.min(), z.max()
497
515
  contour_levels = np.linspace(min_density, max_density, levels)[1:]
498
516
  contour_colors = [color for _ in range(levels - 1)]
499
- # Plot the filled contours only if alpha > 0
500
- if alpha > 0:
517
+ # Plot the filled contours using fill_alpha for transparency
518
+ if fill_alpha > 0:
501
519
  ax.contourf(
502
520
  x,
503
521
  y,
@@ -505,10 +523,10 @@ class NetworkPlotter:
505
523
  levels=contour_levels,
506
524
  colors=contour_colors,
507
525
  antialiased=True,
508
- alpha=alpha,
526
+ alpha=fill_alpha,
509
527
  )
510
528
 
511
- # Plot the contour lines without any change in behavior
529
+ # Plot the contour lines with the specified alpha for transparency
512
530
  c = ax.contour(
513
531
  x,
514
532
  y,
@@ -517,7 +535,9 @@ class NetworkPlotter:
517
535
  colors=contour_colors,
518
536
  linestyles=linestyle,
519
537
  linewidths=linewidth,
538
+ alpha=alpha,
520
539
  )
540
+ # Set linewidth for the contour lines to 0 for levels other than the base level
521
541
  for i in range(1, len(contour_levels)):
522
542
  c.collections[i].set_linewidth(0)
523
543
 
@@ -580,9 +600,9 @@ class NetworkPlotter:
580
600
  )
581
601
 
582
602
  # Convert colors to RGBA using the _to_rgba helper function, applying alpha separately for font and arrow
583
- fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes))
603
+ fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes_map))
584
604
  arrow_color = _to_rgba(
585
- arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_to_nodes)
605
+ arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_to_nodes_map)
586
606
  )
587
607
 
588
608
  # Normalize words_to_omit to lowercase
@@ -591,7 +611,7 @@ class NetworkPlotter:
591
611
 
592
612
  # Calculate the center and radius of the network
593
613
  domain_centroids = {}
594
- for domain, nodes in self.graph.domain_to_nodes.items():
614
+ for domain, nodes in self.graph.domain_to_nodes_map.items():
595
615
  if nodes: # Skip if the domain has no nodes
596
616
  domain_centroids[domain] = self._calculate_domain_centroid(nodes)
597
617
 
@@ -800,28 +820,28 @@ class NetworkPlotter:
800
820
  return adjusted_network_colors
801
821
 
802
822
  def get_annotated_node_sizes(
803
- self, enriched_nodesize: int = 50, nonenriched_nodesize: int = 25
823
+ self, enriched_size: int = 50, nonenriched_size: int = 25
804
824
  ) -> np.ndarray:
805
825
  """Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
806
826
 
807
827
  Args:
808
- enriched_nodesize (int): Size for enriched nodes. Defaults to 50.
809
- nonenriched_nodesize (int): Size for non-enriched nodes. Defaults to 25.
828
+ enriched_size (int): Size for enriched nodes. Defaults to 50.
829
+ nonenriched_size (int): Size for non-enriched nodes. Defaults to 25.
810
830
 
811
831
  Returns:
812
832
  np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
813
833
  """
814
- # Merge all enriched nodes from the domain_to_nodes dictionary
834
+ # Merge all enriched nodes from the domain_to_nodes_map dictionary
815
835
  enriched_nodes = set()
816
- for _, nodes in self.graph.domain_to_nodes.items():
836
+ for _, nodes in self.graph.domain_to_nodes_map.items():
817
837
  enriched_nodes.update(nodes)
818
838
 
819
839
  # Initialize all node sizes to the non-enriched size
820
- node_sizes = np.full(len(self.graph.network.nodes), nonenriched_nodesize)
840
+ node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
821
841
  # Set the size for enriched nodes
822
842
  for node in enriched_nodes:
823
843
  if node in self.graph.network.nodes:
824
- node_sizes[node] = enriched_nodesize
844
+ node_sizes[node] = enriched_size
825
845
 
826
846
  return node_sizes
827
847
 
@@ -928,7 +948,7 @@ class NetworkPlotter:
928
948
  random_seed=random_seed,
929
949
  )
930
950
  annotated_colors = []
931
- for _, nodes in self.graph.domain_to_nodes.items():
951
+ for _, nodes in self.graph.domain_to_nodes_map.items():
932
952
  if len(nodes) > 1:
933
953
  # For multi-node domains, choose the brightest color based on RGB sum
934
954
  domain_colors = np.array([node_colors[node] for node in nodes])
@@ -969,33 +989,51 @@ def _to_rgba(
969
989
  alpha: float = 1.0,
970
990
  num_repeats: Union[int, None] = None,
971
991
  ) -> np.ndarray:
972
- """Convert a color or array of colors to RGBA format, applying or updating the alpha as needed.
992
+ """Convert a color or array of colors to RGBA format, applying alpha only if the color is RGB.
973
993
 
974
994
  Args:
975
995
  color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
976
- alpha (float, optional): Alpha value (transparency) to apply. Defaults to 1.0.
996
+ alpha (float, optional): Alpha value (transparency) to apply if the color is in RGB format. Defaults to 1.0.
977
997
  num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
978
998
 
979
999
  Returns:
980
1000
  np.ndarray: The RGBA color or array of RGBA colors.
981
1001
  """
982
- # Handle single color case
1002
+ # Handle single color case (string, RGB, or RGBA)
983
1003
  if isinstance(color, str) or (
984
1004
  isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
985
1005
  ):
986
- rgba_color = np.array(mcolors.to_rgba(color, alpha))
987
- # Repeat the color if repeat argument is provided
1006
+ rgba_color = np.array(mcolors.to_rgba(color))
1007
+ # Only set alpha if the input is an RGB color or a string (not RGBA)
1008
+ if len(rgba_color) == 4 and (
1009
+ len(color) == 3 or isinstance(color, str)
1010
+ ): # If it's RGB or a string, set the alpha
1011
+ rgba_color[3] = alpha
1012
+
1013
+ # Repeat the color if num_repeats argument is provided
988
1014
  if num_repeats is not None:
989
1015
  return np.array([rgba_color] * num_repeats)
990
1016
 
991
1017
  return rgba_color
992
1018
 
993
- # Handle array of colors case
994
- elif isinstance(color, (list, tuple, np.ndarray)) and isinstance(
995
- color[0], (list, tuple, np.ndarray)
996
- ):
997
- rgba_colors = [mcolors.to_rgba(c, alpha) if len(c) == 3 else np.array(c) for c in color]
998
- # Repeat the colors if repeat argument is provided
1019
+ # Handle array of colors case (including strings, RGB, and RGBA)
1020
+ elif isinstance(color, (list, tuple, np.ndarray)):
1021
+ rgba_colors = []
1022
+ for c in color:
1023
+ # Ensure each element is either a valid string or a list/tuple of length 3 (RGB) or 4 (RGBA)
1024
+ if isinstance(c, str) or (
1025
+ isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]
1026
+ ):
1027
+ rgba_c = np.array(mcolors.to_rgba(c))
1028
+ # Apply alpha only to RGB colors (not RGBA) and strings
1029
+ if len(rgba_c) == 4 and (len(c) == 3 or isinstance(c, str)):
1030
+ rgba_c[3] = alpha
1031
+
1032
+ rgba_colors.append(rgba_c)
1033
+ else:
1034
+ raise ValueError(f"Invalid color: {c}. Must be a valid RGB/RGBA or string color.")
1035
+
1036
+ # Repeat the colors if num_repeats argument is provided
999
1037
  if num_repeats is not None and len(rgba_colors) == 1:
1000
1038
  return np.array([rgba_colors[0]] * num_repeats)
1001
1039
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.6b3
3
+ Version: 0.0.6b5
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
File without changes
File without changes
File without changes
File without changes