genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,671 @@
1
+ """Utility functions for mlflow.langchain."""
2
+
3
+ import contextlib
4
+ import functools
5
+ import importlib
6
+ import json
7
+ import logging
8
+ import os
9
+ import re
10
+ import shutil
11
+ import types
12
+ import warnings
13
+ from functools import lru_cache
14
+ from importlib.util import find_spec
15
+ from typing import Any, Callable, NamedTuple
16
+
17
+ import cloudpickle
18
+ import yaml
19
+ from packaging import version
20
+ from packaging.version import Version
21
+
22
+ import mlflow
23
+ from mlflow.exceptions import MlflowException
24
+ from mlflow.models.utils import _validate_and_get_model_code_path
25
+ from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
26
+ from mlflow.utils.class_utils import _get_class_from_string
27
+
28
+ _AGENT_PRIMITIVES_FILE_NAME = "agent_primitive_args.json"
29
+ _AGENT_PRIMITIVES_DATA_KEY = "agent_primitive_data"
30
+ _AGENT_DATA_FILE_NAME = "agent.yaml"
31
+ _AGENT_DATA_KEY = "agent_data"
32
+ _TOOLS_DATA_FILE_NAME = "tools.pkl"
33
+ _TOOLS_DATA_KEY = "tools_data"
34
+ _LOADER_FN_FILE_NAME = "loader_fn.pkl"
35
+ _LOADER_FN_KEY = "loader_fn"
36
+ _LOADER_ARG_KEY = "loader_arg"
37
+ _PERSIST_DIR_NAME = "persist_dir_data"
38
+ _PERSIST_DIR_KEY = "persist_dir"
39
+ _MODEL_DATA_YAML_FILE_NAME = "model.yaml"
40
+ _MODEL_DATA_PKL_FILE_NAME = "model.pkl"
41
+ _MODEL_DATA_FOLDER_NAME = "model"
42
+ _MODEL_DATA_KEY = "model_data"
43
+ _MODEL_TYPE_KEY = "model_type"
44
+ _RUNNABLE_LOAD_KEY = "runnable_load"
45
+ _BASE_LOAD_KEY = "base_load"
46
+ _CONFIG_LOAD_KEY = "config_load"
47
+ _PICKLE_LOAD_KEY = "pickle_load"
48
+ _MODEL_LOAD_KEY = "model_load"
49
+ _UNSUPPORTED_MODEL_WARNING_MESSAGE = (
50
+ "MLflow does not guarantee support for Chains outside of the subclasses of LLMChain, found %s"
51
+ )
52
+ _UNSUPPORTED_LLM_WARNING_MESSAGE = (
53
+ "MLflow does not guarantee support for LLMs outside of HuggingFacePipeline and OpenAI, found %s"
54
+ )
55
+
56
+
57
+ _CHAT_MODELS_ERROR_MSG = re.compile("Loading (openai-chat|azure-openai-chat) LLM not supported")
58
+
59
+
60
+ try:
61
+ import langchain_community
62
+
63
+ # Since langchain-community 0.0.27, saving or loading a module that relies on the pickle
64
+ # deserialization requires passing `allow_dangerous_deserialization=True`.
65
+ IS_PICKLE_SERIALIZATION_RESTRICTED = Version(langchain_community.__version__) >= Version(
66
+ "0.0.27"
67
+ )
68
+ except ImportError:
69
+ IS_PICKLE_SERIALIZATION_RESTRICTED = False
70
+
71
+ logger = logging.getLogger(__name__)
72
+
73
+
74
+ @lru_cache
75
+ def base_lc_types():
76
+ # add this import to avoid missing module error
77
+ import langchain.agents
78
+ import langchain.agents.agent
79
+ import langchain.chains.base
80
+ import langchain.schema
81
+
82
+ return (
83
+ langchain.chains.base.Chain,
84
+ langchain.agents.agent.AgentExecutor,
85
+ langchain.schema.BaseRetriever,
86
+ )
87
+
88
+
89
+ @lru_cache
90
+ def picklable_runnable_types():
91
+ """
92
+ Runnable types that can be pickled and unpickled by cloudpickle.
93
+ """
94
+ from langchain.chat_models.base import SimpleChatModel
95
+ from langchain.prompts import ChatPromptTemplate
96
+ from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
97
+
98
+ return (
99
+ SimpleChatModel,
100
+ ChatPromptTemplate,
101
+ RunnablePassthrough,
102
+ RunnableLambda,
103
+ )
104
+
105
+
106
+ @lru_cache
107
+ def lc_runnable_with_steps_types():
108
+ from langchain.schema.runnable import RunnableParallel, RunnableSequence
109
+
110
+ return (RunnableParallel, RunnableSequence)
111
+
112
+
113
+ def lc_runnable_assign_types():
114
+ from langchain.schema.runnable.passthrough import RunnableAssign
115
+
116
+ return (RunnableAssign,)
117
+
118
+
119
+ def lc_runnable_branch_types():
120
+ from langchain.schema.runnable import RunnableBranch
121
+
122
+ return (RunnableBranch,)
123
+
124
+
125
+ def lc_runnable_binding_types():
126
+ from langchain.schema.runnable import RunnableBinding
127
+
128
+ return (RunnableBinding,)
129
+
130
+
131
+ def lc_runnables_types():
132
+ return (
133
+ picklable_runnable_types()
134
+ + lc_runnable_with_steps_types()
135
+ + lc_runnable_branch_types()
136
+ + lc_runnable_assign_types()
137
+ + lc_runnable_binding_types()
138
+ )
139
+
140
+
141
+ def langgraph_types():
142
+ try:
143
+ from langgraph.graph.state import CompiledStateGraph
144
+
145
+ return (CompiledStateGraph,)
146
+ except ImportError:
147
+ return ()
148
+
149
+
150
+ def supported_lc_types():
151
+ return base_lc_types() + lc_runnables_types() + langgraph_types()
152
+
153
+
154
+ # Wrapping as a function to avoid callign supported_lc_types() at import time
155
+ def get_unsupported_model_message(model_type):
156
+ return (
157
+ "MLflow langchain flavor only supports subclasses of "
158
+ f"{supported_lc_types()}, found {model_type}."
159
+ )
160
+
161
+
162
+ @lru_cache
163
+ def custom_type_to_loader_dict():
164
+ # helper function to load output_parsers from config
165
+ def _load_output_parser(config: dict[str, Any]) -> Any:
166
+ """Load output parser."""
167
+ from langchain.schema.output_parser import StrOutputParser
168
+
169
+ output_parser_type = config.pop("_type", None)
170
+ if output_parser_type == "default":
171
+ return StrOutputParser(**config)
172
+ else:
173
+ raise ValueError(f"Unsupported output parser {output_parser_type}")
174
+
175
+ return {"default": _load_output_parser}
176
+
177
+
178
+ class _SpecialChainInfo(NamedTuple):
179
+ loader_arg: str
180
+
181
+
182
+ def _get_special_chain_info_or_none(chain):
183
+ for (
184
+ special_chain_class,
185
+ loader_arg,
186
+ ) in _get_map_of_special_chain_class_to_loader_arg().items():
187
+ if isinstance(chain, special_chain_class):
188
+ return _SpecialChainInfo(loader_arg=loader_arg)
189
+
190
+
191
+ @lru_cache
192
+ def _get_map_of_special_chain_class_to_loader_arg():
193
+ import langchain
194
+
195
+ from mlflow.langchain.retriever_chain import _RetrieverChain
196
+
197
+ class_name_to_loader_arg = {
198
+ "langchain.chains.RetrievalQA": "retriever",
199
+ "langchain.chains.APIChain": "requests_wrapper",
200
+ "langchain.chains.HypotheticalDocumentEmbedder": "embeddings",
201
+ }
202
+ # NB: SQLDatabaseChain was migrated to langchain_experimental beginning with version 0.0.247
203
+ if version.parse(langchain.__version__) <= version.parse("0.0.246"):
204
+ class_name_to_loader_arg["langchain.chains.SQLDatabaseChain"] = "database"
205
+ else:
206
+ if find_spec("langchain_experimental"):
207
+ # Add this entry only if langchain_experimental is installed
208
+ class_name_to_loader_arg["langchain_experimental.sql.SQLDatabaseChain"] = "database"
209
+
210
+ class_to_loader_arg = {
211
+ _RetrieverChain: "retriever",
212
+ }
213
+ for class_name, loader_arg in class_name_to_loader_arg.items():
214
+ try:
215
+ cls = _get_class_from_string(class_name)
216
+ class_to_loader_arg[cls] = loader_arg
217
+ except Exception:
218
+ logger.warning(
219
+ "Unexpected import failure for class '%s'. Please file an issue at"
220
+ " https://github.com/mlflow/mlflow/issues/.",
221
+ class_name,
222
+ exc_info=True,
223
+ )
224
+
225
+ return class_to_loader_arg
226
+
227
+
228
+ @lru_cache
229
+ def _get_supported_llms():
230
+ supported_llms = set()
231
+
232
+ def try_adding_llm(module, class_name):
233
+ if cls := getattr(module, class_name, None):
234
+ supported_llms.add(cls)
235
+
236
+ def safe_import_and_add(module_name, class_name):
237
+ """Add conditional support for `partner` and `community` APIs in langchain"""
238
+ try:
239
+ module = importlib.import_module(module_name)
240
+ try_adding_llm(module, class_name)
241
+ except ImportError:
242
+ pass
243
+
244
+ safe_import_and_add("langchain.llms.openai", "OpenAI")
245
+ # HuggingFacePipeline is moved to langchain_huggingface since langchain 0.2.0
246
+ safe_import_and_add("langchain.llms", "HuggingFacePipeline")
247
+ safe_import_and_add("langchain.langchain_huggingface", "HuggingFacePipeline")
248
+ safe_import_and_add("langchain_openai", "OpenAI")
249
+ safe_import_and_add("langchain_databricks", "ChatDatabricks")
250
+ safe_import_and_add("databricks_langchain", "ChatDatabricks")
251
+
252
+ for llm_name in ["Databricks", "Mlflow"]:
253
+ safe_import_and_add("langchain.llms", llm_name)
254
+
255
+ for chat_model_name in [
256
+ "ChatDatabricks",
257
+ "ChatMlflow",
258
+ "ChatOpenAI",
259
+ "AzureChatOpenAI",
260
+ ]:
261
+ safe_import_and_add("langchain.chat_models", chat_model_name)
262
+
263
+ return supported_llms
264
+
265
+
266
+ def _agent_executor_contains_unsupported_llm(lc_model, _SUPPORTED_LLMS):
267
+ import langchain.agents.agent
268
+
269
+ return (
270
+ isinstance(lc_model, langchain.agents.agent.AgentExecutor)
271
+ # 'RunnableMultiActionAgent' object has no attribute 'llm_chain'
272
+ and hasattr(lc_model.agent, "llm_chain")
273
+ and not any(
274
+ isinstance(lc_model.agent.llm_chain.llm, supported_llm)
275
+ for supported_llm in _SUPPORTED_LLMS
276
+ )
277
+ )
278
+
279
+
280
+ # temp_dir is only required when lc_model could be a file path
281
+ def _validate_and_prepare_lc_model_or_path(lc_model, loader_fn, temp_dir=None):
282
+ import langchain.agents.agent
283
+ import langchain.chains.base
284
+ import langchain.chains.llm
285
+ import langchain.llms.huggingface_hub
286
+ import langchain.llms.openai
287
+ import langchain.schema
288
+
289
+ # lc_model is a file path
290
+ if isinstance(lc_model, str):
291
+ return _validate_and_get_model_code_path(lc_model, temp_dir)
292
+
293
+ if not isinstance(lc_model, supported_lc_types()):
294
+ raise mlflow.MlflowException.invalid_parameter_value(
295
+ get_unsupported_model_message(type(lc_model).__name__)
296
+ )
297
+
298
+ _SUPPORTED_LLMS = _get_supported_llms()
299
+ if isinstance(lc_model, langchain.chains.llm.LLMChain) and not any(
300
+ isinstance(lc_model.llm, supported_llm) for supported_llm in _SUPPORTED_LLMS
301
+ ):
302
+ logger.warning(
303
+ _UNSUPPORTED_LLM_WARNING_MESSAGE,
304
+ type(lc_model.llm).__name__,
305
+ )
306
+
307
+ if _agent_executor_contains_unsupported_llm(lc_model, _SUPPORTED_LLMS):
308
+ logger.warning(
309
+ _UNSUPPORTED_LLM_WARNING_MESSAGE,
310
+ type(lc_model.agent.llm_chain.llm).__name__,
311
+ )
312
+
313
+ if special_chain_info := _get_special_chain_info_or_none(lc_model):
314
+ if loader_fn is None:
315
+ raise mlflow.MlflowException.invalid_parameter_value(
316
+ f"For {type(lc_model).__name__} models, a `loader_fn` must be provided."
317
+ )
318
+ if not isinstance(loader_fn, types.FunctionType):
319
+ raise mlflow.MlflowException.invalid_parameter_value(
320
+ "The `loader_fn` must be a function that returns a {loader_arg}.".format(
321
+ loader_arg=special_chain_info.loader_arg
322
+ )
323
+ )
324
+
325
+ # If lc_model is a retriever, wrap it in a _RetrieverChain
326
+ if isinstance(lc_model, langchain.schema.BaseRetriever):
327
+ from mlflow.langchain.retriever_chain import _RetrieverChain
328
+
329
+ if loader_fn is None:
330
+ raise mlflow.MlflowException.invalid_parameter_value(
331
+ f"For {type(lc_model).__name__} models, a `loader_fn` must be provided."
332
+ )
333
+ if not isinstance(loader_fn, types.FunctionType):
334
+ raise mlflow.MlflowException.invalid_parameter_value(
335
+ "The `loader_fn` must be a function that returns a retriever."
336
+ )
337
+ lc_model = _RetrieverChain(retriever=lc_model)
338
+
339
+ return lc_model
340
+
341
+
342
+ def _save_base_lcs(model, path, loader_fn=None, persist_dir=None):
343
+ from langchain.agents.agent import AgentExecutor
344
+ from langchain.chains.base import Chain
345
+ from langchain.chains.llm import LLMChain
346
+ from langchain.chat_models.base import BaseChatModel
347
+
348
+ model_data_path = os.path.join(path, _MODEL_DATA_YAML_FILE_NAME)
349
+ model_data_kwargs = {
350
+ _MODEL_DATA_KEY: _MODEL_DATA_YAML_FILE_NAME,
351
+ _MODEL_LOAD_KEY: _BASE_LOAD_KEY,
352
+ }
353
+
354
+ if isinstance(model, (LLMChain, BaseChatModel)):
355
+ model.save(model_data_path)
356
+ elif isinstance(model, AgentExecutor):
357
+ if model.agent and getattr(model.agent, "llm_chain", None):
358
+ model.agent.llm_chain.save(model_data_path)
359
+
360
+ if model.agent:
361
+ agent_data_path = os.path.join(path, _AGENT_DATA_FILE_NAME)
362
+ model.save_agent(agent_data_path)
363
+ model_data_kwargs[_AGENT_DATA_KEY] = _AGENT_DATA_FILE_NAME
364
+
365
+ if model.tools:
366
+ tools_data_path = os.path.join(path, _TOOLS_DATA_FILE_NAME)
367
+ try:
368
+ with open(tools_data_path, "wb") as f:
369
+ cloudpickle.dump(model.tools, f)
370
+ except Exception as e:
371
+ raise mlflow.MlflowException(
372
+ "Error when attempting to pickle the AgentExecutor tools. "
373
+ "This model likely does not support serialization."
374
+ ) from e
375
+ model_data_kwargs[_TOOLS_DATA_KEY] = _TOOLS_DATA_FILE_NAME
376
+ else:
377
+ raise mlflow.MlflowException.invalid_parameter_value(
378
+ "For initializing the AgentExecutor, tools must be provided."
379
+ )
380
+
381
+ key_to_ignore = ["llm_chain", "agent", "tools", "callback_manager"]
382
+ temp_dict = {k: v for k, v in model.__dict__.items() if k not in key_to_ignore}
383
+
384
+ agent_primitive_path = os.path.join(path, _AGENT_PRIMITIVES_FILE_NAME)
385
+ with open(agent_primitive_path, "w") as config_file:
386
+ json.dump(temp_dict, config_file, indent=4)
387
+
388
+ model_data_kwargs[_AGENT_PRIMITIVES_DATA_KEY] = _AGENT_PRIMITIVES_FILE_NAME
389
+
390
+ elif special_chain_info := _get_special_chain_info_or_none(model):
391
+ # Save loader_fn by pickling
392
+ loader_fn_path = os.path.join(path, _LOADER_FN_FILE_NAME)
393
+ with open(loader_fn_path, "wb") as f:
394
+ cloudpickle.dump(loader_fn, f)
395
+ model_data_kwargs[_LOADER_FN_KEY] = _LOADER_FN_FILE_NAME
396
+ model_data_kwargs[_LOADER_ARG_KEY] = special_chain_info.loader_arg
397
+
398
+ if persist_dir is not None:
399
+ if os.path.exists(persist_dir):
400
+ # Save persist_dir by copying into subdir _PERSIST_DIR_NAME
401
+ persist_dir_data_path = os.path.join(path, _PERSIST_DIR_NAME)
402
+ shutil.copytree(persist_dir, persist_dir_data_path)
403
+ model_data_kwargs[_PERSIST_DIR_KEY] = _PERSIST_DIR_NAME
404
+ else:
405
+ raise mlflow.MlflowException.invalid_parameter_value(
406
+ "The directory provided for persist_dir does not exist."
407
+ )
408
+
409
+ # Save model
410
+ model.save(model_data_path)
411
+ elif isinstance(model, Chain):
412
+ logger.warning(get_unsupported_model_message(type(model).__name__))
413
+ model.save(model_data_path)
414
+ else:
415
+ raise mlflow.MlflowException.invalid_parameter_value(
416
+ get_unsupported_model_message(type(model).__name__)
417
+ )
418
+
419
+ return model_data_kwargs
420
+
421
+
422
+ def _load_from_pickle(path):
423
+ with open(path, "rb") as f:
424
+ return cloudpickle.load(f)
425
+
426
+
427
+ def _load_from_json(path):
428
+ with open(path) as f:
429
+ return json.load(f)
430
+
431
+
432
+ def _load_from_yaml(path):
433
+ with open(path) as f:
434
+ return yaml.safe_load(f)
435
+
436
+
437
+ def _get_path_by_key(root_path, key, conf):
438
+ key_path = conf.get(key)
439
+ return os.path.join(root_path, key_path) if key_path else None
440
+
441
+
442
+ def _patch_loader(loader_func: Callable[..., Any]) -> Callable[..., Any]:
443
+ """
444
+ Patch LangChain loader function like load_chain() to handle the breaking change introduced in
445
+ LangChain 0.1.12.
446
+
447
+ Since langchain-community 0.0.27, loading a module that relies on the pickle deserialization
448
+ requires the `allow_dangerous_deserialization` flag to be set to True, for security reasons.
449
+ However, this flag could not be specified via the LangChain's loading API like load_chain(),
450
+ load_llm(), until LangChain 0.1.14. As a result, such module cannot be loaded with MLflow
451
+ with earlier version of LangChain and we have to tell the user to upgrade LangChain to 0.0.14
452
+ or above.
453
+
454
+ Args:
455
+ loader_func: The LangChain loader function to be patched e.g. load_chain().
456
+
457
+ Returns:
458
+ The patched loader function.
459
+ """
460
+ if not IS_PICKLE_SERIALIZATION_RESTRICTED:
461
+ return loader_func
462
+
463
+ import langchain
464
+
465
+ if Version(langchain.__version__) >= Version("0.1.14"):
466
+ # For LangChain 0.1.14 and above, we can pass `allow_dangerous_deserialization` flag
467
+ # via the loader APIs. Since the model is serialized by the user (or someone who has
468
+ # access to the tracking server), it is safe to set this flag to True.
469
+ def patched_loader(*args, **kwargs):
470
+ return loader_func(*args, **kwargs, allow_dangerous_deserialization=True)
471
+ else:
472
+
473
+ def patched_loader(*args, **kwargs):
474
+ try:
475
+ return loader_func(*args, **kwargs)
476
+ except ValueError as e:
477
+ if "This code relies on the pickle module" in str(e):
478
+ raise MlflowException(
479
+ "Since langchain-community 0.0.27, loading a module that relies on "
480
+ "the pickle deserialization requires the `allow_dangerous_deserialization` "
481
+ "flag to be set to True when loading. However, this flag is not supported "
482
+ "by the installed version of LangChain. Please upgrade LangChain to 0.1.14 "
483
+ "or above by running `pip install langchain>=0.1.14`.",
484
+ error_code=INTERNAL_ERROR,
485
+ ) from e
486
+ else:
487
+ raise
488
+
489
+ return patched_loader
490
+
491
+
492
+ def _load_base_lcs(
493
+ local_model_path,
494
+ conf,
495
+ ):
496
+ lc_model_path = os.path.join(
497
+ local_model_path, conf.get(_MODEL_DATA_KEY, _MODEL_DATA_YAML_FILE_NAME)
498
+ )
499
+
500
+ agent_path = _get_path_by_key(local_model_path, _AGENT_DATA_KEY, conf)
501
+ tools_path = _get_path_by_key(local_model_path, _TOOLS_DATA_KEY, conf)
502
+ agent_primitive_path = _get_path_by_key(local_model_path, _AGENT_PRIMITIVES_DATA_KEY, conf)
503
+ loader_fn_path = _get_path_by_key(local_model_path, _LOADER_FN_KEY, conf)
504
+ persist_dir = _get_path_by_key(local_model_path, _PERSIST_DIR_KEY, conf)
505
+
506
+ model_type = conf.get(_MODEL_TYPE_KEY)
507
+ loader_arg = conf.get(_LOADER_ARG_KEY)
508
+
509
+ from langchain.chains.loading import load_chain
510
+
511
+ from mlflow.langchain.retriever_chain import _RetrieverChain
512
+
513
+ if loader_arg is not None:
514
+ if loader_fn_path is None:
515
+ raise mlflow.MlflowException.invalid_parameter_value(
516
+ "Missing file for loader_fn which is required to build the model."
517
+ )
518
+ loader_fn = _load_from_pickle(loader_fn_path)
519
+ kwargs = {loader_arg: loader_fn(persist_dir)}
520
+ if model_type == _RetrieverChain.__name__:
521
+ model = _RetrieverChain.load(lc_model_path, **kwargs).retriever
522
+ else:
523
+ model = _patch_loader(load_chain)(lc_model_path, **kwargs)
524
+ elif agent_path is None and tools_path is None:
525
+ model = _patch_loader(load_chain)(lc_model_path)
526
+ else:
527
+ from langchain.agents import initialize_agent
528
+
529
+ llm = _patch_loader(load_chain)(lc_model_path)
530
+ tools = []
531
+ kwargs = {}
532
+
533
+ if os.path.exists(tools_path):
534
+ tools = _load_from_pickle(tools_path)
535
+ else:
536
+ raise mlflow.MlflowException(
537
+ "Missing file for tools which is required to build the AgentExecutor object."
538
+ )
539
+
540
+ if os.path.exists(agent_primitive_path):
541
+ kwargs = _load_from_json(agent_primitive_path)
542
+
543
+ model = initialize_agent(tools=tools, llm=llm, agent_path=agent_path, **kwargs)
544
+ return model
545
+
546
+
547
+ def patch_langchain_type_to_cls_dict(func):
548
+ @functools.wraps(func)
549
+ def wrapper(*args, **kwargs):
550
+ def _load_chat_openai():
551
+ from langchain_community.chat_models import ChatOpenAI
552
+
553
+ return ChatOpenAI
554
+
555
+ def _load_azure_chat_openai():
556
+ from langchain_community.chat_models import AzureChatOpenAI
557
+
558
+ return AzureChatOpenAI
559
+
560
+ def _load_chat_databricks():
561
+ from databricks_langchain import ChatDatabricks
562
+
563
+ return ChatDatabricks
564
+
565
+ def _patched_get_type_to_cls_dict(original):
566
+ def _wrapped():
567
+ return {
568
+ **original(),
569
+ "openai-chat": _load_chat_openai,
570
+ "azure-openai-chat": _load_azure_chat_openai,
571
+ "chat-databricks": _load_chat_databricks,
572
+ }
573
+
574
+ return _wrapped
575
+
576
+ modules_to_patch = [
577
+ "langchain_databricks",
578
+ "langchain.llms",
579
+ "langchain_community.llms.loading",
580
+ ]
581
+ originals = {}
582
+ for name in modules_to_patch:
583
+ try:
584
+ module = importlib.import_module(name)
585
+ originals[name] = module.get_type_to_cls_dict # Record original impl for cleanup
586
+ except (ImportError, AttributeError):
587
+ continue
588
+ module.get_type_to_cls_dict = _patched_get_type_to_cls_dict(originals[name])
589
+
590
+ try:
591
+ return func(*args, **kwargs)
592
+ except ValueError as e:
593
+ if m := _CHAT_MODELS_ERROR_MSG.search(str(e)):
594
+ model_name = "ChatOpenAI" if m.group(1) == "openai-chat" else "AzureChatOpenAI"
595
+ raise mlflow.MlflowException(
596
+ f"Loading {model_name} chat model is not supported in MLflow with the "
597
+ "current version of LangChain. Please upgrade LangChain to 0.0.307 or above "
598
+ "by running `pip install langchain>=0.0.307`."
599
+ ) from e
600
+ else:
601
+ raise
602
+ finally:
603
+ # Clean up the patch
604
+ for module_name, original_impl in originals.items():
605
+ module = importlib.import_module(module_name)
606
+ module.get_type_to_cls_dict = original_impl
607
+
608
+ return wrapper
609
+
610
+
611
+ def register_pydantic_serializer():
612
+ """
613
+ Helper function to pickle pydantic fields for pydantic v1.
614
+ Pydantic's Cython validators are not serializable.
615
+ https://github.com/cloudpipe/cloudpickle/issues/408
616
+ """
617
+ import pydantic
618
+
619
+ if Version(pydantic.__version__) >= Version("2.0.0"):
620
+ return
621
+
622
+ import pydantic.fields
623
+
624
+ def custom_serializer(obj):
625
+ return {
626
+ "name": obj.name,
627
+ # outer_type_ is the original type for ModelFields,
628
+ # while type_ can be updated later with the nested type
629
+ # like int for List[int].
630
+ "type_": obj.outer_type_,
631
+ "class_validators": obj.class_validators,
632
+ "model_config": obj.model_config,
633
+ "default": obj.default,
634
+ "default_factory": obj.default_factory,
635
+ "required": obj.required,
636
+ "final": obj.final,
637
+ "alias": obj.alias,
638
+ "field_info": obj.field_info,
639
+ }
640
+
641
+ def custom_deserializer(kwargs):
642
+ return pydantic.fields.ModelField(**kwargs)
643
+
644
+ def _CloudPicklerReducer(obj):
645
+ return custom_deserializer, (custom_serializer(obj),)
646
+
647
+ warnings.warn(
648
+ "Using custom serializer to pickle pydantic.fields.ModelField classes, "
649
+ "this might miss some fields and validators. To avoid this, "
650
+ "please upgrade pydantic to v2 using `pip install pydantic -U` with "
651
+ "langchain 0.0.267 and above."
652
+ )
653
+ cloudpickle.CloudPickler.dispatch[pydantic.fields.ModelField] = _CloudPicklerReducer
654
+
655
+
656
+ def unregister_pydantic_serializer():
657
+ import pydantic
658
+
659
+ if Version(pydantic.__version__) >= Version("2.0.0"):
660
+ return
661
+
662
+ cloudpickle.CloudPickler.dispatch.pop(pydantic.fields.ModelField, None)
663
+
664
+
665
+ @contextlib.contextmanager
666
+ def register_pydantic_v1_serializer_cm():
667
+ try:
668
+ register_pydantic_serializer()
669
+ yield
670
+ finally:
671
+ unregister_pydantic_serializer()
@@ -0,0 +1,36 @@
1
+ import inspect
2
+
3
+ from packaging.version import Version
4
+
5
+
6
+ def convert_to_serializable(response):
7
+ """
8
+ Convert the response to a JSON serializable format.
9
+
10
+ LangChain response objects often contains Pydantic objects, which causes an serialization
11
+ error when the model is served behind REST endpoint.
12
+ """
13
+ import langchain
14
+
15
+ # LangChain >= 0.3.0 uses Pydantic 2.x while < 0.3.0 is based on Pydantic 1.x.
16
+ if Version(langchain.__version__) >= Version("0.3.0"):
17
+ from pydantic import BaseModel
18
+
19
+ if isinstance(response, BaseModel):
20
+ return response.model_dump()
21
+ else:
22
+ from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel
23
+
24
+ if isinstance(response, LangChainBaseModel):
25
+ return response.dict()
26
+
27
+ if inspect.isgenerator(response):
28
+ return (convert_to_serializable(chunk) for chunk in response)
29
+ elif isinstance(response, dict):
30
+ return {k: convert_to_serializable(v) for k, v in response.items()}
31
+ elif isinstance(response, list):
32
+ return [convert_to_serializable(v) for v in response]
33
+ elif isinstance(response, tuple):
34
+ return tuple(convert_to_serializable(v) for v in response)
35
+
36
+ return response
File without changes