M3Drop 0.4.42__py3-none-any.whl → 0.4.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
m3Drop/CoreCPU.py ADDED
@@ -0,0 +1,508 @@
1
+ import time
2
+ import psutil
3
+ import h5py
4
+ import numpy as np
5
+ import anndata
6
+ import pandas as pd
7
+ import os
8
+ import sys
9
+ import pickle
10
+
11
+ # [OPTIMIZATION] Use Numba for near-C++ speed on CPU
12
+ try:
13
+ import numba
14
+ from numba import jit, prange
15
+ except ImportError:
16
+ print("CRITICAL ERROR: 'numba' not found. Please install it (pip install numba) for CPU optimization.")
17
+ sys.exit(1)
18
+
19
+ import statsmodels.api as sm
20
+ import matplotlib.pyplot as plt
21
+ from scipy.stats import norm
22
+ from scipy import sparse
23
+ from statsmodels.stats.multitest import multipletests
24
+
25
+ # [FIX] Strict Relative Import
26
+ # This ensures that if ControlDeviceCPU fails to load (e.g. missing dependency),
27
+ # the real error is shown instead of being masked.
28
+ from .ControlDeviceCPU import ControlDevice
29
+
30
+ # ==========================================
31
+ # NUMBA KERNELS (CPU OPTIMIZED)
32
+ # ==========================================
33
+
34
+ @jit(nopython=True, cache=True)
35
+ def nan_replace_cpu(x):
36
+ """Replaces NaNs with 0 and Infs with 0 or 1."""
37
+ flat = x.ravel()
38
+ for i in range(flat.size):
39
+ val = flat[i]
40
+ if np.isnan(val):
41
+ flat[i] = 0.0
42
+ elif np.isinf(val):
43
+ flat[i] = 1.0 if val > 0 else 0.0
44
+ return x.reshape(x.shape)
45
+
46
+ @jit(nopython=True, parallel=True, fastmath=True)
47
+ def dropout_prob_kernel_cpu(tj, ti, total, exp_size, out_matrix):
48
+ """
49
+ Calculates dropout probabilities using Negative Binomial logic.
50
+ Parallelized across CPU cores.
51
+ """
52
+ rows = out_matrix.shape[0]
53
+ cols = out_matrix.shape[1]
54
+
55
+ # Numba handles the broadcasting loops explicitly for max speed
56
+ for r in prange(rows):
57
+ ti_val = ti[r]
58
+ for c in range(cols):
59
+ mu = (tj[c] * ti_val) / total
60
+ size_val = exp_size[c]
61
+
62
+ base = (mu / size_val) + 1.0
63
+ if base < 1e-12:
64
+ base = 1e-12
65
+
66
+ # pow(base, -size_val)
67
+ val = base ** (-size_val)
68
+
69
+ if np.isnan(val):
70
+ out_matrix[r, c] = 0.0
71
+ elif np.isinf(val):
72
+ out_matrix[r, c] = 1.0 if val > 0 else 0.0
73
+ else:
74
+ out_matrix[r, c] = val
75
+
76
+ @jit(nopython=True, cache=True)
77
+ def dropout_variance_inplace_cpu(p):
78
+ """Calculates variance p * (1 - p) in-place."""
79
+ flat = p.ravel()
80
+ for i in range(flat.size):
81
+ val = flat[i]
82
+ flat[i] = val - (val * val)
83
+
84
+ # ==========================================
85
+ # STAGE 1: MASK GENERATION
86
+ # ==========================================
87
+
88
+ def ConvertDataSparseCPU(input_filename: str, output_mask_filename: str, mode: str = "auto", manual_target: int = 3000):
89
+ start_time = time.perf_counter()
90
+ print(f"FUNCTION: ConvertDataSparseCPU() | FILE: {input_filename}")
91
+
92
+ device = ControlDevice.from_h5ad(input_filename, mode=mode, manual_target=manual_target)
93
+ n_cells = device.total_rows
94
+ n_genes = device.n_genes
95
+
96
+ with h5py.File(input_filename, 'r') as f_in:
97
+ x_group_in = f_in['X']
98
+ print(f"Phase [1/1]: identifying expressed genes...")
99
+ genes_to_keep_mask = np.zeros(n_genes, dtype=bool)
100
+
101
+ h5_indptr = x_group_in['indptr']
102
+ h5_indices = x_group_in['indices']
103
+
104
+ current_row = 0
105
+ while current_row < n_cells:
106
+ # Overhead 1.0 is fine for sparse scan on CPU
107
+ end_row = device.get_next_chunk(current_row, mode='sparse', overhead_multiplier=1.0)
108
+ if end_row is None or end_row <= current_row: break
109
+
110
+ chunk_size = end_row - current_row
111
+ print(f"Phase [1/1]: Scanning rows {end_row} of {n_cells} | Chunk: {chunk_size}", end='\r')
112
+
113
+ start_idx, end_idx = h5_indptr[current_row], h5_indptr[end_row]
114
+ if start_idx == end_idx:
115
+ current_row = end_row
116
+ continue
117
+
118
+ indices = h5_indices[start_idx:end_idx]
119
+ unique_indices = np.unique(indices)
120
+ genes_to_keep_mask[unique_indices] = True
121
+
122
+ current_row = end_row
123
+
124
+ n_genes_to_keep = int(np.sum(genes_to_keep_mask))
125
+ print(f"\nPhase [1/1]: COMPLETE | Result: {n_genes_to_keep} / {n_genes} genes retained.")
126
+
127
+ print(f"Saving mask to {output_mask_filename}...")
128
+ with open(output_mask_filename, 'wb') as f:
129
+ pickle.dump(genes_to_keep_mask, f)
130
+
131
+ end_time = time.perf_counter()
132
+ print(f"Total time: {end_time - start_time:.2f} seconds.\n")
133
+
134
+ # ==========================================
135
+ # STAGE 2: STATISTICS
136
+ # ==========================================
137
+
138
+ def hidden_calc_valsCPU(filename: str, mask_filename: str, mode: str = "auto", manual_target: int = 3000) -> dict:
139
+ start_time = time.perf_counter()
140
+ print(f"FUNCTION: hidden_calc_valsCPU() | FILE: {filename}")
141
+
142
+ # 1. Load Mask
143
+ with open(mask_filename, 'rb') as f: mask = pickle.load(f)
144
+ ng_filtered = int(np.sum(mask))
145
+
146
+ # 2. Init Device
147
+ with h5py.File(filename, 'r') as f:
148
+ indptr_cpu = f['X']['indptr'][:]
149
+ total_rows = len(indptr_cpu) - 1
150
+
151
+ device = ControlDevice(
152
+ indptr=indptr_cpu,
153
+ total_rows=total_rows,
154
+ n_genes=ng_filtered,
155
+ mode=mode,
156
+ manual_target=manual_target
157
+ )
158
+ nc = device.total_rows
159
+
160
+ adata_meta = anndata.read_h5ad(filename, backed='r')
161
+ tis = np.zeros(nc, dtype='float64')
162
+ cell_non_zeros = np.zeros(nc, dtype='int64')
163
+ tjs = np.zeros(ng_filtered, dtype=np.float64)
164
+ gene_non_zeros = np.zeros(ng_filtered, dtype=np.int32)
165
+
166
+ print("Phase [1/2]: Calculating statistics...")
167
+ with h5py.File(filename, 'r') as f_in:
168
+ x_group = f_in['X']
169
+ h5_indptr = x_group['indptr']
170
+ h5_data = x_group['data']
171
+ h5_indices = x_group['indices']
172
+
173
+ current_row = 0
174
+ while current_row < nc:
175
+ end_row = device.get_next_chunk(current_row, mode='sparse', overhead_multiplier=1.1)
176
+ if end_row is None or end_row <= current_row: break
177
+
178
+ chunk_size = end_row - current_row
179
+ print(f"Phase [1/2]: Processing {end_row} of {nc} | Chunk: {chunk_size}", end='\r')
180
+
181
+ start_idx, end_idx = h5_indptr[current_row], h5_indptr[end_row]
182
+ data = np.array(h5_data[start_idx:end_idx], dtype=np.float64)
183
+ indices = np.array(h5_indices[start_idx:end_idx])
184
+ indptr = np.array(h5_indptr[current_row:end_row+1] - h5_indptr[current_row])
185
+
186
+ # Use Scipy CSR for CPU operations
187
+ chunk_csr = sparse.csr_matrix((data, indices, indptr), shape=(chunk_size, len(mask)))
188
+
189
+ # --- VIRTUAL FILTER + CEIL ---
190
+ chunk_csr = chunk_csr[:, mask]
191
+ chunk_csr.data = np.ceil(chunk_csr.data)
192
+ # -----------------------------
193
+
194
+ tis[current_row:end_row] = np.array(chunk_csr.sum(axis=1)).flatten()
195
+ cell_non_zeros[current_row:end_row] = np.diff(chunk_csr.indptr)
196
+
197
+ # Numpy 'add.at' equivalent for sparse accumulation
198
+ np.add.at(tjs, chunk_csr.indices, chunk_csr.data)
199
+
200
+ unique_indices, counts = np.unique(chunk_csr.indices, return_counts=True)
201
+ np.add.at(gene_non_zeros, unique_indices, counts)
202
+
203
+ current_row = end_row
204
+
205
+ print(f"\nPhase [1/2]: COMPLETE{' ' * 50}")
206
+
207
+ print("Phase [2/2]: Finalizing stats...")
208
+ dis = ng_filtered - cell_non_zeros
209
+ djs = nc - gene_non_zeros
210
+ total = tjs.sum()
211
+ print("Phase [2/2]: COMPLETE")
212
+
213
+ end_time = time.perf_counter()
214
+ print(f"Total time: {end_time - start_time:.2f} seconds.\n")
215
+
216
+ filtered_var_index = adata_meta.var.index[mask]
217
+
218
+ return {
219
+ "tis": pd.Series(tis, index=adata_meta.obs.index),
220
+ "tjs": pd.Series(tjs, index=filtered_var_index),
221
+ "dis": pd.Series(dis, index=adata_meta.obs.index),
222
+ "djs": pd.Series(djs, index=filtered_var_index),
223
+ "total": total,
224
+ "nc": nc,
225
+ "ng": ng_filtered
226
+ }
227
+
228
+ def NBumiFitModelCPU(raw_filename: str, mask_filename: str, stats: dict, mode: str = "auto", manual_target: int = 3000) -> dict:
229
+ start_time = time.perf_counter()
230
+ print(f"FUNCTION: NBumiFitModelCPU() | FILE: {raw_filename}")
231
+
232
+ with open(mask_filename, 'rb') as f: mask = pickle.load(f)
233
+ ng_filtered = stats['ng']
234
+
235
+ with h5py.File(raw_filename, 'r') as f:
236
+ indptr_cpu = f['X']['indptr'][:]
237
+ total_rows = len(indptr_cpu) - 1
238
+ device = ControlDevice(indptr=indptr_cpu, total_rows=total_rows, n_genes=ng_filtered, mode=mode, manual_target=manual_target)
239
+ nc = device.total_rows
240
+
241
+ tjs = stats['tjs'].values
242
+ tis = stats['tis'].values
243
+ total = stats['total']
244
+
245
+ # Numpy arrays
246
+ sum_x_sq = np.zeros(ng_filtered, dtype=np.float64)
247
+ sum_2xmu = np.zeros(ng_filtered, dtype=np.float64)
248
+
249
+ print("Phase [1/3]: Pre-calculating sum of squared expectations...")
250
+ sum_tis_sq = np.sum(tis**2)
251
+ sum_mu_sq = (tjs**2 / total**2) * sum_tis_sq
252
+ print("Phase [1/3]: COMPLETE")
253
+
254
+ print("Phase [2/3]: Calculating variance components...")
255
+ with h5py.File(raw_filename, 'r') as f_in:
256
+ x_group = f_in['X']
257
+ h5_indptr = x_group['indptr']
258
+ h5_data = x_group['data']
259
+ h5_indices = x_group['indices']
260
+
261
+ current_row = 0
262
+ while current_row < nc:
263
+ # L3 optimization is critical here for CPU performance
264
+ end_row = device.get_next_chunk(current_row, mode='sparse', overhead_multiplier=1.1)
265
+ if end_row is None or end_row <= current_row: break
266
+
267
+ chunk_size = end_row - current_row
268
+ print(f"Phase [2/3]: Processing {end_row} of {nc} | Chunk: {chunk_size}", end='\r')
269
+
270
+ start_idx, end_idx = h5_indptr[current_row], h5_indptr[end_row]
271
+ data = np.array(h5_data[start_idx:end_idx], dtype=np.float64)
272
+ indices = np.array(h5_indices[start_idx:end_idx])
273
+ indptr = np.array(h5_indptr[current_row:end_row+1] - h5_indptr[current_row])
274
+
275
+ chunk_csr = sparse.csr_matrix((data, indices, indptr), shape=(chunk_size, len(mask)))
276
+ chunk_csr = chunk_csr[:, mask]
277
+ chunk_csr.data = np.ceil(chunk_csr.data)
278
+
279
+ # Accumulate X^2
280
+ np.add.at(sum_x_sq, chunk_csr.indices, chunk_csr.data**2)
281
+
282
+ # Vectorized term calculation for 2 * x * mu
283
+ # To avoid expanding dense matrices, we iterate over CSR structure manually or use broadcasting
284
+ # For CPU, iterating over the non-zeros is efficient enough
285
+
286
+ # Map row indices to global cell indices
287
+ row_indices = np.repeat(np.arange(chunk_size), np.diff(chunk_csr.indptr)) + current_row
288
+ global_tis = tis[row_indices]
289
+
290
+ term_vals = 2 * chunk_csr.data * tjs[chunk_csr.indices] * global_tis / total
291
+ np.add.at(sum_2xmu, chunk_csr.indices, term_vals)
292
+
293
+ current_row = end_row
294
+
295
+ print(f"\nPhase [2/3]: COMPLETE {' ' * 50}")
296
+
297
+ print("Phase [3/3]: Finalizing dispersion...")
298
+ sum_sq_dev = sum_x_sq - sum_2xmu + sum_mu_sq
299
+ var_obs = sum_sq_dev / (nc - 1)
300
+
301
+ sizes = np.full(ng_filtered, 10000.0)
302
+ numerator = (tjs**2 / total**2) * sum_tis_sq
303
+ denominator = sum_sq_dev - tjs
304
+
305
+ stable_mask = denominator > 1e-6
306
+ sizes[stable_mask] = numerator[stable_mask] / denominator[stable_mask]
307
+ sizes[sizes <= 0] = 10000.0
308
+
309
+ print("Phase [3/3]: COMPLETE")
310
+
311
+ end_time = time.perf_counter()
312
+ print(f"Total time: {end_time - start_time:.2f} seconds.\n")
313
+
314
+ return {
315
+ 'var_obs': pd.Series(var_obs, index=stats['tjs'].index),
316
+ 'sizes': pd.Series(sizes, index=stats['tjs'].index),
317
+ 'vals': stats
318
+ }
319
+
320
+ def NBumiFitDispVsMeanCPU(fit: dict, suppress_plot=True):
321
+ vals = fit['vals']
322
+ size_g = fit['sizes'].values
323
+ tjs = vals['tjs'].values
324
+ mean_expression = tjs / vals['nc']
325
+
326
+ forfit = (np.isfinite(size_g)) & (size_g < 1e6) & (mean_expression > 1e-3) & (size_g > 0)
327
+ log2_mean_expr = np.log2(mean_expression, where=(mean_expression > 0))
328
+
329
+ higher = log2_mean_expr > 4
330
+ if np.sum(higher & forfit) > 2000:
331
+ forfit = higher & forfit
332
+
333
+ y = np.log(size_g[forfit])
334
+ x = np.log(mean_expression[forfit])
335
+
336
+ X = sm.add_constant(x)
337
+ model = sm.OLS(y, X).fit()
338
+
339
+ if not suppress_plot:
340
+ plt.figure(figsize=(7, 6))
341
+ plt.scatter(x, y, alpha=0.5, s=1)
342
+ plt.plot(x, model.fittedvalues, color='red')
343
+ plt.show()
344
+
345
+ return model.params
346
+
347
+ def NBumiFeatureSelectionHighVarCPU(fit: dict) -> pd.DataFrame:
348
+ start_time = time.perf_counter()
349
+ print(f"FUNCTION: NBumiFeatureSelectionHighVarCPU()")
350
+
351
+ vals = fit['vals']
352
+ coeffs = NBumiFitDispVsMeanCPU(fit, suppress_plot=True)
353
+ mean_expression = vals['tjs'].values / vals['nc']
354
+
355
+ with np.errstate(divide='ignore', invalid='ignore'):
356
+ log_mean_expression = np.log(mean_expression)
357
+ log_mean_expression[np.isneginf(log_mean_expression)] = 0
358
+ exp_size = np.exp(coeffs[0] + coeffs[1] * log_mean_expression)
359
+ res = np.log(fit['sizes'].values) - np.log(exp_size)
360
+
361
+ results_df = pd.DataFrame({'Gene': fit['sizes'].index, 'Residual': res})
362
+ final_table = results_df.sort_values(by='Residual', ascending=True)
363
+
364
+ end_time = time.perf_counter()
365
+ print(f"Total time: {end_time - start_time:.4f} seconds.\n")
366
+ return final_table
367
+
368
+ def NBumiFeatureSelectionCombinedDropCPU(
369
+ fit: dict,
370
+ raw_filename: str,
371
+ method="fdr_bh",
372
+ qval_thresh=0.05,
373
+ mode: str = "auto",
374
+ manual_target: int = 3000
375
+ ) -> pd.DataFrame:
376
+
377
+ start_time = time.perf_counter()
378
+ print(f"FUNCTION: NBumiFeatureSelectionCombinedDropCPU() | FILE: {raw_filename}")
379
+
380
+ ng_filtered = fit['vals']['ng']
381
+
382
+ with h5py.File(raw_filename, 'r') as f:
383
+ indptr_cpu = f['X']['indptr'][:]
384
+ total_rows = len(indptr_cpu) - 1
385
+ device = ControlDevice(indptr=indptr_cpu, total_rows=total_rows, n_genes=ng_filtered, mode=mode, manual_target=manual_target)
386
+ nc = device.total_rows
387
+
388
+ print("Phase [1/3]: Initializing arrays...")
389
+ vals = fit['vals']
390
+ coeffs = NBumiFitDispVsMeanCPU(fit, suppress_plot=True)
391
+
392
+ tjs = vals['tjs'].values
393
+ tis = vals['tis'].values
394
+ total = vals['total']
395
+
396
+ mean_expression = vals['tjs'].values / nc
397
+ with np.errstate(divide='ignore'):
398
+ exp_size = np.exp(coeffs[0] + coeffs[1] * np.log(mean_expression))
399
+
400
+ # Pre-allocate accumulators
401
+ p_sum = np.zeros(ng_filtered, dtype=np.float64)
402
+ p_var_sum = np.zeros(ng_filtered, dtype=np.float64)
403
+ print("Phase [1/3]: COMPLETE")
404
+
405
+ print("Phase [2/3]: Calculating dropout stats (Virtual)...")
406
+
407
+ current_row = 0
408
+ while current_row < nc:
409
+ # Dense mode allows Numba to rip through the data
410
+ end_row = device.get_next_chunk(current_row, mode='dense', overhead_multiplier=1.1)
411
+ if end_row is None or end_row <= current_row: break
412
+
413
+ chunk_size = end_row - current_row
414
+ print(f"Phase [2/3]: Processing {end_row} of {nc} | Chunk: {chunk_size}", end='\r')
415
+
416
+ tis_chunk = tis[current_row:end_row]
417
+ work_matrix = np.empty((chunk_size, ng_filtered), dtype=np.float64)
418
+
419
+ # CALL NUMBA KERNEL
420
+ dropout_prob_kernel_cpu(
421
+ tjs,
422
+ tis_chunk,
423
+ total,
424
+ exp_size,
425
+ work_matrix
426
+ )
427
+
428
+ p_sum += work_matrix.sum(axis=0)
429
+
430
+ # In-place variance calc
431
+ dropout_variance_inplace_cpu(work_matrix)
432
+ p_var_sum += work_matrix.sum(axis=0)
433
+
434
+ current_row = end_row
435
+
436
+ print(f"\nPhase [2/3]: COMPLETE {' ' * 50}")
437
+
438
+ print("Phase [3/3]: Statistical testing...")
439
+
440
+ droprate_exp = p_sum / nc
441
+ droprate_exp_err = np.sqrt(p_var_sum / (nc**2))
442
+ droprate_obs = vals['djs'].values / nc
443
+
444
+ diff = droprate_obs - droprate_exp
445
+ combined_err = np.sqrt(droprate_exp_err**2 + (droprate_obs * (1 - droprate_obs) / nc))
446
+
447
+ with np.errstate(divide='ignore', invalid='ignore'):
448
+ Zed = diff / combined_err
449
+
450
+ pvalue = norm.sf(Zed)
451
+
452
+ results_df = pd.DataFrame({'Gene': vals['tjs'].index, 'p.value': pvalue, 'effect_size': diff})
453
+ results_df = results_df.sort_values(by='p.value')
454
+
455
+ qval = multipletests(results_df['p.value'].fillna(1), method=method)[1]
456
+ results_df['q.value'] = qval
457
+ final_table = results_df[results_df['q.value'] < qval_thresh]
458
+
459
+ print("Phase [3/3]: COMPLETE")
460
+ end_time = time.perf_counter()
461
+ print(f"Total time: {end_time - start_time:.2f} seconds.\n")
462
+
463
+ return final_table[['Gene', 'effect_size', 'p.value', 'q.value']]
464
+
465
+ def NBumiCombinedDropVolcanoCPU(results_df: pd.DataFrame, qval_thresh=0.05, effect_size_thresh=0.25, top_n_genes=10, suppress_plot=False, plot_filename=None):
466
+ start_time = time.perf_counter()
467
+ print(f"FUNCTION: NBumiCombinedDropVolcanoCPU()")
468
+
469
+ # Standard Matplotlib code - safe for CPU
470
+ df = results_df.copy()
471
+ if (df['q.value'] == 0).any():
472
+ non_zero_min = df[df['q.value'] > 0]['q.value'].min()
473
+ df['q.value'] = df['q.value'].replace(0, non_zero_min)
474
+
475
+ df['-log10_qval'] = -np.log10(df['q.value'])
476
+ df['color'] = 'grey'
477
+ df.loc[(df['q.value'] < qval_thresh) & (df['effect_size'] > effect_size_thresh), 'color'] = 'red'
478
+ df.loc[(df['q.value'] < qval_thresh) & (df['effect_size'] < -effect_size_thresh), 'color'] = 'blue'
479
+
480
+ plt.figure(figsize=(10, 8))
481
+ plt.scatter(df['effect_size'], df['-log10_qval'], c=df['color'], s=10, alpha=0.6, edgecolors='none')
482
+
483
+ plt.axvline(x=effect_size_thresh, linestyle='--', color='grey', linewidth=0.8)
484
+ plt.axvline(x=-effect_size_thresh, linestyle='--', color='grey', linewidth=0.8)
485
+ plt.axhline(y=-np.log10(qval_thresh), linestyle='--', color='grey', linewidth=0.8)
486
+
487
+ top_genes = df.nsmallest(top_n_genes, 'q.value')
488
+ for i, row in top_genes.iterrows():
489
+ plt.text(row['effect_size'], row['-log10_qval'], row['Gene'], fontsize=9, fontweight='bold')
490
+
491
+ plt.title('Volcano Plot: Dropout Rate vs Significance (CPU)')
492
+ plt.xlabel('Effect Size (Observed - Expected Dropout Rate)')
493
+ plt.ylabel('-log10 (FDR Adjusted p-value)')
494
+ plt.grid(True, linestyle='--', alpha=0.3)
495
+ ax = plt.gca()
496
+
497
+ if plot_filename:
498
+ print(f"Saving plot to: {plot_filename}")
499
+ plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
500
+
501
+ if not suppress_plot:
502
+ plt.show()
503
+
504
+ plt.close()
505
+
506
+ end_time = time.perf_counter()
507
+ print(f"Total time: {end_time - start_time:.2f} seconds.\n")
508
+ return ax