nshtrainer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer-0.1.0/PKG-INFO +18 -0
- nshtrainer-0.1.0/README.md +0 -0
- nshtrainer-0.1.0/pyproject.toml +18 -0
- nshtrainer-0.1.0/src/nshtrainer/__init__.py +64 -0
- nshtrainer-0.1.0/src/nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer-0.1.0/src/nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer-0.1.0/src/nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer-0.1.0/src/nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer-0.1.0/src/nshtrainer/_snoop.py +216 -0
- nshtrainer-0.1.0/src/nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer-0.1.0/src/nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer-0.1.0/src/nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer-0.1.0/src/nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer-0.1.0/src/nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer-0.1.0/src/nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer-0.1.0/src/nshtrainer/actsave/__init__.py +7 -0
- nshtrainer-0.1.0/src/nshtrainer/actsave/_callback.py +75 -0
- nshtrainer-0.1.0/src/nshtrainer/actsave/_loader.py +144 -0
- nshtrainer-0.1.0/src/nshtrainer/actsave/_saver.py +337 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/base.py +113 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/ema.py +383 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/interval.py +322 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/timer.py +157 -0
- nshtrainer-0.1.0/src/nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer-0.1.0/src/nshtrainer/config.py +289 -0
- nshtrainer-0.1.0/src/nshtrainer/data/__init__.py +4 -0
- nshtrainer-0.1.0/src/nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer-0.1.0/src/nshtrainer/data/transform.py +67 -0
- nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer-0.1.0/src/nshtrainer/model/__init__.py +44 -0
- nshtrainer-0.1.0/src/nshtrainer/model/base.py +641 -0
- nshtrainer-0.1.0/src/nshtrainer/model/config.py +2064 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/callback.py +157 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/debug.py +42 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/logger.py +170 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer-0.1.0/src/nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer-0.1.0/src/nshtrainer/nn/__init__.py +19 -0
- nshtrainer-0.1.0/src/nshtrainer/nn/mlp.py +106 -0
- nshtrainer-0.1.0/src/nshtrainer/nn/module_dict.py +66 -0
- nshtrainer-0.1.0/src/nshtrainer/nn/module_list.py +50 -0
- nshtrainer-0.1.0/src/nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer-0.1.0/src/nshtrainer/optimizer.py +62 -0
- nshtrainer-0.1.0/src/nshtrainer/runner.py +21 -0
- nshtrainer-0.1.0/src/nshtrainer/scripts/check_env.py +41 -0
- nshtrainer-0.1.0/src/nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer-0.1.0/src/nshtrainer/trainer/__init__.py +1 -0
- nshtrainer-0.1.0/src/nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer-0.1.0/src/nshtrainer/trainer/trainer.py +340 -0
- nshtrainer-0.1.0/src/nshtrainer/typecheck.py +144 -0
- nshtrainer-0.1.0/src/nshtrainer/util/environment.py +119 -0
- nshtrainer-0.1.0/src/nshtrainer/util/seed.py +11 -0
- nshtrainer-0.1.0/src/nshtrainer/util/singleton.py +89 -0
- nshtrainer-0.1.0/src/nshtrainer/util/slurm.py +49 -0
- nshtrainer-0.1.0/src/nshtrainer/util/typed.py +2 -0
- nshtrainer-0.1.0/src/nshtrainer/util/typing_utils.py +19 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: nshtrainer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary:
|
|
5
|
+
Author: Nima Shoghi
|
|
6
|
+
Author-email: nimashoghi@gmail.com
|
|
7
|
+
Requires-Python: >=3.10,<4.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
|
|
13
|
+
Requires-Dist: nshrunner (>=0.1.0,<0.2.0)
|
|
14
|
+
Requires-Dist: torch
|
|
15
|
+
Requires-Dist: typing-extensions
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "nshtrainer"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = ""
|
|
5
|
+
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
|
|
8
|
+
[tool.poetry.dependencies]
|
|
9
|
+
python = "^3.10"
|
|
10
|
+
nshconfig = "^0.2.0"
|
|
11
|
+
nshrunner = "^0.1.0"
|
|
12
|
+
torch = "*"
|
|
13
|
+
typing-extensions = "*"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
[build-system]
|
|
17
|
+
requires = ["poetry-core"]
|
|
18
|
+
build-backend = "poetry.core.masonry.api"
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from . import _experimental as _experimental
|
|
2
|
+
from . import actsave as actsave
|
|
3
|
+
from . import callbacks as callbacks
|
|
4
|
+
from . import lr_scheduler as lr_scheduler
|
|
5
|
+
from . import model as model
|
|
6
|
+
from . import nn as nn
|
|
7
|
+
from . import optimizer as optimizer
|
|
8
|
+
from . import snapshot as snapshot
|
|
9
|
+
from . import typecheck as typecheck
|
|
10
|
+
from ._snoop import snoop as snoop
|
|
11
|
+
from .actsave import ActLoad as ActLoad
|
|
12
|
+
from .actsave import ActSave as ActSave
|
|
13
|
+
from .config import MISSING as MISSING
|
|
14
|
+
from .config import AllowMissing as AllowMissing
|
|
15
|
+
from .config import Field as Field
|
|
16
|
+
from .config import MissingField as MissingField
|
|
17
|
+
from .config import PrivateAttr as PrivateAttr
|
|
18
|
+
from .config import TypedConfig as TypedConfig
|
|
19
|
+
from .data import dataset_transform as dataset_transform
|
|
20
|
+
from .log import init_python_logging as init_python_logging
|
|
21
|
+
from .log import lovely as lovely
|
|
22
|
+
from .log import pretty as pretty
|
|
23
|
+
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
|
+
from .model import ActSaveConfig as ActSaveConfig
|
|
25
|
+
from .model import Base as Base
|
|
26
|
+
from .model import BaseConfig as BaseConfig
|
|
27
|
+
from .model import BaseLoggerConfig as BaseLoggerConfig
|
|
28
|
+
from .model import BaseProfilerConfig as BaseProfilerConfig
|
|
29
|
+
from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
30
|
+
from .model import CheckpointSavingConfig as CheckpointSavingConfig
|
|
31
|
+
from .model import ConfigList as ConfigList
|
|
32
|
+
from .model import DirectoryConfig as DirectoryConfig
|
|
33
|
+
from .model import (
|
|
34
|
+
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
35
|
+
)
|
|
36
|
+
from .model import EnvironmentConfig as EnvironmentConfig
|
|
37
|
+
from .model import (
|
|
38
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
39
|
+
)
|
|
40
|
+
from .model import (
|
|
41
|
+
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
42
|
+
)
|
|
43
|
+
from .model import GradientClippingConfig as GradientClippingConfig
|
|
44
|
+
from .model import LightningDataModuleBase as LightningDataModuleBase
|
|
45
|
+
from .model import LightningModuleBase as LightningModuleBase
|
|
46
|
+
from .model import LoggingConfig as LoggingConfig
|
|
47
|
+
from .model import MetricConfig as MetricConfig
|
|
48
|
+
from .model import OptimizationConfig as OptimizationConfig
|
|
49
|
+
from .model import PrimaryMetricConfig as PrimaryMetricConfig
|
|
50
|
+
from .model import PythonLogging as PythonLogging
|
|
51
|
+
from .model import ReproducibilityConfig as ReproducibilityConfig
|
|
52
|
+
from .model import RunnerConfig as RunnerConfig
|
|
53
|
+
from .model import SanityCheckingConfig as SanityCheckingConfig
|
|
54
|
+
from .model import SeedConfig as SeedConfig
|
|
55
|
+
from .model import TrainerConfig as TrainerConfig
|
|
56
|
+
from .model import WandbWatchConfig as WandbWatchConfig
|
|
57
|
+
from .nn import TypedModuleDict as TypedModuleDict
|
|
58
|
+
from .nn import TypedModuleList as TypedModuleList
|
|
59
|
+
from .optimizer import OptimizerConfig as OptimizerConfig
|
|
60
|
+
from .runner import Runner as Runner
|
|
61
|
+
from .runner import SnapshotConfig as SnapshotConfig
|
|
62
|
+
from .trainer import Trainer as Trainer
|
|
63
|
+
from .util.singleton import Registry as Registry
|
|
64
|
+
from .util.singleton import Singleton as Singleton
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
|
5
|
+
|
|
6
|
+
MEASURE_FLOPS_AVAILABLE = _TORCH_GREATER_EQUAL_2_1
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def measure_flops(
|
|
10
|
+
forward_fn: Callable[[], torch.Tensor],
|
|
11
|
+
loss_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
12
|
+
display: bool = True,
|
|
13
|
+
) -> int:
|
|
14
|
+
"""Utility to compute the total number of FLOPs used by a module during training or during inference.
|
|
15
|
+
|
|
16
|
+
It's recommended to create a meta-device model for this:
|
|
17
|
+
|
|
18
|
+
Example::
|
|
19
|
+
|
|
20
|
+
with torch.device("meta"):
|
|
21
|
+
model = MyModel()
|
|
22
|
+
x = torch.randn(2, 32)
|
|
23
|
+
|
|
24
|
+
model_fwd = lambda: model(x)
|
|
25
|
+
fwd_flops = measure_flops(model, model_fwd)
|
|
26
|
+
|
|
27
|
+
model_loss = lambda y: y.sum()
|
|
28
|
+
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: The model whose FLOPs should be measured.
|
|
32
|
+
forward_fn: A function that runs ``forward`` on the model and returns the result.
|
|
33
|
+
loss_fn: A function that computes the loss given the ``forward_fn`` output. If provided, the loss and `backward`
|
|
34
|
+
FLOPs will be included in the result.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
if not MEASURE_FLOPS_AVAILABLE:
|
|
38
|
+
raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
|
|
39
|
+
|
|
40
|
+
from .flop_counter import FlopCounterMode
|
|
41
|
+
|
|
42
|
+
flop_counter = FlopCounterMode(display=display)
|
|
43
|
+
with flop_counter:
|
|
44
|
+
if loss_fn is None:
|
|
45
|
+
forward_fn()
|
|
46
|
+
else:
|
|
47
|
+
loss_fn(forward_fn()).backward()
|
|
48
|
+
return flop_counter.get_total_flops()
|