arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +4 -4
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +108 -55
- phoenix/__init__.py +0 -27
- phoenix/config.py +21 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +64 -62
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators.py +275 -0
- phoenix/datasets/experiments.py +469 -0
- phoenix/datasets/tracing.py +66 -0
- phoenix/datasets/types.py +212 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +234 -0
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +13 -2
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +230 -3
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +16 -0
- phoenix/server/api/dataloaders/__init__.py +16 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +178 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +42 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +861 -0
- phoenix/server/api/routers/v1/evaluations.py +4 -2
- phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
- phoenix/server/api/routers/v1/experiment_runs.py +108 -0
- phoenix/server/api/routers/v1/experiments.py +174 -0
- phoenix/server/api/routers/v1/spans.py +3 -1
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +135 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Span.py +78 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +99 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2246 -1368
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +316 -21
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
phoenix/inferences/fixtures.py
CHANGED
|
@@ -9,7 +9,7 @@ from urllib.parse import quote, urljoin
|
|
|
9
9
|
|
|
10
10
|
from pandas import read_parquet
|
|
11
11
|
|
|
12
|
-
from phoenix.config import
|
|
12
|
+
from phoenix.config import INFERENCES_DIR
|
|
13
13
|
from phoenix.inferences.inferences import Inferences
|
|
14
14
|
from phoenix.inferences.schema import (
|
|
15
15
|
EmbeddingColumnNames,
|
|
@@ -20,7 +20,7 @@ from phoenix.inferences.schema import (
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class
|
|
23
|
+
class InferencesRole(Enum):
|
|
24
24
|
PRIMARY = auto()
|
|
25
25
|
REFERENCE = auto()
|
|
26
26
|
CORPUS = auto()
|
|
@@ -39,11 +39,11 @@ class Fixture:
|
|
|
39
39
|
corpus_file_name: Optional[str] = None
|
|
40
40
|
corpus_schema: Optional[Schema] = None
|
|
41
41
|
|
|
42
|
-
def paths(self) -> Iterator[Tuple[
|
|
42
|
+
def paths(self) -> Iterator[Tuple[InferencesRole, Path]]:
|
|
43
43
|
return (
|
|
44
44
|
(role, Path(self.prefix) / name)
|
|
45
45
|
for role, name in zip(
|
|
46
|
-
|
|
46
|
+
InferencesRole,
|
|
47
47
|
(
|
|
48
48
|
self.primary_file_name,
|
|
49
49
|
self.reference_file_name,
|
|
@@ -413,41 +413,41 @@ FIXTURES: Tuple[Fixture, ...] = (
|
|
|
413
413
|
NAME_TO_FIXTURE = {fixture.name: fixture for fixture in FIXTURES}
|
|
414
414
|
|
|
415
415
|
|
|
416
|
-
def
|
|
416
|
+
def get_inferences(
|
|
417
417
|
fixture_name: str,
|
|
418
418
|
no_internet: bool = False,
|
|
419
419
|
) -> Tuple[Inferences, Optional[Inferences], Optional[Inferences]]:
|
|
420
420
|
"""
|
|
421
|
-
Downloads primary and reference
|
|
421
|
+
Downloads primary and reference inferences for a fixture if they are not found
|
|
422
422
|
locally.
|
|
423
423
|
"""
|
|
424
424
|
fixture = _get_fixture_by_name(fixture_name=fixture_name)
|
|
425
425
|
if no_internet:
|
|
426
|
-
paths = {role:
|
|
426
|
+
paths = {role: INFERENCES_DIR / path for role, path in fixture.paths()}
|
|
427
427
|
else:
|
|
428
|
-
paths = dict(_download(fixture,
|
|
429
|
-
|
|
430
|
-
read_parquet(paths[
|
|
428
|
+
paths = dict(_download(fixture, INFERENCES_DIR))
|
|
429
|
+
primary_inferences = Inferences(
|
|
430
|
+
read_parquet(paths[InferencesRole.PRIMARY]),
|
|
431
431
|
fixture.primary_schema,
|
|
432
432
|
"production",
|
|
433
433
|
)
|
|
434
|
-
|
|
434
|
+
reference_inferences = None
|
|
435
435
|
if fixture.reference_file_name is not None:
|
|
436
|
-
|
|
437
|
-
read_parquet(paths[
|
|
436
|
+
reference_inferences = Inferences(
|
|
437
|
+
read_parquet(paths[InferencesRole.REFERENCE]),
|
|
438
438
|
fixture.reference_schema
|
|
439
439
|
if fixture.reference_schema is not None
|
|
440
440
|
else fixture.primary_schema,
|
|
441
441
|
"training",
|
|
442
442
|
)
|
|
443
|
-
|
|
443
|
+
corpus_inferences = None
|
|
444
444
|
if fixture.corpus_file_name is not None:
|
|
445
|
-
|
|
446
|
-
read_parquet(paths[
|
|
445
|
+
corpus_inferences = Inferences(
|
|
446
|
+
read_parquet(paths[InferencesRole.CORPUS]),
|
|
447
447
|
fixture.corpus_schema,
|
|
448
448
|
"knowledge_base",
|
|
449
449
|
)
|
|
450
|
-
return
|
|
450
|
+
return primary_inferences, reference_inferences, corpus_inferences
|
|
451
451
|
|
|
452
452
|
|
|
453
453
|
def _get_fixture_by_name(fixture_name: str) -> Fixture:
|
|
@@ -496,14 +496,14 @@ def load_example(use_case: str) -> ExampleInferences:
|
|
|
496
496
|
|
|
497
497
|
"""
|
|
498
498
|
fixture = _get_fixture_by_name(use_case)
|
|
499
|
-
|
|
499
|
+
primary_inferences, reference_inferences, corpus_inferences = get_inferences(use_case)
|
|
500
500
|
print(f"📥 Loaded {use_case} example datasets.")
|
|
501
501
|
print("ℹ️ About this use-case:")
|
|
502
502
|
print(fixture.description)
|
|
503
503
|
return ExampleInferences(
|
|
504
|
-
primary=
|
|
505
|
-
reference=
|
|
506
|
-
corpus=
|
|
504
|
+
primary=primary_inferences,
|
|
505
|
+
reference=reference_inferences,
|
|
506
|
+
corpus=corpus_inferences,
|
|
507
507
|
)
|
|
508
508
|
|
|
509
509
|
|
|
@@ -544,7 +544,7 @@ class GCSAssets(NamedTuple):
|
|
|
544
544
|
)
|
|
545
545
|
|
|
546
546
|
|
|
547
|
-
def _download(fixture: Fixture, location: Path) -> Iterator[Tuple[
|
|
547
|
+
def _download(fixture: Fixture, location: Path) -> Iterator[Tuple[InferencesRole, Path]]:
|
|
548
548
|
for role, path in fixture.paths():
|
|
549
549
|
yield role, GCSAssets().metadata(path).save_artifact(location)
|
|
550
550
|
|
|
@@ -556,5 +556,5 @@ if __name__ == "__main__":
|
|
|
556
556
|
for fixture in FIXTURES:
|
|
557
557
|
start_time = time.time()
|
|
558
558
|
print(f"getting {fixture.name}", end="...")
|
|
559
|
-
dict(_download(fixture,
|
|
559
|
+
dict(_download(fixture, INFERENCES_DIR))
|
|
560
560
|
print(f"done ({time.time() - start_time:.2f}s)")
|
phoenix/inferences/inferences.py
CHANGED
|
@@ -15,7 +15,7 @@ from pandas.api.types import (
|
|
|
15
15
|
)
|
|
16
16
|
from typing_extensions import TypeAlias
|
|
17
17
|
|
|
18
|
-
from phoenix.config import
|
|
18
|
+
from phoenix.config import GENERATED_INFERENCES_NAME_PREFIX, INFERENCES_DIR
|
|
19
19
|
from phoenix.datetime_utils import normalize_timestamps
|
|
20
20
|
from phoenix.utilities.deprecation import deprecated
|
|
21
21
|
|
|
@@ -31,7 +31,7 @@ from .schema import (
|
|
|
31
31
|
SchemaFieldName,
|
|
32
32
|
SchemaFieldValue,
|
|
33
33
|
)
|
|
34
|
-
from .validation import
|
|
34
|
+
from .validation import validate_inferences_inputs
|
|
35
35
|
|
|
36
36
|
logger = logging.getLogger(__name__)
|
|
37
37
|
|
|
@@ -62,7 +62,7 @@ class Inferences:
|
|
|
62
62
|
|
|
63
63
|
Examples
|
|
64
64
|
--------
|
|
65
|
-
>>>
|
|
65
|
+
>>> primary_inferences = px.Inferences(
|
|
66
66
|
>>> dataframe=production_dataframe, schema=schema, name="primary"
|
|
67
67
|
>>> )
|
|
68
68
|
"""
|
|
@@ -81,7 +81,7 @@ class Inferences:
|
|
|
81
81
|
# allow for schema like objects
|
|
82
82
|
if not isinstance(schema, Schema):
|
|
83
83
|
schema = _get_schema_from_unknown_schema_param(schema)
|
|
84
|
-
errors =
|
|
84
|
+
errors = validate_inferences_inputs(
|
|
85
85
|
dataframe=dataframe,
|
|
86
86
|
schema=schema,
|
|
87
87
|
)
|
|
@@ -95,7 +95,7 @@ class Inferences:
|
|
|
95
95
|
self.__dataframe: DataFrame = dataframe
|
|
96
96
|
self.__schema: Schema = schema
|
|
97
97
|
self.__name: str = (
|
|
98
|
-
name if name is not None else f"{
|
|
98
|
+
name if name is not None else f"{GENERATED_INFERENCES_NAME_PREFIX}{str(uuid.uuid4())}"
|
|
99
99
|
)
|
|
100
100
|
self._is_empty = self.dataframe.empty
|
|
101
101
|
logger.info(f"""Dataset: {self.__name} initialized""")
|
|
@@ -118,7 +118,7 @@ class Inferences:
|
|
|
118
118
|
@classmethod
|
|
119
119
|
def from_name(cls, name: str) -> "Inferences":
|
|
120
120
|
"""Retrieves a dataset by name from the file system"""
|
|
121
|
-
directory =
|
|
121
|
+
directory = INFERENCES_DIR / name
|
|
122
122
|
df = read_parquet(directory / cls._data_file_name)
|
|
123
123
|
with open(directory / cls._schema_file_name) as schema_file:
|
|
124
124
|
schema_json = schema_file.read()
|
|
@@ -127,7 +127,7 @@ class Inferences:
|
|
|
127
127
|
|
|
128
128
|
def to_disc(self) -> None:
|
|
129
129
|
"""writes the data and schema to disc"""
|
|
130
|
-
directory =
|
|
130
|
+
directory = INFERENCES_DIR / self.name
|
|
131
131
|
directory.mkdir(parents=True, exist_ok=True)
|
|
132
132
|
self.dataframe.to_parquet(
|
|
133
133
|
directory / self._data_file_name,
|
phoenix/inferences/validation.py
CHANGED
|
@@ -34,7 +34,7 @@ def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
|
|
|
34
34
|
return []
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def
|
|
37
|
+
def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
|
|
38
38
|
errors = _check_missing_columns(dataframe, schema)
|
|
39
39
|
if errors:
|
|
40
40
|
return errors
|
phoenix/server/api/context.py
CHANGED
|
@@ -12,33 +12,49 @@ from typing_extensions import TypeAlias
|
|
|
12
12
|
from phoenix.core.model_schema import Model
|
|
13
13
|
from phoenix.server.api.dataloaders import (
|
|
14
14
|
CacheForDataLoaders,
|
|
15
|
+
DatasetExampleRevisionsDataLoader,
|
|
16
|
+
DatasetExampleSpansDataLoader,
|
|
15
17
|
DocumentEvaluationsDataLoader,
|
|
16
18
|
DocumentEvaluationSummaryDataLoader,
|
|
17
19
|
DocumentRetrievalMetricsDataLoader,
|
|
18
20
|
EvaluationSummaryDataLoader,
|
|
21
|
+
ExperimentAnnotationSummaryDataLoader,
|
|
22
|
+
ExperimentErrorRatesDataLoader,
|
|
23
|
+
ExperimentSequenceNumberDataLoader,
|
|
19
24
|
LatencyMsQuantileDataLoader,
|
|
20
25
|
MinStartOrMaxEndTimeDataLoader,
|
|
26
|
+
ProjectByNameDataLoader,
|
|
21
27
|
RecordCountDataLoader,
|
|
22
28
|
SpanDescendantsDataLoader,
|
|
23
29
|
SpanEvaluationsDataLoader,
|
|
30
|
+
SpanProjectsDataLoader,
|
|
24
31
|
TokenCountDataLoader,
|
|
25
32
|
TraceEvaluationsDataLoader,
|
|
33
|
+
TraceRowIdsDataLoader,
|
|
26
34
|
)
|
|
27
35
|
|
|
28
36
|
|
|
29
37
|
@dataclass
|
|
30
38
|
class DataLoaders:
|
|
39
|
+
dataset_example_revisions: DatasetExampleRevisionsDataLoader
|
|
40
|
+
dataset_example_spans: DatasetExampleSpansDataLoader
|
|
31
41
|
document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
|
|
32
42
|
document_evaluations: DocumentEvaluationsDataLoader
|
|
33
43
|
document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
|
|
34
44
|
evaluation_summaries: EvaluationSummaryDataLoader
|
|
45
|
+
experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
|
|
46
|
+
experiment_error_rates: ExperimentErrorRatesDataLoader
|
|
47
|
+
experiment_sequence_number: ExperimentSequenceNumberDataLoader
|
|
35
48
|
latency_ms_quantile: LatencyMsQuantileDataLoader
|
|
36
49
|
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
|
|
37
50
|
record_counts: RecordCountDataLoader
|
|
38
51
|
span_descendants: SpanDescendantsDataLoader
|
|
39
52
|
span_evaluations: SpanEvaluationsDataLoader
|
|
53
|
+
span_projects: SpanProjectsDataLoader
|
|
40
54
|
token_counts: TokenCountDataLoader
|
|
41
55
|
trace_evaluations: TraceEvaluationsDataLoader
|
|
56
|
+
trace_row_ids: TraceRowIdsDataLoader
|
|
57
|
+
project_by_name: ProjectByNameDataLoader
|
|
42
58
|
|
|
43
59
|
|
|
44
60
|
ProjectRowId: TypeAlias = int
|
|
@@ -8,6 +8,8 @@ from phoenix.db.insertion.evaluation import (
|
|
|
8
8
|
)
|
|
9
9
|
from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
|
|
10
10
|
|
|
11
|
+
from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
|
|
12
|
+
from .dataset_example_spans import DatasetExampleSpansDataLoader
|
|
11
13
|
from .document_evaluation_summaries import (
|
|
12
14
|
DocumentEvaluationSummaryCache,
|
|
13
15
|
DocumentEvaluationSummaryDataLoader,
|
|
@@ -15,27 +17,41 @@ from .document_evaluation_summaries import (
|
|
|
15
17
|
from .document_evaluations import DocumentEvaluationsDataLoader
|
|
16
18
|
from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
|
|
17
19
|
from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
|
|
20
|
+
from .experiment_annotation_summaries import ExperimentAnnotationSummaryDataLoader
|
|
21
|
+
from .experiment_error_rates import ExperimentErrorRatesDataLoader
|
|
22
|
+
from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
|
|
18
23
|
from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
|
|
19
24
|
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
|
|
25
|
+
from .project_by_name import ProjectByNameDataLoader
|
|
20
26
|
from .record_counts import RecordCountCache, RecordCountDataLoader
|
|
21
27
|
from .span_descendants import SpanDescendantsDataLoader
|
|
22
28
|
from .span_evaluations import SpanEvaluationsDataLoader
|
|
29
|
+
from .span_projects import SpanProjectsDataLoader
|
|
23
30
|
from .token_counts import TokenCountCache, TokenCountDataLoader
|
|
24
31
|
from .trace_evaluations import TraceEvaluationsDataLoader
|
|
32
|
+
from .trace_row_ids import TraceRowIdsDataLoader
|
|
25
33
|
|
|
26
34
|
__all__ = [
|
|
27
35
|
"CacheForDataLoaders",
|
|
36
|
+
"DatasetExampleRevisionsDataLoader",
|
|
37
|
+
"DatasetExampleSpansDataLoader",
|
|
28
38
|
"DocumentEvaluationSummaryDataLoader",
|
|
29
39
|
"DocumentEvaluationsDataLoader",
|
|
30
40
|
"DocumentRetrievalMetricsDataLoader",
|
|
31
41
|
"EvaluationSummaryDataLoader",
|
|
42
|
+
"ExperimentAnnotationSummaryDataLoader",
|
|
43
|
+
"ExperimentErrorRatesDataLoader",
|
|
44
|
+
"ExperimentSequenceNumberDataLoader",
|
|
32
45
|
"LatencyMsQuantileDataLoader",
|
|
33
46
|
"MinStartOrMaxEndTimeDataLoader",
|
|
34
47
|
"RecordCountDataLoader",
|
|
35
48
|
"SpanDescendantsDataLoader",
|
|
36
49
|
"SpanEvaluationsDataLoader",
|
|
50
|
+
"SpanProjectsDataLoader",
|
|
37
51
|
"TokenCountDataLoader",
|
|
38
52
|
"TraceEvaluationsDataLoader",
|
|
53
|
+
"TraceRowIdsDataLoader",
|
|
54
|
+
"ProjectByNameDataLoader",
|
|
39
55
|
]
|
|
40
56
|
|
|
41
57
|
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
AsyncContextManager,
|
|
3
|
+
Callable,
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
Tuple,
|
|
7
|
+
Union,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from sqlalchemy import Integer, case, func, literal, or_, select, union
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
12
|
+
from strawberry.dataloader import DataLoader
|
|
13
|
+
from typing_extensions import TypeAlias
|
|
14
|
+
|
|
15
|
+
from phoenix.db import models
|
|
16
|
+
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
17
|
+
|
|
18
|
+
ExampleID: TypeAlias = int
|
|
19
|
+
VersionID: TypeAlias = Optional[int]
|
|
20
|
+
Key: TypeAlias = Tuple[ExampleID, Optional[VersionID]]
|
|
21
|
+
Result: TypeAlias = DatasetExampleRevision
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
|
|
25
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
26
|
+
super().__init__(load_fn=self._load_fn)
|
|
27
|
+
self._db = db
|
|
28
|
+
|
|
29
|
+
async def _load_fn(self, keys: List[Key]) -> List[Union[Result, ValueError]]:
|
|
30
|
+
# sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
|
|
31
|
+
# For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
|
|
32
|
+
keys_subquery = union(
|
|
33
|
+
*(
|
|
34
|
+
select(
|
|
35
|
+
literal(example_id, Integer).label("example_id"),
|
|
36
|
+
literal(version_id, Integer).label("version_id"),
|
|
37
|
+
)
|
|
38
|
+
for example_id, version_id in keys
|
|
39
|
+
)
|
|
40
|
+
).subquery()
|
|
41
|
+
revision_ids = (
|
|
42
|
+
select(
|
|
43
|
+
keys_subquery.c.example_id,
|
|
44
|
+
keys_subquery.c.version_id,
|
|
45
|
+
func.max(models.DatasetExampleRevision.id).label("revision_id"),
|
|
46
|
+
)
|
|
47
|
+
.select_from(keys_subquery)
|
|
48
|
+
.join(
|
|
49
|
+
models.DatasetExampleRevision,
|
|
50
|
+
onclause=keys_subquery.c.example_id
|
|
51
|
+
== models.DatasetExampleRevision.dataset_example_id,
|
|
52
|
+
)
|
|
53
|
+
.where(
|
|
54
|
+
or_(
|
|
55
|
+
keys_subquery.c.version_id.is_(None),
|
|
56
|
+
models.DatasetExampleRevision.dataset_version_id <= keys_subquery.c.version_id,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
.group_by(keys_subquery.c.example_id, keys_subquery.c.version_id)
|
|
60
|
+
).subquery()
|
|
61
|
+
query = (
|
|
62
|
+
select(
|
|
63
|
+
revision_ids.c.example_id,
|
|
64
|
+
revision_ids.c.version_id,
|
|
65
|
+
case(
|
|
66
|
+
(
|
|
67
|
+
or_(
|
|
68
|
+
revision_ids.c.version_id.is_(None),
|
|
69
|
+
models.DatasetVersion.id.is_not(None),
|
|
70
|
+
),
|
|
71
|
+
True,
|
|
72
|
+
),
|
|
73
|
+
else_=False,
|
|
74
|
+
).label("is_valid_version"), # check that non-null versions exist
|
|
75
|
+
models.DatasetExampleRevision,
|
|
76
|
+
)
|
|
77
|
+
.select_from(revision_ids)
|
|
78
|
+
.join(
|
|
79
|
+
models.DatasetExampleRevision,
|
|
80
|
+
onclause=revision_ids.c.revision_id == models.DatasetExampleRevision.id,
|
|
81
|
+
)
|
|
82
|
+
.join(
|
|
83
|
+
models.DatasetVersion,
|
|
84
|
+
onclause=revision_ids.c.version_id == models.DatasetVersion.id,
|
|
85
|
+
isouter=True, # keep rows where the version id is null
|
|
86
|
+
)
|
|
87
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
88
|
+
)
|
|
89
|
+
async with self._db() as session:
|
|
90
|
+
results = {
|
|
91
|
+
(example_id, version_id): DatasetExampleRevision.from_orm_revision(revision)
|
|
92
|
+
async for (
|
|
93
|
+
example_id,
|
|
94
|
+
version_id,
|
|
95
|
+
is_valid_version,
|
|
96
|
+
revision,
|
|
97
|
+
) in await session.stream(query)
|
|
98
|
+
if is_valid_version
|
|
99
|
+
}
|
|
100
|
+
return [results.get(key, ValueError("Could not find revision.")) for key in keys]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
AsyncContextManager,
|
|
3
|
+
Callable,
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import select
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
|
+
from sqlalchemy.orm import joinedload
|
|
11
|
+
from strawberry.dataloader import DataLoader
|
|
12
|
+
from typing_extensions import TypeAlias
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
|
|
16
|
+
ExampleID: TypeAlias = int
|
|
17
|
+
Key: TypeAlias = ExampleID
|
|
18
|
+
Result: TypeAlias = Optional[models.Span]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DatasetExampleSpansDataLoader(DataLoader[Key, Result]):
|
|
22
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
23
|
+
super().__init__(load_fn=self._load_fn)
|
|
24
|
+
self._db = db
|
|
25
|
+
|
|
26
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
27
|
+
example_ids = keys
|
|
28
|
+
async with self._db() as session:
|
|
29
|
+
spans = {
|
|
30
|
+
example_id: span
|
|
31
|
+
async for example_id, span in await session.stream(
|
|
32
|
+
select(models.DatasetExample.id, models.Span)
|
|
33
|
+
.select_from(models.DatasetExample)
|
|
34
|
+
.join(models.Span, models.DatasetExample.span_rowid == models.Span.id)
|
|
35
|
+
.where(models.DatasetExample.id.in_(example_ids))
|
|
36
|
+
.options(
|
|
37
|
+
joinedload(models.Span.trace, innerjoin=True).load_only(
|
|
38
|
+
models.Trace.trace_id
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
}
|
|
43
|
+
return [spans.get(example_id) for example_id in example_ids]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import (
|
|
4
|
+
AsyncContextManager,
|
|
5
|
+
Callable,
|
|
6
|
+
DefaultDict,
|
|
7
|
+
List,
|
|
8
|
+
Optional,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from sqlalchemy import func, select
|
|
12
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
13
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
14
|
+
from typing_extensions import TypeAlias
|
|
15
|
+
|
|
16
|
+
from phoenix.db import models
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ExperimentAnnotationSummary:
|
|
21
|
+
annotation_name: str
|
|
22
|
+
min_score: float
|
|
23
|
+
max_score: float
|
|
24
|
+
mean_score: float
|
|
25
|
+
count: int
|
|
26
|
+
error_count: int
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
ExperimentID: TypeAlias = int
|
|
30
|
+
Key: TypeAlias = ExperimentID
|
|
31
|
+
Result: TypeAlias = List[ExperimentAnnotationSummary]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
38
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
super().__init__(load_fn=self._load_fn)
|
|
41
|
+
self._db = db
|
|
42
|
+
|
|
43
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
44
|
+
experiment_ids = keys
|
|
45
|
+
summaries: DefaultDict[ExperimentID, Result] = defaultdict(list)
|
|
46
|
+
async with self._db() as session:
|
|
47
|
+
async for (
|
|
48
|
+
experiment_id,
|
|
49
|
+
annotation_name,
|
|
50
|
+
min_score,
|
|
51
|
+
max_score,
|
|
52
|
+
mean_score,
|
|
53
|
+
count,
|
|
54
|
+
error_count,
|
|
55
|
+
) in await session.stream(
|
|
56
|
+
select(
|
|
57
|
+
models.ExperimentRun.experiment_id,
|
|
58
|
+
models.ExperimentRunAnnotation.name,
|
|
59
|
+
func.min(models.ExperimentRunAnnotation.score),
|
|
60
|
+
func.max(models.ExperimentRunAnnotation.score),
|
|
61
|
+
func.avg(models.ExperimentRunAnnotation.score),
|
|
62
|
+
func.count(),
|
|
63
|
+
func.count(models.ExperimentRunAnnotation.error),
|
|
64
|
+
)
|
|
65
|
+
.join(
|
|
66
|
+
models.ExperimentRun,
|
|
67
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
68
|
+
)
|
|
69
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
70
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
|
|
71
|
+
):
|
|
72
|
+
summaries[experiment_id].append(
|
|
73
|
+
ExperimentAnnotationSummary(
|
|
74
|
+
annotation_name=annotation_name,
|
|
75
|
+
min_score=min_score,
|
|
76
|
+
max_score=max_score,
|
|
77
|
+
mean_score=mean_score,
|
|
78
|
+
count=count,
|
|
79
|
+
error_count=error_count,
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
return [
|
|
83
|
+
sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
|
|
84
|
+
for experiment_id in experiment_ids
|
|
85
|
+
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
AsyncContextManager,
|
|
3
|
+
Callable,
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import func, select
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
|
+
from strawberry.dataloader import DataLoader
|
|
11
|
+
from typing_extensions import TypeAlias
|
|
12
|
+
|
|
13
|
+
from phoenix.db import models
|
|
14
|
+
|
|
15
|
+
ExperimentID: TypeAlias = int
|
|
16
|
+
ErrorRate: TypeAlias = float
|
|
17
|
+
Key: TypeAlias = ExperimentID
|
|
18
|
+
Result: TypeAlias = Optional[ErrorRate]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(load_fn=self._load_fn)
|
|
27
|
+
self._db = db
|
|
28
|
+
|
|
29
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
30
|
+
experiment_ids = keys
|
|
31
|
+
async with self._db() as session:
|
|
32
|
+
error_rates = {
|
|
33
|
+
experiment_id: error_rate
|
|
34
|
+
async for experiment_id, error_rate in await session.stream(
|
|
35
|
+
select(
|
|
36
|
+
models.ExperimentRun.experiment_id,
|
|
37
|
+
func.count(models.ExperimentRun.error) / func.count(),
|
|
38
|
+
)
|
|
39
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
40
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
41
|
+
)
|
|
42
|
+
}
|
|
43
|
+
return [error_rates.get(experiment_id) for experiment_id in experiment_ids]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
AsyncContextManager,
|
|
3
|
+
Callable,
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import distinct, func, select
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
|
+
from strawberry.dataloader import DataLoader
|
|
11
|
+
from typing_extensions import TypeAlias
|
|
12
|
+
|
|
13
|
+
from phoenix.db import models
|
|
14
|
+
|
|
15
|
+
ExperimentId: TypeAlias = int
|
|
16
|
+
Key: TypeAlias = ExperimentId
|
|
17
|
+
Result: TypeAlias = Optional[int]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
26
|
+
experiment_ids = keys
|
|
27
|
+
dataset_ids = (
|
|
28
|
+
select(distinct(models.Experiment.dataset_id))
|
|
29
|
+
.where(models.Experiment.id.in_(experiment_ids))
|
|
30
|
+
.scalar_subquery()
|
|
31
|
+
)
|
|
32
|
+
row_number = (
|
|
33
|
+
func.row_number().over(
|
|
34
|
+
partition_by=models.Experiment.dataset_id,
|
|
35
|
+
order_by=models.Experiment.id,
|
|
36
|
+
)
|
|
37
|
+
).label("row_number")
|
|
38
|
+
subq = (
|
|
39
|
+
select(models.Experiment.id, row_number)
|
|
40
|
+
.where(models.Experiment.dataset_id.in_(dataset_ids))
|
|
41
|
+
.subquery()
|
|
42
|
+
)
|
|
43
|
+
stmt = select(subq).where(subq.c.id.in_(experiment_ids))
|
|
44
|
+
async with self._db() as session:
|
|
45
|
+
result = {
|
|
46
|
+
experiment_id: sequence_number
|
|
47
|
+
async for experiment_id, sequence_number in await session.stream(stmt)
|
|
48
|
+
}
|
|
49
|
+
return [result.get(experiment_id) for experiment_id in experiment_ids]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import AsyncContextManager, Callable, DefaultDict, List, Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
|
+
from strawberry.dataloader import DataLoader
|
|
7
|
+
from typing_extensions import TypeAlias
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
|
|
11
|
+
ProjectName: TypeAlias = str
|
|
12
|
+
Key: TypeAlias = ProjectName
|
|
13
|
+
Result: TypeAlias = Optional[models.Project]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ProjectByNameDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
22
|
+
project_names = list(set(keys))
|
|
23
|
+
projects_by_name: DefaultDict[Key, Result] = defaultdict(None)
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
data = await session.stream_scalars(
|
|
26
|
+
select(models.Project).where(models.Project.name.in_(project_names))
|
|
27
|
+
)
|
|
28
|
+
async for project in data:
|
|
29
|
+
projects_by_name[project.name] = project
|
|
30
|
+
|
|
31
|
+
return [projects_by_name[project_name] for project_name in project_names]
|
|
@@ -9,7 +9,7 @@ from typing import (
|
|
|
9
9
|
from aioitertools.itertools import groupby
|
|
10
10
|
from sqlalchemy import select
|
|
11
11
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
12
|
-
from sqlalchemy.orm import
|
|
12
|
+
from sqlalchemy.orm import joinedload
|
|
13
13
|
from strawberry.dataloader import DataLoader
|
|
14
14
|
from typing_extensions import TypeAlias
|
|
15
15
|
|
|
@@ -52,8 +52,7 @@ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
|
|
|
52
52
|
stmt = (
|
|
53
53
|
select(descendant_ids.c[root_id_label], models.Span)
|
|
54
54
|
.join(descendant_ids, models.Span.id == descendant_ids.c.id)
|
|
55
|
-
.
|
|
56
|
-
.options(contains_eager(models.Span.trace))
|
|
55
|
+
.options(joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id))
|
|
57
56
|
.order_by(descendant_ids.c[root_id_label])
|
|
58
57
|
)
|
|
59
58
|
results: Dict[SpanId, Result] = {key: [] for key in keys}
|