LZGraphs 2.0.0__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/CHANGELOG.md +1 -1
  2. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/PKG-INFO +1 -1
  3. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/__init__.py +15 -1
  4. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/graphs/amino_acid_positional.py +33 -37
  5. lzgraphs-2.1.0/src/LZGraphs/graphs/edge_data.py +197 -0
  6. lzgraphs-2.1.0/src/LZGraphs/graphs/graph_operations.py +115 -0
  7. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/graphs/lz_graph_base.py +214 -67
  8. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/graphs/naive.py +5 -2
  9. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/graphs/nucleotide_double_positional.py +13 -29
  10. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/metrics/__init__.py +14 -0
  11. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/metrics/convenience.py +8 -0
  12. lzgraphs-2.1.0/src/LZGraphs/metrics/entropy.py +1007 -0
  13. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/mixins/gene_logic.py +8 -42
  14. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/mixins/gene_prediction.py +33 -85
  15. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/utilities/__init__.py +0 -4
  16. lzgraphs-2.1.0/src/LZGraphs/utilities/helpers.py +50 -0
  17. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/utilities/misc.py +1 -13
  18. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs.egg-info/PKG-INFO +1 -1
  19. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs.egg-info/SOURCES.txt +1 -0
  20. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_metrics.py +204 -0
  21. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_new_features.py +74 -0
  22. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_pgen_fixes.py +3 -3
  23. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_serialization.py +1 -1
  24. lzgraphs-2.0.0/src/LZGraphs/graphs/graph_operations.py +0 -164
  25. lzgraphs-2.0.0/src/LZGraphs/metrics/entropy.py +0 -504
  26. lzgraphs-2.0.0/src/LZGraphs/utilities/helpers.py +0 -105
  27. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/CONTRIBUTING.md +0 -0
  28. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/LICENSE +0 -0
  29. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/MANIFEST.in +0 -0
  30. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/README.md +0 -0
  31. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/pyproject.toml +0 -0
  32. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/requirements.txt +0 -0
  33. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/setup.cfg +0 -0
  34. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/bag_of_words/__init__.py +0 -0
  35. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/bag_of_words/bow_encoder.py +0 -0
  36. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/exceptions/__init__.py +0 -0
  37. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/graphs/__init__.py +0 -0
  38. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/metrics/diversity.py +0 -0
  39. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/metrics/saturation.py +0 -0
  40. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/mixins/__init__.py +0 -0
  41. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/mixins/random_walk.py +0 -0
  42. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/py.typed +0 -0
  43. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/utilities/decomposition.py +0 -0
  44. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/visualization/__init__.py +0 -0
  45. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs/visualization/visualize.py +0 -0
  46. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs.egg-info/dependency_links.txt +0 -0
  47. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs.egg-info/requires.txt +0 -0
  48. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/src/LZGraphs.egg-info/top_level.txt +0 -0
  49. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_aap_lzgraph.py +0 -0
  50. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_base_class_methods.py +0 -0
  51. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_bow_encoder.py +0 -0
  52. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_diversity_theory.py +0 -0
  53. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_graph_operations.py +0 -0
  54. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_naive_lzgraph.py +0 -0
  55. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_ndp_lzgraph.py +0 -0
  56. {lzgraphs-2.0.0 → lzgraphs-2.1.0}/tests/test_utilities.py +0 -0
@@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
9
9
 
10
10
  ### Added
11
11
  - Custom exceptions module with comprehensive exception hierarchy for better error handling
12
- - Information-theoretic metrics module (`LZGraphs.Metrics.entropy`)
12
+ - Information-theoretic metrics module (`LZGraphs.metrics.entropy`)
13
13
  - `node_entropy()` - Shannon entropy of node probability distribution
14
14
  - `edge_entropy()` - Shannon entropy of edge transition probabilities
15
15
  - `graph_entropy()` - Combined graph entropy measure
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: LZGraphs
3
- Version: 2.0.0
3
+ Version: 2.1.0
4
4
  Summary: An Implementation of LZ76 Based Graphs for Repertoire Representation and Analysis
5
5
  Author-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
6
6
  Maintainer-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
@@ -1,4 +1,4 @@
1
- __version__ = "2.0.0"
1
+ __version__ = "2.1.0"
2
2
 
3
3
  # =============================================================================
4
4
  # Graph classes
@@ -44,6 +44,13 @@ from .metrics.entropy import (
44
44
  cross_entropy,
45
45
  kl_divergence,
46
46
  mutual_information_genes,
47
+ transition_predictability,
48
+ graph_compression_ratio,
49
+ repertoire_compressibility_index,
50
+ transition_kl_divergence,
51
+ transition_jsd,
52
+ transition_mutual_information_profile,
53
+ path_entropy_rate,
47
54
  )
48
55
 
49
56
  # =============================================================================
@@ -145,6 +152,13 @@ __all__ = [
145
152
  'cross_entropy',
146
153
  'kl_divergence',
147
154
  'mutual_information_genes',
155
+ 'transition_predictability',
156
+ 'graph_compression_ratio',
157
+ 'repertoire_compressibility_index',
158
+ 'transition_kl_divergence',
159
+ 'transition_jsd',
160
+ 'transition_mutual_information_profile',
161
+ 'path_entropy_rate',
148
162
  # Saturation
149
163
  'NodeEdgeSaturationProbe',
150
164
  # Convenience
@@ -144,11 +144,6 @@ class AAPLZGraph(LZGraphBase):
144
144
  self._normalize_edge_weights()
145
145
  self.verbose_driver(3, verbose)
146
146
 
147
- # If gene data is available, normalize gene weights in parallel
148
- if self.genetic:
149
- self._batch_gene_weight_normalization(verbose=verbose)
150
- self.verbose_driver(4, verbose)
151
-
152
147
  # Additional map derivations
153
148
  self.edges_list = None
154
149
  self._derive_terminal_state_map()
@@ -416,16 +411,16 @@ class AAPLZGraph(LZGraphBase):
416
411
  val = np.finfo(float).eps if use_epsilon else 0.0
417
412
  return (val, val)
418
413
 
419
- e_data = self.graph[step1][step2]
414
+ ed = self.graph[step1][step2]['data']
420
415
  # If these genes aren't on the edge, it's effectively 0
421
- if v not in e_data or j not in e_data:
416
+ if not ed.has_gene(v) or not ed.has_gene(j):
422
417
  if verbose:
423
418
  logger.warning(f"Edge {step1}->{step2} missing {v} or {j}.")
424
419
  val = np.finfo(float).eps if use_epsilon else 0.0
425
420
  return (val, val)
426
421
 
427
- proba_v *= e_data[v]
428
- proba_j *= e_data[j]
422
+ proba_v *= ed.v_probability(v)
423
+ proba_j *= ed.j_probability(j)
429
424
 
430
425
  return proba_v, proba_j
431
426
 
@@ -484,19 +479,18 @@ class AAPLZGraph(LZGraphBase):
484
479
  logger.warning(f"Current state {current_state} not in graph.")
485
480
  break
486
481
 
487
- edge_info = pd.DataFrame(dict(self.graph[current_state]))
482
+ # Get edges that have both selected V and J genes
483
+ edges = self.outgoing_edges(current_state)
488
484
  # Apply blacklist if present
489
485
  if (current_state, selected_v, selected_j) in self.genetic_walks_black_list:
490
486
  blacklisted = self.genetic_walks_black_list[(current_state, selected_v, selected_j)]
491
- edge_info = edge_info.drop(columns=blacklisted, errors="ignore")
487
+ edges = {nb: ed for nb, ed in edges.items() if nb not in blacklisted}
492
488
 
493
- # Check for presence of selected V/J genes
494
- # We'll consider edges that contain both selected_v and selected_j
495
- # in the attribute keys
496
- sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
497
- {selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
489
+ # Filter to edges containing both V and J genes
490
+ valid_edges = {nb: ed for nb, ed in edges.items()
491
+ if ed.has_gene(selected_v) and ed.has_gene(selected_j)}
498
492
 
499
- if sub_df.empty:
493
+ if not valid_edges:
500
494
  # No valid edges
501
495
  if len(walk) > 2:
502
496
  prev_state = walk[-2]
@@ -510,10 +504,11 @@ class AAPLZGraph(LZGraphBase):
510
504
  selected_v, selected_j = self._select_random_vj_genes(vj_init)
511
505
  continue
512
506
 
513
- # Weighted choice among these edges
514
- w = edge_info.loc["weight", sub_df.index]
515
- w /= w.sum()
516
- if w.empty:
507
+ # Weighted choice among valid edges
508
+ nbs = list(valid_edges.keys())
509
+ weights = np.array([valid_edges[nb].weight for nb in nbs])
510
+ w_sum = weights.sum()
511
+ if w_sum == 0:
517
512
  # Again, no valid edges
518
513
  if len(walk) > 2:
519
514
  prev_state = walk[-2]
@@ -527,7 +522,8 @@ class AAPLZGraph(LZGraphBase):
527
522
  selected_v, selected_j = self._select_random_vj_genes(vj_init)
528
523
  continue
529
524
 
530
- current_state = np.random.choice(w.index, p=w.values)
525
+ weights /= w_sum
526
+ current_state = np.random.choice(nbs, p=weights)
531
527
  walk.append(current_state)
532
528
 
533
529
  results.append((walk, selected_v, selected_j))
@@ -594,7 +590,8 @@ class AAPLZGraph(LZGraphBase):
594
590
 
595
591
  to_drop = []
596
592
  for src, dst, attrs in self.edges_list:
597
- if (v not in attrs) or (j not in attrs):
593
+ ed = attrs.get('data')
594
+ if ed is None or not (ed.has_gene(v) and ed.has_gene(j)):
598
595
  to_drop.append((src, dst))
599
596
 
600
597
  G = self.graph.copy()
@@ -632,15 +629,14 @@ class AAPLZGraph(LZGraphBase):
632
629
  self.genetic_walks_black_list = {}
633
630
 
634
631
  while current_state not in final_states:
635
- edge_info = pd.DataFrame(dict(G[current_state]))
632
+ # Get outgoing edges from the gene subgraph
633
+ edges = {nb: G[current_state][nb]['data'] for nb in G[current_state]}
636
634
  # Apply blacklist
637
635
  if (selected_v, selected_j, current_state) in self.genetic_walks_black_list:
638
- edge_info = edge_info.drop(
639
- columns=self.genetic_walks_black_list[(selected_v, selected_j, current_state)],
640
- errors="ignore"
641
- )
636
+ blacklisted = self.genetic_walks_black_list[(selected_v, selected_j, current_state)]
637
+ edges = {nb: ed for nb, ed in edges.items() if nb not in blacklisted}
642
638
 
643
- if edge_info.shape[1] == 0:
639
+ if not edges:
644
640
  if len(walk) > 1:
645
641
  prev_state = walk[-2]
646
642
  blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
@@ -649,14 +645,13 @@ class AAPLZGraph(LZGraphBase):
649
645
  walk.pop()
650
646
  current_state = walk[-1]
651
647
  else:
652
- # Stuck at the start
653
648
  break
654
649
  continue
655
650
 
656
- sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
657
- {selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
658
- if sub_df.empty:
659
- # No valid edges — backtrack or break
651
+ # Filter to edges containing both V and J genes
652
+ valid_edges = {nb: ed for nb, ed in edges.items()
653
+ if ed.has_gene(selected_v) and ed.has_gene(selected_j)}
654
+ if not valid_edges:
660
655
  if len(walk) > 1:
661
656
  prev_state = walk[-2]
662
657
  blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
@@ -668,9 +663,10 @@ class AAPLZGraph(LZGraphBase):
668
663
  break
669
664
  continue
670
665
 
671
- w = edge_info.loc["weight", sub_df.index]
672
- w /= w.sum()
673
- next_state = np.random.choice(w.index, p=w.values)
666
+ nbs = list(valid_edges.keys())
667
+ weights = np.array([valid_edges[nb].weight for nb in nbs])
668
+ weights /= weights.sum()
669
+ next_state = np.random.choice(nbs, p=weights)
674
670
  walk.append(next_state)
675
671
  current_state = next_state
676
672
 
@@ -0,0 +1,197 @@
1
+ """
2
+ EdgeData: Encapsulates all data for a single directed edge in an LZGraph.
3
+
4
+ Raw counts are the source of truth. Normalized probabilities are cached
5
+ after calling normalize() and are read-only.
6
+ """
7
+
8
+ from ..utilities.misc import _is_v_gene, _is_j_gene
9
+
10
+ __all__ = ["EdgeData"]
11
+
12
+
13
+ class EdgeData:
14
+ """Stores all data for a single directed edge in an LZGraph.
15
+
16
+ Raw counts are the source of truth. Normalized probabilities
17
+ are cached after calling normalize() and are read-only.
18
+
19
+ Attributes:
20
+ count (int): Raw transition count (source of truth).
21
+ v_genes (dict): {gene_name: raw_count} for V genes.
22
+ j_genes (dict): {gene_name: raw_count} for J genes.
23
+ """
24
+ __slots__ = ('count', '_weight', 'v_genes', 'j_genes')
25
+
26
+ def __init__(self):
27
+ self.count = 0
28
+ self._weight = 0.0
29
+ self.v_genes = {}
30
+ self.j_genes = {}
31
+
32
+ @property
33
+ def weight(self):
34
+ """Cached transition probability P(B|A), set by normalize()."""
35
+ return self._weight
36
+
37
+ @property
38
+ def vsum(self):
39
+ """Total count of V gene observations on this edge."""
40
+ return sum(self.v_genes.values())
41
+
42
+ @property
43
+ def jsum(self):
44
+ """Total count of J gene observations on this edge."""
45
+ return sum(self.j_genes.values())
46
+
47
+ @property
48
+ def is_genetic(self):
49
+ """Whether this edge has any gene data."""
50
+ return bool(self.v_genes or self.j_genes)
51
+
52
+ def record(self, v_gene=None, j_gene=None):
53
+ """Record one traversal during graph construction.
54
+
55
+ Args:
56
+ v_gene (str, optional): V gene to record.
57
+ j_gene (str, optional): J gene to record.
58
+ """
59
+ self.count += 1
60
+ if v_gene is not None:
61
+ self.v_genes[v_gene] = self.v_genes.get(v_gene, 0) + 1
62
+ if j_gene is not None:
63
+ self.j_genes[j_gene] = self.j_genes.get(j_gene, 0) + 1
64
+
65
+ def unrecord(self, v_gene=None, j_gene=None):
66
+ """Remove one traversal (for sequence removal).
67
+
68
+ Args:
69
+ v_gene (str, optional): V gene to decrement.
70
+ j_gene (str, optional): J gene to decrement.
71
+ """
72
+ self.count = max(0, self.count - 1)
73
+ if v_gene is not None and v_gene in self.v_genes:
74
+ self.v_genes[v_gene] -= 1
75
+ if self.v_genes[v_gene] <= 0:
76
+ del self.v_genes[v_gene]
77
+ if j_gene is not None and j_gene in self.j_genes:
78
+ self.j_genes[j_gene] -= 1
79
+ if self.j_genes[j_gene] <= 0:
80
+ del self.j_genes[j_gene]
81
+
82
+ def merge(self, other):
83
+ """Merge another EdgeData into this one (for graph union).
84
+
85
+ Args:
86
+ other (EdgeData): The edge data to merge in.
87
+ """
88
+ self.count += other.count
89
+ for g, c in other.v_genes.items():
90
+ self.v_genes[g] = self.v_genes.get(g, 0) + c
91
+ for g, c in other.j_genes.items():
92
+ self.j_genes[g] = self.j_genes.get(g, 0) + c
93
+
94
+ def normalize(self, node_frequency, alpha=0.0, n_successors=0):
95
+ """Compute and cache transition probability from raw count.
96
+
97
+ Args:
98
+ node_frequency (int): Total outgoing count from source node.
99
+ alpha (float): Laplace smoothing parameter.
100
+ n_successors (int): Number of successors (for Laplace smoothing).
101
+ """
102
+ if alpha > 0:
103
+ denom = node_frequency + alpha * n_successors
104
+ self._weight = (self.count + alpha) / denom if denom > 0 else 0.0
105
+ elif node_frequency > 0:
106
+ self._weight = self.count / node_frequency
107
+ else:
108
+ self._weight = 0.0
109
+
110
+ def v_probability(self, gene):
111
+ """Return P(gene) among V genes on this edge."""
112
+ vsum = self.vsum
113
+ return self.v_genes.get(gene, 0) / vsum if vsum > 0 else 0.0
114
+
115
+ def j_probability(self, gene):
116
+ """Return P(gene) among J genes on this edge."""
117
+ jsum = self.jsum
118
+ return self.j_genes.get(gene, 0) / jsum if jsum > 0 else 0.0
119
+
120
+ def has_gene(self, gene):
121
+ """Check if a gene (V or J) is present on this edge."""
122
+ return gene in self.v_genes or gene in self.j_genes
123
+
124
+ def gene_dict(self):
125
+ """Return {gene: probability} dict for all genes on this edge."""
126
+ result = {}
127
+ vsum, jsum = self.vsum, self.jsum
128
+ for g, c in self.v_genes.items():
129
+ result[g] = c / vsum if vsum > 0 else 0.0
130
+ for g, c in self.j_genes.items():
131
+ result[g] = c / jsum if jsum > 0 else 0.0
132
+ return result
133
+
134
+ def to_legacy_dict(self):
135
+ """Convert to flat dict matching old edge attribute format.
136
+
137
+ Returns:
138
+ dict: {weight, count, Vsum, Jsum, gene_name: probability, ...}
139
+ """
140
+ d = {'weight': self._weight, 'count': self.count}
141
+ if self.v_genes:
142
+ d['Vsum'] = self.vsum
143
+ for g in self.v_genes:
144
+ d[g] = self.v_probability(g)
145
+ if self.j_genes:
146
+ d['Jsum'] = self.jsum
147
+ for g in self.j_genes:
148
+ d[g] = self.j_probability(g)
149
+ return d
150
+
151
+ @classmethod
152
+ def from_legacy_dict(cls, d, node_frequency=0):
153
+ """Reconstruct EdgeData from an old-format flat dict.
154
+
155
+ Used for loading old saves where edge data was stored as
156
+ {weight, Vsum, Jsum, gene_name: probability, ...}.
157
+
158
+ Args:
159
+ d (dict): Old-format edge attribute dictionary.
160
+ node_frequency (int): Per-node observed frequency for count recovery.
161
+
162
+ Returns:
163
+ EdgeData: Reconstructed edge data.
164
+ """
165
+ edge = cls()
166
+ edge._weight = d.get('weight', 0.0)
167
+ edge.count = d.get('count', 0)
168
+ if edge.count == 0 and node_frequency > 0:
169
+ edge.count = int(round(edge._weight * node_frequency))
170
+
171
+ vsum = d.get('Vsum', 0)
172
+ jsum = d.get('Jsum', 0)
173
+ for key, val in d.items():
174
+ if key in ('weight', 'count', 'Vsum', 'Jsum'):
175
+ continue
176
+ if _is_v_gene(key) and vsum > 0:
177
+ edge.v_genes[key] = int(round(val * vsum))
178
+ elif _is_j_gene(key) and jsum > 0:
179
+ edge.j_genes[key] = int(round(val * jsum))
180
+ return edge
181
+
182
+ def __getstate__(self):
183
+ return (self.count, self._weight, self.v_genes, self.j_genes)
184
+
185
+ def __setstate__(self, state):
186
+ self.count, self._weight, self.v_genes, self.j_genes = state
187
+
188
+ def __eq__(self, other):
189
+ if not isinstance(other, EdgeData):
190
+ return NotImplemented
191
+ return (self.count == other.count
192
+ and self.v_genes == other.v_genes
193
+ and self.j_genes == other.j_genes)
194
+
195
+ def __repr__(self):
196
+ return (f"EdgeData(count={self.count}, weight={self._weight:.4f}, "
197
+ f"v={len(self.v_genes)}, j={len(self.j_genes)})")
@@ -0,0 +1,115 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from .edge_data import EdgeData
5
+ from ..exceptions import IncompatibleGraphsError
6
+
7
+
8
+ __all__ = ['graph_union']
9
+
10
+
11
+ def graph_union(graphA, graphB):
12
+ """Perform a union operation between two graphs.
13
+
14
+ graphA will be updated in-place to be the equivalent of the union
15
+ of both. The result is logically equal to constructing a graph from
16
+ the combined sequences of two separate repertoires.
17
+
18
+ Since EdgeData stores raw counts as the source of truth, the union
19
+ simply merges counts and then recalculates all derived probabilities.
20
+
21
+ Args:
22
+ graphA (LZGraph): An LZGraph (will be modified in-place).
23
+ graphB (LZGraph): An LZGraph of the same class as graphA.
24
+
25
+ Returns:
26
+ LZGraph: graphA, updated with the union of both graphs.
27
+ """
28
+ if not isinstance(graphA, type(graphB)) and not isinstance(graphB, type(graphA)):
29
+ raise IncompatibleGraphsError(
30
+ type1=type(graphA).__name__,
31
+ type2=type(graphB).__name__,
32
+ message="Both graphs must be of the same type for union operation."
33
+ )
34
+
35
+ # 1. Merge edges (raw counts)
36
+ for a, b in graphB.graph.edges:
37
+ ed_b = graphB.graph[a][b]['data']
38
+ if graphA.graph.has_edge(a, b):
39
+ graphA.graph[a][b]['data'].merge(ed_b)
40
+ else:
41
+ # Ensure nodes exist
42
+ if a not in graphA.graph:
43
+ graphA.graph.add_node(a)
44
+ if b not in graphA.graph:
45
+ graphA.graph.add_node(b)
46
+ # Deep copy EdgeData from B
47
+ ed_new = EdgeData()
48
+ ed_new.merge(ed_b)
49
+ graphA.graph.add_edge(a, b, data=ed_new)
50
+
51
+ # Also add any nodes from B that have no edges
52
+ for node in graphB.graph.nodes:
53
+ if node not in graphA.graph:
54
+ graphA.graph.add_node(node)
55
+
56
+ # 2. Merge sequence-level counts
57
+ # (per_node_observed_frequency is recomputed in recalculate())
58
+ graphA.initial_states = graphA.initial_states.combine(
59
+ graphB.initial_states, lambda x, y: x + y, fill_value=0
60
+ )
61
+ graphA.terminal_states = graphA.terminal_states.combine(
62
+ graphB.terminal_states, lambda x, y: x + y, fill_value=0
63
+ )
64
+ graphA.n_subpatterns += graphB.n_subpatterns
65
+ graphA.n_transitions += graphB.n_transitions
66
+
67
+ # Merge lengths
68
+ if hasattr(graphB, 'lengths'):
69
+ for length, count in graphB.lengths.items():
70
+ graphA.lengths[length] = graphA.lengths.get(length, 0) + count
71
+
72
+ # 4. Merge gene-level data (if genetic)
73
+ if graphA.genetic and graphB.genetic:
74
+ # Weighted average of marginal gene distributions
75
+ nA = graphA.initial_states.sum()
76
+ nB = graphB.initial_states.sum()
77
+ nTotal = nA + nB
78
+ if nTotal > 0:
79
+ graphA.marginal_vgenes = (
80
+ graphA.marginal_vgenes.combine(graphB.marginal_vgenes,
81
+ lambda x, y: x * nA / nTotal + y * nB / nTotal, fill_value=0)
82
+ )
83
+ graphA.marginal_jgenes = (
84
+ graphA.marginal_jgenes.combine(graphB.marginal_jgenes,
85
+ lambda x, y: x * nA / nTotal + y * nB / nTotal, fill_value=0)
86
+ )
87
+ graphA.vj_probabilities = (
88
+ graphA.vj_probabilities.combine(graphB.vj_probabilities,
89
+ lambda x, y: x * nA / nTotal + y * nB / nTotal, fill_value=0)
90
+ )
91
+
92
+ # Merge length_distribution counts
93
+ if hasattr(graphA, 'length_distribution') and hasattr(graphB, 'length_distribution'):
94
+ graphA.length_distribution = graphA.length_distribution.combine(
95
+ graphB.length_distribution, lambda x, y: x + y, fill_value=0
96
+ )
97
+
98
+ # Merge observed gene sets
99
+ if hasattr(graphB, 'observed_vgenes'):
100
+ graphA.observed_vgenes = list(
101
+ set(graphA.observed_vgenes) | set(graphB.observed_vgenes)
102
+ )
103
+ if hasattr(graphB, 'observed_jgenes'):
104
+ graphA.observed_jgenes = list(
105
+ set(graphA.observed_jgenes) | set(graphB.observed_jgenes)
106
+ )
107
+
108
+ # 5. Recalculate ALL derived state from raw counts
109
+ graphA.recalculate()
110
+
111
+ # Clear cached edges list
112
+ if hasattr(graphA, 'edges_list'):
113
+ graphA.edges_list = None
114
+
115
+ return graphA