scdataloader 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -1,83 +1,53 @@
1
1
  from dataclasses import dataclass, field
2
2
 
3
3
  import lamindb as ln
4
- import lnschema_bionty as lb
4
+ import bionty as bt
5
5
  import pandas as pd
6
6
  from torch.utils.data import Dataset as torchDataset
7
- from typing import Union
7
+ from typing import Union, Optional, Literal
8
8
  from scdataloader import mapped
9
9
  import warnings
10
10
 
11
- # TODO: manage load gene embeddings to make
12
- # from scprint.dataloader.embedder import embed
11
+ from anndata import AnnData
12
+
13
13
  from scdataloader.utils import get_ancestry_mapping, load_genes
14
14
 
15
- LABELS_TOADD = {
16
- "assay_ontology_term_id": {
17
- "10x transcription profiling": "EFO:0030003",
18
- "spatial transcriptomics": "EFO:0008994",
19
- "10x 3' transcription profiling": "EFO:0030003",
20
- "10x 5' transcription profiling": "EFO:0030004",
21
- },
22
- "disease_ontology_term_id": {
23
- "metabolic disease": "MONDO:0005066",
24
- "chronic kidney disease": "MONDO:0005300",
25
- "chromosomal disorder": "MONDO:0019040",
26
- "infectious disease": "MONDO:0005550",
27
- "inflammatory disease": "MONDO:0021166",
28
- # "immune system disease",
29
- "disorder of development or morphogenesis": "MONDO:0021147",
30
- "mitochondrial disease": "MONDO:0044970",
31
- "psychiatric disorder": "MONDO:0002025",
32
- "cancer or benign tumor": "MONDO:0002025",
33
- "neoplasm": "MONDO:0005070",
34
- },
35
- "cell_type_ontology_term_id": {
36
- "progenitor cell": "CL:0011026",
37
- "hematopoietic cell": "CL:0000988",
38
- "myoblast": "CL:0000056",
39
- "myeloid cell": "CL:0000763",
40
- "neuron": "CL:0000540",
41
- "electrically active cell": "CL:0000211",
42
- "epithelial cell": "CL:0000066",
43
- "secretory cell": "CL:0000151",
44
- "stem cell": "CL:0000034",
45
- "non-terminally differentiated cell": "CL:0000055",
46
- "supporting cell": "CL:0000630",
47
- },
48
- }
15
+ from .config import LABELS_TOADD
49
16
 
50
17
 
51
18
  @dataclass
52
19
  class Dataset(torchDataset):
53
20
  """
54
- Dataset class to load a bunch of anndata from a lamin dataset in a memory efficient way.
21
+ Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
22
+
23
+ This serves as a wrapper around lamin's mappedCollection to provide more features,
24
+ mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
55
25
 
56
- For an example, see :meth:`~lamindb.Dataset.mapped`.
26
+ For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
57
27
 
58
28
  .. note::
59
29
 
60
- A similar data loader exists `here
30
+ A related data loader exists `here
61
31
  <https://github.com/Genentech/scimilarity>`__.
62
32
 
63
- Attributes:
33
+ Args:
64
34
  ----
65
35
  lamin_dataset (lamindb.Dataset): lamin dataset to load
66
36
  genedf (pd.Dataframe): dataframe containing the genes to load
67
- gene_embedding: dataframe containing the gene embeddings
68
37
  organisms (list[str]): list of organisms to load
69
- obs (list[str]): list of observations to load
38
+ (for now only validates the the genes map to this organism)
39
+ obs (list[str]): list of observations to load from the Collection
70
40
  clss_to_pred (list[str]): list of observations to encode
71
- hierarchical_clss: list of observations to map to a hierarchy
41
+ join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
42
+ hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
72
43
  """
73
44
 
74
45
  lamin_dataset: ln.Collection
75
- genedf: pd.DataFrame = None
76
- # gene_embedding: pd.DataFrame = None # TODO: make it part of specialized dataset
77
- organisms: Union[list[str], str] = field(
46
+ genedf: Optional[pd.DataFrame] = None
47
+ organisms: Optional[Union[list[str], str]] = field(
78
48
  default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
79
49
  )
80
- obs: list[str] = field(
50
+ obs: Optional[list[str]] = field(
81
51
  default_factory=[
82
52
  "self_reported_ethnicity_ontology_term_id",
83
53
  "assay_ontology_term_id",
@@ -88,16 +58,16 @@ class Dataset(torchDataset):
88
58
  "sex_ontology_term_id",
89
59
  #'dataset_id',
90
60
  #'cell_culture',
91
- "dpt_group",
92
- "heat_diff",
93
- "nnz",
61
+ #"dpt_group",
62
+ #"heat_diff",
63
+ #"nnz",
94
64
  ]
95
65
  )
96
66
  # set of obs to prepare for prediction (encode)
97
- clss_to_pred: list[str] = field(default_factory=list)
67
+ clss_to_pred: Optional[list[str]] = field(default_factory=list)
98
68
  # set of obs that need to be hierarchically prepared
99
- hierarchical_clss: list[str] = field(default_factory=list)
100
- join_vars: str = "None"
69
+ hierarchical_clss: Optional[list[str]] = field(default_factory=list)
70
+ join_vars: Optional[Literal["auto", "inner", "None"]] = "None"
101
71
 
102
72
  def __post_init__(self):
103
73
  self.mapped_dataset = mapped.mapped(
@@ -111,6 +81,8 @@ class Dataset(torchDataset):
111
81
  print(
112
82
  "won't do any check but we recommend to have your dataset coming from local storage"
113
83
  )
84
+ self.class_groupings = {}
85
+ self.class_topred = {}
114
86
  # generate tree from ontologies
115
87
  if len(self.hierarchical_clss) > 0:
116
88
  self.define_hierarchies(self.hierarchical_clss)
@@ -149,7 +121,12 @@ class Dataset(torchDataset):
149
121
 
150
122
  def __getitem__(self, *args, **kwargs):
151
123
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
152
- #item.update({"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)})
124
+ # import pdb
125
+
126
+ # pdb.set_trace()
127
+ # item.update(
128
+ # {"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)}
129
+ # )
153
130
  # ret = {}
154
131
  # ret["count"] = item[0]
155
132
  # for i, val in enumerate(self.obs):
@@ -160,51 +137,36 @@ class Dataset(torchDataset):
160
137
  return item
161
138
 
162
139
  def __repr__(self):
163
- print(
164
- "total dataset size is {} Gb".format(
140
+ return (
141
+ "total dataset size is {} Gb\n".format(
165
142
  sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
166
143
  )
167
- )
168
- print("---")
169
- print("dataset contains:")
170
- print(" {} cells".format(self.mapped_dataset.__len__()))
171
- print(" {} genes".format(self.genedf.shape[0]))
172
- print(" {} labels".format(len(self.obs)))
173
- print(" {} organisms".format(len(self.organisms)))
174
- print(
175
- "dataset contains {} classes to predict".format(
176
- sum([len(self.class_topred[i]) for i in self.class_topred])
144
+ + "---\n"
145
+ + "dataset contains:\n"
146
+ + " {} cells\n".format(self.mapped_dataset.__len__())
147
+ + " {} genes\n".format(self.genedf.shape[0])
148
+ + " {} labels\n".format(len(self.obs))
149
+ + " {} clss_to_pred\n".format(len(self.clss_to_pred))
150
+ + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
151
+ + " {} join_vars\n".format(len(self.join_vars))
152
+ + " {} organisms\n".format(len(self.organisms))
153
+ + (
154
+ "dataset contains {} classes to predict\n".format(
155
+ sum([len(self.class_topred[i]) for i in self.class_topred])
156
+ )
157
+ if len(self.class_topred) > 0
158
+ else ""
177
159
  )
178
160
  )
179
- # print("embedding size is {}".format(self.gene_embedding.shape[1]))
180
- return ""
181
161
 
182
162
  def get_label_weights(self, *args, **kwargs):
183
163
  return self.mapped_dataset.get_label_weights(*args, **kwargs)
184
164
 
185
- def get_unseen_mapped_dataset_elements(self, idx):
165
+ def get_unseen_mapped_dataset_elements(self, idx: int):
186
166
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
187
167
 
188
- # def load_embeddings(self, genedfs, embedding_size=128, cache=True):
189
- # embeddings = []
190
- # for o in self.organisms:
191
- # genedf = genedfs[genedfs.organism == o]
192
- # org_name = lb.Organism.filter(ontology_id=o).one().scientific_name
193
- # embedding = embed(
194
- # genedf=genedf,
195
- # organism=org_name,
196
- # cache=cache,
197
- # fasta_path="/tmp/data/fasta/",
198
- # embedding_size=embedding_size,
199
- # )
200
- # genedf = pd.concat(
201
- # [genedf.set_index("ensembl_gene_id"), embedding], axis=1, join="inner"
202
- # )
203
- # genedf.columns = genedf.columns.astype(str)
204
- # embeddings.append(genedf)
205
- # return pd.concat(embeddings)
206
-
207
- def define_hierarchies(self, labels):
168
+ def define_hierarchies(self, labels: list[str]):
169
+ # TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
208
170
  self.class_groupings = {}
209
171
  self.class_topred = {}
210
172
  for label in labels:
@@ -223,37 +185,37 @@ class Dataset(torchDataset):
223
185
  )
224
186
  elif label == "cell_type_ontology_term_id":
225
187
  parentdf = (
226
- lb.CellType.filter()
188
+ bt.CellType.filter()
227
189
  .df(include=["parents__ontology_id"])
228
190
  .set_index("ontology_id")
229
191
  )
230
192
  elif label == "tissue_ontology_term_id":
231
193
  parentdf = (
232
- lb.Tissue.filter()
194
+ bt.Tissue.filter()
233
195
  .df(include=["parents__ontology_id"])
234
196
  .set_index("ontology_id")
235
197
  )
236
198
  elif label == "disease_ontology_term_id":
237
199
  parentdf = (
238
- lb.Disease.filter()
200
+ bt.Disease.filter()
239
201
  .df(include=["parents__ontology_id"])
240
202
  .set_index("ontology_id")
241
203
  )
242
204
  elif label == "development_stage_ontology_term_id":
243
205
  parentdf = (
244
- lb.DevelopmentalStage.filter()
206
+ bt.DevelopmentalStage.filter()
245
207
  .df(include=["parents__ontology_id"])
246
208
  .set_index("ontology_id")
247
209
  )
248
210
  elif label == "assay_ontology_term_id":
249
211
  parentdf = (
250
- lb.ExperimentalFactor.filter()
212
+ bt.ExperimentalFactor.filter()
251
213
  .df(include=["parents__ontology_id"])
252
214
  .set_index("ontology_id")
253
215
  )
254
216
  elif label == "self_reported_ethnicity_ontology_term_id":
255
217
  parentdf = (
256
- lb.Ethnicity.filter()
218
+ bt.Ethnicity.filter()
257
219
  .df(include=["parents__ontology_id"])
258
220
  .set_index("ontology_id")
259
221
  )
@@ -267,6 +229,9 @@ class Dataset(torchDataset):
267
229
  cats = self.mapped_dataset.get_merged_categories(label)
268
230
  addition = set(LABELS_TOADD.get(label, {}).values())
269
231
  cats |= addition
232
+ # import pdb
233
+
234
+ # pdb.set_trace()
270
235
  groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
271
236
  for i, j in groupings.items():
272
237
  if len(j) == 0:
@@ -316,7 +281,22 @@ class Dataset(torchDataset):
316
281
 
317
282
 
318
283
  class SimpleAnnDataset:
319
- def __init__(self, adata, obs_to_output=[], layer=None):
284
+ def __init__(
285
+ self,
286
+ adata: AnnData,
287
+ obs_to_output: Optional[list[str]] = [],
288
+ layer: Optional[str] = None,
289
+ ):
290
+ """
291
+ SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
292
+ scDataloader and with your model during inference.
293
+
294
+ Args:
295
+ ----
296
+ adata (anndata.AnnData): anndata object to use
297
+ obs_to_output (list[str]): list of observations to output from anndata.obs
298
+ layer (str): layer of the anndata to use
299
+ """
320
300
  self.adata = adata
321
301
  self.obs_to_output = obs_to_output
322
302
  self.layer = layer
@@ -0,0 +1,375 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import lamindb as ln
4
+
5
+ from torch.utils.data.sampler import (
6
+ WeightedRandomSampler,
7
+ SubsetRandomSampler,
8
+ SequentialSampler,
9
+ )
10
+ import torch
11
+ from torch.utils.data import DataLoader, Sampler
12
+ import lightning as L
13
+
14
+ from typing import Optional, Union, Sequence
15
+
16
+ from .data import Dataset
17
+ from .collator import Collator
18
+ from .utils import getBiomartTable
19
+
20
+
21
+ class DataModule(L.LightningDataModule):
22
+ def __init__(
23
+ self,
24
+ collection_name: str,
25
+ label_to_weight: list = ["organism_ontology_term_id"],
26
+ organisms: list = ["NCBITaxon:9606"],
27
+ weight_scaler: int = 10,
28
+ train_oversampling_per_epoch: float = 0.1,
29
+ validation_split: float = 0.2,
30
+ test_split: float = 0,
31
+ gene_embeddings: str = "",
32
+ use_default_col: bool = True,
33
+ gene_position_tolerance: int = 10_000,
34
+ # this is for the mappedCollection
35
+ label_to_pred: list = ["organism_ontology_term_id"],
36
+ all_labels: list = ["organism_ontology_term_id"],
37
+ hierarchical_labels: list = [],
38
+ # this is for the collator
39
+ how: str = "random expr",
40
+ organism_name: str = "organism_ontology_term_id",
41
+ max_len: int = 1000,
42
+ add_zero_genes: int = 100,
43
+ do_gene_pos: Union[bool, str] = True,
44
+ tp_name: Optional[str] = None, # "heat_diff"
45
+ assays_to_drop: list = [
46
+ "EFO:0008853",
47
+ "EFO:0010961",
48
+ "EFO:0030007",
49
+ "EFO:0030062",
50
+ ],
51
+ **kwargs,
52
+ ):
53
+ """
54
+ DataModule a pytorch lighting datamodule directly from a lamin Collection.
55
+ it can work with bare pytorch too
56
+
57
+ It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
58
+ This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
59
+
60
+ Args:
61
+ collection_name (str): The lamindb collection to be used.
62
+ weight_scaler (int, optional): how much more you will see the most present vs less present category.
63
+ gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
64
+ any genes within this distance of each other will be considered at the same position.
65
+ gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
66
+ the file must have ensembl_gene_id as index.
67
+ This is used to subset the available genes further to the ones that have embeddings in your model.
68
+ organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
69
+ label_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
70
+ validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
71
+ test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
72
+ it will use a full dataset and will round to the nearest dataset's cell count.
73
+ **other args: see @file data.py and @file collator.py for more details
74
+ **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
75
+ """
76
+ if collection_name is not None:
77
+ mdataset = Dataset(
78
+ ln.Collection.filter(name=collection_name).first(),
79
+ organisms=organisms,
80
+ obs=all_labels,
81
+ clss_to_pred=label_to_pred,
82
+ hierarchical_clss=hierarchical_labels,
83
+ )
84
+ print(mdataset)
85
+ # and location
86
+ if do_gene_pos:
87
+ if type(do_gene_pos) is str:
88
+ print("seeing a string: loading gene positions as biomart parquet file")
89
+ biomart = pd.read_parquet(do_gene_pos)
90
+ else:
91
+ # and annotations
92
+ biomart = getBiomartTable(
93
+ attributes=["start_position", "chromosome_name"]
94
+ ).set_index("ensembl_gene_id")
95
+ biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
96
+ biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
97
+ c = []
98
+ i = 0
99
+ prev_position = -100000
100
+ prev_chromosome = None
101
+ for _, r in biomart.iterrows():
102
+ if (
103
+ r["chromosome_name"] != prev_chromosome
104
+ or r["start_position"] - prev_position > gene_position_tolerance
105
+ ):
106
+ i += 1
107
+ c.append(i)
108
+ prev_position = r["start_position"]
109
+ prev_chromosome = r["chromosome_name"]
110
+ print(f"reduced the size to {len(set(c))/len(biomart)}")
111
+ biomart["pos"] = c
112
+ mdataset.genedf = biomart.loc[mdataset.genedf.index]
113
+ self.gene_pos = mdataset.genedf["pos"].tolist()
114
+
115
+ if gene_embeddings != "":
116
+ mdataset.genedf = mdataset.genedf.join(
117
+ pd.read_parquet(gene_embeddings), how="inner"
118
+ )
119
+ if do_gene_pos:
120
+ self.gene_pos = mdataset.genedf["pos"].tolist()
121
+ self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
122
+ # we might want not to order the genes by expression (or do it?)
123
+ # we might want to not introduce zeros and
124
+ if use_default_col:
125
+ kwargs["collate_fn"] = Collator(
126
+ organisms=organisms,
127
+ how=how,
128
+ valid_genes=mdataset.genedf.index.tolist(),
129
+ max_len=max_len,
130
+ add_zero_genes=add_zero_genes,
131
+ org_to_id=mdataset.encoder[organism_name],
132
+ tp_name=tp_name,
133
+ organism_name=organism_name,
134
+ class_names=label_to_weight,
135
+ )
136
+ self.validation_split = validation_split
137
+ self.test_split = test_split
138
+ self.dataset = mdataset
139
+ self.kwargs = kwargs
140
+ self.assays_to_drop = assays_to_drop
141
+ self.n_samples = len(mdataset)
142
+ self.weight_scaler = weight_scaler
143
+ self.train_oversampling_per_epoch = train_oversampling_per_epoch
144
+ self.label_to_weight = label_to_weight
145
+ self.train_weights = None
146
+ self.train_labels = None
147
+ super().__init__()
148
+
149
+ def __repr__(self):
150
+ return (
151
+ f"DataLoader(\n"
152
+ f"\twith a dataset=({self.dataset.__repr__()}\n)\n"
153
+ f"\tvalidation_split={self.validation_split},\n"
154
+ f"\ttest_split={self.test_split},\n"
155
+ f"\tn_samples={self.n_samples},\n"
156
+ f"\tweight_scaler={self.weight_scaler},\n"
157
+ f"\train_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
158
+ f"\tlabel_to_weight={self.label_to_weight}\n"
159
+ + (
160
+ "\twith train_dataset size of=("
161
+ + str((self.train_weights != 0).sum())
162
+ + ")\n)"
163
+ )
164
+ if self.train_weights is not None
165
+ else ")"
166
+ )
167
+
168
+ @property
169
+ def decoders(self):
170
+ """
171
+ decoders the decoders for any labels that would have been encoded
172
+
173
+ Returns:
174
+ dict[str, dict[int, str]]
175
+ """
176
+ decoders = {}
177
+ for k, v in self.dataset.encoder.items():
178
+ decoders[k] = {va: ke for ke, va in v.items()}
179
+ return decoders
180
+
181
+ @property
182
+ def cls_hierarchy(self):
183
+ """
184
+ cls_hierarchy the hierarchy of labels for any cls that would have a hierarchy
185
+
186
+ Returns:
187
+ dict[str, dict[str, str]]
188
+ """
189
+ cls_hierarchy = {}
190
+ for k, dic in self.dataset.class_groupings.items():
191
+ rdic = {}
192
+ for sk, v in dic.items():
193
+ rdic[self.dataset.encoder[k][sk]] = [
194
+ self.dataset.encoder[k][i] for i in list(v)
195
+ ]
196
+ cls_hierarchy[k] = rdic
197
+ return cls_hierarchy
198
+
199
+ @property
200
+ def genes(self):
201
+ """
202
+ genes the genes used in this datamodule
203
+
204
+ Returns:
205
+ list
206
+ """
207
+ return self.dataset.genedf.index.tolist()
208
+
209
+ @property
210
+ def num_datasets(self):
211
+ return len(self.dataset.mapped_dataset.storages)
212
+
213
+ def setup(self, stage=None):
214
+ """
215
+ setup method is used to prepare the data for the training, validation, and test sets.
216
+ It shuffles the data, calculates weights for each set, and creates samplers for each set.
217
+
218
+ Args:
219
+ stage (str, optional): The stage of the model training process.
220
+ It can be either 'fit' or 'test'. Defaults to None.
221
+ """
222
+ if len(self.label_to_weight) > 0:
223
+ weights, labels = self.dataset.get_label_weights(
224
+ self.label_to_weight, scaler=self.weight_scaler
225
+ )
226
+ else:
227
+ weights = np.ones(1)
228
+ labels = np.zeros(self.n_samples)
229
+ if isinstance(self.validation_split, int):
230
+ len_valid = self.validation_split
231
+ else:
232
+ len_valid = int(self.n_samples * self.validation_split)
233
+ if isinstance(self.test_split, int):
234
+ len_test = self.test_split
235
+ else:
236
+ len_test = int(self.n_samples * self.test_split)
237
+ assert (
238
+ len_test + len_valid < self.n_samples
239
+ ), "test set + valid set size is configured to be larger than entire dataset."
240
+
241
+ idx_full = []
242
+ if len(self.assays_to_drop) > 0:
243
+ for i, a in enumerate(
244
+ self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
245
+ ):
246
+ if a not in self.assays_to_drop:
247
+ idx_full.append(i)
248
+ idx_full = np.array(idx_full)
249
+ else:
250
+ idx_full = np.arange(self.n_samples)
251
+ test_datasets = []
252
+ if len_test > 0:
253
+ # this way we work on some never seen datasets
254
+ # keeping at least one
255
+ len_test = (
256
+ len_test
257
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
258
+ else self.dataset.mapped_dataset.n_obs_list[0]
259
+ )
260
+ cs = 0
261
+ print("these files will be considered test datasets:")
262
+ for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
263
+ if cs + c > len_test:
264
+ break
265
+ else:
266
+ print(" " + self.dataset.mapped_dataset.path_list[i].path)
267
+ test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
268
+ cs += c
269
+
270
+ len_test = cs
271
+ print("perc test: ", len_test / self.n_samples)
272
+ self.test_idx = idx_full[:len_test]
273
+ idx_full = idx_full[len_test:]
274
+ else:
275
+ self.test_idx = None
276
+
277
+ np.random.shuffle(idx_full)
278
+ if len_valid > 0:
279
+ self.valid_idx = idx_full[:len_valid].copy()
280
+ idx_full = idx_full[len_valid:]
281
+ else:
282
+ self.valid_idx = None
283
+ weights = np.concatenate([weights, np.zeros(1)])
284
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
285
+
286
+ self.train_weights = weights
287
+ self.train_labels = labels
288
+ self.idx_full = idx_full
289
+
290
+ return test_datasets
291
+
292
+ def train_dataloader(self, **kwargs):
293
+ # train_sampler = WeightedRandomSampler(
294
+ # self.train_weights[self.train_labels],
295
+ # int(self.n_samples*self.train_oversampling_per_epoch),
296
+ # replacement=True,
297
+ # )
298
+ train_sampler = LabelWeightedSampler(
299
+ self.train_weights,
300
+ self.train_labels,
301
+ num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
302
+ # replacement=True,
303
+ )
304
+ return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
305
+
306
+ def val_dataloader(self):
307
+ return (
308
+ DataLoader(
309
+ self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
310
+ )
311
+ if self.valid_idx is not None
312
+ else None
313
+ )
314
+
315
+ def test_dataloader(self):
316
+ return (
317
+ DataLoader(
318
+ self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
319
+ )
320
+ if self.test_idx is not None
321
+ else None
322
+ )
323
+
324
+ # def teardown(self):
325
+ # clean up state after the trainer stops, delete files...
326
+ # called on every process in DDP
327
+ # pass
328
+
329
+
330
+ class LabelWeightedSampler(Sampler[int]):
331
+ label_weights: Sequence[float]
332
+ klass_indices: Sequence[Sequence[int]]
333
+ num_samples: int
334
+
335
+ # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
336
+ # this will result a class-balanced sampling, no matter how imbalance the labels are.
337
+ # NOTE: here we use replacement=True, you can change it if you don't upsample a class.
338
+ def __init__(
339
+ self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
340
+ ) -> None:
341
+ """
342
+
343
+ :param label_weights: list(len=num_classes)[float], weights for each class.
344
+ :param labels: list(len=dataset_len)[int], labels of a dataset.
345
+ :param num_samples: number of samples.
346
+ """
347
+
348
+ super(LabelWeightedSampler, self).__init__(None)
349
+ # reweight labels from counter otherwsie same weight to labels that have many elements vs a few
350
+ label_weights = np.array(label_weights) * np.bincount(labels)
351
+
352
+ self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
353
+ self.labels = torch.as_tensor(labels, dtype=torch.int)
354
+ self.num_samples = num_samples
355
+ # list of tensor.
356
+ self.klass_indices = [
357
+ (self.labels == i_klass).nonzero().squeeze(1)
358
+ for i_klass in range(len(label_weights))
359
+ ]
360
+
361
+ def __iter__(self):
362
+ sample_labels = torch.multinomial(
363
+ self.label_weights, num_samples=self.num_samples, replacement=True
364
+ )
365
+ sample_indices = torch.empty_like(sample_labels)
366
+ for i_klass, klass_index in enumerate(self.klass_indices):
367
+ if klass_index.numel() == 0:
368
+ continue
369
+ left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
370
+ right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
371
+ sample_indices[left_inds] = klass_index[right_inds]
372
+ yield from iter(sample_indices.tolist())
373
+
374
+ def __len__(self):
375
+ return self.num_samples