genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
mlflow/pyfunc/model.py ADDED
@@ -0,0 +1,1473 @@
1
+ """
2
+ The ``mlflow.pyfunc.model`` module defines logic for saving and loading custom "python_function"
3
+ models with a user-defined ``PythonModel`` subclass.
4
+ """
5
+
6
+ import bz2
7
+ import gzip
8
+ import inspect
9
+ import logging
10
+ import lzma
11
+ import os
12
+ import shutil
13
+ from abc import ABCMeta, abstractmethod
14
+ from pathlib import Path
15
+ from typing import Any, Generator, Optional, Union
16
+
17
+ import cloudpickle
18
+ import pandas as pd
19
+ import yaml
20
+
21
+ import mlflow.pyfunc
22
+ import mlflow.utils
23
+ from mlflow.environment_variables import MLFLOW_LOG_MODEL_COMPRESSION
24
+ from mlflow.exceptions import MlflowException
25
+ from mlflow.models import Model
26
+ from mlflow.models.model import MLMODEL_FILE_NAME, MODEL_CODE_PATH
27
+ from mlflow.models.rag_signatures import ChatCompletionRequest, SplitChatMessagesRequest
28
+ from mlflow.models.signature import (
29
+ _extract_type_hints,
30
+ _is_context_in_predict_function_signature,
31
+ _TypeHints,
32
+ )
33
+ from mlflow.models.utils import _load_model_code_path
34
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
35
+ from mlflow.pyfunc.utils import pyfunc
36
+ from mlflow.pyfunc.utils.data_validation import (
37
+ _check_func_signature,
38
+ _get_func_info_if_type_hint_supported,
39
+ _wrap_predict_with_pyfunc,
40
+ wrap_non_list_predict_pydantic,
41
+ )
42
+ from mlflow.pyfunc.utils.input_converter import _hydrate_dataclass
43
+ from mlflow.tracking.artifact_utils import _download_artifact_from_uri
44
+ from mlflow.types.agent import (
45
+ ChatAgentChunk,
46
+ ChatAgentMessage,
47
+ ChatAgentRequest,
48
+ ChatAgentResponse,
49
+ ChatContext,
50
+ )
51
+ from mlflow.types.llm import (
52
+ ChatCompletionChunk,
53
+ ChatCompletionResponse,
54
+ ChatMessage,
55
+ ChatParams,
56
+ )
57
+ from mlflow.types.utils import _is_list_dict_str, _is_list_str
58
+ from mlflow.utils.annotations import deprecated, experimental
59
+ from mlflow.utils.databricks_utils import (
60
+ _get_databricks_serverless_env_vars,
61
+ is_in_databricks_serverless_runtime,
62
+ )
63
+ from mlflow.utils.environment import (
64
+ _CONDA_ENV_FILE_NAME,
65
+ _CONSTRAINTS_FILE_NAME,
66
+ _PYTHON_ENV_FILE_NAME,
67
+ _REQUIREMENTS_FILE_NAME,
68
+ _mlflow_conda_env,
69
+ _process_conda_env,
70
+ _process_pip_requirements,
71
+ _PythonEnv,
72
+ )
73
+ from mlflow.utils.file_utils import TempDir, get_total_file_size, write_to
74
+ from mlflow.utils.model_utils import _get_flavor_configuration, _validate_infer_and_copy_code_paths
75
+ from mlflow.utils.pydantic_utils import IS_PYDANTIC_V2_OR_NEWER
76
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
77
+
78
+ CONFIG_KEY_ARTIFACTS = "artifacts"
79
+ CONFIG_KEY_ARTIFACT_RELATIVE_PATH = "path"
80
+ CONFIG_KEY_ARTIFACT_URI = "uri"
81
+ CONFIG_KEY_PYTHON_MODEL = "python_model"
82
+ CONFIG_KEY_CLOUDPICKLE_VERSION = "cloudpickle_version"
83
+ CONFIG_KEY_COMPRESSION = "python_model_compression"
84
+ _SAVED_PYTHON_MODEL_SUBPATH = "python_model.pkl"
85
+ _DEFAULT_CHAT_MODEL_METADATA_TASK = "agent/v1/chat"
86
+ _DEFAULT_CHAT_AGENT_METADATA_TASK = "agent/v2/chat"
87
+ _COMPRESSION_INFO = {
88
+ "lzma": {"ext": ".xz", "open": lzma.open},
89
+ "bzip2": {"ext": ".bz2", "open": bz2.open},
90
+ "gzip": {"ext": ".gz", "open": gzip.open},
91
+ }
92
+ _DEFAULT_RESPONSES_AGENT_METADATA_TASK = "agent/v1/responses"
93
+
94
+ _logger = logging.getLogger(__name__)
95
+
96
+
97
+ def get_default_pip_requirements():
98
+ """
99
+ Returns:
100
+ A list of default pip requirements for MLflow Models produced by this flavor. Calls to
101
+ :func:`save_model()` and :func:`log_model()` produce a pip environment that, at minimum,
102
+ contains these requirements.
103
+ """
104
+ return [_get_pinned_requirement("cloudpickle")]
105
+
106
+
107
+ def get_default_conda_env():
108
+ """
109
+ Returns:
110
+ The default Conda environment for MLflow Models produced by calls to
111
+ :func:`save_model() <mlflow.pyfunc.save_model>`
112
+ and :func:`log_model() <mlflow.pyfunc.log_model>` when a user-defined subclass of
113
+ :class:`PythonModel` is provided.
114
+ """
115
+ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
116
+
117
+
118
+ def _log_warning_if_params_not_in_predict_signature(logger, params):
119
+ if params:
120
+ logger.warning(
121
+ "The underlying model does not support passing additional parameters to the predict"
122
+ f" function. `params` {params} will be ignored."
123
+ )
124
+
125
+
126
+ class PythonModel:
127
+ """
128
+ Represents a generic Python model that evaluates inputs and produces API-compatible outputs.
129
+ By subclassing :class:`~PythonModel`, users can create customized MLflow models with the
130
+ "python_function" ("pyfunc") flavor, leveraging custom inference logic and artifact
131
+ dependencies.
132
+ """
133
+
134
+ __metaclass__ = ABCMeta
135
+
136
+ def load_context(self, context):
137
+ """
138
+ Loads artifacts from the specified :class:`~PythonModelContext` that can be used by
139
+ :func:`~PythonModel.predict` when evaluating inputs. When loading an MLflow model with
140
+ :func:`~load_model`, this method is called as soon as the :class:`~PythonModel` is
141
+ constructed.
142
+
143
+ The same :class:`~PythonModelContext` will also be available during calls to
144
+ :func:`~PythonModel.predict`, but it may be more efficient to override this method
145
+ and load artifacts from the context at model load time.
146
+
147
+ Args:
148
+ context: A :class:`~PythonModelContext` instance containing artifacts that the model
149
+ can use to perform inference.
150
+ """
151
+
152
+ @deprecated("predict_type_hints", "2.20.0")
153
+ def _get_type_hints(self):
154
+ return self.predict_type_hints
155
+
156
+ @property
157
+ def predict_type_hints(self) -> _TypeHints:
158
+ """
159
+ Internal method to get type hints from the predict function signature.
160
+ """
161
+ if hasattr(self, "_predict_type_hints"):
162
+ return self._predict_type_hints
163
+ if _is_context_in_predict_function_signature(func=self.predict):
164
+ self._predict_type_hints = _extract_type_hints(self.predict, input_arg_index=1)
165
+ else:
166
+ self._predict_type_hints = _extract_type_hints(self.predict, input_arg_index=0)
167
+ return self._predict_type_hints
168
+
169
+ def __init_subclass__(cls, **kwargs) -> None:
170
+ super().__init_subclass__(**kwargs)
171
+
172
+ # automatically wrap the predict method with pyfunc to ensure data validation
173
+ # NB: skip wrapping for built-in classes defined in MLflow e.g. ChatModel
174
+ if not cls.__module__.startswith("mlflow."):
175
+ # TODO: ChatModel uses dataclass type hints which are not supported now, hence
176
+ # we need to skip type hint based validation for user-defined subclasses
177
+ # of ChatModel. Once we either (1) support dataclass type hints or (2) migrate
178
+ # ChatModel to pydantic, we can remove this exclusion.
179
+ # NB: issubclass(cls, ChatModel) does not work so we use a hacky attribute check
180
+ if getattr(cls, "_skip_type_hint_validation", False):
181
+ return
182
+
183
+ predict_attr = cls.__dict__.get("predict")
184
+ if predict_attr is not None and callable(predict_attr):
185
+ func_info = _get_func_info_if_type_hint_supported(predict_attr)
186
+ setattr(cls, "predict", _wrap_predict_with_pyfunc(predict_attr, func_info))
187
+ predict_stream_attr = cls.__dict__.get("predict_stream")
188
+ if predict_stream_attr is not None and callable(predict_stream_attr):
189
+ _check_func_signature(predict_stream_attr, "predict_stream")
190
+ else:
191
+ cls.predict._is_pyfunc = True
192
+
193
+ @abstractmethod
194
+ def predict(self, context, model_input, params: Optional[dict[str, Any]] = None):
195
+ """
196
+ Evaluates a pyfunc-compatible input and produces a pyfunc-compatible output.
197
+ For more information about the pyfunc input/output API, see the :ref:`pyfunc-inference-api`.
198
+
199
+ Args:
200
+ context: A :class:`~PythonModelContext` instance containing artifacts that the model
201
+ can use to perform inference.
202
+ model_input: A pyfunc-compatible input for the model to evaluate.
203
+ params: Additional parameters to pass to the model for inference.
204
+
205
+ .. tip::
206
+ Since MLflow 2.20.0, `context` parameter can be removed from `predict` function
207
+ signature if it's not used. `def predict(self, model_input, params=None)` is valid.
208
+ """
209
+
210
+ def predict_stream(self, context, model_input, params: Optional[dict[str, Any]] = None):
211
+ """
212
+ Evaluates a pyfunc-compatible input and produces an iterator of output.
213
+ For more information about the pyfunc input API, see the :ref:`pyfunc-inference-api`.
214
+
215
+ Args:
216
+ context: A :class:`~PythonModelContext` instance containing artifacts that the model
217
+ can use to perform inference.
218
+ model_input: A pyfunc-compatible input for the model to evaluate.
219
+ params: Additional parameters to pass to the model for inference.
220
+
221
+ .. tip::
222
+ Since MLflow 2.20.0, `context` parameter can be removed from `predict_stream` function
223
+ signature if it's not used.
224
+ `def predict_stream(self, model_input, params=None)` is valid.
225
+ """
226
+ raise NotImplementedError()
227
+
228
+
229
+ class _FunctionPythonModel(PythonModel):
230
+ """
231
+ When a user specifies a ``python_model`` argument that is a function, we wrap the function
232
+ in an instance of this class.
233
+ """
234
+
235
+ def __init__(self, func, signature=None):
236
+ self.signature = signature
237
+ # only wrap `func` if @pyfunc is not already applied
238
+ if not getattr(func, "_is_pyfunc", False):
239
+ self.func = pyfunc(func)
240
+ else:
241
+ self.func = func
242
+
243
+ @property
244
+ def predict_type_hints(self):
245
+ if hasattr(self, "_predict_type_hints"):
246
+ return self._predict_type_hints
247
+ self._predict_type_hints = _extract_type_hints(self.func, input_arg_index=0)
248
+ return self._predict_type_hints
249
+
250
+ def predict(
251
+ self,
252
+ model_input,
253
+ params: Optional[dict[str, Any]] = None,
254
+ ):
255
+ """
256
+ Args:
257
+ model_input: A pyfunc-compatible input for the model to evaluate.
258
+ params: Additional parameters to pass to the model for inference.
259
+
260
+ Returns:
261
+ Model predictions.
262
+ """
263
+ # callable only supports one input argument for now
264
+ return self.func(model_input)
265
+
266
+
267
+ class PythonModelContext:
268
+ """
269
+ A collection of artifacts that a :class:`~PythonModel` can use when performing inference.
270
+ :class:`~PythonModelContext` objects are created *implicitly* by the
271
+ :func:`save_model() <mlflow.pyfunc.save_model>` and
272
+ :func:`log_model() <mlflow.pyfunc.log_model>` persistence methods, using the contents specified
273
+ by the ``artifacts`` parameter of these methods.
274
+ """
275
+
276
+ def __init__(self, artifacts, model_config):
277
+ """
278
+ Args:
279
+ artifacts: A dictionary of ``<name, artifact_path>`` entries, where ``artifact_path``
280
+ is an absolute filesystem path to a given artifact.
281
+ model_config: The model configuration to make available to the model at
282
+ loading time.
283
+ """
284
+ self._artifacts = artifacts
285
+ self._model_config = model_config
286
+
287
+ @property
288
+ def artifacts(self):
289
+ """
290
+ A dictionary containing ``<name, artifact_path>`` entries, where ``artifact_path`` is an
291
+ absolute filesystem path to the artifact.
292
+ """
293
+ return self._artifacts
294
+
295
+ @property
296
+ def model_config(self):
297
+ """
298
+ A dictionary containing ``<config, value>`` entries, where ``config`` is the name
299
+ of the model configuration keys and ``value`` is the value of the given configuration.
300
+ """
301
+
302
+ return self._model_config
303
+
304
+
305
+ @deprecated("ResponsesAgent", "3.0.0")
306
+ class ChatModel(PythonModel, metaclass=ABCMeta):
307
+ """
308
+ .. tip::
309
+ Since MLflow 3.0.0, we recommend using
310
+ :py:class:`ResponsesAgent <mlflow.pyfunc.ResponsesAgent>`
311
+ instead of :py:class:`ChatModel <mlflow.pyfunc.ChatModel>` unless you need strict
312
+ compatibility with the OpenAI ChatCompletion API.
313
+
314
+ A subclass of :class:`~PythonModel` that makes it more convenient to implement models
315
+ that are compatible with popular LLM chat APIs. By subclassing :class:`~ChatModel`,
316
+ users can create MLflow models with a ``predict()`` method that is more convenient
317
+ for chat tasks than the generic :class:`~PythonModel` API. ChatModels automatically
318
+ define input/output signatures and an input example, so manually specifying these values
319
+ when calling :func:`mlflow.pyfunc.save_model() <mlflow.pyfunc.save_model>` is not necessary.
320
+
321
+ See the documentation of the ``predict()`` method below for details on that parameters and
322
+ outputs that are expected by the ``ChatModel`` API.
323
+
324
+ .. list-table::
325
+ :header-rows: 1
326
+ :widths: 20 40 40
327
+
328
+ * -
329
+ - ChatModel
330
+ - PythonModel
331
+ * - When to use
332
+ - Use when you want to develop and deploy a conversational model with **standard** chat
333
+ schema compatible with OpenAI spec.
334
+ - Use when you want **full control** over the model's interface or customize every aspect
335
+ of your model's behavior.
336
+ * - Interface
337
+ - **Fixed** to OpenAI's chat schema.
338
+ - **Full control** over the model's input and output schema.
339
+ * - Setup
340
+ - **Quick**. Works out of the box for conversational applications, with pre-defined
341
+ model signature and input example.
342
+ - **Custom**. You need to define model signature or input example yourself.
343
+ * - Complexity
344
+ - **Low**. Standardized interface simplified model deployment and integration.
345
+ - **High**. Deploying and integrating the custom PythonModel may not be straightforward.
346
+ E.g., The model needs to handle Pandas DataFrames as MLflow converts input data to
347
+ DataFrames before passing it to PythonModel.
348
+
349
+ """
350
+
351
+ _skip_type_hint_validation = True
352
+
353
+ @abstractmethod
354
+ def predict(
355
+ self, context, messages: list[ChatMessage], params: ChatParams
356
+ ) -> ChatCompletionResponse:
357
+ """
358
+ Evaluates a chat input and produces a chat output.
359
+
360
+ Args:
361
+ context: A :class:`~PythonModelContext` instance containing artifacts that the model
362
+ can use to perform inference.
363
+ messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
364
+ A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
365
+ objects representing chat history.
366
+ params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
367
+ A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
368
+ containing various parameters used to modify model behavior during
369
+ inference.
370
+
371
+ .. tip::
372
+ Since MLflow 2.20.0, `context` parameter can be removed from `predict` function
373
+ signature if it's not used.
374
+ `def predict(self, messages: list[ChatMessage], params: ChatParams)` is valid.
375
+
376
+ Returns:
377
+ A :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`
378
+ object containing the model's response(s), as well as other metadata.
379
+ """
380
+
381
+ def predict_stream(
382
+ self, context, messages: list[ChatMessage], params: ChatParams
383
+ ) -> Generator[ChatCompletionChunk, None, None]:
384
+ """
385
+ Evaluates a chat input and produces a chat output.
386
+ Override this function to implement a real stream prediction.
387
+
388
+ Args:
389
+ context: A :class:`~PythonModelContext` instance containing artifacts that the model
390
+ can use to perform inference.
391
+ messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
392
+ A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
393
+ objects representing chat history.
394
+ params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
395
+ A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
396
+ containing various parameters used to modify model behavior during
397
+ inference.
398
+
399
+ .. tip::
400
+ Since MLflow 2.20.0, `context` parameter can be removed from `predict_stream` function
401
+ signature if it's not used.
402
+ `def predict_stream(self, messages: list[ChatMessage], params: ChatParams)` is valid.
403
+
404
+ Returns:
405
+ A generator over :py:class:`ChatCompletionChunk <mlflow.types.llm.ChatCompletionChunk>`
406
+ object containing the model's response(s), as well as other metadata.
407
+ """
408
+ raise NotImplementedError(
409
+ "Streaming implementation not provided. Please override the "
410
+ "`predict_stream` method on your model to generate streaming "
411
+ "predictions"
412
+ )
413
+
414
+
415
+ class ChatAgent(PythonModel, metaclass=ABCMeta):
416
+ """
417
+ .. tip::
418
+ Since MLflow 3.0.0, we recommend using
419
+ :py:class:`ResponsesAgent <mlflow.pyfunc.ResponsesAgent>`
420
+ instead of :py:class:`ChatAgent <mlflow.pyfunc.ChatAgent>`.
421
+
422
+ **What is the ChatAgent Interface?**
423
+
424
+ The ChatAgent interface is a chat schema specification that has been designed for authoring
425
+ conversational agents. ChatAgent allows your agent to do the following:
426
+
427
+ - Return multiple messages
428
+ - Return intermediate steps for tool calling agents
429
+ - Confirm tool calls
430
+ - Support multi-agent scenarios
431
+
432
+ ``ChatAgent`` should always be used when authoring an agent. We also recommend using
433
+ ``ChatAgent`` instead of :py:class:`ChatModel <mlflow.pyfunc.ChatModel>` even for use cases
434
+ like simple chat models (e.g. prompt-engineered LLMs), to give you the flexibility to support
435
+ more agentic functionality in the future.
436
+
437
+ The :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema is similar to,
438
+ but not strictly compatible with the OpenAI ChatCompletion schema. ChatAgent adds additional
439
+ functionality and diverges from OpenAI
440
+ :py:class:`ChatCompletionRequest <mlflow.types.llm.ChatCompletionRequest>` in the following
441
+ ways:
442
+
443
+ - Adds an optional ``attachments`` attribute to every input/output message for tools and
444
+ internal agent calls so they can return additional outputs such as visualizations and progress
445
+ indicators
446
+ - Adds a ``context`` attribute with a ``conversation_id`` and ``user_id`` attributes to enable
447
+ modifying the behavior of the agent depending on the user querying the agent
448
+ - Adds the ``custom_inputs`` attribute, an arbitrary ``dict[str, Any]`` to pass in any
449
+ additional information to modify the agent's behavior
450
+
451
+ The :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` schema diverges from
452
+ :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>` schema in the
453
+ following ways:
454
+
455
+ - Adds the ``custom_outputs`` key, an arbitrary ``dict[str, Any]`` to return any additional
456
+ information
457
+ - Allows multiple messages in the output, to improve the display and evaluation of internal
458
+ tool calls and inter-agent communication that led to the final answer.
459
+
460
+ Here's an example of a :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>`
461
+ detailing a tool call:
462
+
463
+ .. code-block:: python
464
+
465
+ {
466
+ "messages": [
467
+ {
468
+ "role": "assistant",
469
+ "content": "",
470
+ "id": "run-04b46401-c569-4a4a-933e-62e38d8f9647-0",
471
+ "tool_calls": [
472
+ {
473
+ "id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043",
474
+ "type": "function",
475
+ "function": {
476
+ "name": "generate_random_ints",
477
+ "arguments": '{"min": 1, "max": 100, "size": 5}',
478
+ },
479
+ }
480
+ ],
481
+ },
482
+ {
483
+ "role": "tool",
484
+ "content": '{"content": "Generated array of 2 random ints in [1, 100]."',
485
+ "name": "generate_random_ints",
486
+ "id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043",
487
+ "tool_call_id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043",
488
+ },
489
+ {
490
+ "role": "assistant",
491
+ "content": "The new set of generated random numbers are: 93, 51, 12, 7, and 25",
492
+ "name": "llm",
493
+ "id": "run-70c7c738-739f-4ecd-ad18-0ae232df24e8-0",
494
+ },
495
+ ],
496
+ "custom_outputs": {"random_nums": [93, 51, 12, 7, 25]},
497
+ }
498
+
499
+ **Streaming Agent Output with ChatAgent**
500
+
501
+ Please read the docstring of
502
+ :py:func:`ChatAgent.predict_stream <mlflow.pyfunc.ChatAgent.predict_stream>`
503
+ for more details on how to stream the output of your agent.
504
+
505
+
506
+ **Authoring a ChatAgent**
507
+
508
+ Authoring an agent using the ChatAgent interface is a framework-agnostic way to create a model
509
+ with a standardized interface that is loggable with the MLflow pyfunc flavor, can be reused
510
+ across clients, and is ready for serving workloads.
511
+
512
+ To write your own agent, subclass ``ChatAgent``, implementing the ``predict`` and optionally
513
+ ``predict_stream`` methods to define the non-streaming and streaming behavior of your agent. You
514
+ can use any agent authoring framework - the only hard requirement is to implement the
515
+ ``predict`` interface.
516
+
517
+ .. code-block:: python
518
+
519
+ def predict(
520
+ self,
521
+ messages: list[ChatAgentMessage],
522
+ context: Optional[ChatContext] = None,
523
+ custom_inputs: Optional[dict[str, Any]] = None,
524
+ ) -> ChatAgentResponse: ...
525
+
526
+ In addition to calling predict and predict_stream methods with an input matching their type
527
+ hints, you can also pass a single input dict that matches the
528
+ :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema for ease of testing.
529
+
530
+ .. code-block:: python
531
+
532
+ chat_agent = MyChatAgent()
533
+ chat_agent.predict(
534
+ {
535
+ "messages": [{"role": "user", "content": "What is 10 + 10?"}],
536
+ "context": {"conversation_id": "123", "user_id": "456"},
537
+ }
538
+ )
539
+
540
+ See an example implementation of ``predict`` and ``predict_stream`` for a LangGraph agent in
541
+ the :py:class:`ChatAgentState <mlflow.langchain.chat_agent_langgraph.ChatAgentState>`
542
+ docstring.
543
+
544
+ **Logging the ChatAgent**
545
+
546
+ Since the landscape of LLM frameworks is constantly evolving and not every flavor can be
547
+ natively supported by MLflow, we recommend the
548
+ `Models-from-Code <https://mlflow.org/docs/latest/model/models-from-code.html>`_ logging
549
+ approach.
550
+
551
+ .. code-block:: python
552
+
553
+ with mlflow.start_run():
554
+ logged_agent_info = mlflow.pyfunc.log_model(
555
+ name="agent",
556
+ python_model=os.path.join(os.getcwd(), "agent"),
557
+ # Add serving endpoints, tools, and vector search indexes here
558
+ resources=[],
559
+ )
560
+
561
+ After logging the model, you can query the model with a single dictionary with the
562
+ :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema. Under the hood, it
563
+ will be converted into the python objects expected by your ``predict`` and ``predict_stream``
564
+ methods.
565
+
566
+ .. code-block:: python
567
+
568
+ loaded_model = mlflow.pyfunc.load_model(tmp_path)
569
+ loaded_model.predict(
570
+ {
571
+ "messages": [{"role": "user", "content": "What is 10 + 10?"}],
572
+ "context": {"conversation_id": "123", "user_id": "456"},
573
+ }
574
+ )
575
+
576
+ To make logging ChatAgent models as easy as possible, MLflow has built in the following
577
+ features:
578
+
579
+ - Automatic Model Signature Inference
580
+ - You do not need to set a signature when logging a ChatAgent
581
+ - An input and output signature will be automatically set that adheres to the
582
+ :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` and
583
+ :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` schemas
584
+ - Metadata
585
+ - ``{"task": "agent/v2/chat"}`` will be automatically appended to any metadata that you may
586
+ pass in when logging the model
587
+ - Input Example
588
+ - Providing an input example is optional, ``mlflow.types.agent.CHAT_AGENT_INPUT_EXAMPLE``
589
+ will be provided by default
590
+ - If you do provide an input example, ensure it's a dict with the
591
+ :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema
592
+
593
+ - .. code-block:: python
594
+
595
+ input_example = {
596
+ "messages": [{"role": "user", "content": "What is MLflow?"}],
597
+ "context": {"conversation_id": "123", "user_id": "456"},
598
+ }
599
+
600
+ **Migrating from ChatModel to ChatAgent**
601
+
602
+ To convert an existing ChatModel that takes in
603
+ :py:class:`List[ChatMessage] <mlflow.types.llm.ChatMessage>` and
604
+ :py:class:`ChatParams <mlflow.types.llm.ChatParams>` and outputs a
605
+ :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`, do the following:
606
+
607
+ - Subclass ``ChatAgent`` instead of ``ChatModel``
608
+ - Move any functionality from your ``ChatModel``'s ``load_context`` implementation into the
609
+ ``__init__`` method of your new ``ChatAgent``.
610
+ - Use ``.model_dump_compat()`` instead of ``.to_dict()`` when converting your model's inputs to
611
+ dictionaries. Ex. ``[msg.model_dump_compat() for msg in messages]`` instead of
612
+ ``[msg.to_dict() for msg in messages]``
613
+ - Return a :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` instead of a
614
+ :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`
615
+
616
+ For example, we can convert the ChatModel from the
617
+ `Chat Model Intro <https://mlflow.org/docs/latest/llms/chat-model-intro/index.html#building-your-first-chatmodel>`_
618
+ to a ChatAgent:
619
+
620
+ .. code-block:: python
621
+
622
+ class SimpleOllamaModel(ChatModel):
623
+ def __init__(self):
624
+ self.model_name = "llama3.2:1b"
625
+ self.client = None
626
+
627
+ def load_context(self, context):
628
+ self.client = ollama.Client()
629
+
630
+ def predict(
631
+ self, context, messages: list[ChatMessage], params: ChatParams = None
632
+ ) -> ChatCompletionResponse:
633
+ ollama_messages = [msg.to_dict() for msg in messages]
634
+ response = self.client.chat(model=self.model_name, messages=ollama_messages)
635
+ return ChatCompletionResponse(
636
+ choices=[{"index": 0, "message": response["message"]}],
637
+ model=self.model_name,
638
+ )
639
+
640
+ .. code-block:: python
641
+
642
+ class SimpleOllamaModel(ChatAgent):
643
+ def __init__(self):
644
+ self.model_name = "llama3.2:1b"
645
+ self.client = None
646
+ self.client = ollama.Client()
647
+
648
+ def predict(
649
+ self,
650
+ messages: list[ChatAgentMessage],
651
+ context: Optional[ChatContext] = None,
652
+ custom_inputs: Optional[dict[str, Any]] = None,
653
+ ) -> ChatAgentResponse:
654
+ ollama_messages = self._convert_messages_to_dict(messages)
655
+ response = self.client.chat(model=self.model_name, messages=ollama_messages)
656
+ return ChatAgentResponse(**{"messages": [response["message"]]})
657
+
658
+ **ChatAgent Connectors**
659
+
660
+ MLflow provides convenience APIs for wrapping agents written in popular authoring frameworks
661
+ with ChatAgent. See examples for:
662
+
663
+ - LangGraph in the
664
+ :py:class:`ChatAgentState <mlflow.langchain.chat_agent_langgraph.ChatAgentState>` docstring
665
+ """
666
+
667
+ _skip_type_hint_validation = True
668
+
669
+ def __init_subclass__(cls, **kwargs) -> None:
670
+ super().__init_subclass__(**kwargs)
671
+ for attr_name in ("predict", "predict_stream"):
672
+ attr = cls.__dict__.get(attr_name)
673
+ if callable(attr):
674
+ setattr(
675
+ cls,
676
+ attr_name,
677
+ wrap_non_list_predict_pydantic(
678
+ attr,
679
+ ChatAgentRequest,
680
+ "Invalid dictionary input for a ChatAgent. Expected a dictionary with the "
681
+ "ChatAgentRequest schema.",
682
+ unpack=True,
683
+ ),
684
+ )
685
+
686
+ def _convert_messages_to_dict(self, messages: list[ChatAgentMessage]):
687
+ return [m.model_dump_compat(exclude_none=True) for m in messages]
688
+
689
+ # nb: We use `messages` instead of `model_input` so that the trace generated by default is
690
+ # compatible with mlflow evaluate. We also want `custom_inputs` to be a top level key for
691
+ # ease of use.
692
+ @abstractmethod
693
+ def predict(
694
+ self,
695
+ messages: list[ChatAgentMessage],
696
+ context: Optional[ChatContext] = None,
697
+ custom_inputs: Optional[dict[str, Any]] = None,
698
+ ) -> ChatAgentResponse:
699
+ """
700
+ Given a ChatAgent input, returns a ChatAgent output. In addition to calling ``predict``
701
+ with an input matching the type hints, you can also pass a single input dict that matches
702
+ the :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema for ease
703
+ of testing.
704
+
705
+ .. code-block:: python
706
+
707
+ chat_agent = ChatAgent()
708
+ chat_agent.predict(
709
+ {
710
+ "messages": [{"role": "user", "content": "What is 10 + 10?"}],
711
+ "context": {"conversation_id": "123", "user_id": "456"},
712
+ }
713
+ )
714
+
715
+ Args:
716
+ messages (List[:py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`]):
717
+ A list of :py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`
718
+ objects representing the chat history.
719
+ context (:py:class:`ChatContext <mlflow.types.agent.ChatContext>`):
720
+ A :py:class:`ChatContext <mlflow.types.agent.ChatContext>` object
721
+ containing conversation_id and user_id. **Optional** Defaults to None.
722
+ custom_inputs (Dict[str, Any]):
723
+ An optional param to provide arbitrary additional inputs
724
+ to the model. The dictionary values must be JSON-serializable. **Optional**
725
+ Defaults to None.
726
+
727
+ Returns:
728
+ A :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` object containing
729
+ the model's response, as well as other metadata.
730
+ """
731
+
732
+ # nb: We use `messages` instead of `model_input` so that the trace generated by default is
733
+ # compatible with mlflow evaluate. We also want `custom_inputs` to be a top level key for
734
+ # ease of use.
735
+ def predict_stream(
736
+ self,
737
+ messages: list[ChatAgentMessage],
738
+ context: Optional[ChatContext] = None,
739
+ custom_inputs: Optional[dict[str, Any]] = None,
740
+ ) -> Generator[ChatAgentChunk, None, None]:
741
+ """
742
+ Given a ChatAgent input, returns a generator containing streaming ChatAgent output chunks.
743
+ In addition to calling ``predict_stream`` with an input matching the type hints, you can
744
+ also pass a single input dict that matches the
745
+ :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>`
746
+ schema for ease of testing.
747
+
748
+ .. code-block:: python
749
+
750
+ chat_agent = ChatAgent()
751
+ for event in chat_agent.predict_stream(
752
+ {
753
+ "messages": [{"role": "user", "content": "What is 10 + 10?"}],
754
+ "context": {"conversation_id": "123", "user_id": "456"},
755
+ }
756
+ ):
757
+ print(event)
758
+
759
+ To support streaming the output of your agent, override this method in your subclass of
760
+ ``ChatAgent``. When implementing ``predict_stream``, keep in mind the following
761
+ requirements:
762
+
763
+ - Ensure your implementation adheres to the ``predict_stream`` type signature. For example,
764
+ streamed messages must be of the type
765
+ :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>`, where each chunk contains
766
+ partial output from a single response message.
767
+ - At most one chunk in a particular response can contain the ``custom_outputs`` key.
768
+ - Chunks containing partial content of a single response message must have the same ``id``.
769
+ The content field of the message and usage stats of the
770
+ :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>` should be aggregated by
771
+ the consuming client. See the example below.
772
+
773
+ .. code-block:: python
774
+
775
+ {"delta": {"role": "assistant", "content": "Born", "id": "123"}}
776
+ {"delta": {"role": "assistant", "content": " in", "id": "123"}}
777
+ {"delta": {"role": "assistant", "content": " data", "id": "123"}}
778
+
779
+
780
+ Args:
781
+ messages (List[:py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`]):
782
+ A list of :py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`
783
+ objects representing the chat history.
784
+ context (:py:class:`ChatContext <mlflow.types.agent.ChatContext>`):
785
+ A :py:class:`ChatContext <mlflow.types.agent.ChatContext>` object
786
+ containing conversation_id and user_id. **Optional** Defaults to None.
787
+ custom_inputs (Dict[str, Any]):
788
+ An optional param to provide arbitrary additional inputs
789
+ to the model. The dictionary values must be JSON-serializable. **Optional**
790
+ Defaults to None.
791
+
792
+ Returns:
793
+ A generator over :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>`
794
+ objects containing the model's response(s), as well as other metadata.
795
+ """
796
+ raise NotImplementedError(
797
+ "Streaming implementation not provided. Please override the "
798
+ "`predict_stream` method on your model to generate streaming predictions"
799
+ )
800
+
801
+
802
+ def _check_compression_supported(compression):
803
+ if compression in _COMPRESSION_INFO:
804
+ return True
805
+ if compression:
806
+ supported = ", ".join(sorted(_COMPRESSION_INFO))
807
+ mlflow.pyfunc._logger.warning(
808
+ f"Unrecognized compression method '{compression}'"
809
+ f"Please select one of: {supported}. Falling back to uncompressed storage/loading."
810
+ )
811
+ return False
812
+
813
+
814
+ def _maybe_compress_cloudpickle_dump(python_model, path, compression):
815
+ file_open = _COMPRESSION_INFO.get(compression, {}).get("open", open)
816
+ with file_open(path, "wb") as out:
817
+ cloudpickle.dump(python_model, out)
818
+
819
+
820
+ def _maybe_decompress_cloudpickle_load(path, compression):
821
+ """
822
+ Genesis-Flow: Secure model loading with safety checks.
823
+ """
824
+ from mlflow.utils.secure_loading import SecureModelLoader, SecurityError
825
+
826
+ _check_compression_supported(compression)
827
+
828
+ # For compressed files, we need to decompress first then load securely
829
+ if compression and compression != "none":
830
+ import tempfile
831
+ file_open = _COMPRESSION_INFO.get(compression, {}).get("open", open)
832
+ with file_open(path, "rb") as compressed_f:
833
+ with tempfile.NamedTemporaryFile(delete=False) as temp_f:
834
+ temp_f.write(compressed_f.read())
835
+ temp_path = temp_f.name
836
+ try:
837
+ return SecureModelLoader.safe_cloudpickle_load(temp_path)
838
+ finally:
839
+ os.unlink(temp_path)
840
+ else:
841
+ # Direct secure loading for uncompressed files
842
+ return SecureModelLoader.safe_cloudpickle_load(path)
843
+
844
+
845
+ if IS_PYDANTIC_V2_OR_NEWER:
846
+ from mlflow.types.responses import (
847
+ ResponsesAgentRequest,
848
+ ResponsesAgentResponse,
849
+ ResponsesAgentStreamEvent,
850
+ )
851
+
852
+ @experimental(version="3.0.0")
853
+ class ResponsesAgent(PythonModel, metaclass=ABCMeta):
854
+ """
855
+ A base class for creating ResponsesAgent models. It can be used as a wrapper around any
856
+ agent framework to create an agent model that can be deployed to MLflow. Has a few helper
857
+ methods to help create output items that can be a part of a ResponsesAgentResponse or
858
+ ResponsesAgentStreamEvent.
859
+
860
+ See https://www.mlflow.org/docs/latest/llms/responses-agent-intro/ for more details.
861
+ """
862
+
863
+ _skip_type_hint_validation = True
864
+
865
+ def __init_subclass__(cls, **kwargs) -> None:
866
+ super().__init_subclass__(**kwargs)
867
+ for attr_name in ("predict", "predict_stream"):
868
+ attr = cls.__dict__.get(attr_name)
869
+ if callable(attr):
870
+ setattr(
871
+ cls,
872
+ attr_name,
873
+ wrap_non_list_predict_pydantic(
874
+ attr,
875
+ ResponsesAgentRequest,
876
+ "Invalid dictionary input for a ResponsesAgent. "
877
+ "Expected a dictionary with the ResponsesRequest schema.",
878
+ ),
879
+ )
880
+
881
+ @abstractmethod
882
+ def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
883
+ """
884
+ Given a ResponsesAgentRequest, returns a ResponsesAgentResponse.
885
+
886
+ You can see example implementations at
887
+ https://www.mlflow.org/docs/latest/llms/responses-agent-intro#simple-chat-example and
888
+ https://www.mlflow.org/docs/latest/llms/responses-agent-intro#tool-calling-example.
889
+ """
890
+
891
+ def predict_stream(
892
+ self, request: ResponsesAgentRequest
893
+ ) -> Generator[ResponsesAgentStreamEvent, None, None]:
894
+ """
895
+ Given a ResponsesAgentRequest, returns a generator of ResponsesAgentStreamEvent objects.
896
+
897
+ See more details at
898
+ https://www.mlflow.org/docs/latest/llms/responses-agent-intro#streaming-agent-output.
899
+
900
+ You can see example implementations at
901
+ https://www.mlflow.org/docs/latest/llms/responses-agent-intro#simple-chat-example and
902
+ https://www.mlflow.org/docs/latest/llms/responses-agent-intro#tool-calling-example.
903
+ """
904
+ raise NotImplementedError(
905
+ "Streaming implementation not provided. Please override the "
906
+ "`predict_stream` method on your model to generate streaming predictions"
907
+ )
908
+
909
+ def create_text_delta(self, delta: str, item_id: str) -> dict[str, Any]:
910
+ """Helper method to create a dictionary conforming to the text delta schema for
911
+ streaming.
912
+
913
+ Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#streaming-agent-output.
914
+ """
915
+ return {
916
+ "type": "response.output_text.delta",
917
+ "item_id": item_id,
918
+ "delta": delta,
919
+ }
920
+
921
+ def create_text_output_item(self, text: str, id: str) -> dict[str, Any]:
922
+ """Helper method to create a dictionary conforming to the text output item schema.
923
+
924
+ Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output.
925
+
926
+ Args:
927
+ text (str): The text to be outputted.
928
+ id (str): The id of the output item.
929
+ """
930
+ return {
931
+ "id": id,
932
+ "content": [
933
+ {
934
+ "text": text,
935
+ "type": "output_text",
936
+ }
937
+ ],
938
+ "role": "assistant",
939
+ "type": "message",
940
+ }
941
+
942
+ def create_function_call_item(
943
+ self, id: str, call_id: str, name: str, arguments: str
944
+ ) -> dict[str, Any]:
945
+ """Helper method to create a dictionary conforming to the function call item schema.
946
+
947
+ Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output.
948
+
949
+ Args:
950
+ id (str): The id of the output item.
951
+ call_id (str): The id of the function call.
952
+ name (str): The name of the function to be called.
953
+ arguments (str): The arguments to be passed to the function.
954
+ """
955
+ return {
956
+ "type": "function_call",
957
+ "id": id,
958
+ "call_id": call_id,
959
+ "name": name,
960
+ "arguments": arguments,
961
+ }
962
+
963
+ def create_function_call_output_item(self, call_id: str, output: str) -> dict[str, Any]:
964
+ """Helper method to create a dictionary conforming to the function call output item
965
+ schema.
966
+
967
+ Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output.
968
+
969
+ Args:
970
+ call_id (str): The id of the function call.
971
+ output (str): The output of the function call.
972
+ """
973
+ return {
974
+ "type": "function_call_output",
975
+ "call_id": call_id,
976
+ "output": output,
977
+ }
978
+
979
+
980
+ def _save_model_with_class_artifacts_params( # noqa: D417
981
+ path,
982
+ python_model,
983
+ signature=None,
984
+ artifacts=None,
985
+ conda_env=None,
986
+ code_paths=None,
987
+ mlflow_model=None,
988
+ pip_requirements=None,
989
+ extra_pip_requirements=None,
990
+ model_config=None,
991
+ streamable=None,
992
+ model_code_path=None,
993
+ infer_code_paths=False,
994
+ ):
995
+ """
996
+ Args:
997
+ path: The path to which to save the Python model.
998
+ python_model: An instance of a subclass of :class:`~PythonModel`. ``python_model``
999
+ defines how the model loads artifacts and how it performs inference.
1000
+ artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs
1001
+ are resolved to absolute filesystem paths, producing a dictionary of
1002
+ ``<name, absolute_path>`` entries, (e.g. {"file": "absolute_path"}).
1003
+ ``python_model`` can reference these resolved entries as the ``artifacts`` property
1004
+ of the ``context`` attribute. If ``<artifact_name, 'hf:/repo_id'>``(e.g.
1005
+ {"bert-tiny-model": "hf:/prajjwal1/bert-tiny"}) is provided, then the model can be
1006
+ fetched from huggingface hub using repo_id `prajjwal1/bert-tiny` directly. If ``None``,
1007
+ no artifacts are added to the model.
1008
+ conda_env: Either a dictionary representation of a Conda environment or the path to a Conda
1009
+ environment yaml file. If provided, this decsribes the environment this model should be
1010
+ run in. At minimum, it should specify the dependencies contained in
1011
+ :func:`get_default_conda_env()`. If ``None``, the default
1012
+ :func:`get_default_conda_env()` environment is added to the model.
1013
+ code_paths: A list of local filesystem paths to Python file dependencies (or directories
1014
+ containing file dependencies). These files are *prepended* to the system path before the
1015
+ model is loaded.
1016
+ mlflow_model: The model to which to add the ``mlflow.pyfunc`` flavor.
1017
+ model_config: The model configuration for the flavor. Model configuration is available
1018
+ during model loading time.
1019
+
1020
+ .. Note:: Experimental: This parameter may change or be removed in a future release
1021
+ without warning.
1022
+
1023
+ model_code_path: The path to the code that is being logged as a PyFunc model. Can be used
1024
+ to load python_model when python_model is None.
1025
+
1026
+ .. Note:: Experimental: This parameter may change or be removed in a future release
1027
+ without warning.
1028
+
1029
+ streamable: A boolean value indicating if the model supports streaming prediction,
1030
+ If None, MLflow will try to inspect if the model supports streaming
1031
+ by checking if `predict_stream` method exists. Default None.
1032
+ """
1033
+ if mlflow_model is None:
1034
+ mlflow_model = Model()
1035
+
1036
+ custom_model_config_kwargs = {
1037
+ CONFIG_KEY_CLOUDPICKLE_VERSION: cloudpickle.__version__,
1038
+ }
1039
+ if callable(python_model):
1040
+ python_model = _FunctionPythonModel(func=python_model, signature=signature)
1041
+
1042
+ saved_python_model_subpath = _SAVED_PYTHON_MODEL_SUBPATH
1043
+
1044
+ compression = MLFLOW_LOG_MODEL_COMPRESSION.get()
1045
+ if compression:
1046
+ if _check_compression_supported(compression):
1047
+ custom_model_config_kwargs[CONFIG_KEY_COMPRESSION] = compression
1048
+ saved_python_model_subpath += _COMPRESSION_INFO[compression]["ext"]
1049
+ else:
1050
+ compression = None
1051
+
1052
+ # If model_code_path is defined, we load the model into python_model, but we don't want to
1053
+ # pickle/save the python_model since the module won't be able to be imported.
1054
+ if not model_code_path:
1055
+ try:
1056
+ _maybe_compress_cloudpickle_dump(
1057
+ python_model, os.path.join(path, saved_python_model_subpath), compression
1058
+ )
1059
+ except Exception as e:
1060
+ raise MlflowException(
1061
+ "Failed to serialize Python model. Please save the model into a python file "
1062
+ "and use code-based logging method instead. See"
1063
+ "https://mlflow.org/docs/latest/models.html#models-from-code for more information."
1064
+ ) from e
1065
+
1066
+ custom_model_config_kwargs[CONFIG_KEY_PYTHON_MODEL] = saved_python_model_subpath
1067
+
1068
+ if artifacts:
1069
+ saved_artifacts_config = {}
1070
+ with TempDir() as tmp_artifacts_dir:
1071
+ saved_artifacts_dir_subpath = "artifacts"
1072
+ hf_prefix = "hf:/"
1073
+ for artifact_name, artifact_uri in artifacts.items():
1074
+ if artifact_uri.startswith(hf_prefix):
1075
+ try:
1076
+ from huggingface_hub import snapshot_download
1077
+ except ImportError as e:
1078
+ raise MlflowException(
1079
+ "Failed to import huggingface_hub. Please install huggingface_hub "
1080
+ f"to log the model with artifact_uri {artifact_uri}. Error: {e}"
1081
+ )
1082
+
1083
+ repo_id = artifact_uri[len(hf_prefix) :]
1084
+ try:
1085
+ snapshot_location = snapshot_download(
1086
+ repo_id=repo_id,
1087
+ local_dir=os.path.join(
1088
+ path, saved_artifacts_dir_subpath, artifact_name
1089
+ ),
1090
+ local_dir_use_symlinks=False,
1091
+ )
1092
+ except Exception as e:
1093
+ raise MlflowException.invalid_parameter_value(
1094
+ "Failed to download snapshot from Hugging Face Hub with artifact_uri: "
1095
+ f"{artifact_uri}. Error: {e}"
1096
+ )
1097
+ saved_artifact_subpath = (
1098
+ Path(snapshot_location).relative_to(Path(os.path.realpath(path))).as_posix()
1099
+ )
1100
+ else:
1101
+ tmp_artifact_path = _download_artifact_from_uri(
1102
+ artifact_uri=artifact_uri, output_path=tmp_artifacts_dir.path()
1103
+ )
1104
+
1105
+ relative_path = (
1106
+ Path(tmp_artifact_path)
1107
+ .relative_to(Path(tmp_artifacts_dir.path()))
1108
+ .as_posix()
1109
+ )
1110
+
1111
+ saved_artifact_subpath = os.path.join(
1112
+ saved_artifacts_dir_subpath, relative_path
1113
+ )
1114
+
1115
+ saved_artifacts_config[artifact_name] = {
1116
+ CONFIG_KEY_ARTIFACT_RELATIVE_PATH: saved_artifact_subpath,
1117
+ CONFIG_KEY_ARTIFACT_URI: artifact_uri,
1118
+ }
1119
+
1120
+ shutil.move(tmp_artifacts_dir.path(), os.path.join(path, saved_artifacts_dir_subpath))
1121
+ custom_model_config_kwargs[CONFIG_KEY_ARTIFACTS] = saved_artifacts_config
1122
+
1123
+ if streamable is None:
1124
+ streamable = python_model.__class__.predict_stream != PythonModel.predict_stream
1125
+
1126
+ if model_code_path:
1127
+ loader_module = mlflow.pyfunc.loaders.code_model.__name__
1128
+ elif python_model:
1129
+ loader_module = _get_pyfunc_loader_module(python_model)
1130
+ else:
1131
+ raise MlflowException(
1132
+ "Either `python_model` or `model_code_path` must be provided to save the model.",
1133
+ error_code=INVALID_PARAMETER_VALUE,
1134
+ )
1135
+
1136
+ mlflow.pyfunc.add_to_model(
1137
+ model=mlflow_model,
1138
+ loader_module=loader_module,
1139
+ code=None,
1140
+ conda_env=_CONDA_ENV_FILE_NAME,
1141
+ python_env=_PYTHON_ENV_FILE_NAME,
1142
+ model_config=model_config,
1143
+ streamable=streamable,
1144
+ model_code_path=model_code_path,
1145
+ **custom_model_config_kwargs,
1146
+ )
1147
+ if size := get_total_file_size(path):
1148
+ mlflow_model.model_size_bytes = size
1149
+ # `mlflow_model.save` must be called before _validate_infer_and_copy_code_paths as it
1150
+ # internally infers dependency, and MLmodel file is required to successfully load the model
1151
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
1152
+
1153
+ saved_code_subpath = _validate_infer_and_copy_code_paths(
1154
+ code_paths,
1155
+ path,
1156
+ infer_code_paths,
1157
+ mlflow.pyfunc.FLAVOR_NAME,
1158
+ )
1159
+ mlflow_model.flavors[mlflow.pyfunc.FLAVOR_NAME][mlflow.pyfunc.CODE] = saved_code_subpath
1160
+
1161
+ # `mlflow_model.code` is updated, re-generate `MLmodel` file.
1162
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
1163
+
1164
+ if conda_env is None:
1165
+ if pip_requirements is None:
1166
+ default_reqs = get_default_pip_requirements()
1167
+ extra_env_vars = (
1168
+ _get_databricks_serverless_env_vars()
1169
+ if is_in_databricks_serverless_runtime()
1170
+ else None
1171
+ )
1172
+ # To ensure `_load_pyfunc` can successfully load the model during the dependency
1173
+ # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
1174
+ inferred_reqs = mlflow.models.infer_pip_requirements(
1175
+ path,
1176
+ mlflow.pyfunc.FLAVOR_NAME,
1177
+ fallback=default_reqs,
1178
+ extra_env_vars=extra_env_vars,
1179
+ )
1180
+ default_reqs = sorted(set(inferred_reqs).union(default_reqs))
1181
+ else:
1182
+ default_reqs = None
1183
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
1184
+ default_reqs,
1185
+ pip_requirements,
1186
+ extra_pip_requirements,
1187
+ )
1188
+ else:
1189
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
1190
+
1191
+ with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
1192
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
1193
+
1194
+ # Save `constraints.txt` if necessary
1195
+ if pip_constraints:
1196
+ write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
1197
+
1198
+ # Save `requirements.txt`
1199
+ write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
1200
+
1201
+ _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
1202
+
1203
+
1204
+ def _load_context_model_and_signature(
1205
+ model_path: str, model_config: Optional[dict[str, Any]] = None
1206
+ ):
1207
+ pyfunc_config = _get_flavor_configuration(
1208
+ model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
1209
+ )
1210
+ signature = mlflow.models.Model.load(model_path).signature
1211
+
1212
+ if MODEL_CODE_PATH in pyfunc_config:
1213
+ conf_model_code_path = pyfunc_config.get(MODEL_CODE_PATH)
1214
+ model_code_path = os.path.join(model_path, os.path.basename(conf_model_code_path))
1215
+ python_model = _load_model_code_path(model_code_path, model_config)
1216
+
1217
+ if callable(python_model):
1218
+ python_model = _FunctionPythonModel(python_model, signature=signature)
1219
+ else:
1220
+ python_model_cloudpickle_version = pyfunc_config.get(CONFIG_KEY_CLOUDPICKLE_VERSION, None)
1221
+ if python_model_cloudpickle_version is None:
1222
+ mlflow.pyfunc._logger.warning(
1223
+ "The version of CloudPickle used to save the model could not be found in the "
1224
+ "MLmodel configuration"
1225
+ )
1226
+ elif python_model_cloudpickle_version != cloudpickle.__version__:
1227
+ # CloudPickle does not have a well-defined cross-version compatibility policy. Micro
1228
+ # version releases have been known to cause incompatibilities. Therefore, we match on
1229
+ # the full library version
1230
+ mlflow.pyfunc._logger.warning(
1231
+ "The version of CloudPickle that was used to save the model, `CloudPickle %s`, "
1232
+ "differs from the version of CloudPickle that is currently running, `CloudPickle "
1233
+ "%s`, and may be incompatible",
1234
+ python_model_cloudpickle_version,
1235
+ cloudpickle.__version__,
1236
+ )
1237
+ python_model_compression = pyfunc_config.get(CONFIG_KEY_COMPRESSION, None)
1238
+
1239
+ python_model_subpath = pyfunc_config.get(CONFIG_KEY_PYTHON_MODEL, None)
1240
+ if python_model_subpath is None:
1241
+ raise MlflowException("Python model path was not specified in the model configuration")
1242
+ python_model = _maybe_decompress_cloudpickle_load(
1243
+ os.path.join(model_path, python_model_subpath), python_model_compression
1244
+ )
1245
+
1246
+ artifacts = {}
1247
+ for saved_artifact_name, saved_artifact_info in pyfunc_config.get(
1248
+ CONFIG_KEY_ARTIFACTS, {}
1249
+ ).items():
1250
+ artifacts[saved_artifact_name] = os.path.join(
1251
+ model_path, saved_artifact_info[CONFIG_KEY_ARTIFACT_RELATIVE_PATH]
1252
+ )
1253
+
1254
+ context = PythonModelContext(artifacts=artifacts, model_config=model_config)
1255
+ python_model.load_context(context=context)
1256
+
1257
+ return context, python_model, signature
1258
+
1259
+
1260
+ def _load_pyfunc(model_path: str, model_config: Optional[dict[str, Any]] = None):
1261
+ context, python_model, signature = _load_context_model_and_signature(model_path, model_config)
1262
+ return _PythonModelPyfuncWrapper(
1263
+ python_model=python_model,
1264
+ context=context,
1265
+ signature=signature,
1266
+ )
1267
+
1268
+
1269
+ def _get_first_string_column(pdf):
1270
+ iter_string_columns = (col for col, val in pdf.iloc[0].items() if isinstance(val, str))
1271
+ return next(iter_string_columns, None)
1272
+
1273
+
1274
+ class _PythonModelPyfuncWrapper:
1275
+ """
1276
+ Wrapper class that creates a predict function such that
1277
+ predict(model_input: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame)
1278
+ """
1279
+
1280
+ def __init__(self, python_model: PythonModel, context, signature):
1281
+ """
1282
+ Args:
1283
+ python_model: An instance of a subclass of :class:`~PythonModel`.
1284
+ context: A :class:`~PythonModelContext` instance containing artifacts that
1285
+ ``python_model`` may use when performing inference.
1286
+ signature: :class:`~ModelSignature` instance describing model input and output.
1287
+ """
1288
+ self.python_model = python_model
1289
+ self.context = context
1290
+ self.signature = signature
1291
+
1292
+ def _convert_input(self, model_input):
1293
+ hints = self.python_model.predict_type_hints
1294
+ # we still need this for backwards compatibility
1295
+ if isinstance(model_input, pd.DataFrame):
1296
+ if _is_list_str(hints.input):
1297
+ first_string_column = _get_first_string_column(model_input)
1298
+ if first_string_column is None:
1299
+ raise MlflowException.invalid_parameter_value(
1300
+ "Expected model input to contain at least one string column"
1301
+ )
1302
+ return model_input[first_string_column].tolist()
1303
+ elif _is_list_dict_str(hints.input):
1304
+ if (
1305
+ len(self.signature.inputs) == 1
1306
+ and next(iter(self.signature.inputs)).name is None
1307
+ ):
1308
+ if first_string_column := _get_first_string_column(model_input):
1309
+ return model_input[[first_string_column]].to_dict(orient="records")
1310
+ if len(model_input.columns) == 1:
1311
+ return model_input.to_dict("list")[0]
1312
+ return model_input.to_dict(orient="records")
1313
+ elif isinstance(hints.input, type) and (
1314
+ issubclass(hints.input, ChatCompletionRequest)
1315
+ or issubclass(hints.input, SplitChatMessagesRequest)
1316
+ ):
1317
+ # If the type hint is a RAG dataclass, we hydrate it
1318
+ # If there are multiple rows, we should throw
1319
+ if len(model_input) > 1:
1320
+ raise MlflowException(
1321
+ "Expected a single input for dataclass type hint, but got multiple rows"
1322
+ )
1323
+ # Since single input is expected, we take the first row
1324
+ return _hydrate_dataclass(hints.input, model_input.iloc[0])
1325
+ return model_input
1326
+
1327
+ def predict(self, model_input, params: Optional[dict[str, Any]] = None):
1328
+ """
1329
+ Args:
1330
+ model_input: Model input data as one of dict, str, bool, bytes, float, int, str type.
1331
+ params: Additional parameters to pass to the model for inference.
1332
+
1333
+ Returns:
1334
+ Model predictions as an iterator of chunks. The chunks in the iterator must be type of
1335
+ dict or string. Chunk dict fields are determined by the model implementation.
1336
+ """
1337
+ parameters = inspect.signature(self.python_model.predict).parameters
1338
+ kwargs = {}
1339
+ if "params" in parameters:
1340
+ kwargs["params"] = params
1341
+ else:
1342
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
1343
+ if _is_context_in_predict_function_signature(parameters=parameters):
1344
+ return self.python_model.predict(
1345
+ self.context, self._convert_input(model_input), **kwargs
1346
+ )
1347
+ else:
1348
+ return self.python_model.predict(self._convert_input(model_input), **kwargs)
1349
+
1350
+ def predict_stream(self, model_input, params: Optional[dict[str, Any]] = None):
1351
+ """
1352
+ Args:
1353
+ model_input: LLM Model single input.
1354
+ params: Additional parameters to pass to the model for inference.
1355
+
1356
+ Returns:
1357
+ Streaming predictions.
1358
+ """
1359
+ parameters = inspect.signature(self.python_model.predict_stream).parameters
1360
+ kwargs = {}
1361
+ if "params" in parameters:
1362
+ kwargs["params"] = params
1363
+ else:
1364
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
1365
+ if _is_context_in_predict_function_signature(parameters=parameters):
1366
+ return self.python_model.predict_stream(
1367
+ self.context, self._convert_input(model_input), **kwargs
1368
+ )
1369
+ else:
1370
+ return self.python_model.predict_stream(self._convert_input(model_input), **kwargs)
1371
+
1372
+
1373
+ def _get_pyfunc_loader_module(python_model):
1374
+ if isinstance(python_model, ChatModel):
1375
+ return mlflow.pyfunc.loaders.chat_model.__name__
1376
+ elif isinstance(python_model, ChatAgent):
1377
+ return mlflow.pyfunc.loaders.chat_agent.__name__
1378
+ elif IS_PYDANTIC_V2_OR_NEWER and isinstance(python_model, ResponsesAgent):
1379
+ return mlflow.pyfunc.loaders.responses_agent.__name__
1380
+ return __name__
1381
+
1382
+
1383
+ class ModelFromDeploymentEndpoint(PythonModel):
1384
+ """
1385
+ A PythonModel wrapper for invoking an MLflow Deployments endpoint.
1386
+ This class is particularly used for running evaluation against an MLflow Deployments endpoint.
1387
+ """
1388
+
1389
+ def __init__(self, endpoint, params):
1390
+ self.endpoint = endpoint
1391
+ self.params = params
1392
+
1393
+ def predict(
1394
+ self, context, model_input: Union[pd.DataFrame, dict[str, Any], list[dict[str, Any]]]
1395
+ ):
1396
+ """
1397
+ Run prediction on the input data.
1398
+
1399
+ Args:
1400
+ context: A :class:`~PythonModelContext` instance containing artifacts that the model
1401
+ can use to perform inference.
1402
+ model_input: The input data for prediction, either of the following:
1403
+ - Pandas DataFrame: If the default evaluator is used, input is a DF
1404
+ that contains the multiple request payloads in a single column.
1405
+ - A dictionary: If the model_type is "databricks-agents" and the
1406
+ Databricks RAG evaluator is used, this PythonModel can be invoked
1407
+ with a single dict corresponding to the ChatCompletionsRequest schema.
1408
+ - A list of dictionaries: Currently we don't have any evaluator that
1409
+ gives this input format, but we keep this for future use cases and
1410
+ compatibility with normal pyfunc models.
1411
+
1412
+ Return:
1413
+ The prediction result. The return type will be consistent with the model input type,
1414
+ e.g., if the input is a Pandas DataFrame, the return will be a Pandas Series.
1415
+ """
1416
+ if isinstance(model_input, dict):
1417
+ return self._predict_single(model_input)
1418
+ elif isinstance(model_input, list) and all(isinstance(data, dict) for data in model_input):
1419
+ return [self._predict_single(data) for data in model_input]
1420
+ elif isinstance(model_input, pd.DataFrame):
1421
+ if len(model_input.columns) != 1:
1422
+ raise MlflowException(
1423
+ f"The number of input columns must be 1, but got {model_input.columns}. "
1424
+ "Multi-column input is not supported for evaluating an MLflow Deployments "
1425
+ "endpoint. Please include the input text or payload in a single column.",
1426
+ error_code=INVALID_PARAMETER_VALUE,
1427
+ )
1428
+ input_column = model_input.columns[0]
1429
+
1430
+ predictions = [self._predict_single(data) for data in model_input[input_column]]
1431
+ return pd.Series(predictions)
1432
+ else:
1433
+ raise MlflowException(
1434
+ f"Invalid input data type: {type(model_input)}. The input data must be either "
1435
+ "a Pandas DataFrame, a dictionary, or a list of dictionaries containing the "
1436
+ "request payloads for evaluating an MLflow Deployments endpoint.",
1437
+ error_code=INVALID_PARAMETER_VALUE,
1438
+ )
1439
+
1440
+ def _predict_single(self, data: Union[str, dict[str, Any]]) -> dict[str, Any]:
1441
+ """
1442
+ Send a single prediction request to the MLflow Deployments endpoint.
1443
+
1444
+ Args:
1445
+ data: The single input data for prediction. If the input data is a string, we will
1446
+ construct the request payload from it. If the input data is a dictionary, we
1447
+ will directly use it as the request payload.
1448
+
1449
+ Returns:
1450
+ The prediction result from the MLflow Deployments endpoint as a dictionary.
1451
+ """
1452
+ from mlflow.metrics.genai.model_utils import call_deployments_api, get_endpoint_type
1453
+
1454
+ endpoint_type = get_endpoint_type(f"endpoints:/{self.endpoint}")
1455
+
1456
+ if isinstance(data, str):
1457
+ # If the input payload is string, MLflow needs to construct the JSON
1458
+ # payload based on the endpoint type. If the endpoint type is not
1459
+ # set on the endpoint, we will default to chat format.
1460
+ endpoint_type = endpoint_type or "llm/v1/chat"
1461
+ prediction = call_deployments_api(self.endpoint, data, self.params, endpoint_type)
1462
+ elif isinstance(data, dict):
1463
+ # If the input is dictionary, we assume the input is already in the
1464
+ # compatible format for the endpoint.
1465
+ prediction = call_deployments_api(self.endpoint, data, self.params, endpoint_type)
1466
+ else:
1467
+ raise MlflowException(
1468
+ f"Invalid input data type: {type(data)}. The feature column of the evaluation "
1469
+ "dataset must contain only strings or dictionaries containing the request "
1470
+ "payload for evaluating an MLflow Deployments endpoint.",
1471
+ error_code=INVALID_PARAMETER_VALUE,
1472
+ )
1473
+ return prediction