arize-phoenix 3.16.0__py3-none-any.whl → 7.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.0.dist-info/METADATA +261 -0
- arize_phoenix-7.7.0.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -247
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +13 -107
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.0.dist-info/METADATA +0 -495
- arize_phoenix-3.16.0.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -617
- phoenix/core/traces.py +0 -100
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from sqlalchemy import func, select
|
|
2
|
+
from sqlalchemy.sql.functions import coalesce
|
|
3
|
+
from strawberry.dataloader import DataLoader
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.types import DbSessionFactory
|
|
8
|
+
from phoenix.trace.schemas import TokenUsage
|
|
9
|
+
|
|
10
|
+
Key: TypeAlias = int
|
|
11
|
+
Result: TypeAlias = TokenUsage
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
16
|
+
super().__init__(load_fn=self._load_fn)
|
|
17
|
+
self._db = db
|
|
18
|
+
|
|
19
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
20
|
+
stmt = (
|
|
21
|
+
select(
|
|
22
|
+
models.Trace.project_session_rowid.label("id_"),
|
|
23
|
+
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
|
|
24
|
+
"prompt"
|
|
25
|
+
),
|
|
26
|
+
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
|
|
27
|
+
"completion"
|
|
28
|
+
),
|
|
29
|
+
)
|
|
30
|
+
.join_from(models.Span, models.Trace)
|
|
31
|
+
.where(models.Span.parent_id.is_(None))
|
|
32
|
+
.where(models.Trace.project_session_rowid.in_(keys))
|
|
33
|
+
.group_by(models.Trace.project_session_rowid)
|
|
34
|
+
)
|
|
35
|
+
async with self._db() as session:
|
|
36
|
+
result: dict[Key, TokenUsage] = {
|
|
37
|
+
id_: TokenUsage(prompt=prompt, completion=completion)
|
|
38
|
+
async for id_, prompt, completion in await session.stream(stmt)
|
|
39
|
+
if id_ is not None
|
|
40
|
+
}
|
|
41
|
+
return [result.get(key, TokenUsage()) for key in keys]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from aioitertools.itertools import groupby
|
|
6
|
+
from sqlalchemy import select
|
|
7
|
+
from strawberry.dataloader import DataLoader
|
|
8
|
+
from typing_extensions import TypeAlias
|
|
9
|
+
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.types import DbSessionFactory
|
|
12
|
+
|
|
13
|
+
SessionId: TypeAlias = int
|
|
14
|
+
Probability: TypeAlias = float
|
|
15
|
+
QuantileValue: TypeAlias = float
|
|
16
|
+
|
|
17
|
+
Key: TypeAlias = tuple[SessionId, Probability]
|
|
18
|
+
Result: TypeAlias = Optional[QuantileValue]
|
|
19
|
+
ResultPosition: TypeAlias = int
|
|
20
|
+
|
|
21
|
+
DEFAULT_VALUE: Result = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SessionTraceLatencyMsQuantileDataLoader(DataLoader[Key, Result]):
|
|
25
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
26
|
+
super().__init__(load_fn=self._load_fn)
|
|
27
|
+
self._db = db
|
|
28
|
+
|
|
29
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
30
|
+
results: list[Result] = [DEFAULT_VALUE] * len(keys)
|
|
31
|
+
argument_position_map: defaultdict[
|
|
32
|
+
SessionId, defaultdict[Probability, list[ResultPosition]]
|
|
33
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
34
|
+
session_rowids = {session_id for session_id, _ in keys}
|
|
35
|
+
for position, (session_id, probability) in enumerate(keys):
|
|
36
|
+
argument_position_map[session_id][probability].append(position)
|
|
37
|
+
stmt = (
|
|
38
|
+
select(
|
|
39
|
+
models.Trace.project_session_rowid,
|
|
40
|
+
models.Trace.latency_ms,
|
|
41
|
+
)
|
|
42
|
+
.where(models.Trace.project_session_rowid.in_(session_rowids))
|
|
43
|
+
.order_by(models.Trace.project_session_rowid)
|
|
44
|
+
)
|
|
45
|
+
async with self._db() as session:
|
|
46
|
+
data = await session.stream(stmt)
|
|
47
|
+
async for project_session_rowid, group in groupby(
|
|
48
|
+
data, lambda row: row.project_session_rowid
|
|
49
|
+
):
|
|
50
|
+
session_latencies = [row.latency_ms for row in group]
|
|
51
|
+
for probability, positions in argument_position_map[project_session_rowid].items():
|
|
52
|
+
quantile_value = np.quantile(session_latencies, probability)
|
|
53
|
+
for position in positions:
|
|
54
|
+
results[position] = quantile_value
|
|
55
|
+
return results
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.models import SpanAnnotation as ORMSpanAnnotation
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
Key: TypeAlias = int
|
|
11
|
+
Result: TypeAlias = list[ORMSpanAnnotation]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SpanAnnotationsDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
16
|
+
super().__init__(load_fn=self._load_fn)
|
|
17
|
+
self._db = db
|
|
18
|
+
|
|
19
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
20
|
+
span_annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
21
|
+
async with self._db() as session:
|
|
22
|
+
async for span_annotation in await session.stream_scalars(
|
|
23
|
+
select(ORMSpanAnnotation).where(ORMSpanAnnotation.span_rowid.in_(keys))
|
|
24
|
+
):
|
|
25
|
+
span_annotations_by_id[span_annotation.span_rowid].append(span_annotation)
|
|
26
|
+
return [span_annotations_by_id[key] for key in keys]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
SpanID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = SpanID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetExample]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SpanDatasetExamplesDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(load_fn=self._load_fn)
|
|
16
|
+
self._db = db
|
|
17
|
+
|
|
18
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
19
|
+
span_rowids = keys
|
|
20
|
+
async with self._db() as session:
|
|
21
|
+
dataset_examples: dict[Key, list[models.DatasetExample]] = {
|
|
22
|
+
span_rowid: [] for span_rowid in span_rowids
|
|
23
|
+
}
|
|
24
|
+
async for span_rowid, dataset_example in await session.stream(
|
|
25
|
+
select(models.Span.id, models.DatasetExample)
|
|
26
|
+
.select_from(models.Span)
|
|
27
|
+
.join(models.DatasetExample, models.DatasetExample.span_rowid == models.Span.id)
|
|
28
|
+
.where(models.Span.id.in_(span_rowids))
|
|
29
|
+
):
|
|
30
|
+
dataset_examples[span_rowid].append(dataset_example)
|
|
31
|
+
return [dataset_examples.get(span_rowid, []) for span_rowid in keys]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from random import randint
|
|
2
|
+
|
|
3
|
+
from aioitertools.itertools import groupby
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.orm import joinedload
|
|
6
|
+
from strawberry.dataloader import DataLoader
|
|
7
|
+
from typing_extensions import TypeAlias
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.types import DbSessionFactory
|
|
11
|
+
|
|
12
|
+
SpanId: TypeAlias = str
|
|
13
|
+
|
|
14
|
+
Key: TypeAlias = SpanId
|
|
15
|
+
Result: TypeAlias = list[models.Span]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SpanDescendantsDataLoader(DataLoader[Key, Result]):
|
|
19
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
super().__init__(load_fn=self._load_fn)
|
|
21
|
+
self._db = db
|
|
22
|
+
|
|
23
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
24
|
+
root_ids = set(keys)
|
|
25
|
+
root_id_label = f"root_id_{randint(0, 10**6):06}"
|
|
26
|
+
descendant_ids = (
|
|
27
|
+
select(
|
|
28
|
+
models.Span.id,
|
|
29
|
+
models.Span.span_id,
|
|
30
|
+
models.Span.parent_id.label(root_id_label),
|
|
31
|
+
)
|
|
32
|
+
.where(models.Span.parent_id.in_(root_ids))
|
|
33
|
+
.cte(recursive=True)
|
|
34
|
+
)
|
|
35
|
+
parent_ids = descendant_ids.alias()
|
|
36
|
+
descendant_ids = descendant_ids.union_all(
|
|
37
|
+
select(
|
|
38
|
+
models.Span.id,
|
|
39
|
+
models.Span.span_id,
|
|
40
|
+
parent_ids.c[root_id_label],
|
|
41
|
+
).join(
|
|
42
|
+
parent_ids,
|
|
43
|
+
models.Span.parent_id == parent_ids.c.span_id,
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
stmt = (
|
|
47
|
+
select(descendant_ids.c[root_id_label], models.Span)
|
|
48
|
+
.join(descendant_ids, models.Span.id == descendant_ids.c.id)
|
|
49
|
+
.options(joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id))
|
|
50
|
+
.order_by(descendant_ids.c[root_id_label])
|
|
51
|
+
)
|
|
52
|
+
results: dict[SpanId, Result] = {key: [] for key in keys}
|
|
53
|
+
async with self._db() as session:
|
|
54
|
+
data = await session.stream(stmt)
|
|
55
|
+
async for root_id, group in groupby(data, key=lambda d: d[0]):
|
|
56
|
+
results[root_id].extend(span for _, span in group)
|
|
57
|
+
return [results[key].copy() for key in keys]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
SpanID: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = SpanID
|
|
12
|
+
Result: TypeAlias = models.Project
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SpanProjectsDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Union[Result, ValueError]]:
|
|
21
|
+
span_ids = list(set(keys))
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
projects = {
|
|
24
|
+
span_id: project
|
|
25
|
+
async for span_id, project in await session.stream(
|
|
26
|
+
select(models.Span.id, models.Project)
|
|
27
|
+
.select_from(models.Span)
|
|
28
|
+
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
29
|
+
.join(models.Project, models.Trace.project_rowid == models.Project.id)
|
|
30
|
+
.where(models.Span.id.in_(span_ids))
|
|
31
|
+
)
|
|
32
|
+
}
|
|
33
|
+
return [projects.get(span_id) or ValueError("Invalid span ID") for span_id in keys]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Literal, Optional
|
|
4
|
+
|
|
5
|
+
from cachetools import LFUCache, TTLCache
|
|
6
|
+
from sqlalchemy import Select, func, select
|
|
7
|
+
from sqlalchemy.sql.functions import coalesce
|
|
8
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
9
|
+
from typing_extensions import TypeAlias
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
13
|
+
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
14
|
+
from phoenix.server.types import DbSessionFactory
|
|
15
|
+
from phoenix.trace.dsl import SpanFilter
|
|
16
|
+
|
|
17
|
+
Kind: TypeAlias = Literal["prompt", "completion", "total"]
|
|
18
|
+
ProjectRowId: TypeAlias = int
|
|
19
|
+
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
20
|
+
FilterCondition: TypeAlias = Optional[str]
|
|
21
|
+
TokenCount: TypeAlias = int
|
|
22
|
+
|
|
23
|
+
Segment: TypeAlias = tuple[TimeInterval, FilterCondition]
|
|
24
|
+
Param: TypeAlias = tuple[ProjectRowId, Kind]
|
|
25
|
+
|
|
26
|
+
Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
|
|
27
|
+
Result: TypeAlias = TokenCount
|
|
28
|
+
ResultPosition: TypeAlias = int
|
|
29
|
+
DEFAULT_VALUE: Result = 0
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
33
|
+
kind, project_rowid, time_range, filter_condition = key
|
|
34
|
+
interval = (
|
|
35
|
+
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
36
|
+
)
|
|
37
|
+
return (interval, filter_condition), (project_rowid, kind)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
_Section: TypeAlias = ProjectRowId
|
|
41
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TokenCountCache(
|
|
45
|
+
TwoTierCache[Key, Result, _Section, _SubKey],
|
|
46
|
+
):
|
|
47
|
+
def __init__(self) -> None:
|
|
48
|
+
super().__init__(
|
|
49
|
+
# TTL=3600 (1-hour) because time intervals are always moving forward, but
|
|
50
|
+
# interval endpoints are rounded down to the hour by the UI, so anything
|
|
51
|
+
# older than an hour most likely won't be a cache-hit anyway.
|
|
52
|
+
main_cache=TTLCache(maxsize=64, ttl=3600),
|
|
53
|
+
sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
57
|
+
(interval, filter_condition), (project_rowid, kind) = _cache_key_fn(key)
|
|
58
|
+
return project_rowid, (interval, filter_condition, kind)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TokenCountDataLoader(DataLoader[Key, Result]):
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
db: DbSessionFactory,
|
|
65
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
super().__init__(
|
|
68
|
+
load_fn=self._load_fn,
|
|
69
|
+
cache_key_fn=_cache_key_fn,
|
|
70
|
+
cache_map=cache_map,
|
|
71
|
+
)
|
|
72
|
+
self._db = db
|
|
73
|
+
|
|
74
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
75
|
+
results: list[Result] = [DEFAULT_VALUE] * len(keys)
|
|
76
|
+
arguments: defaultdict[
|
|
77
|
+
Segment,
|
|
78
|
+
defaultdict[Param, list[ResultPosition]],
|
|
79
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
80
|
+
for position, key in enumerate(keys):
|
|
81
|
+
segment, param = _cache_key_fn(key)
|
|
82
|
+
arguments[segment][param].append(position)
|
|
83
|
+
async with self._db() as session:
|
|
84
|
+
for segment, params in arguments.items():
|
|
85
|
+
stmt = _get_stmt(segment, *params.keys())
|
|
86
|
+
data = await session.stream(stmt)
|
|
87
|
+
async for project_rowid, prompt, completion, total in data:
|
|
88
|
+
for position in params[(project_rowid, "prompt")]:
|
|
89
|
+
results[position] = prompt
|
|
90
|
+
for position in params[(project_rowid, "completion")]:
|
|
91
|
+
results[position] = completion
|
|
92
|
+
for position in params[(project_rowid, "total")]:
|
|
93
|
+
results[position] = total
|
|
94
|
+
return results
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _get_stmt(
|
|
98
|
+
segment: Segment,
|
|
99
|
+
*params: Param,
|
|
100
|
+
) -> Select[Any]:
|
|
101
|
+
(start_time, end_time), filter_condition = segment
|
|
102
|
+
prompt = coalesce(func.sum(models.Span.llm_token_count_prompt), 0)
|
|
103
|
+
completion = coalesce(func.sum(models.Span.llm_token_count_completion), 0)
|
|
104
|
+
total = prompt + completion
|
|
105
|
+
pid = models.Trace.project_rowid
|
|
106
|
+
stmt: Select[Any] = (
|
|
107
|
+
select(
|
|
108
|
+
pid,
|
|
109
|
+
prompt.label("prompt"),
|
|
110
|
+
completion.label("completion"),
|
|
111
|
+
total.label("total"),
|
|
112
|
+
)
|
|
113
|
+
.join_from(models.Trace, models.Span)
|
|
114
|
+
.group_by(pid)
|
|
115
|
+
)
|
|
116
|
+
if start_time:
|
|
117
|
+
stmt = stmt.where(start_time <= models.Span.start_time)
|
|
118
|
+
if end_time:
|
|
119
|
+
stmt = stmt.where(models.Span.start_time < end_time)
|
|
120
|
+
if filter_condition:
|
|
121
|
+
sf = SpanFilter(filter_condition)
|
|
122
|
+
stmt = sf(stmt)
|
|
123
|
+
stmt = stmt.where(pid.in_([rowid for rowid, _ in params]))
|
|
124
|
+
return stmt
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
Key: TypeAlias = str
|
|
11
|
+
Result: TypeAlias = Optional[models.Trace]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TraceByTraceIdsDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
16
|
+
super().__init__(load_fn=self._load_fn)
|
|
17
|
+
self._db = db
|
|
18
|
+
|
|
19
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
20
|
+
stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
|
|
21
|
+
async with self._db() as session:
|
|
22
|
+
result: dict[Key, models.Trace] = {
|
|
23
|
+
trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
|
|
24
|
+
}
|
|
25
|
+
return [result.get(trace_id) for trace_id in keys]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from sqlalchemy.orm import contains_eager
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
Key: TypeAlias = int
|
|
12
|
+
Result: TypeAlias = Optional[models.Span]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TraceRootSpansDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
21
|
+
stmt = (
|
|
22
|
+
select(models.Span)
|
|
23
|
+
.join(models.Trace)
|
|
24
|
+
.where(models.Span.parent_id.is_(None))
|
|
25
|
+
.where(models.Trace.id.in_(keys))
|
|
26
|
+
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
|
|
27
|
+
)
|
|
28
|
+
async with self._db() as session:
|
|
29
|
+
result: dict[Key, models.Span] = {
|
|
30
|
+
span.trace_rowid: span async for span in await session.stream_scalars(stmt)
|
|
31
|
+
}
|
|
32
|
+
return [result.get(key) for key in keys]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
UserRoleId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = UserRoleId
|
|
13
|
+
Result: TypeAlias = Optional[models.UserRole]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UserRolesDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
"""DataLoader that batches together user roles by their ids."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
super().__init__(load_fn=self._load_fn)
|
|
21
|
+
self._db = db
|
|
22
|
+
|
|
23
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
24
|
+
user_roles_by_id: defaultdict[Key, Result] = defaultdict(None)
|
|
25
|
+
async with self._db() as session:
|
|
26
|
+
data = await session.stream_scalars(select(models.UserRole))
|
|
27
|
+
async for user_role in data:
|
|
28
|
+
user_roles_by_id[user_role.id] = user_role
|
|
29
|
+
|
|
30
|
+
return [user_roles_by_id.get(role_id) for role_id in keys]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
UserId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = UserId
|
|
13
|
+
Result: TypeAlias = Optional[models.User]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UsersDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
"""DataLoader that batches together users by their ids."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
super().__init__(load_fn=self._load_fn)
|
|
21
|
+
self._db = db
|
|
22
|
+
|
|
23
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
24
|
+
user_ids = list(set(keys))
|
|
25
|
+
users_by_id: defaultdict[Key, Result] = defaultdict(None)
|
|
26
|
+
async with self._db() as session:
|
|
27
|
+
data = await session.stream_scalars(
|
|
28
|
+
select(models.User).where(models.User.id.in_(user_ids))
|
|
29
|
+
)
|
|
30
|
+
async for user in data:
|
|
31
|
+
users_by_id[user.id] = user
|
|
32
|
+
|
|
33
|
+
return [users_by_id.get(user_id) for user_id in keys]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from graphql.error import GraphQLError
|
|
2
|
+
from strawberry.extensions import MaskErrors
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CustomGraphQLError(Exception):
|
|
6
|
+
"""
|
|
7
|
+
An error that represents an expected error scenario in a GraphQL resolver.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BadRequest(CustomGraphQLError):
|
|
12
|
+
"""
|
|
13
|
+
An error raised due to a malformed or invalid request.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NotFound(CustomGraphQLError):
|
|
18
|
+
"""
|
|
19
|
+
An error raised when the requested resource is not found.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Unauthorized(CustomGraphQLError):
|
|
24
|
+
"""
|
|
25
|
+
An error raised when login fails or a user or other entity is not authorized
|
|
26
|
+
to access a resource.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Conflict(CustomGraphQLError):
|
|
31
|
+
"""
|
|
32
|
+
An error raised when a mutation cannot be completed due to a conflict with
|
|
33
|
+
the current state of one or more resources.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_mask_errors_extension() -> MaskErrors:
|
|
38
|
+
return MaskErrors(
|
|
39
|
+
should_mask_error=_should_mask_error,
|
|
40
|
+
error_message="an unexpected error occurred",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _should_mask_error(error: GraphQLError) -> bool:
|
|
45
|
+
"""
|
|
46
|
+
Masks unexpected errors raised from GraphQL resolvers.
|
|
47
|
+
"""
|
|
48
|
+
return not isinstance(error.original_error, CustomGraphQLError)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Optional, TypeVar
|
|
3
|
+
|
|
4
|
+
T = TypeVar("T")
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def ensure_list(obj: Optional[Iterable[T]]) -> list[T]:
|
|
8
|
+
if isinstance(obj, list):
|
|
9
|
+
return obj
|
|
10
|
+
if isinstance(obj, Iterable):
|
|
11
|
+
return list(obj)
|
|
12
|
+
return []
|