arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py CHANGED
@@ -4,7 +4,6 @@ import importlib
4
4
  import json
5
5
  import logging
6
6
  import os
7
- from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
8
7
  from contextlib import AbstractAsyncContextManager, AsyncExitStack
9
8
  from dataclasses import dataclass, field
10
9
  from datetime import datetime, timedelta, timezone
@@ -14,9 +13,14 @@ from types import MethodType
14
13
  from typing import (
15
14
  TYPE_CHECKING,
16
15
  Any,
16
+ AsyncIterator,
17
+ Awaitable,
18
+ Callable,
19
+ Iterable,
17
20
  NamedTuple,
18
21
  Optional,
19
22
  Protocol,
23
+ Sequence,
20
24
  TypedDict,
21
25
  Union,
22
26
  cast,
@@ -41,7 +45,6 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin
41
45
  from starlette.requests import Request
42
46
  from starlette.responses import JSONResponse, PlainTextResponse, RedirectResponse, Response
43
47
  from starlette.staticfiles import StaticFiles
44
- from starlette.status import HTTP_401_UNAUTHORIZED
45
48
  from starlette.templating import Jinja2Templates
46
49
  from starlette.types import Scope, StatefulLifespan
47
50
  from strawberry.extensions import SchemaExtension
@@ -63,7 +66,7 @@ from phoenix.config import (
63
66
  get_env_gql_extension_paths,
64
67
  get_env_grpc_interceptor_paths,
65
68
  get_env_host,
66
- get_env_host_root_path,
69
+ get_env_max_spans_queue_size,
67
70
  get_env_port,
68
71
  get_env_support_email,
69
72
  server_instrumentation_is_enabled,
@@ -77,21 +80,30 @@ from phoenix.db.facilitator import Facilitator
77
80
  from phoenix.db.helpers import SupportedSQLDialect
78
81
  from phoenix.exceptions import PhoenixMigrationError
79
82
  from phoenix.pointcloud.umap_parameters import UMAPParameters
83
+ from phoenix.server.api.auth_messages import AUTH_ERROR_MESSAGES, AuthErrorCode
80
84
  from phoenix.server.api.context import Context, DataLoaders
81
85
  from phoenix.server.api.dataloaders import (
82
86
  AnnotationConfigsByProjectDataLoader,
83
87
  AnnotationSummaryDataLoader,
88
+ AverageExperimentRepeatedRunGroupLatencyDataLoader,
84
89
  AverageExperimentRunLatencyDataLoader,
85
90
  CacheForDataLoaders,
91
+ DatasetDatasetSplitsDataLoader,
86
92
  DatasetExampleRevisionsDataLoader,
93
+ DatasetExamplesAndVersionsByExperimentRunDataLoader,
87
94
  DatasetExampleSpansDataLoader,
95
+ DatasetExampleSplitsDataLoader,
88
96
  DocumentEvaluationsDataLoader,
89
97
  DocumentEvaluationSummaryDataLoader,
90
98
  DocumentRetrievalMetricsDataLoader,
91
99
  ExperimentAnnotationSummaryDataLoader,
100
+ ExperimentDatasetSplitsDataLoader,
92
101
  ExperimentErrorRatesDataLoader,
102
+ ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
103
+ ExperimentRepeatedRunGroupsDataLoader,
93
104
  ExperimentRunAnnotations,
94
105
  ExperimentRunCountsDataLoader,
106
+ ExperimentRunsByExperimentAndExampleDataLoader,
95
107
  ExperimentSequenceNumberDataLoader,
96
108
  LastUsedTimesByGenerativeModelIdDataLoader,
97
109
  LatencyMsQuantileDataLoader,
@@ -102,6 +114,7 @@ from phoenix.server.api.dataloaders import (
102
114
  ProjectIdsByTraceRetentionPolicyIdDataLoader,
103
115
  PromptVersionSequenceNumberDataLoader,
104
116
  RecordCountDataLoader,
117
+ SessionAnnotationsBySessionDataLoader,
105
118
  SessionIODataLoader,
106
119
  SessionNumTracesDataLoader,
107
120
  SessionNumTracesWithErrorDataLoader,
@@ -116,6 +129,7 @@ from phoenix.server.api.dataloaders import (
116
129
  SpanCostDetailSummaryEntriesBySpanDataLoader,
117
130
  SpanCostDetailSummaryEntriesByTraceDataLoader,
118
131
  SpanCostSummaryByExperimentDataLoader,
132
+ SpanCostSummaryByExperimentRepeatedRunGroupDataLoader,
119
133
  SpanCostSummaryByExperimentRunDataLoader,
120
134
  SpanCostSummaryByGenerativeModelDataLoader,
121
135
  SpanCostSummaryByProjectDataLoader,
@@ -126,14 +140,17 @@ from phoenix.server.api.dataloaders import (
126
140
  SpanProjectsDataLoader,
127
141
  TableFieldsDataLoader,
128
142
  TokenCountDataLoader,
143
+ TokenPricesByModelDataLoader,
144
+ TraceAnnotationsByTraceDataLoader,
129
145
  TraceByTraceIdsDataLoader,
130
146
  TraceRetentionPolicyIdByProjectIdDataLoader,
131
147
  TraceRootSpansDataLoader,
132
148
  UserRolesDataLoader,
133
149
  UsersDataLoader,
134
150
  )
151
+ from phoenix.server.api.dataloaders.dataset_labels import DatasetLabelsDataLoader
135
152
  from phoenix.server.api.routers import (
136
- auth_router,
153
+ create_auth_router,
137
154
  create_embeddings_router,
138
155
  create_v1_router,
139
156
  oauth2_router,
@@ -151,6 +168,7 @@ from phoenix.server.grpc_server import GrpcServer
151
168
  from phoenix.server.jwt_store import JwtStore
152
169
  from phoenix.server.middleware.gzip import GZipMiddleware
153
170
  from phoenix.server.oauth2 import OAuth2Clients
171
+ from phoenix.server.prometheus import SPAN_QUEUE_REJECTIONS
154
172
  from phoenix.server.retention import TraceDataSweeper
155
173
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
156
174
  from phoenix.server.types import (
@@ -161,6 +179,7 @@ from phoenix.server.types import (
161
179
  LastUpdatedAt,
162
180
  TokenStore,
163
181
  )
182
+ from phoenix.server.utils import get_root_path, prepend_root_path
164
183
  from phoenix.settings import Settings
165
184
  from phoenix.trace.fixtures import (
166
185
  TracesFixture,
@@ -179,6 +198,8 @@ from phoenix.version import __version__ as phoenix_version
179
198
  if TYPE_CHECKING:
180
199
  from opentelemetry.trace import TracerProvider
181
200
 
201
+ from phoenix.config import LDAPConfig
202
+
182
203
  logger = logging.getLogger(__name__)
183
204
 
184
205
  router = APIRouter(include_in_schema=False)
@@ -235,11 +256,19 @@ class AppConfig(NamedTuple):
235
256
  web_manifest_path: Path
236
257
  authentication_enabled: bool
237
258
  """ Whether authentication is enabled """
259
+ auth_error_messages: dict[AuthErrorCode, str]
260
+ """ Mapping of auth error codes to user-friendly messages """
238
261
  oauth2_idps: Sequence[OAuth2Idp]
239
262
  basic_auth_disabled: bool = False
263
+ ldap_enabled: bool = False
264
+ """ Whether LDAP authentication is configured """
265
+ ldap_manual_user_creation_enabled: bool = False
266
+ """ Whether manual LDAP user creation is allowed (False when LDAP disabled or no email attr) """
240
267
  auto_login_idp_name: Optional[str] = None
241
268
  fullstory_org: Optional[str] = None
242
269
  """ FullStory organization ID for web analytics tracking """
270
+ scarf_sh_pixel_id: Optional[str] = None
271
+ """ Scarf.sh pixel ID for open-source analytics and usage """
243
272
  management_url: Optional[str] = None
244
273
  """ URL for a phoenix management interface, only visible to management users """
245
274
  support_email: Optional[str] = None
@@ -269,9 +298,6 @@ class Static(StaticFiles):
269
298
  return {}
270
299
  raise e
271
300
 
272
- def _sanitize_basename(self, basename: str) -> str:
273
- return basename[:-1] if basename.endswith("/") else basename
274
-
275
301
  async def get_response(self, path: str, scope: Scope) -> Response:
276
302
  # Redirect to the oauth2 login page if basic auth is disabled and auto_login is enabled
277
303
  # TODO: this needs to be refactored to be cleaner
@@ -280,14 +306,10 @@ class Static(StaticFiles):
280
306
  and self._app_config.basic_auth_disabled
281
307
  and self._app_config.auto_login_idp_name
282
308
  ):
283
- request = Request(scope)
284
- url = URL(
285
- str(
286
- Path(get_env_host_root_path())
287
- / f"oauth2/{self._app_config.auto_login_idp_name}/login"
288
- )
309
+ redirect_path = prepend_root_path(
310
+ scope, f"oauth2/{self._app_config.auto_login_idp_name}/login"
289
311
  )
290
- url = url.include_query_params(**request.query_params)
312
+ url = URL(redirect_path).include_query_params(**Request(scope).query_params)
291
313
  return RedirectResponse(url=url)
292
314
  try:
293
315
  response = await super().get_response(path, scope)
@@ -304,7 +326,7 @@ class Static(StaticFiles):
304
326
  "min_dist": self._app_config.min_dist,
305
327
  "n_neighbors": self._app_config.n_neighbors,
306
328
  "n_samples": self._app_config.n_samples,
307
- "basename": self._sanitize_basename(request.scope.get("root_path", "")),
329
+ "basename": get_root_path(scope),
308
330
  "platform_version": phoenix_version,
309
331
  "request": request,
310
332
  "is_development": self._app_config.is_development,
@@ -312,12 +334,16 @@ class Static(StaticFiles):
312
334
  "authentication_enabled": self._app_config.authentication_enabled,
313
335
  "oauth2_idps": self._app_config.oauth2_idps,
314
336
  "basic_auth_disabled": self._app_config.basic_auth_disabled,
337
+ "ldap_enabled": self._app_config.ldap_enabled,
338
+ "ldap_manual_user_creation_enabled": self._app_config.ldap_manual_user_creation_enabled, # noqa: E501
315
339
  "auto_login_idp_name": self._app_config.auto_login_idp_name,
316
340
  "fullstory_org": self._app_config.fullstory_org,
341
+ "scarf_sh_pixel_id": self._app_config.scarf_sh_pixel_id,
317
342
  "management_url": self._app_config.management_url,
318
343
  "support_email": self._app_config.support_email,
319
344
  "has_db_threshold": self._app_config.has_db_threshold,
320
345
  "allow_external_resources": self._app_config.allow_external_resources,
346
+ "auth_error_messages": self._app_config.auth_error_messages,
321
347
  },
322
348
  )
323
349
  except Exception as e:
@@ -340,7 +366,7 @@ class RequestOriginHostnameValidator(BaseHTTPMiddleware):
340
366
  if not (url := headers.get(key)):
341
367
  continue
342
368
  if urlparse(url).hostname not in self._trusted_hostnames:
343
- return Response(f"untrusted {key}", status_code=HTTP_401_UNAUTHORIZED)
369
+ return Response(f"untrusted {key}", status_code=401)
344
370
  return await call_next(request)
345
371
 
346
372
 
@@ -427,13 +453,13 @@ class Scaffolder(DaemonTask):
427
453
  def __init__(
428
454
  self,
429
455
  config: ScaffolderConfig,
430
- queue_span: Callable[[Span, ProjectName], Awaitable[None]],
431
- queue_evaluation: Callable[[pb.Evaluation], Awaitable[None]],
456
+ enqueue_span: Callable[[Span, ProjectName], Awaitable[None]],
457
+ enqueue_evaluation: Callable[[pb.Evaluation], Awaitable[None]],
432
458
  ) -> None:
433
459
  super().__init__()
434
460
  self._db = config.db
435
- self._queue_span = queue_span
436
- self._queue_evaluation = queue_evaluation
461
+ self._enqueue_span = enqueue_span
462
+ self._enqueue_evaluation = enqueue_evaluation
437
463
  self._tracing_fixtures = [
438
464
  get_trace_fixture_by_name(name) for name in set(config.tracing_fixture_names)
439
465
  ]
@@ -504,9 +530,9 @@ class Scaffolder(DaemonTask):
504
530
  project_name = fixture.project_name or fixture.name
505
531
  logger.info(f"Loading '{project_name}' fixtures...")
506
532
  for span in fixture_spans:
507
- await self._queue_span(span, project_name)
533
+ await self._enqueue_span(span, project_name)
508
534
  for evaluation in fixture_evals:
509
- await self._queue_evaluation(evaluation)
535
+ await self._enqueue_evaluation(evaluation)
510
536
 
511
537
  except FileNotFoundError:
512
538
  logger.warning(f"Fixture file not found for '{fixture.name}'")
@@ -529,6 +555,32 @@ class Scaffolder(DaemonTask):
529
555
  logger.error(f"Error processing dataset fixture: {e}")
530
556
 
531
557
 
558
+ class _CapacityIndicator(Protocol):
559
+ @property
560
+ def is_full(self) -> bool: ...
561
+
562
+
563
+ class CapacityInterceptor(AsyncServerInterceptor):
564
+ def __init__(self, indicator: _CapacityIndicator):
565
+ self._indicator = indicator
566
+
567
+ @override
568
+ async def intercept(
569
+ self,
570
+ method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]],
571
+ request_or_iterator: Any,
572
+ context: grpc.aio.ServicerContext,
573
+ method_name: str,
574
+ ) -> Any:
575
+ if self._indicator.is_full:
576
+ SPAN_QUEUE_REJECTIONS.inc()
577
+ context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
578
+ context.set_details("Server is at capacity and cannot process more requests")
579
+ return
580
+
581
+ return await method(request_or_iterator, context)
582
+
583
+
532
584
  def _lifespan(
533
585
  *,
534
586
  db: DbSessionFactory,
@@ -555,18 +607,23 @@ def _lifespan(
555
607
  db.lock = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
556
608
  async with AsyncExitStack() as stack:
557
609
  (
558
- enqueue,
559
- queue_span,
560
- queue_evaluation,
610
+ enqueue_annotations,
611
+ enqueue_span,
612
+ enqueue_evaluation,
561
613
  enqueue_operation,
562
614
  ) = await stack.enter_async_context(bulk_inserter)
615
+ interceptors = [
616
+ CapacityInterceptor(bulk_inserter),
617
+ *user_grpc_interceptors(),
618
+ *grpc_interceptors,
619
+ ]
563
620
  grpc_server = GrpcServer(
564
- queue_span,
621
+ enqueue_span,
565
622
  disabled=read_only,
566
623
  tracer_provider=tracer_provider,
567
624
  enable_prometheus=enable_prometheus,
568
625
  token_store=token_store,
569
- interceptors=user_grpc_interceptors() + list(grpc_interceptors),
626
+ interceptors=interceptors,
570
627
  )
571
628
  await stack.enter_async_context(grpc_server)
572
629
  await stack.enter_async_context(dml_event_handler)
@@ -578,17 +635,17 @@ def _lifespan(
578
635
  if scaffolder_config:
579
636
  scaffolder = Scaffolder(
580
637
  config=scaffolder_config,
581
- queue_span=queue_span,
582
- queue_evaluation=queue_evaluation,
638
+ enqueue_span=enqueue_span,
639
+ enqueue_evaluation=enqueue_evaluation,
583
640
  )
584
641
  await stack.enter_async_context(scaffolder)
585
642
  if isinstance(token_store, AbstractAsyncContextManager):
586
643
  await stack.enter_async_context(token_store)
587
644
  yield {
588
645
  "event_queue": dml_event_handler,
589
- "enqueue": enqueue,
590
- "queue_span_for_bulk_insert": queue_span,
591
- "queue_evaluation_for_bulk_insert": queue_evaluation,
646
+ "enqueue_annotations": enqueue_annotations,
647
+ "enqueue_span": enqueue_span,
648
+ "enqueue_evaluation": enqueue_evaluation,
592
649
  "enqueue_operation": enqueue_operation,
593
650
  }
594
651
  for callback in shutdown_callbacks:
@@ -663,9 +720,23 @@ def create_graphql_router(
663
720
  event_queue=event_queue,
664
721
  data_loaders=DataLoaders(
665
722
  annotation_configs_by_project=AnnotationConfigsByProjectDataLoader(db),
723
+ average_experiment_repeated_run_group_latency=AverageExperimentRepeatedRunGroupLatencyDataLoader(
724
+ db
725
+ ),
666
726
  average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
727
+ dataset_dataset_splits=DatasetDatasetSplitsDataLoader(db),
728
+ dataset_example_fields=TableFieldsDataLoader(db, models.DatasetExample),
667
729
  dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
668
730
  dataset_example_spans=DatasetExampleSpansDataLoader(db),
731
+ dataset_examples_and_versions_by_experiment_run=DatasetExamplesAndVersionsByExperimentRunDataLoader(
732
+ db
733
+ ),
734
+ dataset_example_splits=DatasetExampleSplitsDataLoader(db),
735
+ dataset_fields=TableFieldsDataLoader(db, models.Dataset),
736
+ dataset_split_fields=TableFieldsDataLoader(db, models.DatasetSplit),
737
+ dataset_version_fields=TableFieldsDataLoader(db, models.DatasetVersion),
738
+ dataset_labels=DatasetLabelsDataLoader(db),
739
+ dataset_label_fields=TableFieldsDataLoader(db, models.DatasetLabel),
669
740
  document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
670
741
  db,
671
742
  cache_map=(
@@ -674,6 +745,7 @@ def create_graphql_router(
674
745
  else None
675
746
  ),
676
747
  ),
748
+ document_annotation_fields=TableFieldsDataLoader(db, models.DocumentAnnotation),
677
749
  document_evaluations=DocumentEvaluationsDataLoader(db),
678
750
  document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db),
679
751
  annotation_summaries=AnnotationSummaryDataLoader(
@@ -683,10 +755,24 @@ def create_graphql_router(
683
755
  ),
684
756
  ),
685
757
  experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(db),
758
+ experiment_dataset_splits=ExperimentDatasetSplitsDataLoader(db),
686
759
  experiment_error_rates=ExperimentErrorRatesDataLoader(db),
760
+ experiment_fields=TableFieldsDataLoader(db, models.Experiment),
761
+ experiment_repeated_run_group_annotation_summaries=ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(
762
+ db
763
+ ),
764
+ experiment_repeated_run_groups=ExperimentRepeatedRunGroupsDataLoader(db),
765
+ experiment_run_annotation_fields=TableFieldsDataLoader(
766
+ db, models.ExperimentRunAnnotation
767
+ ),
687
768
  experiment_run_annotations=ExperimentRunAnnotations(db),
688
769
  experiment_run_counts=ExperimentRunCountsDataLoader(db),
770
+ experiment_run_fields=TableFieldsDataLoader(db, models.ExperimentRun),
771
+ experiment_runs_by_experiment_and_example=ExperimentRunsByExperimentAndExampleDataLoader(
772
+ db
773
+ ),
689
774
  experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
775
+ generative_model_fields=TableFieldsDataLoader(db, models.GenerativeModel),
690
776
  last_used_times_by_generative_model_id=LastUsedTimesByGenerativeModelIdDataLoader(
691
777
  db
692
778
  ),
@@ -710,17 +796,26 @@ def create_graphql_router(
710
796
  projects_by_trace_retention_policy_id=ProjectIdsByTraceRetentionPolicyIdDataLoader(
711
797
  db
712
798
  ),
799
+ prompt_fields=TableFieldsDataLoader(db, models.Prompt),
800
+ prompt_label_fields=TableFieldsDataLoader(db, models.PromptLabel),
713
801
  prompt_version_sequence_number=PromptVersionSequenceNumberDataLoader(db),
802
+ prompt_version_tag_fields=TableFieldsDataLoader(db, models.PromptVersionTag),
803
+ project_session_annotation_fields=TableFieldsDataLoader(
804
+ db, models.ProjectSessionAnnotation
805
+ ),
806
+ project_session_fields=TableFieldsDataLoader(db, models.ProjectSession),
714
807
  record_counts=RecordCountDataLoader(
715
808
  db,
716
809
  cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
717
810
  ),
811
+ session_annotations_by_session=SessionAnnotationsBySessionDataLoader(db),
718
812
  session_first_inputs=SessionIODataLoader(db, "first_input"),
719
813
  session_last_outputs=SessionIODataLoader(db, "last_output"),
720
814
  session_num_traces=SessionNumTracesDataLoader(db),
721
815
  session_num_traces_with_error=SessionNumTracesWithErrorDataLoader(db),
722
816
  session_token_usages=SessionTokenUsagesDataLoader(db),
723
817
  session_trace_latency_ms_quantile=SessionTraceLatencyMsQuantileDataLoader(db),
818
+ span_annotation_fields=TableFieldsDataLoader(db, models.SpanAnnotation),
724
819
  span_annotations=SpanAnnotationsDataLoader(db),
725
820
  span_fields=TableFieldsDataLoader(db, models.Span),
726
821
  span_by_id=SpanByIdDataLoader(db),
@@ -740,6 +835,11 @@ def create_graphql_router(
740
835
  span_cost_details_by_span_cost=SpanCostDetailsBySpanCostDataLoader(db),
741
836
  span_cost_detail_fields=TableFieldsDataLoader(db, models.SpanCostDetail),
742
837
  span_cost_fields=TableFieldsDataLoader(db, models.SpanCost),
838
+ span_cost_summary_by_experiment=SpanCostSummaryByExperimentDataLoader(db),
839
+ span_cost_summary_by_experiment_repeated_run_group=SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(
840
+ db
841
+ ),
842
+ span_cost_summary_by_experiment_run=SpanCostSummaryByExperimentRunDataLoader(db),
743
843
  span_cost_summary_by_generative_model=SpanCostSummaryByGenerativeModelDataLoader(
744
844
  db
745
845
  ),
@@ -756,6 +856,9 @@ def create_graphql_router(
756
856
  db,
757
857
  cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
758
858
  ),
859
+ token_prices_by_model=TokenPricesByModelDataLoader(db),
860
+ trace_annotation_fields=TableFieldsDataLoader(db, models.TraceAnnotation),
861
+ trace_annotations_by_trace=TraceAnnotationsByTraceDataLoader(db),
759
862
  trace_by_trace_ids=TraceByTraceIdsDataLoader(db),
760
863
  trace_fields=TableFieldsDataLoader(db, models.Trace),
761
864
  trace_retention_policy_id_by_project_id=TraceRetentionPolicyIdByProjectIdDataLoader(
@@ -767,9 +870,9 @@ def create_graphql_router(
767
870
  trace_root_spans=TraceRootSpansDataLoader(db),
768
871
  project_by_name=ProjectByNameDataLoader(db),
769
872
  users=UsersDataLoader(db),
873
+ user_api_key_fields=TableFieldsDataLoader(db, models.ApiKey),
874
+ user_fields=TableFieldsDataLoader(db, models.User),
770
875
  user_roles=UserRolesDataLoader(db),
771
- span_cost_summary_by_experiment=SpanCostSummaryByExperimentDataLoader(db),
772
- span_cost_summary_by_experiment_run=SpanCostSummaryByExperimentRunDataLoader(db),
773
876
  ),
774
877
  cache_for_dataloaders=cache_for_dataloaders,
775
878
  read_only=read_only,
@@ -896,6 +999,7 @@ def create_app(
896
999
  scaffolder_config: Optional[ScaffolderConfig] = None,
897
1000
  email_sender: Optional[EmailSender] = None,
898
1001
  oauth2_client_configs: Optional[list[OAuth2ClientConfig]] = None,
1002
+ ldap_config: Optional["LDAPConfig"] = None,
899
1003
  basic_auth_disabled: bool = False,
900
1004
  bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
901
1005
  allowed_origins: Optional[list[str]] = None,
@@ -969,11 +1073,11 @@ def create_app(
969
1073
  span_cost_calculator = SpanCostCalculator(db, generative_model_store)
970
1074
  bulk_inserter = bulk_inserter_factory(
971
1075
  db,
972
- enable_prometheus=enable_prometheus,
973
1076
  span_cost_calculator=span_cost_calculator,
974
1077
  event_queue=dml_event_handler,
975
1078
  initial_batch_of_spans=initial_batch_of_spans,
976
1079
  initial_batch_of_evaluations=initial_batch_of_evaluations,
1080
+ max_spans_queue_size=get_env_max_spans_queue_size(),
977
1081
  )
978
1082
  tracer_provider = None
979
1083
  graphql_schema_extensions: list[Union[type[SchemaExtension], SchemaExtension]] = []
@@ -1054,7 +1158,8 @@ def create_app(
1054
1158
  app.include_router(router)
1055
1159
  app.include_router(graphql_router)
1056
1160
  if authentication_enabled:
1057
- app.include_router(auth_router)
1161
+ # Only register LDAP endpoint if LDAP is configured
1162
+ app.include_router(create_auth_router(ldap_enabled=ldap_config is not None))
1058
1163
  app.include_router(oauth2_router)
1059
1164
  app.add_middleware(GZipMiddleware)
1060
1165
  web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
@@ -1081,8 +1186,14 @@ def create_app(
1081
1186
  web_manifest_path=web_manifest_path,
1082
1187
  oauth2_idps=oauth2_idps,
1083
1188
  basic_auth_disabled=basic_auth_disabled,
1189
+ ldap_enabled=ldap_config is not None,
1190
+ # Disable manual user creation when LDAP disabled or no email attr
1191
+ ldap_manual_user_creation_enabled=(
1192
+ ldap_config.attr_email is not None if ldap_config else False
1193
+ ),
1084
1194
  auto_login_idp_name=auto_login_idp_name,
1085
1195
  fullstory_org=Settings.fullstory_org,
1196
+ scarf_sh_pixel_id=Settings.scarf_sh_pixel_id,
1086
1197
  management_url=management_url,
1087
1198
  support_email=get_env_support_email(),
1088
1199
  has_db_threshold=bool(
@@ -1090,6 +1201,7 @@ def create_app(
1090
1201
  and get_env_database_usage_insertion_blocking_threshold_percentage()
1091
1202
  ),
1092
1203
  allow_external_resources=get_env_allow_external_resources(),
1204
+ auth_error_messages=dict(AUTH_ERROR_MESSAGES) if authentication_enabled else {},
1093
1205
  ),
1094
1206
  ),
1095
1207
  name="static",
@@ -1101,9 +1213,15 @@ def create_app(
1101
1213
  app.state.access_token_expiry = access_token_expiry
1102
1214
  app.state.refresh_token_expiry = refresh_token_expiry
1103
1215
  app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
1216
+ # Cache LDAPAuthenticator to avoid re-parsing TLS config on every login
1217
+ if ldap_config:
1218
+ from phoenix.server.ldap import LDAPAuthenticator
1219
+
1220
+ app.state.ldap_authenticator = LDAPAuthenticator(ldap_config)
1104
1221
  app.state.db = db
1105
1222
  app.state.email_sender = email_sender
1106
1223
  app.state.span_cost_calculator = span_cost_calculator
1224
+ app.state.span_queue_is_full = lambda: bulk_inserter.is_full
1107
1225
  app = _add_get_secret_method(app=app, secret=secret)
1108
1226
  app = _add_get_token_store_method(app=app, token_store=token_store)
1109
1227
  if tracer_provider:
@@ -23,7 +23,6 @@ Usage:
23
23
  """
24
24
 
25
25
  from fastapi import HTTPException, Request
26
- from fastapi import status as fastapi_status
27
26
 
28
27
  from phoenix.config import get_env_support_email
29
28
  from phoenix.server.bearer_auth import PhoenixUser
@@ -43,13 +42,15 @@ def require_admin(request: Request) -> None:
43
42
  Behavior:
44
43
  - Allows access if the authenticated user is an admin or a system user.
45
44
  - Raises HTTP 403 Forbidden if the user is not authorized.
46
- - Expects authentication to be enabled and request.user to be set by the authentication.
45
+ - Allows access if authentication is not enabled.
47
46
  """
47
+ if not request.app.state.authentication_enabled:
48
+ return
48
49
  user = getattr(request, "user", None)
49
50
  # System users have all privileges
50
51
  if not (isinstance(user, PhoenixUser) and user.is_admin):
51
52
  raise HTTPException(
52
- status_code=fastapi_status.HTTP_403_FORBIDDEN,
53
+ status_code=403,
53
54
  detail="Only admin or system users can perform this action.",
54
55
  )
55
56
 
@@ -82,6 +83,6 @@ def is_not_locked(request: Request) -> None:
82
83
  if support_email := get_env_support_email():
83
84
  detail += f" Need help? Contact us at {support_email}"
84
85
  raise HTTPException(
85
- status_code=fastapi_status.HTTP_507_INSUFFICIENT_STORAGE,
86
+ status_code=507,
86
87
  detail=detail,
87
88
  )
@@ -9,7 +9,6 @@ from fastapi import HTTPException, Request, WebSocket, WebSocketException
9
9
  from grpc_interceptor import AsyncServerInterceptor
10
10
  from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
11
11
  from starlette.requests import HTTPConnection
12
- from starlette.status import HTTP_401_UNAUTHORIZED
13
12
  from typing_extensions import override
14
13
 
15
14
  from phoenix import config
@@ -76,11 +75,18 @@ class PhoenixUser(BaseUser):
76
75
  self._is_admin = (
77
76
  claims.status is ClaimSetStatus.VALID and claims.attributes.user_role == "ADMIN"
78
77
  )
78
+ self._is_viewer = (
79
+ claims.status is ClaimSetStatus.VALID and claims.attributes.user_role == "VIEWER"
80
+ )
79
81
 
80
82
  @cached_property
81
83
  def is_admin(self) -> bool:
82
84
  return self._is_admin
83
85
 
86
+ @cached_property
87
+ def is_viewer(self) -> bool:
88
+ return self._is_viewer
89
+
84
90
  @cached_property
85
91
  def identity(self) -> UserId:
86
92
  return self._user_id
@@ -93,6 +99,8 @@ class PhoenixUser(BaseUser):
93
99
  class PhoenixSystemUser(PhoenixUser):
94
100
  def __init__(self, user_id: UserId) -> None:
95
101
  self._user_id = user_id
102
+ self._is_admin = True # System users have admin privileges
103
+ self._is_viewer = False # System users are not viewers
96
104
 
97
105
  @property
98
106
  def is_admin(self) -> bool:
@@ -144,16 +152,16 @@ async def is_authenticated(
144
152
  """
145
153
  assert request or websocket
146
154
  if request and not isinstance((user := request.user), PhoenixUser):
147
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
155
+ raise HTTPException(status_code=401, detail="Invalid token")
148
156
  if websocket and not isinstance((user := websocket.user), PhoenixUser):
149
- raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token")
157
+ raise WebSocketException(code=401, reason="Invalid token")
150
158
  if isinstance(user, PhoenixSystemUser):
151
159
  return
152
160
  claims = user.claims
153
161
  if claims.status is ClaimSetStatus.EXPIRED:
154
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
162
+ raise HTTPException(status_code=401, detail="Expired token")
155
163
  if claims.status is not ClaimSetStatus.VALID:
156
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
164
+ raise HTTPException(status_code=401, detail="Invalid token")
157
165
 
158
166
 
159
167
  async def create_access_and_refresh_tokens(
@@ -1,4 +1,3 @@
1
- import re
2
1
  from datetime import datetime
3
2
  from typing import Any, Iterable, Mapping, Optional
4
3
 
@@ -20,24 +19,53 @@ class CostModelLookup:
20
19
  self,
21
20
  generative_models: Iterable[models.GenerativeModel] = (),
22
21
  ) -> None:
23
- self._models = tuple(generative_models)
22
+ self._models_by_id: dict[int, models.GenerativeModel] = {}
24
23
  self._model_priority: dict[
25
24
  int, tuple[_RegexSpecificityScore, float, _TieBreakerId]
26
25
  ] = {} # higher is better
27
- self._regex_specificity_score: dict[re.Pattern[str], _RegexSpecificityScore] = {}
28
26
 
29
- for m in self._models:
30
- self._regex_specificity_score[m.name_pattern] = regex_specificity.score(m.name_pattern)
27
+ for m in generative_models:
28
+ self._add_or_update_model(m)
31
29
 
32
- # For built-in models, use negative ID so that earlier IDs win
33
- # For user-defined models, use positive ID so later IDs win
34
- tie_breaker = -m.id if m.is_built_in else m.id
30
+ def _add_or_update_model(self, model: models.GenerativeModel) -> None:
31
+ """Add or update a single model in the lookup."""
32
+ self._models_by_id[model.id] = model
35
33
 
36
- self._model_priority[m.id] = (
37
- self._regex_specificity_score[m.name_pattern],
38
- m.start_time.timestamp() if m.start_time else 0.0,
39
- tie_breaker,
40
- )
34
+ specificity_score = regex_specificity.score(model.name_pattern)
35
+
36
+ # For built-in models, use negative ID so that earlier IDs win
37
+ # For user-defined models, use positive ID so later IDs win
38
+ tie_breaker = -model.id if model.is_built_in else model.id
39
+
40
+ self._model_priority[model.id] = (
41
+ specificity_score,
42
+ model.start_time.timestamp() if model.start_time else 0.0,
43
+ tie_breaker,
44
+ )
45
+
46
+ def _remove_model(self, model_id: int) -> None:
47
+ """Remove a model from the lookup."""
48
+ if model_id in self._models_by_id:
49
+ del self._models_by_id[model_id]
50
+ if model_id in self._model_priority:
51
+ del self._model_priority[model_id]
52
+
53
+ def merge(self, models: Iterable[models.GenerativeModel]) -> None:
54
+ """
55
+ Merge a collection of models into the existing lookup.
56
+
57
+ For each model:
58
+ - If deleted_at is set, remove it from the lookup
59
+ - Otherwise, add or update it in the lookup
60
+
61
+ Args:
62
+ models: An iterable of GenerativeModel objects to merge
63
+ """
64
+ for model in models:
65
+ if model.deleted_at is not None:
66
+ self._remove_model(model.id)
67
+ else:
68
+ self._add_or_update_model(model)
41
69
 
42
70
  def find_model(
43
71
  self,
@@ -107,7 +135,7 @@ class CostModelLookup:
107
135
  # 2. only include models that are active and match the regex pattern
108
136
  candidates = [
109
137
  model
110
- for model in self._models
138
+ for model in self._models_by_id.values()
111
139
  if (not model.start_time or model.start_time <= start_time)
112
140
  and model.name_pattern.search(model_name)
113
141
  ]