arize-phoenix 5.5.1__py3-none-any.whl → 5.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (172) hide show
  1. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/METADATA +8 -11
  2. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/RECORD +171 -171
  3. phoenix/config.py +8 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  72. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  73. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  74. phoenix/server/api/input_types/ClusterInput.py +2 -2
  75. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  76. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  77. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  78. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  79. phoenix/server/api/input_types/Granularity.py +1 -1
  80. phoenix/server/api/input_types/InvocationParameters.py +2 -2
  81. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  82. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  83. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  84. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  85. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  86. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  87. phoenix/server/api/mutations/user_mutations.py +4 -4
  88. phoenix/server/api/openapi/schema.py +2 -2
  89. phoenix/server/api/queries.py +20 -20
  90. phoenix/server/api/routers/oauth2.py +4 -4
  91. phoenix/server/api/routers/v1/datasets.py +22 -36
  92. phoenix/server/api/routers/v1/evaluations.py +6 -5
  93. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  94. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  95. phoenix/server/api/routers/v1/experiments.py +4 -4
  96. phoenix/server/api/routers/v1/spans.py +13 -12
  97. phoenix/server/api/routers/v1/traces.py +5 -5
  98. phoenix/server/api/routers/v1/utils.py +5 -5
  99. phoenix/server/api/subscriptions.py +289 -167
  100. phoenix/server/api/types/AnnotationSummary.py +3 -3
  101. phoenix/server/api/types/Cluster.py +8 -7
  102. phoenix/server/api/types/Dataset.py +5 -4
  103. phoenix/server/api/types/Dimension.py +3 -3
  104. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  105. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  106. phoenix/server/api/types/EvaluationSummary.py +3 -3
  107. phoenix/server/api/types/Event.py +7 -7
  108. phoenix/server/api/types/Experiment.py +3 -3
  109. phoenix/server/api/types/ExperimentComparison.py +2 -4
  110. phoenix/server/api/types/Inferences.py +9 -8
  111. phoenix/server/api/types/InferencesRole.py +2 -2
  112. phoenix/server/api/types/Model.py +2 -2
  113. phoenix/server/api/types/Project.py +11 -18
  114. phoenix/server/api/types/Segments.py +3 -3
  115. phoenix/server/api/types/Span.py +8 -7
  116. phoenix/server/api/types/TimeSeries.py +8 -7
  117. phoenix/server/api/types/Trace.py +2 -2
  118. phoenix/server/api/types/UMAPPoints.py +6 -6
  119. phoenix/server/api/types/User.py +3 -3
  120. phoenix/server/api/types/node.py +1 -3
  121. phoenix/server/api/types/pagination.py +4 -4
  122. phoenix/server/api/utils.py +2 -4
  123. phoenix/server/app.py +16 -25
  124. phoenix/server/bearer_auth.py +4 -10
  125. phoenix/server/dml_event.py +3 -3
  126. phoenix/server/dml_event_handler.py +10 -24
  127. phoenix/server/grpc_server.py +3 -2
  128. phoenix/server/jwt_store.py +22 -21
  129. phoenix/server/main.py +3 -3
  130. phoenix/server/oauth2.py +3 -2
  131. phoenix/server/rate_limiters.py +5 -8
  132. phoenix/server/static/.vite/manifest.json +31 -31
  133. phoenix/server/static/assets/components-C70HJiXz.js +1612 -0
  134. phoenix/server/static/assets/{index-BHfTZ6x_.js → index-DLe1Oo3l.js} +2 -2
  135. phoenix/server/static/assets/{pages-aAez_Ntk.js → pages-C8-Sl7JI.js} +269 -434
  136. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  137. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  138. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  139. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  140. phoenix/server/thread_server.py +1 -1
  141. phoenix/server/types.py +17 -29
  142. phoenix/services.py +4 -3
  143. phoenix/session/client.py +12 -24
  144. phoenix/session/data_extractor.py +3 -3
  145. phoenix/session/evaluation.py +1 -2
  146. phoenix/session/session.py +11 -20
  147. phoenix/trace/attributes.py +16 -28
  148. phoenix/trace/dsl/filter.py +17 -21
  149. phoenix/trace/dsl/helpers.py +3 -3
  150. phoenix/trace/dsl/query.py +13 -22
  151. phoenix/trace/fixtures.py +11 -17
  152. phoenix/trace/otel.py +5 -15
  153. phoenix/trace/projects.py +3 -2
  154. phoenix/trace/schemas.py +2 -2
  155. phoenix/trace/span_evaluations.py +9 -8
  156. phoenix/trace/span_json_decoder.py +3 -3
  157. phoenix/trace/span_json_encoder.py +2 -2
  158. phoenix/trace/trace_dataset.py +6 -5
  159. phoenix/trace/utils.py +6 -6
  160. phoenix/utilities/deprecation.py +3 -2
  161. phoenix/utilities/error_handling.py +3 -2
  162. phoenix/utilities/json.py +2 -1
  163. phoenix/utilities/logging.py +2 -2
  164. phoenix/utilities/project.py +1 -1
  165. phoenix/utilities/re.py +3 -4
  166. phoenix/utilities/template_formatters.py +5 -4
  167. phoenix/version.py +1 -1
  168. phoenix/server/static/assets/components-mVBxvljU.js +0 -1428
  169. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/WHEEL +0 -0
  170. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/entry_points.txt +0 -0
  171. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/IP_NOTICE +0 -0
  172. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
- from typing import (
4
- Any,
5
- DefaultDict,
6
- List,
7
- Literal,
8
- Optional,
9
- Tuple,
10
- )
3
+ from typing import Any, Literal, Optional
11
4
 
12
5
  from cachetools import LFUCache, TTLCache
13
6
  from sqlalchemy import Select, func, select
@@ -22,20 +15,20 @@ from phoenix.trace.dsl import SpanFilter
22
15
 
23
16
  Kind: TypeAlias = Literal["span", "trace"]
24
17
  ProjectRowId: TypeAlias = int
25
- TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
18
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
26
19
  FilterCondition: TypeAlias = Optional[str]
27
20
  SpanCount: TypeAlias = int
28
21
 
29
- Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition]
22
+ Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
30
23
  Param: TypeAlias = ProjectRowId
31
24
 
32
- Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
25
+ Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
33
26
  Result: TypeAlias = SpanCount
34
27
  ResultPosition: TypeAlias = int
35
28
  DEFAULT_VALUE: Result = 0
36
29
 
37
30
 
38
- def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
31
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
39
32
  kind, project_rowid, time_range, filter_condition = key
40
33
  interval = (
41
34
  (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
@@ -44,7 +37,7 @@ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
44
37
 
45
38
 
46
39
  _Section: TypeAlias = ProjectRowId
47
- _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind]
40
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
48
41
 
49
42
 
50
43
  class RecordCountCache(
@@ -59,7 +52,7 @@ class RecordCountCache(
59
52
  sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2),
60
53
  )
61
54
 
62
- def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
55
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
63
56
  (kind, interval, filter_condition), project_rowid = _cache_key_fn(key)
64
57
  return project_rowid, (interval, filter_condition, kind)
65
58
 
@@ -77,11 +70,11 @@ class RecordCountDataLoader(DataLoader[Key, Result]):
77
70
  )
78
71
  self._db = db
79
72
 
80
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
81
- results: List[Result] = [DEFAULT_VALUE] * len(keys)
82
- arguments: DefaultDict[
73
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
74
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
75
+ arguments: defaultdict[
83
76
  Segment,
84
- DefaultDict[Param, List[ResultPosition]],
77
+ defaultdict[Param, list[ResultPosition]],
85
78
  ] = defaultdict(lambda: defaultdict(list))
86
79
  for position, key in enumerate(keys):
87
80
  segment, param = _cache_key_fn(key)
@@ -1,8 +1,4 @@
1
1
  from collections import defaultdict
2
- from typing import (
3
- DefaultDict,
4
- List,
5
- )
6
2
 
7
3
  from sqlalchemy import select
8
4
  from strawberry.dataloader import DataLoader
@@ -12,7 +8,7 @@ from phoenix.db.models import SpanAnnotation as ORMSpanAnnotation
12
8
  from phoenix.server.types import DbSessionFactory
13
9
 
14
10
  Key: TypeAlias = int
15
- Result: TypeAlias = List[ORMSpanAnnotation]
11
+ Result: TypeAlias = list[ORMSpanAnnotation]
16
12
 
17
13
 
18
14
  class SpanAnnotationsDataLoader(DataLoader[Key, Result]):
@@ -20,8 +16,8 @@ class SpanAnnotationsDataLoader(DataLoader[Key, Result]):
20
16
  super().__init__(load_fn=self._load_fn)
21
17
  self._db = db
22
18
 
23
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
24
- span_annotations_by_id: DefaultDict[Key, Result] = defaultdict(list)
19
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
20
+ span_annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
25
21
  async with self._db() as session:
26
22
  async for span_annotation in await session.stream_scalars(
27
23
  select(ORMSpanAnnotation).where(ORMSpanAnnotation.span_rowid.in_(keys))
@@ -1,8 +1,3 @@
1
- from typing import (
2
- Dict,
3
- List,
4
- )
5
-
6
1
  from sqlalchemy import select
7
2
  from strawberry.dataloader import DataLoader
8
3
  from typing_extensions import TypeAlias
@@ -12,7 +7,7 @@ from phoenix.server.types import DbSessionFactory
12
7
 
13
8
  SpanID: TypeAlias = int
14
9
  Key: TypeAlias = SpanID
15
- Result: TypeAlias = List[models.DatasetExample]
10
+ Result: TypeAlias = list[models.DatasetExample]
16
11
 
17
12
 
18
13
  class SpanDatasetExamplesDataLoader(DataLoader[Key, Result]):
@@ -20,10 +15,10 @@ class SpanDatasetExamplesDataLoader(DataLoader[Key, Result]):
20
15
  super().__init__(load_fn=self._load_fn)
21
16
  self._db = db
22
17
 
23
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
18
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
19
  span_rowids = keys
25
20
  async with self._db() as session:
26
- dataset_examples: Dict[Key, List[models.DatasetExample]] = {
21
+ dataset_examples: dict[Key, list[models.DatasetExample]] = {
27
22
  span_rowid: [] for span_rowid in span_rowids
28
23
  }
29
24
  async for span_rowid, dataset_example in await session.stream(
@@ -1,8 +1,4 @@
1
1
  from random import randint
2
- from typing import (
3
- Dict,
4
- List,
5
- )
6
2
 
7
3
  from aioitertools.itertools import groupby
8
4
  from sqlalchemy import select
@@ -16,7 +12,7 @@ from phoenix.server.types import DbSessionFactory
16
12
  SpanId: TypeAlias = str
17
13
 
18
14
  Key: TypeAlias = SpanId
19
- Result: TypeAlias = List[models.Span]
15
+ Result: TypeAlias = list[models.Span]
20
16
 
21
17
 
22
18
  class SpanDescendantsDataLoader(DataLoader[Key, Result]):
@@ -24,7 +20,7 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
24
20
  super().__init__(load_fn=self._load_fn)
25
21
  self._db = db
26
22
 
27
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
23
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
28
24
  root_ids = set(keys)
29
25
  root_id_label = f"root_id_{randint(0, 10**6):06}"
30
26
  descendant_ids = (
@@ -53,7 +49,7 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
53
49
  .options(joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id))
54
50
  .order_by(descendant_ids.c[root_id_label])
55
51
  )
56
- results: Dict[SpanId, Result] = {key: [] for key in keys}
52
+ results: dict[SpanId, Result] = {key: [] for key in keys}
57
53
  async with self._db() as session:
58
54
  data = await session.stream(stmt)
59
55
  async for root_id, group in groupby(data, key=lambda d: d[0]):
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import Union
2
2
 
3
3
  from sqlalchemy import select
4
4
  from strawberry.dataloader import DataLoader
@@ -17,7 +17,7 @@ class SpanProjectsDataLoader(DataLoader[Key, Result]):
17
17
  super().__init__(load_fn=self._load_fn)
18
18
  self._db = db
19
19
 
20
- async def _load_fn(self, keys: List[Key]) -> List[Union[Result, ValueError]]:
20
+ async def _load_fn(self, keys: list[Key]) -> list[Union[Result, ValueError]]:
21
21
  span_ids = list(set(keys))
22
22
  async with self._db() as session:
23
23
  projects = {
@@ -1,13 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
- from typing import (
4
- Any,
5
- DefaultDict,
6
- List,
7
- Literal,
8
- Optional,
9
- Tuple,
10
- )
3
+ from typing import Any, Literal, Optional
11
4
 
12
5
  from cachetools import LFUCache, TTLCache
13
6
  from sqlalchemy import Select, func, select
@@ -23,20 +16,20 @@ from phoenix.trace.dsl import SpanFilter
23
16
 
24
17
  Kind: TypeAlias = Literal["prompt", "completion", "total"]
25
18
  ProjectRowId: TypeAlias = int
26
- TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
19
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
27
20
  FilterCondition: TypeAlias = Optional[str]
28
21
  TokenCount: TypeAlias = int
29
22
 
30
- Segment: TypeAlias = Tuple[TimeInterval, FilterCondition]
31
- Param: TypeAlias = Tuple[ProjectRowId, Kind]
23
+ Segment: TypeAlias = tuple[TimeInterval, FilterCondition]
24
+ Param: TypeAlias = tuple[ProjectRowId, Kind]
32
25
 
33
- Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
26
+ Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
34
27
  Result: TypeAlias = TokenCount
35
28
  ResultPosition: TypeAlias = int
36
29
  DEFAULT_VALUE: Result = 0
37
30
 
38
31
 
39
- def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
32
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
40
33
  kind, project_rowid, time_range, filter_condition = key
41
34
  interval = (
42
35
  (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
@@ -45,7 +38,7 @@ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
45
38
 
46
39
 
47
40
  _Section: TypeAlias = ProjectRowId
48
- _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind]
41
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
49
42
 
50
43
 
51
44
  class TokenCountCache(
@@ -60,7 +53,7 @@ class TokenCountCache(
60
53
  sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
61
54
  )
62
55
 
63
- def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
56
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
64
57
  (interval, filter_condition), (project_rowid, kind) = _cache_key_fn(key)
65
58
  return project_rowid, (interval, filter_condition, kind)
66
59
 
@@ -78,11 +71,11 @@ class TokenCountDataLoader(DataLoader[Key, Result]):
78
71
  )
79
72
  self._db = db
80
73
 
81
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
82
- results: List[Result] = [DEFAULT_VALUE] * len(keys)
83
- arguments: DefaultDict[
74
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
75
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
76
+ arguments: defaultdict[
84
77
  Segment,
85
- DefaultDict[Param, List[ResultPosition]],
78
+ defaultdict[Param, list[ResultPosition]],
86
79
  ] = defaultdict(lambda: defaultdict(list))
87
80
  for position, key in enumerate(keys):
88
81
  segment, param = _cache_key_fn(key)
@@ -1,8 +1,4 @@
1
- from typing import (
2
- List,
3
- Optional,
4
- Tuple,
5
- )
1
+ from typing import Optional
6
2
 
7
3
  from sqlalchemy import select
8
4
  from strawberry.dataloader import DataLoader
@@ -15,7 +11,7 @@ TraceId: TypeAlias = str
15
11
  Key: TypeAlias = TraceId
16
12
  TraceRowId: TypeAlias = int
17
13
  ProjectRowId: TypeAlias = int
18
- Result: TypeAlias = Optional[Tuple[TraceRowId, ProjectRowId]]
14
+ Result: TypeAlias = Optional[tuple[TraceRowId, ProjectRowId]]
19
15
 
20
16
 
21
17
  class TraceRowIdsDataLoader(DataLoader[Key, Result]):
@@ -23,7 +19,7 @@ class TraceRowIdsDataLoader(DataLoader[Key, Result]):
23
19
  super().__init__(load_fn=self._load_fn)
24
20
  self._db = db
25
21
 
26
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
27
23
  stmt = select(
28
24
  models.Trace.trace_id,
29
25
  models.Trace.id,
@@ -1,5 +1,5 @@
1
1
  from collections import defaultdict
2
- from typing import DefaultDict, List, Optional
2
+ from typing import Optional
3
3
 
4
4
  from sqlalchemy import select
5
5
  from strawberry.dataloader import DataLoader
@@ -20,8 +20,8 @@ class UserRolesDataLoader(DataLoader[Key, Result]):
20
20
  super().__init__(load_fn=self._load_fn)
21
21
  self._db = db
22
22
 
23
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
24
- user_roles_by_id: DefaultDict[Key, Result] = defaultdict(None)
23
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
+ user_roles_by_id: defaultdict[Key, Result] = defaultdict(None)
25
25
  async with self._db() as session:
26
26
  data = await session.stream_scalars(select(models.UserRole))
27
27
  async for user_role in data:
@@ -1,5 +1,5 @@
1
1
  from collections import defaultdict
2
- from typing import DefaultDict, List, Optional
2
+ from typing import Optional
3
3
 
4
4
  from sqlalchemy import select
5
5
  from strawberry.dataloader import DataLoader
@@ -20,9 +20,9 @@ class UsersDataLoader(DataLoader[Key, Result]):
20
20
  super().__init__(load_fn=self._load_fn)
21
21
  self._db = db
22
22
 
23
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
23
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
24
  user_ids = list(set(keys))
25
- users_by_id: DefaultDict[Key, Result] = defaultdict(None)
25
+ users_by_id: defaultdict[Key, Result] = defaultdict(None)
26
26
  async with self._db() as session:
27
27
  data = await session.stream_scalars(
28
28
  select(models.User).where(models.User.id.in_(user_ids))
@@ -1,10 +1,11 @@
1
- from typing import Iterable, List, Optional, TypeVar
1
+ from collections.abc import Iterable
2
+ from typing import Optional, TypeVar
2
3
 
3
4
  T = TypeVar("T")
4
5
 
5
6
 
6
- def ensure_list(obj: Optional[Iterable[T]]) -> List[T]:
7
- if isinstance(obj, List):
7
+ def ensure_list(obj: Optional[Iterable[T]]) -> list[T]:
8
+ if isinstance(obj, list):
8
9
  return obj
9
10
  if isinstance(obj, Iterable):
10
11
  return list(obj)
@@ -1,5 +1,6 @@
1
1
  import json
2
- from typing import Any, Dict, Literal, Mapping, Optional, Protocol
2
+ from collections.abc import Mapping
3
+ from typing import Any, Literal, Optional, Protocol
3
4
 
4
5
  from openinference.semconv.trace import (
5
6
  MessageAttributes,
@@ -28,7 +29,7 @@ class HasSpanIO(Protocol):
28
29
  retrieval_documents: Any
29
30
 
30
31
 
31
- def get_dataset_example_input(span: HasSpanIO) -> Dict[str, Any]:
32
+ def get_dataset_example_input(span: HasSpanIO) -> dict[str, Any]:
32
33
  """
33
34
  Extracts the input value from a span and returns it as a dictionary. Input
34
35
  values from LLM spans are extracted from the input messages and prompt
@@ -47,7 +48,7 @@ def get_dataset_example_input(span: HasSpanIO) -> Dict[str, Any]:
47
48
  return _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")
48
49
 
49
50
 
50
- def get_dataset_example_output(span: HasSpanIO) -> Dict[str, Any]:
51
+ def get_dataset_example_output(span: HasSpanIO) -> dict[str, Any]:
51
52
  """
52
53
  Extracts the output value from a span and returns it as a dictionary. Output
53
54
  values from LLM spans are extracted from the output messages (if present).
@@ -78,13 +79,13 @@ def _get_llm_span_input(
78
79
  input_value: Any,
79
80
  input_mime_type: Optional[str],
80
81
  prompt_template_variables: Any,
81
- ) -> Dict[str, Any]:
82
+ ) -> dict[str, Any]:
82
83
  """
83
84
  Extracts the input value from an LLM span and returns it as a dictionary.
84
85
  The input is extracted from the input messages (if present) and prompt
85
86
  template variables (if present).
86
87
  """
87
- input: Dict[str, Any] = {}
88
+ input: dict[str, Any] = {}
88
89
  if messages := [_get_message(m) for m in input_messages or ()]:
89
90
  input["messages"] = messages
90
91
  if not input:
@@ -98,7 +99,7 @@ def _get_llm_span_output(
98
99
  output_messages: Any,
99
100
  output_value: Any,
100
101
  output_mime_type: Optional[str],
101
- ) -> Dict[str, Any]:
102
+ ) -> dict[str, Any]:
102
103
  """
103
104
  Extracts the output value from an LLM span and returns it as a dictionary.
104
105
  The output is extracted from the output messages (if present).
@@ -112,7 +113,7 @@ def _get_retriever_span_output(
112
113
  retrieval_documents: Any,
113
114
  output_value: Any,
114
115
  output_mime_type: Optional[str],
115
- ) -> Dict[str, Any]:
116
+ ) -> dict[str, Any]:
116
117
  """
117
118
  Extracts the output value from a retriever span and returns it as a dictionary.
118
119
  The output is extracted from the retrieval documents (if present).
@@ -124,7 +125,7 @@ def _get_retriever_span_output(
124
125
 
125
126
  def _get_generic_io_value(
126
127
  io_value: Any, mime_type: Optional[str], kind: Literal["input", "output"]
127
- ) -> Dict[str, Any]:
128
+ ) -> dict[str, Any]:
128
129
  """
129
130
  Makes a best-effort attempt to extract the input or output value from a span
130
131
  and returns it as a dictionary.
@@ -140,7 +141,7 @@ def _get_generic_io_value(
140
141
  return {}
141
142
 
142
143
 
143
- def _get_message(message: Mapping[str, Any]) -> Dict[str, Any]:
144
+ def _get_message(message: Mapping[str, Any]) -> dict[str, Any]:
144
145
  content = get_attribute_value(message, MESSAGE_CONTENT)
145
146
  name = get_attribute_value(message, MESSAGE_NAME)
146
147
  function_call_name = get_attribute_value(message, MESSAGE_FUNCTION_CALL_NAME)
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -11,6 +11,6 @@ from .DatasetExampleInput import DatasetExampleInput
11
11
  @strawberry.input
12
12
  class AddExamplesToDatasetInput:
13
13
  dataset_id: GlobalID
14
- examples: List[DatasetExampleInput]
14
+ examples: list[DatasetExampleInput]
15
15
  dataset_version_description: Optional[str] = UNSET
16
16
  dataset_version_metadata: Optional[JSON] = UNSET
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -9,6 +9,6 @@ from strawberry.scalars import JSON
9
9
  @strawberry.input
10
10
  class AddSpansToDatasetInput:
11
11
  dataset_id: GlobalID
12
- span_ids: List[GlobalID]
12
+ span_ids: list[GlobalID]
13
13
  dataset_version_description: Optional[str] = UNSET
14
14
  dataset_version_metadata: Optional[JSON] = UNSET
@@ -1,4 +1,7 @@
1
+ from typing import Optional
2
+
1
3
  import strawberry
4
+ from strawberry import UNSET
2
5
  from strawberry.scalars import JSON
3
6
 
4
7
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
@@ -8,5 +11,14 @@ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMes
8
11
  class ChatCompletionMessageInput:
9
12
  role: ChatCompletionMessageRole
10
13
  content: JSON = strawberry.field(
11
- description="The content of the message as JSON to support text and tools",
14
+ default="",
15
+ description="The content of the message as JSON to support various kinds of text",
16
+ )
17
+ tool_calls: Optional[list[JSON]] = strawberry.field(
18
+ description="The tool calls that were made in the message",
19
+ default=UNSET,
20
+ )
21
+ tool_call_id: Optional[str] = strawberry.field(
22
+ description="The ID that corresponds to a prior tool call. Used to link a tool message to a pre-existing tool call.", # noqa: E501
23
+ default=UNSET,
12
24
  )
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import ID, UNSET
@@ -6,5 +6,5 @@ from strawberry import ID, UNSET
6
6
 
7
7
  @strawberry.input
8
8
  class ClusterInput:
9
- event_ids: List[ID]
9
+ event_ids: list[ID]
10
10
  id: Optional[ID] = UNSET
@@ -1,9 +1,7 @@
1
- from typing import List
2
-
3
1
  import strawberry
4
2
  from strawberry.relay import GlobalID
5
3
 
6
4
 
7
5
  @strawberry.input
8
6
  class DeleteAnnotationsInput:
9
- annotation_ids: List[GlobalID]
7
+ annotation_ids: list[GlobalID]
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -8,6 +8,6 @@ from strawberry.scalars import JSON
8
8
 
9
9
  @strawberry.input
10
10
  class DeleteDatasetExamplesInput:
11
- example_ids: List[GlobalID]
11
+ example_ids: list[GlobalID]
12
12
  dataset_version_description: Optional[str] = UNSET
13
13
  dataset_version_metadata: Optional[JSON] = UNSET
@@ -1,9 +1,7 @@
1
- from typing import List
2
-
3
1
  import strawberry
4
2
  from strawberry.relay import GlobalID
5
3
 
6
4
 
7
5
  @strawberry.input
8
6
  class DeleteExperimentsInput:
9
- experiment_ids: List[GlobalID]
7
+ experiment_ids: list[GlobalID]
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -52,9 +52,9 @@ class DimensionFilter:
52
52
 
53
53
  """
54
54
 
55
- types: Optional[List[DimensionType]] = UNSET
56
- shapes: Optional[List[DimensionShape]] = UNSET
57
- data_types: Optional[List[DimensionDataType]] = UNSET
55
+ types: Optional[list[DimensionType]] = UNSET
56
+ shapes: Optional[list[DimensionShape]] = UNSET
57
+ data_types: Optional[list[DimensionDataType]] = UNSET
58
58
 
59
59
  def __post_init__(self) -> None:
60
60
  self.types = ensure_list(self.types)
@@ -1,6 +1,6 @@
1
+ from collections.abc import Iterator
1
2
  from datetime import datetime, timedelta
2
3
  from itertools import accumulate, repeat, takewhile
3
- from typing import Iterator
4
4
 
5
5
  import strawberry
6
6
 
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -15,6 +15,6 @@ class InvocationParameters:
15
15
  max_completion_tokens: Optional[int] = UNSET
16
16
  max_tokens: Optional[int] = UNSET
17
17
  top_p: Optional[float] = UNSET
18
- stop: Optional[List[str]] = UNSET
18
+ stop: Optional[list[str]] = UNSET
19
19
  seed: Optional[int] = UNSET
20
20
  tool_choice: Optional[JSON] = UNSET
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -30,6 +30,6 @@ class PatchDatasetExamplesInput:
30
30
  Input type to the patchDatasetExamples mutation.
31
31
  """
32
32
 
33
- patches: List[DatasetExamplePatch]
33
+ patches: list[DatasetExamplePatch]
34
34
  version_description: Optional[str] = UNSET
35
35
  version_metadata: Optional[JSON] = UNSET
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  from datetime import datetime
3
- from typing import Any, Dict
3
+ from typing import Any
4
4
 
5
5
  import strawberry
6
6
  from openinference.semconv.trace import (
@@ -175,7 +175,7 @@ class DatasetMutationMixin:
175
175
  )
176
176
  ).all()
177
177
 
178
- span_annotations_by_span: Dict[int, Dict[Any, Any]] = {span.id: {} for span in spans}
178
+ span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
179
179
  for annotation in span_annotations:
180
180
  span_id = annotation.span_rowid
181
181
  if span_id not in span_annotations_by_span:
@@ -287,7 +287,7 @@ class DatasetMutationMixin:
287
287
  )
288
288
  ).all()
289
289
 
290
- span_annotations_by_span: Dict[int, Dict[Any, Any]] = {span.id: {} for span in spans}
290
+ span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
291
291
  for annotation in span_annotations:
292
292
  span_id = annotation.span_rowid
293
293
  if span_id not in span_annotations_by_span:
@@ -577,7 +577,7 @@ def _to_orm_revision(
577
577
  patch: DatasetExamplePatch,
578
578
  example_id: int,
579
579
  version_id: int,
580
- ) -> Dict[str, Any]:
580
+ ) -> dict[str, Any]:
581
581
  """
582
582
  Creates a new revision from an existing revision and a patch. The output is a
583
583
  dictionary suitable for insertion into the database using the sqlalchemy
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- from typing import List
3
2
 
4
3
  import strawberry
5
4
  from sqlalchemy import delete
@@ -20,7 +19,7 @@ from phoenix.server.dml_event import ExperimentDeleteEvent
20
19
 
21
20
  @strawberry.type
22
21
  class ExperimentMutationPayload:
23
- experiments: List[Experiment]
22
+ experiments: list[Experiment]
24
23
 
25
24
 
26
25
  @strawberry.type