oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.9rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +506 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +269 -0
- ads/aqua/common/enums.py +122 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1285 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +248 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +298 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +282 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2114 -0
- ads/aqua/modeldeployment/__init__.py +8 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1326 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/inference.py +74 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +499 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +175 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +445 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +125 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/METADATA +150 -150
- oracle_ads-2.13.9rc1.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,979 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
|
7
|
+
"""LLM for OCI data science model deployment endpoint."""
|
8
|
+
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
import traceback
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
AsyncIterator,
|
15
|
+
Callable,
|
16
|
+
Dict,
|
17
|
+
Iterator,
|
18
|
+
List,
|
19
|
+
Literal,
|
20
|
+
Optional,
|
21
|
+
Union,
|
22
|
+
)
|
23
|
+
|
24
|
+
import aiohttp
|
25
|
+
import requests
|
26
|
+
from langchain_community.utilities.requests import Requests
|
27
|
+
from langchain_core.callbacks import (
|
28
|
+
AsyncCallbackManagerForLLMRun,
|
29
|
+
CallbackManagerForLLMRun,
|
30
|
+
)
|
31
|
+
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
32
|
+
from langchain_core.load.serializable import Serializable
|
33
|
+
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
34
|
+
from langchain_core.utils import get_from_dict_or_env
|
35
|
+
from pydantic import Field, model_validator
|
36
|
+
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
DEFAULT_TIME_OUT = 300
|
41
|
+
DEFAULT_CONTENT_TYPE_JSON = "application/json"
|
42
|
+
DEFAULT_MODEL_NAME = "odsc-llm"
|
43
|
+
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
|
44
|
+
|
45
|
+
|
46
|
+
class TokenExpiredError(Exception):
|
47
|
+
"""Raises when token expired."""
|
48
|
+
|
49
|
+
|
50
|
+
class ServerError(Exception):
|
51
|
+
"""Raises when encounter server error when making inference."""
|
52
|
+
|
53
|
+
|
54
|
+
def _create_retry_decorator(
|
55
|
+
llm: "BaseOCIModelDeployment",
|
56
|
+
*,
|
57
|
+
run_manager: Optional[
|
58
|
+
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
59
|
+
] = None,
|
60
|
+
) -> Callable[[Any], Any]:
|
61
|
+
"""Create a retry decorator."""
|
62
|
+
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
|
63
|
+
decorator = create_base_retry_decorator(
|
64
|
+
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
65
|
+
)
|
66
|
+
return decorator
|
67
|
+
|
68
|
+
|
69
|
+
class BaseOCIModelDeployment(Serializable):
|
70
|
+
"""Base class for LLM deployed on OCI Data Science Model Deployment."""
|
71
|
+
|
72
|
+
auth: dict = Field(default_factory=dict, exclude=True)
|
73
|
+
"""ADS auth dictionary for OCI authentication:
|
74
|
+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
|
75
|
+
This can be generated by calling `ads.common.auth.api_keys()`
|
76
|
+
or `ads.common.auth.resource_principal()`. If this is not
|
77
|
+
provided then the `ads.common.default_signer()` will be used."""
|
78
|
+
|
79
|
+
endpoint: str = ""
|
80
|
+
"""The uri of the endpoint from the deployed Model Deployment model."""
|
81
|
+
|
82
|
+
streaming: bool = False
|
83
|
+
"""Whether to stream the results or not."""
|
84
|
+
|
85
|
+
max_retries: int = 3
|
86
|
+
"""Maximum number of retries to make when generating."""
|
87
|
+
|
88
|
+
default_headers: Optional[Dict[str, Any]] = None
|
89
|
+
"""The headers to be added to the Model Deployment request."""
|
90
|
+
|
91
|
+
@model_validator(mode="before")
|
92
|
+
@classmethod
|
93
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
94
|
+
"""Checks if oracle-ads is installed and
|
95
|
+
get credentials/endpoint from environment.
|
96
|
+
"""
|
97
|
+
try:
|
98
|
+
import ads
|
99
|
+
|
100
|
+
except ImportError as ex:
|
101
|
+
raise ImportError(
|
102
|
+
"Could not import ads python package. "
|
103
|
+
"Please install it with `pip install oracle_ads`."
|
104
|
+
) from ex
|
105
|
+
|
106
|
+
if not values.get("auth"):
|
107
|
+
values["auth"] = ads.common.auth.default_signer()
|
108
|
+
|
109
|
+
values["endpoint"] = get_from_dict_or_env(
|
110
|
+
values,
|
111
|
+
"endpoint",
|
112
|
+
"OCI_LLM_ENDPOINT",
|
113
|
+
)
|
114
|
+
return values
|
115
|
+
|
116
|
+
def _headers(
|
117
|
+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
118
|
+
) -> Dict:
|
119
|
+
"""Construct and return the headers for a request.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
is_async (bool, optional): Indicates if the request is asynchronous.
|
123
|
+
Defaults to `False`.
|
124
|
+
body (optional): The request body to be included in the headers if
|
125
|
+
the request is asynchronous.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
Dict: A dictionary containing the appropriate headers for the request.
|
129
|
+
"""
|
130
|
+
headers = self.default_headers or {}
|
131
|
+
if is_async:
|
132
|
+
signer = self.auth["signer"]
|
133
|
+
_req = requests.Request("POST", self.endpoint, json=body)
|
134
|
+
req = _req.prepare()
|
135
|
+
req = signer(req)
|
136
|
+
for key, value in req.headers.items():
|
137
|
+
headers[key] = value
|
138
|
+
|
139
|
+
if self.streaming:
|
140
|
+
headers.update(
|
141
|
+
{"enable-streaming": "true", "Accept": "text/event-stream"}
|
142
|
+
)
|
143
|
+
return headers
|
144
|
+
|
145
|
+
headers.update(
|
146
|
+
{
|
147
|
+
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
148
|
+
"enable-streaming": "true",
|
149
|
+
"Accept": "text/event-stream",
|
150
|
+
}
|
151
|
+
if self.streaming
|
152
|
+
else {
|
153
|
+
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
154
|
+
}
|
155
|
+
)
|
156
|
+
|
157
|
+
return headers
|
158
|
+
|
159
|
+
def completion_with_retry(
|
160
|
+
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
161
|
+
) -> Any:
|
162
|
+
"""Use tenacity to retry the completion call."""
|
163
|
+
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
164
|
+
|
165
|
+
@retry_decorator
|
166
|
+
def _completion_with_retry(**kwargs: Any) -> Any:
|
167
|
+
try:
|
168
|
+
request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT)
|
169
|
+
data = kwargs.pop("data")
|
170
|
+
stream = kwargs.pop("stream", self.streaming)
|
171
|
+
|
172
|
+
request = Requests(
|
173
|
+
headers=self._headers(), auth=self.auth.get("signer")
|
174
|
+
)
|
175
|
+
response = request.post(
|
176
|
+
url=self.endpoint,
|
177
|
+
data=data,
|
178
|
+
timeout=request_timeout,
|
179
|
+
stream=stream,
|
180
|
+
**kwargs,
|
181
|
+
)
|
182
|
+
self._check_response(response)
|
183
|
+
return response
|
184
|
+
except TokenExpiredError as e:
|
185
|
+
raise e
|
186
|
+
except Exception as err:
|
187
|
+
traceback.print_exc()
|
188
|
+
logger.debug(
|
189
|
+
f"Requests payload: {data}. Requests arguments: "
|
190
|
+
f"url={self.endpoint},timeout={request_timeout},stream={stream}. "
|
191
|
+
f"Additional request kwargs={kwargs}."
|
192
|
+
)
|
193
|
+
raise RuntimeError(
|
194
|
+
f"Error occurs by inference endpoint: {str(err)}"
|
195
|
+
) from err
|
196
|
+
|
197
|
+
return _completion_with_retry(**kwargs)
|
198
|
+
|
199
|
+
async def acompletion_with_retry(
|
200
|
+
self,
|
201
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
202
|
+
**kwargs: Any,
|
203
|
+
) -> Any:
|
204
|
+
"""Use tenacity to retry the async completion call."""
|
205
|
+
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
206
|
+
|
207
|
+
@retry_decorator
|
208
|
+
async def _completion_with_retry(**kwargs: Any) -> Any:
|
209
|
+
try:
|
210
|
+
request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT)
|
211
|
+
data = kwargs.pop("data")
|
212
|
+
stream = kwargs.pop("stream", self.streaming)
|
213
|
+
|
214
|
+
request = Requests(headers=self._headers(is_async=True, body=data))
|
215
|
+
if stream:
|
216
|
+
response = request.apost(
|
217
|
+
url=self.endpoint,
|
218
|
+
data=data,
|
219
|
+
timeout=request_timeout,
|
220
|
+
)
|
221
|
+
return self._aiter_sse(response)
|
222
|
+
else:
|
223
|
+
async with request.apost(
|
224
|
+
url=self.endpoint,
|
225
|
+
data=data,
|
226
|
+
timeout=request_timeout,
|
227
|
+
) as resp:
|
228
|
+
self._check_response(resp)
|
229
|
+
data = await resp.json()
|
230
|
+
return data
|
231
|
+
except TokenExpiredError as e:
|
232
|
+
raise e
|
233
|
+
except Exception as err:
|
234
|
+
traceback.print_exc()
|
235
|
+
logger.debug(
|
236
|
+
f"Requests payload: `{data}`. "
|
237
|
+
f"Stream mode={stream}. "
|
238
|
+
f"Requests kwargs: url={self.endpoint}, timeout={request_timeout}."
|
239
|
+
)
|
240
|
+
raise RuntimeError(
|
241
|
+
f"Error occurs by inference endpoint: {str(err)}"
|
242
|
+
) from err
|
243
|
+
|
244
|
+
return await _completion_with_retry(**kwargs)
|
245
|
+
|
246
|
+
def _check_response(self, response: Any) -> None:
|
247
|
+
"""Handle server error by checking the response status.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
response:
|
251
|
+
The response object from either `requests` or `aiohttp` library.
|
252
|
+
|
253
|
+
Raises:
|
254
|
+
TokenExpiredError:
|
255
|
+
If the response status code is 401 and the token refresh is successful.
|
256
|
+
ServerError:
|
257
|
+
If any other HTTP error occurs.
|
258
|
+
"""
|
259
|
+
try:
|
260
|
+
response.raise_for_status()
|
261
|
+
except requests.exceptions.HTTPError as http_err:
|
262
|
+
status_code = (
|
263
|
+
response.status_code
|
264
|
+
if hasattr(response, "status_code")
|
265
|
+
else response.status
|
266
|
+
)
|
267
|
+
if status_code in [401, 404] and self._refresh_signer():
|
268
|
+
raise TokenExpiredError() from http_err
|
269
|
+
|
270
|
+
raise ServerError(
|
271
|
+
f"Server error: {str(http_err)}. \nMessage: {response.text}"
|
272
|
+
) from http_err
|
273
|
+
|
274
|
+
def _parse_stream(self, lines: Iterator[bytes]) -> Iterator[str]:
|
275
|
+
"""Parse a stream of byte lines and yield parsed string lines.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
lines (Iterator[bytes]):
|
279
|
+
An iterator that yields lines in byte format.
|
280
|
+
|
281
|
+
Yields:
|
282
|
+
Iterator[str]:
|
283
|
+
An iterator that yields parsed lines as strings.
|
284
|
+
"""
|
285
|
+
for line in lines:
|
286
|
+
_line = self._parse_stream_line(line)
|
287
|
+
if _line is not None:
|
288
|
+
yield _line
|
289
|
+
|
290
|
+
async def _parse_stream_async(
|
291
|
+
self,
|
292
|
+
lines: aiohttp.StreamReader,
|
293
|
+
) -> AsyncIterator[str]:
|
294
|
+
"""
|
295
|
+
Asynchronously parse a stream of byte lines and yield parsed string lines.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
lines (aiohttp.StreamReader):
|
299
|
+
An `aiohttp.StreamReader` object that yields lines in byte format.
|
300
|
+
|
301
|
+
Yields:
|
302
|
+
AsyncIterator[str]:
|
303
|
+
An asynchronous iterator that yields parsed lines as strings.
|
304
|
+
"""
|
305
|
+
async for line in lines:
|
306
|
+
_line = self._parse_stream_line(line)
|
307
|
+
if _line is not None:
|
308
|
+
yield _line
|
309
|
+
|
310
|
+
def _parse_stream_line(self, line: bytes) -> Optional[str]:
|
311
|
+
"""Parse a single byte line and return a processed string line if valid.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
line (bytes): A single line in byte format.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
Optional[str]:
|
318
|
+
The processed line as a string if valid, otherwise `None`.
|
319
|
+
"""
|
320
|
+
line = line.strip()
|
321
|
+
if not line:
|
322
|
+
return None
|
323
|
+
_line = line.decode("utf-8")
|
324
|
+
|
325
|
+
if _line.lower().startswith("data:"):
|
326
|
+
_line = _line[5:].lstrip()
|
327
|
+
|
328
|
+
if _line.startswith("[DONE]"):
|
329
|
+
return None
|
330
|
+
return _line
|
331
|
+
return None
|
332
|
+
|
333
|
+
async def _aiter_sse(
|
334
|
+
self,
|
335
|
+
async_cntx_mgr: Any,
|
336
|
+
) -> AsyncIterator[str]:
|
337
|
+
"""Asynchronously iterate over server-sent events (SSE).
|
338
|
+
|
339
|
+
Args:
|
340
|
+
async_cntx_mgr: An asynchronous context manager that yields a client
|
341
|
+
response object.
|
342
|
+
|
343
|
+
Yields:
|
344
|
+
AsyncIterator[str]: An asynchronous iterator that yields parsed server-sent
|
345
|
+
event lines as json string.
|
346
|
+
"""
|
347
|
+
async with async_cntx_mgr as client_resp:
|
348
|
+
self._check_response(client_resp)
|
349
|
+
async for line in self._parse_stream_async(client_resp.content):
|
350
|
+
yield line
|
351
|
+
|
352
|
+
def _refresh_signer(self) -> bool:
|
353
|
+
"""Attempt to refresh the security token using the signer.
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
bool: `True` if the token was successfully refreshed, `False` otherwise.
|
357
|
+
"""
|
358
|
+
if self.auth.get("signer", None) and hasattr(
|
359
|
+
self.auth["signer"], "refresh_security_token"
|
360
|
+
):
|
361
|
+
self.auth["signer"].refresh_security_token()
|
362
|
+
return True
|
363
|
+
return False
|
364
|
+
|
365
|
+
@classmethod
|
366
|
+
def is_lc_serializable(cls) -> bool:
|
367
|
+
"""Return whether this model can be serialized by LangChain."""
|
368
|
+
return True
|
369
|
+
|
370
|
+
|
371
|
+
class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
372
|
+
"""LLM deployed on OCI Data Science Model Deployment.
|
373
|
+
|
374
|
+
To use, you must provide the model HTTP endpoint from your deployed
|
375
|
+
model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
|
376
|
+
|
377
|
+
To authenticate, `oracle-ads` has been used to automatically load
|
378
|
+
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
379
|
+
|
380
|
+
Make sure to have the required policies to access the OCI Data
|
381
|
+
Science Model Deployment endpoint. See:
|
382
|
+
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
383
|
+
|
384
|
+
Example:
|
385
|
+
|
386
|
+
.. code-block:: python
|
387
|
+
|
388
|
+
from langchain_community.llms import OCIModelDeploymentLLM
|
389
|
+
|
390
|
+
llm = OCIModelDeploymentLLM(
|
391
|
+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
|
392
|
+
model="odsc-llm",
|
393
|
+
streaming=True,
|
394
|
+
model_kwargs={"frequency_penalty": 1.0},
|
395
|
+
headers={
|
396
|
+
"route": "/v1/completions",
|
397
|
+
# other request headers ...
|
398
|
+
}
|
399
|
+
)
|
400
|
+
llm.invoke("tell me a joke.")
|
401
|
+
|
402
|
+
Customized Usage:
|
403
|
+
|
404
|
+
User can inherit from our base class and overrwrite the `_process_response`, `_process_stream_response`,
|
405
|
+
`_construct_json_body` for satisfying customized needed.
|
406
|
+
|
407
|
+
.. code-block:: python
|
408
|
+
|
409
|
+
from langchain_community.llms import OCIModelDeploymentLLM
|
410
|
+
|
411
|
+
class MyCutomizedModel(OCIModelDeploymentLLM):
|
412
|
+
def _process_stream_response(self, response_json:dict) -> GenerationChunk:
|
413
|
+
print("My customized output stream handler.")
|
414
|
+
return GenerationChunk()
|
415
|
+
|
416
|
+
def _process_response(self, response_json:dict) -> List[Generation]:
|
417
|
+
print("My customized output handler.")
|
418
|
+
return [Generation()]
|
419
|
+
|
420
|
+
def _construct_json_body(self, prompt: str, param:dict) -> dict:
|
421
|
+
print("My customized input handler.")
|
422
|
+
return {}
|
423
|
+
|
424
|
+
llm = MyCutomizedModel(
|
425
|
+
endpoint=f"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/{ocid}/predict",
|
426
|
+
model="<model_name>",
|
427
|
+
}
|
428
|
+
|
429
|
+
llm.invoke("tell me a joke.")
|
430
|
+
|
431
|
+
""" # noqa: E501
|
432
|
+
|
433
|
+
model: str = DEFAULT_MODEL_NAME
|
434
|
+
"""The name of the model."""
|
435
|
+
|
436
|
+
stop: Optional[List[str]] = None
|
437
|
+
"""Stop words to use when generating. Model output is cut off
|
438
|
+
at the first occurrence of any of these substrings."""
|
439
|
+
|
440
|
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
441
|
+
"""Keyword arguments to pass to the model."""
|
442
|
+
|
443
|
+
@property
|
444
|
+
def _llm_type(self) -> str:
|
445
|
+
"""Return type of llm."""
|
446
|
+
return "oci_model_deployment_endpoint"
|
447
|
+
|
448
|
+
@property
|
449
|
+
def _default_params(self) -> Dict[str, Any]:
|
450
|
+
"""Get the default parameters."""
|
451
|
+
return {
|
452
|
+
"model": self.model,
|
453
|
+
"stop": self.stop,
|
454
|
+
"stream": self.streaming,
|
455
|
+
}
|
456
|
+
|
457
|
+
@property
|
458
|
+
def _identifying_params(self) -> Dict[str, Any]:
|
459
|
+
"""Get the identifying parameters."""
|
460
|
+
_model_kwargs = self.model_kwargs or {}
|
461
|
+
return {
|
462
|
+
**{"endpoint": self.endpoint, "model_kwargs": _model_kwargs},
|
463
|
+
**self._default_params,
|
464
|
+
}
|
465
|
+
|
466
|
+
def _headers(
|
467
|
+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
468
|
+
) -> Dict:
|
469
|
+
"""Construct and return the headers for a request.
|
470
|
+
|
471
|
+
Args:
|
472
|
+
is_async (bool, optional): Indicates if the request is asynchronous.
|
473
|
+
Defaults to `False`.
|
474
|
+
body (optional): The request body to be included in the headers if
|
475
|
+
the request is asynchronous.
|
476
|
+
|
477
|
+
Returns:
|
478
|
+
Dict: A dictionary containing the appropriate headers for the request.
|
479
|
+
"""
|
480
|
+
return {
|
481
|
+
"route": DEFAULT_INFERENCE_ENDPOINT,
|
482
|
+
**super()._headers(is_async=is_async, body=body),
|
483
|
+
}
|
484
|
+
|
485
|
+
def _generate(
|
486
|
+
self,
|
487
|
+
prompts: List[str],
|
488
|
+
stop: Optional[List[str]] = None,
|
489
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
490
|
+
**kwargs: Any,
|
491
|
+
) -> LLMResult:
|
492
|
+
"""Call out to OCI Data Science Model Deployment endpoint with k unique prompts.
|
493
|
+
|
494
|
+
Args:
|
495
|
+
prompts: The prompts to pass into the service.
|
496
|
+
stop: Optional list of stop words to use when generating.
|
497
|
+
|
498
|
+
Returns:
|
499
|
+
The full LLM output.
|
500
|
+
|
501
|
+
Example:
|
502
|
+
.. code-block:: python
|
503
|
+
|
504
|
+
response = llm.invoke("Tell me a joke.")
|
505
|
+
response = llm.generate(["Tell me a joke."])
|
506
|
+
"""
|
507
|
+
generations: List[List[Generation]] = []
|
508
|
+
params = self._invocation_params(stop, **kwargs)
|
509
|
+
for prompt in prompts:
|
510
|
+
body = self._construct_json_body(prompt, params)
|
511
|
+
if self.streaming:
|
512
|
+
generation = GenerationChunk(text="")
|
513
|
+
for chunk in self._stream(
|
514
|
+
prompt, stop=stop, run_manager=run_manager, **kwargs
|
515
|
+
):
|
516
|
+
generation += chunk
|
517
|
+
generations.append([generation])
|
518
|
+
else:
|
519
|
+
res = self.completion_with_retry(
|
520
|
+
data=body,
|
521
|
+
run_manager=run_manager,
|
522
|
+
**kwargs,
|
523
|
+
)
|
524
|
+
generations.append(self._process_response(res.json()))
|
525
|
+
return LLMResult(generations=generations)
|
526
|
+
|
527
|
+
async def _agenerate(
|
528
|
+
self,
|
529
|
+
prompts: List[str],
|
530
|
+
stop: Optional[List[str]] = None,
|
531
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
532
|
+
**kwargs: Any,
|
533
|
+
) -> LLMResult:
|
534
|
+
"""Call out to OCI Data Science Model Deployment endpoint async with k unique prompts.
|
535
|
+
|
536
|
+
Args:
|
537
|
+
prompts: The prompts to pass into the service.
|
538
|
+
stop: Optional list of stop words to use when generating.
|
539
|
+
|
540
|
+
Returns:
|
541
|
+
The full LLM output.
|
542
|
+
|
543
|
+
Example:
|
544
|
+
.. code-block:: python
|
545
|
+
|
546
|
+
response = await llm.ainvoke("Tell me a joke.")
|
547
|
+
response = await llm.agenerate(["Tell me a joke."])
|
548
|
+
""" # noqa: E501
|
549
|
+
generations: List[List[Generation]] = []
|
550
|
+
params = self._invocation_params(stop, **kwargs)
|
551
|
+
for prompt in prompts:
|
552
|
+
body = self._construct_json_body(prompt, params)
|
553
|
+
if self.streaming:
|
554
|
+
generation = GenerationChunk(text="")
|
555
|
+
async for chunk in self._astream(
|
556
|
+
prompt, stop=stop, run_manager=run_manager, **kwargs
|
557
|
+
):
|
558
|
+
generation += chunk
|
559
|
+
generations.append([generation])
|
560
|
+
else:
|
561
|
+
res = await self.acompletion_with_retry(
|
562
|
+
data=body,
|
563
|
+
run_manager=run_manager,
|
564
|
+
**kwargs,
|
565
|
+
)
|
566
|
+
generations.append(self._process_response(res))
|
567
|
+
return LLMResult(generations=generations)
|
568
|
+
|
569
|
+
def _stream(
|
570
|
+
self,
|
571
|
+
prompt: str,
|
572
|
+
stop: Optional[List[str]] = None,
|
573
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
574
|
+
**kwargs: Any,
|
575
|
+
) -> Iterator[GenerationChunk]:
|
576
|
+
"""Stream OCI Data Science Model Deployment endpoint on given prompt.
|
577
|
+
|
578
|
+
|
579
|
+
Args:
|
580
|
+
prompt (str):
|
581
|
+
The prompt to pass into the model.
|
582
|
+
stop (List[str], Optional):
|
583
|
+
List of stop words to use when generating.
|
584
|
+
kwargs:
|
585
|
+
requests_kwargs:
|
586
|
+
Additional ``**kwargs`` to pass to requests.post
|
587
|
+
|
588
|
+
Returns:
|
589
|
+
An iterator of GenerationChunks.
|
590
|
+
|
591
|
+
|
592
|
+
Example:
|
593
|
+
|
594
|
+
.. code-block:: python
|
595
|
+
|
596
|
+
response = llm.stream("Tell me a joke.")
|
597
|
+
|
598
|
+
"""
|
599
|
+
requests_kwargs = kwargs.pop("requests_kwargs", {})
|
600
|
+
self.streaming = True
|
601
|
+
params = self._invocation_params(stop, **kwargs)
|
602
|
+
body = self._construct_json_body(prompt, params)
|
603
|
+
|
604
|
+
response = self.completion_with_retry(
|
605
|
+
data=body, run_manager=run_manager, stream=True, **requests_kwargs
|
606
|
+
)
|
607
|
+
for line in self._parse_stream(response.iter_lines()):
|
608
|
+
chunk = self._handle_sse_line(line)
|
609
|
+
if run_manager:
|
610
|
+
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
611
|
+
|
612
|
+
yield chunk
|
613
|
+
|
614
|
+
async def _astream(
|
615
|
+
self,
|
616
|
+
prompt: str,
|
617
|
+
stop: Optional[List[str]] = None,
|
618
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
619
|
+
**kwargs: Any,
|
620
|
+
) -> AsyncIterator[GenerationChunk]:
|
621
|
+
"""Stream OCI Data Science Model Deployment endpoint async on given prompt.
|
622
|
+
|
623
|
+
|
624
|
+
Args:
|
625
|
+
prompt (str):
|
626
|
+
The prompt to pass into the model.
|
627
|
+
stop (List[str], Optional):
|
628
|
+
List of stop words to use when generating.
|
629
|
+
kwargs:
|
630
|
+
requests_kwargs:
|
631
|
+
Additional ``**kwargs`` to pass to requests.post
|
632
|
+
|
633
|
+
Returns:
|
634
|
+
An iterator of GenerationChunks.
|
635
|
+
|
636
|
+
|
637
|
+
Example:
|
638
|
+
|
639
|
+
.. code-block:: python
|
640
|
+
|
641
|
+
async for chunk in llm.astream(("Tell me a joke."):
|
642
|
+
print(chunk, end="", flush=True)
|
643
|
+
|
644
|
+
"""
|
645
|
+
requests_kwargs = kwargs.pop("requests_kwargs", {})
|
646
|
+
self.streaming = True
|
647
|
+
params = self._invocation_params(stop, **kwargs)
|
648
|
+
body = self._construct_json_body(prompt, params)
|
649
|
+
|
650
|
+
async for line in await self.acompletion_with_retry(
|
651
|
+
data=body, run_manager=run_manager, stream=True, **requests_kwargs
|
652
|
+
):
|
653
|
+
chunk = self._handle_sse_line(line)
|
654
|
+
if run_manager:
|
655
|
+
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
656
|
+
yield chunk
|
657
|
+
|
658
|
+
def _construct_json_body(self, prompt: str, params: dict) -> dict:
|
659
|
+
"""Constructs the request body as a dictionary (JSON)."""
|
660
|
+
return {
|
661
|
+
"prompt": prompt,
|
662
|
+
**params,
|
663
|
+
}
|
664
|
+
|
665
|
+
def _invocation_params(
|
666
|
+
self, stop: Optional[List[str]] = None, **kwargs: Any
|
667
|
+
) -> dict:
|
668
|
+
"""Combines the invocation parameters with default parameters."""
|
669
|
+
params = self._default_params
|
670
|
+
_model_kwargs = self.model_kwargs or {}
|
671
|
+
params["stop"] = stop or params.get("stop", [])
|
672
|
+
return {**params, **_model_kwargs, **kwargs}
|
673
|
+
|
674
|
+
def _process_stream_response(self, response_json: dict) -> GenerationChunk:
|
675
|
+
"""Formats streaming response for OpenAI spec into GenerationChunk."""
|
676
|
+
try:
|
677
|
+
choice = response_json["choices"][0]
|
678
|
+
if not isinstance(choice, dict):
|
679
|
+
raise TypeError("Endpoint response is not well formed.")
|
680
|
+
except (KeyError, IndexError, TypeError) as e:
|
681
|
+
raise ValueError("Error while formatting response payload.") from e
|
682
|
+
|
683
|
+
return GenerationChunk(text=choice.get("text", ""))
|
684
|
+
|
685
|
+
def _process_response(self, response_json: dict) -> List[Generation]:
|
686
|
+
"""Formats response in OpenAI spec.
|
687
|
+
|
688
|
+
Args:
|
689
|
+
response_json (dict): The JSON response from the chat model endpoint.
|
690
|
+
|
691
|
+
Returns:
|
692
|
+
ChatResult: An object containing the list of `ChatGeneration` objects
|
693
|
+
and additional LLM output information.
|
694
|
+
|
695
|
+
Raises:
|
696
|
+
ValueError: If the response JSON is not well-formed or does not
|
697
|
+
contain the expected structure.
|
698
|
+
|
699
|
+
"""
|
700
|
+
generations = []
|
701
|
+
try:
|
702
|
+
choices = response_json["choices"]
|
703
|
+
if not isinstance(choices, list):
|
704
|
+
raise TypeError("Endpoint response is not well formed.")
|
705
|
+
except (KeyError, TypeError) as e:
|
706
|
+
raise ValueError("Error while formatting response payload.") from e
|
707
|
+
|
708
|
+
for choice in choices:
|
709
|
+
gen = Generation(
|
710
|
+
text=choice.get("text"),
|
711
|
+
generation_info=self._generate_info(choice),
|
712
|
+
)
|
713
|
+
generations.append(gen)
|
714
|
+
|
715
|
+
return generations
|
716
|
+
|
717
|
+
def _generate_info(self, choice: dict) -> Any:
|
718
|
+
"""Extracts generation info from the response."""
|
719
|
+
gen_info = {}
|
720
|
+
finish_reason = choice.get("finish_reason")
|
721
|
+
logprobs = choice.get("logprobs")
|
722
|
+
index = choice.get("index")
|
723
|
+
if finish_reason:
|
724
|
+
gen_info.update({"finish_reason": finish_reason})
|
725
|
+
if logprobs is not None:
|
726
|
+
gen_info.update({"logprobs": logprobs})
|
727
|
+
if index is not None:
|
728
|
+
gen_info.update({"index": index})
|
729
|
+
|
730
|
+
return gen_info or None
|
731
|
+
|
732
|
+
def _handle_sse_line(self, line: str) -> GenerationChunk:
|
733
|
+
try:
|
734
|
+
obj = json.loads(line)
|
735
|
+
return self._process_stream_response(obj)
|
736
|
+
except Exception:
|
737
|
+
return GenerationChunk(text="")
|
738
|
+
|
739
|
+
|
740
|
+
class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
|
741
|
+
"""OCI Data Science Model Deployment TGI Endpoint.
|
742
|
+
|
743
|
+
To use, you must provide the model HTTP endpoint from your deployed
|
744
|
+
model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
|
745
|
+
|
746
|
+
To authenticate, `oracle-ads` has been used to automatically load
|
747
|
+
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
748
|
+
|
749
|
+
Make sure to have the required policies to access the OCI Data
|
750
|
+
Science Model Deployment endpoint. See:
|
751
|
+
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
752
|
+
|
753
|
+
Example:
|
754
|
+
.. code-block:: python
|
755
|
+
|
756
|
+
from langchain_community.llms import OCIModelDeploymentTGI
|
757
|
+
|
758
|
+
llm = OCIModelDeploymentTGI(
|
759
|
+
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict",
|
760
|
+
api="/v1/completions",
|
761
|
+
streaming=True,
|
762
|
+
temperature=0.2,
|
763
|
+
seed=42,
|
764
|
+
# other model parameters ...
|
765
|
+
)
|
766
|
+
|
767
|
+
"""
|
768
|
+
|
769
|
+
max_tokens: int = 256
|
770
|
+
"""Denotes the number of tokens to predict per generation."""
|
771
|
+
|
772
|
+
temperature: float = 0.2
|
773
|
+
"""A non-negative float that tunes the degree of randomness in generation."""
|
774
|
+
|
775
|
+
k: int = -1
|
776
|
+
"""Number of most likely tokens to consider at each step."""
|
777
|
+
|
778
|
+
p: float = 0.75
|
779
|
+
"""Total probability mass of tokens to consider at each step."""
|
780
|
+
|
781
|
+
best_of: int = 1
|
782
|
+
"""Generates best_of completions server-side and returns the "best"
|
783
|
+
(the one with the highest log probability per token).
|
784
|
+
"""
|
785
|
+
|
786
|
+
api: Literal["/generate", "/v1/completions"] = "/v1/completions"
|
787
|
+
"""Api spec."""
|
788
|
+
|
789
|
+
frequency_penalty: float = 0.0
|
790
|
+
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
791
|
+
|
792
|
+
seed: Optional[int] = None
|
793
|
+
"""Random sampling seed"""
|
794
|
+
|
795
|
+
repetition_penalty: Optional[float] = None
|
796
|
+
"""The parameter for repetition penalty. 1.0 means no penalty."""
|
797
|
+
|
798
|
+
suffix: Optional[str] = None
|
799
|
+
"""The text to append to the prompt. """
|
800
|
+
|
801
|
+
do_sample: bool = True
|
802
|
+
"""If set to True, this parameter enables decoding strategies such as
|
803
|
+
multi-nominal sampling, beam-search multi-nominal sampling, Top-K
|
804
|
+
sampling and Top-p sampling.
|
805
|
+
"""
|
806
|
+
|
807
|
+
watermark: bool = True
|
808
|
+
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
|
809
|
+
Defaults to True."""
|
810
|
+
|
811
|
+
return_full_text: bool = False
|
812
|
+
"""Whether to prepend the prompt to the generated text. Defaults to False."""
|
813
|
+
|
814
|
+
@property
|
815
|
+
def _llm_type(self) -> str:
|
816
|
+
"""Return type of llm."""
|
817
|
+
return "oci_model_deployment_tgi_endpoint"
|
818
|
+
|
819
|
+
@property
|
820
|
+
def _default_params(self) -> Dict[str, Any]:
|
821
|
+
"""Get the default parameters for invoking OCI model deployment TGI endpoint."""
|
822
|
+
return (
|
823
|
+
{
|
824
|
+
"model": self.model, # can be any
|
825
|
+
"frequency_penalty": self.frequency_penalty,
|
826
|
+
"max_tokens": self.max_tokens,
|
827
|
+
"repetition_penalty": self.repetition_penalty,
|
828
|
+
"temperature": self.temperature,
|
829
|
+
"top_p": self.p,
|
830
|
+
"seed": self.seed,
|
831
|
+
"stream": self.streaming,
|
832
|
+
"suffix": self.suffix,
|
833
|
+
"stop": self.stop,
|
834
|
+
}
|
835
|
+
if self.api == "/v1/completions"
|
836
|
+
else {
|
837
|
+
"best_of": self.best_of,
|
838
|
+
"max_new_tokens": self.max_tokens,
|
839
|
+
"temperature": self.temperature,
|
840
|
+
"top_k": (
|
841
|
+
self.k if self.k > 0 else None
|
842
|
+
), # `top_k` must be strictly positive'
|
843
|
+
"top_p": self.p,
|
844
|
+
"do_sample": self.do_sample,
|
845
|
+
"return_full_text": self.return_full_text,
|
846
|
+
"watermark": self.watermark,
|
847
|
+
"stop": self.stop,
|
848
|
+
}
|
849
|
+
)
|
850
|
+
|
851
|
+
@property
|
852
|
+
def _identifying_params(self) -> Dict[str, Any]:
|
853
|
+
"""Get the identifying parameters."""
|
854
|
+
_model_kwargs = self.model_kwargs or {}
|
855
|
+
return {
|
856
|
+
**{
|
857
|
+
"endpoint": self.endpoint,
|
858
|
+
"api": self.api,
|
859
|
+
"model_kwargs": _model_kwargs,
|
860
|
+
},
|
861
|
+
**self._default_params,
|
862
|
+
}
|
863
|
+
|
864
|
+
def _construct_json_body(self, prompt: str, params: dict) -> dict:
|
865
|
+
"""Construct request payload."""
|
866
|
+
if self.api == "/v1/completions":
|
867
|
+
return super()._construct_json_body(prompt, params)
|
868
|
+
|
869
|
+
return {
|
870
|
+
"inputs": prompt,
|
871
|
+
"parameters": params,
|
872
|
+
}
|
873
|
+
|
874
|
+
def _process_response(self, response_json: dict) -> List[Generation]:
|
875
|
+
"""Formats response."""
|
876
|
+
if self.api == "/v1/completions":
|
877
|
+
return super()._process_response(response_json)
|
878
|
+
|
879
|
+
try:
|
880
|
+
text = response_json["generated_text"]
|
881
|
+
except KeyError as e:
|
882
|
+
raise ValueError(
|
883
|
+
f"Error while formatting response payload.response_json={response_json}"
|
884
|
+
) from e
|
885
|
+
|
886
|
+
return [Generation(text=text)]
|
887
|
+
|
888
|
+
|
889
|
+
class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
|
890
|
+
"""VLLM deployed on OCI Data Science Model Deployment
|
891
|
+
|
892
|
+
To use, you must provide the model HTTP endpoint from your deployed
|
893
|
+
model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
|
894
|
+
|
895
|
+
To authenticate, `oracle-ads` has been used to automatically load
|
896
|
+
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
897
|
+
|
898
|
+
Make sure to have the required policies to access the OCI Data
|
899
|
+
Science Model Deployment endpoint. See:
|
900
|
+
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
901
|
+
|
902
|
+
Example:
|
903
|
+
.. code-block:: python
|
904
|
+
|
905
|
+
from langchain_community.llms import OCIModelDeploymentVLLM
|
906
|
+
|
907
|
+
llm = OCIModelDeploymentVLLM(
|
908
|
+
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict",
|
909
|
+
model="odsc-llm",
|
910
|
+
streaming=False,
|
911
|
+
temperature=0.2,
|
912
|
+
max_tokens=512,
|
913
|
+
n=3,
|
914
|
+
best_of=3,
|
915
|
+
# other model parameters
|
916
|
+
)
|
917
|
+
|
918
|
+
"""
|
919
|
+
|
920
|
+
max_tokens: int = 256
|
921
|
+
"""Denotes the number of tokens to predict per generation."""
|
922
|
+
|
923
|
+
temperature: float = 0.2
|
924
|
+
"""A non-negative float that tunes the degree of randomness in generation."""
|
925
|
+
|
926
|
+
p: float = 0.75
|
927
|
+
"""Total probability mass of tokens to consider at each step."""
|
928
|
+
|
929
|
+
best_of: int = 1
|
930
|
+
"""Generates best_of completions server-side and returns the "best"
|
931
|
+
(the one with the highest log probability per token).
|
932
|
+
"""
|
933
|
+
|
934
|
+
n: int = 1
|
935
|
+
"""Number of output sequences to return for the given prompt."""
|
936
|
+
|
937
|
+
k: int = -1
|
938
|
+
"""Number of most likely tokens to consider at each step."""
|
939
|
+
|
940
|
+
frequency_penalty: float = 0.0
|
941
|
+
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
942
|
+
|
943
|
+
presence_penalty: float = 0.0
|
944
|
+
"""Penalizes repeated tokens. Between 0 and 1."""
|
945
|
+
|
946
|
+
use_beam_search: bool = False
|
947
|
+
"""Whether to use beam search instead of sampling."""
|
948
|
+
|
949
|
+
ignore_eos: bool = False
|
950
|
+
"""Whether to ignore the EOS token and continue generating tokens after
|
951
|
+
the EOS token is generated."""
|
952
|
+
|
953
|
+
logprobs: Optional[int] = None
|
954
|
+
"""Number of log probabilities to return per output token."""
|
955
|
+
|
956
|
+
@property
|
957
|
+
def _llm_type(self) -> str:
|
958
|
+
"""Return type of llm."""
|
959
|
+
return "oci_model_deployment_vllm_endpoint"
|
960
|
+
|
961
|
+
@property
|
962
|
+
def _default_params(self) -> Dict[str, Any]:
|
963
|
+
"""Get the default parameters for calling vllm."""
|
964
|
+
return {
|
965
|
+
"best_of": self.best_of,
|
966
|
+
"frequency_penalty": self.frequency_penalty,
|
967
|
+
"ignore_eos": self.ignore_eos,
|
968
|
+
"logprobs": self.logprobs,
|
969
|
+
"max_tokens": self.max_tokens,
|
970
|
+
"model": self.model,
|
971
|
+
"n": self.n,
|
972
|
+
"presence_penalty": self.presence_penalty,
|
973
|
+
"stop": self.stop,
|
974
|
+
"stream": self.streaming,
|
975
|
+
"temperature": self.temperature,
|
976
|
+
"top_k": self.k,
|
977
|
+
"top_p": self.p,
|
978
|
+
"use_beam_search": self.use_beam_search,
|
979
|
+
}
|