M3Drop 0.4.42__py3-none-any.whl → 0.4.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,401 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import h5py
5
+ import os
6
+ import time
7
+ import pickle
8
+ import gc
9
+ from scipy import sparse
10
+ from scipy import stats
11
+ import anndata
12
+
13
+ import statsmodels.api as sm
14
+ from scipy.stats import norm
15
+ from statsmodels.stats.multitest import multipletests
16
+
17
+ # [FIX] Strict Relative Imports
18
+ from .ControlDeviceCPU import ControlDevice
19
+ from .CoreCPU import hidden_calc_valsCPU, NBumiFitModelCPU, NBumiFitDispVsMeanCPU, dropout_prob_kernel_cpu
20
+
21
+ # ==========================================
22
+ # DIAGNOSTICS & COMPARISON (CPU)
23
+ # ==========================================
24
+
25
+ def NBumiFitBasicModelCPU(
26
+ filename: str,
27
+ stats: dict,
28
+ mask_filename: str = None,
29
+ mode: str = "auto",
30
+ manual_target: int = 3000,
31
+ phase_label: str = "Phase [1/1]",
32
+ desc_label: str = None
33
+ ) -> dict:
34
+ """
35
+ Fits the Basic Model by calculating Normalized Variance ON-THE-FLY (CPU Optimized).
36
+ STRICT FLOAT64 ENFORCEMENT.
37
+ """
38
+ # 1. Get Raw Dimensions & Setup ControlDevice
39
+ with h5py.File(filename, 'r') as f:
40
+ indptr_cpu = f['X']['indptr'][:]
41
+ total_rows = len(indptr_cpu) - 1
42
+ raw_ng = f['X'].attrs['shape'][1]
43
+
44
+ device = ControlDevice(
45
+ indptr=indptr_cpu,
46
+ total_rows=total_rows,
47
+ n_genes=raw_ng,
48
+ mode=mode,
49
+ manual_target=manual_target
50
+ )
51
+ nc = device.total_rows
52
+
53
+ if desc_label:
54
+ print(f"{phase_label}: {desc_label}")
55
+
56
+ # 2. Load Mask
57
+ if mask_filename and os.path.exists(mask_filename):
58
+ with open(mask_filename, 'rb') as f:
59
+ mask = pickle.load(f)
60
+ else:
61
+ mask = np.ones(raw_ng, dtype=bool)
62
+
63
+ filtered_ng = int(np.sum(mask))
64
+
65
+ # 3. Pre-calculate Size Factors
66
+ cell_sums = stats['tis'].values
67
+ median_sum = np.median(cell_sums[cell_sums > 0])
68
+
69
+ # [FLOAT64] Explicit enforcement
70
+ size_factors = np.ones_like(cell_sums, dtype=np.float64)
71
+ non_zero_mask = cell_sums > 0
72
+ size_factors[non_zero_mask] = cell_sums[non_zero_mask] / median_sum
73
+
74
+ # 4. Init Accumulators
75
+ sum_norm_x = np.zeros(filtered_ng, dtype=np.float64)
76
+ sum_norm_sq = np.zeros(filtered_ng, dtype=np.float64)
77
+
78
+ with h5py.File(filename, 'r') as f_in:
79
+ h5_indptr = f_in['X']['indptr']
80
+ h5_data = f_in['X']['data']
81
+ h5_indices = f_in['X']['indices']
82
+
83
+ current_row = 0
84
+ while current_row < nc:
85
+ # CPU prefers dense chunks if they fit in L3, but sparse is safer for RAM.
86
+ # We use 'dense' mode here because we convert to dense for normalization anyway.
87
+ end_row = device.get_next_chunk(current_row, mode='dense', overhead_multiplier=1.5)
88
+ if end_row is None or end_row <= current_row: break
89
+
90
+ chunk_size = end_row - current_row
91
+ print(f"{phase_label}: Processing {end_row} of {nc} | Chunk: {chunk_size}", end='\r')
92
+
93
+ start_idx, end_idx = h5_indptr[current_row], h5_indptr[end_row]
94
+ if start_idx == end_idx:
95
+ current_row = end_row
96
+ continue
97
+
98
+ # [FLOAT64] Load Raw Chunk
99
+ data = np.array(h5_data[start_idx:end_idx], dtype=np.float64)
100
+ indices = np.array(h5_indices[start_idx:end_idx])
101
+ indptr = np.array(h5_indptr[current_row:end_row+1] - h5_indptr[current_row])
102
+
103
+ # Reconstruct CSR & Filter
104
+ raw_chunk = sparse.csr_matrix((data, indices, indptr), shape=(chunk_size, raw_ng))
105
+ filtered_chunk = raw_chunk[:, mask]
106
+
107
+ # Normalization (Vectorized CPU)
108
+ sf_chunk = size_factors[current_row:end_row]
109
+
110
+ # Scipy sparse multiplication is efficient
111
+ # D = diag(1/sf)
112
+ recip_sf = 1.0 / sf_chunk
113
+ D = sparse.diags(recip_sf)
114
+ norm_chunk = D.dot(filtered_chunk)
115
+
116
+ # Rounding (in-place on data array)
117
+ np.round(norm_chunk.data, out=norm_chunk.data)
118
+
119
+ # Accumulate
120
+ # Convert to dense for summation if chunk is small (faster on CPU)
121
+ # or keep sparse if very large. Given L3 optimization, dense is often fine.
122
+ norm_dense = norm_chunk.toarray()
123
+
124
+ sum_norm_x += norm_dense.sum(axis=0)
125
+ sum_norm_sq += (norm_dense ** 2).sum(axis=0)
126
+
127
+ current_row = end_row
128
+
129
+ # Final Calculations
130
+ mean_norm = sum_norm_x / nc
131
+ mean_sq_norm = sum_norm_sq / nc
132
+ var_norm = mean_sq_norm - (mean_norm ** 2)
133
+
134
+ denom = var_norm - mean_norm
135
+ sizes = np.full(filtered_ng, 1000.0, dtype=np.float64)
136
+ valid_mask = denom > 1e-6
137
+ sizes[valid_mask] = mean_norm[valid_mask]**2 / denom[valid_mask]
138
+
139
+ # Filtering outliers (Numpy version)
140
+ with np.errstate(invalid='ignore'):
141
+ max_size_val = np.nanmax(sizes[sizes < 1e6]) * 10
142
+
143
+ if np.isnan(max_size_val) or max_size_val == 0: max_size_val = 1000.0
144
+ sizes[np.isnan(sizes) | (sizes <= 0)] = max_size_val
145
+ sizes[sizes < 1e-10] = 1e-10
146
+
147
+ print("")
148
+ print(f"{phase_label}: COMPLETE")
149
+
150
+ return {
151
+ 'var_obs': pd.Series(var_norm, index=stats['tjs'].index),
152
+ 'sizes': pd.Series(sizes, index=stats['tjs'].index),
153
+ 'vals': stats
154
+ }
155
+
156
+ def NBumiCheckFitFSCPU(
157
+ filename: str,
158
+ fit: dict,
159
+ mode: str = "auto",
160
+ manual_target: int = 3000,
161
+ suppress_plot=False,
162
+ plot_filename=None,
163
+ phase_label="Phase [1/1]",
164
+ desc_label: str = None
165
+ ) -> dict:
166
+ """
167
+ Calculates expected dropouts using NUMBA KERNEL on CPU.
168
+ """
169
+ vals = fit['vals']
170
+ ng = vals['ng']
171
+
172
+ with h5py.File(filename, 'r') as f:
173
+ indptr_cpu = f['X']['indptr'][:]
174
+ total_rows = len(indptr_cpu) - 1
175
+
176
+ device = ControlDevice(
177
+ indptr=indptr_cpu,
178
+ total_rows=total_rows,
179
+ n_genes=ng,
180
+ mode=mode,
181
+ manual_target=manual_target
182
+ )
183
+ nc = device.total_rows
184
+
185
+ if desc_label:
186
+ print(f"{phase_label}: {desc_label}")
187
+
188
+ size_coeffs = NBumiFitDispVsMeanCPU(fit, suppress_plot=True)
189
+
190
+ tjs = vals['tjs'].values.astype(np.float64)
191
+ tis = vals['tis'].values.astype(np.float64)
192
+ total = vals['total']
193
+
194
+ mean_expression = tjs / nc
195
+ log_mean_expression = np.zeros_like(mean_expression)
196
+ valid_means = mean_expression > 0
197
+ log_mean_expression[valid_means] = np.log(mean_expression[valid_means])
198
+ smoothed_size = np.exp(size_coeffs[0] + size_coeffs[1] * log_mean_expression)
199
+
200
+ row_ps = np.zeros(ng, dtype=np.float64)
201
+ col_ps = np.zeros(nc, dtype=np.float64)
202
+
203
+ current_row = 0
204
+ while current_row < nc:
205
+ # Use dense mode for Numba efficiency
206
+ end_row = device.get_next_chunk(current_row, mode='dense', overhead_multiplier=1.1)
207
+ if end_row is None or end_row <= current_row: break
208
+
209
+ chunk_size = end_row - current_row
210
+ print(f"{phase_label}: Processing {end_row} of {nc} | Chunk: {chunk_size}", end='\r')
211
+
212
+ tis_chunk = tis[current_row:end_row]
213
+
214
+ # [CRITICAL] NUMBA KERNEL CALL
215
+ # Prepare output buffer
216
+ p_is_chunk = np.empty((chunk_size, ng), dtype=np.float64)
217
+
218
+ dropout_prob_kernel_cpu(
219
+ tjs, # Gene totals
220
+ tis_chunk, # Cell totals (1D array, broadcasting handled inside kernel)
221
+ total, # Grand total
222
+ smoothed_size, # Exp size
223
+ p_is_chunk # Output buffer
224
+ )
225
+
226
+ # Sanitize
227
+ p_is_chunk = np.nan_to_num(p_is_chunk, nan=0.0, posinf=1.0, neginf=0.0)
228
+
229
+ row_ps += p_is_chunk.sum(axis=0)
230
+ col_ps[current_row:end_row] = p_is_chunk.sum(axis=1)
231
+
232
+ current_row = end_row
233
+
234
+ print("")
235
+ print(f"{phase_label}: COMPLETE")
236
+
237
+ return {
238
+ 'rowPs': pd.Series(row_ps, index=fit['vals']['tjs'].index),
239
+ 'colPs': pd.Series(col_ps, index=fit['vals']['tis'].index)
240
+ }
241
+
242
+ def NBumiCompareModelsCPU(
243
+ raw_filename: str,
244
+ stats: dict,
245
+ fit_adjust: dict,
246
+ mask_filename: str = None,
247
+ mode: str = "auto",
248
+ manual_target: int = 3000,
249
+ suppress_plot=False,
250
+ plot_filename=None
251
+ ) -> dict:
252
+ print(f"FUNCTION: NBumiCompareModelsCPU()")
253
+ pipeline_start_time = time.time()
254
+
255
+ # STEP 1: Fit Basic Model
256
+ fit_basic = NBumiFitBasicModelCPU(
257
+ raw_filename,
258
+ stats,
259
+ mask_filename=mask_filename,
260
+ mode=mode,
261
+ manual_target=manual_target,
262
+ phase_label="Phase [1/3]",
263
+ desc_label="Fitting Basic Model (Virtual)..."
264
+ )
265
+
266
+ # STEP 2: Depth-Adjusted Dropout
267
+ check_adjust = NBumiCheckFitFSCPU(
268
+ raw_filename,
269
+ fit_adjust,
270
+ mode=mode,
271
+ manual_target=manual_target,
272
+ suppress_plot=True,
273
+ phase_label="Phase [2/3]",
274
+ desc_label="Calculating Depth-Adjusted Dropouts..."
275
+ )
276
+
277
+ # STEP 3: Basic Dropout
278
+ stats_virtual = stats.copy()
279
+ mean_depth = stats['total'] / stats['nc']
280
+ stats_virtual['tis'] = pd.Series(
281
+ np.full(stats['nc'], mean_depth),
282
+ index=stats['tis'].index
283
+ )
284
+
285
+ fit_basic_for_eval = {
286
+ 'sizes': fit_basic['sizes'],
287
+ 'vals': stats_virtual,
288
+ 'var_obs': fit_basic['var_obs']
289
+ }
290
+
291
+ check_basic = NBumiCheckFitFSCPU(
292
+ raw_filename,
293
+ fit_basic_for_eval,
294
+ mode=mode,
295
+ manual_target=manual_target,
296
+ suppress_plot=True,
297
+ phase_label="Phase [3/3]",
298
+ desc_label="Calculating Basic Dropouts..."
299
+ )
300
+
301
+ # Calculation & Plotting
302
+ nc_data = stats['nc']
303
+ mean_expr = stats['tjs'] / nc_data
304
+ observed_dropout = stats['djs'] / nc_data
305
+
306
+ adj_dropout_fit = check_adjust['rowPs'] / nc_data
307
+ bas_dropout_fit = check_basic['rowPs'] / nc_data
308
+
309
+ err_adj = np.sum(np.abs(adj_dropout_fit - observed_dropout))
310
+ err_bas = np.sum(np.abs(bas_dropout_fit - observed_dropout))
311
+
312
+ comparison_df = pd.DataFrame({
313
+ 'mean_expr': mean_expr,
314
+ 'observed': observed_dropout,
315
+ 'adj_fit': adj_dropout_fit,
316
+ 'bas_fit': bas_dropout_fit
317
+ })
318
+
319
+ # Plotting Logic (Standard Matplotlib)
320
+ plt.figure(figsize=(10, 6))
321
+ sorted_idx = np.argsort(mean_expr.values)
322
+ plot_idx = sorted_idx[::2] if len(mean_expr) > 20000 else sorted_idx
323
+
324
+ plt.scatter(mean_expr.iloc[plot_idx], observed_dropout.iloc[plot_idx],
325
+ c='black', s=3, alpha=0.5, label='Observed')
326
+
327
+ plt.scatter(mean_expr.iloc[plot_idx], bas_dropout_fit.iloc[plot_idx],
328
+ c='purple', s=3, alpha=0.6, label=f'Basic Fit (Error: {err_bas:.2f})')
329
+
330
+ plt.scatter(mean_expr.iloc[plot_idx], adj_dropout_fit.iloc[plot_idx],
331
+ c='goldenrod', s=3, alpha=0.7, label=f'Depth-Adjusted Fit (Error: {err_adj:.2f})')
332
+
333
+ plt.xscale('log')
334
+ plt.xlabel("Mean Expression")
335
+ plt.ylabel("Dropout Rate")
336
+ plt.title("M3Drop Model Comparison (CPU)")
337
+ plt.legend()
338
+ plt.grid(True, linestyle='--', alpha=0.3)
339
+
340
+ if plot_filename:
341
+ plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
342
+ print(f"Saving plot to: {plot_filename}")
343
+
344
+ if not suppress_plot:
345
+ plt.show()
346
+
347
+ plt.close()
348
+
349
+ pipeline_end_time = time.time()
350
+ print(f"Total time: {pipeline_end_time - pipeline_start_time:.2f} seconds.\n")
351
+
352
+ return {
353
+ "errors": {"Depth-Adjusted": err_adj, "Basic": err_bas},
354
+ "comparison_df": comparison_df
355
+ }
356
+
357
+ def NBumiPlotDispVsMeanCPU(
358
+ fit: dict,
359
+ suppress_plot: bool = False,
360
+ plot_filename: str = None
361
+ ):
362
+ print("FUNCTION: NBumiPlotDispVsMeanCPU()")
363
+ start_time = time.time()
364
+
365
+ mean_expression = fit['vals']['tjs'].values / fit['vals']['nc']
366
+ sizes = fit['sizes'].values
367
+
368
+ coeffs = NBumiFitDispVsMeanCPU(fit, suppress_plot=True)
369
+ intercept, slope = coeffs[0], coeffs[1]
370
+
371
+ log_mean_expr_range = np.linspace(
372
+ np.log(mean_expression[mean_expression > 0].min()),
373
+ np.log(mean_expression.max()),
374
+ 100
375
+ )
376
+ log_fitted_sizes = intercept + slope * log_mean_expr_range
377
+ fitted_sizes = np.exp(log_fitted_sizes)
378
+
379
+ plt.figure(figsize=(8, 6))
380
+ plt.scatter(mean_expression, sizes, label='Observed Dispersion', alpha=0.5, s=8)
381
+ plt.plot(np.exp(log_mean_expr_range), fitted_sizes, color='red', label='Regression Fit', linewidth=2)
382
+
383
+ plt.xscale('log')
384
+ plt.yscale('log')
385
+ plt.xlabel('Mean Expression')
386
+ plt.ylabel('Dispersion Parameter (Sizes)')
387
+ plt.title('Dispersion vs. Mean Expression (CPU)')
388
+ plt.legend()
389
+ plt.grid(True, which="both", linestyle='--', alpha=0.6)
390
+
391
+ if plot_filename:
392
+ plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
393
+ print(f"Saving plot to: {plot_filename}")
394
+
395
+ if not suppress_plot:
396
+ plt.show()
397
+
398
+ plt.close()
399
+
400
+ end_time = time.time()
401
+ print(f"Total time: {end_time - start_time:.2f} seconds.\n")