cuvis-ai-schemas 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuvis_ai_schemas/__init__.py +5 -0
- cuvis_ai_schemas/discovery/__init__.py +6 -0
- cuvis_ai_schemas/enums/__init__.py +5 -0
- cuvis_ai_schemas/enums/types.py +30 -0
- cuvis_ai_schemas/execution/__init__.py +12 -0
- cuvis_ai_schemas/execution/context.py +41 -0
- cuvis_ai_schemas/execution/monitoring.py +83 -0
- cuvis_ai_schemas/extensions/__init__.py +3 -0
- cuvis_ai_schemas/extensions/ui/__init__.py +8 -0
- cuvis_ai_schemas/extensions/ui/port_display.py +159 -0
- cuvis_ai_schemas/grpc/__init__.py +3 -0
- cuvis_ai_schemas/grpc/v1/__init__.py +11 -0
- cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2.py +240 -0
- cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2.pyi +1046 -0
- cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2_grpc.py +1290 -0
- cuvis_ai_schemas/pipeline/__init__.py +17 -0
- cuvis_ai_schemas/pipeline/config.py +238 -0
- cuvis_ai_schemas/pipeline/ports.py +48 -0
- cuvis_ai_schemas/plugin/__init__.py +6 -0
- cuvis_ai_schemas/plugin/config.py +118 -0
- cuvis_ai_schemas/plugin/manifest.py +95 -0
- cuvis_ai_schemas/training/__init__.py +40 -0
- cuvis_ai_schemas/training/callbacks.py +137 -0
- cuvis_ai_schemas/training/config.py +135 -0
- cuvis_ai_schemas/training/data.py +73 -0
- cuvis_ai_schemas/training/optimizer.py +94 -0
- cuvis_ai_schemas/training/run.py +198 -0
- cuvis_ai_schemas/training/scheduler.py +69 -0
- cuvis_ai_schemas/training/trainer.py +40 -0
- cuvis_ai_schemas-0.1.0.dist-info/METADATA +111 -0
- cuvis_ai_schemas-0.1.0.dist-info/RECORD +34 -0
- cuvis_ai_schemas-0.1.0.dist-info/WHEEL +5 -0
- cuvis_ai_schemas-0.1.0.dist-info/licenses/LICENSE +190 -0
- cuvis_ai_schemas-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Callback configuration schemas."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
try:
|
|
12
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
13
|
+
except ImportError:
|
|
14
|
+
cuvis_ai_pb2 = None # type: ignore[assignment]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _BaseConfig(BaseModel):
|
|
18
|
+
"""Base model with strict validation."""
|
|
19
|
+
|
|
20
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EarlyStoppingConfig(_BaseConfig):
|
|
24
|
+
"""Early stopping callback configuration."""
|
|
25
|
+
|
|
26
|
+
monitor: str = Field(description="Metric to monitor")
|
|
27
|
+
patience: int = Field(default=10, ge=1, description="Number of epochs to wait")
|
|
28
|
+
mode: str = Field(default="min", description="min or max")
|
|
29
|
+
min_delta: float = Field(default=0.0, ge=0.0, description="Minimum change to qualify")
|
|
30
|
+
stopping_threshold: float | None = Field(
|
|
31
|
+
default=None, description="Stop once monitored metric reaches this threshold"
|
|
32
|
+
)
|
|
33
|
+
verbose: bool = Field(default=True, description="Whether to log state changes")
|
|
34
|
+
strict: bool = Field(default=True, description="Whether to crash if monitor is not found")
|
|
35
|
+
check_finite: bool = Field(
|
|
36
|
+
default=True, description="Stop when monitor becomes NaN or infinite"
|
|
37
|
+
)
|
|
38
|
+
divergence_threshold: float | None = Field(
|
|
39
|
+
default=None,
|
|
40
|
+
description="Stop training when monitor becomes worse than this threshold",
|
|
41
|
+
)
|
|
42
|
+
check_on_train_epoch_end: bool | None = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="Whether to run early stopping at end of training epoch",
|
|
45
|
+
)
|
|
46
|
+
log_rank_zero_only: bool = Field(
|
|
47
|
+
default=False, description="Log status only for rank 0 process"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ModelCheckpointConfig(_BaseConfig):
|
|
52
|
+
"""Model checkpoint callback configuration."""
|
|
53
|
+
|
|
54
|
+
dirpath: str = Field(default="checkpoints", description="Directory to save checkpoints")
|
|
55
|
+
filename: str | None = Field(default=None, description="Checkpoint filename pattern")
|
|
56
|
+
monitor: str = Field(default="val_loss", description="Metric to monitor")
|
|
57
|
+
mode: str = Field(default="min", description="min or max")
|
|
58
|
+
save_top_k: int = Field(default=3, ge=-1, description="Save top k checkpoints (-1 for all)")
|
|
59
|
+
every_n_epochs: int = Field(default=1, ge=1, description="Save every n epochs")
|
|
60
|
+
save_last: bool | Literal["link"] | None = Field(
|
|
61
|
+
default=False, description="Also save last checkpoint (or 'link' for symlink)"
|
|
62
|
+
)
|
|
63
|
+
auto_insert_metric_name: bool = Field(
|
|
64
|
+
default=True, description="Automatically insert metric name into filename"
|
|
65
|
+
)
|
|
66
|
+
verbose: bool = Field(default=False, description="Verbosity mode")
|
|
67
|
+
save_on_exception: bool = Field(
|
|
68
|
+
default=False, description="Whether to save checkpoint when exception is raised"
|
|
69
|
+
)
|
|
70
|
+
save_weights_only: bool = Field(
|
|
71
|
+
default=False,
|
|
72
|
+
description="If True, only save model weights, not optimizer states",
|
|
73
|
+
)
|
|
74
|
+
every_n_train_steps: int | None = Field(
|
|
75
|
+
default=None,
|
|
76
|
+
description="How many training steps to wait before saving checkpoint",
|
|
77
|
+
)
|
|
78
|
+
train_time_interval: timedelta | None = Field(
|
|
79
|
+
default=None, description="Checkpoints monitored at specified time interval"
|
|
80
|
+
)
|
|
81
|
+
save_on_train_epoch_end: bool | None = Field(
|
|
82
|
+
default=None,
|
|
83
|
+
description="Whether to run checkpointing at end of training epoch",
|
|
84
|
+
)
|
|
85
|
+
enable_version_counter: bool = Field(
|
|
86
|
+
default=True, description="Whether to append version to existing file name"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class LearningRateMonitorConfig(_BaseConfig):
|
|
91
|
+
"""Learning rate monitor callback configuration."""
|
|
92
|
+
|
|
93
|
+
logging_interval: Literal["step", "epoch"] | None = Field(
|
|
94
|
+
default="epoch", description="Log lr at 'epoch' or 'step'"
|
|
95
|
+
)
|
|
96
|
+
log_momentum: bool = Field(default=False, description="Log momentum values as well")
|
|
97
|
+
log_weight_decay: bool = Field(default=False, description="Log weight decay values as well")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class CallbacksConfig(_BaseConfig):
|
|
101
|
+
"""Callbacks configuration."""
|
|
102
|
+
|
|
103
|
+
checkpoint: ModelCheckpointConfig | None = Field(
|
|
104
|
+
default=None,
|
|
105
|
+
description="Model checkpoint configuration",
|
|
106
|
+
alias="model_checkpoint",
|
|
107
|
+
)
|
|
108
|
+
early_stopping: list[EarlyStoppingConfig] = Field(
|
|
109
|
+
default_factory=list, description="Early stopping configuration(s)"
|
|
110
|
+
)
|
|
111
|
+
learning_rate_monitor: LearningRateMonitorConfig | None = Field(
|
|
112
|
+
default=None, description="Learning rate monitor configuration"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def to_proto(self) -> Any:
|
|
116
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
117
|
+
try:
|
|
118
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
119
|
+
except ImportError as e:
|
|
120
|
+
raise ImportError(
|
|
121
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
122
|
+
) from e
|
|
123
|
+
|
|
124
|
+
return cuvis_ai_pb2.CallbacksConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def from_proto(cls, proto_config):
|
|
128
|
+
"""Create from protobuf message (requires proto extra)."""
|
|
129
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
__all__ = [
|
|
133
|
+
"EarlyStoppingConfig",
|
|
134
|
+
"ModelCheckpointConfig",
|
|
135
|
+
"LearningRateMonitorConfig",
|
|
136
|
+
"CallbacksConfig",
|
|
137
|
+
]
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Main training configuration schema."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
8
|
+
|
|
9
|
+
from cuvis_ai_schemas.training.callbacks import CallbacksConfig
|
|
10
|
+
from cuvis_ai_schemas.training.optimizer import OptimizerConfig
|
|
11
|
+
from cuvis_ai_schemas.training.scheduler import SchedulerConfig
|
|
12
|
+
from cuvis_ai_schemas.training.trainer import TrainerConfig
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
try:
|
|
16
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
17
|
+
except ImportError:
|
|
18
|
+
cuvis_ai_pb2 = None # type: ignore[assignment]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TrainingConfig(BaseModel):
|
|
22
|
+
"""Complete training configuration."""
|
|
23
|
+
|
|
24
|
+
seed: int = Field(default=42, ge=0, description="Random seed for reproducibility")
|
|
25
|
+
optimizer: OptimizerConfig = Field(
|
|
26
|
+
default_factory=OptimizerConfig, description="Optimizer configuration"
|
|
27
|
+
)
|
|
28
|
+
scheduler: SchedulerConfig | None = Field(
|
|
29
|
+
default=None, description="Learning rate scheduler (optional)"
|
|
30
|
+
)
|
|
31
|
+
callbacks: CallbacksConfig | None = Field(
|
|
32
|
+
default=None, description="Training callbacks (optional)"
|
|
33
|
+
)
|
|
34
|
+
trainer: TrainerConfig = Field(
|
|
35
|
+
default_factory=TrainerConfig, description="Lightning Trainer configuration"
|
|
36
|
+
)
|
|
37
|
+
max_epochs: int = Field(default=100, ge=1, le=10000, description="Maximum training epochs")
|
|
38
|
+
batch_size: int = Field(default=32, ge=1, description="Batch size")
|
|
39
|
+
num_workers: int = Field(default=4, ge=0, description="Number of data loading workers")
|
|
40
|
+
gradient_clip_val: float | None = Field(
|
|
41
|
+
default=None, ge=0.0, description="Gradient clipping value (optional)"
|
|
42
|
+
)
|
|
43
|
+
accumulate_grad_batches: int = Field(
|
|
44
|
+
default=1, ge=1, description="Accumulate gradients over n batches"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
48
|
+
|
|
49
|
+
@model_validator(mode="after")
|
|
50
|
+
def _sync_trainer_fields(self) -> TrainingConfig:
|
|
51
|
+
"""Keep top-level hyperparameters in sync with trainer config."""
|
|
52
|
+
fields_set: set[str] = getattr(self, "model_fields_set", set())
|
|
53
|
+
|
|
54
|
+
# max_epochs: prefer explicit trainer value when top-level not provided
|
|
55
|
+
if "max_epochs" not in fields_set and self.trainer.max_epochs is not None:
|
|
56
|
+
self.max_epochs = self.trainer.max_epochs
|
|
57
|
+
else:
|
|
58
|
+
self.trainer.max_epochs = self.max_epochs
|
|
59
|
+
|
|
60
|
+
# gradient_clip_val
|
|
61
|
+
if "gradient_clip_val" not in fields_set and self.trainer.gradient_clip_val is not None:
|
|
62
|
+
self.gradient_clip_val = self.trainer.gradient_clip_val
|
|
63
|
+
elif self.gradient_clip_val is not None:
|
|
64
|
+
self.trainer.gradient_clip_val = self.gradient_clip_val
|
|
65
|
+
|
|
66
|
+
# accumulate_grad_batches
|
|
67
|
+
if (
|
|
68
|
+
"accumulate_grad_batches" not in fields_set
|
|
69
|
+
and self.trainer.accumulate_grad_batches is not None
|
|
70
|
+
):
|
|
71
|
+
self.accumulate_grad_batches = self.trainer.accumulate_grad_batches
|
|
72
|
+
else:
|
|
73
|
+
self.trainer.accumulate_grad_batches = self.accumulate_grad_batches
|
|
74
|
+
|
|
75
|
+
# callbacks
|
|
76
|
+
if self.callbacks is not None:
|
|
77
|
+
self.trainer.callbacks = self.callbacks
|
|
78
|
+
return self
|
|
79
|
+
|
|
80
|
+
def to_proto(self) -> Any:
|
|
81
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
82
|
+
try:
|
|
83
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
84
|
+
except ImportError as e:
|
|
85
|
+
raise ImportError(
|
|
86
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
87
|
+
) from e
|
|
88
|
+
|
|
89
|
+
return cuvis_ai_pb2.TrainingConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_proto(cls, proto_config):
|
|
93
|
+
"""Create from protobuf message (requires proto extra)."""
|
|
94
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
95
|
+
|
|
96
|
+
def to_json(self) -> str:
|
|
97
|
+
"""JSON serialization helper for legacy callers."""
|
|
98
|
+
return self.model_dump_json()
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_json(cls, payload: str) -> TrainingConfig:
|
|
102
|
+
"""Create from JSON string."""
|
|
103
|
+
return cls.model_validate_json(payload)
|
|
104
|
+
|
|
105
|
+
def to_dict(self) -> dict[str, Any]:
|
|
106
|
+
"""Convert to dictionary."""
|
|
107
|
+
return self.model_dump()
|
|
108
|
+
|
|
109
|
+
def to_dict_config(self) -> dict[str, Any]:
|
|
110
|
+
"""Compatibility shim for legacy OmegaConf usage."""
|
|
111
|
+
try:
|
|
112
|
+
from omegaconf import OmegaConf
|
|
113
|
+
except Exception:
|
|
114
|
+
return self.model_dump()
|
|
115
|
+
|
|
116
|
+
return OmegaConf.create(self.model_dump()) # type: ignore[return-value]
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def from_dict(cls, data: dict[str, Any]) -> TrainingConfig:
|
|
120
|
+
"""Create from dictionary."""
|
|
121
|
+
return cls.model_validate(data)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict_config(cls, config: dict[str, Any]) -> TrainingConfig:
|
|
125
|
+
"""Create from DictConfig (OmegaConf) or dictionary."""
|
|
126
|
+
if config.__class__.__name__ == "DictConfig": # Avoid hard dependency in type hints
|
|
127
|
+
from omegaconf import OmegaConf
|
|
128
|
+
|
|
129
|
+
config = OmegaConf.to_container(config, resolve=True) # type: ignore[assignment]
|
|
130
|
+
elif not isinstance(config, dict):
|
|
131
|
+
config = dict(config)
|
|
132
|
+
return cls.model_validate(config)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
__all__ = ["TrainingConfig"]
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Data configuration schema."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
try:
|
|
11
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
12
|
+
except ImportError:
|
|
13
|
+
cuvis_ai_pb2 = None # type: ignore[assignment]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DataConfig(BaseModel):
|
|
17
|
+
"""Data loading configuration."""
|
|
18
|
+
|
|
19
|
+
cu3s_file_path: str = Field(description="Path to .cu3s file")
|
|
20
|
+
annotation_json_path: str | None = Field(
|
|
21
|
+
default=None, description="Path to annotation JSON (optional)"
|
|
22
|
+
)
|
|
23
|
+
train_ids: list[int] = Field(default_factory=list, description="Training sample IDs")
|
|
24
|
+
val_ids: list[int] = Field(default_factory=list, description="Validation sample IDs")
|
|
25
|
+
test_ids: list[int] = Field(default_factory=list, description="Test sample IDs")
|
|
26
|
+
train_split: float | None = Field(
|
|
27
|
+
default=None, ge=0.0, le=1.0, description="Training split ratio"
|
|
28
|
+
)
|
|
29
|
+
val_split: float | None = Field(
|
|
30
|
+
default=None, ge=0.0, le=1.0, description="Validation split ratio"
|
|
31
|
+
)
|
|
32
|
+
shuffle: bool = Field(default=True, description="Shuffle dataset")
|
|
33
|
+
batch_size: int = Field(default=1, ge=1, description="Batch size")
|
|
34
|
+
processing_mode: str = Field(default="Reflectance", description="Raw or Reflectance mode")
|
|
35
|
+
|
|
36
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
37
|
+
|
|
38
|
+
def to_proto(self) -> Any:
|
|
39
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
40
|
+
try:
|
|
41
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
42
|
+
except ImportError as e:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
45
|
+
) from e
|
|
46
|
+
|
|
47
|
+
return cuvis_ai_pb2.DataConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_proto(cls, proto_config):
|
|
51
|
+
"""Create from protobuf message (requires proto extra)."""
|
|
52
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
53
|
+
|
|
54
|
+
def to_json(self) -> str:
|
|
55
|
+
"""JSON serialization helper for legacy callers."""
|
|
56
|
+
return self.model_dump_json()
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_json(cls, payload: str) -> DataConfig:
|
|
60
|
+
"""Create from JSON string."""
|
|
61
|
+
return cls.model_validate_json(payload)
|
|
62
|
+
|
|
63
|
+
def to_dict(self) -> dict[str, Any]:
|
|
64
|
+
"""Convert to dictionary."""
|
|
65
|
+
return self.model_dump()
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_dict(cls, data: dict[str, Any]) -> DataConfig:
|
|
69
|
+
"""Create from dictionary."""
|
|
70
|
+
return cls.model_validate(data)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
__all__ = ["DataConfig"]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Optimizer configuration schema."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
try:
|
|
11
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
12
|
+
except ImportError:
|
|
13
|
+
cuvis_ai_pb2 = None # type: ignore[assignment]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OptimizerConfig(BaseModel):
|
|
17
|
+
"""Optimizer configuration with constraints and documentation."""
|
|
18
|
+
|
|
19
|
+
name: str = Field(
|
|
20
|
+
default="adamw",
|
|
21
|
+
description="Optimizer type: adamw, sgd, adam",
|
|
22
|
+
)
|
|
23
|
+
lr: float = Field(
|
|
24
|
+
default=1e-3,
|
|
25
|
+
gt=0.0,
|
|
26
|
+
le=1.0,
|
|
27
|
+
description="Learning rate",
|
|
28
|
+
json_schema_extra={"minimum": 1e-6},
|
|
29
|
+
)
|
|
30
|
+
weight_decay: float = Field(
|
|
31
|
+
default=0.0,
|
|
32
|
+
ge=0.0,
|
|
33
|
+
le=1.0,
|
|
34
|
+
description="L2 regularization coefficient",
|
|
35
|
+
)
|
|
36
|
+
momentum: float | None = Field(
|
|
37
|
+
default=0.9,
|
|
38
|
+
ge=0.0,
|
|
39
|
+
le=1.0,
|
|
40
|
+
description="Momentum factor (for SGD)",
|
|
41
|
+
)
|
|
42
|
+
betas: tuple[float, float] | None = Field(default=None, description="Adam betas (beta1, beta2)")
|
|
43
|
+
|
|
44
|
+
model_config = ConfigDict(
|
|
45
|
+
extra="forbid",
|
|
46
|
+
validate_assignment=True,
|
|
47
|
+
populate_by_name=True,
|
|
48
|
+
json_schema_extra={
|
|
49
|
+
"examples": [
|
|
50
|
+
{
|
|
51
|
+
"name": "adamw",
|
|
52
|
+
"lr": 0.001,
|
|
53
|
+
"weight_decay": 0.01,
|
|
54
|
+
}
|
|
55
|
+
]
|
|
56
|
+
},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@field_validator("betas")
|
|
60
|
+
@classmethod
|
|
61
|
+
def _validate_betas(cls, value: tuple[float, float] | None) -> tuple[float, float] | None:
|
|
62
|
+
"""Validate that betas is a tuple of exactly 2 floats."""
|
|
63
|
+
if value is None:
|
|
64
|
+
return value
|
|
65
|
+
if len(value) != 2:
|
|
66
|
+
raise ValueError("betas must be a tuple of length 2")
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
@field_validator("lr")
|
|
70
|
+
@classmethod
|
|
71
|
+
def _validate_lr(cls, value: float) -> float:
|
|
72
|
+
"""Validate that learning rate is non-zero."""
|
|
73
|
+
if value == 0:
|
|
74
|
+
raise ValueError("Learning rate must be non-zero")
|
|
75
|
+
return value
|
|
76
|
+
|
|
77
|
+
def to_proto(self) -> Any:
|
|
78
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
79
|
+
try:
|
|
80
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
81
|
+
except ImportError as e:
|
|
82
|
+
raise ImportError(
|
|
83
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
84
|
+
) from e
|
|
85
|
+
|
|
86
|
+
return cuvis_ai_pb2.OptimizerConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_proto(cls, proto_config):
|
|
90
|
+
"""Create from protobuf message (requires proto extra)."""
|
|
91
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
__all__ = ["OptimizerConfig"]
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Training run configuration schema."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
10
|
+
|
|
11
|
+
from cuvis_ai_schemas.training.config import TrainingConfig
|
|
12
|
+
from cuvis_ai_schemas.training.data import DataConfig
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
try:
|
|
16
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
17
|
+
except ImportError:
|
|
18
|
+
cuvis_ai_pb2 = None # type: ignore[assignment]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PipelineMetadata(BaseModel):
|
|
22
|
+
"""Pipeline metadata for documentation and discovery."""
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
description: str = ""
|
|
26
|
+
created: str = ""
|
|
27
|
+
tags: list[str] = Field(default_factory=list)
|
|
28
|
+
author: str = ""
|
|
29
|
+
cuvis_ai_version: str = Field(default="0.1.0")
|
|
30
|
+
|
|
31
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
32
|
+
|
|
33
|
+
def to_dict(self) -> dict[str, Any]:
|
|
34
|
+
"""Convert to dictionary."""
|
|
35
|
+
return self.model_dump()
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_dict(cls, data: dict[str, Any]) -> PipelineMetadata:
|
|
39
|
+
"""Create from dictionary."""
|
|
40
|
+
return cls.model_validate(data)
|
|
41
|
+
|
|
42
|
+
def to_proto(self) -> Any:
|
|
43
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
44
|
+
try:
|
|
45
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
46
|
+
except ImportError as e:
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
49
|
+
) from e
|
|
50
|
+
|
|
51
|
+
return cuvis_ai_pb2.PipelineMetadata(
|
|
52
|
+
name=self.name,
|
|
53
|
+
description=self.description,
|
|
54
|
+
created=self.created,
|
|
55
|
+
tags=list(self.tags),
|
|
56
|
+
author=self.author,
|
|
57
|
+
cuvis_ai_version=self.cuvis_ai_version,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class PipelineConfig(BaseModel):
|
|
62
|
+
"""Pipeline structure configuration."""
|
|
63
|
+
|
|
64
|
+
name: str = Field(default="", description="Pipeline name")
|
|
65
|
+
nodes: list[dict[str, Any]] = Field(description="Node definitions")
|
|
66
|
+
connections: list[dict[str, Any]] = Field(description="Node connections")
|
|
67
|
+
frozen_nodes: list[str] = Field(
|
|
68
|
+
default_factory=list, description="Node names to keep frozen during training"
|
|
69
|
+
)
|
|
70
|
+
metadata: PipelineMetadata | None = Field(
|
|
71
|
+
default=None, description="Optional pipeline metadata"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
75
|
+
|
|
76
|
+
def to_proto(self) -> Any:
|
|
77
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
78
|
+
try:
|
|
79
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
80
|
+
except ImportError as e:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
83
|
+
) from e
|
|
84
|
+
|
|
85
|
+
return cuvis_ai_pb2.PipelineConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def from_proto(cls, proto_config):
|
|
89
|
+
"""Create from protobuf message (requires proto extra)."""
|
|
90
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
91
|
+
|
|
92
|
+
def to_json(self) -> str:
|
|
93
|
+
"""JSON serialization helper for legacy callers."""
|
|
94
|
+
return self.model_dump_json()
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_json(cls, payload: str) -> PipelineConfig:
|
|
98
|
+
"""Create from JSON string."""
|
|
99
|
+
return cls.model_validate_json(payload)
|
|
100
|
+
|
|
101
|
+
def to_dict(self) -> dict[str, Any]:
|
|
102
|
+
"""Convert to dictionary."""
|
|
103
|
+
return self.model_dump()
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def from_dict(cls, data: dict[str, Any]) -> PipelineConfig:
|
|
107
|
+
"""Create from dictionary."""
|
|
108
|
+
return cls.model_validate(data)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class TrainRunConfig(BaseModel):
|
|
112
|
+
"""Complete reproducible training configuration."""
|
|
113
|
+
|
|
114
|
+
name: str = Field(description="Train run identifier")
|
|
115
|
+
pipeline: PipelineConfig | None = Field(
|
|
116
|
+
default=None, description="Pipeline configuration (optional if already built)"
|
|
117
|
+
)
|
|
118
|
+
data: DataConfig = Field(description="Data configuration")
|
|
119
|
+
|
|
120
|
+
training: TrainingConfig | None = Field(
|
|
121
|
+
default=None,
|
|
122
|
+
description="Training configuration (required if gradient training)",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
loss_nodes: list[str] = Field(
|
|
126
|
+
default_factory=list, description="Loss node names for gradient training"
|
|
127
|
+
)
|
|
128
|
+
metric_nodes: list[str] = Field(
|
|
129
|
+
default_factory=list, description="Metric node names for monitoring"
|
|
130
|
+
)
|
|
131
|
+
freeze_nodes: list[str] = Field(
|
|
132
|
+
default_factory=list, description="Node names to keep frozen during training"
|
|
133
|
+
)
|
|
134
|
+
unfreeze_nodes: list[str] = Field(
|
|
135
|
+
default_factory=list, description="Node names to unfreeze during training"
|
|
136
|
+
)
|
|
137
|
+
output_dir: str = Field(default="./outputs", description="Output directory for artifacts")
|
|
138
|
+
tags: dict[str, str] = Field(default_factory=dict, description="Metadata tags for tracking")
|
|
139
|
+
|
|
140
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
141
|
+
|
|
142
|
+
def to_proto(self) -> Any:
|
|
143
|
+
"""Convert to protobuf message (requires proto extra)."""
|
|
144
|
+
try:
|
|
145
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
146
|
+
except ImportError as e:
|
|
147
|
+
raise ImportError(
|
|
148
|
+
"Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
|
|
149
|
+
) from e
|
|
150
|
+
|
|
151
|
+
return cuvis_ai_pb2.TrainRunConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def from_proto(cls, proto_config):
|
|
155
|
+
"""Create from protobuf message (requires proto extra)."""
|
|
156
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
157
|
+
|
|
158
|
+
def to_json(self) -> str:
|
|
159
|
+
"""JSON serialization helper for legacy callers."""
|
|
160
|
+
return self.model_dump_json()
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def from_json(cls, payload: str) -> TrainRunConfig:
|
|
164
|
+
"""Create from JSON string."""
|
|
165
|
+
return cls.model_validate_json(payload)
|
|
166
|
+
|
|
167
|
+
def to_dict(self) -> dict[str, Any]:
|
|
168
|
+
"""Convert to dictionary."""
|
|
169
|
+
return self.model_dump()
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_dict(cls, data: dict[str, Any]) -> TrainRunConfig:
|
|
173
|
+
"""Create from dictionary."""
|
|
174
|
+
return cls.model_validate(data)
|
|
175
|
+
|
|
176
|
+
def save_to_file(self, path: str | Path) -> None:
|
|
177
|
+
"""Save configuration to YAML file."""
|
|
178
|
+
output_path = Path(path)
|
|
179
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
180
|
+
with output_path.open("w", encoding="utf-8") as f:
|
|
181
|
+
yaml.safe_dump(self.model_dump(), f, sort_keys=False)
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def load_from_file(cls, path: str | Path) -> TrainRunConfig:
|
|
185
|
+
"""Load configuration from YAML file."""
|
|
186
|
+
with Path(path).open("r", encoding="utf-8") as f:
|
|
187
|
+
data = yaml.safe_load(f)
|
|
188
|
+
return cls.from_dict(data)
|
|
189
|
+
|
|
190
|
+
@model_validator(mode="after")
|
|
191
|
+
def _validate_training_config(self) -> TrainRunConfig:
|
|
192
|
+
"""Ensure training config has optimizer if provided."""
|
|
193
|
+
if self.training is not None and self.training.optimizer is None:
|
|
194
|
+
raise ValueError("Training configuration must include optimizer when provided")
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
__all__ = ["PipelineMetadata", "PipelineConfig", "TrainRunConfig"]
|