nshtrainer 0.10.13__tar.gz → 0.10.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/PKG-INFO +1 -1
  2. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/pyproject.toml +1 -1
  3. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/wandb_watch.py +24 -24
  4. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/config.py +36 -6
  5. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/README.md +0 -0
  6. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/__init__.py +0 -0
  7. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_checkpoint/loader.py +0 -0
  8. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  9. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_checkpoint/saver.py +0 -0
  10. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/__init__.py +0 -0
  11. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  12. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  13. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  14. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/__init__.py +0 -0
  15. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/actsave.py +0 -0
  17. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/base.py +0 -0
  18. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  19. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/ema.py +0 -0
  20. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  21. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  22. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/interval.py +0 -0
  23. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
  24. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  25. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/model_checkpoint.py +0 -0
  26. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  27. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  28. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/data/__init__.py +0 -0
  32. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  33. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/data/transform.py +0 -0
  34. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/__init__.py +0 -0
  35. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/_experimental.py +0 -0
  36. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/actsave.py +0 -0
  37. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/callbacks.py +0 -0
  38. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/config.py +0 -0
  39. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/data.py +0 -0
  40. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/log.py +0 -0
  41. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  42. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/model.py +0 -0
  43. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/nn.py +0 -0
  44. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/optimizer.py +0 -0
  45. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/runner.py +0 -0
  46. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/snapshot.py +0 -0
  47. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/snoop.py +0 -0
  48. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/trainer.py +0 -0
  49. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/typecheck.py +0 -0
  50. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/ll/util.py +0 -0
  51. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  52. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  53. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  54. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  55. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/metrics/__init__.py +0 -0
  56. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/metrics/_config.py +0 -0
  57. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/__init__.py +0 -0
  58. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/base.py +0 -0
  59. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/callback.py +0 -0
  60. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/debug.py +0 -0
  61. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/distributed.py +0 -0
  62. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/logger.py +0 -0
  63. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/profiler.py +0 -0
  64. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  65. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  66. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/nn/__init__.py +0 -0
  67. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/nn/mlp.py +0 -0
  68. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/nn/module_dict.py +0 -0
  69. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/nn/module_list.py +0 -0
  70. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/nn/nonlinearity.py +0 -0
  71. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/optimizer.py +0 -0
  72. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/runner.py +0 -0
  73. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/scripts/find_packages.py +0 -0
  74. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/trainer/__init__.py +0 -0
  75. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  76. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  77. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/trainer/signal_connector.py +0 -0
  78. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/trainer/trainer.py +0 -0
  79. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/util/_environment_info.py +0 -0
  80. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/util/environment.py +0 -0
  81. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/util/seed.py +0 -0
  82. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/util/slurm.py +0 -0
  83. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/util/typed.py +0 -0
  84. {nshtrainer-0.10.13 → nshtrainer-0.10.14}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.13
3
+ Version: 0.10.14
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.10.13"
3
+ version = "0.10.14"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -12,13 +12,36 @@ from .base import CallbackConfigBase
12
12
  log = logging.getLogger(__name__)
13
13
 
14
14
 
15
+ class WandbWatchConfig(CallbackConfigBase):
16
+ name: Literal["wandb_watch"] = "wandb_watch"
17
+
18
+ enabled: bool = True
19
+ """Enable watching the model for wandb."""
20
+
21
+ log: str | None = None
22
+ """Log type for wandb."""
23
+
24
+ log_graph: bool = True
25
+ """Whether to log the graph for wandb."""
26
+
27
+ log_freq: int = 100
28
+ """Log frequency for wandb."""
29
+
30
+ def __bool__(self):
31
+ return self.enabled
32
+
33
+ @override
34
+ def create_callbacks(self, root_config):
35
+ yield WandbWatchCallback(self)
36
+
37
+
15
38
  @runtime_checkable
16
39
  class _HasWandbLogModuleProtocol(Protocol):
17
40
  def wandb_log_module(self) -> nn.Module | None: ...
18
41
 
19
42
 
20
43
  class WandbWatchCallback(Callback):
21
- def __init__(self, config: "WandbWatchConfig"):
44
+ def __init__(self, config: WandbWatchConfig):
22
45
  super().__init__()
23
46
 
24
47
  self.config = config
@@ -78,26 +101,3 @@ class WandbWatchCallback(Callback):
78
101
  log_graph=self.config.log_graph,
79
102
  )
80
103
  setattr(pl_module, "_model_watched", True)
81
-
82
-
83
- class WandbWatchConfig(CallbackConfigBase):
84
- name: Literal["wandb_watch"] = "wandb_watch"
85
-
86
- enabled: bool = True
87
- """Enable watching the model for wandb."""
88
-
89
- log: str | None = None
90
- """Log type for wandb."""
91
-
92
- log_graph: bool = True
93
- """Whether to log the graph for wandb."""
94
-
95
- log_freq: int = 100
96
- """Log frequency for wandb."""
97
-
98
- def __bool__(self):
99
- return self.enabled
100
-
101
- @override
102
- def create_callbacks(self, root_config):
103
- yield WandbWatchCallback(self)
@@ -20,6 +20,7 @@ from typing import (
20
20
 
21
21
  import nshconfig as C
22
22
  import numpy as np
23
+ import pkg_resources
23
24
  import torch
24
25
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
25
26
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
@@ -213,7 +214,7 @@ class BaseLoggerConfig(C.Config, ABC):
213
214
  """Enable this logger."""
214
215
 
215
216
  priority: int = 0
216
- """Priority of the logger. Higher values are logged first."""
217
+ """Priority of the logger. Higher priority loggers are created first."""
217
218
 
218
219
  log_dir: DirectoryPath | None = None
219
220
  """Directory to save the logs to. If None, will use the default log directory for the trainer."""
@@ -266,7 +267,8 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
266
267
  """Enable WandB logging."""
267
268
 
268
269
  priority: int = 2
269
- """Priority of the logger. Higher values are logged first."""
270
+ """Priority of the logger. Higher priority loggers are created first,
271
+ and the highest priority logger is the "main" logger for PyTorch Lightning."""
270
272
 
271
273
  project: str | None = None
272
274
  """WandB project name to use for the logger. If None, will use the root config's project name."""
@@ -286,8 +288,17 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
286
288
  offline: bool = False
287
289
  """Whether to run WandB in offline mode."""
288
290
 
289
- def offline_(self):
290
- self.offline = True
291
+ use_wandb_core: bool = False
292
+ """Whether to use the new `wandb-core` backend for WandB.
293
+ `wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
294
+ """
295
+
296
+ def offline_(self, value: bool = True):
297
+ self.offline = value
298
+ return self
299
+
300
+ def core_(self, value: bool = True):
301
+ self.use_wandb_core = value
291
302
  return self
292
303
 
293
304
  @override
@@ -295,6 +306,25 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
295
306
  if not self.enabled:
296
307
  return None
297
308
 
309
+ # If `wandb-core` is enabled, we should use the new backend.
310
+ if self.use_wandb_core:
311
+ try:
312
+ import wandb # type: ignore
313
+
314
+ # The minimum version that supports the new backend is 0.17.5
315
+ if pkg_resources.parse_version(
316
+ wandb.__version__
317
+ ) < pkg_resources.parse_version("0.17.5"):
318
+ log.warning(
319
+ "The version of WandB installed does not support the `wandb-core` backend. "
320
+ "Unable to use the `wandb-core` backend for WandB."
321
+ )
322
+ else:
323
+ wandb.require("core")
324
+ log.critical("Using the `wandb-core` backend for WandB.")
325
+ except ImportError:
326
+ pass
327
+
298
328
  from lightning.pytorch.loggers.wandb import WandbLogger
299
329
 
300
330
  save_dir = root_config.directory._resolve_log_directory_for_logger(
@@ -329,7 +359,7 @@ class CSVLoggerConfig(BaseLoggerConfig):
329
359
  """Enable CSV logging."""
330
360
 
331
361
  priority: int = 0
332
- """Priority of the logger. Higher values are logged first."""
362
+ """Priority of the logger. Higher priority loggers are created first."""
333
363
 
334
364
  prefix: str = ""
335
365
  """A string to put at the beginning of metric keys."""
@@ -383,7 +413,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
383
413
  """Enable TensorBoard logging."""
384
414
 
385
415
  priority: int = 2
386
- """Priority of the logger. Higher values are logged first."""
416
+ """Priority of the logger. Higher priority loggers are created first."""
387
417
 
388
418
  log_graph: bool = False
389
419
  """
File without changes