nshtrainer 0.34.0__py3-none-any.whl → 0.34.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/config.py CHANGED
@@ -1,6 +1,3 @@
1
- from nshconfig._config import Config as Config
2
- from nshsnap._config import SnapshotConfig as SnapshotConfig
3
-
4
1
  from nshtrainer._checkpoint.loader import (
5
2
  BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
6
3
  )
@@ -65,13 +62,13 @@ from nshtrainer.callbacks.throughput_monitor import (
65
62
  )
66
63
  from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
67
64
  from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
68
- from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
69
65
  from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
70
66
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
71
67
  from nshtrainer.loggers.tensorboard import (
72
68
  TensorboardLoggerConfig as TensorboardLoggerConfig,
73
69
  )
74
70
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
71
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
75
72
  from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
76
73
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
77
74
  DurationConfig as DurationConfig,
@@ -81,6 +81,22 @@ class BalancedBatchSampler(BatchSampler):
81
81
  ):
82
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
83
83
 
84
+ # Validate the dataset
85
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
86
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
87
+ # implementation of the dataset sizes function.
88
+ if isinstance(dataset, DatasetWithSizes):
89
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
90
+
91
+ elif self._data_sizes_fn is not None:
92
+ log.critical("BalancedBatchSampler: Using custom data_sizes_fn")
93
+ else:
94
+ raise ValueError(
95
+ "Dataset must implement the `data_sizes` method, "
96
+ "or a custom data_sizes_fn must be provided "
97
+ "to the BalancedBatchSampler."
98
+ )
99
+
84
100
  self._device = device
85
101
  self._data_sizes_fn = data_sizes_fn
86
102
 
@@ -97,7 +113,6 @@ class BalancedBatchSampler(BatchSampler):
97
113
  # Dataset much either implement `data_sizes`, or we need to provide a custom
98
114
  # implementation of the dataset sizes function.
99
115
  if isinstance(dataset, DatasetWithSizes):
100
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
116
  return dataset.data_sizes(indices)
102
117
 
103
118
  if (data_sizes_fn := self._data_sizes_fn) is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.34.0
3
+ Version: 0.34.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -30,9 +30,9 @@ nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5Aeh
30
30
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
31
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
32
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
33
- nshtrainer/config.py,sha256=skar_Wfz50_sU2NZS8PEjqofWeon4g4cyIgby3Da81g,8308
33
+ nshtrainer/config.py,sha256=6U7B-kCIMrfEnF_y92RuBm1WfASW7k05Zsm2uHBzRrk,8205
34
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
- nshtrainer/data/balanced_batch_sampler.py,sha256=WAjhbO9EsZ_UadhdW3obBsjvEDMc2V-irpjegqIb7AI,4791
35
+ nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
36
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
37
37
  nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
38
38
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -97,6 +97,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
97
97
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
98
98
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
99
99
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
100
- nshtrainer-0.34.0.dist-info/METADATA,sha256=GYC9ejdKV3MCyOFhJcFjI-uedTWLGWj-SE5S79ruug4,916
101
- nshtrainer-0.34.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
- nshtrainer-0.34.0.dist-info/RECORD,,
100
+ nshtrainer-0.34.2.dist-info/METADATA,sha256=DQyYTUO0wpboH1gy3nSRJV6EsWCpY7Kb_ldD8v4BQFY,916
101
+ nshtrainer-0.34.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
+ nshtrainer-0.34.2.dist-info/RECORD,,