nshtrainer 0.34.1__py3-none-any.whl → 0.34.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,6 +81,22 @@ class BalancedBatchSampler(BatchSampler):
81
81
  ):
82
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
83
83
 
84
+ # Validate the dataset
85
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
86
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
87
+ # implementation of the dataset sizes function.
88
+ if isinstance(dataset, DatasetWithSizes):
89
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
90
+
91
+ elif self._data_sizes_fn is not None:
92
+ log.critical("BalancedBatchSampler: Using custom data_sizes_fn")
93
+ else:
94
+ raise ValueError(
95
+ "Dataset must implement the `data_sizes` method, "
96
+ "or a custom data_sizes_fn must be provided "
97
+ "to the BalancedBatchSampler."
98
+ )
99
+
84
100
  self._device = device
85
101
  self._data_sizes_fn = data_sizes_fn
86
102
 
@@ -97,7 +113,6 @@ class BalancedBatchSampler(BatchSampler):
97
113
  # Dataset much either implement `data_sizes`, or we need to provide a custom
98
114
  # implementation of the dataset sizes function.
99
115
  if isinstance(dataset, DatasetWithSizes):
100
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
116
  return dataset.data_sizes(indices)
102
117
 
103
118
  if (data_sizes_fn := self._data_sizes_fn) is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.34.1
3
+ Version: 0.34.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50
32
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
33
33
  nshtrainer/config.py,sha256=6U7B-kCIMrfEnF_y92RuBm1WfASW7k05Zsm2uHBzRrk,8205
34
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
- nshtrainer/data/balanced_batch_sampler.py,sha256=WAjhbO9EsZ_UadhdW3obBsjvEDMc2V-irpjegqIb7AI,4791
35
+ nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
36
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
37
37
  nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
38
38
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -97,6 +97,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
97
97
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
98
98
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
99
99
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
100
- nshtrainer-0.34.1.dist-info/METADATA,sha256=c_iXv-CQLl6kig2u3lmrP4EDSOdMvq8L2WewYiFL-8Q,916
101
- nshtrainer-0.34.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
- nshtrainer-0.34.1.dist-info/RECORD,,
100
+ nshtrainer-0.34.2.dist-info/METADATA,sha256=DQyYTUO0wpboH1gy3nSRJV6EsWCpY7Kb_ldD8v4BQFY,916
101
+ nshtrainer-0.34.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
+ nshtrainer-0.34.2.dist-info/RECORD,,