oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +247 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
ads/aqua/common/utils.py
ADDED
@@ -0,0 +1,1295 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
"""AQUA utils and constants."""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
import base64
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
import random
|
12
|
+
import re
|
13
|
+
import shlex
|
14
|
+
import shutil
|
15
|
+
import subprocess
|
16
|
+
from datetime import datetime, timedelta
|
17
|
+
from functools import wraps
|
18
|
+
from pathlib import Path
|
19
|
+
from string import Template
|
20
|
+
from typing import Any, Dict, List, Optional, Union
|
21
|
+
|
22
|
+
import fsspec
|
23
|
+
import oci
|
24
|
+
from cachetools import TTLCache, cached
|
25
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
26
|
+
from huggingface_hub.file_download import repo_folder_name
|
27
|
+
from huggingface_hub.hf_api import HfApi, ModelInfo
|
28
|
+
from huggingface_hub.utils import (
|
29
|
+
GatedRepoError,
|
30
|
+
HfHubHTTPError,
|
31
|
+
RepositoryNotFoundError,
|
32
|
+
RevisionNotFoundError,
|
33
|
+
)
|
34
|
+
from oci.data_science.models import JobRun, Model
|
35
|
+
from oci.object_storage.models import ObjectSummary
|
36
|
+
from pydantic import BaseModel, ValidationError
|
37
|
+
|
38
|
+
from ads.aqua.common.entities import GPUShapesIndex
|
39
|
+
from ads.aqua.common.enums import (
|
40
|
+
CONTAINER_FAMILY_COMPATIBILITY,
|
41
|
+
InferenceContainerParamType,
|
42
|
+
InferenceContainerType,
|
43
|
+
RqsAdditionalDetails,
|
44
|
+
TextEmbeddingInferenceContainerParams,
|
45
|
+
)
|
46
|
+
from ads.aqua.common.errors import (
|
47
|
+
AquaFileNotFoundError,
|
48
|
+
AquaRuntimeError,
|
49
|
+
AquaValueError,
|
50
|
+
)
|
51
|
+
from ads.aqua.constants import (
|
52
|
+
AQUA_GA_LIST,
|
53
|
+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
|
54
|
+
DEPLOYMENT_CONFIG,
|
55
|
+
FINE_TUNING_CONFIG,
|
56
|
+
HF_LOGIN_DEFAULT_TIMEOUT,
|
57
|
+
LICENSE,
|
58
|
+
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
|
59
|
+
MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
60
|
+
README,
|
61
|
+
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
|
62
|
+
SUPPORTED_FILE_FORMATS,
|
63
|
+
TEI_CONTAINER_DEFAULT_HOST,
|
64
|
+
TGI_INFERENCE_RESTRICTED_PARAMS,
|
65
|
+
UNKNOWN_JSON_STR,
|
66
|
+
VLLM_INFERENCE_RESTRICTED_PARAMS,
|
67
|
+
)
|
68
|
+
from ads.aqua.data import AquaResourceIdentifier
|
69
|
+
from ads.common import auth as authutil
|
70
|
+
from ads.common.auth import AuthState, default_signer
|
71
|
+
from ads.common.decorator.threaded import threaded
|
72
|
+
from ads.common.extended_enum import ExtendedEnum
|
73
|
+
from ads.common.object_storage_details import ObjectStorageDetails
|
74
|
+
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
75
|
+
from ads.common.utils import (
|
76
|
+
UNKNOWN,
|
77
|
+
copy_file,
|
78
|
+
get_console_link,
|
79
|
+
read_file,
|
80
|
+
upload_to_os,
|
81
|
+
)
|
82
|
+
from ads.config import (
|
83
|
+
AQUA_MODEL_DEPLOYMENT_FOLDER,
|
84
|
+
AQUA_SERVICE_MODELS_BUCKET,
|
85
|
+
CONDA_BUCKET_NAME,
|
86
|
+
CONDA_BUCKET_NS,
|
87
|
+
TENANCY_OCID,
|
88
|
+
)
|
89
|
+
from ads.model import DataScienceModel, ModelVersionSet
|
90
|
+
|
91
|
+
logger = logging.getLogger("ads.aqua")
|
92
|
+
|
93
|
+
|
94
|
+
DEFINED_METADATA_TO_FILE_MAP = {
|
95
|
+
"readme": README,
|
96
|
+
"license": LICENSE,
|
97
|
+
"finetuneconfiguration": FINE_TUNING_CONFIG,
|
98
|
+
"deploymentconfiguration": DEPLOYMENT_CONFIG,
|
99
|
+
}
|
100
|
+
|
101
|
+
|
102
|
+
class LifecycleStatus(ExtendedEnum):
|
103
|
+
UNKNOWN = ""
|
104
|
+
|
105
|
+
@property
|
106
|
+
def detail(self) -> str:
|
107
|
+
"""Returns the detail message corresponding to the status."""
|
108
|
+
return LIFECYCLE_DETAILS_MAPPING.get(
|
109
|
+
self.name, f"No detail available for the status {self.name}."
|
110
|
+
)
|
111
|
+
|
112
|
+
@staticmethod
|
113
|
+
def get_status(evaluation_status: str, job_run_status: str = None):
|
114
|
+
"""
|
115
|
+
Maps the combination of evaluation status and job run status to a standard status.
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
evaluation_status (str):
|
120
|
+
The status of the evaluation.
|
121
|
+
job_run_status (str):
|
122
|
+
The status of the job run.
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
LifecycleStatus
|
127
|
+
The mapped status ("Completed", "In Progress", "Canceled").
|
128
|
+
"""
|
129
|
+
if not job_run_status:
|
130
|
+
logger.error("Failed to get jobrun state.")
|
131
|
+
# case1 : failed to create jobrun
|
132
|
+
# case2: jobrun is deleted - rqs cannot retreive deleted resource
|
133
|
+
return JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
|
134
|
+
|
135
|
+
status = LifecycleStatus.UNKNOWN
|
136
|
+
if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
|
137
|
+
if job_run_status in {
|
138
|
+
JobRun.LIFECYCLE_STATE_IN_PROGRESS,
|
139
|
+
JobRun.LIFECYCLE_STATE_ACCEPTED,
|
140
|
+
}:
|
141
|
+
status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
|
142
|
+
elif job_run_status in {
|
143
|
+
JobRun.LIFECYCLE_STATE_FAILED,
|
144
|
+
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
|
145
|
+
}:
|
146
|
+
status = JobRun.LIFECYCLE_STATE_FAILED
|
147
|
+
else:
|
148
|
+
status = job_run_status
|
149
|
+
else:
|
150
|
+
status = evaluation_status
|
151
|
+
|
152
|
+
return status
|
153
|
+
|
154
|
+
|
155
|
+
LIFECYCLE_DETAILS_MAPPING = {
|
156
|
+
JobRun.LIFECYCLE_STATE_SUCCEEDED: "The evaluation ran successfully.",
|
157
|
+
JobRun.LIFECYCLE_STATE_IN_PROGRESS: "The evaluation is running.",
|
158
|
+
JobRun.LIFECYCLE_STATE_FAILED: "The evaluation failed.",
|
159
|
+
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION: "Missing jobrun information.",
|
160
|
+
}
|
161
|
+
|
162
|
+
|
163
|
+
def random_color_generator(word: str):
|
164
|
+
seed = sum([ord(c) for c in word]) % 13
|
165
|
+
random.seed(seed)
|
166
|
+
r = random.randint(10, 245)
|
167
|
+
g = random.randint(10, 245)
|
168
|
+
b = random.randint(10, 245)
|
169
|
+
|
170
|
+
text_color = "black" if (0.299 * r + 0.587 * g + 0.114 * b) / 255 > 0.5 else "white"
|
171
|
+
|
172
|
+
return f"#{r:02x}{g:02x}{b:02x}", text_color
|
173
|
+
|
174
|
+
|
175
|
+
def svg_to_base64_datauri(svg_contents: str):
|
176
|
+
base64_encoded_svg_contents = base64.b64encode(svg_contents.encode())
|
177
|
+
return "data:image/svg+xml;base64," + base64_encoded_svg_contents.decode()
|
178
|
+
|
179
|
+
|
180
|
+
def create_word_icon(label: str, width: int = 150, return_as_datauri=True):
|
181
|
+
match = re.findall(r"(^[a-zA-Z]{1}).*?(\d+[a-z]?)", label)
|
182
|
+
icon_text = "".join(match[0] if match else [label[0]])
|
183
|
+
icon_color, text_color = random_color_generator(label)
|
184
|
+
cx = width / 2
|
185
|
+
cy = width / 2
|
186
|
+
r = width / 2
|
187
|
+
fs = int(r / 25)
|
188
|
+
|
189
|
+
t = Template(
|
190
|
+
"""
|
191
|
+
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" width="${width}" height="${width}">
|
192
|
+
|
193
|
+
<style>
|
194
|
+
text {
|
195
|
+
font-size: ${fs}em;
|
196
|
+
font-family: lucida console, Fira Mono, monospace;
|
197
|
+
text-anchor: middle;
|
198
|
+
stroke-width: 1px;
|
199
|
+
font-weight: bold;
|
200
|
+
alignment-baseline: central;
|
201
|
+
}
|
202
|
+
|
203
|
+
</style>
|
204
|
+
|
205
|
+
<circle cx="${cx}" cy="${cy}" r="${r}" fill="${icon_color}" />
|
206
|
+
<text x="50%" y="50%" fill="${text_color}">${icon_text}</text>
|
207
|
+
</svg>
|
208
|
+
""".strip()
|
209
|
+
)
|
210
|
+
|
211
|
+
icon_svg = t.substitute(**locals())
|
212
|
+
if return_as_datauri:
|
213
|
+
return svg_to_base64_datauri(icon_svg)
|
214
|
+
else:
|
215
|
+
return icon_svg
|
216
|
+
|
217
|
+
|
218
|
+
def get_artifact_path(custom_metadata_list: List) -> str:
|
219
|
+
"""Get the artifact path from the custom metadata list of model.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
custom_metadata_list: List
|
224
|
+
A list of custom metadata of OCI model.
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
str:
|
229
|
+
The artifact path from model.
|
230
|
+
"""
|
231
|
+
try:
|
232
|
+
for custom_metadata in custom_metadata_list:
|
233
|
+
if custom_metadata.key == MODEL_BY_REFERENCE_OSS_PATH_KEY:
|
234
|
+
if ObjectStorageDetails.is_oci_path(custom_metadata.value):
|
235
|
+
artifact_path = custom_metadata.value
|
236
|
+
else:
|
237
|
+
artifact_path = ObjectStorageDetails(
|
238
|
+
AQUA_SERVICE_MODELS_BUCKET,
|
239
|
+
CONDA_BUCKET_NS,
|
240
|
+
custom_metadata.value,
|
241
|
+
).path
|
242
|
+
return artifact_path
|
243
|
+
except Exception as ex:
|
244
|
+
logger.debug(ex)
|
245
|
+
|
246
|
+
logger.debug("Failed to get artifact path from custom metadata.")
|
247
|
+
return UNKNOWN
|
248
|
+
|
249
|
+
|
250
|
+
@threaded()
|
251
|
+
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
252
|
+
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
|
253
|
+
signer = default_signer() if artifact_path.startswith("oci://") else {}
|
254
|
+
config = json.loads(
|
255
|
+
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
|
256
|
+
)
|
257
|
+
if not config:
|
258
|
+
raise AquaFileNotFoundError(
|
259
|
+
f"Config file `{config_file_name}` is either empty or missing at {artifact_path}",
|
260
|
+
500,
|
261
|
+
)
|
262
|
+
return config
|
263
|
+
|
264
|
+
|
265
|
+
def list_os_files_with_extension(oss_path: str, extension: str) -> List[str]:
|
266
|
+
"""
|
267
|
+
List files in the specified directory with the given extension.
|
268
|
+
|
269
|
+
Parameters:
|
270
|
+
- oss_path: The path to the directory where files are located.
|
271
|
+
- extension: The file extension to filter by (e.g., 'txt' for text files).
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
- A list of file paths matching the specified extension.
|
275
|
+
"""
|
276
|
+
|
277
|
+
oss_client = ObjectStorageDetails.from_path(oss_path)
|
278
|
+
|
279
|
+
# Ensure the extension is prefixed with a dot if not already
|
280
|
+
if not extension.startswith("."):
|
281
|
+
extension = "." + extension
|
282
|
+
files: List[ObjectSummary] = oss_client.list_objects().objects
|
283
|
+
|
284
|
+
return [
|
285
|
+
file.name[len(oss_client.filepath) :].lstrip("/")
|
286
|
+
for file in files
|
287
|
+
if file.name.endswith(extension)
|
288
|
+
]
|
289
|
+
|
290
|
+
|
291
|
+
def is_valid_ocid(ocid: str) -> bool:
|
292
|
+
"""Checks if the given ocid is valid.
|
293
|
+
|
294
|
+
Parameters
|
295
|
+
----------
|
296
|
+
ocid: str
|
297
|
+
Oracle Cloud Identifier (OCID).
|
298
|
+
|
299
|
+
Returns
|
300
|
+
-------
|
301
|
+
bool:
|
302
|
+
Whether the given ocid is valid.
|
303
|
+
"""
|
304
|
+
|
305
|
+
if not ocid:
|
306
|
+
return False
|
307
|
+
return ocid.lower().startswith("ocid")
|
308
|
+
|
309
|
+
|
310
|
+
def get_resource_type(ocid: str) -> str:
|
311
|
+
"""Gets resource type based on the given ocid.
|
312
|
+
|
313
|
+
Parameters
|
314
|
+
----------
|
315
|
+
ocid: str
|
316
|
+
Oracle Cloud Identifier (OCID).
|
317
|
+
|
318
|
+
Returns
|
319
|
+
-------
|
320
|
+
str:
|
321
|
+
The resource type indicated in the given ocid.
|
322
|
+
|
323
|
+
Raises
|
324
|
+
-------
|
325
|
+
ValueError:
|
326
|
+
When the given ocid is not a valid ocid.
|
327
|
+
"""
|
328
|
+
if not is_valid_ocid(ocid):
|
329
|
+
raise ValueError(
|
330
|
+
f"The given ocid {ocid} is not a valid ocid."
|
331
|
+
"Check out this page https://docs.oracle.com/en-us/iaas/Content/General/Concepts/identifiers.htm to see more details."
|
332
|
+
)
|
333
|
+
return ocid.split(".")[1]
|
334
|
+
|
335
|
+
|
336
|
+
def query_resource(
|
337
|
+
ocid, return_all: bool = True
|
338
|
+
) -> "oci.resource_search.models.ResourceSummary":
|
339
|
+
"""Use Search service to find a single resource within a tenancy.
|
340
|
+
|
341
|
+
Parameters
|
342
|
+
----------
|
343
|
+
ocid: str
|
344
|
+
Oracle Cloud Identifier (OCID).
|
345
|
+
return_all: bool
|
346
|
+
Whether to return allAdditionalFields.
|
347
|
+
|
348
|
+
Returns
|
349
|
+
-------
|
350
|
+
oci.resource_search.models.ResourceSummary:
|
351
|
+
The retrieved resource.
|
352
|
+
"""
|
353
|
+
|
354
|
+
return_all = " return allAdditionalFields " if return_all else " "
|
355
|
+
resource_type = get_resource_type(ocid)
|
356
|
+
query = f"query {resource_type} resources{return_all}where (identifier = '{ocid}')"
|
357
|
+
logger.debug(query)
|
358
|
+
|
359
|
+
resources = OCIResource.search(
|
360
|
+
query,
|
361
|
+
type=SEARCH_TYPE.STRUCTURED,
|
362
|
+
tenant_id=TENANCY_OCID,
|
363
|
+
)
|
364
|
+
if len(resources) == 0:
|
365
|
+
raise AquaRuntimeError(
|
366
|
+
f"Failed to retreive {resource_type}'s information.",
|
367
|
+
service_payload={"query": query, "tenant_id": TENANCY_OCID},
|
368
|
+
)
|
369
|
+
return resources[0]
|
370
|
+
|
371
|
+
|
372
|
+
def query_resources(
|
373
|
+
compartment_id,
|
374
|
+
resource_type: str,
|
375
|
+
return_all: bool = True,
|
376
|
+
tag_list: list = None,
|
377
|
+
status_list: list = None,
|
378
|
+
connect_by_ampersands: bool = True,
|
379
|
+
**kwargs,
|
380
|
+
) -> List["oci.resource_search.models.ResourceSummary"]:
|
381
|
+
"""Use Search service to find resources within compartment.
|
382
|
+
|
383
|
+
Parameters
|
384
|
+
----------
|
385
|
+
compartment_id: str
|
386
|
+
The compartment ocid.
|
387
|
+
resource_type: str
|
388
|
+
The type of the target resources.
|
389
|
+
return_all: bool
|
390
|
+
Whether to return allAdditionalFields.
|
391
|
+
tag_list: list
|
392
|
+
List of tags will be applied for filtering.
|
393
|
+
status_list: list
|
394
|
+
List of lifecycleState will be applied for filtering.
|
395
|
+
connect_by_ampersands: bool
|
396
|
+
Whether to use `&&` to group multiple conditions.
|
397
|
+
if `connect_by_ampersands=False`, `||` will be used.
|
398
|
+
**kwargs:
|
399
|
+
Additional arguments.
|
400
|
+
|
401
|
+
Returns
|
402
|
+
-------
|
403
|
+
List[oci.resource_search.models.ResourceSummary]:
|
404
|
+
The retrieved resources.
|
405
|
+
"""
|
406
|
+
return_all = " return allAdditionalFields " if return_all else " "
|
407
|
+
condition_lifecycle = _construct_condition(
|
408
|
+
field_name="lifecycleState",
|
409
|
+
allowed_values=status_list,
|
410
|
+
connect_by_ampersands=False,
|
411
|
+
)
|
412
|
+
condition_tags = _construct_condition(
|
413
|
+
field_name="freeformTags.key",
|
414
|
+
allowed_values=tag_list,
|
415
|
+
connect_by_ampersands=connect_by_ampersands,
|
416
|
+
)
|
417
|
+
query = f"query {resource_type} resources{return_all}where (compartmentId = '{compartment_id}'{condition_lifecycle}{condition_tags})"
|
418
|
+
logger.debug(query)
|
419
|
+
logger.debug(f"tenant_id=`{TENANCY_OCID}`")
|
420
|
+
|
421
|
+
return OCIResource.search(
|
422
|
+
query, type=SEARCH_TYPE.STRUCTURED, tenant_id=TENANCY_OCID, **kwargs
|
423
|
+
)
|
424
|
+
|
425
|
+
|
426
|
+
def _construct_condition(
|
427
|
+
field_name: str, allowed_values: list = None, connect_by_ampersands: bool = True
|
428
|
+
) -> str:
|
429
|
+
"""Returns tag condition applied in query statement.
|
430
|
+
|
431
|
+
Parameters
|
432
|
+
----------
|
433
|
+
field_name: str
|
434
|
+
The field_name keyword is the resource attribute against which the
|
435
|
+
operation and chosen value of that attribute are evaluated.
|
436
|
+
allowed_values: list
|
437
|
+
List of value will be applied for filtering.
|
438
|
+
connect_by_ampersands: bool
|
439
|
+
Whether to use `&&` to group multiple tag conditions.
|
440
|
+
if `connect_by_ampersands=False`, `||` will be used.
|
441
|
+
|
442
|
+
Returns
|
443
|
+
-------
|
444
|
+
str:
|
445
|
+
The tag condition.
|
446
|
+
"""
|
447
|
+
if not allowed_values:
|
448
|
+
return ""
|
449
|
+
|
450
|
+
joint = "&&" if connect_by_ampersands else "||"
|
451
|
+
formatted_tags = [f"{field_name} = '{value}'" for value in allowed_values]
|
452
|
+
joined_tags = f" {joint} ".join(formatted_tags)
|
453
|
+
condition = f" && ({joined_tags})" if joined_tags else ""
|
454
|
+
return condition
|
455
|
+
|
456
|
+
|
457
|
+
def upload_local_to_os(
|
458
|
+
src_uri: str, dst_uri: str, auth: dict = None, force_overwrite: bool = False
|
459
|
+
):
|
460
|
+
expanded_path = os.path.expanduser(src_uri)
|
461
|
+
if not os.path.isfile(expanded_path):
|
462
|
+
raise AquaFileNotFoundError("Invalid input file path. Specify a valid one.")
|
463
|
+
if Path(expanded_path).suffix.lstrip(".") not in SUPPORTED_FILE_FORMATS:
|
464
|
+
raise AquaValueError(
|
465
|
+
f"Invalid input file. Only {', '.join(SUPPORTED_FILE_FORMATS)} files are supported."
|
466
|
+
)
|
467
|
+
if os.path.getsize(expanded_path) == 0:
|
468
|
+
raise AquaValueError("Empty input file. Specify a valid file path.")
|
469
|
+
if os.path.getsize(expanded_path) > MAXIMUM_ALLOWED_DATASET_IN_BYTE:
|
470
|
+
raise AquaValueError(
|
471
|
+
f"Local dataset file can't exceed {MAXIMUM_ALLOWED_DATASET_IN_BYTE} bytes."
|
472
|
+
)
|
473
|
+
|
474
|
+
upload_to_os(
|
475
|
+
src_uri=expanded_path,
|
476
|
+
dst_uri=dst_uri,
|
477
|
+
auth=auth,
|
478
|
+
force_overwrite=force_overwrite,
|
479
|
+
)
|
480
|
+
|
481
|
+
|
482
|
+
def sanitize_response(oci_client, response: list):
|
483
|
+
"""Builds a JSON POST object for the response from OCI clients.
|
484
|
+
|
485
|
+
Parameters
|
486
|
+
----------
|
487
|
+
oci_client
|
488
|
+
OCI client object
|
489
|
+
|
490
|
+
response
|
491
|
+
list of results from the OCI client
|
492
|
+
|
493
|
+
Returns
|
494
|
+
-------
|
495
|
+
The serialized form of data.
|
496
|
+
|
497
|
+
"""
|
498
|
+
return oci_client.base_client.sanitize_for_serialization(response)
|
499
|
+
|
500
|
+
|
501
|
+
def _build_resource_identifier(
|
502
|
+
id: str = None, name: str = None, region: str = None
|
503
|
+
) -> AquaResourceIdentifier:
|
504
|
+
"""Constructs AquaResourceIdentifier based on the given ocid and display name."""
|
505
|
+
try:
|
506
|
+
resource_type = CONSOLE_LINK_RESOURCE_TYPE_MAPPING.get(get_resource_type(id))
|
507
|
+
|
508
|
+
return AquaResourceIdentifier(
|
509
|
+
id=id,
|
510
|
+
name=name,
|
511
|
+
url=get_console_link(
|
512
|
+
resource=resource_type,
|
513
|
+
ocid=id,
|
514
|
+
region=region,
|
515
|
+
),
|
516
|
+
)
|
517
|
+
except Exception as e:
|
518
|
+
logger.debug(
|
519
|
+
f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}"
|
520
|
+
)
|
521
|
+
return AquaResourceIdentifier()
|
522
|
+
|
523
|
+
|
524
|
+
def _get_experiment_info(
|
525
|
+
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel],
|
526
|
+
) -> tuple:
|
527
|
+
"""Returns ocid and name of the experiment."""
|
528
|
+
return (
|
529
|
+
(
|
530
|
+
model.additional_details.get(RqsAdditionalDetails.MODEL_VERSION_SET_ID),
|
531
|
+
model.additional_details.get(RqsAdditionalDetails.MODEL_VERSION_SET_NAME),
|
532
|
+
)
|
533
|
+
if isinstance(model, oci.resource_search.models.ResourceSummary)
|
534
|
+
else (model.model_version_set_id, model.model_version_set_name)
|
535
|
+
)
|
536
|
+
|
537
|
+
|
538
|
+
def _build_job_identifier(
|
539
|
+
job_run_details: Union[
|
540
|
+
oci.data_science.models.JobRun, oci.resource_search.models.ResourceSummary
|
541
|
+
] = None,
|
542
|
+
**kwargs,
|
543
|
+
) -> AquaResourceIdentifier:
|
544
|
+
try:
|
545
|
+
job_id = (
|
546
|
+
job_run_details.id
|
547
|
+
if isinstance(job_run_details, oci.data_science.models.JobRun)
|
548
|
+
else job_run_details.identifier
|
549
|
+
)
|
550
|
+
return _build_resource_identifier(
|
551
|
+
id=job_id, name=job_run_details.display_name, **kwargs
|
552
|
+
)
|
553
|
+
|
554
|
+
except Exception as e:
|
555
|
+
logger.debug(
|
556
|
+
f"Failed to get job details from job_run_details: {job_run_details}"
|
557
|
+
f"DEBUG INFO:{str(e)}"
|
558
|
+
)
|
559
|
+
return AquaResourceIdentifier()
|
560
|
+
|
561
|
+
|
562
|
+
def get_max_version(versions):
|
563
|
+
"""Takes in a list of versions and returns the higher version."""
|
564
|
+
if not versions:
|
565
|
+
return UNKNOWN
|
566
|
+
|
567
|
+
def compare_versions(version1, version2):
|
568
|
+
# split version strings into parts and convert to int values for comparison
|
569
|
+
parts1 = list(map(int, version1.split(".")))
|
570
|
+
parts2 = list(map(int, version2.split(".")))
|
571
|
+
|
572
|
+
# compare each part
|
573
|
+
for idx in range(min(len(parts1), len(parts2))):
|
574
|
+
if parts1[idx] < parts2[idx]:
|
575
|
+
return version2
|
576
|
+
elif parts1[idx] > parts2[idx]:
|
577
|
+
return version1
|
578
|
+
|
579
|
+
# if all parts are equal up to this point, return the longer version string
|
580
|
+
return version1 if len(parts1) > len(parts2) else version2
|
581
|
+
|
582
|
+
max_version = versions[0]
|
583
|
+
for version in versions[1:]:
|
584
|
+
max_version = compare_versions(max_version, version)
|
585
|
+
|
586
|
+
return max_version
|
587
|
+
|
588
|
+
|
589
|
+
def fire_and_forget(func):
|
590
|
+
"""Decorator to push execution of methods to the background."""
|
591
|
+
|
592
|
+
@wraps(func)
|
593
|
+
def wrapped(*args, **kwargs):
|
594
|
+
return asyncio.get_event_loop().run_in_executor(None, func, *args, *kwargs)
|
595
|
+
|
596
|
+
return wrapped
|
597
|
+
|
598
|
+
|
599
|
+
def extract_id_and_name_from_tag(tag: str):
|
600
|
+
base_model_ocid = UNKNOWN
|
601
|
+
base_model_name = UNKNOWN
|
602
|
+
try:
|
603
|
+
base_model_ocid, base_model_name = tag.split("#")
|
604
|
+
except Exception:
|
605
|
+
pass
|
606
|
+
|
607
|
+
if not (is_valid_ocid(base_model_ocid) and base_model_name):
|
608
|
+
logger.debug(
|
609
|
+
f"Invalid {tag}. Specify tag in the format as <service_model_id>#<service_model_name>."
|
610
|
+
)
|
611
|
+
|
612
|
+
return base_model_ocid, base_model_name
|
613
|
+
|
614
|
+
|
615
|
+
def get_resource_name(ocid: str) -> str:
|
616
|
+
"""Gets resource name based on the given ocid.
|
617
|
+
|
618
|
+
Parameters
|
619
|
+
----------
|
620
|
+
ocid: str
|
621
|
+
Oracle Cloud Identifier (OCID).
|
622
|
+
|
623
|
+
Returns
|
624
|
+
-------
|
625
|
+
str:
|
626
|
+
The resource name indicated in the given ocid.
|
627
|
+
|
628
|
+
Raises
|
629
|
+
-------
|
630
|
+
ValueError:
|
631
|
+
When the given ocid is not a valid ocid.
|
632
|
+
"""
|
633
|
+
if not is_valid_ocid(ocid):
|
634
|
+
raise ValueError(
|
635
|
+
f"The given ocid {ocid} is not a valid ocid."
|
636
|
+
"Check out this page https://docs.oracle.com/en-us/iaas/Content/General/Concepts/identifiers.htm to see more details."
|
637
|
+
)
|
638
|
+
try:
|
639
|
+
resource = query_resource(ocid, return_all=False)
|
640
|
+
name = resource.display_name if resource else UNKNOWN
|
641
|
+
except Exception:
|
642
|
+
name = UNKNOWN
|
643
|
+
return name
|
644
|
+
|
645
|
+
|
646
|
+
def get_model_by_reference_paths(model_file_description: dict):
|
647
|
+
"""Reads the model file description json dict and returns the base model path and fine-tuned path for
|
648
|
+
models created by reference.
|
649
|
+
|
650
|
+
Parameters
|
651
|
+
----------
|
652
|
+
model_file_description: dict
|
653
|
+
json dict containing model paths and objects for models created by reference.
|
654
|
+
|
655
|
+
Returns
|
656
|
+
-------
|
657
|
+
a tuple with base_model_path and fine_tune_output_path
|
658
|
+
"""
|
659
|
+
base_model_path = UNKNOWN
|
660
|
+
fine_tune_output_path = UNKNOWN
|
661
|
+
models = model_file_description["models"]
|
662
|
+
|
663
|
+
if not models:
|
664
|
+
raise AquaValueError(
|
665
|
+
"Model path is not available in the model json artifact. "
|
666
|
+
"Please check if the model created by reference has the correct artifact."
|
667
|
+
)
|
668
|
+
|
669
|
+
if len(models) > 0:
|
670
|
+
# since the model_file_description json does not have a flag to identify the base model, we consider
|
671
|
+
# the first instance to be the base model.
|
672
|
+
base_model_artifact = models[0]
|
673
|
+
base_model_path = f"oci://{base_model_artifact['bucketName']}@{base_model_artifact['namespace']}/{base_model_artifact['prefix']}".rstrip(
|
674
|
+
"/"
|
675
|
+
)
|
676
|
+
if len(models) > 1:
|
677
|
+
# second model is considered as fine-tuned model
|
678
|
+
ft_model_artifact = models[1]
|
679
|
+
fine_tune_output_path = f"oci://{ft_model_artifact['bucketName']}@{ft_model_artifact['namespace']}/{ft_model_artifact['prefix']}".rstrip(
|
680
|
+
"/"
|
681
|
+
)
|
682
|
+
|
683
|
+
return base_model_path, fine_tune_output_path
|
684
|
+
|
685
|
+
|
686
|
+
def _is_valid_mvs(mvs: ModelVersionSet, target_tag: str) -> bool:
|
687
|
+
"""Returns whether the given model version sets has the target tag.
|
688
|
+
|
689
|
+
Parameters
|
690
|
+
----------
|
691
|
+
mvs: str
|
692
|
+
The instance of `ads.model.ModelVersionSet`.
|
693
|
+
target_tag: list
|
694
|
+
Target tag expected to be in MVS.
|
695
|
+
|
696
|
+
Returns
|
697
|
+
-------
|
698
|
+
bool:
|
699
|
+
Return True if the given model version sets is valid.
|
700
|
+
"""
|
701
|
+
if mvs.freeform_tags is None:
|
702
|
+
return False
|
703
|
+
|
704
|
+
return target_tag in mvs.freeform_tags
|
705
|
+
|
706
|
+
|
707
|
+
def known_realm():
|
708
|
+
"""This helper function returns True if the Aqua service is available by default in the given namespace.
|
709
|
+
Returns
|
710
|
+
-------
|
711
|
+
bool:
|
712
|
+
Return True if aqua service is available.
|
713
|
+
|
714
|
+
"""
|
715
|
+
return os.environ.get("CONDA_BUCKET_NS") in AQUA_GA_LIST
|
716
|
+
|
717
|
+
|
718
|
+
def get_ocid_substring(ocid: str, key_len: int) -> str:
|
719
|
+
"""This helper function returns the last n characters of the ocid specified by key_len parameter.
|
720
|
+
If ocid is None or length is less than key_len, it returns an empty string."""
|
721
|
+
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
722
|
+
|
723
|
+
|
724
|
+
def upload_folder(
|
725
|
+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
|
726
|
+
) -> str:
|
727
|
+
"""Upload the local folder to the object storage
|
728
|
+
|
729
|
+
Args:
|
730
|
+
os_path (str): object storage URI with prefix. This is the path to upload
|
731
|
+
local_dir (str): Local directory where the object is downloaded
|
732
|
+
model_name (str): Name of the huggingface model
|
733
|
+
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
|
734
|
+
Retuns:
|
735
|
+
str: Object name inside the bucket
|
736
|
+
"""
|
737
|
+
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
|
738
|
+
if not os_details.is_bucket_versioned():
|
739
|
+
raise ValueError(f"Version is not enabled at object storage location {os_path}")
|
740
|
+
auth_state = AuthState()
|
741
|
+
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
|
742
|
+
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
|
743
|
+
if exclude_pattern:
|
744
|
+
command += f" --exclude {exclude_pattern}"
|
745
|
+
try:
|
746
|
+
logger.info(f"Running: {command}")
|
747
|
+
subprocess.check_call(shlex.split(command))
|
748
|
+
except subprocess.CalledProcessError as e:
|
749
|
+
logger.error(
|
750
|
+
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
|
751
|
+
)
|
752
|
+
|
753
|
+
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
|
754
|
+
|
755
|
+
|
756
|
+
def cleanup_local_hf_model_artifact(
|
757
|
+
model_name: str,
|
758
|
+
local_dir: str = None,
|
759
|
+
):
|
760
|
+
"""
|
761
|
+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
|
762
|
+
Parameters
|
763
|
+
----------
|
764
|
+
model_name (str): Name of the huggingface model
|
765
|
+
local_dir (str): Local directory where the object is downloaded
|
766
|
+
|
767
|
+
"""
|
768
|
+
if local_dir and os.path.exists(local_dir):
|
769
|
+
model_dir = os.path.join(local_dir, model_name)
|
770
|
+
model_dir = (
|
771
|
+
os.path.dirname(model_dir)
|
772
|
+
if "/" in model_name or os.sep in model_name
|
773
|
+
else model_dir
|
774
|
+
)
|
775
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
776
|
+
if os.path.exists(model_dir):
|
777
|
+
logger.debug(
|
778
|
+
f"Could not delete local model artifact directory: {model_dir}"
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
|
782
|
+
|
783
|
+
hf_local_path = os.path.join(
|
784
|
+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
|
785
|
+
)
|
786
|
+
shutil.rmtree(hf_local_path, ignore_errors=True)
|
787
|
+
|
788
|
+
if os.path.exists(hf_local_path):
|
789
|
+
logger.debug(
|
790
|
+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
|
791
|
+
)
|
792
|
+
else:
|
793
|
+
logger.debug(
|
794
|
+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
|
795
|
+
)
|
796
|
+
|
797
|
+
|
798
|
+
def is_service_managed_container(container):
|
799
|
+
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
|
800
|
+
|
801
|
+
|
802
|
+
def get_params_list(params: str) -> List[str]:
|
803
|
+
"""Parses the string parameter and returns a list of params.
|
804
|
+
|
805
|
+
Parameters
|
806
|
+
----------
|
807
|
+
params
|
808
|
+
string parameters by separated by -- delimiter
|
809
|
+
|
810
|
+
Returns
|
811
|
+
-------
|
812
|
+
list of params
|
813
|
+
|
814
|
+
"""
|
815
|
+
if not params:
|
816
|
+
return []
|
817
|
+
return ["--" + param.strip() for param in params.split("--")[1:]]
|
818
|
+
|
819
|
+
|
820
|
+
def get_params_dict(params: Union[str, List[str]]) -> dict:
|
821
|
+
"""Accepts a string or list of string of double-dash parameters and returns a dict with the parameter keys and values.
|
822
|
+
|
823
|
+
Parameters
|
824
|
+
----------
|
825
|
+
params:
|
826
|
+
List of parameters or parameter string separated by space.
|
827
|
+
|
828
|
+
Returns
|
829
|
+
-------
|
830
|
+
dict containing parameter keys and values
|
831
|
+
|
832
|
+
"""
|
833
|
+
params_list = get_params_list(params) if isinstance(params, str) else params
|
834
|
+
return {
|
835
|
+
split_result[0]: " ".join(split_result[1:])
|
836
|
+
if len(split_result) > 1
|
837
|
+
else UNKNOWN
|
838
|
+
for split_result in (x.split() for x in params_list)
|
839
|
+
}
|
840
|
+
|
841
|
+
|
842
|
+
def get_combined_params(params1: str = None, params2: str = None) -> str:
|
843
|
+
"""
|
844
|
+
Combines string of double-dash parameters, and overrides the values from the second string in the first.
|
845
|
+
Parameters
|
846
|
+
----------
|
847
|
+
params1:
|
848
|
+
Parameter string with values
|
849
|
+
params2:
|
850
|
+
Parameter string with values that need to be overridden.
|
851
|
+
|
852
|
+
Returns
|
853
|
+
-------
|
854
|
+
A combined list with overridden values from params2.
|
855
|
+
"""
|
856
|
+
if not params1:
|
857
|
+
return params2
|
858
|
+
if not params2:
|
859
|
+
return params1
|
860
|
+
|
861
|
+
# overwrite values from params2 into params1
|
862
|
+
combined_params = [
|
863
|
+
f"{key} {value}" if value else key
|
864
|
+
for key, value in {
|
865
|
+
**get_params_dict(params1),
|
866
|
+
**get_params_dict(params2),
|
867
|
+
}.items()
|
868
|
+
]
|
869
|
+
|
870
|
+
return " ".join(combined_params)
|
871
|
+
|
872
|
+
|
873
|
+
def build_params_string(params: dict) -> str:
|
874
|
+
"""Builds params string from params dict
|
875
|
+
|
876
|
+
Parameters
|
877
|
+
----------
|
878
|
+
params:
|
879
|
+
Parameter dict with key-value pairs
|
880
|
+
|
881
|
+
Returns
|
882
|
+
-------
|
883
|
+
A params string.
|
884
|
+
"""
|
885
|
+
return (
|
886
|
+
" ".join(
|
887
|
+
f"{name} {value}" if value else f"{name}" for name, value in params.items()
|
888
|
+
).strip()
|
889
|
+
if params
|
890
|
+
else UNKNOWN
|
891
|
+
)
|
892
|
+
|
893
|
+
|
894
|
+
def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
|
895
|
+
"""Copies the aqua model config folder from the artifact path to the user provided object storage path.
|
896
|
+
The config folder is overwritten if the files already exist at the destination path.
|
897
|
+
|
898
|
+
Parameters
|
899
|
+
----------
|
900
|
+
artifact_path:
|
901
|
+
Path of the aqua model where config folder is available.
|
902
|
+
os_path:
|
903
|
+
User provided path where config folder will be copied.
|
904
|
+
auth: (Dict, optional). Defaults to None.
|
905
|
+
The default authentication is set using `ads.set_auth` API. If you need to override the
|
906
|
+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
|
907
|
+
authentication signer and kwargs required to instantiate IdentityClient object.
|
908
|
+
|
909
|
+
Returns
|
910
|
+
-------
|
911
|
+
None
|
912
|
+
Nothing.
|
913
|
+
"""
|
914
|
+
|
915
|
+
try:
|
916
|
+
source_dir = ObjectStorageDetails(
|
917
|
+
AQUA_SERVICE_MODELS_BUCKET,
|
918
|
+
CONDA_BUCKET_NS,
|
919
|
+
f"{os.path.dirname(artifact_path).rstrip('/')}/config",
|
920
|
+
).path
|
921
|
+
dest_dir = f"{os_path.rstrip('/')}/config"
|
922
|
+
|
923
|
+
oss_details = ObjectStorageDetails.from_path(source_dir)
|
924
|
+
objects = oss_details.list_objects(fields="name").objects
|
925
|
+
|
926
|
+
for obj in objects:
|
927
|
+
source_path = ObjectStorageDetails(
|
928
|
+
AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, obj.name
|
929
|
+
).path
|
930
|
+
destination_path = os.path.join(dest_dir, os.path.basename(obj.name))
|
931
|
+
copy_file(
|
932
|
+
uri_src=source_path,
|
933
|
+
uri_dst=destination_path,
|
934
|
+
force_overwrite=True,
|
935
|
+
auth=auth,
|
936
|
+
)
|
937
|
+
except Exception as ex:
|
938
|
+
logger.debug(ex)
|
939
|
+
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")
|
940
|
+
|
941
|
+
|
942
|
+
def get_container_params_type(container_type_name: str) -> str:
|
943
|
+
"""The utility function accepts the deployment container type name and returns the corresponding params name.
|
944
|
+
Parameters
|
945
|
+
----------
|
946
|
+
container_type_name: str
|
947
|
+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
|
948
|
+
|
949
|
+
Returns
|
950
|
+
-------
|
951
|
+
InferenceContainerParamType value
|
952
|
+
|
953
|
+
"""
|
954
|
+
# check substring instead of direct match in case container_type_name changes in the future
|
955
|
+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
|
956
|
+
return InferenceContainerParamType.PARAM_TYPE_VLLM
|
957
|
+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
|
958
|
+
return InferenceContainerParamType.PARAM_TYPE_TGI
|
959
|
+
elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
|
960
|
+
return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
|
961
|
+
else:
|
962
|
+
return UNKNOWN
|
963
|
+
|
964
|
+
|
965
|
+
def get_restricted_params_by_container(container_type_name: str) -> set:
|
966
|
+
"""The utility function accepts the deployment container type name and returns a set of restricted params
|
967
|
+
for that container.
|
968
|
+
Parameters
|
969
|
+
----------
|
970
|
+
container_type_name: str
|
971
|
+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
|
972
|
+
|
973
|
+
Returns
|
974
|
+
-------
|
975
|
+
Set of restricted params based on container type
|
976
|
+
|
977
|
+
"""
|
978
|
+
# check substring instead of direct match in case container_type_name changes in the future
|
979
|
+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
|
980
|
+
return VLLM_INFERENCE_RESTRICTED_PARAMS
|
981
|
+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
|
982
|
+
return TGI_INFERENCE_RESTRICTED_PARAMS
|
983
|
+
else:
|
984
|
+
return set()
|
985
|
+
|
986
|
+
|
987
|
+
def get_huggingface_login_timeout() -> int:
|
988
|
+
"""This helper function returns the huggingface login timeout, returns default if not set via
|
989
|
+
env var.
|
990
|
+
Returns
|
991
|
+
-------
|
992
|
+
timeout: int
|
993
|
+
huggingface login timeout.
|
994
|
+
|
995
|
+
"""
|
996
|
+
timeout = HF_LOGIN_DEFAULT_TIMEOUT
|
997
|
+
try:
|
998
|
+
timeout = int(
|
999
|
+
os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT)
|
1000
|
+
)
|
1001
|
+
except ValueError:
|
1002
|
+
pass
|
1003
|
+
return timeout
|
1004
|
+
|
1005
|
+
|
1006
|
+
def format_hf_custom_error_message(error: HfHubHTTPError):
|
1007
|
+
"""
|
1008
|
+
Formats a custom error message based on the Hugging Face error response.
|
1009
|
+
|
1010
|
+
Parameters
|
1011
|
+
----------
|
1012
|
+
error (HfHubHTTPError): The caught exception.
|
1013
|
+
|
1014
|
+
Raises
|
1015
|
+
------
|
1016
|
+
AquaRuntimeError: A user-friendly error message.
|
1017
|
+
"""
|
1018
|
+
# Extract the repository URL from the error message if present
|
1019
|
+
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
|
1020
|
+
url = match.group(1) if match else "the requested Hugging Face URL."
|
1021
|
+
|
1022
|
+
if isinstance(error, RepositoryNotFoundError):
|
1023
|
+
raise AquaRuntimeError(
|
1024
|
+
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
|
1025
|
+
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
|
1026
|
+
"To register your token, run this command in your terminal: `huggingface-cli login`",
|
1027
|
+
service_payload={"error": "RepositoryNotFoundError"},
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
if isinstance(error, GatedRepoError):
|
1031
|
+
raise AquaRuntimeError(
|
1032
|
+
reason=f"Access denied to `{url}` "
|
1033
|
+
"This repository is gated. Access is restricted to authorized users. "
|
1034
|
+
"Please request access or check with the repository administrator. "
|
1035
|
+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
|
1036
|
+
"To register your token, run this command in your terminal: `huggingface-cli login`",
|
1037
|
+
service_payload={"error": "GatedRepoError"},
|
1038
|
+
)
|
1039
|
+
|
1040
|
+
if isinstance(error, RevisionNotFoundError):
|
1041
|
+
raise AquaRuntimeError(
|
1042
|
+
reason=f"The specified revision could not be found at `{url}` "
|
1043
|
+
"Please check the revision identifier and try again.",
|
1044
|
+
service_payload={"error": "RevisionNotFoundError"},
|
1045
|
+
)
|
1046
|
+
|
1047
|
+
raise AquaRuntimeError(
|
1048
|
+
reason=f"An error occurred while accessing `{url}` "
|
1049
|
+
"Please check your network connection and try again. "
|
1050
|
+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
|
1051
|
+
"To register your token, run this command in your terminal: `huggingface-cli login`",
|
1052
|
+
service_payload={"error": "Error"},
|
1053
|
+
)
|
1054
|
+
|
1055
|
+
|
1056
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
|
1057
|
+
def get_hf_model_info(repo_id: str) -> ModelInfo:
|
1058
|
+
"""Gets the model information object for the given model repository name. For models that requires a token,
|
1059
|
+
this method assumes that the token validation is already done.
|
1060
|
+
|
1061
|
+
Parameters
|
1062
|
+
----------
|
1063
|
+
repo_id: str
|
1064
|
+
hugging face model repository name
|
1065
|
+
|
1066
|
+
Returns
|
1067
|
+
-------
|
1068
|
+
instance of ModelInfo object
|
1069
|
+
|
1070
|
+
"""
|
1071
|
+
try:
|
1072
|
+
return HfApi().model_info(repo_id=repo_id)
|
1073
|
+
except HfHubHTTPError as err:
|
1074
|
+
raise format_hf_custom_error_message(err) from err
|
1075
|
+
|
1076
|
+
|
1077
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
|
1078
|
+
def list_hf_models(query: str) -> List[str]:
|
1079
|
+
try:
|
1080
|
+
models = HfApi().list_models(
|
1081
|
+
model_name=query,
|
1082
|
+
sort="downloads",
|
1083
|
+
direction=-1,
|
1084
|
+
limit=20,
|
1085
|
+
)
|
1086
|
+
return [model.id for model in models if model.disabled is None]
|
1087
|
+
except HfHubHTTPError as err:
|
1088
|
+
raise format_hf_custom_error_message(err) from err
|
1089
|
+
|
1090
|
+
|
1091
|
+
def generate_tei_cmd_var(os_path: str) -> List[str]:
|
1092
|
+
"""This utility functions generates CMD params for Text Embedding Inference container. Only the
|
1093
|
+
essential parameters for OCI model deployment are added, defaults are used for the rest.
|
1094
|
+
Parameters
|
1095
|
+
----------
|
1096
|
+
os_path: str
|
1097
|
+
OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix
|
1098
|
+
|
1099
|
+
Returns
|
1100
|
+
-------
|
1101
|
+
cmd_var:
|
1102
|
+
List of command line arguments
|
1103
|
+
"""
|
1104
|
+
|
1105
|
+
cmd_prefix = "--"
|
1106
|
+
cmd_var = [
|
1107
|
+
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}",
|
1108
|
+
f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/",
|
1109
|
+
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}",
|
1110
|
+
TEI_CONTAINER_DEFAULT_HOST,
|
1111
|
+
]
|
1112
|
+
|
1113
|
+
return cmd_var
|
1114
|
+
|
1115
|
+
|
1116
|
+
def parse_cmd_var(cmd_list: List[str]) -> dict:
|
1117
|
+
"""Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix
|
1118
|
+
'--' and the value of the key is the subsequent element.
|
1119
|
+
"""
|
1120
|
+
parsed_cmd = {}
|
1121
|
+
|
1122
|
+
for i, cmd in enumerate(cmd_list):
|
1123
|
+
if cmd.startswith("--"):
|
1124
|
+
if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"):
|
1125
|
+
parsed_cmd[cmd] = cmd_list[i + 1]
|
1126
|
+
i += 1
|
1127
|
+
else:
|
1128
|
+
parsed_cmd[cmd] = None
|
1129
|
+
return parsed_cmd
|
1130
|
+
|
1131
|
+
|
1132
|
+
def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
|
1133
|
+
"""This function accepts two lists of parameters and combines them. If the second list shares the common parameter
|
1134
|
+
names/keys, then it raises an error.
|
1135
|
+
Parameters
|
1136
|
+
----------
|
1137
|
+
cmd_var: List[str]
|
1138
|
+
Default list of parameters
|
1139
|
+
overrides: List[str]
|
1140
|
+
List of parameters to override
|
1141
|
+
Returns
|
1142
|
+
-------
|
1143
|
+
List[str] of combined parameters
|
1144
|
+
"""
|
1145
|
+
cmd_var = [str(x) for x in cmd_var]
|
1146
|
+
if not overrides:
|
1147
|
+
return cmd_var
|
1148
|
+
overrides = [str(x) for x in overrides]
|
1149
|
+
|
1150
|
+
cmd_dict = parse_cmd_var(cmd_var)
|
1151
|
+
overrides_dict = parse_cmd_var(overrides)
|
1152
|
+
|
1153
|
+
# check for conflicts
|
1154
|
+
common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys())
|
1155
|
+
if common_keys:
|
1156
|
+
raise AquaValueError(
|
1157
|
+
f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}"
|
1158
|
+
)
|
1159
|
+
|
1160
|
+
combined_cmd_var = cmd_var + overrides
|
1161
|
+
return combined_cmd_var
|
1162
|
+
|
1163
|
+
|
1164
|
+
def build_pydantic_error_message(ex: ValidationError):
|
1165
|
+
"""
|
1166
|
+
Added to handle error messages from pydantic model validator.
|
1167
|
+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
|
1168
|
+
message using msg field.
|
1169
|
+
"""
|
1170
|
+
|
1171
|
+
return {
|
1172
|
+
".".join(map(str, e["loc"])): e["msg"]
|
1173
|
+
for e in ex.errors()
|
1174
|
+
if "loc" in e and e["loc"]
|
1175
|
+
} or "; ".join(e["msg"] for e in ex.errors())
|
1176
|
+
|
1177
|
+
|
1178
|
+
def is_pydantic_model(obj: object) -> bool:
|
1179
|
+
"""
|
1180
|
+
Returns True if obj is a Pydantic model class or an instance of a Pydantic model.
|
1181
|
+
|
1182
|
+
Args:
|
1183
|
+
obj: The object or class to check.
|
1184
|
+
|
1185
|
+
Returns:
|
1186
|
+
bool: True if obj is a subclass or instance of BaseModel, False otherwise.
|
1187
|
+
"""
|
1188
|
+
cls = obj if isinstance(obj, type) else type(obj)
|
1189
|
+
return issubclass(cls, BaseModel)
|
1190
|
+
|
1191
|
+
|
1192
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
|
1193
|
+
def load_gpu_shapes_index(
|
1194
|
+
auth: Optional[Dict[str, Any]] = None,
|
1195
|
+
) -> GPUShapesIndex:
|
1196
|
+
"""
|
1197
|
+
Load the GPU shapes index, preferring the OS bucket copy over the local one.
|
1198
|
+
|
1199
|
+
Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
|
1200
|
+
if that succeeds, those entries will override the local defaults.
|
1201
|
+
|
1202
|
+
Parameters
|
1203
|
+
----------
|
1204
|
+
auth
|
1205
|
+
Optional auth dict (as returned by `ads.common.auth.default_signer()`)
|
1206
|
+
to pass through to `fsspec.open()`.
|
1207
|
+
|
1208
|
+
Returns
|
1209
|
+
-------
|
1210
|
+
GPUShapesIndex
|
1211
|
+
Merged index where any shape present remotely supersedes the local entry.
|
1212
|
+
|
1213
|
+
Raises
|
1214
|
+
------
|
1215
|
+
json.JSONDecodeError
|
1216
|
+
If any of the JSON is malformed.
|
1217
|
+
"""
|
1218
|
+
file_name = "gpu_shapes_index.json"
|
1219
|
+
|
1220
|
+
# Try remote load
|
1221
|
+
remote_data: Dict[str, Any] = {}
|
1222
|
+
if CONDA_BUCKET_NS:
|
1223
|
+
try:
|
1224
|
+
auth = auth or authutil.default_signer()
|
1225
|
+
storage_path = (
|
1226
|
+
f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
|
1227
|
+
)
|
1228
|
+
logger.debug(
|
1229
|
+
"Loading GPU shapes index from Object Storage: %s", storage_path
|
1230
|
+
)
|
1231
|
+
with fsspec.open(storage_path, mode="r", **auth) as f:
|
1232
|
+
remote_data = json.load(f)
|
1233
|
+
logger.debug(
|
1234
|
+
"Loaded %d shapes from Object Storage",
|
1235
|
+
len(remote_data.get("shapes", {})),
|
1236
|
+
)
|
1237
|
+
except Exception as ex:
|
1238
|
+
logger.debug("Remote load failed (%s); falling back to local", ex)
|
1239
|
+
|
1240
|
+
# Load local copy
|
1241
|
+
local_data: Dict[str, Any] = {}
|
1242
|
+
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
|
1243
|
+
try:
|
1244
|
+
logger.debug("Loading GPU shapes index from local file: %s", local_path)
|
1245
|
+
with open(local_path) as f:
|
1246
|
+
local_data = json.load(f)
|
1247
|
+
logger.debug(
|
1248
|
+
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
|
1249
|
+
)
|
1250
|
+
except Exception as ex:
|
1251
|
+
logger.debug("Local load GPU shapes index failed (%s)", ex)
|
1252
|
+
|
1253
|
+
# Merge: remote shapes override local
|
1254
|
+
local_shapes = local_data.get("shapes", {})
|
1255
|
+
remote_shapes = remote_data.get("shapes", {})
|
1256
|
+
merged_shapes = {**local_shapes, **remote_shapes}
|
1257
|
+
|
1258
|
+
return GPUShapesIndex(shapes=merged_shapes)
|
1259
|
+
|
1260
|
+
|
1261
|
+
def get_preferred_compatible_family(selected_families: set[str]) -> str:
|
1262
|
+
"""
|
1263
|
+
Determines the preferred container family from a given set of container families.
|
1264
|
+
|
1265
|
+
This method is used in the context of multi-model deployment to handle cases
|
1266
|
+
where models selected for deployment use different, but compatible, container families.
|
1267
|
+
|
1268
|
+
It checks the input `families` set against the `CONTAINER_FAMILY_COMPATIBILITY` map.
|
1269
|
+
If a compatibility group exists that fully includes all the families in the input,
|
1270
|
+
the corresponding key (i.e., the preferred family) is returned.
|
1271
|
+
|
1272
|
+
Parameters
|
1273
|
+
----------
|
1274
|
+
families : set[str]
|
1275
|
+
A set of container family identifiers.
|
1276
|
+
|
1277
|
+
Returns
|
1278
|
+
-------
|
1279
|
+
Optional[str]
|
1280
|
+
The preferred container family if all families are compatible within one group;
|
1281
|
+
otherwise, returns `None` indicating that no compatible family group was found.
|
1282
|
+
|
1283
|
+
Example
|
1284
|
+
-------
|
1285
|
+
>>> get_preferred_compatible_family({"odsc-vllm-serving", "odsc-vllm-serving-v1"})
|
1286
|
+
'odsc-vllm-serving-v1'
|
1287
|
+
|
1288
|
+
>>> get_preferred_compatible_family({"odsc-vllm-serving", "odsc-tgi-serving"})
|
1289
|
+
None # Incompatible families
|
1290
|
+
"""
|
1291
|
+
for preferred, compatible_list in CONTAINER_FAMILY_COMPATIBILITY.items():
|
1292
|
+
if selected_families.issubset(set(compatible_list)):
|
1293
|
+
return preferred
|
1294
|
+
|
1295
|
+
return None
|