scdataloader 1.6.4__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,10 @@ from torch.utils.data.sampler import (
16
16
  from .collator import Collator
17
17
  from .data import Dataset
18
18
  from .utils import getBiomartTable
19
+ import os
20
+
21
+
22
+ FILE_DIR = os.path.dirname(os.path.abspath(__file__))
19
23
 
20
24
 
21
25
  class DataModule(L.LightningDataModule):
@@ -32,22 +36,24 @@ class DataModule(L.LightningDataModule):
32
36
  use_default_col: bool = True,
33
37
  gene_position_tolerance: int = 10_000,
34
38
  # this is for the mappedCollection
35
- clss_to_pred: list = ["organism_ontology_term_id"],
36
- all_clss: list = ["organism_ontology_term_id"],
39
+ clss_to_predict: list = ["organism_ontology_term_id"],
37
40
  hierarchical_clss: list = [],
38
41
  # this is for the collator
39
42
  how: str = "random expr",
40
43
  organism_name: str = "organism_ontology_term_id",
41
44
  max_len: int = 1000,
42
45
  add_zero_genes: int = 100,
46
+ replacement: bool = True,
43
47
  do_gene_pos: Union[bool, str] = True,
44
48
  tp_name: Optional[str] = None, # "heat_diff"
45
49
  assays_to_drop: list = [
46
- "EFO:0008853",
47
- "EFO:0010961",
48
- "EFO:0030007",
49
- "EFO:0030062",
50
+ # "EFO:0008853", #patch seq
51
+ # "EFO:0010961", # visium
52
+ "EFO:0030007", # ATACseq
53
+ # "EFO:0030062", # slide-seq
50
54
  ],
55
+ restart_num: int = 0,
56
+ metacell_mode: float = 0.0,
51
57
  **kwargs,
52
58
  ):
53
59
  """
@@ -59,7 +65,6 @@ class DataModule(L.LightningDataModule):
59
65
 
60
66
  Args:
61
67
  collection_name (str): The lamindb collection to be used.
62
- clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
63
68
  organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
64
69
  weight_scaler (int, optional): how much more you will see the most present vs less present category.
65
70
  train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
@@ -81,23 +86,23 @@ class DataModule(L.LightningDataModule):
81
86
  organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
82
87
  tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
83
88
  hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
84
- all_clss (list, optional): List of all classes. Defaults to ["organism_ontology_term_id"].
85
- clss_to_pred (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
89
+ metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
90
+ clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
86
91
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
87
-
92
+ restart_num (int, optional): The number of the restart if we are continuing a previous run -> /!\ NEEDS TO BE SET. Defaults to 0.
88
93
  see @file data.py and @file collator.py for more details about some of the parameters
89
94
  """
90
95
  if collection_name is not None:
91
96
  mdataset = Dataset(
92
97
  ln.Collection.filter(name=collection_name).first(),
93
98
  organisms=organisms,
94
- obs=all_clss,
95
- clss_to_pred=clss_to_pred,
99
+ clss_to_predict=clss_to_predict,
96
100
  hierarchical_clss=hierarchical_clss,
101
+ metacell_mode=metacell_mode,
97
102
  )
98
- # print(mdataset)
99
103
  # and location
100
104
  self.gene_pos = None
105
+ self.collection_name = collection_name
101
106
  if do_gene_pos:
102
107
  if type(do_gene_pos) is str:
103
108
  print("seeing a string: loading gene positions as biomart parquet file")
@@ -151,11 +156,12 @@ class DataModule(L.LightningDataModule):
151
156
  org_to_id=mdataset.encoder[organism_name],
152
157
  tp_name=tp_name,
153
158
  organism_name=organism_name,
154
- class_names=clss_to_weight,
159
+ class_names=clss_to_predict,
155
160
  )
156
161
  self.validation_split = validation_split
157
162
  self.test_split = test_split
158
163
  self.dataset = mdataset
164
+ self.replacement = replacement
159
165
  self.kwargs = kwargs
160
166
  if "sampler" in self.kwargs:
161
167
  self.kwargs.pop("sampler")
@@ -166,6 +172,8 @@ class DataModule(L.LightningDataModule):
166
172
  self.clss_to_weight = clss_to_weight
167
173
  self.train_weights = None
168
174
  self.train_labels = None
175
+ self.nnz = None
176
+ self.restart_num = restart_num
169
177
  self.test_datasets = []
170
178
  self.test_idx = []
171
179
  super().__init__()
@@ -184,12 +192,8 @@ class DataModule(L.LightningDataModule):
184
192
  f"\ttest datasets={str(self.test_datasets)},\n"
185
193
  f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
186
194
  f"\tclss_to_weight={self.clss_to_weight}\n"
187
- + (
188
- "\twith train_dataset size of=("
189
- + str((self.train_weights != 0).sum())
190
- + ")\n)"
191
- )
192
- if self.train_weights is not None
195
+ + ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
196
+ if self.idx_full is not None
193
197
  else ")"
194
198
  )
195
199
 
@@ -247,9 +251,17 @@ class DataModule(L.LightningDataModule):
247
251
  stage (str, optional): The stage of the model training process.
248
252
  It can be either 'fit' or 'test'. Defaults to None.
249
253
  """
254
+ SCALE = 10
250
255
  if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
256
+ if "nnz" in self.clss_to_weight:
257
+ self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
258
+ self.clss_to_weight.remove("nnz")
259
+ (
260
+ (self.nnz.max() / SCALE)
261
+ / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
262
+ ).min()
251
263
  weights, labels = self.dataset.get_label_weights(
252
- self.clss_to_weight, scaler=self.weight_scaler
264
+ self.clss_to_weight, scaler=self.weight_scaler, return_categories=True
253
265
  )
254
266
  else:
255
267
  weights = np.ones(1)
@@ -268,12 +280,11 @@ class DataModule(L.LightningDataModule):
268
280
 
269
281
  idx_full = []
270
282
  if len(self.assays_to_drop) > 0:
271
- for i, a in enumerate(
272
- self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
273
- ):
274
- if a not in self.assays_to_drop:
275
- idx_full.append(i)
276
- idx_full = np.array(idx_full)
283
+ badloc = np.isin(
284
+ self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
285
+ self.assays_to_drop,
286
+ )
287
+ idx_full = np.arange(len(labels))[~badloc]
277
288
  else:
278
289
  idx_full = np.arange(self.n_samples)
279
290
  if len_test > 0:
@@ -302,12 +313,15 @@ class DataModule(L.LightningDataModule):
302
313
  np.random.shuffle(idx_full)
303
314
  if len_valid > 0:
304
315
  self.valid_idx = idx_full[:len_valid].copy()
316
+ # store it for later
305
317
  idx_full = idx_full[len_valid:]
306
318
  else:
307
319
  self.valid_idx = None
308
320
  weights = np.concatenate([weights, np.zeros(1)])
309
321
  labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
310
-
322
+ # some labels will now not exist anymore as replaced by len(weights) - 1.
323
+ # this means that the associated weights should be 0.
324
+ # by doing np.bincount(labels)*weights this will be taken into account
311
325
  self.train_weights = weights
312
326
  self.train_labels = labels
313
327
  self.idx_full = idx_full
@@ -319,17 +333,30 @@ class DataModule(L.LightningDataModule):
319
333
  # int(self.n_samples*self.train_oversampling_per_epoch),
320
334
  # replacement=True,
321
335
  # )
322
- train_sampler = LabelWeightedSampler(
323
- self.train_weights,
324
- self.train_labels,
325
- num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
336
+ try:
337
+ train_sampler = LabelWeightedSampler(
338
+ self.train_weights,
339
+ self.train_labels,
340
+ num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
341
+ element_weights=self.nnz,
342
+ replacement=self.replacement,
343
+ restart_num=self.restart_num,
344
+ )
345
+ except ValueError as e:
346
+ raise ValueError(e + "have you run `datamodule.setup()`?")
347
+ return DataLoader(
348
+ self.dataset,
349
+ sampler=train_sampler,
350
+ **self.kwargs,
351
+ **kwargs,
326
352
  )
327
- return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
328
353
 
329
354
  def val_dataloader(self):
330
355
  return (
331
356
  DataLoader(
332
- self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
357
+ self.dataset,
358
+ sampler=SubsetRandomSampler(self.valid_idx),
359
+ **self.kwargs,
333
360
  )
334
361
  if self.valid_idx is not None
335
362
  else None
@@ -346,7 +373,9 @@ class DataModule(L.LightningDataModule):
346
373
 
347
374
  def predict_dataloader(self):
348
375
  return DataLoader(
349
- self.dataset, sampler=SubsetRandomSampler(self.idx_full), **self.kwargs
376
+ self.dataset,
377
+ sampler=SubsetRandomSampler(self.idx_full),
378
+ **self.kwargs,
350
379
  )
351
380
 
352
381
  # def teardown(self):
@@ -359,18 +388,27 @@ class LabelWeightedSampler(Sampler[int]):
359
388
  label_weights: Sequence[float]
360
389
  klass_indices: Sequence[Sequence[int]]
361
390
  num_samples: int
362
-
391
+ nnz: Optional[Sequence[int]]
392
+ replacement: bool
393
+ restart_num: int
363
394
  # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
364
395
  # this will result a class-balanced sampling, no matter how imbalance the labels are.
365
- # NOTE: here we use replacement=True, you can change it if you don't upsample a class.
396
+
366
397
  def __init__(
367
- self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
398
+ self,
399
+ label_weights: Sequence[float],
400
+ labels: Sequence[int],
401
+ num_samples: int,
402
+ replacement: bool = True,
403
+ element_weights: Sequence[float] = None,
404
+ restart_num=0,
368
405
  ) -> None:
369
406
  """
370
407
 
371
408
  :param label_weights: list(len=num_classes)[float], weights for each class.
372
409
  :param labels: list(len=dataset_len)[int], labels of a dataset.
373
410
  :param num_samples: number of samples.
411
+ :param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
374
412
  """
375
413
 
376
414
  super(LabelWeightedSampler, self).__init__(None)
@@ -379,24 +417,69 @@ class LabelWeightedSampler(Sampler[int]):
379
417
 
380
418
  self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
381
419
  self.labels = torch.as_tensor(labels, dtype=torch.int)
420
+ self.element_weights = (
421
+ torch.as_tensor(element_weights, dtype=torch.float32)
422
+ if element_weights is not None
423
+ else None
424
+ )
425
+ self.replacement = replacement
382
426
  self.num_samples = num_samples
427
+ self.restart_num = restart_num
383
428
  # list of tensor.
384
429
  self.klass_indices = [
385
430
  (self.labels == i_klass).nonzero().squeeze(1)
386
431
  for i_klass in range(len(label_weights))
387
432
  ]
433
+ self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
388
434
 
389
435
  def __iter__(self):
390
436
  sample_labels = torch.multinomial(
391
- self.label_weights, num_samples=self.num_samples, replacement=True
437
+ self.label_weights,
438
+ num_samples=self.num_samples,
439
+ replacement=True,
440
+ generator=None
441
+ if self.restart_num == 0
442
+ else torch.Generator().manual_seed(self.restart_num),
392
443
  )
393
444
  sample_indices = torch.empty_like(sample_labels)
445
+
394
446
  for i_klass, klass_index in enumerate(self.klass_indices):
395
447
  if klass_index.numel() == 0:
396
448
  continue
397
449
  left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
398
- right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
450
+ if len(left_inds) == 0:
451
+ continue
452
+ if self.element_weights is not None:
453
+ right_inds = torch.multinomial(
454
+ self.element_weights[klass_index],
455
+ num_samples=len(klass_index)
456
+ if not self.replacement and len(klass_index) < len(left_inds)
457
+ else len(left_inds),
458
+ replacement=self.replacement,
459
+ generator=None
460
+ if self.restart_num == 0
461
+ else torch.Generator().manual_seed(self.restart_num),
462
+ )
463
+ elif self.replacement:
464
+ right_inds = torch.randint(
465
+ len(klass_index),
466
+ size=(len(left_inds),),
467
+ generator=None
468
+ if self.restart_num == 0
469
+ else torch.Generator().manual_seed(self.restart_num),
470
+ )
471
+ else:
472
+ maxelem = (
473
+ len(left_inds)
474
+ if len(left_inds) < len(klass_index)
475
+ else len(klass_index)
476
+ )
477
+ right_inds = torch.randperm(len(klass_index))[:maxelem]
399
478
  sample_indices[left_inds] = klass_index[right_inds]
479
+ # torch shuffle
480
+ sample_indices = sample_indices[torch.randperm(len(sample_indices))]
481
+ print(sample_indices.tolist()[:10], sample_labels[:10])
482
+ # raise Exception("stop")
400
483
  yield from iter(sample_indices.tolist())
401
484
 
402
485
  def __len__(self):