structcast-model 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. structcast_model/__init__.py +21 -0
  2. structcast_model/base_trainer.py +345 -0
  3. structcast_model/builders/__init__.py +1 -0
  4. structcast_model/builders/auto_name.py +37 -0
  5. structcast_model/builders/base_builder.py +618 -0
  6. structcast_model/builders/jinja_filters.py +72 -0
  7. structcast_model/builders/schema.py +584 -0
  8. structcast_model/builders/torch_builder.py +174 -0
  9. structcast_model/commands/__init__.py +1 -0
  10. structcast_model/commands/cmd_torch.py +508 -0
  11. structcast_model/commands/main.py +61 -0
  12. structcast_model/commands/utils.py +75 -0
  13. structcast_model/torch/__init__.py +1 -0
  14. structcast_model/torch/layers/__init__.py +37 -0
  15. structcast_model/torch/layers/accuracy.py +26 -0
  16. structcast_model/torch/layers/add.py +16 -0
  17. structcast_model/torch/layers/channel_shuffle.py +27 -0
  18. structcast_model/torch/layers/concatenate.py +45 -0
  19. structcast_model/torch/layers/criteria_tracker.py +38 -0
  20. structcast_model/torch/layers/fold.py +118 -0
  21. structcast_model/torch/layers/lazy_norm.py +97 -0
  22. structcast_model/torch/layers/multiply.py +16 -0
  23. structcast_model/torch/layers/permute.py +72 -0
  24. structcast_model/torch/layers/reduce.py +24 -0
  25. structcast_model/torch/layers/reinmax.py +67 -0
  26. structcast_model/torch/layers/scale_identity.py +30 -0
  27. structcast_model/torch/layers/split.py +43 -0
  28. structcast_model/torch/optimizers.py +257 -0
  29. structcast_model/torch/trainer.py +682 -0
  30. structcast_model/torch/types.py +38 -0
  31. structcast_model/utils/__init__.py +1 -0
  32. structcast_model/utils/base.py +127 -0
  33. structcast_model-1.0.0.dist-info/METADATA +763 -0
  34. structcast_model-1.0.0.dist-info/RECORD +36 -0
  35. structcast_model-1.0.0.dist-info/WHEEL +4 -0
  36. structcast_model-1.0.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,21 @@
1
+ """StructCast-Model: Construct neural network models and training workflows by structcast package."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ __version__ = "1.0.0"
6
+ __all__ = ["base_trainer", "builders", "torch", "utils"]
7
+
8
+ if TYPE_CHECKING:
9
+ from structcast_model import base_trainer, builders, torch, utils
10
+ else:
11
+ import sys
12
+
13
+ from structcast.utils.lazy_import import LazySelectedImporter
14
+
15
+ import_structure = {
16
+ "builders": [],
17
+ "torch": [],
18
+ "utils": [],
19
+ "base_trainer": [],
20
+ }
21
+ sys.modules[__name__] = LazySelectedImporter(__name__, globals(), import_structure)
@@ -0,0 +1,345 @@
1
+ """Base trainer for training a model."""
2
+
3
+ from collections.abc import Callable, Iterable, Mapping
4
+ from dataclasses import dataclass, field
5
+ from logging import getLogger
6
+ from math import inf
7
+ from operator import gt, lt
8
+ from time import time
9
+ from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeAlias, TypeVar
10
+
11
+ logger = getLogger(__name__)
12
+
13
+ ModelT_contra = TypeVar("ModelT_contra", contravariant=True)
14
+
15
+ DatasetLike: TypeAlias = Iterable[dict[str, Any]]
16
+ """Dataset-like object."""
17
+
18
+
19
+ def get_dataset(dataset: DatasetLike | Callable[[], DatasetLike]) -> Iterable[dict[str, Any]]:
20
+ """Get the dataset."""
21
+ return dataset() if callable(dataset) else dataset
22
+
23
+
24
+ def get_dataset_size(dataset: DatasetLike | Callable[[], DatasetLike]) -> int:
25
+ """Get the size of the dataset."""
26
+ dataset = get_dataset(dataset)
27
+ if hasattr(dataset, "__len__"):
28
+ return dataset.__len__()
29
+ return sum(1 for _ in dataset)
30
+
31
+
32
+ class Forward(Protocol[ModelT_contra]):
33
+ """Protocol for forward pass configuration."""
34
+
35
+ def __call__(self, inputs: Any, **models: ModelT_contra) -> dict[str, Any]:
36
+ """Perform the forward pass for the given inputs and return the outputs and any additional information."""
37
+
38
+
39
+ class Backward(Protocol):
40
+ """Protocol for backward pass configuration."""
41
+
42
+ def __call__(self, step: int, *args: Any, **kwargs: Any) -> bool:
43
+ """Perform the backward pass for the given step and losses, and return whether the model has been updated."""
44
+
45
+
46
+ @dataclass(kw_only=True)
47
+ class BaseInfo:
48
+ """Base information for building a model."""
49
+
50
+ step: int = 0
51
+ """The current training step."""
52
+
53
+ update: int = 0
54
+ """The number of times the model has been updated."""
55
+
56
+ epoch: int = 0
57
+ """The current epoch."""
58
+
59
+ history: dict[int, dict[str, Any]] = field(default_factory=dict)
60
+ """History of training and validation logs."""
61
+
62
+ def logs(self, epoch: int | None = None) -> dict[str, Any]:
63
+ """Get the log for the given epoch."""
64
+ if epoch is None:
65
+ return self.history.setdefault(self.epoch, {})
66
+ if epoch in self.history:
67
+ return self.history[epoch]
68
+ raise KeyError(f"No logs found for key: {epoch}.")
69
+
70
+
71
+ class Callback(Protocol, Generic[ModelT_contra]):
72
+ """Protocol for callbacks."""
73
+
74
+ def __call__(self, info: BaseInfo, **models: ModelT_contra) -> None:
75
+ """Call the callback with the given information."""
76
+
77
+
78
+ class BestCallback(Protocol[ModelT_contra]):
79
+ """Protocol for best criterion callback."""
80
+
81
+ def __call__(self, info: BaseInfo, target: str, best: float, **models: ModelT_contra) -> None:
82
+ """Call the callback with the given info, target criterion, and best value."""
83
+
84
+
85
+ def invoke_callback(callbacks: list[Callable[..., None]], info: BaseInfo, *args: Any, **models: ModelT_contra) -> None:
86
+ """Invoke callback."""
87
+ for callback in callbacks:
88
+ callback(info, *args, **models)
89
+
90
+
91
+ @dataclass(kw_only=True)
92
+ class Callbacks(Generic[ModelT_contra]):
93
+ """Callbacks."""
94
+
95
+ on_update: list[Callback[ModelT_contra]] = field(default_factory=list)
96
+ """Callbacks to call after each update."""
97
+
98
+ on_training_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
99
+ """Callbacks to call at the beginning of training."""
100
+
101
+ on_training_end: list[Callback[ModelT_contra]] = field(default_factory=list)
102
+ """Callbacks to call at the end of training."""
103
+
104
+ on_training_step_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
105
+ """Callbacks to be called at the beginning of each training step."""
106
+
107
+ on_training_step_end: list[Callback[ModelT_contra]] = field(default_factory=list)
108
+ """Callbacks to be called at the end of each training step."""
109
+
110
+ on_validation_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
111
+ """Callbacks to be called at the beginning of validation."""
112
+
113
+ on_validation_end: list[Callback[ModelT_contra]] = field(default_factory=list)
114
+ """Callbacks to be called at the end of validation."""
115
+
116
+ on_validation_step_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
117
+ """Callbacks to be called at the beginning of each validation step."""
118
+
119
+ on_validation_step_end: list[Callback[ModelT_contra]] = field(default_factory=list)
120
+ """Callbacks to be called at the end of each validation step."""
121
+
122
+ on_epoch_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
123
+ """Callbacks to be called at the beginning of each epoch."""
124
+
125
+ on_epoch_end: list[Callback[ModelT_contra]] = field(default_factory=list)
126
+ """Callbacks to be called at the end of each epoch."""
127
+
128
+ add_global_callbacks: bool = True
129
+ """Whether to add global callbacks."""
130
+
131
+ def __post_init__(self) -> None:
132
+ """Post initialization."""
133
+ if self.add_global_callbacks:
134
+ self.on_update.extend(GLOBAL_CALLBACKS.on_update)
135
+ self.on_training_begin.extend(GLOBAL_CALLBACKS.on_training_begin)
136
+ self.on_training_end.extend(GLOBAL_CALLBACKS.on_training_end)
137
+ self.on_training_step_begin.extend(GLOBAL_CALLBACKS.on_training_step_begin)
138
+ self.on_training_step_end.extend(GLOBAL_CALLBACKS.on_training_step_end)
139
+ self.on_validation_begin.extend(GLOBAL_CALLBACKS.on_validation_begin)
140
+ self.on_validation_end.extend(GLOBAL_CALLBACKS.on_validation_end)
141
+ self.on_validation_step_begin.extend(GLOBAL_CALLBACKS.on_validation_step_begin)
142
+ self.on_validation_step_end.extend(GLOBAL_CALLBACKS.on_validation_step_end)
143
+ self.on_epoch_begin.extend(GLOBAL_CALLBACKS.on_epoch_begin)
144
+ self.on_epoch_end.extend(GLOBAL_CALLBACKS.on_epoch_end)
145
+
146
+
147
+ GLOBAL_CALLBACKS = Callbacks[Any](add_global_callbacks=False)
148
+ """Global callbacks."""
149
+
150
+
151
+ class InferenceWrapper(Protocol[ModelT_contra]):
152
+ """Protocol for inference wrapper."""
153
+
154
+ def __call__(self, info: BaseInfo, **models: ModelT_contra) -> dict[str, Any]:
155
+ """Wrap the model for inference, e.g., for quantization or ONNX export."""
156
+
157
+
158
+ @dataclass(kw_only=True)
159
+ class BaseTrainer(BaseInfo, Callbacks[ModelT_contra]):
160
+ """Base trainer for training a model."""
161
+
162
+ training_step: Forward[ModelT_contra]
163
+ """The forward pass configuration for training."""
164
+
165
+ backward: Backward
166
+ """The backward pass configuration."""
167
+
168
+ tracker: Callable[..., dict[str, float]]
169
+ """The tracker to log training and validation information."""
170
+
171
+ inference_wrapper: InferenceWrapper[ModelT_contra] | None = None
172
+ """An optional wrapper to apply to the model during inference, e.g., for quantization or ONNX export."""
173
+
174
+ validation_step: Forward[ModelT_contra] | None = None
175
+ """The forward pass configuration for validation."""
176
+
177
+ training_prefix: str = ""
178
+ """ Prefix for training logs. """
179
+
180
+ validation_prefix: str = "val_"
181
+ """ Prefix for validation logs. """
182
+
183
+ history: dict[int, dict[str, Any]] = field(default_factory=dict)
184
+ """History of training and validation logs."""
185
+
186
+ def sync(self) -> None:
187
+ """Synchronize the device if necessary. This is a no-op by default, but can be overridden by subclasses."""
188
+
189
+ def train(self, dataset: DatasetLike | Callable[[], DatasetLike], **models: ModelT_contra) -> Mapping[str, Any]:
190
+ """Train the model on the given dataset.
191
+
192
+ Args:
193
+ dataset (DatasetLike | Callable[[], DatasetLike]): The dataset to train on,
194
+ which can be an iterable of input dictionaries or a callable that returns such an iterable.
195
+ **models (ModelT): The models to train.
196
+
197
+ Returns:
198
+ Mapping[str, Any]: The logs from training, which may include metrics and other information.
199
+ """
200
+ invoke_callback(self.on_training_begin, self, **models)
201
+ tracker, training_step, backward, elapsed_time = self.tracker, self.training_step, self.backward, 0.0
202
+ for index, inputs in enumerate(get_dataset(dataset), start=1):
203
+ self.step += 1
204
+ invoke_callback(self.on_training_step_begin, self, **models)
205
+ elapsed_time -= time()
206
+ criteria = training_step(inputs, **models)
207
+ should_update = backward(self.step, **criteria)
208
+ self.sync()
209
+ elapsed_time += time()
210
+ logs = tracker(**criteria) | {"elapsed_time": elapsed_time / index}
211
+ if self.training_prefix:
212
+ logs = {f"{self.training_prefix}{k}": v for k, v in logs.items()}
213
+ self.logs().update(logs)
214
+ if should_update:
215
+ self.update += 1
216
+ invoke_callback(self.on_update, self, **models)
217
+ invoke_callback(self.on_training_step_end, self, **models)
218
+ invoke_callback(self.on_training_end, self, **models)
219
+ return logs
220
+
221
+ def evaluate(self, dataset: DatasetLike | Callable[[], DatasetLike], **models: ModelT_contra) -> Mapping[str, Any]:
222
+ """Evaluate the model on the given dataset.
223
+
224
+ Args:
225
+ dataset (DatasetLike | Callable[[], DatasetLike]): The dataset to evaluate on,
226
+ which can be an iterable of input dictionaries or a callable that returns such an iterable.
227
+ **models (ModelT): The models to evaluate.
228
+
229
+ Returns:
230
+ Mapping[str, Any]: The logs from evaluation, which may include metrics and other information.
231
+ """
232
+ if self.validation_step is None:
233
+ logger.warning("Validation step is not defined. Skipping evaluation.")
234
+ return {}
235
+ if self.inference_wrapper is not None:
236
+ models = self.inference_wrapper(self, **models)
237
+ invoke_callback(self.on_validation_begin, self, **models)
238
+ tracker, validation_step, elapsed_time = self.tracker, self.validation_step, 0.0
239
+ for index, data in enumerate(get_dataset(dataset), start=1):
240
+ invoke_callback(self.on_validation_step_begin, self, **models)
241
+ elapsed_time -= time()
242
+ criteria = validation_step(data, **models)
243
+ self.sync()
244
+ elapsed_time += time()
245
+ logs = tracker(**criteria) | {"elapsed_time": elapsed_time / index}
246
+ if self.validation_prefix:
247
+ logs = {f"{self.validation_prefix}{k}": v for k, v in logs.items()}
248
+ self.logs().update(logs)
249
+ invoke_callback(self.on_validation_step_end, self, **models)
250
+ invoke_callback(self.on_validation_end, self, **models)
251
+ return logs
252
+
253
+ def fit(
254
+ self,
255
+ epochs: int,
256
+ training_dataset: DatasetLike | Callable[[], DatasetLike],
257
+ validation_dataset: DatasetLike | Callable[[], DatasetLike] | None = None,
258
+ start_epoch: int = 1,
259
+ validation_frequency: int = 1,
260
+ **models: ModelT_contra,
261
+ ) -> dict[int, dict[str, Any]]:
262
+ """Fit the model.
263
+
264
+ Args:
265
+ epochs (int): Number of epochs to train.
266
+ training_dataset (DatasetLike | Callable[[], DatasetLike]): Training dataset.
267
+ validation_dataset (DatasetLike | Callable[[], DatasetLike] | None, optional): Validation dataset.
268
+ Defaults to None.
269
+ start_epoch (int, optional): Epoch to start training from. Defaults to 1.
270
+ validation_frequency (int, optional): Frequency of validation. Defaults to 1.
271
+ **models (ModelT): The models to train and validate.
272
+
273
+ Returns:
274
+ History of training and validation logs.
275
+ """
276
+ if validation_frequency < 1:
277
+ raise ValueError("Validation frequency must be at least 1.")
278
+ if start_epoch < 1:
279
+ raise ValueError(f"Start epoch must be at least 1: {start_epoch}")
280
+ if start_epoch > epochs:
281
+ raise ValueError(f"Start epoch must be less than or equal to epochs: {start_epoch} > {epochs}")
282
+ for epoch in range(start_epoch, epochs + 1):
283
+ self.epoch = epoch
284
+ invoke_callback(self.on_epoch_begin, self, **models)
285
+ self.train(training_dataset, **models)
286
+ if validation_dataset is not None and epoch % validation_frequency == 0:
287
+ self.evaluate(validation_dataset, **models)
288
+ invoke_callback(self.on_epoch_end, self, **models)
289
+ return self.history
290
+
291
+
292
+ @dataclass(kw_only=True, slots=True)
293
+ class BestCriterion(Generic[ModelT_contra]):
294
+ """Callback to track the best criterion during training or validation."""
295
+
296
+ target: str
297
+ """The target criterion to monitor."""
298
+
299
+ mode: Literal["min", "max"] = "min"
300
+ """The mode to monitor the criterion. Either 'min' or 'max'."""
301
+
302
+ on_best: list[BestCallback[ModelT_contra]] = field(default_factory=list)
303
+ """Callbacks to be called when a new best criterion is found."""
304
+
305
+ _best: float = field(init=False, repr=False)
306
+ _compare: Callable[[float, float], bool] = field(init=False, repr=False)
307
+
308
+ def __post_init__(self) -> None:
309
+ """Post initialization."""
310
+ self._compare = lt if self.mode == "min" else gt
311
+ self._best = inf if self.mode == "min" else -inf
312
+
313
+ def __call__(self, info: BaseInfo, **models: ModelT_contra) -> None:
314
+ """Check and update the best criterion."""
315
+ current: float | None = info.logs().get(self.target, None)
316
+ if current is not None:
317
+ if self._compare(current, self._best):
318
+ self._best = current
319
+ invoke_callback(self.on_best, info, self.target, self._best, **models)
320
+
321
+
322
+ __all__ = [
323
+ "GLOBAL_CALLBACKS",
324
+ "Backward",
325
+ "BaseInfo",
326
+ "BaseTrainer",
327
+ "BestCallback",
328
+ "BestCriterion",
329
+ "Callback",
330
+ "Callbacks",
331
+ "DatasetLike",
332
+ "Forward",
333
+ "InferenceWrapper",
334
+ "get_dataset",
335
+ "get_dataset_size",
336
+ "invoke_callback",
337
+ ]
338
+
339
+
340
+ if not TYPE_CHECKING:
341
+ import sys
342
+
343
+ from structcast.utils.lazy_import import LazySelectedImporter
344
+
345
+ sys.modules[__name__] = LazySelectedImporter(__name__, globals())
@@ -0,0 +1 @@
1
+ """Builders module for StructCast-Model."""
@@ -0,0 +1,37 @@
1
+ """Auto Name."""
2
+
3
+ from collections import defaultdict
4
+ from typing import TYPE_CHECKING
5
+
6
+
7
+ class AutoName:
8
+ """Auto Name."""
9
+
10
+ def __init__(self, infix: str = "") -> None:
11
+ """Initialize AutoName."""
12
+ self._infix = infix
13
+ self._object_name_uids: dict[str, int] = defaultdict(int)
14
+
15
+ def __call__(self, value: str) -> str:
16
+ """Generate a unique name."""
17
+ if value in self._object_name_uids:
18
+ unique_name = f"{value}{self._infix}{self._object_name_uids[value]}"
19
+ else:
20
+ unique_name = value
21
+ self._object_name_uids[value] += 1
22
+ return unique_name
23
+
24
+ def reset(self) -> None:
25
+ """Reset the name counter."""
26
+ self._object_name_uids.clear()
27
+
28
+
29
+ __all__ = ["AutoName"]
30
+
31
+
32
+ if not TYPE_CHECKING:
33
+ import sys
34
+
35
+ from structcast.utils.lazy_import import LazySelectedImporter
36
+
37
+ sys.modules[__name__] = LazySelectedImporter(__name__, globals())