desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/adm/ADMAfsar.py +551 -0
  3. desdeo/adm/ADMChen.py +414 -0
  4. desdeo/adm/BaseADM.py +119 -0
  5. desdeo/adm/__init__.py +11 -0
  6. desdeo/api/README.md +73 -0
  7. desdeo/api/__init__.py +15 -0
  8. desdeo/api/app.py +50 -0
  9. desdeo/api/config.py +90 -0
  10. desdeo/api/config.toml +64 -0
  11. desdeo/api/db.py +27 -0
  12. desdeo/api/db_init.py +85 -0
  13. desdeo/api/db_models.py +164 -0
  14. desdeo/api/malaga_db_init.py +27 -0
  15. desdeo/api/models/__init__.py +266 -0
  16. desdeo/api/models/archive.py +23 -0
  17. desdeo/api/models/emo.py +128 -0
  18. desdeo/api/models/enautilus.py +69 -0
  19. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  20. desdeo/api/models/gdm/gdm_base.py +69 -0
  21. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  22. desdeo/api/models/gdm/gnimbus.py +138 -0
  23. desdeo/api/models/generic.py +104 -0
  24. desdeo/api/models/generic_states.py +401 -0
  25. desdeo/api/models/nimbus.py +158 -0
  26. desdeo/api/models/preference.py +128 -0
  27. desdeo/api/models/problem.py +717 -0
  28. desdeo/api/models/reference_point_method.py +18 -0
  29. desdeo/api/models/session.py +49 -0
  30. desdeo/api/models/state.py +463 -0
  31. desdeo/api/models/user.py +52 -0
  32. desdeo/api/models/utopia.py +25 -0
  33. desdeo/api/routers/_EMO.backup +309 -0
  34. desdeo/api/routers/_NAUTILUS.py +245 -0
  35. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  36. desdeo/api/routers/_NIMBUS.py +765 -0
  37. desdeo/api/routers/__init__.py +5 -0
  38. desdeo/api/routers/emo.py +497 -0
  39. desdeo/api/routers/enautilus.py +237 -0
  40. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  41. desdeo/api/routers/gdm/gdm_base.py +420 -0
  42. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  43. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  44. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  45. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  46. desdeo/api/routers/generic.py +233 -0
  47. desdeo/api/routers/nimbus.py +705 -0
  48. desdeo/api/routers/problem.py +307 -0
  49. desdeo/api/routers/reference_point_method.py +93 -0
  50. desdeo/api/routers/session.py +100 -0
  51. desdeo/api/routers/test.py +16 -0
  52. desdeo/api/routers/user_authentication.py +520 -0
  53. desdeo/api/routers/utils.py +187 -0
  54. desdeo/api/routers/utopia.py +230 -0
  55. desdeo/api/schema.py +100 -0
  56. desdeo/api/tests/__init__.py +0 -0
  57. desdeo/api/tests/conftest.py +151 -0
  58. desdeo/api/tests/test_enautilus.py +330 -0
  59. desdeo/api/tests/test_models.py +1179 -0
  60. desdeo/api/tests/test_routes.py +1075 -0
  61. desdeo/api/utils/_database.py +263 -0
  62. desdeo/api/utils/_logger.py +29 -0
  63. desdeo/api/utils/database.py +36 -0
  64. desdeo/api/utils/emo_database.py +40 -0
  65. desdeo/core.py +34 -0
  66. desdeo/emo/__init__.py +159 -0
  67. desdeo/emo/hooks/archivers.py +188 -0
  68. desdeo/emo/methods/EAs.py +541 -0
  69. desdeo/emo/methods/__init__.py +0 -0
  70. desdeo/emo/methods/bases.py +12 -0
  71. desdeo/emo/methods/templates.py +111 -0
  72. desdeo/emo/operators/__init__.py +1 -0
  73. desdeo/emo/operators/crossover.py +1282 -0
  74. desdeo/emo/operators/evaluator.py +114 -0
  75. desdeo/emo/operators/generator.py +459 -0
  76. desdeo/emo/operators/mutation.py +1224 -0
  77. desdeo/emo/operators/scalar_selection.py +202 -0
  78. desdeo/emo/operators/selection.py +1778 -0
  79. desdeo/emo/operators/termination.py +286 -0
  80. desdeo/emo/options/__init__.py +108 -0
  81. desdeo/emo/options/algorithms.py +435 -0
  82. desdeo/emo/options/crossover.py +164 -0
  83. desdeo/emo/options/generator.py +131 -0
  84. desdeo/emo/options/mutation.py +260 -0
  85. desdeo/emo/options/repair.py +61 -0
  86. desdeo/emo/options/scalar_selection.py +66 -0
  87. desdeo/emo/options/selection.py +127 -0
  88. desdeo/emo/options/templates.py +383 -0
  89. desdeo/emo/options/termination.py +143 -0
  90. desdeo/explanations/__init__.py +6 -0
  91. desdeo/explanations/explainer.py +100 -0
  92. desdeo/explanations/utils.py +90 -0
  93. desdeo/gdm/__init__.py +22 -0
  94. desdeo/gdm/gdmtools.py +45 -0
  95. desdeo/gdm/score_bands.py +114 -0
  96. desdeo/gdm/voting_rules.py +50 -0
  97. desdeo/mcdm/__init__.py +41 -0
  98. desdeo/mcdm/enautilus.py +338 -0
  99. desdeo/mcdm/gnimbus.py +484 -0
  100. desdeo/mcdm/nautili.py +345 -0
  101. desdeo/mcdm/nautilus.py +477 -0
  102. desdeo/mcdm/nautilus_navigator.py +656 -0
  103. desdeo/mcdm/nimbus.py +417 -0
  104. desdeo/mcdm/pareto_navigator.py +269 -0
  105. desdeo/mcdm/reference_point_method.py +186 -0
  106. desdeo/problem/__init__.py +83 -0
  107. desdeo/problem/evaluator.py +561 -0
  108. desdeo/problem/external/__init__.py +18 -0
  109. desdeo/problem/external/core.py +356 -0
  110. desdeo/problem/external/pymoo_provider.py +266 -0
  111. desdeo/problem/external/runtime.py +44 -0
  112. desdeo/problem/gurobipy_evaluator.py +562 -0
  113. desdeo/problem/infix_parser.py +341 -0
  114. desdeo/problem/json_parser.py +944 -0
  115. desdeo/problem/pyomo_evaluator.py +487 -0
  116. desdeo/problem/schema.py +1829 -0
  117. desdeo/problem/simulator_evaluator.py +348 -0
  118. desdeo/problem/sympy_evaluator.py +244 -0
  119. desdeo/problem/testproblems/__init__.py +88 -0
  120. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  121. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  122. desdeo/problem/testproblems/cake_problem.py +185 -0
  123. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  124. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  125. desdeo/problem/testproblems/forest_problem.py +283 -0
  126. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  127. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  128. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  129. desdeo/problem/testproblems/momip_problem.py +172 -0
  130. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  131. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  132. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  133. desdeo/problem/testproblems/re_problem.py +492 -0
  134. desdeo/problem/testproblems/river_pollution_problems.py +440 -0
  135. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  136. desdeo/problem/testproblems/simple_problem.py +351 -0
  137. desdeo/problem/testproblems/simulator_problem.py +92 -0
  138. desdeo/problem/testproblems/single_objective.py +289 -0
  139. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  140. desdeo/problem/testproblems/zdt_problem.py +274 -0
  141. desdeo/problem/utils.py +245 -0
  142. desdeo/tools/GenerateReferencePoints.py +181 -0
  143. desdeo/tools/__init__.py +120 -0
  144. desdeo/tools/desc_gen.py +22 -0
  145. desdeo/tools/generics.py +165 -0
  146. desdeo/tools/group_scalarization.py +3090 -0
  147. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  148. desdeo/tools/indicators_binary.py +117 -0
  149. desdeo/tools/indicators_unary.py +362 -0
  150. desdeo/tools/interaction_schema.py +38 -0
  151. desdeo/tools/intersection.py +54 -0
  152. desdeo/tools/iterative_pareto_representer.py +99 -0
  153. desdeo/tools/message.py +265 -0
  154. desdeo/tools/ng_solver_interfaces.py +199 -0
  155. desdeo/tools/non_dominated_sorting.py +134 -0
  156. desdeo/tools/patterns.py +283 -0
  157. desdeo/tools/proximal_solver.py +99 -0
  158. desdeo/tools/pyomo_solver_interfaces.py +477 -0
  159. desdeo/tools/reference_vectors.py +229 -0
  160. desdeo/tools/scalarization.py +2065 -0
  161. desdeo/tools/scipy_solver_interfaces.py +454 -0
  162. desdeo/tools/score_bands.py +627 -0
  163. desdeo/tools/utils.py +388 -0
  164. desdeo/tools/visualizations.py +67 -0
  165. desdeo/utopia_stuff/__init__.py +0 -0
  166. desdeo/utopia_stuff/data/1.json +15 -0
  167. desdeo/utopia_stuff/data/2.json +13 -0
  168. desdeo/utopia_stuff/data/3.json +15 -0
  169. desdeo/utopia_stuff/data/4.json +17 -0
  170. desdeo/utopia_stuff/data/5.json +15 -0
  171. desdeo/utopia_stuff/from_json.py +40 -0
  172. desdeo/utopia_stuff/reinit_user.py +38 -0
  173. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  174. desdeo/utopia_stuff/utopia_problem.py +403 -0
  175. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  176. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  177. desdeo-2.1.0.dist-info/METADATA +186 -0
  178. desdeo-2.1.0.dist-info/RECORD +180 -0
  179. {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  180. desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
  181. desdeo-1.2.dist-info/METADATA +0 -16
  182. desdeo-1.2.dist-info/RECORD +0 -4
@@ -0,0 +1,520 @@
1
+ """This module contains the functions for user authentication."""
2
+
3
+ import uuid
4
+ from datetime import UTC, datetime, timedelta
5
+ from typing import Annotated
6
+
7
+ import bcrypt
8
+ from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, Security, status
9
+ from fastapi.responses import JSONResponse
10
+ from fastapi.security import (
11
+ APIKeyCookie,
12
+ HTTPAuthorizationCredentials,
13
+ HTTPBearer,
14
+ OAuth2PasswordBearer,
15
+ OAuth2PasswordRequestForm,
16
+ )
17
+ from jose import ExpiredSignatureError, JWTError, jwt
18
+ from pydantic import BaseModel
19
+ from sqlmodel import Session, select
20
+
21
+ from desdeo.api import AuthConfig
22
+ from desdeo.api.db import get_session
23
+ from desdeo.api.models import User, UserPublic, UserRole
24
+
25
+ router = APIRouter()
26
+
27
+
28
+ class Tokens(BaseModel):
29
+ """A model for the authentication token."""
30
+
31
+ access_token: str
32
+ refresh_token: str
33
+ token_type: str
34
+
35
+
36
+ # OAuth2PasswordBearer is a class that creates a dependency that will be used to get the token from the request.
37
+ # The token will be used to authenticate the user.
38
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login", auto_error=False)
39
+ # Same, but for getting the access_token from the cookies of the request.
40
+ cookie_scheme = APIKeyCookie(name="access_token", auto_error=False)
41
+
42
+
43
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
44
+ """Check if a password matches a hash.
45
+
46
+ Args:
47
+ plain_password (str): the plain password.
48
+ hashed_password (str): the hashed password.
49
+
50
+ Returns:
51
+ bool: whether the plain password matches the hashed one.
52
+ """
53
+ password_byte_enc = plain_password.encode("utf-8")
54
+
55
+ return bcrypt.checkpw(password=password_byte_enc, hashed_password=hashed_password.encode("utf-8"))
56
+
57
+
58
+ def get_password_hash(password: str) -> str:
59
+ """Hash a password.
60
+
61
+ Args:
62
+ password (str): the password to be hashed.
63
+
64
+ Returns:
65
+ str: the hashed password.
66
+ """
67
+ pwd_bytes = password.encode("utf-8")
68
+
69
+ return bcrypt.hashpw(password=pwd_bytes, salt=bcrypt.gensalt()).decode("utf-8")
70
+
71
+
72
+ def get_user(session: Session, username: str) -> User | None:
73
+ """Get the current user.
74
+
75
+ Get the current user based on the username. If no user if found,
76
+ return None.
77
+
78
+ Args:
79
+ session (Session): database session.
80
+ username (str): the username of the user.
81
+
82
+ Returns:
83
+ User | None: the User. If no user is found, returns None.
84
+ """
85
+ statement = select(User).where(User.username == username)
86
+ return session.exec(statement).first()
87
+
88
+
89
+ def authenticate_user(session: Session, username: str, password: str) -> User | None:
90
+ """Check if a user exists and the password is correct.
91
+
92
+ Check if a user exists and the password is correct. If the user exists and the password
93
+ is correct, returns the user.
94
+
95
+ Args:
96
+ session (Session): database session.
97
+ username (str): the username of the user.
98
+ password (str): password set for the user.
99
+
100
+ Returns:
101
+ User | None: the User. If no user if found, returns None.
102
+ """
103
+ user = get_user(session, username)
104
+
105
+ if not user or not verify_password(password, user.password_hash):
106
+ return None
107
+
108
+ return user
109
+
110
+
111
+ # token: Annotated[str, Depends(oauth2_scheme)],
112
+ def get_current_user(
113
+ session: Annotated[Session, Depends(get_session)],
114
+ header_token: Annotated[str | None, Security(oauth2_scheme)] = None,
115
+ cookie_token: Annotated[str | None, Security(cookie_scheme)] = None,
116
+ ) -> User:
117
+ """Get the current user based on a JWT token.
118
+
119
+ This function is a dependency for other functions that need to get the current user.
120
+
121
+ Args:
122
+ token (Annotated[str, Depends(oauth2_scheme)]): The authentication token.
123
+ session (Annotated[Session, Depends(get_db)]): A database session.
124
+
125
+ Returns:
126
+ User: The information of the current user.
127
+
128
+ Raises:
129
+ HTTPException: If the token is invalid.
130
+ """
131
+ token = header_token or cookie_token
132
+
133
+ credentials_exception = HTTPException(
134
+ status_code=status.HTTP_401_UNAUTHORIZED,
135
+ detail="Could not validate credentials",
136
+ headers={"WWW-Authenticate": "Bearer"},
137
+ )
138
+
139
+ if not token:
140
+ raise credentials_exception
141
+ try:
142
+ payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
143
+ username = payload.get("sub")
144
+ expire_time: datetime = payload.get("exp")
145
+
146
+ if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
147
+ raise credentials_exception
148
+
149
+ except ExpiredSignatureError:
150
+ raise credentials_exception from None
151
+
152
+ except JWTError:
153
+ raise credentials_exception from JWTError
154
+
155
+ user = get_user(session, username=username)
156
+
157
+ if user is None:
158
+ raise credentials_exception
159
+
160
+ return user
161
+
162
+
163
+ def create_jwt_token(
164
+ data: dict,
165
+ expires_delta: timedelta,
166
+ algorithm: str = AuthConfig.authjwt_algorithm,
167
+ secret_key: str = AuthConfig.authjwt_secret_key,
168
+ ) -> str:
169
+ """Creates an JWT Token with `data` and `expire_delta`.
170
+
171
+ Args:
172
+ data (dict): The data to encode in the token.
173
+ expires_delta (timedelta): The time after which the token will expire.
174
+ algorithm (str): the algorithms to encode the JWT token.
175
+ Defaults to `AuthConfig.authjwt_algorithm`.
176
+ secret_key (str): the secret key used in encoding the JWT token.
177
+ Defaults to `AuthConfig.authjwt_secret_key`.
178
+
179
+ Returns:
180
+ str: the JWT token.
181
+ """
182
+ data = data.copy()
183
+ expire = datetime.now(UTC) + expires_delta
184
+ data.update({"exp": expire, "jti": str(uuid.uuid4())})
185
+ # jti adds an unique identifier so that life is easier
186
+
187
+ return jwt.encode(data, secret_key, algorithm=algorithm)
188
+
189
+
190
+ def create_access_token(data: dict, expiration_time: int = AuthConfig.authjwt_access_token_expires) -> str:
191
+ """Creates a JWT access token.
192
+
193
+ Creates a JWT access token with `data`, and an
194
+ expiration time.
195
+
196
+ Args:
197
+ data (dict): the data to encode in the token.
198
+ expiration_time (int): the expiration time of the access token
199
+ in minutes. Defaults to `AuthConfig.authjwt_access_token_expires`.
200
+
201
+ Returns:
202
+ str: the JWT access token.
203
+ """
204
+ return create_jwt_token(data, timedelta(minutes=expiration_time))
205
+
206
+
207
+ def create_refresh_token(data: dict, expiration_time: int = AuthConfig.authjwt_refresh_token_expires) -> str:
208
+ """Creates a JTW refresh token.
209
+
210
+ Creates a JWT refresh token with `data and an expiration time.
211
+
212
+ Args:
213
+ data (dict): The data to encode in the token.
214
+ expiration_time (int): the expiration time of the refresh token
215
+ in minutes. Defaults to `AuthConfig.authjwt_refresh_token_expires`.
216
+
217
+ Returns:
218
+ str: the JWT refresh token.
219
+ """
220
+ refresh_token: str = create_jwt_token(data, timedelta(minutes=expiration_time))
221
+
222
+ return refresh_token
223
+
224
+
225
+ def generate_tokens(data: dict) -> Tokens:
226
+ """Generates a and refresh Tokens with `data`.
227
+
228
+ Note:
229
+ The expiration times of the tokens in defined in
230
+ `AuthConfig`.
231
+
232
+ Args:
233
+ data (dict): The data to encode in the token.
234
+
235
+ Returns:
236
+ Tokens: the access and refresh tokens.
237
+ """
238
+ access_token = create_access_token(data)
239
+ refresh_token = create_refresh_token(data)
240
+ return Tokens(access_token=access_token, refresh_token=refresh_token, token_type="bearer") # noqa: S106
241
+
242
+
243
+ def validate_refresh_token(
244
+ refresh_token: str,
245
+ session: Annotated[Session, Depends(get_session)],
246
+ algorithm: str = AuthConfig.authjwt_algorithm,
247
+ secret_key: str = AuthConfig.authjwt_secret_key,
248
+ ) -> User:
249
+ """Validate a refresh token and return the associated user if valid.
250
+
251
+ Args:
252
+ refresh_token (str): The refresh token to validate.
253
+ session (Annotated[Session, Depends(get_db)]): The database session.
254
+ algorithm (str): the algorithm used to decode the JWT token.
255
+ Defaults to `AuthConfig.authjwt_algorithm`.
256
+ secret_key (str): the secret key used to decode the JWT token.
257
+ Defaults to `AuthConfig.authjwt_secret_key`.
258
+
259
+ Returns:
260
+ UserModel: The user associated with the valid refresh token.
261
+
262
+ Raises:
263
+ HTTPException: If the refresh token is invalid or expired.
264
+ """
265
+ credentials_exception = HTTPException(
266
+ status_code=status.HTTP_401_UNAUTHORIZED,
267
+ detail="Invalid or expired refresh token",
268
+ headers={"WWW-Authenticate": "Bearer"},
269
+ )
270
+
271
+ try:
272
+ # Decode the refresh token
273
+ payload = jwt.decode(refresh_token, secret_key, algorithms=[algorithm])
274
+ username = payload.get("sub")
275
+ expire_time: datetime = payload.get("exp")
276
+
277
+ except Exception as _:
278
+ raise credentials_exception from None
279
+
280
+ if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
281
+ raise credentials_exception
282
+
283
+ # Validate the user from the database
284
+ user = get_user(session, username=username)
285
+ if user is None:
286
+ raise credentials_exception
287
+
288
+ return user
289
+
290
+
291
+ def add_user_to_database(
292
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
293
+ role: UserRole,
294
+ session: Annotated[Session, Depends(get_session)],
295
+ ) -> None:
296
+ """Add a user to database.
297
+
298
+ Args:
299
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()]: form with username and password to be added
300
+ role: UserRole: Role of the user to be added to the database
301
+ session: Annotated[Session, Depends(get_session)]: database session
302
+
303
+ Returns:
304
+ None
305
+
306
+ Raises:
307
+ HTTPException: If username already is in the database or if adding the user to the database failed.
308
+ """
309
+ username = form_data.username
310
+ password = form_data.password
311
+
312
+ # Check if a user with requested username is already in the database
313
+ if get_user(session=session, username=username):
314
+ raise HTTPException(
315
+ status_code=status.HTTP_409_CONFLICT,
316
+ detail="Username already taken.",
317
+ )
318
+
319
+ # Create the user model and put it into database
320
+ new_user = User(
321
+ username=username,
322
+ password_hash=get_password_hash(
323
+ password=password,
324
+ ),
325
+ role=role,
326
+ )
327
+ session.add(new_user)
328
+ session.commit()
329
+ session.refresh(new_user)
330
+
331
+ # Verify that the user actually is in the database
332
+ if not get_user(session=session, username=username):
333
+ raise HTTPException(
334
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
335
+ detail="Failed to add user into database.",
336
+ )
337
+
338
+
339
+ @router.get("/user_info")
340
+ def get_current_user_info(user: Annotated[User, Depends(get_current_user)]) -> UserPublic:
341
+ """Return information about the current user.
342
+
343
+ Args:
344
+ user (Annotated[User, Depends): user dependency, handled by `get_current_user`.
345
+
346
+ Returns:
347
+ UserPublic: public information about the current user.
348
+ """
349
+ return user
350
+
351
+
352
+ @router.post("/login", response_model=Tokens)
353
+ def login(
354
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
355
+ session: Annotated[Session, Depends(get_session)],
356
+ cookie_max_age: int = AuthConfig.authjwt_refresh_token_expires,
357
+ ):
358
+ """Login to get an authentication token.
359
+
360
+ Return an access token in the response and a cookie storing a refresh token.
361
+
362
+ Args:
363
+ form_data (Annotated[OAuth2PasswordRequestForm, Depends()]):
364
+ The form data to authenticate the user.
365
+ session (Annotated[Session, Depends(get_db)]): The database session.
366
+ cookie_max_age (int): the lifetime of the cookie storing the refresh token.
367
+
368
+ """
369
+ user = authenticate_user(session, form_data.username, form_data.password)
370
+ if user is None:
371
+ raise HTTPException(
372
+ status_code=status.HTTP_401_UNAUTHORIZED,
373
+ detail="Incorrect username or password",
374
+ headers={"WWW-Authenticate": "Bearer"},
375
+ )
376
+
377
+ tokens = generate_tokens({"id": user.id, "sub": user.username})
378
+
379
+ response = JSONResponse(content={"access_token": tokens.access_token, "refresh_token": tokens.refresh_token})
380
+
381
+ if AuthConfig.cookie_domain == "":
382
+ response.set_cookie(
383
+ key="refresh_token",
384
+ value=tokens.refresh_token,
385
+ httponly=True, # HTTP only cookie, more secure than storing the refresh token in the frontend code.
386
+ secure=False, # allow http
387
+ samesite="lax", # cross-origin requests
388
+ max_age=cookie_max_age * 60, # convert to minutes
389
+ path="/",
390
+ )
391
+ else:
392
+ response.set_cookie(
393
+ key="refresh_token",
394
+ value=tokens.refresh_token,
395
+ httponly=True, # keep this
396
+ secure=True, # MUST be true for HTTPS
397
+ samesite="none", # required for cross-site subdomains
398
+ max_age=cookie_max_age * 60,
399
+ path="/",
400
+ domain=AuthConfig.cookie_domain, # <- allow sharing between API + webui
401
+ )
402
+
403
+ return response
404
+
405
+
406
+ @router.post("/logout")
407
+ def logout() -> JSONResponse:
408
+ """Log the current user out. Deletes the refresh token that was set by logging in.
409
+
410
+ Args:
411
+ None
412
+
413
+ Returns:
414
+ JSONResponse: A response in which the cookies are deleted
415
+
416
+ """
417
+ response = JSONResponse(content={"message": "logged out"}, status_code=status.HTTP_200_OK)
418
+ response.delete_cookie("refresh_token")
419
+ return response
420
+
421
+
422
+ @router.post("/refresh")
423
+ def refresh_access_token(
424
+ request: Response,
425
+ session: Annotated[Session, Depends(get_session)],
426
+ refresh_token: Annotated[str | None, Cookie()] = None,
427
+ ):
428
+ """Refresh the access token using the refresh token stored in the cookie.
429
+
430
+ Args:
431
+ request (Request): The request containing the cookie.
432
+ session (Annotated[Session, Depends(get_db)]): the database session.
433
+ refresh_token (Annotated[Str | None, Cookie()]): the refresh
434
+ token, which is fetched from a cookie included in the response.
435
+
436
+ Returns:
437
+ dict: A dictionary containing the new access token.
438
+ """
439
+ if not refresh_token:
440
+ raise HTTPException(
441
+ status_code=status.HTTP_401_UNAUTHORIZED,
442
+ detail="Missing refresh token.",
443
+ headers={"WWW-Authenticate": "Bearer"},
444
+ )
445
+
446
+ user = validate_refresh_token(refresh_token, session)
447
+
448
+ # Generate a new access token for the user
449
+ access_token = create_access_token({"id": user.id, "sub": user.username})
450
+
451
+ return {"access_token": access_token}
452
+
453
+
454
+ @router.post("/add_new_dm")
455
+ def add_new_dm(
456
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
457
+ session: Annotated[Session, Depends(get_session)],
458
+ ) -> JSONResponse:
459
+ """Add a new user of the role Decision Maker to the database. Requires no login.
460
+
461
+ Args:
462
+ form_data (Annotated[OAuth2PasswordRequestForm, Depends()]): The user credentials to add to the database.
463
+ session (Annotated[Session, Depends(get_session)]): the database session.
464
+
465
+ Returns:
466
+ JSONResponse: A JSON response
467
+
468
+ Raises:
469
+ HTTPException: if username is already in use or if saving to the database fails for some reason.
470
+ """
471
+ add_user_to_database(
472
+ form_data=form_data,
473
+ role=UserRole.dm,
474
+ session=session,
475
+ )
476
+
477
+ return JSONResponse(
478
+ content={"message": 'User with role "decision maker" created.'},
479
+ status_code=status.HTTP_201_CREATED,
480
+ )
481
+
482
+
483
+ @router.post("/add_new_analyst")
484
+ def add_new_analyst(
485
+ user: Annotated[User, Depends(get_current_user)],
486
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
487
+ session: Annotated[Session, Depends(get_session)],
488
+ ) -> JSONResponse:
489
+ """Add a new user of the role Analyst to the database. Requires a logged in analyst or an admin.
490
+
491
+ Args:
492
+ user: Annotated[User, Depends(get_current_user)]: Logged in user with the role "analyst" or "admin".
493
+ form_data: (Annotated[OAuth2PasswordRequestForm, Depends()]): The user credentials to add to the database.
494
+ session: (Annotated[Session, Depends(get_session)]): the database session.
495
+
496
+ Returns:
497
+ JSONResponse: A JSON response
498
+
499
+ Raises:
500
+ HTTPException: if the logged in user is not an analyst or an admin or if
501
+ username is already in use or if saving to the database fails for some reason.
502
+
503
+ """
504
+ # Check if the user who tries to create the user is either an analyst or an admin.
505
+ if not (user.role == UserRole.analyst or user.role == UserRole.admin):
506
+ raise HTTPException(
507
+ status_code=status.HTTP_401_UNAUTHORIZED,
508
+ detail="Logged in user has insufficient rights.",
509
+ )
510
+
511
+ add_user_to_database(
512
+ form_data=form_data,
513
+ role=UserRole.analyst,
514
+ session=session,
515
+ )
516
+
517
+ return JSONResponse(
518
+ content={"message": 'User with role "analyst" created.'},
519
+ status_code=status.HTTP_201_CREATED,
520
+ )
@@ -0,0 +1,187 @@
1
+ """A selection of utilities for handling routers and data therein.
2
+
3
+ NOTE: No routers should be defined in this file!
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Annotated
8
+
9
+ from fastapi import Depends, HTTPException, status
10
+ from sqlmodel import Session, select
11
+
12
+ from desdeo.api.db import get_session
13
+ from desdeo.api.models import (
14
+ ENautilusStepRequest,
15
+ InteractiveSessionDB,
16
+ ProblemDB,
17
+ RPMSolveRequest,
18
+ StateDB,
19
+ User,
20
+ )
21
+ from desdeo.api.routers.user_authentication import get_current_user
22
+
23
+ RequestType = RPMSolveRequest | ENautilusStepRequest
24
+
25
+
26
+ def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None:
27
+ """Gets the desired instance of `InteractiveSessionDB`.
28
+
29
+ Args:
30
+ user (User): the user whose interactive sessions are to be queried.
31
+ request (RequestType): the request with possibly information on which interactive session to query.
32
+ session (Session): the database session (not to be confused with the interactive session) from
33
+ which the interactive session should be queried.
34
+
35
+ Note:
36
+ If no explicit `session_id` is given in `request`, this function will try to fetch the
37
+ currently active interactive session for the `user`, e.g., with id `user.active_session_id`.
38
+ If this is `None`, then the interactive session returned will be `None` as well.
39
+
40
+ Raises:
41
+ HTTPException: when an explicit interactive session is requested, but it is not found.
42
+
43
+ Returns:
44
+ InteractiveSessionDB | None: an interactive session DB model, or nothing.
45
+ """
46
+ if request.session_id is not None:
47
+ # specific interactive session id is given, try using that
48
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
49
+ interactive_session = session.exec(statement).first()
50
+
51
+ if interactive_session is None:
52
+ # Raise if explicitly requested interactive session cannot be found
53
+ raise HTTPException(
54
+ status_code=status.HTTP_404_NOT_FOUND,
55
+ detail=f"Could not find interactive session with id={request.session_id}.",
56
+ )
57
+ else:
58
+ # request.session_id is None
59
+ # try to use active session instead
60
+
61
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
62
+
63
+ interactive_session = session.exec(statement).first()
64
+
65
+ # At this point interactive_session is either an instance of InteractiveSessionDB or None (which is fine)
66
+
67
+ return interactive_session
68
+
69
+
70
+ def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB:
71
+ """Fetches a user's `ProblemDB` based on the id in the given request.
72
+
73
+ Args:
74
+ user (User): the user for which the problem is fetched.
75
+ request (RequestType): request containing details of the problem to be fetched (`request.problem_id`).
76
+ session (Session): the database session from which to fetch the problem.
77
+
78
+ Raises:
79
+ HTTPException: a problem with the given id (`request.problem_id`) could not be found (404).
80
+
81
+ Returns:
82
+ Problem: the instance of `ProblemDB` with the given id.
83
+ """
84
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
85
+ problem_db = session.exec(statement).first()
86
+
87
+ if problem_db is None:
88
+ raise HTTPException(
89
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
90
+ )
91
+
92
+ return problem_db
93
+
94
+
95
+ def fetch_parent_state(
96
+ user: User, request: RequestType, session: Session, interactive_session: InteractiveSessionDB | None = None
97
+ ) -> StateDB | None:
98
+ """Fetches the parent state, if an id is given, or if defined in the given interactive session.
99
+
100
+ Determines the appropriate parent `StateDB` instance to associate with a new
101
+ state or operation. It first checks whether the `request` explicitly
102
+ provides a `parent_state_id`. If so, it attempts to retrieve the
103
+ corresponding `StateDB` entry from the database. If no such id is provided,
104
+ the function defaults to returning the most recently added state from the
105
+ given `interactive_session`, if available. If neither source provides a
106
+ parent state, `None` is returned.
107
+
108
+
109
+ Args:
110
+ user (User): the user for which the parent state is fetched.
111
+ request (RequestType): request containing details about the parent state and optionally the
112
+ interactive session.
113
+ session (Session): the database session from which to fetch the parent state.
114
+ interactive_session (InteractiveSessionDB | None, optional): the interactive session containing
115
+ information about the parent state. Defaults to None.
116
+
117
+ Raises:
118
+ HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot
119
+ be found in the given database session.
120
+
121
+ Returns:
122
+ StateDB | None: if `request.parent_state_id` is given, returns the corresponding `StateDB`.
123
+ If it is not given, returns the latest state defined in `interactive_session.states`.
124
+ If both `request.parent_state_id` and `interactive_session` are `None`, then returns `None`.
125
+ """
126
+ if request.parent_state_id is None:
127
+ # parent state is assumed to be the last sate added to the session.
128
+ # if `interactive_session` is None, then parent state is set to None.
129
+ parent_state = (
130
+ interactive_session.states[-1]
131
+ if (interactive_session is not None and len(interactive_session.states) > 0)
132
+ else None
133
+ )
134
+
135
+ else:
136
+ # request.parent_state_id is not None
137
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
138
+ parent_state = session.exec(statement).first()
139
+
140
+ # this error is raised because if a parent_state_id is given, it is assumed that the
141
+ # user wished to use that state explicitly as the parent.
142
+ if parent_state is None:
143
+ raise HTTPException(
144
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
145
+ )
146
+
147
+ return parent_state
148
+
149
+
150
+ @dataclass(frozen=True)
151
+ class SessionContext:
152
+ """A generic context to be used in various endpoints."""
153
+
154
+ user: User
155
+ db_session: Session
156
+ problem_db: ProblemDB
157
+ interactive_session: InteractiveSessionDB | None
158
+ parent_state: StateDB | None
159
+
160
+
161
+ def get_session_context(
162
+ request: RequestType,
163
+ user: Annotated[User, Depends(get_current_user)],
164
+ db_session: Annotated[Session, Depends(get_session)],
165
+ ) -> SessionContext:
166
+ """Gets the current session context. Should be used as a dep.
167
+
168
+ Args:
169
+ request (RequestType): request based on which the context is fetched.
170
+ user (Annotated[User, Depends): the current user (dep).
171
+ db_session (Annotated[Session, Depends): the current database session (dep).
172
+
173
+ Returns:
174
+ SessionContext: the current session context with the relevant instances
175
+ of `User`, `Session`, `ProblemDB`, `InteractiveSessionDB`, and `StateDB`.
176
+ """
177
+ problem_db = fetch_user_problem(user, request, db_session)
178
+ interactive_session = fetch_interactive_session(user, request, db_session)
179
+ parent_state = fetch_parent_state(user, request, db_session, interactive_session=interactive_session)
180
+
181
+ return SessionContext(
182
+ user=user,
183
+ db_session=db_session,
184
+ problem_db=problem_db,
185
+ interactive_session=interactive_session,
186
+ parent_state=parent_state,
187
+ )