oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +246 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from ..const import AUTO_SELECT, SpeedAccuracyMode, SupportedModels
|
7
|
+
from ..model_evaluator import ModelEvaluator
|
8
|
+
from ..operator_config import ForecastOperatorConfig
|
9
|
+
from .arima import ArimaOperatorModel
|
10
|
+
from .automlx import AutoMLXOperatorModel
|
11
|
+
from .autots import AutoTSOperatorModel
|
12
|
+
from .base_model import ForecastOperatorBaseModel
|
13
|
+
from .forecast_datasets import ForecastDatasets
|
14
|
+
from .ml_forecast import MLForecastOperatorModel
|
15
|
+
from .neuralprophet import NeuralProphetOperatorModel
|
16
|
+
from .prophet import ProphetOperatorModel
|
17
|
+
|
18
|
+
|
19
|
+
class UnSupportedModelError(Exception):
|
20
|
+
def __init__(self, model_type: str):
|
21
|
+
super().__init__(
|
22
|
+
f"Model: `{model_type}` "
|
23
|
+
f"is not supported. Supported models: {SupportedModels.values()}"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class ForecastOperatorModelFactory:
|
28
|
+
"""
|
29
|
+
The factory class helps to instantiate proper model operator based on the model type.
|
30
|
+
"""
|
31
|
+
|
32
|
+
_MAP = {
|
33
|
+
SupportedModels.Prophet: ProphetOperatorModel,
|
34
|
+
SupportedModels.Arima: ArimaOperatorModel,
|
35
|
+
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
|
36
|
+
SupportedModels.LGBForecast: MLForecastOperatorModel,
|
37
|
+
SupportedModels.AutoMLX: AutoMLXOperatorModel,
|
38
|
+
SupportedModels.AutoTS: AutoTSOperatorModel,
|
39
|
+
}
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def get_model(
|
43
|
+
cls, operator_config: ForecastOperatorConfig, datasets: ForecastDatasets
|
44
|
+
) -> ForecastOperatorBaseModel:
|
45
|
+
"""
|
46
|
+
Gets the forecasting operator model based on the model type.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
operator_config: ForecastOperatorConfig
|
51
|
+
The forecasting operator config.
|
52
|
+
datasets: ForecastDatasets
|
53
|
+
Datasets for predictions
|
54
|
+
|
55
|
+
Returns
|
56
|
+
-------
|
57
|
+
ForecastOperatorBaseModel
|
58
|
+
The forecast operator model.
|
59
|
+
|
60
|
+
Raises
|
61
|
+
------
|
62
|
+
UnSupportedModelError
|
63
|
+
In case of not supported model.
|
64
|
+
"""
|
65
|
+
model_type = operator_config.spec.model
|
66
|
+
if model_type == AUTO_SELECT:
|
67
|
+
model_type = cls.auto_select_model(datasets, operator_config)
|
68
|
+
operator_config.spec.model_kwargs = {}
|
69
|
+
# set the explanations accuracy mode to AUTOMLX if the selected model is automlx
|
70
|
+
if (
|
71
|
+
model_type == SupportedModels.AutoMLX
|
72
|
+
and operator_config.spec.explanations_accuracy_mode
|
73
|
+
== SpeedAccuracyMode.FAST_APPROXIMATE
|
74
|
+
):
|
75
|
+
operator_config.spec.explanations_accuracy_mode = SpeedAccuracyMode.AUTOMLX
|
76
|
+
if model_type not in cls._MAP:
|
77
|
+
raise UnSupportedModelError(model_type)
|
78
|
+
return cls._MAP[model_type](config=operator_config, datasets=datasets)
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def auto_select_model(
|
82
|
+
cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
|
83
|
+
) -> str:
|
84
|
+
"""
|
85
|
+
Selects AutoMLX or Arima model based on column count.
|
86
|
+
|
87
|
+
If the number of columns is less than or equal to the maximum allowed for AutoMLX,
|
88
|
+
returns 'AutoMLX'. Otherwise, returns 'Arima'.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
------------
|
92
|
+
datasets: ForecastDatasets
|
93
|
+
Datasets for predictions
|
94
|
+
|
95
|
+
Returns
|
96
|
+
--------
|
97
|
+
str
|
98
|
+
The type of the model.
|
99
|
+
"""
|
100
|
+
all_models = operator_config.spec.model_kwargs.get(
|
101
|
+
"model_list", cls._MAP.keys()
|
102
|
+
)
|
103
|
+
num_backtests = operator_config.spec.model_kwargs.get("num_backtests", 5)
|
104
|
+
sample_ratio = operator_config.spec.model_kwargs.get("sample_ratio", 0.20)
|
105
|
+
model_evaluator = ModelEvaluator(all_models, num_backtests, sample_ratio)
|
106
|
+
return model_evaluator.find_best_model(datasets, operator_config)
|
@@ -0,0 +1,492 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from typing import Dict, List
|
7
|
+
|
8
|
+
import pandas as pd
|
9
|
+
|
10
|
+
from ads.opctl import logger
|
11
|
+
from ads.opctl.operator.lowcode.common.data import AbstractData
|
12
|
+
from ads.opctl.operator.lowcode.common.errors import (
|
13
|
+
DataMismatchError,
|
14
|
+
InvalidParameterError,
|
15
|
+
)
|
16
|
+
from ads.opctl.operator.lowcode.common.utils import (
|
17
|
+
get_frequency_in_seconds,
|
18
|
+
get_frequency_of_datetime,
|
19
|
+
)
|
20
|
+
|
21
|
+
from ..const import ForecastOutputColumns, SupportedModels
|
22
|
+
from ..operator_config import ForecastOperatorConfig
|
23
|
+
|
24
|
+
|
25
|
+
class HistoricalData(AbstractData):
|
26
|
+
def __init__(self, spec, historical_data=None):
|
27
|
+
super().__init__(spec=spec, name="historical_data", data=historical_data)
|
28
|
+
|
29
|
+
def _ingest_data(self, spec):
|
30
|
+
try:
|
31
|
+
self.freq = get_frequency_of_datetime(self.data.index.get_level_values(0))
|
32
|
+
except TypeError as e:
|
33
|
+
logger.warning(
|
34
|
+
f"Error determining frequency: {e.args}. Setting Frequency to None"
|
35
|
+
)
|
36
|
+
logger.debug(f"Full traceback: {e}")
|
37
|
+
self.freq = None
|
38
|
+
self._verify_dt_col(spec)
|
39
|
+
super()._ingest_data(spec)
|
40
|
+
|
41
|
+
def _verify_dt_col(self, spec):
|
42
|
+
# Check frequency is compatible with model type
|
43
|
+
self.freq_in_secs = get_frequency_in_seconds(
|
44
|
+
self.data.index.get_level_values(0)
|
45
|
+
)
|
46
|
+
if spec.model == SupportedModels.AutoMLX and abs(self.freq_in_secs) < 3600:
|
47
|
+
message = (
|
48
|
+
f"{SupportedModels.AutoMLX} requires data with a frequency of at least one hour. Please try using a different model,"
|
49
|
+
" or select the 'auto' option."
|
50
|
+
)
|
51
|
+
raise InvalidParameterError(message)
|
52
|
+
|
53
|
+
|
54
|
+
class AdditionalData(AbstractData):
|
55
|
+
def __init__(self, spec, historical_data, additional_data=None):
|
56
|
+
if additional_data is not None:
|
57
|
+
super().__init__(spec=spec, name="additional_data", data=additional_data)
|
58
|
+
self.additional_regressors = list(self.data.columns)
|
59
|
+
elif spec.additional_data is not None:
|
60
|
+
super().__init__(spec=spec, name="additional_data")
|
61
|
+
add_dates = self.data.index.get_level_values(0).unique().tolist()
|
62
|
+
add_dates.sort()
|
63
|
+
if historical_data.get_max_time() > add_dates[-spec.horizon]:
|
64
|
+
raise DataMismatchError(
|
65
|
+
f"The Historical Data ends on {historical_data.get_max_time()}. The additional data horizon starts on {add_dates[-spec.horizon]}. The horizon should have exactly {spec.horizon} dates after the Historical at a frequency of {historical_data.freq}"
|
66
|
+
)
|
67
|
+
elif historical_data.get_max_time() != add_dates[-(spec.horizon + 1)]:
|
68
|
+
raise DataMismatchError(
|
69
|
+
f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon+1)]}. These should be the same date."
|
70
|
+
)
|
71
|
+
else:
|
72
|
+
self.name = "additional_data"
|
73
|
+
self.data = None
|
74
|
+
self._data_dict = {}
|
75
|
+
self.create_horizon(spec, historical_data)
|
76
|
+
|
77
|
+
def create_horizon(self, spec, historical_data):
|
78
|
+
logger.debug("No additional data provided. Constructing horizon.")
|
79
|
+
future_dates = pd.Series(
|
80
|
+
pd.date_range(
|
81
|
+
start=historical_data.get_max_time(),
|
82
|
+
periods=spec.horizon + 1,
|
83
|
+
freq=historical_data.freq
|
84
|
+
or pd.infer_freq(
|
85
|
+
historical_data.data.reset_index()[spec.datetime_column.name][-5:]
|
86
|
+
),
|
87
|
+
),
|
88
|
+
name=spec.datetime_column.name,
|
89
|
+
)
|
90
|
+
add_dfs = []
|
91
|
+
for s_id in historical_data.list_series_ids():
|
92
|
+
df_i = historical_data.get_data_for_series(s_id)[spec.datetime_column.name]
|
93
|
+
df_i = pd.DataFrame(pd.concat([df_i, future_dates[1:]]))
|
94
|
+
df_i[ForecastOutputColumns.SERIES] = s_id
|
95
|
+
df_i = df_i.set_index(
|
96
|
+
[spec.datetime_column.name, ForecastOutputColumns.SERIES]
|
97
|
+
)
|
98
|
+
add_dfs.append(df_i)
|
99
|
+
data = pd.concat(add_dfs, axis=1)
|
100
|
+
self.data = data.sort_values(
|
101
|
+
[spec.datetime_column.name, ForecastOutputColumns.SERIES], ascending=True
|
102
|
+
)
|
103
|
+
self.additional_regressors = []
|
104
|
+
|
105
|
+
def _ingest_data(self, spec):
|
106
|
+
_spec = spec
|
107
|
+
self.additional_regressors = list(self.data.columns)
|
108
|
+
if not self.additional_regressors:
|
109
|
+
logger.warning(
|
110
|
+
f"No additional variables found in the additional_data. Only columns found: {self.data.columns}. Skipping for now."
|
111
|
+
)
|
112
|
+
# Check that datetime column matches historical datetime column
|
113
|
+
|
114
|
+
|
115
|
+
class TestData(AbstractData):
|
116
|
+
def __init__(self, spec, test_data):
|
117
|
+
if test_data is not None or spec.test_data is not None:
|
118
|
+
super().__init__(spec=spec, name="test_data", data=test_data)
|
119
|
+
self.dt_column_name = spec.datetime_column.name
|
120
|
+
self.target_name = spec.target_column
|
121
|
+
|
122
|
+
|
123
|
+
class ForecastDatasets:
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
config: ForecastOperatorConfig,
|
127
|
+
historical_data=None,
|
128
|
+
additional_data=None,
|
129
|
+
test_data=None,
|
130
|
+
):
|
131
|
+
"""Instantiates the DataIO instance.
|
132
|
+
|
133
|
+
Properties
|
134
|
+
----------
|
135
|
+
config: ForecastOperatorConfig
|
136
|
+
The forecast operator configuration.
|
137
|
+
"""
|
138
|
+
self.historical_data: HistoricalData = None
|
139
|
+
self.additional_data: AdditionalData = None
|
140
|
+
self._horizon = config.spec.horizon
|
141
|
+
self._datetime_column_name = config.spec.datetime_column.name
|
142
|
+
self._target_col = config.spec.target_column
|
143
|
+
if historical_data is not None:
|
144
|
+
self.historical_data = HistoricalData(config.spec, historical_data)
|
145
|
+
self.additional_data = AdditionalData(
|
146
|
+
config.spec, self.historical_data, additional_data
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
self._load_data(config.spec)
|
150
|
+
self.test_data = TestData(config.spec, test_data)
|
151
|
+
|
152
|
+
def _load_data(self, spec):
|
153
|
+
"""Loads forecasting input data."""
|
154
|
+
self.historical_data = HistoricalData(spec)
|
155
|
+
self.additional_data = AdditionalData(spec, self.historical_data)
|
156
|
+
|
157
|
+
if spec.generate_explanations and spec.additional_data is None:
|
158
|
+
logger.warning(
|
159
|
+
"Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
|
160
|
+
)
|
161
|
+
spec.generate_explanations = False
|
162
|
+
|
163
|
+
def get_all_data_long(self, include_horizon=True):
|
164
|
+
how = "outer" if include_horizon else "left"
|
165
|
+
return pd.merge(
|
166
|
+
self.historical_data.data,
|
167
|
+
self.additional_data.data,
|
168
|
+
how=how,
|
169
|
+
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
|
170
|
+
).reset_index()
|
171
|
+
|
172
|
+
def get_all_data_long_forecast_horizon(self):
|
173
|
+
"""Returns all data in long format for the forecast horizon."""
|
174
|
+
test_data = pd.merge(
|
175
|
+
self.historical_data.data,
|
176
|
+
self.additional_data.data,
|
177
|
+
how="outer",
|
178
|
+
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
|
179
|
+
).reset_index()
|
180
|
+
return test_data[test_data[self._target_col].isnull()].reset_index(drop=True)
|
181
|
+
|
182
|
+
def get_data_multi_indexed(self):
|
183
|
+
return pd.concat(
|
184
|
+
[
|
185
|
+
self.historical_data.data,
|
186
|
+
self.additional_data.data,
|
187
|
+
],
|
188
|
+
axis=1,
|
189
|
+
)
|
190
|
+
|
191
|
+
def get_data_by_series(self, include_horizon=True):
|
192
|
+
total_dict = {}
|
193
|
+
hist_data = self.historical_data.get_dict_by_series()
|
194
|
+
add_data = self.additional_data.get_dict_by_series()
|
195
|
+
how = "outer" if include_horizon else "left"
|
196
|
+
for s_id in self.list_series_ids():
|
197
|
+
# Note: ensure no duplicate column names
|
198
|
+
total_dict[s_id] = pd.merge(
|
199
|
+
hist_data[s_id],
|
200
|
+
add_data[s_id],
|
201
|
+
how=how,
|
202
|
+
on=[self._datetime_column_name],
|
203
|
+
)
|
204
|
+
return total_dict
|
205
|
+
|
206
|
+
def get_data_at_series(self, s_id, include_horizon=True):
|
207
|
+
all_data = self.get_data_by_series(include_horizon=include_horizon)
|
208
|
+
try:
|
209
|
+
return all_data[s_id]
|
210
|
+
except Exception as e:
|
211
|
+
raise InvalidParameterError(
|
212
|
+
f"Unable to retrieve series id: {s_id} from data. Available series ids are: {self.list_series_ids()}"
|
213
|
+
) from e
|
214
|
+
|
215
|
+
def get_horizon_at_series(self, s_id):
|
216
|
+
return self.get_data_at_series(s_id)[-self._horizon :]
|
217
|
+
|
218
|
+
def has_artificial_series(self):
|
219
|
+
return bool(self.historical_data.spec.target_category_columns)
|
220
|
+
|
221
|
+
def get_earliest_timestamp(self):
|
222
|
+
return self.historical_data.get_min_time()
|
223
|
+
|
224
|
+
def get_latest_timestamp(self):
|
225
|
+
return self.historical_data.get_max_time()
|
226
|
+
|
227
|
+
def get_additional_data_column_names(self):
|
228
|
+
return self.additional_data.additional_regressors
|
229
|
+
|
230
|
+
def get_datetime_frequency(self):
|
231
|
+
return self.historical_data.freq
|
232
|
+
|
233
|
+
def get_datetime_frequency_in_seconds(self):
|
234
|
+
return self.historical_data.freq_in_secs
|
235
|
+
|
236
|
+
def get_num_rows(self):
|
237
|
+
return self.historical_data.get_num_rows()
|
238
|
+
|
239
|
+
def list_series_ids(self, sorted=True):
|
240
|
+
series_ids = self.historical_data.list_series_ids()
|
241
|
+
if sorted:
|
242
|
+
try:
|
243
|
+
series_ids.sort()
|
244
|
+
except Exception:
|
245
|
+
pass
|
246
|
+
return series_ids
|
247
|
+
|
248
|
+
def format_wide(self):
|
249
|
+
data_merged = pd.concat(
|
250
|
+
[
|
251
|
+
v[v[k].notna()].set_index(self._datetime_column_name)
|
252
|
+
for k, v in self.get_data_by_series().items()
|
253
|
+
],
|
254
|
+
axis=1,
|
255
|
+
).reset_index()
|
256
|
+
return data_merged
|
257
|
+
|
258
|
+
def get_longest_datetime_column(self):
|
259
|
+
return self.format_wide()[self._datetime_column_name]
|
260
|
+
|
261
|
+
|
262
|
+
class ForecastOutput:
|
263
|
+
def __init__(
|
264
|
+
self,
|
265
|
+
confidence_interval_width: float,
|
266
|
+
horizon: int,
|
267
|
+
target_column: str,
|
268
|
+
dt_column: str,
|
269
|
+
):
|
270
|
+
"""Forecast Output contains all the details required to generate the forecast.csv output file.
|
271
|
+
|
272
|
+
init
|
273
|
+
-------
|
274
|
+
confidence_interval_width: float value from OperatorSpec
|
275
|
+
horizon: int length of horizon
|
276
|
+
target_column: str the name of the original target column
|
277
|
+
dt_column: the name of the original datetime column
|
278
|
+
"""
|
279
|
+
self.series_id_map = {}
|
280
|
+
self._set_ci_column_names(confidence_interval_width)
|
281
|
+
self.horizon = horizon
|
282
|
+
self.target_column_name = target_column
|
283
|
+
self.dt_column_name = dt_column
|
284
|
+
|
285
|
+
def add_series_id(
|
286
|
+
self,
|
287
|
+
series_id: str,
|
288
|
+
forecast: pd.DataFrame,
|
289
|
+
overwrite: bool = False,
|
290
|
+
):
|
291
|
+
if not overwrite and series_id in self.series_id_map:
|
292
|
+
raise ValueError(
|
293
|
+
f"Attempting to update ForecastOutput for series_id {series_id} when this already exists. Set overwrite to True."
|
294
|
+
)
|
295
|
+
forecast = self._check_forecast_format(forecast)
|
296
|
+
self.series_id_map[series_id] = forecast
|
297
|
+
|
298
|
+
def init_series_output(self, series_id, data_at_series):
|
299
|
+
output_i = pd.DataFrame()
|
300
|
+
|
301
|
+
output_i["Date"] = data_at_series[self.dt_column_name]
|
302
|
+
output_i["Series"] = series_id
|
303
|
+
output_i["input_value"] = data_at_series[self.target_column_name]
|
304
|
+
|
305
|
+
output_i["fitted_value"] = float("nan")
|
306
|
+
output_i["forecast_value"] = float("nan")
|
307
|
+
output_i[self.lower_bound_name] = float("nan")
|
308
|
+
output_i[self.upper_bound_name] = float("nan")
|
309
|
+
self.series_id_map[series_id] = output_i
|
310
|
+
|
311
|
+
def populate_series_output(
|
312
|
+
self, series_id, fit_val, forecast_val, upper_bound, lower_bound
|
313
|
+
):
|
314
|
+
"""
|
315
|
+
This method should be run after init_series_output has been run on this series_id
|
316
|
+
|
317
|
+
Parameters:
|
318
|
+
-----------
|
319
|
+
series_id: [str, int] the series being forecasted
|
320
|
+
fit_val: numpy.array of length input_value - horizon
|
321
|
+
forecast_val: numpy.array of length horizon containing the forecasted values
|
322
|
+
upper_bound: numpy.array of length horizon containing the upper_bound values
|
323
|
+
lower_bound: numpy.array of length horizon containing the lower_bound values
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
--------
|
327
|
+
None
|
328
|
+
"""
|
329
|
+
try:
|
330
|
+
output_i = self.series_id_map[series_id]
|
331
|
+
except KeyError as e:
|
332
|
+
raise ValueError(
|
333
|
+
f"Attempting to update output for series: {series_id}, however no series output has been initialized."
|
334
|
+
) from e
|
335
|
+
|
336
|
+
if (output_i.shape[0] - self.horizon) == len(fit_val):
|
337
|
+
output_i["fitted_value"].iloc[: -self.horizon] = (
|
338
|
+
fit_val # Note: may need to do len(output_i) - (len(fit_val) + horizon) : -horizon
|
339
|
+
)
|
340
|
+
elif (output_i.shape[0] - self.horizon) > len(fit_val):
|
341
|
+
logger.debug(
|
342
|
+
f"Fitted Values were only generated on a subset ({len(fit_val)}/{(output_i.shape[0] - self.horizon)}) of the data for Series: {series_id}."
|
343
|
+
)
|
344
|
+
start_idx = output_i.shape[0] - self.horizon - len(fit_val)
|
345
|
+
output_i["fitted_value"].iloc[start_idx : -self.horizon] = fit_val
|
346
|
+
else:
|
347
|
+
output_i["fitted_value"].iloc[start_idx : -self.horizon] = fit_val[
|
348
|
+
-(output_i.shape[0] - self.horizon) :
|
349
|
+
]
|
350
|
+
|
351
|
+
if len(forecast_val) != self.horizon:
|
352
|
+
raise ValueError(
|
353
|
+
f"Attempting to set forecast along horizon ({self.horizon}) for series: {series_id}, however forecast is only length {len(forecast_val)}"
|
354
|
+
)
|
355
|
+
output_i["forecast_value"].iloc[-self.horizon :] = forecast_val
|
356
|
+
|
357
|
+
if len(upper_bound) != self.horizon:
|
358
|
+
raise ValueError(
|
359
|
+
f"Attempting to set upper_bound along horizon ({self.horizon}) for series: {series_id}, however upper_bound is only length {len(upper_bound)}"
|
360
|
+
)
|
361
|
+
output_i[self.upper_bound_name].iloc[-self.horizon :] = upper_bound
|
362
|
+
|
363
|
+
if len(lower_bound) != self.horizon:
|
364
|
+
raise ValueError(
|
365
|
+
f"Attempting to set lower_bound along horizon ({self.horizon}) for series: {series_id}, however lower_bound is only length {len(lower_bound)}"
|
366
|
+
)
|
367
|
+
output_i[self.lower_bound_name].iloc[-self.horizon :] = lower_bound
|
368
|
+
|
369
|
+
self.series_id_map[series_id] = output_i
|
370
|
+
self.verify_series_output(series_id)
|
371
|
+
|
372
|
+
def verify_series_output(self, series_id):
|
373
|
+
forecast = self.series_id_map[series_id]
|
374
|
+
self._check_forecast_format(forecast)
|
375
|
+
|
376
|
+
def get_horizon_by_series(self, series_id):
|
377
|
+
return self.series_id_map[series_id][-self.horizon :]
|
378
|
+
|
379
|
+
def get_horizon_long(self):
|
380
|
+
df = pd.DataFrame()
|
381
|
+
for s_id in self.list_series_ids():
|
382
|
+
df = pd.concat([df, self.get_horizon_by_series(s_id)])
|
383
|
+
return df.reset_index(drop=True)
|
384
|
+
|
385
|
+
def get_forecast(self, series_id):
|
386
|
+
try:
|
387
|
+
return self.series_id_map[series_id]
|
388
|
+
except KeyError:
|
389
|
+
logger.debug(
|
390
|
+
f"No Forecast found for series_id: {series_id}. Returning empty DataFrame."
|
391
|
+
)
|
392
|
+
return pd.DataFrame()
|
393
|
+
|
394
|
+
def list_series_ids(self, sorted=True):
|
395
|
+
series_ids = list(self.series_id_map.keys())
|
396
|
+
if sorted:
|
397
|
+
try:
|
398
|
+
series_ids.sort()
|
399
|
+
except Exception:
|
400
|
+
pass
|
401
|
+
return series_ids
|
402
|
+
|
403
|
+
def _set_ci_column_names(self, confidence_interval_width):
|
404
|
+
yhat_lower_percentage = (100 - confidence_interval_width * 100) // 2
|
405
|
+
self.upper_bound_name = "p" + str(int(100 - yhat_lower_percentage))
|
406
|
+
self.lower_bound_name = "p" + str(int(yhat_lower_percentage))
|
407
|
+
|
408
|
+
def _check_forecast_format(self, forecast):
|
409
|
+
assert isinstance(forecast, pd.DataFrame)
|
410
|
+
assert (
|
411
|
+
len(forecast.columns) == 7
|
412
|
+
), f"Expected just 7 columns, but got: {forecast.columns}"
|
413
|
+
assert ForecastOutputColumns.DATE in forecast.columns
|
414
|
+
assert ForecastOutputColumns.SERIES in forecast.columns
|
415
|
+
assert ForecastOutputColumns.INPUT_VALUE in forecast.columns
|
416
|
+
assert ForecastOutputColumns.FITTED_VALUE in forecast.columns
|
417
|
+
assert ForecastOutputColumns.FORECAST_VALUE in forecast.columns
|
418
|
+
assert self.upper_bound_name in forecast.columns
|
419
|
+
assert self.lower_bound_name in forecast.columns
|
420
|
+
assert not forecast.empty
|
421
|
+
# forecast.columns = pd.Index([
|
422
|
+
# ForecastOutputColumns.DATE,
|
423
|
+
# ForecastOutputColumns.SERIES,
|
424
|
+
# ForecastOutputColumns.INPUT_VALUE,
|
425
|
+
# ForecastOutputColumns.FITTED_VALUE,
|
426
|
+
# ForecastOutputColumns.FORECAST_VALUE,
|
427
|
+
# ForecastOutputColumns.UPPER_BOUND,
|
428
|
+
# ForecastOutputColumns.LOWER_BOUND,
|
429
|
+
# ])
|
430
|
+
return forecast
|
431
|
+
|
432
|
+
def get_forecast_long(self):
|
433
|
+
output = pd.DataFrame()
|
434
|
+
for df in self.series_id_map.values():
|
435
|
+
output = pd.concat([output, df])
|
436
|
+
return output.reset_index(drop=True)
|
437
|
+
|
438
|
+
|
439
|
+
class ForecastResults:
|
440
|
+
"""
|
441
|
+
Forecast Results contains all outputs from the forecast run.
|
442
|
+
This class is returned to users who use the Forecast's `operate` method.
|
443
|
+
|
444
|
+
"""
|
445
|
+
|
446
|
+
def set_forecast(self, df: pd.DataFrame):
|
447
|
+
self.forecast = df
|
448
|
+
|
449
|
+
def get_forecast(self):
|
450
|
+
return getattr(self, "forecast", None)
|
451
|
+
|
452
|
+
def set_metrics(self, df: pd.DataFrame):
|
453
|
+
self.metrics = df
|
454
|
+
|
455
|
+
def get_metrics(self):
|
456
|
+
return getattr(self, "metrics", None)
|
457
|
+
|
458
|
+
def set_test_metrics(self, df: pd.DataFrame):
|
459
|
+
self.test_metrics = df
|
460
|
+
|
461
|
+
def get_test_metrics(self):
|
462
|
+
return getattr(self, "test_metrics", None)
|
463
|
+
|
464
|
+
def set_local_explanations(self, df: pd.DataFrame):
|
465
|
+
self.local_explanations = df
|
466
|
+
|
467
|
+
def get_local_explanations(self):
|
468
|
+
return getattr(self, "local_explanations", None)
|
469
|
+
|
470
|
+
def set_global_explanations(self, df: pd.DataFrame):
|
471
|
+
self.global_explanations = df
|
472
|
+
|
473
|
+
def get_global_explanations(self):
|
474
|
+
return getattr(self, "global_explanations", None)
|
475
|
+
|
476
|
+
def set_model_parameters(self, df: pd.DataFrame):
|
477
|
+
self.model_parameters = df
|
478
|
+
|
479
|
+
def get_model_parameters(self):
|
480
|
+
return getattr(self, "model_parameters", None)
|
481
|
+
|
482
|
+
def set_models(self, models: List):
|
483
|
+
self.models = models
|
484
|
+
|
485
|
+
def get_models(self):
|
486
|
+
return getattr(self, "models", None)
|
487
|
+
|
488
|
+
def set_errors_dict(self, errors_dict: Dict):
|
489
|
+
self.errors_dict = errors_dict
|
490
|
+
|
491
|
+
def get_errors_dict(self):
|
492
|
+
return getattr(self, "errors_dict", None)
|