yamcot 1.0.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yamcot/__init__.py +46 -0
- yamcot/_core/__init__.py +17 -0
- yamcot/_core/_core.cp310-win_amd64.pyd +0 -0
- yamcot/_core/bindings.cpp +28 -0
- yamcot/_core/core_functions.h +29 -0
- yamcot/_core/fasta_to_plain.h +182 -0
- yamcot/_core/mco_prc.cpp +1476 -0
- yamcot/_core/pfm_to_pwm.h +130 -0
- yamcot/cli.py +621 -0
- yamcot/comparison.py +1066 -0
- yamcot/execute.py +97 -0
- yamcot/functions.py +787 -0
- yamcot/io.py +522 -0
- yamcot/models.py +1161 -0
- yamcot/pipeline.py +402 -0
- yamcot/ragged.py +126 -0
- yamcot-1.0.0.dist-info/METADATA +433 -0
- yamcot-1.0.0.dist-info/RECORD +21 -0
- yamcot-1.0.0.dist-info/WHEEL +5 -0
- yamcot-1.0.0.dist-info/entry_points.txt +3 -0
- yamcot-1.0.0.dist-info/licenses/LICENSE +21 -0
yamcot/pipeline.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified pipeline for motif comparison operations.
|
|
3
|
+
This module provides a unified interface for score-based, sequence-based, and tomtom-like comparisons.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import random
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Optional, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from yamcot.comparison import DataComparator, MotaliComparator, TomtomComparator, UniversalMotifComparator
|
|
14
|
+
from yamcot.functions import scores_to_frequencies
|
|
15
|
+
from yamcot.io import read_fasta
|
|
16
|
+
from yamcot.models import MotifModel
|
|
17
|
+
from yamcot.ragged import RaggedData, ragged_from_list
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Pipeline:
|
|
21
|
+
"""
|
|
22
|
+
Unified pipeline for motif comparison operations.
|
|
23
|
+
|
|
24
|
+
This class handles score-based, sequence-based, and tomtom-like comparison paths,
|
|
25
|
+
supporting various model types (BAMM, PWM, Sitega) and comparison methods.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self):
|
|
29
|
+
self.logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
def load_scores(self, profile_path: Union[str, Path]) -> RaggedData:
|
|
32
|
+
"""
|
|
33
|
+
Load pre-calculated scores from text-based profile files in FASTA-like format.
|
|
34
|
+
Each entry consists of a header line starting with '>' (containing metadata)
|
|
35
|
+
followed by a line containing numerical scores separated by commas or tabs.
|
|
36
|
+
The method also handles files without header lines (lines starting with '>').
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
profile_path: Path to the profile file containing scores in text format
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
RaggedData object containing loaded scores
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
path = Path(profile_path)
|
|
46
|
+
scores_list = []
|
|
47
|
+
|
|
48
|
+
with open(path, "r") as file:
|
|
49
|
+
for line in file:
|
|
50
|
+
line = line.strip()
|
|
51
|
+
if not line:
|
|
52
|
+
continue # Skip empty lines
|
|
53
|
+
|
|
54
|
+
if line.startswith(">"):
|
|
55
|
+
# Header line, skip it
|
|
56
|
+
continue
|
|
57
|
+
else:
|
|
58
|
+
# Parse scores from this line
|
|
59
|
+
# Try to split by comma first, then tab, then space
|
|
60
|
+
if "," in line:
|
|
61
|
+
scores = [float(x) for x in line.split(",")]
|
|
62
|
+
elif "\t" in line:
|
|
63
|
+
scores = [float(x) for x in line.split("\t")]
|
|
64
|
+
else:
|
|
65
|
+
# Default to splitting by whitespace
|
|
66
|
+
scores = [float(x) for x in line.split()]
|
|
67
|
+
scores_list.append(np.array(scores, dtype=np.float32))
|
|
68
|
+
|
|
69
|
+
# Convert the list of score arrays to RaggedData
|
|
70
|
+
if not scores_list:
|
|
71
|
+
# Return empty RaggedData if no scores were found
|
|
72
|
+
return RaggedData(np.empty(0, dtype=np.float32), np.zeros(1, dtype=np.int64))
|
|
73
|
+
|
|
74
|
+
return ragged_from_list(scores_list, dtype=np.float32)
|
|
75
|
+
|
|
76
|
+
def normalize_scores(self, scores: RaggedData) -> RaggedData:
|
|
77
|
+
"""
|
|
78
|
+
Normalize scores using scores_to_frequencies function.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
scores: RaggedData containing raw scores
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Normalized frequency data as RaggedData
|
|
85
|
+
"""
|
|
86
|
+
return scores_to_frequencies(scores)
|
|
87
|
+
|
|
88
|
+
def execute_score_comparison(
|
|
89
|
+
self, profile1_path: Union[str, Path], profile2_path: Union[str, Path], **kwargs
|
|
90
|
+
) -> Any:
|
|
91
|
+
"""
|
|
92
|
+
Execute score-based comparison using DataComparator.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
profile1_path: Path to first profile file
|
|
96
|
+
profile2_path: Path to second profile file
|
|
97
|
+
**kwargs: Additional arguments for comparison
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Comparison results
|
|
101
|
+
"""
|
|
102
|
+
# Load scores from both profiles
|
|
103
|
+
scores1 = self.load_scores(profile1_path)
|
|
104
|
+
scores2 = self.load_scores(profile2_path)
|
|
105
|
+
|
|
106
|
+
# Normalize the scores
|
|
107
|
+
freq1 = self.normalize_scores(scores1)
|
|
108
|
+
freq2 = self.normalize_scores(scores2)
|
|
109
|
+
|
|
110
|
+
# Sanitize kwargs for DataComparator
|
|
111
|
+
data_kwargs = {}
|
|
112
|
+
for param in [
|
|
113
|
+
"name",
|
|
114
|
+
"metric",
|
|
115
|
+
"n_permutations",
|
|
116
|
+
"distortion_level",
|
|
117
|
+
"n_jobs",
|
|
118
|
+
"seed",
|
|
119
|
+
"filter_type",
|
|
120
|
+
"filter_threshold",
|
|
121
|
+
"search_range",
|
|
122
|
+
"min_kernel_size",
|
|
123
|
+
"max_kernel_size",
|
|
124
|
+
]:
|
|
125
|
+
if param in kwargs:
|
|
126
|
+
data_kwargs[param] = kwargs[param]
|
|
127
|
+
comparator = DataComparator(**data_kwargs)
|
|
128
|
+
return comparator.compare(freq1, freq2)
|
|
129
|
+
|
|
130
|
+
def load_sequences(
|
|
131
|
+
self, seq_source: Union[str, Path, None], num_sequences: int = 1000, seq_length: int = 200
|
|
132
|
+
) -> RaggedData:
|
|
133
|
+
"""
|
|
134
|
+
Load sequences from source or generate them randomly.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
seq_source: Path to sequence file or None to generate randomly
|
|
138
|
+
num_sequences: Number of sequences to generate if needed
|
|
139
|
+
seq_length: Length of sequences to generate if needed
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
RaggedData object containing sequences
|
|
143
|
+
"""
|
|
144
|
+
if seq_source is not None:
|
|
145
|
+
# Load sequences from file
|
|
146
|
+
return read_fasta(str(seq_source))
|
|
147
|
+
else:
|
|
148
|
+
# Generate random sequences
|
|
149
|
+
sequences = []
|
|
150
|
+
for _ in range(num_sequences):
|
|
151
|
+
seq = self._generate_random_sequence(seq_length)
|
|
152
|
+
sequences.append(self._encode_sequence(seq))
|
|
153
|
+
|
|
154
|
+
return ragged_from_list(sequences, dtype=np.int8)
|
|
155
|
+
|
|
156
|
+
def _encode_sequence(self, seq: str) -> np.ndarray:
|
|
157
|
+
"""Encode a DNA sequence string to integer representation."""
|
|
158
|
+
base_map = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}
|
|
159
|
+
return np.array([base_map.get(base.upper(), 4) for base in seq], dtype=np.int8)
|
|
160
|
+
|
|
161
|
+
def _generate_random_sequence(self, length: int) -> str:
|
|
162
|
+
"""Generate a random DNA sequence of specified length."""
|
|
163
|
+
bases = ["A", "C", "G", "T"]
|
|
164
|
+
return "".join(random.choice(bases) for _ in range(length))
|
|
165
|
+
|
|
166
|
+
def execute_tomtom_comparison(
|
|
167
|
+
self, model1: MotifModel, model2: MotifModel, sequences: Optional[RaggedData], **kwargs
|
|
168
|
+
) -> Any:
|
|
169
|
+
"""
|
|
170
|
+
Execute TomTom-like comparison using TomtomComparator.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
model1: First model
|
|
174
|
+
model2: Second model
|
|
175
|
+
sequences: Set of sequences as RaggedData (for model conversion if needed)
|
|
176
|
+
**kwargs: Additional arguments for comparison
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Comparison results
|
|
180
|
+
"""
|
|
181
|
+
# Sanitize kwargs for TomtomComparator - remove unsupported parameters
|
|
182
|
+
tom_kwargs = {}
|
|
183
|
+
for param in ["metric", "n_permutations", "permute_rows", "n_jobs", "seed", "pfm_mode"]:
|
|
184
|
+
if param in kwargs:
|
|
185
|
+
tom_kwargs[param] = kwargs[param]
|
|
186
|
+
comparator = TomtomComparator(**tom_kwargs)
|
|
187
|
+
return comparator.compare(model1, model2, sequences)
|
|
188
|
+
|
|
189
|
+
def execute_motif_comparison(
|
|
190
|
+
self,
|
|
191
|
+
model1: MotifModel,
|
|
192
|
+
model2: MotifModel,
|
|
193
|
+
sequences: RaggedData,
|
|
194
|
+
promoters: RaggedData,
|
|
195
|
+
comparison_type: str = "motif",
|
|
196
|
+
**kwargs,
|
|
197
|
+
) -> Any:
|
|
198
|
+
"""
|
|
199
|
+
Execute sequence-based comparison using appropriate comparator.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model1: First model
|
|
203
|
+
model2: Second model
|
|
204
|
+
sequences: Set of sequences as RaggedData
|
|
205
|
+
promoters: Set of promoter sequences as RaggedData
|
|
206
|
+
comparison_type: Type of comparison ('motif', 'motali')
|
|
207
|
+
**kwargs: Additional arguments for comparison
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Comparison results
|
|
211
|
+
"""
|
|
212
|
+
if comparison_type.lower() == "motif":
|
|
213
|
+
# Sanitize kwargs for MotifComparator
|
|
214
|
+
motif_kwargs = {}
|
|
215
|
+
for param in [
|
|
216
|
+
"name",
|
|
217
|
+
"metric",
|
|
218
|
+
"n_permutations",
|
|
219
|
+
"distortion_level",
|
|
220
|
+
"n_jobs",
|
|
221
|
+
"seed",
|
|
222
|
+
"filter_type",
|
|
223
|
+
"filter_threshold",
|
|
224
|
+
"search_range",
|
|
225
|
+
"min_kernel_size",
|
|
226
|
+
"max_kernel_size",
|
|
227
|
+
]:
|
|
228
|
+
if param in kwargs:
|
|
229
|
+
motif_kwargs[param] = kwargs[param]
|
|
230
|
+
|
|
231
|
+
comparator = UniversalMotifComparator(**motif_kwargs)
|
|
232
|
+
return comparator.compare(model1, model2, sequences)
|
|
233
|
+
|
|
234
|
+
elif comparison_type.lower() == "motali":
|
|
235
|
+
# Ensure threshold table is calculated for both models
|
|
236
|
+
# Always use the promoters argument for threshold calculation
|
|
237
|
+
if not hasattr(model1, "_threshold_table") or model1._threshold_table is None:
|
|
238
|
+
model1.get_threshold_table(promoters)
|
|
239
|
+
if not hasattr(model2, "_threshold_table") or model2._threshold_table is None:
|
|
240
|
+
model2.get_threshold_table(promoters)
|
|
241
|
+
|
|
242
|
+
# MotaliComparator needs a fasta file path
|
|
243
|
+
motali_kwargs = {}
|
|
244
|
+
for param in ["fasta_path", "threshold", "tmp_directory"]:
|
|
245
|
+
if param in kwargs:
|
|
246
|
+
motali_kwargs[param] = kwargs[param]
|
|
247
|
+
|
|
248
|
+
comparator = MotaliComparator(**motali_kwargs)
|
|
249
|
+
return comparator.compare(model1, model2, sequences)
|
|
250
|
+
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(f"Unknown comparison type: {comparison_type}")
|
|
253
|
+
|
|
254
|
+
def run_pipeline(
|
|
255
|
+
self,
|
|
256
|
+
model1_path: Union[str, Path],
|
|
257
|
+
model2_path: Union[str, Path],
|
|
258
|
+
model1_type: str,
|
|
259
|
+
model2_type: str,
|
|
260
|
+
comparison_type: str = "motif",
|
|
261
|
+
seq_source1: Optional[Union[str, Path]] = None,
|
|
262
|
+
seq_source2: Optional[Union[str, Path]] = None,
|
|
263
|
+
num_sequences: int = 1000,
|
|
264
|
+
seq_length: int = 200,
|
|
265
|
+
**kwargs,
|
|
266
|
+
) -> Any:
|
|
267
|
+
"""
|
|
268
|
+
Main entry point for the unified pipeline.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
model1_path: Path to first model/profile
|
|
272
|
+
model2_path: Path to second model/profile
|
|
273
|
+
model1_type: Type of first model ('pwm', 'bamm', 'sitega')
|
|
274
|
+
model2_type: Type of second model ('pwm', 'bamm', 'sitega')
|
|
275
|
+
comparison_type: Type of comparison ('profile', 'motif', 'motali', 'tomtom-like')
|
|
276
|
+
seq_source1: Path to first sequence file (for sequence-based)
|
|
277
|
+
seq_source2: Path to second sequence file (for sequence-based)
|
|
278
|
+
num_sequences: Number of sequences to generate if needed
|
|
279
|
+
seq_length: Length of sequences to generate if needed
|
|
280
|
+
**kwargs: Additional arguments for comparison
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Comparison results
|
|
284
|
+
"""
|
|
285
|
+
self.logger.info(f"Starting pipeline with comparison_type='{comparison_type}'")
|
|
286
|
+
|
|
287
|
+
if comparison_type.lower() == "profile":
|
|
288
|
+
# Score-based comparison path
|
|
289
|
+
self.logger.info("Executing score-based comparison")
|
|
290
|
+
return self.execute_score_comparison(model1_path, model2_path, **kwargs)
|
|
291
|
+
|
|
292
|
+
elif comparison_type.lower() in ["motif", "motali"]:
|
|
293
|
+
# Sequence-based comparison path
|
|
294
|
+
self.logger.info("Executing scan-based comparison")
|
|
295
|
+
|
|
296
|
+
# Load models
|
|
297
|
+
self.logger.info(
|
|
298
|
+
f"Loading models from {model1_path} (type: {model1_type}) and {model2_path} (type: {model2_type})"
|
|
299
|
+
)
|
|
300
|
+
model1 = MotifModel.create_from_file(str(model1_path), model1_type)
|
|
301
|
+
model2 = MotifModel.create_from_file(str(model2_path), model2_type)
|
|
302
|
+
|
|
303
|
+
if model1 is None or model2 is None:
|
|
304
|
+
raise ValueError("Failed to load one or both models")
|
|
305
|
+
|
|
306
|
+
# Load or generate sequences and promoters
|
|
307
|
+
self.logger.info(f"Loading sequences from source: {seq_source1}")
|
|
308
|
+
sequences = self.load_sequences(seq_source1, num_sequences, seq_length)
|
|
309
|
+
self.logger.info(f"Loading promoters from source: {seq_source2}")
|
|
310
|
+
promoters = self.load_sequences(seq_source2, num_sequences, seq_length)
|
|
311
|
+
|
|
312
|
+
# Execute appropriate comparison
|
|
313
|
+
self.logger.info(f"Running {comparison_type} comparison")
|
|
314
|
+
result = self.execute_motif_comparison(
|
|
315
|
+
model1, model2, sequences=sequences, promoters=promoters, comparison_type=comparison_type, **kwargs
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
self.logger.info("Pipeline completed successfully")
|
|
319
|
+
return result
|
|
320
|
+
|
|
321
|
+
elif comparison_type.lower() == "tomtom-like":
|
|
322
|
+
# TomTom-like comparison path
|
|
323
|
+
self.logger.info("Executing TomTom-like comparison")
|
|
324
|
+
|
|
325
|
+
# Load models
|
|
326
|
+
self.logger.info(
|
|
327
|
+
f"Loading models from {model1_path} (type: {model1_type}) and {model2_path} (type: {model2_type})"
|
|
328
|
+
)
|
|
329
|
+
model1 = MotifModel.create_from_file(str(model1_path), model1_type)
|
|
330
|
+
model2 = MotifModel.create_from_file(str(model2_path), model2_type)
|
|
331
|
+
|
|
332
|
+
if model1 is None or model2 is None:
|
|
333
|
+
raise ValueError("Failed to load one or both models")
|
|
334
|
+
|
|
335
|
+
if model1.model_type != model2.model_type:
|
|
336
|
+
# Load or generate sequences for potential model conversion
|
|
337
|
+
self.logger.info("WARNING! models have different origin, switch to `pfm_mode`")
|
|
338
|
+
self.logger.info("Generation sequences for model conversion")
|
|
339
|
+
sequences = self.load_sequences(None, num_sequences, seq_length)
|
|
340
|
+
kwargs["pfm_mode"] = True
|
|
341
|
+
elif kwargs.get("pfm_mode"):
|
|
342
|
+
# Load or generate sequences for potential model conversion
|
|
343
|
+
self.logger.info("Generation sequences for model conversion")
|
|
344
|
+
sequences = self.load_sequences(None, num_sequences, seq_length)
|
|
345
|
+
else:
|
|
346
|
+
sequences = None
|
|
347
|
+
|
|
348
|
+
# Execute TomTom comparison
|
|
349
|
+
self.logger.info(f"Running {comparison_type} comparison")
|
|
350
|
+
result = self.execute_tomtom_comparison(model1, model2, sequences, **kwargs)
|
|
351
|
+
self.logger.info("Pipeline completed successfully")
|
|
352
|
+
return result
|
|
353
|
+
|
|
354
|
+
else:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
f"Unknown comparison type: {comparison_type}. Expected 'profile', 'motif', 'motali', or 'tomtom-like'."
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def run_pipeline(
|
|
361
|
+
model1_path: Union[str, Path],
|
|
362
|
+
model2_path: Union[str, Path],
|
|
363
|
+
model1_type: str,
|
|
364
|
+
model2_type: str,
|
|
365
|
+
comparison_type: str = "motif",
|
|
366
|
+
seq_source1: Optional[Union[str, Path]] = None,
|
|
367
|
+
seq_source2: Optional[Union[str, Path]] = None,
|
|
368
|
+
num_sequences: int = 1000,
|
|
369
|
+
seq_length: int = 200,
|
|
370
|
+
**kwargs,
|
|
371
|
+
) -> Any:
|
|
372
|
+
"""
|
|
373
|
+
Module-level function to run the pipeline.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
model1_path: Path to first model/profile
|
|
377
|
+
model2_path: Path to second model/profile
|
|
378
|
+
model1_type: Type of first model ('pwm', 'bamm', 'sitega')
|
|
379
|
+
model2_type: Type of second model ('pwm', 'bamm', 'sitega')
|
|
380
|
+
comparison_type: Type of comparison ('profile', 'motif', 'motali', 'tomtom-like')
|
|
381
|
+
seq_source1: Path to first sequence file (for sequence-based)
|
|
382
|
+
seq_source2: Path to second sequence file (for sequence-based)
|
|
383
|
+
num_sequences: Number of sequences to generate if needed
|
|
384
|
+
seq_length: Length of sequences to generate if needed
|
|
385
|
+
**kwargs: Additional arguments for comparison
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Comparison results
|
|
389
|
+
"""
|
|
390
|
+
pipeline = Pipeline()
|
|
391
|
+
return pipeline.run_pipeline(
|
|
392
|
+
model1_path=model1_path,
|
|
393
|
+
model2_path=model2_path,
|
|
394
|
+
model1_type=model1_type,
|
|
395
|
+
model2_type=model2_type,
|
|
396
|
+
comparison_type=comparison_type,
|
|
397
|
+
seq_source1=seq_source1,
|
|
398
|
+
seq_source2=seq_source2,
|
|
399
|
+
num_sequences=num_sequences,
|
|
400
|
+
seq_length=seq_length,
|
|
401
|
+
**kwargs,
|
|
402
|
+
)
|
yamcot/ragged.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RaggedData:
|
|
7
|
+
"""
|
|
8
|
+
Class for storing ragged (variable-length) arrays.
|
|
9
|
+
|
|
10
|
+
Uses a flattened representation (data + offsets) for memory efficiency and fast access.
|
|
11
|
+
This structure is particularly useful for storing sequences of different lengths
|
|
12
|
+
without padding, saving memory and computation time.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, data: np.ndarray, offsets: np.ndarray):
|
|
16
|
+
"""
|
|
17
|
+
Initialize the RaggedData object.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
data : np.ndarray
|
|
22
|
+
Flattened array containing all the data elements.
|
|
23
|
+
offsets : np.ndarray
|
|
24
|
+
Array of indices indicating the start of each sequence in the data array.
|
|
25
|
+
The length should be (num_sequences + 1), where the last element
|
|
26
|
+
indicates the end of the last sequence.
|
|
27
|
+
"""
|
|
28
|
+
self.data = data
|
|
29
|
+
self.offsets = offsets
|
|
30
|
+
|
|
31
|
+
def get_length(self, i: int) -> int:
|
|
32
|
+
"""
|
|
33
|
+
Return the length of the i-th sequence.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
i : int
|
|
38
|
+
Index of the sequence.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
int
|
|
43
|
+
Length of the i-th sequence.
|
|
44
|
+
"""
|
|
45
|
+
return self.offsets[i + 1] - self.offsets[i]
|
|
46
|
+
|
|
47
|
+
def get_slice(self, i: int) -> np.ndarray:
|
|
48
|
+
"""
|
|
49
|
+
Return a slice of data for the i-th sequence (view).
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
i : int
|
|
54
|
+
Index of the sequence.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
np.ndarray
|
|
59
|
+
View of the data array for the i-th sequence.
|
|
60
|
+
"""
|
|
61
|
+
return self.data[self.offsets[i] : self.offsets[i + 1]]
|
|
62
|
+
|
|
63
|
+
def total_elements(self) -> int:
|
|
64
|
+
"""
|
|
65
|
+
Return the total number of elements across all sequences.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
int
|
|
70
|
+
Total number of elements in all sequences.
|
|
71
|
+
"""
|
|
72
|
+
return self.data.size
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def num_sequences(self) -> int:
|
|
76
|
+
"""
|
|
77
|
+
Return the number of sequences.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
int
|
|
82
|
+
Number of sequences stored in this object.
|
|
83
|
+
"""
|
|
84
|
+
return self.offsets.size - 1
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def ragged_from_list(data_list: List[np.ndarray], dtype=None) -> RaggedData:
|
|
88
|
+
"""
|
|
89
|
+
Create RaggedData from a list of numpy arrays.
|
|
90
|
+
|
|
91
|
+
This function efficiently combines a list of arrays into a single RaggedData
|
|
92
|
+
object without creating intermediate lists or unnecessary allocations.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
data_list : List[np.ndarray]
|
|
97
|
+
List of numpy arrays of potentially different lengths.
|
|
98
|
+
dtype : data-type, optional
|
|
99
|
+
Data type for the resulting RaggedData. If None, uses the dtype of the first array.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
RaggedData
|
|
104
|
+
A RaggedData object containing all input arrays.
|
|
105
|
+
"""
|
|
106
|
+
if len(data_list) == 0:
|
|
107
|
+
return RaggedData(np.empty(0, dtype=dtype if dtype else np.float32), np.zeros(1, dtype=np.int64))
|
|
108
|
+
|
|
109
|
+
if dtype is None:
|
|
110
|
+
dtype = data_list[0].dtype
|
|
111
|
+
|
|
112
|
+
n = len(data_list)
|
|
113
|
+
lengths = np.empty(n, dtype=np.int64)
|
|
114
|
+
for i in range(n):
|
|
115
|
+
lengths[i] = len(data_list[i])
|
|
116
|
+
|
|
117
|
+
offsets = np.zeros(n + 1, dtype=np.int64)
|
|
118
|
+
offsets[1:] = np.cumsum(lengths)
|
|
119
|
+
|
|
120
|
+
total_size = offsets[-1]
|
|
121
|
+
data = np.empty(total_size, dtype=dtype)
|
|
122
|
+
|
|
123
|
+
for i in range(n):
|
|
124
|
+
data[offsets[i] : offsets[i + 1]] = data_list[i]
|
|
125
|
+
|
|
126
|
+
return RaggedData(data, offsets)
|