desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +50 -0
- desdeo/api/config.py +90 -0
- desdeo/api/config.toml +64 -0
- desdeo/api/db.py +27 -0
- desdeo/api/db_init.py +85 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +266 -0
- desdeo/api/models/archive.py +23 -0
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +128 -0
- desdeo/api/models/problem.py +717 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +49 -0
- desdeo/api/models/state.py +463 -0
- desdeo/api/models/user.py +52 -0
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +765 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +307 -0
- desdeo/api/routers/reference_point_method.py +93 -0
- desdeo/api/routers/session.py +100 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +520 -0
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +100 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +151 -0
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +1179 -0
- desdeo/api/tests/test_routes.py +1075 -0
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/_logger.py +29 -0
- desdeo/api/utils/database.py +36 -0
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +34 -0
- desdeo/emo/__init__.py +159 -0
- desdeo/emo/hooks/archivers.py +188 -0
- desdeo/emo/methods/EAs.py +541 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +12 -0
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +1282 -0
- desdeo/emo/operators/evaluator.py +114 -0
- desdeo/emo/operators/generator.py +459 -0
- desdeo/emo/operators/mutation.py +1224 -0
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +1778 -0
- desdeo/emo/operators/termination.py +286 -0
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +41 -0
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +656 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +186 -0
- desdeo/problem/__init__.py +83 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +487 -0
- desdeo/problem/schema.py +1829 -0
- desdeo/problem/simulator_evaluator.py +348 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +88 -0
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +283 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problems.py +440 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +274 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +120 -0
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +165 -0
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +117 -0
- desdeo/tools/indicators_unary.py +362 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +265 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +134 -0
- desdeo/tools/patterns.py +283 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +477 -0
- desdeo/tools/reference_vectors.py +229 -0
- desdeo/tools/scalarization.py +2065 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +627 -0
- desdeo/tools/utils.py +388 -0
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.1.0.dist-info/METADATA +186 -0
- desdeo-2.1.0.dist-info/RECORD +180 -0
- {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
- desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
- desdeo-1.2.dist-info/METADATA +0 -16
- desdeo-1.2.dist-info/RECORD +0 -4
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
"""This module contains the functions for user authentication."""
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import UTC, datetime, timedelta
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
import bcrypt
|
|
8
|
+
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, Security, status
|
|
9
|
+
from fastapi.responses import JSONResponse
|
|
10
|
+
from fastapi.security import (
|
|
11
|
+
APIKeyCookie,
|
|
12
|
+
HTTPAuthorizationCredentials,
|
|
13
|
+
HTTPBearer,
|
|
14
|
+
OAuth2PasswordBearer,
|
|
15
|
+
OAuth2PasswordRequestForm,
|
|
16
|
+
)
|
|
17
|
+
from jose import ExpiredSignatureError, JWTError, jwt
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from sqlmodel import Session, select
|
|
20
|
+
|
|
21
|
+
from desdeo.api import AuthConfig
|
|
22
|
+
from desdeo.api.db import get_session
|
|
23
|
+
from desdeo.api.models import User, UserPublic, UserRole
|
|
24
|
+
|
|
25
|
+
router = APIRouter()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Tokens(BaseModel):
|
|
29
|
+
"""A model for the authentication token."""
|
|
30
|
+
|
|
31
|
+
access_token: str
|
|
32
|
+
refresh_token: str
|
|
33
|
+
token_type: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# OAuth2PasswordBearer is a class that creates a dependency that will be used to get the token from the request.
|
|
37
|
+
# The token will be used to authenticate the user.
|
|
38
|
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login", auto_error=False)
|
|
39
|
+
# Same, but for getting the access_token from the cookies of the request.
|
|
40
|
+
cookie_scheme = APIKeyCookie(name="access_token", auto_error=False)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
44
|
+
"""Check if a password matches a hash.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
plain_password (str): the plain password.
|
|
48
|
+
hashed_password (str): the hashed password.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
bool: whether the plain password matches the hashed one.
|
|
52
|
+
"""
|
|
53
|
+
password_byte_enc = plain_password.encode("utf-8")
|
|
54
|
+
|
|
55
|
+
return bcrypt.checkpw(password=password_byte_enc, hashed_password=hashed_password.encode("utf-8"))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_password_hash(password: str) -> str:
|
|
59
|
+
"""Hash a password.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
password (str): the password to be hashed.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
str: the hashed password.
|
|
66
|
+
"""
|
|
67
|
+
pwd_bytes = password.encode("utf-8")
|
|
68
|
+
|
|
69
|
+
return bcrypt.hashpw(password=pwd_bytes, salt=bcrypt.gensalt()).decode("utf-8")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_user(session: Session, username: str) -> User | None:
|
|
73
|
+
"""Get the current user.
|
|
74
|
+
|
|
75
|
+
Get the current user based on the username. If no user if found,
|
|
76
|
+
return None.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
session (Session): database session.
|
|
80
|
+
username (str): the username of the user.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
User | None: the User. If no user is found, returns None.
|
|
84
|
+
"""
|
|
85
|
+
statement = select(User).where(User.username == username)
|
|
86
|
+
return session.exec(statement).first()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
|
90
|
+
"""Check if a user exists and the password is correct.
|
|
91
|
+
|
|
92
|
+
Check if a user exists and the password is correct. If the user exists and the password
|
|
93
|
+
is correct, returns the user.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
session (Session): database session.
|
|
97
|
+
username (str): the username of the user.
|
|
98
|
+
password (str): password set for the user.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
User | None: the User. If no user if found, returns None.
|
|
102
|
+
"""
|
|
103
|
+
user = get_user(session, username)
|
|
104
|
+
|
|
105
|
+
if not user or not verify_password(password, user.password_hash):
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
return user
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# token: Annotated[str, Depends(oauth2_scheme)],
|
|
112
|
+
def get_current_user(
|
|
113
|
+
session: Annotated[Session, Depends(get_session)],
|
|
114
|
+
header_token: Annotated[str | None, Security(oauth2_scheme)] = None,
|
|
115
|
+
cookie_token: Annotated[str | None, Security(cookie_scheme)] = None,
|
|
116
|
+
) -> User:
|
|
117
|
+
"""Get the current user based on a JWT token.
|
|
118
|
+
|
|
119
|
+
This function is a dependency for other functions that need to get the current user.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
token (Annotated[str, Depends(oauth2_scheme)]): The authentication token.
|
|
123
|
+
session (Annotated[Session, Depends(get_db)]): A database session.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
User: The information of the current user.
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
HTTPException: If the token is invalid.
|
|
130
|
+
"""
|
|
131
|
+
token = header_token or cookie_token
|
|
132
|
+
|
|
133
|
+
credentials_exception = HTTPException(
|
|
134
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
135
|
+
detail="Could not validate credentials",
|
|
136
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if not token:
|
|
140
|
+
raise credentials_exception
|
|
141
|
+
try:
|
|
142
|
+
payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
|
|
143
|
+
username = payload.get("sub")
|
|
144
|
+
expire_time: datetime = payload.get("exp")
|
|
145
|
+
|
|
146
|
+
if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
|
|
147
|
+
raise credentials_exception
|
|
148
|
+
|
|
149
|
+
except ExpiredSignatureError:
|
|
150
|
+
raise credentials_exception from None
|
|
151
|
+
|
|
152
|
+
except JWTError:
|
|
153
|
+
raise credentials_exception from JWTError
|
|
154
|
+
|
|
155
|
+
user = get_user(session, username=username)
|
|
156
|
+
|
|
157
|
+
if user is None:
|
|
158
|
+
raise credentials_exception
|
|
159
|
+
|
|
160
|
+
return user
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def create_jwt_token(
|
|
164
|
+
data: dict,
|
|
165
|
+
expires_delta: timedelta,
|
|
166
|
+
algorithm: str = AuthConfig.authjwt_algorithm,
|
|
167
|
+
secret_key: str = AuthConfig.authjwt_secret_key,
|
|
168
|
+
) -> str:
|
|
169
|
+
"""Creates an JWT Token with `data` and `expire_delta`.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
data (dict): The data to encode in the token.
|
|
173
|
+
expires_delta (timedelta): The time after which the token will expire.
|
|
174
|
+
algorithm (str): the algorithms to encode the JWT token.
|
|
175
|
+
Defaults to `AuthConfig.authjwt_algorithm`.
|
|
176
|
+
secret_key (str): the secret key used in encoding the JWT token.
|
|
177
|
+
Defaults to `AuthConfig.authjwt_secret_key`.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
str: the JWT token.
|
|
181
|
+
"""
|
|
182
|
+
data = data.copy()
|
|
183
|
+
expire = datetime.now(UTC) + expires_delta
|
|
184
|
+
data.update({"exp": expire, "jti": str(uuid.uuid4())})
|
|
185
|
+
# jti adds an unique identifier so that life is easier
|
|
186
|
+
|
|
187
|
+
return jwt.encode(data, secret_key, algorithm=algorithm)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def create_access_token(data: dict, expiration_time: int = AuthConfig.authjwt_access_token_expires) -> str:
|
|
191
|
+
"""Creates a JWT access token.
|
|
192
|
+
|
|
193
|
+
Creates a JWT access token with `data`, and an
|
|
194
|
+
expiration time.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
data (dict): the data to encode in the token.
|
|
198
|
+
expiration_time (int): the expiration time of the access token
|
|
199
|
+
in minutes. Defaults to `AuthConfig.authjwt_access_token_expires`.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
str: the JWT access token.
|
|
203
|
+
"""
|
|
204
|
+
return create_jwt_token(data, timedelta(minutes=expiration_time))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def create_refresh_token(data: dict, expiration_time: int = AuthConfig.authjwt_refresh_token_expires) -> str:
|
|
208
|
+
"""Creates a JTW refresh token.
|
|
209
|
+
|
|
210
|
+
Creates a JWT refresh token with `data and an expiration time.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
data (dict): The data to encode in the token.
|
|
214
|
+
expiration_time (int): the expiration time of the refresh token
|
|
215
|
+
in minutes. Defaults to `AuthConfig.authjwt_refresh_token_expires`.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
str: the JWT refresh token.
|
|
219
|
+
"""
|
|
220
|
+
refresh_token: str = create_jwt_token(data, timedelta(minutes=expiration_time))
|
|
221
|
+
|
|
222
|
+
return refresh_token
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def generate_tokens(data: dict) -> Tokens:
|
|
226
|
+
"""Generates a and refresh Tokens with `data`.
|
|
227
|
+
|
|
228
|
+
Note:
|
|
229
|
+
The expiration times of the tokens in defined in
|
|
230
|
+
`AuthConfig`.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
data (dict): The data to encode in the token.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Tokens: the access and refresh tokens.
|
|
237
|
+
"""
|
|
238
|
+
access_token = create_access_token(data)
|
|
239
|
+
refresh_token = create_refresh_token(data)
|
|
240
|
+
return Tokens(access_token=access_token, refresh_token=refresh_token, token_type="bearer") # noqa: S106
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def validate_refresh_token(
|
|
244
|
+
refresh_token: str,
|
|
245
|
+
session: Annotated[Session, Depends(get_session)],
|
|
246
|
+
algorithm: str = AuthConfig.authjwt_algorithm,
|
|
247
|
+
secret_key: str = AuthConfig.authjwt_secret_key,
|
|
248
|
+
) -> User:
|
|
249
|
+
"""Validate a refresh token and return the associated user if valid.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
refresh_token (str): The refresh token to validate.
|
|
253
|
+
session (Annotated[Session, Depends(get_db)]): The database session.
|
|
254
|
+
algorithm (str): the algorithm used to decode the JWT token.
|
|
255
|
+
Defaults to `AuthConfig.authjwt_algorithm`.
|
|
256
|
+
secret_key (str): the secret key used to decode the JWT token.
|
|
257
|
+
Defaults to `AuthConfig.authjwt_secret_key`.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
UserModel: The user associated with the valid refresh token.
|
|
261
|
+
|
|
262
|
+
Raises:
|
|
263
|
+
HTTPException: If the refresh token is invalid or expired.
|
|
264
|
+
"""
|
|
265
|
+
credentials_exception = HTTPException(
|
|
266
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
267
|
+
detail="Invalid or expired refresh token",
|
|
268
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
try:
|
|
272
|
+
# Decode the refresh token
|
|
273
|
+
payload = jwt.decode(refresh_token, secret_key, algorithms=[algorithm])
|
|
274
|
+
username = payload.get("sub")
|
|
275
|
+
expire_time: datetime = payload.get("exp")
|
|
276
|
+
|
|
277
|
+
except Exception as _:
|
|
278
|
+
raise credentials_exception from None
|
|
279
|
+
|
|
280
|
+
if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
|
|
281
|
+
raise credentials_exception
|
|
282
|
+
|
|
283
|
+
# Validate the user from the database
|
|
284
|
+
user = get_user(session, username=username)
|
|
285
|
+
if user is None:
|
|
286
|
+
raise credentials_exception
|
|
287
|
+
|
|
288
|
+
return user
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def add_user_to_database(
|
|
292
|
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
293
|
+
role: UserRole,
|
|
294
|
+
session: Annotated[Session, Depends(get_session)],
|
|
295
|
+
) -> None:
|
|
296
|
+
"""Add a user to database.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]: form with username and password to be added
|
|
300
|
+
role: UserRole: Role of the user to be added to the database
|
|
301
|
+
session: Annotated[Session, Depends(get_session)]: database session
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
None
|
|
305
|
+
|
|
306
|
+
Raises:
|
|
307
|
+
HTTPException: If username already is in the database or if adding the user to the database failed.
|
|
308
|
+
"""
|
|
309
|
+
username = form_data.username
|
|
310
|
+
password = form_data.password
|
|
311
|
+
|
|
312
|
+
# Check if a user with requested username is already in the database
|
|
313
|
+
if get_user(session=session, username=username):
|
|
314
|
+
raise HTTPException(
|
|
315
|
+
status_code=status.HTTP_409_CONFLICT,
|
|
316
|
+
detail="Username already taken.",
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Create the user model and put it into database
|
|
320
|
+
new_user = User(
|
|
321
|
+
username=username,
|
|
322
|
+
password_hash=get_password_hash(
|
|
323
|
+
password=password,
|
|
324
|
+
),
|
|
325
|
+
role=role,
|
|
326
|
+
)
|
|
327
|
+
session.add(new_user)
|
|
328
|
+
session.commit()
|
|
329
|
+
session.refresh(new_user)
|
|
330
|
+
|
|
331
|
+
# Verify that the user actually is in the database
|
|
332
|
+
if not get_user(session=session, username=username):
|
|
333
|
+
raise HTTPException(
|
|
334
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
335
|
+
detail="Failed to add user into database.",
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@router.get("/user_info")
|
|
340
|
+
def get_current_user_info(user: Annotated[User, Depends(get_current_user)]) -> UserPublic:
|
|
341
|
+
"""Return information about the current user.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
user (Annotated[User, Depends): user dependency, handled by `get_current_user`.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
UserPublic: public information about the current user.
|
|
348
|
+
"""
|
|
349
|
+
return user
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@router.post("/login", response_model=Tokens)
|
|
353
|
+
def login(
|
|
354
|
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
355
|
+
session: Annotated[Session, Depends(get_session)],
|
|
356
|
+
cookie_max_age: int = AuthConfig.authjwt_refresh_token_expires,
|
|
357
|
+
):
|
|
358
|
+
"""Login to get an authentication token.
|
|
359
|
+
|
|
360
|
+
Return an access token in the response and a cookie storing a refresh token.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
form_data (Annotated[OAuth2PasswordRequestForm, Depends()]):
|
|
364
|
+
The form data to authenticate the user.
|
|
365
|
+
session (Annotated[Session, Depends(get_db)]): The database session.
|
|
366
|
+
cookie_max_age (int): the lifetime of the cookie storing the refresh token.
|
|
367
|
+
|
|
368
|
+
"""
|
|
369
|
+
user = authenticate_user(session, form_data.username, form_data.password)
|
|
370
|
+
if user is None:
|
|
371
|
+
raise HTTPException(
|
|
372
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
373
|
+
detail="Incorrect username or password",
|
|
374
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
tokens = generate_tokens({"id": user.id, "sub": user.username})
|
|
378
|
+
|
|
379
|
+
response = JSONResponse(content={"access_token": tokens.access_token, "refresh_token": tokens.refresh_token})
|
|
380
|
+
|
|
381
|
+
if AuthConfig.cookie_domain == "":
|
|
382
|
+
response.set_cookie(
|
|
383
|
+
key="refresh_token",
|
|
384
|
+
value=tokens.refresh_token,
|
|
385
|
+
httponly=True, # HTTP only cookie, more secure than storing the refresh token in the frontend code.
|
|
386
|
+
secure=False, # allow http
|
|
387
|
+
samesite="lax", # cross-origin requests
|
|
388
|
+
max_age=cookie_max_age * 60, # convert to minutes
|
|
389
|
+
path="/",
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
response.set_cookie(
|
|
393
|
+
key="refresh_token",
|
|
394
|
+
value=tokens.refresh_token,
|
|
395
|
+
httponly=True, # keep this
|
|
396
|
+
secure=True, # MUST be true for HTTPS
|
|
397
|
+
samesite="none", # required for cross-site subdomains
|
|
398
|
+
max_age=cookie_max_age * 60,
|
|
399
|
+
path="/",
|
|
400
|
+
domain=AuthConfig.cookie_domain, # <- allow sharing between API + webui
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return response
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@router.post("/logout")
|
|
407
|
+
def logout() -> JSONResponse:
|
|
408
|
+
"""Log the current user out. Deletes the refresh token that was set by logging in.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
None
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
JSONResponse: A response in which the cookies are deleted
|
|
415
|
+
|
|
416
|
+
"""
|
|
417
|
+
response = JSONResponse(content={"message": "logged out"}, status_code=status.HTTP_200_OK)
|
|
418
|
+
response.delete_cookie("refresh_token")
|
|
419
|
+
return response
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
@router.post("/refresh")
|
|
423
|
+
def refresh_access_token(
|
|
424
|
+
request: Response,
|
|
425
|
+
session: Annotated[Session, Depends(get_session)],
|
|
426
|
+
refresh_token: Annotated[str | None, Cookie()] = None,
|
|
427
|
+
):
|
|
428
|
+
"""Refresh the access token using the refresh token stored in the cookie.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
request (Request): The request containing the cookie.
|
|
432
|
+
session (Annotated[Session, Depends(get_db)]): the database session.
|
|
433
|
+
refresh_token (Annotated[Str | None, Cookie()]): the refresh
|
|
434
|
+
token, which is fetched from a cookie included in the response.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
dict: A dictionary containing the new access token.
|
|
438
|
+
"""
|
|
439
|
+
if not refresh_token:
|
|
440
|
+
raise HTTPException(
|
|
441
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
442
|
+
detail="Missing refresh token.",
|
|
443
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
user = validate_refresh_token(refresh_token, session)
|
|
447
|
+
|
|
448
|
+
# Generate a new access token for the user
|
|
449
|
+
access_token = create_access_token({"id": user.id, "sub": user.username})
|
|
450
|
+
|
|
451
|
+
return {"access_token": access_token}
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@router.post("/add_new_dm")
|
|
455
|
+
def add_new_dm(
|
|
456
|
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
457
|
+
session: Annotated[Session, Depends(get_session)],
|
|
458
|
+
) -> JSONResponse:
|
|
459
|
+
"""Add a new user of the role Decision Maker to the database. Requires no login.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
form_data (Annotated[OAuth2PasswordRequestForm, Depends()]): The user credentials to add to the database.
|
|
463
|
+
session (Annotated[Session, Depends(get_session)]): the database session.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
JSONResponse: A JSON response
|
|
467
|
+
|
|
468
|
+
Raises:
|
|
469
|
+
HTTPException: if username is already in use or if saving to the database fails for some reason.
|
|
470
|
+
"""
|
|
471
|
+
add_user_to_database(
|
|
472
|
+
form_data=form_data,
|
|
473
|
+
role=UserRole.dm,
|
|
474
|
+
session=session,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
return JSONResponse(
|
|
478
|
+
content={"message": 'User with role "decision maker" created.'},
|
|
479
|
+
status_code=status.HTTP_201_CREATED,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@router.post("/add_new_analyst")
|
|
484
|
+
def add_new_analyst(
|
|
485
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
486
|
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
487
|
+
session: Annotated[Session, Depends(get_session)],
|
|
488
|
+
) -> JSONResponse:
|
|
489
|
+
"""Add a new user of the role Analyst to the database. Requires a logged in analyst or an admin.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
user: Annotated[User, Depends(get_current_user)]: Logged in user with the role "analyst" or "admin".
|
|
493
|
+
form_data: (Annotated[OAuth2PasswordRequestForm, Depends()]): The user credentials to add to the database.
|
|
494
|
+
session: (Annotated[Session, Depends(get_session)]): the database session.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
JSONResponse: A JSON response
|
|
498
|
+
|
|
499
|
+
Raises:
|
|
500
|
+
HTTPException: if the logged in user is not an analyst or an admin or if
|
|
501
|
+
username is already in use or if saving to the database fails for some reason.
|
|
502
|
+
|
|
503
|
+
"""
|
|
504
|
+
# Check if the user who tries to create the user is either an analyst or an admin.
|
|
505
|
+
if not (user.role == UserRole.analyst or user.role == UserRole.admin):
|
|
506
|
+
raise HTTPException(
|
|
507
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
508
|
+
detail="Logged in user has insufficient rights.",
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
add_user_to_database(
|
|
512
|
+
form_data=form_data,
|
|
513
|
+
role=UserRole.analyst,
|
|
514
|
+
session=session,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
return JSONResponse(
|
|
518
|
+
content={"message": 'User with role "analyst" created.'},
|
|
519
|
+
status_code=status.HTTP_201_CREATED,
|
|
520
|
+
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""A selection of utilities for handling routers and data therein.
|
|
2
|
+
|
|
3
|
+
NOTE: No routers should be defined in this file!
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Annotated
|
|
8
|
+
|
|
9
|
+
from fastapi import Depends, HTTPException, status
|
|
10
|
+
from sqlmodel import Session, select
|
|
11
|
+
|
|
12
|
+
from desdeo.api.db import get_session
|
|
13
|
+
from desdeo.api.models import (
|
|
14
|
+
ENautilusStepRequest,
|
|
15
|
+
InteractiveSessionDB,
|
|
16
|
+
ProblemDB,
|
|
17
|
+
RPMSolveRequest,
|
|
18
|
+
StateDB,
|
|
19
|
+
User,
|
|
20
|
+
)
|
|
21
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
22
|
+
|
|
23
|
+
RequestType = RPMSolveRequest | ENautilusStepRequest
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None:
|
|
27
|
+
"""Gets the desired instance of `InteractiveSessionDB`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
user (User): the user whose interactive sessions are to be queried.
|
|
31
|
+
request (RequestType): the request with possibly information on which interactive session to query.
|
|
32
|
+
session (Session): the database session (not to be confused with the interactive session) from
|
|
33
|
+
which the interactive session should be queried.
|
|
34
|
+
|
|
35
|
+
Note:
|
|
36
|
+
If no explicit `session_id` is given in `request`, this function will try to fetch the
|
|
37
|
+
currently active interactive session for the `user`, e.g., with id `user.active_session_id`.
|
|
38
|
+
If this is `None`, then the interactive session returned will be `None` as well.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
HTTPException: when an explicit interactive session is requested, but it is not found.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
InteractiveSessionDB | None: an interactive session DB model, or nothing.
|
|
45
|
+
"""
|
|
46
|
+
if request.session_id is not None:
|
|
47
|
+
# specific interactive session id is given, try using that
|
|
48
|
+
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
|
|
49
|
+
interactive_session = session.exec(statement).first()
|
|
50
|
+
|
|
51
|
+
if interactive_session is None:
|
|
52
|
+
# Raise if explicitly requested interactive session cannot be found
|
|
53
|
+
raise HTTPException(
|
|
54
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
55
|
+
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
# request.session_id is None
|
|
59
|
+
# try to use active session instead
|
|
60
|
+
|
|
61
|
+
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
62
|
+
|
|
63
|
+
interactive_session = session.exec(statement).first()
|
|
64
|
+
|
|
65
|
+
# At this point interactive_session is either an instance of InteractiveSessionDB or None (which is fine)
|
|
66
|
+
|
|
67
|
+
return interactive_session
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB:
|
|
71
|
+
"""Fetches a user's `ProblemDB` based on the id in the given request.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
user (User): the user for which the problem is fetched.
|
|
75
|
+
request (RequestType): request containing details of the problem to be fetched (`request.problem_id`).
|
|
76
|
+
session (Session): the database session from which to fetch the problem.
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
HTTPException: a problem with the given id (`request.problem_id`) could not be found (404).
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Problem: the instance of `ProblemDB` with the given id.
|
|
83
|
+
"""
|
|
84
|
+
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
85
|
+
problem_db = session.exec(statement).first()
|
|
86
|
+
|
|
87
|
+
if problem_db is None:
|
|
88
|
+
raise HTTPException(
|
|
89
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return problem_db
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def fetch_parent_state(
|
|
96
|
+
user: User, request: RequestType, session: Session, interactive_session: InteractiveSessionDB | None = None
|
|
97
|
+
) -> StateDB | None:
|
|
98
|
+
"""Fetches the parent state, if an id is given, or if defined in the given interactive session.
|
|
99
|
+
|
|
100
|
+
Determines the appropriate parent `StateDB` instance to associate with a new
|
|
101
|
+
state or operation. It first checks whether the `request` explicitly
|
|
102
|
+
provides a `parent_state_id`. If so, it attempts to retrieve the
|
|
103
|
+
corresponding `StateDB` entry from the database. If no such id is provided,
|
|
104
|
+
the function defaults to returning the most recently added state from the
|
|
105
|
+
given `interactive_session`, if available. If neither source provides a
|
|
106
|
+
parent state, `None` is returned.
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
user (User): the user for which the parent state is fetched.
|
|
111
|
+
request (RequestType): request containing details about the parent state and optionally the
|
|
112
|
+
interactive session.
|
|
113
|
+
session (Session): the database session from which to fetch the parent state.
|
|
114
|
+
interactive_session (InteractiveSessionDB | None, optional): the interactive session containing
|
|
115
|
+
information about the parent state. Defaults to None.
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot
|
|
119
|
+
be found in the given database session.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
StateDB | None: if `request.parent_state_id` is given, returns the corresponding `StateDB`.
|
|
123
|
+
If it is not given, returns the latest state defined in `interactive_session.states`.
|
|
124
|
+
If both `request.parent_state_id` and `interactive_session` are `None`, then returns `None`.
|
|
125
|
+
"""
|
|
126
|
+
if request.parent_state_id is None:
|
|
127
|
+
# parent state is assumed to be the last sate added to the session.
|
|
128
|
+
# if `interactive_session` is None, then parent state is set to None.
|
|
129
|
+
parent_state = (
|
|
130
|
+
interactive_session.states[-1]
|
|
131
|
+
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
132
|
+
else None
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
else:
|
|
136
|
+
# request.parent_state_id is not None
|
|
137
|
+
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
138
|
+
parent_state = session.exec(statement).first()
|
|
139
|
+
|
|
140
|
+
# this error is raised because if a parent_state_id is given, it is assumed that the
|
|
141
|
+
# user wished to use that state explicitly as the parent.
|
|
142
|
+
if parent_state is None:
|
|
143
|
+
raise HTTPException(
|
|
144
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return parent_state
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass(frozen=True)
|
|
151
|
+
class SessionContext:
|
|
152
|
+
"""A generic context to be used in various endpoints."""
|
|
153
|
+
|
|
154
|
+
user: User
|
|
155
|
+
db_session: Session
|
|
156
|
+
problem_db: ProblemDB
|
|
157
|
+
interactive_session: InteractiveSessionDB | None
|
|
158
|
+
parent_state: StateDB | None
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_session_context(
|
|
162
|
+
request: RequestType,
|
|
163
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
164
|
+
db_session: Annotated[Session, Depends(get_session)],
|
|
165
|
+
) -> SessionContext:
|
|
166
|
+
"""Gets the current session context. Should be used as a dep.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
request (RequestType): request based on which the context is fetched.
|
|
170
|
+
user (Annotated[User, Depends): the current user (dep).
|
|
171
|
+
db_session (Annotated[Session, Depends): the current database session (dep).
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
SessionContext: the current session context with the relevant instances
|
|
175
|
+
of `User`, `Session`, `ProblemDB`, `InteractiveSessionDB`, and `StateDB`.
|
|
176
|
+
"""
|
|
177
|
+
problem_db = fetch_user_problem(user, request, db_session)
|
|
178
|
+
interactive_session = fetch_interactive_session(user, request, db_session)
|
|
179
|
+
parent_state = fetch_parent_state(user, request, db_session, interactive_session=interactive_session)
|
|
180
|
+
|
|
181
|
+
return SessionContext(
|
|
182
|
+
user=user,
|
|
183
|
+
db_session=db_session,
|
|
184
|
+
problem_db=problem_db,
|
|
185
|
+
interactive_session=interactive_session,
|
|
186
|
+
parent_state=parent_state,
|
|
187
|
+
)
|