oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.9rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (857) hide show
  1. ads/aqua/__init__.py +40 -0
  2. ads/aqua/app.py +506 -0
  3. ads/aqua/cli.py +96 -0
  4. ads/aqua/client/__init__.py +3 -0
  5. ads/aqua/client/client.py +836 -0
  6. ads/aqua/client/openai_client.py +305 -0
  7. ads/aqua/common/__init__.py +5 -0
  8. ads/aqua/common/decorator.py +125 -0
  9. ads/aqua/common/entities.py +269 -0
  10. ads/aqua/common/enums.py +122 -0
  11. ads/aqua/common/errors.py +109 -0
  12. ads/aqua/common/utils.py +1285 -0
  13. ads/aqua/config/__init__.py +4 -0
  14. ads/aqua/config/container_config.py +248 -0
  15. ads/aqua/config/evaluation/__init__.py +4 -0
  16. ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
  17. ads/aqua/config/utils/__init__.py +4 -0
  18. ads/aqua/config/utils/serializer.py +339 -0
  19. ads/aqua/constants.py +116 -0
  20. ads/aqua/data.py +14 -0
  21. ads/aqua/dummy_data/icon.txt +1 -0
  22. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  23. ads/aqua/dummy_data/oci_models.json +1 -0
  24. ads/aqua/dummy_data/readme.md +26 -0
  25. ads/aqua/evaluation/__init__.py +8 -0
  26. ads/aqua/evaluation/constants.py +53 -0
  27. ads/aqua/evaluation/entities.py +186 -0
  28. ads/aqua/evaluation/errors.py +70 -0
  29. ads/aqua/evaluation/evaluation.py +1814 -0
  30. ads/aqua/extension/__init__.py +42 -0
  31. ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
  32. ads/aqua/extension/base_handler.py +90 -0
  33. ads/aqua/extension/common_handler.py +121 -0
  34. ads/aqua/extension/common_ws_msg_handler.py +36 -0
  35. ads/aqua/extension/deployment_handler.py +298 -0
  36. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  37. ads/aqua/extension/errors.py +30 -0
  38. ads/aqua/extension/evaluation_handler.py +129 -0
  39. ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
  40. ads/aqua/extension/finetune_handler.py +96 -0
  41. ads/aqua/extension/model_handler.py +390 -0
  42. ads/aqua/extension/models/__init__.py +0 -0
  43. ads/aqua/extension/models/ws_models.py +145 -0
  44. ads/aqua/extension/models_ws_msg_handler.py +50 -0
  45. ads/aqua/extension/ui_handler.py +282 -0
  46. ads/aqua/extension/ui_websocket_handler.py +130 -0
  47. ads/aqua/extension/utils.py +133 -0
  48. ads/aqua/finetuning/__init__.py +7 -0
  49. ads/aqua/finetuning/constants.py +23 -0
  50. ads/aqua/finetuning/entities.py +181 -0
  51. ads/aqua/finetuning/finetuning.py +749 -0
  52. ads/aqua/model/__init__.py +8 -0
  53. ads/aqua/model/constants.py +60 -0
  54. ads/aqua/model/entities.py +385 -0
  55. ads/aqua/model/enums.py +32 -0
  56. ads/aqua/model/model.py +2114 -0
  57. ads/aqua/modeldeployment/__init__.py +8 -0
  58. ads/aqua/modeldeployment/constants.py +10 -0
  59. ads/aqua/modeldeployment/deployment.py +1326 -0
  60. ads/aqua/modeldeployment/entities.py +653 -0
  61. ads/aqua/modeldeployment/inference.py +74 -0
  62. ads/aqua/modeldeployment/utils.py +543 -0
  63. ads/aqua/resources/gpu_shapes_index.json +94 -0
  64. ads/aqua/server/__init__.py +4 -0
  65. ads/aqua/server/__main__.py +24 -0
  66. ads/aqua/server/app.py +47 -0
  67. ads/aqua/server/aqua_spec.yml +1291 -0
  68. ads/aqua/training/__init__.py +4 -0
  69. ads/aqua/training/exceptions.py +476 -0
  70. ads/aqua/ui.py +499 -0
  71. ads/automl/__init__.py +9 -0
  72. ads/automl/driver.py +330 -0
  73. ads/automl/provider.py +975 -0
  74. ads/bds/__init__.py +5 -0
  75. ads/bds/auth.py +127 -0
  76. ads/bds/big_data_service.py +255 -0
  77. ads/catalog/__init__.py +19 -0
  78. ads/catalog/model.py +1576 -0
  79. ads/catalog/notebook.py +461 -0
  80. ads/catalog/project.py +468 -0
  81. ads/catalog/summary.py +178 -0
  82. ads/common/__init__.py +11 -0
  83. ads/common/analyzer.py +65 -0
  84. ads/common/artifact/.model-ignore +63 -0
  85. ads/common/artifact/__init__.py +10 -0
  86. ads/common/auth.py +1122 -0
  87. ads/common/card_identifier.py +83 -0
  88. ads/common/config.py +647 -0
  89. ads/common/data.py +165 -0
  90. ads/common/decorator/__init__.py +9 -0
  91. ads/common/decorator/argument_to_case.py +88 -0
  92. ads/common/decorator/deprecate.py +69 -0
  93. ads/common/decorator/require_nonempty_arg.py +65 -0
  94. ads/common/decorator/runtime_dependency.py +178 -0
  95. ads/common/decorator/threaded.py +97 -0
  96. ads/common/decorator/utils.py +35 -0
  97. ads/common/dsc_file_system.py +303 -0
  98. ads/common/error.py +14 -0
  99. ads/common/extended_enum.py +81 -0
  100. ads/common/function/__init__.py +5 -0
  101. ads/common/function/fn_util.py +142 -0
  102. ads/common/function/func_conf.yaml +25 -0
  103. ads/common/ipython.py +76 -0
  104. ads/common/model.py +679 -0
  105. ads/common/model_artifact.py +1759 -0
  106. ads/common/model_artifact_schema.json +107 -0
  107. ads/common/model_export_util.py +664 -0
  108. ads/common/model_metadata.py +24 -0
  109. ads/common/object_storage_details.py +296 -0
  110. ads/common/oci_client.py +175 -0
  111. ads/common/oci_datascience.py +46 -0
  112. ads/common/oci_logging.py +1144 -0
  113. ads/common/oci_mixin.py +957 -0
  114. ads/common/oci_resource.py +136 -0
  115. ads/common/serializer.py +559 -0
  116. ads/common/utils.py +1852 -0
  117. ads/common/word_lists.py +1491 -0
  118. ads/common/work_request.py +189 -0
  119. ads/data_labeling/__init__.py +13 -0
  120. ads/data_labeling/boundingbox.py +253 -0
  121. ads/data_labeling/constants.py +47 -0
  122. ads/data_labeling/data_labeling_service.py +244 -0
  123. ads/data_labeling/interface/__init__.py +5 -0
  124. ads/data_labeling/interface/loader.py +16 -0
  125. ads/data_labeling/interface/parser.py +16 -0
  126. ads/data_labeling/interface/reader.py +23 -0
  127. ads/data_labeling/loader/__init__.py +5 -0
  128. ads/data_labeling/loader/file_loader.py +241 -0
  129. ads/data_labeling/metadata.py +110 -0
  130. ads/data_labeling/mixin/__init__.py +5 -0
  131. ads/data_labeling/mixin/data_labeling.py +232 -0
  132. ads/data_labeling/ner.py +129 -0
  133. ads/data_labeling/parser/__init__.py +5 -0
  134. ads/data_labeling/parser/dls_record_parser.py +388 -0
  135. ads/data_labeling/parser/export_metadata_parser.py +94 -0
  136. ads/data_labeling/parser/export_record_parser.py +473 -0
  137. ads/data_labeling/reader/__init__.py +5 -0
  138. ads/data_labeling/reader/dataset_reader.py +574 -0
  139. ads/data_labeling/reader/dls_record_reader.py +121 -0
  140. ads/data_labeling/reader/export_record_reader.py +62 -0
  141. ads/data_labeling/reader/jsonl_reader.py +75 -0
  142. ads/data_labeling/reader/metadata_reader.py +203 -0
  143. ads/data_labeling/reader/record_reader.py +263 -0
  144. ads/data_labeling/record.py +52 -0
  145. ads/data_labeling/visualizer/__init__.py +5 -0
  146. ads/data_labeling/visualizer/image_visualizer.py +525 -0
  147. ads/data_labeling/visualizer/text_visualizer.py +357 -0
  148. ads/database/__init__.py +5 -0
  149. ads/database/connection.py +338 -0
  150. ads/dataset/__init__.py +10 -0
  151. ads/dataset/capabilities.md +51 -0
  152. ads/dataset/classification_dataset.py +339 -0
  153. ads/dataset/correlation.py +226 -0
  154. ads/dataset/correlation_plot.py +563 -0
  155. ads/dataset/dask_series.py +173 -0
  156. ads/dataset/dataframe_transformer.py +110 -0
  157. ads/dataset/dataset.py +1979 -0
  158. ads/dataset/dataset_browser.py +360 -0
  159. ads/dataset/dataset_with_target.py +995 -0
  160. ads/dataset/exception.py +25 -0
  161. ads/dataset/factory.py +987 -0
  162. ads/dataset/feature_engineering_transformer.py +35 -0
  163. ads/dataset/feature_selection.py +107 -0
  164. ads/dataset/forecasting_dataset.py +26 -0
  165. ads/dataset/helper.py +1450 -0
  166. ads/dataset/label_encoder.py +99 -0
  167. ads/dataset/mixin/__init__.py +5 -0
  168. ads/dataset/mixin/dataset_accessor.py +134 -0
  169. ads/dataset/pipeline.py +58 -0
  170. ads/dataset/plot.py +710 -0
  171. ads/dataset/progress.py +86 -0
  172. ads/dataset/recommendation.py +297 -0
  173. ads/dataset/recommendation_transformer.py +502 -0
  174. ads/dataset/regression_dataset.py +14 -0
  175. ads/dataset/sampled_dataset.py +1050 -0
  176. ads/dataset/target.py +98 -0
  177. ads/dataset/timeseries.py +18 -0
  178. ads/dbmixin/__init__.py +5 -0
  179. ads/dbmixin/db_pandas_accessor.py +153 -0
  180. ads/environment/__init__.py +9 -0
  181. ads/environment/ml_runtime.py +66 -0
  182. ads/evaluations/README.md +14 -0
  183. ads/evaluations/__init__.py +109 -0
  184. ads/evaluations/evaluation_plot.py +983 -0
  185. ads/evaluations/evaluator.py +1334 -0
  186. ads/evaluations/statistical_metrics.py +543 -0
  187. ads/experiments/__init__.py +9 -0
  188. ads/experiments/capabilities.md +0 -0
  189. ads/explanations/__init__.py +21 -0
  190. ads/explanations/base_explainer.py +142 -0
  191. ads/explanations/capabilities.md +83 -0
  192. ads/explanations/explainer.py +190 -0
  193. ads/explanations/mlx_global_explainer.py +1050 -0
  194. ads/explanations/mlx_interface.py +386 -0
  195. ads/explanations/mlx_local_explainer.py +287 -0
  196. ads/explanations/mlx_whatif_explainer.py +201 -0
  197. ads/feature_engineering/__init__.py +20 -0
  198. ads/feature_engineering/accessor/__init__.py +5 -0
  199. ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
  200. ads/feature_engineering/accessor/mixin/__init__.py +5 -0
  201. ads/feature_engineering/accessor/mixin/correlation.py +166 -0
  202. ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
  203. ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
  204. ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
  205. ads/feature_engineering/accessor/mixin/utils.py +65 -0
  206. ads/feature_engineering/accessor/series_accessor.py +431 -0
  207. ads/feature_engineering/adsimage/__init__.py +5 -0
  208. ads/feature_engineering/adsimage/image.py +192 -0
  209. ads/feature_engineering/adsimage/image_reader.py +170 -0
  210. ads/feature_engineering/adsimage/interface/__init__.py +5 -0
  211. ads/feature_engineering/adsimage/interface/reader.py +19 -0
  212. ads/feature_engineering/adsstring/__init__.py +7 -0
  213. ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
  214. ads/feature_engineering/adsstring/string/__init__.py +8 -0
  215. ads/feature_engineering/data_schema.json +57 -0
  216. ads/feature_engineering/dataset/__init__.py +5 -0
  217. ads/feature_engineering/dataset/zip_code_data.py +42062 -0
  218. ads/feature_engineering/exceptions.py +40 -0
  219. ads/feature_engineering/feature_type/__init__.py +133 -0
  220. ads/feature_engineering/feature_type/address.py +184 -0
  221. ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
  222. ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
  223. ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
  224. ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
  225. ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
  226. ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
  227. ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
  228. ads/feature_engineering/feature_type/adsstring/string.py +258 -0
  229. ads/feature_engineering/feature_type/base.py +58 -0
  230. ads/feature_engineering/feature_type/boolean.py +183 -0
  231. ads/feature_engineering/feature_type/category.py +146 -0
  232. ads/feature_engineering/feature_type/constant.py +137 -0
  233. ads/feature_engineering/feature_type/continuous.py +151 -0
  234. ads/feature_engineering/feature_type/creditcard.py +314 -0
  235. ads/feature_engineering/feature_type/datetime.py +190 -0
  236. ads/feature_engineering/feature_type/discrete.py +134 -0
  237. ads/feature_engineering/feature_type/document.py +43 -0
  238. ads/feature_engineering/feature_type/gis.py +251 -0
  239. ads/feature_engineering/feature_type/handler/__init__.py +5 -0
  240. ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
  241. ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
  242. ads/feature_engineering/feature_type/handler/warnings.py +128 -0
  243. ads/feature_engineering/feature_type/integer.py +142 -0
  244. ads/feature_engineering/feature_type/ip_address.py +144 -0
  245. ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
  246. ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
  247. ads/feature_engineering/feature_type/lat_long.py +256 -0
  248. ads/feature_engineering/feature_type/object.py +43 -0
  249. ads/feature_engineering/feature_type/ordinal.py +132 -0
  250. ads/feature_engineering/feature_type/phone_number.py +135 -0
  251. ads/feature_engineering/feature_type/string.py +171 -0
  252. ads/feature_engineering/feature_type/text.py +93 -0
  253. ads/feature_engineering/feature_type/unknown.py +43 -0
  254. ads/feature_engineering/feature_type/zip_code.py +164 -0
  255. ads/feature_engineering/feature_type_manager.py +406 -0
  256. ads/feature_engineering/schema.py +795 -0
  257. ads/feature_engineering/utils.py +245 -0
  258. ads/feature_store/.readthedocs.yaml +19 -0
  259. ads/feature_store/README.md +65 -0
  260. ads/feature_store/__init__.py +9 -0
  261. ads/feature_store/common/__init__.py +0 -0
  262. ads/feature_store/common/enums.py +339 -0
  263. ads/feature_store/common/exceptions.py +18 -0
  264. ads/feature_store/common/spark_session_singleton.py +125 -0
  265. ads/feature_store/common/utils/__init__.py +0 -0
  266. ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
  267. ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
  268. ads/feature_store/common/utils/transformation_utils.py +82 -0
  269. ads/feature_store/common/utils/utility.py +403 -0
  270. ads/feature_store/data_validation/__init__.py +0 -0
  271. ads/feature_store/data_validation/great_expectation.py +129 -0
  272. ads/feature_store/dataset.py +1230 -0
  273. ads/feature_store/dataset_job.py +530 -0
  274. ads/feature_store/docs/Dockerfile +7 -0
  275. ads/feature_store/docs/Makefile +44 -0
  276. ads/feature_store/docs/conf.py +28 -0
  277. ads/feature_store/docs/requirements.txt +14 -0
  278. ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
  279. ads/feature_store/docs/source/cicd.rst +137 -0
  280. ads/feature_store/docs/source/conf.py +86 -0
  281. ads/feature_store/docs/source/data_versioning.rst +33 -0
  282. ads/feature_store/docs/source/dataset.rst +388 -0
  283. ads/feature_store/docs/source/dataset_job.rst +27 -0
  284. ads/feature_store/docs/source/demo.rst +70 -0
  285. ads/feature_store/docs/source/entity.rst +78 -0
  286. ads/feature_store/docs/source/feature_group.rst +624 -0
  287. ads/feature_store/docs/source/feature_group_job.rst +29 -0
  288. ads/feature_store/docs/source/feature_store.rst +122 -0
  289. ads/feature_store/docs/source/feature_store_class.rst +123 -0
  290. ads/feature_store/docs/source/feature_validation.rst +66 -0
  291. ads/feature_store/docs/source/figures/cicd.png +0 -0
  292. ads/feature_store/docs/source/figures/data_validation.png +0 -0
  293. ads/feature_store/docs/source/figures/data_versioning.png +0 -0
  294. ads/feature_store/docs/source/figures/dataset.gif +0 -0
  295. ads/feature_store/docs/source/figures/dataset.png +0 -0
  296. ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
  297. ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
  298. ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
  299. ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
  300. ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
  301. ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
  302. ads/feature_store/docs/source/figures/entity.png +0 -0
  303. ads/feature_store/docs/source/figures/feature_group.png +0 -0
  304. ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
  305. ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
  306. ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
  307. ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
  308. ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
  309. ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
  310. ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
  311. ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
  312. ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
  313. ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
  314. ads/feature_store/docs/source/figures/overview.png +0 -0
  315. ads/feature_store/docs/source/figures/resource_manager.png +0 -0
  316. ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
  317. ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
  318. ads/feature_store/docs/source/figures/stats_1.png +0 -0
  319. ads/feature_store/docs/source/figures/stats_2.png +0 -0
  320. ads/feature_store/docs/source/figures/stats_d.png +0 -0
  321. ads/feature_store/docs/source/figures/stats_fg.png +0 -0
  322. ads/feature_store/docs/source/figures/transformation.png +0 -0
  323. ads/feature_store/docs/source/figures/transformations.gif +0 -0
  324. ads/feature_store/docs/source/figures/validation.png +0 -0
  325. ads/feature_store/docs/source/figures/validation_fg.png +0 -0
  326. ads/feature_store/docs/source/figures/validation_results.png +0 -0
  327. ads/feature_store/docs/source/figures/validation_summary.png +0 -0
  328. ads/feature_store/docs/source/index.rst +81 -0
  329. ads/feature_store/docs/source/module.rst +8 -0
  330. ads/feature_store/docs/source/notebook.rst +94 -0
  331. ads/feature_store/docs/source/overview.rst +47 -0
  332. ads/feature_store/docs/source/quickstart.rst +176 -0
  333. ads/feature_store/docs/source/release_notes.rst +194 -0
  334. ads/feature_store/docs/source/setup_feature_store.rst +81 -0
  335. ads/feature_store/docs/source/statistics.rst +58 -0
  336. ads/feature_store/docs/source/transformation.rst +199 -0
  337. ads/feature_store/docs/source/ui.rst +65 -0
  338. ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
  339. ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
  340. ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
  341. ads/feature_store/entity.py +718 -0
  342. ads/feature_store/execution_strategy/__init__.py +0 -0
  343. ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
  344. ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
  345. ads/feature_store/execution_strategy/engine/__init__.py +0 -0
  346. ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
  347. ads/feature_store/execution_strategy/execution_strategy.py +113 -0
  348. ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
  349. ads/feature_store/execution_strategy/spark/__init__.py +0 -0
  350. ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
  351. ads/feature_store/feature.py +192 -0
  352. ads/feature_store/feature_group.py +1494 -0
  353. ads/feature_store/feature_group_expectation.py +346 -0
  354. ads/feature_store/feature_group_job.py +602 -0
  355. ads/feature_store/feature_lineage/__init__.py +0 -0
  356. ads/feature_store/feature_lineage/graphviz_service.py +180 -0
  357. ads/feature_store/feature_option_details.py +50 -0
  358. ads/feature_store/feature_statistics/__init__.py +0 -0
  359. ads/feature_store/feature_statistics/statistics_service.py +99 -0
  360. ads/feature_store/feature_store.py +699 -0
  361. ads/feature_store/feature_store_registrar.py +518 -0
  362. ads/feature_store/input_feature_detail.py +149 -0
  363. ads/feature_store/mixin/__init__.py +4 -0
  364. ads/feature_store/mixin/oci_feature_store.py +145 -0
  365. ads/feature_store/model_details.py +73 -0
  366. ads/feature_store/query/__init__.py +0 -0
  367. ads/feature_store/query/filter.py +266 -0
  368. ads/feature_store/query/generator/__init__.py +0 -0
  369. ads/feature_store/query/generator/query_generator.py +298 -0
  370. ads/feature_store/query/join.py +161 -0
  371. ads/feature_store/query/query.py +403 -0
  372. ads/feature_store/query/validator/__init__.py +0 -0
  373. ads/feature_store/query/validator/query_validator.py +57 -0
  374. ads/feature_store/response/__init__.py +0 -0
  375. ads/feature_store/response/response_builder.py +68 -0
  376. ads/feature_store/service/__init__.py +0 -0
  377. ads/feature_store/service/oci_dataset.py +139 -0
  378. ads/feature_store/service/oci_dataset_job.py +199 -0
  379. ads/feature_store/service/oci_entity.py +125 -0
  380. ads/feature_store/service/oci_feature_group.py +164 -0
  381. ads/feature_store/service/oci_feature_group_job.py +214 -0
  382. ads/feature_store/service/oci_feature_store.py +182 -0
  383. ads/feature_store/service/oci_lineage.py +87 -0
  384. ads/feature_store/service/oci_transformation.py +104 -0
  385. ads/feature_store/statistics/__init__.py +0 -0
  386. ads/feature_store/statistics/abs_feature_value.py +49 -0
  387. ads/feature_store/statistics/charts/__init__.py +0 -0
  388. ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
  389. ads/feature_store/statistics/charts/box_plot.py +148 -0
  390. ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
  391. ads/feature_store/statistics/charts/probability_distribution.py +68 -0
  392. ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
  393. ads/feature_store/statistics/feature_stat.py +126 -0
  394. ads/feature_store/statistics/generic_feature_value.py +33 -0
  395. ads/feature_store/statistics/statistics.py +41 -0
  396. ads/feature_store/statistics_config.py +101 -0
  397. ads/feature_store/templates/feature_store_template.yaml +45 -0
  398. ads/feature_store/transformation.py +499 -0
  399. ads/feature_store/validation_output.py +57 -0
  400. ads/hpo/__init__.py +9 -0
  401. ads/hpo/_imports.py +91 -0
  402. ads/hpo/ads_search_space.py +439 -0
  403. ads/hpo/distributions.py +325 -0
  404. ads/hpo/objective.py +280 -0
  405. ads/hpo/search_cv.py +1657 -0
  406. ads/hpo/stopping_criterion.py +75 -0
  407. ads/hpo/tuner_artifact.py +413 -0
  408. ads/hpo/utils.py +91 -0
  409. ads/hpo/validation.py +140 -0
  410. ads/hpo/visualization/__init__.py +5 -0
  411. ads/hpo/visualization/_contour.py +23 -0
  412. ads/hpo/visualization/_edf.py +20 -0
  413. ads/hpo/visualization/_intermediate_values.py +21 -0
  414. ads/hpo/visualization/_optimization_history.py +25 -0
  415. ads/hpo/visualization/_parallel_coordinate.py +169 -0
  416. ads/hpo/visualization/_param_importances.py +26 -0
  417. ads/jobs/__init__.py +53 -0
  418. ads/jobs/ads_job.py +663 -0
  419. ads/jobs/builders/__init__.py +5 -0
  420. ads/jobs/builders/base.py +156 -0
  421. ads/jobs/builders/infrastructure/__init__.py +6 -0
  422. ads/jobs/builders/infrastructure/base.py +165 -0
  423. ads/jobs/builders/infrastructure/dataflow.py +1252 -0
  424. ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
  425. ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
  426. ads/jobs/builders/infrastructure/utils.py +65 -0
  427. ads/jobs/builders/runtimes/__init__.py +5 -0
  428. ads/jobs/builders/runtimes/artifact.py +338 -0
  429. ads/jobs/builders/runtimes/base.py +325 -0
  430. ads/jobs/builders/runtimes/container_runtime.py +242 -0
  431. ads/jobs/builders/runtimes/python_runtime.py +1016 -0
  432. ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
  433. ads/jobs/cli.py +104 -0
  434. ads/jobs/env_var_parser.py +131 -0
  435. ads/jobs/extension.py +160 -0
  436. ads/jobs/schema/__init__.py +5 -0
  437. ads/jobs/schema/infrastructure_schema.json +116 -0
  438. ads/jobs/schema/job_schema.json +42 -0
  439. ads/jobs/schema/runtime_schema.json +183 -0
  440. ads/jobs/schema/validator.py +141 -0
  441. ads/jobs/serializer.py +296 -0
  442. ads/jobs/templates/__init__.py +5 -0
  443. ads/jobs/templates/container.py +6 -0
  444. ads/jobs/templates/driver_notebook.py +177 -0
  445. ads/jobs/templates/driver_oci.py +500 -0
  446. ads/jobs/templates/driver_python.py +48 -0
  447. ads/jobs/templates/driver_pytorch.py +852 -0
  448. ads/jobs/templates/driver_utils.py +615 -0
  449. ads/jobs/templates/hostname_from_env.c +55 -0
  450. ads/jobs/templates/oci_metrics.py +181 -0
  451. ads/jobs/utils.py +104 -0
  452. ads/llm/__init__.py +28 -0
  453. ads/llm/autogen/__init__.py +2 -0
  454. ads/llm/autogen/constants.py +15 -0
  455. ads/llm/autogen/reports/__init__.py +2 -0
  456. ads/llm/autogen/reports/base.py +67 -0
  457. ads/llm/autogen/reports/data.py +103 -0
  458. ads/llm/autogen/reports/session.py +526 -0
  459. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  460. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  461. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  462. ads/llm/autogen/reports/utils.py +56 -0
  463. ads/llm/autogen/v02/__init__.py +4 -0
  464. ads/llm/autogen/v02/client.py +295 -0
  465. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  466. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  467. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  468. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  469. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  470. ads/llm/autogen/v02/loggers/utils.py +86 -0
  471. ads/llm/autogen/v02/runtime_logging.py +163 -0
  472. ads/llm/chain.py +268 -0
  473. ads/llm/chat_template.py +31 -0
  474. ads/llm/deploy.py +63 -0
  475. ads/llm/guardrails/__init__.py +5 -0
  476. ads/llm/guardrails/base.py +442 -0
  477. ads/llm/guardrails/huggingface.py +44 -0
  478. ads/llm/langchain/__init__.py +5 -0
  479. ads/llm/langchain/plugins/__init__.py +5 -0
  480. ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
  481. ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
  482. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  483. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  484. ads/llm/langchain/plugins/llms/__init__.py +5 -0
  485. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
  486. ads/llm/requirements.txt +3 -0
  487. ads/llm/serialize.py +219 -0
  488. ads/llm/serializers/__init__.py +0 -0
  489. ads/llm/serializers/retrieval_qa.py +153 -0
  490. ads/llm/serializers/runnable_parallel.py +27 -0
  491. ads/llm/templates/score_chain.jinja2 +155 -0
  492. ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
  493. ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
  494. ads/model/__init__.py +52 -0
  495. ads/model/artifact.py +573 -0
  496. ads/model/artifact_downloader.py +254 -0
  497. ads/model/artifact_uploader.py +267 -0
  498. ads/model/base_properties.py +238 -0
  499. ads/model/common/.model-ignore +66 -0
  500. ads/model/common/__init__.py +5 -0
  501. ads/model/common/utils.py +142 -0
  502. ads/model/datascience_model.py +2635 -0
  503. ads/model/deployment/__init__.py +20 -0
  504. ads/model/deployment/common/__init__.py +5 -0
  505. ads/model/deployment/common/utils.py +308 -0
  506. ads/model/deployment/model_deployer.py +466 -0
  507. ads/model/deployment/model_deployment.py +1846 -0
  508. ads/model/deployment/model_deployment_infrastructure.py +671 -0
  509. ads/model/deployment/model_deployment_properties.py +493 -0
  510. ads/model/deployment/model_deployment_runtime.py +838 -0
  511. ads/model/extractor/__init__.py +5 -0
  512. ads/model/extractor/automl_extractor.py +74 -0
  513. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  514. ads/model/extractor/huggingface_extractor.py +88 -0
  515. ads/model/extractor/keras_extractor.py +84 -0
  516. ads/model/extractor/lightgbm_extractor.py +93 -0
  517. ads/model/extractor/model_info_extractor.py +114 -0
  518. ads/model/extractor/model_info_extractor_factory.py +105 -0
  519. ads/model/extractor/pytorch_extractor.py +87 -0
  520. ads/model/extractor/sklearn_extractor.py +112 -0
  521. ads/model/extractor/spark_extractor.py +89 -0
  522. ads/model/extractor/tensorflow_extractor.py +85 -0
  523. ads/model/extractor/xgboost_extractor.py +94 -0
  524. ads/model/framework/__init__.py +5 -0
  525. ads/model/framework/automl_model.py +178 -0
  526. ads/model/framework/embedding_onnx_model.py +438 -0
  527. ads/model/framework/huggingface_model.py +399 -0
  528. ads/model/framework/lightgbm_model.py +266 -0
  529. ads/model/framework/pytorch_model.py +266 -0
  530. ads/model/framework/sklearn_model.py +250 -0
  531. ads/model/framework/spark_model.py +326 -0
  532. ads/model/framework/tensorflow_model.py +254 -0
  533. ads/model/framework/xgboost_model.py +258 -0
  534. ads/model/generic_model.py +3518 -0
  535. ads/model/model_artifact_boilerplate/README.md +381 -0
  536. ads/model/model_artifact_boilerplate/__init__.py +5 -0
  537. ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
  538. ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
  539. ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
  540. ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
  541. ads/model/model_artifact_boilerplate/score.py +61 -0
  542. ads/model/model_file_description_schema.json +68 -0
  543. ads/model/model_introspect.py +331 -0
  544. ads/model/model_metadata.py +1810 -0
  545. ads/model/model_metadata_mixin.py +460 -0
  546. ads/model/model_properties.py +63 -0
  547. ads/model/model_version_set.py +739 -0
  548. ads/model/runtime/__init__.py +5 -0
  549. ads/model/runtime/env_info.py +306 -0
  550. ads/model/runtime/model_deployment_details.py +37 -0
  551. ads/model/runtime/model_provenance_details.py +58 -0
  552. ads/model/runtime/runtime_info.py +81 -0
  553. ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
  554. ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
  555. ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
  556. ads/model/runtime/utils.py +201 -0
  557. ads/model/serde/__init__.py +5 -0
  558. ads/model/serde/common.py +40 -0
  559. ads/model/serde/model_input.py +547 -0
  560. ads/model/serde/model_serializer.py +1184 -0
  561. ads/model/service/__init__.py +5 -0
  562. ads/model/service/oci_datascience_model.py +1076 -0
  563. ads/model/service/oci_datascience_model_deployment.py +500 -0
  564. ads/model/service/oci_datascience_model_version_set.py +176 -0
  565. ads/model/transformer/__init__.py +5 -0
  566. ads/model/transformer/onnx_transformer.py +324 -0
  567. ads/mysqldb/__init__.py +5 -0
  568. ads/mysqldb/mysql_db.py +227 -0
  569. ads/opctl/__init__.py +18 -0
  570. ads/opctl/anomaly_detection.py +11 -0
  571. ads/opctl/backend/__init__.py +5 -0
  572. ads/opctl/backend/ads_dataflow.py +353 -0
  573. ads/opctl/backend/ads_ml_job.py +710 -0
  574. ads/opctl/backend/ads_ml_pipeline.py +164 -0
  575. ads/opctl/backend/ads_model_deployment.py +209 -0
  576. ads/opctl/backend/base.py +146 -0
  577. ads/opctl/backend/local.py +1053 -0
  578. ads/opctl/backend/marketplace/__init__.py +9 -0
  579. ads/opctl/backend/marketplace/helm_helper.py +173 -0
  580. ads/opctl/backend/marketplace/local_marketplace.py +271 -0
  581. ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
  582. ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
  583. ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
  584. ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
  585. ads/opctl/backend/marketplace/models/__init__.py +5 -0
  586. ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
  587. ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
  588. ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
  589. ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
  590. ads/opctl/cli.py +707 -0
  591. ads/opctl/cmds.py +869 -0
  592. ads/opctl/conda/__init__.py +5 -0
  593. ads/opctl/conda/cli.py +193 -0
  594. ads/opctl/conda/cmds.py +749 -0
  595. ads/opctl/conda/config.yaml +34 -0
  596. ads/opctl/conda/manifest_template.yaml +13 -0
  597. ads/opctl/conda/multipart_uploader.py +188 -0
  598. ads/opctl/conda/pack.py +89 -0
  599. ads/opctl/config/__init__.py +5 -0
  600. ads/opctl/config/base.py +57 -0
  601. ads/opctl/config/diagnostics/__init__.py +5 -0
  602. ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
  603. ads/opctl/config/merger.py +255 -0
  604. ads/opctl/config/resolver.py +297 -0
  605. ads/opctl/config/utils.py +79 -0
  606. ads/opctl/config/validator.py +17 -0
  607. ads/opctl/config/versioner.py +68 -0
  608. ads/opctl/config/yaml_parsers/__init__.py +7 -0
  609. ads/opctl/config/yaml_parsers/base.py +58 -0
  610. ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
  611. ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
  612. ads/opctl/constants.py +66 -0
  613. ads/opctl/decorator/__init__.py +5 -0
  614. ads/opctl/decorator/common.py +129 -0
  615. ads/opctl/diagnostics/__init__.py +5 -0
  616. ads/opctl/diagnostics/__main__.py +25 -0
  617. ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
  618. ads/opctl/diagnostics/check_requirements.py +144 -0
  619. ads/opctl/diagnostics/requirement_exception.py +9 -0
  620. ads/opctl/distributed/README.md +109 -0
  621. ads/opctl/distributed/__init__.py +5 -0
  622. ads/opctl/distributed/certificates.py +32 -0
  623. ads/opctl/distributed/cli.py +207 -0
  624. ads/opctl/distributed/cmds.py +731 -0
  625. ads/opctl/distributed/common/__init__.py +5 -0
  626. ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
  627. ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
  628. ads/opctl/distributed/common/cluster_config_helper.py +103 -0
  629. ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
  630. ads/opctl/distributed/common/cluster_runner.py +54 -0
  631. ads/opctl/distributed/common/framework_factory.py +29 -0
  632. ads/opctl/docker/Dockerfile.job +103 -0
  633. ads/opctl/docker/Dockerfile.job.arm +107 -0
  634. ads/opctl/docker/Dockerfile.job.gpu +175 -0
  635. ads/opctl/docker/base-env.yaml +13 -0
  636. ads/opctl/docker/cuda.repo +6 -0
  637. ads/opctl/docker/operator/.dockerignore +0 -0
  638. ads/opctl/docker/operator/Dockerfile +41 -0
  639. ads/opctl/docker/operator/Dockerfile.gpu +85 -0
  640. ads/opctl/docker/operator/cuda.repo +6 -0
  641. ads/opctl/docker/operator/environment.yaml +8 -0
  642. ads/opctl/forecast.py +11 -0
  643. ads/opctl/index.yaml +3 -0
  644. ads/opctl/model/__init__.py +5 -0
  645. ads/opctl/model/cli.py +65 -0
  646. ads/opctl/model/cmds.py +73 -0
  647. ads/opctl/operator/README.md +4 -0
  648. ads/opctl/operator/__init__.py +31 -0
  649. ads/opctl/operator/cli.py +344 -0
  650. ads/opctl/operator/cmd.py +596 -0
  651. ads/opctl/operator/common/__init__.py +5 -0
  652. ads/opctl/operator/common/backend_factory.py +460 -0
  653. ads/opctl/operator/common/const.py +27 -0
  654. ads/opctl/operator/common/data/synthetic.csv +16001 -0
  655. ads/opctl/operator/common/dictionary_merger.py +148 -0
  656. ads/opctl/operator/common/errors.py +42 -0
  657. ads/opctl/operator/common/operator_config.py +99 -0
  658. ads/opctl/operator/common/operator_loader.py +811 -0
  659. ads/opctl/operator/common/operator_schema.yaml +130 -0
  660. ads/opctl/operator/common/operator_yaml_generator.py +152 -0
  661. ads/opctl/operator/common/utils.py +208 -0
  662. ads/opctl/operator/lowcode/__init__.py +5 -0
  663. ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
  664. ads/opctl/operator/lowcode/anomaly/README.md +207 -0
  665. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  666. ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
  667. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  668. ads/opctl/operator/lowcode/anomaly/const.py +167 -0
  669. ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
  670. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  671. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
  672. ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
  673. ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
  674. ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
  675. ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
  676. ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
  677. ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
  678. ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
  679. ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
  680. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  681. ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
  682. ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
  683. ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
  684. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  685. ads/opctl/operator/lowcode/common/const.py +10 -0
  686. ads/opctl/operator/lowcode/common/data.py +116 -0
  687. ads/opctl/operator/lowcode/common/errors.py +47 -0
  688. ads/opctl/operator/lowcode/common/transformations.py +296 -0
  689. ads/opctl/operator/lowcode/common/utils.py +384 -0
  690. ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
  691. ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
  692. ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
  693. ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
  694. ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
  695. ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
  696. ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
  697. ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
  698. ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
  699. ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
  700. ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
  701. ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
  702. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
  703. ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
  704. ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
  705. ads/opctl/operator/lowcode/forecast/README.md +209 -0
  706. ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
  707. ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
  708. ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
  709. ads/opctl/operator/lowcode/forecast/const.py +92 -0
  710. ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
  711. ads/opctl/operator/lowcode/forecast/errors.py +26 -0
  712. ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
  713. ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
  714. ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
  715. ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
  716. ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
  717. ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
  718. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
  719. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
  720. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
  721. ads/opctl/operator/lowcode/forecast/model/prophet.py +445 -0
  722. ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
  723. ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
  724. ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
  725. ads/opctl/operator/lowcode/forecast/utils.py +397 -0
  726. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  727. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
  728. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
  729. ads/opctl/operator/lowcode/pii/MLoperator +17 -0
  730. ads/opctl/operator/lowcode/pii/README.md +208 -0
  731. ads/opctl/operator/lowcode/pii/__init__.py +5 -0
  732. ads/opctl/operator/lowcode/pii/__main__.py +78 -0
  733. ads/opctl/operator/lowcode/pii/cmd.py +39 -0
  734. ads/opctl/operator/lowcode/pii/constant.py +84 -0
  735. ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
  736. ads/opctl/operator/lowcode/pii/errors.py +27 -0
  737. ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
  738. ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
  739. ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
  740. ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
  741. ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
  742. ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
  743. ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
  744. ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
  745. ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
  746. ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
  747. ads/opctl/operator/lowcode/pii/model/report.py +487 -0
  748. ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
  749. ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
  750. ads/opctl/operator/lowcode/pii/utils.py +43 -0
  751. ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
  752. ads/opctl/operator/lowcode/recommender/README.md +206 -0
  753. ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
  754. ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
  755. ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
  756. ads/opctl/operator/lowcode/recommender/constant.py +30 -0
  757. ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
  758. ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
  759. ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
  760. ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
  761. ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
  762. ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
  763. ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
  764. ads/opctl/operator/lowcode/recommender/utils.py +13 -0
  765. ads/opctl/operator/runtime/__init__.py +5 -0
  766. ads/opctl/operator/runtime/const.py +17 -0
  767. ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
  768. ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
  769. ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
  770. ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
  771. ads/opctl/operator/runtime/runtime.py +115 -0
  772. ads/opctl/schema.yaml.yml +36 -0
  773. ads/opctl/script.py +40 -0
  774. ads/opctl/spark/__init__.py +5 -0
  775. ads/opctl/spark/cli.py +43 -0
  776. ads/opctl/spark/cmds.py +147 -0
  777. ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
  778. ads/opctl/utils.py +344 -0
  779. ads/oracledb/__init__.py +5 -0
  780. ads/oracledb/oracle_db.py +346 -0
  781. ads/pipeline/__init__.py +39 -0
  782. ads/pipeline/ads_pipeline.py +2279 -0
  783. ads/pipeline/ads_pipeline_run.py +772 -0
  784. ads/pipeline/ads_pipeline_step.py +605 -0
  785. ads/pipeline/builders/__init__.py +5 -0
  786. ads/pipeline/builders/infrastructure/__init__.py +5 -0
  787. ads/pipeline/builders/infrastructure/custom_script.py +32 -0
  788. ads/pipeline/cli.py +119 -0
  789. ads/pipeline/extension.py +291 -0
  790. ads/pipeline/schema/__init__.py +5 -0
  791. ads/pipeline/schema/cs_step_schema.json +35 -0
  792. ads/pipeline/schema/ml_step_schema.json +31 -0
  793. ads/pipeline/schema/pipeline_schema.json +71 -0
  794. ads/pipeline/visualizer/__init__.py +5 -0
  795. ads/pipeline/visualizer/base.py +570 -0
  796. ads/pipeline/visualizer/graph_renderer.py +272 -0
  797. ads/pipeline/visualizer/text_renderer.py +84 -0
  798. ads/secrets/__init__.py +11 -0
  799. ads/secrets/adb.py +386 -0
  800. ads/secrets/auth_token.py +86 -0
  801. ads/secrets/big_data_service.py +365 -0
  802. ads/secrets/mysqldb.py +149 -0
  803. ads/secrets/oracledb.py +160 -0
  804. ads/secrets/secrets.py +407 -0
  805. ads/telemetry/__init__.py +7 -0
  806. ads/telemetry/base.py +69 -0
  807. ads/telemetry/client.py +125 -0
  808. ads/telemetry/telemetry.py +257 -0
  809. ads/templates/dataflow_pyspark.jinja2 +13 -0
  810. ads/templates/dataflow_sparksql.jinja2 +22 -0
  811. ads/templates/func.jinja2 +20 -0
  812. ads/templates/schemas/openapi.json +1740 -0
  813. ads/templates/score-pkl.jinja2 +173 -0
  814. ads/templates/score.jinja2 +322 -0
  815. ads/templates/score_embedding_onnx.jinja2 +202 -0
  816. ads/templates/score_generic.jinja2 +165 -0
  817. ads/templates/score_huggingface_pipeline.jinja2 +217 -0
  818. ads/templates/score_lightgbm.jinja2 +185 -0
  819. ads/templates/score_onnx.jinja2 +407 -0
  820. ads/templates/score_onnx_new.jinja2 +473 -0
  821. ads/templates/score_oracle_automl.jinja2 +185 -0
  822. ads/templates/score_pyspark.jinja2 +154 -0
  823. ads/templates/score_pytorch.jinja2 +219 -0
  824. ads/templates/score_scikit-learn.jinja2 +184 -0
  825. ads/templates/score_tensorflow.jinja2 +184 -0
  826. ads/templates/score_xgboost.jinja2 +178 -0
  827. ads/text_dataset/__init__.py +5 -0
  828. ads/text_dataset/backends.py +211 -0
  829. ads/text_dataset/dataset.py +445 -0
  830. ads/text_dataset/extractor.py +207 -0
  831. ads/text_dataset/options.py +53 -0
  832. ads/text_dataset/udfs.py +22 -0
  833. ads/text_dataset/utils.py +49 -0
  834. ads/type_discovery/__init__.py +9 -0
  835. ads/type_discovery/abstract_detector.py +21 -0
  836. ads/type_discovery/constant_detector.py +41 -0
  837. ads/type_discovery/continuous_detector.py +54 -0
  838. ads/type_discovery/credit_card_detector.py +99 -0
  839. ads/type_discovery/datetime_detector.py +92 -0
  840. ads/type_discovery/discrete_detector.py +118 -0
  841. ads/type_discovery/document_detector.py +146 -0
  842. ads/type_discovery/ip_detector.py +68 -0
  843. ads/type_discovery/latlon_detector.py +90 -0
  844. ads/type_discovery/phone_number_detector.py +63 -0
  845. ads/type_discovery/type_discovery_driver.py +87 -0
  846. ads/type_discovery/typed_feature.py +594 -0
  847. ads/type_discovery/unknown_detector.py +41 -0
  848. ads/type_discovery/zipcode_detector.py +48 -0
  849. ads/vault/__init__.py +7 -0
  850. ads/vault/vault.py +237 -0
  851. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/METADATA +150 -150
  852. oracle_ads-2.13.9rc1.dist-info/RECORD +858 -0
  853. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/WHEEL +1 -2
  854. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/entry_points.txt +2 -1
  855. oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
  856. oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
  857. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.9rc1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,983 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8; -*-
3
+
4
+ # Copyright (c) 2020, 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ from __future__ import print_function, absolute_import, division
8
+
9
+ import base64
10
+ from io import BytesIO
11
+ import matplotlib as mpl
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.lines as mlines
14
+ from matplotlib.ticker import FormatStrFormatter
15
+ import numpy as np
16
+ import math
17
+ from ads.common import logger
18
+ from ads.common.decorator.runtime_dependency import (
19
+ runtime_dependency,
20
+ OptionalDependency,
21
+ )
22
+ import itertools
23
+ import pandas as pd
24
+
25
+ MAX_TITLE_LEN = 20
26
+ MAX_LEGEND_LEN = 10
27
+ MAX_PLOTS_PER_ROW = 2
28
+ # Maximum class number evaluation plotting supporting for multiclass problems
29
+ MAX_PLOTTING_CLASSES = 10
30
+ # Maximum characters in class label able to be shown without being truncated
31
+ MAX_CHARACTERS_LEN = 13
32
+
33
+
34
+ def _fig_to_html(fig):
35
+ tmpfile = BytesIO()
36
+ fig.savefig(tmpfile, format="png")
37
+ encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8")
38
+
39
+ html = "<img src='data:image/png;base64,{}'>".format(encoded)
40
+ return html
41
+
42
+
43
+ class EvaluationPlot:
44
+ """EvaluationPlot holds data and methods for plots and it used to output them
45
+
46
+ Attributes
47
+ ----------
48
+ baseline (bool):
49
+ whether to plot the null model or zero information model
50
+ baseline_kwargs (dict):
51
+ keyword arguments for the baseline plot
52
+ color_wheel (dict):
53
+ color information used by the plot
54
+ font_sz (dict):
55
+ dictionary of plot methods
56
+ perfect (bool):
57
+ determines whether a "perfect" classifier curve is displayed
58
+ perfect_kwargs (dict):
59
+ parameters for the perfect classifier for precision/recall curves
60
+ prob_type (str):
61
+ model type, i.e. classification or regression
62
+
63
+ Methods
64
+ -------
65
+ get_legend_labels(legend_labels)
66
+ Renders the legend labels on the plot
67
+ plot(evaluation, plots, num_classes, perfect, baseline, legend_labels)
68
+ Generates the evalation plot
69
+ """
70
+
71
+ # dict of plot methods
72
+ font_sz = {
73
+ "xl": 16, # Plot type title
74
+ "l": 14, # Individual plot title (name of model)
75
+ "m": 12, # Axis titles
76
+ "s": 10, # Axis labels
77
+ "xs": 8,
78
+ } # test within the plot
79
+
80
+ baseline_kwargs = {"ls": "--", "c": ".2"} # lw = ??
81
+ perfect_kwargs = {"label": "Perfect Classifier", "ls": "--", "color": "gold"}
82
+ perfect = None
83
+ baseline = None
84
+ prob_type = None
85
+ color_wheel = ["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"]
86
+
87
+ _pretty_titles_map = {
88
+ "normalized_confusion_matrix": "Normalized Confusion Matrix",
89
+ "lift_chart": "Lift Chart",
90
+ "gain_chart": "Gain Chart",
91
+ "ks_statistics": "KS Statistics",
92
+ "residuals_qq": "Residuals Q-Q Plot",
93
+ "residuals_vs_predicted": "Residuals vs Predicted",
94
+ "residuals_vs_observed": "Residuals vs Observed",
95
+ "observed_vs_predicted": "Observed vs Predicted",
96
+ "precision_by_label": "Precision by Label",
97
+ "recall_by_label": "Recall by Label",
98
+ "f1_by_label": "F1 by Label",
99
+ "jaccard_by_label": "Jaccard by Label",
100
+ "pr_curve": "PR Curve",
101
+ "roc_curve": "ROC Curve",
102
+ "pr_and_roc_curve": "PR Curve, ROC Curve",
103
+ "lift_and_gain_chart": "Lift Chart, Gain Chart",
104
+ }
105
+
106
+ _ugly_titles_map = {v: k for k, v in _pretty_titles_map.items()}
107
+
108
+ double_overlay_plots = ["pr_and_roc_curve", "lift_and_gain_chart"]
109
+ single_overlay_plots = ["lift_chart", "gain_chart", "roc_curve", "pr_curve"]
110
+
111
+ _bin_plots = [
112
+ "pr_curve",
113
+ "roc_curve",
114
+ "lift_chart",
115
+ "gain_chart",
116
+ "normalized_confusion_matrix",
117
+ ]
118
+ _multi_plots = [
119
+ "normalized_confusion_matrix",
120
+ "roc_curve",
121
+ "pr_curve",
122
+ "precision_by_label",
123
+ "recall_by_label",
124
+ "f1_by_label",
125
+ "jaccard_by_label",
126
+ ]
127
+ _reg_plots = [
128
+ "observed_vs_predicted",
129
+ "residuals_qq",
130
+ "residuals_vs_predicted",
131
+ "residuals_vs_observed",
132
+ ]
133
+
134
+ # list of detailed descriptions of each plot type for every classification type, can be extended when adding more metrics
135
+ _bin_plots_details = """
136
+ In pattern recognition, information retrieval and binary classification, precision (also called positive predictive value)
137
+ is the fraction of relevant instances among the retrieved instances, while recall (also known as sensitivity) is the
138
+ fraction of relevant instances that have been retrieved over the total amount of relevant instances. \n
139
+ A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a
140
+ binary classifier system as its discrimination threshold is varied. \n
141
+ In data mining and association rule learning, lift is a measure of the performance of a targeting model (association rule)
142
+ at predicting or classifying cases as having an enhanced response (with respect to the population as a whole), measured
143
+ against a random choice targeting model. \n
144
+ A gain graph is a graph whose edges are labelled 'invertibly', or 'orientably', by elements of a group G. \n
145
+ In the field of machine learning and specifically the problem of statistical classification, a confusion matrix, also known
146
+ as an error matrix, is a specific table layout that allows visualization of the performance of an algorithm, typically a
147
+ supervised learning one (in unsupervised learning it is usually called a matching matrix).
148
+ """
149
+ # Removed for now:
150
+ # Kuiper 's test (ks_statistics) is used in statistics to test that whether a given distribution, or family of
151
+ # distributions, is contradicted by evidence from a sample of data. \n
152
+
153
+ _multi_plots_details = """
154
+ In the field of machine learning and specifically the problem of statistical classification, a confusion matrix, also known
155
+ as an error matrix, is a specific table layout that allows visualization of the performance of an algorithm, typically a
156
+ supervised learning one (in unsupervised learning it is usually called a matching matrix). \n
157
+ A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a
158
+ binary classifier system as its discrimination threshold is varied. \n
159
+ In pattern recognition, information retrieval and binary classification, precision (also called positive predictive value)
160
+ is the fraction of relevant instances among the retrieved instances, while recall (also known as sensitivity) is the
161
+ fraction of relevant instances that have been retrieved over the total amount of relevant instances. \n
162
+ In statistical analysis of binary classification, the F1 score (also F-score or F-measure) is a measure of a test's accuracy.
163
+ It considers both the precision p and the recall r of the test to compute the score: p is the number of correct positive results
164
+ divided by the number of all positive results returned by the classifier, and r is the number of correct positive results divided
165
+ by the number of all relevant samples (all samples that should have been identified as positive). The F1 score is the harmonic mean
166
+ of the precision and recall, where an F1 score reaches its best value at 1 (perfect precision and recall) and worst at 0. \n
167
+ The Jaccard index, also known as Intersection over Union and the Jaccard similarity coefficient, is a statistic used for gauging
168
+ the similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, and is defined
169
+ as the size of the intersection divided by the size of the union of the sample sets
170
+ """
171
+ _reg_plots_details = """
172
+ In statistics, a Q–Q (quantile-quantile) plot is a probability plot, which is a graphical method for comparing two probability distributions
173
+ by plotting their quantiles against each other. \n
174
+ In statistics and optimization, errors and residuals are two closely related and easily confused measures of the deviation of an observed value
175
+ of an element of a statistical sample from its "theoretical value". The error (or disturbance) of an observed value is the deviation of the
176
+ observed value from the (unobservable) true value of a quantity of interest (for example, a population mean), and the residual of an observed
177
+ value is the difference between the observed value and the estimated value of the quantity of interest (for example, a sample mean).
178
+ """
179
+
180
+ @classmethod
181
+ def _get_formatted_title(cls, title, max_len=MAX_TITLE_LEN):
182
+ return title if len(title) < max_len + 3 else title[:max_len] + "..."
183
+
184
+ @classmethod
185
+ def get_legend_labels(cls, legend_labels):
186
+ """Gets the legend labels, resolves any conflicts such as length, and renders
187
+ the labels for the plot
188
+
189
+ Parameters
190
+ ----------
191
+ legend_labels (dict):
192
+ key/value dictionary containing legend label data
193
+
194
+ Returns
195
+ -------
196
+ Nothing
197
+
198
+ Examples
199
+ --------
200
+
201
+ EvaluationPlot.get_legend_labels({'class_0': 'green', 'class_1': 'yellow', 'class_2': 'red'})
202
+ """
203
+
204
+ @runtime_dependency(module="IPython", install_from=OptionalDependency.NOTEBOOK)
205
+ @runtime_dependency(
206
+ module="ipywidgets", object="HTML", install_from=OptionalDependency.NOTEBOOK
207
+ )
208
+ def render_legend_labels(label_dict):
209
+ encodings = pd.DataFrame(
210
+ pd.Series(label_dict, index=label_dict.keys()),
211
+ columns=["Shortened labels"],
212
+ )
213
+ from IPython.core.display import display, HTML
214
+
215
+ display(
216
+ HTML(
217
+ encodings.style.format(precision=4)
218
+ .set_properties(**{"text-align": "center"})
219
+ .set_table_styles(
220
+ [dict(selector="", props=[("text-align", "center")])]
221
+ )
222
+ .set_table_attributes("class=table")
223
+ .set_caption(
224
+ '<div align="left"><b style="font-size:20px;">'
225
+ + "Legend for labels of the target feature:</b></div>"
226
+ )
227
+ .to_html()
228
+ )
229
+ )
230
+
231
+ if legend_labels is not None:
232
+ # CAUTION: cls.classes is a list of strings. Make sure users know that labels are all converted to strings.
233
+ if isinstance(legend_labels, dict) and set(cls.classes).issubset(
234
+ set(legend_labels.keys())
235
+ ):
236
+ render_legend_labels(legend_labels)
237
+ cls.legend_labels = legend_labels
238
+ else:
239
+ logger.error(
240
+ "The provided `legend_labels` is either not a Python dict or does not possess all possible class labels."
241
+ )
242
+ return
243
+
244
+ # try to remove leading words
245
+ def _check_for_redundant_words(label_vec, prefix, max_len=MAX_LEGEND_LEN):
246
+ words_found = False
247
+ classes = [lab for lab in label_vec]
248
+ while len(set([lab.split()[0] for lab in classes])) <= 1:
249
+ # remove that word from vec, and add it to prefix
250
+ additional_prefix = classes[0].split()[0]
251
+ prefix = (
252
+ prefix[:-3] + additional_prefix + " ..."
253
+ if words_found
254
+ else additional_prefix + " ..."
255
+ )
256
+ classes = [lab[len(additional_prefix) + 1 :] for lab in classes]
257
+ words_found = True
258
+
259
+ classes = (
260
+ classes if words_found else [label[max_len:] for label in label_vec]
261
+ )
262
+ return classes, prefix
263
+
264
+ # returns mapping from real labels to psuedo-labels, when psuedo-labels are not the first X letter
265
+ def _resolve_conflict(label_vec, prefix, max_len=MAX_LEGEND_LEN):
266
+ classes, prefix = _check_for_redundant_words(label_vec, prefix)
267
+ label_dict = _get_labels(classes, max_len=max_len)
268
+ resolved = {}
269
+ for orig, new in label_dict.items():
270
+ resolved[prefix[:-3] + orig] = "..." + new if new[:3] != "..." else new
271
+ return resolved
272
+
273
+ # returns mapping from provided list of strings, to short, unique substrings
274
+ def _get_labels(classes, max_len=MAX_LEGEND_LEN):
275
+ conflict_dict = {}
276
+ for label in classes:
277
+ prefix = label if len(label) < max_len + 3 else label[:max_len] + "..."
278
+ if conflict_dict.get(prefix, None) is None:
279
+ conflict_dict[prefix] = [label]
280
+ else:
281
+ conflict_dict[prefix].append(label)
282
+ out = {}
283
+ for k, v in conflict_dict.items():
284
+ if len(v) == 1:
285
+ out[v[0]] = k
286
+ else:
287
+ resolved = _resolve_conflict(v, k, max_len=MAX_LEGEND_LEN)
288
+ out.update(resolved)
289
+ return out
290
+
291
+ cls.legend_labels = _get_labels(cls.classes)
292
+ if set(cls.legend_labels.keys()) != set(cls.legend_labels.values()):
293
+ logger.info(
294
+ f"Class labels greater than {MAX_CHARACTERS_LEN} characters have been truncated. "
295
+ "Use the `legend_labels` parameter to define labels."
296
+ )
297
+ render_legend_labels(cls.legend_labels)
298
+
299
+ # evaluation is a DataFrame with models as columns and metrics as rows
300
+ @classmethod
301
+ def plot(
302
+ cls,
303
+ evaluation,
304
+ plots,
305
+ num_classes,
306
+ perfect=False,
307
+ baseline=True,
308
+ legend_labels=None,
309
+ ):
310
+ """Generates the evaluation plot
311
+
312
+ Parameters
313
+ ----------
314
+ evaluation (DataFrame):
315
+ DataFrame with models as columns and metrics as rows.
316
+ plots (str):
317
+ The plot type based on class attribute `prob_type`.
318
+ num_classes (int):
319
+ The number of classes for the model.
320
+ perfect (bool, optional):
321
+ Whether to display the curve of a perfect classifier. Default value is `False`.
322
+ baseline (bool, optional):
323
+ Whether to display the curve of the baseline, featureless model. Default value is `True`.
324
+ legend_labels (dict, optional):
325
+ Legend labels dictionary. Default value is `None`. If legend_labels not specified class names will be used for plots.
326
+
327
+ Returns
328
+ -------
329
+ Nothing
330
+ """
331
+
332
+ cls.perfect = perfect
333
+ cls.baseline = baseline
334
+ # get plots to show
335
+ if num_classes == 2:
336
+ cls.prob_type = "_bin"
337
+ elif num_classes > 2:
338
+ cls.prob_type = "_multi"
339
+ else:
340
+ cls.prob_type = "_reg"
341
+ plot_details = getattr(cls, cls.prob_type + "_plots_details")
342
+ if plots is None:
343
+ plots = getattr(cls, cls.prob_type + "_plots")
344
+ logger.info(
345
+ "Showing plot types: {}.".format(
346
+ ", ".join(
347
+ [
348
+ "{}".format(EvaluationPlot._pretty_titles_map[str(p)])
349
+ for p in plots
350
+ ]
351
+ ),
352
+ ", ".join(["{}".format(x) for x in map(str, plots)]),
353
+ )
354
+ )
355
+ logger.info(plot_details)
356
+
357
+ if cls.prob_type == "_bin":
358
+ if "lift_chart" in plots and "gain_chart" in plots:
359
+ plots.remove("lift_chart")
360
+ plots.remove("gain_chart")
361
+ plots.insert(0, "lift_and_gain_chart")
362
+
363
+ if "roc_curve" in plots and "pr_curve" in plots:
364
+ plots.remove("roc_curve")
365
+ plots.remove("pr_curve")
366
+ plots.insert(0, "pr_and_roc_curve")
367
+ elif cls.prob_type == "_multi":
368
+ if (
369
+ "normalized_confusion_matrix" in plots
370
+ and len(evaluation[evaluation.columns[0]]["classes"])
371
+ >= MAX_PLOTTING_CLASSES
372
+ ):
373
+ logger.error(
374
+ f"Evaluation plotting is not yet supported for multiclass problems with {MAX_PLOTTING_CLASSES} or more classes."
375
+ )
376
+ plots = []
377
+ classes = evaluation[evaluation.columns[0]]["classes"]
378
+ if classes is not None:
379
+ # CAUTION: class labels are converted to strings here.
380
+ # If users are passing in legend_labels, they have to use strings as well.
381
+ # Otherwise get_legend_labels() will complaint.
382
+ # If users are not passing in legend_labels, get_legend_labels() generates them from cls.classes.
383
+ # cls.legend_labels are assigned/created in cls.get_lengend_labels() and contain only strings as keys.
384
+ cls.classes = [str(c) for c in classes]
385
+ cls.get_legend_labels(legend_labels)
386
+
387
+ mpl.style.use("default")
388
+ html_raw = []
389
+ for i, plot_type in enumerate(plots):
390
+ fig_title, fig = None, None
391
+ try:
392
+ fig_title, ax_title = plt.subplots(1, 1, figsize=(18, 0.5), dpi=144)
393
+ ax_title.text(
394
+ 0.5,
395
+ 0.5,
396
+ cls._pretty_titles_map[plot_type],
397
+ fontsize=16,
398
+ fontweight="semibold",
399
+ horizontalalignment="center",
400
+ verticalalignment="center",
401
+ transform=ax_title.transAxes,
402
+ )
403
+ ax_title.axis("off")
404
+ html_raw.append(_fig_to_html(fig_title))
405
+ if cls.prob_type == "_bin" and plot_type in cls.double_overlay_plots:
406
+ fig, ax = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144)
407
+ elif cls.prob_type == "_bin" and plot_type in ["roc_curve", "pr_curve"]:
408
+ fig, axs = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144)
409
+ axs[1].axis("off")
410
+ ax = [axs[0]]
411
+ elif cls.prob_type == "_bin" and plot_type in [
412
+ "lift_chart",
413
+ "gain_chart",
414
+ ]:
415
+ fig, axs = plt.subplots(1, 2, figsize=(12, 4.5), dpi=144)
416
+ axs[1].axis("off")
417
+ ax = axs[0]
418
+ else:
419
+ nrows = math.ceil(len(evaluation.columns) / MAX_PLOTS_PER_ROW)
420
+ fig, ax = plt.subplots(
421
+ nrows, MAX_PLOTS_PER_ROW, figsize=(10, 4 * nrows), dpi=144
422
+ ) # 10, 3.5
423
+ ax = ax.flatten()
424
+ getattr(cls, "_" + plot_type)(ax, evaluation)
425
+ fig.tight_layout()
426
+ html_raw.append(_fig_to_html(fig))
427
+ except KeyError as e:
428
+ try:
429
+ if fig_title:
430
+ plt.close(fig=fig_title)
431
+ if fig:
432
+ plt.close(fig=fig)
433
+ except:
434
+ pass
435
+ logger.warning(
436
+ f"Evaluator was not able to plot "
437
+ f"{cls._pretty_titles_map.get(plot_type, plot_type)}, because the relevant "
438
+ f"metrics had complications. Ensure that `predict` and `predict_proba` "
439
+ f"are valid."
440
+ )
441
+ return html_raw
442
+
443
+ @classmethod
444
+ def _lift_and_gain_chart(cls, ax, evaluation):
445
+ cls._lift_chart(ax[0], evaluation)
446
+ cls._gain_chart(ax[1], evaluation)
447
+
448
+ @classmethod
449
+ def _lift_chart(cls, ax, evaluation):
450
+ for mod_name, col in evaluation.items():
451
+ if col["y_score"] is not None:
452
+ ax.plot(
453
+ col["percentages"][1:],
454
+ [1] + list(col["lift"]),
455
+ label=cls._get_formatted_title(mod_name),
456
+ )
457
+ if cls.baseline:
458
+ ax.plot([-10, 110], [1, 1], **cls.baseline_kwargs)
459
+ if cls.perfect:
460
+ perf_idx = next(
461
+ idx
462
+ for idx, scores in enumerate(evaluation.loc["y_score"])
463
+ if scores is not None
464
+ )
465
+ ax.plot(
466
+ evaluation.loc["percentages"][perf_idx][1:],
467
+ [1] + list(evaluation.loc["perfect_lift"][perf_idx]),
468
+ **cls.perfect_kwargs,
469
+ )
470
+ ax.legend(loc="upper right", frameon=False)
471
+ ax.set_xlabel("Percentage of Population", fontsize=12)
472
+ ax.set_ylabel("Lift", fontsize=12)
473
+ ax.set_title("Lift Chart", y=1.08, fontsize=14)
474
+ ax.grid(linewidth=0.2, which="both")
475
+ ax.set_xlim([-10, 110])
476
+
477
+ @classmethod
478
+ def _gain_chart(cls, ax, evaluation):
479
+ for mod_name, col in evaluation.items():
480
+ if col["y_score"] is not None:
481
+ ax.plot(
482
+ col["percentages"],
483
+ list(col["cumulative_gain"]),
484
+ label=cls._get_formatted_title(mod_name),
485
+ )
486
+ if cls.baseline:
487
+ ax.plot([-10, 110], [-10, 110], **cls.baseline_kwargs)
488
+ if cls.perfect:
489
+ perf_idx = next(
490
+ idx
491
+ for idx, scores in enumerate(evaluation.loc["y_score"])
492
+ if scores is not None
493
+ )
494
+ ax.plot(
495
+ evaluation.loc["percentages"][perf_idx],
496
+ evaluation.loc["perfect_gain"][perf_idx],
497
+ **cls.perfect_kwargs,
498
+ )
499
+ ax.legend(loc="lower right", frameon=False)
500
+ ax.set_xlabel("Percentage of Population", fontsize=12)
501
+ ax.set_ylabel("Percentage of Positive Class", fontsize=12)
502
+ ax.set_title("Gain Chart", y=1.08, fontsize=14)
503
+ ax.grid(linewidth=0.2, which="both")
504
+ ax.set_xlim([-10, 110])
505
+ ax.set_ylim([-10, 110])
506
+
507
+ @classmethod
508
+ def _pr_and_roc_curve(cls, ax, evaluation):
509
+ cls._pr_curve([ax[0]], evaluation)
510
+ cls._roc_curve([ax[1]], evaluation)
511
+
512
+ @classmethod
513
+ def _pr_curve(cls, axs, evaluation):
514
+ n_models = len(evaluation.columns)
515
+ for i, ax in enumerate(axs):
516
+ if i >= n_models:
517
+ ax.axis("off")
518
+ return
519
+ if cls.prob_type == "_bin":
520
+ for mod_name, col in evaluation.items():
521
+ if col["y_score"] is not None:
522
+ ax.plot(
523
+ col["recall_values"],
524
+ col["precision_values"],
525
+ label="%s (Precision: %s)"
526
+ % (
527
+ cls._get_formatted_title(mod_name),
528
+ "{:.3f}".format(col["precision"]),
529
+ ),
530
+ )
531
+ ax.plot(
532
+ *col["pr_best_model_score"],
533
+ color=ax.get_lines()[-1].get_color(),
534
+ marker="*",
535
+ )
536
+ else:
537
+ model_name = evaluation.columns[i]
538
+ mod = evaluation[model_name]
539
+ if mod["y_score"] is not None:
540
+ for j, lab in enumerate(mod.classes):
541
+ # cls.legend_labels contains only strings as keys.
542
+ lab = str(lab)
543
+ ax.plot(
544
+ mod["recall_values"][j],
545
+ mod["precision_values"][j],
546
+ label="%s (Precision: %s)"
547
+ % (
548
+ cls.legend_labels[lab],
549
+ "{:.3f}".format(mod["precision_by_label"][j]),
550
+ ),
551
+ )
552
+ ax.plot(
553
+ *mod["pr_best_model_score"][j],
554
+ color=ax.get_lines()[-1].get_color(),
555
+ marker="*",
556
+ )
557
+
558
+ ax.set_xlabel("Recall", fontsize=12)
559
+ ax.set_ylabel("Precision", fontsize=12)
560
+ ax.set_title("Precision Recall Curve", y=1.08, fontsize=14)
561
+ ax.grid(linewidth=0.2, which="both")
562
+ ax.set_xlim([-0.1, 1.1])
563
+ ax.set_ylim([-0.1, 1.1])
564
+ handles, labels = ax.get_legend_handles_labels()
565
+ star = mlines.Line2D(
566
+ [],
567
+ [],
568
+ color="black",
569
+ marker="*",
570
+ linestyle="None",
571
+ markersize=5,
572
+ label="Minimum Error Rate",
573
+ )
574
+ handles.append(star)
575
+ labels.append("Minimum Error Rate")
576
+ ax.legend(
577
+ loc="upper right",
578
+ labels=labels,
579
+ handles=handles,
580
+ frameon=False,
581
+ fontsize="x-small",
582
+ )
583
+
584
+ @classmethod
585
+ def _roc_curve(cls, axs, evaluation):
586
+ n_models = len(evaluation.columns)
587
+ for i, ax in enumerate(axs):
588
+ if i >= n_models:
589
+ ax.axis("off")
590
+ return
591
+ if cls.prob_type == "_bin":
592
+ for mod_name, col in evaluation.items():
593
+ if col["y_score"] is not None:
594
+ ax.plot(
595
+ col["false_positive_rate"],
596
+ col["true_positive_rate"],
597
+ label="%s (AUC: %s)"
598
+ % (
599
+ cls._get_formatted_title(mod_name),
600
+ "{:.3f}".format(col["auc"]),
601
+ ),
602
+ )
603
+ ax.plot(
604
+ *col["roc_best_model_score"],
605
+ color=ax.get_lines()[-1].get_color(),
606
+ marker="*",
607
+ )
608
+ else:
609
+ model_name = evaluation.columns[i]
610
+ mod = evaluation[model_name]
611
+ if mod["y_score"] is not None:
612
+ for j, lab in enumerate(mod.classes):
613
+ # cls.legend_labels contains only strings as keys.
614
+ lab = str(lab)
615
+ ax.plot(
616
+ mod["fpr_by_label"][j],
617
+ mod["tpr_by_label"][j],
618
+ label="%s (AUC: %s)"
619
+ % (cls.legend_labels[lab], "{:.3f}".format(mod["auc"][j])),
620
+ )
621
+ ax.plot(
622
+ *mod["roc_best_model_score"][j],
623
+ color=ax.get_lines()[-1].get_color(),
624
+ marker="*",
625
+ )
626
+ if cls.baseline:
627
+ ax.plot([-0.1, 1.1], [-0.1, 1.1], **cls.baseline_kwargs)
628
+ ax.set_xlabel("False Positive Rate", fontsize=12)
629
+ ax.set_ylabel("True Positive Rate", fontsize=12)
630
+ ax.set_title("ROC Curve", y=1.08, fontsize=14)
631
+ ax.grid(linewidth=0.2, which="both")
632
+ ax.set_xlim([-0.1, 1.1])
633
+ ax.set_ylim([-0.1, 1.1])
634
+ handles, labels = ax.get_legend_handles_labels()
635
+ star = mlines.Line2D(
636
+ [],
637
+ [],
638
+ color="black",
639
+ marker="*",
640
+ linestyle="None",
641
+ markersize=5,
642
+ label="Youden's J Statistic",
643
+ )
644
+ handles.append(star)
645
+ labels.append("Youden's J Statistic")
646
+ ax.legend(
647
+ loc="lower right",
648
+ labels=labels,
649
+ handles=handles,
650
+ frameon=False,
651
+ fontsize="x-small",
652
+ )
653
+
654
+ @classmethod
655
+ def _ks_statistics(cls, axs, evaluation):
656
+ n_models = len(evaluation.columns)
657
+ for i, ax in enumerate(axs):
658
+ if i >= n_models:
659
+ ax.axis("off")
660
+ return
661
+ model_name = evaluation.columns[i]
662
+ mod = evaluation[model_name]
663
+
664
+ ax.set_title(model_name, fontsize=14)
665
+ if mod["y_score"] is not None:
666
+ ax.plot(
667
+ mod["ks_thresholds"],
668
+ mod["ks_pct1"],
669
+ lw=3,
670
+ label=mod["ks_labels"][0],
671
+ )
672
+ ax.plot(
673
+ mod["ks_thresholds"],
674
+ mod["ks_pct2"],
675
+ lw=3,
676
+ label=mod["ks_labels"][1],
677
+ )
678
+ if cls.baseline:
679
+ idx = np.where(mod["ks_thresholds"] == mod["max_distance_at"])[0][0]
680
+ ax.axvline(
681
+ mod["max_distance_at"],
682
+ *sorted([mod["ks_pct1"][idx], mod["ks_pct2"][idx]]),
683
+ label="KS Statistic: {:.3f} at {:.3f}".format(
684
+ mod["ks_statistic"], mod["max_distance_at"]
685
+ ),
686
+ linestyle="--",
687
+ color=".2",
688
+ )
689
+
690
+ ax.set_xlim([0.0, 1.0])
691
+ ax.set_ylim([0.0, 1.0])
692
+
693
+ ax.set_xlabel("Threshold", fontsize=12)
694
+ ax.set_ylabel("Percentage below threshold", fontsize=12)
695
+ ax.tick_params(labelsize=10)
696
+ ax.legend(loc="lower right", fontsize=8)
697
+
698
+ @classmethod
699
+ def _pretty_barh(
700
+ cls, ax, x, y, axis_labels=None, title=None, axis_lim=None, plot_kwargs=None
701
+ ):
702
+ # cls.legend_labels contains only strings as keys.
703
+ new_lab = [cls.legend_labels[str(item)] for item in x]
704
+ ax.barh(
705
+ new_lab,
706
+ y,
707
+ color=["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"],
708
+ )
709
+ for j, v in enumerate(y):
710
+ ax.annotate("{:.3f}".format(v), xy=(v / 2, j), va="center", ha="left")
711
+ if axis_labels:
712
+ if axis_labels[0]:
713
+ ax.set_xlabel(axis_labels[0], fontsize=12)
714
+ if axis_labels[1]:
715
+ ax.set_ylabel(axis_labels[1], fontsize=12)
716
+ if title:
717
+ title = cls._get_formatted_title(title)
718
+ ax.set_title(title, y=1.08, fontsize=14)
719
+ if axis_lim:
720
+ ax.set_xlim(axis_lim)
721
+
722
+ @classmethod
723
+ def _precision_by_label(cls, axs, evaluation):
724
+ n_models = len(evaluation.columns)
725
+ for i, ax in enumerate(axs):
726
+ if i < n_models:
727
+ col = evaluation.columns[i]
728
+ cls._pretty_barh(
729
+ ax,
730
+ evaluation[col]["classes"],
731
+ evaluation[col]["precision_by_label"],
732
+ axis_lim=[0, 1],
733
+ axis_labels=["Precision", None],
734
+ title=col,
735
+ )
736
+ else:
737
+ ax.axis("off")
738
+
739
+ @classmethod
740
+ def _recall_by_label(cls, axs, evaluation):
741
+ n_models = len(evaluation.columns)
742
+ for i, ax in enumerate(axs):
743
+ if i < n_models:
744
+ col = evaluation.columns[i]
745
+ cls._pretty_barh(
746
+ ax,
747
+ evaluation[col]["classes"],
748
+ evaluation[col]["recall_by_label"],
749
+ axis_lim=[0, 1],
750
+ axis_labels=["Recall", None],
751
+ title=col,
752
+ )
753
+ else:
754
+ ax.axis("off")
755
+
756
+ @classmethod
757
+ def _f1_by_label(cls, axs, evaluation):
758
+ n_models = len(evaluation.columns)
759
+ for i, ax in enumerate(axs):
760
+ if i < n_models:
761
+ col = evaluation.columns[i]
762
+ cls._pretty_barh(
763
+ ax,
764
+ evaluation[col]["classes"],
765
+ evaluation[col]["f1_by_label"],
766
+ axis_lim=[0, 1],
767
+ axis_labels=["F1 Score", None],
768
+ title=col,
769
+ )
770
+ else:
771
+ ax.axis("off")
772
+
773
+ @classmethod
774
+ def _jaccard_by_label(cls, axs, evaluation):
775
+ n_models = len(evaluation.columns)
776
+ for i, ax in enumerate(axs):
777
+ if i < n_models:
778
+ col = evaluation.columns[i]
779
+ cls._pretty_barh(
780
+ ax,
781
+ evaluation[col]["classes"],
782
+ evaluation[col]["jaccard_by_label"],
783
+ axis_lim=[0, 1],
784
+ axis_labels=["Jaccard Score", None],
785
+ title=col,
786
+ )
787
+ else:
788
+ ax.axis("off")
789
+
790
+ @classmethod
791
+ def _pretty_scatter(
792
+ cls,
793
+ ax,
794
+ x,
795
+ y,
796
+ s=5,
797
+ alpha=1.0,
798
+ title=None,
799
+ legend=False,
800
+ axis_labels=None,
801
+ axis_lim=None,
802
+ grid=True,
803
+ label=None,
804
+ plot_kwargs=None,
805
+ ):
806
+ if plot_kwargs is None:
807
+ plot_kwargs = {}
808
+ ax.scatter(x, y, s=s, label=label, marker="o", alpha=alpha, **plot_kwargs)
809
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
810
+ ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
811
+ if legend:
812
+ ax.legend(frameon=False)
813
+ if axis_labels:
814
+ ax.set_xlabel(axis_labels[0])
815
+ ax.set_ylabel(axis_labels[1])
816
+ if title:
817
+ ax.set_title(title, y=1.08, fontsize=14)
818
+ if grid:
819
+ ax.grid(linewidth=0.2)
820
+ if axis_lim:
821
+ if axis_lim[0]:
822
+ ax.set_xlim(axis_lim[0])
823
+ if axis_lim[1]:
824
+ ax.set_ylim(axis_lim[1])
825
+
826
+ @classmethod
827
+ def _top_2_features(cls, axs, evaluation):
828
+ pass
829
+
830
+ @classmethod
831
+ def _residuals_qq(cls, axs, evaluation):
832
+ n_models = len(evaluation.columns)
833
+ for i, ax in enumerate(axs):
834
+ if i >= n_models:
835
+ ax.axis("off")
836
+ return
837
+ model_name = evaluation.columns[i]
838
+ mod = evaluation[model_name]
839
+ # getattr(ax, self.plot_method)(self.x, y, **self.plot_kwargs, color='#4a91c2', label=label)
840
+ cls._pretty_scatter(
841
+ ax,
842
+ mod["norm_quantiles"],
843
+ mod["residual_quantiles"],
844
+ title=model_name,
845
+ axis_lim=[(-2.7, 2.7), (-3.1, 3.1)],
846
+ axis_labels=["Theoretical Quantiles", "Sample Quantiles"],
847
+ )
848
+ if cls.baseline:
849
+ ax.plot((-100, 100), (-100, 100), **cls.baseline_kwargs)
850
+
851
+ @classmethod
852
+ def _residuals_vs_predicted(cls, axs, evaluation):
853
+ n_models = len(evaluation.columns)
854
+ for i, ax in enumerate(axs):
855
+ if i >= n_models:
856
+ ax.axis("off")
857
+ return
858
+ model_name = evaluation.columns[i]
859
+ mod = evaluation[model_name]
860
+ y_pred = np.asarray(mod["y_pred"])
861
+ resid = np.asarray(mod["residuals"])
862
+ cls._pretty_scatter(
863
+ ax,
864
+ y_pred,
865
+ resid,
866
+ s=4,
867
+ alpha=0.5,
868
+ title=model_name,
869
+ axis_labels=["Predicted Values", "Residuals"],
870
+ )
871
+ x_lim = (
872
+ y_pred.min() - y_pred.min() * 0.05,
873
+ y_pred.max() + y_pred.max() * 0.05,
874
+ )
875
+ if cls.baseline:
876
+ ax.plot(x_lim, (0, 0), **cls.baseline_kwargs)
877
+ ax.set_xlim(x_lim)
878
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
879
+ ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
880
+
881
+ @classmethod
882
+ def _residuals_vs_observed(cls, axs, evaluation):
883
+ n_models = len(evaluation.columns)
884
+ for i, ax in enumerate(axs):
885
+ if i >= n_models:
886
+ ax.axis("off")
887
+ return
888
+ model_name = evaluation.columns[i]
889
+ mod = evaluation[model_name]
890
+ y_true = np.asarray(mod["y_true"])
891
+ cls._pretty_scatter(
892
+ ax,
893
+ y_true,
894
+ mod["residuals"],
895
+ s=4,
896
+ alpha=0.5,
897
+ title=model_name,
898
+ axis_labels=["Observed Values", "Residuals"],
899
+ )
900
+ x_lim = (
901
+ y_true.min() - y_true.min() * 0.05,
902
+ y_true.max() + y_true.max() * 0.05,
903
+ )
904
+ if cls.baseline:
905
+ ax.plot(x_lim, (0, 0), **cls.baseline_kwargs)
906
+ ax.set_xlim(x_lim)
907
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
908
+ ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
909
+
910
+ @classmethod
911
+ def _observed_vs_predicted(cls, axs, evaluation):
912
+ n_models = len(evaluation.columns)
913
+ for i, ax in enumerate(axs):
914
+ if i >= n_models:
915
+ ax.axis("off")
916
+ return
917
+
918
+ model_name = evaluation.columns[i]
919
+ mod = evaluation[model_name]
920
+ ax.scatter(mod["y_true"], mod["y_pred"], s=4, marker="o", alpha=0.5)
921
+
922
+ y_true = np.asarray(mod["y_true"])
923
+
924
+ yt_min = y_true.min()
925
+ yt_max = y_true.max()
926
+
927
+ x_lim = (yt_min - yt_min * 0.05, yt_max + yt_max * 0.05)
928
+ if cls.baseline:
929
+ ax.plot(x_lim, x_lim, **cls.baseline_kwargs)
930
+ ax.set_xlabel("Observed Values")
931
+ ax.set_ylabel("Predicted Values")
932
+ ax.set_title(model_name, y=1.08, fontsize=14)
933
+ ax.grid(linewidth=0.2)
934
+ ax.set_xlim(x_lim)
935
+
936
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
937
+ ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
938
+
939
+ @classmethod
940
+ def _normalized_confusion_matrix(cls, axs, evaluation):
941
+ for model_num, ax in enumerate(axs):
942
+ if model_num >= len(evaluation.columns):
943
+ ax.axis("off")
944
+ return
945
+ model_name = evaluation.columns[model_num]
946
+ mod = evaluation[model_name]
947
+ if cls.prob_type == "_bin":
948
+ labels = [str(lab == mod["positive_class"]) for lab in mod["classes"]]
949
+ else:
950
+ labels = cls.legend_labels.values()
951
+
952
+ raw_cm = mod["raw_confusion_matrix"]
953
+ cm = np.asarray(mod["confusion_matrix"])
954
+
955
+ ax.set_title(
956
+ "%s\n" % cls._get_formatted_title(model_name), y=1.08, fontsize=14
957
+ )
958
+ ax.imshow(cm, interpolation="nearest", cmap="BuGn")
959
+ x_tick_marks = np.arange(len(labels))
960
+ y_tick_marks = np.arange(len(labels))
961
+ ax.set_xticks(x_tick_marks)
962
+ ax.set_yticks(y_tick_marks)
963
+
964
+ ax.set_xticklabels(labels, rotation=90, fontsize=10)
965
+ ax.set_yticklabels(labels, fontsize=10)
966
+
967
+ for i, j in itertools.product(
968
+ range(raw_cm.shape[0]), range(raw_cm.shape[1])
969
+ ):
970
+ ax.text(
971
+ j,
972
+ i,
973
+ "%s [%s]" % (round(cm[i][j], 3), raw_cm[i, j]),
974
+ horizontalalignment="center",
975
+ verticalalignment="center",
976
+ rotation=45,
977
+ fontsize=max(3, 10 - max(raw_cm.shape[0], raw_cm.shape[1])),
978
+ color="white" if cm[i, j] > 0.5 else "black",
979
+ )
980
+
981
+ ax.set_ylabel("True label", fontsize=10)
982
+ ax.set_xlabel("Predicted label", fontsize=10)
983
+ ax.grid(False)