genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,2111 @@
1
+ import ast
2
+ import base64
3
+ import json
4
+ import math
5
+ import operator
6
+ import re
7
+ import shlex
8
+ from dataclasses import asdict, dataclass
9
+ from typing import Any, Optional
10
+
11
+ import sqlparse
12
+ from packaging.version import Version
13
+ from sqlparse.sql import (
14
+ Comparison,
15
+ Identifier,
16
+ Parenthesis,
17
+ Statement,
18
+ Token,
19
+ TokenList,
20
+ )
21
+ from sqlparse.tokens import Token as TokenType
22
+
23
+ from mlflow.entities import LoggedModel, Metric, RunInfo
24
+ from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL
25
+ from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
26
+ from mlflow.exceptions import MlflowException
27
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
28
+ from mlflow.store.db.db_types import MSSQL, MYSQL, POSTGRES, SQLITE
29
+ from mlflow.tracing.constant import TraceMetadataKey, TraceTagKey
30
+ from mlflow.utils.mlflow_tags import (
31
+ MLFLOW_DATASET_CONTEXT,
32
+ )
33
+
34
+
35
+ def _convert_like_pattern_to_regex(pattern, flags=0):
36
+ if not pattern.startswith("%"):
37
+ pattern = "^" + pattern
38
+ if not pattern.endswith("%"):
39
+ pattern = pattern + "$"
40
+ return re.compile(pattern.replace("_", ".").replace("%", ".*"), flags)
41
+
42
+
43
+ def _like(string, pattern):
44
+ return _convert_like_pattern_to_regex(pattern).match(string) is not None
45
+
46
+
47
+ def _ilike(string, pattern):
48
+ return _convert_like_pattern_to_regex(pattern, flags=re.IGNORECASE).match(string) is not None
49
+
50
+
51
+ def _join_in_comparison_tokens(tokens, search_traces=False):
52
+ """
53
+ Find a sequence of tokens that matches the pattern of an IN comparison or a NOT IN comparison,
54
+ join the tokens into a single Comparison token. Otherwise, return the original list of tokens.
55
+ """
56
+ if Version(sqlparse.__version__) < Version("0.4.4"):
57
+ # In sqlparse < 0.4.4, IN is treated as a comparison, we don't need to join tokens
58
+ return tokens
59
+
60
+ non_whitespace_tokens = [t for t in tokens if not t.is_whitespace]
61
+ joined_tokens = []
62
+ num_tokens = len(non_whitespace_tokens)
63
+ iterator = enumerate(non_whitespace_tokens)
64
+ while elem := next(iterator, None):
65
+ index, first = elem
66
+ # We need at least 3 tokens to form an IN comparison or a NOT IN comparison
67
+ if num_tokens - index < 3:
68
+ joined_tokens.extend(non_whitespace_tokens[index:])
69
+ break
70
+
71
+ if search_traces:
72
+ # timestamp
73
+ if first.match(ttype=TokenType.Name.Builtin, values=["timestamp", "timestamp_ms"]):
74
+ (_, second) = next(iterator, (None, None))
75
+ (_, third) = next(iterator, (None, None))
76
+ if any(x is None for x in [second, third]):
77
+ raise MlflowException(
78
+ f"Invalid comparison clause with token `{first}, {second}, {third}`, "
79
+ "expected 3 tokens",
80
+ error_code=INVALID_PARAMETER_VALUE,
81
+ )
82
+ if (
83
+ second.match(
84
+ ttype=TokenType.Operator.Comparison,
85
+ values=SearchTraceUtils.VALID_NUMERIC_ATTRIBUTE_COMPARATORS,
86
+ )
87
+ and third.ttype == TokenType.Literal.Number.Integer
88
+ ):
89
+ joined_tokens.append(Comparison(TokenList([first, second, third])))
90
+ continue
91
+ else:
92
+ joined_tokens.extend([first, second, third])
93
+
94
+ # Wait until we encounter an identifier token
95
+ if not isinstance(first, Identifier):
96
+ joined_tokens.append(first)
97
+ continue
98
+
99
+ (_, second) = next(iterator)
100
+ (_, third) = next(iterator)
101
+
102
+ # IN
103
+ if (
104
+ isinstance(first, Identifier)
105
+ and second.match(ttype=TokenType.Keyword, values=["IN"])
106
+ and isinstance(third, Parenthesis)
107
+ ):
108
+ joined_tokens.append(Comparison(TokenList([first, second, third])))
109
+ continue
110
+
111
+ (_, fourth) = next(iterator, (None, None))
112
+ if fourth is None:
113
+ joined_tokens.extend([first, second, third])
114
+ break
115
+
116
+ # NOT IN
117
+ if (
118
+ isinstance(first, Identifier)
119
+ and second.match(ttype=TokenType.Keyword, values=["NOT"])
120
+ and third.match(ttype=TokenType.Keyword, values=["IN"])
121
+ and isinstance(fourth, Parenthesis)
122
+ ):
123
+ joined_tokens.append(
124
+ Comparison(TokenList([first, Token(TokenType.Keyword, "NOT IN"), fourth]))
125
+ )
126
+ continue
127
+
128
+ joined_tokens.extend([first, second, third, fourth])
129
+
130
+ return joined_tokens
131
+
132
+
133
+ class SearchUtils:
134
+ LIKE_OPERATOR = "LIKE"
135
+ ILIKE_OPERATOR = "ILIKE"
136
+ ASC_OPERATOR = "asc"
137
+ DESC_OPERATOR = "desc"
138
+ VALID_ORDER_BY_TAGS = [ASC_OPERATOR, DESC_OPERATOR]
139
+ VALID_METRIC_COMPARATORS = {">", ">=", "!=", "=", "<", "<="}
140
+ VALID_PARAM_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR}
141
+ VALID_TAG_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR}
142
+ VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IN", "NOT IN"}
143
+ VALID_NUMERIC_ATTRIBUTE_COMPARATORS = VALID_METRIC_COMPARATORS
144
+ VALID_DATASET_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IN", "NOT IN"}
145
+ _BUILTIN_NUMERIC_ATTRIBUTES = {"start_time", "end_time"}
146
+ _ALTERNATE_NUMERIC_ATTRIBUTES = {"created", "Created"}
147
+ _ALTERNATE_STRING_ATTRIBUTES = {"run name", "Run name", "Run Name"}
148
+ NUMERIC_ATTRIBUTES = set(
149
+ list(_BUILTIN_NUMERIC_ATTRIBUTES) + list(_ALTERNATE_NUMERIC_ATTRIBUTES)
150
+ )
151
+ DATASET_ATTRIBUTES = {"name", "digest", "context"}
152
+ VALID_SEARCH_ATTRIBUTE_KEYS = set(
153
+ RunInfo.get_searchable_attributes()
154
+ + list(_ALTERNATE_NUMERIC_ATTRIBUTES)
155
+ + list(_ALTERNATE_STRING_ATTRIBUTES)
156
+ )
157
+ VALID_ORDER_BY_ATTRIBUTE_KEYS = set(
158
+ RunInfo.get_orderable_attributes() + list(_ALTERNATE_NUMERIC_ATTRIBUTES)
159
+ )
160
+ _METRIC_IDENTIFIER = "metric"
161
+ _ALTERNATE_METRIC_IDENTIFIERS = {"metrics"}
162
+ _PARAM_IDENTIFIER = "parameter"
163
+ _ALTERNATE_PARAM_IDENTIFIERS = {"parameters", "param", "params"}
164
+ _TAG_IDENTIFIER = "tag"
165
+ _ALTERNATE_TAG_IDENTIFIERS = {"tags"}
166
+ _ATTRIBUTE_IDENTIFIER = "attribute"
167
+ _ALTERNATE_ATTRIBUTE_IDENTIFIERS = {"attr", "attributes", "run"}
168
+ _DATASET_IDENTIFIER = "dataset"
169
+ _ALTERNATE_DATASET_IDENTIFIERS = {"datasets"}
170
+ _IDENTIFIERS = [
171
+ _METRIC_IDENTIFIER,
172
+ _PARAM_IDENTIFIER,
173
+ _TAG_IDENTIFIER,
174
+ _ATTRIBUTE_IDENTIFIER,
175
+ _DATASET_IDENTIFIER,
176
+ ]
177
+ _VALID_IDENTIFIERS = set(
178
+ _IDENTIFIERS
179
+ + list(_ALTERNATE_METRIC_IDENTIFIERS)
180
+ + list(_ALTERNATE_PARAM_IDENTIFIERS)
181
+ + list(_ALTERNATE_TAG_IDENTIFIERS)
182
+ + list(_ALTERNATE_ATTRIBUTE_IDENTIFIERS)
183
+ + list(_ALTERNATE_DATASET_IDENTIFIERS)
184
+ )
185
+ STRING_VALUE_TYPES = {TokenType.Literal.String.Single}
186
+ DELIMITER_VALUE_TYPES = {TokenType.Punctuation}
187
+ WHITESPACE_VALUE_TYPE = TokenType.Text.Whitespace
188
+ NUMERIC_VALUE_TYPES = {TokenType.Literal.Number.Integer, TokenType.Literal.Number.Float}
189
+ # Registered Models Constants
190
+ ORDER_BY_KEY_TIMESTAMP = "timestamp"
191
+ ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP = "last_updated_timestamp"
192
+ ORDER_BY_KEY_MODEL_NAME = "name"
193
+ VALID_ORDER_BY_KEYS_REGISTERED_MODELS = {
194
+ ORDER_BY_KEY_TIMESTAMP,
195
+ ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP,
196
+ ORDER_BY_KEY_MODEL_NAME,
197
+ }
198
+ VALID_TIMESTAMP_ORDER_BY_KEYS = {ORDER_BY_KEY_TIMESTAMP, ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP}
199
+ # We encourage users to use timestamp for order-by
200
+ RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS = {ORDER_BY_KEY_MODEL_NAME, ORDER_BY_KEY_TIMESTAMP}
201
+
202
+ @staticmethod
203
+ def get_comparison_func(comparator):
204
+ return {
205
+ ">": operator.gt,
206
+ ">=": operator.ge,
207
+ "=": operator.eq,
208
+ "!=": operator.ne,
209
+ "<=": operator.le,
210
+ "<": operator.lt,
211
+ "LIKE": _like,
212
+ "ILIKE": _ilike,
213
+ "IN": lambda x, y: x in y,
214
+ "NOT IN": lambda x, y: x not in y,
215
+ }[comparator]
216
+
217
+ @staticmethod
218
+ def get_sql_comparison_func(comparator, dialect):
219
+ import sqlalchemy as sa
220
+
221
+ def comparison_func(column, value):
222
+ if comparator == "LIKE":
223
+ return column.like(value)
224
+ elif comparator == "ILIKE":
225
+ return column.ilike(value)
226
+ elif comparator == "IN":
227
+ return column.in_(value)
228
+ elif comparator == "NOT IN":
229
+ return ~column.in_(value)
230
+ return SearchUtils.get_comparison_func(comparator)(column, value)
231
+
232
+ def mssql_comparison_func(column, value):
233
+ if not isinstance(column.type, sa.types.String):
234
+ return comparison_func(column, value)
235
+
236
+ collated = column.collate("Japanese_Bushu_Kakusu_100_CS_AS_KS_WS")
237
+ return comparison_func(collated, value)
238
+
239
+ def mysql_comparison_func(column, value):
240
+ if not isinstance(column.type, sa.types.String):
241
+ return comparison_func(column, value)
242
+
243
+ # MySQL is case insensitive by default, so we need to use the binary operator to
244
+ # perform case sensitive comparisons.
245
+ templates = {
246
+ # Use non-binary ahead of binary comparison for runtime performance
247
+ "=": "({column} = :value AND BINARY {column} = :value)",
248
+ "!=": "({column} != :value OR BINARY {column} != :value)",
249
+ "LIKE": "({column} LIKE :value AND BINARY {column} LIKE :value)",
250
+ }
251
+ if comparator in templates:
252
+ column = f"{column.class_.__tablename__}.{column.key}"
253
+ return sa.text(templates[comparator].format(column=column)).bindparams(
254
+ sa.bindparam("value", value=value, unique=True)
255
+ )
256
+
257
+ return comparison_func(column, value)
258
+
259
+ return {
260
+ POSTGRES: comparison_func,
261
+ SQLITE: comparison_func,
262
+ MSSQL: mssql_comparison_func,
263
+ MYSQL: mysql_comparison_func,
264
+ }[dialect]
265
+
266
+ @staticmethod
267
+ def translate_key_alias(key):
268
+ if key in ["created", "Created"]:
269
+ return "start_time"
270
+ if key in ["run name", "Run name", "Run Name"]:
271
+ return "run_name"
272
+ return key
273
+
274
+ @classmethod
275
+ def _trim_ends(cls, string_value):
276
+ return string_value[1:-1]
277
+
278
+ @classmethod
279
+ def _is_quoted(cls, value, pattern):
280
+ return len(value) >= 2 and value.startswith(pattern) and value.endswith(pattern)
281
+
282
+ @classmethod
283
+ def _trim_backticks(cls, entity_type):
284
+ """Remove backticks from identifier like `param`, if they exist."""
285
+ if cls._is_quoted(entity_type, "`"):
286
+ return cls._trim_ends(entity_type)
287
+ return entity_type
288
+
289
+ @classmethod
290
+ def _strip_quotes(cls, value, expect_quoted_value=False):
291
+ """
292
+ Remove quotes for input string.
293
+ Values of type strings are expected to have quotes.
294
+ Keys containing special characters are also expected to be enclose in quotes.
295
+ """
296
+ if cls._is_quoted(value, "'") or cls._is_quoted(value, '"'):
297
+ return cls._trim_ends(value)
298
+ elif expect_quoted_value:
299
+ raise MlflowException(
300
+ "Parameter value is either not quoted or unidentified quote "
301
+ f"types used for string value {value}. Use either single or double "
302
+ "quotes.",
303
+ error_code=INVALID_PARAMETER_VALUE,
304
+ )
305
+ else:
306
+ return value
307
+
308
+ @classmethod
309
+ def _valid_entity_type(cls, entity_type):
310
+ entity_type = cls._trim_backticks(entity_type)
311
+ if entity_type not in cls._VALID_IDENTIFIERS:
312
+ raise MlflowException(
313
+ f"Invalid entity type '{entity_type}'. Valid values are {cls._IDENTIFIERS}",
314
+ error_code=INVALID_PARAMETER_VALUE,
315
+ )
316
+
317
+ if entity_type in cls._ALTERNATE_PARAM_IDENTIFIERS:
318
+ return cls._PARAM_IDENTIFIER
319
+ elif entity_type in cls._ALTERNATE_METRIC_IDENTIFIERS:
320
+ return cls._METRIC_IDENTIFIER
321
+ elif entity_type in cls._ALTERNATE_TAG_IDENTIFIERS:
322
+ return cls._TAG_IDENTIFIER
323
+ elif entity_type in cls._ALTERNATE_ATTRIBUTE_IDENTIFIERS:
324
+ return cls._ATTRIBUTE_IDENTIFIER
325
+ elif entity_type in cls._ALTERNATE_DATASET_IDENTIFIERS:
326
+ return cls._DATASET_IDENTIFIER
327
+ else:
328
+ # one of ("metric", "parameter", "tag", or "attribute") since it a valid type
329
+ return entity_type
330
+
331
+ @classmethod
332
+ def _get_identifier(cls, identifier, valid_attributes):
333
+ try:
334
+ tokens = identifier.split(".", 1)
335
+ if len(tokens) == 1:
336
+ key = tokens[0]
337
+ entity_type = cls._ATTRIBUTE_IDENTIFIER
338
+ else:
339
+ entity_type, key = tokens
340
+ except ValueError:
341
+ raise MlflowException(
342
+ f"Invalid identifier {identifier!r}. Columns should be specified as "
343
+ "'attribute.<key>', 'metric.<key>', 'tag.<key>', 'dataset.<key>', or "
344
+ "'param.'.",
345
+ error_code=INVALID_PARAMETER_VALUE,
346
+ )
347
+ identifier = cls._valid_entity_type(entity_type)
348
+ key = cls._trim_backticks(cls._strip_quotes(key))
349
+ if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
350
+ raise MlflowException.invalid_parameter_value(
351
+ f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
352
+ )
353
+ elif identifier == cls._DATASET_IDENTIFIER and key not in cls.DATASET_ATTRIBUTES:
354
+ raise MlflowException.invalid_parameter_value(
355
+ f"Invalid dataset key '{key}' specified. Valid keys are '{cls.DATASET_ATTRIBUTES}'"
356
+ )
357
+ return {"type": identifier, "key": key}
358
+
359
+ @classmethod
360
+ def validate_list_supported(cls, key: str) -> None:
361
+ if key != "run_id":
362
+ raise MlflowException(
363
+ "Only the 'run_id' attribute supports comparison with a list of quoted "
364
+ "string values.",
365
+ error_code=INVALID_PARAMETER_VALUE,
366
+ )
367
+
368
+ @classmethod
369
+ def _get_value(cls, identifier_type, key, token):
370
+ if identifier_type == cls._METRIC_IDENTIFIER:
371
+ if token.ttype not in cls.NUMERIC_VALUE_TYPES:
372
+ raise MlflowException(
373
+ f"Expected numeric value type for metric. Found {token.value}",
374
+ error_code=INVALID_PARAMETER_VALUE,
375
+ )
376
+ return token.value
377
+ elif identifier_type == cls._PARAM_IDENTIFIER or identifier_type == cls._TAG_IDENTIFIER:
378
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
379
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
380
+ raise MlflowException(
381
+ "Expected a quoted string value for "
382
+ f"{identifier_type} (e.g. 'my-value'). Got value "
383
+ f"{token.value}",
384
+ error_code=INVALID_PARAMETER_VALUE,
385
+ )
386
+ elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
387
+ if key in cls.NUMERIC_ATTRIBUTES:
388
+ if token.ttype not in cls.NUMERIC_VALUE_TYPES:
389
+ raise MlflowException(
390
+ f"Expected numeric value type for numeric attribute: {key}. "
391
+ f"Found {token.value}",
392
+ error_code=INVALID_PARAMETER_VALUE,
393
+ )
394
+ return token.value
395
+ elif token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
396
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
397
+ elif isinstance(token, Parenthesis):
398
+ cls.validate_list_supported(key)
399
+ return cls._parse_run_ids(token)
400
+ else:
401
+ raise MlflowException(
402
+ f"Expected a quoted string value for attributes. Got value {token.value}",
403
+ error_code=INVALID_PARAMETER_VALUE,
404
+ )
405
+ elif identifier_type == cls._DATASET_IDENTIFIER:
406
+ if key in cls.DATASET_ATTRIBUTES and (
407
+ token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier)
408
+ ):
409
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
410
+ elif isinstance(token, Parenthesis):
411
+ if key not in ("name", "digest", "context"):
412
+ raise MlflowException(
413
+ "Only the dataset 'name' and 'digest' supports comparison with a list of "
414
+ "quoted string values.",
415
+ error_code=INVALID_PARAMETER_VALUE,
416
+ )
417
+ return cls._parse_run_ids(token)
418
+ else:
419
+ raise MlflowException(
420
+ "Expected a quoted string value for dataset attributes. "
421
+ f"Got value {token.value}",
422
+ error_code=INVALID_PARAMETER_VALUE,
423
+ )
424
+ else:
425
+ # Expected to be either "param" or "metric".
426
+ raise MlflowException(
427
+ "Invalid identifier type. Expected one of "
428
+ f"{[cls._METRIC_IDENTIFIER, cls._PARAM_IDENTIFIER]}."
429
+ )
430
+
431
+ @classmethod
432
+ def _validate_comparison(cls, tokens, search_traces=False):
433
+ base_error_string = "Invalid comparison clause"
434
+ if len(tokens) != 3:
435
+ raise MlflowException(
436
+ f"{base_error_string}. Expected 3 tokens found {len(tokens)}",
437
+ error_code=INVALID_PARAMETER_VALUE,
438
+ )
439
+ if not isinstance(tokens[0], Identifier):
440
+ if not search_traces:
441
+ raise MlflowException(
442
+ f"{base_error_string}. Expected 'Identifier' found '{tokens[0]}'",
443
+ error_code=INVALID_PARAMETER_VALUE,
444
+ )
445
+ if search_traces and not tokens[0].match(
446
+ ttype=TokenType.Name.Builtin, values=["timestamp", "timestamp_ms"]
447
+ ):
448
+ raise MlflowException(
449
+ f"{base_error_string}. Expected 'TokenType.Name.Builtin' found '{tokens[0]}'",
450
+ error_code=INVALID_PARAMETER_VALUE,
451
+ )
452
+ if not isinstance(tokens[1], Token) and tokens[1].ttype != TokenType.Operator.Comparison:
453
+ raise MlflowException(
454
+ f"{base_error_string}. Expected comparison found '{tokens[1]}'",
455
+ error_code=INVALID_PARAMETER_VALUE,
456
+ )
457
+ if not isinstance(tokens[2], Token) and (
458
+ tokens[2].ttype not in cls.STRING_VALUE_TYPES.union(cls.NUMERIC_VALUE_TYPES)
459
+ or isinstance(tokens[2], Identifier)
460
+ ):
461
+ raise MlflowException(
462
+ f"{base_error_string}. Expected value token found '{tokens[2]}'",
463
+ error_code=INVALID_PARAMETER_VALUE,
464
+ )
465
+
466
+ @classmethod
467
+ def _get_comparison(cls, comparison):
468
+ stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
469
+ cls._validate_comparison(stripped_comparison)
470
+ comp = cls._get_identifier(stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
471
+ comp["comparator"] = stripped_comparison[1].value
472
+ comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), stripped_comparison[2])
473
+ return comp
474
+
475
+ @classmethod
476
+ def _invalid_statement_token_search_runs(cls, token):
477
+ if (
478
+ isinstance(token, Comparison)
479
+ or token.is_whitespace
480
+ or token.match(ttype=TokenType.Keyword, values=["AND"])
481
+ ):
482
+ return False
483
+ return True
484
+
485
+ @classmethod
486
+ def _process_statement(cls, statement):
487
+ # check validity
488
+ tokens = _join_in_comparison_tokens(statement.tokens)
489
+ invalids = list(filter(cls._invalid_statement_token_search_runs, tokens))
490
+ if len(invalids) > 0:
491
+ invalid_clauses = ", ".join(f"'{token}'" for token in invalids)
492
+ raise MlflowException(
493
+ f"Invalid clause(s) in filter string: {invalid_clauses}",
494
+ error_code=INVALID_PARAMETER_VALUE,
495
+ )
496
+ return [cls._get_comparison(si) for si in tokens if isinstance(si, Comparison)]
497
+
498
+ @classmethod
499
+ def parse_search_filter(cls, filter_string):
500
+ if not filter_string:
501
+ return []
502
+ try:
503
+ parsed = sqlparse.parse(filter_string)
504
+ except Exception:
505
+ raise MlflowException(
506
+ f"Error on parsing filter '{filter_string}'", error_code=INVALID_PARAMETER_VALUE
507
+ )
508
+ if len(parsed) == 0 or not isinstance(parsed[0], Statement):
509
+ raise MlflowException(
510
+ f"Invalid filter '{filter_string}'. Could not be parsed.",
511
+ error_code=INVALID_PARAMETER_VALUE,
512
+ )
513
+ elif len(parsed) > 1:
514
+ raise MlflowException(
515
+ f"Search filter contained multiple expression {filter_string!r}. "
516
+ "Provide AND-ed expression list.",
517
+ error_code=INVALID_PARAMETER_VALUE,
518
+ )
519
+ return cls._process_statement(parsed[0])
520
+
521
+ @classmethod
522
+ def is_metric(cls, key_type, comparator):
523
+ if key_type == cls._METRIC_IDENTIFIER:
524
+ if comparator not in cls.VALID_METRIC_COMPARATORS:
525
+ raise MlflowException(
526
+ f"Invalid comparator '{comparator}' not one of '{cls.VALID_METRIC_COMPARATORS}",
527
+ error_code=INVALID_PARAMETER_VALUE,
528
+ )
529
+ return True
530
+ return False
531
+
532
+ @classmethod
533
+ def is_param(cls, key_type, comparator):
534
+ if key_type == cls._PARAM_IDENTIFIER:
535
+ if comparator not in cls.VALID_PARAM_COMPARATORS:
536
+ raise MlflowException(
537
+ f"Invalid comparator '{comparator}' not one of '{cls.VALID_PARAM_COMPARATORS}'",
538
+ error_code=INVALID_PARAMETER_VALUE,
539
+ )
540
+ return True
541
+ return False
542
+
543
+ @classmethod
544
+ def is_tag(cls, key_type, comparator):
545
+ if key_type == cls._TAG_IDENTIFIER:
546
+ if comparator not in cls.VALID_TAG_COMPARATORS:
547
+ raise MlflowException(
548
+ f"Invalid comparator '{comparator}' not one of '{cls.VALID_TAG_COMPARATORS}",
549
+ error_code=INVALID_PARAMETER_VALUE,
550
+ )
551
+ return True
552
+ return False
553
+
554
+ @classmethod
555
+ def is_attribute(cls, key_type, key_name, comparator):
556
+ return cls.is_string_attribute(key_type, key_name, comparator) or cls.is_numeric_attribute(
557
+ key_type, key_name, comparator
558
+ )
559
+
560
+ @classmethod
561
+ def is_string_attribute(cls, key_type, key_name, comparator):
562
+ if key_type == cls._ATTRIBUTE_IDENTIFIER and key_name not in cls.NUMERIC_ATTRIBUTES:
563
+ if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS:
564
+ raise MlflowException(
565
+ f"Invalid comparator '{comparator}' not one of "
566
+ f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}'",
567
+ error_code=INVALID_PARAMETER_VALUE,
568
+ )
569
+ return True
570
+ return False
571
+
572
+ @classmethod
573
+ def is_numeric_attribute(cls, key_type, key_name, comparator):
574
+ if key_type == cls._ATTRIBUTE_IDENTIFIER and key_name in cls.NUMERIC_ATTRIBUTES:
575
+ if comparator not in cls.VALID_NUMERIC_ATTRIBUTE_COMPARATORS:
576
+ raise MlflowException(
577
+ f"Invalid comparator '{comparator}' not one of "
578
+ f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}",
579
+ error_code=INVALID_PARAMETER_VALUE,
580
+ )
581
+ return True
582
+ return False
583
+
584
+ @classmethod
585
+ def is_dataset(cls, key_type, comparator):
586
+ if key_type == cls._DATASET_IDENTIFIER:
587
+ if comparator not in cls.VALID_DATASET_COMPARATORS:
588
+ raise MlflowException(
589
+ f"Invalid comparator '{comparator}' "
590
+ f"not one of '{cls.VALID_DATASET_COMPARATORS}",
591
+ error_code=INVALID_PARAMETER_VALUE,
592
+ )
593
+ return True
594
+ return False
595
+
596
+ @classmethod
597
+ def _is_metric_on_dataset(cls, metric: Metric, dataset: dict[str, Any]) -> bool:
598
+ return metric.dataset_name == dataset.get("dataset_name") and (
599
+ dataset.get("dataset_digest") is None
600
+ or dataset.get("dataset_digest") == metric.dataset_digest
601
+ )
602
+
603
+ @classmethod
604
+ def _does_run_match_clause(cls, run, sed):
605
+ key_type = sed.get("type")
606
+ key = sed.get("key")
607
+ value = sed.get("value")
608
+ comparator = sed.get("comparator").upper()
609
+
610
+ key = SearchUtils.translate_key_alias(key)
611
+
612
+ if cls.is_metric(key_type, comparator):
613
+ lhs = run.data.metrics.get(key, None)
614
+ value = float(value)
615
+ elif cls.is_param(key_type, comparator):
616
+ lhs = run.data.params.get(key, None)
617
+ elif cls.is_tag(key_type, comparator):
618
+ lhs = run.data.tags.get(key, None)
619
+ elif cls.is_string_attribute(key_type, key, comparator):
620
+ lhs = getattr(run.info, key)
621
+ elif cls.is_numeric_attribute(key_type, key, comparator):
622
+ lhs = getattr(run.info, key)
623
+ value = int(value)
624
+ elif cls.is_dataset(key_type, comparator):
625
+ if key == "context":
626
+ return any(
627
+ SearchUtils.get_comparison_func(comparator)(tag.value if tag else None, value)
628
+ for dataset_input in run.inputs.dataset_inputs
629
+ for tag in dataset_input.tags
630
+ if tag.key == MLFLOW_DATASET_CONTEXT
631
+ )
632
+ else:
633
+ return any(
634
+ SearchUtils.get_comparison_func(comparator)(
635
+ getattr(dataset_input.dataset, key), value
636
+ )
637
+ for dataset_input in run.inputs.dataset_inputs
638
+ )
639
+ else:
640
+ raise MlflowException(
641
+ f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
642
+ )
643
+ if lhs is None:
644
+ return False
645
+
646
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
647
+
648
+ @classmethod
649
+ def _does_model_match_clause(cls, model, sed):
650
+ key_type = sed.get("type")
651
+ key = sed.get("key")
652
+ value = sed.get("value")
653
+ comparator = sed.get("comparator").upper()
654
+
655
+ key = SearchUtils.translate_key_alias(key)
656
+
657
+ if cls.is_metric(key_type, comparator):
658
+ matching_metrics = [metric for metric in model.metrics if metric.key == key]
659
+ lhs = matching_metrics[0].value if matching_metrics else None
660
+ value = float(value)
661
+ elif cls.is_param(key_type, comparator):
662
+ lhs = model.params.get(key, None)
663
+ elif cls.is_tag(key_type, comparator):
664
+ lhs = model.tags.get(key, None)
665
+ elif cls.is_string_attribute(key_type, key, comparator):
666
+ lhs = getattr(model.info, key)
667
+ elif cls.is_numeric_attribute(key_type, key, comparator):
668
+ lhs = getattr(model.info, key)
669
+ value = int(value)
670
+ else:
671
+ raise MlflowException(
672
+ f"Invalid model search expression type '{key_type}'",
673
+ error_code=INVALID_PARAMETER_VALUE,
674
+ )
675
+ if lhs is None:
676
+ return False
677
+
678
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
679
+
680
+ @classmethod
681
+ def filter(cls, runs, filter_string):
682
+ """Filters a set of runs based on a search filter string."""
683
+ if not filter_string:
684
+ return runs
685
+ parsed = cls.parse_search_filter(filter_string)
686
+
687
+ def run_matches(run):
688
+ return all(cls._does_run_match_clause(run, s) for s in parsed)
689
+
690
+ return [run for run in runs if run_matches(run)]
691
+
692
+ @classmethod
693
+ def _validate_order_by_and_generate_token(cls, order_by):
694
+ try:
695
+ parsed = sqlparse.parse(order_by)
696
+ except Exception:
697
+ raise MlflowException(
698
+ f"Error on parsing order_by clause '{order_by}'",
699
+ error_code=INVALID_PARAMETER_VALUE,
700
+ )
701
+ if len(parsed) != 1 or not isinstance(parsed[0], Statement):
702
+ raise MlflowException(
703
+ f"Invalid order_by clause '{order_by}'. Could not be parsed.",
704
+ error_code=INVALID_PARAMETER_VALUE,
705
+ )
706
+ statement = parsed[0]
707
+ ttype_for_timestamp = (
708
+ TokenType.Name.Builtin
709
+ if Version(sqlparse.__version__) >= Version("0.4.3")
710
+ else TokenType.Keyword
711
+ )
712
+
713
+ if len(statement.tokens) == 1 and isinstance(statement[0], Identifier):
714
+ token_value = statement.tokens[0].value
715
+ elif len(statement.tokens) == 1 and statement.tokens[0].match(
716
+ ttype=ttype_for_timestamp, values=[cls.ORDER_BY_KEY_TIMESTAMP]
717
+ ):
718
+ token_value = cls.ORDER_BY_KEY_TIMESTAMP
719
+ elif (
720
+ statement.tokens[0].match(
721
+ ttype=ttype_for_timestamp, values=[cls.ORDER_BY_KEY_TIMESTAMP]
722
+ )
723
+ and all(token.is_whitespace for token in statement.tokens[1:-1])
724
+ and statement.tokens[-1].ttype == TokenType.Keyword.Order
725
+ ):
726
+ token_value = cls.ORDER_BY_KEY_TIMESTAMP + " " + statement.tokens[-1].value
727
+ else:
728
+ raise MlflowException(
729
+ f"Invalid order_by clause '{order_by}'. Could not be parsed.",
730
+ error_code=INVALID_PARAMETER_VALUE,
731
+ )
732
+ return token_value
733
+
734
+ @classmethod
735
+ def _parse_order_by_string(cls, order_by):
736
+ token_value = cls._validate_order_by_and_generate_token(order_by)
737
+ is_ascending = True
738
+ tokens = shlex.split(token_value.replace("`", '"'))
739
+ if len(tokens) > 2:
740
+ raise MlflowException(
741
+ f"Invalid order_by clause '{order_by}'. Could not be parsed.",
742
+ error_code=INVALID_PARAMETER_VALUE,
743
+ )
744
+ elif len(tokens) == 2:
745
+ order_token = tokens[1].lower()
746
+ if order_token not in cls.VALID_ORDER_BY_TAGS:
747
+ raise MlflowException(
748
+ f"Invalid ordering key in order_by clause '{order_by}'.",
749
+ error_code=INVALID_PARAMETER_VALUE,
750
+ )
751
+ is_ascending = order_token == cls.ASC_OPERATOR
752
+ token_value = tokens[0]
753
+ return token_value, is_ascending
754
+
755
+ @classmethod
756
+ def parse_order_by_for_search_runs(cls, order_by):
757
+ token_value, is_ascending = cls._parse_order_by_string(order_by)
758
+ identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS)
759
+ return identifier["type"], identifier["key"], is_ascending
760
+
761
+ @classmethod
762
+ def parse_order_by_for_search_registered_models(cls, order_by):
763
+ token_value, is_ascending = cls._parse_order_by_string(order_by)
764
+ token_value = token_value.strip()
765
+ if token_value not in cls.VALID_ORDER_BY_KEYS_REGISTERED_MODELS:
766
+ raise MlflowException(
767
+ f"Invalid order by key '{token_value}' specified. Valid keys "
768
+ f"are '{cls.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'",
769
+ error_code=INVALID_PARAMETER_VALUE,
770
+ )
771
+ return token_value, is_ascending
772
+
773
+ @classmethod
774
+ def _get_value_for_sort(cls, run, key_type, key, ascending):
775
+ """Returns a tuple suitable to be used as a sort key for runs."""
776
+ sort_value = None
777
+ key = SearchUtils.translate_key_alias(key)
778
+ if key_type == cls._METRIC_IDENTIFIER:
779
+ sort_value = run.data.metrics.get(key)
780
+ elif key_type == cls._PARAM_IDENTIFIER:
781
+ sort_value = run.data.params.get(key)
782
+ elif key_type == cls._TAG_IDENTIFIER:
783
+ sort_value = run.data.tags.get(key)
784
+ elif key_type == cls._ATTRIBUTE_IDENTIFIER:
785
+ sort_value = getattr(run.info, key)
786
+ else:
787
+ raise MlflowException(
788
+ f"Invalid order_by entity type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
789
+ )
790
+
791
+ # Return a key such that None values are always at the end.
792
+ is_none = sort_value is None
793
+ is_nan = isinstance(sort_value, float) and math.isnan(sort_value)
794
+ fill_value = (1 if ascending else -1) * math.inf
795
+
796
+ if is_none:
797
+ sort_value = fill_value
798
+ elif is_nan:
799
+ sort_value = -fill_value
800
+
801
+ is_none_or_nan = is_none or is_nan
802
+
803
+ return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value)
804
+
805
+ @classmethod
806
+ def _get_model_value_for_sort(cls, model, key_type, key, ascending):
807
+ """Returns a tuple suitable to be used as a sort key for models."""
808
+ sort_value = None
809
+ key = SearchUtils.translate_key_alias(key)
810
+ if key_type == cls._METRIC_IDENTIFIER:
811
+ matching_metrics = [metric for metric in model.metrics if metric.key == key]
812
+ sort_value = float(matching_metrics[0].value) if matching_metrics else None
813
+ elif key_type == cls._PARAM_IDENTIFIER:
814
+ sort_value = model.params.get(key)
815
+ elif key_type == cls._TAG_IDENTIFIER:
816
+ sort_value = model.tags.get(key)
817
+ elif key_type == cls._ATTRIBUTE_IDENTIFIER:
818
+ sort_value = getattr(model, key)
819
+ else:
820
+ raise MlflowException(
821
+ f"Invalid models order_by entity type '{key_type}'",
822
+ error_code=INVALID_PARAMETER_VALUE,
823
+ )
824
+
825
+ # Return a key such that None values are always at the end.
826
+ is_none = sort_value is None
827
+ is_nan = isinstance(sort_value, float) and math.isnan(sort_value)
828
+ fill_value = (1 if ascending else -1) * math.inf
829
+
830
+ if is_none:
831
+ sort_value = fill_value
832
+ elif is_nan:
833
+ sort_value = -fill_value
834
+
835
+ is_none_or_nan = is_none or is_nan
836
+
837
+ return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value)
838
+
839
+ @classmethod
840
+ def sort(cls, runs, order_by_list):
841
+ """Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
842
+ Runs are naturally ordered first by start time descending, then by run id for tie-breaking.
843
+ """
844
+ runs = sorted(runs, key=lambda run: (-run.info.start_time, run.info.run_id))
845
+ if not order_by_list:
846
+ return runs
847
+ # NB: We rely on the stability of Python's sort function, so that we can apply
848
+ # the ordering conditions in reverse order.
849
+ for order_by_clause in reversed(order_by_list):
850
+ (key_type, key, ascending) = cls.parse_order_by_for_search_runs(order_by_clause)
851
+
852
+ runs = sorted(
853
+ runs,
854
+ key=lambda run: cls._get_value_for_sort(run, key_type, key, ascending),
855
+ reverse=not ascending,
856
+ )
857
+ return runs
858
+
859
+ @classmethod
860
+ def parse_start_offset_from_page_token(cls, page_token):
861
+ # Note: the page_token is expected to be a base64-encoded JSON that looks like
862
+ # { "offset": xxx }. However, this format is not stable, so it should not be
863
+ # relied upon outside of this method.
864
+ if not page_token:
865
+ return 0
866
+
867
+ try:
868
+ decoded_token = base64.b64decode(page_token)
869
+ except TypeError:
870
+ raise MlflowException(
871
+ "Invalid page token, could not base64-decode", error_code=INVALID_PARAMETER_VALUE
872
+ )
873
+ except base64.binascii.Error:
874
+ raise MlflowException(
875
+ "Invalid page token, could not base64-decode", error_code=INVALID_PARAMETER_VALUE
876
+ )
877
+
878
+ try:
879
+ parsed_token = json.loads(decoded_token)
880
+ except ValueError:
881
+ raise MlflowException(
882
+ f"Invalid page token, decoded value={decoded_token}",
883
+ error_code=INVALID_PARAMETER_VALUE,
884
+ )
885
+
886
+ offset_str = parsed_token.get("offset")
887
+ if not offset_str:
888
+ raise MlflowException(
889
+ f"Invalid page token, parsed value={parsed_token}",
890
+ error_code=INVALID_PARAMETER_VALUE,
891
+ )
892
+
893
+ try:
894
+ offset = int(offset_str)
895
+ except ValueError:
896
+ raise MlflowException(
897
+ f"Invalid page token, not stringable {offset_str}",
898
+ error_code=INVALID_PARAMETER_VALUE,
899
+ )
900
+
901
+ return offset
902
+
903
+ @classmethod
904
+ def create_page_token(cls, offset):
905
+ return base64.b64encode(json.dumps({"offset": offset}).encode("utf-8"))
906
+
907
+ @classmethod
908
+ def paginate(cls, runs, page_token, max_results):
909
+ """Paginates a set of runs based on an offset encoded into the page_token and a max
910
+ results limit. Returns a pair containing the set of paginated runs, followed by
911
+ an optional next_page_token if there are further results that need to be returned.
912
+ """
913
+ start_offset = cls.parse_start_offset_from_page_token(page_token)
914
+ final_offset = start_offset + max_results
915
+
916
+ paginated_runs = runs[start_offset:final_offset]
917
+ next_page_token = None
918
+ if final_offset < len(runs):
919
+ next_page_token = cls.create_page_token(final_offset)
920
+ return (paginated_runs, next_page_token)
921
+
922
+ # Model Registry specific parser
923
+ # TODO: Tech debt. Refactor search code into common utils, tracking server, and model
924
+ # registry specific code.
925
+
926
+ VALID_SEARCH_KEYS_FOR_MODEL_VERSIONS = {"name", "run_id", "source_path"}
927
+ VALID_SEARCH_KEYS_FOR_REGISTERED_MODELS = {"name"}
928
+
929
+ @classmethod
930
+ def _check_valid_identifier_list(cls, tup: tuple[Any, ...]) -> None:
931
+ """
932
+ Validate that `tup` is a non-empty tuple of strings.
933
+ """
934
+ if len(tup) == 0:
935
+ raise MlflowException(
936
+ "While parsing a list in the query,"
937
+ " expected a non-empty list of string values, but got empty list",
938
+ error_code=INVALID_PARAMETER_VALUE,
939
+ )
940
+
941
+ if not all(isinstance(x, str) for x in tup):
942
+ raise MlflowException(
943
+ "While parsing a list in the query, expected string value, punctuation, "
944
+ f"or whitespace, but got different type in list: {tup}",
945
+ error_code=INVALID_PARAMETER_VALUE,
946
+ )
947
+
948
+ @classmethod
949
+ def _parse_list_from_sql_token(cls, token):
950
+ try:
951
+ parsed = ast.literal_eval(token.value)
952
+ except SyntaxError as e:
953
+ raise MlflowException(
954
+ "While parsing a list in the query,"
955
+ " expected a non-empty list of string values, but got ill-formed list.",
956
+ error_code=INVALID_PARAMETER_VALUE,
957
+ ) from e
958
+
959
+ parsed = parsed if isinstance(parsed, tuple) else (parsed,)
960
+ cls._check_valid_identifier_list(parsed)
961
+ return parsed
962
+
963
+ @classmethod
964
+ def _parse_run_ids(cls, token):
965
+ run_id_list = cls._parse_list_from_sql_token(token)
966
+ # Because MySQL IN clause is case-insensitive, but all run_ids only contain lower
967
+ # case letters, so that we filter out run_ids containing upper case letters here.
968
+ return [run_id for run_id in run_id_list if run_id.islower()]
969
+
970
+
971
+ class SearchExperimentsUtils(SearchUtils):
972
+ VALID_SEARCH_ATTRIBUTE_KEYS = {"name", "creation_time", "last_update_time"}
973
+ VALID_ORDER_BY_ATTRIBUTE_KEYS = {"name", "experiment_id", "creation_time", "last_update_time"}
974
+ NUMERIC_ATTRIBUTES = {"creation_time", "last_update_time"}
975
+
976
+ @classmethod
977
+ def _invalid_statement_token_search_experiments(cls, token):
978
+ if (
979
+ isinstance(token, Comparison)
980
+ or token.is_whitespace
981
+ or token.match(ttype=TokenType.Keyword, values=["AND"])
982
+ ):
983
+ return False
984
+ return True
985
+
986
+ @classmethod
987
+ def _process_statement(cls, statement):
988
+ tokens = _join_in_comparison_tokens(statement.tokens)
989
+ invalids = list(filter(cls._invalid_statement_token_search_experiments, tokens))
990
+ if len(invalids) > 0:
991
+ invalid_clauses = ", ".join(map(str, invalids))
992
+ raise MlflowException.invalid_parameter_value(
993
+ f"Invalid clause(s) in filter string: {invalid_clauses}"
994
+ )
995
+ return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)]
996
+
997
+ @classmethod
998
+ def _get_identifier(cls, identifier, valid_attributes):
999
+ tokens = identifier.split(".", maxsplit=1)
1000
+ if len(tokens) == 1:
1001
+ key = tokens[0]
1002
+ identifier = cls._ATTRIBUTE_IDENTIFIER
1003
+ else:
1004
+ entity_type, key = tokens
1005
+ valid_entity_types = ("attribute", "tag", "tags")
1006
+ if entity_type not in valid_entity_types:
1007
+ raise MlflowException.invalid_parameter_value(
1008
+ f"Invalid entity type '{entity_type}'. "
1009
+ f"Valid entity types are {valid_entity_types}"
1010
+ )
1011
+ identifier = cls._valid_entity_type(entity_type)
1012
+
1013
+ key = cls._trim_backticks(cls._strip_quotes(key))
1014
+ if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
1015
+ raise MlflowException.invalid_parameter_value(
1016
+ f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
1017
+ )
1018
+ return {"type": identifier, "key": key}
1019
+
1020
+ @classmethod
1021
+ def _get_comparison(cls, comparison):
1022
+ stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
1023
+ cls._validate_comparison(stripped_comparison)
1024
+ left, comparator, right = stripped_comparison
1025
+ comp = cls._get_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
1026
+ comp["comparator"] = comparator.value
1027
+ comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right)
1028
+ return comp
1029
+
1030
+ @classmethod
1031
+ def parse_order_by_for_search_experiments(cls, order_by):
1032
+ token_value, is_ascending = cls._parse_order_by_string(order_by)
1033
+ identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS)
1034
+ return identifier["type"], identifier["key"], is_ascending
1035
+
1036
+ @classmethod
1037
+ def is_attribute(cls, key_type, comparator):
1038
+ if key_type == cls._ATTRIBUTE_IDENTIFIER:
1039
+ if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS:
1040
+ raise MlflowException(
1041
+ f"Invalid comparator '{comparator}' not one of "
1042
+ f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}'"
1043
+ )
1044
+ return True
1045
+ return False
1046
+
1047
+ @classmethod
1048
+ def _does_experiment_match_clause(cls, experiment, sed):
1049
+ key_type = sed.get("type")
1050
+ key = sed.get("key")
1051
+ value = sed.get("value")
1052
+ comparator = sed.get("comparator").upper()
1053
+
1054
+ if cls.is_string_attribute(key_type, key, comparator):
1055
+ lhs = getattr(experiment, key)
1056
+ elif cls.is_numeric_attribute(key_type, key, comparator):
1057
+ lhs = getattr(experiment, key)
1058
+ value = float(value)
1059
+ elif cls.is_tag(key_type, comparator):
1060
+ if key not in experiment.tags:
1061
+ return False
1062
+ lhs = experiment.tags.get(key, None)
1063
+ if lhs is None:
1064
+ return experiment
1065
+ else:
1066
+ raise MlflowException(
1067
+ f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
1068
+ )
1069
+
1070
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
1071
+
1072
+ @classmethod
1073
+ def filter(cls, experiments, filter_string):
1074
+ if not filter_string:
1075
+ return experiments
1076
+ parsed = cls.parse_search_filter(filter_string)
1077
+
1078
+ def experiment_matches(experiment):
1079
+ return all(cls._does_experiment_match_clause(experiment, s) for s in parsed)
1080
+
1081
+ return list(filter(experiment_matches, experiments))
1082
+
1083
+ @classmethod
1084
+ def _get_sort_key(cls, order_by_list):
1085
+ order_by = []
1086
+ parsed_order_by = map(cls.parse_order_by_for_search_experiments, order_by_list)
1087
+ for type_, key, ascending in parsed_order_by:
1088
+ if type_ == "attribute":
1089
+ order_by.append((key, ascending))
1090
+ else:
1091
+ raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
1092
+
1093
+ # Add a tie-breaker
1094
+ if not any(key == "experiment_id" for key, _ in order_by):
1095
+ order_by.append(("experiment_id", False))
1096
+
1097
+ # https://stackoverflow.com/a/56842689
1098
+ class _Sorter:
1099
+ def __init__(self, obj, ascending):
1100
+ self.obj = obj
1101
+ self.ascending = ascending
1102
+
1103
+ # Only need < and == are needed for use as a key parameter in the sorted function
1104
+ def __eq__(self, other):
1105
+ return other.obj == self.obj
1106
+
1107
+ def __lt__(self, other):
1108
+ if self.obj is None:
1109
+ return False
1110
+ elif other.obj is None:
1111
+ return True
1112
+ elif self.ascending:
1113
+ return self.obj < other.obj
1114
+ else:
1115
+ return other.obj < self.obj
1116
+
1117
+ def _apply_sorter(experiment, key, ascending):
1118
+ attr = getattr(experiment, key)
1119
+ return _Sorter(attr, ascending)
1120
+
1121
+ return lambda experiment: tuple(_apply_sorter(experiment, k, asc) for (k, asc) in order_by)
1122
+
1123
+ @classmethod
1124
+ def sort(cls, experiments, order_by_list):
1125
+ return sorted(experiments, key=cls._get_sort_key(order_by_list))
1126
+
1127
+
1128
+ # https://stackoverflow.com/a/56842689
1129
+ class _Reversor:
1130
+ def __init__(self, obj):
1131
+ self.obj = obj
1132
+
1133
+ # Only need < and == are needed for use as a key parameter in the sorted function
1134
+ def __eq__(self, other):
1135
+ return other.obj == self.obj
1136
+
1137
+ def __lt__(self, other):
1138
+ if self.obj is None:
1139
+ return False
1140
+ if other.obj is None:
1141
+ return True
1142
+ return other.obj < self.obj
1143
+
1144
+
1145
+ def _apply_reversor(model, key, ascending):
1146
+ attr = getattr(model, key)
1147
+ return attr if ascending else _Reversor(attr)
1148
+
1149
+
1150
+ class SearchModelUtils(SearchUtils):
1151
+ NUMERIC_ATTRIBUTES = {"creation_timestamp", "last_updated_timestamp"}
1152
+ VALID_SEARCH_ATTRIBUTE_KEYS = {"name"}
1153
+ VALID_ORDER_BY_KEYS_REGISTERED_MODELS = {"name", "creation_timestamp", "last_updated_timestamp"}
1154
+
1155
+ @classmethod
1156
+ def _does_registered_model_match_clauses(cls, model, sed):
1157
+ key_type = sed.get("type")
1158
+ key = sed.get("key")
1159
+ value = sed.get("value")
1160
+ comparator = sed.get("comparator").upper()
1161
+
1162
+ # what comparators do we support here?
1163
+ if cls.is_string_attribute(key_type, key, comparator):
1164
+ lhs = getattr(model, key)
1165
+ elif cls.is_numeric_attribute(key_type, key, comparator):
1166
+ lhs = getattr(model, key)
1167
+ value = int(value)
1168
+ elif cls.is_tag(key_type, comparator):
1169
+ # NB: We should use the private attribute `_tags` instead of the `tags` property
1170
+ # to consider all tags including reserved ones.
1171
+ lhs = model._tags.get(key, None)
1172
+ else:
1173
+ raise MlflowException(
1174
+ f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
1175
+ )
1176
+
1177
+ # NB: Handling the special `mlflow.prompt.is_prompt` tag. This tag is used for
1178
+ # distinguishing between prompt models and normal models. For example, we want to
1179
+ # search for models only by the following filter string:
1180
+ #
1181
+ # tags.`mlflow.prompt.is_prompt` != 'true'
1182
+ # tags.`mlflow.prompt.is_prompt` = 'false'
1183
+ #
1184
+ # However, models do not have this tag, so lhs is None in this case. Instead of returning
1185
+ # False like normal tag filter, we need to return True here.
1186
+ if key == IS_PROMPT_TAG_KEY and lhs is None:
1187
+ return (comparator == "=" and value == "false") or (
1188
+ comparator == "!=" and value == "true"
1189
+ )
1190
+
1191
+ if lhs is None:
1192
+ return False
1193
+
1194
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
1195
+
1196
+ @classmethod
1197
+ def filter(cls, registered_models, filter_string):
1198
+ """Filters a set of registered models based on a search filter string."""
1199
+ if not filter_string:
1200
+ return registered_models
1201
+ parsed = cls.parse_search_filter(filter_string)
1202
+
1203
+ def registered_model_matches(model):
1204
+ return all(cls._does_registered_model_match_clauses(model, s) for s in parsed)
1205
+
1206
+ return [
1207
+ registered_model
1208
+ for registered_model in registered_models
1209
+ if registered_model_matches(registered_model)
1210
+ ]
1211
+
1212
+ @classmethod
1213
+ def parse_order_by_for_search_registered_models(cls, order_by):
1214
+ token_value, is_ascending = cls._parse_order_by_string(order_by)
1215
+ identifier = SearchExperimentsUtils._get_identifier(
1216
+ token_value.strip(), cls.VALID_ORDER_BY_KEYS_REGISTERED_MODELS
1217
+ )
1218
+ return identifier["type"], identifier["key"], is_ascending
1219
+
1220
+ @classmethod
1221
+ def _get_sort_key(cls, order_by_list):
1222
+ order_by = []
1223
+ parsed_order_by = map(cls.parse_order_by_for_search_registered_models, order_by_list or [])
1224
+ for type_, key, ascending in parsed_order_by:
1225
+ if type_ == "attribute":
1226
+ order_by.append((key, ascending))
1227
+ else:
1228
+ raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
1229
+
1230
+ # Add a tie-breaker
1231
+ if not any(key == "name" for key, _ in order_by):
1232
+ order_by.append(("name", True))
1233
+
1234
+ return lambda model: tuple(_apply_reversor(model, k, asc) for (k, asc) in order_by)
1235
+
1236
+ @classmethod
1237
+ def sort(cls, models, order_by_list):
1238
+ return sorted(models, key=cls._get_sort_key(order_by_list))
1239
+
1240
+ @classmethod
1241
+ def _process_statement(cls, statement):
1242
+ tokens = _join_in_comparison_tokens(statement.tokens)
1243
+ invalids = list(filter(cls._invalid_statement_token_search_model_registry, tokens))
1244
+ if len(invalids) > 0:
1245
+ invalid_clauses = ", ".join(map(str, invalids))
1246
+ raise MlflowException.invalid_parameter_value(
1247
+ f"Invalid clause(s) in filter string: {invalid_clauses}"
1248
+ )
1249
+ return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)]
1250
+
1251
+ @classmethod
1252
+ def _get_model_search_identifier(cls, identifier, valid_attributes):
1253
+ tokens = identifier.split(".", maxsplit=1)
1254
+ if len(tokens) == 1:
1255
+ key = tokens[0]
1256
+ identifier = cls._ATTRIBUTE_IDENTIFIER
1257
+ else:
1258
+ entity_type, key = tokens
1259
+ valid_entity_types = ("attribute", "tag", "tags")
1260
+ if entity_type not in valid_entity_types:
1261
+ raise MlflowException.invalid_parameter_value(
1262
+ f"Invalid entity type '{entity_type}'. "
1263
+ f"Valid entity types are {valid_entity_types}"
1264
+ )
1265
+ identifier = (
1266
+ cls._TAG_IDENTIFIER if entity_type in ("tag", "tags") else cls._ATTRIBUTE_IDENTIFIER
1267
+ )
1268
+
1269
+ if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
1270
+ raise MlflowException.invalid_parameter_value(
1271
+ f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
1272
+ )
1273
+
1274
+ key = cls._trim_backticks(cls._strip_quotes(key))
1275
+ return {"type": identifier, "key": key}
1276
+
1277
+ @classmethod
1278
+ def _get_comparison(cls, comparison):
1279
+ stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
1280
+ cls._validate_comparison(stripped_comparison)
1281
+ left, comparator, right = stripped_comparison
1282
+ comp = cls._get_model_search_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
1283
+ comp["comparator"] = comparator.value.upper()
1284
+ comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right)
1285
+ return comp
1286
+
1287
+ @classmethod
1288
+ def _get_value(cls, identifier_type, key, token):
1289
+ if identifier_type == cls._TAG_IDENTIFIER:
1290
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1291
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1292
+ raise MlflowException(
1293
+ "Expected a quoted string value for "
1294
+ f"{identifier_type} (e.g. 'my-value'). Got value "
1295
+ f"{token.value}",
1296
+ error_code=INVALID_PARAMETER_VALUE,
1297
+ )
1298
+ elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
1299
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1300
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1301
+ elif isinstance(token, Parenthesis):
1302
+ if key != "run_id":
1303
+ raise MlflowException(
1304
+ "Only the 'run_id' attribute supports comparison with a list of quoted "
1305
+ "string values.",
1306
+ error_code=INVALID_PARAMETER_VALUE,
1307
+ )
1308
+ return cls._parse_run_ids(token)
1309
+ else:
1310
+ raise MlflowException(
1311
+ "Expected a quoted string value or a list of quoted string values for "
1312
+ f"attributes. Got value {token.value}",
1313
+ error_code=INVALID_PARAMETER_VALUE,
1314
+ )
1315
+ else:
1316
+ # Expected to be either "param" or "metric".
1317
+ raise MlflowException(
1318
+ "Invalid identifier type. Expected one of "
1319
+ f"{[cls._ATTRIBUTE_IDENTIFIER, cls._TAG_IDENTIFIER]}.",
1320
+ error_code=INVALID_PARAMETER_VALUE,
1321
+ )
1322
+
1323
+ @classmethod
1324
+ def _invalid_statement_token_search_model_registry(cls, token):
1325
+ if (
1326
+ isinstance(token, Comparison)
1327
+ or token.is_whitespace
1328
+ or token.match(ttype=TokenType.Keyword, values=["AND"])
1329
+ ):
1330
+ return False
1331
+ return True
1332
+
1333
+
1334
+ class SearchModelVersionUtils(SearchUtils):
1335
+ NUMERIC_ATTRIBUTES = {"version_number", "creation_timestamp", "last_updated_timestamp"}
1336
+ VALID_SEARCH_ATTRIBUTE_KEYS = {
1337
+ "name",
1338
+ "version_number",
1339
+ "run_id",
1340
+ "source_path",
1341
+ }
1342
+ VALID_ORDER_BY_ATTRIBUTE_KEYS = {
1343
+ "name",
1344
+ "version_number",
1345
+ "creation_timestamp",
1346
+ "last_updated_timestamp",
1347
+ }
1348
+ VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "IN"}
1349
+
1350
+ @classmethod
1351
+ def _does_model_version_match_clauses(cls, mv, sed):
1352
+ key_type = sed.get("type")
1353
+ key = sed.get("key")
1354
+ value = sed.get("value")
1355
+ comparator = sed.get("comparator").upper()
1356
+
1357
+ if cls.is_string_attribute(key_type, key, comparator):
1358
+ lhs = getattr(mv, "source" if key == "source_path" else key)
1359
+ elif cls.is_numeric_attribute(key_type, key, comparator):
1360
+ if key == "version_number":
1361
+ key = "version"
1362
+ lhs = getattr(mv, key)
1363
+ value = int(value)
1364
+ elif cls.is_tag(key_type, comparator):
1365
+ lhs = mv.tags.get(key, None)
1366
+ else:
1367
+ raise MlflowException(
1368
+ f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE
1369
+ )
1370
+
1371
+ # NB: Handling the special `mlflow.prompt.is_prompt` tag. This tag is used for
1372
+ # distinguishing between prompt models and normal models. For example, we want to
1373
+ # search for models only by the following filter string:
1374
+ #
1375
+ # tags.`mlflow.prompt.is_prompt` != 'true'
1376
+ # tags.`mlflow.prompt.is_prompt` = 'false'
1377
+ #
1378
+ # However, models do not have this tag, so lhs is None in this case. Instead of returning
1379
+ # False like normal tag filter, we need to return True here.
1380
+ if key == IS_PROMPT_TAG_KEY and lhs is None:
1381
+ return (comparator == "=" and value == "false") or (
1382
+ comparator == "!=" and value == "true"
1383
+ )
1384
+
1385
+ if lhs is None:
1386
+ return False
1387
+
1388
+ if comparator == "IN" and isinstance(value, (set, list)):
1389
+ return lhs in set(value)
1390
+
1391
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
1392
+
1393
+ @classmethod
1394
+ def filter(cls, model_versions, filter_string):
1395
+ """Filters a set of model versions based on a search filter string."""
1396
+ model_versions = [mv for mv in model_versions if mv.current_stage != STAGE_DELETED_INTERNAL]
1397
+ if not filter_string:
1398
+ return model_versions
1399
+ parsed = cls.parse_search_filter(filter_string)
1400
+
1401
+ def model_version_matches(mv):
1402
+ return all(cls._does_model_version_match_clauses(mv, s) for s in parsed)
1403
+
1404
+ return [mv for mv in model_versions if model_version_matches(mv)]
1405
+
1406
+ @classmethod
1407
+ def parse_order_by_for_search_model_versions(cls, order_by):
1408
+ token_value, is_ascending = cls._parse_order_by_string(order_by)
1409
+ identifier = SearchExperimentsUtils._get_identifier(
1410
+ token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS
1411
+ )
1412
+ return identifier["type"], identifier["key"], is_ascending
1413
+
1414
+ @classmethod
1415
+ def _get_sort_key(cls, order_by_list):
1416
+ order_by = []
1417
+ parsed_order_by = map(cls.parse_order_by_for_search_model_versions, order_by_list or [])
1418
+ for type_, key, ascending in parsed_order_by:
1419
+ if type_ == "attribute":
1420
+ # Need to add this mapping because version is a keyword in sql
1421
+ if key == "version_number":
1422
+ key = "version"
1423
+ order_by.append((key, ascending))
1424
+ else:
1425
+ raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
1426
+
1427
+ # Add a tie-breaker
1428
+ if not any(key == "name" for key, _ in order_by):
1429
+ order_by.append(("name", True))
1430
+ if not any(key == "version_number" for key, _ in order_by):
1431
+ order_by.append(("version", False))
1432
+
1433
+ return lambda model_version: tuple(
1434
+ _apply_reversor(model_version, k, asc) for (k, asc) in order_by
1435
+ )
1436
+
1437
+ @classmethod
1438
+ def sort(cls, model_versions, order_by_list):
1439
+ return sorted(model_versions, key=cls._get_sort_key(order_by_list))
1440
+
1441
+ @classmethod
1442
+ def _get_model_version_search_identifier(cls, identifier, valid_attributes):
1443
+ tokens = identifier.split(".", maxsplit=1)
1444
+ if len(tokens) == 1:
1445
+ key = tokens[0]
1446
+ identifier = cls._ATTRIBUTE_IDENTIFIER
1447
+ else:
1448
+ entity_type, key = tokens
1449
+ valid_entity_types = ("attribute", "tag", "tags")
1450
+ if entity_type not in valid_entity_types:
1451
+ raise MlflowException.invalid_parameter_value(
1452
+ f"Invalid entity type '{entity_type}'. "
1453
+ f"Valid entity types are {valid_entity_types}"
1454
+ )
1455
+ identifier = (
1456
+ cls._TAG_IDENTIFIER if entity_type in ("tag", "tags") else cls._ATTRIBUTE_IDENTIFIER
1457
+ )
1458
+
1459
+ if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes:
1460
+ raise MlflowException.invalid_parameter_value(
1461
+ f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'"
1462
+ )
1463
+
1464
+ key = cls._trim_backticks(cls._strip_quotes(key))
1465
+ return {"type": identifier, "key": key}
1466
+
1467
+ @classmethod
1468
+ def _get_comparison(cls, comparison):
1469
+ stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
1470
+ cls._validate_comparison(stripped_comparison)
1471
+ left, comparator, right = stripped_comparison
1472
+ comp = cls._get_model_version_search_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
1473
+ comp["comparator"] = comparator.value.upper()
1474
+ comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right)
1475
+ return comp
1476
+
1477
+ @classmethod
1478
+ def _get_value(cls, identifier_type, key, token):
1479
+ if identifier_type == cls._TAG_IDENTIFIER:
1480
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1481
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1482
+ raise MlflowException(
1483
+ "Expected a quoted string value for "
1484
+ f"{identifier_type} (e.g. 'my-value'). Got value "
1485
+ f"{token.value}",
1486
+ error_code=INVALID_PARAMETER_VALUE,
1487
+ )
1488
+ elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
1489
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1490
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1491
+ elif isinstance(token, Parenthesis):
1492
+ if key != "run_id":
1493
+ raise MlflowException(
1494
+ "Only the 'run_id' attribute supports comparison with a list of quoted "
1495
+ "string values.",
1496
+ error_code=INVALID_PARAMETER_VALUE,
1497
+ )
1498
+ return cls._parse_run_ids(token)
1499
+ elif token.ttype in cls.NUMERIC_VALUE_TYPES:
1500
+ if key not in cls.NUMERIC_ATTRIBUTES:
1501
+ raise MlflowException(
1502
+ f"Only the '{cls.NUMERIC_ATTRIBUTES}' attributes support comparison with "
1503
+ "numeric values.",
1504
+ error_code=INVALID_PARAMETER_VALUE,
1505
+ )
1506
+ if token.ttype == TokenType.Literal.Number.Integer:
1507
+ return int(token.value)
1508
+ elif token.ttype == TokenType.Literal.Number.Float:
1509
+ return float(token.value)
1510
+ else:
1511
+ raise MlflowException(
1512
+ "Expected a quoted string value or a list of quoted string values for "
1513
+ f"attributes. Got value {token.value}",
1514
+ error_code=INVALID_PARAMETER_VALUE,
1515
+ )
1516
+ else:
1517
+ # Expected to be either "param" or "metric".
1518
+ raise MlflowException(
1519
+ "Invalid identifier type. Expected one of "
1520
+ f"{[cls._ATTRIBUTE_IDENTIFIER, cls._TAG_IDENTIFIER]}.",
1521
+ error_code=INVALID_PARAMETER_VALUE,
1522
+ )
1523
+
1524
+ @classmethod
1525
+ def _process_statement(cls, statement):
1526
+ tokens = _join_in_comparison_tokens(statement.tokens)
1527
+ invalids = list(filter(cls._invalid_statement_token_search_model_version, tokens))
1528
+ if len(invalids) > 0:
1529
+ invalid_clauses = ", ".join(map(str, invalids))
1530
+ raise MlflowException.invalid_parameter_value(
1531
+ f"Invalid clause(s) in filter string: {invalid_clauses}"
1532
+ )
1533
+ return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)]
1534
+
1535
+ @classmethod
1536
+ def _invalid_statement_token_search_model_version(cls, token):
1537
+ if (
1538
+ isinstance(token, Comparison)
1539
+ or token.is_whitespace
1540
+ or token.match(ttype=TokenType.Keyword, values=["AND"])
1541
+ ):
1542
+ return False
1543
+ return True
1544
+
1545
+ @classmethod
1546
+ def parse_search_filter(cls, filter_string):
1547
+ if not filter_string:
1548
+ return []
1549
+ try:
1550
+ parsed = sqlparse.parse(filter_string)
1551
+ except Exception:
1552
+ raise MlflowException(
1553
+ f"Error on parsing filter '{filter_string}'", error_code=INVALID_PARAMETER_VALUE
1554
+ )
1555
+ if len(parsed) == 0 or not isinstance(parsed[0], Statement):
1556
+ raise MlflowException(
1557
+ f"Invalid filter '{filter_string}'. Could not be parsed.",
1558
+ error_code=INVALID_PARAMETER_VALUE,
1559
+ )
1560
+ elif len(parsed) > 1:
1561
+ raise MlflowException(
1562
+ f"Search filter contained multiple expression {filter_string!r}. "
1563
+ "Provide AND-ed expression list.",
1564
+ error_code=INVALID_PARAMETER_VALUE,
1565
+ )
1566
+ return cls._process_statement(parsed[0])
1567
+
1568
+
1569
+ class SearchTraceUtils(SearchUtils):
1570
+ """
1571
+ Utility class for searching traces.
1572
+ """
1573
+
1574
+ VALID_SEARCH_ATTRIBUTE_KEYS = {
1575
+ "request_id",
1576
+ "timestamp",
1577
+ "timestamp_ms",
1578
+ "execution_time",
1579
+ "execution_time_ms",
1580
+ "status",
1581
+ # The following keys are mapped to tags or metadata
1582
+ "name",
1583
+ "run_id",
1584
+ }
1585
+ VALID_ORDER_BY_ATTRIBUTE_KEYS = {
1586
+ "experiment_id",
1587
+ "timestamp",
1588
+ "timestamp_ms",
1589
+ "execution_time",
1590
+ "execution_time_ms",
1591
+ "status",
1592
+ "request_id",
1593
+ # The following keys are mapped to tags or metadata
1594
+ "name",
1595
+ "run_id",
1596
+ }
1597
+
1598
+ NUMERIC_ATTRIBUTES = {
1599
+ "timestamp_ms",
1600
+ "timestamp",
1601
+ "execution_time_ms",
1602
+ "execution_time",
1603
+ }
1604
+
1605
+ # For now, don't support LIKE/ILIKE operators for trace search because it may
1606
+ # cause performance issues with large attributes and tags. We can revisit this
1607
+ # decision if we find a way to support them efficiently.
1608
+ VALID_TAG_COMPARATORS = {"!=", "="}
1609
+ VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", "IN", "NOT IN"}
1610
+
1611
+ _REQUEST_METADATA_IDENTIFIER = "request_metadata"
1612
+ _TAG_IDENTIFIER = "tag"
1613
+ _ATTRIBUTE_IDENTIFIER = "attribute"
1614
+
1615
+ # These are aliases for the base identifiers
1616
+ # e.g. trace.status is equivalent to attribute.status
1617
+ _ALTERNATE_IDENTIFIERS = {
1618
+ "tags": _TAG_IDENTIFIER,
1619
+ "attributes": _ATTRIBUTE_IDENTIFIER,
1620
+ "trace": _ATTRIBUTE_IDENTIFIER,
1621
+ "metadata": _REQUEST_METADATA_IDENTIFIER,
1622
+ }
1623
+ _IDENTIFIERS = {_TAG_IDENTIFIER, _REQUEST_METADATA_IDENTIFIER, _ATTRIBUTE_IDENTIFIER}
1624
+ _VALID_IDENTIFIERS = _IDENTIFIERS | set(_ALTERNATE_IDENTIFIERS.keys())
1625
+
1626
+ SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS = {"name", "status", "request_id", "run_id"}
1627
+
1628
+ # Some search keys are defined differently in the DB models.
1629
+ # E.g. "name" is mapped to TraceTagKey.TRACE_NAME
1630
+ SEARCH_KEY_TO_TAG = {
1631
+ "name": TraceTagKey.TRACE_NAME,
1632
+ }
1633
+ SEARCH_KEY_TO_METADATA = {
1634
+ "run_id": TraceMetadataKey.SOURCE_RUN,
1635
+ }
1636
+ # Alias for attribute keys
1637
+ SEARCH_KEY_TO_ATTRIBUTE = {
1638
+ "timestamp": "timestamp_ms",
1639
+ "execution_time": "execution_time_ms",
1640
+ }
1641
+
1642
+ @classmethod
1643
+ def filter(cls, traces, filter_string):
1644
+ """Filters a set of traces based on a search filter string."""
1645
+ if not filter_string:
1646
+ return traces
1647
+ parsed = cls.parse_search_filter_for_search_traces(filter_string)
1648
+
1649
+ def trace_matches(trace):
1650
+ return all(cls._does_trace_match_clause(trace, s) for s in parsed)
1651
+
1652
+ return list(filter(trace_matches, traces))
1653
+
1654
+ @classmethod
1655
+ def _does_trace_match_clause(cls, trace, sed):
1656
+ type_ = sed.get("type")
1657
+ key = sed.get("key")
1658
+ value = sed.get("value")
1659
+ comparator = sed.get("comparator").upper()
1660
+
1661
+ if cls.is_tag(type_, comparator):
1662
+ lhs = trace.tags.get(key)
1663
+ elif cls.is_request_metadata(type_, comparator):
1664
+ lhs = trace.request_metadata.get(key)
1665
+ elif cls.is_attribute(type_, key, comparator):
1666
+ lhs = getattr(trace, key)
1667
+ elif sed.get("type") == cls._TAG_IDENTIFIER:
1668
+ lhs = trace.tags.get(key)
1669
+ else:
1670
+ raise MlflowException(
1671
+ f"Invalid search key '{key}', supported are {cls.VALID_SEARCH_ATTRIBUTE_KEYS}",
1672
+ error_code=INVALID_PARAMETER_VALUE,
1673
+ )
1674
+ if lhs is None:
1675
+ return False
1676
+
1677
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
1678
+
1679
+ @classmethod
1680
+ def sort(cls, traces, order_by_list):
1681
+ return sorted(traces, key=cls._get_sort_key(order_by_list))
1682
+
1683
+ @classmethod
1684
+ def parse_order_by_for_search_traces(cls, order_by):
1685
+ token_value, is_ascending = cls._parse_order_by_string(order_by)
1686
+ identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS)
1687
+ identifier = cls._replace_key_to_tag_or_metadata(identifier)
1688
+ return identifier["type"], identifier["key"], is_ascending
1689
+
1690
+ @classmethod
1691
+ def parse_search_filter_for_search_traces(cls, filter_string):
1692
+ parsed = cls.parse_search_filter(filter_string)
1693
+ return [cls._replace_key_to_tag_or_metadata(p) for p in parsed]
1694
+
1695
+ @classmethod
1696
+ def _replace_key_to_tag_or_metadata(cls, parsed: dict[str, Any]):
1697
+ """
1698
+ Replace search key to tag or metadata key if it is in the mapping.
1699
+ """
1700
+ key = parsed.get("key").lower()
1701
+ if key in cls.SEARCH_KEY_TO_TAG:
1702
+ parsed["type"] = cls._TAG_IDENTIFIER
1703
+ parsed["key"] = cls.SEARCH_KEY_TO_TAG[key]
1704
+ elif key in cls.SEARCH_KEY_TO_METADATA:
1705
+ parsed["type"] = cls._REQUEST_METADATA_IDENTIFIER
1706
+ parsed["key"] = cls.SEARCH_KEY_TO_METADATA[key]
1707
+ elif key in cls.SEARCH_KEY_TO_ATTRIBUTE:
1708
+ parsed["key"] = cls.SEARCH_KEY_TO_ATTRIBUTE[key]
1709
+ return parsed
1710
+
1711
+ @classmethod
1712
+ def is_request_metadata(cls, key_type, comparator):
1713
+ if key_type == cls._REQUEST_METADATA_IDENTIFIER:
1714
+ # Request metadata accepts the same set of comparators as tags
1715
+ if comparator not in cls.VALID_TAG_COMPARATORS:
1716
+ raise MlflowException(
1717
+ f"Invalid comparator '{comparator}' not one of '{cls.VALID_TAG_COMPARATORS}'",
1718
+ error_code=INVALID_PARAMETER_VALUE,
1719
+ )
1720
+ return True
1721
+ return False
1722
+
1723
+ @classmethod
1724
+ def _valid_entity_type(cls, entity_type):
1725
+ entity_type = cls._trim_backticks(entity_type)
1726
+ if entity_type not in cls._VALID_IDENTIFIERS:
1727
+ raise MlflowException(
1728
+ f"Invalid entity type '{entity_type}'. Valid values are {cls._VALID_IDENTIFIERS}",
1729
+ error_code=INVALID_PARAMETER_VALUE,
1730
+ )
1731
+ elif entity_type in cls._ALTERNATE_IDENTIFIERS:
1732
+ return cls._ALTERNATE_IDENTIFIERS[entity_type]
1733
+ else:
1734
+ return entity_type
1735
+
1736
+ @classmethod
1737
+ def _get_sort_key(cls, order_by_list):
1738
+ order_by = []
1739
+ parsed_order_by = map(cls.parse_order_by_for_search_traces, order_by_list or [])
1740
+ for type_, key, ascending in parsed_order_by:
1741
+ if type_ == "attribute":
1742
+ order_by.append((key, ascending))
1743
+ else:
1744
+ raise MlflowException.invalid_parameter_value(
1745
+ f"Invalid order_by entity `{type_}` with key `{key}`"
1746
+ )
1747
+
1748
+ # Add a tie-breaker
1749
+ if not any(key == "timestamp_ms" for key, _ in order_by):
1750
+ order_by.append(("timestamp_ms", False))
1751
+ if not any(key == "request_id" for key, _ in order_by):
1752
+ order_by.append(("request_id", True))
1753
+
1754
+ return lambda trace: tuple(_apply_reversor(trace, k, asc) for (k, asc) in order_by)
1755
+
1756
+ @classmethod
1757
+ def _get_value(cls, identifier_type, key, token):
1758
+ if identifier_type == cls._TAG_IDENTIFIER:
1759
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1760
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1761
+ elif isinstance(token, Parenthesis):
1762
+ return cls._parse_attribute_lists(token)
1763
+ raise MlflowException(
1764
+ "Expected a quoted string value for "
1765
+ f"{identifier_type} (e.g. 'my-value'). Got value "
1766
+ f"{token.value}",
1767
+ error_code=INVALID_PARAMETER_VALUE,
1768
+ )
1769
+ elif identifier_type == cls._ATTRIBUTE_IDENTIFIER:
1770
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1771
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1772
+ elif isinstance(token, Parenthesis):
1773
+ if key not in cls.SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS:
1774
+ raise MlflowException(
1775
+ f"Only attributes in {cls.SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS} "
1776
+ "supports comparison with a list of quoted string values.",
1777
+ error_code=INVALID_PARAMETER_VALUE,
1778
+ )
1779
+ return cls._parse_attribute_lists(token)
1780
+ elif token.ttype in cls.NUMERIC_VALUE_TYPES:
1781
+ if key not in cls.NUMERIC_ATTRIBUTES:
1782
+ raise MlflowException(
1783
+ f"Only the '{cls.NUMERIC_ATTRIBUTES}' attributes support comparison with "
1784
+ "numeric values.",
1785
+ error_code=INVALID_PARAMETER_VALUE,
1786
+ )
1787
+ if token.ttype == TokenType.Literal.Number.Integer:
1788
+ return int(token.value)
1789
+ elif token.ttype == TokenType.Literal.Number.Float:
1790
+ return float(token.value)
1791
+ else:
1792
+ raise MlflowException(
1793
+ "Expected a quoted string value or a list of quoted string values for "
1794
+ f"attributes. Got value {token.value}",
1795
+ error_code=INVALID_PARAMETER_VALUE,
1796
+ )
1797
+ elif identifier_type == cls._REQUEST_METADATA_IDENTIFIER:
1798
+ if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
1799
+ return cls._strip_quotes(token.value, expect_quoted_value=True)
1800
+ else:
1801
+ raise MlflowException(
1802
+ "Expected a quoted string value for "
1803
+ f"{identifier_type} (e.g. 'my-value'). Got value "
1804
+ f"{token.value}",
1805
+ error_code=INVALID_PARAMETER_VALUE,
1806
+ )
1807
+ else:
1808
+ # Expected to be either "param" or "metric".
1809
+ raise MlflowException(
1810
+ f"Invalid identifier type: {identifier_type}. "
1811
+ f"Expected one of {cls._VALID_IDENTIFIERS}.",
1812
+ error_code=INVALID_PARAMETER_VALUE,
1813
+ )
1814
+
1815
+ @classmethod
1816
+ def _parse_attribute_lists(cls, token):
1817
+ return cls._parse_list_from_sql_token(token)
1818
+
1819
+ @classmethod
1820
+ def _process_statement(cls, statement):
1821
+ # check validity
1822
+ tokens = _join_in_comparison_tokens(statement.tokens, search_traces=True)
1823
+ invalids = list(filter(cls._invalid_statement_token_search_traces, tokens))
1824
+ if len(invalids) > 0:
1825
+ invalid_clauses = ", ".join(f"'{token}'" for token in invalids)
1826
+ raise MlflowException(
1827
+ f"Invalid clause(s) in filter string: {invalid_clauses}",
1828
+ error_code=INVALID_PARAMETER_VALUE,
1829
+ )
1830
+ return [cls._get_comparison(si) for si in tokens if isinstance(si, Comparison)]
1831
+
1832
+ @classmethod
1833
+ def _invalid_statement_token_search_traces(cls, token):
1834
+ if (
1835
+ isinstance(token, Comparison)
1836
+ or token.is_whitespace
1837
+ or token.match(ttype=TokenType.Keyword, values=["AND"])
1838
+ ):
1839
+ return False
1840
+ return True
1841
+
1842
+ @classmethod
1843
+ def _get_comparison(cls, comparison):
1844
+ stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
1845
+ cls._validate_comparison(stripped_comparison, search_traces=True)
1846
+ comp = cls._get_identifier(stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS)
1847
+ comp["comparator"] = stripped_comparison[1].value
1848
+ comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), stripped_comparison[2])
1849
+ return comp
1850
+
1851
+
1852
+ class SearchLoggedModelsUtils(SearchUtils):
1853
+ NUMERIC_ATTRIBUTES = {
1854
+ "creation_timestamp",
1855
+ "creation_time",
1856
+ "last_updated_timestamp",
1857
+ "last_updated_time",
1858
+ }
1859
+ VALID_SEARCH_ATTRIBUTE_KEYS = {
1860
+ "name",
1861
+ "model_id",
1862
+ "model_type",
1863
+ "status",
1864
+ "source_run_id",
1865
+ } | NUMERIC_ATTRIBUTES
1866
+ VALID_ORDER_BY_ATTRIBUTE_KEYS = VALID_SEARCH_ATTRIBUTE_KEYS
1867
+
1868
+ @classmethod
1869
+ def _does_logged_model_match_clause(
1870
+ cls,
1871
+ model: LoggedModel,
1872
+ condition: dict[str, Any],
1873
+ datasets: Optional[list[dict[str, Any]]] = None,
1874
+ ):
1875
+ key_type = condition.get("type")
1876
+ key = condition.get("key")
1877
+ value = condition.get("value")
1878
+ comparator = condition.get("comparator").upper()
1879
+
1880
+ key = SearchUtils.translate_key_alias(key)
1881
+
1882
+ if cls.is_metric(key_type, comparator):
1883
+ matching_metrics = [metric for metric in model.metrics if metric.key == key]
1884
+ if datasets:
1885
+ matching_metrics = [
1886
+ metric
1887
+ for metric in matching_metrics
1888
+ if any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets)
1889
+ ]
1890
+ lhs = matching_metrics[0].value if matching_metrics else None
1891
+ value = float(value)
1892
+ elif cls.is_param(key_type, comparator):
1893
+ lhs = model.params.get(key, None)
1894
+ elif cls.is_tag(key_type, comparator):
1895
+ lhs = model.tags.get(key, None)
1896
+ elif cls.is_numeric_attribute(key_type, key, comparator):
1897
+ lhs = getattr(model, key)
1898
+ value = int(value)
1899
+ elif hasattr(model, key):
1900
+ lhs = getattr(model, key)
1901
+ else:
1902
+ raise MlflowException.invalid_parameter_value(
1903
+ f"Invalid logged model search key '{key}'",
1904
+ )
1905
+ if lhs is None:
1906
+ return False
1907
+
1908
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
1909
+
1910
+ @classmethod
1911
+ def validate_list_supported(cls, key: str) -> None:
1912
+ """
1913
+ Override to allow logged model attributes to be used with IN/NOT IN.
1914
+ """
1915
+
1916
+ @classmethod
1917
+ def filter_logged_models(
1918
+ cls,
1919
+ models: list[LoggedModel],
1920
+ filter_string: Optional[str] = None,
1921
+ datasets: Optional[list[dict[str, Any]]] = None,
1922
+ ):
1923
+ """Filters a set of runs based on a search filter string and list of dataset filters."""
1924
+ if not filter_string and not datasets:
1925
+ return models
1926
+
1927
+ parsed = cls.parse_search_filter(filter_string)
1928
+
1929
+ # If there are dataset filters but no metric filters in the filter string,
1930
+ # filter for models that have any metrics on the datasets
1931
+ if datasets and not any(
1932
+ cls.is_metric(s.get("type"), s.get("comparator").upper()) for s in parsed
1933
+ ):
1934
+
1935
+ def model_has_metrics_on_datasets(model):
1936
+ return any(
1937
+ any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets)
1938
+ for metric in model.metrics
1939
+ )
1940
+
1941
+ models = [model for model in models if model_has_metrics_on_datasets(model)]
1942
+
1943
+ def model_matches(model):
1944
+ return all(cls._does_logged_model_match_clause(model, s, datasets) for s in parsed)
1945
+
1946
+ return [model for model in models if model_matches(model)]
1947
+
1948
+ @dataclass
1949
+ class OrderBy:
1950
+ field_name: str
1951
+ ascending: bool = True
1952
+ dataset_name: Optional[str] = None
1953
+ dataset_digest: Optional[str] = None
1954
+
1955
+ @classmethod
1956
+ def parse_order_by_for_logged_models(cls, order_by: dict[str, Any]) -> OrderBy:
1957
+ if not isinstance(order_by, dict):
1958
+ raise MlflowException.invalid_parameter_value(
1959
+ "`order_by` must be a list of dictionaries."
1960
+ )
1961
+ field_name = order_by.get("field_name")
1962
+ if field_name is None:
1963
+ raise MlflowException.invalid_parameter_value(
1964
+ "`field_name` in the `order_by` clause must be specified."
1965
+ )
1966
+ if "." in field_name:
1967
+ entity = field_name.split(".", 1)[0]
1968
+ if entity != "metrics":
1969
+ raise MlflowException.invalid_parameter_value(
1970
+ f"Invalid order by field name: {entity}, only `metrics.<name>` is allowed."
1971
+ )
1972
+ else:
1973
+ field_name = field_name.strip()
1974
+ if field_name not in cls.VALID_ORDER_BY_ATTRIBUTE_KEYS:
1975
+ raise MlflowException.invalid_parameter_value(
1976
+ f"Invalid order by field name: {field_name}."
1977
+ )
1978
+ ascending = order_by.get("ascending", True)
1979
+ if ascending not in [True, False]:
1980
+ raise MlflowException.invalid_parameter_value(
1981
+ "Value of `ascending` in the `order_by` clause must be a boolean, got "
1982
+ f"{type(ascending)} for field {field_name}."
1983
+ )
1984
+ dataset_name = order_by.get("dataset_name")
1985
+ dataset_digest = order_by.get("dataset_digest")
1986
+ if dataset_digest and not dataset_name:
1987
+ raise MlflowException.invalid_parameter_value(
1988
+ "`dataset_digest` can only be specified if `dataset_name` is also specified."
1989
+ )
1990
+
1991
+ aliases = {
1992
+ "creation_time": "creation_timestamp",
1993
+ }
1994
+ return cls.OrderBy(
1995
+ aliases.get(field_name, field_name), ascending, dataset_name, dataset_digest
1996
+ )
1997
+
1998
+ @classmethod
1999
+ def _apply_reversor_for_logged_model(
2000
+ cls,
2001
+ model: LoggedModel,
2002
+ order_by: OrderBy,
2003
+ ):
2004
+ if "." in order_by.field_name:
2005
+ metric_key = order_by.field_name.split(".", 1)[1]
2006
+ filtered_metrics = sorted(
2007
+ [
2008
+ m
2009
+ for m in model.metrics
2010
+ if m.key == metric_key
2011
+ and (not order_by.dataset_name or m.dataset_name == order_by.dataset_name)
2012
+ and (not order_by.dataset_digest or m.dataset_digest == order_by.dataset_digest)
2013
+ ],
2014
+ key=lambda metric: metric.timestamp,
2015
+ reverse=True,
2016
+ )
2017
+ latest_metric_value = None if len(filtered_metrics) == 0 else filtered_metrics[0].value
2018
+ return (
2019
+ _LoggedModelMetricComp(latest_metric_value)
2020
+ if order_by.ascending
2021
+ else _Reversor(latest_metric_value)
2022
+ )
2023
+ else:
2024
+ value = getattr(model, order_by.field_name)
2025
+ return value if order_by.ascending else _Reversor(value)
2026
+
2027
+ @classmethod
2028
+ def _get_sort_key(cls, order_by_list: Optional[list[dict[str, Any]]]):
2029
+ parsed_order_by = list(map(cls.parse_order_by_for_logged_models, order_by_list or []))
2030
+
2031
+ # Add a tie-breaker
2032
+ if not any(order_by.field_name == "creation_timestamp" for order_by in parsed_order_by):
2033
+ parsed_order_by.append(cls.OrderBy("creation_timestamp", False))
2034
+ if not any(order_by.field_name == "model_id" for order_by in parsed_order_by):
2035
+ parsed_order_by.append(cls.OrderBy("model_id"))
2036
+
2037
+ return lambda logged_model: tuple(
2038
+ cls._apply_reversor_for_logged_model(logged_model, order_by)
2039
+ for order_by in parsed_order_by
2040
+ )
2041
+
2042
+ @classmethod
2043
+ def sort(cls, models, order_by_list):
2044
+ return sorted(models, key=cls._get_sort_key(order_by_list))
2045
+
2046
+
2047
+ class _LoggedModelMetricComp:
2048
+ def __init__(self, obj):
2049
+ self.obj = obj
2050
+
2051
+ def __eq__(self, other):
2052
+ return other.obj == self.obj
2053
+
2054
+ def __lt__(self, other):
2055
+ if self.obj is None:
2056
+ return False
2057
+ if other.obj is None:
2058
+ return True
2059
+ return self.obj < other.obj
2060
+
2061
+
2062
+ @dataclass
2063
+ class SearchLoggedModelsPaginationToken:
2064
+ experiment_ids: list[str]
2065
+ filter_string: Optional[str] = None
2066
+ order_by: Optional[list[dict[str, Any]]] = None
2067
+ offset: int = 0
2068
+
2069
+ def to_json(self) -> str:
2070
+ return json.dumps(asdict(self))
2071
+
2072
+ def encode(self) -> str:
2073
+ return base64.b64encode(self.to_json().encode("utf-8")).decode("utf-8")
2074
+
2075
+ @classmethod
2076
+ def decode(cls, token: str) -> "SearchLoggedModelsPaginationToken":
2077
+ try:
2078
+ token = json.loads(base64.b64decode(token.encode("utf-8")).decode("utf-8"))
2079
+ except json.JSONDecodeError as e:
2080
+ raise MlflowException.invalid_parameter_value(f"Invalid page token: {token}. {e}")
2081
+
2082
+ return cls(
2083
+ experiment_ids=token.get("experiment_ids"),
2084
+ filter_string=token.get("filter_string") or None,
2085
+ order_by=token.get("order_by") or None,
2086
+ offset=token.get("offset") or 0,
2087
+ )
2088
+
2089
+ def validate(
2090
+ self,
2091
+ experiment_ids: list[str],
2092
+ filter_string: Optional[str],
2093
+ order_by: Optional[list[dict[str, Any]]],
2094
+ ) -> None:
2095
+ if self.experiment_ids != experiment_ids:
2096
+ raise MlflowException.invalid_parameter_value(
2097
+ f"Experiment IDs in the page token do not match the requested experiment IDs. "
2098
+ f"Expected: {experiment_ids}. Found: {self.experiment_ids}"
2099
+ )
2100
+
2101
+ if self.filter_string != filter_string:
2102
+ raise MlflowException.invalid_parameter_value(
2103
+ f"Filter string in the page token does not match the requested filter string. "
2104
+ f"Expected: {filter_string}. Found: {self.filter_string}"
2105
+ )
2106
+
2107
+ if self.order_by != order_by:
2108
+ raise MlflowException.invalid_parameter_value(
2109
+ f"Order by in the page token does not match the requested order by. "
2110
+ f"Expected: {order_by}. Found: {self.order_by}"
2111
+ )