sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,562 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Script to generate Pydantic classes from JSON schema
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from typing import Dict, Any, Set
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def to_snake_case(name: str) -> str:
|
|
11
|
+
"""Convert camelCase or PascalCase to snake_case"""
|
|
12
|
+
import re
|
|
13
|
+
|
|
14
|
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
15
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def to_pascal_case(name: str) -> str:
|
|
19
|
+
"""Convert snake_case to PascalCase"""
|
|
20
|
+
return "".join(word.capitalize() for word in name.split("_"))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_python_type(json_type: str, format_info: Dict[str, Any] = None) -> str:
|
|
24
|
+
"""Convert JSON schema type to Python type"""
|
|
25
|
+
type_mapping = {
|
|
26
|
+
"string": "str",
|
|
27
|
+
"number": "float",
|
|
28
|
+
"integer": "int",
|
|
29
|
+
"boolean": "bool",
|
|
30
|
+
"array": "List",
|
|
31
|
+
"object": "Dict",
|
|
32
|
+
}
|
|
33
|
+
return type_mapping.get(json_type, "Any")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def generate_enum_class(name: str, enum_values: list) -> str:
|
|
37
|
+
"""Generate enum class from schema enum"""
|
|
38
|
+
class_name = to_pascal_case(name)
|
|
39
|
+
lines = [f"class {class_name}(str, Enum):"]
|
|
40
|
+
for value in enum_values:
|
|
41
|
+
enum_name = value.upper().replace(" ", "_").replace("-", "_")
|
|
42
|
+
lines.append(f' {enum_name} = "{value}"')
|
|
43
|
+
return "\n".join(lines) + "\n\n"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def find_nested_class_name(prop_name: str, known_classes: Set[str]) -> str:
|
|
47
|
+
"""Find the correct nested class name for a property"""
|
|
48
|
+
# Common mappings for nested classes
|
|
49
|
+
nested_mappings = {
|
|
50
|
+
"training_job_details": "TrainingJobDetails",
|
|
51
|
+
"training_environment": "TrainingEnvironment",
|
|
52
|
+
"model_overview": "ModelOverview",
|
|
53
|
+
"intended_uses": "IntendedUses",
|
|
54
|
+
"business_details": "BusinessDetails",
|
|
55
|
+
"training_details": "TrainingDetails",
|
|
56
|
+
"additional_information": "AdditionalInformation",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return nested_mappings.get(prop_name, "")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def generate_field_definition(
|
|
63
|
+
prop_name: str,
|
|
64
|
+
prop_schema: Dict[str, Any],
|
|
65
|
+
required: bool = False,
|
|
66
|
+
additional_classes=None,
|
|
67
|
+
known_classes: Set[str] = None,
|
|
68
|
+
) -> str:
|
|
69
|
+
"""Generate Pydantic field definition"""
|
|
70
|
+
if additional_classes is None:
|
|
71
|
+
additional_classes = []
|
|
72
|
+
if known_classes is None:
|
|
73
|
+
known_classes = set()
|
|
74
|
+
|
|
75
|
+
field_type = prop_schema.get("type")
|
|
76
|
+
|
|
77
|
+
# Handle $ref
|
|
78
|
+
if "$ref" in prop_schema:
|
|
79
|
+
ref_name = prop_schema["$ref"].split("/")[-1]
|
|
80
|
+
class_name = to_pascal_case(ref_name)
|
|
81
|
+
# Map specific missing types
|
|
82
|
+
type_mapping = {
|
|
83
|
+
"axis_name_string": "str",
|
|
84
|
+
"axis_name_array": "List[str]",
|
|
85
|
+
"custom_property": "str",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Special handling for function enum
|
|
89
|
+
if ref_name == "function" and prop_name == "function":
|
|
90
|
+
field_type_str = "Function"
|
|
91
|
+
elif ref_name in type_mapping:
|
|
92
|
+
field_type_str = type_mapping[ref_name]
|
|
93
|
+
elif class_name not in known_classes:
|
|
94
|
+
# Create a new class for undefined references
|
|
95
|
+
additional_classes.append(
|
|
96
|
+
(class_name, {"type": "object", "additionalProperties": True})
|
|
97
|
+
)
|
|
98
|
+
field_type_str = class_name
|
|
99
|
+
else:
|
|
100
|
+
field_type_str = class_name
|
|
101
|
+
|
|
102
|
+
if not required:
|
|
103
|
+
field_type_str = f"Optional[{field_type_str}]"
|
|
104
|
+
return f"{prop_name}: {field_type_str} = None"
|
|
105
|
+
|
|
106
|
+
# Handle enum
|
|
107
|
+
if "enum" in prop_schema:
|
|
108
|
+
# For single-value enums or specific cases, use Literal
|
|
109
|
+
if len(prop_schema["enum"]) == 1:
|
|
110
|
+
enum_value = prop_schema["enum"][0]
|
|
111
|
+
field_type_str = (
|
|
112
|
+
f'Literal["{enum_value}"]' if required else f'Optional[Literal["{enum_value}"]]'
|
|
113
|
+
)
|
|
114
|
+
elif prop_name == "type" and "enum" in prop_schema:
|
|
115
|
+
# For type fields with enums, use Literal with Union
|
|
116
|
+
enum_values = prop_schema["enum"]
|
|
117
|
+
literal_values = ", ".join([f'"{val}"' for val in enum_values])
|
|
118
|
+
field_type_str = (
|
|
119
|
+
f"Literal[{literal_values}]" if required else f"Optional[Literal[{literal_values}]]"
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
enum_class = to_pascal_case(prop_name)
|
|
123
|
+
field_type_str = enum_class if required else f"Optional[{enum_class}]"
|
|
124
|
+
return f"{prop_name}: {field_type_str} = None"
|
|
125
|
+
|
|
126
|
+
# Handle array
|
|
127
|
+
if field_type == "array":
|
|
128
|
+
items = prop_schema.get("items", {})
|
|
129
|
+
if "type" in items and items["type"] != "object":
|
|
130
|
+
item_type = get_python_type(items["type"])
|
|
131
|
+
field_type_str = f"List[{item_type}]"
|
|
132
|
+
elif not items:
|
|
133
|
+
# Empty items - create a generic item class
|
|
134
|
+
item_class_name = to_pascal_case(f"{prop_name}_item")
|
|
135
|
+
additional_classes.append(
|
|
136
|
+
(item_class_name, {"type": "object", "additionalProperties": True})
|
|
137
|
+
)
|
|
138
|
+
field_type_str = f"List[{item_class_name}]"
|
|
139
|
+
elif "$ref" in items:
|
|
140
|
+
ref_name = items["$ref"].split("/")[-1]
|
|
141
|
+
item_type = to_pascal_case(ref_name)
|
|
142
|
+
field_type_str = f"List[{item_type}]"
|
|
143
|
+
elif "anyOf" in items:
|
|
144
|
+
# Handle union types in arrays
|
|
145
|
+
union_types = []
|
|
146
|
+
for any_of_item in items["anyOf"]:
|
|
147
|
+
if "$ref" in any_of_item:
|
|
148
|
+
ref_name = any_of_item["$ref"].split("/")[-1]
|
|
149
|
+
union_types.append(to_pascal_case(ref_name))
|
|
150
|
+
field_type_str = f"List[Union[{', '.join(union_types)}]]"
|
|
151
|
+
elif items.get("type") == "object":
|
|
152
|
+
# Create a class for array items
|
|
153
|
+
item_class_name = to_pascal_case(f"{prop_name}_item")
|
|
154
|
+
additional_classes.append((item_class_name, items))
|
|
155
|
+
field_type_str = f"List[{item_class_name}]"
|
|
156
|
+
else:
|
|
157
|
+
# Create a generic item class
|
|
158
|
+
item_class_name = to_pascal_case(f"{prop_name}_item")
|
|
159
|
+
additional_classes.append(
|
|
160
|
+
(item_class_name, {"type": "object", "additionalProperties": True})
|
|
161
|
+
)
|
|
162
|
+
field_type_str = f"List[{item_class_name}]"
|
|
163
|
+
|
|
164
|
+
if not required:
|
|
165
|
+
field_type_str = f"Optional[{field_type_str}]"
|
|
166
|
+
|
|
167
|
+
# Handle object
|
|
168
|
+
elif field_type == "object":
|
|
169
|
+
if "additionalProperties" in prop_schema:
|
|
170
|
+
add_props = prop_schema["additionalProperties"]
|
|
171
|
+
if isinstance(add_props, dict) and "$ref" in add_props:
|
|
172
|
+
ref_name = add_props["$ref"].split("/")[-1]
|
|
173
|
+
field_type_str = f"Dict[str, {to_pascal_case(ref_name)}]"
|
|
174
|
+
elif isinstance(add_props, dict) and "type" in add_props:
|
|
175
|
+
value_type = get_python_type(add_props["type"])
|
|
176
|
+
field_type_str = f"Dict[str, {value_type}]"
|
|
177
|
+
elif isinstance(add_props, bool) and add_props:
|
|
178
|
+
# additionalProperties: true - check for existing nested class first
|
|
179
|
+
nested_class_name = find_nested_class_name(prop_name, known_classes)
|
|
180
|
+
if nested_class_name:
|
|
181
|
+
# Check if this should be a direct reference or Dict based on schema structure
|
|
182
|
+
if prop_name in [
|
|
183
|
+
"training_job_details",
|
|
184
|
+
"training_environment",
|
|
185
|
+
"model_overview",
|
|
186
|
+
"intended_uses",
|
|
187
|
+
"business_details",
|
|
188
|
+
"training_details",
|
|
189
|
+
"additional_information",
|
|
190
|
+
]:
|
|
191
|
+
field_type_str = nested_class_name
|
|
192
|
+
else:
|
|
193
|
+
field_type_str = f"Dict[str, {nested_class_name}]"
|
|
194
|
+
else:
|
|
195
|
+
value_class_name = to_pascal_case(f"{prop_name}_value")
|
|
196
|
+
additional_classes.append(
|
|
197
|
+
(value_class_name, {"type": "object", "additionalProperties": True})
|
|
198
|
+
)
|
|
199
|
+
field_type_str = f"Dict[str, {value_class_name}]"
|
|
200
|
+
else:
|
|
201
|
+
# Create a value class for complex additionalProperties
|
|
202
|
+
nested_class_name = find_nested_class_name(prop_name, known_classes)
|
|
203
|
+
if nested_class_name:
|
|
204
|
+
# Check if this should be a direct reference or Dict based on schema structure
|
|
205
|
+
if prop_name in [
|
|
206
|
+
"training_job_details",
|
|
207
|
+
"training_environment",
|
|
208
|
+
"model_overview",
|
|
209
|
+
"intended_uses",
|
|
210
|
+
"business_details",
|
|
211
|
+
"training_details",
|
|
212
|
+
"additional_information",
|
|
213
|
+
]:
|
|
214
|
+
field_type_str = nested_class_name
|
|
215
|
+
else:
|
|
216
|
+
field_type_str = f"Dict[str, {nested_class_name}]"
|
|
217
|
+
else:
|
|
218
|
+
value_class_name = to_pascal_case(f"{prop_name}_value")
|
|
219
|
+
additional_classes.append(
|
|
220
|
+
(
|
|
221
|
+
value_class_name,
|
|
222
|
+
(
|
|
223
|
+
add_props
|
|
224
|
+
if isinstance(add_props, dict)
|
|
225
|
+
else {"type": "object", "additionalProperties": True}
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
field_type_str = f"Dict[str, {value_class_name}]"
|
|
230
|
+
elif "properties" in prop_schema:
|
|
231
|
+
# This should be a separate class - use direct reference
|
|
232
|
+
class_name = to_pascal_case(prop_name)
|
|
233
|
+
field_type_str = class_name
|
|
234
|
+
else:
|
|
235
|
+
# Create a class for generic objects
|
|
236
|
+
class_name = to_pascal_case(prop_name)
|
|
237
|
+
additional_classes.append(
|
|
238
|
+
(class_name, {"type": "object", "additionalProperties": True})
|
|
239
|
+
)
|
|
240
|
+
field_type_str = class_name
|
|
241
|
+
|
|
242
|
+
if not required:
|
|
243
|
+
field_type_str = f"Optional[{field_type_str}]"
|
|
244
|
+
|
|
245
|
+
# Handle anyOf/oneOf
|
|
246
|
+
elif "anyOf" in prop_schema:
|
|
247
|
+
# Create a union type or a generic class
|
|
248
|
+
union_types = []
|
|
249
|
+
for i, any_of_item in enumerate(prop_schema["anyOf"]):
|
|
250
|
+
if "type" in any_of_item and any_of_item["type"] != "object":
|
|
251
|
+
union_types.append(get_python_type(any_of_item["type"]))
|
|
252
|
+
elif "type" in any_of_item and any_of_item["type"] == "array":
|
|
253
|
+
# Handle array types in anyOf
|
|
254
|
+
items = any_of_item.get("items", {})
|
|
255
|
+
if "type" in items:
|
|
256
|
+
item_type = get_python_type(items["type"])
|
|
257
|
+
union_types.append(f"List[{item_type}]")
|
|
258
|
+
else:
|
|
259
|
+
union_types.append("List")
|
|
260
|
+
else:
|
|
261
|
+
# Create a class for complex anyOf items
|
|
262
|
+
item_class_name = to_pascal_case(f"{prop_name}_variant_{i}")
|
|
263
|
+
additional_classes.append((item_class_name, any_of_item))
|
|
264
|
+
union_types.append(item_class_name)
|
|
265
|
+
|
|
266
|
+
if len(union_types) == 1:
|
|
267
|
+
field_type_str = union_types[0]
|
|
268
|
+
else:
|
|
269
|
+
field_type_str = f"Union[{', '.join(union_types)}]"
|
|
270
|
+
|
|
271
|
+
if not required:
|
|
272
|
+
field_type_str = f"Optional[{field_type_str}]"
|
|
273
|
+
|
|
274
|
+
# Handle fields without explicit type but with nested structure
|
|
275
|
+
elif not field_type and ("function" in prop_schema or "notes" in prop_schema):
|
|
276
|
+
# This is likely the objective_function case with nested structure
|
|
277
|
+
class_name = to_pascal_case(prop_name)
|
|
278
|
+
field_type_str = class_name if required else f"Optional[{class_name}]"
|
|
279
|
+
|
|
280
|
+
# Handle basic types
|
|
281
|
+
else:
|
|
282
|
+
python_type = get_python_type(field_type)
|
|
283
|
+
field_type_str = python_type if required else f"Optional[{python_type}]"
|
|
284
|
+
|
|
285
|
+
# Add Field constraints
|
|
286
|
+
constraints = []
|
|
287
|
+
if "maxLength" in prop_schema:
|
|
288
|
+
constraints.append(f"max_length={prop_schema['maxLength']}")
|
|
289
|
+
if "minLength" in prop_schema:
|
|
290
|
+
constraints.append(f"min_length={prop_schema['minLength']}")
|
|
291
|
+
if "maxItems" in prop_schema:
|
|
292
|
+
constraints.append(f"max_length={prop_schema['maxItems']}")
|
|
293
|
+
if "minItems" in prop_schema:
|
|
294
|
+
constraints.append(f"min_length={prop_schema['minItems']}")
|
|
295
|
+
if "pattern" in prop_schema:
|
|
296
|
+
constraints.append(f'pattern="{prop_schema["pattern"]}"')
|
|
297
|
+
if field_type == "string" and "enum" in prop_schema and len(prop_schema["enum"]) == 1:
|
|
298
|
+
constraints.append(f"const=True")
|
|
299
|
+
|
|
300
|
+
if required and not constraints:
|
|
301
|
+
field_def = f"{prop_name}: {field_type_str}"
|
|
302
|
+
elif required and constraints:
|
|
303
|
+
field_def = f"{prop_name}: {field_type_str} = Field({', '.join(constraints)})"
|
|
304
|
+
elif not required and constraints:
|
|
305
|
+
field_def = f"{prop_name}: {field_type_str} = Field(None, {', '.join(constraints)})"
|
|
306
|
+
else:
|
|
307
|
+
default_val = "None"
|
|
308
|
+
if field_type == "array" and "default" in prop_schema:
|
|
309
|
+
default_val = "Field(default_factory=list)"
|
|
310
|
+
field_def = f"{prop_name}: {field_type_str} = {default_val}"
|
|
311
|
+
|
|
312
|
+
return field_def
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def generate_class(
|
|
316
|
+
class_name: str, schema: Dict[str, Any], additional_classes=None, known_classes: Set[str] = None
|
|
317
|
+
) -> str:
|
|
318
|
+
"""Generate Pydantic class from schema"""
|
|
319
|
+
if additional_classes is None:
|
|
320
|
+
additional_classes = []
|
|
321
|
+
if known_classes is None:
|
|
322
|
+
known_classes = set()
|
|
323
|
+
|
|
324
|
+
lines = [f"class {class_name}(BaseModel):"]
|
|
325
|
+
|
|
326
|
+
properties = schema.get("properties", {})
|
|
327
|
+
required_fields = set(schema.get("required", []))
|
|
328
|
+
|
|
329
|
+
if not properties:
|
|
330
|
+
lines.append(" pass")
|
|
331
|
+
return "\n".join(lines) + "\n\n"
|
|
332
|
+
|
|
333
|
+
for prop_name, prop_schema in properties.items():
|
|
334
|
+
is_required = prop_name in required_fields
|
|
335
|
+
field_def = generate_field_definition(
|
|
336
|
+
prop_name, prop_schema, is_required, additional_classes, known_classes
|
|
337
|
+
)
|
|
338
|
+
lines.append(f" {field_def}")
|
|
339
|
+
|
|
340
|
+
return "\n".join(lines) + "\n\n"
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def extract_class_dependencies(class_def: str) -> set:
|
|
344
|
+
"""Extract class names that this class depends on"""
|
|
345
|
+
import re
|
|
346
|
+
|
|
347
|
+
dependencies = set()
|
|
348
|
+
|
|
349
|
+
# Find type hints that reference other classes
|
|
350
|
+
patterns = [
|
|
351
|
+
r":\s*([A-Z][a-zA-Z0-9_]*)", # field: ClassName
|
|
352
|
+
r"Optional\[([A-Z][a-zA-Z0-9_]*)\]", # Optional[ClassName]
|
|
353
|
+
r"List\[([A-Z][a-zA-Z0-9_]*)\]", # List[ClassName]
|
|
354
|
+
r"Dict\[str,\s*([A-Z][a-zA-Z0-9_]*)\]", # Dict[str, ClassName]
|
|
355
|
+
r"Union\[([^\]]+)\]", # Union[...]
|
|
356
|
+
]
|
|
357
|
+
|
|
358
|
+
for pattern in patterns:
|
|
359
|
+
matches = re.findall(pattern, class_def)
|
|
360
|
+
for match in matches:
|
|
361
|
+
if pattern.endswith(r"Union\[([^\]]+)\]"):
|
|
362
|
+
# Handle Union types
|
|
363
|
+
union_types = [t.strip() for t in match.split(",")]
|
|
364
|
+
for union_type in union_types:
|
|
365
|
+
if union_type[0].isupper():
|
|
366
|
+
dependencies.add(union_type)
|
|
367
|
+
elif match[0].isupper(): # Only class names (start with uppercase)
|
|
368
|
+
dependencies.add(match)
|
|
369
|
+
|
|
370
|
+
return dependencies
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def topological_sort_classes(class_definitions: list) -> list:
|
|
374
|
+
"""Sort classes based on their dependencies using topological sort"""
|
|
375
|
+
# Build dependency graph
|
|
376
|
+
dependencies = {}
|
|
377
|
+
class_names = set()
|
|
378
|
+
|
|
379
|
+
for class_def in class_definitions:
|
|
380
|
+
# Extract class name
|
|
381
|
+
lines = class_def.strip().split("\n")
|
|
382
|
+
if lines and lines[0].startswith("class "):
|
|
383
|
+
class_name = lines[0].split("(")[0].replace("class ", "").strip()
|
|
384
|
+
class_names.add(class_name)
|
|
385
|
+
dependencies[class_name] = extract_class_dependencies(class_def)
|
|
386
|
+
|
|
387
|
+
# Remove dependencies that are not in our class set (external dependencies)
|
|
388
|
+
for class_name in dependencies:
|
|
389
|
+
dependencies[class_name] = dependencies[class_name].intersection(class_names)
|
|
390
|
+
|
|
391
|
+
# Topological sort
|
|
392
|
+
sorted_classes = []
|
|
393
|
+
visited = set()
|
|
394
|
+
temp_visited = set()
|
|
395
|
+
|
|
396
|
+
def visit(class_name):
|
|
397
|
+
if class_name in temp_visited:
|
|
398
|
+
return # Circular dependency, skip
|
|
399
|
+
if class_name in visited:
|
|
400
|
+
return
|
|
401
|
+
|
|
402
|
+
temp_visited.add(class_name)
|
|
403
|
+
for dep in dependencies.get(class_name, set()):
|
|
404
|
+
visit(dep)
|
|
405
|
+
temp_visited.remove(class_name)
|
|
406
|
+
visited.add(class_name)
|
|
407
|
+
|
|
408
|
+
# Find the class definition
|
|
409
|
+
for class_def in class_definitions:
|
|
410
|
+
if class_def.strip().startswith(f"class {class_name}("):
|
|
411
|
+
sorted_classes.append(class_def)
|
|
412
|
+
break
|
|
413
|
+
|
|
414
|
+
for class_name in class_names:
|
|
415
|
+
visit(class_name)
|
|
416
|
+
|
|
417
|
+
return sorted_classes
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def generate_pydantic_from_schema(schema_path: str, output_path: str):
|
|
421
|
+
"""Main function to generate Pydantic classes from JSON schema"""
|
|
422
|
+
|
|
423
|
+
with open(schema_path, "r") as f:
|
|
424
|
+
schema = json.load(f)
|
|
425
|
+
|
|
426
|
+
output_lines = [
|
|
427
|
+
"from typing import List, Optional, Dict, Union, Literal",
|
|
428
|
+
"from pydantic import BaseModel, Field",
|
|
429
|
+
"from enum import Enum",
|
|
430
|
+
"",
|
|
431
|
+
"",
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
additional_classes = []
|
|
435
|
+
known_classes = set()
|
|
436
|
+
all_class_definitions = []
|
|
437
|
+
|
|
438
|
+
# Collect all known class names first
|
|
439
|
+
definitions = schema.get("definitions", {})
|
|
440
|
+
for def_name, def_schema in definitions.items():
|
|
441
|
+
class_name = to_pascal_case(def_name)
|
|
442
|
+
known_classes.add(class_name)
|
|
443
|
+
# Also add the original snake_case name for reference matching
|
|
444
|
+
known_classes.add(def_name)
|
|
445
|
+
|
|
446
|
+
properties = schema.get("properties", {})
|
|
447
|
+
for prop_name, prop_schema in properties.items():
|
|
448
|
+
if prop_schema.get("type") == "object":
|
|
449
|
+
known_classes.add(to_pascal_case(prop_name))
|
|
450
|
+
|
|
451
|
+
# Generate enums from definitions
|
|
452
|
+
for def_name, def_schema in definitions.items():
|
|
453
|
+
if "enum" in def_schema:
|
|
454
|
+
enum_class = generate_enum_class(def_name, def_schema["enum"])
|
|
455
|
+
output_lines.append(enum_class)
|
|
456
|
+
|
|
457
|
+
# No type aliases needed - use direct types
|
|
458
|
+
|
|
459
|
+
# Generate Function enum based on objective_function definition
|
|
460
|
+
if "objective_function" in definitions:
|
|
461
|
+
obj_func_def = definitions["objective_function"]
|
|
462
|
+
if "properties" in obj_func_def and "function" in obj_func_def["properties"]:
|
|
463
|
+
func_prop = obj_func_def["properties"]["function"]
|
|
464
|
+
if "enum" in func_prop:
|
|
465
|
+
enum_class = generate_enum_class("function", func_prop["enum"])
|
|
466
|
+
output_lines.append(enum_class)
|
|
467
|
+
|
|
468
|
+
# Generate classes from definitions (including those without explicit type)
|
|
469
|
+
for def_name, def_schema in definitions.items():
|
|
470
|
+
if "enum" not in def_schema and (
|
|
471
|
+
def_schema.get("type") == "object" or "properties" in def_schema
|
|
472
|
+
):
|
|
473
|
+
class_name = to_pascal_case(def_name)
|
|
474
|
+
class_def = generate_class(class_name, def_schema, additional_classes, known_classes)
|
|
475
|
+
all_class_definitions.append(class_def)
|
|
476
|
+
|
|
477
|
+
# Generate classes from main properties
|
|
478
|
+
for prop_name, prop_schema in properties.items():
|
|
479
|
+
if prop_schema.get("type") == "object":
|
|
480
|
+
class_name = to_pascal_case(prop_name)
|
|
481
|
+
class_def = generate_class(class_name, prop_schema, additional_classes, known_classes)
|
|
482
|
+
all_class_definitions.append(class_def)
|
|
483
|
+
|
|
484
|
+
# Generate nested object classes
|
|
485
|
+
def find_nested_objects(obj, parent_name=""):
|
|
486
|
+
nested_classes = []
|
|
487
|
+
if isinstance(obj, dict):
|
|
488
|
+
properties = obj.get("properties", {})
|
|
489
|
+
for prop_name, prop_schema in properties.items():
|
|
490
|
+
if prop_schema.get("type") == "object" and "properties" in prop_schema:
|
|
491
|
+
class_name = to_pascal_case(prop_name)
|
|
492
|
+
class_def = generate_class(
|
|
493
|
+
class_name, prop_schema, additional_classes, known_classes
|
|
494
|
+
)
|
|
495
|
+
nested_classes.append(class_def)
|
|
496
|
+
nested_classes.extend(find_nested_objects(prop_schema, prop_name))
|
|
497
|
+
elif prop_schema.get("type") == "array":
|
|
498
|
+
items = prop_schema.get("items", {})
|
|
499
|
+
if items.get("type") == "object" and "properties" in items:
|
|
500
|
+
class_name = to_pascal_case(f"{prop_name}_item")
|
|
501
|
+
class_def = generate_class(
|
|
502
|
+
class_name, items, additional_classes, known_classes
|
|
503
|
+
)
|
|
504
|
+
nested_classes.append(class_def)
|
|
505
|
+
return nested_classes
|
|
506
|
+
|
|
507
|
+
nested_classes = find_nested_objects(schema)
|
|
508
|
+
all_class_definitions.extend(nested_classes)
|
|
509
|
+
|
|
510
|
+
# Generate additional classes
|
|
511
|
+
for class_name, class_schema in additional_classes:
|
|
512
|
+
if isinstance(class_schema, dict):
|
|
513
|
+
# Handle generic objects with additionalProperties
|
|
514
|
+
if class_schema.get("additionalProperties") is True and not class_schema.get(
|
|
515
|
+
"properties"
|
|
516
|
+
):
|
|
517
|
+
# Create a simple pass class for generic objects
|
|
518
|
+
all_class_definitions.append(f"class {class_name}(BaseModel):\n pass\n\n")
|
|
519
|
+
else:
|
|
520
|
+
class_def = generate_class(class_name, class_schema)
|
|
521
|
+
all_class_definitions.append(class_def)
|
|
522
|
+
|
|
523
|
+
# Generate main schema class
|
|
524
|
+
main_class_name = schema.get("title", "Schema")
|
|
525
|
+
if main_class_name == "SageMakerModelCardSchema":
|
|
526
|
+
main_class_name = "ModelCardContent"
|
|
527
|
+
elif main_class_name.endswith("Schema"):
|
|
528
|
+
main_class_name = main_class_name
|
|
529
|
+
else:
|
|
530
|
+
main_class_name = f"{main_class_name}Schema"
|
|
531
|
+
|
|
532
|
+
main_class = generate_class(main_class_name, schema, additional_classes, known_classes)
|
|
533
|
+
all_class_definitions.append(main_class)
|
|
534
|
+
|
|
535
|
+
# Sort classes by dependencies
|
|
536
|
+
sorted_classes = topological_sort_classes(all_class_definitions)
|
|
537
|
+
output_lines.extend(sorted_classes)
|
|
538
|
+
|
|
539
|
+
# Remove duplicates while preserving order
|
|
540
|
+
seen_classes = set()
|
|
541
|
+
final_output = []
|
|
542
|
+
for line in output_lines:
|
|
543
|
+
if line.strip().startswith("class "):
|
|
544
|
+
class_name = line.strip().split("(")[0].replace("class ", "").strip()
|
|
545
|
+
if class_name not in seen_classes:
|
|
546
|
+
seen_classes.add(class_name)
|
|
547
|
+
final_output.append(line)
|
|
548
|
+
# Skip duplicate class definitions
|
|
549
|
+
else:
|
|
550
|
+
final_output.append(line)
|
|
551
|
+
|
|
552
|
+
# Post-process to replace any remaining type aliases
|
|
553
|
+
final_content = "\n".join(final_output)
|
|
554
|
+
final_content = final_content.replace("CustomProperty", "str")
|
|
555
|
+
final_content = final_content.replace("AxisNameString", "str")
|
|
556
|
+
final_content = final_content.replace("AxisNameArray", "List[str]")
|
|
557
|
+
|
|
558
|
+
# Write output
|
|
559
|
+
with open(output_path, "w") as f:
|
|
560
|
+
f.write(final_content)
|
|
561
|
+
|
|
562
|
+
print(f"Generated Pydantic classes in {output_path}")
|