oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.9rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +506 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +269 -0
- ads/aqua/common/enums.py +122 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1285 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +248 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +298 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +282 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2114 -0
- ads/aqua/modeldeployment/__init__.py +8 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1326 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/inference.py +74 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +499 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +175 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +445 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +125 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/METADATA +150 -150
- oracle_ads-2.13.9rc1.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,1050 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8; -*-
|
3
|
+
|
4
|
+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
|
10
|
+
from ads.common import logger, utils
|
11
|
+
from ads.explanations.base_explainer import GlobalExplainer
|
12
|
+
from ads.explanations.mlx_interface import check_tabular_or_text
|
13
|
+
from ads.explanations.mlx_interface import init_lime_explainer
|
14
|
+
from ads.explanations.mlx_interface import init_permutation_importance_explainer
|
15
|
+
from ads.explanations.mlx_interface import (
|
16
|
+
init_partial_dependence_explainer,
|
17
|
+
init_ale_explainer,
|
18
|
+
)
|
19
|
+
from ads.common.decorator.runtime_dependency import (
|
20
|
+
runtime_dependency,
|
21
|
+
OptionalDependency,
|
22
|
+
)
|
23
|
+
from ads.common.decorator.deprecate import deprecated
|
24
|
+
|
25
|
+
|
26
|
+
class MLXGlobalExplainer(GlobalExplainer):
|
27
|
+
"""
|
28
|
+
Global Explainer class.
|
29
|
+
|
30
|
+
Generates global explanations to help understand the general model
|
31
|
+
behavior. Supported explanations:
|
32
|
+
|
33
|
+
- (Tabular) Feature Permutation Importance.
|
34
|
+
- (Tabular) Partial Dependence Plots (PDP) & Individual Conditional
|
35
|
+
Expectation (ICE).
|
36
|
+
- (Text) Aggregate local explanations (global explanation approximation
|
37
|
+
constructed from multiple local explanations).
|
38
|
+
|
39
|
+
Supports:
|
40
|
+
|
41
|
+
- Binary classification.
|
42
|
+
- Multi-class classification.
|
43
|
+
- Regression.
|
44
|
+
|
45
|
+
"""
|
46
|
+
|
47
|
+
@deprecated(
|
48
|
+
details="Working with AutoML has moved from within ADS to working directly with the AutoMLx library. AutoMLx are preinstalled in conda pack automlx_p38_cpu_v2 and later, and can now be updated independently of ADS. AutoMLx documentation may be found at https://docs.oracle.com/en-us/iaas/tools/automlx/latest/html/multiversion/v23.1.1/index.html. Notebook examples are in Oracle's samples repository: https://github.com/oracle-samples/oci-data-science-ai-samples/tree/master/notebook_examples and a migration tutorial can be found at https://accelerated-data-science.readthedocs.io/en/latest/user_guide/model_training/automl/quick_start.html .",
|
49
|
+
raise_error=True,
|
50
|
+
)
|
51
|
+
def __init__(self):
|
52
|
+
super(GlobalExplainer, self).__init__()
|
53
|
+
self.explainer = None
|
54
|
+
self.selected_features = None
|
55
|
+
self.pdp_explainer = None
|
56
|
+
self.ale_explainer = None
|
57
|
+
|
58
|
+
def compute_feature_importance(
|
59
|
+
self,
|
60
|
+
n_iter=20,
|
61
|
+
sampling=None,
|
62
|
+
balance=False,
|
63
|
+
scoring_metric=None,
|
64
|
+
selected_features=None,
|
65
|
+
):
|
66
|
+
"""
|
67
|
+
Generates a global explanation to help understand the general behavior
|
68
|
+
of the model. This explainer identifies which features are most important
|
69
|
+
to the model.
|
70
|
+
|
71
|
+
If the dataset is tabular, computes a global feature permutation importance
|
72
|
+
explanation. If the dataset is text, approximates a global explanation by
|
73
|
+
generating and aggregating multiple local explanations.
|
74
|
+
|
75
|
+
Parameters
|
76
|
+
----------
|
77
|
+
n_iter : int, optional
|
78
|
+
Number of iterations of the permutation importance algorithm to
|
79
|
+
perform. Increasing this value increases the quality/stability of
|
80
|
+
the explanation, but increases the explanation time. Default value is 20.
|
81
|
+
sampling : dict, optional
|
82
|
+
If not `None`, the dataset is clustered or sampled according to the
|
83
|
+
provided technique. `sampling` is a dictionary containing the technique
|
84
|
+
to use and the corresponding parameters. Format is described below:
|
85
|
+
|
86
|
+
- `technique`: Either `cluster` or `random`.
|
87
|
+
- If `cluster`, also requires:
|
88
|
+
|
89
|
+
- `eps`: Maximum distance between two samples to be considered
|
90
|
+
in the same cluster.
|
91
|
+
- `min_samples`: Minimum number of samples to include in each
|
92
|
+
cluster.
|
93
|
+
|
94
|
+
- If `random`, also requires:
|
95
|
+
|
96
|
+
- `n_samples`: Number of samples to return.
|
97
|
+
|
98
|
+
By default None. Note that text datasets are always sampled. If not provided
|
99
|
+
with a sampling option, defaults to 40 random samples.
|
100
|
+
balance : bool, optional
|
101
|
+
If True, the dataset will be balanced via sampling. If 'sampling' is not
|
102
|
+
set, the sampling technique defaults to 'random'.
|
103
|
+
scoring_metric : string, optional
|
104
|
+
If specified, propegates a string indicating the supported scoring metric.
|
105
|
+
The scoring metrics available out of the box are the ones made available
|
106
|
+
by ScyPy. Supported Metrics:
|
107
|
+
|
108
|
+
- Multi-class Classification
|
109
|
+
|
110
|
+
`f1_weighted`, `f1_micro`, `f1_macro`, `recall_weighted`, `recall_micro`, `recall_macro`,
|
111
|
+
`accuracy`, `balanced_accuracy`, `roc_auc`, `precision_weighted`, `precision_macro`,
|
112
|
+
`precision_micro`
|
113
|
+
|
114
|
+
- Binary Classification
|
115
|
+
|
116
|
+
Same as multi-class classification
|
117
|
+
|
118
|
+
- Regression
|
119
|
+
|
120
|
+
`r2`, `neg_mean_squared_error`, `neg_root_mean_squared_error`, `neg_mean_absolute_error`,
|
121
|
+
`neg_median_absolute_error`, `neg_mean_absolute_percentage_error`,
|
122
|
+
`neg_symmetric_mean_absolute_percentage_error`
|
123
|
+
selected_features: list[str], list[int], optional
|
124
|
+
List of the selected features. It can be any subset of
|
125
|
+
the original features that are in the dataset provided to the model.
|
126
|
+
Default value is None.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
:class:FeatureImportance
|
131
|
+
`FeaturePermutationImportance` explanation object.
|
132
|
+
|
133
|
+
"""
|
134
|
+
self.selected_features = selected_features
|
135
|
+
self.configure_feature_importance(selected_features=self.selected_features)
|
136
|
+
if self.explainer.config.type == "text":
|
137
|
+
labels = list(range(len(self.class_names)))
|
138
|
+
# The requirement to downsample the text datasets should be fixed at somepoint
|
139
|
+
if sampling is None:
|
140
|
+
sampling = {"technique": "random", "n_samples": 40}
|
141
|
+
explanation = self.explainer.explain_aggregate_local(
|
142
|
+
self.X_test, sampling=sampling, labels=labels
|
143
|
+
)
|
144
|
+
else:
|
145
|
+
if self.mode_ == "regression":
|
146
|
+
allowed_metrics = [
|
147
|
+
"r2",
|
148
|
+
"neg_mean_squared_error",
|
149
|
+
"neg_root_mean_squared_error",
|
150
|
+
"neg_mean_absolute_error",
|
151
|
+
"neg_median_absolute_error",
|
152
|
+
"neg_mean_absolute_percentage_error",
|
153
|
+
"neg_symmetric_mean_absolute_percentage_error",
|
154
|
+
]
|
155
|
+
elif self.mode_ == "classification" and len(self.class_names) == 2:
|
156
|
+
# Binary classification
|
157
|
+
allowed_metrics = [
|
158
|
+
"f1_weighted",
|
159
|
+
"f1_micro",
|
160
|
+
"f1_macro",
|
161
|
+
"recall_weighted",
|
162
|
+
"recall_micro",
|
163
|
+
"recall_macro",
|
164
|
+
"accuracy",
|
165
|
+
"balanced_accuracy",
|
166
|
+
"roc_auc",
|
167
|
+
"precision_weighted",
|
168
|
+
"precision_macro",
|
169
|
+
"precision_micro",
|
170
|
+
]
|
171
|
+
else:
|
172
|
+
# Multiclass classification
|
173
|
+
allowed_metrics = [
|
174
|
+
"f1_weighted",
|
175
|
+
"f1_micro",
|
176
|
+
"f1_macro",
|
177
|
+
"recall_weighted",
|
178
|
+
"recall_micro",
|
179
|
+
"recall_macro",
|
180
|
+
"accuracy",
|
181
|
+
"balanced_accuracy",
|
182
|
+
"roc_auc",
|
183
|
+
"precision_weighted",
|
184
|
+
"precision_macro",
|
185
|
+
"precision_micro",
|
186
|
+
]
|
187
|
+
if scoring_metric not in allowed_metrics and scoring_metric is not None:
|
188
|
+
raise Exception(
|
189
|
+
"Scoring Metric not supported for this type of problem: {}, for problem type {}, the availble supported metrics are {}".format(
|
190
|
+
scoring_metric, self.mode_, allowed_metrics
|
191
|
+
)
|
192
|
+
)
|
193
|
+
if balance and sampling is None:
|
194
|
+
sampling = {"technique": "random"}
|
195
|
+
try:
|
196
|
+
explanation = self.explainer.compute(
|
197
|
+
self.X_test,
|
198
|
+
self.y_test,
|
199
|
+
n_iter=n_iter,
|
200
|
+
sampling=sampling,
|
201
|
+
balance=balance,
|
202
|
+
scoring_metric=scoring_metric,
|
203
|
+
)
|
204
|
+
except IndexError as e:
|
205
|
+
if selected_features is not None:
|
206
|
+
raise IndexError(
|
207
|
+
f"Unable to calculate permutation importance due to: {e}. "
|
208
|
+
f"selected_features must be a list of features within the bounds of the existing features "
|
209
|
+
f"(that were provided to model). Provided selected_features: {selected_features}."
|
210
|
+
)
|
211
|
+
except Exception as e:
|
212
|
+
logger.error(
|
213
|
+
f"Unable to calculate permutation importance scores due to: {e}."
|
214
|
+
)
|
215
|
+
raise e
|
216
|
+
return FeatureImportance(explanation, self.class_names, self.explainer.config)
|
217
|
+
|
218
|
+
def compute_partial_dependence(
|
219
|
+
self, features, partial_range=(0.00, 1.0), num_samples=30, sampling=None
|
220
|
+
):
|
221
|
+
"""
|
222
|
+
Generates a global partial dependence plot (PDP) and individual conditional
|
223
|
+
expectation (ICE) plots to help understand the relationship between feature
|
224
|
+
values and the model target.
|
225
|
+
|
226
|
+
Only supported for tabular datasets.
|
227
|
+
|
228
|
+
Parameters
|
229
|
+
----------
|
230
|
+
|
231
|
+
features : list of int, list of str
|
232
|
+
List of feature names or feature indices to explain.
|
233
|
+
partial_range : tuple, optional
|
234
|
+
2-tuple with the minimum and maximum percentile values to consider for the PDP from
|
235
|
+
the feature's train distribution. Must be between 0.0 and 1.0.
|
236
|
+
Defaults to `partial_range = (0.05, 0.95)`.
|
237
|
+
num_samples : int, optional
|
238
|
+
Maximum number of samples to generate for each feature within the
|
239
|
+
`partial_range` of its value distribution. Increasing this value
|
240
|
+
generates more points to evaluate, but increases the explanation
|
241
|
+
time. If there are fewer unique values for a feature within the
|
242
|
+
`partial_range`, the number of unique values is selected. For two-feature
|
243
|
+
PDP, the total number of evaluated samples is the multiplication
|
244
|
+
of `num_samples`. Default value is 30.
|
245
|
+
sampling : dict, optional
|
246
|
+
If not None, the dataset will be clustered or sampled according to the
|
247
|
+
provided technique. 'sampling' is a dictionary containing the technique
|
248
|
+
to use and the corresponding parameters. Format is described below:
|
249
|
+
|
250
|
+
- `technique`: Either "cluster" or "random".
|
251
|
+
- `cluster` also requires:
|
252
|
+
|
253
|
+
- `eps`: Maximum distance between two samples to be considered
|
254
|
+
in the same cluster.
|
255
|
+
- `min_samples`: Minimum number of samples to include in each
|
256
|
+
cluster.
|
257
|
+
|
258
|
+
- `random` also requires:
|
259
|
+
|
260
|
+
- 'n_samples': Number of samples to return.
|
261
|
+
|
262
|
+
Default value is `None` (no sampling).
|
263
|
+
|
264
|
+
Returns
|
265
|
+
-------
|
266
|
+
:class:MLXPartialDependencies
|
267
|
+
`MLXPartialDependencies` object.
|
268
|
+
|
269
|
+
"""
|
270
|
+
if self.pdp_explainer is None:
|
271
|
+
self._init_partial_dependence()
|
272
|
+
|
273
|
+
# Wrap in a list if a list is not provided
|
274
|
+
if not isinstance(features, list):
|
275
|
+
features = [features]
|
276
|
+
|
277
|
+
# Convert to uppercase to be case-insensitive
|
278
|
+
features = [str(f).upper() for f in features]
|
279
|
+
feature_names = np.char.upper(self.X_train.columns.tolist())
|
280
|
+
|
281
|
+
# Fail if we were not provided valid feature names
|
282
|
+
if not all(np.isin(features, feature_names)):
|
283
|
+
print("One or more features (%s) does not exist in data." % str(features))
|
284
|
+
print("Existing features: %s" % str(feature_names))
|
285
|
+
return
|
286
|
+
|
287
|
+
# Extract the feature ids
|
288
|
+
feature_ids = np.where(np.isin(feature_names, features))[0].tolist()
|
289
|
+
|
290
|
+
if check_tabular_or_text(self.est, self.X_train) == "tabular":
|
291
|
+
if len(feature_ids) > 2:
|
292
|
+
raise ValueError("Maximum number of partial dependency features is 2.")
|
293
|
+
|
294
|
+
return MLXPartialDependencies(
|
295
|
+
pdp=self.pdp_explainer.compute(
|
296
|
+
data=self.X_train,
|
297
|
+
partial_ids=feature_ids,
|
298
|
+
partial_range=partial_range,
|
299
|
+
num_samples=num_samples,
|
300
|
+
sampling=sampling,
|
301
|
+
),
|
302
|
+
pdp_exp=self.pdp_explainer,
|
303
|
+
)
|
304
|
+
else:
|
305
|
+
raise ValueError(
|
306
|
+
"Partial Dependence Plot is not supported for text classification dataset."
|
307
|
+
)
|
308
|
+
|
309
|
+
def compute_accumulated_local_effects(
|
310
|
+
self,
|
311
|
+
feature,
|
312
|
+
partial_range=(0.00, 1.0),
|
313
|
+
num_samples=30,
|
314
|
+
sampling=None,
|
315
|
+
corr_threshold=0.7,
|
316
|
+
):
|
317
|
+
"""
|
318
|
+
Generates the accumulated local effects plots to help understand the relationship between feature
|
319
|
+
values and the model target.
|
320
|
+
|
321
|
+
Only supported for tabular datasets.
|
322
|
+
|
323
|
+
Parameters
|
324
|
+
----------
|
325
|
+
feature : str
|
326
|
+
Feature name to explain.
|
327
|
+
partial_range : tuple, optional
|
328
|
+
Min/max percentile values to consider for the ALE from the
|
329
|
+
feature's train distribution. Must be between 0.0 and 1.0.
|
330
|
+
By default `partial = (0.05, 0.95)`.
|
331
|
+
num_samples : int, optional
|
332
|
+
Maximum number of samples to generate for each feature within the
|
333
|
+
`partial_range` of its value distribution. Increasing this value
|
334
|
+
generates more points to evaluate, but increases the explanation
|
335
|
+
time. If there are fewer unique values for a feature within the
|
336
|
+
`partial_range`, the number of unique values is selected.
|
337
|
+
sampling : dict, optional
|
338
|
+
If not `None`, the dataset is clustered or sampled according to the
|
339
|
+
provided technique. `sampling` is a dictionary containing the technique
|
340
|
+
to use and the corresponding parameters. The format is:
|
341
|
+
|
342
|
+
- `technique`: Can be either "cluster" or "random".
|
343
|
+
- `cluster` also requires:
|
344
|
+
|
345
|
+
- `eps`: Maximum distance between two samples to be considered
|
346
|
+
in the same cluster.
|
347
|
+
- `min_samples`: Minimum number of samples to include in each
|
348
|
+
cluster.
|
349
|
+
|
350
|
+
- `random` also requires:
|
351
|
+
|
352
|
+
- `n_samples`: Number of samples to return.
|
353
|
+
|
354
|
+
Defaults to `None` (no sampling).
|
355
|
+
|
356
|
+
corr_threshold : float, optional
|
357
|
+
Value between 0.0 and 1.0 for which a feature is considered highly correlated with
|
358
|
+
another feature (Default = 0.7).
|
359
|
+
|
360
|
+
Returns
|
361
|
+
-------
|
362
|
+
:class:MLXAccumulatedLocalEffects
|
363
|
+
`AccumulatedLocalEffects` explanation object.
|
364
|
+
|
365
|
+
"""
|
366
|
+
if self.ale_explainer is None:
|
367
|
+
self._init_accumulated_local_effects()
|
368
|
+
|
369
|
+
# Wrap in a list if a list is not provide, to be able support list of two features in the near future.
|
370
|
+
if not isinstance(feature, list):
|
371
|
+
feature = [feature]
|
372
|
+
|
373
|
+
# Convert to uppercase to be case-insensitive
|
374
|
+
feature = [str(f).upper() for f in feature]
|
375
|
+
feature_names = np.char.upper(self.X_train.columns.tolist())
|
376
|
+
|
377
|
+
# Fail if we were not provided valid feature names
|
378
|
+
if not all(np.isin(feature, feature_names)):
|
379
|
+
print("One or more features (%s) does not exist in data." % str(feature))
|
380
|
+
print("Existing features: %s" % str(feature_names))
|
381
|
+
return
|
382
|
+
|
383
|
+
# Extract the feature ids
|
384
|
+
feature_ids = np.where(np.isin(feature_names, feature))[0].tolist()
|
385
|
+
|
386
|
+
if check_tabular_or_text(self.est, self.X_train) == "tabular":
|
387
|
+
if len(feature_ids) > 1:
|
388
|
+
raise ValueError(
|
389
|
+
"Maximum number of Accumulated Local Effects features is 1."
|
390
|
+
)
|
391
|
+
|
392
|
+
return MLXAccumulatedLocalEffects(
|
393
|
+
ale=self.ale_explainer.compute(
|
394
|
+
data=self.X_train,
|
395
|
+
partial_ids=feature_ids,
|
396
|
+
partial_range=partial_range,
|
397
|
+
num_samples=num_samples,
|
398
|
+
sampling=sampling,
|
399
|
+
corr_threshold=corr_threshold,
|
400
|
+
),
|
401
|
+
ale_exp=self.ale_explainer,
|
402
|
+
)
|
403
|
+
else:
|
404
|
+
raise ValueError(
|
405
|
+
"Accumulated Local Effects Plot is not supported for text classification dataset."
|
406
|
+
)
|
407
|
+
|
408
|
+
@runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK)
|
409
|
+
def show_in_notebook(self): # pragma: no cover
|
410
|
+
"""
|
411
|
+
Generates and visualizes the global feature importance explanation.
|
412
|
+
"""
|
413
|
+
with utils.get_progress_bar(3, description="Model Explanation") as bar:
|
414
|
+
bar.update("begin computing")
|
415
|
+
bar.update("calculating feature importance")
|
416
|
+
explainer_holder = self.compute_feature_importance(
|
417
|
+
selected_features=self.selected_features
|
418
|
+
)
|
419
|
+
plot1 = explainer_holder.show_in_notebook()
|
420
|
+
bar.update("calculating partial dependence plot")
|
421
|
+
pdp_plot_feature_name = explainer_holder.explanation
|
422
|
+
# pdp_plot_feature_name = explainer_holder.explanation.get_global_explanation().index[0]
|
423
|
+
pdp_plot = self.compute_partial_dependence([pdp_plot_feature_name])
|
424
|
+
# plot2 = pdp_plot.show_in_notebook()
|
425
|
+
|
426
|
+
from IPython.core.display import display, HTML
|
427
|
+
|
428
|
+
display(HTML(plot1.data))
|
429
|
+
# display(HTML(plot1.data + plot2.data))
|
430
|
+
|
431
|
+
def configure_feature_importance(self, **kwargs):
|
432
|
+
"""
|
433
|
+
Validates and initializes the feature importance explainer based on the provided
|
434
|
+
configuration parameters in kwargs. Tabular datasets use the feature permutation
|
435
|
+
importance explainer, text datasets use the aggregate local explainer.
|
436
|
+
|
437
|
+
Supported configuration options:
|
438
|
+
|
439
|
+
- For tabular datasets:
|
440
|
+
|
441
|
+
- `client`: Currently only allowed to be None to disable parallelization.
|
442
|
+
- `random_state`: None, int, or instance of Randomstate.
|
443
|
+
- `selected_features`: None, or list of the selected features.
|
444
|
+
|
445
|
+
- For text datasets:
|
446
|
+
|
447
|
+
- `surrogate_model`: Surrogate model to use. Can be 'linear' or 'decision_tree'.
|
448
|
+
- `num_samples`: Number of generated samples to fit the surrogate model. Int.
|
449
|
+
- `exp_sorting`: Feature importance sorting. Can be 'absolute' or 'ordered'.
|
450
|
+
- `scale_weight`: Normalizes the feature importance coefficients from LIME to sum to one.
|
451
|
+
- `client`: Currently only allowed to be None to disable parallelization.
|
452
|
+
- `batch_size`: Number of local explanations per Dask worker.
|
453
|
+
- `random_state`: None, int, or instance of Randomstate.
|
454
|
+
- `selected_features`: None, or list of the selected features.
|
455
|
+
|
456
|
+
Parameters
|
457
|
+
----------
|
458
|
+
kwargs : dict
|
459
|
+
Keyword parameter dictionary.
|
460
|
+
|
461
|
+
Returns
|
462
|
+
-------
|
463
|
+
MLXGlobalExplainer
|
464
|
+
the modified instance (self)
|
465
|
+
|
466
|
+
"""
|
467
|
+
|
468
|
+
if check_tabular_or_text(self.est, self.X_train) == "tabular":
|
469
|
+
avail_args = ["client", "random_state", "selected_features"]
|
470
|
+
else:
|
471
|
+
avail_args = [
|
472
|
+
"client",
|
473
|
+
"random_state",
|
474
|
+
"surrogate_model",
|
475
|
+
"num_samples",
|
476
|
+
"exp_sorting",
|
477
|
+
"scale_weight",
|
478
|
+
"batch_size",
|
479
|
+
"selected_features",
|
480
|
+
]
|
481
|
+
|
482
|
+
for k, _ in kwargs.items():
|
483
|
+
if k not in avail_args:
|
484
|
+
raise ValueError(
|
485
|
+
"Unexpected argument for the feature importance explainer: {}".format(
|
486
|
+
k
|
487
|
+
)
|
488
|
+
)
|
489
|
+
|
490
|
+
if kwargs.get("client", None) is not None:
|
491
|
+
raise ValueError(
|
492
|
+
"Invalid client provided. Currently only supports disabling parallelization "
|
493
|
+
"by setting client=None"
|
494
|
+
)
|
495
|
+
if kwargs.get("surrogate_model", None) not in ["linear", "decision_tree", None]:
|
496
|
+
raise ValueError(
|
497
|
+
"Invalid surrogate_model provided. Currently only supports linear or decision_tree"
|
498
|
+
)
|
499
|
+
selected_features = kwargs.get("selected_features")
|
500
|
+
if selected_features is not None and not isinstance(selected_features, list):
|
501
|
+
raise ValueError(
|
502
|
+
f"selected_features ({selected_features}) value must be a list of features, "
|
503
|
+
f"but it is of type: {type(selected_features)}."
|
504
|
+
)
|
505
|
+
|
506
|
+
self._init_feature_importance(**kwargs)
|
507
|
+
return self
|
508
|
+
|
509
|
+
def configure_partial_dependence(self, **kwargs):
|
510
|
+
"""
|
511
|
+
Validates and initializes the partial dependence explainer based on the provided
|
512
|
+
configuration parameters in kwargs. Only supports tabular datasets.
|
513
|
+
|
514
|
+
Supported configuration options:
|
515
|
+
client: Currently only supports 'None' to disable parallelization.
|
516
|
+
|
517
|
+
Parameters
|
518
|
+
----------
|
519
|
+
kwargs : dict
|
520
|
+
Keyword parameter dictionary.
|
521
|
+
|
522
|
+
Returns
|
523
|
+
-------
|
524
|
+
MLXGlobalExplainer
|
525
|
+
the modified instance (self)
|
526
|
+
"""
|
527
|
+
|
528
|
+
for k, _ in kwargs.items():
|
529
|
+
if k not in ["client"]:
|
530
|
+
raise ValueError(
|
531
|
+
"Unexpected argument for the partial dependence explainer: {}".format(
|
532
|
+
k
|
533
|
+
)
|
534
|
+
)
|
535
|
+
if kwargs.get("client", None) is not None:
|
536
|
+
raise ValueError(
|
537
|
+
"Invalid client provided. Currently only supports disabling parallelization "
|
538
|
+
"by setting client=None"
|
539
|
+
)
|
540
|
+
self._init_partial_dependence(**kwargs)
|
541
|
+
return self
|
542
|
+
|
543
|
+
def configure_accumulated_local_effects(self, **kwargs):
|
544
|
+
"""
|
545
|
+
Validates and initializes the accumulated local effects explainer based on the provided
|
546
|
+
configuration parameters in kwargs. Only supports tabular datasets.
|
547
|
+
|
548
|
+
Supported configuration options:
|
549
|
+
|
550
|
+
- client: Currently only supports 'None' to disable parallelization.
|
551
|
+
|
552
|
+
Parameters
|
553
|
+
----------
|
554
|
+
kwargs : dict
|
555
|
+
Keyword parameter dictionary.
|
556
|
+
|
557
|
+
Returns
|
558
|
+
-------
|
559
|
+
MLXGlobalExplainer
|
560
|
+
the modified instance (self)
|
561
|
+
"""
|
562
|
+
|
563
|
+
for k, _ in kwargs.items():
|
564
|
+
if k not in ["client"]:
|
565
|
+
raise ValueError(
|
566
|
+
"Unexpected argument for the accumulated local effects explainer: {}".format(
|
567
|
+
k
|
568
|
+
)
|
569
|
+
)
|
570
|
+
if kwargs.get("client", None) is not None:
|
571
|
+
raise ValueError(
|
572
|
+
"Invalid client provided. Currently only supports disabling parallelization "
|
573
|
+
"by setting client=None"
|
574
|
+
)
|
575
|
+
self._init_accumulated_local_effects(**kwargs)
|
576
|
+
return self
|
577
|
+
|
578
|
+
def feature_importance_summary(self):
|
579
|
+
"""
|
580
|
+
Displays detailed information about the feature importance explainer.
|
581
|
+
|
582
|
+
Returns
|
583
|
+
-------
|
584
|
+
str
|
585
|
+
HTML object representing the explainer summary.
|
586
|
+
|
587
|
+
"""
|
588
|
+
|
589
|
+
if self.explainer is None:
|
590
|
+
self.compute_feature_importance(selected_features=self.selected_features)
|
591
|
+
return self.explainer.show_in_notebook()
|
592
|
+
|
593
|
+
def partial_dependence_summary(self):
|
594
|
+
"""
|
595
|
+
Displays detailed information about the partial dependence explainer.
|
596
|
+
|
597
|
+
Returns
|
598
|
+
-------
|
599
|
+
str
|
600
|
+
HTML object representing the explainer summary.
|
601
|
+
"""
|
602
|
+
|
603
|
+
if self.pdp_explainer is None:
|
604
|
+
self._init_partial_dependence()
|
605
|
+
return self.pdp_explainer.show_in_notebook()
|
606
|
+
|
607
|
+
def accumulated_local_effects_summary(self):
|
608
|
+
"""
|
609
|
+
Displays detailed information about the accumulated local effects explainer.
|
610
|
+
|
611
|
+
Returns
|
612
|
+
-------
|
613
|
+
str
|
614
|
+
HTML object representing the explainer summary.
|
615
|
+
"""
|
616
|
+
|
617
|
+
if self.ale_explainer is None:
|
618
|
+
self._init_accumulated_local_effects()
|
619
|
+
return self.ale_explainer.show_in_notebook()
|
620
|
+
|
621
|
+
def _init_feature_importance(self, **kwargs):
|
622
|
+
"""
|
623
|
+
Internal function to initialize the feature importance explainer. Tabular datasets
|
624
|
+
use the feature permutation importance explainer, text datasets use the aggregate local
|
625
|
+
explainer.
|
626
|
+
|
627
|
+
Parameters
|
628
|
+
----------
|
629
|
+
kwargs : dict
|
630
|
+
Keyword parameter dictionary.
|
631
|
+
"""
|
632
|
+
if self.mode == "regression":
|
633
|
+
self.class_names_ = ["Target"]
|
634
|
+
if check_tabular_or_text(self.est, self.X_train) == "tabular":
|
635
|
+
self.explainer = init_permutation_importance_explainer(
|
636
|
+
self.explainer,
|
637
|
+
self.est,
|
638
|
+
self.X_train,
|
639
|
+
self.y_train,
|
640
|
+
self.mode,
|
641
|
+
class_names=self.class_names,
|
642
|
+
**kwargs,
|
643
|
+
)
|
644
|
+
else:
|
645
|
+
self.explainer = init_lime_explainer(
|
646
|
+
self.explainer,
|
647
|
+
self.est,
|
648
|
+
self.X_train,
|
649
|
+
self.y_train,
|
650
|
+
self.mode,
|
651
|
+
class_names=self.class_names,
|
652
|
+
**kwargs,
|
653
|
+
)
|
654
|
+
|
655
|
+
def _init_partial_dependence(self, **kwargs):
|
656
|
+
"""
|
657
|
+
Internal function to initialize the partial dependence explainer.
|
658
|
+
|
659
|
+
Parameters
|
660
|
+
----------
|
661
|
+
kwargs : dict
|
662
|
+
Keyword parameter dictionary.
|
663
|
+
"""
|
664
|
+
if self.mode == "regression":
|
665
|
+
self.class_names_ = ["Target"]
|
666
|
+
self.pdp_explainer = init_partial_dependence_explainer(
|
667
|
+
self.pdp_explainer,
|
668
|
+
self.est,
|
669
|
+
self.X_train,
|
670
|
+
self.y_train,
|
671
|
+
self.mode,
|
672
|
+
class_names=self.class_names,
|
673
|
+
**kwargs,
|
674
|
+
)
|
675
|
+
|
676
|
+
def _init_accumulated_local_effects(self, **kwargs):
|
677
|
+
"""
|
678
|
+
Internal function to initialize the accumulated local effects explainer.
|
679
|
+
|
680
|
+
Parameters
|
681
|
+
----------
|
682
|
+
kwargs : dict
|
683
|
+
Keyword parameter dictionary.
|
684
|
+
"""
|
685
|
+
if self.mode == "regression":
|
686
|
+
self.class_names_ = ["Target"]
|
687
|
+
self.ale_explainer = init_ale_explainer(
|
688
|
+
self.ale_explainer,
|
689
|
+
self.est,
|
690
|
+
self.X_train,
|
691
|
+
self.y_train,
|
692
|
+
self.mode,
|
693
|
+
class_names=self.class_names,
|
694
|
+
**kwargs,
|
695
|
+
)
|
696
|
+
|
697
|
+
|
698
|
+
class MLXFeatureDependenceExplanation(ABC):
|
699
|
+
|
700
|
+
__name__ = "MLXFeatureDependenceExplanation"
|
701
|
+
|
702
|
+
def __init__(self, fd, fd_exp):
|
703
|
+
self.fd = fd
|
704
|
+
self.fd_exp = fd_exp
|
705
|
+
|
706
|
+
@abstractmethod
|
707
|
+
def show_in_notebook(
|
708
|
+
self,
|
709
|
+
labels=None,
|
710
|
+
cscale="YIGnBu",
|
711
|
+
show_distribution=True,
|
712
|
+
discrete_threshold=0.15,
|
713
|
+
# line_gap=0, # will add it back after ALE starts handling two features, remember to add the doc string too
|
714
|
+
show_correlation_warning=True,
|
715
|
+
centered=False,
|
716
|
+
show_median=True,
|
717
|
+
): # pragma: no cover
|
718
|
+
"""
|
719
|
+
Visualize PDP/ICE plots in the Notebook.
|
720
|
+
|
721
|
+
Parameters
|
722
|
+
----------
|
723
|
+
labels : tuple, list, int, bool, str, optional
|
724
|
+
labels to visualize.
|
725
|
+
cscale : str, optional
|
726
|
+
Plotly color scale to use for the heatmap. See the standard Plotly color scales for available options
|
727
|
+
Default value is "YIGnBu".
|
728
|
+
show_distribution : bool, optional
|
729
|
+
If `True`, the feature’s value distribution (from the train set) will be shown along the
|
730
|
+
corresponding axis in the 1-feature or 2-feature plot. Default is `True`.
|
731
|
+
discrete_threshold : float, optional
|
732
|
+
Value between 0.0 and 1.0 indicating the fraction of unique values required for a numerical feature
|
733
|
+
to be considered discrete or continuous. Default is 0.15.
|
734
|
+
show_correlation_warning : bool, optional
|
735
|
+
If `True`, the correlated feature warning is shown. Default is `True`.
|
736
|
+
centered : bool, optional
|
737
|
+
If `True`, ICE plots is centered based on the first value of each sample (i.e., all values are
|
738
|
+
subtracted from the first value). Default is False.
|
739
|
+
show_median : bool, optional
|
740
|
+
If True, a median line is included in the ICE explanation plot. Default is True.
|
741
|
+
|
742
|
+
Returns
|
743
|
+
-------
|
744
|
+
str
|
745
|
+
Plotly HTML object containing a line chart, heat map, or violin plot for this feature dependence explanation
|
746
|
+
"""
|
747
|
+
pass
|
748
|
+
|
749
|
+
def as_dataframe(self):
|
750
|
+
"""
|
751
|
+
Returns the raw explanation data as a pandas.DataFrame.
|
752
|
+
|
753
|
+
Returns
|
754
|
+
-------
|
755
|
+
pandas.DataFrame
|
756
|
+
DataFrame containing the raw PDP explanation data.
|
757
|
+
"""
|
758
|
+
return self.fd.as_dataframe()
|
759
|
+
|
760
|
+
def get_diagnostics(self):
|
761
|
+
"""
|
762
|
+
Extracts the raw explanation and evaluation data from the explanation object
|
763
|
+
(Used to generate the visualizations).
|
764
|
+
|
765
|
+
Returns
|
766
|
+
-------
|
767
|
+
dict
|
768
|
+
Dictionary containing the raw explanation/evaluation data.
|
769
|
+
"""
|
770
|
+
return self.fd.get_diagnostic()
|
771
|
+
|
772
|
+
|
773
|
+
class MLXPartialDependencies(MLXFeatureDependenceExplanation):
|
774
|
+
"""
|
775
|
+
Represents the object constructed by the :class:`MLXGlobalExplainer`.
|
776
|
+
|
777
|
+
Contains functions to visualize the explanation and extract raw explanation data.
|
778
|
+
"""
|
779
|
+
|
780
|
+
__name__ = "MLXPartialDependencies"
|
781
|
+
|
782
|
+
def __init__(self, pdp, pdp_exp):
|
783
|
+
super(MLXPartialDependencies, self).__init__(pdp, pdp_exp)
|
784
|
+
|
785
|
+
def show_in_notebook(
|
786
|
+
self,
|
787
|
+
mode="pdp",
|
788
|
+
labels=None,
|
789
|
+
cscale="YIGnBu",
|
790
|
+
show_distribution=True,
|
791
|
+
discrete_threshold=0.15,
|
792
|
+
line_gap=0,
|
793
|
+
show_correlation_warning=True,
|
794
|
+
centered=False,
|
795
|
+
show_median=True,
|
796
|
+
):
|
797
|
+
"""
|
798
|
+
Visualize PDP/ICE plots in the Notebook.
|
799
|
+
|
800
|
+
Parameters
|
801
|
+
----------
|
802
|
+
mode : str, optional
|
803
|
+
Type to visualize. Either "pdp" or "ice". Default is "pdp".
|
804
|
+
labels : tuple, list, int, bool, str, optional
|
805
|
+
labels to visualize.
|
806
|
+
cscale : str, optional
|
807
|
+
Plotly color scale to use for the heatmap. See the standard Plotly color scales for available options
|
808
|
+
Default value is "YIGnBu".
|
809
|
+
show_distribution : bool, optional
|
810
|
+
If `True`, the feature’s value distribution (from the train set) will be shown along the
|
811
|
+
corresponding axis in the 1-feature or 2-feature plot. Default is `True`.
|
812
|
+
discrete_threshold : float, optional
|
813
|
+
Value between 0.0 and 1.0 indicating the fraction of unique values required for a numerical feature
|
814
|
+
to be considered discrete or continuous. Default is 0.15.
|
815
|
+
line_gap : int, optional
|
816
|
+
Width of the gap between values in the two-feature PDP heat map. Default is 0.
|
817
|
+
show_correlation_warning : bool, optional
|
818
|
+
If `True`, the correlated feature warning are shown. Default is `True`.
|
819
|
+
centered : bool, optional
|
820
|
+
If `True`, ICE plots are centered based on the first value of each sample (i.e., all values are
|
821
|
+
subtracted from the first value). Default is `False`.
|
822
|
+
show_median : bool, optional
|
823
|
+
If `True`, a median line is included in the ICE explanation plot. Default is `True`.
|
824
|
+
|
825
|
+
Returns
|
826
|
+
-------
|
827
|
+
str
|
828
|
+
Plotly HTML object containing a line chart, heat map, or violin plot for this feature dependence explanation
|
829
|
+
"""
|
830
|
+
return self.fd.show_in_notebook(
|
831
|
+
mode=mode,
|
832
|
+
labels=labels,
|
833
|
+
cscale=cscale,
|
834
|
+
show_distribution=show_distribution,
|
835
|
+
discrete_threshold=discrete_threshold,
|
836
|
+
line_gap=line_gap,
|
837
|
+
show_correlation_warning=show_correlation_warning,
|
838
|
+
centered=centered,
|
839
|
+
show_median=show_median,
|
840
|
+
)
|
841
|
+
|
842
|
+
|
843
|
+
class MLXAccumulatedLocalEffects(MLXFeatureDependenceExplanation):
|
844
|
+
"""
|
845
|
+
Accumulated Local Effects explanation object constructed by the `:class:MLXGlobalExplainer`.
|
846
|
+
|
847
|
+
Contains functions to visualize the explanation in a Notebook and extract the
|
848
|
+
raw explanation data.
|
849
|
+
"""
|
850
|
+
|
851
|
+
__name__ = "MLXAccumulatedLocalEffects"
|
852
|
+
|
853
|
+
def __init__(self, ale, ale_exp):
|
854
|
+
super(MLXAccumulatedLocalEffects, self).__init__(ale, ale_exp)
|
855
|
+
|
856
|
+
def show_in_notebook(
|
857
|
+
self,
|
858
|
+
labels=None,
|
859
|
+
cscale="YIGnBu",
|
860
|
+
show_distribution=True,
|
861
|
+
discrete_threshold=0.15,
|
862
|
+
show_correlation_warning=True,
|
863
|
+
centered=False,
|
864
|
+
show_median=True,
|
865
|
+
):
|
866
|
+
"""
|
867
|
+
Visualize ALE plots in the Notebook.
|
868
|
+
|
869
|
+
Parameters
|
870
|
+
----------
|
871
|
+
labels : tuple, list, int, bool, str, optional
|
872
|
+
labels to visualize.
|
873
|
+
cscale : str, optional
|
874
|
+
Plotly color scale to use for the heatmap. See the standard Plotly color scales for available options
|
875
|
+
Default value is "YIGnBu".
|
876
|
+
show_distribution : bool, optional
|
877
|
+
If `True`, the feature’s value distribution (from the train set) will be shown along the
|
878
|
+
corresponding axis in the 1-feature. Default is `True`.
|
879
|
+
discrete_threshold : float, optional
|
880
|
+
Value between 0.0 and 1.0 indicating the fraction of unique values required for a numerical feature
|
881
|
+
to be considered discrete or continuous. Default is 0.15.
|
882
|
+
show_correlation_warning : bool, optional
|
883
|
+
If `True`, the correlated feature warning will be shown. Default is `True`.
|
884
|
+
centered : bool, optional
|
885
|
+
If `True`, ALE plots will be centered based on the first value of each sample (i.e., all values are
|
886
|
+
subtracted from the first value). Default is `False`.
|
887
|
+
show_median : bool, optional
|
888
|
+
If `True`, a median line is included in the ALE explanation plot. Default is `True`.
|
889
|
+
|
890
|
+
Returns
|
891
|
+
-------
|
892
|
+
str
|
893
|
+
Plotly HTML object containing a line chart, heat map, or violin plot for this feature dependence explanation
|
894
|
+
"""
|
895
|
+
return self.fd.show_in_notebook(
|
896
|
+
mode="pdp",
|
897
|
+
labels=labels,
|
898
|
+
cscale=cscale,
|
899
|
+
show_distribution=show_distribution,
|
900
|
+
discrete_threshold=discrete_threshold,
|
901
|
+
line_gap=0,
|
902
|
+
show_correlation_warning=show_correlation_warning,
|
903
|
+
centered=centered,
|
904
|
+
show_median=show_median,
|
905
|
+
)
|
906
|
+
|
907
|
+
|
908
|
+
class FeatureImportance:
|
909
|
+
"""
|
910
|
+
Feature Permutation Importance explanation object constructed by the
|
911
|
+
:class:`MLXGlobalExplainer` class.
|
912
|
+
|
913
|
+
Contains functions to visualize the explanation in a Notebook and extract the
|
914
|
+
raw explanation data.
|
915
|
+
"""
|
916
|
+
|
917
|
+
def __init__(self, explanation, class_names, type):
|
918
|
+
self.explanation = explanation
|
919
|
+
self.class_names = class_names
|
920
|
+
self.type = type
|
921
|
+
if isinstance(self.class_names, np.ndarray):
|
922
|
+
self.class_names = self.class_names.tolist()
|
923
|
+
|
924
|
+
def show_in_notebook(
|
925
|
+
self,
|
926
|
+
mode=None,
|
927
|
+
show=None,
|
928
|
+
labels=None,
|
929
|
+
cscale="YIGnBu",
|
930
|
+
colormap=None,
|
931
|
+
return_wordcloud=False,
|
932
|
+
n_features=None,
|
933
|
+
**kwargs,
|
934
|
+
):
|
935
|
+
"""
|
936
|
+
Generates a visualization for the local explanation. Depending on the type of explanation, different
|
937
|
+
visualizations are supported. See the "mode" and "show" parameters below.
|
938
|
+
|
939
|
+
Parameters
|
940
|
+
----------
|
941
|
+
mode : str
|
942
|
+
Type of visualization to generate. Certain visualization modes are only supported for either text or
|
943
|
+
tabular datasets. Supported options:
|
944
|
+
|
945
|
+
- `bar`: Generates a horizontal bar chart for the most important features (text and tabular).
|
946
|
+
- `stacked`: Generates a stacked horizontal bar chart for the most important features (text and tabular).
|
947
|
+
- `box_plot`: Generates a box plot for the most important features, which provides more information
|
948
|
+
about the explanation over the different iterations of the feature permutation importance
|
949
|
+
algorithm (tabular only).
|
950
|
+
- `detailed`: Generates a scatter plot for the most important features, providing even more
|
951
|
+
information about the explanation over the different iterations of the Feature Permutation
|
952
|
+
Importance algorithm (tabular only).
|
953
|
+
- `heatmap`: Generates a heatmap representing the average feature/word importance over multiple
|
954
|
+
local explanations (aggregates local explanations). Average feature importance is measured by
|
955
|
+
the fraction of local explanations where a given feature was assigned a given importance
|
956
|
+
(text only).
|
957
|
+
- `wordcloud`: Generates a wordcloud from the average feature importance. Features/words with
|
958
|
+
higher importance are larger than features/words with lower importance (text only).
|
959
|
+
|
960
|
+
Default value is "bar" for tabular and "wordcloud" for text.
|
961
|
+
show : str
|
962
|
+
(text only) Secondary visualization mode for configuring the visualization.
|
963
|
+
Can be one of:
|
964
|
+
|
965
|
+
- `absolute`: The absolute value of feature importances are shown (i.e., a feature that is highly
|
966
|
+
important towards or against the target label is considered important).
|
967
|
+
- `posneg`: Shows both the positive and negative global feature attributions. For bar, the features
|
968
|
+
are ordered based on their absolute feature importance (sum of pos/neg) and a dual bar chart shows
|
969
|
+
the fraction of local explanations where the feature contributed both towards and against the
|
970
|
+
corresponding label. For wordcloud, two wordclouds are generated for the positive and negative feature
|
971
|
+
importances. Only valid for mode=bar and mode=wordcloud. `mode=heatmap` defaults to `show=absolute`.
|
972
|
+
|
973
|
+
labels : tuple, list, int, bool, str
|
974
|
+
(text only) Label indices to visualize. If `None`, all of the labels that the explanation was generated for
|
975
|
+
will be visualized. By default None.
|
976
|
+
cscale : str, optional
|
977
|
+
Plotly color scale to use for the heatmap. See the standard Plotly color scales for available
|
978
|
+
options. Default value is "YIGnBu".
|
979
|
+
colormap : list of str, optional
|
980
|
+
List of colormaps to use for the wordclouds. One per label. Defaults to `None`.
|
981
|
+
return_wordcloud : bool, optional
|
982
|
+
If `True`, the generated wordcloud objects are returned instead of visualized. Defaults
|
983
|
+
to `False`.
|
984
|
+
n_features : int, optional
|
985
|
+
(tabular only). Allows the user to visualize a subset of the top-N most important features from the explainer.
|
986
|
+
If `n_features` is `None` or greater than the total number of features, all features are shown. If
|
987
|
+
`n_features` is not an `int` or <= 0, an exception is thrown.
|
988
|
+
kwargs : dict
|
989
|
+
Keyword arguments for configuring the wordclouds.
|
990
|
+
|
991
|
+
Returns
|
992
|
+
-------
|
993
|
+
str, list of wordcloud
|
994
|
+
HTML string for the visualization or list of generated wordcloud objects if `return_wordcloud=True`,
|
995
|
+
two per label (+/-).
|
996
|
+
"""
|
997
|
+
if self.type.type == "text":
|
998
|
+
if labels:
|
999
|
+
labels = [
|
1000
|
+
self.class_names.index(label) if isinstance(label, str) else label
|
1001
|
+
for label in labels
|
1002
|
+
]
|
1003
|
+
else:
|
1004
|
+
labels = list(range(len(self.class_names)))
|
1005
|
+
if len(self.class_names) == 2:
|
1006
|
+
return self.explanation.show_in_notebook()
|
1007
|
+
else:
|
1008
|
+
return self.explanation.show_in_notebook(
|
1009
|
+
mode=mode if mode else "wordcloud",
|
1010
|
+
show=show,
|
1011
|
+
labels=labels,
|
1012
|
+
cscale=cscale,
|
1013
|
+
colormap=colormap,
|
1014
|
+
return_wordcloud=return_wordcloud,
|
1015
|
+
**kwargs,
|
1016
|
+
)
|
1017
|
+
else:
|
1018
|
+
if labels:
|
1019
|
+
raise ValueError("label is supported only for text explanation.")
|
1020
|
+
return self.explanation.show_in_notebook(
|
1021
|
+
n_features=n_features, mode=mode if mode else "bar"
|
1022
|
+
)
|
1023
|
+
|
1024
|
+
def get_global_explanation(self):
|
1025
|
+
"""
|
1026
|
+
Returns the raw global explanation data only.
|
1027
|
+
|
1028
|
+
Returns
|
1029
|
+
-------
|
1030
|
+
dict
|
1031
|
+
Dictionary containing raw explanation data.
|
1032
|
+
"""
|
1033
|
+
return self.explanation.get_global_explanation()
|
1034
|
+
|
1035
|
+
def get_diagnostics(self):
|
1036
|
+
"""
|
1037
|
+
Extracts the raw explanation and evaluation data from the explanation object
|
1038
|
+
(Used to generate the visualizations).
|
1039
|
+
|
1040
|
+
Returns
|
1041
|
+
-------
|
1042
|
+
dict
|
1043
|
+
Dictionary containing the raw explanation/evaluation data.
|
1044
|
+
"""
|
1045
|
+
return self.explanation.get_diagnostic()
|
1046
|
+
|
1047
|
+
|
1048
|
+
class GlobalExplanationsException(TypeError):
|
1049
|
+
def __init__(self, msg):
|
1050
|
+
super(GlobalExplanationsException, self).__init__(msg)
|