scdataloader 0.0.4__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from torch.utils.data.sampler import (
6
6
  WeightedRandomSampler,
7
7
  SubsetRandomSampler,
8
8
  SequentialSampler,
9
+ RandomSampler,
9
10
  )
10
11
  import torch
11
12
  from torch.utils.data import DataLoader, Sampler
@@ -22,7 +23,7 @@ class DataModule(L.LightningDataModule):
22
23
  def __init__(
23
24
  self,
24
25
  collection_name: str,
25
- label_to_weight: list = ["organism_ontology_term_id"],
26
+ clss_to_weight: list = ["organism_ontology_term_id"],
26
27
  organisms: list = ["NCBITaxon:9606"],
27
28
  weight_scaler: int = 10,
28
29
  train_oversampling_per_epoch: float = 0.1,
@@ -32,9 +33,9 @@ class DataModule(L.LightningDataModule):
32
33
  use_default_col: bool = True,
33
34
  gene_position_tolerance: int = 10_000,
34
35
  # this is for the mappedCollection
35
- label_to_pred: list = ["organism_ontology_term_id"],
36
- all_labels: list = ["organism_ontology_term_id"],
37
- hierarchical_labels: list = [],
36
+ clss_to_pred: list = ["organism_ontology_term_id"],
37
+ all_clss: list = ["organism_ontology_term_id"],
38
+ hierarchical_clss: list = [],
38
39
  # this is for the collator
39
40
  how: str = "random expr",
40
41
  organism_name: str = "organism_ontology_term_id",
@@ -59,36 +60,55 @@ class DataModule(L.LightningDataModule):
59
60
 
60
61
  Args:
61
62
  collection_name (str): The lamindb collection to be used.
62
- weight_scaler (int, optional): how much more you will see the most present vs less present category.
63
- gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
64
- any genes within this distance of each other will be considered at the same position.
65
- gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
66
- the file must have ensembl_gene_id as index.
67
- This is used to subset the available genes further to the ones that have embeddings in your model.
63
+ clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
68
64
  organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
69
- label_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
65
+ weight_scaler (int, optional): how much more you will see the most present vs less present category.
66
+ train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
70
67
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
71
68
  test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
72
69
  it will use a full dataset and will round to the nearest dataset's cell count.
73
- **other args: see @file data.py and @file collator.py for more details
70
+ gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
71
+ the file must have ensembl_gene_id as index.
72
+ This is used to subset the available genes further to the ones that have embeddings in your model.
73
+ use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
74
+ gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
75
+ any genes within this distance of each other will be considered at the same position.
76
+ clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
77
+ assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
78
+ do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
79
+ max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
80
+ add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
81
+ how (str, optional): The method to use for the collator. Defaults to "random expr".
82
+ organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
83
+ tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
84
+ hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
85
+ all_clss (list, optional): List of all classes. Defaults to ["organism_ontology_term_id"].
86
+ clss_to_pred (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
74
87
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
88
+
89
+ see @file data.py and @file collator.py for more details about some of the parameters
75
90
  """
76
91
  if collection_name is not None:
77
92
  mdataset = Dataset(
78
93
  ln.Collection.filter(name=collection_name).first(),
79
94
  organisms=organisms,
80
- obs=all_labels,
81
- clss_to_pred=label_to_pred,
82
- hierarchical_clss=hierarchical_labels,
95
+ obs=all_clss,
96
+ clss_to_pred=clss_to_pred,
97
+ hierarchical_clss=hierarchical_clss,
83
98
  )
84
- print(mdataset)
99
+ # print(mdataset)
85
100
  # and location
101
+ self.gene_pos = None
86
102
  if do_gene_pos:
87
103
  if type(do_gene_pos) is str:
88
104
  print("seeing a string: loading gene positions as biomart parquet file")
89
105
  biomart = pd.read_parquet(do_gene_pos)
90
106
  else:
91
107
  # and annotations
108
+ if organisms != ["NCBITaxon:9606"]:
109
+ raise ValueError(
110
+ "need to provide your own table as this automated function only works for humans for now"
111
+ )
92
112
  biomart = getBiomartTable(
93
113
  attributes=["start_position", "chromosome_name"]
94
114
  ).set_index("ensembl_gene_id")
@@ -118,7 +138,7 @@ class DataModule(L.LightningDataModule):
118
138
  )
119
139
  if do_gene_pos:
120
140
  self.gene_pos = mdataset.genedf["pos"].tolist()
121
- self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
141
+ self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
122
142
  # we might want not to order the genes by expression (or do it?)
123
143
  # we might want to not introduce zeros and
124
144
  if use_default_col:
@@ -131,19 +151,23 @@ class DataModule(L.LightningDataModule):
131
151
  org_to_id=mdataset.encoder[organism_name],
132
152
  tp_name=tp_name,
133
153
  organism_name=organism_name,
134
- class_names=label_to_weight,
154
+ class_names=clss_to_weight,
135
155
  )
136
156
  self.validation_split = validation_split
137
157
  self.test_split = test_split
138
158
  self.dataset = mdataset
139
159
  self.kwargs = kwargs
160
+ if "sampler" in self.kwargs:
161
+ self.kwargs.pop("sampler")
140
162
  self.assays_to_drop = assays_to_drop
141
163
  self.n_samples = len(mdataset)
142
164
  self.weight_scaler = weight_scaler
143
165
  self.train_oversampling_per_epoch = train_oversampling_per_epoch
144
- self.label_to_weight = label_to_weight
166
+ self.clss_to_weight = clss_to_weight
145
167
  self.train_weights = None
146
168
  self.train_labels = None
169
+ self.test_datasets = []
170
+ self.test_idx = []
147
171
  super().__init__()
148
172
 
149
173
  def __repr__(self):
@@ -154,8 +178,12 @@ class DataModule(L.LightningDataModule):
154
178
  f"\ttest_split={self.test_split},\n"
155
179
  f"\tn_samples={self.n_samples},\n"
156
180
  f"\tweight_scaler={self.weight_scaler},\n"
157
- f"\train_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
158
- f"\tlabel_to_weight={self.label_to_weight}\n"
181
+ f"\ttrain_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
182
+ f"\tassays_to_drop={self.assays_to_drop},\n"
183
+ f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
184
+ f"\ttest datasets={str(self.test_datasets)},\n"
185
+ f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
186
+ f"\tclss_to_weight={self.clss_to_weight}\n"
159
187
  + (
160
188
  "\twith train_dataset size of=("
161
189
  + str((self.train_weights != 0).sum())
@@ -179,22 +207,22 @@ class DataModule(L.LightningDataModule):
179
207
  return decoders
180
208
 
181
209
  @property
182
- def cls_hierarchy(self):
210
+ def labels_hierarchy(self):
183
211
  """
184
- cls_hierarchy the hierarchy of labels for any cls that would have a hierarchy
212
+ labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy
185
213
 
186
214
  Returns:
187
215
  dict[str, dict[str, str]]
188
216
  """
189
- cls_hierarchy = {}
190
- for k, dic in self.dataset.class_groupings.items():
217
+ labels_hierarchy = {}
218
+ for k, dic in self.dataset.labels_groupings.items():
191
219
  rdic = {}
192
220
  for sk, v in dic.items():
193
221
  rdic[self.dataset.encoder[k][sk]] = [
194
222
  self.dataset.encoder[k][i] for i in list(v)
195
223
  ]
196
- cls_hierarchy[k] = rdic
197
- return cls_hierarchy
224
+ labels_hierarchy[k] = rdic
225
+ return labels_hierarchy
198
226
 
199
227
  @property
200
228
  def genes(self):
@@ -219,9 +247,9 @@ class DataModule(L.LightningDataModule):
219
247
  stage (str, optional): The stage of the model training process.
220
248
  It can be either 'fit' or 'test'. Defaults to None.
221
249
  """
222
- if len(self.label_to_weight) > 0:
250
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
223
251
  weights, labels = self.dataset.get_label_weights(
224
- self.label_to_weight, scaler=self.weight_scaler
252
+ self.clss_to_weight, scaler=self.weight_scaler
225
253
  )
226
254
  else:
227
255
  weights = np.ones(1)
@@ -248,7 +276,6 @@ class DataModule(L.LightningDataModule):
248
276
  idx_full = np.array(idx_full)
249
277
  else:
250
278
  idx_full = np.arange(self.n_samples)
251
- test_datasets = []
252
279
  if len_test > 0:
253
280
  # this way we work on some never seen datasets
254
281
  # keeping at least one
@@ -258,17 +285,15 @@ class DataModule(L.LightningDataModule):
258
285
  else self.dataset.mapped_dataset.n_obs_list[0]
259
286
  )
260
287
  cs = 0
261
- print("these files will be considered test datasets:")
262
288
  for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
263
289
  if cs + c > len_test:
264
290
  break
265
291
  else:
266
- print(" " + self.dataset.mapped_dataset.path_list[i].path)
267
- test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
292
+ self.test_datasets.append(
293
+ self.dataset.mapped_dataset._path_list[i].path
294
+ )
268
295
  cs += c
269
-
270
296
  len_test = cs
271
- print("perc test: ", len_test / self.n_samples)
272
297
  self.test_idx = idx_full[:len_test]
273
298
  idx_full = idx_full[len_test:]
274
299
  else:
@@ -286,8 +311,7 @@ class DataModule(L.LightningDataModule):
286
311
  self.train_weights = weights
287
312
  self.train_labels = labels
288
313
  self.idx_full = idx_full
289
-
290
- return test_datasets
314
+ return self.test_datasets
291
315
 
292
316
  def train_dataloader(self, **kwargs):
293
317
  # train_sampler = WeightedRandomSampler(
@@ -299,7 +323,6 @@ class DataModule(L.LightningDataModule):
299
323
  self.train_weights,
300
324
  self.train_labels,
301
325
  num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
302
- # replacement=True,
303
326
  )
304
327
  return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
305
328
 
@@ -321,6 +344,11 @@ class DataModule(L.LightningDataModule):
321
344
  else None
322
345
  )
323
346
 
347
+ def predict_dataloader(self):
348
+ return DataLoader(
349
+ self.dataset, sampler=SubsetRandomSampler(self.idx_full), **self.kwargs
350
+ )
351
+
324
352
  # def teardown(self):
325
353
  # clean up state after the trainer stops, delete files...
326
354
  # called on every process in DDP