torch-tk 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torch_tk/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+
3
+ _DIST_NAME = "torch-tk"
4
+
5
+ try:
6
+ __version__ = version(_DIST_NAME)
7
+ except PackageNotFoundError:
8
+ pkg = __package__ or __name__.split(".", 1)[0]
9
+ try:
10
+ __version__ = version(pkg)
11
+ except PackageNotFoundError:
12
+ __version__ = "0.0.0+local"
13
+
14
+ from .models.model import Model
15
+ from .optimizers.sgd import SGD
16
+ from .optimizers.adam import Adam
17
+ from .training.trainer import Trainer
18
+ from .checkpoints.checkpoint_manager import CheckPointManager
19
+ from .diagnostics.diagnostics import Diagnostics
20
+
21
+ __all__ = [
22
+ "__version__",
23
+ "Model",
24
+ "SGD",
25
+ "Adam",
26
+ "Trainer",
27
+ "CheckPointManager",
28
+ "Diagnostics",
29
+ ]
@@ -0,0 +1 @@
1
+ from .checkpoint_manager import CheckPointManager
@@ -0,0 +1,139 @@
1
+ '''
2
+ Checkpoint utilities for saving and restoring self-describing model and optimizer state.
3
+
4
+ This module provides the CheckPointManager class, which saves checkpoints that
5
+ contain enough information to reconstruct both a model and its optimizer in the
6
+ state that created the checkpoint. Each checkpoint stores the fully qualified
7
+ class path, constructor arguments from constructor_dict(), and state from
8
+ state_dict() for both objects.
9
+
10
+ Models and optimizers must be importable from a stable class path and must expose
11
+
12
+ - constructor_dict()
13
+ - state_dict()
14
+ - load_state_dict()
15
+
16
+ A model must be reconstructible from its class path and constructor_dict().
17
+ An optimizer must be reconstructible from its class path, model.parameters(),
18
+ and constructor_dict().
19
+
20
+ The constructor data must be serializable. In practice, this means it should
21
+ contain only standard serializable Python values such as numbers, strings,
22
+ lists, tuples, and dictionaries.
23
+
24
+ This design is not suitable for optimizers that depend on non-serializable
25
+ constructor inputs, non-standard constructor signatures, custom parameter-group
26
+ reconstruction beyond model.parameters(), or runtime state not captured by
27
+ state_dict().
28
+ '''
29
+
30
+ from pathlib import Path
31
+
32
+ import torch
33
+
34
+ from .utils import class_path_of_instance, import_class
35
+
36
+
37
+ class CheckPointManager:
38
+ '''
39
+ Save checkpoints and rebuild a model and optimizer from a checkpoint file.
40
+
41
+ The checkpoint contains the epoch, class paths, constructor arguments, and
42
+ state dictionaries for the model and optimizer.
43
+ '''
44
+
45
+ def __init__(self, model, optimizer, directory):
46
+ '''
47
+ Store the model, optimizer, and checkpoint directory.
48
+
49
+ The directory is converted to a Path and created if needed.
50
+ '''
51
+ if not isinstance(directory, Path):
52
+ directory = Path(directory)
53
+
54
+ self.model = model
55
+ self.optimizer = optimizer
56
+ self.directory = directory
57
+
58
+ def save(self, epoch):
59
+ '''
60
+ Save a checkpoint for the current model and optimizer state.
61
+
62
+ Returns the path to the written checkpoint file.
63
+ '''
64
+ if not hasattr(self.model, 'constructor_dict'):
65
+ raise TypeError('Model must implement constructor_dict().')
66
+
67
+ if not hasattr(self.optimizer, 'constructor_dict'):
68
+ raise TypeError('Optimizer must implement constructor_dict().')
69
+
70
+ checkpoint = {
71
+ 'epoch': epoch,
72
+ 'model_class_path': class_path_of_instance(self.model),
73
+ 'model_constructor_dict': self.model.constructor_dict(),
74
+ 'model_state_dict': self.model.state_dict(),
75
+ 'optimizer_class_path': class_path_of_instance(self.optimizer),
76
+ 'optimizer_constructor_dict': self.optimizer.constructor_dict(),
77
+ 'optimizer_state_dict': self.optimizer.state_dict(),
78
+ }
79
+
80
+ file_name = Path(type(self.model).__name__ + '.' + type(self.optimizer).__name__ + '.epoch=' + str(epoch) + '.pt')
81
+
82
+ self.directory.mkdir(parents=True, exist_ok=True)
83
+
84
+ file_path = self.directory / file_name
85
+ torch.save(checkpoint, file_path)
86
+
87
+ return file_path
88
+
89
+ @classmethod
90
+ def load_from_file(cls, file_path, device=None):
91
+ '''
92
+ Load a checkpoint file and reconstruct the model, optimizer, and manager.
93
+
94
+ Returns:
95
+ checkpoint_manager, model, optimizer, epoch
96
+
97
+ If device is given, the checkpoint is loaded onto that device and a
98
+ saved model constructor argument named 'device' is overridden.
99
+ '''
100
+ if not isinstance(file_path, Path):
101
+ file_path = Path(file_path)
102
+
103
+ checkpoint = torch.load(file_path, map_location=device)
104
+
105
+ # Reconstruct model
106
+
107
+ model_class = import_class(checkpoint['model_class_path'])
108
+
109
+ model_constructor_dict = checkpoint['model_constructor_dict']
110
+ model_args = model_constructor_dict.get('args', [])
111
+ model_kwargs = dict(model_constructor_dict.get('kwargs', {}))
112
+
113
+ if 'device' in model_kwargs and device is not None:
114
+ model_kwargs['device'] = device
115
+
116
+ model = model_class(*model_args, **model_kwargs)
117
+ model.load_state_dict(checkpoint['model_state_dict'])
118
+
119
+ if device is not None:
120
+ model = model.to(device)
121
+
122
+ # Reconstruct optimizer
123
+
124
+ optimizer_class = import_class(checkpoint['optimizer_class_path'])
125
+
126
+ optimizer_constructor_dict = checkpoint['optimizer_constructor_dict']
127
+ optimizer_args = optimizer_constructor_dict.get('args', [])
128
+ optimizer_kwargs = dict(optimizer_constructor_dict.get('kwargs', {}))
129
+
130
+ optimizer = optimizer_class(model.parameters(), *optimizer_args, **optimizer_kwargs)
131
+
132
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
133
+
134
+ # Create a checkpoint manager instance
135
+ checkpoint_manager = cls(model, optimizer, file_path.parent)
136
+
137
+ epoch = checkpoint['epoch']
138
+
139
+ return checkpoint_manager, model, optimizer, epoch
@@ -0,0 +1,33 @@
1
+ import importlib
2
+
3
+
4
+ def class_path_of_instance(instance):
5
+ '''
6
+ Return the fully qualified class path for an instance's class.
7
+ '''
8
+ cls = type(instance)
9
+ return cls.__module__ + '.' + cls.__qualname__
10
+
11
+
12
+ def import_class(class_path):
13
+ '''
14
+ Import and return a class from a fully qualified class path.
15
+
16
+ Nested classes referenced through __qualname__ are supported.
17
+ '''
18
+ parts = class_path.split('.')
19
+
20
+ for i in range(len(parts), 0, -1):
21
+ module_name = '.'.join(parts[:i])
22
+ try:
23
+ obj = importlib.import_module(module_name)
24
+ break
25
+ except ModuleNotFoundError:
26
+ continue
27
+ else:
28
+ raise ImportError(f'Could not import any module prefix from {class_path!r}')
29
+
30
+ for attr in parts[i:]:
31
+ obj = getattr(obj, attr)
32
+
33
+ return obj
@@ -0,0 +1,3 @@
1
+ from .diagnostics import Diagnostics
2
+ from .loss import model_worst_loss, per_sample_loss_from_data, per_sample_loss_from_data_loader
3
+ from .plotting import plot_diagnostics
@@ -0,0 +1,281 @@
1
+ '''
2
+ Utilities for storing, combining, and serializing sample-resolved training
3
+ diagnostics.
4
+
5
+ This module defines the Diagnostics class, which can be created from in-memory
6
+ data, a data loader, or a saved netCDF file. It stores per-sample loss values
7
+ together with model and optimizer metadata and can write the diagnostics back
8
+ to netCDF.
9
+ '''
10
+
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ import xarray as xr
16
+
17
+ from .loss import per_sample_loss_from_data, per_sample_loss_from_data_loader
18
+
19
+
20
+ class Diagnostics:
21
+ '''
22
+ Store sample-resolved loss diagnostics together with training metadata.
23
+
24
+ An instance holds model and optimizer identifiers, learning rate, batch
25
+ size, an optional description, one or more epoch values, and the
26
+ corresponding per-sample loss tensors. It also supports loading from and
27
+ saving to netCDF.
28
+ '''
29
+
30
+ @classmethod
31
+ def from_data_loader(
32
+ cls,
33
+ model,
34
+ loss_function_sample_resolved,
35
+ optimizer,
36
+ learning_rate,
37
+ batch_size,
38
+ data_loader,
39
+ description: str = None,
40
+ epoch=0,
41
+ ):
42
+ '''
43
+ Create a Diagnostics instance from a model evaluated on a data loader.
44
+
45
+ The method computes per-sample losses for all samples in the loader and
46
+ stores them together with model, optimizer, and training metadata.
47
+ '''
48
+
49
+ mean_per_sample_loss, per_sample_loss = per_sample_loss_from_data_loader(
50
+ model, loss_function_sample_resolved, data_loader
51
+ )
52
+
53
+ return cls(
54
+ type(model).__name__,
55
+ type(optimizer).__name__,
56
+ learning_rate,
57
+ batch_size,
58
+ epoch=epoch,
59
+ per_sample_loss=per_sample_loss,
60
+ description=description,
61
+ )
62
+
63
+ @classmethod
64
+ def from_data(
65
+ cls,
66
+ model,
67
+ loss_function_sample_resolved,
68
+ optimizer,
69
+ learning_rate,
70
+ batch_size,
71
+ x_data,
72
+ y_data,
73
+ description: str = None,
74
+ epoch=0,
75
+ chunk_size=None,
76
+ ):
77
+ '''
78
+ Create a Diagnostics instance from in-memory input and target tensors.
79
+
80
+ The method computes per-sample losses from the supplied tensors,
81
+ optionally in chunks to reduce memory usage, and stores them
82
+ together with model, optimizer, and training metadata.
83
+ '''
84
+
85
+ mean_per_sample_loss, per_sample_loss = per_sample_loss_from_data(
86
+ model, loss_function_sample_resolved, x_data, y_data, chunk_size=chunk_size
87
+ )
88
+
89
+ return cls(
90
+ type(model).__name__,
91
+ type(optimizer).__name__,
92
+ learning_rate,
93
+ batch_size,
94
+ epoch=epoch,
95
+ per_sample_loss=per_sample_loss,
96
+ description=description,
97
+ )
98
+
99
+ @classmethod
100
+ def from_netcdf(cls, path):
101
+ '''
102
+ Create a Diagnostics instance from a saved netCDF file.
103
+
104
+ This restores metadata, epoch values, and per-sample loss data from a
105
+ previously saved diagnostics file, which is useful when resuming work
106
+ from a checkpoint.
107
+ '''
108
+ if not isinstance(path, Path):
109
+ path = Path(path)
110
+
111
+ ds = xr.open_dataset(path)
112
+
113
+ instance = cls(
114
+ ds.attrs['model'],
115
+ ds.attrs['optimizer'],
116
+ ds.attrs['learning_rate'],
117
+ ds.attrs['batch_size'],
118
+ description=ds.attrs['description'],
119
+ per_sample_loss=torch.as_tensor(ds['per_sample_loss'].values),
120
+ epoch=ds['epoch'].values,
121
+ )
122
+
123
+ ds.close()
124
+
125
+ return instance
126
+
127
+ def __init__(
128
+ self, model_name, optimizer_name, learning_rate, batch_size, epoch=None, per_sample_loss=None, description: str = None
129
+ ):
130
+ '''
131
+ Initialize a Diagnostics instance.
132
+
133
+ Epoch values and per-sample losses must either both be provided or both
134
+ be omitted. If per-sample loss is one-dimensional, it is promoted to
135
+ two dimensions so that the leading dimension corresponds to epoch.
136
+ '''
137
+
138
+ if epoch is None and per_sample_loss is not None:
139
+ raise ValueError('epoch and per_sample_loss must both be None or both not None.')
140
+ if epoch is not None and per_sample_loss is None:
141
+ raise ValueError('epoch and per_sample_loss must both be None or both not None.')
142
+
143
+ self.model = model_name
144
+ self.optimizer = optimizer_name
145
+
146
+ self.learning_rate = str(learning_rate)
147
+ self.batch_size = batch_size
148
+ self.description = description
149
+
150
+ if epoch is not None:
151
+ self.epoch = np.atleast_1d(np.asarray(epoch, dtype=np.int64))
152
+ else:
153
+ self.epoch = None
154
+
155
+ # Add a first dimension that corresponds to the value(s) in the list self.epoch
156
+ if per_sample_loss is not None:
157
+ if per_sample_loss.ndim == 1:
158
+ self.per_sample_loss = per_sample_loss.unsqueeze(0)
159
+ else:
160
+ self.per_sample_loss = per_sample_loss
161
+ else:
162
+ self.per_sample_loss = None
163
+
164
+ # Check for consistency
165
+ if epoch is not None and len(self.epoch) != self.per_sample_loss.shape[0]:
166
+ raise ValueError('Number of epochs and number of per-sample loss instances do not match.')
167
+
168
+ return
169
+
170
+ def __add__(self, other):
171
+ if not isinstance(other, Diagnostics):
172
+ return NotImplemented
173
+
174
+ '''
175
+ Combine two compatible diagnostics objects.
176
+
177
+ The two objects must have matching metadata. When both contain per-sample
178
+ loss data, their epoch arrays and per-sample loss tensors are concatenated
179
+ along the epoch dimension.
180
+ '''
181
+
182
+ result = Diagnostics.__new__(Diagnostics)
183
+
184
+ assert self.model == other.model, (
185
+ f"Cannot add Diagnostics objects with different models: {self.model!r} != {other.model!r}"
186
+ )
187
+ assert self.optimizer == other.optimizer, (
188
+ f"Cannot add Diagnostics objects with different optimizers: {self.optimizer!r} != {other.optimizer!r}"
189
+ )
190
+ assert self.learning_rate == other.learning_rate, (
191
+ f"Cannot add Diagnostics objects with different learning rates: {self.learning_rate!r} != {other.learning_rate!r}"
192
+ )
193
+ assert self.batch_size == other.batch_size, (
194
+ f"Cannot add Diagnostics objects with different batch sizes: {self.batch_size!r} != {other.batch_size!r}"
195
+ )
196
+ assert self.description == other.description, (
197
+ f"Cannot add Diagnostics objects with different descriptions: {self.description!r} != {other.description!r}"
198
+ )
199
+
200
+ if self.per_sample_loss is None:
201
+ result.model = other.model
202
+ result.optimizer = other.optimizer
203
+ result.epoch = other.epoch
204
+ result.learning_rate = other.learning_rate
205
+ result.batch_size = other.batch_size
206
+ result.description = other.description
207
+ result.per_sample_loss = other.per_sample_loss
208
+ elif other.per_sample_loss is None:
209
+ result.model = self.model
210
+ result.optimizer = self.optimizer
211
+ result.epoch = self.epoch
212
+ result.learning_rate = self.learning_rate
213
+ result.batch_size = self.batch_size
214
+ result.description = self.description
215
+ result.per_sample_loss = self.per_sample_loss
216
+ else:
217
+ result.model = self.model
218
+ result.optimizer = self.optimizer
219
+ result.epoch = np.concatenate((self.epoch, other.epoch))
220
+ result.learning_rate = self.learning_rate
221
+ result.batch_size = self.batch_size
222
+ result.description = self.description
223
+ result.per_sample_loss = torch.cat((self.per_sample_loss, other.per_sample_loss), dim=0)
224
+
225
+ return result
226
+
227
+ def to_netcdf(self, directory, verbose=True):
228
+ '''
229
+ Save the diagnostics object to a netCDF file.
230
+
231
+ The output file is written under the given directory, which is created
232
+ if necessary. The saved dataset includes epoch and sample coordinates,
233
+ metadata attributes, and the per-sample loss array.
234
+ '''
235
+
236
+ if not isinstance(directory, Path):
237
+ directory = Path(directory)
238
+
239
+ file_name = Path(
240
+ self.model
241
+ + '.'
242
+ + self.optimizer
243
+ + '.'
244
+ + self.description
245
+ + '.epoch='
246
+ + str(self.epoch[0])
247
+ + '_to_'
248
+ + str(self.epoch[-1])
249
+ + '.nc'
250
+ )
251
+
252
+ # Create Xarray dataset
253
+ ds = xr.Dataset()
254
+
255
+ # Coordinates
256
+ ds.coords['epoch'] = self.epoch
257
+ ds.coords['sample'] = np.arange(self.per_sample_loss.shape[1])
258
+
259
+ # Global attributes
260
+ ds.attrs['model'] = self.model
261
+ ds.attrs['optimizer'] = self.optimizer
262
+ ds.attrs['description'] = self.description
263
+ ds.attrs['learning_rate'] = self.learning_rate
264
+ ds.attrs['batch_size'] = self.batch_size
265
+
266
+ # 2D variables
267
+ ds['per_sample_loss'] = (['epoch', 'sample'], self.per_sample_loss.detach().cpu().numpy())
268
+ ds['per_sample_loss'].attrs['long_name'] = 'Mean per-sample loss'
269
+
270
+ # Save
271
+ file_path = directory / file_name
272
+ file_path.parent.mkdir(parents=True, exist_ok=True)
273
+ ds.to_netcdf(file_path, unlimited_dims='epoch')
274
+
275
+ # Close
276
+ ds.close()
277
+
278
+ if verbose:
279
+ print('Saved diagnostics in ', file_path)
280
+
281
+ return file_path
@@ -0,0 +1,160 @@
1
+ '''
2
+ Loss-related utilities for evaluating models at per-sample resolution.
3
+
4
+ This module provides functions to compute per-sample losses from either
5
+ a data loader or in-memory tensors, preserve the model's training state
6
+ during evaluation, and identify the samples with the largest losses.
7
+ '''
8
+
9
+ import torch
10
+
11
+ from torch_tk.models.utils import get_model_device
12
+
13
+
14
+ def per_sample_loss_from_data_loader(model, loss_function_sample_resolved, data_loader):
15
+ '''
16
+ Compute per-sample loss values and their mean from a data loader.
17
+
18
+ The model is evaluated without gradient tracking, its original training
19
+ state is restored afterward, and the returned per-sample loss tensor is
20
+ moved to CPU memory.
21
+
22
+ Returns
23
+ -------
24
+ float
25
+ Mean loss over all samples in the data loader.
26
+ torch.Tensor
27
+ A 1-D tensor containing the model loss for each sample in the given data loader,
28
+ always in CPU memory.
29
+ '''
30
+
31
+ device = get_model_device(model)
32
+
33
+ was_training = model.training
34
+
35
+ model.eval()
36
+
37
+ data_n = len(data_loader.dataset)
38
+
39
+ per_sample_loss = torch.empty(data_n, device=device)
40
+
41
+ with torch.no_grad():
42
+ ii = 0
43
+
44
+ for x, y in data_loader:
45
+ batch_n = len(x)
46
+ loss = loss_function_sample_resolved(model(x.to(device)), y.to(device))
47
+ if loss.ndim != 1:
48
+ raise ValueError('Per-sample loss function does not produce a 1-dimensional tensor.')
49
+ if loss.shape[0] != batch_n:
50
+ raise ValueError('Per-sample loss function produces fewer or more loss values than samples.')
51
+ per_sample_loss[ii : ii + batch_n] = loss
52
+ ii += batch_n
53
+
54
+ mean_per_sample_loss = per_sample_loss.mean().item()
55
+
56
+ model.train(was_training)
57
+
58
+ return mean_per_sample_loss, per_sample_loss.detach().cpu()
59
+
60
+
61
+ def per_sample_loss_from_data(model, loss_function_sample_resolved, x_data, y_data, chunk_size=None):
62
+ '''
63
+ Compute per-sample loss values and their mean from in-memory tensors.
64
+
65
+ The model is evaluated without gradient tracking, optionally in chunks to
66
+ limit memory usage, and its original training state is restored afterward.
67
+ The returned per-sample loss tensor is moved to CPU memory.
68
+
69
+ Arguments
70
+ ----------
71
+ model : torch.nn.Module
72
+ Model used for prediction.
73
+ loss_function_sample_resolved : callable
74
+ Function taking (predictions, targets) and returning a 1-D tensor
75
+ of per-sample losses.
76
+ x_data : torch.Tensor
77
+ Input data of shape (N, ...).
78
+ y_data : torch.Tensor
79
+ Target data of shape (N, ...).
80
+ chunk_size : int, optional
81
+ Number of samples to process at once. If not provided, all samples will be processed at once.
82
+
83
+ Returns
84
+ -------
85
+ float
86
+ Mean loss over all samples.
87
+ torch.Tensor
88
+ A 1-D tensor containing the loss for each sample, always in CPU memory.
89
+ '''
90
+
91
+ device = get_model_device(model)
92
+
93
+ was_training = model.training
94
+
95
+ model.eval()
96
+
97
+ data_n = x_data.shape[0]
98
+
99
+ per_sample_loss = torch.empty(data_n, device=device)
100
+
101
+ with torch.no_grad():
102
+ if chunk_size:
103
+ for i_start in range(0, data_n, chunk_size):
104
+ i_end = min(i_start + chunk_size, data_n)
105
+ xb = x_data[i_start:i_end].to(device)
106
+ yb = y_data[i_start:i_end].to(device)
107
+ per_sample_loss[i_start:i_end] = loss_function_sample_resolved(model(xb), yb)
108
+ else:
109
+ per_sample_loss = loss_function_sample_resolved(model(x_data.to(device)), y_data.to(device))
110
+
111
+ mean_per_sample_loss = per_sample_loss.mean().item()
112
+
113
+ if per_sample_loss.ndim != 1 or per_sample_loss.shape[0] != data_n:
114
+ raise ValueError('loss_function_sample_resolved must return a 1-D tensor with one loss value per sample.')
115
+
116
+ model.train(was_training)
117
+
118
+ return mean_per_sample_loss, per_sample_loss.detach().cpu()
119
+
120
+
121
+ def model_worst_loss(model, loss_function_sample_resolved, x_data, y_data, n, chunk_size=None):
122
+ '''
123
+ Return the indices and loss values of the n worst-performing samples.
124
+
125
+ Calculates the
126
+ - the n indices in the inputs x_data for which the given model has the
127
+ largest loss relative to the reference data y_data
128
+ - the corresponding loss values
129
+
130
+ Arguments
131
+ ----------
132
+ model : torch.nn.Module
133
+ The model to evaluate.
134
+ x_data : torch.Tensor
135
+ Input data, with batch dimension first.
136
+ y_data : torch.Tensor
137
+ Target data, with batch dimension first.
138
+ chunk_size : int, optional
139
+ Number of samples to process at once.
140
+
141
+ Returns
142
+ -------
143
+ - the n indices in the inputs x_data for which the given model has the
144
+ largest mean square error relative to the reference data y_data
145
+ - the corresponding mean square errors
146
+ '''
147
+
148
+ # Model mean squared error for each data sample
149
+ with torch.no_grad():
150
+ mean_per_sample_loss, per_sample_loss = per_sample_loss_from_data(
151
+ model, loss_function_sample_resolved, x_data, y_data, chunk_size=chunk_size
152
+ )
153
+
154
+ # Indices such that per_sample_loss[idxs] is sorted in descending order
155
+ idxs = torch.argsort(per_sample_loss, descending=True)
156
+
157
+ # The indices of the n elements in x_data that produce the largest model mean square error
158
+ idxs_worst = idxs[:n].tolist()
159
+
160
+ return idxs_worst, per_sample_loss[idxs_worst]