genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
mlflow/types/llm.py ADDED
@@ -0,0 +1,935 @@
1
+ import time
2
+ import uuid
3
+ from dataclasses import asdict, dataclass, field, fields
4
+ from typing import Any, Literal, Optional
5
+
6
+ from mlflow.types.schema import AnyType, Array, ColSpec, DataType, Map, Object, Property, Schema
7
+
8
+ # TODO: Switch to pydantic in a future version of MLflow.
9
+ # For now, to prevent adding pydantic as a core dependency,
10
+ # we use dataclasses instead.
11
+ #
12
+ # Unfortunately, validation for generic types is not that
13
+ # straightforward. For example, `isinstance(thing, List[T])``
14
+ # is not supported, so the code here is a little ugly.
15
+
16
+
17
+ JSON_SCHEMA_TYPES = ["string", "number", "integer", "object", "array", "boolean", "null"]
18
+
19
+
20
+ class _BaseDataclass:
21
+ def _validate_field(self, key, val_type, required):
22
+ value = getattr(self, key, None)
23
+ if required and value is None:
24
+ raise ValueError(f"`{key}` is required")
25
+ if value is not None and not isinstance(value, val_type):
26
+ raise ValueError(
27
+ f"`{key}` must be of type {val_type.__name__}, got {type(value).__name__}"
28
+ )
29
+
30
+ def _validate_literal(self, key, allowed_values, required):
31
+ value = getattr(self, key, None)
32
+ if required and value is None:
33
+ raise ValueError(f"`{key}` is required")
34
+ if value is not None and value not in allowed_values:
35
+ raise ValueError(f"`{key}` must be one of {allowed_values}, got {value}")
36
+
37
+ def _validate_list(self, key, val_type, required):
38
+ values = getattr(self, key, None)
39
+ if required and values is None:
40
+ raise ValueError(f"`{key}` is required")
41
+
42
+ if values is not None:
43
+ if isinstance(values, list) and not all(isinstance(v, val_type) for v in values):
44
+ raise ValueError(f"All items in `{key}` must be of type {val_type.__name__}")
45
+ elif not isinstance(values, list):
46
+ raise ValueError(f"`{key}` must be a list, got {type(values).__name__}")
47
+
48
+ def _convert_dataclass(self, key: str, cls: "_BaseDataclass", required=True):
49
+ value = getattr(self, key)
50
+ if value is None:
51
+ if required:
52
+ raise ValueError(f"`{key}` is required")
53
+ return
54
+
55
+ if isinstance(value, cls):
56
+ return
57
+
58
+ if not isinstance(value, dict):
59
+ raise ValueError(
60
+ f"Expected `{key}` to be either an instance of `{cls.__name__}` or "
61
+ f"a dict matching the schema. Received `{type(value).__name__}`"
62
+ )
63
+
64
+ try:
65
+ setattr(self, key, cls.from_dict(value))
66
+ except TypeError as e:
67
+ raise ValueError(f"Error when coercing {value} to {cls.__name__}: {e}")
68
+
69
+ def _convert_dataclass_list(self, key: str, cls: "_BaseDataclass", required=True):
70
+ values = getattr(self, key)
71
+ if values is None:
72
+ if required:
73
+ raise ValueError(f"`{key}` is required")
74
+ return
75
+ if not isinstance(values, list):
76
+ raise ValueError(f"`{key}` must be a list")
77
+
78
+ if len(values) > 0:
79
+ # if the items are all dicts, try to convert them to the desired class
80
+ if all(isinstance(v, dict) for v in values):
81
+ try:
82
+ setattr(self, key, [cls.from_dict(v) for v in values])
83
+ except TypeError as e:
84
+ raise ValueError(f"Error when coercing {values} to {cls.__name__}: {e}")
85
+ elif any(not isinstance(v, cls) for v in values):
86
+ raise ValueError(
87
+ f"Items in `{key}` must all have the same type: {cls.__name__} or dict"
88
+ )
89
+
90
+ def _convert_dataclass_map(self, key, cls, required=True):
91
+ mapping = getattr(self, key)
92
+ if mapping is None:
93
+ if required:
94
+ raise ValueError(f"`{key}` is required")
95
+ return
96
+
97
+ if not isinstance(mapping, dict):
98
+ raise ValueError(f"`{key}` must be a dict")
99
+
100
+ # create a new map to avoid mutating the original
101
+ new_mapping = {}
102
+ for k, v in mapping.items():
103
+ if isinstance(v, cls):
104
+ new_mapping[k] = v
105
+ elif isinstance(v, dict):
106
+ try:
107
+ new_mapping[k] = cls.from_dict(v)
108
+ except TypeError as e:
109
+ raise ValueError(f"Error when coercing {v} to {cls.__name__}: {e}")
110
+ else:
111
+ raise ValueError(
112
+ f"Items in `{key}` must be either an instance of `{cls.__name__}` "
113
+ f"or a dict matching the schema. Received `{type(v).__name__}`"
114
+ )
115
+ setattr(self, key, new_mapping)
116
+
117
+ def to_dict(self):
118
+ return asdict(self, dict_factory=lambda obj: {k: v for (k, v) in obj if v is not None})
119
+
120
+ @classmethod
121
+ def from_dict(cls, data):
122
+ """
123
+ Create an instance of the class from a dict, ignoring any undefined fields.
124
+ This is useful when the dict contains extra fields, causing cls(**data) to fail.
125
+ """
126
+ field_names = [field.name for field in fields(cls)]
127
+ filtered_data = {k: v for k, v in data.items() if k in field_names}
128
+ return cls(**filtered_data)
129
+
130
+
131
+ @dataclass
132
+ class FunctionToolCallArguments(_BaseDataclass):
133
+ """
134
+ The arguments of a function tool call made by the model.
135
+
136
+ Args:
137
+ arguments (str): A JSON string of arguments that should be passed to the tool.
138
+ name (str): The name of the tool that is being called.
139
+ """
140
+
141
+ name: str
142
+ arguments: str
143
+
144
+ def __post_init__(self):
145
+ self._validate_field("name", str, True)
146
+ self._validate_field("arguments", str, True)
147
+
148
+ def to_tool_call(self, id=None):
149
+ if id is None:
150
+ id = str(uuid.uuid4())
151
+ return ToolCall(id=id, function=self)
152
+
153
+
154
+ @dataclass
155
+ class ToolCall(_BaseDataclass):
156
+ """
157
+ A tool call made by the model.
158
+
159
+ Args:
160
+ function (:py:class:`FunctionToolCallArguments`): The arguments of the function tool call.
161
+ id (str): The ID of the tool call. Defaults to a random UUID.
162
+ type (str): The type of the object. Defaults to "function".
163
+ """
164
+
165
+ function: FunctionToolCallArguments
166
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
167
+ type: str = "function"
168
+
169
+ def __post_init__(self):
170
+ self._validate_field("id", str, True)
171
+ self._convert_dataclass("function", FunctionToolCallArguments, True)
172
+ self._validate_field("type", str, True)
173
+
174
+
175
+ @dataclass
176
+ class ChatMessage(_BaseDataclass):
177
+ """
178
+ A message in a chat request or response.
179
+
180
+ Args:
181
+ role (str): The role of the entity that sent the message (e.g. ``"user"``,
182
+ ``"system"``, ``"assistant"``, ``"tool"``).
183
+ content (str): The content of the message.
184
+ **Optional** Can be ``None`` if refusal or tool_calls are provided.
185
+ refusal (str): The refusal message content.
186
+ **Optional** Supplied if a refusal response is provided.
187
+ name (str): The name of the entity that sent the message. **Optional**.
188
+ tool_calls (List[:py:class:`ToolCall`]): A list of tool calls made by the model.
189
+ **Optional** defaults to ``None``
190
+ tool_call_id (str): The ID of the tool call that this message is a response to.
191
+ **Optional** defaults to ``None``
192
+ """
193
+
194
+ role: str
195
+ content: Optional[str] = None
196
+ refusal: Optional[str] = None
197
+ name: Optional[str] = None
198
+ tool_calls: Optional[list[ToolCall]] = None
199
+ tool_call_id: Optional[str] = None
200
+
201
+ def __post_init__(self):
202
+ self._validate_field("role", str, True)
203
+
204
+ if self.refusal:
205
+ self._validate_field("refusal", str, True)
206
+ if self.content:
207
+ raise ValueError("Both `content` and `refusal` cannot be set")
208
+ elif self.tool_calls:
209
+ self._validate_field("content", str, False)
210
+ else:
211
+ self._validate_field("content", str, True)
212
+
213
+ self._validate_field("name", str, False)
214
+ self._convert_dataclass_list("tool_calls", ToolCall, False)
215
+ self._validate_field("tool_call_id", str, False)
216
+
217
+
218
+ @dataclass
219
+ class ChatChoiceDelta(_BaseDataclass):
220
+ """
221
+ A streaming message delta in a chat response.
222
+
223
+ Args:
224
+ role (str): The role of the entity that sent the message (e.g. ``"user"``,
225
+ ``"system"``, ``"assistant"``, ``"tool"``).
226
+ **Optional** defaults to ``"assistant"``
227
+ This is optional because OpenAI clients can explicitly return None for
228
+ the role
229
+ content (str): The content of the new token being streamed
230
+ **Optional** Can be ``None`` on the last delta chunk or if refusal or
231
+ tool_calls are provided
232
+ refusal (str): The refusal message content.
233
+ **Optional** Supplied if a refusal response is provided.
234
+ name (str): The name of the entity that sent the message. **Optional**.
235
+ tool_calls (List[:py:class:`ToolCall`]): A list of tool calls made by the model.
236
+ **Optional** defaults to ``None``
237
+ """
238
+
239
+ role: Optional[str] = "assistant"
240
+ content: Optional[str] = None
241
+ refusal: Optional[str] = None
242
+ name: Optional[str] = None
243
+ tool_calls: Optional[list[ToolCall]] = None
244
+
245
+ def __post_init__(self):
246
+ self._validate_field("role", str, False)
247
+
248
+ if self.refusal:
249
+ self._validate_field("refusal", str, True)
250
+ if self.content:
251
+ raise ValueError("Both `content` and `refusal` cannot be set")
252
+ self._validate_field("content", str, False)
253
+ self._validate_field("name", str, False)
254
+ self._convert_dataclass_list("tool_calls", ToolCall, False)
255
+
256
+
257
+ @dataclass
258
+ class ParamType(_BaseDataclass):
259
+ type: Literal["string", "number", "integer", "object", "array", "boolean", "null"]
260
+
261
+ def __post_init__(self):
262
+ self._validate_literal("type", JSON_SCHEMA_TYPES, True)
263
+
264
+
265
+ @dataclass
266
+ class ParamProperty(ParamType):
267
+ """
268
+ A single parameter within a function definition.
269
+
270
+ Args:
271
+ type (str): The type of the parameter. Possible values are "string", "number", "integer",
272
+ "object", "array", "boolean", or "null", conforming to the JSON Schema specification.
273
+ description (str): A description of the parameter.
274
+ **Optional**, defaults to ``None``
275
+ enum (List[str]): Used to constrain the possible values for the parameter.
276
+ **Optional**, defaults to ``None``
277
+ items (:py:class:`ParamProperty`): If the param is of ``array`` type, this field can be
278
+ used to specify the type of its items. **Optional**, defaults to ``None``
279
+ """
280
+
281
+ description: Optional[str] = None
282
+ enum: Optional[list[str]] = None
283
+ items: Optional[ParamType] = None
284
+
285
+ def __post_init__(self):
286
+ self._validate_field("description", str, False)
287
+ self._validate_list("enum", str, False)
288
+ self._convert_dataclass("items", ParamType, False)
289
+ super().__post_init__()
290
+
291
+
292
+ @dataclass
293
+ class ToolParamsSchema(_BaseDataclass):
294
+ """
295
+ A tool parameter definition.
296
+
297
+ Args:
298
+ properties (Dict[str, :py:class:`ParamProperty`]): A mapping of parameter names to
299
+ their definitions.
300
+ type (str): The type of the parameter. Currently only "object" is supported.
301
+ required (List[str]): A list of required parameter names. **Optional**, defaults to ``None``
302
+ additionalProperties (bool): Whether additional properties are allowed in the object.
303
+ **Optional**, defaults to ``None``
304
+ """
305
+
306
+ properties: dict[str, ParamProperty]
307
+ type: Literal["object"] = "object"
308
+ required: Optional[list[str]] = None
309
+ additionalProperties: Optional[bool] = None
310
+
311
+ def __post_init__(self):
312
+ self._convert_dataclass_map("properties", ParamProperty, True)
313
+ self._validate_literal("type", ["object"], True)
314
+ self._validate_list("required", str, False)
315
+ self._validate_field("additionalProperties", bool, False)
316
+
317
+
318
+ @dataclass
319
+ class FunctionToolDefinition(_BaseDataclass):
320
+ """
321
+ Definition for function tools (currently the only supported type of tool).
322
+
323
+ Args:
324
+ name (str): The name of the tool.
325
+ description (str): A description of what the tool does, and how it should be used.
326
+ **Optional**, defaults to ``None``
327
+ parameters: A mapping of parameter names to their
328
+ definitions. If not provided, this defines a function without parameters.
329
+ **Optional**, defaults to ``None``
330
+ strict (bool): A flag that represents whether or not the model should
331
+ strictly follow the schema provided.
332
+ """
333
+
334
+ name: str
335
+ description: Optional[str] = None
336
+ parameters: Optional[ToolParamsSchema] = None
337
+ strict: bool = False
338
+
339
+ def __post_init__(self):
340
+ self._validate_field("name", str, True)
341
+ self._validate_field("description", str, False)
342
+ self._convert_dataclass("parameters", ToolParamsSchema, False)
343
+ self._validate_field("strict", bool, True)
344
+
345
+ def to_tool_definition(self):
346
+ """
347
+ Convenience function for wrapping this in a ToolDefinition
348
+ """
349
+ return ToolDefinition(type="function", function=self)
350
+
351
+
352
+ @dataclass
353
+ class ToolDefinition(_BaseDataclass):
354
+ """
355
+ Definition for tools that can be called by the model.
356
+
357
+ Args:
358
+ function (:py:class:`FunctionToolDefinition`): The definition of a function tool.
359
+ type (str): The type of the tool. Currently only "function" is supported.
360
+ """
361
+
362
+ function: FunctionToolDefinition
363
+ type: Literal["function"] = "function"
364
+
365
+ def __post_init__(self):
366
+ self._validate_literal("type", ["function"], True)
367
+ self._convert_dataclass("function", FunctionToolDefinition, True)
368
+
369
+
370
+ @dataclass
371
+ class ChatParams(_BaseDataclass):
372
+ """
373
+ Common parameters used for chat inference
374
+
375
+ Args:
376
+ temperature (float): A param used to control randomness and creativity during inference.
377
+ **Optional**, defaults to ``1.0``
378
+ max_tokens (int): The maximum number of new tokens to generate.
379
+ **Optional**, defaults to ``None`` (unlimited)
380
+ stop (List[str]): A list of tokens at which to stop generation.
381
+ **Optional**, defaults to ``None``
382
+ n (int): The number of responses to generate.
383
+ **Optional**, defaults to ``1``
384
+ stream (bool): Whether to stream back responses as they are generated.
385
+ **Optional**, defaults to ``False``
386
+ top_p (float): An optional param to control sampling with temperature, the model considers
387
+ the results of the tokens with top_p probability mass. E.g., 0.1 means only the tokens
388
+ comprising the top 10% probability mass are considered.
389
+ top_k (int): An optional param for reducing the vocabulary size to top k tokens
390
+ (sorted in descending order by their probabilities).
391
+ frequency_penalty: (float): An optional param of positive or negative value,
392
+ positive values penalize new tokens based on
393
+ their existing frequency in the text so far, decreasing the model's likelihood to repeat
394
+ the same line verbatim.
395
+ presence_penalty: (float): An optional param of positive or negative value,
396
+ positive values penalize new tokens based on whether they appear in the text so far,
397
+ increasing the model's likelihood to talk about new topics.
398
+ custom_inputs (Dict[str, Any]): An optional param to provide arbitrary additional context
399
+ to the model. The dictionary values must be JSON-serializable.
400
+ tools (List[:py:class:`ToolDefinition`]): An optional list of tools that can be called by
401
+ the model.
402
+
403
+ .. warning::
404
+
405
+ In an upcoming MLflow release, default values for `temperature`, `n` and `stream` will be
406
+ removed. Please provide these values explicitly in your code if needed.
407
+ """
408
+
409
+ temperature: float = 1.0
410
+ max_tokens: Optional[int] = None
411
+ stop: Optional[list[str]] = None
412
+ n: int = 1
413
+ stream: bool = False
414
+
415
+ top_p: Optional[float] = None
416
+ top_k: Optional[int] = None
417
+ frequency_penalty: Optional[float] = None
418
+ presence_penalty: Optional[float] = None
419
+
420
+ custom_inputs: Optional[dict[str, Any]] = None
421
+ tools: Optional[list[ToolDefinition]] = None
422
+
423
+ def __post_init__(self):
424
+ self._validate_field("temperature", float, True)
425
+ self._validate_field("max_tokens", int, False)
426
+ self._validate_list("stop", str, False)
427
+ self._validate_field("n", int, True)
428
+ self._validate_field("stream", bool, True)
429
+
430
+ self._validate_field("top_p", float, False)
431
+ self._validate_field("top_k", int, False)
432
+ self._validate_field("frequency_penalty", float, False)
433
+ self._validate_field("presence_penalty", float, False)
434
+ self._convert_dataclass_list("tools", ToolDefinition, False)
435
+
436
+ # validate that the custom_inputs field is a map from string to string
437
+ if self.custom_inputs is not None:
438
+ if not isinstance(self.custom_inputs, dict):
439
+ raise ValueError(
440
+ "Expected `custom_inputs` to be a dictionary, "
441
+ f"received `{type(self.custom_inputs).__name__}`"
442
+ )
443
+ for key, value in self.custom_inputs.items():
444
+ if not isinstance(key, str):
445
+ raise ValueError(
446
+ "Expected `custom_inputs` to be of type `Dict[str, Any]`, "
447
+ f"received key of type `{type(key).__name__}` (key: {key})"
448
+ )
449
+
450
+ @classmethod
451
+ def keys(cls) -> set[str]:
452
+ """
453
+ Return the keys of the dataclass
454
+ """
455
+ return {field.name for field in fields(cls)}
456
+
457
+
458
+ @dataclass()
459
+ class ChatCompletionRequest(ChatParams):
460
+ """
461
+ Format of the request object expected by the chat endpoint.
462
+
463
+ Args:
464
+ messages (List[:py:class:`ChatMessage`]): A list of :py:class:`ChatMessage`
465
+ that will be passed to the model. **Optional**, defaults to empty list (``[]``)
466
+ temperature (float): A param used to control randomness and creativity during inference.
467
+ **Optional**, defaults to ``1.0``
468
+ max_tokens (int): The maximum number of new tokens to generate.
469
+ **Optional**, defaults to ``None`` (unlimited)
470
+ stop (List[str]): A list of tokens at which to stop generation.
471
+ **Optional**, defaults to ``None``
472
+ n (int): The number of responses to generate.
473
+ **Optional**, defaults to ``1``
474
+ stream (bool): Whether to stream back responses as they are generated.
475
+ **Optional**, defaults to ``False``
476
+ top_p (float): An optional param to control sampling with temperature, the model considers
477
+ the results of the tokens with top_p probability mass. E.g., 0.1 means only the tokens
478
+ comprising the top 10% probability mass are considered.
479
+ top_k (int): An optional param for reducing the vocabulary size to top k tokens
480
+ (sorted in descending order by their probabilities).
481
+ frequency_penalty: (float): An optional param of positive or negative value,
482
+ positive values penalize new tokens based on
483
+ their existing frequency in the text so far, decreasing the model's likelihood to repeat
484
+ the same line verbatim.
485
+ presence_penalty: (float): An optional param of positive or negative value,
486
+ positive values penalize new tokens based on whether they appear in the text so far,
487
+ increasing the model's likelihood to talk about new topics.
488
+ custom_inputs (Dict[str, Any]): An optional param to provide arbitrary additional context
489
+ to the model. The dictionary values must be JSON-serializable.
490
+ tools (List[:py:class:`ToolDefinition`]): An optional list of tools that can be called by
491
+ the model.
492
+
493
+ .. warning::
494
+
495
+ In an upcoming MLflow release, default values for `temperature`, `n` and `stream` will be
496
+ removed. Please provide these values explicitly in your code if needed.
497
+ """
498
+
499
+ messages: list[ChatMessage] = field(default_factory=list)
500
+
501
+ def __post_init__(self):
502
+ self._convert_dataclass_list("messages", ChatMessage)
503
+ super().__post_init__()
504
+
505
+
506
+ @dataclass
507
+ class TopTokenLogProb(_BaseDataclass):
508
+ """
509
+ Token and its log probability.
510
+
511
+ Args:
512
+ token: The token.
513
+ logprob: The log probability of this token, if it is within the top
514
+ 20 most likely tokens. Otherwise, the value -9999.0 is used to
515
+ signify that the token is very unlikely.
516
+ bytes: A list of integers representing the UTF-8 bytes representation
517
+ of the token. Useful in instances where characters are represented
518
+ by multiple tokens and their byte representations must be combined
519
+ to generate the correct text representation. Can be null if there
520
+ is no bytes representation for the token.
521
+ """
522
+
523
+ token: str
524
+ logprob: float
525
+ bytes: Optional[list[int]] = None
526
+
527
+ def __post_init__(self):
528
+ self._validate_field("token", str, True)
529
+ self._validate_field("logprob", float, True)
530
+ self._validate_list("bytes", int, False)
531
+
532
+
533
+ @dataclass
534
+ class TokenLogProb(_BaseDataclass):
535
+ """
536
+ Message content token with log probability information.
537
+
538
+ Args:
539
+ token: The token.
540
+ logprob: The log probability of this token, if it is within the top
541
+ 20 most likely tokens. Otherwise, the value -9999.0 is used to
542
+ signify that the token is very unlikely.
543
+ bytes: A list of integers representing the UTF-8 bytes representation
544
+ of the token. Useful in instances where characters are represented
545
+ by multiple tokens and their byte representations must be combined
546
+ to generate the correct text representation. Can be null if there
547
+ is no bytes representation for the token.
548
+ top_logprobs: List of the most likely tokens and their log probability,
549
+ at this token position. In rare cases, there may be fewer than the
550
+ number of requested top_logprobs returned.
551
+ """
552
+
553
+ token: str
554
+ logprob: float
555
+ top_logprobs: list[TopTokenLogProb]
556
+ bytes: Optional[list[int]] = None
557
+
558
+ def __post_init__(self):
559
+ self._validate_field("token", str, True)
560
+ self._validate_field("logprob", float, True)
561
+ self._convert_dataclass_list("top_logprobs", TopTokenLogProb)
562
+ self._validate_list("bytes", int, False)
563
+
564
+
565
+ @dataclass
566
+ class ChatChoiceLogProbs(_BaseDataclass):
567
+ """
568
+ Log probability information for the choice.
569
+
570
+ Args:
571
+ content: A list of message content tokens with log probability information.
572
+ """
573
+
574
+ content: Optional[list[TokenLogProb]] = None
575
+
576
+ def __post_init__(self):
577
+ self._convert_dataclass_list("content", TokenLogProb, False)
578
+
579
+
580
+ @dataclass
581
+ class ChatChoice(_BaseDataclass):
582
+ """
583
+ A single chat response generated by the model.
584
+ ref: https://platform.openai.com/docs/api-reference/chat/object
585
+
586
+ Args:
587
+ message (:py:class:`ChatMessage`): The message that was generated.
588
+ index (int): The index of the response in the list of responses.
589
+ Defaults to ``0``
590
+ finish_reason (str): The reason why generation stopped.
591
+ **Optional**, defaults to ``"stop"``
592
+ logprobs (:py:class:`ChatChoiceLogProbs`): Log probability information for the choice.
593
+ **Optional**, defaults to ``None``
594
+ """
595
+
596
+ message: ChatMessage
597
+ index: int = 0
598
+ finish_reason: str = "stop"
599
+ logprobs: Optional[ChatChoiceLogProbs] = None
600
+
601
+ def __post_init__(self):
602
+ self._validate_field("index", int, True)
603
+ self._validate_field("finish_reason", str, True)
604
+ self._convert_dataclass("message", ChatMessage, True)
605
+ self._convert_dataclass("logprobs", ChatChoiceLogProbs, False)
606
+
607
+
608
+ @dataclass
609
+ class ChatChunkChoice(_BaseDataclass):
610
+ """
611
+ A single chat response chunk generated by the model.
612
+ ref: https://platform.openai.com/docs/api-reference/chat/streaming
613
+
614
+ Args:
615
+ index (int): The index of the response in the list of responses.
616
+ defaults to ``0``
617
+ delta (:py:class:`ChatChoiceDelta`): The streaming chunk message that was generated.
618
+ finish_reason (str): The reason why generation stopped.
619
+ **Optional**, defaults to ``None``
620
+ logprobs (:py:class:`ChatChoiceLogProbs`): Log probability information for the choice.
621
+ **Optional**, defaults to ``None``
622
+ """
623
+
624
+ delta: ChatChoiceDelta
625
+ index: int = 0
626
+ finish_reason: Optional[str] = None
627
+ logprobs: Optional[ChatChoiceLogProbs] = None
628
+
629
+ def __post_init__(self):
630
+ self._validate_field("index", int, True)
631
+ self._validate_field("finish_reason", str, False)
632
+ self._convert_dataclass("delta", ChatChoiceDelta, True)
633
+ self._convert_dataclass("logprobs", ChatChoiceLogProbs, False)
634
+
635
+
636
+ @dataclass
637
+ class TokenUsageStats(_BaseDataclass):
638
+ """
639
+ Stats about the number of tokens used during inference.
640
+
641
+ Args:
642
+ prompt_tokens (int): The number of tokens in the prompt.
643
+ **Optional**, defaults to ``None``
644
+ completion_tokens (int): The number of tokens in the generated completion.
645
+ **Optional**, defaults to ``None``
646
+ total_tokens (int): The total number of tokens used.
647
+ **Optional**, defaults to ``None``
648
+ """
649
+
650
+ prompt_tokens: Optional[int] = None
651
+ completion_tokens: Optional[int] = None
652
+ total_tokens: Optional[int] = None
653
+
654
+ def __post_init__(self):
655
+ self._validate_field("prompt_tokens", int, False)
656
+ self._validate_field("completion_tokens", int, False)
657
+ self._validate_field("total_tokens", int, False)
658
+
659
+
660
+ @dataclass
661
+ class ChatCompletionResponse(_BaseDataclass):
662
+ """
663
+ The full response object returned by the chat endpoint.
664
+
665
+ Args:
666
+ choices (List[:py:class:`ChatChoice`]): A list of :py:class:`ChatChoice` objects
667
+ containing the generated responses
668
+ usage (:py:class:`TokenUsageStats`): An object describing the tokens used by the request.
669
+ **Optional**, defaults to ``None``.
670
+ id (str): The ID of the response. **Optional**, defaults to ``None``
671
+ model (str): The name of the model used. **Optional**, defaults to ``None``
672
+ object (str): The object type. Defaults to 'chat.completion'
673
+ created (int): The time the response was created.
674
+ **Optional**, defaults to the current time.
675
+ custom_outputs (Dict[str, Any]): An field that can contain arbitrary additional context.
676
+ The dictionary values must be JSON-serializable.
677
+ **Optional**, defaults to ``None``
678
+ """
679
+
680
+ choices: list[ChatChoice]
681
+ usage: Optional[TokenUsageStats] = None
682
+ id: Optional[str] = None
683
+ model: Optional[str] = None
684
+ object: str = "chat.completion"
685
+ created: int = field(default_factory=lambda: int(time.time()))
686
+ custom_outputs: Optional[dict[str, Any]] = None
687
+
688
+ def __post_init__(self):
689
+ self._validate_field("id", str, False)
690
+ self._validate_field("object", str, True)
691
+ self._validate_field("created", int, True)
692
+ self._validate_field("model", str, False)
693
+ self._convert_dataclass_list("choices", ChatChoice)
694
+ self._convert_dataclass("usage", TokenUsageStats, False)
695
+
696
+
697
+ @dataclass
698
+ class ChatCompletionChunk(_BaseDataclass):
699
+ """
700
+ The streaming chunk returned by the chat endpoint.
701
+ ref: https://platform.openai.com/docs/api-reference/chat/streaming
702
+
703
+ Args:
704
+ choices (List[:py:class:`ChatChunkChoice`]): A list of :py:class:`ChatChunkChoice` objects
705
+ containing the generated chunk of a streaming response
706
+ usage (:py:class:`TokenUsageStats`): An object describing the tokens used by the request.
707
+ **Optional**, defaults to ``None``.
708
+ id (str): The ID of the response. **Optional**, defaults to ``None``
709
+ model (str): The name of the model used. **Optional**, defaults to ``None``
710
+ object (str): The object type. Defaults to 'chat.completion.chunk'
711
+ created (int): The time the response was created.
712
+ **Optional**, defaults to the current time.
713
+ custom_outputs (Dict[str, Any]): An field that can contain arbitrary additional context.
714
+ The dictionary values must be JSON-serializable.
715
+ **Optional**, defaults to ``None``
716
+ """
717
+
718
+ choices: list[ChatChunkChoice]
719
+ usage: Optional[TokenUsageStats] = None
720
+ id: Optional[str] = None
721
+ model: Optional[str] = None
722
+ object: str = "chat.completion.chunk"
723
+ created: int = field(default_factory=lambda: int(time.time()))
724
+ custom_outputs: Optional[dict[str, Any]] = None
725
+
726
+ def __post_init__(self):
727
+ self._validate_field("id", str, False)
728
+ self._validate_field("object", str, True)
729
+ self._validate_field("created", int, True)
730
+ self._validate_field("model", str, False)
731
+ self._convert_dataclass_list("choices", ChatChunkChoice)
732
+ self._convert_dataclass("usage", TokenUsageStats, False)
733
+
734
+
735
+ # turn off formatting for the model signatures to preserve readability
736
+ # fmt: off
737
+
738
+ _token_usage_stats_col_spec = ColSpec(
739
+ name="usage",
740
+ type=Object(
741
+ [
742
+ Property("prompt_tokens", DataType.long),
743
+ Property("completion_tokens", DataType.long),
744
+ Property("total_tokens", DataType.long),
745
+ ]
746
+ ),
747
+ required=False,
748
+ )
749
+ _custom_inputs_col_spec = ColSpec(name="custom_inputs", type=Map(AnyType()), required=False)
750
+ _custom_outputs_col_spec = ColSpec(name="custom_outputs", type=Map(AnyType()), required=False)
751
+
752
+ CHAT_MODEL_INPUT_SCHEMA = Schema(
753
+ [
754
+ ColSpec(
755
+ name="messages",
756
+ type=Array(
757
+ Object(
758
+ [
759
+ Property("role", DataType.string),
760
+ Property("content", DataType.string, False),
761
+ Property("name", DataType.string, False),
762
+ Property("refusal", DataType.string, False),
763
+ Property("tool_calls", Array(Object([
764
+ Property("id", DataType.string),
765
+ Property("function", Object([
766
+ Property("name", DataType.string),
767
+ Property("arguments", DataType.string),
768
+ ])),
769
+ Property("type", DataType.string),
770
+ ])), False),
771
+ Property("tool_call_id", DataType.string, False),
772
+ ]
773
+ )
774
+ ),
775
+ ),
776
+ ColSpec(name="temperature", type=DataType.double, required=False),
777
+ ColSpec(name="max_tokens", type=DataType.long, required=False),
778
+ ColSpec(name="stop", type=Array(DataType.string), required=False),
779
+ ColSpec(name="n", type=DataType.long, required=False),
780
+ ColSpec(name="stream", type=DataType.boolean, required=False),
781
+ ColSpec(name="top_p", type=DataType.double, required=False),
782
+ ColSpec(name="top_k", type=DataType.long, required=False),
783
+ ColSpec(name="frequency_penalty", type=DataType.double, required=False),
784
+ ColSpec(name="presence_penalty", type=DataType.double, required=False),
785
+ ColSpec(
786
+ name="tools",
787
+ type=Array(
788
+ Object([
789
+ Property("type", DataType.string),
790
+ Property("function", Object([
791
+ Property("name", DataType.string),
792
+ Property("description", DataType.string, False),
793
+ Property("parameters", Object([
794
+ Property("properties", Map(Object([
795
+ Property("type", DataType.string),
796
+ Property("description", DataType.string, False),
797
+ Property("enum", Array(DataType.string), False),
798
+ Property("items", Object([Property("type", DataType.string)]), False), # noqa
799
+ ]))),
800
+ Property("type", DataType.string, False),
801
+ Property("required", Array(DataType.string), False),
802
+ Property("additionalProperties", DataType.boolean, False),
803
+ ])),
804
+ Property("strict", DataType.boolean, False),
805
+ ]), False),
806
+ ]),
807
+ ),
808
+ required=False,
809
+ ),
810
+ _custom_inputs_col_spec,
811
+ ]
812
+ )
813
+
814
+ CHAT_MODEL_OUTPUT_SCHEMA = Schema(
815
+ [
816
+ ColSpec(name="id", type=DataType.string),
817
+ ColSpec(name="object", type=DataType.string),
818
+ ColSpec(name="created", type=DataType.long),
819
+ ColSpec(name="model", type=DataType.string),
820
+ ColSpec(
821
+ name="choices",
822
+ type=Array(Object([
823
+ Property("index", DataType.long),
824
+ Property("message", Object([
825
+ Property("role", DataType.string),
826
+ Property("content", DataType.string, False),
827
+ Property("name", DataType.string, False),
828
+ Property("refusal", DataType.string, False),
829
+ Property("tool_calls",Array(Object([
830
+ Property("id", DataType.string),
831
+ Property("function", Object([
832
+ Property("name", DataType.string),
833
+ Property("arguments", DataType.string),
834
+ ])),
835
+ Property("type", DataType.string),
836
+ ])), False),
837
+ Property("tool_call_id", DataType.string, False),
838
+ ])),
839
+ Property("finish_reason", DataType.string),
840
+ ])),
841
+ ),
842
+ _token_usage_stats_col_spec,
843
+ _custom_outputs_col_spec
844
+ ]
845
+ )
846
+
847
+ CHAT_MODEL_INPUT_EXAMPLE = {
848
+ "messages": [
849
+ {"role": "user", "content": "Hello!"},
850
+ ],
851
+ "temperature": 1.0,
852
+ "max_tokens": 10,
853
+ "stop": ["\n"],
854
+ "n": 1,
855
+ "stream": False,
856
+ }
857
+
858
+ COMPLETIONS_MODEL_INPUT_SCHEMA = Schema(
859
+ [
860
+ ColSpec(name="prompt", type=DataType.string),
861
+ ColSpec(name="temperature", type=DataType.double, required=False),
862
+ ColSpec(name="max_tokens", type=DataType.long, required=False),
863
+ ColSpec(name="stop", type=Array(DataType.string), required=False),
864
+ ColSpec(name="n", type=DataType.long, required=False),
865
+ ColSpec(name="stream", type=DataType.boolean, required=False),
866
+ ]
867
+ )
868
+
869
+ COMPLETIONS_MODEL_OUTPUT_SCHEMA = Schema(
870
+ [
871
+ ColSpec(name="id", type=DataType.string),
872
+ ColSpec(name="object", type=DataType.string),
873
+ ColSpec(name="created", type=DataType.long),
874
+ ColSpec(name="model", type=DataType.string),
875
+ ColSpec(
876
+ name="choices",
877
+ type=Array(
878
+ Object(
879
+ [
880
+ Property("index", DataType.long),
881
+ Property(
882
+ "text",
883
+ DataType.string,
884
+ ),
885
+ Property("finish_reason", DataType.string),
886
+ ]
887
+ )
888
+ ),
889
+ ),
890
+ ColSpec(
891
+ name="usage",
892
+ type=Object(
893
+ [
894
+ Property("prompt_tokens", DataType.long),
895
+ Property("completion_tokens", DataType.long),
896
+ Property("total_tokens", DataType.long),
897
+ ]
898
+ ),
899
+ ),
900
+ ]
901
+ )
902
+
903
+ EMBEDDING_MODEL_INPUT_SCHEMA = Schema(
904
+ [
905
+ ColSpec(name="input", type=DataType.string),
906
+ ]
907
+ )
908
+
909
+ EMBEDDING_MODEL_OUTPUT_SCHEMA = Schema(
910
+ [
911
+ ColSpec(name="object", type=DataType.string),
912
+ ColSpec(
913
+ name="data",
914
+ type=Array(
915
+ Object(
916
+ [
917
+ Property("index", DataType.long),
918
+ Property("object", DataType.string),
919
+ Property("embedding", Array(DataType.double)),
920
+ ]
921
+ )
922
+ ),
923
+ ),
924
+ ColSpec(
925
+ name="usage",
926
+ type=Object(
927
+ [
928
+ Property("prompt_tokens", DataType.long),
929
+ Property("total_tokens", DataType.long),
930
+ ]
931
+ ),
932
+ ),
933
+ ]
934
+ )
935
+ # fmt: on