scdataloader 1.8.0__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/VERSION CHANGED
@@ -1 +1 @@
1
- 1.8.0
1
+ 1.8.1
scdataloader/__init__.py CHANGED
@@ -2,5 +2,6 @@ from .collator import Collator
2
2
  from .data import Dataset, SimpleAnnDataset
3
3
  from .datamodule import DataModule
4
4
  from .preprocess import Preprocessor
5
+ from importlib.metadata import version
5
6
 
6
- __version__ = "1.7.0"
7
+ __version__ = version("scdataloader")
scdataloader/collator.py CHANGED
@@ -24,7 +24,6 @@ class Collator:
24
24
  genelist: list[str] = [],
25
25
  downsample: Optional[float] = None, # don't use it for training!
26
26
  save_output: Optional[str] = None,
27
- metacell_mode: bool = False,
28
27
  ):
29
28
  """
30
29
  This class is responsible for collating data for the scPRINT model. It handles the
@@ -62,7 +61,6 @@ class Collator:
62
61
  This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
63
62
  save_output (str, optional): If not None, saves the output to a file. Defaults to None.
64
63
  This is mainly for debugging purposes
65
- metacell_mode (bool, optional): Whether to sample a metacell. Defaults to False.
66
64
  """
67
65
  self.organisms = organisms
68
66
  self.genedf = load_genes(organisms)
@@ -82,7 +80,6 @@ class Collator:
82
80
  self.accepted_genes = {}
83
81
  self.downsample = downsample
84
82
  self.to_subset = {}
85
- self.metacell_mode = metacell_mode
86
83
  self._setup(org_to_id, valid_genes, genelist)
87
84
 
88
85
  def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
@@ -135,6 +132,7 @@ class Collator:
135
132
  dataset = []
136
133
  nnz_loc = []
137
134
  is_meta = []
135
+ knn_cells = []
138
136
  for elem in batch:
139
137
  organism_id = elem[self.organism_name]
140
138
  if organism_id not in self.organism_ids:
@@ -145,14 +143,24 @@ class Collator:
145
143
  total_count.append(expr.sum())
146
144
  if len(self.accepted_genes) > 0:
147
145
  expr = expr[self.accepted_genes[organism_id]]
146
+ if "knn_cells" in elem:
147
+ elem["knn_cells"] = elem["knn_cells"][
148
+ :, self.accepted_genes[organism_id]
149
+ ]
148
150
  if self.how == "most expr":
149
- nnz_loc = np.where(expr > 0)[0]
151
+ if "knn_cells" in elem:
152
+ nnz_loc = np.where(expr + elem["knn_cells"].sum(0) > 0)[0]
153
+ else:
154
+ nnz_loc = np.where(expr > 0)[0]
150
155
  ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
151
156
  loc = np.argsort(expr)[-(ma):][::-1]
152
157
  # nnz_loc = [1] * 30_000
153
158
  # loc = np.argsort(expr)[-(self.max_len) :][::-1]
154
159
  elif self.how == "random expr":
155
- nnz_loc = np.where(expr > 0)[0]
160
+ if "knn_cells" in elem:
161
+ nnz_loc = np.where(expr + elem["knn_cells"].sum(0) > 0)[0]
162
+ else:
163
+ nnz_loc = np.where(expr > 0)[0]
156
164
  loc = nnz_loc[
157
165
  np.random.choice(
158
166
  len(nnz_loc),
@@ -171,7 +179,10 @@ class Collator:
171
179
  "all",
172
180
  "some",
173
181
  ]:
174
- zero_loc = np.where(expr == 0)[0]
182
+ if "knn_cells" in elem:
183
+ zero_loc = np.where(expr + elem["knn_cells"].sum(0) == 0)[0]
184
+ else:
185
+ zero_loc = np.where(expr == 0)[0]
175
186
  zero_loc = zero_loc[
176
187
  np.random.choice(
177
188
  len(zero_loc),
@@ -185,9 +196,13 @@ class Collator:
185
196
  )
186
197
  ]
187
198
  loc = np.concatenate((loc, zero_loc), axis=None)
199
+ if "knn_cells" in elem:
200
+ knn_cells.append(elem["knn_cells"][:, loc])
188
201
  expr = expr[loc]
189
202
  loc = loc + self.start_idx[organism_id]
190
203
  if self.how == "some":
204
+ if "knn_cells" in elem:
205
+ knn_cells[-1] = knn_cells[-1][self.to_subset[organism_id]]
191
206
  expr = expr[self.to_subset[organism_id]]
192
207
  loc = loc[self.to_subset[organism_id]]
193
208
  exprs.append(expr)
@@ -197,7 +212,7 @@ class Collator:
197
212
  tp.append(elem[self.tp_name])
198
213
  else:
199
214
  tp.append(0)
200
- if self.metacell_mode:
215
+ if "is_meta" in elem:
201
216
  is_meta.append(elem["is_meta"])
202
217
  other_classes.append([elem[i] for i in self.class_names])
203
218
  expr = np.array(exprs)
@@ -207,6 +222,7 @@ class Collator:
207
222
  other_classes = np.array(other_classes)
208
223
  dataset = np.array(dataset)
209
224
  is_meta = np.array(is_meta)
225
+ knn_cells = np.array(knn_cells)
210
226
  # normalize counts
211
227
  if self.norm_to is not None:
212
228
  expr = (expr * self.norm_to) / total_count[:, None]
@@ -217,15 +233,6 @@ class Collator:
217
233
  if self.n_bins:
218
234
  pass
219
235
 
220
- # find the associated gene ids (given the species)
221
-
222
- # get the NN cells
223
-
224
- # do encoding / selection a la scGPT
225
-
226
- # do encoding of graph location
227
- # encode all the edges in some sparse way
228
- # normalizing total counts between 0,1
229
236
  ret = {
230
237
  "x": Tensor(expr),
231
238
  "genes": Tensor(gene_locs).int(),
@@ -233,8 +240,10 @@ class Collator:
233
240
  "tp": Tensor(tp),
234
241
  "depth": Tensor(total_count),
235
242
  }
236
- if self.metacell_mode:
243
+ if len(is_meta) > 0:
237
244
  ret.update({"is_meta": Tensor(is_meta).int()})
245
+ if len(knn_cells) > 0:
246
+ ret.update({"knn_cells": Tensor(knn_cells).int()})
238
247
  if len(dataset) > 0:
239
248
  ret.update({"dataset": Tensor(dataset).to(long)})
240
249
  if self.downsample is not None:
scdataloader/data.py CHANGED
@@ -58,6 +58,7 @@ class Dataset(torchDataset):
58
58
  hierarchical_clss: Optional[list[str]] = field(default_factory=list)
59
59
  join_vars: Literal["inner", "outer"] | None = None
60
60
  metacell_mode: float = 0.0
61
+ get_knn_cells: bool = False
61
62
 
62
63
  def __post_init__(self):
63
64
  self.mapped_dataset = mapped(
@@ -69,6 +70,7 @@ class Dataset(torchDataset):
69
70
  stream=True,
70
71
  parallel=True,
71
72
  metacell_mode=self.metacell_mode,
73
+ get_knn_cells=self.get_knn_cells,
72
74
  )
73
75
  print(
74
76
  "won't do any check but we recommend to have your dataset coming from local storage"
@@ -371,6 +373,7 @@ def mapped(
371
373
  is_run_input: bool | None = None,
372
374
  metacell_mode: bool = False,
373
375
  meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
376
+ get_knn_cells: bool = False,
374
377
  ) -> MappedCollection:
375
378
  path_list = []
376
379
  for artifact in dataset.artifacts.all():
@@ -397,5 +400,6 @@ def mapped(
397
400
  dtype=dtype,
398
401
  meta_assays=meta_assays,
399
402
  metacell_mode=metacell_mode,
403
+ get_knn_cells=get_knn_cells,
400
404
  )
401
405
  return ds
@@ -52,6 +52,7 @@ class DataModule(L.LightningDataModule):
52
52
  # "EFO:0030062", # slide-seq
53
53
  ],
54
54
  metacell_mode: float = 0.0,
55
+ get_knn_cells: bool = False,
55
56
  modify_seed_on_requeue: bool = True,
56
57
  **kwargs,
57
58
  ):
@@ -88,6 +89,7 @@ class DataModule(L.LightningDataModule):
88
89
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
89
90
  clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
90
91
  modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
92
+ get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
91
93
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
92
94
  see @file data.py and @file collator.py for more details about some of the parameters
93
95
  """
@@ -98,6 +100,7 @@ class DataModule(L.LightningDataModule):
98
100
  clss_to_predict=clss_to_predict,
99
101
  hierarchical_clss=hierarchical_clss,
100
102
  metacell_mode=metacell_mode,
103
+ get_knn_cells=get_knn_cells,
101
104
  )
102
105
  # and location
103
106
  self.metacell_mode = bool(metacell_mode)
@@ -157,7 +160,6 @@ class DataModule(L.LightningDataModule):
157
160
  tp_name=tp_name,
158
161
  organism_name=organism_name,
159
162
  class_names=clss_to_predict,
160
- metacell_mode=bool(metacell_mode),
161
163
  )
162
164
  self.validation_split = validation_split
163
165
  self.test_split = test_split
scdataloader/mapped.py CHANGED
@@ -96,8 +96,9 @@ class MappedCollection:
96
96
  cache_categories: Enable caching categories of ``obs_keys`` for faster access.
97
97
  parallel: Enable sampling with multiple processes.
98
98
  dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
99
- meta_assays: Assays to check for metacells.
100
- metacell_mode: Mode for metacells.
99
+ meta_assays: Assays that are already defined as metacells.
100
+ metacell_mode: frequency at which to sample a metacell (an average of k-nearest neighbors).
101
+ get_knn_cells: Whether to also dataload the k-nearest neighbors of each queried cells.
101
102
  """
102
103
 
103
104
  def __init__(
@@ -114,6 +115,7 @@ class MappedCollection:
114
115
  parallel: bool = False,
115
116
  dtype: str | None = None,
116
117
  metacell_mode: float = 0.0,
118
+ get_knn_cells: bool = False,
117
119
  meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
118
120
  ):
119
121
  if join not in {None, "inner", "outer"}: # pragma: nocover
@@ -166,6 +168,7 @@ class MappedCollection:
166
168
  self.metacell_mode = metacell_mode
167
169
  self.path_list = path_list
168
170
  self.meta_assays = meta_assays
171
+ self.get_knn_cells = get_knn_cells
169
172
  self._make_connections(path_list, parallel)
170
173
 
171
174
  self._cache_cats: dict = {}
@@ -396,12 +399,15 @@ class MappedCollection:
396
399
  label_idx = self.encoders[label][label_idx]
397
400
  out[label] = label_idx
398
401
 
399
- out["is_meta"] = False
400
- if len(self.meta_assays) > 0 and "assay_ontology_term_id" in self.obs_keys:
401
- if out["assay_ontology_term_id"] in self.meta_assays:
402
- out["is_meta"] = True
403
- return out
404
402
  if self.metacell_mode > 0:
403
+ if (
404
+ len(self.meta_assays) > 0
405
+ and "assay_ontology_term_id" in self.obs_keys
406
+ ):
407
+ if out["assay_ontology_term_id"] in self.meta_assays:
408
+ out["is_meta"] = True
409
+ return out
410
+ out["is_meta"] = False
405
411
  if np.random.random() < self.metacell_mode:
406
412
  out["is_meta"] = True
407
413
  distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
@@ -410,6 +416,18 @@ class MappedCollection:
410
416
  out[layers_key] += self._get_data_idx(
411
417
  lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
412
418
  )
419
+ elif self.get_knn_cells:
420
+ distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
421
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
422
+ out["knn_cells"] = np.array(
423
+ [
424
+ self._get_data_idx(
425
+ lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
426
+ )
427
+ for i in nn_idx
428
+ ],
429
+ dtype=int,
430
+ )
413
431
 
414
432
  return out
415
433
 
@@ -64,6 +64,11 @@ class Preprocessor:
64
64
  """
65
65
  Initializes the preprocessor and configures the workflow steps.
66
66
 
67
+ Your dataset should contain at least the following obs:
68
+ - `organism_ontology_term_id` with the ontology id of the organism of your anndata
69
+ - gene names in the `var.index` field of your anndata that map to the ensembl_gene nomenclature
70
+ or the hugo gene symbols nomenclature (if the later, set `is_symbol` to True)
71
+
67
72
  Args:
68
73
  filter_gene_by_counts (int or bool, optional): Determines whether to filter genes by counts.
69
74
  If int, filters genes with counts. Defaults to False.
@@ -130,6 +135,14 @@ class Preprocessor:
130
135
  self.keepdata = keepdata
131
136
 
132
137
  def __call__(self, adata, dataset_id=None) -> AnnData:
138
+ if "organism_ontology_term_id" not in adata[0].obs.columns:
139
+ raise ValueError(
140
+ "organism_ontology_term_id not found in adata.obs, you need to add an ontology term id for the organism of your anndata"
141
+ )
142
+ if not adata[0].var.index.str.contains("ENS").any() and not self.is_symbol:
143
+ raise ValueError(
144
+ "gene names in the `var.index` field of your anndata should map to the ensembl_gene nomenclature else set `is_symbol` to True if using hugo symbols"
145
+ )
133
146
  if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms:
134
147
  raise ValueError(
135
148
  "we cannot work with this organism",
scdataloader/utils.py CHANGED
@@ -154,7 +154,7 @@ def getBiomartTable(
154
154
  return res
155
155
 
156
156
 
157
- def validate(adata: AnnData, organism: str, need_all=True):
157
+ def validate(adata: AnnData, organism: str, need_all=False):
158
158
  """
159
159
  validate checks if the adata object is valid for lamindb
160
160
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdataloader
3
- Version: 1.8.0
3
+ Version: 1.8.1
4
4
  Summary: a dataloader for single cell data in lamindb
5
5
  Project-URL: repository, https://github.com/jkobject/scDataLoader
6
6
  Author-email: jkobject <jkobject@gmail.com>
@@ -16,7 +16,7 @@ Requires-Dist: harmonypy>=0.0.10
16
16
  Requires-Dist: ipykernel>=6.20.0
17
17
  Requires-Dist: lamindb[bionty,cellregistry,jupyter,ourprojects,zarr]<2,>=1.0.4
18
18
  Requires-Dist: leidenalg>=0.8.0
19
- Requires-Dist: lightning>=2.0.0
19
+ Requires-Dist: lightning>=2.3.0
20
20
  Requires-Dist: matplotlib>=3.5.0
21
21
  Requires-Dist: numpy==1.26.0
22
22
  Requires-Dist: palantir>=1.3.3
@@ -0,0 +1,16 @@
1
+ scdataloader/VERSION,sha256=Jc7Jc50yGOSKzF2MPUMz4dYkEhxberO83ccdD6ATS4M,6
2
+ scdataloader/__init__.py,sha256=1SyT5MzcFl8mfp5NB4idgYQ4insXbDRd-EBNvoz_dXQ,225
3
+ scdataloader/__main__.py,sha256=3aZnqYrH8XDT9nW9Dbb3o9kr-sx1STmXDQHxBo_h_q0,8719
4
+ scdataloader/base.py,sha256=M1gD59OffRdLOgS1vHKygOomUoAMuzjpRtAfM3SBKF8,338
5
+ scdataloader/collator.py,sha256=UWyTSFEYCAVcBRreFItzDgTyBx224u-ThjjW9x-osHY,12301
6
+ scdataloader/config.py,sha256=tu9hkUiU2HfaIiVzdmrjbzt73yV4zP-t8lDuJqyGcDA,6546
7
+ scdataloader/data.py,sha256=xWlNU6cJmrzP4BFMsJDIksLaxe1pUfgDBlQ_IeLIXj0,15578
8
+ scdataloader/datamodule.py,sha256=6Oby-BySXaWYr34PocgCq25FLH1QUX-EsWOZI6EVjgw,21128
9
+ scdataloader/mapped.py,sha256=DzryqhELXo-s5RgdmRFaa8zLiGjyjFKn7wW77lGLTaI,26900
10
+ scdataloader/preprocess.py,sha256=Ewla5GYD_8YBqCDr7kaOwrYN_ok0YmYvYpwbxTComXg,35764
11
+ scdataloader/utils.py,sha256=F5ZhdalHbxdZOs9aZ9RP9LTHGsmuoofgC39W9GS7EA4,28362
12
+ scdataloader-1.8.1.dist-info/METADATA,sha256=NLNmj2mWRQFpwUpMxaTHuWK309MYBiBzxBxk9Nd0KD8,9946
13
+ scdataloader-1.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ scdataloader-1.8.1.dist-info/entry_points.txt,sha256=VXAN1m_CjbdLJ6SKYR0sBLGDV4wvv31ri7fWWuwbpno,60
15
+ scdataloader-1.8.1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
16
+ scdataloader-1.8.1.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- scdataloader/VERSION,sha256=PrHvlLWJDKcnFYsQYUJoXIczsKzlvLTPPwrBT58GQ_Q,6
2
- scdataloader/__init__.py,sha256=GYqFXVzcgkqwcWodyHQSa3bnCuWsBt9jWYHEcLnx6xU,170
3
- scdataloader/__main__.py,sha256=3aZnqYrH8XDT9nW9Dbb3o9kr-sx1STmXDQHxBo_h_q0,8719
4
- scdataloader/base.py,sha256=M1gD59OffRdLOgS1vHKygOomUoAMuzjpRtAfM3SBKF8,338
5
- scdataloader/collator.py,sha256=n_DI630Eqo-C_G02krFD-Ixj3EKReZfW84VZy5wZHCw,11758
6
- scdataloader/config.py,sha256=tu9hkUiU2HfaIiVzdmrjbzt73yV4zP-t8lDuJqyGcDA,6546
7
- scdataloader/data.py,sha256=nLw0yCe0Sj0RGR9ioYKszwzuah-KRG0tpyjOh8xjNuY,15430
8
- scdataloader/datamodule.py,sha256=7xTaa6I2Yj6ikGy-bLmrsr0-9VrQUO9vW17bqhhcyJU,20972
9
- scdataloader/mapped.py,sha256=GCAygW7-JcEQ7sB-dsiA_nTPaA3Df5AcSd79_GFhh9k,26053
10
- scdataloader/preprocess.py,sha256=cHKUkGJVpnWfAVsSpl_B_IOmh8aQ0WAF2QPclhkA2eA,34876
11
- scdataloader/utils.py,sha256=GoRSEZ8aqmB8KussSTb95BxUBWlcLtErB_HGe0iZwic,28361
12
- scdataloader-1.8.0.dist-info/METADATA,sha256=UkC5E9nEXo1qf3QmNc1mz8Lvk5HL3sfmt8WiiAIqtGo,9946
13
- scdataloader-1.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- scdataloader-1.8.0.dist-info/entry_points.txt,sha256=VXAN1m_CjbdLJ6SKYR0sBLGDV4wvv31ri7fWWuwbpno,60
15
- scdataloader-1.8.0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
16
- scdataloader-1.8.0.dist-info/RECORD,,