LZGraphs 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,950 @@
1
+ import logging
2
+ import re
3
+ import time
4
+ from collections import Counter
5
+ from typing import List, Tuple, Union, Optional, Generator
6
+
7
+ import networkx as nx
8
+ import numpy as np
9
+ import pandas as pd
10
+ from tqdm.auto import tqdm
11
+
12
+ # Replace these imports with the correct paths in your package
13
+ from .LZGraphBase import LZGraphBase
14
+ from ..Utilities.decomposition import lempel_ziv_decomposition
15
+ from ..Utilities.misc import window, choice
16
+ from ..Exceptions import (
17
+ EmptyDataError,
18
+ MissingColumnError,
19
+ InvalidSequenceError,
20
+ NoGeneDataError,
21
+ GeneAnnotationError,
22
+ )
23
+
24
+ # --------------------------------------------------------------------------
25
+ # Global Logger Configuration
26
+ # --------------------------------------------------------------------------
27
+ # Configure logging so that users see log messages without setting it up themselves
28
+ logger = logging.getLogger(__name__)
29
+ logger.setLevel(logging.INFO)
30
+
31
+ if not logger.handlers:
32
+ console_handler = logging.StreamHandler()
33
+ console_handler.setLevel(logging.INFO)
34
+ formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s")
35
+ console_handler.setFormatter(formatter)
36
+ logger.addHandler(console_handler)
37
+
38
+ # --------------------------------------------------------------------------
39
+ # Helper Functions
40
+ # --------------------------------------------------------------------------
41
+
42
+ def derive_lz_and_position(cdr3_sequence: str) -> Tuple[List[str], List[int]]:
43
+ """
44
+ Decompose a CDR3 amino acid sequence into its LZ subpatterns along with
45
+ cumulative positions. For example, "ABCDE" might become (["AB", "CD", "E"], [2,4,5]).
46
+ """
47
+ lz_subpatterns = lempel_ziv_decomposition(cdr3_sequence)
48
+ cumulative_lengths = []
49
+ total_length = 0
50
+ for subpattern in lz_subpatterns:
51
+ total_length += len(subpattern)
52
+ cumulative_lengths.append(total_length)
53
+ return lz_subpatterns, cumulative_lengths
54
+
55
+
56
+ def path_to_sequence(lz_subpatterns: List[str]) -> str:
57
+ """
58
+ Given a list of LZ subpatterns with positions attached, clean them to remove the
59
+ numeric part and return a single concatenated amino acid sequence.
60
+ """
61
+ cleaned_nodes = [AAPLZGraph.clean_node(sp) for sp in lz_subpatterns]
62
+ return ''.join(cleaned_nodes)
63
+
64
+ # --------------------------------------------------------------------------
65
+ # The AAPLZGraph Class
66
+ # --------------------------------------------------------------------------
67
+
68
+ class AAPLZGraph(LZGraphBase):
69
+ """
70
+ Implements the "Amino Acid Positional" version of the LZGraph for analyzing
71
+ amino-acid sequences, especially for immunological data.
72
+
73
+ Each node is labeled as:
74
+ {LZ_subpattern}_{start_position_in_sequence}
75
+ """
76
+
77
+ # Valid amino acid characters (standard 20 amino acids)
78
+ VALID_AMINO_ACIDS = set('ACDEFGHIKLMNPQRSTVWY')
79
+
80
+ def __init__(
81
+ self,
82
+ data: pd.DataFrame,
83
+ verbose: bool = True,
84
+ calculate_trainset_pgen: bool = False,
85
+ validate_sequences: bool = True
86
+ ):
87
+ """
88
+ Create an amino-acid-positional LZGraph from a DataFrame.
89
+
90
+ The DataFrame must contain at least a column "cdr3_amino_acid".
91
+ Optionally, columns "V" and "J" may also be provided to embed
92
+ gene information. If these columns are present, self.genetic is set to True.
93
+
94
+ Args:
95
+ data (pd.DataFrame): Input data for constructing the graph. Must contain
96
+ a "cdr3_amino_acid" column; optionally "V" and "J" columns.
97
+ verbose (bool): Whether to log progress information.
98
+ calculate_trainset_pgen (bool): If True, compute PGEN for each sequence in `data`.
99
+ validate_sequences (bool): If True, validate that sequences contain only
100
+ standard amino acids. Set to False to skip validation for performance.
101
+
102
+ Raises:
103
+ TypeError: If data is not a pandas DataFrame.
104
+ ValueError: If required columns are missing or sequences are invalid.
105
+ """
106
+ super().__init__() # Initialize LZGraphBase
107
+
108
+ # Input validation
109
+ self._validate_input(data, validate_sequences)
110
+
111
+ # Determine if we have gene data
112
+ self.genetic = (
113
+ isinstance(data, pd.DataFrame) and
114
+ ("V" in data.columns) and
115
+ ("J" in data.columns)
116
+ )
117
+
118
+ # Load gene data if present
119
+ if self.genetic:
120
+ self._load_gene_data(data)
121
+ self.verbose_driver(0, verbose) # "Gene Information Loaded"
122
+
123
+ # Build the graph with a custom routine
124
+ self.__simultaneous_graph_construction(data)
125
+ self.verbose_driver(1, verbose) # "Graph Constructed"
126
+
127
+ # Convert dicts to Series and normalize
128
+ self.length_distribution = pd.Series(self.lengths)
129
+ self.terminal_states = pd.Series(self.terminal_states)
130
+ self.initial_states = pd.Series(self.initial_states)
131
+
132
+ self.length_distribution_proba = self.terminal_states / self.terminal_states.sum()
133
+
134
+ # Filter out rarely observed initial states (for example, those <= 5)
135
+ self.initial_states = self.initial_states[self.initial_states > 5]
136
+ self.initial_states_probability = self.initial_states / self.initial_states.sum()
137
+
138
+ self.verbose_driver(2, verbose) # "Graph Metadata Derived"
139
+
140
+ # Derive subpattern probabilities & normalize edges
141
+ self._derive_subpattern_individual_probability()
142
+ self.verbose_driver(8, verbose)
143
+
144
+ self._normalize_edge_weights()
145
+ self.verbose_driver(3, verbose)
146
+
147
+ # If gene data is available, normalize gene weights in parallel
148
+ if self.genetic:
149
+ self._batch_gene_weight_normalization(n_process=3, verbose=verbose)
150
+ self.verbose_driver(4, verbose)
151
+
152
+ # Additional map derivations
153
+ self.edges_list = None
154
+ self._derive_terminal_state_map()
155
+ self.verbose_driver(7, verbose)
156
+ self._derive_stop_probability_data()
157
+ self.verbose_driver(8, verbose)
158
+ self.verbose_driver(5, verbose)
159
+
160
+ # Optionally compute the PGEN for each sequence
161
+ if calculate_trainset_pgen:
162
+ logger.info("Calculating PGEN for the training set. This may take some time...")
163
+ self.train_pgen = np.array([
164
+ self.walk_probability(seq, verbose=False)
165
+ for seq in data["cdr3_amino_acid"]
166
+ ])
167
+
168
+ self.constructor_end_time = time.time()
169
+ self.verbose_driver(6, verbose)
170
+ self.verbose_driver(-2, verbose)
171
+
172
+ # --------------------------------------------------------------------------
173
+ # Input Validation
174
+ # --------------------------------------------------------------------------
175
+
176
+ def _validate_input(self, data: pd.DataFrame, validate_sequences: bool) -> None:
177
+ """
178
+ Validate input data before graph construction.
179
+
180
+ Args:
181
+ data: Input DataFrame
182
+ validate_sequences: Whether to check sequence content
183
+
184
+ Raises:
185
+ TypeError: If data is not a pandas DataFrame
186
+ ValueError: If required columns are missing or data is invalid
187
+ """
188
+ # Check type
189
+ if not isinstance(data, pd.DataFrame):
190
+ raise TypeError(
191
+ f"Expected pandas DataFrame, got {type(data).__name__}. "
192
+ "Please provide a DataFrame with a 'cdr3_amino_acid' column."
193
+ )
194
+
195
+ # Check for required column
196
+ if 'cdr3_amino_acid' not in data.columns:
197
+ raise MissingColumnError(
198
+ column_name='cdr3_amino_acid',
199
+ available_columns=list(data.columns)
200
+ )
201
+
202
+ # Check for empty data
203
+ if len(data) == 0:
204
+ raise EmptyDataError("DataFrame is empty. Cannot build LZGraph from zero sequences.")
205
+
206
+ # Check for null values in CDR3 column
207
+ null_count = data['cdr3_amino_acid'].isna().sum()
208
+ if null_count > 0:
209
+ raise ValueError(
210
+ f"Found {null_count} null values in 'cdr3_amino_acid' column. "
211
+ "Please remove or fill null values before building the graph."
212
+ )
213
+
214
+ # Check for empty strings
215
+ empty_count = (data['cdr3_amino_acid'].str.len() == 0).sum()
216
+ if empty_count > 0:
217
+ raise ValueError(
218
+ f"Found {empty_count} empty strings in 'cdr3_amino_acid' column. "
219
+ "Please remove empty sequences before building the graph."
220
+ )
221
+
222
+ # Validate sequence content if requested
223
+ if validate_sequences:
224
+ self._validate_sequence_content(data['cdr3_amino_acid'])
225
+
226
+ # Validate gene columns if present
227
+ if 'V' in data.columns and 'J' in data.columns:
228
+ self._validate_gene_columns(data)
229
+
230
+ def _validate_sequence_content(self, sequences: pd.Series) -> None:
231
+ """
232
+ Validate that sequences contain only valid amino acid characters.
233
+
234
+ Args:
235
+ sequences: Series of amino acid sequences
236
+
237
+ Raises:
238
+ ValueError: If invalid characters are found
239
+ """
240
+ # Sample up to 1000 sequences for validation (performance)
241
+ sample_size = min(1000, len(sequences))
242
+ sample = sequences.sample(n=sample_size, random_state=42) if len(sequences) > sample_size else sequences
243
+
244
+ invalid_chars_found = set()
245
+ invalid_sequences = []
246
+
247
+ for seq in sample:
248
+ if not isinstance(seq, str):
249
+ raise InvalidSequenceError(
250
+ sequence=str(seq),
251
+ message=f"Sequence must be a string, got {type(seq).__name__}: {seq}"
252
+ )
253
+ invalid_in_seq = set(seq.upper()) - self.VALID_AMINO_ACIDS
254
+ if invalid_in_seq:
255
+ invalid_chars_found.update(invalid_in_seq)
256
+ if len(invalid_sequences) < 3:
257
+ invalid_sequences.append(seq)
258
+
259
+ if invalid_chars_found:
260
+ examples = ", ".join(f"'{s}'" for s in invalid_sequences[:3])
261
+ raise InvalidSequenceError(
262
+ sequence=invalid_sequences[0] if invalid_sequences else None,
263
+ invalid_chars=''.join(sorted(invalid_chars_found)),
264
+ message=(
265
+ f"Found invalid amino acid characters: {sorted(invalid_chars_found)}. "
266
+ f"Valid amino acids are: {sorted(self.VALID_AMINO_ACIDS)}. "
267
+ f"Example invalid sequences: {examples}"
268
+ )
269
+ )
270
+
271
+ def _validate_gene_columns(self, data: pd.DataFrame) -> None:
272
+ """
273
+ Validate V and J gene columns.
274
+
275
+ Args:
276
+ data: DataFrame with V and J columns
277
+
278
+ Raises:
279
+ ValueError: If gene columns contain invalid data
280
+ """
281
+ # Check for nulls in gene columns
282
+ v_nulls = data['V'].isna().sum()
283
+ j_nulls = data['J'].isna().sum()
284
+
285
+ if v_nulls > 0 or j_nulls > 0:
286
+ raise ValueError(
287
+ f"Found null values in gene columns: V has {v_nulls} nulls, "
288
+ f"J has {j_nulls} nulls. Please fill or remove rows with missing genes."
289
+ )
290
+
291
+ # --------------------------------------------------------------------------
292
+ # Overridden / specialized methods
293
+ # --------------------------------------------------------------------------
294
+
295
+ @staticmethod
296
+ def encode_sequence(amino_acid: str) -> List[str]:
297
+ """
298
+ Convert an amino acid string into LZ sub-patterns with positions.
299
+ Each sub-pattern has the format: '{LZ_subpattern}_{position}'.
300
+ """
301
+ lz, locs = derive_lz_and_position(amino_acid)
302
+ return [f"{subp}_{pos}" for subp, pos in zip(lz, locs)]
303
+
304
+ @staticmethod
305
+ def clean_node(base: str) -> str:
306
+ """
307
+ Given a sub-pattern that might look like "ABC_10", extract only the amino acids ("ABC").
308
+ """
309
+ match = re.search(r'[A-Z]+', base)
310
+ return match.group(0) if match else ""
311
+
312
+ def _decomposed_sequence_generator(
313
+ self,
314
+ data: Union[pd.DataFrame, pd.Series]
315
+ ) -> Generator:
316
+ """
317
+ A generator that yields the information needed to build the graph:
318
+ (steps, locations, v, j) if self.genetic == True, otherwise (steps, locations).
319
+ """
320
+ if self.genetic:
321
+ # DataFrame with cdr3_amino_acid, V, J columns
322
+ for cdr3, v, j in tqdm(
323
+ zip(data["cdr3_amino_acid"], data["V"], data["J"]),
324
+ desc="Building Graph",
325
+ leave=False
326
+ ):
327
+ lz, locs = derive_lz_and_position(cdr3)
328
+ steps = window(lz, 2)
329
+ locations = window(locs, 2)
330
+
331
+ self.lengths[len(cdr3)] = self.lengths.get(len(cdr3), 0) + 1
332
+ self._update_terminal_states(f"{lz[-1]}_{locs[-1]}")
333
+ self._update_initial_states(f"{lz[0]}_1")
334
+
335
+ yield (steps, locations, v, j)
336
+ else:
337
+ # Possibly just a "cdr3_amino_acid" column
338
+ seq_iter = data["cdr3_amino_acid"] if isinstance(data, pd.DataFrame) else data
339
+ for cdr3 in tqdm(seq_iter, desc="Building Graph", leave=False):
340
+ lz, locs = derive_lz_and_position(cdr3)
341
+ steps = window(lz, 2)
342
+ locations = window(locs, 2)
343
+
344
+ self.lengths[len(cdr3)] = self.lengths.get(len(cdr3), 0) + 1
345
+ self._update_terminal_states(f"{lz[-1]}_{locs[-1]}")
346
+ self._update_initial_states(f"{lz[0]}_1")
347
+
348
+ yield (steps, locations)
349
+
350
+ def __simultaneous_graph_construction(self, data: pd.DataFrame) -> None:
351
+ """
352
+ Custom simultaneous construction of the graph, mirroring the parent's
353
+ _simultaneous_graph_construction but applying our specialized decomposition.
354
+ """
355
+ logger.debug("Starting custom __simultaneous_graph_construction...")
356
+ processing_stream = self._decomposed_sequence_generator(data)
357
+ if self.genetic:
358
+ for steps, locations, v, j in processing_stream:
359
+ for (A, B), (loc_a, loc_b) in zip(steps, locations):
360
+ A_ = f"{A}_{loc_a}"
361
+ self.per_node_observed_frequency[A_] = self.per_node_observed_frequency.get(A_, 0) + 1
362
+ B_ = f"{B}_{loc_b}"
363
+ self._insert_edge_and_information(A_, B_, v, j)
364
+ self.per_node_observed_frequency[B_] = self.per_node_observed_frequency.get(B_, 0)
365
+ else:
366
+ for steps, locations in processing_stream:
367
+ for (A, B), (loc_a, loc_b) in zip(steps, locations):
368
+ A_ = f"{A}_{loc_a}"
369
+ self.per_node_observed_frequency[A_] = self.per_node_observed_frequency.get(A_, 0) + 1
370
+ B_ = f"{B}_{loc_b}"
371
+ self._insert_edge_and_information_no_genes(A_, B_)
372
+ self.per_node_observed_frequency[B_] = self.per_node_observed_frequency.get(B_, 0)
373
+
374
+ logger.debug("Finished custom __simultaneous_graph_construction.")
375
+
376
+ # --------------------------------------------------------------------------
377
+ # Probability / Gene-Related Methods
378
+ # --------------------------------------------------------------------------
379
+
380
+ def walk_probability(
381
+ self,
382
+ walk: Union[str, List[str]],
383
+ verbose: bool = True,
384
+ use_epsilon: bool = False,
385
+ use_log: bool = False
386
+ ) -> float:
387
+ """
388
+ Given a walk (a sequence or a pre-encoded LZ pattern list), return
389
+ the probability (PGEN) of generating it under this graph.
390
+
391
+ If edges are missing, we handle them by a geometric mean approach.
392
+ If verbose=True, log warnings on missing edges.
393
+
394
+ Args:
395
+ walk: The walk as a string or list of sub-patterns.
396
+ verbose: Whether to log missing-edge warnings.
397
+ use_epsilon: Not used in the main logic here, but kept for consistency.
398
+ use_log: If True, return log-probability instead of probability.
399
+ Recommended for long sequences (>30 amino acids) to prevent
400
+ numerical underflow. Default is False for backward compatibility.
401
+
402
+ Returns:
403
+ float: Probability of generating the walk, or log-probability if use_log=True.
404
+ For log-probability mode, returns log(P) where P is the probability.
405
+ """
406
+ # If the user passed a raw sequence, encode it
407
+ if isinstance(walk, str):
408
+ lz, locs = derive_lz_and_position(walk)
409
+ walk_ = [f"{subp}_{pos}" for subp, pos in zip(lz, locs)]
410
+ else:
411
+ walk_ = walk
412
+
413
+ if len(walk_) == 0:
414
+ logger.warning("Empty walk provided to walk_probability. Returning eps.")
415
+ return np.log(np.finfo(float).eps) if use_log else np.finfo(float).eps
416
+
417
+ # If the first subpattern isn't observed, return near-zero
418
+ first_node = walk_[0]
419
+ if first_node not in self.subpattern_individual_probability['proba']:
420
+ eps_val = np.finfo(float).eps ** 2
421
+ return np.log(eps_val) if use_log else eps_val
422
+
423
+ missing_count = 0
424
+ total_steps = 0
425
+
426
+ if use_log:
427
+ # Log-space computation to prevent underflow
428
+ log_proba = np.log(self.subpattern_individual_probability['proba'][first_node])
429
+
430
+ for step1, step2 in window(walk_, 2):
431
+ if self.graph.has_edge(step1, step2):
432
+ edge_weight = self.graph[step1][step2]["weight"]
433
+ log_proba += np.log(edge_weight)
434
+ else:
435
+ if verbose:
436
+ logger.warning(f"No Edge Connecting: {step1} --> {step2}. Probability adjusted.")
437
+ missing_count += 1
438
+ total_steps += 1
439
+
440
+ if missing_count > 0 and total_steps > 0:
441
+ # Geometric mean approach in log-space
442
+ # gmean = proba^(1/total_steps) => log(gmean) = log_proba / total_steps
443
+ log_gmean = log_proba / total_steps
444
+ log_proba += log_gmean * missing_count
445
+
446
+ return log_proba
447
+ else:
448
+ # Original probability-space computation
449
+ proba = self.subpattern_individual_probability['proba'][first_node]
450
+
451
+ for step1, step2 in window(walk_, 2):
452
+ if self.graph.has_edge(step1, step2):
453
+ edge_weight = self.graph[step1][step2]["weight"]
454
+ proba *= edge_weight
455
+ else:
456
+ if verbose:
457
+ logger.warning(f"No Edge Connecting: {step1} --> {step2}. Probability adjusted.")
458
+ missing_count += 1
459
+ total_steps += 1
460
+
461
+ if missing_count > 0 and total_steps > 0:
462
+ # Geometric mean approach
463
+ gmean = np.power(proba, 1.0 / total_steps)
464
+ proba *= (gmean ** missing_count)
465
+
466
+ return proba
467
+
468
+ def walk_log_probability(
469
+ self,
470
+ walk: Union[str, List[str]],
471
+ verbose: bool = True
472
+ ) -> float:
473
+ """
474
+ Convenience method to compute log-probability of a walk.
475
+ Equivalent to walk_probability(walk, use_log=True).
476
+
477
+ Recommended for long sequences (>30 amino acids) to prevent
478
+ numerical underflow.
479
+
480
+ Args:
481
+ walk: The walk as a string or list of sub-patterns.
482
+ verbose: Whether to log missing-edge warnings.
483
+
484
+ Returns:
485
+ float: Log-probability of generating the walk.
486
+ """
487
+ return self.walk_probability(walk, verbose=verbose, use_log=True)
488
+
489
+ def walk_gene_probability(
490
+ self,
491
+ walk: Union[str, List[str]],
492
+ v: str,
493
+ j: str,
494
+ verbose: bool = True,
495
+ use_epsilon: bool = False
496
+ ) -> Tuple[float, float]:
497
+ """
498
+ Compute the probability of generating a walk under a specific (V, J) gene pair.
499
+ We start with the marginal probabilities for v and j, then multiply by
500
+ edge-level usage.
501
+
502
+ Returns:
503
+ (proba_v, proba_j) as a tuple of floats.
504
+ If an edge is missing, we either return 0 or an epsilon if use_epsilon=True.
505
+ """
506
+ # Possibly re-encode the walk if the user passed a raw string
507
+ if isinstance(walk, str):
508
+ lz, locs = derive_lz_and_position(walk)
509
+ walk_ = [f"{subp}_{pos}" for subp, pos in zip(lz, locs)]
510
+ else:
511
+ walk_ = walk
512
+
513
+ try:
514
+ proba_v = self.marginal_vgenes.loc[v]
515
+ proba_j = self.marginal_jgenes.loc[j]
516
+ except KeyError:
517
+ logger.warning(f"Gene {v} or {j} not found in the marginal distributions.")
518
+ val = np.finfo(float).eps if use_epsilon else 0.0
519
+ return (val, val)
520
+
521
+ for step1, step2 in window(walk_, 2):
522
+ if not self.graph.has_edge(step1, step2):
523
+ if verbose:
524
+ logger.warning(f"No edge for {step1}->{step2}.")
525
+ val = np.finfo(float).eps if use_epsilon else 0.0
526
+ return (val, val)
527
+
528
+ e_data = self.graph[step1][step2]
529
+ # If these genes aren't on the edge, it's effectively 0
530
+ if v not in e_data or j not in e_data:
531
+ if verbose:
532
+ logger.warning(f"Edge {step1}->{step2} missing {v} or {j}.")
533
+ val = np.finfo(float).eps if use_epsilon else 0.0
534
+ return (val, val)
535
+
536
+ proba_v *= e_data[v]
537
+ proba_j *= e_data[j]
538
+
539
+ return proba_v, proba_j
540
+
541
+ # --------------------------------------------------------------------------
542
+ # Random Walk, Multi-gene Walk, and Variation Methods
543
+ # --------------------------------------------------------------------------
544
+
545
+ def multi_gene_random_walk(
546
+ self,
547
+ N: int,
548
+ seq_len: Union[int, str],
549
+ initial_state: Optional[str] = None,
550
+ vj_init: str = "marginal"
551
+ ):
552
+ """
553
+ Generate N random walks, each constrained to use a randomly selected (V, J) pair.
554
+ If seq_len is an integer, we aim for a terminal state that matches that length.
555
+ If seq_len == 'unsupervised', we consider all terminal states.
556
+
557
+ Args:
558
+ N (int): Number of random walks to generate.
559
+ seq_len (int or 'unsupervised'): Desired sequence length or 'unsupervised'.
560
+ initial_state (str): Optional initial node.
561
+ vj_init (str): 'marginal' or 'combined' for random gene selection.
562
+
563
+ Returns:
564
+ A list of tuples: [(walk, selected_v, selected_j), ...].
565
+ """
566
+ selected_v, selected_j = self._select_random_vj_genes(vj_init)
567
+
568
+ if seq_len == "unsupervised":
569
+ final_states = list(self.terminal_states.index)
570
+ else:
571
+ final_states = self._length_specific_terminal_state(seq_len)
572
+
573
+ if self.genetic_walks_black_list is None:
574
+ self.genetic_walks_black_list = {}
575
+
576
+ # We'll keep track of how many times each final state can still be used
577
+ lengths = pd.Series(self.terminal_states).value_counts()
578
+ max_length = lengths.idxmax() if not lengths.empty else None
579
+
580
+ results = []
581
+ for _ in tqdm(range(N), desc="Generating multi-gene walks"):
582
+ if initial_state is None:
583
+ current_state = self._random_initial_state()
584
+ walk = [current_state]
585
+ else:
586
+ current_state = initial_state
587
+ walk = [initial_state]
588
+
589
+ # while the walk is not in a valid final state
590
+ while current_state not in lengths.index:
591
+ # Extract data from the current state's edges
592
+ if current_state not in self.graph:
593
+ logger.warning(f"Current state {current_state} not in graph.")
594
+ break
595
+
596
+ edge_info = pd.DataFrame(dict(self.graph[current_state]))
597
+ # Apply blacklist if present
598
+ if (current_state, selected_v, selected_j) in self.genetic_walks_black_list:
599
+ blacklisted = self.genetic_walks_black_list[(current_state, selected_v, selected_j)]
600
+ edge_info = edge_info.drop(columns=blacklisted, errors="ignore")
601
+
602
+ # Check for presence of selected V/J genes
603
+ # We'll consider edges that contain both selected_v and selected_j
604
+ # in the attribute keys
605
+ sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
606
+ {selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
607
+
608
+ if sub_df.empty:
609
+ # No valid edges
610
+ if len(walk) > 2:
611
+ prev_state = walk[-2]
612
+ self.genetic_walks_black_list[(prev_state, selected_v, selected_j)] = \
613
+ self.genetic_walks_black_list.get((prev_state, selected_v, selected_j), []) + [walk[-1]]
614
+ current_state = prev_state
615
+ walk.pop()
616
+ else:
617
+ walk = walk[:1]
618
+ current_state = walk[0]
619
+ selected_v, selected_j = self._select_random_vj_genes(vj_init)
620
+ continue
621
+
622
+ # Weighted choice among these edges
623
+ w = edge_info.loc["weight", sub_df.index]
624
+ w /= w.sum()
625
+ if w.empty:
626
+ # Again, no valid edges
627
+ if len(walk) > 2:
628
+ prev_state = walk[-2]
629
+ self.genetic_walks_black_list[(prev_state, selected_v, selected_j)] = \
630
+ self.genetic_walks_black_list.get((prev_state, selected_v, selected_j), []) + [walk[-1]]
631
+ current_state = prev_state
632
+ walk.pop()
633
+ else:
634
+ walk = walk[:1]
635
+ current_state = walk[0]
636
+ selected_v, selected_j = self._select_random_vj_genes(vj_init)
637
+ continue
638
+
639
+ current_state = np.random.choice(w.index, p=w.values)
640
+ walk.append(current_state)
641
+
642
+ results.append((walk, selected_v, selected_j))
643
+
644
+ # If the walk ended in a length we track, decrement
645
+ if (walk[-1] in lengths.index) and (walk[-1] != max_length):
646
+ lengths[walk[-1]] -= 1
647
+ if lengths[walk[-1]] < 0:
648
+ lengths.pop(walk[-1])
649
+
650
+ return results
651
+
652
+ def unsupervised_random_walk(self):
653
+ """
654
+ Conduct a random walk from a randomly selected initial state
655
+ to a final state, ignoring gene constraints. The walk stops when
656
+ `is_stop_condition` is True.
657
+
658
+ Returns:
659
+ (walk, sequence):
660
+ - walk: list of node names
661
+ - sequence: cleaned amino-acid sequence of the walk
662
+ """
663
+ random_init = self._random_initial_state()
664
+ current_state = random_init
665
+ walk = [random_init]
666
+ sequence = self.clean_node(random_init)
667
+
668
+ while not self.is_stop_condition(current_state):
669
+ current_state = self.random_step(current_state)
670
+ walk.append(current_state)
671
+ sequence += self.clean_node(current_state)
672
+
673
+ return walk, sequence
674
+
675
+ def walk_genes(
676
+ self,
677
+ walk: List[str],
678
+ dropna: bool = True,
679
+ raise_error: bool = True
680
+ ) -> pd.DataFrame:
681
+ """
682
+ Given a walk (list of nodes), return a DataFrame of gene usage at each edge.
683
+
684
+ Args:
685
+ walk: The node path.
686
+ dropna: If True, drop edges with no gene data.
687
+ raise_error: If True and result is empty, raise an Exception.
688
+
689
+ Returns:
690
+ A DataFrame where rows = gene names (V*, J*) and columns = edges in walk.
691
+ """
692
+ trans_genes = {}
693
+ for i in range(len(walk) - 1):
694
+ if self.graph.has_edge(walk[i], walk[i+1]):
695
+ edge_attrs = self.graph[walk[i]][walk[i+1]].copy()
696
+ # Remove these special keys
697
+ for remove_key in ["weight", "Vsum", "Jsum"]:
698
+ edge_attrs.pop(remove_key, None)
699
+ trans_genes[f"{walk[i]}->{walk[i+1]}"] = edge_attrs
700
+
701
+ df = pd.DataFrame(trans_genes)
702
+ if dropna:
703
+ df.dropna(how="all", inplace=True)
704
+
705
+ if df.empty and raise_error:
706
+ raise GeneAnnotationError("No gene data found in the edges for the given walk.")
707
+
708
+ # Example: add gene type and sum columns for clarity
709
+ df["type"] = ["V" if "v" in idx.lower() else "J" for idx in df.index]
710
+ df["sum"] = df.sum(axis=1, numeric_only=True)
711
+
712
+ return df
713
+
714
+ def random_walk_distribution_based(self, length_distribution: pd.Series):
715
+ """
716
+ Creates random walks in proportion to a given length distribution.
717
+ We do a large number of unsupervised walks, then sample from them
718
+ to match the specified distribution.
719
+
720
+ Args:
721
+ length_distribution: A Series whose index is lengths and values are
722
+ how many sequences of that length we want.
723
+
724
+ Returns:
725
+ A 2D array (list of pairs) of shape [N, 2], where each row is (Seq, Walk).
726
+ """
727
+ N = length_distribution.sum() * 3 # multiply by some factor
728
+ N = int(N)
729
+
730
+ walks = []
731
+ seqs = []
732
+ logger.info(f"Generating ~{N} random walks to filter by length distribution...")
733
+ for _ in tqdm(range(N), desc="Random Walk Distribution"):
734
+ rw, rseq = self.unsupervised_random_walk()
735
+ walks.append(rw)
736
+ seqs.append(rseq)
737
+
738
+ df = pd.DataFrame({"Seqs": seqs, "Walks": walks})
739
+ df["L"] = df["Seqs"].str.len()
740
+
741
+ samples = []
742
+ for length_val in length_distribution.index:
743
+ needed = length_distribution[length_val]
744
+ subset = df[df["L"] == length_val]
745
+ if len(subset) < needed:
746
+ logger.warning(
747
+ f"Requested {needed} sequences of length {length_val}, but only found {len(subset)}."
748
+ )
749
+ needed = len(subset)
750
+ if needed > 0:
751
+ samples.append(subset.sample(n=needed, replace=False))
752
+
753
+ if not samples:
754
+ return np.array([])
755
+
756
+ final = pd.concat(samples, ignore_index=True)
757
+ return final[["Seqs", "Walks"]].values
758
+
759
+ def get_gene_graph(self, v: str, j: str) -> nx.DiGraph:
760
+ """
761
+ Returns a subgraph containing only edges that contain both gene v and j.
762
+ """
763
+ if self.edges_list is None:
764
+ self.edges_list = list(self.graph.edges(data=True))
765
+
766
+ to_drop = []
767
+ for src, dst, attrs in self.edges_list:
768
+ if (v not in attrs) or (j not in attrs):
769
+ to_drop.append((src, dst))
770
+
771
+ G = self.graph.copy()
772
+ G.remove_edges_from(to_drop)
773
+ G.remove_nodes_from(list(nx.isolates(G)))
774
+ return G
775
+
776
+ def cac_random_gene_walk(self, initial_state=None, vj_init="combined"):
777
+ """
778
+ Conduct a random walk in a "combine-and-conquer" style,
779
+ using a subgraph that only contains edges with the selected V/J.
780
+
781
+ If the subgraph for (V, J) doesn't exist yet, create it. Then pick a random
782
+ initial state from that subgraph and walk until a final node is reached.
783
+ """
784
+ selected_v, selected_j = self._select_random_vj_genes(vj_init)
785
+
786
+ if (selected_v, selected_j) not in self.cac_graphs:
787
+ G = self.get_gene_graph(selected_v, selected_j)
788
+ self.cac_graphs[(selected_v, selected_j)] = G
789
+ else:
790
+ G = self.cac_graphs[(selected_v, selected_j)]
791
+
792
+ final_states = list(set(self.terminal_states.index) & set(G.nodes))
793
+ first_states = self.initial_states.loc[list(set(self.initial_states.index) & set(G.nodes))]
794
+ first_states = first_states / first_states.sum()
795
+
796
+ if initial_state is None:
797
+ current_state = np.random.choice(first_states.index, p=first_states.values)
798
+ else:
799
+ current_state = initial_state
800
+
801
+ walk = [current_state]
802
+ if self.genetic_walks_black_list is None:
803
+ self.genetic_walks_black_list = {}
804
+
805
+ while current_state not in final_states:
806
+ edge_info = pd.DataFrame(dict(G[current_state]))
807
+ # Apply blacklist
808
+ if (selected_v, selected_j, current_state) in self.genetic_walks_black_list:
809
+ edge_info = edge_info.drop(
810
+ columns=self.genetic_walks_black_list[(selected_v, selected_j, current_state)],
811
+ errors="ignore"
812
+ )
813
+
814
+ if edge_info.shape[1] == 0:
815
+ if len(walk) > 1:
816
+ prev_state = walk[-2]
817
+ blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
818
+ blacklisted_cols.append(current_state)
819
+ self.genetic_walks_black_list[(selected_v, selected_j, prev_state)] = blacklisted_cols
820
+ walk.pop()
821
+ current_state = walk[-1]
822
+ else:
823
+ # Stuck at the start
824
+ break
825
+
826
+ sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
827
+ {selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
828
+ if sub_df.empty:
829
+ # No valid edges
830
+ if len(walk) > 1:
831
+ prev_state = walk[-2]
832
+ blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
833
+ blacklisted_cols.append(current_state)
834
+ self.genetic_walks_black_list[(selected_v, selected_j, prev_state)] = blacklisted_cols
835
+ walk.pop()
836
+ current_state = walk[-1]
837
+ else:
838
+ break
839
+ else:
840
+ w = edge_info.loc["weight", sub_df.index]
841
+ w /= w.sum()
842
+ next_state = np.random.choice(w.index, p=w.values)
843
+ walk.append(next_state)
844
+ current_state = next_state
845
+
846
+ return walk, selected_v, selected_j
847
+
848
+ def sequence_variation_curve(self, cdr3_sample: str):
849
+ """
850
+ Given a CDR3 sequence, return two lists:
851
+ (encoded_subpatterns, out_degree_list)
852
+ where out_degree_list[i] is the out-degree of the node in the graph
853
+ corresponding to the i-th subpattern.
854
+ """
855
+ encoded = self.encode_sequence(cdr3_sample)
856
+ curve = [self.graph.out_degree(node) for node in encoded]
857
+ return encoded, curve
858
+
859
+ def path_gene_table(
860
+ self,
861
+ cdr3_sample: str,
862
+ threshold: Optional[float] = None
863
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
864
+ """
865
+ Return two tables (for V genes and J genes) representing all possible
866
+ V/J usage that could generate the given cdr3_sample. Genes missing from
867
+ more than 'threshold' fraction of edges are dropped.
868
+
869
+ Args:
870
+ cdr3_sample: The amino acid sequence to examine.
871
+ threshold: If None, defaults to length/4 for V genes,
872
+ and length/2 for J genes.
873
+
874
+ Returns:
875
+ (vgene_table, jgene_table) as DataFrames.
876
+ """
877
+ encoded = self.encode_sequence(cdr3_sample)
878
+ length = len(encoded)
879
+
880
+ if threshold is None:
881
+ threshold_v = length * 0.25
882
+ threshold_j = length * 0.5
883
+ else:
884
+ threshold_v = threshold
885
+ threshold_j = threshold
886
+
887
+ # Get gene table once (avoid duplicate expensive call)
888
+ gene_table = self.walk_genes(encoded, dropna=False, raise_error=False)
889
+ na_counts = gene_table.isna().sum(axis=1)
890
+
891
+ # For V genes
892
+ mask_v = na_counts < threshold_v
893
+ vgene_table = gene_table[mask_v & gene_table.index.str.contains("V", case=False)]
894
+
895
+ # For J genes
896
+ mask_j = na_counts < threshold_j
897
+ jgene_table = gene_table[mask_j & gene_table.index.str.contains("J", case=False)]
898
+
899
+ # Sort by ascending number of NaNs (optional clarity)
900
+ jgene_table = jgene_table.loc[jgene_table.isna().sum(axis=1).sort_values().index]
901
+ vgene_table = vgene_table.loc[vgene_table.isna().sum(axis=1).sort_values().index]
902
+
903
+ return vgene_table, jgene_table
904
+
905
+ def gene_variation(self, cdr3: str) -> pd.DataFrame:
906
+ """
907
+ Return a DataFrame that shows how many V and J genes are possible
908
+ for each subpattern in the given cdr3 sequence.
909
+
910
+ The DataFrame columns:
911
+ - 'genes': number of possible V or J genes
912
+ - 'type': 'V' or 'J'
913
+ - 'sp': the LZ subpattern
914
+ """
915
+ if not self.genetic:
916
+ raise NoGeneDataError(
917
+ operation="gene_repertoire_per_subpattern",
918
+ message="Cannot compute gene repertoire: this LZGraph has no gene data (genetic=False)."
919
+ )
920
+
921
+ encoded_a = self.encode_sequence(cdr3)
922
+ n_v_genes = []
923
+ n_j_genes = []
924
+
925
+ # First subpattern: full marginal V, J size
926
+ n_v_genes.append(len(self.marginal_vgenes))
927
+ n_j_genes.append(len(self.marginal_jgenes))
928
+
929
+ for node in encoded_a[1:]:
930
+ in_edges = self.graph.in_edges(node)
931
+ v_genes = set()
932
+ j_genes = set()
933
+ for e_a, e_b in in_edges:
934
+ # Gather keys ignoring weight, Vsum, Jsum
935
+ ed = pd.Series(self.graph[e_a][e_b]).drop(["weight", "Vsum", "Jsum"], errors="ignore")
936
+ # Gene names are like "TRBV30-1*01" and "TRBJ1-2*01", so use 'in' not startswith
937
+ v_genes |= set(g for g in ed.index if "V" in g)
938
+ j_genes |= set(g for g in ed.index if "J" in g)
939
+
940
+ n_v_genes.append(len(v_genes))
941
+ n_j_genes.append(len(j_genes))
942
+
943
+ # Combine into a DataFrame
944
+ lz_subpatterns = lempel_ziv_decomposition(cdr3)
945
+ j_df = pd.DataFrame({
946
+ "genes": n_v_genes + n_j_genes,
947
+ "type": (["V"] * len(n_v_genes)) + (["J"] * len(n_j_genes)),
948
+ "sp": lz_subpatterns + lz_subpatterns
949
+ })
950
+ return j_df