torch-tk 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_tk/__init__.py +29 -0
- torch_tk/checkpoints/__init__.py +1 -0
- torch_tk/checkpoints/checkpoint_manager.py +139 -0
- torch_tk/checkpoints/utils.py +33 -0
- torch_tk/diagnostics/__init__.py +3 -0
- torch_tk/diagnostics/diagnostics.py +281 -0
- torch_tk/diagnostics/loss.py +160 -0
- torch_tk/diagnostics/plotting.py +148 -0
- torch_tk/models/__init__.py +2 -0
- torch_tk/models/model.py +102 -0
- torch_tk/models/utils.py +18 -0
- torch_tk/optimizers/__init__.py +3 -0
- torch_tk/optimizers/adam.py +87 -0
- torch_tk/optimizers/sgd.py +79 -0
- torch_tk/optimizers/sgd_manual.py +104 -0
- torch_tk/test.py +5 -0
- torch_tk/training/__init__.py +1 -0
- torch_tk/training/trainer.py +429 -0
- torch_tk-1.0.8.dist-info/METADATA +197 -0
- torch_tk-1.0.8.dist-info/RECORD +23 -0
- torch_tk-1.0.8.dist-info/WHEEL +5 -0
- torch_tk-1.0.8.dist-info/licenses/LICENSE +13 -0
- torch_tk-1.0.8.dist-info/top_level.txt +1 -0
torch_tk/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
2
|
+
|
|
3
|
+
_DIST_NAME = "torch-tk"
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version(_DIST_NAME)
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
pkg = __package__ or __name__.split(".", 1)[0]
|
|
9
|
+
try:
|
|
10
|
+
__version__ = version(pkg)
|
|
11
|
+
except PackageNotFoundError:
|
|
12
|
+
__version__ = "0.0.0+local"
|
|
13
|
+
|
|
14
|
+
from .models.model import Model
|
|
15
|
+
from .optimizers.sgd import SGD
|
|
16
|
+
from .optimizers.adam import Adam
|
|
17
|
+
from .training.trainer import Trainer
|
|
18
|
+
from .checkpoints.checkpoint_manager import CheckPointManager
|
|
19
|
+
from .diagnostics.diagnostics import Diagnostics
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"__version__",
|
|
23
|
+
"Model",
|
|
24
|
+
"SGD",
|
|
25
|
+
"Adam",
|
|
26
|
+
"Trainer",
|
|
27
|
+
"CheckPointManager",
|
|
28
|
+
"Diagnostics",
|
|
29
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .checkpoint_manager import CheckPointManager
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Checkpoint utilities for saving and restoring self-describing model and optimizer state.
|
|
3
|
+
|
|
4
|
+
This module provides the CheckPointManager class, which saves checkpoints that
|
|
5
|
+
contain enough information to reconstruct both a model and its optimizer in the
|
|
6
|
+
state that created the checkpoint. Each checkpoint stores the fully qualified
|
|
7
|
+
class path, constructor arguments from constructor_dict(), and state from
|
|
8
|
+
state_dict() for both objects.
|
|
9
|
+
|
|
10
|
+
Models and optimizers must be importable from a stable class path and must expose
|
|
11
|
+
|
|
12
|
+
- constructor_dict()
|
|
13
|
+
- state_dict()
|
|
14
|
+
- load_state_dict()
|
|
15
|
+
|
|
16
|
+
A model must be reconstructible from its class path and constructor_dict().
|
|
17
|
+
An optimizer must be reconstructible from its class path, model.parameters(),
|
|
18
|
+
and constructor_dict().
|
|
19
|
+
|
|
20
|
+
The constructor data must be serializable. In practice, this means it should
|
|
21
|
+
contain only standard serializable Python values such as numbers, strings,
|
|
22
|
+
lists, tuples, and dictionaries.
|
|
23
|
+
|
|
24
|
+
This design is not suitable for optimizers that depend on non-serializable
|
|
25
|
+
constructor inputs, non-standard constructor signatures, custom parameter-group
|
|
26
|
+
reconstruction beyond model.parameters(), or runtime state not captured by
|
|
27
|
+
state_dict().
|
|
28
|
+
'''
|
|
29
|
+
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
from .utils import class_path_of_instance, import_class
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CheckPointManager:
|
|
38
|
+
'''
|
|
39
|
+
Save checkpoints and rebuild a model and optimizer from a checkpoint file.
|
|
40
|
+
|
|
41
|
+
The checkpoint contains the epoch, class paths, constructor arguments, and
|
|
42
|
+
state dictionaries for the model and optimizer.
|
|
43
|
+
'''
|
|
44
|
+
|
|
45
|
+
def __init__(self, model, optimizer, directory):
|
|
46
|
+
'''
|
|
47
|
+
Store the model, optimizer, and checkpoint directory.
|
|
48
|
+
|
|
49
|
+
The directory is converted to a Path and created if needed.
|
|
50
|
+
'''
|
|
51
|
+
if not isinstance(directory, Path):
|
|
52
|
+
directory = Path(directory)
|
|
53
|
+
|
|
54
|
+
self.model = model
|
|
55
|
+
self.optimizer = optimizer
|
|
56
|
+
self.directory = directory
|
|
57
|
+
|
|
58
|
+
def save(self, epoch):
|
|
59
|
+
'''
|
|
60
|
+
Save a checkpoint for the current model and optimizer state.
|
|
61
|
+
|
|
62
|
+
Returns the path to the written checkpoint file.
|
|
63
|
+
'''
|
|
64
|
+
if not hasattr(self.model, 'constructor_dict'):
|
|
65
|
+
raise TypeError('Model must implement constructor_dict().')
|
|
66
|
+
|
|
67
|
+
if not hasattr(self.optimizer, 'constructor_dict'):
|
|
68
|
+
raise TypeError('Optimizer must implement constructor_dict().')
|
|
69
|
+
|
|
70
|
+
checkpoint = {
|
|
71
|
+
'epoch': epoch,
|
|
72
|
+
'model_class_path': class_path_of_instance(self.model),
|
|
73
|
+
'model_constructor_dict': self.model.constructor_dict(),
|
|
74
|
+
'model_state_dict': self.model.state_dict(),
|
|
75
|
+
'optimizer_class_path': class_path_of_instance(self.optimizer),
|
|
76
|
+
'optimizer_constructor_dict': self.optimizer.constructor_dict(),
|
|
77
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
file_name = Path(type(self.model).__name__ + '.' + type(self.optimizer).__name__ + '.epoch=' + str(epoch) + '.pt')
|
|
81
|
+
|
|
82
|
+
self.directory.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
file_path = self.directory / file_name
|
|
85
|
+
torch.save(checkpoint, file_path)
|
|
86
|
+
|
|
87
|
+
return file_path
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def load_from_file(cls, file_path, device=None):
|
|
91
|
+
'''
|
|
92
|
+
Load a checkpoint file and reconstruct the model, optimizer, and manager.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
checkpoint_manager, model, optimizer, epoch
|
|
96
|
+
|
|
97
|
+
If device is given, the checkpoint is loaded onto that device and a
|
|
98
|
+
saved model constructor argument named 'device' is overridden.
|
|
99
|
+
'''
|
|
100
|
+
if not isinstance(file_path, Path):
|
|
101
|
+
file_path = Path(file_path)
|
|
102
|
+
|
|
103
|
+
checkpoint = torch.load(file_path, map_location=device)
|
|
104
|
+
|
|
105
|
+
# Reconstruct model
|
|
106
|
+
|
|
107
|
+
model_class = import_class(checkpoint['model_class_path'])
|
|
108
|
+
|
|
109
|
+
model_constructor_dict = checkpoint['model_constructor_dict']
|
|
110
|
+
model_args = model_constructor_dict.get('args', [])
|
|
111
|
+
model_kwargs = dict(model_constructor_dict.get('kwargs', {}))
|
|
112
|
+
|
|
113
|
+
if 'device' in model_kwargs and device is not None:
|
|
114
|
+
model_kwargs['device'] = device
|
|
115
|
+
|
|
116
|
+
model = model_class(*model_args, **model_kwargs)
|
|
117
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
118
|
+
|
|
119
|
+
if device is not None:
|
|
120
|
+
model = model.to(device)
|
|
121
|
+
|
|
122
|
+
# Reconstruct optimizer
|
|
123
|
+
|
|
124
|
+
optimizer_class = import_class(checkpoint['optimizer_class_path'])
|
|
125
|
+
|
|
126
|
+
optimizer_constructor_dict = checkpoint['optimizer_constructor_dict']
|
|
127
|
+
optimizer_args = optimizer_constructor_dict.get('args', [])
|
|
128
|
+
optimizer_kwargs = dict(optimizer_constructor_dict.get('kwargs', {}))
|
|
129
|
+
|
|
130
|
+
optimizer = optimizer_class(model.parameters(), *optimizer_args, **optimizer_kwargs)
|
|
131
|
+
|
|
132
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
133
|
+
|
|
134
|
+
# Create a checkpoint manager instance
|
|
135
|
+
checkpoint_manager = cls(model, optimizer, file_path.parent)
|
|
136
|
+
|
|
137
|
+
epoch = checkpoint['epoch']
|
|
138
|
+
|
|
139
|
+
return checkpoint_manager, model, optimizer, epoch
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def class_path_of_instance(instance):
|
|
5
|
+
'''
|
|
6
|
+
Return the fully qualified class path for an instance's class.
|
|
7
|
+
'''
|
|
8
|
+
cls = type(instance)
|
|
9
|
+
return cls.__module__ + '.' + cls.__qualname__
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def import_class(class_path):
|
|
13
|
+
'''
|
|
14
|
+
Import and return a class from a fully qualified class path.
|
|
15
|
+
|
|
16
|
+
Nested classes referenced through __qualname__ are supported.
|
|
17
|
+
'''
|
|
18
|
+
parts = class_path.split('.')
|
|
19
|
+
|
|
20
|
+
for i in range(len(parts), 0, -1):
|
|
21
|
+
module_name = '.'.join(parts[:i])
|
|
22
|
+
try:
|
|
23
|
+
obj = importlib.import_module(module_name)
|
|
24
|
+
break
|
|
25
|
+
except ModuleNotFoundError:
|
|
26
|
+
continue
|
|
27
|
+
else:
|
|
28
|
+
raise ImportError(f'Could not import any module prefix from {class_path!r}')
|
|
29
|
+
|
|
30
|
+
for attr in parts[i:]:
|
|
31
|
+
obj = getattr(obj, attr)
|
|
32
|
+
|
|
33
|
+
return obj
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Utilities for storing, combining, and serializing sample-resolved training
|
|
3
|
+
diagnostics.
|
|
4
|
+
|
|
5
|
+
This module defines the Diagnostics class, which can be created from in-memory
|
|
6
|
+
data, a data loader, or a saved netCDF file. It stores per-sample loss values
|
|
7
|
+
together with model and optimizer metadata and can write the diagnostics back
|
|
8
|
+
to netCDF.
|
|
9
|
+
'''
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
import xarray as xr
|
|
16
|
+
|
|
17
|
+
from .loss import per_sample_loss_from_data, per_sample_loss_from_data_loader
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Diagnostics:
|
|
21
|
+
'''
|
|
22
|
+
Store sample-resolved loss diagnostics together with training metadata.
|
|
23
|
+
|
|
24
|
+
An instance holds model and optimizer identifiers, learning rate, batch
|
|
25
|
+
size, an optional description, one or more epoch values, and the
|
|
26
|
+
corresponding per-sample loss tensors. It also supports loading from and
|
|
27
|
+
saving to netCDF.
|
|
28
|
+
'''
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_data_loader(
|
|
32
|
+
cls,
|
|
33
|
+
model,
|
|
34
|
+
loss_function_sample_resolved,
|
|
35
|
+
optimizer,
|
|
36
|
+
learning_rate,
|
|
37
|
+
batch_size,
|
|
38
|
+
data_loader,
|
|
39
|
+
description: str = None,
|
|
40
|
+
epoch=0,
|
|
41
|
+
):
|
|
42
|
+
'''
|
|
43
|
+
Create a Diagnostics instance from a model evaluated on a data loader.
|
|
44
|
+
|
|
45
|
+
The method computes per-sample losses for all samples in the loader and
|
|
46
|
+
stores them together with model, optimizer, and training metadata.
|
|
47
|
+
'''
|
|
48
|
+
|
|
49
|
+
mean_per_sample_loss, per_sample_loss = per_sample_loss_from_data_loader(
|
|
50
|
+
model, loss_function_sample_resolved, data_loader
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return cls(
|
|
54
|
+
type(model).__name__,
|
|
55
|
+
type(optimizer).__name__,
|
|
56
|
+
learning_rate,
|
|
57
|
+
batch_size,
|
|
58
|
+
epoch=epoch,
|
|
59
|
+
per_sample_loss=per_sample_loss,
|
|
60
|
+
description=description,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_data(
|
|
65
|
+
cls,
|
|
66
|
+
model,
|
|
67
|
+
loss_function_sample_resolved,
|
|
68
|
+
optimizer,
|
|
69
|
+
learning_rate,
|
|
70
|
+
batch_size,
|
|
71
|
+
x_data,
|
|
72
|
+
y_data,
|
|
73
|
+
description: str = None,
|
|
74
|
+
epoch=0,
|
|
75
|
+
chunk_size=None,
|
|
76
|
+
):
|
|
77
|
+
'''
|
|
78
|
+
Create a Diagnostics instance from in-memory input and target tensors.
|
|
79
|
+
|
|
80
|
+
The method computes per-sample losses from the supplied tensors,
|
|
81
|
+
optionally in chunks to reduce memory usage, and stores them
|
|
82
|
+
together with model, optimizer, and training metadata.
|
|
83
|
+
'''
|
|
84
|
+
|
|
85
|
+
mean_per_sample_loss, per_sample_loss = per_sample_loss_from_data(
|
|
86
|
+
model, loss_function_sample_resolved, x_data, y_data, chunk_size=chunk_size
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return cls(
|
|
90
|
+
type(model).__name__,
|
|
91
|
+
type(optimizer).__name__,
|
|
92
|
+
learning_rate,
|
|
93
|
+
batch_size,
|
|
94
|
+
epoch=epoch,
|
|
95
|
+
per_sample_loss=per_sample_loss,
|
|
96
|
+
description=description,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def from_netcdf(cls, path):
|
|
101
|
+
'''
|
|
102
|
+
Create a Diagnostics instance from a saved netCDF file.
|
|
103
|
+
|
|
104
|
+
This restores metadata, epoch values, and per-sample loss data from a
|
|
105
|
+
previously saved diagnostics file, which is useful when resuming work
|
|
106
|
+
from a checkpoint.
|
|
107
|
+
'''
|
|
108
|
+
if not isinstance(path, Path):
|
|
109
|
+
path = Path(path)
|
|
110
|
+
|
|
111
|
+
ds = xr.open_dataset(path)
|
|
112
|
+
|
|
113
|
+
instance = cls(
|
|
114
|
+
ds.attrs['model'],
|
|
115
|
+
ds.attrs['optimizer'],
|
|
116
|
+
ds.attrs['learning_rate'],
|
|
117
|
+
ds.attrs['batch_size'],
|
|
118
|
+
description=ds.attrs['description'],
|
|
119
|
+
per_sample_loss=torch.as_tensor(ds['per_sample_loss'].values),
|
|
120
|
+
epoch=ds['epoch'].values,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
ds.close()
|
|
124
|
+
|
|
125
|
+
return instance
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self, model_name, optimizer_name, learning_rate, batch_size, epoch=None, per_sample_loss=None, description: str = None
|
|
129
|
+
):
|
|
130
|
+
'''
|
|
131
|
+
Initialize a Diagnostics instance.
|
|
132
|
+
|
|
133
|
+
Epoch values and per-sample losses must either both be provided or both
|
|
134
|
+
be omitted. If per-sample loss is one-dimensional, it is promoted to
|
|
135
|
+
two dimensions so that the leading dimension corresponds to epoch.
|
|
136
|
+
'''
|
|
137
|
+
|
|
138
|
+
if epoch is None and per_sample_loss is not None:
|
|
139
|
+
raise ValueError('epoch and per_sample_loss must both be None or both not None.')
|
|
140
|
+
if epoch is not None and per_sample_loss is None:
|
|
141
|
+
raise ValueError('epoch and per_sample_loss must both be None or both not None.')
|
|
142
|
+
|
|
143
|
+
self.model = model_name
|
|
144
|
+
self.optimizer = optimizer_name
|
|
145
|
+
|
|
146
|
+
self.learning_rate = str(learning_rate)
|
|
147
|
+
self.batch_size = batch_size
|
|
148
|
+
self.description = description
|
|
149
|
+
|
|
150
|
+
if epoch is not None:
|
|
151
|
+
self.epoch = np.atleast_1d(np.asarray(epoch, dtype=np.int64))
|
|
152
|
+
else:
|
|
153
|
+
self.epoch = None
|
|
154
|
+
|
|
155
|
+
# Add a first dimension that corresponds to the value(s) in the list self.epoch
|
|
156
|
+
if per_sample_loss is not None:
|
|
157
|
+
if per_sample_loss.ndim == 1:
|
|
158
|
+
self.per_sample_loss = per_sample_loss.unsqueeze(0)
|
|
159
|
+
else:
|
|
160
|
+
self.per_sample_loss = per_sample_loss
|
|
161
|
+
else:
|
|
162
|
+
self.per_sample_loss = None
|
|
163
|
+
|
|
164
|
+
# Check for consistency
|
|
165
|
+
if epoch is not None and len(self.epoch) != self.per_sample_loss.shape[0]:
|
|
166
|
+
raise ValueError('Number of epochs and number of per-sample loss instances do not match.')
|
|
167
|
+
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
def __add__(self, other):
|
|
171
|
+
if not isinstance(other, Diagnostics):
|
|
172
|
+
return NotImplemented
|
|
173
|
+
|
|
174
|
+
'''
|
|
175
|
+
Combine two compatible diagnostics objects.
|
|
176
|
+
|
|
177
|
+
The two objects must have matching metadata. When both contain per-sample
|
|
178
|
+
loss data, their epoch arrays and per-sample loss tensors are concatenated
|
|
179
|
+
along the epoch dimension.
|
|
180
|
+
'''
|
|
181
|
+
|
|
182
|
+
result = Diagnostics.__new__(Diagnostics)
|
|
183
|
+
|
|
184
|
+
assert self.model == other.model, (
|
|
185
|
+
f"Cannot add Diagnostics objects with different models: {self.model!r} != {other.model!r}"
|
|
186
|
+
)
|
|
187
|
+
assert self.optimizer == other.optimizer, (
|
|
188
|
+
f"Cannot add Diagnostics objects with different optimizers: {self.optimizer!r} != {other.optimizer!r}"
|
|
189
|
+
)
|
|
190
|
+
assert self.learning_rate == other.learning_rate, (
|
|
191
|
+
f"Cannot add Diagnostics objects with different learning rates: {self.learning_rate!r} != {other.learning_rate!r}"
|
|
192
|
+
)
|
|
193
|
+
assert self.batch_size == other.batch_size, (
|
|
194
|
+
f"Cannot add Diagnostics objects with different batch sizes: {self.batch_size!r} != {other.batch_size!r}"
|
|
195
|
+
)
|
|
196
|
+
assert self.description == other.description, (
|
|
197
|
+
f"Cannot add Diagnostics objects with different descriptions: {self.description!r} != {other.description!r}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if self.per_sample_loss is None:
|
|
201
|
+
result.model = other.model
|
|
202
|
+
result.optimizer = other.optimizer
|
|
203
|
+
result.epoch = other.epoch
|
|
204
|
+
result.learning_rate = other.learning_rate
|
|
205
|
+
result.batch_size = other.batch_size
|
|
206
|
+
result.description = other.description
|
|
207
|
+
result.per_sample_loss = other.per_sample_loss
|
|
208
|
+
elif other.per_sample_loss is None:
|
|
209
|
+
result.model = self.model
|
|
210
|
+
result.optimizer = self.optimizer
|
|
211
|
+
result.epoch = self.epoch
|
|
212
|
+
result.learning_rate = self.learning_rate
|
|
213
|
+
result.batch_size = self.batch_size
|
|
214
|
+
result.description = self.description
|
|
215
|
+
result.per_sample_loss = self.per_sample_loss
|
|
216
|
+
else:
|
|
217
|
+
result.model = self.model
|
|
218
|
+
result.optimizer = self.optimizer
|
|
219
|
+
result.epoch = np.concatenate((self.epoch, other.epoch))
|
|
220
|
+
result.learning_rate = self.learning_rate
|
|
221
|
+
result.batch_size = self.batch_size
|
|
222
|
+
result.description = self.description
|
|
223
|
+
result.per_sample_loss = torch.cat((self.per_sample_loss, other.per_sample_loss), dim=0)
|
|
224
|
+
|
|
225
|
+
return result
|
|
226
|
+
|
|
227
|
+
def to_netcdf(self, directory, verbose=True):
|
|
228
|
+
'''
|
|
229
|
+
Save the diagnostics object to a netCDF file.
|
|
230
|
+
|
|
231
|
+
The output file is written under the given directory, which is created
|
|
232
|
+
if necessary. The saved dataset includes epoch and sample coordinates,
|
|
233
|
+
metadata attributes, and the per-sample loss array.
|
|
234
|
+
'''
|
|
235
|
+
|
|
236
|
+
if not isinstance(directory, Path):
|
|
237
|
+
directory = Path(directory)
|
|
238
|
+
|
|
239
|
+
file_name = Path(
|
|
240
|
+
self.model
|
|
241
|
+
+ '.'
|
|
242
|
+
+ self.optimizer
|
|
243
|
+
+ '.'
|
|
244
|
+
+ self.description
|
|
245
|
+
+ '.epoch='
|
|
246
|
+
+ str(self.epoch[0])
|
|
247
|
+
+ '_to_'
|
|
248
|
+
+ str(self.epoch[-1])
|
|
249
|
+
+ '.nc'
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Create Xarray dataset
|
|
253
|
+
ds = xr.Dataset()
|
|
254
|
+
|
|
255
|
+
# Coordinates
|
|
256
|
+
ds.coords['epoch'] = self.epoch
|
|
257
|
+
ds.coords['sample'] = np.arange(self.per_sample_loss.shape[1])
|
|
258
|
+
|
|
259
|
+
# Global attributes
|
|
260
|
+
ds.attrs['model'] = self.model
|
|
261
|
+
ds.attrs['optimizer'] = self.optimizer
|
|
262
|
+
ds.attrs['description'] = self.description
|
|
263
|
+
ds.attrs['learning_rate'] = self.learning_rate
|
|
264
|
+
ds.attrs['batch_size'] = self.batch_size
|
|
265
|
+
|
|
266
|
+
# 2D variables
|
|
267
|
+
ds['per_sample_loss'] = (['epoch', 'sample'], self.per_sample_loss.detach().cpu().numpy())
|
|
268
|
+
ds['per_sample_loss'].attrs['long_name'] = 'Mean per-sample loss'
|
|
269
|
+
|
|
270
|
+
# Save
|
|
271
|
+
file_path = directory / file_name
|
|
272
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
273
|
+
ds.to_netcdf(file_path, unlimited_dims='epoch')
|
|
274
|
+
|
|
275
|
+
# Close
|
|
276
|
+
ds.close()
|
|
277
|
+
|
|
278
|
+
if verbose:
|
|
279
|
+
print('Saved diagnostics in ', file_path)
|
|
280
|
+
|
|
281
|
+
return file_path
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Loss-related utilities for evaluating models at per-sample resolution.
|
|
3
|
+
|
|
4
|
+
This module provides functions to compute per-sample losses from either
|
|
5
|
+
a data loader or in-memory tensors, preserve the model's training state
|
|
6
|
+
during evaluation, and identify the samples with the largest losses.
|
|
7
|
+
'''
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from torch_tk.models.utils import get_model_device
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def per_sample_loss_from_data_loader(model, loss_function_sample_resolved, data_loader):
|
|
15
|
+
'''
|
|
16
|
+
Compute per-sample loss values and their mean from a data loader.
|
|
17
|
+
|
|
18
|
+
The model is evaluated without gradient tracking, its original training
|
|
19
|
+
state is restored afterward, and the returned per-sample loss tensor is
|
|
20
|
+
moved to CPU memory.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
float
|
|
25
|
+
Mean loss over all samples in the data loader.
|
|
26
|
+
torch.Tensor
|
|
27
|
+
A 1-D tensor containing the model loss for each sample in the given data loader,
|
|
28
|
+
always in CPU memory.
|
|
29
|
+
'''
|
|
30
|
+
|
|
31
|
+
device = get_model_device(model)
|
|
32
|
+
|
|
33
|
+
was_training = model.training
|
|
34
|
+
|
|
35
|
+
model.eval()
|
|
36
|
+
|
|
37
|
+
data_n = len(data_loader.dataset)
|
|
38
|
+
|
|
39
|
+
per_sample_loss = torch.empty(data_n, device=device)
|
|
40
|
+
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
ii = 0
|
|
43
|
+
|
|
44
|
+
for x, y in data_loader:
|
|
45
|
+
batch_n = len(x)
|
|
46
|
+
loss = loss_function_sample_resolved(model(x.to(device)), y.to(device))
|
|
47
|
+
if loss.ndim != 1:
|
|
48
|
+
raise ValueError('Per-sample loss function does not produce a 1-dimensional tensor.')
|
|
49
|
+
if loss.shape[0] != batch_n:
|
|
50
|
+
raise ValueError('Per-sample loss function produces fewer or more loss values than samples.')
|
|
51
|
+
per_sample_loss[ii : ii + batch_n] = loss
|
|
52
|
+
ii += batch_n
|
|
53
|
+
|
|
54
|
+
mean_per_sample_loss = per_sample_loss.mean().item()
|
|
55
|
+
|
|
56
|
+
model.train(was_training)
|
|
57
|
+
|
|
58
|
+
return mean_per_sample_loss, per_sample_loss.detach().cpu()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def per_sample_loss_from_data(model, loss_function_sample_resolved, x_data, y_data, chunk_size=None):
|
|
62
|
+
'''
|
|
63
|
+
Compute per-sample loss values and their mean from in-memory tensors.
|
|
64
|
+
|
|
65
|
+
The model is evaluated without gradient tracking, optionally in chunks to
|
|
66
|
+
limit memory usage, and its original training state is restored afterward.
|
|
67
|
+
The returned per-sample loss tensor is moved to CPU memory.
|
|
68
|
+
|
|
69
|
+
Arguments
|
|
70
|
+
----------
|
|
71
|
+
model : torch.nn.Module
|
|
72
|
+
Model used for prediction.
|
|
73
|
+
loss_function_sample_resolved : callable
|
|
74
|
+
Function taking (predictions, targets) and returning a 1-D tensor
|
|
75
|
+
of per-sample losses.
|
|
76
|
+
x_data : torch.Tensor
|
|
77
|
+
Input data of shape (N, ...).
|
|
78
|
+
y_data : torch.Tensor
|
|
79
|
+
Target data of shape (N, ...).
|
|
80
|
+
chunk_size : int, optional
|
|
81
|
+
Number of samples to process at once. If not provided, all samples will be processed at once.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
float
|
|
86
|
+
Mean loss over all samples.
|
|
87
|
+
torch.Tensor
|
|
88
|
+
A 1-D tensor containing the loss for each sample, always in CPU memory.
|
|
89
|
+
'''
|
|
90
|
+
|
|
91
|
+
device = get_model_device(model)
|
|
92
|
+
|
|
93
|
+
was_training = model.training
|
|
94
|
+
|
|
95
|
+
model.eval()
|
|
96
|
+
|
|
97
|
+
data_n = x_data.shape[0]
|
|
98
|
+
|
|
99
|
+
per_sample_loss = torch.empty(data_n, device=device)
|
|
100
|
+
|
|
101
|
+
with torch.no_grad():
|
|
102
|
+
if chunk_size:
|
|
103
|
+
for i_start in range(0, data_n, chunk_size):
|
|
104
|
+
i_end = min(i_start + chunk_size, data_n)
|
|
105
|
+
xb = x_data[i_start:i_end].to(device)
|
|
106
|
+
yb = y_data[i_start:i_end].to(device)
|
|
107
|
+
per_sample_loss[i_start:i_end] = loss_function_sample_resolved(model(xb), yb)
|
|
108
|
+
else:
|
|
109
|
+
per_sample_loss = loss_function_sample_resolved(model(x_data.to(device)), y_data.to(device))
|
|
110
|
+
|
|
111
|
+
mean_per_sample_loss = per_sample_loss.mean().item()
|
|
112
|
+
|
|
113
|
+
if per_sample_loss.ndim != 1 or per_sample_loss.shape[0] != data_n:
|
|
114
|
+
raise ValueError('loss_function_sample_resolved must return a 1-D tensor with one loss value per sample.')
|
|
115
|
+
|
|
116
|
+
model.train(was_training)
|
|
117
|
+
|
|
118
|
+
return mean_per_sample_loss, per_sample_loss.detach().cpu()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def model_worst_loss(model, loss_function_sample_resolved, x_data, y_data, n, chunk_size=None):
|
|
122
|
+
'''
|
|
123
|
+
Return the indices and loss values of the n worst-performing samples.
|
|
124
|
+
|
|
125
|
+
Calculates the
|
|
126
|
+
- the n indices in the inputs x_data for which the given model has the
|
|
127
|
+
largest loss relative to the reference data y_data
|
|
128
|
+
- the corresponding loss values
|
|
129
|
+
|
|
130
|
+
Arguments
|
|
131
|
+
----------
|
|
132
|
+
model : torch.nn.Module
|
|
133
|
+
The model to evaluate.
|
|
134
|
+
x_data : torch.Tensor
|
|
135
|
+
Input data, with batch dimension first.
|
|
136
|
+
y_data : torch.Tensor
|
|
137
|
+
Target data, with batch dimension first.
|
|
138
|
+
chunk_size : int, optional
|
|
139
|
+
Number of samples to process at once.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
- the n indices in the inputs x_data for which the given model has the
|
|
144
|
+
largest mean square error relative to the reference data y_data
|
|
145
|
+
- the corresponding mean square errors
|
|
146
|
+
'''
|
|
147
|
+
|
|
148
|
+
# Model mean squared error for each data sample
|
|
149
|
+
with torch.no_grad():
|
|
150
|
+
mean_per_sample_loss, per_sample_loss = per_sample_loss_from_data(
|
|
151
|
+
model, loss_function_sample_resolved, x_data, y_data, chunk_size=chunk_size
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Indices such that per_sample_loss[idxs] is sorted in descending order
|
|
155
|
+
idxs = torch.argsort(per_sample_loss, descending=True)
|
|
156
|
+
|
|
157
|
+
# The indices of the n elements in x_data that produce the largest model mean square error
|
|
158
|
+
idxs_worst = idxs[:n].tolist()
|
|
159
|
+
|
|
160
|
+
return idxs_worst, per_sample_loss[idxs_worst]
|