genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1104 @@
1
+ import abc
2
+ import functools
3
+ import inspect
4
+ import itertools
5
+ import uuid
6
+ from contextlib import asynccontextmanager, contextmanager
7
+ from typing import Any, Callable, NamedTuple, Optional
8
+
9
+ import mlflow
10
+ import mlflow.utils.autologging_utils
11
+ from mlflow.entities.run_status import RunStatus
12
+ from mlflow.environment_variables import _MLFLOW_AUTOLOGGING_TESTING
13
+ from mlflow.exceptions import MlflowException
14
+ from mlflow.utils import gorilla, is_iterator
15
+ from mlflow.utils.autologging_utils import _logger
16
+ from mlflow.utils.autologging_utils.events import AutologgingEventLoggerWrapper
17
+ from mlflow.utils.autologging_utils.logging_and_warnings import (
18
+ MlflowEventsAndWarningsBehaviorGlobally,
19
+ NonMlflowWarningsBehaviorForCurrentThread,
20
+ )
21
+ from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING
22
+
23
+ _AUTOLOGGING_PATCHES = {}
24
+
25
+
26
+ # Function attribute used for testing purposes to verify that a given function
27
+ # has been wrapped with the `exception_safe_function_for_class` and
28
+ # `picklable_exception_safe_function` decorators
29
+ _ATTRIBUTE_EXCEPTION_SAFE = "exception_safe"
30
+
31
+
32
+ _ERROR_MSG = "Encountered unexpected error during {} autologging: {}"
33
+
34
+
35
+ def exception_safe_function_for_class(function):
36
+ """
37
+ Wraps the specified function with broad exception handling to guard
38
+ against unexpected errors during autologging.
39
+ Note this function creates an unpicklable function as `safe_function` is locally defined,
40
+ but a class instance containing methods decorated by this function should be pickalable,
41
+ because pickle only saves instance attributes, not methods.
42
+ See https://docs.python.org/3/library/pickle.html#pickling-class-instances for more details.
43
+ """
44
+ if is_testing():
45
+ setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
46
+
47
+ def safe_function(*args, **kwargs):
48
+ try:
49
+ return function(*args, **kwargs)
50
+ except Exception as e:
51
+ if is_testing():
52
+ raise
53
+ else:
54
+ _logger.warning("Encountered unexpected error during autologging: %s", e)
55
+
56
+ return update_wrapper_extended(safe_function, function)
57
+
58
+
59
+ def _safe_function(function, *args, **kwargs):
60
+ try:
61
+ return function(*args, **kwargs)
62
+ except Exception as e:
63
+ if is_testing():
64
+ raise
65
+ else:
66
+ _logger.warning("Encountered unexpected error during autologging: %s", e)
67
+
68
+
69
+ def picklable_exception_safe_function(function):
70
+ """
71
+ Wraps the specified function with broad exception handling to guard
72
+ against unexpected errors during autologging while preserving picklability.
73
+ """
74
+ if is_testing():
75
+ setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
76
+
77
+ return update_wrapper_extended(functools.partial(_safe_function, function), function)
78
+
79
+
80
+ def _exception_safe_class_factory(base_class):
81
+ """
82
+ Creates an exception safe metaclass that inherits from `base_class`.
83
+ """
84
+
85
+ class _ExceptionSafeClass(base_class):
86
+ """
87
+ Metaclass that wraps all functions defined on the specified class with broad error handling
88
+ logic to guard against unexpected errors during autlogging.
89
+
90
+ Rationale: Patched autologging functions commonly pass additional class instances as
91
+ arguments to their underlying original training routines; for example, Keras autologging
92
+ constructs a subclass of `keras.callbacks.Callback` and forwards it to `Model.fit()`.
93
+ To prevent errors encountered during method execution within such classes from disrupting
94
+ model training, this metaclass wraps all class functions in a broad try / catch statement.
95
+
96
+ Note: `ExceptionSafeClass` does not handle exceptions in class methods or static methods,
97
+ as these are not always Python callables and are difficult to wrap
98
+ """
99
+
100
+ def __new__(cls, name, bases, dct):
101
+ for m in dct:
102
+ # class methods or static methods are not callable.
103
+ if callable(dct[m]):
104
+ dct[m] = exception_safe_function_for_class(dct[m])
105
+ return base_class.__new__(cls, name, bases, dct)
106
+
107
+ return _ExceptionSafeClass
108
+
109
+
110
+ ExceptionSafeClass = _exception_safe_class_factory(type)
111
+
112
+ # `ExceptionSafeClass` causes an error when used with an abstract class.
113
+ #
114
+ # ```
115
+ # class AbstractClass(abc.ABC):
116
+ # ...
117
+ #
118
+ # class DerivedClass(AbstractClass, metaclass=ExceptionSafeClass):
119
+ # ...
120
+ # ```
121
+ #
122
+ # This raises:
123
+ #
124
+ # ```
125
+ # TypeError: metaclass conflict: the metaclass of a derived class must be
126
+ # a (non-strict) subclass of the metaclasses of all its bases.
127
+ # ```
128
+ #
129
+ # To avoid this error, create `ExceptionSafeAbstractClass` that is based on `abc.ABCMeta`.
130
+ ExceptionSafeAbstractClass = _exception_safe_class_factory(abc.ABCMeta)
131
+
132
+
133
+ def with_managed_run(autologging_integration, patch_function, tags=None):
134
+ """Given a `patch_function`, returns an `augmented_patch_function` that wraps the execution of
135
+ `patch_function` with an active MLflow run. The following properties apply:
136
+
137
+ - An MLflow run is only created if there is no active run present when the
138
+ patch function is executed
139
+
140
+ - If an active run is created by the `augmented_patch_function`, it is terminated
141
+ with the `FINISHED` state at the end of function execution
142
+
143
+ - If an active run is created by the `augmented_patch_function`, it is terminated
144
+ with the `FAILED` if an unhandled exception is thrown during function execution
145
+
146
+ Note that, if nested runs or non-fluent runs are created by `patch_function`, `patch_function`
147
+ is responsible for terminating them by the time it terminates
148
+ (or in the event of an exception).
149
+
150
+ Args:
151
+ autologging_integration: The autologging integration associated
152
+ with the `patch_function`.
153
+ patch_function: A function object compatible with `safe_patch`.
154
+ tags: A dictionary of string tags to set on each managed run created during the
155
+ execution of `patch_function`.
156
+ """
157
+ from mlflow.tracking.fluent import active_run
158
+ from mlflow.utils.autologging_utils import _has_active_training_session
159
+
160
+ def create_managed_run():
161
+ managed_run = mlflow.start_run(tags=tags)
162
+ _logger.info(
163
+ "Created MLflow autologging run with ID '%s', which will track hyperparameters,"
164
+ " performance metrics, model artifacts, and lineage information for the"
165
+ " current %s workflow",
166
+ managed_run.info.run_id,
167
+ autologging_integration,
168
+ )
169
+ return managed_run
170
+
171
+ def patch_with_managed_run(original, *args, **kwargs):
172
+ managed_run = None
173
+ # If there is an active training session but there is no active run
174
+ # in current thread, it means the thread is spawned by `estimator.fit`
175
+ # as a worker thread, we should disable autologging in
176
+ # these worker threads, so skip creating managed run.
177
+ if not active_run() and not _has_active_training_session():
178
+ managed_run = create_managed_run()
179
+
180
+ try:
181
+ result = patch_function(original, *args, **kwargs)
182
+ except (Exception, KeyboardInterrupt):
183
+ # In addition to standard Python exceptions, handle keyboard interrupts to ensure
184
+ # that runs are terminated if a user prematurely interrupts training execution
185
+ # (e.g. via sigint / ctrl-c)
186
+ if managed_run:
187
+ mlflow.end_run(RunStatus.to_string(RunStatus.FAILED))
188
+ raise
189
+ else:
190
+ if managed_run:
191
+ mlflow.end_run(RunStatus.to_string(RunStatus.FINISHED))
192
+ return result
193
+
194
+ return patch_with_managed_run
195
+
196
+
197
+ def is_testing():
198
+ """
199
+ Indicates whether or not autologging functionality is running in test mode (as determined
200
+ by the `MLFLOW_AUTOLOGGING_TESTING` environment variable). Test mode performs additional
201
+ validation during autologging, including:
202
+
203
+ - Checks for the exception safety of arguments passed to model training functions
204
+ (i.e. all additional arguments should be "exception safe" functions or classes)
205
+ - Disables exception handling for patched function logic, ensuring that patch code
206
+ executes without errors during testing
207
+ """
208
+ return _MLFLOW_AUTOLOGGING_TESTING.get()
209
+
210
+
211
+ def _resolve_extra_tags(autologging_integration, extra_tags):
212
+ tags = {MLFLOW_AUTOLOGGING: autologging_integration}
213
+ if extra_tags:
214
+ if isinstance(extra_tags, dict):
215
+ if MLFLOW_AUTOLOGGING in extra_tags:
216
+ extra_tags.pop(MLFLOW_AUTOLOGGING)
217
+ _logger.warning(
218
+ f"Tag `{MLFLOW_AUTOLOGGING}` is ignored as it is a reserved tag by MLflow "
219
+ f"autologging."
220
+ )
221
+ tags.update(extra_tags)
222
+ else:
223
+ raise mlflow.exceptions.MlflowException.invalid_parameter_value(
224
+ f"Invalid `extra_tags` type: expecting dictionary, "
225
+ f"received `{type(extra_tags).__name__}`"
226
+ )
227
+ return tags
228
+
229
+
230
+ def safe_patch(
231
+ autologging_integration,
232
+ destination,
233
+ function_name,
234
+ patch_function,
235
+ manage_run=False,
236
+ extra_tags=None,
237
+ ):
238
+ """Patches the specified `function_name` on the specified `destination` class for autologging
239
+ purposes, preceding its implementation with an error-safe copy of the specified patch
240
+ `patch_function` with the following error handling behavior:
241
+ - Exceptions thrown from the underlying / original function
242
+ (`<destination>.<function_name>`) are propagated to the caller.
243
+ - Exceptions thrown from other parts of the patched implementation (`patch_function`)
244
+ are caught and logged as warnings.
245
+
246
+ Args:
247
+ autologging_integration: The name of the autologging integration associated with the
248
+ patch.
249
+ destination: The Python class on which the patch is being defined.
250
+ function_name: The name of the function to patch on the specified `destination` class.
251
+ patch_function: The patched function code to apply. The first argument should be reserved
252
+ for an `original` argument representing the underlying / original function. Subsequent
253
+ arguments should be identical to those of the original function being patched.
254
+ manage_run: If `True`, applies the `with_managed_run` wrapper to the specified
255
+ `patch_function`, which automatically creates & terminates an MLflow
256
+ active run during patch code execution if necessary. If `False`,
257
+ does not apply the `with_managed_run` wrapper to the specified
258
+ `patch_function`.
259
+ extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
260
+ """
261
+ from mlflow.tracking.fluent import active_run
262
+ from mlflow.utils.autologging_utils import autologging_is_disabled, get_autologging_config
263
+
264
+ # NB: Checking the signature of the patch function rather than original, so that we don't
265
+ # accidentally change the behavior of existing patches that may use sync patch function
266
+ # for async original functions (e.g. LangChain).
267
+ is_async_function = inspect.iscoroutinefunction(patch_function)
268
+
269
+ if manage_run:
270
+ if is_async_function:
271
+ raise MlflowException("manage_run parameter is not supported for async functions.")
272
+
273
+ tags = _resolve_extra_tags(autologging_integration, extra_tags)
274
+ patch_function = with_managed_run(
275
+ autologging_integration,
276
+ patch_function,
277
+ tags=tags,
278
+ )
279
+
280
+ original_fn = gorilla.get_original_attribute(
281
+ destination, function_name, bypass_descriptor_protocol=False
282
+ )
283
+ # Retrieve raw attribute while bypassing the descriptor protocol
284
+ raw_original_obj = gorilla.get_original_attribute(
285
+ destination, function_name, bypass_descriptor_protocol=True
286
+ )
287
+ if original_fn != raw_original_obj:
288
+ raise RuntimeError(f"Unsupported patch on {destination}.{function_name}")
289
+ elif isinstance(original_fn, property):
290
+ if is_async_function:
291
+ raise MlflowException("Patching async property methods is not supported.")
292
+
293
+ is_property_method = True
294
+
295
+ # For property decorated methods (a kind of method delegation), e.g.
296
+ # class A:
297
+ # @property
298
+ # def f1(self):
299
+ # ...
300
+ # return delegated_f1
301
+ #
302
+ # suppose `a1` is an instance of class `A`,
303
+ # `A.f1.fget` will get the original `def f1(self)` method,
304
+ # and `A.f1.fget(a1)` will be equivalent to `a1.f1()` and
305
+ # its return value will be the `delegated_f1` function.
306
+ # So using the `property.fget` we can construct the (delegated) "original_fn"
307
+ def original(self, *args, **kwargs):
308
+ # the `original_fn.fget` will get the original method decorated by `property`
309
+ # the `original_fn.fget(self)` will get the delegated function returned by the
310
+ # property decorated method.
311
+ bound_delegate_method = original_fn.fget(self)
312
+ return bound_delegate_method(*args, **kwargs)
313
+
314
+ else:
315
+ original = original_fn
316
+ is_property_method = False
317
+
318
+ def safe_patch_function(*args, **kwargs):
319
+ """
320
+ A safe wrapper around the specified `patch_function` implementation designed to
321
+ handle exceptions thrown during the execution of `patch_function`. This wrapper
322
+ distinguishes exceptions thrown from the underlying / original function
323
+ (`<destination>.<function_name>`) from exceptions thrown from other parts of
324
+ `patch_function`. This distinction is made by passing an augmented version of the
325
+ underlying / original function to `patch_function` that uses nonlocal state to track
326
+ whether or not it has been executed and whether or not it threw an exception.
327
+ Exceptions thrown from the underlying / original function are propagated to the caller,
328
+ while exceptions thrown from other parts of `patch_function` are caught and logged as
329
+ warnings.
330
+
331
+ NB: PLEASE BE SUPER CAREFUL WHEN MODIFYING THIS FUNCTION. IT IS USED IN A WIDE VARIETY
332
+ OF CONTEXTX AND CRITICAL PATH IN DBR/MLR BY DEFAULT. ANY BUG HERE CAN BREAK USERS'
333
+ WORKLOAD WITHOUT THEM TAKING ANY ACTION.
334
+ """
335
+ # Reroute warnings encountered during the patch function implementation to an MLflow event
336
+ # logger, and enforce silent mode if applicable (i.e. if the corresponding autologging
337
+ # integration was called with `silent=True`), hiding MLflow event logging statements and
338
+ # hiding all warnings in the autologging preamble and postamble (i.e. the code surrounding
339
+ # the user's original / underlying ML function). Non-MLflow warnings are enabled during the
340
+ # execution of the original / underlying ML function
341
+ #
342
+ # Note that we've opted *not* to apply this context manager as a decorator on
343
+ # `safe_patch_function` because the context-manager-as-decorator pattern uses
344
+ # `contextlib.ContextDecorator`, which creates generator expressions that cannot be pickled
345
+ # during model serialization by ML frameworks such as scikit-learn
346
+ is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
347
+ with (
348
+ MlflowEventsAndWarningsBehaviorGlobally(
349
+ # MLflow warnings emitted during autologging training sessions are likely not
350
+ # actionable and result from the autologging implementation invoking another MLflow
351
+ # API. Accordingly, we reroute these warnings to the MLflow event logger with level
352
+ # WARNING For reference, see recommended warning and event logging behaviors from
353
+ # https://docs.python.org/3/howto/logging.html#when-to-use-logging
354
+ reroute_warnings=True,
355
+ disable_event_logs=is_silent_mode,
356
+ disable_warnings=is_silent_mode,
357
+ ),
358
+ NonMlflowWarningsBehaviorForCurrentThread(
359
+ # non-MLflow Warnings emitted during the autologging preamble (before the original /
360
+ # underlying ML function is called) and postamble (after the original / underlying
361
+ # ML function is called) are likely not actionable and result from the autologging
362
+ # implementation invoking an API from a dependent library. Accordingly, we reroute
363
+ # these warnings to the MLflow event logger with level WARNING. For reference, see
364
+ # recommended warning and event logging behaviors from
365
+ # https://docs.python.org/3/howto/logging.html#when-to-use-logging
366
+ reroute_warnings=True,
367
+ disable_warnings=is_silent_mode,
368
+ ),
369
+ ):
370
+ if is_testing():
371
+ preexisting_run_for_testing = active_run()
372
+
373
+ # Whether or not to exclude autologged content from user-created fluent runs
374
+ # (i.e. runs created manually via `mlflow.start_run()`)
375
+ exclusive = get_autologging_config(autologging_integration, "exclusive", False)
376
+ user_created_fluent_run_is_active = (
377
+ active_run() and not _AutologgingSessionManager.active_session()
378
+ )
379
+ active_session_failed = (
380
+ _AutologgingSessionManager.active_session() is not None
381
+ and _AutologgingSessionManager.active_session().state == "failed"
382
+ )
383
+
384
+ if (
385
+ active_session_failed
386
+ or autologging_is_disabled(autologging_integration)
387
+ or (user_created_fluent_run_is_active and exclusive)
388
+ or (
389
+ mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
390
+ and autologging_integration
391
+ )
392
+ ):
393
+ # If the autologging integration associated with this patch is disabled,
394
+ # or if the current autologging integration is in exclusive mode and a user-created
395
+ # fluent run is active, call the original function and return. Restore the original
396
+ # warning behavior during original function execution, since autologging is being
397
+ # skipped
398
+ with NonMlflowWarningsBehaviorForCurrentThread(
399
+ disable_warnings=False,
400
+ reroute_warnings=False,
401
+ ):
402
+ return original(*args, **kwargs)
403
+
404
+ # Whether or not the original / underlying function has been called during the
405
+ # execution of patched code
406
+ original_has_been_called = False
407
+ # The value returned by the call to the original / underlying function during
408
+ # the execution of patched code
409
+ original_result = None
410
+ # Whether or not an exception was raised from within the original / underlying function
411
+ # during the execution of patched code
412
+ failed_during_original = False
413
+ # The active MLflow run (if any) associated with patch code execution
414
+ patch_function_run_for_testing = None
415
+ # The exception raised during executing patching function
416
+ patch_error = None
417
+
418
+ with _AutologgingSessionManager.start_session(autologging_integration) as session:
419
+ event_logger = AutologgingEventLoggerWrapper(session, destination, function_name)
420
+
421
+ def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs):
422
+ try:
423
+ event_logger.log_original_function_start(og_args, og_kwargs)
424
+
425
+ original_fn_result = original_fn(*og_args, **og_kwargs)
426
+
427
+ event_logger.log_original_function_success(og_args, og_kwargs)
428
+ return original_fn_result
429
+ except Exception as e:
430
+ event_logger.log_original_function_error(og_args, og_kwargs, e)
431
+
432
+ nonlocal failed_during_original
433
+ failed_during_original = True
434
+ raise
435
+
436
+ try:
437
+
438
+ def call_original(*og_args, **og_kwargs):
439
+ def _original_fn(*_og_args, **_og_kwargs):
440
+ if is_testing():
441
+ _validate_args(
442
+ autologging_integration,
443
+ function_name,
444
+ args,
445
+ kwargs,
446
+ og_args,
447
+ og_kwargs,
448
+ )
449
+ # By the time `original` is called by the patch implementation, we
450
+ # assume that either: 1. the patch implementation has already
451
+ # created an MLflow run or 2. the patch code will not create an
452
+ # MLflow run during the current execution. Here, we capture a
453
+ # reference to the active run, which we will use later on to
454
+ # determine whether or not the patch implementation created
455
+ # a run and perform validation if necessary
456
+ nonlocal patch_function_run_for_testing
457
+ patch_function_run_for_testing = active_run()
458
+
459
+ nonlocal original_has_been_called
460
+ original_has_been_called = True
461
+
462
+ nonlocal original_result
463
+ # Show all non-MLflow warnings as normal (i.e. not as event logs)
464
+ # during original function execution, even if silent mode is enabled
465
+ # (`silent=True`), since these warnings originate from the ML framework
466
+ # or one of its dependencies and are likely relevant to the caller
467
+ with NonMlflowWarningsBehaviorForCurrentThread(
468
+ disable_warnings=False,
469
+ reroute_warnings=False,
470
+ ):
471
+ original_result = original(*_og_args, **_og_kwargs)
472
+ return original_result
473
+
474
+ return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
475
+
476
+ # Apply the name, docstring, and signature of `original` to `call_original`.
477
+ # This is important because several autologging patch implementations inspect
478
+ # the signature of the `original` argument during execution
479
+ call_original = update_wrapper_extended(call_original, original)
480
+
481
+ event_logger.log_patch_function_start(args, kwargs)
482
+
483
+ patch_function(call_original, *args, **kwargs)
484
+
485
+ session.state = "succeeded"
486
+ event_logger.log_patch_function_success(args, kwargs)
487
+
488
+ except Exception as e:
489
+ session.state = "failed"
490
+ patch_error = e
491
+ # Exceptions thrown during execution of the original function should be
492
+ # propagated to the caller. Additionally, exceptions encountered during test
493
+ # mode should be reraised to detect bugs in autologging implementations
494
+ if failed_during_original or is_testing():
495
+ raise
496
+
497
+ if is_testing() and not preexisting_run_for_testing:
498
+ # If an MLflow run was created during the execution of patch code, verify that
499
+ # it is no longer active and that it contains expected autologging tags
500
+ assert not active_run(), (
501
+ f"Autologging integration {autologging_integration} leaked an active run"
502
+ )
503
+ if patch_function_run_for_testing:
504
+ _validate_autologging_run(
505
+ autologging_integration, patch_function_run_for_testing.info.run_id
506
+ )
507
+ try:
508
+ if original_has_been_called:
509
+ return original_result
510
+ else:
511
+ return call_original_fn_with_event_logging(original, args, kwargs)
512
+ finally:
513
+ # If original function succeeds, but `patch_function_exception` exists,
514
+ # it represent patching code unexpected failure, so we call
515
+ # `log_patch_function_error` in this case.
516
+ # If original function failed, we don't call `log_patch_function_error`
517
+ # even if `patch_function_exception` exists, because original function failure
518
+ # means there's some error in user code (e.g. user provide wrong arguments)
519
+ if patch_error is not None and not failed_during_original:
520
+ event_logger.log_patch_function_error(args, kwargs, patch_error)
521
+ _logger.warning(_ERROR_MSG.format(autologging_integration, patch_error))
522
+
523
+ async def async_safe_patch_function(*args, **kwargs):
524
+ """
525
+ Async version of safe_patch_function.
526
+
527
+ This code brainlessly copies the synchronous version of the function, but with async
528
+ context managers and async functions. This is done to avoid the risk of introducing
529
+ any bugs or regressions in the async version of the function. Note that we need to
530
+ be really careful here, because autologging is enabled by-default in DBR/MLR, hence
531
+ any bug here can break users' workload without them taking any action.
532
+
533
+ That said, some long comments are omitted in this version to avoid redundancy. If
534
+ you want to understand the context of the code better, please refer to the
535
+ synchronous version as well.
536
+ """
537
+ is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
538
+ async with (
539
+ MlflowEventsAndWarningsBehaviorGlobally(
540
+ reroute_warnings=True,
541
+ disable_event_logs=is_silent_mode,
542
+ disable_warnings=is_silent_mode,
543
+ ),
544
+ NonMlflowWarningsBehaviorForCurrentThread(
545
+ disable_warnings=is_silent_mode,
546
+ reroute_warnings=True,
547
+ ),
548
+ ):
549
+ if is_testing():
550
+ preexisting_run_for_testing = active_run()
551
+
552
+ # Whether or not to exclude autologged content from user-created fluent runs
553
+ # (i.e. runs created manually via `mlflow.start_run()`)
554
+ exclusive = get_autologging_config(autologging_integration, "exclusive", False)
555
+ user_created_fluent_run_is_active = (
556
+ active_run() and not _AutologgingSessionManager.active_session()
557
+ )
558
+ active_session_failed = (
559
+ _AutologgingSessionManager.active_session() is not None
560
+ and _AutologgingSessionManager.active_session().state == "failed"
561
+ )
562
+
563
+ if (
564
+ active_session_failed
565
+ or autologging_is_disabled(autologging_integration)
566
+ or (user_created_fluent_run_is_active and exclusive)
567
+ or (
568
+ mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
569
+ and autologging_integration
570
+ )
571
+ ):
572
+ async with NonMlflowWarningsBehaviorForCurrentThread(False, False):
573
+ return await original(*args, **kwargs)
574
+
575
+ original_has_been_called = False
576
+ original_result = None
577
+ failed_during_original = False
578
+ patch_function_run_for_testing = None
579
+ patch_error = None
580
+
581
+ async with _AutologgingSessionManager.astart_session(
582
+ autologging_integration
583
+ ) as session:
584
+ event_logger = AutologgingEventLoggerWrapper(session, destination, function_name)
585
+
586
+ async def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs):
587
+ try:
588
+ event_logger.log_original_function_start(og_args, og_kwargs)
589
+ original_fn_result = await original_fn(*og_args, **og_kwargs)
590
+ event_logger.log_original_function_success(og_args, og_kwargs)
591
+ return original_fn_result
592
+ except Exception as e:
593
+ event_logger.log_original_function_error(og_args, og_kwargs, e)
594
+ nonlocal failed_during_original
595
+ failed_during_original = True
596
+ raise
597
+
598
+ try:
599
+
600
+ async def call_original(*og_args, **og_kwargs):
601
+ async def _original_fn(*_og_args, **_og_kwargs):
602
+ if is_testing():
603
+ _validate_args(
604
+ autologging_integration,
605
+ function_name,
606
+ args,
607
+ kwargs,
608
+ og_args,
609
+ og_kwargs,
610
+ )
611
+ nonlocal patch_function_run_for_testing
612
+ patch_function_run_for_testing = active_run()
613
+
614
+ nonlocal original_has_been_called
615
+ original_has_been_called = True
616
+
617
+ nonlocal original_result
618
+ async with NonMlflowWarningsBehaviorForCurrentThread(False, False):
619
+ original_result = await original(*_og_args, **_og_kwargs)
620
+ return original_result
621
+
622
+ return await call_original_fn_with_event_logging(
623
+ _original_fn, og_args, og_kwargs
624
+ )
625
+
626
+ # Apply the name, docstring, and signature of `original` to `call_original`.
627
+ # This is important because several autologging patch implementations inspect
628
+ # the signature of the `original` argument during execution
629
+ call_original = update_wrapper_extended(call_original, original)
630
+
631
+ event_logger.log_patch_function_start(args, kwargs)
632
+
633
+ await patch_function(call_original, *args, **kwargs)
634
+
635
+ session.state = "succeeded"
636
+ event_logger.log_patch_function_success(args, kwargs)
637
+
638
+ except Exception as e:
639
+ session.state = "failed"
640
+ patch_error = e
641
+ # Exceptions thrown during execution of the original function should be
642
+ # propagated to the caller. Additionally, exceptions encountered during test
643
+ # mode should be reraised to detect bugs in autologging implementations
644
+ if failed_during_original or is_testing():
645
+ raise
646
+
647
+ if is_testing() and not preexisting_run_for_testing:
648
+ # If an MLflow run was created during the execution of patch code, verify that
649
+ # it is no longer active and that it contains expected autologging tags
650
+ assert not active_run(), (
651
+ f"Autologging integration {autologging_integration} leaked an active run"
652
+ )
653
+ if patch_function_run_for_testing:
654
+ _validate_autologging_run(
655
+ autologging_integration, patch_function_run_for_testing.info.run_id
656
+ )
657
+ try:
658
+ if original_has_been_called:
659
+ return original_result
660
+ else:
661
+ return await call_original_fn_with_event_logging(original, args, kwargs)
662
+ finally:
663
+ if patch_error is not None and not failed_during_original:
664
+ event_logger.log_patch_function_error(args, kwargs, patch_error)
665
+ _logger.warning(_ERROR_MSG.format(autologging_integration, patch_error))
666
+
667
+ if is_property_method:
668
+ # Create a patched function (also property decorated)
669
+ # like:
670
+ #
671
+ # class A:
672
+ # @property
673
+ # def get_bound_safe_patch_fn(self):
674
+ # original_fn.fget(self) # do availability check
675
+ # return bound_safe_patch_fn
676
+ #
677
+ # Suppose `a1` is instance of class A,
678
+ # then `a1.get_bound_safe_patch_fn(*args, **kwargs)` will be equivalent to
679
+ # `bound_safe_patch_fn(*args, **kwargs)`
680
+ def get_bound_safe_patch_fn(self):
681
+ # This `original_fn.fget` call is for availability check, if it raise error
682
+ # then `hasattr(obj, {func_name})` will return False
683
+ # so it mimic the original property behavior.
684
+ original_fn.fget(self)
685
+
686
+ def bound_safe_patch_fn(*args, **kwargs):
687
+ return safe_patch_function(self, *args, **kwargs)
688
+
689
+ # Make bound method `instance.target_method` keep the same doc and signature.
690
+ # Here return the bound safe patch function because user call property decorated
691
+ # method will like `instance.property_decorated_method(...)`, and internally it will
692
+ # call the `bound_safe_patch_fn`, the argument list don't include the `self` argument,
693
+ # so return bound function here.
694
+ return update_wrapper_extended(bound_safe_patch_fn, original_fn.fget)
695
+
696
+ # Make unbound method `class.target_method` keep the same doc and signature
697
+ get_bound_safe_patch_fn = update_wrapper_extended(get_bound_safe_patch_fn, original_fn.fget)
698
+ safe_patch_obj = property(get_bound_safe_patch_fn)
699
+ elif is_async_function:
700
+ safe_patch_obj = update_wrapper_extended(async_safe_patch_function, original)
701
+ else:
702
+ safe_patch_obj = update_wrapper_extended(safe_patch_function, original)
703
+
704
+ new_patch = _wrap_patch(destination, function_name, safe_patch_obj)
705
+ _store_patch(autologging_integration, new_patch)
706
+
707
+
708
+ def revert_patches(autologging_integration):
709
+ """Reverts all patches on the specified destination class for autologging disablement purposes.
710
+
711
+ Args:
712
+ autologging_integration: The name of the autologging integration associated with the
713
+ patch. Note: If called via fluent api (`autologging_integration="mlflow"`), then revert
714
+ all patches for all active autologging integrations.
715
+
716
+ """
717
+ for patch in _AUTOLOGGING_PATCHES.get(autologging_integration, []):
718
+ gorilla.revert(patch)
719
+
720
+ _AUTOLOGGING_PATCHES.pop(autologging_integration, None)
721
+
722
+
723
+ # Represents an active autologging session using two fields:
724
+ # - integration: the name of the autologging integration corresponding to the session
725
+ # - id: a unique session identifier (e.g., a UUID)
726
+ # - state: the state of AutologgingSession, will be one of running/succeeded/failed
727
+ class AutologgingSession:
728
+ def __init__(self, integration, id_):
729
+ self.integration = integration
730
+ self.id = id_
731
+ self.state = "running"
732
+
733
+
734
+ class _AutologgingSessionManager:
735
+ _session = None
736
+
737
+ @classmethod
738
+ @contextmanager
739
+ def start_session(cls, integration):
740
+ try:
741
+ prev_session = cls._session
742
+ if prev_session is None:
743
+ session_id = uuid.uuid4().hex
744
+ cls._session = AutologgingSession(integration, session_id)
745
+ yield cls._session
746
+ finally:
747
+ # Only end the session upon termination of the context if we created
748
+ # the session; otherwise, leave the session open for later termination
749
+ # by its creator
750
+ if prev_session is None:
751
+ cls._end_session()
752
+
753
+ @classmethod
754
+ @asynccontextmanager
755
+ async def astart_session(cls, integration):
756
+ try:
757
+ prev_session = cls._session
758
+ if prev_session is None:
759
+ session_id = uuid.uuid4().hex
760
+ cls._session = AutologgingSession(integration, session_id)
761
+ yield cls._session
762
+ finally:
763
+ if prev_session is None:
764
+ cls._end_session()
765
+
766
+ @classmethod
767
+ def active_session(cls):
768
+ return cls._session
769
+
770
+ @classmethod
771
+ def _end_session(cls):
772
+ cls._session = None
773
+
774
+
775
+ def update_wrapper_extended(wrapper, wrapped):
776
+ """Update a `wrapper` function to look like the `wrapped` function. This is an extension of
777
+ `functools.update_wrapper` that applies the docstring *and* signature of `wrapped` to
778
+ `wrapper`, producing a new function.
779
+
780
+ Returns:
781
+ A new function with the same implementation as `wrapper` and the same docstring
782
+ & signature as `wrapped`.
783
+ """
784
+ updated_wrapper = functools.update_wrapper(wrapper, wrapped)
785
+ # Assign the signature of the `wrapped` function to the updated wrapper function.
786
+ # Certain frameworks may disallow signature inspection, causing `inspect.signature()` to throw.
787
+ # One such example is the `tensorflow.estimator.Estimator.export_savedmodel()` function
788
+ try:
789
+ updated_wrapper.__signature__ = inspect.signature(wrapped)
790
+ except Exception:
791
+ _logger.debug("Failed to restore original signature for wrapper around %s", wrapped)
792
+ return updated_wrapper
793
+
794
+
795
+ def _wrap_patch(destination, name, patch_obj, settings=None):
796
+ """Apply a patch.
797
+
798
+ Args:
799
+ destination: Patch destination.
800
+ name: Name of the attribute at the destination.
801
+ patch_obj: Patch object, it should be a function or a property decorated function
802
+ to be assigned to the patch point {destination}.{name}.
803
+ settings: Settings for gorilla.Patch.
804
+
805
+ """
806
+ if settings is None:
807
+ settings = gorilla.Settings(allow_hit=True, store_hit=True)
808
+
809
+ patch = gorilla.Patch(destination, name, patch_obj, settings=settings)
810
+ gorilla.apply(patch)
811
+ return patch
812
+
813
+
814
+ def _store_patch(autologging_integration, patch):
815
+ """
816
+ Stores a patch for a specified autologging_integration class. Later to be used for being able
817
+ to revert the patch when disabling autologging.
818
+
819
+ Args:
820
+ autologging_integration: The name of the autologging integration associated with the
821
+ patch.
822
+ patch: The patch to be stored.
823
+ """
824
+ if autologging_integration in _AUTOLOGGING_PATCHES:
825
+ _AUTOLOGGING_PATCHES[autologging_integration].add(patch)
826
+ else:
827
+ _AUTOLOGGING_PATCHES[autologging_integration] = {patch}
828
+
829
+
830
+ def _validate_autologging_run(autologging_integration, run_id):
831
+ """
832
+ For testing purposes, verifies that an MLflow run produced by an `autologging_integration`
833
+ satisfies the following properties:
834
+
835
+ - The run has an autologging tag whose value is the name of the autologging integration
836
+ - The run has a terminal status (e.g., KILLED, FAILED, FINISHED)
837
+ """
838
+ from mlflow.tracking.client import MlflowClient
839
+
840
+ client = MlflowClient()
841
+ run = client.get_run(run_id)
842
+ autologging_tag_value = run.data.tags.get(MLFLOW_AUTOLOGGING)
843
+ assert autologging_tag_value == autologging_integration, (
844
+ f"Autologging run with id {run_id} failed to set autologging tag with expected value. "
845
+ f"Expected: '{autologging_integration}', Actual: '{autologging_tag_value}'"
846
+ )
847
+ assert RunStatus.is_terminated(RunStatus.from_string(run.info.status)), (
848
+ f"Autologging run with id {run_id} has a non-terminal status '{run.info.status}'"
849
+ )
850
+
851
+
852
+ class ValidationExemptArgument(NamedTuple):
853
+ """
854
+ A NamedTuple representing the properties of an argument that is exempt from validation
855
+
856
+ autologging_integration: The name of the autologging integration.
857
+ function_name: The name of the function that is being validated.
858
+ type_function: A Callable that accepts an object and returns True if the given object matches
859
+ the argument type. Returns False otherwise.
860
+ positional_argument_index: The index of the argument in the function signature.
861
+ keyword_argument_name: The name of the argument in the function signature.
862
+ """
863
+
864
+ autologging_integration: str
865
+ function_name: str
866
+ type_function: Callable[..., Any]
867
+ positional_argument_index: Optional[int] = None
868
+ keyword_argument_name: Optional[str] = None
869
+
870
+ def matches(
871
+ self,
872
+ autologging_integration,
873
+ function_name,
874
+ value,
875
+ argument_index=None,
876
+ argument_name=None,
877
+ ):
878
+ """
879
+ This method checks if the properties provided through the function arguments matches the
880
+ properties defined in the NamedTuple.
881
+
882
+ Args:
883
+ autologging_integration: The name of an autologging integration.
884
+ function_name: The name of the function that is being matched.
885
+ value: The value of the argument.
886
+ argument_index: The index of the argument, if it is passed as a positional
887
+ argument. Otherwise it is None.
888
+ argument_name: The name of the argument, if it is passed as a keyword
889
+ argument. Otherwise it is None.
890
+
891
+ Returns:
892
+ Returns True if the given function properties matches the exempt argument's
893
+ properties. Returns False otherwise.
894
+ """
895
+ return (
896
+ self.autologging_integration == autologging_integration
897
+ and self.function_name == function_name
898
+ and (
899
+ self.positional_argument_index == argument_index
900
+ or self.keyword_argument_name == argument_name
901
+ )
902
+ and self.type_function(value)
903
+ )
904
+
905
+
906
+ # WARNING: Exemptions should NOT be introduced unless absolutely necessary. If deemed necessary,
907
+ # clear reasons must be provided as comment in addition to thorough integration tests.
908
+ _VALIDATION_EXEMPT_ARGUMENTS = [
909
+ # When extracting implicitly defined `batch_size` in the case that `x` is a generator or a
910
+ # generator class, we need to consume and restore the first element back to the generator to
911
+ # calculate the `batch_size`. This means that:
912
+ # 1. The type of `x` will become 'generator' regardless if user provided `x` as a generator or a
913
+ # custom generator class.
914
+ # 2. The instance of `x` will be different, since we reconstructed the generator after consuming
915
+ # the first element.
916
+ ValidationExemptArgument("tensorflow", "fit", is_iterator, 1, "x"),
917
+ ValidationExemptArgument("keras", "fit", is_iterator, 1, "x"),
918
+ ]
919
+
920
+
921
+ def _is_arg_exempt_from_validation(
922
+ autologging_integration,
923
+ function_name,
924
+ argument,
925
+ argument_index=None,
926
+ argument_name=None,
927
+ ):
928
+ """This function is responsible for determining whether or not an argument is exempt from
929
+ autolog safety validations. This includes both type checking and immutable checking.
930
+
931
+ Args:
932
+ autologging_integration: The name of the autologging integration.
933
+ function_name: The name of the function that is being validated.
934
+ argument: The actual argument.
935
+ argument_index: The index of the argument, if it is passed as a positional
936
+ argument. Otherwise it is None.
937
+ argument_name: The name of the argument, if it is passed as a keyword argument.
938
+ Otherwise it is None.
939
+
940
+ Returns:
941
+ True or False
942
+ """
943
+ return any(
944
+ exemption.matches(
945
+ autologging_integration,
946
+ function_name,
947
+ argument,
948
+ argument_index,
949
+ argument_name,
950
+ )
951
+ for exemption in _VALIDATION_EXEMPT_ARGUMENTS
952
+ )
953
+
954
+
955
+ def _validate_args(
956
+ autologging_integration,
957
+ function_name,
958
+ user_call_args,
959
+ user_call_kwargs,
960
+ autologging_call_args,
961
+ autologging_call_kwargs,
962
+ ):
963
+ """
964
+ Used for testing purposes to verify that, when a patched ML function calls its underlying
965
+ / original ML function, the following properties are satisfied:
966
+
967
+ - All arguments supplied to the patched ML function are forwarded to the
968
+ original ML function
969
+ - Any additional arguments supplied to the original function are exception safe (i.e.
970
+ they are either functions decorated with the `@exception_safe_function_for_class` or
971
+ `@pickalable_exception_safe_function` decorators, or classes / instances of classes with
972
+ type `ExceptionSafeClass`
973
+ """
974
+
975
+ def _validate_new_input(inp):
976
+ """
977
+ Validates a new input (arg or kwarg) introduced to the underlying / original ML function
978
+ call during the execution of a patched ML function. The new input is valid if:
979
+
980
+ - The new input is a function that has been decorated with
981
+ `exception_safe_function_for_class` or `pickalable_exception_safe_function`
982
+ - OR the new input is a class with the `ExceptionSafeClass` metaclass
983
+ - OR the new input is a list and each of its elements is valid according to the
984
+ these criteria
985
+ """
986
+ if type(inp) == list:
987
+ for item in inp:
988
+ _validate_new_input(item)
989
+ elif isinstance(inp, dict) and "callbacks" in inp:
990
+ _validate_new_input(inp["callbacks"])
991
+ elif callable(inp):
992
+ assert getattr(inp, _ATTRIBUTE_EXCEPTION_SAFE, False), (
993
+ f"New function argument '{inp}' passed to original function is not exception-safe."
994
+ " Please decorate the function with `exception_safe_function` or "
995
+ "`pickalable_exception_safe_function`"
996
+ )
997
+ else:
998
+ assert hasattr(inp, "__class__") and type(inp.__class__) in [
999
+ ExceptionSafeClass,
1000
+ ExceptionSafeAbstractClass,
1001
+ ], (
1002
+ f"Invalid new input '{inp}'. New args / kwargs introduced to `original` function "
1003
+ "calls by patched code must either be functions decorated with "
1004
+ "`exception_safe_function_for_class`, instances of classes with the "
1005
+ "`ExceptionSafeClass` or `ExceptionSafeAbstractClass` metaclass safe or lists of "
1006
+ "such exception safe functions / classes."
1007
+ )
1008
+
1009
+ def _assert_autologging_input_positional_args_are_superset(
1010
+ autologging_call_input, user_call_input
1011
+ ):
1012
+ length_diff = len(autologging_call_input) - len(user_call_input)
1013
+ assert length_diff >= 0, (
1014
+ f"{length_diff} expected inputs are missing from the call to the original function."
1015
+ )
1016
+
1017
+ def _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input):
1018
+ assert set(user_call_input.keys()).issubset(set(autologging_call_input.keys())), (
1019
+ "Keyword or dictionary arguments to original function omit"
1020
+ " one or more expected keys: '{}'".format(
1021
+ set(user_call_input.keys()) - set(autologging_call_input.keys())
1022
+ )
1023
+ )
1024
+
1025
+ def _validate(autologging_call_input, user_call_input=None):
1026
+ """
1027
+ Validates that the specified `autologging_call_input` and `user_call_input`
1028
+ are compatible. If `user_call_input` is `None`, then `autologging_call_input`
1029
+ is regarded as a new input added by autologging and is validated using
1030
+ `_validate_new_input`. Otherwise, the following properties must hold:
1031
+
1032
+ - `autologging_call_input` and `user_call_input` must have the same type
1033
+ (referred to as "input type")
1034
+ - if the input type is a tuple, list or dictionary, then `autologging_call_input` must
1035
+ be equivalent to `user_call_input` or be a superset of `user_call_input`
1036
+ - for all other input types, `autologging_call_input` and `user_call_input`
1037
+ must be equivalent by reference equality or by object equality
1038
+
1039
+ Args:
1040
+ autologging_call_input: call input from autologging.
1041
+ user_call_input: call input from user.
1042
+ """
1043
+
1044
+ if user_call_input is None and autologging_call_input is not None:
1045
+ _validate_new_input(autologging_call_input)
1046
+ return
1047
+
1048
+ assert type(autologging_call_input) == type(user_call_input), (
1049
+ "Type of input to original function '{}' does not match expected type '{}'".format(
1050
+ type(autologging_call_input), type(user_call_input)
1051
+ )
1052
+ )
1053
+
1054
+ if type(autologging_call_input) in [list, tuple]:
1055
+ _assert_autologging_input_positional_args_are_superset(
1056
+ autologging_call_input, user_call_input
1057
+ )
1058
+ # If the autologging call input is longer than the user call input, we `zip_longest`
1059
+ # will pad the user call input with `None` values to ensure that the subsequent calls
1060
+ # to `_validate` identify new inputs added by the autologging call
1061
+ for a, u in itertools.zip_longest(autologging_call_input, user_call_input):
1062
+ _validate(a, u)
1063
+ elif type(autologging_call_input) == dict:
1064
+ _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input)
1065
+ for key in autologging_call_input.keys():
1066
+ _validate(autologging_call_input[key], user_call_input.get(key, None))
1067
+
1068
+ else:
1069
+ assert (
1070
+ autologging_call_input is user_call_input
1071
+ or autologging_call_input == user_call_input
1072
+ ), (
1073
+ "Input to original function does not match expected input."
1074
+ f" Original: '{autologging_call_input}'. Expected: '{user_call_input}'"
1075
+ )
1076
+
1077
+ # Similar validation logic found in _validate, unraveling the list of arguments to exclude
1078
+ # checks for any validation exempt positional arguments.
1079
+ _assert_autologging_input_positional_args_are_superset(autologging_call_args, user_call_args)
1080
+ for index, autologging_call_arg, user_call_arg in itertools.zip_longest(
1081
+ range(len(user_call_args)), autologging_call_args, user_call_args
1082
+ ):
1083
+ if not _is_arg_exempt_from_validation(
1084
+ autologging_integration,
1085
+ function_name,
1086
+ user_call_arg,
1087
+ argument_index=index,
1088
+ ):
1089
+ _validate(autologging_call_arg, user_call_arg)
1090
+
1091
+ # Similar validation logic found in _validate, unraveling the dictionary of arguments to exclude
1092
+ # checks for any validation exempt keyword arguments.
1093
+ _assert_autologging_input_kwargs_are_superset(autologging_call_kwargs, user_call_kwargs)
1094
+ for key in autologging_call_kwargs.keys():
1095
+ if not _is_arg_exempt_from_validation(
1096
+ autologging_integration,
1097
+ function_name,
1098
+ user_call_kwargs.get(key, None),
1099
+ argument_name=key,
1100
+ ):
1101
+ _validate(
1102
+ autologging_call_kwargs[key],
1103
+ user_call_kwargs.get(key, None),
1104
+ )