arize-phoenix 9.6.1__py3-none-any.whl → 10.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.0.dist-info}/METADATA +1 -1
- {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.0.dist-info}/RECORD +27 -26
- phoenix/auth.py +6 -4
- phoenix/config.py +150 -25
- phoenix/db/enums.py +1 -2
- phoenix/db/facilitator.py +20 -11
- phoenix/db/migrations/versions/6a88424799fe_update_users_with_auth_method.py +179 -0
- phoenix/db/models.py +66 -37
- phoenix/server/api/context.py +5 -4
- phoenix/server/api/mutations/user_mutations.py +58 -26
- phoenix/server/api/routers/auth.py +16 -4
- phoenix/server/api/routers/oauth2.py +196 -15
- phoenix/server/app.py +36 -10
- phoenix/server/bearer_auth.py +5 -7
- phoenix/server/jwt_store.py +5 -4
- phoenix/server/main.py +11 -4
- phoenix/server/oauth2.py +47 -3
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-CDvTuTqd.js → components-CjGpmneV.js} +192 -184
- phoenix/server/static/assets/{index-DpcxdHu4.js → index-C57g4e_o.js} +11 -11
- phoenix/server/static/assets/{pages-Bcs41-Zv.js → pages-fQ2s7TFY.js} +343 -337
- phoenix/server/templates/index.html +1 -0
- phoenix/version.py +1 -1
- {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""Add auth_method column to users table and migrate existing authentication data.
|
|
2
|
+
|
|
3
|
+
This migration:
|
|
4
|
+
1. Adds a new 'auth_method' column to the users table that indicates whether a user
|
|
5
|
+
authenticates via local password ('LOCAL') or external OAuth2 ('OAUTH2')
|
|
6
|
+
2. Migrates existing authentication data to populate the new column:
|
|
7
|
+
- Sets 'LOCAL' for users with password_hash
|
|
8
|
+
- Sets 'OAUTH2' for users with OAuth2 credentials
|
|
9
|
+
3. Adds appropriate constraints to ensure data integrity:
|
|
10
|
+
- NOT NULL constraint on auth_method
|
|
11
|
+
- 'valid_auth_method': ensures only 'LOCAL' or 'OAUTH2' values
|
|
12
|
+
- 'local_auth_has_password_no_oauth': ensures LOCAL users have password credentials and
|
|
13
|
+
do not have OAuth2 credentials
|
|
14
|
+
- 'non_local_auth_has_no_password': ensures OAUTH2 users do not have password credentials
|
|
15
|
+
4. Removes legacy constraints that are replaced by the new column:
|
|
16
|
+
- 'password_hash_and_salt': ensures password_hash and password_salt are consistent
|
|
17
|
+
- 'exactly_one_auth_method': replaced by auth_method column and its constraints
|
|
18
|
+
- 'oauth2_client_id_and_user_id': replaced by auth_method column and its constraints
|
|
19
|
+
5. Drops redundant single column indices:
|
|
20
|
+
- 'ix_users_oauth2_client_id' and 'ix_users_oauth2_user_id' are removed as they are
|
|
21
|
+
redundant with the unique constraint 'uq_users_oauth2_client_id_oauth2_user_id',
|
|
22
|
+
which already provides the necessary composite index for lookups
|
|
23
|
+
|
|
24
|
+
The migration uses batch_alter_table to ensure compatibility with both SQLite and PostgreSQL.
|
|
25
|
+
This approach allows us to:
|
|
26
|
+
- Add the column as nullable initially
|
|
27
|
+
- Update the values based on existing authentication data
|
|
28
|
+
- Make the column NOT NULL after populating
|
|
29
|
+
- Add appropriate constraints
|
|
30
|
+
- Remove legacy constraints
|
|
31
|
+
- Drop redundant indices
|
|
32
|
+
|
|
33
|
+
The downgrade path:
|
|
34
|
+
1. Recreates the legacy constraints:
|
|
35
|
+
- 'password_hash_and_salt': ensures password_hash and password_salt are consistent
|
|
36
|
+
- 'exactly_one_auth_method': ensures exactly one auth method is set
|
|
37
|
+
- 'oauth2_client_id_and_user_id': ensures OAuth2 credentials are consistent
|
|
38
|
+
2. Removes the auth_method column and its associated constraints
|
|
39
|
+
3. Recreates the single column indices to maintain backward compatibility:
|
|
40
|
+
- 'ix_users_oauth2_client_id'
|
|
41
|
+
- 'ix_users_oauth2_user_id'
|
|
42
|
+
|
|
43
|
+
Revision ID: 6a88424799fe
|
|
44
|
+
Revises: 8a3764fe7f1a
|
|
45
|
+
Create Date: 2025-05-01 08:08:22.700715
|
|
46
|
+
|
|
47
|
+
""" # noqa: E501
|
|
48
|
+
|
|
49
|
+
from typing import Sequence, Union
|
|
50
|
+
|
|
51
|
+
import sqlalchemy as sa
|
|
52
|
+
from alembic import op
|
|
53
|
+
|
|
54
|
+
# revision identifiers, used by Alembic.
|
|
55
|
+
revision: str = "6a88424799fe"
|
|
56
|
+
down_revision: Union[str, None] = "8a3764fe7f1a"
|
|
57
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
58
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def upgrade() -> None:
|
|
62
|
+
"""Upgrade the database schema to include the auth_method column.
|
|
63
|
+
|
|
64
|
+
This function:
|
|
65
|
+
1. Adds the auth_method column as nullable
|
|
66
|
+
2. Populates the column based on existing authentication data:
|
|
67
|
+
- 'LOCAL' for users with password_hash
|
|
68
|
+
- 'OAUTH2' for users with OAuth2 credentials
|
|
69
|
+
3. Makes the column NOT NULL after populating
|
|
70
|
+
4. Adds CHECK constraints to ensure data integrity:
|
|
71
|
+
- 'valid_auth_method': ensures only 'LOCAL' or 'OAUTH2' values
|
|
72
|
+
- 'local_auth_has_password_no_oauth': ensures LOCAL users have password credentials and
|
|
73
|
+
do not have OAuth2 credentials
|
|
74
|
+
- 'non_local_auth_has_no_password': ensures OAUTH2 users do not have password credentials
|
|
75
|
+
5. Removes legacy constraints that are replaced by the new column:
|
|
76
|
+
- 'password_hash_and_salt'
|
|
77
|
+
- 'exactly_one_auth_method'
|
|
78
|
+
- 'oauth2_client_id_and_user_id'
|
|
79
|
+
6. Drops redundant single column indices:
|
|
80
|
+
- 'ix_users_oauth2_client_id' and 'ix_users_oauth2_user_id' are removed as they are
|
|
81
|
+
redundant with the unique constraint 'uq_users_oauth2_client_id_oauth2_user_id',
|
|
82
|
+
which already provides the necessary composite index for lookups
|
|
83
|
+
|
|
84
|
+
The implementation uses batch_alter_table for compatibility with both
|
|
85
|
+
SQLite and PostgreSQL databases.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
sqlalchemy.exc.SQLAlchemyError: If database operations fail
|
|
89
|
+
""" # noqa: E501
|
|
90
|
+
with op.batch_alter_table("users") as batch_op:
|
|
91
|
+
# For SQLite, first add the column as nullable
|
|
92
|
+
batch_op.add_column(sa.Column("auth_method", sa.String, nullable=True))
|
|
93
|
+
|
|
94
|
+
with op.batch_alter_table("users") as batch_op:
|
|
95
|
+
batch_op.execute("""
|
|
96
|
+
UPDATE users
|
|
97
|
+
SET auth_method = CASE
|
|
98
|
+
WHEN password_hash IS NOT NULL THEN 'LOCAL' ELSE 'OAUTH2' END
|
|
99
|
+
""")
|
|
100
|
+
# Make the column non-nullable
|
|
101
|
+
batch_op.alter_column("auth_method", nullable=False, existing_nullable=True)
|
|
102
|
+
|
|
103
|
+
# Drop both old constraints as they're now redundant
|
|
104
|
+
batch_op.drop_constraint("password_hash_and_salt", type_="check")
|
|
105
|
+
batch_op.drop_constraint("exactly_one_auth_method", type_="check")
|
|
106
|
+
batch_op.drop_constraint("oauth2_client_id_and_user_id", type_="check")
|
|
107
|
+
|
|
108
|
+
# Drop redundant single column indices, because a composite index already
|
|
109
|
+
# exists in the uniqueness constraint for (client_id, user_id)
|
|
110
|
+
batch_op.drop_index("ix_users_oauth2_client_id")
|
|
111
|
+
batch_op.drop_index("ix_users_oauth2_user_id")
|
|
112
|
+
|
|
113
|
+
# Add CHECK constraint to ensure only valid values are allowed
|
|
114
|
+
batch_op.create_check_constraint(
|
|
115
|
+
"valid_auth_method",
|
|
116
|
+
"auth_method IN ('LOCAL', 'OAUTH2')",
|
|
117
|
+
)
|
|
118
|
+
batch_op.create_check_constraint(
|
|
119
|
+
"local_auth_has_password_no_oauth",
|
|
120
|
+
"auth_method != 'LOCAL' "
|
|
121
|
+
"OR (password_hash IS NOT NULL AND password_salt IS NOT NULL "
|
|
122
|
+
"AND oauth2_client_id IS NULL AND oauth2_user_id IS NULL)",
|
|
123
|
+
)
|
|
124
|
+
batch_op.create_check_constraint(
|
|
125
|
+
"non_local_auth_has_no_password",
|
|
126
|
+
"auth_method = 'LOCAL' OR (password_hash IS NULL AND password_salt IS NULL)",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def downgrade() -> None:
|
|
131
|
+
"""Downgrade the database schema by removing the auth_method column.
|
|
132
|
+
|
|
133
|
+
This function:
|
|
134
|
+
1. Recreates the legacy constraints that were removed in the upgrade:
|
|
135
|
+
- 'password_hash_and_salt': ensures password_hash and password_salt are consistent
|
|
136
|
+
- 'exactly_one_auth_method': ensures exactly one auth method is set
|
|
137
|
+
- 'oauth2_client_id_and_user_id': ensures OAuth2 credentials are consistent
|
|
138
|
+
2. Removes the auth_method column and its associated CHECK constraints:
|
|
139
|
+
- 'non_local_auth_has_no_password'
|
|
140
|
+
- 'local_auth_has_password_no_oauth'
|
|
141
|
+
- 'valid_auth_method'
|
|
142
|
+
3. Recreates the single column indices to maintain backward compatibility:
|
|
143
|
+
- 'ix_users_oauth2_client_id'
|
|
144
|
+
- 'ix_users_oauth2_user_id'
|
|
145
|
+
|
|
146
|
+
The implementation uses batch_alter_table to ensure compatibility with both
|
|
147
|
+
SQLite and PostgreSQL databases.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
sqlalchemy.exc.SQLAlchemyError: If database operations fail
|
|
151
|
+
""" # noqa: E501
|
|
152
|
+
# Use batch_alter_table for SQLite compatibility
|
|
153
|
+
# This ensures the downgrade works on both SQLite and PostgreSQL
|
|
154
|
+
with op.batch_alter_table("users") as batch_op:
|
|
155
|
+
# Drop the CHECK constraint and column
|
|
156
|
+
batch_op.drop_constraint("non_local_auth_has_no_password", type_="check")
|
|
157
|
+
batch_op.drop_constraint("local_auth_has_password_no_oauth", type_="check")
|
|
158
|
+
batch_op.drop_constraint("valid_auth_method", type_="check")
|
|
159
|
+
|
|
160
|
+
# Recreate single column indices
|
|
161
|
+
batch_op.create_index("ix_users_oauth2_user_id", ["oauth2_user_id"])
|
|
162
|
+
batch_op.create_index("ix_users_oauth2_client_id", ["oauth2_client_id"])
|
|
163
|
+
|
|
164
|
+
# Recreate both old constraints that were dropped in upgrade
|
|
165
|
+
batch_op.create_check_constraint(
|
|
166
|
+
"oauth2_client_id_and_user_id",
|
|
167
|
+
"(oauth2_client_id IS NULL) = (oauth2_user_id IS NULL)",
|
|
168
|
+
)
|
|
169
|
+
batch_op.create_check_constraint(
|
|
170
|
+
"exactly_one_auth_method",
|
|
171
|
+
"(password_hash IS NULL) != (oauth2_client_id IS NULL)",
|
|
172
|
+
)
|
|
173
|
+
batch_op.create_check_constraint(
|
|
174
|
+
"password_hash_and_salt",
|
|
175
|
+
"(password_hash IS NULL) = (password_salt IS NULL)",
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Remove added column
|
|
179
|
+
batch_op.drop_column("auth_method")
|
phoenix/db/models.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from datetime import datetime, timezone
|
|
2
|
-
from enum import Enum
|
|
3
2
|
from typing import Any, Iterable, Literal, Optional, Sequence, TypedDict, cast
|
|
4
3
|
|
|
5
4
|
import sqlalchemy.sql as sql
|
|
@@ -23,7 +22,6 @@ from sqlalchemy import (
|
|
|
23
22
|
case,
|
|
24
23
|
func,
|
|
25
24
|
insert,
|
|
26
|
-
not_,
|
|
27
25
|
select,
|
|
28
26
|
text,
|
|
29
27
|
)
|
|
@@ -42,6 +40,7 @@ from sqlalchemy.orm import (
|
|
|
42
40
|
from sqlalchemy.sql import Values, column, compiler, expression, literal, roles, union_all
|
|
43
41
|
from sqlalchemy.sql.compiler import SQLCompiler
|
|
44
42
|
from sqlalchemy.sql.functions import coalesce
|
|
43
|
+
from typing_extensions import TypeAlias
|
|
45
44
|
|
|
46
45
|
from phoenix.config import get_env_database_schema
|
|
47
46
|
from phoenix.datetime_utils import normalize_datetime
|
|
@@ -147,9 +146,7 @@ def render_values_w_union(
|
|
|
147
146
|
return compiler.process(subquery, from_linter=from_linter, **kw)
|
|
148
147
|
|
|
149
148
|
|
|
150
|
-
|
|
151
|
-
LOCAL = "LOCAL"
|
|
152
|
-
OAUTH2 = "OAUTH2"
|
|
149
|
+
AuthMethod: TypeAlias = Literal["LOCAL", "OAUTH2"]
|
|
153
150
|
|
|
154
151
|
|
|
155
152
|
class JSONB(JSON):
|
|
@@ -1152,8 +1149,11 @@ class User(Base):
|
|
|
1152
1149
|
password_hash: Mapped[Optional[bytes]]
|
|
1153
1150
|
password_salt: Mapped[Optional[bytes]]
|
|
1154
1151
|
reset_password: Mapped[bool]
|
|
1155
|
-
oauth2_client_id: Mapped[Optional[str]]
|
|
1156
|
-
oauth2_user_id: Mapped[Optional[str]]
|
|
1152
|
+
oauth2_client_id: Mapped[Optional[str]]
|
|
1153
|
+
oauth2_user_id: Mapped[Optional[str]]
|
|
1154
|
+
auth_method: Mapped[AuthMethod] = mapped_column(
|
|
1155
|
+
CheckConstraint("auth_method IN ('LOCAL', 'OAUTH2')", name="valid_auth_method")
|
|
1156
|
+
)
|
|
1157
1157
|
created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
|
|
1158
1158
|
updated_at: Mapped[datetime] = mapped_column(
|
|
1159
1159
|
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
|
|
@@ -1169,41 +1169,21 @@ class User(Base):
|
|
|
1169
1169
|
)
|
|
1170
1170
|
api_keys: Mapped[list["ApiKey"]] = relationship("ApiKey", back_populates="user")
|
|
1171
1171
|
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
elif self.oauth2_client_id is not None:
|
|
1177
|
-
return AuthMethod.OAUTH2.value
|
|
1178
|
-
return None
|
|
1179
|
-
|
|
1180
|
-
@auth_method.inplace.expression
|
|
1181
|
-
@classmethod
|
|
1182
|
-
def _auth_method_expression(cls) -> ColumnElement[Optional[str]]:
|
|
1183
|
-
return case(
|
|
1184
|
-
(
|
|
1185
|
-
not_(cls.password_hash.is_(None)),
|
|
1186
|
-
AuthMethod.LOCAL.value,
|
|
1187
|
-
),
|
|
1188
|
-
(
|
|
1189
|
-
not_(cls.oauth2_client_id.is_(None)),
|
|
1190
|
-
AuthMethod.OAUTH2.value,
|
|
1191
|
-
),
|
|
1192
|
-
else_=None,
|
|
1193
|
-
)
|
|
1172
|
+
__mapper_args__ = {
|
|
1173
|
+
"polymorphic_on": "auth_method",
|
|
1174
|
+
"polymorphic_identity": None, # Base class is abstract
|
|
1175
|
+
}
|
|
1194
1176
|
|
|
1195
1177
|
__table_args__ = (
|
|
1196
1178
|
CheckConstraint(
|
|
1197
|
-
"
|
|
1198
|
-
|
|
1179
|
+
"auth_method != 'LOCAL' "
|
|
1180
|
+
"OR (password_hash IS NOT NULL AND password_salt IS NOT NULL "
|
|
1181
|
+
"AND oauth2_client_id IS NULL AND oauth2_user_id IS NULL)",
|
|
1182
|
+
name="local_auth_has_password_no_oauth",
|
|
1199
1183
|
),
|
|
1200
1184
|
CheckConstraint(
|
|
1201
|
-
"(
|
|
1202
|
-
name="
|
|
1203
|
-
),
|
|
1204
|
-
CheckConstraint(
|
|
1205
|
-
"(password_hash IS NULL) != (oauth2_client_id IS NULL)",
|
|
1206
|
-
name="exactly_one_auth_method",
|
|
1185
|
+
"auth_method = 'LOCAL' OR (password_hash IS NULL AND password_salt IS NULL)",
|
|
1186
|
+
name="non_local_auth_has_no_password",
|
|
1207
1187
|
),
|
|
1208
1188
|
UniqueConstraint(
|
|
1209
1189
|
"oauth2_client_id",
|
|
@@ -1213,6 +1193,55 @@ class User(Base):
|
|
|
1213
1193
|
)
|
|
1214
1194
|
|
|
1215
1195
|
|
|
1196
|
+
class LocalUser(User):
|
|
1197
|
+
__mapper_args__ = {
|
|
1198
|
+
"polymorphic_identity": "LOCAL",
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
def __init__(
|
|
1202
|
+
self,
|
|
1203
|
+
*,
|
|
1204
|
+
email: str,
|
|
1205
|
+
username: str,
|
|
1206
|
+
password_hash: bytes,
|
|
1207
|
+
password_salt: bytes,
|
|
1208
|
+
reset_password: bool = True,
|
|
1209
|
+
user_role_id: Optional[int] = None,
|
|
1210
|
+
) -> None:
|
|
1211
|
+
if not password_hash or not password_salt:
|
|
1212
|
+
raise ValueError("password_hash and password_salt are required for LocalUser")
|
|
1213
|
+
super().__init__(
|
|
1214
|
+
email=email.strip(),
|
|
1215
|
+
username=username.strip(),
|
|
1216
|
+
user_role_id=user_role_id,
|
|
1217
|
+
password_hash=password_hash,
|
|
1218
|
+
password_salt=password_salt,
|
|
1219
|
+
reset_password=reset_password,
|
|
1220
|
+
auth_method="LOCAL",
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
|
|
1224
|
+
class OAuth2User(User):
|
|
1225
|
+
__mapper_args__ = {
|
|
1226
|
+
"polymorphic_identity": "OAUTH2",
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
def __init__(
|
|
1230
|
+
self,
|
|
1231
|
+
*,
|
|
1232
|
+
email: str,
|
|
1233
|
+
username: str,
|
|
1234
|
+
user_role_id: Optional[int] = None,
|
|
1235
|
+
) -> None:
|
|
1236
|
+
super().__init__(
|
|
1237
|
+
email=email.strip(),
|
|
1238
|
+
username=username.strip(),
|
|
1239
|
+
user_role_id=user_role_id,
|
|
1240
|
+
reset_password=False,
|
|
1241
|
+
auth_method="OAUTH2",
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
|
|
1216
1245
|
class PasswordResetToken(Base):
|
|
1217
1246
|
__tablename__ = "password_reset_tokens"
|
|
1218
1247
|
user_id: Mapped[int] = mapped_column(
|
phoenix/server/api/context.py
CHANGED
|
@@ -4,6 +4,7 @@ from functools import cached_property, partial
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Any, Optional, cast
|
|
6
6
|
|
|
7
|
+
from starlette.datastructures import Secret
|
|
7
8
|
from starlette.requests import Request as StarletteRequest
|
|
8
9
|
from starlette.responses import Response as StarletteResponse
|
|
9
10
|
from strawberry.fastapi import BaseContext
|
|
@@ -128,11 +129,11 @@ class Context(BaseContext):
|
|
|
128
129
|
read_only: bool = False
|
|
129
130
|
locked: bool = False
|
|
130
131
|
auth_enabled: bool = False
|
|
131
|
-
secret: Optional[
|
|
132
|
+
secret: Optional[Secret] = None
|
|
132
133
|
token_store: Optional[TokenStore] = None
|
|
133
134
|
email_sender: Optional[EmailSender] = None
|
|
134
135
|
|
|
135
|
-
def get_secret(self) ->
|
|
136
|
+
def get_secret(self) -> Secret:
|
|
136
137
|
"""A type-safe way to get the application secret. Throws an error if the secret is not set.
|
|
137
138
|
|
|
138
139
|
Returns:
|
|
@@ -161,7 +162,7 @@ class Context(BaseContext):
|
|
|
161
162
|
raise ValueError("no response is set")
|
|
162
163
|
return response
|
|
163
164
|
|
|
164
|
-
async def is_valid_password(self, password:
|
|
165
|
+
async def is_valid_password(self, password: Secret, user: models.User) -> bool:
|
|
165
166
|
return (
|
|
166
167
|
(hash_ := user.password_hash) is not None
|
|
167
168
|
and (salt := user.password_salt) is not None
|
|
@@ -169,7 +170,7 @@ class Context(BaseContext):
|
|
|
169
170
|
)
|
|
170
171
|
|
|
171
172
|
@staticmethod
|
|
172
|
-
async def hash_password(password:
|
|
173
|
+
async def hash_password(password: Secret, salt: bytes) -> bytes:
|
|
173
174
|
compute = partial(compute_password_hash, password=password, salt=salt)
|
|
174
175
|
return await get_running_loop().run_in_executor(None, compute)
|
|
175
176
|
|
|
@@ -9,6 +9,7 @@ from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func
|
|
|
9
9
|
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
10
10
|
from sqlalchemy.orm import joinedload
|
|
11
11
|
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
12
|
+
from starlette.datastructures import Secret
|
|
12
13
|
from strawberry import UNSET
|
|
13
14
|
from strawberry.relay import GlobalID
|
|
14
15
|
from strawberry.types import Info
|
|
@@ -23,11 +24,13 @@ from phoenix.auth import (
|
|
|
23
24
|
validate_email_format,
|
|
24
25
|
validate_password_format,
|
|
25
26
|
)
|
|
27
|
+
from phoenix.config import get_env_disable_basic_auth
|
|
26
28
|
from phoenix.db import enums, models
|
|
27
29
|
from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly
|
|
28
30
|
from phoenix.server.api.context import Context
|
|
29
|
-
from phoenix.server.api.exceptions import Conflict, NotFound, Unauthorized
|
|
31
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized
|
|
30
32
|
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
|
|
33
|
+
from phoenix.server.api.types.AuthMethod import AuthMethod
|
|
31
34
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
32
35
|
from phoenix.server.api.types.User import User, to_gql_user
|
|
33
36
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
@@ -40,9 +43,19 @@ logger = logging.getLogger(__name__)
|
|
|
40
43
|
class CreateUserInput:
|
|
41
44
|
email: str
|
|
42
45
|
username: str
|
|
43
|
-
password: str
|
|
46
|
+
password: Optional[str] = UNSET
|
|
44
47
|
role: UserRoleInput
|
|
45
48
|
send_welcome_email: Optional[bool] = False
|
|
49
|
+
auth_method: Optional[AuthMethod] = AuthMethod.LOCAL
|
|
50
|
+
|
|
51
|
+
def __post_init__(self) -> None:
|
|
52
|
+
if self.auth_method is AuthMethod.OAUTH2:
|
|
53
|
+
if self.password:
|
|
54
|
+
raise BadRequest("Password is not allowed for OAuth2 authentication")
|
|
55
|
+
elif get_env_disable_basic_auth():
|
|
56
|
+
raise BadRequest("Basic auth is disabled: OAuth2 authentication only")
|
|
57
|
+
elif not self.password:
|
|
58
|
+
raise BadRequest("Password is required for local authentication")
|
|
46
59
|
|
|
47
60
|
|
|
48
61
|
@strawberry.input
|
|
@@ -53,11 +66,16 @@ class PatchViewerInput:
|
|
|
53
66
|
|
|
54
67
|
def __post_init__(self) -> None:
|
|
55
68
|
if not self.new_username and not self.new_password:
|
|
56
|
-
raise
|
|
57
|
-
if self.new_password and not self.current_password:
|
|
58
|
-
raise ValueError("current_password is required when modifying password")
|
|
69
|
+
raise BadRequest("At least one field must be set")
|
|
59
70
|
if self.new_password:
|
|
60
|
-
|
|
71
|
+
if get_env_disable_basic_auth():
|
|
72
|
+
raise BadRequest("Basic auth is disabled: OAuth2 authentication only")
|
|
73
|
+
if not self.current_password:
|
|
74
|
+
raise BadRequest("current_password is required when modifying password")
|
|
75
|
+
try:
|
|
76
|
+
PASSWORD_REQUIREMENTS.validate(self.new_password)
|
|
77
|
+
except ValueError as e:
|
|
78
|
+
raise BadRequest(str(e))
|
|
61
79
|
|
|
62
80
|
|
|
63
81
|
@strawberry.input
|
|
@@ -69,9 +87,14 @@ class PatchUserInput:
|
|
|
69
87
|
|
|
70
88
|
def __post_init__(self) -> None:
|
|
71
89
|
if not self.new_role and not self.new_username and not self.new_password:
|
|
72
|
-
raise
|
|
90
|
+
raise BadRequest("At least one field must be set")
|
|
73
91
|
if self.new_password:
|
|
74
|
-
|
|
92
|
+
if get_env_disable_basic_auth():
|
|
93
|
+
raise BadRequest("Basic auth is disabled: OAuth2 authentication only")
|
|
94
|
+
try:
|
|
95
|
+
PASSWORD_REQUIREMENTS.validate(self.new_password)
|
|
96
|
+
except ValueError as e:
|
|
97
|
+
raise BadRequest(str(e))
|
|
75
98
|
|
|
76
99
|
|
|
77
100
|
@strawberry.input
|
|
@@ -92,17 +115,24 @@ class UserMutationMixin:
|
|
|
92
115
|
info: Info[Context, None],
|
|
93
116
|
input: CreateUserInput,
|
|
94
117
|
) -> UserMutationPayload:
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
118
|
+
user: models.User
|
|
119
|
+
if input.auth_method is AuthMethod.OAUTH2:
|
|
120
|
+
user = models.OAuth2User(
|
|
121
|
+
email=input.email,
|
|
122
|
+
username=input.username,
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
assert input.password
|
|
126
|
+
validate_email_format(input.email)
|
|
127
|
+
validate_password_format(input.password)
|
|
128
|
+
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
129
|
+
password_hash = await info.context.hash_password(Secret(input.password), salt)
|
|
130
|
+
user = models.LocalUser(
|
|
131
|
+
email=input.email,
|
|
132
|
+
username=input.username,
|
|
133
|
+
password_hash=password_hash,
|
|
134
|
+
password_salt=salt,
|
|
135
|
+
)
|
|
106
136
|
async with AsyncExitStack() as stack:
|
|
107
137
|
session = await stack.enter_async_context(info.context.db())
|
|
108
138
|
user_role_id = await session.scalar(_select_role_id_by_name(input.role.value))
|
|
@@ -150,11 +180,12 @@ class UserMutationMixin:
|
|
|
150
180
|
raise NotFound(f"Role {input.new_role.value} not found")
|
|
151
181
|
user.user_role_id = user_role_id
|
|
152
182
|
if password := input.new_password:
|
|
153
|
-
if user.auth_method !=
|
|
183
|
+
if user.auth_method != "LOCAL":
|
|
154
184
|
raise Conflict("Cannot modify password for non-local user")
|
|
155
185
|
validate_password_format(password)
|
|
156
|
-
|
|
157
|
-
user.
|
|
186
|
+
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
187
|
+
user.password_salt = salt
|
|
188
|
+
user.password_hash = await info.context.hash_password(Secret(password), salt)
|
|
158
189
|
user.reset_password = True
|
|
159
190
|
if username := input.new_username:
|
|
160
191
|
user.username = username
|
|
@@ -183,15 +214,16 @@ class UserMutationMixin:
|
|
|
183
214
|
raise NotFound("User not found")
|
|
184
215
|
stack.enter_context(session.no_autoflush)
|
|
185
216
|
if password := input.new_password:
|
|
186
|
-
if user.auth_method !=
|
|
217
|
+
if user.auth_method != "LOCAL":
|
|
187
218
|
raise Conflict("Cannot modify password for non-local user")
|
|
188
219
|
if not (
|
|
189
220
|
current_password := input.current_password
|
|
190
|
-
) or not await info.context.is_valid_password(current_password, user):
|
|
221
|
+
) or not await info.context.is_valid_password(Secret(current_password), user):
|
|
191
222
|
raise Conflict("Valid current password is required to modify password")
|
|
192
223
|
validate_password_format(password)
|
|
193
|
-
|
|
194
|
-
user.
|
|
224
|
+
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
225
|
+
user.password_salt = salt
|
|
226
|
+
user.password_hash = await info.context.hash_password(Secret(password), salt)
|
|
195
227
|
user.reset_password = False
|
|
196
228
|
if username := input.new_username:
|
|
197
229
|
user.username = username
|
|
@@ -11,6 +11,7 @@ from sqlalchemy.orm import joinedload
|
|
|
11
11
|
from starlette.status import (
|
|
12
12
|
HTTP_204_NO_CONTENT,
|
|
13
13
|
HTTP_401_UNAUTHORIZED,
|
|
14
|
+
HTTP_403_FORBIDDEN,
|
|
14
15
|
HTTP_404_NOT_FOUND,
|
|
15
16
|
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
16
17
|
HTTP_503_SERVICE_UNAVAILABLE,
|
|
@@ -31,8 +32,13 @@ from phoenix.auth import (
|
|
|
31
32
|
set_refresh_token_cookie,
|
|
32
33
|
validate_password_format,
|
|
33
34
|
)
|
|
34
|
-
from phoenix.config import
|
|
35
|
-
|
|
35
|
+
from phoenix.config import (
|
|
36
|
+
get_base_url,
|
|
37
|
+
get_env_disable_basic_auth,
|
|
38
|
+
get_env_disable_rate_limit,
|
|
39
|
+
get_env_host_root_path,
|
|
40
|
+
)
|
|
41
|
+
from phoenix.db import models
|
|
36
42
|
from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
|
|
37
43
|
from phoenix.server.email.types import EmailSender
|
|
38
44
|
from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_ip_rate_limiter
|
|
@@ -68,6 +74,8 @@ router = APIRouter(prefix="/auth", include_in_schema=False, dependencies=auth_de
|
|
|
68
74
|
|
|
69
75
|
@router.post("/login")
|
|
70
76
|
async def login(request: Request) -> Response:
|
|
77
|
+
if get_env_disable_basic_auth():
|
|
78
|
+
raise HTTPException(status_code=HTTP_403_FORBIDDEN)
|
|
71
79
|
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
|
|
72
80
|
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
|
|
73
81
|
token_store: TokenStore = request.app.state.get_token_store()
|
|
@@ -192,6 +200,8 @@ async def refresh_tokens(request: Request) -> Response:
|
|
|
192
200
|
|
|
193
201
|
@router.post("/password-reset-email")
|
|
194
202
|
async def initiate_password_reset(request: Request) -> Response:
|
|
203
|
+
if get_env_disable_basic_auth():
|
|
204
|
+
raise HTTPException(status_code=HTTP_403_FORBIDDEN)
|
|
195
205
|
data = await request.json()
|
|
196
206
|
if not (email := data.get("email")):
|
|
197
207
|
raise MISSING_EMAIL
|
|
@@ -207,7 +217,7 @@ async def initiate_password_reset(request: Request) -> Response:
|
|
|
207
217
|
joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id)
|
|
208
218
|
)
|
|
209
219
|
)
|
|
210
|
-
if user is None or user.auth_method !=
|
|
220
|
+
if user is None or user.auth_method != "LOCAL":
|
|
211
221
|
# Withold privileged information
|
|
212
222
|
return Response(status_code=HTTP_204_NO_CONTENT)
|
|
213
223
|
token_store: TokenStore = request.app.state.get_token_store()
|
|
@@ -230,6 +240,8 @@ async def initiate_password_reset(request: Request) -> Response:
|
|
|
230
240
|
|
|
231
241
|
@router.post("/password-reset")
|
|
232
242
|
async def reset_password(request: Request) -> Response:
|
|
243
|
+
if get_env_disable_basic_auth():
|
|
244
|
+
raise HTTPException(status_code=HTTP_403_FORBIDDEN)
|
|
233
245
|
data = await request.json()
|
|
234
246
|
if not (password := data.get("password")):
|
|
235
247
|
raise MISSING_PASSWORD
|
|
@@ -244,7 +256,7 @@ async def reset_password(request: Request) -> Response:
|
|
|
244
256
|
assert (user_id := claims.subject)
|
|
245
257
|
async with request.app.state.db() as session:
|
|
246
258
|
user = await session.scalar(select(models.User).filter_by(id=int(user_id)))
|
|
247
|
-
if user is None or user.auth_method !=
|
|
259
|
+
if user is None or user.auth_method != "LOCAL":
|
|
248
260
|
# Withold privileged information
|
|
249
261
|
return Response(status_code=HTTP_204_NO_CONTENT)
|
|
250
262
|
validate_password_format(password)
|