scdataloader 1.6.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from typing import Optional, Sequence, Union
2
3
 
3
4
  import lamindb as ln
@@ -15,7 +16,9 @@ from torch.utils.data.sampler import (
15
16
 
16
17
  from .collator import Collator
17
18
  from .data import Dataset
18
- from .utils import getBiomartTable
19
+ from .utils import getBiomartTable, slurm_restart_count
20
+
21
+ FILE_DIR = os.path.dirname(os.path.abspath(__file__))
19
22
 
20
23
 
21
24
  class DataModule(L.LightningDataModule):
@@ -32,22 +35,24 @@ class DataModule(L.LightningDataModule):
32
35
  use_default_col: bool = True,
33
36
  gene_position_tolerance: int = 10_000,
34
37
  # this is for the mappedCollection
35
- clss_to_pred: list = ["organism_ontology_term_id"],
36
- all_clss: list = ["organism_ontology_term_id"],
38
+ clss_to_predict: list = ["organism_ontology_term_id"],
37
39
  hierarchical_clss: list = [],
38
40
  # this is for the collator
39
41
  how: str = "random expr",
40
42
  organism_name: str = "organism_ontology_term_id",
41
43
  max_len: int = 1000,
42
44
  add_zero_genes: int = 100,
45
+ replacement: bool = True,
43
46
  do_gene_pos: Union[bool, str] = True,
44
47
  tp_name: Optional[str] = None, # "heat_diff"
45
48
  assays_to_drop: list = [
46
- "EFO:0008853",
47
- "EFO:0010961",
48
- "EFO:0030007",
49
- "EFO:0030062",
49
+ # "EFO:0008853", #patch seq
50
+ # "EFO:0010961", # visium
51
+ "EFO:0030007", # ATACseq
52
+ # "EFO:0030062", # slide-seq
50
53
  ],
54
+ metacell_mode: float = 0.0,
55
+ modify_seed_on_requeue: bool = True,
51
56
  **kwargs,
52
57
  ):
53
58
  """
@@ -59,7 +64,6 @@ class DataModule(L.LightningDataModule):
59
64
 
60
65
  Args:
61
66
  collection_name (str): The lamindb collection to be used.
62
- clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
63
67
  organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
64
68
  weight_scaler (int, optional): how much more you will see the most present vs less present category.
65
69
  train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
@@ -81,23 +85,24 @@ class DataModule(L.LightningDataModule):
81
85
  organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
82
86
  tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
83
87
  hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
84
- all_clss (list, optional): List of all classes. Defaults to ["organism_ontology_term_id"].
85
- clss_to_pred (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
88
+ metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
89
+ clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
90
+ modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
86
91
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
87
-
88
92
  see @file data.py and @file collator.py for more details about some of the parameters
89
93
  """
90
94
  if collection_name is not None:
91
95
  mdataset = Dataset(
92
96
  ln.Collection.filter(name=collection_name).first(),
93
97
  organisms=organisms,
94
- obs=all_clss,
95
- clss_to_pred=clss_to_pred,
98
+ clss_to_predict=clss_to_predict,
96
99
  hierarchical_clss=hierarchical_clss,
100
+ metacell_mode=metacell_mode,
97
101
  )
98
- # print(mdataset)
99
102
  # and location
103
+ self.metacell_mode = bool(metacell_mode)
100
104
  self.gene_pos = None
105
+ self.collection_name = collection_name
101
106
  if do_gene_pos:
102
107
  if type(do_gene_pos) is str:
103
108
  print("seeing a string: loading gene positions as biomart parquet file")
@@ -127,7 +132,7 @@ class DataModule(L.LightningDataModule):
127
132
  c.append(i)
128
133
  prev_position = r["start_position"]
129
134
  prev_chromosome = r["chromosome_name"]
130
- print(f"reduced the size to {len(set(c))/len(biomart)}")
135
+ print(f"reduced the size to {len(set(c)) / len(biomart)}")
131
136
  biomart["pos"] = c
132
137
  mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
133
138
  self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
@@ -151,11 +156,13 @@ class DataModule(L.LightningDataModule):
151
156
  org_to_id=mdataset.encoder[organism_name],
152
157
  tp_name=tp_name,
153
158
  organism_name=organism_name,
154
- class_names=clss_to_weight,
159
+ class_names=clss_to_predict,
160
+ metacell_mode=bool(metacell_mode),
155
161
  )
156
162
  self.validation_split = validation_split
157
163
  self.test_split = test_split
158
164
  self.dataset = mdataset
165
+ self.replacement = replacement
159
166
  self.kwargs = kwargs
160
167
  if "sampler" in self.kwargs:
161
168
  self.kwargs.pop("sampler")
@@ -166,6 +173,9 @@ class DataModule(L.LightningDataModule):
166
173
  self.clss_to_weight = clss_to_weight
167
174
  self.train_weights = None
168
175
  self.train_labels = None
176
+ self.modify_seed_on_requeue = modify_seed_on_requeue
177
+ self.nnz = None
178
+ self.restart_num = 0
169
179
  self.test_datasets = []
170
180
  self.test_idx = []
171
181
  super().__init__()
@@ -184,12 +194,8 @@ class DataModule(L.LightningDataModule):
184
194
  f"\ttest datasets={str(self.test_datasets)},\n"
185
195
  f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
186
196
  f"\tclss_to_weight={self.clss_to_weight}\n"
187
- + (
188
- "\twith train_dataset size of=("
189
- + str((self.train_weights != 0).sum())
190
- + ")\n)"
191
- )
192
- if self.train_weights is not None
197
+ + ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
198
+ if self.idx_full is not None
193
199
  else ")"
194
200
  )
195
201
 
@@ -247,13 +253,23 @@ class DataModule(L.LightningDataModule):
247
253
  stage (str, optional): The stage of the model training process.
248
254
  It can be either 'fit' or 'test'. Defaults to None.
249
255
  """
256
+ SCALE = 10
257
+ if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
258
+ self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
259
+ self.clss_to_weight.remove("nnz")
260
+ (
261
+ (self.nnz.max() / SCALE)
262
+ / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
263
+ ).min()
250
264
  if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
251
265
  weights, labels = self.dataset.get_label_weights(
252
- self.clss_to_weight, scaler=self.weight_scaler
266
+ self.clss_to_weight,
267
+ scaler=self.weight_scaler,
268
+ return_categories=True,
253
269
  )
254
270
  else:
255
271
  weights = np.ones(1)
256
- labels = np.zeros(self.n_samples)
272
+ labels = np.zeros(self.n_samples, dtype=int)
257
273
  if isinstance(self.validation_split, int):
258
274
  len_valid = self.validation_split
259
275
  else:
@@ -268,12 +284,11 @@ class DataModule(L.LightningDataModule):
268
284
 
269
285
  idx_full = []
270
286
  if len(self.assays_to_drop) > 0:
271
- for i, a in enumerate(
272
- self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
273
- ):
274
- if a not in self.assays_to_drop:
275
- idx_full.append(i)
276
- idx_full = np.array(idx_full)
287
+ badloc = np.isin(
288
+ self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
289
+ self.assays_to_drop,
290
+ )
291
+ idx_full = np.arange(len(labels))[~badloc]
277
292
  else:
278
293
  idx_full = np.arange(self.n_samples)
279
294
  if len_test > 0:
@@ -302,12 +317,15 @@ class DataModule(L.LightningDataModule):
302
317
  np.random.shuffle(idx_full)
303
318
  if len_valid > 0:
304
319
  self.valid_idx = idx_full[:len_valid].copy()
320
+ # store it for later
305
321
  idx_full = idx_full[len_valid:]
306
322
  else:
307
323
  self.valid_idx = None
308
324
  weights = np.concatenate([weights, np.zeros(1)])
309
325
  labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
310
-
326
+ # some labels will now not exist anymore as replaced by len(weights) - 1.
327
+ # this means that the associated weights should be 0.
328
+ # by doing np.bincount(labels)*weights this will be taken into account
311
329
  self.train_weights = weights
312
330
  self.train_labels = labels
313
331
  self.idx_full = idx_full
@@ -319,17 +337,31 @@ class DataModule(L.LightningDataModule):
319
337
  # int(self.n_samples*self.train_oversampling_per_epoch),
320
338
  # replacement=True,
321
339
  # )
322
- train_sampler = LabelWeightedSampler(
323
- self.train_weights,
324
- self.train_labels,
325
- num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
340
+ try:
341
+ train_sampler = LabelWeightedSampler(
342
+ label_weights=self.train_weights,
343
+ labels=self.train_labels,
344
+ num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
345
+ element_weights=self.nnz,
346
+ replacement=self.replacement,
347
+ restart_num=self.restart_num,
348
+ modify_seed_on_requeue=self.modify_seed_on_requeue,
349
+ )
350
+ except ValueError as e:
351
+ raise ValueError(e + "have you run `datamodule.setup()`?")
352
+ return DataLoader(
353
+ self.dataset,
354
+ sampler=train_sampler,
355
+ **self.kwargs,
356
+ **kwargs,
326
357
  )
327
- return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
328
358
 
329
359
  def val_dataloader(self):
330
360
  return (
331
361
  DataLoader(
332
- self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
362
+ self.dataset,
363
+ sampler=SubsetRandomSampler(self.valid_idx),
364
+ **self.kwargs,
333
365
  )
334
366
  if self.valid_idx is not None
335
367
  else None
@@ -346,7 +378,9 @@ class DataModule(L.LightningDataModule):
346
378
 
347
379
  def predict_dataloader(self):
348
380
  return DataLoader(
349
- self.dataset, sampler=SubsetRandomSampler(self.idx_full), **self.kwargs
381
+ self.dataset,
382
+ sampler=SubsetRandomSampler(self.idx_full),
383
+ **self.kwargs,
350
384
  )
351
385
 
352
386
  # def teardown(self):
@@ -359,18 +393,29 @@ class LabelWeightedSampler(Sampler[int]):
359
393
  label_weights: Sequence[float]
360
394
  klass_indices: Sequence[Sequence[int]]
361
395
  num_samples: int
362
-
396
+ nnz: Optional[Sequence[int]]
397
+ replacement: bool
398
+ restart_num: int
399
+ modify_seed_on_requeue: bool
363
400
  # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
364
401
  # this will result a class-balanced sampling, no matter how imbalance the labels are.
365
- # NOTE: here we use replacement=True, you can change it if you don't upsample a class.
402
+
366
403
  def __init__(
367
- self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
404
+ self,
405
+ label_weights: Sequence[float],
406
+ labels: Sequence[int],
407
+ num_samples: int,
408
+ replacement: bool = True,
409
+ element_weights: Sequence[float] = None,
410
+ restart_num: int = 0,
411
+ modify_seed_on_requeue: bool = True,
368
412
  ) -> None:
369
413
  """
370
414
 
371
415
  :param label_weights: list(len=num_classes)[float], weights for each class.
372
416
  :param labels: list(len=dataset_len)[int], labels of a dataset.
373
417
  :param num_samples: number of samples.
418
+ :param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
374
419
  """
375
420
 
376
421
  super(LabelWeightedSampler, self).__init__(None)
@@ -379,24 +424,73 @@ class LabelWeightedSampler(Sampler[int]):
379
424
 
380
425
  self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
381
426
  self.labels = torch.as_tensor(labels, dtype=torch.int)
427
+ self.element_weights = (
428
+ torch.as_tensor(element_weights, dtype=torch.float32)
429
+ if element_weights is not None
430
+ else None
431
+ )
432
+ self.replacement = replacement
382
433
  self.num_samples = num_samples
434
+ self.restart_num = slurm_restart_count(use_mine=True) + restart_num
435
+ self.modify_seed_on_requeue = modify_seed_on_requeue
383
436
  # list of tensor.
384
437
  self.klass_indices = [
385
438
  (self.labels == i_klass).nonzero().squeeze(1)
386
439
  for i_klass in range(len(label_weights))
387
440
  ]
441
+ self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
388
442
 
389
443
  def __iter__(self):
390
444
  sample_labels = torch.multinomial(
391
- self.label_weights, num_samples=self.num_samples, replacement=True
445
+ self.label_weights,
446
+ num_samples=self.num_samples,
447
+ replacement=True,
448
+ generator=None
449
+ if self.restart_num == 0 and not self.modify_seed_on_requeue
450
+ else torch.Generator().manual_seed(self.restart_num),
392
451
  )
393
452
  sample_indices = torch.empty_like(sample_labels)
394
453
  for i_klass, klass_index in enumerate(self.klass_indices):
395
454
  if klass_index.numel() == 0:
396
455
  continue
397
456
  left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
398
- right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
399
- sample_indices[left_inds] = klass_index[right_inds]
457
+ if len(left_inds) == 0:
458
+ continue
459
+ if self.element_weights is not None:
460
+ right_inds = torch.multinomial(
461
+ self.element_weights[klass_index],
462
+ num_samples=len(klass_index)
463
+ if not self.replacement and len(klass_index) < len(left_inds)
464
+ else len(left_inds),
465
+ replacement=self.replacement,
466
+ generator=None
467
+ if self.restart_num == 0 and not self.modify_seed_on_requeue
468
+ else torch.Generator().manual_seed(self.restart_num),
469
+ )
470
+ elif self.replacement:
471
+ right_inds = torch.randint(
472
+ len(klass_index),
473
+ size=(len(left_inds),),
474
+ generator=None
475
+ if self.restart_num == 0 and not self.modify_seed_on_requeue
476
+ else torch.Generator().manual_seed(self.restart_num),
477
+ )
478
+ else:
479
+ maxelem = (
480
+ len(left_inds)
481
+ if len(left_inds) < len(klass_index)
482
+ else len(klass_index)
483
+ )
484
+ right_inds = torch.randperm(len(klass_index))[:maxelem]
485
+ sample_indices[left_inds[: len(right_inds)]] = klass_index[right_inds]
486
+ if len(right_inds) < len(left_inds):
487
+ sample_indices[left_inds[len(right_inds) :]] = -1
488
+ # drop all -1
489
+ sample_indices = sample_indices[sample_indices != -1]
490
+ # torch shuffle
491
+ sample_indices = sample_indices[torch.randperm(len(sample_indices))]
492
+ self.num_samples = len(sample_indices)
493
+ # raise Exception("stop")
400
494
  yield from iter(sample_indices.tolist())
401
495
 
402
496
  def __len__(self):