oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.9rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +506 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +269 -0
- ads/aqua/common/enums.py +122 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1285 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +248 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +298 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +282 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2114 -0
- ads/aqua/modeldeployment/__init__.py +8 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1326 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/inference.py +74 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +499 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +175 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +445 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +125 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/METADATA +150 -150
- oracle_ads-2.13.9rc1.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,266 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8; -*-
|
3
|
+
|
4
|
+
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
|
8
|
+
from typing import Dict, List, Optional, Tuple, Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
from ads.common import logger
|
13
|
+
from ads.model.extractor.pytorch_extractor import PytorchExtractor
|
14
|
+
from ads.common.decorator.runtime_dependency import (
|
15
|
+
runtime_dependency,
|
16
|
+
OptionalDependency,
|
17
|
+
)
|
18
|
+
from ads.model.generic_model import FrameworkSpecificModel
|
19
|
+
from ads.model.model_properties import ModelProperties
|
20
|
+
from ads.model.serde.model_serializer import PyTorchModelSerializerType
|
21
|
+
from ads.model.common.utils import (
|
22
|
+
DEPRECATE_AS_ONNX_WARNING,
|
23
|
+
DEPRECATE_USE_TORCH_SCRIPT_WARNING,
|
24
|
+
)
|
25
|
+
from ads.model.serde.common import SERDE
|
26
|
+
|
27
|
+
ONNX_MODEL_FILE_NAME = "model.onnx"
|
28
|
+
PYTORCH_MODEL_FILE_NAME = "model.pt"
|
29
|
+
|
30
|
+
|
31
|
+
class PyTorchModel(FrameworkSpecificModel):
|
32
|
+
"""PyTorchModel class for estimators from Pytorch framework.
|
33
|
+
|
34
|
+
Attributes
|
35
|
+
----------
|
36
|
+
algorithm: str
|
37
|
+
The algorithm of the model.
|
38
|
+
artifact_dir: str
|
39
|
+
Artifact directory to store the files needed for deployment.
|
40
|
+
auth: Dict
|
41
|
+
Default authentication is set using the `ads.set_auth` API. To override the
|
42
|
+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
|
43
|
+
an authentication signer to instantiate an IdentityClient object.
|
44
|
+
estimator: Callable
|
45
|
+
A trained pytorch estimator/model using Pytorch.
|
46
|
+
framework: str
|
47
|
+
"pytorch", the framework name of the model.
|
48
|
+
hyperparameter: dict
|
49
|
+
The hyperparameters of the estimator.
|
50
|
+
metadata_custom: ModelCustomMetadata
|
51
|
+
The model custom metadata.
|
52
|
+
metadata_provenance: ModelProvenanceMetadata
|
53
|
+
The model provenance metadata.
|
54
|
+
metadata_taxonomy: ModelTaxonomyMetadata
|
55
|
+
The model taxonomy metadata.
|
56
|
+
model_artifact: ModelArtifact
|
57
|
+
This is built by calling prepare.
|
58
|
+
model_deployment: ModelDeployment
|
59
|
+
A ModelDeployment instance.
|
60
|
+
model_file_name: str
|
61
|
+
Name of the serialized model.
|
62
|
+
model_id: str
|
63
|
+
The model ID.
|
64
|
+
properties: ModelProperties
|
65
|
+
ModelProperties object required to save and deploy model.
|
66
|
+
For more details, check https://accelerated-data-science.readthedocs.io/en/latest/ads.model.html#module-ads.model.model_properties.
|
67
|
+
runtime_info: RuntimeInfo
|
68
|
+
A RuntimeInfo instance.
|
69
|
+
schema_input: Schema
|
70
|
+
Schema describes the structure of the input data.
|
71
|
+
schema_output: Schema
|
72
|
+
Schema describes the structure of the output data.
|
73
|
+
serialize: bool
|
74
|
+
Whether to serialize the model to pkl file by default. If False, you need to serialize the model manually,
|
75
|
+
save it under artifact_dir and update the score.py manually.
|
76
|
+
version: str
|
77
|
+
The framework version of the model.
|
78
|
+
|
79
|
+
Methods
|
80
|
+
-------
|
81
|
+
delete_deployment(...)
|
82
|
+
Deletes the current model deployment.
|
83
|
+
deploy(..., **kwargs)
|
84
|
+
Deploys a model.
|
85
|
+
from_model_artifact(uri, model_file_name, artifact_dir, ..., **kwargs)
|
86
|
+
Loads model from the specified folder, or zip/tar archive.
|
87
|
+
from_model_catalog(model_id, model_file_name, artifact_dir, ..., **kwargs)
|
88
|
+
Loads model from model catalog.
|
89
|
+
introspect(...)
|
90
|
+
Runs model introspection.
|
91
|
+
predict(data, ...)
|
92
|
+
Returns prediction of input data run against the model deployment endpoint.
|
93
|
+
prepare(..., **kwargs)
|
94
|
+
Prepare and save the score.py, serialized model and runtime.yaml file.
|
95
|
+
reload(...)
|
96
|
+
Reloads the model artifact files: `score.py` and the `runtime.yaml`.
|
97
|
+
save(..., **kwargs)
|
98
|
+
Saves model artifacts to the model catalog.
|
99
|
+
summary_status(...)
|
100
|
+
Gets a summary table of the current status.
|
101
|
+
verify(data, ...)
|
102
|
+
Tests if deployment works in local environment.
|
103
|
+
|
104
|
+
Examples
|
105
|
+
--------
|
106
|
+
>>> torch_model = PyTorchModel(estimator=torch_estimator,
|
107
|
+
... artifact_dir=tmp_model_dir)
|
108
|
+
>>> inference_conda_env = "generalml_p37_cpu_v1"
|
109
|
+
|
110
|
+
>>> torch_model.prepare(inference_conda_env=inference_conda_env, force_overwrite=True)
|
111
|
+
>>> torch_model.reload()
|
112
|
+
>>> torch_model.verify(...)
|
113
|
+
>>> torch_model.save()
|
114
|
+
>>> model_deployment = torch_model.deploy(wait_for_completion=False)
|
115
|
+
>>> torch_model.predict(...)
|
116
|
+
"""
|
117
|
+
|
118
|
+
_PREFIX = "pytorch"
|
119
|
+
model_save_serializer_type = PyTorchModelSerializerType
|
120
|
+
|
121
|
+
@runtime_dependency(module="torch", install_from=OptionalDependency.PYTORCH)
|
122
|
+
def __init__(
|
123
|
+
self,
|
124
|
+
estimator: callable,
|
125
|
+
artifact_dir: Optional[str] = None,
|
126
|
+
properties: Optional[ModelProperties] = None,
|
127
|
+
auth: Dict = None,
|
128
|
+
model_save_serializer: Optional[SERDE] = model_save_serializer_type.TORCH,
|
129
|
+
model_input_serializer: Optional[SERDE] = None,
|
130
|
+
**kwargs,
|
131
|
+
):
|
132
|
+
"""
|
133
|
+
Initiates a PyTorchModel instance.
|
134
|
+
|
135
|
+
Parameters
|
136
|
+
----------
|
137
|
+
estimator: callable
|
138
|
+
Any model object generated by pytorch framework
|
139
|
+
artifact_dir: str
|
140
|
+
artifact directory to store the files needed for deployment.
|
141
|
+
properties: (ModelProperties, optional). Defaults to None.
|
142
|
+
ModelProperties object required to save and deploy model.
|
143
|
+
auth :(Dict, optional). Defaults to None.
|
144
|
+
The default authetication is set using `ads.set_auth` API. If you need to override the
|
145
|
+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
|
146
|
+
authentication signer and kwargs required to instantiate IdentityClient object.
|
147
|
+
model_save_serializer: (SERDE or str, optional). Defaults to None.
|
148
|
+
Instance of ads.model.SERDE. Used for serialize/deserialize model.
|
149
|
+
model_input_serializer: (SERDE, optional). Defaults to None.
|
150
|
+
Instance of ads.model.SERDE. Used for serialize/deserialize data.
|
151
|
+
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
PyTorchModel
|
155
|
+
PyTorchModel instance.
|
156
|
+
"""
|
157
|
+
super().__init__(
|
158
|
+
estimator=estimator,
|
159
|
+
artifact_dir=artifact_dir,
|
160
|
+
properties=properties,
|
161
|
+
auth=auth,
|
162
|
+
model_save_serializer=model_save_serializer,
|
163
|
+
model_input_serializer=model_input_serializer,
|
164
|
+
**kwargs,
|
165
|
+
)
|
166
|
+
self._extractor = PytorchExtractor(estimator)
|
167
|
+
self.framework = self._extractor.framework
|
168
|
+
self.algorithm = self._extractor.algorithm
|
169
|
+
self.version = self._extractor.version
|
170
|
+
self.hyperparameter = self._extractor.hyperparameter
|
171
|
+
self.version = torch.__version__
|
172
|
+
|
173
|
+
def serialize_model(
|
174
|
+
self,
|
175
|
+
as_onnx: bool = False,
|
176
|
+
force_overwrite: bool = False,
|
177
|
+
X_sample: Optional[
|
178
|
+
Union[
|
179
|
+
Dict,
|
180
|
+
str,
|
181
|
+
List,
|
182
|
+
Tuple,
|
183
|
+
np.ndarray,
|
184
|
+
pd.core.series.Series,
|
185
|
+
pd.core.frame.DataFrame,
|
186
|
+
]
|
187
|
+
] = None,
|
188
|
+
use_torch_script: bool = None,
|
189
|
+
**kwargs,
|
190
|
+
) -> None:
|
191
|
+
"""
|
192
|
+
Serialize and save Pytorch model using ONNX or model specific method.
|
193
|
+
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
as_onnx: (bool, optional). Defaults to False.
|
197
|
+
If set as True, convert into ONNX model.
|
198
|
+
force_overwrite: (bool, optional). Defaults to False.
|
199
|
+
If set as True, overwrite serialized model if exists.
|
200
|
+
X_sample: Union[list, tuple, pd.Series, np.ndarray, pd.DataFrame]. Defaults to None.
|
201
|
+
A sample of input data that will be used to generate input schema and detect onnx_args.
|
202
|
+
use_torch_script: (bool, optional). Defaults to None (If the default value has not been changed, it will be set as `False`).
|
203
|
+
If set as `True`, the model will be serialized as a TorchScript program. Check https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format for more details.
|
204
|
+
If set as `False`, it will only save the trained model’s learned parameters, and the score.py
|
205
|
+
need to be modified to construct the model class instance first. Check https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended for more details.
|
206
|
+
**kwargs: optional params used to serialize pytorch model to onnx,
|
207
|
+
including the following:
|
208
|
+
onnx_args: (tuple or torch.Tensor), default to None
|
209
|
+
Contains model inputs such that model(onnx_args) is a valid
|
210
|
+
invocation of the model. Can be structured either as: 1) ONLY A
|
211
|
+
TUPLE OF ARGUMENTS; 2) A TENSOR; 3) A TUPLE OF ARGUMENTS ENDING
|
212
|
+
WITH A DICTIONARY OF NAMED ARGUMENTS
|
213
|
+
input_names: (List[str], optional). Names to assign to the input
|
214
|
+
nodes of the graph, in order.
|
215
|
+
output_names: (List[str], optional). Names to assign to the output nodes of the graph, in order.
|
216
|
+
dynamic_axes: (dict, optional), default to None. Specify axes of tensors as dynamic (i.e. known only at run-time).
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
None
|
221
|
+
Nothing.
|
222
|
+
"""
|
223
|
+
if use_torch_script is None:
|
224
|
+
logger.warning(
|
225
|
+
"In future the models will be saved in TorchScript format by default. Currently saving it using torch.save method."
|
226
|
+
"Set `use_torch_script` as `True` to serialize the model as a TorchScript program by `torch.jit.save()` "
|
227
|
+
"and loaded using `torch.jit.load()` in score.py. "
|
228
|
+
"You don't need to modify `load_model()` in score.py to load the model."
|
229
|
+
"Check https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format for more details."
|
230
|
+
"Set `use_torch_script` as `False` to save only the model parameters."
|
231
|
+
"The model class instance must be constructed before "
|
232
|
+
"loading parameters in the predict function of score.py."
|
233
|
+
"Check https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended for more details."
|
234
|
+
)
|
235
|
+
use_torch_script = False
|
236
|
+
|
237
|
+
if as_onnx and use_torch_script:
|
238
|
+
raise ValueError("You can only save Pytorch model into one format.")
|
239
|
+
|
240
|
+
if as_onnx:
|
241
|
+
logger.warning(DEPRECATE_AS_ONNX_WARNING)
|
242
|
+
self.set_model_save_serializer(self.model_save_serializer_type.ONNX)
|
243
|
+
|
244
|
+
if use_torch_script:
|
245
|
+
logger.warning(DEPRECATE_USE_TORCH_SCRIPT_WARNING)
|
246
|
+
self.set_model_save_serializer(self.model_save_serializer_type.TORCHSCRIPT)
|
247
|
+
|
248
|
+
super().serialize_model(
|
249
|
+
as_onnx=as_onnx,
|
250
|
+
force_overwrite=force_overwrite,
|
251
|
+
X_sample=X_sample,
|
252
|
+
**kwargs,
|
253
|
+
)
|
254
|
+
|
255
|
+
def _to_tensor(self, data):
|
256
|
+
try:
|
257
|
+
import torchvision.transforms as transforms
|
258
|
+
|
259
|
+
convert_tensor = transforms.ToTensor()
|
260
|
+
data = convert_tensor(data)
|
261
|
+
except ModuleNotFoundError:
|
262
|
+
raise ModuleNotFoundError(
|
263
|
+
f"The `torchvision` module was not found. Please run "
|
264
|
+
f"`pip install {OptionalDependency.PYTORCH}`."
|
265
|
+
)
|
266
|
+
return data
|
@@ -0,0 +1,250 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from ads.common import logger
|
11
|
+
from ads.model.extractor.sklearn_extractor import SklearnExtractor
|
12
|
+
from ads.model.generic_model import FrameworkSpecificModel
|
13
|
+
from ads.model.model_properties import ModelProperties
|
14
|
+
from ads.model.serde.model_serializer import SklearnModelSerializerType
|
15
|
+
from ads.model.common.utils import DEPRECATE_AS_ONNX_WARNING
|
16
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17
|
+
from ads.model.serde.common import SERDE
|
18
|
+
|
19
|
+
|
20
|
+
class SklearnModel(FrameworkSpecificModel):
|
21
|
+
"""SklearnModel class for estimators from sklearn framework.
|
22
|
+
|
23
|
+
Attributes
|
24
|
+
----------
|
25
|
+
algorithm: str
|
26
|
+
The algorithm of the model.
|
27
|
+
artifact_dir: str
|
28
|
+
Artifact directory to store the files needed for deployment.
|
29
|
+
auth: Dict
|
30
|
+
Default authentication is set using the `ads.set_auth` API. To override the
|
31
|
+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
|
32
|
+
an authentication signer to instantiate an IdentityClient object.
|
33
|
+
estimator: Callable
|
34
|
+
A trained sklearn estimator/model using scikit-learn.
|
35
|
+
framework: str
|
36
|
+
"scikit-learn", the framework name of the model.
|
37
|
+
hyperparameter: dict
|
38
|
+
The hyperparameters of the estimator.
|
39
|
+
metadata_custom: ModelCustomMetadata
|
40
|
+
The model custom metadata.
|
41
|
+
metadata_provenance: ModelProvenanceMetadata
|
42
|
+
The model provenance metadata.
|
43
|
+
metadata_taxonomy: ModelTaxonomyMetadata
|
44
|
+
The model taxonomy metadata.
|
45
|
+
model_artifact: ModelArtifact
|
46
|
+
This is built by calling prepare.
|
47
|
+
model_deployment: ModelDeployment
|
48
|
+
A ModelDeployment instance.
|
49
|
+
model_file_name: str
|
50
|
+
Name of the serialized model.
|
51
|
+
model_id: str
|
52
|
+
The model ID.
|
53
|
+
properties: ModelProperties
|
54
|
+
ModelProperties object required to save and deploy model.
|
55
|
+
For more details, check https://accelerated-data-science.readthedocs.io/en/latest/ads.model.html#module-ads.model.model_properties.
|
56
|
+
runtime_info: RuntimeInfo
|
57
|
+
A RuntimeInfo instance.
|
58
|
+
schema_input: Schema
|
59
|
+
Schema describes the structure of the input data.
|
60
|
+
schema_output: Schema
|
61
|
+
Schema describes the structure of the output data.
|
62
|
+
serialize: bool
|
63
|
+
Whether to serialize the model to pkl file by default. If False, you need to serialize the model manually,
|
64
|
+
save it under artifact_dir and update the score.py manually.
|
65
|
+
version: str
|
66
|
+
The framework version of the model.
|
67
|
+
|
68
|
+
Methods
|
69
|
+
-------
|
70
|
+
delete_deployment(...)
|
71
|
+
Deletes the current model deployment.
|
72
|
+
deploy(..., **kwargs)
|
73
|
+
Deploys a model.
|
74
|
+
from_model_artifact(uri, model_file_name, artifact_dir, ..., **kwargs)
|
75
|
+
Loads model from the specified folder, or zip/tar archive.
|
76
|
+
from_model_catalog(model_id, model_file_name, artifact_dir, ..., **kwargs)
|
77
|
+
Loads model from model catalog.
|
78
|
+
introspect(...)
|
79
|
+
Runs model introspection.
|
80
|
+
predict(data, ...)
|
81
|
+
Returns prediction of input data run against the model deployment endpoint.
|
82
|
+
prepare(..., **kwargs)
|
83
|
+
Prepare and save the score.py, serialized model and runtime.yaml file.
|
84
|
+
reload(...)
|
85
|
+
Reloads the model artifact files: `score.py` and the `runtime.yaml`.
|
86
|
+
save(..., **kwargs)
|
87
|
+
Saves model artifacts to the model catalog.
|
88
|
+
summary_status(...)
|
89
|
+
Gets a summary table of the current status.
|
90
|
+
verify(data, ...)
|
91
|
+
Tests if deployment works in local environment.
|
92
|
+
|
93
|
+
Examples
|
94
|
+
--------
|
95
|
+
>>> import tempfile
|
96
|
+
>>> from sklearn.model_selection import train_test_split
|
97
|
+
>>> from ads.model.framework.sklearn_model import SklearnModel
|
98
|
+
>>> from sklearn.linear_model import LogisticRegression
|
99
|
+
>>> from sklearn.datasets import load_iris
|
100
|
+
|
101
|
+
>>> iris = load_iris()
|
102
|
+
>>> X, y = iris.data, iris.target
|
103
|
+
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
|
104
|
+
>>> sklearn_estimator = LogisticRegression()
|
105
|
+
>>> sklearn_estimator.fit(X_train, y_train)
|
106
|
+
|
107
|
+
>>> sklearn_model = SklearnModel(estimator=sklearn_estimator,
|
108
|
+
... artifact_dir=tmp_model_dir)
|
109
|
+
|
110
|
+
>>> sklearn_model.prepare(inference_conda_env="generalml_p37_cpu_v1", force_overwrite=True)
|
111
|
+
>>> sklearn_model.reload()
|
112
|
+
>>> sklearn_model.verify(X_test)
|
113
|
+
>>> sklearn_model.save()
|
114
|
+
>>> model_deployment = sklearn_model.deploy(wait_for_completion=False)
|
115
|
+
>>> sklearn_model.predict(X_test)
|
116
|
+
"""
|
117
|
+
|
118
|
+
_PREFIX = "sklearn"
|
119
|
+
model_save_serializer_type = SklearnModelSerializerType
|
120
|
+
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
estimator: Callable,
|
124
|
+
artifact_dir: Optional[str] = None,
|
125
|
+
properties: Optional[ModelProperties] = None,
|
126
|
+
auth: Dict = None,
|
127
|
+
model_save_serializer: Optional[SERDE] = model_save_serializer_type.JOBLIB,
|
128
|
+
model_input_serializer: Optional[SERDE] = None,
|
129
|
+
**kwargs,
|
130
|
+
):
|
131
|
+
"""
|
132
|
+
Initiates a SklearnModel instance.
|
133
|
+
|
134
|
+
Parameters
|
135
|
+
----------
|
136
|
+
estimator: Callable
|
137
|
+
Sklearn Model
|
138
|
+
artifact_dir: str
|
139
|
+
Directory for generate artifact.
|
140
|
+
properties: (ModelProperties, optional). Defaults to None.
|
141
|
+
ModelProperties object required to save and deploy model.
|
142
|
+
auth :(Dict, optional). Defaults to None.
|
143
|
+
The default authetication is set using `ads.set_auth` API. If you need to override the
|
144
|
+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
|
145
|
+
authentication signer and kwargs required to instantiate IdentityClient object.
|
146
|
+
model_save_serializer: (SERDE or str, optional). Defaults to None.
|
147
|
+
Instance of ads.model.SERDE. Used for serialize/deserialize model.
|
148
|
+
model_input_serializer: (SERDE, optional). Defaults to None.
|
149
|
+
Instance of ads.model.SERDE. Used for serialize/deserialize data.
|
150
|
+
|
151
|
+
Returns
|
152
|
+
-------
|
153
|
+
SklearnModel
|
154
|
+
SklearnModel instance.
|
155
|
+
|
156
|
+
|
157
|
+
Examples
|
158
|
+
--------
|
159
|
+
>>> import tempfile
|
160
|
+
>>> from sklearn.model_selection import train_test_split
|
161
|
+
>>> from ads.model.framework.sklearn_model import SklearnModel
|
162
|
+
>>> from sklearn.linear_model import LogisticRegression
|
163
|
+
>>> from sklearn.datasets import load_iris
|
164
|
+
|
165
|
+
>>> iris = load_iris()
|
166
|
+
>>> X, y = iris.data, iris.target
|
167
|
+
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
|
168
|
+
>>> sklearn_estimator = LogisticRegression()
|
169
|
+
>>> sklearn_estimator.fit(X_train, y_train)
|
170
|
+
|
171
|
+
>>> sklearn_model = SklearnModel(estimator=sklearn_estimator, artifact_dir=tempfile.mkdtemp())
|
172
|
+
>>> sklearn_model.prepare(inference_conda_env="dataexpl_p37_cpu_v3")
|
173
|
+
>>> sklearn_model.verify(X_test)
|
174
|
+
>>> sklearn_model.save()
|
175
|
+
>>> model_deployment = sklearn_model.deploy()
|
176
|
+
>>> sklearn_model.predict(X_test)
|
177
|
+
>>> sklearn_model.delete_deployment()
|
178
|
+
"""
|
179
|
+
if not (
|
180
|
+
str(type(estimator)).startswith("<class 'sklearn.")
|
181
|
+
or str(type(estimator)).startswith("<class 'onnxruntime.")
|
182
|
+
):
|
183
|
+
if hasattr(self, "ignore_conda_error") and not self.ignore_conda_error:
|
184
|
+
raise TypeError(
|
185
|
+
f"{str(type(estimator))} is not supported in SklearnModel."
|
186
|
+
)
|
187
|
+
super().__init__(
|
188
|
+
estimator=estimator,
|
189
|
+
artifact_dir=artifact_dir,
|
190
|
+
properties=properties,
|
191
|
+
auth=auth,
|
192
|
+
model_save_serializer=model_save_serializer,
|
193
|
+
model_input_serializer=model_input_serializer,
|
194
|
+
**kwargs,
|
195
|
+
)
|
196
|
+
self._extractor = SklearnExtractor(estimator)
|
197
|
+
self.framework = self._extractor.framework
|
198
|
+
self.algorithm = self._extractor.algorithm
|
199
|
+
self.version = self._extractor.version
|
200
|
+
self.hyperparameter = self._extractor.hyperparameter
|
201
|
+
|
202
|
+
def serialize_model(
|
203
|
+
self,
|
204
|
+
as_onnx: Optional[bool] = False,
|
205
|
+
initial_types: Optional[List[Tuple]] = None,
|
206
|
+
force_overwrite: Optional[bool] = False,
|
207
|
+
X_sample: Optional[
|
208
|
+
Union[
|
209
|
+
Dict,
|
210
|
+
str,
|
211
|
+
List,
|
212
|
+
Tuple,
|
213
|
+
np.ndarray,
|
214
|
+
pd.core.series.Series,
|
215
|
+
pd.core.frame.DataFrame,
|
216
|
+
]
|
217
|
+
] = None,
|
218
|
+
**kwargs: Dict,
|
219
|
+
):
|
220
|
+
"""
|
221
|
+
Serialize and save scikit-learn model using ONNX or model specific method.
|
222
|
+
|
223
|
+
Parameters
|
224
|
+
----------
|
225
|
+
as_onnx: (bool, optional). Defaults to False.
|
226
|
+
If set as True, provide initial_types or X_sample to convert into ONNX.
|
227
|
+
initial_types: (List[Tuple], optional). Defaults to None.
|
228
|
+
Each element is a tuple of a variable name and a type.
|
229
|
+
force_overwrite: (bool, optional). Defaults to False.
|
230
|
+
If set as True, overwrite serialized model if exists.
|
231
|
+
X_sample: Union[Dict, str, List, np.ndarray, pd.core.series.Series, pd.core.frame.DataFrame,]. Defaults to None.
|
232
|
+
Contains model inputs such that model(X_sample) is a valid invocation of the model.
|
233
|
+
Used to generate initial_types.
|
234
|
+
|
235
|
+
Returns
|
236
|
+
-------
|
237
|
+
None
|
238
|
+
Nothing.
|
239
|
+
"""
|
240
|
+
if as_onnx:
|
241
|
+
logger.warning(DEPRECATE_AS_ONNX_WARNING)
|
242
|
+
self.set_model_save_serializer(self.model_save_serializer_type.ONNX)
|
243
|
+
|
244
|
+
super().serialize_model(
|
245
|
+
as_onnx=as_onnx,
|
246
|
+
initial_types=initial_types,
|
247
|
+
force_overwrite=force_overwrite,
|
248
|
+
X_sample=X_sample,
|
249
|
+
**kwargs,
|
250
|
+
)
|