nshtrainer 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,7 @@ class LogEpochCallbackConfig(CallbackConfigBase):
35
35
  yield LogEpochCallback(self)
36
36
 
37
37
 
38
- def _worker_fn(
38
+ def _log_on_step(
39
39
  trainer: Trainer,
40
40
  pl_module: LightningModule,
41
41
  num_batches_prop: str,
@@ -75,6 +75,19 @@ def _worker_fn(
75
75
  pl_module.log(metric_name, epoch, on_step=True, on_epoch=False)
76
76
 
77
77
 
78
+ def _log_on_epoch(
79
+ trainer: Trainer,
80
+ pl_module: LightningModule,
81
+ *,
82
+ metric_name: str,
83
+ ):
84
+ if trainer.logger is None:
85
+ return
86
+
87
+ epoch = pl_module.current_epoch + 1
88
+ pl_module.log(metric_name, epoch, on_step=False, on_epoch=True)
89
+
90
+
78
91
  class LogEpochCallback(Callback):
79
92
  def __init__(self, config: LogEpochCallbackConfig):
80
93
  super().__init__()
@@ -85,16 +98,27 @@ class LogEpochCallback(Callback):
85
98
  def on_train_batch_start(
86
99
  self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
87
100
  ):
88
- if trainer.logger is None or not self.config.train:
101
+ if not self.config.train:
89
102
  return
90
103
 
91
- _worker_fn(
104
+ _log_on_step(
92
105
  trainer,
93
106
  pl_module,
94
107
  "num_training_batches",
95
108
  metric_name=self.config.metric_name,
96
109
  )
97
110
 
111
+ @override
112
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
113
+ if not self.config.train:
114
+ return
115
+
116
+ _log_on_epoch(
117
+ trainer,
118
+ pl_module,
119
+ metric_name=self.config.metric_name,
120
+ )
121
+
98
122
  @override
99
123
  def on_validation_batch_start(
100
124
  self,
@@ -104,10 +128,10 @@ class LogEpochCallback(Callback):
104
128
  batch_idx: int,
105
129
  dataloader_idx: int = 0,
106
130
  ) -> None:
107
- if trainer.logger is None or not self.config.val:
131
+ if not self.config.val:
108
132
  return
109
133
 
110
- _worker_fn(
134
+ _log_on_step(
111
135
  trainer,
112
136
  pl_module,
113
137
  "num_val_batches",
@@ -115,6 +139,19 @@ class LogEpochCallback(Callback):
115
139
  metric_name=self.config.metric_name,
116
140
  )
117
141
 
142
+ @override
143
+ def on_validation_epoch_end(
144
+ self, trainer: Trainer, pl_module: LightningModule
145
+ ) -> None:
146
+ if not self.config.val:
147
+ return
148
+
149
+ _log_on_epoch(
150
+ trainer,
151
+ pl_module,
152
+ metric_name=self.config.metric_name,
153
+ )
154
+
118
155
  @override
119
156
  def on_test_batch_start(
120
157
  self,
@@ -124,13 +161,24 @@ class LogEpochCallback(Callback):
124
161
  batch_idx: int,
125
162
  dataloader_idx: int = 0,
126
163
  ) -> None:
127
- if trainer.logger is None or not self.config.test:
164
+ if not self.config.test:
128
165
  return
129
166
 
130
- _worker_fn(
167
+ _log_on_step(
131
168
  trainer,
132
169
  pl_module,
133
170
  "num_test_batches",
134
171
  dataloader_idx=dataloader_idx,
135
172
  metric_name=self.config.metric_name,
136
173
  )
174
+
175
+ @override
176
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
177
+ if not self.config.test:
178
+ return
179
+
180
+ _log_on_epoch(
181
+ trainer,
182
+ pl_module,
183
+ metric_name=self.config.metric_name,
184
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -21,7 +21,7 @@ nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,1
21
21
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
22
22
  nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
- nshtrainer/callbacks/log_epoch.py,sha256=C2yUww8lAuCX-dy06tsw95yCBOfFd2mfGs0VhrEq1oU,3775
24
+ nshtrainer/callbacks/log_epoch.py,sha256=-uC5ss9p_ngXUCrSIUwViFcaaVX6ALUzIAKxoDgZrac,4823
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=v45ehnwNO987087HfiOY5aIrVRbwdKMgPYRFHs1fyEE,1444
26
26
  nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
27
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
@@ -159,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
159
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
160
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
161
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
162
- nshtrainer-1.5.0.dist-info/METADATA,sha256=fbtia7kDnNxHx_8VE0I-zFtmlF-HMAxH5raSiPjtl7w,980
163
- nshtrainer-1.5.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
- nshtrainer-1.5.0.dist-info/RECORD,,
162
+ nshtrainer-1.5.1.dist-info/METADATA,sha256=ct7S8c2O-oHJ2yw3-pApipwmP_r07z8lmFu40FhQY-k,980
163
+ nshtrainer-1.5.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.5.1.dist-info/RECORD,,