pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
kito/engine.py
ADDED
|
@@ -0,0 +1,841 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import h5py
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
|
|
12
|
+
from kito.callbacks.callback_base import Callback, CallbackList
|
|
13
|
+
from kito.callbacks.ddp_aware_callback import DDPAwareCallback
|
|
14
|
+
from kito.config.moduleconfig import CallbacksConfig
|
|
15
|
+
from kito.data.datapipeline import GenericDataPipeline
|
|
16
|
+
from kito.module import KitoModule
|
|
17
|
+
from kito.strategies.logger_strategy import DDPLogger, DefaultLogger
|
|
18
|
+
from kito.strategies.progress_bar_strategy import (
|
|
19
|
+
StandardProgressBarHandler,
|
|
20
|
+
DDPProgressBarHandler
|
|
21
|
+
)
|
|
22
|
+
from kito.strategies.readiness_validator import ReadinessValidator
|
|
23
|
+
from kito.utils.decorators import require_mode
|
|
24
|
+
from kito.utils.gpu_utils import assign_device, get_available_devices
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Engine:
|
|
28
|
+
"""
|
|
29
|
+
Engine for training, validation, and inference.
|
|
30
|
+
|
|
31
|
+
HYBRID APPROACH:
|
|
32
|
+
- Auto-builds model if not built (user can also build explicitly)
|
|
33
|
+
- Auto-sets optimizer if not set (user can also set explicitly)
|
|
34
|
+
|
|
35
|
+
This provides:
|
|
36
|
+
1. Simplicity for beginners (just works)
|
|
37
|
+
2. Control for advanced users (explicit calls)
|
|
38
|
+
3. Flexibility for power users (custom args, modifications)
|
|
39
|
+
|
|
40
|
+
The Engine:
|
|
41
|
+
1. Manages device assignment and DDP
|
|
42
|
+
2. Iterates over batches and epochs
|
|
43
|
+
3. Calls module's single-batch methods (training_step, validation_step, prediction_step)
|
|
44
|
+
4. Manages callbacks (creates defaults, registers hooks)
|
|
45
|
+
5. Handles logging and progress bars
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
module: BaseModule instance
|
|
49
|
+
config: Configuration object (for device, DDP, callbacks, etc.)
|
|
50
|
+
|
|
51
|
+
Example (Simple - auto-build):
|
|
52
|
+
module = MyModel('MyModel', device, config)
|
|
53
|
+
engine = Engine(module, config)
|
|
54
|
+
engine.fit(train_loader, val_loader, max_epochs=100) # Auto-builds and sets optimizer
|
|
55
|
+
|
|
56
|
+
Example (Advanced - explicit control):
|
|
57
|
+
module = MyModel('MyModel', device, config)
|
|
58
|
+
module.build(custom_layers=64) # Custom build
|
|
59
|
+
module.summary() # Inspect
|
|
60
|
+
module.associate_optimizer() # Custom optimizer setup
|
|
61
|
+
|
|
62
|
+
engine = Engine(module, config)
|
|
63
|
+
engine.fit(train_loader, val_loader, max_epochs=100) # Uses pre-built model
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, module: KitoModule, config):
|
|
67
|
+
"""
|
|
68
|
+
Initialize Engine.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
module: KitoModule instance (can be built or not)
|
|
72
|
+
config: Configuration object
|
|
73
|
+
"""
|
|
74
|
+
self.max_epochs = None
|
|
75
|
+
self.module = module
|
|
76
|
+
self.config = config
|
|
77
|
+
|
|
78
|
+
# Extract config values
|
|
79
|
+
self.distributed_training = config.training.distributed_training
|
|
80
|
+
self.work_directory = config.workdir.work_directory
|
|
81
|
+
|
|
82
|
+
# Logger
|
|
83
|
+
self.logger = DDPLogger() if self.distributed_training else DefaultLogger()
|
|
84
|
+
|
|
85
|
+
# Device and DDP setup
|
|
86
|
+
self.gpu_id = (
|
|
87
|
+
dist.get_rank()
|
|
88
|
+
if self.distributed_training
|
|
89
|
+
else config.training.master_gpu_id
|
|
90
|
+
)
|
|
91
|
+
self.driver_device = (
|
|
92
|
+
self.gpu_id == 0
|
|
93
|
+
if self.distributed_training
|
|
94
|
+
else True
|
|
95
|
+
)
|
|
96
|
+
device_type = config.training.device_type
|
|
97
|
+
self.device = assign_device(device_type, self.gpu_id)
|
|
98
|
+
|
|
99
|
+
# Log device info
|
|
100
|
+
available = get_available_devices()
|
|
101
|
+
self.logger.log_info(
|
|
102
|
+
f"Device configuration:\n"
|
|
103
|
+
f" Requested: {device_type}\n"
|
|
104
|
+
f" Assigned: {self.device}\n"
|
|
105
|
+
f" Available: {available}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Assign to module
|
|
109
|
+
self.module._move_to_device(self.device)
|
|
110
|
+
|
|
111
|
+
# Progress bars
|
|
112
|
+
self.train_pbar = (
|
|
113
|
+
DDPProgressBarHandler()
|
|
114
|
+
if self.distributed_training
|
|
115
|
+
else StandardProgressBarHandler()
|
|
116
|
+
)
|
|
117
|
+
self.val_pbar = (
|
|
118
|
+
DDPProgressBarHandler()
|
|
119
|
+
if self.distributed_training
|
|
120
|
+
else StandardProgressBarHandler()
|
|
121
|
+
)
|
|
122
|
+
self.inference_pbar = StandardProgressBarHandler()
|
|
123
|
+
|
|
124
|
+
# Training state
|
|
125
|
+
self.current_epoch = 0
|
|
126
|
+
self.stop_training = False
|
|
127
|
+
|
|
128
|
+
# First batch flag for data shape checking
|
|
129
|
+
self._first_train_batch = True
|
|
130
|
+
|
|
131
|
+
# ========================================================================
|
|
132
|
+
# AUTO-SETUP (Hybrid approach)
|
|
133
|
+
# ========================================================================
|
|
134
|
+
|
|
135
|
+
def _ensure_model_ready_for_training(self):
|
|
136
|
+
"""
|
|
137
|
+
Ensure model is ready for training.
|
|
138
|
+
|
|
139
|
+
Auto-builds and auto-sets optimizer if needed.
|
|
140
|
+
Logs when doing so for transparency.
|
|
141
|
+
"""
|
|
142
|
+
# Auto-build if needed
|
|
143
|
+
if not self.module.is_built:
|
|
144
|
+
self.logger.log_info(
|
|
145
|
+
f"Model '{self.module.model_name}' not built. Building automatically..."
|
|
146
|
+
)
|
|
147
|
+
self.module.build()
|
|
148
|
+
self.module._move_to_device(self.device) # Then move to device
|
|
149
|
+
self.logger.log_info("✓ Model built successfully.")
|
|
150
|
+
|
|
151
|
+
# Auto-setup optimizer if needed
|
|
152
|
+
if not self.module.is_optimizer_set:
|
|
153
|
+
self.logger.log_info(
|
|
154
|
+
f"Optimizer not set for '{self.module.model_name}'. Setting up automatically..."
|
|
155
|
+
)
|
|
156
|
+
self.module.associate_optimizer()
|
|
157
|
+
self.logger.log_info(
|
|
158
|
+
f"✓ Optimizer configured: {self.module.optimizer.__class__.__name__}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _ensure_model_ready_for_inference(self, weight_path: Optional[str] = None):
|
|
162
|
+
"""
|
|
163
|
+
Ensure model is ready for inference.
|
|
164
|
+
|
|
165
|
+
Auto-builds if needed.
|
|
166
|
+
Does NOT auto-load weights (user must do this explicitly).
|
|
167
|
+
"""
|
|
168
|
+
# Auto-build if needed
|
|
169
|
+
if not self.module.is_built:
|
|
170
|
+
self.logger.log_info(
|
|
171
|
+
f"Model '{self.module.model_name}' not built. Building automatically..."
|
|
172
|
+
)
|
|
173
|
+
self.module.build()
|
|
174
|
+
self.logger.log_info("✓ Model built successfully.")
|
|
175
|
+
|
|
176
|
+
# Check weights are loaded (don't auto-load, user must be explicit)
|
|
177
|
+
if not self.module.is_weights_loaded:
|
|
178
|
+
self.logger.log_warning(
|
|
179
|
+
f"⚠ Weights not loaded for '{self.module.model_name}'. "
|
|
180
|
+
"Call module.load_weights() or engine.load_weights() before inference."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# ========================================================================
|
|
184
|
+
# FIT - Training + Validation
|
|
185
|
+
# ========================================================================
|
|
186
|
+
@require_mode('train')
|
|
187
|
+
def fit(
|
|
188
|
+
self,
|
|
189
|
+
train_loader: DataLoader = None,
|
|
190
|
+
val_loader: DataLoader = None,
|
|
191
|
+
data_pipeline: GenericDataPipeline = None,
|
|
192
|
+
max_epochs: Optional[int] = None,
|
|
193
|
+
callbacks: Optional[List[Callback]] = None,
|
|
194
|
+
):
|
|
195
|
+
"""
|
|
196
|
+
Train the module.
|
|
197
|
+
|
|
198
|
+
HYBRID: Auto-builds model and sets optimizer if not already done.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
train_loader: Training DataLoader
|
|
202
|
+
val_loader: Validation DataLoader
|
|
203
|
+
data_pipeline: GenericDataPipeline instance
|
|
204
|
+
max_epochs: Maximum epochs (None = use config value)
|
|
205
|
+
callbacks: List of callbacks (None = create smart defaults)
|
|
206
|
+
|
|
207
|
+
Example (Simple):
|
|
208
|
+
engine.fit(train_loader, val_loader, max_epochs=100)
|
|
209
|
+
# Auto-builds and sets optimizer
|
|
210
|
+
|
|
211
|
+
Example (Advanced):
|
|
212
|
+
module.build(custom_layers=64)
|
|
213
|
+
module.associate_optimizer()
|
|
214
|
+
engine.fit(train_loader, val_loader, max_epochs=100)
|
|
215
|
+
# Uses pre-built model
|
|
216
|
+
"""
|
|
217
|
+
# ===== HYBRID: Auto-setup if needed =====
|
|
218
|
+
self._ensure_model_ready_for_training()
|
|
219
|
+
|
|
220
|
+
if data_pipeline is not None:
|
|
221
|
+
train_loader = data_pipeline.train_dataloader()
|
|
222
|
+
val_loader = data_pipeline.val_dataloader()
|
|
223
|
+
|
|
224
|
+
if train_loader is None and data_pipeline is None:
|
|
225
|
+
raise ValueError("Must provide either train_loader or data_pipeline")
|
|
226
|
+
|
|
227
|
+
# Validate readiness (should always pass after auto-setup)
|
|
228
|
+
ReadinessValidator.check_for_training(self.module)
|
|
229
|
+
ReadinessValidator.check_data_loaders(train_loader, val_loader)
|
|
230
|
+
|
|
231
|
+
# Get max_epochs
|
|
232
|
+
if max_epochs is None:
|
|
233
|
+
max_epochs = self.config.training.n_train_epochs
|
|
234
|
+
self.max_epochs = max_epochs
|
|
235
|
+
|
|
236
|
+
# Wrap model for DDP if needed
|
|
237
|
+
if self.distributed_training:
|
|
238
|
+
self._wrap_model_ddp()
|
|
239
|
+
|
|
240
|
+
# Setup callbacks
|
|
241
|
+
if callbacks is None:
|
|
242
|
+
callbacks = self._create_default_callbacks()
|
|
243
|
+
|
|
244
|
+
# Wrap callbacks for DDP
|
|
245
|
+
if self.distributed_training:
|
|
246
|
+
callbacks = [DDPAwareCallback(cb) for cb in callbacks]
|
|
247
|
+
|
|
248
|
+
callbacks = CallbackList(callbacks)
|
|
249
|
+
|
|
250
|
+
# Log training info
|
|
251
|
+
self._log_training_info(max_epochs, len(callbacks.callbacks))
|
|
252
|
+
|
|
253
|
+
# Reset first batch flag
|
|
254
|
+
self._first_train_batch = True
|
|
255
|
+
|
|
256
|
+
# ===== HOOK: on_train_begin =====
|
|
257
|
+
callbacks.on_train_begin(engine=self, model=self.module.model)
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
for epoch in range(max_epochs):
|
|
261
|
+
self.current_epoch = epoch + 1
|
|
262
|
+
|
|
263
|
+
# ===== HOOK: on_epoch_begin =====
|
|
264
|
+
callbacks.on_epoch_begin(
|
|
265
|
+
epoch=self.current_epoch,
|
|
266
|
+
engine=self,
|
|
267
|
+
model=self.module.model
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Train epoch
|
|
271
|
+
train_loss = self._train_epoch(
|
|
272
|
+
train_loader,
|
|
273
|
+
self.config.training.train_verbosity_level
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Validate epoch
|
|
277
|
+
val_loss, val_data, val_outputs = self._validate_epoch(
|
|
278
|
+
val_loader,
|
|
279
|
+
self.config.training.val_verbosity_level
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Prepare logs
|
|
283
|
+
logs = {
|
|
284
|
+
'train_loss': train_loss,
|
|
285
|
+
'val_loss': val_loss
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
# ===== HOOK: on_epoch_end =====
|
|
289
|
+
callbacks.on_epoch_end(
|
|
290
|
+
epoch=self.current_epoch,
|
|
291
|
+
engine=self,
|
|
292
|
+
model=self.module.model,
|
|
293
|
+
logs=logs,
|
|
294
|
+
val_data=val_data,
|
|
295
|
+
val_outputs=val_outputs
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Check early stopping
|
|
299
|
+
if self.stop_training:
|
|
300
|
+
self.logger.log_info(f"\nStopping training at epoch {self.current_epoch}")
|
|
301
|
+
break
|
|
302
|
+
|
|
303
|
+
except KeyboardInterrupt:
|
|
304
|
+
self.logger.log_info(f"\nTraining interrupted at epoch {self.current_epoch}")
|
|
305
|
+
|
|
306
|
+
finally:
|
|
307
|
+
# ===== HOOK: on_train_end =====
|
|
308
|
+
callbacks.on_train_end(engine=self, model=self.module.model)
|
|
309
|
+
|
|
310
|
+
# Cleanup DDP
|
|
311
|
+
if self.distributed_training:
|
|
312
|
+
dist.destroy_process_group()
|
|
313
|
+
|
|
314
|
+
self.logger.log_info(f"\nTraining of {self.module.model_name} completed.")
|
|
315
|
+
|
|
316
|
+
def _train_epoch(self, train_loader, verbosity_level):
|
|
317
|
+
"""
|
|
318
|
+
Train for one epoch.
|
|
319
|
+
|
|
320
|
+
Iterates over batches and calls module.training_step(batch).
|
|
321
|
+
"""
|
|
322
|
+
self.module.model.train()
|
|
323
|
+
|
|
324
|
+
# Set epoch for DDP sampler
|
|
325
|
+
if self.distributed_training:
|
|
326
|
+
train_loader.sampler.set_epoch(self.current_epoch)
|
|
327
|
+
|
|
328
|
+
# Init progress bar
|
|
329
|
+
self.train_pbar.init(
|
|
330
|
+
len(train_loader),
|
|
331
|
+
verbosity_level,
|
|
332
|
+
message= f"Epoch {self.current_epoch}/{self.max_epochs}"
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Accumulate loss
|
|
336
|
+
running_loss = 0.0
|
|
337
|
+
|
|
338
|
+
for batch_idx, batch in enumerate(train_loader):
|
|
339
|
+
# Check data shape on first batch
|
|
340
|
+
if self._first_train_batch:
|
|
341
|
+
self.module._check_data_shape()
|
|
342
|
+
self._first_train_batch = False
|
|
343
|
+
|
|
344
|
+
# ===== Call module's training_step =====
|
|
345
|
+
step_output = self.module.training_step(batch, self.train_pbar)
|
|
346
|
+
|
|
347
|
+
# Extract loss
|
|
348
|
+
loss = step_output['loss']
|
|
349
|
+
running_loss += loss.item()
|
|
350
|
+
|
|
351
|
+
# Update progress bar
|
|
352
|
+
self.train_pbar.step(
|
|
353
|
+
batch_idx + 1,
|
|
354
|
+
[("train_loss: ", float(f'{loss.item():.4f}'))]
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Return average loss
|
|
358
|
+
return running_loss / len(train_loader)
|
|
359
|
+
|
|
360
|
+
def _validate_epoch(self, val_loader, verbosity_level):
|
|
361
|
+
"""
|
|
362
|
+
Validate for one epoch.
|
|
363
|
+
|
|
364
|
+
Iterates over batches and calls module.validation_step(batch).
|
|
365
|
+
"""
|
|
366
|
+
self.module.model.eval()
|
|
367
|
+
|
|
368
|
+
# Init progress bar
|
|
369
|
+
self.val_pbar.init(
|
|
370
|
+
len(val_loader) + 1,
|
|
371
|
+
verbosity_level
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Accumulate metrics
|
|
375
|
+
running_loss = 0.0
|
|
376
|
+
last_inputs = None
|
|
377
|
+
last_targets = None
|
|
378
|
+
last_outputs = None
|
|
379
|
+
|
|
380
|
+
with torch.no_grad():
|
|
381
|
+
for batch_idx, batch in enumerate(val_loader):
|
|
382
|
+
# ===== Call module's validation_step =====
|
|
383
|
+
step_output = self.module.validation_step(batch, self.val_pbar)
|
|
384
|
+
|
|
385
|
+
# Extract metrics
|
|
386
|
+
loss = step_output['loss']
|
|
387
|
+
running_loss += loss.item()
|
|
388
|
+
|
|
389
|
+
# Store last batch for callbacks
|
|
390
|
+
last_outputs = step_output.get('outputs')
|
|
391
|
+
last_targets = step_output.get('targets')
|
|
392
|
+
last_inputs = step_output.get('inputs')
|
|
393
|
+
|
|
394
|
+
# Update progress bar
|
|
395
|
+
self.val_pbar.step(
|
|
396
|
+
batch_idx + 1,
|
|
397
|
+
[("val_loss: ", float(f'{loss.item():.4f}'))]
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Average loss
|
|
401
|
+
avg_loss = running_loss / len(val_loader)
|
|
402
|
+
|
|
403
|
+
# Package validation data for callbacks
|
|
404
|
+
val_data = (last_inputs, last_targets)
|
|
405
|
+
|
|
406
|
+
return avg_loss, val_data, last_outputs
|
|
407
|
+
|
|
408
|
+
# ========================================================================
|
|
409
|
+
# PREDICT - Inference
|
|
410
|
+
# ========================================================================
|
|
411
|
+
@require_mode('inference')
|
|
412
|
+
def predict(
|
|
413
|
+
self,
|
|
414
|
+
test_loader,
|
|
415
|
+
save_to_disk: bool = False,
|
|
416
|
+
output_path: Optional[str] = None,
|
|
417
|
+
):
|
|
418
|
+
"""
|
|
419
|
+
Run inference on test data.
|
|
420
|
+
|
|
421
|
+
HYBRID: Auto-builds model if not already built.
|
|
422
|
+
Does NOT auto-load weights (user must load explicitly).
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
test_loader: Test DataLoader
|
|
426
|
+
save_to_disk: Save predictions to HDF5 file
|
|
427
|
+
output_path: Path to save predictions (if save_to_disk=True)
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
numpy.ndarray: Predictions (if save_to_disk=False)
|
|
431
|
+
None: If save_to_disk=True
|
|
432
|
+
|
|
433
|
+
Example:
|
|
434
|
+
# Load weights first (explicit)
|
|
435
|
+
module.load_weights('weights/best.pt')
|
|
436
|
+
|
|
437
|
+
# Predict (auto-builds if needed)
|
|
438
|
+
predictions = engine.predict(test_loader)
|
|
439
|
+
"""
|
|
440
|
+
# ===== HYBRID: Auto-setup if needed =====
|
|
441
|
+
self._ensure_model_ready_for_inference()
|
|
442
|
+
|
|
443
|
+
# Validate (will warn if weights not loaded)
|
|
444
|
+
ReadinessValidator.check_for_inference(self.module)
|
|
445
|
+
ReadinessValidator.check_data_loaders(test_loader=test_loader)
|
|
446
|
+
|
|
447
|
+
if save_to_disk and output_path is None:
|
|
448
|
+
raise ValueError("output_path required when save_to_disk=True")
|
|
449
|
+
|
|
450
|
+
# Log inference info
|
|
451
|
+
self._log_inference_info(len(test_loader), save_to_disk, output_path)
|
|
452
|
+
|
|
453
|
+
# Prepare storage
|
|
454
|
+
storage = self._prepare_prediction_storage(
|
|
455
|
+
test_loader,
|
|
456
|
+
save_to_disk,
|
|
457
|
+
output_path
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Run inference
|
|
461
|
+
self.module.model.eval()
|
|
462
|
+
|
|
463
|
+
self.inference_pbar.init(
|
|
464
|
+
len(test_loader),
|
|
465
|
+
self.config.training.test_verbosity_level,
|
|
466
|
+
message='Evaluating model in inference mode'
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
with torch.no_grad():
|
|
470
|
+
for batch_idx, batch in enumerate(test_loader):
|
|
471
|
+
# ===== Call module's prediction_step =====
|
|
472
|
+
outputs = self.module.prediction_step(batch, self.inference_pbar)
|
|
473
|
+
|
|
474
|
+
# Store predictions
|
|
475
|
+
self._store_predictions(
|
|
476
|
+
storage,
|
|
477
|
+
batch_idx,
|
|
478
|
+
outputs,
|
|
479
|
+
test_loader.batch_size,
|
|
480
|
+
len(test_loader.dataset)
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Update progress bar
|
|
484
|
+
self.inference_pbar.step(batch_idx + 1, values=None)
|
|
485
|
+
|
|
486
|
+
# Finalize storage
|
|
487
|
+
result = self._finalize_prediction_storage(storage, save_to_disk, output_path)
|
|
488
|
+
|
|
489
|
+
self.logger.log_info(f"\nInference of {self.module.model_name} completed.")
|
|
490
|
+
|
|
491
|
+
return result
|
|
492
|
+
|
|
493
|
+
def _prepare_prediction_storage(self, test_loader, save_to_disk, output_path):
|
|
494
|
+
"""Prepare storage for predictions."""
|
|
495
|
+
# Infer output shape from first batch
|
|
496
|
+
sample_batch = next(iter(test_loader))
|
|
497
|
+
|
|
498
|
+
with torch.no_grad():
|
|
499
|
+
sample_output = self.module.prediction_step(sample_batch, None)
|
|
500
|
+
|
|
501
|
+
# Determine shapes
|
|
502
|
+
batch_size = sample_output.shape[0]
|
|
503
|
+
output_shape = sample_output.shape[1:]
|
|
504
|
+
total_samples = len(test_loader.dataset)
|
|
505
|
+
full_shape = (total_samples,) + output_shape
|
|
506
|
+
|
|
507
|
+
# Use module's standard_data_shape if set, otherwise infer
|
|
508
|
+
if self.module.standard_data_shape is not None:
|
|
509
|
+
output_shape = self.module.standard_data_shape
|
|
510
|
+
full_shape = (total_samples,) + output_shape
|
|
511
|
+
|
|
512
|
+
if save_to_disk:
|
|
513
|
+
# Validate output path
|
|
514
|
+
self._check_inference_save_path_valid(output_path)
|
|
515
|
+
|
|
516
|
+
# Create HDF5 file
|
|
517
|
+
h5_file = h5py.File(output_path, 'w')
|
|
518
|
+
h5_dataset = h5_file.create_dataset(
|
|
519
|
+
'predictions',
|
|
520
|
+
shape=full_shape,
|
|
521
|
+
dtype='float32',
|
|
522
|
+
chunks=(batch_size,) + output_shape
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
self.logger.log_info(f"Created HDF5 dataset '{output_path}' to store predictions.")
|
|
526
|
+
|
|
527
|
+
return {'type': 'disk', 'file': h5_file, 'dataset': h5_dataset}
|
|
528
|
+
else:
|
|
529
|
+
# Allocate memory
|
|
530
|
+
predictions = np.zeros(full_shape, dtype=np.float32)
|
|
531
|
+
self.logger.log_info(f"Created tensor in memory to store predictions.")
|
|
532
|
+
return {'type': 'memory', 'data': predictions}
|
|
533
|
+
|
|
534
|
+
def _store_predictions(self, storage, batch_idx, outputs, batch_size, dataset_len):
|
|
535
|
+
"""Store predictions for a batch."""
|
|
536
|
+
batch_data = outputs.cpu().detach().numpy()
|
|
537
|
+
|
|
538
|
+
start_idx = batch_idx * batch_size
|
|
539
|
+
end_idx = min((batch_idx + 1) * batch_size, dataset_len)
|
|
540
|
+
|
|
541
|
+
if storage['type'] == 'disk':
|
|
542
|
+
storage['dataset'][start_idx:end_idx] = batch_data
|
|
543
|
+
else:
|
|
544
|
+
storage['data'][start_idx:end_idx] = batch_data
|
|
545
|
+
|
|
546
|
+
def _finalize_prediction_storage(self, storage, save_to_disk, output_path):
|
|
547
|
+
"""Finalize storage and return result."""
|
|
548
|
+
if storage['type'] == 'disk':
|
|
549
|
+
storage['file'].close()
|
|
550
|
+
self.logger.log_info(f"Saved predictions to '{output_path}'")
|
|
551
|
+
return None
|
|
552
|
+
else:
|
|
553
|
+
return storage['data']
|
|
554
|
+
|
|
555
|
+
def _check_inference_save_path_valid(self, inference_filename):
|
|
556
|
+
"""Validate inference output path."""
|
|
557
|
+
file_name, file_extension = os.path.splitext(inference_filename)
|
|
558
|
+
|
|
559
|
+
if os.path.isdir(inference_filename):
|
|
560
|
+
raise IsADirectoryError(
|
|
561
|
+
f"ERROR: '{os.path.abspath(inference_filename)}' is a directory."
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if file_extension != '.h5':
|
|
565
|
+
raise ValueError(
|
|
566
|
+
f"ERROR: '{os.path.abspath(inference_filename)}' must have .h5 extension."
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if os.path.exists(inference_filename):
|
|
570
|
+
raise FileExistsError(
|
|
571
|
+
f"ERROR: '{os.path.abspath(inference_filename)}' already exists."
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# ========================================================================
|
|
575
|
+
# DDP
|
|
576
|
+
# ========================================================================
|
|
577
|
+
|
|
578
|
+
def _wrap_model_ddp(self):
|
|
579
|
+
"""Wrap model with DistributedDataParallel."""
|
|
580
|
+
from torch.nn.parallel import DistributedDataParallel
|
|
581
|
+
|
|
582
|
+
if not isinstance(self.module.model, DistributedDataParallel):
|
|
583
|
+
self.module.model = DistributedDataParallel(
|
|
584
|
+
self.module.model,
|
|
585
|
+
device_ids=[self.gpu_id]
|
|
586
|
+
)
|
|
587
|
+
# self.module.model.to(self.device)
|
|
588
|
+
self.module._move_to_device(self.device)
|
|
589
|
+
self.logger.log_info("Model wrapped in DistributedDataParallel.")
|
|
590
|
+
|
|
591
|
+
# ========================================================================
|
|
592
|
+
# DEFAULT CALLBACKS
|
|
593
|
+
# ========================================================================
|
|
594
|
+
|
|
595
|
+
'''def _create_default_callbacks(self):
|
|
596
|
+
"""Create smart default callbacks based on config."""
|
|
597
|
+
# vedi se tenerlo cosi' + forse fai interfacce in __init__.py
|
|
598
|
+
from kito.callbacks.modelcheckpoint import ModelCheckpoint
|
|
599
|
+
from kito.callbacks.csv_logger import CSVLogger
|
|
600
|
+
from kito.callbacks.txt_logger import TextLogger
|
|
601
|
+
from kito.callbacks.tensorboard_callbacks import TensorBoardScalars, TensorBoardHistograms, TensorBoardGraph
|
|
602
|
+
|
|
603
|
+
from kito.callbacks.tensorboard_callback_images import SimpleImagePlotter
|
|
604
|
+
|
|
605
|
+
callbacks = []
|
|
606
|
+
|
|
607
|
+
# Setup paths
|
|
608
|
+
timestamp = datetime.datetime.now().strftime("%d%b%Y-%H%M%S")
|
|
609
|
+
work_dir = Path(os.path.expandvars(self.work_directory))
|
|
610
|
+
model_name = self.module.model_name
|
|
611
|
+
train_codename = self.config.model.train_codename
|
|
612
|
+
|
|
613
|
+
# === CSV Logger (always) ===
|
|
614
|
+
csv_dir = work_dir / "logs" / "csv"
|
|
615
|
+
csv_dir.mkdir(parents=True, exist_ok=True)
|
|
616
|
+
csv_path = csv_dir / f"{model_name}_{timestamp}_{train_codename}.csv"
|
|
617
|
+
callbacks.append(CSVLogger(str(csv_path)))
|
|
618
|
+
|
|
619
|
+
# === Text Logger (always) ===
|
|
620
|
+
log_dir = work_dir / "logs" / "text"
|
|
621
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
622
|
+
log_path = log_dir / f"{model_name}_{timestamp}_{train_codename}.log"
|
|
623
|
+
callbacks.append(TextLogger(str(log_path)))
|
|
624
|
+
|
|
625
|
+
# === Model Checkpoint (if configured) ===
|
|
626
|
+
if self.config.model.save_model_weights:
|
|
627
|
+
weight_dir = work_dir / "weights" / model_name
|
|
628
|
+
weight_dir.mkdir(parents=True, exist_ok=True)
|
|
629
|
+
weight_path = weight_dir / f"best_{model_name}_{timestamp}_{train_codename}.pt"
|
|
630
|
+
|
|
631
|
+
callbacks.append(
|
|
632
|
+
ModelCheckpoint(
|
|
633
|
+
filepath=str(weight_path),
|
|
634
|
+
monitor='val_loss',
|
|
635
|
+
save_best_only=True,
|
|
636
|
+
mode='min',
|
|
637
|
+
verbose=True
|
|
638
|
+
)
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
# === TensorBoard (if configured) ===
|
|
642
|
+
if self.config.model.log_to_tensorboard:
|
|
643
|
+
tb_dir = work_dir / "logs" / "tensorboard" / model_name / timestamp / train_codename
|
|
644
|
+
tb_dir.mkdir(parents=True, exist_ok=True)
|
|
645
|
+
|
|
646
|
+
# Scalars
|
|
647
|
+
callbacks.append(TensorBoardScalars(str(tb_dir)))
|
|
648
|
+
|
|
649
|
+
# Histograms
|
|
650
|
+
callbacks.append(TensorBoardHistograms(str(tb_dir), freq=5))
|
|
651
|
+
|
|
652
|
+
# Model graph
|
|
653
|
+
callbacks.append(
|
|
654
|
+
TensorBoardGraph(
|
|
655
|
+
str(tb_dir),
|
|
656
|
+
input_to_model=lambda: self.module.get_sample_input()
|
|
657
|
+
)
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Image plotting (if configured)
|
|
661
|
+
if hasattr(self.config.model, 'tensorboard_img_id'):
|
|
662
|
+
img_dir = tb_dir / 'images'
|
|
663
|
+
callbacks.append(
|
|
664
|
+
SimpleImagePlotter(
|
|
665
|
+
log_dir=str(img_dir),
|
|
666
|
+
tag=self.config.model.tensorboard_img_id,
|
|
667
|
+
freq=1,
|
|
668
|
+
batch_indices=getattr(self.config.model, 'batch_idx_viz', [0])
|
|
669
|
+
)
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
return callbacks'''
|
|
673
|
+
|
|
674
|
+
def _create_default_callbacks(self):
|
|
675
|
+
"""Create smart default callbacks based on config."""
|
|
676
|
+
from kito.callbacks.modelcheckpoint import ModelCheckpoint
|
|
677
|
+
from kito.callbacks.csv_logger import CSVLogger
|
|
678
|
+
from kito.callbacks.txt_logger import TextLogger
|
|
679
|
+
from kito.callbacks.tensorboard_callbacks import (
|
|
680
|
+
TensorBoardScalars,
|
|
681
|
+
TensorBoardHistograms,
|
|
682
|
+
TensorBoardGraph
|
|
683
|
+
)
|
|
684
|
+
from kito.callbacks.tensorboard_callback_images import SimpleImagePlotter
|
|
685
|
+
|
|
686
|
+
callbacks = []
|
|
687
|
+
|
|
688
|
+
# Get callbacks config (with defaults if not provided)
|
|
689
|
+
cb_config = getattr(self.config, 'callbacks', CallbacksConfig())
|
|
690
|
+
|
|
691
|
+
# Setup paths
|
|
692
|
+
timestamp = datetime.datetime.now().strftime("%d%b%Y-%H%M%S")
|
|
693
|
+
work_dir = Path(os.path.expandvars(self.work_directory))
|
|
694
|
+
model_name = self.module.model_name
|
|
695
|
+
train_codename = self.config.model.train_codename
|
|
696
|
+
|
|
697
|
+
# === CSV Logger ===
|
|
698
|
+
if cb_config.enable_csv_logger:
|
|
699
|
+
csv_dir = work_dir / "logs" / "csv"
|
|
700
|
+
csv_dir.mkdir(parents=True, exist_ok=True)
|
|
701
|
+
csv_path = csv_dir / f"{model_name}_{timestamp}_{train_codename}.csv"
|
|
702
|
+
callbacks.append(CSVLogger(str(csv_path)))
|
|
703
|
+
|
|
704
|
+
# === Text Logger ===
|
|
705
|
+
if cb_config.enable_text_logger:
|
|
706
|
+
log_dir = work_dir / "logs" / "text"
|
|
707
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
708
|
+
log_path = log_dir / f"{model_name}_{timestamp}_{train_codename}.log"
|
|
709
|
+
callbacks.append(TextLogger(str(log_path)))
|
|
710
|
+
|
|
711
|
+
# === Model Checkpoint ===
|
|
712
|
+
if cb_config.enable_model_checkpoint and self.config.model.save_model_weights: # should leave only one of the two
|
|
713
|
+
weight_dir = work_dir / "weights" / model_name
|
|
714
|
+
weight_dir.mkdir(parents=True, exist_ok=True)
|
|
715
|
+
weight_path = weight_dir / f"best_{model_name}_{timestamp}_{train_codename}.pt"
|
|
716
|
+
|
|
717
|
+
callbacks.append(
|
|
718
|
+
ModelCheckpoint(
|
|
719
|
+
filepath=str(weight_path),
|
|
720
|
+
monitor=cb_config.checkpoint_monitor,
|
|
721
|
+
save_best_only=cb_config.checkpoint_save_best_only,
|
|
722
|
+
mode=cb_config.checkpoint_mode,
|
|
723
|
+
verbose=cb_config.checkpoint_verbose
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# === TensorBoard ===
|
|
728
|
+
if cb_config.enable_tensorboard and self.config.model.log_to_tensorboard: # should leave only one of the two
|
|
729
|
+
tb_dir = work_dir / "logs" / "tensorboard" / model_name / timestamp / train_codename
|
|
730
|
+
tb_dir.mkdir(parents=True, exist_ok=True)
|
|
731
|
+
|
|
732
|
+
# Scalars
|
|
733
|
+
if cb_config.tensorboard_scalars:
|
|
734
|
+
callbacks.append(TensorBoardScalars(str(tb_dir)))
|
|
735
|
+
|
|
736
|
+
# Histograms
|
|
737
|
+
if cb_config.tensorboard_histograms:
|
|
738
|
+
callbacks.append(
|
|
739
|
+
TensorBoardHistograms(
|
|
740
|
+
str(tb_dir),
|
|
741
|
+
freq=cb_config.tensorboard_histogram_freq
|
|
742
|
+
)
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# Model graph
|
|
746
|
+
if cb_config.tensorboard_graph:
|
|
747
|
+
callbacks.append(
|
|
748
|
+
TensorBoardGraph(
|
|
749
|
+
str(tb_dir),
|
|
750
|
+
input_to_model=lambda: self.module.get_sample_input()
|
|
751
|
+
)
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
# Image plotting
|
|
755
|
+
if cb_config.tensorboard_images:
|
|
756
|
+
img_dir = tb_dir / 'images'
|
|
757
|
+
callbacks.append(
|
|
758
|
+
SimpleImagePlotter(
|
|
759
|
+
log_dir=str(img_dir),
|
|
760
|
+
tag=getattr(self.config.model, 'tensorboard_img_id', 'images'),
|
|
761
|
+
freq=cb_config.tensorboard_image_freq,
|
|
762
|
+
batch_indices=cb_config.tensorboard_batch_indices
|
|
763
|
+
)
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
return callbacks
|
|
767
|
+
|
|
768
|
+
def get_default_callbacks(self):
|
|
769
|
+
"""
|
|
770
|
+
Get default callbacks configured from config.callbacks.
|
|
771
|
+
|
|
772
|
+
Useful when you want to extend defaults with custom callbacks.
|
|
773
|
+
|
|
774
|
+
Returns:
|
|
775
|
+
List of callback instances
|
|
776
|
+
|
|
777
|
+
Example:
|
|
778
|
+
>>> # Get defaults
|
|
779
|
+
>>> callbacks = engine.get_default_callbacks()
|
|
780
|
+
>>>
|
|
781
|
+
>>> # Modify or extend
|
|
782
|
+
>>> callbacks.append(MyCustomCallback(param=10))
|
|
783
|
+
>>>
|
|
784
|
+
>>> # Or modify existing
|
|
785
|
+
>>> for cb in callbacks:
|
|
786
|
+
... if isinstance(cb, ModelCheckpoint):
|
|
787
|
+
... cb.verbose = False
|
|
788
|
+
>>>
|
|
789
|
+
>>> # Use them
|
|
790
|
+
>>> engine.fit(train_loader, val_loader, callbacks=callbacks)
|
|
791
|
+
"""
|
|
792
|
+
return self._create_default_callbacks()
|
|
793
|
+
# ========================================================================
|
|
794
|
+
# WEIGHT LOADING (Convenience method)
|
|
795
|
+
# ========================================================================
|
|
796
|
+
|
|
797
|
+
def load_weights(self, weight_path: str, strict: bool = True):
|
|
798
|
+
"""
|
|
799
|
+
Load weights into module.
|
|
800
|
+
|
|
801
|
+
Convenience method that calls module.load_weights().
|
|
802
|
+
|
|
803
|
+
Args:
|
|
804
|
+
weight_path: Path to weight file
|
|
805
|
+
strict: Strict loading
|
|
806
|
+
"""
|
|
807
|
+
self.module.load_weights(weight_path, strict)
|
|
808
|
+
self.logger.log_info(f"Loaded weights from '{weight_path}' successfully.")
|
|
809
|
+
|
|
810
|
+
# ========================================================================
|
|
811
|
+
# LOGGING
|
|
812
|
+
# ========================================================================
|
|
813
|
+
|
|
814
|
+
def _log_training_info(self, max_epochs, num_callbacks):
|
|
815
|
+
"""Log training configuration."""
|
|
816
|
+
attrib_to_log = {
|
|
817
|
+
'model_name': self.module.model_name,
|
|
818
|
+
'optimizer': self.module.optimizer.__class__.__name__,
|
|
819
|
+
'batch_size': self.module.batch_size,
|
|
820
|
+
'n_train_epochs': max_epochs,
|
|
821
|
+
'learning_rate': self.module.learning_rate,
|
|
822
|
+
'work_directory': self.work_directory,
|
|
823
|
+
'save_model_weights': self.config.model.save_model_weights,
|
|
824
|
+
'distributed_training': self.distributed_training,
|
|
825
|
+
'callbacks': num_callbacks
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
self.logger.log_info(
|
|
829
|
+
'Model being used with the following parameters:\n\n' +
|
|
830
|
+
'\n '.join(f"{k} -> {v}" for k, v in attrib_to_log.items()) + '\n'
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
def _log_inference_info(self, num_batches, save_to_disk, output_path):
|
|
834
|
+
"""Log inference configuration."""
|
|
835
|
+
self.logger.log_info(
|
|
836
|
+
f"\nRunning inference for {self.module.model_name}\n"
|
|
837
|
+
f"Batches: {num_batches}\n"
|
|
838
|
+
f"Device: {self.device}\n"
|
|
839
|
+
f"Save to disk: {save_to_disk}\n"
|
|
840
|
+
+ (f"Output: {output_path}\n" if save_to_disk else "Output: In-memory\n")
|
|
841
|
+
)
|