arize-phoenix 4.10.2rc2__py3-none-any.whl → 4.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/METADATA +3 -4
- {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/RECORD +29 -29
- phoenix/server/api/context.py +7 -3
- phoenix/server/api/openapi/main.py +2 -18
- phoenix/server/api/openapi/schema.py +12 -12
- phoenix/server/api/routers/v1/__init__.py +83 -36
- phoenix/server/api/routers/v1/dataset_examples.py +123 -102
- phoenix/server/api/routers/v1/datasets.py +507 -389
- phoenix/server/api/routers/v1/evaluations.py +66 -73
- phoenix/server/api/routers/v1/experiment_evaluations.py +91 -67
- phoenix/server/api/routers/v1/experiment_runs.py +155 -97
- phoenix/server/api/routers/v1/experiments.py +181 -131
- phoenix/server/api/routers/v1/spans.py +173 -143
- phoenix/server/api/routers/v1/traces.py +128 -114
- phoenix/server/api/types/Span.py +1 -0
- phoenix/server/app.py +176 -148
- phoenix/server/openapi/docs.py +221 -0
- phoenix/server/static/index.js +574 -573
- phoenix/server/thread_server.py +2 -2
- phoenix/session/client.py +5 -0
- phoenix/session/data_extractor.py +20 -1
- phoenix/session/session.py +4 -0
- phoenix/trace/attributes.py +2 -1
- phoenix/trace/schemas.py +1 -0
- phoenix/trace/span_json_decoder.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/routers/v1/utils.py +0 -94
- {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py
CHANGED
|
@@ -19,24 +19,25 @@ from typing import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
import strawberry
|
|
22
|
-
from fastapi import APIRouter, FastAPI
|
|
23
|
-
from fastapi.middleware.gzip import GZipMiddleware
|
|
24
|
-
from fastapi.responses import FileResponse
|
|
25
|
-
from fastapi.utils import is_body_allowed_for_status_code
|
|
26
22
|
from sqlalchemy.ext.asyncio import (
|
|
27
23
|
AsyncEngine,
|
|
28
24
|
AsyncSession,
|
|
29
25
|
async_sessionmaker,
|
|
30
26
|
)
|
|
27
|
+
from starlette.applications import Starlette
|
|
28
|
+
from starlette.datastructures import QueryParams
|
|
29
|
+
from starlette.endpoints import HTTPEndpoint
|
|
31
30
|
from starlette.exceptions import HTTPException
|
|
32
31
|
from starlette.middleware import Middleware
|
|
33
32
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
34
33
|
from starlette.requests import Request
|
|
35
|
-
from starlette.responses import PlainTextResponse, Response
|
|
34
|
+
from starlette.responses import FileResponse, PlainTextResponse, Response
|
|
35
|
+
from starlette.routing import Mount, Route
|
|
36
36
|
from starlette.staticfiles import StaticFiles
|
|
37
37
|
from starlette.templating import Jinja2Templates
|
|
38
38
|
from starlette.types import Scope, StatefulLifespan
|
|
39
|
-
from
|
|
39
|
+
from starlette.websockets import WebSocket
|
|
40
|
+
from strawberry.asgi import GraphQL
|
|
40
41
|
from strawberry.schema import BaseSchema
|
|
41
42
|
from typing_extensions import TypeAlias
|
|
42
43
|
|
|
@@ -79,10 +80,11 @@ from phoenix.server.api.dataloaders import (
|
|
|
79
80
|
TraceEvaluationsDataLoader,
|
|
80
81
|
TraceRowIdsDataLoader,
|
|
81
82
|
)
|
|
82
|
-
from phoenix.server.api.
|
|
83
|
-
from phoenix.server.api.routers.v1 import
|
|
83
|
+
from phoenix.server.api.openapi.schema import OPENAPI_SCHEMA_GENERATOR
|
|
84
|
+
from phoenix.server.api.routers.v1 import V1_ROUTES
|
|
84
85
|
from phoenix.server.api.schema import schema
|
|
85
86
|
from phoenix.server.grpc_server import GrpcServer
|
|
87
|
+
from phoenix.server.openapi.docs import get_swagger_ui_html
|
|
86
88
|
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
|
|
87
89
|
from phoenix.trace.schemas import Span
|
|
88
90
|
|
|
@@ -91,8 +93,6 @@ if TYPE_CHECKING:
|
|
|
91
93
|
|
|
92
94
|
logger = logging.getLogger(__name__)
|
|
93
95
|
|
|
94
|
-
router = APIRouter(include_in_schema=False)
|
|
95
|
-
|
|
96
96
|
templates = Jinja2Templates(directory=SERVER_DIR / "templates")
|
|
97
97
|
|
|
98
98
|
|
|
@@ -157,20 +157,116 @@ class HeadersMiddleware(BaseHTTPMiddleware):
|
|
|
157
157
|
ProjectRowId: TypeAlias = int
|
|
158
158
|
|
|
159
159
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
160
|
+
class GraphQLWithContext(GraphQL): # type: ignore
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
schema: BaseSchema,
|
|
164
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
165
|
+
model: Model,
|
|
166
|
+
export_path: Path,
|
|
167
|
+
graphiql: bool = False,
|
|
168
|
+
corpus: Optional[Model] = None,
|
|
169
|
+
streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
|
|
170
|
+
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
|
|
171
|
+
read_only: bool = False,
|
|
172
|
+
) -> None:
|
|
173
|
+
self.db = db
|
|
174
|
+
self.model = model
|
|
175
|
+
self.corpus = corpus
|
|
176
|
+
self.export_path = export_path
|
|
177
|
+
self.streaming_last_updated_at = streaming_last_updated_at
|
|
178
|
+
self.cache_for_dataloaders = cache_for_dataloaders
|
|
179
|
+
self.read_only = read_only
|
|
180
|
+
super().__init__(schema, graphiql=graphiql)
|
|
181
|
+
|
|
182
|
+
async def get_context(
|
|
183
|
+
self,
|
|
184
|
+
request: Union[Request, WebSocket],
|
|
185
|
+
response: Optional[Response] = None,
|
|
186
|
+
) -> Context:
|
|
187
|
+
return Context(
|
|
188
|
+
request=request,
|
|
189
|
+
response=response,
|
|
190
|
+
db=self.db,
|
|
191
|
+
model=self.model,
|
|
192
|
+
corpus=self.corpus,
|
|
193
|
+
export_path=self.export_path,
|
|
194
|
+
streaming_last_updated_at=self.streaming_last_updated_at,
|
|
195
|
+
data_loaders=DataLoaders(
|
|
196
|
+
average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(self.db),
|
|
197
|
+
dataset_example_revisions=DatasetExampleRevisionsDataLoader(self.db),
|
|
198
|
+
dataset_example_spans=DatasetExampleSpansDataLoader(self.db),
|
|
199
|
+
document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
|
|
200
|
+
self.db,
|
|
201
|
+
cache_map=self.cache_for_dataloaders.document_evaluation_summary
|
|
202
|
+
if self.cache_for_dataloaders
|
|
203
|
+
else None,
|
|
204
|
+
),
|
|
205
|
+
document_evaluations=DocumentEvaluationsDataLoader(self.db),
|
|
206
|
+
document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db),
|
|
207
|
+
evaluation_summaries=EvaluationSummaryDataLoader(
|
|
208
|
+
self.db,
|
|
209
|
+
cache_map=self.cache_for_dataloaders.evaluation_summary
|
|
210
|
+
if self.cache_for_dataloaders
|
|
211
|
+
else None,
|
|
212
|
+
),
|
|
213
|
+
experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(self.db),
|
|
214
|
+
experiment_error_rates=ExperimentErrorRatesDataLoader(self.db),
|
|
215
|
+
experiment_run_counts=ExperimentRunCountsDataLoader(self.db),
|
|
216
|
+
experiment_sequence_number=ExperimentSequenceNumberDataLoader(self.db),
|
|
217
|
+
latency_ms_quantile=LatencyMsQuantileDataLoader(
|
|
218
|
+
self.db,
|
|
219
|
+
cache_map=self.cache_for_dataloaders.latency_ms_quantile
|
|
220
|
+
if self.cache_for_dataloaders
|
|
221
|
+
else None,
|
|
222
|
+
),
|
|
223
|
+
min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(
|
|
224
|
+
self.db,
|
|
225
|
+
cache_map=self.cache_for_dataloaders.min_start_or_max_end_time
|
|
226
|
+
if self.cache_for_dataloaders
|
|
227
|
+
else None,
|
|
228
|
+
),
|
|
229
|
+
record_counts=RecordCountDataLoader(
|
|
230
|
+
self.db,
|
|
231
|
+
cache_map=self.cache_for_dataloaders.record_count
|
|
232
|
+
if self.cache_for_dataloaders
|
|
233
|
+
else None,
|
|
234
|
+
),
|
|
235
|
+
span_descendants=SpanDescendantsDataLoader(self.db),
|
|
236
|
+
span_evaluations=SpanEvaluationsDataLoader(self.db),
|
|
237
|
+
span_projects=SpanProjectsDataLoader(self.db),
|
|
238
|
+
token_counts=TokenCountDataLoader(
|
|
239
|
+
self.db,
|
|
240
|
+
cache_map=self.cache_for_dataloaders.token_count
|
|
241
|
+
if self.cache_for_dataloaders
|
|
242
|
+
else None,
|
|
243
|
+
),
|
|
244
|
+
trace_evaluations=TraceEvaluationsDataLoader(self.db),
|
|
245
|
+
trace_row_ids=TraceRowIdsDataLoader(self.db),
|
|
246
|
+
project_by_name=ProjectByNameDataLoader(self.db),
|
|
247
|
+
span_annotations=SpanAnnotationsDataLoader(self.db),
|
|
248
|
+
),
|
|
249
|
+
cache_for_dataloaders=self.cache_for_dataloaders,
|
|
250
|
+
read_only=self.read_only,
|
|
251
|
+
)
|
|
252
|
+
|
|
170
253
|
|
|
254
|
+
class Download(HTTPEndpoint):
|
|
255
|
+
path: Path
|
|
171
256
|
|
|
172
|
-
|
|
173
|
-
|
|
257
|
+
async def get(self, request: Request) -> FileResponse:
|
|
258
|
+
params = QueryParams(request.query_params)
|
|
259
|
+
file = self.path / (params.get("filename", "") + ".parquet")
|
|
260
|
+
if not file.is_file():
|
|
261
|
+
raise HTTPException(status_code=404)
|
|
262
|
+
return FileResponse(
|
|
263
|
+
path=file,
|
|
264
|
+
filename=file.name,
|
|
265
|
+
media_type="application/x-octet-stream",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
async def version(_: Request) -> PlainTextResponse:
|
|
174
270
|
return PlainTextResponse(f"{phoenix.__version__}")
|
|
175
271
|
|
|
176
272
|
|
|
@@ -192,9 +288,9 @@ def _lifespan(
|
|
|
192
288
|
enable_prometheus: bool = False,
|
|
193
289
|
clean_ups: Iterable[Callable[[], None]] = (),
|
|
194
290
|
read_only: bool = False,
|
|
195
|
-
) -> StatefulLifespan[
|
|
291
|
+
) -> StatefulLifespan[Starlette]:
|
|
196
292
|
@contextlib.asynccontextmanager
|
|
197
|
-
async def lifespan(_:
|
|
293
|
+
async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]:
|
|
198
294
|
async with bulk_inserter as (
|
|
199
295
|
queue_span,
|
|
200
296
|
queue_evaluation,
|
|
@@ -216,89 +312,16 @@ def _lifespan(
|
|
|
216
312
|
return lifespan
|
|
217
313
|
|
|
218
314
|
|
|
219
|
-
@router.get("/healthz")
|
|
220
315
|
async def check_healthz(_: Request) -> PlainTextResponse:
|
|
221
316
|
return PlainTextResponse("OK")
|
|
222
317
|
|
|
223
318
|
|
|
224
|
-
def
|
|
225
|
-
|
|
226
|
-
schema: BaseSchema,
|
|
227
|
-
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
228
|
-
model: Model,
|
|
229
|
-
export_path: Path,
|
|
230
|
-
corpus: Optional[Model] = None,
|
|
231
|
-
streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
|
|
232
|
-
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
|
|
233
|
-
read_only: bool = False,
|
|
234
|
-
) -> GraphQLRouter: # type: ignore[type-arg]
|
|
235
|
-
context = Context(
|
|
236
|
-
db=db,
|
|
237
|
-
model=model,
|
|
238
|
-
corpus=corpus,
|
|
239
|
-
export_path=export_path,
|
|
240
|
-
streaming_last_updated_at=streaming_last_updated_at,
|
|
241
|
-
data_loaders=DataLoaders(
|
|
242
|
-
average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
|
|
243
|
-
dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
|
|
244
|
-
dataset_example_spans=DatasetExampleSpansDataLoader(db),
|
|
245
|
-
document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
|
|
246
|
-
db,
|
|
247
|
-
cache_map=cache_for_dataloaders.document_evaluation_summary
|
|
248
|
-
if cache_for_dataloaders
|
|
249
|
-
else None,
|
|
250
|
-
),
|
|
251
|
-
document_evaluations=DocumentEvaluationsDataLoader(db),
|
|
252
|
-
document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db),
|
|
253
|
-
evaluation_summaries=EvaluationSummaryDataLoader(
|
|
254
|
-
db,
|
|
255
|
-
cache_map=cache_for_dataloaders.evaluation_summary
|
|
256
|
-
if cache_for_dataloaders
|
|
257
|
-
else None,
|
|
258
|
-
),
|
|
259
|
-
experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(db),
|
|
260
|
-
experiment_error_rates=ExperimentErrorRatesDataLoader(db),
|
|
261
|
-
experiment_run_counts=ExperimentRunCountsDataLoader(db),
|
|
262
|
-
experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
|
|
263
|
-
latency_ms_quantile=LatencyMsQuantileDataLoader(
|
|
264
|
-
db,
|
|
265
|
-
cache_map=cache_for_dataloaders.latency_ms_quantile
|
|
266
|
-
if cache_for_dataloaders
|
|
267
|
-
else None,
|
|
268
|
-
),
|
|
269
|
-
min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(
|
|
270
|
-
db,
|
|
271
|
-
cache_map=cache_for_dataloaders.min_start_or_max_end_time
|
|
272
|
-
if cache_for_dataloaders
|
|
273
|
-
else None,
|
|
274
|
-
),
|
|
275
|
-
record_counts=RecordCountDataLoader(
|
|
276
|
-
db,
|
|
277
|
-
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
|
|
278
|
-
),
|
|
279
|
-
span_annotations=SpanAnnotationsDataLoader(db),
|
|
280
|
-
span_descendants=SpanDescendantsDataLoader(db),
|
|
281
|
-
span_evaluations=SpanEvaluationsDataLoader(db),
|
|
282
|
-
span_projects=SpanProjectsDataLoader(db),
|
|
283
|
-
token_counts=TokenCountDataLoader(
|
|
284
|
-
db,
|
|
285
|
-
cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
|
|
286
|
-
),
|
|
287
|
-
trace_evaluations=TraceEvaluationsDataLoader(db),
|
|
288
|
-
trace_row_ids=TraceRowIdsDataLoader(db),
|
|
289
|
-
project_by_name=ProjectByNameDataLoader(db),
|
|
290
|
-
),
|
|
291
|
-
cache_for_dataloaders=cache_for_dataloaders,
|
|
292
|
-
read_only=read_only,
|
|
293
|
-
)
|
|
319
|
+
async def openapi_schema(request: Request) -> Response:
|
|
320
|
+
return OPENAPI_SCHEMA_GENERATOR.OpenAPIResponse(request=request)
|
|
294
321
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
context_getter=lambda: context,
|
|
299
|
-
include_in_schema=False,
|
|
300
|
-
prefix="/graphql",
|
|
301
|
-
)
|
|
322
|
+
|
|
323
|
+
async def api_docs(request: Request) -> Response:
|
|
324
|
+
return get_swagger_ui_html(openapi_url="/schema", title="arize-phoenix API")
|
|
302
325
|
|
|
303
326
|
|
|
304
327
|
class SessionFactory:
|
|
@@ -349,18 +372,6 @@ def instrument_engine_if_enabled(engine: AsyncEngine) -> List[Callable[[], None]
|
|
|
349
372
|
return instrumentation_cleanups
|
|
350
373
|
|
|
351
374
|
|
|
352
|
-
async def plain_text_http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
|
353
|
-
"""
|
|
354
|
-
Overrides the default handler for HTTPExceptions to return a plain text
|
|
355
|
-
response instead of a JSON response. For the original source code, see
|
|
356
|
-
https://github.com/tiangolo/fastapi/blob/d3cdd3bbd14109f3b268df7ca496e24bb64593aa/fastapi/exception_handlers.py#L11
|
|
357
|
-
"""
|
|
358
|
-
headers = getattr(exc, "headers", None)
|
|
359
|
-
if not is_body_allowed_for_status_code(exc.status_code):
|
|
360
|
-
return Response(status_code=exc.status_code, headers=headers)
|
|
361
|
-
return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
|
|
362
|
-
|
|
363
|
-
|
|
364
375
|
def create_app(
|
|
365
376
|
db: SessionFactory,
|
|
366
377
|
export_path: Path,
|
|
@@ -374,7 +385,7 @@ def create_app(
|
|
|
374
385
|
initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
|
|
375
386
|
serve_ui: bool = True,
|
|
376
387
|
clean_up_callbacks: List[Callable[[], None]] = [],
|
|
377
|
-
) ->
|
|
388
|
+
) -> Starlette:
|
|
378
389
|
clean_ups: List[Callable[[], None]] = clean_up_callbacks # To be called at app shutdown.
|
|
379
390
|
initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
|
|
380
391
|
()
|
|
@@ -399,6 +410,7 @@ def create_app(
|
|
|
399
410
|
tracer_provider = None
|
|
400
411
|
strawberry_extensions = schema.get_extensions()
|
|
401
412
|
if server_instrumentation_is_enabled():
|
|
413
|
+
tracer_provider = initialize_opentelemetry_tracer_provider()
|
|
402
414
|
from opentelemetry.trace import TracerProvider
|
|
403
415
|
from strawberry.extensions.tracing import OpenTelemetryExtension
|
|
404
416
|
|
|
@@ -416,7 +428,7 @@ def create_app(
|
|
|
416
428
|
|
|
417
429
|
strawberry_extensions.append(_OpenTelemetryExtension)
|
|
418
430
|
|
|
419
|
-
|
|
431
|
+
graphql = GraphQLWithContext(
|
|
420
432
|
db=db,
|
|
421
433
|
schema=strawberry.Schema(
|
|
422
434
|
query=schema.query,
|
|
@@ -427,6 +439,7 @@ def create_app(
|
|
|
427
439
|
model=model,
|
|
428
440
|
corpus=corpus,
|
|
429
441
|
export_path=export_path,
|
|
442
|
+
graphiql=True,
|
|
430
443
|
streaming_last_updated_at=bulk_inserter.last_updated_at,
|
|
431
444
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
432
445
|
read_only=read_only,
|
|
@@ -437,9 +450,7 @@ def create_app(
|
|
|
437
450
|
prometheus_middlewares = [Middleware(PrometheusMiddleware)]
|
|
438
451
|
else:
|
|
439
452
|
prometheus_middlewares = []
|
|
440
|
-
app =
|
|
441
|
-
title="Arize-Phoenix REST API",
|
|
442
|
-
version=REST_API_VERSION,
|
|
453
|
+
app = Starlette(
|
|
443
454
|
lifespan=_lifespan(
|
|
444
455
|
read_only=read_only,
|
|
445
456
|
bulk_inserter=bulk_inserter,
|
|
@@ -451,39 +462,56 @@ def create_app(
|
|
|
451
462
|
Middleware(HeadersMiddleware),
|
|
452
463
|
*prometheus_middlewares,
|
|
453
464
|
],
|
|
454
|
-
exception_handlers={HTTPException: plain_text_http_exception_handler},
|
|
455
465
|
debug=debug,
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
app.mount(
|
|
468
|
-
"/",
|
|
469
|
-
app=Static(
|
|
470
|
-
directory=SERVER_DIR / "static",
|
|
471
|
-
app_config=AppConfig(
|
|
472
|
-
has_inferences=model.is_empty is not True,
|
|
473
|
-
has_corpus=corpus is not None,
|
|
474
|
-
min_dist=umap_params.min_dist,
|
|
475
|
-
n_neighbors=umap_params.n_neighbors,
|
|
476
|
-
n_samples=umap_params.n_samples,
|
|
466
|
+
routes=V1_ROUTES
|
|
467
|
+
+ [
|
|
468
|
+
Route("/schema", endpoint=openapi_schema, include_in_schema=False),
|
|
469
|
+
Route("/arize_phoenix_version", version),
|
|
470
|
+
Route("/healthz", check_healthz),
|
|
471
|
+
Route(
|
|
472
|
+
"/exports",
|
|
473
|
+
type(
|
|
474
|
+
"DownloadExports",
|
|
475
|
+
(Download,),
|
|
476
|
+
{"path": export_path},
|
|
477
477
|
),
|
|
478
478
|
),
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
479
|
+
Route(
|
|
480
|
+
"/docs",
|
|
481
|
+
api_docs,
|
|
482
|
+
),
|
|
483
|
+
Route(
|
|
484
|
+
"/graphql",
|
|
485
|
+
graphql,
|
|
486
|
+
),
|
|
487
|
+
]
|
|
488
|
+
+ (
|
|
489
|
+
[
|
|
490
|
+
Mount(
|
|
491
|
+
"/",
|
|
492
|
+
app=Static(
|
|
493
|
+
directory=SERVER_DIR / "static",
|
|
494
|
+
app_config=AppConfig(
|
|
495
|
+
has_inferences=model.is_empty is not True,
|
|
496
|
+
has_corpus=corpus is not None,
|
|
497
|
+
min_dist=umap_params.min_dist,
|
|
498
|
+
n_neighbors=umap_params.n_neighbors,
|
|
499
|
+
n_samples=umap_params.n_samples,
|
|
500
|
+
),
|
|
501
|
+
),
|
|
502
|
+
name="static",
|
|
503
|
+
),
|
|
504
|
+
]
|
|
505
|
+
if serve_ui
|
|
506
|
+
else []
|
|
507
|
+
),
|
|
508
|
+
)
|
|
509
|
+
app.state.read_only = read_only
|
|
482
510
|
app.state.db = db
|
|
483
511
|
if tracer_provider:
|
|
484
|
-
from opentelemetry.instrumentation.
|
|
512
|
+
from opentelemetry.instrumentation.starlette import StarletteInstrumentor
|
|
485
513
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
clean_ups.append(
|
|
514
|
+
StarletteInstrumentor().instrument(tracer_provider=tracer_provider)
|
|
515
|
+
StarletteInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
|
|
516
|
+
clean_ups.append(StarletteInstrumentor().uninstrument)
|
|
489
517
|
return app
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from starlette.responses import HTMLResponse
|
|
5
|
+
|
|
6
|
+
swagger_ui_default_parameters: Dict[str, Any] = {
|
|
7
|
+
"dom_id": "#swagger-ui",
|
|
8
|
+
"layout": "BaseLayout",
|
|
9
|
+
"deepLinking": True,
|
|
10
|
+
"showExtensions": True,
|
|
11
|
+
"showCommonExtensions": True,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_swagger_ui_html(
|
|
16
|
+
*,
|
|
17
|
+
openapi_url: str = "/schema",
|
|
18
|
+
title: str,
|
|
19
|
+
swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
|
|
20
|
+
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
|
|
21
|
+
swagger_favicon_url: str = "/favicon.ico",
|
|
22
|
+
oauth2_redirect_url: Optional[str] = None,
|
|
23
|
+
init_oauth: Optional[str] = None,
|
|
24
|
+
swagger_ui_parameters: Optional[Dict[str, Any]] = None,
|
|
25
|
+
) -> HTMLResponse:
|
|
26
|
+
"""
|
|
27
|
+
Generate and return the HTML that loads Swagger UI for the interactive API
|
|
28
|
+
docs (normally served at `/docs`).
|
|
29
|
+
"""
|
|
30
|
+
current_swagger_ui_parameters = swagger_ui_default_parameters.copy()
|
|
31
|
+
if swagger_ui_parameters:
|
|
32
|
+
current_swagger_ui_parameters.update(swagger_ui_parameters)
|
|
33
|
+
|
|
34
|
+
html = f"""
|
|
35
|
+
<!DOCTYPE html>
|
|
36
|
+
<html>
|
|
37
|
+
<head>
|
|
38
|
+
<link type="text/css" rel="stylesheet" href="{swagger_css_url}">
|
|
39
|
+
<link rel="shortcut icon" href="{swagger_favicon_url}">
|
|
40
|
+
<title>{title}</title>
|
|
41
|
+
</head>
|
|
42
|
+
<body>
|
|
43
|
+
<div id="swagger-ui">
|
|
44
|
+
</div>
|
|
45
|
+
<script src="{swagger_js_url}"></script>
|
|
46
|
+
<style type="text/css">
|
|
47
|
+
div[id^="operations-private"]{{display:none}} #operations-tag-private{{display:none}}
|
|
48
|
+
</style>
|
|
49
|
+
<!-- `SwaggerUIBundle` is now available on the page -->
|
|
50
|
+
<script>
|
|
51
|
+
const ui = SwaggerUIBundle({{
|
|
52
|
+
url: '{openapi_url}',
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
for key, value in current_swagger_ui_parameters.items():
|
|
56
|
+
html += f"{json.dumps(key)}: {json.dumps(value)},\n"
|
|
57
|
+
|
|
58
|
+
if oauth2_redirect_url:
|
|
59
|
+
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
|
|
60
|
+
|
|
61
|
+
html += """
|
|
62
|
+
presets: [
|
|
63
|
+
SwaggerUIBundle.presets.apis,
|
|
64
|
+
SwaggerUIBundle.SwaggerUIStandalonePreset
|
|
65
|
+
],
|
|
66
|
+
})"""
|
|
67
|
+
|
|
68
|
+
if init_oauth:
|
|
69
|
+
html += f"""
|
|
70
|
+
ui.initOAuth({json.dumps(init_oauth)})
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
html += """
|
|
74
|
+
</script>
|
|
75
|
+
</body>
|
|
76
|
+
</html>
|
|
77
|
+
"""
|
|
78
|
+
return HTMLResponse(html)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_redoc_html(
|
|
82
|
+
*,
|
|
83
|
+
openapi_url: str,
|
|
84
|
+
title: str,
|
|
85
|
+
redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
|
|
86
|
+
redoc_favicon_url: str = "/favicon.ico",
|
|
87
|
+
with_google_fonts: bool = True,
|
|
88
|
+
) -> HTMLResponse:
|
|
89
|
+
"""
|
|
90
|
+
Generate and return the HTML response that loads ReDoc for the alternative
|
|
91
|
+
API docs (normally served at `/redoc`).
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
html = f"""
|
|
96
|
+
<!DOCTYPE html>
|
|
97
|
+
<html>
|
|
98
|
+
<head>
|
|
99
|
+
<title>{title}</title>
|
|
100
|
+
<!-- needed for adaptive design -->
|
|
101
|
+
<meta charset="utf-8"/>
|
|
102
|
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
103
|
+
"""
|
|
104
|
+
if with_google_fonts:
|
|
105
|
+
html += """
|
|
106
|
+
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
|
107
|
+
""" # noqa: E501
|
|
108
|
+
html += f"""
|
|
109
|
+
<link rel="shortcut icon" href="{redoc_favicon_url}">
|
|
110
|
+
<!--
|
|
111
|
+
ReDoc doesn't change outer page styles
|
|
112
|
+
-->
|
|
113
|
+
<style>
|
|
114
|
+
body {{
|
|
115
|
+
margin: 0;
|
|
116
|
+
padding: 0;
|
|
117
|
+
}}
|
|
118
|
+
</style>
|
|
119
|
+
</head>
|
|
120
|
+
<body>
|
|
121
|
+
<noscript>
|
|
122
|
+
ReDoc requires Javascript to function. Please enable it to browse the documentation.
|
|
123
|
+
</noscript>
|
|
124
|
+
<redoc spec-url="{openapi_url}"></redoc>
|
|
125
|
+
<script src="{redoc_js_url}"> </script>
|
|
126
|
+
</body>
|
|
127
|
+
</html>
|
|
128
|
+
"""
|
|
129
|
+
return HTMLResponse(html)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# Not needed now but copy-pasting for future reference
|
|
133
|
+
def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse:
|
|
134
|
+
"""
|
|
135
|
+
Generate the HTML response with the OAuth2 redirection for Swagger UI.
|
|
136
|
+
|
|
137
|
+
You normally don't need to use or change this.
|
|
138
|
+
"""
|
|
139
|
+
# copied from https://github.com/swagger-api/swagger-ui/blob/v4.14.0/dist/oauth2-redirect.html
|
|
140
|
+
html = """
|
|
141
|
+
<!doctype html>
|
|
142
|
+
<html lang="en-US">
|
|
143
|
+
<head>
|
|
144
|
+
<title>Swagger UI: OAuth2 Redirect</title>
|
|
145
|
+
</head>
|
|
146
|
+
<body>
|
|
147
|
+
<script>
|
|
148
|
+
'use strict';
|
|
149
|
+
function run () {
|
|
150
|
+
var oauth2 = window.opener.swaggerUIRedirectOauth2;
|
|
151
|
+
var sentState = oauth2.state;
|
|
152
|
+
var redirectUrl = oauth2.redirectUrl;
|
|
153
|
+
var isValid, qp, arr;
|
|
154
|
+
|
|
155
|
+
if (/code|token|error/.test(window.location.hash)) {
|
|
156
|
+
qp = window.location.hash.substring(1).replace('?', '&');
|
|
157
|
+
} else {
|
|
158
|
+
qp = location.search.substring(1);
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
arr = qp.split("&");
|
|
162
|
+
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
|
|
163
|
+
qp = qp ? JSON.parse('{' + arr.join() + '}',
|
|
164
|
+
function (key, value) {
|
|
165
|
+
return key === "" ? value : decodeURIComponent(value);
|
|
166
|
+
}
|
|
167
|
+
) : {};
|
|
168
|
+
|
|
169
|
+
isValid = qp.state === sentState;
|
|
170
|
+
|
|
171
|
+
if ((
|
|
172
|
+
oauth2.auth.schema.get("flow") === "accessCode" ||
|
|
173
|
+
oauth2.auth.schema.get("flow") === "authorizationCode" ||
|
|
174
|
+
oauth2.auth.schema.get("flow") === "authorization_code"
|
|
175
|
+
) && !oauth2.auth.code) {
|
|
176
|
+
if (!isValid) {
|
|
177
|
+
oauth2.errCb({
|
|
178
|
+
authId: oauth2.auth.name,
|
|
179
|
+
source: "auth",
|
|
180
|
+
level: "warning",
|
|
181
|
+
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
|
|
182
|
+
});
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
if (qp.code) {
|
|
186
|
+
delete oauth2.state;
|
|
187
|
+
oauth2.auth.code = qp.code;
|
|
188
|
+
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
|
|
189
|
+
} else {
|
|
190
|
+
let oauthErrorMsg;
|
|
191
|
+
if (qp.error) {
|
|
192
|
+
oauthErrorMsg = "["+qp.error+"]: " +
|
|
193
|
+
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
|
|
194
|
+
(qp.error_uri ? "More info: "+qp.error_uri : "");
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
oauth2.errCb({
|
|
198
|
+
authId: oauth2.auth.name,
|
|
199
|
+
source: "auth",
|
|
200
|
+
level: "error",
|
|
201
|
+
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
|
|
202
|
+
});
|
|
203
|
+
}
|
|
204
|
+
} else {
|
|
205
|
+
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
|
|
206
|
+
}
|
|
207
|
+
window.close();
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
if (document.readyState !== 'loading') {
|
|
211
|
+
run();
|
|
212
|
+
} else {
|
|
213
|
+
document.addEventListener('DOMContentLoaded', function () {
|
|
214
|
+
run();
|
|
215
|
+
});
|
|
216
|
+
}
|
|
217
|
+
</script>
|
|
218
|
+
</body>
|
|
219
|
+
</html>
|
|
220
|
+
""" # noqa: E501
|
|
221
|
+
return HTMLResponse(content=html)
|