oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +40 -0
- ads/aqua/app.py +507 -0
- ads/aqua/cli.py +96 -0
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +836 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/common/decorator.py +125 -0
- ads/aqua/common/entities.py +274 -0
- ads/aqua/common/enums.py +134 -0
- ads/aqua/common/errors.py +109 -0
- ads/aqua/common/utils.py +1295 -0
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/container_config.py +246 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +116 -0
- ads/aqua/data.py +14 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +186 -0
- ads/aqua/evaluation/errors.py +70 -0
- ads/aqua/evaluation/evaluation.py +1814 -0
- ads/aqua/extension/__init__.py +42 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
- ads/aqua/extension/base_handler.py +90 -0
- ads/aqua/extension/common_handler.py +121 -0
- ads/aqua/extension/common_ws_msg_handler.py +36 -0
- ads/aqua/extension/deployment_handler.py +381 -0
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +30 -0
- ads/aqua/extension/evaluation_handler.py +129 -0
- ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
- ads/aqua/extension/finetune_handler.py +96 -0
- ads/aqua/extension/model_handler.py +390 -0
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +145 -0
- ads/aqua/extension/models_ws_msg_handler.py +50 -0
- ads/aqua/extension/ui_handler.py +300 -0
- ads/aqua/extension/ui_websocket_handler.py +130 -0
- ads/aqua/extension/utils.py +133 -0
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +23 -0
- ads/aqua/finetuning/entities.py +181 -0
- ads/aqua/finetuning/finetuning.py +749 -0
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +60 -0
- ads/aqua/model/entities.py +385 -0
- ads/aqua/model/enums.py +32 -0
- ads/aqua/model/model.py +2134 -0
- ads/aqua/model/utils.py +52 -0
- ads/aqua/modeldeployment/__init__.py +6 -0
- ads/aqua/modeldeployment/constants.py +10 -0
- ads/aqua/modeldeployment/deployment.py +1315 -0
- ads/aqua/modeldeployment/entities.py +653 -0
- ads/aqua/modeldeployment/utils.py +543 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +476 -0
- ads/aqua/ui.py +519 -0
- ads/automl/__init__.py +9 -0
- ads/automl/driver.py +330 -0
- ads/automl/provider.py +975 -0
- ads/bds/__init__.py +5 -0
- ads/bds/auth.py +127 -0
- ads/bds/big_data_service.py +255 -0
- ads/catalog/__init__.py +19 -0
- ads/catalog/model.py +1576 -0
- ads/catalog/notebook.py +461 -0
- ads/catalog/project.py +468 -0
- ads/catalog/summary.py +178 -0
- ads/common/__init__.py +11 -0
- ads/common/analyzer.py +65 -0
- ads/common/artifact/.model-ignore +63 -0
- ads/common/artifact/__init__.py +10 -0
- ads/common/auth.py +1122 -0
- ads/common/card_identifier.py +83 -0
- ads/common/config.py +647 -0
- ads/common/data.py +165 -0
- ads/common/decorator/__init__.py +9 -0
- ads/common/decorator/argument_to_case.py +88 -0
- ads/common/decorator/deprecate.py +69 -0
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/decorator/runtime_dependency.py +178 -0
- ads/common/decorator/threaded.py +97 -0
- ads/common/decorator/utils.py +35 -0
- ads/common/dsc_file_system.py +303 -0
- ads/common/error.py +14 -0
- ads/common/extended_enum.py +81 -0
- ads/common/function/__init__.py +5 -0
- ads/common/function/fn_util.py +142 -0
- ads/common/function/func_conf.yaml +25 -0
- ads/common/ipython.py +76 -0
- ads/common/model.py +679 -0
- ads/common/model_artifact.py +1759 -0
- ads/common/model_artifact_schema.json +107 -0
- ads/common/model_export_util.py +664 -0
- ads/common/model_metadata.py +24 -0
- ads/common/object_storage_details.py +296 -0
- ads/common/oci_client.py +179 -0
- ads/common/oci_datascience.py +46 -0
- ads/common/oci_logging.py +1144 -0
- ads/common/oci_mixin.py +957 -0
- ads/common/oci_resource.py +136 -0
- ads/common/serializer.py +559 -0
- ads/common/utils.py +1852 -0
- ads/common/word_lists.py +1491 -0
- ads/common/work_request.py +189 -0
- ads/config.py +1 -0
- ads/data_labeling/__init__.py +13 -0
- ads/data_labeling/boundingbox.py +253 -0
- ads/data_labeling/constants.py +47 -0
- ads/data_labeling/data_labeling_service.py +244 -0
- ads/data_labeling/interface/__init__.py +5 -0
- ads/data_labeling/interface/loader.py +16 -0
- ads/data_labeling/interface/parser.py +16 -0
- ads/data_labeling/interface/reader.py +23 -0
- ads/data_labeling/loader/__init__.py +5 -0
- ads/data_labeling/loader/file_loader.py +241 -0
- ads/data_labeling/metadata.py +110 -0
- ads/data_labeling/mixin/__init__.py +5 -0
- ads/data_labeling/mixin/data_labeling.py +232 -0
- ads/data_labeling/ner.py +129 -0
- ads/data_labeling/parser/__init__.py +5 -0
- ads/data_labeling/parser/dls_record_parser.py +388 -0
- ads/data_labeling/parser/export_metadata_parser.py +94 -0
- ads/data_labeling/parser/export_record_parser.py +473 -0
- ads/data_labeling/reader/__init__.py +5 -0
- ads/data_labeling/reader/dataset_reader.py +574 -0
- ads/data_labeling/reader/dls_record_reader.py +121 -0
- ads/data_labeling/reader/export_record_reader.py +62 -0
- ads/data_labeling/reader/jsonl_reader.py +75 -0
- ads/data_labeling/reader/metadata_reader.py +203 -0
- ads/data_labeling/reader/record_reader.py +263 -0
- ads/data_labeling/record.py +52 -0
- ads/data_labeling/visualizer/__init__.py +5 -0
- ads/data_labeling/visualizer/image_visualizer.py +525 -0
- ads/data_labeling/visualizer/text_visualizer.py +357 -0
- ads/database/__init__.py +5 -0
- ads/database/connection.py +338 -0
- ads/dataset/__init__.py +10 -0
- ads/dataset/capabilities.md +51 -0
- ads/dataset/classification_dataset.py +339 -0
- ads/dataset/correlation.py +226 -0
- ads/dataset/correlation_plot.py +563 -0
- ads/dataset/dask_series.py +173 -0
- ads/dataset/dataframe_transformer.py +110 -0
- ads/dataset/dataset.py +1979 -0
- ads/dataset/dataset_browser.py +360 -0
- ads/dataset/dataset_with_target.py +995 -0
- ads/dataset/exception.py +25 -0
- ads/dataset/factory.py +987 -0
- ads/dataset/feature_engineering_transformer.py +35 -0
- ads/dataset/feature_selection.py +107 -0
- ads/dataset/forecasting_dataset.py +26 -0
- ads/dataset/helper.py +1450 -0
- ads/dataset/label_encoder.py +99 -0
- ads/dataset/mixin/__init__.py +5 -0
- ads/dataset/mixin/dataset_accessor.py +134 -0
- ads/dataset/pipeline.py +58 -0
- ads/dataset/plot.py +710 -0
- ads/dataset/progress.py +86 -0
- ads/dataset/recommendation.py +297 -0
- ads/dataset/recommendation_transformer.py +502 -0
- ads/dataset/regression_dataset.py +14 -0
- ads/dataset/sampled_dataset.py +1050 -0
- ads/dataset/target.py +98 -0
- ads/dataset/timeseries.py +18 -0
- ads/dbmixin/__init__.py +5 -0
- ads/dbmixin/db_pandas_accessor.py +153 -0
- ads/environment/__init__.py +9 -0
- ads/environment/ml_runtime.py +66 -0
- ads/evaluations/README.md +14 -0
- ads/evaluations/__init__.py +109 -0
- ads/evaluations/evaluation_plot.py +983 -0
- ads/evaluations/evaluator.py +1334 -0
- ads/evaluations/statistical_metrics.py +543 -0
- ads/experiments/__init__.py +9 -0
- ads/experiments/capabilities.md +0 -0
- ads/explanations/__init__.py +21 -0
- ads/explanations/base_explainer.py +142 -0
- ads/explanations/capabilities.md +83 -0
- ads/explanations/explainer.py +190 -0
- ads/explanations/mlx_global_explainer.py +1050 -0
- ads/explanations/mlx_interface.py +386 -0
- ads/explanations/mlx_local_explainer.py +287 -0
- ads/explanations/mlx_whatif_explainer.py +201 -0
- ads/feature_engineering/__init__.py +20 -0
- ads/feature_engineering/accessor/__init__.py +5 -0
- ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
- ads/feature_engineering/accessor/mixin/__init__.py +5 -0
- ads/feature_engineering/accessor/mixin/correlation.py +166 -0
- ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
- ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
- ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
- ads/feature_engineering/accessor/mixin/utils.py +65 -0
- ads/feature_engineering/accessor/series_accessor.py +431 -0
- ads/feature_engineering/adsimage/__init__.py +5 -0
- ads/feature_engineering/adsimage/image.py +192 -0
- ads/feature_engineering/adsimage/image_reader.py +170 -0
- ads/feature_engineering/adsimage/interface/__init__.py +5 -0
- ads/feature_engineering/adsimage/interface/reader.py +19 -0
- ads/feature_engineering/adsstring/__init__.py +7 -0
- ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
- ads/feature_engineering/adsstring/string/__init__.py +8 -0
- ads/feature_engineering/data_schema.json +57 -0
- ads/feature_engineering/dataset/__init__.py +5 -0
- ads/feature_engineering/dataset/zip_code_data.py +42062 -0
- ads/feature_engineering/exceptions.py +40 -0
- ads/feature_engineering/feature_type/__init__.py +133 -0
- ads/feature_engineering/feature_type/address.py +184 -0
- ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
- ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
- ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
- ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
- ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
- ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
- ads/feature_engineering/feature_type/adsstring/string.py +258 -0
- ads/feature_engineering/feature_type/base.py +58 -0
- ads/feature_engineering/feature_type/boolean.py +183 -0
- ads/feature_engineering/feature_type/category.py +146 -0
- ads/feature_engineering/feature_type/constant.py +137 -0
- ads/feature_engineering/feature_type/continuous.py +151 -0
- ads/feature_engineering/feature_type/creditcard.py +314 -0
- ads/feature_engineering/feature_type/datetime.py +190 -0
- ads/feature_engineering/feature_type/discrete.py +134 -0
- ads/feature_engineering/feature_type/document.py +43 -0
- ads/feature_engineering/feature_type/gis.py +251 -0
- ads/feature_engineering/feature_type/handler/__init__.py +5 -0
- ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
- ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
- ads/feature_engineering/feature_type/handler/warnings.py +128 -0
- ads/feature_engineering/feature_type/integer.py +142 -0
- ads/feature_engineering/feature_type/ip_address.py +144 -0
- ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
- ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
- ads/feature_engineering/feature_type/lat_long.py +256 -0
- ads/feature_engineering/feature_type/object.py +43 -0
- ads/feature_engineering/feature_type/ordinal.py +132 -0
- ads/feature_engineering/feature_type/phone_number.py +135 -0
- ads/feature_engineering/feature_type/string.py +171 -0
- ads/feature_engineering/feature_type/text.py +93 -0
- ads/feature_engineering/feature_type/unknown.py +43 -0
- ads/feature_engineering/feature_type/zip_code.py +164 -0
- ads/feature_engineering/feature_type_manager.py +406 -0
- ads/feature_engineering/schema.py +795 -0
- ads/feature_engineering/utils.py +245 -0
- ads/feature_store/.readthedocs.yaml +19 -0
- ads/feature_store/README.md +65 -0
- ads/feature_store/__init__.py +9 -0
- ads/feature_store/common/__init__.py +0 -0
- ads/feature_store/common/enums.py +339 -0
- ads/feature_store/common/exceptions.py +18 -0
- ads/feature_store/common/spark_session_singleton.py +125 -0
- ads/feature_store/common/utils/__init__.py +0 -0
- ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
- ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
- ads/feature_store/common/utils/transformation_utils.py +82 -0
- ads/feature_store/common/utils/utility.py +403 -0
- ads/feature_store/data_validation/__init__.py +0 -0
- ads/feature_store/data_validation/great_expectation.py +129 -0
- ads/feature_store/dataset.py +1230 -0
- ads/feature_store/dataset_job.py +530 -0
- ads/feature_store/docs/Dockerfile +7 -0
- ads/feature_store/docs/Makefile +44 -0
- ads/feature_store/docs/conf.py +28 -0
- ads/feature_store/docs/requirements.txt +14 -0
- ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
- ads/feature_store/docs/source/cicd.rst +137 -0
- ads/feature_store/docs/source/conf.py +86 -0
- ads/feature_store/docs/source/data_versioning.rst +33 -0
- ads/feature_store/docs/source/dataset.rst +388 -0
- ads/feature_store/docs/source/dataset_job.rst +27 -0
- ads/feature_store/docs/source/demo.rst +70 -0
- ads/feature_store/docs/source/entity.rst +78 -0
- ads/feature_store/docs/source/feature_group.rst +624 -0
- ads/feature_store/docs/source/feature_group_job.rst +29 -0
- ads/feature_store/docs/source/feature_store.rst +122 -0
- ads/feature_store/docs/source/feature_store_class.rst +123 -0
- ads/feature_store/docs/source/feature_validation.rst +66 -0
- ads/feature_store/docs/source/figures/cicd.png +0 -0
- ads/feature_store/docs/source/figures/data_validation.png +0 -0
- ads/feature_store/docs/source/figures/data_versioning.png +0 -0
- ads/feature_store/docs/source/figures/dataset.gif +0 -0
- ads/feature_store/docs/source/figures/dataset.png +0 -0
- ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
- ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
- ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
- ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
- ads/feature_store/docs/source/figures/entity.png +0 -0
- ads/feature_store/docs/source/figures/feature_group.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
- ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
- ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
- ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
- ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
- ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
- ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
- ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
- ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
- ads/feature_store/docs/source/figures/overview.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
- ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
- ads/feature_store/docs/source/figures/stats_1.png +0 -0
- ads/feature_store/docs/source/figures/stats_2.png +0 -0
- ads/feature_store/docs/source/figures/stats_d.png +0 -0
- ads/feature_store/docs/source/figures/stats_fg.png +0 -0
- ads/feature_store/docs/source/figures/transformation.png +0 -0
- ads/feature_store/docs/source/figures/transformations.gif +0 -0
- ads/feature_store/docs/source/figures/validation.png +0 -0
- ads/feature_store/docs/source/figures/validation_fg.png +0 -0
- ads/feature_store/docs/source/figures/validation_results.png +0 -0
- ads/feature_store/docs/source/figures/validation_summary.png +0 -0
- ads/feature_store/docs/source/index.rst +81 -0
- ads/feature_store/docs/source/module.rst +8 -0
- ads/feature_store/docs/source/notebook.rst +94 -0
- ads/feature_store/docs/source/overview.rst +47 -0
- ads/feature_store/docs/source/quickstart.rst +176 -0
- ads/feature_store/docs/source/release_notes.rst +194 -0
- ads/feature_store/docs/source/setup_feature_store.rst +81 -0
- ads/feature_store/docs/source/statistics.rst +58 -0
- ads/feature_store/docs/source/transformation.rst +199 -0
- ads/feature_store/docs/source/ui.rst +65 -0
- ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
- ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
- ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
- ads/feature_store/entity.py +718 -0
- ads/feature_store/execution_strategy/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
- ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
- ads/feature_store/execution_strategy/engine/__init__.py +0 -0
- ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
- ads/feature_store/execution_strategy/execution_strategy.py +113 -0
- ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
- ads/feature_store/execution_strategy/spark/__init__.py +0 -0
- ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
- ads/feature_store/feature.py +192 -0
- ads/feature_store/feature_group.py +1494 -0
- ads/feature_store/feature_group_expectation.py +346 -0
- ads/feature_store/feature_group_job.py +602 -0
- ads/feature_store/feature_lineage/__init__.py +0 -0
- ads/feature_store/feature_lineage/graphviz_service.py +180 -0
- ads/feature_store/feature_option_details.py +50 -0
- ads/feature_store/feature_statistics/__init__.py +0 -0
- ads/feature_store/feature_statistics/statistics_service.py +99 -0
- ads/feature_store/feature_store.py +699 -0
- ads/feature_store/feature_store_registrar.py +518 -0
- ads/feature_store/input_feature_detail.py +149 -0
- ads/feature_store/mixin/__init__.py +4 -0
- ads/feature_store/mixin/oci_feature_store.py +145 -0
- ads/feature_store/model_details.py +73 -0
- ads/feature_store/query/__init__.py +0 -0
- ads/feature_store/query/filter.py +266 -0
- ads/feature_store/query/generator/__init__.py +0 -0
- ads/feature_store/query/generator/query_generator.py +298 -0
- ads/feature_store/query/join.py +161 -0
- ads/feature_store/query/query.py +403 -0
- ads/feature_store/query/validator/__init__.py +0 -0
- ads/feature_store/query/validator/query_validator.py +57 -0
- ads/feature_store/response/__init__.py +0 -0
- ads/feature_store/response/response_builder.py +68 -0
- ads/feature_store/service/__init__.py +0 -0
- ads/feature_store/service/oci_dataset.py +139 -0
- ads/feature_store/service/oci_dataset_job.py +199 -0
- ads/feature_store/service/oci_entity.py +125 -0
- ads/feature_store/service/oci_feature_group.py +164 -0
- ads/feature_store/service/oci_feature_group_job.py +214 -0
- ads/feature_store/service/oci_feature_store.py +182 -0
- ads/feature_store/service/oci_lineage.py +87 -0
- ads/feature_store/service/oci_transformation.py +104 -0
- ads/feature_store/statistics/__init__.py +0 -0
- ads/feature_store/statistics/abs_feature_value.py +49 -0
- ads/feature_store/statistics/charts/__init__.py +0 -0
- ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
- ads/feature_store/statistics/charts/box_plot.py +148 -0
- ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
- ads/feature_store/statistics/charts/probability_distribution.py +68 -0
- ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
- ads/feature_store/statistics/feature_stat.py +126 -0
- ads/feature_store/statistics/generic_feature_value.py +33 -0
- ads/feature_store/statistics/statistics.py +41 -0
- ads/feature_store/statistics_config.py +101 -0
- ads/feature_store/templates/feature_store_template.yaml +45 -0
- ads/feature_store/transformation.py +499 -0
- ads/feature_store/validation_output.py +57 -0
- ads/hpo/__init__.py +9 -0
- ads/hpo/_imports.py +91 -0
- ads/hpo/ads_search_space.py +439 -0
- ads/hpo/distributions.py +325 -0
- ads/hpo/objective.py +280 -0
- ads/hpo/search_cv.py +1657 -0
- ads/hpo/stopping_criterion.py +75 -0
- ads/hpo/tuner_artifact.py +413 -0
- ads/hpo/utils.py +91 -0
- ads/hpo/validation.py +140 -0
- ads/hpo/visualization/__init__.py +5 -0
- ads/hpo/visualization/_contour.py +23 -0
- ads/hpo/visualization/_edf.py +20 -0
- ads/hpo/visualization/_intermediate_values.py +21 -0
- ads/hpo/visualization/_optimization_history.py +25 -0
- ads/hpo/visualization/_parallel_coordinate.py +169 -0
- ads/hpo/visualization/_param_importances.py +26 -0
- ads/jobs/__init__.py +53 -0
- ads/jobs/ads_job.py +663 -0
- ads/jobs/builders/__init__.py +5 -0
- ads/jobs/builders/base.py +156 -0
- ads/jobs/builders/infrastructure/__init__.py +6 -0
- ads/jobs/builders/infrastructure/base.py +165 -0
- ads/jobs/builders/infrastructure/dataflow.py +1252 -0
- ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
- ads/jobs/builders/infrastructure/utils.py +65 -0
- ads/jobs/builders/runtimes/__init__.py +5 -0
- ads/jobs/builders/runtimes/artifact.py +338 -0
- ads/jobs/builders/runtimes/base.py +325 -0
- ads/jobs/builders/runtimes/container_runtime.py +242 -0
- ads/jobs/builders/runtimes/python_runtime.py +1016 -0
- ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
- ads/jobs/cli.py +104 -0
- ads/jobs/env_var_parser.py +131 -0
- ads/jobs/extension.py +160 -0
- ads/jobs/schema/__init__.py +5 -0
- ads/jobs/schema/infrastructure_schema.json +116 -0
- ads/jobs/schema/job_schema.json +42 -0
- ads/jobs/schema/runtime_schema.json +183 -0
- ads/jobs/schema/validator.py +141 -0
- ads/jobs/serializer.py +296 -0
- ads/jobs/templates/__init__.py +5 -0
- ads/jobs/templates/container.py +6 -0
- ads/jobs/templates/driver_notebook.py +177 -0
- ads/jobs/templates/driver_oci.py +500 -0
- ads/jobs/templates/driver_python.py +48 -0
- ads/jobs/templates/driver_pytorch.py +852 -0
- ads/jobs/templates/driver_utils.py +615 -0
- ads/jobs/templates/hostname_from_env.c +55 -0
- ads/jobs/templates/oci_metrics.py +181 -0
- ads/jobs/utils.py +104 -0
- ads/llm/__init__.py +28 -0
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/v02/client.py +295 -0
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/chain.py +268 -0
- ads/llm/chat_template.py +31 -0
- ads/llm/deploy.py +63 -0
- ads/llm/guardrails/__init__.py +5 -0
- ads/llm/guardrails/base.py +442 -0
- ads/llm/guardrails/huggingface.py +44 -0
- ads/llm/langchain/__init__.py +5 -0
- ads/llm/langchain/plugins/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
- ads/llm/requirements.txt +3 -0
- ads/llm/serialize.py +219 -0
- ads/llm/serializers/__init__.py +0 -0
- ads/llm/serializers/retrieval_qa.py +153 -0
- ads/llm/serializers/runnable_parallel.py +27 -0
- ads/llm/templates/score_chain.jinja2 +155 -0
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- ads/model/__init__.py +52 -0
- ads/model/artifact.py +573 -0
- ads/model/artifact_downloader.py +254 -0
- ads/model/artifact_uploader.py +267 -0
- ads/model/base_properties.py +238 -0
- ads/model/common/.model-ignore +66 -0
- ads/model/common/__init__.py +5 -0
- ads/model/common/utils.py +142 -0
- ads/model/datascience_model.py +2635 -0
- ads/model/deployment/__init__.py +20 -0
- ads/model/deployment/common/__init__.py +5 -0
- ads/model/deployment/common/utils.py +308 -0
- ads/model/deployment/model_deployer.py +466 -0
- ads/model/deployment/model_deployment.py +1846 -0
- ads/model/deployment/model_deployment_infrastructure.py +671 -0
- ads/model/deployment/model_deployment_properties.py +493 -0
- ads/model/deployment/model_deployment_runtime.py +838 -0
- ads/model/extractor/__init__.py +5 -0
- ads/model/extractor/automl_extractor.py +74 -0
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/extractor/huggingface_extractor.py +88 -0
- ads/model/extractor/keras_extractor.py +84 -0
- ads/model/extractor/lightgbm_extractor.py +93 -0
- ads/model/extractor/model_info_extractor.py +114 -0
- ads/model/extractor/model_info_extractor_factory.py +105 -0
- ads/model/extractor/pytorch_extractor.py +87 -0
- ads/model/extractor/sklearn_extractor.py +112 -0
- ads/model/extractor/spark_extractor.py +89 -0
- ads/model/extractor/tensorflow_extractor.py +85 -0
- ads/model/extractor/xgboost_extractor.py +94 -0
- ads/model/framework/__init__.py +5 -0
- ads/model/framework/automl_model.py +178 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/framework/huggingface_model.py +399 -0
- ads/model/framework/lightgbm_model.py +266 -0
- ads/model/framework/pytorch_model.py +266 -0
- ads/model/framework/sklearn_model.py +250 -0
- ads/model/framework/spark_model.py +326 -0
- ads/model/framework/tensorflow_model.py +254 -0
- ads/model/framework/xgboost_model.py +258 -0
- ads/model/generic_model.py +3518 -0
- ads/model/model_artifact_boilerplate/README.md +381 -0
- ads/model/model_artifact_boilerplate/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
- ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
- ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
- ads/model/model_artifact_boilerplate/score.py +61 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_introspect.py +331 -0
- ads/model/model_metadata.py +1810 -0
- ads/model/model_metadata_mixin.py +460 -0
- ads/model/model_properties.py +63 -0
- ads/model/model_version_set.py +739 -0
- ads/model/runtime/__init__.py +5 -0
- ads/model/runtime/env_info.py +306 -0
- ads/model/runtime/model_deployment_details.py +37 -0
- ads/model/runtime/model_provenance_details.py +58 -0
- ads/model/runtime/runtime_info.py +81 -0
- ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
- ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
- ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
- ads/model/runtime/utils.py +201 -0
- ads/model/serde/__init__.py +5 -0
- ads/model/serde/common.py +40 -0
- ads/model/serde/model_input.py +547 -0
- ads/model/serde/model_serializer.py +1184 -0
- ads/model/service/__init__.py +5 -0
- ads/model/service/oci_datascience_model.py +1076 -0
- ads/model/service/oci_datascience_model_deployment.py +500 -0
- ads/model/service/oci_datascience_model_version_set.py +176 -0
- ads/model/transformer/__init__.py +5 -0
- ads/model/transformer/onnx_transformer.py +324 -0
- ads/mysqldb/__init__.py +5 -0
- ads/mysqldb/mysql_db.py +227 -0
- ads/opctl/__init__.py +18 -0
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/__init__.py +5 -0
- ads/opctl/backend/ads_dataflow.py +353 -0
- ads/opctl/backend/ads_ml_job.py +710 -0
- ads/opctl/backend/ads_ml_pipeline.py +164 -0
- ads/opctl/backend/ads_model_deployment.py +209 -0
- ads/opctl/backend/base.py +146 -0
- ads/opctl/backend/local.py +1053 -0
- ads/opctl/backend/marketplace/__init__.py +9 -0
- ads/opctl/backend/marketplace/helm_helper.py +173 -0
- ads/opctl/backend/marketplace/local_marketplace.py +271 -0
- ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
- ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
- ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
- ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
- ads/opctl/backend/marketplace/models/__init__.py +5 -0
- ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
- ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
- ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
- ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
- ads/opctl/cli.py +707 -0
- ads/opctl/cmds.py +869 -0
- ads/opctl/conda/__init__.py +5 -0
- ads/opctl/conda/cli.py +193 -0
- ads/opctl/conda/cmds.py +749 -0
- ads/opctl/conda/config.yaml +34 -0
- ads/opctl/conda/manifest_template.yaml +13 -0
- ads/opctl/conda/multipart_uploader.py +188 -0
- ads/opctl/conda/pack.py +89 -0
- ads/opctl/config/__init__.py +5 -0
- ads/opctl/config/base.py +57 -0
- ads/opctl/config/diagnostics/__init__.py +5 -0
- ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
- ads/opctl/config/merger.py +255 -0
- ads/opctl/config/resolver.py +297 -0
- ads/opctl/config/utils.py +79 -0
- ads/opctl/config/validator.py +17 -0
- ads/opctl/config/versioner.py +68 -0
- ads/opctl/config/yaml_parsers/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/base.py +58 -0
- ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
- ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
- ads/opctl/constants.py +66 -0
- ads/opctl/decorator/__init__.py +5 -0
- ads/opctl/decorator/common.py +129 -0
- ads/opctl/diagnostics/__init__.py +5 -0
- ads/opctl/diagnostics/__main__.py +25 -0
- ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
- ads/opctl/diagnostics/check_requirements.py +144 -0
- ads/opctl/diagnostics/requirement_exception.py +9 -0
- ads/opctl/distributed/README.md +109 -0
- ads/opctl/distributed/__init__.py +5 -0
- ads/opctl/distributed/certificates.py +32 -0
- ads/opctl/distributed/cli.py +207 -0
- ads/opctl/distributed/cmds.py +731 -0
- ads/opctl/distributed/common/__init__.py +5 -0
- ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
- ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
- ads/opctl/distributed/common/cluster_config_helper.py +103 -0
- ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
- ads/opctl/distributed/common/cluster_runner.py +54 -0
- ads/opctl/distributed/common/framework_factory.py +29 -0
- ads/opctl/docker/Dockerfile.job +103 -0
- ads/opctl/docker/Dockerfile.job.arm +107 -0
- ads/opctl/docker/Dockerfile.job.gpu +175 -0
- ads/opctl/docker/base-env.yaml +13 -0
- ads/opctl/docker/cuda.repo +6 -0
- ads/opctl/docker/operator/.dockerignore +0 -0
- ads/opctl/docker/operator/Dockerfile +41 -0
- ads/opctl/docker/operator/Dockerfile.gpu +85 -0
- ads/opctl/docker/operator/cuda.repo +6 -0
- ads/opctl/docker/operator/environment.yaml +8 -0
- ads/opctl/forecast.py +11 -0
- ads/opctl/index.yaml +3 -0
- ads/opctl/model/__init__.py +5 -0
- ads/opctl/model/cli.py +65 -0
- ads/opctl/model/cmds.py +73 -0
- ads/opctl/operator/README.md +4 -0
- ads/opctl/operator/__init__.py +31 -0
- ads/opctl/operator/cli.py +344 -0
- ads/opctl/operator/cmd.py +596 -0
- ads/opctl/operator/common/__init__.py +5 -0
- ads/opctl/operator/common/backend_factory.py +460 -0
- ads/opctl/operator/common/const.py +27 -0
- ads/opctl/operator/common/data/synthetic.csv +16001 -0
- ads/opctl/operator/common/dictionary_merger.py +148 -0
- ads/opctl/operator/common/errors.py +42 -0
- ads/opctl/operator/common/operator_config.py +99 -0
- ads/opctl/operator/common/operator_loader.py +811 -0
- ads/opctl/operator/common/operator_schema.yaml +130 -0
- ads/opctl/operator/common/operator_yaml_generator.py +152 -0
- ads/opctl/operator/common/utils.py +208 -0
- ads/opctl/operator/lowcode/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
- ads/opctl/operator/lowcode/anomaly/README.md +207 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +167 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +116 -0
- ads/opctl/operator/lowcode/common/errors.py +47 -0
- ads/opctl/operator/lowcode/common/transformations.py +296 -0
- ads/opctl/operator/lowcode/common/utils.py +384 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
- ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
- ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
- ads/opctl/operator/lowcode/forecast/README.md +209 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
- ads/opctl/operator/lowcode/forecast/const.py +92 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
- ads/opctl/operator/lowcode/forecast/errors.py +26 -0
- ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
- ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
- ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
- ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
- ads/opctl/operator/lowcode/forecast/utils.py +397 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
- ads/opctl/operator/lowcode/pii/MLoperator +17 -0
- ads/opctl/operator/lowcode/pii/README.md +208 -0
- ads/opctl/operator/lowcode/pii/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/__main__.py +78 -0
- ads/opctl/operator/lowcode/pii/cmd.py +39 -0
- ads/opctl/operator/lowcode/pii/constant.py +84 -0
- ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
- ads/opctl/operator/lowcode/pii/errors.py +27 -0
- ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
- ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
- ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
- ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
- ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
- ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
- ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
- ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
- ads/opctl/operator/lowcode/pii/model/report.py +487 -0
- ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
- ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
- ads/opctl/operator/lowcode/pii/utils.py +43 -0
- ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
- ads/opctl/operator/lowcode/recommender/README.md +206 -0
- ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
- ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
- ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
- ads/opctl/operator/lowcode/recommender/constant.py +30 -0
- ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
- ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
- ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
- ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
- ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
- ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
- ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
- ads/opctl/operator/lowcode/recommender/utils.py +13 -0
- ads/opctl/operator/runtime/__init__.py +5 -0
- ads/opctl/operator/runtime/const.py +17 -0
- ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
- ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
- ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
- ads/opctl/operator/runtime/runtime.py +115 -0
- ads/opctl/schema.yaml.yml +36 -0
- ads/opctl/script.py +40 -0
- ads/opctl/spark/__init__.py +5 -0
- ads/opctl/spark/cli.py +43 -0
- ads/opctl/spark/cmds.py +147 -0
- ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
- ads/opctl/utils.py +344 -0
- ads/oracledb/__init__.py +5 -0
- ads/oracledb/oracle_db.py +346 -0
- ads/pipeline/__init__.py +39 -0
- ads/pipeline/ads_pipeline.py +2279 -0
- ads/pipeline/ads_pipeline_run.py +772 -0
- ads/pipeline/ads_pipeline_step.py +605 -0
- ads/pipeline/builders/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/__init__.py +5 -0
- ads/pipeline/builders/infrastructure/custom_script.py +32 -0
- ads/pipeline/cli.py +119 -0
- ads/pipeline/extension.py +291 -0
- ads/pipeline/schema/__init__.py +5 -0
- ads/pipeline/schema/cs_step_schema.json +35 -0
- ads/pipeline/schema/ml_step_schema.json +31 -0
- ads/pipeline/schema/pipeline_schema.json +71 -0
- ads/pipeline/visualizer/__init__.py +5 -0
- ads/pipeline/visualizer/base.py +570 -0
- ads/pipeline/visualizer/graph_renderer.py +272 -0
- ads/pipeline/visualizer/text_renderer.py +84 -0
- ads/secrets/__init__.py +11 -0
- ads/secrets/adb.py +386 -0
- ads/secrets/auth_token.py +86 -0
- ads/secrets/big_data_service.py +365 -0
- ads/secrets/mysqldb.py +149 -0
- ads/secrets/oracledb.py +160 -0
- ads/secrets/secrets.py +407 -0
- ads/telemetry/__init__.py +7 -0
- ads/telemetry/base.py +69 -0
- ads/telemetry/client.py +122 -0
- ads/telemetry/telemetry.py +257 -0
- ads/templates/dataflow_pyspark.jinja2 +13 -0
- ads/templates/dataflow_sparksql.jinja2 +22 -0
- ads/templates/func.jinja2 +20 -0
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score-pkl.jinja2 +173 -0
- ads/templates/score.jinja2 +322 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- ads/templates/score_generic.jinja2 +165 -0
- ads/templates/score_huggingface_pipeline.jinja2 +217 -0
- ads/templates/score_lightgbm.jinja2 +185 -0
- ads/templates/score_onnx.jinja2 +407 -0
- ads/templates/score_onnx_new.jinja2 +473 -0
- ads/templates/score_oracle_automl.jinja2 +185 -0
- ads/templates/score_pyspark.jinja2 +154 -0
- ads/templates/score_pytorch.jinja2 +219 -0
- ads/templates/score_scikit-learn.jinja2 +184 -0
- ads/templates/score_tensorflow.jinja2 +184 -0
- ads/templates/score_xgboost.jinja2 +178 -0
- ads/text_dataset/__init__.py +5 -0
- ads/text_dataset/backends.py +211 -0
- ads/text_dataset/dataset.py +445 -0
- ads/text_dataset/extractor.py +207 -0
- ads/text_dataset/options.py +53 -0
- ads/text_dataset/udfs.py +22 -0
- ads/text_dataset/utils.py +49 -0
- ads/type_discovery/__init__.py +9 -0
- ads/type_discovery/abstract_detector.py +21 -0
- ads/type_discovery/constant_detector.py +41 -0
- ads/type_discovery/continuous_detector.py +54 -0
- ads/type_discovery/credit_card_detector.py +99 -0
- ads/type_discovery/datetime_detector.py +92 -0
- ads/type_discovery/discrete_detector.py +118 -0
- ads/type_discovery/document_detector.py +146 -0
- ads/type_discovery/ip_detector.py +68 -0
- ads/type_discovery/latlon_detector.py +90 -0
- ads/type_discovery/phone_number_detector.py +63 -0
- ads/type_discovery/type_discovery_driver.py +87 -0
- ads/type_discovery/typed_feature.py +594 -0
- ads/type_discovery/unknown_detector.py +41 -0
- ads/type_discovery/zipcode_detector.py +48 -0
- ads/vault/__init__.py +7 -0
- ads/vault/vault.py +237 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/METADATA +150 -149
- oracle_ads-2.13.10.dist-info/RECORD +858 -0
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/WHEEL +1 -2
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/entry_points.txt +2 -1
- oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
- oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
- {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,852 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8; -*-
|
3
|
+
|
4
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
"""This module requires oracle-ads>=2.6.8
|
7
|
+
"""
|
8
|
+
import getpass
|
9
|
+
import ipaddress
|
10
|
+
import logging
|
11
|
+
import multiprocessing
|
12
|
+
import os
|
13
|
+
import time
|
14
|
+
import shlex
|
15
|
+
import socket
|
16
|
+
import sys
|
17
|
+
import traceback
|
18
|
+
|
19
|
+
import oci
|
20
|
+
import psutil
|
21
|
+
import torch
|
22
|
+
from ads import set_auth
|
23
|
+
from ads.jobs import DataScienceJobRun
|
24
|
+
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
|
25
|
+
PythonRuntimeHandler,
|
26
|
+
)
|
27
|
+
from ads.opctl.distributed.common import cluster_config_helper
|
28
|
+
|
29
|
+
try:
|
30
|
+
# This is used by ADS and testing
|
31
|
+
from . import driver_utils
|
32
|
+
from .driver_oci import GitSSHKey, GitManager
|
33
|
+
from .oci_metrics import collect_metrics, METRIC_NAMESPACE
|
34
|
+
except ImportError:
|
35
|
+
# This is used when the script is in a job run.
|
36
|
+
import driver_utils
|
37
|
+
from driver_oci import GitSSHKey, GitManager
|
38
|
+
from oci_metrics import collect_metrics, METRIC_NAMESPACE
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
logger = driver_utils.set_log_level(logger)
|
42
|
+
|
43
|
+
|
44
|
+
# Envs provisioned by the service
|
45
|
+
CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
|
46
|
+
CONST_ENV_JOB_RUN_OCID = "JOB_RUN_OCID"
|
47
|
+
# Envs set by the ADS API
|
48
|
+
OCI__WORKER_COUNT = "OCI__WORKER_COUNT"
|
49
|
+
CONST_ENV_NODE_RANK = "NODE_RANK"
|
50
|
+
CONST_ENV_NODE_COUNT = "NODE_COUNT"
|
51
|
+
CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
|
52
|
+
CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
|
53
|
+
# Envs set by this module
|
54
|
+
CONST_ENV_WORLD_SIZE = "WORLD_SIZE"
|
55
|
+
CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
|
56
|
+
# Envs for debugging only
|
57
|
+
# OCI_ODSC_SERVICE_ENDPOINT is used for all processes in the job run
|
58
|
+
CONST_ENV_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT"
|
59
|
+
# OCI_DS_SERVICE_ENDPOINT is used only by the training process
|
60
|
+
CONST_ENV_DS_SERVICE_ENDPOINT = "OCI_DS_SERVICE_ENDPOINT"
|
61
|
+
|
62
|
+
# Constants used in logs
|
63
|
+
LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
|
64
|
+
LOG_PREFIX_NODE_IP = "Node IP: "
|
65
|
+
LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
|
66
|
+
# Other constants used within this script
|
67
|
+
# Other constants used within this script
|
68
|
+
USER_HOME = os.environ.get("HOME", f"/home/{getpass.getuser()}")
|
69
|
+
SSH_DIR = os.environ.get("OCI__SSH_DIR", os.path.join(USER_HOME, ".ssh"))
|
70
|
+
DEFAULT_LAUNCHER = "torchrun"
|
71
|
+
|
72
|
+
# Set authentication method to resource principal
|
73
|
+
# This script is expected to be running inside the job run
|
74
|
+
if "OCI_RESOURCE_PRINCIPAL_VERSION" in os.environ:
|
75
|
+
set_auth("resource_principal")
|
76
|
+
|
77
|
+
|
78
|
+
class LazyEvaluate:
|
79
|
+
"""This is a class to delay the function call until
|
80
|
+
its return value is needed for logging purpose.
|
81
|
+
|
82
|
+
Example::
|
83
|
+
logger.debug("The value is %s", LazyEvaluate(the_function, *args, **kwargs))
|
84
|
+
|
85
|
+
Python logging will only call the __str__() method when the value is needed.
|
86
|
+
|
87
|
+
In the above example, if the log level is INFO or above,
|
88
|
+
the_function() will not be called/evaluated.
|
89
|
+
If the log level is DEBUG, the_function will be called,
|
90
|
+
and if there is an error, the error will be logged.
|
91
|
+
The program will continue to run even if the error happens during logging.
|
92
|
+
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(self, func, *args, **kwargs) -> None:
|
96
|
+
self.func = func
|
97
|
+
self.args = args
|
98
|
+
self.kwargs = kwargs
|
99
|
+
|
100
|
+
def eval(self):
|
101
|
+
"""Evaluates the function call."""
|
102
|
+
return self.func(*self.args, **self.kwargs)
|
103
|
+
|
104
|
+
def __str__(self) -> str:
|
105
|
+
"""Evaluate the function call and convert the return value as a string."""
|
106
|
+
try:
|
107
|
+
val = str(self.eval())
|
108
|
+
except Exception as ex:
|
109
|
+
logger.debug(traceback.format_exc())
|
110
|
+
val = f"ERROR: {str(ex)}"
|
111
|
+
return val
|
112
|
+
|
113
|
+
|
114
|
+
class Runner(driver_utils.JobRunner):
|
115
|
+
"""Base runner class for PyTorch training job"""
|
116
|
+
|
117
|
+
# LAUNCHER stores the main command for launching the training job.
|
118
|
+
# e.g. torchrun, deepspeed, accelerate, etc.
|
119
|
+
LAUNCHER = ""
|
120
|
+
|
121
|
+
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
|
122
|
+
super().__init__(code_dir)
|
123
|
+
self.launch_cmd = os.environ.get(CONST_ENV_LAUNCH_CMD, "")
|
124
|
+
|
125
|
+
self.ds_client = driver_utils.OCIHelper.init_oci_client(
|
126
|
+
oci.data_science.DataScienceClient
|
127
|
+
)
|
128
|
+
self.ip = self.find_self_ip()
|
129
|
+
# IP address of other nodes as a list
|
130
|
+
self.node_ip_list = []
|
131
|
+
# DataScienceJobRun objects of other nodes as a list
|
132
|
+
self.node_runs = []
|
133
|
+
|
134
|
+
if CONST_ENV_HOST_JOB_RUN_OCID in os.environ:
|
135
|
+
# Print the node IP address to logs so that it can be obtained by the host.
|
136
|
+
print(f"{LOG_PREFIX_NODE_IP}{self.ip}")
|
137
|
+
self.host_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID]
|
138
|
+
logger.debug("Host job run OCID: %s", self.host_ocid)
|
139
|
+
self.host_ip = None
|
140
|
+
self.is_host = False
|
141
|
+
else:
|
142
|
+
# Print the host IP address to logs so that it can be obtained by the nodes.
|
143
|
+
print(f"{LOG_PREFIX_HOST_IP}{self.ip}")
|
144
|
+
self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID)
|
145
|
+
self.host_ip = self.ip
|
146
|
+
self.is_host = True
|
147
|
+
|
148
|
+
self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
|
149
|
+
self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT
|
150
|
+
# The total number of nodes is OCI__WORKER_COUNT + 1
|
151
|
+
if CONST_ENV_NODE_COUNT in os.environ:
|
152
|
+
self.node_count = int(os.environ[CONST_ENV_NODE_COUNT])
|
153
|
+
else:
|
154
|
+
self.node_count = int(os.environ.get(OCI__WORKER_COUNT, 0)) + 1
|
155
|
+
logger.debug("Node count: %s", self.node_count)
|
156
|
+
self.gpu_count = torch.cuda.device_count()
|
157
|
+
logger.debug("GPU count on this node: %s", self.gpu_count)
|
158
|
+
|
159
|
+
logger.debug("Runner initialized.")
|
160
|
+
|
161
|
+
def launch_cmd_contains(self, arg) -> bool:
|
162
|
+
"""Checks if the cmd for launching the training contains specific keyword argument."""
|
163
|
+
return f"--{arg}" in self.launch_cmd
|
164
|
+
|
165
|
+
def wait_for_host_ip_address(self, timeout=15 * 60) -> str:
|
166
|
+
"""Waits until the IP address of the host is obtained.
|
167
|
+
|
168
|
+
Parameters
|
169
|
+
----------
|
170
|
+
timeout : int, optional
|
171
|
+
Timeout in seconds, by default 15 minutes.
|
172
|
+
|
173
|
+
Returns
|
174
|
+
-------
|
175
|
+
str
|
176
|
+
IP address
|
177
|
+
"""
|
178
|
+
if not self.host_ip:
|
179
|
+
logger.info("Waiting for host's IP address...")
|
180
|
+
self.host_ip = self.wait_for_ip_address(self.host_job_run, timeout)
|
181
|
+
return self
|
182
|
+
|
183
|
+
def wait_for_ip_address(self, job_run, timeout=15 * 60) -> str:
|
184
|
+
"""Waits until the IP address of a particular job run is obtained.
|
185
|
+
|
186
|
+
Parameters
|
187
|
+
----------
|
188
|
+
job_run : DataScienceJobRun
|
189
|
+
A DataScienceJobRun object
|
190
|
+
timeout : int, optional
|
191
|
+
Timeout in seconds, by default 15 minutes.
|
192
|
+
|
193
|
+
Returns
|
194
|
+
-------
|
195
|
+
str
|
196
|
+
IP address
|
197
|
+
"""
|
198
|
+
logger.info("Waiting for IP address of job run %s", job_run.id)
|
199
|
+
if job_run == self.host_job_run:
|
200
|
+
log_prefix = LOG_PREFIX_HOST_IP
|
201
|
+
else:
|
202
|
+
log_prefix = LOG_PREFIX_NODE_IP
|
203
|
+
ip_address = self.wait_for_log(job_run, log_prefix, timeout).strip()
|
204
|
+
logger.info("IP of %s: %s", job_run.id[-6:], ip_address)
|
205
|
+
return ip_address
|
206
|
+
|
207
|
+
def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
|
208
|
+
"""Waits until a log message with specific prefix is found in the logs of a job run.
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
job_run : DataScienceJobRun
|
213
|
+
A DataScienceJobRun object
|
214
|
+
log_prefix : str
|
215
|
+
The prefix of the log message to look for.
|
216
|
+
timeout : int, optional
|
217
|
+
Timeout in seconds, by default 15 minutes.
|
218
|
+
|
219
|
+
Returns
|
220
|
+
-------
|
221
|
+
str
|
222
|
+
The log message with out the prefix.
|
223
|
+
|
224
|
+
Raises
|
225
|
+
------
|
226
|
+
TimeoutError
|
227
|
+
Failed to obtain the log message within the specific timeout.
|
228
|
+
"""
|
229
|
+
logger.debug(
|
230
|
+
"Waiting for logs with prefix '%s' from %s.", log_prefix, job_run.id
|
231
|
+
)
|
232
|
+
second_started = time.time()
|
233
|
+
log = None
|
234
|
+
while not log:
|
235
|
+
log = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix)
|
236
|
+
if log:
|
237
|
+
break
|
238
|
+
if time.time() - second_started > timeout:
|
239
|
+
raise TimeoutError(
|
240
|
+
f"Failed to obtain log with prefix {log_prefix} for {job_run.id} in {timeout} seconds."
|
241
|
+
)
|
242
|
+
time.sleep(60)
|
243
|
+
return log
|
244
|
+
|
245
|
+
@staticmethod
|
246
|
+
def check_job_run_logs(job_run, log_prefix: str) -> str:
|
247
|
+
"""Checks the logs of a specific job run and find the log message with specific prefix.
|
248
|
+
|
249
|
+
Parameters
|
250
|
+
----------
|
251
|
+
job_run : DataScienceJobRun
|
252
|
+
The Job run object from which the logs will be obtained.
|
253
|
+
log_prefix : str
|
254
|
+
The prefix to look for.
|
255
|
+
|
256
|
+
Returns
|
257
|
+
-------
|
258
|
+
str
|
259
|
+
The log message without the prefix.
|
260
|
+
"""
|
261
|
+
logger.debug("Checking logs for job run %s", job_run.id)
|
262
|
+
logs = job_run.logs()
|
263
|
+
for log in logs:
|
264
|
+
if log["message"].startswith(log_prefix):
|
265
|
+
return log["message"][len(log_prefix) :]
|
266
|
+
return None
|
267
|
+
|
268
|
+
def find_self_ip(self):
|
269
|
+
"""
|
270
|
+
Identify IP address by finding which of the host IP intersects with the CIDR block of the subnet
|
271
|
+
associated with the JOB_OCID
|
272
|
+
"""
|
273
|
+
hostname = socket.gethostname()
|
274
|
+
logger.debug("Hostname: %s", hostname)
|
275
|
+
logger.debug(
|
276
|
+
"Get Host by Addr: %s", LazyEvaluate(socket.gethostbyaddr, hostname)
|
277
|
+
)
|
278
|
+
logger.debug("FQDN: %s", LazyEvaluate(socket.getfqdn, hostname))
|
279
|
+
if os.environ.get("JOB_OCID"):
|
280
|
+
subnet_id = self.ds_client.get_job(
|
281
|
+
os.environ["JOB_OCID"]
|
282
|
+
).data.job_infrastructure_configuration_details.subnet_id
|
283
|
+
core_client = driver_utils.OCIHelper.init_oci_client(
|
284
|
+
oci.core.VirtualNetworkClient
|
285
|
+
)
|
286
|
+
cidr = core_client.get_subnet(subnet_id).data.cidr_block
|
287
|
+
|
288
|
+
for interface, snics in psutil.net_if_addrs().items():
|
289
|
+
ip = snics[0].address
|
290
|
+
if ipaddress.ip_address(ip) in ipaddress.ip_network(cidr):
|
291
|
+
logger.info("Node IP address: %s", ip)
|
292
|
+
# Specify the network interface for NCCL/GLOO
|
293
|
+
os.environ["GLOO_SOCKET_IFNAME"] = interface
|
294
|
+
os.environ["NCCL_SOCKET_IFNAME"] = interface
|
295
|
+
return ip
|
296
|
+
raise EnvironmentError("Unable to determine node IP address.")
|
297
|
+
else:
|
298
|
+
ip = socket.gethostbyname(hostname)
|
299
|
+
logger.info("Node IP address: %s", ip)
|
300
|
+
return ip
|
301
|
+
|
302
|
+
def fetch_code(self):
|
303
|
+
"""Fetches source code from Git if repo uri is specified."""
|
304
|
+
if cluster_config_helper.OCI__RUNTIME_URI in os.environ:
|
305
|
+
self._fetch_git(code_dir=self.code_dir)
|
306
|
+
return self
|
307
|
+
|
308
|
+
def _fetch_git(self, code_dir):
|
309
|
+
"""Fetches source code from Git repository."""
|
310
|
+
uri = os.environ.get(cluster_config_helper.OCI__RUNTIME_URI)
|
311
|
+
branch = os.environ.get(cluster_config_helper.OCI__RUNTIME_GIT_BRANCH)
|
312
|
+
commit = os.environ.get(cluster_config_helper.OCI__RUNTIME_GIT_COMMIT)
|
313
|
+
secret_ocid = os.environ.get(cluster_config_helper.OCI__RUNTIME_GIT_SECRET_ID)
|
314
|
+
# with GitSSHKey does nothing if secret_ocid is None or empty
|
315
|
+
with GitSSHKey(secret_ocid):
|
316
|
+
GitManager(uri, code_dir=code_dir).fetch_repo().checkout_code(
|
317
|
+
branch=branch, commit=commit
|
318
|
+
)
|
319
|
+
|
320
|
+
def get_cmd_with_entrypoint_and_args(self, prefix: str = "") -> str:
|
321
|
+
"""Gets the command based on entrypoint and arguments.
|
322
|
+
|
323
|
+
Parameters
|
324
|
+
----------
|
325
|
+
prefix : str, optional
|
326
|
+
Command prefix, by default ""
|
327
|
+
This can be used to set environment variables for the command.
|
328
|
+
e.g. ENV=1 command
|
329
|
+
|
330
|
+
Returns
|
331
|
+
-------
|
332
|
+
str
|
333
|
+
The command including the prefix, entrypoint and arguments.
|
334
|
+
"""
|
335
|
+
cmd = os.environ[self.entrypoint_env]
|
336
|
+
if prefix:
|
337
|
+
cmd = prefix + " " + cmd
|
338
|
+
if sys.argv[1:]:
|
339
|
+
cmd += " " + " ".join(sys.argv[1:])
|
340
|
+
return cmd
|
341
|
+
|
342
|
+
def prepare_cmd(self, launch_args: list = None, prefix=""):
|
343
|
+
"""Prepares the command for starting the training.
|
344
|
+
|
345
|
+
Parameters
|
346
|
+
----------
|
347
|
+
launch_args : list
|
348
|
+
The command and arguments for starting the training as a list.
|
349
|
+
prefix : str, optional
|
350
|
+
The prefix to be added to the launch_args in the command, by default ""
|
351
|
+
This can be used to set environment variables for the command.
|
352
|
+
e.g. ENV=1 command
|
353
|
+
|
354
|
+
Returns
|
355
|
+
-------
|
356
|
+
str
|
357
|
+
The command for starting the training.
|
358
|
+
"""
|
359
|
+
if not launch_args:
|
360
|
+
launch_args = []
|
361
|
+
# Append launch cmd args specified by the user.
|
362
|
+
if self.launch_cmd:
|
363
|
+
if self.LAUNCHER:
|
364
|
+
if not self.launch_cmd.startswith(self.LAUNCHER):
|
365
|
+
raise ValueError(f"Command not supported: '{self.launch_cmd}'. ")
|
366
|
+
|
367
|
+
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
|
368
|
+
else:
|
369
|
+
launch_args.append(self.launch_cmd)
|
370
|
+
else:
|
371
|
+
launch_args.append(self.get_cmd_with_entrypoint_and_args())
|
372
|
+
|
373
|
+
if prefix:
|
374
|
+
launcher = f"{prefix} {self.LAUNCHER}"
|
375
|
+
else:
|
376
|
+
launcher = self.LAUNCHER
|
377
|
+
|
378
|
+
return f"{launcher} {' '.join(launch_args)}"
|
379
|
+
|
380
|
+
def time_cmd(self, cmd):
|
381
|
+
"""Run the command and log the time used."""
|
382
|
+
# Show current working directory for debugging purpose
|
383
|
+
self.run_command("pwd", level=logging.DEBUG)
|
384
|
+
# Show all environment variables
|
385
|
+
self.run_command("printenv", level=logging.DEBUG)
|
386
|
+
training_start_time = time.time()
|
387
|
+
self.run_command(cmd, conda_prefix=self.conda_prefix, check=True)
|
388
|
+
logger.info("Time: %s seconds.", time.time() - training_start_time)
|
389
|
+
|
390
|
+
def run(self):
|
391
|
+
raise NotImplementedError()
|
392
|
+
|
393
|
+
|
394
|
+
class TorchRunner(Runner):
|
395
|
+
RDZV_PORT = 29400
|
396
|
+
LAUNCHER = "torchrun"
|
397
|
+
|
398
|
+
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
|
399
|
+
super().__init__(code_dir)
|
400
|
+
self.build_c_library()
|
401
|
+
|
402
|
+
def build_c_library(self):
|
403
|
+
C_SOURCE_CODE = "hostname_from_env.c"
|
404
|
+
source_path = os.path.join(
|
405
|
+
os.path.dirname(os.path.abspath(__file__)), C_SOURCE_CODE
|
406
|
+
)
|
407
|
+
if not os.path.exists(source_path):
|
408
|
+
logger.error("Source code %s not found.", source_path)
|
409
|
+
return
|
410
|
+
|
411
|
+
self.run_command(
|
412
|
+
"gcc -fPIC -shared -Wl,-soname,libhostname.so.1 -ldl "
|
413
|
+
f"-o {self.conda_prefix}/lib/libhostname.so.1 {source_path}",
|
414
|
+
conda_prefix=self.conda_prefix,
|
415
|
+
level=logging.DEBUG,
|
416
|
+
)
|
417
|
+
self.run_command(
|
418
|
+
f"ls {self.conda_prefix}/lib/libhostname*", level=logging.DEBUG
|
419
|
+
)
|
420
|
+
|
421
|
+
return self
|
422
|
+
|
423
|
+
def env_ld_preload(self) -> str:
|
424
|
+
"""Generate environment variable config for LD_PRELOAD and OCI__HOSTNAME.
|
425
|
+
The return value can be used as the prefix of a bash command.
|
426
|
+
"""
|
427
|
+
cmd_prefix = ""
|
428
|
+
# Use LD_PRELOAD only if LD_PRELOAD is not defined by the user.
|
429
|
+
# For pytorch>=2.0, we can use f"--local_addr={self.ip} " instead of LD_PRELOAD.
|
430
|
+
if CONST_ENV_LD_PRELOAD not in os.environ:
|
431
|
+
cmd_prefix = f"LD_PRELOAD={self.conda_prefix}/lib/libhostname.so.1 OCI__HOSTNAME={self.ip}"
|
432
|
+
return cmd_prefix
|
433
|
+
|
434
|
+
def get_rdzv_conf(self) -> str:
|
435
|
+
"""Prepare additional rendezvous config for torch run.
|
436
|
+
|
437
|
+
The default read_timeout is 60 seconds.
|
438
|
+
The job run will fail if the node cannot reach the host within read_timeout.
|
439
|
+
"""
|
440
|
+
rdzv_timeout = os.environ.get("OCI__RDZV_TIMEOUT", "600")
|
441
|
+
rdzv_conf = f"read_timeout={rdzv_timeout}"
|
442
|
+
return rdzv_conf
|
443
|
+
|
444
|
+
def run(self):
|
445
|
+
if self.gpu_count > 0:
|
446
|
+
nproc_per_node = self.gpu_count
|
447
|
+
else:
|
448
|
+
nproc_per_node = 1
|
449
|
+
|
450
|
+
launch_args = []
|
451
|
+
# Add nnode, nproc_per_node and rdzv args only if they are not specified by the user.
|
452
|
+
if not self.launch_cmd_contains("nnode"):
|
453
|
+
launch_args.append(f"--nnode={self.node_count}")
|
454
|
+
if not self.launch_cmd_contains("nproc_per_node"):
|
455
|
+
launch_args.append(f"--nproc_per_node={nproc_per_node}")
|
456
|
+
if not self.launch_cmd_contains("rdzv_backend"):
|
457
|
+
launch_args.extend(
|
458
|
+
[
|
459
|
+
"--rdzv_backend=c10d",
|
460
|
+
f"--rdzv_endpoint={self.host_ip}:{self.RDZV_PORT}",
|
461
|
+
f"--rdzv_conf={self.get_rdzv_conf()}",
|
462
|
+
]
|
463
|
+
)
|
464
|
+
|
465
|
+
self.time_cmd(cmd=self.prepare_cmd(launch_args, prefix=self.env_ld_preload()))
|
466
|
+
|
467
|
+
|
468
|
+
class DeepSpeedRunner(Runner):
|
469
|
+
STOP_FILE = "/home/datascience/stop"
|
470
|
+
ERROR_FILE = "/home/datascience/error"
|
471
|
+
HOST_FILE = "/home/datascience/hostfile"
|
472
|
+
ENV_FILE = os.path.expanduser("~/.deepspeed_env")
|
473
|
+
LAUNCHER = "deepspeed"
|
474
|
+
|
475
|
+
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
|
476
|
+
super().__init__(code_dir)
|
477
|
+
self.update_os()
|
478
|
+
|
479
|
+
def update_os(self):
|
480
|
+
# Generate SSH host keys for SSH server
|
481
|
+
self.run_command("sudo ssh-keygen -A", level=logging.DEBUG, check=True)
|
482
|
+
# Install SSH server to accept SSH connections
|
483
|
+
# DeepSpeed uses "hostname -I" to determine the IP address
|
484
|
+
# pdsh is required for default multi node training
|
485
|
+
# torch cpp extension uses which command to find compiler
|
486
|
+
# DeepSpeed async_io requires libaio-devel
|
487
|
+
self.run_command(
|
488
|
+
"sudo --preserve-env yum install -y openssh-server hostname pdsh which libaio-devel",
|
489
|
+
level=logging.DEBUG,
|
490
|
+
check=True,
|
491
|
+
)
|
492
|
+
# Start SSH service
|
493
|
+
self.run_command("sudo /usr/sbin/sshd", level=logging.DEBUG, check=True)
|
494
|
+
|
495
|
+
def generate_key_pair(self):
|
496
|
+
self.run_command(
|
497
|
+
"ssh-keygen -q -t rsa -N '' <<< $'\ny'", level=logging.DEBUG, check=True
|
498
|
+
)
|
499
|
+
with open(os.path.join(SSH_DIR, "id_rsa.pub"), "r", encoding="utf-8") as f:
|
500
|
+
public_key = f.read()
|
501
|
+
print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}")
|
502
|
+
self.add_authoried_key(public_key)
|
503
|
+
self.run_command(
|
504
|
+
f"ssh-keyscan -H {self.host_ip} >> {SSH_DIR}/known_hosts",
|
505
|
+
level=logging.DEBUG,
|
506
|
+
check=True,
|
507
|
+
)
|
508
|
+
self.test_ssh_connection(self.host_ip)
|
509
|
+
# Check DeepSpeed compatibility
|
510
|
+
self.run_command(
|
511
|
+
"ds_report", conda_prefix=self.conda_prefix, level=logging.DEBUG
|
512
|
+
)
|
513
|
+
return self
|
514
|
+
|
515
|
+
@staticmethod
|
516
|
+
def add_authoried_key(public_key):
|
517
|
+
auth_keys_file = os.path.join(SSH_DIR, "authorized_keys")
|
518
|
+
os.makedirs(SSH_DIR, exist_ok=True)
|
519
|
+
with open(auth_keys_file, "a+", encoding="utf-8") as f:
|
520
|
+
f.write(public_key)
|
521
|
+
f.write("\n")
|
522
|
+
logger.debug("Public key saved to %s", auth_keys_file)
|
523
|
+
|
524
|
+
def fetch_host_public_key(self):
|
525
|
+
public_key = self.wait_for_log(self.host_job_run, LOG_PREFIX_PUBLIC_KEY)
|
526
|
+
print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}")
|
527
|
+
# logger.debug("%s", LOG_PREFIX_PUBLIC_KEY + public_key)
|
528
|
+
self.add_authoried_key(public_key)
|
529
|
+
|
530
|
+
def generate_hostfile(self):
|
531
|
+
runs = self.host_job_run.job.run_list()
|
532
|
+
self.node_runs = [
|
533
|
+
run
|
534
|
+
for run in runs
|
535
|
+
if run.status in ["ACCEPTED", "IN_PROGRESS"] and run.id != self.host_ocid
|
536
|
+
]
|
537
|
+
self.node_ip_list = [self.wait_for_ip_address(run) for run in self.node_runs]
|
538
|
+
logger.info("Node IPs: %s", self.node_ip_list)
|
539
|
+
# Hostfile
|
540
|
+
logger.debug("Writing hostfile to %s", self.HOST_FILE)
|
541
|
+
os.makedirs(os.path.dirname(self.HOST_FILE), exist_ok=True)
|
542
|
+
host_file_content = [f"{ip} slots={self.gpu_count}" for ip in self.node_ip_list]
|
543
|
+
with open(self.HOST_FILE, "w", encoding="utf-8") as f:
|
544
|
+
f.write(f"{self.host_ip} slots={self.gpu_count}\n")
|
545
|
+
f.writelines(host_file_content)
|
546
|
+
self.run_command(f"cat {self.HOST_FILE}", level=logging.DEBUG)
|
547
|
+
# SSH config
|
548
|
+
ssh_config_path = os.path.join(SSH_DIR, "config")
|
549
|
+
logger.debug("Writing SSH config to %s", ssh_config_path)
|
550
|
+
with open(ssh_config_path, "w", encoding="utf-8") as f:
|
551
|
+
f.writelines(
|
552
|
+
[
|
553
|
+
"",
|
554
|
+
f"Host {self.host_ip}",
|
555
|
+
"IdentityFile /home/datascience/.ssh/id_rsa",
|
556
|
+
"User datascience",
|
557
|
+
]
|
558
|
+
)
|
559
|
+
for node_ip in self.node_ip_list:
|
560
|
+
f.writelines(
|
561
|
+
[
|
562
|
+
"",
|
563
|
+
f"Host {node_ip}",
|
564
|
+
"IdentityFile /home/datascience/.ssh/id_rsa",
|
565
|
+
"User datascience",
|
566
|
+
]
|
567
|
+
)
|
568
|
+
return self
|
569
|
+
|
570
|
+
def test_ssh_connection(self, host):
|
571
|
+
ret = self.run_command(
|
572
|
+
f"ssh -v -o PasswordAuthentication=no {host} hostname -I",
|
573
|
+
level=logging.DEBUG,
|
574
|
+
)
|
575
|
+
if ret == 0:
|
576
|
+
logger.debug("SSH connection to %s - OK", host)
|
577
|
+
else:
|
578
|
+
logger.debug("SSH connection to %s - FAILED", host)
|
579
|
+
|
580
|
+
def touch_file(self, filename):
|
581
|
+
"""Creates an empty file with specific name on all the worker nodes."""
|
582
|
+
for node_ip in self.node_ip_list:
|
583
|
+
logger.debug("Sending stop file to %s", node_ip)
|
584
|
+
self.run_command(
|
585
|
+
f"ssh -v {node_ip} 'touch {filename}'",
|
586
|
+
level=logging.DEBUG,
|
587
|
+
check=True,
|
588
|
+
)
|
589
|
+
|
590
|
+
def save_deepspeed_env(self):
|
591
|
+
"""Saves the environment variables for multi node training.
|
592
|
+
DeepSpeed performs multi-node training via SSH,
|
593
|
+
the environment variables configured by the job runs are not propagated to the SSH session.
|
594
|
+
DeepSpeed will load the environment variables from file for the SSH sessions.
|
595
|
+
"""
|
596
|
+
with open(self.ENV_FILE, mode="w", encoding="utf-8") as f:
|
597
|
+
for k, v in os.environ.items():
|
598
|
+
# As of deepspeed==0.9.2, empty value or line break will cause parsing error,
|
599
|
+
# as the .deepspeed_env file is parsed line by line.
|
600
|
+
if not v or "\n" in v:
|
601
|
+
continue
|
602
|
+
# Ignore variables that are node specific
|
603
|
+
# The network interface name for each job run is a unique string, e.g. ens300f0v1604
|
604
|
+
if k in ["NCCL_SOCKET_IFNAME", "GLOO_SOCKET_IFNAME", "JOB_RUN_OCID"]:
|
605
|
+
continue
|
606
|
+
# Quote the value if it contains space
|
607
|
+
# Environment variable containing space may not be exported correctly when using pdsh
|
608
|
+
# https://github.com/microsoft/DeepSpeed/blob/v0.9.2/deepspeed/launcher/multinode_runner.py#L79
|
609
|
+
if " " in v:
|
610
|
+
v = shlex.quote(v)
|
611
|
+
|
612
|
+
f.write(f"{k}={v}\n")
|
613
|
+
# The following are required for specifying the network interface to be used by NCCL/GLOO
|
614
|
+
# The value should be the prefix of the expected network interface name
|
615
|
+
# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-ifname
|
616
|
+
f.write("NCCL_SOCKET_IFNAME=ens\n")
|
617
|
+
f.write("GLOO_SOCKET_IFNAME=ens\n")
|
618
|
+
logger.debug("Environment variables saved to %s", self.ENV_FILE)
|
619
|
+
self.run_command(f"cat {self.ENV_FILE}")
|
620
|
+
|
621
|
+
def run_deepspeed_host(self, launch_args=None):
|
622
|
+
"""Prepares the host and launch the deepspeed training.
|
623
|
+
|
624
|
+
Parameters
|
625
|
+
----------
|
626
|
+
launch_args : str, optional
|
627
|
+
Additional command line arguments, by default None.
|
628
|
+
The deepspeed host file should be specified in the launch args.
|
629
|
+
For "deepspeed": --hostfile
|
630
|
+
For "accelerate launch": --deepspeed_hostfile
|
631
|
+
"""
|
632
|
+
if self.node_count > 1:
|
633
|
+
self.generate_key_pair().generate_hostfile()
|
634
|
+
self.save_deepspeed_env()
|
635
|
+
# Wait for nodes to be ready
|
636
|
+
for run in self.node_runs:
|
637
|
+
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
|
638
|
+
|
639
|
+
for node_ip in self.node_ip_list:
|
640
|
+
self.run_command(
|
641
|
+
f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts",
|
642
|
+
level=logging.DEBUG,
|
643
|
+
check=True,
|
644
|
+
)
|
645
|
+
|
646
|
+
cmd = self.prepare_cmd(launch_args)
|
647
|
+
# For DeepSpeed, we only need to run the cmd on the host
|
648
|
+
try:
|
649
|
+
self.time_cmd(cmd)
|
650
|
+
except:
|
651
|
+
# Caution: file will not be generated if job run is killed from the console.
|
652
|
+
self.touch_file(self.ERROR_FILE)
|
653
|
+
raise
|
654
|
+
# Signal stop
|
655
|
+
self.touch_file(self.STOP_FILE)
|
656
|
+
|
657
|
+
def run_deepspeed_worker(self):
|
658
|
+
self.fetch_host_public_key()
|
659
|
+
# Keep the job run alive until host job run is finished.
|
660
|
+
while not os.path.exists(self.STOP_FILE):
|
661
|
+
time.sleep(60)
|
662
|
+
# Stop the node if the host touched the error file.
|
663
|
+
if os.path.exists(self.ERROR_FILE):
|
664
|
+
logger.error("There is an error in the host job run.")
|
665
|
+
sys.exit(1)
|
666
|
+
# Stop the node if the host job run is CANCELLED or in unexpected state.
|
667
|
+
try:
|
668
|
+
self.host_job_run.sync()
|
669
|
+
except oci.exceptions.TransientServiceError:
|
670
|
+
# Ignore the transient error and try again next time.
|
671
|
+
continue
|
672
|
+
if self.host_job_run.status not in [
|
673
|
+
"ACCEPTED",
|
674
|
+
"IN_PROGRESS",
|
675
|
+
"SUCCEEDED",
|
676
|
+
]:
|
677
|
+
logger.info(
|
678
|
+
"Host job run status is %s. Stopping job run...",
|
679
|
+
self.host_job_run.status,
|
680
|
+
)
|
681
|
+
sys.exit(2)
|
682
|
+
logger.info("Job finished successfully. Stopping job run...")
|
683
|
+
|
684
|
+
def run(self):
|
685
|
+
if self.is_host:
|
686
|
+
if self.node_count > 1:
|
687
|
+
launch_args = [f"--hostfile={self.HOST_FILE}"]
|
688
|
+
else:
|
689
|
+
launch_args = []
|
690
|
+
self.run_deepspeed_host(launch_args)
|
691
|
+
else:
|
692
|
+
self.run_deepspeed_worker()
|
693
|
+
|
694
|
+
|
695
|
+
class GenericRunner(TorchRunner, DeepSpeedRunner):
|
696
|
+
"""Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
|
697
|
+
|
698
|
+
LAUNCHER = ""
|
699
|
+
|
700
|
+
def use_deepspeed(self) -> bool:
|
701
|
+
"""Indicate if DeepSpeed is used."""
|
702
|
+
if os.environ.get(CONST_ENV_DEEPSPEED):
|
703
|
+
return True
|
704
|
+
return False
|
705
|
+
|
706
|
+
def set_env_var(self):
|
707
|
+
"""Set default environment variables."""
|
708
|
+
defaults = {
|
709
|
+
"WORLD_SIZE": self.node_count * self.gpu_count,
|
710
|
+
"MASTER_ADDR": self.host_ip,
|
711
|
+
"MASTER_PORT": self.RDZV_PORT,
|
712
|
+
}
|
713
|
+
for k, v in defaults.items():
|
714
|
+
if k not in os.environ:
|
715
|
+
os.environ[k] = str(v)
|
716
|
+
|
717
|
+
def run(self):
|
718
|
+
"""Runs the user's command.
|
719
|
+
Note that for TorchRunner or DeepSpeedRunner,
|
720
|
+
we automatically add arguments for some settings,
|
721
|
+
like the number of nodes and the host node address.
|
722
|
+
|
723
|
+
This generic runner does not modify the command specified by the user.
|
724
|
+
User needs to make sure the command can work on all nodes.
|
725
|
+
User may use the environment variables in the command.
|
726
|
+
"""
|
727
|
+
self.set_env_var()
|
728
|
+
if self.use_deepspeed():
|
729
|
+
if self.is_host:
|
730
|
+
self.run_deepspeed_host()
|
731
|
+
else:
|
732
|
+
self.run_deepspeed_worker()
|
733
|
+
else:
|
734
|
+
self.time_cmd(cmd=self.prepare_cmd(prefix=self.env_ld_preload()))
|
735
|
+
|
736
|
+
|
737
|
+
class AccelerateRunner(TorchRunner, DeepSpeedRunner):
|
738
|
+
"""Runner for HuggingFace Accelerate."""
|
739
|
+
|
740
|
+
# accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
|
741
|
+
# https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
|
742
|
+
DEFAULT_ARGS = [
|
743
|
+
"num_processes",
|
744
|
+
"num_machines",
|
745
|
+
"machine_rank",
|
746
|
+
"main_process_ip",
|
747
|
+
"main_process_port",
|
748
|
+
]
|
749
|
+
TORCHRUN_ARGS = []
|
750
|
+
LAUNCHER = "accelerate launch"
|
751
|
+
|
752
|
+
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
|
753
|
+
super().__init__(code_dir)
|
754
|
+
# For "accelerate launch", only one of the following options can be used at one time
|
755
|
+
# `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`.
|
756
|
+
# When a config file is not provided,
|
757
|
+
# --multi_gpu will be set automatically if there is more than 1 GPU
|
758
|
+
# self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
|
759
|
+
self.num_machines = self.node_count
|
760
|
+
self.machine_rank = os.environ["NODE_RANK"]
|
761
|
+
# Total number of processes across all nodes
|
762
|
+
# Here we assume all nodes are having the same shape
|
763
|
+
self.num_processes = (self.gpu_count if self.gpu_count else 1) * self.node_count
|
764
|
+
|
765
|
+
self.main_process_port = self.RDZV_PORT
|
766
|
+
# Host IP is not ready at initialization
|
767
|
+
self.main_process_ip = None
|
768
|
+
|
769
|
+
def use_deepspeed(self):
|
770
|
+
"""Indicate if DeepSpeed is used."""
|
771
|
+
# Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
|
772
|
+
if os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
|
773
|
+
"use_deepspeed"
|
774
|
+
):
|
775
|
+
return True
|
776
|
+
return False
|
777
|
+
|
778
|
+
def accelerate_args(self):
|
779
|
+
"""Gets the default arguments for the accelerate command.
|
780
|
+
The value of the default arguments are assigned in ``__init__()``.
|
781
|
+
"""
|
782
|
+
args = []
|
783
|
+
for arg in self.DEFAULT_ARGS:
|
784
|
+
arg_val = getattr(self, arg, None)
|
785
|
+
logger.debug("%s=%s", arg, arg_val)
|
786
|
+
if arg_val is True:
|
787
|
+
args.append(f"--{arg}")
|
788
|
+
elif arg_val:
|
789
|
+
args.extend([f"--{arg}", str(arg_val)])
|
790
|
+
return args
|
791
|
+
|
792
|
+
def run_with_torchrun(self):
|
793
|
+
"""Runs the job with torchrun."""
|
794
|
+
launch_args = self.accelerate_args()
|
795
|
+
for arg in self.TORCHRUN_ARGS:
|
796
|
+
if not self.launch_cmd_contains(arg):
|
797
|
+
launch_args.extend([f"--{arg}", f"{getattr(self, arg)}"])
|
798
|
+
cmd = self.prepare_cmd(launch_args, prefix=self.env_ld_preload())
|
799
|
+
self.time_cmd(cmd=cmd)
|
800
|
+
|
801
|
+
def run_with_deepspeed(self):
|
802
|
+
"""Runs the job with DeepSpeed."""
|
803
|
+
if self.is_host:
|
804
|
+
launch_args = self.accelerate_args()
|
805
|
+
if self.num_machines > 1:
|
806
|
+
launch_args.append(f"--deepspeed_hostfile={self.HOST_FILE}")
|
807
|
+
self.run_deepspeed_host(launch_args)
|
808
|
+
else:
|
809
|
+
self.run_deepspeed_worker()
|
810
|
+
|
811
|
+
def run(self):
|
812
|
+
self.main_process_ip = self.host_ip
|
813
|
+
# Check if any default argument is provided by the user
|
814
|
+
for arg in self.DEFAULT_ARGS:
|
815
|
+
if self.launch_cmd_contains(arg):
|
816
|
+
logger.debug("%s found in command.", arg)
|
817
|
+
setattr(self, arg, None)
|
818
|
+
if self.use_deepspeed():
|
819
|
+
self.run_with_deepspeed()
|
820
|
+
else:
|
821
|
+
self.run_with_torchrun()
|
822
|
+
|
823
|
+
|
824
|
+
def main():
|
825
|
+
launch_cmd = os.environ.get(CONST_ENV_LAUNCH_CMD)
|
826
|
+
if not launch_cmd or launch_cmd.startswith("torchrun "):
|
827
|
+
# Use torchrun as default if launch cmd is not provided
|
828
|
+
runner_class = TorchRunner
|
829
|
+
elif launch_cmd.startswith("deepspeed "):
|
830
|
+
runner_class = DeepSpeedRunner
|
831
|
+
elif launch_cmd.startswith("accelerate "):
|
832
|
+
runner_class = AccelerateRunner
|
833
|
+
else:
|
834
|
+
runner_class = GenericRunner
|
835
|
+
|
836
|
+
runner = runner_class()
|
837
|
+
runner: Runner
|
838
|
+
runner.fetch_code().set_working_dir().setup_python_path().install_dependencies()
|
839
|
+
|
840
|
+
driver_utils.OCIHelper.copy_inputs()
|
841
|
+
|
842
|
+
runner.wait_for_host_ip_address().run()
|
843
|
+
driver_utils.OCIHelper.copy_outputs()
|
844
|
+
|
845
|
+
|
846
|
+
if __name__ == "__main__":
|
847
|
+
# Collect GPU metrics only if GPU is available and user defined METRIC_NAMESPACE
|
848
|
+
if METRIC_NAMESPACE and torch.cuda.device_count():
|
849
|
+
p = multiprocessing.Process(target=collect_metrics)
|
850
|
+
p.daemon = True
|
851
|
+
p.start()
|
852
|
+
main()
|