synapse-sdk 1.0.0b5__py3-none-any.whl → 2025.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. synapse_sdk/__init__.py +24 -0
  2. synapse_sdk/cli/code_server.py +305 -33
  3. synapse_sdk/clients/agent/__init__.py +2 -1
  4. synapse_sdk/clients/agent/container.py +143 -0
  5. synapse_sdk/clients/agent/ray.py +296 -38
  6. synapse_sdk/clients/backend/annotation.py +1 -1
  7. synapse_sdk/clients/backend/core.py +31 -4
  8. synapse_sdk/clients/backend/data_collection.py +82 -7
  9. synapse_sdk/clients/backend/hitl.py +1 -1
  10. synapse_sdk/clients/backend/ml.py +1 -1
  11. synapse_sdk/clients/base.py +211 -61
  12. synapse_sdk/loggers.py +46 -0
  13. synapse_sdk/plugins/README.md +1340 -0
  14. synapse_sdk/plugins/categories/base.py +59 -9
  15. synapse_sdk/plugins/categories/export/actions/__init__.py +3 -0
  16. synapse_sdk/plugins/categories/export/actions/export/__init__.py +28 -0
  17. synapse_sdk/plugins/categories/export/actions/export/action.py +165 -0
  18. synapse_sdk/plugins/categories/export/actions/export/enums.py +113 -0
  19. synapse_sdk/plugins/categories/export/actions/export/exceptions.py +53 -0
  20. synapse_sdk/plugins/categories/export/actions/export/models.py +74 -0
  21. synapse_sdk/plugins/categories/export/actions/export/run.py +195 -0
  22. synapse_sdk/plugins/categories/export/actions/export/utils.py +187 -0
  23. synapse_sdk/plugins/categories/export/templates/config.yaml +19 -1
  24. synapse_sdk/plugins/categories/export/templates/plugin/__init__.py +390 -0
  25. synapse_sdk/plugins/categories/export/templates/plugin/export.py +153 -177
  26. synapse_sdk/plugins/categories/neural_net/actions/train.py +1130 -32
  27. synapse_sdk/plugins/categories/neural_net/actions/tune.py +157 -4
  28. synapse_sdk/plugins/categories/neural_net/templates/config.yaml +7 -4
  29. synapse_sdk/plugins/categories/pre_annotation/actions/__init__.py +4 -0
  30. synapse_sdk/plugins/categories/pre_annotation/actions/pre_annotation/__init__.py +3 -0
  31. synapse_sdk/plugins/categories/pre_annotation/actions/pre_annotation/action.py +10 -0
  32. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/__init__.py +28 -0
  33. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/action.py +148 -0
  34. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/enums.py +269 -0
  35. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/exceptions.py +14 -0
  36. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/factory.py +76 -0
  37. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/models.py +100 -0
  38. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/orchestrator.py +248 -0
  39. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/run.py +64 -0
  40. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/__init__.py +17 -0
  41. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/annotation.py +265 -0
  42. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/base.py +170 -0
  43. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/extraction.py +83 -0
  44. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/metrics.py +92 -0
  45. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/preprocessor.py +243 -0
  46. synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/validation.py +143 -0
  47. synapse_sdk/plugins/categories/upload/actions/upload/__init__.py +19 -0
  48. synapse_sdk/plugins/categories/upload/actions/upload/action.py +236 -0
  49. synapse_sdk/plugins/categories/upload/actions/upload/context.py +185 -0
  50. synapse_sdk/plugins/categories/upload/actions/upload/enums.py +493 -0
  51. synapse_sdk/plugins/categories/upload/actions/upload/exceptions.py +36 -0
  52. synapse_sdk/plugins/categories/upload/actions/upload/factory.py +138 -0
  53. synapse_sdk/plugins/categories/upload/actions/upload/models.py +214 -0
  54. synapse_sdk/plugins/categories/upload/actions/upload/orchestrator.py +183 -0
  55. synapse_sdk/plugins/categories/upload/actions/upload/registry.py +113 -0
  56. synapse_sdk/plugins/categories/upload/actions/upload/run.py +179 -0
  57. synapse_sdk/plugins/categories/upload/actions/upload/steps/__init__.py +1 -0
  58. synapse_sdk/plugins/categories/upload/actions/upload/steps/base.py +107 -0
  59. synapse_sdk/plugins/categories/upload/actions/upload/steps/cleanup.py +62 -0
  60. synapse_sdk/plugins/categories/upload/actions/upload/steps/collection.py +63 -0
  61. synapse_sdk/plugins/categories/upload/actions/upload/steps/generate.py +91 -0
  62. synapse_sdk/plugins/categories/upload/actions/upload/steps/initialize.py +82 -0
  63. synapse_sdk/plugins/categories/upload/actions/upload/steps/metadata.py +235 -0
  64. synapse_sdk/plugins/categories/upload/actions/upload/steps/organize.py +201 -0
  65. synapse_sdk/plugins/categories/upload/actions/upload/steps/upload.py +104 -0
  66. synapse_sdk/plugins/categories/upload/actions/upload/steps/validate.py +71 -0
  67. synapse_sdk/plugins/categories/upload/actions/upload/strategies/__init__.py +1 -0
  68. synapse_sdk/plugins/categories/upload/actions/upload/strategies/base.py +82 -0
  69. synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/__init__.py +1 -0
  70. synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/batch.py +39 -0
  71. synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/single.py +29 -0
  72. synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/__init__.py +1 -0
  73. synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/flat.py +300 -0
  74. synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/recursive.py +287 -0
  75. synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/__init__.py +1 -0
  76. synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/excel.py +174 -0
  77. synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/none.py +16 -0
  78. synapse_sdk/plugins/categories/upload/actions/upload/strategies/upload/__init__.py +1 -0
  79. synapse_sdk/plugins/categories/upload/actions/upload/strategies/upload/sync.py +84 -0
  80. synapse_sdk/plugins/categories/upload/actions/upload/strategies/validation/__init__.py +1 -0
  81. synapse_sdk/plugins/categories/upload/actions/upload/strategies/validation/default.py +60 -0
  82. synapse_sdk/plugins/categories/upload/actions/upload/utils.py +250 -0
  83. synapse_sdk/plugins/categories/upload/templates/README.md +470 -0
  84. synapse_sdk/plugins/categories/upload/templates/config.yaml +28 -2
  85. synapse_sdk/plugins/categories/upload/templates/plugin/__init__.py +310 -0
  86. synapse_sdk/plugins/categories/upload/templates/plugin/upload.py +82 -20
  87. synapse_sdk/plugins/models.py +111 -9
  88. synapse_sdk/plugins/templates/plugin-config-schema.json +7 -0
  89. synapse_sdk/plugins/templates/schema.json +7 -0
  90. synapse_sdk/plugins/utils/__init__.py +3 -0
  91. synapse_sdk/plugins/utils/ray_gcs.py +66 -0
  92. synapse_sdk/shared/__init__.py +25 -0
  93. synapse_sdk/utils/converters/dm/__init__.py +42 -41
  94. synapse_sdk/utils/converters/dm/base.py +137 -0
  95. synapse_sdk/utils/converters/dm/from_v1.py +208 -562
  96. synapse_sdk/utils/converters/dm/to_v1.py +258 -304
  97. synapse_sdk/utils/converters/dm/tools/__init__.py +214 -0
  98. synapse_sdk/utils/converters/dm/tools/answer.py +95 -0
  99. synapse_sdk/utils/converters/dm/tools/bounding_box.py +132 -0
  100. synapse_sdk/utils/converters/dm/tools/bounding_box_3d.py +121 -0
  101. synapse_sdk/utils/converters/dm/tools/classification.py +75 -0
  102. synapse_sdk/utils/converters/dm/tools/keypoint.py +117 -0
  103. synapse_sdk/utils/converters/dm/tools/named_entity.py +111 -0
  104. synapse_sdk/utils/converters/dm/tools/polygon.py +122 -0
  105. synapse_sdk/utils/converters/dm/tools/polyline.py +124 -0
  106. synapse_sdk/utils/converters/dm/tools/prompt.py +94 -0
  107. synapse_sdk/utils/converters/dm/tools/relation.py +86 -0
  108. synapse_sdk/utils/converters/dm/tools/segmentation.py +141 -0
  109. synapse_sdk/utils/converters/dm/tools/segmentation_3d.py +83 -0
  110. synapse_sdk/utils/converters/dm/types.py +168 -0
  111. synapse_sdk/utils/converters/dm/utils.py +162 -0
  112. synapse_sdk/utils/converters/dm_legacy/__init__.py +56 -0
  113. synapse_sdk/utils/converters/dm_legacy/from_v1.py +627 -0
  114. synapse_sdk/utils/converters/dm_legacy/to_v1.py +367 -0
  115. synapse_sdk/utils/file/__init__.py +58 -0
  116. synapse_sdk/utils/file/archive.py +32 -0
  117. synapse_sdk/utils/file/checksum.py +56 -0
  118. synapse_sdk/utils/file/chunking.py +31 -0
  119. synapse_sdk/utils/file/download.py +385 -0
  120. synapse_sdk/utils/file/encoding.py +40 -0
  121. synapse_sdk/utils/file/io.py +22 -0
  122. synapse_sdk/utils/file/upload.py +165 -0
  123. synapse_sdk/utils/file/video/__init__.py +29 -0
  124. synapse_sdk/utils/file/video/transcode.py +307 -0
  125. synapse_sdk/utils/{file.py → file.py.backup} +77 -0
  126. synapse_sdk/utils/network.py +272 -0
  127. synapse_sdk/utils/storage/__init__.py +6 -2
  128. synapse_sdk/utils/storage/providers/file_system.py +6 -0
  129. {synapse_sdk-1.0.0b5.dist-info → synapse_sdk-2025.12.3.dist-info}/METADATA +19 -2
  130. {synapse_sdk-1.0.0b5.dist-info → synapse_sdk-2025.12.3.dist-info}/RECORD +134 -74
  131. synapse_sdk/devtools/docs/.gitignore +0 -20
  132. synapse_sdk/devtools/docs/README.md +0 -41
  133. synapse_sdk/devtools/docs/blog/2019-05-28-first-blog-post.md +0 -12
  134. synapse_sdk/devtools/docs/blog/2019-05-29-long-blog-post.md +0 -44
  135. synapse_sdk/devtools/docs/blog/2021-08-01-mdx-blog-post.mdx +0 -24
  136. synapse_sdk/devtools/docs/blog/2021-08-26-welcome/docusaurus-plushie-banner.jpeg +0 -0
  137. synapse_sdk/devtools/docs/blog/2021-08-26-welcome/index.md +0 -29
  138. synapse_sdk/devtools/docs/blog/authors.yml +0 -25
  139. synapse_sdk/devtools/docs/blog/tags.yml +0 -19
  140. synapse_sdk/devtools/docs/docusaurus.config.ts +0 -138
  141. synapse_sdk/devtools/docs/package-lock.json +0 -17455
  142. synapse_sdk/devtools/docs/package.json +0 -47
  143. synapse_sdk/devtools/docs/sidebars.ts +0 -44
  144. synapse_sdk/devtools/docs/src/components/HomepageFeatures/index.tsx +0 -71
  145. synapse_sdk/devtools/docs/src/components/HomepageFeatures/styles.module.css +0 -11
  146. synapse_sdk/devtools/docs/src/css/custom.css +0 -30
  147. synapse_sdk/devtools/docs/src/pages/index.module.css +0 -23
  148. synapse_sdk/devtools/docs/src/pages/index.tsx +0 -21
  149. synapse_sdk/devtools/docs/src/pages/markdown-page.md +0 -7
  150. synapse_sdk/devtools/docs/static/.nojekyll +0 -0
  151. synapse_sdk/devtools/docs/static/img/docusaurus-social-card.jpg +0 -0
  152. synapse_sdk/devtools/docs/static/img/docusaurus.png +0 -0
  153. synapse_sdk/devtools/docs/static/img/favicon.ico +0 -0
  154. synapse_sdk/devtools/docs/static/img/logo.png +0 -0
  155. synapse_sdk/devtools/docs/static/img/undraw_docusaurus_mountain.svg +0 -171
  156. synapse_sdk/devtools/docs/static/img/undraw_docusaurus_react.svg +0 -170
  157. synapse_sdk/devtools/docs/static/img/undraw_docusaurus_tree.svg +0 -40
  158. synapse_sdk/devtools/docs/tsconfig.json +0 -8
  159. synapse_sdk/plugins/categories/export/actions/export.py +0 -346
  160. synapse_sdk/plugins/categories/export/enums.py +0 -7
  161. synapse_sdk/plugins/categories/neural_net/actions/gradio.py +0 -151
  162. synapse_sdk/plugins/categories/pre_annotation/actions/to_task.py +0 -943
  163. synapse_sdk/plugins/categories/upload/actions/upload.py +0 -954
  164. {synapse_sdk-1.0.0b5.dist-info → synapse_sdk-2025.12.3.dist-info}/WHEEL +0 -0
  165. {synapse_sdk-1.0.0b5.dist-info → synapse_sdk-2025.12.3.dist-info}/entry_points.txt +0 -0
  166. {synapse_sdk-1.0.0b5.dist-info → synapse_sdk-2025.12.3.dist-info}/licenses/LICENSE +0 -0
  167. {synapse_sdk-1.0.0b5.dist-info → synapse_sdk-2025.12.3.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,11 @@
1
1
  import copy
2
+ import shutil
2
3
  import tempfile
3
- from decimal import Decimal
4
+ from numbers import Number
4
5
  from pathlib import Path
5
- from typing import Annotated
6
+ from typing import Annotated, Callable, Dict, Optional
6
7
 
7
- from pydantic import AfterValidator, BaseModel, field_validator
8
+ from pydantic import AfterValidator, BaseModel, field_validator, model_validator
8
9
  from pydantic_core import PydanticCustomError
9
10
 
10
11
  from synapse_sdk.clients.exceptions import ClientError
@@ -13,63 +14,378 @@ from synapse_sdk.plugins.categories.decorators import register_action
13
14
  from synapse_sdk.plugins.enums import PluginCategory, RunMethod
14
15
  from synapse_sdk.plugins.models import Run
15
16
  from synapse_sdk.utils.file import archive, get_temp_path, unarchive
17
+ from synapse_sdk.utils.module_loading import import_string
16
18
  from synapse_sdk.utils.pydantic.validators import non_blank
17
19
 
18
20
 
19
21
  class TrainRun(Run):
22
+ is_tune = False
23
+ completed_samples = 0
24
+ num_samples = 0
25
+ checkpoint_output = None
26
+
27
+ def set_progress(self, current, total, category=''):
28
+ if getattr(self, 'is_tune', False) and category == 'train':
29
+ # Ignore train progress updates in tune mode to keep trials-only bar
30
+ return
31
+ super().set_progress(current, total, category)
32
+
20
33
  def log_metric(self, category, key, value, **metrics):
21
34
  # TODO validate input via plugin config
22
- self.log('metric', {'category': category, 'key': key, 'value': value, 'metrics': metrics})
35
+ data = {'category': category, 'key': key, 'value': value, 'metrics': metrics}
36
+
37
+ # Automatically add trial_id when is_tune=True
38
+ if self.is_tune:
39
+ try:
40
+ from ray import train
41
+
42
+ context = train.get_context()
43
+ trial_id = context.get_trial_id()
44
+ if trial_id:
45
+ data['trial_id'] = trial_id
46
+ except Exception:
47
+ # If Ray context is not available, continue without trial_id
48
+ pass
49
+
50
+ self.log('metric', data)
23
51
 
24
52
  def log_visualization(self, category, group, index, image, **meta):
25
53
  # TODO validate input via plugin config
26
- self.log('visualization', {'category': category, 'group': group, 'index': index, **meta}, file=image)
54
+ data = {'category': category, 'group': group, 'index': index, **meta}
55
+
56
+ # Automatically add trial_id when is_tune=True
57
+ if self.is_tune:
58
+ try:
59
+ from ray import train
60
+
61
+ context = train.get_context()
62
+ trial_id = context.get_trial_id()
63
+ if trial_id:
64
+ data['trial_id'] = trial_id
65
+ except Exception:
66
+ # If Ray context is not available, continue without trial_id
67
+ pass
68
+
69
+ self.log('visualization', data, file=image)
70
+
71
+ def log_trials(self, data=None, *, trials=None, base=None, hyperparameters=None, metrics=None, best_trial=''):
72
+ """
73
+ Log structured Ray Tune trial progress tables.
74
+
75
+ Args:
76
+ data (dict | None): Pre-built payload to send. Should contain
77
+ ``trials`` (dict) key.
78
+ base (list[str] | None): Column names that belong to the fixed base section.
79
+ trials (dict | None): Mapping of ``trial_id`` -> structured section values.
80
+ hyperparameters (list[str] | None): Column names belonging to hyperparameters.
81
+ metrics (list[str] | None): Column names belonging to metrics.
82
+ best_trial (str): Trial ID of the best trial (empty string during tuning, populated at the end).
83
+
84
+ Returns:
85
+ dict: The payload that was logged.
86
+ """
87
+ if data is None:
88
+ data = {
89
+ 'base': base or [],
90
+ 'trials': trials or {},
91
+ 'hyperparameters': hyperparameters or [],
92
+ 'metrics': metrics or [],
93
+ 'best_trial': best_trial,
94
+ }
95
+ elif not isinstance(data, dict):
96
+ raise ValueError('log_trials expects a dictionary payload')
97
+
98
+ if 'trials' not in data:
99
+ raise ValueError('log_trials payload must include "trials" key')
100
+
101
+ data.setdefault('base', base or [])
102
+ data.setdefault('hyperparameters', hyperparameters or [])
103
+ data.setdefault('metrics', metrics or [])
104
+ data.setdefault('best_trial', best_trial)
105
+
106
+ self.log('trials', data)
107
+ # Keep track of the last snapshot so we can reuse it (e.g., when finalizing best_trial)
108
+ try:
109
+ self._last_trials_payload = copy.deepcopy(data)
110
+ except Exception:
111
+ self._last_trials_payload = data
112
+ return data
27
113
 
28
114
 
29
- class Hyperparameter(BaseModel):
30
- batch_size: int
31
- epochs: int
32
- learning_rate: Decimal
115
+ class SearchAlgo(BaseModel):
116
+ """
117
+ Configuration for Ray Tune search algorithms.
118
+
119
+ Supported algorithms:
120
+ - 'bayesoptsearch': Bayesian optimization using Gaussian Processes
121
+ - 'hyperoptsearch': Tree-structured Parzen Estimator (TPE)
122
+ - 'basicvariantgenerator': Random search (default)
123
+
124
+ Attributes:
125
+ name (str): Name of the search algorithm (case-insensitive)
126
+ points_to_evaluate (Optional[dict]): Optional initial hyperparameter
127
+ configurations to evaluate before starting optimization
128
+
129
+ Example:
130
+ {
131
+ "name": "hyperoptsearch",
132
+ "points_to_evaluate": [
133
+ {"learning_rate": 0.001, "batch_size": 32}
134
+ ]
135
+ }
136
+ """
137
+
138
+ name: str
139
+ points_to_evaluate: Optional[dict] = None
140
+
141
+
142
+ class Scheduler(BaseModel):
143
+ """
144
+ Configuration for Ray Tune schedulers.
145
+
146
+ Supported schedulers:
147
+ - 'fifo': First-In-First-Out scheduler (default, runs all trials)
148
+ - 'hyperband': HyperBand early stopping scheduler
149
+
150
+ Attributes:
151
+ name (str): Name of the scheduler (case-insensitive)
152
+ options (Optional[str]): Optional scheduler-specific configuration parameters
153
+
154
+ Example:
155
+ {
156
+ "name": "hyperband",
157
+ "options": {
158
+ "max_t": 100,
159
+ "reduction_factor": 3
160
+ }
161
+ }
162
+ """
163
+
164
+ name: str
165
+ options: Optional[str] = None
166
+
167
+
168
+ class TuneConfig(BaseModel):
169
+ """
170
+ Configuration for Ray Tune hyperparameter optimization.
171
+
172
+ Used when is_tune=True to configure the hyperparameter search process.
173
+
174
+ Attributes:
175
+ mode (Optional[str]): Optimization mode - 'max' or 'min'
176
+ metric (Optional[str]): Name of the metric to optimize
177
+ num_samples (int): Number of hyperparameter configurations to try (default: 1)
178
+ max_concurrent_trials (Optional[int]): Maximum number of trials to run in parallel
179
+ search_alg (Optional[SearchAlgo]): Search algorithm configuration
180
+ scheduler (Optional[Scheduler]): Trial scheduler configuration
181
+
182
+ Example:
183
+ {
184
+ "mode": "max",
185
+ "metric": "accuracy",
186
+ "num_samples": 20,
187
+ "max_concurrent_trials": 4,
188
+ "search_alg": {
189
+ "name": "hyperoptsearch"
190
+ },
191
+ "scheduler": {
192
+ "name": "hyperband",
193
+ "options": {"max_t": 100}
194
+ }
195
+ }
196
+ """
197
+
198
+ mode: Optional[str] = None
199
+ metric: Optional[str] = None
200
+ num_samples: int = 1
201
+ max_concurrent_trials: Optional[int] = None
202
+ search_alg: Optional[SearchAlgo] = None
203
+ scheduler: Optional[Scheduler] = None
33
204
 
34
205
 
35
206
  class TrainParams(BaseModel):
207
+ """
208
+ Parameters for TrainAction supporting both regular training and hyperparameter tuning.
209
+
210
+ Attributes:
211
+ name (str): Name for the training/tuning job
212
+ description (str): Description of the job
213
+ checkpoint (int | None): Optional checkpoint ID to resume from
214
+ dataset (int): Dataset ID to use for training
215
+ is_tune (bool): Enable hyperparameter tuning mode (default: False)
216
+ tune_config (Optional[TuneConfig]): Tune configuration (required when is_tune=True)
217
+ num_cpus (Optional[int]): CPUs per trial (tuning mode only)
218
+ num_gpus (Optional[int]): GPUs per trial (tuning mode only)
219
+ hyperparameter (Optional[Any]): Fixed hyperparameters (required when is_tune=False)
220
+ hyperparameters (Optional[list]): Hyperparameter search space (required when is_tune=True)
221
+
222
+ Hyperparameter format when is_tune=True:
223
+ Each item in hyperparameters list must have:
224
+ - 'name': Parameter name (string)
225
+ - 'type': Distribution type (string)
226
+ - Type-specific parameters:
227
+ - uniform/quniform: 'min', 'max'
228
+ - loguniform/qloguniform: 'min', 'max', 'base'
229
+ - randn/qrandn: 'mean', 'sd'
230
+ - randint/qrandint: 'min', 'max'
231
+ - lograndint/qlograndint: 'min', 'max', 'base'
232
+ - choice/grid_search: 'options'
233
+
234
+ Example (Training mode):
235
+ {
236
+ "name": "my_training",
237
+ "dataset": 123,
238
+ "is_tune": false,
239
+ "hyperparameter": {
240
+ "epochs": 100,
241
+ "batch_size": 32,
242
+ "learning_rate": 0.001
243
+ }
244
+ }
245
+
246
+ Example (Tuning mode):
247
+ {
248
+ "name": "my_tuning",
249
+ "dataset": 123,
250
+ "is_tune": true,
251
+ "hyperparameters": [
252
+ {"name": "batch_size", "type": "choice", "options": [16, 32, 64]},
253
+ {"name": "learning_rate", "type": "loguniform", "min": 0.0001, "max": 0.01, "base": 10},
254
+ {"name": "epochs", "type": "randint", "min": 5, "max": 15}
255
+ ],
256
+ "tune_config": {
257
+ "mode": "max",
258
+ "metric": "accuracy",
259
+ "num_samples": 10
260
+ }
261
+ }
262
+ """
263
+
36
264
  name: Annotated[str, AfterValidator(non_blank)]
37
265
  description: str
38
266
  checkpoint: int | None
39
267
  dataset: int
40
- hyperparameter: Hyperparameter
268
+ is_tune: bool = False
269
+ tune_config: Optional[TuneConfig] = None
270
+ num_cpus: Optional[int] = None
271
+ num_gpus: Optional[int] = None
272
+ hyperparameter: Optional[dict] = None # plan to be deprecated
273
+ hyperparameters: Optional[list] = None
274
+
275
+ @field_validator('hyperparameter', mode='before')
276
+ @classmethod
277
+ def validate_hyperparameter(cls, v, info):
278
+ """Validate hyperparameter for train mode (is_tune=False)"""
279
+ # Get is_tune flag to determine if this field should be validated
280
+ is_tune = info.data.get('is_tune', False)
281
+
282
+ # If is_tune=True, hyperparameter should be None/not used
283
+ # Just return whatever was passed (will be validated in model_validator)
284
+ if is_tune:
285
+ return v
286
+
287
+ # For train mode, hyperparameter should be a dict
288
+ if isinstance(v, dict):
289
+ return v
290
+ elif isinstance(v, list):
291
+ raise ValueError(
292
+ 'hyperparameter must be a dict, not a list. '
293
+ 'If you want to use hyperparameter tuning, '
294
+ 'set "is_tune": true and use "hyperparameters" instead.'
295
+ )
296
+ else:
297
+ raise ValueError('hyperparameter must be a dict')
298
+
299
+ @field_validator('hyperparameters', mode='before')
300
+ @classmethod
301
+ def validate_hyperparameters(cls, v, info):
302
+ """Validate hyperparameters for tune mode (is_tune=True)"""
303
+ # Get is_tune flag to determine if this field should be validated
304
+ is_tune = info.data.get('is_tune', False)
305
+
306
+ # If is_tune=False, hyperparameters should be None/not used
307
+ # Just return whatever was passed (will be validated in model_validator)
308
+ if not is_tune:
309
+ return v
310
+
311
+ # For tune mode, hyperparameters should be a list
312
+ if isinstance(v, list):
313
+ return v
314
+ elif isinstance(v, dict):
315
+ raise ValueError(
316
+ 'hyperparameters must be a list, not a dict. '
317
+ 'If you want to use fixed hyperparameters for training, '
318
+ 'set "is_tune": false and use "hyperparameter" instead.'
319
+ )
320
+ else:
321
+ raise ValueError('hyperparameters must be a list')
41
322
 
42
323
  @field_validator('name')
43
324
  @staticmethod
44
325
  def unique_name(value, info):
45
326
  action = info.context['action']
46
327
  client = action.client
328
+ is_tune = info.data.get('is_tune', False)
329
+ encoded_value = value.replace(':', '%3A').replace(',', '%2C')
47
330
  try:
48
- model_exists = client.exists('list_models', params={'name': value})
49
- job_exists = client.exists(
50
- 'list_jobs',
51
- params={
52
- 'ids_ex': action.job_id,
53
- 'category': 'neural_net',
54
- 'job__action': 'train',
55
- 'is_active': True,
56
- 'params': f'name:{value.replace(":", "%3A")}',
57
- },
58
- )
59
- assert not model_exists and not job_exists, '존재하는 학습 이름입니다.'
331
+ if not is_tune:
332
+ model_exists = client.exists('list_models', params={'name': value})
333
+ job_exists = client.exists(
334
+ 'list_jobs',
335
+ params={
336
+ 'ids_ex': action.job_id,
337
+ 'category': 'neural_net',
338
+ 'job__action': 'train',
339
+ 'is_active': True,
340
+ 'params': f'name:{encoded_value}',
341
+ },
342
+ )
343
+ assert not model_exists and not job_exists, '존재하는 학습 이름입니다.'
344
+ else:
345
+ job_exists = client.exists(
346
+ 'list_jobs',
347
+ params={
348
+ 'ids_ex': action.job_id,
349
+ 'category': 'neural_net',
350
+ 'job__action': 'train',
351
+ 'is_active': True,
352
+ 'params': f'name:{encoded_value}',
353
+ },
354
+ )
355
+ assert not job_exists, '존재하는 튜닝 작업 이름입니다.'
60
356
  except ClientError:
61
357
  raise PydanticCustomError('client_error', '')
62
358
  return value
63
359
 
360
+ @model_validator(mode='after')
361
+ def validate_tune_params(self):
362
+ if self.is_tune:
363
+ # When is_tune=True, hyperparameters is required
364
+ if self.hyperparameters is None:
365
+ raise ValueError('hyperparameters is required when is_tune=True')
366
+ if self.hyperparameter is not None:
367
+ raise ValueError('hyperparameter should not be provided when is_tune=True, use hyperparameters instead')
368
+ if self.tune_config is None:
369
+ raise ValueError('tune_config is required when is_tune=True')
370
+ else:
371
+ # When is_tune=False, either hyperparameter or hyperparameters is required
372
+ if self.hyperparameter is None and self.hyperparameters is None:
373
+ raise ValueError('Either hyperparameter or hyperparameters is required when is_tune=False')
374
+
375
+ if self.hyperparameter is not None and self.hyperparameters is not None:
376
+ raise ValueError('Provide either hyperparameter or hyperparameters, but not both')
377
+
378
+ if self.hyperparameters is not None:
379
+ if not isinstance(self.hyperparameters, list) or len(self.hyperparameters) != 1:
380
+ raise ValueError('hyperparameters must be a list containing a single dictionary')
381
+ self.hyperparameter = self.hyperparameters[0]
382
+ self.hyperparameters = None
383
+ return self
384
+
64
385
 
65
386
  @register_action
66
387
  class TrainAction(Action):
67
- name = 'train'
68
- category = PluginCategory.NEURAL_NET
69
- method = RunMethod.JOB
70
- run_class = TrainRun
71
- params_model = TrainParams
72
- progress_categories = {
388
+ TRAIN_PROGRESS = {
73
389
  'dataset': {
74
390
  'proportion': 20,
75
391
  },
@@ -81,8 +397,88 @@ class TrainAction(Action):
81
397
  },
82
398
  }
83
399
 
400
+ TUNE_PROGRESS = {
401
+ 'dataset': {
402
+ 'proportion': 20,
403
+ },
404
+ 'trials': {
405
+ 'proportion': 75,
406
+ },
407
+ 'model_upload': {
408
+ 'proportion': 5,
409
+ },
410
+ }
411
+
412
+ """
413
+ **Important notes when using train with is_tune=True:**
414
+
415
+ 1. Path to the model output (which is the return value of your train function)
416
+ should be set to the checkpoint_output attribute of the run object **before**
417
+ starting the training.
418
+ 2. Before exiting the training function, report the results to Tune.
419
+ 3. When using own tune.py, take note of the difference in the order of parameters.
420
+ tune() function starts with hyperparameter, run, dataset, checkpoint, **kwargs
421
+ whereas the train() function starts with run, dataset, hyperparameter, checkpoint, **kwargs.
422
+ ----
423
+ 1)
424
+ Set the output path for the checkpoint to export best model
425
+
426
+ output_path = Path('path/to/your/weights')
427
+ run.checkpoint_output = str(output_path)
428
+
429
+ 2)
430
+ Before exiting the training function, report the results to Tune.
431
+ The results_dict should contain the metrics you want to report.
432
+
433
+ Example: (In train function)
434
+ results_dict = {
435
+ "accuracy": accuracy,
436
+ "loss": loss,
437
+ # Add other metrics as needed
438
+ }
439
+ if hasattr(self.dm_run, 'is_tune') and self.dm_run.is_tune:
440
+ tune.report(results_dict, checkpoint=tune.Checkpoint.from_directory(self.dm_run.checkpoint_output))
441
+
442
+
443
+ 3)
444
+ tune() function takes hyperparameter, run, dataset, checkpoint, **kwargs in that order
445
+ whereas train() function takes run, dataset, hyperparameter, checkpoint, **kwargs in that order.
446
+
447
+ """
448
+
449
+ name = 'train'
450
+ category = PluginCategory.NEURAL_NET
451
+ method = RunMethod.JOB
452
+ run_class = TrainRun
453
+ params_model = TrainParams
454
+ progress_categories = None
455
+
456
+ def __init__(self, params, plugin_config, requirements=None, envs=None, job_id=None, direct=False, debug=False):
457
+ selected = self.TUNE_PROGRESS if (params or {}).get('is_tune') else self.TRAIN_PROGRESS
458
+ self.progress_categories = copy.deepcopy(selected)
459
+ super().__init__(
460
+ params, plugin_config, requirements=requirements, envs=envs, job_id=job_id, direct=direct, debug=debug
461
+ )
462
+
84
463
  def start(self):
85
- hyperparameter = self.params['hyperparameter']
464
+ try:
465
+ if self.params.get('is_tune', False):
466
+ return self._start_tune()
467
+ return self._start_train()
468
+ finally:
469
+ # Always emit completion log so backend can record end time even on failures
470
+ self.run.end_log()
471
+
472
+ def _start_train(self):
473
+ """Original train logic"""
474
+ hyperparameter = self.params.get('hyperparameter')
475
+ if hyperparameter is None:
476
+ hyperparameters = self.params.get('hyperparameters') or []
477
+ if not hyperparameters:
478
+ raise ValueError('hyperparameter is missing for train mode')
479
+ hyperparameter = hyperparameters[0]
480
+ # Persist the normalized form so later steps (e.g., create_model) find it
481
+ self.params['hyperparameter'] = hyperparameter
86
482
 
87
483
  # download dataset
88
484
  self.run.log_message('Preparing dataset for training.')
@@ -104,9 +500,311 @@ class TrainAction(Action):
104
500
  model = self.create_model(result)
105
501
  self.run.set_progress(1, 1, category='model_upload')
106
502
 
107
- self.run.end_log()
108
503
  return {'model_id': model['id'] if model else None}
109
504
 
505
+ def _start_tune(self):
506
+ """Tune logic using Ray Tune for hyperparameter optimization"""
507
+ from ray import tune
508
+
509
+ # Ensure Ray is connected to the cluster so GPU resources are visible to trials
510
+ self.ray_init()
511
+
512
+ class _TuneTrialsLoggingCallback(tune.Callback):
513
+ """Capture Ray Tune trial table snapshots and forward them to run.log_trials."""
514
+
515
+ BASE_COLUMNS = ('trial_id', 'status')
516
+ METRIC_COLUMN_LIMIT = 4
517
+ RESERVED_RESULT_KEYS = {
518
+ 'config',
519
+ 'date',
520
+ 'done',
521
+ 'experiment_id',
522
+ 'experiment_state',
523
+ 'experiment_tag',
524
+ 'hostname',
525
+ 'iterations_since_restore',
526
+ 'logdir',
527
+ 'node_ip',
528
+ 'pid',
529
+ 'restored_from_trial_id',
530
+ 'time_since_restore',
531
+ 'time_this_iter_s',
532
+ 'time_total',
533
+ 'time_total_s',
534
+ 'timestamp',
535
+ 'timesteps_since_restore',
536
+ 'timesteps_total',
537
+ 'training_iteration',
538
+ 'trial_id',
539
+ }
540
+
541
+ def __init__(self, run):
542
+ self.run = run
543
+ self.trial_rows: Dict[str, Dict[str, object]] = {}
544
+ self.config_keys: list[str] = []
545
+ self.metric_keys: list[str] = []
546
+ self._last_snapshot = None
547
+
548
+ def on_trial_result(self, iteration, trials, trial, result, **info):
549
+ self._record_trial(trial, result, status_override='RUNNING')
550
+ self._emit_snapshot()
551
+
552
+ def on_trial_complete(self, iteration, trials, trial, **info):
553
+ self._record_trial(trial, getattr(trial, 'last_result', None), status_override='TERMINATED')
554
+ self._emit_snapshot()
555
+
556
+ def on_trial_error(self, iteration, trials, trial, **info):
557
+ self._record_trial(trial, getattr(trial, 'last_result', None), status_override='ERROR')
558
+ self._emit_snapshot()
559
+
560
+ def on_step_end(self, iteration, trials, **info):
561
+ updated = False
562
+ for trial in trials or []:
563
+ status = getattr(trial, 'status', None)
564
+ existing = self.trial_rows.get(trial.trial_id)
565
+ existing_status = existing.get('status') if existing else None
566
+ if existing is None or (status and status != existing_status):
567
+ self._record_trial(
568
+ trial,
569
+ getattr(trial, 'last_result', None),
570
+ status_override=status,
571
+ )
572
+ updated = True
573
+ if updated:
574
+ self._emit_snapshot()
575
+
576
+ def _record_trial(self, trial, result, status_override=None):
577
+ if not self.run or not getattr(self.run, 'log_trials', None):
578
+ return
579
+
580
+ row = self.trial_rows.setdefault(trial.trial_id, {})
581
+ result = result or {}
582
+ if not isinstance(result, dict):
583
+ result = {}
584
+
585
+ row['trial_id'] = trial.trial_id
586
+ row['status'] = status_override or getattr(trial, 'status', 'PENDING')
587
+ config_data = self._extract_config(trial.config or {})
588
+ metric_data = self._extract_metrics(result)
589
+
590
+ row.update(config_data)
591
+ row.update(metric_data)
592
+
593
+ self._track_columns(config_data.keys(), metric_data.keys())
594
+
595
+ def _extract_config(self, config):
596
+ flat = {}
597
+ if not isinstance(config, dict):
598
+ return flat
599
+ for key, value in self._flatten_items(config):
600
+ serialized = self._serialize_config_value(value)
601
+ flat[key] = serialized
602
+ return flat
603
+
604
+ def _extract_metrics(self, result):
605
+ metrics = {}
606
+ if not isinstance(result, dict):
607
+ return metrics
608
+
609
+ nested = result.get('metrics')
610
+ if isinstance(nested, dict):
611
+ for key, value in self._flatten_items(nested, prefix='metrics'):
612
+ serialized = self._serialize_metric_value(value)
613
+ if serialized is not None:
614
+ metrics[key] = serialized
615
+
616
+ for key, value in result.items():
617
+ if key in self.RESERVED_RESULT_KEYS or key == 'metrics':
618
+ continue
619
+ if isinstance(value, dict):
620
+ continue
621
+ serialized = self._serialize_metric_value(value)
622
+ if serialized is not None:
623
+ metrics[key] = serialized
624
+
625
+ return metrics
626
+
627
+ def _track_columns(self, config_keys, metric_keys):
628
+ for key in config_keys:
629
+ if key not in self.config_keys:
630
+ self.config_keys.append(key)
631
+ for key in metric_keys:
632
+ if key not in self.metric_keys and len(self.metric_keys) < self.METRIC_COLUMN_LIMIT:
633
+ self.metric_keys.append(key)
634
+
635
+ def _emit_snapshot(self):
636
+ if not self.trial_rows:
637
+ return
638
+
639
+ base_keys = list(self.BASE_COLUMNS)
640
+ config_keys = list(self.config_keys)
641
+ metric_keys = list(self.metric_keys)
642
+ columns = base_keys + config_keys + metric_keys
643
+
644
+ ordered_trials = {}
645
+ flat_rows = []
646
+ for trial_id in sorted(self.trial_rows.keys()):
647
+ row = self.trial_rows[trial_id]
648
+ base_values = [row.get(column) for column in base_keys]
649
+ hyper_values = [row.get(column) for column in config_keys]
650
+ metric_values = [row.get(column) for column in metric_keys]
651
+ flat_values = base_values + hyper_values + metric_values
652
+ ordered_trials[trial_id] = {
653
+ 'base': base_values,
654
+ 'hyperparameters': hyper_values,
655
+ 'metrics': metric_values,
656
+ }
657
+ flat_rows.append((trial_id, tuple(flat_values)))
658
+
659
+ snapshot = (
660
+ tuple(columns),
661
+ tuple(flat_rows),
662
+ )
663
+ if snapshot == self._last_snapshot:
664
+ return
665
+ self._last_snapshot = snapshot
666
+
667
+ self.run.log_trials(
668
+ base=base_keys,
669
+ trials=ordered_trials,
670
+ hyperparameters=config_keys,
671
+ metrics=metric_keys,
672
+ best_trial='',
673
+ )
674
+ self._update_trials_progress()
675
+
676
+ def _flatten_items(self, data, prefix=None):
677
+ if not isinstance(data, dict):
678
+ return
679
+ for key, value in data.items():
680
+ key_str = str(key)
681
+ current = f'{prefix}/{key_str}' if prefix else key_str
682
+ if isinstance(value, dict):
683
+ yield from self._flatten_items(value, current)
684
+ else:
685
+ yield current, value
686
+
687
+ def _update_trials_progress(self):
688
+ total = getattr(self.run, 'num_samples', None)
689
+ if not total:
690
+ return
691
+
692
+ completed_statuses = {'TERMINATED', 'ERROR'}
693
+ completed = sum(1 for row in self.trial_rows.values() if row.get('status') in completed_statuses)
694
+ completed = min(completed, total)
695
+
696
+ try:
697
+ self.run.set_progress(completed, total, category='trials')
698
+ except Exception: # pragma: no cover - safeguard against logging failures
699
+ self.run.log_message('Failed to update trials progress.')
700
+
701
+ def _serialize_config_value(self, value):
702
+ if isinstance(value, (str, bool)) or value is None:
703
+ return value
704
+ if isinstance(value, Number):
705
+ return float(value) if not isinstance(value, bool) else value
706
+ return str(value)
707
+
708
+ def _serialize_metric_value(self, value):
709
+ if isinstance(value, Number):
710
+ return float(value)
711
+ return None
712
+
713
+ # Mark run as tune
714
+ self.run.is_tune = True
715
+
716
+ # download dataset
717
+ self.run.log_message('Preparing dataset for hyperparameter tuning.')
718
+ input_dataset = self.get_dataset()
719
+
720
+ # retrieve checkpoint
721
+ checkpoint = None
722
+ if self.params['checkpoint']:
723
+ self.run.log_message('Retrieving checkpoint.')
724
+ checkpoint = self.get_model(self.params['checkpoint'])
725
+
726
+ # train dataset
727
+ self.run.log_message('Starting training for hyperparameter tuning.')
728
+
729
+ # Save num_samples to TrainRun for logging
730
+ self.run.num_samples = self.params['tune_config']['num_samples']
731
+
732
+ tune_config = self.params['tune_config']
733
+
734
+ entrypoint = self.entrypoint
735
+ if not self._tune_override_exists():
736
+ # entrypoint must be train entrypoint
737
+ train_entrypoint = entrypoint
738
+
739
+ def _tune(param_space, run, dataset, checkpoint=None, **kwargs):
740
+ return train_entrypoint(run, dataset, param_space, checkpoint, **kwargs)
741
+
742
+ entrypoint = _tune
743
+
744
+ entrypoint = self._wrap_tune_entrypoint(entrypoint, tune_config.get('metric'))
745
+
746
+ train_fn = tune.with_parameters(entrypoint, run=self.run, dataset=input_dataset, checkpoint=checkpoint)
747
+
748
+ # Extract search_alg and scheduler as separate objects to avoid JSON serialization issues
749
+ search_alg = self.convert_tune_search_alg(tune_config)
750
+ scheduler = self.convert_tune_scheduler(tune_config)
751
+
752
+ # Create a copy of tune_config without non-serializable objects
753
+ tune_config_dict = {
754
+ 'mode': tune_config.get('mode'),
755
+ 'metric': tune_config.get('metric'),
756
+ 'num_samples': tune_config.get('num_samples', 1),
757
+ 'max_concurrent_trials': tune_config.get('max_concurrent_trials'),
758
+ }
759
+
760
+ # Add search_alg and scheduler to tune_config_dict only if they exist
761
+ if search_alg is not None:
762
+ tune_config_dict['search_alg'] = search_alg
763
+ if scheduler is not None:
764
+ tune_config_dict['scheduler'] = scheduler
765
+
766
+ hyperparameters = self.params['hyperparameters']
767
+ param_space = self.convert_tune_params(hyperparameters)
768
+ temp_path = tempfile.TemporaryDirectory()
769
+ trials_logger = _TuneTrialsLoggingCallback(self.run)
770
+
771
+ trainable = tune.with_resources(train_fn, {'cpu': 1, 'gpu': 0.5})
772
+ print('tune_config :', tune_config)
773
+ print('tune_config_dict :', tune_config_dict)
774
+ # print('self.tune_resources :', self.tune_resources)
775
+ # trainable = tune.with_resources(train_fn, self.tune_resources)
776
+
777
+ tuner = tune.Tuner(
778
+ trainable,
779
+ tune_config=tune.TuneConfig(**tune_config_dict),
780
+ run_config=tune.RunConfig(
781
+ name=f'synapse_tune_hpo_{self.job_id}',
782
+ log_to_file=('stdout.log', 'stderr.log'),
783
+ storage_path=temp_path.name,
784
+ callbacks=[trials_logger],
785
+ ),
786
+ param_space=param_space,
787
+ )
788
+ result = tuner.fit()
789
+
790
+ trial_models_map, trial_models_summary = self._upload_tune_trial_models(result)
791
+
792
+ best_result = result.get_best_result()
793
+ artifact_path = self._get_tune_artifact_path(best_result)
794
+ self._override_best_trial(best_result, artifact_path)
795
+
796
+ # upload model_data
797
+ self.run.log_message('Registering best model data.')
798
+ self.run.set_progress(0, 1, category='model_upload')
799
+ if artifact_path not in trial_models_map:
800
+ trial_models_map[artifact_path] = self.create_model_from_result(best_result, artifact_path=artifact_path)
801
+ self.run.set_progress(1, 1, category='model_upload')
802
+
803
+ return {
804
+ 'best_result': best_result.config,
805
+ 'trial_models': trial_models_summary,
806
+ }
807
+
110
808
  def get_dataset(self):
111
809
  client = self.run.client
112
810
  assert bool(client)
@@ -145,9 +843,20 @@ class TrainAction(Action):
145
843
  configuration_fields = ['hyperparameter']
146
844
  configuration = {field: params.pop(field) for field in configuration_fields}
147
845
 
148
- with tempfile.TemporaryDirectory() as temp_path:
846
+ run_name = params.get('name') or f'{self.plugin_release.name}-{self.job_id}'
847
+ unique_name = run_name
848
+
849
+ # Derive a stable id from the path for naming
850
+ trial_id = self._resolve_trial_id(type('Result', (), {})(), artifact_path=path)
851
+ if trial_id:
852
+ unique_name = f'{run_name}_{trial_id}'
853
+
854
+ params['name'] = unique_name
855
+
856
+ temp_dir = tempfile.mkdtemp()
857
+ try:
149
858
  input_path = Path(path)
150
- archive_path = Path(temp_path, 'archive.zip')
859
+ archive_path = Path(temp_dir, 'archive.zip')
151
860
  archive(input_path, archive_path)
152
861
 
153
862
  return self.client.create_model({
@@ -157,3 +866,392 @@ class TrainAction(Action):
157
866
  'configuration': configuration,
158
867
  **params,
159
868
  })
869
+ finally:
870
+ shutil.rmtree(temp_dir, ignore_errors=True)
871
+
872
+ @property
873
+ def tune_resources(self):
874
+ resources = {}
875
+ for option in ['num_cpus', 'num_gpus']:
876
+ option_value = self.params.get(option)
877
+ if option_value:
878
+ # Remove the 'num_' prefix and trailing s from the option name
879
+ resources[(lambda s: s[4:-1])(option)] = option_value
880
+ return resources
881
+
882
+ def _upload_tune_trial_models(self, result_grid):
883
+ trial_models = {}
884
+ trial_summaries = []
885
+
886
+ total_results = len(result_grid)
887
+
888
+ for index in range(total_results):
889
+ trial_result = result_grid[index]
890
+
891
+ if getattr(trial_result, 'error', None):
892
+ continue
893
+
894
+ artifact_path = self._get_tune_artifact_path(trial_result)
895
+ if not artifact_path:
896
+ trial_id = getattr(trial_result, 'trial_id', None)
897
+ self.run.log_message(f'Skipping model registration: no checkpoint path for trial {trial_id}')
898
+ continue
899
+
900
+ try:
901
+ model = self.create_model_from_result(trial_result, artifact_path=artifact_path)
902
+ except Exception as exc: # pragma: no cover - best effort logging
903
+ self.run.log_message(f'Failed to register model for trial at {artifact_path}: {exc}')
904
+ continue
905
+
906
+ if model:
907
+ trial_models[artifact_path] = model
908
+ trial_summaries.append({
909
+ 'trial_logdir': artifact_path,
910
+ 'model_id': model.get('id'),
911
+ 'config': getattr(trial_result, 'config', None),
912
+ 'metrics': getattr(trial_result, 'metrics', None),
913
+ })
914
+
915
+ return trial_models, trial_summaries
916
+
917
+ def _override_best_trial(self, best_result, artifact_path=None):
918
+ if not best_result:
919
+ return
920
+
921
+ best_config = getattr(best_result, 'config', None)
922
+ if not isinstance(best_config, dict):
923
+ return
924
+
925
+ if artifact_path is None:
926
+ artifact_path = self._get_tune_artifact_path(best_result)
927
+
928
+ trial_id = self._resolve_trial_id(best_result, artifact_path)
929
+
930
+ if not trial_id:
931
+ self.run.log_message('Skipping override_best_trial request: trial_id missing.')
932
+ return
933
+
934
+ payload = {'trial_id': trial_id, **best_config}
935
+
936
+ url = f'trains/{self.job_id}/override_best_trial/'
937
+ self.run.log_message(f'Calling override_best_trial: {url} payload={payload}')
938
+
939
+ try:
940
+ self.client._put(url, data=payload)
941
+ # Log trials with best_trial after successful PUT request
942
+ last_snapshot = getattr(self.run, '_last_trials_payload', None)
943
+ if isinstance(last_snapshot, dict) and 'trials' in last_snapshot:
944
+ final_snapshot = copy.deepcopy(last_snapshot)
945
+ final_snapshot['best_trial'] = trial_id
946
+ self.run.log_trials(data=final_snapshot)
947
+ else:
948
+ self.run.log_trials(best_trial=trial_id)
949
+ except ClientError as exc: # pragma: no cover - network failure should not break run
950
+ self.run.log_message(f'Failed to override best trial: {exc}')
951
+
952
+ def create_model_from_result(self, result, *, artifact_path=None):
953
+ params = copy.deepcopy(self.params)
954
+ configuration_fields = ['hyperparameters']
955
+ configuration = {field: params.pop(field) for field in configuration_fields}
956
+ configuration['tune_trial'] = {
957
+ 'config': getattr(result, 'config', None),
958
+ 'metrics': getattr(result, 'metrics', None),
959
+ 'logdir': artifact_path or getattr(result, 'path', None),
960
+ }
961
+
962
+ if artifact_path is None:
963
+ artifact_path = self._get_tune_artifact_path(result)
964
+
965
+ if not artifact_path:
966
+ raise ValueError('No checkpoint path available to create model from result.')
967
+
968
+ temp_dir = tempfile.mkdtemp()
969
+ archive_path = Path(temp_dir, 'archive.zip')
970
+
971
+ # Archive tune results
972
+ # https://docs.ray.io/en/latest/tune/tutorials/tune_get_data_in_and_out.html#getting-data-out-of-tune-using-checkpoints-other-artifacts
973
+ archive(artifact_path, archive_path)
974
+
975
+ unique_name = params.get('name') or f'{self.plugin_release.name}-{self.job_id}'
976
+ trial_id = self._resolve_trial_id(result, artifact_path)
977
+ if trial_id:
978
+ unique_name = f'{unique_name}_{trial_id}'
979
+ params['name'] = unique_name
980
+
981
+ try:
982
+ return self.client.create_model({
983
+ 'plugin': self.plugin_release.plugin,
984
+ 'version': self.plugin_release.version,
985
+ 'file': str(archive_path),
986
+ 'configuration': configuration,
987
+ **params,
988
+ })
989
+ finally:
990
+ shutil.rmtree(temp_dir, ignore_errors=True)
991
+
992
+ @staticmethod
993
+ def convert_tune_scheduler(tune_config):
994
+ """
995
+ Convert YAML hyperparameter configuration to a Ray Tune scheduler.
996
+
997
+ Args:
998
+ tune_config (dict): Hyperparameter configuration.
999
+
1000
+ Returns:
1001
+ object: Ray Tune scheduler instance.
1002
+
1003
+ Supported schedulers:
1004
+ - 'fifo': FIFOScheduler (default)
1005
+ - 'hyperband': HyperBandScheduler
1006
+ """
1007
+
1008
+ from ray.tune.schedulers import (
1009
+ ASHAScheduler,
1010
+ FIFOScheduler,
1011
+ HyperBandScheduler,
1012
+ MedianStoppingRule,
1013
+ PopulationBasedTraining,
1014
+ )
1015
+
1016
+ if tune_config.get('scheduler') is None:
1017
+ return None
1018
+
1019
+ scheduler_map = {
1020
+ 'fifo': FIFOScheduler,
1021
+ 'asha': ASHAScheduler,
1022
+ 'hyperband': HyperBandScheduler,
1023
+ 'pbt': PopulationBasedTraining,
1024
+ 'median': MedianStoppingRule,
1025
+ }
1026
+
1027
+ scheduler_type = tune_config['scheduler'].get('name', 'fifo').lower()
1028
+ scheduler_class = scheduler_map.get(scheduler_type, FIFOScheduler)
1029
+
1030
+ # 옵션이 있는 경우 전달하고, 없으면 기본 생성자 호출
1031
+ options = tune_config['scheduler'].get('options')
1032
+
1033
+ # options가 None이거나 빈 딕셔너리가 아닌 경우에만 전달
1034
+ scheduler = scheduler_class(**options) if options else scheduler_class()
1035
+
1036
+ return scheduler
1037
+
1038
+ @staticmethod
1039
+ def convert_tune_search_alg(tune_config):
1040
+ """
1041
+ Convert YAML hyperparameter configuration to Ray Tune search algorithm.
1042
+
1043
+ Args:
1044
+ tune_config (dict): Hyperparameter configuration.
1045
+
1046
+ Returns:
1047
+ object: Ray Tune search algorithm instance or None
1048
+
1049
+ Supported search algorithms:
1050
+ - 'bayesoptsearch': Bayesian optimization
1051
+ - 'hyperoptsearch': Tree-structured Parzen Estimator
1052
+ - 'basicvariantgenerator': Random search (default)
1053
+ """
1054
+
1055
+ if tune_config.get('search_alg') is None:
1056
+ return None
1057
+
1058
+ search_alg_name = tune_config['search_alg']['name'].lower()
1059
+ metric = tune_config['metric']
1060
+ mode = tune_config['mode']
1061
+ points_to_evaluate = tune_config['search_alg'].get('points_to_evaluate', None)
1062
+
1063
+ if search_alg_name == 'axsearch':
1064
+ from ray.tune.search.ax import AxSearch
1065
+
1066
+ search_alg = AxSearch(metric=metric, mode=mode)
1067
+ elif search_alg_name == 'bayesoptsearch':
1068
+ from ray.tune.search.bayesopt import BayesOptSearch
1069
+
1070
+ search_alg = BayesOptSearch(metric=metric, mode=mode)
1071
+ elif search_alg_name == 'hyperoptsearch':
1072
+ from ray.tune.search.hyperopt import HyperOptSearch
1073
+
1074
+ search_alg = HyperOptSearch(metric=metric, mode=mode)
1075
+ elif search_alg_name == 'optunasearch':
1076
+ from ray.tune.search.optuna import OptunaSearch
1077
+
1078
+ search_alg = OptunaSearch(metric=metric, mode=mode)
1079
+ elif search_alg_name == 'basicvariantgenerator':
1080
+ from ray.tune.search.basic_variant import BasicVariantGenerator
1081
+
1082
+ search_alg = BasicVariantGenerator(
1083
+ points_to_evaluate=points_to_evaluate, max_concurrent=tune_config['max_concurrent_trials']
1084
+ )
1085
+ else:
1086
+ raise ValueError(
1087
+ f'Unsupported search algorithm: {search_alg_name}. '
1088
+ f'Supported algorithms are: bayesoptsearch, hyperoptsearch, basicvariantgenerator'
1089
+ )
1090
+
1091
+ return search_alg
1092
+
1093
+ @staticmethod
1094
+ def convert_tune_params(param_list):
1095
+ """
1096
+ Convert YAML hyperparameter configuration to Ray Tune parameter dictionary.
1097
+
1098
+ Args:
1099
+ param_list (list): List of hyperparameter configurations.
1100
+
1101
+ Returns:
1102
+ dict: Ray Tune parameter dictionary
1103
+ """
1104
+ from ray import tune
1105
+
1106
+ param_handlers = {
1107
+ 'uniform': lambda p: tune.uniform(p['min'], p['max']),
1108
+ 'quniform': lambda p: tune.quniform(p['min'], p['max']),
1109
+ 'loguniform': lambda p: tune.loguniform(p['min'], p['max'], p['base']),
1110
+ 'qloguniform': lambda p: tune.qloguniform(p['min'], p['max'], p['base']),
1111
+ 'randn': lambda p: tune.randn(p['mean'], p['sd']),
1112
+ 'qrandn': lambda p: tune.qrandn(p['mean'], p['sd']),
1113
+ 'randint': lambda p: tune.randint(p['min'], p['max']),
1114
+ 'qrandint': lambda p: tune.qrandint(p['min'], p['max']),
1115
+ 'lograndint': lambda p: tune.lograndint(p['min'], p['max'], p['base']),
1116
+ 'qlograndint': lambda p: tune.qlograndint(p['min'], p['max'], p['base']),
1117
+ 'choice': lambda p: tune.choice(p['options']),
1118
+ 'grid_search': lambda p: tune.grid_search(p['options']),
1119
+ }
1120
+
1121
+ param_space = {}
1122
+
1123
+ for param in param_list:
1124
+ name = param['name']
1125
+ param_type = param['type']
1126
+
1127
+ if param_type in param_handlers:
1128
+ param_space[name] = param_handlers[param_type](param)
1129
+ else:
1130
+ raise ValueError(f'Unknown parameter type: {param_type}')
1131
+
1132
+ return param_space
1133
+
1134
+ @staticmethod
1135
+ def _tune_override_exists(module_path='plugin.tune') -> bool:
1136
+ try:
1137
+ import_string(module_path)
1138
+ return True
1139
+ except ImportError:
1140
+ return False
1141
+
1142
+ @staticmethod
1143
+ def _resolve_trial_id(result, artifact_path: Optional[str] = None) -> Optional[str]:
1144
+ """
1145
+ Return a stable trial_id.
1146
+
1147
+ Priority:
1148
+ 1. result.trial_id (Ray provided)
1149
+ 2. metrics['trial_id'] if present
1150
+ 3. Deterministic hash of artifact_path
1151
+ """
1152
+ trial_id = getattr(result, 'trial_id', None)
1153
+ if trial_id:
1154
+ return str(trial_id)
1155
+
1156
+ metrics = getattr(result, 'metrics', None)
1157
+ if isinstance(metrics, dict):
1158
+ trial_id = metrics.get('trial_id')
1159
+ if trial_id:
1160
+ return str(trial_id)
1161
+
1162
+ if artifact_path:
1163
+ import hashlib
1164
+
1165
+ return hashlib.sha1(str(artifact_path).encode()).hexdigest()[:12]
1166
+
1167
+ return None
1168
+
1169
+ def _get_tune_artifact_path(self, result) -> Optional[str]:
1170
+ """
1171
+ Determine the artifact/checkpoint path for a Ray Tune result.
1172
+
1173
+ Priority:
1174
+ 1. checkpoint_output provided via metrics (if present)
1175
+ 2. Explicit checkpoint path reported by Ray (result.checkpoint.*)
1176
+ No fallback to result.path to avoid mixing with trial logdir.
1177
+ """
1178
+ metrics = getattr(result, 'metrics', None)
1179
+ if isinstance(metrics, dict):
1180
+ for key in ('checkpoint_output', 'checkpoint', 'result'):
1181
+ path = metrics.get(key)
1182
+ if path:
1183
+ return str(path)
1184
+
1185
+ checkpoint = getattr(result, 'checkpoint', None)
1186
+ if checkpoint:
1187
+ for attr in ('path', '_local_path', '_uri'):
1188
+ path = getattr(checkpoint, attr, None)
1189
+ if path:
1190
+ return str(path)
1191
+ try:
1192
+ tmp_dir = tempfile.mkdtemp()
1193
+ checkpoint.to_directory(tmp_dir)
1194
+ return tmp_dir
1195
+ except Exception:
1196
+ pass
1197
+
1198
+ return None
1199
+
1200
+ def _wrap_tune_entrypoint(self, entrypoint: Callable, metric_key: Optional[str]) -> Callable:
1201
+ def _wrapped(*args, **kwargs):
1202
+ last_metrics: Optional[Dict[str, float]] = None
1203
+
1204
+ try:
1205
+ from ray import tune as ray_tune
1206
+ except ImportError:
1207
+ ray_tune = None
1208
+
1209
+ if ray_tune and hasattr(ray_tune, 'report'):
1210
+ original_report = ray_tune.report
1211
+
1212
+ def caching_report(metrics, *r_args, **r_kwargs):
1213
+ nonlocal last_metrics
1214
+ if isinstance(metrics, dict):
1215
+ last_metrics = metrics.copy()
1216
+ return original_report(metrics, *r_args, **r_kwargs)
1217
+
1218
+ ray_tune.report = caching_report
1219
+ else:
1220
+ original_report = None
1221
+
1222
+ try:
1223
+ result = entrypoint(*args, **kwargs)
1224
+ finally:
1225
+ if ray_tune and original_report:
1226
+ ray_tune.report = original_report
1227
+
1228
+ payload = self._normalize_tune_result(result, metric_key)
1229
+ if last_metrics:
1230
+ merged = last_metrics.copy()
1231
+ merged.update(payload)
1232
+ payload = merged
1233
+
1234
+ if metric_key and metric_key not in payload:
1235
+ payload[metric_key] = (last_metrics or {}).get(metric_key, 0.0)
1236
+
1237
+ return payload
1238
+
1239
+ wrapper_name = getattr(entrypoint, '__name__', None)
1240
+ if wrapper_name and (wrapper_name.startswith('_') or wrapper_name == '<lambda>'):
1241
+ wrapper_name = None
1242
+ final_name = wrapper_name or f'trial_{hash(entrypoint) & 0xFFFF:X}'
1243
+ _wrapped.__name__ = final_name
1244
+ _wrapped.__qualname__ = final_name
1245
+
1246
+ return _wrapped
1247
+
1248
+ @staticmethod
1249
+ def _normalize_tune_result(result, metric_key: Optional[str]) -> Dict:
1250
+ if isinstance(result, dict):
1251
+ return result
1252
+
1253
+ if isinstance(result, Number):
1254
+ target_key = metric_key or 'result'
1255
+ return {target_key: result}
1256
+
1257
+ return {'result': result}