pytorch-kito 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kito/engine.py ADDED
@@ -0,0 +1,841 @@
1
+ import datetime
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ import h5py
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.utils.data import DataLoader
11
+
12
+ from kito.callbacks.callback_base import Callback, CallbackList
13
+ from kito.callbacks.ddp_aware_callback import DDPAwareCallback
14
+ from kito.config.moduleconfig import CallbacksConfig
15
+ from kito.data.datapipeline import GenericDataPipeline
16
+ from kito.module import KitoModule
17
+ from kito.strategies.logger_strategy import DDPLogger, DefaultLogger
18
+ from kito.strategies.progress_bar_strategy import (
19
+ StandardProgressBarHandler,
20
+ DDPProgressBarHandler
21
+ )
22
+ from kito.strategies.readiness_validator import ReadinessValidator
23
+ from kito.utils.decorators import require_mode
24
+ from kito.utils.gpu_utils import assign_device, get_available_devices
25
+
26
+
27
+ class Engine:
28
+ """
29
+ Engine for training, validation, and inference.
30
+
31
+ HYBRID APPROACH:
32
+ - Auto-builds model if not built (user can also build explicitly)
33
+ - Auto-sets optimizer if not set (user can also set explicitly)
34
+
35
+ This provides:
36
+ 1. Simplicity for beginners (just works)
37
+ 2. Control for advanced users (explicit calls)
38
+ 3. Flexibility for power users (custom args, modifications)
39
+
40
+ The Engine:
41
+ 1. Manages device assignment and DDP
42
+ 2. Iterates over batches and epochs
43
+ 3. Calls module's single-batch methods (training_step, validation_step, prediction_step)
44
+ 4. Manages callbacks (creates defaults, registers hooks)
45
+ 5. Handles logging and progress bars
46
+
47
+ Args:
48
+ module: BaseModule instance
49
+ config: Configuration object (for device, DDP, callbacks, etc.)
50
+
51
+ Example (Simple - auto-build):
52
+ module = MyModel('MyModel', device, config)
53
+ engine = Engine(module, config)
54
+ engine.fit(train_loader, val_loader, max_epochs=100) # Auto-builds and sets optimizer
55
+
56
+ Example (Advanced - explicit control):
57
+ module = MyModel('MyModel', device, config)
58
+ module.build(custom_layers=64) # Custom build
59
+ module.summary() # Inspect
60
+ module.associate_optimizer() # Custom optimizer setup
61
+
62
+ engine = Engine(module, config)
63
+ engine.fit(train_loader, val_loader, max_epochs=100) # Uses pre-built model
64
+ """
65
+
66
+ def __init__(self, module: KitoModule, config):
67
+ """
68
+ Initialize Engine.
69
+
70
+ Args:
71
+ module: KitoModule instance (can be built or not)
72
+ config: Configuration object
73
+ """
74
+ self.max_epochs = None
75
+ self.module = module
76
+ self.config = config
77
+
78
+ # Extract config values
79
+ self.distributed_training = config.training.distributed_training
80
+ self.work_directory = config.workdir.work_directory
81
+
82
+ # Logger
83
+ self.logger = DDPLogger() if self.distributed_training else DefaultLogger()
84
+
85
+ # Device and DDP setup
86
+ self.gpu_id = (
87
+ dist.get_rank()
88
+ if self.distributed_training
89
+ else config.training.master_gpu_id
90
+ )
91
+ self.driver_device = (
92
+ self.gpu_id == 0
93
+ if self.distributed_training
94
+ else True
95
+ )
96
+ device_type = config.training.device_type
97
+ self.device = assign_device(device_type, self.gpu_id)
98
+
99
+ # Log device info
100
+ available = get_available_devices()
101
+ self.logger.log_info(
102
+ f"Device configuration:\n"
103
+ f" Requested: {device_type}\n"
104
+ f" Assigned: {self.device}\n"
105
+ f" Available: {available}"
106
+ )
107
+
108
+ # Assign to module
109
+ self.module._move_to_device(self.device)
110
+
111
+ # Progress bars
112
+ self.train_pbar = (
113
+ DDPProgressBarHandler()
114
+ if self.distributed_training
115
+ else StandardProgressBarHandler()
116
+ )
117
+ self.val_pbar = (
118
+ DDPProgressBarHandler()
119
+ if self.distributed_training
120
+ else StandardProgressBarHandler()
121
+ )
122
+ self.inference_pbar = StandardProgressBarHandler()
123
+
124
+ # Training state
125
+ self.current_epoch = 0
126
+ self.stop_training = False
127
+
128
+ # First batch flag for data shape checking
129
+ self._first_train_batch = True
130
+
131
+ # ========================================================================
132
+ # AUTO-SETUP (Hybrid approach)
133
+ # ========================================================================
134
+
135
+ def _ensure_model_ready_for_training(self):
136
+ """
137
+ Ensure model is ready for training.
138
+
139
+ Auto-builds and auto-sets optimizer if needed.
140
+ Logs when doing so for transparency.
141
+ """
142
+ # Auto-build if needed
143
+ if not self.module.is_built:
144
+ self.logger.log_info(
145
+ f"Model '{self.module.model_name}' not built. Building automatically..."
146
+ )
147
+ self.module.build()
148
+ self.module._move_to_device(self.device) # Then move to device
149
+ self.logger.log_info("✓ Model built successfully.")
150
+
151
+ # Auto-setup optimizer if needed
152
+ if not self.module.is_optimizer_set:
153
+ self.logger.log_info(
154
+ f"Optimizer not set for '{self.module.model_name}'. Setting up automatically..."
155
+ )
156
+ self.module.associate_optimizer()
157
+ self.logger.log_info(
158
+ f"✓ Optimizer configured: {self.module.optimizer.__class__.__name__}"
159
+ )
160
+
161
+ def _ensure_model_ready_for_inference(self, weight_path: Optional[str] = None):
162
+ """
163
+ Ensure model is ready for inference.
164
+
165
+ Auto-builds if needed.
166
+ Does NOT auto-load weights (user must do this explicitly).
167
+ """
168
+ # Auto-build if needed
169
+ if not self.module.is_built:
170
+ self.logger.log_info(
171
+ f"Model '{self.module.model_name}' not built. Building automatically..."
172
+ )
173
+ self.module.build()
174
+ self.logger.log_info("✓ Model built successfully.")
175
+
176
+ # Check weights are loaded (don't auto-load, user must be explicit)
177
+ if not self.module.is_weights_loaded:
178
+ self.logger.log_warning(
179
+ f"⚠ Weights not loaded for '{self.module.model_name}'. "
180
+ "Call module.load_weights() or engine.load_weights() before inference."
181
+ )
182
+
183
+ # ========================================================================
184
+ # FIT - Training + Validation
185
+ # ========================================================================
186
+ @require_mode('train')
187
+ def fit(
188
+ self,
189
+ train_loader: DataLoader = None,
190
+ val_loader: DataLoader = None,
191
+ data_pipeline: GenericDataPipeline = None,
192
+ max_epochs: Optional[int] = None,
193
+ callbacks: Optional[List[Callback]] = None,
194
+ ):
195
+ """
196
+ Train the module.
197
+
198
+ HYBRID: Auto-builds model and sets optimizer if not already done.
199
+
200
+ Args:
201
+ train_loader: Training DataLoader
202
+ val_loader: Validation DataLoader
203
+ data_pipeline: GenericDataPipeline instance
204
+ max_epochs: Maximum epochs (None = use config value)
205
+ callbacks: List of callbacks (None = create smart defaults)
206
+
207
+ Example (Simple):
208
+ engine.fit(train_loader, val_loader, max_epochs=100)
209
+ # Auto-builds and sets optimizer
210
+
211
+ Example (Advanced):
212
+ module.build(custom_layers=64)
213
+ module.associate_optimizer()
214
+ engine.fit(train_loader, val_loader, max_epochs=100)
215
+ # Uses pre-built model
216
+ """
217
+ # ===== HYBRID: Auto-setup if needed =====
218
+ self._ensure_model_ready_for_training()
219
+
220
+ if data_pipeline is not None:
221
+ train_loader = data_pipeline.train_dataloader()
222
+ val_loader = data_pipeline.val_dataloader()
223
+
224
+ if train_loader is None and data_pipeline is None:
225
+ raise ValueError("Must provide either train_loader or data_pipeline")
226
+
227
+ # Validate readiness (should always pass after auto-setup)
228
+ ReadinessValidator.check_for_training(self.module)
229
+ ReadinessValidator.check_data_loaders(train_loader, val_loader)
230
+
231
+ # Get max_epochs
232
+ if max_epochs is None:
233
+ max_epochs = self.config.training.n_train_epochs
234
+ self.max_epochs = max_epochs
235
+
236
+ # Wrap model for DDP if needed
237
+ if self.distributed_training:
238
+ self._wrap_model_ddp()
239
+
240
+ # Setup callbacks
241
+ if callbacks is None:
242
+ callbacks = self._create_default_callbacks()
243
+
244
+ # Wrap callbacks for DDP
245
+ if self.distributed_training:
246
+ callbacks = [DDPAwareCallback(cb) for cb in callbacks]
247
+
248
+ callbacks = CallbackList(callbacks)
249
+
250
+ # Log training info
251
+ self._log_training_info(max_epochs, len(callbacks.callbacks))
252
+
253
+ # Reset first batch flag
254
+ self._first_train_batch = True
255
+
256
+ # ===== HOOK: on_train_begin =====
257
+ callbacks.on_train_begin(engine=self, model=self.module.model)
258
+
259
+ try:
260
+ for epoch in range(max_epochs):
261
+ self.current_epoch = epoch + 1
262
+
263
+ # ===== HOOK: on_epoch_begin =====
264
+ callbacks.on_epoch_begin(
265
+ epoch=self.current_epoch,
266
+ engine=self,
267
+ model=self.module.model
268
+ )
269
+
270
+ # Train epoch
271
+ train_loss = self._train_epoch(
272
+ train_loader,
273
+ self.config.training.train_verbosity_level
274
+ )
275
+
276
+ # Validate epoch
277
+ val_loss, val_data, val_outputs = self._validate_epoch(
278
+ val_loader,
279
+ self.config.training.val_verbosity_level
280
+ )
281
+
282
+ # Prepare logs
283
+ logs = {
284
+ 'train_loss': train_loss,
285
+ 'val_loss': val_loss
286
+ }
287
+
288
+ # ===== HOOK: on_epoch_end =====
289
+ callbacks.on_epoch_end(
290
+ epoch=self.current_epoch,
291
+ engine=self,
292
+ model=self.module.model,
293
+ logs=logs,
294
+ val_data=val_data,
295
+ val_outputs=val_outputs
296
+ )
297
+
298
+ # Check early stopping
299
+ if self.stop_training:
300
+ self.logger.log_info(f"\nStopping training at epoch {self.current_epoch}")
301
+ break
302
+
303
+ except KeyboardInterrupt:
304
+ self.logger.log_info(f"\nTraining interrupted at epoch {self.current_epoch}")
305
+
306
+ finally:
307
+ # ===== HOOK: on_train_end =====
308
+ callbacks.on_train_end(engine=self, model=self.module.model)
309
+
310
+ # Cleanup DDP
311
+ if self.distributed_training:
312
+ dist.destroy_process_group()
313
+
314
+ self.logger.log_info(f"\nTraining of {self.module.model_name} completed.")
315
+
316
+ def _train_epoch(self, train_loader, verbosity_level):
317
+ """
318
+ Train for one epoch.
319
+
320
+ Iterates over batches and calls module.training_step(batch).
321
+ """
322
+ self.module.model.train()
323
+
324
+ # Set epoch for DDP sampler
325
+ if self.distributed_training:
326
+ train_loader.sampler.set_epoch(self.current_epoch)
327
+
328
+ # Init progress bar
329
+ self.train_pbar.init(
330
+ len(train_loader),
331
+ verbosity_level,
332
+ message= f"Epoch {self.current_epoch}/{self.max_epochs}"
333
+ )
334
+
335
+ # Accumulate loss
336
+ running_loss = 0.0
337
+
338
+ for batch_idx, batch in enumerate(train_loader):
339
+ # Check data shape on first batch
340
+ if self._first_train_batch:
341
+ self.module._check_data_shape()
342
+ self._first_train_batch = False
343
+
344
+ # ===== Call module's training_step =====
345
+ step_output = self.module.training_step(batch, self.train_pbar)
346
+
347
+ # Extract loss
348
+ loss = step_output['loss']
349
+ running_loss += loss.item()
350
+
351
+ # Update progress bar
352
+ self.train_pbar.step(
353
+ batch_idx + 1,
354
+ [("train_loss: ", float(f'{loss.item():.4f}'))]
355
+ )
356
+
357
+ # Return average loss
358
+ return running_loss / len(train_loader)
359
+
360
+ def _validate_epoch(self, val_loader, verbosity_level):
361
+ """
362
+ Validate for one epoch.
363
+
364
+ Iterates over batches and calls module.validation_step(batch).
365
+ """
366
+ self.module.model.eval()
367
+
368
+ # Init progress bar
369
+ self.val_pbar.init(
370
+ len(val_loader) + 1,
371
+ verbosity_level
372
+ )
373
+
374
+ # Accumulate metrics
375
+ running_loss = 0.0
376
+ last_inputs = None
377
+ last_targets = None
378
+ last_outputs = None
379
+
380
+ with torch.no_grad():
381
+ for batch_idx, batch in enumerate(val_loader):
382
+ # ===== Call module's validation_step =====
383
+ step_output = self.module.validation_step(batch, self.val_pbar)
384
+
385
+ # Extract metrics
386
+ loss = step_output['loss']
387
+ running_loss += loss.item()
388
+
389
+ # Store last batch for callbacks
390
+ last_outputs = step_output.get('outputs')
391
+ last_targets = step_output.get('targets')
392
+ last_inputs = step_output.get('inputs')
393
+
394
+ # Update progress bar
395
+ self.val_pbar.step(
396
+ batch_idx + 1,
397
+ [("val_loss: ", float(f'{loss.item():.4f}'))]
398
+ )
399
+
400
+ # Average loss
401
+ avg_loss = running_loss / len(val_loader)
402
+
403
+ # Package validation data for callbacks
404
+ val_data = (last_inputs, last_targets)
405
+
406
+ return avg_loss, val_data, last_outputs
407
+
408
+ # ========================================================================
409
+ # PREDICT - Inference
410
+ # ========================================================================
411
+ @require_mode('inference')
412
+ def predict(
413
+ self,
414
+ test_loader,
415
+ save_to_disk: bool = False,
416
+ output_path: Optional[str] = None,
417
+ ):
418
+ """
419
+ Run inference on test data.
420
+
421
+ HYBRID: Auto-builds model if not already built.
422
+ Does NOT auto-load weights (user must load explicitly).
423
+
424
+ Args:
425
+ test_loader: Test DataLoader
426
+ save_to_disk: Save predictions to HDF5 file
427
+ output_path: Path to save predictions (if save_to_disk=True)
428
+
429
+ Returns:
430
+ numpy.ndarray: Predictions (if save_to_disk=False)
431
+ None: If save_to_disk=True
432
+
433
+ Example:
434
+ # Load weights first (explicit)
435
+ module.load_weights('weights/best.pt')
436
+
437
+ # Predict (auto-builds if needed)
438
+ predictions = engine.predict(test_loader)
439
+ """
440
+ # ===== HYBRID: Auto-setup if needed =====
441
+ self._ensure_model_ready_for_inference()
442
+
443
+ # Validate (will warn if weights not loaded)
444
+ ReadinessValidator.check_for_inference(self.module)
445
+ ReadinessValidator.check_data_loaders(test_loader=test_loader)
446
+
447
+ if save_to_disk and output_path is None:
448
+ raise ValueError("output_path required when save_to_disk=True")
449
+
450
+ # Log inference info
451
+ self._log_inference_info(len(test_loader), save_to_disk, output_path)
452
+
453
+ # Prepare storage
454
+ storage = self._prepare_prediction_storage(
455
+ test_loader,
456
+ save_to_disk,
457
+ output_path
458
+ )
459
+
460
+ # Run inference
461
+ self.module.model.eval()
462
+
463
+ self.inference_pbar.init(
464
+ len(test_loader),
465
+ self.config.training.test_verbosity_level,
466
+ message='Evaluating model in inference mode'
467
+ )
468
+
469
+ with torch.no_grad():
470
+ for batch_idx, batch in enumerate(test_loader):
471
+ # ===== Call module's prediction_step =====
472
+ outputs = self.module.prediction_step(batch, self.inference_pbar)
473
+
474
+ # Store predictions
475
+ self._store_predictions(
476
+ storage,
477
+ batch_idx,
478
+ outputs,
479
+ test_loader.batch_size,
480
+ len(test_loader.dataset)
481
+ )
482
+
483
+ # Update progress bar
484
+ self.inference_pbar.step(batch_idx + 1, values=None)
485
+
486
+ # Finalize storage
487
+ result = self._finalize_prediction_storage(storage, save_to_disk, output_path)
488
+
489
+ self.logger.log_info(f"\nInference of {self.module.model_name} completed.")
490
+
491
+ return result
492
+
493
+ def _prepare_prediction_storage(self, test_loader, save_to_disk, output_path):
494
+ """Prepare storage for predictions."""
495
+ # Infer output shape from first batch
496
+ sample_batch = next(iter(test_loader))
497
+
498
+ with torch.no_grad():
499
+ sample_output = self.module.prediction_step(sample_batch, None)
500
+
501
+ # Determine shapes
502
+ batch_size = sample_output.shape[0]
503
+ output_shape = sample_output.shape[1:]
504
+ total_samples = len(test_loader.dataset)
505
+ full_shape = (total_samples,) + output_shape
506
+
507
+ # Use module's standard_data_shape if set, otherwise infer
508
+ if self.module.standard_data_shape is not None:
509
+ output_shape = self.module.standard_data_shape
510
+ full_shape = (total_samples,) + output_shape
511
+
512
+ if save_to_disk:
513
+ # Validate output path
514
+ self._check_inference_save_path_valid(output_path)
515
+
516
+ # Create HDF5 file
517
+ h5_file = h5py.File(output_path, 'w')
518
+ h5_dataset = h5_file.create_dataset(
519
+ 'predictions',
520
+ shape=full_shape,
521
+ dtype='float32',
522
+ chunks=(batch_size,) + output_shape
523
+ )
524
+
525
+ self.logger.log_info(f"Created HDF5 dataset '{output_path}' to store predictions.")
526
+
527
+ return {'type': 'disk', 'file': h5_file, 'dataset': h5_dataset}
528
+ else:
529
+ # Allocate memory
530
+ predictions = np.zeros(full_shape, dtype=np.float32)
531
+ self.logger.log_info(f"Created tensor in memory to store predictions.")
532
+ return {'type': 'memory', 'data': predictions}
533
+
534
+ def _store_predictions(self, storage, batch_idx, outputs, batch_size, dataset_len):
535
+ """Store predictions for a batch."""
536
+ batch_data = outputs.cpu().detach().numpy()
537
+
538
+ start_idx = batch_idx * batch_size
539
+ end_idx = min((batch_idx + 1) * batch_size, dataset_len)
540
+
541
+ if storage['type'] == 'disk':
542
+ storage['dataset'][start_idx:end_idx] = batch_data
543
+ else:
544
+ storage['data'][start_idx:end_idx] = batch_data
545
+
546
+ def _finalize_prediction_storage(self, storage, save_to_disk, output_path):
547
+ """Finalize storage and return result."""
548
+ if storage['type'] == 'disk':
549
+ storage['file'].close()
550
+ self.logger.log_info(f"Saved predictions to '{output_path}'")
551
+ return None
552
+ else:
553
+ return storage['data']
554
+
555
+ def _check_inference_save_path_valid(self, inference_filename):
556
+ """Validate inference output path."""
557
+ file_name, file_extension = os.path.splitext(inference_filename)
558
+
559
+ if os.path.isdir(inference_filename):
560
+ raise IsADirectoryError(
561
+ f"ERROR: '{os.path.abspath(inference_filename)}' is a directory."
562
+ )
563
+
564
+ if file_extension != '.h5':
565
+ raise ValueError(
566
+ f"ERROR: '{os.path.abspath(inference_filename)}' must have .h5 extension."
567
+ )
568
+
569
+ if os.path.exists(inference_filename):
570
+ raise FileExistsError(
571
+ f"ERROR: '{os.path.abspath(inference_filename)}' already exists."
572
+ )
573
+
574
+ # ========================================================================
575
+ # DDP
576
+ # ========================================================================
577
+
578
+ def _wrap_model_ddp(self):
579
+ """Wrap model with DistributedDataParallel."""
580
+ from torch.nn.parallel import DistributedDataParallel
581
+
582
+ if not isinstance(self.module.model, DistributedDataParallel):
583
+ self.module.model = DistributedDataParallel(
584
+ self.module.model,
585
+ device_ids=[self.gpu_id]
586
+ )
587
+ # self.module.model.to(self.device)
588
+ self.module._move_to_device(self.device)
589
+ self.logger.log_info("Model wrapped in DistributedDataParallel.")
590
+
591
+ # ========================================================================
592
+ # DEFAULT CALLBACKS
593
+ # ========================================================================
594
+
595
+ '''def _create_default_callbacks(self):
596
+ """Create smart default callbacks based on config."""
597
+ # vedi se tenerlo cosi' + forse fai interfacce in __init__.py
598
+ from kito.callbacks.modelcheckpoint import ModelCheckpoint
599
+ from kito.callbacks.csv_logger import CSVLogger
600
+ from kito.callbacks.txt_logger import TextLogger
601
+ from kito.callbacks.tensorboard_callbacks import TensorBoardScalars, TensorBoardHistograms, TensorBoardGraph
602
+
603
+ from kito.callbacks.tensorboard_callback_images import SimpleImagePlotter
604
+
605
+ callbacks = []
606
+
607
+ # Setup paths
608
+ timestamp = datetime.datetime.now().strftime("%d%b%Y-%H%M%S")
609
+ work_dir = Path(os.path.expandvars(self.work_directory))
610
+ model_name = self.module.model_name
611
+ train_codename = self.config.model.train_codename
612
+
613
+ # === CSV Logger (always) ===
614
+ csv_dir = work_dir / "logs" / "csv"
615
+ csv_dir.mkdir(parents=True, exist_ok=True)
616
+ csv_path = csv_dir / f"{model_name}_{timestamp}_{train_codename}.csv"
617
+ callbacks.append(CSVLogger(str(csv_path)))
618
+
619
+ # === Text Logger (always) ===
620
+ log_dir = work_dir / "logs" / "text"
621
+ log_dir.mkdir(parents=True, exist_ok=True)
622
+ log_path = log_dir / f"{model_name}_{timestamp}_{train_codename}.log"
623
+ callbacks.append(TextLogger(str(log_path)))
624
+
625
+ # === Model Checkpoint (if configured) ===
626
+ if self.config.model.save_model_weights:
627
+ weight_dir = work_dir / "weights" / model_name
628
+ weight_dir.mkdir(parents=True, exist_ok=True)
629
+ weight_path = weight_dir / f"best_{model_name}_{timestamp}_{train_codename}.pt"
630
+
631
+ callbacks.append(
632
+ ModelCheckpoint(
633
+ filepath=str(weight_path),
634
+ monitor='val_loss',
635
+ save_best_only=True,
636
+ mode='min',
637
+ verbose=True
638
+ )
639
+ )
640
+
641
+ # === TensorBoard (if configured) ===
642
+ if self.config.model.log_to_tensorboard:
643
+ tb_dir = work_dir / "logs" / "tensorboard" / model_name / timestamp / train_codename
644
+ tb_dir.mkdir(parents=True, exist_ok=True)
645
+
646
+ # Scalars
647
+ callbacks.append(TensorBoardScalars(str(tb_dir)))
648
+
649
+ # Histograms
650
+ callbacks.append(TensorBoardHistograms(str(tb_dir), freq=5))
651
+
652
+ # Model graph
653
+ callbacks.append(
654
+ TensorBoardGraph(
655
+ str(tb_dir),
656
+ input_to_model=lambda: self.module.get_sample_input()
657
+ )
658
+ )
659
+
660
+ # Image plotting (if configured)
661
+ if hasattr(self.config.model, 'tensorboard_img_id'):
662
+ img_dir = tb_dir / 'images'
663
+ callbacks.append(
664
+ SimpleImagePlotter(
665
+ log_dir=str(img_dir),
666
+ tag=self.config.model.tensorboard_img_id,
667
+ freq=1,
668
+ batch_indices=getattr(self.config.model, 'batch_idx_viz', [0])
669
+ )
670
+ )
671
+
672
+ return callbacks'''
673
+
674
+ def _create_default_callbacks(self):
675
+ """Create smart default callbacks based on config."""
676
+ from kito.callbacks.modelcheckpoint import ModelCheckpoint
677
+ from kito.callbacks.csv_logger import CSVLogger
678
+ from kito.callbacks.txt_logger import TextLogger
679
+ from kito.callbacks.tensorboard_callbacks import (
680
+ TensorBoardScalars,
681
+ TensorBoardHistograms,
682
+ TensorBoardGraph
683
+ )
684
+ from kito.callbacks.tensorboard_callback_images import SimpleImagePlotter
685
+
686
+ callbacks = []
687
+
688
+ # Get callbacks config (with defaults if not provided)
689
+ cb_config = getattr(self.config, 'callbacks', CallbacksConfig())
690
+
691
+ # Setup paths
692
+ timestamp = datetime.datetime.now().strftime("%d%b%Y-%H%M%S")
693
+ work_dir = Path(os.path.expandvars(self.work_directory))
694
+ model_name = self.module.model_name
695
+ train_codename = self.config.model.train_codename
696
+
697
+ # === CSV Logger ===
698
+ if cb_config.enable_csv_logger:
699
+ csv_dir = work_dir / "logs" / "csv"
700
+ csv_dir.mkdir(parents=True, exist_ok=True)
701
+ csv_path = csv_dir / f"{model_name}_{timestamp}_{train_codename}.csv"
702
+ callbacks.append(CSVLogger(str(csv_path)))
703
+
704
+ # === Text Logger ===
705
+ if cb_config.enable_text_logger:
706
+ log_dir = work_dir / "logs" / "text"
707
+ log_dir.mkdir(parents=True, exist_ok=True)
708
+ log_path = log_dir / f"{model_name}_{timestamp}_{train_codename}.log"
709
+ callbacks.append(TextLogger(str(log_path)))
710
+
711
+ # === Model Checkpoint ===
712
+ if cb_config.enable_model_checkpoint and self.config.model.save_model_weights: # should leave only one of the two
713
+ weight_dir = work_dir / "weights" / model_name
714
+ weight_dir.mkdir(parents=True, exist_ok=True)
715
+ weight_path = weight_dir / f"best_{model_name}_{timestamp}_{train_codename}.pt"
716
+
717
+ callbacks.append(
718
+ ModelCheckpoint(
719
+ filepath=str(weight_path),
720
+ monitor=cb_config.checkpoint_monitor,
721
+ save_best_only=cb_config.checkpoint_save_best_only,
722
+ mode=cb_config.checkpoint_mode,
723
+ verbose=cb_config.checkpoint_verbose
724
+ )
725
+ )
726
+
727
+ # === TensorBoard ===
728
+ if cb_config.enable_tensorboard and self.config.model.log_to_tensorboard: # should leave only one of the two
729
+ tb_dir = work_dir / "logs" / "tensorboard" / model_name / timestamp / train_codename
730
+ tb_dir.mkdir(parents=True, exist_ok=True)
731
+
732
+ # Scalars
733
+ if cb_config.tensorboard_scalars:
734
+ callbacks.append(TensorBoardScalars(str(tb_dir)))
735
+
736
+ # Histograms
737
+ if cb_config.tensorboard_histograms:
738
+ callbacks.append(
739
+ TensorBoardHistograms(
740
+ str(tb_dir),
741
+ freq=cb_config.tensorboard_histogram_freq
742
+ )
743
+ )
744
+
745
+ # Model graph
746
+ if cb_config.tensorboard_graph:
747
+ callbacks.append(
748
+ TensorBoardGraph(
749
+ str(tb_dir),
750
+ input_to_model=lambda: self.module.get_sample_input()
751
+ )
752
+ )
753
+
754
+ # Image plotting
755
+ if cb_config.tensorboard_images:
756
+ img_dir = tb_dir / 'images'
757
+ callbacks.append(
758
+ SimpleImagePlotter(
759
+ log_dir=str(img_dir),
760
+ tag=getattr(self.config.model, 'tensorboard_img_id', 'images'),
761
+ freq=cb_config.tensorboard_image_freq,
762
+ batch_indices=cb_config.tensorboard_batch_indices
763
+ )
764
+ )
765
+
766
+ return callbacks
767
+
768
+ def get_default_callbacks(self):
769
+ """
770
+ Get default callbacks configured from config.callbacks.
771
+
772
+ Useful when you want to extend defaults with custom callbacks.
773
+
774
+ Returns:
775
+ List of callback instances
776
+
777
+ Example:
778
+ >>> # Get defaults
779
+ >>> callbacks = engine.get_default_callbacks()
780
+ >>>
781
+ >>> # Modify or extend
782
+ >>> callbacks.append(MyCustomCallback(param=10))
783
+ >>>
784
+ >>> # Or modify existing
785
+ >>> for cb in callbacks:
786
+ ... if isinstance(cb, ModelCheckpoint):
787
+ ... cb.verbose = False
788
+ >>>
789
+ >>> # Use them
790
+ >>> engine.fit(train_loader, val_loader, callbacks=callbacks)
791
+ """
792
+ return self._create_default_callbacks()
793
+ # ========================================================================
794
+ # WEIGHT LOADING (Convenience method)
795
+ # ========================================================================
796
+
797
+ def load_weights(self, weight_path: str, strict: bool = True):
798
+ """
799
+ Load weights into module.
800
+
801
+ Convenience method that calls module.load_weights().
802
+
803
+ Args:
804
+ weight_path: Path to weight file
805
+ strict: Strict loading
806
+ """
807
+ self.module.load_weights(weight_path, strict)
808
+ self.logger.log_info(f"Loaded weights from '{weight_path}' successfully.")
809
+
810
+ # ========================================================================
811
+ # LOGGING
812
+ # ========================================================================
813
+
814
+ def _log_training_info(self, max_epochs, num_callbacks):
815
+ """Log training configuration."""
816
+ attrib_to_log = {
817
+ 'model_name': self.module.model_name,
818
+ 'optimizer': self.module.optimizer.__class__.__name__,
819
+ 'batch_size': self.module.batch_size,
820
+ 'n_train_epochs': max_epochs,
821
+ 'learning_rate': self.module.learning_rate,
822
+ 'work_directory': self.work_directory,
823
+ 'save_model_weights': self.config.model.save_model_weights,
824
+ 'distributed_training': self.distributed_training,
825
+ 'callbacks': num_callbacks
826
+ }
827
+
828
+ self.logger.log_info(
829
+ 'Model being used with the following parameters:\n\n' +
830
+ '\n '.join(f"{k} -> {v}" for k, v in attrib_to_log.items()) + '\n'
831
+ )
832
+
833
+ def _log_inference_info(self, num_batches, save_to_disk, output_path):
834
+ """Log inference configuration."""
835
+ self.logger.log_info(
836
+ f"\nRunning inference for {self.module.model_name}\n"
837
+ f"Batches: {num_batches}\n"
838
+ f"Device: {self.device}\n"
839
+ f"Save to disk: {save_to_disk}\n"
840
+ + (f"Output: {output_path}\n" if save_to_disk else "Output: In-memory\n")
841
+ )