oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +247 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,154 @@
|
|
1
|
+
# score.py {{SCORE_VERSION}} generated by ADS {{ADS_VERSION}} on {{time_created}}
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
import json
|
5
|
+
from functools import lru_cache
|
6
|
+
from collections.abc import Iterable
|
7
|
+
|
8
|
+
# Importing pyspark libraries
|
9
|
+
from pyspark.ml.pipeline import PipelineModel
|
10
|
+
from pyspark.sql import SparkSession
|
11
|
+
from pyspark.sql.types import StructType
|
12
|
+
|
13
|
+
"""
|
14
|
+
Inference script. This script is used for prediction by scoring server when schema is known.
|
15
|
+
"""
|
16
|
+
|
17
|
+
model_name = '{{model_file_name}}'
|
18
|
+
spark_input_schema = model_name + "_input_data_schema.json"
|
19
|
+
|
20
|
+
spark = SparkSession \
|
21
|
+
.builder \
|
22
|
+
.appName('Spark Model Inference') \
|
23
|
+
.getOrCreate()
|
24
|
+
|
25
|
+
|
26
|
+
@lru_cache(maxsize=10)
|
27
|
+
def load_model(model_file_name=model_name):
|
28
|
+
"""
|
29
|
+
Loads model from the serialized format
|
30
|
+
|
31
|
+
Returns
|
32
|
+
-------
|
33
|
+
model: a model instance on which predict API can be invoked
|
34
|
+
"""
|
35
|
+
model_dir = os.environ.get("OCI_DEPLOYMENT_PATH", os.path.dirname(os.path.realpath(__file__)))
|
36
|
+
if model_dir not in sys.path:
|
37
|
+
sys.path.insert(0, model_dir)
|
38
|
+
try:
|
39
|
+
print(f'Start loading {model_file_name} from model directory {model_dir} ...')
|
40
|
+
loaded_model = PipelineModel.load(os.path.join(model_dir, model_file_name))
|
41
|
+
print("Model is successfully loaded.")
|
42
|
+
return loaded_model
|
43
|
+
except Exception as e:
|
44
|
+
error_msg = f"Failed to load the model and gave the exception: {e}."
|
45
|
+
raise Exception(error_msg)
|
46
|
+
|
47
|
+
@lru_cache(maxsize=1)
|
48
|
+
def fetch_data_type_from_schema(input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
|
49
|
+
"""
|
50
|
+
Returns data type information fetch from input_schema.json.
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
input_schema_path: path of input schema.
|
55
|
+
|
56
|
+
Returns
|
57
|
+
-------
|
58
|
+
data_type: data type fetch from input_schema.json.
|
59
|
+
|
60
|
+
"""
|
61
|
+
data_type = {}
|
62
|
+
if os.path.exists(input_schema_path):
|
63
|
+
schema = json.load(open(input_schema_path))
|
64
|
+
for col in schema['schema']:
|
65
|
+
data_type[col['name']] = col['dtype']
|
66
|
+
else:
|
67
|
+
print("input_schema has to be passed in in order to recover the same data type. pass `X_sample` in `ads.model.framework.spark_model.SparkPipelineModel.prepare` function to generate the input_schema. Otherwise, the data type might be changed after serialization/deserialization.")
|
68
|
+
return data_type
|
69
|
+
|
70
|
+
def deserialize(data, input_schema_path):
|
71
|
+
"""
|
72
|
+
Deserialize json serialization data to data in original type when sent to predict.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
data: serialized input data.
|
77
|
+
input_schema_path: path of input schema.
|
78
|
+
|
79
|
+
Returns
|
80
|
+
-------
|
81
|
+
data: deserialized input data.
|
82
|
+
|
83
|
+
"""
|
84
|
+
data_type = data.get('data_type', '') if isinstance(data, dict) else ''
|
85
|
+
json_data = data.get('data', data) if isinstance(data, dict) else data
|
86
|
+
|
87
|
+
spark_schema_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), spark_input_schema)
|
88
|
+
with open(spark_schema_path) as f:
|
89
|
+
spark_schema = StructType.fromJson(json.loads(f.read()))
|
90
|
+
|
91
|
+
if isinstance(json_data, str) or not isinstance(json_data, Iterable):
|
92
|
+
json_data = [json_data]
|
93
|
+
|
94
|
+
try:
|
95
|
+
rdd_data = spark.sparkContext.parallelize(json_data)
|
96
|
+
return spark.read.json(rdd_data, schema=spark_schema)
|
97
|
+
except:
|
98
|
+
raise TypeError(f"Unsupported data_type: {data_type} or failed to load data with provided spark schema: {spark_schema}. Only supported data types are: pyspark.sql.dataframe.DataFrame, pyspark.pandas.dataframe.DataFrame, pandas.DataFrame, json")
|
99
|
+
|
100
|
+
def pre_inference(data, input_schema_path):
|
101
|
+
"""
|
102
|
+
Preprocess data
|
103
|
+
|
104
|
+
Parameters
|
105
|
+
----------
|
106
|
+
data: Data format as expected by the predict API of the core estimator.
|
107
|
+
input_schema_path: path of input schema.
|
108
|
+
|
109
|
+
Returns
|
110
|
+
-------
|
111
|
+
data: Data format after any processing.
|
112
|
+
|
113
|
+
"""
|
114
|
+
data = deserialize(data, input_schema_path)
|
115
|
+
return data
|
116
|
+
|
117
|
+
def post_inference(yhat, input_cols):
|
118
|
+
"""
|
119
|
+
Post-process the model results
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
yhat: Data format after calling model.predict.
|
124
|
+
|
125
|
+
Returns
|
126
|
+
-------
|
127
|
+
yhat: Data format after any processing.
|
128
|
+
|
129
|
+
"""
|
130
|
+
return yhat.toPandas()['prediction'].to_list()
|
131
|
+
|
132
|
+
def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
|
133
|
+
"""
|
134
|
+
Returns prediction given the model and data to predict
|
135
|
+
|
136
|
+
Parameters
|
137
|
+
----------
|
138
|
+
model: Model instance returned by load_model API
|
139
|
+
data: Data format as expected by the predict API of the core estimator. For eg. in case of sckit models it could be numpy array/List of list/Pandas DataFrame
|
140
|
+
input_schema_path: path of input schema.
|
141
|
+
|
142
|
+
Returns
|
143
|
+
-------
|
144
|
+
predictions: Output from scoring server
|
145
|
+
Format: {'prediction': output from model.predict method}
|
146
|
+
|
147
|
+
"""
|
148
|
+
|
149
|
+
input = pre_inference(data, input_schema_path)
|
150
|
+
input_cols = input.columns
|
151
|
+
print("Dataset Count: " + str(input.count()))
|
152
|
+
|
153
|
+
yhat = post_inference(model.transform(input), input_cols)
|
154
|
+
return {'prediction': yhat}
|
@@ -0,0 +1,219 @@
|
|
1
|
+
# score.py {{SCORE_VERSION}} generated by ADS {{ADS_VERSION}} on {{time_created}}
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
from functools import lru_cache
|
5
|
+
import torch
|
6
|
+
import json
|
7
|
+
from typing import Dict, List
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from io import BytesIO
|
11
|
+
import base64
|
12
|
+
import logging
|
13
|
+
from random import randint
|
14
|
+
|
15
|
+
|
16
|
+
def get_torch_device():
|
17
|
+
num_devices = torch.cuda.device_count()
|
18
|
+
if num_devices == 0:
|
19
|
+
return "cpu"
|
20
|
+
if num_devices == 1:
|
21
|
+
return "cuda:0"
|
22
|
+
else:
|
23
|
+
return f"cuda:{randint(0, num_devices-1)}"
|
24
|
+
|
25
|
+
|
26
|
+
model_name = '{{model_file_name}}'
|
27
|
+
device = torch.device(get_torch_device())
|
28
|
+
|
29
|
+
"""
|
30
|
+
Inference script. This script is used for prediction by scoring server when schema is known.
|
31
|
+
"""
|
32
|
+
|
33
|
+
@lru_cache(maxsize=10)
|
34
|
+
def load_model(model_file_name=model_name):
|
35
|
+
"""
|
36
|
+
Loads model from the serialized format
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
model: a model instance on which predict API can be invoked
|
41
|
+
"""
|
42
|
+
model_dir = os.path.dirname(os.path.realpath(__file__))
|
43
|
+
if model_dir not in sys.path:
|
44
|
+
sys.path.insert(0, model_dir)
|
45
|
+
contents = os.listdir(model_dir)
|
46
|
+
if model_file_name not in contents:
|
47
|
+
raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
|
48
|
+
|
49
|
+
{% if model_serializer == "torchscript" %}
|
50
|
+
print(f'Start loading {model_file_name} from model directory {model_dir} ...')
|
51
|
+
the_model = torch.jit.load(os.path.join(model_dir, model_file_name))
|
52
|
+
print(f"loading {model_file_name} is complete.")
|
53
|
+
|
54
|
+
{% elif model_serializer == "torch" %}
|
55
|
+
print(f'Start loading {model_file_name} from model directory {model_dir} ...')
|
56
|
+
model_state_dict = torch.load(os.path.join(model_dir, model_file_name))
|
57
|
+
print(f"loading {model_file_name} is complete.")
|
58
|
+
|
59
|
+
# User would need to provide reference to the TheModelClass and
|
60
|
+
# construct the the_model instance first before loading the parameters.
|
61
|
+
# the_model = TheModelClass(*args, **kwargs)
|
62
|
+
try:
|
63
|
+
the_model.load_state_dict(model_state_dict)
|
64
|
+
except NameError as e:
|
65
|
+
raise NotImplementedError("TheModelClass instance must be constructed before loading the parameters. Please modify the load_model() function in score.py." )
|
66
|
+
except Exception as e:
|
67
|
+
raise e
|
68
|
+
|
69
|
+
the_model.eval()
|
70
|
+
{% else %}
|
71
|
+
raise NotImplementedError("`load_model()` needs to be implemented.")
|
72
|
+
|
73
|
+
{% endif %}
|
74
|
+
print("Model is successfully loaded.")
|
75
|
+
the_model = the_model.to(device)
|
76
|
+
return the_model
|
77
|
+
|
78
|
+
@lru_cache(maxsize=1)
|
79
|
+
def fetch_data_type_from_schema(input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
|
80
|
+
"""
|
81
|
+
Returns data type information fetch from input_schema.json.
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
input_schema_path: path of input schema.
|
86
|
+
|
87
|
+
Returns
|
88
|
+
-------
|
89
|
+
data_type: data type fetch from input_schema.json.
|
90
|
+
|
91
|
+
"""
|
92
|
+
data_type = {}
|
93
|
+
if os.path.exists(input_schema_path):
|
94
|
+
schema = json.load(open(input_schema_path))
|
95
|
+
for col in schema['schema']:
|
96
|
+
data_type[col['name']] = col['dtype']
|
97
|
+
else:
|
98
|
+
print("input_schema has to be passed in in order to recover the same data type. pass `X_sample` in `ads.model.framework.pytorch_model.PyTorchModel.prepare` function to generate the input_schema. Otherwise, the data type might be changed after serialization/deserialization.")
|
99
|
+
return data_type
|
100
|
+
|
101
|
+
def deserialize(data, input_schema_path):
|
102
|
+
"""
|
103
|
+
Deserialize json-serialized data to data in original type when sent to
|
104
|
+
predict.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
data: serialized input data.
|
109
|
+
input_schema_path: path of input schema.
|
110
|
+
|
111
|
+
Returns
|
112
|
+
-------
|
113
|
+
data: deserialized input data.
|
114
|
+
|
115
|
+
"""
|
116
|
+
{% if data_deserializer == "json" %}
|
117
|
+
if isinstance(data, bytes):
|
118
|
+
logging.warning(
|
119
|
+
"bytes are passed directly to the model. If the model expects a specific data format, you need to write the conversion logic in `deserialize()` yourself."
|
120
|
+
)
|
121
|
+
return data
|
122
|
+
|
123
|
+
data_type = data.get('data_type', '') if isinstance(data, dict) else ''
|
124
|
+
json_data = data.get('data', data) if isinstance(data, dict) else data
|
125
|
+
|
126
|
+
if "numpy.ndarray" in data_type:
|
127
|
+
load_bytes = BytesIO(base64.b64decode(json_data.encode('utf-8')))
|
128
|
+
return np.load(load_bytes, allow_pickle=True)
|
129
|
+
if "torch.Tensor" in data_type:
|
130
|
+
load_bytes = BytesIO(base64.b64decode(json_data.encode('utf-8')))
|
131
|
+
return torch.load(load_bytes)
|
132
|
+
if "pandas.core.series.Series" in data_type:
|
133
|
+
return pd.Series(json_data)
|
134
|
+
if "pandas.core.frame.DataFrame" in data_type or isinstance(json_data, str):
|
135
|
+
return pd.read_json(json_data, dtype=fetch_data_type_from_schema(input_schema_path))
|
136
|
+
|
137
|
+
return json_data
|
138
|
+
|
139
|
+
{% elif data_deserializer == "cloudpickle" %}
|
140
|
+
import cloudpickle
|
141
|
+
from pickle import UnpicklingError
|
142
|
+
deserialized_data = data
|
143
|
+
try:
|
144
|
+
deserialized_data = cloudpickle.loads(data)
|
145
|
+
except TypeError:
|
146
|
+
pass
|
147
|
+
except UnpicklingError:
|
148
|
+
logger.warning(
|
149
|
+
"bytes are passed directly to the model. If the model expects a specific data format, you need to write the conversion logic in `deserialize()` yourself."
|
150
|
+
)
|
151
|
+
|
152
|
+
return deserialized_data
|
153
|
+
|
154
|
+
{% else %}
|
155
|
+
# Add further data deserialization if needed
|
156
|
+
return data
|
157
|
+
{% endif %}
|
158
|
+
|
159
|
+
def pre_inference(data, input_schema_path):
|
160
|
+
"""
|
161
|
+
Preprocess json-serialized data to feed into predict function.
|
162
|
+
|
163
|
+
Parameters
|
164
|
+
----------
|
165
|
+
data: Data format as expected by the predict API of the core estimator.
|
166
|
+
input_schema_path: path of input schema.
|
167
|
+
|
168
|
+
Returns
|
169
|
+
-------
|
170
|
+
data: Data format after any processing.
|
171
|
+
"""
|
172
|
+
data = deserialize(data, input_schema_path)
|
173
|
+
|
174
|
+
# Add further data preprocessing if needed
|
175
|
+
data = data.to(device)
|
176
|
+
return data
|
177
|
+
|
178
|
+
def post_inference(yhat):
|
179
|
+
"""
|
180
|
+
Post-process the model results.
|
181
|
+
|
182
|
+
Parameters
|
183
|
+
----------
|
184
|
+
yhat: Data format after calling model.predict.
|
185
|
+
|
186
|
+
Returns
|
187
|
+
-------
|
188
|
+
yhat: Data format after any processing.
|
189
|
+
|
190
|
+
"""
|
191
|
+
if isinstance(yhat, torch.Tensor):
|
192
|
+
return yhat.tolist()
|
193
|
+
|
194
|
+
# Add further data postprocessing if needed
|
195
|
+
return yhat
|
196
|
+
|
197
|
+
def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
|
198
|
+
"""
|
199
|
+
Returns prediction given the model and data to predict.
|
200
|
+
|
201
|
+
Parameters
|
202
|
+
----------
|
203
|
+
model: Model instance returned by load_model API
|
204
|
+
data: Data format as expected by the predict API of the core estimator
|
205
|
+
input_schema_path: path of input schema.
|
206
|
+
|
207
|
+
Returns
|
208
|
+
-------
|
209
|
+
predictions: Output from scoring server
|
210
|
+
Format: {'prediction': output from model.predict method}
|
211
|
+
|
212
|
+
"""
|
213
|
+
inputs = pre_inference(data, input_schema_path)
|
214
|
+
|
215
|
+
with torch.no_grad():
|
216
|
+
yhat = post_inference(
|
217
|
+
model(inputs).to("cpu")
|
218
|
+
)
|
219
|
+
return {'prediction': yhat}
|
@@ -0,0 +1,184 @@
|
|
1
|
+
# score.py {{SCORE_VERSION}} generated by ADS {{ADS_VERSION}} on {{time_created}}
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
import json
|
5
|
+
import pandas as pd
|
6
|
+
import numpy as np
|
7
|
+
from functools import lru_cache
|
8
|
+
from io import BytesIO
|
9
|
+
import base64
|
10
|
+
import logging
|
11
|
+
|
12
|
+
model_name = '{{model_file_name}}'
|
13
|
+
|
14
|
+
|
15
|
+
"""
|
16
|
+
Inference script. This script is used for prediction by scoring server when schema is known.
|
17
|
+
"""
|
18
|
+
|
19
|
+
|
20
|
+
@lru_cache(maxsize=10)
|
21
|
+
def load_model(model_file_name=model_name):
|
22
|
+
"""
|
23
|
+
Loads model from the serialized format
|
24
|
+
|
25
|
+
Returns
|
26
|
+
-------
|
27
|
+
model: a model instance on which predict API can be invoked
|
28
|
+
"""
|
29
|
+
model_dir = os.path.dirname(os.path.realpath(__file__))
|
30
|
+
if model_dir not in sys.path:
|
31
|
+
sys.path.insert(0, model_dir)
|
32
|
+
contents = os.listdir(model_dir)
|
33
|
+
if model_file_name not in contents:
|
34
|
+
raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
|
35
|
+
|
36
|
+
{% if model_serializer == "joblib" %}
|
37
|
+
print(f'Start loading {model_file_name} from model directory {model_dir} ...')
|
38
|
+
from joblib import load
|
39
|
+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), model_file_name), "rb") as file:
|
40
|
+
loaded_model = load(file)
|
41
|
+
|
42
|
+
{% elif model_serializer == "cloudpickle" %}
|
43
|
+
import cloudpickle
|
44
|
+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), model_file_name), "rb") as file:
|
45
|
+
loaded_model = cloudpickle.load(file)
|
46
|
+
|
47
|
+
{% else %}
|
48
|
+
raise NotImplementedError("`load_model()` needs to be implemented.")
|
49
|
+
|
50
|
+
{% endif %}
|
51
|
+
print("Model is successfully loaded.")
|
52
|
+
return loaded_model
|
53
|
+
|
54
|
+
|
55
|
+
@lru_cache(maxsize=1)
|
56
|
+
def fetch_data_type_from_schema(input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
|
57
|
+
"""
|
58
|
+
Returns data type information fetch from input_schema.json.
|
59
|
+
|
60
|
+
Parameters
|
61
|
+
----------
|
62
|
+
input_schema_path: path of input schema.
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
data_type: data type fetch from input_schema.json.
|
67
|
+
|
68
|
+
"""
|
69
|
+
data_type = {}
|
70
|
+
if os.path.exists(input_schema_path):
|
71
|
+
schema = json.load(open(input_schema_path))
|
72
|
+
for col in schema['schema']:
|
73
|
+
data_type[col['name']] = col['dtype']
|
74
|
+
else:
|
75
|
+
print("input_schema has to be passed in in order to recover the same data type. pass `X_sample` in `ads.model.framework.sklearn_model.SklearnModel.prepare` function to generate the input_schema. Otherwise, the data type might be changed after serialization/deserialization.")
|
76
|
+
return data_type
|
77
|
+
|
78
|
+
def deserialize(data, input_schema_path):
|
79
|
+
"""
|
80
|
+
Deserialize json serialization data to data in original type when sent to predict.
|
81
|
+
|
82
|
+
Parameters
|
83
|
+
----------
|
84
|
+
data: serialized input data.
|
85
|
+
input_schema_path: path of input schema.
|
86
|
+
|
87
|
+
Returns
|
88
|
+
-------
|
89
|
+
data: deserialized input data.
|
90
|
+
|
91
|
+
"""
|
92
|
+
{% if data_deserializer == "json" %}
|
93
|
+
if isinstance(data, bytes):
|
94
|
+
logging.warning(
|
95
|
+
"bytes are passed directly to the model. If the model expects a specific data format, you need to write the conversion logic in `deserialize()` yourself."
|
96
|
+
)
|
97
|
+
return data
|
98
|
+
|
99
|
+
data_type = data.get('data_type', '') if isinstance(data, dict) else ''
|
100
|
+
json_data = data.get('data', data) if isinstance(data, dict) else data
|
101
|
+
|
102
|
+
if "numpy.ndarray" in data_type:
|
103
|
+
load_bytes = BytesIO(base64.b64decode(json_data.encode('utf-8')))
|
104
|
+
return np.load(load_bytes, allow_pickle=True)
|
105
|
+
if "pandas.core.series.Series" in data_type:
|
106
|
+
return pd.Series(json_data)
|
107
|
+
if "pandas.core.frame.DataFrame" in data_type or isinstance(json_data, str):
|
108
|
+
return pd.read_json(json_data, dtype=fetch_data_type_from_schema(input_schema_path))
|
109
|
+
if isinstance(json_data, dict):
|
110
|
+
return pd.DataFrame.from_dict(json_data)
|
111
|
+
|
112
|
+
return json_data
|
113
|
+
|
114
|
+
{% elif data_deserializer == "cloudpickle" %}
|
115
|
+
import cloudpickle
|
116
|
+
from pickle import UnpicklingError
|
117
|
+
deserialized_data = data
|
118
|
+
try:
|
119
|
+
deserialized_data = cloudpickle.loads(data)
|
120
|
+
except TypeError:
|
121
|
+
pass
|
122
|
+
except UnpicklingError:
|
123
|
+
logger.warning(
|
124
|
+
"bytes are passed directly to the model. If the model expects a specific data format, you need to write the conversion logic in `deserialize()` yourself."
|
125
|
+
)
|
126
|
+
|
127
|
+
return deserialized_data
|
128
|
+
|
129
|
+
{% else %}
|
130
|
+
# Add further data deserialization if needed
|
131
|
+
return data
|
132
|
+
{% endif %}
|
133
|
+
|
134
|
+
def pre_inference(data, input_schema_path):
|
135
|
+
"""
|
136
|
+
Preprocess data
|
137
|
+
|
138
|
+
Parameters
|
139
|
+
----------
|
140
|
+
data: Data format as expected by the predict API of the core estimator.
|
141
|
+
input_schema_path: path of input schema.
|
142
|
+
|
143
|
+
Returns
|
144
|
+
-------
|
145
|
+
data: Data format after any processing.
|
146
|
+
|
147
|
+
"""
|
148
|
+
data = deserialize(data, input_schema_path)
|
149
|
+
return data
|
150
|
+
|
151
|
+
def post_inference(yhat):
|
152
|
+
"""
|
153
|
+
Post-process the model results
|
154
|
+
|
155
|
+
Parameters
|
156
|
+
----------
|
157
|
+
yhat: Data format after calling model.predict.
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
yhat: Data format after any processing.
|
162
|
+
|
163
|
+
"""
|
164
|
+
return yhat.tolist()
|
165
|
+
|
166
|
+
def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
|
167
|
+
"""
|
168
|
+
Returns prediction given the model and data to predict
|
169
|
+
|
170
|
+
Parameters
|
171
|
+
----------
|
172
|
+
model: Model instance returned by load_model API
|
173
|
+
data: Data format as expected by the predict API of the core estimator. For eg. in case of sckit models it could be numpy array/List of list/Pandas DataFrame
|
174
|
+
input_schema_path: path of input schema.
|
175
|
+
|
176
|
+
Returns
|
177
|
+
-------
|
178
|
+
predictions: Output from scoring server
|
179
|
+
Format: {'prediction': output from model.predict method}
|
180
|
+
|
181
|
+
"""
|
182
|
+
input = pre_inference(data, input_schema_path)
|
183
|
+
yhat = post_inference(model.predict(input))
|
184
|
+
return {'prediction': yhat}
|