synapse-sdk 1.0.0a35__py3-none-any.whl → 2025.11.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synapse-sdk might be problematic. Click here for more details.
- synapse_sdk/__init__.py +24 -0
- synapse_sdk/cli/__init__.py +308 -5
- synapse_sdk/cli/alias/utils.py +1 -1
- synapse_sdk/cli/code_server.py +687 -0
- synapse_sdk/cli/config.py +440 -0
- synapse_sdk/cli/devtools.py +90 -0
- synapse_sdk/cli/plugin/publish.py +23 -15
- synapse_sdk/clients/agent/__init__.py +9 -3
- synapse_sdk/clients/agent/container.py +133 -0
- synapse_sdk/clients/agent/core.py +19 -0
- synapse_sdk/clients/agent/ray.py +298 -9
- synapse_sdk/clients/backend/__init__.py +28 -12
- synapse_sdk/clients/backend/annotation.py +9 -1
- synapse_sdk/clients/backend/core.py +31 -4
- synapse_sdk/clients/backend/data_collection.py +186 -0
- synapse_sdk/clients/backend/hitl.py +1 -1
- synapse_sdk/clients/backend/integration.py +4 -3
- synapse_sdk/clients/backend/ml.py +1 -1
- synapse_sdk/clients/backend/models.py +35 -1
- synapse_sdk/clients/base.py +309 -36
- synapse_sdk/clients/ray/serve.py +2 -0
- synapse_sdk/devtools/__init__.py +0 -0
- synapse_sdk/devtools/config.py +94 -0
- synapse_sdk/devtools/docs/.gitignore +20 -0
- synapse_sdk/devtools/docs/README.md +41 -0
- synapse_sdk/devtools/docs/blog/2019-05-28-first-blog-post.md +12 -0
- synapse_sdk/devtools/docs/blog/2019-05-29-long-blog-post.md +44 -0
- synapse_sdk/devtools/docs/blog/2021-08-01-mdx-blog-post.mdx +24 -0
- synapse_sdk/devtools/docs/blog/2021-08-26-welcome/docusaurus-plushie-banner.jpeg +0 -0
- synapse_sdk/devtools/docs/blog/2021-08-26-welcome/index.md +29 -0
- synapse_sdk/devtools/docs/blog/authors.yml +25 -0
- synapse_sdk/devtools/docs/blog/tags.yml +19 -0
- synapse_sdk/devtools/docs/docs/api/clients/agent.md +43 -0
- synapse_sdk/devtools/docs/docs/api/clients/annotation-mixin.md +378 -0
- synapse_sdk/devtools/docs/docs/api/clients/backend.md +420 -0
- synapse_sdk/devtools/docs/docs/api/clients/base.md +257 -0
- synapse_sdk/devtools/docs/docs/api/clients/core-mixin.md +477 -0
- synapse_sdk/devtools/docs/docs/api/clients/data-collection-mixin.md +422 -0
- synapse_sdk/devtools/docs/docs/api/clients/hitl-mixin.md +554 -0
- synapse_sdk/devtools/docs/docs/api/clients/index.md +391 -0
- synapse_sdk/devtools/docs/docs/api/clients/integration-mixin.md +571 -0
- synapse_sdk/devtools/docs/docs/api/clients/ml-mixin.md +578 -0
- synapse_sdk/devtools/docs/docs/api/clients/ray.md +342 -0
- synapse_sdk/devtools/docs/docs/api/index.md +52 -0
- synapse_sdk/devtools/docs/docs/api/plugins/categories.md +43 -0
- synapse_sdk/devtools/docs/docs/api/plugins/models.md +114 -0
- synapse_sdk/devtools/docs/docs/api/plugins/utils.md +328 -0
- synapse_sdk/devtools/docs/docs/categories.md +0 -0
- synapse_sdk/devtools/docs/docs/cli-usage.md +280 -0
- synapse_sdk/devtools/docs/docs/concepts/index.md +38 -0
- synapse_sdk/devtools/docs/docs/configuration.md +83 -0
- synapse_sdk/devtools/docs/docs/contributing.md +306 -0
- synapse_sdk/devtools/docs/docs/examples/index.md +29 -0
- synapse_sdk/devtools/docs/docs/faq.md +179 -0
- synapse_sdk/devtools/docs/docs/features/converters/index.md +455 -0
- synapse_sdk/devtools/docs/docs/features/index.md +24 -0
- synapse_sdk/devtools/docs/docs/features/utils/file.md +415 -0
- synapse_sdk/devtools/docs/docs/features/utils/network.md +378 -0
- synapse_sdk/devtools/docs/docs/features/utils/storage.md +57 -0
- synapse_sdk/devtools/docs/docs/features/utils/types.md +51 -0
- synapse_sdk/devtools/docs/docs/installation.md +94 -0
- synapse_sdk/devtools/docs/docs/introduction.md +47 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/neural-net-plugins/train-action-overview.md +814 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/pre-annotation-plugins/pre-annotation-plugin-overview.md +198 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/pre-annotation-plugins/to-task-action-development.md +1645 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/pre-annotation-plugins/to-task-overview.md +717 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/pre-annotation-plugins/to-task-template-development.md +1380 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/upload-plugins/upload-plugin-action.md +948 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/upload-plugins/upload-plugin-overview.md +544 -0
- synapse_sdk/devtools/docs/docs/plugins/categories/upload-plugins/upload-plugin-template.md +766 -0
- synapse_sdk/devtools/docs/docs/plugins/export-plugins.md +1092 -0
- synapse_sdk/devtools/docs/docs/plugins/plugins.md +852 -0
- synapse_sdk/devtools/docs/docs/quickstart.md +78 -0
- synapse_sdk/devtools/docs/docs/troubleshooting.md +519 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/_category_.json +8 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/congratulations.md +23 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/create-a-blog-post.md +34 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/create-a-document.md +57 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/create-a-page.md +43 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/deploy-your-site.md +31 -0
- synapse_sdk/devtools/docs/docs/tutorial-basics/markdown-features.mdx +152 -0
- synapse_sdk/devtools/docs/docs/tutorial-extras/_category_.json +7 -0
- synapse_sdk/devtools/docs/docs/tutorial-extras/img/docsVersionDropdown.png +0 -0
- synapse_sdk/devtools/docs/docs/tutorial-extras/img/localeDropdown.png +0 -0
- synapse_sdk/devtools/docs/docs/tutorial-extras/manage-docs-versions.md +55 -0
- synapse_sdk/devtools/docs/docs/tutorial-extras/translate-your-site.md +88 -0
- synapse_sdk/devtools/docs/docusaurus.config.ts +148 -0
- synapse_sdk/devtools/docs/i18n/ko/code.json +325 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/agent.md +43 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/annotation-mixin.md +289 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/backend.md +420 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/base.md +257 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/core-mixin.md +417 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/data-collection-mixin.md +356 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/hitl-mixin.md +192 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/index.md +391 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/integration-mixin.md +479 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/ml-mixin.md +284 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/clients/ray.md +342 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/index.md +52 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/api/plugins/models.md +114 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/categories.md +0 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/cli-usage.md +280 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/concepts/index.md +38 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/configuration.md +83 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/contributing.md +306 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/examples/index.md +29 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/faq.md +179 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/features/converters/index.md +30 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/features/index.md +24 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/features/utils/file.md +415 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/features/utils/network.md +378 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/features/utils/storage.md +60 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/features/utils/types.md +51 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/installation.md +94 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/introduction.md +47 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/neural-net-plugins/train-action-overview.md +815 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/pre-annotation-plugins/pre-annotation-plugin-overview.md +198 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/pre-annotation-plugins/to-task-action-development.md +1645 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/pre-annotation-plugins/to-task-overview.md +717 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/pre-annotation-plugins/to-task-template-development.md +1380 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/upload-plugins/upload-plugin-action.md +948 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/upload-plugins/upload-plugin-overview.md +544 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/categories/upload-plugins/upload-plugin-template.md +766 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/export-plugins.md +1092 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/plugins/plugins.md +117 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/quickstart.md +78 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current/troubleshooting.md +519 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-plugin-content-docs/current.json +34 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-theme-classic/footer.json +42 -0
- synapse_sdk/devtools/docs/i18n/ko/docusaurus-theme-classic/navbar.json +18 -0
- synapse_sdk/devtools/docs/package-lock.json +18784 -0
- synapse_sdk/devtools/docs/package.json +48 -0
- synapse_sdk/devtools/docs/sidebars.ts +122 -0
- synapse_sdk/devtools/docs/src/components/HomepageFeatures/index.tsx +71 -0
- synapse_sdk/devtools/docs/src/components/HomepageFeatures/styles.module.css +11 -0
- synapse_sdk/devtools/docs/src/css/custom.css +30 -0
- synapse_sdk/devtools/docs/src/pages/index.module.css +23 -0
- synapse_sdk/devtools/docs/src/pages/index.tsx +21 -0
- synapse_sdk/devtools/docs/src/pages/markdown-page.md +7 -0
- synapse_sdk/devtools/docs/static/.nojekyll +0 -0
- synapse_sdk/devtools/docs/static/img/docusaurus-social-card.jpg +0 -0
- synapse_sdk/devtools/docs/static/img/docusaurus.png +0 -0
- synapse_sdk/devtools/docs/static/img/favicon.ico +0 -0
- synapse_sdk/devtools/docs/static/img/logo.png +0 -0
- synapse_sdk/devtools/docs/static/img/undraw_docusaurus_mountain.svg +171 -0
- synapse_sdk/devtools/docs/static/img/undraw_docusaurus_react.svg +170 -0
- synapse_sdk/devtools/docs/static/img/undraw_docusaurus_tree.svg +40 -0
- synapse_sdk/devtools/docs/tsconfig.json +8 -0
- synapse_sdk/devtools/server.py +41 -0
- synapse_sdk/devtools/streamlit_app/__init__.py +5 -0
- synapse_sdk/devtools/streamlit_app/app.py +128 -0
- synapse_sdk/devtools/streamlit_app/services/__init__.py +11 -0
- synapse_sdk/devtools/streamlit_app/services/job_service.py +233 -0
- synapse_sdk/devtools/streamlit_app/services/plugin_service.py +236 -0
- synapse_sdk/devtools/streamlit_app/services/serve_service.py +95 -0
- synapse_sdk/devtools/streamlit_app/ui/__init__.py +15 -0
- synapse_sdk/devtools/streamlit_app/ui/config_tab.py +76 -0
- synapse_sdk/devtools/streamlit_app/ui/deployment_tab.py +66 -0
- synapse_sdk/devtools/streamlit_app/ui/http_tab.py +125 -0
- synapse_sdk/devtools/streamlit_app/ui/jobs_tab.py +573 -0
- synapse_sdk/devtools/streamlit_app/ui/serve_tab.py +346 -0
- synapse_sdk/devtools/streamlit_app/ui/status_bar.py +118 -0
- synapse_sdk/devtools/streamlit_app/utils/__init__.py +40 -0
- synapse_sdk/devtools/streamlit_app/utils/json_viewer.py +197 -0
- synapse_sdk/devtools/streamlit_app/utils/log_formatter.py +38 -0
- synapse_sdk/devtools/streamlit_app/utils/styles.py +241 -0
- synapse_sdk/devtools/streamlit_app/utils/ui_components.py +289 -0
- synapse_sdk/devtools/streamlit_app.py +10 -0
- synapse_sdk/loggers.py +65 -7
- synapse_sdk/plugins/README.md +1340 -0
- synapse_sdk/plugins/categories/base.py +73 -11
- synapse_sdk/plugins/categories/data_validation/actions/validation.py +72 -0
- synapse_sdk/plugins/categories/data_validation/templates/plugin/validation.py +33 -5
- synapse_sdk/plugins/categories/export/actions/__init__.py +3 -0
- synapse_sdk/plugins/categories/export/actions/export/__init__.py +28 -0
- synapse_sdk/plugins/categories/export/actions/export/action.py +165 -0
- synapse_sdk/plugins/categories/export/actions/export/enums.py +113 -0
- synapse_sdk/plugins/categories/export/actions/export/exceptions.py +53 -0
- synapse_sdk/plugins/categories/export/actions/export/models.py +74 -0
- synapse_sdk/plugins/categories/export/actions/export/run.py +195 -0
- synapse_sdk/plugins/categories/export/actions/{export.py → export/utils.py} +47 -82
- synapse_sdk/plugins/categories/export/templates/config.yaml +19 -1
- synapse_sdk/plugins/categories/export/templates/plugin/__init__.py +390 -0
- synapse_sdk/plugins/categories/export/templates/plugin/export.py +153 -129
- synapse_sdk/plugins/categories/neural_net/actions/deployment.py +9 -62
- synapse_sdk/plugins/categories/neural_net/actions/train.py +1062 -32
- synapse_sdk/plugins/categories/neural_net/actions/tune.py +534 -0
- synapse_sdk/plugins/categories/neural_net/templates/config.yaml +27 -5
- synapse_sdk/plugins/categories/neural_net/templates/plugin/inference.py +26 -10
- synapse_sdk/plugins/categories/pre_annotation/actions/__init__.py +4 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/pre_annotation/__init__.py +3 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/pre_annotation/action.py +10 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/__init__.py +28 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/action.py +145 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/enums.py +269 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/exceptions.py +14 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/factory.py +76 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/models.py +97 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/orchestrator.py +250 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/run.py +64 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/__init__.py +17 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/annotation.py +287 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/base.py +170 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/extraction.py +83 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/metrics.py +87 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/preprocessor.py +127 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task/strategies/validation.py +143 -0
- synapse_sdk/plugins/categories/pre_annotation/actions/to_task.py +966 -0
- synapse_sdk/plugins/categories/pre_annotation/templates/config.yaml +19 -0
- synapse_sdk/plugins/categories/pre_annotation/templates/plugin/to_task.py +40 -0
- synapse_sdk/plugins/categories/upload/actions/upload/__init__.py +19 -0
- synapse_sdk/plugins/categories/upload/actions/upload/action.py +232 -0
- synapse_sdk/plugins/categories/upload/actions/upload/context.py +185 -0
- synapse_sdk/plugins/categories/upload/actions/upload/enums.py +471 -0
- synapse_sdk/plugins/categories/upload/actions/upload/exceptions.py +36 -0
- synapse_sdk/plugins/categories/upload/actions/upload/factory.py +138 -0
- synapse_sdk/plugins/categories/upload/actions/upload/models.py +203 -0
- synapse_sdk/plugins/categories/upload/actions/upload/orchestrator.py +183 -0
- synapse_sdk/plugins/categories/upload/actions/upload/registry.py +113 -0
- synapse_sdk/plugins/categories/upload/actions/upload/run.py +179 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/base.py +107 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/cleanup.py +62 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/collection.py +63 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/generate.py +84 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/initialize.py +82 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/metadata.py +235 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/organize.py +203 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/upload.py +97 -0
- synapse_sdk/plugins/categories/upload/actions/upload/steps/validate.py +71 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/base.py +82 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/batch.py +39 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/data_unit/single.py +29 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/flat.py +258 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/file_discovery/recursive.py +281 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/excel.py +174 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/metadata/none.py +16 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/upload/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/upload/sync.py +84 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/validation/__init__.py +1 -0
- synapse_sdk/plugins/categories/upload/actions/upload/strategies/validation/default.py +60 -0
- synapse_sdk/plugins/categories/upload/actions/upload/utils.py +250 -0
- synapse_sdk/plugins/categories/upload/templates/README.md +470 -0
- synapse_sdk/plugins/categories/upload/templates/config.yaml +29 -2
- synapse_sdk/plugins/categories/upload/templates/plugin/__init__.py +294 -0
- synapse_sdk/plugins/categories/upload/templates/plugin/upload.py +88 -30
- synapse_sdk/plugins/models.py +122 -16
- synapse_sdk/plugins/templates/plugin-config-schema.json +406 -0
- synapse_sdk/plugins/templates/schema.json +491 -0
- synapse_sdk/plugins/templates/synapse-{{cookiecutter.plugin_code}}-plugin/requirements.txt +1 -1
- synapse_sdk/plugins/utils/__init__.py +46 -0
- synapse_sdk/plugins/utils/actions.py +119 -0
- synapse_sdk/plugins/utils/config.py +203 -0
- synapse_sdk/plugins/{utils.py → utils/legacy.py} +26 -46
- synapse_sdk/plugins/utils/ray_gcs.py +66 -0
- synapse_sdk/plugins/utils/registry.py +58 -0
- synapse_sdk/shared/__init__.py +25 -0
- synapse_sdk/shared/enums.py +93 -0
- synapse_sdk/utils/converters/__init__.py +240 -0
- synapse_sdk/utils/converters/coco/__init__.py +0 -0
- synapse_sdk/utils/converters/coco/from_dm.py +322 -0
- synapse_sdk/utils/converters/coco/to_dm.py +215 -0
- synapse_sdk/utils/converters/dm/__init__.py +56 -0
- synapse_sdk/utils/converters/dm/from_v1.py +627 -0
- synapse_sdk/utils/converters/dm/to_v1.py +367 -0
- synapse_sdk/utils/converters/pascal/__init__.py +0 -0
- synapse_sdk/utils/converters/pascal/from_dm.py +244 -0
- synapse_sdk/utils/converters/pascal/to_dm.py +214 -0
- synapse_sdk/utils/converters/yolo/__init__.py +0 -0
- synapse_sdk/utils/converters/yolo/from_dm.py +384 -0
- synapse_sdk/utils/converters/yolo/to_dm.py +267 -0
- synapse_sdk/utils/dataset.py +46 -0
- synapse_sdk/utils/encryption.py +158 -0
- synapse_sdk/utils/file/__init__.py +39 -0
- synapse_sdk/utils/file/archive.py +32 -0
- synapse_sdk/utils/file/checksum.py +56 -0
- synapse_sdk/utils/file/chunking.py +31 -0
- synapse_sdk/utils/file/download.py +385 -0
- synapse_sdk/utils/file/encoding.py +40 -0
- synapse_sdk/utils/file/io.py +22 -0
- synapse_sdk/utils/file/video/__init__.py +29 -0
- synapse_sdk/utils/file/video/transcode.py +307 -0
- synapse_sdk/utils/{file.py → file.py.backup} +84 -2
- synapse_sdk/utils/http.py +138 -0
- synapse_sdk/utils/network.py +293 -0
- synapse_sdk/utils/storage/__init__.py +36 -2
- synapse_sdk/utils/storage/providers/__init__.py +141 -0
- synapse_sdk/utils/storage/providers/file_system.py +134 -0
- synapse_sdk/utils/storage/providers/http.py +190 -0
- synapse_sdk/utils/storage/providers/s3.py +54 -6
- synapse_sdk/utils/storage/providers/sftp.py +31 -0
- synapse_sdk/utils/storage/registry.py +6 -0
- synapse_sdk-2025.11.7.dist-info/METADATA +122 -0
- synapse_sdk-2025.11.7.dist-info/RECORD +386 -0
- {synapse_sdk-1.0.0a35.dist-info → synapse_sdk-2025.11.7.dist-info}/WHEEL +1 -1
- synapse_sdk/clients/backend/dataset.py +0 -102
- synapse_sdk/plugins/categories/upload/actions/upload.py +0 -293
- synapse_sdk-1.0.0a35.dist-info/METADATA +0 -47
- synapse_sdk-1.0.0a35.dist-info/RECORD +0 -137
- {synapse_sdk-1.0.0a35.dist-info → synapse_sdk-2025.11.7.dist-info}/entry_points.txt +0 -0
- {synapse_sdk-1.0.0a35.dist-info → synapse_sdk-2025.11.7.dist-info}/licenses/LICENSE +0 -0
- {synapse_sdk-1.0.0a35.dist-info → synapse_sdk-2025.11.7.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
2
4
|
import tempfile
|
|
3
|
-
from
|
|
5
|
+
from numbers import Number
|
|
4
6
|
from pathlib import Path
|
|
5
|
-
from typing import Annotated
|
|
7
|
+
from typing import Annotated, Callable, Dict, Optional
|
|
6
8
|
|
|
7
|
-
from pydantic import AfterValidator, BaseModel, field_validator
|
|
9
|
+
from pydantic import AfterValidator, BaseModel, field_validator, model_validator
|
|
8
10
|
from pydantic_core import PydanticCustomError
|
|
9
11
|
|
|
10
12
|
from synapse_sdk.clients.exceptions import ClientError
|
|
@@ -13,63 +15,377 @@ from synapse_sdk.plugins.categories.decorators import register_action
|
|
|
13
15
|
from synapse_sdk.plugins.enums import PluginCategory, RunMethod
|
|
14
16
|
from synapse_sdk.plugins.models import Run
|
|
15
17
|
from synapse_sdk.utils.file import archive, get_temp_path, unarchive
|
|
18
|
+
from synapse_sdk.utils.module_loading import import_string
|
|
16
19
|
from synapse_sdk.utils.pydantic.validators import non_blank
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
class TrainRun(Run):
|
|
23
|
+
is_tune = False
|
|
24
|
+
completed_samples = 0
|
|
25
|
+
num_samples = 0
|
|
26
|
+
checkpoint_output = None
|
|
27
|
+
|
|
28
|
+
def set_progress(self, current, total, category=''):
|
|
29
|
+
if getattr(self, 'is_tune', False) and category == 'train':
|
|
30
|
+
# Ignore train progress updates in tune mode to keep trials-only bar
|
|
31
|
+
return
|
|
32
|
+
super().set_progress(current, total, category)
|
|
33
|
+
|
|
20
34
|
def log_metric(self, category, key, value, **metrics):
|
|
21
35
|
# TODO validate input via plugin config
|
|
22
|
-
|
|
36
|
+
data = {'category': category, 'key': key, 'value': value, 'metrics': metrics}
|
|
37
|
+
|
|
38
|
+
# Automatically add trial_id when is_tune=True
|
|
39
|
+
if self.is_tune:
|
|
40
|
+
try:
|
|
41
|
+
from ray import train
|
|
23
42
|
|
|
24
|
-
|
|
43
|
+
context = train.get_context()
|
|
44
|
+
trial_id = context.get_trial_id()
|
|
45
|
+
if trial_id:
|
|
46
|
+
data['trial_id'] = trial_id
|
|
47
|
+
except Exception:
|
|
48
|
+
# If Ray context is not available, continue without trial_id
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
self.log('metric', data)
|
|
52
|
+
|
|
53
|
+
def log_visualization(self, category, group, index, image, **meta):
|
|
25
54
|
# TODO validate input via plugin config
|
|
26
|
-
|
|
55
|
+
data = {'category': category, 'group': group, 'index': index, **meta}
|
|
56
|
+
|
|
57
|
+
# Automatically add trial_id when is_tune=True
|
|
58
|
+
if self.is_tune:
|
|
59
|
+
try:
|
|
60
|
+
from ray import train
|
|
61
|
+
|
|
62
|
+
context = train.get_context()
|
|
63
|
+
trial_id = context.get_trial_id()
|
|
64
|
+
if trial_id:
|
|
65
|
+
data['trial_id'] = trial_id
|
|
66
|
+
except Exception:
|
|
67
|
+
# If Ray context is not available, continue without trial_id
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
self.log('visualization', data, file=image)
|
|
71
|
+
|
|
72
|
+
def log_trials(self, data=None, *, trials=None, base=None, hyperparameters=None, metrics=None, best_trial=''):
|
|
73
|
+
"""
|
|
74
|
+
Log structured Ray Tune trial progress tables.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
data (dict | None): Pre-built payload to send. Should contain
|
|
78
|
+
``trials`` (dict) key.
|
|
79
|
+
base (list[str] | None): Column names that belong to the fixed base section.
|
|
80
|
+
trials (dict | None): Mapping of ``trial_id`` -> structured section values.
|
|
81
|
+
hyperparameters (list[str] | None): Column names belonging to hyperparameters.
|
|
82
|
+
metrics (list[str] | None): Column names belonging to metrics.
|
|
83
|
+
best_trial (str): Trial ID of the best trial (empty string during tuning, populated at the end).
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
dict: The payload that was logged.
|
|
87
|
+
"""
|
|
88
|
+
if data is None:
|
|
89
|
+
data = {
|
|
90
|
+
'base': base or [],
|
|
91
|
+
'trials': trials or {},
|
|
92
|
+
'hyperparameters': hyperparameters or [],
|
|
93
|
+
'metrics': metrics or [],
|
|
94
|
+
'best_trial': best_trial,
|
|
95
|
+
}
|
|
96
|
+
elif not isinstance(data, dict):
|
|
97
|
+
raise ValueError('log_trials expects a dictionary payload')
|
|
98
|
+
|
|
99
|
+
if 'trials' not in data:
|
|
100
|
+
raise ValueError('log_trials payload must include "trials" key')
|
|
101
|
+
|
|
102
|
+
data.setdefault('base', base or [])
|
|
103
|
+
data.setdefault('hyperparameters', hyperparameters or [])
|
|
104
|
+
data.setdefault('metrics', metrics or [])
|
|
105
|
+
data.setdefault('best_trial', best_trial)
|
|
106
|
+
|
|
107
|
+
self.log('trials', data)
|
|
108
|
+
# Keep track of the last snapshot so we can reuse it (e.g., when finalizing best_trial)
|
|
109
|
+
try:
|
|
110
|
+
self._last_trials_payload = copy.deepcopy(data)
|
|
111
|
+
except Exception:
|
|
112
|
+
self._last_trials_payload = data
|
|
113
|
+
return data
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SearchAlgo(BaseModel):
|
|
117
|
+
"""
|
|
118
|
+
Configuration for Ray Tune search algorithms.
|
|
119
|
+
|
|
120
|
+
Supported algorithms:
|
|
121
|
+
- 'bayesoptsearch': Bayesian optimization using Gaussian Processes
|
|
122
|
+
- 'hyperoptsearch': Tree-structured Parzen Estimator (TPE)
|
|
123
|
+
- 'basicvariantgenerator': Random search (default)
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
name (str): Name of the search algorithm (case-insensitive)
|
|
127
|
+
points_to_evaluate (Optional[dict]): Optional initial hyperparameter
|
|
128
|
+
configurations to evaluate before starting optimization
|
|
129
|
+
|
|
130
|
+
Example:
|
|
131
|
+
{
|
|
132
|
+
"name": "hyperoptsearch",
|
|
133
|
+
"points_to_evaluate": [
|
|
134
|
+
{"learning_rate": 0.001, "batch_size": 32}
|
|
135
|
+
]
|
|
136
|
+
}
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
name: str
|
|
140
|
+
points_to_evaluate: Optional[dict] = None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Scheduler(BaseModel):
|
|
144
|
+
"""
|
|
145
|
+
Configuration for Ray Tune schedulers.
|
|
146
|
+
|
|
147
|
+
Supported schedulers:
|
|
148
|
+
- 'fifo': First-In-First-Out scheduler (default, runs all trials)
|
|
149
|
+
- 'hyperband': HyperBand early stopping scheduler
|
|
150
|
+
|
|
151
|
+
Attributes:
|
|
152
|
+
name (str): Name of the scheduler (case-insensitive)
|
|
153
|
+
options (Optional[str]): Optional scheduler-specific configuration parameters
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
{
|
|
157
|
+
"name": "hyperband",
|
|
158
|
+
"options": {
|
|
159
|
+
"max_t": 100,
|
|
160
|
+
"reduction_factor": 3
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
name: str
|
|
166
|
+
options: Optional[str] = None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class TuneConfig(BaseModel):
|
|
170
|
+
"""
|
|
171
|
+
Configuration for Ray Tune hyperparameter optimization.
|
|
27
172
|
|
|
173
|
+
Used when is_tune=True to configure the hyperparameter search process.
|
|
28
174
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
175
|
+
Attributes:
|
|
176
|
+
mode (Optional[str]): Optimization mode - 'max' or 'min'
|
|
177
|
+
metric (Optional[str]): Name of the metric to optimize
|
|
178
|
+
num_samples (int): Number of hyperparameter configurations to try (default: 1)
|
|
179
|
+
max_concurrent_trials (Optional[int]): Maximum number of trials to run in parallel
|
|
180
|
+
search_alg (Optional[SearchAlgo]): Search algorithm configuration
|
|
181
|
+
scheduler (Optional[Scheduler]): Trial scheduler configuration
|
|
182
|
+
|
|
183
|
+
Example:
|
|
184
|
+
{
|
|
185
|
+
"mode": "max",
|
|
186
|
+
"metric": "accuracy",
|
|
187
|
+
"num_samples": 20,
|
|
188
|
+
"max_concurrent_trials": 4,
|
|
189
|
+
"search_alg": {
|
|
190
|
+
"name": "hyperoptsearch"
|
|
191
|
+
},
|
|
192
|
+
"scheduler": {
|
|
193
|
+
"name": "hyperband",
|
|
194
|
+
"options": {"max_t": 100}
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
mode: Optional[str] = None
|
|
200
|
+
metric: Optional[str] = None
|
|
201
|
+
num_samples: int = 1
|
|
202
|
+
max_concurrent_trials: Optional[int] = None
|
|
203
|
+
search_alg: Optional[SearchAlgo] = None
|
|
204
|
+
scheduler: Optional[Scheduler] = None
|
|
33
205
|
|
|
34
206
|
|
|
35
207
|
class TrainParams(BaseModel):
|
|
208
|
+
"""
|
|
209
|
+
Parameters for TrainAction supporting both regular training and hyperparameter tuning.
|
|
210
|
+
|
|
211
|
+
Attributes:
|
|
212
|
+
name (str): Name for the training/tuning job
|
|
213
|
+
description (str): Description of the job
|
|
214
|
+
checkpoint (int | None): Optional checkpoint ID to resume from
|
|
215
|
+
dataset (int): Dataset ID to use for training
|
|
216
|
+
is_tune (bool): Enable hyperparameter tuning mode (default: False)
|
|
217
|
+
tune_config (Optional[TuneConfig]): Tune configuration (required when is_tune=True)
|
|
218
|
+
num_cpus (Optional[int]): CPUs per trial (tuning mode only)
|
|
219
|
+
num_gpus (Optional[int]): GPUs per trial (tuning mode only)
|
|
220
|
+
hyperparameter (Optional[Any]): Fixed hyperparameters (required when is_tune=False)
|
|
221
|
+
hyperparameters (Optional[list]): Hyperparameter search space (required when is_tune=True)
|
|
222
|
+
|
|
223
|
+
Hyperparameter format when is_tune=True:
|
|
224
|
+
Each item in hyperparameters list must have:
|
|
225
|
+
- 'name': Parameter name (string)
|
|
226
|
+
- 'type': Distribution type (string)
|
|
227
|
+
- Type-specific parameters:
|
|
228
|
+
- uniform/quniform: 'min', 'max'
|
|
229
|
+
- loguniform/qloguniform: 'min', 'max', 'base'
|
|
230
|
+
- randn/qrandn: 'mean', 'sd'
|
|
231
|
+
- randint/qrandint: 'min', 'max'
|
|
232
|
+
- lograndint/qlograndint: 'min', 'max', 'base'
|
|
233
|
+
- choice/grid_search: 'options'
|
|
234
|
+
|
|
235
|
+
Example (Training mode):
|
|
236
|
+
{
|
|
237
|
+
"name": "my_training",
|
|
238
|
+
"dataset": 123,
|
|
239
|
+
"is_tune": false,
|
|
240
|
+
"hyperparameter": {
|
|
241
|
+
"epochs": 100,
|
|
242
|
+
"batch_size": 32,
|
|
243
|
+
"learning_rate": 0.001
|
|
244
|
+
}
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
Example (Tuning mode):
|
|
248
|
+
{
|
|
249
|
+
"name": "my_tuning",
|
|
250
|
+
"dataset": 123,
|
|
251
|
+
"is_tune": true,
|
|
252
|
+
"hyperparameters": [
|
|
253
|
+
{"name": "batch_size", "type": "choice", "options": [16, 32, 64]},
|
|
254
|
+
{"name": "learning_rate", "type": "loguniform", "min": 0.0001, "max": 0.01, "base": 10},
|
|
255
|
+
{"name": "epochs", "type": "randint", "min": 5, "max": 15}
|
|
256
|
+
],
|
|
257
|
+
"tune_config": {
|
|
258
|
+
"mode": "max",
|
|
259
|
+
"metric": "accuracy",
|
|
260
|
+
"num_samples": 10
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
"""
|
|
264
|
+
|
|
36
265
|
name: Annotated[str, AfterValidator(non_blank)]
|
|
37
266
|
description: str
|
|
38
267
|
checkpoint: int | None
|
|
39
268
|
dataset: int
|
|
40
|
-
|
|
269
|
+
is_tune: bool = False
|
|
270
|
+
tune_config: Optional[TuneConfig] = None
|
|
271
|
+
num_cpus: Optional[int] = None
|
|
272
|
+
num_gpus: Optional[int] = None
|
|
273
|
+
hyperparameter: Optional[dict] = None # plan to be deprecated
|
|
274
|
+
hyperparameters: Optional[list] = None
|
|
275
|
+
|
|
276
|
+
@field_validator('hyperparameter', mode='before')
|
|
277
|
+
@classmethod
|
|
278
|
+
def validate_hyperparameter(cls, v, info):
|
|
279
|
+
"""Validate hyperparameter for train mode (is_tune=False)"""
|
|
280
|
+
# Get is_tune flag to determine if this field should be validated
|
|
281
|
+
is_tune = info.data.get('is_tune', False)
|
|
282
|
+
|
|
283
|
+
# If is_tune=True, hyperparameter should be None/not used
|
|
284
|
+
# Just return whatever was passed (will be validated in model_validator)
|
|
285
|
+
if is_tune:
|
|
286
|
+
return v
|
|
287
|
+
|
|
288
|
+
# For train mode, hyperparameter should be a dict
|
|
289
|
+
if isinstance(v, dict):
|
|
290
|
+
return v
|
|
291
|
+
elif isinstance(v, list):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
'hyperparameter must be a dict, not a list. '
|
|
294
|
+
'If you want to use hyperparameter tuning, '
|
|
295
|
+
'set "is_tune": true and use "hyperparameters" instead.'
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError('hyperparameter must be a dict')
|
|
299
|
+
|
|
300
|
+
@field_validator('hyperparameters', mode='before')
|
|
301
|
+
@classmethod
|
|
302
|
+
def validate_hyperparameters(cls, v, info):
|
|
303
|
+
"""Validate hyperparameters for tune mode (is_tune=True)"""
|
|
304
|
+
# Get is_tune flag to determine if this field should be validated
|
|
305
|
+
is_tune = info.data.get('is_tune', False)
|
|
306
|
+
|
|
307
|
+
# If is_tune=False, hyperparameters should be None/not used
|
|
308
|
+
# Just return whatever was passed (will be validated in model_validator)
|
|
309
|
+
if not is_tune:
|
|
310
|
+
return v
|
|
311
|
+
|
|
312
|
+
# For tune mode, hyperparameters should be a list
|
|
313
|
+
if isinstance(v, list):
|
|
314
|
+
return v
|
|
315
|
+
elif isinstance(v, dict):
|
|
316
|
+
raise ValueError(
|
|
317
|
+
'hyperparameters must be a list, not a dict. '
|
|
318
|
+
'If you want to use fixed hyperparameters for training, '
|
|
319
|
+
'set "is_tune": false and use "hyperparameter" instead.'
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError('hyperparameters must be a list')
|
|
41
323
|
|
|
42
324
|
@field_validator('name')
|
|
43
325
|
@staticmethod
|
|
44
326
|
def unique_name(value, info):
|
|
45
327
|
action = info.context['action']
|
|
46
328
|
client = action.client
|
|
329
|
+
is_tune = info.data.get('is_tune', False)
|
|
47
330
|
try:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
331
|
+
if not is_tune:
|
|
332
|
+
model_exists = client.exists('list_models', params={'name': value})
|
|
333
|
+
job_exists = client.exists(
|
|
334
|
+
'list_jobs',
|
|
335
|
+
params={
|
|
336
|
+
'ids_ex': action.job_id,
|
|
337
|
+
'category': 'neural_net',
|
|
338
|
+
'job__action': 'train',
|
|
339
|
+
'is_active': True,
|
|
340
|
+
'params': f'name:{value.replace(":", "%3A")}',
|
|
341
|
+
},
|
|
342
|
+
)
|
|
343
|
+
assert not model_exists and not job_exists, '존재하는 학습 이름입니다.'
|
|
344
|
+
else:
|
|
345
|
+
job_exists = client.exists(
|
|
346
|
+
'list_jobs',
|
|
347
|
+
params={
|
|
348
|
+
'ids_ex': action.job_id,
|
|
349
|
+
'category': 'neural_net',
|
|
350
|
+
'job__action': 'train',
|
|
351
|
+
'is_active': True,
|
|
352
|
+
'params': f'name:{value}',
|
|
353
|
+
},
|
|
354
|
+
)
|
|
355
|
+
assert not job_exists, '존재하는 튜닝 작업 이름입니다.'
|
|
60
356
|
except ClientError:
|
|
61
357
|
raise PydanticCustomError('client_error', '')
|
|
62
358
|
return value
|
|
63
359
|
|
|
360
|
+
@model_validator(mode='after')
|
|
361
|
+
def validate_tune_params(self):
|
|
362
|
+
if self.is_tune:
|
|
363
|
+
# When is_tune=True, hyperparameters is required
|
|
364
|
+
if self.hyperparameters is None:
|
|
365
|
+
raise ValueError('hyperparameters is required when is_tune=True')
|
|
366
|
+
if self.hyperparameter is not None:
|
|
367
|
+
raise ValueError('hyperparameter should not be provided when is_tune=True, use hyperparameters instead')
|
|
368
|
+
if self.tune_config is None:
|
|
369
|
+
raise ValueError('tune_config is required when is_tune=True')
|
|
370
|
+
else:
|
|
371
|
+
# When is_tune=False, either hyperparameter or hyperparameters is required
|
|
372
|
+
if self.hyperparameter is None and self.hyperparameters is None:
|
|
373
|
+
raise ValueError('Either hyperparameter or hyperparameters is required when is_tune=False')
|
|
374
|
+
|
|
375
|
+
if self.hyperparameter is not None and self.hyperparameters is not None:
|
|
376
|
+
raise ValueError('Provide either hyperparameter or hyperparameters, but not both')
|
|
377
|
+
|
|
378
|
+
if self.hyperparameters is not None:
|
|
379
|
+
if not isinstance(self.hyperparameters, list) or len(self.hyperparameters) != 1:
|
|
380
|
+
raise ValueError('hyperparameters must be a list containing a single dictionary')
|
|
381
|
+
self.hyperparameter = self.hyperparameters[0]
|
|
382
|
+
self.hyperparameters = None
|
|
383
|
+
return self
|
|
384
|
+
|
|
64
385
|
|
|
65
386
|
@register_action
|
|
66
387
|
class TrainAction(Action):
|
|
67
|
-
|
|
68
|
-
category = PluginCategory.NEURAL_NET
|
|
69
|
-
method = RunMethod.JOB
|
|
70
|
-
run_class = TrainRun
|
|
71
|
-
params_model = TrainParams
|
|
72
|
-
progress_categories = {
|
|
388
|
+
TRAIN_PROGRESS = {
|
|
73
389
|
'dataset': {
|
|
74
390
|
'proportion': 20,
|
|
75
391
|
},
|
|
@@ -81,8 +397,85 @@ class TrainAction(Action):
|
|
|
81
397
|
},
|
|
82
398
|
}
|
|
83
399
|
|
|
400
|
+
TUNE_PROGRESS = {
|
|
401
|
+
'dataset': {
|
|
402
|
+
'proportion': 20,
|
|
403
|
+
},
|
|
404
|
+
'trials': {
|
|
405
|
+
'proportion': 75,
|
|
406
|
+
},
|
|
407
|
+
'model_upload': {
|
|
408
|
+
'proportion': 5,
|
|
409
|
+
},
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
"""
|
|
413
|
+
**Important notes when using train with is_tune=True:**
|
|
414
|
+
|
|
415
|
+
1. Path to the model output (which is the return value of your train function)
|
|
416
|
+
should be set to the checkpoint_output attribute of the run object **before**
|
|
417
|
+
starting the training.
|
|
418
|
+
2. Before exiting the training function, report the results to Tune.
|
|
419
|
+
3. When using own tune.py, take note of the difference in the order of parameters.
|
|
420
|
+
tune() function starts with hyperparameter, run, dataset, checkpoint, **kwargs
|
|
421
|
+
whereas the train() function starts with run, dataset, hyperparameter, checkpoint, **kwargs.
|
|
422
|
+
----
|
|
423
|
+
1)
|
|
424
|
+
Set the output path for the checkpoint to export best model
|
|
425
|
+
|
|
426
|
+
output_path = Path('path/to/your/weights')
|
|
427
|
+
run.checkpoint_output = str(output_path)
|
|
428
|
+
|
|
429
|
+
2)
|
|
430
|
+
Before exiting the training function, report the results to Tune.
|
|
431
|
+
The results_dict should contain the metrics you want to report.
|
|
432
|
+
|
|
433
|
+
Example: (In train function)
|
|
434
|
+
results_dict = {
|
|
435
|
+
"accuracy": accuracy,
|
|
436
|
+
"loss": loss,
|
|
437
|
+
# Add other metrics as needed
|
|
438
|
+
}
|
|
439
|
+
if hasattr(self.dm_run, 'is_tune') and self.dm_run.is_tune:
|
|
440
|
+
tune.report(results_dict, checkpoint=tune.Checkpoint.from_directory(self.dm_run.checkpoint_output))
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
3)
|
|
444
|
+
tune() function takes hyperparameter, run, dataset, checkpoint, **kwargs in that order
|
|
445
|
+
whereas train() function takes run, dataset, hyperparameter, checkpoint, **kwargs in that order.
|
|
446
|
+
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
name = 'train'
|
|
450
|
+
category = PluginCategory.NEURAL_NET
|
|
451
|
+
method = RunMethod.JOB
|
|
452
|
+
run_class = TrainRun
|
|
453
|
+
params_model = TrainParams
|
|
454
|
+
progress_categories = None
|
|
455
|
+
|
|
456
|
+
def __init__(self, params, plugin_config, requirements=None, envs=None, job_id=None, direct=False, debug=False):
|
|
457
|
+
selected = self.TUNE_PROGRESS if (params or {}).get('is_tune') else self.TRAIN_PROGRESS
|
|
458
|
+
self.progress_categories = copy.deepcopy(selected)
|
|
459
|
+
super().__init__(
|
|
460
|
+
params, plugin_config, requirements=requirements, envs=envs, job_id=job_id, direct=direct, debug=debug
|
|
461
|
+
)
|
|
462
|
+
|
|
84
463
|
def start(self):
|
|
85
|
-
|
|
464
|
+
if self.params.get('is_tune', False):
|
|
465
|
+
return self._start_tune()
|
|
466
|
+
else:
|
|
467
|
+
return self._start_train()
|
|
468
|
+
|
|
469
|
+
def _start_train(self):
|
|
470
|
+
"""Original train logic"""
|
|
471
|
+
hyperparameter = self.params.get('hyperparameter')
|
|
472
|
+
if hyperparameter is None:
|
|
473
|
+
hyperparameters = self.params.get('hyperparameters') or []
|
|
474
|
+
if not hyperparameters:
|
|
475
|
+
raise ValueError('hyperparameter is missing for train mode')
|
|
476
|
+
hyperparameter = hyperparameters[0]
|
|
477
|
+
# Persist the normalized form so later steps (e.g., create_model) find it
|
|
478
|
+
self.params['hyperparameter'] = hyperparameter
|
|
86
479
|
|
|
87
480
|
# download dataset
|
|
88
481
|
self.run.log_message('Preparing dataset for training.')
|
|
@@ -107,6 +500,301 @@ class TrainAction(Action):
|
|
|
107
500
|
self.run.end_log()
|
|
108
501
|
return {'model_id': model['id'] if model else None}
|
|
109
502
|
|
|
503
|
+
def _start_tune(self):
|
|
504
|
+
"""Tune logic using Ray Tune for hyperparameter optimization"""
|
|
505
|
+
from ray import tune
|
|
506
|
+
|
|
507
|
+
class _TuneTrialsLoggingCallback(tune.Callback):
|
|
508
|
+
"""Capture Ray Tune trial table snapshots and forward them to run.log_trials."""
|
|
509
|
+
|
|
510
|
+
BASE_COLUMNS = ('trial_id', 'status')
|
|
511
|
+
METRIC_COLUMN_LIMIT = 4
|
|
512
|
+
RESERVED_RESULT_KEYS = {
|
|
513
|
+
'config',
|
|
514
|
+
'date',
|
|
515
|
+
'done',
|
|
516
|
+
'experiment_id',
|
|
517
|
+
'experiment_state',
|
|
518
|
+
'experiment_tag',
|
|
519
|
+
'hostname',
|
|
520
|
+
'iterations_since_restore',
|
|
521
|
+
'logdir',
|
|
522
|
+
'node_ip',
|
|
523
|
+
'pid',
|
|
524
|
+
'restored_from_trial_id',
|
|
525
|
+
'time_since_restore',
|
|
526
|
+
'time_this_iter_s',
|
|
527
|
+
'time_total',
|
|
528
|
+
'time_total_s',
|
|
529
|
+
'timestamp',
|
|
530
|
+
'timesteps_since_restore',
|
|
531
|
+
'timesteps_total',
|
|
532
|
+
'training_iteration',
|
|
533
|
+
'trial_id',
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
def __init__(self, run):
|
|
537
|
+
self.run = run
|
|
538
|
+
self.trial_rows: Dict[str, Dict[str, object]] = {}
|
|
539
|
+
self.config_keys: list[str] = []
|
|
540
|
+
self.metric_keys: list[str] = []
|
|
541
|
+
self._last_snapshot = None
|
|
542
|
+
|
|
543
|
+
def on_trial_result(self, iteration, trials, trial, result, **info):
|
|
544
|
+
self._record_trial(trial, result, status_override='RUNNING')
|
|
545
|
+
self._emit_snapshot()
|
|
546
|
+
|
|
547
|
+
def on_trial_complete(self, iteration, trials, trial, **info):
|
|
548
|
+
self._record_trial(trial, getattr(trial, 'last_result', None), status_override='TERMINATED')
|
|
549
|
+
self._emit_snapshot()
|
|
550
|
+
|
|
551
|
+
def on_trial_error(self, iteration, trials, trial, **info):
|
|
552
|
+
self._record_trial(trial, getattr(trial, 'last_result', None), status_override='ERROR')
|
|
553
|
+
self._emit_snapshot()
|
|
554
|
+
|
|
555
|
+
def on_step_end(self, iteration, trials, **info):
|
|
556
|
+
updated = False
|
|
557
|
+
for trial in trials or []:
|
|
558
|
+
status = getattr(trial, 'status', None)
|
|
559
|
+
existing = self.trial_rows.get(trial.trial_id)
|
|
560
|
+
existing_status = existing.get('status') if existing else None
|
|
561
|
+
if existing is None or (status and status != existing_status):
|
|
562
|
+
self._record_trial(
|
|
563
|
+
trial,
|
|
564
|
+
getattr(trial, 'last_result', None),
|
|
565
|
+
status_override=status,
|
|
566
|
+
)
|
|
567
|
+
updated = True
|
|
568
|
+
if updated:
|
|
569
|
+
self._emit_snapshot()
|
|
570
|
+
|
|
571
|
+
def _record_trial(self, trial, result, status_override=None):
|
|
572
|
+
if not self.run or not getattr(self.run, 'log_trials', None):
|
|
573
|
+
return
|
|
574
|
+
|
|
575
|
+
row = self.trial_rows.setdefault(trial.trial_id, {})
|
|
576
|
+
result = result or {}
|
|
577
|
+
if not isinstance(result, dict):
|
|
578
|
+
result = {}
|
|
579
|
+
|
|
580
|
+
row['trial_id'] = trial.trial_id
|
|
581
|
+
row['status'] = status_override or getattr(trial, 'status', 'PENDING')
|
|
582
|
+
config_data = self._extract_config(trial.config or {})
|
|
583
|
+
metric_data = self._extract_metrics(result)
|
|
584
|
+
|
|
585
|
+
row.update(config_data)
|
|
586
|
+
row.update(metric_data)
|
|
587
|
+
|
|
588
|
+
self._track_columns(config_data.keys(), metric_data.keys())
|
|
589
|
+
|
|
590
|
+
def _extract_config(self, config):
|
|
591
|
+
flat = {}
|
|
592
|
+
if not isinstance(config, dict):
|
|
593
|
+
return flat
|
|
594
|
+
for key, value in self._flatten_items(config):
|
|
595
|
+
serialized = self._serialize_config_value(value)
|
|
596
|
+
flat[key] = serialized
|
|
597
|
+
return flat
|
|
598
|
+
|
|
599
|
+
def _extract_metrics(self, result):
|
|
600
|
+
metrics = {}
|
|
601
|
+
if not isinstance(result, dict):
|
|
602
|
+
return metrics
|
|
603
|
+
|
|
604
|
+
nested = result.get('metrics')
|
|
605
|
+
if isinstance(nested, dict):
|
|
606
|
+
for key, value in self._flatten_items(nested, prefix='metrics'):
|
|
607
|
+
serialized = self._serialize_metric_value(value)
|
|
608
|
+
if serialized is not None:
|
|
609
|
+
metrics[key] = serialized
|
|
610
|
+
|
|
611
|
+
for key, value in result.items():
|
|
612
|
+
if key in self.RESERVED_RESULT_KEYS or key == 'metrics':
|
|
613
|
+
continue
|
|
614
|
+
if isinstance(value, dict):
|
|
615
|
+
continue
|
|
616
|
+
serialized = self._serialize_metric_value(value)
|
|
617
|
+
if serialized is not None:
|
|
618
|
+
metrics[key] = serialized
|
|
619
|
+
|
|
620
|
+
return metrics
|
|
621
|
+
|
|
622
|
+
def _track_columns(self, config_keys, metric_keys):
|
|
623
|
+
for key in config_keys:
|
|
624
|
+
if key not in self.config_keys:
|
|
625
|
+
self.config_keys.append(key)
|
|
626
|
+
for key in metric_keys:
|
|
627
|
+
if key not in self.metric_keys and len(self.metric_keys) < self.METRIC_COLUMN_LIMIT:
|
|
628
|
+
self.metric_keys.append(key)
|
|
629
|
+
|
|
630
|
+
def _emit_snapshot(self):
|
|
631
|
+
if not self.trial_rows:
|
|
632
|
+
return
|
|
633
|
+
|
|
634
|
+
base_keys = list(self.BASE_COLUMNS)
|
|
635
|
+
config_keys = list(self.config_keys)
|
|
636
|
+
metric_keys = list(self.metric_keys)
|
|
637
|
+
columns = base_keys + config_keys + metric_keys
|
|
638
|
+
|
|
639
|
+
ordered_trials = {}
|
|
640
|
+
flat_rows = []
|
|
641
|
+
for trial_id in sorted(self.trial_rows.keys()):
|
|
642
|
+
row = self.trial_rows[trial_id]
|
|
643
|
+
base_values = [row.get(column) for column in base_keys]
|
|
644
|
+
hyper_values = [row.get(column) for column in config_keys]
|
|
645
|
+
metric_values = [row.get(column) for column in metric_keys]
|
|
646
|
+
flat_values = base_values + hyper_values + metric_values
|
|
647
|
+
ordered_trials[trial_id] = {
|
|
648
|
+
'base': base_values,
|
|
649
|
+
'hyperparameters': hyper_values,
|
|
650
|
+
'metrics': metric_values,
|
|
651
|
+
}
|
|
652
|
+
flat_rows.append((trial_id, tuple(flat_values)))
|
|
653
|
+
|
|
654
|
+
snapshot = (
|
|
655
|
+
tuple(columns),
|
|
656
|
+
tuple(flat_rows),
|
|
657
|
+
)
|
|
658
|
+
if snapshot == self._last_snapshot:
|
|
659
|
+
return
|
|
660
|
+
self._last_snapshot = snapshot
|
|
661
|
+
|
|
662
|
+
self.run.log_trials(
|
|
663
|
+
base=base_keys,
|
|
664
|
+
trials=ordered_trials,
|
|
665
|
+
hyperparameters=config_keys,
|
|
666
|
+
metrics=metric_keys,
|
|
667
|
+
best_trial='',
|
|
668
|
+
)
|
|
669
|
+
self._update_trials_progress()
|
|
670
|
+
|
|
671
|
+
def _flatten_items(self, data, prefix=None):
|
|
672
|
+
if not isinstance(data, dict):
|
|
673
|
+
return
|
|
674
|
+
for key, value in data.items():
|
|
675
|
+
key_str = str(key)
|
|
676
|
+
current = f'{prefix}/{key_str}' if prefix else key_str
|
|
677
|
+
if isinstance(value, dict):
|
|
678
|
+
yield from self._flatten_items(value, current)
|
|
679
|
+
else:
|
|
680
|
+
yield current, value
|
|
681
|
+
|
|
682
|
+
def _update_trials_progress(self):
|
|
683
|
+
total = getattr(self.run, 'num_samples', None)
|
|
684
|
+
if not total:
|
|
685
|
+
return
|
|
686
|
+
|
|
687
|
+
completed_statuses = {'TERMINATED', 'ERROR'}
|
|
688
|
+
completed = sum(1 for row in self.trial_rows.values() if row.get('status') in completed_statuses)
|
|
689
|
+
completed = min(completed, total)
|
|
690
|
+
|
|
691
|
+
try:
|
|
692
|
+
self.run.set_progress(completed, total, category='trials')
|
|
693
|
+
except Exception: # pragma: no cover - safeguard against logging failures
|
|
694
|
+
self.run.log_message('Failed to update trials progress.')
|
|
695
|
+
|
|
696
|
+
def _serialize_config_value(self, value):
|
|
697
|
+
if isinstance(value, (str, bool)) or value is None:
|
|
698
|
+
return value
|
|
699
|
+
if isinstance(value, Number):
|
|
700
|
+
return float(value) if not isinstance(value, bool) else value
|
|
701
|
+
return str(value)
|
|
702
|
+
|
|
703
|
+
def _serialize_metric_value(self, value):
|
|
704
|
+
if isinstance(value, Number):
|
|
705
|
+
return float(value)
|
|
706
|
+
return None
|
|
707
|
+
|
|
708
|
+
# Mark run as tune
|
|
709
|
+
self.run.is_tune = True
|
|
710
|
+
|
|
711
|
+
# download dataset
|
|
712
|
+
self.run.log_message('Preparing dataset for hyperparameter tuning.')
|
|
713
|
+
input_dataset = self.get_dataset()
|
|
714
|
+
|
|
715
|
+
# retrieve checkpoint
|
|
716
|
+
checkpoint = None
|
|
717
|
+
if self.params['checkpoint']:
|
|
718
|
+
self.run.log_message('Retrieving checkpoint.')
|
|
719
|
+
checkpoint = self.get_model(self.params['checkpoint'])
|
|
720
|
+
|
|
721
|
+
# train dataset
|
|
722
|
+
self.run.log_message('Starting training for hyperparameter tuning.')
|
|
723
|
+
|
|
724
|
+
# Save num_samples to TrainRun for logging
|
|
725
|
+
self.run.num_samples = self.params['tune_config']['num_samples']
|
|
726
|
+
|
|
727
|
+
tune_config = self.params['tune_config']
|
|
728
|
+
|
|
729
|
+
entrypoint = self.entrypoint
|
|
730
|
+
if not self._tune_override_exists():
|
|
731
|
+
# entrypoint must be train entrypoint
|
|
732
|
+
train_entrypoint = entrypoint
|
|
733
|
+
|
|
734
|
+
def _tune(param_space, run, dataset, checkpoint=None, **kwargs):
|
|
735
|
+
return train_entrypoint(run, dataset, param_space, checkpoint, **kwargs)
|
|
736
|
+
|
|
737
|
+
entrypoint = _tune
|
|
738
|
+
|
|
739
|
+
entrypoint = self._wrap_tune_entrypoint(entrypoint, tune_config.get('metric'))
|
|
740
|
+
|
|
741
|
+
trainable = tune.with_parameters(entrypoint, run=self.run, dataset=input_dataset, checkpoint=checkpoint)
|
|
742
|
+
|
|
743
|
+
# Extract search_alg and scheduler as separate objects to avoid JSON serialization issues
|
|
744
|
+
search_alg = self.convert_tune_search_alg(tune_config)
|
|
745
|
+
scheduler = self.convert_tune_scheduler(tune_config)
|
|
746
|
+
|
|
747
|
+
# Create a copy of tune_config without non-serializable objects
|
|
748
|
+
tune_config_dict = {
|
|
749
|
+
'mode': tune_config.get('mode'),
|
|
750
|
+
'metric': tune_config.get('metric'),
|
|
751
|
+
'num_samples': tune_config.get('num_samples', 1),
|
|
752
|
+
'max_concurrent_trials': tune_config.get('max_concurrent_trials'),
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
# Add search_alg and scheduler to tune_config_dict only if they exist
|
|
756
|
+
if search_alg is not None:
|
|
757
|
+
tune_config_dict['search_alg'] = search_alg
|
|
758
|
+
if scheduler is not None:
|
|
759
|
+
tune_config_dict['scheduler'] = scheduler
|
|
760
|
+
|
|
761
|
+
hyperparameters = self.params['hyperparameters']
|
|
762
|
+
param_space = self.convert_tune_params(hyperparameters)
|
|
763
|
+
temp_path = tempfile.TemporaryDirectory()
|
|
764
|
+
trials_logger = _TuneTrialsLoggingCallback(self.run)
|
|
765
|
+
|
|
766
|
+
tuner = tune.Tuner(
|
|
767
|
+
tune.with_resources(trainable, resources=self.tune_resources),
|
|
768
|
+
tune_config=tune.TuneConfig(**tune_config_dict),
|
|
769
|
+
run_config=tune.RunConfig(
|
|
770
|
+
name=f'synapse_tune_hpo_{self.job_id}',
|
|
771
|
+
log_to_file=('stdout.log', 'stderr.log'),
|
|
772
|
+
storage_path=temp_path.name,
|
|
773
|
+
callbacks=[trials_logger],
|
|
774
|
+
),
|
|
775
|
+
param_space=param_space,
|
|
776
|
+
)
|
|
777
|
+
result = tuner.fit()
|
|
778
|
+
|
|
779
|
+
trial_models_map, trial_models_summary = self._upload_tune_trial_models(result)
|
|
780
|
+
|
|
781
|
+
best_result = result.get_best_result()
|
|
782
|
+
self._override_best_trial(best_result)
|
|
783
|
+
|
|
784
|
+
# upload model_data
|
|
785
|
+
self.run.log_message('Registering best model data.')
|
|
786
|
+
self.run.set_progress(0, 1, category='model_upload')
|
|
787
|
+
if best_result.path not in trial_models_map:
|
|
788
|
+
trial_models_map[best_result.path] = self.create_model_from_result(best_result)
|
|
789
|
+
self.run.set_progress(1, 1, category='model_upload')
|
|
790
|
+
|
|
791
|
+
self.run.end_log()
|
|
792
|
+
|
|
793
|
+
return {
|
|
794
|
+
'best_result': best_result.config,
|
|
795
|
+
'trial_models': trial_models_summary,
|
|
796
|
+
}
|
|
797
|
+
|
|
110
798
|
def get_dataset(self):
|
|
111
799
|
client = self.run.client
|
|
112
800
|
assert bool(client)
|
|
@@ -145,9 +833,19 @@ class TrainAction(Action):
|
|
|
145
833
|
configuration_fields = ['hyperparameter']
|
|
146
834
|
configuration = {field: params.pop(field) for field in configuration_fields}
|
|
147
835
|
|
|
148
|
-
|
|
836
|
+
run_name = params.get('name') or f'{self.plugin_release.name}-{self.job_id}'
|
|
837
|
+
unique_name = run_name
|
|
838
|
+
|
|
839
|
+
trial_id = self._extract_trial_id(path)
|
|
840
|
+
if trial_id:
|
|
841
|
+
unique_name = f'{run_name}_{trial_id}'
|
|
842
|
+
|
|
843
|
+
params['name'] = unique_name
|
|
844
|
+
|
|
845
|
+
temp_dir = tempfile.mkdtemp()
|
|
846
|
+
try:
|
|
149
847
|
input_path = Path(path)
|
|
150
|
-
archive_path = Path(
|
|
848
|
+
archive_path = Path(temp_dir, 'archive.zip')
|
|
151
849
|
archive(input_path, archive_path)
|
|
152
850
|
|
|
153
851
|
return self.client.create_model({
|
|
@@ -157,3 +855,335 @@ class TrainAction(Action):
|
|
|
157
855
|
'configuration': configuration,
|
|
158
856
|
**params,
|
|
159
857
|
})
|
|
858
|
+
finally:
|
|
859
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
860
|
+
|
|
861
|
+
@property
|
|
862
|
+
def tune_resources(self):
|
|
863
|
+
resources = {}
|
|
864
|
+
for option in ['num_cpus', 'num_gpus']:
|
|
865
|
+
option_value = self.params.get(option)
|
|
866
|
+
if option_value:
|
|
867
|
+
# Remove the 'num_' prefix and trailing s from the option name
|
|
868
|
+
resources[(lambda s: s[4:-1])(option)] = option_value
|
|
869
|
+
return resources
|
|
870
|
+
|
|
871
|
+
def _upload_tune_trial_models(self, result_grid):
|
|
872
|
+
trial_models = {}
|
|
873
|
+
trial_summaries = []
|
|
874
|
+
|
|
875
|
+
total_results = len(result_grid)
|
|
876
|
+
|
|
877
|
+
for index in range(total_results):
|
|
878
|
+
trial_result = result_grid[index]
|
|
879
|
+
|
|
880
|
+
if getattr(trial_result, 'error', None):
|
|
881
|
+
continue
|
|
882
|
+
|
|
883
|
+
try:
|
|
884
|
+
model = self.create_model_from_result(trial_result)
|
|
885
|
+
except Exception as exc: # pragma: no cover - best effort logging
|
|
886
|
+
self.run.log_message(f'Failed to register model for trial at {trial_result.path}: {exc}')
|
|
887
|
+
continue
|
|
888
|
+
|
|
889
|
+
if model:
|
|
890
|
+
trial_models[trial_result.path] = model
|
|
891
|
+
trial_summaries.append({
|
|
892
|
+
'trial_logdir': trial_result.path,
|
|
893
|
+
'model_id': model.get('id'),
|
|
894
|
+
'config': getattr(trial_result, 'config', None),
|
|
895
|
+
'metrics': getattr(trial_result, 'metrics', None),
|
|
896
|
+
})
|
|
897
|
+
|
|
898
|
+
return trial_models, trial_summaries
|
|
899
|
+
|
|
900
|
+
def _override_best_trial(self, best_result):
|
|
901
|
+
if not best_result:
|
|
902
|
+
return
|
|
903
|
+
|
|
904
|
+
best_config = getattr(best_result, 'config', None)
|
|
905
|
+
if not isinstance(best_config, dict):
|
|
906
|
+
return
|
|
907
|
+
|
|
908
|
+
trial_id = getattr(best_result, 'trial_id', None)
|
|
909
|
+
if not trial_id:
|
|
910
|
+
path = getattr(best_result, 'path', None)
|
|
911
|
+
if path:
|
|
912
|
+
trial_id = self._extract_trial_id(path)
|
|
913
|
+
|
|
914
|
+
if not trial_id:
|
|
915
|
+
self.run.log_message('Skipping override_best_trial request: trial_id missing.')
|
|
916
|
+
return
|
|
917
|
+
|
|
918
|
+
payload = {'trial_id': trial_id, **best_config}
|
|
919
|
+
|
|
920
|
+
url = f'trains/{self.job_id}/override_best_trial/'
|
|
921
|
+
self.run.log_message(f'Calling override_best_trial: {url} payload={payload}')
|
|
922
|
+
|
|
923
|
+
try:
|
|
924
|
+
self.client._put(url, data=payload)
|
|
925
|
+
# Log trials with best_trial after successful PUT request
|
|
926
|
+
last_snapshot = getattr(self.run, '_last_trials_payload', None)
|
|
927
|
+
if isinstance(last_snapshot, dict) and 'trials' in last_snapshot:
|
|
928
|
+
final_snapshot = copy.deepcopy(last_snapshot)
|
|
929
|
+
final_snapshot['best_trial'] = trial_id
|
|
930
|
+
self.run.log_trials(data=final_snapshot)
|
|
931
|
+
else:
|
|
932
|
+
self.run.log_trials(best_trial=trial_id)
|
|
933
|
+
except ClientError as exc: # pragma: no cover - network failure should not break run
|
|
934
|
+
self.run.log_message(f'Failed to override best trial: {exc}')
|
|
935
|
+
|
|
936
|
+
def create_model_from_result(self, result):
|
|
937
|
+
params = copy.deepcopy(self.params)
|
|
938
|
+
configuration_fields = ['hyperparameters']
|
|
939
|
+
configuration = {field: params.pop(field) for field in configuration_fields}
|
|
940
|
+
configuration['tune_trial'] = {
|
|
941
|
+
'config': getattr(result, 'config', None),
|
|
942
|
+
'metrics': getattr(result, 'metrics', None),
|
|
943
|
+
'logdir': getattr(result, 'path', None),
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
temp_dir = tempfile.mkdtemp()
|
|
947
|
+
archive_path = Path(temp_dir, 'archive.zip')
|
|
948
|
+
|
|
949
|
+
# Archive tune results
|
|
950
|
+
# https://docs.ray.io/en/latest/tune/tutorials/tune_get_data_in_and_out.html#getting-data-out-of-tune-using-checkpoints-other-artifacts
|
|
951
|
+
archive(result.path, archive_path)
|
|
952
|
+
|
|
953
|
+
unique_name = params.get('name') or f'{self.plugin_release.name}-{self.job_id}'
|
|
954
|
+
trial_id = getattr(result, 'trial_id', None) or self._extract_trial_id(result.path)
|
|
955
|
+
if trial_id:
|
|
956
|
+
unique_name = f'{unique_name}_{trial_id}'
|
|
957
|
+
params['name'] = unique_name
|
|
958
|
+
|
|
959
|
+
try:
|
|
960
|
+
return self.client.create_model({
|
|
961
|
+
'plugin': self.plugin_release.plugin,
|
|
962
|
+
'version': self.plugin_release.version,
|
|
963
|
+
'file': str(archive_path),
|
|
964
|
+
'configuration': configuration,
|
|
965
|
+
**params,
|
|
966
|
+
})
|
|
967
|
+
finally:
|
|
968
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
969
|
+
|
|
970
|
+
@staticmethod
|
|
971
|
+
def convert_tune_scheduler(tune_config):
|
|
972
|
+
"""
|
|
973
|
+
Convert YAML hyperparameter configuration to a Ray Tune scheduler.
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
tune_config (dict): Hyperparameter configuration.
|
|
977
|
+
|
|
978
|
+
Returns:
|
|
979
|
+
object: Ray Tune scheduler instance.
|
|
980
|
+
|
|
981
|
+
Supported schedulers:
|
|
982
|
+
- 'fifo': FIFOScheduler (default)
|
|
983
|
+
- 'hyperband': HyperBandScheduler
|
|
984
|
+
"""
|
|
985
|
+
|
|
986
|
+
from ray.tune.schedulers import (
|
|
987
|
+
ASHAScheduler,
|
|
988
|
+
FIFOScheduler,
|
|
989
|
+
HyperBandScheduler,
|
|
990
|
+
MedianStoppingRule,
|
|
991
|
+
PopulationBasedTraining,
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
if tune_config.get('scheduler') is None:
|
|
995
|
+
return None
|
|
996
|
+
|
|
997
|
+
scheduler_map = {
|
|
998
|
+
'fifo': FIFOScheduler,
|
|
999
|
+
'asha': ASHAScheduler,
|
|
1000
|
+
'hyperband': HyperBandScheduler,
|
|
1001
|
+
'pbt': PopulationBasedTraining,
|
|
1002
|
+
'median': MedianStoppingRule,
|
|
1003
|
+
}
|
|
1004
|
+
|
|
1005
|
+
scheduler_type = tune_config['scheduler'].get('name', 'fifo').lower()
|
|
1006
|
+
scheduler_class = scheduler_map.get(scheduler_type, FIFOScheduler)
|
|
1007
|
+
|
|
1008
|
+
# 옵션이 있는 경우 전달하고, 없으면 기본 생성자 호출
|
|
1009
|
+
options = tune_config['scheduler'].get('options')
|
|
1010
|
+
|
|
1011
|
+
# options가 None이거나 빈 딕셔너리가 아닌 경우에만 전달
|
|
1012
|
+
scheduler = scheduler_class(**options) if options else scheduler_class()
|
|
1013
|
+
|
|
1014
|
+
return scheduler
|
|
1015
|
+
|
|
1016
|
+
@staticmethod
|
|
1017
|
+
def convert_tune_search_alg(tune_config):
|
|
1018
|
+
"""
|
|
1019
|
+
Convert YAML hyperparameter configuration to Ray Tune search algorithm.
|
|
1020
|
+
|
|
1021
|
+
Args:
|
|
1022
|
+
tune_config (dict): Hyperparameter configuration.
|
|
1023
|
+
|
|
1024
|
+
Returns:
|
|
1025
|
+
object: Ray Tune search algorithm instance or None
|
|
1026
|
+
|
|
1027
|
+
Supported search algorithms:
|
|
1028
|
+
- 'bayesoptsearch': Bayesian optimization
|
|
1029
|
+
- 'hyperoptsearch': Tree-structured Parzen Estimator
|
|
1030
|
+
- 'basicvariantgenerator': Random search (default)
|
|
1031
|
+
"""
|
|
1032
|
+
|
|
1033
|
+
if tune_config.get('search_alg') is None:
|
|
1034
|
+
return None
|
|
1035
|
+
|
|
1036
|
+
search_alg_name = tune_config['search_alg']['name'].lower()
|
|
1037
|
+
metric = tune_config['metric']
|
|
1038
|
+
mode = tune_config['mode']
|
|
1039
|
+
points_to_evaluate = tune_config['search_alg'].get('points_to_evaluate', None)
|
|
1040
|
+
|
|
1041
|
+
if search_alg_name == 'axsearch':
|
|
1042
|
+
from ray.tune.search.ax import AxSearch
|
|
1043
|
+
|
|
1044
|
+
search_alg = AxSearch(metric=metric, mode=mode)
|
|
1045
|
+
elif search_alg_name == 'bayesoptsearch':
|
|
1046
|
+
from ray.tune.search.bayesopt import BayesOptSearch
|
|
1047
|
+
|
|
1048
|
+
search_alg = BayesOptSearch(metric=metric, mode=mode)
|
|
1049
|
+
elif search_alg_name == 'hyperoptsearch':
|
|
1050
|
+
from ray.tune.search.hyperopt import HyperOptSearch
|
|
1051
|
+
|
|
1052
|
+
search_alg = HyperOptSearch(metric=metric, mode=mode)
|
|
1053
|
+
elif search_alg_name == 'optunasearch':
|
|
1054
|
+
from ray.tune.search.optuna import OptunaSearch
|
|
1055
|
+
|
|
1056
|
+
search_alg = OptunaSearch(metric=metric, mode=mode)
|
|
1057
|
+
elif search_alg_name == 'basicvariantgenerator':
|
|
1058
|
+
from ray.tune.search.basic_variant import BasicVariantGenerator
|
|
1059
|
+
|
|
1060
|
+
search_alg = BasicVariantGenerator(points_to_evaluate=points_to_evaluate)
|
|
1061
|
+
else:
|
|
1062
|
+
raise ValueError(
|
|
1063
|
+
f'Unsupported search algorithm: {search_alg_name}. '
|
|
1064
|
+
f'Supported algorithms are: bayesoptsearch, hyperoptsearch, basicvariantgenerator'
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
return search_alg
|
|
1068
|
+
|
|
1069
|
+
@staticmethod
|
|
1070
|
+
def convert_tune_params(param_list):
|
|
1071
|
+
"""
|
|
1072
|
+
Convert YAML hyperparameter configuration to Ray Tune parameter dictionary.
|
|
1073
|
+
|
|
1074
|
+
Args:
|
|
1075
|
+
param_list (list): List of hyperparameter configurations.
|
|
1076
|
+
|
|
1077
|
+
Returns:
|
|
1078
|
+
dict: Ray Tune parameter dictionary
|
|
1079
|
+
"""
|
|
1080
|
+
from ray import tune
|
|
1081
|
+
|
|
1082
|
+
param_handlers = {
|
|
1083
|
+
'uniform': lambda p: tune.uniform(p['min'], p['max']),
|
|
1084
|
+
'quniform': lambda p: tune.quniform(p['min'], p['max']),
|
|
1085
|
+
'loguniform': lambda p: tune.loguniform(p['min'], p['max'], p['base']),
|
|
1086
|
+
'qloguniform': lambda p: tune.qloguniform(p['min'], p['max'], p['base']),
|
|
1087
|
+
'randn': lambda p: tune.randn(p['mean'], p['sd']),
|
|
1088
|
+
'qrandn': lambda p: tune.qrandn(p['mean'], p['sd']),
|
|
1089
|
+
'randint': lambda p: tune.randint(p['min'], p['max']),
|
|
1090
|
+
'qrandint': lambda p: tune.qrandint(p['min'], p['max']),
|
|
1091
|
+
'lograndint': lambda p: tune.lograndint(p['min'], p['max'], p['base']),
|
|
1092
|
+
'qlograndint': lambda p: tune.qlograndint(p['min'], p['max'], p['base']),
|
|
1093
|
+
'choice': lambda p: tune.choice(p['options']),
|
|
1094
|
+
'grid_search': lambda p: tune.grid_search(p['options']),
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
param_space = {}
|
|
1098
|
+
|
|
1099
|
+
for param in param_list:
|
|
1100
|
+
name = param['name']
|
|
1101
|
+
param_type = param['type']
|
|
1102
|
+
|
|
1103
|
+
if param_type in param_handlers:
|
|
1104
|
+
param_space[name] = param_handlers[param_type](param)
|
|
1105
|
+
else:
|
|
1106
|
+
raise ValueError(f'Unknown parameter type: {param_type}')
|
|
1107
|
+
|
|
1108
|
+
return param_space
|
|
1109
|
+
|
|
1110
|
+
@staticmethod
|
|
1111
|
+
def _tune_override_exists(module_path='plugin.tune') -> bool:
|
|
1112
|
+
try:
|
|
1113
|
+
import_string(module_path)
|
|
1114
|
+
return True
|
|
1115
|
+
except ImportError:
|
|
1116
|
+
return False
|
|
1117
|
+
|
|
1118
|
+
@staticmethod
|
|
1119
|
+
def _extract_trial_id(path: str | Path) -> Optional[str]:
|
|
1120
|
+
name = Path(path).name
|
|
1121
|
+
|
|
1122
|
+
ray_id_patterns = (
|
|
1123
|
+
re.compile(r'([0-9a-f]{8})', re.IGNORECASE), # e.g., e460453e
|
|
1124
|
+
re.compile(r'([0-9a-f]{5}_[0-9]{5})', re.IGNORECASE), # e.g., 7a2d0_00000
|
|
1125
|
+
)
|
|
1126
|
+
for pattern in ray_id_patterns:
|
|
1127
|
+
match = pattern.search(name)
|
|
1128
|
+
if match:
|
|
1129
|
+
return match.group(1)
|
|
1130
|
+
return None
|
|
1131
|
+
|
|
1132
|
+
def _wrap_tune_entrypoint(self, entrypoint: Callable, metric_key: Optional[str]) -> Callable:
|
|
1133
|
+
def _wrapped(*args, **kwargs):
|
|
1134
|
+
last_metrics: Optional[Dict[str, float]] = None
|
|
1135
|
+
|
|
1136
|
+
try:
|
|
1137
|
+
from ray import tune as ray_tune
|
|
1138
|
+
except ImportError:
|
|
1139
|
+
ray_tune = None
|
|
1140
|
+
|
|
1141
|
+
if ray_tune and hasattr(ray_tune, 'report'):
|
|
1142
|
+
original_report = ray_tune.report
|
|
1143
|
+
|
|
1144
|
+
def caching_report(metrics, *r_args, **r_kwargs):
|
|
1145
|
+
nonlocal last_metrics
|
|
1146
|
+
if isinstance(metrics, dict):
|
|
1147
|
+
last_metrics = metrics.copy()
|
|
1148
|
+
return original_report(metrics, *r_args, **r_kwargs)
|
|
1149
|
+
|
|
1150
|
+
ray_tune.report = caching_report
|
|
1151
|
+
else:
|
|
1152
|
+
original_report = None
|
|
1153
|
+
|
|
1154
|
+
try:
|
|
1155
|
+
result = entrypoint(*args, **kwargs)
|
|
1156
|
+
finally:
|
|
1157
|
+
if ray_tune and original_report:
|
|
1158
|
+
ray_tune.report = original_report
|
|
1159
|
+
|
|
1160
|
+
payload = self._normalize_tune_result(result, metric_key)
|
|
1161
|
+
if last_metrics:
|
|
1162
|
+
merged = last_metrics.copy()
|
|
1163
|
+
merged.update(payload)
|
|
1164
|
+
payload = merged
|
|
1165
|
+
|
|
1166
|
+
if metric_key and metric_key not in payload:
|
|
1167
|
+
payload[metric_key] = (last_metrics or {}).get(metric_key, 0.0)
|
|
1168
|
+
|
|
1169
|
+
return payload
|
|
1170
|
+
|
|
1171
|
+
wrapper_name = getattr(entrypoint, '__name__', None)
|
|
1172
|
+
if wrapper_name and (wrapper_name.startswith('_') or wrapper_name == '<lambda>'):
|
|
1173
|
+
wrapper_name = None
|
|
1174
|
+
final_name = wrapper_name or f'trial_{hash(entrypoint) & 0xFFFF:X}'
|
|
1175
|
+
_wrapped.__name__ = final_name
|
|
1176
|
+
_wrapped.__qualname__ = final_name
|
|
1177
|
+
|
|
1178
|
+
return _wrapped
|
|
1179
|
+
|
|
1180
|
+
@staticmethod
|
|
1181
|
+
def _normalize_tune_result(result, metric_key: Optional[str]) -> Dict:
|
|
1182
|
+
if isinstance(result, dict):
|
|
1183
|
+
return result
|
|
1184
|
+
|
|
1185
|
+
if isinstance(result, Number):
|
|
1186
|
+
target_key = metric_key or 'result'
|
|
1187
|
+
return {target_key: result}
|
|
1188
|
+
|
|
1189
|
+
return {'result': result}
|