nshtrainer 0.10.13__py3-none-any.whl → 0.10.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,13 +12,36 @@ from .base import CallbackConfigBase
12
12
  log = logging.getLogger(__name__)
13
13
 
14
14
 
15
+ class WandbWatchConfig(CallbackConfigBase):
16
+ name: Literal["wandb_watch"] = "wandb_watch"
17
+
18
+ enabled: bool = True
19
+ """Enable watching the model for wandb."""
20
+
21
+ log: str | None = None
22
+ """Log type for wandb."""
23
+
24
+ log_graph: bool = True
25
+ """Whether to log the graph for wandb."""
26
+
27
+ log_freq: int = 100
28
+ """Log frequency for wandb."""
29
+
30
+ def __bool__(self):
31
+ return self.enabled
32
+
33
+ @override
34
+ def create_callbacks(self, root_config):
35
+ yield WandbWatchCallback(self)
36
+
37
+
15
38
  @runtime_checkable
16
39
  class _HasWandbLogModuleProtocol(Protocol):
17
40
  def wandb_log_module(self) -> nn.Module | None: ...
18
41
 
19
42
 
20
43
  class WandbWatchCallback(Callback):
21
- def __init__(self, config: "WandbWatchConfig"):
44
+ def __init__(self, config: WandbWatchConfig):
22
45
  super().__init__()
23
46
 
24
47
  self.config = config
@@ -78,26 +101,3 @@ class WandbWatchCallback(Callback):
78
101
  log_graph=self.config.log_graph,
79
102
  )
80
103
  setattr(pl_module, "_model_watched", True)
81
-
82
-
83
- class WandbWatchConfig(CallbackConfigBase):
84
- name: Literal["wandb_watch"] = "wandb_watch"
85
-
86
- enabled: bool = True
87
- """Enable watching the model for wandb."""
88
-
89
- log: str | None = None
90
- """Log type for wandb."""
91
-
92
- log_graph: bool = True
93
- """Whether to log the graph for wandb."""
94
-
95
- log_freq: int = 100
96
- """Log frequency for wandb."""
97
-
98
- def __bool__(self):
99
- return self.enabled
100
-
101
- @override
102
- def create_callbacks(self, root_config):
103
- yield WandbWatchCallback(self)
@@ -20,6 +20,7 @@ from typing import (
20
20
 
21
21
  import nshconfig as C
22
22
  import numpy as np
23
+ import pkg_resources
23
24
  import torch
24
25
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
25
26
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
@@ -213,7 +214,7 @@ class BaseLoggerConfig(C.Config, ABC):
213
214
  """Enable this logger."""
214
215
 
215
216
  priority: int = 0
216
- """Priority of the logger. Higher values are logged first."""
217
+ """Priority of the logger. Higher priority loggers are created first."""
217
218
 
218
219
  log_dir: DirectoryPath | None = None
219
220
  """Directory to save the logs to. If None, will use the default log directory for the trainer."""
@@ -266,7 +267,8 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
266
267
  """Enable WandB logging."""
267
268
 
268
269
  priority: int = 2
269
- """Priority of the logger. Higher values are logged first."""
270
+ """Priority of the logger. Higher priority loggers are created first,
271
+ and the highest priority logger is the "main" logger for PyTorch Lightning."""
270
272
 
271
273
  project: str | None = None
272
274
  """WandB project name to use for the logger. If None, will use the root config's project name."""
@@ -286,8 +288,17 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
286
288
  offline: bool = False
287
289
  """Whether to run WandB in offline mode."""
288
290
 
289
- def offline_(self):
290
- self.offline = True
291
+ use_wandb_core: bool = False
292
+ """Whether to use the new `wandb-core` backend for WandB.
293
+ `wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
294
+ """
295
+
296
+ def offline_(self, value: bool = True):
297
+ self.offline = value
298
+ return self
299
+
300
+ def core_(self, value: bool = True):
301
+ self.use_wandb_core = value
291
302
  return self
292
303
 
293
304
  @override
@@ -295,6 +306,25 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
295
306
  if not self.enabled:
296
307
  return None
297
308
 
309
+ # If `wandb-core` is enabled, we should use the new backend.
310
+ if self.use_wandb_core:
311
+ try:
312
+ import wandb # type: ignore
313
+
314
+ # The minimum version that supports the new backend is 0.17.5
315
+ if pkg_resources.parse_version(
316
+ wandb.__version__
317
+ ) < pkg_resources.parse_version("0.17.5"):
318
+ log.warning(
319
+ "The version of WandB installed does not support the `wandb-core` backend. "
320
+ "Unable to use the `wandb-core` backend for WandB."
321
+ )
322
+ else:
323
+ wandb.require("core")
324
+ log.critical("Using the `wandb-core` backend for WandB.")
325
+ except ImportError:
326
+ pass
327
+
298
328
  from lightning.pytorch.loggers.wandb import WandbLogger
299
329
 
300
330
  save_dir = root_config.directory._resolve_log_directory_for_logger(
@@ -329,7 +359,7 @@ class CSVLoggerConfig(BaseLoggerConfig):
329
359
  """Enable CSV logging."""
330
360
 
331
361
  priority: int = 0
332
- """Priority of the logger. Higher values are logged first."""
362
+ """Priority of the logger. Higher priority loggers are created first."""
333
363
 
334
364
  prefix: str = ""
335
365
  """A string to put at the beginning of metric keys."""
@@ -383,7 +413,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
383
413
  """Enable TensorBoard logging."""
384
414
 
385
415
  priority: int = 2
386
- """Priority of the logger. Higher values are logged first."""
416
+ """Priority of the logger. Higher priority loggers are created first."""
387
417
 
388
418
  log_graph: bool = False
389
419
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.13
3
+ Version: 0.10.14
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -23,7 +23,7 @@ nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKh
23
23
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
24
24
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
25
25
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
26
- nshtrainer/callbacks/wandb_watch.py,sha256=EJ93mtJlph4BZsXh8HJPNiw2VNSm2N6TOwpCwqRAeKI,2923
26
+ nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
27
27
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
28
28
  nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
29
29
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
@@ -52,7 +52,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
52
52
  nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
53
53
  nshtrainer/model/__init__.py,sha256=NpvyQHmGaHB8xdraHmm8l7kDHLmvJSgBNQKkfYqtgyI,1454
54
54
  nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
55
- nshtrainer/model/config.py,sha256=65UDzt3ZZFUQaHMlK7f9wzwyGH3cDyHGtjZ2eOjHvVo,53360
55
+ nshtrainer/model/config.py,sha256=z6kSkTirvRsyW3YIDTG1uAmK4fCC-gNAQrMi7Osxiow,54643
56
56
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
57
57
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
58
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -79,6 +79,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
79
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.13.dist-info/METADATA,sha256=HpGl8_E6q2l2nQrIzU5ibNEyUXj8adF8cxMzouUSpAg,696
83
- nshtrainer-0.10.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.13.dist-info/RECORD,,
82
+ nshtrainer-0.10.14.dist-info/METADATA,sha256=7aheATImk1o69ugMxCBAbWhgWfRgRKYbP_QeFaqUGbM,696
83
+ nshtrainer-0.10.14.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.14.dist-info/RECORD,,