genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,619 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ import subprocess
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+
10
+ from mlflow.environment_variables import _MLFLOW_TESTING
11
+ from mlflow.metrics.base import MetricValue, standard_aggregations
12
+
13
+ _logger = logging.getLogger(__name__)
14
+
15
+
16
+ # used to silently fail with invalid metric params
17
+ def noop(*args, **kwargs):
18
+ return None
19
+
20
+
21
+ targets_col_specifier = "the column specified by the `targets` parameter"
22
+ predictions_col_specifier = (
23
+ "the column specified by the `predictions` parameter or the model output column"
24
+ )
25
+
26
+
27
+ def _validate_text_data(data, metric_name, col_specifier):
28
+ """Validates that the data is a list of strs and is non-empty"""
29
+ if data is None or len(data) == 0:
30
+ _logger.warning(
31
+ f"Cannot calculate {metric_name} for empty inputs: "
32
+ f"{col_specifier} is empty or the parameter is not specified. Skipping metric logging."
33
+ )
34
+ return False
35
+
36
+ for row, line in enumerate(data):
37
+ if not isinstance(line, str):
38
+ _logger.warning(
39
+ f"Cannot calculate {metric_name} for non-string inputs. "
40
+ f"Non-string found for {col_specifier} on row {row}. Skipping metric logging."
41
+ )
42
+ return False
43
+
44
+ return True
45
+
46
+
47
+ def _validate_array_like_id_data(data, metric_name, col_specifier):
48
+ """Validates that the data is a list of lists/np.ndarrays of strings/ints and is non-empty"""
49
+ if data is None or len(data) == 0:
50
+ return False
51
+
52
+ for index, value in data.items():
53
+ if not (
54
+ (isinstance(value, list) and all(isinstance(val, (str, int)) for val in value))
55
+ or (
56
+ isinstance(value, np.ndarray)
57
+ and (np.issubdtype(value.dtype, str) or np.issubdtype(value.dtype, int))
58
+ )
59
+ ):
60
+ _logger.warning(
61
+ f"Cannot calculate metric '{metric_name}' for non-arraylike of string or int "
62
+ f"inputs. Non-arraylike of strings/ints found for {col_specifier} on row "
63
+ f"{index}, value {value}. Skipping metric logging."
64
+ )
65
+ return False
66
+
67
+ return True
68
+
69
+
70
+ def _token_count_eval_fn(predictions, targets=None, metrics=None):
71
+ import tiktoken
72
+
73
+ # ref: https://github.com/openai/tiktoken/issues/75
74
+ # Only set TIKTOKEN_CACHE_DIR if not already set by user
75
+ if "TIKTOKEN_CACHE_DIR" not in os.environ:
76
+ os.environ["TIKTOKEN_CACHE_DIR"] = ""
77
+ encoding = tiktoken.get_encoding("cl100k_base")
78
+
79
+ num_tokens = []
80
+ for prediction in predictions:
81
+ if isinstance(prediction, str):
82
+ num_tokens.append(len(encoding.encode(prediction)))
83
+ else:
84
+ num_tokens.append(None)
85
+
86
+ return MetricValue(
87
+ scores=num_tokens,
88
+ aggregate_results={},
89
+ )
90
+
91
+
92
+ def _load_from_github(path: str, module_type: str = "metric"):
93
+ import evaluate
94
+
95
+ with tempfile.TemporaryDirectory() as tmpdir:
96
+ tmpdir = Path(tmpdir)
97
+ subprocess.check_call(
98
+ [
99
+ "git",
100
+ "clone",
101
+ "--filter=blob:none",
102
+ "--no-checkout",
103
+ "https://github.com/huggingface/evaluate.git",
104
+ tmpdir,
105
+ ]
106
+ )
107
+ path = f"{module_type}s/{path}"
108
+ subprocess.check_call(["git", "sparse-checkout", "set", path], cwd=tmpdir)
109
+ subprocess.check_call(["git", "checkout"], cwd=tmpdir)
110
+ return evaluate.load(str(tmpdir / path))
111
+
112
+
113
+ @functools.lru_cache(maxsize=8)
114
+ def _cached_evaluate_load(path: str, module_type: str = "metric"):
115
+ import evaluate
116
+
117
+ try:
118
+ return evaluate.load(path, module_type=module_type)
119
+ except (FileNotFoundError, OSError):
120
+ if _MLFLOW_TESTING.get():
121
+ # `evaluate.load` is highly unstable and often fails due to a network error or
122
+ # huggingface hub being down. In testing, we want to avoid this instability, so we
123
+ # load the metric from the evaluate repository on GitHub.
124
+ return _load_from_github(path, module_type=module_type)
125
+ raise
126
+
127
+
128
+ def _toxicity_eval_fn(predictions, targets=None, metrics=None):
129
+ if not _validate_text_data(predictions, "toxicity", predictions_col_specifier):
130
+ return
131
+ try:
132
+ toxicity = _cached_evaluate_load("toxicity", module_type="measurement")
133
+ except Exception as e:
134
+ _logger.warning(
135
+ f"Failed to load 'toxicity' metric (error: {e!r}), skipping metric logging."
136
+ )
137
+ return
138
+
139
+ scores = toxicity.compute(predictions=predictions)["toxicity"]
140
+ toxicity_ratio = toxicity.compute(predictions=predictions, aggregation="ratio")[
141
+ "toxicity_ratio"
142
+ ]
143
+ return MetricValue(
144
+ scores=scores,
145
+ aggregate_results={
146
+ **standard_aggregations(scores),
147
+ "ratio": toxicity_ratio,
148
+ },
149
+ )
150
+
151
+
152
+ def _flesch_kincaid_eval_fn(predictions, targets=None, metrics=None):
153
+ if not _validate_text_data(predictions, "flesch_kincaid", predictions_col_specifier):
154
+ return
155
+
156
+ try:
157
+ import textstat
158
+ except ImportError:
159
+ _logger.warning(
160
+ "Failed to import textstat for flesch kincaid metric, skipping metric logging. "
161
+ "Please install textstat using 'pip install textstat'."
162
+ )
163
+ return
164
+
165
+ scores = [textstat.flesch_kincaid_grade(prediction) for prediction in predictions]
166
+ return MetricValue(
167
+ scores=scores,
168
+ aggregate_results=standard_aggregations(scores),
169
+ )
170
+
171
+
172
+ def _ari_eval_fn(predictions, targets=None, metrics=None):
173
+ if not _validate_text_data(predictions, "ari", predictions_col_specifier):
174
+ return
175
+
176
+ try:
177
+ import textstat
178
+ except ImportError:
179
+ _logger.warning(
180
+ "Failed to import textstat for automated readability index metric, "
181
+ "skipping metric logging. "
182
+ "Please install textstat using 'pip install textstat'."
183
+ )
184
+ return
185
+
186
+ scores = [textstat.automated_readability_index(prediction) for prediction in predictions]
187
+ return MetricValue(
188
+ scores=scores,
189
+ aggregate_results=standard_aggregations(scores),
190
+ )
191
+
192
+
193
+ def _accuracy_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
194
+ if targets is not None and len(targets) != 0:
195
+ from sklearn.metrics import accuracy_score
196
+
197
+ acc = accuracy_score(y_true=targets, y_pred=predictions, sample_weight=sample_weight)
198
+ return MetricValue(aggregate_results={"exact_match": acc})
199
+
200
+
201
+ def _rouge1_eval_fn(predictions, targets=None, metrics=None):
202
+ if not _validate_text_data(targets, "rouge1", targets_col_specifier) or not _validate_text_data(
203
+ predictions, "rouge1", predictions_col_specifier
204
+ ):
205
+ return
206
+
207
+ try:
208
+ rouge = _cached_evaluate_load("rouge")
209
+ except Exception as e:
210
+ _logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
211
+ return
212
+
213
+ scores = rouge.compute(
214
+ predictions=predictions,
215
+ references=targets,
216
+ rouge_types=["rouge1"],
217
+ use_aggregator=False,
218
+ )["rouge1"]
219
+ return MetricValue(
220
+ scores=scores,
221
+ aggregate_results=standard_aggregations(scores),
222
+ )
223
+
224
+
225
+ def _rouge2_eval_fn(predictions, targets=None, metrics=None):
226
+ if not _validate_text_data(targets, "rouge2", targets_col_specifier) or not _validate_text_data(
227
+ predictions, "rouge2", predictions_col_specifier
228
+ ):
229
+ return
230
+
231
+ try:
232
+ rouge = _cached_evaluate_load("rouge")
233
+ except Exception as e:
234
+ _logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
235
+ return
236
+
237
+ scores = rouge.compute(
238
+ predictions=predictions,
239
+ references=targets,
240
+ rouge_types=["rouge2"],
241
+ use_aggregator=False,
242
+ )["rouge2"]
243
+ return MetricValue(
244
+ scores=scores,
245
+ aggregate_results=standard_aggregations(scores),
246
+ )
247
+
248
+
249
+ def _rougeL_eval_fn(predictions, targets=None, metrics=None):
250
+ if not _validate_text_data(targets, "rougeL", targets_col_specifier) or not _validate_text_data(
251
+ predictions, "rougeL", predictions_col_specifier
252
+ ):
253
+ return
254
+
255
+ try:
256
+ rouge = _cached_evaluate_load("rouge")
257
+ except Exception as e:
258
+ _logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
259
+ return
260
+
261
+ scores = rouge.compute(
262
+ predictions=predictions,
263
+ references=targets,
264
+ rouge_types=["rougeL"],
265
+ use_aggregator=False,
266
+ )["rougeL"]
267
+ return MetricValue(
268
+ scores=scores,
269
+ aggregate_results=standard_aggregations(scores),
270
+ )
271
+
272
+
273
+ def _rougeLsum_eval_fn(predictions, targets=None, metrics=None):
274
+ if not _validate_text_data(
275
+ targets, "rougeLsum", targets_col_specifier
276
+ ) or not _validate_text_data(predictions, "rougeLsum", predictions_col_specifier):
277
+ return
278
+
279
+ try:
280
+ rouge = _cached_evaluate_load("rouge")
281
+ except Exception as e:
282
+ _logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
283
+ return
284
+
285
+ scores = rouge.compute(
286
+ predictions=predictions,
287
+ references=targets,
288
+ rouge_types=["rougeLsum"],
289
+ use_aggregator=False,
290
+ )["rougeLsum"]
291
+ return MetricValue(
292
+ scores=scores,
293
+ aggregate_results=standard_aggregations(scores),
294
+ )
295
+
296
+
297
+ def _mae_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
298
+ if targets is not None and len(targets) != 0:
299
+ from sklearn.metrics import mean_absolute_error
300
+
301
+ mae = mean_absolute_error(targets, predictions, sample_weight=sample_weight)
302
+ return MetricValue(aggregate_results={"mean_absolute_error": mae})
303
+
304
+
305
+ def _mse_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
306
+ if targets is not None and len(targets) != 0:
307
+ from sklearn.metrics import mean_squared_error
308
+
309
+ mse = mean_squared_error(targets, predictions, sample_weight=sample_weight)
310
+ return MetricValue(aggregate_results={"mean_squared_error": mse})
311
+
312
+
313
+ def _root_mean_squared_error(*, y_true, y_pred, sample_weight):
314
+ try:
315
+ from sklearn.metrics import root_mean_squared_error
316
+ except ImportError:
317
+ # If root_mean_squared_error is unavailable, fall back to
318
+ # `mean_squared_error(..., squared=False)`, which is deprecated in scikit-learn >= 1.4.
319
+ from sklearn.metrics import mean_squared_error
320
+
321
+ return mean_squared_error(
322
+ y_true=y_true, y_pred=y_pred, sample_weight=sample_weight, squared=False
323
+ )
324
+ else:
325
+ return root_mean_squared_error(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
326
+
327
+
328
+ def _rmse_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
329
+ if targets is not None and len(targets) != 0:
330
+ rmse = _root_mean_squared_error(
331
+ y_true=targets, y_pred=predictions, sample_weight=sample_weight
332
+ )
333
+ return MetricValue(aggregate_results={"root_mean_squared_error": rmse})
334
+
335
+
336
+ def _r2_score_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
337
+ if targets is not None and len(targets) != 0:
338
+ from sklearn.metrics import r2_score
339
+
340
+ r2 = r2_score(targets, predictions, sample_weight=sample_weight)
341
+ return MetricValue(aggregate_results={"r2_score": r2})
342
+
343
+
344
+ def _max_error_eval_fn(predictions, targets=None, metrics=None):
345
+ if targets is not None and len(targets) != 0:
346
+ from sklearn.metrics import max_error
347
+
348
+ error = max_error(targets, predictions)
349
+ return MetricValue(aggregate_results={"max_error": error})
350
+
351
+
352
+ def _mape_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
353
+ if targets is not None and len(targets) != 0:
354
+ from sklearn.metrics import mean_absolute_percentage_error
355
+
356
+ mape = mean_absolute_percentage_error(targets, predictions, sample_weight=sample_weight)
357
+ return MetricValue(aggregate_results={"mean_absolute_percentage_error": mape})
358
+
359
+
360
+ def _recall_eval_fn(
361
+ predictions, targets=None, metrics=None, pos_label=1, average="binary", sample_weight=None
362
+ ):
363
+ if targets is not None and len(targets) != 0:
364
+ from sklearn.metrics import recall_score
365
+
366
+ recall = recall_score(
367
+ targets, predictions, pos_label=pos_label, average=average, sample_weight=sample_weight
368
+ )
369
+ return MetricValue(aggregate_results={"recall_score": recall})
370
+
371
+
372
+ def _precision_eval_fn(
373
+ predictions, targets=None, metrics=None, pos_label=1, average="binary", sample_weight=None
374
+ ):
375
+ if targets is not None and len(targets) != 0:
376
+ from sklearn.metrics import precision_score
377
+
378
+ precision = precision_score(
379
+ targets,
380
+ predictions,
381
+ pos_label=pos_label,
382
+ average=average,
383
+ sample_weight=sample_weight,
384
+ )
385
+ return MetricValue(aggregate_results={"precision_score": precision})
386
+
387
+
388
+ def _f1_score_eval_fn(
389
+ predictions, targets=None, metrics=None, pos_label=1, average="binary", sample_weight=None
390
+ ):
391
+ if targets is not None and len(targets) != 0:
392
+ from sklearn.metrics import f1_score
393
+
394
+ f1 = f1_score(
395
+ targets,
396
+ predictions,
397
+ pos_label=pos_label,
398
+ average=average,
399
+ sample_weight=sample_weight,
400
+ )
401
+ return MetricValue(aggregate_results={"f1_score": f1})
402
+
403
+
404
+ def _precision_at_k_eval_fn(k):
405
+ if not (isinstance(k, int) and k > 0):
406
+ _logger.warning(
407
+ f"Cannot calculate 'precision_at_k' for invalid parameter 'k'. "
408
+ f"'k' should be a positive integer; found: {k}. Skipping metric logging."
409
+ )
410
+ return noop
411
+
412
+ def _fn(predictions, targets):
413
+ if not _validate_array_like_id_data(
414
+ predictions, "precision_at_k", predictions_col_specifier
415
+ ) or not _validate_array_like_id_data(targets, "precision_at_k", targets_col_specifier):
416
+ return
417
+
418
+ scores = []
419
+ for target, prediction in zip(targets, predictions):
420
+ # only include the top k retrieved chunks
421
+ ground_truth = set(target)
422
+ retrieved = prediction[:k]
423
+ relevant_doc_count = sum(1 for doc in retrieved if doc in ground_truth)
424
+ if len(retrieved) > 0:
425
+ scores.append(relevant_doc_count / len(retrieved))
426
+ else:
427
+ # when no documents are retrieved, precision is 0
428
+ scores.append(0)
429
+
430
+ return MetricValue(scores=scores, aggregate_results=standard_aggregations(scores))
431
+
432
+ return _fn
433
+
434
+
435
+ def _expand_duplicate_retrieved_docs(predictions, targets):
436
+ counter = {}
437
+ expanded_predictions = []
438
+ expanded_targets = targets
439
+ for doc_id in predictions:
440
+ if doc_id not in counter:
441
+ counter[doc_id] = 1
442
+ expanded_predictions.append(doc_id)
443
+ else:
444
+ counter[doc_id] += 1
445
+ new_doc_id = (
446
+ f"{doc_id}_bc574ae_{counter[doc_id]}" # adding a random string to avoid collisions
447
+ )
448
+ expanded_predictions.append(new_doc_id)
449
+ if doc_id in expanded_targets:
450
+ expanded_targets.add(new_doc_id)
451
+ return expanded_predictions, expanded_targets
452
+
453
+
454
+ def _prepare_row_for_ndcg(predictions, targets):
455
+ """Prepare data one row from predictions and targets to y_score, y_true for ndcg calculation.
456
+
457
+ Args:
458
+ predictions: A list of strings of at most k doc IDs retrieved.
459
+ targets: A list of strings of ground-truth doc IDs.
460
+
461
+ Returns:
462
+ y_true : ndarray of shape (1, n_docs) Representing the ground-truth relevant docs.
463
+ n_docs is the number of unique docs in union of predictions and targets.
464
+ y_score : ndarray of shape (1, n_docs) Representing the retrieved docs.
465
+ n_docs is the number of unique docs in union of predictions and targets.
466
+ """
467
+ # sklearn does an internal sort of y_score, so to preserve the order of our retrieved
468
+ # docs, we need to modify the relevance value slightly
469
+ eps = 1e-6
470
+
471
+ # support predictions containing duplicate doc ID
472
+ targets = set(targets)
473
+ predictions, targets = _expand_duplicate_retrieved_docs(predictions, targets)
474
+
475
+ all_docs = targets.union(predictions)
476
+ doc_id_to_index = {doc_id: i for i, doc_id in enumerate(all_docs)}
477
+ n_labels = max(len(doc_id_to_index), 2) # sklearn.metrics.ndcg_score requires at least 2 labels
478
+ y_true = np.zeros((1, n_labels), dtype=np.float32)
479
+ y_score = np.zeros((1, n_labels), dtype=np.float32)
480
+ for i, doc_id in enumerate(predictions):
481
+ # "1 - i * eps" means we assign higher score to docs that are ranked higher,
482
+ # but all scores are still approximately 1.
483
+ y_score[0, doc_id_to_index[doc_id]] = 1 - i * eps
484
+ for doc_id in targets:
485
+ y_true[0, doc_id_to_index[doc_id]] = 1
486
+ return y_score, y_true
487
+
488
+
489
+ def _ndcg_at_k_eval_fn(k):
490
+ if not (isinstance(k, int) and k > 0):
491
+ _logger.warning(
492
+ f"Cannot calculate 'ndcg_at_k' for invalid parameter 'k'. "
493
+ f"'k' should be a positive integer; found: {k}. Skipping metric logging."
494
+ )
495
+ return noop
496
+
497
+ def _fn(predictions, targets):
498
+ from sklearn.metrics import ndcg_score
499
+
500
+ if not _validate_array_like_id_data(
501
+ predictions, "ndcg_at_k", predictions_col_specifier
502
+ ) or not _validate_array_like_id_data(targets, "ndcg_at_k", targets_col_specifier):
503
+ return
504
+
505
+ scores = []
506
+ for ground_truth, retrieved in zip(targets, predictions):
507
+ # 1. If no ground truth doc IDs are provided and no documents are retrieved,
508
+ # the score is 1.
509
+ if len(retrieved) == 0 and len(ground_truth) == 0:
510
+ scores.append(1) # no error is made
511
+ continue
512
+ # 2. If no ground truth doc IDs are provided and documents are retrieved,
513
+ # the score is 0.
514
+ # 3. If ground truth doc IDs are provided and no documents are retrieved,
515
+ # the score is 0.
516
+ if len(retrieved) == 0 or len(ground_truth) == 0:
517
+ scores.append(0)
518
+ continue
519
+
520
+ # only include the top k retrieved chunks
521
+ y_score, y_true = _prepare_row_for_ndcg(retrieved[:k], ground_truth)
522
+ score = ndcg_score(y_true, y_score, k=len(retrieved[:k]), ignore_ties=True)
523
+ scores.append(score)
524
+
525
+ return MetricValue(scores=scores, aggregate_results=standard_aggregations(scores))
526
+
527
+ return _fn
528
+
529
+
530
+ def _recall_at_k_eval_fn(k):
531
+ if not (isinstance(k, int) and k > 0):
532
+ _logger.warning(
533
+ f"Cannot calculate 'recall_at_k' for invalid parameter 'k'. "
534
+ f"'k' should be a positive integer; found: {k}. Skipping metric logging."
535
+ )
536
+ return noop
537
+
538
+ def _fn(predictions, targets):
539
+ if not _validate_array_like_id_data(
540
+ predictions, "recall_at_k", predictions_col_specifier
541
+ ) or not _validate_array_like_id_data(targets, "recall_at_k", targets_col_specifier):
542
+ return
543
+
544
+ scores = []
545
+ for target, prediction in zip(targets, predictions):
546
+ # only include the top k retrieved chunks
547
+ ground_truth = set(target)
548
+ retrieved = set(prediction[:k])
549
+ relevant_doc_count = len(ground_truth.intersection(retrieved))
550
+ if len(ground_truth) > 0:
551
+ scores.append(relevant_doc_count / len(ground_truth))
552
+ elif len(retrieved) == 0:
553
+ # there are 0 retrieved and ground truth docs, so reward for the match
554
+ scores.append(1)
555
+ else:
556
+ # there are > 0 retrieved, but 0 ground truth, so penalize
557
+ scores.append(0)
558
+
559
+ return MetricValue(scores=scores, aggregate_results=standard_aggregations(scores))
560
+
561
+ return _fn
562
+
563
+
564
+ def _bleu_eval_fn(predictions, targets=None, metrics=None):
565
+ # Validate input data
566
+ if not _validate_text_data(targets, "bleu", targets_col_specifier):
567
+ _logger.error(
568
+ """Target validation failed.
569
+ Ensure targets are valid for BLEU computation."""
570
+ )
571
+ return
572
+ if not _validate_text_data(predictions, "bleu", predictions_col_specifier):
573
+ _logger.error(
574
+ """Prediction validation failed.
575
+ Ensure predictions are valid for BLEU computation."""
576
+ )
577
+ return
578
+
579
+ # Load BLEU metric
580
+ try:
581
+ bleu = _cached_evaluate_load("bleu")
582
+ except Exception as e:
583
+ _logger.warning(f"Failed to load 'bleu' metric (error: {e!r}), skipping metric logging.")
584
+ return
585
+
586
+ # Calculate BLEU scores for each prediction-target pair
587
+ result = []
588
+ invalid_indices = []
589
+
590
+ for i, (prediction, target) in enumerate(zip(predictions, targets)):
591
+ if len(target) == 0 or len(prediction) == 0:
592
+ invalid_indices.append(i)
593
+ result.append(0) # Append 0 as a placeholder for invalid entries
594
+ continue
595
+
596
+ try:
597
+ score = bleu.compute(predictions=[prediction], references=[[target]])
598
+ result.append(score["bleu"])
599
+ except Exception as e:
600
+ _logger.warning(f"Failed to calculate BLEU for row {i} (error: {e!r}). Skipping.")
601
+ result.append(0) # Append 0 for consistency if an unexpected error occurs
602
+
603
+ # Log warning for any invalid indices
604
+ if invalid_indices:
605
+ _logger.warning(
606
+ f"BLEU score calculation skipped for the following indices "
607
+ f"due to empty target or prediction: {invalid_indices}. "
608
+ f"A score of 0 was appended for these entries."
609
+ )
610
+
611
+ # Return results
612
+ if not result:
613
+ _logger.warning("No BLEU scores were calculated due to input errors.")
614
+ return
615
+
616
+ return MetricValue(
617
+ scores=result,
618
+ aggregate_results=standard_aggregations(result),
619
+ )
mlflow/mismatch.py ADDED
@@ -0,0 +1,34 @@
1
+ import importlib.metadata
2
+ import warnings
3
+ from typing import Optional
4
+
5
+
6
+ def _get_version(package_name: str) -> Optional[str]:
7
+ try:
8
+ return importlib.metadata.version(package_name)
9
+ except importlib.metadata.PackageNotFoundError:
10
+ return None
11
+
12
+
13
+ def _check_version_mismatch() -> None:
14
+ """
15
+ Warns if both mlflow and mlflow-skinny are installed but their versions are different.
16
+
17
+ Reference: https://github.com/pypa/pip/issues/4625
18
+ """
19
+ if (
20
+ (mlflow_ver := _get_version("mlflow"))
21
+ and ("dev" not in mlflow_ver)
22
+ and (skinny_ver := _get_version("mlflow-skinny"))
23
+ and ("dev" not in skinny_ver)
24
+ and mlflow_ver != skinny_ver
25
+ ):
26
+ warnings.warn(
27
+ (
28
+ f"Versions of mlflow ({mlflow_ver}) and mlflow-skinny ({skinny_ver}) "
29
+ "are different. This may lead to unexpected behavior. "
30
+ "Please install the same version of both packages."
31
+ ),
32
+ stacklevel=2,
33
+ category=UserWarning,
34
+ )
@@ -0,0 +1,34 @@
1
+ from mlflow.mistral.autolog import patched_class_call
2
+ from mlflow.utils.annotations import experimental
3
+ from mlflow.utils.autologging_utils import autologging_integration, safe_patch
4
+
5
+ FLAVOR_NAME = "mistral"
6
+
7
+
8
+ @experimental(version="2.21.0")
9
+ @autologging_integration(FLAVOR_NAME)
10
+ def autolog(
11
+ log_traces: bool = True,
12
+ disable: bool = False,
13
+ silent: bool = False,
14
+ ):
15
+ """
16
+ Enables (or disables) and configures autologging from Mistral AI to MLflow.
17
+ Only synchronous calls to the Text generation API are supported.
18
+ Asynchronous APIs and streaming are not recorded.
19
+
20
+ Args:
21
+ log_traces: If ``True``, traces are logged for Mistral AI models.
22
+ If ``False``, no traces are collected during inference. Default to ``True``.
23
+ disable: If ``True``, disables the Mistral AI autologging. Default to ``False``.
24
+ silent: If ``True``, suppress all event logs and warnings from MLflow during Mistral AI
25
+ autologging. If ``False``, show all events and warnings.
26
+ """
27
+ from mistralai.chat import Chat
28
+
29
+ safe_patch(
30
+ FLAVOR_NAME,
31
+ Chat,
32
+ "complete",
33
+ patched_class_call,
34
+ )