nshtrainer 0.14.2__tar.gz → 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/PKG-INFO +2 -1
  2. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/pyproject.toml +2 -1
  3. nshtrainer-0.15.0/src/nshtrainer/_experimental/__init__.py +1 -0
  4. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/signal_connector.py +3 -1
  5. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/_environment_info.py +39 -14
  6. nshtrainer-0.14.2/src/nshtrainer/_experimental/__init__.py +0 -2
  7. nshtrainer-0.14.2/src/nshtrainer/_experimental/flops/__init__.py +0 -48
  8. nshtrainer-0.14.2/src/nshtrainer/_experimental/flops/flop_counter.py +0 -787
  9. nshtrainer-0.14.2/src/nshtrainer/_experimental/flops/module_tracker.py +0 -140
  10. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/README.md +0 -0
  11. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/__init__.py +0 -0
  12. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  13. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  14. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  15. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  16. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  17. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  18. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/base.py +0 -0
  19. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  20. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  21. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  22. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  23. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  24. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  25. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/ema.py +0 -0
  26. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  27. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  28. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/interval.py +0 -0
  29. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  30. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  31. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  32. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  33. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/timer.py +0 -0
  34. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  35. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/data/__init__.py +0 -0
  36. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  37. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/data/transform.py +0 -0
  38. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/__init__.py +0 -0
  39. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/_experimental.py +0 -0
  40. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/actsave.py +0 -0
  41. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/callbacks.py +0 -0
  42. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/config.py +0 -0
  43. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/data.py +0 -0
  44. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/log.py +0 -0
  45. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  46. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/model.py +0 -0
  47. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/nn.py +0 -0
  48. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/optimizer.py +0 -0
  49. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/runner.py +0 -0
  50. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/snapshot.py +0 -0
  51. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/snoop.py +0 -0
  52. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/trainer.py +0 -0
  53. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/typecheck.py +0 -0
  54. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/ll/util.py +0 -0
  55. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/__init__.py +0 -0
  56. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/_base.py +0 -0
  57. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/csv.py +0 -0
  58. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  59. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/loggers/wandb.py +0 -0
  60. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  61. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  62. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  63. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  64. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/metrics/__init__.py +0 -0
  65. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/metrics/_config.py +0 -0
  66. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/__init__.py +0 -0
  67. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/base.py +0 -0
  68. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/config.py +0 -0
  69. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/callback.py +0 -0
  70. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/debug.py +0 -0
  71. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/distributed.py +0 -0
  72. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/logger.py +0 -0
  73. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/profiler.py +0 -0
  74. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  75. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  76. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/__init__.py +0 -0
  77. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/mlp.py +0 -0
  78. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/module_dict.py +0 -0
  79. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/module_list.py +0 -0
  80. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  81. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/optimizer.py +0 -0
  82. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/runner.py +0 -0
  83. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  84. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/__init__.py +0 -0
  85. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  86. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  87. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/trainer/trainer.py +0 -0
  88. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/_useful_types.py +0 -0
  89. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/environment.py +0 -0
  90. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/seed.py +0 -0
  91. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/slurm.py +0 -0
  92. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/typed.py +0 -0
  93. {nshtrainer-0.14.2 → nshtrainer-0.15.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.14.2
3
+ Version: 0.15.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -16,6 +16,7 @@ Requires-Dist: nshconfig
16
16
  Requires-Dist: nshrunner
17
17
  Requires-Dist: nshutils
18
18
  Requires-Dist: numpy
19
+ Requires-Dist: packaging
19
20
  Requires-Dist: psutil
20
21
  Requires-Dist: pytorch-lightning
21
22
  Requires-Dist: tensorboard ; extra == "extra"
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.14.2"
3
+ version = "0.15.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -14,6 +14,7 @@ psutil = "*"
14
14
  numpy = "*"
15
15
  torch = "*"
16
16
  typing-extensions = "*"
17
+ packaging = "*"
17
18
  lightning = "*"
18
19
  pytorch-lightning = "*"
19
20
  torchmetrics = { version = "*", optional = true }
@@ -0,0 +1 @@
1
+ from lightning.fabric.utilities.throughput import measure_flops as measure_flops
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import os
3
+ import platform
3
4
  import re
4
5
  import signal
5
6
  import subprocess
@@ -25,6 +26,7 @@ log = logging.getLogger(__name__)
25
26
 
26
27
  _SIGNUM = int | signal.Signals
27
28
  _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
29
+ _IS_WINDOWS = platform.system() == "Windows"
28
30
 
29
31
 
30
32
  def _resolve_requeue_signals():
@@ -57,7 +59,7 @@ class _SignalConnector(_LightningSignalConnector):
57
59
  handlers: list[_HANDLER],
58
60
  replace_existing: bool = False,
59
61
  ):
60
- if self._is_on_windows():
62
+ if _IS_WINDOWS:
61
63
  log.info(
62
64
  f"Signal {signum.name} has no handlers or is not supported on Windows."
63
65
  )
@@ -1,4 +1,5 @@
1
1
  import getpass
2
+ import importlib.metadata
2
3
  import inspect
3
4
  import logging
4
5
  import os
@@ -12,6 +13,7 @@ from typing import TYPE_CHECKING, Any, cast
12
13
  import nshconfig as C
13
14
  import psutil
14
15
  import torch
16
+ from packaging import version
15
17
  from typing_extensions import Self
16
18
 
17
19
  from .slurm import parse_slurm_node_list
@@ -398,23 +400,46 @@ class EnvironmentPackageConfig(C.Config):
398
400
 
399
401
  @classmethod
400
402
  def from_current_environment(cls):
401
- # Add Python package information
402
403
  python_packages: dict[str, Self] = {}
403
404
  try:
404
- import pkg_resources
405
-
406
- for package in pkg_resources.working_set:
407
- python_packages[package.key] = cls(
408
- name=package.project_name,
409
- version=package.version,
410
- path=Path(package.location) if package.location else None,
411
- summary=getattr(package, "summary", None),
412
- author=getattr(package, "author", None),
413
- license=getattr(package, "license", None),
414
- requires=[str(req) for req in package.requires()],
415
- )
405
+ for dist in importlib.metadata.distributions():
406
+ try:
407
+ # Get package metadata
408
+ metadata = dist.metadata
409
+
410
+ # Parse the version, stripping any local version identifier
411
+ pkg_version = version.parse(dist.version)
412
+ clean_version = (
413
+ f"{pkg_version.major}.{pkg_version.minor}.{pkg_version.micro}"
414
+ )
415
+
416
+ # Get requirements
417
+ requires = []
418
+ for req in dist.requires or []:
419
+ try:
420
+ requires.append(str(req))
421
+ except ValueError:
422
+ # If there's an invalid requirement, we'll skip it
423
+ log.warning(
424
+ f"Skipping invalid requirement for {dist.name}: {req}"
425
+ )
426
+
427
+ python_packages[dist.name] = cls(
428
+ name=dist.name,
429
+ version=clean_version,
430
+ path=Path(str(f)) if (f := dist.locate_file("")) else None,
431
+ summary=metadata["Summary"] if "Summary" in metadata else None,
432
+ author=metadata["Author"] if "Summary" in metadata else None,
433
+ license=metadata["License"] if "Summary" in metadata else None,
434
+ requires=requires,
435
+ )
436
+ except Exception as e:
437
+ log.warning(f"Error processing package {dist.name}: {str(e)}")
438
+
416
439
  except ImportError:
417
- log.warning("pkg_resources not available, skipping package information")
440
+ log.warning(
441
+ "importlib.metadata not available, skipping package information"
442
+ )
418
443
 
419
444
  return python_packages
420
445
 
@@ -1,2 +0,0 @@
1
- from .flops import MEASURE_FLOPS_AVAILABLE as MEASURE_FLOPS_AVAILABLE
2
- from .flops import measure_flops as measure_flops
@@ -1,48 +0,0 @@
1
- from collections.abc import Callable
2
-
3
- import torch
4
- from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
5
-
6
- MEASURE_FLOPS_AVAILABLE = _TORCH_GREATER_EQUAL_2_1
7
-
8
-
9
- def measure_flops(
10
- forward_fn: Callable[[], torch.Tensor],
11
- loss_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
12
- display: bool = True,
13
- ) -> int:
14
- """Utility to compute the total number of FLOPs used by a module during training or during inference.
15
-
16
- It's recommended to create a meta-device model for this:
17
-
18
- Example::
19
-
20
- with torch.device("meta"):
21
- model = MyModel()
22
- x = torch.randn(2, 32)
23
-
24
- model_fwd = lambda: model(x)
25
- fwd_flops = measure_flops(model, model_fwd)
26
-
27
- model_loss = lambda y: y.sum()
28
- fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
29
-
30
- Args:
31
- model: The model whose FLOPs should be measured.
32
- forward_fn: A function that runs ``forward`` on the model and returns the result.
33
- loss_fn: A function that computes the loss given the ``forward_fn`` output. If provided, the loss and `backward`
34
- FLOPs will be included in the result.
35
-
36
- """
37
- if not MEASURE_FLOPS_AVAILABLE:
38
- raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
39
-
40
- from .flop_counter import FlopCounterMode
41
-
42
- flop_counter = FlopCounterMode(display=display)
43
- with flop_counter:
44
- if loss_fn is None:
45
- forward_fn()
46
- else:
47
- loss_fn(forward_fn()).backward()
48
- return flop_counter.get_total_flops()