cuvis-ai-schemas 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cuvis_ai_schemas/__init__.py +5 -0
  2. cuvis_ai_schemas/discovery/__init__.py +6 -0
  3. cuvis_ai_schemas/enums/__init__.py +5 -0
  4. cuvis_ai_schemas/enums/types.py +30 -0
  5. cuvis_ai_schemas/execution/__init__.py +12 -0
  6. cuvis_ai_schemas/execution/context.py +41 -0
  7. cuvis_ai_schemas/execution/monitoring.py +83 -0
  8. cuvis_ai_schemas/extensions/__init__.py +3 -0
  9. cuvis_ai_schemas/extensions/ui/__init__.py +8 -0
  10. cuvis_ai_schemas/extensions/ui/port_display.py +159 -0
  11. cuvis_ai_schemas/grpc/__init__.py +3 -0
  12. cuvis_ai_schemas/grpc/v1/__init__.py +11 -0
  13. cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2.py +240 -0
  14. cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2.pyi +1046 -0
  15. cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2_grpc.py +1290 -0
  16. cuvis_ai_schemas/pipeline/__init__.py +17 -0
  17. cuvis_ai_schemas/pipeline/config.py +238 -0
  18. cuvis_ai_schemas/pipeline/ports.py +48 -0
  19. cuvis_ai_schemas/plugin/__init__.py +6 -0
  20. cuvis_ai_schemas/plugin/config.py +118 -0
  21. cuvis_ai_schemas/plugin/manifest.py +95 -0
  22. cuvis_ai_schemas/training/__init__.py +40 -0
  23. cuvis_ai_schemas/training/callbacks.py +137 -0
  24. cuvis_ai_schemas/training/config.py +135 -0
  25. cuvis_ai_schemas/training/data.py +73 -0
  26. cuvis_ai_schemas/training/optimizer.py +94 -0
  27. cuvis_ai_schemas/training/run.py +198 -0
  28. cuvis_ai_schemas/training/scheduler.py +69 -0
  29. cuvis_ai_schemas/training/trainer.py +40 -0
  30. cuvis_ai_schemas-0.1.0.dist-info/METADATA +111 -0
  31. cuvis_ai_schemas-0.1.0.dist-info/RECORD +34 -0
  32. cuvis_ai_schemas-0.1.0.dist-info/WHEEL +5 -0
  33. cuvis_ai_schemas-0.1.0.dist-info/licenses/LICENSE +190 -0
  34. cuvis_ai_schemas-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,137 @@
1
+ """Callback configuration schemas."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import timedelta
6
+ from typing import TYPE_CHECKING, Any, Literal
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field
9
+
10
+ if TYPE_CHECKING:
11
+ try:
12
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
13
+ except ImportError:
14
+ cuvis_ai_pb2 = None # type: ignore[assignment]
15
+
16
+
17
+ class _BaseConfig(BaseModel):
18
+ """Base model with strict validation."""
19
+
20
+ model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
21
+
22
+
23
+ class EarlyStoppingConfig(_BaseConfig):
24
+ """Early stopping callback configuration."""
25
+
26
+ monitor: str = Field(description="Metric to monitor")
27
+ patience: int = Field(default=10, ge=1, description="Number of epochs to wait")
28
+ mode: str = Field(default="min", description="min or max")
29
+ min_delta: float = Field(default=0.0, ge=0.0, description="Minimum change to qualify")
30
+ stopping_threshold: float | None = Field(
31
+ default=None, description="Stop once monitored metric reaches this threshold"
32
+ )
33
+ verbose: bool = Field(default=True, description="Whether to log state changes")
34
+ strict: bool = Field(default=True, description="Whether to crash if monitor is not found")
35
+ check_finite: bool = Field(
36
+ default=True, description="Stop when monitor becomes NaN or infinite"
37
+ )
38
+ divergence_threshold: float | None = Field(
39
+ default=None,
40
+ description="Stop training when monitor becomes worse than this threshold",
41
+ )
42
+ check_on_train_epoch_end: bool | None = Field(
43
+ default=None,
44
+ description="Whether to run early stopping at end of training epoch",
45
+ )
46
+ log_rank_zero_only: bool = Field(
47
+ default=False, description="Log status only for rank 0 process"
48
+ )
49
+
50
+
51
+ class ModelCheckpointConfig(_BaseConfig):
52
+ """Model checkpoint callback configuration."""
53
+
54
+ dirpath: str = Field(default="checkpoints", description="Directory to save checkpoints")
55
+ filename: str | None = Field(default=None, description="Checkpoint filename pattern")
56
+ monitor: str = Field(default="val_loss", description="Metric to monitor")
57
+ mode: str = Field(default="min", description="min or max")
58
+ save_top_k: int = Field(default=3, ge=-1, description="Save top k checkpoints (-1 for all)")
59
+ every_n_epochs: int = Field(default=1, ge=1, description="Save every n epochs")
60
+ save_last: bool | Literal["link"] | None = Field(
61
+ default=False, description="Also save last checkpoint (or 'link' for symlink)"
62
+ )
63
+ auto_insert_metric_name: bool = Field(
64
+ default=True, description="Automatically insert metric name into filename"
65
+ )
66
+ verbose: bool = Field(default=False, description="Verbosity mode")
67
+ save_on_exception: bool = Field(
68
+ default=False, description="Whether to save checkpoint when exception is raised"
69
+ )
70
+ save_weights_only: bool = Field(
71
+ default=False,
72
+ description="If True, only save model weights, not optimizer states",
73
+ )
74
+ every_n_train_steps: int | None = Field(
75
+ default=None,
76
+ description="How many training steps to wait before saving checkpoint",
77
+ )
78
+ train_time_interval: timedelta | None = Field(
79
+ default=None, description="Checkpoints monitored at specified time interval"
80
+ )
81
+ save_on_train_epoch_end: bool | None = Field(
82
+ default=None,
83
+ description="Whether to run checkpointing at end of training epoch",
84
+ )
85
+ enable_version_counter: bool = Field(
86
+ default=True, description="Whether to append version to existing file name"
87
+ )
88
+
89
+
90
+ class LearningRateMonitorConfig(_BaseConfig):
91
+ """Learning rate monitor callback configuration."""
92
+
93
+ logging_interval: Literal["step", "epoch"] | None = Field(
94
+ default="epoch", description="Log lr at 'epoch' or 'step'"
95
+ )
96
+ log_momentum: bool = Field(default=False, description="Log momentum values as well")
97
+ log_weight_decay: bool = Field(default=False, description="Log weight decay values as well")
98
+
99
+
100
+ class CallbacksConfig(_BaseConfig):
101
+ """Callbacks configuration."""
102
+
103
+ checkpoint: ModelCheckpointConfig | None = Field(
104
+ default=None,
105
+ description="Model checkpoint configuration",
106
+ alias="model_checkpoint",
107
+ )
108
+ early_stopping: list[EarlyStoppingConfig] = Field(
109
+ default_factory=list, description="Early stopping configuration(s)"
110
+ )
111
+ learning_rate_monitor: LearningRateMonitorConfig | None = Field(
112
+ default=None, description="Learning rate monitor configuration"
113
+ )
114
+
115
+ def to_proto(self) -> Any:
116
+ """Convert to protobuf message (requires proto extra)."""
117
+ try:
118
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
119
+ except ImportError as e:
120
+ raise ImportError(
121
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
122
+ ) from e
123
+
124
+ return cuvis_ai_pb2.CallbacksConfig(config_bytes=self.model_dump_json().encode("utf-8"))
125
+
126
+ @classmethod
127
+ def from_proto(cls, proto_config):
128
+ """Create from protobuf message (requires proto extra)."""
129
+ return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
130
+
131
+
132
+ __all__ = [
133
+ "EarlyStoppingConfig",
134
+ "ModelCheckpointConfig",
135
+ "LearningRateMonitorConfig",
136
+ "CallbacksConfig",
137
+ ]
@@ -0,0 +1,135 @@
1
+ """Main training configuration schema."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+
9
+ from cuvis_ai_schemas.training.callbacks import CallbacksConfig
10
+ from cuvis_ai_schemas.training.optimizer import OptimizerConfig
11
+ from cuvis_ai_schemas.training.scheduler import SchedulerConfig
12
+ from cuvis_ai_schemas.training.trainer import TrainerConfig
13
+
14
+ if TYPE_CHECKING:
15
+ try:
16
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
17
+ except ImportError:
18
+ cuvis_ai_pb2 = None # type: ignore[assignment]
19
+
20
+
21
+ class TrainingConfig(BaseModel):
22
+ """Complete training configuration."""
23
+
24
+ seed: int = Field(default=42, ge=0, description="Random seed for reproducibility")
25
+ optimizer: OptimizerConfig = Field(
26
+ default_factory=OptimizerConfig, description="Optimizer configuration"
27
+ )
28
+ scheduler: SchedulerConfig | None = Field(
29
+ default=None, description="Learning rate scheduler (optional)"
30
+ )
31
+ callbacks: CallbacksConfig | None = Field(
32
+ default=None, description="Training callbacks (optional)"
33
+ )
34
+ trainer: TrainerConfig = Field(
35
+ default_factory=TrainerConfig, description="Lightning Trainer configuration"
36
+ )
37
+ max_epochs: int = Field(default=100, ge=1, le=10000, description="Maximum training epochs")
38
+ batch_size: int = Field(default=32, ge=1, description="Batch size")
39
+ num_workers: int = Field(default=4, ge=0, description="Number of data loading workers")
40
+ gradient_clip_val: float | None = Field(
41
+ default=None, ge=0.0, description="Gradient clipping value (optional)"
42
+ )
43
+ accumulate_grad_batches: int = Field(
44
+ default=1, ge=1, description="Accumulate gradients over n batches"
45
+ )
46
+
47
+ model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
48
+
49
+ @model_validator(mode="after")
50
+ def _sync_trainer_fields(self) -> TrainingConfig:
51
+ """Keep top-level hyperparameters in sync with trainer config."""
52
+ fields_set: set[str] = getattr(self, "model_fields_set", set())
53
+
54
+ # max_epochs: prefer explicit trainer value when top-level not provided
55
+ if "max_epochs" not in fields_set and self.trainer.max_epochs is not None:
56
+ self.max_epochs = self.trainer.max_epochs
57
+ else:
58
+ self.trainer.max_epochs = self.max_epochs
59
+
60
+ # gradient_clip_val
61
+ if "gradient_clip_val" not in fields_set and self.trainer.gradient_clip_val is not None:
62
+ self.gradient_clip_val = self.trainer.gradient_clip_val
63
+ elif self.gradient_clip_val is not None:
64
+ self.trainer.gradient_clip_val = self.gradient_clip_val
65
+
66
+ # accumulate_grad_batches
67
+ if (
68
+ "accumulate_grad_batches" not in fields_set
69
+ and self.trainer.accumulate_grad_batches is not None
70
+ ):
71
+ self.accumulate_grad_batches = self.trainer.accumulate_grad_batches
72
+ else:
73
+ self.trainer.accumulate_grad_batches = self.accumulate_grad_batches
74
+
75
+ # callbacks
76
+ if self.callbacks is not None:
77
+ self.trainer.callbacks = self.callbacks
78
+ return self
79
+
80
+ def to_proto(self) -> Any:
81
+ """Convert to protobuf message (requires proto extra)."""
82
+ try:
83
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
84
+ except ImportError as e:
85
+ raise ImportError(
86
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
87
+ ) from e
88
+
89
+ return cuvis_ai_pb2.TrainingConfig(config_bytes=self.model_dump_json().encode("utf-8"))
90
+
91
+ @classmethod
92
+ def from_proto(cls, proto_config):
93
+ """Create from protobuf message (requires proto extra)."""
94
+ return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
95
+
96
+ def to_json(self) -> str:
97
+ """JSON serialization helper for legacy callers."""
98
+ return self.model_dump_json()
99
+
100
+ @classmethod
101
+ def from_json(cls, payload: str) -> TrainingConfig:
102
+ """Create from JSON string."""
103
+ return cls.model_validate_json(payload)
104
+
105
+ def to_dict(self) -> dict[str, Any]:
106
+ """Convert to dictionary."""
107
+ return self.model_dump()
108
+
109
+ def to_dict_config(self) -> dict[str, Any]:
110
+ """Compatibility shim for legacy OmegaConf usage."""
111
+ try:
112
+ from omegaconf import OmegaConf
113
+ except Exception:
114
+ return self.model_dump()
115
+
116
+ return OmegaConf.create(self.model_dump()) # type: ignore[return-value]
117
+
118
+ @classmethod
119
+ def from_dict(cls, data: dict[str, Any]) -> TrainingConfig:
120
+ """Create from dictionary."""
121
+ return cls.model_validate(data)
122
+
123
+ @classmethod
124
+ def from_dict_config(cls, config: dict[str, Any]) -> TrainingConfig:
125
+ """Create from DictConfig (OmegaConf) or dictionary."""
126
+ if config.__class__.__name__ == "DictConfig": # Avoid hard dependency in type hints
127
+ from omegaconf import OmegaConf
128
+
129
+ config = OmegaConf.to_container(config, resolve=True) # type: ignore[assignment]
130
+ elif not isinstance(config, dict):
131
+ config = dict(config)
132
+ return cls.model_validate(config)
133
+
134
+
135
+ __all__ = ["TrainingConfig"]
@@ -0,0 +1,73 @@
1
+ """Data configuration schema."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field
8
+
9
+ if TYPE_CHECKING:
10
+ try:
11
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
12
+ except ImportError:
13
+ cuvis_ai_pb2 = None # type: ignore[assignment]
14
+
15
+
16
+ class DataConfig(BaseModel):
17
+ """Data loading configuration."""
18
+
19
+ cu3s_file_path: str = Field(description="Path to .cu3s file")
20
+ annotation_json_path: str | None = Field(
21
+ default=None, description="Path to annotation JSON (optional)"
22
+ )
23
+ train_ids: list[int] = Field(default_factory=list, description="Training sample IDs")
24
+ val_ids: list[int] = Field(default_factory=list, description="Validation sample IDs")
25
+ test_ids: list[int] = Field(default_factory=list, description="Test sample IDs")
26
+ train_split: float | None = Field(
27
+ default=None, ge=0.0, le=1.0, description="Training split ratio"
28
+ )
29
+ val_split: float | None = Field(
30
+ default=None, ge=0.0, le=1.0, description="Validation split ratio"
31
+ )
32
+ shuffle: bool = Field(default=True, description="Shuffle dataset")
33
+ batch_size: int = Field(default=1, ge=1, description="Batch size")
34
+ processing_mode: str = Field(default="Reflectance", description="Raw or Reflectance mode")
35
+
36
+ model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
37
+
38
+ def to_proto(self) -> Any:
39
+ """Convert to protobuf message (requires proto extra)."""
40
+ try:
41
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
42
+ except ImportError as e:
43
+ raise ImportError(
44
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
45
+ ) from e
46
+
47
+ return cuvis_ai_pb2.DataConfig(config_bytes=self.model_dump_json().encode("utf-8"))
48
+
49
+ @classmethod
50
+ def from_proto(cls, proto_config):
51
+ """Create from protobuf message (requires proto extra)."""
52
+ return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
53
+
54
+ def to_json(self) -> str:
55
+ """JSON serialization helper for legacy callers."""
56
+ return self.model_dump_json()
57
+
58
+ @classmethod
59
+ def from_json(cls, payload: str) -> DataConfig:
60
+ """Create from JSON string."""
61
+ return cls.model_validate_json(payload)
62
+
63
+ def to_dict(self) -> dict[str, Any]:
64
+ """Convert to dictionary."""
65
+ return self.model_dump()
66
+
67
+ @classmethod
68
+ def from_dict(cls, data: dict[str, Any]) -> DataConfig:
69
+ """Create from dictionary."""
70
+ return cls.model_validate(data)
71
+
72
+
73
+ __all__ = ["DataConfig"]
@@ -0,0 +1,94 @@
1
+ """Optimizer configuration schema."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
8
+
9
+ if TYPE_CHECKING:
10
+ try:
11
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
12
+ except ImportError:
13
+ cuvis_ai_pb2 = None # type: ignore[assignment]
14
+
15
+
16
+ class OptimizerConfig(BaseModel):
17
+ """Optimizer configuration with constraints and documentation."""
18
+
19
+ name: str = Field(
20
+ default="adamw",
21
+ description="Optimizer type: adamw, sgd, adam",
22
+ )
23
+ lr: float = Field(
24
+ default=1e-3,
25
+ gt=0.0,
26
+ le=1.0,
27
+ description="Learning rate",
28
+ json_schema_extra={"minimum": 1e-6},
29
+ )
30
+ weight_decay: float = Field(
31
+ default=0.0,
32
+ ge=0.0,
33
+ le=1.0,
34
+ description="L2 regularization coefficient",
35
+ )
36
+ momentum: float | None = Field(
37
+ default=0.9,
38
+ ge=0.0,
39
+ le=1.0,
40
+ description="Momentum factor (for SGD)",
41
+ )
42
+ betas: tuple[float, float] | None = Field(default=None, description="Adam betas (beta1, beta2)")
43
+
44
+ model_config = ConfigDict(
45
+ extra="forbid",
46
+ validate_assignment=True,
47
+ populate_by_name=True,
48
+ json_schema_extra={
49
+ "examples": [
50
+ {
51
+ "name": "adamw",
52
+ "lr": 0.001,
53
+ "weight_decay": 0.01,
54
+ }
55
+ ]
56
+ },
57
+ )
58
+
59
+ @field_validator("betas")
60
+ @classmethod
61
+ def _validate_betas(cls, value: tuple[float, float] | None) -> tuple[float, float] | None:
62
+ """Validate that betas is a tuple of exactly 2 floats."""
63
+ if value is None:
64
+ return value
65
+ if len(value) != 2:
66
+ raise ValueError("betas must be a tuple of length 2")
67
+ return value
68
+
69
+ @field_validator("lr")
70
+ @classmethod
71
+ def _validate_lr(cls, value: float) -> float:
72
+ """Validate that learning rate is non-zero."""
73
+ if value == 0:
74
+ raise ValueError("Learning rate must be non-zero")
75
+ return value
76
+
77
+ def to_proto(self) -> Any:
78
+ """Convert to protobuf message (requires proto extra)."""
79
+ try:
80
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
81
+ except ImportError as e:
82
+ raise ImportError(
83
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
84
+ ) from e
85
+
86
+ return cuvis_ai_pb2.OptimizerConfig(config_bytes=self.model_dump_json().encode("utf-8"))
87
+
88
+ @classmethod
89
+ def from_proto(cls, proto_config):
90
+ """Create from protobuf message (requires proto extra)."""
91
+ return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
92
+
93
+
94
+ __all__ = ["OptimizerConfig"]
@@ -0,0 +1,198 @@
1
+ """Training run configuration schema."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import yaml
9
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
10
+
11
+ from cuvis_ai_schemas.training.config import TrainingConfig
12
+ from cuvis_ai_schemas.training.data import DataConfig
13
+
14
+ if TYPE_CHECKING:
15
+ try:
16
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
17
+ except ImportError:
18
+ cuvis_ai_pb2 = None # type: ignore[assignment]
19
+
20
+
21
+ class PipelineMetadata(BaseModel):
22
+ """Pipeline metadata for documentation and discovery."""
23
+
24
+ name: str
25
+ description: str = ""
26
+ created: str = ""
27
+ tags: list[str] = Field(default_factory=list)
28
+ author: str = ""
29
+ cuvis_ai_version: str = Field(default="0.1.0")
30
+
31
+ model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
32
+
33
+ def to_dict(self) -> dict[str, Any]:
34
+ """Convert to dictionary."""
35
+ return self.model_dump()
36
+
37
+ @classmethod
38
+ def from_dict(cls, data: dict[str, Any]) -> PipelineMetadata:
39
+ """Create from dictionary."""
40
+ return cls.model_validate(data)
41
+
42
+ def to_proto(self) -> Any:
43
+ """Convert to protobuf message (requires proto extra)."""
44
+ try:
45
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
46
+ except ImportError as e:
47
+ raise ImportError(
48
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
49
+ ) from e
50
+
51
+ return cuvis_ai_pb2.PipelineMetadata(
52
+ name=self.name,
53
+ description=self.description,
54
+ created=self.created,
55
+ tags=list(self.tags),
56
+ author=self.author,
57
+ cuvis_ai_version=self.cuvis_ai_version,
58
+ )
59
+
60
+
61
+ class PipelineConfig(BaseModel):
62
+ """Pipeline structure configuration."""
63
+
64
+ name: str = Field(default="", description="Pipeline name")
65
+ nodes: list[dict[str, Any]] = Field(description="Node definitions")
66
+ connections: list[dict[str, Any]] = Field(description="Node connections")
67
+ frozen_nodes: list[str] = Field(
68
+ default_factory=list, description="Node names to keep frozen during training"
69
+ )
70
+ metadata: PipelineMetadata | None = Field(
71
+ default=None, description="Optional pipeline metadata"
72
+ )
73
+
74
+ model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
75
+
76
+ def to_proto(self) -> Any:
77
+ """Convert to protobuf message (requires proto extra)."""
78
+ try:
79
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
80
+ except ImportError as e:
81
+ raise ImportError(
82
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
83
+ ) from e
84
+
85
+ return cuvis_ai_pb2.PipelineConfig(config_bytes=self.model_dump_json().encode("utf-8"))
86
+
87
+ @classmethod
88
+ def from_proto(cls, proto_config):
89
+ """Create from protobuf message (requires proto extra)."""
90
+ return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
91
+
92
+ def to_json(self) -> str:
93
+ """JSON serialization helper for legacy callers."""
94
+ return self.model_dump_json()
95
+
96
+ @classmethod
97
+ def from_json(cls, payload: str) -> PipelineConfig:
98
+ """Create from JSON string."""
99
+ return cls.model_validate_json(payload)
100
+
101
+ def to_dict(self) -> dict[str, Any]:
102
+ """Convert to dictionary."""
103
+ return self.model_dump()
104
+
105
+ @classmethod
106
+ def from_dict(cls, data: dict[str, Any]) -> PipelineConfig:
107
+ """Create from dictionary."""
108
+ return cls.model_validate(data)
109
+
110
+
111
+ class TrainRunConfig(BaseModel):
112
+ """Complete reproducible training configuration."""
113
+
114
+ name: str = Field(description="Train run identifier")
115
+ pipeline: PipelineConfig | None = Field(
116
+ default=None, description="Pipeline configuration (optional if already built)"
117
+ )
118
+ data: DataConfig = Field(description="Data configuration")
119
+
120
+ training: TrainingConfig | None = Field(
121
+ default=None,
122
+ description="Training configuration (required if gradient training)",
123
+ )
124
+
125
+ loss_nodes: list[str] = Field(
126
+ default_factory=list, description="Loss node names for gradient training"
127
+ )
128
+ metric_nodes: list[str] = Field(
129
+ default_factory=list, description="Metric node names for monitoring"
130
+ )
131
+ freeze_nodes: list[str] = Field(
132
+ default_factory=list, description="Node names to keep frozen during training"
133
+ )
134
+ unfreeze_nodes: list[str] = Field(
135
+ default_factory=list, description="Node names to unfreeze during training"
136
+ )
137
+ output_dir: str = Field(default="./outputs", description="Output directory for artifacts")
138
+ tags: dict[str, str] = Field(default_factory=dict, description="Metadata tags for tracking")
139
+
140
+ model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
141
+
142
+ def to_proto(self) -> Any:
143
+ """Convert to protobuf message (requires proto extra)."""
144
+ try:
145
+ from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
146
+ except ImportError as e:
147
+ raise ImportError(
148
+ "Proto support requires the 'proto' extra: pip install cuvis-ai-schemas[proto]"
149
+ ) from e
150
+
151
+ return cuvis_ai_pb2.TrainRunConfig(config_bytes=self.model_dump_json().encode("utf-8"))
152
+
153
+ @classmethod
154
+ def from_proto(cls, proto_config):
155
+ """Create from protobuf message (requires proto extra)."""
156
+ return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
157
+
158
+ def to_json(self) -> str:
159
+ """JSON serialization helper for legacy callers."""
160
+ return self.model_dump_json()
161
+
162
+ @classmethod
163
+ def from_json(cls, payload: str) -> TrainRunConfig:
164
+ """Create from JSON string."""
165
+ return cls.model_validate_json(payload)
166
+
167
+ def to_dict(self) -> dict[str, Any]:
168
+ """Convert to dictionary."""
169
+ return self.model_dump()
170
+
171
+ @classmethod
172
+ def from_dict(cls, data: dict[str, Any]) -> TrainRunConfig:
173
+ """Create from dictionary."""
174
+ return cls.model_validate(data)
175
+
176
+ def save_to_file(self, path: str | Path) -> None:
177
+ """Save configuration to YAML file."""
178
+ output_path = Path(path)
179
+ output_path.parent.mkdir(parents=True, exist_ok=True)
180
+ with output_path.open("w", encoding="utf-8") as f:
181
+ yaml.safe_dump(self.model_dump(), f, sort_keys=False)
182
+
183
+ @classmethod
184
+ def load_from_file(cls, path: str | Path) -> TrainRunConfig:
185
+ """Load configuration from YAML file."""
186
+ with Path(path).open("r", encoding="utf-8") as f:
187
+ data = yaml.safe_load(f)
188
+ return cls.from_dict(data)
189
+
190
+ @model_validator(mode="after")
191
+ def _validate_training_config(self) -> TrainRunConfig:
192
+ """Ensure training config has optimizer if provided."""
193
+ if self.training is not None and self.training.optimizer is None:
194
+ raise ValueError("Training configuration must include optimizer when provided")
195
+ return self
196
+
197
+
198
+ __all__ = ["PipelineMetadata", "PipelineConfig", "TrainRunConfig"]