arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +27 -0
- phoenix/config.py +7 -21
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +62 -64
- phoenix/core/model_schema_adapter.py +25 -27
- phoenix/db/bulk_inserter.py +14 -54
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +2 -13
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
- phoenix/db/models.py +4 -236
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +0 -18
- phoenix/server/api/dataloaders/__init__.py +0 -18
- phoenix/server/api/dataloaders/span_descendants.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -77
- phoenix/server/api/routers/v1/evaluations.py +2 -4
- phoenix/server/api/routers/v1/spans.py +1 -3
- phoenix/server/api/routers/v1/traces.py +4 -1
- phoenix/server/api/schema.py +303 -2
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/Dataset.py +63 -282
- phoenix/server/api/types/DatasetRole.py +23 -0
- phoenix/server/api/types/Dimension.py +29 -30
- phoenix/server/api/types/EmbeddingDimension.py +34 -40
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
- phoenix/server/api/types/Model.py +42 -43
- phoenix/server/api/types/Project.py +12 -26
- phoenix/server/api/types/Span.py +2 -79
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +4 -15
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +111 -5
- phoenix/server/api/types/pagination.py +52 -10
- phoenix/server/app.py +49 -101
- phoenix/server/main.py +27 -49
- phoenix/server/openapi/docs.py +0 -3
- phoenix/server/static/index.js +2595 -3523
- phoenix/server/templates/index.html +0 -1
- phoenix/services.py +15 -15
- phoenix/session/client.py +21 -438
- phoenix/session/session.py +37 -47
- phoenix/trace/exporter.py +9 -14
- phoenix/trace/fixtures.py +7 -133
- phoenix/trace/schemas.py +2 -1
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/version.py +1 -1
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators/__init__.py +0 -18
- phoenix/datasets/evaluators/code_evaluators.py +0 -99
- phoenix/datasets/evaluators/llm_evaluators.py +0 -244
- phoenix/datasets/evaluators/utils.py +0 -292
- phoenix/datasets/experiments.py +0 -550
- phoenix/datasets/tracing.py +0 -85
- phoenix/datasets/types.py +0 -178
- phoenix/db/insertion/dataset.py +0 -237
- phoenix/db/migrations/types.py +0 -29
- phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
- phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
- phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
- phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
- phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
- phoenix/server/api/dataloaders/project_by_name.py +0 -31
- phoenix/server/api/dataloaders/span_projects.py +0 -33
- phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
- phoenix/server/api/helpers/dataset_helpers.py +0 -179
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
- phoenix/server/api/input_types/ClearProjectInput.py +0 -15
- phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
- phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
- phoenix/server/api/input_types/DatasetSort.py +0 -17
- phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
- phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
- phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
- phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
- phoenix/server/api/mutations/__init__.py +0 -13
- phoenix/server/api/mutations/auth.py +0 -11
- phoenix/server/api/mutations/dataset_mutations.py +0 -520
- phoenix/server/api/mutations/experiment_mutations.py +0 -65
- phoenix/server/api/mutations/project_mutations.py +0 -47
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +0 -6
- phoenix/server/api/openapi/schema.py +0 -16
- phoenix/server/api/queries.py +0 -503
- phoenix/server/api/routers/v1/dataset_examples.py +0 -178
- phoenix/server/api/routers/v1/datasets.py +0 -965
- phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
- phoenix/server/api/routers/v1/experiment_runs.py +0 -108
- phoenix/server/api/routers/v1/experiments.py +0 -174
- phoenix/server/api/types/AnnotatorKind.py +0 -10
- phoenix/server/api/types/CreateDatasetPayload.py +0 -8
- phoenix/server/api/types/DatasetExample.py +0 -85
- phoenix/server/api/types/DatasetExampleRevision.py +0 -34
- phoenix/server/api/types/DatasetVersion.py +0 -14
- phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
- phoenix/server/api/types/Experiment.py +0 -140
- phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
- phoenix/server/api/types/ExperimentComparison.py +0 -19
- phoenix/server/api/types/ExperimentRun.py +0 -91
- phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
- phoenix/server/api/types/Inferences.py +0 -80
- phoenix/server/api/types/InferencesRole.py +0 -23
- phoenix/utilities/json.py +0 -61
- phoenix/utilities/re.py +0 -50
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
phoenix/server/api/types/Span.py
CHANGED
|
@@ -1,33 +1,23 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from dataclasses import dataclass
|
|
3
2
|
from datetime import datetime
|
|
4
3
|
from enum import Enum
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any, List, Mapping, Optional, Sized, cast
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
import strawberry
|
|
9
8
|
from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
|
|
10
9
|
from strawberry import ID, UNSET
|
|
11
|
-
from strawberry.relay import Node, NodeID
|
|
12
10
|
from strawberry.types import Info
|
|
13
|
-
from typing_extensions import Annotated
|
|
14
11
|
|
|
15
12
|
import phoenix.trace.schemas as trace_schema
|
|
16
13
|
from phoenix.db import models
|
|
17
14
|
from phoenix.server.api.context import Context
|
|
18
|
-
from phoenix.server.api.helpers.dataset_helpers import (
|
|
19
|
-
get_dataset_example_input,
|
|
20
|
-
get_dataset_example_output,
|
|
21
|
-
)
|
|
22
15
|
from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
|
|
23
16
|
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
|
|
24
|
-
from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
|
|
25
17
|
from phoenix.server.api.types.MimeType import MimeType
|
|
18
|
+
from phoenix.server.api.types.node import Node
|
|
26
19
|
from phoenix.trace.attributes import get_attribute_value
|
|
27
20
|
|
|
28
|
-
if TYPE_CHECKING:
|
|
29
|
-
from phoenix.server.api.types.Project import Project
|
|
30
|
-
|
|
31
21
|
EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
|
|
32
22
|
EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
|
|
33
23
|
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
|
@@ -35,9 +25,6 @@ INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
|
35
25
|
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
36
26
|
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
37
27
|
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
|
38
|
-
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
39
|
-
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
40
|
-
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
41
28
|
METADATA = SpanAttributes.METADATA
|
|
42
29
|
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
43
30
|
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
@@ -59,7 +46,6 @@ class SpanKind(Enum):
|
|
|
59
46
|
embedding = "EMBEDDING"
|
|
60
47
|
agent = "AGENT"
|
|
61
48
|
reranker = "RERANKER"
|
|
62
|
-
evaluator = "EVALUATOR"
|
|
63
49
|
unknown = "UNKNOWN"
|
|
64
50
|
|
|
65
51
|
@classmethod
|
|
@@ -115,14 +101,8 @@ class SpanEvent:
|
|
|
115
101
|
)
|
|
116
102
|
|
|
117
103
|
|
|
118
|
-
@strawberry.type
|
|
119
|
-
class SpanAsExampleRevision(ExampleRevision): ...
|
|
120
|
-
|
|
121
|
-
|
|
122
104
|
@strawberry.type
|
|
123
105
|
class Span(Node):
|
|
124
|
-
id_attr: NodeID[int]
|
|
125
|
-
db_span: strawberry.Private[models.Span]
|
|
126
106
|
name: str
|
|
127
107
|
status_code: SpanStatusCode
|
|
128
108
|
status_message: str
|
|
@@ -208,44 +188,6 @@ class Span(Node):
|
|
|
208
188
|
spans = await info.context.data_loaders.span_descendants.load(span_id)
|
|
209
189
|
return [to_gql_span(span) for span in spans]
|
|
210
190
|
|
|
211
|
-
@strawberry.field(
|
|
212
|
-
description="The span's attributes translated into an example revision for a dataset",
|
|
213
|
-
) # type: ignore
|
|
214
|
-
def as_example_revision(self) -> SpanAsExampleRevision:
|
|
215
|
-
db_span = self.db_span
|
|
216
|
-
attributes = db_span.attributes
|
|
217
|
-
span_io = _SpanIO(
|
|
218
|
-
span_kind=db_span.span_kind,
|
|
219
|
-
input_value=get_attribute_value(attributes, INPUT_VALUE),
|
|
220
|
-
input_mime_type=get_attribute_value(attributes, INPUT_MIME_TYPE),
|
|
221
|
-
output_value=get_attribute_value(attributes, OUTPUT_VALUE),
|
|
222
|
-
output_mime_type=get_attribute_value(attributes, OUTPUT_MIME_TYPE),
|
|
223
|
-
llm_prompt_template_variables=get_attribute_value(
|
|
224
|
-
attributes, LLM_PROMPT_TEMPLATE_VARIABLES
|
|
225
|
-
),
|
|
226
|
-
llm_input_messages=get_attribute_value(attributes, LLM_INPUT_MESSAGES),
|
|
227
|
-
llm_output_messages=get_attribute_value(attributes, LLM_OUTPUT_MESSAGES),
|
|
228
|
-
retrieval_documents=get_attribute_value(attributes, RETRIEVAL_DOCUMENTS),
|
|
229
|
-
)
|
|
230
|
-
return SpanAsExampleRevision(
|
|
231
|
-
input=get_dataset_example_input(span_io),
|
|
232
|
-
output=get_dataset_example_output(span_io),
|
|
233
|
-
metadata=attributes,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
@strawberry.field(description="The project that this span belongs to.") # type: ignore
|
|
237
|
-
async def project(
|
|
238
|
-
self,
|
|
239
|
-
info: Info[Context, None],
|
|
240
|
-
) -> Annotated[
|
|
241
|
-
"Project", strawberry.lazy("phoenix.server.api.types.Project")
|
|
242
|
-
]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
|
|
243
|
-
from phoenix.server.api.types.Project import to_gql_project
|
|
244
|
-
|
|
245
|
-
span_id = self.id_attr
|
|
246
|
-
project = await info.context.data_loaders.span_projects.load(span_id)
|
|
247
|
-
return to_gql_project(project)
|
|
248
|
-
|
|
249
191
|
|
|
250
192
|
def to_gql_span(span: models.Span) -> Span:
|
|
251
193
|
events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
|
|
@@ -255,7 +197,6 @@ def to_gql_span(span: models.Span) -> Span:
|
|
|
255
197
|
num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None
|
|
256
198
|
return Span(
|
|
257
199
|
id_attr=span.id,
|
|
258
|
-
db_span=span,
|
|
259
200
|
name=span.name,
|
|
260
201
|
status_code=SpanStatusCode(span.status_code),
|
|
261
202
|
status_message=span.status_message,
|
|
@@ -361,21 +302,3 @@ def _convert_metadata_to_string(metadata: Any) -> Optional[str]:
|
|
|
361
302
|
return json.dumps(metadata)
|
|
362
303
|
except Exception:
|
|
363
304
|
return str(metadata)
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
@dataclass
|
|
367
|
-
class _SpanIO:
|
|
368
|
-
"""
|
|
369
|
-
An class that contains the information needed to extract dataset example
|
|
370
|
-
input and output values from a span.
|
|
371
|
-
"""
|
|
372
|
-
|
|
373
|
-
span_kind: Optional[str]
|
|
374
|
-
input_value: Any
|
|
375
|
-
input_mime_type: Optional[str]
|
|
376
|
-
output_value: Any
|
|
377
|
-
output_mime_type: Optional[str]
|
|
378
|
-
llm_prompt_template_variables: Any
|
|
379
|
-
llm_input_messages: Any
|
|
380
|
-
llm_output_messages: Any
|
|
381
|
-
retrieval_documents: Any
|
|
@@ -7,7 +7,7 @@ import pandas as pd
|
|
|
7
7
|
import strawberry
|
|
8
8
|
from strawberry import UNSET
|
|
9
9
|
|
|
10
|
-
from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column,
|
|
10
|
+
from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dataset, Dimension
|
|
11
11
|
from phoenix.metrics import Metric, binning
|
|
12
12
|
from phoenix.metrics.mixins import UnaryOperator
|
|
13
13
|
from phoenix.metrics.timeseries import timeseries
|
|
@@ -15,7 +15,7 @@ from phoenix.server.api.input_types.Granularity import Granularity, to_timestamp
|
|
|
15
15
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
16
16
|
from phoenix.server.api.interceptor import GqlValueMediator
|
|
17
17
|
from phoenix.server.api.types.DataQualityMetric import DataQualityMetric
|
|
18
|
-
from phoenix.server.api.types.
|
|
18
|
+
from phoenix.server.api.types.DatasetRole import DatasetRole
|
|
19
19
|
from phoenix.server.api.types.ScalarDriftMetricEnum import ScalarDriftMetric
|
|
20
20
|
from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
|
|
21
21
|
|
|
@@ -97,7 +97,7 @@ def get_data_quality_timeseries_data(
|
|
|
97
97
|
metric: DataQualityMetric,
|
|
98
98
|
time_range: TimeRange,
|
|
99
99
|
granularity: Granularity,
|
|
100
|
-
|
|
100
|
+
dataset_role: DatasetRole,
|
|
101
101
|
) -> List[TimeSeriesDataPoint]:
|
|
102
102
|
metric_instance = metric.value()
|
|
103
103
|
if isinstance(metric_instance, UnaryOperator):
|
|
@@ -106,7 +106,7 @@ def get_data_quality_timeseries_data(
|
|
|
106
106
|
operand=Column(dimension.name),
|
|
107
107
|
)
|
|
108
108
|
df = pd.DataFrame(
|
|
109
|
-
{dimension.name: dimension[
|
|
109
|
+
{dimension.name: dimension[dataset_role.value]},
|
|
110
110
|
copy=False,
|
|
111
111
|
)
|
|
112
112
|
return get_timeseries_data(
|
|
@@ -160,12 +160,12 @@ class PerformanceTimeSeries(TimeSeries):
|
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def ensure_timeseries_parameters(
|
|
163
|
-
|
|
163
|
+
dataset: Dataset,
|
|
164
164
|
time_range: Optional[TimeRange] = UNSET,
|
|
165
165
|
granularity: Optional[Granularity] = UNSET,
|
|
166
166
|
) -> Tuple[TimeRange, Granularity]:
|
|
167
167
|
if not isinstance(time_range, TimeRange):
|
|
168
|
-
start, stop =
|
|
168
|
+
start, stop = dataset.time_range
|
|
169
169
|
time_range = TimeRange(start=start, end=stop)
|
|
170
170
|
if not isinstance(granularity, Granularity):
|
|
171
171
|
total_minutes = int((time_range.end - time_range.start).total_seconds()) // 60
|
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
from typing import List, Optional
|
|
4
2
|
|
|
5
3
|
import strawberry
|
|
6
4
|
from sqlalchemy import desc, select
|
|
7
5
|
from sqlalchemy.orm import contains_eager
|
|
8
|
-
from strawberry import UNSET
|
|
9
|
-
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
6
|
+
from strawberry import UNSET
|
|
10
7
|
from strawberry.types import Info
|
|
11
8
|
|
|
12
9
|
from phoenix.db import models
|
|
13
10
|
from phoenix.server.api.context import Context
|
|
14
11
|
from phoenix.server.api.types.Evaluation import TraceEvaluation
|
|
12
|
+
from phoenix.server.api.types.node import Node
|
|
15
13
|
from phoenix.server.api.types.pagination import (
|
|
14
|
+
Connection,
|
|
16
15
|
ConnectionArgs,
|
|
17
16
|
CursorString,
|
|
18
17
|
connection_from_list,
|
|
@@ -22,16 +21,6 @@ from phoenix.server.api.types.Span import Span, to_gql_span
|
|
|
22
21
|
|
|
23
22
|
@strawberry.type
|
|
24
23
|
class Trace(Node):
|
|
25
|
-
id_attr: NodeID[int]
|
|
26
|
-
project_rowid: Private[int]
|
|
27
|
-
trace_id: str
|
|
28
|
-
|
|
29
|
-
@strawberry.field
|
|
30
|
-
async def project_id(self) -> GlobalID:
|
|
31
|
-
from phoenix.server.api.types.Project import Project
|
|
32
|
-
|
|
33
|
-
return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
|
|
34
|
-
|
|
35
24
|
@strawberry.field
|
|
36
25
|
async def spans(
|
|
37
26
|
self,
|
|
@@ -51,7 +40,7 @@ class Trace(Node):
|
|
|
51
40
|
select(models.Span)
|
|
52
41
|
.join(models.Trace)
|
|
53
42
|
.where(models.Trace.id == self.id_attr)
|
|
54
|
-
.options(contains_eager(models.Span.trace)
|
|
43
|
+
.options(contains_eager(models.Span.trace))
|
|
55
44
|
# Sort descending because the root span tends to show up later
|
|
56
45
|
# in the ingestion process.
|
|
57
46
|
.order_by(desc(models.Span.id))
|
|
@@ -3,13 +3,13 @@ from typing import List, Union
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import numpy.typing as npt
|
|
5
5
|
import strawberry
|
|
6
|
-
from strawberry.relay.types import GlobalID
|
|
7
6
|
from strawberry.scalars import ID
|
|
8
7
|
|
|
9
8
|
from phoenix.server.api.types.Cluster import Cluster
|
|
10
9
|
|
|
11
10
|
from .EmbeddingMetadata import EmbeddingMetadata
|
|
12
11
|
from .EventMetadata import EventMetadata
|
|
12
|
+
from .node import GlobalID
|
|
13
13
|
from .Retrieval import Retrieval
|
|
14
14
|
|
|
15
15
|
|
phoenix/server/api/types/node.py
CHANGED
|
@@ -1,19 +1,36 @@
|
|
|
1
|
-
|
|
1
|
+
import base64
|
|
2
|
+
import dataclasses
|
|
3
|
+
from typing import Tuple, Union
|
|
2
4
|
|
|
3
|
-
|
|
5
|
+
import strawberry
|
|
6
|
+
from graphql import GraphQLID
|
|
7
|
+
from strawberry.custom_scalar import ScalarDefinition
|
|
8
|
+
from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
|
|
4
9
|
|
|
5
10
|
|
|
6
|
-
def
|
|
11
|
+
def to_global_id(type_name: str, node_id: int) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Encode the given id into a global id.
|
|
14
|
+
|
|
15
|
+
:param type_name: The type of the node.
|
|
16
|
+
:param node_id: The id of the node.
|
|
17
|
+
:return: A global id.
|
|
18
|
+
"""
|
|
19
|
+
return base64.b64encode(f"{type_name}:{node_id}".encode("utf-8")).decode()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def from_global_id(global_id: str) -> Tuple[str, int]:
|
|
7
23
|
"""
|
|
8
24
|
Decode the given global id into a type and id.
|
|
9
25
|
|
|
10
26
|
:param global_id: The global id to decode.
|
|
11
27
|
:return: A tuple of type and id.
|
|
12
28
|
"""
|
|
13
|
-
|
|
29
|
+
type_name, node_id = base64.b64decode(global_id).decode().split(":")
|
|
30
|
+
return type_name, int(node_id)
|
|
14
31
|
|
|
15
32
|
|
|
16
|
-
def from_global_id_with_expected_type(global_id:
|
|
33
|
+
def from_global_id_with_expected_type(global_id: str, expected_type_name: str) -> int:
|
|
17
34
|
"""
|
|
18
35
|
Decodes the given global id and return the id, checking that the type
|
|
19
36
|
matches the expected type.
|
|
@@ -25,3 +42,92 @@ def from_global_id_with_expected_type(global_id: GlobalID, expected_type_name: s
|
|
|
25
42
|
f"but instead corresponds to a node of type: {type_name}"
|
|
26
43
|
)
|
|
27
44
|
return node_id
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GlobalIDValueError(ValueError):
|
|
48
|
+
"""GlobalID value error, usually related to parsing or serialization."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclasses.dataclass(frozen=True)
|
|
52
|
+
class GlobalID:
|
|
53
|
+
"""Global ID for relay types.
|
|
54
|
+
Different from `strawberry.ID`, this ID wraps the original object ID in a string
|
|
55
|
+
that contains both its GraphQL type name and the ID itself, and encodes it
|
|
56
|
+
to a base64_ string.
|
|
57
|
+
This object contains helpers to work with that, including method to retrieve
|
|
58
|
+
the python object type or even the encoded node itself.
|
|
59
|
+
Attributes:
|
|
60
|
+
type_name:
|
|
61
|
+
The type name part of the id
|
|
62
|
+
node_id:
|
|
63
|
+
The node id part of the id
|
|
64
|
+
.. _base64:
|
|
65
|
+
https://en.wikipedia.org/wiki/Base64
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
type_name: str
|
|
69
|
+
node_id: int
|
|
70
|
+
|
|
71
|
+
def __post_init__(self) -> None:
|
|
72
|
+
if not isinstance(self.type_name, str):
|
|
73
|
+
raise GlobalIDValueError(
|
|
74
|
+
f"type_name is expected to be a string, found {self.type_name}"
|
|
75
|
+
)
|
|
76
|
+
try:
|
|
77
|
+
# node_id could be numpy.int64, hence the need for coercion
|
|
78
|
+
object.__setattr__(self, "node_id", int(self.node_id))
|
|
79
|
+
except ValueError:
|
|
80
|
+
raise GlobalIDValueError(f"node_id is expected to be an int, found {self.node_id}")
|
|
81
|
+
|
|
82
|
+
def __str__(self) -> str:
|
|
83
|
+
return to_global_id(self.type_name, self.node_id)
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def from_id(cls, value: Union[str, strawberry.ID]) -> "GlobalID":
|
|
87
|
+
"""Create a new GlobalID from parsing the given value.
|
|
88
|
+
Args:
|
|
89
|
+
value:
|
|
90
|
+
The value to be parsed, as a base64 string in the "TypeName:NodeID" format
|
|
91
|
+
Returns:
|
|
92
|
+
An instance of GLobalID
|
|
93
|
+
Raises:
|
|
94
|
+
GlobalIDValueError:
|
|
95
|
+
If the value is not in a GLobalID format
|
|
96
|
+
"""
|
|
97
|
+
try:
|
|
98
|
+
type_name, node_id = from_global_id(value)
|
|
99
|
+
except ValueError as e:
|
|
100
|
+
raise GlobalIDValueError(str(e)) from e
|
|
101
|
+
|
|
102
|
+
return cls(type_name=type_name, node_id=node_id)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@strawberry.interface(description="A node in the graph with a globally unique ID")
|
|
106
|
+
class Node:
|
|
107
|
+
"""
|
|
108
|
+
All types that are relay ready should inherit from this interface and
|
|
109
|
+
implement the following methods.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
id_attr:
|
|
113
|
+
The raw id field of node. Typically a database id or index
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
id_attr: strawberry.Private[int]
|
|
117
|
+
|
|
118
|
+
@strawberry.field
|
|
119
|
+
def id(self) -> GlobalID:
|
|
120
|
+
return GlobalID(type(self).__name__, self.id_attr)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# Register our GlobalID scalar
|
|
124
|
+
DEFAULT_SCALAR_REGISTRY[GlobalID] = ScalarDefinition(
|
|
125
|
+
# Use the same name/description/parse_literal from GraphQLID
|
|
126
|
+
# specs expect this type to be "ID".
|
|
127
|
+
name="GlobalID",
|
|
128
|
+
description=GraphQLID.description,
|
|
129
|
+
parse_literal=lambda v, vars=None: GlobalID.from_id(GraphQLID.parse_literal(v, vars)),
|
|
130
|
+
parse_value=GlobalID.from_id,
|
|
131
|
+
serialize=str,
|
|
132
|
+
specified_by_url="https://relay.dev/graphql/objectidentification.htm",
|
|
133
|
+
)
|
|
@@ -2,18 +2,60 @@ import base64
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from enum import Enum, auto
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union
|
|
6
6
|
|
|
7
|
+
import strawberry
|
|
7
8
|
from strawberry import UNSET
|
|
8
|
-
from strawberry.relay.types import Connection, Edge, NodeType, PageInfo
|
|
9
9
|
from typing_extensions import TypeAlias, assert_never
|
|
10
10
|
|
|
11
11
|
ID: TypeAlias = int
|
|
12
|
+
GenericType = TypeVar("GenericType")
|
|
12
13
|
CursorSortColumnValue: TypeAlias = Union[str, int, float, datetime]
|
|
13
14
|
|
|
15
|
+
|
|
16
|
+
@strawberry.type
|
|
17
|
+
class Connection(Generic[GenericType]):
|
|
18
|
+
"""Represents a paginated relationship between two entities
|
|
19
|
+
|
|
20
|
+
This pattern is used when the relationship itself has attributes.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
page_info: "PageInfo"
|
|
24
|
+
edges: List["Edge[GenericType]"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@strawberry.type
|
|
28
|
+
class PageInfo:
|
|
29
|
+
"""Pagination context to navigate objects with cursor-based pagination
|
|
30
|
+
|
|
31
|
+
Instead of classic offset pagination via `page` and `limit` parameters,
|
|
32
|
+
here we have a cursor of the last object and we fetch items starting from that one
|
|
33
|
+
|
|
34
|
+
Read more at:
|
|
35
|
+
- https://graphql.org/learn/pagination/#pagination-and-edges
|
|
36
|
+
- https://relay.dev/graphql/connections.htm
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
has_next_page: bool
|
|
40
|
+
has_previous_page: bool
|
|
41
|
+
start_cursor: Optional[str]
|
|
42
|
+
end_cursor: Optional[str]
|
|
43
|
+
|
|
44
|
+
|
|
14
45
|
# A type alias for the connection cursor implementation
|
|
15
46
|
CursorString = str
|
|
16
47
|
|
|
48
|
+
|
|
49
|
+
@strawberry.type
|
|
50
|
+
class Edge(Generic[GenericType]):
|
|
51
|
+
"""
|
|
52
|
+
An edge may contain additional information of the relationship. This is the trivial case
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
node: GenericType
|
|
56
|
+
cursor: str
|
|
57
|
+
|
|
58
|
+
|
|
17
59
|
# The hashing prefix for a connection cursor
|
|
18
60
|
CURSOR_PREFIX = "connection:"
|
|
19
61
|
|
|
@@ -176,9 +218,9 @@ class ConnectionArgs:
|
|
|
176
218
|
|
|
177
219
|
|
|
178
220
|
def connection_from_list(
|
|
179
|
-
data: List[
|
|
221
|
+
data: List[GenericType],
|
|
180
222
|
args: ConnectionArgs,
|
|
181
|
-
) -> Connection[
|
|
223
|
+
) -> Connection[GenericType]:
|
|
182
224
|
"""
|
|
183
225
|
A simple function that accepts a list and connection arguments, and returns
|
|
184
226
|
a connection object for use in GraphQL. It uses list offsets as pagination,
|
|
@@ -188,11 +230,11 @@ def connection_from_list(
|
|
|
188
230
|
|
|
189
231
|
|
|
190
232
|
def connection_from_list_slice(
|
|
191
|
-
list_slice: List[
|
|
233
|
+
list_slice: List[GenericType],
|
|
192
234
|
args: ConnectionArgs,
|
|
193
235
|
slice_start: int,
|
|
194
236
|
list_length: int,
|
|
195
|
-
) -> Connection[
|
|
237
|
+
) -> Connection[GenericType]:
|
|
196
238
|
"""
|
|
197
239
|
Given a slice (subset) of a list, returns a connection object for use in
|
|
198
240
|
GraphQL.
|
|
@@ -253,12 +295,12 @@ def connection_from_list_slice(
|
|
|
253
295
|
)
|
|
254
296
|
|
|
255
297
|
|
|
256
|
-
def
|
|
257
|
-
|
|
298
|
+
def connections(
|
|
299
|
+
data: List[Tuple[Cursor, GenericType]],
|
|
258
300
|
has_previous_page: bool,
|
|
259
301
|
has_next_page: bool,
|
|
260
|
-
) -> Connection[
|
|
261
|
-
edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in
|
|
302
|
+
) -> Connection[GenericType]:
|
|
303
|
+
edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in data]
|
|
262
304
|
has_edges = len(edges) > 0
|
|
263
305
|
first_edge = edges[0] if has_edges else None
|
|
264
306
|
last_edge = edges[-1] if has_edges else None
|