arize-phoenix 4.10.2rc2__py3-none-any.whl → 4.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (30) hide show
  1. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/METADATA +3 -4
  2. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/RECORD +29 -29
  3. phoenix/server/api/context.py +7 -3
  4. phoenix/server/api/openapi/main.py +2 -18
  5. phoenix/server/api/openapi/schema.py +12 -12
  6. phoenix/server/api/routers/v1/__init__.py +83 -36
  7. phoenix/server/api/routers/v1/dataset_examples.py +123 -102
  8. phoenix/server/api/routers/v1/datasets.py +507 -389
  9. phoenix/server/api/routers/v1/evaluations.py +66 -73
  10. phoenix/server/api/routers/v1/experiment_evaluations.py +91 -67
  11. phoenix/server/api/routers/v1/experiment_runs.py +155 -97
  12. phoenix/server/api/routers/v1/experiments.py +181 -131
  13. phoenix/server/api/routers/v1/spans.py +173 -143
  14. phoenix/server/api/routers/v1/traces.py +128 -114
  15. phoenix/server/api/types/Span.py +1 -0
  16. phoenix/server/app.py +176 -148
  17. phoenix/server/openapi/docs.py +221 -0
  18. phoenix/server/static/index.js +574 -573
  19. phoenix/server/thread_server.py +2 -2
  20. phoenix/session/client.py +5 -0
  21. phoenix/session/data_extractor.py +20 -1
  22. phoenix/session/session.py +4 -0
  23. phoenix/trace/attributes.py +2 -1
  24. phoenix/trace/schemas.py +1 -0
  25. phoenix/trace/span_json_decoder.py +1 -1
  26. phoenix/version.py +1 -1
  27. phoenix/server/api/routers/v1/utils.py +0 -94
  28. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/WHEEL +0 -0
  29. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/licenses/IP_NOTICE +0 -0
  30. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py CHANGED
@@ -19,24 +19,25 @@ from typing import (
19
19
  )
20
20
 
21
21
  import strawberry
22
- from fastapi import APIRouter, FastAPI
23
- from fastapi.middleware.gzip import GZipMiddleware
24
- from fastapi.responses import FileResponse
25
- from fastapi.utils import is_body_allowed_for_status_code
26
22
  from sqlalchemy.ext.asyncio import (
27
23
  AsyncEngine,
28
24
  AsyncSession,
29
25
  async_sessionmaker,
30
26
  )
27
+ from starlette.applications import Starlette
28
+ from starlette.datastructures import QueryParams
29
+ from starlette.endpoints import HTTPEndpoint
31
30
  from starlette.exceptions import HTTPException
32
31
  from starlette.middleware import Middleware
33
32
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
34
33
  from starlette.requests import Request
35
- from starlette.responses import PlainTextResponse, Response
34
+ from starlette.responses import FileResponse, PlainTextResponse, Response
35
+ from starlette.routing import Mount, Route
36
36
  from starlette.staticfiles import StaticFiles
37
37
  from starlette.templating import Jinja2Templates
38
38
  from starlette.types import Scope, StatefulLifespan
39
- from strawberry.fastapi import GraphQLRouter
39
+ from starlette.websockets import WebSocket
40
+ from strawberry.asgi import GraphQL
40
41
  from strawberry.schema import BaseSchema
41
42
  from typing_extensions import TypeAlias
42
43
 
@@ -79,10 +80,11 @@ from phoenix.server.api.dataloaders import (
79
80
  TraceEvaluationsDataLoader,
80
81
  TraceRowIdsDataLoader,
81
82
  )
82
- from phoenix.server.api.routers.v1 import REST_API_VERSION
83
- from phoenix.server.api.routers.v1 import router as v1_router
83
+ from phoenix.server.api.openapi.schema import OPENAPI_SCHEMA_GENERATOR
84
+ from phoenix.server.api.routers.v1 import V1_ROUTES
84
85
  from phoenix.server.api.schema import schema
85
86
  from phoenix.server.grpc_server import GrpcServer
87
+ from phoenix.server.openapi.docs import get_swagger_ui_html
86
88
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
87
89
  from phoenix.trace.schemas import Span
88
90
 
@@ -91,8 +93,6 @@ if TYPE_CHECKING:
91
93
 
92
94
  logger = logging.getLogger(__name__)
93
95
 
94
- router = APIRouter(include_in_schema=False)
95
-
96
96
  templates = Jinja2Templates(directory=SERVER_DIR / "templates")
97
97
 
98
98
 
@@ -157,20 +157,116 @@ class HeadersMiddleware(BaseHTTPMiddleware):
157
157
  ProjectRowId: TypeAlias = int
158
158
 
159
159
 
160
- @router.get("/exports")
161
- async def download_exported_file(request: Request, filename: str) -> FileResponse:
162
- file = request.app.state.export_path / (filename + ".parquet")
163
- if not file.is_file():
164
- raise HTTPException(status_code=404)
165
- return FileResponse(
166
- path=file,
167
- filename=file.name,
168
- media_type="application/x-octet-stream",
169
- )
160
+ class GraphQLWithContext(GraphQL): # type: ignore
161
+ def __init__(
162
+ self,
163
+ schema: BaseSchema,
164
+ db: Callable[[], AsyncContextManager[AsyncSession]],
165
+ model: Model,
166
+ export_path: Path,
167
+ graphiql: bool = False,
168
+ corpus: Optional[Model] = None,
169
+ streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
170
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
171
+ read_only: bool = False,
172
+ ) -> None:
173
+ self.db = db
174
+ self.model = model
175
+ self.corpus = corpus
176
+ self.export_path = export_path
177
+ self.streaming_last_updated_at = streaming_last_updated_at
178
+ self.cache_for_dataloaders = cache_for_dataloaders
179
+ self.read_only = read_only
180
+ super().__init__(schema, graphiql=graphiql)
181
+
182
+ async def get_context(
183
+ self,
184
+ request: Union[Request, WebSocket],
185
+ response: Optional[Response] = None,
186
+ ) -> Context:
187
+ return Context(
188
+ request=request,
189
+ response=response,
190
+ db=self.db,
191
+ model=self.model,
192
+ corpus=self.corpus,
193
+ export_path=self.export_path,
194
+ streaming_last_updated_at=self.streaming_last_updated_at,
195
+ data_loaders=DataLoaders(
196
+ average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(self.db),
197
+ dataset_example_revisions=DatasetExampleRevisionsDataLoader(self.db),
198
+ dataset_example_spans=DatasetExampleSpansDataLoader(self.db),
199
+ document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
200
+ self.db,
201
+ cache_map=self.cache_for_dataloaders.document_evaluation_summary
202
+ if self.cache_for_dataloaders
203
+ else None,
204
+ ),
205
+ document_evaluations=DocumentEvaluationsDataLoader(self.db),
206
+ document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db),
207
+ evaluation_summaries=EvaluationSummaryDataLoader(
208
+ self.db,
209
+ cache_map=self.cache_for_dataloaders.evaluation_summary
210
+ if self.cache_for_dataloaders
211
+ else None,
212
+ ),
213
+ experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(self.db),
214
+ experiment_error_rates=ExperimentErrorRatesDataLoader(self.db),
215
+ experiment_run_counts=ExperimentRunCountsDataLoader(self.db),
216
+ experiment_sequence_number=ExperimentSequenceNumberDataLoader(self.db),
217
+ latency_ms_quantile=LatencyMsQuantileDataLoader(
218
+ self.db,
219
+ cache_map=self.cache_for_dataloaders.latency_ms_quantile
220
+ if self.cache_for_dataloaders
221
+ else None,
222
+ ),
223
+ min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(
224
+ self.db,
225
+ cache_map=self.cache_for_dataloaders.min_start_or_max_end_time
226
+ if self.cache_for_dataloaders
227
+ else None,
228
+ ),
229
+ record_counts=RecordCountDataLoader(
230
+ self.db,
231
+ cache_map=self.cache_for_dataloaders.record_count
232
+ if self.cache_for_dataloaders
233
+ else None,
234
+ ),
235
+ span_descendants=SpanDescendantsDataLoader(self.db),
236
+ span_evaluations=SpanEvaluationsDataLoader(self.db),
237
+ span_projects=SpanProjectsDataLoader(self.db),
238
+ token_counts=TokenCountDataLoader(
239
+ self.db,
240
+ cache_map=self.cache_for_dataloaders.token_count
241
+ if self.cache_for_dataloaders
242
+ else None,
243
+ ),
244
+ trace_evaluations=TraceEvaluationsDataLoader(self.db),
245
+ trace_row_ids=TraceRowIdsDataLoader(self.db),
246
+ project_by_name=ProjectByNameDataLoader(self.db),
247
+ span_annotations=SpanAnnotationsDataLoader(self.db),
248
+ ),
249
+ cache_for_dataloaders=self.cache_for_dataloaders,
250
+ read_only=self.read_only,
251
+ )
252
+
170
253
 
254
+ class Download(HTTPEndpoint):
255
+ path: Path
171
256
 
172
- @router.get("/arize_phoenix_version")
173
- async def version() -> PlainTextResponse:
257
+ async def get(self, request: Request) -> FileResponse:
258
+ params = QueryParams(request.query_params)
259
+ file = self.path / (params.get("filename", "") + ".parquet")
260
+ if not file.is_file():
261
+ raise HTTPException(status_code=404)
262
+ return FileResponse(
263
+ path=file,
264
+ filename=file.name,
265
+ media_type="application/x-octet-stream",
266
+ )
267
+
268
+
269
+ async def version(_: Request) -> PlainTextResponse:
174
270
  return PlainTextResponse(f"{phoenix.__version__}")
175
271
 
176
272
 
@@ -192,9 +288,9 @@ def _lifespan(
192
288
  enable_prometheus: bool = False,
193
289
  clean_ups: Iterable[Callable[[], None]] = (),
194
290
  read_only: bool = False,
195
- ) -> StatefulLifespan[FastAPI]:
291
+ ) -> StatefulLifespan[Starlette]:
196
292
  @contextlib.asynccontextmanager
197
- async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]:
293
+ async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]:
198
294
  async with bulk_inserter as (
199
295
  queue_span,
200
296
  queue_evaluation,
@@ -216,89 +312,16 @@ def _lifespan(
216
312
  return lifespan
217
313
 
218
314
 
219
- @router.get("/healthz")
220
315
  async def check_healthz(_: Request) -> PlainTextResponse:
221
316
  return PlainTextResponse("OK")
222
317
 
223
318
 
224
- def create_graphql_router(
225
- *,
226
- schema: BaseSchema,
227
- db: Callable[[], AsyncContextManager[AsyncSession]],
228
- model: Model,
229
- export_path: Path,
230
- corpus: Optional[Model] = None,
231
- streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
232
- cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
233
- read_only: bool = False,
234
- ) -> GraphQLRouter: # type: ignore[type-arg]
235
- context = Context(
236
- db=db,
237
- model=model,
238
- corpus=corpus,
239
- export_path=export_path,
240
- streaming_last_updated_at=streaming_last_updated_at,
241
- data_loaders=DataLoaders(
242
- average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
243
- dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
244
- dataset_example_spans=DatasetExampleSpansDataLoader(db),
245
- document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
246
- db,
247
- cache_map=cache_for_dataloaders.document_evaluation_summary
248
- if cache_for_dataloaders
249
- else None,
250
- ),
251
- document_evaluations=DocumentEvaluationsDataLoader(db),
252
- document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db),
253
- evaluation_summaries=EvaluationSummaryDataLoader(
254
- db,
255
- cache_map=cache_for_dataloaders.evaluation_summary
256
- if cache_for_dataloaders
257
- else None,
258
- ),
259
- experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(db),
260
- experiment_error_rates=ExperimentErrorRatesDataLoader(db),
261
- experiment_run_counts=ExperimentRunCountsDataLoader(db),
262
- experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
263
- latency_ms_quantile=LatencyMsQuantileDataLoader(
264
- db,
265
- cache_map=cache_for_dataloaders.latency_ms_quantile
266
- if cache_for_dataloaders
267
- else None,
268
- ),
269
- min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(
270
- db,
271
- cache_map=cache_for_dataloaders.min_start_or_max_end_time
272
- if cache_for_dataloaders
273
- else None,
274
- ),
275
- record_counts=RecordCountDataLoader(
276
- db,
277
- cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
278
- ),
279
- span_annotations=SpanAnnotationsDataLoader(db),
280
- span_descendants=SpanDescendantsDataLoader(db),
281
- span_evaluations=SpanEvaluationsDataLoader(db),
282
- span_projects=SpanProjectsDataLoader(db),
283
- token_counts=TokenCountDataLoader(
284
- db,
285
- cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
286
- ),
287
- trace_evaluations=TraceEvaluationsDataLoader(db),
288
- trace_row_ids=TraceRowIdsDataLoader(db),
289
- project_by_name=ProjectByNameDataLoader(db),
290
- ),
291
- cache_for_dataloaders=cache_for_dataloaders,
292
- read_only=read_only,
293
- )
319
+ async def openapi_schema(request: Request) -> Response:
320
+ return OPENAPI_SCHEMA_GENERATOR.OpenAPIResponse(request=request)
294
321
 
295
- return GraphQLRouter(
296
- schema,
297
- graphiql=True,
298
- context_getter=lambda: context,
299
- include_in_schema=False,
300
- prefix="/graphql",
301
- )
322
+
323
+ async def api_docs(request: Request) -> Response:
324
+ return get_swagger_ui_html(openapi_url="/schema", title="arize-phoenix API")
302
325
 
303
326
 
304
327
  class SessionFactory:
@@ -349,18 +372,6 @@ def instrument_engine_if_enabled(engine: AsyncEngine) -> List[Callable[[], None]
349
372
  return instrumentation_cleanups
350
373
 
351
374
 
352
- async def plain_text_http_exception_handler(request: Request, exc: HTTPException) -> Response:
353
- """
354
- Overrides the default handler for HTTPExceptions to return a plain text
355
- response instead of a JSON response. For the original source code, see
356
- https://github.com/tiangolo/fastapi/blob/d3cdd3bbd14109f3b268df7ca496e24bb64593aa/fastapi/exception_handlers.py#L11
357
- """
358
- headers = getattr(exc, "headers", None)
359
- if not is_body_allowed_for_status_code(exc.status_code):
360
- return Response(status_code=exc.status_code, headers=headers)
361
- return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
362
-
363
-
364
375
  def create_app(
365
376
  db: SessionFactory,
366
377
  export_path: Path,
@@ -374,7 +385,7 @@ def create_app(
374
385
  initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
375
386
  serve_ui: bool = True,
376
387
  clean_up_callbacks: List[Callable[[], None]] = [],
377
- ) -> FastAPI:
388
+ ) -> Starlette:
378
389
  clean_ups: List[Callable[[], None]] = clean_up_callbacks # To be called at app shutdown.
379
390
  initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
380
391
  ()
@@ -399,6 +410,7 @@ def create_app(
399
410
  tracer_provider = None
400
411
  strawberry_extensions = schema.get_extensions()
401
412
  if server_instrumentation_is_enabled():
413
+ tracer_provider = initialize_opentelemetry_tracer_provider()
402
414
  from opentelemetry.trace import TracerProvider
403
415
  from strawberry.extensions.tracing import OpenTelemetryExtension
404
416
 
@@ -416,7 +428,7 @@ def create_app(
416
428
 
417
429
  strawberry_extensions.append(_OpenTelemetryExtension)
418
430
 
419
- graphql_router = create_graphql_router(
431
+ graphql = GraphQLWithContext(
420
432
  db=db,
421
433
  schema=strawberry.Schema(
422
434
  query=schema.query,
@@ -427,6 +439,7 @@ def create_app(
427
439
  model=model,
428
440
  corpus=corpus,
429
441
  export_path=export_path,
442
+ graphiql=True,
430
443
  streaming_last_updated_at=bulk_inserter.last_updated_at,
431
444
  cache_for_dataloaders=cache_for_dataloaders,
432
445
  read_only=read_only,
@@ -437,9 +450,7 @@ def create_app(
437
450
  prometheus_middlewares = [Middleware(PrometheusMiddleware)]
438
451
  else:
439
452
  prometheus_middlewares = []
440
- app = FastAPI(
441
- title="Arize-Phoenix REST API",
442
- version=REST_API_VERSION,
453
+ app = Starlette(
443
454
  lifespan=_lifespan(
444
455
  read_only=read_only,
445
456
  bulk_inserter=bulk_inserter,
@@ -451,39 +462,56 @@ def create_app(
451
462
  Middleware(HeadersMiddleware),
452
463
  *prometheus_middlewares,
453
464
  ],
454
- exception_handlers={HTTPException: plain_text_http_exception_handler},
455
465
  debug=debug,
456
- swagger_ui_parameters={
457
- "defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
458
- },
459
- )
460
- app.state.read_only = read_only
461
- app.state.export_path = export_path
462
- app.include_router(v1_router)
463
- app.include_router(router)
464
- app.include_router(graphql_router)
465
- app.add_middleware(GZipMiddleware)
466
- if serve_ui:
467
- app.mount(
468
- "/",
469
- app=Static(
470
- directory=SERVER_DIR / "static",
471
- app_config=AppConfig(
472
- has_inferences=model.is_empty is not True,
473
- has_corpus=corpus is not None,
474
- min_dist=umap_params.min_dist,
475
- n_neighbors=umap_params.n_neighbors,
476
- n_samples=umap_params.n_samples,
466
+ routes=V1_ROUTES
467
+ + [
468
+ Route("/schema", endpoint=openapi_schema, include_in_schema=False),
469
+ Route("/arize_phoenix_version", version),
470
+ Route("/healthz", check_healthz),
471
+ Route(
472
+ "/exports",
473
+ type(
474
+ "DownloadExports",
475
+ (Download,),
476
+ {"path": export_path},
477
477
  ),
478
478
  ),
479
- name="static",
480
- )
481
-
479
+ Route(
480
+ "/docs",
481
+ api_docs,
482
+ ),
483
+ Route(
484
+ "/graphql",
485
+ graphql,
486
+ ),
487
+ ]
488
+ + (
489
+ [
490
+ Mount(
491
+ "/",
492
+ app=Static(
493
+ directory=SERVER_DIR / "static",
494
+ app_config=AppConfig(
495
+ has_inferences=model.is_empty is not True,
496
+ has_corpus=corpus is not None,
497
+ min_dist=umap_params.min_dist,
498
+ n_neighbors=umap_params.n_neighbors,
499
+ n_samples=umap_params.n_samples,
500
+ ),
501
+ ),
502
+ name="static",
503
+ ),
504
+ ]
505
+ if serve_ui
506
+ else []
507
+ ),
508
+ )
509
+ app.state.read_only = read_only
482
510
  app.state.db = db
483
511
  if tracer_provider:
484
- from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
512
+ from opentelemetry.instrumentation.starlette import StarletteInstrumentor
485
513
 
486
- FastAPIInstrumentor().instrument(tracer_provider=tracer_provider)
487
- FastAPIInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
488
- clean_ups.append(FastAPIInstrumentor().uninstrument)
514
+ StarletteInstrumentor().instrument(tracer_provider=tracer_provider)
515
+ StarletteInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
516
+ clean_ups.append(StarletteInstrumentor().uninstrument)
489
517
  return app
@@ -0,0 +1,221 @@
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ from starlette.responses import HTMLResponse
5
+
6
+ swagger_ui_default_parameters: Dict[str, Any] = {
7
+ "dom_id": "#swagger-ui",
8
+ "layout": "BaseLayout",
9
+ "deepLinking": True,
10
+ "showExtensions": True,
11
+ "showCommonExtensions": True,
12
+ }
13
+
14
+
15
+ def get_swagger_ui_html(
16
+ *,
17
+ openapi_url: str = "/schema",
18
+ title: str,
19
+ swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
20
+ swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
21
+ swagger_favicon_url: str = "/favicon.ico",
22
+ oauth2_redirect_url: Optional[str] = None,
23
+ init_oauth: Optional[str] = None,
24
+ swagger_ui_parameters: Optional[Dict[str, Any]] = None,
25
+ ) -> HTMLResponse:
26
+ """
27
+ Generate and return the HTML that loads Swagger UI for the interactive API
28
+ docs (normally served at `/docs`).
29
+ """
30
+ current_swagger_ui_parameters = swagger_ui_default_parameters.copy()
31
+ if swagger_ui_parameters:
32
+ current_swagger_ui_parameters.update(swagger_ui_parameters)
33
+
34
+ html = f"""
35
+ <!DOCTYPE html>
36
+ <html>
37
+ <head>
38
+ <link type="text/css" rel="stylesheet" href="{swagger_css_url}">
39
+ <link rel="shortcut icon" href="{swagger_favicon_url}">
40
+ <title>{title}</title>
41
+ </head>
42
+ <body>
43
+ <div id="swagger-ui">
44
+ </div>
45
+ <script src="{swagger_js_url}"></script>
46
+ <style type="text/css">
47
+ div[id^="operations-private"]{{display:none}} #operations-tag-private{{display:none}}
48
+ </style>
49
+ <!-- `SwaggerUIBundle` is now available on the page -->
50
+ <script>
51
+ const ui = SwaggerUIBundle({{
52
+ url: '{openapi_url}',
53
+ """
54
+
55
+ for key, value in current_swagger_ui_parameters.items():
56
+ html += f"{json.dumps(key)}: {json.dumps(value)},\n"
57
+
58
+ if oauth2_redirect_url:
59
+ html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
60
+
61
+ html += """
62
+ presets: [
63
+ SwaggerUIBundle.presets.apis,
64
+ SwaggerUIBundle.SwaggerUIStandalonePreset
65
+ ],
66
+ })"""
67
+
68
+ if init_oauth:
69
+ html += f"""
70
+ ui.initOAuth({json.dumps(init_oauth)})
71
+ """
72
+
73
+ html += """
74
+ </script>
75
+ </body>
76
+ </html>
77
+ """
78
+ return HTMLResponse(html)
79
+
80
+
81
+ def get_redoc_html(
82
+ *,
83
+ openapi_url: str,
84
+ title: str,
85
+ redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
86
+ redoc_favicon_url: str = "/favicon.ico",
87
+ with_google_fonts: bool = True,
88
+ ) -> HTMLResponse:
89
+ """
90
+ Generate and return the HTML response that loads ReDoc for the alternative
91
+ API docs (normally served at `/redoc`).
92
+
93
+
94
+ """
95
+ html = f"""
96
+ <!DOCTYPE html>
97
+ <html>
98
+ <head>
99
+ <title>{title}</title>
100
+ <!-- needed for adaptive design -->
101
+ <meta charset="utf-8"/>
102
+ <meta name="viewport" content="width=device-width, initial-scale=1">
103
+ """
104
+ if with_google_fonts:
105
+ html += """
106
+ <link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
107
+ """ # noqa: E501
108
+ html += f"""
109
+ <link rel="shortcut icon" href="{redoc_favicon_url}">
110
+ <!--
111
+ ReDoc doesn't change outer page styles
112
+ -->
113
+ <style>
114
+ body {{
115
+ margin: 0;
116
+ padding: 0;
117
+ }}
118
+ </style>
119
+ </head>
120
+ <body>
121
+ <noscript>
122
+ ReDoc requires Javascript to function. Please enable it to browse the documentation.
123
+ </noscript>
124
+ <redoc spec-url="{openapi_url}"></redoc>
125
+ <script src="{redoc_js_url}"> </script>
126
+ </body>
127
+ </html>
128
+ """
129
+ return HTMLResponse(html)
130
+
131
+
132
+ # Not needed now but copy-pasting for future reference
133
+ def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse:
134
+ """
135
+ Generate the HTML response with the OAuth2 redirection for Swagger UI.
136
+
137
+ You normally don't need to use or change this.
138
+ """
139
+ # copied from https://github.com/swagger-api/swagger-ui/blob/v4.14.0/dist/oauth2-redirect.html
140
+ html = """
141
+ <!doctype html>
142
+ <html lang="en-US">
143
+ <head>
144
+ <title>Swagger UI: OAuth2 Redirect</title>
145
+ </head>
146
+ <body>
147
+ <script>
148
+ 'use strict';
149
+ function run () {
150
+ var oauth2 = window.opener.swaggerUIRedirectOauth2;
151
+ var sentState = oauth2.state;
152
+ var redirectUrl = oauth2.redirectUrl;
153
+ var isValid, qp, arr;
154
+
155
+ if (/code|token|error/.test(window.location.hash)) {
156
+ qp = window.location.hash.substring(1).replace('?', '&');
157
+ } else {
158
+ qp = location.search.substring(1);
159
+ }
160
+
161
+ arr = qp.split("&");
162
+ arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
163
+ qp = qp ? JSON.parse('{' + arr.join() + '}',
164
+ function (key, value) {
165
+ return key === "" ? value : decodeURIComponent(value);
166
+ }
167
+ ) : {};
168
+
169
+ isValid = qp.state === sentState;
170
+
171
+ if ((
172
+ oauth2.auth.schema.get("flow") === "accessCode" ||
173
+ oauth2.auth.schema.get("flow") === "authorizationCode" ||
174
+ oauth2.auth.schema.get("flow") === "authorization_code"
175
+ ) && !oauth2.auth.code) {
176
+ if (!isValid) {
177
+ oauth2.errCb({
178
+ authId: oauth2.auth.name,
179
+ source: "auth",
180
+ level: "warning",
181
+ message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
182
+ });
183
+ }
184
+
185
+ if (qp.code) {
186
+ delete oauth2.state;
187
+ oauth2.auth.code = qp.code;
188
+ oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
189
+ } else {
190
+ let oauthErrorMsg;
191
+ if (qp.error) {
192
+ oauthErrorMsg = "["+qp.error+"]: " +
193
+ (qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
194
+ (qp.error_uri ? "More info: "+qp.error_uri : "");
195
+ }
196
+
197
+ oauth2.errCb({
198
+ authId: oauth2.auth.name,
199
+ source: "auth",
200
+ level: "error",
201
+ message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
202
+ });
203
+ }
204
+ } else {
205
+ oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
206
+ }
207
+ window.close();
208
+ }
209
+
210
+ if (document.readyState !== 'loading') {
211
+ run();
212
+ } else {
213
+ document.addEventListener('DOMContentLoaded', function () {
214
+ run();
215
+ });
216
+ }
217
+ </script>
218
+ </body>
219
+ </html>
220
+ """ # noqa: E501
221
+ return HTMLResponse(content=html)