nshtrainer 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer-0.1.0/PKG-INFO +18 -0
  2. nshtrainer-0.1.0/README.md +0 -0
  3. nshtrainer-0.1.0/pyproject.toml +18 -0
  4. nshtrainer-0.1.0/src/nshtrainer/__init__.py +64 -0
  5. nshtrainer-0.1.0/src/nshtrainer/_experimental/__init__.py +2 -0
  6. nshtrainer-0.1.0/src/nshtrainer/_experimental/flops/__init__.py +48 -0
  7. nshtrainer-0.1.0/src/nshtrainer/_experimental/flops/flop_counter.py +787 -0
  8. nshtrainer-0.1.0/src/nshtrainer/_experimental/flops/module_tracker.py +140 -0
  9. nshtrainer-0.1.0/src/nshtrainer/_snoop.py +216 -0
  10. nshtrainer-0.1.0/src/nshtrainer/_submit/print_environment_info.py +31 -0
  11. nshtrainer-0.1.0/src/nshtrainer/_submit/session/_output.py +12 -0
  12. nshtrainer-0.1.0/src/nshtrainer/_submit/session/_script.py +109 -0
  13. nshtrainer-0.1.0/src/nshtrainer/_submit/session/lsf.py +467 -0
  14. nshtrainer-0.1.0/src/nshtrainer/_submit/session/slurm.py +573 -0
  15. nshtrainer-0.1.0/src/nshtrainer/_submit/session/unified.py +350 -0
  16. nshtrainer-0.1.0/src/nshtrainer/actsave/__init__.py +7 -0
  17. nshtrainer-0.1.0/src/nshtrainer/actsave/_callback.py +75 -0
  18. nshtrainer-0.1.0/src/nshtrainer/actsave/_loader.py +144 -0
  19. nshtrainer-0.1.0/src/nshtrainer/actsave/_saver.py +337 -0
  20. nshtrainer-0.1.0/src/nshtrainer/callbacks/__init__.py +35 -0
  21. nshtrainer-0.1.0/src/nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  22. nshtrainer-0.1.0/src/nshtrainer/callbacks/base.py +113 -0
  23. nshtrainer-0.1.0/src/nshtrainer/callbacks/early_stopping.py +112 -0
  24. nshtrainer-0.1.0/src/nshtrainer/callbacks/ema.py +383 -0
  25. nshtrainer-0.1.0/src/nshtrainer/callbacks/finite_checks.py +75 -0
  26. nshtrainer-0.1.0/src/nshtrainer/callbacks/gradient_skipping.py +103 -0
  27. nshtrainer-0.1.0/src/nshtrainer/callbacks/interval.py +322 -0
  28. nshtrainer-0.1.0/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  29. nshtrainer-0.1.0/src/nshtrainer/callbacks/log_epoch.py +35 -0
  30. nshtrainer-0.1.0/src/nshtrainer/callbacks/norm_logging.py +187 -0
  31. nshtrainer-0.1.0/src/nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  32. nshtrainer-0.1.0/src/nshtrainer/callbacks/print_table.py +90 -0
  33. nshtrainer-0.1.0/src/nshtrainer/callbacks/throughput_monitor.py +56 -0
  34. nshtrainer-0.1.0/src/nshtrainer/callbacks/timer.py +157 -0
  35. nshtrainer-0.1.0/src/nshtrainer/callbacks/wandb_watch.py +103 -0
  36. nshtrainer-0.1.0/src/nshtrainer/config.py +289 -0
  37. nshtrainer-0.1.0/src/nshtrainer/data/__init__.py +4 -0
  38. nshtrainer-0.1.0/src/nshtrainer/data/balanced_batch_sampler.py +132 -0
  39. nshtrainer-0.1.0/src/nshtrainer/data/transform.py +67 -0
  40. nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/__init__.py +18 -0
  41. nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/_base.py +101 -0
  42. nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  43. nshtrainer-0.1.0/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  44. nshtrainer-0.1.0/src/nshtrainer/model/__init__.py +44 -0
  45. nshtrainer-0.1.0/src/nshtrainer/model/base.py +641 -0
  46. nshtrainer-0.1.0/src/nshtrainer/model/config.py +2064 -0
  47. nshtrainer-0.1.0/src/nshtrainer/model/modules/callback.py +157 -0
  48. nshtrainer-0.1.0/src/nshtrainer/model/modules/debug.py +42 -0
  49. nshtrainer-0.1.0/src/nshtrainer/model/modules/distributed.py +70 -0
  50. nshtrainer-0.1.0/src/nshtrainer/model/modules/logger.py +170 -0
  51. nshtrainer-0.1.0/src/nshtrainer/model/modules/profiler.py +24 -0
  52. nshtrainer-0.1.0/src/nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  53. nshtrainer-0.1.0/src/nshtrainer/model/modules/shared_parameters.py +72 -0
  54. nshtrainer-0.1.0/src/nshtrainer/nn/__init__.py +19 -0
  55. nshtrainer-0.1.0/src/nshtrainer/nn/mlp.py +106 -0
  56. nshtrainer-0.1.0/src/nshtrainer/nn/module_dict.py +66 -0
  57. nshtrainer-0.1.0/src/nshtrainer/nn/module_list.py +50 -0
  58. nshtrainer-0.1.0/src/nshtrainer/nn/nonlinearity.py +157 -0
  59. nshtrainer-0.1.0/src/nshtrainer/optimizer.py +62 -0
  60. nshtrainer-0.1.0/src/nshtrainer/runner.py +21 -0
  61. nshtrainer-0.1.0/src/nshtrainer/scripts/check_env.py +41 -0
  62. nshtrainer-0.1.0/src/nshtrainer/scripts/find_packages.py +51 -0
  63. nshtrainer-0.1.0/src/nshtrainer/trainer/__init__.py +1 -0
  64. nshtrainer-0.1.0/src/nshtrainer/trainer/signal_connector.py +208 -0
  65. nshtrainer-0.1.0/src/nshtrainer/trainer/trainer.py +340 -0
  66. nshtrainer-0.1.0/src/nshtrainer/typecheck.py +144 -0
  67. nshtrainer-0.1.0/src/nshtrainer/util/environment.py +119 -0
  68. nshtrainer-0.1.0/src/nshtrainer/util/seed.py +11 -0
  69. nshtrainer-0.1.0/src/nshtrainer/util/singleton.py +89 -0
  70. nshtrainer-0.1.0/src/nshtrainer/util/slurm.py +49 -0
  71. nshtrainer-0.1.0/src/nshtrainer/util/typed.py +2 -0
  72. nshtrainer-0.1.0/src/nshtrainer/util/typing_utils.py +19 -0
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.1
2
+ Name: nshtrainer
3
+ Version: 0.1.0
4
+ Summary:
5
+ Author: Nima Shoghi
6
+ Author-email: nimashoghi@gmail.com
7
+ Requires-Python: >=3.10,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
13
+ Requires-Dist: nshrunner (>=0.1.0,<0.2.0)
14
+ Requires-Dist: torch
15
+ Requires-Dist: typing-extensions
16
+ Description-Content-Type: text/markdown
17
+
18
+
File without changes
@@ -0,0 +1,18 @@
1
+ [tool.poetry]
2
+ name = "nshtrainer"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.10"
10
+ nshconfig = "^0.2.0"
11
+ nshrunner = "^0.1.0"
12
+ torch = "*"
13
+ typing-extensions = "*"
14
+
15
+
16
+ [build-system]
17
+ requires = ["poetry-core"]
18
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,64 @@
1
+ from . import _experimental as _experimental
2
+ from . import actsave as actsave
3
+ from . import callbacks as callbacks
4
+ from . import lr_scheduler as lr_scheduler
5
+ from . import model as model
6
+ from . import nn as nn
7
+ from . import optimizer as optimizer
8
+ from . import snapshot as snapshot
9
+ from . import typecheck as typecheck
10
+ from ._snoop import snoop as snoop
11
+ from .actsave import ActLoad as ActLoad
12
+ from .actsave import ActSave as ActSave
13
+ from .config import MISSING as MISSING
14
+ from .config import AllowMissing as AllowMissing
15
+ from .config import Field as Field
16
+ from .config import MissingField as MissingField
17
+ from .config import PrivateAttr as PrivateAttr
18
+ from .config import TypedConfig as TypedConfig
19
+ from .data import dataset_transform as dataset_transform
20
+ from .log import init_python_logging as init_python_logging
21
+ from .log import lovely as lovely
22
+ from .log import pretty as pretty
23
+ from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
+ from .model import ActSaveConfig as ActSaveConfig
25
+ from .model import Base as Base
26
+ from .model import BaseConfig as BaseConfig
27
+ from .model import BaseLoggerConfig as BaseLoggerConfig
28
+ from .model import BaseProfilerConfig as BaseProfilerConfig
29
+ from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
30
+ from .model import CheckpointSavingConfig as CheckpointSavingConfig
31
+ from .model import ConfigList as ConfigList
32
+ from .model import DirectoryConfig as DirectoryConfig
33
+ from .model import (
34
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
35
+ )
36
+ from .model import EnvironmentConfig as EnvironmentConfig
37
+ from .model import (
38
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
39
+ )
40
+ from .model import (
41
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
42
+ )
43
+ from .model import GradientClippingConfig as GradientClippingConfig
44
+ from .model import LightningDataModuleBase as LightningDataModuleBase
45
+ from .model import LightningModuleBase as LightningModuleBase
46
+ from .model import LoggingConfig as LoggingConfig
47
+ from .model import MetricConfig as MetricConfig
48
+ from .model import OptimizationConfig as OptimizationConfig
49
+ from .model import PrimaryMetricConfig as PrimaryMetricConfig
50
+ from .model import PythonLogging as PythonLogging
51
+ from .model import ReproducibilityConfig as ReproducibilityConfig
52
+ from .model import RunnerConfig as RunnerConfig
53
+ from .model import SanityCheckingConfig as SanityCheckingConfig
54
+ from .model import SeedConfig as SeedConfig
55
+ from .model import TrainerConfig as TrainerConfig
56
+ from .model import WandbWatchConfig as WandbWatchConfig
57
+ from .nn import TypedModuleDict as TypedModuleDict
58
+ from .nn import TypedModuleList as TypedModuleList
59
+ from .optimizer import OptimizerConfig as OptimizerConfig
60
+ from .runner import Runner as Runner
61
+ from .runner import SnapshotConfig as SnapshotConfig
62
+ from .trainer import Trainer as Trainer
63
+ from .util.singleton import Registry as Registry
64
+ from .util.singleton import Singleton as Singleton
@@ -0,0 +1,2 @@
1
+ from .flops import MEASURE_FLOPS_AVAILABLE as MEASURE_FLOPS_AVAILABLE
2
+ from .flops import measure_flops as measure_flops
@@ -0,0 +1,48 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
5
+
6
+ MEASURE_FLOPS_AVAILABLE = _TORCH_GREATER_EQUAL_2_1
7
+
8
+
9
+ def measure_flops(
10
+ forward_fn: Callable[[], torch.Tensor],
11
+ loss_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
12
+ display: bool = True,
13
+ ) -> int:
14
+ """Utility to compute the total number of FLOPs used by a module during training or during inference.
15
+
16
+ It's recommended to create a meta-device model for this:
17
+
18
+ Example::
19
+
20
+ with torch.device("meta"):
21
+ model = MyModel()
22
+ x = torch.randn(2, 32)
23
+
24
+ model_fwd = lambda: model(x)
25
+ fwd_flops = measure_flops(model, model_fwd)
26
+
27
+ model_loss = lambda y: y.sum()
28
+ fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
29
+
30
+ Args:
31
+ model: The model whose FLOPs should be measured.
32
+ forward_fn: A function that runs ``forward`` on the model and returns the result.
33
+ loss_fn: A function that computes the loss given the ``forward_fn`` output. If provided, the loss and `backward`
34
+ FLOPs will be included in the result.
35
+
36
+ """
37
+ if not MEASURE_FLOPS_AVAILABLE:
38
+ raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
39
+
40
+ from .flop_counter import FlopCounterMode
41
+
42
+ flop_counter = FlopCounterMode(display=display)
43
+ with flop_counter:
44
+ if loss_fn is None:
45
+ forward_fn()
46
+ else:
47
+ loss_fn(forward_fn()).backward()
48
+ return flop_counter.get_total_flops()