pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base and example callbacks for image plotting to TensorBoard.
|
|
3
|
+
|
|
4
|
+
The BaseImagePlotter provides a foundation for custom image plotting callbacks.
|
|
5
|
+
"""
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
from typing import Optional, List
|
|
8
|
+
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
11
|
+
|
|
12
|
+
from kito.callbacks.callback_base import Callback
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseImagePlotter(Callback):
|
|
16
|
+
"""
|
|
17
|
+
Base class for plotting images to TensorBoard.
|
|
18
|
+
|
|
19
|
+
Subclasses must implement `create_figure()` to define custom plotting logic.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
log_dir: Directory for TensorBoard logs
|
|
23
|
+
tag: Tag for the image in TensorBoard
|
|
24
|
+
freq: Frequency of plotting (every N epochs)
|
|
25
|
+
batch_indices: Which batch indices to visualize (e.g., [0, 1, 2])
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
class MyCustomPlotter(BaseImagePlotter):
|
|
29
|
+
def create_figure(self, val_data, val_outputs, epoch, **kwargs):
|
|
30
|
+
fig, ax = plt.subplots()
|
|
31
|
+
# Your custom plotting logic here
|
|
32
|
+
ax.imshow(val_outputs[0].cpu().numpy())
|
|
33
|
+
return fig
|
|
34
|
+
|
|
35
|
+
plotter = MyCustomPlotter('logs/images', tag='predictions')
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
log_dir: str,
|
|
41
|
+
tag: str = 'validation_images',
|
|
42
|
+
freq: int = 1,
|
|
43
|
+
batch_indices: Optional[List[int]] = None
|
|
44
|
+
):
|
|
45
|
+
self.log_dir = log_dir
|
|
46
|
+
self.tag = tag
|
|
47
|
+
self.freq = freq
|
|
48
|
+
self.batch_indices = batch_indices or [0]
|
|
49
|
+
self.writer = None
|
|
50
|
+
|
|
51
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
52
|
+
"""Initialize TensorBoard writer."""
|
|
53
|
+
self.writer = SummaryWriter(self.log_dir)
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def create_figure(self, val_data, val_outputs, epoch, **kwargs):
|
|
57
|
+
"""
|
|
58
|
+
Create matplotlib figure for visualization.
|
|
59
|
+
|
|
60
|
+
Must be implemented by subclasses.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
val_data: Validation data (tuple of inputs and targets)
|
|
64
|
+
val_outputs: Model predictions on validation data
|
|
65
|
+
epoch: Current epoch number
|
|
66
|
+
**kwargs: Additional context from engine
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
matplotlib.figure.Figure or list of figures
|
|
70
|
+
"""
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
74
|
+
"""Plot images to TensorBoard."""
|
|
75
|
+
if epoch % self.freq != 0:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
# Get validation data and outputs from kwargs
|
|
79
|
+
val_data = kwargs.get('val_data')
|
|
80
|
+
val_outputs = kwargs.get('val_outputs')
|
|
81
|
+
|
|
82
|
+
if val_data is None or val_outputs is None:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
# Create figure(s)
|
|
86
|
+
figures = self.create_figure(val_data, val_outputs, epoch, **kwargs)
|
|
87
|
+
|
|
88
|
+
# Log to TensorBoard
|
|
89
|
+
if isinstance(figures, (list, tuple)):
|
|
90
|
+
# Multiple figures
|
|
91
|
+
for i, fig in enumerate(figures):
|
|
92
|
+
idx = self.batch_indices[i] if i < len(self.batch_indices) else i
|
|
93
|
+
self.writer.add_figure(
|
|
94
|
+
f'{self.tag}_batch_{idx}',
|
|
95
|
+
fig,
|
|
96
|
+
global_step=epoch,
|
|
97
|
+
close=True
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
# Single figure
|
|
101
|
+
self.writer.add_figure(
|
|
102
|
+
self.tag,
|
|
103
|
+
figures,
|
|
104
|
+
global_step=epoch,
|
|
105
|
+
close=True
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.writer.flush()
|
|
109
|
+
|
|
110
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
111
|
+
"""Close TensorBoard writer."""
|
|
112
|
+
if self.writer:
|
|
113
|
+
self.writer.close()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ============================================================================
|
|
117
|
+
# EXAMPLE IMAGE PLOTTERS
|
|
118
|
+
# ============================================================================
|
|
119
|
+
|
|
120
|
+
class SimpleImagePlotter(BaseImagePlotter):
|
|
121
|
+
"""
|
|
122
|
+
Simple image plotter for input/output comparison.
|
|
123
|
+
|
|
124
|
+
Shows input, ground truth, and prediction side by side.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
log_dir: Directory for TensorBoard logs
|
|
128
|
+
tag: Tag for the image in TensorBoard
|
|
129
|
+
freq: Frequency of plotting (every N epochs)
|
|
130
|
+
batch_indices: Which batch indices to visualize
|
|
131
|
+
cmap: Colormap for matplotlib (default: 'viridis')
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
plotter = SimpleImagePlotter(
|
|
135
|
+
'logs/images',
|
|
136
|
+
tag='validation',
|
|
137
|
+
batch_indices=[0, 1, 2]
|
|
138
|
+
)
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
log_dir: str,
|
|
144
|
+
tag: str = 'validation',
|
|
145
|
+
freq: int = 1,
|
|
146
|
+
batch_indices: Optional[List[int]] = None,
|
|
147
|
+
cmap: str = 'viridis'
|
|
148
|
+
):
|
|
149
|
+
super().__init__(log_dir, tag, freq, batch_indices)
|
|
150
|
+
self.cmap = cmap
|
|
151
|
+
|
|
152
|
+
def create_figure(self, val_data, val_outputs, epoch, **kwargs):
|
|
153
|
+
"""Create side-by-side comparison figure."""
|
|
154
|
+
# Extract input and target
|
|
155
|
+
val_input = val_data[0] # (B, C, H, W) or (B, T, C, H, W)
|
|
156
|
+
val_target = val_data[1]
|
|
157
|
+
|
|
158
|
+
# Create figures for each batch index
|
|
159
|
+
figures = []
|
|
160
|
+
for idx in self.batch_indices:
|
|
161
|
+
if idx >= val_input.shape[0]:
|
|
162
|
+
break
|
|
163
|
+
|
|
164
|
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
165
|
+
|
|
166
|
+
# Get data for this batch index
|
|
167
|
+
input_img = self._prepare_image(val_input[idx])
|
|
168
|
+
target_img = self._prepare_image(val_target[idx])
|
|
169
|
+
pred_img = self._prepare_image(val_outputs[idx])
|
|
170
|
+
|
|
171
|
+
# Plot
|
|
172
|
+
axes[0].imshow(input_img, cmap=self.cmap)
|
|
173
|
+
axes[0].set_title('Input')
|
|
174
|
+
axes[0].axis('off')
|
|
175
|
+
|
|
176
|
+
axes[1].imshow(target_img, cmap=self.cmap)
|
|
177
|
+
axes[1].set_title('Ground Truth')
|
|
178
|
+
axes[1].axis('off')
|
|
179
|
+
|
|
180
|
+
axes[2].imshow(pred_img, cmap=self.cmap)
|
|
181
|
+
axes[2].set_title('Prediction')
|
|
182
|
+
axes[2].axis('off')
|
|
183
|
+
|
|
184
|
+
plt.tight_layout()
|
|
185
|
+
figures.append(fig)
|
|
186
|
+
|
|
187
|
+
return figures if len(figures) > 1 else figures[0]
|
|
188
|
+
|
|
189
|
+
def _prepare_image(self, tensor):
|
|
190
|
+
"""
|
|
191
|
+
Prepare tensor for visualization.
|
|
192
|
+
|
|
193
|
+
Handles:
|
|
194
|
+
- Time series: (T, C, H, W) -> take middle frame
|
|
195
|
+
- Multi-channel: (C, H, W) -> take first channel
|
|
196
|
+
- Single channel: (1, H, W) -> squeeze
|
|
197
|
+
"""
|
|
198
|
+
# Move to CPU and detach
|
|
199
|
+
img = tensor.detach().cpu()
|
|
200
|
+
|
|
201
|
+
# Handle time series
|
|
202
|
+
if img.ndim == 4: # (T, C, H, W)
|
|
203
|
+
img = img[img.shape[0] // 2] # Take middle frame
|
|
204
|
+
|
|
205
|
+
# Handle channels
|
|
206
|
+
if img.ndim == 3: # (C, H, W)
|
|
207
|
+
if img.shape[0] == 1:
|
|
208
|
+
img = img[0] # Single channel
|
|
209
|
+
elif img.shape[0] == 3:
|
|
210
|
+
img = img.permute(1, 2, 0) # RGB
|
|
211
|
+
else:
|
|
212
|
+
img = img[0] # Take first channel
|
|
213
|
+
|
|
214
|
+
return img.numpy()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class DifferencePlotter(BaseImagePlotter):
|
|
218
|
+
"""
|
|
219
|
+
Plot input, prediction, and difference (error) map.
|
|
220
|
+
|
|
221
|
+
Useful for regression tasks to visualize prediction errors.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
log_dir: Directory for TensorBoard logs
|
|
225
|
+
tag: Tag for the image in TensorBoard
|
|
226
|
+
freq: Frequency of plotting (every N epochs)
|
|
227
|
+
batch_indices: Which batch indices to visualize
|
|
228
|
+
|
|
229
|
+
Example:
|
|
230
|
+
plotter = DifferencePlotter('logs/images', tag='errors')
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
log_dir: str,
|
|
236
|
+
tag: str = 'difference',
|
|
237
|
+
freq: int = 1,
|
|
238
|
+
batch_indices: Optional[List[int]] = None
|
|
239
|
+
):
|
|
240
|
+
super().__init__(log_dir, tag, freq, batch_indices)
|
|
241
|
+
|
|
242
|
+
def create_figure(self, val_data, val_outputs, epoch, **kwargs):
|
|
243
|
+
"""Create figure with input, prediction, and difference."""
|
|
244
|
+
val_input = val_data[0]
|
|
245
|
+
val_target = val_data[1]
|
|
246
|
+
|
|
247
|
+
figures = []
|
|
248
|
+
for idx in self.batch_indices:
|
|
249
|
+
if idx >= val_input.shape[0]:
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
|
|
253
|
+
|
|
254
|
+
# Get data
|
|
255
|
+
input_img = self._prepare_image(val_input[idx])
|
|
256
|
+
target_img = self._prepare_image(val_target[idx])
|
|
257
|
+
pred_img = self._prepare_image(val_outputs[idx])
|
|
258
|
+
diff_img = target_img - pred_img
|
|
259
|
+
|
|
260
|
+
# Plot
|
|
261
|
+
axes[0].imshow(input_img, cmap='gray')
|
|
262
|
+
axes[0].set_title('Input')
|
|
263
|
+
axes[0].axis('off')
|
|
264
|
+
|
|
265
|
+
axes[1].imshow(target_img, cmap='gray')
|
|
266
|
+
axes[1].set_title('Target')
|
|
267
|
+
axes[1].axis('off')
|
|
268
|
+
|
|
269
|
+
axes[2].imshow(pred_img, cmap='gray')
|
|
270
|
+
axes[2].set_title('Prediction')
|
|
271
|
+
axes[2].axis('off')
|
|
272
|
+
|
|
273
|
+
im = axes[3].imshow(diff_img, cmap='RdBu_r', vmin=-diff_img.std(), vmax=diff_img.std())
|
|
274
|
+
axes[3].set_title('Difference (Target - Pred)')
|
|
275
|
+
axes[3].axis('off')
|
|
276
|
+
plt.colorbar(im, ax=axes[3])
|
|
277
|
+
|
|
278
|
+
plt.tight_layout()
|
|
279
|
+
figures.append(fig)
|
|
280
|
+
|
|
281
|
+
return figures if len(figures) > 1 else figures[0]
|
|
282
|
+
|
|
283
|
+
def _prepare_image(self, tensor):
|
|
284
|
+
"""Prepare tensor for visualization."""
|
|
285
|
+
img = tensor.detach().cpu()
|
|
286
|
+
|
|
287
|
+
# Handle time series
|
|
288
|
+
if img.ndim == 4:
|
|
289
|
+
img = img[img.shape[0] // 2]
|
|
290
|
+
|
|
291
|
+
# Handle channels
|
|
292
|
+
if img.ndim == 3:
|
|
293
|
+
if img.shape[0] > 1:
|
|
294
|
+
img = img[0] # Take first channel
|
|
295
|
+
else:
|
|
296
|
+
img = img[0]
|
|
297
|
+
|
|
298
|
+
return img.numpy()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from typing import Optional, Callable
|
|
2
|
+
|
|
3
|
+
from kito.callbacks.callback_base import Callback
|
|
4
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TensorBoardScalars(Callback):
|
|
8
|
+
"""
|
|
9
|
+
Log scalar metrics to TensorBoard.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
log_dir: Directory for TensorBoard logs
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
tb_scalars = TensorBoardScalars('logs/tensorboard')
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, log_dir: str):
|
|
19
|
+
self.log_dir = log_dir
|
|
20
|
+
self.writer = None
|
|
21
|
+
|
|
22
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
23
|
+
"""Initialize TensorBoard writer."""
|
|
24
|
+
self.writer = SummaryWriter(self.log_dir)
|
|
25
|
+
|
|
26
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
27
|
+
"""Log scalars to TensorBoard."""
|
|
28
|
+
if logs is None:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
for key, value in logs.items():
|
|
32
|
+
self.writer.add_scalar(key, value, epoch)
|
|
33
|
+
|
|
34
|
+
self.writer.flush()
|
|
35
|
+
|
|
36
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
37
|
+
"""Close TensorBoard writer."""
|
|
38
|
+
if self.writer:
|
|
39
|
+
self.writer.close()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class TensorBoardHistograms(Callback):
|
|
43
|
+
"""
|
|
44
|
+
Log model parameter histograms to TensorBoard.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
log_dir: Directory for TensorBoard logs
|
|
48
|
+
freq: Frequency of histogram logging (every N epochs)
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
tb_histograms = TensorBoardHistograms('logs/tensorboard', freq=5)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, log_dir: str, freq: int = 1):
|
|
55
|
+
self.log_dir = log_dir
|
|
56
|
+
self.freq = freq
|
|
57
|
+
self.writer = None
|
|
58
|
+
|
|
59
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
60
|
+
"""Initialize TensorBoard writer."""
|
|
61
|
+
self.writer = SummaryWriter(self.log_dir)
|
|
62
|
+
|
|
63
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
64
|
+
"""Log parameter histograms."""
|
|
65
|
+
if epoch % self.freq != 0:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Handle DDP
|
|
69
|
+
model_to_log = model.module if hasattr(model, 'module') else model
|
|
70
|
+
|
|
71
|
+
for name, param in model_to_log.named_parameters():
|
|
72
|
+
self.writer.add_histogram(name, param, epoch)
|
|
73
|
+
|
|
74
|
+
self.writer.flush()
|
|
75
|
+
|
|
76
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
77
|
+
"""Close TensorBoard writer."""
|
|
78
|
+
if self.writer:
|
|
79
|
+
self.writer.close()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class TensorBoardGraph(Callback):
|
|
83
|
+
"""
|
|
84
|
+
Log model graph to TensorBoard.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
log_dir: Directory for TensorBoard logs
|
|
88
|
+
input_to_model: Function that returns sample input for the model
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
def get_input():
|
|
92
|
+
return torch.randn(1, 3, 64, 64)
|
|
93
|
+
|
|
94
|
+
tb_graph = TensorBoardGraph('logs/tensorboard', input_to_model=get_input)
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
log_dir: str,
|
|
100
|
+
input_to_model: Optional[Callable] = None
|
|
101
|
+
):
|
|
102
|
+
self.log_dir = log_dir
|
|
103
|
+
self.input_to_model = input_to_model
|
|
104
|
+
self.writer = None
|
|
105
|
+
self.logged = False
|
|
106
|
+
|
|
107
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
108
|
+
"""Initialize TensorBoard writer and log graph."""
|
|
109
|
+
if self.logged:
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
self.writer = SummaryWriter(self.log_dir)
|
|
113
|
+
|
|
114
|
+
# Get sample input
|
|
115
|
+
if self.input_to_model:
|
|
116
|
+
sample_input = self.input_to_model()
|
|
117
|
+
elif hasattr(engine, 'get_sample_input'):
|
|
118
|
+
sample_input = engine.get_sample_input()
|
|
119
|
+
else:
|
|
120
|
+
return # Can't log graph without input
|
|
121
|
+
|
|
122
|
+
# Handle DDP
|
|
123
|
+
model_to_log = model.module if hasattr(model, 'module') else model
|
|
124
|
+
|
|
125
|
+
self.writer.add_graph(model_to_log, sample_input)
|
|
126
|
+
self.writer.flush()
|
|
127
|
+
self.logged = True
|
|
128
|
+
|
|
129
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
130
|
+
"""Close TensorBoard writer."""
|
|
131
|
+
if self.writer:
|
|
132
|
+
self.writer.close()
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from kito.callbacks.callback_base import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TextLogger(Callback):
|
|
8
|
+
"""
|
|
9
|
+
Log training metrics to a text file.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
filename: Path to log file
|
|
13
|
+
append: Append to existing file or overwrite
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
text_logger = TextLogger('logs/training.log')
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
filename: str,
|
|
22
|
+
append: bool = True
|
|
23
|
+
):
|
|
24
|
+
self.filename = filename
|
|
25
|
+
self.append = append
|
|
26
|
+
self.file = None
|
|
27
|
+
|
|
28
|
+
# Create directory
|
|
29
|
+
os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)
|
|
30
|
+
|
|
31
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
32
|
+
"""Open log file."""
|
|
33
|
+
mode = 'a' if self.append else 'w'
|
|
34
|
+
self.file = open(self.filename, mode)
|
|
35
|
+
|
|
36
|
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
37
|
+
self.file.write(f"{'=' * 60}")
|
|
38
|
+
self.file.write(f"Training started at {timestamp}")
|
|
39
|
+
self.file.write(f"{'=' * 60}")
|
|
40
|
+
self.file.flush()
|
|
41
|
+
|
|
42
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
43
|
+
"""Write metrics to log file."""
|
|
44
|
+
if logs is None:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
self.file.write(f"Epoch {epoch}: ")
|
|
48
|
+
metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in logs.items())
|
|
49
|
+
self.file.write(metrics_str + "\n")
|
|
50
|
+
self.file.flush()
|
|
51
|
+
|
|
52
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
53
|
+
"""Close log file."""
|
|
54
|
+
if self.file:
|
|
55
|
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
56
|
+
self.file.write(f"\nTraining ended at {timestamp}\n")
|
|
57
|
+
self.file.close()
|
kito/config/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base configuration system for KitoModule framework.
|
|
3
|
+
|
|
4
|
+
Users can extend these base configs with their own custom parameters.
|
|
5
|
+
"""
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Tuple, List, Optional, Dict, Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PreprocessingStepConfig:
|
|
12
|
+
"""
|
|
13
|
+
Configuration for a single preprocessing step.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
type: Name of preprocessing class (e.g., 'detrend', 'standardization')
|
|
17
|
+
params: Dictionary of parameters to pass to preprocessing class
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> step = PreprocessingStepConfig(
|
|
21
|
+
... type='standardization',
|
|
22
|
+
... params={'mean': 0.5, 'std': 0.2}
|
|
23
|
+
... )
|
|
24
|
+
"""
|
|
25
|
+
type: str
|
|
26
|
+
params: Dict[str, Any] = field(default_factory=dict)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class DataConfig:
|
|
31
|
+
"""
|
|
32
|
+
Data loading and preprocessing configuration.
|
|
33
|
+
|
|
34
|
+
This config defines:
|
|
35
|
+
- What dataset to use (H5, memory, custom)
|
|
36
|
+
- Where data is located
|
|
37
|
+
- How to initialize the dataset (flexible args)
|
|
38
|
+
- Memory loading strategy
|
|
39
|
+
- How to split data (train/val/test)
|
|
40
|
+
- What preprocessing to apply
|
|
41
|
+
- DataLoader settings
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
# Dataset configuration
|
|
45
|
+
dataset_type: str = 'h5dataset' # 'h5dataset', 'memdataset', or custom
|
|
46
|
+
|
|
47
|
+
# Simple path (backward compatible)
|
|
48
|
+
dataset_path: str = ''
|
|
49
|
+
|
|
50
|
+
# Flexible initialization args (for custom datasets)
|
|
51
|
+
# If provided, takes precedence over dataset_path
|
|
52
|
+
# Allows any constructor signature: Dataset(**dataset_init_args)
|
|
53
|
+
dataset_init_args: Dict[str, Any] = field(default_factory=dict)
|
|
54
|
+
|
|
55
|
+
# Memory management
|
|
56
|
+
load_into_memory: bool = False # Load entire dataset into RAM for faster training
|
|
57
|
+
|
|
58
|
+
# Splitting ratios
|
|
59
|
+
train_ratio: float = 0.8
|
|
60
|
+
val_ratio: float = 0.1
|
|
61
|
+
# test_ratio is implicit: 1 - train_ratio - val_ratio
|
|
62
|
+
|
|
63
|
+
# Total samples to use (None = use all available)
|
|
64
|
+
total_samples: Optional[int] = None
|
|
65
|
+
|
|
66
|
+
# Preprocessing pipeline
|
|
67
|
+
# List of preprocessing steps applied in order
|
|
68
|
+
preprocessing: List[PreprocessingStepConfig] = field(default_factory=list)
|
|
69
|
+
|
|
70
|
+
# DataLoader settings
|
|
71
|
+
num_workers: int = 0
|
|
72
|
+
prefetch_factor: int = 2
|
|
73
|
+
pin_memory: bool = False
|
|
74
|
+
persistent_workers: bool = False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class TrainingConfig:
|
|
79
|
+
"""Core training parameters required by KitoModule."""
|
|
80
|
+
|
|
81
|
+
# Essential training parameters
|
|
82
|
+
learning_rate: float
|
|
83
|
+
n_train_epochs: int
|
|
84
|
+
batch_size: int
|
|
85
|
+
train_mode: bool # True for training, False for inference
|
|
86
|
+
|
|
87
|
+
# Verbosity (0=silent, 1=progress bar, 2=detailed)
|
|
88
|
+
train_verbosity_level: int = 2
|
|
89
|
+
val_verbosity_level: int = 2
|
|
90
|
+
test_verbosity_level: int = 2
|
|
91
|
+
|
|
92
|
+
# Distributed training
|
|
93
|
+
distributed_training: bool = False
|
|
94
|
+
master_gpu_id: int = 0 # Only used if not distributed
|
|
95
|
+
|
|
96
|
+
# Weight initialization
|
|
97
|
+
initialize_model_with_saved_weights: bool = False
|
|
98
|
+
|
|
99
|
+
# device type initialization
|
|
100
|
+
device_type: str = "cuda" # "cuda", "mps", or "cpu"
|
|
101
|
+
|
|
102
|
+
def __post_init__(self):
|
|
103
|
+
"""Validate device_type after initialization."""
|
|
104
|
+
valid_devices = {"cuda", "mps", "cpu"}
|
|
105
|
+
|
|
106
|
+
if self.device_type not in valid_devices:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Invalid device_type: '{self.device_type}'. "
|
|
109
|
+
f"Must be one of {valid_devices}."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Normalize to lowercase (user-friendly)
|
|
113
|
+
self.device_type = self.device_type.lower()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class ModelConfig:
|
|
118
|
+
"""Core model parameters required by KitoModule."""
|
|
119
|
+
|
|
120
|
+
# Data dimensions
|
|
121
|
+
input_data_size: Tuple[int, ...] # Flexible shape
|
|
122
|
+
|
|
123
|
+
# Loss and optimization
|
|
124
|
+
loss: str = ""
|
|
125
|
+
|
|
126
|
+
# Callbacks and logging
|
|
127
|
+
log_to_tensorboard: bool = False
|
|
128
|
+
save_model_weights: bool = False
|
|
129
|
+
text_logging: bool = False
|
|
130
|
+
csv_logging: bool = False
|
|
131
|
+
train_codename: str = "experiment"
|
|
132
|
+
|
|
133
|
+
# Weights
|
|
134
|
+
weight_load_path: str = ""
|
|
135
|
+
|
|
136
|
+
# Inference
|
|
137
|
+
save_inference_to_disk: bool = False
|
|
138
|
+
inference_filename: str = ""
|
|
139
|
+
|
|
140
|
+
# TensorBoard visualization (optional)
|
|
141
|
+
tensorboard_img_id: str = "training_viz"
|
|
142
|
+
batch_idx_viz: List[int] = field(default_factory=lambda: [0])
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class WorkDirConfig:
|
|
147
|
+
"""Working directory configuration."""
|
|
148
|
+
work_directory: str
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class CallbacksConfig:
|
|
153
|
+
"""
|
|
154
|
+
Configuration for Kito's built-in default callbacks.
|
|
155
|
+
|
|
156
|
+
For custom callbacks, use instead:
|
|
157
|
+
callbacks = engine.get_default_callbacks()
|
|
158
|
+
callbacks.append(MyCustomCallback())
|
|
159
|
+
engine.fit(..., callbacks=callbacks)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
# === CSV Logger ===
|
|
163
|
+
enable_csv_logger: bool = True
|
|
164
|
+
|
|
165
|
+
# === Text Logger ===
|
|
166
|
+
enable_text_logger: bool = True
|
|
167
|
+
|
|
168
|
+
# === Model Checkpoint ===
|
|
169
|
+
enable_model_checkpoint: bool = True
|
|
170
|
+
checkpoint_monitor: str = 'val_loss'
|
|
171
|
+
checkpoint_mode: str = 'min' # 'min' or 'max'
|
|
172
|
+
checkpoint_save_best_only: bool = True
|
|
173
|
+
checkpoint_verbose: bool = False
|
|
174
|
+
|
|
175
|
+
# === TensorBoard ===
|
|
176
|
+
enable_tensorboard: bool = False # Master switch
|
|
177
|
+
tensorboard_scalars: bool = True
|
|
178
|
+
tensorboard_histograms: bool = True
|
|
179
|
+
tensorboard_histogram_freq: int = 5
|
|
180
|
+
tensorboard_graph: bool = True
|
|
181
|
+
tensorboard_images: bool = False
|
|
182
|
+
tensorboard_image_freq: int = 1
|
|
183
|
+
tensorboard_batch_indices: List[int] = field(default_factory=lambda: [0])
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@dataclass
|
|
187
|
+
class KitoModuleConfig:
|
|
188
|
+
"""
|
|
189
|
+
Base configuration container for KitoModule.
|
|
190
|
+
|
|
191
|
+
Contains all configuration sections:
|
|
192
|
+
- training: Training parameters
|
|
193
|
+
- model: Model architecture and settings
|
|
194
|
+
- workdir: Output directories
|
|
195
|
+
- data: Dataset and preprocessing
|
|
196
|
+
"""
|
|
197
|
+
training: TrainingConfig
|
|
198
|
+
model: ModelConfig
|
|
199
|
+
workdir: WorkDirConfig
|
|
200
|
+
data: DataConfig
|
|
201
|
+
callbacks: CallbacksConfig
|