arize-phoenix 5.5.2__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (186) hide show
  1. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +4 -7
  2. arize_phoenix-5.7.0.dist-info/RECORD +330 -0
  3. phoenix/config.py +50 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/helpers/playground_clients.py +671 -0
  72. phoenix/server/api/helpers/playground_registry.py +70 -0
  73. phoenix/server/api/helpers/playground_spans.py +325 -0
  74. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  75. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  76. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  77. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  78. phoenix/server/api/input_types/ClusterInput.py +2 -2
  79. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  80. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  81. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  82. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  83. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  84. phoenix/server/api/input_types/Granularity.py +1 -1
  85. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  86. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  87. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  88. phoenix/server/api/mutations/__init__.py +4 -0
  89. phoenix/server/api/mutations/chat_mutations.py +374 -0
  90. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  91. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  92. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  93. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  94. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  95. phoenix/server/api/mutations/user_mutations.py +4 -4
  96. phoenix/server/api/openapi/schema.py +2 -2
  97. phoenix/server/api/queries.py +61 -72
  98. phoenix/server/api/routers/oauth2.py +4 -4
  99. phoenix/server/api/routers/v1/datasets.py +22 -36
  100. phoenix/server/api/routers/v1/evaluations.py +6 -5
  101. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  102. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  103. phoenix/server/api/routers/v1/experiments.py +4 -4
  104. phoenix/server/api/routers/v1/spans.py +13 -12
  105. phoenix/server/api/routers/v1/traces.py +5 -5
  106. phoenix/server/api/routers/v1/utils.py +5 -5
  107. phoenix/server/api/schema.py +42 -10
  108. phoenix/server/api/subscriptions.py +347 -494
  109. phoenix/server/api/types/AnnotationSummary.py +3 -3
  110. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  111. phoenix/server/api/types/Cluster.py +8 -7
  112. phoenix/server/api/types/Dataset.py +5 -4
  113. phoenix/server/api/types/Dimension.py +3 -3
  114. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  115. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  116. phoenix/server/api/types/EvaluationSummary.py +3 -3
  117. phoenix/server/api/types/Event.py +7 -7
  118. phoenix/server/api/types/Experiment.py +3 -3
  119. phoenix/server/api/types/ExperimentComparison.py +2 -4
  120. phoenix/server/api/types/GenerativeProvider.py +27 -3
  121. phoenix/server/api/types/Inferences.py +9 -8
  122. phoenix/server/api/types/InferencesRole.py +2 -2
  123. phoenix/server/api/types/Model.py +2 -2
  124. phoenix/server/api/types/Project.py +11 -18
  125. phoenix/server/api/types/Segments.py +3 -3
  126. phoenix/server/api/types/Span.py +45 -7
  127. phoenix/server/api/types/TemplateLanguage.py +9 -0
  128. phoenix/server/api/types/TimeSeries.py +8 -7
  129. phoenix/server/api/types/Trace.py +2 -2
  130. phoenix/server/api/types/UMAPPoints.py +6 -6
  131. phoenix/server/api/types/User.py +3 -3
  132. phoenix/server/api/types/node.py +1 -3
  133. phoenix/server/api/types/pagination.py +4 -4
  134. phoenix/server/api/utils.py +2 -4
  135. phoenix/server/app.py +76 -37
  136. phoenix/server/bearer_auth.py +4 -10
  137. phoenix/server/dml_event.py +3 -3
  138. phoenix/server/dml_event_handler.py +10 -24
  139. phoenix/server/grpc_server.py +3 -2
  140. phoenix/server/jwt_store.py +22 -21
  141. phoenix/server/main.py +17 -4
  142. phoenix/server/oauth2.py +3 -2
  143. phoenix/server/rate_limiters.py +5 -8
  144. phoenix/server/static/.vite/manifest.json +31 -31
  145. phoenix/server/static/assets/components-Csu8UKOs.js +1612 -0
  146. phoenix/server/static/assets/{index-DCzakdJq.js → index-Bk5C9EA7.js} +2 -2
  147. phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-UeWaKXNs.js} +337 -442
  148. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  149. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  150. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  151. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  152. phoenix/server/templates/index.html +1 -0
  153. phoenix/server/thread_server.py +1 -1
  154. phoenix/server/types.py +17 -29
  155. phoenix/services.py +8 -3
  156. phoenix/session/client.py +12 -24
  157. phoenix/session/data_extractor.py +3 -3
  158. phoenix/session/evaluation.py +1 -2
  159. phoenix/session/session.py +26 -21
  160. phoenix/trace/attributes.py +16 -28
  161. phoenix/trace/dsl/filter.py +17 -21
  162. phoenix/trace/dsl/helpers.py +3 -3
  163. phoenix/trace/dsl/query.py +13 -22
  164. phoenix/trace/fixtures.py +11 -17
  165. phoenix/trace/otel.py +5 -15
  166. phoenix/trace/projects.py +3 -2
  167. phoenix/trace/schemas.py +2 -2
  168. phoenix/trace/span_evaluations.py +9 -8
  169. phoenix/trace/span_json_decoder.py +3 -3
  170. phoenix/trace/span_json_encoder.py +2 -2
  171. phoenix/trace/trace_dataset.py +6 -5
  172. phoenix/trace/utils.py +6 -6
  173. phoenix/utilities/deprecation.py +3 -2
  174. phoenix/utilities/error_handling.py +3 -2
  175. phoenix/utilities/json.py +2 -1
  176. phoenix/utilities/logging.py +2 -2
  177. phoenix/utilities/project.py +1 -1
  178. phoenix/utilities/re.py +3 -4
  179. phoenix/utilities/template_formatters.py +16 -5
  180. phoenix/version.py +1 -1
  181. arize_phoenix-5.5.2.dist-info/RECORD +0 -321
  182. phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
  183. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  184. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  185. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  186. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- from typing import List, Sequence
1
+ from collections.abc import Sequence
2
2
 
3
3
  import strawberry
4
4
  from sqlalchemy import delete, insert, update
@@ -19,7 +19,7 @@ from phoenix.server.dml_event import SpanAnnotationDeleteEvent, SpanAnnotationIn
19
19
 
20
20
  @strawberry.type
21
21
  class SpanAnnotationMutationPayload:
22
- span_annotations: List[SpanAnnotation]
22
+ span_annotations: list[SpanAnnotation]
23
23
  query: Query
24
24
 
25
25
 
@@ -27,7 +27,7 @@ class SpanAnnotationMutationPayload:
27
27
  class SpanAnnotationMutationMixin:
28
28
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
29
29
  async def create_span_annotations(
30
- self, info: Info[Context, None], input: List[CreateSpanAnnotationInput]
30
+ self, info: Info[Context, None], input: list[CreateSpanAnnotationInput]
31
31
  ) -> SpanAnnotationMutationPayload:
32
32
  inserted_annotations: Sequence[models.SpanAnnotation] = []
33
33
  async with info.context.db() as session:
@@ -61,7 +61,7 @@ class SpanAnnotationMutationMixin:
61
61
 
62
62
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
63
63
  async def patch_span_annotations(
64
- self, info: Info[Context, None], input: List[PatchAnnotationInput]
64
+ self, info: Info[Context, None], input: list[PatchAnnotationInput]
65
65
  ) -> SpanAnnotationMutationPayload:
66
66
  patched_annotations = []
67
67
  async with info.context.db() as session:
@@ -1,4 +1,4 @@
1
- from typing import List, Sequence
1
+ from collections.abc import Sequence
2
2
 
3
3
  import strawberry
4
4
  from sqlalchemy import delete, insert, update
@@ -19,7 +19,7 @@ from phoenix.server.dml_event import TraceAnnotationDeleteEvent, TraceAnnotation
19
19
 
20
20
  @strawberry.type
21
21
  class TraceAnnotationMutationPayload:
22
- trace_annotations: List[TraceAnnotation]
22
+ trace_annotations: list[TraceAnnotation]
23
23
  query: Query
24
24
 
25
25
 
@@ -27,7 +27,7 @@ class TraceAnnotationMutationPayload:
27
27
  class TraceAnnotationMutationMixin:
28
28
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
29
29
  async def create_trace_annotations(
30
- self, info: Info[Context, None], input: List[CreateTraceAnnotationInput]
30
+ self, info: Info[Context, None], input: list[CreateTraceAnnotationInput]
31
31
  ) -> TraceAnnotationMutationPayload:
32
32
  inserted_annotations: Sequence[models.TraceAnnotation] = []
33
33
  async with info.context.db() as session:
@@ -61,7 +61,7 @@ class TraceAnnotationMutationMixin:
61
61
 
62
62
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
63
63
  async def patch_trace_annotations(
64
- self, info: Info[Context, None], input: List[PatchAnnotationInput]
64
+ self, info: Info[Context, None], input: list[PatchAnnotationInput]
65
65
  ) -> TraceAnnotationMutationPayload:
66
66
  patched_annotations = []
67
67
  async with info.context.db() as session:
@@ -1,7 +1,7 @@
1
1
  import secrets
2
2
  from contextlib import AsyncExitStack
3
3
  from datetime import datetime, timezone
4
- from typing import List, Literal, Optional, Tuple
4
+ from typing import Literal, Optional
5
5
 
6
6
  import strawberry
7
7
  from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
@@ -71,7 +71,7 @@ class PatchUserInput:
71
71
 
72
72
  @strawberry.input
73
73
  class DeleteUsersInput:
74
- user_ids: List[GlobalID]
74
+ user_ids: list[GlobalID]
75
75
 
76
76
 
77
77
  @strawberry.type
@@ -302,11 +302,11 @@ class UserMutationMixin:
302
302
  )
303
303
 
304
304
 
305
- def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:
305
+ def _select_role_id_by_name(role_name: str) -> Select[tuple[int]]:
306
306
  return select(models.UserRole.id).where(models.UserRole.name == role_name)
307
307
 
308
308
 
309
- def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
309
+ def _select_user_by_id(user_id: int) -> Select[tuple[models.User]]:
310
310
  return (
311
311
  select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
312
312
  )
@@ -1,11 +1,11 @@
1
- from typing import Any, Dict
1
+ from typing import Any
2
2
 
3
3
  from fastapi.openapi.utils import get_openapi
4
4
 
5
5
  from phoenix.server.api.routers.v1 import REST_API_VERSION, create_v1_router
6
6
 
7
7
 
8
- def get_openapi_schema() -> Dict[str, Any]:
8
+ def get_openapi_schema() -> dict[str, Any]:
9
9
  v1_router = create_v1_router(authentication_enabled=False)
10
10
  return get_openapi(
11
11
  title="Arize-Phoenix REST API",
@@ -1,6 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
- from typing import DefaultDict, Dict, List, Optional, Set, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
@@ -37,12 +37,17 @@ from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
37
37
  from phoenix.server.api.context import Context
38
38
  from phoenix.server.api.exceptions import NotFound, Unauthorized
39
39
  from phoenix.server.api.helpers import ensure_list
40
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
41
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
40
42
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
41
43
  from phoenix.server.api.input_types.Coordinates import (
42
44
  InputCoordinate2D,
43
45
  InputCoordinate3D,
44
46
  )
45
47
  from phoenix.server.api.input_types.DatasetSort import DatasetSort
48
+ from phoenix.server.api.input_types.InvocationParameters import (
49
+ InvocationParameter,
50
+ )
46
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
47
52
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
48
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -80,78 +85,62 @@ from phoenix.server.api.types.User import User, to_gql_user
80
85
  from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
81
86
  from phoenix.server.api.types.UserRole import UserRole
82
87
 
88
+ initialize_playground_clients()
89
+
83
90
 
84
91
  @strawberry.input
85
92
  class ModelsInput:
86
93
  provider_key: Optional[GenerativeProviderKey]
94
+ model_name: Optional[str] = None
87
95
 
88
96
 
89
97
  @strawberry.type
90
98
  class Query:
91
99
  @strawberry.field
92
- async def model_providers(self) -> List[GenerativeProvider]:
100
+ async def model_providers(self) -> list[GenerativeProvider]:
101
+ available_providers = PLAYGROUND_CLIENT_REGISTRY.list_all_providers()
93
102
  return [
94
103
  GenerativeProvider(
95
- name="OpenAI",
96
- key=GenerativeProviderKey.OPENAI,
97
- ),
98
- GenerativeProvider(
99
- name="Azure OpenAI",
100
- key=GenerativeProviderKey.AZURE_OPENAI,
101
- ),
102
- GenerativeProvider(
103
- name="Anthropic",
104
- key=GenerativeProviderKey.ANTHROPIC,
105
- ),
104
+ name=provider_key.value,
105
+ key=provider_key,
106
+ )
107
+ for provider_key in available_providers
106
108
  ]
107
109
 
108
110
  @strawberry.field
109
- async def models(self, input: Optional[ModelsInput] = None) -> List[GenerativeModel]:
110
- openai_models = [
111
- "o1-preview",
112
- "o1-preview-2024-09-12",
113
- "o1-mini",
114
- "o1-mini-2024-09-12",
115
- "gpt-4o",
116
- "gpt-4o-2024-08-06",
117
- "gpt-4o-2024-05-13",
118
- "chatgpt-4o-latest",
119
- "gpt-4o-mini",
120
- "gpt-4o-mini-2024-07-18",
121
- "gpt-4-turbo",
122
- "gpt-4-turbo-2024-04-09",
123
- "gpt-4-turbo-preview",
124
- "gpt-4-0125-preview",
125
- "gpt-4-1106-preview",
126
- "gpt-4",
127
- "gpt-4-0613",
128
- "gpt-3.5-turbo-0125",
129
- "gpt-3.5-turbo",
130
- "gpt-3.5-turbo-1106",
131
- "gpt-3.5-turbo-instruct",
132
- ]
133
- anthropic_models = [
134
- "claude-3-5-sonnet-20240620",
135
- "claude-3-opus-20240229",
136
- "claude-3-sonnet-20240229",
137
- "claude-3-haiku-20240307",
138
- ]
139
- openai_generative_models = [
140
- GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.OPENAI)
141
- for model_name in openai_models
142
- ]
143
- anthropic_generative_models = [
144
- GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.ANTHROPIC)
145
- for model_name in anthropic_models
146
- ]
147
-
148
- all_models = openai_generative_models + anthropic_generative_models
149
-
111
+ async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
150
112
  if input is not None and input.provider_key is not None:
151
- return [model for model in all_models if model.provider_key == input.provider_key]
113
+ supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
114
+ supported_models = [
115
+ GenerativeModel(name=model_name, provider_key=input.provider_key)
116
+ for model_name in supported_model_names
117
+ ]
118
+ return supported_models
152
119
 
120
+ registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
121
+ all_models: list[GenerativeModel] = []
122
+ for provider_key, model_name in registered_models:
123
+ if model_name is not None and provider_key is not None:
124
+ all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
153
125
  return all_models
154
126
 
127
+ @strawberry.field
128
+ async def model_invocation_parameters(
129
+ self, input: Optional[ModelsInput] = None
130
+ ) -> list[InvocationParameter]:
131
+ if input is None:
132
+ return []
133
+ provider_key = input.provider_key
134
+ model_name = input.model_name
135
+ if provider_key is not None:
136
+ client = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, model_name)
137
+ if client is None:
138
+ return []
139
+ invocation_parameters = client.supported_invocation_parameters()
140
+ return invocation_parameters
141
+ else:
142
+ return []
143
+
155
144
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
156
145
  async def users(
157
146
  self,
@@ -183,7 +172,7 @@ class Query:
183
172
  async def user_roles(
184
173
  self,
185
174
  info: Info[Context, None],
186
- ) -> List[UserRole]:
175
+ ) -> list[UserRole]:
187
176
  async with info.context.db() as session:
188
177
  roles = await session.scalars(
189
178
  select(models.UserRole).where(models.UserRole.name != enums.UserRole.SYSTEM.value)
@@ -197,7 +186,7 @@ class Query:
197
186
  ]
198
187
 
199
188
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
200
- async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
189
+ async def user_api_keys(self, info: Info[Context, None]) -> list[UserApiKey]:
201
190
  stmt = (
202
191
  select(models.ApiKey)
203
192
  .join(models.User)
@@ -209,7 +198,7 @@ class Query:
209
198
  return [to_gql_api_key(api_key) for api_key in api_keys]
210
199
 
211
200
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
212
- async def system_api_keys(self, info: Info[Context, None]) -> List[SystemApiKey]:
201
+ async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
213
202
  stmt = (
214
203
  select(models.ApiKey)
215
204
  .join(models.User)
@@ -304,8 +293,8 @@ class Query:
304
293
  async def compare_experiments(
305
294
  self,
306
295
  info: Info[Context, None],
307
- experiment_ids: List[GlobalID],
308
- ) -> List[ExperimentComparison]:
296
+ experiment_ids: list[GlobalID],
297
+ ) -> list[ExperimentComparison]:
309
298
  experiment_ids_ = [
310
299
  from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
311
300
  for experiment_id in experiment_ids
@@ -369,7 +358,7 @@ class Query:
369
358
 
370
359
  ExampleID: TypeAlias = int
371
360
  ExperimentID: TypeAlias = int
372
- runs: DefaultDict[ExampleID, DefaultDict[ExperimentID, List[OrmRun]]] = defaultdict(
361
+ runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[OrmRun]]] = defaultdict(
373
362
  lambda: defaultdict(list)
374
363
  )
375
364
  async for run in await session.stream_scalars(
@@ -576,9 +565,9 @@ class Query:
576
565
  @strawberry.field
577
566
  def clusters(
578
567
  self,
579
- clusters: List[ClusterInput],
580
- ) -> List[Cluster]:
581
- clustered_events: Dict[str, Set[ID]] = defaultdict(set)
568
+ clusters: list[ClusterInput],
569
+ ) -> list[Cluster]:
570
+ clustered_events: dict[str, set[ID]] = defaultdict(set)
582
571
  for i, cluster in enumerate(clusters):
583
572
  clustered_events[cluster.id or str(i)].update(cluster.event_ids)
584
573
  return to_gql_clusters(
@@ -590,19 +579,19 @@ class Query:
590
579
  self,
591
580
  info: Info[Context, None],
592
581
  event_ids: Annotated[
593
- List[ID],
582
+ list[ID],
594
583
  strawberry.argument(
595
584
  description="Event ID of the coordinates",
596
585
  ),
597
586
  ],
598
587
  coordinates_2d: Annotated[
599
- Optional[List[InputCoordinate2D]],
588
+ Optional[list[InputCoordinate2D]],
600
589
  strawberry.argument(
601
590
  description="Point coordinates. Must be either 2D or 3D.",
602
591
  ),
603
592
  ] = UNSET,
604
593
  coordinates_3d: Annotated[
605
- Optional[List[InputCoordinate3D]],
594
+ Optional[list[InputCoordinate3D]],
606
595
  strawberry.argument(
607
596
  description="Point coordinates. Must be either 2D or 3D.",
608
597
  ),
@@ -625,7 +614,7 @@ class Query:
625
614
  description="HDBSCAN cluster selection epsilon",
626
615
  ),
627
616
  ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
628
- ) -> List[Cluster]:
617
+ ) -> list[Cluster]:
629
618
  coordinates_3d = ensure_list(coordinates_3d)
630
619
  coordinates_2d = ensure_list(coordinates_2d)
631
620
 
@@ -661,13 +650,13 @@ class Query:
661
650
  if len(event_ids) == 0:
662
651
  return []
663
652
 
664
- grouped_event_ids: Dict[
653
+ grouped_event_ids: dict[
665
654
  Union[InferencesRole, AncillaryInferencesRole],
666
- List[ID],
655
+ list[ID],
667
656
  ] = defaultdict(list)
668
- grouped_coordinates: Dict[
657
+ grouped_coordinates: dict[
669
658
  Union[InferencesRole, AncillaryInferencesRole],
670
- List[npt.NDArray[np.float64]],
659
+ list[npt.NDArray[np.float64]],
671
660
  ] = defaultdict(list)
672
661
 
673
662
  for event_id, coordinate in zip(event_ids, coordinates):
@@ -2,7 +2,7 @@ import re
2
2
  from dataclasses import dataclass
3
3
  from datetime import timedelta
4
4
  from random import randrange
5
- from typing import Any, Dict, Optional, Tuple, TypedDict
5
+ from typing import Any, Optional, TypedDict
6
6
  from urllib.parse import unquote, urlparse
7
7
 
8
8
  from authlib.common.security import generate_token
@@ -192,7 +192,7 @@ class UserInfo:
192
192
  profile_picture_url: Optional[str]
193
193
 
194
194
 
195
- def _validate_token_data(token_data: Dict[str, Any]) -> None:
195
+ def _validate_token_data(token_data: dict[str, Any]) -> None:
196
196
  """
197
197
  Performs basic validations on the token data returned by the IDP.
198
198
  """
@@ -201,7 +201,7 @@ def _validate_token_data(token_data: Dict[str, Any]) -> None:
201
201
  assert token_type.lower() == "bearer"
202
202
 
203
203
 
204
- def _parse_user_info(user_info: Dict[str, Any]) -> UserInfo:
204
+ def _parse_user_info(user_info: dict[str, Any]) -> UserInfo:
205
205
  """
206
206
  Parses user info from the IDP's ID token.
207
207
  """
@@ -321,7 +321,7 @@ async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: s
321
321
 
322
322
  async def _email_and_username_exist(
323
323
  session: AsyncSession, /, *, email: str, username: Optional[str]
324
- ) -> Tuple[bool, bool]:
324
+ ) -> tuple[bool, bool]:
325
325
  """
326
326
  Checks whether the email and username are already in use.
327
327
  """
@@ -6,25 +6,11 @@ import logging
6
6
  import zlib
7
7
  from asyncio import QueueFull
8
8
  from collections import Counter
9
+ from collections.abc import Awaitable, Callable, Coroutine, Iterator, Mapping, Sequence
9
10
  from datetime import datetime
10
11
  from enum import Enum
11
12
  from functools import partial
12
- from typing import (
13
- Any,
14
- Awaitable,
15
- Callable,
16
- Coroutine,
17
- Dict,
18
- FrozenSet,
19
- Iterator,
20
- List,
21
- Mapping,
22
- Optional,
23
- Sequence,
24
- Tuple,
25
- Union,
26
- cast,
27
- )
13
+ from typing import Any, Optional, Union, cast
28
14
 
29
15
  import pandas as pd
30
16
  import pyarrow as pa
@@ -83,7 +69,7 @@ class Dataset(V1RoutesBaseModel):
83
69
  id: str
84
70
  name: str
85
71
  description: Optional[str]
86
- metadata: Dict[str, Any]
72
+ metadata: dict[str, Any]
87
73
  created_at: datetime
88
74
  updated_at: datetime
89
75
 
@@ -246,7 +232,7 @@ async def get_dataset(
246
232
  class DatasetVersion(V1RoutesBaseModel):
247
233
  version_id: str
248
234
  description: Optional[str]
249
- metadata: Dict[str, Any]
235
+ metadata: dict[str, Any]
250
236
  created_at: datetime
251
237
 
252
238
 
@@ -522,16 +508,16 @@ class FileContentEncoding(Enum):
522
508
 
523
509
  Name: TypeAlias = str
524
510
  Description: TypeAlias = Optional[str]
525
- InputKeys: TypeAlias = FrozenSet[str]
526
- OutputKeys: TypeAlias = FrozenSet[str]
527
- MetadataKeys: TypeAlias = FrozenSet[str]
511
+ InputKeys: TypeAlias = frozenset[str]
512
+ OutputKeys: TypeAlias = frozenset[str]
513
+ MetadataKeys: TypeAlias = frozenset[str]
528
514
  DatasetId: TypeAlias = int
529
515
  Examples: TypeAlias = Iterator[ExampleContent]
530
516
 
531
517
 
532
518
  def _process_json(
533
519
  data: Mapping[str, Any],
534
- ) -> Tuple[Examples, DatasetAction, Name, Description]:
520
+ ) -> tuple[Examples, DatasetAction, Name, Description]:
535
521
  name = data.get("name")
536
522
  if not name:
537
523
  raise ValueError("Dataset name is required")
@@ -547,7 +533,7 @@ def _process_json(
547
533
  raise ValueError(
548
534
  f"{k} should be a list of same length as input containing only dictionary objects"
549
535
  )
550
- examples: List[ExampleContent] = []
536
+ examples: list[ExampleContent] = []
551
537
  for i, obj in enumerate(inputs):
552
538
  example = ExampleContent(
553
539
  input=obj,
@@ -623,7 +609,7 @@ async def _check_table_exists(session: AsyncSession, name: str) -> bool:
623
609
 
624
610
 
625
611
  def _check_keys_exist(
626
- column_headers: FrozenSet[str],
612
+ column_headers: frozenset[str],
627
613
  input_keys: InputKeys,
628
614
  output_keys: OutputKeys,
629
615
  metadata_keys: MetadataKeys,
@@ -639,7 +625,7 @@ def _check_keys_exist(
639
625
 
640
626
  async def _parse_form_data(
641
627
  form: FormData,
642
- ) -> Tuple[
628
+ ) -> tuple[
643
629
  DatasetAction,
644
630
  Name,
645
631
  Description,
@@ -656,9 +642,9 @@ async def _parse_form_data(
656
642
  if not isinstance(file, UploadFile):
657
643
  raise ValueError("Malformed file in form data.")
658
644
  description = cast(Optional[str], form.get("description")) or file.filename
659
- input_keys = frozenset(filter(bool, cast(List[str], form.getlist("input_keys[]"))))
660
- output_keys = frozenset(filter(bool, cast(List[str], form.getlist("output_keys[]"))))
661
- metadata_keys = frozenset(filter(bool, cast(List[str], form.getlist("metadata_keys[]"))))
645
+ input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
646
+ output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
647
+ metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
662
648
  return (
663
649
  action,
664
650
  name,
@@ -672,16 +658,16 @@ async def _parse_form_data(
672
658
 
673
659
  class DatasetExample(V1RoutesBaseModel):
674
660
  id: str
675
- input: Dict[str, Any]
676
- output: Dict[str, Any]
677
- metadata: Dict[str, Any]
661
+ input: dict[str, Any]
662
+ output: dict[str, Any]
663
+ metadata: dict[str, Any]
678
664
  updated_at: datetime
679
665
 
680
666
 
681
667
  class ListDatasetExamplesData(V1RoutesBaseModel):
682
668
  dataset_id: str
683
669
  version_id: str
684
- examples: List[DatasetExample]
670
+ examples: list[DatasetExample]
685
671
 
686
672
 
687
673
  class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
@@ -914,7 +900,7 @@ async def get_dataset_jsonl_openai_evals(
914
900
  return content
915
901
 
916
902
 
917
- def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
903
+ def _get_content_csv(examples: list[models.DatasetExampleRevision]) -> bytes:
918
904
  records = [
919
905
  {
920
906
  "example_id": GlobalID(
@@ -930,7 +916,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
930
916
  return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
931
917
 
932
918
 
933
- def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
919
+ def _get_content_jsonl_openai_ft(examples: list[models.DatasetExampleRevision]) -> bytes:
934
920
  records = io.BytesIO()
935
921
  for ex in examples:
936
922
  records.write(
@@ -951,7 +937,7 @@ def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision])
951
937
  return records.read()
952
938
 
953
939
 
954
- def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
940
+ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision]) -> bytes:
955
941
  records = io.BytesIO()
956
942
  for ex in examples:
957
943
  records.write(
@@ -980,7 +966,7 @@ def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision
980
966
 
981
967
  async def _get_db_examples(
982
968
  *, session: Any, id: str, version_id: Optional[str]
983
- ) -> Tuple[str, List[models.DatasetExampleRevision]]:
969
+ ) -> tuple[str, list[models.DatasetExampleRevision]]:
984
970
  dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
985
971
  dataset_version_id: Optional[int] = None
986
972
  if version_id:
@@ -1,6 +1,7 @@
1
1
  import gzip
2
+ from collections.abc import Callable
2
3
  from itertools import chain
3
- from typing import Any, Callable, Iterator, Optional, Tuple, Union, cast
4
+ from typing import Any, Iterator, Optional, Union, cast
4
5
 
5
6
  import pandas as pd
6
7
  import pyarrow as pa
@@ -208,7 +209,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
208
209
  )
209
210
  for index, row in dataframe.iterrows():
210
211
  score, label, explanation = _get_annotation_result(row)
211
- document_annotation = cls(cast(Union[Tuple[str, int], Tuple[int, str]], index))(
212
+ document_annotation = cls(cast(Union[tuple[str, int], tuple[int, str]], index))(
212
213
  name=eval_name,
213
214
  annotator_kind="LLM",
214
215
  score=score,
@@ -245,7 +246,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
245
246
 
246
247
  def _get_annotation_result(
247
248
  row: "pd.Series[Any]",
248
- ) -> Tuple[Optional[float], Optional[str], Optional[str]]:
249
+ ) -> tuple[Optional[float], Optional[str], Optional[str]]:
249
250
  return (
250
251
  cast(Optional[float], row.get("score")),
251
252
  cast(Optional[str], row.get("label")),
@@ -257,7 +258,7 @@ def _document_annotation_factory(
257
258
  span_id_idx: int,
258
259
  document_position_idx: int,
259
260
  ) -> Callable[
260
- [Union[Tuple[str, int], Tuple[int, str]]],
261
+ [Union[tuple[str, int], tuple[int, str]]],
261
262
  Callable[..., Precursors.DocumentAnnotation],
262
263
  ]:
263
264
  return lambda index: lambda **kwargs: Precursors.DocumentAnnotation(
@@ -356,6 +357,6 @@ def _read_sql_document_evaluations_into_dataframe(
356
357
 
357
358
  def _groupby_eval_name(
358
359
  evals_dataframe: DataFrame,
359
- ) -> Iterator[Tuple[EvaluationName, DataFrame]]:
360
+ ) -> Iterator[tuple[EvaluationName, DataFrame]]:
360
361
  for eval_name, evals_dataframe_for_name in evals_dataframe.groupby("name", as_index=False):
361
362
  yield str(eval_name), evals_dataframe_for_name
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import Any, Dict, Literal, Optional
2
+ from typing import Any, Literal, Optional
3
3
 
4
4
  from fastapi import APIRouter, HTTPException
5
5
  from pydantic import Field
@@ -39,7 +39,7 @@ class UpsertExperimentEvaluationRequestBody(V1RoutesBaseModel):
39
39
  error: Optional[str] = Field(
40
40
  None, description="Optional error message if the evaluation encountered an error"
41
41
  )
42
- metadata: Optional[Dict[str, Any]] = Field(
42
+ metadata: Optional[dict[str, Any]] = Field(
43
43
  default=None, description="Metadata for the evaluation"
44
44
  )
45
45
  trace_id: Optional[str] = Field(default=None, description="Optional trace ID for tracking")
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import Any, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from fastapi import APIRouter, HTTPException
5
5
  from pydantic import Field
@@ -113,7 +113,7 @@ class ExperimentRunResponse(ExperimentRun):
113
113
  experiment_id: str = Field(description="The ID of the experiment")
114
114
 
115
115
 
116
- class ListExperimentRunsResponseBody(ResponseBody[List[ExperimentRunResponse]]):
116
+ class ListExperimentRunsResponseBody(ResponseBody[list[ExperimentRunResponse]]):
117
117
  pass
118
118
 
119
119
 
@@ -1,6 +1,6 @@
1
1
  from datetime import datetime
2
2
  from random import getrandbits
3
- from typing import Any, Dict, List, Optional
3
+ from typing import Any, Optional
4
4
 
5
5
  from fastapi import APIRouter, HTTPException, Path
6
6
  from pydantic import Field
@@ -40,7 +40,7 @@ class Experiment(V1RoutesBaseModel):
40
40
  description="The ID of the dataset version associated with the experiment"
41
41
  )
42
42
  repetitions: int = Field(description="Number of times the experiment is repeated")
43
- metadata: Dict[str, Any] = Field(description="Metadata of the experiment")
43
+ metadata: dict[str, Any] = Field(description="Metadata of the experiment")
44
44
  project_name: Optional[str] = Field(
45
45
  description="The name of the project associated with the experiment"
46
46
  )
@@ -60,7 +60,7 @@ class CreateExperimentRequestBody(V1RoutesBaseModel):
60
60
  description: Optional[str] = Field(
61
61
  default=None, description="An optional description of the experiment"
62
62
  )
63
- metadata: Optional[Dict[str, Any]] = Field(
63
+ metadata: Optional[dict[str, Any]] = Field(
64
64
  default=None, description="Metadata for the experiment"
65
65
  )
66
66
  version_id: Optional[str] = Field(
@@ -254,7 +254,7 @@ async def get_experiment(request: Request, experiment_id: str) -> GetExperimentR
254
254
  )
255
255
 
256
256
 
257
- class ListExperimentsResponseBody(ResponseBody[List[Experiment]]):
257
+ class ListExperimentsResponseBody(ResponseBody[list[Experiment]]):
258
258
  pass
259
259
 
260
260