arize-phoenix 8.0.1__py3-none-any.whl → 8.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (38) hide show
  1. {arize_phoenix-8.0.1.dist-info → arize_phoenix-8.2.0.dist-info}/METADATA +1 -2
  2. {arize_phoenix-8.0.1.dist-info → arize_phoenix-8.2.0.dist-info}/RECORD +37 -36
  3. phoenix/config.py +32 -6
  4. phoenix/db/models.py +151 -11
  5. phoenix/server/api/context.py +4 -0
  6. phoenix/server/api/dataloaders/__init__.py +4 -0
  7. phoenix/server/api/dataloaders/span_by_id.py +29 -0
  8. phoenix/server/api/dataloaders/span_descendants.py +24 -15
  9. phoenix/server/api/dataloaders/span_fields.py +76 -0
  10. phoenix/server/api/dataloaders/trace_root_spans.py +9 -10
  11. phoenix/server/api/mutations/chat_mutations.py +10 -7
  12. phoenix/server/api/queries.py +2 -2
  13. phoenix/server/api/subscriptions.py +3 -3
  14. phoenix/server/api/types/Annotation.py +4 -1
  15. phoenix/server/api/types/DatasetExample.py +2 -2
  16. phoenix/server/api/types/Project.py +8 -10
  17. phoenix/server/api/types/ProjectSession.py +2 -2
  18. phoenix/server/api/types/Span.py +377 -120
  19. phoenix/server/api/types/SpanIOValue.py +39 -6
  20. phoenix/server/api/types/Trace.py +17 -15
  21. phoenix/server/app.py +4 -0
  22. phoenix/server/prometheus.py +113 -7
  23. phoenix/server/static/.vite/manifest.json +36 -36
  24. phoenix/server/static/assets/{components-B-qgPyHv.js → components-C48uRczp.js} +1 -1
  25. phoenix/server/static/assets/{index-D4KO1IcF.js → index-5klIR86Z.js} +2 -2
  26. phoenix/server/static/assets/{pages-DdcuL3Rh.js → pages-sERWyBWu.js} +326 -326
  27. phoenix/server/static/assets/{vendor-DQp7CrDA.js → vendor-Cqfydjep.js} +117 -117
  28. phoenix/server/static/assets/{vendor-arizeai-C1nEIEQq.js → vendor-arizeai-WnerlUPN.js} +1 -1
  29. phoenix/server/static/assets/{vendor-codemirror-BZXYUIkP.js → vendor-codemirror-D-ZZKLFq.js} +1 -1
  30. phoenix/server/static/assets/{vendor-recharts-BUFpwCVD.js → vendor-recharts-KY97ZPfK.js} +1 -1
  31. phoenix/server/static/assets/{vendor-shiki-C8L-c9jT.js → vendor-shiki-D5K9GnFn.js} +1 -1
  32. phoenix/trace/attributes.py +7 -2
  33. phoenix/version.py +1 -1
  34. phoenix/server/api/helpers/jsonschema.py +0 -135
  35. {arize_phoenix-8.0.1.dist-info → arize_phoenix-8.2.0.dist-info}/WHEEL +0 -0
  36. {arize_phoenix-8.0.1.dist-info → arize_phoenix-8.2.0.dist-info}/entry_points.txt +0 -0
  37. {arize_phoenix-8.0.1.dist-info → arize_phoenix-8.2.0.dist-info}/licenses/IP_NOTICE +0 -0
  38. {arize_phoenix-8.0.1.dist-info → arize_phoenix-8.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,18 @@
1
- from random import randint
1
+ from secrets import token_hex
2
+ from typing import Iterable
2
3
 
3
4
  from aioitertools.itertools import groupby
4
5
  from sqlalchemy import select
5
- from sqlalchemy.orm import joinedload
6
6
  from strawberry.dataloader import DataLoader
7
7
  from typing_extensions import TypeAlias
8
8
 
9
9
  from phoenix.db import models
10
10
  from phoenix.server.types import DbSessionFactory
11
11
 
12
- SpanId: TypeAlias = str
12
+ SpanRowId: TypeAlias = int
13
13
 
14
- Key: TypeAlias = SpanId
15
- Result: TypeAlias = list[models.Span]
14
+ Key: TypeAlias = SpanRowId
15
+ Result: TypeAlias = list[SpanRowId]
16
16
 
17
17
 
18
18
  class SpanDescendantsDataLoader(DataLoader[Key, Result]):
@@ -20,16 +20,25 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
20
20
  super().__init__(load_fn=self._load_fn)
21
21
  self._db = db
22
22
 
23
- async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
- root_ids = set(keys)
25
- root_id_label = f"root_id_{randint(0, 10**6):06}"
23
+ async def _load_fn(self, keys: Iterable[Key]) -> list[Result]:
24
+ root_ids = (
25
+ select(
26
+ models.Span.span_id,
27
+ models.Span.id,
28
+ )
29
+ .where(models.Span.id.in_(set(keys)))
30
+ .subquery()
31
+ )
32
+ root_id_label = f"root_id_{token_hex(8)}"
33
+ root_rowid_label = f"root_rowid_{token_hex(8)}"
26
34
  descendant_ids = (
27
35
  select(
28
36
  models.Span.id,
29
37
  models.Span.span_id,
30
38
  models.Span.parent_id.label(root_id_label),
39
+ root_ids.c.id.label(root_rowid_label),
31
40
  )
32
- .where(models.Span.parent_id.in_(root_ids))
41
+ .join(root_ids, models.Span.parent_id == root_ids.c.span_id)
33
42
  .cte(recursive=True)
34
43
  )
35
44
  parent_ids = descendant_ids.alias()
@@ -38,20 +47,20 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
38
47
  models.Span.id,
39
48
  models.Span.span_id,
40
49
  parent_ids.c[root_id_label],
50
+ parent_ids.c[root_rowid_label],
41
51
  ).join(
42
52
  parent_ids,
43
53
  models.Span.parent_id == parent_ids.c.span_id,
44
54
  )
45
55
  )
46
56
  stmt = (
47
- select(descendant_ids.c[root_id_label], models.Span)
57
+ select(descendant_ids.c[root_rowid_label], models.Span.id)
48
58
  .join(descendant_ids, models.Span.id == descendant_ids.c.id)
49
- .options(joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id))
50
- .order_by(descendant_ids.c[root_id_label])
59
+ .order_by(descendant_ids.c[root_rowid_label])
51
60
  )
52
- results: dict[SpanId, Result] = {key: [] for key in keys}
61
+ results: dict[Key, Result] = {key: [] for key in keys}
53
62
  async with self._db() as session:
54
63
  data = await session.stream(stmt)
55
- async for root_id, group in groupby(data, key=lambda d: d[0]):
56
- results[root_id].extend(span for _, span in group)
64
+ async for key, group in groupby(data, key=lambda d: d[0]):
65
+ results[key].extend(span_rowid for _, span_rowid in group)
57
66
  return [results[key].copy() for key in keys]
@@ -0,0 +1,76 @@
1
+ from typing import Any, Iterable, Union
2
+
3
+ from sqlalchemy import Select, select
4
+ from sqlalchemy.orm import QueryableAttribute
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ SpanRowId: TypeAlias = int
12
+
13
+ Key: TypeAlias = tuple[SpanRowId, QueryableAttribute[Any]]
14
+ Result: TypeAlias = Any
15
+
16
+
17
+ _ResultColumnPosition: TypeAlias = int
18
+ _AttrStrIdentifier: TypeAlias = str
19
+
20
+
21
+ class SpanFieldsDataLoader(DataLoader[Key, Result]):
22
+ def __init__(self, db: DbSessionFactory) -> None:
23
+ super().__init__(load_fn=self._load_fn)
24
+ self._db = db
25
+
26
+ async def _load_fn(self, keys: Iterable[Key]) -> list[Union[Result, ValueError]]:
27
+ result: dict[tuple[SpanRowId, _AttrStrIdentifier], Result] = {}
28
+ stmt, attr_strs = _get_stmt(keys)
29
+ async with self._db() as session:
30
+ data = await session.stream(stmt)
31
+ async for row in data:
32
+ span_rowid: SpanRowId = row[0] # models.Span's primary key
33
+ for i, value in enumerate(row[1:]):
34
+ result[span_rowid, attr_strs[i]] = value
35
+ return [result.get((span_rowid, str(attr))) for span_rowid, attr in keys]
36
+
37
+
38
+ def _get_stmt(
39
+ keys: Iterable[Key],
40
+ ) -> tuple[
41
+ Select[Any],
42
+ dict[_ResultColumnPosition, _AttrStrIdentifier],
43
+ ]:
44
+ """
45
+ Generate a SQLAlchemy Select statement and a mapping of attribute identifiers (from their
46
+ column positions in the query result starting at the second column).
47
+
48
+ This function constructs a SQLAlchemy Select statement to query the `Span` model
49
+ based on the provided keys. It also creates a mapping of attribute identifiers
50
+ to their positions in the query result (starting at the second column as the zero-th
51
+ position).
52
+
53
+ Args:
54
+ keys (list[Key]): A list of tuples, where each tuple contains an integer ID, i.e. the
55
+ primary key of models.Span, and a QueryableAttribute.
56
+
57
+ Returns:
58
+ tuple: A tuple containing:
59
+ - Select[Any]: A SQLAlchemy Select statement with `Span` ID and attributes.
60
+ - dict[int, str]: A dictionary mapping the column position--where 0-th position starts
61
+ at the second column (because the first column is the span's primary key)--in the
62
+ result to the attribute's string identifier.
63
+ """
64
+ span_rowids: set[SpanRowId] = set()
65
+ attrs: dict[_AttrStrIdentifier, QueryableAttribute[Any]] = {}
66
+ joins = set()
67
+ for span_rowid, attr in keys:
68
+ span_rowids.add(span_rowid)
69
+ attrs[str(attr)] = attr
70
+ if (entity := attr.parent.entity) is not models.Span:
71
+ joins.add(entity)
72
+ stmt = select(models.Span.id).where(models.Span.id.in_(span_rowids))
73
+ for table in joins:
74
+ stmt = stmt.join(table)
75
+ identifiers, columns = zip(*attrs.items())
76
+ return stmt.add_columns(*columns), dict(enumerate(identifiers))
@@ -1,15 +1,17 @@
1
- from typing import List, Optional
1
+ from typing import Iterable, Optional
2
2
 
3
3
  from sqlalchemy import select
4
- from sqlalchemy.orm import contains_eager
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
8
7
  from phoenix.db import models
9
8
  from phoenix.server.types import DbSessionFactory
10
9
 
11
- Key: TypeAlias = int
12
- Result: TypeAlias = Optional[models.Span]
10
+ TraceRowId: TypeAlias = int
11
+ SpanRowId: TypeAlias = int
12
+
13
+ Key: TypeAlias = TraceRowId
14
+ Result: TypeAlias = Optional[SpanRowId]
13
15
 
14
16
 
15
17
  class TraceRootSpansDataLoader(DataLoader[Key, Result]):
@@ -17,16 +19,13 @@ class TraceRootSpansDataLoader(DataLoader[Key, Result]):
17
19
  super().__init__(load_fn=self._load_fn)
18
20
  self._db = db
19
21
 
20
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
22
+ async def _load_fn(self, keys: Iterable[Key]) -> list[Result]:
21
23
  stmt = (
22
- select(models.Span)
24
+ select(models.Trace.id, models.Span.id)
23
25
  .join(models.Trace)
24
26
  .where(models.Span.parent_id.is_(None))
25
27
  .where(models.Trace.id.in_(keys))
26
- .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
27
28
  )
28
29
  async with self._db() as session:
29
- result: dict[Key, models.Span] = {
30
- span.trace_rowid: span async for span in await session.stream_scalars(stmt)
31
- }
30
+ result: dict[Key, int] = {k: v async for k, v in await session.stream(stmt)}
32
31
  return [result.get(key) for key in keys]
@@ -60,7 +60,7 @@ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
60
60
  from phoenix.server.api.types.Dataset import Dataset
61
61
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
62
62
  from phoenix.server.api.types.node import from_global_id_with_expected_type
63
- from phoenix.server.api.types.Span import Span, to_gql_span
63
+ from phoenix.server.api.types.Span import Span
64
64
  from phoenix.server.dml_event import SpanInsertEvent
65
65
  from phoenix.trace.attributes import unflatten
66
66
  from phoenix.trace.schemas import SpanException
@@ -91,6 +91,7 @@ class ChatCompletionToolCall:
91
91
 
92
92
  @strawberry.type
93
93
  class ChatCompletionMutationPayload:
94
+ db_span: strawberry.Private[models.Span]
94
95
  content: Optional[str]
95
96
  tool_calls: List[ChatCompletionToolCall]
96
97
  span: Span
@@ -188,7 +189,7 @@ class ChatCompletionMutationMixin:
188
189
  session.add(experiment)
189
190
  await session.flush()
190
191
 
191
- results = []
192
+ results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
192
193
  batch_size = 3
193
194
  start_time = datetime.now(timezone.utc)
194
195
  for batch in _get_batches(revisions, batch_size):
@@ -234,19 +235,19 @@ class ChatCompletionMutationMixin:
234
235
  error=str(result),
235
236
  )
236
237
  else:
237
- db_span = result.span.db_span
238
+ db_span: models.Span = result.db_span
238
239
  experiment_run = models.ExperimentRun(
239
240
  experiment_id=experiment.id,
240
241
  dataset_example_id=revision.dataset_example_id,
241
- trace_id=str(result.span.context.trace_id),
242
+ trace_id=db_span.trace.trace_id,
242
243
  output=models.ExperimentRunOutput(
243
244
  task_output=get_dataset_example_output(db_span),
244
245
  ),
245
246
  prompt_token_count=db_span.cumulative_llm_token_count_prompt,
246
247
  completion_token_count=db_span.cumulative_llm_token_count_completion,
247
248
  repetition_number=1,
248
- start_time=result.span.start_time,
249
- end_time=result.span.end_time,
249
+ start_time=db_span.start_time,
250
+ end_time=db_span.end_time,
250
251
  error=str(result.error_message) if result.error_message else None,
251
252
  )
252
253
  experiment_runs.append(experiment_run)
@@ -433,12 +434,13 @@ class ChatCompletionMutationMixin:
433
434
  session.add(span)
434
435
  await session.flush()
435
436
 
436
- gql_span = to_gql_span(span)
437
+ gql_span = Span(span_rowid=span.id, db_span=span)
437
438
 
438
439
  info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
439
440
 
440
441
  if status_code is StatusCode.ERROR:
441
442
  return ChatCompletionMutationPayload(
443
+ db_span=span,
442
444
  content=None,
443
445
  tool_calls=[],
444
446
  span=gql_span,
@@ -446,6 +448,7 @@ class ChatCompletionMutationMixin:
446
448
  )
447
449
  else:
448
450
  return ChatCompletionMutationPayload(
451
+ db_span=span,
449
452
  content=text_content if text_content else None,
450
453
  tool_calls=list(tool_calls.values()),
451
454
  span=gql_span,
@@ -64,7 +64,7 @@ from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
64
64
  from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
65
65
  from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
66
66
  from phoenix.server.api.types.SortDir import SortDir
67
- from phoenix.server.api.types.Span import Span, to_gql_span
67
+ from phoenix.server.api.types.Span import Span
68
68
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
69
69
  from phoenix.server.api.types.Trace import to_gql_trace
70
70
  from phoenix.server.api.types.User import User, to_gql_user
@@ -483,7 +483,7 @@ class Query:
483
483
  span = await session.scalar(span_stmt)
484
484
  if span is None:
485
485
  raise NotFound(f"Unknown span: {id}")
486
- return to_gql_span(span)
486
+ return Span(span_rowid=span.id, db_span=span)
487
487
  elif type_name == Dataset.__name__:
488
488
  dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
489
489
  async with info.context.db() as session:
@@ -59,7 +59,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion
59
59
  from phoenix.server.api.types.Experiment import to_gql_experiment
60
60
  from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
61
61
  from phoenix.server.api.types.node import from_global_id_with_expected_type
62
- from phoenix.server.api.types.Span import to_gql_span
62
+ from phoenix.server.api.types.Span import Span
63
63
  from phoenix.server.dml_event import SpanInsertEvent
64
64
  from phoenix.server.types import DbSessionFactory
65
65
  from phoenix.utilities.template_formatters import (
@@ -165,7 +165,7 @@ class Subscription:
165
165
  session.add(db_span)
166
166
  await session.flush()
167
167
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
168
- yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
168
+ yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
169
169
 
170
170
  @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
171
171
  async def chat_completion_over_dataset(
@@ -457,7 +457,7 @@ async def _chat_completion_result_payloads(
457
457
  await session.flush()
458
458
  for example_id, span, run in results:
459
459
  yield ChatCompletionSubscriptionResult(
460
- span=to_gql_span(span) if span else None,
460
+ span=Span(span_rowid=span.id, db_span=span) if span else None,
461
461
  experiment_run=to_gql_experiment_run(run),
462
462
  dataset_example_id=example_id,
463
463
  )
@@ -2,6 +2,8 @@ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
 
5
+ from phoenix.server.api.interceptor import GqlValueMediator
6
+
5
7
 
6
8
  @strawberry.interface
7
9
  class Annotation:
@@ -9,7 +11,8 @@ class Annotation:
9
11
  description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
10
12
  )
11
13
  score: Optional[float] = strawberry.field(
12
- description="Value of the annotation in the form of a numeric score."
14
+ description="Value of the annotation in the form of a numeric score.",
15
+ default=GqlValueMediator(),
13
16
  )
14
17
  label: Optional[str] = strawberry.field(
15
18
  description="Value of the annotation in the form of a string, e.g. "
@@ -19,7 +19,7 @@ from phoenix.server.api.types.pagination import (
19
19
  CursorString,
20
20
  connection_from_list,
21
21
  )
22
- from phoenix.server.api.types.Span import Span, to_gql_span
22
+ from phoenix.server.api.types.Span import Span
23
23
 
24
24
 
25
25
  @strawberry.type
@@ -52,7 +52,7 @@ class DatasetExample(Node):
52
52
  info: Info[Context, None],
53
53
  ) -> Optional[Span]:
54
54
  return (
55
- to_gql_span(span)
55
+ Span(span_rowid=span.id, db_span=span)
56
56
  if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
57
57
  else None
58
58
  )
@@ -6,7 +6,6 @@ import strawberry
6
6
  from aioitertools.itertools import islice
7
7
  from openinference.semconv.trace import SpanAttributes
8
8
  from sqlalchemy import desc, distinct, func, or_, select
9
- from sqlalchemy.orm import contains_eager
10
9
  from sqlalchemy.sql.elements import ColumnElement
11
10
  from sqlalchemy.sql.expression import tuple_
12
11
  from strawberry import ID, UNSET
@@ -33,7 +32,7 @@ from phoenix.server.api.types.pagination import (
33
32
  )
34
33
  from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
35
34
  from phoenix.server.api.types.SortDir import SortDir
36
- from phoenix.server.api.types.Span import Span, to_gql_span
35
+ from phoenix.server.api.types.Span import Span
37
36
  from phoenix.server.api.types.Trace import Trace, to_gql_trace
38
37
  from phoenix.server.api.types.ValidationResult import ValidationResult
39
38
  from phoenix.trace.dsl import SpanFilter
@@ -184,10 +183,9 @@ class Project(Node):
184
183
  filter_condition: Optional[str] = UNSET,
185
184
  ) -> Connection[Span]:
186
185
  stmt = (
187
- select(models.Span)
186
+ select(models.Span.id)
188
187
  .join(models.Trace)
189
188
  .where(models.Trace.project_rowid == self.id_attr)
190
- .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
191
189
  )
192
190
  if time_range:
193
191
  if time_range.start:
@@ -232,21 +230,21 @@ class Project(Node):
232
230
  stmt = stmt.order_by(cursor_rowid_column)
233
231
  cursors_and_nodes = []
234
232
  async with info.context.db() as session:
235
- span_records = await session.execute(stmt)
233
+ span_records = await session.stream(stmt)
236
234
  async for span_record in islice(span_records, first):
237
- span = span_record[0]
238
- cursor = Cursor(rowid=span.id)
235
+ span_rowid: int = span_record[0]
236
+ cursor = Cursor(rowid=span_rowid)
239
237
  if sort_config:
240
238
  assert len(span_record) > 1
241
239
  cursor.sort_column = CursorSortColumn(
242
240
  type=sort_config.column_data_type,
243
241
  value=span_record[1],
244
242
  )
245
- cursors_and_nodes.append((cursor, to_gql_span(span)))
243
+ cursors_and_nodes.append((cursor, Span(span_rowid=span_rowid)))
246
244
  has_next_page = True
247
245
  try:
248
- next(span_records)
249
- except StopIteration:
246
+ await span_records.__anext__()
247
+ except StopAsyncIteration:
250
248
  has_next_page = False
251
249
 
252
250
  return connection_from_cursors_and_nodes(
@@ -57,7 +57,7 @@ class ProjectSession(Node):
57
57
  return None
58
58
  return SpanIOValue(
59
59
  mime_type=MimeType(record.mime_type.value),
60
- value=record.value,
60
+ cached_value=record.value,
61
61
  )
62
62
 
63
63
  @strawberry.field
@@ -70,7 +70,7 @@ class ProjectSession(Node):
70
70
  return None
71
71
  return SpanIOValue(
72
72
  mime_type=MimeType(record.mime_type.value),
73
- value=record.value,
73
+ cached_value=record.value,
74
74
  )
75
75
 
76
76
  @strawberry.field