cuvis-ai-schemas 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuvis_ai_schemas/__init__.py +5 -0
- cuvis_ai_schemas/discovery/__init__.py +6 -0
- cuvis_ai_schemas/enums/__init__.py +5 -0
- cuvis_ai_schemas/enums/types.py +30 -0
- cuvis_ai_schemas/execution/__init__.py +12 -0
- cuvis_ai_schemas/execution/context.py +41 -0
- cuvis_ai_schemas/execution/monitoring.py +83 -0
- cuvis_ai_schemas/extensions/__init__.py +3 -0
- cuvis_ai_schemas/extensions/ui/__init__.py +8 -0
- cuvis_ai_schemas/extensions/ui/port_display.py +159 -0
- cuvis_ai_schemas/grpc/__init__.py +3 -0
- cuvis_ai_schemas/grpc/v1/__init__.py +11 -0
- cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2.py +240 -0
- cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2.pyi +1046 -0
- cuvis_ai_schemas/grpc/v1/cuvis_ai_pb2_grpc.py +1290 -0
- cuvis_ai_schemas/pipeline/__init__.py +17 -0
- cuvis_ai_schemas/pipeline/config.py +238 -0
- cuvis_ai_schemas/pipeline/ports.py +48 -0
- cuvis_ai_schemas/plugin/__init__.py +6 -0
- cuvis_ai_schemas/plugin/config.py +118 -0
- cuvis_ai_schemas/plugin/manifest.py +95 -0
- cuvis_ai_schemas/training/__init__.py +40 -0
- cuvis_ai_schemas/training/callbacks.py +137 -0
- cuvis_ai_schemas/training/config.py +135 -0
- cuvis_ai_schemas/training/data.py +73 -0
- cuvis_ai_schemas/training/optimizer.py +94 -0
- cuvis_ai_schemas/training/run.py +198 -0
- cuvis_ai_schemas/training/scheduler.py +69 -0
- cuvis_ai_schemas/training/trainer.py +40 -0
- cuvis_ai_schemas-0.1.0.dist-info/METADATA +111 -0
- cuvis_ai_schemas-0.1.0.dist-info/RECORD +34 -0
- cuvis_ai_schemas-0.1.0.dist-info/WHEEL +5 -0
- cuvis_ai_schemas-0.1.0.dist-info/licenses/LICENSE +190 -0
- cuvis_ai_schemas-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Pipeline structure schemas."""
|
|
2
|
+
|
|
3
|
+
from cuvis_ai_schemas.pipeline.config import (
|
|
4
|
+
ConnectionConfig,
|
|
5
|
+
NodeConfig,
|
|
6
|
+
PipelineConfig,
|
|
7
|
+
PipelineMetadata,
|
|
8
|
+
)
|
|
9
|
+
from cuvis_ai_schemas.pipeline.ports import PortSpec
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"PipelineConfig",
|
|
13
|
+
"PipelineMetadata",
|
|
14
|
+
"NodeConfig",
|
|
15
|
+
"ConnectionConfig",
|
|
16
|
+
"PortSpec",
|
|
17
|
+
]
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""Pipeline configuration schemas."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _BaseConfig(BaseModel):
|
|
17
|
+
"""Base model with strict validation."""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(extra="forbid", validate_assignment=True, populate_by_name=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PipelineMetadata(_BaseConfig):
|
|
23
|
+
"""Pipeline metadata for documentation and discovery.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
name : str
|
|
28
|
+
Pipeline name
|
|
29
|
+
description : str
|
|
30
|
+
Human-readable description
|
|
31
|
+
created : str
|
|
32
|
+
Creation timestamp (ISO format)
|
|
33
|
+
tags : list[str]
|
|
34
|
+
Tags for categorization and search
|
|
35
|
+
author : str
|
|
36
|
+
Author name or email
|
|
37
|
+
cuvis_ai_version : str
|
|
38
|
+
Version of cuvis-ai-schemas used
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
name: str
|
|
42
|
+
description: str = ""
|
|
43
|
+
created: str = ""
|
|
44
|
+
tags: list[str] = Field(default_factory=list)
|
|
45
|
+
author: str = ""
|
|
46
|
+
cuvis_ai_version: str = "0.1.0"
|
|
47
|
+
|
|
48
|
+
def to_dict(self) -> dict[str, Any]:
|
|
49
|
+
"""Convert to dictionary."""
|
|
50
|
+
return self.model_dump()
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_dict(cls, data: dict[str, Any]) -> PipelineMetadata:
|
|
54
|
+
"""Load from dictionary."""
|
|
55
|
+
return cls.model_validate(data)
|
|
56
|
+
|
|
57
|
+
def to_proto(self) -> cuvis_ai_pb2.PipelineMetadata:
|
|
58
|
+
"""Convert to proto message.
|
|
59
|
+
|
|
60
|
+
Requires cuvis-ai-schemas[proto] to be installed.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
cuvis_ai_pb2.PipelineMetadata
|
|
65
|
+
Proto message representation
|
|
66
|
+
"""
|
|
67
|
+
try:
|
|
68
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
69
|
+
except ImportError as exc:
|
|
70
|
+
msg = "Proto support not installed. Install with: pip install cuvis-ai-schemas[proto]"
|
|
71
|
+
raise ImportError(msg) from exc
|
|
72
|
+
|
|
73
|
+
return cuvis_ai_pb2.PipelineMetadata(
|
|
74
|
+
name=self.name,
|
|
75
|
+
description=self.description,
|
|
76
|
+
created=self.created,
|
|
77
|
+
tags=list(self.tags),
|
|
78
|
+
author=self.author,
|
|
79
|
+
cuvis_ai_version=self.cuvis_ai_version,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class NodeConfig(_BaseConfig):
|
|
84
|
+
"""Node configuration within a pipeline.
|
|
85
|
+
|
|
86
|
+
Attributes
|
|
87
|
+
----------
|
|
88
|
+
id : str
|
|
89
|
+
Unique node identifier
|
|
90
|
+
class_name : str
|
|
91
|
+
Fully-qualified class name (e.g., 'my_package.MyNode')
|
|
92
|
+
Alias: 'class' for backward compatibility
|
|
93
|
+
params : dict[str, Any]
|
|
94
|
+
Node parameters/hyperparameters
|
|
95
|
+
Alias: 'hparams' for backward compatibility
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
id: str = Field(description="Unique node identifier")
|
|
99
|
+
class_name: str = Field(description="Fully-qualified class name", alias="class")
|
|
100
|
+
params: dict[str, Any] = Field(
|
|
101
|
+
default_factory=dict, description="Node parameters", alias="hparams"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ConnectionConfig(_BaseConfig):
|
|
106
|
+
"""Connection between two nodes.
|
|
107
|
+
|
|
108
|
+
Attributes
|
|
109
|
+
----------
|
|
110
|
+
from_node : str
|
|
111
|
+
Source node ID
|
|
112
|
+
from_port : str
|
|
113
|
+
Source port name
|
|
114
|
+
to_node : str
|
|
115
|
+
Target node ID
|
|
116
|
+
to_port : str
|
|
117
|
+
Target port name
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
from_node: str = Field(description="Source node ID")
|
|
121
|
+
from_port: str = Field(description="Source port name")
|
|
122
|
+
to_node: str = Field(description="Target node ID")
|
|
123
|
+
to_port: str = Field(description="Target port name")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class PipelineConfig(_BaseConfig):
|
|
127
|
+
"""Pipeline structure configuration.
|
|
128
|
+
|
|
129
|
+
Attributes
|
|
130
|
+
----------
|
|
131
|
+
name : str
|
|
132
|
+
Pipeline name
|
|
133
|
+
nodes : list[NodeConfig] | list[dict[str, Any]]
|
|
134
|
+
Node definitions (can be NodeConfig or dict for flexibility)
|
|
135
|
+
connections : list[ConnectionConfig] | list[dict[str, Any]]
|
|
136
|
+
Node connections (can be ConnectionConfig or dict for flexibility)
|
|
137
|
+
frozen_nodes : list[str]
|
|
138
|
+
Node IDs to keep frozen during training
|
|
139
|
+
metadata : PipelineMetadata | None
|
|
140
|
+
Optional pipeline metadata
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
name: str = Field(default="", description="Pipeline name")
|
|
144
|
+
nodes: list[dict[str, Any]] = Field(description="Node definitions")
|
|
145
|
+
connections: list[dict[str, Any]] = Field(description="Node connections")
|
|
146
|
+
frozen_nodes: list[str] = Field(
|
|
147
|
+
default_factory=list, description="Node names to keep frozen during training"
|
|
148
|
+
)
|
|
149
|
+
metadata: PipelineMetadata | None = Field(
|
|
150
|
+
default=None, description="Optional pipeline metadata"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def to_proto(self) -> cuvis_ai_pb2.PipelineConfig:
|
|
154
|
+
"""Convert to proto message.
|
|
155
|
+
|
|
156
|
+
Requires cuvis-ai-schemas[proto] to be installed.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
cuvis_ai_pb2.PipelineConfig
|
|
161
|
+
Proto message representation
|
|
162
|
+
"""
|
|
163
|
+
try:
|
|
164
|
+
from cuvis_ai_schemas.grpc.v1 import cuvis_ai_pb2
|
|
165
|
+
except ImportError as exc:
|
|
166
|
+
msg = "Proto support not installed. Install with: pip install cuvis-ai-schemas[proto]"
|
|
167
|
+
raise ImportError(msg) from exc
|
|
168
|
+
|
|
169
|
+
return cuvis_ai_pb2.PipelineConfig(config_bytes=self.model_dump_json().encode("utf-8"))
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_proto(cls, proto_config: cuvis_ai_pb2.PipelineConfig) -> PipelineConfig:
|
|
173
|
+
"""Load from proto message.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
proto_config : cuvis_ai_pb2.PipelineConfig
|
|
178
|
+
Proto message to deserialize
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
PipelineConfig
|
|
183
|
+
Loaded configuration
|
|
184
|
+
"""
|
|
185
|
+
return cls.model_validate_json(proto_config.config_bytes.decode("utf-8"))
|
|
186
|
+
|
|
187
|
+
def to_json(self) -> str:
|
|
188
|
+
"""Convert to JSON string."""
|
|
189
|
+
return self.model_dump_json()
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def from_json(cls, payload: str) -> PipelineConfig:
|
|
193
|
+
"""Load from JSON string."""
|
|
194
|
+
return cls.model_validate_json(payload)
|
|
195
|
+
|
|
196
|
+
def to_dict(self) -> dict[str, Any]:
|
|
197
|
+
"""Convert to dictionary."""
|
|
198
|
+
return self.model_dump()
|
|
199
|
+
|
|
200
|
+
@classmethod
|
|
201
|
+
def from_dict(cls, data: dict[str, Any]) -> PipelineConfig:
|
|
202
|
+
"""Load from dictionary."""
|
|
203
|
+
return cls.model_validate(data)
|
|
204
|
+
|
|
205
|
+
def save_to_file(self, path: str | Path) -> None:
|
|
206
|
+
"""Save pipeline configuration to YAML file.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
path : str | Path
|
|
211
|
+
Output file path
|
|
212
|
+
"""
|
|
213
|
+
from pathlib import Path
|
|
214
|
+
|
|
215
|
+
output_path = Path(path)
|
|
216
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
217
|
+
with output_path.open("w", encoding="utf-8") as f:
|
|
218
|
+
yaml.safe_dump(self.model_dump(), f, sort_keys=False)
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def load_from_file(cls, path: str | Path) -> PipelineConfig:
|
|
222
|
+
"""Load pipeline configuration from YAML file.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
path : str | Path
|
|
227
|
+
Input file path
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
PipelineConfig
|
|
232
|
+
Loaded configuration
|
|
233
|
+
"""
|
|
234
|
+
from pathlib import Path
|
|
235
|
+
|
|
236
|
+
with Path(path).open("r", encoding="utf-8") as f:
|
|
237
|
+
data = yaml.safe_load(f)
|
|
238
|
+
return cls.from_dict(data)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Port specification for node inputs and outputs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class PortSpec:
|
|
11
|
+
"""Specification for a node input or output port.
|
|
12
|
+
|
|
13
|
+
This is a lightweight schema definition. Full compatibility checking
|
|
14
|
+
logic is implemented in cuvis-ai-core.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
dtype : Any
|
|
19
|
+
Data type for the port (e.g., torch.Tensor, torch.float32, int, str)
|
|
20
|
+
shape : tuple[int | str, ...]
|
|
21
|
+
Expected shape with:
|
|
22
|
+
- Fixed dimensions: positive integers
|
|
23
|
+
- Flexible dimensions: -1
|
|
24
|
+
- Symbolic dimensions: strings (resolved from node attributes)
|
|
25
|
+
description : str
|
|
26
|
+
Human-readable description of the port
|
|
27
|
+
optional : bool
|
|
28
|
+
Whether the port is optional (for inputs)
|
|
29
|
+
|
|
30
|
+
Examples
|
|
31
|
+
--------
|
|
32
|
+
>>> # Fixed shape tensor port
|
|
33
|
+
>>> port = PortSpec(dtype=torch.Tensor, shape=(1, 3, 224, 224))
|
|
34
|
+
|
|
35
|
+
>>> # Flexible batch dimension
|
|
36
|
+
>>> port = PortSpec(dtype=torch.Tensor, shape=(-1, 3, 224, 224))
|
|
37
|
+
|
|
38
|
+
>>> # Symbolic dimension from node attribute
|
|
39
|
+
>>> port = PortSpec(dtype=torch.Tensor, shape=(-1, "num_channels", 224, 224))
|
|
40
|
+
|
|
41
|
+
>>> # Scalar port
|
|
42
|
+
>>> port = PortSpec(dtype=float, shape=())
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
dtype: Any
|
|
46
|
+
shape: tuple[int | str, ...]
|
|
47
|
+
description: str = ""
|
|
48
|
+
optional: bool = False
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Plugin configuration schemas."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class _BasePluginConfig(BaseModel):
|
|
9
|
+
"""Base plugin configuration with strict validation.
|
|
10
|
+
|
|
11
|
+
All plugin types inherit from this base class to ensure
|
|
12
|
+
consistent validation and error handling.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
model_config = ConfigDict(
|
|
16
|
+
extra="forbid", # Reject unknown fields (catch typos)
|
|
17
|
+
validate_assignment=True, # Validate on attribute assignment
|
|
18
|
+
populate_by_name=True, # Allow field aliases
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
provides: list[str] = Field(
|
|
22
|
+
description="List of fully-qualified class paths this plugin provides",
|
|
23
|
+
min_length=1, # At least one class required
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
@field_validator("provides")
|
|
27
|
+
@classmethod
|
|
28
|
+
def _validate_class_paths(cls, value: list[str]) -> list[str]:
|
|
29
|
+
"""Ensure class paths are well-formed."""
|
|
30
|
+
for class_path in value:
|
|
31
|
+
if not class_path or "." not in class_path:
|
|
32
|
+
msg = (
|
|
33
|
+
f"Invalid class path '{class_path}'. "
|
|
34
|
+
"Must be fully-qualified (e.g., 'package.module.ClassName')"
|
|
35
|
+
)
|
|
36
|
+
raise ValueError(msg)
|
|
37
|
+
return value
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GitPluginConfig(_BasePluginConfig):
|
|
41
|
+
"""Git repository plugin configuration.
|
|
42
|
+
|
|
43
|
+
Supports:
|
|
44
|
+
- SSH URLs: git@gitlab.com:user/repo.git
|
|
45
|
+
- HTTPS URLs: https://github.com/user/repo.git
|
|
46
|
+
- Git tags only: v1.2.3, v0.1.0-alpha, etc.
|
|
47
|
+
|
|
48
|
+
Note: Branches and commit hashes are NOT supported for reproducibility.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
repo: str = Field(
|
|
52
|
+
description="Git repository URL (SSH or HTTPS)",
|
|
53
|
+
min_length=1,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
tag: str = Field(
|
|
57
|
+
description="Git tag (e.g., v1.2.3, v0.1.0-alpha). "
|
|
58
|
+
"Branches and commit hashes are not supported.",
|
|
59
|
+
min_length=1,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
@field_validator("repo")
|
|
63
|
+
@classmethod
|
|
64
|
+
def _validate_repo_url(cls, value: str) -> str:
|
|
65
|
+
"""Validate Git repository URL format."""
|
|
66
|
+
if not (
|
|
67
|
+
value.startswith("git@") or value.startswith("https://") or value.startswith("http://")
|
|
68
|
+
):
|
|
69
|
+
msg = f"Invalid repo URL '{value}'. Must start with 'git@', 'https://', or 'http://'"
|
|
70
|
+
raise ValueError(msg)
|
|
71
|
+
return value
|
|
72
|
+
|
|
73
|
+
@field_validator("tag")
|
|
74
|
+
@classmethod
|
|
75
|
+
def _validate_tag(cls, value: str) -> str:
|
|
76
|
+
"""Validate Git tag is not empty."""
|
|
77
|
+
if not value.strip():
|
|
78
|
+
msg = "Git tag cannot be empty"
|
|
79
|
+
raise ValueError(msg)
|
|
80
|
+
return value.strip()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class LocalPluginConfig(_BasePluginConfig):
|
|
84
|
+
"""Local filesystem plugin configuration.
|
|
85
|
+
|
|
86
|
+
Supports:
|
|
87
|
+
- Absolute paths: /home/user/my-plugin
|
|
88
|
+
- Relative paths: ../my-plugin (resolved relative to manifest file)
|
|
89
|
+
- Windows paths: C:\\Users\\user\\my-plugin
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
path: str = Field(
|
|
93
|
+
description="Absolute or relative path to plugin directory",
|
|
94
|
+
min_length=1,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@field_validator("path")
|
|
98
|
+
@classmethod
|
|
99
|
+
def _validate_path(cls, value: str) -> str:
|
|
100
|
+
"""Validate path is not empty."""
|
|
101
|
+
if not value.strip():
|
|
102
|
+
msg = "Path cannot be empty"
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
return value.strip()
|
|
105
|
+
|
|
106
|
+
def resolve_path(self, manifest_dir: Path) -> Path:
|
|
107
|
+
"""Resolve relative paths to absolute paths.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
manifest_dir: Directory containing the manifest file
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Absolute path to plugin directory
|
|
114
|
+
"""
|
|
115
|
+
plugin_path = Path(self.path)
|
|
116
|
+
if not plugin_path.is_absolute():
|
|
117
|
+
plugin_path = (manifest_dir / plugin_path).resolve()
|
|
118
|
+
return plugin_path
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Plugin manifest schema."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Annotated
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
8
|
+
|
|
9
|
+
from cuvis_ai_schemas.plugin.config import GitPluginConfig, LocalPluginConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PluginManifest(BaseModel):
|
|
13
|
+
"""Complete plugin manifest containing all plugin configurations.
|
|
14
|
+
|
|
15
|
+
This is the root configuration object validated when loading
|
|
16
|
+
a plugins.yaml file or dictionary.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(
|
|
20
|
+
extra="forbid",
|
|
21
|
+
validate_assignment=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
plugins: dict[
|
|
25
|
+
str,
|
|
26
|
+
Annotated[
|
|
27
|
+
GitPluginConfig | LocalPluginConfig,
|
|
28
|
+
Field(discriminator=None), # Pydantic will auto-detect based on fields
|
|
29
|
+
],
|
|
30
|
+
] = Field(
|
|
31
|
+
description="Map of plugin names to their configurations",
|
|
32
|
+
default_factory=dict,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
@field_validator("plugins")
|
|
36
|
+
@classmethod
|
|
37
|
+
def _validate_plugin_names(cls, value: dict) -> dict:
|
|
38
|
+
"""Ensure plugin names are valid Python identifiers."""
|
|
39
|
+
for name in value.keys():
|
|
40
|
+
if not name.isidentifier():
|
|
41
|
+
msg = f"Invalid plugin name '{name}'. Must be a valid Python identifier"
|
|
42
|
+
raise ValueError(msg)
|
|
43
|
+
return value
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_yaml(cls, yaml_path: Path) -> "PluginManifest":
|
|
47
|
+
"""Load and validate manifest from YAML file.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
yaml_path: Path to YAML file
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Validated PluginManifest instance
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
FileNotFoundError: If yaml_path doesn't exist
|
|
57
|
+
"""
|
|
58
|
+
if not yaml_path.exists():
|
|
59
|
+
msg = f"Plugin manifest not found: {yaml_path}"
|
|
60
|
+
raise FileNotFoundError(msg)
|
|
61
|
+
|
|
62
|
+
with yaml_path.open("r", encoding="utf-8") as f:
|
|
63
|
+
data = yaml.safe_load(f)
|
|
64
|
+
|
|
65
|
+
if not data:
|
|
66
|
+
return cls(plugins={})
|
|
67
|
+
|
|
68
|
+
return cls.model_validate(data)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_dict(cls, data: dict) -> "PluginManifest":
|
|
72
|
+
"""Load and validate manifest from dictionary.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
data: Dictionary containing plugin configurations
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Validated PluginManifest instance
|
|
79
|
+
"""
|
|
80
|
+
return cls.model_validate(data)
|
|
81
|
+
|
|
82
|
+
def to_yaml(self, yaml_path: Path) -> None:
|
|
83
|
+
"""Save manifest to YAML file.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
yaml_path: Path where YAML file should be saved
|
|
87
|
+
"""
|
|
88
|
+
yaml_path.parent.mkdir(parents=True, exist_ok=True)
|
|
89
|
+
with yaml_path.open("w", encoding="utf-8") as f:
|
|
90
|
+
yaml.safe_dump(
|
|
91
|
+
self.model_dump(exclude_none=True),
|
|
92
|
+
f,
|
|
93
|
+
sort_keys=False,
|
|
94
|
+
default_flow_style=False,
|
|
95
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Training configuration schemas for cuvis-ai."""
|
|
2
|
+
|
|
3
|
+
from cuvis_ai_schemas.training.callbacks import (
|
|
4
|
+
CallbacksConfig,
|
|
5
|
+
EarlyStoppingConfig,
|
|
6
|
+
LearningRateMonitorConfig,
|
|
7
|
+
ModelCheckpointConfig,
|
|
8
|
+
)
|
|
9
|
+
from cuvis_ai_schemas.training.config import TrainingConfig
|
|
10
|
+
from cuvis_ai_schemas.training.data import DataConfig
|
|
11
|
+
from cuvis_ai_schemas.training.optimizer import OptimizerConfig
|
|
12
|
+
from cuvis_ai_schemas.training.run import (
|
|
13
|
+
PipelineConfig,
|
|
14
|
+
PipelineMetadata,
|
|
15
|
+
TrainRunConfig,
|
|
16
|
+
)
|
|
17
|
+
from cuvis_ai_schemas.training.scheduler import SchedulerConfig
|
|
18
|
+
from cuvis_ai_schemas.training.trainer import TrainerConfig
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
# Callbacks
|
|
22
|
+
"CallbacksConfig",
|
|
23
|
+
"EarlyStoppingConfig",
|
|
24
|
+
"LearningRateMonitorConfig",
|
|
25
|
+
"ModelCheckpointConfig",
|
|
26
|
+
# Config
|
|
27
|
+
"TrainingConfig",
|
|
28
|
+
# Data
|
|
29
|
+
"DataConfig",
|
|
30
|
+
# Optimizer
|
|
31
|
+
"OptimizerConfig",
|
|
32
|
+
# Scheduler
|
|
33
|
+
"SchedulerConfig",
|
|
34
|
+
# Trainer
|
|
35
|
+
"TrainerConfig",
|
|
36
|
+
# Run
|
|
37
|
+
"PipelineConfig",
|
|
38
|
+
"PipelineMetadata",
|
|
39
|
+
"TrainRunConfig",
|
|
40
|
+
]
|