genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,919 @@
1
+ """
2
+ The ``mlflow.langchain`` module provides an API for logging and loading LangChain models.
3
+ This module exports multivariate LangChain models in the langchain flavor and univariate
4
+ LangChain models in the pyfunc flavor:
5
+
6
+ LangChain (native) format
7
+ This is the main flavor that can be accessed with LangChain APIs.
8
+ :py:mod:`mlflow.pyfunc`
9
+ Produced for use by generic pyfunc-based deployment tools and for batch inference.
10
+
11
+ .. _LangChain:
12
+ https://python.langchain.com/en/latest/index.html
13
+ """
14
+
15
+ import logging
16
+ import os
17
+ import tempfile
18
+ import warnings
19
+ from typing import Any, Iterator, Optional, Union
20
+
21
+ import cloudpickle
22
+ import pandas as pd
23
+ import yaml
24
+ from packaging.version import Version
25
+
26
+ import mlflow
27
+ from mlflow import pyfunc
28
+ from mlflow.entities.model_registry.prompt import Prompt
29
+ from mlflow.exceptions import MlflowException
30
+ from mlflow.langchain.constants import FLAVOR_NAME
31
+ from mlflow.langchain.databricks_dependencies import _detect_databricks_dependencies
32
+ from mlflow.langchain.runnables import _load_runnables, _save_runnables
33
+ from mlflow.langchain.utils.logging import (
34
+ _BASE_LOAD_KEY,
35
+ _MODEL_LOAD_KEY,
36
+ _RUNNABLE_LOAD_KEY,
37
+ _load_base_lcs,
38
+ _save_base_lcs,
39
+ _validate_and_prepare_lc_model_or_path,
40
+ lc_runnables_types,
41
+ patch_langchain_type_to_cls_dict,
42
+ register_pydantic_v1_serializer_cm,
43
+ )
44
+ from mlflow.models import Model, ModelInputExample, ModelSignature
45
+ from mlflow.models.dependencies_schemas import (
46
+ _clear_dependencies_schemas,
47
+ _get_dependencies_schema_from_model,
48
+ _get_dependencies_schemas,
49
+ )
50
+ from mlflow.models.model import (
51
+ MLMODEL_FILE_NAME,
52
+ MODEL_CODE_PATH,
53
+ MODEL_CONFIG,
54
+ _update_active_model_id_based_on_mlflow_model,
55
+ )
56
+ from mlflow.models.resources import DatabricksFunction, Resource, _ResourceBuilder
57
+ from mlflow.models.signature import _infer_signature_from_input_example
58
+ from mlflow.models.utils import (
59
+ _convert_llm_input_data,
60
+ _load_model_code_path,
61
+ _save_example,
62
+ )
63
+ from mlflow.pyfunc import FLAVOR_NAME as PYFUNC_FLAVOR_NAME
64
+ from mlflow.pyfunc.context import Context
65
+ from mlflow.tracing.provider import trace_disabled
66
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
67
+ from mlflow.tracking.artifact_utils import _download_artifact_from_uri
68
+ from mlflow.types.schema import ColSpec, DataType, Schema
69
+ from mlflow.utils.annotations import experimental
70
+ from mlflow.utils.databricks_utils import (
71
+ _get_databricks_serverless_env_vars,
72
+ is_in_databricks_model_serving_environment,
73
+ is_in_databricks_serverless_runtime,
74
+ is_mlflow_tracing_enabled_in_model_serving,
75
+ )
76
+ from mlflow.utils.docstring_utils import (
77
+ LOG_MODEL_PARAM_DOCS,
78
+ docstring_version_compatibility_warning,
79
+ format_docstring,
80
+ )
81
+ from mlflow.utils.environment import (
82
+ _CONDA_ENV_FILE_NAME,
83
+ _CONSTRAINTS_FILE_NAME,
84
+ _PYTHON_ENV_FILE_NAME,
85
+ _REQUIREMENTS_FILE_NAME,
86
+ _mlflow_conda_env,
87
+ _process_conda_env,
88
+ _process_pip_requirements,
89
+ _PythonEnv,
90
+ _validate_env_arguments,
91
+ )
92
+ from mlflow.utils.file_utils import get_total_file_size, write_to
93
+ from mlflow.utils.model_utils import (
94
+ _add_code_from_conf_to_system_path,
95
+ _get_flavor_configuration,
96
+ _validate_and_copy_code_paths,
97
+ _validate_and_copy_file_to_directory,
98
+ _validate_and_get_model_config_from_file,
99
+ _validate_and_prepare_target_save_path,
100
+ )
101
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
102
+
103
+ logger = logging.getLogger(mlflow.__name__)
104
+
105
+ _MODEL_TYPE_KEY = "model_type"
106
+
107
+
108
+ def get_default_pip_requirements():
109
+ """
110
+ Returns:
111
+ A list of default pip requirements for MLflow Models produced by this flavor.
112
+ Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
113
+ that, at a minimum, contains these requirements.
114
+ """
115
+ # pin pydantic and cloudpickle version as they are used in langchain
116
+ # model saving and loading
117
+ return list(map(_get_pinned_requirement, ["langchain", "pydantic", "cloudpickle"]))
118
+
119
+
120
+ def get_default_conda_env():
121
+ """
122
+ Returns:
123
+ The default Conda environment for MLflow Models produced by calls to
124
+ :func:`save_model()` and :func:`log_model()`.
125
+ """
126
+ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
127
+
128
+
129
+ @experimental(version="2.3.0")
130
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
131
+ @docstring_version_compatibility_warning(FLAVOR_NAME)
132
+ @trace_disabled # Suppress traces for internal predict calls while saving model
133
+ def save_model(
134
+ lc_model,
135
+ path,
136
+ conda_env=None,
137
+ code_paths=None,
138
+ mlflow_model=None,
139
+ signature: ModelSignature = None,
140
+ input_example: ModelInputExample = None,
141
+ pip_requirements=None,
142
+ extra_pip_requirements=None,
143
+ metadata=None,
144
+ loader_fn=None,
145
+ persist_dir=None,
146
+ model_config=None,
147
+ streamable: Optional[bool] = None,
148
+ ):
149
+ """
150
+ Save a LangChain model to a path on the local file system.
151
+
152
+ Args:
153
+ lc_model: A LangChain model, which could be a
154
+ `Chain <https://python.langchain.com/docs/modules/chains/>`_,
155
+ `Agent <https://python.langchain.com/docs/modules/agents/>`_,
156
+ `retriever <https://python.langchain.com/docs/modules/data_connection/retrievers/>`_,
157
+ or `RunnableSequence <https://python.langchain.com/docs/modules/chains/foundational/sequential_chains#using-lcel>`_,
158
+ or a path containing the `LangChain model code <https://github.com/mlflow/mlflow/blob/master/examples/langchain/chain_as_code_driver.py>`
159
+ for the above types. When using model as path, make sure to set the model
160
+ by using :func:`mlflow.models.set_model()`.
161
+
162
+ .. Note:: Experimental: Using model as path may change or be removed in a future
163
+ release without warning.
164
+ path: Local path where the serialized model (as YAML) is to be saved.
165
+ conda_env: {{ conda_env }}
166
+ code_paths: {{ code_paths }}
167
+ mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
168
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
169
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
170
+ If not specified, the model signature would be set according to
171
+ `lc_model.input_keys` and `lc_model.output_keys` as columns names, and
172
+ `DataType.string` as the column type.
173
+ Alternatively, you can explicitly specify the model signature.
174
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
175
+ from datasets with valid model input (e.g. the training dataset with target
176
+ column omitted) and valid model output (e.g. model predictions generated on
177
+ the training dataset), for example:
178
+
179
+ .. code-block:: python
180
+
181
+ from mlflow.models import infer_signature
182
+
183
+ chain = LLMChain(llm=llm, prompt=prompt)
184
+ prediction = chain.run(input_str)
185
+ input_columns = [
186
+ {"type": "string", "name": input_key} for input_key in chain.input_keys
187
+ ]
188
+ signature = infer_signature(input_columns, predictions)
189
+
190
+ input_example: {{ input_example }}
191
+ pip_requirements: {{ pip_requirements }}
192
+ extra_pip_requirements: {{ extra_pip_requirements }}
193
+ metadata: {{ metadata }}
194
+ loader_fn: A function that's required for models containing objects that aren't natively
195
+ serialized by LangChain.
196
+ This function takes a string `persist_dir` as an argument and returns the
197
+ specific object that the model needs. Depending on the model,
198
+ this could be a retriever, vectorstore, requests_wrapper, embeddings, or
199
+ database. For RetrievalQA Chain and retriever models, the object is a
200
+ (`retriever <https://python.langchain.com/docs/modules/data_connection/retrievers/>`_).
201
+ For APIChain models, it's a
202
+ (`requests_wrapper <https://python.langchain.com/docs/modules/agents/tools/integrations/requests>`_).
203
+ For HypotheticalDocumentEmbedder models, it's an
204
+ (`embeddings <https://python.langchain.com/docs/modules/data_connection/text_embedding/>`_).
205
+ For SQLDatabaseChain models, it's a
206
+ (`database <https://python.langchain.com/docs/modules/agents/toolkits/sql_database>`_).
207
+ persist_dir: The directory where the object is stored. The `loader_fn`
208
+ takes this string as the argument to load the object.
209
+ This is optional for models containing objects that aren't natively
210
+ serialized by LangChain. MLflow logs the content in this directory as
211
+ artifacts in the subdirectory named `persist_dir_data`.
212
+
213
+ Here is the code snippet for logging a RetrievalQA chain with `loader_fn`
214
+ and `persist_dir`:
215
+
216
+ .. Note:: In langchain_community >= 0.0.27, loading pickled data requires providing the
217
+ ``allow_dangerous_deserialization`` argument.
218
+
219
+ .. code-block:: python
220
+
221
+ qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=db.as_retriever())
222
+
223
+
224
+ def load_retriever(persist_directory):
225
+ embeddings = OpenAIEmbeddings()
226
+ vectorstore = FAISS.load_local(
227
+ persist_directory,
228
+ embeddings,
229
+ # you may need to add the line below
230
+ # for langchain_community >= 0.0.27
231
+ allow_dangerous_deserialization=True,
232
+ )
233
+ return vectorstore.as_retriever()
234
+
235
+
236
+ with mlflow.start_run() as run:
237
+ logged_model = mlflow.langchain.log_model(
238
+ qa,
239
+ name="retrieval_qa",
240
+ loader_fn=load_retriever,
241
+ persist_dir=persist_dir,
242
+ )
243
+
244
+ See a complete example in examples/langchain/retrieval_qa_chain.py.
245
+ model_config: The model configuration to apply to the model if saving model from code. This
246
+ configuration is available during model loading.
247
+
248
+ .. Note:: Experimental: This parameter may change or be removed in a future
249
+ release without warning.
250
+ streamable: A boolean value indicating if the model supports streaming prediction. If
251
+ True, the model must implement `stream` method. If None, streamable is
252
+ set to True if the model implements `stream` method. Default to `None`.
253
+ """
254
+ with tempfile.TemporaryDirectory() as temp_dir:
255
+ import langchain
256
+ from langchain.schema import BaseRetriever
257
+
258
+ lc_model_or_path = _validate_and_prepare_lc_model_or_path(lc_model, loader_fn, temp_dir)
259
+
260
+ _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
261
+
262
+ path = os.path.abspath(path)
263
+ _validate_and_prepare_target_save_path(path)
264
+
265
+ if isinstance(model_config, str):
266
+ model_config = _validate_and_get_model_config_from_file(model_config)
267
+
268
+ model_code_path = None
269
+ if isinstance(lc_model_or_path, str):
270
+ # The LangChain model is defined as Python code located in the file at the path
271
+ # specified by `lc_model`. Verify that the path exists and, if so, copy it to the
272
+ # model directory along with any other specified code modules
273
+ model_code_path = lc_model_or_path
274
+
275
+ lc_model = _load_model_code_path(model_code_path, model_config)
276
+ _validate_and_copy_file_to_directory(model_code_path, path, "code")
277
+ else:
278
+ lc_model = lc_model_or_path
279
+
280
+ code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
281
+
282
+ if mlflow_model is None:
283
+ mlflow_model = Model()
284
+ saved_example = _save_example(mlflow_model, input_example, path)
285
+
286
+ if signature is None:
287
+ if saved_example is not None:
288
+ wrapped_model = _LangChainModelWrapper(lc_model)
289
+ signature = _infer_signature_from_input_example(saved_example, wrapped_model)
290
+ else:
291
+ if hasattr(lc_model, "input_keys"):
292
+ input_columns = [
293
+ ColSpec(type=DataType.string, name=input_key)
294
+ for input_key in lc_model.input_keys
295
+ ]
296
+ input_schema = Schema(input_columns)
297
+ else:
298
+ input_schema = None
299
+ if (
300
+ hasattr(lc_model, "output_keys")
301
+ and len(lc_model.output_keys) == 1
302
+ and not isinstance(lc_model, BaseRetriever)
303
+ ):
304
+ output_columns = [
305
+ ColSpec(type=DataType.string, name=output_key)
306
+ for output_key in lc_model.output_keys
307
+ ]
308
+ output_schema = Schema(output_columns)
309
+ else:
310
+ # TODO: empty output schema if multiple output_keys or is a retriever. fix later!
311
+ # https://databricks.atlassian.net/browse/ML-34706
312
+ output_schema = None
313
+
314
+ signature = (
315
+ ModelSignature(input_schema, output_schema)
316
+ if input_schema or output_schema
317
+ else None
318
+ )
319
+
320
+ if signature is not None:
321
+ mlflow_model.signature = signature
322
+ if metadata is not None:
323
+ mlflow_model.metadata = metadata
324
+
325
+ with _get_dependencies_schemas() as dependencies_schemas:
326
+ schema = dependencies_schemas.to_dict()
327
+ if schema is not None:
328
+ if mlflow_model.metadata is None:
329
+ mlflow_model.metadata = {}
330
+ mlflow_model.metadata.update(schema)
331
+
332
+ if streamable is None:
333
+ streamable = hasattr(lc_model, "stream")
334
+
335
+ model_data_kwargs = {}
336
+ flavor_conf = {}
337
+ if not isinstance(model_code_path, str):
338
+ model_data_kwargs = _save_model(lc_model, path, loader_fn, persist_dir)
339
+ flavor_conf = {
340
+ _MODEL_TYPE_KEY: lc_model.__class__.__name__,
341
+ **model_data_kwargs,
342
+ }
343
+
344
+ pyfunc.add_to_model(
345
+ mlflow_model,
346
+ loader_module="mlflow.langchain",
347
+ conda_env=_CONDA_ENV_FILE_NAME,
348
+ python_env=_PYTHON_ENV_FILE_NAME,
349
+ code=code_dir_subpath,
350
+ predict_stream_fn="predict_stream",
351
+ streamable=streamable,
352
+ model_code_path=model_code_path,
353
+ model_config=model_config,
354
+ **model_data_kwargs,
355
+ )
356
+
357
+ needs_databricks_auth = False
358
+ if Version(langchain.__version__) >= Version("0.0.311") and mlflow_model.resources is None:
359
+ if databricks_resources := _detect_databricks_dependencies(lc_model):
360
+ logger.info(
361
+ "Attempting to auto-detect Databricks resource dependencies for the "
362
+ "current langchain model. Dependency auto-detection is "
363
+ "best-effort and may not capture all dependencies of your langchain "
364
+ "model, resulting in authorization errors when serving or querying "
365
+ "your model. We recommend that you explicitly pass `resources` "
366
+ "to mlflow.langchain.log_model() to ensure authorization to "
367
+ "dependent resources succeeds when the model is deployed."
368
+ )
369
+ serialized_databricks_resources = _ResourceBuilder.from_resources(databricks_resources)
370
+ mlflow_model.resources = serialized_databricks_resources
371
+ needs_databricks_auth = any(
372
+ isinstance(r, DatabricksFunction) for r in databricks_resources
373
+ )
374
+
375
+ mlflow_model.add_flavor(
376
+ FLAVOR_NAME,
377
+ langchain_version=langchain.__version__,
378
+ code=code_dir_subpath,
379
+ streamable=streamable,
380
+ **flavor_conf,
381
+ )
382
+ if size := get_total_file_size(path):
383
+ mlflow_model.model_size_bytes = size
384
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
385
+
386
+ if conda_env is None:
387
+ if pip_requirements is None:
388
+ default_reqs = get_default_pip_requirements()
389
+ extra_env_vars = (
390
+ _get_databricks_serverless_env_vars()
391
+ if needs_databricks_auth and is_in_databricks_serverless_runtime()
392
+ else None
393
+ )
394
+ inferred_reqs = mlflow.models.infer_pip_requirements(
395
+ str(path), FLAVOR_NAME, fallback=default_reqs, extra_env_vars=extra_env_vars
396
+ )
397
+ default_reqs = sorted(set(inferred_reqs).union(default_reqs))
398
+ else:
399
+ default_reqs = None
400
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
401
+ default_reqs, pip_requirements, extra_pip_requirements
402
+ )
403
+ else:
404
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
405
+
406
+ with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
407
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
408
+
409
+ if pip_constraints:
410
+ write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
411
+
412
+ write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
413
+
414
+ _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
415
+
416
+
417
+ @experimental(version="2.3.0")
418
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
419
+ @docstring_version_compatibility_warning(FLAVOR_NAME)
420
+ @trace_disabled # Suppress traces for internal predict calls while logging model
421
+ def log_model(
422
+ lc_model,
423
+ artifact_path: Optional[str] = None,
424
+ conda_env=None,
425
+ code_paths=None,
426
+ registered_model_name=None,
427
+ signature: ModelSignature = None,
428
+ input_example: ModelInputExample = None,
429
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
430
+ pip_requirements=None,
431
+ extra_pip_requirements=None,
432
+ metadata=None,
433
+ loader_fn=None,
434
+ persist_dir=None,
435
+ run_id=None,
436
+ model_config=None,
437
+ streamable=None,
438
+ resources: Optional[Union[list[Resource], str]] = None,
439
+ prompts: Optional[list[Union[str, Prompt]]] = None,
440
+ name: Optional[str] = None,
441
+ params: Optional[dict[str, Any]] = None,
442
+ tags: Optional[dict[str, Any]] = None,
443
+ model_type: Optional[str] = None,
444
+ step: int = 0,
445
+ model_id: Optional[str] = None,
446
+ ):
447
+ """
448
+ Log a LangChain model as an MLflow artifact for the current run.
449
+
450
+ Args:
451
+ lc_model: A LangChain model, which could be a
452
+ `Chain <https://python.langchain.com/docs/modules/chains/>`_,
453
+ `Agent <https://python.langchain.com/docs/modules/agents/>`_, or
454
+ `retriever <https://python.langchain.com/docs/modules/data_connection/retrievers/>`_
455
+ or a path containing the `LangChain model code <https://github.com/mlflow/mlflow/blob/master/examples/langchain/chain_as_code_driver.py>`
456
+ for the above types. When using model as path, make sure to set the model
457
+ by using :func:`mlflow.models.set_model()`.
458
+
459
+ .. Note:: Experimental: Using model as path may change or be removed in a future
460
+ release without warning.
461
+ artifact_path: Deprecated. Use `name` instead.
462
+ conda_env: {{ conda_env }}
463
+ code_paths: {{ code_paths }}
464
+ registered_model_name: If given, create a model
465
+ version under ``registered_model_name``, also creating a
466
+ registered model if one with the given name does not exist.
467
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
468
+ describes model input and output
469
+ :py:class:`Schema <mlflow.types.Schema>`.
470
+ If not specified, the model signature would be set according to
471
+ `lc_model.input_keys` and `lc_model.output_keys` as columns names, and
472
+ `DataType.string` as the column type.
473
+ Alternatively, you can explicitly specify the model signature.
474
+ The model signature can be :py:func:`inferred
475
+ <mlflow.models.infer_signature>` from datasets with valid model input
476
+ (e.g. the training dataset with target column omitted) and valid model
477
+ output (e.g. model predictions generated on the training dataset),
478
+ for example:
479
+
480
+ .. code-block:: python
481
+
482
+ from mlflow.models import infer_signature
483
+
484
+ chain = LLMChain(llm=llm, prompt=prompt)
485
+ prediction = chain.run(input_str)
486
+ input_columns = [
487
+ {"type": "string", "name": input_key} for input_key in chain.input_keys
488
+ ]
489
+ signature = infer_signature(input_columns, predictions)
490
+
491
+ input_example: {{ input_example }}
492
+ await_registration_for: Number of seconds to wait for the model version
493
+ to finish being created and is in ``READY`` status.
494
+ By default, the function waits for five minutes.
495
+ Specify 0 or None to skip waiting.
496
+ pip_requirements: {{ pip_requirements }}
497
+ extra_pip_requirements: {{ extra_pip_requirements }}
498
+ metadata: {{ metadata }}
499
+ loader_fn: A function that's required for models containing objects that aren't natively
500
+ serialized by LangChain.
501
+ This function takes a string `persist_dir` as an argument and returns the
502
+ specific object that the model needs. Depending on the model,
503
+ this could be a retriever, vectorstore, requests_wrapper, embeddings, or
504
+ database. For RetrievalQA Chain and retriever models, the object is a
505
+ (`retriever <https://python.langchain.com/docs/modules/data_connection/retrievers/>`_).
506
+ For APIChain models, it's a
507
+ (`requests_wrapper <https://python.langchain.com/docs/modules/agents/tools/integrations/requests>`_).
508
+ For HypotheticalDocumentEmbedder models, it's an
509
+ (`embeddings <https://python.langchain.com/docs/modules/data_connection/text_embedding/>`_).
510
+ For SQLDatabaseChain models, it's a
511
+ (`database <https://python.langchain.com/docs/modules/agents/toolkits/sql_database>`_).
512
+ persist_dir: The directory where the object is stored. The `loader_fn`
513
+ takes this string as the argument to load the object.
514
+ This is optional for models containing objects that aren't natively
515
+ serialized by LangChain. MLflow logs the content in this directory as
516
+ artifacts in the subdirectory named `persist_dir_data`.
517
+
518
+ Here is the code snippet for logging a RetrievalQA chain with `loader_fn`
519
+ and `persist_dir`:
520
+
521
+ .. Note:: In langchain_community >= 0.0.27, loading pickled data requires providing the
522
+ ``allow_dangerous_deserialization`` argument.
523
+
524
+ .. code-block:: python
525
+
526
+ qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=db.as_retriever())
527
+
528
+
529
+ def load_retriever(persist_directory):
530
+ embeddings = OpenAIEmbeddings()
531
+ vectorstore = FAISS.load_local(
532
+ persist_directory,
533
+ embeddings,
534
+ # you may need to add the line below
535
+ # for langchain_community >= 0.0.27
536
+ allow_dangerous_deserialization=True,
537
+ )
538
+ return vectorstore.as_retriever()
539
+
540
+
541
+ with mlflow.start_run() as run:
542
+ logged_model = mlflow.langchain.log_model(
543
+ qa,
544
+ name="retrieval_qa",
545
+ loader_fn=load_retriever,
546
+ persist_dir=persist_dir,
547
+ )
548
+
549
+ See a complete example in examples/langchain/retrieval_qa_chain.py.
550
+ run_id: run_id to associate with this model version. If specified, we resume the
551
+ run and log the model to that run. Otherwise, a new run is created.
552
+ Default to None.
553
+ model_config: The model configuration to apply to the model if saving model from code. This
554
+ configuration is available during model loading.
555
+
556
+ .. Note:: Experimental: This parameter may change or be removed in a future
557
+ release without warning.
558
+ streamable: A boolean value indicating if the model supports streaming prediction. If
559
+ True, the model must implement `stream` method. If None, If None, streamable is
560
+ set to True if the model implements `stream` method. Default to `None`.
561
+ resources: A list of model resources or a resources.yaml file containing a list of
562
+ resources required to serve the model. If logging a LangChain model with dependencies
563
+ (e.g. on LLM model serving endpoints), we encourage explicitly passing dependencies
564
+ via this parameter. Otherwise, ``log_model`` will attempt to infer dependencies,
565
+ but dependency auto-inference is best-effort and may miss some dependencies.
566
+ prompts: {{ prompts }}
567
+
568
+ name: {{ name }}
569
+ params: {{ params }}
570
+ tags: {{ tags }}
571
+ model_type: {{ model_type }}
572
+ step: {{ step }}
573
+ model_id: {{ model_id }}
574
+
575
+ Returns:
576
+ A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
577
+ metadata of the logged model.
578
+ """
579
+ return Model.log(
580
+ artifact_path=artifact_path,
581
+ name=name,
582
+ flavor=mlflow.langchain,
583
+ registered_model_name=registered_model_name,
584
+ lc_model=lc_model,
585
+ conda_env=conda_env,
586
+ code_paths=code_paths,
587
+ signature=signature,
588
+ input_example=input_example,
589
+ await_registration_for=await_registration_for,
590
+ pip_requirements=pip_requirements,
591
+ extra_pip_requirements=extra_pip_requirements,
592
+ metadata=metadata,
593
+ loader_fn=loader_fn,
594
+ persist_dir=persist_dir,
595
+ run_id=run_id,
596
+ model_config=model_config,
597
+ streamable=streamable,
598
+ resources=resources,
599
+ prompts=prompts,
600
+ params=params,
601
+ tags=tags,
602
+ model_type=model_type,
603
+ step=step,
604
+ model_id=model_id,
605
+ )
606
+
607
+
608
+ # patch_langchain_type_to_cls_dict here as we attempt to load model
609
+ # if it's saved by `dict` method
610
+ @patch_langchain_type_to_cls_dict
611
+ def _save_model(model, path, loader_fn, persist_dir):
612
+ if Version(cloudpickle.__version__) < Version("2.1.0"):
613
+ warnings.warn(
614
+ "If you are constructing a custom LangChain model, "
615
+ "please upgrade cloudpickle to version 2.1.0 or later "
616
+ "using `pip install cloudpickle>=2.1.0` "
617
+ "to ensure the model can be loaded correctly."
618
+ )
619
+
620
+ with register_pydantic_v1_serializer_cm():
621
+ if isinstance(model, lc_runnables_types()):
622
+ return _save_runnables(model, path, loader_fn=loader_fn, persist_dir=persist_dir)
623
+ else:
624
+ return _save_base_lcs(model, path, loader_fn, persist_dir)
625
+
626
+
627
+ @patch_langchain_type_to_cls_dict
628
+ def _load_model(local_model_path, flavor_conf):
629
+ # model_type is not accurate as the class can be subclass
630
+ # of supported types, we define _MODEL_LOAD_KEY to ensure
631
+ # which load function to use
632
+ model_load_fn = flavor_conf.get(_MODEL_LOAD_KEY)
633
+ with register_pydantic_v1_serializer_cm():
634
+ if model_load_fn == _RUNNABLE_LOAD_KEY:
635
+ model = _load_runnables(local_model_path, flavor_conf)
636
+ elif model_load_fn == _BASE_LOAD_KEY:
637
+ model = _load_base_lcs(local_model_path, flavor_conf)
638
+ else:
639
+ raise mlflow.MlflowException(
640
+ "Failed to load LangChain model. Unknown model type: "
641
+ f"{flavor_conf.get(_MODEL_TYPE_KEY)}"
642
+ )
643
+ return model
644
+
645
+
646
+ class _LangChainModelWrapper:
647
+ def __init__(self, lc_model, model_path=None):
648
+ self.lc_model = lc_model
649
+ self.model_path = model_path
650
+
651
+ def get_raw_model(self):
652
+ """
653
+ Returns the underlying model.
654
+ """
655
+ return self.lc_model
656
+
657
+ def predict(
658
+ self,
659
+ data: Union[pd.DataFrame, list[Union[str, dict[str, Any]]], Any],
660
+ params: Optional[dict[str, Any]] = None,
661
+ ) -> list[Union[str, dict[str, Any]]]:
662
+ """
663
+ Args:
664
+ data: Model input data.
665
+ params: Additional parameters to pass to the model for inference.
666
+
667
+ Returns:
668
+ Model predictions.
669
+ """
670
+ # TODO: We don't automatically turn tracing on in OSS model serving, because we haven't
671
+ # implemented storage option for traces in OSS model serving (counterpart to the
672
+ # Inference Table in Databricks model serving).
673
+ if (
674
+ is_in_databricks_model_serving_environment()
675
+ # TODO: This env var was once used for controlling whether or not to inject the
676
+ # tracer in Databricks model serving. However, now we have the new env var
677
+ # `ENABLE_MLFLOW_TRACING` to control that. We don't remove this condition
678
+ # right now in the interest of caution, but we should remove this condition
679
+ # after making sure that the functionality is stable.
680
+ and os.environ.get("MLFLOW_ENABLE_TRACE_IN_SERVING", "false").lower() == "true"
681
+ # if this is False, tracing is disabled and we shouldn't inject the tracer
682
+ and is_mlflow_tracing_enabled_in_model_serving()
683
+ ):
684
+ from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
685
+
686
+ callbacks = [MlflowLangchainTracer()]
687
+ else:
688
+ callbacks = None
689
+
690
+ return self._predict_with_callbacks(data, params, callback_handlers=callbacks)
691
+
692
+ def _update_dependencies_schemas_in_prediction_context(
693
+ self, callback_handlers
694
+ ) -> Optional[Context]:
695
+ from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
696
+
697
+ if (
698
+ callback_handlers
699
+ and (
700
+ tracer := next(
701
+ (c for c in callback_handlers if isinstance(c, MlflowLangchainTracer)), None
702
+ )
703
+ )
704
+ and self.model_path
705
+ ):
706
+ model = Model.load(self.model_path)
707
+ context = tracer._prediction_context
708
+ if context and (schema := _get_dependencies_schema_from_model(model)):
709
+ context.update(**schema)
710
+ return context
711
+
712
+ @experimental(version="2.10.0")
713
+ def _predict_with_callbacks(
714
+ self,
715
+ data: Union[pd.DataFrame, list[Union[str, dict[str, Any]]], Any],
716
+ params: Optional[dict[str, Any]] = None,
717
+ callback_handlers=None,
718
+ convert_chat_responses=False,
719
+ ) -> list[Union[str, dict[str, Any]]]:
720
+ """
721
+ Args:
722
+ data: Model input data.
723
+ params: Additional parameters to pass to the model for inference.
724
+ callback_handlers: Callback handlers to pass to LangChain.
725
+ convert_chat_responses: If true, forcibly convert response to chat model
726
+ response format.
727
+
728
+ Returns:
729
+ Model predictions.
730
+ """
731
+ from mlflow.langchain.api_request_parallel_processor import process_api_requests
732
+
733
+ context = self._update_dependencies_schemas_in_prediction_context(callback_handlers)
734
+ messages, return_first_element = self._prepare_predict_messages(data)
735
+ results = process_api_requests(
736
+ lc_model=self.lc_model,
737
+ requests=messages,
738
+ callback_handlers=callback_handlers,
739
+ convert_chat_responses=convert_chat_responses,
740
+ params=params or {},
741
+ context=context,
742
+ )
743
+ return results[0] if return_first_element else results
744
+
745
+ def _prepare_predict_messages(self, data):
746
+ """
747
+ Return a tuple of (preprocessed_data, return_first_element)
748
+ `preprocessed_data` is always a list,
749
+ and `return_first_element` means if True, we should return the first element
750
+ of inference result, otherwise we should return the whole inference result.
751
+ """
752
+ data = _convert_llm_input_data(data)
753
+
754
+ if not isinstance(data, list):
755
+ # if the input data is not a list (i.e. single input),
756
+ # we still need to convert it to a one-element list `[data]`
757
+ # because `process_api_requests` only accepts list as valid input.
758
+ # and in this case,
759
+ # we should return the first element of the inference result
760
+ # because we change input `data` to `[data]`
761
+ return [data], True
762
+ if isinstance(data, list):
763
+ return data, False
764
+ raise mlflow.MlflowException.invalid_parameter_value(
765
+ "Input must be a pandas DataFrame or a list "
766
+ f"for model {self.lc_model.__class__.__name__}"
767
+ )
768
+
769
+ def _prepare_predict_stream_messages(self, data):
770
+ data = _convert_llm_input_data(data)
771
+
772
+ if isinstance(data, list):
773
+ # `predict_stream` only accepts single input.
774
+ # but `enforce_schema` might convert single input into a list like `[single_input]`
775
+ # so extract the first element in the list.
776
+ if len(data) != 1:
777
+ raise MlflowException(
778
+ f"'predict_stream' requires single input, but it got input data {data}"
779
+ )
780
+ return data[0]
781
+ return data
782
+
783
+ def predict_stream(
784
+ self,
785
+ data: Any,
786
+ params: Optional[dict[str, Any]] = None,
787
+ ) -> Iterator[Union[str, dict[str, Any]]]:
788
+ """
789
+ Args:
790
+ data: Model input data, only single input is allowed.
791
+ params: Additional parameters to pass to the model for inference.
792
+
793
+ Returns:
794
+ An iterator of model prediction chunks.
795
+ """
796
+ from mlflow.langchain.api_request_parallel_processor import (
797
+ process_stream_request,
798
+ )
799
+
800
+ data = self._prepare_predict_stream_messages(data)
801
+ return process_stream_request(
802
+ lc_model=self.lc_model,
803
+ request_json=data,
804
+ params=params or {},
805
+ )
806
+
807
+ def _predict_stream_with_callbacks(
808
+ self,
809
+ data: Any,
810
+ params: Optional[dict[str, Any]] = None,
811
+ callback_handlers=None,
812
+ convert_chat_responses=False,
813
+ ) -> Iterator[Union[str, dict[str, Any]]]:
814
+ """
815
+ Args:
816
+ data: Model input data, only single input is allowed.
817
+ params: Additional parameters to pass to the model for inference.
818
+ callback_handlers: Callback handlers to pass to LangChain.
819
+ convert_chat_responses: If true, forcibly convert response to chat model
820
+ response format.
821
+
822
+ Returns:
823
+ An iterator of model prediction chunks.
824
+ """
825
+ from mlflow.langchain.api_request_parallel_processor import (
826
+ process_stream_request,
827
+ )
828
+
829
+ self._update_dependencies_schemas_in_prediction_context(callback_handlers)
830
+ data = self._prepare_predict_stream_messages(data)
831
+ return process_stream_request(
832
+ lc_model=self.lc_model,
833
+ request_json=data,
834
+ callback_handlers=callback_handlers,
835
+ convert_chat_responses=convert_chat_responses,
836
+ params=params or {},
837
+ )
838
+
839
+
840
+ def _load_pyfunc(path: str, model_config: Optional[dict[str, Any]] = None): # noqa: D417
841
+ """Load PyFunc implementation for LangChain. Called by ``pyfunc.load_model``.
842
+
843
+ Args:
844
+ path: Local filesystem path to the MLflow Model with the ``langchain`` flavor.
845
+ """
846
+ return _LangChainModelWrapper(_load_model_from_local_fs(path, model_config), path)
847
+
848
+
849
+ def _load_model_from_local_fs(local_model_path, model_config_overrides=None):
850
+ mlflow_model = Model.load(local_model_path)
851
+ flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
852
+ pyfunc_flavor_conf = _get_flavor_configuration(
853
+ model_path=local_model_path, flavor_name=PYFUNC_FLAVOR_NAME
854
+ )
855
+ # Add code from the langchain flavor to the system path
856
+ _add_code_from_conf_to_system_path(local_model_path, flavor_conf)
857
+ # The model_code_path and the model_config were previously saved langchain flavor but now we
858
+ # also save them inside the pyfunc flavor. For backwards compatibility of previous models,
859
+ # we need to check both places.
860
+ if MODEL_CODE_PATH in pyfunc_flavor_conf or MODEL_CODE_PATH in flavor_conf:
861
+ model_config = pyfunc_flavor_conf.get(MODEL_CONFIG, flavor_conf.get(MODEL_CONFIG, None))
862
+ if isinstance(model_config, str):
863
+ config_path = os.path.join(
864
+ local_model_path,
865
+ os.path.basename(model_config),
866
+ )
867
+ model_config = _validate_and_get_model_config_from_file(config_path)
868
+
869
+ flavor_code_path = pyfunc_flavor_conf.get(
870
+ MODEL_CODE_PATH, flavor_conf.get(MODEL_CODE_PATH, None)
871
+ )
872
+ model_code_path = os.path.join(
873
+ local_model_path,
874
+ os.path.basename(flavor_code_path),
875
+ )
876
+ try:
877
+ model = _load_model_code_path(
878
+ model_code_path, {**(model_config or {}), **(model_config_overrides or {})}
879
+ )
880
+ finally:
881
+ # We would like to clean up the dependencies schema which is set to global
882
+ # after loading the mode to avoid the schema being used in the next model loading
883
+ _clear_dependencies_schemas()
884
+ else:
885
+ model = _load_model(local_model_path, flavor_conf)
886
+ # set active model after model loading since experiment ID might be set
887
+ # in the model loading process
888
+ _update_active_model_id_based_on_mlflow_model(mlflow_model)
889
+ return model
890
+
891
+
892
+ @experimental(version="2.3.0")
893
+ @docstring_version_compatibility_warning(FLAVOR_NAME)
894
+ @trace_disabled # Suppress traces while loading model
895
+ def load_model(model_uri, dst_path=None):
896
+ """
897
+ Load a LangChain model from a local file or a run.
898
+
899
+ Args:
900
+ model_uri: The location, in URI format, of the MLflow model. For example:
901
+
902
+ - ``/Users/me/path/to/local/model``
903
+ - ``relative/path/to/local/model``
904
+ - ``s3://my_bucket/path/to/model``
905
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
906
+
907
+ For more information about supported URI schemes, see
908
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
909
+ artifact-locations>`_.
910
+ dst_path: The local filesystem path to which to download the model artifact.
911
+ This directory must already exist. If unspecified, a local output
912
+ path will be created.
913
+
914
+ Returns:
915
+ A LangChain model instance.
916
+ """
917
+ model_uri = str(model_uri)
918
+ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
919
+ return _load_model_from_local_fs(local_model_path)