nshtrainer 1.0.0b17__py3-none-any.whl → 1.0.0b19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -732,9 +732,7 @@ class TrainerConfig(C.Config):
732
732
  automatic selection based on the chosen accelerator. Default: ``"auto"``.
733
733
  """
734
734
 
735
- shared_parameters: SharedParametersCallbackConfig | None = (
736
- SharedParametersCallbackConfig()
737
- )
735
+ shared_parameters: SharedParametersCallbackConfig | None = None
738
736
  """If enabled, the model supports scaling the gradients of shared parameters that
739
737
  are registered in the self.shared_parameters list. This is useful for models that
740
738
  share parameters across multiple modules (e.g., in a GPT model) and want to
@@ -802,6 +800,10 @@ class TrainerConfig(C.Config):
802
800
  )
803
801
 
804
802
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
803
+ # Disable all callbacks if barebones mode is enabled
804
+ if self.barebones:
805
+ return
806
+
805
807
  yield self.early_stopping
806
808
  yield self.checkpoint_saving
807
809
  yield self.lr_monitor
@@ -823,6 +825,11 @@ class TrainerConfig(C.Config):
823
825
  yield from self.enabled_loggers()
824
826
  yield self.actsave_logger
825
827
 
828
+ def _nshtrainer_validate_before_run(self):
829
+ # shared_parameters is not supported under barebones mode
830
+ if self.barebones and self.shared_parameters:
831
+ raise ValueError("shared_parameters is not supported under barebones mode")
832
+
826
833
  # region Helper Methods
827
834
  def fast_dev_run_(self, value: int | bool = True, /):
828
835
  """
@@ -134,6 +134,10 @@ class Trainer(LightningTrainer):
134
134
  for key, value in update.items():
135
135
  _update_key(key, value)
136
136
 
137
+ # Set `barebones`
138
+ if hparams.barebones:
139
+ _update_kwargs(barebones=True)
140
+
137
141
  # Set `default_root_dir` if `auto_set_default_root_dir` is enabled.
138
142
  if hparams.auto_set_default_root_dir:
139
143
  if kwargs.get("default_root_dir"):
@@ -296,6 +300,7 @@ class Trainer(LightningTrainer):
296
300
  f"Got {type(hparams)=} instead."
297
301
  )
298
302
  hparams = hparams.model_deep_validate()
303
+ hparams._nshtrainer_validate_before_run()
299
304
 
300
305
  self._pre_init(hparams)
301
306
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b17
3
+ Version: 1.0.0b19
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -119,10 +119,10 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
119
119
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
120
120
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
121
121
  nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
122
- nshtrainer/trainer/_config.py,sha256=Bx-32uNvmJ68qlRoK8xxh71gZcqkIkDJNkkOWoU5Hnc,33486
122
+ nshtrainer/trainer/_config.py,sha256=2AIr8w_ysRtn0yo49rwdduyBJ9bIAVpQdRpJoMg9Cd0,33806
123
123
  nshtrainer/trainer/_runtime_callback.py,sha256=T3epaj1YeIN0R8CS2cg5HNJIB21TyaD_PVNNOPJ6nJs,4200
124
124
  nshtrainer/trainer/signal_connector.py,sha256=YMJf6vTnW0JcnBkuYikm9x_9XscaokrCEzCn4THOGao,10776
125
- nshtrainer/trainer/trainer.py,sha256=vfQAr5H5HkDlIxjdEP8yhLDKplodxIws3sx3u_8qbkc,19381
125
+ nshtrainer/trainer/trainer.py,sha256=V5aRA6hBSxYi-Hbp-lg6b5mRCw_bc_0QzkJ7LG0c49M,19531
126
126
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
127
127
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
128
128
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
@@ -135,6 +135,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
135
135
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
136
136
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
137
137
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
138
- nshtrainer-1.0.0b17.dist-info/METADATA,sha256=6-ZUgdzm04noloz7OD0Y3jxg-pBe2u-WqYS4ZGFhWIU,937
139
- nshtrainer-1.0.0b17.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
140
- nshtrainer-1.0.0b17.dist-info/RECORD,,
138
+ nshtrainer-1.0.0b19.dist-info/METADATA,sha256=5WTNG0hJXpVLs3iZrpsIjgbccJeuIMG7i7t4MyAeIt4,937
139
+ nshtrainer-1.0.0b19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
140
+ nshtrainer-1.0.0b19.dist-info/RECORD,,