structcast-model 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- structcast_model/__init__.py +21 -0
- structcast_model/base_trainer.py +345 -0
- structcast_model/builders/__init__.py +1 -0
- structcast_model/builders/auto_name.py +37 -0
- structcast_model/builders/base_builder.py +618 -0
- structcast_model/builders/jinja_filters.py +72 -0
- structcast_model/builders/schema.py +584 -0
- structcast_model/builders/torch_builder.py +174 -0
- structcast_model/commands/__init__.py +1 -0
- structcast_model/commands/cmd_torch.py +508 -0
- structcast_model/commands/main.py +61 -0
- structcast_model/commands/utils.py +75 -0
- structcast_model/torch/__init__.py +1 -0
- structcast_model/torch/layers/__init__.py +37 -0
- structcast_model/torch/layers/accuracy.py +26 -0
- structcast_model/torch/layers/add.py +16 -0
- structcast_model/torch/layers/channel_shuffle.py +27 -0
- structcast_model/torch/layers/concatenate.py +45 -0
- structcast_model/torch/layers/criteria_tracker.py +38 -0
- structcast_model/torch/layers/fold.py +118 -0
- structcast_model/torch/layers/lazy_norm.py +97 -0
- structcast_model/torch/layers/multiply.py +16 -0
- structcast_model/torch/layers/permute.py +72 -0
- structcast_model/torch/layers/reduce.py +24 -0
- structcast_model/torch/layers/reinmax.py +67 -0
- structcast_model/torch/layers/scale_identity.py +30 -0
- structcast_model/torch/layers/split.py +43 -0
- structcast_model/torch/optimizers.py +257 -0
- structcast_model/torch/trainer.py +682 -0
- structcast_model/torch/types.py +38 -0
- structcast_model/utils/__init__.py +1 -0
- structcast_model/utils/base.py +127 -0
- structcast_model-1.0.0.dist-info/METADATA +763 -0
- structcast_model-1.0.0.dist-info/RECORD +36 -0
- structcast_model-1.0.0.dist-info/WHEEL +4 -0
- structcast_model-1.0.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""StructCast-Model: Construct neural network models and training workflows by structcast package."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
__version__ = "1.0.0"
|
|
6
|
+
__all__ = ["base_trainer", "builders", "torch", "utils"]
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from structcast_model import base_trainer, builders, torch, utils
|
|
10
|
+
else:
|
|
11
|
+
import sys
|
|
12
|
+
|
|
13
|
+
from structcast.utils.lazy_import import LazySelectedImporter
|
|
14
|
+
|
|
15
|
+
import_structure = {
|
|
16
|
+
"builders": [],
|
|
17
|
+
"torch": [],
|
|
18
|
+
"utils": [],
|
|
19
|
+
"base_trainer": [],
|
|
20
|
+
}
|
|
21
|
+
sys.modules[__name__] = LazySelectedImporter(__name__, globals(), import_structure)
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""Base trainer for training a model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from math import inf
|
|
7
|
+
from operator import gt, lt
|
|
8
|
+
from time import time
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeAlias, TypeVar
|
|
10
|
+
|
|
11
|
+
logger = getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
ModelT_contra = TypeVar("ModelT_contra", contravariant=True)
|
|
14
|
+
|
|
15
|
+
DatasetLike: TypeAlias = Iterable[dict[str, Any]]
|
|
16
|
+
"""Dataset-like object."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_dataset(dataset: DatasetLike | Callable[[], DatasetLike]) -> Iterable[dict[str, Any]]:
|
|
20
|
+
"""Get the dataset."""
|
|
21
|
+
return dataset() if callable(dataset) else dataset
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_dataset_size(dataset: DatasetLike | Callable[[], DatasetLike]) -> int:
|
|
25
|
+
"""Get the size of the dataset."""
|
|
26
|
+
dataset = get_dataset(dataset)
|
|
27
|
+
if hasattr(dataset, "__len__"):
|
|
28
|
+
return dataset.__len__()
|
|
29
|
+
return sum(1 for _ in dataset)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Forward(Protocol[ModelT_contra]):
|
|
33
|
+
"""Protocol for forward pass configuration."""
|
|
34
|
+
|
|
35
|
+
def __call__(self, inputs: Any, **models: ModelT_contra) -> dict[str, Any]:
|
|
36
|
+
"""Perform the forward pass for the given inputs and return the outputs and any additional information."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Backward(Protocol):
|
|
40
|
+
"""Protocol for backward pass configuration."""
|
|
41
|
+
|
|
42
|
+
def __call__(self, step: int, *args: Any, **kwargs: Any) -> bool:
|
|
43
|
+
"""Perform the backward pass for the given step and losses, and return whether the model has been updated."""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(kw_only=True)
|
|
47
|
+
class BaseInfo:
|
|
48
|
+
"""Base information for building a model."""
|
|
49
|
+
|
|
50
|
+
step: int = 0
|
|
51
|
+
"""The current training step."""
|
|
52
|
+
|
|
53
|
+
update: int = 0
|
|
54
|
+
"""The number of times the model has been updated."""
|
|
55
|
+
|
|
56
|
+
epoch: int = 0
|
|
57
|
+
"""The current epoch."""
|
|
58
|
+
|
|
59
|
+
history: dict[int, dict[str, Any]] = field(default_factory=dict)
|
|
60
|
+
"""History of training and validation logs."""
|
|
61
|
+
|
|
62
|
+
def logs(self, epoch: int | None = None) -> dict[str, Any]:
|
|
63
|
+
"""Get the log for the given epoch."""
|
|
64
|
+
if epoch is None:
|
|
65
|
+
return self.history.setdefault(self.epoch, {})
|
|
66
|
+
if epoch in self.history:
|
|
67
|
+
return self.history[epoch]
|
|
68
|
+
raise KeyError(f"No logs found for key: {epoch}.")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Callback(Protocol, Generic[ModelT_contra]):
|
|
72
|
+
"""Protocol for callbacks."""
|
|
73
|
+
|
|
74
|
+
def __call__(self, info: BaseInfo, **models: ModelT_contra) -> None:
|
|
75
|
+
"""Call the callback with the given information."""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BestCallback(Protocol[ModelT_contra]):
|
|
79
|
+
"""Protocol for best criterion callback."""
|
|
80
|
+
|
|
81
|
+
def __call__(self, info: BaseInfo, target: str, best: float, **models: ModelT_contra) -> None:
|
|
82
|
+
"""Call the callback with the given info, target criterion, and best value."""
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def invoke_callback(callbacks: list[Callable[..., None]], info: BaseInfo, *args: Any, **models: ModelT_contra) -> None:
|
|
86
|
+
"""Invoke callback."""
|
|
87
|
+
for callback in callbacks:
|
|
88
|
+
callback(info, *args, **models)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass(kw_only=True)
|
|
92
|
+
class Callbacks(Generic[ModelT_contra]):
|
|
93
|
+
"""Callbacks."""
|
|
94
|
+
|
|
95
|
+
on_update: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
96
|
+
"""Callbacks to call after each update."""
|
|
97
|
+
|
|
98
|
+
on_training_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
99
|
+
"""Callbacks to call at the beginning of training."""
|
|
100
|
+
|
|
101
|
+
on_training_end: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
102
|
+
"""Callbacks to call at the end of training."""
|
|
103
|
+
|
|
104
|
+
on_training_step_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
105
|
+
"""Callbacks to be called at the beginning of each training step."""
|
|
106
|
+
|
|
107
|
+
on_training_step_end: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
108
|
+
"""Callbacks to be called at the end of each training step."""
|
|
109
|
+
|
|
110
|
+
on_validation_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
111
|
+
"""Callbacks to be called at the beginning of validation."""
|
|
112
|
+
|
|
113
|
+
on_validation_end: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
114
|
+
"""Callbacks to be called at the end of validation."""
|
|
115
|
+
|
|
116
|
+
on_validation_step_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
117
|
+
"""Callbacks to be called at the beginning of each validation step."""
|
|
118
|
+
|
|
119
|
+
on_validation_step_end: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
120
|
+
"""Callbacks to be called at the end of each validation step."""
|
|
121
|
+
|
|
122
|
+
on_epoch_begin: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
123
|
+
"""Callbacks to be called at the beginning of each epoch."""
|
|
124
|
+
|
|
125
|
+
on_epoch_end: list[Callback[ModelT_contra]] = field(default_factory=list)
|
|
126
|
+
"""Callbacks to be called at the end of each epoch."""
|
|
127
|
+
|
|
128
|
+
add_global_callbacks: bool = True
|
|
129
|
+
"""Whether to add global callbacks."""
|
|
130
|
+
|
|
131
|
+
def __post_init__(self) -> None:
|
|
132
|
+
"""Post initialization."""
|
|
133
|
+
if self.add_global_callbacks:
|
|
134
|
+
self.on_update.extend(GLOBAL_CALLBACKS.on_update)
|
|
135
|
+
self.on_training_begin.extend(GLOBAL_CALLBACKS.on_training_begin)
|
|
136
|
+
self.on_training_end.extend(GLOBAL_CALLBACKS.on_training_end)
|
|
137
|
+
self.on_training_step_begin.extend(GLOBAL_CALLBACKS.on_training_step_begin)
|
|
138
|
+
self.on_training_step_end.extend(GLOBAL_CALLBACKS.on_training_step_end)
|
|
139
|
+
self.on_validation_begin.extend(GLOBAL_CALLBACKS.on_validation_begin)
|
|
140
|
+
self.on_validation_end.extend(GLOBAL_CALLBACKS.on_validation_end)
|
|
141
|
+
self.on_validation_step_begin.extend(GLOBAL_CALLBACKS.on_validation_step_begin)
|
|
142
|
+
self.on_validation_step_end.extend(GLOBAL_CALLBACKS.on_validation_step_end)
|
|
143
|
+
self.on_epoch_begin.extend(GLOBAL_CALLBACKS.on_epoch_begin)
|
|
144
|
+
self.on_epoch_end.extend(GLOBAL_CALLBACKS.on_epoch_end)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
GLOBAL_CALLBACKS = Callbacks[Any](add_global_callbacks=False)
|
|
148
|
+
"""Global callbacks."""
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class InferenceWrapper(Protocol[ModelT_contra]):
|
|
152
|
+
"""Protocol for inference wrapper."""
|
|
153
|
+
|
|
154
|
+
def __call__(self, info: BaseInfo, **models: ModelT_contra) -> dict[str, Any]:
|
|
155
|
+
"""Wrap the model for inference, e.g., for quantization or ONNX export."""
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclass(kw_only=True)
|
|
159
|
+
class BaseTrainer(BaseInfo, Callbacks[ModelT_contra]):
|
|
160
|
+
"""Base trainer for training a model."""
|
|
161
|
+
|
|
162
|
+
training_step: Forward[ModelT_contra]
|
|
163
|
+
"""The forward pass configuration for training."""
|
|
164
|
+
|
|
165
|
+
backward: Backward
|
|
166
|
+
"""The backward pass configuration."""
|
|
167
|
+
|
|
168
|
+
tracker: Callable[..., dict[str, float]]
|
|
169
|
+
"""The tracker to log training and validation information."""
|
|
170
|
+
|
|
171
|
+
inference_wrapper: InferenceWrapper[ModelT_contra] | None = None
|
|
172
|
+
"""An optional wrapper to apply to the model during inference, e.g., for quantization or ONNX export."""
|
|
173
|
+
|
|
174
|
+
validation_step: Forward[ModelT_contra] | None = None
|
|
175
|
+
"""The forward pass configuration for validation."""
|
|
176
|
+
|
|
177
|
+
training_prefix: str = ""
|
|
178
|
+
""" Prefix for training logs. """
|
|
179
|
+
|
|
180
|
+
validation_prefix: str = "val_"
|
|
181
|
+
""" Prefix for validation logs. """
|
|
182
|
+
|
|
183
|
+
history: dict[int, dict[str, Any]] = field(default_factory=dict)
|
|
184
|
+
"""History of training and validation logs."""
|
|
185
|
+
|
|
186
|
+
def sync(self) -> None:
|
|
187
|
+
"""Synchronize the device if necessary. This is a no-op by default, but can be overridden by subclasses."""
|
|
188
|
+
|
|
189
|
+
def train(self, dataset: DatasetLike | Callable[[], DatasetLike], **models: ModelT_contra) -> Mapping[str, Any]:
|
|
190
|
+
"""Train the model on the given dataset.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
dataset (DatasetLike | Callable[[], DatasetLike]): The dataset to train on,
|
|
194
|
+
which can be an iterable of input dictionaries or a callable that returns such an iterable.
|
|
195
|
+
**models (ModelT): The models to train.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Mapping[str, Any]: The logs from training, which may include metrics and other information.
|
|
199
|
+
"""
|
|
200
|
+
invoke_callback(self.on_training_begin, self, **models)
|
|
201
|
+
tracker, training_step, backward, elapsed_time = self.tracker, self.training_step, self.backward, 0.0
|
|
202
|
+
for index, inputs in enumerate(get_dataset(dataset), start=1):
|
|
203
|
+
self.step += 1
|
|
204
|
+
invoke_callback(self.on_training_step_begin, self, **models)
|
|
205
|
+
elapsed_time -= time()
|
|
206
|
+
criteria = training_step(inputs, **models)
|
|
207
|
+
should_update = backward(self.step, **criteria)
|
|
208
|
+
self.sync()
|
|
209
|
+
elapsed_time += time()
|
|
210
|
+
logs = tracker(**criteria) | {"elapsed_time": elapsed_time / index}
|
|
211
|
+
if self.training_prefix:
|
|
212
|
+
logs = {f"{self.training_prefix}{k}": v for k, v in logs.items()}
|
|
213
|
+
self.logs().update(logs)
|
|
214
|
+
if should_update:
|
|
215
|
+
self.update += 1
|
|
216
|
+
invoke_callback(self.on_update, self, **models)
|
|
217
|
+
invoke_callback(self.on_training_step_end, self, **models)
|
|
218
|
+
invoke_callback(self.on_training_end, self, **models)
|
|
219
|
+
return logs
|
|
220
|
+
|
|
221
|
+
def evaluate(self, dataset: DatasetLike | Callable[[], DatasetLike], **models: ModelT_contra) -> Mapping[str, Any]:
|
|
222
|
+
"""Evaluate the model on the given dataset.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
dataset (DatasetLike | Callable[[], DatasetLike]): The dataset to evaluate on,
|
|
226
|
+
which can be an iterable of input dictionaries or a callable that returns such an iterable.
|
|
227
|
+
**models (ModelT): The models to evaluate.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Mapping[str, Any]: The logs from evaluation, which may include metrics and other information.
|
|
231
|
+
"""
|
|
232
|
+
if self.validation_step is None:
|
|
233
|
+
logger.warning("Validation step is not defined. Skipping evaluation.")
|
|
234
|
+
return {}
|
|
235
|
+
if self.inference_wrapper is not None:
|
|
236
|
+
models = self.inference_wrapper(self, **models)
|
|
237
|
+
invoke_callback(self.on_validation_begin, self, **models)
|
|
238
|
+
tracker, validation_step, elapsed_time = self.tracker, self.validation_step, 0.0
|
|
239
|
+
for index, data in enumerate(get_dataset(dataset), start=1):
|
|
240
|
+
invoke_callback(self.on_validation_step_begin, self, **models)
|
|
241
|
+
elapsed_time -= time()
|
|
242
|
+
criteria = validation_step(data, **models)
|
|
243
|
+
self.sync()
|
|
244
|
+
elapsed_time += time()
|
|
245
|
+
logs = tracker(**criteria) | {"elapsed_time": elapsed_time / index}
|
|
246
|
+
if self.validation_prefix:
|
|
247
|
+
logs = {f"{self.validation_prefix}{k}": v for k, v in logs.items()}
|
|
248
|
+
self.logs().update(logs)
|
|
249
|
+
invoke_callback(self.on_validation_step_end, self, **models)
|
|
250
|
+
invoke_callback(self.on_validation_end, self, **models)
|
|
251
|
+
return logs
|
|
252
|
+
|
|
253
|
+
def fit(
|
|
254
|
+
self,
|
|
255
|
+
epochs: int,
|
|
256
|
+
training_dataset: DatasetLike | Callable[[], DatasetLike],
|
|
257
|
+
validation_dataset: DatasetLike | Callable[[], DatasetLike] | None = None,
|
|
258
|
+
start_epoch: int = 1,
|
|
259
|
+
validation_frequency: int = 1,
|
|
260
|
+
**models: ModelT_contra,
|
|
261
|
+
) -> dict[int, dict[str, Any]]:
|
|
262
|
+
"""Fit the model.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
epochs (int): Number of epochs to train.
|
|
266
|
+
training_dataset (DatasetLike | Callable[[], DatasetLike]): Training dataset.
|
|
267
|
+
validation_dataset (DatasetLike | Callable[[], DatasetLike] | None, optional): Validation dataset.
|
|
268
|
+
Defaults to None.
|
|
269
|
+
start_epoch (int, optional): Epoch to start training from. Defaults to 1.
|
|
270
|
+
validation_frequency (int, optional): Frequency of validation. Defaults to 1.
|
|
271
|
+
**models (ModelT): The models to train and validate.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
History of training and validation logs.
|
|
275
|
+
"""
|
|
276
|
+
if validation_frequency < 1:
|
|
277
|
+
raise ValueError("Validation frequency must be at least 1.")
|
|
278
|
+
if start_epoch < 1:
|
|
279
|
+
raise ValueError(f"Start epoch must be at least 1: {start_epoch}")
|
|
280
|
+
if start_epoch > epochs:
|
|
281
|
+
raise ValueError(f"Start epoch must be less than or equal to epochs: {start_epoch} > {epochs}")
|
|
282
|
+
for epoch in range(start_epoch, epochs + 1):
|
|
283
|
+
self.epoch = epoch
|
|
284
|
+
invoke_callback(self.on_epoch_begin, self, **models)
|
|
285
|
+
self.train(training_dataset, **models)
|
|
286
|
+
if validation_dataset is not None and epoch % validation_frequency == 0:
|
|
287
|
+
self.evaluate(validation_dataset, **models)
|
|
288
|
+
invoke_callback(self.on_epoch_end, self, **models)
|
|
289
|
+
return self.history
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@dataclass(kw_only=True, slots=True)
|
|
293
|
+
class BestCriterion(Generic[ModelT_contra]):
|
|
294
|
+
"""Callback to track the best criterion during training or validation."""
|
|
295
|
+
|
|
296
|
+
target: str
|
|
297
|
+
"""The target criterion to monitor."""
|
|
298
|
+
|
|
299
|
+
mode: Literal["min", "max"] = "min"
|
|
300
|
+
"""The mode to monitor the criterion. Either 'min' or 'max'."""
|
|
301
|
+
|
|
302
|
+
on_best: list[BestCallback[ModelT_contra]] = field(default_factory=list)
|
|
303
|
+
"""Callbacks to be called when a new best criterion is found."""
|
|
304
|
+
|
|
305
|
+
_best: float = field(init=False, repr=False)
|
|
306
|
+
_compare: Callable[[float, float], bool] = field(init=False, repr=False)
|
|
307
|
+
|
|
308
|
+
def __post_init__(self) -> None:
|
|
309
|
+
"""Post initialization."""
|
|
310
|
+
self._compare = lt if self.mode == "min" else gt
|
|
311
|
+
self._best = inf if self.mode == "min" else -inf
|
|
312
|
+
|
|
313
|
+
def __call__(self, info: BaseInfo, **models: ModelT_contra) -> None:
|
|
314
|
+
"""Check and update the best criterion."""
|
|
315
|
+
current: float | None = info.logs().get(self.target, None)
|
|
316
|
+
if current is not None:
|
|
317
|
+
if self._compare(current, self._best):
|
|
318
|
+
self._best = current
|
|
319
|
+
invoke_callback(self.on_best, info, self.target, self._best, **models)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
__all__ = [
|
|
323
|
+
"GLOBAL_CALLBACKS",
|
|
324
|
+
"Backward",
|
|
325
|
+
"BaseInfo",
|
|
326
|
+
"BaseTrainer",
|
|
327
|
+
"BestCallback",
|
|
328
|
+
"BestCriterion",
|
|
329
|
+
"Callback",
|
|
330
|
+
"Callbacks",
|
|
331
|
+
"DatasetLike",
|
|
332
|
+
"Forward",
|
|
333
|
+
"InferenceWrapper",
|
|
334
|
+
"get_dataset",
|
|
335
|
+
"get_dataset_size",
|
|
336
|
+
"invoke_callback",
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
if not TYPE_CHECKING:
|
|
341
|
+
import sys
|
|
342
|
+
|
|
343
|
+
from structcast.utils.lazy_import import LazySelectedImporter
|
|
344
|
+
|
|
345
|
+
sys.modules[__name__] = LazySelectedImporter(__name__, globals())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Builders module for StructCast-Model."""
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Auto Name."""
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AutoName:
|
|
8
|
+
"""Auto Name."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, infix: str = "") -> None:
|
|
11
|
+
"""Initialize AutoName."""
|
|
12
|
+
self._infix = infix
|
|
13
|
+
self._object_name_uids: dict[str, int] = defaultdict(int)
|
|
14
|
+
|
|
15
|
+
def __call__(self, value: str) -> str:
|
|
16
|
+
"""Generate a unique name."""
|
|
17
|
+
if value in self._object_name_uids:
|
|
18
|
+
unique_name = f"{value}{self._infix}{self._object_name_uids[value]}"
|
|
19
|
+
else:
|
|
20
|
+
unique_name = value
|
|
21
|
+
self._object_name_uids[value] += 1
|
|
22
|
+
return unique_name
|
|
23
|
+
|
|
24
|
+
def reset(self) -> None:
|
|
25
|
+
"""Reset the name counter."""
|
|
26
|
+
self._object_name_uids.clear()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ["AutoName"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if not TYPE_CHECKING:
|
|
33
|
+
import sys
|
|
34
|
+
|
|
35
|
+
from structcast.utils.lazy_import import LazySelectedImporter
|
|
36
|
+
|
|
37
|
+
sys.modules[__name__] = LazySelectedImporter(__name__, globals())
|