yamcot 1.0.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yamcot/io.py ADDED
@@ -0,0 +1,522 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import Iterable, List, Tuple, Union
6
+
7
+ import numpy as np
8
+
9
+ from yamcot.functions import pfm_to_pwm
10
+ from yamcot.ragged import RaggedData
11
+
12
+
13
+ def read_fasta(path: str) -> RaggedData:
14
+ """Read a FASTA file and return integer-encoded sequences.
15
+
16
+ Parameters
17
+ ----------
18
+ path : str
19
+ Path to a FASTA formatted file.
20
+ return_ragged : bool, default True
21
+ If True, returns RaggedData object. If False, returns List[np.ndarray]
22
+ for backward compatibility.
23
+
24
+ Returns
25
+ -------
26
+ Union[RaggedData, List[np.ndarray]]
27
+ Integer-encoded sequences (dtype=np.int8).
28
+ Using RaggedData is more memory-efficient and faster for large files
29
+ as it avoids multiple small allocations.
30
+ """
31
+ # Translation table for fast conversion of strings to bytes (int8)
32
+ # 0=A, 1=C, 2=G, 3=T, 4=N (and others)
33
+ trans_table = bytearray([4] * 256)
34
+ for char, code in zip(b"ACGTacgt", [0, 1, 2, 3] * 2, strict=False):
35
+ trans_table[char] = code
36
+
37
+ # First collect lengths and data in one pass if possible,
38
+ # but FASTA requires parsing to determine lengths.
39
+ # For efficiency, we use a temporary list or intermediate buffer.
40
+
41
+ sequences: List[np.ndarray] = []
42
+
43
+ with open(path, "r") as handle:
44
+ current_seq_bytes = bytearray()
45
+ for line in handle:
46
+ line = line.strip()
47
+ if not line:
48
+ continue
49
+ if line.startswith(">"):
50
+ if current_seq_bytes:
51
+ encoded = np.frombuffer(current_seq_bytes.translate(trans_table), dtype=np.int8).copy()
52
+ sequences.append(encoded)
53
+ current_seq_bytes.clear()
54
+ else:
55
+ current_seq_bytes.extend(line.encode("ascii", errors="ignore"))
56
+
57
+ if current_seq_bytes:
58
+ encoded = np.frombuffer(current_seq_bytes.translate(trans_table), dtype=np.int8).copy()
59
+ sequences.append(encoded)
60
+
61
+ if not sequences:
62
+ return RaggedData(np.empty(0, dtype=np.int8), np.zeros(1, dtype=np.int64))
63
+
64
+ # Convert to RaggedData
65
+ n = len(sequences)
66
+ offsets = np.zeros(n + 1, dtype=np.int64)
67
+ for i, seq in enumerate(sequences):
68
+ offsets[i + 1] = offsets[i] + len(seq)
69
+
70
+ total_data = np.empty(offsets[-1], dtype=np.int8)
71
+ for i, seq in enumerate(sequences):
72
+ total_data[offsets[i] : offsets[i + 1]] = seq
73
+
74
+ return RaggedData(total_data, offsets)
75
+
76
+
77
+ def write_fasta(sequences: Union[RaggedData, Iterable[np.ndarray]], path: str) -> None:
78
+ """Write integer-encoded sequences to a FASTA file.
79
+
80
+ Parameters
81
+ ----------
82
+ sequences : Union[RaggedData, Iterable[np.ndarray]]
83
+ RaggedData object or a collection of integer-encoded sequences (0=A, 1=C, 2=G, 3=T, 4=N).
84
+ path : str
85
+ Path to the output file.
86
+ """
87
+ # Array for converting indices back to symbols
88
+ decoder = np.array(["A", "C", "G", "T", "N"], dtype="U1")
89
+
90
+ with open(path, "w") as out:
91
+ if isinstance(sequences, RaggedData):
92
+ for i in range(sequences.num_sequences):
93
+ seq_int = sequences.get_slice(i)
94
+ safe_seq = np.clip(seq_int, 0, 4)
95
+ chars = decoder[safe_seq]
96
+ seq_str = "".join(chars)
97
+ out.write(f">{i}\n")
98
+ out.write(f"{seq_str}\n")
99
+ else:
100
+ for idx, seq_int in enumerate(sequences):
101
+ safe_seq = np.clip(seq_int, 0, 4)
102
+ chars = decoder[safe_seq]
103
+ seq_str = "".join(chars)
104
+ out.write(f">{idx}\n")
105
+ out.write(f"{seq_str}\n")
106
+
107
+
108
+ def read_meme(path: str, index: int = 0) -> Tuple[np.ndarray, Tuple[str, int], int]:
109
+ """Read a specific motif from a MEME formatted file and return total count.
110
+
111
+ Parameters
112
+ ----------
113
+ path : str
114
+ Path to the MEME file.
115
+ index : int, default 0
116
+ The zero-based index of the motif to return.
117
+
118
+ Returns
119
+ -------
120
+ Tuple[np.ndarray, Tuple[str, int], int]
121
+ A tuple containing:
122
+ - The requested motif matrix (shape (4, L))
123
+ - A tuple with the motif's name and length
124
+ - Total number of motifs found in the file
125
+ """
126
+ target_motif: np.ndarray | None = None
127
+ target_info: Tuple[str, int] | None = None
128
+ motif_count = 0
129
+
130
+ with open(path) as handle:
131
+ line = handle.readline()
132
+ while line:
133
+ if line.startswith("MOTIF"):
134
+ # Check if this is the motif we are looking for
135
+ is_target = motif_count == index
136
+ motif_count += 1
137
+
138
+ parts = line.strip().split()
139
+ name = parts[1]
140
+
141
+ # Read header line containing motif length (w=)
142
+ header_line = handle.readline()
143
+ header = header_line.strip().split()
144
+
145
+ try:
146
+ length_idx = header.index("w=") + 1
147
+ length = int(header[length_idx])
148
+ except (ValueError, IndexError):
149
+ length = 0
150
+
151
+ if is_target:
152
+ matrix = []
153
+ for _ in range(length):
154
+ row_line = handle.readline()
155
+ row = row_line.strip().split()
156
+ if not row:
157
+ continue
158
+ matrix.append(list(map(float, row)))
159
+
160
+ # Transpose into shape (4, length)
161
+ target_motif = np.array(matrix, dtype=np.float32).T
162
+ target_info = (name, length)
163
+ else:
164
+ # Skip the matrix rows for other motifs to save time
165
+ for _ in range(length):
166
+ handle.readline()
167
+
168
+ line = handle.readline()
169
+
170
+ if target_motif is None:
171
+ if motif_count == 0:
172
+ raise ValueError(f"No motifs found in {path}")
173
+ else:
174
+ raise IndexError(f"Motif index {index} out of range. File contains {motif_count} motifs.")
175
+
176
+ # We know that if target_motif is not None, then target_info is also not None
177
+ # because they are set together in the same condition
178
+ assert target_info is not None
179
+
180
+ return target_motif, target_info, motif_count
181
+
182
+
183
+ def write_meme(motifs: List[np.ndarray], info: List[Tuple[str, int]], path: str) -> None:
184
+ """Write a list of motifs to a MEME formatted file.
185
+
186
+ Parameters
187
+ ----------
188
+ motifs : List[np.ndarray]
189
+ List of motif matrices of shape (5, L). Only the first four rows
190
+ (A, C, G, T) are written; the fifth row is ignored.
191
+ info : List[Tuple[str, int]]
192
+ A list of (name, length) tuples corresponding to the motifs.
193
+ path : str
194
+ Path of the output file.
195
+ """
196
+ with open(path, "w") as out:
197
+ out.write("MEME version 4\n\n")
198
+ out.write("ALPHABET= ACGT\n\n")
199
+ out.write("strands: + -\n\n")
200
+ out.write("Background letter frequencies\n")
201
+ out.write("A 0.25 C 0.25 G 0.25 T 0.25\n\n")
202
+ for motif, (name, length) in zip(motifs, info, strict=False):
203
+ out.write(f"MOTIF {name}\n")
204
+ out.write(f"letter-probability matrix: alength= 4 w= {length}\n")
205
+ for row in motif[:4].T:
206
+ out.write(" " + " ".join(f"{val:.6f}" for val in row) + "\n")
207
+ out.write("\n")
208
+
209
+
210
+ def read_sitega(path: str) -> tuple[np.ndarray, int, float, float]:
211
+ """Parse SiteGA output file and return the motif matrix with metadata.
212
+
213
+ Parameters
214
+ ----------
215
+ path : str
216
+ Path to the SiteGA output file (typically ends with '.mat').
217
+
218
+ Returns
219
+ -------
220
+ tuple[np.ndarray, int, float, float]
221
+ A tuple containing:
222
+ - SiteGA matrix of shape (5, 5, length) representing dinucleotide dependencies
223
+ - Length of the motif
224
+ - Minimum score value
225
+ - Maximum score value
226
+ """
227
+ converter = {"A": 0, "C": 1, "G": 2, "T": 3}
228
+ with open(path) as file:
229
+ _name = file.readline().strip()
230
+ _number_of_lpd = int(file.readline().strip().split()[0])
231
+ length = int(file.readline().strip().split()[0])
232
+ minimum = float(file.readline().strip().split()[0])
233
+ maximum = float(file.readline().strip().split()[0])
234
+ sitega = np.zeros((5, 5, length), dtype=np.float32)
235
+ for line in file:
236
+ start, stop, value, _, dinucleotide = line.strip().split()
237
+ dinucleotide = dinucleotide.upper()
238
+ nuc_1, nuc_2 = converter[dinucleotide[0]], converter[dinucleotide[1]]
239
+ number_of_positions = int(stop) - int(start) + 1
240
+ for index in range(int(start), int(stop) + 1):
241
+ sitega[nuc_1][nuc_2][index] += float(value) / number_of_positions
242
+ return np.array(sitega, dtype=np.float32), length, minimum, maximum
243
+
244
+
245
+ def parse_file_content(filepath: str) -> tuple[dict[int, list[np.ndarray]], int, int]:
246
+ """Parse BaMM file content, ignoring comments starting with '#'.
247
+
248
+ Parameters
249
+ ----------
250
+ filepath : str
251
+ Path to the BaMM file to parse.
252
+
253
+ Returns
254
+ -------
255
+ tuple[dict[int, list[np.ndarray]], int, int]
256
+ A tuple containing:
257
+ - Dictionary mapping order indices to lists of coefficient arrays
258
+ - Maximum order found in the file
259
+ - Number of positions (length of motif)
260
+
261
+ Raises
262
+ ------
263
+ FileNotFoundError
264
+ If the specified file does not exist.
265
+ ValueError
266
+ If no valid data is found in the file or inconsistent orders are detected.
267
+ """
268
+ if not os.path.isfile(filepath):
269
+ raise FileNotFoundError(f"File {filepath} not found")
270
+
271
+ with open(filepath, "r") as f:
272
+ raw_text = f.read()
273
+
274
+ # Split blocks by double newline
275
+ raw_blocks = raw_text.strip().split("\n\n")
276
+ clean_blocks_data = []
277
+
278
+ for raw_block in raw_blocks:
279
+ lines = raw_block.strip().split("\n")
280
+
281
+ # Filter comments and empty lines
282
+ valid_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith("#")]
283
+
284
+ if not valid_lines:
285
+ continue
286
+
287
+ block_arrays = []
288
+ for line in valid_lines:
289
+ # Check for potential empty strings after split
290
+ parts = line.split()
291
+ if not parts:
292
+ continue
293
+ arr = np.array([float(x) for x in parts], dtype=np.float32)
294
+ block_arrays.append(arr)
295
+
296
+ clean_blocks_data.append(block_arrays)
297
+
298
+ if not clean_blocks_data:
299
+ raise ValueError(f"No valid data found in {filepath}")
300
+
301
+ num_positions = len(clean_blocks_data)
302
+ max_order = len(clean_blocks_data[0]) - 1
303
+
304
+ data_by_order = {}
305
+ for k in range(max_order + 1):
306
+ data_by_order[k] = []
307
+ for pos_idx in range(num_positions):
308
+ if len(clean_blocks_data[pos_idx]) <= k:
309
+ raise ValueError(f"Inconsistent orders in block {pos_idx}")
310
+ data_by_order[k].append(clean_blocks_data[pos_idx][k])
311
+
312
+ return data_by_order, max_order, num_positions
313
+
314
+
315
+ def read_bamm(motif_path: str, bg_path: str, target_order: int) -> np.ndarray:
316
+ """Read BaMM files, apply ramp-up logic, and add padding for 'N' (index 4).
317
+
318
+ This function reads motif and background BaMM files, computes log-odds ratios,
319
+ applies a ramp-up strategy for lower-order coefficients, and adds padding for
320
+ ambiguous nucleotides ('N') by setting their scores to the minimum of each position.
321
+
322
+ Parameters
323
+ ----------
324
+ motif_path : str
325
+ Path to the motif BaMM file (.ihbcp format).
326
+ bg_path : str
327
+ Path to the background BaMM file (.hbcp format).
328
+ target_order : int
329
+ Target order for the BaMM model (determines tensor dimensions).
330
+
331
+ Returns
332
+ -------
333
+ np.ndarray
334
+ 3D+ tensor of shape (5, 5, ..., 5, Length) where the number of dimensions
335
+ equals target_order + 2. The final dimension represents motif length,
336
+ and the first target_order+1 dimensions represent nucleotide dependencies
337
+ including 'N' padding at index 4.
338
+
339
+ Raises
340
+ ------
341
+ ValueError
342
+ If target order exceeds the maximum order in the file.
343
+ """
344
+ # 1. Parse Data
345
+ motif_raw, max_order_file, motif_length = parse_file_content(motif_path)
346
+ bg_raw, max_order_bg, _ = parse_file_content(bg_path)
347
+
348
+ if max_order_file > max_order_bg:
349
+ max_order_file = max_order_bg
350
+ if target_order > max_order_file:
351
+ target_order = max_order_file
352
+ logger = logging.getLogger(__name__)
353
+ logger.warning(
354
+ f"Target order {target_order} exceeds file max order {max_order_file}, target order set as max order"
355
+ )
356
+
357
+ # 2. Build 4x4...x4 Tensor slices (Standard Logic)
358
+ # We first build the pure ACGT tensor to compute minimums correctly
359
+ acgt_slices = []
360
+
361
+ for pos in range(motif_length):
362
+ current_k = min(pos, target_order)
363
+
364
+ p_motif = motif_raw[current_k][pos]
365
+ p_bg = bg_raw[current_k][0]
366
+
367
+ epsilon = 1e-10
368
+ log_odds = np.log2((p_motif + epsilon) / (p_bg + epsilon))
369
+
370
+ # Reshape & Broadcast (Ramp-up)
371
+ shape_k = [4] * (current_k + 1)
372
+ tensor_k = log_odds.reshape(shape_k)
373
+
374
+ if current_k < target_order:
375
+ missing_dims = target_order - current_k
376
+ expand_shape = [1] * missing_dims + shape_k
377
+ tensor_expanded = tensor_k.reshape(expand_shape)
378
+ target_shape_4 = [4] * (target_order + 1)
379
+ tensor_final = np.broadcast_to(tensor_expanded, target_shape_4).copy()
380
+ else:
381
+ tensor_final = tensor_k
382
+
383
+ acgt_slices.append(tensor_final)
384
+
385
+ # Stack to (4, 4, ..., 4, Length)
386
+ # This tensor contains only valid ACGT scores
387
+ acgt_tensor = np.stack(acgt_slices, axis=-1)
388
+
389
+ # 3. Create 5x5...x5 Tensor with N-padding
390
+
391
+ # Calculate global minimum per position (over all ACGT contexts)
392
+ # We want min over axes (0, 1, ..., target_order).
393
+ # acgt_tensor shape is (4, 4, ..., L). Last axis is Length.
394
+ reduce_axes = tuple(range(target_order + 1))
395
+ min_scores_per_pos = np.min(acgt_tensor, axis=reduce_axes) # Shape (Length,)
396
+
397
+ # Define new shape: (5, 5, ..., 5, Length)
398
+ new_shape = [5] * (target_order + 1) + [motif_length]
399
+
400
+ # Initialize with min values broadcasted
401
+ # NumPy broadcasts from last dimension: (Length,) broadcasts to (5, 5, ..., 5, Length)
402
+ final_tensor = np.ones(new_shape, dtype=np.float32) * min_scores_per_pos
403
+
404
+ # 4. Copy ACGT data into the 5x5 structure
405
+ # We need to slice [0:4, 0:4, ..., :]
406
+ slice_objs = [slice(0, 4)] * (target_order + 1) + [slice(None)]
407
+ final_tensor[tuple(slice_objs)] = acgt_tensor
408
+
409
+ # Return as contiguous array for Numba
410
+ return np.array(final_tensor, dtype=np.float32)
411
+
412
+
413
+ def write_sitega(motif, path: str) -> None:
414
+ """Write SiteGA motif to a file in the .mat format understood by mco_prc.exe.
415
+
416
+ Parameters
417
+ ----------
418
+ motif : SitegaMotif
419
+ The SiteGA motif to write.
420
+ path : str
421
+ Path to the output file.
422
+ """
423
+ sitega_matrix = motif.matrix
424
+ converter = {0: "A", 1: "C", 2: "G", 3: "T"}
425
+
426
+ # Список для хранения найденных сегментов (start, stop, value, dinucleotide)
427
+ segments = []
428
+
429
+ # Один проход для сбора всех данных
430
+ for nuc1 in range(4):
431
+ for nuc2 in range(4):
432
+ # Пропускаем, если все значения для динуклеотида нулевые
433
+ if np.all(np.abs(sitega_matrix[nuc1, nuc2, :]) <= 1e-9):
434
+ continue
435
+
436
+ dinucleotide = converter[nuc1] + converter[nuc2]
437
+ pos = 0
438
+
439
+ while pos < motif.length:
440
+ # Пропускаем нули
441
+ while pos < motif.length and abs(sitega_matrix[nuc1, nuc2, pos]) <= 1e-9:
442
+ pos += 1
443
+
444
+ if pos >= motif.length:
445
+ break
446
+
447
+ # Начало ненулевой последовательности
448
+ start_pos = pos
449
+ current_val = sitega_matrix[nuc1, nuc2, pos]
450
+
451
+ # Ищем конец последовательности с одинаковым значением
452
+ while pos + 1 < motif.length and abs(sitega_matrix[nuc1, nuc2, pos + 1] - current_val) < 1e-9:
453
+ pos += 1
454
+
455
+ # Сохраняем сегмент
456
+ segments.append({"start": start_pos, "stop": pos, "val": current_val, "dinucl": dinucleotide})
457
+
458
+ pos += 1
459
+
460
+ # Количество строк данных теперь просто длина списка
461
+ lpd_count = len(segments)
462
+
463
+ with open(path, "w") as f:
464
+ f.write(f"{motif.name}\n")
465
+ f.write(f"{lpd_count}\tLPD count\n")
466
+ f.write(f"{motif.length}\tModel length\n")
467
+ f.write(f"{motif.minimum:.12f}\tMinimum\n")
468
+ f.write(f"{motif.maximum:.12f}\tRazmah\n")
469
+
470
+ # Записываем данные из собранного списка
471
+ for seg in segments:
472
+ range_length = seg["stop"] - seg["start"] + 1
473
+ total_value = seg["val"] * range_length
474
+
475
+ f.write(f"{seg['start']}\t{seg['stop']}\t{total_value:.12f}\t0\t{seg['dinucl'].lower()}\n")
476
+
477
+
478
+ def write_pfm(pfm: np.ndarray, name: str, length: int, path: str) -> None:
479
+ """Write a Position Frequency Matrix to a file.
480
+
481
+ Parameters
482
+ ----------
483
+ pfm : np.ndarray
484
+ Position frequency matrix of shape (4, length).
485
+ name : str
486
+ Name of the motif.
487
+ length : int
488
+ Length of the motif.
489
+ path : str
490
+ Path to the output file.
491
+ """
492
+ with open(path, "w") as f:
493
+ f.write(f">{name}\n")
494
+ # Transpose the matrix to get the right format
495
+ np.savetxt(f, pfm.T, fmt="%.6f", delimiter="\t")
496
+
497
+
498
+ def read_pfm(path: str) -> tuple[np.ndarray, int, float, float]:
499
+ """Read a Position Frequency Matrix (PFM) from a file and convert to PWM.
500
+
501
+ Parameters
502
+ ----------
503
+ path : str
504
+ Path to the PFM file.
505
+
506
+ Returns
507
+ -------
508
+ tuple[np.ndarray, int, float, float]
509
+ A tuple containing:
510
+ - PWM matrix with shape (5, L) where the 5th row contains column minima
511
+ - Length of the motif
512
+ - Minimum possible score
513
+ - Maximum possible score
514
+ """
515
+ pfm = np.loadtxt(path, comments=">").T
516
+ length = pfm.shape[1]
517
+ pwm = pfm_to_pwm(pfm)
518
+ minimum = np.sum(pwm.min(axis=0))
519
+ maximum = np.sum(pwm.max(axis=0))
520
+ pwm = np.concatenate((pwm, np.min(pwm, axis=0).reshape(1, pwm.shape[1])), axis=0)
521
+ pwm = np.array(pwm, dtype=np.float32)
522
+ return pwm, length, minimum, maximum