arize-phoenix 4.35.2__py3-none-any.whl → 5.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +92 -79
- phoenix/__init__.py +86 -0
- phoenix/auth.py +275 -14
- phoenix/config.py +369 -27
- phoenix/db/alembic.ini +0 -34
- phoenix/db/engines.py +27 -10
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +112 -0
- phoenix/db/insertion/dataset.py +0 -1
- phoenix/db/insertion/types.py +1 -1
- phoenix/db/migrate.py +3 -3
- phoenix/db/migrations/env.py +0 -7
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/models.py +145 -60
- phoenix/experiments/evaluators/code_evaluators.py +9 -3
- phoenix/experiments/functions.py +1 -4
- phoenix/inferences/fixtures.py +0 -1
- phoenix/inferences/inferences.py +0 -1
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +0 -1
- phoenix/otel/settings.py +4 -4
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +32 -0
- phoenix/server/api/context.py +50 -2
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +7 -0
- phoenix/server/api/mutations/__init__.py +0 -2
- phoenix/server/api/mutations/api_key_mutations.py +104 -86
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/project_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +282 -42
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +48 -39
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +456 -0
- phoenix/server/api/routers/v1/__init__.py +38 -16
- phoenix/server/api/routers/v1/datasets.py +0 -1
- phoenix/server/api/types/ApiKey.py +11 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/User.py +48 -4
- phoenix/server/api/types/UserApiKey.py +35 -1
- phoenix/server/api/types/UserRole.py +7 -0
- phoenix/server/app.py +105 -34
- phoenix/server/bearer_auth.py +161 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +26 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +6 -0
- phoenix/server/jwt_store.py +504 -0
- phoenix/server/main.py +61 -30
- phoenix/server/oauth2.py +51 -0
- phoenix/server/prometheus.py +20 -0
- phoenix/server/rate_limiters.py +191 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
- phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
- phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
- phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
- phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
- phoenix/server/telemetry.py +2 -2
- phoenix/server/templates/index.html +1 -0
- phoenix/server/types.py +157 -1
- phoenix/services.py +0 -1
- phoenix/session/client.py +7 -3
- phoenix/session/evaluation.py +0 -1
- phoenix/session/session.py +0 -1
- phoenix/settings.py +9 -0
- phoenix/trace/exporter.py +0 -1
- phoenix/trace/fixtures.py +0 -2
- phoenix/utilities/client.py +16 -0
- phoenix/utilities/logging.py +9 -1
- phoenix/utilities/re.py +3 -3
- phoenix/version.py +1 -1
- phoenix/db/migrations/future_versions/README.md +0 -4
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
- phoenix/db/migrations/versions/.gitignore +0 -1
- phoenix/server/api/mutations/auth.py +0 -18
- phoenix/server/api/mutations/auth_mutations.py +0 -65
- phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -103
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -31
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import logging.config
|
|
4
|
+
import logging.handlers
|
|
5
|
+
import queue
|
|
6
|
+
from sys import stderr, stdout
|
|
7
|
+
|
|
8
|
+
from typing_extensions import assert_never
|
|
9
|
+
|
|
10
|
+
from phoenix.config import LoggingMode
|
|
11
|
+
from phoenix.logging._filter import NonErrorFilter
|
|
12
|
+
from phoenix.settings import Settings
|
|
13
|
+
|
|
14
|
+
from ._formatter import PhoenixJSONFormatter
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_logging() -> None:
|
|
18
|
+
"""
|
|
19
|
+
Configures logging for the specified logging mode.
|
|
20
|
+
"""
|
|
21
|
+
logging_mode = Settings.logging_mode
|
|
22
|
+
if logging_mode is LoggingMode.DEFAULT:
|
|
23
|
+
_setup_library_logging()
|
|
24
|
+
elif logging_mode is LoggingMode.STRUCTURED:
|
|
25
|
+
_setup_application_logging()
|
|
26
|
+
else:
|
|
27
|
+
assert_never(logging_mode)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _setup_library_logging() -> None:
|
|
31
|
+
"""
|
|
32
|
+
Configures logging if Phoenix is used as a library
|
|
33
|
+
"""
|
|
34
|
+
logger = logging.getLogger("phoenix")
|
|
35
|
+
logger.setLevel(Settings.logging_level)
|
|
36
|
+
db_logger = logging.getLogger("sqlalchemy")
|
|
37
|
+
db_logger.setLevel(Settings.db_logging_level)
|
|
38
|
+
logger.info("Default logging ready")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _setup_application_logging() -> None:
|
|
42
|
+
"""
|
|
43
|
+
Configures logging if Phoenix is used as an application
|
|
44
|
+
"""
|
|
45
|
+
sql_engine_logger = logging.getLogger("sqlalchemy.engine.Engine")
|
|
46
|
+
# Remove all existing handlers
|
|
47
|
+
for handler in sql_engine_logger.handlers[:]:
|
|
48
|
+
sql_engine_logger.removeHandler(handler)
|
|
49
|
+
handler.close()
|
|
50
|
+
|
|
51
|
+
phoenix_logger = logging.getLogger("phoenix")
|
|
52
|
+
phoenix_logger.setLevel(Settings.logging_level)
|
|
53
|
+
phoenix_logger.propagate = False # Do not pass records to the root logger
|
|
54
|
+
sql_logger = logging.getLogger("sqlalchemy")
|
|
55
|
+
sql_logger.setLevel(Settings.db_logging_level)
|
|
56
|
+
sql_logger.propagate = False # Do not pass records to the root logger
|
|
57
|
+
|
|
58
|
+
log_queue = queue.Queue() # type:ignore
|
|
59
|
+
queue_handler = logging.handlers.QueueHandler(log_queue)
|
|
60
|
+
phoenix_logger.addHandler(queue_handler)
|
|
61
|
+
sql_logger.addHandler(queue_handler)
|
|
62
|
+
|
|
63
|
+
fmt_keys = {
|
|
64
|
+
"level": "levelname",
|
|
65
|
+
"message": "message",
|
|
66
|
+
"timestamp": "timestamp",
|
|
67
|
+
"logger": "name",
|
|
68
|
+
"module": "module",
|
|
69
|
+
"function": "funcName",
|
|
70
|
+
"line": "lineno",
|
|
71
|
+
"thread_name": "threadName",
|
|
72
|
+
}
|
|
73
|
+
formatter = PhoenixJSONFormatter(fmt_keys=fmt_keys)
|
|
74
|
+
|
|
75
|
+
# stdout handler
|
|
76
|
+
stdout_handler = logging.StreamHandler(stdout)
|
|
77
|
+
stdout_handler.setFormatter(formatter)
|
|
78
|
+
stdout_handler.setLevel(Settings.logging_level)
|
|
79
|
+
stdout_handler.addFilter(NonErrorFilter())
|
|
80
|
+
|
|
81
|
+
# stderr handler
|
|
82
|
+
stderr_handler = logging.StreamHandler(stderr)
|
|
83
|
+
stderr_handler.setFormatter(formatter)
|
|
84
|
+
stderr_handler.setLevel(logging.WARNING)
|
|
85
|
+
|
|
86
|
+
queue_listener = logging.handlers.QueueListener(log_queue, stdout_handler, stderr_handler)
|
|
87
|
+
if queue_listener is not None:
|
|
88
|
+
queue_listener.start()
|
|
89
|
+
atexit.register(queue_listener.stop)
|
|
90
|
+
phoenix_logger.info("Structured logging ready")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Dict, Optional
|
|
5
|
+
|
|
6
|
+
LOG_RECORD_BUILTIN_ATTRS = {
|
|
7
|
+
"args",
|
|
8
|
+
"asctime",
|
|
9
|
+
"created",
|
|
10
|
+
"exc_info",
|
|
11
|
+
"exc_text",
|
|
12
|
+
"filename",
|
|
13
|
+
"funcName",
|
|
14
|
+
"levelname",
|
|
15
|
+
"levelno",
|
|
16
|
+
"lineno",
|
|
17
|
+
"module",
|
|
18
|
+
"msecs",
|
|
19
|
+
"message",
|
|
20
|
+
"msg",
|
|
21
|
+
"name",
|
|
22
|
+
"pathname",
|
|
23
|
+
"process",
|
|
24
|
+
"processName",
|
|
25
|
+
"relativeCreated",
|
|
26
|
+
"stack_info",
|
|
27
|
+
"thread",
|
|
28
|
+
"threadName",
|
|
29
|
+
"taskName",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PhoenixJSONFormatter(logging.Formatter):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
fmt_keys: Optional[Dict[str, str]] = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.fmt_keys = fmt_keys if fmt_keys is not None else {}
|
|
41
|
+
|
|
42
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
43
|
+
message = self._prepare_log_dict(record)
|
|
44
|
+
return json.dumps(message, default=str)
|
|
45
|
+
|
|
46
|
+
def _prepare_log_dict(self, record: logging.LogRecord) -> Dict[str, str]:
|
|
47
|
+
always_fields = {
|
|
48
|
+
"message": record.getMessage(),
|
|
49
|
+
"timestamp": dt.datetime.fromtimestamp(record.created, tz=dt.timezone.utc).isoformat(),
|
|
50
|
+
}
|
|
51
|
+
if record.exc_info is not None:
|
|
52
|
+
always_fields["exc_info"] = self.formatException(record.exc_info)
|
|
53
|
+
|
|
54
|
+
if record.stack_info is not None:
|
|
55
|
+
always_fields["stack_info"] = self.formatStack(record.stack_info)
|
|
56
|
+
|
|
57
|
+
message = {
|
|
58
|
+
key: msg_val
|
|
59
|
+
if (msg_val := always_fields.pop(val, None)) is not None
|
|
60
|
+
else getattr(record, val)
|
|
61
|
+
for key, val in self.fmt_keys.items()
|
|
62
|
+
}
|
|
63
|
+
message.update(always_fields)
|
|
64
|
+
|
|
65
|
+
for key, val in record.__dict__.items():
|
|
66
|
+
if key not in LOG_RECORD_BUILTIN_ATTRS:
|
|
67
|
+
message[key] = val
|
|
68
|
+
|
|
69
|
+
return message
|
phoenix/metrics/__init__.py
CHANGED
phoenix/otel/settings.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
import urllib
|
|
3
|
-
from logging import getLogger
|
|
4
4
|
from re import compile
|
|
5
5
|
from typing import Dict, List, Optional
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
8
|
|
|
9
9
|
# Environment variables specific to the subpackage
|
|
10
10
|
ENV_PHOENIX_COLLECTOR_ENDPOINT = "PHOENIX_COLLECTOR_ENDPOINT"
|
|
@@ -72,13 +72,13 @@ def parse_env_headers(s: str) -> Dict[str, str]:
|
|
|
72
72
|
encoded_header = f"{urllib.parse.quote(name)}={urllib.parse.quote(value)}"
|
|
73
73
|
match = _HEADER_PATTERN.fullmatch(encoded_header.strip())
|
|
74
74
|
if not match:
|
|
75
|
-
|
|
75
|
+
logger.warning(
|
|
76
76
|
"Header format invalid! Header values in environment variables must be "
|
|
77
77
|
"URL encoded: %s",
|
|
78
78
|
f"{name}: ****",
|
|
79
79
|
)
|
|
80
80
|
continue
|
|
81
|
-
|
|
81
|
+
logger.warning(
|
|
82
82
|
"Header values in environment variables should be URL encoded, attempting to "
|
|
83
83
|
"URL encode header: {name}: ****"
|
|
84
84
|
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Permission Matrix for GraphQL API
|
|
2
|
+
|
|
3
|
+
## Mutations
|
|
4
|
+
|
|
5
|
+
| Action | Admin | Member |
|
|
6
|
+
|:-----------------------------|:-----:|:------:|
|
|
7
|
+
| Create User | Yes | No |
|
|
8
|
+
| Delete User | Yes | No |
|
|
9
|
+
| Change Own Password | Yes | Yes |
|
|
10
|
+
| Change Other's Password | Yes | No |
|
|
11
|
+
| Change Own Username | Yes | Yes |
|
|
12
|
+
| Change Other's Username | Yes | No |
|
|
13
|
+
| Change Own Email | No | No |
|
|
14
|
+
| Change Other's Email | No | No |
|
|
15
|
+
| Create System API Keys | Yes | No |
|
|
16
|
+
| Delete System API Keys | Yes | No |
|
|
17
|
+
| Create Own User API Keys | Yes | Yes |
|
|
18
|
+
| Delete Own User API Keys | Yes | Yes |
|
|
19
|
+
| Delete Other's User API Keys | Yes | No |
|
|
20
|
+
|
|
21
|
+
## Queries
|
|
22
|
+
|
|
23
|
+
| Action | Admin | Member |
|
|
24
|
+
|:-------------------------------------|:-----:|:------:|
|
|
25
|
+
| List All System API Keys | Yes | No |
|
|
26
|
+
| List All User API Keys | Yes | No |
|
|
27
|
+
| List All Users | Yes | No |
|
|
28
|
+
| Fetch Other User's Info, e.g. emails | Yes | No |
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from strawberry import Info
|
|
5
|
+
from strawberry.permission import BasePermission
|
|
6
|
+
|
|
7
|
+
from phoenix.server.api.exceptions import Unauthorized
|
|
8
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Authorization(BasePermission, ABC):
|
|
12
|
+
def on_unauthorized(self) -> None:
|
|
13
|
+
raise Unauthorized(self.message)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class IsNotReadOnly(Authorization):
|
|
17
|
+
message = "Application is read-only"
|
|
18
|
+
|
|
19
|
+
def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
|
|
20
|
+
return not info.context.read_only
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
MSG_ADMIN_ONLY = "Only admin can perform this action"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class IsAdmin(Authorization):
|
|
27
|
+
message = MSG_ADMIN_ONLY
|
|
28
|
+
|
|
29
|
+
def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
|
|
30
|
+
if not info.context.auth_enabled:
|
|
31
|
+
return False
|
|
32
|
+
return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
|
phoenix/server/api/context.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
|
+
from asyncio import get_running_loop
|
|
1
2
|
from dataclasses import dataclass
|
|
3
|
+
from functools import cached_property, partial
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Optional
|
|
5
|
+
from typing import Any, Optional, cast
|
|
4
6
|
|
|
7
|
+
from starlette.requests import Request as StarletteRequest
|
|
5
8
|
from starlette.responses import Response as StarletteResponse
|
|
6
9
|
from strawberry.fastapi import BaseContext
|
|
7
10
|
|
|
11
|
+
from phoenix.auth import (
|
|
12
|
+
compute_password_hash,
|
|
13
|
+
)
|
|
8
14
|
from phoenix.core.model_schema import Model
|
|
15
|
+
from phoenix.db import models
|
|
9
16
|
from phoenix.server.api.dataloaders import (
|
|
10
17
|
AnnotationSummaryDataLoader,
|
|
11
18
|
AverageExperimentRunLatencyDataLoader,
|
|
@@ -30,9 +37,18 @@ from phoenix.server.api.dataloaders import (
|
|
|
30
37
|
SpanProjectsDataLoader,
|
|
31
38
|
TokenCountDataLoader,
|
|
32
39
|
TraceRowIdsDataLoader,
|
|
40
|
+
UserRolesDataLoader,
|
|
41
|
+
UsersDataLoader,
|
|
33
42
|
)
|
|
43
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
34
44
|
from phoenix.server.dml_event import DmlEvent
|
|
35
|
-
from phoenix.server.types import
|
|
45
|
+
from phoenix.server.types import (
|
|
46
|
+
CanGetLastUpdatedAt,
|
|
47
|
+
CanPutItem,
|
|
48
|
+
DbSessionFactory,
|
|
49
|
+
TokenStore,
|
|
50
|
+
UserId,
|
|
51
|
+
)
|
|
36
52
|
|
|
37
53
|
|
|
38
54
|
@dataclass
|
|
@@ -59,6 +75,8 @@ class DataLoaders:
|
|
|
59
75
|
token_counts: TokenCountDataLoader
|
|
60
76
|
trace_row_ids: TraceRowIdsDataLoader
|
|
61
77
|
project_by_name: ProjectByNameDataLoader
|
|
78
|
+
users: UsersDataLoader
|
|
79
|
+
user_roles: UserRolesDataLoader
|
|
62
80
|
|
|
63
81
|
|
|
64
82
|
class _NoOp:
|
|
@@ -77,7 +95,9 @@ class Context(BaseContext):
|
|
|
77
95
|
event_queue: CanPutItem[DmlEvent] = _NoOp()
|
|
78
96
|
corpus: Optional[Model] = None
|
|
79
97
|
read_only: bool = False
|
|
98
|
+
auth_enabled: bool = False
|
|
80
99
|
secret: Optional[str] = None
|
|
100
|
+
token_store: Optional[TokenStore] = None
|
|
81
101
|
|
|
82
102
|
def get_secret(self) -> str:
|
|
83
103
|
"""A type-safe way to get the application secret. Throws an error if the secret is not set.
|
|
@@ -92,6 +112,14 @@ class Context(BaseContext):
|
|
|
92
112
|
)
|
|
93
113
|
return self.secret
|
|
94
114
|
|
|
115
|
+
def get_request(self) -> StarletteRequest:
|
|
116
|
+
"""
|
|
117
|
+
A type-safe way to get the request object. Throws an error if the request is not set.
|
|
118
|
+
"""
|
|
119
|
+
if not isinstance(request := self.request, StarletteRequest):
|
|
120
|
+
raise ValueError("no request is set")
|
|
121
|
+
return request
|
|
122
|
+
|
|
95
123
|
def get_response(self) -> StarletteResponse:
|
|
96
124
|
"""
|
|
97
125
|
A type-safe way to get the response object. Throws an error if the response is not set.
|
|
@@ -99,3 +127,23 @@ class Context(BaseContext):
|
|
|
99
127
|
if (response := self.response) is None:
|
|
100
128
|
raise ValueError("no response is set")
|
|
101
129
|
return response
|
|
130
|
+
|
|
131
|
+
async def is_valid_password(self, password: str, user: models.User) -> bool:
|
|
132
|
+
return (
|
|
133
|
+
(hash_ := user.password_hash) is not None
|
|
134
|
+
and (salt := user.password_salt) is not None
|
|
135
|
+
and hash_ == await self.hash_password(password, salt)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
async def hash_password(password: str, salt: bytes) -> bytes:
|
|
140
|
+
compute = partial(compute_password_hash, password=password, salt=salt)
|
|
141
|
+
return await get_running_loop().run_in_executor(None, compute)
|
|
142
|
+
|
|
143
|
+
async def log_out(self, user_id: int) -> None:
|
|
144
|
+
assert self.token_store is not None
|
|
145
|
+
await self.token_store.log_out(UserId(user_id))
|
|
146
|
+
|
|
147
|
+
@cached_property
|
|
148
|
+
def user(self) -> PhoenixUser:
|
|
149
|
+
return cast(PhoenixUser, self.get_request().user)
|
|
@@ -25,6 +25,8 @@ from .span_descendants import SpanDescendantsDataLoader
|
|
|
25
25
|
from .span_projects import SpanProjectsDataLoader
|
|
26
26
|
from .token_counts import TokenCountCache, TokenCountDataLoader
|
|
27
27
|
from .trace_row_ids import TraceRowIdsDataLoader
|
|
28
|
+
from .user_roles import UserRolesDataLoader
|
|
29
|
+
from .users import UsersDataLoader
|
|
28
30
|
|
|
29
31
|
__all__ = [
|
|
30
32
|
"CacheForDataLoaders",
|
|
@@ -50,6 +52,8 @@ __all__ = [
|
|
|
50
52
|
"TraceRowIdsDataLoader",
|
|
51
53
|
"ProjectByNameDataLoader",
|
|
52
54
|
"SpanAnnotationsDataLoader",
|
|
55
|
+
"UsersDataLoader",
|
|
56
|
+
"UserRolesDataLoader",
|
|
53
57
|
]
|
|
54
58
|
|
|
55
59
|
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import DefaultDict, List, Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
UserRoleId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = UserRoleId
|
|
13
|
+
Result: TypeAlias = Optional[models.UserRole]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UserRolesDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
"""DataLoader that batches together user roles by their ids."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
super().__init__(load_fn=self._load_fn)
|
|
21
|
+
self._db = db
|
|
22
|
+
|
|
23
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
24
|
+
user_roles_by_id: DefaultDict[Key, Result] = defaultdict(None)
|
|
25
|
+
async with self._db() as session:
|
|
26
|
+
data = await session.stream_scalars(select(models.UserRole))
|
|
27
|
+
async for user_role in data:
|
|
28
|
+
user_roles_by_id[user_role.id] = user_role
|
|
29
|
+
|
|
30
|
+
return [user_roles_by_id.get(role_id) for role_id in keys]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import DefaultDict, List, Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
UserId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = UserId
|
|
13
|
+
Result: TypeAlias = Optional[models.User]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UsersDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
"""DataLoader that batches together users by their ids."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
super().__init__(load_fn=self._load_fn)
|
|
21
|
+
self._db = db
|
|
22
|
+
|
|
23
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
24
|
+
user_ids = list(set(keys))
|
|
25
|
+
users_by_id: DefaultDict[Key, Result] = defaultdict(None)
|
|
26
|
+
async with self._db() as session:
|
|
27
|
+
data = await session.stream_scalars(
|
|
28
|
+
select(models.User).where(models.User.id.in_(user_ids))
|
|
29
|
+
)
|
|
30
|
+
async for user in data:
|
|
31
|
+
users_by_id[user.id] = user
|
|
32
|
+
|
|
33
|
+
return [users_by_id.get(user_id) for user_id in keys]
|
phoenix/server/api/exceptions.py
CHANGED
|
@@ -27,6 +27,13 @@ class Unauthorized(CustomGraphQLError):
|
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
class Conflict(CustomGraphQLError):
|
|
31
|
+
"""
|
|
32
|
+
An error raised when a mutation cannot be completed due to a conflict with
|
|
33
|
+
the current state of one or more resources.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
30
37
|
def get_mask_errors_extension() -> MaskErrors:
|
|
31
38
|
return MaskErrors(
|
|
32
39
|
should_mask_error=_should_mask_error,
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import strawberry
|
|
2
2
|
|
|
3
3
|
from phoenix.server.api.mutations.api_key_mutations import ApiKeyMutationMixin
|
|
4
|
-
from phoenix.server.api.mutations.auth_mutations import AuthMutationMixin
|
|
5
4
|
from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
6
5
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
7
6
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
@@ -14,7 +13,6 @@ from phoenix.server.api.mutations.user_mutations import UserMutationMixin
|
|
|
14
13
|
@strawberry.type
|
|
15
14
|
class Mutation(
|
|
16
15
|
ApiKeyMutationMixin,
|
|
17
|
-
AuthMutationMixin,
|
|
18
16
|
DatasetMutationMixin,
|
|
19
17
|
ExperimentMutationMixin,
|
|
20
18
|
ExportEventsMutationMixin,
|