arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +108 -55
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +64 -62
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/server/api/context.py +16 -0
  25. phoenix/server/api/dataloaders/__init__.py +16 -0
  26. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  27. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  28. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  29. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  30. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  31. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  33. phoenix/server/api/dataloaders/span_projects.py +33 -0
  34. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  35. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  36. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  37. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  38. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  39. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  40. phoenix/server/api/input_types/DatasetSort.py +17 -0
  41. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  42. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  43. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  44. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  45. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  46. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  47. phoenix/server/api/mutations/__init__.py +13 -0
  48. phoenix/server/api/mutations/auth.py +11 -0
  49. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  50. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  51. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  52. phoenix/server/api/mutations/project_mutations.py +42 -0
  53. phoenix/server/api/queries.py +503 -0
  54. phoenix/server/api/routers/v1/__init__.py +77 -2
  55. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  56. phoenix/server/api/routers/v1/datasets.py +861 -0
  57. phoenix/server/api/routers/v1/evaluations.py +4 -2
  58. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  59. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  60. phoenix/server/api/routers/v1/experiments.py +174 -0
  61. phoenix/server/api/routers/v1/spans.py +3 -1
  62. phoenix/server/api/routers/v1/traces.py +1 -4
  63. phoenix/server/api/schema.py +2 -303
  64. phoenix/server/api/types/AnnotatorKind.py +10 -0
  65. phoenix/server/api/types/Cluster.py +19 -19
  66. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  67. phoenix/server/api/types/Dataset.py +282 -63
  68. phoenix/server/api/types/DatasetExample.py +85 -0
  69. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  70. phoenix/server/api/types/DatasetVersion.py +14 -0
  71. phoenix/server/api/types/Dimension.py +30 -29
  72. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  73. phoenix/server/api/types/Event.py +16 -16
  74. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  75. phoenix/server/api/types/Experiment.py +135 -0
  76. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  77. phoenix/server/api/types/ExperimentComparison.py +19 -0
  78. phoenix/server/api/types/ExperimentRun.py +91 -0
  79. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  80. phoenix/server/api/types/Inferences.py +80 -0
  81. phoenix/server/api/types/InferencesRole.py +23 -0
  82. phoenix/server/api/types/Model.py +43 -42
  83. phoenix/server/api/types/Project.py +26 -12
  84. phoenix/server/api/types/Span.py +78 -2
  85. phoenix/server/api/types/TimeSeries.py +6 -6
  86. phoenix/server/api/types/Trace.py +15 -4
  87. phoenix/server/api/types/UMAPPoints.py +1 -1
  88. phoenix/server/api/types/node.py +5 -111
  89. phoenix/server/api/types/pagination.py +10 -52
  90. phoenix/server/app.py +99 -49
  91. phoenix/server/main.py +49 -27
  92. phoenix/server/openapi/docs.py +3 -0
  93. phoenix/server/static/index.js +2246 -1368
  94. phoenix/server/templates/index.html +1 -0
  95. phoenix/services.py +15 -15
  96. phoenix/session/client.py +316 -21
  97. phoenix/session/session.py +47 -37
  98. phoenix/trace/exporter.py +14 -9
  99. phoenix/trace/fixtures.py +133 -7
  100. phoenix/trace/span_evaluations.py +3 -3
  101. phoenix/trace/trace_dataset.py +6 -6
  102. phoenix/utilities/json.py +61 -0
  103. phoenix/utilities/re.py +50 -0
  104. phoenix/version.py +1 -1
  105. phoenix/server/api/types/DatasetRole.py +0 -23
  106. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  107. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  108. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  109. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -9,7 +9,7 @@ from urllib.parse import quote, urljoin
9
9
 
10
10
  from pandas import read_parquet
11
11
 
12
- from phoenix.config import DATASET_DIR
12
+ from phoenix.config import INFERENCES_DIR
13
13
  from phoenix.inferences.inferences import Inferences
14
14
  from phoenix.inferences.schema import (
15
15
  EmbeddingColumnNames,
@@ -20,7 +20,7 @@ from phoenix.inferences.schema import (
20
20
  logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
- class DatasetRole(Enum):
23
+ class InferencesRole(Enum):
24
24
  PRIMARY = auto()
25
25
  REFERENCE = auto()
26
26
  CORPUS = auto()
@@ -39,11 +39,11 @@ class Fixture:
39
39
  corpus_file_name: Optional[str] = None
40
40
  corpus_schema: Optional[Schema] = None
41
41
 
42
- def paths(self) -> Iterator[Tuple[DatasetRole, Path]]:
42
+ def paths(self) -> Iterator[Tuple[InferencesRole, Path]]:
43
43
  return (
44
44
  (role, Path(self.prefix) / name)
45
45
  for role, name in zip(
46
- DatasetRole,
46
+ InferencesRole,
47
47
  (
48
48
  self.primary_file_name,
49
49
  self.reference_file_name,
@@ -413,41 +413,41 @@ FIXTURES: Tuple[Fixture, ...] = (
413
413
  NAME_TO_FIXTURE = {fixture.name: fixture for fixture in FIXTURES}
414
414
 
415
415
 
416
- def get_datasets(
416
+ def get_inferences(
417
417
  fixture_name: str,
418
418
  no_internet: bool = False,
419
419
  ) -> Tuple[Inferences, Optional[Inferences], Optional[Inferences]]:
420
420
  """
421
- Downloads primary and reference datasets for a fixture if they are not found
421
+ Downloads primary and reference inferences for a fixture if they are not found
422
422
  locally.
423
423
  """
424
424
  fixture = _get_fixture_by_name(fixture_name=fixture_name)
425
425
  if no_internet:
426
- paths = {role: DATASET_DIR / path for role, path in fixture.paths()}
426
+ paths = {role: INFERENCES_DIR / path for role, path in fixture.paths()}
427
427
  else:
428
- paths = dict(_download(fixture, DATASET_DIR))
429
- primary_dataset = Inferences(
430
- read_parquet(paths[DatasetRole.PRIMARY]),
428
+ paths = dict(_download(fixture, INFERENCES_DIR))
429
+ primary_inferences = Inferences(
430
+ read_parquet(paths[InferencesRole.PRIMARY]),
431
431
  fixture.primary_schema,
432
432
  "production",
433
433
  )
434
- reference_dataset = None
434
+ reference_inferences = None
435
435
  if fixture.reference_file_name is not None:
436
- reference_dataset = Inferences(
437
- read_parquet(paths[DatasetRole.REFERENCE]),
436
+ reference_inferences = Inferences(
437
+ read_parquet(paths[InferencesRole.REFERENCE]),
438
438
  fixture.reference_schema
439
439
  if fixture.reference_schema is not None
440
440
  else fixture.primary_schema,
441
441
  "training",
442
442
  )
443
- corpus_dataset = None
443
+ corpus_inferences = None
444
444
  if fixture.corpus_file_name is not None:
445
- corpus_dataset = Inferences(
446
- read_parquet(paths[DatasetRole.CORPUS]),
445
+ corpus_inferences = Inferences(
446
+ read_parquet(paths[InferencesRole.CORPUS]),
447
447
  fixture.corpus_schema,
448
448
  "knowledge_base",
449
449
  )
450
- return primary_dataset, reference_dataset, corpus_dataset
450
+ return primary_inferences, reference_inferences, corpus_inferences
451
451
 
452
452
 
453
453
  def _get_fixture_by_name(fixture_name: str) -> Fixture:
@@ -496,14 +496,14 @@ def load_example(use_case: str) -> ExampleInferences:
496
496
 
497
497
  """
498
498
  fixture = _get_fixture_by_name(use_case)
499
- primary_dataset, reference_dataset, corpus_dataset = get_datasets(use_case)
499
+ primary_inferences, reference_inferences, corpus_inferences = get_inferences(use_case)
500
500
  print(f"📥 Loaded {use_case} example datasets.")
501
501
  print("ℹ️ About this use-case:")
502
502
  print(fixture.description)
503
503
  return ExampleInferences(
504
- primary=primary_dataset,
505
- reference=reference_dataset,
506
- corpus=corpus_dataset,
504
+ primary=primary_inferences,
505
+ reference=reference_inferences,
506
+ corpus=corpus_inferences,
507
507
  )
508
508
 
509
509
 
@@ -544,7 +544,7 @@ class GCSAssets(NamedTuple):
544
544
  )
545
545
 
546
546
 
547
- def _download(fixture: Fixture, location: Path) -> Iterator[Tuple[DatasetRole, Path]]:
547
+ def _download(fixture: Fixture, location: Path) -> Iterator[Tuple[InferencesRole, Path]]:
548
548
  for role, path in fixture.paths():
549
549
  yield role, GCSAssets().metadata(path).save_artifact(location)
550
550
 
@@ -556,5 +556,5 @@ if __name__ == "__main__":
556
556
  for fixture in FIXTURES:
557
557
  start_time = time.time()
558
558
  print(f"getting {fixture.name}", end="...")
559
- dict(_download(fixture, DATASET_DIR))
559
+ dict(_download(fixture, INFERENCES_DIR))
560
560
  print(f"done ({time.time() - start_time:.2f}s)")
@@ -15,7 +15,7 @@ from pandas.api.types import (
15
15
  )
16
16
  from typing_extensions import TypeAlias
17
17
 
18
- from phoenix.config import DATASET_DIR, GENERATED_DATASET_NAME_PREFIX
18
+ from phoenix.config import GENERATED_INFERENCES_NAME_PREFIX, INFERENCES_DIR
19
19
  from phoenix.datetime_utils import normalize_timestamps
20
20
  from phoenix.utilities.deprecation import deprecated
21
21
 
@@ -31,7 +31,7 @@ from .schema import (
31
31
  SchemaFieldName,
32
32
  SchemaFieldValue,
33
33
  )
34
- from .validation import validate_dataset_inputs
34
+ from .validation import validate_inferences_inputs
35
35
 
36
36
  logger = logging.getLogger(__name__)
37
37
 
@@ -62,7 +62,7 @@ class Inferences:
62
62
 
63
63
  Examples
64
64
  --------
65
- >>> primary_dataset = px.Inferences(
65
+ >>> primary_inferences = px.Inferences(
66
66
  >>> dataframe=production_dataframe, schema=schema, name="primary"
67
67
  >>> )
68
68
  """
@@ -81,7 +81,7 @@ class Inferences:
81
81
  # allow for schema like objects
82
82
  if not isinstance(schema, Schema):
83
83
  schema = _get_schema_from_unknown_schema_param(schema)
84
- errors = validate_dataset_inputs(
84
+ errors = validate_inferences_inputs(
85
85
  dataframe=dataframe,
86
86
  schema=schema,
87
87
  )
@@ -95,7 +95,7 @@ class Inferences:
95
95
  self.__dataframe: DataFrame = dataframe
96
96
  self.__schema: Schema = schema
97
97
  self.__name: str = (
98
- name if name is not None else f"{GENERATED_DATASET_NAME_PREFIX}{str(uuid.uuid4())}"
98
+ name if name is not None else f"{GENERATED_INFERENCES_NAME_PREFIX}{str(uuid.uuid4())}"
99
99
  )
100
100
  self._is_empty = self.dataframe.empty
101
101
  logger.info(f"""Dataset: {self.__name} initialized""")
@@ -118,7 +118,7 @@ class Inferences:
118
118
  @classmethod
119
119
  def from_name(cls, name: str) -> "Inferences":
120
120
  """Retrieves a dataset by name from the file system"""
121
- directory = DATASET_DIR / name
121
+ directory = INFERENCES_DIR / name
122
122
  df = read_parquet(directory / cls._data_file_name)
123
123
  with open(directory / cls._schema_file_name) as schema_file:
124
124
  schema_json = schema_file.read()
@@ -127,7 +127,7 @@ class Inferences:
127
127
 
128
128
  def to_disc(self) -> None:
129
129
  """writes the data and schema to disc"""
130
- directory = DATASET_DIR / self.name
130
+ directory = INFERENCES_DIR / self.name
131
131
  directory.mkdir(parents=True, exist_ok=True)
132
132
  self.dataframe.to_parquet(
133
133
  directory / self._data_file_name,
@@ -34,7 +34,7 @@ def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
34
34
  return []
35
35
 
36
36
 
37
- def validate_dataset_inputs(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
37
+ def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
38
38
  errors = _check_missing_columns(dataframe, schema)
39
39
  if errors:
40
40
  return errors
@@ -12,33 +12,49 @@ from typing_extensions import TypeAlias
12
12
  from phoenix.core.model_schema import Model
13
13
  from phoenix.server.api.dataloaders import (
14
14
  CacheForDataLoaders,
15
+ DatasetExampleRevisionsDataLoader,
16
+ DatasetExampleSpansDataLoader,
15
17
  DocumentEvaluationsDataLoader,
16
18
  DocumentEvaluationSummaryDataLoader,
17
19
  DocumentRetrievalMetricsDataLoader,
18
20
  EvaluationSummaryDataLoader,
21
+ ExperimentAnnotationSummaryDataLoader,
22
+ ExperimentErrorRatesDataLoader,
23
+ ExperimentSequenceNumberDataLoader,
19
24
  LatencyMsQuantileDataLoader,
20
25
  MinStartOrMaxEndTimeDataLoader,
26
+ ProjectByNameDataLoader,
21
27
  RecordCountDataLoader,
22
28
  SpanDescendantsDataLoader,
23
29
  SpanEvaluationsDataLoader,
30
+ SpanProjectsDataLoader,
24
31
  TokenCountDataLoader,
25
32
  TraceEvaluationsDataLoader,
33
+ TraceRowIdsDataLoader,
26
34
  )
27
35
 
28
36
 
29
37
  @dataclass
30
38
  class DataLoaders:
39
+ dataset_example_revisions: DatasetExampleRevisionsDataLoader
40
+ dataset_example_spans: DatasetExampleSpansDataLoader
31
41
  document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
32
42
  document_evaluations: DocumentEvaluationsDataLoader
33
43
  document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
34
44
  evaluation_summaries: EvaluationSummaryDataLoader
45
+ experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
46
+ experiment_error_rates: ExperimentErrorRatesDataLoader
47
+ experiment_sequence_number: ExperimentSequenceNumberDataLoader
35
48
  latency_ms_quantile: LatencyMsQuantileDataLoader
36
49
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
37
50
  record_counts: RecordCountDataLoader
38
51
  span_descendants: SpanDescendantsDataLoader
39
52
  span_evaluations: SpanEvaluationsDataLoader
53
+ span_projects: SpanProjectsDataLoader
40
54
  token_counts: TokenCountDataLoader
41
55
  trace_evaluations: TraceEvaluationsDataLoader
56
+ trace_row_ids: TraceRowIdsDataLoader
57
+ project_by_name: ProjectByNameDataLoader
42
58
 
43
59
 
44
60
  ProjectRowId: TypeAlias = int
@@ -8,6 +8,8 @@ from phoenix.db.insertion.evaluation import (
8
8
  )
9
9
  from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
10
10
 
11
+ from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
12
+ from .dataset_example_spans import DatasetExampleSpansDataLoader
11
13
  from .document_evaluation_summaries import (
12
14
  DocumentEvaluationSummaryCache,
13
15
  DocumentEvaluationSummaryDataLoader,
@@ -15,27 +17,41 @@ from .document_evaluation_summaries import (
15
17
  from .document_evaluations import DocumentEvaluationsDataLoader
16
18
  from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
17
19
  from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
20
+ from .experiment_annotation_summaries import ExperimentAnnotationSummaryDataLoader
21
+ from .experiment_error_rates import ExperimentErrorRatesDataLoader
22
+ from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
18
23
  from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
19
24
  from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
25
+ from .project_by_name import ProjectByNameDataLoader
20
26
  from .record_counts import RecordCountCache, RecordCountDataLoader
21
27
  from .span_descendants import SpanDescendantsDataLoader
22
28
  from .span_evaluations import SpanEvaluationsDataLoader
29
+ from .span_projects import SpanProjectsDataLoader
23
30
  from .token_counts import TokenCountCache, TokenCountDataLoader
24
31
  from .trace_evaluations import TraceEvaluationsDataLoader
32
+ from .trace_row_ids import TraceRowIdsDataLoader
25
33
 
26
34
  __all__ = [
27
35
  "CacheForDataLoaders",
36
+ "DatasetExampleRevisionsDataLoader",
37
+ "DatasetExampleSpansDataLoader",
28
38
  "DocumentEvaluationSummaryDataLoader",
29
39
  "DocumentEvaluationsDataLoader",
30
40
  "DocumentRetrievalMetricsDataLoader",
31
41
  "EvaluationSummaryDataLoader",
42
+ "ExperimentAnnotationSummaryDataLoader",
43
+ "ExperimentErrorRatesDataLoader",
44
+ "ExperimentSequenceNumberDataLoader",
32
45
  "LatencyMsQuantileDataLoader",
33
46
  "MinStartOrMaxEndTimeDataLoader",
34
47
  "RecordCountDataLoader",
35
48
  "SpanDescendantsDataLoader",
36
49
  "SpanEvaluationsDataLoader",
50
+ "SpanProjectsDataLoader",
37
51
  "TokenCountDataLoader",
38
52
  "TraceEvaluationsDataLoader",
53
+ "TraceRowIdsDataLoader",
54
+ "ProjectByNameDataLoader",
39
55
  ]
40
56
 
41
57
 
@@ -0,0 +1,100 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ Optional,
6
+ Tuple,
7
+ Union,
8
+ )
9
+
10
+ from sqlalchemy import Integer, case, func, literal, or_, select, union
11
+ from sqlalchemy.ext.asyncio import AsyncSession
12
+ from strawberry.dataloader import DataLoader
13
+ from typing_extensions import TypeAlias
14
+
15
+ from phoenix.db import models
16
+ from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
17
+
18
+ ExampleID: TypeAlias = int
19
+ VersionID: TypeAlias = Optional[int]
20
+ Key: TypeAlias = Tuple[ExampleID, Optional[VersionID]]
21
+ Result: TypeAlias = DatasetExampleRevision
22
+
23
+
24
+ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
25
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
26
+ super().__init__(load_fn=self._load_fn)
27
+ self._db = db
28
+
29
+ async def _load_fn(self, keys: List[Key]) -> List[Union[Result, ValueError]]:
30
+ # sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
31
+ # For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
32
+ keys_subquery = union(
33
+ *(
34
+ select(
35
+ literal(example_id, Integer).label("example_id"),
36
+ literal(version_id, Integer).label("version_id"),
37
+ )
38
+ for example_id, version_id in keys
39
+ )
40
+ ).subquery()
41
+ revision_ids = (
42
+ select(
43
+ keys_subquery.c.example_id,
44
+ keys_subquery.c.version_id,
45
+ func.max(models.DatasetExampleRevision.id).label("revision_id"),
46
+ )
47
+ .select_from(keys_subquery)
48
+ .join(
49
+ models.DatasetExampleRevision,
50
+ onclause=keys_subquery.c.example_id
51
+ == models.DatasetExampleRevision.dataset_example_id,
52
+ )
53
+ .where(
54
+ or_(
55
+ keys_subquery.c.version_id.is_(None),
56
+ models.DatasetExampleRevision.dataset_version_id <= keys_subquery.c.version_id,
57
+ )
58
+ )
59
+ .group_by(keys_subquery.c.example_id, keys_subquery.c.version_id)
60
+ ).subquery()
61
+ query = (
62
+ select(
63
+ revision_ids.c.example_id,
64
+ revision_ids.c.version_id,
65
+ case(
66
+ (
67
+ or_(
68
+ revision_ids.c.version_id.is_(None),
69
+ models.DatasetVersion.id.is_not(None),
70
+ ),
71
+ True,
72
+ ),
73
+ else_=False,
74
+ ).label("is_valid_version"), # check that non-null versions exist
75
+ models.DatasetExampleRevision,
76
+ )
77
+ .select_from(revision_ids)
78
+ .join(
79
+ models.DatasetExampleRevision,
80
+ onclause=revision_ids.c.revision_id == models.DatasetExampleRevision.id,
81
+ )
82
+ .join(
83
+ models.DatasetVersion,
84
+ onclause=revision_ids.c.version_id == models.DatasetVersion.id,
85
+ isouter=True, # keep rows where the version id is null
86
+ )
87
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
88
+ )
89
+ async with self._db() as session:
90
+ results = {
91
+ (example_id, version_id): DatasetExampleRevision.from_orm_revision(revision)
92
+ async for (
93
+ example_id,
94
+ version_id,
95
+ is_valid_version,
96
+ revision,
97
+ ) in await session.stream(query)
98
+ if is_valid_version
99
+ }
100
+ return [results.get(key, ValueError("Could not find revision.")) for key in keys]
@@ -0,0 +1,43 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ Optional,
6
+ )
7
+
8
+ from sqlalchemy import select
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+ from sqlalchemy.orm import joinedload
11
+ from strawberry.dataloader import DataLoader
12
+ from typing_extensions import TypeAlias
13
+
14
+ from phoenix.db import models
15
+
16
+ ExampleID: TypeAlias = int
17
+ Key: TypeAlias = ExampleID
18
+ Result: TypeAlias = Optional[models.Span]
19
+
20
+
21
+ class DatasetExampleSpansDataLoader(DataLoader[Key, Result]):
22
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ super().__init__(load_fn=self._load_fn)
24
+ self._db = db
25
+
26
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
27
+ example_ids = keys
28
+ async with self._db() as session:
29
+ spans = {
30
+ example_id: span
31
+ async for example_id, span in await session.stream(
32
+ select(models.DatasetExample.id, models.Span)
33
+ .select_from(models.DatasetExample)
34
+ .join(models.Span, models.DatasetExample.span_rowid == models.Span.id)
35
+ .where(models.DatasetExample.id.in_(example_ids))
36
+ .options(
37
+ joinedload(models.Span.trace, innerjoin=True).load_only(
38
+ models.Trace.trace_id
39
+ )
40
+ )
41
+ )
42
+ }
43
+ return [spans.get(example_id) for example_id in example_ids]
@@ -0,0 +1,85 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import (
4
+ AsyncContextManager,
5
+ Callable,
6
+ DefaultDict,
7
+ List,
8
+ Optional,
9
+ )
10
+
11
+ from sqlalchemy import func, select
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from strawberry.dataloader import AbstractCache, DataLoader
14
+ from typing_extensions import TypeAlias
15
+
16
+ from phoenix.db import models
17
+
18
+
19
+ @dataclass
20
+ class ExperimentAnnotationSummary:
21
+ annotation_name: str
22
+ min_score: float
23
+ max_score: float
24
+ mean_score: float
25
+ count: int
26
+ error_count: int
27
+
28
+
29
+ ExperimentID: TypeAlias = int
30
+ Key: TypeAlias = ExperimentID
31
+ Result: TypeAlias = List[ExperimentAnnotationSummary]
32
+
33
+
34
+ class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
35
+ def __init__(
36
+ self,
37
+ db: Callable[[], AsyncContextManager[AsyncSession]],
38
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
39
+ ) -> None:
40
+ super().__init__(load_fn=self._load_fn)
41
+ self._db = db
42
+
43
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
44
+ experiment_ids = keys
45
+ summaries: DefaultDict[ExperimentID, Result] = defaultdict(list)
46
+ async with self._db() as session:
47
+ async for (
48
+ experiment_id,
49
+ annotation_name,
50
+ min_score,
51
+ max_score,
52
+ mean_score,
53
+ count,
54
+ error_count,
55
+ ) in await session.stream(
56
+ select(
57
+ models.ExperimentRun.experiment_id,
58
+ models.ExperimentRunAnnotation.name,
59
+ func.min(models.ExperimentRunAnnotation.score),
60
+ func.max(models.ExperimentRunAnnotation.score),
61
+ func.avg(models.ExperimentRunAnnotation.score),
62
+ func.count(),
63
+ func.count(models.ExperimentRunAnnotation.error),
64
+ )
65
+ .join(
66
+ models.ExperimentRun,
67
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
68
+ )
69
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
70
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
71
+ ):
72
+ summaries[experiment_id].append(
73
+ ExperimentAnnotationSummary(
74
+ annotation_name=annotation_name,
75
+ min_score=min_score,
76
+ max_score=max_score,
77
+ mean_score=mean_score,
78
+ count=count,
79
+ error_count=error_count,
80
+ )
81
+ )
82
+ return [
83
+ sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
84
+ for experiment_id in experiment_ids
85
+ ]
@@ -0,0 +1,43 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ Optional,
6
+ )
7
+
8
+ from sqlalchemy import func, select
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+ from strawberry.dataloader import DataLoader
11
+ from typing_extensions import TypeAlias
12
+
13
+ from phoenix.db import models
14
+
15
+ ExperimentID: TypeAlias = int
16
+ ErrorRate: TypeAlias = float
17
+ Key: TypeAlias = ExperimentID
18
+ Result: TypeAlias = Optional[ErrorRate]
19
+
20
+
21
+ class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
22
+ def __init__(
23
+ self,
24
+ db: Callable[[], AsyncContextManager[AsyncSession]],
25
+ ) -> None:
26
+ super().__init__(load_fn=self._load_fn)
27
+ self._db = db
28
+
29
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
30
+ experiment_ids = keys
31
+ async with self._db() as session:
32
+ error_rates = {
33
+ experiment_id: error_rate
34
+ async for experiment_id, error_rate in await session.stream(
35
+ select(
36
+ models.ExperimentRun.experiment_id,
37
+ func.count(models.ExperimentRun.error) / func.count(),
38
+ )
39
+ .group_by(models.ExperimentRun.experiment_id)
40
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
41
+ )
42
+ }
43
+ return [error_rates.get(experiment_id) for experiment_id in experiment_ids]
@@ -0,0 +1,49 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ Optional,
6
+ )
7
+
8
+ from sqlalchemy import distinct, func, select
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+ from strawberry.dataloader import DataLoader
11
+ from typing_extensions import TypeAlias
12
+
13
+ from phoenix.db import models
14
+
15
+ ExperimentId: TypeAlias = int
16
+ Key: TypeAlias = ExperimentId
17
+ Result: TypeAlias = Optional[int]
18
+
19
+
20
+ class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
26
+ experiment_ids = keys
27
+ dataset_ids = (
28
+ select(distinct(models.Experiment.dataset_id))
29
+ .where(models.Experiment.id.in_(experiment_ids))
30
+ .scalar_subquery()
31
+ )
32
+ row_number = (
33
+ func.row_number().over(
34
+ partition_by=models.Experiment.dataset_id,
35
+ order_by=models.Experiment.id,
36
+ )
37
+ ).label("row_number")
38
+ subq = (
39
+ select(models.Experiment.id, row_number)
40
+ .where(models.Experiment.dataset_id.in_(dataset_ids))
41
+ .subquery()
42
+ )
43
+ stmt = select(subq).where(subq.c.id.in_(experiment_ids))
44
+ async with self._db() as session:
45
+ result = {
46
+ experiment_id: sequence_number
47
+ async for experiment_id, sequence_number in await session.stream(stmt)
48
+ }
49
+ return [result.get(experiment_id) for experiment_id in experiment_ids]
@@ -0,0 +1,31 @@
1
+ from collections import defaultdict
2
+ from typing import AsyncContextManager, Callable, DefaultDict, List, Optional
3
+
4
+ from sqlalchemy import select
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from strawberry.dataloader import DataLoader
7
+ from typing_extensions import TypeAlias
8
+
9
+ from phoenix.db import models
10
+
11
+ ProjectName: TypeAlias = str
12
+ Key: TypeAlias = ProjectName
13
+ Result: TypeAlias = Optional[models.Project]
14
+
15
+
16
+ class ProjectByNameDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
22
+ project_names = list(set(keys))
23
+ projects_by_name: DefaultDict[Key, Result] = defaultdict(None)
24
+ async with self._db() as session:
25
+ data = await session.stream_scalars(
26
+ select(models.Project).where(models.Project.name.in_(project_names))
27
+ )
28
+ async for project in data:
29
+ projects_by_name[project.name] = project
30
+
31
+ return [projects_by_name[project_name] for project_name in project_names]
@@ -9,7 +9,7 @@ from typing import (
9
9
  from aioitertools.itertools import groupby
10
10
  from sqlalchemy import select
11
11
  from sqlalchemy.ext.asyncio import AsyncSession
12
- from sqlalchemy.orm import contains_eager
12
+ from sqlalchemy.orm import joinedload
13
13
  from strawberry.dataloader import DataLoader
14
14
  from typing_extensions import TypeAlias
15
15
 
@@ -52,8 +52,7 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
52
52
  stmt = (
53
53
  select(descendant_ids.c[root_id_label], models.Span)
54
54
  .join(descendant_ids, models.Span.id == descendant_ids.c.id)
55
- .join(models.Trace)
56
- .options(contains_eager(models.Span.trace))
55
+ .options(joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id))
57
56
  .order_by(descendant_ids.c[root_id_label])
58
57
  )
59
58
  results: Dict[SpanId, Result] = {key: [] for key in keys}