arize-phoenix 11.38.0__py3-none-any.whl → 12.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +71 -50
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/types.py +30 -0
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +285 -46
- phoenix/server/api/context.py +13 -2
- phoenix/server/api/dataloaders/__init__.py +6 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +65 -35
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +8 -3
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +32 -0
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +5 -2
- phoenix/server/api/types/DatasetExample.py +11 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +6 -2
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-BQPHTBfv.js → components-Dl9SUw1U.js} +371 -327
- phoenix/server/static/assets/{index-BL5BMgJU.js → index-CqQS0dTo.js} +2 -2
- phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DKSjVA_E.js} +762 -514
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/api/context.py
CHANGED
|
@@ -23,6 +23,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
23
23
|
DatasetExampleRevisionsDataLoader,
|
|
24
24
|
DatasetExamplesAndVersionsByExperimentRunDataLoader,
|
|
25
25
|
DatasetExampleSpansDataLoader,
|
|
26
|
+
DatasetExampleSplitsDataLoader,
|
|
26
27
|
DocumentEvaluationsDataLoader,
|
|
27
28
|
DocumentEvaluationSummaryDataLoader,
|
|
28
29
|
DocumentRetrievalMetricsDataLoader,
|
|
@@ -30,7 +31,6 @@ from phoenix.server.api.dataloaders import (
|
|
|
30
31
|
ExperimentErrorRatesDataLoader,
|
|
31
32
|
ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
|
|
32
33
|
ExperimentRepeatedRunGroupsDataLoader,
|
|
33
|
-
ExperimentRepetitionCountsDataLoader,
|
|
34
34
|
ExperimentRunAnnotations,
|
|
35
35
|
ExperimentRunCountsDataLoader,
|
|
36
36
|
ExperimentSequenceNumberDataLoader,
|
|
@@ -43,6 +43,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
43
43
|
ProjectIdsByTraceRetentionPolicyIdDataLoader,
|
|
44
44
|
PromptVersionSequenceNumberDataLoader,
|
|
45
45
|
RecordCountDataLoader,
|
|
46
|
+
SessionAnnotationsBySessionDataLoader,
|
|
46
47
|
SessionIODataLoader,
|
|
47
48
|
SessionNumTracesDataLoader,
|
|
48
49
|
SessionNumTracesWithErrorDataLoader,
|
|
@@ -68,6 +69,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
68
69
|
SpanProjectsDataLoader,
|
|
69
70
|
TableFieldsDataLoader,
|
|
70
71
|
TokenCountDataLoader,
|
|
72
|
+
TraceAnnotationsByTraceDataLoader,
|
|
71
73
|
TraceByTraceIdsDataLoader,
|
|
72
74
|
TraceRetentionPolicyIdByProjectIdDataLoader,
|
|
73
75
|
TraceRootSpansDataLoader,
|
|
@@ -100,6 +102,7 @@ class DataLoaders:
|
|
|
100
102
|
dataset_examples_and_versions_by_experiment_run: (
|
|
101
103
|
DatasetExamplesAndVersionsByExperimentRunDataLoader
|
|
102
104
|
)
|
|
105
|
+
dataset_example_splits: DatasetExampleSplitsDataLoader
|
|
103
106
|
document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
|
|
104
107
|
document_evaluations: DocumentEvaluationsDataLoader
|
|
105
108
|
document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
|
|
@@ -109,7 +112,6 @@ class DataLoaders:
|
|
|
109
112
|
ExperimentRepeatedRunGroupAnnotationSummariesDataLoader
|
|
110
113
|
)
|
|
111
114
|
experiment_repeated_run_groups: ExperimentRepeatedRunGroupsDataLoader
|
|
112
|
-
experiment_repetition_counts: ExperimentRepetitionCountsDataLoader
|
|
113
115
|
experiment_run_annotations: ExperimentRunAnnotations
|
|
114
116
|
experiment_run_counts: ExperimentRunCountsDataLoader
|
|
115
117
|
experiment_sequence_number: ExperimentSequenceNumberDataLoader
|
|
@@ -124,6 +126,7 @@ class DataLoaders:
|
|
|
124
126
|
projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
|
|
125
127
|
prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
|
|
126
128
|
record_counts: RecordCountDataLoader
|
|
129
|
+
session_annotations_by_session: SessionAnnotationsBySessionDataLoader
|
|
127
130
|
session_first_inputs: SessionIODataLoader
|
|
128
131
|
session_last_outputs: SessionIODataLoader
|
|
129
132
|
session_num_traces: SessionNumTracesDataLoader
|
|
@@ -158,6 +161,7 @@ class DataLoaders:
|
|
|
158
161
|
span_fields: TableFieldsDataLoader
|
|
159
162
|
span_projects: SpanProjectsDataLoader
|
|
160
163
|
token_counts: TokenCountDataLoader
|
|
164
|
+
trace_annotations_by_trace: TraceAnnotationsByTraceDataLoader
|
|
161
165
|
trace_by_trace_ids: TraceByTraceIdsDataLoader
|
|
162
166
|
trace_fields: TableFieldsDataLoader
|
|
163
167
|
trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
|
|
@@ -237,3 +241,10 @@ class Context(BaseContext):
|
|
|
237
241
|
@cached_property
|
|
238
242
|
def user(self) -> PhoenixUser:
|
|
239
243
|
return cast(PhoenixUser, self.get_request().user)
|
|
244
|
+
|
|
245
|
+
@cached_property
|
|
246
|
+
def user_id(self) -> Optional[int]:
|
|
247
|
+
try:
|
|
248
|
+
return int(self.user.identity)
|
|
249
|
+
except Exception:
|
|
250
|
+
return None
|
|
@@ -12,6 +12,7 @@ from .average_experiment_repeated_run_group_latency import (
|
|
|
12
12
|
from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
|
|
13
13
|
from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
|
|
14
14
|
from .dataset_example_spans import DatasetExampleSpansDataLoader
|
|
15
|
+
from .dataset_example_splits import DatasetExampleSplitsDataLoader
|
|
15
16
|
from .dataset_examples_and_versions_by_experiment_run import (
|
|
16
17
|
DatasetExamplesAndVersionsByExperimentRunDataLoader,
|
|
17
18
|
)
|
|
@@ -27,7 +28,6 @@ from .experiment_repeated_run_group_annotation_summaries import (
|
|
|
27
28
|
ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
|
|
28
29
|
)
|
|
29
30
|
from .experiment_repeated_run_groups import ExperimentRepeatedRunGroupsDataLoader
|
|
30
|
-
from .experiment_repetition_counts import ExperimentRepetitionCountsDataLoader
|
|
31
31
|
from .experiment_run_annotations import ExperimentRunAnnotations
|
|
32
32
|
from .experiment_run_counts import ExperimentRunCountsDataLoader
|
|
33
33
|
from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
|
|
@@ -40,6 +40,7 @@ from .project_by_name import ProjectByNameDataLoader
|
|
|
40
40
|
from .project_ids_by_trace_retention_policy_id import ProjectIdsByTraceRetentionPolicyIdDataLoader
|
|
41
41
|
from .prompt_version_sequence_number import PromptVersionSequenceNumberDataLoader
|
|
42
42
|
from .record_counts import RecordCountCache, RecordCountDataLoader
|
|
43
|
+
from .session_annotations_by_session import SessionAnnotationsBySessionDataLoader
|
|
43
44
|
from .session_io import SessionIODataLoader
|
|
44
45
|
from .session_num_traces import SessionNumTracesDataLoader
|
|
45
46
|
from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader
|
|
@@ -69,6 +70,7 @@ from .span_descendants import SpanDescendantsDataLoader
|
|
|
69
70
|
from .span_projects import SpanProjectsDataLoader
|
|
70
71
|
from .table_fields import TableFieldsDataLoader
|
|
71
72
|
from .token_counts import TokenCountCache, TokenCountDataLoader
|
|
73
|
+
from .trace_annotations_by_trace import TraceAnnotationsByTraceDataLoader
|
|
72
74
|
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
|
|
73
75
|
from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
|
|
74
76
|
from .trace_root_spans import TraceRootSpansDataLoader
|
|
@@ -84,6 +86,7 @@ __all__ = [
|
|
|
84
86
|
"DatasetExampleRevisionsDataLoader",
|
|
85
87
|
"DatasetExampleSpansDataLoader",
|
|
86
88
|
"DatasetExamplesAndVersionsByExperimentRunDataLoader",
|
|
89
|
+
"DatasetExampleSplitsDataLoader",
|
|
87
90
|
"DocumentEvaluationSummaryDataLoader",
|
|
88
91
|
"DocumentEvaluationsDataLoader",
|
|
89
92
|
"DocumentRetrievalMetricsDataLoader",
|
|
@@ -91,7 +94,6 @@ __all__ = [
|
|
|
91
94
|
"ExperimentErrorRatesDataLoader",
|
|
92
95
|
"ExperimentRepeatedRunGroupsDataLoader",
|
|
93
96
|
"ExperimentRepeatedRunGroupAnnotationSummariesDataLoader",
|
|
94
|
-
"ExperimentRepetitionCountsDataLoader",
|
|
95
97
|
"ExperimentRunAnnotations",
|
|
96
98
|
"ExperimentRunCountsDataLoader",
|
|
97
99
|
"ExperimentSequenceNumberDataLoader",
|
|
@@ -104,6 +106,7 @@ __all__ = [
|
|
|
104
106
|
"ProjectIdsByTraceRetentionPolicyIdDataLoader",
|
|
105
107
|
"PromptVersionSequenceNumberDataLoader",
|
|
106
108
|
"RecordCountDataLoader",
|
|
109
|
+
"SessionAnnotationsBySessionDataLoader",
|
|
107
110
|
"SessionIODataLoader",
|
|
108
111
|
"SessionNumTracesDataLoader",
|
|
109
112
|
"SessionNumTracesWithErrorDataLoader",
|
|
@@ -130,6 +133,7 @@ __all__ = [
|
|
|
130
133
|
"SpanProjectsDataLoader",
|
|
131
134
|
"TableFieldsDataLoader",
|
|
132
135
|
"TokenCountDataLoader",
|
|
136
|
+
"TraceAnnotationsByTraceDataLoader",
|
|
133
137
|
"TraceByTraceIdsDataLoader",
|
|
134
138
|
"TraceRetentionPolicyIdByProjectIdDataLoader",
|
|
135
139
|
"TraceRootSpansDataLoader",
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExampleID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = ExampleID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetExampleSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
example_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[ExampleID, list[models.DatasetSplit]] = {}
|
|
24
|
+
|
|
25
|
+
async for example_id, split in await session.stream(
|
|
26
|
+
select(models.DatasetSplitDatasetExample.dataset_example_id, models.DatasetSplit)
|
|
27
|
+
.select_from(models.DatasetSplit)
|
|
28
|
+
.join(
|
|
29
|
+
models.DatasetSplitDatasetExample,
|
|
30
|
+
onclause=(
|
|
31
|
+
models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
.where(models.DatasetSplitDatasetExample.dataset_example_id.in_(example_ids))
|
|
35
|
+
):
|
|
36
|
+
if example_id not in splits:
|
|
37
|
+
splits[example_id] = []
|
|
38
|
+
splits[example_id].append(split)
|
|
39
|
+
|
|
40
|
+
return [sorted(splits.get(example_id, []), key=lambda x: x.name) for example_id in keys]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.models import ProjectSessionAnnotation
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ProjectSessionId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ProjectSessionId
|
|
12
|
+
Result: TypeAlias = list[ProjectSessionAnnotation]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SessionAnnotationsBySessionDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
async for annotation in await session.stream_scalars(
|
|
24
|
+
select(ProjectSessionAnnotation).where(
|
|
25
|
+
ProjectSessionAnnotation.project_session_id.in_(keys)
|
|
26
|
+
)
|
|
27
|
+
):
|
|
28
|
+
annotations_by_id[annotation.project_session_id].append(annotation)
|
|
29
|
+
return [annotations_by_id[key] for key in keys]
|
|
@@ -18,7 +18,7 @@ _AttrStrIdentifier: TypeAlias = str
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
21
|
-
def __init__(self, db: DbSessionFactory, table: type[models.
|
|
21
|
+
def __init__(self, db: DbSessionFactory, table: type[models.HasId]) -> None:
|
|
22
22
|
super().__init__(load_fn=self._load_fn)
|
|
23
23
|
self._db = db
|
|
24
24
|
self._table = table
|
|
@@ -37,7 +37,7 @@ class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
|
37
37
|
|
|
38
38
|
def _get_stmt(
|
|
39
39
|
keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
|
|
40
|
-
table: type[models.
|
|
40
|
+
table: type[models.HasId],
|
|
41
41
|
) -> tuple[
|
|
42
42
|
Select[Any],
|
|
43
43
|
dict[_ResultColumnPosition, _AttrStrIdentifier],
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.models import TraceAnnotation
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
TraceRowId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = TraceRowId
|
|
12
|
+
Result: TypeAlias = list[TraceAnnotation]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TraceAnnotationsByTraceDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
async for annotation in await session.stream_scalars(
|
|
24
|
+
select(TraceAnnotation).where(TraceAnnotation.trace_rowid.in_(keys))
|
|
25
|
+
):
|
|
26
|
+
annotations_by_id[annotation.trace_rowid].append(annotation)
|
|
27
|
+
return [annotations_by_id[key] for key in keys]
|
|
@@ -1705,11 +1705,6 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
|
|
|
1705
1705
|
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1706
1706
|
model_names=[
|
|
1707
1707
|
PROVIDER_DEFAULT,
|
|
1708
|
-
"gemini-2.5-flash",
|
|
1709
|
-
"gemini-2.5-flash-lite",
|
|
1710
|
-
"gemini-2.5-pro",
|
|
1711
|
-
"gemini-2.5-pro-preview-03-25",
|
|
1712
|
-
"gemini-2.0-flash",
|
|
1713
1708
|
"gemini-2.0-flash-lite",
|
|
1714
1709
|
"gemini-2.0-flash-001",
|
|
1715
1710
|
"gemini-2.0-flash-thinking-exp-01-21",
|
|
@@ -1725,7 +1720,7 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1725
1720
|
model: GenerativeModelInput,
|
|
1726
1721
|
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
1727
1722
|
) -> None:
|
|
1728
|
-
import google.
|
|
1723
|
+
import google.genai as google_genai
|
|
1729
1724
|
|
|
1730
1725
|
super().__init__(model=model, credentials=credentials)
|
|
1731
1726
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
|
|
@@ -1742,12 +1737,12 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1742
1737
|
if not api_key:
|
|
1743
1738
|
raise BadRequest("An API key is required for Gemini models")
|
|
1744
1739
|
|
|
1745
|
-
google_genai.
|
|
1740
|
+
self.client = google_genai.Client(api_key=api_key)
|
|
1746
1741
|
self.model_name = model.name
|
|
1747
1742
|
|
|
1748
1743
|
@classmethod
|
|
1749
1744
|
def dependencies(cls) -> list[Dependency]:
|
|
1750
|
-
return [Dependency(name="google-
|
|
1745
|
+
return [Dependency(name="google-genai", module_name="google.genai")]
|
|
1751
1746
|
|
|
1752
1747
|
@classmethod
|
|
1753
1748
|
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
@@ -1802,28 +1797,19 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1802
1797
|
tools: list[JSONScalarType],
|
|
1803
1798
|
**invocation_parameters: Any,
|
|
1804
1799
|
) -> AsyncIterator[ChatCompletionChunk]:
|
|
1805
|
-
|
|
1800
|
+
contents, system_prompt = self._build_google_messages(messages)
|
|
1806
1801
|
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
)
|
|
1810
|
-
|
|
1811
|
-
model_args = {"model_name": self.model_name}
|
|
1802
|
+
# Build config object for the new API
|
|
1803
|
+
config = invocation_parameters
|
|
1812
1804
|
if system_prompt:
|
|
1813
|
-
|
|
1814
|
-
client = google_genai.GenerativeModel(**model_args)
|
|
1805
|
+
config["system_instruction"] = system_prompt
|
|
1815
1806
|
|
|
1816
|
-
|
|
1817
|
-
|
|
1807
|
+
# Use the client's async models.generate_content_stream method
|
|
1808
|
+
stream = await self.client.aio.models.generate_content_stream(
|
|
1809
|
+
model=f"models/{self.model_name}",
|
|
1810
|
+
contents=contents,
|
|
1811
|
+
config=config if config else None,
|
|
1818
1812
|
)
|
|
1819
|
-
google_params = {
|
|
1820
|
-
"content": current_message,
|
|
1821
|
-
"generation_config": google_config,
|
|
1822
|
-
"stream": True,
|
|
1823
|
-
}
|
|
1824
|
-
|
|
1825
|
-
chat = client.start_chat(history=google_message_history)
|
|
1826
|
-
stream = await chat.send_message_async(**google_params)
|
|
1827
1813
|
async for event in stream:
|
|
1828
1814
|
self._attributes.update(
|
|
1829
1815
|
{
|
|
@@ -1837,26 +1823,70 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1837
1823
|
def _build_google_messages(
|
|
1838
1824
|
self,
|
|
1839
1825
|
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
1840
|
-
) -> tuple[list["ContentType"], str
|
|
1841
|
-
|
|
1826
|
+
) -> tuple[list["ContentType"], str]:
|
|
1827
|
+
"""Build Google messages following the standard pattern - process ALL messages."""
|
|
1828
|
+
google_messages: list["ContentType"] = []
|
|
1842
1829
|
system_prompts = []
|
|
1843
1830
|
for role, content, _tool_call_id, _tool_calls in messages:
|
|
1844
1831
|
if role == ChatCompletionMessageRole.USER:
|
|
1845
|
-
|
|
1832
|
+
google_messages.append({"role": "user", "parts": [{"text": content}]})
|
|
1846
1833
|
elif role == ChatCompletionMessageRole.AI:
|
|
1847
|
-
|
|
1834
|
+
google_messages.append({"role": "model", "parts": [{"text": content}]})
|
|
1848
1835
|
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
1849
1836
|
system_prompts.append(content)
|
|
1850
1837
|
elif role == ChatCompletionMessageRole.TOOL:
|
|
1851
1838
|
raise NotImplementedError
|
|
1852
1839
|
else:
|
|
1853
1840
|
assert_never(role)
|
|
1854
|
-
if google_message_history:
|
|
1855
|
-
prompt = google_message_history.pop()["parts"]
|
|
1856
|
-
else:
|
|
1857
|
-
prompt = ""
|
|
1858
1841
|
|
|
1859
|
-
return
|
|
1842
|
+
return google_messages, "\n".join(system_prompts)
|
|
1843
|
+
|
|
1844
|
+
|
|
1845
|
+
@register_llm_client(
|
|
1846
|
+
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1847
|
+
model_names=[
|
|
1848
|
+
PROVIDER_DEFAULT,
|
|
1849
|
+
"gemini-2.5-pro",
|
|
1850
|
+
"gemini-2.5-flash",
|
|
1851
|
+
"gemini-2.5-flash-lite",
|
|
1852
|
+
"gemini-2.5-pro-preview-03-25",
|
|
1853
|
+
],
|
|
1854
|
+
)
|
|
1855
|
+
class Gemini25GoogleStreamingClient(GoogleStreamingClient):
|
|
1856
|
+
@classmethod
|
|
1857
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1858
|
+
return [
|
|
1859
|
+
BoundedFloatInvocationParameter(
|
|
1860
|
+
invocation_name="temperature",
|
|
1861
|
+
canonical_name=CanonicalParameterName.TEMPERATURE,
|
|
1862
|
+
label="Temperature",
|
|
1863
|
+
default_value=1.0,
|
|
1864
|
+
min_value=0.0,
|
|
1865
|
+
max_value=2.0,
|
|
1866
|
+
),
|
|
1867
|
+
IntInvocationParameter(
|
|
1868
|
+
invocation_name="max_output_tokens",
|
|
1869
|
+
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
1870
|
+
label="Max Output Tokens",
|
|
1871
|
+
),
|
|
1872
|
+
StringListInvocationParameter(
|
|
1873
|
+
invocation_name="stop_sequences",
|
|
1874
|
+
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
1875
|
+
label="Stop Sequences",
|
|
1876
|
+
),
|
|
1877
|
+
BoundedFloatInvocationParameter(
|
|
1878
|
+
invocation_name="top_p",
|
|
1879
|
+
canonical_name=CanonicalParameterName.TOP_P,
|
|
1880
|
+
label="Top P",
|
|
1881
|
+
default_value=1.0,
|
|
1882
|
+
min_value=0.0,
|
|
1883
|
+
max_value=1.0,
|
|
1884
|
+
),
|
|
1885
|
+
FloatInvocationParameter(
|
|
1886
|
+
invocation_name="top_k",
|
|
1887
|
+
label="Top K",
|
|
1888
|
+
),
|
|
1889
|
+
]
|
|
1860
1890
|
|
|
1861
1891
|
|
|
1862
1892
|
def initialize_playground_clients() -> None:
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Optional,
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from starlette.requests import Request
|
|
6
|
+
from strawberry import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.server.api.context import Context
|
|
9
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_user(info: Info[Context, None]) -> Optional[int]:
|
|
13
|
+
user_id: Optional[int] = None
|
|
14
|
+
try:
|
|
15
|
+
assert isinstance(request := info.context.request, Request)
|
|
16
|
+
|
|
17
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
18
|
+
user_id = int(user.identity)
|
|
19
|
+
except AssertionError:
|
|
20
|
+
# Request is not available, try to obtain user identify
|
|
21
|
+
# this will also throw an assertion error if auth is not available
|
|
22
|
+
# the finally block will continue execution returning None
|
|
23
|
+
if info.context.user.is_authenticated:
|
|
24
|
+
user_id = int(info.context.user.identity)
|
|
25
|
+
finally:
|
|
26
|
+
return user_id
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
4
|
from strawberry import UNSET
|
|
5
5
|
from strawberry.relay import GlobalID
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
6
7
|
|
|
7
8
|
from phoenix.db import models
|
|
8
9
|
from phoenix.server.api.exceptions import BadRequest
|
|
@@ -11,7 +12,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@strawberry.input
|
|
14
|
-
class
|
|
15
|
+
class AnnotationFilterCondition:
|
|
15
16
|
names: Optional[list[str]] = UNSET
|
|
16
17
|
sources: Optional[list[AnnotationSource]] = UNSET
|
|
17
18
|
user_ids: Optional[list[Optional[GlobalID]]] = UNSET
|
|
@@ -26,42 +27,49 @@ class SpanAnnotationFilterCondition:
|
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@strawberry.input
|
|
29
|
-
class
|
|
30
|
-
include: Optional[
|
|
31
|
-
exclude: Optional[
|
|
30
|
+
class AnnotationFilter:
|
|
31
|
+
include: Optional[AnnotationFilterCondition] = UNSET
|
|
32
|
+
exclude: Optional[AnnotationFilterCondition] = UNSET
|
|
32
33
|
|
|
33
34
|
def __post_init__(self) -> None:
|
|
34
35
|
if self.include is UNSET and self.exclude is UNSET:
|
|
35
36
|
raise BadRequest("include and exclude cannot both be unset")
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
|
|
39
|
+
_Annotation: TypeAlias = Union[
|
|
40
|
+
models.SpanAnnotation,
|
|
41
|
+
models.TraceAnnotation,
|
|
42
|
+
models.ProjectSessionAnnotation,
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def satisfies_filter(annotation: _Annotation, filter: AnnotationFilter) -> bool:
|
|
39
47
|
"""
|
|
40
|
-
Returns true if the
|
|
48
|
+
Returns true if the annotation satisfies the filter and false otherwise.
|
|
41
49
|
"""
|
|
42
|
-
|
|
50
|
+
annotation_source = AnnotationSource(annotation.source)
|
|
43
51
|
if include := filter.include:
|
|
44
|
-
if include.names and
|
|
52
|
+
if include.names and annotation.name not in include.names:
|
|
45
53
|
return False
|
|
46
|
-
if include.sources and
|
|
54
|
+
if include.sources and annotation_source not in include.sources:
|
|
47
55
|
return False
|
|
48
56
|
if include.user_ids:
|
|
49
57
|
user_rowids = [
|
|
50
58
|
from_global_id_with_expected_type(user_id, "User") if user_id is not None else None
|
|
51
59
|
for user_id in include.user_ids
|
|
52
60
|
]
|
|
53
|
-
if
|
|
61
|
+
if annotation.user_id not in user_rowids:
|
|
54
62
|
return False
|
|
55
63
|
if exclude := filter.exclude:
|
|
56
|
-
if exclude.names and
|
|
64
|
+
if exclude.names and annotation.name in exclude.names:
|
|
57
65
|
return False
|
|
58
|
-
if exclude.sources and
|
|
66
|
+
if exclude.sources and annotation_source in exclude.sources:
|
|
59
67
|
return False
|
|
60
68
|
if exclude.user_ids:
|
|
61
69
|
user_rowids = [
|
|
62
70
|
from_global_id_with_expected_type(user_id, "User") if user_id is not None else None
|
|
63
71
|
for user_id in exclude.user_ids
|
|
64
72
|
]
|
|
65
|
-
if
|
|
73
|
+
if annotation.user_id in user_rowids:
|
|
66
74
|
return False
|
|
67
75
|
return True
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import GlobalID
|
|
5
|
+
from strawberry.scalars import JSON
|
|
6
|
+
|
|
7
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
8
|
+
from phoenix.server.api.types.AnnotationSource import AnnotationSource
|
|
9
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@strawberry.input
|
|
13
|
+
class CreateProjectSessionAnnotationInput:
|
|
14
|
+
project_session_id: GlobalID
|
|
15
|
+
name: str
|
|
16
|
+
annotator_kind: AnnotatorKind = AnnotatorKind.HUMAN
|
|
17
|
+
label: Optional[str] = None
|
|
18
|
+
score: Optional[float] = None
|
|
19
|
+
explanation: Optional[str] = None
|
|
20
|
+
metadata: JSON = strawberry.field(default_factory=dict)
|
|
21
|
+
source: AnnotationSource = AnnotationSource.APP
|
|
22
|
+
identifier: Optional[str] = strawberry.UNSET
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
self.name = self.name.strip()
|
|
26
|
+
if isinstance(self.label, str):
|
|
27
|
+
self.label = self.label.strip()
|
|
28
|
+
if not self.label:
|
|
29
|
+
self.label = None
|
|
30
|
+
if isinstance(self.explanation, str):
|
|
31
|
+
self.explanation = self.explanation.strip()
|
|
32
|
+
if not self.explanation:
|
|
33
|
+
self.explanation = None
|
|
34
|
+
if isinstance(self.identifier, str):
|
|
35
|
+
self.identifier = self.identifier.strip()
|
|
36
|
+
if self.score is None and not self.label and not self.explanation:
|
|
37
|
+
raise BadRequest("At least one of score, label, or explanation must be not null/empty.")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import GlobalID
|
|
5
|
+
from strawberry.scalars import JSON
|
|
6
|
+
|
|
7
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
8
|
+
from phoenix.server.api.types.AnnotationSource import AnnotationSource
|
|
9
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@strawberry.input
|
|
13
|
+
class UpdateAnnotationInput:
|
|
14
|
+
id: GlobalID
|
|
15
|
+
name: str
|
|
16
|
+
annotator_kind: AnnotatorKind = AnnotatorKind.HUMAN
|
|
17
|
+
label: Optional[str] = None
|
|
18
|
+
score: Optional[float] = None
|
|
19
|
+
explanation: Optional[str] = None
|
|
20
|
+
metadata: JSON = strawberry.field(default_factory=dict)
|
|
21
|
+
source: AnnotationSource = AnnotationSource.APP
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
self.name = self.name.strip()
|
|
25
|
+
if isinstance(self.label, str):
|
|
26
|
+
self.label = self.label.strip()
|
|
27
|
+
if not self.label:
|
|
28
|
+
self.label = None
|
|
29
|
+
if isinstance(self.explanation, str):
|
|
30
|
+
self.explanation = self.explanation.strip()
|
|
31
|
+
if not self.explanation:
|
|
32
|
+
self.explanation = None
|
|
33
|
+
if self.score is None and not self.label and not self.explanation:
|
|
34
|
+
raise BadRequest("At least one of score, label, or explanation must be not null/empty.")
|
|
@@ -6,10 +6,14 @@ from phoenix.server.api.mutations.chat_mutations import (
|
|
|
6
6
|
ChatCompletionMutationMixin,
|
|
7
7
|
)
|
|
8
8
|
from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
9
|
+
from phoenix.server.api.mutations.dataset_split_mutations import DatasetSplitMutationMixin
|
|
9
10
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
10
11
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
11
12
|
from phoenix.server.api.mutations.model_mutations import ModelMutationMixin
|
|
12
13
|
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
|
|
14
|
+
from phoenix.server.api.mutations.project_session_annotations_mutations import (
|
|
15
|
+
ProjectSessionAnnotationMutationMixin,
|
|
16
|
+
)
|
|
13
17
|
from phoenix.server.api.mutations.project_trace_retention_policy_mutations import (
|
|
14
18
|
ProjectTraceRetentionPolicyMutationMixin,
|
|
15
19
|
)
|
|
@@ -28,6 +32,7 @@ class Mutation(
|
|
|
28
32
|
ApiKeyMutationMixin,
|
|
29
33
|
ChatCompletionMutationMixin,
|
|
30
34
|
DatasetMutationMixin,
|
|
35
|
+
DatasetSplitMutationMixin,
|
|
31
36
|
ExperimentMutationMixin,
|
|
32
37
|
ExportEventsMutationMixin,
|
|
33
38
|
ModelMutationMixin,
|
|
@@ -37,6 +42,7 @@ class Mutation(
|
|
|
37
42
|
PromptVersionTagMutationMixin,
|
|
38
43
|
PromptLabelMutationMixin,
|
|
39
44
|
SpanAnnotationMutationMixin,
|
|
45
|
+
ProjectSessionAnnotationMutationMixin,
|
|
40
46
|
TraceAnnotationMutationMixin,
|
|
41
47
|
TraceMutationMixin,
|
|
42
48
|
UserMutationMixin,
|
|
@@ -26,7 +26,10 @@ from typing_extensions import assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
-
from phoenix.db.helpers import
|
|
29
|
+
from phoenix.db.helpers import (
|
|
30
|
+
get_dataset_example_revisions,
|
|
31
|
+
insert_experiment_with_examples_snapshot,
|
|
32
|
+
)
|
|
30
33
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
31
34
|
from phoenix.server.api.context import Context
|
|
32
35
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
@@ -46,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
46
49
|
llm_tools,
|
|
47
50
|
prompt_metadata,
|
|
48
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
49
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
50
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
51
55
|
ChatCompletionInput,
|
|
@@ -192,6 +196,7 @@ class ChatCompletionMutationMixin:
|
|
|
192
196
|
]
|
|
193
197
|
if not revisions:
|
|
194
198
|
raise NotFound("No examples found for the given dataset and version")
|
|
199
|
+
user_id = get_user(info)
|
|
195
200
|
experiment = models.Experiment(
|
|
196
201
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
197
202
|
dataset_version_id=resolved_version_id,
|
|
@@ -201,9 +206,9 @@ class ChatCompletionMutationMixin:
|
|
|
201
206
|
repetitions=1,
|
|
202
207
|
metadata_=input.experiment_metadata or dict(),
|
|
203
208
|
project_name=project_name,
|
|
209
|
+
user_id=user_id,
|
|
204
210
|
)
|
|
205
|
-
session
|
|
206
|
-
await session.flush()
|
|
211
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
207
212
|
|
|
208
213
|
results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
|
|
209
214
|
batch_size = 3
|