genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,580 @@
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import warnings
5
+
6
+ from packaging.version import Version
7
+
8
+ import mlflow.pytorch
9
+ from mlflow.exceptions import MlflowException
10
+ from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
11
+ from mlflow.utils.autologging_utils import (
12
+ BatchMetricsLogger,
13
+ ExceptionSafeAbstractClass,
14
+ MlflowAutologgingQueueingClient,
15
+ disable_autologging,
16
+ get_autologging_config,
17
+ )
18
+ from mlflow.utils.checkpoint_utils import MlflowModelCheckpointCallbackBase
19
+
20
+ logging.basicConfig(level=logging.ERROR)
21
+ MIN_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["minimum"])
22
+ MAX_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["maximum"])
23
+
24
+ import pytorch_lightning as pl
25
+ from pytorch_lightning.utilities import rank_zero_only
26
+
27
+ # The following are the downsides of using PyTorch Lightning's built-in MlflowLogger.
28
+ # 1. MlflowLogger doesn't provide a mechanism to store an entire model into mlflow.
29
+ # Only model checkpoint is saved.
30
+ # 2. For storing the model into mlflow `mlflow.pytorch` library is used
31
+ # and the library expects `mlflow` object to be instantiated.
32
+ # In case of MlflowLogger, Run management is completely controlled by the class and
33
+ # hence mlflow object needs to be reinstantiated by setting
34
+ # tracking uri, experiment_id and run_id which may lead to a race condition.
35
+ # TODO: Replace __MlflowPLCallback with Pytorch Lightning's built-in MlflowLogger
36
+ # once the above mentioned issues have been addressed
37
+
38
+ _logger = logging.getLogger(__name__)
39
+
40
+ _pl_version = Version(pl.__version__)
41
+ if _pl_version < Version("1.5.0"):
42
+ from pytorch_lightning.core.memory import ModelSummary
43
+ else:
44
+ from pytorch_lightning.utilities.model_summary import ModelSummary
45
+
46
+
47
+ def _get_optimizer_name(optimizer):
48
+ """
49
+ In pytorch-lightning 1.1.0, `LightningOptimizer` was introduced:
50
+ https://github.com/PyTorchLightning/pytorch-lightning/pull/4658
51
+
52
+ If a user sets `enable_pl_optimizer` to True when instantiating a `Trainer` object,
53
+ each optimizer will be wrapped by `LightningOptimizer`:
54
+ https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.html
55
+ #pytorch_lightning.trainer.trainer.Trainer.params.enable_pl_optimizer
56
+ """
57
+ if Version(pl.__version__) < Version("1.1.0"):
58
+ return optimizer.__class__.__name__
59
+ else:
60
+ from pytorch_lightning.core.optimizer import LightningOptimizer
61
+
62
+ return (
63
+ optimizer._optimizer.__class__.__name__
64
+ if isinstance(optimizer, LightningOptimizer)
65
+ else optimizer.__class__.__name__
66
+ )
67
+
68
+
69
+ class __MlflowPLCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass):
70
+ """
71
+ Callback for auto-logging metrics and parameters.
72
+ """
73
+
74
+ def __init__(
75
+ self, client, metrics_logger, run_id, log_models, log_every_n_epoch, log_every_n_step
76
+ ):
77
+ if log_every_n_step and _pl_version < Version("1.1.0"):
78
+ raise MlflowException(
79
+ "log_every_n_step is only supported for PyTorch-Lightning >= 1.1.0"
80
+ )
81
+ self.early_stopping = False
82
+ self.client = client
83
+ self.metrics_logger = metrics_logger
84
+ self.run_id = run_id
85
+ self.log_models = log_models
86
+ self.log_every_n_epoch = log_every_n_epoch
87
+ self.log_every_n_step = log_every_n_step
88
+ self._global_steps_per_training_step = 1
89
+ # Sets for tracking which metrics are logged on steps and which are logged on epochs
90
+ self._step_metrics = set()
91
+ self._epoch_metrics = set()
92
+
93
+ def _log_metrics(self, trainer, step, metric_items):
94
+ # pytorch-lightning runs a few steps of validation in the beginning of training
95
+ # as a sanity check to catch bugs without having to wait for the training routine
96
+ # to complete. During this check, we should skip logging metrics.
97
+ # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#num-sanity-val-steps
98
+ sanity_checking = (
99
+ # `running_sanity_check` has been renamed to `sanity_checking`:
100
+ # https://github.com/PyTorchLightning/pytorch-lightning/pull/9209
101
+ trainer.sanity_checking
102
+ if Version(pl.__version__) > Version("1.4.5")
103
+ else trainer.running_sanity_check
104
+ )
105
+ if sanity_checking:
106
+ return
107
+
108
+ # Cast metric value as float before passing into logger.
109
+ metrics = {x[0]: float(x[1]) for x in metric_items}
110
+ self.metrics_logger.record_metrics(metrics, step)
111
+
112
+ def _log_epoch_metrics(self, trainer, pl_module):
113
+ # `trainer.callback_metrics` contains both training and validation metrics
114
+ # and includes metrics logged on steps and epochs.
115
+ # If we have logged any metrics on a step basis in mlflow, we exclude these from the
116
+ # epoch level metrics to prevent mixing epoch and step based values.
117
+ metric_items = [
118
+ (name, val)
119
+ for (name, val) in trainer.callback_metrics.items()
120
+ if name not in self._step_metrics
121
+ ]
122
+ # Record which metrics are logged on epochs, so we don't try to log these on steps
123
+ self._epoch_metrics.update(name for (name, _) in metric_items)
124
+ if (pl_module.current_epoch + 1) % self.log_every_n_epoch == 0:
125
+ self._log_metrics(trainer, pl_module.current_epoch, metric_items)
126
+
127
+ _pl_version = Version(pl.__version__)
128
+
129
+ # In pytorch-lightning >= 1.4.0, validation is run inside the training epoch and
130
+ # `trainer.callback_metrics` contains both training and validation metrics of the
131
+ # current training epoch when `on_train_epoch_end` is called:
132
+ # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357
133
+ if _pl_version >= Version("1.4.0dev"):
134
+
135
+ @rank_zero_only
136
+ def on_train_epoch_end(self, trainer, pl_module, *args):
137
+ self._log_epoch_metrics(trainer, pl_module)
138
+
139
+ # In pytorch-lightning >= 1.2.0, logging metrics in `on_epoch_end` results in duplicate
140
+ # metrics records because `on_epoch_end` is called after both train and validation
141
+ # epochs (related PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/5986)
142
+ # As a workaround, use `on_train_epoch_end` and `on_validation_epoch_end` instead
143
+ # in pytorch-lightning >= 1.2.0.
144
+ elif _pl_version >= Version("1.2.0"):
145
+ # NB: Override `on_train_epoch_end` with an additional `*args` parameter for
146
+ # compatibility with versions of pytorch-lightning <= 1.2.0, which required an
147
+ # `outputs` argument that was not used and is no longer defined in
148
+ # pytorch-lightning >= 1.3.0
149
+
150
+ @rank_zero_only
151
+ def on_train_epoch_end(self, trainer, pl_module, *args):
152
+ """
153
+ Log loss and other metrics values after each train epoch
154
+
155
+ Args:
156
+ trainer: pytorch lightning trainer instance
157
+ pl_module: pytorch lightning base module
158
+ args: additional positional arguments
159
+ """
160
+ # If validation loop is enabled (meaning `validation_step` is overridden),
161
+ # log metrics in `on_validaion_epoch_end` to avoid logging the same metrics
162
+ # records twice
163
+ if not trainer.enable_validation:
164
+ self._log_epoch_metrics(trainer, pl_module)
165
+
166
+ @rank_zero_only
167
+ def on_validation_epoch_end(self, trainer, pl_module):
168
+ """
169
+ Log loss and other metrics values after each validation epoch
170
+
171
+ Args:
172
+ trainer: pytorch lightning trainer instance
173
+ pl_module: pytorch lightning base module
174
+ """
175
+ self._log_epoch_metrics(trainer, pl_module)
176
+
177
+ else:
178
+
179
+ @rank_zero_only
180
+ def on_epoch_end(self, trainer, pl_module):
181
+ """
182
+ Log loss and other metrics values after each epoch
183
+
184
+ Args:
185
+ trainer: pytorch lightning trainer instance
186
+ pl_module: pytorch lightning base module
187
+ """
188
+ self._log_epoch_metrics(trainer, pl_module)
189
+
190
+ @rank_zero_only
191
+ def on_train_batch_end(self, trainer, pl_module, *args):
192
+ """
193
+ Log metric values after each step
194
+
195
+ Args:
196
+ trainer: pytorch lightning trainer instance
197
+ pl_module: pytorch lightning base module
198
+ args: additional positional arguments
199
+ """
200
+ if not self.log_every_n_step:
201
+ return
202
+ # When logging at the end of a batch step, we only want to log metrics that are logged
203
+ # on steps. For forked metrics (metrics logged on both steps and epochs), we exclude the
204
+ # metric with the non-forked name (eg. "loss" when we have "loss", "loss_step" and
205
+ # "loss_epoch") so that this is only logged on epochs. We also record which metrics
206
+ # we've logged per step, so we can later exclude these from metrics logged on epochs.
207
+ metrics = _get_step_metrics(trainer)
208
+ metric_items = [
209
+ (name, val)
210
+ for (name, val) in metrics.items()
211
+ if (name not in self._epoch_metrics) and (f"{name}_step" not in metrics.keys())
212
+ ]
213
+ self._step_metrics.update(name for (name, _) in metric_items)
214
+ step = trainer.global_step
215
+ if ((step // self._global_steps_per_training_step) + 1) % self.log_every_n_step == 0:
216
+ self._log_metrics(trainer, step, metric_items)
217
+
218
+ @rank_zero_only
219
+ def on_train_start(self, trainer, pl_module):
220
+ """
221
+ Logs Optimizer related metrics when the train begins
222
+
223
+ Args:
224
+ trainer: pytorch lightning trainer instance
225
+ pl_module: pytorch lightning base module
226
+ """
227
+ self.client.set_tags(self.run_id, {"Mode": "training"})
228
+
229
+ params = {"epochs": trainer.max_epochs}
230
+
231
+ # TODO For logging optimizer params - Following scenarios are to revisited.
232
+ # 1. In the current scenario, only the first optimizer details are logged.
233
+ # Code to be enhanced to log params when multiple optimizers are used.
234
+ # 2. mlflow.log_params is used to store optimizer default values into mlflow.
235
+ # The keys in default dictionary are too short, Ex: (lr - learning_rate).
236
+ # Efficient mapping technique needs to be introduced
237
+ # to rename the optimizer parameters based on keys in default dictionary.
238
+
239
+ if hasattr(trainer, "optimizers"):
240
+ # Lightning >= 1.6.0 increments the global step every time an optimizer is stepped.
241
+ # We assume every optimizer will be stepped in each training step.
242
+ if _pl_version >= Version("1.6.0"):
243
+ self._global_steps_per_training_step = len(trainer.optimizers)
244
+ optimizer = trainer.optimizers[0]
245
+ params["optimizer_name"] = _get_optimizer_name(optimizer)
246
+
247
+ if hasattr(optimizer, "defaults"):
248
+ params.update(optimizer.defaults)
249
+
250
+ self.client.log_params(self.run_id, params)
251
+ self.client.flush(synchronous=True)
252
+
253
+ @rank_zero_only
254
+ def on_train_end(self, trainer, pl_module):
255
+ """
256
+ Logs the model checkpoint into mlflow - models folder on the training end
257
+
258
+
259
+ Args:
260
+ trainer: pytorch lightning trainer instance
261
+ pl_module: pytorch lightning base module
262
+ """
263
+ # manually flush any remaining metadata from training
264
+ self.metrics_logger.flush()
265
+ self.client.flush(synchronous=True)
266
+
267
+ @rank_zero_only
268
+ def on_test_end(self, trainer, pl_module):
269
+ """
270
+ Logs accuracy and other relevant metrics on the testing end
271
+
272
+ Args:
273
+ trainer: pytorch lightning trainer instance
274
+ pl_module: pytorch lightning base module
275
+ """
276
+ self.client.set_tags(self.run_id, {"Mode": "testing"})
277
+ self.client.flush(synchronous=True)
278
+
279
+ self.metrics_logger.record_metrics(
280
+ {key: float(value) for key, value in trainer.callback_metrics.items()}
281
+ )
282
+ self.metrics_logger.flush()
283
+
284
+
285
+ class MlflowModelCheckpointCallback(pl.Callback, MlflowModelCheckpointCallbackBase):
286
+ """Callback for auto-logging pytorch-lightning model checkpoints to MLflow.
287
+ This callback implementation only supports pytorch-lightning >= 1.6.0.
288
+
289
+ Args:
290
+ monitor: In automatic model checkpointing, the metric name to monitor if
291
+ you set `model_checkpoint_save_best_only` to True.
292
+ save_best_only: If True, automatic model checkpointing only saves when
293
+ the model is considered the "best" model according to the quantity
294
+ monitored and previous checkpoint model is overwritten.
295
+ mode: one of {"min", "max"}. In automatic model checkpointing,
296
+ if save_best_only=True, the decision to overwrite the current save file is made
297
+ based on either the maximization or the minimization of the monitored quantity.
298
+ save_weights_only: In automatic model checkpointing, if True, then
299
+ only the model's weights will be saved. Otherwise, the optimizer states,
300
+ lr-scheduler states, etc are added in the checkpoint too.
301
+ save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
302
+ saves the model after each epoch. When using integer, the callback
303
+ saves the model at end of this many batches. Note that if the saving isn't
304
+ aligned to epochs, the monitored metric may potentially be less reliable (it
305
+ could reflect as little as 1 batch, since the metrics get reset
306
+ every epoch). Defaults to `"epoch"`.
307
+
308
+ .. code-block:: python
309
+ :caption: Example
310
+
311
+ import mlflow
312
+ from mlflow.pytorch import MlflowModelCheckpointCallback
313
+ from pytorch_lightning import Trainer
314
+
315
+ mlflow.pytorch.autolog(checkpoint=True)
316
+
317
+ model = MyLightningModuleNet() # A custom-pytorch lightning model
318
+ train_loader = create_train_dataset_loader()
319
+
320
+ mlflow_checkpoint_callback = MlflowModelCheckpointCallback()
321
+
322
+ trainer = Trainer(callbacks=[mlflow_checkpoint_callback])
323
+
324
+ with mlflow.start_run() as run:
325
+ trainer.fit(model, train_loader)
326
+
327
+ """
328
+
329
+ def __init__(
330
+ self,
331
+ monitor="val_loss",
332
+ mode="min",
333
+ save_best_only=True,
334
+ save_weights_only=False,
335
+ save_freq="epoch",
336
+ ):
337
+ super().__init__(
338
+ checkpoint_file_suffix=".pth",
339
+ monitor=monitor,
340
+ mode=mode,
341
+ save_best_only=save_best_only,
342
+ save_weights_only=save_weights_only,
343
+ save_freq=save_freq,
344
+ )
345
+ self.trainer = None
346
+
347
+ def save_checkpoint(self, filepath: str):
348
+ # Note: `trainer.save_checkpoint` implementation contains invocation of
349
+ # `self.strategy.barrier("Trainer.save_checkpoint")`,
350
+ # in DDP training, this callback is only invoked in rank 0 process,
351
+ # the `barrier` invocation causes deadlock,
352
+ # so I implement `save_checkpoint` instead of
353
+ # calling `trainer.save_checkpoint`.
354
+ checkpoint = self.trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
355
+ self.trainer.strategy.save_checkpoint(checkpoint, filepath)
356
+
357
+ @rank_zero_only
358
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
359
+ self.trainer = trainer
360
+
361
+ @rank_zero_only
362
+ def on_train_batch_end(
363
+ self,
364
+ trainer: "pl.Trainer",
365
+ pl_module: "pl.LightningModule",
366
+ outputs,
367
+ batch,
368
+ batch_idx,
369
+ ) -> None:
370
+ if isinstance(self.save_freq, int) and (
371
+ trainer.global_step > 0 and trainer.global_step % self.save_freq == 0
372
+ ):
373
+ self.check_and_save_checkpoint_if_needed(
374
+ current_epoch=trainer.current_epoch,
375
+ global_step=trainer.global_step,
376
+ metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()},
377
+ )
378
+
379
+ @rank_zero_only
380
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
381
+ if self.save_freq == "epoch":
382
+ self.check_and_save_checkpoint_if_needed(
383
+ current_epoch=trainer.current_epoch,
384
+ global_step=trainer.global_step,
385
+ metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()},
386
+ )
387
+
388
+
389
+ # PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics
390
+ # update on demand. Prior to this, the metrics from the current step were not available to
391
+ # callbacks immediately, so the view of metrics was off by one step.
392
+ # To avoid this problem, we access the metrics via the logger_connector for older versions.
393
+ if _pl_version >= Version("1.4.0"):
394
+
395
+ def _get_step_metrics(trainer):
396
+ return trainer.callback_metrics
397
+
398
+ else:
399
+
400
+ def _get_step_metrics(trainer):
401
+ return trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
402
+
403
+
404
+ def _log_early_stop_params(early_stop_callback, client, run_id):
405
+ """
406
+ Logs early stopping configuration parameters to MLflow.
407
+
408
+ Args:
409
+ early_stop_callback: The early stopping callback instance used during training.
410
+ client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
411
+ run_id: The ID of the MLflow Run to which to log configuration parameters.
412
+ """
413
+ client.log_params(
414
+ run_id,
415
+ {
416
+ p: getattr(early_stop_callback, p)
417
+ for p in ["monitor", "mode", "patience", "min_delta", "stopped_epoch"]
418
+ if hasattr(early_stop_callback, p)
419
+ },
420
+ )
421
+
422
+
423
+ def _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=None):
424
+ """
425
+ Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow.
426
+
427
+ Args:
428
+ early_stop_callback: The early stopping callback instance used during training.
429
+ client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
430
+ run_id: The ID of the MLflow Run to which to log configuration parameters.
431
+ model_id: The ID of the LoggedModel to which the metrics are associated.
432
+ """
433
+ if early_stop_callback.stopped_epoch == 0:
434
+ return
435
+
436
+ metrics = {
437
+ "stopped_epoch": early_stop_callback.stopped_epoch,
438
+ "restored_epoch": early_stop_callback.stopped_epoch - max(1, early_stop_callback.patience),
439
+ }
440
+
441
+ if hasattr(early_stop_callback, "best_score"):
442
+ metrics["best_score"] = float(early_stop_callback.best_score)
443
+
444
+ if hasattr(early_stop_callback, "wait_count"):
445
+ metrics["wait_count"] = early_stop_callback.wait_count
446
+
447
+ client.log_metrics(run_id, metrics, model_id=model_id)
448
+
449
+
450
+ def patched_fit(original, self, *args, **kwargs):
451
+ """
452
+ A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the
453
+ following parameters, metrics and artifacts:
454
+
455
+ - Training epochs
456
+ - Optimizer parameters
457
+ - `EarlyStoppingCallback`_ parameters
458
+ - Metrics stored in `trainer.callback_metrics`
459
+ - Model checkpoints
460
+ - Trained model
461
+
462
+ .. _EarlyStoppingCallback:
463
+ https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
464
+ """
465
+ if not MIN_REQ_VERSION <= _pl_version <= MAX_REQ_VERSION:
466
+ warnings.warn(
467
+ "Autologging is known to be compatible with pytorch-lightning versions between "
468
+ f"{MIN_REQ_VERSION} and {MAX_REQ_VERSION} and may not succeed with packages "
469
+ "outside this range."
470
+ )
471
+
472
+ with disable_autologging():
473
+ run_id = mlflow.active_run().info.run_id
474
+ tracking_uri = mlflow.get_tracking_uri()
475
+ client = MlflowAutologgingQueueingClient(tracking_uri)
476
+
477
+ log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_models", True)
478
+ model_id = None
479
+ if log_models:
480
+ model_id = mlflow.initialize_logged_model(name="model").model_id
481
+ metrics_logger = BatchMetricsLogger(run_id, tracking_uri, model_id=model_id)
482
+
483
+ log_every_n_epoch = get_autologging_config(
484
+ mlflow.pytorch.FLAVOR_NAME, "log_every_n_epoch", 1
485
+ )
486
+ log_every_n_step = get_autologging_config(
487
+ mlflow.pytorch.FLAVOR_NAME, "log_every_n_step", None
488
+ )
489
+
490
+ early_stop_callback = None
491
+ for callback in self.callbacks:
492
+ if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping):
493
+ early_stop_callback = callback
494
+ _log_early_stop_params(early_stop_callback, client, run_id)
495
+
496
+ if not any(isinstance(callbacks, __MlflowPLCallback) for callbacks in self.callbacks):
497
+ self.callbacks += [
498
+ __MlflowPLCallback(
499
+ client, metrics_logger, run_id, log_models, log_every_n_epoch, log_every_n_step
500
+ )
501
+ ]
502
+
503
+ model_checkpoint = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "checkpoint", True)
504
+ if model_checkpoint:
505
+ # __MLflowModelCheckpoint only supports pytorch-lightning >= 1.6.0
506
+ if _pl_version >= Version("1.6.0"):
507
+ checkpoint_monitor = get_autologging_config(
508
+ mlflow.pytorch.FLAVOR_NAME, "checkpoint_monitor", "val_loss"
509
+ )
510
+ checkpoint_mode = get_autologging_config(
511
+ mlflow.pytorch.FLAVOR_NAME, "checkpoint_mode", "min"
512
+ )
513
+ checkpoint_save_best_only = get_autologging_config(
514
+ mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_best_only", True
515
+ )
516
+ checkpoint_save_weights_only = get_autologging_config(
517
+ mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_weights_only", False
518
+ )
519
+ checkpoint_save_freq = get_autologging_config(
520
+ mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_freq", "epoch"
521
+ )
522
+
523
+ if not any(
524
+ isinstance(callbacks, MlflowModelCheckpointCallback)
525
+ for callbacks in self.callbacks
526
+ ):
527
+ self.callbacks += [
528
+ MlflowModelCheckpointCallback(
529
+ monitor=checkpoint_monitor,
530
+ mode=checkpoint_mode,
531
+ save_best_only=checkpoint_save_best_only,
532
+ save_weights_only=checkpoint_save_weights_only,
533
+ save_freq=checkpoint_save_freq,
534
+ )
535
+ ]
536
+ else:
537
+ warnings.warn(
538
+ "Automatic model checkpointing is disabled because this feature only "
539
+ "supports pytorch-lightning >= 1.6.0."
540
+ )
541
+
542
+ client.flush(synchronous=False)
543
+
544
+ result = original(self, *args, **kwargs)
545
+
546
+ if early_stop_callback is not None:
547
+ _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=model_id)
548
+
549
+ if Version(pl.__version__) < Version("1.4.0"):
550
+ summary = str(ModelSummary(self.model, mode="full"))
551
+ else:
552
+ summary = str(ModelSummary(self.model, max_depth=-1))
553
+
554
+ with tempfile.TemporaryDirectory() as tempdir:
555
+ summary_file = os.path.join(tempdir, "model_summary.txt")
556
+ with open(summary_file, "w") as f:
557
+ f.write(summary)
558
+
559
+ mlflow.log_artifact(local_path=summary_file)
560
+
561
+ if log_models:
562
+ registered_model_name = get_autologging_config(
563
+ mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None
564
+ )
565
+ mlflow.pytorch.log_model(
566
+ self.model,
567
+ name="model",
568
+ registered_model_name=registered_model_name,
569
+ model_id=model_id,
570
+ )
571
+
572
+ if early_stop_callback is not None and self.checkpoint_callback.best_model_path:
573
+ mlflow.log_artifact(
574
+ local_path=self.checkpoint_callback.best_model_path,
575
+ artifact_path="restored_model_checkpoint",
576
+ )
577
+
578
+ client.flush(synchronous=True)
579
+
580
+ return result
@@ -0,0 +1,50 @@
1
+ import time
2
+
3
+ import mlflow
4
+ from mlflow.entities import Metric, Param
5
+ from mlflow.tracking import MlflowClient
6
+ from mlflow.utils.autologging_utils.metrics_queue import (
7
+ add_to_metrics_queue,
8
+ flush_metrics_queue,
9
+ )
10
+
11
+
12
+ def patched_add_hparams(original, self, hparam_dict, metric_dict, *args, **kwargs):
13
+ """use a synchronous call here since this is going to get called very infrequently."""
14
+
15
+ run = mlflow.active_run()
16
+
17
+ if run is not None and hparam_dict:
18
+ run_id = run.info.run_id
19
+ # str() is required by mlflow :(
20
+ params_arr = [Param(key, str(value)) for key, value in hparam_dict.items()]
21
+ metrics_arr = [
22
+ Metric(key, value, int(time.time() * 1000), 0) for key, value in metric_dict.items()
23
+ ]
24
+ MlflowClient().log_batch(run_id=run_id, metrics=metrics_arr, params=params_arr, tags=[])
25
+
26
+ return original(self, hparam_dict, metric_dict, *args, **kwargs)
27
+
28
+
29
+ def patched_add_event(original, self, event, *args, mlflow_log_every_n_step, **kwargs):
30
+ run = mlflow.active_run()
31
+ if run is not None and event.WhichOneof("what") == "summary" and mlflow_log_every_n_step:
32
+ summary = event.summary
33
+ global_step = args[0] if len(args) > 0 else kwargs.get("global_step")
34
+ global_step = global_step or 0
35
+ for v in summary.value:
36
+ if v.HasField("simple_value") and global_step % mlflow_log_every_n_step == 0:
37
+ add_to_metrics_queue(
38
+ key=v.tag,
39
+ value=v.simple_value,
40
+ step=global_step,
41
+ time=int((event.wall_time or time.time()) * 1000),
42
+ run_id=run.info.run_id,
43
+ )
44
+
45
+ return original(self, event, *args, **kwargs)
46
+
47
+
48
+ def patched_add_summary(original, self, *args, **kwargs):
49
+ flush_metrics_queue()
50
+ return original(self, *args, **kwargs)
@@ -0,0 +1,35 @@
1
+ """
2
+ This module imports contents from CloudPickle in a way that is compatible with the
3
+ ``pickle_module`` parameter of PyTorch's model persistence function: ``torch.save``
4
+ (see https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/
5
+ serialization.py#L192). It is included as a distinct module from :mod:`mlflow.pytorch` to avoid
6
+ polluting the namespace with wildcard imports.
7
+
8
+ Calling ``torch.save(..., pickle_module=mlflow.pytorch.pickle_module)`` will persist PyTorch model
9
+ definitions using CloudPickle, leveraging improved pickling functionality such as the ability
10
+ to capture class definitions in the "__main__" scope.
11
+
12
+ TODO: Remove this module or make it an alias of CloudPickle when CloudPickle and PyTorch have
13
+ compatible pickling APIs.
14
+ """
15
+
16
+ # Import all contents of the CloudPickle module in an attempt to include all functions required
17
+ # by ``torch.save``.
18
+
19
+ # CloudPickle does not include `Unpickler` in its namespace, which is required by PyTorch for
20
+ # deserialization. Noting that CloudPickle's `load()` and `loads()` routines are aliases for
21
+ # `pickle.load()` and `pickle.loads()`, we therefore import Unpickler from the native
22
+ # Python pickle library.
23
+ from pickle import Unpickler # noqa: F401
24
+
25
+ from cloudpickle import * # noqa: F403
26
+
27
+ # PyTorch uses the ``Pickler`` class of the specified ``pickle_module``
28
+ # (https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/
29
+ # serialization.py#L290). Unfortunately, ``cloudpickle.Pickler`` is an alias for Python's native
30
+ # pickling class: ``pickle.Pickler``, instead of ``cloudpickle.CloudPickler``.
31
+ # https://github.com/cloudpipe/cloudpickle/pull/235 has been filed to correct the issue,
32
+ # but this import renaming is necessary until either the requested change has been incorporated
33
+ # into a CloudPickle release or the ``torch.save`` API has been updated to be compatible with
34
+ # the existing CloudPickle API.
35
+ from cloudpickle import CloudPickler as Pickler # noqa: F401