synapse-sdk 1.0.0a23__py3-none-any.whl → 2025.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synapse_sdk/__init__.py +24 -0
- synapse_sdk/cli/__init__.py +310 -5
- synapse_sdk/cli/alias/__init__.py +22 -0
- synapse_sdk/cli/alias/create.py +36 -0
- synapse_sdk/cli/alias/dataclass.py +31 -0
- synapse_sdk/cli/alias/default.py +16 -0
- synapse_sdk/cli/alias/delete.py +15 -0
- synapse_sdk/cli/alias/list.py +19 -0
- synapse_sdk/cli/alias/read.py +15 -0
- synapse_sdk/cli/alias/update.py +17 -0
- synapse_sdk/cli/alias/utils.py +61 -0
- synapse_sdk/cli/code_server.py +687 -0
- synapse_sdk/cli/config.py +440 -0
- synapse_sdk/cli/devtools.py +90 -0
- synapse_sdk/cli/plugin/__init__.py +33 -0
- synapse_sdk/cli/{create_plugin.py → plugin/create.py} +2 -2
- synapse_sdk/{plugins/cli → cli/plugin}/publish.py +23 -15
- synapse_sdk/clients/agent/__init__.py +9 -3
- synapse_sdk/clients/agent/container.py +143 -0
- synapse_sdk/clients/agent/core.py +19 -0
- synapse_sdk/clients/agent/ray.py +298 -9
- synapse_sdk/clients/backend/__init__.py +30 -12
- synapse_sdk/clients/backend/annotation.py +13 -5
- synapse_sdk/clients/backend/core.py +31 -4
- synapse_sdk/clients/backend/data_collection.py +186 -0
- synapse_sdk/clients/backend/hitl.py +17 -0
- synapse_sdk/clients/backend/integration.py +16 -1
- synapse_sdk/clients/backend/ml.py +5 -1
- synapse_sdk/clients/backend/models.py +78 -0
- synapse_sdk/clients/base.py +384 -41
- synapse_sdk/clients/ray/serve.py +2 -0
- synapse_sdk/clients/validators/collections.py +31 -0
- synapse_sdk/devtools/config.py +94 -0
- synapse_sdk/devtools/server.py +41 -0
- synapse_sdk/devtools/streamlit_app/__init__.py +5 -0
- synapse_sdk/devtools/streamlit_app/app.py +128 -0
- synapse_sdk/devtools/streamlit_app/services/__init__.py +11 -0
- synapse_sdk/devtools/streamlit_app/services/job_service.py +233 -0
- synapse_sdk/devtools/streamlit_app/services/plugin_service.py +236 -0
- synapse_sdk/devtools/streamlit_app/services/serve_service.py +95 -0
- synapse_sdk/devtools/streamlit_app/ui/__init__.py +15 -0
- synapse_sdk/devtools/streamlit_app/ui/config_tab.py +76 -0
- synapse_sdk/devtools/streamlit_app/ui/deployment_tab.py +66 -0
- synapse_sdk/devtools/streamlit_app/ui/http_tab.py +125 -0
- synapse_sdk/devtools/streamlit_app/ui/jobs_tab.py +573 -0
- synapse_sdk/devtools/streamlit_app/ui/serve_tab.py +346 -0
- synapse_sdk/devtools/streamlit_app/ui/status_bar.py +118 -0
- synapse_sdk/devtools/streamlit_app/utils/__init__.py +40 -0
- synapse_sdk/devtools/streamlit_app/utils/json_viewer.py +197 -0
- synapse_sdk/devtools/streamlit_app/utils/log_formatter.py +38 -0
- synapse_sdk/devtools/streamlit_app/utils/styles.py +241 -0
- synapse_sdk/devtools/streamlit_app/utils/ui_components.py +289 -0
- synapse_sdk/devtools/streamlit_app.py +10 -0
- synapse_sdk/loggers.py +120 -9
- synapse_sdk/plugins/README.md +1340 -0
- synapse_sdk/plugins/__init__.py +0 -13
- synapse_sdk/plugins/categories/base.py +117 -11
- synapse_sdk/plugins/categories/data_validation/actions/validation.py +72 -0
- synapse_sdk/plugins/categories/data_validation/templates/plugin/validation.py +33 -5
- synapse_sdk/plugins/categories/export/actions/__init__.py +3 -0
- synapse_sdk/plugins/categories/export/actions/export/__init__.py +28 -0
- synapse_sdk/plugins/categories/export/actions/export/action.py +165 -0
- synapse_sdk/plugins/categories/export/actions/export/enums.py +113 -0
- synapse_sdk/plugins/categories/export/actions/export/exceptions.py +53 -0
- synapse_sdk/plugins/categories/export/actions/export/models.py +74 -0
- synapse_sdk/plugins/categories/export/actions/export/run.py +195 -0
- synapse_sdk/plugins/categories/export/actions/export/utils.py +187 -0
- synapse_sdk/plugins/categories/export/templates/config.yaml +21 -0
- synapse_sdk/plugins/categories/export/templates/plugin/__init__.py +390 -0
- synapse_sdk/plugins/categories/export/templates/plugin/export.py +160 -0
- synapse_sdk/plugins/categories/neural_net/actions/deployment.py +13 -12
- synapse_sdk/plugins/categories/neural_net/actions/train.py +1134 -31
- synapse_sdk/plugins/categories/neural_net/actions/tune.py +534 -0
- synapse_sdk/plugins/categories/neural_net/base/inference.py +1 -1
- synapse_sdk/plugins/categories/neural_net/templates/config.yaml +32 -4
- synapse_sdk/plugins/categories/neural_net/templates/plugin/inference.py +26 -10
- synapse_sdk/plugins/categories/pre_annotation/actions/__init__.py +4 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/pre_annotation/__init__.py +3 -0
- synapse_sdk/plugins/categories/{export/actions/export.py → pre_annotation/actions/pre_annotation/action.py} +4 -4
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/__init__.py +28 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/action.py +148 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/enums.py +269 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/exceptions.py +14 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/factory.py +76 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/models.py +100 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/orchestrator.py +248 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/run.py +64 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/__init__.py +17 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/annotation.py +265 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/base.py +170 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/extraction.py +83 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/metrics.py +92 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/preprocessor.py +243 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/validation.py +143 -0
- synapse_sdk/plugins/categories/pre_annotation/templates/config.yaml +19 -0
- synapse_sdk/plugins/categories/pre_annotation/templates/plugin/to_task.py +40 -0
- synapse_sdk/plugins/categories/smart_tool/templates/config.yaml +2 -0
- synapse_sdk/plugins/categories/upload/__init__.py +0 -0
- synapse_sdk/plugins/categories/upload/actions/__init__.py +0 -0
- synapse_sdk/plugins/categories/upload/actions/upload/__init__.py +19 -0
- synapse_sdk/plugins/categories/upload/actions/upload/action.py +236 -0
- synapse_sdk/plugins/categories/upload/actions/upload/context.py +185 -0
- synapse_sdk/plugins/categories/upload/actions/upload/enums.py +493 -0
- synapse_sdk/plugins/categories/upload/actions/upload/exceptions.py +36 -0
- synapse_sdk/plugins/categories/upload/actions/upload/factory.py +138 -0
- synapse_sdk/plugins/categories/upload/actions/upload/models.py +214 -0
- synapse_sdk/plugins/categories/upload/actions/upload/orchestrator.py +183 -0
- synapse_sdk/plugins/categories/upload/actions/upload/registry.py +113 -0
- synapse_sdk/plugins/categories/upload/actions/upload/run.py +179 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/base.py +107 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/cleanup.py +62 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/collection.py +63 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/generate.py +91 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/initialize.py +82 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/metadata.py +235 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/organize.py +201 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/upload.py +104 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/validate.py +71 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/base.py +82 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/batch.py +39 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/single.py +29 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/flat.py +300 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/recursive.py +287 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/excel.py +174 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/none.py +16 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/upload/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/upload/sync.py +84 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/validation/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/validation/default.py +60 -0
- synapse_sdk/plugins/categories/upload/actions/upload/utils.py +250 -0
- synapse_sdk/plugins/categories/upload/templates/README.md +470 -0
- synapse_sdk/plugins/categories/upload/templates/config.yaml +33 -0
- synapse_sdk/plugins/categories/upload/templates/plugin/__init__.py +310 -0
- synapse_sdk/plugins/categories/upload/templates/plugin/upload.py +102 -0
- synapse_sdk/plugins/enums.py +3 -1
- synapse_sdk/plugins/models.py +148 -11
- synapse_sdk/plugins/templates/plugin-config-schema.json +406 -0
- synapse_sdk/plugins/templates/schema.json +491 -0
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/config.yaml +1 -0
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/requirements.txt +1 -1
- synapse_sdk/plugins/utils/__init__.py +46 -0
- synapse_sdk/plugins/utils/actions.py +119 -0
- synapse_sdk/plugins/utils/config.py +203 -0
- synapse_sdk/plugins/{utils.py → utils/legacy.py} +26 -46
- synapse_sdk/plugins/utils/ray_gcs.py +66 -0
- synapse_sdk/plugins/utils/registry.py +58 -0
- synapse_sdk/shared/__init__.py +25 -0
- synapse_sdk/shared/enums.py +93 -0
- synapse_sdk/types.py +19 -0
- synapse_sdk/utils/converters/__init__.py +240 -0
- synapse_sdk/utils/converters/coco/__init__.py +0 -0
- synapse_sdk/utils/converters/coco/from_dm.py +322 -0
- synapse_sdk/utils/converters/coco/to_dm.py +215 -0
- synapse_sdk/utils/converters/dm/__init__.py +57 -0
- synapse_sdk/utils/converters/dm/base.py +137 -0
- synapse_sdk/utils/converters/dm/from_v1.py +273 -0
- synapse_sdk/utils/converters/dm/to_v1.py +321 -0
- synapse_sdk/utils/converters/dm/tools/__init__.py +214 -0
- synapse_sdk/utils/converters/dm/tools/answer.py +95 -0
- synapse_sdk/utils/converters/dm/tools/bounding_box.py +132 -0
- synapse_sdk/utils/converters/dm/tools/bounding_box_3d.py +121 -0
- synapse_sdk/utils/converters/dm/tools/classification.py +75 -0
- synapse_sdk/utils/converters/dm/tools/keypoint.py +117 -0
- synapse_sdk/utils/converters/dm/tools/named_entity.py +111 -0
- synapse_sdk/utils/converters/dm/tools/polygon.py +122 -0
- synapse_sdk/utils/converters/dm/tools/polyline.py +124 -0
- synapse_sdk/utils/converters/dm/tools/prompt.py +94 -0
- synapse_sdk/utils/converters/dm/tools/relation.py +86 -0
- synapse_sdk/utils/converters/dm/tools/segmentation.py +141 -0
- synapse_sdk/utils/converters/dm/tools/segmentation_3d.py +83 -0
- synapse_sdk/utils/converters/dm/types.py +168 -0
- synapse_sdk/utils/converters/dm/utils.py +162 -0
- synapse_sdk/utils/converters/dm_legacy/__init__.py +56 -0
- synapse_sdk/utils/converters/dm_legacy/from_v1.py +627 -0
- synapse_sdk/utils/converters/dm_legacy/to_v1.py +367 -0
- synapse_sdk/utils/converters/pascal/__init__.py +0 -0
- synapse_sdk/utils/converters/pascal/from_dm.py +244 -0
- synapse_sdk/utils/converters/pascal/to_dm.py +214 -0
- synapse_sdk/utils/converters/yolo/__init__.py +0 -0
- synapse_sdk/utils/converters/yolo/from_dm.py +384 -0
- synapse_sdk/utils/converters/yolo/to_dm.py +267 -0
- synapse_sdk/utils/dataset.py +46 -0
- synapse_sdk/utils/encryption.py +158 -0
- synapse_sdk/utils/file/__init__.py +58 -0
- synapse_sdk/utils/file/archive.py +32 -0
- synapse_sdk/utils/file/checksum.py +56 -0
- synapse_sdk/utils/file/chunking.py +31 -0
- synapse_sdk/utils/file/download.py +385 -0
- synapse_sdk/utils/file/encoding.py +40 -0
- synapse_sdk/utils/file/io.py +22 -0
- synapse_sdk/utils/file/upload.py +165 -0
- synapse_sdk/utils/file/video/__init__.py +29 -0
- synapse_sdk/utils/file/video/transcode.py +307 -0
- synapse_sdk/utils/file.py.backup +301 -0
- synapse_sdk/utils/http.py +138 -0
- synapse_sdk/utils/network.py +309 -0
- synapse_sdk/utils/storage/__init__.py +72 -0
- synapse_sdk/utils/storage/providers/__init__.py +183 -0
- synapse_sdk/utils/storage/providers/file_system.py +134 -0
- synapse_sdk/utils/storage/providers/gcp.py +13 -0
- synapse_sdk/utils/storage/providers/http.py +190 -0
- synapse_sdk/utils/storage/providers/s3.py +91 -0
- synapse_sdk/utils/storage/providers/sftp.py +47 -0
- synapse_sdk/utils/storage/registry.py +17 -0
- synapse_sdk-2025.12.3.dist-info/METADATA +123 -0
- synapse_sdk-2025.12.3.dist-info/RECORD +279 -0
- {synapse_sdk-1.0.0a23.dist-info → synapse_sdk-2025.12.3.dist-info}/WHEEL +1 -1
- synapse_sdk/clients/backend/dataset.py +0 -51
- synapse_sdk/plugins/categories/import/actions/import.py +0 -10
- synapse_sdk/plugins/cli/__init__.py +0 -21
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/.env +0 -24
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/.env.dist +0 -24
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/main.py +0 -4
- synapse_sdk/utils/file.py +0 -168
- synapse_sdk/utils/storage.py +0 -91
- synapse_sdk-1.0.0a23.dist-info/METADATA +0 -44
- synapse_sdk-1.0.0a23.dist-info/RECORD +0 -114
- /synapse_sdk/{plugins/cli → cli/plugin}/run.py +0 -0
- /synapse_sdk/{plugins/categories/import → clients/validators}/__init__.py +0 -0
- /synapse_sdk/{plugins/categories/import/actions → devtools}/__init__.py +0 -0
- {synapse_sdk-1.0.0a23.dist-info → synapse_sdk-2025.12.3.dist-info}/entry_points.txt +0 -0
- {synapse_sdk-1.0.0a23.dist-info → synapse_sdk-2025.12.3.dist-info/licenses}/LICENSE +0 -0
- {synapse_sdk-1.0.0a23.dist-info → synapse_sdk-2025.12.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import tempfile
|
|
3
|
+
from numbers import Number
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Annotated, Optional
|
|
6
|
+
|
|
7
|
+
from pydantic import AfterValidator, BaseModel, field_validator
|
|
8
|
+
from pydantic_core import PydanticCustomError
|
|
9
|
+
|
|
10
|
+
from synapse_sdk.clients.exceptions import ClientError
|
|
11
|
+
from synapse_sdk.plugins.categories.decorators import register_action
|
|
12
|
+
from synapse_sdk.plugins.categories.neural_net.actions.train import TrainAction, TrainRun
|
|
13
|
+
from synapse_sdk.plugins.enums import PluginCategory, RunMethod
|
|
14
|
+
from synapse_sdk.utils.file import archive
|
|
15
|
+
from synapse_sdk.utils.module_loading import import_string
|
|
16
|
+
from synapse_sdk.utils.pydantic.validators import non_blank
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TuneRun(TrainRun):
|
|
20
|
+
is_tune = True
|
|
21
|
+
completed_samples = 0
|
|
22
|
+
num_samples = 0
|
|
23
|
+
checkpoint_output = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SearchAlgo(BaseModel):
|
|
27
|
+
"""
|
|
28
|
+
Configuration for Ray Tune search algorithms.
|
|
29
|
+
|
|
30
|
+
Supported algorithms:
|
|
31
|
+
- 'bayesoptsearch': Bayesian optimization using Gaussian Processes
|
|
32
|
+
- 'hyperoptsearch': Tree-structured Parzen Estimator (TPE)
|
|
33
|
+
- 'basicvariantgenerator': Random search (default)
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
name (str): Name of the search algorithm (case-insensitive)
|
|
37
|
+
points_to_evaluate (Optional[dict]): Optional initial hyperparameter
|
|
38
|
+
configurations to evaluate before starting optimization
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
{
|
|
42
|
+
"name": "hyperoptsearch",
|
|
43
|
+
"points_to_evaluate": [
|
|
44
|
+
{"learning_rate": 0.001, "batch_size": 32}
|
|
45
|
+
]
|
|
46
|
+
}
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
name: str
|
|
50
|
+
points_to_evaluate: Optional[dict] = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Scheduler(BaseModel):
|
|
54
|
+
"""
|
|
55
|
+
Configuration for Ray Tune schedulers.
|
|
56
|
+
|
|
57
|
+
Supported schedulers:
|
|
58
|
+
- 'fifo': First-In-First-Out scheduler (default, runs all trials)
|
|
59
|
+
- 'hyperband': HyperBand early stopping scheduler
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
name (str): Name of the scheduler (case-insensitive)
|
|
63
|
+
options (Optional[str]): Optional scheduler-specific configuration parameters
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
{
|
|
67
|
+
"name": "hyperband",
|
|
68
|
+
"options": {
|
|
69
|
+
"max_t": 100,
|
|
70
|
+
"reduction_factor": 3
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
name: str
|
|
76
|
+
options: Optional[str] = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TuneConfig(BaseModel):
|
|
80
|
+
"""
|
|
81
|
+
Configuration for Ray Tune hyperparameter optimization.
|
|
82
|
+
|
|
83
|
+
Attributes:
|
|
84
|
+
mode (Optional[str]): Optimization mode - 'max' or 'min'
|
|
85
|
+
metric (Optional[str]): Name of the metric to optimize
|
|
86
|
+
num_samples (int): Number of hyperparameter configurations to try (default: 1)
|
|
87
|
+
max_concurrent_trials (Optional[int]): Maximum number of trials to run in parallel
|
|
88
|
+
search_alg (Optional[SearchAlgo]): Search algorithm configuration
|
|
89
|
+
scheduler (Optional[Scheduler]): Trial scheduler configuration
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
{
|
|
93
|
+
"mode": "max",
|
|
94
|
+
"metric": "accuracy",
|
|
95
|
+
"num_samples": 20,
|
|
96
|
+
"max_concurrent_trials": 4,
|
|
97
|
+
"search_alg": {
|
|
98
|
+
"name": "hyperoptsearch"
|
|
99
|
+
},
|
|
100
|
+
"scheduler": {
|
|
101
|
+
"name": "hyperband",
|
|
102
|
+
"options": {"max_t": 100}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
mode: Optional[str] = None
|
|
108
|
+
metric: Optional[str] = None
|
|
109
|
+
num_samples: int = 1
|
|
110
|
+
max_concurrent_trials: Optional[int] = None
|
|
111
|
+
search_alg: Optional[SearchAlgo] = None
|
|
112
|
+
scheduler: Optional[Scheduler] = None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class TuneParams(BaseModel):
|
|
116
|
+
"""
|
|
117
|
+
Parameters for TuneAction (DEPRECATED - use TrainAction with is_tune=True instead).
|
|
118
|
+
|
|
119
|
+
Attributes:
|
|
120
|
+
name (str): Name for the tuning job
|
|
121
|
+
description (str): Description of the job
|
|
122
|
+
checkpoint (int | None): Optional checkpoint ID to resume from
|
|
123
|
+
dataset (int): Dataset ID to use for training
|
|
124
|
+
hyperparameter (list): Hyperparameter search space
|
|
125
|
+
tune_config (TuneConfig): Tune configuration
|
|
126
|
+
|
|
127
|
+
Hyperparameter format:
|
|
128
|
+
Each item in hyperparameter list must have:
|
|
129
|
+
- 'name': Parameter name (string)
|
|
130
|
+
- 'type': Distribution type (string)
|
|
131
|
+
- Type-specific parameters:
|
|
132
|
+
- uniform/quniform: 'min', 'max'
|
|
133
|
+
- loguniform/qloguniform: 'min', 'max', 'base'
|
|
134
|
+
- randn/qrandn: 'mean', 'sd'
|
|
135
|
+
- randint/qrandint: 'min', 'max'
|
|
136
|
+
- lograndint/qlograndint: 'min', 'max', 'base'
|
|
137
|
+
- choice/grid_search: 'options'
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
{
|
|
141
|
+
"name": "my_tuning",
|
|
142
|
+
"dataset": 123,
|
|
143
|
+
"hyperparameter": [
|
|
144
|
+
{"name": "batch_size", "type": "choice", "options": [16, 32, 64]},
|
|
145
|
+
{"name": "learning_rate", "type": "loguniform", "min": 0.0001, "max": 0.01, "base": 10},
|
|
146
|
+
{"name": "epochs", "type": "randint", "min": 5, "max": 15}
|
|
147
|
+
],
|
|
148
|
+
"tune_config": {
|
|
149
|
+
"mode": "max",
|
|
150
|
+
"metric": "accuracy",
|
|
151
|
+
"num_samples": 10
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
name: Annotated[str, AfterValidator(non_blank)]
|
|
157
|
+
description: str
|
|
158
|
+
checkpoint: int | None
|
|
159
|
+
dataset: int
|
|
160
|
+
hyperparameter: list
|
|
161
|
+
tune_config: TuneConfig
|
|
162
|
+
|
|
163
|
+
@field_validator('name')
|
|
164
|
+
@staticmethod
|
|
165
|
+
def unique_name(value, info):
|
|
166
|
+
action = info.context['action']
|
|
167
|
+
client = action.client
|
|
168
|
+
try:
|
|
169
|
+
job_exists = client.exists(
|
|
170
|
+
'list_jobs',
|
|
171
|
+
params={
|
|
172
|
+
'ids_ex': action.job_id,
|
|
173
|
+
'category': 'neural_net',
|
|
174
|
+
'job__action': 'tune',
|
|
175
|
+
'is_active': True,
|
|
176
|
+
'params': f'name:{value}',
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
assert not job_exists, '존재하는 튜닝 작업 이름입니다.'
|
|
180
|
+
except ClientError:
|
|
181
|
+
raise PydanticCustomError('client_error', '')
|
|
182
|
+
return value
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@register_action
|
|
186
|
+
class TuneAction(TrainAction):
|
|
187
|
+
"""
|
|
188
|
+
**DEPRECATED**: This action is deprecated. Please use TrainAction with is_tune=True instead.
|
|
189
|
+
|
|
190
|
+
To migrate from tune to train with tuning:
|
|
191
|
+
- Change action from "tune" to "train"
|
|
192
|
+
- Add "is_tune": true to params
|
|
193
|
+
- Keep tune_config and hyperparameter as they are
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
{
|
|
197
|
+
"action": "train",
|
|
198
|
+
"params": {
|
|
199
|
+
"is_tune": true,
|
|
200
|
+
"tune_config": { ... },
|
|
201
|
+
"hyperparameter": [ ... ]
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
**Must read** Important notes before using Tune:
|
|
206
|
+
|
|
207
|
+
1. Path to the model output (which is the return value of your train function)
|
|
208
|
+
should be set to the checkpoint_output attribute of the run object **before**
|
|
209
|
+
starting the training.
|
|
210
|
+
2. Before exiting the training function, report the results to Tune.
|
|
211
|
+
3. When using own tune.py, take note of the difference in the order of parameters.
|
|
212
|
+
tune() function starts with hyperparameter, run, dataset, checkpoint, **kwargs
|
|
213
|
+
whereas the train() function starts with run, dataset, hyperparameter, checkpoint, **kwargs.
|
|
214
|
+
----
|
|
215
|
+
1)
|
|
216
|
+
Set the output path for the checkpoint to export best model
|
|
217
|
+
|
|
218
|
+
output_path = Path('path/to/your/weights')
|
|
219
|
+
run.checkpoint_output = str(output_path)
|
|
220
|
+
|
|
221
|
+
2)
|
|
222
|
+
Before exiting the training function, report the results to Tune.
|
|
223
|
+
The results_dict should contain the metrics you want to report.
|
|
224
|
+
|
|
225
|
+
Example: (In train function)
|
|
226
|
+
results_dict = {
|
|
227
|
+
"accuracy": accuracy,
|
|
228
|
+
"loss": loss,
|
|
229
|
+
# Add other metrics as needed
|
|
230
|
+
}
|
|
231
|
+
if hasattr(self.dm_run, 'is_tune') and self.dm_run.is_tune:
|
|
232
|
+
tune.report(results_dict, checkpoint=tune.Checkpoint.from_directory(self.dm_run.checkpoint_output))
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
3)
|
|
236
|
+
tune() function takes hyperparameter, run, dataset, checkpoint, **kwargs in that order
|
|
237
|
+
whereas train() function takes run, dataset, hyperparameter, checkpoint, **kwargs in that order.
|
|
238
|
+
|
|
239
|
+
--------------------------------------------------------------------------------------------------------
|
|
240
|
+
|
|
241
|
+
**중요** Tune 사용 전 반드시 읽어야 할 사항들
|
|
242
|
+
|
|
243
|
+
1. 본 플러그인의 train 함수에서, 학습을 진행하기 코드 전에
|
|
244
|
+
결과 모델 파일의 경로(train함수의 리턴 값)을 checkpoint_output 속성에 설정해야 합니다.
|
|
245
|
+
2. 학습이 종료되기 전에, 결과를 Tune에 보고해야 합니다.
|
|
246
|
+
3. 플러그인에서 tune.py를 직접 생성해서 사용할 시, 매개변수의 순서가 다릅니다.
|
|
247
|
+
|
|
248
|
+
----
|
|
249
|
+
1)
|
|
250
|
+
체크포인트를 설정할 경로를 지정합니다.
|
|
251
|
+
output_path = Path('path/to/your/weights')
|
|
252
|
+
run.checkpoint_output = str(output_path)
|
|
253
|
+
|
|
254
|
+
2)
|
|
255
|
+
학습이 종료되기 전에, 결과를 Tune에 보고합니다.
|
|
256
|
+
results_dict = {
|
|
257
|
+
"accuracy": accuracy,
|
|
258
|
+
"loss": loss,
|
|
259
|
+
# 필요한 다른 메트릭 추가
|
|
260
|
+
}
|
|
261
|
+
if hasattr(self.dm_run, 'is_tune') and self.dm_run.is_tune:
|
|
262
|
+
tune.report(results_dict, checkpoint=tune.Checkpoint.from_directory(self.dm_run.checkpoint_output))
|
|
263
|
+
|
|
264
|
+
3)
|
|
265
|
+
tune() 함수는 hyperparameter, run, dataset, checkpoint, **kwargs 순서이고
|
|
266
|
+
train() 함수는 run, dataset, hyperparameter, checkpoint, **kwargs 순서입니다.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
name = 'tune'
|
|
270
|
+
category = PluginCategory.NEURAL_NET
|
|
271
|
+
method = RunMethod.JOB
|
|
272
|
+
run_class = TuneRun
|
|
273
|
+
params_model = TuneParams
|
|
274
|
+
progress_categories = {
|
|
275
|
+
'dataset': {
|
|
276
|
+
'proportion': 5,
|
|
277
|
+
},
|
|
278
|
+
'trials': {
|
|
279
|
+
'proportion': 90,
|
|
280
|
+
},
|
|
281
|
+
'model_upload': {
|
|
282
|
+
'proportion': 5,
|
|
283
|
+
},
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
def start(self):
|
|
287
|
+
from ray import tune
|
|
288
|
+
|
|
289
|
+
# download dataset
|
|
290
|
+
self.run.log_message('Preparing dataset for hyperparameter tuning.')
|
|
291
|
+
input_dataset = self.get_dataset()
|
|
292
|
+
|
|
293
|
+
# retrieve checkpoint
|
|
294
|
+
checkpoint = None
|
|
295
|
+
if self.params['checkpoint']:
|
|
296
|
+
self.run.log_message('Retrieving checkpoint.')
|
|
297
|
+
checkpoint = self.get_model(self.params['checkpoint'])
|
|
298
|
+
|
|
299
|
+
# train dataset
|
|
300
|
+
self.run.log_message('Starting training for hyperparameter tuning.')
|
|
301
|
+
|
|
302
|
+
# Save num_samples to TuneRun for logging
|
|
303
|
+
self.run.num_samples = self.params['tune_config']['num_samples']
|
|
304
|
+
|
|
305
|
+
entrypoint = self.entrypoint
|
|
306
|
+
if not self._tune_override_exists():
|
|
307
|
+
# entrypoint must be train entrypoint
|
|
308
|
+
train_entrypoint = entrypoint
|
|
309
|
+
|
|
310
|
+
def _tune(param_space, run, dataset, checkpoint=None, **kwargs):
|
|
311
|
+
result = train_entrypoint(run, dataset, param_space, checkpoint, **kwargs)
|
|
312
|
+
if isinstance(result, Number) or isinstance(result, dict):
|
|
313
|
+
return result
|
|
314
|
+
return {'result': result}
|
|
315
|
+
|
|
316
|
+
entrypoint = _tune
|
|
317
|
+
|
|
318
|
+
trainable = tune.with_parameters(entrypoint, run=self.run, dataset=input_dataset, checkpoint=checkpoint)
|
|
319
|
+
|
|
320
|
+
tune_config = self.params['tune_config']
|
|
321
|
+
|
|
322
|
+
tune_config['search_alg'] = self.convert_tune_search_alg(tune_config)
|
|
323
|
+
tune_config['scheduler'] = self.convert_tune_scheduler(tune_config)
|
|
324
|
+
|
|
325
|
+
hyperparameter = self.params['hyperparameter']
|
|
326
|
+
param_space = self.convert_tune_params(hyperparameter)
|
|
327
|
+
temp_path = tempfile.TemporaryDirectory()
|
|
328
|
+
|
|
329
|
+
tuner = tune.Tuner(
|
|
330
|
+
tune.with_resources(trainable, resources=self.tune_resources),
|
|
331
|
+
tune_config=tune.TuneConfig(**tune_config),
|
|
332
|
+
run_config=tune.RunConfig(
|
|
333
|
+
name=f'synapse_tune_hpo_{self.job_id}',
|
|
334
|
+
log_to_file=('stdout.log', 'stderr.log'),
|
|
335
|
+
storage_path=temp_path.name,
|
|
336
|
+
),
|
|
337
|
+
param_space=param_space,
|
|
338
|
+
)
|
|
339
|
+
result = tuner.fit()
|
|
340
|
+
|
|
341
|
+
best_result = result.get_best_result()
|
|
342
|
+
|
|
343
|
+
# upload model_data
|
|
344
|
+
self.run.log_message('Registering best model data.')
|
|
345
|
+
self.run.set_progress(0, 1, category='model_upload')
|
|
346
|
+
self.create_model_from_result(best_result)
|
|
347
|
+
self.run.set_progress(1, 1, category='model_upload')
|
|
348
|
+
|
|
349
|
+
self.run.end_log()
|
|
350
|
+
|
|
351
|
+
return {'best_result': best_result.config}
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def tune_resources(self):
|
|
355
|
+
resources = {}
|
|
356
|
+
for option in ['num_cpus', 'num_gpus']:
|
|
357
|
+
option_value = self.params.get(option)
|
|
358
|
+
if option_value:
|
|
359
|
+
# Remove the 'num_' prefix and trailing s from the option name
|
|
360
|
+
resources[(lambda s: s[4:-1])(option)] = option_value
|
|
361
|
+
return resources
|
|
362
|
+
|
|
363
|
+
def create_model_from_result(self, result):
|
|
364
|
+
params = copy.deepcopy(self.params)
|
|
365
|
+
configuration_fields = ['hyperparameter']
|
|
366
|
+
configuration = {field: params.pop(field) for field in configuration_fields}
|
|
367
|
+
|
|
368
|
+
with tempfile.TemporaryDirectory() as temp_path:
|
|
369
|
+
archive_path = Path(temp_path, 'archive.zip')
|
|
370
|
+
|
|
371
|
+
# Archive tune results
|
|
372
|
+
# https://docs.ray.io/en/latest/tune/tutorials/tune_get_data_in_and_out.html#getting-data-out-of-tune-using-checkpoints-other-artifacts
|
|
373
|
+
archive(result.path, archive_path)
|
|
374
|
+
|
|
375
|
+
return self.client.create_model({
|
|
376
|
+
'plugin': self.plugin_release.plugin,
|
|
377
|
+
'version': self.plugin_release.version,
|
|
378
|
+
'file': str(archive_path),
|
|
379
|
+
'configuration': configuration,
|
|
380
|
+
**params,
|
|
381
|
+
})
|
|
382
|
+
|
|
383
|
+
@staticmethod
|
|
384
|
+
def convert_tune_scheduler(tune_config):
|
|
385
|
+
"""
|
|
386
|
+
Convert YAML hyperparameter configuration to a Ray Tune scheduler.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
tune_config (dict): Hyperparameter configuration.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
object: Ray Tune scheduler instance.
|
|
393
|
+
|
|
394
|
+
Supported schedulers:
|
|
395
|
+
- 'fifo': FIFOScheduler (default)
|
|
396
|
+
- 'hyperband': HyperBandScheduler
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
from ray.tune.schedulers import (
|
|
400
|
+
ASHAScheduler,
|
|
401
|
+
FIFOScheduler,
|
|
402
|
+
HyperBandScheduler,
|
|
403
|
+
MedianStoppingRule,
|
|
404
|
+
PopulationBasedTraining,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
if tune_config.get('scheduler') is None:
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
scheduler_map = {
|
|
411
|
+
'fifo': FIFOScheduler,
|
|
412
|
+
'asha': ASHAScheduler,
|
|
413
|
+
'hyperband': HyperBandScheduler,
|
|
414
|
+
'pbt': PopulationBasedTraining,
|
|
415
|
+
'median': MedianStoppingRule,
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
scheduler_type = tune_config['scheduler'].get('name', 'fifo').lower()
|
|
419
|
+
scheduler_class = scheduler_map.get(scheduler_type)
|
|
420
|
+
|
|
421
|
+
if scheduler_class is None:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
f'Unsupported scheduler: {scheduler_type}. Supported schedulers are: {", ".join(scheduler_map.keys())}'
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# 옵션이 있는 경우 전달하고, 없으면 기본 생성자 호출
|
|
427
|
+
options = tune_config['scheduler'].get('options')
|
|
428
|
+
|
|
429
|
+
# options가 None이거나 빈 딕셔너리가 아닌 경우에만 전달
|
|
430
|
+
scheduler = scheduler_class(**options) if options else scheduler_class()
|
|
431
|
+
|
|
432
|
+
return scheduler
|
|
433
|
+
|
|
434
|
+
@staticmethod
|
|
435
|
+
def convert_tune_search_alg(tune_config):
|
|
436
|
+
"""
|
|
437
|
+
Convert YAML hyperparameter configuration to Ray Tune search algorithm.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
tune_config (dict): Hyperparameter configuration.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
object: Ray Tune search algorithm instance or None
|
|
444
|
+
|
|
445
|
+
Supported search algorithms:
|
|
446
|
+
- 'bayesoptsearch': Bayesian optimization
|
|
447
|
+
- 'hyperoptsearch': Tree-structured Parzen Estimator
|
|
448
|
+
- 'basicvariantgenerator': Random search (default)
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
if tune_config.get('search_alg') is None:
|
|
452
|
+
return None
|
|
453
|
+
|
|
454
|
+
search_alg_name = tune_config['search_alg']['name'].lower()
|
|
455
|
+
metric = tune_config['metric']
|
|
456
|
+
mode = tune_config['mode']
|
|
457
|
+
points_to_evaluate = tune_config['search_alg'].get('points_to_evaluate', None)
|
|
458
|
+
|
|
459
|
+
if search_alg_name == 'axsearch':
|
|
460
|
+
from ray.tune.search.ax import AxSearch
|
|
461
|
+
|
|
462
|
+
search_alg = AxSearch(metric=metric, mode=mode)
|
|
463
|
+
elif search_alg_name == 'bayesoptsearch':
|
|
464
|
+
from ray.tune.search.bayesopt import BayesOptSearch
|
|
465
|
+
|
|
466
|
+
search_alg = BayesOptSearch(metric=metric, mode=mode)
|
|
467
|
+
elif search_alg_name == 'hyperoptsearch':
|
|
468
|
+
from ray.tune.search.hyperopt import HyperOptSearch
|
|
469
|
+
|
|
470
|
+
search_alg = HyperOptSearch(metric=metric, mode=mode)
|
|
471
|
+
elif search_alg_name == 'optunasearch':
|
|
472
|
+
from ray.tune.search.optuna import OptunaSearch
|
|
473
|
+
|
|
474
|
+
search_alg = OptunaSearch(metric=metric, mode=mode)
|
|
475
|
+
elif search_alg_name == 'basicvariantgenerator':
|
|
476
|
+
from ray.tune.search.basic_variant import BasicVariantGenerator
|
|
477
|
+
|
|
478
|
+
search_alg = BasicVariantGenerator(points_to_evaluate=points_to_evaluate)
|
|
479
|
+
else:
|
|
480
|
+
raise ValueError(
|
|
481
|
+
f'Unsupported search algorithm: {search_alg_name}. '
|
|
482
|
+
f'Supported algorithms are: bayesoptsearch, hyperoptsearch, basicvariantgenerator'
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
return search_alg
|
|
486
|
+
|
|
487
|
+
@staticmethod
|
|
488
|
+
def convert_tune_params(param_list):
|
|
489
|
+
"""
|
|
490
|
+
Convert YAML hyperparameter configuration to Ray Tune parameter dictionary.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
param_list (list): List of hyperparameter configurations.
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
dict: Ray Tune parameter dictionary
|
|
497
|
+
"""
|
|
498
|
+
from ray import tune
|
|
499
|
+
|
|
500
|
+
param_handlers = {
|
|
501
|
+
'uniform': lambda p: tune.uniform(p['min'], p['max']),
|
|
502
|
+
'quniform': lambda p: tune.quniform(p['min'], p['max']),
|
|
503
|
+
'loguniform': lambda p: tune.loguniform(p['min'], p['max'], p['base']),
|
|
504
|
+
'qloguniform': lambda p: tune.qloguniform(p['min'], p['max'], p['base']),
|
|
505
|
+
'randn': lambda p: tune.randn(p['mean'], p['sd']),
|
|
506
|
+
'qrandn': lambda p: tune.qrandn(p['mean'], p['sd']),
|
|
507
|
+
'randint': lambda p: tune.randint(p['min'], p['max']),
|
|
508
|
+
'qrandint': lambda p: tune.qrandint(p['min'], p['max']),
|
|
509
|
+
'lograndint': lambda p: tune.lograndint(p['min'], p['max'], p['base']),
|
|
510
|
+
'qlograndint': lambda p: tune.qlograndint(p['min'], p['max'], p['base']),
|
|
511
|
+
'choice': lambda p: tune.choice(p['options']),
|
|
512
|
+
'grid_search': lambda p: tune.grid_search(p['options']),
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
param_space = {}
|
|
516
|
+
|
|
517
|
+
for param in param_list:
|
|
518
|
+
name = param['name']
|
|
519
|
+
param_type = param['type']
|
|
520
|
+
|
|
521
|
+
if param_type in param_handlers:
|
|
522
|
+
param_space[name] = param_handlers[param_type](param)
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError(f'Unknown parameter type: {param_type}')
|
|
525
|
+
|
|
526
|
+
return param_space
|
|
527
|
+
|
|
528
|
+
@staticmethod
|
|
529
|
+
def _tune_override_exists(module_path='plugin.tune') -> bool:
|
|
530
|
+
try:
|
|
531
|
+
import_string(module_path)
|
|
532
|
+
return True
|
|
533
|
+
except ImportError:
|
|
534
|
+
return False
|
|
@@ -1,18 +1,46 @@
|
|
|
1
|
-
|
|
1
|
+
name: plugin_name
|
|
2
|
+
code: plugin_code
|
|
3
|
+
version: plugin_version
|
|
4
|
+
readme: README.md
|
|
5
|
+
description: This is plugin_name plugin
|
|
6
|
+
category: neural_net
|
|
7
|
+
tasks:
|
|
8
|
+
- image.object_detection
|
|
2
9
|
data_type: image
|
|
10
|
+
package_manager: uv
|
|
3
11
|
actions:
|
|
4
12
|
train:
|
|
5
13
|
dataset: dataset
|
|
6
14
|
entrypoint: plugin.train.train
|
|
7
15
|
metrics:
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
16
|
+
train:
|
|
17
|
+
epoch:
|
|
18
|
+
loss: # Use the plugin's internal variable as the key, and the user-facing title in name.
|
|
19
|
+
name: Loss
|
|
20
|
+
validation:
|
|
21
|
+
epoch:
|
|
22
|
+
acc:
|
|
23
|
+
name: Accuracy
|
|
24
|
+
hyperparameters:
|
|
25
|
+
ui_schema: |
|
|
26
|
+
Dumped FormKit Schema for hyperparameters
|
|
27
|
+
visualizations:
|
|
28
|
+
validation_samples_per_epochs: # put in log_visualization name
|
|
29
|
+
type: vis_type
|
|
30
|
+
name: user-facing title
|
|
31
|
+
options:
|
|
32
|
+
group_name: Epoch
|
|
33
|
+
thumbnail_size: [50, 50]
|
|
34
|
+
options:
|
|
35
|
+
visualize: false # Whether to visualize the training process
|
|
11
36
|
deployment:
|
|
12
37
|
entrypoint: plugin.inference.MockNetInference
|
|
13
38
|
inference:
|
|
14
39
|
method: restapi
|
|
15
40
|
endpoints:
|
|
16
41
|
- method: get
|
|
42
|
+
required_resources: # Specify required resources for inference deployment
|
|
43
|
+
required_cpu_count: 1
|
|
44
|
+
required_gpu_count: 0.1
|
|
17
45
|
test:
|
|
18
46
|
entrypoint: plugin.test.test
|
|
@@ -1,14 +1,30 @@
|
|
|
1
|
-
from
|
|
1
|
+
from pydantic import BaseModel
|
|
2
2
|
|
|
3
|
+
# for load file with synapse
|
|
4
|
+
# from synapse_sdk.types import FileField
|
|
5
|
+
from synapse_sdk.plugins.categories.neural_net.base.inference import BaseInference, app
|
|
3
6
|
|
|
4
|
-
class MockNetInference:
|
|
5
|
-
model_id = None
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
return model_id
|
|
8
|
+
class InputData(BaseModel): # Pydantic
|
|
9
|
+
input_string: str
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
11
|
+
|
|
12
|
+
class ResNetInference(BaseInference):
|
|
13
|
+
async def _get_model(self, model): # Load model
|
|
14
|
+
model_directory_path = model['path']
|
|
15
|
+
|
|
16
|
+
# implement model_load code
|
|
17
|
+
model = model_directory_path
|
|
18
|
+
|
|
19
|
+
return model # return loaded_model
|
|
20
|
+
|
|
21
|
+
@app.post('/load_model')
|
|
22
|
+
async def load_model(self):
|
|
23
|
+
await self.get_model()
|
|
24
|
+
|
|
25
|
+
@app.post('/')
|
|
26
|
+
async def infer(self, data: InputData):
|
|
27
|
+
model = await self.get_model()
|
|
28
|
+
results = model(data.input_string) # This is Sample code. implement your model's prediction code
|
|
29
|
+
|
|
30
|
+
return results
|
|
@@ -4,7 +4,7 @@ from synapse_sdk.plugins.enums import PluginCategory, RunMethod
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
@register_action
|
|
7
|
-
class
|
|
8
|
-
name = '
|
|
9
|
-
category = PluginCategory.
|
|
10
|
-
method = RunMethod.
|
|
7
|
+
class PreAnnotationAction(Action):
|
|
8
|
+
name = 'pre_annotation'
|
|
9
|
+
category = PluginCategory.PRE_ANNOTATION
|
|
10
|
+
method = RunMethod.TASK
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from .action import ToTaskAction
|
|
2
|
+
from .enums import AnnotateTaskDataStatus, AnnotationMethod, LogCode
|
|
3
|
+
from .exceptions import CriticalError, PreAnnotationToTaskFailed
|
|
4
|
+
|
|
5
|
+
# Advanced imports for extending the system
|
|
6
|
+
from .factory import ToTaskStrategyFactory
|
|
7
|
+
from .models import MetricsRecord, ToTaskParams, ToTaskResult
|
|
8
|
+
from .orchestrator import ToTaskOrchestrator
|
|
9
|
+
from .run import ToTaskRun
|
|
10
|
+
from .strategies.base import ToTaskContext
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
# Core public API (maintains backward compatibility)
|
|
14
|
+
'ToTaskAction',
|
|
15
|
+
'ToTaskRun',
|
|
16
|
+
'ToTaskParams',
|
|
17
|
+
'ToTaskResult',
|
|
18
|
+
'AnnotationMethod',
|
|
19
|
+
'AnnotateTaskDataStatus',
|
|
20
|
+
'LogCode',
|
|
21
|
+
'CriticalError',
|
|
22
|
+
'PreAnnotationToTaskFailed',
|
|
23
|
+
'MetricsRecord',
|
|
24
|
+
# Advanced components for customization and testing
|
|
25
|
+
'ToTaskOrchestrator',
|
|
26
|
+
'ToTaskContext',
|
|
27
|
+
'ToTaskStrategyFactory',
|
|
28
|
+
]
|