arize-phoenix 6.1.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (54) hide show
  1. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/METADATA +9 -8
  2. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/RECORD +52 -38
  3. phoenix/config.py +4 -1
  4. phoenix/db/engines.py +1 -1
  5. phoenix/db/insertion/span.py +65 -30
  6. phoenix/db/migrate.py +4 -1
  7. phoenix/db/migrations/data_migration_scripts/__init__.py +0 -0
  8. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  9. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  10. phoenix/db/models.py +27 -0
  11. phoenix/metrics/wrappers.py +7 -1
  12. phoenix/server/api/context.py +15 -2
  13. phoenix/server/api/dataloaders/__init__.py +14 -2
  14. phoenix/server/api/dataloaders/session_io.py +75 -0
  15. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  16. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  17. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  18. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  19. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  20. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  21. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  22. phoenix/server/api/mutations/chat_mutations.py +5 -0
  23. phoenix/server/api/mutations/project_mutations.py +12 -2
  24. phoenix/server/api/queries.py +14 -9
  25. phoenix/server/api/subscriptions.py +6 -0
  26. phoenix/server/api/types/EmbeddingDimension.py +1 -1
  27. phoenix/server/api/types/ExperimentRun.py +3 -4
  28. phoenix/server/api/types/ExperimentRunAnnotation.py +3 -4
  29. phoenix/server/api/types/Project.py +150 -12
  30. phoenix/server/api/types/ProjectSession.py +139 -0
  31. phoenix/server/api/types/Span.py +6 -19
  32. phoenix/server/api/types/SpanIOValue.py +15 -0
  33. phoenix/server/api/types/TokenUsage.py +11 -0
  34. phoenix/server/api/types/Trace.py +59 -2
  35. phoenix/server/app.py +15 -2
  36. phoenix/server/static/.vite/manifest.json +40 -31
  37. phoenix/server/static/assets/{components-CdiZ1Osh.js → components-DKH6AzJw.js} +410 -351
  38. phoenix/server/static/assets/index-DLV87qiO.js +93 -0
  39. phoenix/server/static/assets/{pages-FArMEfgg.js → pages-CVY3Nv4Z.js} +638 -316
  40. phoenix/server/static/assets/vendor-Cb3zlNNd.js +894 -0
  41. phoenix/server/static/assets/{vendor-arizeai-BG6iwyLC.js → vendor-arizeai-Buo4e1A6.js} +2 -2
  42. phoenix/server/static/assets/{vendor-codemirror-BotnVFFX.js → vendor-codemirror-BuAQiUVf.js} +5 -5
  43. phoenix/server/static/assets/{vendor-recharts-Dy5gEFzQ.js → vendor-recharts-Cl9dK5tC.js} +1 -1
  44. phoenix/server/static/assets/{vendor-Bnv1dNRQ.js → vendor-shiki-CazYUixL.js} +5 -898
  45. phoenix/session/client.py +13 -4
  46. phoenix/trace/fixtures.py +8 -0
  47. phoenix/trace/schemas.py +16 -0
  48. phoenix/version.py +1 -1
  49. phoenix/server/api/dataloaders/trace_row_ids.py +0 -33
  50. phoenix/server/static/assets/index-D_sCOjlG.js +0 -101
  51. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/WHEEL +0 -0
  52. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/entry_points.txt +0 -0
  53. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  54. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,32 @@
1
+ from typing import List, Optional
2
+
3
+ from sqlalchemy import select
4
+ from sqlalchemy.orm import contains_eager
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ Key: TypeAlias = int
12
+ Result: TypeAlias = Optional[models.Span]
13
+
14
+
15
+ class TraceRootSpansDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
21
+ stmt = (
22
+ select(models.Span)
23
+ .join(models.Trace)
24
+ .where(models.Span.parent_id.is_(None))
25
+ .where(models.Trace.id.in_(keys))
26
+ .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
27
+ )
28
+ async with self._db() as session:
29
+ result: dict[Key, models.Span] = {
30
+ span.trace_rowid: span async for span in await session.stream_scalars(stmt)
31
+ }
32
+ return [result.get(key) for key in keys]
@@ -0,0 +1,29 @@
1
+ from enum import Enum, auto
2
+
3
+ import strawberry
4
+ from typing_extensions import assert_never
5
+
6
+ from phoenix.server.api.types.pagination import CursorSortColumnDataType
7
+ from phoenix.server.api.types.SortDir import SortDir
8
+
9
+
10
+ @strawberry.enum
11
+ class ProjectSessionColumn(Enum):
12
+ startTime = auto()
13
+ endTime = auto()
14
+ tokenCountTotal = auto()
15
+ numTraces = auto()
16
+
17
+ @property
18
+ def data_type(self) -> CursorSortColumnDataType:
19
+ if self is ProjectSessionColumn.tokenCountTotal or self is ProjectSessionColumn.numTraces:
20
+ return CursorSortColumnDataType.INT
21
+ if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
22
+ return CursorSortColumnDataType.DATETIME
23
+ assert_never(self)
24
+
25
+
26
+ @strawberry.input(description="The sort key and direction for ProjectSession connections.")
27
+ class ProjectSessionSort:
28
+ col: ProjectSessionColumn
29
+ dir: SortDir
@@ -318,6 +318,9 @@ class ChatCompletionMutationMixin:
318
318
  ]
319
319
  if template_options := input.template:
320
320
  messages = list(_formatted_messages(messages, template_options))
321
+ attributes.update(
322
+ {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
323
+ )
321
324
 
322
325
  invocation_parameters = llm_client.construct_invocation_parameters(
323
326
  input.invocation_parameters
@@ -584,5 +587,7 @@ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
584
587
  TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
585
588
 
586
589
  TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
590
+ PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
591
+
587
592
 
588
593
  PLAYGROUND_PROJECT_NAME = "playground"
@@ -38,10 +38,20 @@ class ProjectMutationMixin:
38
38
  project_id = from_global_id_with_expected_type(
39
39
  global_id=input.id, expected_type_name="Project"
40
40
  )
41
- delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
41
+ delete_statement = (
42
+ delete(models.Trace)
43
+ .where(models.Trace.project_rowid == project_id)
44
+ .returning(models.Trace.project_session_rowid)
45
+ )
42
46
  if input.end_time:
43
47
  delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
44
48
  async with info.context.db() as session:
45
- await session.execute(delete_statement)
49
+ deleted_trace_project_session_ids = await session.scalars(delete_statement)
50
+ if deleted_trace_project_session_ids:
51
+ await session.execute(
52
+ delete(models.ProjectSession).where(
53
+ models.ProjectSession.id.in_(set(deleted_trace_project_session_ids))
54
+ )
55
+ )
46
56
  info.context.event_queue.put(SpanDeleteEvent((project_id,)))
47
57
  return Query()
@@ -78,10 +78,11 @@ from phoenix.server.api.types.pagination import (
78
78
  connection_from_list,
79
79
  )
80
80
  from phoenix.server.api.types.Project import Project
81
+ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
81
82
  from phoenix.server.api.types.SortDir import SortDir
82
83
  from phoenix.server.api.types.Span import Span, to_gql_span
83
84
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
84
- from phoenix.server.api.types.Trace import Trace
85
+ from phoenix.server.api.types.Trace import to_gql_trace
85
86
  from phoenix.server.api.types.User import User, to_gql_user
86
87
  from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
87
88
  from phoenix.server.api.types.UserRole import UserRole
@@ -445,17 +446,12 @@ class Query:
445
446
  gradient_end_color=project.gradient_end_color,
446
447
  )
447
448
  elif type_name == "Trace":
448
- trace_stmt = select(
449
- models.Trace.id,
450
- models.Trace.project_rowid,
451
- ).where(models.Trace.id == node_id)
449
+ trace_stmt = select(models.Trace).filter_by(id=node_id)
452
450
  async with info.context.db() as session:
453
- trace = (await session.execute(trace_stmt)).first()
451
+ trace = await session.scalar(trace_stmt)
454
452
  if trace is None:
455
453
  raise NotFound(f"Unknown trace: {id}")
456
- return Trace(
457
- id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
458
- )
454
+ return to_gql_trace(trace)
459
455
  elif type_name == Span.__name__:
460
456
  span_stmt = (
461
457
  select(models.Span)
@@ -544,6 +540,15 @@ class Query:
544
540
  ):
545
541
  raise NotFound(f"Unknown user: {id}")
546
542
  return to_gql_user(user)
543
+ elif type_name == ProjectSession.__name__:
544
+ async with info.context.db() as session:
545
+ if not (
546
+ project_session := await session.scalar(
547
+ select(models.ProjectSession).filter_by(id=node_id)
548
+ )
549
+ ):
550
+ raise NotFound(f"Unknown user: {id}")
551
+ return to_gql_project_session(project_session)
547
552
  raise NotFound(f"Unknown node type: {type_name}")
548
553
 
549
554
  @strawberry.field
@@ -15,6 +15,7 @@ from typing import (
15
15
  )
16
16
 
17
17
  import strawberry
18
+ from openinference.instrumentation import safe_json_dumps
18
19
  from openinference.semconv.trace import SpanAttributes
19
20
  from sqlalchemy import and_, func, insert, select
20
21
  from sqlalchemy.orm import load_only
@@ -118,6 +119,7 @@ class Subscription:
118
119
  )
119
120
  for message in input.messages
120
121
  ]
122
+ attributes = None
121
123
  if template_options := input.template:
122
124
  messages = list(
123
125
  _formatted_messages(
@@ -126,6 +128,7 @@ class Subscription:
126
128
  template_variables=template_options.variables,
127
129
  )
128
130
  )
131
+ attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
129
132
  invocation_parameters = llm_client.construct_invocation_parameters(
130
133
  input.invocation_parameters
131
134
  )
@@ -133,6 +136,7 @@ class Subscription:
133
136
  input=input,
134
137
  messages=messages,
135
138
  invocation_parameters=invocation_parameters,
139
+ attributes=attributes,
136
140
  ) as span:
137
141
  async for chunk in llm_client.chat_completion_create(
138
142
  messages=messages, tools=input.tools or [], **invocation_parameters
@@ -424,6 +428,7 @@ async def _stream_chat_completion_over_dataset_example(
424
428
  input=input,
425
429
  messages=messages,
426
430
  invocation_parameters=invocation_parameters,
431
+ attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
427
432
  ) as span:
428
433
  async for chunk in llm_client.chat_completion_create(
429
434
  messages=messages, tools=input.tools or [], **invocation_parameters
@@ -589,3 +594,4 @@ def _default_playground_experiment_metadata(
589
594
  LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
590
595
  LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
591
596
  LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
597
+ PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
@@ -477,7 +477,7 @@ def _row_indices(
477
477
  return
478
478
  shuffled_indices = np.arange(start, stop)
479
479
  np.random.shuffle(shuffled_indices)
480
- yield from shuffled_indices
480
+ yield from shuffled_indices # type: ignore[misc,unused-ignore]
481
481
 
482
482
 
483
483
  def to_gql_embedding_dimension(
@@ -20,7 +20,7 @@ from phoenix.server.api.types.pagination import (
20
20
  CursorString,
21
21
  connection_from_list,
22
22
  )
23
- from phoenix.server.api.types.Trace import Trace
23
+ from phoenix.server.api.types.Trace import Trace, to_gql_trace
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -61,11 +61,10 @@ class ExperimentRun(Node):
61
61
  async def trace(self, info: Info) -> Optional[Trace]:
62
62
  if not self.trace_id:
63
63
  return None
64
- dataloader = info.context.data_loaders.trace_row_ids
64
+ dataloader = info.context.data_loaders.trace_by_trace_ids
65
65
  if (trace := await dataloader.load(self.trace_id)) is None:
66
66
  return None
67
- trace_rowid, project_rowid = trace
68
- return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
67
+ return to_gql_trace(trace)
69
68
 
70
69
  @strawberry.field
71
70
  async def example(
@@ -8,7 +8,7 @@ from strawberry.scalars import JSON
8
8
 
9
9
  from phoenix.db import models
10
10
  from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
- from phoenix.server.api.types.Trace import Trace
11
+ from phoenix.server.api.types.Trace import Trace, to_gql_trace
12
12
 
13
13
 
14
14
  @strawberry.type
@@ -29,11 +29,10 @@ class ExperimentRunAnnotation(Node):
29
29
  async def trace(self, info: Info) -> Optional[Trace]:
30
30
  if not self.trace_id:
31
31
  return None
32
- dataloader = info.context.data_loaders.trace_row_ids
32
+ dataloader = info.context.data_loaders.trace_by_trace_ids
33
33
  if (trace := await dataloader.load(self.trace_id)) is None:
34
34
  return None
35
- trace_row_id, project_row_id = trace
36
- return Trace(id_attr=trace_row_id, trace_id=self.trace_id, project_rowid=project_row_id)
35
+ return to_gql_trace(trace)
37
36
 
38
37
 
39
38
  def to_gql_experiment_run_annotation(
@@ -1,23 +1,26 @@
1
1
  import operator
2
2
  from datetime import datetime
3
- from typing import (
4
- Any,
5
- ClassVar,
6
- Optional,
7
- )
3
+ from typing import Any, ClassVar, Optional
8
4
 
9
5
  import strawberry
10
6
  from aioitertools.itertools import islice
11
- from sqlalchemy import and_, desc, distinct, select
7
+ from openinference.semconv.trace import SpanAttributes
8
+ from sqlalchemy import and_, desc, distinct, func, or_, select
12
9
  from sqlalchemy.orm import contains_eager
10
+ from sqlalchemy.sql.elements import ColumnElement
13
11
  from sqlalchemy.sql.expression import tuple_
14
12
  from strawberry import ID, UNSET
15
13
  from strawberry.relay import Connection, Node, NodeID
16
14
  from strawberry.types import Info
15
+ from typing_extensions import assert_never
17
16
 
18
17
  from phoenix.datetime_utils import right_open_time_range
19
18
  from phoenix.db import models
20
19
  from phoenix.server.api.context import Context
20
+ from phoenix.server.api.input_types.ProjectSessionSort import (
21
+ ProjectSessionColumn,
22
+ ProjectSessionSort,
23
+ )
21
24
  from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
22
25
  from phoenix.server.api.input_types.TimeRange import TimeRange
23
26
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
@@ -28,9 +31,10 @@ from phoenix.server.api.types.pagination import (
28
31
  CursorString,
29
32
  connection_from_cursors_and_nodes,
30
33
  )
34
+ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
31
35
  from phoenix.server.api.types.SortDir import SortDir
32
36
  from phoenix.server.api.types.Span import Span, to_gql_span
33
- from phoenix.server.api.types.Trace import Trace
37
+ from phoenix.server.api.types.Trace import Trace, to_gql_trace
34
38
  from phoenix.server.api.types.ValidationResult import ValidationResult
35
39
  from phoenix.trace.dsl import SpanFilter
36
40
 
@@ -127,7 +131,13 @@ class Project(Node):
127
131
  time_range: Optional[TimeRange] = UNSET,
128
132
  ) -> Optional[float]:
129
133
  return await info.context.data_loaders.latency_ms_quantile.load(
130
- ("trace", self.id_attr, time_range, None, probability),
134
+ (
135
+ "trace",
136
+ self.id_attr,
137
+ time_range,
138
+ None,
139
+ probability,
140
+ ),
131
141
  )
132
142
 
133
143
  @strawberry.field
@@ -139,20 +149,26 @@ class Project(Node):
139
149
  filter_condition: Optional[str] = UNSET,
140
150
  ) -> Optional[float]:
141
151
  return await info.context.data_loaders.latency_ms_quantile.load(
142
- ("span", self.id_attr, time_range, filter_condition, probability),
152
+ (
153
+ "span",
154
+ self.id_attr,
155
+ time_range,
156
+ filter_condition,
157
+ probability,
158
+ ),
143
159
  )
144
160
 
145
161
  @strawberry.field
146
162
  async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]:
147
163
  stmt = (
148
- select(models.Trace.id)
164
+ select(models.Trace)
149
165
  .where(models.Trace.trace_id == str(trace_id))
150
166
  .where(models.Trace.project_rowid == self.id_attr)
151
167
  )
152
168
  async with info.context.db() as session:
153
- if (id_attr := await session.scalar(stmt)) is None:
169
+ if (trace := await session.scalar(stmt)) is None:
154
170
  return None
155
- return Trace(id_attr=id_attr, trace_id=trace_id, project_rowid=self.id_attr)
171
+ return to_gql_trace(trace)
156
172
 
157
173
  @strawberry.field
158
174
  async def spans(
@@ -241,6 +257,124 @@ class Project(Node):
241
257
  has_next_page=has_next_page,
242
258
  )
243
259
 
260
+ @strawberry.field
261
+ async def sessions(
262
+ self,
263
+ info: Info[Context, None],
264
+ time_range: Optional[TimeRange] = UNSET,
265
+ first: Optional[int] = 50,
266
+ after: Optional[CursorString] = UNSET,
267
+ sort: Optional[ProjectSessionSort] = UNSET,
268
+ filter_io_substring: Optional[str] = UNSET,
269
+ ) -> Connection[ProjectSession]:
270
+ table = models.ProjectSession
271
+ stmt = select(table).filter_by(project_id=self.id_attr)
272
+ if time_range:
273
+ if time_range.start:
274
+ stmt = stmt.where(time_range.start <= table.start_time)
275
+ if time_range.end:
276
+ stmt = stmt.where(table.start_time < time_range.end)
277
+ if filter_io_substring:
278
+ filter_subq = (
279
+ stmt.with_only_columns(distinct(table.id).label("id"))
280
+ .join_from(table, models.Trace)
281
+ .join_from(models.Trace, models.Span)
282
+ .where(models.Span.parent_id.is_(None))
283
+ .where(
284
+ or_(
285
+ models.TextContains(
286
+ models.Span.attributes[INPUT_VALUE].as_string(),
287
+ filter_io_substring,
288
+ ),
289
+ models.TextContains(
290
+ models.Span.attributes[OUTPUT_VALUE].as_string(),
291
+ filter_io_substring,
292
+ ),
293
+ )
294
+ )
295
+ ).subquery()
296
+ stmt = stmt.join(filter_subq, table.id == filter_subq.c.id)
297
+ if sort:
298
+ key: ColumnElement[Any]
299
+ if sort.col is ProjectSessionColumn.startTime:
300
+ key = table.start_time.label("key")
301
+ elif sort.col is ProjectSessionColumn.endTime:
302
+ key = table.end_time.label("key")
303
+ elif (
304
+ sort.col is ProjectSessionColumn.tokenCountTotal
305
+ or sort.col is ProjectSessionColumn.numTraces
306
+ ):
307
+ if sort.col is ProjectSessionColumn.tokenCountTotal:
308
+ sort_subq = (
309
+ select(
310
+ models.Trace.project_session_rowid.label("id"),
311
+ func.sum(models.Span.cumulative_llm_token_count_total).label("key"),
312
+ )
313
+ .join_from(models.Trace, models.Span)
314
+ .where(models.Span.parent_id.is_(None))
315
+ .group_by(models.Trace.project_session_rowid)
316
+ ).subquery()
317
+ elif sort.col is ProjectSessionColumn.numTraces:
318
+ sort_subq = (
319
+ select(
320
+ models.Trace.project_session_rowid.label("id"),
321
+ func.count(models.Trace.id).label("key"),
322
+ ).group_by(models.Trace.project_session_rowid)
323
+ ).subquery()
324
+ else:
325
+ assert_never(sort.col)
326
+ key = sort_subq.c.key
327
+ stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
328
+ else:
329
+ assert_never(sort.col)
330
+ stmt = stmt.add_columns(key)
331
+ if sort.dir is SortDir.asc:
332
+ stmt = stmt.order_by(key.asc(), table.id.asc())
333
+ else:
334
+ stmt = stmt.order_by(key.desc(), table.id.desc())
335
+ if after:
336
+ cursor = Cursor.from_string(after)
337
+ assert cursor.sort_column is not None
338
+ compare = operator.lt if sort.dir is SortDir.desc else operator.gt
339
+ stmt = stmt.where(
340
+ compare(
341
+ tuple_(key, table.id),
342
+ (cursor.sort_column.value, cursor.rowid),
343
+ )
344
+ )
345
+ else:
346
+ stmt = stmt.order_by(table.id.desc())
347
+ if after:
348
+ cursor = Cursor.from_string(after)
349
+ stmt = stmt.where(table.id < cursor.rowid)
350
+ if first:
351
+ stmt = stmt.limit(
352
+ first + 1 # over-fetch by one to determine whether there's a next page
353
+ )
354
+ cursors_and_nodes = []
355
+ async with info.context.db() as session:
356
+ records = await session.stream(stmt)
357
+ async for record in islice(records, first):
358
+ project_session = record[0]
359
+ cursor = Cursor(rowid=project_session.id)
360
+ if sort:
361
+ assert len(record) > 1
362
+ cursor.sort_column = CursorSortColumn(
363
+ type=sort.col.data_type,
364
+ value=record[1],
365
+ )
366
+ cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
367
+ has_next_page = True
368
+ try:
369
+ await records.__anext__()
370
+ except StopAsyncIteration:
371
+ has_next_page = False
372
+ return connection_from_cursors_and_nodes(
373
+ cursors_and_nodes,
374
+ has_previous_page=False,
375
+ has_next_page=has_next_page,
376
+ )
377
+
244
378
  @strawberry.field(
245
379
  description="Names of all available annotations for traces. "
246
380
  "(The list contains no duplicates.)"
@@ -363,3 +497,7 @@ def to_gql_project(project: models.Project) -> Project:
363
497
  gradient_start_color=project.gradient_start_color,
364
498
  gradient_end_color=project.gradient_end_color,
365
499
  )
500
+
501
+
502
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
503
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
@@ -0,0 +1,139 @@
1
+ from datetime import datetime
2
+ from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
3
+
4
+ import strawberry
5
+ from openinference.semconv.trace import SpanAttributes
6
+ from sqlalchemy import select
7
+ from strawberry import UNSET, Info, Private, lazy
8
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.context import Context
12
+ from phoenix.server.api.types.MimeType import MimeType
13
+ from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
14
+ from phoenix.server.api.types.SpanIOValue import SpanIOValue
15
+ from phoenix.server.api.types.TokenUsage import TokenUsage
16
+
17
+ if TYPE_CHECKING:
18
+ from phoenix.server.api.types.Trace import Trace
19
+
20
+
21
+ @strawberry.type
22
+ class ProjectSession(Node):
23
+ _table: ClassVar[Type[models.ProjectSession]] = models.ProjectSession
24
+ id_attr: NodeID[int]
25
+ project_rowid: Private[int]
26
+ session_id: str
27
+ start_time: datetime
28
+ end_time: datetime
29
+
30
+ @strawberry.field
31
+ async def project_id(self) -> GlobalID:
32
+ from phoenix.server.api.types.Project import Project
33
+
34
+ return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
35
+
36
+ @strawberry.field
37
+ async def num_traces(
38
+ self,
39
+ info: Info[Context, None],
40
+ ) -> int:
41
+ return await info.context.data_loaders.session_num_traces.load(self.id_attr)
42
+
43
+ @strawberry.field
44
+ async def num_traces_with_error(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> int:
48
+ return await info.context.data_loaders.session_num_traces_with_error.load(self.id_attr)
49
+
50
+ @strawberry.field
51
+ async def first_input(
52
+ self,
53
+ info: Info[Context, None],
54
+ ) -> Optional[SpanIOValue]:
55
+ record = await info.context.data_loaders.session_first_inputs.load(self.id_attr)
56
+ if record is None:
57
+ return None
58
+ return SpanIOValue(
59
+ mime_type=MimeType(record.mime_type.value),
60
+ value=record.value,
61
+ )
62
+
63
+ @strawberry.field
64
+ async def last_output(
65
+ self,
66
+ info: Info[Context, None],
67
+ ) -> Optional[SpanIOValue]:
68
+ record = await info.context.data_loaders.session_last_outputs.load(self.id_attr)
69
+ if record is None:
70
+ return None
71
+ return SpanIOValue(
72
+ mime_type=MimeType(record.mime_type.value),
73
+ value=record.value,
74
+ )
75
+
76
+ @strawberry.field
77
+ async def token_usage(
78
+ self,
79
+ info: Info[Context, None],
80
+ ) -> TokenUsage:
81
+ usage = await info.context.data_loaders.session_token_usages.load(self.id_attr)
82
+ return TokenUsage(
83
+ prompt=usage.prompt,
84
+ completion=usage.completion,
85
+ )
86
+
87
+ @strawberry.field
88
+ async def traces(
89
+ self,
90
+ info: Info[Context, None],
91
+ first: Optional[int] = 50,
92
+ last: Optional[int] = UNSET,
93
+ after: Optional[CursorString] = UNSET,
94
+ before: Optional[CursorString] = UNSET,
95
+ ) -> Connection[Annotated["Trace", lazy(".Trace")]]:
96
+ from phoenix.server.api.types.Trace import to_gql_trace
97
+
98
+ args = ConnectionArgs(
99
+ first=first,
100
+ after=after if isinstance(after, CursorString) else None,
101
+ last=last,
102
+ before=before if isinstance(before, CursorString) else None,
103
+ )
104
+ stmt = (
105
+ select(models.Trace)
106
+ .filter_by(project_session_rowid=self.id_attr)
107
+ .order_by(models.Trace.start_time)
108
+ .limit(first)
109
+ )
110
+ async with info.context.db() as session:
111
+ traces = await session.stream_scalars(stmt)
112
+ data = [to_gql_trace(trace) async for trace in traces]
113
+ return connection_from_list(data=data, args=args)
114
+
115
+ @strawberry.field
116
+ async def trace_latency_ms_quantile(
117
+ self,
118
+ info: Info[Context, None],
119
+ probability: float,
120
+ ) -> Optional[float]:
121
+ return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
122
+ (self.id_attr, probability)
123
+ )
124
+
125
+
126
+ def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
127
+ return ProjectSession(
128
+ id_attr=project_session.id,
129
+ session_id=project_session.session_id,
130
+ start_time=project_session.start_time,
131
+ project_rowid=project_session.project_id,
132
+ end_time=project_session.end_time,
133
+ )
134
+
135
+
136
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
137
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
138
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
139
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
@@ -24,17 +24,16 @@ from phoenix.server.api.input_types.SpanAnnotationSort import (
24
24
  SpanAnnotationColumn,
25
25
  SpanAnnotationSort,
26
26
  )
27
+ from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
28
+ from phoenix.server.api.types.Evaluation import DocumentEvaluation
29
+ from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
27
30
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider
31
+ from phoenix.server.api.types.MimeType import MimeType
28
32
  from phoenix.server.api.types.SortDir import SortDir
29
- from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation
33
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
34
+ from phoenix.server.api.types.SpanIOValue import SpanIOValue
30
35
  from phoenix.trace.attributes import get_attribute_value
31
36
 
32
- from .DocumentRetrievalMetrics import DocumentRetrievalMetrics
33
- from .Evaluation import DocumentEvaluation
34
- from .ExampleRevisionInterface import ExampleRevision
35
- from .MimeType import MimeType
36
- from .SpanAnnotation import SpanAnnotation
37
-
38
37
  if TYPE_CHECKING:
39
38
  from phoenix.server.api.types.Project import Project
40
39
 
@@ -71,18 +70,6 @@ class SpanContext:
71
70
  span_id: ID
72
71
 
73
72
 
74
- @strawberry.type
75
- class SpanIOValue:
76
- mime_type: MimeType
77
- value: str
78
-
79
- @strawberry.field(
80
- description="Truncate value up to `chars` characters, appending '...' if truncated.",
81
- ) # type: ignore
82
- def truncated_value(self, chars: int = 100) -> str:
83
- return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value
84
-
85
-
86
73
  @strawberry.enum
87
74
  class SpanStatusCode(Enum):
88
75
  OK = "OK"
@@ -0,0 +1,15 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.MimeType import MimeType
4
+
5
+
6
+ @strawberry.type
7
+ class SpanIOValue:
8
+ mime_type: MimeType
9
+ value: str
10
+
11
+ @strawberry.field(
12
+ description="Truncate value up to `chars` characters, appending '...' if truncated.",
13
+ ) # type: ignore
14
+ def truncated_value(self, chars: int = 100) -> str:
15
+ return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value