genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
mlflow/openai/model.py ADDED
@@ -0,0 +1,824 @@
1
+ import importlib.metadata
2
+ import itertools
3
+ import logging
4
+ import os
5
+ import warnings
6
+ from functools import partial
7
+ from string import Formatter
8
+ from typing import Any, Optional, Union
9
+
10
+ import yaml
11
+ from packaging.version import Version
12
+
13
+ import mlflow
14
+ from mlflow import pyfunc
15
+ from mlflow.entities.model_registry.prompt import Prompt
16
+ from mlflow.environment_variables import MLFLOW_OPENAI_SECRET_SCOPE
17
+ from mlflow.exceptions import MlflowException
18
+ from mlflow.models import Model, ModelInputExample, ModelSignature
19
+ from mlflow.models.model import MLMODEL_FILE_NAME, _update_active_model_id_based_on_mlflow_model
20
+ from mlflow.models.signature import _infer_signature_from_input_example
21
+ from mlflow.models.utils import _save_example
22
+ from mlflow.openai.constant import FLAVOR_NAME
23
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
24
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
25
+ from mlflow.tracking.artifact_utils import _download_artifact_from_uri
26
+ from mlflow.types import ColSpec, Schema, TensorSpec
27
+ from mlflow.utils.annotations import experimental
28
+ from mlflow.utils.databricks_utils import (
29
+ check_databricks_secret_scope_access,
30
+ is_in_databricks_runtime,
31
+ )
32
+ from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
33
+ from mlflow.utils.environment import (
34
+ _CONDA_ENV_FILE_NAME,
35
+ _CONSTRAINTS_FILE_NAME,
36
+ _PYTHON_ENV_FILE_NAME,
37
+ _REQUIREMENTS_FILE_NAME,
38
+ _mlflow_conda_env,
39
+ _process_conda_env,
40
+ _process_pip_requirements,
41
+ _PythonEnv,
42
+ _validate_env_arguments,
43
+ )
44
+ from mlflow.utils.file_utils import write_to
45
+ from mlflow.utils.model_utils import (
46
+ _add_code_from_conf_to_system_path,
47
+ _get_flavor_configuration,
48
+ _validate_and_copy_code_paths,
49
+ _validate_and_prepare_target_save_path,
50
+ )
51
+ from mlflow.utils.openai_utils import (
52
+ _OAITokenHolder,
53
+ _OpenAIApiConfig,
54
+ _OpenAIEnvVar,
55
+ _validate_model_params,
56
+ )
57
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
58
+
59
+ MODEL_FILENAME = "model.yaml"
60
+ _PYFUNC_SUPPORTED_TASKS = ("chat.completions", "embeddings", "completions")
61
+
62
+ _logger = logging.getLogger(__name__)
63
+
64
+
65
+ @experimental(version="2.3.0")
66
+ def get_default_pip_requirements():
67
+ """
68
+ Returns:
69
+ A list of default pip requirements for MLflow Models produced by this flavor.
70
+ Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
71
+ that, at minimum, contains these requirements.
72
+ """
73
+ return list(map(_get_pinned_requirement, ["openai", "tiktoken", "tenacity"]))
74
+
75
+
76
+ @experimental(version="2.3.0")
77
+ def get_default_conda_env():
78
+ """
79
+ Returns:
80
+ The default Conda environment for MLflow Models produced by calls to
81
+ :func:`save_model()` and :func:`log_model()`.
82
+ """
83
+ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
84
+
85
+
86
+ def _get_obj_to_task_mapping():
87
+ from openai import resources as r
88
+
89
+ mapping = {
90
+ r.Audio: "audio",
91
+ r.chat.Completions: "chat.completions",
92
+ r.Completions: "completions",
93
+ r.Images.edit: "images.edit",
94
+ r.Embeddings: "embeddings",
95
+ r.Files: "files",
96
+ r.Images: "images",
97
+ r.FineTuning: "fine_tuning",
98
+ r.Moderations: "moderations",
99
+ r.Models: "models",
100
+ r.chat.AsyncCompletions: "chat.completions",
101
+ r.AsyncCompletions: "completions",
102
+ r.AsyncEmbeddings: "embeddings",
103
+ }
104
+
105
+ try:
106
+ from openai.resources.beta.chat import completions as c
107
+
108
+ mapping.update(
109
+ {
110
+ c.AsyncCompletions: "chat.completions",
111
+ c.Completions: "chat.completions",
112
+ }
113
+ )
114
+ except ImportError:
115
+ pass
116
+ return mapping
117
+
118
+
119
+ def _get_model_name(model):
120
+ import openai
121
+
122
+ if isinstance(model, str):
123
+ return model
124
+
125
+ if Version(_get_openai_package_version()).major < 1 and isinstance(model, openai.Model):
126
+ return model.id
127
+
128
+ raise mlflow.MlflowException(
129
+ f"Unsupported model type: {type(model)}", error_code=INVALID_PARAMETER_VALUE
130
+ )
131
+
132
+
133
+ def _get_task_name(task):
134
+ mapping = _get_obj_to_task_mapping()
135
+ if isinstance(task, str):
136
+ if task not in mapping.values():
137
+ raise mlflow.MlflowException(
138
+ f"Unsupported task: {task}", error_code=INVALID_PARAMETER_VALUE
139
+ )
140
+ return task
141
+ else:
142
+ task_name = (
143
+ mapping.get(task)
144
+ or mapping.get(task.__class__)
145
+ or mapping.get(getattr(task, "__func__")) # if task is a method
146
+ )
147
+ if task_name is None:
148
+ raise mlflow.MlflowException(
149
+ f"Unsupported task object: {task}", error_code=INVALID_PARAMETER_VALUE
150
+ )
151
+ return task_name
152
+
153
+
154
+ def _get_api_config() -> _OpenAIApiConfig:
155
+ """Gets the parameters and configuration of the OpenAI API connected to."""
156
+ import openai
157
+
158
+ api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value, openai.api_type)
159
+ api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value, openai.api_version)
160
+ api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value) or os.getenv(
161
+ _OpenAIEnvVar.OPENAI_BASE_URL.value
162
+ )
163
+ deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
164
+ organization = os.getenv(_OpenAIEnvVar.OPENAI_ORGANIZATION.value, None)
165
+ if api_type in ("azure", "azure_ad", "azuread"):
166
+ batch_size = 16
167
+ max_tokens_per_minute = 60_000
168
+ else:
169
+ # The maximum batch size is 2048:
170
+ # https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
171
+ # We use a smaller batch size to be safe.
172
+ batch_size = 1024
173
+ max_tokens_per_minute = 90_000
174
+ return _OpenAIApiConfig(
175
+ api_type=api_type,
176
+ batch_size=batch_size,
177
+ max_requests_per_minute=3_500,
178
+ max_tokens_per_minute=max_tokens_per_minute,
179
+ api_base=api_base,
180
+ api_version=api_version,
181
+ deployment_id=deployment_id,
182
+ organization=organization,
183
+ )
184
+
185
+
186
+ def _get_openai_package_version():
187
+ return importlib.metadata.version("openai")
188
+
189
+
190
+ def _log_secrets_yaml(local_model_dir, scope):
191
+ with open(os.path.join(local_model_dir, "openai.yaml"), "w") as f:
192
+ yaml.safe_dump({e.value: f"{scope}:{e.secret_key}" for e in _OpenAIEnvVar}, f)
193
+
194
+
195
+ def _parse_format_fields(s) -> set[str]:
196
+ """Parses format fields from a given string, e.g. "Hello {name}" -> ["name"]."""
197
+ return {fn for _, fn, _, _ in Formatter().parse(s) if fn is not None}
198
+
199
+
200
+ def _get_input_schema(task, content):
201
+ if content:
202
+ formatter = _ContentFormatter(task, content)
203
+ variables = formatter.variables
204
+ if len(variables) == 1:
205
+ return Schema([ColSpec(type="string")])
206
+ elif len(variables) > 1:
207
+ return Schema([ColSpec(name=v, type="string") for v in variables])
208
+ else:
209
+ return Schema([ColSpec(type="string")])
210
+ else:
211
+ return Schema([ColSpec(type="string")])
212
+
213
+
214
+ @experimental(version="2.3.0")
215
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
216
+ def save_model(
217
+ model,
218
+ task,
219
+ path,
220
+ conda_env=None,
221
+ code_paths=None,
222
+ mlflow_model=None,
223
+ signature: ModelSignature = None,
224
+ input_example: ModelInputExample = None,
225
+ pip_requirements=None,
226
+ extra_pip_requirements=None,
227
+ metadata=None,
228
+ **kwargs,
229
+ ):
230
+ """
231
+ Save an OpenAI model to a path on the local file system.
232
+
233
+ Args:
234
+ model: The OpenAI model name.
235
+ task: The task the model is performing, e.g., ``openai.chat.completions`` or
236
+ ``'chat.completions'``.
237
+ path: Local path where the model is to be saved.
238
+ conda_env: {{ conda_env }}
239
+ code_paths: {{ code_paths }}
240
+ mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
241
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
242
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
243
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
244
+ from datasets with valid model input (e.g. the training dataset with target
245
+ column omitted) and valid model output (e.g. model predictions generated on
246
+ the training dataset), for example:
247
+
248
+ .. code-block:: python
249
+
250
+ from mlflow.models import infer_signature
251
+
252
+ train = df.drop_column("target_label")
253
+ predictions = ... # compute model predictions
254
+ signature = infer_signature(train, predictions)
255
+ input_example: {{ input_example }}
256
+ pip_requirements: {{ pip_requirements }}
257
+ extra_pip_requirements: {{ extra_pip_requirements }}
258
+ metadata: {{ metadata }}
259
+ kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
260
+ :ref:`mlflow.openai.messages` for more details on this parameter)
261
+ or ``top_p`` value to use for chat completion.
262
+
263
+ .. code-block:: python
264
+
265
+ import mlflow
266
+ import openai
267
+
268
+ # Chat
269
+ mlflow.openai.save_model(
270
+ model="gpt-4o-mini",
271
+ task=openai.chat.completions,
272
+ messages=[{"role": "user", "content": "Tell me a joke."}],
273
+ path="model",
274
+ )
275
+
276
+ # Completions
277
+ mlflow.openai.save_model(
278
+ model="text-davinci-002",
279
+ task=openai.completions,
280
+ prompt="{text}. The general sentiment of the text is",
281
+ path="model",
282
+ )
283
+
284
+ # Embeddings
285
+ mlflow.openai.save_model(
286
+ model="text-embedding-ada-002",
287
+ task=openai.embeddings,
288
+ path="model",
289
+ )
290
+ """
291
+ if Version(_get_openai_package_version()).major < 1:
292
+ raise MlflowException("Only openai>=1.0 is supported.")
293
+
294
+ import numpy as np
295
+
296
+ _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
297
+ path = os.path.abspath(path)
298
+ _validate_and_prepare_target_save_path(path)
299
+ code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
300
+ task = _get_task_name(task)
301
+
302
+ if mlflow_model is None:
303
+ mlflow_model = Model()
304
+
305
+ if signature is not None:
306
+ if signature.params:
307
+ _validate_model_params(
308
+ task, kwargs, {p.name: p.default for p in signature.params.params}
309
+ )
310
+ elif task == "chat.completions":
311
+ messages = kwargs.get("messages", [])
312
+ if messages and not (
313
+ all(isinstance(m, dict) for m in messages) and all(map(_is_valid_message, messages))
314
+ ):
315
+ raise mlflow.MlflowException.invalid_parameter_value(
316
+ "If `messages` is provided, it must be a list of dictionaries with keys "
317
+ "'role' and 'content'."
318
+ )
319
+
320
+ signature = ModelSignature(
321
+ inputs=_get_input_schema(task, messages),
322
+ outputs=Schema([ColSpec(type="string", name=None)]),
323
+ )
324
+ elif task == "completions":
325
+ prompt = kwargs.get("prompt")
326
+ signature = ModelSignature(
327
+ inputs=_get_input_schema(task, prompt),
328
+ outputs=Schema([ColSpec(type="string", name=None)]),
329
+ )
330
+ elif task == "embeddings":
331
+ signature = ModelSignature(
332
+ inputs=Schema([ColSpec(type="string", name=None)]),
333
+ outputs=Schema([TensorSpec(type=np.dtype("float64"), shape=(-1,))]),
334
+ )
335
+
336
+ saved_example = _save_example(mlflow_model, input_example, path)
337
+ if signature is None and saved_example is not None:
338
+ wrapped_model = _OpenAIWrapper(model)
339
+ signature = _infer_signature_from_input_example(saved_example, wrapped_model)
340
+
341
+ if signature is not None:
342
+ mlflow_model.signature = signature
343
+
344
+ if metadata is not None:
345
+ mlflow_model.metadata = metadata
346
+ model_data_path = os.path.join(path, MODEL_FILENAME)
347
+ model_dict = {
348
+ "model": _get_model_name(model),
349
+ "task": task,
350
+ **kwargs,
351
+ }
352
+ with open(model_data_path, "w") as f:
353
+ yaml.safe_dump(model_dict, f)
354
+
355
+ if task in _PYFUNC_SUPPORTED_TASKS:
356
+ pyfunc.add_to_model(
357
+ mlflow_model,
358
+ loader_module="mlflow.openai",
359
+ data=MODEL_FILENAME,
360
+ conda_env=_CONDA_ENV_FILE_NAME,
361
+ python_env=_PYTHON_ENV_FILE_NAME,
362
+ code=code_dir_subpath,
363
+ )
364
+ mlflow_model.add_flavor(
365
+ FLAVOR_NAME,
366
+ openai_version=_get_openai_package_version(),
367
+ data=MODEL_FILENAME,
368
+ code=code_dir_subpath,
369
+ )
370
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
371
+
372
+ if is_in_databricks_runtime():
373
+ if scope := MLFLOW_OPENAI_SECRET_SCOPE.get():
374
+ url = "https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html"
375
+ warnings.warn(
376
+ "Specifying secrets for model serving with `MLFLOW_OPENAI_SECRET_SCOPE` is "
377
+ f"deprecated. Use secrets-based environment variables ({url}) instead.",
378
+ FutureWarning,
379
+ )
380
+ check_databricks_secret_scope_access(scope)
381
+ _log_secrets_yaml(path, scope)
382
+
383
+ if conda_env is None:
384
+ if pip_requirements is None:
385
+ default_reqs = get_default_pip_requirements()
386
+ inferred_reqs = mlflow.models.infer_pip_requirements(
387
+ path, FLAVOR_NAME, fallback=default_reqs
388
+ )
389
+ default_reqs = sorted(set(inferred_reqs).union(default_reqs))
390
+ else:
391
+ default_reqs = None
392
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
393
+ default_reqs,
394
+ pip_requirements,
395
+ extra_pip_requirements,
396
+ )
397
+ else:
398
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
399
+
400
+ with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
401
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
402
+
403
+ # Save `constraints.txt` if necessary
404
+ if pip_constraints:
405
+ write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
406
+
407
+ # Save `requirements.txt`
408
+ write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
409
+
410
+ _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
411
+
412
+
413
+ @experimental(version="2.3.0")
414
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
415
+ def log_model(
416
+ model,
417
+ task,
418
+ artifact_path: Optional[str] = None,
419
+ conda_env=None,
420
+ code_paths=None,
421
+ registered_model_name=None,
422
+ signature: ModelSignature = None,
423
+ input_example: ModelInputExample = None,
424
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
425
+ pip_requirements=None,
426
+ extra_pip_requirements=None,
427
+ metadata=None,
428
+ prompts: Optional[list[Union[str, Prompt]]] = None,
429
+ name: Optional[str] = None,
430
+ params: Optional[dict[str, Any]] = None,
431
+ tags: Optional[dict[str, Any]] = None,
432
+ model_type: Optional[str] = None,
433
+ step: int = 0,
434
+ model_id: Optional[str] = None,
435
+ **kwargs,
436
+ ):
437
+ """
438
+ Log an OpenAI model as an MLflow artifact for the current run.
439
+
440
+ Args:
441
+ model: The OpenAI model name or reference instance, e.g.,
442
+ ``openai.Model.retrieve("gpt-4o-mini")``.
443
+ task: The task the model is performing, e.g., ``openai.chat.completions`` or
444
+ ``'chat.completions'``.
445
+ artifact_path: Deprecated. Use `name` instead.
446
+ conda_env: {{ conda_env }}
447
+ code_paths: {{ code_paths }}
448
+ registered_model_name: If given, create a model version under
449
+ ``registered_model_name``, also creating a registered model if one
450
+ with the given name does not exist.
451
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
452
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
453
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
454
+ from datasets with valid model input (e.g. the training dataset with target
455
+ column omitted) and valid model output (e.g. model predictions generated on
456
+ the training dataset), for example:
457
+
458
+ .. code-block:: python
459
+
460
+ from mlflow.models import infer_signature
461
+
462
+ train = df.drop_column("target_label")
463
+ predictions = ... # compute model predictions
464
+ signature = infer_signature(train, predictions)
465
+
466
+ input_example: {{ input_example }}
467
+ await_registration_for: Number of seconds to wait for the model version to finish
468
+ being created and is in ``READY`` status. By default, the function
469
+ waits for five minutes. Specify 0 or None to skip waiting.
470
+ pip_requirements: {{ pip_requirements }}
471
+ extra_pip_requirements: {{ extra_pip_requirements }}
472
+ metadata: {{ metadata }}
473
+ prompts: {{ prompts }}
474
+ name: {{ name }}
475
+ params: {{ params }}
476
+ tags: {{ tags }}
477
+ model_type: {{ model_type }}
478
+ step: {{ step }}
479
+ model_id: {{ model_id }}
480
+ kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
481
+ :ref:`mlflow.openai.messages` for more details on this parameter)
482
+ or ``top_p`` value to use for chat completion.
483
+
484
+ Returns:
485
+ A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
486
+ metadata of the logged model.
487
+
488
+ .. code-block:: python
489
+ :caption: Example
490
+
491
+ import mlflow
492
+ import openai
493
+ import pandas as pd
494
+
495
+ # Chat
496
+ with mlflow.start_run():
497
+ info = mlflow.openai.log_model(
498
+ model="gpt-4o-mini",
499
+ task=openai.chat.completions,
500
+ messages=[{"role": "user", "content": "Tell me a joke about {animal}."}],
501
+ name="model",
502
+ )
503
+ model = mlflow.pyfunc.load_model(info.model_uri)
504
+ df = pd.DataFrame({"animal": ["cats", "dogs"]})
505
+ print(model.predict(df))
506
+
507
+ # Embeddings
508
+ with mlflow.start_run():
509
+ info = mlflow.openai.log_model(
510
+ model="text-embedding-ada-002",
511
+ task=openai.embeddings,
512
+ name="embeddings",
513
+ )
514
+ model = mlflow.pyfunc.load_model(info.model_uri)
515
+ print(model.predict(["hello", "world"]))
516
+ """
517
+ return Model.log(
518
+ artifact_path=artifact_path,
519
+ name=name,
520
+ flavor=mlflow.openai,
521
+ registered_model_name=registered_model_name,
522
+ model=model,
523
+ task=task,
524
+ conda_env=conda_env,
525
+ code_paths=code_paths,
526
+ signature=signature,
527
+ input_example=input_example,
528
+ await_registration_for=await_registration_for,
529
+ pip_requirements=pip_requirements,
530
+ extra_pip_requirements=extra_pip_requirements,
531
+ metadata=metadata,
532
+ prompts=prompts,
533
+ params=params,
534
+ tags=tags,
535
+ model_type=model_type,
536
+ step=step,
537
+ model_id=model_id,
538
+ **kwargs,
539
+ )
540
+
541
+
542
+ def _load_model(path):
543
+ model_file_path = os.path.dirname(path)
544
+ if os.path.exists(model_file_path):
545
+ mlflow_model = Model.load(model_file_path)
546
+ _update_active_model_id_based_on_mlflow_model(mlflow_model)
547
+ with open(path) as f:
548
+ return yaml.safe_load(f)
549
+
550
+
551
+ def _is_valid_message(d):
552
+ return isinstance(d, dict) and "content" in d and "role" in d
553
+
554
+
555
+ class _ContentFormatter:
556
+ def __init__(self, task, template=None):
557
+ if task == "completions":
558
+ template = template or "{prompt}"
559
+ if not isinstance(template, str):
560
+ raise mlflow.MlflowException.invalid_parameter_value(
561
+ f"Template for task {task} expects type `str`, but got {type(template)}."
562
+ )
563
+
564
+ self.template = template
565
+ self.format_fn = self.format_prompt
566
+ self.variables = sorted(_parse_format_fields(self.template))
567
+ elif task == "chat.completions":
568
+ if not template:
569
+ template = [{"role": "user", "content": "{content}"}]
570
+ if not all(map(_is_valid_message, template)):
571
+ raise mlflow.MlflowException.invalid_parameter_value(
572
+ f"Template for task {task} expects type `dict` with keys 'content' "
573
+ f"and 'role', but got {type(template)}."
574
+ )
575
+
576
+ self.template = template.copy()
577
+ self.format_fn = self.format_chat
578
+ self.variables = sorted(
579
+ set(
580
+ itertools.chain.from_iterable(
581
+ _parse_format_fields(message.get("content"))
582
+ | _parse_format_fields(message.get("role"))
583
+ for message in self.template
584
+ )
585
+ )
586
+ )
587
+ if not self.variables:
588
+ self.template.append({"role": "user", "content": "{content}"})
589
+ self.variables.append("content")
590
+ else:
591
+ raise mlflow.MlflowException.invalid_parameter_value(
592
+ f"Task type ``{task}`` is not supported for formatting."
593
+ )
594
+
595
+ def format(self, **params):
596
+ if missing_params := set(self.variables) - set(params):
597
+ raise mlflow.MlflowException.invalid_parameter_value(
598
+ f"Expected parameters {self.variables} to be provided, "
599
+ f"only got {list(params)}, {list(missing_params)} are missing."
600
+ )
601
+ return self.format_fn(**params)
602
+
603
+ def format_prompt(self, **params):
604
+ return self.template.format(**{v: params[v] for v in self.variables})
605
+
606
+ def format_chat(self, **params):
607
+ format_args = {v: params[v] for v in self.variables}
608
+ return [
609
+ {
610
+ "role": message.get("role").format(**format_args),
611
+ "content": message.get("content").format(**format_args),
612
+ }
613
+ for message in self.template
614
+ ]
615
+
616
+
617
+ def _first_string_column(pdf):
618
+ iter_str_cols = (c for c, v in pdf.iloc[0].items() if isinstance(v, str))
619
+ col = next(iter_str_cols, None)
620
+ if col is None:
621
+ raise mlflow.MlflowException.invalid_parameter_value(
622
+ f"Could not find a string column in the input data: {pdf.dtypes.to_dict()}"
623
+ )
624
+ return col
625
+
626
+
627
+ class _OpenAIWrapper:
628
+ def __init__(self, model):
629
+ task = model.pop("task")
630
+ if task not in _PYFUNC_SUPPORTED_TASKS:
631
+ raise mlflow.MlflowException.invalid_parameter_value(
632
+ f"Unsupported task: {task}. Supported tasks: {_PYFUNC_SUPPORTED_TASKS}."
633
+ )
634
+ self.model = model
635
+ self.task = task
636
+ self.api_config = _get_api_config()
637
+ self.api_token = _OAITokenHolder(self.api_config.api_type)
638
+
639
+ if self.task != "embeddings":
640
+ self._setup_completions()
641
+
642
+ def get_raw_model(self):
643
+ """
644
+ Returns the underlying model.
645
+ """
646
+ return self.model
647
+
648
+ def _setup_completions(self):
649
+ if self.task == "chat.completions":
650
+ self.template = self.model.get("messages", [])
651
+ else:
652
+ self.template = self.model.get("prompt")
653
+ self.formatter = _ContentFormatter(self.task, self.template)
654
+
655
+ def format_completions(self, params_list):
656
+ return [self.formatter.format(**params) for params in params_list]
657
+
658
+ def get_params_list(self, data):
659
+ if len(self.formatter.variables) == 1:
660
+ variable = self.formatter.variables[0]
661
+ if variable in data.columns:
662
+ return data[[variable]].to_dict(orient="records")
663
+ else:
664
+ first_string_column = _first_string_column(data)
665
+ return [{variable: s} for s in data[first_string_column]]
666
+ else:
667
+ return data[self.formatter.variables].to_dict(orient="records")
668
+
669
+ def get_client(self, max_retries: int, timeout: float):
670
+ # with_option method should not be used before v1.3.8: https://github.com/openai/openai-python/issues/865
671
+ if self.api_config.api_type in ("azure", "azure_ad", "azuread"):
672
+ from openai import AzureOpenAI
673
+
674
+ return AzureOpenAI(
675
+ api_key=self.api_token.token,
676
+ azure_endpoint=self.api_config.api_base,
677
+ api_version=self.api_config.api_version,
678
+ azure_deployment=self.api_config.deployment_id,
679
+ max_retries=max_retries,
680
+ timeout=timeout,
681
+ )
682
+ else:
683
+ from openai import OpenAI
684
+
685
+ return OpenAI(
686
+ api_key=self.api_token.token,
687
+ base_url=self.api_config.api_base,
688
+ max_retries=max_retries,
689
+ timeout=timeout,
690
+ )
691
+
692
+ def _predict_chat(self, data, params):
693
+ from mlflow.openai.api_request_parallel_processor import process_api_requests
694
+
695
+ _validate_model_params(self.task, self.model, params)
696
+ max_retries = params.pop("max_retries", self.api_config.max_retries)
697
+ timeout = params.pop("timeout", self.api_config.timeout)
698
+
699
+ messages_list = self.format_completions(self.get_params_list(data))
700
+ client = self.get_client(max_retries=max_retries, timeout=timeout)
701
+
702
+ requests = [
703
+ partial(
704
+ client.chat.completions.create,
705
+ messages=messages,
706
+ model=self.model["model"],
707
+ **params,
708
+ )
709
+ for messages in messages_list
710
+ ]
711
+
712
+ results = process_api_requests(request_tasks=requests)
713
+
714
+ return [r.choices[0].message.content for r in results]
715
+
716
+ def _predict_completions(self, data, params):
717
+ from mlflow.openai.api_request_parallel_processor import process_api_requests
718
+
719
+ _validate_model_params(self.task, self.model, params)
720
+ prompts_list = self.format_completions(self.get_params_list(data))
721
+ max_retries = params.pop("max_retries", self.api_config.max_retries)
722
+ timeout = params.pop("timeout", self.api_config.timeout)
723
+ batch_size = params.pop("batch_size", self.api_config.batch_size)
724
+ _logger.debug(f"Requests are being batched by {batch_size} samples.")
725
+
726
+ client = self.get_client(max_retries=max_retries, timeout=timeout)
727
+
728
+ requests = [
729
+ partial(
730
+ client.completions.create,
731
+ prompt=prompts_list[i : i + batch_size],
732
+ model=self.model["model"],
733
+ **params,
734
+ )
735
+ for i in range(0, len(prompts_list), batch_size)
736
+ ]
737
+
738
+ results = process_api_requests(request_tasks=requests)
739
+
740
+ return [row.text for batch in results for row in batch.choices]
741
+
742
+ def _predict_embeddings(self, data, params):
743
+ from mlflow.openai.api_request_parallel_processor import process_api_requests
744
+
745
+ _validate_model_params(self.task, self.model, params)
746
+ max_retries = params.pop("max_retries", self.api_config.max_retries)
747
+ timeout = params.pop("timeout", self.api_config.timeout)
748
+ batch_size = params.pop("batch_size", self.api_config.batch_size)
749
+ _logger.debug(f"Requests are being batched by {batch_size} samples.")
750
+
751
+ first_string_column = _first_string_column(data)
752
+ texts = data[first_string_column].tolist()
753
+
754
+ client = self.get_client(max_retries=max_retries, timeout=timeout)
755
+
756
+ requests = [
757
+ partial(
758
+ client.embeddings.create,
759
+ input=texts[i : i + batch_size],
760
+ model=self.model["model"],
761
+ **params,
762
+ )
763
+ for i in range(0, len(texts), batch_size)
764
+ ]
765
+
766
+ results = process_api_requests(request_tasks=requests)
767
+
768
+ return [row.embedding for batch in results for row in batch.data]
769
+
770
+ def predict(self, data, params: Optional[dict[str, Any]] = None):
771
+ """
772
+ Args:
773
+ data: Model input data.
774
+ params: Additional parameters to pass to the model for inference.
775
+
776
+ Returns:
777
+ Model predictions.
778
+ """
779
+ self.api_token.refresh()
780
+ if self.task == "chat.completions":
781
+ return self._predict_chat(data, params or {})
782
+ elif self.task == "completions":
783
+ return self._predict_completions(data, params or {})
784
+ elif self.task == "embeddings":
785
+ return self._predict_embeddings(data, params or {})
786
+
787
+
788
+ def _load_pyfunc(path):
789
+ """Loads PyFunc implementation. Called by ``pyfunc.load_model``.
790
+
791
+ Args:
792
+ path: Local filesystem path to the MLflow Model with the ``openai`` flavor.
793
+ """
794
+ return _OpenAIWrapper(_load_model(path))
795
+
796
+
797
+ @experimental(version="2.3.0")
798
+ def load_model(model_uri, dst_path=None):
799
+ """
800
+ Load an OpenAI model from a local file or a run.
801
+
802
+ Args:
803
+ model_uri: The location, in URI format, of the MLflow model. For example:
804
+
805
+ - ``/Users/me/path/to/local/model``
806
+ - ``relative/path/to/local/model``
807
+ - ``s3://my_bucket/path/to/model``
808
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
809
+
810
+ For more information about supported URI schemes, see
811
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
812
+ artifact-locations>`_.
813
+ dst_path: The local filesystem path to which to download the model artifact.
814
+ This directory must already exist. If unspecified, a local output
815
+ path will be created.
816
+
817
+ Returns:
818
+ A dictionary representing the OpenAI model.
819
+ """
820
+ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
821
+ flavor_conf = _get_flavor_configuration(local_model_path, FLAVOR_NAME)
822
+ _add_code_from_conf_to_system_path(local_model_path, flavor_conf)
823
+ model_data_path = os.path.join(local_model_path, flavor_conf.get("data", MODEL_FILENAME))
824
+ return _load_model(model_data_path)