risk-network 0.0.8b3__tar.gz → 0.0.8b5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/PKG-INFO +1 -1
  2. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/__init__.py +1 -1
  3. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/plot.py +204 -102
  4. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/PKG-INFO +1 -1
  5. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/LICENSE +0 -0
  6. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/MANIFEST.in +0 -0
  7. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/README.md +0 -0
  8. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/pyproject.toml +0 -0
  9. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/annotations/__init__.py +0 -0
  10. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/annotations/annotations.py +0 -0
  11. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/annotations/io.py +0 -0
  12. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/constants.py +0 -0
  13. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/log/__init__.py +0 -0
  14. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/log/config.py +0 -0
  15. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/log/params.py +0 -0
  16. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/__init__.py +0 -0
  17. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/community.py +0 -0
  18. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/domains.py +0 -0
  19. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/neighborhoods.py +0 -0
  20. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/__init__.py +0 -0
  21. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/geometry.py +0 -0
  22. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/graph.py +0 -0
  23. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/io.py +0 -0
  24. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/risk.py +0 -0
  25. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/__init__.py +0 -0
  26. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/hypergeom.py +0 -0
  27. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/permutation/__init__.py +0 -0
  28. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/permutation/permutation.py +0 -0
  29. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/permutation/test_functions.py +0 -0
  30. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/poisson.py +0 -0
  31. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/stats.py +0 -0
  32. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/SOURCES.txt +0 -0
  33. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/dependency_links.txt +0 -0
  34. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/requires.txt +0 -0
  35. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/top_level.txt +0 -0
  36. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/setup.cfg +0 -0
  37. {risk_network-0.0.8b3 → risk_network-0.0.8b5}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b3
3
+ Version: 0.0.8b5
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.3"
10
+ __version__ = "0.0.8-beta.5"
@@ -33,6 +33,7 @@ class NetworkPlotter:
33
33
  graph: NetworkGraph,
34
34
  figsize: Tuple = (10, 10),
35
35
  background_color: Union[str, List, Tuple, np.ndarray] = "white",
36
+ background_alpha: float = 1.0,
36
37
  ) -> None:
37
38
  """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
38
39
 
@@ -40,16 +41,23 @@ class NetworkPlotter:
40
41
  graph (NetworkGraph): The network data and attributes to be visualized.
41
42
  figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
42
43
  background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
44
+ background_alpha (float, optional): Transparency level of the background color. Defaults to 1.0.
43
45
  """
44
46
  self.graph = graph
45
47
  # Initialize the plot with the specified parameters
46
- self.ax = self._initialize_plot(graph, figsize, background_color)
48
+ self.ax = self._initialize_plot(
49
+ graph=graph,
50
+ figsize=figsize,
51
+ background_color=background_color,
52
+ background_alpha=background_alpha,
53
+ )
47
54
 
48
55
  def _initialize_plot(
49
56
  self,
50
57
  graph: NetworkGraph,
51
58
  figsize: Tuple,
52
59
  background_color: Union[str, List, Tuple, np.ndarray],
60
+ background_alpha: float = 1.0,
53
61
  ) -> plt.Axes:
54
62
  """Set up the plot with figure size and background color.
55
63
 
@@ -57,6 +65,7 @@ class NetworkPlotter:
57
65
  graph (NetworkGraph): The network data and attributes to be visualized.
58
66
  figsize (tuple): Size of the figure in inches (width, height).
59
67
  background_color (str): Background color of the plot.
68
+ background_alpha (float, optional): Transparency level of the background color. Defaults to 1.0.
60
69
 
61
70
  Returns:
62
71
  plt.Axes: The axis object for the plot.
@@ -76,7 +85,7 @@ class NetworkPlotter:
76
85
 
77
86
  # Set the background color of the plot
78
87
  # Convert color to RGBA using the _to_rgba helper function
79
- fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
88
+ fig.patch.set_facecolor(_to_rgba(color=background_color, alpha=background_alpha))
80
89
  ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
81
90
  # Remove axis spines for a cleaner look
82
91
  for spine in ax.spines.values():
@@ -199,7 +208,7 @@ class NetworkPlotter:
199
208
  )
200
209
 
201
210
  # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
202
- color = _to_rgba(color, outline_alpha)
211
+ color = _to_rgba(color=color, alpha=outline_alpha)
203
212
  # Extract node coordinates from the network graph
204
213
  node_coordinates = self.graph.node_coordinates
205
214
  # Calculate the center and radius of the bounding box around the network
@@ -218,7 +227,7 @@ class NetworkPlotter:
218
227
  )
219
228
  # Set the transparency of the fill if applicable
220
229
  if fill_alpha > 0:
221
- circle.set_facecolor(_to_rgba(color, fill_alpha))
230
+ circle.set_facecolor(_to_rgba(color=color, alpha=fill_alpha))
222
231
 
223
232
  self.ax.add_artist(circle)
224
233
 
@@ -263,7 +272,7 @@ class NetworkPlotter:
263
272
  )
264
273
 
265
274
  # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
266
- color = _to_rgba(color, outline_alpha)
275
+ color = _to_rgba(color=color, alpha=outline_alpha)
267
276
  # Extract node coordinates from the network graph
268
277
  node_coordinates = self.graph.node_coordinates
269
278
  # Scale the node coordinates if needed
@@ -326,9 +335,15 @@ class NetworkPlotter:
326
335
 
327
336
  # Convert colors to RGBA using the _to_rgba helper function
328
337
  # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
329
- node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
330
- node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes))
331
- edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
338
+ node_color = _to_rgba(
339
+ color=node_color, alpha=node_alpha, num_repeats=len(self.graph.network.nodes)
340
+ )
341
+ node_edgecolor = _to_rgba(
342
+ color=node_edgecolor, alpha=1.0, num_repeats=len(self.graph.network.nodes)
343
+ )
344
+ edge_color = _to_rgba(
345
+ color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
346
+ )
332
347
 
333
348
  # Extract node coordinates from the network graph
334
349
  node_coordinates = self.graph.node_coordinates
@@ -405,9 +420,11 @@ class NetworkPlotter:
405
420
  ]
406
421
 
407
422
  # Convert colors to RGBA using the _to_rgba helper function
408
- node_color = _to_rgba(node_color, node_alpha, num_repeats=len(node_ids))
409
- node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(node_ids))
410
- edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
423
+ node_color = _to_rgba(color=node_color, alpha=node_alpha, num_repeats=len(node_ids))
424
+ node_edgecolor = _to_rgba(color=node_edgecolor, alpha=1.0, num_repeats=len(node_ids))
425
+ edge_color = _to_rgba(
426
+ color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
427
+ )
411
428
 
412
429
  # Get the coordinates of the filtered nodes
413
430
  node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
@@ -470,7 +487,9 @@ class NetworkPlotter:
470
487
  )
471
488
 
472
489
  # Ensure color is converted to RGBA with repetition matching the number of domains
473
- color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map))
490
+ color = _to_rgba(
491
+ color=color, alpha=alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
492
+ )
474
493
  # Extract node coordinates from the network graph
475
494
  node_coordinates = self.graph.node_coordinates
476
495
  # Draw contours for each domain in the network
@@ -527,7 +546,7 @@ class NetworkPlotter:
527
546
  node_groups = [nodes]
528
547
 
529
548
  # Convert color to RGBA using the _to_rgba helper function
530
- color_rgba = _to_rgba(color, alpha)
549
+ color_rgba = _to_rgba(color=color, alpha=alpha)
531
550
 
532
551
  # Iterate over each group of nodes (either sublists or flat list)
533
552
  for sublist in node_groups:
@@ -705,10 +724,10 @@ class NetworkPlotter:
705
724
  arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
706
725
  arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
707
726
  max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
708
- max_label_lines (int, optional): Maximum number of lines in a label. Defaults to None (no limit).
709
727
  min_label_lines (int, optional): Minimum number of lines in a label. Defaults to 1.
710
- max_chars_per_line (int, optional): Maximum number of characters in a line to display. Defaults to None (no limit).
728
+ max_label_lines (int, optional): Maximum number of lines in a label. Defaults to None (no limit).
711
729
  min_chars_per_line (int, optional): Minimum number of characters in a line to display. Defaults to 1.
730
+ max_chars_per_line (int, optional): Maximum number of characters in a line to display. Defaults to None (no limit).
712
731
  words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
713
732
  overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
714
733
  ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
@@ -761,11 +780,11 @@ class NetworkPlotter:
761
780
  if words_to_omit:
762
781
  words_to_omit = set(word.lower() for word in words_to_omit)
763
782
 
764
- # Calculate the center and radius of the network
765
- domain_centroids = {}
783
+ # Calculate the center and radius of domains to position labels around the network
784
+ domain_id_to_centroid_map = {}
766
785
  for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
767
786
  if node_ids: # Skip if the domain has no nodes
768
- domain_centroids[domain_id] = self._calculate_domain_centroid(node_ids)
787
+ domain_id_to_centroid_map[domain_id] = self._calculate_domain_centroid(node_ids)
769
788
 
770
789
  # Initialize dictionaries and lists for valid indices
771
790
  valid_indices = [] # List of valid indices to plot colors and arrows
@@ -775,8 +794,8 @@ class NetworkPlotter:
775
794
  if ids_to_keep:
776
795
  # Process the ids_to_keep first INPLACE
777
796
  self._process_ids_to_keep(
797
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
778
798
  ids_to_keep=ids_to_keep,
779
- domain_centroids=domain_centroids,
780
799
  ids_to_replace=ids_to_replace,
781
800
  words_to_omit=words_to_omit,
782
801
  max_labels=max_labels,
@@ -796,7 +815,7 @@ class NetworkPlotter:
796
815
  # Process remaining domains INPLACE to fill in additional labels, if there are slots left
797
816
  if remaining_labels and remaining_labels > 0:
798
817
  self._process_remaining_domains(
799
- domain_centroids=domain_centroids,
818
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
800
819
  ids_to_keep=ids_to_keep,
801
820
  ids_to_replace=ids_to_replace,
802
821
  words_to_omit=words_to_omit,
@@ -816,12 +835,14 @@ class NetworkPlotter:
816
835
  best_label_positions = _calculate_best_label_positions(
817
836
  filtered_domain_centroids, center, radius, offset
818
837
  )
819
- # Convert colors to RGBA using the _to_rgba helper function
838
+ # Convert all domain colors to RGBA using the _to_rgba helper function
820
839
  fontcolor = _to_rgba(
821
- fontcolor, fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
840
+ color=fontcolor, alpha=fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
822
841
  )
823
842
  arrow_color = _to_rgba(
824
- arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
843
+ color=arrow_color,
844
+ alpha=arrow_alpha,
845
+ num_repeats=len(self.graph.domain_id_to_node_ids_map),
825
846
  )
826
847
 
827
848
  # Annotate the network with labels
@@ -847,8 +868,11 @@ class NetworkPlotter:
847
868
  shrinkB=arrow_tip_shrink,
848
869
  ),
849
870
  )
850
- # Overlay domain ID at the centroid if requested
851
- if overlay_ids:
871
+
872
+ # Overlay domain ID at the centroid regardless of max_labels if requested
873
+ if overlay_ids:
874
+ for idx, domain in enumerate(self.graph.domain_id_to_node_ids_map):
875
+ centroid = domain_id_to_centroid_map[domain]
852
876
  self.ax.text(
853
877
  centroid[0],
854
878
  centroid[1],
@@ -907,8 +931,8 @@ class NetworkPlotter:
907
931
  node_groups = [nodes]
908
932
 
909
933
  # Convert fontcolor and arrow_color to RGBA
910
- fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
911
- arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
934
+ fontcolor_rgba = _to_rgba(color=fontcolor, alpha=fontalpha)
935
+ arrow_color_rgba = _to_rgba(color=arrow_color, alpha=arrow_alpha)
912
936
 
913
937
  # Calculate the bounding box around the network
914
938
  center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
@@ -977,8 +1001,8 @@ class NetworkPlotter:
977
1001
 
978
1002
  def _process_ids_to_keep(
979
1003
  self,
1004
+ domain_id_to_centroid_map: Dict[str, np.ndarray],
980
1005
  ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
981
- domain_centroids: Dict[str, np.ndarray],
982
1006
  ids_to_replace: Union[Dict[str, str], None],
983
1007
  words_to_omit: Union[List[str], None],
984
1008
  max_labels: Union[int, None],
@@ -993,8 +1017,8 @@ class NetworkPlotter:
993
1017
  """Process the ids_to_keep, apply filtering, and store valid domain centroids and terms.
994
1018
 
995
1019
  Args:
1020
+ domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
996
1021
  ids_to_keep (list, tuple, or np.ndarray, optional): IDs of domains that must be labeled.
997
- domain_centroids (dict): Mapping of domains to their centroids.
998
1022
  ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
999
1023
  words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
1000
1024
  max_labels (int, optional): Maximum number of labels allowed.
@@ -1020,25 +1044,30 @@ class NetworkPlotter:
1020
1044
 
1021
1045
  # Process each domain in ids_to_keep
1022
1046
  for domain in ids_to_keep:
1023
- if domain in self.graph.domain_id_to_domain_terms_map and domain in domain_centroids:
1024
- domain_terms = self._process_terms(
1047
+ if (
1048
+ domain in self.graph.domain_id_to_domain_terms_map
1049
+ and domain in domain_id_to_centroid_map
1050
+ ):
1051
+ domain_centroid = domain_id_to_centroid_map[domain]
1052
+ # No need to filter the domain terms if it is in ids_to_keep
1053
+ _ = self._validate_and_update_domain(
1025
1054
  domain=domain,
1055
+ domain_centroid=domain_centroid,
1056
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
1026
1057
  ids_to_replace=ids_to_replace,
1027
1058
  words_to_omit=words_to_omit,
1059
+ min_label_lines=min_label_lines,
1028
1060
  max_label_lines=max_label_lines,
1029
1061
  min_chars_per_line=min_chars_per_line,
1030
1062
  max_chars_per_line=max_chars_per_line,
1063
+ filtered_domain_centroids=filtered_domain_centroids,
1064
+ filtered_domain_terms=filtered_domain_terms,
1065
+ valid_indices=valid_indices,
1031
1066
  )
1032
- num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
1033
- # Check if the number of lines in the label is greater than or equal to the minimum
1034
- if num_domain_lines >= min_label_lines:
1035
- filtered_domain_terms[domain] = domain_terms
1036
- filtered_domain_centroids[domain] = domain_centroids[domain]
1037
- valid_indices.append(list(domain_centroids.keys()).index(domain))
1038
1067
 
1039
1068
  def _process_remaining_domains(
1040
1069
  self,
1041
- domain_centroids: Dict[str, np.ndarray],
1070
+ domain_id_to_centroid_map: Dict[str, np.ndarray],
1042
1071
  ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
1043
1072
  ids_to_replace: Union[Dict[str, str], None],
1044
1073
  words_to_omit: Union[List[str], None],
@@ -1054,7 +1083,7 @@ class NetworkPlotter:
1054
1083
  """Process remaining domains to fill in additional labels, respecting the remaining_labels limit.
1055
1084
 
1056
1085
  Args:
1057
- domain_centroids (dict): Mapping of domains to their centroids.
1086
+ domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
1058
1087
  ids_to_keep (list, tuple, or np.ndarray, optional): IDs of domains that must be labeled.
1059
1088
  ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
1060
1089
  words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
@@ -1066,13 +1095,16 @@ class NetworkPlotter:
1066
1095
  filtered_domain_centroids (dict): Dictionary to store filtered domain centroids (output).
1067
1096
  filtered_domain_terms (dict): Dictionary to store filtered domain terms (output).
1068
1097
  valid_indices (list): List to store valid indices (output).
1098
+
1099
+ Note:
1100
+ The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
1069
1101
  """
1070
1102
  # Counter to track how many labels have been created
1071
1103
  label_count = 0
1072
1104
  # Collect domains not in ids_to_keep
1073
1105
  remaining_domains = {
1074
1106
  domain: centroid
1075
- for domain, centroid in domain_centroids.items()
1107
+ for domain, centroid in domain_id_to_centroid_map.items()
1076
1108
  if domain not in ids_to_keep and not pd.isna(domain)
1077
1109
  }
1078
1110
 
@@ -1112,26 +1144,89 @@ class NetworkPlotter:
1112
1144
 
1113
1145
  # Process the selected domains and add to filtered lists
1114
1146
  for domain in selected_domains:
1115
- centroid = remaining_domains[domain]
1116
- domain_terms = self._process_terms(
1147
+ domain_centroid = remaining_domains[domain]
1148
+ is_domain_valid = self._validate_and_update_domain(
1117
1149
  domain=domain,
1150
+ domain_centroid=domain_centroid,
1151
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
1118
1152
  ids_to_replace=ids_to_replace,
1119
1153
  words_to_omit=words_to_omit,
1154
+ min_label_lines=min_label_lines,
1120
1155
  max_label_lines=max_label_lines,
1121
1156
  min_chars_per_line=min_chars_per_line,
1122
1157
  max_chars_per_line=max_chars_per_line,
1158
+ filtered_domain_centroids=filtered_domain_centroids,
1159
+ filtered_domain_terms=filtered_domain_terms,
1160
+ valid_indices=valid_indices,
1123
1161
  )
1124
- num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
1125
- # Check if the number of lines in the label is greater than or equal to the minimum
1126
- if num_domain_lines >= min_label_lines:
1127
- filtered_domain_centroids[domain] = centroid
1128
- filtered_domain_terms[domain] = domain_terms
1129
- valid_indices.append(list(domain_centroids.keys()).index(domain))
1130
-
1162
+ # Increment the label count if the domain is valid
1163
+ if is_domain_valid:
1131
1164
  label_count += 1
1132
1165
  if label_count >= remaining_labels:
1133
1166
  break
1134
1167
 
1168
+ def _validate_and_update_domain(
1169
+ self,
1170
+ domain: str,
1171
+ domain_centroid: np.ndarray,
1172
+ domain_id_to_centroid_map: Dict[str, np.ndarray],
1173
+ ids_to_replace: Union[Dict[str, str], None],
1174
+ words_to_omit: Union[List[str], None],
1175
+ min_label_lines: int,
1176
+ max_label_lines: int,
1177
+ min_chars_per_line: int,
1178
+ max_chars_per_line: int,
1179
+ filtered_domain_centroids: Dict[str, np.ndarray],
1180
+ filtered_domain_terms: Dict[str, str],
1181
+ valid_indices: List[int],
1182
+ ) -> bool:
1183
+ """Validate and process the domain terms, updating relevant dictionaries if valid.
1184
+
1185
+ Args:
1186
+ domain (str): Domain ID to process.
1187
+ domain_centroid (np.ndarray): Centroid position of the domain.
1188
+ domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
1189
+ ids_to_replace (Union[Dict[str, str], None]): A dictionary mapping domain IDs to custom labels.
1190
+ words_to_omit (Union[List[str], None]): List of words to omit from the labels.
1191
+ min_label_lines (int): Minimum number of lines required in a label.
1192
+ max_label_lines (int): Maximum number of lines allowed in a label.
1193
+ min_chars_per_line (int): Minimum number of characters allowed per line.
1194
+ max_chars_per_line (int): Maximum number of characters allowed per line.
1195
+ filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store valid domain centroids.
1196
+ filtered_domain_terms (Dict[str, str]): Dictionary to store valid domain terms.
1197
+ valid_indices (List[int]): List of valid domain indices.
1198
+
1199
+ Returns:
1200
+ bool: True if the domain is valid and added to the filtered dictionaries, False otherwise.
1201
+
1202
+ Note:
1203
+ The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
1204
+ """
1205
+ # Process the domain terms
1206
+ domain_terms = self._process_terms(
1207
+ domain=domain,
1208
+ ids_to_replace=ids_to_replace,
1209
+ words_to_omit=words_to_omit,
1210
+ max_label_lines=max_label_lines,
1211
+ min_chars_per_line=min_chars_per_line,
1212
+ max_chars_per_line=max_chars_per_line,
1213
+ )
1214
+ # If domain_terms is empty, skip further processing
1215
+ if not domain_terms:
1216
+ return False
1217
+
1218
+ # Split the terms by TERM_DELIMITER and count the number of lines
1219
+ num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
1220
+ # Check if the number of lines is greater than or equal to the minimum
1221
+ if num_domain_lines >= min_label_lines:
1222
+ filtered_domain_centroids[domain] = domain_centroid
1223
+ filtered_domain_terms[domain] = domain_terms
1224
+ # Add the index of the domain to the valid indices list
1225
+ valid_indices.append(list(domain_id_to_centroid_map.keys()).index(domain))
1226
+ return True
1227
+
1228
+ return False
1229
+
1135
1230
  def _process_terms(
1136
1231
  self,
1137
1232
  domain: str,
@@ -1152,7 +1247,7 @@ class NetworkPlotter:
1152
1247
  max_chars_per_line (int): Maximum number of characters in a line to display.
1153
1248
 
1154
1249
  Returns:
1155
- list: Processed terms, with words combined if necessary to fit within constraints.
1250
+ str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
1156
1251
  """
1157
1252
  # Handle ids_to_replace logic
1158
1253
  if ids_to_replace and domain in ids_to_replace:
@@ -1213,11 +1308,11 @@ class NetworkPlotter:
1213
1308
  # Apply the alpha value for enriched nodes
1214
1309
  network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
1215
1310
  # Convert the non-enriched color to RGBA using the _to_rgba helper function
1216
- nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
1311
+ nonenriched_color = _to_rgba(color=nonenriched_color, alpha=nonenriched_alpha)
1217
1312
  # Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
1218
1313
  adjusted_network_colors = np.where(
1219
1314
  np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
1220
- np.array([nonenriched_color]), # Apply the non-enriched color with alpha
1315
+ np.array(nonenriched_color), # Apply the non-enriched color with alpha
1221
1316
  network_colors, # Keep the original colors for enriched nodes
1222
1317
  )
1223
1318
  return adjusted_network_colors
@@ -1391,62 +1486,65 @@ class NetworkPlotter:
1391
1486
 
1392
1487
  def _to_rgba(
1393
1488
  color: Union[str, List, Tuple, np.ndarray],
1394
- alpha: float = 1.0,
1489
+ alpha: Union[float, None] = None,
1395
1490
  num_repeats: Union[int, None] = None,
1396
1491
  ) -> np.ndarray:
1397
- """Convert a color or array of colors to RGBA format, applying alpha only if the color is RGB.
1492
+ """Convert color(s) to RGBA format, applying alpha and repeating as needed.
1398
1493
 
1399
1494
  Args:
1400
1495
  color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
1401
- alpha (float, optional): Alpha value (transparency) to apply if the color is in RGB format. Defaults to 1.0.
1402
- num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
1496
+ alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values.
1497
+ num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
1403
1498
 
1404
1499
  Returns:
1405
- np.ndarray: The RGBA color or array of RGBA colors.
1500
+ np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
1406
1501
  """
1407
- # Handle single color case (string, RGB, or RGBA)
1408
- if isinstance(color, str) or (
1409
- isinstance(color, (list, tuple, np.ndarray))
1410
- and len(color) in [3, 4]
1411
- and not any(isinstance(c, (list, tuple, np.ndarray)) for c in color)
1412
- ):
1413
- rgba_color = np.array(mcolors.to_rgba(color))
1414
- # Only set alpha if the input is an RGB color or a string (not RGBA)
1415
- if len(rgba_color) == 4 and (
1416
- len(color) == 3 or isinstance(color, str)
1417
- ): # If it's RGB or a string, set the alpha
1418
- rgba_color[3] = alpha
1419
1502
 
1420
- # Repeat the color if num_repeats argument is provided
1421
- if num_repeats is not None:
1422
- return np.array([rgba_color] * num_repeats)
1503
+ def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
1504
+ """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
1505
+ if isinstance(c, str):
1506
+ # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
1507
+ rgba = np.array(mcolors.to_rgba(c))
1508
+ elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
1509
+ # Convert RGB (3) or RGBA (4) values to RGBA format
1510
+ rgba = np.array(mcolors.to_rgba(c))
1511
+ else:
1512
+ raise ValueError(
1513
+ f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
1514
+ )
1423
1515
 
1424
- return rgba_color
1516
+ if alpha is not None: # Override alpha if provided
1517
+ rgba[3] = alpha
1518
+ return rgba
1425
1519
 
1426
- # Handle array of colors case (including strings, RGB, and RGBA)
1427
- elif isinstance(color, (list, tuple, np.ndarray)):
1428
- rgba_colors = []
1429
- for i in range(num_repeats):
1430
- # Reiterate over the colors if the number of repeats exceeds the number of colors
1431
- c = color[i % len(color)]
1432
- # Ensure each element is either a valid string or a list/tuple of length 3 (RGB) or 4 (RGBA)
1433
- if isinstance(c, str) or (
1434
- isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]
1435
- ):
1436
- rgba_c = np.array(mcolors.to_rgba(c))
1437
- # Apply alpha only to RGB colors (not RGBA) and strings
1438
- if len(rgba_c) == 4 and (len(c) == 3 or isinstance(c, str)):
1439
- rgba_c[3] = alpha
1440
-
1441
- rgba_colors.append(rgba_c)
1442
- else:
1443
- raise ValueError(f"Invalid color: {c}. Must be a valid RGB/RGBA or string color.")
1520
+ # If color is a 2D array of RGBA values, convert it to a list of lists
1521
+ if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
1522
+ color = [list(c) for c in color]
1444
1523
 
1445
- # Repeat the colors if num_repeats argument is provided
1446
- if num_repeats is not None and len(rgba_colors) == 1:
1447
- return np.array([rgba_colors[0]] * num_repeats)
1524
+ # Handle a single color (string or RGB/RGBA list/tuple)
1525
+ if isinstance(color, (str, list, tuple)) and not any(
1526
+ isinstance(c, (list, tuple, np.ndarray)) for c in color
1527
+ ):
1528
+ rgba_color = convert_to_rgba(color)
1529
+ if num_repeats:
1530
+ return np.tile(
1531
+ rgba_color, (num_repeats, 1)
1532
+ ) # Repeat the color if num_repeats is provided
1533
+ return np.array([rgba_color]) # Return a single color wrapped in a numpy array
1534
+
1535
+ # Handle a list/array of colors
1536
+ elif isinstance(color, (list, tuple, np.ndarray)):
1537
+ rgba_colors = np.array(
1538
+ [convert_to_rgba(c) for c in color]
1539
+ ) # Convert each color in the list to RGBA
1540
+ # Handle repetition if num_repeats is provided
1541
+ if num_repeats:
1542
+ repeated_colors = np.array(
1543
+ [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
1544
+ )
1545
+ return repeated_colors
1448
1546
 
1449
- return np.array(rgba_colors)
1547
+ return rgba_colors
1450
1548
 
1451
1549
  else:
1452
1550
  raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
@@ -1487,13 +1585,13 @@ def _calculate_bounding_box(
1487
1585
  return center, radius
1488
1586
 
1489
1587
 
1490
- def _combine_words(words: List[str], max_length: int, max_label_lines: int) -> str:
1491
- """Combine words to fit within the max_length and max_label_lines constraints,
1492
- and separate the final output by ':' for plotting.
1588
+ def _combine_words(words: List[str], max_chars_per_line: int, max_label_lines: int) -> str:
1589
+ """Combine words to fit within the max_chars_per_line and max_label_lines constraints,
1590
+ and separate the final output by TERM_DELIMITER for plotting.
1493
1591
 
1494
1592
  Args:
1495
1593
  words (List[str]): List of words to combine.
1496
- max_length (int): Maximum allowed length for a combined line.
1594
+ max_chars_per_line (int): Maximum number of characters in a line to display.
1497
1595
  max_label_lines (int): Maximum number of lines in a label.
1498
1596
 
1499
1597
  Returns:
@@ -1510,14 +1608,18 @@ def _combine_words(words: List[str], max_length: int, max_label_lines: int) -> s
1510
1608
  # Try to combine more words if possible, and ensure the combination fits within max_length
1511
1609
  for j in range(i + 1, len(words_batch)):
1512
1610
  next_word = words_batch[j]
1513
- if len(combined_word) + len(next_word) + 2 <= max_length: # +2 for ', '
1611
+ # Ensure that the combined word fits within the max_chars_per_line limit
1612
+ if len(combined_word) + len(next_word) + 1 <= max_chars_per_line: # +1 for space
1514
1613
  combined_word = f"{combined_word} {next_word}"
1515
1614
  i += 1 # Move past the combined word
1516
1615
  else:
1517
1616
  break # Stop combining if the length is exceeded
1518
1617
 
1519
- combined_lines.append(combined_word) # Add the combined word or single word
1520
- i += 1 # Move to the next word
1618
+ # Add the combined word only if it fits within the max_chars_per_line limit
1619
+ if len(combined_word) <= max_chars_per_line:
1620
+ combined_lines.append(combined_word) # Add the combined word
1621
+ # Move to the next word
1622
+ i += 1
1521
1623
 
1522
1624
  # Stop if we've reached the max_label_lines limit
1523
1625
  if len(combined_lines) >= max_label_lines:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b3
3
+ Version: 0.0.8b5
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
File without changes
File without changes
File without changes
File without changes