scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ import scanpy as sc
9
9
  from anndata import AnnData
10
10
  from django.db import IntegrityError
11
11
  from scipy.sparse import csr_matrix
12
+ import os
12
13
 
13
14
  from scdataloader import utils as data_utils
14
15
 
@@ -18,6 +19,8 @@ FULL_LENGTH_ASSAYS = [
18
19
  "EFO:0008931",
19
20
  ]
20
21
 
22
+ MAXFILESIZE = 10_000_000_000
23
+
21
24
 
22
25
  class Preprocessor:
23
26
  """
@@ -30,23 +33,27 @@ class Preprocessor:
30
33
  filter_gene_by_counts: Union[int, bool] = False,
31
34
  filter_cell_by_counts: Union[int, bool] = False,
32
35
  normalize_sum: float = 1e4,
33
- keep_norm_layer: bool = False,
34
36
  subset_hvg: int = 0,
37
+ use_layer: Optional[str] = None,
38
+ is_symbol: bool = False,
35
39
  hvg_flavor: str = "seurat_v3",
36
40
  binning: Optional[int] = None,
37
41
  result_binned_key: str = "X_binned",
38
42
  length_normalize: bool = False,
39
- force_preprocess=False,
40
- min_dataset_size=100,
41
- min_valid_genes_id=10_000,
42
- min_nnz_genes=200,
43
- maxdropamount=2,
44
- madoutlier=5,
45
- pct_mt_outlier=8,
46
- batch_key=None,
47
- skip_validate=False,
43
+ force_preprocess: bool = False,
44
+ min_dataset_size: int = 100,
45
+ min_valid_genes_id: int = 10_000,
46
+ min_nnz_genes: int = 200,
47
+ maxdropamount: int = 50,
48
+ madoutlier: int = 5,
49
+ pct_mt_outlier: int = 8,
50
+ batch_key: Optional[str] = None,
51
+ skip_validate: bool = False,
48
52
  additional_preprocess: Optional[Callable[[AnnData], AnnData]] = None,
49
53
  additional_postprocess: Optional[Callable[[AnnData], AnnData]] = None,
54
+ do_postp: bool = True,
55
+ organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"],
56
+ use_raw: bool = True,
50
57
  ) -> None:
51
58
  """
52
59
  Initializes the preprocessor and configures the workflow steps.
@@ -67,14 +74,34 @@ class Preprocessor:
67
74
  binning (int, optional): Determines whether to bin the data into discrete values of number of bins provided.
68
75
  result_binned_key (str, optional): Specifies the key of :class:`~anndata.AnnData` to store the binned data.
69
76
  Defaults to "X_binned".
77
+ length_normalize (bool, optional): Determines whether to length normalize the data.
78
+ Defaults to False.
79
+ force_preprocess (bool, optional): Determines whether to bypass the check of raw counts.
80
+ Defaults to False.
81
+ min_dataset_size (int, optional): The minimum size required for a dataset to be kept.
82
+ Defaults to 100.
83
+ min_valid_genes_id (int, optional): The minimum number of valid genes to keep a dataset.
84
+ Defaults to 10_000.
85
+ min_nnz_genes (int, optional): The minimum number of non-zero genes to keep a cell.
86
+ Defaults to 200.
87
+ maxdropamount (int, optional): The maximum amount of dropped cells per dataset. (2 for 50% drop, 3 for 33% drop, etc.)
88
+ Defaults to 2.
89
+ madoutlier (int, optional): The maximum absolute deviation of the outlier samples.
90
+ Defaults to 5.
91
+ pct_mt_outlier (int, optional): The maximum percentage of mitochondrial genes outlier.
92
+ Defaults to 8.
93
+ batch_key (str, optional): The key of :class:`~anndata.AnnData.obs` to use for batch information.
94
+ This arg is used in the highly variable gene selection step.
95
+ skip_validate (bool, optional): Determines whether to skip the validation step.
96
+ Defaults to False.
70
97
  """
71
98
  self.filter_gene_by_counts = filter_gene_by_counts
72
99
  self.filter_cell_by_counts = filter_cell_by_counts
73
100
  self.normalize_sum = normalize_sum
74
- self.keep_norm_layer = keep_norm_layer
75
101
  self.subset_hvg = subset_hvg
76
102
  self.hvg_flavor = hvg_flavor
77
103
  self.binning = binning
104
+ self.organisms = organisms
78
105
  self.result_binned_key = result_binned_key
79
106
  self.additional_preprocess = additional_preprocess
80
107
  self.additional_postprocess = additional_postprocess
@@ -88,45 +115,71 @@ class Preprocessor:
88
115
  self.batch_key = batch_key
89
116
  self.length_normalize = length_normalize
90
117
  self.skip_validate = skip_validate
118
+ self.use_layer = use_layer
119
+ self.is_symbol = is_symbol
120
+ self.do_postp = do_postp
121
+ self.use_raw = use_raw
91
122
 
92
123
  def __call__(self, adata) -> AnnData:
124
+ if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms:
125
+ raise ValueError(
126
+ "we cannot work with this organism",
127
+ adata[0].obs.organism_ontology_term_id.iloc[0],
128
+ )
93
129
  if self.additional_preprocess is not None:
94
130
  adata = self.additional_preprocess(adata)
95
- if adata.raw is not None:
131
+ if adata.raw is not None and self.use_raw:
96
132
  adata.X = adata.raw.X
97
133
  del adata.raw
134
+ if self.use_layer is not None:
135
+ adata.X = adata.layers[self.use_layer]
98
136
  if adata.layers is not None:
137
+ if "counts" in adata.layers.keys():
138
+ if np.abs(adata[:50_000].X.astype(int) - adata[:50_000].X).sum():
139
+ print("X was not raw counts, using 'counts' layer")
140
+ adata.X = adata.layers["counts"].copy()
141
+ print("Dropping layers: ", adata.layers.keys())
99
142
  del adata.layers
100
143
  if len(adata.varm.keys()) > 0:
101
144
  del adata.varm
102
- if len(adata.obsm.keys()) > 0:
145
+ if len(adata.obsm.keys()) > 0 and self.do_postp:
103
146
  del adata.obsm
104
- if len(adata.obsp.keys()) > 0:
147
+ if len(adata.obsp.keys()) > 0 and self.do_postp:
105
148
  del adata.obsp
106
149
  if len(adata.uns.keys()) > 0:
107
150
  del adata.uns
108
151
  if len(adata.varp.keys()) > 0:
109
152
  del adata.varp
110
153
  # check that it is a count
111
- if (
112
- np.abs(adata.X.astype(int) - adata.X).sum() and not self.force_preprocess
113
- ): # check if likely raw data
114
- raise ValueError(
115
- "Data is not raw counts, please check layers, find raw data, or bypass with force_preprocess"
116
- )
154
+ print("checking raw counts")
155
+ if np.abs(
156
+ adata[:50_000].X.astype(int) - adata[:50_000].X
157
+ ).sum(): # check if likely raw data
158
+ if not self.force_preprocess:
159
+ raise ValueError(
160
+ "Data is not raw counts, please check layers, find raw data, or bypass with force_preprocess"
161
+ )
162
+ else:
163
+ print(
164
+ "Data is not raw counts, please check layers, find raw data, or bypass with force_preprocess"
165
+ )
117
166
  # please check layers
118
167
  # if not available count drop
168
+ prevsize = adata.shape[0]
169
+ # dropping non primary
170
+ if "is_primary_data" in adata.obs.columns:
171
+ adata = adata[adata.obs.is_primary_data]
172
+ if adata.shape[0] < self.min_dataset_size:
173
+ raise Exception("Dataset dropped due to too many secondary cells")
174
+ print(
175
+ "removed {} non primary cells, {} renamining".format(
176
+ prevsize - adata.shape[0], adata.shape[0]
177
+ )
178
+ )
119
179
  # # cleanup and dropping low expressed genes and unexpressed cells
120
180
  prevsize = adata.shape[0]
121
181
  adata.obs["nnz"] = np.array(np.sum(adata.X != 0, axis=1).flatten())[0]
122
- adata = adata[
123
- (adata.obs["nnz"] > self.min_nnz_genes)
124
- # or if slide-seq
125
- | (
126
- (adata.obs.assay_ontology_term_id == "EFO:0030062")
127
- & (adata.obs["nnz"] > (self.min_nnz_genes / 3))
128
- )
129
- ]
182
+ adata = adata[(adata.obs["nnz"] > self.min_nnz_genes)]
130
183
  if self.filter_gene_by_counts:
131
184
  sc.pp.filter_genes(adata, min_counts=self.filter_gene_by_counts)
132
185
  if self.filter_cell_by_counts:
@@ -145,12 +198,29 @@ class Preprocessor:
145
198
  "Dataset dropped due to low expressed genes and unexpressed cells: current size: "
146
199
  + str(adata.shape[0])
147
200
  )
148
- # dropping non primary
149
- adata = adata[adata.obs.is_primary_data]
150
- if adata.shape[0] < self.min_dataset_size:
151
- raise ValueError(
152
- "Dataset dropped because contains too many secondary cells"
201
+ print(
202
+ "filtered out {} cells, {} renamining".format(
203
+ prevsize - adata.shape[0], adata.shape[0]
204
+ )
205
+ )
206
+
207
+ if self.is_symbol or not adata.var.index.str.contains("ENSG").any():
208
+ if not adata.var.index.str.contains("ENSG").any():
209
+ print("No ENSG genes found, assuming gene symbols...")
210
+ genesdf["ensembl_gene_id"] = genesdf.index
211
+ var = (
212
+ adata.var.merge(
213
+ genesdf.drop_duplicates("symbol").set_index("symbol", drop=False),
214
+ left_index=True,
215
+ right_index=True,
216
+ how="inner",
217
+ )
218
+ .sort_values(by="ensembl_gene_id")
219
+ .set_index("ensembl_gene_id")
153
220
  )
221
+ adata = adata[:, var["symbol"]]
222
+ adata.var = var
223
+ genesdf = genesdf.set_index("ensembl_gene_id")
154
224
 
155
225
  intersect_genes = set(adata.var.index).intersection(set(genesdf.index))
156
226
  print(f"Removed {len(adata.var.index) - len(intersect_genes)} genes.")
@@ -169,36 +239,39 @@ class Preprocessor:
169
239
  # do a validation function
170
240
  adata.uns["unseen_genes"] = list(unseen)
171
241
  if not self.skip_validate:
242
+ print("validating")
172
243
  data_utils.validate(adata, organism=adata.obs.organism_ontology_term_id[0])
173
- # length normalization
174
- if (
175
- adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS).any()
176
- and self.length_normalize
177
- ):
178
- subadata = data_utils.length_normalize(
179
- adata[adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS)],
180
- )
244
+ # length normalization
245
+ if (
246
+ adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS).any()
247
+ and self.length_normalize
248
+ ):
249
+ print("doing length norm")
250
+ subadata = data_utils.length_normalize(
251
+ adata[adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS)],
252
+ )
181
253
 
182
- adata = ad.concat(
183
- [
184
- adata[
185
- ~adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS)
254
+ adata = ad.concat(
255
+ [
256
+ adata[
257
+ ~adata.obs["assay_ontology_term_id"].isin(
258
+ FULL_LENGTH_ASSAYS
259
+ )
260
+ ],
261
+ subadata,
186
262
  ],
187
- subadata,
188
- ],
189
- axis=0,
190
- join="outer",
191
- merge="only",
192
- )
193
- # step 3: normalize total
194
- adata.layers["clean"] = sc.pp.log1p(
195
- sc.pp.normalize_total(adata, target_sum=self.normalize_sum, inplace=False)[
196
- "X"
197
- ]
198
- )
263
+ axis=0,
264
+ join="outer",
265
+ merge="only",
266
+ )
199
267
 
200
268
  # QC
269
+
201
270
  adata.var[genesdf.columns] = genesdf.loc[adata.var.index]
271
+ for name in ["stable_id", "created_at", "updated_at"]:
272
+ if name in adata.var.columns:
273
+ adata.var = adata.var.drop(columns=name)
274
+ print("startin QC")
202
275
  sc.pp.calculate_qc_metrics(
203
276
  adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20]
204
277
  )
@@ -224,31 +297,38 @@ class Preprocessor:
224
297
  # raise Exception("More than 50% of the dataset has been dropped due to outliers.")
225
298
  # adata = adata[(~adata.obs.outlier) & (~adata.obs.mt_outlier)].copy()
226
299
  # remaining
227
- # step 5: subset hvg
228
- if self.subset_hvg:
229
- sc.pp.highly_variable_genes(
230
- adata,
231
- layer="clean",
232
- n_top_genes=self.subset_hvg,
233
- batch_key=self.batch_key,
234
- flavor=self.hvg_flavor,
235
- subset=False,
236
- )
300
+
237
301
  # based on the topometry paper https://www.biorxiv.org/content/10.1101/2022.03.14.484134v2
238
302
  # https://rapids-singlecell.readthedocs.io/en/latest/api/generated/rapids_singlecell.pp.pca.html#rapids_singlecell.pp.pca
239
-
240
- adata.obsm["clean_pca"] = sc.pp.pca(
241
- adata.layers["clean"],
242
- n_comps=300 if adata.shape[0] > 300 else adata.shape[0] - 2,
243
- )
244
- sc.pp.neighbors(adata, use_rep="clean_pca")
245
- sc.tl.leiden(adata, key_added="leiden_3", resolution=3.0)
246
- sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
247
- sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
248
- sc.tl.umap(adata)
249
- # additional
250
- if self.additional_postprocess is not None:
251
- adata = self.additional_postprocess(adata)
303
+ if self.do_postp:
304
+ print("normalize")
305
+ adata.layers["clean"] = sc.pp.log1p(
306
+ sc.pp.normalize_total(
307
+ adata, target_sum=self.normalize_sum, inplace=False
308
+ )["X"]
309
+ )
310
+ # step 5: subset hvg
311
+ if self.subset_hvg:
312
+ sc.pp.highly_variable_genes(
313
+ adata,
314
+ layer="clean",
315
+ n_top_genes=self.subset_hvg,
316
+ batch_key=self.batch_key,
317
+ flavor=self.hvg_flavor,
318
+ subset=False,
319
+ )
320
+ adata.obsm["clean_pca"] = sc.pp.pca(
321
+ adata.layers["clean"],
322
+ n_comps=300 if adata.shape[0] > 300 else adata.shape[0] - 2,
323
+ )
324
+ sc.pp.neighbors(adata, use_rep="clean_pca")
325
+ sc.tl.leiden(adata, key_added="leiden_3", resolution=3.0)
326
+ sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
327
+ sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
328
+ sc.tl.umap(adata)
329
+ # additional
330
+ if self.additional_postprocess is not None:
331
+ adata = self.additional_postprocess(adata)
252
332
  adata = adata[:, adata.var.sort_index().index]
253
333
  # create random ids for all cells
254
334
  adata.obs.index = [str(uuid4()) for _ in range(adata.shape[0])]
@@ -296,6 +376,7 @@ class Preprocessor:
296
376
  bin_edges.append(np.concatenate([[0], bins]))
297
377
  adata.layers[self.result_binned_key] = np.stack(binned_rows)
298
378
  adata.obsm["bin_edges"] = np.stack(bin_edges)
379
+ print("done")
299
380
  return adata
300
381
 
301
382
 
@@ -306,12 +387,14 @@ class LaminPreprocessor(Preprocessor):
306
387
  erase_prev_dataset: bool = False,
307
388
  cache: bool = True,
308
389
  stream: bool = False,
390
+ keep_files: bool = True,
309
391
  **kwargs,
310
392
  ):
311
393
  super().__init__(*args, **kwargs)
312
394
  self.erase_prev_dataset = erase_prev_dataset
313
395
  self.cache = cache
314
396
  self.stream = stream
397
+ self.keep_files = keep_files
315
398
 
316
399
  def __call__(
317
400
  self,
@@ -319,7 +402,7 @@ class LaminPreprocessor(Preprocessor):
319
402
  name="preprocessed dataset",
320
403
  description="preprocessed dataset using scprint",
321
404
  start_at=0,
322
- version="2",
405
+ version=2,
323
406
  ):
324
407
  """
325
408
  format controls the different input value wrapping, including categorical
@@ -334,49 +417,97 @@ class LaminPreprocessor(Preprocessor):
334
417
  all_ready_processed_keys = set()
335
418
  if self.cache:
336
419
  for i in ln.Artifact.filter(description=description):
337
- all_ready_processed_keys.add(i.initial_version.key)
420
+ all_ready_processed_keys.add(i.stem_uid)
338
421
  if isinstance(data, AnnData):
339
- return self.preprocess(data)
422
+ return super().__call__(data)
340
423
  elif isinstance(data, ln.Collection):
341
424
  for i, file in enumerate(data.artifacts.all()[start_at:]):
342
425
  # use the counts matrix
343
426
  print(i)
344
- if file.key in all_ready_processed_keys:
345
- print(f"{file.key} is already processed")
427
+ if file.stem_uid in all_ready_processed_keys:
428
+ print(f"{file.stem_uid} is already processed... not preprocessing")
346
429
  continue
347
430
  print(file)
348
- if file.backed().obs.is_primary_data.sum() == 0:
349
- print(f"{file.key} only contains non primary cells")
431
+ backed = file.backed()
432
+ if backed.obs.is_primary_data.sum() == 0:
433
+ print(f"{file.key} only contains non primary cells.. dropping")
434
+ continue
435
+ if backed.shape[1] < 1000:
436
+ print(
437
+ f"{file.key} only contains less than 1000 genes and is likely not scRNAseq... dropping"
438
+ )
350
439
  continue
351
- adata = file.load(stream=self.stream)
440
+ if file.size <= MAXFILESIZE:
441
+ adata = file.load(stream=self.stream)
442
+ print(adata)
443
+ else:
444
+ badata = backed
445
+ print(badata)
352
446
 
353
- print(adata)
354
447
  try:
355
- adata = super().__call__(adata)
448
+ if file.size > MAXFILESIZE:
449
+ print(
450
+ f"dividing the dataset as it is too large: {file.size//1_000_000_000}Gb"
451
+ )
452
+ num_blocks = int(np.ceil(file.size / (MAXFILESIZE / 2)))
453
+ block_size = int(
454
+ (np.ceil(badata.shape[0] / 30_000) * 30_000) // num_blocks
455
+ )
456
+ print("num blocks ", num_blocks)
457
+ for i in range(num_blocks):
458
+ start_index = i * block_size
459
+ end_index = min((i + 1) * block_size, badata.shape[0])
460
+ block = badata[start_index:end_index].to_memory()
461
+ print(block)
462
+ block = super().__call__(block)
463
+ myfile = ln.Artifact(
464
+ block,
465
+ is_new_version_of=file,
466
+ description=description,
467
+ version=str(version) + "_s" + str(i),
468
+ )
469
+ myfile.save()
470
+ if self.keep_files:
471
+ files.append(myfile)
472
+ else:
473
+ del myfile
474
+ del block
475
+
476
+ else:
477
+ adata = super().__call__(adata)
478
+ myfile = ln.Artifact(
479
+ adata,
480
+ is_new_version_of=file,
481
+ description=description,
482
+ version=str(version),
483
+ )
484
+ myfile.save()
485
+ if self.keep_files:
486
+ files.append(myfile)
487
+ else:
488
+ del myfile
489
+ del adata
356
490
 
357
491
  except ValueError as v:
358
- if v.args[0].startswith(
359
- "Dataset dropped because contains too many secondary"
360
- ):
492
+ if v.args[0].startswith("we cannot work with this organism"):
361
493
  print(v)
362
494
  continue
363
495
  else:
364
496
  raise v
365
- for name in ["stable_id", "created_at", "updated_at"]:
366
- if name in adata.var.columns:
367
- adata.var = adata.var.drop(columns=name)
368
- myfile = ln.Artifact(
369
- adata,
370
- is_new_version_of=file,
371
- description=description,
372
- version=version,
373
- )
497
+ except Exception as e:
498
+ if e.args[0].startswith("Dataset dropped due to"):
499
+ print(e)
500
+ continue
501
+ else:
502
+ raise e
503
+
374
504
  # issues with KLlggfw6I6lvmbqiZm46
375
- myfile.save()
376
- files.append(myfile)
377
- dataset = ln.Collection(files, name=name, description=description)
378
- dataset.save()
379
- return dataset
505
+ if self.keep_files:
506
+ dataset = ln.Collection(files, name=name, description=description)
507
+ dataset.save()
508
+ return dataset
509
+ else:
510
+ return
380
511
  else:
381
512
  raise ValueError("Please provide either anndata or ln.Collection")
382
513
 
@@ -498,7 +629,7 @@ def additional_preprocess(adata):
498
629
  ].astype(str)
499
630
  adata.obs.loc[loc, "tissue_ontology_term_id"] = adata.obs.loc[
500
631
  loc, "tissue_ontology_term_id"
501
- ].str.replace(r" \(cell culture\)", "")
632
+ ].str.replace(" (cell culture)", "")
502
633
 
503
634
  loc = adata.obs["tissue_ontology_term_id"].str.contains("(organoid)", regex=False)
504
635
  if loc.sum() > 0:
@@ -508,7 +639,7 @@ def additional_preprocess(adata):
508
639
  ].astype(str)
509
640
  adata.obs.loc[loc, "tissue_ontology_term_id"] = adata.obs.loc[
510
641
  loc, "tissue_ontology_term_id"
511
- ].str.replace(r" \(organoid\)", "")
642
+ ].str.replace(" (organoid)", "")
512
643
 
513
644
  loc = adata.obs["tissue_ontology_term_id"].str.contains("CL:", regex=False)
514
645
  if loc.sum() > 0: