nshtrainer 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,32 +18,119 @@ log = logging.getLogger(__name__)
18
18
  class LogEpochCallbackConfig(CallbackConfigBase):
19
19
  name: Literal["log_epoch"] = "log_epoch"
20
20
 
21
+ metric_name: str = "computed_epoch"
22
+ """The name of the metric to log the epoch as."""
23
+
24
+ train: bool = True
25
+ """Whether to log the epoch during training."""
26
+
27
+ val: bool = True
28
+ """Whether to log the epoch during validation."""
29
+
30
+ test: bool = True
31
+ """Whether to log the epoch during testing."""
32
+
21
33
  @override
22
34
  def create_callbacks(self, trainer_config):
23
- yield LogEpochCallback()
35
+ yield LogEpochCallback(self)
36
+
37
+
38
+ def _worker_fn(
39
+ trainer: Trainer,
40
+ pl_module: LightningModule,
41
+ num_batches_prop: str,
42
+ dataloader_idx: int | None = None,
43
+ *,
44
+ metric_name: str,
45
+ ):
46
+ if trainer.logger is None:
47
+ return
48
+
49
+ # If trainer.num_{training/val/test}_batches is not set or is nan/inf, we cannot calculate the epoch
50
+ if not (num_batches := getattr(trainer, num_batches_prop, None)):
51
+ log.warning(f"Trainer has no valid `{num_batches_prop}`. Cannot log epoch.")
52
+ return
53
+
54
+ # If the trainer has a dataloader_idx, num_batches is a list of num_batches for each dataloader.
55
+ if dataloader_idx is not None:
56
+ assert isinstance(num_batches, list), (
57
+ f"Expected num_batches to be a list, got {type(num_batches)}"
58
+ )
59
+ assert 0 <= dataloader_idx < len(num_batches), (
60
+ f"Expected dataloader_idx to be between 0 and {len(num_batches)}, got {dataloader_idx}"
61
+ )
62
+ num_batches = num_batches[dataloader_idx]
63
+
64
+ if (
65
+ not isinstance(num_batches, (int, float))
66
+ or math.isnan(num_batches)
67
+ or math.isinf(num_batches)
68
+ ):
69
+ log.warning(
70
+ f"Trainer has no valid `{num_batches_prop}` (got {num_batches=}). Cannot log epoch."
71
+ )
72
+ return
73
+
74
+ epoch = pl_module.global_step / num_batches
75
+ pl_module.log(metric_name, epoch, on_step=True, on_epoch=False)
24
76
 
25
77
 
26
78
  class LogEpochCallback(Callback):
27
- def __init__(self, metric_name: str = "computed_epoch"):
79
+ def __init__(self, config: LogEpochCallbackConfig):
28
80
  super().__init__()
29
81
 
30
- self.metric_name = metric_name
82
+ self.config = config
31
83
 
32
84
  @override
33
85
  def on_train_batch_start(
34
86
  self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
35
87
  ):
36
- if trainer.logger is None:
88
+ if trainer.logger is None or not self.config.train:
89
+ return
90
+
91
+ _worker_fn(
92
+ trainer,
93
+ pl_module,
94
+ "num_training_batches",
95
+ metric_name=self.config.metric_name,
96
+ )
97
+
98
+ @override
99
+ def on_validation_batch_start(
100
+ self,
101
+ trainer: Trainer,
102
+ pl_module: LightningModule,
103
+ batch: Any,
104
+ batch_idx: int,
105
+ dataloader_idx: int = 0,
106
+ ) -> None:
107
+ if trainer.logger is None or not self.config.val:
37
108
  return
38
109
 
39
- # If trainer.num_training_batches is not set or is nan/inf, we cannot calculate the epoch
40
- if (
41
- not trainer.num_training_batches
42
- or math.isnan(trainer.num_training_batches)
43
- or math.isinf(trainer.num_training_batches)
44
- ):
45
- log.warning("Trainer has no valid num_training_batches. Cannot log epoch.")
110
+ _worker_fn(
111
+ trainer,
112
+ pl_module,
113
+ "num_val_batches",
114
+ dataloader_idx=dataloader_idx,
115
+ metric_name=self.config.metric_name,
116
+ )
117
+
118
+ @override
119
+ def on_test_batch_start(
120
+ self,
121
+ trainer: Trainer,
122
+ pl_module: LightningModule,
123
+ batch: Any,
124
+ batch_idx: int,
125
+ dataloader_idx: int = 0,
126
+ ) -> None:
127
+ if trainer.logger is None or not self.config.test:
46
128
  return
47
129
 
48
- epoch = pl_module.global_step / trainer.num_training_batches
49
- pl_module.log(self.metric_name, epoch, on_step=True, on_epoch=False)
130
+ _worker_fn(
131
+ trainer,
132
+ pl_module,
133
+ "num_test_batches",
134
+ dataloader_idx=dataloader_idx,
135
+ metric_name=self.config.metric_name,
136
+ )
@@ -419,7 +419,7 @@ class DirectoryConfig(C.Config):
419
419
 
420
420
  class TrainerConfig(C.Config):
421
421
  # region Active Run Configuration
422
- id: Annotated[str, C.AllowMissing()] = C.MISSING
422
+ id: C.AllowMissing[str] = C.MISSING
423
423
  """ID of the run."""
424
424
  name: list[str] = []
425
425
  """Run name in parts. Full name is constructed by joining the parts with spaces."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.4.0
3
+ Version: 1.4.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -14,7 +14,7 @@ Provides-Extra: extra
14
14
  Requires-Dist: GitPython ; extra == "extra"
15
15
  Requires-Dist: huggingface-hub ; extra == "extra"
16
16
  Requires-Dist: lightning
17
- Requires-Dist: nshconfig (>0.39)
17
+ Requires-Dist: nshconfig (>=0.43)
18
18
  Requires-Dist: nshrunner ; extra == "extra"
19
19
  Requires-Dist: nshutils ; extra == "extra"
20
20
  Requires-Dist: numpy
@@ -21,7 +21,7 @@ nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,1
21
21
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
22
22
  nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
- nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
24
+ nshtrainer/callbacks/log_epoch.py,sha256=C2yUww8lAuCX-dy06tsw95yCBOfFd2mfGs0VhrEq1oU,3775
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=v45ehnwNO987087HfiOY5aIrVRbwdKMgPYRFHs1fyEE,1444
26
26
  nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
27
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
@@ -133,7 +133,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
133
133
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
134
134
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
135
135
  nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
136
- nshtrainer/trainer/_config.py,sha256=FWEspBYt_bjLhUSkJApkC9pfYBTlFBHmIQRFNGpGjAc,45849
136
+ nshtrainer/trainer/_config.py,sha256=GL8DtuH-6x2aHcRlEcmzyhEBMRRldiSazNAeNmPw7gM,45836
137
137
  nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
138
138
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
139
139
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
@@ -159,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
159
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
160
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
161
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
162
- nshtrainer-1.4.0.dist-info/METADATA,sha256=PIV_5Swp1HhgFU2ZBj_X1tCeOBfNhrhTXOFB1vgunno,979
163
- nshtrainer-1.4.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
- nshtrainer-1.4.0.dist-info/RECORD,,
162
+ nshtrainer-1.4.1.dist-info/METADATA,sha256=QL69Trcmw3NF3UOovpqVJbzBTtHJtnDDxAzxyj9EX24,980
163
+ nshtrainer-1.4.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.4.1.dist-info/RECORD,,