arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/document_annotation.py +1 -1
  9. phoenix/db/insertion/helpers.py +2 -2
  10. phoenix/db/insertion/session_annotation.py +176 -0
  11. phoenix/db/insertion/span_annotation.py +1 -1
  12. phoenix/db/insertion/trace_annotation.py +1 -1
  13. phoenix/db/insertion/types.py +29 -3
  14. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  15. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  16. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  17. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  18. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  19. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  20. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  21. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  22. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  23. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  24. phoenix/db/models.py +306 -46
  25. phoenix/server/api/context.py +15 -2
  26. phoenix/server/api/dataloaders/__init__.py +8 -2
  27. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  28. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  29. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  30. phoenix/server/api/dataloaders/table_fields.py +2 -2
  31. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  32. phoenix/server/api/helpers/playground_clients.py +66 -35
  33. phoenix/server/api/helpers/playground_users.py +26 -0
  34. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  35. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  36. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  37. phoenix/server/api/mutations/__init__.py +8 -0
  38. phoenix/server/api/mutations/chat_mutations.py +8 -3
  39. phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
  40. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  41. phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
  42. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  43. phoenix/server/api/queries.py +53 -0
  44. phoenix/server/api/routers/auth.py +5 -5
  45. phoenix/server/api/routers/oauth2.py +5 -23
  46. phoenix/server/api/routers/v1/__init__.py +2 -0
  47. phoenix/server/api/routers/v1/annotations.py +320 -0
  48. phoenix/server/api/routers/v1/datasets.py +5 -0
  49. phoenix/server/api/routers/v1/experiments.py +10 -3
  50. phoenix/server/api/routers/v1/sessions.py +111 -0
  51. phoenix/server/api/routers/v1/traces.py +1 -2
  52. phoenix/server/api/routers/v1/users.py +7 -0
  53. phoenix/server/api/subscriptions.py +5 -2
  54. phoenix/server/api/types/Dataset.py +8 -0
  55. phoenix/server/api/types/DatasetExample.py +18 -0
  56. phoenix/server/api/types/DatasetLabel.py +23 -0
  57. phoenix/server/api/types/DatasetSplit.py +32 -0
  58. phoenix/server/api/types/Experiment.py +0 -4
  59. phoenix/server/api/types/Project.py +16 -0
  60. phoenix/server/api/types/ProjectSession.py +88 -3
  61. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  62. phoenix/server/api/types/Prompt.py +18 -1
  63. phoenix/server/api/types/Span.py +5 -5
  64. phoenix/server/api/types/Trace.py +61 -0
  65. phoenix/server/app.py +13 -14
  66. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  67. phoenix/server/dml_event.py +13 -0
  68. phoenix/server/static/.vite/manifest.json +39 -39
  69. phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
  70. phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
  71. phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
  72. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
  73. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
  74. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
  75. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
  76. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
  77. phoenix/server/utils.py +74 -0
  78. phoenix/session/session.py +25 -5
  79. phoenix/version.py +1 -1
  80. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  81. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -23,6 +23,7 @@ from phoenix.server.api.dataloaders import (
23
23
  DatasetExampleRevisionsDataLoader,
24
24
  DatasetExamplesAndVersionsByExperimentRunDataLoader,
25
25
  DatasetExampleSpansDataLoader,
26
+ DatasetExampleSplitsDataLoader,
26
27
  DocumentEvaluationsDataLoader,
27
28
  DocumentEvaluationSummaryDataLoader,
28
29
  DocumentRetrievalMetricsDataLoader,
@@ -30,7 +31,6 @@ from phoenix.server.api.dataloaders import (
30
31
  ExperimentErrorRatesDataLoader,
31
32
  ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
32
33
  ExperimentRepeatedRunGroupsDataLoader,
33
- ExperimentRepetitionCountsDataLoader,
34
34
  ExperimentRunAnnotations,
35
35
  ExperimentRunCountsDataLoader,
36
36
  ExperimentSequenceNumberDataLoader,
@@ -43,6 +43,7 @@ from phoenix.server.api.dataloaders import (
43
43
  ProjectIdsByTraceRetentionPolicyIdDataLoader,
44
44
  PromptVersionSequenceNumberDataLoader,
45
45
  RecordCountDataLoader,
46
+ SessionAnnotationsBySessionDataLoader,
46
47
  SessionIODataLoader,
47
48
  SessionNumTracesDataLoader,
48
49
  SessionNumTracesWithErrorDataLoader,
@@ -68,12 +69,14 @@ from phoenix.server.api.dataloaders import (
68
69
  SpanProjectsDataLoader,
69
70
  TableFieldsDataLoader,
70
71
  TokenCountDataLoader,
72
+ TraceAnnotationsByTraceDataLoader,
71
73
  TraceByTraceIdsDataLoader,
72
74
  TraceRetentionPolicyIdByProjectIdDataLoader,
73
75
  TraceRootSpansDataLoader,
74
76
  UserRolesDataLoader,
75
77
  UsersDataLoader,
76
78
  )
79
+ from phoenix.server.api.dataloaders.dataset_labels import DatasetLabelsDataLoader
77
80
  from phoenix.server.bearer_auth import PhoenixUser
78
81
  from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
79
82
  from phoenix.server.dml_event import DmlEvent
@@ -97,9 +100,11 @@ class DataLoaders:
97
100
  average_experiment_run_latency: AverageExperimentRunLatencyDataLoader
98
101
  dataset_example_revisions: DatasetExampleRevisionsDataLoader
99
102
  dataset_example_spans: DatasetExampleSpansDataLoader
103
+ dataset_labels: DatasetLabelsDataLoader
100
104
  dataset_examples_and_versions_by_experiment_run: (
101
105
  DatasetExamplesAndVersionsByExperimentRunDataLoader
102
106
  )
107
+ dataset_example_splits: DatasetExampleSplitsDataLoader
103
108
  document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
104
109
  document_evaluations: DocumentEvaluationsDataLoader
105
110
  document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
@@ -109,7 +114,6 @@ class DataLoaders:
109
114
  ExperimentRepeatedRunGroupAnnotationSummariesDataLoader
110
115
  )
111
116
  experiment_repeated_run_groups: ExperimentRepeatedRunGroupsDataLoader
112
- experiment_repetition_counts: ExperimentRepetitionCountsDataLoader
113
117
  experiment_run_annotations: ExperimentRunAnnotations
114
118
  experiment_run_counts: ExperimentRunCountsDataLoader
115
119
  experiment_sequence_number: ExperimentSequenceNumberDataLoader
@@ -124,6 +128,7 @@ class DataLoaders:
124
128
  projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
125
129
  prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
126
130
  record_counts: RecordCountDataLoader
131
+ session_annotations_by_session: SessionAnnotationsBySessionDataLoader
127
132
  session_first_inputs: SessionIODataLoader
128
133
  session_last_outputs: SessionIODataLoader
129
134
  session_num_traces: SessionNumTracesDataLoader
@@ -158,6 +163,7 @@ class DataLoaders:
158
163
  span_fields: TableFieldsDataLoader
159
164
  span_projects: SpanProjectsDataLoader
160
165
  token_counts: TokenCountDataLoader
166
+ trace_annotations_by_trace: TraceAnnotationsByTraceDataLoader
161
167
  trace_by_trace_ids: TraceByTraceIdsDataLoader
162
168
  trace_fields: TableFieldsDataLoader
163
169
  trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
@@ -237,3 +243,10 @@ class Context(BaseContext):
237
243
  @cached_property
238
244
  def user(self) -> PhoenixUser:
239
245
  return cast(PhoenixUser, self.get_request().user)
246
+
247
+ @cached_property
248
+ def user_id(self) -> Optional[int]:
249
+ try:
250
+ return int(self.user.identity)
251
+ except Exception:
252
+ return None
@@ -12,9 +12,11 @@ from .average_experiment_repeated_run_group_latency import (
12
12
  from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
13
13
  from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
14
14
  from .dataset_example_spans import DatasetExampleSpansDataLoader
15
+ from .dataset_example_splits import DatasetExampleSplitsDataLoader
15
16
  from .dataset_examples_and_versions_by_experiment_run import (
16
17
  DatasetExamplesAndVersionsByExperimentRunDataLoader,
17
18
  )
19
+ from .dataset_labels import DatasetLabelsDataLoader
18
20
  from .document_evaluation_summaries import (
19
21
  DocumentEvaluationSummaryCache,
20
22
  DocumentEvaluationSummaryDataLoader,
@@ -27,7 +29,6 @@ from .experiment_repeated_run_group_annotation_summaries import (
27
29
  ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
28
30
  )
29
31
  from .experiment_repeated_run_groups import ExperimentRepeatedRunGroupsDataLoader
30
- from .experiment_repetition_counts import ExperimentRepetitionCountsDataLoader
31
32
  from .experiment_run_annotations import ExperimentRunAnnotations
32
33
  from .experiment_run_counts import ExperimentRunCountsDataLoader
33
34
  from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
@@ -40,6 +41,7 @@ from .project_by_name import ProjectByNameDataLoader
40
41
  from .project_ids_by_trace_retention_policy_id import ProjectIdsByTraceRetentionPolicyIdDataLoader
41
42
  from .prompt_version_sequence_number import PromptVersionSequenceNumberDataLoader
42
43
  from .record_counts import RecordCountCache, RecordCountDataLoader
44
+ from .session_annotations_by_session import SessionAnnotationsBySessionDataLoader
43
45
  from .session_io import SessionIODataLoader
44
46
  from .session_num_traces import SessionNumTracesDataLoader
45
47
  from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader
@@ -69,6 +71,7 @@ from .span_descendants import SpanDescendantsDataLoader
69
71
  from .span_projects import SpanProjectsDataLoader
70
72
  from .table_fields import TableFieldsDataLoader
71
73
  from .token_counts import TokenCountCache, TokenCountDataLoader
74
+ from .trace_annotations_by_trace import TraceAnnotationsByTraceDataLoader
72
75
  from .trace_by_trace_ids import TraceByTraceIdsDataLoader
73
76
  from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
74
77
  from .trace_root_spans import TraceRootSpansDataLoader
@@ -84,6 +87,8 @@ __all__ = [
84
87
  "DatasetExampleRevisionsDataLoader",
85
88
  "DatasetExampleSpansDataLoader",
86
89
  "DatasetExamplesAndVersionsByExperimentRunDataLoader",
90
+ "DatasetExampleSplitsDataLoader",
91
+ "DatasetLabelsDataLoader",
87
92
  "DocumentEvaluationSummaryDataLoader",
88
93
  "DocumentEvaluationsDataLoader",
89
94
  "DocumentRetrievalMetricsDataLoader",
@@ -91,7 +96,6 @@ __all__ = [
91
96
  "ExperimentErrorRatesDataLoader",
92
97
  "ExperimentRepeatedRunGroupsDataLoader",
93
98
  "ExperimentRepeatedRunGroupAnnotationSummariesDataLoader",
94
- "ExperimentRepetitionCountsDataLoader",
95
99
  "ExperimentRunAnnotations",
96
100
  "ExperimentRunCountsDataLoader",
97
101
  "ExperimentSequenceNumberDataLoader",
@@ -104,6 +108,7 @@ __all__ = [
104
108
  "ProjectIdsByTraceRetentionPolicyIdDataLoader",
105
109
  "PromptVersionSequenceNumberDataLoader",
106
110
  "RecordCountDataLoader",
111
+ "SessionAnnotationsBySessionDataLoader",
107
112
  "SessionIODataLoader",
108
113
  "SessionNumTracesDataLoader",
109
114
  "SessionNumTracesWithErrorDataLoader",
@@ -130,6 +135,7 @@ __all__ = [
130
135
  "SpanProjectsDataLoader",
131
136
  "TableFieldsDataLoader",
132
137
  "TokenCountDataLoader",
138
+ "TraceAnnotationsByTraceDataLoader",
133
139
  "TraceByTraceIdsDataLoader",
134
140
  "TraceRetentionPolicyIdByProjectIdDataLoader",
135
141
  "TraceRootSpansDataLoader",
@@ -0,0 +1,40 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ ExampleID: TypeAlias = int
9
+ Key: TypeAlias = ExampleID
10
+ Result: TypeAlias = list[models.DatasetSplit]
11
+
12
+
13
+ class DatasetExampleSplitsDataLoader(DataLoader[Key, Result]):
14
+ def __init__(self, db: DbSessionFactory) -> None:
15
+ super().__init__(
16
+ load_fn=self._load_fn,
17
+ )
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ example_ids = keys
22
+ async with self._db() as session:
23
+ splits: dict[ExampleID, list[models.DatasetSplit]] = {}
24
+
25
+ async for example_id, split in await session.stream(
26
+ select(models.DatasetSplitDatasetExample.dataset_example_id, models.DatasetSplit)
27
+ .select_from(models.DatasetSplit)
28
+ .join(
29
+ models.DatasetSplitDatasetExample,
30
+ onclause=(
31
+ models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
32
+ ),
33
+ )
34
+ .where(models.DatasetSplitDatasetExample.dataset_example_id.in_(example_ids))
35
+ ):
36
+ if example_id not in splits:
37
+ splits[example_id] = []
38
+ splits[example_id].append(split)
39
+
40
+ return [sorted(splits.get(example_id, []), key=lambda x: x.name) for example_id in keys]
@@ -0,0 +1,36 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ DatasetID: TypeAlias = int
9
+ Key: TypeAlias = DatasetID
10
+ Result: TypeAlias = list[models.DatasetLabel]
11
+
12
+
13
+ class DatasetLabelsDataLoader(DataLoader[Key, Result]):
14
+ def __init__(self, db: DbSessionFactory) -> None:
15
+ super().__init__(load_fn=self._load_fn)
16
+ self._db = db
17
+
18
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
19
+ dataset_ids = keys
20
+ async with self._db() as session:
21
+ labels: dict[Key, Result] = {}
22
+ for dataset_id, label in await session.execute(
23
+ select(models.DatasetsDatasetLabel.dataset_id, models.DatasetLabel)
24
+ .select_from(models.DatasetLabel)
25
+ .join(
26
+ models.DatasetsDatasetLabel,
27
+ models.DatasetLabel.id == models.DatasetsDatasetLabel.dataset_label_id,
28
+ )
29
+ .where(models.DatasetsDatasetLabel.dataset_id.in_(dataset_ids))
30
+ ):
31
+ if dataset_id not in labels:
32
+ labels[dataset_id] = []
33
+ labels[dataset_id].append(label)
34
+ return [
35
+ sorted(labels.get(dataset_id, []), key=lambda label: label.name) for dataset_id in keys
36
+ ]
@@ -0,0 +1,29 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.models import ProjectSessionAnnotation
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ProjectSessionId: TypeAlias = int
11
+ Key: TypeAlias = ProjectSessionId
12
+ Result: TypeAlias = list[ProjectSessionAnnotation]
13
+
14
+
15
+ class SessionAnnotationsBySessionDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
22
+ async with self._db() as session:
23
+ async for annotation in await session.stream_scalars(
24
+ select(ProjectSessionAnnotation).where(
25
+ ProjectSessionAnnotation.project_session_id.in_(keys)
26
+ )
27
+ ):
28
+ annotations_by_id[annotation.project_session_id].append(annotation)
29
+ return [annotations_by_id[key] for key in keys]
@@ -18,7 +18,7 @@ _AttrStrIdentifier: TypeAlias = str
18
18
 
19
19
 
20
20
  class TableFieldsDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: DbSessionFactory, table: type[models.Base]) -> None:
21
+ def __init__(self, db: DbSessionFactory, table: type[models.HasId]) -> None:
22
22
  super().__init__(load_fn=self._load_fn)
23
23
  self._db = db
24
24
  self._table = table
@@ -37,7 +37,7 @@ class TableFieldsDataLoader(DataLoader[Key, Result]):
37
37
 
38
38
  def _get_stmt(
39
39
  keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
40
- table: type[models.Base],
40
+ table: type[models.HasId],
41
41
  ) -> tuple[
42
42
  Select[Any],
43
43
  dict[_ResultColumnPosition, _AttrStrIdentifier],
@@ -0,0 +1,27 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.models import TraceAnnotation
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ TraceRowId: TypeAlias = int
11
+ Key: TypeAlias = TraceRowId
12
+ Result: TypeAlias = list[TraceAnnotation]
13
+
14
+
15
+ class TraceAnnotationsByTraceDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
22
+ async with self._db() as session:
23
+ async for annotation in await session.stream_scalars(
24
+ select(TraceAnnotation).where(TraceAnnotation.trace_rowid.in_(keys))
25
+ ):
26
+ annotations_by_id[annotation.trace_rowid].append(annotation)
27
+ return [annotations_by_id[key] for key in keys]
@@ -1677,6 +1677,7 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
1677
1677
  @register_llm_client(
1678
1678
  provider_key=GenerativeProviderKey.ANTHROPIC,
1679
1679
  model_names=[
1680
+ "claude-sonnet-4-5",
1680
1681
  "claude-sonnet-4-0",
1681
1682
  "claude-sonnet-4-20250514",
1682
1683
  "claude-opus-4-1",
@@ -1705,11 +1706,6 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
1705
1706
  provider_key=GenerativeProviderKey.GOOGLE,
1706
1707
  model_names=[
1707
1708
  PROVIDER_DEFAULT,
1708
- "gemini-2.5-flash",
1709
- "gemini-2.5-flash-lite",
1710
- "gemini-2.5-pro",
1711
- "gemini-2.5-pro-preview-03-25",
1712
- "gemini-2.0-flash",
1713
1709
  "gemini-2.0-flash-lite",
1714
1710
  "gemini-2.0-flash-001",
1715
1711
  "gemini-2.0-flash-thinking-exp-01-21",
@@ -1725,7 +1721,7 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1725
1721
  model: GenerativeModelInput,
1726
1722
  credentials: Optional[list[PlaygroundClientCredential]] = None,
1727
1723
  ) -> None:
1728
- import google.generativeai as google_genai
1724
+ import google.genai as google_genai
1729
1725
 
1730
1726
  super().__init__(model=model, credentials=credentials)
1731
1727
  self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
@@ -1742,12 +1738,12 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1742
1738
  if not api_key:
1743
1739
  raise BadRequest("An API key is required for Gemini models")
1744
1740
 
1745
- google_genai.configure(api_key=api_key)
1741
+ self.client = google_genai.Client(api_key=api_key)
1746
1742
  self.model_name = model.name
1747
1743
 
1748
1744
  @classmethod
1749
1745
  def dependencies(cls) -> list[Dependency]:
1750
- return [Dependency(name="google-generativeai", module_name="google.generativeai")]
1746
+ return [Dependency(name="google-genai", module_name="google.genai")]
1751
1747
 
1752
1748
  @classmethod
1753
1749
  def supported_invocation_parameters(cls) -> list[InvocationParameter]:
@@ -1802,28 +1798,19 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1802
1798
  tools: list[JSONScalarType],
1803
1799
  **invocation_parameters: Any,
1804
1800
  ) -> AsyncIterator[ChatCompletionChunk]:
1805
- import google.generativeai as google_genai
1801
+ contents, system_prompt = self._build_google_messages(messages)
1806
1802
 
1807
- google_message_history, current_message, system_prompt = self._build_google_messages(
1808
- messages
1809
- )
1810
-
1811
- model_args = {"model_name": self.model_name}
1803
+ # Build config object for the new API
1804
+ config = invocation_parameters
1812
1805
  if system_prompt:
1813
- model_args["system_instruction"] = system_prompt
1814
- client = google_genai.GenerativeModel(**model_args)
1806
+ config["system_instruction"] = system_prompt
1815
1807
 
1816
- google_config = google_genai.GenerationConfig(
1817
- **invocation_parameters,
1808
+ # Use the client's async models.generate_content_stream method
1809
+ stream = await self.client.aio.models.generate_content_stream(
1810
+ model=f"models/{self.model_name}",
1811
+ contents=contents,
1812
+ config=config if config else None,
1818
1813
  )
1819
- google_params = {
1820
- "content": current_message,
1821
- "generation_config": google_config,
1822
- "stream": True,
1823
- }
1824
-
1825
- chat = client.start_chat(history=google_message_history)
1826
- stream = await chat.send_message_async(**google_params)
1827
1814
  async for event in stream:
1828
1815
  self._attributes.update(
1829
1816
  {
@@ -1837,26 +1824,70 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1837
1824
  def _build_google_messages(
1838
1825
  self,
1839
1826
  messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
1840
- ) -> tuple[list["ContentType"], str, str]:
1841
- google_message_history: list["ContentType"] = []
1827
+ ) -> tuple[list["ContentType"], str]:
1828
+ """Build Google messages following the standard pattern - process ALL messages."""
1829
+ google_messages: list["ContentType"] = []
1842
1830
  system_prompts = []
1843
1831
  for role, content, _tool_call_id, _tool_calls in messages:
1844
1832
  if role == ChatCompletionMessageRole.USER:
1845
- google_message_history.append({"role": "user", "parts": content})
1833
+ google_messages.append({"role": "user", "parts": [{"text": content}]})
1846
1834
  elif role == ChatCompletionMessageRole.AI:
1847
- google_message_history.append({"role": "model", "parts": content})
1835
+ google_messages.append({"role": "model", "parts": [{"text": content}]})
1848
1836
  elif role == ChatCompletionMessageRole.SYSTEM:
1849
1837
  system_prompts.append(content)
1850
1838
  elif role == ChatCompletionMessageRole.TOOL:
1851
1839
  raise NotImplementedError
1852
1840
  else:
1853
1841
  assert_never(role)
1854
- if google_message_history:
1855
- prompt = google_message_history.pop()["parts"]
1856
- else:
1857
- prompt = ""
1858
1842
 
1859
- return google_message_history, prompt, "\n".join(system_prompts)
1843
+ return google_messages, "\n".join(system_prompts)
1844
+
1845
+
1846
+ @register_llm_client(
1847
+ provider_key=GenerativeProviderKey.GOOGLE,
1848
+ model_names=[
1849
+ PROVIDER_DEFAULT,
1850
+ "gemini-2.5-pro",
1851
+ "gemini-2.5-flash",
1852
+ "gemini-2.5-flash-lite",
1853
+ "gemini-2.5-pro-preview-03-25",
1854
+ ],
1855
+ )
1856
+ class Gemini25GoogleStreamingClient(GoogleStreamingClient):
1857
+ @classmethod
1858
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
1859
+ return [
1860
+ BoundedFloatInvocationParameter(
1861
+ invocation_name="temperature",
1862
+ canonical_name=CanonicalParameterName.TEMPERATURE,
1863
+ label="Temperature",
1864
+ default_value=1.0,
1865
+ min_value=0.0,
1866
+ max_value=2.0,
1867
+ ),
1868
+ IntInvocationParameter(
1869
+ invocation_name="max_output_tokens",
1870
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
1871
+ label="Max Output Tokens",
1872
+ ),
1873
+ StringListInvocationParameter(
1874
+ invocation_name="stop_sequences",
1875
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
1876
+ label="Stop Sequences",
1877
+ ),
1878
+ BoundedFloatInvocationParameter(
1879
+ invocation_name="top_p",
1880
+ canonical_name=CanonicalParameterName.TOP_P,
1881
+ label="Top P",
1882
+ default_value=1.0,
1883
+ min_value=0.0,
1884
+ max_value=1.0,
1885
+ ),
1886
+ FloatInvocationParameter(
1887
+ invocation_name="top_k",
1888
+ label="Top K",
1889
+ ),
1890
+ ]
1860
1891
 
1861
1892
 
1862
1893
  def initialize_playground_clients() -> None:
@@ -0,0 +1,26 @@
1
+ from typing import (
2
+ Optional,
3
+ )
4
+
5
+ from starlette.requests import Request
6
+ from strawberry import Info
7
+
8
+ from phoenix.server.api.context import Context
9
+ from phoenix.server.bearer_auth import PhoenixUser
10
+
11
+
12
+ def get_user(info: Info[Context, None]) -> Optional[int]:
13
+ user_id: Optional[int] = None
14
+ try:
15
+ assert isinstance(request := info.context.request, Request)
16
+
17
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
18
+ user_id = int(user.identity)
19
+ except AssertionError:
20
+ # Request is not available, try to obtain user identify
21
+ # this will also throw an assertion error if auth is not available
22
+ # the finally block will continue execution returning None
23
+ if info.context.user.is_authenticated:
24
+ user_id = int(info.context.user.identity)
25
+ finally:
26
+ return user_id
@@ -1,8 +1,9 @@
1
- from typing import Optional
1
+ from typing import Optional, Union
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
5
5
  from strawberry.relay import GlobalID
6
+ from typing_extensions import TypeAlias
6
7
 
7
8
  from phoenix.db import models
8
9
  from phoenix.server.api.exceptions import BadRequest
@@ -11,7 +12,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
11
12
 
12
13
 
13
14
  @strawberry.input
14
- class SpanAnnotationFilterCondition:
15
+ class AnnotationFilterCondition:
15
16
  names: Optional[list[str]] = UNSET
16
17
  sources: Optional[list[AnnotationSource]] = UNSET
17
18
  user_ids: Optional[list[Optional[GlobalID]]] = UNSET
@@ -26,42 +27,49 @@ class SpanAnnotationFilterCondition:
26
27
 
27
28
 
28
29
  @strawberry.input
29
- class SpanAnnotationFilter:
30
- include: Optional[SpanAnnotationFilterCondition] = UNSET
31
- exclude: Optional[SpanAnnotationFilterCondition] = UNSET
30
+ class AnnotationFilter:
31
+ include: Optional[AnnotationFilterCondition] = UNSET
32
+ exclude: Optional[AnnotationFilterCondition] = UNSET
32
33
 
33
34
  def __post_init__(self) -> None:
34
35
  if self.include is UNSET and self.exclude is UNSET:
35
36
  raise BadRequest("include and exclude cannot both be unset")
36
37
 
37
38
 
38
- def satisfies_filter(span_annotation: models.SpanAnnotation, filter: SpanAnnotationFilter) -> bool:
39
+ _Annotation: TypeAlias = Union[
40
+ models.SpanAnnotation,
41
+ models.TraceAnnotation,
42
+ models.ProjectSessionAnnotation,
43
+ ]
44
+
45
+
46
+ def satisfies_filter(annotation: _Annotation, filter: AnnotationFilter) -> bool:
39
47
  """
40
- Returns true if the span annotation satisfies the filter and false otherwise.
48
+ Returns true if the annotation satisfies the filter and false otherwise.
41
49
  """
42
- span_annotation_source = AnnotationSource(span_annotation.source)
50
+ annotation_source = AnnotationSource(annotation.source)
43
51
  if include := filter.include:
44
- if include.names and span_annotation.name not in include.names:
52
+ if include.names and annotation.name not in include.names:
45
53
  return False
46
- if include.sources and span_annotation_source not in include.sources:
54
+ if include.sources and annotation_source not in include.sources:
47
55
  return False
48
56
  if include.user_ids:
49
57
  user_rowids = [
50
58
  from_global_id_with_expected_type(user_id, "User") if user_id is not None else None
51
59
  for user_id in include.user_ids
52
60
  ]
53
- if span_annotation.user_id not in user_rowids:
61
+ if annotation.user_id not in user_rowids:
54
62
  return False
55
63
  if exclude := filter.exclude:
56
- if exclude.names and span_annotation.name in exclude.names:
64
+ if exclude.names and annotation.name in exclude.names:
57
65
  return False
58
- if exclude.sources and span_annotation_source in exclude.sources:
66
+ if exclude.sources and annotation_source in exclude.sources:
59
67
  return False
60
68
  if exclude.user_ids:
61
69
  user_rowids = [
62
70
  from_global_id_with_expected_type(user_id, "User") if user_id is not None else None
63
71
  for user_id in exclude.user_ids
64
72
  ]
65
- if span_annotation.user_id in user_rowids:
73
+ if annotation.user_id in user_rowids:
66
74
  return False
67
75
  return True
@@ -0,0 +1,37 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import GlobalID
5
+ from strawberry.scalars import JSON
6
+
7
+ from phoenix.server.api.exceptions import BadRequest
8
+ from phoenix.server.api.types.AnnotationSource import AnnotationSource
9
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
10
+
11
+
12
+ @strawberry.input
13
+ class CreateProjectSessionAnnotationInput:
14
+ project_session_id: GlobalID
15
+ name: str
16
+ annotator_kind: AnnotatorKind = AnnotatorKind.HUMAN
17
+ label: Optional[str] = None
18
+ score: Optional[float] = None
19
+ explanation: Optional[str] = None
20
+ metadata: JSON = strawberry.field(default_factory=dict)
21
+ source: AnnotationSource = AnnotationSource.APP
22
+ identifier: Optional[str] = strawberry.UNSET
23
+
24
+ def __post_init__(self) -> None:
25
+ self.name = self.name.strip()
26
+ if isinstance(self.label, str):
27
+ self.label = self.label.strip()
28
+ if not self.label:
29
+ self.label = None
30
+ if isinstance(self.explanation, str):
31
+ self.explanation = self.explanation.strip()
32
+ if not self.explanation:
33
+ self.explanation = None
34
+ if isinstance(self.identifier, str):
35
+ self.identifier = self.identifier.strip()
36
+ if self.score is None and not self.label and not self.explanation:
37
+ raise BadRequest("At least one of score, label, or explanation must be not null/empty.")
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import GlobalID
5
+ from strawberry.scalars import JSON
6
+
7
+ from phoenix.server.api.exceptions import BadRequest
8
+ from phoenix.server.api.types.AnnotationSource import AnnotationSource
9
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
10
+
11
+
12
+ @strawberry.input
13
+ class UpdateAnnotationInput:
14
+ id: GlobalID
15
+ name: str
16
+ annotator_kind: AnnotatorKind = AnnotatorKind.HUMAN
17
+ label: Optional[str] = None
18
+ score: Optional[float] = None
19
+ explanation: Optional[str] = None
20
+ metadata: JSON = strawberry.field(default_factory=dict)
21
+ source: AnnotationSource = AnnotationSource.APP
22
+
23
+ def __post_init__(self) -> None:
24
+ self.name = self.name.strip()
25
+ if isinstance(self.label, str):
26
+ self.label = self.label.strip()
27
+ if not self.label:
28
+ self.label = None
29
+ if isinstance(self.explanation, str):
30
+ self.explanation = self.explanation.strip()
31
+ if not self.explanation:
32
+ self.explanation = None
33
+ if self.score is None and not self.label and not self.explanation:
34
+ raise BadRequest("At least one of score, label, or explanation must be not null/empty.")