pytorch-kito 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kito/module.py ADDED
@@ -0,0 +1,447 @@
1
+ import os
2
+
3
+ import torch
4
+ from packaging.version import parse
5
+ from torchsummary import summary as model_summary
6
+
7
+ from kito.config.moduleconfig import KitoModuleConfig
8
+
9
+
10
+ class KitoModule:
11
+ """
12
+ Base module for PyTorch models.
13
+
14
+ Focused on model definition and single-batch operations.
15
+ The Engine handles all iteration, callbacks, and orchestration.
16
+
17
+ Usage:
18
+ class MyModel(KitoModule):
19
+ def build_inner_model(self):
20
+ self.model = nn.Sequential(...)
21
+ self.model_input_size = (3, 64, 64)
22
+ self.standard_data_shape = (3, 64, 64) # For inference
23
+
24
+ def bind_optimizer(self):
25
+ self.optimizer = torch.optim.Adam(
26
+ self.model.parameters(),
27
+ lr=self.learning_rate
28
+ )
29
+
30
+ # Use with Engine
31
+ module = MyModel('MyModel', device, config)
32
+ module.build()
33
+ module.associate_optimizer()
34
+
35
+ engine = Engine(module)
36
+ engine.fit(train_loader, val_loader, max_epochs=100)
37
+ """
38
+
39
+ def __init__(self, model_name: str, config: KitoModuleConfig = None):
40
+ """
41
+ Initialize BaseModule.
42
+
43
+ Args:
44
+ model_name: Name of the model
45
+ config: Optional config object for future extensibility
46
+ """
47
+ self.model_name = model_name
48
+ self.config = config
49
+
50
+ # Extract useful config values if provided
51
+ if config is not None:
52
+ self.learning_rate = config.training.learning_rate
53
+ self.batch_size = config.training.batch_size
54
+ else:
55
+ self.learning_rate = None
56
+ self.batch_size = None
57
+
58
+ # Model components
59
+ self.model = None
60
+ self.device = None # set by Engine
61
+ self.model_input_size = None
62
+ self.standard_data_shape = None # For inference (set by subclass)
63
+ self.optimizer = None
64
+
65
+ # Loss function (set by subclass or from config)
66
+ if config is not None:
67
+ from kito.utils.loss_utils import get_loss
68
+ self.loss = get_loss(config.model.loss)
69
+ else:
70
+ self.loss = None
71
+
72
+ # State flags
73
+ self._model_built = False
74
+ self._optimizer_bound = False
75
+ self._weights_loaded = False
76
+
77
+ # ========================================================================
78
+ # ABSTRACT METHODS - Must be implemented by subclasses
79
+ # ========================================================================
80
+
81
+ def build_inner_model(self, *args, **kwargs):
82
+ """
83
+ Build the model architecture.
84
+
85
+ Must set:
86
+ - self.model: The PyTorch model
87
+ - self.model_input_size: Tuple of input shape (C, H, W) or (C, H, W, D)
88
+ - self.standard_data_shape: Output shape for inference (optional)
89
+
90
+ Example:
91
+ def build_inner_model(self):
92
+ self.model = nn.Sequential(
93
+ nn.Conv2d(3, 64, 3, padding=1),
94
+ nn.ReLU(),
95
+ nn.Conv2d(64, 3, 3, padding=1)
96
+ )
97
+ self.model_input_size = (3, 64, 64)
98
+ self.standard_data_shape = (3, 64, 64)
99
+ """
100
+ raise NotImplementedError("Subclasses must implement build_inner_model().")
101
+
102
+ def bind_optimizer(self, *args, **kwargs):
103
+ """
104
+ Setup the optimizer.
105
+
106
+ Must set:
107
+ - self.optimizer: The PyTorch optimizer
108
+
109
+ Example:
110
+ def bind_optimizer(self):
111
+ self.optimizer = torch.optim.Adam(
112
+ self.model.parameters(),
113
+ lr=self.learning_rate
114
+ )
115
+ """
116
+ raise NotImplementedError("Subclasses must implement bind_optimizer().")
117
+
118
+ def _check_data_shape(self):
119
+ """
120
+ Check data shape on first batch.
121
+
122
+ Optional - implement if you need to validate input shape.
123
+ Called by Engine on first training batch.
124
+
125
+ Example:
126
+ def _check_data_shape(self):
127
+ # Validate that data matches expected shape
128
+ pass
129
+ """
130
+ pass # Default: no checking
131
+
132
+ # ========================================================================
133
+ # SETUP METHODS
134
+ # ========================================================================
135
+
136
+ def _move_to_device(self, device: torch.device):
137
+ """
138
+ Internal method called by Engine to move model to device.
139
+
140
+ Called AFTER build() by Engine.
141
+ """
142
+ self.device = device
143
+ if self.model is not None:
144
+ self.model.to(device)
145
+
146
+ def build(self, *args, **kwargs):
147
+ """Build model and move to device."""
148
+ self.build_inner_model(*args, **kwargs)
149
+ # self.model.to(self.device)
150
+ self._model_built = True
151
+
152
+ def associate_optimizer(self, *args, **kwargs):
153
+ """Setup optimizer."""
154
+ if not self._model_built:
155
+ raise RuntimeError("Must call build() before associate_optimizer()")
156
+ self.bind_optimizer(*args, **kwargs)
157
+ self._optimizer_bound = True
158
+
159
+ # ========================================================================
160
+ # SINGLE-BATCH OPERATIONS (Called by Engine)
161
+ # ========================================================================
162
+
163
+ def training_step(self, batch, pbar_handler=None):
164
+ """
165
+ Perform one training step on a single batch.
166
+
167
+ Called by Engine for each training batch.
168
+
169
+ Args:
170
+ batch: Tuple of (inputs, targets) from DataLoader
171
+ pbar_handler: Progress bar handler (optional, provided by Engine)
172
+
173
+ Returns:
174
+ dict: {'loss': tensor} - Must contain at least 'loss'
175
+
176
+ Override this for custom training logic (freeze layers, gradient accumulation, etc.).
177
+
178
+ Default implementation:
179
+ 1. Move data to device
180
+ 2. Zero gradients
181
+ 3. Forward pass
182
+ 4. Compute loss
183
+ 5. Backward pass
184
+ 6. Optimizer step
185
+ """
186
+ inputs, targets = batch
187
+
188
+ # Move to device
189
+ inputs = self.send_data_to_device(inputs)
190
+ targets = targets.to(self.device)
191
+
192
+ # Zero gradients
193
+ self.optimizer.zero_grad()
194
+
195
+ # Forward pass
196
+ outputs = self.pass_data_through_model(inputs)
197
+
198
+ # Compute loss
199
+ loss = self.compute_loss((inputs, targets), outputs)
200
+
201
+ # Backward pass
202
+ loss.backward()
203
+
204
+ # Optimizer step
205
+ self.optimizer.step()
206
+
207
+ return {'loss': loss}
208
+
209
+ def validation_step(self, batch, pbar_handler=None):
210
+ """
211
+ Perform one validation step on a single batch.
212
+
213
+ Called by Engine for each validation batch.
214
+
215
+ Args:
216
+ batch: Tuple of (inputs, targets) from DataLoader
217
+ pbar_handler: Progress bar handler (optional, provided by Engine)
218
+
219
+ Returns:
220
+ dict: Must contain 'loss', 'outputs', 'inputs', 'targets'
221
+
222
+ Override this for custom validation logic.
223
+ """
224
+ inputs, targets = batch
225
+
226
+ # Move to device
227
+ inputs = self.send_data_to_device(inputs)
228
+ targets = targets.to(self.device)
229
+
230
+ # Forward pass (no gradients)
231
+ outputs = self.pass_data_through_model(inputs)
232
+
233
+ # Compute loss
234
+ loss = self.compute_loss((inputs, targets), outputs)
235
+
236
+ return {
237
+ 'loss': loss,
238
+ 'outputs': outputs,
239
+ 'targets': targets,
240
+ 'inputs': inputs
241
+ }
242
+
243
+ def prediction_step(self, batch, pbar_handler=None):
244
+ """
245
+ Perform one prediction step on a single batch.
246
+
247
+ Called by Engine for each inference batch.
248
+
249
+ Args:
250
+ batch: Input data from DataLoader (can be tuple or tensor)
251
+ pbar_handler: Progress bar handler (optional, provided by Engine)
252
+
253
+ Returns:
254
+ tensor: Model predictions
255
+
256
+ Override this for custom prediction logic.
257
+ """
258
+ # Handle different batch formats
259
+ '''if isinstance(batch, (tuple, list)):
260
+ inputs = batch[0]
261
+ else:
262
+ inputs = batch # here in the else case errors might produce...'''
263
+ if isinstance(batch, (tuple, list)):
264
+ inputs = batch[0] if len(batch) > 0 else batch
265
+ elif isinstance(batch, torch.Tensor):
266
+ inputs = batch
267
+ elif isinstance(batch, dict):
268
+ # Handle dict batches (common in HuggingFace)
269
+ inputs = batch.get('input', batch.get('data', batch))
270
+ else:
271
+ raise TypeError(f"Unsupported batch type: {type(batch)}")
272
+
273
+ # Move to device
274
+ inputs = self.send_data_to_device(inputs)
275
+
276
+ # Forward pass
277
+ outputs = self.pass_data_through_model(inputs)
278
+
279
+ # Handle model outputs (for multi-output scenarios)
280
+ outputs = self.handle_model_outputs(outputs)
281
+
282
+ return outputs
283
+
284
+ # ========================================================================
285
+ # CORE OPERATIONS (Can be overridden)
286
+ # ========================================================================
287
+
288
+ def pass_data_through_model(self, data):
289
+ """
290
+ Forward pass through model.
291
+
292
+ Override for multi-input models.
293
+
294
+ Args:
295
+ data: Input tensor(s)
296
+
297
+ Returns:
298
+ Output tensor(s)
299
+ """
300
+ return self.model(data)
301
+
302
+ def compute_loss(self, data_pair, y_pred, **kwargs):
303
+ """
304
+ Compute loss from data pair and predictions.
305
+
306
+ Override for custom loss computation or multi-output scenarios.
307
+
308
+ Args:
309
+ data_pair: Tuple of (inputs, targets)
310
+ y_pred: Model predictions
311
+ **kwargs: Additional arguments (e.g., epoch)
312
+
313
+ Returns:
314
+ Loss tensor
315
+ """
316
+ y_true = data_pair[1].to(self.device)
317
+ return self.apply_loss(y_pred, y_true, **kwargs)
318
+
319
+ def apply_loss(self, y_pred, y_true, **kwargs):
320
+ """
321
+ Apply loss function.
322
+
323
+ Override for custom loss scenarios.
324
+
325
+ Args:
326
+ y_pred: Model predictions
327
+ y_true: Ground truth
328
+ **kwargs: Additional arguments
329
+
330
+ Returns:
331
+ Loss value
332
+ """
333
+ return self.loss(y_pred, y_true)
334
+
335
+ def send_data_to_device(self, data):
336
+ """
337
+ Move data to device.
338
+
339
+ Override for complex data structures (dict, nested tuples, etc.).
340
+
341
+ Args:
342
+ data: Data to move
343
+
344
+ Returns:
345
+ Data on device
346
+ """
347
+ return data.to(self.device)
348
+
349
+ def handle_model_outputs(self, outputs):
350
+ """
351
+ Handle model outputs.
352
+
353
+ Override for multi-output scenarios.
354
+
355
+ Args:
356
+ outputs: Raw model outputs
357
+
358
+ Returns:
359
+ Processed outputs
360
+ """
361
+ return outputs
362
+
363
+ # ========================================================================
364
+ # WEIGHTS (Model-specific operations)
365
+ # ========================================================================
366
+
367
+ def load_weights(self, weight_path: str, strict: bool = True):
368
+ """
369
+ Load model weights.
370
+
371
+ Called by Engine but can be overridden for custom loading logic.
372
+
373
+ Args:
374
+ weight_path: Path to weight file
375
+ strict: Strict state dict loading
376
+ """
377
+ if not os.path.exists(weight_path):
378
+ raise FileNotFoundError(f"Weight file not found: {weight_path}")
379
+
380
+ # Check file extension
381
+ _, file_extension = os.path.splitext(weight_path)
382
+ if file_extension != '.pt':
383
+ raise ValueError(f"Invalid weight file: {weight_path}. Must be .pt file.")
384
+
385
+ # Load weights
386
+ state_dict = torch.load(weight_path, map_location=self.device, weights_only=True)
387
+ self.model.load_state_dict(state_dict, strict=strict)
388
+
389
+ self._weights_loaded = True
390
+
391
+ def save_weights(self, weight_path: str):
392
+ """
393
+ Save model weights.
394
+
395
+ Args:
396
+ weight_path: Path to save weights
397
+ """
398
+ os.makedirs(os.path.dirname(weight_path) or '.', exist_ok=True)
399
+ torch.save(self.model.state_dict(), weight_path)
400
+
401
+ # ========================================================================
402
+ # UTILITIES
403
+ # ========================================================================
404
+
405
+ def get_sample_input(self):
406
+ """Get sample input tensor (for summaries, TensorBoard graphs, etc.)."""
407
+ if self.model_input_size is None:
408
+ raise ValueError("model_input_size not set. Call build() first.")
409
+ return torch.randn(1, *self.model_input_size).to(self.device)
410
+
411
+ def set_model_input_size(self, *args, **kwargs):
412
+ """
413
+ Hook for setting model input size in complex scenarios.
414
+
415
+ Optional - most models set this in build_inner_model().
416
+ """
417
+ raise NotImplementedError("Subclasses can implement set_model_input_size() if needed.")
418
+
419
+ def summary(self, summary_depth: int = 3):
420
+ """Print model summary."""
421
+ if self.model_input_size is None:
422
+ raise ValueError("model_input_size not set. Call build() first.")
423
+
424
+ torch_version = torch.__version__.split('+')[0]
425
+ if parse(torch_version) < parse('2.6.0'):
426
+ model_summary(self.model, self.model_input_size, batch_dim=0, depth=summary_depth)
427
+ else:
428
+ model_summary(self.model, self.model_input_size)
429
+
430
+ # ========================================================================
431
+ # STATE PROPERTIES
432
+ # ========================================================================
433
+
434
+ @property
435
+ def is_built(self):
436
+ """Check if model is built."""
437
+ return self._model_built
438
+
439
+ @property
440
+ def is_optimizer_set(self):
441
+ """Check if optimizer is set."""
442
+ return self._optimizer_bound
443
+
444
+ @property
445
+ def is_weights_loaded(self):
446
+ """Check if weights are loaded."""
447
+ return self._weights_loaded
File without changes
@@ -0,0 +1,51 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+
6
+ class BaseLogger:
7
+ def log_info(self, msg):
8
+ raise NotImplementedError
9
+
10
+ def log_warning(self, msg):
11
+ raise NotImplementedError
12
+
13
+ def log_error(self, msg):
14
+ raise NotImplementedError
15
+
16
+
17
+ class DefaultLogger(BaseLogger):
18
+ def __init__(self, log_level=logging.INFO):
19
+ self.logger = logging.getLogger(__name__)
20
+ self.logger.setLevel(log_level)
21
+ handler = logging.StreamHandler()
22
+ handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
23
+ self.logger.addHandler(handler)
24
+
25
+ def log_info(self, msg):
26
+ self.logger.info(msg)
27
+
28
+ def log_warning(self, msg):
29
+ self.logger.warning(msg)
30
+
31
+ def log_error(self, msg):
32
+ self.logger.error(msg)
33
+
34
+
35
+ class DDPLogger(DefaultLogger):
36
+ def __init__(self, log_level=logging.INFO):
37
+ super().__init__(log_level)
38
+ # Only log from rank 0
39
+ self.is_driver = torch.distributed.get_rank() == 0
40
+
41
+ def log_info(self, msg):
42
+ if self.is_driver:
43
+ super().log_info(msg)
44
+
45
+ def log_warning(self, msg):
46
+ if self.is_driver:
47
+ super().log_warning(msg)
48
+
49
+ def log_error(self, msg):
50
+ if self.is_driver:
51
+ super().log_error(msg)
@@ -0,0 +1,57 @@
1
+ from abc import ABC
2
+
3
+ import pkbar
4
+ import torch
5
+
6
+
7
+ class BaseProgressBarHandler(ABC):
8
+
9
+ def init(self, n_target_elements, verbosity_level, message=None):
10
+ pass
11
+
12
+ def step(self, i, values):
13
+ pass
14
+
15
+ def finalize(self, i, values):
16
+ pass
17
+
18
+
19
+ class StandardProgressBarHandler(BaseProgressBarHandler):
20
+ def __init__(self):
21
+ self.progress_bar = None
22
+ # print("Epoch {}/{}".format(epoch + 1, self.n_train_epochs)) # might be replaced by logger
23
+
24
+ def init(self, n_target_elements, verbosity_level, message=None):
25
+ if message is not None:
26
+ print(message)
27
+ self.progress_bar = pkbar.Kbar(target=n_target_elements, always_stateful=False, width=25,
28
+ verbose=verbosity_level)
29
+ self.progress_bar.update(0, None)
30
+
31
+ def step(self, i, values):
32
+ self.progress_bar.update(i, values)
33
+
34
+ def finalize(self, i, values):
35
+ self.progress_bar.add(i, values)
36
+
37
+
38
+ class DDPProgressBarHandler(BaseProgressBarHandler):
39
+ def __init__(self):
40
+ self.progress_bar = None
41
+ self.is_driver = (torch.distributed.get_rank() == 0)
42
+ # print("Epoch {}/{}".format(epoch + 1, self.n_train_epochs)) # might be replaced by logger
43
+
44
+ def init(self, n_target_elements, verbosity_level, message=None):
45
+ if self.is_driver:
46
+ if message is not None:
47
+ print(message)
48
+ self.progress_bar = pkbar.Kbar(target=n_target_elements, always_stateful=False, width=25,
49
+ verbose=verbosity_level)
50
+
51
+ def step(self, i, values):
52
+ if self.is_driver:
53
+ self.progress_bar.update(i, values)
54
+
55
+ def finalize(self, i, values):
56
+ if self.is_driver:
57
+ self.progress_bar.add(i, values)
@@ -0,0 +1,85 @@
1
+ class ReadinessValidator:
2
+ """
3
+ Validates module readiness for different operations.
4
+
5
+ This replaces the decorator pattern with a cleaner strategy pattern
6
+ that is easier to test and extend.
7
+
8
+ Usage:
9
+ # In Engine
10
+ ReadinessValidator.check_for_training(module)
11
+ ReadinessValidator.check_for_inference(module, weight_path)
12
+ """
13
+
14
+ @staticmethod
15
+ def check_for_training(module):
16
+ """
17
+ Check if module is ready for training.
18
+
19
+ Args:
20
+ module: BaseModule instance
21
+
22
+ Raises:
23
+ RuntimeError: If module is not ready
24
+ """
25
+ if not module.is_built:
26
+ raise RuntimeError(
27
+ f"Module '{module.model_name}' not built. "
28
+ "Call module.build() before training."
29
+ )
30
+
31
+ if not module.is_optimizer_set:
32
+ raise RuntimeError(
33
+ f"Module '{module.model_name}' optimizer not set. "
34
+ "Call module.associate_optimizer() before training."
35
+ )
36
+
37
+ if module.learning_rate is None:
38
+ raise RuntimeError(
39
+ f"Module '{module.model_name}' learning_rate not set."
40
+ )
41
+
42
+ @staticmethod
43
+ def check_for_inference(module, weight_path=None):
44
+ """
45
+ Check if module is ready for inference.
46
+
47
+ Args:
48
+ module: BaseModule instance
49
+ weight_path: Optional weight path to check
50
+
51
+ Raises:
52
+ RuntimeError: If module is not ready
53
+ """
54
+ if not module.is_built:
55
+ raise RuntimeError(
56
+ f"Module '{module.model_name}' not built. "
57
+ "Call module.build() before inference."
58
+ )
59
+
60
+ if not module.is_weights_loaded:
61
+ raise RuntimeError(
62
+ f"Module '{module.model_name}' weights not loaded. "
63
+ "Call module.load_weights() or engine.load_weights() before inference."
64
+ )
65
+
66
+ if weight_path is not None:
67
+ import os
68
+ if not os.path.exists(weight_path):
69
+ raise FileNotFoundError(f"Weight file not found: {weight_path}")
70
+
71
+ @staticmethod
72
+ def check_data_loaders(train_loader=None, val_loader=None, test_loader=None):
73
+ """
74
+ Check if data loaders are provided.
75
+
76
+ Args:
77
+ train_loader: Training data loader
78
+ val_loader: Validation data loader
79
+ test_loader: Test data loader
80
+
81
+ Raises:
82
+ ValueError: If required loaders are missing
83
+ """
84
+ if train_loader is None and val_loader is None and test_loader is None:
85
+ raise ValueError("At least one data loader must be provided")
kito/utils/__init__.py ADDED
File without changes
@@ -0,0 +1,45 @@
1
+ from functools import wraps
2
+
3
+
4
+ def require_mode(mode: str):
5
+ """
6
+ Decorator to enforce train_mode compatibility in Engine methods.
7
+
8
+ Validates that the requested operation (train or inference) matches
9
+ the config.training.train_mode setting.
10
+
11
+ Args:
12
+ mode: Either 'train' or 'inference'
13
+
14
+ Raises:
15
+ RuntimeError: If the operation conflicts with train_mode setting
16
+ ValueError: If mode is not 'train' or 'inference'
17
+ """
18
+ if mode not in ('train', 'inference'):
19
+ raise ValueError(f"Invalid mode: {mode}. Must be 'train' or 'inference'.")
20
+
21
+ def decorator(func):
22
+ @wraps(func)
23
+ def wrapper(self, *args, **kwargs):
24
+ # Get train_mode from config (default True for backward compatibility)
25
+ config_mode = getattr(self.config.training, 'train_mode', True)
26
+
27
+ # Check compatibility
28
+ if mode == 'train' and not config_mode:
29
+ raise RuntimeError(
30
+ f"Training not allowed: config.training.train_mode is set to False (inference mode).\n"
31
+ f"To enable training, set config.training.train_mode=True in your config."
32
+ )
33
+
34
+ if mode == 'inference' and config_mode:
35
+ raise RuntimeError(
36
+ f"Inference not allowed: config.training.train_mode is set to True (training mode).\n"
37
+ f"To enable inference, set config.training.train_mode=False in your config."
38
+ )
39
+
40
+ # Mode is compatible - execute the original function
41
+ return func(self, *args, **kwargs)
42
+
43
+ return wrapper
44
+
45
+ return decorator