pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
kito/module.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from packaging.version import parse
|
|
5
|
+
from torchsummary import summary as model_summary
|
|
6
|
+
|
|
7
|
+
from kito.config.moduleconfig import KitoModuleConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class KitoModule:
|
|
11
|
+
"""
|
|
12
|
+
Base module for PyTorch models.
|
|
13
|
+
|
|
14
|
+
Focused on model definition and single-batch operations.
|
|
15
|
+
The Engine handles all iteration, callbacks, and orchestration.
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
class MyModel(KitoModule):
|
|
19
|
+
def build_inner_model(self):
|
|
20
|
+
self.model = nn.Sequential(...)
|
|
21
|
+
self.model_input_size = (3, 64, 64)
|
|
22
|
+
self.standard_data_shape = (3, 64, 64) # For inference
|
|
23
|
+
|
|
24
|
+
def bind_optimizer(self):
|
|
25
|
+
self.optimizer = torch.optim.Adam(
|
|
26
|
+
self.model.parameters(),
|
|
27
|
+
lr=self.learning_rate
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Use with Engine
|
|
31
|
+
module = MyModel('MyModel', device, config)
|
|
32
|
+
module.build()
|
|
33
|
+
module.associate_optimizer()
|
|
34
|
+
|
|
35
|
+
engine = Engine(module)
|
|
36
|
+
engine.fit(train_loader, val_loader, max_epochs=100)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, model_name: str, config: KitoModuleConfig = None):
|
|
40
|
+
"""
|
|
41
|
+
Initialize BaseModule.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model_name: Name of the model
|
|
45
|
+
config: Optional config object for future extensibility
|
|
46
|
+
"""
|
|
47
|
+
self.model_name = model_name
|
|
48
|
+
self.config = config
|
|
49
|
+
|
|
50
|
+
# Extract useful config values if provided
|
|
51
|
+
if config is not None:
|
|
52
|
+
self.learning_rate = config.training.learning_rate
|
|
53
|
+
self.batch_size = config.training.batch_size
|
|
54
|
+
else:
|
|
55
|
+
self.learning_rate = None
|
|
56
|
+
self.batch_size = None
|
|
57
|
+
|
|
58
|
+
# Model components
|
|
59
|
+
self.model = None
|
|
60
|
+
self.device = None # set by Engine
|
|
61
|
+
self.model_input_size = None
|
|
62
|
+
self.standard_data_shape = None # For inference (set by subclass)
|
|
63
|
+
self.optimizer = None
|
|
64
|
+
|
|
65
|
+
# Loss function (set by subclass or from config)
|
|
66
|
+
if config is not None:
|
|
67
|
+
from kito.utils.loss_utils import get_loss
|
|
68
|
+
self.loss = get_loss(config.model.loss)
|
|
69
|
+
else:
|
|
70
|
+
self.loss = None
|
|
71
|
+
|
|
72
|
+
# State flags
|
|
73
|
+
self._model_built = False
|
|
74
|
+
self._optimizer_bound = False
|
|
75
|
+
self._weights_loaded = False
|
|
76
|
+
|
|
77
|
+
# ========================================================================
|
|
78
|
+
# ABSTRACT METHODS - Must be implemented by subclasses
|
|
79
|
+
# ========================================================================
|
|
80
|
+
|
|
81
|
+
def build_inner_model(self, *args, **kwargs):
|
|
82
|
+
"""
|
|
83
|
+
Build the model architecture.
|
|
84
|
+
|
|
85
|
+
Must set:
|
|
86
|
+
- self.model: The PyTorch model
|
|
87
|
+
- self.model_input_size: Tuple of input shape (C, H, W) or (C, H, W, D)
|
|
88
|
+
- self.standard_data_shape: Output shape for inference (optional)
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
def build_inner_model(self):
|
|
92
|
+
self.model = nn.Sequential(
|
|
93
|
+
nn.Conv2d(3, 64, 3, padding=1),
|
|
94
|
+
nn.ReLU(),
|
|
95
|
+
nn.Conv2d(64, 3, 3, padding=1)
|
|
96
|
+
)
|
|
97
|
+
self.model_input_size = (3, 64, 64)
|
|
98
|
+
self.standard_data_shape = (3, 64, 64)
|
|
99
|
+
"""
|
|
100
|
+
raise NotImplementedError("Subclasses must implement build_inner_model().")
|
|
101
|
+
|
|
102
|
+
def bind_optimizer(self, *args, **kwargs):
|
|
103
|
+
"""
|
|
104
|
+
Setup the optimizer.
|
|
105
|
+
|
|
106
|
+
Must set:
|
|
107
|
+
- self.optimizer: The PyTorch optimizer
|
|
108
|
+
|
|
109
|
+
Example:
|
|
110
|
+
def bind_optimizer(self):
|
|
111
|
+
self.optimizer = torch.optim.Adam(
|
|
112
|
+
self.model.parameters(),
|
|
113
|
+
lr=self.learning_rate
|
|
114
|
+
)
|
|
115
|
+
"""
|
|
116
|
+
raise NotImplementedError("Subclasses must implement bind_optimizer().")
|
|
117
|
+
|
|
118
|
+
def _check_data_shape(self):
|
|
119
|
+
"""
|
|
120
|
+
Check data shape on first batch.
|
|
121
|
+
|
|
122
|
+
Optional - implement if you need to validate input shape.
|
|
123
|
+
Called by Engine on first training batch.
|
|
124
|
+
|
|
125
|
+
Example:
|
|
126
|
+
def _check_data_shape(self):
|
|
127
|
+
# Validate that data matches expected shape
|
|
128
|
+
pass
|
|
129
|
+
"""
|
|
130
|
+
pass # Default: no checking
|
|
131
|
+
|
|
132
|
+
# ========================================================================
|
|
133
|
+
# SETUP METHODS
|
|
134
|
+
# ========================================================================
|
|
135
|
+
|
|
136
|
+
def _move_to_device(self, device: torch.device):
|
|
137
|
+
"""
|
|
138
|
+
Internal method called by Engine to move model to device.
|
|
139
|
+
|
|
140
|
+
Called AFTER build() by Engine.
|
|
141
|
+
"""
|
|
142
|
+
self.device = device
|
|
143
|
+
if self.model is not None:
|
|
144
|
+
self.model.to(device)
|
|
145
|
+
|
|
146
|
+
def build(self, *args, **kwargs):
|
|
147
|
+
"""Build model and move to device."""
|
|
148
|
+
self.build_inner_model(*args, **kwargs)
|
|
149
|
+
# self.model.to(self.device)
|
|
150
|
+
self._model_built = True
|
|
151
|
+
|
|
152
|
+
def associate_optimizer(self, *args, **kwargs):
|
|
153
|
+
"""Setup optimizer."""
|
|
154
|
+
if not self._model_built:
|
|
155
|
+
raise RuntimeError("Must call build() before associate_optimizer()")
|
|
156
|
+
self.bind_optimizer(*args, **kwargs)
|
|
157
|
+
self._optimizer_bound = True
|
|
158
|
+
|
|
159
|
+
# ========================================================================
|
|
160
|
+
# SINGLE-BATCH OPERATIONS (Called by Engine)
|
|
161
|
+
# ========================================================================
|
|
162
|
+
|
|
163
|
+
def training_step(self, batch, pbar_handler=None):
|
|
164
|
+
"""
|
|
165
|
+
Perform one training step on a single batch.
|
|
166
|
+
|
|
167
|
+
Called by Engine for each training batch.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
batch: Tuple of (inputs, targets) from DataLoader
|
|
171
|
+
pbar_handler: Progress bar handler (optional, provided by Engine)
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
dict: {'loss': tensor} - Must contain at least 'loss'
|
|
175
|
+
|
|
176
|
+
Override this for custom training logic (freeze layers, gradient accumulation, etc.).
|
|
177
|
+
|
|
178
|
+
Default implementation:
|
|
179
|
+
1. Move data to device
|
|
180
|
+
2. Zero gradients
|
|
181
|
+
3. Forward pass
|
|
182
|
+
4. Compute loss
|
|
183
|
+
5. Backward pass
|
|
184
|
+
6. Optimizer step
|
|
185
|
+
"""
|
|
186
|
+
inputs, targets = batch
|
|
187
|
+
|
|
188
|
+
# Move to device
|
|
189
|
+
inputs = self.send_data_to_device(inputs)
|
|
190
|
+
targets = targets.to(self.device)
|
|
191
|
+
|
|
192
|
+
# Zero gradients
|
|
193
|
+
self.optimizer.zero_grad()
|
|
194
|
+
|
|
195
|
+
# Forward pass
|
|
196
|
+
outputs = self.pass_data_through_model(inputs)
|
|
197
|
+
|
|
198
|
+
# Compute loss
|
|
199
|
+
loss = self.compute_loss((inputs, targets), outputs)
|
|
200
|
+
|
|
201
|
+
# Backward pass
|
|
202
|
+
loss.backward()
|
|
203
|
+
|
|
204
|
+
# Optimizer step
|
|
205
|
+
self.optimizer.step()
|
|
206
|
+
|
|
207
|
+
return {'loss': loss}
|
|
208
|
+
|
|
209
|
+
def validation_step(self, batch, pbar_handler=None):
|
|
210
|
+
"""
|
|
211
|
+
Perform one validation step on a single batch.
|
|
212
|
+
|
|
213
|
+
Called by Engine for each validation batch.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
batch: Tuple of (inputs, targets) from DataLoader
|
|
217
|
+
pbar_handler: Progress bar handler (optional, provided by Engine)
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
dict: Must contain 'loss', 'outputs', 'inputs', 'targets'
|
|
221
|
+
|
|
222
|
+
Override this for custom validation logic.
|
|
223
|
+
"""
|
|
224
|
+
inputs, targets = batch
|
|
225
|
+
|
|
226
|
+
# Move to device
|
|
227
|
+
inputs = self.send_data_to_device(inputs)
|
|
228
|
+
targets = targets.to(self.device)
|
|
229
|
+
|
|
230
|
+
# Forward pass (no gradients)
|
|
231
|
+
outputs = self.pass_data_through_model(inputs)
|
|
232
|
+
|
|
233
|
+
# Compute loss
|
|
234
|
+
loss = self.compute_loss((inputs, targets), outputs)
|
|
235
|
+
|
|
236
|
+
return {
|
|
237
|
+
'loss': loss,
|
|
238
|
+
'outputs': outputs,
|
|
239
|
+
'targets': targets,
|
|
240
|
+
'inputs': inputs
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
def prediction_step(self, batch, pbar_handler=None):
|
|
244
|
+
"""
|
|
245
|
+
Perform one prediction step on a single batch.
|
|
246
|
+
|
|
247
|
+
Called by Engine for each inference batch.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
batch: Input data from DataLoader (can be tuple or tensor)
|
|
251
|
+
pbar_handler: Progress bar handler (optional, provided by Engine)
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
tensor: Model predictions
|
|
255
|
+
|
|
256
|
+
Override this for custom prediction logic.
|
|
257
|
+
"""
|
|
258
|
+
# Handle different batch formats
|
|
259
|
+
'''if isinstance(batch, (tuple, list)):
|
|
260
|
+
inputs = batch[0]
|
|
261
|
+
else:
|
|
262
|
+
inputs = batch # here in the else case errors might produce...'''
|
|
263
|
+
if isinstance(batch, (tuple, list)):
|
|
264
|
+
inputs = batch[0] if len(batch) > 0 else batch
|
|
265
|
+
elif isinstance(batch, torch.Tensor):
|
|
266
|
+
inputs = batch
|
|
267
|
+
elif isinstance(batch, dict):
|
|
268
|
+
# Handle dict batches (common in HuggingFace)
|
|
269
|
+
inputs = batch.get('input', batch.get('data', batch))
|
|
270
|
+
else:
|
|
271
|
+
raise TypeError(f"Unsupported batch type: {type(batch)}")
|
|
272
|
+
|
|
273
|
+
# Move to device
|
|
274
|
+
inputs = self.send_data_to_device(inputs)
|
|
275
|
+
|
|
276
|
+
# Forward pass
|
|
277
|
+
outputs = self.pass_data_through_model(inputs)
|
|
278
|
+
|
|
279
|
+
# Handle model outputs (for multi-output scenarios)
|
|
280
|
+
outputs = self.handle_model_outputs(outputs)
|
|
281
|
+
|
|
282
|
+
return outputs
|
|
283
|
+
|
|
284
|
+
# ========================================================================
|
|
285
|
+
# CORE OPERATIONS (Can be overridden)
|
|
286
|
+
# ========================================================================
|
|
287
|
+
|
|
288
|
+
def pass_data_through_model(self, data):
|
|
289
|
+
"""
|
|
290
|
+
Forward pass through model.
|
|
291
|
+
|
|
292
|
+
Override for multi-input models.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
data: Input tensor(s)
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Output tensor(s)
|
|
299
|
+
"""
|
|
300
|
+
return self.model(data)
|
|
301
|
+
|
|
302
|
+
def compute_loss(self, data_pair, y_pred, **kwargs):
|
|
303
|
+
"""
|
|
304
|
+
Compute loss from data pair and predictions.
|
|
305
|
+
|
|
306
|
+
Override for custom loss computation or multi-output scenarios.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
data_pair: Tuple of (inputs, targets)
|
|
310
|
+
y_pred: Model predictions
|
|
311
|
+
**kwargs: Additional arguments (e.g., epoch)
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Loss tensor
|
|
315
|
+
"""
|
|
316
|
+
y_true = data_pair[1].to(self.device)
|
|
317
|
+
return self.apply_loss(y_pred, y_true, **kwargs)
|
|
318
|
+
|
|
319
|
+
def apply_loss(self, y_pred, y_true, **kwargs):
|
|
320
|
+
"""
|
|
321
|
+
Apply loss function.
|
|
322
|
+
|
|
323
|
+
Override for custom loss scenarios.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
y_pred: Model predictions
|
|
327
|
+
y_true: Ground truth
|
|
328
|
+
**kwargs: Additional arguments
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Loss value
|
|
332
|
+
"""
|
|
333
|
+
return self.loss(y_pred, y_true)
|
|
334
|
+
|
|
335
|
+
def send_data_to_device(self, data):
|
|
336
|
+
"""
|
|
337
|
+
Move data to device.
|
|
338
|
+
|
|
339
|
+
Override for complex data structures (dict, nested tuples, etc.).
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
data: Data to move
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
Data on device
|
|
346
|
+
"""
|
|
347
|
+
return data.to(self.device)
|
|
348
|
+
|
|
349
|
+
def handle_model_outputs(self, outputs):
|
|
350
|
+
"""
|
|
351
|
+
Handle model outputs.
|
|
352
|
+
|
|
353
|
+
Override for multi-output scenarios.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
outputs: Raw model outputs
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Processed outputs
|
|
360
|
+
"""
|
|
361
|
+
return outputs
|
|
362
|
+
|
|
363
|
+
# ========================================================================
|
|
364
|
+
# WEIGHTS (Model-specific operations)
|
|
365
|
+
# ========================================================================
|
|
366
|
+
|
|
367
|
+
def load_weights(self, weight_path: str, strict: bool = True):
|
|
368
|
+
"""
|
|
369
|
+
Load model weights.
|
|
370
|
+
|
|
371
|
+
Called by Engine but can be overridden for custom loading logic.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
weight_path: Path to weight file
|
|
375
|
+
strict: Strict state dict loading
|
|
376
|
+
"""
|
|
377
|
+
if not os.path.exists(weight_path):
|
|
378
|
+
raise FileNotFoundError(f"Weight file not found: {weight_path}")
|
|
379
|
+
|
|
380
|
+
# Check file extension
|
|
381
|
+
_, file_extension = os.path.splitext(weight_path)
|
|
382
|
+
if file_extension != '.pt':
|
|
383
|
+
raise ValueError(f"Invalid weight file: {weight_path}. Must be .pt file.")
|
|
384
|
+
|
|
385
|
+
# Load weights
|
|
386
|
+
state_dict = torch.load(weight_path, map_location=self.device, weights_only=True)
|
|
387
|
+
self.model.load_state_dict(state_dict, strict=strict)
|
|
388
|
+
|
|
389
|
+
self._weights_loaded = True
|
|
390
|
+
|
|
391
|
+
def save_weights(self, weight_path: str):
|
|
392
|
+
"""
|
|
393
|
+
Save model weights.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
weight_path: Path to save weights
|
|
397
|
+
"""
|
|
398
|
+
os.makedirs(os.path.dirname(weight_path) or '.', exist_ok=True)
|
|
399
|
+
torch.save(self.model.state_dict(), weight_path)
|
|
400
|
+
|
|
401
|
+
# ========================================================================
|
|
402
|
+
# UTILITIES
|
|
403
|
+
# ========================================================================
|
|
404
|
+
|
|
405
|
+
def get_sample_input(self):
|
|
406
|
+
"""Get sample input tensor (for summaries, TensorBoard graphs, etc.)."""
|
|
407
|
+
if self.model_input_size is None:
|
|
408
|
+
raise ValueError("model_input_size not set. Call build() first.")
|
|
409
|
+
return torch.randn(1, *self.model_input_size).to(self.device)
|
|
410
|
+
|
|
411
|
+
def set_model_input_size(self, *args, **kwargs):
|
|
412
|
+
"""
|
|
413
|
+
Hook for setting model input size in complex scenarios.
|
|
414
|
+
|
|
415
|
+
Optional - most models set this in build_inner_model().
|
|
416
|
+
"""
|
|
417
|
+
raise NotImplementedError("Subclasses can implement set_model_input_size() if needed.")
|
|
418
|
+
|
|
419
|
+
def summary(self, summary_depth: int = 3):
|
|
420
|
+
"""Print model summary."""
|
|
421
|
+
if self.model_input_size is None:
|
|
422
|
+
raise ValueError("model_input_size not set. Call build() first.")
|
|
423
|
+
|
|
424
|
+
torch_version = torch.__version__.split('+')[0]
|
|
425
|
+
if parse(torch_version) < parse('2.6.0'):
|
|
426
|
+
model_summary(self.model, self.model_input_size, batch_dim=0, depth=summary_depth)
|
|
427
|
+
else:
|
|
428
|
+
model_summary(self.model, self.model_input_size)
|
|
429
|
+
|
|
430
|
+
# ========================================================================
|
|
431
|
+
# STATE PROPERTIES
|
|
432
|
+
# ========================================================================
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def is_built(self):
|
|
436
|
+
"""Check if model is built."""
|
|
437
|
+
return self._model_built
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def is_optimizer_set(self):
|
|
441
|
+
"""Check if optimizer is set."""
|
|
442
|
+
return self._optimizer_bound
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def is_weights_loaded(self):
|
|
446
|
+
"""Check if weights are loaded."""
|
|
447
|
+
return self._weights_loaded
|
|
File without changes
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseLogger:
|
|
7
|
+
def log_info(self, msg):
|
|
8
|
+
raise NotImplementedError
|
|
9
|
+
|
|
10
|
+
def log_warning(self, msg):
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
def log_error(self, msg):
|
|
14
|
+
raise NotImplementedError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DefaultLogger(BaseLogger):
|
|
18
|
+
def __init__(self, log_level=logging.INFO):
|
|
19
|
+
self.logger = logging.getLogger(__name__)
|
|
20
|
+
self.logger.setLevel(log_level)
|
|
21
|
+
handler = logging.StreamHandler()
|
|
22
|
+
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
|
|
23
|
+
self.logger.addHandler(handler)
|
|
24
|
+
|
|
25
|
+
def log_info(self, msg):
|
|
26
|
+
self.logger.info(msg)
|
|
27
|
+
|
|
28
|
+
def log_warning(self, msg):
|
|
29
|
+
self.logger.warning(msg)
|
|
30
|
+
|
|
31
|
+
def log_error(self, msg):
|
|
32
|
+
self.logger.error(msg)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DDPLogger(DefaultLogger):
|
|
36
|
+
def __init__(self, log_level=logging.INFO):
|
|
37
|
+
super().__init__(log_level)
|
|
38
|
+
# Only log from rank 0
|
|
39
|
+
self.is_driver = torch.distributed.get_rank() == 0
|
|
40
|
+
|
|
41
|
+
def log_info(self, msg):
|
|
42
|
+
if self.is_driver:
|
|
43
|
+
super().log_info(msg)
|
|
44
|
+
|
|
45
|
+
def log_warning(self, msg):
|
|
46
|
+
if self.is_driver:
|
|
47
|
+
super().log_warning(msg)
|
|
48
|
+
|
|
49
|
+
def log_error(self, msg):
|
|
50
|
+
if self.is_driver:
|
|
51
|
+
super().log_error(msg)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
import pkbar
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseProgressBarHandler(ABC):
|
|
8
|
+
|
|
9
|
+
def init(self, n_target_elements, verbosity_level, message=None):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def step(self, i, values):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
def finalize(self, i, values):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class StandardProgressBarHandler(BaseProgressBarHandler):
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self.progress_bar = None
|
|
22
|
+
# print("Epoch {}/{}".format(epoch + 1, self.n_train_epochs)) # might be replaced by logger
|
|
23
|
+
|
|
24
|
+
def init(self, n_target_elements, verbosity_level, message=None):
|
|
25
|
+
if message is not None:
|
|
26
|
+
print(message)
|
|
27
|
+
self.progress_bar = pkbar.Kbar(target=n_target_elements, always_stateful=False, width=25,
|
|
28
|
+
verbose=verbosity_level)
|
|
29
|
+
self.progress_bar.update(0, None)
|
|
30
|
+
|
|
31
|
+
def step(self, i, values):
|
|
32
|
+
self.progress_bar.update(i, values)
|
|
33
|
+
|
|
34
|
+
def finalize(self, i, values):
|
|
35
|
+
self.progress_bar.add(i, values)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DDPProgressBarHandler(BaseProgressBarHandler):
|
|
39
|
+
def __init__(self):
|
|
40
|
+
self.progress_bar = None
|
|
41
|
+
self.is_driver = (torch.distributed.get_rank() == 0)
|
|
42
|
+
# print("Epoch {}/{}".format(epoch + 1, self.n_train_epochs)) # might be replaced by logger
|
|
43
|
+
|
|
44
|
+
def init(self, n_target_elements, verbosity_level, message=None):
|
|
45
|
+
if self.is_driver:
|
|
46
|
+
if message is not None:
|
|
47
|
+
print(message)
|
|
48
|
+
self.progress_bar = pkbar.Kbar(target=n_target_elements, always_stateful=False, width=25,
|
|
49
|
+
verbose=verbosity_level)
|
|
50
|
+
|
|
51
|
+
def step(self, i, values):
|
|
52
|
+
if self.is_driver:
|
|
53
|
+
self.progress_bar.update(i, values)
|
|
54
|
+
|
|
55
|
+
def finalize(self, i, values):
|
|
56
|
+
if self.is_driver:
|
|
57
|
+
self.progress_bar.add(i, values)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
class ReadinessValidator:
|
|
2
|
+
"""
|
|
3
|
+
Validates module readiness for different operations.
|
|
4
|
+
|
|
5
|
+
This replaces the decorator pattern with a cleaner strategy pattern
|
|
6
|
+
that is easier to test and extend.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
# In Engine
|
|
10
|
+
ReadinessValidator.check_for_training(module)
|
|
11
|
+
ReadinessValidator.check_for_inference(module, weight_path)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def check_for_training(module):
|
|
16
|
+
"""
|
|
17
|
+
Check if module is ready for training.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
module: BaseModule instance
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
RuntimeError: If module is not ready
|
|
24
|
+
"""
|
|
25
|
+
if not module.is_built:
|
|
26
|
+
raise RuntimeError(
|
|
27
|
+
f"Module '{module.model_name}' not built. "
|
|
28
|
+
"Call module.build() before training."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if not module.is_optimizer_set:
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
f"Module '{module.model_name}' optimizer not set. "
|
|
34
|
+
"Call module.associate_optimizer() before training."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if module.learning_rate is None:
|
|
38
|
+
raise RuntimeError(
|
|
39
|
+
f"Module '{module.model_name}' learning_rate not set."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def check_for_inference(module, weight_path=None):
|
|
44
|
+
"""
|
|
45
|
+
Check if module is ready for inference.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
module: BaseModule instance
|
|
49
|
+
weight_path: Optional weight path to check
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
RuntimeError: If module is not ready
|
|
53
|
+
"""
|
|
54
|
+
if not module.is_built:
|
|
55
|
+
raise RuntimeError(
|
|
56
|
+
f"Module '{module.model_name}' not built. "
|
|
57
|
+
"Call module.build() before inference."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if not module.is_weights_loaded:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
f"Module '{module.model_name}' weights not loaded. "
|
|
63
|
+
"Call module.load_weights() or engine.load_weights() before inference."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if weight_path is not None:
|
|
67
|
+
import os
|
|
68
|
+
if not os.path.exists(weight_path):
|
|
69
|
+
raise FileNotFoundError(f"Weight file not found: {weight_path}")
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def check_data_loaders(train_loader=None, val_loader=None, test_loader=None):
|
|
73
|
+
"""
|
|
74
|
+
Check if data loaders are provided.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
train_loader: Training data loader
|
|
78
|
+
val_loader: Validation data loader
|
|
79
|
+
test_loader: Test data loader
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If required loaders are missing
|
|
83
|
+
"""
|
|
84
|
+
if train_loader is None and val_loader is None and test_loader is None:
|
|
85
|
+
raise ValueError("At least one data loader must be provided")
|
kito/utils/__init__.py
ADDED
|
File without changes
|
kito/utils/decorators.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def require_mode(mode: str):
|
|
5
|
+
"""
|
|
6
|
+
Decorator to enforce train_mode compatibility in Engine methods.
|
|
7
|
+
|
|
8
|
+
Validates that the requested operation (train or inference) matches
|
|
9
|
+
the config.training.train_mode setting.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
mode: Either 'train' or 'inference'
|
|
13
|
+
|
|
14
|
+
Raises:
|
|
15
|
+
RuntimeError: If the operation conflicts with train_mode setting
|
|
16
|
+
ValueError: If mode is not 'train' or 'inference'
|
|
17
|
+
"""
|
|
18
|
+
if mode not in ('train', 'inference'):
|
|
19
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'train' or 'inference'.")
|
|
20
|
+
|
|
21
|
+
def decorator(func):
|
|
22
|
+
@wraps(func)
|
|
23
|
+
def wrapper(self, *args, **kwargs):
|
|
24
|
+
# Get train_mode from config (default True for backward compatibility)
|
|
25
|
+
config_mode = getattr(self.config.training, 'train_mode', True)
|
|
26
|
+
|
|
27
|
+
# Check compatibility
|
|
28
|
+
if mode == 'train' and not config_mode:
|
|
29
|
+
raise RuntimeError(
|
|
30
|
+
f"Training not allowed: config.training.train_mode is set to False (inference mode).\n"
|
|
31
|
+
f"To enable training, set config.training.train_mode=True in your config."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if mode == 'inference' and config_mode:
|
|
35
|
+
raise RuntimeError(
|
|
36
|
+
f"Inference not allowed: config.training.train_mode is set to True (training mode).\n"
|
|
37
|
+
f"To enable inference, set config.training.train_mode=False in your config."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Mode is compatible - execute the original function
|
|
41
|
+
return func(self, *args, **kwargs)
|
|
42
|
+
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
45
|
+
return decorator
|