nshtrainer 1.0.0b11__tar.gz → 1.0.0b12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/pyproject.toml +1 -1
- nshtrainer-1.0.0b12/src/nshtrainer/data/datamodule.py +126 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/model/base.py +100 -2
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/_config.py +0 -1
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/trainer.py +38 -63
- nshtrainer-1.0.0b11/src/nshtrainer/data/datamodule.py +0 -57
- nshtrainer-1.0.0b11/src/nshtrainer/scripts/find_packages.py +0 -52
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/README.md +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Callable, Mapping
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Generic, cast
|
7
|
+
|
8
|
+
import nshconfig as C
|
9
|
+
import torch
|
10
|
+
from lightning.pytorch import LightningDataModule
|
11
|
+
from lightning.pytorch.utilities.model_helpers import is_overridden
|
12
|
+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
13
|
+
from typing_extensions import Never, TypeVar, deprecated, override
|
14
|
+
|
15
|
+
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
16
|
+
from ..model.mixins.debug import _DebugModuleMixin
|
17
|
+
|
18
|
+
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
19
|
+
|
20
|
+
|
21
|
+
class LightningDataModuleBase(
|
22
|
+
_DebugModuleMixin,
|
23
|
+
CallbackRegistrarModuleMixin,
|
24
|
+
LightningDataModule,
|
25
|
+
ABC,
|
26
|
+
Generic[THparams],
|
27
|
+
):
|
28
|
+
@property
|
29
|
+
@override
|
30
|
+
def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
31
|
+
return cast(THparams, super().hparams)
|
32
|
+
|
33
|
+
@property
|
34
|
+
@override
|
35
|
+
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
36
|
+
hparams = cast(THparams, super().hparams_initial)
|
37
|
+
return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
|
38
|
+
|
39
|
+
@property
|
40
|
+
@deprecated("Use `hparams` instead")
|
41
|
+
def config(self):
|
42
|
+
return cast(Never, self.hparams)
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
@abstractmethod
|
46
|
+
def hparams_cls(cls) -> type[THparams]: ...
|
47
|
+
|
48
|
+
@override
|
49
|
+
def __init__(self, hparams: THparams | Mapping[str, Any]):
|
50
|
+
super().__init__()
|
51
|
+
|
52
|
+
# Validate and save hyperparameters
|
53
|
+
hparams_cls = self.hparams_cls()
|
54
|
+
if isinstance(hparams, Mapping):
|
55
|
+
hparams = hparams_cls.model_validate(hparams)
|
56
|
+
elif not isinstance(hparams, hparams_cls):
|
57
|
+
raise TypeError(
|
58
|
+
f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
|
59
|
+
)
|
60
|
+
hparams = hparams.model_deep_validate()
|
61
|
+
self.save_hyperparameters(hparams)
|
62
|
+
|
63
|
+
@override
|
64
|
+
@classmethod
|
65
|
+
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
66
|
+
raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def hparams_from_checkpoint(
|
70
|
+
cls,
|
71
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
72
|
+
/,
|
73
|
+
strict: bool | None = None,
|
74
|
+
*,
|
75
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
76
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
77
|
+
):
|
78
|
+
if isinstance(ckpt_or_path, dict):
|
79
|
+
ckpt = ckpt_or_path
|
80
|
+
else:
|
81
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
82
|
+
|
83
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
84
|
+
raise ValueError(
|
85
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
86
|
+
)
|
87
|
+
if update_hparams_dict is not None:
|
88
|
+
hparams = update_hparams_dict(hparams)
|
89
|
+
|
90
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
91
|
+
if update_hparams is not None:
|
92
|
+
hparams = update_hparams(hparams)
|
93
|
+
|
94
|
+
return hparams
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def from_checkpoint(
|
98
|
+
cls,
|
99
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
100
|
+
/,
|
101
|
+
strict: bool | None = None,
|
102
|
+
map_location: torch.serialization.MAP_LOCATION = None,
|
103
|
+
*,
|
104
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
105
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
106
|
+
):
|
107
|
+
# Load checkpoint
|
108
|
+
if isinstance(ckpt_or_path, Mapping):
|
109
|
+
ckpt = ckpt_or_path
|
110
|
+
else:
|
111
|
+
ckpt = torch.load(ckpt_or_path, map_location=map_location)
|
112
|
+
|
113
|
+
# Load hyperparameters from checkpoint
|
114
|
+
hparams = cls.hparams_from_checkpoint(
|
115
|
+
ckpt,
|
116
|
+
strict=strict,
|
117
|
+
update_hparams=update_hparams,
|
118
|
+
update_hparams_dict=update_hparams_dict,
|
119
|
+
)
|
120
|
+
|
121
|
+
# Load datamodule from checkpoint
|
122
|
+
datamodule = cls(hparams)
|
123
|
+
if datamodule.__class__.__qualname__ in ckpt:
|
124
|
+
datamodule.load_state_dict(ckpt[datamodule.__class__.__qualname__])
|
125
|
+
|
126
|
+
return datamodule
|
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import Mapping
|
5
|
+
from collections.abc import Callable, Mapping
|
6
|
+
from pathlib import Path
|
6
7
|
from typing import Any, Generic, Literal, cast
|
7
8
|
|
8
9
|
import nshconfig as C
|
@@ -10,11 +11,13 @@ import torch
|
|
10
11
|
import torch.distributed
|
11
12
|
from lightning.pytorch import LightningModule
|
12
13
|
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
14
|
+
from lightning.pytorch.utilities.model_helpers import is_overridden
|
15
|
+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
13
16
|
from typing_extensions import Never, TypeVar, deprecated, override
|
14
17
|
|
15
18
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
16
19
|
from .mixins.callback import CallbackModuleMixin
|
17
|
-
from .mixins.debug import _DebugModuleMixin
|
20
|
+
from .mixins.debug import _DebugModuleMixin
|
18
21
|
from .mixins.logger import LoggerLightningModuleMixin
|
19
22
|
|
20
23
|
log = logging.getLogger(__name__)
|
@@ -241,3 +244,98 @@ class LightningModuleBase(
|
|
241
244
|
loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
|
242
245
|
loss = cast(torch.Tensor, loss)
|
243
246
|
return loss
|
247
|
+
|
248
|
+
@override
|
249
|
+
@classmethod
|
250
|
+
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
251
|
+
raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def hparams_from_checkpoint(
|
255
|
+
cls,
|
256
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
257
|
+
/,
|
258
|
+
strict: bool | None = None,
|
259
|
+
*,
|
260
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
261
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
262
|
+
):
|
263
|
+
if isinstance(ckpt_or_path, dict):
|
264
|
+
ckpt = ckpt_or_path
|
265
|
+
else:
|
266
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
267
|
+
|
268
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
269
|
+
raise ValueError(
|
270
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
271
|
+
)
|
272
|
+
if update_hparams_dict is not None:
|
273
|
+
hparams = update_hparams_dict(hparams)
|
274
|
+
|
275
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
276
|
+
if update_hparams is not None:
|
277
|
+
hparams = update_hparams(hparams)
|
278
|
+
|
279
|
+
return hparams
|
280
|
+
|
281
|
+
@classmethod
|
282
|
+
def from_checkpoint(
|
283
|
+
cls,
|
284
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
285
|
+
/,
|
286
|
+
strict: bool | None = None,
|
287
|
+
map_location: torch.serialization.MAP_LOCATION = None,
|
288
|
+
*,
|
289
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
290
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
291
|
+
):
|
292
|
+
# Load checkpoint
|
293
|
+
if isinstance(ckpt_or_path, Mapping):
|
294
|
+
ckpt = ckpt_or_path
|
295
|
+
else:
|
296
|
+
ckpt = torch.load(ckpt_or_path, map_location=map_location)
|
297
|
+
|
298
|
+
# Load hyperparameters from checkpoint
|
299
|
+
hparams = cls.hparams_from_checkpoint(
|
300
|
+
ckpt,
|
301
|
+
strict=strict,
|
302
|
+
update_hparams=update_hparams,
|
303
|
+
update_hparams_dict=update_hparams_dict,
|
304
|
+
)
|
305
|
+
|
306
|
+
# Load model from checkpoint
|
307
|
+
model = cls(hparams)
|
308
|
+
|
309
|
+
# Load model state from checkpoint
|
310
|
+
if (
|
311
|
+
model._strict_loading is not None
|
312
|
+
and strict is not None
|
313
|
+
and strict != model.strict_loading
|
314
|
+
):
|
315
|
+
raise ValueError(
|
316
|
+
f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
|
317
|
+
f" `{cls.__name__}.strict_loading={model.strict_loading!r}. Please set the same value for both of them."
|
318
|
+
)
|
319
|
+
strict = model.strict_loading if strict is None else strict
|
320
|
+
|
321
|
+
if is_overridden("configure_model", model):
|
322
|
+
model.configure_model()
|
323
|
+
|
324
|
+
# give model a chance to load something
|
325
|
+
model.on_load_checkpoint(ckpt)
|
326
|
+
|
327
|
+
# load the state_dict on the model automatically
|
328
|
+
|
329
|
+
keys = model.load_state_dict(ckpt["state_dict"], strict=strict)
|
330
|
+
|
331
|
+
if not strict:
|
332
|
+
if keys.missing_keys:
|
333
|
+
rank_zero_warn(
|
334
|
+
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
|
335
|
+
)
|
336
|
+
if keys.unexpected_keys:
|
337
|
+
rank_zero_warn(
|
338
|
+
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
|
339
|
+
)
|
340
|
+
|
341
|
+
return model
|
@@ -2,28 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import os
|
5
|
-
from collections.abc import Mapping, Sequence
|
5
|
+
from collections.abc import Callable, Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import
|
7
|
+
from typing import TYPE_CHECKING, Any, cast
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
11
11
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
12
12
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
13
|
-
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
14
|
-
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
15
13
|
from lightning.pytorch import LightningModule
|
16
14
|
from lightning.pytorch import Trainer as LightningTrainer
|
17
15
|
from lightning.pytorch.callbacks import Callback
|
18
|
-
from lightning.pytorch.core.saving import (
|
19
|
-
_default_map_location,
|
20
|
-
load_hparams_from_tags_csv,
|
21
|
-
load_hparams_from_yaml,
|
22
|
-
)
|
23
16
|
from lightning.pytorch.profilers import Profiler
|
24
17
|
from lightning.pytorch.trainer.states import TrainerFn
|
25
|
-
from lightning.pytorch.utilities.migration import pl_legacy_patch
|
26
|
-
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
|
27
18
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
28
19
|
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
29
20
|
|
@@ -473,62 +464,46 @@ class Trainer(LightningTrainer):
|
|
473
464
|
_callback._call_on_checkpoint_saved(self, filepath, metadata_path)
|
474
465
|
|
475
466
|
@classmethod
|
476
|
-
def
|
467
|
+
def hparams_from_checkpoint(
|
477
468
|
cls,
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
469
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
470
|
+
/,
|
471
|
+
strict: bool | None = None,
|
472
|
+
*,
|
473
|
+
update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
|
474
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
482
475
|
):
|
483
|
-
|
484
|
-
|
485
|
-
map_location=map_location,
|
486
|
-
hparams_file=hparams_file,
|
487
|
-
**kwargs,
|
488
|
-
)
|
489
|
-
return loaded
|
490
|
-
|
491
|
-
|
492
|
-
def _load_from_checkpoint(
|
493
|
-
checkpoint_path: _PATH | IO,
|
494
|
-
map_location: _MAP_LOCATION_TYPE = None,
|
495
|
-
hparams_file: _PATH | None = None,
|
496
|
-
**kwargs: Any,
|
497
|
-
):
|
498
|
-
map_location = map_location or _default_map_location
|
499
|
-
with pl_legacy_patch():
|
500
|
-
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
501
|
-
|
502
|
-
# convert legacy checkpoints to the new format
|
503
|
-
checkpoint = _pl_migrate_checkpoint(
|
504
|
-
checkpoint,
|
505
|
-
checkpoint_path=(
|
506
|
-
checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None
|
507
|
-
),
|
508
|
-
)
|
509
|
-
|
510
|
-
if hparams_file is not None:
|
511
|
-
extension = str(hparams_file).split(".")[-1]
|
512
|
-
if extension.lower() == "csv":
|
513
|
-
hparams = load_hparams_from_tags_csv(hparams_file)
|
514
|
-
elif extension.lower() in ("yml", "yaml"):
|
515
|
-
hparams = load_hparams_from_yaml(hparams_file)
|
476
|
+
if isinstance(ckpt_or_path, dict):
|
477
|
+
ckpt = ckpt_or_path
|
516
478
|
else:
|
517
|
-
|
479
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
518
480
|
|
519
|
-
|
520
|
-
|
481
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
482
|
+
raise ValueError(
|
483
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
484
|
+
)
|
485
|
+
if update_hparams_dict is not None:
|
486
|
+
hparams = update_hparams_dict(hparams)
|
521
487
|
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
|
488
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
489
|
+
if update_hparams is not None:
|
490
|
+
hparams = update_hparams(hparams)
|
526
491
|
|
527
|
-
|
528
|
-
hparams = Trainer.hparams_cls().model_validate(
|
529
|
-
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY]
|
530
|
-
)
|
492
|
+
return hparams
|
531
493
|
|
532
|
-
|
533
|
-
|
534
|
-
|
494
|
+
@classmethod
|
495
|
+
def from_checkpoint(
|
496
|
+
cls,
|
497
|
+
path: str | Path,
|
498
|
+
strict: bool | None = None,
|
499
|
+
*,
|
500
|
+
update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
|
501
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
502
|
+
):
|
503
|
+
hparams = cls.hparams_from_checkpoint(
|
504
|
+
path,
|
505
|
+
strict=strict,
|
506
|
+
update_hparams=update_hparams,
|
507
|
+
update_hparams_dict=update_hparams_dict,
|
508
|
+
)
|
509
|
+
return cls(hparams)
|
@@ -1,57 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from abc import ABC, abstractmethod
|
4
|
-
from collections.abc import Mapping
|
5
|
-
from typing import Any, Generic, cast
|
6
|
-
|
7
|
-
import nshconfig as C
|
8
|
-
from lightning.pytorch import LightningDataModule
|
9
|
-
from typing_extensions import Never, TypeVar, deprecated, override
|
10
|
-
|
11
|
-
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
12
|
-
from ..model.mixins.debug import _DebugModuleMixin
|
13
|
-
|
14
|
-
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
15
|
-
|
16
|
-
|
17
|
-
class LightningDataModuleBase(
|
18
|
-
_DebugModuleMixin,
|
19
|
-
CallbackRegistrarModuleMixin,
|
20
|
-
LightningDataModule,
|
21
|
-
ABC,
|
22
|
-
Generic[THparams],
|
23
|
-
):
|
24
|
-
@property
|
25
|
-
@override
|
26
|
-
def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
27
|
-
return cast(THparams, super().hparams)
|
28
|
-
|
29
|
-
@property
|
30
|
-
@override
|
31
|
-
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
32
|
-
hparams = cast(THparams, super().hparams_initial)
|
33
|
-
return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
|
34
|
-
|
35
|
-
@property
|
36
|
-
@deprecated("Use `hparams` instead")
|
37
|
-
def config(self):
|
38
|
-
return cast(Never, self.hparams)
|
39
|
-
|
40
|
-
@classmethod
|
41
|
-
@abstractmethod
|
42
|
-
def hparams_cls(cls) -> type[THparams]: ...
|
43
|
-
|
44
|
-
@override
|
45
|
-
def __init__(self, hparams: THparams | Mapping[str, Any]):
|
46
|
-
super().__init__()
|
47
|
-
|
48
|
-
# Validate and save hyperparameters
|
49
|
-
hparams_cls = self.hparams_cls()
|
50
|
-
if isinstance(hparams, Mapping):
|
51
|
-
hparams = hparams_cls.model_validate(hparams)
|
52
|
-
elif not isinstance(hparams, hparams_cls):
|
53
|
-
raise TypeError(
|
54
|
-
f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
|
55
|
-
)
|
56
|
-
hparams = hparams.model_deep_validate()
|
57
|
-
self.save_hyperparameters(hparams)
|
@@ -1,52 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import argparse
|
4
|
-
import ast
|
5
|
-
import glob
|
6
|
-
import sys
|
7
|
-
from pathlib import Path
|
8
|
-
|
9
|
-
|
10
|
-
def get_imports(file_path: Path):
|
11
|
-
with open(file_path, "r") as file:
|
12
|
-
try:
|
13
|
-
tree = ast.parse(file.read())
|
14
|
-
except SyntaxError:
|
15
|
-
print(f"Syntax error in file: {file_path}", file=sys.stderr)
|
16
|
-
return set()
|
17
|
-
|
18
|
-
imports = set()
|
19
|
-
for node in ast.walk(tree):
|
20
|
-
if isinstance(node, ast.Import):
|
21
|
-
for alias in node.names:
|
22
|
-
imports.add(alias.name.split(".")[0])
|
23
|
-
elif isinstance(node, ast.ImportFrom):
|
24
|
-
if node.level == 0 and node.module: # Absolute import
|
25
|
-
imports.add(node.module.split(".")[0])
|
26
|
-
return imports
|
27
|
-
|
28
|
-
|
29
|
-
def main():
|
30
|
-
parser = argparse.ArgumentParser(
|
31
|
-
description="Find unique Python packages used in files."
|
32
|
-
)
|
33
|
-
parser.add_argument("glob_pattern", help="Glob pattern to match files")
|
34
|
-
parser.add_argument(
|
35
|
-
"--exclude-std", action="store_true", help="Exclude Python standard libraries"
|
36
|
-
)
|
37
|
-
args = parser.parse_args()
|
38
|
-
|
39
|
-
all_imports = set()
|
40
|
-
for file_path in glob.glob(args.glob_pattern, recursive=True):
|
41
|
-
all_imports.update(get_imports(Path(file_path)))
|
42
|
-
|
43
|
-
if args.exclude_std:
|
44
|
-
std_libs = set(sys.stdlib_module_names)
|
45
|
-
all_imports = all_imports - std_libs
|
46
|
-
|
47
|
-
for package in sorted(all_imports):
|
48
|
-
print(package)
|
49
|
-
|
50
|
-
|
51
|
-
if __name__ == "__main__":
|
52
|
-
main()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_checkpoint/loader/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/actsave/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/base/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/ema/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/print_table/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/timer/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/_base/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/actsave/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/tensorboard/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/loggers/wandb/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/metrics/_config/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/nn/nonlinearity/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/_base/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/advanced/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/pytorch/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/profiler/simple/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/trainer/_config/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/trainer/trainer/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/config/dtype/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/configs/util/config/duration/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py
RENAMED
File without changes
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b12}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|