oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +247 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
ads/hpo/search_cv.py
ADDED
@@ -0,0 +1,1657 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
import importlib
|
8
|
+
import multiprocessing
|
9
|
+
import os
|
10
|
+
import uuid
|
11
|
+
import psutil
|
12
|
+
from enum import Enum, auto
|
13
|
+
from time import time, sleep
|
14
|
+
|
15
|
+
import matplotlib.pyplot as plt
|
16
|
+
import numpy as np
|
17
|
+
import pandas as pd
|
18
|
+
|
19
|
+
import logging
|
20
|
+
from ads.common import logger
|
21
|
+
from ads.common import utils
|
22
|
+
from ads.common.data import ADSData
|
23
|
+
from ads.common.decorator.runtime_dependency import (
|
24
|
+
runtime_dependency,
|
25
|
+
OptionalDependency,
|
26
|
+
)
|
27
|
+
from ads.hpo._imports import try_import
|
28
|
+
from ads.hpo.ads_search_space import get_model2searchspace
|
29
|
+
from ads.hpo.distributions import *
|
30
|
+
from ads.hpo.objective import _Objective
|
31
|
+
from ads.hpo.stopping_criterion import NTrials, ScoreValue, TimeBudget
|
32
|
+
from ads.hpo.utils import _num_samples, _safe_indexing, _update_space_name
|
33
|
+
from ads.hpo.validation import (
|
34
|
+
assert_is_estimator,
|
35
|
+
assert_model_is_supported,
|
36
|
+
assert_strategy_valid,
|
37
|
+
assert_tuner_is_fitted,
|
38
|
+
validate_fit_params,
|
39
|
+
validate_pipeline,
|
40
|
+
validate_search_space,
|
41
|
+
validate_params_for_plot,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
with try_import() as _imports:
|
46
|
+
from sklearn.base import BaseEstimator, clone, is_classifier
|
47
|
+
from sklearn.model_selection import BaseCrossValidator # NOQA
|
48
|
+
from sklearn.model_selection import check_cv, cross_validate
|
49
|
+
from sklearn.pipeline import Pipeline, make_pipeline
|
50
|
+
from sklearn.utils import check_random_state
|
51
|
+
from sklearn.exceptions import NotFittedError
|
52
|
+
|
53
|
+
try:
|
54
|
+
from sklearn.metrics import check_scoring
|
55
|
+
except:
|
56
|
+
from sklearn.metrics.scorer import check_scoring
|
57
|
+
|
58
|
+
|
59
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union # NOQA
|
60
|
+
|
61
|
+
|
62
|
+
class State(Enum):
|
63
|
+
INITIATED = auto()
|
64
|
+
RUNNING = auto()
|
65
|
+
HALTED = auto()
|
66
|
+
TERMINATED = auto()
|
67
|
+
COMPLETED = auto()
|
68
|
+
|
69
|
+
|
70
|
+
class InvalidStateTransition(Exception): # pragma: no cover
|
71
|
+
"""
|
72
|
+
`Invalid State Transition` is raised when an invalid transition request is made, such as calling
|
73
|
+
halt without a running process.
|
74
|
+
"""
|
75
|
+
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
class ExitCriterionError(Exception): # pragma: no cover
|
80
|
+
"""
|
81
|
+
`ExitCriterionError` is raised when an attempt is made to check exit status for a different exit
|
82
|
+
type than the tuner was initialized with. For example, if an HPO study has an exit criteria based
|
83
|
+
on the number of trials and a request is made for the time remaining, which is a different exit
|
84
|
+
criterion, an exception is raised.
|
85
|
+
"""
|
86
|
+
|
87
|
+
pass
|
88
|
+
|
89
|
+
|
90
|
+
class DuplicatedStudyError(Exception): # pragma: no cover
|
91
|
+
"""
|
92
|
+
`DuplicatedStudyError` is raised when a new tuner process is created with a study name that
|
93
|
+
already exists in storage.
|
94
|
+
"""
|
95
|
+
|
96
|
+
|
97
|
+
class NoRestartError(Exception): # pragma: no cover
|
98
|
+
"""
|
99
|
+
`NoRestartError` is raised when an attempt is made to check how many seconds have transpired since
|
100
|
+
the HPO process was last resumed from a halt. This can happen if the process has been terminated
|
101
|
+
or it was never halted and then resumed to begin with.
|
102
|
+
"""
|
103
|
+
|
104
|
+
pass
|
105
|
+
|
106
|
+
|
107
|
+
class DataScienceObjective:
|
108
|
+
"""This class is to replace the previous lambda function to solve the problem that python does not allow pickle local function/lambda function."""
|
109
|
+
|
110
|
+
def __init__(self, objective, X_res, y_res):
|
111
|
+
self.objective = objective
|
112
|
+
self.X_res = X_res
|
113
|
+
self.y_res = y_res
|
114
|
+
|
115
|
+
def __call__(self, trial):
|
116
|
+
return self.objective(self.X_res, self.y_res, trial)
|
117
|
+
|
118
|
+
|
119
|
+
class ADSTuner(BaseEstimator):
|
120
|
+
"""
|
121
|
+
Hyperparameter search with cross-validation.
|
122
|
+
"""
|
123
|
+
|
124
|
+
_required_parameters = ["model"]
|
125
|
+
|
126
|
+
@property
|
127
|
+
def sklearn_steps(self):
|
128
|
+
"""
|
129
|
+
Returns
|
130
|
+
-------
|
131
|
+
int
|
132
|
+
Search space which corresponds to the best candidate parameter setting.
|
133
|
+
"""
|
134
|
+
return _update_space_name(self.best_params, step_name=self._step_name)
|
135
|
+
|
136
|
+
@property
|
137
|
+
def best_index(self):
|
138
|
+
"""
|
139
|
+
Returns
|
140
|
+
-------
|
141
|
+
int
|
142
|
+
Index which corresponds to the best candidate parameter setting.
|
143
|
+
"""
|
144
|
+
return self.trials["value"].idxmax()
|
145
|
+
|
146
|
+
@property
|
147
|
+
def best_params(self):
|
148
|
+
"""
|
149
|
+
Returns
|
150
|
+
-------
|
151
|
+
Dict[str, Any]
|
152
|
+
Parameters of the best trial.
|
153
|
+
"""
|
154
|
+
self._check_is_fitted()
|
155
|
+
return self._remove_step_name(self._study.best_params)
|
156
|
+
|
157
|
+
@property
|
158
|
+
def best_score(self):
|
159
|
+
"""
|
160
|
+
Returns
|
161
|
+
-------
|
162
|
+
float
|
163
|
+
Mean cross-validated score of the best estimator.
|
164
|
+
"""
|
165
|
+
self._check_is_fitted()
|
166
|
+
return self._study.best_value
|
167
|
+
|
168
|
+
@property
|
169
|
+
def score_remaining(self):
|
170
|
+
"""
|
171
|
+
Returns
|
172
|
+
-------
|
173
|
+
float
|
174
|
+
The difference between the best score and the optimal score.
|
175
|
+
|
176
|
+
Raises
|
177
|
+
------
|
178
|
+
:class:`ExitCriterionError`
|
179
|
+
Error is raised if there is no score-based criteria for tuning.
|
180
|
+
"""
|
181
|
+
if self._optimal_score is None:
|
182
|
+
raise ExitCriterionError(
|
183
|
+
"Tuner does not have a score-based exit condition."
|
184
|
+
)
|
185
|
+
else:
|
186
|
+
return self._optimal_score - self.best_score
|
187
|
+
|
188
|
+
@property
|
189
|
+
def scoring_name(self):
|
190
|
+
"""
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
str
|
194
|
+
Scoring name.
|
195
|
+
"""
|
196
|
+
return self._extract_scoring_name()
|
197
|
+
|
198
|
+
@property
|
199
|
+
def n_trials(self):
|
200
|
+
"""
|
201
|
+
Returns
|
202
|
+
-------
|
203
|
+
int
|
204
|
+
Number of completed trials. Alias for `trial_count`.
|
205
|
+
"""
|
206
|
+
self._check_is_fitted()
|
207
|
+
return len(self.trials)
|
208
|
+
|
209
|
+
# Alias for n_trials
|
210
|
+
trial_count = n_trials
|
211
|
+
|
212
|
+
@property
|
213
|
+
def trials_remaining(self):
|
214
|
+
"""
|
215
|
+
Returns
|
216
|
+
-------
|
217
|
+
int
|
218
|
+
The number of trials remaining in the budget.
|
219
|
+
|
220
|
+
Raises
|
221
|
+
------
|
222
|
+
:class:`ExitCriterionError`
|
223
|
+
Raised if the current tuner does not include a trials-based exit
|
224
|
+
condition.
|
225
|
+
"""
|
226
|
+
if self._n_trials is None:
|
227
|
+
raise ExitCriterionError(
|
228
|
+
"This tuner does not include a trials-based exit condition"
|
229
|
+
)
|
230
|
+
return self._n_trials - self.n_trials + self._previous_trial_count
|
231
|
+
|
232
|
+
@property
|
233
|
+
def trials(self):
|
234
|
+
"""
|
235
|
+
Returns
|
236
|
+
-------
|
237
|
+
:class:`pandas.DataFrame`
|
238
|
+
Trial data up to this point.
|
239
|
+
"""
|
240
|
+
if self.is_halted():
|
241
|
+
if self._trial_dataframe is None:
|
242
|
+
return pd.DataFrame()
|
243
|
+
return self._trial_dataframe
|
244
|
+
trials_dataframe = self._study.trials_dataframe().copy()
|
245
|
+
return trials_dataframe
|
246
|
+
|
247
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
248
|
+
def __init__(
|
249
|
+
self,
|
250
|
+
model, # type: Union[BaseEstimator, Pipeline]
|
251
|
+
strategy="perfunctory", # type: Union[str, Mapping[str, optuna.distributions.BaseDistribution]]
|
252
|
+
scoring=None, # type: Optional[Union[Callable[..., float], str]]
|
253
|
+
cv=5, # type: Optional[int]
|
254
|
+
study_name=None, # type: Optional[str]
|
255
|
+
storage=None, # type: Optional[str]
|
256
|
+
load_if_exists=True, # type: Optional[bool]
|
257
|
+
random_state=None, # type: Optional[int]
|
258
|
+
loglevel=logging.INFO, # type: Optional[int]
|
259
|
+
n_jobs=1, # type: Optional[int]
|
260
|
+
X=None, # type: Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
|
261
|
+
y=None, # type: Optional[Union[OneDimArrayLikeType, TwoDimArrayLikeType]]
|
262
|
+
):
|
263
|
+
# type: (...) -> None
|
264
|
+
"""
|
265
|
+
Returns a hyperparameter tuning object
|
266
|
+
|
267
|
+
Parameters
|
268
|
+
----------
|
269
|
+
model:
|
270
|
+
Object to use to fit the data. This is assumed to implement the
|
271
|
+
scikit-learn estimator or pipeline interface.
|
272
|
+
strategy:
|
273
|
+
``perfunctory``, ``detailed`` or a dictionary/mapping of hyperparameter
|
274
|
+
and its distribution . If obj:`perfunctory`, picks a few
|
275
|
+
relatively more important hyperparmeters to tune . If obj:`detailed`,
|
276
|
+
extends to a larger search space. If obj:dict, user defined search
|
277
|
+
space: Dictionary where keys are hyperparameters and values are distributions.
|
278
|
+
Distributions are assumed to implement the ads distribution interface.
|
279
|
+
scoring: Optional[Union[Callable[..., float], str]]
|
280
|
+
String or callable to evaluate the predictions on the validation data.
|
281
|
+
If :obj:`None`, ``score`` on the estimator is used.
|
282
|
+
cv: int
|
283
|
+
Integer to specify the number of folds in a CV splitter.
|
284
|
+
If :obj:`estimator` is a classifier and :obj:`y` is
|
285
|
+
either binary or multiclass,
|
286
|
+
``sklearn.model_selection.StratifiedKFold`` is used. otherwise,
|
287
|
+
``sklearn.model_selection.KFold`` is used.
|
288
|
+
study_name: str,
|
289
|
+
Name of the current experiment for the ADSTuner object. One ADSTuner
|
290
|
+
object can only be attached to one study_name.
|
291
|
+
storage:
|
292
|
+
Database URL. (e.g. sqlite:///example.db). Default to sqlite:////tmp/hpo_*.db.
|
293
|
+
load_if_exists:
|
294
|
+
Flag to control the behavior to handle a conflict of study names.
|
295
|
+
In the case where a study named ``study_name`` already exists in the ``storage``,
|
296
|
+
a :class:`DuplicatedStudyError` is raised if ``load_if_exists`` is
|
297
|
+
set to :obj:`False`.
|
298
|
+
Otherwise, the existing one is returned.
|
299
|
+
random_state:
|
300
|
+
Seed of the pseudo random number generator. If int, this is the
|
301
|
+
seed used by the random number generator. If :obj:`None`, the global random state from
|
302
|
+
``numpy.random`` is used.
|
303
|
+
loglevel:
|
304
|
+
loglevel. can be logging.NOTSET, logging.INFO, logging.DEBUG, logging.WARNING
|
305
|
+
n_jobs: int
|
306
|
+
Number of parallel jobs. :obj:`-1` means using all processors.
|
307
|
+
X: TwoDimArrayLikeType, Union[List[List[float]], np.ndarray,
|
308
|
+
pd.DataFrame, spmatrix, ADSData]
|
309
|
+
Training data.
|
310
|
+
y: Union[OneDimArrayLikeType, TwoDimArrayLikeType], optional
|
311
|
+
OneDimArrayLikeType: Union[List[float], np.ndarray, pd.Series]
|
312
|
+
TwoDimArrayLikeType: Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
|
313
|
+
Target.
|
314
|
+
|
315
|
+
Example::
|
316
|
+
|
317
|
+
from ads.hpo.stopping_criterion import *
|
318
|
+
from ads.hpo.search_cv import ADSTuner
|
319
|
+
from sklearn.datasets import load_iris
|
320
|
+
from sklearn.svm import SVC
|
321
|
+
|
322
|
+
tuner = ADSTuner(
|
323
|
+
SVC(),
|
324
|
+
strategy='detailed',
|
325
|
+
scoring='f1_weighted',
|
326
|
+
random_state=42
|
327
|
+
)
|
328
|
+
|
329
|
+
X, y = load_iris(return_X_y=True)
|
330
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
331
|
+
"""
|
332
|
+
_imports.check()
|
333
|
+
self._n_jobs = n_jobs
|
334
|
+
assert (
|
335
|
+
cv > 1
|
336
|
+
), "k-fold cross-validation requires at least one train/test split by setting cv=2 or more"
|
337
|
+
self.cv = cv
|
338
|
+
self._error_score = np.nan
|
339
|
+
self.model = model
|
340
|
+
self._check_pipeline()
|
341
|
+
self._step_name = None
|
342
|
+
self._extract_estimator()
|
343
|
+
self.strategy = None
|
344
|
+
self._param_distributions = None
|
345
|
+
self._check_strategy(strategy)
|
346
|
+
self.strategy = strategy
|
347
|
+
self._param_distributions = self._get_param_distributions(self.strategy)
|
348
|
+
self._enable_pruning = hasattr(self.model, "partial_fit")
|
349
|
+
self._max_iter = 100
|
350
|
+
self.__random_state = random_state # to be used in export_trials
|
351
|
+
# this calls the randomstate.setter which turns self.random_state into a np.random.RandomState instance
|
352
|
+
# make it hard to be serialized.
|
353
|
+
self.random_state = check_random_state(random_state)
|
354
|
+
|
355
|
+
self._return_train_score = False
|
356
|
+
self.scoring = scoring
|
357
|
+
self._subsample = 1.0
|
358
|
+
self.loglevel = loglevel
|
359
|
+
self._trial_dataframe = None
|
360
|
+
self._status = State.INITIATED
|
361
|
+
self.study_name = (
|
362
|
+
study_name if study_name is not None else "hpo_" + str(uuid.uuid4())
|
363
|
+
)
|
364
|
+
self.storage = (
|
365
|
+
"sqlite:////tmp/hpo_" + str(uuid.uuid4()) + ".db"
|
366
|
+
if storage is None
|
367
|
+
else storage
|
368
|
+
)
|
369
|
+
self.oci_client = None
|
370
|
+
|
371
|
+
seed = np.random.randint(0, np.iinfo("int32").max)
|
372
|
+
|
373
|
+
self.sampler = optuna.samplers.TPESampler(seed=seed)
|
374
|
+
self.median_pruner = self._pruner(
|
375
|
+
class_name="median_pruner",
|
376
|
+
n_startup_trials=5,
|
377
|
+
n_warmup_steps=1,
|
378
|
+
interval_steps=1,
|
379
|
+
)
|
380
|
+
self.load_if_exists = load_if_exists
|
381
|
+
try:
|
382
|
+
self._study = optuna.study.create_study(
|
383
|
+
study_name=self.study_name,
|
384
|
+
direction="maximize",
|
385
|
+
pruner=self.median_pruner,
|
386
|
+
sampler=self.sampler,
|
387
|
+
storage=self.storage,
|
388
|
+
load_if_exists=self.load_if_exists,
|
389
|
+
)
|
390
|
+
except optuna.exceptions.DuplicatedStudyError as e:
|
391
|
+
if self.load_if_exists:
|
392
|
+
logger.info(
|
393
|
+
"Using an existing study with name '{}' instead of "
|
394
|
+
"creating a new one.".format(self.study_name)
|
395
|
+
)
|
396
|
+
else:
|
397
|
+
raise DuplicatedStudyError(
|
398
|
+
f"The study_name `{self.study_name}` exists in the {self.storage}. Either set load_if_exists=True, or use a new study_name."
|
399
|
+
)
|
400
|
+
self._init_data(X, y)
|
401
|
+
|
402
|
+
def search_space(self, strategy=None, overwrite=False):
|
403
|
+
"""
|
404
|
+
Returns the search space. If strategy is not passed in, return the existing search
|
405
|
+
space. When strategy is passed in, overwrite the existing search space if overwrite
|
406
|
+
is set True, otherwise, only update the existing search space.
|
407
|
+
|
408
|
+
Parameters
|
409
|
+
----------
|
410
|
+
strategy: Union[str, dict], optional
|
411
|
+
``perfunctory``, ``detailed`` or a dictionary/mapping of the hyperparameters
|
412
|
+
and their distributions. If obj:`perfunctory`, picks a few relatively
|
413
|
+
more important hyperparmeters to tune . If obj:`detailed`, extends to a
|
414
|
+
larger search space. If obj:dict, user defined search space: Dictionary
|
415
|
+
where keys are parameters and values are distributions. Distributions are
|
416
|
+
assumed to implement the ads distribution interface.
|
417
|
+
overwrite: bool, optional
|
418
|
+
Ignored when strategy is None. Otherwise, search space is overwritten if overwrite
|
419
|
+
is set True and updated if it is False.
|
420
|
+
|
421
|
+
Returns
|
422
|
+
-------
|
423
|
+
dict
|
424
|
+
A mapping of the hyperparameters and their distributions.
|
425
|
+
|
426
|
+
Example::
|
427
|
+
|
428
|
+
from ads.hpo.stopping_criterion import *
|
429
|
+
from ads.hpo.search_cv import ADSTuner
|
430
|
+
from sklearn.datasets import load_iris
|
431
|
+
from sklearn.linear_model import SGDClassifier
|
432
|
+
|
433
|
+
tuner = ADSTuner(
|
434
|
+
SGDClassifier(),
|
435
|
+
strategy='detailed',
|
436
|
+
scoring='f1_weighted',
|
437
|
+
random_state=42
|
438
|
+
)
|
439
|
+
tuner.search_space({'max_iter': 100})
|
440
|
+
X, y = load_iris(return_X_y=True)
|
441
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
442
|
+
tuner.search_space()
|
443
|
+
"""
|
444
|
+
assert hasattr(
|
445
|
+
self, "_param_distributions"
|
446
|
+
), "Call <code>ADSTuner</code> first."
|
447
|
+
if not strategy:
|
448
|
+
return self._remove_step_name(self._param_distributions)
|
449
|
+
self._check_strategy(strategy)
|
450
|
+
self.strategy = strategy
|
451
|
+
if overwrite:
|
452
|
+
self._param_distributions = self._get_param_distributions(self.strategy)
|
453
|
+
else:
|
454
|
+
self._param_distributions.update(
|
455
|
+
self._get_param_distributions(self.strategy)
|
456
|
+
)
|
457
|
+
return self._remove_step_name(self._param_distributions)
|
458
|
+
|
459
|
+
@staticmethod
|
460
|
+
def _remove_step_name(param_distributions):
|
461
|
+
search_space = {}
|
462
|
+
for param, distributions in param_distributions.items():
|
463
|
+
if "__" in param:
|
464
|
+
param = param.split("__")[1]
|
465
|
+
search_space[param] = distributions
|
466
|
+
return search_space
|
467
|
+
|
468
|
+
def _check_pipeline(self):
|
469
|
+
self.model = validate_pipeline(self.model)
|
470
|
+
|
471
|
+
def _get_internal_param_distributions(self, strategy):
|
472
|
+
if isinstance(self.model, Pipeline):
|
473
|
+
for step_name, step in self.model.steps:
|
474
|
+
if step.__class__ in get_model2searchspace().keys():
|
475
|
+
self._step_name = step_name
|
476
|
+
param_distributions = get_model2searchspace()[step.__class__](
|
477
|
+
strategy
|
478
|
+
).suggest_space(step_name=step_name)
|
479
|
+
if len(param_distributions) == 0:
|
480
|
+
logger.warning("Nothing to tune.")
|
481
|
+
else:
|
482
|
+
assert_model_is_supported(self.model)
|
483
|
+
param_distributions = get_model2searchspace()[self.model.__class__](
|
484
|
+
strategy
|
485
|
+
).suggest_space()
|
486
|
+
self._check_search_space(param_distributions)
|
487
|
+
return param_distributions
|
488
|
+
|
489
|
+
def _get_param_distributions(self, strategy):
|
490
|
+
if isinstance(strategy, str):
|
491
|
+
param_distributions = self._get_internal_param_distributions(strategy)
|
492
|
+
if isinstance(strategy, dict):
|
493
|
+
param_distributions = _update_space_name(
|
494
|
+
strategy, step_name=self._step_name
|
495
|
+
)
|
496
|
+
self._check_search_space(param_distributions)
|
497
|
+
return param_distributions
|
498
|
+
|
499
|
+
def _check_search_space(self, param_distributions):
|
500
|
+
validate_search_space(self.model.get_params().keys(), param_distributions)
|
501
|
+
|
502
|
+
def _check_is_fitted(self):
|
503
|
+
assert_tuner_is_fitted(self)
|
504
|
+
|
505
|
+
def _check_strategy(self, strategy):
|
506
|
+
assert_strategy_valid(self._param_distributions, strategy, self.strategy)
|
507
|
+
|
508
|
+
def _add_halt_time(self):
|
509
|
+
"""Adds a new start time window to the start/stop log. This happens in two cases: when the tuning process
|
510
|
+
has commenced and when it resumes following a halt
|
511
|
+
"""
|
512
|
+
self._time_log.append(dict(halt=time(), resume=None))
|
513
|
+
|
514
|
+
def _add_resume_time(self):
|
515
|
+
"""Adds a new stopping time to the last window in the time log. This happens when the HPO process is
|
516
|
+
halted or terminated.
|
517
|
+
"""
|
518
|
+
if len(self._time_log) > 0:
|
519
|
+
entry = self._time_log.pop()
|
520
|
+
if entry["resume"] is not None:
|
521
|
+
raise Exception("Cannot close a time window without an opening time.")
|
522
|
+
self._time_log.append(dict(halt=entry["halt"], resume=time()))
|
523
|
+
|
524
|
+
def tune(
|
525
|
+
self,
|
526
|
+
X=None, # type: TwoDimArrayLikeType
|
527
|
+
y=None, # type: Optional[Union[OneDimArrayLikeType, TwoDimArrayLikeType]]
|
528
|
+
exit_criterion=[], # type: Optional[list]
|
529
|
+
loglevel=None, # type: Optional[int]
|
530
|
+
synchronous=False, # type: Optional[boolean]
|
531
|
+
):
|
532
|
+
"""
|
533
|
+
Run hypyerparameter tuning until one of the <code>exit_criterion</code>
|
534
|
+
is met. The default is to run 50 trials.
|
535
|
+
|
536
|
+
Parameters
|
537
|
+
----------
|
538
|
+
X: TwoDimArrayLikeType, Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
|
539
|
+
|
540
|
+
Training data.
|
541
|
+
y: Union[OneDimArrayLikeType, TwoDimArrayLikeType], optional
|
542
|
+
OneDimArrayLikeType: Union[List[float], np.ndarray, pd.Series]
|
543
|
+
TwoDimArrayLikeType: Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
|
544
|
+
|
545
|
+
Target.
|
546
|
+
exit_criterion: list, optional
|
547
|
+
A list of ads stopping criterion. Can be `ScoreValue()`, `NTrials()`, `TimeBudget()`.
|
548
|
+
For example, [ScoreValue(0.96), NTrials(40), TimeBudget(10)]. It will exit when any of the
|
549
|
+
stopping criterion is satisfied in the `exit_criterion` list.
|
550
|
+
By default, the run will stop after 50 trials.
|
551
|
+
loglevel: int, optional
|
552
|
+
Log level.
|
553
|
+
synchronous: boolean, optional
|
554
|
+
Tune synchronously or not. Defaults to `False`
|
555
|
+
|
556
|
+
Returns
|
557
|
+
-------
|
558
|
+
None
|
559
|
+
Nothing
|
560
|
+
|
561
|
+
Example::
|
562
|
+
|
563
|
+
from ads.hpo.stopping_criterion import *
|
564
|
+
from ads.hpo.search_cv import ADSTuner
|
565
|
+
from sklearn.datasets import load_iris
|
566
|
+
from sklearn.svm import SVC
|
567
|
+
|
568
|
+
tuner = ADSTuner(
|
569
|
+
SVC(),
|
570
|
+
strategy='detailed',
|
571
|
+
scoring='f1_weighted',
|
572
|
+
random_state=42
|
573
|
+
)
|
574
|
+
tuner.search_space({'max_iter': 100})
|
575
|
+
X, y = load_iris(return_X_y=True)
|
576
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
577
|
+
"""
|
578
|
+
|
579
|
+
# Get previous trial count to ensure proper counting.
|
580
|
+
try:
|
581
|
+
self._previous_trial_count = self.trial_count
|
582
|
+
except NotFittedError:
|
583
|
+
self._previous_trial_count = 0
|
584
|
+
except Exception as e:
|
585
|
+
_logger.error(f"Error retrieving previous trial count: {e}")
|
586
|
+
raise
|
587
|
+
|
588
|
+
self._init_data(X, y)
|
589
|
+
if self.X is None:
|
590
|
+
raise ValueError(
|
591
|
+
"Need to either pass the data to `X` and `y` in `tune()`, or to `ADSTuner`."
|
592
|
+
)
|
593
|
+
if self.is_running():
|
594
|
+
raise InvalidStateTransition(
|
595
|
+
"Running process found. Do you need to call terminate() to stop before calling tune()?"
|
596
|
+
)
|
597
|
+
if self.is_halted():
|
598
|
+
raise InvalidStateTransition(
|
599
|
+
"Halted process found. You need to call resume()."
|
600
|
+
)
|
601
|
+
# handle ADSData
|
602
|
+
|
603
|
+
# Initialize time log for every new call to tune(). Set shared global time values
|
604
|
+
self._global_start = multiprocessing.Value("d", 0.0)
|
605
|
+
self._global_stop = multiprocessing.Value("d", 0.0)
|
606
|
+
self._time_log = []
|
607
|
+
|
608
|
+
self._tune(
|
609
|
+
X=self.X,
|
610
|
+
y=self.y,
|
611
|
+
exit_criterion=exit_criterion,
|
612
|
+
loglevel=loglevel,
|
613
|
+
synchronous=synchronous,
|
614
|
+
)
|
615
|
+
|
616
|
+
# Tune cannot exit before the clock starts in the subprocess.
|
617
|
+
while self._global_start.value == 0.0:
|
618
|
+
sleep(0.01)
|
619
|
+
|
620
|
+
def _init_data(self, X, y):
|
621
|
+
if X is not None:
|
622
|
+
if isinstance(X, ADSData):
|
623
|
+
self.y = X.y
|
624
|
+
self.X = X.X
|
625
|
+
else:
|
626
|
+
self.X = X
|
627
|
+
self.y = y
|
628
|
+
|
629
|
+
def halt(self):
|
630
|
+
"""
|
631
|
+
Halt the current running tuning process.
|
632
|
+
|
633
|
+
Returns
|
634
|
+
-------
|
635
|
+
None
|
636
|
+
Nothing
|
637
|
+
|
638
|
+
Raises
|
639
|
+
------
|
640
|
+
`InvalidStateTransition` if no running process is found
|
641
|
+
|
642
|
+
Example::
|
643
|
+
|
644
|
+
from ads.hpo.stopping_criterion import *
|
645
|
+
from ads.hpo.search_cv import ADSTuner
|
646
|
+
from sklearn.datasets import load_iris
|
647
|
+
from sklearn.linear_model import SGDClassifier
|
648
|
+
|
649
|
+
tuner = ADSTuner(
|
650
|
+
SGDClassifier(),
|
651
|
+
strategy='detailed',
|
652
|
+
scoring='f1_weighted',
|
653
|
+
random_state=42
|
654
|
+
)
|
655
|
+
tuner.search_space({'max_iter': 100})
|
656
|
+
X, y = load_iris(return_X_y=True)
|
657
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
658
|
+
tuner.halt()
|
659
|
+
"""
|
660
|
+
if hasattr(self, "_tune_process") and self._status == State.RUNNING:
|
661
|
+
self._trial_dataframe = self._study.trials_dataframe().copy()
|
662
|
+
psutil.Process(self._tune_process.pid).suspend()
|
663
|
+
self._status = State.HALTED
|
664
|
+
self._add_halt_time()
|
665
|
+
else:
|
666
|
+
raise InvalidStateTransition(
|
667
|
+
"No running process found. Do you need to call tune()?"
|
668
|
+
)
|
669
|
+
|
670
|
+
def resume(self):
|
671
|
+
"""
|
672
|
+
Resume the current halted tuning process.
|
673
|
+
|
674
|
+
Returns
|
675
|
+
-------
|
676
|
+
None
|
677
|
+
Nothing
|
678
|
+
|
679
|
+
Example::
|
680
|
+
|
681
|
+
from ads.hpo.stopping_criterion import *
|
682
|
+
from ads.hpo.search_cv import ADSTuner
|
683
|
+
from sklearn.datasets import load_iris
|
684
|
+
from sklearn.linear_model import SGDClassifier
|
685
|
+
|
686
|
+
tuner = ADSTuner(
|
687
|
+
SGDClassifier(),
|
688
|
+
strategy='detailed',
|
689
|
+
scoring='f1_weighted',
|
690
|
+
random_state=42
|
691
|
+
)
|
692
|
+
tuner.search_space({'max_iter': 100})
|
693
|
+
X, y = load_iris(return_X_y=True)
|
694
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
695
|
+
tuner.halt()
|
696
|
+
tuner.resume()
|
697
|
+
"""
|
698
|
+
if self.is_halted():
|
699
|
+
psutil.Process(self._tune_process.pid).resume()
|
700
|
+
self._add_resume_time()
|
701
|
+
self._status = State.RUNNING
|
702
|
+
else:
|
703
|
+
raise InvalidStateTransition("No paused process found.")
|
704
|
+
|
705
|
+
def wait(self):
|
706
|
+
"""
|
707
|
+
Wait for the current tuning process to finish running.
|
708
|
+
|
709
|
+
Returns
|
710
|
+
-------
|
711
|
+
None
|
712
|
+
Nothing
|
713
|
+
|
714
|
+
Example::
|
715
|
+
|
716
|
+
from ads.hpo.stopping_criterion import *
|
717
|
+
from ads.hpo.search_cv import ADSTuner
|
718
|
+
from sklearn.datasets import load_iris
|
719
|
+
from sklearn.linear_model import SGDClassifier
|
720
|
+
|
721
|
+
tuner = ADSTuner(
|
722
|
+
SGDClassifier(),
|
723
|
+
strategy='detailed',
|
724
|
+
scoring='f1_weighted',
|
725
|
+
random_state=42
|
726
|
+
)
|
727
|
+
tuner.search_space({'max_iter': 100})
|
728
|
+
X, y = load_iris(return_X_y=True)
|
729
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
730
|
+
tuner.wait()
|
731
|
+
"""
|
732
|
+
if self.is_running():
|
733
|
+
self._tune_process.join()
|
734
|
+
self._status = State.COMPLETED
|
735
|
+
else:
|
736
|
+
raise InvalidStateTransition("No running process.")
|
737
|
+
|
738
|
+
def terminate(self):
|
739
|
+
"""
|
740
|
+
Terminate the current tuning process.
|
741
|
+
|
742
|
+
Returns
|
743
|
+
-------
|
744
|
+
None
|
745
|
+
Nothing
|
746
|
+
|
747
|
+
Example::
|
748
|
+
|
749
|
+
from ads.hpo.stopping_criterion import *
|
750
|
+
from ads.hpo.search_cv import ADSTuner
|
751
|
+
from sklearn.datasets import load_iris
|
752
|
+
from sklearn.linear_model import SGDClassifier
|
753
|
+
|
754
|
+
tuner = ADSTuner(
|
755
|
+
SGDClassifier(),
|
756
|
+
strategy='detailed',
|
757
|
+
scoring='f1_weighted',
|
758
|
+
random_state=42
|
759
|
+
)
|
760
|
+
tuner.search_space({'max_iter': 100})
|
761
|
+
X, y = load_iris(return_X_y=True)
|
762
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
763
|
+
tuner.terminate()
|
764
|
+
"""
|
765
|
+
if self.is_running():
|
766
|
+
self._tune_process.terminate()
|
767
|
+
self._tune_process.join()
|
768
|
+
self._status = State.TERMINATED
|
769
|
+
# self._add_terminate_time()
|
770
|
+
self._update_failed_trial_state()
|
771
|
+
else:
|
772
|
+
raise RuntimeError("No running process found. Do you need to call tune()?")
|
773
|
+
|
774
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
775
|
+
def _update_failed_trial_state(self):
|
776
|
+
from optuna.trial import TrialState
|
777
|
+
|
778
|
+
for trial in self._study.trials:
|
779
|
+
if trial.state == TrialState.RUNNING:
|
780
|
+
self._study._storage.set_trial_state(
|
781
|
+
trial._trial_id, optuna.structs.TrialState.FAIL
|
782
|
+
)
|
783
|
+
|
784
|
+
@property
|
785
|
+
def time_remaining(self):
|
786
|
+
"""Returns the number of seconds remaining in the study
|
787
|
+
|
788
|
+
Returns
|
789
|
+
-------
|
790
|
+
int: Number of seconds remaining in the budget. 0 if complete/terminated
|
791
|
+
|
792
|
+
Raises
|
793
|
+
------
|
794
|
+
:class:`ExitCriterionError`
|
795
|
+
Error is raised if time has not been included in the budget.
|
796
|
+
"""
|
797
|
+
if self._time_budget is None:
|
798
|
+
raise ExitCriterionError(
|
799
|
+
"This tuner does not include a time-based exit condition"
|
800
|
+
)
|
801
|
+
elif self.is_completed() or self.is_terminated():
|
802
|
+
return 0
|
803
|
+
return max(self._time_budget - self.time_elapsed, 0)
|
804
|
+
|
805
|
+
@property
|
806
|
+
def time_since_resume(self):
|
807
|
+
"""Return the seconds since the process has been resumed from a halt.
|
808
|
+
|
809
|
+
Returns
|
810
|
+
-------
|
811
|
+
int: the number of seconds since the process was last resumed
|
812
|
+
|
813
|
+
Raises
|
814
|
+
------
|
815
|
+
`NoRestartError` is the process has not been resumed
|
816
|
+
|
817
|
+
"""
|
818
|
+
if len(self._time_log) > 0:
|
819
|
+
last_time_resumed = self._time_log[-1].get("resume")
|
820
|
+
else:
|
821
|
+
raise Exception("Time log should not be empty")
|
822
|
+
|
823
|
+
if self.is_running():
|
824
|
+
if last_time_resumed is not None:
|
825
|
+
return time() - last_time_resumed
|
826
|
+
else:
|
827
|
+
raise NoRestartError("The process has not been resumed")
|
828
|
+
elif self.is_halted():
|
829
|
+
return 0 # if halted, the amount of time since restarted from a halt is 0
|
830
|
+
elif self.is_terminated():
|
831
|
+
raise NoRestartError("The process has been terminated")
|
832
|
+
|
833
|
+
@property
|
834
|
+
def time_elapsed(self):
|
835
|
+
"""Return the time in seconds that the HPO process has been searching
|
836
|
+
|
837
|
+
Returns
|
838
|
+
-------
|
839
|
+
int: The number of seconds the HPO process has been searching
|
840
|
+
"""
|
841
|
+
time_in_halted_state = 0.0
|
842
|
+
|
843
|
+
# Add up all the halted durations, i.e. the time spent between halt and resume
|
844
|
+
for entry in self._time_log:
|
845
|
+
halt_time = entry.get("halt")
|
846
|
+
resume_time = entry.get("resume")
|
847
|
+
|
848
|
+
if resume_time is None:
|
849
|
+
# halted state.
|
850
|
+
# elapsed = halt time - global start - time halted
|
851
|
+
elapsed = halt_time - self._global_start.value - time_in_halted_state
|
852
|
+
return elapsed
|
853
|
+
|
854
|
+
else:
|
855
|
+
# running/completed/terminated state,
|
856
|
+
time_in_halted_state += resume_time - halt_time
|
857
|
+
|
858
|
+
# If the loop ends all halts were resumed. If self._global_stop != 0 that means the
|
859
|
+
# process has exited.
|
860
|
+
if self._global_stop.value != 0:
|
861
|
+
global_time = self._global_stop.value - self._global_start.value
|
862
|
+
else:
|
863
|
+
global_time = time() - self._global_start.value
|
864
|
+
|
865
|
+
elapsed = global_time - time_in_halted_state
|
866
|
+
return elapsed
|
867
|
+
|
868
|
+
def best_scores(self, n: int = 5, reverse: bool = True):
|
869
|
+
"""Return the best scores from the study
|
870
|
+
|
871
|
+
Parameters
|
872
|
+
----------
|
873
|
+
n: int
|
874
|
+
The maximum number of results to show. Defaults to 5. If `None` or
|
875
|
+
negative return all.
|
876
|
+
reverse: bool
|
877
|
+
Whether to reverse the sort order so results are in descending order.
|
878
|
+
Defaults to `True`
|
879
|
+
|
880
|
+
Returns
|
881
|
+
-------
|
882
|
+
list[float or int]
|
883
|
+
List of the best scores
|
884
|
+
|
885
|
+
Raises
|
886
|
+
------
|
887
|
+
`ValueError` if there are no trials
|
888
|
+
"""
|
889
|
+
if len(self.trials) < 1:
|
890
|
+
raise ValueError("No score data to show")
|
891
|
+
else:
|
892
|
+
scores = self.trials.value
|
893
|
+
scores = scores[scores.notnull()]
|
894
|
+
if scores is None:
|
895
|
+
raise ValueError(
|
896
|
+
f"No score data despite valid trial data. Trial data length: {len(self.trials)}"
|
897
|
+
)
|
898
|
+
if not isinstance(n, int) or n <= 0:
|
899
|
+
return sorted(scores, reverse=reverse)
|
900
|
+
else:
|
901
|
+
return sorted(scores, reverse=reverse)[:n]
|
902
|
+
|
903
|
+
def get_status(self):
|
904
|
+
"""
|
905
|
+
return the status of the current tuning process.
|
906
|
+
|
907
|
+
Alias for the property `status`.
|
908
|
+
|
909
|
+
Returns
|
910
|
+
-------
|
911
|
+
:class:`Status`
|
912
|
+
The status of the process
|
913
|
+
|
914
|
+
Example::
|
915
|
+
|
916
|
+
from ads.hpo.stopping_criterion import *
|
917
|
+
from ads.hpo.search_cv import ADSTuner
|
918
|
+
from sklearn.datasets import load_iris
|
919
|
+
from sklearn.linear_model import SGDClassifier
|
920
|
+
|
921
|
+
tuner = ADSTuner(
|
922
|
+
SGDClassifier(),
|
923
|
+
strategy='detailed',
|
924
|
+
scoring='f1_weighted',
|
925
|
+
random_state=42
|
926
|
+
)
|
927
|
+
tuner.search_space({'max_iter': 100})
|
928
|
+
X, y = load_iris(return_X_y=True)
|
929
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
|
930
|
+
tuner.get_status()
|
931
|
+
"""
|
932
|
+
return self.status
|
933
|
+
|
934
|
+
def is_running(self):
|
935
|
+
"""
|
936
|
+
Returns
|
937
|
+
-------
|
938
|
+
bool
|
939
|
+
`True` if the :class:`ADSTuner` instance is running; `False` otherwise.
|
940
|
+
"""
|
941
|
+
return self.status == State.RUNNING
|
942
|
+
|
943
|
+
def is_halted(self):
|
944
|
+
"""
|
945
|
+
Returns
|
946
|
+
-------
|
947
|
+
bool
|
948
|
+
`True` if the :class:`ADSTuner` instance is halted; `False` otherwise.
|
949
|
+
"""
|
950
|
+
return self.status == State.HALTED
|
951
|
+
|
952
|
+
def is_terminated(self):
|
953
|
+
"""
|
954
|
+
Returns
|
955
|
+
-------
|
956
|
+
bool
|
957
|
+
`True` if the :class:`ADSTuner` instance has been terminated; `False` otherwise.
|
958
|
+
"""
|
959
|
+
return self.status == State.TERMINATED
|
960
|
+
|
961
|
+
def is_completed(self):
|
962
|
+
"""
|
963
|
+
Returns
|
964
|
+
-------
|
965
|
+
bool
|
966
|
+
`True` if the :class:`ADSTuner` instance has completed; `False` otherwise.
|
967
|
+
"""
|
968
|
+
return self.status == State.COMPLETED
|
969
|
+
|
970
|
+
def _is_tuning_started(self):
|
971
|
+
"""
|
972
|
+
Returns
|
973
|
+
-------
|
974
|
+
bool
|
975
|
+
`True` if the :class:`ADSTuner` instance has been started (for example, halted or
|
976
|
+
running); `False` otherwise.
|
977
|
+
"""
|
978
|
+
return self.status == State.HALTED or self.status == State.RUNNING
|
979
|
+
|
980
|
+
def _is_tuning_finished(self):
|
981
|
+
"""
|
982
|
+
Returns
|
983
|
+
-------
|
984
|
+
bool
|
985
|
+
`True` if the :class:`ADSTuner` instance is finished running (i.e. completed
|
986
|
+
or terminated); `False` otherwise.
|
987
|
+
"""
|
988
|
+
return self.status == State.COMPLETED or self.status == State.TERMINATED
|
989
|
+
|
990
|
+
@property
|
991
|
+
def status(self):
|
992
|
+
"""
|
993
|
+
Returns
|
994
|
+
-------
|
995
|
+
:class:`Status`
|
996
|
+
The status of the current tuning process.
|
997
|
+
"""
|
998
|
+
if (
|
999
|
+
self._status == State.HALTED
|
1000
|
+
or self._status == State.TERMINATED
|
1001
|
+
or self._status == State.INITIATED
|
1002
|
+
):
|
1003
|
+
return self._status
|
1004
|
+
elif hasattr(self, "_tune_process") and self._tune_process.is_alive():
|
1005
|
+
return State.RUNNING
|
1006
|
+
else:
|
1007
|
+
return State.COMPLETED
|
1008
|
+
return self._status
|
1009
|
+
|
1010
|
+
def _extract_exit_criterion(self, exit_criterion):
|
1011
|
+
# handle the exit criterion
|
1012
|
+
self._time_budget = None
|
1013
|
+
self._n_trials = None
|
1014
|
+
self.exit_criterion = []
|
1015
|
+
self._optimal_score = None
|
1016
|
+
if exit_criterion is None or len(exit_criterion) == 0:
|
1017
|
+
self._n_trials = 50
|
1018
|
+
for i, criteria in enumerate(exit_criterion):
|
1019
|
+
if isinstance(criteria, TimeBudget):
|
1020
|
+
self._time_budget = criteria()
|
1021
|
+
elif isinstance(criteria, NTrials):
|
1022
|
+
self._n_trials = criteria()
|
1023
|
+
elif isinstance(criteria, ScoreValue):
|
1024
|
+
self._optimal_score = criteria.score
|
1025
|
+
self.exit_criterion.append(criteria)
|
1026
|
+
else:
|
1027
|
+
raise NotImplementedError(
|
1028
|
+
"``{}`` is not supported!".format(criteria.__class__.__name__)
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
def _extract_estimator(self):
|
1032
|
+
if isinstance(self.model, Pipeline): # Pipeline
|
1033
|
+
for step_name, step in self.model.steps:
|
1034
|
+
if self._is_estimator(step):
|
1035
|
+
self._step_name = step_name
|
1036
|
+
self.estimator = step
|
1037
|
+
|
1038
|
+
else:
|
1039
|
+
self.estimator = self.model
|
1040
|
+
assert_is_estimator(self.estimator)
|
1041
|
+
# assert _check_estimator(self.estimator), "Estimator must implement fit"
|
1042
|
+
|
1043
|
+
def _extract_scoring_name(self):
|
1044
|
+
if isinstance(self.scoring, str):
|
1045
|
+
return self.scoring
|
1046
|
+
if not callable(self._scorer):
|
1047
|
+
return (
|
1048
|
+
self._scorer
|
1049
|
+
if isinstance(self._scorer, str)
|
1050
|
+
else str(self._scorer).split("(")[1].split(")")[0]
|
1051
|
+
)
|
1052
|
+
else:
|
1053
|
+
if is_classifier(self.model):
|
1054
|
+
return "mean accuracy"
|
1055
|
+
else:
|
1056
|
+
return "r2"
|
1057
|
+
|
1058
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
1059
|
+
def _set_logger(self, loglevel, class_name):
|
1060
|
+
if loglevel is not None:
|
1061
|
+
self.loglevel = loglevel
|
1062
|
+
if class_name == "optuna":
|
1063
|
+
optuna.logging.set_verbosity(self.loglevel)
|
1064
|
+
else:
|
1065
|
+
raise NotImplementedError("{} is not supported.".format(class_name))
|
1066
|
+
|
1067
|
+
def _set_sample_indices(self, X, random_state):
|
1068
|
+
max_samples = self._subsample
|
1069
|
+
n_samples = _num_samples(X)
|
1070
|
+
self._sample_indices = np.arange(n_samples)
|
1071
|
+
|
1072
|
+
if isinstance(max_samples, float):
|
1073
|
+
max_samples = int(max_samples * n_samples)
|
1074
|
+
|
1075
|
+
if max_samples < n_samples:
|
1076
|
+
self._sample_indices = random_state.choice(
|
1077
|
+
self._sample_indices, max_samples, replace=False
|
1078
|
+
)
|
1079
|
+
|
1080
|
+
self._sample_indices.sort()
|
1081
|
+
|
1082
|
+
def _get_fit_params_res(self, X):
|
1083
|
+
fit_params = {}
|
1084
|
+
fit_params_res = fit_params
|
1085
|
+
|
1086
|
+
if fit_params_res is not None:
|
1087
|
+
fit_params_res = validate_fit_params(X, fit_params, self._sample_indices)
|
1088
|
+
return fit_params_res
|
1089
|
+
|
1090
|
+
def _can_tune(self):
|
1091
|
+
assert hasattr(self, "model"), "Call <code>ADSTuner</code> first."
|
1092
|
+
if self._param_distributions == {}:
|
1093
|
+
logger.warning("Nothing to tune.")
|
1094
|
+
|
1095
|
+
if self._param_distributions is None:
|
1096
|
+
raise NotImplementedError(
|
1097
|
+
"There was no model specified or the model is not supported."
|
1098
|
+
)
|
1099
|
+
|
1100
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
1101
|
+
def _tune(
|
1102
|
+
self,
|
1103
|
+
X, # type: TwoDimArrayLikeType
|
1104
|
+
y, # type: Optional[Union[OneDimArrayLikeType, TwoDimArrayLikeType]]
|
1105
|
+
exit_criterion=[], # type: Optional[list]
|
1106
|
+
loglevel=None, # type: Optional[int]
|
1107
|
+
synchronous=False, # type: Optional[boolean]
|
1108
|
+
):
|
1109
|
+
# type: (...) -> tuple
|
1110
|
+
"""
|
1111
|
+
Tune with all sets of parameters.
|
1112
|
+
"""
|
1113
|
+
self._can_tune()
|
1114
|
+
self._set_logger(loglevel=loglevel, class_name="optuna")
|
1115
|
+
self._extract_exit_criterion(exit_criterion)
|
1116
|
+
self._extract_estimator()
|
1117
|
+
random_state = self.random_state
|
1118
|
+
old_level = logger.getEffectiveLevel()
|
1119
|
+
logger.setLevel(self.loglevel)
|
1120
|
+
if not synchronous:
|
1121
|
+
optuna.logging.set_verbosity(optuna.logging.ERROR)
|
1122
|
+
logger.setLevel(logging.ERROR)
|
1123
|
+
|
1124
|
+
self._set_sample_indices(X, random_state)
|
1125
|
+
X_res = _safe_indexing(X, self._sample_indices)
|
1126
|
+
y_res = _safe_indexing(y, self._sample_indices)
|
1127
|
+
groups_res = _safe_indexing(None, self._sample_indices)
|
1128
|
+
fit_params_res = self._get_fit_params_res(X)
|
1129
|
+
|
1130
|
+
classifier = is_classifier(self.model)
|
1131
|
+
cv = check_cv(self.cv, y_res, classifier=classifier)
|
1132
|
+
self._n_splits = cv.get_n_splits(X_res, y_res, groups=groups_res)
|
1133
|
+
|
1134
|
+
# scoring
|
1135
|
+
self._scorer = check_scoring(self.estimator, scoring=self.scoring)
|
1136
|
+
|
1137
|
+
self._study = optuna.study.create_study(
|
1138
|
+
study_name=self.study_name,
|
1139
|
+
direction="maximize",
|
1140
|
+
pruner=self.median_pruner,
|
1141
|
+
sampler=self.sampler,
|
1142
|
+
storage=self.storage,
|
1143
|
+
load_if_exists=self.load_if_exists,
|
1144
|
+
)
|
1145
|
+
objective = _Objective(
|
1146
|
+
self.model,
|
1147
|
+
self._param_distributions,
|
1148
|
+
cv,
|
1149
|
+
self._enable_pruning,
|
1150
|
+
self._error_score,
|
1151
|
+
fit_params_res,
|
1152
|
+
groups_res,
|
1153
|
+
self._max_iter,
|
1154
|
+
self._return_train_score,
|
1155
|
+
self._scorer,
|
1156
|
+
self.scoring_name,
|
1157
|
+
self._step_name,
|
1158
|
+
)
|
1159
|
+
|
1160
|
+
if synchronous:
|
1161
|
+
logger.info(
|
1162
|
+
"Optimizing hyperparameters using {} "
|
1163
|
+
"samples...".format(_num_samples(self._sample_indices))
|
1164
|
+
)
|
1165
|
+
|
1166
|
+
self._tune_process = multiprocessing.Process(
|
1167
|
+
target=ADSTuner.optimizer,
|
1168
|
+
args=(
|
1169
|
+
self.study_name,
|
1170
|
+
self.median_pruner,
|
1171
|
+
self.sampler,
|
1172
|
+
self.storage,
|
1173
|
+
self.load_if_exists,
|
1174
|
+
DataScienceObjective(objective, X_res, y_res),
|
1175
|
+
self._global_start,
|
1176
|
+
self._global_stop,
|
1177
|
+
),
|
1178
|
+
kwargs=dict(
|
1179
|
+
n_jobs=self._n_jobs,
|
1180
|
+
n_trials=self._n_trials,
|
1181
|
+
timeout=self._time_budget,
|
1182
|
+
show_progress_bar=False,
|
1183
|
+
callbacks=self.exit_criterion,
|
1184
|
+
gc_after_trial=False,
|
1185
|
+
),
|
1186
|
+
)
|
1187
|
+
|
1188
|
+
self._tune_process.start()
|
1189
|
+
self._status = State.RUNNING
|
1190
|
+
|
1191
|
+
if synchronous:
|
1192
|
+
self._tune_process.join()
|
1193
|
+
logger.info("Finished hyperparemeter search!")
|
1194
|
+
self._status = State.COMPLETED
|
1195
|
+
|
1196
|
+
logger.setLevel(old_level)
|
1197
|
+
|
1198
|
+
@staticmethod
|
1199
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
1200
|
+
def optimizer(
|
1201
|
+
study_name,
|
1202
|
+
pruner,
|
1203
|
+
sampler,
|
1204
|
+
storage,
|
1205
|
+
load_if_exists,
|
1206
|
+
objective_func,
|
1207
|
+
global_start,
|
1208
|
+
global_stop,
|
1209
|
+
**kwargs,
|
1210
|
+
):
|
1211
|
+
"""
|
1212
|
+
Static method for running ADSTuner tuning process
|
1213
|
+
|
1214
|
+
Parameters
|
1215
|
+
----------
|
1216
|
+
study_name: str
|
1217
|
+
The name of the study.
|
1218
|
+
pruner
|
1219
|
+
The pruning method for pruning trials.
|
1220
|
+
sampler
|
1221
|
+
The sampling method used for tuning.
|
1222
|
+
storage: str
|
1223
|
+
Storage endpoint.
|
1224
|
+
load_if_exists: bool
|
1225
|
+
Load existing study if it exists.
|
1226
|
+
objective_func
|
1227
|
+
The objective function to be maximized.
|
1228
|
+
global_start: :class:`multiprocesing.Value`
|
1229
|
+
The global start time.
|
1230
|
+
global_stop: :class:`multiprocessing.Value`
|
1231
|
+
The global stop time.
|
1232
|
+
kwargs: dict
|
1233
|
+
Keyword/value pairs passed into the optimize process
|
1234
|
+
|
1235
|
+
|
1236
|
+
Raises
|
1237
|
+
------
|
1238
|
+
:class:`Exception`
|
1239
|
+
Raised for any exceptions thrown by the underlying optimization process
|
1240
|
+
|
1241
|
+
Returns
|
1242
|
+
-------
|
1243
|
+
None
|
1244
|
+
Nothing
|
1245
|
+
|
1246
|
+
"""
|
1247
|
+
import traceback
|
1248
|
+
|
1249
|
+
study = optuna.study.create_study(
|
1250
|
+
study_name=study_name,
|
1251
|
+
direction="maximize",
|
1252
|
+
pruner=pruner,
|
1253
|
+
sampler=sampler,
|
1254
|
+
storage=storage,
|
1255
|
+
load_if_exists=load_if_exists,
|
1256
|
+
)
|
1257
|
+
try:
|
1258
|
+
global_start.value = time()
|
1259
|
+
study.optimize(objective_func, **kwargs)
|
1260
|
+
global_stop.value = time()
|
1261
|
+
except Exception as e:
|
1262
|
+
traceback.print_exc()
|
1263
|
+
raise e
|
1264
|
+
|
1265
|
+
@staticmethod
|
1266
|
+
def _is_estimator(step):
|
1267
|
+
return hasattr(step, "fit") and (
|
1268
|
+
not hasattr(step, "transform")
|
1269
|
+
or hasattr(step, "predict")
|
1270
|
+
or hasattr(step, "fit_predict")
|
1271
|
+
)
|
1272
|
+
|
1273
|
+
@staticmethod
|
1274
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
1275
|
+
def _pruner(class_name, **kwargs):
|
1276
|
+
if class_name == "median_pruner":
|
1277
|
+
return optuna.pruners.MedianPruner(**kwargs)
|
1278
|
+
else:
|
1279
|
+
raise NotImplementedError("{} is not supported.".format(class_name))
|
1280
|
+
|
1281
|
+
def trials_export(
|
1282
|
+
self, file_uri, metadata=None, script_dict={"model": None, "scoring": None}
|
1283
|
+
):
|
1284
|
+
"""Export the meta data as well as files needed to reconstruct the ADSTuner object to the object storage.
|
1285
|
+
Data is not stored. To resume the same ADSTuner object from object storage and continue tuning from previous trials,
|
1286
|
+
you have to provide the dataset.
|
1287
|
+
|
1288
|
+
Parameters
|
1289
|
+
----------
|
1290
|
+
file_uri: str
|
1291
|
+
Object storage path, 'oci://bucketname@namespace/filepath/on/objectstorage'. For example,
|
1292
|
+
`oci://test_bucket@ociodsccust/tuner/test.zip`
|
1293
|
+
metadata: str, optional
|
1294
|
+
User defined metadata
|
1295
|
+
script_dict: dict, optional
|
1296
|
+
Script paths for model and scoring. This is only recommended for unsupported
|
1297
|
+
models and user-defined scoring functions. You can store the model and scoring
|
1298
|
+
function in a dictionary with keys `model` and `scoring` and the respective
|
1299
|
+
paths as values. The model and scoring scripts must import necessary libraries
|
1300
|
+
for the script to run. The ``model`` and ``scoring`` variables must be set to
|
1301
|
+
your model and scoring function.
|
1302
|
+
|
1303
|
+
Returns
|
1304
|
+
-------
|
1305
|
+
None
|
1306
|
+
Nothing
|
1307
|
+
|
1308
|
+
Example::
|
1309
|
+
|
1310
|
+
# Print out a list of supported models
|
1311
|
+
from ads.hpo.ads_search_space import model_list
|
1312
|
+
print(model_list)
|
1313
|
+
|
1314
|
+
# Example scoring dictionary
|
1315
|
+
{'model':'/home/datascience/advanced-ds/notebooks/scratch/ADSTunerV2/mymodel.py',
|
1316
|
+
'scoring':'/home/datascience/advanced-ds/notebooks/scratch/ADSTunerV2/customized_scoring.py'}
|
1317
|
+
|
1318
|
+
Example::
|
1319
|
+
|
1320
|
+
from ads.hpo.stopping_criterion import *
|
1321
|
+
from ads.hpo.search_cv import ADSTuner
|
1322
|
+
from sklearn.datasets import load_iris
|
1323
|
+
from sklearn.linear_model import SGDClassifier
|
1324
|
+
|
1325
|
+
tuner = ADSTuner(
|
1326
|
+
SGDClassifier(),
|
1327
|
+
strategy='detailed',
|
1328
|
+
scoring='f1_weighted',
|
1329
|
+
random_state=42
|
1330
|
+
)
|
1331
|
+
tuner.search_space({'max_iter': 100})
|
1332
|
+
X, y = load_iris(return_X_y=True)
|
1333
|
+
tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)], synchronous=True)
|
1334
|
+
tuner.trials_export('oci://<bucket_name>@<namespace>/tuner/test.zip')
|
1335
|
+
"""
|
1336
|
+
# oci://bucketname@namespace/filename
|
1337
|
+
from ads.hpo.tuner_artifact import UploadTunerArtifact
|
1338
|
+
|
1339
|
+
assert self._is_tuning_finished()
|
1340
|
+
assert script_dict.keys() <= set(
|
1341
|
+
["model", "scoring"]
|
1342
|
+
), "script_dict keys can only be model and scoring."
|
1343
|
+
|
1344
|
+
UploadTunerArtifact(self, file_uri, metadata).upload(script_dict)
|
1345
|
+
|
1346
|
+
@classmethod
|
1347
|
+
def trials_import(cls, file_uri, delete_zip_file=True, target_file_path=None):
|
1348
|
+
"""Import the database file from the object storage
|
1349
|
+
|
1350
|
+
Parameters
|
1351
|
+
----------
|
1352
|
+
file_uri: str
|
1353
|
+
'oci://bucketname@namespace/filepath/on/objectstorage'
|
1354
|
+
Example: 'oci://<bucket_name>@<namespace>/tuner/test.zip'
|
1355
|
+
delete_zip_file: bool, defaults to True, optional
|
1356
|
+
Whether delete the zip file afterwards.
|
1357
|
+
target_file_path: str, optional
|
1358
|
+
The path where the zip file will be saved. For example, '/home/datascience/myfile.zip'.
|
1359
|
+
|
1360
|
+
Returns
|
1361
|
+
-------
|
1362
|
+
:class:`ADSTuner`
|
1363
|
+
ADSTuner object
|
1364
|
+
|
1365
|
+
Examples
|
1366
|
+
--------
|
1367
|
+
>>> from ads.hpo.stopping_criterion import *
|
1368
|
+
>>> from ads.hpo.search_cv import ADSTuner
|
1369
|
+
>>> from sklearn.datasets import load_iris
|
1370
|
+
>>> from sklearn.linear_model import SGDClassifier
|
1371
|
+
>>> X, y = load_iris(return_X_y=True)
|
1372
|
+
>>> tuner = ADSTuner.trials_import('oci://<bucket_name>@<namespace>/tuner/test.zip')
|
1373
|
+
>>> tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)], synchronous=True)
|
1374
|
+
"""
|
1375
|
+
from ads.hpo.tuner_artifact import DownloadTunerArtifact
|
1376
|
+
|
1377
|
+
tuner_args, cls.metadata = DownloadTunerArtifact(
|
1378
|
+
file_uri, target_file_path=target_file_path
|
1379
|
+
).extract_tuner_args(delete_zip_file=delete_zip_file)
|
1380
|
+
return cls(**tuner_args)
|
1381
|
+
|
1382
|
+
@runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK)
|
1383
|
+
def _plot(
|
1384
|
+
self, # type: ADSTuner
|
1385
|
+
plot_module, # type: str
|
1386
|
+
plot_func, # type: str
|
1387
|
+
time_interval=0.5, # type: float
|
1388
|
+
fig_size=(800, 500), # type: tuple
|
1389
|
+
**kwargs,
|
1390
|
+
):
|
1391
|
+
if fig_size:
|
1392
|
+
logger.warning(
|
1393
|
+
"The param fig_size will be depreciated in future releases.",
|
1394
|
+
)
|
1395
|
+
|
1396
|
+
spec = importlib.util.spec_from_file_location(
|
1397
|
+
"plot",
|
1398
|
+
os.path.join(
|
1399
|
+
os.path.dirname(os.path.abspath(__file__)),
|
1400
|
+
"visualization",
|
1401
|
+
plot_module + ".py",
|
1402
|
+
),
|
1403
|
+
)
|
1404
|
+
plot = importlib.util.module_from_spec(spec)
|
1405
|
+
spec.loader.exec_module(plot)
|
1406
|
+
|
1407
|
+
_imports.check()
|
1408
|
+
assert self._study is not None, "Need to call <code>.tune()</code> first."
|
1409
|
+
ntrials = 0
|
1410
|
+
if plot_func == "_plot_param_importances":
|
1411
|
+
print("Waiting for more trials before evaluating the param importance.")
|
1412
|
+
while self.status == State.RUNNING:
|
1413
|
+
import time
|
1414
|
+
from IPython.display import clear_output
|
1415
|
+
|
1416
|
+
time.sleep(time_interval)
|
1417
|
+
if len(self.trials[~self.trials["value"].isnull()]) > ntrials:
|
1418
|
+
if plot_func == "_plot_param_importances":
|
1419
|
+
if len(self.trials[~self.trials["value"].isnull()]) >= 4:
|
1420
|
+
clear_output(wait=True)
|
1421
|
+
getattr(plot, plot_func)(
|
1422
|
+
study=self._study, fig_size=fig_size, **kwargs
|
1423
|
+
)
|
1424
|
+
clear_output(wait=True)
|
1425
|
+
else:
|
1426
|
+
getattr(plot, plot_func)(
|
1427
|
+
study=self._study, fig_size=fig_size, **kwargs
|
1428
|
+
)
|
1429
|
+
clear_output(wait=True)
|
1430
|
+
if len(self.trials) == 0:
|
1431
|
+
plt.figure()
|
1432
|
+
plt.title("Intermediate Values Plot")
|
1433
|
+
plt.xlabel("Step")
|
1434
|
+
plt.ylabel("Intermediate Value")
|
1435
|
+
plt.show(block=False)
|
1436
|
+
|
1437
|
+
ntrials = len(self.trials[~self.trials["value"].isnull()])
|
1438
|
+
getattr(plot, plot_func)(study=self._study, fig_size=fig_size, **kwargs)
|
1439
|
+
|
1440
|
+
def plot_best_scores(
|
1441
|
+
self,
|
1442
|
+
best=True, # type: bool
|
1443
|
+
inferior=True, # type: bool
|
1444
|
+
time_interval=1, # type: float
|
1445
|
+
fig_size=(800, 500), # type: tuple
|
1446
|
+
):
|
1447
|
+
"""Plot optimization history of all trials in a study.
|
1448
|
+
|
1449
|
+
Parameters
|
1450
|
+
----------
|
1451
|
+
best:
|
1452
|
+
controls whether to plot the lines for the best scores so far.
|
1453
|
+
inferior:
|
1454
|
+
controls whether to plot the dots for the actual objective scores.
|
1455
|
+
time_interval:
|
1456
|
+
how often(in seconds) the plot refresh to check on the new trial results.
|
1457
|
+
fig_size: tuple
|
1458
|
+
width and height of the figure.
|
1459
|
+
|
1460
|
+
Returns
|
1461
|
+
-------
|
1462
|
+
None
|
1463
|
+
Nothing.
|
1464
|
+
"""
|
1465
|
+
self._plot(
|
1466
|
+
"_optimization_history",
|
1467
|
+
"_get_optimization_history_plot",
|
1468
|
+
time_interval=time_interval,
|
1469
|
+
fig_size=fig_size,
|
1470
|
+
best=best,
|
1471
|
+
inferior=inferior,
|
1472
|
+
)
|
1473
|
+
|
1474
|
+
@runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
|
1475
|
+
def plot_param_importance(
|
1476
|
+
self,
|
1477
|
+
importance_evaluator="Fanova", # type: str
|
1478
|
+
time_interval=1, # type: float
|
1479
|
+
fig_size=(800, 500), # type: tuple
|
1480
|
+
):
|
1481
|
+
"""Plot hyperparameter importances.
|
1482
|
+
|
1483
|
+
Parameters
|
1484
|
+
----------
|
1485
|
+
importance_evaluator: str
|
1486
|
+
Importance evaluator. Valid values: "Fanova", "MeanDecreaseImpurity". Defaults
|
1487
|
+
to "Fanova".
|
1488
|
+
time_interval: float
|
1489
|
+
How often the plot refresh to check on the new trial results.
|
1490
|
+
fig_size: tuple
|
1491
|
+
Width and height of the figure.
|
1492
|
+
|
1493
|
+
Raises
|
1494
|
+
------
|
1495
|
+
:class:`NotImplementedErorr`
|
1496
|
+
Raised for unsupported importance evaluators
|
1497
|
+
|
1498
|
+
Returns
|
1499
|
+
-------
|
1500
|
+
None
|
1501
|
+
Nothing.
|
1502
|
+
"""
|
1503
|
+
assert importance_evaluator in [
|
1504
|
+
"MeanDecreaseImpurity",
|
1505
|
+
"Fanova",
|
1506
|
+
], "Only support <code>MeanDecreaseImpurity</code> and <code>Fanova</code>."
|
1507
|
+
if importance_evaluator == "Fanova":
|
1508
|
+
evaluator = None
|
1509
|
+
elif importance_evaluator == "MeanDecreaseImpurity":
|
1510
|
+
evaluator = optuna.importance.MeanDecreaseImpurityImportanceEvaluator()
|
1511
|
+
else:
|
1512
|
+
raise NotImplemented(
|
1513
|
+
f"{importance_evaluator} is not supported. It can be either `Fanova` or `MeanDecreaseImpurity`."
|
1514
|
+
)
|
1515
|
+
try:
|
1516
|
+
self._plot(
|
1517
|
+
plot_module="_param_importances",
|
1518
|
+
plot_func="_plot_param_importances",
|
1519
|
+
time_interval=time_interval,
|
1520
|
+
fig_size=fig_size,
|
1521
|
+
evaluator=evaluator,
|
1522
|
+
)
|
1523
|
+
except:
|
1524
|
+
logger.error(
|
1525
|
+
msg="""Cannot calculate the hyperparameter importance. Increase the number of trials or time budget. """
|
1526
|
+
)
|
1527
|
+
|
1528
|
+
def plot_intermediate_scores(
|
1529
|
+
self,
|
1530
|
+
time_interval=1, # type: float
|
1531
|
+
fig_size=(800, 500), # type: tuple
|
1532
|
+
):
|
1533
|
+
"""
|
1534
|
+
Plot intermediate values of all trials in a study.
|
1535
|
+
|
1536
|
+
Parameters
|
1537
|
+
----------
|
1538
|
+
time_interval: float
|
1539
|
+
Time interval for the plot. Defaults to 1.
|
1540
|
+
fig_size: tuple[int, int]
|
1541
|
+
Figure size. Defaults to (800, 500).
|
1542
|
+
|
1543
|
+
Returns
|
1544
|
+
-------
|
1545
|
+
None
|
1546
|
+
Nothing.
|
1547
|
+
"""
|
1548
|
+
if not self._enable_pruning:
|
1549
|
+
logger.error(
|
1550
|
+
msg="Pruning was not used during tuning. "
|
1551
|
+
"There are no intermediate values to plot."
|
1552
|
+
)
|
1553
|
+
|
1554
|
+
self._plot(
|
1555
|
+
"_intermediate_values",
|
1556
|
+
"_get_intermediate_plot",
|
1557
|
+
time_interval=time_interval,
|
1558
|
+
fig_size=fig_size,
|
1559
|
+
)
|
1560
|
+
|
1561
|
+
def plot_edf_scores(
|
1562
|
+
self,
|
1563
|
+
time_interval=1, # type: float
|
1564
|
+
fig_size=(800, 500), # type: tuple
|
1565
|
+
):
|
1566
|
+
"""
|
1567
|
+
Plot the EDF (empirical distribution function) of the scores.
|
1568
|
+
|
1569
|
+
Only completed trials are used.
|
1570
|
+
|
1571
|
+
Parameters
|
1572
|
+
----------
|
1573
|
+
time_interval: float
|
1574
|
+
Time interval for the plot. Defaults to 1.
|
1575
|
+
fig_size: tuple[int, int]
|
1576
|
+
Figure size. Defaults to (800, 500).
|
1577
|
+
|
1578
|
+
Returns
|
1579
|
+
-------
|
1580
|
+
None
|
1581
|
+
Nothing.
|
1582
|
+
"""
|
1583
|
+
self._plot(
|
1584
|
+
"_edf", "_get_edf_plot", time_interval=time_interval, fig_size=fig_size
|
1585
|
+
)
|
1586
|
+
|
1587
|
+
def plot_contour_scores(
|
1588
|
+
self,
|
1589
|
+
params=None, # type: Optional[List[str]]
|
1590
|
+
time_interval=1, # type: float
|
1591
|
+
fig_size=(800, 500), # type: tuple
|
1592
|
+
):
|
1593
|
+
"""
|
1594
|
+
Contour plot of the scores.
|
1595
|
+
|
1596
|
+
Parameters
|
1597
|
+
----------
|
1598
|
+
params: Optional[List[str]]
|
1599
|
+
Parameter list to visualize. Defaults to all.
|
1600
|
+
time_interval: float
|
1601
|
+
Time interval for the plot. Defaults to 1.
|
1602
|
+
fig_size: tuple[int, int]
|
1603
|
+
Figure size. Defaults to (800, 500).
|
1604
|
+
|
1605
|
+
Returns
|
1606
|
+
-------
|
1607
|
+
None
|
1608
|
+
Nothing.
|
1609
|
+
"""
|
1610
|
+
validate_params_for_plot(params, self._param_distributions)
|
1611
|
+
try:
|
1612
|
+
self._plot(
|
1613
|
+
"_contour",
|
1614
|
+
"_get_contour_plot",
|
1615
|
+
time_interval=time_interval,
|
1616
|
+
fig_size=fig_size,
|
1617
|
+
params=params,
|
1618
|
+
)
|
1619
|
+
except ValueError:
|
1620
|
+
logger.warning(
|
1621
|
+
msg="Cannot plot contour score."
|
1622
|
+
" Increase the number of trials or time budget."
|
1623
|
+
)
|
1624
|
+
|
1625
|
+
def plot_parallel_coordinate_scores(
|
1626
|
+
self,
|
1627
|
+
params=None, # type: Optional[List[str]]
|
1628
|
+
time_interval=1, # type: float
|
1629
|
+
fig_size=(800, 500), # type: tuple
|
1630
|
+
):
|
1631
|
+
"""
|
1632
|
+
Plot the high-dimentional parameter relationships in a study.
|
1633
|
+
|
1634
|
+
Note that, If a parameter contains missing values, a trial with missing values is not plotted.
|
1635
|
+
|
1636
|
+
Parameters
|
1637
|
+
----------
|
1638
|
+
params: Optional[List[str]]
|
1639
|
+
Parameter list to visualize. Defaults to all.
|
1640
|
+
time_interval: float
|
1641
|
+
Time interval for the plot. Defaults to 1.
|
1642
|
+
fig_size: tuple[int, int]
|
1643
|
+
Figure size. Defaults to (800, 500).
|
1644
|
+
|
1645
|
+
Returns
|
1646
|
+
-------
|
1647
|
+
None
|
1648
|
+
Nothing.
|
1649
|
+
"""
|
1650
|
+
validate_params_for_plot(params, self._param_distributions)
|
1651
|
+
self._plot(
|
1652
|
+
"_parallel_coordinate",
|
1653
|
+
"_get_parallel_coordinate_plot",
|
1654
|
+
time_interval=time_interval,
|
1655
|
+
fig_size=fig_size,
|
1656
|
+
params=params,
|
1657
|
+
)
|