scdataloader 1.8.1__tar.gz → 1.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdataloader
3
- Version: 1.8.1
3
+ Version: 1.9.0
4
4
  Summary: a dataloader for single cell data in lamindb
5
5
  Project-URL: repository, https://github.com/jkobject/scDataLoader
6
6
  Author-email: jkobject <jkobject@gmail.com>
@@ -14,13 +14,14 @@ Requires-Dist: cellxgene-census>=0.1.0
14
14
  Requires-Dist: django>=4.0.0
15
15
  Requires-Dist: harmonypy>=0.0.10
16
16
  Requires-Dist: ipykernel>=6.20.0
17
+ Requires-Dist: jupytext>=1.16.0
17
18
  Requires-Dist: lamindb[bionty,cellregistry,jupyter,ourprojects,zarr]<2,>=1.0.4
18
19
  Requires-Dist: leidenalg>=0.8.0
19
- Requires-Dist: lightning>=2.3.0
20
20
  Requires-Dist: matplotlib>=3.5.0
21
21
  Requires-Dist: numpy==1.26.0
22
22
  Requires-Dist: palantir>=1.3.3
23
23
  Requires-Dist: pandas>=2.0.0
24
+ Requires-Dist: pytorch-lightning>=2.3.0
24
25
  Requires-Dist: scikit-misc>=0.5.0
25
26
  Requires-Dist: seaborn>=0.11.0
26
27
  Requires-Dist: torch==2.2.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scdataloader"
3
- version = "1.8.1"
3
+ version = "1.9.0"
4
4
  description = "a dataloader for single cell data in lamindb"
5
5
  authors = [
6
6
  {name = "jkobject", email = "jkobject@gmail.com"}
@@ -14,7 +14,7 @@ dependencies = [
14
14
  "lamindb[bionty,ourprojects,jupyter,cellregistry,zarr]>=1.0.4,<2",
15
15
  "cellxgene-census>=0.1.0",
16
16
  "torch==2.2.0",
17
- "lightning>=2.3.0",
17
+ "pytorch-lightning>=2.3.0",
18
18
  "anndata>=0.9.0",
19
19
  "zarr>=2.10.0",
20
20
  "matplotlib>=3.5.0",
@@ -28,6 +28,8 @@ dependencies = [
28
28
  "scikit-misc>=0.5.0",
29
29
  "palantir>=1.3.3",
30
30
  "harmonypy>=0.0.10",
31
+ "jupytext>=1.16.0",
32
+
31
33
  ]
32
34
 
33
35
  [project.optional-dependencies]
@@ -0,0 +1 @@
1
+ 1.9.0
@@ -148,19 +148,19 @@ class Collator:
148
148
  :, self.accepted_genes[organism_id]
149
149
  ]
150
150
  if self.how == "most expr":
151
+ nnz_loc = np.where(expr > 0)[0]
151
152
  if "knn_cells" in elem:
152
153
  nnz_loc = np.where(expr + elem["knn_cells"].sum(0) > 0)[0]
154
+ ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
155
+ loc = np.argsort(expr + elem["knn_cells"].mean(0))[-(ma):][::-1]
153
156
  else:
154
157
  nnz_loc = np.where(expr > 0)[0]
155
- ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
156
- loc = np.argsort(expr)[-(ma):][::-1]
158
+ ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
159
+ loc = np.argsort(expr)[-(ma):][::-1]
157
160
  # nnz_loc = [1] * 30_000
158
161
  # loc = np.argsort(expr)[-(self.max_len) :][::-1]
159
162
  elif self.how == "random expr":
160
- if "knn_cells" in elem:
161
- nnz_loc = np.where(expr + elem["knn_cells"].sum(0) > 0)[0]
162
- else:
163
- nnz_loc = np.where(expr > 0)[0]
163
+ nnz_loc = np.where(expr > 0)[0]
164
164
  loc = nnz_loc[
165
165
  np.random.choice(
166
166
  len(nnz_loc),
@@ -180,33 +180,42 @@ class Collator:
180
180
  "some",
181
181
  ]:
182
182
  if "knn_cells" in elem:
183
- zero_loc = np.where(expr + elem["knn_cells"].sum(0) == 0)[0]
183
+ # we complete with genes expressed in the knn
184
+ nnz_loc = np.where(elem["knn_cells"].sum(0) > 0)[0]
185
+ ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
186
+ # which is not a zero_loc in this context
187
+ zero_loc = np.argsort(elem["knn_cells"].sum(0))[-(ma):][::-1]
184
188
  else:
185
189
  zero_loc = np.where(expr == 0)[0]
186
- zero_loc = zero_loc[
187
- np.random.choice(
188
- len(zero_loc),
189
- self.add_zero_genes
190
- + (
191
- 0
192
- if self.max_len < len(nnz_loc)
193
- else self.max_len - len(nnz_loc)
194
- ),
195
- replace=False,
196
- )
197
- ]
190
+ zero_loc = zero_loc[
191
+ np.random.choice(
192
+ len(zero_loc),
193
+ self.add_zero_genes
194
+ + (
195
+ 0
196
+ if self.max_len < len(nnz_loc)
197
+ else self.max_len - len(nnz_loc)
198
+ ),
199
+ replace=False,
200
+ )
201
+ ]
198
202
  loc = np.concatenate((loc, zero_loc), axis=None)
199
- if "knn_cells" in elem:
200
- knn_cells.append(elem["knn_cells"][:, loc])
201
203
  expr = expr[loc]
202
- loc = loc + self.start_idx[organism_id]
204
+ if "knn_cells" in elem:
205
+ elem["knn_cells"] = elem["knn_cells"][:, loc]
203
206
  if self.how == "some":
204
207
  if "knn_cells" in elem:
205
- knn_cells[-1] = knn_cells[-1][self.to_subset[organism_id]]
208
+ elem["knn_cells"] = elem["knn_cells"][
209
+ :, self.to_subset[organism_id]
210
+ ]
206
211
  expr = expr[self.to_subset[organism_id]]
207
212
  loc = loc[self.to_subset[organism_id]]
208
213
  exprs.append(expr)
209
- gene_locs.append(loc)
214
+ if "knn_cells" in elem:
215
+ knn_cells.append(elem["knn_cells"])
216
+ # then we need to add the start_idx to the loc to give it the correct index
217
+ # according to the model
218
+ gene_locs.append(loc + self.start_idx[organism_id])
210
219
 
211
220
  if self.tp_name is not None:
212
221
  tp.append(elem[self.tp_name])
@@ -243,7 +252,7 @@ class Collator:
243
252
  if len(is_meta) > 0:
244
253
  ret.update({"is_meta": Tensor(is_meta).int()})
245
254
  if len(knn_cells) > 0:
246
- ret.update({"knn_cells": Tensor(knn_cells).int()})
255
+ ret.update({"knn_cells": Tensor(knn_cells)})
247
256
  if len(dataset) > 0:
248
257
  ret.update({"dataset": Tensor(dataset).to(long)})
249
258
  if self.downsample is not None:
@@ -251,6 +260,8 @@ class Collator:
251
260
  if self.save_output is not None:
252
261
  with open(self.save_output, "a") as f:
253
262
  np.savetxt(f, ret["x"].numpy())
263
+ with open(self.save_output + "_loc", "a") as f:
264
+ np.savetxt(f, gene_locs)
254
265
  return ret
255
266
 
256
267
 
@@ -118,7 +118,7 @@ MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
118
118
  ],
119
119
  "HsapDv:0000258": [ # mature stage
120
120
  "MmusDv:0000110", # mature stage
121
- "HsapDv:0000204",
121
+ "HsapDv:0000204", #
122
122
  ],
123
123
  "HsapDv:0000227": [ # late adult stage
124
124
  "MmusDv:0000091", # 20 month-old stage
@@ -428,6 +428,7 @@ class MappedCollection:
428
428
  ],
429
429
  dtype=int,
430
430
  )
431
+ out["distances"] = distances[nn_idx]
431
432
 
432
433
  return out
433
434
 
@@ -9,7 +9,7 @@ import scanpy as sc
9
9
  from anndata import AnnData, read_h5ad
10
10
  from scipy.sparse import csr_matrix
11
11
  from upath import UPath
12
-
12
+ import gc
13
13
  from scdataloader import utils as data_utils
14
14
 
15
15
  FULL_LENGTH_ASSAYS = [
@@ -18,7 +18,7 @@ FULL_LENGTH_ASSAYS = [
18
18
  "EFO:0008931",
19
19
  ]
20
20
 
21
- MAXFILESIZE = 10_000_000_000
21
+ MAXFILESIZE = 5_000_000_000
22
22
 
23
23
 
24
24
  class Preprocessor:
@@ -135,6 +135,8 @@ class Preprocessor:
135
135
  self.keepdata = keepdata
136
136
 
137
137
  def __call__(self, adata, dataset_id=None) -> AnnData:
138
+ if self.additional_preprocess is not None:
139
+ adata = self.additional_preprocess(adata)
138
140
  if "organism_ontology_term_id" not in adata[0].obs.columns:
139
141
  raise ValueError(
140
142
  "organism_ontology_term_id not found in adata.obs, you need to add an ontology term id for the organism of your anndata"
@@ -143,13 +145,11 @@ class Preprocessor:
143
145
  raise ValueError(
144
146
  "gene names in the `var.index` field of your anndata should map to the ensembl_gene nomenclature else set `is_symbol` to True if using hugo symbols"
145
147
  )
146
- if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms:
148
+ if adata.obs["organism_ontology_term_id"].iloc[0] not in self.organisms:
147
149
  raise ValueError(
148
150
  "we cannot work with this organism",
149
- adata[0].obs.organism_ontology_term_id.iloc[0],
151
+ adata.obs["organism_ontology_term_id"],
150
152
  )
151
- if self.additional_preprocess is not None:
152
- adata = self.additional_preprocess(adata)
153
153
  if adata.raw is not None and self.use_raw:
154
154
  adata.X = adata.raw.X
155
155
  del adata.raw
@@ -165,11 +165,12 @@ class Preprocessor:
165
165
  del adata.layers
166
166
  if len(adata.varm.keys()) > 0 and not self.keepdata:
167
167
  del adata.varm
168
- if len(adata.obsm.keys()) > 0 and self.do_postp and not self.keepdata:
168
+ if len(adata.obsm.keys()) > 0 and not self.keepdata:
169
169
  del adata.obsm
170
- if len(adata.obsp.keys()) > 0 and self.do_postp and not self.keepdata:
170
+ if len(adata.obsp.keys()) > 0 and not self.keepdata:
171
171
  del adata.obsp
172
172
  # check that it is a count
173
+
173
174
  print("checking raw counts")
174
175
  if np.abs(
175
176
  adata[:50_000].X.astype(int) - adata[:50_000].X
@@ -230,23 +231,51 @@ class Preprocessor:
230
231
  )
231
232
  )
232
233
 
233
- if self.is_symbol or not adata.var.index.str.contains("ENS").any():
234
- if not adata.var.index.str.contains("ENS").any():
235
- print("No ENS genes found, assuming gene symbols...")
236
- genesdf["ensembl_gene_id"] = genesdf.index
237
- var = (
238
- adata.var.merge(
239
- genesdf.drop_duplicates("symbol").set_index("symbol", drop=False),
240
- left_index=True,
241
- right_index=True,
242
- how="inner",
243
- )
244
- .sort_values(by="ensembl_gene_id")
245
- .set_index("ensembl_gene_id")
234
+ # Check if we have a mix of gene names and ensembl IDs
235
+ has_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").any()
236
+ all_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").all()
237
+
238
+ if not has_ens:
239
+ print("No ENS genes found, assuming gene symbols...")
240
+ elif not all_ens:
241
+ print("Mix of ENS and gene symbols found, converting all to ENS IDs...")
242
+
243
+ genesdf["ensembl_gene_id"] = genesdf.index
244
+
245
+ # For genes that are already ENS IDs, use them directly
246
+ ens_mask = adata.var.index.str.match(r"ENS.*\d{6,}$")
247
+ symbol_mask = ~ens_mask
248
+
249
+ # Handle symbol genes
250
+ if symbol_mask.any():
251
+ symbol_var = adata.var[symbol_mask].merge(
252
+ genesdf.drop_duplicates("symbol").set_index("symbol", drop=False),
253
+ left_index=True,
254
+ right_index=True,
255
+ how="inner",
256
+ )
257
+
258
+ # Handle ENS genes
259
+ if ens_mask.any():
260
+ ens_var = adata.var[ens_mask].merge(
261
+ genesdf, left_index=True, right_index=True, how="inner"
246
262
  )
247
- adata = adata[:, var["symbol"]]
248
- adata.var = var
249
- genesdf = genesdf.set_index("ensembl_gene_id")
263
+
264
+ # Combine and sort
265
+ if symbol_mask.any() and ens_mask.any():
266
+ var = pd.concat([symbol_var, ens_var])
267
+ elif symbol_mask.any():
268
+ var = symbol_var
269
+ else:
270
+ var = ens_var
271
+
272
+ adata = adata[:, var.index]
273
+ var = var.sort_values(by="ensembl_gene_id").set_index("ensembl_gene_id")
274
+ # Update adata with combined genes
275
+ adata.var = var
276
+ genesdf = genesdf.set_index("ensembl_gene_id")
277
+ # Drop duplicate genes, keeping first occurrence
278
+ adata = adata[:, ~adata.var.index.duplicated(keep="first")]
250
279
 
251
280
  intersect_genes = set(adata.var.index).intersection(set(genesdf.index))
252
281
  print(f"Removed {len(adata.var.index) - len(intersect_genes)} genes.")
@@ -475,13 +504,17 @@ class LaminPreprocessor(Preprocessor):
475
504
  print(file)
476
505
 
477
506
  path = cache_path(file) if self.force_preloaded else file.cache()
478
- backed = read_h5ad(path, backed="r")
479
- if backed.obs.is_primary_data.sum() == 0:
480
- print(f"{file.key} only contains non primary cells.. dropping")
481
- # Save the stem_uid to a file to avoid loading it again
507
+ backed = file.open()
508
+ # backed = read_h5ad(path, backed="r")
509
+ if "is_primary_data" in backed.obs.columns:
510
+ if backed.obs.is_primary_data.sum() == 0:
511
+ print(f"{file.key} only contains non primary cells.. dropping")
512
+ # Save the stem_uid to a file to avoid loading it again
482
513
  with open("nonprimary.txt", "a") as f:
483
514
  f.write(f"{file.stem_uid}\n")
484
515
  continue
516
+ else:
517
+ print("Warning: couldn't check unicity from is_primary_data column")
485
518
  if backed.shape[1] < 1000:
486
519
  print(
487
520
  f"{file.key} only contains less than 1000 genes and is likely not scRNAseq... dropping"
@@ -502,16 +535,23 @@ class LaminPreprocessor(Preprocessor):
502
535
  block_size = int(
503
536
  (np.ceil(badata.shape[0] / 30_000) * 30_000) // num_blocks
504
537
  )
505
- print("num blocks ", num_blocks)
538
+ print(
539
+ "num blocks ",
540
+ num_blocks,
541
+ "block size ",
542
+ block_size,
543
+ "total elements ",
544
+ badata.shape[0],
545
+ )
506
546
  for j in range(num_blocks):
507
- if j == 0 and i == 390:
508
- continue
509
547
  start_index = j * block_size
510
548
  end_index = min((j + 1) * block_size, badata.shape[0])
511
- block = badata[start_index:end_index].to_memory()
549
+ block = badata[start_index:end_index]
550
+ block = block.to_memory()
512
551
  print(block)
513
552
  block = super().__call__(
514
- block, dataset_id=file.stem_uid + "_p" + str(j)
553
+ block,
554
+ dataset_id=file.stem_uid + "_p" + str(j),
515
555
  )
516
556
  myfile = ln.Artifact.from_anndata(
517
557
  block,
@@ -521,16 +561,19 @@ class LaminPreprocessor(Preprocessor):
521
561
  + " p"
522
562
  + str(j)
523
563
  + " ( revises file "
524
- + str(file.key)
564
+ + str(file.stem_uid)
525
565
  + " )",
526
566
  version=version,
527
567
  )
528
568
  myfile.save()
569
+
529
570
  if self.keep_files:
530
571
  files.append(myfile)
572
+ del block
531
573
  else:
532
574
  del myfile
533
575
  del block
576
+ gc.collect()
534
577
 
535
578
  else:
536
579
  adata = super().__call__(adata, dataset_id=file.stem_uid)
@@ -543,6 +586,7 @@ class LaminPreprocessor(Preprocessor):
543
586
  myfile.save()
544
587
  if self.keep_files:
545
588
  files.append(myfile)
589
+ del adata
546
590
  else:
547
591
  del myfile
548
592
  del adata
@@ -562,7 +606,12 @@ class LaminPreprocessor(Preprocessor):
562
606
 
563
607
  # issues with KLlggfw6I6lvmbqiZm46
564
608
  if self.keep_files:
565
- dataset = ln.Collection(files, name=name, description=description)
609
+ # Reconstruct collection using keys
610
+ dataset = ln.Collection(
611
+ [ln.Artifact.filter(key=k).one() for k in files],
612
+ name=name,
613
+ description=description,
614
+ )
566
615
  dataset.save()
567
616
  return dataset
568
617
  else:
@@ -578,7 +578,6 @@ def load_genes(organisms: Union[str, list] = "NCBITaxon:9606"): # "NCBITaxon:10
578
578
 
579
579
 
580
580
  def populate_my_ontology(
581
- organisms: List[str] = ["NCBITaxon:10090", "NCBITaxon:9606"],
582
581
  sex: List[str] = ["PATO:0000384", "PATO:0000383"],
583
582
  celltypes: List[str] = [],
584
583
  ethnicities: List[str] = [],
@@ -586,7 +585,7 @@ def populate_my_ontology(
586
585
  tissues: List[str] = [],
587
586
  diseases: List[str] = [],
588
587
  dev_stages: List[str] = [],
589
- organism_clade: str = "vertebrates",
588
+ organisms_clade: List[str] = ["vertebrates", "plants"],
590
589
  ):
591
590
  """
592
591
  creates a local version of the lamin ontologies and add the required missing values in base ontologies
@@ -622,23 +621,27 @@ def populate_my_ontology(
622
621
  ln.save(records)
623
622
  bt.CellType(name="unknown", ontology_id="unknown").save()
624
623
  # Organism
625
- if organisms is not None:
626
- names = (
627
- bt.Organism.public(organism=organism_clade).df().index
628
- if not organisms
629
- else organisms
630
- )
631
- source = bt.PublicSource.filter(name="ensembl", organism=organism_clade).last()
632
- records = [
633
- organism_or_organismlist
634
- if isinstance(organism_or_organismlist, bt.Organism)
635
- else organism_or_organismlist[0]
636
- for organism_or_organismlist in [
637
- bt.Organism.from_source(ontology_id=name, source=source)
638
- for name in names
624
+ if organisms_clade is not None:
625
+ records = []
626
+ for organism_clade in organisms_clade:
627
+ names = bt.Organism.public(organism=organism_clade).df().index
628
+ source = bt.PublicSource.filter(
629
+ name="ensembl", organism=organism_clade
630
+ ).last()
631
+ records += [
632
+ bt.Organism.from_source(name=name, source=source) for name in names
639
633
  ]
640
- ]
641
- ln.save(records)
634
+ nrecords = []
635
+ prevrec = set()
636
+ for rec in records:
637
+ if rec is None:
638
+ continue
639
+ if not isinstance(rec, bt.Organism):
640
+ rec = rec[0]
641
+ if rec.uid not in prevrec:
642
+ nrecords.append(rec)
643
+ prevrec.add(rec.uid)
644
+ ln.save(nrecords)
642
645
  bt.Organism(name="unknown", ontology_id="unknown").save()
643
646
  # Phenotype
644
647
  if sex is not None:
@@ -1 +0,0 @@
1
- 1.8.1
File without changes
File without changes
File without changes