arize-phoenix 4.12.1rc1__py3-none-any.whl → 4.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (73) hide show
  1. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/METADATA +10 -6
  2. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/RECORD +70 -68
  3. phoenix/db/bulk_inserter.py +5 -4
  4. phoenix/db/engines.py +2 -1
  5. phoenix/experiments/evaluators/base.py +4 -0
  6. phoenix/experiments/evaluators/code_evaluators.py +80 -0
  7. phoenix/experiments/evaluators/llm_evaluators.py +77 -1
  8. phoenix/experiments/evaluators/utils.py +70 -21
  9. phoenix/experiments/functions.py +17 -16
  10. phoenix/server/api/context.py +5 -3
  11. phoenix/server/api/dataloaders/__init__.py +2 -0
  12. phoenix/server/api/dataloaders/average_experiment_run_latency.py +25 -25
  13. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  14. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  15. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  16. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  17. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  18. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  19. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/experiment_error_rates.py +32 -14
  21. phoenix/server/api/dataloaders/experiment_run_counts.py +20 -9
  22. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  23. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  24. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  25. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  26. phoenix/server/api/dataloaders/record_counts.py +2 -4
  27. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  28. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  29. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  30. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  31. phoenix/server/api/dataloaders/span_projects.py +3 -3
  32. phoenix/server/api/dataloaders/token_counts.py +2 -4
  33. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  34. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  35. phoenix/server/api/input_types/{CreateSpanAnnotationsInput.py → CreateSpanAnnotationInput.py} +4 -2
  36. phoenix/server/api/input_types/{CreateTraceAnnotationsInput.py → CreateTraceAnnotationInput.py} +4 -2
  37. phoenix/server/api/input_types/{PatchAnnotationsInput.py → PatchAnnotationInput.py} +4 -2
  38. phoenix/server/api/mutations/span_annotations_mutations.py +20 -9
  39. phoenix/server/api/mutations/trace_annotations_mutations.py +20 -9
  40. phoenix/server/api/routers/v1/datasets.py +132 -10
  41. phoenix/server/api/routers/v1/evaluations.py +3 -5
  42. phoenix/server/api/routers/v1/experiments.py +1 -1
  43. phoenix/server/api/types/Experiment.py +2 -2
  44. phoenix/server/api/types/Inferences.py +1 -2
  45. phoenix/server/api/types/Model.py +1 -2
  46. phoenix/server/api/types/Span.py +5 -0
  47. phoenix/server/api/utils.py +4 -4
  48. phoenix/server/app.py +21 -18
  49. phoenix/server/grpc_server.py +2 -2
  50. phoenix/server/main.py +5 -9
  51. phoenix/server/static/.vite/manifest.json +31 -31
  52. phoenix/server/static/assets/{components-C8sm_r1F.js → components-kGgeFkHp.js} +150 -110
  53. phoenix/server/static/assets/index-BctFO6S7.js +100 -0
  54. phoenix/server/static/assets/{pages-bN7juCjh.js → pages-DabDCmVd.js} +432 -255
  55. phoenix/server/static/assets/{vendor-CUDAPm8e.js → vendor-CP0b0YG0.js} +2 -2
  56. phoenix/server/static/assets/{vendor-arizeai-Do2HOmcL.js → vendor-arizeai-B5Hti8OB.js} +27 -27
  57. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  58. phoenix/server/static/assets/{vendor-recharts-PKRvByVe.js → vendor-recharts-A0DA1O99.js} +1 -1
  59. phoenix/server/types.py +18 -0
  60. phoenix/session/client.py +9 -6
  61. phoenix/session/session.py +2 -2
  62. phoenix/trace/dsl/filter.py +40 -25
  63. phoenix/trace/fixtures.py +17 -23
  64. phoenix/trace/utils.py +23 -0
  65. phoenix/utilities/client.py +116 -0
  66. phoenix/utilities/project.py +1 -1
  67. phoenix/version.py +1 -1
  68. phoenix/server/api/routers/v1/dataset_examples.py +0 -157
  69. phoenix/server/static/assets/index-BEKPzgQs.js +0 -100
  70. phoenix/server/static/assets/vendor-codemirror-CrdxOlMs.js +0 -12
  71. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/WHEEL +0 -0
  72. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/licenses/IP_NOTICE +0 -0
  73. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,14 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  )
7
5
 
8
- from sqlalchemy import func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
6
+ from sqlalchemy import case, func, select
10
7
  from strawberry.dataloader import DataLoader
11
8
  from typing_extensions import TypeAlias
12
9
 
13
10
  from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
14
12
 
15
13
  ExperimentID: TypeAlias = int
16
14
  ErrorRate: TypeAlias = float
@@ -21,23 +19,43 @@ Result: TypeAlias = Optional[ErrorRate]
21
19
  class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
22
20
  def __init__(
23
21
  self,
24
- db: Callable[[], AsyncContextManager[AsyncSession]],
22
+ db: DbSessionFactory,
25
23
  ) -> None:
26
24
  super().__init__(load_fn=self._load_fn)
27
25
  self._db = db
28
26
 
29
27
  async def _load_fn(self, keys: List[Key]) -> List[Result]:
30
28
  experiment_ids = keys
29
+ resolved_experiment_ids = (
30
+ select(models.Experiment.id)
31
+ .where(models.Experiment.id.in_(set(experiment_ids)))
32
+ .subquery()
33
+ )
34
+ query = (
35
+ select(
36
+ resolved_experiment_ids.c.id,
37
+ case(
38
+ (
39
+ func.count(models.ExperimentRun.id) != 0,
40
+ func.count(models.ExperimentRun.error)
41
+ / func.count(models.ExperimentRun.id),
42
+ ),
43
+ else_=None,
44
+ ),
45
+ )
46
+ .outerjoin_from(
47
+ from_=resolved_experiment_ids,
48
+ target=models.ExperimentRun,
49
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
50
+ )
51
+ .group_by(resolved_experiment_ids.c.id)
52
+ )
31
53
  async with self._db() as session:
32
54
  error_rates = {
33
55
  experiment_id: error_rate
34
- async for experiment_id, error_rate in await session.stream(
35
- select(
36
- models.ExperimentRun.experiment_id,
37
- func.count(models.ExperimentRun.error) / func.count(),
38
- )
39
- .group_by(models.ExperimentRun.experiment_id)
40
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
41
- )
56
+ async for experiment_id, error_rate in await session.stream(query)
42
57
  }
43
- return [error_rates.get(experiment_id) for experiment_id in experiment_ids]
58
+ return [
59
+ error_rates.get(experiment_id, ValueError(f"Unknown experiment ID: {experiment_id}"))
60
+ for experiment_id in experiment_ids
61
+ ]
@@ -1,15 +1,13 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  )
6
4
 
7
5
  from sqlalchemy import func, select
8
- from sqlalchemy.ext.asyncio import AsyncSession
9
6
  from strawberry.dataloader import DataLoader
10
7
  from typing_extensions import TypeAlias
11
8
 
12
9
  from phoenix.db import models
10
+ from phoenix.server.types import DbSessionFactory
13
11
 
14
12
  ExperimentID: TypeAlias = int
15
13
  RunCount: TypeAlias = int
@@ -20,21 +18,34 @@ Result: TypeAlias = RunCount
20
18
  class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
21
19
  def __init__(
22
20
  self,
23
- db: Callable[[], AsyncContextManager[AsyncSession]],
21
+ db: DbSessionFactory,
24
22
  ) -> None:
25
23
  super().__init__(load_fn=self._load_fn)
26
24
  self._db = db
27
25
 
28
26
  async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
27
  experiment_ids = keys
28
+ resolved_experiment_ids = (
29
+ select(models.Experiment.id)
30
+ .where(models.Experiment.id.in_(set(experiment_ids)))
31
+ .subquery()
32
+ )
33
+ query = (
34
+ select(
35
+ resolved_experiment_ids.c.id,
36
+ func.count(models.ExperimentRun.experiment_id),
37
+ )
38
+ .outerjoin_from(
39
+ from_=resolved_experiment_ids,
40
+ target=models.ExperimentRun,
41
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
42
+ )
43
+ .group_by(resolved_experiment_ids.c.id)
44
+ )
30
45
  async with self._db() as session:
31
46
  run_counts = {
32
47
  experiment_id: run_count
33
- async for experiment_id, run_count in await session.stream(
34
- select(models.ExperimentRun.experiment_id, func.count())
35
- .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
36
- .group_by(models.ExperimentRun.experiment_id)
37
- )
48
+ async for experiment_id, run_count in await session.stream(query)
38
49
  }
39
50
  return [
40
51
  run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
@@ -1,16 +1,14 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  )
7
5
 
8
6
  from sqlalchemy import distinct, func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
7
  from strawberry.dataloader import DataLoader
11
8
  from typing_extensions import TypeAlias
12
9
 
13
10
  from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
14
12
 
15
13
  ExperimentId: TypeAlias = int
16
14
  Key: TypeAlias = ExperimentId
@@ -18,7 +16,7 @@ Result: TypeAlias = Optional[int]
18
16
 
19
17
 
20
18
  class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
19
+ def __init__(self, db: DbSessionFactory) -> None:
22
20
  super().__init__(load_fn=self._load_fn)
23
21
  self._db = db
24
22
 
@@ -2,9 +2,7 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
5
  AsyncIterator,
7
- Callable,
8
6
  DefaultDict,
9
7
  List,
10
8
  Literal,
@@ -36,6 +34,7 @@ from phoenix.db import models
36
34
  from phoenix.db.helpers import SupportedSQLDialect
37
35
  from phoenix.server.api.dataloaders.cache import TwoTierCache
38
36
  from phoenix.server.api.input_types.TimeRange import TimeRange
37
+ from phoenix.server.types import DbSessionFactory
39
38
  from phoenix.trace.dsl import SpanFilter
40
39
 
41
40
  Kind: TypeAlias = Literal["span", "trace"]
@@ -88,7 +87,7 @@ class LatencyMsQuantileCache(
88
87
  class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
89
88
  def __init__(
90
89
  self,
91
- db: Callable[[], AsyncContextManager[AsyncSession]],
90
+ db: DbSessionFactory,
92
91
  cache_map: Optional[AbstractCache[Key, Result]] = None,
93
92
  ) -> None:
94
93
  super().__init__(
@@ -1,8 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
- AsyncContextManager,
5
- Callable,
6
4
  DefaultDict,
7
5
  List,
8
6
  Literal,
@@ -12,12 +10,12 @@ from typing import (
12
10
 
13
11
  from cachetools import LFUCache
14
12
  from sqlalchemy import func, select
15
- from sqlalchemy.ext.asyncio import AsyncSession
16
13
  from strawberry.dataloader import AbstractCache, DataLoader
17
14
  from typing_extensions import TypeAlias, assert_never
18
15
 
19
16
  from phoenix.db import models
20
17
  from phoenix.server.api.dataloaders.cache import TwoTierCache
18
+ from phoenix.server.types import DbSessionFactory
21
19
 
22
20
  Kind: TypeAlias = Literal["start", "end"]
23
21
  ProjectRowId: TypeAlias = int
@@ -50,7 +48,7 @@ class MinStartOrMaxEndTimeCache(
50
48
  class MinStartOrMaxEndTimeDataLoader(DataLoader[Key, Result]):
51
49
  def __init__(
52
50
  self,
53
- db: Callable[[], AsyncContextManager[AsyncSession]],
51
+ db: DbSessionFactory,
54
52
  cache_map: Optional[AbstractCache[Key, Result]] = None,
55
53
  ) -> None:
56
54
  super().__init__(
@@ -1,12 +1,12 @@
1
1
  from collections import defaultdict
2
- from typing import AsyncContextManager, Callable, DefaultDict, List, Optional
2
+ from typing import DefaultDict, List, Optional
3
3
 
4
4
  from sqlalchemy import select
5
- from sqlalchemy.ext.asyncio import AsyncSession
6
5
  from strawberry.dataloader import DataLoader
7
6
  from typing_extensions import TypeAlias
8
7
 
9
8
  from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
10
 
11
11
  ProjectName: TypeAlias = str
12
12
  Key: TypeAlias = ProjectName
@@ -14,7 +14,7 @@ Result: TypeAlias = Optional[models.Project]
14
14
 
15
15
 
16
16
  class ProjectByNameDataLoader(DataLoader[Key, Result]):
17
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
18
  super().__init__(load_fn=self._load_fn)
19
19
  self._db = db
20
20
 
@@ -2,8 +2,6 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
- Callable,
7
5
  DefaultDict,
8
6
  List,
9
7
  Literal,
@@ -13,13 +11,13 @@ from typing import (
13
11
 
14
12
  from cachetools import LFUCache, TTLCache
15
13
  from sqlalchemy import Select, func, select
16
- from sqlalchemy.ext.asyncio import AsyncSession
17
14
  from strawberry.dataloader import AbstractCache, DataLoader
18
15
  from typing_extensions import TypeAlias, assert_never
19
16
 
20
17
  from phoenix.db import models
21
18
  from phoenix.server.api.dataloaders.cache import TwoTierCache
22
19
  from phoenix.server.api.input_types.TimeRange import TimeRange
20
+ from phoenix.server.types import DbSessionFactory
23
21
  from phoenix.trace.dsl import SpanFilter
24
22
 
25
23
  Kind: TypeAlias = Literal["span", "trace"]
@@ -69,7 +67,7 @@ class RecordCountCache(
69
67
  class RecordCountDataLoader(DataLoader[Key, Result]):
70
68
  def __init__(
71
69
  self,
72
- db: Callable[[], AsyncContextManager[AsyncSession]],
70
+ db: DbSessionFactory,
73
71
  cache_map: Optional[AbstractCache[Key, Result]] = None,
74
72
  ) -> None:
75
73
  super().__init__(
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[SpanAnnotation]
19
17
 
20
18
 
21
19
  class SpanAnnotationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -0,0 +1,36 @@
1
+ from typing import (
2
+ Dict,
3
+ List,
4
+ )
5
+
6
+ from sqlalchemy import select
7
+ from strawberry.dataloader import DataLoader
8
+ from typing_extensions import TypeAlias
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
12
+
13
+ SpanID: TypeAlias = int
14
+ Key: TypeAlias = SpanID
15
+ Result: TypeAlias = List[models.DatasetExample]
16
+
17
+
18
+ class SpanDatasetExamplesDataLoader(DataLoader[Key, Result]):
19
+ def __init__(self, db: DbSessionFactory) -> None:
20
+ super().__init__(load_fn=self._load_fn)
21
+ self._db = db
22
+
23
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
24
+ span_rowids = keys
25
+ async with self._db() as session:
26
+ dataset_examples: Dict[Key, List[models.DatasetExample]] = {
27
+ span_rowid: [] for span_rowid in span_rowids
28
+ }
29
+ async for span_rowid, dataset_example in await session.stream(
30
+ select(models.Span.id, models.DatasetExample)
31
+ .select_from(models.Span)
32
+ .join(models.DatasetExample, models.DatasetExample.span_rowid == models.Span.id)
33
+ .where(models.Span.id.in_(span_rowids))
34
+ ):
35
+ dataset_examples[span_rowid].append(dataset_example)
36
+ return [dataset_examples.get(span_rowid, []) for span_rowid in span_rowids]
@@ -1,19 +1,17 @@
1
1
  from random import randint
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  Dict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from aioitertools.itertools import groupby
10
8
  from sqlalchemy import select
11
- from sqlalchemy.ext.asyncio import AsyncSession
12
9
  from sqlalchemy.orm import joinedload
13
10
  from strawberry.dataloader import DataLoader
14
11
  from typing_extensions import TypeAlias
15
12
 
16
13
  from phoenix.db import models
14
+ from phoenix.server.types import DbSessionFactory
17
15
 
18
16
  SpanId: TypeAlias = str
19
17
 
@@ -22,7 +20,7 @@ Result: TypeAlias = List[models.Span]
22
20
 
23
21
 
24
22
  class SpanDescendantsDataLoader(DataLoader[Key, Result]):
25
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ def __init__(self, db: DbSessionFactory) -> None:
26
24
  super().__init__(load_fn=self._load_fn)
27
25
  self._db = db
28
26
 
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.Evaluation import SpanEvaluation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[SpanEvaluation]
19
17
 
20
18
 
21
19
  class SpanEvaluationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -1,11 +1,11 @@
1
- from typing import AsyncContextManager, Callable, List, Union
1
+ from typing import List, Union
2
2
 
3
3
  from sqlalchemy import select
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
8
7
  from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
9
 
10
10
  SpanID: TypeAlias = int
11
11
  Key: TypeAlias = SpanID
@@ -13,7 +13,7 @@ Result: TypeAlias = models.Project
13
13
 
14
14
 
15
15
  class SpanProjectsDataLoader(DataLoader[Key, Result]):
16
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
17
  super().__init__(load_fn=self._load_fn)
18
18
  self._db = db
19
19
 
@@ -2,8 +2,6 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
- Callable,
7
5
  DefaultDict,
8
6
  List,
9
7
  Literal,
@@ -14,7 +12,6 @@ from typing import (
14
12
  from cachetools import LFUCache, TTLCache
15
13
  from openinference.semconv.trace import SpanAttributes
16
14
  from sqlalchemy import Select, func, select
17
- from sqlalchemy.ext.asyncio import AsyncSession
18
15
  from sqlalchemy.sql.functions import coalesce
19
16
  from strawberry.dataloader import AbstractCache, DataLoader
20
17
  from typing_extensions import TypeAlias
@@ -22,6 +19,7 @@ from typing_extensions import TypeAlias
22
19
  from phoenix.db import models
23
20
  from phoenix.server.api.dataloaders.cache import TwoTierCache
24
21
  from phoenix.server.api.input_types.TimeRange import TimeRange
22
+ from phoenix.server.types import DbSessionFactory
25
23
  from phoenix.trace.dsl import SpanFilter
26
24
 
27
25
  Kind: TypeAlias = Literal["prompt", "completion", "total"]
@@ -71,7 +69,7 @@ class TokenCountCache(
71
69
  class TokenCountDataLoader(DataLoader[Key, Result]):
72
70
  def __init__(
73
71
  self,
74
- db: Callable[[], AsyncContextManager[AsyncSession]],
72
+ db: DbSessionFactory,
75
73
  cache_map: Optional[AbstractCache[Key, Result]] = None,
76
74
  ) -> None:
77
75
  super().__init__(
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.Evaluation import TraceEvaluation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[TraceEvaluation]
19
17
 
20
18
 
21
19
  class TraceEvaluationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -1,17 +1,15 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  Tuple,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
12
+ from phoenix.server.types import DbSessionFactory
15
13
 
16
14
  TraceId: TypeAlias = str
17
15
  Key: TypeAlias = TraceId
@@ -21,7 +19,7 @@ Result: TypeAlias = Optional[Tuple[TraceRowId, ProjectRowId]]
21
19
 
22
20
 
23
21
  class TraceRowIdsDataLoader(DataLoader[Key, Result]):
24
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
22
+ def __init__(self, db: DbSessionFactory) -> None:
25
23
  super().__init__(load_fn=self._load_fn)
26
24
  self._db = db
27
25
 
@@ -4,12 +4,14 @@ import strawberry
4
4
  from strawberry.relay import GlobalID
5
5
  from strawberry.scalars import JSON
6
6
 
7
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
8
+
7
9
 
8
10
  @strawberry.input
9
- class CreateSpanAnnotationsInput:
11
+ class CreateSpanAnnotationInput:
10
12
  span_id: GlobalID
11
13
  name: str
12
- annotator_kind: str
14
+ annotator_kind: AnnotatorKind
13
15
  label: Optional[str] = None
14
16
  score: Optional[float] = None
15
17
  explanation: Optional[str] = None
@@ -4,12 +4,14 @@ import strawberry
4
4
  from strawberry.relay import GlobalID
5
5
  from strawberry.scalars import JSON
6
6
 
7
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
8
+
7
9
 
8
10
  @strawberry.input
9
- class CreateTraceAnnotationsInput:
11
+ class CreateTraceAnnotationInput:
10
12
  trace_id: GlobalID
11
13
  name: str
12
- annotator_kind: str
14
+ annotator_kind: AnnotatorKind
13
15
  label: Optional[str] = None
14
16
  score: Optional[float] = None
15
17
  explanation: Optional[str] = None
@@ -5,12 +5,14 @@ from strawberry import UNSET
5
5
  from strawberry.relay import GlobalID
6
6
  from strawberry.scalars import JSON
7
7
 
8
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
9
+
8
10
 
9
11
  @strawberry.input
10
- class PatchAnnotationsInput:
12
+ class PatchAnnotationInput:
11
13
  annotation_id: GlobalID
12
14
  name: Optional[str] = UNSET
13
- annotator_kind: Optional[str] = UNSET
15
+ annotator_kind: Optional[AnnotatorKind] = UNSET
14
16
  label: Optional[str] = UNSET
15
17
  score: Optional[float] = UNSET
16
18
  explanation: Optional[str] = UNSET
@@ -7,10 +7,11 @@ from strawberry.types import Info
7
7
 
8
8
  from phoenix.db import models
9
9
  from phoenix.server.api.context import Context
10
- from phoenix.server.api.input_types.CreateSpanAnnotationsInput import CreateSpanAnnotationsInput
10
+ from phoenix.server.api.input_types.CreateSpanAnnotationInput import CreateSpanAnnotationInput
11
11
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
- from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
12
+ from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
13
  from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.queries import Query
14
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
15
16
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
16
17
 
@@ -18,13 +19,14 @@ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_
18
19
  @strawberry.type
19
20
  class SpanAnnotationMutationPayload:
20
21
  span_annotations: List[SpanAnnotation]
22
+ query: Query
21
23
 
22
24
 
23
25
  @strawberry.type
24
26
  class SpanAnnotationMutationMixin:
25
27
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
26
28
  async def create_span_annotations(
27
- self, info: Info[Context, None], input: List[CreateSpanAnnotationsInput]
29
+ self, info: Info[Context, None], input: List[CreateSpanAnnotationInput]
28
30
  ) -> SpanAnnotationMutationPayload:
29
31
  inserted_annotations: Sequence[models.SpanAnnotation] = []
30
32
  async with info.context.db() as session:
@@ -35,7 +37,7 @@ class SpanAnnotationMutationMixin:
35
37
  label=annotation.label,
36
38
  score=annotation.score,
37
39
  explanation=annotation.explanation,
38
- annotator_kind=annotation.annotator_kind,
40
+ annotator_kind=annotation.annotator_kind.value,
39
41
  metadata_=annotation.metadata,
40
42
  )
41
43
  for annotation in input
@@ -49,12 +51,13 @@ class SpanAnnotationMutationMixin:
49
51
  return SpanAnnotationMutationPayload(
50
52
  span_annotations=[
51
53
  to_gql_span_annotation(annotation) for annotation in inserted_annotations
52
- ]
54
+ ],
55
+ query=Query(),
53
56
  )
54
57
 
55
58
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
56
59
  async def patch_span_annotations(
57
- self, info: Info[Context, None], input: List[PatchAnnotationsInput]
60
+ self, info: Info[Context, None], input: List[PatchAnnotationInput]
58
61
  ) -> SpanAnnotationMutationPayload:
59
62
  patched_annotations = []
60
63
  async with info.context.db() as session:
@@ -66,7 +69,13 @@ class SpanAnnotationMutationMixin:
66
69
  column.key: patch_value
67
70
  for column, patch_value, column_is_nullable in (
68
71
  (models.SpanAnnotation.name, annotation.name, False),
69
- (models.SpanAnnotation.annotator_kind, annotation.annotator_kind, False),
72
+ (
73
+ models.SpanAnnotation.annotator_kind,
74
+ annotation.annotator_kind.value
75
+ if annotation.annotator_kind is not None
76
+ else None,
77
+ False,
78
+ ),
70
79
  (models.SpanAnnotation.label, annotation.label, True),
71
80
  (models.SpanAnnotation.score, annotation.score, True),
72
81
  (models.SpanAnnotation.explanation, annotation.explanation, True),
@@ -83,7 +92,7 @@ class SpanAnnotationMutationMixin:
83
92
  if span_annotation is not None:
84
93
  patched_annotations.append(to_gql_span_annotation(span_annotation))
85
94
 
86
- return SpanAnnotationMutationPayload(span_annotations=patched_annotations)
95
+ return SpanAnnotationMutationPayload(span_annotations=patched_annotations, query=Query())
87
96
 
88
97
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
89
98
  async def delete_span_annotations(
@@ -105,4 +114,6 @@ class SpanAnnotationMutationMixin:
105
114
  deleted_annotations_gql = [
106
115
  to_gql_span_annotation(annotation) for annotation in deleted_annotations
107
116
  ]
108
- return SpanAnnotationMutationPayload(span_annotations=deleted_annotations_gql)
117
+ return SpanAnnotationMutationPayload(
118
+ span_annotations=deleted_annotations_gql, query=Query()
119
+ )