LZGraphs 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LZGraphs/BagOfWords/BOWEncoder.py +104 -0
- LZGraphs/BagOfWords/__init__.py +0 -0
- LZGraphs/Exceptions/__init__.py +550 -0
- LZGraphs/Graphs/AminoAcidPositional.py +950 -0
- LZGraphs/Graphs/LZGraphBase.py +788 -0
- LZGraphs/Graphs/Naive.py +735 -0
- LZGraphs/Graphs/NucleotideDoublePositional.py +587 -0
- LZGraphs/Graphs/__init__.py +0 -0
- LZGraphs/Metrics/Metrics.py +356 -0
- LZGraphs/Metrics/__init__.py +22 -0
- LZGraphs/Metrics/entropy.py +477 -0
- LZGraphs/Mixins/GeneLogicMixin.py +129 -0
- LZGraphs/Mixins/GenePredictionMixin.py +266 -0
- LZGraphs/Mixins/RandomWalkMixin.py +154 -0
- LZGraphs/Mixins/__init__.py +3 -0
- LZGraphs/Utilities/NodeEdgeSaturationProbe.py +427 -0
- LZGraphs/Utilities/Utilities.py +101 -0
- LZGraphs/Utilities/__init__.py +1 -0
- LZGraphs/Utilities/decomposition.py +69 -0
- LZGraphs/Utilities/graph_operations.py +123 -0
- LZGraphs/Utilities/misc.py +95 -0
- LZGraphs/Visualization/Visualize.py +196 -0
- LZGraphs/Visualization/__init__.py +0 -0
- LZGraphs/__init__.py +50 -0
- LZGraphs/py.typed +0 -0
- lzgraphs-1.0.0.dist-info/METADATA +236 -0
- lzgraphs-1.0.0.dist-info/RECORD +30 -0
- lzgraphs-1.0.0.dist-info/WHEEL +5 -0
- lzgraphs-1.0.0.dist-info/licenses/LICENSE +21 -0
- lzgraphs-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,950 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
import time
|
|
4
|
+
from collections import Counter
|
|
5
|
+
from typing import List, Tuple, Union, Optional, Generator
|
|
6
|
+
|
|
7
|
+
import networkx as nx
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
# Replace these imports with the correct paths in your package
|
|
13
|
+
from .LZGraphBase import LZGraphBase
|
|
14
|
+
from ..Utilities.decomposition import lempel_ziv_decomposition
|
|
15
|
+
from ..Utilities.misc import window, choice
|
|
16
|
+
from ..Exceptions import (
|
|
17
|
+
EmptyDataError,
|
|
18
|
+
MissingColumnError,
|
|
19
|
+
InvalidSequenceError,
|
|
20
|
+
NoGeneDataError,
|
|
21
|
+
GeneAnnotationError,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# --------------------------------------------------------------------------
|
|
25
|
+
# Global Logger Configuration
|
|
26
|
+
# --------------------------------------------------------------------------
|
|
27
|
+
# Configure logging so that users see log messages without setting it up themselves
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
logger.setLevel(logging.INFO)
|
|
30
|
+
|
|
31
|
+
if not logger.handlers:
|
|
32
|
+
console_handler = logging.StreamHandler()
|
|
33
|
+
console_handler.setLevel(logging.INFO)
|
|
34
|
+
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s")
|
|
35
|
+
console_handler.setFormatter(formatter)
|
|
36
|
+
logger.addHandler(console_handler)
|
|
37
|
+
|
|
38
|
+
# --------------------------------------------------------------------------
|
|
39
|
+
# Helper Functions
|
|
40
|
+
# --------------------------------------------------------------------------
|
|
41
|
+
|
|
42
|
+
def derive_lz_and_position(cdr3_sequence: str) -> Tuple[List[str], List[int]]:
|
|
43
|
+
"""
|
|
44
|
+
Decompose a CDR3 amino acid sequence into its LZ subpatterns along with
|
|
45
|
+
cumulative positions. For example, "ABCDE" might become (["AB", "CD", "E"], [2,4,5]).
|
|
46
|
+
"""
|
|
47
|
+
lz_subpatterns = lempel_ziv_decomposition(cdr3_sequence)
|
|
48
|
+
cumulative_lengths = []
|
|
49
|
+
total_length = 0
|
|
50
|
+
for subpattern in lz_subpatterns:
|
|
51
|
+
total_length += len(subpattern)
|
|
52
|
+
cumulative_lengths.append(total_length)
|
|
53
|
+
return lz_subpatterns, cumulative_lengths
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def path_to_sequence(lz_subpatterns: List[str]) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Given a list of LZ subpatterns with positions attached, clean them to remove the
|
|
59
|
+
numeric part and return a single concatenated amino acid sequence.
|
|
60
|
+
"""
|
|
61
|
+
cleaned_nodes = [AAPLZGraph.clean_node(sp) for sp in lz_subpatterns]
|
|
62
|
+
return ''.join(cleaned_nodes)
|
|
63
|
+
|
|
64
|
+
# --------------------------------------------------------------------------
|
|
65
|
+
# The AAPLZGraph Class
|
|
66
|
+
# --------------------------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
class AAPLZGraph(LZGraphBase):
|
|
69
|
+
"""
|
|
70
|
+
Implements the "Amino Acid Positional" version of the LZGraph for analyzing
|
|
71
|
+
amino-acid sequences, especially for immunological data.
|
|
72
|
+
|
|
73
|
+
Each node is labeled as:
|
|
74
|
+
{LZ_subpattern}_{start_position_in_sequence}
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# Valid amino acid characters (standard 20 amino acids)
|
|
78
|
+
VALID_AMINO_ACIDS = set('ACDEFGHIKLMNPQRSTVWY')
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
data: pd.DataFrame,
|
|
83
|
+
verbose: bool = True,
|
|
84
|
+
calculate_trainset_pgen: bool = False,
|
|
85
|
+
validate_sequences: bool = True
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Create an amino-acid-positional LZGraph from a DataFrame.
|
|
89
|
+
|
|
90
|
+
The DataFrame must contain at least a column "cdr3_amino_acid".
|
|
91
|
+
Optionally, columns "V" and "J" may also be provided to embed
|
|
92
|
+
gene information. If these columns are present, self.genetic is set to True.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
data (pd.DataFrame): Input data for constructing the graph. Must contain
|
|
96
|
+
a "cdr3_amino_acid" column; optionally "V" and "J" columns.
|
|
97
|
+
verbose (bool): Whether to log progress information.
|
|
98
|
+
calculate_trainset_pgen (bool): If True, compute PGEN for each sequence in `data`.
|
|
99
|
+
validate_sequences (bool): If True, validate that sequences contain only
|
|
100
|
+
standard amino acids. Set to False to skip validation for performance.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
TypeError: If data is not a pandas DataFrame.
|
|
104
|
+
ValueError: If required columns are missing or sequences are invalid.
|
|
105
|
+
"""
|
|
106
|
+
super().__init__() # Initialize LZGraphBase
|
|
107
|
+
|
|
108
|
+
# Input validation
|
|
109
|
+
self._validate_input(data, validate_sequences)
|
|
110
|
+
|
|
111
|
+
# Determine if we have gene data
|
|
112
|
+
self.genetic = (
|
|
113
|
+
isinstance(data, pd.DataFrame) and
|
|
114
|
+
("V" in data.columns) and
|
|
115
|
+
("J" in data.columns)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Load gene data if present
|
|
119
|
+
if self.genetic:
|
|
120
|
+
self._load_gene_data(data)
|
|
121
|
+
self.verbose_driver(0, verbose) # "Gene Information Loaded"
|
|
122
|
+
|
|
123
|
+
# Build the graph with a custom routine
|
|
124
|
+
self.__simultaneous_graph_construction(data)
|
|
125
|
+
self.verbose_driver(1, verbose) # "Graph Constructed"
|
|
126
|
+
|
|
127
|
+
# Convert dicts to Series and normalize
|
|
128
|
+
self.length_distribution = pd.Series(self.lengths)
|
|
129
|
+
self.terminal_states = pd.Series(self.terminal_states)
|
|
130
|
+
self.initial_states = pd.Series(self.initial_states)
|
|
131
|
+
|
|
132
|
+
self.length_distribution_proba = self.terminal_states / self.terminal_states.sum()
|
|
133
|
+
|
|
134
|
+
# Filter out rarely observed initial states (for example, those <= 5)
|
|
135
|
+
self.initial_states = self.initial_states[self.initial_states > 5]
|
|
136
|
+
self.initial_states_probability = self.initial_states / self.initial_states.sum()
|
|
137
|
+
|
|
138
|
+
self.verbose_driver(2, verbose) # "Graph Metadata Derived"
|
|
139
|
+
|
|
140
|
+
# Derive subpattern probabilities & normalize edges
|
|
141
|
+
self._derive_subpattern_individual_probability()
|
|
142
|
+
self.verbose_driver(8, verbose)
|
|
143
|
+
|
|
144
|
+
self._normalize_edge_weights()
|
|
145
|
+
self.verbose_driver(3, verbose)
|
|
146
|
+
|
|
147
|
+
# If gene data is available, normalize gene weights in parallel
|
|
148
|
+
if self.genetic:
|
|
149
|
+
self._batch_gene_weight_normalization(n_process=3, verbose=verbose)
|
|
150
|
+
self.verbose_driver(4, verbose)
|
|
151
|
+
|
|
152
|
+
# Additional map derivations
|
|
153
|
+
self.edges_list = None
|
|
154
|
+
self._derive_terminal_state_map()
|
|
155
|
+
self.verbose_driver(7, verbose)
|
|
156
|
+
self._derive_stop_probability_data()
|
|
157
|
+
self.verbose_driver(8, verbose)
|
|
158
|
+
self.verbose_driver(5, verbose)
|
|
159
|
+
|
|
160
|
+
# Optionally compute the PGEN for each sequence
|
|
161
|
+
if calculate_trainset_pgen:
|
|
162
|
+
logger.info("Calculating PGEN for the training set. This may take some time...")
|
|
163
|
+
self.train_pgen = np.array([
|
|
164
|
+
self.walk_probability(seq, verbose=False)
|
|
165
|
+
for seq in data["cdr3_amino_acid"]
|
|
166
|
+
])
|
|
167
|
+
|
|
168
|
+
self.constructor_end_time = time.time()
|
|
169
|
+
self.verbose_driver(6, verbose)
|
|
170
|
+
self.verbose_driver(-2, verbose)
|
|
171
|
+
|
|
172
|
+
# --------------------------------------------------------------------------
|
|
173
|
+
# Input Validation
|
|
174
|
+
# --------------------------------------------------------------------------
|
|
175
|
+
|
|
176
|
+
def _validate_input(self, data: pd.DataFrame, validate_sequences: bool) -> None:
|
|
177
|
+
"""
|
|
178
|
+
Validate input data before graph construction.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
data: Input DataFrame
|
|
182
|
+
validate_sequences: Whether to check sequence content
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
TypeError: If data is not a pandas DataFrame
|
|
186
|
+
ValueError: If required columns are missing or data is invalid
|
|
187
|
+
"""
|
|
188
|
+
# Check type
|
|
189
|
+
if not isinstance(data, pd.DataFrame):
|
|
190
|
+
raise TypeError(
|
|
191
|
+
f"Expected pandas DataFrame, got {type(data).__name__}. "
|
|
192
|
+
"Please provide a DataFrame with a 'cdr3_amino_acid' column."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Check for required column
|
|
196
|
+
if 'cdr3_amino_acid' not in data.columns:
|
|
197
|
+
raise MissingColumnError(
|
|
198
|
+
column_name='cdr3_amino_acid',
|
|
199
|
+
available_columns=list(data.columns)
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Check for empty data
|
|
203
|
+
if len(data) == 0:
|
|
204
|
+
raise EmptyDataError("DataFrame is empty. Cannot build LZGraph from zero sequences.")
|
|
205
|
+
|
|
206
|
+
# Check for null values in CDR3 column
|
|
207
|
+
null_count = data['cdr3_amino_acid'].isna().sum()
|
|
208
|
+
if null_count > 0:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
f"Found {null_count} null values in 'cdr3_amino_acid' column. "
|
|
211
|
+
"Please remove or fill null values before building the graph."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Check for empty strings
|
|
215
|
+
empty_count = (data['cdr3_amino_acid'].str.len() == 0).sum()
|
|
216
|
+
if empty_count > 0:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Found {empty_count} empty strings in 'cdr3_amino_acid' column. "
|
|
219
|
+
"Please remove empty sequences before building the graph."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Validate sequence content if requested
|
|
223
|
+
if validate_sequences:
|
|
224
|
+
self._validate_sequence_content(data['cdr3_amino_acid'])
|
|
225
|
+
|
|
226
|
+
# Validate gene columns if present
|
|
227
|
+
if 'V' in data.columns and 'J' in data.columns:
|
|
228
|
+
self._validate_gene_columns(data)
|
|
229
|
+
|
|
230
|
+
def _validate_sequence_content(self, sequences: pd.Series) -> None:
|
|
231
|
+
"""
|
|
232
|
+
Validate that sequences contain only valid amino acid characters.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
sequences: Series of amino acid sequences
|
|
236
|
+
|
|
237
|
+
Raises:
|
|
238
|
+
ValueError: If invalid characters are found
|
|
239
|
+
"""
|
|
240
|
+
# Sample up to 1000 sequences for validation (performance)
|
|
241
|
+
sample_size = min(1000, len(sequences))
|
|
242
|
+
sample = sequences.sample(n=sample_size, random_state=42) if len(sequences) > sample_size else sequences
|
|
243
|
+
|
|
244
|
+
invalid_chars_found = set()
|
|
245
|
+
invalid_sequences = []
|
|
246
|
+
|
|
247
|
+
for seq in sample:
|
|
248
|
+
if not isinstance(seq, str):
|
|
249
|
+
raise InvalidSequenceError(
|
|
250
|
+
sequence=str(seq),
|
|
251
|
+
message=f"Sequence must be a string, got {type(seq).__name__}: {seq}"
|
|
252
|
+
)
|
|
253
|
+
invalid_in_seq = set(seq.upper()) - self.VALID_AMINO_ACIDS
|
|
254
|
+
if invalid_in_seq:
|
|
255
|
+
invalid_chars_found.update(invalid_in_seq)
|
|
256
|
+
if len(invalid_sequences) < 3:
|
|
257
|
+
invalid_sequences.append(seq)
|
|
258
|
+
|
|
259
|
+
if invalid_chars_found:
|
|
260
|
+
examples = ", ".join(f"'{s}'" for s in invalid_sequences[:3])
|
|
261
|
+
raise InvalidSequenceError(
|
|
262
|
+
sequence=invalid_sequences[0] if invalid_sequences else None,
|
|
263
|
+
invalid_chars=''.join(sorted(invalid_chars_found)),
|
|
264
|
+
message=(
|
|
265
|
+
f"Found invalid amino acid characters: {sorted(invalid_chars_found)}. "
|
|
266
|
+
f"Valid amino acids are: {sorted(self.VALID_AMINO_ACIDS)}. "
|
|
267
|
+
f"Example invalid sequences: {examples}"
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def _validate_gene_columns(self, data: pd.DataFrame) -> None:
|
|
272
|
+
"""
|
|
273
|
+
Validate V and J gene columns.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
data: DataFrame with V and J columns
|
|
277
|
+
|
|
278
|
+
Raises:
|
|
279
|
+
ValueError: If gene columns contain invalid data
|
|
280
|
+
"""
|
|
281
|
+
# Check for nulls in gene columns
|
|
282
|
+
v_nulls = data['V'].isna().sum()
|
|
283
|
+
j_nulls = data['J'].isna().sum()
|
|
284
|
+
|
|
285
|
+
if v_nulls > 0 or j_nulls > 0:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"Found null values in gene columns: V has {v_nulls} nulls, "
|
|
288
|
+
f"J has {j_nulls} nulls. Please fill or remove rows with missing genes."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# --------------------------------------------------------------------------
|
|
292
|
+
# Overridden / specialized methods
|
|
293
|
+
# --------------------------------------------------------------------------
|
|
294
|
+
|
|
295
|
+
@staticmethod
|
|
296
|
+
def encode_sequence(amino_acid: str) -> List[str]:
|
|
297
|
+
"""
|
|
298
|
+
Convert an amino acid string into LZ sub-patterns with positions.
|
|
299
|
+
Each sub-pattern has the format: '{LZ_subpattern}_{position}'.
|
|
300
|
+
"""
|
|
301
|
+
lz, locs = derive_lz_and_position(amino_acid)
|
|
302
|
+
return [f"{subp}_{pos}" for subp, pos in zip(lz, locs)]
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def clean_node(base: str) -> str:
|
|
306
|
+
"""
|
|
307
|
+
Given a sub-pattern that might look like "ABC_10", extract only the amino acids ("ABC").
|
|
308
|
+
"""
|
|
309
|
+
match = re.search(r'[A-Z]+', base)
|
|
310
|
+
return match.group(0) if match else ""
|
|
311
|
+
|
|
312
|
+
def _decomposed_sequence_generator(
|
|
313
|
+
self,
|
|
314
|
+
data: Union[pd.DataFrame, pd.Series]
|
|
315
|
+
) -> Generator:
|
|
316
|
+
"""
|
|
317
|
+
A generator that yields the information needed to build the graph:
|
|
318
|
+
(steps, locations, v, j) if self.genetic == True, otherwise (steps, locations).
|
|
319
|
+
"""
|
|
320
|
+
if self.genetic:
|
|
321
|
+
# DataFrame with cdr3_amino_acid, V, J columns
|
|
322
|
+
for cdr3, v, j in tqdm(
|
|
323
|
+
zip(data["cdr3_amino_acid"], data["V"], data["J"]),
|
|
324
|
+
desc="Building Graph",
|
|
325
|
+
leave=False
|
|
326
|
+
):
|
|
327
|
+
lz, locs = derive_lz_and_position(cdr3)
|
|
328
|
+
steps = window(lz, 2)
|
|
329
|
+
locations = window(locs, 2)
|
|
330
|
+
|
|
331
|
+
self.lengths[len(cdr3)] = self.lengths.get(len(cdr3), 0) + 1
|
|
332
|
+
self._update_terminal_states(f"{lz[-1]}_{locs[-1]}")
|
|
333
|
+
self._update_initial_states(f"{lz[0]}_1")
|
|
334
|
+
|
|
335
|
+
yield (steps, locations, v, j)
|
|
336
|
+
else:
|
|
337
|
+
# Possibly just a "cdr3_amino_acid" column
|
|
338
|
+
seq_iter = data["cdr3_amino_acid"] if isinstance(data, pd.DataFrame) else data
|
|
339
|
+
for cdr3 in tqdm(seq_iter, desc="Building Graph", leave=False):
|
|
340
|
+
lz, locs = derive_lz_and_position(cdr3)
|
|
341
|
+
steps = window(lz, 2)
|
|
342
|
+
locations = window(locs, 2)
|
|
343
|
+
|
|
344
|
+
self.lengths[len(cdr3)] = self.lengths.get(len(cdr3), 0) + 1
|
|
345
|
+
self._update_terminal_states(f"{lz[-1]}_{locs[-1]}")
|
|
346
|
+
self._update_initial_states(f"{lz[0]}_1")
|
|
347
|
+
|
|
348
|
+
yield (steps, locations)
|
|
349
|
+
|
|
350
|
+
def __simultaneous_graph_construction(self, data: pd.DataFrame) -> None:
|
|
351
|
+
"""
|
|
352
|
+
Custom simultaneous construction of the graph, mirroring the parent's
|
|
353
|
+
_simultaneous_graph_construction but applying our specialized decomposition.
|
|
354
|
+
"""
|
|
355
|
+
logger.debug("Starting custom __simultaneous_graph_construction...")
|
|
356
|
+
processing_stream = self._decomposed_sequence_generator(data)
|
|
357
|
+
if self.genetic:
|
|
358
|
+
for steps, locations, v, j in processing_stream:
|
|
359
|
+
for (A, B), (loc_a, loc_b) in zip(steps, locations):
|
|
360
|
+
A_ = f"{A}_{loc_a}"
|
|
361
|
+
self.per_node_observed_frequency[A_] = self.per_node_observed_frequency.get(A_, 0) + 1
|
|
362
|
+
B_ = f"{B}_{loc_b}"
|
|
363
|
+
self._insert_edge_and_information(A_, B_, v, j)
|
|
364
|
+
self.per_node_observed_frequency[B_] = self.per_node_observed_frequency.get(B_, 0)
|
|
365
|
+
else:
|
|
366
|
+
for steps, locations in processing_stream:
|
|
367
|
+
for (A, B), (loc_a, loc_b) in zip(steps, locations):
|
|
368
|
+
A_ = f"{A}_{loc_a}"
|
|
369
|
+
self.per_node_observed_frequency[A_] = self.per_node_observed_frequency.get(A_, 0) + 1
|
|
370
|
+
B_ = f"{B}_{loc_b}"
|
|
371
|
+
self._insert_edge_and_information_no_genes(A_, B_)
|
|
372
|
+
self.per_node_observed_frequency[B_] = self.per_node_observed_frequency.get(B_, 0)
|
|
373
|
+
|
|
374
|
+
logger.debug("Finished custom __simultaneous_graph_construction.")
|
|
375
|
+
|
|
376
|
+
# --------------------------------------------------------------------------
|
|
377
|
+
# Probability / Gene-Related Methods
|
|
378
|
+
# --------------------------------------------------------------------------
|
|
379
|
+
|
|
380
|
+
def walk_probability(
|
|
381
|
+
self,
|
|
382
|
+
walk: Union[str, List[str]],
|
|
383
|
+
verbose: bool = True,
|
|
384
|
+
use_epsilon: bool = False,
|
|
385
|
+
use_log: bool = False
|
|
386
|
+
) -> float:
|
|
387
|
+
"""
|
|
388
|
+
Given a walk (a sequence or a pre-encoded LZ pattern list), return
|
|
389
|
+
the probability (PGEN) of generating it under this graph.
|
|
390
|
+
|
|
391
|
+
If edges are missing, we handle them by a geometric mean approach.
|
|
392
|
+
If verbose=True, log warnings on missing edges.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
walk: The walk as a string or list of sub-patterns.
|
|
396
|
+
verbose: Whether to log missing-edge warnings.
|
|
397
|
+
use_epsilon: Not used in the main logic here, but kept for consistency.
|
|
398
|
+
use_log: If True, return log-probability instead of probability.
|
|
399
|
+
Recommended for long sequences (>30 amino acids) to prevent
|
|
400
|
+
numerical underflow. Default is False for backward compatibility.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
float: Probability of generating the walk, or log-probability if use_log=True.
|
|
404
|
+
For log-probability mode, returns log(P) where P is the probability.
|
|
405
|
+
"""
|
|
406
|
+
# If the user passed a raw sequence, encode it
|
|
407
|
+
if isinstance(walk, str):
|
|
408
|
+
lz, locs = derive_lz_and_position(walk)
|
|
409
|
+
walk_ = [f"{subp}_{pos}" for subp, pos in zip(lz, locs)]
|
|
410
|
+
else:
|
|
411
|
+
walk_ = walk
|
|
412
|
+
|
|
413
|
+
if len(walk_) == 0:
|
|
414
|
+
logger.warning("Empty walk provided to walk_probability. Returning eps.")
|
|
415
|
+
return np.log(np.finfo(float).eps) if use_log else np.finfo(float).eps
|
|
416
|
+
|
|
417
|
+
# If the first subpattern isn't observed, return near-zero
|
|
418
|
+
first_node = walk_[0]
|
|
419
|
+
if first_node not in self.subpattern_individual_probability['proba']:
|
|
420
|
+
eps_val = np.finfo(float).eps ** 2
|
|
421
|
+
return np.log(eps_val) if use_log else eps_val
|
|
422
|
+
|
|
423
|
+
missing_count = 0
|
|
424
|
+
total_steps = 0
|
|
425
|
+
|
|
426
|
+
if use_log:
|
|
427
|
+
# Log-space computation to prevent underflow
|
|
428
|
+
log_proba = np.log(self.subpattern_individual_probability['proba'][first_node])
|
|
429
|
+
|
|
430
|
+
for step1, step2 in window(walk_, 2):
|
|
431
|
+
if self.graph.has_edge(step1, step2):
|
|
432
|
+
edge_weight = self.graph[step1][step2]["weight"]
|
|
433
|
+
log_proba += np.log(edge_weight)
|
|
434
|
+
else:
|
|
435
|
+
if verbose:
|
|
436
|
+
logger.warning(f"No Edge Connecting: {step1} --> {step2}. Probability adjusted.")
|
|
437
|
+
missing_count += 1
|
|
438
|
+
total_steps += 1
|
|
439
|
+
|
|
440
|
+
if missing_count > 0 and total_steps > 0:
|
|
441
|
+
# Geometric mean approach in log-space
|
|
442
|
+
# gmean = proba^(1/total_steps) => log(gmean) = log_proba / total_steps
|
|
443
|
+
log_gmean = log_proba / total_steps
|
|
444
|
+
log_proba += log_gmean * missing_count
|
|
445
|
+
|
|
446
|
+
return log_proba
|
|
447
|
+
else:
|
|
448
|
+
# Original probability-space computation
|
|
449
|
+
proba = self.subpattern_individual_probability['proba'][first_node]
|
|
450
|
+
|
|
451
|
+
for step1, step2 in window(walk_, 2):
|
|
452
|
+
if self.graph.has_edge(step1, step2):
|
|
453
|
+
edge_weight = self.graph[step1][step2]["weight"]
|
|
454
|
+
proba *= edge_weight
|
|
455
|
+
else:
|
|
456
|
+
if verbose:
|
|
457
|
+
logger.warning(f"No Edge Connecting: {step1} --> {step2}. Probability adjusted.")
|
|
458
|
+
missing_count += 1
|
|
459
|
+
total_steps += 1
|
|
460
|
+
|
|
461
|
+
if missing_count > 0 and total_steps > 0:
|
|
462
|
+
# Geometric mean approach
|
|
463
|
+
gmean = np.power(proba, 1.0 / total_steps)
|
|
464
|
+
proba *= (gmean ** missing_count)
|
|
465
|
+
|
|
466
|
+
return proba
|
|
467
|
+
|
|
468
|
+
def walk_log_probability(
|
|
469
|
+
self,
|
|
470
|
+
walk: Union[str, List[str]],
|
|
471
|
+
verbose: bool = True
|
|
472
|
+
) -> float:
|
|
473
|
+
"""
|
|
474
|
+
Convenience method to compute log-probability of a walk.
|
|
475
|
+
Equivalent to walk_probability(walk, use_log=True).
|
|
476
|
+
|
|
477
|
+
Recommended for long sequences (>30 amino acids) to prevent
|
|
478
|
+
numerical underflow.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
walk: The walk as a string or list of sub-patterns.
|
|
482
|
+
verbose: Whether to log missing-edge warnings.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
float: Log-probability of generating the walk.
|
|
486
|
+
"""
|
|
487
|
+
return self.walk_probability(walk, verbose=verbose, use_log=True)
|
|
488
|
+
|
|
489
|
+
def walk_gene_probability(
|
|
490
|
+
self,
|
|
491
|
+
walk: Union[str, List[str]],
|
|
492
|
+
v: str,
|
|
493
|
+
j: str,
|
|
494
|
+
verbose: bool = True,
|
|
495
|
+
use_epsilon: bool = False
|
|
496
|
+
) -> Tuple[float, float]:
|
|
497
|
+
"""
|
|
498
|
+
Compute the probability of generating a walk under a specific (V, J) gene pair.
|
|
499
|
+
We start with the marginal probabilities for v and j, then multiply by
|
|
500
|
+
edge-level usage.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
(proba_v, proba_j) as a tuple of floats.
|
|
504
|
+
If an edge is missing, we either return 0 or an epsilon if use_epsilon=True.
|
|
505
|
+
"""
|
|
506
|
+
# Possibly re-encode the walk if the user passed a raw string
|
|
507
|
+
if isinstance(walk, str):
|
|
508
|
+
lz, locs = derive_lz_and_position(walk)
|
|
509
|
+
walk_ = [f"{subp}_{pos}" for subp, pos in zip(lz, locs)]
|
|
510
|
+
else:
|
|
511
|
+
walk_ = walk
|
|
512
|
+
|
|
513
|
+
try:
|
|
514
|
+
proba_v = self.marginal_vgenes.loc[v]
|
|
515
|
+
proba_j = self.marginal_jgenes.loc[j]
|
|
516
|
+
except KeyError:
|
|
517
|
+
logger.warning(f"Gene {v} or {j} not found in the marginal distributions.")
|
|
518
|
+
val = np.finfo(float).eps if use_epsilon else 0.0
|
|
519
|
+
return (val, val)
|
|
520
|
+
|
|
521
|
+
for step1, step2 in window(walk_, 2):
|
|
522
|
+
if not self.graph.has_edge(step1, step2):
|
|
523
|
+
if verbose:
|
|
524
|
+
logger.warning(f"No edge for {step1}->{step2}.")
|
|
525
|
+
val = np.finfo(float).eps if use_epsilon else 0.0
|
|
526
|
+
return (val, val)
|
|
527
|
+
|
|
528
|
+
e_data = self.graph[step1][step2]
|
|
529
|
+
# If these genes aren't on the edge, it's effectively 0
|
|
530
|
+
if v not in e_data or j not in e_data:
|
|
531
|
+
if verbose:
|
|
532
|
+
logger.warning(f"Edge {step1}->{step2} missing {v} or {j}.")
|
|
533
|
+
val = np.finfo(float).eps if use_epsilon else 0.0
|
|
534
|
+
return (val, val)
|
|
535
|
+
|
|
536
|
+
proba_v *= e_data[v]
|
|
537
|
+
proba_j *= e_data[j]
|
|
538
|
+
|
|
539
|
+
return proba_v, proba_j
|
|
540
|
+
|
|
541
|
+
# --------------------------------------------------------------------------
|
|
542
|
+
# Random Walk, Multi-gene Walk, and Variation Methods
|
|
543
|
+
# --------------------------------------------------------------------------
|
|
544
|
+
|
|
545
|
+
def multi_gene_random_walk(
|
|
546
|
+
self,
|
|
547
|
+
N: int,
|
|
548
|
+
seq_len: Union[int, str],
|
|
549
|
+
initial_state: Optional[str] = None,
|
|
550
|
+
vj_init: str = "marginal"
|
|
551
|
+
):
|
|
552
|
+
"""
|
|
553
|
+
Generate N random walks, each constrained to use a randomly selected (V, J) pair.
|
|
554
|
+
If seq_len is an integer, we aim for a terminal state that matches that length.
|
|
555
|
+
If seq_len == 'unsupervised', we consider all terminal states.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
N (int): Number of random walks to generate.
|
|
559
|
+
seq_len (int or 'unsupervised'): Desired sequence length or 'unsupervised'.
|
|
560
|
+
initial_state (str): Optional initial node.
|
|
561
|
+
vj_init (str): 'marginal' or 'combined' for random gene selection.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
A list of tuples: [(walk, selected_v, selected_j), ...].
|
|
565
|
+
"""
|
|
566
|
+
selected_v, selected_j = self._select_random_vj_genes(vj_init)
|
|
567
|
+
|
|
568
|
+
if seq_len == "unsupervised":
|
|
569
|
+
final_states = list(self.terminal_states.index)
|
|
570
|
+
else:
|
|
571
|
+
final_states = self._length_specific_terminal_state(seq_len)
|
|
572
|
+
|
|
573
|
+
if self.genetic_walks_black_list is None:
|
|
574
|
+
self.genetic_walks_black_list = {}
|
|
575
|
+
|
|
576
|
+
# We'll keep track of how many times each final state can still be used
|
|
577
|
+
lengths = pd.Series(self.terminal_states).value_counts()
|
|
578
|
+
max_length = lengths.idxmax() if not lengths.empty else None
|
|
579
|
+
|
|
580
|
+
results = []
|
|
581
|
+
for _ in tqdm(range(N), desc="Generating multi-gene walks"):
|
|
582
|
+
if initial_state is None:
|
|
583
|
+
current_state = self._random_initial_state()
|
|
584
|
+
walk = [current_state]
|
|
585
|
+
else:
|
|
586
|
+
current_state = initial_state
|
|
587
|
+
walk = [initial_state]
|
|
588
|
+
|
|
589
|
+
# while the walk is not in a valid final state
|
|
590
|
+
while current_state not in lengths.index:
|
|
591
|
+
# Extract data from the current state's edges
|
|
592
|
+
if current_state not in self.graph:
|
|
593
|
+
logger.warning(f"Current state {current_state} not in graph.")
|
|
594
|
+
break
|
|
595
|
+
|
|
596
|
+
edge_info = pd.DataFrame(dict(self.graph[current_state]))
|
|
597
|
+
# Apply blacklist if present
|
|
598
|
+
if (current_state, selected_v, selected_j) in self.genetic_walks_black_list:
|
|
599
|
+
blacklisted = self.genetic_walks_black_list[(current_state, selected_v, selected_j)]
|
|
600
|
+
edge_info = edge_info.drop(columns=blacklisted, errors="ignore")
|
|
601
|
+
|
|
602
|
+
# Check for presence of selected V/J genes
|
|
603
|
+
# We'll consider edges that contain both selected_v and selected_j
|
|
604
|
+
# in the attribute keys
|
|
605
|
+
sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
|
|
606
|
+
{selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
|
|
607
|
+
|
|
608
|
+
if sub_df.empty:
|
|
609
|
+
# No valid edges
|
|
610
|
+
if len(walk) > 2:
|
|
611
|
+
prev_state = walk[-2]
|
|
612
|
+
self.genetic_walks_black_list[(prev_state, selected_v, selected_j)] = \
|
|
613
|
+
self.genetic_walks_black_list.get((prev_state, selected_v, selected_j), []) + [walk[-1]]
|
|
614
|
+
current_state = prev_state
|
|
615
|
+
walk.pop()
|
|
616
|
+
else:
|
|
617
|
+
walk = walk[:1]
|
|
618
|
+
current_state = walk[0]
|
|
619
|
+
selected_v, selected_j = self._select_random_vj_genes(vj_init)
|
|
620
|
+
continue
|
|
621
|
+
|
|
622
|
+
# Weighted choice among these edges
|
|
623
|
+
w = edge_info.loc["weight", sub_df.index]
|
|
624
|
+
w /= w.sum()
|
|
625
|
+
if w.empty:
|
|
626
|
+
# Again, no valid edges
|
|
627
|
+
if len(walk) > 2:
|
|
628
|
+
prev_state = walk[-2]
|
|
629
|
+
self.genetic_walks_black_list[(prev_state, selected_v, selected_j)] = \
|
|
630
|
+
self.genetic_walks_black_list.get((prev_state, selected_v, selected_j), []) + [walk[-1]]
|
|
631
|
+
current_state = prev_state
|
|
632
|
+
walk.pop()
|
|
633
|
+
else:
|
|
634
|
+
walk = walk[:1]
|
|
635
|
+
current_state = walk[0]
|
|
636
|
+
selected_v, selected_j = self._select_random_vj_genes(vj_init)
|
|
637
|
+
continue
|
|
638
|
+
|
|
639
|
+
current_state = np.random.choice(w.index, p=w.values)
|
|
640
|
+
walk.append(current_state)
|
|
641
|
+
|
|
642
|
+
results.append((walk, selected_v, selected_j))
|
|
643
|
+
|
|
644
|
+
# If the walk ended in a length we track, decrement
|
|
645
|
+
if (walk[-1] in lengths.index) and (walk[-1] != max_length):
|
|
646
|
+
lengths[walk[-1]] -= 1
|
|
647
|
+
if lengths[walk[-1]] < 0:
|
|
648
|
+
lengths.pop(walk[-1])
|
|
649
|
+
|
|
650
|
+
return results
|
|
651
|
+
|
|
652
|
+
def unsupervised_random_walk(self):
|
|
653
|
+
"""
|
|
654
|
+
Conduct a random walk from a randomly selected initial state
|
|
655
|
+
to a final state, ignoring gene constraints. The walk stops when
|
|
656
|
+
`is_stop_condition` is True.
|
|
657
|
+
|
|
658
|
+
Returns:
|
|
659
|
+
(walk, sequence):
|
|
660
|
+
- walk: list of node names
|
|
661
|
+
- sequence: cleaned amino-acid sequence of the walk
|
|
662
|
+
"""
|
|
663
|
+
random_init = self._random_initial_state()
|
|
664
|
+
current_state = random_init
|
|
665
|
+
walk = [random_init]
|
|
666
|
+
sequence = self.clean_node(random_init)
|
|
667
|
+
|
|
668
|
+
while not self.is_stop_condition(current_state):
|
|
669
|
+
current_state = self.random_step(current_state)
|
|
670
|
+
walk.append(current_state)
|
|
671
|
+
sequence += self.clean_node(current_state)
|
|
672
|
+
|
|
673
|
+
return walk, sequence
|
|
674
|
+
|
|
675
|
+
def walk_genes(
|
|
676
|
+
self,
|
|
677
|
+
walk: List[str],
|
|
678
|
+
dropna: bool = True,
|
|
679
|
+
raise_error: bool = True
|
|
680
|
+
) -> pd.DataFrame:
|
|
681
|
+
"""
|
|
682
|
+
Given a walk (list of nodes), return a DataFrame of gene usage at each edge.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
walk: The node path.
|
|
686
|
+
dropna: If True, drop edges with no gene data.
|
|
687
|
+
raise_error: If True and result is empty, raise an Exception.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
A DataFrame where rows = gene names (V*, J*) and columns = edges in walk.
|
|
691
|
+
"""
|
|
692
|
+
trans_genes = {}
|
|
693
|
+
for i in range(len(walk) - 1):
|
|
694
|
+
if self.graph.has_edge(walk[i], walk[i+1]):
|
|
695
|
+
edge_attrs = self.graph[walk[i]][walk[i+1]].copy()
|
|
696
|
+
# Remove these special keys
|
|
697
|
+
for remove_key in ["weight", "Vsum", "Jsum"]:
|
|
698
|
+
edge_attrs.pop(remove_key, None)
|
|
699
|
+
trans_genes[f"{walk[i]}->{walk[i+1]}"] = edge_attrs
|
|
700
|
+
|
|
701
|
+
df = pd.DataFrame(trans_genes)
|
|
702
|
+
if dropna:
|
|
703
|
+
df.dropna(how="all", inplace=True)
|
|
704
|
+
|
|
705
|
+
if df.empty and raise_error:
|
|
706
|
+
raise GeneAnnotationError("No gene data found in the edges for the given walk.")
|
|
707
|
+
|
|
708
|
+
# Example: add gene type and sum columns for clarity
|
|
709
|
+
df["type"] = ["V" if "v" in idx.lower() else "J" for idx in df.index]
|
|
710
|
+
df["sum"] = df.sum(axis=1, numeric_only=True)
|
|
711
|
+
|
|
712
|
+
return df
|
|
713
|
+
|
|
714
|
+
def random_walk_distribution_based(self, length_distribution: pd.Series):
|
|
715
|
+
"""
|
|
716
|
+
Creates random walks in proportion to a given length distribution.
|
|
717
|
+
We do a large number of unsupervised walks, then sample from them
|
|
718
|
+
to match the specified distribution.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
length_distribution: A Series whose index is lengths and values are
|
|
722
|
+
how many sequences of that length we want.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
A 2D array (list of pairs) of shape [N, 2], where each row is (Seq, Walk).
|
|
726
|
+
"""
|
|
727
|
+
N = length_distribution.sum() * 3 # multiply by some factor
|
|
728
|
+
N = int(N)
|
|
729
|
+
|
|
730
|
+
walks = []
|
|
731
|
+
seqs = []
|
|
732
|
+
logger.info(f"Generating ~{N} random walks to filter by length distribution...")
|
|
733
|
+
for _ in tqdm(range(N), desc="Random Walk Distribution"):
|
|
734
|
+
rw, rseq = self.unsupervised_random_walk()
|
|
735
|
+
walks.append(rw)
|
|
736
|
+
seqs.append(rseq)
|
|
737
|
+
|
|
738
|
+
df = pd.DataFrame({"Seqs": seqs, "Walks": walks})
|
|
739
|
+
df["L"] = df["Seqs"].str.len()
|
|
740
|
+
|
|
741
|
+
samples = []
|
|
742
|
+
for length_val in length_distribution.index:
|
|
743
|
+
needed = length_distribution[length_val]
|
|
744
|
+
subset = df[df["L"] == length_val]
|
|
745
|
+
if len(subset) < needed:
|
|
746
|
+
logger.warning(
|
|
747
|
+
f"Requested {needed} sequences of length {length_val}, but only found {len(subset)}."
|
|
748
|
+
)
|
|
749
|
+
needed = len(subset)
|
|
750
|
+
if needed > 0:
|
|
751
|
+
samples.append(subset.sample(n=needed, replace=False))
|
|
752
|
+
|
|
753
|
+
if not samples:
|
|
754
|
+
return np.array([])
|
|
755
|
+
|
|
756
|
+
final = pd.concat(samples, ignore_index=True)
|
|
757
|
+
return final[["Seqs", "Walks"]].values
|
|
758
|
+
|
|
759
|
+
def get_gene_graph(self, v: str, j: str) -> nx.DiGraph:
|
|
760
|
+
"""
|
|
761
|
+
Returns a subgraph containing only edges that contain both gene v and j.
|
|
762
|
+
"""
|
|
763
|
+
if self.edges_list is None:
|
|
764
|
+
self.edges_list = list(self.graph.edges(data=True))
|
|
765
|
+
|
|
766
|
+
to_drop = []
|
|
767
|
+
for src, dst, attrs in self.edges_list:
|
|
768
|
+
if (v not in attrs) or (j not in attrs):
|
|
769
|
+
to_drop.append((src, dst))
|
|
770
|
+
|
|
771
|
+
G = self.graph.copy()
|
|
772
|
+
G.remove_edges_from(to_drop)
|
|
773
|
+
G.remove_nodes_from(list(nx.isolates(G)))
|
|
774
|
+
return G
|
|
775
|
+
|
|
776
|
+
def cac_random_gene_walk(self, initial_state=None, vj_init="combined"):
|
|
777
|
+
"""
|
|
778
|
+
Conduct a random walk in a "combine-and-conquer" style,
|
|
779
|
+
using a subgraph that only contains edges with the selected V/J.
|
|
780
|
+
|
|
781
|
+
If the subgraph for (V, J) doesn't exist yet, create it. Then pick a random
|
|
782
|
+
initial state from that subgraph and walk until a final node is reached.
|
|
783
|
+
"""
|
|
784
|
+
selected_v, selected_j = self._select_random_vj_genes(vj_init)
|
|
785
|
+
|
|
786
|
+
if (selected_v, selected_j) not in self.cac_graphs:
|
|
787
|
+
G = self.get_gene_graph(selected_v, selected_j)
|
|
788
|
+
self.cac_graphs[(selected_v, selected_j)] = G
|
|
789
|
+
else:
|
|
790
|
+
G = self.cac_graphs[(selected_v, selected_j)]
|
|
791
|
+
|
|
792
|
+
final_states = list(set(self.terminal_states.index) & set(G.nodes))
|
|
793
|
+
first_states = self.initial_states.loc[list(set(self.initial_states.index) & set(G.nodes))]
|
|
794
|
+
first_states = first_states / first_states.sum()
|
|
795
|
+
|
|
796
|
+
if initial_state is None:
|
|
797
|
+
current_state = np.random.choice(first_states.index, p=first_states.values)
|
|
798
|
+
else:
|
|
799
|
+
current_state = initial_state
|
|
800
|
+
|
|
801
|
+
walk = [current_state]
|
|
802
|
+
if self.genetic_walks_black_list is None:
|
|
803
|
+
self.genetic_walks_black_list = {}
|
|
804
|
+
|
|
805
|
+
while current_state not in final_states:
|
|
806
|
+
edge_info = pd.DataFrame(dict(G[current_state]))
|
|
807
|
+
# Apply blacklist
|
|
808
|
+
if (selected_v, selected_j, current_state) in self.genetic_walks_black_list:
|
|
809
|
+
edge_info = edge_info.drop(
|
|
810
|
+
columns=self.genetic_walks_black_list[(selected_v, selected_j, current_state)],
|
|
811
|
+
errors="ignore"
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
if edge_info.shape[1] == 0:
|
|
815
|
+
if len(walk) > 1:
|
|
816
|
+
prev_state = walk[-2]
|
|
817
|
+
blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
|
|
818
|
+
blacklisted_cols.append(current_state)
|
|
819
|
+
self.genetic_walks_black_list[(selected_v, selected_j, prev_state)] = blacklisted_cols
|
|
820
|
+
walk.pop()
|
|
821
|
+
current_state = walk[-1]
|
|
822
|
+
else:
|
|
823
|
+
# Stuck at the start
|
|
824
|
+
break
|
|
825
|
+
|
|
826
|
+
sub_df = edge_info.T[[selected_v, selected_j]].dropna(how="any") if \
|
|
827
|
+
{selected_v, selected_j}.issubset(edge_info.index) else pd.DataFrame()
|
|
828
|
+
if sub_df.empty:
|
|
829
|
+
# No valid edges
|
|
830
|
+
if len(walk) > 1:
|
|
831
|
+
prev_state = walk[-2]
|
|
832
|
+
blacklisted_cols = self.genetic_walks_black_list.get((selected_v, selected_j, prev_state), [])
|
|
833
|
+
blacklisted_cols.append(current_state)
|
|
834
|
+
self.genetic_walks_black_list[(selected_v, selected_j, prev_state)] = blacklisted_cols
|
|
835
|
+
walk.pop()
|
|
836
|
+
current_state = walk[-1]
|
|
837
|
+
else:
|
|
838
|
+
break
|
|
839
|
+
else:
|
|
840
|
+
w = edge_info.loc["weight", sub_df.index]
|
|
841
|
+
w /= w.sum()
|
|
842
|
+
next_state = np.random.choice(w.index, p=w.values)
|
|
843
|
+
walk.append(next_state)
|
|
844
|
+
current_state = next_state
|
|
845
|
+
|
|
846
|
+
return walk, selected_v, selected_j
|
|
847
|
+
|
|
848
|
+
def sequence_variation_curve(self, cdr3_sample: str):
|
|
849
|
+
"""
|
|
850
|
+
Given a CDR3 sequence, return two lists:
|
|
851
|
+
(encoded_subpatterns, out_degree_list)
|
|
852
|
+
where out_degree_list[i] is the out-degree of the node in the graph
|
|
853
|
+
corresponding to the i-th subpattern.
|
|
854
|
+
"""
|
|
855
|
+
encoded = self.encode_sequence(cdr3_sample)
|
|
856
|
+
curve = [self.graph.out_degree(node) for node in encoded]
|
|
857
|
+
return encoded, curve
|
|
858
|
+
|
|
859
|
+
def path_gene_table(
|
|
860
|
+
self,
|
|
861
|
+
cdr3_sample: str,
|
|
862
|
+
threshold: Optional[float] = None
|
|
863
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
864
|
+
"""
|
|
865
|
+
Return two tables (for V genes and J genes) representing all possible
|
|
866
|
+
V/J usage that could generate the given cdr3_sample. Genes missing from
|
|
867
|
+
more than 'threshold' fraction of edges are dropped.
|
|
868
|
+
|
|
869
|
+
Args:
|
|
870
|
+
cdr3_sample: The amino acid sequence to examine.
|
|
871
|
+
threshold: If None, defaults to length/4 for V genes,
|
|
872
|
+
and length/2 for J genes.
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
(vgene_table, jgene_table) as DataFrames.
|
|
876
|
+
"""
|
|
877
|
+
encoded = self.encode_sequence(cdr3_sample)
|
|
878
|
+
length = len(encoded)
|
|
879
|
+
|
|
880
|
+
if threshold is None:
|
|
881
|
+
threshold_v = length * 0.25
|
|
882
|
+
threshold_j = length * 0.5
|
|
883
|
+
else:
|
|
884
|
+
threshold_v = threshold
|
|
885
|
+
threshold_j = threshold
|
|
886
|
+
|
|
887
|
+
# Get gene table once (avoid duplicate expensive call)
|
|
888
|
+
gene_table = self.walk_genes(encoded, dropna=False, raise_error=False)
|
|
889
|
+
na_counts = gene_table.isna().sum(axis=1)
|
|
890
|
+
|
|
891
|
+
# For V genes
|
|
892
|
+
mask_v = na_counts < threshold_v
|
|
893
|
+
vgene_table = gene_table[mask_v & gene_table.index.str.contains("V", case=False)]
|
|
894
|
+
|
|
895
|
+
# For J genes
|
|
896
|
+
mask_j = na_counts < threshold_j
|
|
897
|
+
jgene_table = gene_table[mask_j & gene_table.index.str.contains("J", case=False)]
|
|
898
|
+
|
|
899
|
+
# Sort by ascending number of NaNs (optional clarity)
|
|
900
|
+
jgene_table = jgene_table.loc[jgene_table.isna().sum(axis=1).sort_values().index]
|
|
901
|
+
vgene_table = vgene_table.loc[vgene_table.isna().sum(axis=1).sort_values().index]
|
|
902
|
+
|
|
903
|
+
return vgene_table, jgene_table
|
|
904
|
+
|
|
905
|
+
def gene_variation(self, cdr3: str) -> pd.DataFrame:
|
|
906
|
+
"""
|
|
907
|
+
Return a DataFrame that shows how many V and J genes are possible
|
|
908
|
+
for each subpattern in the given cdr3 sequence.
|
|
909
|
+
|
|
910
|
+
The DataFrame columns:
|
|
911
|
+
- 'genes': number of possible V or J genes
|
|
912
|
+
- 'type': 'V' or 'J'
|
|
913
|
+
- 'sp': the LZ subpattern
|
|
914
|
+
"""
|
|
915
|
+
if not self.genetic:
|
|
916
|
+
raise NoGeneDataError(
|
|
917
|
+
operation="gene_repertoire_per_subpattern",
|
|
918
|
+
message="Cannot compute gene repertoire: this LZGraph has no gene data (genetic=False)."
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
encoded_a = self.encode_sequence(cdr3)
|
|
922
|
+
n_v_genes = []
|
|
923
|
+
n_j_genes = []
|
|
924
|
+
|
|
925
|
+
# First subpattern: full marginal V, J size
|
|
926
|
+
n_v_genes.append(len(self.marginal_vgenes))
|
|
927
|
+
n_j_genes.append(len(self.marginal_jgenes))
|
|
928
|
+
|
|
929
|
+
for node in encoded_a[1:]:
|
|
930
|
+
in_edges = self.graph.in_edges(node)
|
|
931
|
+
v_genes = set()
|
|
932
|
+
j_genes = set()
|
|
933
|
+
for e_a, e_b in in_edges:
|
|
934
|
+
# Gather keys ignoring weight, Vsum, Jsum
|
|
935
|
+
ed = pd.Series(self.graph[e_a][e_b]).drop(["weight", "Vsum", "Jsum"], errors="ignore")
|
|
936
|
+
# Gene names are like "TRBV30-1*01" and "TRBJ1-2*01", so use 'in' not startswith
|
|
937
|
+
v_genes |= set(g for g in ed.index if "V" in g)
|
|
938
|
+
j_genes |= set(g for g in ed.index if "J" in g)
|
|
939
|
+
|
|
940
|
+
n_v_genes.append(len(v_genes))
|
|
941
|
+
n_j_genes.append(len(j_genes))
|
|
942
|
+
|
|
943
|
+
# Combine into a DataFrame
|
|
944
|
+
lz_subpatterns = lempel_ziv_decomposition(cdr3)
|
|
945
|
+
j_df = pd.DataFrame({
|
|
946
|
+
"genes": n_v_genes + n_j_genes,
|
|
947
|
+
"type": (["V"] * len(n_v_genes)) + (["J"] * len(n_j_genes)),
|
|
948
|
+
"sp": lz_subpatterns + lz_subpatterns
|
|
949
|
+
})
|
|
950
|
+
return j_df
|