oracle-ads 2.13.9rc0__py3-none-any.whl → 2.13.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (858) hide show
  1. ads/aqua/__init__.py +40 -0
  2. ads/aqua/app.py +507 -0
  3. ads/aqua/cli.py +96 -0
  4. ads/aqua/client/__init__.py +3 -0
  5. ads/aqua/client/client.py +836 -0
  6. ads/aqua/client/openai_client.py +305 -0
  7. ads/aqua/common/__init__.py +5 -0
  8. ads/aqua/common/decorator.py +125 -0
  9. ads/aqua/common/entities.py +274 -0
  10. ads/aqua/common/enums.py +134 -0
  11. ads/aqua/common/errors.py +109 -0
  12. ads/aqua/common/utils.py +1295 -0
  13. ads/aqua/config/__init__.py +4 -0
  14. ads/aqua/config/container_config.py +247 -0
  15. ads/aqua/config/evaluation/__init__.py +4 -0
  16. ads/aqua/config/evaluation/evaluation_service_config.py +147 -0
  17. ads/aqua/config/utils/__init__.py +4 -0
  18. ads/aqua/config/utils/serializer.py +339 -0
  19. ads/aqua/constants.py +116 -0
  20. ads/aqua/data.py +14 -0
  21. ads/aqua/dummy_data/icon.txt +1 -0
  22. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  23. ads/aqua/dummy_data/oci_models.json +1 -0
  24. ads/aqua/dummy_data/readme.md +26 -0
  25. ads/aqua/evaluation/__init__.py +8 -0
  26. ads/aqua/evaluation/constants.py +53 -0
  27. ads/aqua/evaluation/entities.py +186 -0
  28. ads/aqua/evaluation/errors.py +70 -0
  29. ads/aqua/evaluation/evaluation.py +1814 -0
  30. ads/aqua/extension/__init__.py +42 -0
  31. ads/aqua/extension/aqua_ws_msg_handler.py +76 -0
  32. ads/aqua/extension/base_handler.py +90 -0
  33. ads/aqua/extension/common_handler.py +121 -0
  34. ads/aqua/extension/common_ws_msg_handler.py +36 -0
  35. ads/aqua/extension/deployment_handler.py +381 -0
  36. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  37. ads/aqua/extension/errors.py +30 -0
  38. ads/aqua/extension/evaluation_handler.py +129 -0
  39. ads/aqua/extension/evaluation_ws_msg_handler.py +61 -0
  40. ads/aqua/extension/finetune_handler.py +96 -0
  41. ads/aqua/extension/model_handler.py +390 -0
  42. ads/aqua/extension/models/__init__.py +0 -0
  43. ads/aqua/extension/models/ws_models.py +145 -0
  44. ads/aqua/extension/models_ws_msg_handler.py +50 -0
  45. ads/aqua/extension/ui_handler.py +300 -0
  46. ads/aqua/extension/ui_websocket_handler.py +130 -0
  47. ads/aqua/extension/utils.py +133 -0
  48. ads/aqua/finetuning/__init__.py +7 -0
  49. ads/aqua/finetuning/constants.py +23 -0
  50. ads/aqua/finetuning/entities.py +181 -0
  51. ads/aqua/finetuning/finetuning.py +749 -0
  52. ads/aqua/model/__init__.py +8 -0
  53. ads/aqua/model/constants.py +60 -0
  54. ads/aqua/model/entities.py +385 -0
  55. ads/aqua/model/enums.py +32 -0
  56. ads/aqua/model/model.py +2134 -0
  57. ads/aqua/model/utils.py +52 -0
  58. ads/aqua/modeldeployment/__init__.py +6 -0
  59. ads/aqua/modeldeployment/constants.py +10 -0
  60. ads/aqua/modeldeployment/deployment.py +1315 -0
  61. ads/aqua/modeldeployment/entities.py +653 -0
  62. ads/aqua/modeldeployment/utils.py +543 -0
  63. ads/aqua/resources/gpu_shapes_index.json +94 -0
  64. ads/aqua/server/__init__.py +4 -0
  65. ads/aqua/server/__main__.py +24 -0
  66. ads/aqua/server/app.py +47 -0
  67. ads/aqua/server/aqua_spec.yml +1291 -0
  68. ads/aqua/training/__init__.py +4 -0
  69. ads/aqua/training/exceptions.py +476 -0
  70. ads/aqua/ui.py +519 -0
  71. ads/automl/__init__.py +9 -0
  72. ads/automl/driver.py +330 -0
  73. ads/automl/provider.py +975 -0
  74. ads/bds/__init__.py +5 -0
  75. ads/bds/auth.py +127 -0
  76. ads/bds/big_data_service.py +255 -0
  77. ads/catalog/__init__.py +19 -0
  78. ads/catalog/model.py +1576 -0
  79. ads/catalog/notebook.py +461 -0
  80. ads/catalog/project.py +468 -0
  81. ads/catalog/summary.py +178 -0
  82. ads/common/__init__.py +11 -0
  83. ads/common/analyzer.py +65 -0
  84. ads/common/artifact/.model-ignore +63 -0
  85. ads/common/artifact/__init__.py +10 -0
  86. ads/common/auth.py +1122 -0
  87. ads/common/card_identifier.py +83 -0
  88. ads/common/config.py +647 -0
  89. ads/common/data.py +165 -0
  90. ads/common/decorator/__init__.py +9 -0
  91. ads/common/decorator/argument_to_case.py +88 -0
  92. ads/common/decorator/deprecate.py +69 -0
  93. ads/common/decorator/require_nonempty_arg.py +65 -0
  94. ads/common/decorator/runtime_dependency.py +178 -0
  95. ads/common/decorator/threaded.py +97 -0
  96. ads/common/decorator/utils.py +35 -0
  97. ads/common/dsc_file_system.py +303 -0
  98. ads/common/error.py +14 -0
  99. ads/common/extended_enum.py +81 -0
  100. ads/common/function/__init__.py +5 -0
  101. ads/common/function/fn_util.py +142 -0
  102. ads/common/function/func_conf.yaml +25 -0
  103. ads/common/ipython.py +76 -0
  104. ads/common/model.py +679 -0
  105. ads/common/model_artifact.py +1759 -0
  106. ads/common/model_artifact_schema.json +107 -0
  107. ads/common/model_export_util.py +664 -0
  108. ads/common/model_metadata.py +24 -0
  109. ads/common/object_storage_details.py +296 -0
  110. ads/common/oci_client.py +179 -0
  111. ads/common/oci_datascience.py +46 -0
  112. ads/common/oci_logging.py +1144 -0
  113. ads/common/oci_mixin.py +957 -0
  114. ads/common/oci_resource.py +136 -0
  115. ads/common/serializer.py +559 -0
  116. ads/common/utils.py +1852 -0
  117. ads/common/word_lists.py +1491 -0
  118. ads/common/work_request.py +189 -0
  119. ads/config.py +1 -0
  120. ads/data_labeling/__init__.py +13 -0
  121. ads/data_labeling/boundingbox.py +253 -0
  122. ads/data_labeling/constants.py +47 -0
  123. ads/data_labeling/data_labeling_service.py +244 -0
  124. ads/data_labeling/interface/__init__.py +5 -0
  125. ads/data_labeling/interface/loader.py +16 -0
  126. ads/data_labeling/interface/parser.py +16 -0
  127. ads/data_labeling/interface/reader.py +23 -0
  128. ads/data_labeling/loader/__init__.py +5 -0
  129. ads/data_labeling/loader/file_loader.py +241 -0
  130. ads/data_labeling/metadata.py +110 -0
  131. ads/data_labeling/mixin/__init__.py +5 -0
  132. ads/data_labeling/mixin/data_labeling.py +232 -0
  133. ads/data_labeling/ner.py +129 -0
  134. ads/data_labeling/parser/__init__.py +5 -0
  135. ads/data_labeling/parser/dls_record_parser.py +388 -0
  136. ads/data_labeling/parser/export_metadata_parser.py +94 -0
  137. ads/data_labeling/parser/export_record_parser.py +473 -0
  138. ads/data_labeling/reader/__init__.py +5 -0
  139. ads/data_labeling/reader/dataset_reader.py +574 -0
  140. ads/data_labeling/reader/dls_record_reader.py +121 -0
  141. ads/data_labeling/reader/export_record_reader.py +62 -0
  142. ads/data_labeling/reader/jsonl_reader.py +75 -0
  143. ads/data_labeling/reader/metadata_reader.py +203 -0
  144. ads/data_labeling/reader/record_reader.py +263 -0
  145. ads/data_labeling/record.py +52 -0
  146. ads/data_labeling/visualizer/__init__.py +5 -0
  147. ads/data_labeling/visualizer/image_visualizer.py +525 -0
  148. ads/data_labeling/visualizer/text_visualizer.py +357 -0
  149. ads/database/__init__.py +5 -0
  150. ads/database/connection.py +338 -0
  151. ads/dataset/__init__.py +10 -0
  152. ads/dataset/capabilities.md +51 -0
  153. ads/dataset/classification_dataset.py +339 -0
  154. ads/dataset/correlation.py +226 -0
  155. ads/dataset/correlation_plot.py +563 -0
  156. ads/dataset/dask_series.py +173 -0
  157. ads/dataset/dataframe_transformer.py +110 -0
  158. ads/dataset/dataset.py +1979 -0
  159. ads/dataset/dataset_browser.py +360 -0
  160. ads/dataset/dataset_with_target.py +995 -0
  161. ads/dataset/exception.py +25 -0
  162. ads/dataset/factory.py +987 -0
  163. ads/dataset/feature_engineering_transformer.py +35 -0
  164. ads/dataset/feature_selection.py +107 -0
  165. ads/dataset/forecasting_dataset.py +26 -0
  166. ads/dataset/helper.py +1450 -0
  167. ads/dataset/label_encoder.py +99 -0
  168. ads/dataset/mixin/__init__.py +5 -0
  169. ads/dataset/mixin/dataset_accessor.py +134 -0
  170. ads/dataset/pipeline.py +58 -0
  171. ads/dataset/plot.py +710 -0
  172. ads/dataset/progress.py +86 -0
  173. ads/dataset/recommendation.py +297 -0
  174. ads/dataset/recommendation_transformer.py +502 -0
  175. ads/dataset/regression_dataset.py +14 -0
  176. ads/dataset/sampled_dataset.py +1050 -0
  177. ads/dataset/target.py +98 -0
  178. ads/dataset/timeseries.py +18 -0
  179. ads/dbmixin/__init__.py +5 -0
  180. ads/dbmixin/db_pandas_accessor.py +153 -0
  181. ads/environment/__init__.py +9 -0
  182. ads/environment/ml_runtime.py +66 -0
  183. ads/evaluations/README.md +14 -0
  184. ads/evaluations/__init__.py +109 -0
  185. ads/evaluations/evaluation_plot.py +983 -0
  186. ads/evaluations/evaluator.py +1334 -0
  187. ads/evaluations/statistical_metrics.py +543 -0
  188. ads/experiments/__init__.py +9 -0
  189. ads/experiments/capabilities.md +0 -0
  190. ads/explanations/__init__.py +21 -0
  191. ads/explanations/base_explainer.py +142 -0
  192. ads/explanations/capabilities.md +83 -0
  193. ads/explanations/explainer.py +190 -0
  194. ads/explanations/mlx_global_explainer.py +1050 -0
  195. ads/explanations/mlx_interface.py +386 -0
  196. ads/explanations/mlx_local_explainer.py +287 -0
  197. ads/explanations/mlx_whatif_explainer.py +201 -0
  198. ads/feature_engineering/__init__.py +20 -0
  199. ads/feature_engineering/accessor/__init__.py +5 -0
  200. ads/feature_engineering/accessor/dataframe_accessor.py +535 -0
  201. ads/feature_engineering/accessor/mixin/__init__.py +5 -0
  202. ads/feature_engineering/accessor/mixin/correlation.py +166 -0
  203. ads/feature_engineering/accessor/mixin/eda_mixin.py +266 -0
  204. ads/feature_engineering/accessor/mixin/eda_mixin_series.py +85 -0
  205. ads/feature_engineering/accessor/mixin/feature_types_mixin.py +211 -0
  206. ads/feature_engineering/accessor/mixin/utils.py +65 -0
  207. ads/feature_engineering/accessor/series_accessor.py +431 -0
  208. ads/feature_engineering/adsimage/__init__.py +5 -0
  209. ads/feature_engineering/adsimage/image.py +192 -0
  210. ads/feature_engineering/adsimage/image_reader.py +170 -0
  211. ads/feature_engineering/adsimage/interface/__init__.py +5 -0
  212. ads/feature_engineering/adsimage/interface/reader.py +19 -0
  213. ads/feature_engineering/adsstring/__init__.py +7 -0
  214. ads/feature_engineering/adsstring/oci_language/__init__.py +8 -0
  215. ads/feature_engineering/adsstring/string/__init__.py +8 -0
  216. ads/feature_engineering/data_schema.json +57 -0
  217. ads/feature_engineering/dataset/__init__.py +5 -0
  218. ads/feature_engineering/dataset/zip_code_data.py +42062 -0
  219. ads/feature_engineering/exceptions.py +40 -0
  220. ads/feature_engineering/feature_type/__init__.py +133 -0
  221. ads/feature_engineering/feature_type/address.py +184 -0
  222. ads/feature_engineering/feature_type/adsstring/__init__.py +5 -0
  223. ads/feature_engineering/feature_type/adsstring/common_regex_mixin.py +164 -0
  224. ads/feature_engineering/feature_type/adsstring/oci_language.py +93 -0
  225. ads/feature_engineering/feature_type/adsstring/parsers/__init__.py +5 -0
  226. ads/feature_engineering/feature_type/adsstring/parsers/base.py +47 -0
  227. ads/feature_engineering/feature_type/adsstring/parsers/nltk_parser.py +96 -0
  228. ads/feature_engineering/feature_type/adsstring/parsers/spacy_parser.py +221 -0
  229. ads/feature_engineering/feature_type/adsstring/string.py +258 -0
  230. ads/feature_engineering/feature_type/base.py +58 -0
  231. ads/feature_engineering/feature_type/boolean.py +183 -0
  232. ads/feature_engineering/feature_type/category.py +146 -0
  233. ads/feature_engineering/feature_type/constant.py +137 -0
  234. ads/feature_engineering/feature_type/continuous.py +151 -0
  235. ads/feature_engineering/feature_type/creditcard.py +314 -0
  236. ads/feature_engineering/feature_type/datetime.py +190 -0
  237. ads/feature_engineering/feature_type/discrete.py +134 -0
  238. ads/feature_engineering/feature_type/document.py +43 -0
  239. ads/feature_engineering/feature_type/gis.py +251 -0
  240. ads/feature_engineering/feature_type/handler/__init__.py +5 -0
  241. ads/feature_engineering/feature_type/handler/feature_validator.py +524 -0
  242. ads/feature_engineering/feature_type/handler/feature_warning.py +319 -0
  243. ads/feature_engineering/feature_type/handler/warnings.py +128 -0
  244. ads/feature_engineering/feature_type/integer.py +142 -0
  245. ads/feature_engineering/feature_type/ip_address.py +144 -0
  246. ads/feature_engineering/feature_type/ip_address_v4.py +138 -0
  247. ads/feature_engineering/feature_type/ip_address_v6.py +138 -0
  248. ads/feature_engineering/feature_type/lat_long.py +256 -0
  249. ads/feature_engineering/feature_type/object.py +43 -0
  250. ads/feature_engineering/feature_type/ordinal.py +132 -0
  251. ads/feature_engineering/feature_type/phone_number.py +135 -0
  252. ads/feature_engineering/feature_type/string.py +171 -0
  253. ads/feature_engineering/feature_type/text.py +93 -0
  254. ads/feature_engineering/feature_type/unknown.py +43 -0
  255. ads/feature_engineering/feature_type/zip_code.py +164 -0
  256. ads/feature_engineering/feature_type_manager.py +406 -0
  257. ads/feature_engineering/schema.py +795 -0
  258. ads/feature_engineering/utils.py +245 -0
  259. ads/feature_store/.readthedocs.yaml +19 -0
  260. ads/feature_store/README.md +65 -0
  261. ads/feature_store/__init__.py +9 -0
  262. ads/feature_store/common/__init__.py +0 -0
  263. ads/feature_store/common/enums.py +339 -0
  264. ads/feature_store/common/exceptions.py +18 -0
  265. ads/feature_store/common/spark_session_singleton.py +125 -0
  266. ads/feature_store/common/utils/__init__.py +0 -0
  267. ads/feature_store/common/utils/base64_encoder_decoder.py +72 -0
  268. ads/feature_store/common/utils/feature_schema_mapper.py +283 -0
  269. ads/feature_store/common/utils/transformation_utils.py +82 -0
  270. ads/feature_store/common/utils/utility.py +403 -0
  271. ads/feature_store/data_validation/__init__.py +0 -0
  272. ads/feature_store/data_validation/great_expectation.py +129 -0
  273. ads/feature_store/dataset.py +1230 -0
  274. ads/feature_store/dataset_job.py +530 -0
  275. ads/feature_store/docs/Dockerfile +7 -0
  276. ads/feature_store/docs/Makefile +44 -0
  277. ads/feature_store/docs/conf.py +28 -0
  278. ads/feature_store/docs/requirements.txt +14 -0
  279. ads/feature_store/docs/source/ads.feature_store.query.rst +20 -0
  280. ads/feature_store/docs/source/cicd.rst +137 -0
  281. ads/feature_store/docs/source/conf.py +86 -0
  282. ads/feature_store/docs/source/data_versioning.rst +33 -0
  283. ads/feature_store/docs/source/dataset.rst +388 -0
  284. ads/feature_store/docs/source/dataset_job.rst +27 -0
  285. ads/feature_store/docs/source/demo.rst +70 -0
  286. ads/feature_store/docs/source/entity.rst +78 -0
  287. ads/feature_store/docs/source/feature_group.rst +624 -0
  288. ads/feature_store/docs/source/feature_group_job.rst +29 -0
  289. ads/feature_store/docs/source/feature_store.rst +122 -0
  290. ads/feature_store/docs/source/feature_store_class.rst +123 -0
  291. ads/feature_store/docs/source/feature_validation.rst +66 -0
  292. ads/feature_store/docs/source/figures/cicd.png +0 -0
  293. ads/feature_store/docs/source/figures/data_validation.png +0 -0
  294. ads/feature_store/docs/source/figures/data_versioning.png +0 -0
  295. ads/feature_store/docs/source/figures/dataset.gif +0 -0
  296. ads/feature_store/docs/source/figures/dataset.png +0 -0
  297. ads/feature_store/docs/source/figures/dataset_lineage.png +0 -0
  298. ads/feature_store/docs/source/figures/dataset_statistics.png +0 -0
  299. ads/feature_store/docs/source/figures/dataset_statistics_viz.png +0 -0
  300. ads/feature_store/docs/source/figures/dataset_validation_results.png +0 -0
  301. ads/feature_store/docs/source/figures/dataset_validation_summary.png +0 -0
  302. ads/feature_store/docs/source/figures/drift_monitoring.png +0 -0
  303. ads/feature_store/docs/source/figures/entity.png +0 -0
  304. ads/feature_store/docs/source/figures/feature_group.png +0 -0
  305. ads/feature_store/docs/source/figures/feature_group_lineage.png +0 -0
  306. ads/feature_store/docs/source/figures/feature_group_statistics_viz.png +0 -0
  307. ads/feature_store/docs/source/figures/feature_store_deployment.png +0 -0
  308. ads/feature_store/docs/source/figures/feature_store_overview.png +0 -0
  309. ads/feature_store/docs/source/figures/featuregroup.gif +0 -0
  310. ads/feature_store/docs/source/figures/lineage_d1.png +0 -0
  311. ads/feature_store/docs/source/figures/lineage_d2.png +0 -0
  312. ads/feature_store/docs/source/figures/lineage_fg.png +0 -0
  313. ads/feature_store/docs/source/figures/logo-dark-mode.png +0 -0
  314. ads/feature_store/docs/source/figures/logo-light-mode.png +0 -0
  315. ads/feature_store/docs/source/figures/overview.png +0 -0
  316. ads/feature_store/docs/source/figures/resource_manager.png +0 -0
  317. ads/feature_store/docs/source/figures/resource_manager_feature_store_stack.png +0 -0
  318. ads/feature_store/docs/source/figures/resource_manager_home.png +0 -0
  319. ads/feature_store/docs/source/figures/stats_1.png +0 -0
  320. ads/feature_store/docs/source/figures/stats_2.png +0 -0
  321. ads/feature_store/docs/source/figures/stats_d.png +0 -0
  322. ads/feature_store/docs/source/figures/stats_fg.png +0 -0
  323. ads/feature_store/docs/source/figures/transformation.png +0 -0
  324. ads/feature_store/docs/source/figures/transformations.gif +0 -0
  325. ads/feature_store/docs/source/figures/validation.png +0 -0
  326. ads/feature_store/docs/source/figures/validation_fg.png +0 -0
  327. ads/feature_store/docs/source/figures/validation_results.png +0 -0
  328. ads/feature_store/docs/source/figures/validation_summary.png +0 -0
  329. ads/feature_store/docs/source/index.rst +81 -0
  330. ads/feature_store/docs/source/module.rst +8 -0
  331. ads/feature_store/docs/source/notebook.rst +94 -0
  332. ads/feature_store/docs/source/overview.rst +47 -0
  333. ads/feature_store/docs/source/quickstart.rst +176 -0
  334. ads/feature_store/docs/source/release_notes.rst +194 -0
  335. ads/feature_store/docs/source/setup_feature_store.rst +81 -0
  336. ads/feature_store/docs/source/statistics.rst +58 -0
  337. ads/feature_store/docs/source/transformation.rst +199 -0
  338. ads/feature_store/docs/source/ui.rst +65 -0
  339. ads/feature_store/docs/source/user_guides.setup.feature_store_operator.rst +66 -0
  340. ads/feature_store/docs/source/user_guides.setup.helm_chart.rst +192 -0
  341. ads/feature_store/docs/source/user_guides.setup.terraform.rst +338 -0
  342. ads/feature_store/entity.py +718 -0
  343. ads/feature_store/execution_strategy/__init__.py +0 -0
  344. ads/feature_store/execution_strategy/delta_lake/__init__.py +0 -0
  345. ads/feature_store/execution_strategy/delta_lake/delta_lake_service.py +375 -0
  346. ads/feature_store/execution_strategy/engine/__init__.py +0 -0
  347. ads/feature_store/execution_strategy/engine/spark_engine.py +316 -0
  348. ads/feature_store/execution_strategy/execution_strategy.py +113 -0
  349. ads/feature_store/execution_strategy/execution_strategy_provider.py +47 -0
  350. ads/feature_store/execution_strategy/spark/__init__.py +0 -0
  351. ads/feature_store/execution_strategy/spark/spark_execution.py +618 -0
  352. ads/feature_store/feature.py +192 -0
  353. ads/feature_store/feature_group.py +1494 -0
  354. ads/feature_store/feature_group_expectation.py +346 -0
  355. ads/feature_store/feature_group_job.py +602 -0
  356. ads/feature_store/feature_lineage/__init__.py +0 -0
  357. ads/feature_store/feature_lineage/graphviz_service.py +180 -0
  358. ads/feature_store/feature_option_details.py +50 -0
  359. ads/feature_store/feature_statistics/__init__.py +0 -0
  360. ads/feature_store/feature_statistics/statistics_service.py +99 -0
  361. ads/feature_store/feature_store.py +699 -0
  362. ads/feature_store/feature_store_registrar.py +518 -0
  363. ads/feature_store/input_feature_detail.py +149 -0
  364. ads/feature_store/mixin/__init__.py +4 -0
  365. ads/feature_store/mixin/oci_feature_store.py +145 -0
  366. ads/feature_store/model_details.py +73 -0
  367. ads/feature_store/query/__init__.py +0 -0
  368. ads/feature_store/query/filter.py +266 -0
  369. ads/feature_store/query/generator/__init__.py +0 -0
  370. ads/feature_store/query/generator/query_generator.py +298 -0
  371. ads/feature_store/query/join.py +161 -0
  372. ads/feature_store/query/query.py +403 -0
  373. ads/feature_store/query/validator/__init__.py +0 -0
  374. ads/feature_store/query/validator/query_validator.py +57 -0
  375. ads/feature_store/response/__init__.py +0 -0
  376. ads/feature_store/response/response_builder.py +68 -0
  377. ads/feature_store/service/__init__.py +0 -0
  378. ads/feature_store/service/oci_dataset.py +139 -0
  379. ads/feature_store/service/oci_dataset_job.py +199 -0
  380. ads/feature_store/service/oci_entity.py +125 -0
  381. ads/feature_store/service/oci_feature_group.py +164 -0
  382. ads/feature_store/service/oci_feature_group_job.py +214 -0
  383. ads/feature_store/service/oci_feature_store.py +182 -0
  384. ads/feature_store/service/oci_lineage.py +87 -0
  385. ads/feature_store/service/oci_transformation.py +104 -0
  386. ads/feature_store/statistics/__init__.py +0 -0
  387. ads/feature_store/statistics/abs_feature_value.py +49 -0
  388. ads/feature_store/statistics/charts/__init__.py +0 -0
  389. ads/feature_store/statistics/charts/abstract_feature_plot.py +37 -0
  390. ads/feature_store/statistics/charts/box_plot.py +148 -0
  391. ads/feature_store/statistics/charts/frequency_distribution.py +65 -0
  392. ads/feature_store/statistics/charts/probability_distribution.py +68 -0
  393. ads/feature_store/statistics/charts/top_k_frequent_elements.py +98 -0
  394. ads/feature_store/statistics/feature_stat.py +126 -0
  395. ads/feature_store/statistics/generic_feature_value.py +33 -0
  396. ads/feature_store/statistics/statistics.py +41 -0
  397. ads/feature_store/statistics_config.py +101 -0
  398. ads/feature_store/templates/feature_store_template.yaml +45 -0
  399. ads/feature_store/transformation.py +499 -0
  400. ads/feature_store/validation_output.py +57 -0
  401. ads/hpo/__init__.py +9 -0
  402. ads/hpo/_imports.py +91 -0
  403. ads/hpo/ads_search_space.py +439 -0
  404. ads/hpo/distributions.py +325 -0
  405. ads/hpo/objective.py +280 -0
  406. ads/hpo/search_cv.py +1657 -0
  407. ads/hpo/stopping_criterion.py +75 -0
  408. ads/hpo/tuner_artifact.py +413 -0
  409. ads/hpo/utils.py +91 -0
  410. ads/hpo/validation.py +140 -0
  411. ads/hpo/visualization/__init__.py +5 -0
  412. ads/hpo/visualization/_contour.py +23 -0
  413. ads/hpo/visualization/_edf.py +20 -0
  414. ads/hpo/visualization/_intermediate_values.py +21 -0
  415. ads/hpo/visualization/_optimization_history.py +25 -0
  416. ads/hpo/visualization/_parallel_coordinate.py +169 -0
  417. ads/hpo/visualization/_param_importances.py +26 -0
  418. ads/jobs/__init__.py +53 -0
  419. ads/jobs/ads_job.py +663 -0
  420. ads/jobs/builders/__init__.py +5 -0
  421. ads/jobs/builders/base.py +156 -0
  422. ads/jobs/builders/infrastructure/__init__.py +6 -0
  423. ads/jobs/builders/infrastructure/base.py +165 -0
  424. ads/jobs/builders/infrastructure/dataflow.py +1252 -0
  425. ads/jobs/builders/infrastructure/dsc_job.py +1894 -0
  426. ads/jobs/builders/infrastructure/dsc_job_runtime.py +1233 -0
  427. ads/jobs/builders/infrastructure/utils.py +65 -0
  428. ads/jobs/builders/runtimes/__init__.py +5 -0
  429. ads/jobs/builders/runtimes/artifact.py +338 -0
  430. ads/jobs/builders/runtimes/base.py +325 -0
  431. ads/jobs/builders/runtimes/container_runtime.py +242 -0
  432. ads/jobs/builders/runtimes/python_runtime.py +1016 -0
  433. ads/jobs/builders/runtimes/pytorch_runtime.py +204 -0
  434. ads/jobs/cli.py +104 -0
  435. ads/jobs/env_var_parser.py +131 -0
  436. ads/jobs/extension.py +160 -0
  437. ads/jobs/schema/__init__.py +5 -0
  438. ads/jobs/schema/infrastructure_schema.json +116 -0
  439. ads/jobs/schema/job_schema.json +42 -0
  440. ads/jobs/schema/runtime_schema.json +183 -0
  441. ads/jobs/schema/validator.py +141 -0
  442. ads/jobs/serializer.py +296 -0
  443. ads/jobs/templates/__init__.py +5 -0
  444. ads/jobs/templates/container.py +6 -0
  445. ads/jobs/templates/driver_notebook.py +177 -0
  446. ads/jobs/templates/driver_oci.py +500 -0
  447. ads/jobs/templates/driver_python.py +48 -0
  448. ads/jobs/templates/driver_pytorch.py +852 -0
  449. ads/jobs/templates/driver_utils.py +615 -0
  450. ads/jobs/templates/hostname_from_env.c +55 -0
  451. ads/jobs/templates/oci_metrics.py +181 -0
  452. ads/jobs/utils.py +104 -0
  453. ads/llm/__init__.py +28 -0
  454. ads/llm/autogen/__init__.py +2 -0
  455. ads/llm/autogen/constants.py +15 -0
  456. ads/llm/autogen/reports/__init__.py +2 -0
  457. ads/llm/autogen/reports/base.py +67 -0
  458. ads/llm/autogen/reports/data.py +103 -0
  459. ads/llm/autogen/reports/session.py +526 -0
  460. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  461. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  462. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  463. ads/llm/autogen/reports/utils.py +56 -0
  464. ads/llm/autogen/v02/__init__.py +4 -0
  465. ads/llm/autogen/v02/client.py +295 -0
  466. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  467. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  468. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  469. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  470. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  471. ads/llm/autogen/v02/loggers/utils.py +86 -0
  472. ads/llm/autogen/v02/runtime_logging.py +163 -0
  473. ads/llm/chain.py +268 -0
  474. ads/llm/chat_template.py +31 -0
  475. ads/llm/deploy.py +63 -0
  476. ads/llm/guardrails/__init__.py +5 -0
  477. ads/llm/guardrails/base.py +442 -0
  478. ads/llm/guardrails/huggingface.py +44 -0
  479. ads/llm/langchain/__init__.py +5 -0
  480. ads/llm/langchain/plugins/__init__.py +5 -0
  481. ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
  482. ads/llm/langchain/plugins/chat_models/oci_data_science.py +1027 -0
  483. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  484. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  485. ads/llm/langchain/plugins/llms/__init__.py +5 -0
  486. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +979 -0
  487. ads/llm/requirements.txt +3 -0
  488. ads/llm/serialize.py +219 -0
  489. ads/llm/serializers/__init__.py +0 -0
  490. ads/llm/serializers/retrieval_qa.py +153 -0
  491. ads/llm/serializers/runnable_parallel.py +27 -0
  492. ads/llm/templates/score_chain.jinja2 +155 -0
  493. ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
  494. ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
  495. ads/model/__init__.py +52 -0
  496. ads/model/artifact.py +573 -0
  497. ads/model/artifact_downloader.py +254 -0
  498. ads/model/artifact_uploader.py +267 -0
  499. ads/model/base_properties.py +238 -0
  500. ads/model/common/.model-ignore +66 -0
  501. ads/model/common/__init__.py +5 -0
  502. ads/model/common/utils.py +142 -0
  503. ads/model/datascience_model.py +2635 -0
  504. ads/model/deployment/__init__.py +20 -0
  505. ads/model/deployment/common/__init__.py +5 -0
  506. ads/model/deployment/common/utils.py +308 -0
  507. ads/model/deployment/model_deployer.py +466 -0
  508. ads/model/deployment/model_deployment.py +1846 -0
  509. ads/model/deployment/model_deployment_infrastructure.py +671 -0
  510. ads/model/deployment/model_deployment_properties.py +493 -0
  511. ads/model/deployment/model_deployment_runtime.py +838 -0
  512. ads/model/extractor/__init__.py +5 -0
  513. ads/model/extractor/automl_extractor.py +74 -0
  514. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  515. ads/model/extractor/huggingface_extractor.py +88 -0
  516. ads/model/extractor/keras_extractor.py +84 -0
  517. ads/model/extractor/lightgbm_extractor.py +93 -0
  518. ads/model/extractor/model_info_extractor.py +114 -0
  519. ads/model/extractor/model_info_extractor_factory.py +105 -0
  520. ads/model/extractor/pytorch_extractor.py +87 -0
  521. ads/model/extractor/sklearn_extractor.py +112 -0
  522. ads/model/extractor/spark_extractor.py +89 -0
  523. ads/model/extractor/tensorflow_extractor.py +85 -0
  524. ads/model/extractor/xgboost_extractor.py +94 -0
  525. ads/model/framework/__init__.py +5 -0
  526. ads/model/framework/automl_model.py +178 -0
  527. ads/model/framework/embedding_onnx_model.py +438 -0
  528. ads/model/framework/huggingface_model.py +399 -0
  529. ads/model/framework/lightgbm_model.py +266 -0
  530. ads/model/framework/pytorch_model.py +266 -0
  531. ads/model/framework/sklearn_model.py +250 -0
  532. ads/model/framework/spark_model.py +326 -0
  533. ads/model/framework/tensorflow_model.py +254 -0
  534. ads/model/framework/xgboost_model.py +258 -0
  535. ads/model/generic_model.py +3518 -0
  536. ads/model/model_artifact_boilerplate/README.md +381 -0
  537. ads/model/model_artifact_boilerplate/__init__.py +5 -0
  538. ads/model/model_artifact_boilerplate/artifact_introspection_test/__init__.py +5 -0
  539. ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +427 -0
  540. ads/model/model_artifact_boilerplate/artifact_introspection_test/requirements.txt +2 -0
  541. ads/model/model_artifact_boilerplate/runtime.yaml +7 -0
  542. ads/model/model_artifact_boilerplate/score.py +61 -0
  543. ads/model/model_file_description_schema.json +68 -0
  544. ads/model/model_introspect.py +331 -0
  545. ads/model/model_metadata.py +1810 -0
  546. ads/model/model_metadata_mixin.py +460 -0
  547. ads/model/model_properties.py +63 -0
  548. ads/model/model_version_set.py +739 -0
  549. ads/model/runtime/__init__.py +5 -0
  550. ads/model/runtime/env_info.py +306 -0
  551. ads/model/runtime/model_deployment_details.py +37 -0
  552. ads/model/runtime/model_provenance_details.py +58 -0
  553. ads/model/runtime/runtime_info.py +81 -0
  554. ads/model/runtime/schemas/inference_env_info_schema.yaml +16 -0
  555. ads/model/runtime/schemas/model_provenance_schema.yaml +36 -0
  556. ads/model/runtime/schemas/training_env_info_schema.yaml +16 -0
  557. ads/model/runtime/utils.py +201 -0
  558. ads/model/serde/__init__.py +5 -0
  559. ads/model/serde/common.py +40 -0
  560. ads/model/serde/model_input.py +547 -0
  561. ads/model/serde/model_serializer.py +1184 -0
  562. ads/model/service/__init__.py +5 -0
  563. ads/model/service/oci_datascience_model.py +1076 -0
  564. ads/model/service/oci_datascience_model_deployment.py +500 -0
  565. ads/model/service/oci_datascience_model_version_set.py +176 -0
  566. ads/model/transformer/__init__.py +5 -0
  567. ads/model/transformer/onnx_transformer.py +324 -0
  568. ads/mysqldb/__init__.py +5 -0
  569. ads/mysqldb/mysql_db.py +227 -0
  570. ads/opctl/__init__.py +18 -0
  571. ads/opctl/anomaly_detection.py +11 -0
  572. ads/opctl/backend/__init__.py +5 -0
  573. ads/opctl/backend/ads_dataflow.py +353 -0
  574. ads/opctl/backend/ads_ml_job.py +710 -0
  575. ads/opctl/backend/ads_ml_pipeline.py +164 -0
  576. ads/opctl/backend/ads_model_deployment.py +209 -0
  577. ads/opctl/backend/base.py +146 -0
  578. ads/opctl/backend/local.py +1053 -0
  579. ads/opctl/backend/marketplace/__init__.py +9 -0
  580. ads/opctl/backend/marketplace/helm_helper.py +173 -0
  581. ads/opctl/backend/marketplace/local_marketplace.py +271 -0
  582. ads/opctl/backend/marketplace/marketplace_backend_runner.py +71 -0
  583. ads/opctl/backend/marketplace/marketplace_operator_interface.py +44 -0
  584. ads/opctl/backend/marketplace/marketplace_operator_runner.py +24 -0
  585. ads/opctl/backend/marketplace/marketplace_utils.py +212 -0
  586. ads/opctl/backend/marketplace/models/__init__.py +5 -0
  587. ads/opctl/backend/marketplace/models/bearer_token.py +94 -0
  588. ads/opctl/backend/marketplace/models/marketplace_type.py +70 -0
  589. ads/opctl/backend/marketplace/models/ocir_details.py +56 -0
  590. ads/opctl/backend/marketplace/prerequisite_checker.py +238 -0
  591. ads/opctl/cli.py +707 -0
  592. ads/opctl/cmds.py +869 -0
  593. ads/opctl/conda/__init__.py +5 -0
  594. ads/opctl/conda/cli.py +193 -0
  595. ads/opctl/conda/cmds.py +749 -0
  596. ads/opctl/conda/config.yaml +34 -0
  597. ads/opctl/conda/manifest_template.yaml +13 -0
  598. ads/opctl/conda/multipart_uploader.py +188 -0
  599. ads/opctl/conda/pack.py +89 -0
  600. ads/opctl/config/__init__.py +5 -0
  601. ads/opctl/config/base.py +57 -0
  602. ads/opctl/config/diagnostics/__init__.py +5 -0
  603. ads/opctl/config/diagnostics/distributed/default_requirements_config.yaml +62 -0
  604. ads/opctl/config/merger.py +255 -0
  605. ads/opctl/config/resolver.py +297 -0
  606. ads/opctl/config/utils.py +79 -0
  607. ads/opctl/config/validator.py +17 -0
  608. ads/opctl/config/versioner.py +68 -0
  609. ads/opctl/config/yaml_parsers/__init__.py +7 -0
  610. ads/opctl/config/yaml_parsers/base.py +58 -0
  611. ads/opctl/config/yaml_parsers/distributed/__init__.py +7 -0
  612. ads/opctl/config/yaml_parsers/distributed/yaml_parser.py +201 -0
  613. ads/opctl/constants.py +66 -0
  614. ads/opctl/decorator/__init__.py +5 -0
  615. ads/opctl/decorator/common.py +129 -0
  616. ads/opctl/diagnostics/__init__.py +5 -0
  617. ads/opctl/diagnostics/__main__.py +25 -0
  618. ads/opctl/diagnostics/check_distributed_job_requirements.py +212 -0
  619. ads/opctl/diagnostics/check_requirements.py +144 -0
  620. ads/opctl/diagnostics/requirement_exception.py +9 -0
  621. ads/opctl/distributed/README.md +109 -0
  622. ads/opctl/distributed/__init__.py +5 -0
  623. ads/opctl/distributed/certificates.py +32 -0
  624. ads/opctl/distributed/cli.py +207 -0
  625. ads/opctl/distributed/cmds.py +731 -0
  626. ads/opctl/distributed/common/__init__.py +5 -0
  627. ads/opctl/distributed/common/abstract_cluster_provider.py +449 -0
  628. ads/opctl/distributed/common/abstract_framework_spec_builder.py +88 -0
  629. ads/opctl/distributed/common/cluster_config_helper.py +103 -0
  630. ads/opctl/distributed/common/cluster_provider_factory.py +21 -0
  631. ads/opctl/distributed/common/cluster_runner.py +54 -0
  632. ads/opctl/distributed/common/framework_factory.py +29 -0
  633. ads/opctl/docker/Dockerfile.job +103 -0
  634. ads/opctl/docker/Dockerfile.job.arm +107 -0
  635. ads/opctl/docker/Dockerfile.job.gpu +175 -0
  636. ads/opctl/docker/base-env.yaml +13 -0
  637. ads/opctl/docker/cuda.repo +6 -0
  638. ads/opctl/docker/operator/.dockerignore +0 -0
  639. ads/opctl/docker/operator/Dockerfile +41 -0
  640. ads/opctl/docker/operator/Dockerfile.gpu +85 -0
  641. ads/opctl/docker/operator/cuda.repo +6 -0
  642. ads/opctl/docker/operator/environment.yaml +8 -0
  643. ads/opctl/forecast.py +11 -0
  644. ads/opctl/index.yaml +3 -0
  645. ads/opctl/model/__init__.py +5 -0
  646. ads/opctl/model/cli.py +65 -0
  647. ads/opctl/model/cmds.py +73 -0
  648. ads/opctl/operator/README.md +4 -0
  649. ads/opctl/operator/__init__.py +31 -0
  650. ads/opctl/operator/cli.py +344 -0
  651. ads/opctl/operator/cmd.py +596 -0
  652. ads/opctl/operator/common/__init__.py +5 -0
  653. ads/opctl/operator/common/backend_factory.py +460 -0
  654. ads/opctl/operator/common/const.py +27 -0
  655. ads/opctl/operator/common/data/synthetic.csv +16001 -0
  656. ads/opctl/operator/common/dictionary_merger.py +148 -0
  657. ads/opctl/operator/common/errors.py +42 -0
  658. ads/opctl/operator/common/operator_config.py +99 -0
  659. ads/opctl/operator/common/operator_loader.py +811 -0
  660. ads/opctl/operator/common/operator_schema.yaml +130 -0
  661. ads/opctl/operator/common/operator_yaml_generator.py +152 -0
  662. ads/opctl/operator/common/utils.py +208 -0
  663. ads/opctl/operator/lowcode/__init__.py +5 -0
  664. ads/opctl/operator/lowcode/anomaly/MLoperator +16 -0
  665. ads/opctl/operator/lowcode/anomaly/README.md +207 -0
  666. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  667. ads/opctl/operator/lowcode/anomaly/__main__.py +103 -0
  668. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  669. ads/opctl/operator/lowcode/anomaly/const.py +167 -0
  670. ads/opctl/operator/lowcode/anomaly/environment.yaml +10 -0
  671. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  672. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +146 -0
  673. ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +162 -0
  674. ads/opctl/operator/lowcode/anomaly/model/automlx.py +99 -0
  675. ads/opctl/operator/lowcode/anomaly/model/autots.py +115 -0
  676. ads/opctl/operator/lowcode/anomaly/model/base_model.py +404 -0
  677. ads/opctl/operator/lowcode/anomaly/model/factory.py +110 -0
  678. ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +78 -0
  679. ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +78 -0
  680. ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +120 -0
  681. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  682. ads/opctl/operator/lowcode/anomaly/operator_config.py +127 -0
  683. ads/opctl/operator/lowcode/anomaly/schema.yaml +401 -0
  684. ads/opctl/operator/lowcode/anomaly/utils.py +88 -0
  685. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  686. ads/opctl/operator/lowcode/common/const.py +10 -0
  687. ads/opctl/operator/lowcode/common/data.py +116 -0
  688. ads/opctl/operator/lowcode/common/errors.py +47 -0
  689. ads/opctl/operator/lowcode/common/transformations.py +296 -0
  690. ads/opctl/operator/lowcode/common/utils.py +384 -0
  691. ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator +13 -0
  692. ads/opctl/operator/lowcode/feature_store_marketplace/README.md +30 -0
  693. ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py +5 -0
  694. ads/opctl/operator/lowcode/feature_store_marketplace/__main__.py +116 -0
  695. ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py +85 -0
  696. ads/opctl/operator/lowcode/feature_store_marketplace/const.py +15 -0
  697. ads/opctl/operator/lowcode/feature_store_marketplace/environment.yaml +0 -0
  698. ads/opctl/operator/lowcode/feature_store_marketplace/models/__init__.py +4 -0
  699. ads/opctl/operator/lowcode/feature_store_marketplace/models/apigw_config.py +32 -0
  700. ads/opctl/operator/lowcode/feature_store_marketplace/models/db_config.py +43 -0
  701. ads/opctl/operator/lowcode/feature_store_marketplace/models/mysql_config.py +120 -0
  702. ads/opctl/operator/lowcode/feature_store_marketplace/models/serializable_yaml_model.py +34 -0
  703. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +386 -0
  704. ads/opctl/operator/lowcode/feature_store_marketplace/schema.yaml +160 -0
  705. ads/opctl/operator/lowcode/forecast/MLoperator +25 -0
  706. ads/opctl/operator/lowcode/forecast/README.md +209 -0
  707. ads/opctl/operator/lowcode/forecast/__init__.py +5 -0
  708. ads/opctl/operator/lowcode/forecast/__main__.py +89 -0
  709. ads/opctl/operator/lowcode/forecast/cmd.py +40 -0
  710. ads/opctl/operator/lowcode/forecast/const.py +92 -0
  711. ads/opctl/operator/lowcode/forecast/environment.yaml +20 -0
  712. ads/opctl/operator/lowcode/forecast/errors.py +26 -0
  713. ads/opctl/operator/lowcode/forecast/model/__init__.py +5 -0
  714. ads/opctl/operator/lowcode/forecast/model/arima.py +279 -0
  715. ads/opctl/operator/lowcode/forecast/model/automlx.py +553 -0
  716. ads/opctl/operator/lowcode/forecast/model/autots.py +312 -0
  717. ads/opctl/operator/lowcode/forecast/model/base_model.py +875 -0
  718. ads/opctl/operator/lowcode/forecast/model/factory.py +106 -0
  719. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +492 -0
  720. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +243 -0
  721. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +482 -0
  722. ads/opctl/operator/lowcode/forecast/model/prophet.py +450 -0
  723. ads/opctl/operator/lowcode/forecast/model_evaluator.py +244 -0
  724. ads/opctl/operator/lowcode/forecast/operator_config.py +234 -0
  725. ads/opctl/operator/lowcode/forecast/schema.yaml +506 -0
  726. ads/opctl/operator/lowcode/forecast/utils.py +397 -0
  727. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  728. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +285 -0
  729. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +246 -0
  730. ads/opctl/operator/lowcode/pii/MLoperator +17 -0
  731. ads/opctl/operator/lowcode/pii/README.md +208 -0
  732. ads/opctl/operator/lowcode/pii/__init__.py +5 -0
  733. ads/opctl/operator/lowcode/pii/__main__.py +78 -0
  734. ads/opctl/operator/lowcode/pii/cmd.py +39 -0
  735. ads/opctl/operator/lowcode/pii/constant.py +84 -0
  736. ads/opctl/operator/lowcode/pii/environment.yaml +17 -0
  737. ads/opctl/operator/lowcode/pii/errors.py +27 -0
  738. ads/opctl/operator/lowcode/pii/model/__init__.py +5 -0
  739. ads/opctl/operator/lowcode/pii/model/factory.py +82 -0
  740. ads/opctl/operator/lowcode/pii/model/guardrails.py +167 -0
  741. ads/opctl/operator/lowcode/pii/model/pii.py +145 -0
  742. ads/opctl/operator/lowcode/pii/model/processor/__init__.py +34 -0
  743. ads/opctl/operator/lowcode/pii/model/processor/email_replacer.py +34 -0
  744. ads/opctl/operator/lowcode/pii/model/processor/mbi_replacer.py +35 -0
  745. ads/opctl/operator/lowcode/pii/model/processor/name_replacer.py +225 -0
  746. ads/opctl/operator/lowcode/pii/model/processor/number_replacer.py +73 -0
  747. ads/opctl/operator/lowcode/pii/model/processor/remover.py +26 -0
  748. ads/opctl/operator/lowcode/pii/model/report.py +487 -0
  749. ads/opctl/operator/lowcode/pii/operator_config.py +95 -0
  750. ads/opctl/operator/lowcode/pii/schema.yaml +108 -0
  751. ads/opctl/operator/lowcode/pii/utils.py +43 -0
  752. ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
  753. ads/opctl/operator/lowcode/recommender/README.md +206 -0
  754. ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
  755. ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
  756. ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
  757. ads/opctl/operator/lowcode/recommender/constant.py +30 -0
  758. ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
  759. ads/opctl/operator/lowcode/recommender/model/base_model.py +212 -0
  760. ads/opctl/operator/lowcode/recommender/model/factory.py +56 -0
  761. ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
  762. ads/opctl/operator/lowcode/recommender/model/svd.py +106 -0
  763. ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
  764. ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
  765. ads/opctl/operator/lowcode/recommender/utils.py +13 -0
  766. ads/opctl/operator/runtime/__init__.py +5 -0
  767. ads/opctl/operator/runtime/const.py +17 -0
  768. ads/opctl/operator/runtime/container_runtime_schema.yaml +50 -0
  769. ads/opctl/operator/runtime/marketplace_runtime.py +50 -0
  770. ads/opctl/operator/runtime/python_marketplace_runtime_schema.yaml +21 -0
  771. ads/opctl/operator/runtime/python_runtime_schema.yaml +21 -0
  772. ads/opctl/operator/runtime/runtime.py +115 -0
  773. ads/opctl/schema.yaml.yml +36 -0
  774. ads/opctl/script.py +40 -0
  775. ads/opctl/spark/__init__.py +5 -0
  776. ads/opctl/spark/cli.py +43 -0
  777. ads/opctl/spark/cmds.py +147 -0
  778. ads/opctl/templates/diagnostic_report_template.jinja2 +102 -0
  779. ads/opctl/utils.py +344 -0
  780. ads/oracledb/__init__.py +5 -0
  781. ads/oracledb/oracle_db.py +346 -0
  782. ads/pipeline/__init__.py +39 -0
  783. ads/pipeline/ads_pipeline.py +2279 -0
  784. ads/pipeline/ads_pipeline_run.py +772 -0
  785. ads/pipeline/ads_pipeline_step.py +605 -0
  786. ads/pipeline/builders/__init__.py +5 -0
  787. ads/pipeline/builders/infrastructure/__init__.py +5 -0
  788. ads/pipeline/builders/infrastructure/custom_script.py +32 -0
  789. ads/pipeline/cli.py +119 -0
  790. ads/pipeline/extension.py +291 -0
  791. ads/pipeline/schema/__init__.py +5 -0
  792. ads/pipeline/schema/cs_step_schema.json +35 -0
  793. ads/pipeline/schema/ml_step_schema.json +31 -0
  794. ads/pipeline/schema/pipeline_schema.json +71 -0
  795. ads/pipeline/visualizer/__init__.py +5 -0
  796. ads/pipeline/visualizer/base.py +570 -0
  797. ads/pipeline/visualizer/graph_renderer.py +272 -0
  798. ads/pipeline/visualizer/text_renderer.py +84 -0
  799. ads/secrets/__init__.py +11 -0
  800. ads/secrets/adb.py +386 -0
  801. ads/secrets/auth_token.py +86 -0
  802. ads/secrets/big_data_service.py +365 -0
  803. ads/secrets/mysqldb.py +149 -0
  804. ads/secrets/oracledb.py +160 -0
  805. ads/secrets/secrets.py +407 -0
  806. ads/telemetry/__init__.py +7 -0
  807. ads/telemetry/base.py +69 -0
  808. ads/telemetry/client.py +122 -0
  809. ads/telemetry/telemetry.py +257 -0
  810. ads/templates/dataflow_pyspark.jinja2 +13 -0
  811. ads/templates/dataflow_sparksql.jinja2 +22 -0
  812. ads/templates/func.jinja2 +20 -0
  813. ads/templates/schemas/openapi.json +1740 -0
  814. ads/templates/score-pkl.jinja2 +173 -0
  815. ads/templates/score.jinja2 +322 -0
  816. ads/templates/score_embedding_onnx.jinja2 +202 -0
  817. ads/templates/score_generic.jinja2 +165 -0
  818. ads/templates/score_huggingface_pipeline.jinja2 +217 -0
  819. ads/templates/score_lightgbm.jinja2 +185 -0
  820. ads/templates/score_onnx.jinja2 +407 -0
  821. ads/templates/score_onnx_new.jinja2 +473 -0
  822. ads/templates/score_oracle_automl.jinja2 +185 -0
  823. ads/templates/score_pyspark.jinja2 +154 -0
  824. ads/templates/score_pytorch.jinja2 +219 -0
  825. ads/templates/score_scikit-learn.jinja2 +184 -0
  826. ads/templates/score_tensorflow.jinja2 +184 -0
  827. ads/templates/score_xgboost.jinja2 +178 -0
  828. ads/text_dataset/__init__.py +5 -0
  829. ads/text_dataset/backends.py +211 -0
  830. ads/text_dataset/dataset.py +445 -0
  831. ads/text_dataset/extractor.py +207 -0
  832. ads/text_dataset/options.py +53 -0
  833. ads/text_dataset/udfs.py +22 -0
  834. ads/text_dataset/utils.py +49 -0
  835. ads/type_discovery/__init__.py +9 -0
  836. ads/type_discovery/abstract_detector.py +21 -0
  837. ads/type_discovery/constant_detector.py +41 -0
  838. ads/type_discovery/continuous_detector.py +54 -0
  839. ads/type_discovery/credit_card_detector.py +99 -0
  840. ads/type_discovery/datetime_detector.py +92 -0
  841. ads/type_discovery/discrete_detector.py +118 -0
  842. ads/type_discovery/document_detector.py +146 -0
  843. ads/type_discovery/ip_detector.py +68 -0
  844. ads/type_discovery/latlon_detector.py +90 -0
  845. ads/type_discovery/phone_number_detector.py +63 -0
  846. ads/type_discovery/type_discovery_driver.py +87 -0
  847. ads/type_discovery/typed_feature.py +594 -0
  848. ads/type_discovery/unknown_detector.py +41 -0
  849. ads/type_discovery/zipcode_detector.py +48 -0
  850. ads/vault/__init__.py +7 -0
  851. ads/vault/vault.py +237 -0
  852. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/METADATA +150 -149
  853. oracle_ads-2.13.10rc0.dist-info/RECORD +858 -0
  854. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/WHEEL +1 -2
  855. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/entry_points.txt +2 -1
  856. oracle_ads-2.13.9rc0.dist-info/RECORD +0 -9
  857. oracle_ads-2.13.9rc0.dist-info/top_level.txt +0 -1
  858. {oracle_ads-2.13.9rc0.dist-info → oracle_ads-2.13.10rc0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,1184 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8; -*-
3
+
4
+ # Copyright (c) 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+ import cloudpickle
7
+ import numpy as np
8
+ import pandas as pd
9
+ from ads.model.serde.common import Serializer, Deserializer
10
+ from ads.common.decorator.runtime_dependency import (
11
+ runtime_dependency,
12
+ OptionalDependency,
13
+ )
14
+ from ads.common import logger
15
+ from pandas.api.types import is_numeric_dtype, is_string_dtype
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+ from joblib import dump
18
+
19
+
20
+ MODEL_SERIALIZATION_TYPE_ONNX = "onnx"
21
+ MODEL_SERIALIZATION_TYPE_CLOUDPICKLE = "cloudpickle"
22
+ MODEL_SERIALIZATION_TYPE_TORHCSCRIPT = "torchscript"
23
+ MODEL_SERIALIZATION_TYPE_TORCH = "torch"
24
+ MODEL_SERIALIZATION_TYPE_TORCH_ONNX = "torch_onnx"
25
+ MODEL_SERIALIZATION_TYPE_TF = "tf"
26
+ MODEL_SERIALIZATION_TYPE_TF_ONNX = "tf_onnx"
27
+ MODEL_SERIALIZATION_TYPE_JOBLIB = "joblib"
28
+ MODEL_SERIALIZATION_TYPE_SKLEARN_ONNX = "sklearn_onnx"
29
+ MODEL_SERIALIZATION_TYPE_LIGHTGBM = "lightgbm"
30
+ MODEL_SERIALIZATION_TYPE_LIGHTGBM_ONNX = "lightgbm_onnx"
31
+ MODEL_SERIALIZATION_TYPE_XGBOOST = "xgboost"
32
+ MODEL_SERIALIZATION_TYPE_XGBOOST_UBJ = "xgboost_ubj"
33
+ MODEL_SERIALIZATION_TYPE_XGBOOST_TXT = "xgboost_txt"
34
+ MODEL_SERIALIZATION_TYPE_XGBOOST_ONNX = "xgboost_onnx"
35
+ MODEL_SERIALIZATION_TYPE_SPARK = "spark"
36
+ MODEL_SERIALIZATION_TYPE_HUGGINGFACE = "huggingface"
37
+
38
+
39
+ SUPPORTED_MODEL_SERIALIZERS = [
40
+ MODEL_SERIALIZATION_TYPE_ONNX,
41
+ MODEL_SERIALIZATION_TYPE_CLOUDPICKLE,
42
+ MODEL_SERIALIZATION_TYPE_TORHCSCRIPT,
43
+ MODEL_SERIALIZATION_TYPE_TORCH,
44
+ MODEL_SERIALIZATION_TYPE_TORCH_ONNX,
45
+ MODEL_SERIALIZATION_TYPE_TF,
46
+ MODEL_SERIALIZATION_TYPE_TF_ONNX,
47
+ MODEL_SERIALIZATION_TYPE_JOBLIB,
48
+ MODEL_SERIALIZATION_TYPE_SKLEARN_ONNX,
49
+ MODEL_SERIALIZATION_TYPE_LIGHTGBM,
50
+ MODEL_SERIALIZATION_TYPE_LIGHTGBM_ONNX,
51
+ MODEL_SERIALIZATION_TYPE_XGBOOST,
52
+ MODEL_SERIALIZATION_TYPE_XGBOOST_ONNX,
53
+ MODEL_SERIALIZATION_TYPE_SPARK,
54
+ MODEL_SERIALIZATION_TYPE_HUGGINGFACE,
55
+ ]
56
+
57
+
58
+ class ModelSerializerType:
59
+ CLOUDPICKLE = MODEL_SERIALIZATION_TYPE_CLOUDPICKLE
60
+ ONNX = MODEL_SERIALIZATION_TYPE_ONNX
61
+
62
+
63
+ class PyTorchModelSerializerType:
64
+ TORCH = MODEL_SERIALIZATION_TYPE_TORCH
65
+ TORCHSCRIPT = MODEL_SERIALIZATION_TYPE_TORHCSCRIPT
66
+ ONNX = MODEL_SERIALIZATION_TYPE_TORCH_ONNX
67
+
68
+
69
+ class TensorflowModelSerializerType:
70
+ TENSORFLOW = MODEL_SERIALIZATION_TYPE_TF
71
+ ONNX = MODEL_SERIALIZATION_TYPE_TF_ONNX
72
+
73
+
74
+ class LightGBMModelSerializerType:
75
+ LIGHTGBM = MODEL_SERIALIZATION_TYPE_LIGHTGBM
76
+ ONNX = MODEL_SERIALIZATION_TYPE_LIGHTGBM_ONNX
77
+
78
+
79
+ class SklearnModelSerializerType:
80
+ JOBLIB = MODEL_SERIALIZATION_TYPE_JOBLIB
81
+ CLOUDPICKLE = MODEL_SERIALIZATION_TYPE_CLOUDPICKLE
82
+ ONNX = MODEL_SERIALIZATION_TYPE_SKLEARN_ONNX
83
+
84
+
85
+ class XgboostModelSerializerType:
86
+ XGBOOST = MODEL_SERIALIZATION_TYPE_XGBOOST
87
+ ONNX = MODEL_SERIALIZATION_TYPE_XGBOOST_ONNX
88
+
89
+
90
+ class SparkModelSerializerType:
91
+ SPARK = MODEL_SERIALIZATION_TYPE_SPARK
92
+
93
+
94
+ class HuggingFaceSerializerType:
95
+ HUGGINGFACE = MODEL_SERIALIZATION_TYPE_HUGGINGFACE
96
+
97
+
98
+ class ModelSerializer(Serializer):
99
+ """Base class for creation of new model serializers."""
100
+
101
+ def __init__(self, model_file_suffix):
102
+ super().__init__()
103
+ self.model_file_suffix = model_file_suffix
104
+
105
+
106
+ class ModelDeserializer(Deserializer):
107
+ """Base class for creation of new model deserializers."""
108
+
109
+ def deserialize(self, **kwargs):
110
+ raise NotImplementedError
111
+
112
+
113
+ class CloudPickleModelSerializer(ModelSerializer):
114
+ """Uses `Cloudpickle` to save model."""
115
+
116
+ def __init__(self, model_file_suffix="pkl"):
117
+ super().__init__(model_file_suffix=model_file_suffix)
118
+
119
+ def serialize(self, estimator, model_path, **kwargs):
120
+ """Uses `cloudpickle.dump` to save model. See https://docs.python.org/3/library/pickle.html#pickle.dump for more details.
121
+
122
+ Args:
123
+ estimator: The model to be saved.
124
+ model_path: The file object or path of the model in which it is to be stored.
125
+ kwargs:
126
+ model_save: (dict, optional).
127
+ The dictionary where contains the availiable options to be passed to `cloudpickle.dump`.
128
+ """
129
+ cloudpickle_kwargs = kwargs.pop("model_save", {})
130
+ with open(model_path, "wb") as f:
131
+ cloudpickle.dump(estimator, f, **cloudpickle_kwargs)
132
+
133
+
134
+ class JobLibModelSerializer(ModelSerializer):
135
+ """Uses `Joblib` to save model."""
136
+
137
+ def __init__(self, model_file_suffix="joblib"):
138
+ super().__init__(model_file_suffix=model_file_suffix)
139
+
140
+ def serialize(self, estimator, model_path, **kwargs):
141
+ """Uses `joblib.dump` to save model. See https://joblib.readthedocs.io/en/latest/generated/joblib.dump.html for more details.
142
+
143
+ Args:
144
+ estimator: The model to be saved.
145
+ model_path: The file object or path of the model in which it is to be stored.
146
+ kwargs:
147
+ model_save: (dict, optional).
148
+ The dictionary where contains the availiable options to be passed to `joblib.dump`.
149
+ """
150
+ joblib_kwargs = kwargs.pop("model_save", {})
151
+ dump(estimator, model_path, **joblib_kwargs)
152
+
153
+
154
+ class SparkModelSerializer(ModelSerializer):
155
+ """Save Spark Model."""
156
+
157
+ def __init__(self, model_file_suffix=""):
158
+ super().__init__(model_file_suffix=model_file_suffix)
159
+
160
+ def serialize(self, estimator, model_path, **kwargs):
161
+ estimator.write().overwrite().save(model_path)
162
+
163
+
164
+ class PyTorchModelSerializer(ModelSerializer):
165
+ """Save PyTorch Model using torch.save(). See https://pytorch.org/docs/stable/generated/torch.save.html for more details."""
166
+
167
+ def __init__(self, model_file_suffix="pt"):
168
+ super().__init__(model_file_suffix=model_file_suffix)
169
+
170
+ @runtime_dependency(module="torch", install_from=OptionalDependency.PYTORCH)
171
+ def serialize(self, estimator, model_path, **kwarg):
172
+ torch.save(estimator.state_dict(), model_path)
173
+
174
+
175
+ class TorchScriptModelSerializer(ModelSerializer):
176
+ """Save PyTorch Model using torchscript. See https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format for more details."""
177
+
178
+ def __init__(self, model_file_suffix="pt"):
179
+ super().__init__(model_file_suffix=model_file_suffix)
180
+
181
+ @runtime_dependency(module="torch", install_from=OptionalDependency.PYTORCH)
182
+ def serialize(self, estimator, model_path, **kwargs):
183
+ compiled_model = torch.jit.script(estimator)
184
+ torch.jit.save(compiled_model, model_path)
185
+
186
+
187
+ class LightGBMModelSerializer(ModelSerializer):
188
+ """Save LightGBM Model through save_model into txt."""
189
+
190
+ def __init__(self, model_file_suffix="txt"):
191
+ super().__init__(model_file_suffix=model_file_suffix)
192
+
193
+ def serialize(self, estimator, model_path, **kwargs):
194
+ estimator.save_model(model_path)
195
+
196
+
197
+ class XgboostJsonModelSerializer(ModelSerializer):
198
+ """Save Xgboost Model through xgboost.save_model into JSON."""
199
+
200
+ def __init__(self, model_file_suffix="json"):
201
+ super().__init__(model_file_suffix=model_file_suffix)
202
+
203
+ def serialize(self, estimator, model_path, **kwargs):
204
+ """Save Xgboost Model through xgboost.save_model .See
205
+ https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.Booster.save_model
206
+ for more details.
207
+
208
+ Args:
209
+ estimator: The model to be saved.
210
+ model_path: The file object or path of the model in which it is to be stored.
211
+ """
212
+ estimator.save_model(model_path)
213
+
214
+
215
+ class XgboostTxtModelSerializer(ModelSerializer):
216
+ """Save Xgboost Model through xgboost.save_model into txt."""
217
+
218
+ def __init__(self, model_file_suffix="txt"):
219
+ super().__init__(model_file_suffix=model_file_suffix)
220
+
221
+ def serialize(self, estimator, model_path, **kwargs):
222
+ """Save Xgboost Model through xgboost.save_model .See
223
+ https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.Booster.save_model
224
+ for more details.
225
+
226
+ Args:
227
+ estimator: The model to be saved.
228
+ model_path: The file object or path of the model in which it is to be stored.
229
+ """
230
+ estimator.save_model(model_path)
231
+
232
+
233
+ class XgboostUbjModelSerializer(ModelSerializer):
234
+ """Save Xgboost Model through xgboost.save_model into binary JSON."""
235
+
236
+ def __init__(self, model_file_suffix="ubj"):
237
+ super().__init__(model_file_suffix=model_file_suffix)
238
+
239
+ def serialize(self, estimator, model_path, **kwargs):
240
+ """Save Xgboost Model through xgboost.save_model .See
241
+ https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.Booster.save_model
242
+ for more details.
243
+
244
+ Args:
245
+ estimator: The model to be saved.
246
+ model_path: The file object or path of the model in which it is to be stored.
247
+ """
248
+ estimator.save_model(model_path)
249
+
250
+
251
+ class TensorFlowModelSerializer(ModelSerializer):
252
+ """Save Tensorflow Model."""
253
+
254
+ def __init__(self, model_file_suffix="h5"):
255
+ super().__init__(model_file_suffix=model_file_suffix)
256
+
257
+ def serialize(self, estimator, model_path, **kwargs):
258
+ estimator.save(model_path)
259
+
260
+
261
+ class HuggingFaceModelSerializer(ModelSerializer):
262
+ """Save HuggingFace Pipeline."""
263
+
264
+ def __init__(self, model_file_suffix=""):
265
+ super().__init__(model_file_suffix=model_file_suffix)
266
+
267
+ def serialize(self, estimator, model_path, **kwargs):
268
+ estimator.save_pretrained(save_directory=model_path)
269
+ estimator.model.config.use_pretrained_backbone = False
270
+ estimator.model.config.save_pretrained(save_directory=model_path)
271
+
272
+
273
+ class OnnxModelSerializer(ModelSerializer):
274
+ """Base class for creation of onnx converter for each model framework."""
275
+
276
+ def __init__(self, model_file_suffix="onnx"):
277
+ super().__init__(model_file_suffix=model_file_suffix)
278
+
279
+ def serialize(
280
+ self,
281
+ estimator,
282
+ model_path,
283
+ initial_types: List[Tuple] = None,
284
+ X_sample: Optional[
285
+ Union[
286
+ Dict,
287
+ str,
288
+ List,
289
+ Tuple,
290
+ np.ndarray,
291
+ pd.core.series.Series,
292
+ pd.core.frame.DataFrame,
293
+ ]
294
+ ] = None,
295
+ **kwargs,
296
+ ):
297
+ """Save model into onnx format.
298
+
299
+ Args:
300
+ estimator: The model to be saved.
301
+ model_path: The file object or path of the model in which it is to be stored.
302
+ initial_types: (List[Tuple], optional)
303
+ a python list. Each element is a tuple of a variable name and a data type.
304
+ X_sample: (any, optional). Defaults to None.
305
+ Contains model inputs such that model(X_sample) is a valid
306
+ invocation of the model, used to valid model input type.
307
+ """
308
+ self.estimator = estimator
309
+ onx = self._to_onnx(
310
+ initial_types=initial_types,
311
+ X_sample=X_sample,
312
+ **kwargs,
313
+ )
314
+ with open(model_path, "wb") as f:
315
+ f.write(onx.SerializeToString())
316
+
317
+ def _to_onnx(
318
+ self,
319
+ initial_types: List[Tuple] = None,
320
+ X_sample: Optional[
321
+ Union[
322
+ Dict,
323
+ str,
324
+ List,
325
+ Tuple,
326
+ np.ndarray,
327
+ pd.core.series.Series,
328
+ pd.core.frame.DataFrame,
329
+ ]
330
+ ] = None,
331
+ **kwargs,
332
+ ):
333
+ raise NotImplementedError
334
+
335
+
336
+ class SklearnOnnxModelSerializer(OnnxModelSerializer):
337
+ """Converts Skearn Model into Onnx."""
338
+
339
+ def __init__(self):
340
+ super().__init__()
341
+
342
+ @runtime_dependency(module="onnx", install_from=OptionalDependency.ONNX)
343
+ @runtime_dependency(module="xgboost", install_from=OptionalDependency.BOOSTED)
344
+ @runtime_dependency(module="lightgbm", install_from=OptionalDependency.BOOSTED)
345
+ @runtime_dependency(module="skl2onnx", install_from=OptionalDependency.ONNX)
346
+ @runtime_dependency(module="onnxmltools", install_from=OptionalDependency.ONNX)
347
+ @runtime_dependency(
348
+ module="onnxmltools.convert.xgboost.operator_converters.XGBoost",
349
+ object="convert_xgboost",
350
+ install_from=OptionalDependency.ONNX,
351
+ )
352
+ @runtime_dependency(
353
+ module="onnxmltools.convert.lightgbm.operator_converters.LightGbm",
354
+ object="convert_lightgbm",
355
+ install_from=OptionalDependency.ONNX,
356
+ )
357
+ def _to_onnx(
358
+ self,
359
+ initial_types: List[Tuple] = None,
360
+ X_sample: Optional[
361
+ Union[
362
+ Dict,
363
+ str,
364
+ List,
365
+ Tuple,
366
+ np.ndarray,
367
+ pd.core.series.Series,
368
+ pd.core.frame.DataFrame,
369
+ ]
370
+ ] = None,
371
+ **kwargs,
372
+ ):
373
+ """
374
+ Produces an equivalent ONNX model of the given scikit-learn model.
375
+
376
+ Parameters
377
+ ----------
378
+ initial_types: (List[Tuple], optional). Defaults to None.
379
+ Each element is a tuple of a variable name and a type.
380
+ X_sample: Union[Dict, str, List, np.ndarray, pd.core.series.Series, pd.core.frame.DataFrame,]. Defaults to None.
381
+ Contains model inputs such that model(X_sample) is a valid invocation of the model.
382
+ Used to generate initial_types.
383
+
384
+ Returns
385
+ -------
386
+ onnx.onnx_ml_pb2.ModelProto
387
+ An ONNX model (type: ModelProto) which is equivalent to the input scikit-learn model.
388
+ """
389
+ auto_generated_initial_types = None
390
+ if not initial_types:
391
+ if X_sample is None:
392
+ raise ValueError(
393
+ " At least one of `X_sample` or `initial_types` must be provided."
394
+ )
395
+ auto_generated_initial_types = self._generate_initial_types(X_sample)
396
+ if str(type(self.estimator)).startswith("<class 'sklearn.pipeline"):
397
+ model_types = []
398
+ model_types = [type(val[1]) for val in self.estimator.steps]
399
+ if xgboost.sklearn.XGBClassifier in model_types:
400
+ skl2onnx.update_registered_converter(
401
+ xgboost.XGBClassifier,
402
+ "XGBoostXGBClassifier",
403
+ skl2onnx.common.shape_calculator.calculate_linear_classifier_output_shapes,
404
+ convert_xgboost,
405
+ options=kwargs.pop(
406
+ "options", {"nocl": [True, False], "zipmap": [True, False]}
407
+ ),
408
+ )
409
+
410
+ if xgboost.sklearn.XGBRegressor in model_types:
411
+ skl2onnx.update_registered_converter(
412
+ xgboost.XGBRegressor,
413
+ "XGBoostXGBRegressor",
414
+ skl2onnx.common.shape_calculator.calculate_linear_regressor_output_shapes,
415
+ convert_xgboost,
416
+ )
417
+
418
+ if lightgbm.sklearn.LGBMClassifier in model_types:
419
+ skl2onnx.update_registered_converter(
420
+ lightgbm.LGBMClassifier,
421
+ "LightGbmLGBMClassifier",
422
+ skl2onnx.common.shape_calculator.calculate_linear_classifier_output_shapes,
423
+ convert_lightgbm,
424
+ options=kwargs.pop(
425
+ "options",
426
+ {"nocl": [True, False], "zipmap": [True, False, "columns"]},
427
+ ),
428
+ )
429
+
430
+ if lightgbm.sklearn.LGBMRegressor in model_types:
431
+
432
+ def skl2onnx_convert_lightgbm(scope, operator, container):
433
+ options = scope.get_options(operator.raw_operator)
434
+ if "split" in options:
435
+ if StrictVersion(onnxmltools.__version__) < StrictVersion(
436
+ "1.9.2"
437
+ ):
438
+ logger.warnings(
439
+ "Option split was released in version 1.9.2 but %s is "
440
+ "installed. It will be ignored."
441
+ % onnxmltools.__version__
442
+ )
443
+ operator.split = options["split"]
444
+ else:
445
+ operator.split = None
446
+ convert_lightgbm(scope, operator, container)
447
+
448
+ skl2onnx.update_registered_converter(
449
+ lightgbm.LGBMRegressor,
450
+ "LightGbmLGBMRegressor",
451
+ skl2onnx.common.shape_calculator.calculate_linear_regressor_output_shapes,
452
+ skl2onnx_convert_lightgbm,
453
+ options=kwargs.pop("options", {"split": None}),
454
+ )
455
+ if initial_types:
456
+ return skl2onnx.convert_sklearn(
457
+ self.estimator, initial_types=initial_types, **kwargs
458
+ )
459
+ else:
460
+ try:
461
+ return skl2onnx.convert_sklearn(
462
+ self.estimator,
463
+ initial_types=auto_generated_initial_types,
464
+ target_opset=None,
465
+ **kwargs,
466
+ )
467
+ except Exception as e:
468
+ raise ValueError(
469
+ "`initial_types` can not be autodetected. Please directly pass `initial_types`."
470
+ )
471
+ else:
472
+ if initial_types:
473
+ return onnxmltools.convert_sklearn(
474
+ self.estimator,
475
+ initial_types=initial_types,
476
+ targeted_onnx=onnx.__version__,
477
+ **kwargs,
478
+ )
479
+ else:
480
+ try:
481
+ return onnxmltools.convert_sklearn(
482
+ self.estimator,
483
+ initial_types=auto_generated_initial_types,
484
+ targeted_onnx=onnx.__version__,
485
+ **kwargs,
486
+ )
487
+ except Exception as e:
488
+ raise ValueError(
489
+ "`initial_types` can not be detected. Please directly pass initial_types."
490
+ )
491
+
492
+ @runtime_dependency(module="skl2onnx", install_from=OptionalDependency.ONNX)
493
+ def _generate_initial_types(self, X_sample: Any) -> List:
494
+ """Auto generate intial types.
495
+
496
+ Parameters
497
+ ----------
498
+ X_sample: (Any)
499
+ Train data.
500
+
501
+ Returns
502
+ -------
503
+ List
504
+ Initial types.
505
+ """
506
+ if self._is_all_numerical_array_dataframe(X_sample):
507
+ # if it's a dataframe and all the columns are numerical. Or
508
+ # it's not a dataframe, also try this.
509
+ if hasattr(X_sample, "shape") and len(X_sample.shape) >= 2:
510
+ auto_generated_initial_types = [
511
+ (
512
+ "input",
513
+ skl2onnx.common.data_types.FloatTensorType(
514
+ [None, X_sample.shape[1]]
515
+ ),
516
+ )
517
+ ]
518
+ elif hasattr(self.estimator, "n_features_in_"):
519
+ n_cols = self.estimator.n_features_in_
520
+ auto_generated_initial_types = [
521
+ (
522
+ "input",
523
+ skl2onnx.common.data_types.FloatTensorType([None, n_cols]),
524
+ )
525
+ ]
526
+ else:
527
+ raise ValueError(
528
+ "`initial_types` can not be detected. Please directly pass initial_types."
529
+ )
530
+ elif self.is_either_numerical_or_string_dataframe(X_sample):
531
+ # for dataframe and not all the columns are numerical, then generate
532
+ # the input types of all the columns one by one.
533
+ auto_generated_initial_types = []
534
+
535
+ for i, col in X_sample.items():
536
+ if is_numeric_dtype(col.dtypes):
537
+ auto_generated_initial_types.append(
538
+ (
539
+ col.name,
540
+ skl2onnx.common.data_types.FloatTensorType([None, 1]),
541
+ )
542
+ )
543
+ else:
544
+ auto_generated_initial_types.append(
545
+ (
546
+ col.name,
547
+ skl2onnx.common.data_types.StringTensorType([None, 1]),
548
+ )
549
+ )
550
+ else:
551
+ try:
552
+ auto_generated_initial_types = (
553
+ skl2onnx.common.data_types.guess_data_type(
554
+ np.array(X_sample) if isinstance(X_sample, list) else X_sample
555
+ )
556
+ )
557
+ except:
558
+ auto_generated_initial_types = None
559
+ return auto_generated_initial_types
560
+
561
+ @staticmethod
562
+ def _is_all_numerical_array_dataframe(
563
+ data: Union[pd.DataFrame, np.ndarray]
564
+ ) -> bool:
565
+ """Check whether all the columns are numerical for numpy array and dataframe.
566
+ For data with any other data types, it will return False.
567
+
568
+ Parameters
569
+ ----------
570
+ data: Union[pd.DataFrame, np.ndarray]
571
+
572
+ Returns
573
+ -------
574
+ bool
575
+ Whether all the columns in a pandas dataframe or numpy array are all numerical.
576
+ """
577
+ return (
578
+ isinstance(data, pd.DataFrame)
579
+ and all([is_numeric_dtype(dtype) for dtype in data.dtypes])
580
+ or (isinstance(data, np.ndarray) and is_numeric_dtype(data.dtype))
581
+ )
582
+
583
+ @staticmethod
584
+ def is_either_numerical_or_string_dataframe(data: pd.DataFrame) -> bool:
585
+ """Check whether all the columns are either numerical or string for dataframe."""
586
+ return isinstance(data, pd.DataFrame) and all(
587
+ [
588
+ is_numeric_dtype(col.dtypes) or is_string_dtype(col.dtypes)
589
+ for _, col in data.items()
590
+ ]
591
+ )
592
+
593
+
594
+ class LightGBMOnnxModelSerializer(OnnxModelSerializer):
595
+ """Converts LightGBM model into onnx format."""
596
+
597
+ def __init__(self):
598
+ super().__init__()
599
+
600
+ @runtime_dependency(
601
+ module="skl2onnx.common.data_types",
602
+ object="FloatTensorType",
603
+ install_from=OptionalDependency.ONNX,
604
+ )
605
+ @runtime_dependency(
606
+ module="onnxmltools.convert",
607
+ object="convert_lightgbm",
608
+ install_from=OptionalDependency.ONNX,
609
+ )
610
+ def _to_onnx(
611
+ self,
612
+ initial_types: List[Tuple] = None,
613
+ X_sample: Optional[
614
+ Union[
615
+ Dict,
616
+ str,
617
+ List,
618
+ Tuple,
619
+ np.ndarray,
620
+ pd.core.series.Series,
621
+ pd.core.frame.DataFrame,
622
+ ]
623
+ ] = None,
624
+ **kwargs,
625
+ ):
626
+ """
627
+ Produces an equivalent ONNX model of the given LightGBM model.
628
+
629
+ Parameters
630
+ ----------
631
+ initial_types: (List[Tuple], optional). Defaults to None.
632
+ Each element is a tuple of a variable name and a type.
633
+ X_sample: Union[Dict, str, List, np.ndarray, pd.core.series.Series, pd.core.frame.DataFrame,]. Defaults to None.
634
+ Contains model inputs such that model(X_sample) is a valid invocation of the model.
635
+ Used to generate initial_types.
636
+
637
+ Returns
638
+ ------
639
+ An ONNX model (type: ModelProto) which is equivalent to the input LightGBM model.
640
+ """
641
+ auto_generated_initial_types = None
642
+ if not initial_types:
643
+ auto_generated_initial_types = self._generate_initial_types(X_sample)
644
+ try:
645
+ return convert_lightgbm(
646
+ self.estimator,
647
+ initial_types=auto_generated_initial_types,
648
+ target_opset=kwargs.pop("target_opset", None),
649
+ **kwargs,
650
+ )
651
+ except:
652
+ raise ValueError(
653
+ "`initial_types` can not be detected. Please directly pass initial_types."
654
+ )
655
+ else:
656
+ return convert_lightgbm(
657
+ self.estimator,
658
+ initial_types=initial_types,
659
+ target_opset=kwargs.pop("target_opset", None),
660
+ **kwargs,
661
+ )
662
+
663
+ @runtime_dependency(
664
+ module="skl2onnx.common.data_types",
665
+ object="FloatTensorType",
666
+ install_from=OptionalDependency.ONNX,
667
+ )
668
+ def _generate_initial_types(self, X_sample: Any) -> List:
669
+ """Auto generate intial types.
670
+
671
+ Parameters
672
+ ----------
673
+ X_sample: (Any)
674
+ Train data.
675
+
676
+ Returns
677
+ -------
678
+ List
679
+ Initial types.
680
+ """
681
+ if X_sample is not None and hasattr(X_sample, "shape"):
682
+ auto_generated_initial_types = [
683
+ ("input", FloatTensorType([None, X_sample.shape[1]]))
684
+ ]
685
+ elif hasattr(self.estimator, "num_feature"):
686
+ n_cols = self.estimator.num_feature()
687
+ auto_generated_initial_types = [("input", FloatTensorType([None, n_cols]))]
688
+ elif hasattr(self.estimator, "n_features_in_"):
689
+ n_cols = self.estimator.n_features_in_
690
+ auto_generated_initial_types = [("input", FloatTensorType([None, n_cols]))]
691
+ else:
692
+ raise ValueError(
693
+ "`initial_types` can not be detected. Please directly pass initial_types."
694
+ )
695
+ return auto_generated_initial_types
696
+
697
+
698
+ class XgboostOnnxModelSerializer(OnnxModelSerializer):
699
+ """Converts Xgboost model into onnx format."""
700
+
701
+ def __init__(self):
702
+ super().__init__()
703
+
704
+ @runtime_dependency(module="onnx", install_from=OptionalDependency.ONNX)
705
+ @runtime_dependency(module="xgboost", install_from=OptionalDependency.BOOSTED)
706
+ @runtime_dependency(
707
+ module="skl2onnx",
708
+ object="convert_sklearn",
709
+ install_from=OptionalDependency.ONNX,
710
+ )
711
+ @runtime_dependency(
712
+ module="skl2onnx",
713
+ object="update_registered_converter",
714
+ install_from=OptionalDependency.ONNX,
715
+ )
716
+ @runtime_dependency(
717
+ module="skl2onnx.common.data_types",
718
+ object="FloatTensorType",
719
+ install_from=OptionalDependency.ONNX,
720
+ )
721
+ @runtime_dependency(
722
+ module="skl2onnx.common.shape_calculator",
723
+ object="calculate_linear_classifier_output_shapes",
724
+ install_from=OptionalDependency.ONNX,
725
+ )
726
+ @runtime_dependency(
727
+ module="skl2onnx.common.shape_calculator",
728
+ object="calculate_linear_regressor_output_shapes",
729
+ install_from=OptionalDependency.ONNX,
730
+ )
731
+ @runtime_dependency(module="onnxmltools", install_from=OptionalDependency.ONNX)
732
+ @runtime_dependency(
733
+ module="onnxmltools.convert.xgboost.operator_converters.XGBoost",
734
+ object="convert_xgboost",
735
+ install_from=OptionalDependency.ONNX,
736
+ )
737
+ def _to_onnx(
738
+ self,
739
+ initial_types: List[Tuple] = None,
740
+ X_sample: Union[list, tuple, pd.DataFrame, pd.Series, np.ndarray] = None,
741
+ **kwargs,
742
+ ):
743
+ """
744
+ Produces an equivalent ONNX model of the given Xgboost model.
745
+
746
+ Parameters
747
+ ----------
748
+ initial_types: (List[Tuple], optional). Defaults to None.
749
+ Each element is a tuple of a variable name and a type.
750
+ X_sample: Union[Dict, str, List, np.ndarray, pd.core.series.Series, pd.core.frame.DataFrame,]. Defaults to None.
751
+ Contains model inputs such that model(X_sample) is a valid invocation of the model.
752
+ Used to generate initial_types.
753
+
754
+ Returns
755
+ -------
756
+ onnx.onnx_ml_pb2.ModelProto
757
+ An ONNX model (type: ModelProto) which is equivalent to the input xgboost model.
758
+ """
759
+ auto_generated_initial_types = None
760
+ if not initial_types:
761
+ auto_generated_initial_types = self._generate_initial_types(X_sample)
762
+
763
+ model_types = []
764
+ if str(type(self.estimator)).startswith("<class 'xgboost.sklearn."):
765
+ model_types.append(type(self.estimator))
766
+
767
+ if model_types:
768
+ if xgboost.sklearn.XGBClassifier in model_types:
769
+ update_registered_converter(
770
+ xgboost.XGBClassifier,
771
+ "XGBoostXGBClassifier",
772
+ calculate_linear_classifier_output_shapes,
773
+ convert_xgboost,
774
+ options={"nocl": [True, False], "zipmap": [True, False]},
775
+ )
776
+ elif xgboost.sklearn.XGBRegressor in model_types:
777
+ update_registered_converter(
778
+ xgboost.XGBRegressor,
779
+ "XGBoostXGBRegressor",
780
+ calculate_linear_regressor_output_shapes,
781
+ convert_xgboost,
782
+ )
783
+ if initial_types:
784
+ return convert_sklearn(
785
+ self.estimator, initial_types=initial_types, **kwargs
786
+ )
787
+ else:
788
+ try:
789
+ return convert_sklearn(
790
+ self.estimator,
791
+ initial_types=auto_generated_initial_types,
792
+ **kwargs,
793
+ )
794
+ except:
795
+ raise ValueError(
796
+ "`initial_types` can not be autodetected. Please directly pass `initial_types`."
797
+ )
798
+ else:
799
+ # xgboost api
800
+ if initial_types:
801
+ return onnxmltools.convert_xgboost(
802
+ self.estimator,
803
+ initial_types=initial_types,
804
+ target_opset=kwargs.pop("target_opset", None),
805
+ targeted_onnx=onnx.__version__,
806
+ **kwargs,
807
+ )
808
+ else:
809
+ try:
810
+ return onnxmltools.convert_xgboost(
811
+ self.estimator,
812
+ initial_types=auto_generated_initial_types,
813
+ target_opset=kwargs.pop("target_opset", None),
814
+ targeted_onnx=onnx.__version__,
815
+ **kwargs,
816
+ )
817
+ except:
818
+ raise ValueError(
819
+ "`initial_types` can not be autodetected. Please directly pass `initial_types`."
820
+ )
821
+
822
+ @runtime_dependency(
823
+ module="skl2onnx.common.data_types",
824
+ object="FloatTensorType",
825
+ install_from=OptionalDependency.ONNX,
826
+ )
827
+ def _generate_initial_types(self, X_sample: Any) -> List:
828
+ """Auto generate intial types.
829
+
830
+ Parameters
831
+ ----------
832
+ X_sample: (Any)
833
+ Train data.
834
+
835
+ Returns
836
+ -------
837
+ List
838
+ Initial types.
839
+ """
840
+ if hasattr(self.estimator, "n_features_in_"):
841
+ # sklearn api
842
+ n_cols = self.estimator.n_features_in_
843
+ return [("input", FloatTensorType([None, n_cols]))]
844
+ elif hasattr(self.estimator, "feature_names") and self.estimator.feature_names:
845
+ # xgboost learning api
846
+ n_cols = len(self.estimator.feature_names)
847
+ return [("input", FloatTensorType([None, n_cols]))]
848
+ if X_sample is None:
849
+ raise ValueError(
850
+ " At least one of `X_sample` or `initial_types` must be provided."
851
+ )
852
+ if (
853
+ X_sample is not None
854
+ and hasattr(X_sample, "shape")
855
+ and len(X_sample.shape) >= 2
856
+ ):
857
+ auto_generated_initial_types = [
858
+ ("input", FloatTensorType([None, X_sample.shape[1]]))
859
+ ]
860
+ else:
861
+ raise ValueError(
862
+ "`initial_types` can not be detected. Please directly pass initial_types."
863
+ )
864
+ return auto_generated_initial_types
865
+
866
+
867
+ class PytorchOnnxModelSerializer(OnnxModelSerializer):
868
+ """Converts Pytorch model into onnx format."""
869
+
870
+ def __init__(self):
871
+ super().__init__()
872
+
873
+ @runtime_dependency(module="torch", install_from=OptionalDependency.PYTORCH)
874
+ def serialize(
875
+ self,
876
+ estimator,
877
+ model_path: str,
878
+ X_sample: Optional[
879
+ Union[
880
+ Dict,
881
+ str,
882
+ List,
883
+ Tuple,
884
+ np.ndarray,
885
+ pd.core.series.Series,
886
+ pd.core.frame.DataFrame,
887
+ ]
888
+ ] = None,
889
+ **kwargs,
890
+ ):
891
+ """
892
+ Exports the given Pytorch model into ONNX format.
893
+
894
+ Parameters
895
+ ----------
896
+ path: str, default to None
897
+ Path to save the serialized model.
898
+ onnx_args: (tuple or torch.Tensor), default to None
899
+ Contains model inputs such that model(onnx_args) is a valid
900
+ invocation of the model. Can be structured either as: 1) ONLY A
901
+ TUPLE OF ARGUMENTS; 2) A TENSOR; 3) A TUPLE OF ARGUMENTS ENDING
902
+ WITH A DICTIONARY OF NAMED ARGUMENTS
903
+ X_sample: Union[list, tuple, pd.Series, np.ndarray, pd.DataFrame]. Defaults to None.
904
+ A sample of input data that will be used to generate input schema and detect onnx_args.
905
+ kwargs:
906
+ input_names: (List[str], optional). Defaults to ["input"].
907
+ Names to assign to the input nodes of the graph, in order.
908
+ output_names: (List[str], optional). Defaults to ["output"].
909
+ Names to assign to the output nodes of the graph, in order.
910
+ dynamic_axes: (dict, optional). Defaults to None.
911
+ Specify axes of tensors as dynamic (i.e. known only at run-time).
912
+
913
+ Returns
914
+ -------
915
+ None
916
+ Nothing
917
+
918
+ Raises
919
+ ------
920
+ AssertionError
921
+ if onnx module is not support by the current version of torch
922
+ ValueError
923
+ if X_sample is not provided
924
+ if path is not provided
925
+ """
926
+ onnx_args = kwargs.get("onnx_args", None)
927
+ input_names = kwargs.get("input_names", ["input"])
928
+ output_names = kwargs.get("output_names", ["output"])
929
+ dynamic_axes = kwargs.get("dynamic_axes", None)
930
+
931
+ assert hasattr(torch, "onnx"), (
932
+ f"This version of pytorch {torch.__version__} does not appear to support onnx "
933
+ "conversion."
934
+ )
935
+
936
+ if onnx_args is None:
937
+ if X_sample is not None:
938
+ logger.warning(
939
+ "Since `onnx_args` is not provided, `onnx_args` is "
940
+ "detected from `X_sample` to export pytorch model as onnx."
941
+ )
942
+ onnx_args = X_sample
943
+ else:
944
+ raise ValueError(
945
+ "`onnx_args` can not be detected. The parameter `onnx_args` must be provided to export pytorch model as onnx."
946
+ )
947
+
948
+ if not model_path:
949
+ raise ValueError(
950
+ "The parameter `model_path` must be provided to save the model file."
951
+ )
952
+
953
+ torch.onnx.export(
954
+ estimator,
955
+ args=onnx_args,
956
+ f=model_path,
957
+ input_names=input_names,
958
+ output_names=output_names,
959
+ dynamic_axes=dynamic_axes,
960
+ )
961
+
962
+
963
+ class TensorFlowOnnxModelSerializer(OnnxModelSerializer):
964
+ """Converts Tensorflow model into onnx format."""
965
+
966
+ def __init__(self):
967
+ super().__init__()
968
+
969
+ @runtime_dependency(module="tf2onnx", install_from=OptionalDependency.ONNX)
970
+ @runtime_dependency(
971
+ module="tensorflow",
972
+ short_name="tf",
973
+ install_from=OptionalDependency.TENSORFLOW,
974
+ )
975
+ def serialize(
976
+ self,
977
+ estimator,
978
+ model_path: str = None,
979
+ X_sample: Optional[
980
+ Union[
981
+ Dict,
982
+ str,
983
+ List,
984
+ Tuple,
985
+ np.ndarray,
986
+ pd.core.series.Series,
987
+ pd.core.frame.DataFrame,
988
+ ]
989
+ ] = None,
990
+ **kwargs,
991
+ ):
992
+ """
993
+ Exports the given Tensorflow model into ONNX format.
994
+
995
+ Parameters
996
+ ----------
997
+ model_path: str, default to None
998
+ Path to save the serialized model.
999
+ X_sample: Union[list, tuple, pd.Series, np.ndarray, pd.DataFrame]. Defaults to None.
1000
+ A sample of input data that will be used to generate input schema and detect input_signature.
1001
+
1002
+
1003
+ Returns
1004
+ -------
1005
+ None
1006
+ Nothing
1007
+
1008
+ Raises
1009
+ ------
1010
+ ValueError
1011
+ if model_path is not provided
1012
+ """
1013
+ opset_version = kwargs.get("opset_version", None)
1014
+ input_signature = kwargs.get("input_signature", None)
1015
+
1016
+ if not model_path:
1017
+ raise ValueError(
1018
+ "The parameter `model_path` must be provided to save the model file."
1019
+ )
1020
+ if input_signature is None:
1021
+ if hasattr(estimator, "input_shape"):
1022
+ if not isinstance(estimator.input, list):
1023
+ # single input
1024
+ detected_input_signature = (
1025
+ tf.TensorSpec(
1026
+ estimator.input_shape,
1027
+ dtype=estimator.input.dtype,
1028
+ name="input",
1029
+ ),
1030
+ )
1031
+ else:
1032
+ # multiple input
1033
+ detected_input_signature = []
1034
+ for i in range(len(estimator.input)):
1035
+ detected_input_signature.append(
1036
+ tf.TensorSpec(
1037
+ estimator.input_shape[i],
1038
+ dtype=estimator.input[i].dtype,
1039
+ )
1040
+ )
1041
+
1042
+ elif X_sample is not None and hasattr(X_sample, "shape"):
1043
+ logger.warning(
1044
+ "Since `input_signature` is not provided, `input_signature` is "
1045
+ "detected from `X_sample` to export tensorflow model as "
1046
+ "onnx."
1047
+ )
1048
+ X_sample_shape = list(X_sample.shape)
1049
+ X_sample_shape[0] = None
1050
+ detected_input_signature = (
1051
+ tf.TensorSpec(X_sample_shape, dtype=X_sample.dtype, name="input"),
1052
+ )
1053
+ else:
1054
+ raise ValueError(
1055
+ "The parameter `input_signature` must be provided to export "
1056
+ "tensorflow model as onnx."
1057
+ )
1058
+ try:
1059
+ tf2onnx.convert.from_keras(
1060
+ estimator,
1061
+ input_signature=detected_input_signature,
1062
+ opset=opset_version,
1063
+ output_path=model_path,
1064
+ )
1065
+ except:
1066
+ raise ValueError(
1067
+ "`input_signature` can not be autodetected. The parameter `input_signature` must be provided to export "
1068
+ "tensorflow model as onnx."
1069
+ )
1070
+
1071
+ else:
1072
+ tf2onnx.convert.from_keras(
1073
+ estimator,
1074
+ input_signature=input_signature,
1075
+ opset=opset_version,
1076
+ output_path=model_path,
1077
+ )
1078
+
1079
+
1080
+ class OnnxModelSaveSERDE(OnnxModelSerializer, ModelDeserializer):
1081
+ name = MODEL_SERIALIZATION_TYPE_ONNX
1082
+
1083
+
1084
+ class CloudpickleModelSaveSERDE(CloudPickleModelSerializer, ModelDeserializer):
1085
+ name = MODEL_SERIALIZATION_TYPE_CLOUDPICKLE
1086
+
1087
+
1088
+ class JoblibModelSaveSERDE(JobLibModelSerializer, ModelDeserializer):
1089
+ name = MODEL_SERIALIZATION_TYPE_JOBLIB
1090
+
1091
+
1092
+ class SparkModelSaveSERDE(SparkModelSerializer, ModelDeserializer):
1093
+ name = MODEL_SERIALIZATION_TYPE_SPARK
1094
+
1095
+
1096
+ class HuggingFacePipelineSaveSERDE(HuggingFaceModelSerializer, ModelDeserializer):
1097
+ name = MODEL_SERIALIZATION_TYPE_HUGGINGFACE
1098
+
1099
+
1100
+ class TorchScriptModelSaveSERDE(TorchScriptModelSerializer, ModelDeserializer):
1101
+ name = MODEL_SERIALIZATION_TYPE_TORHCSCRIPT
1102
+
1103
+
1104
+ class PyTorchModelSaveSERDE(PyTorchModelSerializer, ModelDeserializer):
1105
+ name = MODEL_SERIALIZATION_TYPE_TORCH
1106
+
1107
+
1108
+ class PyTorchOnnxModelSaveSERDE(PytorchOnnxModelSerializer, ModelDeserializer):
1109
+ name = MODEL_SERIALIZATION_TYPE_TORCH_ONNX
1110
+
1111
+
1112
+ class TensorFlowModelSaveSERDE(TensorFlowModelSerializer, ModelDeserializer):
1113
+ name = MODEL_SERIALIZATION_TYPE_TF
1114
+
1115
+
1116
+ class TensorFlowOnnxModelSaveSERDE(TensorFlowOnnxModelSerializer, ModelDeserializer):
1117
+ name = MODEL_SERIALIZATION_TYPE_TF_ONNX
1118
+
1119
+
1120
+ class SklearnOnnxModelSaveSERDE(SklearnOnnxModelSerializer, ModelDeserializer):
1121
+ name = MODEL_SERIALIZATION_TYPE_SKLEARN_ONNX
1122
+
1123
+
1124
+ class LightGBMModelSaveSERDE(LightGBMModelSerializer, ModelDeserializer):
1125
+ name = MODEL_SERIALIZATION_TYPE_LIGHTGBM
1126
+
1127
+
1128
+ class LightGBMOnnxModelSaveSERDE(LightGBMOnnxModelSerializer, ModelDeserializer):
1129
+ name = MODEL_SERIALIZATION_TYPE_LIGHTGBM_ONNX
1130
+
1131
+
1132
+ class XgboostJsonModelSaveSERDE(XgboostJsonModelSerializer, ModelDeserializer):
1133
+ name = MODEL_SERIALIZATION_TYPE_XGBOOST
1134
+
1135
+
1136
+ class XgboostUbjModelSaveSERDE(XgboostUbjModelSerializer, ModelDeserializer):
1137
+ name = MODEL_SERIALIZATION_TYPE_XGBOOST_UBJ
1138
+
1139
+
1140
+ class XgboostTxtModelSaveSERDE(XgboostTxtModelSerializer, ModelDeserializer):
1141
+ name = MODEL_SERIALIZATION_TYPE_XGBOOST_TXT
1142
+
1143
+
1144
+ class XgboostOnnxModelSaveSERDE(XgboostOnnxModelSerializer, ModelDeserializer):
1145
+ name = MODEL_SERIALIZATION_TYPE_XGBOOST_ONNX
1146
+
1147
+
1148
+ class ModelSerializerFactory:
1149
+ """Model Serializer Factory.
1150
+
1151
+ Returns
1152
+ -------
1153
+ model_save_serde: Intance of `ads.model.SERDE`".
1154
+ """
1155
+
1156
+ _factory = {}
1157
+ _factory[MODEL_SERIALIZATION_TYPE_CLOUDPICKLE] = CloudpickleModelSaveSERDE
1158
+ _factory[MODEL_SERIALIZATION_TYPE_ONNX] = OnnxModelSaveSERDE
1159
+ _factory[MODEL_SERIALIZATION_TYPE_TORHCSCRIPT] = TorchScriptModelSaveSERDE
1160
+ _factory[MODEL_SERIALIZATION_TYPE_TORCH] = PyTorchModelSaveSERDE
1161
+ _factory[MODEL_SERIALIZATION_TYPE_TORCH_ONNX] = PyTorchOnnxModelSaveSERDE
1162
+ _factory[MODEL_SERIALIZATION_TYPE_TF] = TensorFlowModelSaveSERDE
1163
+ _factory[MODEL_SERIALIZATION_TYPE_TF_ONNX] = TensorFlowOnnxModelSaveSERDE
1164
+ _factory[MODEL_SERIALIZATION_TYPE_JOBLIB] = JoblibModelSaveSERDE
1165
+ _factory[MODEL_SERIALIZATION_TYPE_SKLEARN_ONNX] = SklearnOnnxModelSaveSERDE
1166
+ _factory[MODEL_SERIALIZATION_TYPE_LIGHTGBM] = LightGBMModelSaveSERDE
1167
+ _factory[MODEL_SERIALIZATION_TYPE_LIGHTGBM_ONNX] = LightGBMOnnxModelSaveSERDE
1168
+ _factory[MODEL_SERIALIZATION_TYPE_XGBOOST] = XgboostJsonModelSaveSERDE
1169
+ _factory[MODEL_SERIALIZATION_TYPE_XGBOOST_UBJ] = XgboostUbjModelSaveSERDE
1170
+ _factory[MODEL_SERIALIZATION_TYPE_XGBOOST_TXT] = XgboostTxtModelSaveSERDE
1171
+ _factory[MODEL_SERIALIZATION_TYPE_XGBOOST_ONNX] = XgboostOnnxModelSaveSERDE
1172
+ _factory[MODEL_SERIALIZATION_TYPE_SPARK] = SparkModelSaveSERDE
1173
+ _factory[MODEL_SERIALIZATION_TYPE_HUGGINGFACE] = HuggingFacePipelineSaveSERDE
1174
+
1175
+ @classmethod
1176
+ def get(cls, se: str):
1177
+ serde = cls._factory.get(se, None)
1178
+ if serde:
1179
+ return serde()
1180
+ else:
1181
+ raise ValueError(
1182
+ f"This {se} format is not supported."
1183
+ f"Currently support the following format: {SUPPORTED_MODEL_SERIALIZERS}."
1184
+ )