oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +247 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
ads/aqua/model/model.py
ADDED
@@ -0,0 +1,2134 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import pathlib
|
7
|
+
import re
|
8
|
+
from datetime import datetime, timedelta
|
9
|
+
from threading import Lock
|
10
|
+
from typing import Any, Dict, List, Optional, Set, Union
|
11
|
+
|
12
|
+
import oci
|
13
|
+
from cachetools import TTLCache
|
14
|
+
from huggingface_hub import snapshot_download
|
15
|
+
from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails
|
16
|
+
|
17
|
+
from ads.aqua import logger
|
18
|
+
from ads.aqua.app import AquaApp
|
19
|
+
from ads.aqua.common.entities import AquaMultiModelRef
|
20
|
+
from ads.aqua.common.enums import (
|
21
|
+
ConfigFolder,
|
22
|
+
CustomInferenceContainerTypeFamily,
|
23
|
+
FineTuningContainerTypeFamily,
|
24
|
+
InferenceContainerTypeFamily,
|
25
|
+
ModelFormat,
|
26
|
+
Platform,
|
27
|
+
Tags,
|
28
|
+
)
|
29
|
+
from ads.aqua.common.errors import (
|
30
|
+
AquaFileNotFoundError,
|
31
|
+
AquaRuntimeError,
|
32
|
+
AquaValueError,
|
33
|
+
)
|
34
|
+
from ads.aqua.common.utils import (
|
35
|
+
LifecycleStatus,
|
36
|
+
_build_resource_identifier,
|
37
|
+
cleanup_local_hf_model_artifact,
|
38
|
+
create_word_icon,
|
39
|
+
generate_tei_cmd_var,
|
40
|
+
get_artifact_path,
|
41
|
+
get_hf_model_info,
|
42
|
+
get_preferred_compatible_family,
|
43
|
+
list_os_files_with_extension,
|
44
|
+
load_config,
|
45
|
+
upload_folder,
|
46
|
+
)
|
47
|
+
from ads.aqua.config.container_config import AquaContainerConfig, Usage
|
48
|
+
from ads.aqua.constants import (
|
49
|
+
AQUA_MODEL_ARTIFACT_CONFIG,
|
50
|
+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
|
51
|
+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
|
52
|
+
AQUA_MODEL_ARTIFACT_FILE,
|
53
|
+
AQUA_MODEL_TOKENIZER_CONFIG,
|
54
|
+
AQUA_MODEL_TYPE_CUSTOM,
|
55
|
+
HF_METADATA_FOLDER,
|
56
|
+
LICENSE,
|
57
|
+
MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
58
|
+
README,
|
59
|
+
READY_TO_DEPLOY_STATUS,
|
60
|
+
READY_TO_FINE_TUNE_STATUS,
|
61
|
+
READY_TO_IMPORT_STATUS,
|
62
|
+
TRAINING_METRICS_FINAL,
|
63
|
+
TRINING_METRICS,
|
64
|
+
VALIDATION_METRICS,
|
65
|
+
VALIDATION_METRICS_FINAL,
|
66
|
+
)
|
67
|
+
from ads.aqua.model.constants import (
|
68
|
+
AquaModelMetadataKeys,
|
69
|
+
FineTuningCustomMetadata,
|
70
|
+
FineTuningMetricCategories,
|
71
|
+
ModelCustomMetadataFields,
|
72
|
+
ModelType,
|
73
|
+
)
|
74
|
+
from ads.aqua.model.entities import (
|
75
|
+
AquaFineTuneModel,
|
76
|
+
AquaFineTuningMetric,
|
77
|
+
AquaModel,
|
78
|
+
AquaModelLicense,
|
79
|
+
AquaModelReadme,
|
80
|
+
AquaModelSummary,
|
81
|
+
ImportModelDetails,
|
82
|
+
ModelFileDescription,
|
83
|
+
ModelValidationResult,
|
84
|
+
)
|
85
|
+
from ads.aqua.model.enums import MultiModelSupportedTaskType
|
86
|
+
from ads.aqua.model.utils import (
|
87
|
+
extract_base_model_from_ft,
|
88
|
+
extract_fine_tune_artifacts_path,
|
89
|
+
)
|
90
|
+
from ads.common.auth import default_signer
|
91
|
+
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
92
|
+
from ads.common.utils import (
|
93
|
+
UNKNOWN,
|
94
|
+
get_console_link,
|
95
|
+
is_path_exists,
|
96
|
+
read_file,
|
97
|
+
)
|
98
|
+
from ads.config import (
|
99
|
+
AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME,
|
100
|
+
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
|
101
|
+
AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME,
|
102
|
+
AQUA_EVALUATION_CONTAINER_METADATA_NAME,
|
103
|
+
AQUA_FINETUNING_CONTAINER_METADATA_NAME,
|
104
|
+
AQUA_SERVICE_MODELS,
|
105
|
+
COMPARTMENT_OCID,
|
106
|
+
PROJECT_OCID,
|
107
|
+
SERVICE,
|
108
|
+
TENANCY_OCID,
|
109
|
+
USER,
|
110
|
+
)
|
111
|
+
from ads.model import DataScienceModel
|
112
|
+
from ads.model.common.utils import MetadataArtifactPathType
|
113
|
+
from ads.model.model_metadata import (
|
114
|
+
MetadataCustomCategory,
|
115
|
+
ModelCustomMetadata,
|
116
|
+
ModelCustomMetadataItem,
|
117
|
+
)
|
118
|
+
from ads.telemetry import telemetry
|
119
|
+
|
120
|
+
|
121
|
+
class AquaModelApp(AquaApp):
|
122
|
+
"""Provides a suite of APIs to interact with Aqua models within the Oracle
|
123
|
+
Cloud Infrastructure Data Science service, serving as an interface for
|
124
|
+
managing machine learning models.
|
125
|
+
|
126
|
+
|
127
|
+
Methods
|
128
|
+
-------
|
129
|
+
create(model_id: str, project_id: str, compartment_id: str = None, **kwargs) -> "AquaModel"
|
130
|
+
Creates custom aqua model from service model.
|
131
|
+
get(model_id: str) -> AquaModel:
|
132
|
+
Retrieves details of an Aqua model by its unique identifier.
|
133
|
+
list(compartment_id: str = None, project_id: str = None, **kwargs) -> List[AquaModelSummary]:
|
134
|
+
Lists all Aqua models within a specified compartment and/or project.
|
135
|
+
clear_model_list_cache()
|
136
|
+
Allows clear list model cache items from the service models compartment.
|
137
|
+
register(model: str, os_path: str, local_dir: str = None)
|
138
|
+
|
139
|
+
Note:
|
140
|
+
This class is designed to work within the Oracle Cloud Infrastructure
|
141
|
+
and requires proper configuration and authentication set up to interact
|
142
|
+
with OCI services.
|
143
|
+
"""
|
144
|
+
|
145
|
+
_service_models_cache = TTLCache(
|
146
|
+
maxsize=10, ttl=timedelta(hours=5), timer=datetime.now
|
147
|
+
)
|
148
|
+
# Used for saving service model details
|
149
|
+
_service_model_details_cache = TTLCache(
|
150
|
+
maxsize=10, ttl=timedelta(hours=5), timer=datetime.now
|
151
|
+
)
|
152
|
+
_cache_lock = Lock()
|
153
|
+
|
154
|
+
@telemetry(entry_point="plugin=model&action=create", name="aqua")
|
155
|
+
def create(
|
156
|
+
self,
|
157
|
+
model_id: Union[str, AquaMultiModelRef],
|
158
|
+
project_id: Optional[str] = None,
|
159
|
+
compartment_id: Optional[str] = None,
|
160
|
+
freeform_tags: Optional[Dict] = None,
|
161
|
+
defined_tags: Optional[Dict] = None,
|
162
|
+
**kwargs,
|
163
|
+
) -> DataScienceModel:
|
164
|
+
"""
|
165
|
+
Creates a custom Aqua model from a service model.
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
model_id : Union[str, AquaMultiModelRef]
|
170
|
+
The model ID as a string or a AquaMultiModelRef instance to be deployed.
|
171
|
+
project_id : Optional[str]
|
172
|
+
The project ID for the custom model.
|
173
|
+
compartment_id : Optional[str]
|
174
|
+
The compartment ID for the custom model. Defaults to None.
|
175
|
+
If not provided, the compartment ID will be fetched from environment variables.
|
176
|
+
freeform_tags : Optional[Dict]
|
177
|
+
Freeform tags for the model.
|
178
|
+
defined_tags : Optional[Dict]
|
179
|
+
Defined tags for the model.
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
DataScienceModel
|
184
|
+
The instance of DataScienceModel.
|
185
|
+
"""
|
186
|
+
model_id = (
|
187
|
+
model_id.model_id if isinstance(model_id, AquaMultiModelRef) else model_id
|
188
|
+
)
|
189
|
+
service_model = DataScienceModel.from_id(model_id)
|
190
|
+
target_project = project_id or PROJECT_OCID
|
191
|
+
target_compartment = compartment_id or COMPARTMENT_OCID
|
192
|
+
|
193
|
+
# Skip model copying if it is registered model or fine-tuned model
|
194
|
+
if (
|
195
|
+
service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None
|
196
|
+
or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
|
197
|
+
is not None
|
198
|
+
):
|
199
|
+
logger.info(
|
200
|
+
f"Aqua Model {model_id} already exists in the user's compartment."
|
201
|
+
"Skipped copying."
|
202
|
+
)
|
203
|
+
return service_model
|
204
|
+
|
205
|
+
# combine tags
|
206
|
+
combined_freeform_tags = {
|
207
|
+
**(service_model.freeform_tags or {}),
|
208
|
+
**(freeform_tags or {}),
|
209
|
+
}
|
210
|
+
combined_defined_tags = {
|
211
|
+
**(service_model.defined_tags or {}),
|
212
|
+
**(defined_tags or {}),
|
213
|
+
}
|
214
|
+
|
215
|
+
custom_model = (
|
216
|
+
DataScienceModel()
|
217
|
+
.with_compartment_id(target_compartment)
|
218
|
+
.with_project_id(target_project)
|
219
|
+
.with_model_file_description(json_dict=service_model.model_file_description)
|
220
|
+
.with_display_name(service_model.display_name)
|
221
|
+
.with_description(service_model.description)
|
222
|
+
.with_freeform_tags(**combined_freeform_tags)
|
223
|
+
.with_defined_tags(**combined_defined_tags)
|
224
|
+
.with_custom_metadata_list(service_model.custom_metadata_list)
|
225
|
+
.with_defined_metadata_list(service_model.defined_metadata_list)
|
226
|
+
.with_provenance_metadata(service_model.provenance_metadata)
|
227
|
+
.create(model_by_reference=True, **kwargs)
|
228
|
+
)
|
229
|
+
logger.info(
|
230
|
+
f"Aqua Model {custom_model.id} created with the service model {model_id}."
|
231
|
+
)
|
232
|
+
|
233
|
+
# Track unique models that were created in the user's compartment
|
234
|
+
self.telemetry.record_event_async(
|
235
|
+
category="aqua/service/model",
|
236
|
+
action="create",
|
237
|
+
detail=service_model.display_name,
|
238
|
+
)
|
239
|
+
|
240
|
+
return custom_model
|
241
|
+
|
242
|
+
@telemetry(entry_point="plugin=model&action=create", name="aqua")
|
243
|
+
def create_multi(
|
244
|
+
self,
|
245
|
+
models: List[AquaMultiModelRef],
|
246
|
+
project_id: Optional[str] = None,
|
247
|
+
compartment_id: Optional[str] = None,
|
248
|
+
freeform_tags: Optional[Dict] = None,
|
249
|
+
defined_tags: Optional[Dict] = None,
|
250
|
+
**kwargs, # noqa: ARG002
|
251
|
+
) -> DataScienceModel:
|
252
|
+
"""
|
253
|
+
Creates a multi-model grouping using the provided model list.
|
254
|
+
|
255
|
+
Parameters
|
256
|
+
----------
|
257
|
+
models : List[AquaMultiModelRef]
|
258
|
+
List of AquaMultiModelRef instances for creating a multi-model group.
|
259
|
+
project_id : Optional[str]
|
260
|
+
The project ID for the multi-model group.
|
261
|
+
compartment_id : Optional[str]
|
262
|
+
The compartment ID for the multi-model group.
|
263
|
+
freeform_tags : Optional[Dict]
|
264
|
+
Freeform tags for the model.
|
265
|
+
defined_tags : Optional[Dict]
|
266
|
+
Defined tags for the model.
|
267
|
+
|
268
|
+
Returns
|
269
|
+
-------
|
270
|
+
DataScienceModel
|
271
|
+
Instance of DataScienceModel object.
|
272
|
+
"""
|
273
|
+
|
274
|
+
if not models:
|
275
|
+
raise AquaValueError(
|
276
|
+
"Model list cannot be empty. Please provide at least one model for deployment."
|
277
|
+
)
|
278
|
+
|
279
|
+
display_name_list = []
|
280
|
+
model_file_description_list: List[ModelFileDescription] = []
|
281
|
+
model_custom_metadata = ModelCustomMetadata()
|
282
|
+
|
283
|
+
service_inference_containers = (
|
284
|
+
self.get_container_config().to_dict().get("inference")
|
285
|
+
)
|
286
|
+
|
287
|
+
supported_container_families = [
|
288
|
+
container_config_item.family
|
289
|
+
for container_config_item in service_inference_containers
|
290
|
+
if any(
|
291
|
+
usage.upper() in container_config_item.usages
|
292
|
+
for usage in [Usage.MULTI_MODEL, Usage.OTHER]
|
293
|
+
)
|
294
|
+
]
|
295
|
+
|
296
|
+
if not supported_container_families:
|
297
|
+
raise AquaValueError(
|
298
|
+
"Currently, there are no containers that support multi-model deployment."
|
299
|
+
)
|
300
|
+
|
301
|
+
selected_models_deployment_containers = set()
|
302
|
+
|
303
|
+
# Process each model
|
304
|
+
for model in models:
|
305
|
+
source_model = DataScienceModel.from_id(model.model_id)
|
306
|
+
display_name = source_model.display_name
|
307
|
+
model_file_description = source_model.model_file_description
|
308
|
+
# Update model name in user's input model
|
309
|
+
model.model_name = model.model_name or display_name
|
310
|
+
|
311
|
+
# TODO Uncomment the section below, if only service models should be allowed for multi-model deployment
|
312
|
+
# if not source_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, UNKNOWN):
|
313
|
+
# raise AquaValueError(
|
314
|
+
# f"Invalid selected model {display_name}. "
|
315
|
+
# "Currently only service models are supported for multi model deployment."
|
316
|
+
# )
|
317
|
+
|
318
|
+
# check if model is a fine-tuned model and if so, add the fine tuned weights path to the fine_tune_weights_location pydantic field
|
319
|
+
is_fine_tuned_model = (
|
320
|
+
Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
|
321
|
+
)
|
322
|
+
|
323
|
+
if is_fine_tuned_model:
|
324
|
+
model.model_id, model.model_name = extract_base_model_from_ft(
|
325
|
+
source_model
|
326
|
+
)
|
327
|
+
model_artifact_path, model.fine_tune_weights_location = (
|
328
|
+
extract_fine_tune_artifacts_path(source_model)
|
329
|
+
)
|
330
|
+
|
331
|
+
else:
|
332
|
+
# Retrieve model artifact for base models
|
333
|
+
model_artifact_path = source_model.artifact
|
334
|
+
|
335
|
+
display_name_list.append(display_name)
|
336
|
+
|
337
|
+
self._extract_model_task(model, source_model)
|
338
|
+
|
339
|
+
if not model_artifact_path:
|
340
|
+
raise AquaValueError(
|
341
|
+
f"Model '{display_name}' (ID: {model.model_id}) has no artifacts. "
|
342
|
+
"Please register the model first."
|
343
|
+
)
|
344
|
+
|
345
|
+
# Update model artifact location in user's input model
|
346
|
+
model.artifact_location = model_artifact_path
|
347
|
+
|
348
|
+
if not model_file_description:
|
349
|
+
raise AquaValueError(
|
350
|
+
f"Model '{display_name}' (ID: {model.model_id}) has no file description. "
|
351
|
+
"Please register the model first."
|
352
|
+
)
|
353
|
+
|
354
|
+
model_file_description_list.append(
|
355
|
+
ModelFileDescription(**model_file_description)
|
356
|
+
)
|
357
|
+
|
358
|
+
# Validate deployment container consistency
|
359
|
+
deployment_container = source_model.custom_metadata_list.get(
|
360
|
+
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
|
361
|
+
ModelCustomMetadataItem(
|
362
|
+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER
|
363
|
+
),
|
364
|
+
).value
|
365
|
+
|
366
|
+
if deployment_container not in supported_container_families:
|
367
|
+
raise AquaValueError(
|
368
|
+
f"Unsupported deployment container '{deployment_container}' for model '{source_model.id}'. "
|
369
|
+
f"Only '{supported_container_families}' are supported for multi-model deployments."
|
370
|
+
)
|
371
|
+
|
372
|
+
selected_models_deployment_containers.add(deployment_container)
|
373
|
+
|
374
|
+
if not selected_models_deployment_containers:
|
375
|
+
raise AquaValueError(
|
376
|
+
"None of the selected models are associated with a recognized container family. "
|
377
|
+
"Please review the selected models, or select a different group of models."
|
378
|
+
)
|
379
|
+
|
380
|
+
# Check if the all models in the group shares same container family
|
381
|
+
if len(selected_models_deployment_containers) > 1:
|
382
|
+
deployment_container = get_preferred_compatible_family(
|
383
|
+
selected_families=selected_models_deployment_containers
|
384
|
+
)
|
385
|
+
if not deployment_container:
|
386
|
+
raise AquaValueError(
|
387
|
+
"The selected models are associated with different container families: "
|
388
|
+
f"{list(selected_models_deployment_containers)}."
|
389
|
+
"For multi-model deployment, all models in the group must belong to the same container "
|
390
|
+
"family or to compatible container families."
|
391
|
+
)
|
392
|
+
else:
|
393
|
+
deployment_container = selected_models_deployment_containers.pop()
|
394
|
+
|
395
|
+
# Generate model group details
|
396
|
+
timestamp = datetime.now().strftime("%Y%m%d")
|
397
|
+
model_group_display_name = f"model_group_{timestamp}"
|
398
|
+
combined_models = ", ".join(display_name_list)
|
399
|
+
model_group_description = f"Multi-model grouping using {combined_models}."
|
400
|
+
|
401
|
+
# Add global metadata
|
402
|
+
model_custom_metadata.add(
|
403
|
+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
|
404
|
+
value=deployment_container,
|
405
|
+
description=f"Inference container mapping for {model_group_display_name}",
|
406
|
+
category="Other",
|
407
|
+
)
|
408
|
+
model_custom_metadata.add(
|
409
|
+
key=ModelCustomMetadataFields.MULTIMODEL_GROUP_COUNT,
|
410
|
+
value=str(len(models)),
|
411
|
+
description="Number of models in the group.",
|
412
|
+
category="Other",
|
413
|
+
)
|
414
|
+
|
415
|
+
# Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show
|
416
|
+
# the models created for multi-model purpose in the AQUA models list.
|
417
|
+
tags = {
|
418
|
+
# Tags.AQUA_TAG: "active",
|
419
|
+
Tags.MULTIMODEL_TYPE_TAG: "true",
|
420
|
+
**(freeform_tags or {}),
|
421
|
+
}
|
422
|
+
|
423
|
+
# Create multi-model group
|
424
|
+
custom_model = (
|
425
|
+
DataScienceModel()
|
426
|
+
.with_compartment_id(compartment_id)
|
427
|
+
.with_project_id(project_id)
|
428
|
+
.with_display_name(model_group_display_name)
|
429
|
+
.with_description(model_group_description)
|
430
|
+
.with_freeform_tags(**tags)
|
431
|
+
.with_defined_tags(**(defined_tags or {}))
|
432
|
+
.with_custom_metadata_list(model_custom_metadata)
|
433
|
+
)
|
434
|
+
|
435
|
+
# Update multi model file description to attach artifacts
|
436
|
+
custom_model.with_model_file_description(
|
437
|
+
json_dict=ModelFileDescription(
|
438
|
+
models=[
|
439
|
+
models
|
440
|
+
for model_file_description in model_file_description_list
|
441
|
+
for models in model_file_description.models
|
442
|
+
]
|
443
|
+
).model_dump(by_alias=True)
|
444
|
+
)
|
445
|
+
|
446
|
+
# Finalize creation
|
447
|
+
custom_model.create(model_by_reference=True)
|
448
|
+
|
449
|
+
logger.info(
|
450
|
+
f"Aqua Model '{custom_model.id}' created with models: {', '.join(display_name_list)}."
|
451
|
+
)
|
452
|
+
|
453
|
+
# Create custom metadata for multi model metadata
|
454
|
+
custom_model.create_custom_metadata_artifact(
|
455
|
+
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA,
|
456
|
+
artifact_path_or_content=json.dumps(
|
457
|
+
[model.model_dump() for model in models]
|
458
|
+
).encode(),
|
459
|
+
path_type=MetadataArtifactPathType.CONTENT,
|
460
|
+
)
|
461
|
+
|
462
|
+
logger.debug(
|
463
|
+
f"Multi model metadata uploaded for Aqua model: {custom_model.id}."
|
464
|
+
)
|
465
|
+
|
466
|
+
# Track telemetry event
|
467
|
+
self.telemetry.record_event_async(
|
468
|
+
category="aqua/multimodel",
|
469
|
+
action="create",
|
470
|
+
detail=combined_models,
|
471
|
+
)
|
472
|
+
|
473
|
+
return custom_model
|
474
|
+
|
475
|
+
@telemetry(entry_point="plugin=model&action=get", name="aqua")
|
476
|
+
def get(self, model_id: str) -> "AquaModel":
|
477
|
+
"""Gets the information of an Aqua model.
|
478
|
+
|
479
|
+
Parameters
|
480
|
+
----------
|
481
|
+
model_id: str
|
482
|
+
The model OCID.
|
483
|
+
load_model_card: (bool, optional). Defaults to `True`.
|
484
|
+
Whether to load model card from artifacts or not.
|
485
|
+
|
486
|
+
Returns
|
487
|
+
-------
|
488
|
+
AquaModel:
|
489
|
+
The instance of AquaModel.
|
490
|
+
"""
|
491
|
+
|
492
|
+
cached_item = self._service_model_details_cache.get(model_id)
|
493
|
+
if cached_item:
|
494
|
+
logger.info(f"Fetching model details for model {model_id} from cache.")
|
495
|
+
return cached_item
|
496
|
+
|
497
|
+
logger.info(f"Fetching model details for model {model_id}.")
|
498
|
+
ds_model = DataScienceModel.from_id(model_id)
|
499
|
+
|
500
|
+
if not self._if_show(ds_model):
|
501
|
+
raise AquaRuntimeError(
|
502
|
+
f"Target model `{ds_model.id} `is not an Aqua model as it does not contain "
|
503
|
+
f"{Tags.AQUA_TAG} tag."
|
504
|
+
)
|
505
|
+
|
506
|
+
is_fine_tuned_model = bool(
|
507
|
+
ds_model.freeform_tags
|
508
|
+
and ds_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
|
509
|
+
)
|
510
|
+
|
511
|
+
inference_container = ds_model.custom_metadata_list.get(
|
512
|
+
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
|
513
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER),
|
514
|
+
).value
|
515
|
+
inference_container_uri = ds_model.custom_metadata_list.get(
|
516
|
+
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
|
517
|
+
ModelCustomMetadataItem(
|
518
|
+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI
|
519
|
+
),
|
520
|
+
).value
|
521
|
+
evaluation_container = ds_model.custom_metadata_list.get(
|
522
|
+
ModelCustomMetadataFields.EVALUATION_CONTAINER,
|
523
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.EVALUATION_CONTAINER),
|
524
|
+
).value
|
525
|
+
finetuning_container: str = ds_model.custom_metadata_list.get(
|
526
|
+
ModelCustomMetadataFields.FINETUNE_CONTAINER,
|
527
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.FINETUNE_CONTAINER),
|
528
|
+
).value
|
529
|
+
artifact_location = ds_model.custom_metadata_list.get(
|
530
|
+
ModelCustomMetadataFields.ARTIFACT_LOCATION,
|
531
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.ARTIFACT_LOCATION),
|
532
|
+
).value
|
533
|
+
|
534
|
+
aqua_model_attributes = dict(
|
535
|
+
**self._process_model(ds_model, self.region),
|
536
|
+
project_id=ds_model.project_id,
|
537
|
+
inference_container=inference_container,
|
538
|
+
inference_container_uri=inference_container_uri,
|
539
|
+
finetuning_container=finetuning_container,
|
540
|
+
evaluation_container=evaluation_container,
|
541
|
+
artifact_location=artifact_location,
|
542
|
+
)
|
543
|
+
|
544
|
+
if not is_fine_tuned_model:
|
545
|
+
model_details = AquaModel(**aqua_model_attributes)
|
546
|
+
self._service_model_details_cache.__setitem__(
|
547
|
+
key=model_id, value=model_details
|
548
|
+
)
|
549
|
+
|
550
|
+
else:
|
551
|
+
try:
|
552
|
+
jobrun_ocid = ds_model.provenance_metadata.training_id
|
553
|
+
jobrun = self.ds_client.get_job_run(jobrun_ocid).data
|
554
|
+
except Exception as e:
|
555
|
+
logger.debug(
|
556
|
+
f"Missing jobrun information in the provenance metadata of the given model {model_id}."
|
557
|
+
f"\nError: {str(e)}"
|
558
|
+
)
|
559
|
+
jobrun = None
|
560
|
+
|
561
|
+
try:
|
562
|
+
source_id = ds_model.custom_metadata_list.get(
|
563
|
+
FineTuningCustomMetadata.FT_SOURCE
|
564
|
+
).value
|
565
|
+
except ValueError as e:
|
566
|
+
logger.debug(
|
567
|
+
f"Custom metadata is missing {FineTuningCustomMetadata.FT_SOURCE} key for "
|
568
|
+
f"model {model_id}.\nError: {str(e)}"
|
569
|
+
)
|
570
|
+
source_id = UNKNOWN
|
571
|
+
|
572
|
+
try:
|
573
|
+
source_name = ds_model.custom_metadata_list.get(
|
574
|
+
FineTuningCustomMetadata.FT_SOURCE_NAME
|
575
|
+
).value
|
576
|
+
except ValueError as e:
|
577
|
+
logger.debug(
|
578
|
+
f"Custom metadata is missing {FineTuningCustomMetadata.FT_SOURCE_NAME} key for "
|
579
|
+
f"model {model_id}.\nError: {str(e)}"
|
580
|
+
)
|
581
|
+
source_name = UNKNOWN
|
582
|
+
|
583
|
+
source_identifier = _build_resource_identifier(
|
584
|
+
id=source_id,
|
585
|
+
name=source_name,
|
586
|
+
region=self.region,
|
587
|
+
)
|
588
|
+
|
589
|
+
ft_metrics = self._build_ft_metrics(ds_model.custom_metadata_list)
|
590
|
+
|
591
|
+
job_run_status = (
|
592
|
+
jobrun.lifecycle_state
|
593
|
+
if jobrun and jobrun.lifecycle_state != JobRun.LIFECYCLE_STATE_DELETED
|
594
|
+
else (
|
595
|
+
JobRun.LIFECYCLE_STATE_SUCCEEDED
|
596
|
+
if self.if_artifact_exist(ds_model.id)
|
597
|
+
else JobRun.LIFECYCLE_STATE_FAILED
|
598
|
+
)
|
599
|
+
)
|
600
|
+
# TODO: change the argument's name.
|
601
|
+
lifecycle_state = LifecycleStatus.get_status(
|
602
|
+
evaluation_status=ds_model.lifecycle_state,
|
603
|
+
job_run_status=job_run_status,
|
604
|
+
)
|
605
|
+
|
606
|
+
model_details = AquaFineTuneModel(
|
607
|
+
**aqua_model_attributes,
|
608
|
+
source=source_identifier,
|
609
|
+
lifecycle_state=(
|
610
|
+
Model.LIFECYCLE_STATE_ACTIVE
|
611
|
+
if lifecycle_state == JobRun.LIFECYCLE_STATE_SUCCEEDED
|
612
|
+
else lifecycle_state
|
613
|
+
),
|
614
|
+
metrics=ft_metrics,
|
615
|
+
model=ds_model,
|
616
|
+
jobrun=jobrun,
|
617
|
+
region=self.region,
|
618
|
+
)
|
619
|
+
|
620
|
+
return model_details
|
621
|
+
|
622
|
+
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
|
623
|
+
def delete_model(self, model_id):
|
624
|
+
ds_model = DataScienceModel.from_id(model_id)
|
625
|
+
is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
|
626
|
+
is_fine_tuned_model = ds_model.freeform_tags.get(
|
627
|
+
Tags.AQUA_FINE_TUNED_MODEL_TAG, None
|
628
|
+
)
|
629
|
+
if is_registered_model or is_fine_tuned_model:
|
630
|
+
logger.info(f"Deleting model {model_id}.")
|
631
|
+
return ds_model.delete()
|
632
|
+
else:
|
633
|
+
raise AquaRuntimeError(
|
634
|
+
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
|
635
|
+
)
|
636
|
+
|
637
|
+
@telemetry(entry_point="plugin=model&action=edit", name="aqua")
|
638
|
+
def edit_registered_model(
|
639
|
+
self, id, inference_container, inference_container_uri, enable_finetuning, task
|
640
|
+
):
|
641
|
+
"""Edits the default config of unverified registered model.
|
642
|
+
|
643
|
+
Parameters
|
644
|
+
----------
|
645
|
+
id: str
|
646
|
+
The model OCID.
|
647
|
+
inference_container: str.
|
648
|
+
The inference container family name
|
649
|
+
inference_container_uri: str
|
650
|
+
The inference container uri for embedding models
|
651
|
+
enable_finetuning: str
|
652
|
+
Flag to enable or disable finetuning over the model. Defaults to None
|
653
|
+
task:
|
654
|
+
The usecase type of the model. e.g , text-generation , text_embedding etc.
|
655
|
+
|
656
|
+
Returns
|
657
|
+
-------
|
658
|
+
Model:
|
659
|
+
The instance of oci.data_science.models.Model.
|
660
|
+
|
661
|
+
"""
|
662
|
+
ds_model = DataScienceModel.from_id(id)
|
663
|
+
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
|
664
|
+
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
|
665
|
+
raise AquaRuntimeError(
|
666
|
+
"Only registered unverified models can be edited."
|
667
|
+
)
|
668
|
+
else:
|
669
|
+
custom_metadata_list = ds_model.custom_metadata_list
|
670
|
+
freeform_tags = ds_model.freeform_tags
|
671
|
+
if inference_container:
|
672
|
+
if (
|
673
|
+
inference_container in CustomInferenceContainerTypeFamily
|
674
|
+
and inference_container_uri is None
|
675
|
+
):
|
676
|
+
raise AquaRuntimeError(
|
677
|
+
"Inference container URI must be provided."
|
678
|
+
)
|
679
|
+
else:
|
680
|
+
custom_metadata_list.add(
|
681
|
+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
|
682
|
+
value=inference_container,
|
683
|
+
category=MetadataCustomCategory.OTHER,
|
684
|
+
description="Deployment container mapping for SMC",
|
685
|
+
replace=True,
|
686
|
+
)
|
687
|
+
if inference_container_uri:
|
688
|
+
if (
|
689
|
+
inference_container in CustomInferenceContainerTypeFamily
|
690
|
+
or inference_container is None
|
691
|
+
):
|
692
|
+
custom_metadata_list.add(
|
693
|
+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
|
694
|
+
value=inference_container_uri,
|
695
|
+
category=MetadataCustomCategory.OTHER,
|
696
|
+
description=f"Inference container URI for {ds_model.display_name}",
|
697
|
+
replace=True,
|
698
|
+
)
|
699
|
+
else:
|
700
|
+
raise AquaRuntimeError(
|
701
|
+
f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}"
|
702
|
+
)
|
703
|
+
|
704
|
+
if enable_finetuning is not None:
|
705
|
+
if enable_finetuning.lower() == "true":
|
706
|
+
custom_metadata_list.add(
|
707
|
+
key=ModelCustomMetadataFields.FINETUNE_CONTAINER,
|
708
|
+
value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY,
|
709
|
+
category=MetadataCustomCategory.OTHER,
|
710
|
+
description="Fine-tuning container mapping for SMC",
|
711
|
+
replace=True,
|
712
|
+
)
|
713
|
+
freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"})
|
714
|
+
elif enable_finetuning.lower() == "false":
|
715
|
+
try:
|
716
|
+
custom_metadata_list.remove(
|
717
|
+
ModelCustomMetadataFields.FINETUNE_CONTAINER
|
718
|
+
)
|
719
|
+
freeform_tags.pop(Tags.READY_TO_FINE_TUNE)
|
720
|
+
except Exception as ex:
|
721
|
+
raise AquaRuntimeError(
|
722
|
+
f"The given model already doesn't support finetuning: {ex}"
|
723
|
+
) from ex
|
724
|
+
|
725
|
+
custom_metadata_list.remove("modelDescription")
|
726
|
+
if task:
|
727
|
+
freeform_tags.update({Tags.TASK: task})
|
728
|
+
updated_custom_metadata_list = [
|
729
|
+
Metadata(**metadata)
|
730
|
+
for metadata in custom_metadata_list.to_dict()["data"]
|
731
|
+
]
|
732
|
+
update_model_details = UpdateModelDetails(
|
733
|
+
custom_metadata_list=updated_custom_metadata_list,
|
734
|
+
freeform_tags=freeform_tags,
|
735
|
+
)
|
736
|
+
AquaApp().update_model(id, update_model_details)
|
737
|
+
logger.info(f"Updated model details for the model {id}.")
|
738
|
+
else:
|
739
|
+
raise AquaRuntimeError("Only registered unverified models can be edited.")
|
740
|
+
|
741
|
+
def _extract_model_task(
|
742
|
+
self,
|
743
|
+
model: AquaMultiModelRef,
|
744
|
+
source_model: DataScienceModel,
|
745
|
+
) -> None:
|
746
|
+
"""In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user"""
|
747
|
+
# user does not supply model task, we extract from model metadata
|
748
|
+
if not model.model_task:
|
749
|
+
model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
|
750
|
+
|
751
|
+
task_tag = re.sub(r"-", "_", model.model_task).lower()
|
752
|
+
# re-visit logic when more model task types are supported
|
753
|
+
if task_tag in MultiModelSupportedTaskType:
|
754
|
+
model.model_task = task_tag
|
755
|
+
else:
|
756
|
+
raise AquaValueError(
|
757
|
+
f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. "
|
758
|
+
f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
|
759
|
+
)
|
760
|
+
|
761
|
+
def _fetch_metric_from_metadata(
|
762
|
+
self,
|
763
|
+
custom_metadata_list: ModelCustomMetadata,
|
764
|
+
target: str,
|
765
|
+
category: str,
|
766
|
+
metric_name: str,
|
767
|
+
) -> AquaFineTuningMetric:
|
768
|
+
"""Gets target metric from `ads.model.model_metadata.ModelCustomMetadata`."""
|
769
|
+
try:
|
770
|
+
scores = []
|
771
|
+
for custom_metadata in custom_metadata_list._items:
|
772
|
+
# We use description to group metrics
|
773
|
+
if custom_metadata.description == target:
|
774
|
+
scores.append(custom_metadata.value)
|
775
|
+
if metric_name.endswith("final"):
|
776
|
+
break
|
777
|
+
|
778
|
+
return AquaFineTuningMetric(
|
779
|
+
name=metric_name,
|
780
|
+
category=category,
|
781
|
+
scores=scores,
|
782
|
+
)
|
783
|
+
except Exception:
|
784
|
+
return AquaFineTuningMetric(name=metric_name, category=category, scores=[])
|
785
|
+
|
786
|
+
def _build_ft_metrics(
|
787
|
+
self, custom_metadata_list: ModelCustomMetadata
|
788
|
+
) -> List[AquaFineTuningMetric]:
|
789
|
+
"""Builds Fine Tuning metrics."""
|
790
|
+
|
791
|
+
validation_metrics = self._fetch_metric_from_metadata(
|
792
|
+
custom_metadata_list=custom_metadata_list,
|
793
|
+
target=FineTuningCustomMetadata.VALIDATION_METRICS_EPOCH,
|
794
|
+
category=FineTuningMetricCategories.VALIDATION,
|
795
|
+
metric_name=VALIDATION_METRICS,
|
796
|
+
)
|
797
|
+
|
798
|
+
training_metrics = self._fetch_metric_from_metadata(
|
799
|
+
custom_metadata_list=custom_metadata_list,
|
800
|
+
target=FineTuningCustomMetadata.TRAINING_METRICS_EPOCH,
|
801
|
+
category=FineTuningMetricCategories.TRAINING,
|
802
|
+
metric_name=TRINING_METRICS,
|
803
|
+
)
|
804
|
+
|
805
|
+
validation_final = self._fetch_metric_from_metadata(
|
806
|
+
custom_metadata_list=custom_metadata_list,
|
807
|
+
target=FineTuningCustomMetadata.VALIDATION_METRICS_FINAL,
|
808
|
+
category=FineTuningMetricCategories.VALIDATION,
|
809
|
+
metric_name=VALIDATION_METRICS_FINAL,
|
810
|
+
)
|
811
|
+
|
812
|
+
training_final = self._fetch_metric_from_metadata(
|
813
|
+
custom_metadata_list=custom_metadata_list,
|
814
|
+
target=FineTuningCustomMetadata.TRAINING_METRICS_FINAL,
|
815
|
+
category=FineTuningMetricCategories.TRAINING,
|
816
|
+
metric_name=TRAINING_METRICS_FINAL,
|
817
|
+
)
|
818
|
+
|
819
|
+
return [
|
820
|
+
validation_metrics,
|
821
|
+
training_metrics,
|
822
|
+
validation_final,
|
823
|
+
training_final,
|
824
|
+
]
|
825
|
+
|
826
|
+
def get_hf_tokenizer_config(self, model_id):
|
827
|
+
"""
|
828
|
+
Gets the default model tokenizer config for the given Aqua model.
|
829
|
+
Returns the content of tokenizer_config.json stored in model artifact.
|
830
|
+
|
831
|
+
Parameters
|
832
|
+
----------
|
833
|
+
model_id: str
|
834
|
+
The OCID of the Aqua model.
|
835
|
+
|
836
|
+
Returns
|
837
|
+
-------
|
838
|
+
Dict:
|
839
|
+
Model tokenizer config.
|
840
|
+
"""
|
841
|
+
config = self.get_config(
|
842
|
+
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
|
843
|
+
).config
|
844
|
+
if not config:
|
845
|
+
logger.debug(
|
846
|
+
f"{AQUA_MODEL_TOKENIZER_CONFIG} is not available for the model: {model_id}. "
|
847
|
+
f"Check if the custom metadata has the artifact path set."
|
848
|
+
)
|
849
|
+
return config
|
850
|
+
|
851
|
+
return config
|
852
|
+
|
853
|
+
@staticmethod
|
854
|
+
def to_aqua_model(
|
855
|
+
model: Union[
|
856
|
+
DataScienceModel,
|
857
|
+
oci.data_science.models.model.Model,
|
858
|
+
oci.data_science.models.ModelSummary,
|
859
|
+
oci.resource_search.models.ResourceSummary,
|
860
|
+
],
|
861
|
+
region: str,
|
862
|
+
) -> AquaModel:
|
863
|
+
"""Converts a model to an Aqua model."""
|
864
|
+
return AquaModel(**AquaModelApp._process_model(model, region))
|
865
|
+
|
866
|
+
@staticmethod
|
867
|
+
def _process_model(
|
868
|
+
model: Union[
|
869
|
+
DataScienceModel,
|
870
|
+
oci.data_science.models.model.Model,
|
871
|
+
oci.data_science.models.ModelSummary,
|
872
|
+
oci.resource_search.models.ResourceSummary,
|
873
|
+
],
|
874
|
+
region: str,
|
875
|
+
inference_containers: Optional[List[Any]] = None,
|
876
|
+
) -> dict:
|
877
|
+
"""Constructs required fields for AquaModelSummary."""
|
878
|
+
|
879
|
+
# todo: revisit icon generation code
|
880
|
+
# icon = self._load_icon(model.display_name)
|
881
|
+
icon = ""
|
882
|
+
|
883
|
+
tags = {}
|
884
|
+
tags.update(model.defined_tags or {})
|
885
|
+
tags.update(model.freeform_tags or {})
|
886
|
+
|
887
|
+
model_id = (
|
888
|
+
model.identifier
|
889
|
+
if isinstance(model, oci.resource_search.models.ResourceSummary)
|
890
|
+
else model.id
|
891
|
+
)
|
892
|
+
|
893
|
+
console_link = get_console_link(
|
894
|
+
resource="models",
|
895
|
+
ocid=model_id,
|
896
|
+
region=region,
|
897
|
+
)
|
898
|
+
|
899
|
+
description = ""
|
900
|
+
if isinstance(model, (DataScienceModel, oci.data_science.models.model.Model)):
|
901
|
+
description = model.description
|
902
|
+
elif isinstance(model, oci.resource_search.models.ResourceSummary):
|
903
|
+
description = model.additional_details.get("description")
|
904
|
+
|
905
|
+
search_text = (
|
906
|
+
AquaModelApp._build_search_text(tags=tags, description=description)
|
907
|
+
if tags
|
908
|
+
else UNKNOWN
|
909
|
+
)
|
910
|
+
|
911
|
+
freeform_tags = model.freeform_tags or {}
|
912
|
+
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in freeform_tags
|
913
|
+
ready_to_deploy = (
|
914
|
+
freeform_tags.get(Tags.AQUA_TAG, "").upper() == READY_TO_DEPLOY_STATUS
|
915
|
+
)
|
916
|
+
|
917
|
+
ready_to_finetune = (
|
918
|
+
freeform_tags.get(Tags.READY_TO_FINE_TUNE, "").upper()
|
919
|
+
== READY_TO_FINE_TUNE_STATUS
|
920
|
+
)
|
921
|
+
ready_to_import = (
|
922
|
+
freeform_tags.get(Tags.READY_TO_IMPORT, "").upper()
|
923
|
+
== READY_TO_IMPORT_STATUS
|
924
|
+
)
|
925
|
+
|
926
|
+
try:
|
927
|
+
model_file = model.custom_metadata_list.get(AQUA_MODEL_ARTIFACT_FILE).value
|
928
|
+
except Exception:
|
929
|
+
model_file = UNKNOWN
|
930
|
+
|
931
|
+
if not inference_containers:
|
932
|
+
inference_containers = (
|
933
|
+
AquaApp().get_container_config().to_dict().get("inference")
|
934
|
+
)
|
935
|
+
|
936
|
+
model_formats_str = freeform_tags.get(
|
937
|
+
Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS
|
938
|
+
).upper()
|
939
|
+
model_formats = model_formats_str.split(",")
|
940
|
+
|
941
|
+
supported_platform: Set[str] = set()
|
942
|
+
|
943
|
+
for container in inference_containers:
|
944
|
+
for model_format in model_formats:
|
945
|
+
if model_format in container.model_formats:
|
946
|
+
supported_platform.update(container.platforms)
|
947
|
+
|
948
|
+
nvidia_gpu_supported = Platform.NVIDIA_GPU in supported_platform
|
949
|
+
arm_cpu_supported = Platform.ARM_CPU in supported_platform
|
950
|
+
|
951
|
+
return {
|
952
|
+
"compartment_id": model.compartment_id,
|
953
|
+
"icon": icon or UNKNOWN,
|
954
|
+
"id": model_id,
|
955
|
+
"license": freeform_tags.get(Tags.LICENSE, UNKNOWN),
|
956
|
+
"name": model.display_name,
|
957
|
+
"organization": freeform_tags.get(Tags.ORGANIZATION, UNKNOWN),
|
958
|
+
"task": freeform_tags.get(Tags.TASK, UNKNOWN),
|
959
|
+
"time_created": str(model.time_created),
|
960
|
+
"is_fine_tuned_model": is_fine_tuned_model,
|
961
|
+
"tags": tags,
|
962
|
+
"console_link": console_link,
|
963
|
+
"search_text": search_text,
|
964
|
+
"ready_to_deploy": ready_to_deploy,
|
965
|
+
"ready_to_finetune": ready_to_finetune,
|
966
|
+
"ready_to_import": ready_to_import,
|
967
|
+
"nvidia_gpu_supported": nvidia_gpu_supported,
|
968
|
+
"arm_cpu_supported": arm_cpu_supported,
|
969
|
+
"model_file": model_file,
|
970
|
+
"model_formats": model_formats,
|
971
|
+
}
|
972
|
+
|
973
|
+
@telemetry(entry_point="plugin=model&action=list", name="aqua")
|
974
|
+
def list(
|
975
|
+
self,
|
976
|
+
compartment_id: str = None,
|
977
|
+
category: str = None,
|
978
|
+
project_id: str = None,
|
979
|
+
model_type: str = None,
|
980
|
+
**kwargs,
|
981
|
+
) -> List["AquaModelSummary"]:
|
982
|
+
"""Lists all Aqua models within a specified compartment and/or project.
|
983
|
+
If `category` is not specified, the method defaults to returning
|
984
|
+
the service models within the pre-configured default compartment. By default, the list
|
985
|
+
of models in the service compartment are cached. Use clear_model_list_cache() to invalidate
|
986
|
+
the cache.
|
987
|
+
|
988
|
+
Parameters
|
989
|
+
----------
|
990
|
+
compartment_id: (str, optional). Defaults to `None`.
|
991
|
+
The compartment OCID.
|
992
|
+
category: (str,optional). Defaults to `SERVICE`
|
993
|
+
The category of the models to fetch. Can be either `USER` or `SERVICE`
|
994
|
+
project_id: (str, optional). Defaults to `None`.
|
995
|
+
The project OCID.
|
996
|
+
model_type: (str, optional). Defaults to `None`.
|
997
|
+
Model type represents the type of model in the user compartment, can be either FT or BASE.
|
998
|
+
**kwargs:
|
999
|
+
Additional keyword arguments that can be used to filter the results.
|
1000
|
+
|
1001
|
+
Returns
|
1002
|
+
-------
|
1003
|
+
List[AquaModelSummary]:
|
1004
|
+
The list of the `ads.aqua.model.AquaModelSummary`.
|
1005
|
+
"""
|
1006
|
+
|
1007
|
+
category = category or kwargs.pop("category", SERVICE)
|
1008
|
+
compartment_id = compartment_id or COMPARTMENT_OCID
|
1009
|
+
if category == USER:
|
1010
|
+
# tracks number of times custom model listing was called
|
1011
|
+
self.telemetry.record_event_async(
|
1012
|
+
category="aqua/custom/model", action="list"
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
logger.info(f"Fetching custom models from compartment_id={compartment_id}.")
|
1016
|
+
model_type = model_type.upper() if model_type else ModelType.FT
|
1017
|
+
models = self._rqs(compartment_id, model_type=model_type)
|
1018
|
+
logger.info(
|
1019
|
+
f"Fetched {len(models)} models from {compartment_id or COMPARTMENT_OCID}."
|
1020
|
+
)
|
1021
|
+
else:
|
1022
|
+
# tracks number of times service model listing was called
|
1023
|
+
self.telemetry.record_event_async(
|
1024
|
+
category="aqua/service/model", action="list"
|
1025
|
+
)
|
1026
|
+
|
1027
|
+
if AQUA_SERVICE_MODELS in self._service_models_cache:
|
1028
|
+
logger.info("Returning service models list from cache.")
|
1029
|
+
return self._service_models_cache.get(AQUA_SERVICE_MODELS)
|
1030
|
+
lifecycle_state = kwargs.pop(
|
1031
|
+
"lifecycle_state", Model.LIFECYCLE_STATE_ACTIVE
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
models = self.list_resource(
|
1035
|
+
self.ds_client.list_models,
|
1036
|
+
compartment_id=compartment_id,
|
1037
|
+
lifecycle_state=lifecycle_state,
|
1038
|
+
category=category,
|
1039
|
+
**kwargs,
|
1040
|
+
)
|
1041
|
+
logger.info(f"Fetched {len(models)} service models.")
|
1042
|
+
|
1043
|
+
aqua_models = []
|
1044
|
+
inference_containers = self.get_container_config().to_dict().get("inference")
|
1045
|
+
for model in models:
|
1046
|
+
aqua_models.append(
|
1047
|
+
AquaModelSummary(
|
1048
|
+
**self._process_model(
|
1049
|
+
model=model,
|
1050
|
+
region=self.region,
|
1051
|
+
inference_containers=inference_containers,
|
1052
|
+
),
|
1053
|
+
project_id=project_id or UNKNOWN,
|
1054
|
+
)
|
1055
|
+
)
|
1056
|
+
if category == SERVICE:
|
1057
|
+
self._service_models_cache.__setitem__(
|
1058
|
+
key=AQUA_SERVICE_MODELS, value=aqua_models
|
1059
|
+
)
|
1060
|
+
|
1061
|
+
return aqua_models
|
1062
|
+
|
1063
|
+
def clear_model_list_cache(
|
1064
|
+
self,
|
1065
|
+
):
|
1066
|
+
"""
|
1067
|
+
Allows user to clear list model cache items from the service models compartment.
|
1068
|
+
Returns
|
1069
|
+
-------
|
1070
|
+
dict with the key used, and True if cache has the key that needs to be deleted.
|
1071
|
+
"""
|
1072
|
+
res = {}
|
1073
|
+
with self._cache_lock:
|
1074
|
+
if AQUA_SERVICE_MODELS in self._service_models_cache:
|
1075
|
+
self._service_models_cache.pop(key=AQUA_SERVICE_MODELS)
|
1076
|
+
logger.info("Cleared models cache for service compartment.")
|
1077
|
+
res = {
|
1078
|
+
"cache_deleted": True,
|
1079
|
+
}
|
1080
|
+
return res
|
1081
|
+
|
1082
|
+
def clear_model_details_cache(self, model_id):
|
1083
|
+
"""
|
1084
|
+
Allows user to clear model details cache item
|
1085
|
+
Returns
|
1086
|
+
-------
|
1087
|
+
dict with the key used, and True if cache has the key that needs to be deleted.
|
1088
|
+
"""
|
1089
|
+
res = {}
|
1090
|
+
with self._cache_lock:
|
1091
|
+
if model_id in self._service_model_details_cache:
|
1092
|
+
self._service_model_details_cache.pop(key=model_id)
|
1093
|
+
logger.info(f"Clearing model details cache for model {model_id}.")
|
1094
|
+
res = {"key": {"model_id": model_id}, "cache_deleted": True}
|
1095
|
+
|
1096
|
+
return res
|
1097
|
+
|
1098
|
+
@staticmethod
|
1099
|
+
def list_valid_inference_containers():
|
1100
|
+
containers = AquaApp().get_container_config().to_dict().get("inference")
|
1101
|
+
family_values = [item.family for item in containers]
|
1102
|
+
return family_values
|
1103
|
+
|
1104
|
+
@telemetry(
|
1105
|
+
entry_point="plugin=model&action=get_defined_metadata_artifact_content",
|
1106
|
+
name="aqua",
|
1107
|
+
)
|
1108
|
+
def get_defined_metadata_artifact_content(self, model_id: str, metadata_key: str):
|
1109
|
+
"""
|
1110
|
+
Gets the defined metadata artifact content for the given model
|
1111
|
+
|
1112
|
+
Args:
|
1113
|
+
model_id: str
|
1114
|
+
model ocid for which defined metadata artifact needs to be created
|
1115
|
+
metadata_key: str
|
1116
|
+
defined metadata key like Readme , License , DeploymentConfiguration , FinetuningConfiguration
|
1117
|
+
Returns:
|
1118
|
+
The model defined metadata artifact content. Can be either str or Dict
|
1119
|
+
|
1120
|
+
"""
|
1121
|
+
|
1122
|
+
content = self.get_config(model_id, metadata_key)
|
1123
|
+
if not content:
|
1124
|
+
logger.debug(
|
1125
|
+
f"Defined metadata artifact {metadata_key} for model: {model_id} is not available."
|
1126
|
+
)
|
1127
|
+
return content
|
1128
|
+
|
1129
|
+
@telemetry(
|
1130
|
+
entry_point="plugin=model&action=create_defined_metadata_artifact", name="aqua"
|
1131
|
+
)
|
1132
|
+
def create_defined_metadata_artifact(
|
1133
|
+
self,
|
1134
|
+
model_id: str,
|
1135
|
+
metadata_key: str,
|
1136
|
+
path_type: MetadataArtifactPathType,
|
1137
|
+
artifact_path_or_content: str,
|
1138
|
+
) -> None:
|
1139
|
+
"""
|
1140
|
+
Creates defined metadata artifact for the registered unverified model
|
1141
|
+
|
1142
|
+
Args:
|
1143
|
+
model_id: str
|
1144
|
+
model ocid for which defined metadata artifact needs to be created
|
1145
|
+
metadata_key: str
|
1146
|
+
defined metadata key like Readme , License , DeploymentConfiguration , FinetuningConfiguration
|
1147
|
+
path_type: str
|
1148
|
+
path type of the given defined metadata can be local , oss or the content itself
|
1149
|
+
artifact_path_or_content: str
|
1150
|
+
It can be local path or oss path or the actual content itself
|
1151
|
+
Returns:
|
1152
|
+
None
|
1153
|
+
"""
|
1154
|
+
|
1155
|
+
ds_model = DataScienceModel.from_id(model_id)
|
1156
|
+
oci_aqua = ds_model.freeform_tags.get(Tags.AQUA_TAG, None)
|
1157
|
+
if not oci_aqua:
|
1158
|
+
raise AquaRuntimeError(f"Target model {model_id} is not an Aqua model.")
|
1159
|
+
is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
|
1160
|
+
is_verified_model = ds_model.freeform_tags.get(
|
1161
|
+
Tags.AQUA_SERVICE_MODEL_TAG, None
|
1162
|
+
)
|
1163
|
+
if is_registered_model and not is_verified_model:
|
1164
|
+
try:
|
1165
|
+
ds_model.create_defined_metadata_artifact(
|
1166
|
+
metadata_key_name=metadata_key,
|
1167
|
+
artifact_path_or_content=artifact_path_or_content,
|
1168
|
+
path_type=path_type,
|
1169
|
+
)
|
1170
|
+
except Exception as ex:
|
1171
|
+
raise AquaRuntimeError(
|
1172
|
+
f"Error occurred in creating defined metadata artifact for model {model_id}: {ex}"
|
1173
|
+
) from ex
|
1174
|
+
else:
|
1175
|
+
raise AquaRuntimeError(
|
1176
|
+
f"Cannot create defined metadata artifact for model {model_id}"
|
1177
|
+
)
|
1178
|
+
|
1179
|
+
def _create_model_catalog_entry(
|
1180
|
+
self,
|
1181
|
+
os_path: str,
|
1182
|
+
model_name: str,
|
1183
|
+
inference_container: str,
|
1184
|
+
finetuning_container: str,
|
1185
|
+
verified_model: DataScienceModel,
|
1186
|
+
validation_result: ModelValidationResult,
|
1187
|
+
compartment_id: Optional[str],
|
1188
|
+
project_id: Optional[str],
|
1189
|
+
inference_container_uri: Optional[str],
|
1190
|
+
freeform_tags: Optional[dict] = None,
|
1191
|
+
defined_tags: Optional[dict] = None,
|
1192
|
+
) -> DataScienceModel:
|
1193
|
+
"""Create model by reference from the object storage path
|
1194
|
+
|
1195
|
+
Args:
|
1196
|
+
os_path (str): OCI where the model is uploaded - oci://bucket@namespace/prefix
|
1197
|
+
model_name (str): name of the model
|
1198
|
+
inference_container (str): selects service defaults
|
1199
|
+
finetuning_container (str): selects service defaults
|
1200
|
+
verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service verified model
|
1201
|
+
compartment_id (Optional[str]): Compartment Id of the compartment where the model has to be created
|
1202
|
+
project_id (Optional[str]): Project id of the project where the model has to be created
|
1203
|
+
inference_container_uri (Optional[str]): Inference container uri for BYOC
|
1204
|
+
freeform_tags (dict): Freeform tags for the model
|
1205
|
+
defined_tags (dict): Defined tags for the model
|
1206
|
+
|
1207
|
+
Returns:
|
1208
|
+
DataScienceModel: Returns Datascience model instance.
|
1209
|
+
"""
|
1210
|
+
model = DataScienceModel()
|
1211
|
+
tags: Dict[str, str] = (
|
1212
|
+
{
|
1213
|
+
**verified_model.freeform_tags,
|
1214
|
+
Tags.AQUA_SERVICE_MODEL_TAG: verified_model.id,
|
1215
|
+
}
|
1216
|
+
if verified_model
|
1217
|
+
else {
|
1218
|
+
Tags.AQUA_TAG: "active",
|
1219
|
+
Tags.BASE_MODEL_CUSTOM: "true",
|
1220
|
+
}
|
1221
|
+
)
|
1222
|
+
tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
|
1223
|
+
|
1224
|
+
if validation_result and validation_result.model_formats:
|
1225
|
+
tags.update(
|
1226
|
+
{
|
1227
|
+
Tags.MODEL_FORMAT: ",".join(
|
1228
|
+
model_format for model_format in validation_result.model_formats
|
1229
|
+
)
|
1230
|
+
}
|
1231
|
+
)
|
1232
|
+
|
1233
|
+
# Remove `ready_to_import` tag that might get copied from service model.
|
1234
|
+
tags.pop(Tags.READY_TO_IMPORT, None)
|
1235
|
+
defined_metadata_dict = {}
|
1236
|
+
readme_file_path = os_path.rstrip("/") + "/" + README
|
1237
|
+
license_file_path = os_path.rstrip("/") + "/" + LICENSE
|
1238
|
+
if verified_model:
|
1239
|
+
# Verified model is a model in the service catalog that either has no artifacts but contains all the necessary metadata for deploying and fine tuning.
|
1240
|
+
# If set, then we copy all the model metadata.
|
1241
|
+
metadata = verified_model.custom_metadata_list
|
1242
|
+
if verified_model.model_file_description:
|
1243
|
+
model = model.with_model_file_description(
|
1244
|
+
json_dict=verified_model.model_file_description
|
1245
|
+
)
|
1246
|
+
defined_metadata_list = (
|
1247
|
+
verified_model.defined_metadata_list._to_oci_metadata()
|
1248
|
+
)
|
1249
|
+
for defined_metadata in defined_metadata_list:
|
1250
|
+
if defined_metadata.has_artifact:
|
1251
|
+
content = (
|
1252
|
+
self.ds_client.get_model_defined_metadatum_artifact_content(
|
1253
|
+
verified_model.id, defined_metadata.key
|
1254
|
+
).data.content
|
1255
|
+
)
|
1256
|
+
defined_metadata_dict[defined_metadata.key] = content
|
1257
|
+
else:
|
1258
|
+
metadata = ModelCustomMetadata()
|
1259
|
+
if not inference_container:
|
1260
|
+
raise AquaRuntimeError(
|
1261
|
+
f"Require Inference container information. Model: {model_name} does not have associated inference "
|
1262
|
+
f"container defaults. Check docs for more information on how to pass inference container."
|
1263
|
+
)
|
1264
|
+
metadata.add(
|
1265
|
+
key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
|
1266
|
+
value=inference_container,
|
1267
|
+
description=f"Inference container mapping for {model_name}",
|
1268
|
+
category="Other",
|
1269
|
+
)
|
1270
|
+
if inference_container_uri:
|
1271
|
+
metadata.add(
|
1272
|
+
key=AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME,
|
1273
|
+
value=inference_container_uri,
|
1274
|
+
description=f"Inference container URI for {model_name}",
|
1275
|
+
category="Other",
|
1276
|
+
)
|
1277
|
+
|
1278
|
+
inference_containers = (
|
1279
|
+
AquaContainerConfig.from_service_config(
|
1280
|
+
service_containers=self.list_service_containers()
|
1281
|
+
)
|
1282
|
+
.to_dict()
|
1283
|
+
.get("inference")
|
1284
|
+
)
|
1285
|
+
smc_container_set = {container.family for container in inference_containers}
|
1286
|
+
# only add cmd vars if inference container is not an SMC
|
1287
|
+
if (
|
1288
|
+
inference_container not in smc_container_set
|
1289
|
+
and inference_container in CustomInferenceContainerTypeFamily.values()
|
1290
|
+
):
|
1291
|
+
cmd_vars = generate_tei_cmd_var(os_path)
|
1292
|
+
metadata.add(
|
1293
|
+
key=AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME,
|
1294
|
+
value=" ".join(cmd_vars),
|
1295
|
+
description=f"Inference container cmd vars for {model_name}",
|
1296
|
+
category="Other",
|
1297
|
+
)
|
1298
|
+
|
1299
|
+
if finetuning_container:
|
1300
|
+
tags[Tags.READY_TO_FINE_TUNE] = "true"
|
1301
|
+
metadata.add(
|
1302
|
+
key=AQUA_FINETUNING_CONTAINER_METADATA_NAME,
|
1303
|
+
value=finetuning_container,
|
1304
|
+
description=f"Fine-tuning container mapping for {model_name}",
|
1305
|
+
category="Other",
|
1306
|
+
)
|
1307
|
+
else:
|
1308
|
+
logger.warn(
|
1309
|
+
"Proceeding with model registration without the fine-tuning container information. "
|
1310
|
+
"This model will not be available for fine tuning."
|
1311
|
+
)
|
1312
|
+
if validation_result and validation_result.model_file:
|
1313
|
+
metadata.add(
|
1314
|
+
key=AQUA_MODEL_ARTIFACT_FILE,
|
1315
|
+
value=validation_result.model_file,
|
1316
|
+
description=f"The model file for {model_name}",
|
1317
|
+
category="Other",
|
1318
|
+
)
|
1319
|
+
|
1320
|
+
metadata.add(
|
1321
|
+
key=AQUA_EVALUATION_CONTAINER_METADATA_NAME,
|
1322
|
+
value="odsc-llm-evaluate",
|
1323
|
+
description="Evaluation container mapping for SMC",
|
1324
|
+
category="Other",
|
1325
|
+
)
|
1326
|
+
|
1327
|
+
if validation_result and validation_result.tags:
|
1328
|
+
tags[Tags.TASK] = validation_result.tags.get(Tags.TASK, UNKNOWN)
|
1329
|
+
tags[Tags.ORGANIZATION] = validation_result.tags.get(
|
1330
|
+
Tags.ORGANIZATION, UNKNOWN
|
1331
|
+
)
|
1332
|
+
tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
|
1333
|
+
|
1334
|
+
# Set artifact location to user bucket, and replace existing key if present.
|
1335
|
+
metadata.add(
|
1336
|
+
key=MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
1337
|
+
value=os_path,
|
1338
|
+
description="artifact location",
|
1339
|
+
category="Other",
|
1340
|
+
replace=True,
|
1341
|
+
)
|
1342
|
+
# override tags with freeform tags if set
|
1343
|
+
tags = {**tags, **(freeform_tags or {})}
|
1344
|
+
model = (
|
1345
|
+
model.with_custom_metadata_list(metadata)
|
1346
|
+
.with_compartment_id(compartment_id or COMPARTMENT_OCID)
|
1347
|
+
.with_project_id(project_id or PROJECT_OCID)
|
1348
|
+
.with_artifact(os_path)
|
1349
|
+
.with_display_name(model_name)
|
1350
|
+
.with_freeform_tags(**tags)
|
1351
|
+
.with_defined_tags(**(defined_tags or {}))
|
1352
|
+
).create(model_by_reference=True)
|
1353
|
+
logger.debug(f"Created model catalog entry for the model:\n{model}")
|
1354
|
+
for key, value in defined_metadata_dict.items():
|
1355
|
+
model.create_defined_metadata_artifact(
|
1356
|
+
key, value, MetadataArtifactPathType.CONTENT
|
1357
|
+
)
|
1358
|
+
|
1359
|
+
if is_path_exists(readme_file_path):
|
1360
|
+
try:
|
1361
|
+
model.create_defined_metadata_artifact(
|
1362
|
+
AquaModelMetadataKeys.README,
|
1363
|
+
readme_file_path,
|
1364
|
+
MetadataArtifactPathType.OSS,
|
1365
|
+
)
|
1366
|
+
except Exception as ex:
|
1367
|
+
logger.error(
|
1368
|
+
f"Error Uploading Readme in defined metadata for model: {model.id} : {str(ex)}"
|
1369
|
+
)
|
1370
|
+
if not verified_model and is_path_exists(license_file_path):
|
1371
|
+
try:
|
1372
|
+
model.create_defined_metadata_artifact(
|
1373
|
+
AquaModelMetadataKeys.LICENSE,
|
1374
|
+
license_file_path,
|
1375
|
+
MetadataArtifactPathType.OSS,
|
1376
|
+
)
|
1377
|
+
except Exception as ex:
|
1378
|
+
logger.error(
|
1379
|
+
f"Error Uploading License in defined metadata for model: {model.id} : {str(ex)}"
|
1380
|
+
)
|
1381
|
+
return model
|
1382
|
+
|
1383
|
+
@staticmethod
|
1384
|
+
def get_model_files(os_path: str, model_format: str) -> List[str]:
|
1385
|
+
"""
|
1386
|
+
Get a list of model files based on the given OS path and model format.
|
1387
|
+
|
1388
|
+
Args:
|
1389
|
+
os_path (str): The OS path where the model files are located.
|
1390
|
+
model_format (str): The format of the model files.
|
1391
|
+
|
1392
|
+
Returns:
|
1393
|
+
List[str]: A list of model file names.
|
1394
|
+
|
1395
|
+
"""
|
1396
|
+
model_files: List[str] = []
|
1397
|
+
# todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
|
1398
|
+
# are grouped in one category and validation checks for config.json files only.
|
1399
|
+
if model_format == ModelFormat.SAFETENSORS:
|
1400
|
+
model_files.extend(
|
1401
|
+
list_os_files_with_extension(oss_path=os_path, extension=".safetensors")
|
1402
|
+
)
|
1403
|
+
try:
|
1404
|
+
load_config(
|
1405
|
+
file_path=os_path,
|
1406
|
+
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
|
1407
|
+
)
|
1408
|
+
except Exception as ex:
|
1409
|
+
message = (
|
1410
|
+
f"The model path {os_path} does not contain the file config.json. "
|
1411
|
+
f"Please check if the path is correct or the model artifacts are available at this location."
|
1412
|
+
)
|
1413
|
+
logger.warning(
|
1414
|
+
f"{message}\n"
|
1415
|
+
f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}\n"
|
1416
|
+
)
|
1417
|
+
else:
|
1418
|
+
model_files.append(AQUA_MODEL_ARTIFACT_CONFIG)
|
1419
|
+
|
1420
|
+
if model_format == ModelFormat.GGUF:
|
1421
|
+
model_files.extend(
|
1422
|
+
list_os_files_with_extension(oss_path=os_path, extension=".gguf")
|
1423
|
+
)
|
1424
|
+
logger.debug(
|
1425
|
+
f"Fetched {len(model_files)} model files from {os_path} for model format {model_format}."
|
1426
|
+
)
|
1427
|
+
return model_files
|
1428
|
+
|
1429
|
+
@staticmethod
|
1430
|
+
def get_hf_model_files(model_name: str, model_format: str) -> List[str]:
|
1431
|
+
"""
|
1432
|
+
Get a list of model files based on the given OS path and model format.
|
1433
|
+
|
1434
|
+
Args:
|
1435
|
+
model_name (str): The huggingface model name.
|
1436
|
+
model_format (str): The format of the model files.
|
1437
|
+
|
1438
|
+
Returns:
|
1439
|
+
List[str]: A list of model file names.
|
1440
|
+
|
1441
|
+
"""
|
1442
|
+
model_files: List[str] = []
|
1443
|
+
|
1444
|
+
# todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
|
1445
|
+
# are grouped in one category and returns config.json file only.
|
1446
|
+
|
1447
|
+
try:
|
1448
|
+
model_siblings = get_hf_model_info(repo_id=model_name).siblings
|
1449
|
+
except Exception as e:
|
1450
|
+
huggingface_err_message = str(e)
|
1451
|
+
raise AquaValueError(
|
1452
|
+
f"Could not get the model files of {model_name} from https://huggingface.co. "
|
1453
|
+
f"Error: {huggingface_err_message}."
|
1454
|
+
) from e
|
1455
|
+
|
1456
|
+
if not model_siblings:
|
1457
|
+
raise AquaValueError(
|
1458
|
+
f"Failed to fetch the model files of {model_name} from https://huggingface.co."
|
1459
|
+
)
|
1460
|
+
|
1461
|
+
for model_sibling in model_siblings:
|
1462
|
+
extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
|
1463
|
+
if (
|
1464
|
+
model_format == ModelFormat.SAFETENSORS
|
1465
|
+
and model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG
|
1466
|
+
):
|
1467
|
+
model_files.append(model_sibling.rfilename)
|
1468
|
+
if extension == model_format:
|
1469
|
+
model_files.append(model_sibling.rfilename)
|
1470
|
+
|
1471
|
+
logger.debug(
|
1472
|
+
f"Fetched {len(model_files)} model files for the model {model_name} for model format {model_format}."
|
1473
|
+
)
|
1474
|
+
return model_files
|
1475
|
+
|
1476
|
+
def _validate_model(
|
1477
|
+
self,
|
1478
|
+
import_model_details: ImportModelDetails = None,
|
1479
|
+
model_name: str = None,
|
1480
|
+
verified_model: DataScienceModel = None,
|
1481
|
+
) -> ModelValidationResult:
|
1482
|
+
"""
|
1483
|
+
Validates the model configuration and returns the model format telemetry model name.
|
1484
|
+
|
1485
|
+
Args:
|
1486
|
+
import_model_details (ImportModelDetails): Model details for importing the model.
|
1487
|
+
model_name (str): name of the model
|
1488
|
+
verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from
|
1489
|
+
the service verified model
|
1490
|
+
|
1491
|
+
Returns:
|
1492
|
+
ModelValidationResult: The result of the model validation.
|
1493
|
+
|
1494
|
+
Raises:
|
1495
|
+
AquaRuntimeError: If there is an error while loading the config file or if the model path is incorrect.
|
1496
|
+
AquaValueError: If the model format is not supported by AQUA.
|
1497
|
+
"""
|
1498
|
+
model_formats = []
|
1499
|
+
validation_result: ModelValidationResult = ModelValidationResult()
|
1500
|
+
|
1501
|
+
hf_download_config_present = False
|
1502
|
+
|
1503
|
+
if import_model_details.download_from_hf:
|
1504
|
+
safetensors_model_files = self.get_hf_model_files(
|
1505
|
+
model_name, ModelFormat.SAFETENSORS
|
1506
|
+
)
|
1507
|
+
if (
|
1508
|
+
safetensors_model_files
|
1509
|
+
and AQUA_MODEL_ARTIFACT_CONFIG in safetensors_model_files
|
1510
|
+
):
|
1511
|
+
hf_download_config_present = True
|
1512
|
+
gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF)
|
1513
|
+
else:
|
1514
|
+
safetensors_model_files = self.get_model_files(
|
1515
|
+
import_model_details.os_path, ModelFormat.SAFETENSORS
|
1516
|
+
)
|
1517
|
+
gguf_model_files = self.get_model_files(
|
1518
|
+
import_model_details.os_path, ModelFormat.GGUF
|
1519
|
+
)
|
1520
|
+
|
1521
|
+
if not (safetensors_model_files or gguf_model_files):
|
1522
|
+
raise AquaRuntimeError(
|
1523
|
+
f"The model {model_name} does not contain either {ModelFormat.SAFETENSORS} "
|
1524
|
+
f"or {ModelFormat.GGUF} files in {import_model_details.os_path} or Hugging Face repository. "
|
1525
|
+
f"Please check if the path is correct or the model artifacts are available at this location."
|
1526
|
+
)
|
1527
|
+
|
1528
|
+
if verified_model:
|
1529
|
+
aqua_model = self.to_aqua_model(verified_model, self.region)
|
1530
|
+
model_formats = aqua_model.model_formats
|
1531
|
+
else:
|
1532
|
+
if safetensors_model_files:
|
1533
|
+
model_formats.append(ModelFormat.SAFETENSORS)
|
1534
|
+
if gguf_model_files:
|
1535
|
+
model_formats.append(ModelFormat.GGUF)
|
1536
|
+
|
1537
|
+
# get tags for models from hf
|
1538
|
+
if import_model_details.download_from_hf:
|
1539
|
+
model_info = get_hf_model_info(repo_id=model_name)
|
1540
|
+
|
1541
|
+
try:
|
1542
|
+
license_value = UNKNOWN
|
1543
|
+
if model_info.tags:
|
1544
|
+
license_tag = next(
|
1545
|
+
(
|
1546
|
+
tag
|
1547
|
+
for tag in model_info.tags
|
1548
|
+
if tag.startswith("license:")
|
1549
|
+
),
|
1550
|
+
UNKNOWN,
|
1551
|
+
)
|
1552
|
+
license_value = (
|
1553
|
+
license_tag.split(":")[1] if license_tag else UNKNOWN
|
1554
|
+
)
|
1555
|
+
|
1556
|
+
hf_tags = {
|
1557
|
+
Tags.TASK: (model_info and model_info.pipeline_tag) or UNKNOWN,
|
1558
|
+
Tags.ORGANIZATION: (
|
1559
|
+
model_info.author
|
1560
|
+
if model_info and hasattr(model_info, "author")
|
1561
|
+
else UNKNOWN
|
1562
|
+
),
|
1563
|
+
Tags.LICENSE: license_value,
|
1564
|
+
}
|
1565
|
+
validation_result.tags = hf_tags
|
1566
|
+
except Exception as ex:
|
1567
|
+
logger.debug(
|
1568
|
+
f"An error occurred while getting tag information for model {model_name}. "
|
1569
|
+
f"Error: {str(ex)}"
|
1570
|
+
)
|
1571
|
+
|
1572
|
+
validation_result.model_formats = model_formats
|
1573
|
+
|
1574
|
+
# now as we know that at least one type of model files exist, validate the content of oss path.
|
1575
|
+
# for safetensors, we check if config.json files exist, and for gguf format we check if files with
|
1576
|
+
# gguf extension exist.
|
1577
|
+
if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)):
|
1578
|
+
if (
|
1579
|
+
import_model_details.inference_container.lower()
|
1580
|
+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
|
1581
|
+
):
|
1582
|
+
self._validate_gguf_format(
|
1583
|
+
import_model_details=import_model_details,
|
1584
|
+
verified_model=verified_model,
|
1585
|
+
gguf_model_files=gguf_model_files,
|
1586
|
+
validation_result=validation_result,
|
1587
|
+
model_name=model_name,
|
1588
|
+
)
|
1589
|
+
else:
|
1590
|
+
self._validate_safetensor_format(
|
1591
|
+
import_model_details=import_model_details,
|
1592
|
+
verified_model=verified_model,
|
1593
|
+
validation_result=validation_result,
|
1594
|
+
hf_download_config_present=hf_download_config_present,
|
1595
|
+
model_name=model_name,
|
1596
|
+
)
|
1597
|
+
elif ModelFormat.SAFETENSORS in model_formats:
|
1598
|
+
self._validate_safetensor_format(
|
1599
|
+
import_model_details=import_model_details,
|
1600
|
+
verified_model=verified_model,
|
1601
|
+
validation_result=validation_result,
|
1602
|
+
hf_download_config_present=hf_download_config_present,
|
1603
|
+
model_name=model_name,
|
1604
|
+
)
|
1605
|
+
elif ModelFormat.GGUF in model_formats:
|
1606
|
+
self._validate_gguf_format(
|
1607
|
+
import_model_details=import_model_details,
|
1608
|
+
verified_model=verified_model,
|
1609
|
+
gguf_model_files=gguf_model_files,
|
1610
|
+
validation_result=validation_result,
|
1611
|
+
model_name=model_name,
|
1612
|
+
)
|
1613
|
+
|
1614
|
+
return validation_result
|
1615
|
+
|
1616
|
+
@staticmethod
|
1617
|
+
def _validate_safetensor_format(
|
1618
|
+
import_model_details: ImportModelDetails = None,
|
1619
|
+
verified_model: DataScienceModel = None,
|
1620
|
+
validation_result: ModelValidationResult = None,
|
1621
|
+
hf_download_config_present: bool = None,
|
1622
|
+
model_name: str = None,
|
1623
|
+
):
|
1624
|
+
if import_model_details.download_from_hf:
|
1625
|
+
# validates config.json exists for safetensors model from huggingface
|
1626
|
+
if not (
|
1627
|
+
hf_download_config_present
|
1628
|
+
or import_model_details.ignore_model_artifact_check
|
1629
|
+
):
|
1630
|
+
raise AquaRuntimeError(
|
1631
|
+
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
|
1632
|
+
f"by {ModelFormat.SAFETENSORS} format model."
|
1633
|
+
f" Please check if the model name is correct in Hugging Face repository."
|
1634
|
+
)
|
1635
|
+
validation_result.telemetry_model_name = model_name
|
1636
|
+
else:
|
1637
|
+
# validate if config.json is available from object storage, and get model name for telemetry
|
1638
|
+
model_config = None
|
1639
|
+
try:
|
1640
|
+
model_config = load_config(
|
1641
|
+
file_path=import_model_details.os_path,
|
1642
|
+
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
|
1643
|
+
)
|
1644
|
+
except Exception as ex:
|
1645
|
+
message = (
|
1646
|
+
f"The model path {import_model_details.os_path} does not contain the file config.json. "
|
1647
|
+
f"Please check if the path is correct or the model artifacts are available at this location."
|
1648
|
+
)
|
1649
|
+
if not import_model_details.ignore_model_artifact_check:
|
1650
|
+
logger.error(
|
1651
|
+
f"{message}\n"
|
1652
|
+
f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}"
|
1653
|
+
)
|
1654
|
+
raise AquaRuntimeError(message) from ex
|
1655
|
+
else:
|
1656
|
+
logger.warning(
|
1657
|
+
f"{message}\n"
|
1658
|
+
f"Proceeding with model registration as ignore_model_artifact_check field is set."
|
1659
|
+
)
|
1660
|
+
|
1661
|
+
if verified_model:
|
1662
|
+
# model_type validation, log message if metadata field doesn't match.
|
1663
|
+
try:
|
1664
|
+
metadata_model_type = verified_model.custom_metadata_list.get(
|
1665
|
+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
|
1666
|
+
).value
|
1667
|
+
if metadata_model_type and model_config is not None:
|
1668
|
+
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
|
1669
|
+
if (
|
1670
|
+
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
|
1671
|
+
!= metadata_model_type
|
1672
|
+
):
|
1673
|
+
logger.debug(
|
1674
|
+
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
|
1675
|
+
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
|
1676
|
+
f"the model {model_name}. Please check if the path is correct or "
|
1677
|
+
f"the correct model artifacts are available at this location."
|
1678
|
+
f""
|
1679
|
+
)
|
1680
|
+
else:
|
1681
|
+
logger.debug(
|
1682
|
+
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
|
1683
|
+
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
|
1684
|
+
)
|
1685
|
+
except Exception as ex:
|
1686
|
+
# todo: raise exception if model_type doesn't match. Currently log message and pass since service
|
1687
|
+
# models do not have this metadata.
|
1688
|
+
logger.debug(
|
1689
|
+
f"Error occurred while processing metadata for model {model_name}. "
|
1690
|
+
f"Exception: {str(ex)}"
|
1691
|
+
)
|
1692
|
+
validation_result.telemetry_model_name = verified_model.display_name
|
1693
|
+
elif (
|
1694
|
+
model_config is not None
|
1695
|
+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
|
1696
|
+
):
|
1697
|
+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
|
1698
|
+
elif (
|
1699
|
+
model_config is not None
|
1700
|
+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
|
1701
|
+
):
|
1702
|
+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
|
1703
|
+
else:
|
1704
|
+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
|
1705
|
+
|
1706
|
+
@staticmethod
|
1707
|
+
def _validate_gguf_format(
|
1708
|
+
import_model_details: ImportModelDetails = None,
|
1709
|
+
verified_model: DataScienceModel = None,
|
1710
|
+
gguf_model_files: List[str] = None,
|
1711
|
+
validation_result: ModelValidationResult = None,
|
1712
|
+
model_name: str = None,
|
1713
|
+
):
|
1714
|
+
if import_model_details.finetuning_container:
|
1715
|
+
raise AquaValueError(
|
1716
|
+
"Fine-tuning is currently not supported with GGUF model format."
|
1717
|
+
)
|
1718
|
+
if verified_model:
|
1719
|
+
try:
|
1720
|
+
model_file = verified_model.custom_metadata_list.get(
|
1721
|
+
AQUA_MODEL_ARTIFACT_FILE
|
1722
|
+
).value
|
1723
|
+
except ValueError as err:
|
1724
|
+
raise AquaRuntimeError(
|
1725
|
+
f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
|
1726
|
+
f"Please check if the model has the valid metadata."
|
1727
|
+
) from err
|
1728
|
+
else:
|
1729
|
+
model_file = import_model_details.model_file
|
1730
|
+
|
1731
|
+
model_files = gguf_model_files
|
1732
|
+
# todo: have a separate error validation class for different type of error messages.
|
1733
|
+
if model_file:
|
1734
|
+
if model_file not in model_files:
|
1735
|
+
raise AquaRuntimeError(
|
1736
|
+
f"The model path {import_model_details.os_path} or the Hugging Face "
|
1737
|
+
f"model repository for {model_name} does not contain the file "
|
1738
|
+
f"{model_file}. Please check if the path is correct or the model "
|
1739
|
+
f"artifacts are available at this location."
|
1740
|
+
)
|
1741
|
+
else:
|
1742
|
+
validation_result.model_file = model_file
|
1743
|
+
elif len(model_files) == 0:
|
1744
|
+
raise AquaRuntimeError(
|
1745
|
+
f"The model path {import_model_details.os_path} or the Hugging Face model "
|
1746
|
+
f"repository for {model_name} does not contain any GGUF format files. "
|
1747
|
+
f"Please check if the path is correct or the model artifacts are available "
|
1748
|
+
f"at this location."
|
1749
|
+
)
|
1750
|
+
elif len(model_files) > 1:
|
1751
|
+
raise AquaRuntimeError(
|
1752
|
+
f"The model path {import_model_details.os_path} or the Hugging Face model "
|
1753
|
+
f"repository for {model_name} contains multiple GGUF format files. "
|
1754
|
+
f"Please specify the file that needs to be deployed using the model_file "
|
1755
|
+
f"parameter."
|
1756
|
+
)
|
1757
|
+
else:
|
1758
|
+
validation_result.model_file = model_files[0]
|
1759
|
+
|
1760
|
+
if verified_model:
|
1761
|
+
validation_result.telemetry_model_name = verified_model.display_name
|
1762
|
+
elif import_model_details.download_from_hf:
|
1763
|
+
validation_result.telemetry_model_name = model_name
|
1764
|
+
else:
|
1765
|
+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
|
1766
|
+
|
1767
|
+
@staticmethod
|
1768
|
+
def _download_model_from_hf(
|
1769
|
+
model_name: str,
|
1770
|
+
os_path: str,
|
1771
|
+
local_dir: str = None,
|
1772
|
+
allow_patterns: List[str] = None,
|
1773
|
+
ignore_patterns: List[str] = None,
|
1774
|
+
) -> str:
|
1775
|
+
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
|
1776
|
+
to object storage location.
|
1777
|
+
|
1778
|
+
Parameters
|
1779
|
+
----------
|
1780
|
+
model_name (str): The huggingface model name.
|
1781
|
+
os_path (str): The OS path where the model files are located.
|
1782
|
+
local_dir (str): The local temp dir to store the huggingface model.
|
1783
|
+
allow_patterns (list): Model files matching at least one pattern are downloaded.
|
1784
|
+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
|
1785
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1786
|
+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
|
1787
|
+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
|
1788
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1789
|
+
|
1790
|
+
Returns
|
1791
|
+
-------
|
1792
|
+
model_artifact_path (str): Location where the model artifacts are downloaded.
|
1793
|
+
"""
|
1794
|
+
# Download the model from hub
|
1795
|
+
if local_dir:
|
1796
|
+
local_dir = os.path.join(local_dir, model_name)
|
1797
|
+
os.makedirs(local_dir, exist_ok=True)
|
1798
|
+
|
1799
|
+
# if local_dir is not set, the return value points to the cached data folder
|
1800
|
+
local_dir = snapshot_download(
|
1801
|
+
repo_id=model_name,
|
1802
|
+
local_dir=local_dir,
|
1803
|
+
allow_patterns=allow_patterns,
|
1804
|
+
ignore_patterns=ignore_patterns,
|
1805
|
+
)
|
1806
|
+
# Upload to object storage and skip .cache/huggingface/ folder
|
1807
|
+
logger.debug(
|
1808
|
+
f"Uploading local artifacts from local directory {local_dir} to {os_path}."
|
1809
|
+
)
|
1810
|
+
# Upload to object storage
|
1811
|
+
model_artifact_path = upload_folder(
|
1812
|
+
os_path=os_path,
|
1813
|
+
local_dir=local_dir,
|
1814
|
+
model_name=model_name,
|
1815
|
+
exclude_pattern=f"{HF_METADATA_FOLDER}*",
|
1816
|
+
)
|
1817
|
+
|
1818
|
+
return model_artifact_path
|
1819
|
+
|
1820
|
+
def register(
|
1821
|
+
self, import_model_details: ImportModelDetails = None, **kwargs
|
1822
|
+
) -> AquaModel:
|
1823
|
+
"""Loads the model from object storage and registers as Model in Data Science Model catalog
|
1824
|
+
The inference container and finetuning container could be of type Service Managed Container(SMC) or custom.
|
1825
|
+
If it is custom, full container URI is expected. If it of type SMC, only the container family name is expected.\n
|
1826
|
+
For detailed information about CLI flags see: https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/cli-tips.md#register-model
|
1827
|
+
|
1828
|
+
Args:
|
1829
|
+
import_model_details (ImportModelDetails): Model details for importing the model.
|
1830
|
+
kwargs:
|
1831
|
+
model (str): name of the model or OCID of the service model that has inference and finetuning information
|
1832
|
+
os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix
|
1833
|
+
inference_container (str): selects service defaults
|
1834
|
+
finetuning_container (str): selects service defaults
|
1835
|
+
allow_patterns (list): Model files matching at least one pattern are downloaded.
|
1836
|
+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
|
1837
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1838
|
+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
|
1839
|
+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
|
1840
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1841
|
+
cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
|
1842
|
+
registered. Set to True by default.
|
1843
|
+
|
1844
|
+
Returns:
|
1845
|
+
AquaModel:
|
1846
|
+
The registered model as a AquaModel object.
|
1847
|
+
"""
|
1848
|
+
if not import_model_details:
|
1849
|
+
import_model_details = ImportModelDetails(**kwargs)
|
1850
|
+
|
1851
|
+
# If OCID of a model is passed, we need to copy the defaults for Tags and metadata from the service model.
|
1852
|
+
verified_model: Optional[DataScienceModel] = None
|
1853
|
+
if (
|
1854
|
+
import_model_details.model.startswith("ocid")
|
1855
|
+
and "datasciencemodel" in import_model_details.model
|
1856
|
+
):
|
1857
|
+
logger.info(f"Fetching details for model {import_model_details.model}.")
|
1858
|
+
verified_model = DataScienceModel.from_id(import_model_details.model)
|
1859
|
+
else:
|
1860
|
+
# If users passes model name, check if there is model with the same name in the service model catalog. If it is there, then use that model
|
1861
|
+
model_service_id = self._find_matching_aqua_model(
|
1862
|
+
import_model_details.model
|
1863
|
+
)
|
1864
|
+
if model_service_id:
|
1865
|
+
logger.info(
|
1866
|
+
f"Found service model for {import_model_details.model}: {model_service_id}"
|
1867
|
+
)
|
1868
|
+
verified_model = DataScienceModel.from_id(model_service_id)
|
1869
|
+
|
1870
|
+
# Copy the model name from the service model if `model` is ocid
|
1871
|
+
model_name = (
|
1872
|
+
verified_model.display_name
|
1873
|
+
if verified_model
|
1874
|
+
else import_model_details.model
|
1875
|
+
)
|
1876
|
+
|
1877
|
+
# validate model and artifact
|
1878
|
+
validation_result = self._validate_model(
|
1879
|
+
import_model_details=import_model_details,
|
1880
|
+
model_name=model_name,
|
1881
|
+
verified_model=verified_model,
|
1882
|
+
)
|
1883
|
+
|
1884
|
+
# download model from hugginface if indicates
|
1885
|
+
if import_model_details.download_from_hf:
|
1886
|
+
artifact_path = self._download_model_from_hf(
|
1887
|
+
model_name=model_name,
|
1888
|
+
os_path=import_model_details.os_path,
|
1889
|
+
local_dir=import_model_details.local_dir,
|
1890
|
+
allow_patterns=import_model_details.allow_patterns,
|
1891
|
+
ignore_patterns=import_model_details.ignore_patterns,
|
1892
|
+
).rstrip("/")
|
1893
|
+
else:
|
1894
|
+
artifact_path = import_model_details.os_path.rstrip("/")
|
1895
|
+
|
1896
|
+
# Create Model catalog entry with pass by reference
|
1897
|
+
ds_model = self._create_model_catalog_entry(
|
1898
|
+
os_path=artifact_path,
|
1899
|
+
model_name=model_name,
|
1900
|
+
inference_container=import_model_details.inference_container,
|
1901
|
+
finetuning_container=import_model_details.finetuning_container,
|
1902
|
+
verified_model=verified_model,
|
1903
|
+
validation_result=validation_result,
|
1904
|
+
compartment_id=import_model_details.compartment_id,
|
1905
|
+
project_id=import_model_details.project_id,
|
1906
|
+
inference_container_uri=import_model_details.inference_container_uri,
|
1907
|
+
freeform_tags=import_model_details.freeform_tags,
|
1908
|
+
defined_tags=import_model_details.defined_tags,
|
1909
|
+
)
|
1910
|
+
# registered model will always have inference and evaluation container, but
|
1911
|
+
# fine-tuning container may be not set
|
1912
|
+
inference_container = ds_model.custom_metadata_list.get(
|
1913
|
+
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
|
1914
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER),
|
1915
|
+
).value
|
1916
|
+
inference_container_uri = ds_model.custom_metadata_list.get(
|
1917
|
+
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
|
1918
|
+
ModelCustomMetadataItem(
|
1919
|
+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI
|
1920
|
+
),
|
1921
|
+
).value
|
1922
|
+
evaluation_container = ds_model.custom_metadata_list.get(
|
1923
|
+
ModelCustomMetadataFields.EVALUATION_CONTAINER,
|
1924
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.EVALUATION_CONTAINER),
|
1925
|
+
).value
|
1926
|
+
finetuning_container: str = ds_model.custom_metadata_list.get(
|
1927
|
+
ModelCustomMetadataFields.FINETUNE_CONTAINER,
|
1928
|
+
ModelCustomMetadataItem(key=ModelCustomMetadataFields.FINETUNE_CONTAINER),
|
1929
|
+
).value
|
1930
|
+
|
1931
|
+
aqua_model_attributes = dict(
|
1932
|
+
**self._process_model(ds_model, self.region),
|
1933
|
+
project_id=ds_model.project_id,
|
1934
|
+
inference_container=inference_container,
|
1935
|
+
inference_container_uri=inference_container_uri,
|
1936
|
+
finetuning_container=finetuning_container,
|
1937
|
+
evaluation_container=evaluation_container,
|
1938
|
+
artifact_location=artifact_path,
|
1939
|
+
)
|
1940
|
+
|
1941
|
+
self.telemetry.record_event_async(
|
1942
|
+
category="aqua/model",
|
1943
|
+
action="register",
|
1944
|
+
detail=validation_result.telemetry_model_name,
|
1945
|
+
)
|
1946
|
+
|
1947
|
+
if (
|
1948
|
+
import_model_details.download_from_hf
|
1949
|
+
and import_model_details.cleanup_model_cache
|
1950
|
+
):
|
1951
|
+
cleanup_local_hf_model_artifact(
|
1952
|
+
model_name=model_name, local_dir=import_model_details.local_dir
|
1953
|
+
)
|
1954
|
+
|
1955
|
+
return AquaModel(**aqua_model_attributes)
|
1956
|
+
|
1957
|
+
def _if_show(self, model: DataScienceModel) -> bool:
|
1958
|
+
"""Determine if the given model should be return by `list`."""
|
1959
|
+
if model.freeform_tags is None:
|
1960
|
+
return False
|
1961
|
+
|
1962
|
+
TARGET_TAGS = model.freeform_tags.keys()
|
1963
|
+
return Tags.AQUA_TAG in TARGET_TAGS or Tags.AQUA_TAG.lower() in TARGET_TAGS
|
1964
|
+
|
1965
|
+
def _load_icon(self, model_name: str) -> str:
|
1966
|
+
"""Loads icon."""
|
1967
|
+
|
1968
|
+
# TODO: switch to the official logo
|
1969
|
+
try:
|
1970
|
+
return create_word_icon(model_name, return_as_datauri=True)
|
1971
|
+
except Exception as e:
|
1972
|
+
logger.debug(f"Failed to load icon for the model={model_name}: {str(e)}.")
|
1973
|
+
return None
|
1974
|
+
|
1975
|
+
def _rqs(self, compartment_id: str, model_type="FT", **kwargs):
|
1976
|
+
"""Use RQS to fetch models in the user tenancy."""
|
1977
|
+
if model_type == ModelType.FT:
|
1978
|
+
filter_tag = Tags.AQUA_FINE_TUNED_MODEL_TAG
|
1979
|
+
elif model_type == ModelType.BASE:
|
1980
|
+
filter_tag = Tags.BASE_MODEL_CUSTOM
|
1981
|
+
# elif model_type == ModelType.MULTIMODEL:
|
1982
|
+
# filter_tag = Tags.MULTIMODEL_TYPE_TAG
|
1983
|
+
else:
|
1984
|
+
raise AquaValueError(
|
1985
|
+
f"Model of type {model_type} is unknown. The values should be in {ModelType.values()}"
|
1986
|
+
)
|
1987
|
+
|
1988
|
+
condition_tags = f"&& (freeformTags.key = '{Tags.AQUA_TAG}' && freeformTags.key = '{filter_tag}')"
|
1989
|
+
condition_lifecycle = "&& lifecycleState = 'ACTIVE'"
|
1990
|
+
query = f"query datasciencemodel resources where (compartmentId = '{compartment_id}' {condition_lifecycle} {condition_tags})"
|
1991
|
+
logger.info(query)
|
1992
|
+
logger.info(f"tenant_id={TENANCY_OCID}")
|
1993
|
+
return OCIResource.search(
|
1994
|
+
query, type=SEARCH_TYPE.STRUCTURED, tenant_id=TENANCY_OCID, **kwargs
|
1995
|
+
)
|
1996
|
+
|
1997
|
+
@staticmethod
|
1998
|
+
def _build_search_text(tags: dict, description: str = None) -> str:
|
1999
|
+
"""Constructs search_text field in response."""
|
2000
|
+
description = description or ""
|
2001
|
+
tags_text = (
|
2002
|
+
",".join(str(v) for v in tags.values()) if isinstance(tags, dict) else ""
|
2003
|
+
)
|
2004
|
+
separator = " " if description else ""
|
2005
|
+
return f"{description}{separator}{tags_text}"
|
2006
|
+
|
2007
|
+
@telemetry(entry_point="plugin=model&action=load_readme", name="aqua")
|
2008
|
+
def load_readme(self, model_id: str) -> AquaModelReadme:
|
2009
|
+
"""Loads the readme or the model card for the given model.
|
2010
|
+
|
2011
|
+
Parameters
|
2012
|
+
----------
|
2013
|
+
model_id: str
|
2014
|
+
The model id.
|
2015
|
+
|
2016
|
+
Returns
|
2017
|
+
-------
|
2018
|
+
AquaModelReadme:
|
2019
|
+
The instance of AquaModelReadme.
|
2020
|
+
"""
|
2021
|
+
oci_model = self.ds_client.get_model(model_id).data
|
2022
|
+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
|
2023
|
+
if not artifact_path:
|
2024
|
+
raise AquaRuntimeError(
|
2025
|
+
f"Readme could not be loaded. Failed to get artifact path from custom metadata for"
|
2026
|
+
f"the model {model_id}."
|
2027
|
+
)
|
2028
|
+
|
2029
|
+
content = ""
|
2030
|
+
try:
|
2031
|
+
content = self.ds_client.get_model_defined_metadatum_artifact_content(
|
2032
|
+
model_id, AquaModelMetadataKeys.README
|
2033
|
+
).data.content.decode("utf-8", errors="ignore")
|
2034
|
+
logger.info(f"Fetched {README} from defined metadata for model: {model_id}")
|
2035
|
+
except Exception as ex:
|
2036
|
+
logger.error(
|
2037
|
+
f"Readme could not be found for model: {model_id} in defined metadata : {str(ex)}"
|
2038
|
+
)
|
2039
|
+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
|
2040
|
+
readme_path = os.path.join(os.path.dirname(artifact_path), "artifact")
|
2041
|
+
if not is_path_exists(readme_path):
|
2042
|
+
readme_path = os.path.join(artifact_path.rstrip("/"), "artifact")
|
2043
|
+
if not is_path_exists(readme_path):
|
2044
|
+
readme_path = f"{artifact_path.rstrip('/')}/"
|
2045
|
+
|
2046
|
+
readme_file_path = os.path.join(readme_path, README)
|
2047
|
+
logger.info(f"Fetching {README} from {readme_file_path}")
|
2048
|
+
if is_path_exists(readme_file_path):
|
2049
|
+
try:
|
2050
|
+
content = str(read_file(readme_file_path, auth=default_signer()))
|
2051
|
+
except Exception as e:
|
2052
|
+
logger.debug(
|
2053
|
+
f"Error occurred while fetching config {README} at path {readme_file_path} : {str(e)}"
|
2054
|
+
)
|
2055
|
+
return AquaModelReadme(id=model_id, model_card=content)
|
2056
|
+
|
2057
|
+
@telemetry(entry_point="plugin=model&action=load_license", name="aqua")
|
2058
|
+
def load_license(self, model_id: str) -> AquaModelLicense:
|
2059
|
+
"""Loads the license full text for the given model.
|
2060
|
+
|
2061
|
+
Parameters
|
2062
|
+
----------
|
2063
|
+
model_id: str
|
2064
|
+
The model id.
|
2065
|
+
|
2066
|
+
Returns
|
2067
|
+
-------
|
2068
|
+
AquaModelLicense:
|
2069
|
+
The instance of AquaModelLicense.
|
2070
|
+
"""
|
2071
|
+
oci_model = self.ds_client.get_model(model_id).data
|
2072
|
+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
|
2073
|
+
if not artifact_path:
|
2074
|
+
raise AquaRuntimeError(
|
2075
|
+
f"License could not be loaded. Failed to get artifact path from custom metadata for"
|
2076
|
+
f"the model {model_id}."
|
2077
|
+
)
|
2078
|
+
|
2079
|
+
content = ""
|
2080
|
+
try:
|
2081
|
+
content = self.ds_client.get_model_defined_metadatum_artifact_content(
|
2082
|
+
model_id, AquaModelMetadataKeys.LICENSE
|
2083
|
+
).data.content.decode("utf-8", errors="ignore")
|
2084
|
+
logger.info(
|
2085
|
+
f"Fetched {LICENSE} from defined metadata for model: {model_id}"
|
2086
|
+
)
|
2087
|
+
except Exception as ex:
|
2088
|
+
logger.error(
|
2089
|
+
f"License could not be found for model: {model_id} in defined metadata : {str(ex)}"
|
2090
|
+
)
|
2091
|
+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
|
2092
|
+
license_path = os.path.join(os.path.dirname(artifact_path), "config")
|
2093
|
+
if not is_path_exists(license_path):
|
2094
|
+
license_path = os.path.join(artifact_path.rstrip("/"), "config")
|
2095
|
+
if not is_path_exists(license_path):
|
2096
|
+
license_path = f"{artifact_path.rstrip('/')}/"
|
2097
|
+
|
2098
|
+
license_file_path = os.path.join(license_path, LICENSE)
|
2099
|
+
logger.info(f"Fetching {LICENSE} from {license_file_path}")
|
2100
|
+
if is_path_exists(license_file_path):
|
2101
|
+
try:
|
2102
|
+
content = str(read_file(license_file_path, auth=default_signer()))
|
2103
|
+
except Exception as e:
|
2104
|
+
logger.debug(
|
2105
|
+
f"Error occurred while fetching config {LICENSE} at path {license_path} : {str(e)}"
|
2106
|
+
)
|
2107
|
+
return AquaModelLicense(id=model_id, license=content)
|
2108
|
+
|
2109
|
+
def _find_matching_aqua_model(self, model_id: str) -> Optional[str]:
|
2110
|
+
"""
|
2111
|
+
Finds a matching model in AQUA based on the model ID from list of verified models.
|
2112
|
+
|
2113
|
+
Parameters
|
2114
|
+
----------
|
2115
|
+
model_id (str): Verified model ID to match.
|
2116
|
+
|
2117
|
+
Returns
|
2118
|
+
-------
|
2119
|
+
Optional[str]
|
2120
|
+
Returns model ocid that matches the model in the service catalog else returns None.
|
2121
|
+
"""
|
2122
|
+
# Convert the model ID to lowercase once
|
2123
|
+
model_id_lower = model_id.lower()
|
2124
|
+
|
2125
|
+
aqua_model_list = self.list()
|
2126
|
+
|
2127
|
+
for aqua_model_summary in aqua_model_list:
|
2128
|
+
if aqua_model_summary.name.lower() == model_id_lower:
|
2129
|
+
logger.info(
|
2130
|
+
f"Found matching verified model id {aqua_model_summary.id} for the model {model_id}"
|
2131
|
+
)
|
2132
|
+
return aqua_model_summary.id
|
2133
|
+
|
2134
|
+
return None
|