edgepython 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edgepython/io.py ADDED
@@ -0,0 +1,1887 @@
1
+ # This code was written by Claude (Anthropic). The project was directed by Lior Pachter.
2
+ """
3
+ I/O functions for edgePython.
4
+
5
+ Port of edgeR's readDGE, read10X, featureCountsToMatrix,
6
+ catchSalmon, catchKallisto, catchRSEM, catchOarfish.
7
+ """
8
+
9
+ import os
10
+ import warnings
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+
15
+ def read_dge(files, path=None, columns=(0, 1), group=None, labels=None, sep='\t'):
16
+ """Read and collate count data files.
17
+
18
+ Port of edgeR's readDGE.
19
+
20
+ Parameters
21
+ ----------
22
+ files : list of str or DataFrame
23
+ File names or DataFrame with 'files' column.
24
+ path : str, optional
25
+ Path prefix for files.
26
+ columns : tuple of int
27
+ Column indices for gene IDs and counts (0-indexed).
28
+ group : array-like, optional
29
+ Group factor.
30
+ labels : list of str, optional
31
+ Sample labels.
32
+ sep : str
33
+ Field separator.
34
+
35
+ Returns
36
+ -------
37
+ DGEList-like dict.
38
+ """
39
+ from .dgelist import make_dgelist
40
+
41
+ if isinstance(files, pd.DataFrame):
42
+ samples = files.copy()
43
+ if labels is not None:
44
+ samples.index = labels
45
+ if 'files' not in samples.columns:
46
+ raise ValueError("file names not found")
47
+ file_list = samples['files'].astype(str).tolist()
48
+ else:
49
+ file_list = [str(f) for f in files]
50
+ if labels is None:
51
+ labels = [os.path.splitext(os.path.basename(f))[0] for f in file_list]
52
+ samples = pd.DataFrame({'files': file_list}, index=labels)
53
+
54
+ nfiles = len(file_list)
55
+
56
+ if group is not None:
57
+ samples['group'] = group
58
+ if 'group' not in samples.columns:
59
+ samples['group'] = 1
60
+
61
+ # Read files
62
+ all_data = {}
63
+ all_tags = {}
64
+ for i, fn in enumerate(file_list):
65
+ if path is not None:
66
+ fp = os.path.join(path, fn)
67
+ else:
68
+ fp = fn
69
+ df = pd.read_csv(fp, sep=sep, header=0)
70
+ tag_col = df.columns[columns[0]]
71
+ count_col = df.columns[columns[1]]
72
+ tags = df[tag_col].astype(str).values
73
+ if len(tags) != len(set(tags)):
74
+ raise ValueError(f"Repeated row names in {fn}. Row names must be unique.")
75
+ all_tags[fn] = tags
76
+ all_data[fn] = df[count_col].values
77
+
78
+ # Collate counts
79
+ all_gene_ids = []
80
+ seen = set()
81
+ for fn in file_list:
82
+ for t in all_tags[fn]:
83
+ if t not in seen:
84
+ all_gene_ids.append(t)
85
+ seen.add(t)
86
+
87
+ ntags = len(all_gene_ids)
88
+ counts = np.zeros((ntags, nfiles), dtype=np.float64)
89
+ tag_to_idx = {t: i for i, t in enumerate(all_gene_ids)}
90
+
91
+ for i, fn in enumerate(file_list):
92
+ for j, tag in enumerate(all_tags[fn]):
93
+ counts[tag_to_idx[tag], i] = all_data[fn][j]
94
+
95
+ samples['lib.size'] = counts.sum(axis=0)
96
+ samples['norm.factors'] = 1.0
97
+
98
+ return make_dgelist(counts, samples=samples,
99
+ genes=pd.DataFrame({'GeneID': all_gene_ids}))
100
+
101
+
102
+ def read_10x(path='.', mtx=None, genes=None, barcodes=None, as_dgelist=True):
103
+ """Read 10X Genomics CellRanger output.
104
+
105
+ Port of edgeR's read10X.
106
+
107
+ Parameters
108
+ ----------
109
+ path : str
110
+ Directory containing 10X files.
111
+ mtx : str, optional
112
+ Matrix file name.
113
+ genes : str, optional
114
+ Genes/features file name.
115
+ barcodes : str, optional
116
+ Barcodes file name.
117
+ as_dgelist : bool
118
+ Return DGEList-like dict.
119
+
120
+ Returns
121
+ -------
122
+ DGEList-like dict or dict with counts/genes/samples.
123
+ """
124
+ from scipy.io import mmread
125
+
126
+ files = os.listdir(path)
127
+
128
+ if mtx is None:
129
+ for candidate in ['matrix.mtx.gz', 'matrix.mtx']:
130
+ if candidate in files:
131
+ mtx = candidate
132
+ break
133
+ if mtx is None:
134
+ raise FileNotFoundError("Can't find matrix.mtx file")
135
+
136
+ if genes is None:
137
+ for candidate in ['features.tsv.gz', 'features.tsv', 'genes.tsv.gz', 'genes.tsv']:
138
+ if candidate in files:
139
+ genes = candidate
140
+ break
141
+ if genes is None:
142
+ raise FileNotFoundError("Can't find genes/features file")
143
+
144
+ if barcodes is None:
145
+ for candidate in ['barcodes.tsv.gz', 'barcodes.tsv']:
146
+ if candidate in files:
147
+ barcodes = candidate
148
+ break
149
+
150
+ mtx_path = os.path.join(path, mtx)
151
+ genes_path = os.path.join(path, genes)
152
+
153
+ # Read sparse matrix
154
+ sparse_mat = mmread(mtx_path)
155
+ y = np.array(sparse_mat.todense(), dtype=np.float64)
156
+
157
+ # Read gene info
158
+ gene_df = pd.read_csv(genes_path, sep='\t', header=None)
159
+ if gene_df.shape[1] >= 2:
160
+ gene_df.columns = ['GeneID', 'Symbol'] + [f'col{i}' for i in range(2, gene_df.shape[1])]
161
+ else:
162
+ gene_df.columns = ['GeneID']
163
+
164
+ # Read barcodes
165
+ samples_df = None
166
+ if barcodes is not None:
167
+ barcodes_path = os.path.join(path, barcodes)
168
+ bc = pd.read_csv(barcodes_path, sep='\t', header=None)
169
+ samples_df = pd.DataFrame({'Barcode': bc.iloc[:, 0].values})
170
+
171
+ if as_dgelist:
172
+ from .dgelist import make_dgelist
173
+ return make_dgelist(y, genes=gene_df, samples=samples_df)
174
+
175
+ return {'counts': y, 'genes': gene_df, 'samples': samples_df}
176
+
177
+
178
+ def catch_salmon(paths, verbose=True):
179
+ """Read Salmon quantification output.
180
+
181
+ Parameters
182
+ ----------
183
+ paths : list of str
184
+ Paths to Salmon output directories.
185
+ verbose : bool
186
+ Print progress.
187
+
188
+ Returns
189
+ -------
190
+ dict with counts, annotation, and samples.
191
+ """
192
+ return _catch_quant(paths, tool='salmon', verbose=verbose)
193
+
194
+
195
+ def catch_kallisto(paths, verbose=True):
196
+ """Read kallisto quantification output.
197
+
198
+ Parameters
199
+ ----------
200
+ paths : list of str
201
+ Paths to kallisto output directories.
202
+ verbose : bool
203
+ Print progress.
204
+
205
+ Returns
206
+ -------
207
+ dict with counts, annotation, and samples.
208
+ """
209
+ return _catch_quant(paths, tool='kallisto', verbose=verbose)
210
+
211
+
212
+ def catch_rsem(files, verbose=True):
213
+ """Read RSEM quantification output.
214
+
215
+ Parameters
216
+ ----------
217
+ files : list of str
218
+ RSEM output files.
219
+ verbose : bool
220
+ Print progress.
221
+
222
+ Returns
223
+ -------
224
+ dict with counts and annotation.
225
+ """
226
+ all_data = []
227
+ gene_ids = None
228
+
229
+ for f in files:
230
+ df = pd.read_csv(f, sep='\t')
231
+ if 'expected_count' in df.columns:
232
+ count_col = 'expected_count'
233
+ elif 'FPKM' in df.columns:
234
+ count_col = 'FPKM'
235
+ else:
236
+ count_col = df.columns[1]
237
+
238
+ if gene_ids is None:
239
+ gene_ids = df.iloc[:, 0].values
240
+ all_data.append(df[count_col].values)
241
+
242
+ counts = np.column_stack(all_data)
243
+ labels = [os.path.splitext(os.path.basename(f))[0] for f in files]
244
+
245
+ return {
246
+ 'counts': counts,
247
+ 'annotation': pd.DataFrame({'GeneID': gene_ids}),
248
+ 'samples': pd.DataFrame({'files': files}, index=labels)
249
+ }
250
+
251
+
252
+ def feature_counts_to_dgelist(files):
253
+ """Convert featureCounts output to DGEList.
254
+
255
+ Parameters
256
+ ----------
257
+ files : str or list of str
258
+ featureCounts output file(s).
259
+
260
+ Returns
261
+ -------
262
+ DGEList-like dict.
263
+ """
264
+ from .dgelist import make_dgelist
265
+
266
+ if isinstance(files, str):
267
+ files = [files]
268
+
269
+ all_counts = []
270
+ gene_info = None
271
+ sample_names = []
272
+
273
+ for f in files:
274
+ df = pd.read_csv(f, sep='\t', comment='#')
275
+ # featureCounts format: Geneid, Chr, Start, End, Strand, Length, Count1, Count2, ...
276
+ if gene_info is None:
277
+ meta_cols = ['Geneid', 'Chr', 'Start', 'End', 'Strand', 'Length']
278
+ gene_cols = [c for c in meta_cols if c in df.columns]
279
+ if gene_cols:
280
+ gene_info = df[gene_cols].copy()
281
+ else:
282
+ gene_info = df.iloc[:, :1].copy()
283
+
284
+ count_cols = [c for c in df.columns if c not in
285
+ ['Geneid', 'Chr', 'Start', 'End', 'Strand', 'Length']]
286
+ for col in count_cols:
287
+ all_counts.append(df[col].values)
288
+ sample_names.append(col)
289
+
290
+ counts = np.column_stack(all_counts)
291
+ return make_dgelist(counts, genes=gene_info)
292
+
293
+
294
+ def _catch_quant(paths, tool='salmon', verbose=True):
295
+ """Internal function to read Salmon/kallisto output (legacy)."""
296
+ all_data = []
297
+ gene_ids = None
298
+ labels = []
299
+
300
+ for p in paths:
301
+ label = os.path.basename(os.path.normpath(p))
302
+ labels.append(label)
303
+
304
+ if tool == 'salmon':
305
+ quant_file = os.path.join(p, 'quant.sf')
306
+ else:
307
+ quant_file = os.path.join(p, 'abundance.tsv')
308
+
309
+ if not os.path.exists(quant_file):
310
+ raise FileNotFoundError(f"Cannot find {quant_file}")
311
+
312
+ df = pd.read_csv(quant_file, sep='\t')
313
+
314
+ if tool == 'salmon':
315
+ count_col = 'NumReads'
316
+ id_col = 'Name'
317
+ length_col = 'EffectiveLength'
318
+ else:
319
+ count_col = 'est_counts'
320
+ id_col = 'target_id'
321
+ length_col = 'eff_length'
322
+
323
+ if gene_ids is None:
324
+ gene_ids = df[id_col].values
325
+ if length_col in df.columns:
326
+ eff_length = df[length_col].values
327
+ else:
328
+ eff_length = None
329
+
330
+ all_data.append(df[count_col].values)
331
+
332
+ if verbose:
333
+ print(f"Reading {label}...")
334
+
335
+ counts = np.column_stack(all_data)
336
+ annotation = pd.DataFrame({'GeneID': gene_ids})
337
+ if eff_length is not None:
338
+ annotation['Length'] = eff_length
339
+
340
+ return {
341
+ 'counts': counts,
342
+ 'annotation': annotation,
343
+ 'samples': pd.DataFrame({'files': paths}, index=labels)
344
+ }
345
+
346
+
347
+ # =====================================================================
348
+ # Overdispersion estimation (shared core)
349
+ # =====================================================================
350
+
351
+ def _accumulate_overdispersion(boot, overdisp, df_arr):
352
+ """Accumulate bootstrap overdispersion statistics for one sample.
353
+
354
+ Port of the per-sample loop body shared across edgeR's catchSalmon,
355
+ catchKallisto, catchRSEM, catchOarfish.
356
+
357
+ Parameters
358
+ ----------
359
+ boot : ndarray, shape (n_tx, n_boot)
360
+ Bootstrap count matrix for one sample.
361
+ overdisp : ndarray, shape (n_tx,)
362
+ Running overdispersion accumulator (modified in-place).
363
+ df_arr : ndarray, shape (n_tx,)
364
+ Running degrees of freedom accumulator (modified in-place).
365
+ """
366
+ n_boot = boot.shape[1]
367
+ M = boot.mean(axis=1)
368
+ pos = M > 0
369
+ overdisp[pos] += np.sum((boot[pos] - M[pos, np.newaxis]) ** 2, axis=1) / M[pos]
370
+ df_arr[pos] += n_boot - 1
371
+
372
+
373
+ def _estimate_overdispersion(overdisp, df_arr):
374
+ """Estimate per-transcript overdispersion with moderate shrinkage.
375
+
376
+ Port of the overdispersion finalization shared across all edgeR
377
+ catch* functions. Applies limited moderation with DFPrior=3.
378
+
379
+ Parameters
380
+ ----------
381
+ overdisp : ndarray, shape (n_tx,)
382
+ Accumulated sum of (Boot - M)^2 / M across samples.
383
+ df_arr : ndarray, shape (n_tx,)
384
+ Accumulated degrees of freedom.
385
+
386
+ Returns
387
+ -------
388
+ overdisp_final : ndarray, shape (n_tx,)
389
+ Moderated overdispersion estimates (>= 1).
390
+ overdisp_prior : float
391
+ Prior overdispersion value used for shrinkage.
392
+ """
393
+ from scipy.stats import f as f_dist
394
+
395
+ pos = df_arr > 0
396
+ n_pos = np.sum(pos)
397
+
398
+ if n_pos > 0:
399
+ overdisp[pos] = overdisp[pos] / df_arr[pos]
400
+
401
+ df_median = float(np.median(df_arr[pos]))
402
+ df_prior = 3.0
403
+ overdisp_prior = float(np.median(overdisp[pos])) / f_dist.ppf(0.5, dfn=df_median, dfd=df_prior)
404
+ if overdisp_prior < 1.0:
405
+ overdisp_prior = 1.0
406
+
407
+ overdisp[pos] = (df_prior * overdisp_prior + df_arr[pos] * overdisp[pos]) / (df_prior + df_arr[pos])
408
+ overdisp = np.maximum(overdisp, 1.0)
409
+ overdisp[~pos] = overdisp_prior
410
+ else:
411
+ overdisp[:] = np.nan
412
+ overdisp_prior = np.nan
413
+
414
+ return overdisp, overdisp_prior
415
+
416
+
417
+ # =====================================================================
418
+ # Format-specific readers for read_data()
419
+ # =====================================================================
420
+
421
+ def _read_kallisto_h5(paths, verbose):
422
+ """Read kallisto H5 output with bootstrap overdispersion.
423
+
424
+ Port of edgeR's catchKallisto.
425
+ """
426
+ try:
427
+ import h5py
428
+ except ImportError:
429
+ raise ImportError(
430
+ "h5py package required for kallisto H5 format. "
431
+ "Install with: pip install h5py"
432
+ )
433
+
434
+ n_samples = len(paths)
435
+ counts = None
436
+ overdisp = None
437
+ df_arr = None
438
+ ids = None
439
+ lengths = None
440
+ eff_lengths = None
441
+
442
+ for j, p in enumerate(paths):
443
+ h5_file = os.path.join(p, 'abundance.h5')
444
+ if not os.path.exists(h5_file):
445
+ raise FileNotFoundError(f"abundance.h5 not found in {p}")
446
+
447
+ with h5py.File(h5_file, 'r') as f:
448
+ n_tx = len(f['aux']['lengths'])
449
+ n_boot = int(np.asarray(f['aux']['num_bootstrap']).flat[0])
450
+
451
+ if verbose:
452
+ label = os.path.basename(os.path.normpath(p))
453
+ print(f"Reading {label}, {n_tx} transcripts, {n_boot} bootstraps")
454
+
455
+ if j == 0:
456
+ counts = np.zeros((n_tx, n_samples), dtype=np.float64)
457
+ overdisp = np.zeros(n_tx, dtype=np.float64)
458
+ df_arr = np.zeros(n_tx, dtype=np.int64)
459
+
460
+ # Store annotation from each sample (R uses the last sample's aux)
461
+ raw_ids = f['aux']['ids'][:]
462
+ ids = np.array([s.decode('utf-8') if isinstance(s, bytes) else str(s)
463
+ for s in raw_ids])
464
+ lengths = np.asarray(f['aux']['lengths'][:], dtype=np.int64)
465
+ eff_lengths = np.asarray(f['aux']['eff_lengths'][:], dtype=np.float64)
466
+
467
+ counts[:, j] = f['est_counts'][:]
468
+
469
+ if n_boot > 0 and 'bootstrap' in f:
470
+ boot = np.column_stack([f['bootstrap'][f'bs{k}'][:] for k in range(n_boot)])
471
+ _accumulate_overdispersion(boot, overdisp, df_arr)
472
+
473
+ has_bootstraps = np.any(df_arr > 0)
474
+ ann_dict = {'Length': lengths, 'EffectiveLength': eff_lengths}
475
+ if has_bootstraps:
476
+ overdisp_final, overdisp_prior = _estimate_overdispersion(overdisp, df_arr)
477
+ ann_dict['Overdispersion'] = overdisp_final
478
+ else:
479
+ overdisp_prior = None
480
+ annotation = pd.DataFrame(ann_dict, index=ids)
481
+
482
+ return counts, annotation, ids, overdisp_prior
483
+
484
+
485
+ def _read_kallisto_tsv(paths, verbose):
486
+ """Read kallisto TSV output (no bootstraps)."""
487
+ n_samples = len(paths)
488
+ all_data = []
489
+ ids = None
490
+ lengths = None
491
+ eff_lengths = None
492
+
493
+ for j, p in enumerate(paths):
494
+ tsv_file = os.path.join(p, 'abundance.tsv')
495
+ if not os.path.exists(tsv_file):
496
+ raise FileNotFoundError(f"abundance.tsv not found in {p}")
497
+
498
+ if verbose:
499
+ label = os.path.basename(os.path.normpath(p))
500
+ print(f"Reading {label}...")
501
+
502
+ df = pd.read_csv(tsv_file, sep='\t')
503
+ if j == 0:
504
+ ids = df['target_id'].values.astype(str)
505
+ lengths = df['length'].values.astype(np.int64)
506
+ eff_lengths = df['eff_length'].values.astype(np.float64)
507
+ all_data.append(df['est_counts'].values.astype(np.float64))
508
+
509
+ counts = np.column_stack(all_data)
510
+ annotation = pd.DataFrame({
511
+ 'Length': lengths,
512
+ 'EffectiveLength': eff_lengths,
513
+ }, index=ids)
514
+
515
+ return counts, annotation, ids, None
516
+
517
+
518
+ def _read_kallisto(paths, fmt, verbose):
519
+ """Read kallisto output, dispatching to H5 or TSV."""
520
+ if fmt is None:
521
+ h5_path = os.path.join(paths[0], 'abundance.h5')
522
+ fmt = 'h5' if os.path.exists(h5_path) else 'tsv'
523
+
524
+ if fmt == 'h5':
525
+ return _read_kallisto_h5(paths, verbose)
526
+ else:
527
+ return _read_kallisto_tsv(paths, verbose)
528
+
529
+
530
+ def _read_salmon(paths, verbose):
531
+ """Read Salmon output with bootstrap overdispersion.
532
+
533
+ Port of edgeR's catchSalmon.
534
+ """
535
+ import json
536
+ import gzip
537
+
538
+ n_samples = len(paths)
539
+ counts = None
540
+ overdisp = None
541
+ df_arr = None
542
+ ids = None
543
+ lengths = None
544
+ eff_lengths = None
545
+
546
+ for j, p in enumerate(paths):
547
+ label = os.path.basename(os.path.normpath(p))
548
+ quant_file = os.path.join(p, 'quant.sf')
549
+ meta_file = os.path.join(p, 'aux_info', 'meta_info.json')
550
+ boot_file = os.path.join(p, 'aux_info', 'bootstrap', 'bootstraps.gz')
551
+
552
+ if not os.path.exists(quant_file):
553
+ raise FileNotFoundError(f"quant.sf not found in {p}")
554
+
555
+ # Read meta info for bootstrap count
556
+ n_boot = 0
557
+ n_tx_meta = None
558
+ if os.path.exists(meta_file):
559
+ with open(meta_file) as mf:
560
+ meta = json.load(mf)
561
+ n_tx_meta = meta.get('num_targets') or meta.get('num_valid_targets')
562
+ n_boot = meta.get('num_bootstraps', 0)
563
+ samp_type = meta.get('samp_type', 'bootstrap')
564
+ else:
565
+ samp_type = 'bootstrap'
566
+
567
+ quant_df = pd.read_csv(quant_file, sep='\t')
568
+ n_tx = len(quant_df)
569
+
570
+ if verbose:
571
+ if n_boot > 0:
572
+ print(f"Reading {label}, {n_tx} transcripts, {n_boot} {samp_type} samples")
573
+ else:
574
+ print(f"Reading {label}, {n_tx} transcripts")
575
+
576
+ if j == 0:
577
+ counts = np.zeros((n_tx, n_samples), dtype=np.float64)
578
+ overdisp = np.zeros(n_tx, dtype=np.float64)
579
+ df_arr = np.zeros(n_tx, dtype=np.int64)
580
+ ids = quant_df['Name'].values.astype(str)
581
+ lengths = quant_df['Length'].values.astype(np.int64)
582
+ eff_lengths = quant_df['EffectiveLength'].values.astype(np.float64)
583
+
584
+ counts[:, j] = quant_df['NumReads'].values.astype(np.float64)
585
+
586
+ # Read binary bootstrap samples
587
+ if n_boot > 0 and os.path.exists(boot_file):
588
+ with gzip.open(boot_file, 'rb') as bf:
589
+ raw = bf.read()
590
+ # R: readBin(con, "double", n=NTx*NBoot); dim(Boot) <- c(NTx, NBoot)
591
+ # R fills column-major: first NTx values = bootstrap 0, etc.
592
+ all_vals = np.frombuffer(raw, dtype=np.float64)
593
+ boot = all_vals.reshape((n_tx, n_boot), order='F')
594
+ _accumulate_overdispersion(boot, overdisp, df_arr)
595
+
596
+ has_bootstraps = np.any(df_arr > 0)
597
+ ann_dict = {'Length': lengths, 'EffectiveLength': eff_lengths}
598
+ if has_bootstraps:
599
+ overdisp_final, overdisp_prior = _estimate_overdispersion(overdisp, df_arr)
600
+ ann_dict['Overdispersion'] = overdisp_final
601
+ else:
602
+ overdisp_prior = None
603
+ annotation = pd.DataFrame(ann_dict, index=ids)
604
+
605
+ return counts, annotation, ids, overdisp_prior
606
+
607
+
608
+ def _read_oarfish(paths, path, verbose):
609
+ """Read oarfish output with parquet bootstrap overdispersion.
610
+
611
+ Port of edgeR's catchOarfish.
612
+ """
613
+ import json
614
+
615
+ # If paths is None, auto-discover .quant files in path
616
+ if paths is None:
617
+ if path is None:
618
+ path = '.'
619
+ quant_files = sorted([f for f in os.listdir(path) if f.endswith('.quant')])
620
+ if not quant_files:
621
+ raise FileNotFoundError(f"No oarfish .quant files found in {path}")
622
+ prefixes = [os.path.join(path, f[:-6]) for f in quant_files]
623
+ else:
624
+ # paths is list of prefixes or full .quant paths
625
+ prefixes = []
626
+ for p in paths:
627
+ if p.endswith('.quant'):
628
+ prefixes.append(p[:-6])
629
+ else:
630
+ prefixes.append(p)
631
+
632
+ n_samples = len(prefixes)
633
+ counts = None
634
+ overdisp = None
635
+ df_arr = None
636
+ ids = None
637
+ lengths = None
638
+
639
+ for j, prefix in enumerate(prefixes):
640
+ quant_file = f"{prefix}.quant"
641
+ meta_file = f"{prefix}.meta_info.json"
642
+ boot_file = f"{prefix}.infreps.pq"
643
+
644
+ if not os.path.exists(quant_file):
645
+ raise FileNotFoundError(f"{quant_file} not found")
646
+
647
+ n_boot = 0
648
+ if os.path.exists(meta_file):
649
+ with open(meta_file) as mf:
650
+ meta = json.load(mf)
651
+ n_boot = meta.get('num_bootstraps', 0)
652
+
653
+ quant_df = pd.read_csv(quant_file, sep='\t')
654
+ n_tx = len(quant_df)
655
+
656
+ if verbose:
657
+ label = os.path.basename(prefix)
658
+ print(f"Reading {label}, {n_tx} transcripts, {n_boot} bootstraps")
659
+
660
+ if j == 0:
661
+ counts = np.zeros((n_tx, n_samples), dtype=np.float64)
662
+ overdisp = np.zeros(n_tx, dtype=np.float64)
663
+ df_arr = np.zeros(n_tx, dtype=np.int64)
664
+ ids = quant_df['tname'].values.astype(str)
665
+ lengths = quant_df['len'].values.astype(np.int64)
666
+
667
+ counts[:, j] = quant_df['num_reads'].values.astype(np.float64)
668
+
669
+ if n_boot > 0 and os.path.exists(boot_file):
670
+ try:
671
+ boot = pd.read_parquet(boot_file).values.astype(np.float64)
672
+ except ImportError:
673
+ raise ImportError(
674
+ "pyarrow package required for oarfish parquet bootstraps. "
675
+ "Install with: pip install pyarrow"
676
+ )
677
+ _accumulate_overdispersion(boot, overdisp, df_arr)
678
+
679
+ has_bootstraps = np.any(df_arr > 0)
680
+ ann_dict = {'Length': lengths}
681
+ if has_bootstraps:
682
+ overdisp_final, overdisp_prior = _estimate_overdispersion(overdisp, df_arr)
683
+ ann_dict['Overdispersion'] = overdisp_final
684
+ else:
685
+ overdisp_prior = None
686
+ annotation = pd.DataFrame(ann_dict, index=ids)
687
+
688
+ return counts, annotation, ids, overdisp_prior
689
+
690
+
691
+ def _read_rsem_data(files, path, ngibbs, verbose):
692
+ """Read RSEM output with Gibbs-based overdispersion.
693
+
694
+ Port of edgeR's catchRSEM.
695
+ """
696
+ n_samples = len(files)
697
+ if isinstance(ngibbs, (int, float)):
698
+ ngibbs = [int(ngibbs)] * n_samples
699
+
700
+ counts = None
701
+ overdisp = None
702
+ df_arr = None
703
+ ids = None
704
+ lengths = None
705
+ eff_lengths = None
706
+
707
+ for j, f in enumerate(files):
708
+ full_path = os.path.join(path, f) if path else f
709
+ if not os.path.exists(full_path):
710
+ raise FileNotFoundError(f"{full_path} not found")
711
+
712
+ quant_df = pd.read_csv(full_path, sep='\t')
713
+ if 'expected_count' not in quant_df.columns:
714
+ raise ValueError(f"File {f} doesn't contain expected_count column")
715
+
716
+ n_tx = len(quant_df)
717
+ ng = ngibbs[j]
718
+
719
+ if verbose:
720
+ print(f"Reading {os.path.basename(f)}, {n_tx} transcripts, {ng} Gibbs samples")
721
+
722
+ if j == 0:
723
+ counts = np.zeros((n_tx, n_samples), dtype=np.float64)
724
+ overdisp = np.zeros(n_tx, dtype=np.float64)
725
+ df_arr = np.zeros(n_tx, dtype=np.int64)
726
+ id_col = 'transcript_id' if 'transcript_id' in quant_df.columns else quant_df.columns[0]
727
+ ids = quant_df[id_col].values.astype(str)
728
+ lengths = quant_df['length'].values.astype(np.int64) if 'length' in quant_df.columns else None
729
+ eff_lengths = quant_df['effective_length'].values.astype(np.float64) if 'effective_length' in quant_df.columns else None
730
+
731
+ counts[:, j] = quant_df['expected_count'].values.astype(np.float64)
732
+
733
+ # RSEM Gibbs overdispersion: (ngibbs-1) * S^2 / M
734
+ M_col = quant_df.get('posterior_mean_count')
735
+ S_col = quant_df.get('posterior_standard_deviation_of_count')
736
+ if M_col is not None and S_col is not None and ng > 0:
737
+ M = M_col.values.astype(np.float64)
738
+ S = S_col.values.astype(np.float64)
739
+ pos = M > 0
740
+ overdisp[pos] += (ng - 1) * (S[pos] ** 2) / M[pos]
741
+ df_arr[pos] += ng - 1
742
+
743
+ has_bootstraps = np.any(df_arr > 0)
744
+ ann_dict = {}
745
+ if lengths is not None:
746
+ ann_dict['Length'] = lengths
747
+ if eff_lengths is not None:
748
+ ann_dict['EffectiveLength'] = eff_lengths
749
+ if has_bootstraps:
750
+ overdisp_final, overdisp_prior = _estimate_overdispersion(overdisp, df_arr)
751
+ ann_dict['Overdispersion'] = overdisp_final
752
+ else:
753
+ overdisp_prior = None
754
+ annotation = pd.DataFrame(ann_dict, index=ids)
755
+
756
+ return counts, annotation, ids, overdisp_prior
757
+
758
+
759
+ def _read_anndata(data, group, labels, obs_col, layer, verbose):
760
+ """Read AnnData object or .h5ad file."""
761
+ try:
762
+ import anndata
763
+ except ImportError:
764
+ raise ImportError(
765
+ "anndata package required for AnnData/.h5ad import. "
766
+ "Install with: pip install anndata"
767
+ )
768
+ from .dgelist import make_dgelist
769
+
770
+ if isinstance(data, str):
771
+ if verbose:
772
+ print(f"Reading {data}...")
773
+ adata = anndata.read_h5ad(data)
774
+ else:
775
+ adata = data
776
+
777
+ # Extract count matrix: AnnData is obs×var (samples×genes)
778
+ # edgePython needs genes×samples -> transpose
779
+ if layer is not None:
780
+ if layer not in adata.layers:
781
+ raise ValueError(f"Layer '{layer}' not found in AnnData. "
782
+ f"Available: {list(adata.layers.keys())}")
783
+ X = adata.layers[layer]
784
+ elif 'counts' in adata.layers:
785
+ # Prefer raw counts layer when available (common Scanpy convention)
786
+ X = adata.layers['counts']
787
+ else:
788
+ X = adata.X
789
+
790
+ # Handle sparse matrices
791
+ if hasattr(X, 'toarray') and hasattr(X, 'nnz'):
792
+ shape = X.shape
793
+ nnz = X.nnz
794
+ density = nnz / (shape[0] * shape[1]) if shape[0] * shape[1] > 0 else 0
795
+ warnings.warn(
796
+ f"Densifying sparse AnnData matrix ({shape[0]} x {shape[1]}, "
797
+ f"{100*density:.1f}% non-zero, "
798
+ f"{shape[0] * shape[1] * 8 / 1e6:.0f} MB dense). "
799
+ f"edgePython stores counts as dense arrays.",
800
+ stacklevel=2,
801
+ )
802
+ X = X.toarray()
803
+ counts = np.asarray(X, dtype=np.float64).T # genes × samples
804
+
805
+ # Gene annotation from .var
806
+ genes_df = adata.var.copy() if len(adata.var.columns) > 0 else None
807
+
808
+ # Sample labels
809
+ if labels is None:
810
+ labels = list(adata.obs_names)
811
+
812
+ # Group from obs_col
813
+ if group is None and obs_col is not None:
814
+ if obs_col in adata.obs.columns:
815
+ group = adata.obs[obs_col].values
816
+ else:
817
+ raise ValueError(f"Column '{obs_col}' not found in AnnData.obs. "
818
+ f"Available: {list(adata.obs.columns)}")
819
+
820
+ dge = make_dgelist(counts, group=group, genes=genes_df)
821
+ if labels is not None:
822
+ dge['samples'].index = labels
823
+ return dge
824
+
825
+
826
+ def _parse_rds_metadata(path):
827
+ """Parse key=value metadata file written by R extraction script."""
828
+ metadata = {}
829
+ if not os.path.exists(path):
830
+ return metadata
831
+ with open(path) as f:
832
+ for line in f:
833
+ line = line.strip()
834
+ if '=' in line:
835
+ key, value = line.split('=', 1)
836
+ metadata[key] = None if value == 'NA' else value
837
+ return metadata
838
+
839
+
840
+ def _build_rds_extraction_script(filepath, tmpdir):
841
+ """Build R script that extracts components from an RDS file."""
842
+ r_filepath = filepath.replace('\\', '/')
843
+ r_tmpdir = tmpdir.replace('\\', '/')
844
+
845
+ return f'''
846
+ suppressPackageStartupMessages(library(methods))
847
+ x <- readRDS("{r_filepath}")
848
+ tmpdir <- "{r_tmpdir}"
849
+ cls <- class(x)[1]
850
+
851
+ if (inherits(x, "DGEList")) {{
852
+ write.csv(x$counts, file.path(tmpdir, "counts.csv"))
853
+ write.csv(x$samples, file.path(tmpdir, "samples.csv"))
854
+ if (!is.null(x$genes)) {{
855
+ write.csv(x$genes, file.path(tmpdir, "genes.csv"))
856
+ }}
857
+
858
+ has_common <- !is.null(x$common.dispersion)
859
+ has_trended <- !is.null(x$trended.dispersion)
860
+ has_tagwise <- !is.null(x$tagwise.dispersion)
861
+ has_alc <- !is.null(x$AveLogCPM)
862
+
863
+ if (has_trended || has_tagwise) {{
864
+ disp_df <- data.frame(row.names=rownames(x$counts))
865
+ if (has_trended) disp_df$trended <- x$trended.dispersion
866
+ if (has_tagwise) disp_df$tagwise <- x$tagwise.dispersion
867
+ write.csv(disp_df, file.path(tmpdir, "dispersions.csv"))
868
+ }}
869
+
870
+ if (has_alc) {{
871
+ write.csv(data.frame(value=x$AveLogCPM, row.names=rownames(x$counts)),
872
+ file.path(tmpdir, "AveLogCPM.csv"))
873
+ }}
874
+
875
+ prior_df <- ifelse(is.null(x$prior.df), "NA", as.character(x$prior.df))
876
+
877
+ metadata <- c(
878
+ paste0("class=", cls),
879
+ paste0("nrow=", nrow(x$counts)),
880
+ paste0("ncol=", ncol(x$counts)),
881
+ paste0("has_genes=", !is.null(x$genes)),
882
+ paste0("has_common_dispersion=", has_common),
883
+ paste0("has_trended_dispersion=", has_trended),
884
+ paste0("has_tagwise_dispersion=", has_tagwise),
885
+ paste0("has_AveLogCPM=", has_alc),
886
+ paste0("common.dispersion=", ifelse(has_common, x$common.dispersion, "NA")),
887
+ paste0("prior.df=", prior_df),
888
+ paste0("has_size_factors=FALSE")
889
+ )
890
+ writeLines(metadata, file.path(tmpdir, "metadata.txt"))
891
+
892
+ }} else if (isClass("SummarizedExperiment") && is(x, "SummarizedExperiment")) {{
893
+ suppressPackageStartupMessages(library(SummarizedExperiment))
894
+
895
+ counts_mat <- as.matrix(assay(x))
896
+ write.csv(counts_mat, file.path(tmpdir, "counts.csv"))
897
+
898
+ cd <- as.data.frame(colData(x))
899
+ write.csv(cd, file.path(tmpdir, "samples.csv"))
900
+
901
+ rd <- as.data.frame(rowData(x))
902
+ if (ncol(rd) > 0) {{
903
+ write.csv(rd, file.path(tmpdir, "genes.csv"))
904
+ }}
905
+
906
+ has_sf <- FALSE
907
+ if (is(x, "DESeqDataSet")) {{
908
+ tryCatch({{
909
+ suppressPackageStartupMessages(library(DESeq2))
910
+ sf <- sizeFactors(x)
911
+ if (!is.null(sf)) {{
912
+ write.csv(data.frame(value=sf, row.names=colnames(x)),
913
+ file.path(tmpdir, "size_factors.csv"))
914
+ has_sf <- TRUE
915
+ }}
916
+ }}, error = function(e) {{}})
917
+ }}
918
+
919
+ metadata <- c(
920
+ paste0("class=", cls),
921
+ paste0("nrow=", nrow(counts_mat)),
922
+ paste0("ncol=", ncol(counts_mat)),
923
+ paste0("has_genes=", ncol(rd) > 0),
924
+ paste0("has_common_dispersion=FALSE"),
925
+ paste0("has_trended_dispersion=FALSE"),
926
+ paste0("has_tagwise_dispersion=FALSE"),
927
+ paste0("has_AveLogCPM=FALSE"),
928
+ paste0("common.dispersion=NA"),
929
+ paste0("prior.df=NA"),
930
+ paste0("has_size_factors=", has_sf)
931
+ )
932
+ writeLines(metadata, file.path(tmpdir, "metadata.txt"))
933
+
934
+ }} else if (is.matrix(x) || is.data.frame(x)) {{
935
+ write.csv(as.matrix(x), file.path(tmpdir, "counts.csv"))
936
+ metadata <- c(
937
+ paste0("class=", cls),
938
+ paste0("nrow=", nrow(x)),
939
+ paste0("ncol=", ncol(x)),
940
+ paste0("has_genes=FALSE"),
941
+ paste0("has_common_dispersion=FALSE"),
942
+ paste0("has_trended_dispersion=FALSE"),
943
+ paste0("has_tagwise_dispersion=FALSE"),
944
+ paste0("has_AveLogCPM=FALSE"),
945
+ paste0("common.dispersion=NA"),
946
+ paste0("prior.df=NA"),
947
+ paste0("has_size_factors=FALSE")
948
+ )
949
+ writeLines(metadata, file.path(tmpdir, "metadata.txt"))
950
+
951
+ }} else {{
952
+ stop(paste0("Unsupported R object class: ", cls,
953
+ ". Expected DGEList, SummarizedExperiment, DESeqDataSet, matrix, or data.frame."))
954
+ }}
955
+ '''
956
+
957
+
958
+ def _read_rds(filepath, group=None, verbose=True):
959
+ """Read an R .rds file containing a DGEList, SummarizedExperiment, or DESeqDataSet.
960
+
961
+ Uses R (via subprocess) to extract components to temporary CSV files,
962
+ then loads them into a DGEList. Requires R to be installed and
963
+ accessible as 'Rscript' on PATH.
964
+ """
965
+ import subprocess
966
+ import shutil
967
+ import tempfile
968
+
969
+ from .dgelist import make_dgelist
970
+
971
+ if not os.path.isfile(filepath):
972
+ raise FileNotFoundError(f"RDS file not found: {filepath}")
973
+
974
+ rscript = shutil.which('Rscript')
975
+ if rscript is None:
976
+ raise RuntimeError(
977
+ "Rscript not found on PATH. R must be installed to read .rds files. "
978
+ "Install R from https://cran.r-project.org/"
979
+ )
980
+
981
+ tmpdir = tempfile.mkdtemp(prefix='edgepy_rds_')
982
+
983
+ try:
984
+ r_script = _build_rds_extraction_script(os.path.abspath(filepath), tmpdir)
985
+ script_path = os.path.join(tmpdir, 'extract.R')
986
+ with open(script_path, 'w') as f:
987
+ f.write(r_script)
988
+
989
+ if verbose:
990
+ print(f"Reading {os.path.basename(filepath)} via R...")
991
+
992
+ result = subprocess.run(
993
+ [rscript, '--no-save', '--no-restore', script_path],
994
+ capture_output=True, text=True, timeout=120
995
+ )
996
+
997
+ if result.returncode != 0:
998
+ err_msg = result.stderr.strip() or result.stdout.strip()
999
+ raise RuntimeError(f"R failed to read {filepath}:\n{err_msg}")
1000
+
1001
+ # Parse metadata
1002
+ metadata = _parse_rds_metadata(os.path.join(tmpdir, 'metadata.txt'))
1003
+
1004
+ if verbose:
1005
+ cls = metadata.get('class', 'unknown')
1006
+ nrow = metadata.get('nrow', '?')
1007
+ ncol = metadata.get('ncol', '?')
1008
+ print(f" {cls}: {nrow} genes x {ncol} samples")
1009
+
1010
+ # Load counts
1011
+ counts_path = os.path.join(tmpdir, 'counts.csv')
1012
+ if not os.path.exists(counts_path):
1013
+ raise RuntimeError("R extraction did not produce counts.csv")
1014
+ counts_df = pd.read_csv(counts_path, index_col=0)
1015
+ gene_ids = list(counts_df.index.astype(str))
1016
+ sample_names = list(counts_df.columns)
1017
+ counts = counts_df.values.astype(np.float64)
1018
+
1019
+ # Load sample info
1020
+ samples_path = os.path.join(tmpdir, 'samples.csv')
1021
+ samples_df = pd.read_csv(samples_path, index_col=0) if os.path.exists(samples_path) else None
1022
+
1023
+ r_group = None
1024
+ lib_size = None
1025
+ norm_factors = None
1026
+
1027
+ if samples_df is not None:
1028
+ if 'group' in samples_df.columns:
1029
+ r_group = samples_df['group'].values
1030
+ if 'lib.size' in samples_df.columns:
1031
+ lib_size = samples_df['lib.size'].values.astype(np.float64)
1032
+ if 'norm.factors' in samples_df.columns:
1033
+ norm_factors = samples_df['norm.factors'].values.astype(np.float64)
1034
+
1035
+ if group is None:
1036
+ group = r_group
1037
+
1038
+ # Load gene annotation
1039
+ genes_path = os.path.join(tmpdir, 'genes.csv')
1040
+ if os.path.exists(genes_path):
1041
+ genes_df = pd.read_csv(genes_path, index_col=0)
1042
+ else:
1043
+ # Create minimal genes DataFrame to preserve row names
1044
+ genes_df = pd.DataFrame(index=gene_ids)
1045
+
1046
+ # Build DGEList
1047
+ dge = make_dgelist(
1048
+ counts, lib_size=lib_size, norm_factors=norm_factors,
1049
+ group=group, genes=genes_df,
1050
+ )
1051
+
1052
+ # Restore original row/column names
1053
+ if gene_ids:
1054
+ if 'genes' in dge and dge['genes'] is not None:
1055
+ dge['genes'].index = gene_ids
1056
+ if sample_names:
1057
+ dge['samples'].index = sample_names
1058
+
1059
+ # Restore dispersions
1060
+ if metadata.get('has_common_dispersion') == 'TRUE':
1061
+ val = metadata.get('common.dispersion')
1062
+ if val is not None:
1063
+ dge['common.dispersion'] = float(val)
1064
+
1065
+ disp_path = os.path.join(tmpdir, 'dispersions.csv')
1066
+ if os.path.exists(disp_path):
1067
+ disp_df = pd.read_csv(disp_path, index_col=0)
1068
+ if 'trended' in disp_df.columns:
1069
+ dge['trended.dispersion'] = disp_df['trended'].values.astype(np.float64)
1070
+ if 'tagwise' in disp_df.columns:
1071
+ dge['tagwise.dispersion'] = disp_df['tagwise'].values.astype(np.float64)
1072
+
1073
+ alc_path = os.path.join(tmpdir, 'AveLogCPM.csv')
1074
+ if metadata.get('has_AveLogCPM') == 'TRUE' and os.path.exists(alc_path):
1075
+ alc_df = pd.read_csv(alc_path, index_col=0)
1076
+ dge['AveLogCPM'] = alc_df['value'].values.astype(np.float64)
1077
+
1078
+ if metadata.get('prior.df') is not None:
1079
+ dge['prior.df'] = float(metadata['prior.df'])
1080
+
1081
+ # DESeqDataSet size factors
1082
+ sf_path = os.path.join(tmpdir, 'size_factors.csv')
1083
+ if metadata.get('has_size_factors') == 'TRUE' and os.path.exists(sf_path):
1084
+ sf_df = pd.read_csv(sf_path, index_col=0)
1085
+ dge['deseq2.size.factors'] = sf_df['value'].values.astype(np.float64)
1086
+
1087
+ return dge
1088
+
1089
+ except subprocess.TimeoutExpired:
1090
+ raise RuntimeError(
1091
+ f"R subprocess timed out reading {filepath}. "
1092
+ "The file may be very large or R may be unresponsive."
1093
+ )
1094
+ finally:
1095
+ shutil.rmtree(tmpdir, ignore_errors=True)
1096
+
1097
+
1098
+ def _read_table_file(data, path, columns, sep, group, verbose):
1099
+ """Read CSV/TSV count table (e.g., exported from R).
1100
+
1101
+ Handles two formats:
1102
+ 1. Single file with gene IDs as first column/index and samples as columns
1103
+ 2. List of per-sample files (delegates to read_dge)
1104
+ """
1105
+ from .dgelist import make_dgelist
1106
+
1107
+ files = [data] if isinstance(data, str) else list(data)
1108
+
1109
+ # Single file: try reading as a count matrix (R-style export)
1110
+ if len(files) == 1:
1111
+ f = files[0]
1112
+ fp = os.path.join(path, f) if path else f
1113
+
1114
+ # Detect separator from extension if not specified
1115
+ actual_sep = sep
1116
+ if fp.endswith('.csv'):
1117
+ actual_sep = ','
1118
+
1119
+ if verbose:
1120
+ print(f"Reading {os.path.basename(fp)}...")
1121
+
1122
+ df = pd.read_csv(fp, sep=actual_sep, index_col=0)
1123
+
1124
+ # Check if this looks like a count matrix (all numeric columns)
1125
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
1126
+ if len(numeric_cols) == len(df.columns):
1127
+ # All columns are numeric — treat as gene×sample count matrix
1128
+ counts = df.values.astype(np.float64)
1129
+ genes_df = pd.DataFrame(index=df.index)
1130
+ dge = make_dgelist(counts, group=group, genes=genes_df)
1131
+ dge['samples'].index = list(df.columns)
1132
+ return dge
1133
+
1134
+ # Multiple files or non-numeric single file: delegate to read_dge
1135
+ return read_dge(files, path=path, columns=columns or (0, 1),
1136
+ group=group, sep=sep)
1137
+
1138
+
1139
+ def _auto_detect_source(data, fmt):
1140
+ """Detect data source from data argument type and file structure."""
1141
+ if isinstance(data, np.ndarray):
1142
+ return 'matrix'
1143
+ if isinstance(data, pd.DataFrame):
1144
+ return 'dataframe'
1145
+ # scipy.sparse matrices
1146
+ try:
1147
+ import scipy.sparse as sp
1148
+ if sp.issparse(data):
1149
+ return 'sparse'
1150
+ except ImportError:
1151
+ pass
1152
+ if isinstance(data, str):
1153
+ if data.endswith('.h5ad'):
1154
+ return 'anndata'
1155
+ if data.lower().endswith('.rds'):
1156
+ return 'rds'
1157
+ if os.path.isdir(data):
1158
+ try:
1159
+ contents = os.listdir(data)
1160
+ except OSError:
1161
+ raise ValueError(f"Cannot list directory: {data}")
1162
+ if 'matrix.mtx' in contents or 'matrix.mtx.gz' in contents:
1163
+ return '10x'
1164
+ if 'quant.sf' in contents:
1165
+ return 'salmon'
1166
+ if 'abundance.tsv' in contents or 'abundance.h5' in contents:
1167
+ return 'kallisto'
1168
+ if any(f.endswith('.quant') for f in contents):
1169
+ return 'oarfish'
1170
+ elif os.path.isfile(data):
1171
+ if data.endswith('.isoforms.results'):
1172
+ return 'rsem'
1173
+ return 'table'
1174
+ raise ValueError(f"Cannot auto-detect source from path: {data}")
1175
+
1176
+ if isinstance(data, (list, tuple)) and len(data) > 0:
1177
+ first = str(data[0])
1178
+ if os.path.isdir(first):
1179
+ try:
1180
+ contents = os.listdir(first)
1181
+ except OSError:
1182
+ raise ValueError(f"Cannot list directory: {first}")
1183
+ if 'quant.sf' in contents:
1184
+ return 'salmon'
1185
+ if 'abundance.tsv' in contents or 'abundance.h5' in contents:
1186
+ return 'kallisto'
1187
+ elif os.path.isfile(first):
1188
+ if first.endswith('.isoforms.results'):
1189
+ return 'rsem'
1190
+ if first.endswith('.quant'):
1191
+ return 'oarfish'
1192
+ return 'table'
1193
+
1194
+ raise ValueError(
1195
+ "Cannot auto-detect data source. Please specify source='kallisto', "
1196
+ "'salmon', 'oarfish', 'rsem', 'anndata', '10x', 'table', or 'matrix'."
1197
+ )
1198
+
1199
+
1200
+ def _get_anndata_type():
1201
+ """Return anndata.AnnData class without hard import."""
1202
+ try:
1203
+ import anndata
1204
+ return anndata.AnnData
1205
+ except ImportError:
1206
+ return None
1207
+
1208
+
1209
+ # =====================================================================
1210
+ # Universal read_data() function
1211
+ # =====================================================================
1212
+
1213
+ def read_data(
1214
+ data,
1215
+ *,
1216
+ source=None,
1217
+ format=None,
1218
+ path=None,
1219
+ group=None,
1220
+ labels=None,
1221
+ columns=None,
1222
+ sep='\t',
1223
+ obs_col=None,
1224
+ layer=None,
1225
+ ngibbs=100,
1226
+ verbose=True,
1227
+ ):
1228
+ """Universal data import for edgePython.
1229
+
1230
+ Reads count data from various sources and returns a DGEList.
1231
+ When bootstrap/Gibbs resampling information is available, computes
1232
+ overdispersion estimates following edgeR's algorithm.
1233
+
1234
+ Parameters
1235
+ ----------
1236
+ data : various
1237
+ Input data. Accepts:
1238
+ - list of str: paths to quantification directories
1239
+ (kallisto/salmon) or files (RSEM/oarfish/table)
1240
+ - str: path to .h5ad file, .rds file (DGEList/SummarizedExperiment/
1241
+ DESeqDataSet), 10X directory, quantification directory, or
1242
+ count table (.csv/.tsv/.txt)
1243
+ - AnnData object: in-memory AnnData
1244
+ - DGEList: returned as-is (pass-through)
1245
+ - ndarray: count matrix (genes x samples)
1246
+ - scipy.sparse matrix (CSR/CSC): sparse count matrix
1247
+ (will be densified with a warning)
1248
+ - DataFrame: count matrix with gene names as index
1249
+ source : str or None
1250
+ Data source. Auto-detected if None.
1251
+ One of: 'kallisto', 'salmon', 'oarfish', 'rsem', '10x',
1252
+ 'table', 'anndata', 'rds', 'sparse', 'matrix', 'dataframe'.
1253
+ format : str or None
1254
+ For kallisto: 'h5' or 'tsv'. If None, prefers H5 when available.
1255
+ path : str or None
1256
+ Base path prefix for relative file paths.
1257
+ group : array-like or None
1258
+ Sample group assignments.
1259
+ labels : list of str or None
1260
+ Sample labels. Auto-generated from directory/file names if None.
1261
+ columns : tuple of int or None
1262
+ For table format: (gene_id_col, count_col) as 0-based indices.
1263
+ sep : str
1264
+ Field separator for table/CSV files.
1265
+ obs_col : str or None
1266
+ For AnnData: column in .obs to use as group factor.
1267
+ layer : str or None
1268
+ For AnnData: layer name to use instead of .X.
1269
+ ngibbs : int or array-like
1270
+ For RSEM: number of Gibbs samples per sample.
1271
+ verbose : bool
1272
+ Print progress messages.
1273
+
1274
+ Returns
1275
+ -------
1276
+ DGEList
1277
+ With keys: counts, samples, genes.
1278
+ When bootstraps are available: genes DataFrame includes
1279
+ 'Overdispersion' column and dge['overdispersion.prior'] is set.
1280
+ """
1281
+ from .dgelist import make_dgelist
1282
+ from .classes import DGEList
1283
+
1284
+ # --- Pass-through ---
1285
+ if isinstance(data, DGEList):
1286
+ return data
1287
+ if isinstance(data, dict) and 'counts' in data:
1288
+ return data
1289
+
1290
+ # --- AnnData (in-memory object) ---
1291
+ _anndata_cls = _get_anndata_type()
1292
+ if _anndata_cls is not None and isinstance(data, _anndata_cls):
1293
+ source = 'anndata'
1294
+
1295
+ # --- Auto-detect ---
1296
+ if source is None:
1297
+ source = _auto_detect_source(data, format)
1298
+
1299
+ # --- Dispatch ---
1300
+ if source == 'anndata':
1301
+ return _read_anndata(data, group=group, labels=labels,
1302
+ obs_col=obs_col, layer=layer, verbose=verbose)
1303
+
1304
+ if source == 'rds':
1305
+ return _read_rds(data, group=group, verbose=verbose)
1306
+
1307
+ if source == '10x':
1308
+ p = data if isinstance(data, str) else path
1309
+ return read_10x(p, as_dgelist=True)
1310
+
1311
+ if source == 'sparse':
1312
+ shape = data.shape
1313
+ nnz = data.nnz
1314
+ density = nnz / (shape[0] * shape[1]) if shape[0] * shape[1] > 0 else 0
1315
+ warnings.warn(
1316
+ f"Densifying sparse matrix ({shape[0]} x {shape[1]}, "
1317
+ f"{100*density:.1f}% non-zero, "
1318
+ f"{shape[0] * shape[1] * 8 / 1e6:.0f} MB dense). "
1319
+ f"edgePython stores counts as dense arrays.",
1320
+ stacklevel=2,
1321
+ )
1322
+ counts = np.asarray(data.toarray(), dtype=np.float64)
1323
+ return make_dgelist(counts, group=group)
1324
+
1325
+ if source == 'matrix':
1326
+ counts = np.asarray(data, dtype=np.float64)
1327
+ if counts.ndim == 1:
1328
+ counts = counts.reshape(-1, 1)
1329
+ return make_dgelist(counts, group=group)
1330
+
1331
+ if source == 'dataframe':
1332
+ genes_df = pd.DataFrame(index=data.index)
1333
+ counts = data.values.astype(np.float64)
1334
+ dge = make_dgelist(counts, group=group, genes=genes_df)
1335
+ dge['samples'].index = list(data.columns)
1336
+ return dge
1337
+
1338
+ if source == 'table':
1339
+ return _read_table_file(data, path=path, columns=columns,
1340
+ sep=sep, group=group, verbose=verbose)
1341
+
1342
+ # --- Quantification tools: data is path(s) ---
1343
+ if isinstance(data, str):
1344
+ paths = [data]
1345
+ elif isinstance(data, (list, tuple)):
1346
+ paths = list(data)
1347
+ else:
1348
+ raise ValueError(f"Expected path or list of paths for source='{source}'")
1349
+
1350
+ if path is not None:
1351
+ paths = [os.path.join(path, p) for p in paths]
1352
+
1353
+ if labels is None:
1354
+ labels = [os.path.basename(os.path.normpath(p)) for p in paths]
1355
+
1356
+ overdisp_prior = None
1357
+ if source == 'kallisto':
1358
+ counts, annotation, ids, overdisp_prior = _read_kallisto(paths, format, verbose)
1359
+ elif source == 'salmon':
1360
+ counts, annotation, ids, overdisp_prior = _read_salmon(paths, verbose)
1361
+ elif source == 'oarfish':
1362
+ counts, annotation, ids, overdisp_prior = _read_oarfish(
1363
+ paths if isinstance(data, (list, tuple)) else None,
1364
+ path=data if isinstance(data, str) and os.path.isdir(data) else None,
1365
+ verbose=verbose)
1366
+ elif source == 'rsem':
1367
+ counts, annotation, ids, overdisp_prior = _read_rsem_data(
1368
+ paths, path=None, ngibbs=ngibbs, verbose=verbose)
1369
+ else:
1370
+ raise ValueError(f"Unknown source: {source!r}")
1371
+
1372
+ dge = make_dgelist(counts, group=group, genes=annotation)
1373
+
1374
+ # Restore transcript IDs as genes index (make_dgelist overwrites with numeric)
1375
+ if ids is not None and 'genes' in dge:
1376
+ dge['genes'].index = ids
1377
+
1378
+ if overdisp_prior is not None and not np.isnan(overdisp_prior):
1379
+ dge['overdispersion.prior'] = overdisp_prior
1380
+
1381
+ if labels is not None:
1382
+ dge['samples'].index = labels
1383
+
1384
+ return dge
1385
+
1386
+
1387
+ # =====================================================================
1388
+ # Bismark methylation coverage
1389
+ # =====================================================================
1390
+
1391
+ def read_bismark2dge(files, sample_names=None, verbose=True):
1392
+ """Read Bismark methylation coverage files into a DGEList.
1393
+
1394
+ Port of edgeR's readBismark2DGE.
1395
+
1396
+ Reads Bismark ``.cov`` coverage files and collates them into a single
1397
+ DGEList with two columns per sample (methylated and unmethylated
1398
+ counts). Column ordering is interleaved:
1399
+ ``Sample1-Me, Sample1-Un, Sample2-Me, Sample2-Un, ...``
1400
+
1401
+ Parameters
1402
+ ----------
1403
+ files : list of str
1404
+ Paths to Bismark coverage files. Each file is tab-delimited with
1405
+ six columns: chr, start, end, methylation%, count_methylated,
1406
+ count_unmethylated.
1407
+ sample_names : list of str, optional
1408
+ Sample names. If None, derived from file names (extensions stripped).
1409
+ verbose : bool
1410
+ Print progress messages.
1411
+
1412
+ Returns
1413
+ -------
1414
+ DGEList
1415
+ With ``2 * nsamples`` columns and a ``genes`` DataFrame containing
1416
+ ``Chr`` and ``Locus`` columns.
1417
+ """
1418
+ from .dgelist import make_dgelist
1419
+
1420
+ files = [str(f) for f in files]
1421
+ nsamples = len(files)
1422
+
1423
+ if sample_names is None:
1424
+ sample_names = []
1425
+ for f in files:
1426
+ name = os.path.basename(f)
1427
+ # Strip up to 3 extensions (matching R's removeExt×3)
1428
+ for _ in range(3):
1429
+ root, ext = os.path.splitext(name)
1430
+ if ext:
1431
+ name = root
1432
+ else:
1433
+ break
1434
+ sample_names.append(name)
1435
+
1436
+ # Read all files, collecting chromosome names and loci
1437
+ chr_rle_list = []
1438
+ locus_list = []
1439
+ count_list = []
1440
+ chr_names = []
1441
+
1442
+ for i, f in enumerate(files):
1443
+ if verbose:
1444
+ print(f"Reading {f}")
1445
+ x = pd.read_csv(f, sep='\t', header=None)
1446
+ # Columns: 0=chr, 1=start, 2=end, 3=meth%, 4=Me, 5=Un
1447
+ chrs = x.iloc[:, 0].values.astype(str)
1448
+ loci = x.iloc[:, 1].values.astype(np.int64)
1449
+ me_un = x.iloc[:, [4, 5]].values.astype(np.int64)
1450
+
1451
+ # Collect unique chromosome names in order of appearance
1452
+ for c in chrs:
1453
+ if c not in chr_names:
1454
+ chr_names.append(c)
1455
+
1456
+ chr_rle_list.append(chrs)
1457
+ locus_list.append(loci)
1458
+ count_list.append(me_un)
1459
+
1460
+ if verbose:
1461
+ print("Hashing ...")
1462
+
1463
+ # Map chromosome names to integers
1464
+ chr_to_int = {c: i + 1 for i, c in enumerate(chr_names)}
1465
+ hash_base = len(chr_names) + 1
1466
+
1467
+ # Hash genomic positions: chr_int / hash_base + locus
1468
+ hash_list = []
1469
+ hash_unique = []
1470
+ hash_set = set()
1471
+ for i in range(nsamples):
1472
+ chr_ints = np.array([chr_to_int[c] for c in chr_rle_list[i]],
1473
+ dtype=np.float64)
1474
+ h = chr_ints / hash_base + locus_list[i].astype(np.float64)
1475
+ hash_list.append(h)
1476
+ for v in h:
1477
+ if v not in hash_set:
1478
+ hash_unique.append(v)
1479
+ hash_set.add(v)
1480
+
1481
+ hash_unique = np.array(hash_unique, dtype=np.float64)
1482
+ n_loci = len(hash_unique)
1483
+
1484
+ if verbose:
1485
+ print("Collating counts ...")
1486
+
1487
+ # Build merged count matrix with interleaved columns
1488
+ # Column order: S1-Me, S1-Un, S2-Me, S2-Un, ...
1489
+ counts = np.zeros((n_loci, nsamples * 2), dtype=np.int64)
1490
+ hash_to_row = {v: idx for idx, v in enumerate(hash_unique)}
1491
+
1492
+ for i in range(nsamples):
1493
+ h = hash_list[i]
1494
+ rows = np.array([hash_to_row[v] for v in h], dtype=np.int64)
1495
+ counts[rows, 2 * i] = count_list[i][:, 0] # Me
1496
+ counts[rows, 2 * i + 1] = count_list[i][:, 1] # Un
1497
+
1498
+ # Unhash: recover chromosome and locus
1499
+ locus_arr = hash_unique.astype(np.int64)
1500
+ chr_int_arr = np.round((hash_unique - locus_arr) * hash_base).astype(int)
1501
+ chr_arr = np.array([chr_names[ci - 1] for ci in chr_int_arr])
1502
+
1503
+ # Column names: interleaved
1504
+ col_names = []
1505
+ for sn in sample_names:
1506
+ col_names.append(f"{sn}-Me")
1507
+ col_names.append(f"{sn}-Un")
1508
+
1509
+ # Row names
1510
+ row_names = [f"{c}-{l}" for c, l in zip(chr_arr, locus_arr)]
1511
+
1512
+ # Gene annotation
1513
+ genes_df = pd.DataFrame({'Chr': chr_arr, 'Locus': locus_arr},
1514
+ index=row_names)
1515
+
1516
+ # Build DGEList
1517
+ counts_float = counts.astype(np.float64)
1518
+ y = make_dgelist(counts_float, genes=genes_df)
1519
+ y['samples'].index = col_names
1520
+ # Set row names on the DGEList
1521
+ if 'genes' in y and y['genes'] is not None:
1522
+ y['genes'].index = row_names
1523
+
1524
+ return y
1525
+
1526
+
1527
+ # =====================================================================
1528
+ # Pseudo-bulk aggregation (Seurat2PB)
1529
+ # =====================================================================
1530
+
1531
+ def seurat_to_pb(object, sample, cluster="cluster"):
1532
+ """Convert single-cell data to pseudo-bulk DGEList.
1533
+
1534
+ Port of edgeR's Seurat2PB. Aggregates raw counts of cells sharing
1535
+ the same sample and cluster identity into pseudo-bulk columns.
1536
+
1537
+ In R, this function takes a Seurat object. In Python, this function
1538
+ accepts an AnnData object (the standard Python single-cell container),
1539
+ a dict with 'counts' and 'obs' keys, or a raw count matrix with
1540
+ separate metadata.
1541
+
1542
+ Parameters
1543
+ ----------
1544
+ object : AnnData, dict, or ndarray
1545
+ Single-cell data. Accepted formats:
1546
+
1547
+ - **AnnData**: Uses ``.X`` or ``.layers['counts']`` for the
1548
+ count matrix (obs x var = cells x genes). The ``sample``
1549
+ and ``cluster`` columns must be present in ``.obs``.
1550
+ - **dict**: Must contain ``'counts'`` (genes x cells ndarray)
1551
+ and ``'obs'`` (DataFrame with sample/cluster columns).
1552
+ - **ndarray**: Raw count matrix (genes x cells). In this case,
1553
+ ``sample`` must be an array-like of per-cell sample labels
1554
+ and ``cluster`` must be an array-like of per-cell cluster
1555
+ labels.
1556
+ sample : str or array-like
1557
+ If str, the column name in ``.obs`` (AnnData) or ``obs``
1558
+ (dict) identifying the biological sample each cell belongs to.
1559
+ If array-like, per-cell sample labels (length = n_cells).
1560
+ cluster : str or array-like
1561
+ If str, the column name in ``.obs`` or ``obs`` identifying
1562
+ the cell cluster. Default ``"cluster"``.
1563
+ If array-like, per-cell cluster labels.
1564
+
1565
+ Returns
1566
+ -------
1567
+ DGEList
1568
+ Pseudo-bulk DGEList with one column per sample-cluster
1569
+ combination. The ``samples`` DataFrame contains ``sample``
1570
+ and ``cluster`` columns.
1571
+ """
1572
+ from .dgelist import make_dgelist
1573
+
1574
+ # --- Extract counts and metadata ---
1575
+ counts = None
1576
+ obs = None
1577
+
1578
+ # Check for AnnData
1579
+ _anndata_cls = _get_anndata_type()
1580
+ if _anndata_cls is not None and isinstance(object, _anndata_cls):
1581
+ # AnnData: obs x var (cells x genes) -> need to transpose
1582
+ if 'counts' in object.layers:
1583
+ X = object.layers['counts']
1584
+ else:
1585
+ X = object.X
1586
+ if hasattr(X, 'toarray'):
1587
+ X = X.toarray()
1588
+ counts = np.asarray(X, dtype=np.float64).T # genes x cells
1589
+ obs = object.obs
1590
+ gene_names = list(object.var_names)
1591
+ gene_info = object.var.copy() if len(object.var.columns) > 0 else None
1592
+ elif isinstance(object, dict):
1593
+ counts = np.asarray(object['counts'], dtype=np.float64)
1594
+ obs = object.get('obs')
1595
+ gene_names = None
1596
+ gene_info = object.get('genes')
1597
+ elif isinstance(object, np.ndarray):
1598
+ counts = np.asarray(object, dtype=np.float64)
1599
+ obs = None
1600
+ gene_names = None
1601
+ gene_info = None
1602
+ else:
1603
+ raise TypeError(
1604
+ f"Expected AnnData, dict, or ndarray, got {type(object).__name__}"
1605
+ )
1606
+
1607
+ n_genes, n_cells = counts.shape
1608
+
1609
+ # --- Get sample and cluster labels ---
1610
+ if isinstance(sample, str):
1611
+ if obs is None:
1612
+ raise ValueError(
1613
+ "sample is a column name but no obs/metadata provided"
1614
+ )
1615
+ if sample not in obs.columns:
1616
+ raise ValueError(
1617
+ f"Column '{sample}' not found in obs. "
1618
+ f"Available: {list(obs.columns)}"
1619
+ )
1620
+ sample_labels = obs[sample].values
1621
+ else:
1622
+ sample_labels = np.asarray(sample)
1623
+ if len(sample_labels) != n_cells:
1624
+ raise ValueError(
1625
+ f"sample length ({len(sample_labels)}) != "
1626
+ f"number of cells ({n_cells})"
1627
+ )
1628
+
1629
+ if isinstance(cluster, str):
1630
+ if obs is None:
1631
+ raise ValueError(
1632
+ "cluster is a column name but no obs/metadata provided"
1633
+ )
1634
+ if cluster not in obs.columns:
1635
+ raise ValueError(
1636
+ f"Column '{cluster}' not found in obs. "
1637
+ f"Available: {list(obs.columns)}"
1638
+ )
1639
+ cluster_labels = obs[cluster].values
1640
+ else:
1641
+ cluster_labels = np.asarray(cluster)
1642
+ if len(cluster_labels) != n_cells:
1643
+ raise ValueError(
1644
+ f"cluster length ({len(cluster_labels)}) != "
1645
+ f"number of cells ({n_cells})"
1646
+ )
1647
+
1648
+ # --- Create combined sample_cluster factor ---
1649
+ sample_labels = np.asarray(sample_labels, dtype=str)
1650
+ cluster_labels = np.asarray(cluster_labels, dtype=str)
1651
+ combined = np.array([
1652
+ f"{s}_cluster{c}" for s, c in zip(sample_labels, cluster_labels)
1653
+ ])
1654
+
1655
+ # Get unique groups preserving order of appearance
1656
+ seen = {}
1657
+ unique_groups = []
1658
+ for g in combined:
1659
+ if g not in seen:
1660
+ seen[g] = len(unique_groups)
1661
+ unique_groups.append(g)
1662
+
1663
+ n_groups = len(unique_groups)
1664
+
1665
+ # --- Aggregate counts by matrix multiplication (matching R) ---
1666
+ # Build indicator matrix: n_cells x n_groups
1667
+ group_idx = np.array([seen[g] for g in combined])
1668
+ indicator = np.zeros((n_cells, n_groups), dtype=np.float64)
1669
+ indicator[np.arange(n_cells), group_idx] = 1.0
1670
+
1671
+ # counts (genes x cells) @ indicator (cells x groups) = pb (genes x groups)
1672
+ counts_pb = counts @ indicator
1673
+
1674
+ # --- Build sample metadata ---
1675
+ sample_pb = []
1676
+ cluster_pb = []
1677
+ for g in unique_groups:
1678
+ # Parse "samplename_clusterX" back into sample and cluster
1679
+ idx = g.index("_cluster")
1680
+ sample_pb.append(g[:idx])
1681
+ cluster_pb.append(g[idx + 8:]) # len("_cluster") == 8
1682
+
1683
+ samples_df = pd.DataFrame({
1684
+ 'sample': sample_pb,
1685
+ 'cluster': cluster_pb,
1686
+ }, index=unique_groups)
1687
+
1688
+ # --- Build gene annotation ---
1689
+ genes_df = None
1690
+ if gene_info is not None:
1691
+ genes_df = gene_info.copy()
1692
+ if 'gene' not in genes_df.columns and gene_names is not None:
1693
+ genes_df.insert(0, 'gene', gene_names)
1694
+ elif gene_names is not None:
1695
+ genes_df = pd.DataFrame({'gene': gene_names})
1696
+
1697
+ # --- Create DGEList ---
1698
+ dge = make_dgelist(counts_pb, genes=genes_df)
1699
+ # Set sample metadata
1700
+ dge['samples']['sample'] = samples_df['sample'].values
1701
+ dge['samples']['cluster'] = samples_df['cluster'].values
1702
+ dge['samples'].index = unique_groups
1703
+
1704
+ return dge
1705
+
1706
+
1707
+ # =====================================================================
1708
+ # AnnData export
1709
+ # =====================================================================
1710
+
1711
+ def to_anndata(obj, adata=None):
1712
+ """Convert edgePython results to AnnData format.
1713
+
1714
+ Stores results in a predictable schema compatible with the Scanpy
1715
+ ecosystem. Can either create a new AnnData or update an existing
1716
+ one in-place.
1717
+
1718
+ Schema
1719
+ ------
1720
+ .X : ndarray
1721
+ Raw counts (samples x genes, transposed from edgePython layout).
1722
+ .layers['counts'] : ndarray
1723
+ Copy of raw counts in the standard Scanpy layer.
1724
+ .obs : DataFrame
1725
+ Sample metadata: group, lib_size, norm_factors.
1726
+ .var : DataFrame
1727
+ Gene metadata from genes DataFrame, plus per-gene DE results
1728
+ (logFC, logCPM, PValue, FDR, F/LR) and dispersions.
1729
+ .varm['edgepython_coefficients'] : ndarray
1730
+ GLM coefficient matrix (genes x coefficients) when available.
1731
+ .uns['edgepython'] : dict
1732
+ Global results: common_dispersion, prior_df, design matrix,
1733
+ test method, contrast info, overdispersion prior.
1734
+
1735
+ Parameters
1736
+ ----------
1737
+ obj : DGEList, DGELRT, DGEExact, TopTags, or DGEGLM
1738
+ edgePython result object.
1739
+ adata : AnnData or None
1740
+ Existing AnnData to update. If None, creates a new one.
1741
+
1742
+ Returns
1743
+ -------
1744
+ AnnData
1745
+ """
1746
+ try:
1747
+ import anndata
1748
+ except ImportError:
1749
+ raise ImportError(
1750
+ "anndata package required for AnnData export. "
1751
+ "Install with: pip install anndata"
1752
+ )
1753
+
1754
+ from .classes import DGEList, DGEExact, DGEGLM, DGELRT, TopTags
1755
+
1756
+ # --- Extract counts if available ---
1757
+ counts = None
1758
+ if 'counts' in obj and obj['counts'] is not None:
1759
+ counts = obj['counts']
1760
+
1761
+ n_genes = None
1762
+ if counts is not None:
1763
+ n_genes, n_samples = counts.shape
1764
+ elif 'coefficients' in obj and obj['coefficients'] is not None:
1765
+ n_genes = obj['coefficients'].shape[0]
1766
+ elif 'table' in obj and obj['table'] is not None:
1767
+ n_genes = len(obj['table'])
1768
+
1769
+ # --- Build or update AnnData ---
1770
+ if adata is None:
1771
+ if counts is not None:
1772
+ # Create new with counts: AnnData is obs×var (samples×genes)
1773
+ adata = anndata.AnnData(
1774
+ X=counts.T.copy(),
1775
+ dtype=np.float64,
1776
+ )
1777
+ adata.layers['counts'] = counts.T.copy()
1778
+ else:
1779
+ # No counts (e.g., TopTags or DGELRT without counts)
1780
+ # Create minimal AnnData from table or coefficients
1781
+ if 'table' in obj and obj['table'] is not None:
1782
+ table = obj['table']
1783
+ n_vars = len(table)
1784
+ n_obs = obj['samples'].shape[0] if 'samples' in obj else 1
1785
+ adata = anndata.AnnData(
1786
+ X=np.zeros((n_obs, n_vars), dtype=np.float64),
1787
+ dtype=np.float64,
1788
+ )
1789
+ else:
1790
+ raise ValueError(
1791
+ "No counts or results table found in object. "
1792
+ "Pass a DGEList, DGELRT, DGEExact, or TopTags."
1793
+ )
1794
+ else:
1795
+ # Update existing — just add results, don't touch .X
1796
+ pass
1797
+
1798
+ # --- .obs: sample metadata ---
1799
+ if 'samples' in obj and obj['samples'] is not None:
1800
+ sam = obj['samples']
1801
+ adata.obs_names = list(sam.index)
1802
+ if 'group' in sam.columns:
1803
+ adata.obs['group'] = sam['group'].values
1804
+ if 'lib.size' in sam.columns:
1805
+ adata.obs['lib_size'] = sam['lib.size'].values
1806
+ if 'norm.factors' in sam.columns:
1807
+ adata.obs['norm_factors'] = sam['norm.factors'].values
1808
+
1809
+ # --- .var: gene metadata ---
1810
+ if 'genes' in obj and obj['genes'] is not None:
1811
+ genes_df = obj['genes']
1812
+ adata.var_names = list(genes_df.index)
1813
+ for col in genes_df.columns:
1814
+ adata.var[col] = genes_df[col].values
1815
+
1816
+ # --- .var: DE test results ---
1817
+ table = None
1818
+ if 'table' in obj and obj['table'] is not None:
1819
+ t = obj['table']
1820
+ if isinstance(t, pd.DataFrame) and len(t) > 0:
1821
+ table = t
1822
+
1823
+ if table is not None:
1824
+ n_var = adata.shape[1]
1825
+ if len(table) == n_var:
1826
+ # Full table — assign directly by position
1827
+ for col in table.columns:
1828
+ adata.var[col] = table[col].values
1829
+ else:
1830
+ # Partial (e.g., top n genes) — fill NaN first
1831
+ for col in table.columns:
1832
+ adata.var[col] = np.nan
1833
+ adata.var.loc[table.index, col] = table[col].values
1834
+
1835
+ # --- .var: dispersions ---
1836
+ n_var = adata.shape[1]
1837
+ if 'tagwise.dispersion' in obj and obj['tagwise.dispersion'] is not None:
1838
+ v = obj['tagwise.dispersion']
1839
+ if hasattr(v, '__len__') and len(v) == n_var:
1840
+ adata.var['tagwise_dispersion'] = v
1841
+ if 'trended.dispersion' in obj and obj['trended.dispersion'] is not None:
1842
+ v = obj['trended.dispersion']
1843
+ if hasattr(v, '__len__') and len(v) == n_var:
1844
+ adata.var['trended_dispersion'] = v
1845
+ if 'dispersion' in obj and obj['dispersion'] is not None:
1846
+ v = obj['dispersion']
1847
+ if hasattr(v, '__len__') and len(v) == n_var:
1848
+ adata.var['dispersion'] = v
1849
+ if 'AveLogCPM' in obj and obj['AveLogCPM'] is not None:
1850
+ v = obj['AveLogCPM']
1851
+ n_var = adata.shape[1]
1852
+ if hasattr(v, '__len__') and len(v) == n_var:
1853
+ adata.var['AveLogCPM'] = v
1854
+
1855
+ # --- .varm: GLM coefficients ---
1856
+ if 'coefficients' in obj and obj['coefficients'] is not None:
1857
+ coefs = obj['coefficients']
1858
+ n_var = adata.shape[1]
1859
+ if isinstance(coefs, np.ndarray) and coefs.shape[0] == n_var:
1860
+ adata.varm['edgepython_coefficients'] = coefs
1861
+
1862
+ # --- .uns: global / scalar results ---
1863
+ uns = {}
1864
+ if 'common.dispersion' in obj and obj['common.dispersion'] is not None:
1865
+ uns['common_dispersion'] = float(obj['common.dispersion'])
1866
+ if 'method' in obj and obj['method'] is not None:
1867
+ uns['method'] = obj['method']
1868
+ if 'prior.df' in obj and obj['prior.df'] is not None:
1869
+ uns['prior_df'] = float(obj['prior.df'])
1870
+ if 'df.prior' in obj and obj['df.prior'] is not None:
1871
+ v = obj['df.prior']
1872
+ uns['df_prior'] = float(v) if np.isscalar(v) else v.tolist()
1873
+ if 'design' in obj and obj['design'] is not None:
1874
+ uns['design'] = obj['design'].tolist()
1875
+ if 'comparison' in obj:
1876
+ uns['comparison'] = obj['comparison']
1877
+ if 'test' in obj:
1878
+ uns['test'] = obj['test']
1879
+ if 'adjust.method' in obj:
1880
+ uns['adjust_method'] = obj['adjust.method']
1881
+ if 'overdispersion.prior' in obj:
1882
+ uns['overdispersion_prior'] = float(obj['overdispersion.prior'])
1883
+
1884
+ if uns:
1885
+ adata.uns['edgepython'] = uns
1886
+
1887
+ return adata