arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.1.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.1.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,948 @@
1
+ import asyncio
2
+ import importlib.util
3
+ import inspect
4
+ import json
5
+ import os
6
+ import time
7
+ from abc import ABC, abstractmethod
8
+ from collections.abc import AsyncIterator, Callable, Iterator
9
+ from functools import wraps
10
+ from typing import TYPE_CHECKING, Any, Hashable, Mapping, Optional, Union
11
+
12
+ from openinference.instrumentation import safe_json_dumps
13
+ from openinference.semconv.trace import (
14
+ OpenInferenceLLMProviderValues,
15
+ OpenInferenceLLMSystemValues,
16
+ SpanAttributes,
17
+ )
18
+ from strawberry import UNSET
19
+ from strawberry.scalars import JSON as JSONScalarType
20
+ from typing_extensions import TypeAlias, assert_never
21
+
22
+ from phoenix.evals.models.rate_limiters import (
23
+ AsyncCallable,
24
+ GenericType,
25
+ ParameterSpec,
26
+ RateLimiter,
27
+ RateLimitError,
28
+ )
29
+ from phoenix.server.api.exceptions import BadRequest
30
+ from phoenix.server.api.helpers.playground_registry import PROVIDER_DEFAULT, register_llm_client
31
+ from phoenix.server.api.input_types.GenerativeModelInput import GenerativeModelInput
32
+ from phoenix.server.api.input_types.InvocationParameters import (
33
+ BoundedFloatInvocationParameter,
34
+ CanonicalParameterName,
35
+ FloatInvocationParameter,
36
+ IntInvocationParameter,
37
+ InvocationParameter,
38
+ InvocationParameterInput,
39
+ JSONInvocationParameter,
40
+ StringListInvocationParameter,
41
+ extract_parameter,
42
+ validate_invocation_parameters,
43
+ )
44
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
45
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
46
+ FunctionCallChunk,
47
+ TextChunk,
48
+ ToolCallChunk,
49
+ )
50
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
51
+
52
+ if TYPE_CHECKING:
53
+ from anthropic.types import MessageParam, TextBlockParam, ToolResultBlockParam
54
+ from google.generativeai.types import ContentType
55
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
56
+ from openai.types import CompletionUsage
57
+ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCallParam
58
+
59
+ SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
60
+ ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
61
+
62
+
63
+ class Dependency:
64
+ """
65
+ Set the module_name to the import name if it is different from the install name
66
+ """
67
+
68
+ def __init__(self, name: str, module_name: Optional[str] = None):
69
+ self.name = name
70
+ self.module_name = module_name
71
+
72
+ @property
73
+ def import_name(self) -> str:
74
+ return self.module_name or self.name
75
+
76
+
77
+ class KeyedSingleton:
78
+ _instances: dict[Hashable, "KeyedSingleton"] = {}
79
+
80
+ def __new__(cls, *args: Any, **kwargs: Any) -> "KeyedSingleton":
81
+ if "singleton_key" in kwargs:
82
+ singleton_key = kwargs.pop("singleton_key")
83
+ elif args:
84
+ singleton_key = args[0]
85
+ args = args[1:]
86
+ else:
87
+ raise ValueError("singleton_key must be provided")
88
+
89
+ instance_key = (cls, singleton_key)
90
+ if instance_key not in cls._instances:
91
+ instance = super().__new__(cls)
92
+ cls._instances[instance_key] = instance
93
+ return cls._instances[instance_key]
94
+
95
+
96
+ class PlaygroundRateLimiter(RateLimiter, KeyedSingleton):
97
+ """
98
+ A rate rate limiter class that will be instantiated once per `singleton_key`.
99
+ """
100
+
101
+ def __init__(self, singleton_key: Hashable, rate_limit_error: Optional[type[BaseException]]):
102
+ super().__init__(
103
+ rate_limit_error=rate_limit_error,
104
+ max_rate_limit_retries=3,
105
+ initial_per_second_request_rate=1.0,
106
+ maximum_per_second_request_rate=3.0,
107
+ enforcement_window_minutes=0.05,
108
+ rate_reduction_factor=0.5,
109
+ rate_increase_factor=0.01,
110
+ cooldown_seconds=5,
111
+ verbose=False,
112
+ )
113
+
114
+ # TODO: update the rate limiter class in phoenix.evals to support decorated sync functions
115
+ def _alimit(
116
+ self, fn: Callable[ParameterSpec, GenericType]
117
+ ) -> AsyncCallable[ParameterSpec, GenericType]:
118
+ @wraps(fn)
119
+ async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
120
+ self._initialize_async_primitives()
121
+ assert self._rate_limit_handling_lock is not None and isinstance(
122
+ self._rate_limit_handling_lock, asyncio.Lock
123
+ )
124
+ assert self._rate_limit_handling is not None and isinstance(
125
+ self._rate_limit_handling, asyncio.Event
126
+ )
127
+ try:
128
+ try:
129
+ await asyncio.wait_for(self._rate_limit_handling.wait(), 120)
130
+ except asyncio.TimeoutError:
131
+ self._rate_limit_handling.set() # Set the event as a failsafe
132
+ await self._throttler.async_wait_until_ready()
133
+ request_start_time = time.time()
134
+ maybe_coroutine = fn(*args, **kwargs)
135
+ if inspect.isawaitable(maybe_coroutine):
136
+ return await maybe_coroutine # type: ignore[no-any-return]
137
+ else:
138
+ return maybe_coroutine
139
+ except self._rate_limit_error:
140
+ async with self._rate_limit_handling_lock:
141
+ self._rate_limit_handling.clear() # prevent new requests from starting
142
+ self._throttler.on_rate_limit_error(request_start_time, verbose=self._verbose)
143
+ try:
144
+ for _attempt in range(self._max_rate_limit_retries):
145
+ try:
146
+ request_start_time = time.time()
147
+ await self._throttler.async_wait_until_ready()
148
+ maybe_coroutine = fn(*args, **kwargs)
149
+ if inspect.isawaitable(maybe_coroutine):
150
+ return await maybe_coroutine # type: ignore[no-any-return]
151
+ else:
152
+ return maybe_coroutine
153
+ except self._rate_limit_error:
154
+ self._throttler.on_rate_limit_error(
155
+ request_start_time, verbose=self._verbose
156
+ )
157
+ continue
158
+ finally:
159
+ self._rate_limit_handling.set() # allow new requests to start
160
+ raise RateLimitError(f"Exceeded max ({self._max_rate_limit_retries}) retries")
161
+
162
+ return wrapper
163
+
164
+
165
+ class PlaygroundStreamingClient(ABC):
166
+ def __init__(
167
+ self,
168
+ model: GenerativeModelInput,
169
+ api_key: Optional[str] = None,
170
+ ) -> None:
171
+ self._attributes: dict[str, Any] = dict()
172
+
173
+ @classmethod
174
+ @abstractmethod
175
+ def dependencies(cls) -> list[Dependency]:
176
+ # A list of dependencies this client needs to run
177
+ ...
178
+
179
+ @classmethod
180
+ @abstractmethod
181
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]: ...
182
+
183
+ @abstractmethod
184
+ async def chat_completion_create(
185
+ self,
186
+ messages: list[
187
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
188
+ ],
189
+ tools: list[JSONScalarType],
190
+ **invocation_parameters: Any,
191
+ ) -> AsyncIterator[ChatCompletionChunk]:
192
+ # a yield statement is needed to satisfy the type-checker
193
+ # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
194
+ yield TextChunk(content="")
195
+
196
+ @classmethod
197
+ def construct_invocation_parameters(
198
+ cls, invocation_parameters: list[InvocationParameterInput]
199
+ ) -> dict[str, Any]:
200
+ supported_params = cls.supported_invocation_parameters()
201
+ params = {param.invocation_name: param for param in supported_params}
202
+
203
+ formatted_invocation_parameters = dict()
204
+
205
+ for param_input in invocation_parameters:
206
+ invocation_name = param_input.invocation_name
207
+ if invocation_name not in params:
208
+ raise ValueError(f"Unsupported invocation parameter: {invocation_name}")
209
+
210
+ param_def = params[invocation_name]
211
+ value = extract_parameter(param_def, param_input)
212
+ if value is not UNSET:
213
+ formatted_invocation_parameters[invocation_name] = value
214
+ validate_invocation_parameters(supported_params, formatted_invocation_parameters)
215
+ return formatted_invocation_parameters
216
+
217
+ @classmethod
218
+ def dependencies_are_installed(cls) -> bool:
219
+ try:
220
+ for dependency in cls.dependencies():
221
+ import_name = dependency.import_name
222
+ if importlib.util.find_spec(import_name) is None:
223
+ return False
224
+ return True
225
+ except ValueError:
226
+ # happens in some cases if the spec is None
227
+ return False
228
+
229
+ @property
230
+ def attributes(self) -> dict[str, Any]:
231
+ return self._attributes
232
+
233
+
234
+ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
235
+ def __init__(
236
+ self,
237
+ *,
238
+ client: Union["AsyncOpenAI", "AsyncAzureOpenAI"],
239
+ model: GenerativeModelInput,
240
+ api_key: Optional[str] = None,
241
+ ) -> None:
242
+ from openai import RateLimitError as OpenAIRateLimitError
243
+
244
+ super().__init__(model=model, api_key=api_key)
245
+ self.client = client
246
+ self.model_name = model.name
247
+ self.rate_limiter = PlaygroundRateLimiter(model.provider_key, OpenAIRateLimitError)
248
+
249
+ @classmethod
250
+ def dependencies(cls) -> list[Dependency]:
251
+ return [Dependency(name="openai")]
252
+
253
+ @classmethod
254
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
255
+ return [
256
+ BoundedFloatInvocationParameter(
257
+ invocation_name="temperature",
258
+ canonical_name=CanonicalParameterName.TEMPERATURE,
259
+ label="Temperature",
260
+ default_value=1.0,
261
+ min_value=0.0,
262
+ max_value=2.0,
263
+ ),
264
+ IntInvocationParameter(
265
+ invocation_name="max_tokens",
266
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
267
+ label="Max Tokens",
268
+ ),
269
+ BoundedFloatInvocationParameter(
270
+ invocation_name="frequency_penalty",
271
+ label="Frequency Penalty",
272
+ default_value=0.0,
273
+ min_value=-2.0,
274
+ max_value=2.0,
275
+ ),
276
+ BoundedFloatInvocationParameter(
277
+ invocation_name="presence_penalty",
278
+ label="Presence Penalty",
279
+ default_value=0.0,
280
+ min_value=-2.0,
281
+ max_value=2.0,
282
+ ),
283
+ StringListInvocationParameter(
284
+ invocation_name="stop",
285
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
286
+ label="Stop Sequences",
287
+ ),
288
+ BoundedFloatInvocationParameter(
289
+ invocation_name="top_p",
290
+ canonical_name=CanonicalParameterName.TOP_P,
291
+ label="Top P",
292
+ default_value=1.0,
293
+ min_value=0.0,
294
+ max_value=1.0,
295
+ ),
296
+ IntInvocationParameter(
297
+ invocation_name="seed",
298
+ canonical_name=CanonicalParameterName.RANDOM_SEED,
299
+ label="Seed",
300
+ ),
301
+ JSONInvocationParameter(
302
+ invocation_name="tool_choice",
303
+ label="Tool Choice",
304
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
305
+ ),
306
+ JSONInvocationParameter(
307
+ invocation_name="response_format",
308
+ label="Response Format",
309
+ canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
310
+ ),
311
+ ]
312
+
313
+ async def chat_completion_create(
314
+ self,
315
+ messages: list[
316
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
317
+ ],
318
+ tools: list[JSONScalarType],
319
+ **invocation_parameters: Any,
320
+ ) -> AsyncIterator[ChatCompletionChunk]:
321
+ from openai import NOT_GIVEN
322
+ from openai.types.chat import ChatCompletionStreamOptionsParam
323
+
324
+ # Convert standard messages to OpenAI messages
325
+ openai_messages = []
326
+ for message in messages:
327
+ openai_message = self.to_openai_chat_completion_param(*message)
328
+ if openai_message is not None:
329
+ openai_messages.append(openai_message)
330
+ tool_call_ids: dict[int, str] = {}
331
+ token_usage: Optional["CompletionUsage"] = None
332
+ throttled_create = self.rate_limiter._alimit(self.client.chat.completions.create)
333
+ async for chunk in await throttled_create(
334
+ messages=openai_messages,
335
+ model=self.model_name,
336
+ stream=True,
337
+ stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
338
+ tools=tools or NOT_GIVEN,
339
+ **invocation_parameters,
340
+ ):
341
+ if (usage := chunk.usage) is not None:
342
+ token_usage = usage
343
+ continue
344
+ if not chunk.choices:
345
+ # for Azure, initial chunk contains the content filter
346
+ continue
347
+ choice = chunk.choices[0]
348
+ delta = choice.delta
349
+ if choice.finish_reason is None:
350
+ if isinstance(chunk_content := delta.content, str):
351
+ text_chunk = TextChunk(content=chunk_content)
352
+ yield text_chunk
353
+ if (tool_calls := delta.tool_calls) is not None:
354
+ for tool_call_index, tool_call in enumerate(tool_calls):
355
+ tool_call_id = (
356
+ tool_call.id
357
+ if tool_call.id is not None
358
+ else tool_call_ids[tool_call_index]
359
+ )
360
+ tool_call_ids[tool_call_index] = tool_call_id
361
+ if (function := tool_call.function) is not None:
362
+ tool_call_chunk = ToolCallChunk(
363
+ id=tool_call_id,
364
+ function=FunctionCallChunk(
365
+ name=function.name or "",
366
+ arguments=function.arguments or "",
367
+ ),
368
+ )
369
+ yield tool_call_chunk
370
+ if token_usage is not None:
371
+ self._attributes.update(dict(self._llm_token_counts(token_usage)))
372
+
373
+ def to_openai_chat_completion_param(
374
+ self,
375
+ role: ChatCompletionMessageRole,
376
+ content: JSONScalarType,
377
+ tool_call_id: Optional[str] = None,
378
+ tool_calls: Optional[list[JSONScalarType]] = None,
379
+ ) -> Optional["ChatCompletionMessageParam"]:
380
+ from openai.types.chat import (
381
+ ChatCompletionAssistantMessageParam,
382
+ ChatCompletionSystemMessageParam,
383
+ ChatCompletionToolMessageParam,
384
+ ChatCompletionUserMessageParam,
385
+ )
386
+
387
+ if role is ChatCompletionMessageRole.USER:
388
+ return ChatCompletionUserMessageParam(
389
+ {
390
+ "content": content,
391
+ "role": "user",
392
+ }
393
+ )
394
+ if role is ChatCompletionMessageRole.SYSTEM:
395
+ return ChatCompletionSystemMessageParam(
396
+ {
397
+ "content": content,
398
+ "role": "system",
399
+ }
400
+ )
401
+ if role is ChatCompletionMessageRole.AI:
402
+ if tool_calls is None:
403
+ return ChatCompletionAssistantMessageParam(
404
+ {
405
+ "content": content,
406
+ "role": "assistant",
407
+ }
408
+ )
409
+ else:
410
+ return ChatCompletionAssistantMessageParam(
411
+ {
412
+ "content": content,
413
+ "role": "assistant",
414
+ "tool_calls": [
415
+ self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
416
+ ],
417
+ }
418
+ )
419
+ if role is ChatCompletionMessageRole.TOOL:
420
+ if tool_call_id is None:
421
+ raise ValueError("tool_call_id is required for tool messages")
422
+ return ChatCompletionToolMessageParam(
423
+ {"content": content, "role": "tool", "tool_call_id": tool_call_id}
424
+ )
425
+ assert_never(role)
426
+
427
+ def to_openai_tool_call_param(
428
+ self,
429
+ tool_call: JSONScalarType,
430
+ ) -> "ChatCompletionMessageToolCallParam":
431
+ from openai.types.chat import ChatCompletionMessageToolCallParam
432
+
433
+ return ChatCompletionMessageToolCallParam(
434
+ id=tool_call.get("id", ""),
435
+ function={
436
+ "name": tool_call.get("function", {}).get("name", ""),
437
+ "arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
438
+ },
439
+ type="function",
440
+ )
441
+
442
+ @staticmethod
443
+ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
444
+ yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
445
+ yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
446
+ yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
447
+
448
+
449
+ @register_llm_client(
450
+ provider_key=GenerativeProviderKey.OPENAI,
451
+ model_names=[
452
+ PROVIDER_DEFAULT,
453
+ "gpt-4o",
454
+ "gpt-4o-2024-08-06",
455
+ "gpt-4o-2024-05-13",
456
+ "chatgpt-4o-latest",
457
+ "gpt-4o-mini",
458
+ "gpt-4o-mini-2024-07-18",
459
+ "gpt-4-turbo",
460
+ "gpt-4-turbo-2024-04-09",
461
+ "gpt-4-turbo-preview",
462
+ "gpt-4-0125-preview",
463
+ "gpt-4-1106-preview",
464
+ "gpt-4",
465
+ "gpt-4-0613",
466
+ "gpt-3.5-turbo-0125",
467
+ "gpt-3.5-turbo",
468
+ "gpt-3.5-turbo-1106",
469
+ "gpt-3.5-turbo-instruct",
470
+ ],
471
+ )
472
+ class OpenAIStreamingClient(OpenAIBaseStreamingClient):
473
+ def __init__(
474
+ self,
475
+ model: GenerativeModelInput,
476
+ api_key: Optional[str] = None,
477
+ ) -> None:
478
+ from openai import AsyncOpenAI
479
+
480
+ # todo: check if custom base url is set before raising error to allow
481
+ # for custom endpoints that don't require an API key
482
+ if not (api_key := api_key or os.environ.get("OPENAI_API_KEY")):
483
+ raise BadRequest("An API key is required for OpenAI models")
484
+ client = AsyncOpenAI(api_key=api_key)
485
+ super().__init__(client=client, model=model, api_key=api_key)
486
+ self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.OPENAI.value
487
+ self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
488
+
489
+
490
+ @register_llm_client(
491
+ provider_key=GenerativeProviderKey.OPENAI,
492
+ model_names=[
493
+ "o1",
494
+ "o1-2024-12-17",
495
+ "o1-mini",
496
+ "o1-mini-2024-09-12",
497
+ "o1-preview",
498
+ "o1-preview-2024-09-12",
499
+ ],
500
+ )
501
+ class OpenAIO1StreamingClient(OpenAIStreamingClient):
502
+ @classmethod
503
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
504
+ return [
505
+ IntInvocationParameter(
506
+ invocation_name="max_completion_tokens",
507
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
508
+ label="Max Completion Tokens",
509
+ ),
510
+ IntInvocationParameter(
511
+ invocation_name="seed",
512
+ canonical_name=CanonicalParameterName.RANDOM_SEED,
513
+ label="Seed",
514
+ ),
515
+ JSONInvocationParameter(
516
+ invocation_name="tool_choice",
517
+ label="Tool Choice",
518
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
519
+ ),
520
+ JSONInvocationParameter(
521
+ invocation_name="response_format",
522
+ label="Response Format",
523
+ canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
524
+ ),
525
+ ]
526
+
527
+ def to_openai_chat_completion_param(
528
+ self,
529
+ role: ChatCompletionMessageRole,
530
+ content: JSONScalarType,
531
+ tool_call_id: Optional[str] = None,
532
+ tool_calls: Optional[list[JSONScalarType]] = None,
533
+ ) -> Optional["ChatCompletionMessageParam"]:
534
+ from openai.types.chat import (
535
+ ChatCompletionAssistantMessageParam,
536
+ ChatCompletionToolMessageParam,
537
+ ChatCompletionUserMessageParam,
538
+ )
539
+
540
+ if role is ChatCompletionMessageRole.USER:
541
+ return ChatCompletionUserMessageParam(
542
+ {
543
+ "content": content,
544
+ "role": "user",
545
+ }
546
+ )
547
+ if role is ChatCompletionMessageRole.SYSTEM:
548
+ return None # System messages are not supported for o1 models
549
+ if role is ChatCompletionMessageRole.AI:
550
+ if tool_calls is None:
551
+ return ChatCompletionAssistantMessageParam(
552
+ {
553
+ "content": content,
554
+ "role": "assistant",
555
+ }
556
+ )
557
+ else:
558
+ return ChatCompletionAssistantMessageParam(
559
+ {
560
+ "content": content,
561
+ "role": "assistant",
562
+ "tool_calls": [
563
+ self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
564
+ ],
565
+ }
566
+ )
567
+ if role is ChatCompletionMessageRole.TOOL:
568
+ if tool_call_id is None:
569
+ raise ValueError("tool_call_id is required for tool messages")
570
+ return ChatCompletionToolMessageParam(
571
+ {"content": content, "role": "tool", "tool_call_id": tool_call_id}
572
+ )
573
+ assert_never(role)
574
+
575
+ @staticmethod
576
+ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
577
+ yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
578
+ yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
579
+ yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
580
+
581
+
582
+ @register_llm_client(
583
+ provider_key=GenerativeProviderKey.AZURE_OPENAI,
584
+ model_names=[
585
+ PROVIDER_DEFAULT,
586
+ ],
587
+ )
588
+ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
589
+ def __init__(
590
+ self,
591
+ model: GenerativeModelInput,
592
+ api_key: Optional[str] = None,
593
+ ):
594
+ from openai import AsyncAzureOpenAI
595
+
596
+ if not (api_key := api_key or os.environ.get("AZURE_OPENAI_API_KEY")):
597
+ raise BadRequest("An Azure API key is required for Azure OpenAI models")
598
+ if not (endpoint := model.endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")):
599
+ raise BadRequest("An Azure endpoint is required for Azure OpenAI models")
600
+ if not (api_version := model.api_version or os.environ.get("OPENAI_API_VERSION")):
601
+ raise BadRequest("An OpenAI API version is required for Azure OpenAI models")
602
+ client = AsyncAzureOpenAI(
603
+ api_key=api_key,
604
+ azure_endpoint=endpoint,
605
+ api_version=api_version,
606
+ )
607
+ super().__init__(client=client, model=model, api_key=api_key)
608
+ self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.AZURE.value
609
+ self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
610
+
611
+
612
+ @register_llm_client(
613
+ provider_key=GenerativeProviderKey.ANTHROPIC,
614
+ model_names=[
615
+ PROVIDER_DEFAULT,
616
+ "claude-3-5-sonnet-latest",
617
+ "claude-3-5-haiku-latest",
618
+ "claude-3-5-sonnet-20241022",
619
+ "claude-3-5-haiku-20241022",
620
+ "claude-3-5-sonnet-20240620",
621
+ "claude-3-opus-latest",
622
+ "claude-3-sonnet-20240229",
623
+ "claude-3-haiku-20240307",
624
+ ],
625
+ )
626
+ class AnthropicStreamingClient(PlaygroundStreamingClient):
627
+ def __init__(
628
+ self,
629
+ model: GenerativeModelInput,
630
+ api_key: Optional[str] = None,
631
+ ) -> None:
632
+ import anthropic
633
+
634
+ super().__init__(model=model, api_key=api_key)
635
+ self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.ANTHROPIC.value
636
+ self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.ANTHROPIC.value
637
+ if not (api_key := api_key or os.environ.get("ANTHROPIC_API_KEY")):
638
+ raise BadRequest("An API key is required for Anthropic models")
639
+ self.client = anthropic.AsyncAnthropic(api_key=api_key)
640
+ self.model_name = model.name
641
+ self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
642
+
643
+ @classmethod
644
+ def dependencies(cls) -> list[Dependency]:
645
+ return [Dependency(name="anthropic")]
646
+
647
+ @classmethod
648
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
649
+ return [
650
+ IntInvocationParameter(
651
+ invocation_name="max_tokens",
652
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
653
+ label="Max Tokens",
654
+ default_value=1024,
655
+ required=True,
656
+ ),
657
+ BoundedFloatInvocationParameter(
658
+ invocation_name="temperature",
659
+ canonical_name=CanonicalParameterName.TEMPERATURE,
660
+ label="Temperature",
661
+ default_value=1.0,
662
+ min_value=0.0,
663
+ max_value=1.0,
664
+ ),
665
+ StringListInvocationParameter(
666
+ invocation_name="stop_sequences",
667
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
668
+ label="Stop Sequences",
669
+ ),
670
+ BoundedFloatInvocationParameter(
671
+ invocation_name="top_p",
672
+ canonical_name=CanonicalParameterName.TOP_P,
673
+ label="Top P",
674
+ default_value=1.0,
675
+ min_value=0.0,
676
+ max_value=1.0,
677
+ ),
678
+ JSONInvocationParameter(
679
+ invocation_name="tool_choice",
680
+ label="Tool Choice",
681
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
682
+ ),
683
+ ]
684
+
685
+ async def chat_completion_create(
686
+ self,
687
+ messages: list[
688
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
689
+ ],
690
+ tools: list[JSONScalarType],
691
+ **invocation_parameters: Any,
692
+ ) -> AsyncIterator[ChatCompletionChunk]:
693
+ import anthropic.lib.streaming as anthropic_streaming
694
+ import anthropic.types as anthropic_types
695
+
696
+ anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
697
+ anthropic_params = {
698
+ "messages": anthropic_messages,
699
+ "model": self.model_name,
700
+ "system": system_prompt,
701
+ "tools": tools,
702
+ **invocation_parameters,
703
+ }
704
+ throttled_stream = self.rate_limiter._alimit(self.client.messages.stream)
705
+ async with await throttled_stream(**anthropic_params) as stream:
706
+ async for event in stream:
707
+ if isinstance(event, anthropic_types.RawMessageStartEvent):
708
+ self._attributes.update(
709
+ {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
710
+ )
711
+ elif isinstance(event, anthropic_streaming.TextEvent):
712
+ yield TextChunk(content=event.text)
713
+ elif isinstance(event, anthropic_streaming.MessageStopEvent):
714
+ self._attributes.update(
715
+ {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
716
+ )
717
+ elif (
718
+ isinstance(event, anthropic_streaming.ContentBlockStopEvent)
719
+ and event.content_block.type == "tool_use"
720
+ ):
721
+ tool_call_chunk = ToolCallChunk(
722
+ id=event.content_block.id,
723
+ function=FunctionCallChunk(
724
+ name=event.content_block.name,
725
+ arguments=json.dumps(event.content_block.input),
726
+ ),
727
+ )
728
+ yield tool_call_chunk
729
+ elif isinstance(
730
+ event,
731
+ (
732
+ anthropic_types.RawContentBlockStartEvent,
733
+ anthropic_types.RawContentBlockDeltaEvent,
734
+ anthropic_types.RawMessageDeltaEvent,
735
+ anthropic_streaming.ContentBlockStopEvent,
736
+ anthropic_streaming.InputJsonEvent,
737
+ ),
738
+ ):
739
+ # event types emitted by the stream that don't contain useful information
740
+ pass
741
+ elif isinstance(event, anthropic_streaming.InputJsonEvent):
742
+ raise NotImplementedError
743
+ else:
744
+ assert_never(event)
745
+
746
+ def _build_anthropic_messages(
747
+ self,
748
+ messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
749
+ ) -> tuple[list["MessageParam"], str]:
750
+ anthropic_messages: list["MessageParam"] = []
751
+ system_prompt = ""
752
+ for role, content, _tool_call_id, _tool_calls in messages:
753
+ tool_aware_content = self._anthropic_message_content(content, _tool_calls)
754
+ if role == ChatCompletionMessageRole.USER:
755
+ anthropic_messages.append({"role": "user", "content": tool_aware_content})
756
+ elif role == ChatCompletionMessageRole.AI:
757
+ anthropic_messages.append({"role": "assistant", "content": tool_aware_content})
758
+ elif role == ChatCompletionMessageRole.SYSTEM:
759
+ system_prompt += content + "\n"
760
+ elif role == ChatCompletionMessageRole.TOOL:
761
+ anthropic_messages.append(
762
+ {
763
+ "role": "user",
764
+ "content": [
765
+ {
766
+ "type": "tool_result",
767
+ "tool_use_id": _tool_call_id or "",
768
+ "content": content or "",
769
+ }
770
+ ],
771
+ }
772
+ )
773
+ else:
774
+ assert_never(role)
775
+
776
+ return anthropic_messages, system_prompt
777
+
778
+ def _anthropic_message_content(
779
+ self, content: str, tool_calls: Optional[list[JSONScalarType]]
780
+ ) -> Union[str, list[Union["ToolResultBlockParam", "TextBlockParam"]]]:
781
+ if tool_calls:
782
+ # Anthropic combines tool calls and the reasoning text into a single message object
783
+ tool_use_content: list[Union["ToolResultBlockParam", "TextBlockParam"]] = []
784
+ if content:
785
+ tool_use_content.append({"type": "text", "text": content})
786
+ tool_use_content.extend(tool_calls)
787
+ return tool_use_content
788
+
789
+ return content
790
+
791
+
792
+ @register_llm_client(
793
+ provider_key=GenerativeProviderKey.GEMINI,
794
+ model_names=[
795
+ PROVIDER_DEFAULT,
796
+ "gemini-2.0-flash-exp",
797
+ "gemini-1.5-flash",
798
+ "gemini-1.5-flash-8b",
799
+ "gemini-1.5-pro",
800
+ "gemini-1.0-pro",
801
+ ],
802
+ )
803
+ class GeminiStreamingClient(PlaygroundStreamingClient):
804
+ def __init__(
805
+ self,
806
+ model: GenerativeModelInput,
807
+ api_key: Optional[str] = None,
808
+ ) -> None:
809
+ import google.generativeai as google_genai
810
+
811
+ super().__init__(model=model, api_key=api_key)
812
+ self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
813
+ self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.VERTEXAI.value
814
+ if not (
815
+ api_key := api_key
816
+ or os.environ.get("GEMINI_API_KEY")
817
+ or os.environ.get("GOOGLE_API_KEY")
818
+ ):
819
+ raise BadRequest("An API key is required for Gemini models")
820
+ google_genai.configure(api_key=api_key)
821
+ self.model_name = model.name
822
+
823
+ @classmethod
824
+ def dependencies(cls) -> list[Dependency]:
825
+ return [Dependency(name="google-generativeai", module_name="google.generativeai")]
826
+
827
+ @classmethod
828
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
829
+ return [
830
+ BoundedFloatInvocationParameter(
831
+ invocation_name="temperature",
832
+ canonical_name=CanonicalParameterName.TEMPERATURE,
833
+ label="Temperature",
834
+ default_value=1.0,
835
+ min_value=0.0,
836
+ max_value=2.0,
837
+ ),
838
+ IntInvocationParameter(
839
+ invocation_name="max_output_tokens",
840
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
841
+ label="Max Output Tokens",
842
+ ),
843
+ StringListInvocationParameter(
844
+ invocation_name="stop_sequences",
845
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
846
+ label="Stop Sequences",
847
+ ),
848
+ FloatInvocationParameter(
849
+ invocation_name="presence_penalty",
850
+ label="Presence Penalty",
851
+ default_value=0.0,
852
+ ),
853
+ FloatInvocationParameter(
854
+ invocation_name="frequency_penalty",
855
+ label="Frequency Penalty",
856
+ default_value=0.0,
857
+ ),
858
+ BoundedFloatInvocationParameter(
859
+ invocation_name="top_p",
860
+ canonical_name=CanonicalParameterName.TOP_P,
861
+ label="Top P",
862
+ default_value=1.0,
863
+ min_value=0.0,
864
+ max_value=1.0,
865
+ ),
866
+ IntInvocationParameter(
867
+ invocation_name="top_k",
868
+ label="Top K",
869
+ ),
870
+ ]
871
+
872
+ async def chat_completion_create(
873
+ self,
874
+ messages: list[
875
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
876
+ ],
877
+ tools: list[JSONScalarType],
878
+ **invocation_parameters: Any,
879
+ ) -> AsyncIterator[ChatCompletionChunk]:
880
+ import google.generativeai as google_genai
881
+
882
+ gemini_message_history, current_message, system_prompt = self._build_gemini_messages(
883
+ messages
884
+ )
885
+
886
+ model_args = {"model_name": self.model_name}
887
+ if system_prompt:
888
+ model_args["system_instruction"] = system_prompt
889
+ client = google_genai.GenerativeModel(**model_args)
890
+
891
+ gemini_config = google_genai.GenerationConfig(
892
+ **invocation_parameters,
893
+ )
894
+ gemini_params = {
895
+ "content": current_message,
896
+ "generation_config": gemini_config,
897
+ "stream": True,
898
+ }
899
+
900
+ chat = client.start_chat(history=gemini_message_history)
901
+ stream = await chat.send_message_async(**gemini_params)
902
+ async for event in stream:
903
+ self._attributes.update(
904
+ {
905
+ LLM_TOKEN_COUNT_PROMPT: event.usage_metadata.prompt_token_count,
906
+ LLM_TOKEN_COUNT_COMPLETION: event.usage_metadata.candidates_token_count,
907
+ LLM_TOKEN_COUNT_TOTAL: event.usage_metadata.total_token_count,
908
+ }
909
+ )
910
+ yield TextChunk(content=event.text)
911
+
912
+ def _build_gemini_messages(
913
+ self,
914
+ messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
915
+ ) -> tuple[list["ContentType"], str, str]:
916
+ gemini_message_history: list["ContentType"] = []
917
+ system_prompts = []
918
+ for role, content, _tool_call_id, _tool_calls in messages:
919
+ if role == ChatCompletionMessageRole.USER:
920
+ gemini_message_history.append({"role": "user", "parts": content})
921
+ elif role == ChatCompletionMessageRole.AI:
922
+ gemini_message_history.append({"role": "model", "parts": content})
923
+ elif role == ChatCompletionMessageRole.SYSTEM:
924
+ system_prompts.append(content)
925
+ elif role == ChatCompletionMessageRole.TOOL:
926
+ raise NotImplementedError
927
+ else:
928
+ assert_never(role)
929
+ if gemini_message_history:
930
+ prompt = gemini_message_history.pop()["parts"]
931
+ else:
932
+ prompt = ""
933
+
934
+ return gemini_message_history, prompt, "\n".join(system_prompts)
935
+
936
+
937
+ def initialize_playground_clients() -> None:
938
+ """
939
+ Ensure that all playground clients are registered at import time.
940
+ """
941
+ pass
942
+
943
+
944
+ LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
945
+ LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
946
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
947
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
948
+ LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL