nshtrainer 0.14.2__tar.gz → 0.15.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/PKG-INFO +2 -1
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/pyproject.toml +2 -1
- nshtrainer-0.15.0/src/nshtrainer/_experimental/__init__.py +1 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/signal_connector.py +3 -1
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/_environment_info.py +39 -14
- nshtrainer-0.14.2/src/nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer-0.14.2/src/nshtrainer/_experimental/flops/__init__.py +0 -48
- nshtrainer-0.14.2/src/nshtrainer/_experimental/flops/flop_counter.py +0 -787
- nshtrainer-0.14.2/src/nshtrainer/_experimental/flops/module_tracker.py +0 -140
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/README.md +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.15.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -16,6 +16,7 @@ Requires-Dist: nshconfig
|
|
|
16
16
|
Requires-Dist: nshrunner
|
|
17
17
|
Requires-Dist: nshutils
|
|
18
18
|
Requires-Dist: numpy
|
|
19
|
+
Requires-Dist: packaging
|
|
19
20
|
Requires-Dist: psutil
|
|
20
21
|
Requires-Dist: pytorch-lightning
|
|
21
22
|
Requires-Dist: tensorboard ; extra == "extra"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.15.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -14,6 +14,7 @@ psutil = "*"
|
|
|
14
14
|
numpy = "*"
|
|
15
15
|
torch = "*"
|
|
16
16
|
typing-extensions = "*"
|
|
17
|
+
packaging = "*"
|
|
17
18
|
lightning = "*"
|
|
18
19
|
pytorch-lightning = "*"
|
|
19
20
|
torchmetrics = { version = "*", optional = true }
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from lightning.fabric.utilities.throughput import measure_flops as measure_flops
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import platform
|
|
3
4
|
import re
|
|
4
5
|
import signal
|
|
5
6
|
import subprocess
|
|
@@ -25,6 +26,7 @@ log = logging.getLogger(__name__)
|
|
|
25
26
|
|
|
26
27
|
_SIGNUM = int | signal.Signals
|
|
27
28
|
_HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
|
|
29
|
+
_IS_WINDOWS = platform.system() == "Windows"
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
def _resolve_requeue_signals():
|
|
@@ -57,7 +59,7 @@ class _SignalConnector(_LightningSignalConnector):
|
|
|
57
59
|
handlers: list[_HANDLER],
|
|
58
60
|
replace_existing: bool = False,
|
|
59
61
|
):
|
|
60
|
-
if
|
|
62
|
+
if _IS_WINDOWS:
|
|
61
63
|
log.info(
|
|
62
64
|
f"Signal {signum.name} has no handlers or is not supported on Windows."
|
|
63
65
|
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import getpass
|
|
2
|
+
import importlib.metadata
|
|
2
3
|
import inspect
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
@@ -12,6 +13,7 @@ from typing import TYPE_CHECKING, Any, cast
|
|
|
12
13
|
import nshconfig as C
|
|
13
14
|
import psutil
|
|
14
15
|
import torch
|
|
16
|
+
from packaging import version
|
|
15
17
|
from typing_extensions import Self
|
|
16
18
|
|
|
17
19
|
from .slurm import parse_slurm_node_list
|
|
@@ -398,23 +400,46 @@ class EnvironmentPackageConfig(C.Config):
|
|
|
398
400
|
|
|
399
401
|
@classmethod
|
|
400
402
|
def from_current_environment(cls):
|
|
401
|
-
# Add Python package information
|
|
402
403
|
python_packages: dict[str, Self] = {}
|
|
403
404
|
try:
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
version
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
405
|
+
for dist in importlib.metadata.distributions():
|
|
406
|
+
try:
|
|
407
|
+
# Get package metadata
|
|
408
|
+
metadata = dist.metadata
|
|
409
|
+
|
|
410
|
+
# Parse the version, stripping any local version identifier
|
|
411
|
+
pkg_version = version.parse(dist.version)
|
|
412
|
+
clean_version = (
|
|
413
|
+
f"{pkg_version.major}.{pkg_version.minor}.{pkg_version.micro}"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# Get requirements
|
|
417
|
+
requires = []
|
|
418
|
+
for req in dist.requires or []:
|
|
419
|
+
try:
|
|
420
|
+
requires.append(str(req))
|
|
421
|
+
except ValueError:
|
|
422
|
+
# If there's an invalid requirement, we'll skip it
|
|
423
|
+
log.warning(
|
|
424
|
+
f"Skipping invalid requirement for {dist.name}: {req}"
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
python_packages[dist.name] = cls(
|
|
428
|
+
name=dist.name,
|
|
429
|
+
version=clean_version,
|
|
430
|
+
path=Path(str(f)) if (f := dist.locate_file("")) else None,
|
|
431
|
+
summary=metadata["Summary"] if "Summary" in metadata else None,
|
|
432
|
+
author=metadata["Author"] if "Summary" in metadata else None,
|
|
433
|
+
license=metadata["License"] if "Summary" in metadata else None,
|
|
434
|
+
requires=requires,
|
|
435
|
+
)
|
|
436
|
+
except Exception as e:
|
|
437
|
+
log.warning(f"Error processing package {dist.name}: {str(e)}")
|
|
438
|
+
|
|
416
439
|
except ImportError:
|
|
417
|
-
log.warning(
|
|
440
|
+
log.warning(
|
|
441
|
+
"importlib.metadata not available, skipping package information"
|
|
442
|
+
)
|
|
418
443
|
|
|
419
444
|
return python_packages
|
|
420
445
|
|
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
|
5
|
-
|
|
6
|
-
MEASURE_FLOPS_AVAILABLE = _TORCH_GREATER_EQUAL_2_1
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def measure_flops(
|
|
10
|
-
forward_fn: Callable[[], torch.Tensor],
|
|
11
|
-
loss_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
12
|
-
display: bool = True,
|
|
13
|
-
) -> int:
|
|
14
|
-
"""Utility to compute the total number of FLOPs used by a module during training or during inference.
|
|
15
|
-
|
|
16
|
-
It's recommended to create a meta-device model for this:
|
|
17
|
-
|
|
18
|
-
Example::
|
|
19
|
-
|
|
20
|
-
with torch.device("meta"):
|
|
21
|
-
model = MyModel()
|
|
22
|
-
x = torch.randn(2, 32)
|
|
23
|
-
|
|
24
|
-
model_fwd = lambda: model(x)
|
|
25
|
-
fwd_flops = measure_flops(model, model_fwd)
|
|
26
|
-
|
|
27
|
-
model_loss = lambda y: y.sum()
|
|
28
|
-
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
model: The model whose FLOPs should be measured.
|
|
32
|
-
forward_fn: A function that runs ``forward`` on the model and returns the result.
|
|
33
|
-
loss_fn: A function that computes the loss given the ``forward_fn`` output. If provided, the loss and `backward`
|
|
34
|
-
FLOPs will be included in the result.
|
|
35
|
-
|
|
36
|
-
"""
|
|
37
|
-
if not MEASURE_FLOPS_AVAILABLE:
|
|
38
|
-
raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
|
|
39
|
-
|
|
40
|
-
from .flop_counter import FlopCounterMode
|
|
41
|
-
|
|
42
|
-
flop_counter = FlopCounterMode(display=display)
|
|
43
|
-
with flop_counter:
|
|
44
|
-
if loss_fn is None:
|
|
45
|
-
forward_fn()
|
|
46
|
-
else:
|
|
47
|
-
loss_fn(forward_fn()).backward()
|
|
48
|
-
return flop_counter.get_total_flops()
|