arize-phoenix 4.4.4rc6__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +8 -14
  2. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +58 -122
  3. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -42
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/datetime_utils.py +0 -4
  10. phoenix/db/bulk_inserter.py +14 -54
  11. phoenix/db/insertion/evaluation.py +10 -10
  12. phoenix/db/insertion/helpers.py +14 -17
  13. phoenix/db/insertion/span.py +3 -3
  14. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  15. phoenix/db/models.py +4 -236
  16. phoenix/inferences/fixtures.py +23 -23
  17. phoenix/inferences/inferences.py +7 -7
  18. phoenix/inferences/validation.py +1 -1
  19. phoenix/server/api/context.py +0 -20
  20. phoenix/server/api/dataloaders/__init__.py +0 -20
  21. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  22. phoenix/server/api/routers/v1/__init__.py +2 -77
  23. phoenix/server/api/routers/v1/evaluations.py +13 -8
  24. phoenix/server/api/routers/v1/spans.py +5 -9
  25. phoenix/server/api/routers/v1/traces.py +4 -1
  26. phoenix/server/api/schema.py +303 -2
  27. phoenix/server/api/types/Cluster.py +19 -19
  28. phoenix/server/api/types/Dataset.py +63 -282
  29. phoenix/server/api/types/DatasetRole.py +23 -0
  30. phoenix/server/api/types/Dimension.py +29 -30
  31. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  32. phoenix/server/api/types/Event.py +16 -16
  33. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  34. phoenix/server/api/types/Model.py +42 -43
  35. phoenix/server/api/types/Project.py +12 -26
  36. phoenix/server/api/types/Span.py +2 -79
  37. phoenix/server/api/types/TimeSeries.py +6 -6
  38. phoenix/server/api/types/Trace.py +4 -15
  39. phoenix/server/api/types/UMAPPoints.py +1 -1
  40. phoenix/server/api/types/node.py +111 -5
  41. phoenix/server/api/types/pagination.py +52 -10
  42. phoenix/server/app.py +49 -103
  43. phoenix/server/main.py +27 -49
  44. phoenix/server/openapi/docs.py +0 -3
  45. phoenix/server/static/index.js +1384 -2390
  46. phoenix/server/templates/index.html +0 -1
  47. phoenix/services.py +15 -15
  48. phoenix/session/client.py +23 -611
  49. phoenix/session/session.py +37 -47
  50. phoenix/trace/exporter.py +9 -14
  51. phoenix/trace/fixtures.py +7 -133
  52. phoenix/trace/schemas.py +2 -1
  53. phoenix/trace/span_evaluations.py +3 -3
  54. phoenix/trace/trace_dataset.py +6 -6
  55. phoenix/version.py +1 -1
  56. phoenix/db/insertion/dataset.py +0 -237
  57. phoenix/db/migrations/types.py +0 -29
  58. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  59. phoenix/experiments/__init__.py +0 -6
  60. phoenix/experiments/evaluators/__init__.py +0 -29
  61. phoenix/experiments/evaluators/base.py +0 -153
  62. phoenix/experiments/evaluators/code_evaluators.py +0 -99
  63. phoenix/experiments/evaluators/llm_evaluators.py +0 -244
  64. phoenix/experiments/evaluators/utils.py +0 -189
  65. phoenix/experiments/functions.py +0 -616
  66. phoenix/experiments/tracing.py +0 -85
  67. phoenix/experiments/types.py +0 -722
  68. phoenix/experiments/utils.py +0 -9
  69. phoenix/server/api/dataloaders/average_experiment_run_latency.py +0 -54
  70. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  71. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  72. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  73. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  74. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  75. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  76. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  77. phoenix/server/api/dataloaders/span_projects.py +0 -33
  78. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  79. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  80. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  81. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  82. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  83. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  84. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  85. phoenix/server/api/input_types/DatasetSort.py +0 -17
  86. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  87. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  88. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  89. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  90. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  91. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  92. phoenix/server/api/mutations/__init__.py +0 -13
  93. phoenix/server/api/mutations/auth.py +0 -11
  94. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  95. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  96. phoenix/server/api/mutations/project_mutations.py +0 -47
  97. phoenix/server/api/openapi/__init__.py +0 -0
  98. phoenix/server/api/openapi/main.py +0 -6
  99. phoenix/server/api/openapi/schema.py +0 -16
  100. phoenix/server/api/queries.py +0 -503
  101. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  102. phoenix/server/api/routers/v1/datasets.py +0 -965
  103. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -65
  104. phoenix/server/api/routers/v1/experiment_runs.py +0 -96
  105. phoenix/server/api/routers/v1/experiments.py +0 -174
  106. phoenix/server/api/types/AnnotatorKind.py +0 -10
  107. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  108. phoenix/server/api/types/DatasetExample.py +0 -85
  109. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  110. phoenix/server/api/types/DatasetVersion.py +0 -14
  111. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  112. phoenix/server/api/types/Experiment.py +0 -147
  113. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  114. phoenix/server/api/types/ExperimentComparison.py +0 -19
  115. phoenix/server/api/types/ExperimentRun.py +0 -91
  116. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  117. phoenix/server/api/types/Inferences.py +0 -80
  118. phoenix/server/api/types/InferencesRole.py +0 -23
  119. phoenix/utilities/json.py +0 -61
  120. phoenix/utilities/re.py +0 -50
  121. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -1,9 +0,0 @@
1
- from phoenix.config import get_web_base_url
2
-
3
-
4
- def get_experiment_url(*, dataset_id: str, experiment_id: str) -> str:
5
- return f"{get_web_base_url()}datasets/{dataset_id}/compare?experimentId={experiment_id}"
6
-
7
-
8
- def get_dataset_experiments_url(*, dataset_id: str) -> str:
9
- return f"{get_web_base_url()}datasets/{dataset_id}/experiments"
@@ -1,54 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- )
6
-
7
- from sqlalchemy import func, select
8
- from sqlalchemy.ext.asyncio import AsyncSession
9
- from strawberry.dataloader import DataLoader
10
- from typing_extensions import TypeAlias
11
-
12
- from phoenix.db import models
13
-
14
- ExperimentID: TypeAlias = int
15
- RunLatency: TypeAlias = float
16
- Key: TypeAlias = ExperimentID
17
- Result: TypeAlias = RunLatency
18
-
19
-
20
- class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
21
- def __init__(
22
- self,
23
- db: Callable[[], AsyncContextManager[AsyncSession]],
24
- ) -> None:
25
- super().__init__(load_fn=self._load_fn)
26
- self._db = db
27
-
28
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
- experiment_ids = keys
30
- async with self._db() as session:
31
- avg_latencies = {
32
- experiment_id: avg_latency
33
- async for experiment_id, avg_latency in await session.stream(
34
- select(
35
- models.ExperimentRun.experiment_id,
36
- func.avg(
37
- func.extract(
38
- "epoch",
39
- models.ExperimentRun.end_time,
40
- )
41
- - func.extract(
42
- "epoch",
43
- models.ExperimentRun.start_time,
44
- )
45
- ),
46
- )
47
- .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
48
- .group_by(models.ExperimentRun.experiment_id)
49
- )
50
- }
51
- return [
52
- avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
53
- for experiment_id in experiment_ids
54
- ]
@@ -1,100 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- Tuple,
7
- Union,
8
- )
9
-
10
- from sqlalchemy import Integer, case, func, literal, or_, select, union
11
- from sqlalchemy.ext.asyncio import AsyncSession
12
- from strawberry.dataloader import DataLoader
13
- from typing_extensions import TypeAlias
14
-
15
- from phoenix.db import models
16
- from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
17
-
18
- ExampleID: TypeAlias = int
19
- VersionID: TypeAlias = Optional[int]
20
- Key: TypeAlias = Tuple[ExampleID, Optional[VersionID]]
21
- Result: TypeAlias = DatasetExampleRevision
22
-
23
-
24
- class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
25
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
26
- super().__init__(load_fn=self._load_fn)
27
- self._db = db
28
-
29
- async def _load_fn(self, keys: List[Key]) -> List[Union[Result, ValueError]]:
30
- # sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
31
- # For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
32
- keys_subquery = union(
33
- *(
34
- select(
35
- literal(example_id, Integer).label("example_id"),
36
- literal(version_id, Integer).label("version_id"),
37
- )
38
- for example_id, version_id in keys
39
- )
40
- ).subquery()
41
- revision_ids = (
42
- select(
43
- keys_subquery.c.example_id,
44
- keys_subquery.c.version_id,
45
- func.max(models.DatasetExampleRevision.id).label("revision_id"),
46
- )
47
- .select_from(keys_subquery)
48
- .join(
49
- models.DatasetExampleRevision,
50
- onclause=keys_subquery.c.example_id
51
- == models.DatasetExampleRevision.dataset_example_id,
52
- )
53
- .where(
54
- or_(
55
- keys_subquery.c.version_id.is_(None),
56
- models.DatasetExampleRevision.dataset_version_id <= keys_subquery.c.version_id,
57
- )
58
- )
59
- .group_by(keys_subquery.c.example_id, keys_subquery.c.version_id)
60
- ).subquery()
61
- query = (
62
- select(
63
- revision_ids.c.example_id,
64
- revision_ids.c.version_id,
65
- case(
66
- (
67
- or_(
68
- revision_ids.c.version_id.is_(None),
69
- models.DatasetVersion.id.is_not(None),
70
- ),
71
- True,
72
- ),
73
- else_=False,
74
- ).label("is_valid_version"), # check that non-null versions exist
75
- models.DatasetExampleRevision,
76
- )
77
- .select_from(revision_ids)
78
- .join(
79
- models.DatasetExampleRevision,
80
- onclause=revision_ids.c.revision_id == models.DatasetExampleRevision.id,
81
- )
82
- .join(
83
- models.DatasetVersion,
84
- onclause=revision_ids.c.version_id == models.DatasetVersion.id,
85
- isouter=True, # keep rows where the version id is null
86
- )
87
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
88
- )
89
- async with self._db() as session:
90
- results = {
91
- (example_id, version_id): DatasetExampleRevision.from_orm_revision(revision)
92
- async for (
93
- example_id,
94
- version_id,
95
- is_valid_version,
96
- revision,
97
- ) in await session.stream(query)
98
- if is_valid_version
99
- }
100
- return [results.get(key, ValueError("Could not find revision.")) for key in keys]
@@ -1,43 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- )
7
-
8
- from sqlalchemy import select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
- from sqlalchemy.orm import joinedload
11
- from strawberry.dataloader import DataLoader
12
- from typing_extensions import TypeAlias
13
-
14
- from phoenix.db import models
15
-
16
- ExampleID: TypeAlias = int
17
- Key: TypeAlias = ExampleID
18
- Result: TypeAlias = Optional[models.Span]
19
-
20
-
21
- class DatasetExampleSpansDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
- super().__init__(load_fn=self._load_fn)
24
- self._db = db
25
-
26
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
27
- example_ids = keys
28
- async with self._db() as session:
29
- spans = {
30
- example_id: span
31
- async for example_id, span in await session.stream(
32
- select(models.DatasetExample.id, models.Span)
33
- .select_from(models.DatasetExample)
34
- .join(models.Span, models.DatasetExample.span_rowid == models.Span.id)
35
- .where(models.DatasetExample.id.in_(example_ids))
36
- .options(
37
- joinedload(models.Span.trace, innerjoin=True).load_only(
38
- models.Trace.trace_id
39
- )
40
- )
41
- )
42
- }
43
- return [spans.get(example_id) for example_id in example_ids]
@@ -1,85 +0,0 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass
3
- from typing import (
4
- AsyncContextManager,
5
- Callable,
6
- DefaultDict,
7
- List,
8
- Optional,
9
- )
10
-
11
- from sqlalchemy import func, select
12
- from sqlalchemy.ext.asyncio import AsyncSession
13
- from strawberry.dataloader import AbstractCache, DataLoader
14
- from typing_extensions import TypeAlias
15
-
16
- from phoenix.db import models
17
-
18
-
19
- @dataclass
20
- class ExperimentAnnotationSummary:
21
- annotation_name: str
22
- min_score: float
23
- max_score: float
24
- mean_score: float
25
- count: int
26
- error_count: int
27
-
28
-
29
- ExperimentID: TypeAlias = int
30
- Key: TypeAlias = ExperimentID
31
- Result: TypeAlias = List[ExperimentAnnotationSummary]
32
-
33
-
34
- class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
35
- def __init__(
36
- self,
37
- db: Callable[[], AsyncContextManager[AsyncSession]],
38
- cache_map: Optional[AbstractCache[Key, Result]] = None,
39
- ) -> None:
40
- super().__init__(load_fn=self._load_fn)
41
- self._db = db
42
-
43
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
44
- experiment_ids = keys
45
- summaries: DefaultDict[ExperimentID, Result] = defaultdict(list)
46
- async with self._db() as session:
47
- async for (
48
- experiment_id,
49
- annotation_name,
50
- min_score,
51
- max_score,
52
- mean_score,
53
- count,
54
- error_count,
55
- ) in await session.stream(
56
- select(
57
- models.ExperimentRun.experiment_id,
58
- models.ExperimentRunAnnotation.name,
59
- func.min(models.ExperimentRunAnnotation.score),
60
- func.max(models.ExperimentRunAnnotation.score),
61
- func.avg(models.ExperimentRunAnnotation.score),
62
- func.count(),
63
- func.count(models.ExperimentRunAnnotation.error),
64
- )
65
- .join(
66
- models.ExperimentRun,
67
- models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
68
- )
69
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
70
- .group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
71
- ):
72
- summaries[experiment_id].append(
73
- ExperimentAnnotationSummary(
74
- annotation_name=annotation_name,
75
- min_score=min_score,
76
- max_score=max_score,
77
- mean_score=mean_score,
78
- count=count,
79
- error_count=error_count,
80
- )
81
- )
82
- return [
83
- sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
84
- for experiment_id in experiment_ids
85
- ]
@@ -1,43 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- )
7
-
8
- from sqlalchemy import func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
- from strawberry.dataloader import DataLoader
11
- from typing_extensions import TypeAlias
12
-
13
- from phoenix.db import models
14
-
15
- ExperimentID: TypeAlias = int
16
- ErrorRate: TypeAlias = float
17
- Key: TypeAlias = ExperimentID
18
- Result: TypeAlias = Optional[ErrorRate]
19
-
20
-
21
- class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
22
- def __init__(
23
- self,
24
- db: Callable[[], AsyncContextManager[AsyncSession]],
25
- ) -> None:
26
- super().__init__(load_fn=self._load_fn)
27
- self._db = db
28
-
29
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
30
- experiment_ids = keys
31
- async with self._db() as session:
32
- error_rates = {
33
- experiment_id: error_rate
34
- async for experiment_id, error_rate in await session.stream(
35
- select(
36
- models.ExperimentRun.experiment_id,
37
- func.count(models.ExperimentRun.error) / func.count(),
38
- )
39
- .group_by(models.ExperimentRun.experiment_id)
40
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
41
- )
42
- }
43
- return [error_rates.get(experiment_id) for experiment_id in experiment_ids]
@@ -1,42 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- )
6
-
7
- from sqlalchemy import func, select
8
- from sqlalchemy.ext.asyncio import AsyncSession
9
- from strawberry.dataloader import DataLoader
10
- from typing_extensions import TypeAlias
11
-
12
- from phoenix.db import models
13
-
14
- ExperimentID: TypeAlias = int
15
- RunCount: TypeAlias = int
16
- Key: TypeAlias = ExperimentID
17
- Result: TypeAlias = RunCount
18
-
19
-
20
- class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
21
- def __init__(
22
- self,
23
- db: Callable[[], AsyncContextManager[AsyncSession]],
24
- ) -> None:
25
- super().__init__(load_fn=self._load_fn)
26
- self._db = db
27
-
28
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
- experiment_ids = keys
30
- async with self._db() as session:
31
- run_counts = {
32
- experiment_id: run_count
33
- async for experiment_id, run_count in await session.stream(
34
- select(models.ExperimentRun.experiment_id, func.count())
35
- .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
36
- .group_by(models.ExperimentRun.experiment_id)
37
- )
38
- }
39
- return [
40
- run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
41
- for experiment_id in experiment_ids
42
- ]
@@ -1,49 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- )
7
-
8
- from sqlalchemy import distinct, func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
- from strawberry.dataloader import DataLoader
11
- from typing_extensions import TypeAlias
12
-
13
- from phoenix.db import models
14
-
15
- ExperimentId: TypeAlias = int
16
- Key: TypeAlias = ExperimentId
17
- Result: TypeAlias = Optional[int]
18
-
19
-
20
- class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
22
- super().__init__(load_fn=self._load_fn)
23
- self._db = db
24
-
25
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
26
- experiment_ids = keys
27
- dataset_ids = (
28
- select(distinct(models.Experiment.dataset_id))
29
- .where(models.Experiment.id.in_(experiment_ids))
30
- .scalar_subquery()
31
- )
32
- row_number = (
33
- func.row_number().over(
34
- partition_by=models.Experiment.dataset_id,
35
- order_by=models.Experiment.id,
36
- )
37
- ).label("row_number")
38
- subq = (
39
- select(models.Experiment.id, row_number)
40
- .where(models.Experiment.dataset_id.in_(dataset_ids))
41
- .subquery()
42
- )
43
- stmt = select(subq).where(subq.c.id.in_(experiment_ids))
44
- async with self._db() as session:
45
- result = {
46
- experiment_id: sequence_number
47
- async for experiment_id, sequence_number in await session.stream(stmt)
48
- }
49
- return [result.get(experiment_id) for experiment_id in experiment_ids]
@@ -1,31 +0,0 @@
1
- from collections import defaultdict
2
- from typing import AsyncContextManager, Callable, DefaultDict, List, Optional
3
-
4
- from sqlalchemy import select
5
- from sqlalchemy.ext.asyncio import AsyncSession
6
- from strawberry.dataloader import DataLoader
7
- from typing_extensions import TypeAlias
8
-
9
- from phoenix.db import models
10
-
11
- ProjectName: TypeAlias = str
12
- Key: TypeAlias = ProjectName
13
- Result: TypeAlias = Optional[models.Project]
14
-
15
-
16
- class ProjectByNameDataLoader(DataLoader[Key, Result]):
17
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
18
- super().__init__(load_fn=self._load_fn)
19
- self._db = db
20
-
21
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
22
- project_names = list(set(keys))
23
- projects_by_name: DefaultDict[Key, Result] = defaultdict(None)
24
- async with self._db() as session:
25
- data = await session.stream_scalars(
26
- select(models.Project).where(models.Project.name.in_(project_names))
27
- )
28
- async for project in data:
29
- projects_by_name[project.name] = project
30
-
31
- return [projects_by_name[project_name] for project_name in project_names]
@@ -1,33 +0,0 @@
1
- from typing import AsyncContextManager, Callable, List, Union
2
-
3
- from sqlalchemy import select
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
- from strawberry.dataloader import DataLoader
6
- from typing_extensions import TypeAlias
7
-
8
- from phoenix.db import models
9
-
10
- SpanID: TypeAlias = int
11
- Key: TypeAlias = SpanID
12
- Result: TypeAlias = models.Project
13
-
14
-
15
- class SpanProjectsDataLoader(DataLoader[Key, Result]):
16
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
17
- super().__init__(load_fn=self._load_fn)
18
- self._db = db
19
-
20
- async def _load_fn(self, keys: List[Key]) -> List[Union[Result, ValueError]]:
21
- span_ids = list(set(keys))
22
- async with self._db() as session:
23
- projects = {
24
- span_id: project
25
- async for span_id, project in await session.stream(
26
- select(models.Span.id, models.Project)
27
- .select_from(models.Span)
28
- .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
29
- .join(models.Project, models.Trace.project_rowid == models.Project.id)
30
- .where(models.Span.id.in_(span_ids))
31
- )
32
- }
33
- return [projects.get(span_id) or ValueError("Invalid span ID") for span_id in span_ids]
@@ -1,39 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- Tuple,
7
- )
8
-
9
- from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
- from strawberry.dataloader import DataLoader
12
- from typing_extensions import TypeAlias
13
-
14
- from phoenix.db import models
15
-
16
- TraceId: TypeAlias = str
17
- Key: TypeAlias = TraceId
18
- TraceRowId: TypeAlias = int
19
- ProjectRowId: TypeAlias = int
20
- Result: TypeAlias = Optional[Tuple[TraceRowId, ProjectRowId]]
21
-
22
-
23
- class TraceRowIdsDataLoader(DataLoader[Key, Result]):
24
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
25
- super().__init__(load_fn=self._load_fn)
26
- self._db = db
27
-
28
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
- stmt = select(
30
- models.Trace.trace_id,
31
- models.Trace.id,
32
- models.Trace.project_rowid,
33
- ).where(models.Trace.trace_id.in_(keys))
34
- async with self._db() as session:
35
- result = {
36
- trace_id: (id_, project_rowid)
37
- async for trace_id, id_, project_rowid in await session.stream(stmt)
38
- }
39
- return list(map(result.get, keys))