arize-phoenix 4.4.4rc6__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +8 -14
  2. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +58 -122
  3. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -42
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/datetime_utils.py +0 -4
  10. phoenix/db/bulk_inserter.py +14 -54
  11. phoenix/db/insertion/evaluation.py +10 -10
  12. phoenix/db/insertion/helpers.py +14 -17
  13. phoenix/db/insertion/span.py +3 -3
  14. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  15. phoenix/db/models.py +4 -236
  16. phoenix/inferences/fixtures.py +23 -23
  17. phoenix/inferences/inferences.py +7 -7
  18. phoenix/inferences/validation.py +1 -1
  19. phoenix/server/api/context.py +0 -20
  20. phoenix/server/api/dataloaders/__init__.py +0 -20
  21. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  22. phoenix/server/api/routers/v1/__init__.py +2 -77
  23. phoenix/server/api/routers/v1/evaluations.py +13 -8
  24. phoenix/server/api/routers/v1/spans.py +5 -9
  25. phoenix/server/api/routers/v1/traces.py +4 -1
  26. phoenix/server/api/schema.py +303 -2
  27. phoenix/server/api/types/Cluster.py +19 -19
  28. phoenix/server/api/types/Dataset.py +63 -282
  29. phoenix/server/api/types/DatasetRole.py +23 -0
  30. phoenix/server/api/types/Dimension.py +29 -30
  31. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  32. phoenix/server/api/types/Event.py +16 -16
  33. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  34. phoenix/server/api/types/Model.py +42 -43
  35. phoenix/server/api/types/Project.py +12 -26
  36. phoenix/server/api/types/Span.py +2 -79
  37. phoenix/server/api/types/TimeSeries.py +6 -6
  38. phoenix/server/api/types/Trace.py +4 -15
  39. phoenix/server/api/types/UMAPPoints.py +1 -1
  40. phoenix/server/api/types/node.py +111 -5
  41. phoenix/server/api/types/pagination.py +52 -10
  42. phoenix/server/app.py +49 -103
  43. phoenix/server/main.py +27 -49
  44. phoenix/server/openapi/docs.py +0 -3
  45. phoenix/server/static/index.js +1384 -2390
  46. phoenix/server/templates/index.html +0 -1
  47. phoenix/services.py +15 -15
  48. phoenix/session/client.py +23 -611
  49. phoenix/session/session.py +37 -47
  50. phoenix/trace/exporter.py +9 -14
  51. phoenix/trace/fixtures.py +7 -133
  52. phoenix/trace/schemas.py +2 -1
  53. phoenix/trace/span_evaluations.py +3 -3
  54. phoenix/trace/trace_dataset.py +6 -6
  55. phoenix/version.py +1 -1
  56. phoenix/db/insertion/dataset.py +0 -237
  57. phoenix/db/migrations/types.py +0 -29
  58. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  59. phoenix/experiments/__init__.py +0 -6
  60. phoenix/experiments/evaluators/__init__.py +0 -29
  61. phoenix/experiments/evaluators/base.py +0 -153
  62. phoenix/experiments/evaluators/code_evaluators.py +0 -99
  63. phoenix/experiments/evaluators/llm_evaluators.py +0 -244
  64. phoenix/experiments/evaluators/utils.py +0 -189
  65. phoenix/experiments/functions.py +0 -616
  66. phoenix/experiments/tracing.py +0 -85
  67. phoenix/experiments/types.py +0 -722
  68. phoenix/experiments/utils.py +0 -9
  69. phoenix/server/api/dataloaders/average_experiment_run_latency.py +0 -54
  70. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  71. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  72. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  73. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  74. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  75. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  76. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  77. phoenix/server/api/dataloaders/span_projects.py +0 -33
  78. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  79. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  80. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  81. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  82. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  83. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  84. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  85. phoenix/server/api/input_types/DatasetSort.py +0 -17
  86. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  87. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  88. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  89. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  90. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  91. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  92. phoenix/server/api/mutations/__init__.py +0 -13
  93. phoenix/server/api/mutations/auth.py +0 -11
  94. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  95. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  96. phoenix/server/api/mutations/project_mutations.py +0 -47
  97. phoenix/server/api/openapi/__init__.py +0 -0
  98. phoenix/server/api/openapi/main.py +0 -6
  99. phoenix/server/api/openapi/schema.py +0 -16
  100. phoenix/server/api/queries.py +0 -503
  101. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  102. phoenix/server/api/routers/v1/datasets.py +0 -965
  103. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -65
  104. phoenix/server/api/routers/v1/experiment_runs.py +0 -96
  105. phoenix/server/api/routers/v1/experiments.py +0 -174
  106. phoenix/server/api/types/AnnotatorKind.py +0 -10
  107. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  108. phoenix/server/api/types/DatasetExample.py +0 -85
  109. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  110. phoenix/server/api/types/DatasetVersion.py +0 -14
  111. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  112. phoenix/server/api/types/Experiment.py +0 -147
  113. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  114. phoenix/server/api/types/ExperimentComparison.py +0 -19
  115. phoenix/server/api/types/ExperimentRun.py +0 -91
  116. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  117. phoenix/server/api/types/Inferences.py +0 -80
  118. phoenix/server/api/types/InferencesRole.py +0 -23
  119. phoenix/utilities/json.py +0 -61
  120. phoenix/utilities/re.py +0 -50
  121. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -11,54 +11,34 @@ from typing_extensions import TypeAlias
11
11
 
12
12
  from phoenix.core.model_schema import Model
13
13
  from phoenix.server.api.dataloaders import (
14
- AverageExperimentRunLatencyDataLoader,
15
14
  CacheForDataLoaders,
16
- DatasetExampleRevisionsDataLoader,
17
- DatasetExampleSpansDataLoader,
18
15
  DocumentEvaluationsDataLoader,
19
16
  DocumentEvaluationSummaryDataLoader,
20
17
  DocumentRetrievalMetricsDataLoader,
21
18
  EvaluationSummaryDataLoader,
22
- ExperimentAnnotationSummaryDataLoader,
23
- ExperimentErrorRatesDataLoader,
24
- ExperimentRunCountsDataLoader,
25
- ExperimentSequenceNumberDataLoader,
26
19
  LatencyMsQuantileDataLoader,
27
20
  MinStartOrMaxEndTimeDataLoader,
28
- ProjectByNameDataLoader,
29
21
  RecordCountDataLoader,
30
22
  SpanDescendantsDataLoader,
31
23
  SpanEvaluationsDataLoader,
32
- SpanProjectsDataLoader,
33
24
  TokenCountDataLoader,
34
25
  TraceEvaluationsDataLoader,
35
- TraceRowIdsDataLoader,
36
26
  )
37
27
 
38
28
 
39
29
  @dataclass
40
30
  class DataLoaders:
41
- average_experiment_run_latency: AverageExperimentRunLatencyDataLoader
42
- dataset_example_revisions: DatasetExampleRevisionsDataLoader
43
- dataset_example_spans: DatasetExampleSpansDataLoader
44
31
  document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
45
32
  document_evaluations: DocumentEvaluationsDataLoader
46
33
  document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
47
34
  evaluation_summaries: EvaluationSummaryDataLoader
48
- experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
49
- experiment_error_rates: ExperimentErrorRatesDataLoader
50
- experiment_run_counts: ExperimentRunCountsDataLoader
51
- experiment_sequence_number: ExperimentSequenceNumberDataLoader
52
35
  latency_ms_quantile: LatencyMsQuantileDataLoader
53
36
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
54
37
  record_counts: RecordCountDataLoader
55
38
  span_descendants: SpanDescendantsDataLoader
56
39
  span_evaluations: SpanEvaluationsDataLoader
57
- span_projects: SpanProjectsDataLoader
58
40
  token_counts: TokenCountDataLoader
59
41
  trace_evaluations: TraceEvaluationsDataLoader
60
- trace_row_ids: TraceRowIdsDataLoader
61
- project_by_name: ProjectByNameDataLoader
62
42
 
63
43
 
64
44
  ProjectRowId: TypeAlias = int
@@ -8,9 +8,6 @@ from phoenix.db.insertion.evaluation import (
8
8
  )
9
9
  from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
10
10
 
11
- from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
12
- from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
13
- from .dataset_example_spans import DatasetExampleSpansDataLoader
14
11
  from .document_evaluation_summaries import (
15
12
  DocumentEvaluationSummaryCache,
16
13
  DocumentEvaluationSummaryDataLoader,
@@ -18,44 +15,27 @@ from .document_evaluation_summaries import (
18
15
  from .document_evaluations import DocumentEvaluationsDataLoader
19
16
  from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
20
17
  from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
21
- from .experiment_annotation_summaries import ExperimentAnnotationSummaryDataLoader
22
- from .experiment_error_rates import ExperimentErrorRatesDataLoader
23
- from .experiment_run_counts import ExperimentRunCountsDataLoader
24
- from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
25
18
  from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
26
19
  from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
27
- from .project_by_name import ProjectByNameDataLoader
28
20
  from .record_counts import RecordCountCache, RecordCountDataLoader
29
21
  from .span_descendants import SpanDescendantsDataLoader
30
22
  from .span_evaluations import SpanEvaluationsDataLoader
31
- from .span_projects import SpanProjectsDataLoader
32
23
  from .token_counts import TokenCountCache, TokenCountDataLoader
33
24
  from .trace_evaluations import TraceEvaluationsDataLoader
34
- from .trace_row_ids import TraceRowIdsDataLoader
35
25
 
36
26
  __all__ = [
37
27
  "CacheForDataLoaders",
38
- "AverageExperimentRunLatencyDataLoader",
39
- "DatasetExampleRevisionsDataLoader",
40
- "DatasetExampleSpansDataLoader",
41
28
  "DocumentEvaluationSummaryDataLoader",
42
29
  "DocumentEvaluationsDataLoader",
43
30
  "DocumentRetrievalMetricsDataLoader",
44
31
  "EvaluationSummaryDataLoader",
45
- "ExperimentAnnotationSummaryDataLoader",
46
- "ExperimentErrorRatesDataLoader",
47
- "ExperimentRunCountsDataLoader",
48
- "ExperimentSequenceNumberDataLoader",
49
32
  "LatencyMsQuantileDataLoader",
50
33
  "MinStartOrMaxEndTimeDataLoader",
51
34
  "RecordCountDataLoader",
52
35
  "SpanDescendantsDataLoader",
53
36
  "SpanEvaluationsDataLoader",
54
- "SpanProjectsDataLoader",
55
37
  "TokenCountDataLoader",
56
38
  "TraceEvaluationsDataLoader",
57
- "TraceRowIdsDataLoader",
58
- "ProjectByNameDataLoader",
59
39
  ]
60
40
 
61
41
 
@@ -9,7 +9,7 @@ from typing import (
9
9
  from aioitertools.itertools import groupby
10
10
  from sqlalchemy import select
11
11
  from sqlalchemy.ext.asyncio import AsyncSession
12
- from sqlalchemy.orm import joinedload
12
+ from sqlalchemy.orm import contains_eager
13
13
  from strawberry.dataloader import DataLoader
14
14
  from typing_extensions import TypeAlias
15
15
 
@@ -52,7 +52,8 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
52
52
  stmt = (
53
53
  select(descendant_ids.c[root_id_label], models.Span)
54
54
  .join(descendant_ids, models.Span.id == descendant_ids.c.id)
55
- .options(joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id))
55
+ .join(models.Trace)
56
+ .options(contains_eager(models.Span.trace))
56
57
  .order_by(descendant_ids.c[root_id_label])
57
58
  )
58
59
  results: Dict[SpanId, Result] = {key: [] for key in keys}
@@ -1,40 +1,6 @@
1
- from typing import Any, Awaitable, Callable, Mapping, Tuple
2
-
3
- import wrapt
4
- from starlette import routing
5
- from starlette.requests import Request
6
- from starlette.responses import Response
7
- from starlette.status import HTTP_403_FORBIDDEN
8
-
9
- from . import (
10
- datasets,
11
- evaluations,
12
- experiment_evaluations,
13
- experiment_runs,
14
- experiments,
15
- spans,
16
- traces,
17
- )
18
- from .dataset_examples import list_dataset_examples
19
-
20
-
21
- @wrapt.decorator # type: ignore
22
- async def forbid_if_readonly(
23
- wrapped: Callable[[Request], Awaitable[Response]],
24
- _: Any,
25
- args: Tuple[Request],
26
- kwargs: Mapping[str, Any],
27
- ) -> Response:
28
- request, *_ = args
29
- if request.app.state.read_only:
30
- return Response(status_code=HTTP_403_FORBIDDEN)
31
- return await wrapped(*args, **kwargs)
32
-
33
-
34
- class Route(routing.Route):
35
- def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None:
36
- super().__init__(path, forbid_if_readonly(endpoint), **kwargs)
1
+ from starlette.routing import Route
37
2
 
3
+ from . import evaluations, spans, traces
38
4
 
39
5
  V1_ROUTES = [
40
6
  Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
@@ -42,45 +8,4 @@ V1_ROUTES = [
42
8
  Route("/v1/traces", traces.post_traces, methods=["POST"]),
43
9
  Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
44
10
  Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
45
- Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
46
- Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
47
- Route("/v1/datasets/{id:str}", datasets.get_dataset_by_id, methods=["GET"]),
48
- Route("/v1/datasets/{id:str}/csv", datasets.get_dataset_csv, methods=["GET"]),
49
- Route(
50
- "/v1/datasets/{id:str}/jsonl/openai_ft",
51
- datasets.get_dataset_jsonl_openai_ft,
52
- methods=["GET"],
53
- ),
54
- Route(
55
- "/v1/datasets/{id:str}/jsonl/openai_evals",
56
- datasets.get_dataset_jsonl_openai_evals,
57
- methods=["GET"],
58
- ),
59
- Route("/v1/datasets/{id:str}/examples", list_dataset_examples, methods=["GET"]),
60
- Route("/v1/datasets/{id:str}/versions", datasets.get_dataset_versions, methods=["GET"]),
61
- Route(
62
- "/v1/datasets/{dataset_id:str}/experiments",
63
- experiments.create_experiment,
64
- methods=["POST"],
65
- ),
66
- Route(
67
- "/v1/experiments/{experiment_id:str}",
68
- experiments.read_experiment,
69
- methods=["GET"],
70
- ),
71
- Route(
72
- "/v1/experiments/{experiment_id:str}/runs",
73
- experiment_runs.create_experiment_run,
74
- methods=["POST"],
75
- ),
76
- Route(
77
- "/v1/experiments/{experiment_id:str}/runs",
78
- experiment_runs.list_experiment_runs,
79
- methods=["GET"],
80
- ),
81
- Route(
82
- "/v1/experiment_evaluations",
83
- experiment_evaluations.upsert_experiment_evaluation,
84
- methods=["POST"],
85
- ),
86
11
  ]
@@ -44,7 +44,14 @@ async def post_evaluations(request: Request) -> Response:
44
44
  summary: Add evaluations to a span, trace, or document
45
45
  operationId: addEvaluations
46
46
  tags:
47
- - private
47
+ - evaluations
48
+ parameters:
49
+ - name: project-name
50
+ in: query
51
+ schema:
52
+ type: string
53
+ default: default
54
+ description: The project name to add the evaluation to
48
55
  requestBody:
49
56
  required: true
50
57
  content:
@@ -98,9 +105,9 @@ async def get_evaluations(request: Request) -> Response:
98
105
  summary: Get evaluations from Phoenix
99
106
  operationId: getEvaluation
100
107
  tags:
101
- - private
108
+ - evaluations
102
109
  parameters:
103
- - name: project_name
110
+ - name: project-name
104
111
  in: query
105
112
  schema:
106
113
  type: string
@@ -109,15 +116,13 @@ async def get_evaluations(request: Request) -> Response:
109
116
  responses:
110
117
  200:
111
118
  description: Success
112
- 403:
113
- description: Forbidden
114
119
  404:
115
120
  description: Not found
116
121
  """
117
122
  project_name = (
118
- request.query_params.get("project_name")
119
- or request.query_params.get("project-name") # for backward compatibility
120
- or request.headers.get("project-name") # read from headers for backwards compatibility
123
+ request.query_params.get("project-name")
124
+ # read from headers for backwards compatibility
125
+ or request.headers.get("project-name")
121
126
  or DEFAULT_PROJECT_NAME
122
127
  )
123
128
 
@@ -19,9 +19,9 @@ async def query_spans_handler(request: Request) -> Response:
19
19
  summary: Query spans using query DSL
20
20
  operationId: querySpans
21
21
  tags:
22
- - private
22
+ - spans
23
23
  parameters:
24
- - name: project_name
24
+ - name: project-name
25
25
  in: query
26
26
  schema:
27
27
  type: string
@@ -68,8 +68,6 @@ async def query_spans_handler(request: Request) -> Response:
68
68
  responses:
69
69
  200:
70
70
  description: Success
71
- 403:
72
- description: Forbidden
73
71
  404:
74
72
  description: Not found
75
73
  422:
@@ -78,11 +76,9 @@ async def query_spans_handler(request: Request) -> Response:
78
76
  payload = await request.json()
79
77
  queries = payload.pop("queries", [])
80
78
  project_name = (
81
- request.query_params.get("project_name")
82
- or request.query_params.get("project-name") # for backward compatibility
83
- or request.headers.get(
84
- "project-name"
85
- ) # read from headers/payload for backward-compatibility
79
+ request.query_params.get("project-name")
80
+ # read from headers/payload for backward-compatibility
81
+ or request.headers.get("project-name")
86
82
  or payload.get("project_name")
87
83
  or DEFAULT_PROJECT_NAME
88
84
  )
@@ -11,6 +11,7 @@ from starlette.datastructures import State
11
11
  from starlette.requests import Request
12
12
  from starlette.responses import Response
13
13
  from starlette.status import (
14
+ HTTP_403_FORBIDDEN,
14
15
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
15
16
  HTTP_422_UNPROCESSABLE_ENTITY,
16
17
  )
@@ -24,7 +25,7 @@ async def post_traces(request: Request) -> Response:
24
25
  summary: Send traces to Phoenix
25
26
  operationId: addTraces
26
27
  tags:
27
- - private
28
+ - traces
28
29
  requestBody:
29
30
  required: true
30
31
  content:
@@ -42,6 +43,8 @@ async def post_traces(request: Request) -> Response:
42
43
  422:
43
44
  description: Request body is invalid
44
45
  """
46
+ if request.app.state.read_only:
47
+ return Response(status_code=HTTP_403_FORBIDDEN)
45
48
  content_type = request.headers.get("content-type")
46
49
  if content_type != "application/x-protobuf":
47
50
  return Response(
@@ -1,7 +1,308 @@
1
+ from collections import defaultdict
2
+ from typing import Dict, List, Optional, Set, Union
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
1
6
  import strawberry
7
+ from sqlalchemy import delete, select
8
+ from sqlalchemy.orm import contains_eager, load_only
9
+ from strawberry import ID, UNSET
10
+ from strawberry.types import Info
11
+ from typing_extensions import Annotated
12
+
13
+ from phoenix.config import DEFAULT_PROJECT_NAME
14
+ from phoenix.db import models
15
+ from phoenix.db.insertion.span import ClearProjectSpansEvent
16
+ from phoenix.pointcloud.clustering import Hdbscan
17
+ from phoenix.server.api.context import Context
18
+ from phoenix.server.api.helpers import ensure_list
19
+ from phoenix.server.api.input_types.ClusterInput import ClusterInput
20
+ from phoenix.server.api.input_types.Coordinates import (
21
+ InputCoordinate2D,
22
+ InputCoordinate3D,
23
+ )
24
+ from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
25
+ from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
26
+ from phoenix.server.api.types.Dimension import to_gql_dimension
27
+ from phoenix.server.api.types.EmbeddingDimension import (
28
+ DEFAULT_CLUSTER_SELECTION_EPSILON,
29
+ DEFAULT_MIN_CLUSTER_SIZE,
30
+ DEFAULT_MIN_SAMPLES,
31
+ to_gql_embedding_dimension,
32
+ )
33
+ from phoenix.server.api.types.Event import create_event_id, unpack_event_id
34
+ from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
35
+ from phoenix.server.api.types.Functionality import Functionality
36
+ from phoenix.server.api.types.Model import Model
37
+ from phoenix.server.api.types.node import (
38
+ GlobalID,
39
+ Node,
40
+ from_global_id,
41
+ from_global_id_with_expected_type,
42
+ )
43
+ from phoenix.server.api.types.pagination import (
44
+ Connection,
45
+ ConnectionArgs,
46
+ CursorString,
47
+ connection_from_list,
48
+ )
49
+ from phoenix.server.api.types.Project import Project
50
+ from phoenix.server.api.types.Span import to_gql_span
51
+ from phoenix.server.api.types.Trace import Trace
52
+
53
+
54
+ @strawberry.type
55
+ class Query:
56
+ @strawberry.field
57
+ async def projects(
58
+ self,
59
+ info: Info[Context, None],
60
+ first: Optional[int] = 50,
61
+ last: Optional[int] = UNSET,
62
+ after: Optional[CursorString] = UNSET,
63
+ before: Optional[CursorString] = UNSET,
64
+ ) -> Connection[Project]:
65
+ args = ConnectionArgs(
66
+ first=first,
67
+ after=after if isinstance(after, CursorString) else None,
68
+ last=last,
69
+ before=before if isinstance(before, CursorString) else None,
70
+ )
71
+ async with info.context.db() as session:
72
+ projects = await session.scalars(select(models.Project))
73
+ data = [
74
+ Project(
75
+ id_attr=project.id,
76
+ name=project.name,
77
+ gradient_start_color=project.gradient_start_color,
78
+ gradient_end_color=project.gradient_end_color,
79
+ )
80
+ for project in projects
81
+ ]
82
+ return connection_from_list(data=data, args=args)
83
+
84
+ @strawberry.field
85
+ async def functionality(self, info: Info[Context, None]) -> "Functionality":
86
+ has_model_inferences = not info.context.model.is_empty
87
+ async with info.context.db() as session:
88
+ has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
89
+ return Functionality(
90
+ model_inferences=has_model_inferences,
91
+ tracing=has_traces,
92
+ )
93
+
94
+ @strawberry.field
95
+ def model(self) -> Model:
96
+ return Model()
97
+
98
+ @strawberry.field
99
+ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
100
+ type_name, node_id = from_global_id(str(id))
101
+ if type_name == "Dimension":
102
+ dimension = info.context.model.scalar_dimensions[node_id]
103
+ return to_gql_dimension(node_id, dimension)
104
+ elif type_name == "EmbeddingDimension":
105
+ embedding_dimension = info.context.model.embedding_dimensions[node_id]
106
+ return to_gql_embedding_dimension(node_id, embedding_dimension)
107
+ elif type_name == "Project":
108
+ project_stmt = select(
109
+ models.Project.id,
110
+ models.Project.name,
111
+ models.Project.gradient_start_color,
112
+ models.Project.gradient_end_color,
113
+ ).where(models.Project.id == node_id)
114
+ async with info.context.db() as session:
115
+ project = (await session.execute(project_stmt)).first()
116
+ if project is None:
117
+ raise ValueError(f"Unknown project: {id}")
118
+ return Project(
119
+ id_attr=project.id,
120
+ name=project.name,
121
+ gradient_start_color=project.gradient_start_color,
122
+ gradient_end_color=project.gradient_end_color,
123
+ )
124
+ elif type_name == "Trace":
125
+ trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
126
+ async with info.context.db() as session:
127
+ id_attr = await session.scalar(trace_stmt)
128
+ if id_attr is None:
129
+ raise ValueError(f"Unknown trace: {id}")
130
+ return Trace(id_attr=id_attr)
131
+ elif type_name == "Span":
132
+ span_stmt = (
133
+ select(models.Span)
134
+ .join(models.Trace)
135
+ .options(contains_eager(models.Span.trace))
136
+ .where(models.Span.id == node_id)
137
+ )
138
+ async with info.context.db() as session:
139
+ span = await session.scalar(span_stmt)
140
+ if span is None:
141
+ raise ValueError(f"Unknown span: {id}")
142
+ return to_gql_span(span)
143
+ raise Exception(f"Unknown node type: {type_name}")
144
+
145
+ @strawberry.field
146
+ def clusters(
147
+ self,
148
+ clusters: List[ClusterInput],
149
+ ) -> List[Cluster]:
150
+ clustered_events: Dict[str, Set[ID]] = defaultdict(set)
151
+ for i, cluster in enumerate(clusters):
152
+ clustered_events[cluster.id or str(i)].update(cluster.event_ids)
153
+ return to_gql_clusters(
154
+ clustered_events=clustered_events,
155
+ )
156
+
157
+ @strawberry.field
158
+ def hdbscan_clustering(
159
+ self,
160
+ info: Info[Context, None],
161
+ event_ids: Annotated[
162
+ List[ID],
163
+ strawberry.argument(
164
+ description="Event ID of the coordinates",
165
+ ),
166
+ ],
167
+ coordinates_2d: Annotated[
168
+ Optional[List[InputCoordinate2D]],
169
+ strawberry.argument(
170
+ description="Point coordinates. Must be either 2D or 3D.",
171
+ ),
172
+ ] = UNSET,
173
+ coordinates_3d: Annotated[
174
+ Optional[List[InputCoordinate3D]],
175
+ strawberry.argument(
176
+ description="Point coordinates. Must be either 2D or 3D.",
177
+ ),
178
+ ] = UNSET,
179
+ min_cluster_size: Annotated[
180
+ int,
181
+ strawberry.argument(
182
+ description="HDBSCAN minimum cluster size",
183
+ ),
184
+ ] = DEFAULT_MIN_CLUSTER_SIZE,
185
+ cluster_min_samples: Annotated[
186
+ int,
187
+ strawberry.argument(
188
+ description="HDBSCAN minimum samples",
189
+ ),
190
+ ] = DEFAULT_MIN_SAMPLES,
191
+ cluster_selection_epsilon: Annotated[
192
+ float,
193
+ strawberry.argument(
194
+ description="HDBSCAN cluster selection epsilon",
195
+ ),
196
+ ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
197
+ ) -> List[Cluster]:
198
+ coordinates_3d = ensure_list(coordinates_3d)
199
+ coordinates_2d = ensure_list(coordinates_2d)
200
+
201
+ if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
202
+ raise ValueError("must specify only one of 2D or 3D coordinates")
203
+
204
+ if len(coordinates_3d) > 0:
205
+ coordinates = list(
206
+ map(
207
+ lambda coord: np.array(
208
+ [coord.x, coord.y, coord.z],
209
+ ),
210
+ coordinates_3d,
211
+ )
212
+ )
213
+ else:
214
+ coordinates = list(
215
+ map(
216
+ lambda coord: np.array(
217
+ [coord.x, coord.y],
218
+ ),
219
+ coordinates_2d,
220
+ )
221
+ )
222
+
223
+ if len(event_ids) != len(coordinates):
224
+ raise ValueError(
225
+ f"length mismatch between "
226
+ f"event_ids ({len(event_ids)}) "
227
+ f"and coordinates ({len(coordinates)})"
228
+ )
229
+
230
+ if len(event_ids) == 0:
231
+ return []
232
+
233
+ grouped_event_ids: Dict[
234
+ Union[DatasetRole, AncillaryDatasetRole],
235
+ List[ID],
236
+ ] = defaultdict(list)
237
+ grouped_coordinates: Dict[
238
+ Union[DatasetRole, AncillaryDatasetRole],
239
+ List[npt.NDArray[np.float64]],
240
+ ] = defaultdict(list)
241
+
242
+ for event_id, coordinate in zip(event_ids, coordinates):
243
+ row_id, dataset_role = unpack_event_id(event_id)
244
+ grouped_coordinates[dataset_role].append(coordinate)
245
+ grouped_event_ids[dataset_role].append(create_event_id(row_id, dataset_role))
246
+
247
+ stacked_event_ids = (
248
+ grouped_event_ids[DatasetRole.primary]
249
+ + grouped_event_ids[DatasetRole.reference]
250
+ + grouped_event_ids[AncillaryDatasetRole.corpus]
251
+ )
252
+ stacked_coordinates = np.stack(
253
+ grouped_coordinates[DatasetRole.primary]
254
+ + grouped_coordinates[DatasetRole.reference]
255
+ + grouped_coordinates[AncillaryDatasetRole.corpus]
256
+ )
257
+
258
+ clusters = Hdbscan(
259
+ min_cluster_size=min_cluster_size,
260
+ min_samples=cluster_min_samples,
261
+ cluster_selection_epsilon=cluster_selection_epsilon,
262
+ ).find_clusters(stacked_coordinates)
263
+
264
+ clustered_events = {
265
+ str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
266
+ for i, cluster in enumerate(clusters)
267
+ }
268
+
269
+ return to_gql_clusters(
270
+ clustered_events=clustered_events,
271
+ )
272
+
273
+
274
+ @strawberry.type
275
+ class Mutation(ExportEventsMutation):
276
+ @strawberry.mutation
277
+ async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
278
+ if info.context.read_only:
279
+ return Query()
280
+ node_id = from_global_id_with_expected_type(str(id), "Project")
281
+ async with info.context.db() as session:
282
+ project = await session.scalar(
283
+ select(models.Project)
284
+ .where(models.Project.id == node_id)
285
+ .options(load_only(models.Project.name))
286
+ )
287
+ if project is None:
288
+ raise ValueError(f"Unknown project: {id}")
289
+ if project.name == DEFAULT_PROJECT_NAME:
290
+ raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
291
+ await session.delete(project)
292
+ return Query()
293
+
294
+ @strawberry.mutation
295
+ async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
296
+ if info.context.read_only:
297
+ return Query()
298
+ project_id = from_global_id_with_expected_type(str(id), "Project")
299
+ delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
300
+ async with info.context.db() as session:
301
+ await session.execute(delete_statement)
302
+ if cache := info.context.cache_for_dataloaders:
303
+ cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
304
+ return Query()
2
305
 
3
- from phoenix.server.api.mutations import Mutation
4
- from phoenix.server.api.queries import Query
5
306
 
6
307
  # This is the schema for generating `schema.graphql`.
7
308
  # See https://strawberry.rocks/docs/guides/schema-export