oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (858) hide show
  1. ads/aqua/__init__.py +40 -0
  2. ads/aqua/app.py +507 -0
  3. ads/aqua/cli.py +96 -0
  4. ads/aqua/client/__init__.py +3 -0
  5. ads/aqua/client/client.py +836 -0
  6. ads/aqua/client/openai_client.py +305 -0
  7. ads/aqua/common/__init__.py +5 -0
  8. ads/aqua/common/decorator.py +125 -0
  9. ads/aqua/common/entities.py +274 -0
  10. ads/aqua/common/enums.py +134 -0
  11. ads/aqua/common/errors.py +109 -0
  12. ads/aqua/common/utils.py +1295 -0
  13. ads/aqua/config/__init__.py +4 -0
  14. ads/aqua/config/container_config.py +247 -0
  15. ads/aqua/config/evaluation/__init__.py +4 -0
  16. ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
  17. ads/aqua/config/utils/__init__.py +4 -0
  18. ads/aqua/config/utils/serializer.py +339 -0
  19. ads/aqua/constants.py +116 -0
  20. ads/aqua/data.py +14 -0
  21. ads/aqua/dummy_data/icon.txt +1 -0
  22. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  23. ads/aqua/dummy_data/oci_models.json +1 -0
  24. ads/aqua/dummy_data/readme.md +26 -0
  25. ads/aqua/evaluation/__init__.py +8 -0
  26. ads/aqua/evaluation/constants.py +53 -0
  27. ads/aqua/evaluation/entities.py +186 -0
  28. ads/aqua/evaluation/errors.py +70 -0
  29. ads/aqua/evaluation/evaluation.py +1814 -0
  30. ads/aqua/extension/__init__.py +42 -0
  31. ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
  32. ads/aqua/extension/base_handler.py +90 -0
  33. ads/aqua/extension/common_handler.py +121 -0
  34. ads/aqua/extension/common_ws_msg_handler.py +36 -0
  35. ads/aqua/extension/deployment_handler.py +381 -0
  36. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  37. ads/aqua/extension/errors.py +30 -0
  38. ads/aqua/extension/evaluation_handler.py +129 -0
  39. ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
  40. ads/aqua/extension/finetune_handler.py +96 -0
  41. ads/aqua/extension/model_handler.py +390 -0
  42. ads/aqua/extension/models/__init__.py +0 -0
  43. ads/aqua/extension/models/ws_models.py +145 -0
  44. ads/aqua/extension/models_ws_msg_handler.py +50 -0
  45. ads/aqua/extension/ui_handler.py +300 -0
  46. ads/aqua/extension/ui_websocket_handler.py +130 -0
  47. ads/aqua/extension/utils.py +133 -0
  48. ads/aqua/finetuning/__init__.py +7 -0
  49. ads/aqua/finetuning/constants.py +23 -0
  50. ads/aqua/finetuning/entities.py +181 -0
  51. ads/aqua/finetuning/finetuning.py +749 -0
  52. ads/aqua/model/__init__.py +8 -0
  53. ads/aqua/model/constants.py +60 -0
  54. ads/aqua/model/entities.py +385 -0
  55. ads/aqua/model/enums.py +32 -0
  56. ads/aqua/model/model.py +2134 -0
  57. ads/aqua/model/utils.py +52 -0
  58. ads/aqua/modeldeployment/__init__.py +6 -0
  59. ads/aqua/modeldeployment/constants.py +10 -0
  60. ads/aqua/modeldeployment/deployment.py +1315 -0
  61. ads/aqua/modeldeployment/entities.py +653 -0
  62. ads/aqua/modeldeployment/utils.py +543 -0
  63. ads/aqua/resources/gpu_shapes_index.json +94 -0
  64. ads/aqua/server/__init__.py +4 -0
  65. ads/aqua/server/__main__.py +24 -0
  66. ads/aqua/server/app.py +47 -0
  67. ads/aqua/server/aqua_spec.yml +1291 -0
  68. ads/aqua/training/__init__.py +4 -0
  69. ads/aqua/training/exceptions.py +476 -0
  70. ads/aqua/ui.py +519 -0
  71. ads/automl/__init__.py +9 -0
  72. ads/automl/driver.py +330 -0
  73. ads/automl/provider.py +975 -0
  74. ads/bds/__init__.py +5 -0
  75. ads/bds/auth.py +127 -0
  76. ads/bds/big_data_service.py +255 -0
  77. ads/catalog/__init__.py +19 -0
  78. ads/catalog/model.py +1576 -0
  79. ads/catalog/notebook.py +461 -0
  80. ads/catalog/project.py +468 -0
  81. ads/catalog/summary.py +178 -0
  82. ads/common/__init__.py +11 -0
  83. ads/common/analyzer.py +65 -0
  84. ads/common/artifact/.model-ignore +63 -0
  85. ads/common/artifact/__init__.py +10 -0
  86. ads/common/auth.py +1122 -0
  87. ads/common/card_identifier.py +83 -0
  88. ads/common/config.py +647 -0
  89. ads/common/data.py +165 -0
  90. ads/common/decorator/__init__.py +9 -0
  91. ads/common/decorator/argument_to_case.py +88 -0
  92. ads/common/decorator/deprecate.py +69 -0
  93. ads/common/decorator/require_nonempty_arg.py +65 -0
  94. ads/common/decorator/runtime_dependency.py +178 -0
  95. ads/common/decorator/threaded.py +97 -0
  96. ads/common/decorator/utils.py +35 -0
  97. ads/common/dsc_file_system.py +303 -0
  98. ads/common/error.py +14 -0
  99. ads/common/extended_enum.py +81 -0
  100. ads/common/function/__init__.py +5 -0
  101. ads/common/function/fn_util.py +142 -0
  102. ads/common/function/func_conf.yaml +25 -0
  103. ads/common/ipython.py +76 -0
  104. ads/common/model.py +679 -0
  105. ads/common/model_artifact.py +1759 -0
  106. ads/common/model_artifact_schema.json +107 -0
  107. ads/common/model_export_util.py +664 -0
  108. ads/common/model_metadata.py +24 -0
  109. ads/common/object_storage_details.py +296 -0
  110. ads/common/oci_client.py +179 -0
  111. ads/common/oci_datascience.py +46 -0
  112. ads/common/oci_logging.py +1144 -0
  113. ads/common/oci_mixin.py +957 -0
  114. ads/common/oci_resource.py +136 -0
  115. ads/common/serializer.py +559 -0
  116. ads/common/utils.py +1852 -0
  117. ads/common/word_lists.py +1491 -0
  118. ads/common/work_request.py +189 -0
  119. ads/config.py +1 -0
  120. ads/data_labeling/__init__.py +13 -0
  121. ads/data_labeling/boundingbox.py +253 -0
  122. ads/data_labeling/constants.py +47 -0
  123. ads/data_labeling/data_labeling_service.py +244 -0
  124. ads/data_labeling/interface/__init__.py +5 -0
  125. ads/data_labeling/interface/loader.py +16 -0
  126. ads/data_labeling/interface/parser.py +16 -0
  127. ads/data_labeling/interface/reader.py +23 -0
  128. ads/data_labeling/loader/__init__.py +5 -0
  129. ads/data_labeling/loader/file_loader.py +241 -0
  130. ads/data_labeling/metadata.py +110 -0
  131. ads/data_labeling/mixin/__init__.py +5 -0
  132. ads/data_labeling/mixin/data_labeling.py +232 -0
  133. ads/data_labeling/ner.py +129 -0
  134. ads/data_labeling/parser/__init__.py +5 -0
  135. ads/data_labeling/parser/dls_record_parser.py +388 -0
  136. ads/data_labeling/parser/export_metadata_parser.py +94 -0
  137. ads/data_labeling/parser/export_record_parser.py +473 -0
  138. ads/data_labeling/reader/__init__.py +5 -0
  139. ads/data_labeling/reader/dataset_reader.py +574 -0
  140. ads/data_labeling/reader/dls_record_reader.py +121 -0
  141. ads/data_labeling/reader/export_record_reader.py +62 -0
  142. ads/data_labeling/reader/jsonl_reader.py +75 -0
  143. ads/data_labeling/reader/metadata_reader.py +203 -0
  144. ads/data_labeling/reader/record_reader.py +263 -0
  145. ads/data_labeling/record.py +52 -0
  146. ads/data_labeling/visualizer/__init__.py +5 -0
  147. ads/data_labeling/visualizer/image_visualizer.py +525 -0
  148. ads/data_labeling/visualizer/text_visualizer.py +357 -0
  149. ads/database/__init__.py +5 -0
  150. ads/database/connection.py +338 -0
  151. ads/dataset/__init__.py +10 -0
  152. ads/dataset/capabilities.md +51 -0
  153. ads/dataset/classification_dataset.py +339 -0
  154. ads/dataset/correlation.py +226 -0
  155. ads/dataset/correlation_plot.py +563 -0
  156. ads/dataset/dask_series.py +173 -0
  157. ads/dataset/dataframe_transformer.py +110 -0
  158. ads/dataset/dataset.py +1979 -0
  159. ads/dataset/dataset_browser.py +360 -0
  160. ads/dataset/dataset_with_target.py +995 -0
  161. ads/dataset/exception.py +25 -0
  162. ads/dataset/factory.py +987 -0
  163. ads/dataset/feature_engineering_transformer.py +35 -0
  164. ads/dataset/feature_selection.py +107 -0
  165. ads/dataset/forecasting_dataset.py +26 -0
  166. ads/dataset/helper.py +1450 -0
  167. ads/dataset/label_encoder.py +99 -0
  168. ads/dataset/mixin/__init__.py +5 -0
  169. ads/dataset/mixin/dataset_accessor.py +134 -0
  170. ads/dataset/pipeline.py +58 -0
  171. ads/dataset/plot.py +710 -0
  172. ads/dataset/progress.py +86 -0
  173. ads/dataset/recommendation.py +297 -0
  174. ads/dataset/recommendation_transformer.py +502 -0
  175. ads/dataset/regression_dataset.py +14 -0
  176. ads/dataset/sampled_dataset.py +1050 -0
  177. ads/dataset/target.py +98 -0
  178. ads/dataset/timeseries.py +18 -0
  179. ads/dbmixin/__init__.py +5 -0
  180. ads/dbmixin/db_pandas_accessor.py +153 -0
  181. ads/environment/__init__.py +9 -0
  182. ads/environment/ml_runtime.py +66 -0
  183. ads/evaluations/README.md +14 -0
  184. ads/evaluations/__init__.py +109 -0
  185. ads/evaluations/evaluation_plot.py +983 -0
  186. ads/evaluations/evaluator.py +1334 -0
  187. ads/evaluations/statistical_metrics.py +543 -0
  188. ads/experiments/__init__.py +9 -0
  189. ads/experiments/capabilities.md +0 -0
  190. ads/explanations/__init__.py +21 -0
  191. ads/explanations/base_explainer.py +142 -0
  192. ads/explanations/capabilities.md +83 -0
  193. ads/explanations/explainer.py +190 -0
  194. ads/explanations/mlx_global_explainer.py +1050 -0
  195. ads/explanations/mlx_interface.py +386 -0
  196. ads/explanations/mlx_local_explainer.py +287 -0
  197. ads/explanations/mlx_whatif_explainer.py +201 -0
  198. ads/feature_engineering/__init__.py +20 -0
  199. ads/feature_engineering/accessor/__init__.py +5 -0
  200. ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
  201. ads/feature_engineering/accessor/mixin/__init__.py +5 -0
  202. ads/feature_engineering/accessor/mixin/correlation.py +166 -0
  203. ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
  204. ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
  205. ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
  206. ads/feature_engineering/accessor/mixin/utils.py +65 -0
  207. ads/feature_engineering/accessor/series_accessor.py +431 -0
  208. ads/feature_engineering/adsimage/__init__.py +5 -0
  209. ads/feature_engineering/adsimage/image.py +192 -0
  210. ads/feature_engineering/adsimage/image_reader.py +170 -0
  211. ads/feature_engineering/adsimage/interface/__init__.py +5 -0
  212. ads/feature_engineering/adsimage/interface/reader.py +19 -0
  213. ads/feature_engineering/adsstring/__init__.py +7 -0
  214. ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
  215. ads/feature_engineering/adsstring/string/__init__.py +8 -0
  216. ads/feature_engineering/data_schema.json +57 -0
  217. ads/feature_engineering/dataset/__init__.py +5 -0
  218. ads/feature_engineering/dataset/zip_code_data.py +42062 -0
  219. ads/feature_engineering/exceptions.py +40 -0
  220. ads/feature_engineering/feature_type/__init__.py +133 -0
  221. ads/feature_engineering/feature_type/address.py +184 -0
  222. ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
  223. ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
  224. ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
  225. ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
  226. ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
  227. ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
  228. ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
  229. ads/feature_engineering/feature_type/adsstring/string.py +258 -0
  230. ads/feature_engineering/feature_type/base.py +58 -0
  231. ads/feature_engineering/feature_type/boolean.py +183 -0
  232. ads/feature_engineering/feature_type/category.py +146 -0
  233. ads/feature_engineering/feature_type/constant.py +137 -0
  234. ads/feature_engineering/feature_type/continuous.py +151 -0
  235. ads/feature_engineering/feature_type/creditcard.py +314 -0
  236. ads/feature_engineering/feature_type/datetime.py +190 -0
  237. ads/feature_engineering/feature_type/discrete.py +134 -0
  238. ads/feature_engineering/feature_type/document.py +43 -0
  239. ads/feature_engineering/feature_type/gis.py +251 -0
  240. ads/feature_engineering/feature_type/handler/__init__.py +5 -0
  241. ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
  242. ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
  243. ads/feature_engineering/feature_type/handler/warnings.py +128 -0
  244. ads/feature_engineering/feature_type/integer.py +142 -0
  245. ads/feature_engineering/feature_type/ip_address.py +144 -0
  246. ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
  247. ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
  248. ads/feature_engineering/feature_type/lat_long.py +256 -0
  249. ads/feature_engineering/feature_type/object.py +43 -0
  250. ads/feature_engineering/feature_type/ordinal.py +132 -0
  251. ads/feature_engineering/feature_type/phone_number.py +135 -0
  252. ads/feature_engineering/feature_type/string.py +171 -0
  253. ads/feature_engineering/feature_type/text.py +93 -0
  254. ads/feature_engineering/feature_type/unknown.py +43 -0
  255. ads/feature_engineering/feature_type/zip_code.py +164 -0
  256. ads/feature_engineering/feature_type_manager.py +406 -0
  257. ads/feature_engineering/schema.py +795 -0
  258. ads/feature_engineering/utils.py +245 -0
  259. ads/feature_store/.readthedocs.yaml +19 -0
  260. ads/feature_store/README.md +65 -0
  261. ads/feature_store/__init__.py +9 -0
  262. ads/feature_store/common/__init__.py +0 -0
  263. ads/feature_store/common/enums.py +339 -0
  264. ads/feature_store/common/exceptions.py +18 -0
  265. ads/feature_store/common/spark_session_singleton.py +125 -0
  266. ads/feature_store/common/utils/__init__.py +0 -0
  267. ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
  268. ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
  269. ads/feature_store/common/utils/transformation_utils.py +82 -0
  270. ads/feature_store/common/utils/utility.py +403 -0
  271. ads/feature_store/data_validation/__init__.py +0 -0
  272. ads/feature_store/data_validation/great_expectation.py +129 -0
  273. ads/feature_store/dataset.py +1230 -0
  274. ads/feature_store/dataset_job.py +530 -0
  275. ads/feature_store/docs/Dockerfile +7 -0
  276. ads/feature_store/docs/Makefile +44 -0
  277. ads/feature_store/docs/conf.py +28 -0
  278. ads/feature_store/docs/requirements.txt +14 -0
  279. ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
  280. ads/feature_store/docs/source/cicd.rst +137 -0
  281. ads/feature_store/docs/source/conf.py +86 -0
  282. ads/feature_store/docs/source/data_versioning.rst +33 -0
  283. ads/feature_store/docs/source/dataset.rst +388 -0
  284. ads/feature_store/docs/source/dataset_job.rst +27 -0
  285. ads/feature_store/docs/source/demo.rst +70 -0
  286. ads/feature_store/docs/source/entity.rst +78 -0
  287. ads/feature_store/docs/source/feature_group.rst +624 -0
  288. ads/feature_store/docs/source/feature_group_job.rst +29 -0
  289. ads/feature_store/docs/source/feature_store.rst +122 -0
  290. ads/feature_store/docs/source/feature_store_class.rst +123 -0
  291. ads/feature_store/docs/source/feature_validation.rst +66 -0
  292. ads/feature_store/docs/source/figures/cicd.png +0 -0
  293. ads/feature_store/docs/source/figures/data_validation.png +0 -0
  294. ads/feature_store/docs/source/figures/data_versioning.png +0 -0
  295. ads/feature_store/docs/source/figures/dataset.gif +0 -0
  296. ads/feature_store/docs/source/figures/dataset.png +0 -0
  297. ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
  298. ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
  299. ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
  300. ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
  301. ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
  302. ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
  303. ads/feature_store/docs/source/figures/entity.png +0 -0
  304. ads/feature_store/docs/source/figures/feature_group.png +0 -0
  305. ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
  306. ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
  307. ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
  308. ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
  309. ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
  310. ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
  311. ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
  312. ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
  313. ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
  314. ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
  315. ads/feature_store/docs/source/figures/overview.png +0 -0
  316. ads/feature_store/docs/source/figures/resource_manager.png +0 -0
  317. ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
  318. ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
  319. ads/feature_store/docs/source/figures/stats_1.png +0 -0
  320. ads/feature_store/docs/source/figures/stats_2.png +0 -0
  321. ads/feature_store/docs/source/figures/stats_d.png +0 -0
  322. ads/feature_store/docs/source/figures/stats_fg.png +0 -0
  323. ads/feature_store/docs/source/figures/transformation.png +0 -0
  324. ads/feature_store/docs/source/figures/transformations.gif +0 -0
  325. ads/feature_store/docs/source/figures/validation.png +0 -0
  326. ads/feature_store/docs/source/figures/validation_fg.png +0 -0
  327. ads/feature_store/docs/source/figures/validation_results.png +0 -0
  328. ads/feature_store/docs/source/figures/validation_summary.png +0 -0
  329. ads/feature_store/docs/source/index.rst +81 -0
  330. ads/feature_store/docs/source/module.rst +8 -0
  331. ads/feature_store/docs/source/notebook.rst +94 -0
  332. ads/feature_store/docs/source/overview.rst +47 -0
  333. ads/feature_store/docs/source/quickstart.rst +176 -0
  334. ads/feature_store/docs/source/release_notes.rst +194 -0
  335. ads/feature_store/docs/source/setup_feature_store.rst +81 -0
  336. ads/feature_store/docs/source/statistics.rst +58 -0
  337. ads/feature_store/docs/source/transformation.rst +199 -0
  338. ads/feature_store/docs/source/ui.rst +65 -0
  339. ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
  340. ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
  341. ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
  342. ads/feature_store/entity.py +718 -0
  343. ads/feature_store/execution_strategy/__init__.py +0 -0
  344. ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
  345. ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
  346. ads/feature_store/execution_strategy/engine/__init__.py +0 -0
  347. ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
  348. ads/feature_store/execution_strategy/execution_strategy.py +113 -0
  349. ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
  350. ads/feature_store/execution_strategy/spark/__init__.py +0 -0
  351. ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
  352. ads/feature_store/feature.py +192 -0
  353. ads/feature_store/feature_group.py +1494 -0
  354. ads/feature_store/feature_group_expectation.py +346 -0
  355. ads/feature_store/feature_group_job.py +602 -0
  356. ads/feature_store/feature_lineage/__init__.py +0 -0
  357. ads/feature_store/feature_lineage/graphviz_service.py +180 -0
  358. ads/feature_store/feature_option_details.py +50 -0
  359. ads/feature_store/feature_statistics/__init__.py +0 -0
  360. ads/feature_store/feature_statistics/statistics_service.py +99 -0
  361. ads/feature_store/feature_store.py +699 -0
  362. ads/feature_store/feature_store_registrar.py +518 -0
  363. ads/feature_store/input_feature_detail.py +149 -0
  364. ads/feature_store/mixin/__init__.py +4 -0
  365. ads/feature_store/mixin/oci_feature_store.py +145 -0
  366. ads/feature_store/model_details.py +73 -0
  367. ads/feature_store/query/__init__.py +0 -0
  368. ads/feature_store/query/filter.py +266 -0
  369. ads/feature_store/query/generator/__init__.py +0 -0
  370. ads/feature_store/query/generator/query_generator.py +298 -0
  371. ads/feature_store/query/join.py +161 -0
  372. ads/feature_store/query/query.py +403 -0
  373. ads/feature_store/query/validator/__init__.py +0 -0
  374. ads/feature_store/query/validator/query_validator.py +57 -0
  375. ads/feature_store/response/__init__.py +0 -0
  376. ads/feature_store/response/response_builder.py +68 -0
  377. ads/feature_store/service/__init__.py +0 -0
  378. ads/feature_store/service/oci_dataset.py +139 -0
  379. ads/feature_store/service/oci_dataset_job.py +199 -0
  380. ads/feature_store/service/oci_entity.py +125 -0
  381. ads/feature_store/service/oci_feature_group.py +164 -0
  382. ads/feature_store/service/oci_feature_group_job.py +214 -0
  383. ads/feature_store/service/oci_feature_store.py +182 -0
  384. ads/feature_store/service/oci_lineage.py +87 -0
  385. ads/feature_store/service/oci_transformation.py +104 -0
  386. ads/feature_store/statistics/__init__.py +0 -0
  387. ads/feature_store/statistics/abs_feature_value.py +49 -0
  388. ads/feature_store/statistics/charts/__init__.py +0 -0
  389. ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
  390. ads/feature_store/statistics/charts/box_plot.py +148 -0
  391. ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
  392. ads/feature_store/statistics/charts/probability_distribution.py +68 -0
  393. ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
  394. ads/feature_store/statistics/feature_stat.py +126 -0
  395. ads/feature_store/statistics/generic_feature_value.py +33 -0
  396. ads/feature_store/statistics/statistics.py +41 -0
  397. ads/feature_store/statistics_config.py +101 -0
  398. ads/feature_store/templates/feature_store_template.yaml +45 -0
  399. ads/feature_store/transformation.py +499 -0
  400. ads/feature_store/validation_output.py +57 -0
  401. ads/hpo/__init__.py +9 -0
  402. ads/hpo/_imports.py +91 -0
  403. ads/hpo/ads_search_space.py +439 -0
  404. ads/hpo/distributions.py +325 -0
  405. ads/hpo/objective.py +280 -0
  406. ads/hpo/search_cv.py +1657 -0
  407. ads/hpo/stopping_criterion.py +75 -0
  408. ads/hpo/tuner_artifact.py +413 -0
  409. ads/hpo/utils.py +91 -0
  410. ads/hpo/validation.py +140 -0
  411. ads/hpo/visualization/__init__.py +5 -0
  412. ads/hpo/visualization/_contour.py +23 -0
  413. ads/hpo/visualization/_edf.py +20 -0
  414. ads/hpo/visualization/_intermediate_values.py +21 -0
  415. ads/hpo/visualization/_optimization_history.py +25 -0
  416. ads/hpo/visualization/_parallel_coordinate.py +169 -0
  417. ads/hpo/visualization/_param_importances.py +26 -0
  418. ads/jobs/__init__.py +53 -0
  419. ads/jobs/ads_job.py +663 -0
  420. ads/jobs/builders/__init__.py +5 -0
  421. ads/jobs/builders/base.py +156 -0
  422. ads/jobs/builders/infrastructure/__init__.py +6 -0
  423. ads/jobs/builders/infrastructure/base.py +165 -0
  424. ads/jobs/builders/infrastructure/dataflow.py +1252 -0
  425. ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
  426. ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
  427. ads/jobs/builders/infrastructure/utils.py +65 -0
  428. ads/jobs/builders/runtimes/__init__.py +5 -0
  429. ads/jobs/builders/runtimes/artifact.py +338 -0
  430. ads/jobs/builders/runtimes/base.py +325 -0
  431. ads/jobs/builders/runtimes/container_runtime.py +242 -0
  432. ads/jobs/builders/runtimes/python_runtime.py +1016 -0
  433. ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
  434. ads/jobs/cli.py +104 -0
  435. ads/jobs/env_var_parser.py +131 -0
  436. ads/jobs/extension.py +160 -0
  437. ads/jobs/schema/__init__.py +5 -0
  438. ads/jobs/schema/infrastructure_schema.json +116 -0
  439. ads/jobs/schema/job_schema.json +42 -0
  440. ads/jobs/schema/runtime_schema.json +183 -0
  441. ads/jobs/schema/validator.py +141 -0
  442. ads/jobs/serializer.py +296 -0
  443. ads/jobs/templates/__init__.py +5 -0
  444. ads/jobs/templates/container.py +6 -0
  445. ads/jobs/templates/driver_notebook.py +177 -0
  446. ads/jobs/templates/driver_oci.py +500 -0
  447. ads/jobs/templates/driver_python.py +48 -0
  448. ads/jobs/templates/driver_pytorch.py +852 -0
  449. ads/jobs/templates/driver_utils.py +615 -0
  450. ads/jobs/templates/hostname_from_env.c +55 -0
  451. ads/jobs/templates/oci_metrics.py +181 -0
  452. ads/jobs/utils.py +104 -0
  453. ads/llm/__init__.py +28 -0
  454. ads/llm/autogen/__init__.py +2 -0
  455. ads/llm/autogen/constants.py +15 -0
  456. ads/llm/autogen/reports/__init__.py +2 -0
  457. ads/llm/autogen/reports/base.py +67 -0
  458. ads/llm/autogen/reports/data.py +103 -0
  459. ads/llm/autogen/reports/session.py +526 -0
  460. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  461. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  462. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  463. ads/llm/autogen/reports/utils.py +56 -0
  464. ads/llm/autogen/v02/__init__.py +4 -0
  465. ads/llm/autogen/v02/client.py +295 -0
  466. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  467. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  468. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  469. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  470. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  471. ads/llm/autogen/v02/loggers/utils.py +86 -0
  472. ads/llm/autogen/v02/runtime_logging.py +163 -0
  473. ads/llm/chain.py +268 -0
  474. ads/llm/chat_template.py +31 -0
  475. ads/llm/deploy.py +63 -0
  476. ads/llm/guardrails/__init__.py +5 -0
  477. ads/llm/guardrails/base.py +442 -0
  478. ads/llm/guardrails/huggingface.py +44 -0
  479. ads/llm/langchain/__init__.py +5 -0
  480. ads/llm/langchain/plugins/__init__.py +5 -0
  481. ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
  482. ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
  483. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  484. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  485. ads/llm/langchain/plugins/llms/__init__.py +5 -0
  486. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
  487. ads/llm/requirements.txt +3 -0
  488. ads/llm/serialize.py +219 -0
  489. ads/llm/serializers/__init__.py +0 -0
  490. ads/llm/serializers/retrieval_qa.py +153 -0
  491. ads/llm/serializers/runnable_parallel.py +27 -0
  492. ads/llm/templates/score_chain.jinja2 +155 -0
  493. ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
  494. ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
  495. ads/model/__init__.py +52 -0
  496. ads/model/artifact.py +573 -0
  497. ads/model/artifact_downloader.py +254 -0
  498. ads/model/artifact_uploader.py +267 -0
  499. ads/model/base_properties.py +238 -0
  500. ads/model/common/.model-ignore +66 -0
  501. ads/model/common/__init__.py +5 -0
  502. ads/model/common/utils.py +142 -0
  503. ads/model/datascience_model.py +2635 -0
  504. ads/model/deployment/__init__.py +20 -0
  505. ads/model/deployment/common/__init__.py +5 -0
  506. ads/model/deployment/common/utils.py +308 -0
  507. ads/model/deployment/model_deployer.py +466 -0
  508. ads/model/deployment/model_deployment.py +1846 -0
  509. ads/model/deployment/model_deployment_infrastructure.py +671 -0
  510. ads/model/deployment/model_deployment_properties.py +493 -0
  511. ads/model/deployment/model_deployment_runtime.py +838 -0
  512. ads/model/extractor/__init__.py +5 -0
  513. ads/model/extractor/automl_extractor.py +74 -0
  514. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  515. ads/model/extractor/huggingface_extractor.py +88 -0
  516. ads/model/extractor/keras_extractor.py +84 -0
  517. ads/model/extractor/lightgbm_extractor.py +93 -0
  518. ads/model/extractor/model_info_extractor.py +114 -0
  519. ads/model/extractor/model_info_extractor_factory.py +105 -0
  520. ads/model/extractor/pytorch_extractor.py +87 -0
  521. ads/model/extractor/sklearn_extractor.py +112 -0
  522. ads/model/extractor/spark_extractor.py +89 -0
  523. ads/model/extractor/tensorflow_extractor.py +85 -0
  524. ads/model/extractor/xgboost_extractor.py +94 -0
  525. ads/model/framework/__init__.py +5 -0
  526. ads/model/framework/automl_model.py +178 -0
  527. ads/model/framework/embedding_onnx_model.py +438 -0
  528. ads/model/framework/huggingface_model.py +399 -0
  529. ads/model/framework/lightgbm_model.py +266 -0
  530. ads/model/framework/pytorch_model.py +266 -0
  531. ads/model/framework/sklearn_model.py +250 -0
  532. ads/model/framework/spark_model.py +326 -0
  533. ads/model/framework/tensorflow_model.py +254 -0
  534. ads/model/framework/xgboost_model.py +258 -0
  535. ads/model/generic_model.py +3518 -0
  536. ads/model/model_artifact_boilerplate/README.md +381 -0
  537. ads/model/model_artifact_boilerplate/__init__.py +5 -0
  538. ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
  539. ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
  540. ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
  541. ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
  542. ads/model/model_artifact_boilerplate/score.py +61 -0
  543. ads/model/model_file_description_schema.json +68 -0
  544. ads/model/model_introspect.py +331 -0
  545. ads/model/model_metadata.py +1810 -0
  546. ads/model/model_metadata_mixin.py +460 -0
  547. ads/model/model_properties.py +63 -0
  548. ads/model/model_version_set.py +739 -0
  549. ads/model/runtime/__init__.py +5 -0
  550. ads/model/runtime/env_info.py +306 -0
  551. ads/model/runtime/model_deployment_details.py +37 -0
  552. ads/model/runtime/model_provenance_details.py +58 -0
  553. ads/model/runtime/runtime_info.py +81 -0
  554. ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
  555. ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
  556. ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
  557. ads/model/runtime/utils.py +201 -0
  558. ads/model/serde/__init__.py +5 -0
  559. ads/model/serde/common.py +40 -0
  560. ads/model/serde/model_input.py +547 -0
  561. ads/model/serde/model_serializer.py +1184 -0
  562. ads/model/service/__init__.py +5 -0
  563. ads/model/service/oci_datascience_model.py +1076 -0
  564. ads/model/service/oci_datascience_model_deployment.py +500 -0
  565. ads/model/service/oci_datascience_model_version_set.py +176 -0
  566. ads/model/transformer/__init__.py +5 -0
  567. ads/model/transformer/onnx_transformer.py +324 -0
  568. ads/mysqldb/__init__.py +5 -0
  569. ads/mysqldb/mysql_db.py +227 -0
  570. ads/opctl/__init__.py +18 -0
  571. ads/opctl/anomaly_detection.py +11 -0
  572. ads/opctl/backend/__init__.py +5 -0
  573. ads/opctl/backend/ads_dataflow.py +353 -0
  574. ads/opctl/backend/ads_ml_job.py +710 -0
  575. ads/opctl/backend/ads_ml_pipeline.py +164 -0
  576. ads/opctl/backend/ads_model_deployment.py +209 -0
  577. ads/opctl/backend/base.py +146 -0
  578. ads/opctl/backend/local.py +1053 -0
  579. ads/opctl/backend/marketplace/__init__.py +9 -0
  580. ads/opctl/backend/marketplace/helm_helper.py +173 -0
  581. ads/opctl/backend/marketplace/local_marketplace.py +271 -0
  582. ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
  583. ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
  584. ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
  585. ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
  586. ads/opctl/backend/marketplace/models/__init__.py +5 -0
  587. ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
  588. ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
  589. ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
  590. ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
  591. ads/opctl/cli.py +707 -0
  592. ads/opctl/cmds.py +869 -0
  593. ads/opctl/conda/__init__.py +5 -0
  594. ads/opctl/conda/cli.py +193 -0
  595. ads/opctl/conda/cmds.py +749 -0
  596. ads/opctl/conda/config.yaml +34 -0
  597. ads/opctl/conda/manifest_template.yaml +13 -0
  598. ads/opctl/conda/multipart_uploader.py +188 -0
  599. ads/opctl/conda/pack.py +89 -0
  600. ads/opctl/config/__init__.py +5 -0
  601. ads/opctl/config/base.py +57 -0
  602. ads/opctl/config/diagnostics/__init__.py +5 -0
  603. ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
  604. ads/opctl/config/merger.py +255 -0
  605. ads/opctl/config/resolver.py +297 -0
  606. ads/opctl/config/utils.py +79 -0
  607. ads/opctl/config/validator.py +17 -0
  608. ads/opctl/config/versioner.py +68 -0
  609. ads/opctl/config/yaml_parsers/__init__.py +7 -0
  610. ads/opctl/config/yaml_parsers/base.py +58 -0
  611. ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
  612. ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
  613. ads/opctl/constants.py +66 -0
  614. ads/opctl/decorator/__init__.py +5 -0
  615. ads/opctl/decorator/common.py +129 -0
  616. ads/opctl/diagnostics/__init__.py +5 -0
  617. ads/opctl/diagnostics/__main__.py +25 -0
  618. ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
  619. ads/opctl/diagnostics/check_requirements.py +144 -0
  620. ads/opctl/diagnostics/requirement_exception.py +9 -0
  621. ads/opctl/distributed/README.md +109 -0
  622. ads/opctl/distributed/__init__.py +5 -0
  623. ads/opctl/distributed/certificates.py +32 -0
  624. ads/opctl/distributed/cli.py +207 -0
  625. ads/opctl/distributed/cmds.py +731 -0
  626. ads/opctl/distributed/common/__init__.py +5 -0
  627. ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
  628. ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
  629. ads/opctl/distributed/common/cluster_config_helper.py +103 -0
  630. ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
  631. ads/opctl/distributed/common/cluster_runner.py +54 -0
  632. ads/opctl/distributed/common/framework_factory.py +29 -0
  633. ads/opctl/docker/Dockerfile.job +103 -0
  634. ads/opctl/docker/Dockerfile.job.arm +107 -0
  635. ads/opctl/docker/Dockerfile.job.gpu +175 -0
  636. ads/opctl/docker/base-env.yaml +13 -0
  637. ads/opctl/docker/cuda.repo +6 -0
  638. ads/opctl/docker/operator/.dockerignore +0 -0
  639. ads/opctl/docker/operator/Dockerfile +41 -0
  640. ads/opctl/docker/operator/Dockerfile.gpu +85 -0
  641. ads/opctl/docker/operator/cuda.repo +6 -0
  642. ads/opctl/docker/operator/environment.yaml +8 -0
  643. ads/opctl/forecast.py +11 -0
  644. ads/opctl/index.yaml +3 -0
  645. ads/opctl/model/__init__.py +5 -0
  646. ads/opctl/model/cli.py +65 -0
  647. ads/opctl/model/cmds.py +73 -0
  648. ads/opctl/operator/README.md +4 -0
  649. ads/opctl/operator/__init__.py +31 -0
  650. ads/opctl/operator/cli.py +344 -0
  651. ads/opctl/operator/cmd.py +596 -0
  652. ads/opctl/operator/common/__init__.py +5 -0
  653. ads/opctl/operator/common/backend_factory.py +460 -0
  654. ads/opctl/operator/common/const.py +27 -0
  655. ads/opctl/operator/common/data/synthetic.csv +16001 -0
  656. ads/opctl/operator/common/dictionary_merger.py +148 -0
  657. ads/opctl/operator/common/errors.py +42 -0
  658. ads/opctl/operator/common/operator_config.py +99 -0
  659. ads/opctl/operator/common/operator_loader.py +811 -0
  660. ads/opctl/operator/common/operator_schema.yaml +130 -0
  661. ads/opctl/operator/common/operator_yaml_generator.py +152 -0
  662. ads/opctl/operator/common/utils.py +208 -0
  663. ads/opctl/operator/lowcode/__init__.py +5 -0
  664. ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
  665. ads/opctl/operator/lowcode/anomaly/README.md +207 -0
  666. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  667. ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
  668. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  669. ads/opctl/operator/lowcode/anomaly/const.py +167 -0
  670. ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
  671. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  672. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
  673. ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
  674. ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
  675. ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
  676. ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
  677. ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
  678. ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
  679. ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
  680. ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
  681. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  682. ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
  683. ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
  684. ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
  685. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  686. ads/opctl/operator/lowcode/common/const.py +10 -0
  687. ads/opctl/operator/lowcode/common/data.py +116 -0
  688. ads/opctl/operator/lowcode/common/errors.py +47 -0
  689. ads/opctl/operator/lowcode/common/transformations.py +296 -0
  690. ads/opctl/operator/lowcode/common/utils.py +384 -0
  691. ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
  692. ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
  693. ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
  694. ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
  695. ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
  696. ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
  697. ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
  698. ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
  699. ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
  700. ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
  701. ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
  702. ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
  703. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
  704. ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
  705. ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
  706. ads/opctl/operator/lowcode/forecast/README.md +209 -0
  707. ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
  708. ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
  709. ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
  710. ads/opctl/operator/lowcode/forecast/const.py +92 -0
  711. ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
  712. ads/opctl/operator/lowcode/forecast/errors.py +26 -0
  713. ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
  714. ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
  715. ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
  716. ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
  717. ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
  718. ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
  719. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
  720. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
  721. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
  722. ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
  723. ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
  724. ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
  725. ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
  726. ads/opctl/operator/lowcode/forecast/utils.py +397 -0
  727. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  728. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
  729. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
  730. ads/opctl/operator/lowcode/pii/MLoperator +17 -0
  731. ads/opctl/operator/lowcode/pii/README.md +208 -0
  732. ads/opctl/operator/lowcode/pii/__init__.py +5 -0
  733. ads/opctl/operator/lowcode/pii/__main__.py +78 -0
  734. ads/opctl/operator/lowcode/pii/cmd.py +39 -0
  735. ads/opctl/operator/lowcode/pii/constant.py +84 -0
  736. ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
  737. ads/opctl/operator/lowcode/pii/errors.py +27 -0
  738. ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
  739. ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
  740. ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
  741. ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
  742. ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
  743. ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
  744. ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
  745. ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
  746. ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
  747. ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
  748. ads/opctl/operator/lowcode/pii/model/report.py +487 -0
  749. ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
  750. ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
  751. ads/opctl/operator/lowcode/pii/utils.py +43 -0
  752. ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
  753. ads/opctl/operator/lowcode/recommender/README.md +206 -0
  754. ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
  755. ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
  756. ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
  757. ads/opctl/operator/lowcode/recommender/constant.py +30 -0
  758. ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
  759. ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
  760. ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
  761. ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
  762. ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
  763. ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
  764. ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
  765. ads/opctl/operator/lowcode/recommender/utils.py +13 -0
  766. ads/opctl/operator/runtime/__init__.py +5 -0
  767. ads/opctl/operator/runtime/const.py +17 -0
  768. ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
  769. ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
  770. ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
  771. ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
  772. ads/opctl/operator/runtime/runtime.py +115 -0
  773. ads/opctl/schema.yaml.yml +36 -0
  774. ads/opctl/script.py +40 -0
  775. ads/opctl/spark/__init__.py +5 -0
  776. ads/opctl/spark/cli.py +43 -0
  777. ads/opctl/spark/cmds.py +147 -0
  778. ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
  779. ads/opctl/utils.py +344 -0
  780. ads/oracledb/__init__.py +5 -0
  781. ads/oracledb/oracle_db.py +346 -0
  782. ads/pipeline/__init__.py +39 -0
  783. ads/pipeline/ads_pipeline.py +2279 -0
  784. ads/pipeline/ads_pipeline_run.py +772 -0
  785. ads/pipeline/ads_pipeline_step.py +605 -0
  786. ads/pipeline/builders/__init__.py +5 -0
  787. ads/pipeline/builders/infrastructure/__init__.py +5 -0
  788. ads/pipeline/builders/infrastructure/custom_script.py +32 -0
  789. ads/pipeline/cli.py +119 -0
  790. ads/pipeline/extension.py +291 -0
  791. ads/pipeline/schema/__init__.py +5 -0
  792. ads/pipeline/schema/cs_step_schema.json +35 -0
  793. ads/pipeline/schema/ml_step_schema.json +31 -0
  794. ads/pipeline/schema/pipeline_schema.json +71 -0
  795. ads/pipeline/visualizer/__init__.py +5 -0
  796. ads/pipeline/visualizer/base.py +570 -0
  797. ads/pipeline/visualizer/graph_renderer.py +272 -0
  798. ads/pipeline/visualizer/text_renderer.py +84 -0
  799. ads/secrets/__init__.py +11 -0
  800. ads/secrets/adb.py +386 -0
  801. ads/secrets/auth_token.py +86 -0
  802. ads/secrets/big_data_service.py +365 -0
  803. ads/secrets/mysqldb.py +149 -0
  804. ads/secrets/oracledb.py +160 -0
  805. ads/secrets/secrets.py +407 -0
  806. ads/telemetry/__init__.py +7 -0
  807. ads/telemetry/base.py +69 -0
  808. ads/telemetry/client.py +122 -0
  809. ads/telemetry/telemetry.py +257 -0
  810. ads/templates/dataflow_pyspark.jinja2 +13 -0
  811. ads/templates/dataflow_sparksql.jinja2 +22 -0
  812. ads/templates/func.jinja2 +20 -0
  813. ads/templates/schemas/openapi.json +1740 -0
  814. ads/templates/score-pkl.jinja2 +173 -0
  815. ads/templates/score.jinja2 +322 -0
  816. ads/templates/score_embedding_onnx.jinja2 +202 -0
  817. ads/templates/score_generic.jinja2 +165 -0
  818. ads/templates/score_huggingface_pipeline.jinja2 +217 -0
  819. ads/templates/score_lightgbm.jinja2 +185 -0
  820. ads/templates/score_onnx.jinja2 +407 -0
  821. ads/templates/score_onnx_new.jinja2 +473 -0
  822. ads/templates/score_oracle_automl.jinja2 +185 -0
  823. ads/templates/score_pyspark.jinja2 +154 -0
  824. ads/templates/score_pytorch.jinja2 +219 -0
  825. ads/templates/score_scikit-learn.jinja2 +184 -0
  826. ads/templates/score_tensorflow.jinja2 +184 -0
  827. ads/templates/score_xgboost.jinja2 +178 -0
  828. ads/text_dataset/__init__.py +5 -0
  829. ads/text_dataset/backends.py +211 -0
  830. ads/text_dataset/dataset.py +445 -0
  831. ads/text_dataset/extractor.py +207 -0
  832. ads/text_dataset/options.py +53 -0
  833. ads/text_dataset/udfs.py +22 -0
  834. ads/text_dataset/utils.py +49 -0
  835. ads/type_discovery/__init__.py +9 -0
  836. ads/type_discovery/abstract_detector.py +21 -0
  837. ads/type_discovery/constant_detector.py +41 -0
  838. ads/type_discovery/continuous_detector.py +54 -0
  839. ads/type_discovery/credit_card_detector.py +99 -0
  840. ads/type_discovery/datetime_detector.py +92 -0
  841. ads/type_discovery/discrete_detector.py +118 -0
  842. ads/type_discovery/document_detector.py +146 -0
  843. ads/type_discovery/ip_detector.py +68 -0
  844. ads/type_discovery/latlon_detector.py +90 -0
  845. ads/type_discovery/phone_number_detector.py +63 -0
  846. ads/type_discovery/type_discovery_driver.py +87 -0
  847. ads/type_discovery/typed_feature.py +594 -0
  848. ads/type_discovery/unknown_detector.py +41 -0
  849. ads/type_discovery/zipcode_detector.py +48 -0
  850. ads/vault/__init__.py +7 -0
  851. ads/vault/vault.py +237 -0
  852. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
  853. oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
  854. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
  855. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
  856. oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
  857. oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
  858. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
ads/hpo/search_cv.py ADDED
@@ -0,0 +1,1657 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2020, 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ import importlib
8
+ import multiprocessing
9
+ import os
10
+ import uuid
11
+ import psutil
12
+ from enum import Enum, auto
13
+ from time import time, sleep
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ import logging
20
+ from ads.common import logger
21
+ from ads.common import utils
22
+ from ads.common.data import ADSData
23
+ from ads.common.decorator.runtime_dependency import (
24
+ runtime_dependency,
25
+ OptionalDependency,
26
+ )
27
+ from ads.hpo._imports import try_import
28
+ from ads.hpo.ads_search_space import get_model2searchspace
29
+ from ads.hpo.distributions import *
30
+ from ads.hpo.objective import _Objective
31
+ from ads.hpo.stopping_criterion import NTrials, ScoreValue, TimeBudget
32
+ from ads.hpo.utils import _num_samples, _safe_indexing, _update_space_name
33
+ from ads.hpo.validation import (
34
+ assert_is_estimator,
35
+ assert_model_is_supported,
36
+ assert_strategy_valid,
37
+ assert_tuner_is_fitted,
38
+ validate_fit_params,
39
+ validate_pipeline,
40
+ validate_search_space,
41
+ validate_params_for_plot,
42
+ )
43
+
44
+
45
+ with try_import() as _imports:
46
+ from sklearn.base import BaseEstimator, clone, is_classifier
47
+ from sklearn.model_selection import BaseCrossValidator # NOQA
48
+ from sklearn.model_selection import check_cv, cross_validate
49
+ from sklearn.pipeline import Pipeline, make_pipeline
50
+ from sklearn.utils import check_random_state
51
+ from sklearn.exceptions import NotFittedError
52
+
53
+ try:
54
+ from sklearn.metrics import check_scoring
55
+ except:
56
+ from sklearn.metrics.scorer import check_scoring
57
+
58
+
59
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union # NOQA
60
+
61
+
62
+ class State(Enum):
63
+ INITIATED = auto()
64
+ RUNNING = auto()
65
+ HALTED = auto()
66
+ TERMINATED = auto()
67
+ COMPLETED = auto()
68
+
69
+
70
+ class InvalidStateTransition(Exception): # pragma: no cover
71
+ """
72
+ `Invalid State Transition` is raised when an invalid transition request is made, such as calling
73
+ halt without a running process.
74
+ """
75
+
76
+ pass
77
+
78
+
79
+ class ExitCriterionError(Exception): # pragma: no cover
80
+ """
81
+ `ExitCriterionError` is raised when an attempt is made to check exit status for a different exit
82
+ type than the tuner was initialized with. For example, if an HPO study has an exit criteria based
83
+ on the number of trials and a request is made for the time remaining, which is a different exit
84
+ criterion, an exception is raised.
85
+ """
86
+
87
+ pass
88
+
89
+
90
+ class DuplicatedStudyError(Exception): # pragma: no cover
91
+ """
92
+ `DuplicatedStudyError` is raised when a new tuner process is created with a study name that
93
+ already exists in storage.
94
+ """
95
+
96
+
97
+ class NoRestartError(Exception): # pragma: no cover
98
+ """
99
+ `NoRestartError` is raised when an attempt is made to check how many seconds have transpired since
100
+ the HPO process was last resumed from a halt. This can happen if the process has been terminated
101
+ or it was never halted and then resumed to begin with.
102
+ """
103
+
104
+ pass
105
+
106
+
107
+ class DataScienceObjective:
108
+ """This class is to replace the previous lambda function to solve the problem that python does not allow pickle local function/lambda function."""
109
+
110
+ def __init__(self, objective, X_res, y_res):
111
+ self.objective = objective
112
+ self.X_res = X_res
113
+ self.y_res = y_res
114
+
115
+ def __call__(self, trial):
116
+ return self.objective(self.X_res, self.y_res, trial)
117
+
118
+
119
+ class ADSTuner(BaseEstimator):
120
+ """
121
+ Hyperparameter search with cross-validation.
122
+ """
123
+
124
+ _required_parameters = ["model"]
125
+
126
+ @property
127
+ def sklearn_steps(self):
128
+ """
129
+ Returns
130
+ -------
131
+ int
132
+ Search space which corresponds to the best candidate parameter setting.
133
+ """
134
+ return _update_space_name(self.best_params, step_name=self._step_name)
135
+
136
+ @property
137
+ def best_index(self):
138
+ """
139
+ Returns
140
+ -------
141
+ int
142
+ Index which corresponds to the best candidate parameter setting.
143
+ """
144
+ return self.trials["value"].idxmax()
145
+
146
+ @property
147
+ def best_params(self):
148
+ """
149
+ Returns
150
+ -------
151
+ Dict[str, Any]
152
+ Parameters of the best trial.
153
+ """
154
+ self._check_is_fitted()
155
+ return self._remove_step_name(self._study.best_params)
156
+
157
+ @property
158
+ def best_score(self):
159
+ """
160
+ Returns
161
+ -------
162
+ float
163
+ Mean cross-validated score of the best estimator.
164
+ """
165
+ self._check_is_fitted()
166
+ return self._study.best_value
167
+
168
+ @property
169
+ def score_remaining(self):
170
+ """
171
+ Returns
172
+ -------
173
+ float
174
+ The difference between the best score and the optimal score.
175
+
176
+ Raises
177
+ ------
178
+ :class:`ExitCriterionError`
179
+ Error is raised if there is no score-based criteria for tuning.
180
+ """
181
+ if self._optimal_score is None:
182
+ raise ExitCriterionError(
183
+ "Tuner does not have a score-based exit condition."
184
+ )
185
+ else:
186
+ return self._optimal_score - self.best_score
187
+
188
+ @property
189
+ def scoring_name(self):
190
+ """
191
+ Returns
192
+ -------
193
+ str
194
+ Scoring name.
195
+ """
196
+ return self._extract_scoring_name()
197
+
198
+ @property
199
+ def n_trials(self):
200
+ """
201
+ Returns
202
+ -------
203
+ int
204
+ Number of completed trials. Alias for `trial_count`.
205
+ """
206
+ self._check_is_fitted()
207
+ return len(self.trials)
208
+
209
+ # Alias for n_trials
210
+ trial_count = n_trials
211
+
212
+ @property
213
+ def trials_remaining(self):
214
+ """
215
+ Returns
216
+ -------
217
+ int
218
+ The number of trials remaining in the budget.
219
+
220
+ Raises
221
+ ------
222
+ :class:`ExitCriterionError`
223
+ Raised if the current tuner does not include a trials-based exit
224
+ condition.
225
+ """
226
+ if self._n_trials is None:
227
+ raise ExitCriterionError(
228
+ "This tuner does not include a trials-based exit condition"
229
+ )
230
+ return self._n_trials - self.n_trials + self._previous_trial_count
231
+
232
+ @property
233
+ def trials(self):
234
+ """
235
+ Returns
236
+ -------
237
+ :class:`pandas.DataFrame`
238
+ Trial data up to this point.
239
+ """
240
+ if self.is_halted():
241
+ if self._trial_dataframe is None:
242
+ return pd.DataFrame()
243
+ return self._trial_dataframe
244
+ trials_dataframe = self._study.trials_dataframe().copy()
245
+ return trials_dataframe
246
+
247
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
248
+ def __init__(
249
+ self,
250
+ model, # type: Union[BaseEstimator, Pipeline]
251
+ strategy="perfunctory", # type: Union[str, Mapping[str, optuna.distributions.BaseDistribution]]
252
+ scoring=None, # type: Optional[Union[Callable[..., float], str]]
253
+ cv=5, # type: Optional[int]
254
+ study_name=None, # type: Optional[str]
255
+ storage=None, # type: Optional[str]
256
+ load_if_exists=True, # type: Optional[bool]
257
+ random_state=None, # type: Optional[int]
258
+ loglevel=logging.INFO, # type: Optional[int]
259
+ n_jobs=1, # type: Optional[int]
260
+ X=None, # type: Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
261
+ y=None, # type: Optional[Union[OneDimArrayLikeType, TwoDimArrayLikeType]]
262
+ ):
263
+ # type: (...) -> None
264
+ """
265
+ Returns a hyperparameter tuning object
266
+
267
+ Parameters
268
+ ----------
269
+ model:
270
+ Object to use to fit the data. This is assumed to implement the
271
+ scikit-learn estimator or pipeline interface.
272
+ strategy:
273
+ ``perfunctory``, ``detailed`` or a dictionary/mapping of hyperparameter
274
+ and its distribution . If obj:`perfunctory`, picks a few
275
+ relatively more important hyperparmeters to tune . If obj:`detailed`,
276
+ extends to a larger search space. If obj:dict, user defined search
277
+ space: Dictionary where keys are hyperparameters and values are distributions.
278
+ Distributions are assumed to implement the ads distribution interface.
279
+ scoring: Optional[Union[Callable[..., float], str]]
280
+ String or callable to evaluate the predictions on the validation data.
281
+ If :obj:`None`, ``score`` on the estimator is used.
282
+ cv: int
283
+ Integer to specify the number of folds in a CV splitter.
284
+ If :obj:`estimator` is a classifier and :obj:`y` is
285
+ either binary or multiclass,
286
+ ``sklearn.model_selection.StratifiedKFold`` is used. otherwise,
287
+ ``sklearn.model_selection.KFold`` is used.
288
+ study_name: str,
289
+ Name of the current experiment for the ADSTuner object. One ADSTuner
290
+ object can only be attached to one study_name.
291
+ storage:
292
+ Database URL. (e.g. sqlite:///example.db). Default to sqlite:////tmp/hpo_*.db.
293
+ load_if_exists:
294
+ Flag to control the behavior to handle a conflict of study names.
295
+ In the case where a study named ``study_name`` already exists in the ``storage``,
296
+ a :class:`DuplicatedStudyError` is raised if ``load_if_exists`` is
297
+ set to :obj:`False`.
298
+ Otherwise, the existing one is returned.
299
+ random_state:
300
+ Seed of the pseudo random number generator. If int, this is the
301
+ seed used by the random number generator. If :obj:`None`, the global random state from
302
+ ``numpy.random`` is used.
303
+ loglevel:
304
+ loglevel. can be logging.NOTSET, logging.INFO, logging.DEBUG, logging.WARNING
305
+ n_jobs: int
306
+ Number of parallel jobs. :obj:`-1` means using all processors.
307
+ X: TwoDimArrayLikeType, Union[List[List[float]], np.ndarray,
308
+ pd.DataFrame, spmatrix, ADSData]
309
+ Training data.
310
+ y: Union[OneDimArrayLikeType, TwoDimArrayLikeType], optional
311
+ OneDimArrayLikeType: Union[List[float], np.ndarray, pd.Series]
312
+ TwoDimArrayLikeType: Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
313
+ Target.
314
+
315
+ Example::
316
+
317
+ from ads.hpo.stopping_criterion import *
318
+ from ads.hpo.search_cv import ADSTuner
319
+ from sklearn.datasets import load_iris
320
+ from sklearn.svm import SVC
321
+
322
+ tuner = ADSTuner(
323
+ SVC(),
324
+ strategy='detailed',
325
+ scoring='f1_weighted',
326
+ random_state=42
327
+ )
328
+
329
+ X, y = load_iris(return_X_y=True)
330
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
331
+ """
332
+ _imports.check()
333
+ self._n_jobs = n_jobs
334
+ assert (
335
+ cv > 1
336
+ ), "k-fold cross-validation requires at least one train/test split by setting cv=2 or more"
337
+ self.cv = cv
338
+ self._error_score = np.nan
339
+ self.model = model
340
+ self._check_pipeline()
341
+ self._step_name = None
342
+ self._extract_estimator()
343
+ self.strategy = None
344
+ self._param_distributions = None
345
+ self._check_strategy(strategy)
346
+ self.strategy = strategy
347
+ self._param_distributions = self._get_param_distributions(self.strategy)
348
+ self._enable_pruning = hasattr(self.model, "partial_fit")
349
+ self._max_iter = 100
350
+ self.__random_state = random_state # to be used in export_trials
351
+ # this calls the randomstate.setter which turns self.random_state into a np.random.RandomState instance
352
+ # make it hard to be serialized.
353
+ self.random_state = check_random_state(random_state)
354
+
355
+ self._return_train_score = False
356
+ self.scoring = scoring
357
+ self._subsample = 1.0
358
+ self.loglevel = loglevel
359
+ self._trial_dataframe = None
360
+ self._status = State.INITIATED
361
+ self.study_name = (
362
+ study_name if study_name is not None else "hpo_" + str(uuid.uuid4())
363
+ )
364
+ self.storage = (
365
+ "sqlite:////tmp/hpo_" + str(uuid.uuid4()) + ".db"
366
+ if storage is None
367
+ else storage
368
+ )
369
+ self.oci_client = None
370
+
371
+ seed = np.random.randint(0, np.iinfo("int32").max)
372
+
373
+ self.sampler = optuna.samplers.TPESampler(seed=seed)
374
+ self.median_pruner = self._pruner(
375
+ class_name="median_pruner",
376
+ n_startup_trials=5,
377
+ n_warmup_steps=1,
378
+ interval_steps=1,
379
+ )
380
+ self.load_if_exists = load_if_exists
381
+ try:
382
+ self._study = optuna.study.create_study(
383
+ study_name=self.study_name,
384
+ direction="maximize",
385
+ pruner=self.median_pruner,
386
+ sampler=self.sampler,
387
+ storage=self.storage,
388
+ load_if_exists=self.load_if_exists,
389
+ )
390
+ except optuna.exceptions.DuplicatedStudyError as e:
391
+ if self.load_if_exists:
392
+ logger.info(
393
+ "Using an existing study with name '{}' instead of "
394
+ "creating a new one.".format(self.study_name)
395
+ )
396
+ else:
397
+ raise DuplicatedStudyError(
398
+ f"The study_name `{self.study_name}` exists in the {self.storage}. Either set load_if_exists=True, or use a new study_name."
399
+ )
400
+ self._init_data(X, y)
401
+
402
+ def search_space(self, strategy=None, overwrite=False):
403
+ """
404
+ Returns the search space. If strategy is not passed in, return the existing search
405
+ space. When strategy is passed in, overwrite the existing search space if overwrite
406
+ is set True, otherwise, only update the existing search space.
407
+
408
+ Parameters
409
+ ----------
410
+ strategy: Union[str, dict], optional
411
+ ``perfunctory``, ``detailed`` or a dictionary/mapping of the hyperparameters
412
+ and their distributions. If obj:`perfunctory`, picks a few relatively
413
+ more important hyperparmeters to tune . If obj:`detailed`, extends to a
414
+ larger search space. If obj:dict, user defined search space: Dictionary
415
+ where keys are parameters and values are distributions. Distributions are
416
+ assumed to implement the ads distribution interface.
417
+ overwrite: bool, optional
418
+ Ignored when strategy is None. Otherwise, search space is overwritten if overwrite
419
+ is set True and updated if it is False.
420
+
421
+ Returns
422
+ -------
423
+ dict
424
+ A mapping of the hyperparameters and their distributions.
425
+
426
+ Example::
427
+
428
+ from ads.hpo.stopping_criterion import *
429
+ from ads.hpo.search_cv import ADSTuner
430
+ from sklearn.datasets import load_iris
431
+ from sklearn.linear_model import SGDClassifier
432
+
433
+ tuner = ADSTuner(
434
+ SGDClassifier(),
435
+ strategy='detailed',
436
+ scoring='f1_weighted',
437
+ random_state=42
438
+ )
439
+ tuner.search_space({'max_iter': 100})
440
+ X, y = load_iris(return_X_y=True)
441
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
442
+ tuner.search_space()
443
+ """
444
+ assert hasattr(
445
+ self, "_param_distributions"
446
+ ), "Call <code>ADSTuner</code> first."
447
+ if not strategy:
448
+ return self._remove_step_name(self._param_distributions)
449
+ self._check_strategy(strategy)
450
+ self.strategy = strategy
451
+ if overwrite:
452
+ self._param_distributions = self._get_param_distributions(self.strategy)
453
+ else:
454
+ self._param_distributions.update(
455
+ self._get_param_distributions(self.strategy)
456
+ )
457
+ return self._remove_step_name(self._param_distributions)
458
+
459
+ @staticmethod
460
+ def _remove_step_name(param_distributions):
461
+ search_space = {}
462
+ for param, distributions in param_distributions.items():
463
+ if "__" in param:
464
+ param = param.split("__")[1]
465
+ search_space[param] = distributions
466
+ return search_space
467
+
468
+ def _check_pipeline(self):
469
+ self.model = validate_pipeline(self.model)
470
+
471
+ def _get_internal_param_distributions(self, strategy):
472
+ if isinstance(self.model, Pipeline):
473
+ for step_name, step in self.model.steps:
474
+ if step.__class__ in get_model2searchspace().keys():
475
+ self._step_name = step_name
476
+ param_distributions = get_model2searchspace()[step.__class__](
477
+ strategy
478
+ ).suggest_space(step_name=step_name)
479
+ if len(param_distributions) == 0:
480
+ logger.warning("Nothing to tune.")
481
+ else:
482
+ assert_model_is_supported(self.model)
483
+ param_distributions = get_model2searchspace()[self.model.__class__](
484
+ strategy
485
+ ).suggest_space()
486
+ self._check_search_space(param_distributions)
487
+ return param_distributions
488
+
489
+ def _get_param_distributions(self, strategy):
490
+ if isinstance(strategy, str):
491
+ param_distributions = self._get_internal_param_distributions(strategy)
492
+ if isinstance(strategy, dict):
493
+ param_distributions = _update_space_name(
494
+ strategy, step_name=self._step_name
495
+ )
496
+ self._check_search_space(param_distributions)
497
+ return param_distributions
498
+
499
+ def _check_search_space(self, param_distributions):
500
+ validate_search_space(self.model.get_params().keys(), param_distributions)
501
+
502
+ def _check_is_fitted(self):
503
+ assert_tuner_is_fitted(self)
504
+
505
+ def _check_strategy(self, strategy):
506
+ assert_strategy_valid(self._param_distributions, strategy, self.strategy)
507
+
508
+ def _add_halt_time(self):
509
+ """Adds a new start time window to the start/stop log. This happens in two cases: when the tuning process
510
+ has commenced and when it resumes following a halt
511
+ """
512
+ self._time_log.append(dict(halt=time(), resume=None))
513
+
514
+ def _add_resume_time(self):
515
+ """Adds a new stopping time to the last window in the time log. This happens when the HPO process is
516
+ halted or terminated.
517
+ """
518
+ if len(self._time_log) > 0:
519
+ entry = self._time_log.pop()
520
+ if entry["resume"] is not None:
521
+ raise Exception("Cannot close a time window without an opening time.")
522
+ self._time_log.append(dict(halt=entry["halt"], resume=time()))
523
+
524
+ def tune(
525
+ self,
526
+ X=None, # type: TwoDimArrayLikeType
527
+ y=None, # type: Optional[Union[OneDimArrayLikeType, TwoDimArrayLikeType]]
528
+ exit_criterion=[], # type: Optional[list]
529
+ loglevel=None, # type: Optional[int]
530
+ synchronous=False, # type: Optional[boolean]
531
+ ):
532
+ """
533
+ Run hypyerparameter tuning until one of the <code>exit_criterion</code>
534
+ is met. The default is to run 50 trials.
535
+
536
+ Parameters
537
+ ----------
538
+ X: TwoDimArrayLikeType, Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
539
+
540
+ Training data.
541
+ y: Union[OneDimArrayLikeType, TwoDimArrayLikeType], optional
542
+ OneDimArrayLikeType: Union[List[float], np.ndarray, pd.Series]
543
+ TwoDimArrayLikeType: Union[List[List[float]], np.ndarray, pd.DataFrame, spmatrix, ADSData]
544
+
545
+ Target.
546
+ exit_criterion: list, optional
547
+ A list of ads stopping criterion. Can be `ScoreValue()`, `NTrials()`, `TimeBudget()`.
548
+ For example, [ScoreValue(0.96), NTrials(40), TimeBudget(10)]. It will exit when any of the
549
+ stopping criterion is satisfied in the `exit_criterion` list.
550
+ By default, the run will stop after 50 trials.
551
+ loglevel: int, optional
552
+ Log level.
553
+ synchronous: boolean, optional
554
+ Tune synchronously or not. Defaults to `False`
555
+
556
+ Returns
557
+ -------
558
+ None
559
+ Nothing
560
+
561
+ Example::
562
+
563
+ from ads.hpo.stopping_criterion import *
564
+ from ads.hpo.search_cv import ADSTuner
565
+ from sklearn.datasets import load_iris
566
+ from sklearn.svm import SVC
567
+
568
+ tuner = ADSTuner(
569
+ SVC(),
570
+ strategy='detailed',
571
+ scoring='f1_weighted',
572
+ random_state=42
573
+ )
574
+ tuner.search_space({'max_iter': 100})
575
+ X, y = load_iris(return_X_y=True)
576
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
577
+ """
578
+
579
+ # Get previous trial count to ensure proper counting.
580
+ try:
581
+ self._previous_trial_count = self.trial_count
582
+ except NotFittedError:
583
+ self._previous_trial_count = 0
584
+ except Exception as e:
585
+ _logger.error(f"Error retrieving previous trial count: {e}")
586
+ raise
587
+
588
+ self._init_data(X, y)
589
+ if self.X is None:
590
+ raise ValueError(
591
+ "Need to either pass the data to `X` and `y` in `tune()`, or to `ADSTuner`."
592
+ )
593
+ if self.is_running():
594
+ raise InvalidStateTransition(
595
+ "Running process found. Do you need to call terminate() to stop before calling tune()?"
596
+ )
597
+ if self.is_halted():
598
+ raise InvalidStateTransition(
599
+ "Halted process found. You need to call resume()."
600
+ )
601
+ # handle ADSData
602
+
603
+ # Initialize time log for every new call to tune(). Set shared global time values
604
+ self._global_start = multiprocessing.Value("d", 0.0)
605
+ self._global_stop = multiprocessing.Value("d", 0.0)
606
+ self._time_log = []
607
+
608
+ self._tune(
609
+ X=self.X,
610
+ y=self.y,
611
+ exit_criterion=exit_criterion,
612
+ loglevel=loglevel,
613
+ synchronous=synchronous,
614
+ )
615
+
616
+ # Tune cannot exit before the clock starts in the subprocess.
617
+ while self._global_start.value == 0.0:
618
+ sleep(0.01)
619
+
620
+ def _init_data(self, X, y):
621
+ if X is not None:
622
+ if isinstance(X, ADSData):
623
+ self.y = X.y
624
+ self.X = X.X
625
+ else:
626
+ self.X = X
627
+ self.y = y
628
+
629
+ def halt(self):
630
+ """
631
+ Halt the current running tuning process.
632
+
633
+ Returns
634
+ -------
635
+ None
636
+ Nothing
637
+
638
+ Raises
639
+ ------
640
+ `InvalidStateTransition` if no running process is found
641
+
642
+ Example::
643
+
644
+ from ads.hpo.stopping_criterion import *
645
+ from ads.hpo.search_cv import ADSTuner
646
+ from sklearn.datasets import load_iris
647
+ from sklearn.linear_model import SGDClassifier
648
+
649
+ tuner = ADSTuner(
650
+ SGDClassifier(),
651
+ strategy='detailed',
652
+ scoring='f1_weighted',
653
+ random_state=42
654
+ )
655
+ tuner.search_space({'max_iter': 100})
656
+ X, y = load_iris(return_X_y=True)
657
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
658
+ tuner.halt()
659
+ """
660
+ if hasattr(self, "_tune_process") and self._status == State.RUNNING:
661
+ self._trial_dataframe = self._study.trials_dataframe().copy()
662
+ psutil.Process(self._tune_process.pid).suspend()
663
+ self._status = State.HALTED
664
+ self._add_halt_time()
665
+ else:
666
+ raise InvalidStateTransition(
667
+ "No running process found. Do you need to call tune()?"
668
+ )
669
+
670
+ def resume(self):
671
+ """
672
+ Resume the current halted tuning process.
673
+
674
+ Returns
675
+ -------
676
+ None
677
+ Nothing
678
+
679
+ Example::
680
+
681
+ from ads.hpo.stopping_criterion import *
682
+ from ads.hpo.search_cv import ADSTuner
683
+ from sklearn.datasets import load_iris
684
+ from sklearn.linear_model import SGDClassifier
685
+
686
+ tuner = ADSTuner(
687
+ SGDClassifier(),
688
+ strategy='detailed',
689
+ scoring='f1_weighted',
690
+ random_state=42
691
+ )
692
+ tuner.search_space({'max_iter': 100})
693
+ X, y = load_iris(return_X_y=True)
694
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
695
+ tuner.halt()
696
+ tuner.resume()
697
+ """
698
+ if self.is_halted():
699
+ psutil.Process(self._tune_process.pid).resume()
700
+ self._add_resume_time()
701
+ self._status = State.RUNNING
702
+ else:
703
+ raise InvalidStateTransition("No paused process found.")
704
+
705
+ def wait(self):
706
+ """
707
+ Wait for the current tuning process to finish running.
708
+
709
+ Returns
710
+ -------
711
+ None
712
+ Nothing
713
+
714
+ Example::
715
+
716
+ from ads.hpo.stopping_criterion import *
717
+ from ads.hpo.search_cv import ADSTuner
718
+ from sklearn.datasets import load_iris
719
+ from sklearn.linear_model import SGDClassifier
720
+
721
+ tuner = ADSTuner(
722
+ SGDClassifier(),
723
+ strategy='detailed',
724
+ scoring='f1_weighted',
725
+ random_state=42
726
+ )
727
+ tuner.search_space({'max_iter': 100})
728
+ X, y = load_iris(return_X_y=True)
729
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
730
+ tuner.wait()
731
+ """
732
+ if self.is_running():
733
+ self._tune_process.join()
734
+ self._status = State.COMPLETED
735
+ else:
736
+ raise InvalidStateTransition("No running process.")
737
+
738
+ def terminate(self):
739
+ """
740
+ Terminate the current tuning process.
741
+
742
+ Returns
743
+ -------
744
+ None
745
+ Nothing
746
+
747
+ Example::
748
+
749
+ from ads.hpo.stopping_criterion import *
750
+ from ads.hpo.search_cv import ADSTuner
751
+ from sklearn.datasets import load_iris
752
+ from sklearn.linear_model import SGDClassifier
753
+
754
+ tuner = ADSTuner(
755
+ SGDClassifier(),
756
+ strategy='detailed',
757
+ scoring='f1_weighted',
758
+ random_state=42
759
+ )
760
+ tuner.search_space({'max_iter': 100})
761
+ X, y = load_iris(return_X_y=True)
762
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
763
+ tuner.terminate()
764
+ """
765
+ if self.is_running():
766
+ self._tune_process.terminate()
767
+ self._tune_process.join()
768
+ self._status = State.TERMINATED
769
+ # self._add_terminate_time()
770
+ self._update_failed_trial_state()
771
+ else:
772
+ raise RuntimeError("No running process found. Do you need to call tune()?")
773
+
774
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
775
+ def _update_failed_trial_state(self):
776
+ from optuna.trial import TrialState
777
+
778
+ for trial in self._study.trials:
779
+ if trial.state == TrialState.RUNNING:
780
+ self._study._storage.set_trial_state(
781
+ trial._trial_id, optuna.structs.TrialState.FAIL
782
+ )
783
+
784
+ @property
785
+ def time_remaining(self):
786
+ """Returns the number of seconds remaining in the study
787
+
788
+ Returns
789
+ -------
790
+ int: Number of seconds remaining in the budget. 0 if complete/terminated
791
+
792
+ Raises
793
+ ------
794
+ :class:`ExitCriterionError`
795
+ Error is raised if time has not been included in the budget.
796
+ """
797
+ if self._time_budget is None:
798
+ raise ExitCriterionError(
799
+ "This tuner does not include a time-based exit condition"
800
+ )
801
+ elif self.is_completed() or self.is_terminated():
802
+ return 0
803
+ return max(self._time_budget - self.time_elapsed, 0)
804
+
805
+ @property
806
+ def time_since_resume(self):
807
+ """Return the seconds since the process has been resumed from a halt.
808
+
809
+ Returns
810
+ -------
811
+ int: the number of seconds since the process was last resumed
812
+
813
+ Raises
814
+ ------
815
+ `NoRestartError` is the process has not been resumed
816
+
817
+ """
818
+ if len(self._time_log) > 0:
819
+ last_time_resumed = self._time_log[-1].get("resume")
820
+ else:
821
+ raise Exception("Time log should not be empty")
822
+
823
+ if self.is_running():
824
+ if last_time_resumed is not None:
825
+ return time() - last_time_resumed
826
+ else:
827
+ raise NoRestartError("The process has not been resumed")
828
+ elif self.is_halted():
829
+ return 0 # if halted, the amount of time since restarted from a halt is 0
830
+ elif self.is_terminated():
831
+ raise NoRestartError("The process has been terminated")
832
+
833
+ @property
834
+ def time_elapsed(self):
835
+ """Return the time in seconds that the HPO process has been searching
836
+
837
+ Returns
838
+ -------
839
+ int: The number of seconds the HPO process has been searching
840
+ """
841
+ time_in_halted_state = 0.0
842
+
843
+ # Add up all the halted durations, i.e. the time spent between halt and resume
844
+ for entry in self._time_log:
845
+ halt_time = entry.get("halt")
846
+ resume_time = entry.get("resume")
847
+
848
+ if resume_time is None:
849
+ # halted state.
850
+ # elapsed = halt time - global start - time halted
851
+ elapsed = halt_time - self._global_start.value - time_in_halted_state
852
+ return elapsed
853
+
854
+ else:
855
+ # running/completed/terminated state,
856
+ time_in_halted_state += resume_time - halt_time
857
+
858
+ # If the loop ends all halts were resumed. If self._global_stop != 0 that means the
859
+ # process has exited.
860
+ if self._global_stop.value != 0:
861
+ global_time = self._global_stop.value - self._global_start.value
862
+ else:
863
+ global_time = time() - self._global_start.value
864
+
865
+ elapsed = global_time - time_in_halted_state
866
+ return elapsed
867
+
868
+ def best_scores(self, n: int = 5, reverse: bool = True):
869
+ """Return the best scores from the study
870
+
871
+ Parameters
872
+ ----------
873
+ n: int
874
+ The maximum number of results to show. Defaults to 5. If `None` or
875
+ negative return all.
876
+ reverse: bool
877
+ Whether to reverse the sort order so results are in descending order.
878
+ Defaults to `True`
879
+
880
+ Returns
881
+ -------
882
+ list[float or int]
883
+ List of the best scores
884
+
885
+ Raises
886
+ ------
887
+ `ValueError` if there are no trials
888
+ """
889
+ if len(self.trials) < 1:
890
+ raise ValueError("No score data to show")
891
+ else:
892
+ scores = self.trials.value
893
+ scores = scores[scores.notnull()]
894
+ if scores is None:
895
+ raise ValueError(
896
+ f"No score data despite valid trial data. Trial data length: {len(self.trials)}"
897
+ )
898
+ if not isinstance(n, int) or n <= 0:
899
+ return sorted(scores, reverse=reverse)
900
+ else:
901
+ return sorted(scores, reverse=reverse)[:n]
902
+
903
+ def get_status(self):
904
+ """
905
+ return the status of the current tuning process.
906
+
907
+ Alias for the property `status`.
908
+
909
+ Returns
910
+ -------
911
+ :class:`Status`
912
+ The status of the process
913
+
914
+ Example::
915
+
916
+ from ads.hpo.stopping_criterion import *
917
+ from ads.hpo.search_cv import ADSTuner
918
+ from sklearn.datasets import load_iris
919
+ from sklearn.linear_model import SGDClassifier
920
+
921
+ tuner = ADSTuner(
922
+ SGDClassifier(),
923
+ strategy='detailed',
924
+ scoring='f1_weighted',
925
+ random_state=42
926
+ )
927
+ tuner.search_space({'max_iter': 100})
928
+ X, y = load_iris(return_X_y=True)
929
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)])
930
+ tuner.get_status()
931
+ """
932
+ return self.status
933
+
934
+ def is_running(self):
935
+ """
936
+ Returns
937
+ -------
938
+ bool
939
+ `True` if the :class:`ADSTuner` instance is running; `False` otherwise.
940
+ """
941
+ return self.status == State.RUNNING
942
+
943
+ def is_halted(self):
944
+ """
945
+ Returns
946
+ -------
947
+ bool
948
+ `True` if the :class:`ADSTuner` instance is halted; `False` otherwise.
949
+ """
950
+ return self.status == State.HALTED
951
+
952
+ def is_terminated(self):
953
+ """
954
+ Returns
955
+ -------
956
+ bool
957
+ `True` if the :class:`ADSTuner` instance has been terminated; `False` otherwise.
958
+ """
959
+ return self.status == State.TERMINATED
960
+
961
+ def is_completed(self):
962
+ """
963
+ Returns
964
+ -------
965
+ bool
966
+ `True` if the :class:`ADSTuner` instance has completed; `False` otherwise.
967
+ """
968
+ return self.status == State.COMPLETED
969
+
970
+ def _is_tuning_started(self):
971
+ """
972
+ Returns
973
+ -------
974
+ bool
975
+ `True` if the :class:`ADSTuner` instance has been started (for example, halted or
976
+ running); `False` otherwise.
977
+ """
978
+ return self.status == State.HALTED or self.status == State.RUNNING
979
+
980
+ def _is_tuning_finished(self):
981
+ """
982
+ Returns
983
+ -------
984
+ bool
985
+ `True` if the :class:`ADSTuner` instance is finished running (i.e. completed
986
+ or terminated); `False` otherwise.
987
+ """
988
+ return self.status == State.COMPLETED or self.status == State.TERMINATED
989
+
990
+ @property
991
+ def status(self):
992
+ """
993
+ Returns
994
+ -------
995
+ :class:`Status`
996
+ The status of the current tuning process.
997
+ """
998
+ if (
999
+ self._status == State.HALTED
1000
+ or self._status == State.TERMINATED
1001
+ or self._status == State.INITIATED
1002
+ ):
1003
+ return self._status
1004
+ elif hasattr(self, "_tune_process") and self._tune_process.is_alive():
1005
+ return State.RUNNING
1006
+ else:
1007
+ return State.COMPLETED
1008
+ return self._status
1009
+
1010
+ def _extract_exit_criterion(self, exit_criterion):
1011
+ # handle the exit criterion
1012
+ self._time_budget = None
1013
+ self._n_trials = None
1014
+ self.exit_criterion = []
1015
+ self._optimal_score = None
1016
+ if exit_criterion is None or len(exit_criterion) == 0:
1017
+ self._n_trials = 50
1018
+ for i, criteria in enumerate(exit_criterion):
1019
+ if isinstance(criteria, TimeBudget):
1020
+ self._time_budget = criteria()
1021
+ elif isinstance(criteria, NTrials):
1022
+ self._n_trials = criteria()
1023
+ elif isinstance(criteria, ScoreValue):
1024
+ self._optimal_score = criteria.score
1025
+ self.exit_criterion.append(criteria)
1026
+ else:
1027
+ raise NotImplementedError(
1028
+ "``{}`` is not supported!".format(criteria.__class__.__name__)
1029
+ )
1030
+
1031
+ def _extract_estimator(self):
1032
+ if isinstance(self.model, Pipeline): # Pipeline
1033
+ for step_name, step in self.model.steps:
1034
+ if self._is_estimator(step):
1035
+ self._step_name = step_name
1036
+ self.estimator = step
1037
+
1038
+ else:
1039
+ self.estimator = self.model
1040
+ assert_is_estimator(self.estimator)
1041
+ # assert _check_estimator(self.estimator), "Estimator must implement fit"
1042
+
1043
+ def _extract_scoring_name(self):
1044
+ if isinstance(self.scoring, str):
1045
+ return self.scoring
1046
+ if not callable(self._scorer):
1047
+ return (
1048
+ self._scorer
1049
+ if isinstance(self._scorer, str)
1050
+ else str(self._scorer).split("(")[1].split(")")[0]
1051
+ )
1052
+ else:
1053
+ if is_classifier(self.model):
1054
+ return "mean accuracy"
1055
+ else:
1056
+ return "r2"
1057
+
1058
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
1059
+ def _set_logger(self, loglevel, class_name):
1060
+ if loglevel is not None:
1061
+ self.loglevel = loglevel
1062
+ if class_name == "optuna":
1063
+ optuna.logging.set_verbosity(self.loglevel)
1064
+ else:
1065
+ raise NotImplementedError("{} is not supported.".format(class_name))
1066
+
1067
+ def _set_sample_indices(self, X, random_state):
1068
+ max_samples = self._subsample
1069
+ n_samples = _num_samples(X)
1070
+ self._sample_indices = np.arange(n_samples)
1071
+
1072
+ if isinstance(max_samples, float):
1073
+ max_samples = int(max_samples * n_samples)
1074
+
1075
+ if max_samples < n_samples:
1076
+ self._sample_indices = random_state.choice(
1077
+ self._sample_indices, max_samples, replace=False
1078
+ )
1079
+
1080
+ self._sample_indices.sort()
1081
+
1082
+ def _get_fit_params_res(self, X):
1083
+ fit_params = {}
1084
+ fit_params_res = fit_params
1085
+
1086
+ if fit_params_res is not None:
1087
+ fit_params_res = validate_fit_params(X, fit_params, self._sample_indices)
1088
+ return fit_params_res
1089
+
1090
+ def _can_tune(self):
1091
+ assert hasattr(self, "model"), "Call <code>ADSTuner</code> first."
1092
+ if self._param_distributions == {}:
1093
+ logger.warning("Nothing to tune.")
1094
+
1095
+ if self._param_distributions is None:
1096
+ raise NotImplementedError(
1097
+ "There was no model specified or the model is not supported."
1098
+ )
1099
+
1100
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
1101
+ def _tune(
1102
+ self,
1103
+ X, # type: TwoDimArrayLikeType
1104
+ y, # type: Optional[Union[OneDimArrayLikeType, TwoDimArrayLikeType]]
1105
+ exit_criterion=[], # type: Optional[list]
1106
+ loglevel=None, # type: Optional[int]
1107
+ synchronous=False, # type: Optional[boolean]
1108
+ ):
1109
+ # type: (...) -> tuple
1110
+ """
1111
+ Tune with all sets of parameters.
1112
+ """
1113
+ self._can_tune()
1114
+ self._set_logger(loglevel=loglevel, class_name="optuna")
1115
+ self._extract_exit_criterion(exit_criterion)
1116
+ self._extract_estimator()
1117
+ random_state = self.random_state
1118
+ old_level = logger.getEffectiveLevel()
1119
+ logger.setLevel(self.loglevel)
1120
+ if not synchronous:
1121
+ optuna.logging.set_verbosity(optuna.logging.ERROR)
1122
+ logger.setLevel(logging.ERROR)
1123
+
1124
+ self._set_sample_indices(X, random_state)
1125
+ X_res = _safe_indexing(X, self._sample_indices)
1126
+ y_res = _safe_indexing(y, self._sample_indices)
1127
+ groups_res = _safe_indexing(None, self._sample_indices)
1128
+ fit_params_res = self._get_fit_params_res(X)
1129
+
1130
+ classifier = is_classifier(self.model)
1131
+ cv = check_cv(self.cv, y_res, classifier=classifier)
1132
+ self._n_splits = cv.get_n_splits(X_res, y_res, groups=groups_res)
1133
+
1134
+ # scoring
1135
+ self._scorer = check_scoring(self.estimator, scoring=self.scoring)
1136
+
1137
+ self._study = optuna.study.create_study(
1138
+ study_name=self.study_name,
1139
+ direction="maximize",
1140
+ pruner=self.median_pruner,
1141
+ sampler=self.sampler,
1142
+ storage=self.storage,
1143
+ load_if_exists=self.load_if_exists,
1144
+ )
1145
+ objective = _Objective(
1146
+ self.model,
1147
+ self._param_distributions,
1148
+ cv,
1149
+ self._enable_pruning,
1150
+ self._error_score,
1151
+ fit_params_res,
1152
+ groups_res,
1153
+ self._max_iter,
1154
+ self._return_train_score,
1155
+ self._scorer,
1156
+ self.scoring_name,
1157
+ self._step_name,
1158
+ )
1159
+
1160
+ if synchronous:
1161
+ logger.info(
1162
+ "Optimizing hyperparameters using {} "
1163
+ "samples...".format(_num_samples(self._sample_indices))
1164
+ )
1165
+
1166
+ self._tune_process = multiprocessing.Process(
1167
+ target=ADSTuner.optimizer,
1168
+ args=(
1169
+ self.study_name,
1170
+ self.median_pruner,
1171
+ self.sampler,
1172
+ self.storage,
1173
+ self.load_if_exists,
1174
+ DataScienceObjective(objective, X_res, y_res),
1175
+ self._global_start,
1176
+ self._global_stop,
1177
+ ),
1178
+ kwargs=dict(
1179
+ n_jobs=self._n_jobs,
1180
+ n_trials=self._n_trials,
1181
+ timeout=self._time_budget,
1182
+ show_progress_bar=False,
1183
+ callbacks=self.exit_criterion,
1184
+ gc_after_trial=False,
1185
+ ),
1186
+ )
1187
+
1188
+ self._tune_process.start()
1189
+ self._status = State.RUNNING
1190
+
1191
+ if synchronous:
1192
+ self._tune_process.join()
1193
+ logger.info("Finished hyperparemeter search!")
1194
+ self._status = State.COMPLETED
1195
+
1196
+ logger.setLevel(old_level)
1197
+
1198
+ @staticmethod
1199
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
1200
+ def optimizer(
1201
+ study_name,
1202
+ pruner,
1203
+ sampler,
1204
+ storage,
1205
+ load_if_exists,
1206
+ objective_func,
1207
+ global_start,
1208
+ global_stop,
1209
+ **kwargs,
1210
+ ):
1211
+ """
1212
+ Static method for running ADSTuner tuning process
1213
+
1214
+ Parameters
1215
+ ----------
1216
+ study_name: str
1217
+ The name of the study.
1218
+ pruner
1219
+ The pruning method for pruning trials.
1220
+ sampler
1221
+ The sampling method used for tuning.
1222
+ storage: str
1223
+ Storage endpoint.
1224
+ load_if_exists: bool
1225
+ Load existing study if it exists.
1226
+ objective_func
1227
+ The objective function to be maximized.
1228
+ global_start: :class:`multiprocesing.Value`
1229
+ The global start time.
1230
+ global_stop: :class:`multiprocessing.Value`
1231
+ The global stop time.
1232
+ kwargs: dict
1233
+ Keyword/value pairs passed into the optimize process
1234
+
1235
+
1236
+ Raises
1237
+ ------
1238
+ :class:`Exception`
1239
+ Raised for any exceptions thrown by the underlying optimization process
1240
+
1241
+ Returns
1242
+ -------
1243
+ None
1244
+ Nothing
1245
+
1246
+ """
1247
+ import traceback
1248
+
1249
+ study = optuna.study.create_study(
1250
+ study_name=study_name,
1251
+ direction="maximize",
1252
+ pruner=pruner,
1253
+ sampler=sampler,
1254
+ storage=storage,
1255
+ load_if_exists=load_if_exists,
1256
+ )
1257
+ try:
1258
+ global_start.value = time()
1259
+ study.optimize(objective_func, **kwargs)
1260
+ global_stop.value = time()
1261
+ except Exception as e:
1262
+ traceback.print_exc()
1263
+ raise e
1264
+
1265
+ @staticmethod
1266
+ def _is_estimator(step):
1267
+ return hasattr(step, "fit") and (
1268
+ not hasattr(step, "transform")
1269
+ or hasattr(step, "predict")
1270
+ or hasattr(step, "fit_predict")
1271
+ )
1272
+
1273
+ @staticmethod
1274
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
1275
+ def _pruner(class_name, **kwargs):
1276
+ if class_name == "median_pruner":
1277
+ return optuna.pruners.MedianPruner(**kwargs)
1278
+ else:
1279
+ raise NotImplementedError("{} is not supported.".format(class_name))
1280
+
1281
+ def trials_export(
1282
+ self, file_uri, metadata=None, script_dict={"model": None, "scoring": None}
1283
+ ):
1284
+ """Export the meta data as well as files needed to reconstruct the ADSTuner object to the object storage.
1285
+ Data is not stored. To resume the same ADSTuner object from object storage and continue tuning from previous trials,
1286
+ you have to provide the dataset.
1287
+
1288
+ Parameters
1289
+ ----------
1290
+ file_uri: str
1291
+ Object storage path, 'oci://bucketname@namespace/filepath/on/objectstorage'. For example,
1292
+ `oci://test_bucket@ociodsccust/tuner/test.zip`
1293
+ metadata: str, optional
1294
+ User defined metadata
1295
+ script_dict: dict, optional
1296
+ Script paths for model and scoring. This is only recommended for unsupported
1297
+ models and user-defined scoring functions. You can store the model and scoring
1298
+ function in a dictionary with keys `model` and `scoring` and the respective
1299
+ paths as values. The model and scoring scripts must import necessary libraries
1300
+ for the script to run. The ``model`` and ``scoring`` variables must be set to
1301
+ your model and scoring function.
1302
+
1303
+ Returns
1304
+ -------
1305
+ None
1306
+ Nothing
1307
+
1308
+ Example::
1309
+
1310
+ # Print out a list of supported models
1311
+ from ads.hpo.ads_search_space import model_list
1312
+ print(model_list)
1313
+
1314
+ # Example scoring dictionary
1315
+ {'model':'/home/datascience/advanced-ds/notebooks/scratch/ADSTunerV2/mymodel.py',
1316
+ 'scoring':'/home/datascience/advanced-ds/notebooks/scratch/ADSTunerV2/customized_scoring.py'}
1317
+
1318
+ Example::
1319
+
1320
+ from ads.hpo.stopping_criterion import *
1321
+ from ads.hpo.search_cv import ADSTuner
1322
+ from sklearn.datasets import load_iris
1323
+ from sklearn.linear_model import SGDClassifier
1324
+
1325
+ tuner = ADSTuner(
1326
+ SGDClassifier(),
1327
+ strategy='detailed',
1328
+ scoring='f1_weighted',
1329
+ random_state=42
1330
+ )
1331
+ tuner.search_space({'max_iter': 100})
1332
+ X, y = load_iris(return_X_y=True)
1333
+ tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)], synchronous=True)
1334
+ tuner.trials_export('oci://<bucket_name>@<namespace>/tuner/test.zip')
1335
+ """
1336
+ # oci://bucketname@namespace/filename
1337
+ from ads.hpo.tuner_artifact import UploadTunerArtifact
1338
+
1339
+ assert self._is_tuning_finished()
1340
+ assert script_dict.keys() <= set(
1341
+ ["model", "scoring"]
1342
+ ), "script_dict keys can only be model and scoring."
1343
+
1344
+ UploadTunerArtifact(self, file_uri, metadata).upload(script_dict)
1345
+
1346
+ @classmethod
1347
+ def trials_import(cls, file_uri, delete_zip_file=True, target_file_path=None):
1348
+ """Import the database file from the object storage
1349
+
1350
+ Parameters
1351
+ ----------
1352
+ file_uri: str
1353
+ 'oci://bucketname@namespace/filepath/on/objectstorage'
1354
+ Example: 'oci://<bucket_name>@<namespace>/tuner/test.zip'
1355
+ delete_zip_file: bool, defaults to True, optional
1356
+ Whether delete the zip file afterwards.
1357
+ target_file_path: str, optional
1358
+ The path where the zip file will be saved. For example, '/home/datascience/myfile.zip'.
1359
+
1360
+ Returns
1361
+ -------
1362
+ :class:`ADSTuner`
1363
+ ADSTuner object
1364
+
1365
+ Examples
1366
+ --------
1367
+ >>> from ads.hpo.stopping_criterion import *
1368
+ >>> from ads.hpo.search_cv import ADSTuner
1369
+ >>> from sklearn.datasets import load_iris
1370
+ >>> from sklearn.linear_model import SGDClassifier
1371
+ >>> X, y = load_iris(return_X_y=True)
1372
+ >>> tuner = ADSTuner.trials_import('oci://<bucket_name>@<namespace>/tuner/test.zip')
1373
+ >>> tuner.tune(X=X, y=y, exit_criterion=[TimeBudget(1)], synchronous=True)
1374
+ """
1375
+ from ads.hpo.tuner_artifact import DownloadTunerArtifact
1376
+
1377
+ tuner_args, cls.metadata = DownloadTunerArtifact(
1378
+ file_uri, target_file_path=target_file_path
1379
+ ).extract_tuner_args(delete_zip_file=delete_zip_file)
1380
+ return cls(**tuner_args)
1381
+
1382
+ @runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK)
1383
+ def _plot(
1384
+ self, # type: ADSTuner
1385
+ plot_module, # type: str
1386
+ plot_func, # type: str
1387
+ time_interval=0.5, # type: float
1388
+ fig_size=(800, 500), # type: tuple
1389
+ **kwargs,
1390
+ ):
1391
+ if fig_size:
1392
+ logger.warning(
1393
+ "The param fig_size will be depreciated in future releases.",
1394
+ )
1395
+
1396
+ spec = importlib.util.spec_from_file_location(
1397
+ "plot",
1398
+ os.path.join(
1399
+ os.path.dirname(os.path.abspath(__file__)),
1400
+ "visualization",
1401
+ plot_module + ".py",
1402
+ ),
1403
+ )
1404
+ plot = importlib.util.module_from_spec(spec)
1405
+ spec.loader.exec_module(plot)
1406
+
1407
+ _imports.check()
1408
+ assert self._study is not None, "Need to call <code>.tune()</code> first."
1409
+ ntrials = 0
1410
+ if plot_func == "_plot_param_importances":
1411
+ print("Waiting for more trials before evaluating the param importance.")
1412
+ while self.status == State.RUNNING:
1413
+ import time
1414
+ from IPython.display import clear_output
1415
+
1416
+ time.sleep(time_interval)
1417
+ if len(self.trials[~self.trials["value"].isnull()]) > ntrials:
1418
+ if plot_func == "_plot_param_importances":
1419
+ if len(self.trials[~self.trials["value"].isnull()]) >= 4:
1420
+ clear_output(wait=True)
1421
+ getattr(plot, plot_func)(
1422
+ study=self._study, fig_size=fig_size, **kwargs
1423
+ )
1424
+ clear_output(wait=True)
1425
+ else:
1426
+ getattr(plot, plot_func)(
1427
+ study=self._study, fig_size=fig_size, **kwargs
1428
+ )
1429
+ clear_output(wait=True)
1430
+ if len(self.trials) == 0:
1431
+ plt.figure()
1432
+ plt.title("Intermediate Values Plot")
1433
+ plt.xlabel("Step")
1434
+ plt.ylabel("Intermediate Value")
1435
+ plt.show(block=False)
1436
+
1437
+ ntrials = len(self.trials[~self.trials["value"].isnull()])
1438
+ getattr(plot, plot_func)(study=self._study, fig_size=fig_size, **kwargs)
1439
+
1440
+ def plot_best_scores(
1441
+ self,
1442
+ best=True, # type: bool
1443
+ inferior=True, # type: bool
1444
+ time_interval=1, # type: float
1445
+ fig_size=(800, 500), # type: tuple
1446
+ ):
1447
+ """Plot optimization history of all trials in a study.
1448
+
1449
+ Parameters
1450
+ ----------
1451
+ best:
1452
+ controls whether to plot the lines for the best scores so far.
1453
+ inferior:
1454
+ controls whether to plot the dots for the actual objective scores.
1455
+ time_interval:
1456
+ how often(in seconds) the plot refresh to check on the new trial results.
1457
+ fig_size: tuple
1458
+ width and height of the figure.
1459
+
1460
+ Returns
1461
+ -------
1462
+ None
1463
+ Nothing.
1464
+ """
1465
+ self._plot(
1466
+ "_optimization_history",
1467
+ "_get_optimization_history_plot",
1468
+ time_interval=time_interval,
1469
+ fig_size=fig_size,
1470
+ best=best,
1471
+ inferior=inferior,
1472
+ )
1473
+
1474
+ @runtime_dependency(module="optuna", install_from=OptionalDependency.OPTUNA)
1475
+ def plot_param_importance(
1476
+ self,
1477
+ importance_evaluator="Fanova", # type: str
1478
+ time_interval=1, # type: float
1479
+ fig_size=(800, 500), # type: tuple
1480
+ ):
1481
+ """Plot hyperparameter importances.
1482
+
1483
+ Parameters
1484
+ ----------
1485
+ importance_evaluator: str
1486
+ Importance evaluator. Valid values: "Fanova", "MeanDecreaseImpurity". Defaults
1487
+ to "Fanova".
1488
+ time_interval: float
1489
+ How often the plot refresh to check on the new trial results.
1490
+ fig_size: tuple
1491
+ Width and height of the figure.
1492
+
1493
+ Raises
1494
+ ------
1495
+ :class:`NotImplementedErorr`
1496
+ Raised for unsupported importance evaluators
1497
+
1498
+ Returns
1499
+ -------
1500
+ None
1501
+ Nothing.
1502
+ """
1503
+ assert importance_evaluator in [
1504
+ "MeanDecreaseImpurity",
1505
+ "Fanova",
1506
+ ], "Only support <code>MeanDecreaseImpurity</code> and <code>Fanova</code>."
1507
+ if importance_evaluator == "Fanova":
1508
+ evaluator = None
1509
+ elif importance_evaluator == "MeanDecreaseImpurity":
1510
+ evaluator = optuna.importance.MeanDecreaseImpurityImportanceEvaluator()
1511
+ else:
1512
+ raise NotImplemented(
1513
+ f"{importance_evaluator} is not supported. It can be either `Fanova` or `MeanDecreaseImpurity`."
1514
+ )
1515
+ try:
1516
+ self._plot(
1517
+ plot_module="_param_importances",
1518
+ plot_func="_plot_param_importances",
1519
+ time_interval=time_interval,
1520
+ fig_size=fig_size,
1521
+ evaluator=evaluator,
1522
+ )
1523
+ except:
1524
+ logger.error(
1525
+ msg="""Cannot calculate the hyperparameter importance. Increase the number of trials or time budget. """
1526
+ )
1527
+
1528
+ def plot_intermediate_scores(
1529
+ self,
1530
+ time_interval=1, # type: float
1531
+ fig_size=(800, 500), # type: tuple
1532
+ ):
1533
+ """
1534
+ Plot intermediate values of all trials in a study.
1535
+
1536
+ Parameters
1537
+ ----------
1538
+ time_interval: float
1539
+ Time interval for the plot. Defaults to 1.
1540
+ fig_size: tuple[int, int]
1541
+ Figure size. Defaults to (800, 500).
1542
+
1543
+ Returns
1544
+ -------
1545
+ None
1546
+ Nothing.
1547
+ """
1548
+ if not self._enable_pruning:
1549
+ logger.error(
1550
+ msg="Pruning was not used during tuning. "
1551
+ "There are no intermediate values to plot."
1552
+ )
1553
+
1554
+ self._plot(
1555
+ "_intermediate_values",
1556
+ "_get_intermediate_plot",
1557
+ time_interval=time_interval,
1558
+ fig_size=fig_size,
1559
+ )
1560
+
1561
+ def plot_edf_scores(
1562
+ self,
1563
+ time_interval=1, # type: float
1564
+ fig_size=(800, 500), # type: tuple
1565
+ ):
1566
+ """
1567
+ Plot the EDF (empirical distribution function) of the scores.
1568
+
1569
+ Only completed trials are used.
1570
+
1571
+ Parameters
1572
+ ----------
1573
+ time_interval: float
1574
+ Time interval for the plot. Defaults to 1.
1575
+ fig_size: tuple[int, int]
1576
+ Figure size. Defaults to (800, 500).
1577
+
1578
+ Returns
1579
+ -------
1580
+ None
1581
+ Nothing.
1582
+ """
1583
+ self._plot(
1584
+ "_edf", "_get_edf_plot", time_interval=time_interval, fig_size=fig_size
1585
+ )
1586
+
1587
+ def plot_contour_scores(
1588
+ self,
1589
+ params=None, # type: Optional[List[str]]
1590
+ time_interval=1, # type: float
1591
+ fig_size=(800, 500), # type: tuple
1592
+ ):
1593
+ """
1594
+ Contour plot of the scores.
1595
+
1596
+ Parameters
1597
+ ----------
1598
+ params: Optional[List[str]]
1599
+ Parameter list to visualize. Defaults to all.
1600
+ time_interval: float
1601
+ Time interval for the plot. Defaults to 1.
1602
+ fig_size: tuple[int, int]
1603
+ Figure size. Defaults to (800, 500).
1604
+
1605
+ Returns
1606
+ -------
1607
+ None
1608
+ Nothing.
1609
+ """
1610
+ validate_params_for_plot(params, self._param_distributions)
1611
+ try:
1612
+ self._plot(
1613
+ "_contour",
1614
+ "_get_contour_plot",
1615
+ time_interval=time_interval,
1616
+ fig_size=fig_size,
1617
+ params=params,
1618
+ )
1619
+ except ValueError:
1620
+ logger.warning(
1621
+ msg="Cannot plot contour score."
1622
+ " Increase the number of trials or time budget."
1623
+ )
1624
+
1625
+ def plot_parallel_coordinate_scores(
1626
+ self,
1627
+ params=None, # type: Optional[List[str]]
1628
+ time_interval=1, # type: float
1629
+ fig_size=(800, 500), # type: tuple
1630
+ ):
1631
+ """
1632
+ Plot the high-dimentional parameter relationships in a study.
1633
+
1634
+ Note that, If a parameter contains missing values, a trial with missing values is not plotted.
1635
+
1636
+ Parameters
1637
+ ----------
1638
+ params: Optional[List[str]]
1639
+ Parameter list to visualize. Defaults to all.
1640
+ time_interval: float
1641
+ Time interval for the plot. Defaults to 1.
1642
+ fig_size: tuple[int, int]
1643
+ Figure size. Defaults to (800, 500).
1644
+
1645
+ Returns
1646
+ -------
1647
+ None
1648
+ Nothing.
1649
+ """
1650
+ validate_params_for_plot(params, self._param_distributions)
1651
+ self._plot(
1652
+ "_parallel_coordinate",
1653
+ "_get_parallel_coordinate_plot",
1654
+ time_interval=time_interval,
1655
+ fig_size=fig_size,
1656
+ params=params,
1657
+ )