LZGraphs 2.0.0__tar.gz → 2.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/CHANGELOG.md +1 -1
  2. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/PKG-INFO +1 -1
  3. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/__init__.py +23 -1
  4. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/graphs/amino_acid_positional.py +37 -45
  5. lzgraphs-2.1.1/src/LZGraphs/graphs/edge_data.py +197 -0
  6. lzgraphs-2.1.1/src/LZGraphs/graphs/graph_operations.py +115 -0
  7. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/graphs/lz_graph_base.py +861 -139
  8. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/graphs/naive.py +6 -5
  9. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/graphs/nucleotide_double_positional.py +18 -36
  10. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/metrics/__init__.py +22 -0
  11. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/metrics/convenience.py +8 -0
  12. lzgraphs-2.1.1/src/LZGraphs/metrics/entropy.py +1007 -0
  13. lzgraphs-2.1.1/src/LZGraphs/metrics/pgen_distribution.py +351 -0
  14. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/mixins/gene_logic.py +8 -42
  15. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/mixins/gene_prediction.py +33 -85
  16. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/mixins/random_walk.py +8 -11
  17. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/utilities/__init__.py +0 -4
  18. lzgraphs-2.1.1/src/LZGraphs/utilities/helpers.py +50 -0
  19. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/utilities/misc.py +1 -13
  20. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs.egg-info/PKG-INFO +1 -1
  21. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs.egg-info/SOURCES.txt +5 -0
  22. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_aap_lzgraph.py +15 -24
  23. lzgraphs-2.1.1/tests/test_analytical_distribution.py +481 -0
  24. lzgraphs-2.1.1/tests/test_lzpgen_distribution.py +262 -0
  25. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_metrics.py +204 -0
  26. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_naive_lzgraph.py +14 -16
  27. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_ndp_lzgraph.py +14 -22
  28. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_new_features.py +74 -0
  29. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_pgen_fixes.py +3 -3
  30. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_serialization.py +1 -1
  31. lzgraphs-2.1.1/tests/test_simulate.py +225 -0
  32. lzgraphs-2.0.0/src/LZGraphs/graphs/graph_operations.py +0 -164
  33. lzgraphs-2.0.0/src/LZGraphs/metrics/entropy.py +0 -504
  34. lzgraphs-2.0.0/src/LZGraphs/utilities/helpers.py +0 -105
  35. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/CONTRIBUTING.md +0 -0
  36. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/LICENSE +0 -0
  37. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/MANIFEST.in +0 -0
  38. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/README.md +0 -0
  39. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/pyproject.toml +0 -0
  40. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/requirements.txt +0 -0
  41. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/setup.cfg +0 -0
  42. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/bag_of_words/__init__.py +0 -0
  43. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/bag_of_words/bow_encoder.py +0 -0
  44. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/exceptions/__init__.py +0 -0
  45. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/graphs/__init__.py +0 -0
  46. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/metrics/diversity.py +0 -0
  47. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/metrics/saturation.py +0 -0
  48. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/mixins/__init__.py +0 -0
  49. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/py.typed +0 -0
  50. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/utilities/decomposition.py +0 -0
  51. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/visualization/__init__.py +0 -0
  52. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs/visualization/visualize.py +0 -0
  53. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs.egg-info/dependency_links.txt +0 -0
  54. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs.egg-info/requires.txt +0 -0
  55. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/src/LZGraphs.egg-info/top_level.txt +0 -0
  56. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_base_class_methods.py +0 -0
  57. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_bow_encoder.py +0 -0
  58. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_diversity_theory.py +0 -0
  59. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_graph_operations.py +0 -0
  60. {lzgraphs-2.0.0 → lzgraphs-2.1.1}/tests/test_utilities.py +0 -0
@@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
9
9
 
10
10
  ### Added
11
11
  - Custom exceptions module with comprehensive exception hierarchy for better error handling
12
- - Information-theoretic metrics module (`LZGraphs.Metrics.entropy`)
12
+ - Information-theoretic metrics module (`LZGraphs.metrics.entropy`)
13
13
  - `node_entropy()` - Shannon entropy of node probability distribution
14
14
  - `edge_entropy()` - Shannon entropy of edge transition probabilities
15
15
  - `graph_entropy()` - Combined graph entropy measure
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: LZGraphs
3
- Version: 2.0.0
3
+ Version: 2.1.1
4
4
  Summary: An Implementation of LZ76 Based Graphs for Repertoire Representation and Analysis
5
5
  Author-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
6
6
  Maintainer-email: Thomas Konstantinovsky <thomaskon90@gmail.com>
@@ -1,4 +1,4 @@
1
- __version__ = "2.0.0"
1
+ __version__ = "2.1.1"
2
2
 
3
3
  # =============================================================================
4
4
  # Graph classes
@@ -44,6 +44,13 @@ from .metrics.entropy import (
44
44
  cross_entropy,
45
45
  kl_divergence,
46
46
  mutual_information_genes,
47
+ transition_predictability,
48
+ graph_compression_ratio,
49
+ repertoire_compressibility_index,
50
+ transition_kl_divergence,
51
+ transition_jsd,
52
+ transition_mutual_information_profile,
53
+ path_entropy_rate,
47
54
  )
48
55
 
49
56
  # =============================================================================
@@ -56,6 +63,11 @@ from .metrics.saturation import NodeEdgeSaturationProbe
56
63
  # =============================================================================
57
64
  from .metrics.convenience import compare_repertoires
58
65
 
66
+ # =============================================================================
67
+ # Metrics - PGen Distribution
68
+ # =============================================================================
69
+ from .metrics.pgen_distribution import LZPgenDistribution, compare_lzpgen_distributions
70
+
59
71
  # =============================================================================
60
72
  # Utilities
61
73
  # =============================================================================
@@ -145,10 +157,20 @@ __all__ = [
145
157
  'cross_entropy',
146
158
  'kl_divergence',
147
159
  'mutual_information_genes',
160
+ 'transition_predictability',
161
+ 'graph_compression_ratio',
162
+ 'repertoire_compressibility_index',
163
+ 'transition_kl_divergence',
164
+ 'transition_jsd',
165
+ 'transition_mutual_information_profile',
166
+ 'path_entropy_rate',
148
167
  # Saturation
149
168
  'NodeEdgeSaturationProbe',
150
169
  # Convenience
151
170
  'compare_repertoires',
171
+ # PGen distribution
172
+ 'LZPgenDistribution',
173
+ 'compare_lzpgen_distributions',
152
174
  # Utilities
153
175
  'generate_kmer_dictionary',
154
176
  'lempel_ziv_decomposition',
@@ -1,5 +1,4 @@
1
1
  import logging
2
- import re
3
2
  import time
4
3
  from typing import List, Tuple, Union, Optional, Generator
5
4
 
@@ -10,7 +9,7 @@ from tqdm.auto import tqdm
10
9
 
11
10
  from .lz_graph_base import LZGraphBase
12
11
  from ..utilities.decomposition import lempel_ziv_decomposition
13
- from ..utilities.misc import window, choice
12
+ from ..utilities.misc import window
14
13
  from ..exceptions import (
15
14
  EmptyDataError,
16
15
  MissingColumnError,
@@ -144,18 +143,10 @@ class AAPLZGraph(LZGraphBase):
144
143
  self._normalize_edge_weights()
145
144
  self.verbose_driver(3, verbose)
146
145
 
147
- # If gene data is available, normalize gene weights in parallel
148
- if self.genetic:
149
- self._batch_gene_weight_normalization(verbose=verbose)
150
- self.verbose_driver(4, verbose)
151
-
152
146
  # Additional map derivations
153
147
  self.edges_list = None
154
- self._derive_terminal_state_map()
155
- self.verbose_driver(7, verbose)
156
148
  self._derive_stop_probability_data()
157
- self.verbose_driver(8, verbose)
158
- self.verbose_driver(5, verbose)
149
+ self.verbose_driver(9, verbose)
159
150
 
160
151
  # Optionally compute the PGEN for each sequence
161
152
  if calculate_trainset_pgen:
@@ -306,8 +297,8 @@ class AAPLZGraph(LZGraphBase):
306
297
  """
307
298
  Given a sub-pattern that might look like "ABC_10", extract only the amino acids ("ABC").
308
299
  """
309
- match = re.search(r'[A-Z]+', base)
310
- return match.group(0) if match else ""
300
+ idx = base.rfind('_')
301
+ return base[:idx] if idx > 0 else base
311
302
 
312
303
  def _decomposed_sequence_generator(
313
304
  self,
@@ -416,16 +407,16 @@ class AAPLZGraph(LZGraphBase):
416
407
  val = np.finfo(float).eps if use_epsilon else 0.0
417
408
  return (val, val)
418
409
 
419
- e_data = self.graph[step1][step2]
410
+ ed = self.graph[step1][step2]['data']
420
411
  # If these genes aren't on the edge, it's effectively 0
421
- if v not in e_data or j not in e_data:
412
+ if not ed.has_gene(v) or not ed.has_gene(j):
422
413
  if verbose:
423
414
  logger.warning(f"Edge {step1}->{step2} missing {v} or {j}.")
424
415
  val = np.finfo(float).eps if use_epsilon else 0.0
425
416
  return (val, val)
426
417
 
427
- proba_v *= e_data[v]
428
- proba_j *= e_data[j]
418
+ proba_v *= ed.v_probability(v)
419
+ proba_j *= ed.j_probability(j)
429
420
 
430
421
  return proba_v, proba_j
431
422
 
@@ -484,19 +475,18 @@ class AAPLZGraph(LZGraphBase):
484
475
  logger.warning(f"Current state {current_state} not in graph.")
485
476
  break
486
477
 
487
- edge_info = pd.DataFrame(dict(self.graph[current_state]))
478
+ # Get edges that have both selected V and J genes
479
+ edges = self.outgoing_edges(current_state)
488
480
  # Apply blacklist if present
489
481
  if (current_state, selected_v, selected_j) in self.genetic_walks_black_list:
490
482
  blacklisted = self.genetic_walks_black_list[(current_state, selected_v, selected_j)]
491
- edge_info = edge_info.drop(columns=blacklisted, errors="ignore")
483
+ edges = {nb: ed for nb, ed in edges.items() if nb not in blacklisted}
492
484
 
493
- # Check for presence of selected V/J genes
494
- # We'll consider edges that contain both selected_v and selected_j
495
- # in the attribute keys
496
- sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
497
- {selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
485
+ # Filter to edges containing both V and J genes
486
+ valid_edges = {nb: ed for nb, ed in edges.items()
487
+ if ed.has_gene(selected_v) and ed.has_gene(selected_j)}
498
488
 
499
- if sub_df.empty:
489
+ if not valid_edges:
500
490
  # No valid edges
501
491
  if len(walk) > 2:
502
492
  prev_state = walk[-2]
@@ -510,10 +500,11 @@ class AAPLZGraph(LZGraphBase):
510
500
  selected_v, selected_j = self._select_random_vj_genes(vj_init)
511
501
  continue
512
502
 
513
- # Weighted choice among these edges
514
- w = edge_info.loc["weight", sub_df.index]
515
- w /= w.sum()
516
- if w.empty:
503
+ # Weighted choice among valid edges
504
+ nbs = list(valid_edges.keys())
505
+ weights = np.array([valid_edges[nb].weight for nb in nbs])
506
+ w_sum = weights.sum()
507
+ if w_sum == 0:
517
508
  # Again, no valid edges
518
509
  if len(walk) > 2:
519
510
  prev_state = walk[-2]
@@ -527,7 +518,8 @@ class AAPLZGraph(LZGraphBase):
527
518
  selected_v, selected_j = self._select_random_vj_genes(vj_init)
528
519
  continue
529
520
 
530
- current_state = np.random.choice(w.index, p=w.values)
521
+ weights /= w_sum
522
+ current_state = np.random.choice(nbs, p=weights)
531
523
  walk.append(current_state)
532
524
 
533
525
  results.append((walk, selected_v, selected_j))
@@ -594,7 +586,8 @@ class AAPLZGraph(LZGraphBase):
594
586
 
595
587
  to_drop = []
596
588
  for src, dst, attrs in self.edges_list:
597
- if (v not in attrs) or (j not in attrs):
589
+ ed = attrs.get('data')
590
+ if ed is None or not (ed.has_gene(v) and ed.has_gene(j)):
598
591
  to_drop.append((src, dst))
599
592
 
600
593
  G = self.graph.copy()
@@ -632,15 +625,14 @@ class AAPLZGraph(LZGraphBase):
632
625
  self.genetic_walks_black_list = {}
633
626
 
634
627
  while current_state not in final_states:
635
- edge_info = pd.DataFrame(dict(G[current_state]))
628
+ # Get outgoing edges from the gene subgraph
629
+ edges = {nb: G[current_state][nb]['data'] for nb in G[current_state]}
636
630
  # Apply blacklist
637
631
  if (selected_v, selected_j, current_state) in self.genetic_walks_black_list:
638
- edge_info = edge_info.drop(
639
- columns=self.genetic_walks_black_list[(selected_v, selected_j, current_state)],
640
- errors="ignore"
641
- )
632
+ blacklisted = self.genetic_walks_black_list[(selected_v, selected_j, current_state)]
633
+ edges = {nb: ed for nb, ed in edges.items() if nb not in blacklisted}
642
634
 
643
- if edge_info.shape[1] == 0:
635
+ if not edges:
644
636
  if len(walk) > 1:
645
637
  prev_state = walk[-2]
646
638
  blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
@@ -649,14 +641,13 @@ class AAPLZGraph(LZGraphBase):
649
641
  walk.pop()
650
642
  current_state = walk[-1]
651
643
  else:
652
- # Stuck at the start
653
644
  break
654
645
  continue
655
646
 
656
- sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
657
- {selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
658
- if sub_df.empty:
659
- # No valid edges — backtrack or break
647
+ # Filter to edges containing both V and J genes
648
+ valid_edges = {nb: ed for nb, ed in edges.items()
649
+ if ed.has_gene(selected_v) and ed.has_gene(selected_j)}
650
+ if not valid_edges:
660
651
  if len(walk) > 1:
661
652
  prev_state = walk[-2]
662
653
  blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
@@ -668,9 +659,10 @@ class AAPLZGraph(LZGraphBase):
668
659
  break
669
660
  continue
670
661
 
671
- w = edge_info.loc["weight", sub_df.index]
672
- w /= w.sum()
673
- next_state = np.random.choice(w.index, p=w.values)
662
+ nbs = list(valid_edges.keys())
663
+ weights = np.array([valid_edges[nb].weight for nb in nbs])
664
+ weights /= weights.sum()
665
+ next_state = np.random.choice(nbs, p=weights)
674
666
  walk.append(next_state)
675
667
  current_state = next_state
676
668
 
@@ -0,0 +1,197 @@
1
+ """
2
+ EdgeData: Encapsulates all data for a single directed edge in an LZGraph.
3
+
4
+ Raw counts are the source of truth. Normalized probabilities are cached
5
+ after calling normalize() and are read-only.
6
+ """
7
+
8
+ from ..utilities.misc import _is_v_gene, _is_j_gene
9
+
10
+ __all__ = ["EdgeData"]
11
+
12
+
13
+ class EdgeData:
14
+ """Stores all data for a single directed edge in an LZGraph.
15
+
16
+ Raw counts are the source of truth. Normalized probabilities
17
+ are cached after calling normalize() and are read-only.
18
+
19
+ Attributes:
20
+ count (int): Raw transition count (source of truth).
21
+ v_genes (dict): {gene_name: raw_count} for V genes.
22
+ j_genes (dict): {gene_name: raw_count} for J genes.
23
+ """
24
+ __slots__ = ('count', '_weight', 'v_genes', 'j_genes')
25
+
26
+ def __init__(self):
27
+ self.count = 0
28
+ self._weight = 0.0
29
+ self.v_genes = {}
30
+ self.j_genes = {}
31
+
32
+ @property
33
+ def weight(self):
34
+ """Cached transition probability P(B|A), set by normalize()."""
35
+ return self._weight
36
+
37
+ @property
38
+ def vsum(self):
39
+ """Total count of V gene observations on this edge."""
40
+ return sum(self.v_genes.values())
41
+
42
+ @property
43
+ def jsum(self):
44
+ """Total count of J gene observations on this edge."""
45
+ return sum(self.j_genes.values())
46
+
47
+ @property
48
+ def is_genetic(self):
49
+ """Whether this edge has any gene data."""
50
+ return bool(self.v_genes or self.j_genes)
51
+
52
+ def record(self, v_gene=None, j_gene=None):
53
+ """Record one traversal during graph construction.
54
+
55
+ Args:
56
+ v_gene (str, optional): V gene to record.
57
+ j_gene (str, optional): J gene to record.
58
+ """
59
+ self.count += 1
60
+ if v_gene is not None:
61
+ self.v_genes[v_gene] = self.v_genes.get(v_gene, 0) + 1
62
+ if j_gene is not None:
63
+ self.j_genes[j_gene] = self.j_genes.get(j_gene, 0) + 1
64
+
65
+ def unrecord(self, v_gene=None, j_gene=None):
66
+ """Remove one traversal (for sequence removal).
67
+
68
+ Args:
69
+ v_gene (str, optional): V gene to decrement.
70
+ j_gene (str, optional): J gene to decrement.
71
+ """
72
+ self.count = max(0, self.count - 1)
73
+ if v_gene is not None and v_gene in self.v_genes:
74
+ self.v_genes[v_gene] -= 1
75
+ if self.v_genes[v_gene] <= 0:
76
+ del self.v_genes[v_gene]
77
+ if j_gene is not None and j_gene in self.j_genes:
78
+ self.j_genes[j_gene] -= 1
79
+ if self.j_genes[j_gene] <= 0:
80
+ del self.j_genes[j_gene]
81
+
82
+ def merge(self, other):
83
+ """Merge another EdgeData into this one (for graph union).
84
+
85
+ Args:
86
+ other (EdgeData): The edge data to merge in.
87
+ """
88
+ self.count += other.count
89
+ for g, c in other.v_genes.items():
90
+ self.v_genes[g] = self.v_genes.get(g, 0) + c
91
+ for g, c in other.j_genes.items():
92
+ self.j_genes[g] = self.j_genes.get(g, 0) + c
93
+
94
+ def normalize(self, node_frequency, alpha=0.0, n_successors=0):
95
+ """Compute and cache transition probability from raw count.
96
+
97
+ Args:
98
+ node_frequency (int): Total outgoing count from source node.
99
+ alpha (float): Laplace smoothing parameter.
100
+ n_successors (int): Number of successors (for Laplace smoothing).
101
+ """
102
+ if alpha > 0:
103
+ denom = node_frequency + alpha * n_successors
104
+ self._weight = (self.count + alpha) / denom if denom > 0 else 0.0
105
+ elif node_frequency > 0:
106
+ self._weight = self.count / node_frequency
107
+ else:
108
+ self._weight = 0.0
109
+
110
+ def v_probability(self, gene):
111
+ """Return P(gene) among V genes on this edge."""
112
+ vsum = self.vsum
113
+ return self.v_genes.get(gene, 0) / vsum if vsum > 0 else 0.0
114
+
115
+ def j_probability(self, gene):
116
+ """Return P(gene) among J genes on this edge."""
117
+ jsum = self.jsum
118
+ return self.j_genes.get(gene, 0) / jsum if jsum > 0 else 0.0
119
+
120
+ def has_gene(self, gene):
121
+ """Check if a gene (V or J) is present on this edge."""
122
+ return gene in self.v_genes or gene in self.j_genes
123
+
124
+ def gene_dict(self):
125
+ """Return {gene: probability} dict for all genes on this edge."""
126
+ result = {}
127
+ vsum, jsum = self.vsum, self.jsum
128
+ for g, c in self.v_genes.items():
129
+ result[g] = c / vsum if vsum > 0 else 0.0
130
+ for g, c in self.j_genes.items():
131
+ result[g] = c / jsum if jsum > 0 else 0.0
132
+ return result
133
+
134
+ def to_legacy_dict(self):
135
+ """Convert to flat dict matching old edge attribute format.
136
+
137
+ Returns:
138
+ dict: {weight, count, Vsum, Jsum, gene_name: probability, ...}
139
+ """
140
+ d = {'weight': self._weight, 'count': self.count}
141
+ if self.v_genes:
142
+ d['Vsum'] = self.vsum
143
+ for g in self.v_genes:
144
+ d[g] = self.v_probability(g)
145
+ if self.j_genes:
146
+ d['Jsum'] = self.jsum
147
+ for g in self.j_genes:
148
+ d[g] = self.j_probability(g)
149
+ return d
150
+
151
+ @classmethod
152
+ def from_legacy_dict(cls, d, node_frequency=0):
153
+ """Reconstruct EdgeData from an old-format flat dict.
154
+
155
+ Used for loading old saves where edge data was stored as
156
+ {weight, Vsum, Jsum, gene_name: probability, ...}.
157
+
158
+ Args:
159
+ d (dict): Old-format edge attribute dictionary.
160
+ node_frequency (int): Per-node observed frequency for count recovery.
161
+
162
+ Returns:
163
+ EdgeData: Reconstructed edge data.
164
+ """
165
+ edge = cls()
166
+ edge._weight = d.get('weight', 0.0)
167
+ edge.count = d.get('count', 0)
168
+ if edge.count == 0 and node_frequency > 0:
169
+ edge.count = int(round(edge._weight * node_frequency))
170
+
171
+ vsum = d.get('Vsum', 0)
172
+ jsum = d.get('Jsum', 0)
173
+ for key, val in d.items():
174
+ if key in ('weight', 'count', 'Vsum', 'Jsum'):
175
+ continue
176
+ if _is_v_gene(key) and vsum > 0:
177
+ edge.v_genes[key] = int(round(val * vsum))
178
+ elif _is_j_gene(key) and jsum > 0:
179
+ edge.j_genes[key] = int(round(val * jsum))
180
+ return edge
181
+
182
+ def __getstate__(self):
183
+ return (self.count, self._weight, self.v_genes, self.j_genes)
184
+
185
+ def __setstate__(self, state):
186
+ self.count, self._weight, self.v_genes, self.j_genes = state
187
+
188
+ def __eq__(self, other):
189
+ if not isinstance(other, EdgeData):
190
+ return NotImplemented
191
+ return (self.count == other.count
192
+ and self.v_genes == other.v_genes
193
+ and self.j_genes == other.j_genes)
194
+
195
+ def __repr__(self):
196
+ return (f"EdgeData(count={self.count}, weight={self._weight:.4f}, "
197
+ f"v={len(self.v_genes)}, j={len(self.j_genes)})")
@@ -0,0 +1,115 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from .edge_data import EdgeData
5
+ from ..exceptions import IncompatibleGraphsError
6
+
7
+
8
+ __all__ = ['graph_union']
9
+
10
+
11
+ def graph_union(graphA, graphB):
12
+ """Perform a union operation between two graphs.
13
+
14
+ graphA will be updated in-place to be the equivalent of the union
15
+ of both. The result is logically equal to constructing a graph from
16
+ the combined sequences of two separate repertoires.
17
+
18
+ Since EdgeData stores raw counts as the source of truth, the union
19
+ simply merges counts and then recalculates all derived probabilities.
20
+
21
+ Args:
22
+ graphA (LZGraph): An LZGraph (will be modified in-place).
23
+ graphB (LZGraph): An LZGraph of the same class as graphA.
24
+
25
+ Returns:
26
+ LZGraph: graphA, updated with the union of both graphs.
27
+ """
28
+ if not isinstance(graphA, type(graphB)) and not isinstance(graphB, type(graphA)):
29
+ raise IncompatibleGraphsError(
30
+ type1=type(graphA).__name__,
31
+ type2=type(graphB).__name__,
32
+ message="Both graphs must be of the same type for union operation."
33
+ )
34
+
35
+ # 1. Merge edges (raw counts)
36
+ for a, b in graphB.graph.edges:
37
+ ed_b = graphB.graph[a][b]['data']
38
+ if graphA.graph.has_edge(a, b):
39
+ graphA.graph[a][b]['data'].merge(ed_b)
40
+ else:
41
+ # Ensure nodes exist
42
+ if a not in graphA.graph:
43
+ graphA.graph.add_node(a)
44
+ if b not in graphA.graph:
45
+ graphA.graph.add_node(b)
46
+ # Deep copy EdgeData from B
47
+ ed_new = EdgeData()
48
+ ed_new.merge(ed_b)
49
+ graphA.graph.add_edge(a, b, data=ed_new)
50
+
51
+ # Also add any nodes from B that have no edges
52
+ for node in graphB.graph.nodes:
53
+ if node not in graphA.graph:
54
+ graphA.graph.add_node(node)
55
+
56
+ # 2. Merge sequence-level counts
57
+ # (per_node_observed_frequency is recomputed in recalculate())
58
+ graphA.initial_states = graphA.initial_states.combine(
59
+ graphB.initial_states, lambda x, y: x + y, fill_value=0
60
+ )
61
+ graphA.terminal_states = graphA.terminal_states.combine(
62
+ graphB.terminal_states, lambda x, y: x + y, fill_value=0
63
+ )
64
+ graphA.n_subpatterns += graphB.n_subpatterns
65
+ graphA.n_transitions += graphB.n_transitions
66
+
67
+ # Merge lengths
68
+ if hasattr(graphB, 'lengths'):
69
+ for length, count in graphB.lengths.items():
70
+ graphA.lengths[length] = graphA.lengths.get(length, 0) + count
71
+
72
+ # 4. Merge gene-level data (if genetic)
73
+ if graphA.genetic and graphB.genetic:
74
+ # Weighted average of marginal gene distributions
75
+ nA = graphA.initial_states.sum()
76
+ nB = graphB.initial_states.sum()
77
+ nTotal = nA + nB
78
+ if nTotal > 0:
79
+ graphA.marginal_vgenes = (
80
+ graphA.marginal_vgenes.combine(graphB.marginal_vgenes,
81
+ lambda x, y: x * nA / nTotal + y * nB / nTotal, fill_value=0)
82
+ )
83
+ graphA.marginal_jgenes = (
84
+ graphA.marginal_jgenes.combine(graphB.marginal_jgenes,
85
+ lambda x, y: x * nA / nTotal + y * nB / nTotal, fill_value=0)
86
+ )
87
+ graphA.vj_probabilities = (
88
+ graphA.vj_probabilities.combine(graphB.vj_probabilities,
89
+ lambda x, y: x * nA / nTotal + y * nB / nTotal, fill_value=0)
90
+ )
91
+
92
+ # Merge length_distribution counts
93
+ if hasattr(graphA, 'length_distribution') and hasattr(graphB, 'length_distribution'):
94
+ graphA.length_distribution = graphA.length_distribution.combine(
95
+ graphB.length_distribution, lambda x, y: x + y, fill_value=0
96
+ )
97
+
98
+ # Merge observed gene sets
99
+ if hasattr(graphB, 'observed_vgenes'):
100
+ graphA.observed_vgenes = list(
101
+ set(graphA.observed_vgenes) | set(graphB.observed_vgenes)
102
+ )
103
+ if hasattr(graphB, 'observed_jgenes'):
104
+ graphA.observed_jgenes = list(
105
+ set(graphA.observed_jgenes) | set(graphB.observed_jgenes)
106
+ )
107
+
108
+ # 5. Recalculate ALL derived state from raw counts
109
+ graphA.recalculate()
110
+
111
+ # Clear cached edges list
112
+ if hasattr(graphA, 'edges_list'):
113
+ graphA.edges_list = None
114
+
115
+ return graphA