synapse-sdk 1.0.0a11__py3-none-any.whl → 2026.1.1b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synapse-sdk might be problematic. Click here for more details.
- synapse_sdk/__init__.py +24 -0
- synapse_sdk/cli/__init__.py +9 -8
- synapse_sdk/cli/agent/__init__.py +25 -0
- synapse_sdk/cli/agent/config.py +104 -0
- synapse_sdk/cli/agent/select.py +197 -0
- synapse_sdk/cli/auth.py +104 -0
- synapse_sdk/cli/main.py +1025 -0
- synapse_sdk/cli/plugin/__init__.py +58 -0
- synapse_sdk/cli/plugin/create.py +566 -0
- synapse_sdk/cli/plugin/job.py +196 -0
- synapse_sdk/cli/plugin/publish.py +322 -0
- synapse_sdk/cli/plugin/run.py +131 -0
- synapse_sdk/cli/plugin/test.py +200 -0
- synapse_sdk/clients/README.md +239 -0
- synapse_sdk/clients/__init__.py +5 -0
- synapse_sdk/clients/_template.py +266 -0
- synapse_sdk/clients/agent/__init__.py +84 -29
- synapse_sdk/clients/agent/async_ray.py +289 -0
- synapse_sdk/clients/agent/container.py +83 -0
- synapse_sdk/clients/agent/plugin.py +101 -0
- synapse_sdk/clients/agent/ray.py +296 -39
- synapse_sdk/clients/backend/__init__.py +152 -12
- synapse_sdk/clients/backend/annotation.py +164 -22
- synapse_sdk/clients/backend/core.py +101 -0
- synapse_sdk/clients/backend/data_collection.py +292 -0
- synapse_sdk/clients/backend/hitl.py +87 -0
- synapse_sdk/clients/backend/integration.py +374 -46
- synapse_sdk/clients/backend/ml.py +134 -22
- synapse_sdk/clients/backend/models.py +247 -0
- synapse_sdk/clients/base.py +538 -59
- synapse_sdk/clients/exceptions.py +35 -7
- synapse_sdk/clients/pipeline/__init__.py +5 -0
- synapse_sdk/clients/pipeline/client.py +636 -0
- synapse_sdk/clients/protocols.py +178 -0
- synapse_sdk/clients/utils.py +86 -8
- synapse_sdk/clients/validation.py +58 -0
- synapse_sdk/enums.py +76 -0
- synapse_sdk/exceptions.py +168 -0
- synapse_sdk/integrations/__init__.py +74 -0
- synapse_sdk/integrations/_base.py +119 -0
- synapse_sdk/integrations/_context.py +53 -0
- synapse_sdk/integrations/ultralytics/__init__.py +78 -0
- synapse_sdk/integrations/ultralytics/_callbacks.py +126 -0
- synapse_sdk/integrations/ultralytics/_patches.py +124 -0
- synapse_sdk/loggers.py +476 -95
- synapse_sdk/mcp/MCP.md +69 -0
- synapse_sdk/mcp/__init__.py +48 -0
- synapse_sdk/mcp/__main__.py +6 -0
- synapse_sdk/mcp/config.py +349 -0
- synapse_sdk/mcp/prompts/__init__.py +4 -0
- synapse_sdk/mcp/resources/__init__.py +4 -0
- synapse_sdk/mcp/server.py +1352 -0
- synapse_sdk/mcp/tools/__init__.py +6 -0
- synapse_sdk/plugins/__init__.py +133 -9
- synapse_sdk/plugins/action.py +229 -0
- synapse_sdk/plugins/actions/__init__.py +82 -0
- synapse_sdk/plugins/actions/dataset/__init__.py +37 -0
- synapse_sdk/plugins/actions/dataset/action.py +471 -0
- synapse_sdk/plugins/actions/export/__init__.py +55 -0
- synapse_sdk/plugins/actions/export/action.py +183 -0
- synapse_sdk/plugins/actions/export/context.py +59 -0
- synapse_sdk/plugins/actions/inference/__init__.py +84 -0
- synapse_sdk/plugins/actions/inference/action.py +285 -0
- synapse_sdk/plugins/actions/inference/context.py +81 -0
- synapse_sdk/plugins/actions/inference/deployment.py +322 -0
- synapse_sdk/plugins/actions/inference/serve.py +252 -0
- synapse_sdk/plugins/actions/train/__init__.py +54 -0
- synapse_sdk/plugins/actions/train/action.py +326 -0
- synapse_sdk/plugins/actions/train/context.py +57 -0
- synapse_sdk/plugins/actions/upload/__init__.py +49 -0
- synapse_sdk/plugins/actions/upload/action.py +165 -0
- synapse_sdk/plugins/actions/upload/context.py +61 -0
- synapse_sdk/plugins/config.py +98 -0
- synapse_sdk/plugins/context/__init__.py +109 -0
- synapse_sdk/plugins/context/env.py +113 -0
- synapse_sdk/plugins/datasets/__init__.py +113 -0
- synapse_sdk/plugins/datasets/converters/__init__.py +76 -0
- synapse_sdk/plugins/datasets/converters/base.py +347 -0
- synapse_sdk/plugins/datasets/converters/yolo/__init__.py +9 -0
- synapse_sdk/plugins/datasets/converters/yolo/from_dm.py +468 -0
- synapse_sdk/plugins/datasets/converters/yolo/to_dm.py +381 -0
- synapse_sdk/plugins/datasets/formats/__init__.py +82 -0
- synapse_sdk/plugins/datasets/formats/dm.py +351 -0
- synapse_sdk/plugins/datasets/formats/yolo.py +240 -0
- synapse_sdk/plugins/decorators.py +83 -0
- synapse_sdk/plugins/discovery.py +790 -0
- synapse_sdk/plugins/docs/ACTION_DEV_GUIDE.md +933 -0
- synapse_sdk/plugins/docs/ARCHITECTURE.md +1225 -0
- synapse_sdk/plugins/docs/LOGGING_SYSTEM.md +683 -0
- synapse_sdk/plugins/docs/OVERVIEW.md +531 -0
- synapse_sdk/plugins/docs/PIPELINE_GUIDE.md +145 -0
- synapse_sdk/plugins/docs/README.md +513 -0
- synapse_sdk/plugins/docs/STEP.md +656 -0
- synapse_sdk/plugins/enums.py +70 -10
- synapse_sdk/plugins/errors.py +92 -0
- synapse_sdk/plugins/executors/__init__.py +43 -0
- synapse_sdk/plugins/executors/local.py +99 -0
- synapse_sdk/plugins/executors/ray/__init__.py +18 -0
- synapse_sdk/plugins/executors/ray/base.py +282 -0
- synapse_sdk/plugins/executors/ray/job.py +298 -0
- synapse_sdk/plugins/executors/ray/jobs_api.py +511 -0
- synapse_sdk/plugins/executors/ray/packaging.py +137 -0
- synapse_sdk/plugins/executors/ray/pipeline.py +792 -0
- synapse_sdk/plugins/executors/ray/task.py +257 -0
- synapse_sdk/plugins/models/__init__.py +26 -0
- synapse_sdk/plugins/models/logger.py +173 -0
- synapse_sdk/plugins/models/pipeline.py +25 -0
- synapse_sdk/plugins/pipelines/__init__.py +81 -0
- synapse_sdk/plugins/pipelines/action_pipeline.py +417 -0
- synapse_sdk/plugins/pipelines/context.py +107 -0
- synapse_sdk/plugins/pipelines/display.py +311 -0
- synapse_sdk/plugins/runner.py +114 -0
- synapse_sdk/plugins/schemas/__init__.py +19 -0
- synapse_sdk/plugins/schemas/results.py +152 -0
- synapse_sdk/plugins/steps/__init__.py +63 -0
- synapse_sdk/plugins/steps/base.py +128 -0
- synapse_sdk/plugins/steps/context.py +90 -0
- synapse_sdk/plugins/steps/orchestrator.py +128 -0
- synapse_sdk/plugins/steps/registry.py +103 -0
- synapse_sdk/plugins/steps/utils/__init__.py +20 -0
- synapse_sdk/plugins/steps/utils/logging.py +85 -0
- synapse_sdk/plugins/steps/utils/timing.py +71 -0
- synapse_sdk/plugins/steps/utils/validation.py +68 -0
- synapse_sdk/plugins/templates/__init__.py +50 -0
- synapse_sdk/plugins/templates/base/.gitignore.j2 +26 -0
- synapse_sdk/plugins/templates/base/.synapseignore.j2 +11 -0
- synapse_sdk/plugins/templates/base/README.md.j2 +26 -0
- synapse_sdk/plugins/templates/base/plugin/__init__.py.j2 +1 -0
- synapse_sdk/plugins/templates/base/pyproject.toml.j2 +14 -0
- synapse_sdk/plugins/templates/base/requirements.txt.j2 +1 -0
- synapse_sdk/plugins/templates/custom/plugin/main.py.j2 +18 -0
- synapse_sdk/plugins/templates/data_validation/plugin/validate.py.j2 +32 -0
- synapse_sdk/plugins/templates/export/plugin/export.py.j2 +36 -0
- synapse_sdk/plugins/templates/neural_net/plugin/inference.py.j2 +36 -0
- synapse_sdk/plugins/templates/neural_net/plugin/train.py.j2 +33 -0
- synapse_sdk/plugins/templates/post_annotation/plugin/post_annotate.py.j2 +32 -0
- synapse_sdk/plugins/templates/pre_annotation/plugin/pre_annotate.py.j2 +32 -0
- synapse_sdk/plugins/templates/smart_tool/plugin/auto_label.py.j2 +44 -0
- synapse_sdk/plugins/templates/upload/plugin/upload.py.j2 +35 -0
- synapse_sdk/plugins/testing/__init__.py +25 -0
- synapse_sdk/plugins/testing/sample_actions.py +98 -0
- synapse_sdk/plugins/types.py +206 -0
- synapse_sdk/plugins/upload.py +595 -64
- synapse_sdk/plugins/utils.py +325 -37
- synapse_sdk/shared/__init__.py +25 -0
- synapse_sdk/utils/__init__.py +1 -0
- synapse_sdk/utils/auth.py +74 -0
- synapse_sdk/utils/file/__init__.py +58 -0
- synapse_sdk/utils/file/archive.py +449 -0
- synapse_sdk/utils/file/checksum.py +167 -0
- synapse_sdk/utils/file/download.py +286 -0
- synapse_sdk/utils/file/io.py +129 -0
- synapse_sdk/utils/file/requirements.py +36 -0
- synapse_sdk/utils/network.py +168 -0
- synapse_sdk/utils/storage/__init__.py +238 -0
- synapse_sdk/utils/storage/config.py +188 -0
- synapse_sdk/utils/storage/errors.py +52 -0
- synapse_sdk/utils/storage/providers/__init__.py +13 -0
- synapse_sdk/utils/storage/providers/base.py +76 -0
- synapse_sdk/utils/storage/providers/gcs.py +168 -0
- synapse_sdk/utils/storage/providers/http.py +250 -0
- synapse_sdk/utils/storage/providers/local.py +126 -0
- synapse_sdk/utils/storage/providers/s3.py +177 -0
- synapse_sdk/utils/storage/providers/sftp.py +208 -0
- synapse_sdk/utils/storage/registry.py +125 -0
- synapse_sdk/utils/websocket.py +99 -0
- synapse_sdk-2026.1.1b2.dist-info/METADATA +715 -0
- synapse_sdk-2026.1.1b2.dist-info/RECORD +172 -0
- {synapse_sdk-1.0.0a11.dist-info → synapse_sdk-2026.1.1b2.dist-info}/WHEEL +1 -1
- synapse_sdk-2026.1.1b2.dist-info/licenses/LICENSE +201 -0
- locale/en/LC_MESSAGES/messages.mo +0 -0
- locale/en/LC_MESSAGES/messages.po +0 -39
- locale/ko/LC_MESSAGES/messages.mo +0 -0
- locale/ko/LC_MESSAGES/messages.po +0 -34
- synapse_sdk/cli/create_plugin.py +0 -10
- synapse_sdk/clients/agent/core.py +0 -7
- synapse_sdk/clients/agent/service.py +0 -15
- synapse_sdk/clients/backend/dataset.py +0 -51
- synapse_sdk/clients/ray/__init__.py +0 -6
- synapse_sdk/clients/ray/core.py +0 -22
- synapse_sdk/clients/ray/serve.py +0 -20
- synapse_sdk/i18n.py +0 -35
- synapse_sdk/plugins/categories/__init__.py +0 -0
- synapse_sdk/plugins/categories/base.py +0 -235
- synapse_sdk/plugins/categories/data_validation/__init__.py +0 -0
- synapse_sdk/plugins/categories/data_validation/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/data_validation/actions/validation.py +0 -10
- synapse_sdk/plugins/categories/data_validation/templates/config.yaml +0 -3
- synapse_sdk/plugins/categories/data_validation/templates/plugin/__init__.py +0 -0
- synapse_sdk/plugins/categories/data_validation/templates/plugin/validation.py +0 -5
- synapse_sdk/plugins/categories/decorators.py +0 -13
- synapse_sdk/plugins/categories/export/__init__.py +0 -0
- synapse_sdk/plugins/categories/export/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/export/actions/export.py +0 -10
- synapse_sdk/plugins/categories/import/__init__.py +0 -0
- synapse_sdk/plugins/categories/import/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/import/actions/import.py +0 -10
- synapse_sdk/plugins/categories/neural_net/__init__.py +0 -0
- synapse_sdk/plugins/categories/neural_net/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/neural_net/actions/deployment.py +0 -45
- synapse_sdk/plugins/categories/neural_net/actions/inference.py +0 -18
- synapse_sdk/plugins/categories/neural_net/actions/test.py +0 -10
- synapse_sdk/plugins/categories/neural_net/actions/train.py +0 -143
- synapse_sdk/plugins/categories/neural_net/templates/config.yaml +0 -12
- synapse_sdk/plugins/categories/neural_net/templates/plugin/__init__.py +0 -0
- synapse_sdk/plugins/categories/neural_net/templates/plugin/inference.py +0 -4
- synapse_sdk/plugins/categories/neural_net/templates/plugin/test.py +0 -2
- synapse_sdk/plugins/categories/neural_net/templates/plugin/train.py +0 -14
- synapse_sdk/plugins/categories/post_annotation/__init__.py +0 -0
- synapse_sdk/plugins/categories/post_annotation/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/post_annotation/actions/post_annotation.py +0 -10
- synapse_sdk/plugins/categories/post_annotation/templates/config.yaml +0 -3
- synapse_sdk/plugins/categories/post_annotation/templates/plugin/__init__.py +0 -0
- synapse_sdk/plugins/categories/post_annotation/templates/plugin/post_annotation.py +0 -3
- synapse_sdk/plugins/categories/pre_annotation/__init__.py +0 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/pre_annotation.py +0 -10
- synapse_sdk/plugins/categories/pre_annotation/templates/config.yaml +0 -3
- synapse_sdk/plugins/categories/pre_annotation/templates/plugin/__init__.py +0 -0
- synapse_sdk/plugins/categories/pre_annotation/templates/plugin/pre_annotation.py +0 -3
- synapse_sdk/plugins/categories/registry.py +0 -16
- synapse_sdk/plugins/categories/smart_tool/__init__.py +0 -0
- synapse_sdk/plugins/categories/smart_tool/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/smart_tool/actions/auto_label.py +0 -37
- synapse_sdk/plugins/categories/smart_tool/templates/config.yaml +0 -7
- synapse_sdk/plugins/categories/smart_tool/templates/plugin/__init__.py +0 -0
- synapse_sdk/plugins/categories/smart_tool/templates/plugin/auto_label.py +0 -11
- synapse_sdk/plugins/categories/templates.py +0 -32
- synapse_sdk/plugins/cli/__init__.py +0 -21
- synapse_sdk/plugins/cli/publish.py +0 -37
- synapse_sdk/plugins/cli/run.py +0 -67
- synapse_sdk/plugins/exceptions.py +0 -22
- synapse_sdk/plugins/models.py +0 -121
- synapse_sdk/plugins/templates/cookiecutter.json +0 -11
- synapse_sdk/plugins/templates/hooks/post_gen_project.py +0 -3
- synapse_sdk/plugins/templates/hooks/pre_prompt.py +0 -21
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/.env +0 -24
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/.env.dist +0 -24
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/.gitignore +0 -27
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/.pre-commit-config.yaml +0 -7
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/README.md +0 -5
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/config.yaml +0 -6
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/main.py +0 -4
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/plugin/__init__.py +0 -0
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/pyproject.toml +0 -13
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/requirements.txt +0 -1
- synapse_sdk/shared/enums.py +0 -8
- synapse_sdk/utils/debug.py +0 -5
- synapse_sdk/utils/file.py +0 -87
- synapse_sdk/utils/module_loading.py +0 -29
- synapse_sdk/utils/pydantic/__init__.py +0 -0
- synapse_sdk/utils/pydantic/config.py +0 -4
- synapse_sdk/utils/pydantic/errors.py +0 -33
- synapse_sdk/utils/pydantic/validators.py +0 -7
- synapse_sdk/utils/storage.py +0 -91
- synapse_sdk/utils/string.py +0 -11
- synapse_sdk-1.0.0a11.dist-info/LICENSE +0 -21
- synapse_sdk-1.0.0a11.dist-info/METADATA +0 -43
- synapse_sdk-1.0.0a11.dist-info/RECORD +0 -111
- {synapse_sdk-1.0.0a11.dist-info → synapse_sdk-2026.1.1b2.dist-info}/entry_points.txt +0 -0
- {synapse_sdk-1.0.0a11.dist-info → synapse_sdk-2026.1.1b2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
"""Deployment action base class for Ray Serve deployments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from synapse_sdk.plugins.action import BaseAction
|
|
10
|
+
from synapse_sdk.plugins.actions.inference.context import DeploymentContext
|
|
11
|
+
from synapse_sdk.plugins.steps import Orchestrator, StepRegistry
|
|
12
|
+
|
|
13
|
+
P = TypeVar('P', bound=BaseModel)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from synapse_sdk.clients.agent import AgentClient
|
|
17
|
+
from synapse_sdk.clients.backend import BackendClient
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DeploymentProgressCategories:
|
|
21
|
+
"""Standard progress category names for deployment workflows.
|
|
22
|
+
|
|
23
|
+
Use these constants with set_progress() to track deployment phases:
|
|
24
|
+
- INITIALIZE: Ray cluster initialization
|
|
25
|
+
- DEPLOY: Deploying to Ray Serve
|
|
26
|
+
- REGISTER: Registering with backend
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> self.set_progress(1, 3, self.progress.INITIALIZE)
|
|
30
|
+
>>> self.set_progress(2, 3, self.progress.DEPLOY)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
INITIALIZE: str = 'initialize'
|
|
34
|
+
DEPLOY: str = 'deploy'
|
|
35
|
+
REGISTER: str = 'register'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BaseDeploymentAction(BaseAction[P]):
|
|
39
|
+
"""Base class for Ray Serve deployment actions.
|
|
40
|
+
|
|
41
|
+
Provides helper methods for deploying inference endpoints to Ray Serve.
|
|
42
|
+
Handles Ray initialization, deployment creation, and backend registration.
|
|
43
|
+
|
|
44
|
+
Supports two execution modes:
|
|
45
|
+
1. Simple execute: Override execute() directly for simple deployments
|
|
46
|
+
2. Step-based: Override setup_steps() to register workflow steps
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
progress: Standard progress category names.
|
|
50
|
+
entrypoint: The serve deployment class to deploy (set in subclass).
|
|
51
|
+
|
|
52
|
+
Example (simple execute):
|
|
53
|
+
>>> class MyDeploymentAction(BaseDeploymentAction[MyParams]):
|
|
54
|
+
... action_name = 'deployment'
|
|
55
|
+
... category = 'neural_net'
|
|
56
|
+
... params_model = MyParams
|
|
57
|
+
... entrypoint = MyServeDeployment
|
|
58
|
+
...
|
|
59
|
+
... def execute(self) -> dict[str, Any]:
|
|
60
|
+
... self.ray_init()
|
|
61
|
+
... self.set_progress(1, 3, self.progress.INITIALIZE)
|
|
62
|
+
... self.deploy()
|
|
63
|
+
... self.set_progress(2, 3, self.progress.DEPLOY)
|
|
64
|
+
... app_id = self.register_serve_application()
|
|
65
|
+
... self.set_progress(3, 3, self.progress.REGISTER)
|
|
66
|
+
... return {'serve_application': app_id}
|
|
67
|
+
|
|
68
|
+
Example (step-based):
|
|
69
|
+
>>> class MyDeploymentAction(BaseDeploymentAction[MyParams]):
|
|
70
|
+
... entrypoint = MyServeDeployment
|
|
71
|
+
...
|
|
72
|
+
... def setup_steps(self, registry: StepRegistry[DeploymentContext]) -> None:
|
|
73
|
+
... registry.register(InitializeRayStep())
|
|
74
|
+
... registry.register(DeployStep())
|
|
75
|
+
... registry.register(RegisterStep())
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
progress = DeploymentProgressCategories()
|
|
79
|
+
|
|
80
|
+
# Override in subclass with your serve deployment class
|
|
81
|
+
entrypoint: type | None = None
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def client(self) -> BackendClient:
|
|
85
|
+
"""Backend client from context.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
BackendClient instance.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
RuntimeError: If no client in context.
|
|
92
|
+
"""
|
|
93
|
+
if self.ctx.client is None:
|
|
94
|
+
raise RuntimeError('No client in context. Provide a client via RuntimeContext.')
|
|
95
|
+
return self.ctx.client
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def agent_client(self) -> AgentClient:
|
|
99
|
+
"""Agent client from context.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
AgentClient instance for Ray operations.
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
RuntimeError: If no agent_client in context.
|
|
106
|
+
"""
|
|
107
|
+
if self.ctx.agent_client is None:
|
|
108
|
+
raise RuntimeError('No agent_client in context. Provide an agent_client via RuntimeContext.')
|
|
109
|
+
return self.ctx.agent_client
|
|
110
|
+
|
|
111
|
+
def setup_steps(self, registry: StepRegistry[DeploymentContext]) -> None:
|
|
112
|
+
"""Register workflow steps for step-based execution.
|
|
113
|
+
|
|
114
|
+
Override this method to register custom steps for deployment workflow.
|
|
115
|
+
If steps are registered, step-based execution takes precedence.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
registry: StepRegistry to register steps with.
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
>>> def setup_steps(self, registry: StepRegistry[DeploymentContext]) -> None:
|
|
122
|
+
... registry.register(InitializeRayStep())
|
|
123
|
+
... registry.register(DeployStep())
|
|
124
|
+
... registry.register(RegisterStep())
|
|
125
|
+
"""
|
|
126
|
+
pass # Default: no steps, uses simple execute()
|
|
127
|
+
|
|
128
|
+
def create_context(self) -> DeploymentContext:
|
|
129
|
+
"""Create deployment context for step-based workflow.
|
|
130
|
+
|
|
131
|
+
Override to customize context creation or add additional state.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
DeploymentContext instance with params and runtime context.
|
|
135
|
+
"""
|
|
136
|
+
params_dict = self.params.model_dump() if hasattr(self.params, 'model_dump') else dict(self.params)
|
|
137
|
+
return DeploymentContext(
|
|
138
|
+
runtime_ctx=self.ctx,
|
|
139
|
+
params=params_dict,
|
|
140
|
+
model_id=params_dict.get('model_id'),
|
|
141
|
+
serve_app_name=self.get_serve_app_name(),
|
|
142
|
+
route_prefix=self.get_route_prefix(),
|
|
143
|
+
ray_actor_options=self.get_ray_actor_options(),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def run(self) -> Any:
|
|
147
|
+
"""Run the action, using steps if registered.
|
|
148
|
+
|
|
149
|
+
This method is called by executors. It checks if steps are
|
|
150
|
+
registered and uses step-based execution if so.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Action result (dict or any return type).
|
|
154
|
+
"""
|
|
155
|
+
# Check if steps are registered
|
|
156
|
+
registry: StepRegistry[DeploymentContext] = StepRegistry()
|
|
157
|
+
self.setup_steps(registry)
|
|
158
|
+
|
|
159
|
+
if registry:
|
|
160
|
+
# Step-based execution
|
|
161
|
+
context = self.create_context()
|
|
162
|
+
orchestrator: Orchestrator[DeploymentContext] = Orchestrator(
|
|
163
|
+
registry=registry,
|
|
164
|
+
context=context,
|
|
165
|
+
progress_callback=lambda curr, total: self.set_progress(curr, total),
|
|
166
|
+
)
|
|
167
|
+
result = orchestrator.execute()
|
|
168
|
+
|
|
169
|
+
# Add context data to result
|
|
170
|
+
if context.serve_app_id:
|
|
171
|
+
result['serve_application'] = context.serve_app_id
|
|
172
|
+
result['deployed'] = context.deployed
|
|
173
|
+
|
|
174
|
+
return result
|
|
175
|
+
|
|
176
|
+
# Simple execute mode
|
|
177
|
+
return self.execute()
|
|
178
|
+
|
|
179
|
+
def get_serve_app_name(self) -> str:
|
|
180
|
+
"""Get the name for the Ray Serve application.
|
|
181
|
+
|
|
182
|
+
Default uses plugin release code from SYNAPSE_PLUGIN_RELEASE_CODE env var.
|
|
183
|
+
Override for custom naming.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Serve application name.
|
|
187
|
+
"""
|
|
188
|
+
return self.ctx.env.get_str('SYNAPSE_PLUGIN_RELEASE_CODE', 'synapse-serve-app') or 'synapse-serve-app'
|
|
189
|
+
|
|
190
|
+
def get_route_prefix(self) -> str:
|
|
191
|
+
"""Get the route prefix for the deployment.
|
|
192
|
+
|
|
193
|
+
Default uses plugin release checksum from SYNAPSE_PLUGIN_RELEASE_CHECKSUM env var.
|
|
194
|
+
Override for custom routing.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Route prefix string (e.g., '/abc123').
|
|
198
|
+
"""
|
|
199
|
+
checksum = self.ctx.env.get_str('SYNAPSE_PLUGIN_RELEASE_CHECKSUM', 'default') or 'default'
|
|
200
|
+
return f'/{checksum}'
|
|
201
|
+
|
|
202
|
+
def get_ray_actor_options(self) -> dict[str, Any]:
|
|
203
|
+
"""Get Ray actor options for the deployment.
|
|
204
|
+
|
|
205
|
+
Default extracts num_cpus and num_gpus from params.
|
|
206
|
+
Override for custom resource allocation.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Dict with Ray actor options (num_cpus, num_gpus, etc.).
|
|
210
|
+
"""
|
|
211
|
+
options: dict[str, Any] = {
|
|
212
|
+
'runtime_env': self.get_runtime_env(),
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
params_dict = self.params.model_dump() if hasattr(self.params, 'model_dump') else dict(self.params)
|
|
216
|
+
|
|
217
|
+
for option in ['num_cpus', 'num_gpus', 'memory']:
|
|
218
|
+
if value := params_dict.get(option):
|
|
219
|
+
options[option] = value
|
|
220
|
+
|
|
221
|
+
return options
|
|
222
|
+
|
|
223
|
+
def get_runtime_env(self) -> dict[str, Any]:
|
|
224
|
+
"""Get Ray runtime environment.
|
|
225
|
+
|
|
226
|
+
Override to customize the runtime environment for deployments.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Dict with runtime environment configuration.
|
|
230
|
+
"""
|
|
231
|
+
return {}
|
|
232
|
+
|
|
233
|
+
def ray_init(self, **kwargs: Any) -> None:
|
|
234
|
+
"""Initialize Ray cluster connection.
|
|
235
|
+
|
|
236
|
+
Call this before deploying to ensure Ray is connected.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
**kwargs: Additional arguments for ray.init().
|
|
240
|
+
"""
|
|
241
|
+
try:
|
|
242
|
+
import ray
|
|
243
|
+
except ImportError:
|
|
244
|
+
raise ImportError("Ray is required for deployment actions. Install with: pip install 'synapse-sdk[ray]'")
|
|
245
|
+
|
|
246
|
+
if not ray.is_initialized():
|
|
247
|
+
ray.init(**kwargs)
|
|
248
|
+
|
|
249
|
+
def deploy(self) -> None:
|
|
250
|
+
"""Deploy the inference endpoint to Ray Serve.
|
|
251
|
+
|
|
252
|
+
Uses the entrypoint class and current configuration to create
|
|
253
|
+
a Ray Serve deployment.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
RuntimeError: If entrypoint is not set.
|
|
257
|
+
ImportError: If Ray Serve is not installed.
|
|
258
|
+
"""
|
|
259
|
+
if self.entrypoint is None:
|
|
260
|
+
raise RuntimeError(
|
|
261
|
+
'entrypoint must be set to a serve deployment class. Example: entrypoint = MyServeDeployment'
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
from ray import serve
|
|
266
|
+
except ImportError:
|
|
267
|
+
raise ImportError(
|
|
268
|
+
"Ray Serve is required for deployment actions. Install with: pip install 'synapse-sdk[ray]'"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Get deployment configuration
|
|
272
|
+
app_name = self.get_serve_app_name()
|
|
273
|
+
route_prefix = self.get_route_prefix()
|
|
274
|
+
ray_actor_options = self.get_ray_actor_options()
|
|
275
|
+
|
|
276
|
+
# Delete existing deployment if present
|
|
277
|
+
try:
|
|
278
|
+
serve.delete(app_name)
|
|
279
|
+
except Exception:
|
|
280
|
+
pass # Ignore if not exists
|
|
281
|
+
|
|
282
|
+
# Get backend URL for the deployment
|
|
283
|
+
backend_url = self.ctx.env.get_str('SYNAPSE_PLUGIN_RUN_HOST', '') or ''
|
|
284
|
+
|
|
285
|
+
# Create and deploy
|
|
286
|
+
# The entrypoint should be a class that implements BaseServeDeployment
|
|
287
|
+
deployment = serve.deployment(ray_actor_options=ray_actor_options)(self.entrypoint).bind(backend_url)
|
|
288
|
+
|
|
289
|
+
serve.run(
|
|
290
|
+
deployment,
|
|
291
|
+
name=app_name,
|
|
292
|
+
route_prefix=route_prefix,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def register_serve_application(self) -> int | None:
|
|
296
|
+
"""Register the serve application with the backend.
|
|
297
|
+
|
|
298
|
+
Creates a serve application record in the backend for tracking.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Serve application ID if created, None otherwise.
|
|
302
|
+
"""
|
|
303
|
+
job_id = self.ctx.job_id
|
|
304
|
+
if not job_id:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
app_name = self.get_serve_app_name()
|
|
308
|
+
|
|
309
|
+
# Get serve application status from Ray
|
|
310
|
+
try:
|
|
311
|
+
serve_app = self.agent_client.get_serve_application(app_name)
|
|
312
|
+
except Exception:
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
# Register with backend
|
|
316
|
+
result = self.client.create_serve_application({
|
|
317
|
+
'job': job_id,
|
|
318
|
+
'status': serve_app.get('status'),
|
|
319
|
+
'data': serve_app,
|
|
320
|
+
})
|
|
321
|
+
|
|
322
|
+
return result.get('id')
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Base Ray Serve deployment class for inference endpoints."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import tempfile
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseServeDeployment(ABC):
|
|
11
|
+
"""Base class for Ray Serve inference deployments.
|
|
12
|
+
|
|
13
|
+
Provides model loading with multiplexing support. Subclasses implement
|
|
14
|
+
_get_model() to load their specific model format and infer() to run
|
|
15
|
+
inference.
|
|
16
|
+
|
|
17
|
+
This class is designed to be used with Ray Serve's @serve.deployment
|
|
18
|
+
decorator and supports model multiplexing via @serve.multiplexed().
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
backend_url: URL of the Synapse backend for model fetching.
|
|
22
|
+
_model_cache: Internal cache for loaded models.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
>>> from ray import serve
|
|
26
|
+
>>> from fastapi import FastAPI
|
|
27
|
+
>>>
|
|
28
|
+
>>> app = FastAPI()
|
|
29
|
+
>>>
|
|
30
|
+
>>> @serve.deployment
|
|
31
|
+
>>> @serve.ingress(app)
|
|
32
|
+
>>> class MyInference(BaseServeDeployment):
|
|
33
|
+
... async def _get_model(self, model_info: dict) -> Any:
|
|
34
|
+
... import torch
|
|
35
|
+
... return torch.load(model_info['path'] / 'model.pt')
|
|
36
|
+
...
|
|
37
|
+
... async def infer(self, inputs: list[dict]) -> list[dict]:
|
|
38
|
+
... model = await self.get_model()
|
|
39
|
+
... return [{'prediction': model(inp)} for inp in inputs]
|
|
40
|
+
>>>
|
|
41
|
+
>>> # Deploy with:
|
|
42
|
+
>>> deployment = MyInference.bind(backend_url='https://api.example.com')
|
|
43
|
+
>>> serve.run(deployment)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, backend_url: str) -> None:
|
|
47
|
+
"""Initialize the serve deployment.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
backend_url: URL of the Synapse backend for fetching models.
|
|
51
|
+
"""
|
|
52
|
+
self.backend_url = backend_url
|
|
53
|
+
self._model_cache: dict[str, Any] = {}
|
|
54
|
+
|
|
55
|
+
async def _load_model_from_token(self, model_token: str) -> Any:
|
|
56
|
+
"""Load model from an encoded token.
|
|
57
|
+
|
|
58
|
+
Decodes the JWT token containing model info, fetches model from
|
|
59
|
+
backend, downloads and extracts artifacts, then calls _get_model().
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
model_token: JWT-encoded token with model info.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Loaded model object (format depends on _get_model implementation).
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ImportError: If jwt or required dependencies not installed.
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
import jwt
|
|
72
|
+
except ImportError:
|
|
73
|
+
raise ImportError('PyJWT is required for model token decoding. Install with: pip install PyJWT')
|
|
74
|
+
|
|
75
|
+
# Decode token to get model info
|
|
76
|
+
model_info = jwt.decode(model_token, self.backend_url, algorithms=['HS256'])
|
|
77
|
+
|
|
78
|
+
# Create backend client with user credentials
|
|
79
|
+
from synapse_sdk.clients.backend import BackendClient
|
|
80
|
+
|
|
81
|
+
client = BackendClient(
|
|
82
|
+
base_url=self.backend_url,
|
|
83
|
+
access_token=model_info['token'],
|
|
84
|
+
tenant=model_info.get('tenant'),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Fetch model metadata
|
|
88
|
+
model = client.get_model(int(model_info['model']))
|
|
89
|
+
|
|
90
|
+
if not model.get('file'):
|
|
91
|
+
raise ValueError(f'Model {model_info["model"]} has no file URL')
|
|
92
|
+
|
|
93
|
+
# Download and extract model
|
|
94
|
+
with tempfile.TemporaryDirectory() as temp_path:
|
|
95
|
+
from pathlib import Path
|
|
96
|
+
|
|
97
|
+
from synapse_sdk.utils.file.archive import extract_archive
|
|
98
|
+
from synapse_sdk.utils.file.download import download_file
|
|
99
|
+
|
|
100
|
+
archive_path = Path(temp_path) / 'model.zip'
|
|
101
|
+
download_file(model['file'], archive_path)
|
|
102
|
+
extract_archive(archive_path, temp_path)
|
|
103
|
+
|
|
104
|
+
model['path'] = Path(temp_path)
|
|
105
|
+
return await self._get_model(model)
|
|
106
|
+
|
|
107
|
+
async def get_model(self) -> Any:
|
|
108
|
+
"""Get the current model for inference.
|
|
109
|
+
|
|
110
|
+
Uses Ray Serve's multiplexing to load the appropriate model
|
|
111
|
+
based on the request's multiplexed model ID header.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Loaded model object.
|
|
115
|
+
|
|
116
|
+
Note:
|
|
117
|
+
This method uses Ray Serve's @serve.multiplexed() decorator
|
|
118
|
+
internally. Ensure requests include the appropriate header.
|
|
119
|
+
"""
|
|
120
|
+
# Import here to avoid issues when Ray is not installed
|
|
121
|
+
try:
|
|
122
|
+
from ray import serve
|
|
123
|
+
except ImportError:
|
|
124
|
+
raise ImportError("Ray Serve is required. Install with: pip install 'synapse-sdk[ray]'")
|
|
125
|
+
|
|
126
|
+
model_id = serve.get_multiplexed_model_id()
|
|
127
|
+
return await self._load_model_multiplexed(model_id)
|
|
128
|
+
|
|
129
|
+
async def _load_model_multiplexed(self, model_id: str) -> Any:
|
|
130
|
+
"""Load model with multiplexing support.
|
|
131
|
+
|
|
132
|
+
This method is decorated with @serve.multiplexed() to enable
|
|
133
|
+
model multiplexing in Ray Serve deployments.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
model_id: The model token/ID from request header.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Loaded model object.
|
|
140
|
+
"""
|
|
141
|
+
# Check cache first
|
|
142
|
+
if model_id in self._model_cache:
|
|
143
|
+
return self._model_cache[model_id]
|
|
144
|
+
|
|
145
|
+
# Load and cache
|
|
146
|
+
model = await self._load_model_from_token(model_id)
|
|
147
|
+
self._model_cache[model_id] = model
|
|
148
|
+
return model
|
|
149
|
+
|
|
150
|
+
@abstractmethod
|
|
151
|
+
async def _get_model(self, model_info: dict[str, Any]) -> Any:
|
|
152
|
+
"""Load model from extracted artifacts.
|
|
153
|
+
|
|
154
|
+
Override this method to implement your specific model loading logic.
|
|
155
|
+
Called after model artifacts are downloaded and extracted.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
model_info: Model metadata dict with 'path' key for local artifacts.
|
|
159
|
+
The path is a Path object pointing to extracted directory.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Loaded model object (framework-specific).
|
|
163
|
+
|
|
164
|
+
Example (PyTorch):
|
|
165
|
+
>>> async def _get_model(self, model_info: dict) -> Any:
|
|
166
|
+
... import torch
|
|
167
|
+
... model_path = model_info['path'] / 'model.pt'
|
|
168
|
+
... return torch.load(model_path)
|
|
169
|
+
|
|
170
|
+
Example (ONNX):
|
|
171
|
+
>>> async def _get_model(self, model_info: dict) -> Any:
|
|
172
|
+
... import onnxruntime as ort
|
|
173
|
+
... model_path = model_info['path'] / 'model.onnx'
|
|
174
|
+
... return ort.InferenceSession(str(model_path))
|
|
175
|
+
"""
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
'Override _get_model() to load your model format. '
|
|
178
|
+
'Example: return torch.load(model_info["path"] / "model.pt")'
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
async def infer(self, *args: Any, **kwargs: Any) -> Any:
|
|
183
|
+
"""Run inference on inputs.
|
|
184
|
+
|
|
185
|
+
Override this method to implement your inference logic.
|
|
186
|
+
Use self.get_model() to obtain the loaded model.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
*args: Inference inputs (format depends on implementation).
|
|
190
|
+
**kwargs: Additional inference parameters.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Inference results (format depends on implementation).
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> async def infer(self, inputs: list[dict]) -> list[dict]:
|
|
197
|
+
... model = await self.get_model()
|
|
198
|
+
... results = []
|
|
199
|
+
... for inp in inputs:
|
|
200
|
+
... prediction = model.predict(inp['data'])
|
|
201
|
+
... results.append({'prediction': prediction.tolist()})
|
|
202
|
+
... return results
|
|
203
|
+
"""
|
|
204
|
+
raise NotImplementedError(
|
|
205
|
+
'Override infer() to implement inference logic. Example: return model.predict(inputs)'
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def create_serve_multiplexed_model_id(
|
|
210
|
+
model_id: int | str,
|
|
211
|
+
token: str,
|
|
212
|
+
backend_url: str,
|
|
213
|
+
tenant: str | None = None,
|
|
214
|
+
) -> str:
|
|
215
|
+
"""Create a JWT-encoded model ID for serve multiplexing.
|
|
216
|
+
|
|
217
|
+
This helper creates the token that should be passed in the
|
|
218
|
+
'serve_multiplexed_model_id' header for inference requests.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
model_id: The model ID to encode.
|
|
222
|
+
token: User access token for authentication.
|
|
223
|
+
backend_url: Backend URL (used as JWT secret).
|
|
224
|
+
tenant: Optional tenant identifier.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
JWT-encoded model token string.
|
|
228
|
+
|
|
229
|
+
Example:
|
|
230
|
+
>>> model_token = create_serve_multiplexed_model_id(
|
|
231
|
+
... model_id=123,
|
|
232
|
+
... token='user_access_token',
|
|
233
|
+
... backend_url='https://api.example.com',
|
|
234
|
+
... tenant='my-tenant',
|
|
235
|
+
... )
|
|
236
|
+
>>> # Use in request headers:
|
|
237
|
+
>>> headers = {'serve_multiplexed_model_id': model_token}
|
|
238
|
+
"""
|
|
239
|
+
try:
|
|
240
|
+
import jwt
|
|
241
|
+
except ImportError:
|
|
242
|
+
raise ImportError('PyJWT is required for model token encoding. Install with: pip install PyJWT')
|
|
243
|
+
|
|
244
|
+
payload: dict[str, Any] = {
|
|
245
|
+
'model': str(model_id),
|
|
246
|
+
'token': token,
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
if tenant:
|
|
250
|
+
payload['tenant'] = tenant
|
|
251
|
+
|
|
252
|
+
return jwt.encode(payload, backend_url, algorithm='HS256')
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Train action module with optional workflow step support.
|
|
2
|
+
|
|
3
|
+
Provides the training action base class:
|
|
4
|
+
- BaseTrainAction: Base class for training workflows
|
|
5
|
+
- TrainContext: Training-specific context extending BaseStepContext
|
|
6
|
+
- TrainProgressCategories: Standard progress category names
|
|
7
|
+
|
|
8
|
+
For step infrastructure (BaseStep, StepRegistry, Orchestrator),
|
|
9
|
+
use the steps module:
|
|
10
|
+
from synapse_sdk.plugins.steps import BaseStep, StepRegistry
|
|
11
|
+
|
|
12
|
+
Example (simple execute):
|
|
13
|
+
>>> class MyTrainAction(BaseTrainAction[MyParams]):
|
|
14
|
+
... def execute(self) -> dict[str, Any]:
|
|
15
|
+
... dataset = self.get_dataset()
|
|
16
|
+
... # ... train model ...
|
|
17
|
+
... return {'model_id': model['id']}
|
|
18
|
+
|
|
19
|
+
Example (step-based):
|
|
20
|
+
>>> from synapse_sdk.plugins.steps import BaseStep, StepResult
|
|
21
|
+
>>>
|
|
22
|
+
>>> class LoadDatasetStep(BaseStep[TrainContext]):
|
|
23
|
+
... @property
|
|
24
|
+
... def name(self) -> str:
|
|
25
|
+
... return 'load_dataset'
|
|
26
|
+
...
|
|
27
|
+
... @property
|
|
28
|
+
... def progress_weight(self) -> float:
|
|
29
|
+
... return 0.2
|
|
30
|
+
...
|
|
31
|
+
... def execute(self, context: TrainContext) -> StepResult:
|
|
32
|
+
... context.dataset = load_data(context.params['dataset_id'])
|
|
33
|
+
... return StepResult(success=True)
|
|
34
|
+
>>>
|
|
35
|
+
>>> class MyTrainAction(BaseTrainAction[MyParams]):
|
|
36
|
+
... def setup_steps(self, registry) -> None:
|
|
37
|
+
... registry.register(LoadDatasetStep())
|
|
38
|
+
... registry.register(TrainStep())
|
|
39
|
+
... registry.register(UploadModelStep())
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
from synapse_sdk.plugins.actions.train.action import (
|
|
43
|
+
BaseTrainAction,
|
|
44
|
+
BaseTrainParams,
|
|
45
|
+
TrainProgressCategories,
|
|
46
|
+
)
|
|
47
|
+
from synapse_sdk.plugins.actions.train.context import TrainContext
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
'BaseTrainAction',
|
|
51
|
+
'BaseTrainParams',
|
|
52
|
+
'TrainContext',
|
|
53
|
+
'TrainProgressCategories',
|
|
54
|
+
]
|