oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +247 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,983 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8; -*-
|
3
|
+
|
4
|
+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
from __future__ import print_function, absolute_import, division
|
8
|
+
|
9
|
+
import base64
|
10
|
+
from io import BytesIO
|
11
|
+
import matplotlib as mpl
|
12
|
+
import matplotlib.pyplot as plt
|
13
|
+
import matplotlib.lines as mlines
|
14
|
+
from matplotlib.ticker import FormatStrFormatter
|
15
|
+
import numpy as np
|
16
|
+
import math
|
17
|
+
from ads.common import logger
|
18
|
+
from ads.common.decorator.runtime_dependency import (
|
19
|
+
runtime_dependency,
|
20
|
+
OptionalDependency,
|
21
|
+
)
|
22
|
+
import itertools
|
23
|
+
import pandas as pd
|
24
|
+
|
25
|
+
MAX_TITLE_LEN = 20
|
26
|
+
MAX_LEGEND_LEN = 10
|
27
|
+
MAX_PLOTS_PER_ROW = 2
|
28
|
+
# Maximum class number evaluation plotting supporting for multiclass problems
|
29
|
+
MAX_PLOTTING_CLASSES = 10
|
30
|
+
# Maximum characters in class label able to be shown without being truncated
|
31
|
+
MAX_CHARACTERS_LEN = 13
|
32
|
+
|
33
|
+
|
34
|
+
def _fig_to_html(fig):
|
35
|
+
tmpfile = BytesIO()
|
36
|
+
fig.savefig(tmpfile, format="png")
|
37
|
+
encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8")
|
38
|
+
|
39
|
+
html = "<img src='data:image/png;base64,{}'>".format(encoded)
|
40
|
+
return html
|
41
|
+
|
42
|
+
|
43
|
+
class EvaluationPlot:
|
44
|
+
"""EvaluationPlot holds data and methods for plots and it used to output them
|
45
|
+
|
46
|
+
Attributes
|
47
|
+
----------
|
48
|
+
baseline (bool):
|
49
|
+
whether to plot the null model or zero information model
|
50
|
+
baseline_kwargs (dict):
|
51
|
+
keyword arguments for the baseline plot
|
52
|
+
color_wheel (dict):
|
53
|
+
color information used by the plot
|
54
|
+
font_sz (dict):
|
55
|
+
dictionary of plot methods
|
56
|
+
perfect (bool):
|
57
|
+
determines whether a "perfect" classifier curve is displayed
|
58
|
+
perfect_kwargs (dict):
|
59
|
+
parameters for the perfect classifier for precision/recall curves
|
60
|
+
prob_type (str):
|
61
|
+
model type, i.e. classification or regression
|
62
|
+
|
63
|
+
Methods
|
64
|
+
-------
|
65
|
+
get_legend_labels(legend_labels)
|
66
|
+
Renders the legend labels on the plot
|
67
|
+
plot(evaluation, plots, num_classes, perfect, baseline, legend_labels)
|
68
|
+
Generates the evalation plot
|
69
|
+
"""
|
70
|
+
|
71
|
+
# dict of plot methods
|
72
|
+
font_sz = {
|
73
|
+
"xl": 16, # Plot type title
|
74
|
+
"l": 14, # Individual plot title (name of model)
|
75
|
+
"m": 12, # Axis titles
|
76
|
+
"s": 10, # Axis labels
|
77
|
+
"xs": 8,
|
78
|
+
} # test within the plot
|
79
|
+
|
80
|
+
baseline_kwargs = {"ls": "--", "c": ".2"} # lw = ??
|
81
|
+
perfect_kwargs = {"label": "Perfect Classifier", "ls": "--", "color": "gold"}
|
82
|
+
perfect = None
|
83
|
+
baseline = None
|
84
|
+
prob_type = None
|
85
|
+
color_wheel = ["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"]
|
86
|
+
|
87
|
+
_pretty_titles_map = {
|
88
|
+
"normalized_confusion_matrix": "Normalized Confusion Matrix",
|
89
|
+
"lift_chart": "Lift Chart",
|
90
|
+
"gain_chart": "Gain Chart",
|
91
|
+
"ks_statistics": "KS Statistics",
|
92
|
+
"residuals_qq": "Residuals Q-Q Plot",
|
93
|
+
"residuals_vs_predicted": "Residuals vs Predicted",
|
94
|
+
"residuals_vs_observed": "Residuals vs Observed",
|
95
|
+
"observed_vs_predicted": "Observed vs Predicted",
|
96
|
+
"precision_by_label": "Precision by Label",
|
97
|
+
"recall_by_label": "Recall by Label",
|
98
|
+
"f1_by_label": "F1 by Label",
|
99
|
+
"jaccard_by_label": "Jaccard by Label",
|
100
|
+
"pr_curve": "PR Curve",
|
101
|
+
"roc_curve": "ROC Curve",
|
102
|
+
"pr_and_roc_curve": "PR Curve, ROC Curve",
|
103
|
+
"lift_and_gain_chart": "Lift Chart, Gain Chart",
|
104
|
+
}
|
105
|
+
|
106
|
+
_ugly_titles_map = {v: k for k, v in _pretty_titles_map.items()}
|
107
|
+
|
108
|
+
double_overlay_plots = ["pr_and_roc_curve", "lift_and_gain_chart"]
|
109
|
+
single_overlay_plots = ["lift_chart", "gain_chart", "roc_curve", "pr_curve"]
|
110
|
+
|
111
|
+
_bin_plots = [
|
112
|
+
"pr_curve",
|
113
|
+
"roc_curve",
|
114
|
+
"lift_chart",
|
115
|
+
"gain_chart",
|
116
|
+
"normalized_confusion_matrix",
|
117
|
+
]
|
118
|
+
_multi_plots = [
|
119
|
+
"normalized_confusion_matrix",
|
120
|
+
"roc_curve",
|
121
|
+
"pr_curve",
|
122
|
+
"precision_by_label",
|
123
|
+
"recall_by_label",
|
124
|
+
"f1_by_label",
|
125
|
+
"jaccard_by_label",
|
126
|
+
]
|
127
|
+
_reg_plots = [
|
128
|
+
"observed_vs_predicted",
|
129
|
+
"residuals_qq",
|
130
|
+
"residuals_vs_predicted",
|
131
|
+
"residuals_vs_observed",
|
132
|
+
]
|
133
|
+
|
134
|
+
# list of detailed descriptions of each plot type for every classification type, can be extended when adding more metrics
|
135
|
+
_bin_plots_details = """
|
136
|
+
In pattern recognition, information retrieval and binary classification, precision (also called positive predictive value)
|
137
|
+
is the fraction of relevant instances among the retrieved instances, while recall (also known as sensitivity) is the
|
138
|
+
fraction of relevant instances that have been retrieved over the total amount of relevant instances. \n
|
139
|
+
A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a
|
140
|
+
binary classifier system as its discrimination threshold is varied. \n
|
141
|
+
In data mining and association rule learning, lift is a measure of the performance of a targeting model (association rule)
|
142
|
+
at predicting or classifying cases as having an enhanced response (with respect to the population as a whole), measured
|
143
|
+
against a random choice targeting model. \n
|
144
|
+
A gain graph is a graph whose edges are labelled 'invertibly', or 'orientably', by elements of a group G. \n
|
145
|
+
In the field of machine learning and specifically the problem of statistical classification, a confusion matrix, also known
|
146
|
+
as an error matrix, is a specific table layout that allows visualization of the performance of an algorithm, typically a
|
147
|
+
supervised learning one (in unsupervised learning it is usually called a matching matrix).
|
148
|
+
"""
|
149
|
+
# Removed for now:
|
150
|
+
# Kuiper 's test (ks_statistics) is used in statistics to test that whether a given distribution, or family of
|
151
|
+
# distributions, is contradicted by evidence from a sample of data. \n
|
152
|
+
|
153
|
+
_multi_plots_details = """
|
154
|
+
In the field of machine learning and specifically the problem of statistical classification, a confusion matrix, also known
|
155
|
+
as an error matrix, is a specific table layout that allows visualization of the performance of an algorithm, typically a
|
156
|
+
supervised learning one (in unsupervised learning it is usually called a matching matrix). \n
|
157
|
+
A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a
|
158
|
+
binary classifier system as its discrimination threshold is varied. \n
|
159
|
+
In pattern recognition, information retrieval and binary classification, precision (also called positive predictive value)
|
160
|
+
is the fraction of relevant instances among the retrieved instances, while recall (also known as sensitivity) is the
|
161
|
+
fraction of relevant instances that have been retrieved over the total amount of relevant instances. \n
|
162
|
+
In statistical analysis of binary classification, the F1 score (also F-score or F-measure) is a measure of a test's accuracy.
|
163
|
+
It considers both the precision p and the recall r of the test to compute the score: p is the number of correct positive results
|
164
|
+
divided by the number of all positive results returned by the classifier, and r is the number of correct positive results divided
|
165
|
+
by the number of all relevant samples (all samples that should have been identified as positive). The F1 score is the harmonic mean
|
166
|
+
of the precision and recall, where an F1 score reaches its best value at 1 (perfect precision and recall) and worst at 0. \n
|
167
|
+
The Jaccard index, also known as Intersection over Union and the Jaccard similarity coefficient, is a statistic used for gauging
|
168
|
+
the similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, and is defined
|
169
|
+
as the size of the intersection divided by the size of the union of the sample sets
|
170
|
+
"""
|
171
|
+
_reg_plots_details = """
|
172
|
+
In statistics, a Q–Q (quantile-quantile) plot is a probability plot, which is a graphical method for comparing two probability distributions
|
173
|
+
by plotting their quantiles against each other. \n
|
174
|
+
In statistics and optimization, errors and residuals are two closely related and easily confused measures of the deviation of an observed value
|
175
|
+
of an element of a statistical sample from its "theoretical value". The error (or disturbance) of an observed value is the deviation of the
|
176
|
+
observed value from the (unobservable) true value of a quantity of interest (for example, a population mean), and the residual of an observed
|
177
|
+
value is the difference between the observed value and the estimated value of the quantity of interest (for example, a sample mean).
|
178
|
+
"""
|
179
|
+
|
180
|
+
@classmethod
|
181
|
+
def _get_formatted_title(cls, title, max_len=MAX_TITLE_LEN):
|
182
|
+
return title if len(title) < max_len + 3 else title[:max_len] + "..."
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def get_legend_labels(cls, legend_labels):
|
186
|
+
"""Gets the legend labels, resolves any conflicts such as length, and renders
|
187
|
+
the labels for the plot
|
188
|
+
|
189
|
+
Parameters
|
190
|
+
----------
|
191
|
+
legend_labels (dict):
|
192
|
+
key/value dictionary containing legend label data
|
193
|
+
|
194
|
+
Returns
|
195
|
+
-------
|
196
|
+
Nothing
|
197
|
+
|
198
|
+
Examples
|
199
|
+
--------
|
200
|
+
|
201
|
+
EvaluationPlot.get_legend_labels({'class_0': 'green', 'class_1': 'yellow', 'class_2': 'red'})
|
202
|
+
"""
|
203
|
+
|
204
|
+
@runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK)
|
205
|
+
@runtime_dependency(
|
206
|
+
module="ipywidgets", object="HTML", install_from=OptionalDependency.NOTEBOOK
|
207
|
+
)
|
208
|
+
def render_legend_labels(label_dict):
|
209
|
+
encodings = pd.DataFrame(
|
210
|
+
pd.Series(label_dict, index=label_dict.keys()),
|
211
|
+
columns=["Shortened labels"],
|
212
|
+
)
|
213
|
+
from IPython.core.display import display, HTML
|
214
|
+
|
215
|
+
display(
|
216
|
+
HTML(
|
217
|
+
encodings.style.format(precision=4)
|
218
|
+
.set_properties(**{"text-align": "center"})
|
219
|
+
.set_table_styles(
|
220
|
+
[dict(selector="", props=[("text-align", "center")])]
|
221
|
+
)
|
222
|
+
.set_table_attributes("class=table")
|
223
|
+
.set_caption(
|
224
|
+
'<div align="left"><b style="font-size:20px;">'
|
225
|
+
+ "Legend for labels of the target feature:</b></div>"
|
226
|
+
)
|
227
|
+
.to_html()
|
228
|
+
)
|
229
|
+
)
|
230
|
+
|
231
|
+
if legend_labels is not None:
|
232
|
+
# CAUTION: cls.classes is a list of strings. Make sure users know that labels are all converted to strings.
|
233
|
+
if isinstance(legend_labels, dict) and set(cls.classes).issubset(
|
234
|
+
set(legend_labels.keys())
|
235
|
+
):
|
236
|
+
render_legend_labels(legend_labels)
|
237
|
+
cls.legend_labels = legend_labels
|
238
|
+
else:
|
239
|
+
logger.error(
|
240
|
+
"The provided `legend_labels` is either not a Python dict or does not possess all possible class labels."
|
241
|
+
)
|
242
|
+
return
|
243
|
+
|
244
|
+
# try to remove leading words
|
245
|
+
def _check_for_redundant_words(label_vec, prefix, max_len=MAX_LEGEND_LEN):
|
246
|
+
words_found = False
|
247
|
+
classes = [lab for lab in label_vec]
|
248
|
+
while len(set([lab.split()[0] for lab in classes])) <= 1:
|
249
|
+
# remove that word from vec, and add it to prefix
|
250
|
+
additional_prefix = classes[0].split()[0]
|
251
|
+
prefix = (
|
252
|
+
prefix[:-3] + additional_prefix + " ..."
|
253
|
+
if words_found
|
254
|
+
else additional_prefix + " ..."
|
255
|
+
)
|
256
|
+
classes = [lab[len(additional_prefix) + 1 :] for lab in classes]
|
257
|
+
words_found = True
|
258
|
+
|
259
|
+
classes = (
|
260
|
+
classes if words_found else [label[max_len:] for label in label_vec]
|
261
|
+
)
|
262
|
+
return classes, prefix
|
263
|
+
|
264
|
+
# returns mapping from real labels to psuedo-labels, when psuedo-labels are not the first X letter
|
265
|
+
def _resolve_conflict(label_vec, prefix, max_len=MAX_LEGEND_LEN):
|
266
|
+
classes, prefix = _check_for_redundant_words(label_vec, prefix)
|
267
|
+
label_dict = _get_labels(classes, max_len=max_len)
|
268
|
+
resolved = {}
|
269
|
+
for orig, new in label_dict.items():
|
270
|
+
resolved[prefix[:-3] + orig] = "..." + new if new[:3] != "..." else new
|
271
|
+
return resolved
|
272
|
+
|
273
|
+
# returns mapping from provided list of strings, to short, unique substrings
|
274
|
+
def _get_labels(classes, max_len=MAX_LEGEND_LEN):
|
275
|
+
conflict_dict = {}
|
276
|
+
for label in classes:
|
277
|
+
prefix = label if len(label) < max_len + 3 else label[:max_len] + "..."
|
278
|
+
if conflict_dict.get(prefix, None) is None:
|
279
|
+
conflict_dict[prefix] = [label]
|
280
|
+
else:
|
281
|
+
conflict_dict[prefix].append(label)
|
282
|
+
out = {}
|
283
|
+
for k, v in conflict_dict.items():
|
284
|
+
if len(v) == 1:
|
285
|
+
out[v[0]] = k
|
286
|
+
else:
|
287
|
+
resolved = _resolve_conflict(v, k, max_len=MAX_LEGEND_LEN)
|
288
|
+
out.update(resolved)
|
289
|
+
return out
|
290
|
+
|
291
|
+
cls.legend_labels = _get_labels(cls.classes)
|
292
|
+
if set(cls.legend_labels.keys()) != set(cls.legend_labels.values()):
|
293
|
+
logger.info(
|
294
|
+
f"Class labels greater than {MAX_CHARACTERS_LEN} characters have been truncated. "
|
295
|
+
"Use the `legend_labels` parameter to define labels."
|
296
|
+
)
|
297
|
+
render_legend_labels(cls.legend_labels)
|
298
|
+
|
299
|
+
# evaluation is a DataFrame with models as columns and metrics as rows
|
300
|
+
@classmethod
|
301
|
+
def plot(
|
302
|
+
cls,
|
303
|
+
evaluation,
|
304
|
+
plots,
|
305
|
+
num_classes,
|
306
|
+
perfect=False,
|
307
|
+
baseline=True,
|
308
|
+
legend_labels=None,
|
309
|
+
):
|
310
|
+
"""Generates the evaluation plot
|
311
|
+
|
312
|
+
Parameters
|
313
|
+
----------
|
314
|
+
evaluation (DataFrame):
|
315
|
+
DataFrame with models as columns and metrics as rows.
|
316
|
+
plots (str):
|
317
|
+
The plot type based on class attribute `prob_type`.
|
318
|
+
num_classes (int):
|
319
|
+
The number of classes for the model.
|
320
|
+
perfect (bool, optional):
|
321
|
+
Whether to display the curve of a perfect classifier. Default value is `False`.
|
322
|
+
baseline (bool, optional):
|
323
|
+
Whether to display the curve of the baseline, featureless model. Default value is `True`.
|
324
|
+
legend_labels (dict, optional):
|
325
|
+
Legend labels dictionary. Default value is `None`. If legend_labels not specified class names will be used for plots.
|
326
|
+
|
327
|
+
Returns
|
328
|
+
-------
|
329
|
+
Nothing
|
330
|
+
"""
|
331
|
+
|
332
|
+
cls.perfect = perfect
|
333
|
+
cls.baseline = baseline
|
334
|
+
# get plots to show
|
335
|
+
if num_classes == 2:
|
336
|
+
cls.prob_type = "_bin"
|
337
|
+
elif num_classes > 2:
|
338
|
+
cls.prob_type = "_multi"
|
339
|
+
else:
|
340
|
+
cls.prob_type = "_reg"
|
341
|
+
plot_details = getattr(cls, cls.prob_type + "_plots_details")
|
342
|
+
if plots is None:
|
343
|
+
plots = getattr(cls, cls.prob_type + "_plots")
|
344
|
+
logger.info(
|
345
|
+
"Showing plot types: {}.".format(
|
346
|
+
", ".join(
|
347
|
+
[
|
348
|
+
"{}".format(EvaluationPlot._pretty_titles_map[str(p)])
|
349
|
+
for p in plots
|
350
|
+
]
|
351
|
+
),
|
352
|
+
", ".join(["{}".format(x) for x in map(str, plots)]),
|
353
|
+
)
|
354
|
+
)
|
355
|
+
logger.info(plot_details)
|
356
|
+
|
357
|
+
if cls.prob_type == "_bin":
|
358
|
+
if "lift_chart" in plots and "gain_chart" in plots:
|
359
|
+
plots.remove("lift_chart")
|
360
|
+
plots.remove("gain_chart")
|
361
|
+
plots.insert(0, "lift_and_gain_chart")
|
362
|
+
|
363
|
+
if "roc_curve" in plots and "pr_curve" in plots:
|
364
|
+
plots.remove("roc_curve")
|
365
|
+
plots.remove("pr_curve")
|
366
|
+
plots.insert(0, "pr_and_roc_curve")
|
367
|
+
elif cls.prob_type == "_multi":
|
368
|
+
if (
|
369
|
+
"normalized_confusion_matrix" in plots
|
370
|
+
and len(evaluation[evaluation.columns[0]]["classes"])
|
371
|
+
>= MAX_PLOTTING_CLASSES
|
372
|
+
):
|
373
|
+
logger.error(
|
374
|
+
f"Evaluation plotting is not yet supported for multiclass problems with {MAX_PLOTTING_CLASSES} or more classes."
|
375
|
+
)
|
376
|
+
plots = []
|
377
|
+
classes = evaluation[evaluation.columns[0]]["classes"]
|
378
|
+
if classes is not None:
|
379
|
+
# CAUTION: class labels are converted to strings here.
|
380
|
+
# If users are passing in legend_labels, they have to use strings as well.
|
381
|
+
# Otherwise get_legend_labels() will complaint.
|
382
|
+
# If users are not passing in legend_labels, get_legend_labels() generates them from cls.classes.
|
383
|
+
# cls.legend_labels are assigned/created in cls.get_lengend_labels() and contain only strings as keys.
|
384
|
+
cls.classes = [str(c) for c in classes]
|
385
|
+
cls.get_legend_labels(legend_labels)
|
386
|
+
|
387
|
+
mpl.style.use("default")
|
388
|
+
html_raw = []
|
389
|
+
for i, plot_type in enumerate(plots):
|
390
|
+
fig_title, fig = None, None
|
391
|
+
try:
|
392
|
+
fig_title, ax_title = plt.subplots(1, 1, figsize=(18, 0.5), dpi=144)
|
393
|
+
ax_title.text(
|
394
|
+
0.5,
|
395
|
+
0.5,
|
396
|
+
cls._pretty_titles_map[plot_type],
|
397
|
+
fontsize=16,
|
398
|
+
fontweight="semibold",
|
399
|
+
horizontalalignment="center",
|
400
|
+
verticalalignment="center",
|
401
|
+
transform=ax_title.transAxes,
|
402
|
+
)
|
403
|
+
ax_title.axis("off")
|
404
|
+
html_raw.append(_fig_to_html(fig_title))
|
405
|
+
if cls.prob_type == "_bin" and plot_type in cls.double_overlay_plots:
|
406
|
+
fig, ax = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144)
|
407
|
+
elif cls.prob_type == "_bin" and plot_type in ["roc_curve", "pr_curve"]:
|
408
|
+
fig, axs = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144)
|
409
|
+
axs[1].axis("off")
|
410
|
+
ax = [axs[0]]
|
411
|
+
elif cls.prob_type == "_bin" and plot_type in [
|
412
|
+
"lift_chart",
|
413
|
+
"gain_chart",
|
414
|
+
]:
|
415
|
+
fig, axs = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144)
|
416
|
+
axs[1].axis("off")
|
417
|
+
ax = axs[0]
|
418
|
+
else:
|
419
|
+
nrows = math.ceil(len(evaluation.columns) / MAX_PLOTS_PER_ROW)
|
420
|
+
fig, ax = plt.subplots(
|
421
|
+
nrows, MAX_PLOTS_PER_ROW, figsize=(10, 4 * nrows), dpi=144
|
422
|
+
) # 10, 3.5
|
423
|
+
ax = ax.flatten()
|
424
|
+
getattr(cls, "_" + plot_type)(ax, evaluation)
|
425
|
+
fig.tight_layout()
|
426
|
+
html_raw.append(_fig_to_html(fig))
|
427
|
+
except KeyError as e:
|
428
|
+
try:
|
429
|
+
if fig_title:
|
430
|
+
plt.close(fig=fig_title)
|
431
|
+
if fig:
|
432
|
+
plt.close(fig=fig)
|
433
|
+
except:
|
434
|
+
pass
|
435
|
+
logger.warning(
|
436
|
+
f"Evaluator was not able to plot "
|
437
|
+
f"{cls._pretty_titles_map.get(plot_type, plot_type)}, because the relevant "
|
438
|
+
f"metrics had complications. Ensure that `predict` and `predict_proba` "
|
439
|
+
f"are valid."
|
440
|
+
)
|
441
|
+
return html_raw
|
442
|
+
|
443
|
+
@classmethod
|
444
|
+
def _lift_and_gain_chart(cls, ax, evaluation):
|
445
|
+
cls._lift_chart(ax[0], evaluation)
|
446
|
+
cls._gain_chart(ax[1], evaluation)
|
447
|
+
|
448
|
+
@classmethod
|
449
|
+
def _lift_chart(cls, ax, evaluation):
|
450
|
+
for mod_name, col in evaluation.items():
|
451
|
+
if col["y_score"] is not None:
|
452
|
+
ax.plot(
|
453
|
+
col["percentages"][1:],
|
454
|
+
[1] + list(col["lift"]),
|
455
|
+
label=cls._get_formatted_title(mod_name),
|
456
|
+
)
|
457
|
+
if cls.baseline:
|
458
|
+
ax.plot([-10, 110], [1, 1], **cls.baseline_kwargs)
|
459
|
+
if cls.perfect:
|
460
|
+
perf_idx = next(
|
461
|
+
idx
|
462
|
+
for idx, scores in enumerate(evaluation.loc["y_score"])
|
463
|
+
if scores is not None
|
464
|
+
)
|
465
|
+
ax.plot(
|
466
|
+
evaluation.loc["percentages"][perf_idx][1:],
|
467
|
+
[1] + list(evaluation.loc["perfect_lift"][perf_idx]),
|
468
|
+
**cls.perfect_kwargs,
|
469
|
+
)
|
470
|
+
ax.legend(loc="upper right", frameon=False)
|
471
|
+
ax.set_xlabel("Percentage of Population", fontsize=12)
|
472
|
+
ax.set_ylabel("Lift", fontsize=12)
|
473
|
+
ax.set_title("Lift Chart", y=1.08, fontsize=14)
|
474
|
+
ax.grid(linewidth=0.2, which="both")
|
475
|
+
ax.set_xlim([-10, 110])
|
476
|
+
|
477
|
+
@classmethod
|
478
|
+
def _gain_chart(cls, ax, evaluation):
|
479
|
+
for mod_name, col in evaluation.items():
|
480
|
+
if col["y_score"] is not None:
|
481
|
+
ax.plot(
|
482
|
+
col["percentages"],
|
483
|
+
list(col["cumulative_gain"]),
|
484
|
+
label=cls._get_formatted_title(mod_name),
|
485
|
+
)
|
486
|
+
if cls.baseline:
|
487
|
+
ax.plot([-10, 110], [-10, 110], **cls.baseline_kwargs)
|
488
|
+
if cls.perfect:
|
489
|
+
perf_idx = next(
|
490
|
+
idx
|
491
|
+
for idx, scores in enumerate(evaluation.loc["y_score"])
|
492
|
+
if scores is not None
|
493
|
+
)
|
494
|
+
ax.plot(
|
495
|
+
evaluation.loc["percentages"][perf_idx],
|
496
|
+
evaluation.loc["perfect_gain"][perf_idx],
|
497
|
+
**cls.perfect_kwargs,
|
498
|
+
)
|
499
|
+
ax.legend(loc="lower right", frameon=False)
|
500
|
+
ax.set_xlabel("Percentage of Population", fontsize=12)
|
501
|
+
ax.set_ylabel("Percentage of Positive Class", fontsize=12)
|
502
|
+
ax.set_title("Gain Chart", y=1.08, fontsize=14)
|
503
|
+
ax.grid(linewidth=0.2, which="both")
|
504
|
+
ax.set_xlim([-10, 110])
|
505
|
+
ax.set_ylim([-10, 110])
|
506
|
+
|
507
|
+
@classmethod
|
508
|
+
def _pr_and_roc_curve(cls, ax, evaluation):
|
509
|
+
cls._pr_curve([ax[0]], evaluation)
|
510
|
+
cls._roc_curve([ax[1]], evaluation)
|
511
|
+
|
512
|
+
@classmethod
|
513
|
+
def _pr_curve(cls, axs, evaluation):
|
514
|
+
n_models = len(evaluation.columns)
|
515
|
+
for i, ax in enumerate(axs):
|
516
|
+
if i >= n_models:
|
517
|
+
ax.axis("off")
|
518
|
+
return
|
519
|
+
if cls.prob_type == "_bin":
|
520
|
+
for mod_name, col in evaluation.items():
|
521
|
+
if col["y_score"] is not None:
|
522
|
+
ax.plot(
|
523
|
+
col["recall_values"],
|
524
|
+
col["precision_values"],
|
525
|
+
label="%s (Precision: %s)"
|
526
|
+
% (
|
527
|
+
cls._get_formatted_title(mod_name),
|
528
|
+
"{:.3f}".format(col["precision"]),
|
529
|
+
),
|
530
|
+
)
|
531
|
+
ax.plot(
|
532
|
+
*col["pr_best_model_score"],
|
533
|
+
color=ax.get_lines()[-1].get_color(),
|
534
|
+
marker="*",
|
535
|
+
)
|
536
|
+
else:
|
537
|
+
model_name = evaluation.columns[i]
|
538
|
+
mod = evaluation[model_name]
|
539
|
+
if mod["y_score"] is not None:
|
540
|
+
for j, lab in enumerate(mod.classes):
|
541
|
+
# cls.legend_labels contains only strings as keys.
|
542
|
+
lab = str(lab)
|
543
|
+
ax.plot(
|
544
|
+
mod["recall_values"][j],
|
545
|
+
mod["precision_values"][j],
|
546
|
+
label="%s (Precision: %s)"
|
547
|
+
% (
|
548
|
+
cls.legend_labels[lab],
|
549
|
+
"{:.3f}".format(mod["precision_by_label"][j]),
|
550
|
+
),
|
551
|
+
)
|
552
|
+
ax.plot(
|
553
|
+
*mod["pr_best_model_score"][j],
|
554
|
+
color=ax.get_lines()[-1].get_color(),
|
555
|
+
marker="*",
|
556
|
+
)
|
557
|
+
|
558
|
+
ax.set_xlabel("Recall", fontsize=12)
|
559
|
+
ax.set_ylabel("Precision", fontsize=12)
|
560
|
+
ax.set_title("Precision Recall Curve", y=1.08, fontsize=14)
|
561
|
+
ax.grid(linewidth=0.2, which="both")
|
562
|
+
ax.set_xlim([-0.1, 1.1])
|
563
|
+
ax.set_ylim([-0.1, 1.1])
|
564
|
+
handles, labels = ax.get_legend_handles_labels()
|
565
|
+
star = mlines.Line2D(
|
566
|
+
[],
|
567
|
+
[],
|
568
|
+
color="black",
|
569
|
+
marker="*",
|
570
|
+
linestyle="None",
|
571
|
+
markersize=5,
|
572
|
+
label="Minimum Error Rate",
|
573
|
+
)
|
574
|
+
handles.append(star)
|
575
|
+
labels.append("Minimum Error Rate")
|
576
|
+
ax.legend(
|
577
|
+
loc="upper right",
|
578
|
+
labels=labels,
|
579
|
+
handles=handles,
|
580
|
+
frameon=False,
|
581
|
+
fontsize="x-small",
|
582
|
+
)
|
583
|
+
|
584
|
+
@classmethod
|
585
|
+
def _roc_curve(cls, axs, evaluation):
|
586
|
+
n_models = len(evaluation.columns)
|
587
|
+
for i, ax in enumerate(axs):
|
588
|
+
if i >= n_models:
|
589
|
+
ax.axis("off")
|
590
|
+
return
|
591
|
+
if cls.prob_type == "_bin":
|
592
|
+
for mod_name, col in evaluation.items():
|
593
|
+
if col["y_score"] is not None:
|
594
|
+
ax.plot(
|
595
|
+
col["false_positive_rate"],
|
596
|
+
col["true_positive_rate"],
|
597
|
+
label="%s (AUC: %s)"
|
598
|
+
% (
|
599
|
+
cls._get_formatted_title(mod_name),
|
600
|
+
"{:.3f}".format(col["auc"]),
|
601
|
+
),
|
602
|
+
)
|
603
|
+
ax.plot(
|
604
|
+
*col["roc_best_model_score"],
|
605
|
+
color=ax.get_lines()[-1].get_color(),
|
606
|
+
marker="*",
|
607
|
+
)
|
608
|
+
else:
|
609
|
+
model_name = evaluation.columns[i]
|
610
|
+
mod = evaluation[model_name]
|
611
|
+
if mod["y_score"] is not None:
|
612
|
+
for j, lab in enumerate(mod.classes):
|
613
|
+
# cls.legend_labels contains only strings as keys.
|
614
|
+
lab = str(lab)
|
615
|
+
ax.plot(
|
616
|
+
mod["fpr_by_label"][j],
|
617
|
+
mod["tpr_by_label"][j],
|
618
|
+
label="%s (AUC: %s)"
|
619
|
+
% (cls.legend_labels[lab], "{:.3f}".format(mod["auc"][j])),
|
620
|
+
)
|
621
|
+
ax.plot(
|
622
|
+
*mod["roc_best_model_score"][j],
|
623
|
+
color=ax.get_lines()[-1].get_color(),
|
624
|
+
marker="*",
|
625
|
+
)
|
626
|
+
if cls.baseline:
|
627
|
+
ax.plot([-0.1, 1.1], [-0.1, 1.1], **cls.baseline_kwargs)
|
628
|
+
ax.set_xlabel("False Positive Rate", fontsize=12)
|
629
|
+
ax.set_ylabel("True Positive Rate", fontsize=12)
|
630
|
+
ax.set_title("ROC Curve", y=1.08, fontsize=14)
|
631
|
+
ax.grid(linewidth=0.2, which="both")
|
632
|
+
ax.set_xlim([-0.1, 1.1])
|
633
|
+
ax.set_ylim([-0.1, 1.1])
|
634
|
+
handles, labels = ax.get_legend_handles_labels()
|
635
|
+
star = mlines.Line2D(
|
636
|
+
[],
|
637
|
+
[],
|
638
|
+
color="black",
|
639
|
+
marker="*",
|
640
|
+
linestyle="None",
|
641
|
+
markersize=5,
|
642
|
+
label="Youden's J Statistic",
|
643
|
+
)
|
644
|
+
handles.append(star)
|
645
|
+
labels.append("Youden's J Statistic")
|
646
|
+
ax.legend(
|
647
|
+
loc="lower right",
|
648
|
+
labels=labels,
|
649
|
+
handles=handles,
|
650
|
+
frameon=False,
|
651
|
+
fontsize="x-small",
|
652
|
+
)
|
653
|
+
|
654
|
+
@classmethod
|
655
|
+
def _ks_statistics(cls, axs, evaluation):
|
656
|
+
n_models = len(evaluation.columns)
|
657
|
+
for i, ax in enumerate(axs):
|
658
|
+
if i >= n_models:
|
659
|
+
ax.axis("off")
|
660
|
+
return
|
661
|
+
model_name = evaluation.columns[i]
|
662
|
+
mod = evaluation[model_name]
|
663
|
+
|
664
|
+
ax.set_title(model_name, fontsize=14)
|
665
|
+
if mod["y_score"] is not None:
|
666
|
+
ax.plot(
|
667
|
+
mod["ks_thresholds"],
|
668
|
+
mod["ks_pct1"],
|
669
|
+
lw=3,
|
670
|
+
label=mod["ks_labels"][0],
|
671
|
+
)
|
672
|
+
ax.plot(
|
673
|
+
mod["ks_thresholds"],
|
674
|
+
mod["ks_pct2"],
|
675
|
+
lw=3,
|
676
|
+
label=mod["ks_labels"][1],
|
677
|
+
)
|
678
|
+
if cls.baseline:
|
679
|
+
idx = np.where(mod["ks_thresholds"] == mod["max_distance_at"])[0][0]
|
680
|
+
ax.axvline(
|
681
|
+
mod["max_distance_at"],
|
682
|
+
*sorted([mod["ks_pct1"][idx], mod["ks_pct2"][idx]]),
|
683
|
+
label="KS Statistic: {:.3f} at {:.3f}".format(
|
684
|
+
mod["ks_statistic"], mod["max_distance_at"]
|
685
|
+
),
|
686
|
+
linestyle="--",
|
687
|
+
color=".2",
|
688
|
+
)
|
689
|
+
|
690
|
+
ax.set_xlim([0.0, 1.0])
|
691
|
+
ax.set_ylim([0.0, 1.0])
|
692
|
+
|
693
|
+
ax.set_xlabel("Threshold", fontsize=12)
|
694
|
+
ax.set_ylabel("Percentage below threshold", fontsize=12)
|
695
|
+
ax.tick_params(labelsize=10)
|
696
|
+
ax.legend(loc="lower right", fontsize=8)
|
697
|
+
|
698
|
+
@classmethod
|
699
|
+
def _pretty_barh(
|
700
|
+
cls, ax, x, y, axis_labels=None, title=None, axis_lim=None, plot_kwargs=None
|
701
|
+
):
|
702
|
+
# cls.legend_labels contains only strings as keys.
|
703
|
+
new_lab = [cls.legend_labels[str(item)] for item in x]
|
704
|
+
ax.barh(
|
705
|
+
new_lab,
|
706
|
+
y,
|
707
|
+
color=["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"],
|
708
|
+
)
|
709
|
+
for j, v in enumerate(y):
|
710
|
+
ax.annotate("{:.3f}".format(v), xy=(v / 2, j), va="center", ha="left")
|
711
|
+
if axis_labels:
|
712
|
+
if axis_labels[0]:
|
713
|
+
ax.set_xlabel(axis_labels[0], fontsize=12)
|
714
|
+
if axis_labels[1]:
|
715
|
+
ax.set_ylabel(axis_labels[1], fontsize=12)
|
716
|
+
if title:
|
717
|
+
title = cls._get_formatted_title(title)
|
718
|
+
ax.set_title(title, y=1.08, fontsize=14)
|
719
|
+
if axis_lim:
|
720
|
+
ax.set_xlim(axis_lim)
|
721
|
+
|
722
|
+
@classmethod
|
723
|
+
def _precision_by_label(cls, axs, evaluation):
|
724
|
+
n_models = len(evaluation.columns)
|
725
|
+
for i, ax in enumerate(axs):
|
726
|
+
if i < n_models:
|
727
|
+
col = evaluation.columns[i]
|
728
|
+
cls._pretty_barh(
|
729
|
+
ax,
|
730
|
+
evaluation[col]["classes"],
|
731
|
+
evaluation[col]["precision_by_label"],
|
732
|
+
axis_lim=[0, 1],
|
733
|
+
axis_labels=["Precision", None],
|
734
|
+
title=col,
|
735
|
+
)
|
736
|
+
else:
|
737
|
+
ax.axis("off")
|
738
|
+
|
739
|
+
@classmethod
|
740
|
+
def _recall_by_label(cls, axs, evaluation):
|
741
|
+
n_models = len(evaluation.columns)
|
742
|
+
for i, ax in enumerate(axs):
|
743
|
+
if i < n_models:
|
744
|
+
col = evaluation.columns[i]
|
745
|
+
cls._pretty_barh(
|
746
|
+
ax,
|
747
|
+
evaluation[col]["classes"],
|
748
|
+
evaluation[col]["recall_by_label"],
|
749
|
+
axis_lim=[0, 1],
|
750
|
+
axis_labels=["Recall", None],
|
751
|
+
title=col,
|
752
|
+
)
|
753
|
+
else:
|
754
|
+
ax.axis("off")
|
755
|
+
|
756
|
+
@classmethod
|
757
|
+
def _f1_by_label(cls, axs, evaluation):
|
758
|
+
n_models = len(evaluation.columns)
|
759
|
+
for i, ax in enumerate(axs):
|
760
|
+
if i < n_models:
|
761
|
+
col = evaluation.columns[i]
|
762
|
+
cls._pretty_barh(
|
763
|
+
ax,
|
764
|
+
evaluation[col]["classes"],
|
765
|
+
evaluation[col]["f1_by_label"],
|
766
|
+
axis_lim=[0, 1],
|
767
|
+
axis_labels=["F1 Score", None],
|
768
|
+
title=col,
|
769
|
+
)
|
770
|
+
else:
|
771
|
+
ax.axis("off")
|
772
|
+
|
773
|
+
@classmethod
|
774
|
+
def _jaccard_by_label(cls, axs, evaluation):
|
775
|
+
n_models = len(evaluation.columns)
|
776
|
+
for i, ax in enumerate(axs):
|
777
|
+
if i < n_models:
|
778
|
+
col = evaluation.columns[i]
|
779
|
+
cls._pretty_barh(
|
780
|
+
ax,
|
781
|
+
evaluation[col]["classes"],
|
782
|
+
evaluation[col]["jaccard_by_label"],
|
783
|
+
axis_lim=[0, 1],
|
784
|
+
axis_labels=["Jaccard Score", None],
|
785
|
+
title=col,
|
786
|
+
)
|
787
|
+
else:
|
788
|
+
ax.axis("off")
|
789
|
+
|
790
|
+
@classmethod
|
791
|
+
def _pretty_scatter(
|
792
|
+
cls,
|
793
|
+
ax,
|
794
|
+
x,
|
795
|
+
y,
|
796
|
+
s=5,
|
797
|
+
alpha=1.0,
|
798
|
+
title=None,
|
799
|
+
legend=False,
|
800
|
+
axis_labels=None,
|
801
|
+
axis_lim=None,
|
802
|
+
grid=True,
|
803
|
+
label=None,
|
804
|
+
plot_kwargs=None,
|
805
|
+
):
|
806
|
+
if plot_kwargs is None:
|
807
|
+
plot_kwargs = {}
|
808
|
+
ax.scatter(x, y, s=s, label=label, marker="o", alpha=alpha, **plot_kwargs)
|
809
|
+
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
810
|
+
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
811
|
+
if legend:
|
812
|
+
ax.legend(frameon=False)
|
813
|
+
if axis_labels:
|
814
|
+
ax.set_xlabel(axis_labels[0])
|
815
|
+
ax.set_ylabel(axis_labels[1])
|
816
|
+
if title:
|
817
|
+
ax.set_title(title, y=1.08, fontsize=14)
|
818
|
+
if grid:
|
819
|
+
ax.grid(linewidth=0.2)
|
820
|
+
if axis_lim:
|
821
|
+
if axis_lim[0]:
|
822
|
+
ax.set_xlim(axis_lim[0])
|
823
|
+
if axis_lim[1]:
|
824
|
+
ax.set_ylim(axis_lim[1])
|
825
|
+
|
826
|
+
@classmethod
|
827
|
+
def _top_2_features(cls, axs, evaluation):
|
828
|
+
pass
|
829
|
+
|
830
|
+
@classmethod
|
831
|
+
def _residuals_qq(cls, axs, evaluation):
|
832
|
+
n_models = len(evaluation.columns)
|
833
|
+
for i, ax in enumerate(axs):
|
834
|
+
if i >= n_models:
|
835
|
+
ax.axis("off")
|
836
|
+
return
|
837
|
+
model_name = evaluation.columns[i]
|
838
|
+
mod = evaluation[model_name]
|
839
|
+
# getattr(ax, self.plot_method)(self.x, y, **self.plot_kwargs, color='#4a91c2', label=label)
|
840
|
+
cls._pretty_scatter(
|
841
|
+
ax,
|
842
|
+
mod["norm_quantiles"],
|
843
|
+
mod["residual_quantiles"],
|
844
|
+
title=model_name,
|
845
|
+
axis_lim=[(-2.7, 2.7), (-3.1, 3.1)],
|
846
|
+
axis_labels=["Theoretical Quantiles", "Sample Quantiles"],
|
847
|
+
)
|
848
|
+
if cls.baseline:
|
849
|
+
ax.plot((-100, 100), (-100, 100), **cls.baseline_kwargs)
|
850
|
+
|
851
|
+
@classmethod
|
852
|
+
def _residuals_vs_predicted(cls, axs, evaluation):
|
853
|
+
n_models = len(evaluation.columns)
|
854
|
+
for i, ax in enumerate(axs):
|
855
|
+
if i >= n_models:
|
856
|
+
ax.axis("off")
|
857
|
+
return
|
858
|
+
model_name = evaluation.columns[i]
|
859
|
+
mod = evaluation[model_name]
|
860
|
+
y_pred = np.asarray(mod["y_pred"])
|
861
|
+
resid = np.asarray(mod["residuals"])
|
862
|
+
cls._pretty_scatter(
|
863
|
+
ax,
|
864
|
+
y_pred,
|
865
|
+
resid,
|
866
|
+
s=4,
|
867
|
+
alpha=0.5,
|
868
|
+
title=model_name,
|
869
|
+
axis_labels=["Predicted Values", "Residuals"],
|
870
|
+
)
|
871
|
+
x_lim = (
|
872
|
+
y_pred.min() - y_pred.min() * 0.05,
|
873
|
+
y_pred.max() + y_pred.max() * 0.05,
|
874
|
+
)
|
875
|
+
if cls.baseline:
|
876
|
+
ax.plot(x_lim, (0, 0), **cls.baseline_kwargs)
|
877
|
+
ax.set_xlim(x_lim)
|
878
|
+
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
879
|
+
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
880
|
+
|
881
|
+
@classmethod
|
882
|
+
def _residuals_vs_observed(cls, axs, evaluation):
|
883
|
+
n_models = len(evaluation.columns)
|
884
|
+
for i, ax in enumerate(axs):
|
885
|
+
if i >= n_models:
|
886
|
+
ax.axis("off")
|
887
|
+
return
|
888
|
+
model_name = evaluation.columns[i]
|
889
|
+
mod = evaluation[model_name]
|
890
|
+
y_true = np.asarray(mod["y_true"])
|
891
|
+
cls._pretty_scatter(
|
892
|
+
ax,
|
893
|
+
y_true,
|
894
|
+
mod["residuals"],
|
895
|
+
s=4,
|
896
|
+
alpha=0.5,
|
897
|
+
title=model_name,
|
898
|
+
axis_labels=["Observed Values", "Residuals"],
|
899
|
+
)
|
900
|
+
x_lim = (
|
901
|
+
y_true.min() - y_true.min() * 0.05,
|
902
|
+
y_true.max() + y_true.max() * 0.05,
|
903
|
+
)
|
904
|
+
if cls.baseline:
|
905
|
+
ax.plot(x_lim, (0, 0), **cls.baseline_kwargs)
|
906
|
+
ax.set_xlim(x_lim)
|
907
|
+
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
908
|
+
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
909
|
+
|
910
|
+
@classmethod
|
911
|
+
def _observed_vs_predicted(cls, axs, evaluation):
|
912
|
+
n_models = len(evaluation.columns)
|
913
|
+
for i, ax in enumerate(axs):
|
914
|
+
if i >= n_models:
|
915
|
+
ax.axis("off")
|
916
|
+
return
|
917
|
+
|
918
|
+
model_name = evaluation.columns[i]
|
919
|
+
mod = evaluation[model_name]
|
920
|
+
ax.scatter(mod["y_true"], mod["y_pred"], s=4, marker="o", alpha=0.5)
|
921
|
+
|
922
|
+
y_true = np.asarray(mod["y_true"])
|
923
|
+
|
924
|
+
yt_min = y_true.min()
|
925
|
+
yt_max = y_true.max()
|
926
|
+
|
927
|
+
x_lim = (yt_min - yt_min * 0.05, yt_max + yt_max * 0.05)
|
928
|
+
if cls.baseline:
|
929
|
+
ax.plot(x_lim, x_lim, **cls.baseline_kwargs)
|
930
|
+
ax.set_xlabel("Observed Values")
|
931
|
+
ax.set_ylabel("Predicted Values")
|
932
|
+
ax.set_title(model_name, y=1.08, fontsize=14)
|
933
|
+
ax.grid(linewidth=0.2)
|
934
|
+
ax.set_xlim(x_lim)
|
935
|
+
|
936
|
+
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
937
|
+
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
|
938
|
+
|
939
|
+
@classmethod
|
940
|
+
def _normalized_confusion_matrix(cls, axs, evaluation):
|
941
|
+
for model_num, ax in enumerate(axs):
|
942
|
+
if model_num >= len(evaluation.columns):
|
943
|
+
ax.axis("off")
|
944
|
+
return
|
945
|
+
model_name = evaluation.columns[model_num]
|
946
|
+
mod = evaluation[model_name]
|
947
|
+
if cls.prob_type == "_bin":
|
948
|
+
labels = [str(lab == mod["positive_class"]) for lab in mod["classes"]]
|
949
|
+
else:
|
950
|
+
labels = cls.legend_labels.values()
|
951
|
+
|
952
|
+
raw_cm = mod["raw_confusion_matrix"]
|
953
|
+
cm = np.asarray(mod["confusion_matrix"])
|
954
|
+
|
955
|
+
ax.set_title(
|
956
|
+
"%s\n" % cls._get_formatted_title(model_name), y=1.08, fontsize=14
|
957
|
+
)
|
958
|
+
ax.imshow(cm, interpolation="nearest", cmap="BuGn")
|
959
|
+
x_tick_marks = np.arange(len(labels))
|
960
|
+
y_tick_marks = np.arange(len(labels))
|
961
|
+
ax.set_xticks(x_tick_marks)
|
962
|
+
ax.set_yticks(y_tick_marks)
|
963
|
+
|
964
|
+
ax.set_xticklabels(labels, rotation=90, fontsize=10)
|
965
|
+
ax.set_yticklabels(labels, fontsize=10)
|
966
|
+
|
967
|
+
for i, j in itertools.product(
|
968
|
+
range(raw_cm.shape[0]), range(raw_cm.shape[1])
|
969
|
+
):
|
970
|
+
ax.text(
|
971
|
+
j,
|
972
|
+
i,
|
973
|
+
"%s [%s]" % (round(cm[i][j], 3), raw_cm[i, j]),
|
974
|
+
horizontalalignment="center",
|
975
|
+
verticalalignment="center",
|
976
|
+
rotation=45,
|
977
|
+
fontsize=max(3, 10 - max(raw_cm.shape[0], raw_cm.shape[1])),
|
978
|
+
color="white" if cm[i, j] > 0.5 else "black",
|
979
|
+
)
|
980
|
+
|
981
|
+
ax.set_ylabel("True label", fontsize=10)
|
982
|
+
ax.set_xlabel("Predicted label", fontsize=10)
|
983
|
+
ax.grid(False)
|