arize-phoenix 4.4.4rc6__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +8 -14
  2. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +58 -122
  3. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -42
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/datetime_utils.py +0 -4
  10. phoenix/db/bulk_inserter.py +14 -54
  11. phoenix/db/insertion/evaluation.py +10 -10
  12. phoenix/db/insertion/helpers.py +14 -17
  13. phoenix/db/insertion/span.py +3 -3
  14. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  15. phoenix/db/models.py +4 -236
  16. phoenix/inferences/fixtures.py +23 -23
  17. phoenix/inferences/inferences.py +7 -7
  18. phoenix/inferences/validation.py +1 -1
  19. phoenix/server/api/context.py +0 -20
  20. phoenix/server/api/dataloaders/__init__.py +0 -20
  21. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  22. phoenix/server/api/routers/v1/__init__.py +2 -77
  23. phoenix/server/api/routers/v1/evaluations.py +13 -8
  24. phoenix/server/api/routers/v1/spans.py +5 -9
  25. phoenix/server/api/routers/v1/traces.py +4 -1
  26. phoenix/server/api/schema.py +303 -2
  27. phoenix/server/api/types/Cluster.py +19 -19
  28. phoenix/server/api/types/Dataset.py +63 -282
  29. phoenix/server/api/types/DatasetRole.py +23 -0
  30. phoenix/server/api/types/Dimension.py +29 -30
  31. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  32. phoenix/server/api/types/Event.py +16 -16
  33. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  34. phoenix/server/api/types/Model.py +42 -43
  35. phoenix/server/api/types/Project.py +12 -26
  36. phoenix/server/api/types/Span.py +2 -79
  37. phoenix/server/api/types/TimeSeries.py +6 -6
  38. phoenix/server/api/types/Trace.py +4 -15
  39. phoenix/server/api/types/UMAPPoints.py +1 -1
  40. phoenix/server/api/types/node.py +111 -5
  41. phoenix/server/api/types/pagination.py +52 -10
  42. phoenix/server/app.py +49 -103
  43. phoenix/server/main.py +27 -49
  44. phoenix/server/openapi/docs.py +0 -3
  45. phoenix/server/static/index.js +1384 -2390
  46. phoenix/server/templates/index.html +0 -1
  47. phoenix/services.py +15 -15
  48. phoenix/session/client.py +23 -611
  49. phoenix/session/session.py +37 -47
  50. phoenix/trace/exporter.py +9 -14
  51. phoenix/trace/fixtures.py +7 -133
  52. phoenix/trace/schemas.py +2 -1
  53. phoenix/trace/span_evaluations.py +3 -3
  54. phoenix/trace/trace_dataset.py +6 -6
  55. phoenix/version.py +1 -1
  56. phoenix/db/insertion/dataset.py +0 -237
  57. phoenix/db/migrations/types.py +0 -29
  58. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  59. phoenix/experiments/__init__.py +0 -6
  60. phoenix/experiments/evaluators/__init__.py +0 -29
  61. phoenix/experiments/evaluators/base.py +0 -153
  62. phoenix/experiments/evaluators/code_evaluators.py +0 -99
  63. phoenix/experiments/evaluators/llm_evaluators.py +0 -244
  64. phoenix/experiments/evaluators/utils.py +0 -189
  65. phoenix/experiments/functions.py +0 -616
  66. phoenix/experiments/tracing.py +0 -85
  67. phoenix/experiments/types.py +0 -722
  68. phoenix/experiments/utils.py +0 -9
  69. phoenix/server/api/dataloaders/average_experiment_run_latency.py +0 -54
  70. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  71. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  72. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  73. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  74. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  75. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  76. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  77. phoenix/server/api/dataloaders/span_projects.py +0 -33
  78. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  79. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  80. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  81. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  82. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  83. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  84. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  85. phoenix/server/api/input_types/DatasetSort.py +0 -17
  86. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  87. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  88. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  89. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  90. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  91. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  92. phoenix/server/api/mutations/__init__.py +0 -13
  93. phoenix/server/api/mutations/auth.py +0 -11
  94. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  95. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  96. phoenix/server/api/mutations/project_mutations.py +0 -47
  97. phoenix/server/api/openapi/__init__.py +0 -0
  98. phoenix/server/api/openapi/main.py +0 -6
  99. phoenix/server/api/openapi/schema.py +0 -16
  100. phoenix/server/api/queries.py +0 -503
  101. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  102. phoenix/server/api/routers/v1/datasets.py +0 -965
  103. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -65
  104. phoenix/server/api/routers/v1/experiment_runs.py +0 -96
  105. phoenix/server/api/routers/v1/experiments.py +0 -174
  106. phoenix/server/api/types/AnnotatorKind.py +0 -10
  107. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  108. phoenix/server/api/types/DatasetExample.py +0 -85
  109. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  110. phoenix/server/api/types/DatasetVersion.py +0 -14
  111. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  112. phoenix/server/api/types/Experiment.py +0 -147
  113. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  114. phoenix/server/api/types/ExperimentComparison.py +0 -19
  115. phoenix/server/api/types/ExperimentRun.py +0 -91
  116. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  117. phoenix/server/api/types/Inferences.py +0 -80
  118. phoenix/server/api/types/InferencesRole.py +0 -23
  119. phoenix/utilities/json.py +0 -61
  120. phoenix/utilities/re.py +0 -50
  121. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -1,33 +1,23 @@
1
1
  import json
2
- from dataclasses import dataclass
3
2
  from datetime import datetime
4
3
  from enum import Enum
5
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sized, cast
4
+ from typing import Any, List, Mapping, Optional, Sized, cast
6
5
 
7
6
  import numpy as np
8
7
  import strawberry
9
8
  from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
10
9
  from strawberry import ID, UNSET
11
- from strawberry.relay import Node, NodeID
12
10
  from strawberry.types import Info
13
- from typing_extensions import Annotated
14
11
 
15
12
  import phoenix.trace.schemas as trace_schema
16
13
  from phoenix.db import models
17
14
  from phoenix.server.api.context import Context
18
- from phoenix.server.api.helpers.dataset_helpers import (
19
- get_dataset_example_input,
20
- get_dataset_example_output,
21
- )
22
15
  from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
23
16
  from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
24
- from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
25
17
  from phoenix.server.api.types.MimeType import MimeType
18
+ from phoenix.server.api.types.node import Node
26
19
  from phoenix.trace.attributes import get_attribute_value
27
20
 
28
- if TYPE_CHECKING:
29
- from phoenix.server.api.types.Project import Project
30
-
31
21
  EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
32
22
  EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
33
23
  INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
@@ -35,9 +25,6 @@ INPUT_VALUE = SpanAttributes.INPUT_VALUE
35
25
  LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
36
26
  LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
37
27
  LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
38
- LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
39
- LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
40
- LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
41
28
  METADATA = SpanAttributes.METADATA
42
29
  OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
43
30
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
@@ -59,7 +46,6 @@ class SpanKind(Enum):
59
46
  embedding = "EMBEDDING"
60
47
  agent = "AGENT"
61
48
  reranker = "RERANKER"
62
- evaluator = "EVALUATOR"
63
49
  unknown = "UNKNOWN"
64
50
 
65
51
  @classmethod
@@ -115,14 +101,8 @@ class SpanEvent:
115
101
  )
116
102
 
117
103
 
118
- @strawberry.type
119
- class SpanAsExampleRevision(ExampleRevision): ...
120
-
121
-
122
104
  @strawberry.type
123
105
  class Span(Node):
124
- id_attr: NodeID[int]
125
- db_span: strawberry.Private[models.Span]
126
106
  name: str
127
107
  status_code: SpanStatusCode
128
108
  status_message: str
@@ -208,44 +188,6 @@ class Span(Node):
208
188
  spans = await info.context.data_loaders.span_descendants.load(span_id)
209
189
  return [to_gql_span(span) for span in spans]
210
190
 
211
- @strawberry.field(
212
- description="The span's attributes translated into an example revision for a dataset",
213
- ) # type: ignore
214
- def as_example_revision(self) -> SpanAsExampleRevision:
215
- db_span = self.db_span
216
- attributes = db_span.attributes
217
- span_io = _SpanIO(
218
- span_kind=db_span.span_kind,
219
- input_value=get_attribute_value(attributes, INPUT_VALUE),
220
- input_mime_type=get_attribute_value(attributes, INPUT_MIME_TYPE),
221
- output_value=get_attribute_value(attributes, OUTPUT_VALUE),
222
- output_mime_type=get_attribute_value(attributes, OUTPUT_MIME_TYPE),
223
- llm_prompt_template_variables=get_attribute_value(
224
- attributes, LLM_PROMPT_TEMPLATE_VARIABLES
225
- ),
226
- llm_input_messages=get_attribute_value(attributes, LLM_INPUT_MESSAGES),
227
- llm_output_messages=get_attribute_value(attributes, LLM_OUTPUT_MESSAGES),
228
- retrieval_documents=get_attribute_value(attributes, RETRIEVAL_DOCUMENTS),
229
- )
230
- return SpanAsExampleRevision(
231
- input=get_dataset_example_input(span_io),
232
- output=get_dataset_example_output(span_io),
233
- metadata=attributes,
234
- )
235
-
236
- @strawberry.field(description="The project that this span belongs to.") # type: ignore
237
- async def project(
238
- self,
239
- info: Info[Context, None],
240
- ) -> Annotated[
241
- "Project", strawberry.lazy("phoenix.server.api.types.Project")
242
- ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
243
- from phoenix.server.api.types.Project import to_gql_project
244
-
245
- span_id = self.id_attr
246
- project = await info.context.data_loaders.span_projects.load(span_id)
247
- return to_gql_project(project)
248
-
249
191
 
250
192
  def to_gql_span(span: models.Span) -> Span:
251
193
  events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
@@ -255,7 +197,6 @@ def to_gql_span(span: models.Span) -> Span:
255
197
  num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None
256
198
  return Span(
257
199
  id_attr=span.id,
258
- db_span=span,
259
200
  name=span.name,
260
201
  status_code=SpanStatusCode(span.status_code),
261
202
  status_message=span.status_message,
@@ -361,21 +302,3 @@ def _convert_metadata_to_string(metadata: Any) -> Optional[str]:
361
302
  return json.dumps(metadata)
362
303
  except Exception:
363
304
  return str(metadata)
364
-
365
-
366
- @dataclass
367
- class _SpanIO:
368
- """
369
- An class that contains the information needed to extract dataset example
370
- input and output values from a span.
371
- """
372
-
373
- span_kind: Optional[str]
374
- input_value: Any
375
- input_mime_type: Optional[str]
376
- output_value: Any
377
- output_mime_type: Optional[str]
378
- llm_prompt_template_variables: Any
379
- llm_input_messages: Any
380
- llm_output_messages: Any
381
- retrieval_documents: Any
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  import strawberry
8
8
  from strawberry import UNSET
9
9
 
10
- from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dimension, Inferences
10
+ from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dataset, Dimension
11
11
  from phoenix.metrics import Metric, binning
12
12
  from phoenix.metrics.mixins import UnaryOperator
13
13
  from phoenix.metrics.timeseries import timeseries
@@ -15,7 +15,7 @@ from phoenix.server.api.input_types.Granularity import Granularity, to_timestamp
15
15
  from phoenix.server.api.input_types.TimeRange import TimeRange
16
16
  from phoenix.server.api.interceptor import GqlValueMediator
17
17
  from phoenix.server.api.types.DataQualityMetric import DataQualityMetric
18
- from phoenix.server.api.types.InferencesRole import InferencesRole
18
+ from phoenix.server.api.types.DatasetRole import DatasetRole
19
19
  from phoenix.server.api.types.ScalarDriftMetricEnum import ScalarDriftMetric
20
20
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
21
21
 
@@ -97,7 +97,7 @@ def get_data_quality_timeseries_data(
97
97
  metric: DataQualityMetric,
98
98
  time_range: TimeRange,
99
99
  granularity: Granularity,
100
- inferences_role: InferencesRole,
100
+ dataset_role: DatasetRole,
101
101
  ) -> List[TimeSeriesDataPoint]:
102
102
  metric_instance = metric.value()
103
103
  if isinstance(metric_instance, UnaryOperator):
@@ -106,7 +106,7 @@ def get_data_quality_timeseries_data(
106
106
  operand=Column(dimension.name),
107
107
  )
108
108
  df = pd.DataFrame(
109
- {dimension.name: dimension[inferences_role.value]},
109
+ {dimension.name: dimension[dataset_role.value]},
110
110
  copy=False,
111
111
  )
112
112
  return get_timeseries_data(
@@ -160,12 +160,12 @@ class PerformanceTimeSeries(TimeSeries):
160
160
 
161
161
 
162
162
  def ensure_timeseries_parameters(
163
- inferences: Inferences,
163
+ dataset: Dataset,
164
164
  time_range: Optional[TimeRange] = UNSET,
165
165
  granularity: Optional[Granularity] = UNSET,
166
166
  ) -> Tuple[TimeRange, Granularity]:
167
167
  if not isinstance(time_range, TimeRange):
168
- start, stop = inferences.time_range
168
+ start, stop = dataset.time_range
169
169
  time_range = TimeRange(start=start, end=stop)
170
170
  if not isinstance(granularity, Granularity):
171
171
  total_minutes = int((time_range.end - time_range.start).total_seconds()) // 60
@@ -1,18 +1,17 @@
1
- from __future__ import annotations
2
-
3
1
  from typing import List, Optional
4
2
 
5
3
  import strawberry
6
4
  from sqlalchemy import desc, select
7
5
  from sqlalchemy.orm import contains_eager
8
- from strawberry import UNSET, Private
9
- from strawberry.relay import Connection, GlobalID, Node, NodeID
6
+ from strawberry import UNSET
10
7
  from strawberry.types import Info
11
8
 
12
9
  from phoenix.db import models
13
10
  from phoenix.server.api.context import Context
14
11
  from phoenix.server.api.types.Evaluation import TraceEvaluation
12
+ from phoenix.server.api.types.node import Node
15
13
  from phoenix.server.api.types.pagination import (
14
+ Connection,
16
15
  ConnectionArgs,
17
16
  CursorString,
18
17
  connection_from_list,
@@ -22,16 +21,6 @@ from phoenix.server.api.types.Span import Span, to_gql_span
22
21
 
23
22
  @strawberry.type
24
23
  class Trace(Node):
25
- id_attr: NodeID[int]
26
- project_rowid: Private[int]
27
- trace_id: str
28
-
29
- @strawberry.field
30
- async def project_id(self) -> GlobalID:
31
- from phoenix.server.api.types.Project import Project
32
-
33
- return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
34
-
35
24
  @strawberry.field
36
25
  async def spans(
37
26
  self,
@@ -51,7 +40,7 @@ class Trace(Node):
51
40
  select(models.Span)
52
41
  .join(models.Trace)
53
42
  .where(models.Trace.id == self.id_attr)
54
- .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
43
+ .options(contains_eager(models.Span.trace))
55
44
  # Sort descending because the root span tends to show up later
56
45
  # in the ingestion process.
57
46
  .order_by(desc(models.Span.id))
@@ -3,13 +3,13 @@ from typing import List, Union
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
  import strawberry
6
- from strawberry.relay.types import GlobalID
7
6
  from strawberry.scalars import ID
8
7
 
9
8
  from phoenix.server.api.types.Cluster import Cluster
10
9
 
11
10
  from .EmbeddingMetadata import EmbeddingMetadata
12
11
  from .EventMetadata import EventMetadata
12
+ from .node import GlobalID
13
13
  from .Retrieval import Retrieval
14
14
 
15
15
 
@@ -1,19 +1,36 @@
1
- from typing import Tuple
1
+ import base64
2
+ import dataclasses
3
+ from typing import Tuple, Union
2
4
 
3
- from strawberry.relay import GlobalID
5
+ import strawberry
6
+ from graphql import GraphQLID
7
+ from strawberry.custom_scalar import ScalarDefinition
8
+ from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
4
9
 
5
10
 
6
- def from_global_id(global_id: GlobalID) -> Tuple[str, int]:
11
+ def to_global_id(type_name: str, node_id: int) -> str:
12
+ """
13
+ Encode the given id into a global id.
14
+
15
+ :param type_name: The type of the node.
16
+ :param node_id: The id of the node.
17
+ :return: A global id.
18
+ """
19
+ return base64.b64encode(f"{type_name}:{node_id}".encode("utf-8")).decode()
20
+
21
+
22
+ def from_global_id(global_id: str) -> Tuple[str, int]:
7
23
  """
8
24
  Decode the given global id into a type and id.
9
25
 
10
26
  :param global_id: The global id to decode.
11
27
  :return: A tuple of type and id.
12
28
  """
13
- return global_id.type_name, int(global_id.node_id)
29
+ type_name, node_id = base64.b64decode(global_id).decode().split(":")
30
+ return type_name, int(node_id)
14
31
 
15
32
 
16
- def from_global_id_with_expected_type(global_id: GlobalID, expected_type_name: str) -> int:
33
+ def from_global_id_with_expected_type(global_id: str, expected_type_name: str) -> int:
17
34
  """
18
35
  Decodes the given global id and return the id, checking that the type
19
36
  matches the expected type.
@@ -25,3 +42,92 @@ def from_global_id_with_expected_type(global_id: GlobalID, expected_type_name: s
25
42
  f"but instead corresponds to a node of type: {type_name}"
26
43
  )
27
44
  return node_id
45
+
46
+
47
+ class GlobalIDValueError(ValueError):
48
+ """GlobalID value error, usually related to parsing or serialization."""
49
+
50
+
51
+ @dataclasses.dataclass(frozen=True)
52
+ class GlobalID:
53
+ """Global ID for relay types.
54
+ Different from `strawberry.ID`, this ID wraps the original object ID in a string
55
+ that contains both its GraphQL type name and the ID itself, and encodes it
56
+ to a base64_ string.
57
+ This object contains helpers to work with that, including method to retrieve
58
+ the python object type or even the encoded node itself.
59
+ Attributes:
60
+ type_name:
61
+ The type name part of the id
62
+ node_id:
63
+ The node id part of the id
64
+ .. _base64:
65
+ https://en.wikipedia.org/wiki/Base64
66
+ """
67
+
68
+ type_name: str
69
+ node_id: int
70
+
71
+ def __post_init__(self) -> None:
72
+ if not isinstance(self.type_name, str):
73
+ raise GlobalIDValueError(
74
+ f"type_name is expected to be a string, found {self.type_name}"
75
+ )
76
+ try:
77
+ # node_id could be numpy.int64, hence the need for coercion
78
+ object.__setattr__(self, "node_id", int(self.node_id))
79
+ except ValueError:
80
+ raise GlobalIDValueError(f"node_id is expected to be an int, found {self.node_id}")
81
+
82
+ def __str__(self) -> str:
83
+ return to_global_id(self.type_name, self.node_id)
84
+
85
+ @classmethod
86
+ def from_id(cls, value: Union[str, strawberry.ID]) -> "GlobalID":
87
+ """Create a new GlobalID from parsing the given value.
88
+ Args:
89
+ value:
90
+ The value to be parsed, as a base64 string in the "TypeName:NodeID" format
91
+ Returns:
92
+ An instance of GLobalID
93
+ Raises:
94
+ GlobalIDValueError:
95
+ If the value is not in a GLobalID format
96
+ """
97
+ try:
98
+ type_name, node_id = from_global_id(value)
99
+ except ValueError as e:
100
+ raise GlobalIDValueError(str(e)) from e
101
+
102
+ return cls(type_name=type_name, node_id=node_id)
103
+
104
+
105
+ @strawberry.interface(description="A node in the graph with a globally unique ID")
106
+ class Node:
107
+ """
108
+ All types that are relay ready should inherit from this interface and
109
+ implement the following methods.
110
+
111
+ Attributes:
112
+ id_attr:
113
+ The raw id field of node. Typically a database id or index
114
+ """
115
+
116
+ id_attr: strawberry.Private[int]
117
+
118
+ @strawberry.field
119
+ def id(self) -> GlobalID:
120
+ return GlobalID(type(self).__name__, self.id_attr)
121
+
122
+
123
+ # Register our GlobalID scalar
124
+ DEFAULT_SCALAR_REGISTRY[GlobalID] = ScalarDefinition(
125
+ # Use the same name/description/parse_literal from GraphQLID
126
+ # specs expect this type to be "ID".
127
+ name="GlobalID",
128
+ description=GraphQLID.description,
129
+ parse_literal=lambda v, vars=None: GlobalID.from_id(GraphQLID.parse_literal(v, vars)),
130
+ parse_value=GlobalID.from_id,
131
+ serialize=str,
132
+ specified_by_url="https://relay.dev/graphql/objectidentification.htm",
133
+ )
@@ -2,18 +2,60 @@ import base64
2
2
  from dataclasses import dataclass
3
3
  from datetime import datetime
4
4
  from enum import Enum, auto
5
- from typing import Any, ClassVar, List, Optional, Tuple, Union
5
+ from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union
6
6
 
7
+ import strawberry
7
8
  from strawberry import UNSET
8
- from strawberry.relay.types import Connection, Edge, NodeType, PageInfo
9
9
  from typing_extensions import TypeAlias, assert_never
10
10
 
11
11
  ID: TypeAlias = int
12
+ GenericType = TypeVar("GenericType")
12
13
  CursorSortColumnValue: TypeAlias = Union[str, int, float, datetime]
13
14
 
15
+
16
+ @strawberry.type
17
+ class Connection(Generic[GenericType]):
18
+ """Represents a paginated relationship between two entities
19
+
20
+ This pattern is used when the relationship itself has attributes.
21
+ """
22
+
23
+ page_info: "PageInfo"
24
+ edges: List["Edge[GenericType]"]
25
+
26
+
27
+ @strawberry.type
28
+ class PageInfo:
29
+ """Pagination context to navigate objects with cursor-based pagination
30
+
31
+ Instead of classic offset pagination via `page` and `limit` parameters,
32
+ here we have a cursor of the last object and we fetch items starting from that one
33
+
34
+ Read more at:
35
+ - https://graphql.org/learn/pagination/#pagination-and-edges
36
+ - https://relay.dev/graphql/connections.htm
37
+ """
38
+
39
+ has_next_page: bool
40
+ has_previous_page: bool
41
+ start_cursor: Optional[str]
42
+ end_cursor: Optional[str]
43
+
44
+
14
45
  # A type alias for the connection cursor implementation
15
46
  CursorString = str
16
47
 
48
+
49
+ @strawberry.type
50
+ class Edge(Generic[GenericType]):
51
+ """
52
+ An edge may contain additional information of the relationship. This is the trivial case
53
+ """
54
+
55
+ node: GenericType
56
+ cursor: str
57
+
58
+
17
59
  # The hashing prefix for a connection cursor
18
60
  CURSOR_PREFIX = "connection:"
19
61
 
@@ -176,9 +218,9 @@ class ConnectionArgs:
176
218
 
177
219
 
178
220
  def connection_from_list(
179
- data: List[NodeType],
221
+ data: List[GenericType],
180
222
  args: ConnectionArgs,
181
- ) -> Connection[NodeType]:
223
+ ) -> Connection[GenericType]:
182
224
  """
183
225
  A simple function that accepts a list and connection arguments, and returns
184
226
  a connection object for use in GraphQL. It uses list offsets as pagination,
@@ -188,11 +230,11 @@ def connection_from_list(
188
230
 
189
231
 
190
232
  def connection_from_list_slice(
191
- list_slice: List[NodeType],
233
+ list_slice: List[GenericType],
192
234
  args: ConnectionArgs,
193
235
  slice_start: int,
194
236
  list_length: int,
195
- ) -> Connection[NodeType]:
237
+ ) -> Connection[GenericType]:
196
238
  """
197
239
  Given a slice (subset) of a list, returns a connection object for use in
198
240
  GraphQL.
@@ -253,12 +295,12 @@ def connection_from_list_slice(
253
295
  )
254
296
 
255
297
 
256
- def connection_from_cursors_and_nodes(
257
- cursors_and_nodes: List[Tuple[Any, NodeType]],
298
+ def connections(
299
+ data: List[Tuple[Cursor, GenericType]],
258
300
  has_previous_page: bool,
259
301
  has_next_page: bool,
260
- ) -> Connection[NodeType]:
261
- edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in cursors_and_nodes]
302
+ ) -> Connection[GenericType]:
303
+ edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in data]
262
304
  has_edges = len(edges) > 0
263
305
  first_edge = edges[0] if has_edges else None
264
306
  last_edge = edges[-1] if has_edges else None