scdataloader 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,7 @@
1
- from typing import Callable, Optional, Union
1
+ from typing import Any, Callable, Optional, Union
2
2
  from uuid import uuid4
3
3
 
4
4
  import anndata as ad
5
- import bionty as bt
6
5
  import lamindb as ln
7
6
  import numpy as np
8
7
  import pandas as pd
@@ -28,18 +27,15 @@ class Preprocessor:
28
27
 
29
28
  def __init__(
30
29
  self,
31
- lb,
32
30
  filter_gene_by_counts: Union[int, bool] = False,
33
31
  filter_cell_by_counts: Union[int, bool] = False,
34
- normalize_total: Union[float, bool] = False,
35
- log1p: bool = False,
36
- subset_hvg: Union[int, bool] = False,
32
+ normalize_sum: float = 1e4,
33
+ keep_norm_layer: bool = False,
34
+ subset_hvg: int = 0,
37
35
  hvg_flavor: str = "seurat_v3",
38
36
  binning: Optional[int] = None,
39
37
  result_binned_key: str = "X_binned",
40
38
  length_normalize: bool = False,
41
- additional_preprocess: Optional[Callable[[AnnData], AnnData]] = None,
42
- additional_postprocess: Optional[Callable[[AnnData], AnnData]] = None,
43
39
  force_preprocess=False,
44
40
  min_dataset_size=100,
45
41
  min_valid_genes_id=10_000,
@@ -48,10 +44,10 @@ class Preprocessor:
48
44
  madoutlier=5,
49
45
  pct_mt_outlier=8,
50
46
  batch_key=None,
51
- erase_prev_dataset: bool = False,
52
- cache: bool = True,
53
- stream: bool = False,
54
- ):
47
+ skip_validate=False,
48
+ additional_preprocess: Optional[Callable[[AnnData], AnnData]] = None,
49
+ additional_postprocess: Optional[Callable[[AnnData], AnnData]] = None,
50
+ ) -> None:
55
51
  """
56
52
  Initializes the preprocessor and configures the workflow steps.
57
53
 
@@ -60,7 +56,7 @@ class Preprocessor:
60
56
  If int, filters genes with counts. Defaults to False.
61
57
  filter_cell_by_counts (int or bool, optional): Determines whether to filter cells by counts.
62
58
  If int, filters cells with counts. Defaults to False.
63
- normalize_total (float or bool, optional): Determines whether to normalize the total counts of each cell to a specific value.
59
+ normalize_sum (float or bool, optional): Determines whether to normalize the total counts of each cell to a specific value.
64
60
  Defaults to 1e4.
65
61
  log1p (bool, optional): Determines whether to apply log1p transform to the normalized data.
66
62
  Defaults to True.
@@ -74,8 +70,8 @@ class Preprocessor:
74
70
  """
75
71
  self.filter_gene_by_counts = filter_gene_by_counts
76
72
  self.filter_cell_by_counts = filter_cell_by_counts
77
- self.normalize_total = normalize_total
78
- self.log1p = log1p
73
+ self.normalize_sum = normalize_sum
74
+ self.keep_norm_layer = keep_norm_layer
79
75
  self.subset_hvg = subset_hvg
80
76
  self.hvg_flavor = hvg_flavor
81
77
  self.binning = binning
@@ -83,7 +79,6 @@ class Preprocessor:
83
79
  self.additional_preprocess = additional_preprocess
84
80
  self.additional_postprocess = additional_postprocess
85
81
  self.force_preprocess = force_preprocess
86
- self.lb = lb
87
82
  self.min_dataset_size = min_dataset_size
88
83
  self.min_valid_genes_id = min_valid_genes_id
89
84
  self.min_nnz_genes = min_nnz_genes
@@ -91,79 +86,10 @@ class Preprocessor:
91
86
  self.madoutlier = madoutlier
92
87
  self.pct_mt_outlier = pct_mt_outlier
93
88
  self.batch_key = batch_key
94
- self.erase_prev_dataset = erase_prev_dataset
95
89
  self.length_normalize = length_normalize
96
- self.cache = cache
97
- self.stream = stream
90
+ self.skip_validate = skip_validate
98
91
 
99
- def __call__(
100
- self,
101
- data: Union[ln.Dataset, AnnData] = None,
102
- name="preprocessed dataset",
103
- description="preprocessed dataset using scprint",
104
- start_at=0,
105
- ):
106
- """
107
- format controls the different input value wrapping, including categorical
108
- binned style, fixed-sum normalized counts, log1p fixed-sum normalized counts, etc.
109
-
110
- Args:
111
- adata (AnnData): The AnnData object to preprocess.
112
- batch_key (str, optional): The key of AnnData.obs to use for batch information. This arg
113
- is used in the highly variable gene selection step.
114
- """
115
- files = []
116
- all_ready_processed_keys = set()
117
- if self.cache:
118
- for i in ln.Artifact.filter(description="preprocessed by scprint"):
119
- all_ready_processed_keys.add(i.initial_version.key)
120
- if isinstance(data, AnnData):
121
- return self.preprocess(data)
122
- elif isinstance(data, ln.Dataset):
123
- for i, file in enumerate(data.artifacts.all()[start_at:]):
124
- # use the counts matrix
125
- print(i)
126
- if file.key in all_ready_processed_keys:
127
- print(f"{file.key} is already processed")
128
- continue
129
- print(file)
130
- if file.backed().obs.is_primary_data.sum() == 0:
131
- print(f"{file.key} only contains non primary cells")
132
- continue
133
- adata = file.load(stream=self.stream)
134
-
135
- print(adata)
136
- try:
137
- adata = self.preprocess(adata)
138
-
139
- except ValueError as v:
140
- if v.args[0].startswith(
141
- "Dataset dropped because contains too many secondary"
142
- ):
143
- print(v)
144
- continue
145
- else:
146
- raise v
147
- try:
148
- file.save()
149
- except IntegrityError as e:
150
- # UNIQUE constraint failed: lnschema_bionty_organism.ontology_id
151
- print(f"seeing {e}... continuing")
152
- myfile = ln.Artifact(
153
- adata,
154
- is_new_version_of=file,
155
- description="preprocessed by scprint",
156
- )
157
- # issues with KLlggfw6I6lvmbqiZm46
158
- myfile.save()
159
- files.append(myfile)
160
- dataset = ln.Dataset(files, name=name, description=description)
161
- dataset.save()
162
- return dataset
163
- else:
164
- raise ValueError("Please provide either anndata or ln.Dataset")
165
-
166
- def preprocess(self, adata: AnnData):
92
+ def __call__(self, adata) -> AnnData:
167
93
  if self.additional_preprocess is not None:
168
94
  adata = self.additional_preprocess(adata)
169
95
  if adata.raw is not None:
@@ -183,8 +109,7 @@ class Preprocessor:
183
109
  del adata.varp
184
110
  # check that it is a count
185
111
  if (
186
- int(adata.X[:100].max()) != adata.X[:100].max()
187
- and not self.force_preprocess
112
+ np.abs(adata.X.astype(int) - adata.X).sum() and not self.force_preprocess
188
113
  ): # check if likely raw data
189
114
  raise ValueError(
190
115
  "Data is not raw counts, please check layers, find raw data, or bypass with force_preprocess"
@@ -208,7 +133,7 @@ class Preprocessor:
208
133
  sc.pp.filter_cells(adata, min_counts=self.filter_cell_by_counts)
209
134
  # if lost > 50% of the dataset, drop dataset
210
135
  # load the genes
211
- genesdf = self.load_genes(adata.obs.organism_ontology_term_id[0])
136
+ genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
212
137
 
213
138
  if prevsize / adata.shape[0] > self.maxdropamount:
214
139
  raise Exception(
@@ -230,9 +155,7 @@ class Preprocessor:
230
155
  intersect_genes = set(adata.var.index).intersection(set(genesdf.index))
231
156
  print(f"Removed {len(adata.var.index) - len(intersect_genes)} genes.")
232
157
  if len(intersect_genes) < self.min_valid_genes_id:
233
- raise Exception(
234
- "Dataset dropped due to too many genes not mapping to it"
235
- )
158
+ raise Exception("Dataset dropped due to too many genes not mapping to it")
236
159
  adata = adata[:, list(intersect_genes)]
237
160
  # marking unseen genes
238
161
  unseen = set(genesdf.index) - set(adata.var.index)
@@ -245,28 +168,21 @@ class Preprocessor:
245
168
  adata = ad.concat([adata, emptyda], axis=1, join="outer", merge="only")
246
169
  # do a validation function
247
170
  adata.uns["unseen_genes"] = list(unseen)
248
- data_utils.validate(
249
- adata, self.lb, organism=adata.obs.organism_ontology_term_id[0]
250
- )
171
+ if not self.skip_validate:
172
+ data_utils.validate(adata, organism=adata.obs.organism_ontology_term_id[0])
251
173
  # length normalization
252
174
  if (
253
175
  adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS).any()
254
176
  and self.length_normalize
255
177
  ):
256
178
  subadata = data_utils.length_normalize(
257
- adata[
258
- adata.obs["assay_ontology_term_id"].isin(
259
- FULL_LENGTH_ASSAYS
260
- )
261
- ],
179
+ adata[adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS)],
262
180
  )
263
181
 
264
182
  adata = ad.concat(
265
183
  [
266
184
  adata[
267
- ~adata.obs["assay_ontology_term_id"].isin(
268
- FULL_LENGTH_ASSAYS
269
- )
185
+ ~adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS)
270
186
  ],
271
187
  subadata,
272
188
  ],
@@ -275,10 +191,11 @@ class Preprocessor:
275
191
  merge="only",
276
192
  )
277
193
  # step 3: normalize total
278
- if self.normalize_total:
279
- sc.pp.normalize_total(adata, target_sum=self.normalize_total)
280
- if self.log1p and not is_log1p(adata):
281
- sc.pp.log1p(adata)
194
+ adata.layers["clean"] = sc.pp.log1p(
195
+ sc.pp.normalize_total(adata, target_sum=self.normalize_sum, inplace=False)[
196
+ "X"
197
+ ]
198
+ )
282
199
 
283
200
  # QC
284
201
  adata.var[genesdf.columns] = genesdf.loc[adata.var.index]
@@ -288,17 +205,15 @@ class Preprocessor:
288
205
 
289
206
  adata.obs["outlier"] = (
290
207
  data_utils.is_outlier(adata, "total_counts", self.madoutlier)
291
- | data_utils.is_outlier(
292
- adata, "n_genes_by_counts", self.madoutlier
293
- )
208
+ | data_utils.is_outlier(adata, "n_genes_by_counts", self.madoutlier)
294
209
  | data_utils.is_outlier(
295
210
  adata, "pct_counts_in_top_20_genes", self.madoutlier
296
211
  )
297
212
  )
298
213
 
299
- adata.obs["mt_outlier"] = data_utils.is_outlier(
300
- adata, "pct_counts_mt", 3
301
- ) | (adata.obs["pct_counts_mt"] > self.pct_mt_outlier)
214
+ adata.obs["mt_outlier"] = data_utils.is_outlier(adata, "pct_counts_mt", 3) | (
215
+ adata.obs["pct_counts_mt"] > self.pct_mt_outlier
216
+ )
302
217
  total_outliers = (adata.obs["outlier"] | adata.obs["mt_outlier"]).sum()
303
218
  total_cells = adata.shape[0]
304
219
  percentage_outliers = (total_outliers / total_cells) * 100
@@ -313,16 +228,20 @@ class Preprocessor:
313
228
  if self.subset_hvg:
314
229
  sc.pp.highly_variable_genes(
315
230
  adata,
231
+ layer="clean",
316
232
  n_top_genes=self.subset_hvg,
317
233
  batch_key=self.batch_key,
318
234
  flavor=self.hvg_flavor,
319
- subset=True,
235
+ subset=False,
320
236
  )
321
237
  # based on the topometry paper https://www.biorxiv.org/content/10.1101/2022.03.14.484134v2
322
238
  # https://rapids-singlecell.readthedocs.io/en/latest/api/generated/rapids_singlecell.pp.pca.html#rapids_singlecell.pp.pca
323
- sc.pp.neighbors(
324
- adata, n_pcs=500 if adata.shape[0] > 500 else adata.shape[0] - 2
239
+
240
+ adata.obsm["clean_pca"] = sc.pp.pca(
241
+ adata.layers["clean"],
242
+ n_comps=300 if adata.shape[0] > 300 else adata.shape[0] - 2,
325
243
  )
244
+ sc.pp.neighbors(adata, use_rep="clean_pca")
326
245
  sc.tl.leiden(adata, key_added="leiden_3", resolution=3.0)
327
246
  sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
328
247
  sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
@@ -341,9 +260,7 @@ class Preprocessor:
341
260
  print("Binning data ...")
342
261
  if not isinstance(self.binning, int):
343
262
  raise ValueError(
344
- "Binning arg must be an integer, but got {}.".format(
345
- self.binning
346
- )
263
+ "Binning arg must be an integer, but got {}.".format(self.binning)
347
264
  )
348
265
  # NOTE: the first bin is always a spectial for zero
349
266
  n_bins = self.binning
@@ -381,21 +298,87 @@ class Preprocessor:
381
298
  adata.obsm["bin_edges"] = np.stack(bin_edges)
382
299
  return adata
383
300
 
384
- def load_genes(self, organism):
385
- genesdf = bt.Gene(
386
- organism=self.lb.Organism.filter(ontology_id=organism).first().name
387
- ).df()
388
- genesdf = genesdf.drop_duplicates(subset="ensembl_gene_id")
389
- genesdf = genesdf.set_index("ensembl_gene_id")
390
- # mitochondrial genes
391
- genesdf["mt"] = genesdf.symbol.astype(str).str.startswith("MT-")
392
- # ribosomal genes
393
- genesdf["ribo"] = genesdf.symbol.astype(str).str.startswith(
394
- ("RPS", "RPL")
395
- )
396
- # hemoglobin genes.
397
- genesdf["hb"] = genesdf.symbol.astype(str).str.contains(("^HB[^(P)]"))
398
- return genesdf
301
+
302
+ class LaminPreprocessor(Preprocessor):
303
+ def __init__(
304
+ self,
305
+ *args,
306
+ erase_prev_dataset: bool = False,
307
+ cache: bool = True,
308
+ stream: bool = False,
309
+ **kwargs,
310
+ ):
311
+ super().__init__(*args, **kwargs)
312
+ self.erase_prev_dataset = erase_prev_dataset
313
+ self.cache = cache
314
+ self.stream = stream
315
+
316
+ def __call__(
317
+ self,
318
+ data: Union[ln.Collection, AnnData] = None,
319
+ name="preprocessed dataset",
320
+ description="preprocessed dataset using scprint",
321
+ start_at=0,
322
+ version="2",
323
+ ):
324
+ """
325
+ format controls the different input value wrapping, including categorical
326
+ binned style, fixed-sum normalized counts, log1p fixed-sum normalized counts, etc.
327
+
328
+ Args:
329
+ adata (AnnData): The AnnData object to preprocess.
330
+ batch_key (str, optional): The key of AnnData.obs to use for batch information. This arg
331
+ is used in the highly variable gene selection step.
332
+ """
333
+ files = []
334
+ all_ready_processed_keys = set()
335
+ if self.cache:
336
+ for i in ln.Artifact.filter(description=description):
337
+ all_ready_processed_keys.add(i.initial_version.key)
338
+ if isinstance(data, AnnData):
339
+ return self.preprocess(data)
340
+ elif isinstance(data, ln.Collection):
341
+ for i, file in enumerate(data.artifacts.all()[start_at:]):
342
+ # use the counts matrix
343
+ print(i)
344
+ if file.key in all_ready_processed_keys:
345
+ print(f"{file.key} is already processed")
346
+ continue
347
+ print(file)
348
+ if file.backed().obs.is_primary_data.sum() == 0:
349
+ print(f"{file.key} only contains non primary cells")
350
+ continue
351
+ adata = file.load(stream=self.stream)
352
+
353
+ print(adata)
354
+ try:
355
+ adata = super().__call__(adata)
356
+
357
+ except ValueError as v:
358
+ if v.args[0].startswith(
359
+ "Dataset dropped because contains too many secondary"
360
+ ):
361
+ print(v)
362
+ continue
363
+ else:
364
+ raise v
365
+ for name in ["stable_id", "created_at", "updated_at"]:
366
+ if name in adata.var.columns:
367
+ adata.var = adata.var.drop(columns=name)
368
+ myfile = ln.Artifact(
369
+ adata,
370
+ is_new_version_of=file,
371
+ description=description,
372
+ version=version,
373
+ )
374
+ # issues with KLlggfw6I6lvmbqiZm46
375
+ myfile.save()
376
+ files.append(myfile)
377
+ dataset = ln.Collection(files, name=name, description=description)
378
+ dataset.save()
379
+ return dataset
380
+ else:
381
+ raise ValueError("Please provide either anndata or ln.Collection")
399
382
 
400
383
 
401
384
  def is_log1p(adata: AnnData) -> bool:
@@ -494,7 +477,7 @@ def additional_preprocess(adata):
494
477
  adata.obs["cell_culture"] = False
495
478
  # if cell_type contains the word "(cell culture)" then it is a cell culture and we mark it as so and remove this from the cell type
496
479
  loc = adata.obs["cell_type_ontology_term_id"].str.contains(
497
- "(cell culture)"
480
+ "(cell culture)", regex=False
498
481
  )
499
482
  if loc.sum() > 0:
500
483
  adata.obs["cell_type_ontology_term_id"] = adata.obs[
@@ -505,7 +488,9 @@ def additional_preprocess(adata):
505
488
  loc, "cell_type_ontology_term_id"
506
489
  ].str.replace(" (cell culture)", "")
507
490
 
508
- loc = adata.obs["tissue_ontology_term_id"].str.contains("(cell culture)")
491
+ loc = adata.obs["tissue_ontology_term_id"].str.contains(
492
+ "(cell culture)", regex=False
493
+ )
509
494
  if loc.sum() > 0:
510
495
  adata.obs.loc[loc, "cell_culture"] = True
511
496
  adata.obs["tissue_ontology_term_id"] = adata.obs[
@@ -515,7 +500,7 @@ def additional_preprocess(adata):
515
500
  loc, "tissue_ontology_term_id"
516
501
  ].str.replace(r" \(cell culture\)", "")
517
502
 
518
- loc = adata.obs["tissue_ontology_term_id"].str.contains("(organoid)")
503
+ loc = adata.obs["tissue_ontology_term_id"].str.contains("(organoid)", regex=False)
519
504
  if loc.sum() > 0:
520
505
  adata.obs.loc[loc, "cell_culture"] = True
521
506
  adata.obs["tissue_ontology_term_id"] = adata.obs[
@@ -525,7 +510,7 @@ def additional_preprocess(adata):
525
510
  loc, "tissue_ontology_term_id"
526
511
  ].str.replace(r" \(organoid\)", "")
527
512
 
528
- loc = adata.obs["tissue_ontology_term_id"].str.contains("CL:")
513
+ loc = adata.obs["tissue_ontology_term_id"].str.contains("CL:", regex=False)
529
514
  if loc.sum() > 0:
530
515
  adata.obs["tissue_ontology_term_id"] = adata.obs[
531
516
  "tissue_ontology_term_id"
@@ -553,12 +538,8 @@ def additional_postprocess(adata):
553
538
  ) # + "_" + adata.obs['dataset_id'].astype(str)
554
539
 
555
540
  # if group is too small
556
- okgroup = [
557
- i for i, j in adata.obs["dpt_group"].value_counts().items() if j >= 10
558
- ]
559
- not_okgroup = [
560
- i for i, j in adata.obs["dpt_group"].value_counts().items() if j < 3
561
- ]
541
+ okgroup = [i for i, j in adata.obs["dpt_group"].value_counts().items() if j >= 10]
542
+ not_okgroup = [i for i, j in adata.obs["dpt_group"].value_counts().items() if j < 3]
562
543
  # set the group to empty
563
544
  adata.obs.loc[adata.obs["dpt_group"].isin(not_okgroup), "dpt_group"] = ""
564
545
  adata.obs["heat_diff"] = np.nan