desdeo 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/api/models/nimbus.py +8 -4
- desdeo/api/routers/emo.py +75 -104
- desdeo/api/routers/generic.py +26 -58
- desdeo/api/routers/nimbus.py +108 -247
- desdeo/api/routers/problem.py +69 -56
- desdeo/api/routers/reference_point_method.py +29 -27
- desdeo/api/routers/session.py +15 -11
- desdeo/api/routers/user_authentication.py +27 -5
- desdeo/api/routers/utils.py +42 -37
- desdeo/api/routers/utopia.py +11 -12
- desdeo/api/tests/test_routes.py +6 -5
- desdeo/emo/__init__.py +2 -0
- desdeo/emo/operators/__init__.py +1 -1
- desdeo/emo/operators/generator.py +153 -2
- desdeo/emo/options/__init__.py +4 -0
- desdeo/emo/options/generator.py +24 -0
- desdeo/problem/__init__.py +12 -11
- desdeo/problem/evaluator.py +4 -5
- desdeo/problem/gurobipy_evaluator.py +37 -12
- desdeo/problem/infix_parser.py +1 -16
- desdeo/problem/json_parser.py +7 -11
- desdeo/problem/schema.py +6 -9
- desdeo/problem/utils.py +1 -1
- desdeo/tools/pyomo_solver_interfaces.py +1 -1
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/METADATA +21 -12
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/RECORD +28 -28
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/WHEEL +1 -1
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/licenses/LICENSE +0 -0
desdeo/api/routers/problem.py
CHANGED
|
@@ -5,9 +5,8 @@ from typing import Annotated
|
|
|
5
5
|
|
|
6
6
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
7
7
|
from fastapi.responses import JSONResponse
|
|
8
|
-
from sqlmodel import
|
|
8
|
+
from sqlmodel import select
|
|
9
9
|
|
|
10
|
-
from desdeo.api.db import get_session
|
|
11
10
|
from desdeo.api.models import (
|
|
12
11
|
ForestProblemMetaData,
|
|
13
12
|
ProblemDB,
|
|
@@ -26,6 +25,8 @@ from desdeo.api.routers.user_authentication import get_current_user
|
|
|
26
25
|
from desdeo.problem import Problem
|
|
27
26
|
from desdeo.tools.utils import available_solvers
|
|
28
27
|
|
|
28
|
+
from .utils import SessionContext, get_session_context, get_session_context_without_request
|
|
29
|
+
|
|
29
30
|
router = APIRouter(prefix="/problem")
|
|
30
31
|
|
|
31
32
|
|
|
@@ -87,15 +88,13 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[
|
|
|
87
88
|
@router.post("/get")
|
|
88
89
|
def get_problem(
|
|
89
90
|
request: ProblemGetRequest,
|
|
90
|
-
|
|
91
|
-
session: Annotated[Session, Depends(get_session)],
|
|
91
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
92
92
|
) -> ProblemInfo:
|
|
93
93
|
"""Get the model of a specific problem.
|
|
94
94
|
|
|
95
95
|
Args:
|
|
96
96
|
request (ProblemGetRequest): the request containing the problem's id `problem_id`.
|
|
97
|
-
|
|
98
|
-
session (Annotated[Session, Depends): the database session.
|
|
97
|
+
context (Annotated[SessionContext, Depends): the session context.
|
|
99
98
|
|
|
100
99
|
Raises:
|
|
101
100
|
HTTPException: could not find a problem with the given id.
|
|
@@ -103,29 +102,28 @@ def get_problem(
|
|
|
103
102
|
Returns:
|
|
104
103
|
ProblemInfo: detailed information on the requested problem.
|
|
105
104
|
"""
|
|
106
|
-
|
|
105
|
+
problem_db = context.problem_db
|
|
107
106
|
|
|
108
|
-
|
|
107
|
+
# Ensure problem exists
|
|
108
|
+
if problem_db is None:
|
|
109
109
|
raise HTTPException(
|
|
110
110
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
111
111
|
detail=f"The problem with the requested id={request.problem_id} was not found.",
|
|
112
112
|
)
|
|
113
113
|
|
|
114
|
-
return
|
|
114
|
+
return problem_db
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
@router.post("/add")
|
|
118
118
|
def add_problem(
|
|
119
119
|
request: Annotated[Problem, Depends(parse_problem_json)],
|
|
120
|
-
|
|
121
|
-
session: Annotated[Session, Depends(get_session)],
|
|
120
|
+
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
|
|
122
121
|
) -> ProblemInfo:
|
|
123
122
|
"""Add a newly defined problem to the database.
|
|
124
123
|
|
|
125
124
|
Args:
|
|
126
125
|
request (Problem): the JSON representation of the problem.
|
|
127
|
-
|
|
128
|
-
session (Annotated[Session, Depends): the database session.
|
|
126
|
+
context (Annotated[SessionContext, Depends): the session context.
|
|
129
127
|
|
|
130
128
|
Note:
|
|
131
129
|
Users with the role 'guest' may not add new problems.
|
|
@@ -136,10 +134,15 @@ def add_problem(
|
|
|
136
134
|
Returns:
|
|
137
135
|
ProblemInfo: the information about the problem added.
|
|
138
136
|
"""
|
|
137
|
+
user = context.user
|
|
138
|
+
db_session = context.db_session
|
|
139
|
+
|
|
139
140
|
if user.role == UserRole.guest:
|
|
140
141
|
raise HTTPException(
|
|
141
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
142
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
143
|
+
detail="Guest users are not allowed to add new problems.",
|
|
142
144
|
)
|
|
145
|
+
|
|
143
146
|
try:
|
|
144
147
|
problem_db = ProblemDB.from_problem(request, user=user)
|
|
145
148
|
except Exception as e:
|
|
@@ -148,9 +151,9 @@ def add_problem(
|
|
|
148
151
|
detail=f"Could not add problem. Possible reason: {e!r}",
|
|
149
152
|
) from e
|
|
150
153
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
+
db_session.add(problem_db)
|
|
155
|
+
db_session.commit()
|
|
156
|
+
db_session.refresh(problem_db)
|
|
154
157
|
|
|
155
158
|
return problem_db
|
|
156
159
|
|
|
@@ -158,15 +161,13 @@ def add_problem(
|
|
|
158
161
|
@router.post("/add_json")
|
|
159
162
|
def add_problem_json(
|
|
160
163
|
json_file: UploadFile,
|
|
161
|
-
|
|
162
|
-
session: Annotated[Session, Depends(get_session)],
|
|
164
|
+
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
|
|
163
165
|
) -> ProblemInfo:
|
|
164
166
|
"""Adds a problem to the database based on its JSON definition.
|
|
165
167
|
|
|
166
168
|
Args:
|
|
167
169
|
json_file (UploadFile): a file in JSON format describing the problem.
|
|
168
|
-
|
|
169
|
-
session (Annotated[Session, Depends): the database session.
|
|
170
|
+
context (Annotated[SessionContext, Depends): the session context.
|
|
170
171
|
|
|
171
172
|
Raises:
|
|
172
173
|
HTTPException: if the provided `json_file` is empty.
|
|
@@ -175,23 +176,25 @@ def add_problem_json(
|
|
|
175
176
|
Returns:
|
|
176
177
|
ProblemInfo: a description of the added problem.
|
|
177
178
|
"""
|
|
179
|
+
user = context.user
|
|
180
|
+
db_session = context.db_session
|
|
181
|
+
|
|
178
182
|
raw = json_file.file.read()
|
|
179
183
|
|
|
180
184
|
if not raw:
|
|
181
|
-
raise HTTPException(400, "Empty upload.")
|
|
185
|
+
raise HTTPException(status_code=400, detail="Empty upload.")
|
|
182
186
|
|
|
183
187
|
try:
|
|
184
|
-
#
|
|
185
|
-
json.loads(raw)
|
|
188
|
+
json.loads(raw) # extra validation
|
|
186
189
|
except json.JSONDecodeError as e:
|
|
187
|
-
raise HTTPException(400, "Invalid JSON.") from e
|
|
190
|
+
raise HTTPException(status_code=400, detail="Invalid JSON.") from e
|
|
188
191
|
|
|
189
192
|
problem = Problem.model_validate_json(raw, by_name=True)
|
|
190
193
|
problem_db = ProblemDB.from_problem(problem, user=user)
|
|
191
194
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
+
db_session.add(problem_db)
|
|
196
|
+
db_session.commit()
|
|
197
|
+
db_session.refresh(problem_db)
|
|
195
198
|
|
|
196
199
|
return problem_db
|
|
197
200
|
|
|
@@ -199,8 +202,7 @@ def add_problem_json(
|
|
|
199
202
|
@router.post("/get_metadata")
|
|
200
203
|
def get_metadata(
|
|
201
204
|
request: ProblemMetaDataGetRequest,
|
|
202
|
-
|
|
203
|
-
session: Annotated[Session, Depends(get_session)],
|
|
205
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
204
206
|
) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]:
|
|
205
207
|
"""Fetch specific metadata for a specific problem.
|
|
206
208
|
|
|
@@ -210,19 +212,21 @@ def get_metadata(
|
|
|
210
212
|
|
|
211
213
|
Args:
|
|
212
214
|
request (MetaDataGetRequest): the requested metadata type.
|
|
213
|
-
|
|
214
|
-
session (Annotated[Session, Depends]): the database session.
|
|
215
|
+
context (Annotated[SessionContext, Depends]): the session context.
|
|
215
216
|
|
|
216
217
|
Returns:
|
|
217
218
|
list[ForestProblemMetadata | RepresentativeNonDominatedSolutions]: list containing all the metadata
|
|
218
219
|
defined for the problem with the requested metadata type. If no match is found,
|
|
219
220
|
returns an empty list.
|
|
220
221
|
"""
|
|
221
|
-
|
|
222
|
-
|
|
222
|
+
db_session = context.db_session
|
|
223
|
+
|
|
224
|
+
problem_from_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
|
|
225
|
+
|
|
223
226
|
if problem_from_db is None:
|
|
224
227
|
raise HTTPException(
|
|
225
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
228
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
229
|
+
detail=f"Problem with ID {request.problem_id} not found!",
|
|
226
230
|
)
|
|
227
231
|
|
|
228
232
|
problem_metadata = problem_from_db.problem_metadata
|
|
@@ -230,7 +234,6 @@ def get_metadata(
|
|
|
230
234
|
if problem_metadata is None:
|
|
231
235
|
# no metadata define for the problem
|
|
232
236
|
return []
|
|
233
|
-
|
|
234
237
|
# metadata is defined, try to find matching types based on request
|
|
235
238
|
return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type]
|
|
236
239
|
|
|
@@ -240,17 +243,17 @@ def get_available_solvers() -> list[str]:
|
|
|
240
243
|
"""Return the list of available solver names."""
|
|
241
244
|
return list(available_solvers.keys())
|
|
242
245
|
|
|
246
|
+
|
|
243
247
|
@router.post("/assign_solver")
|
|
244
248
|
def select_solver(
|
|
245
249
|
request: ProblemSelectSolverRequest,
|
|
246
|
-
|
|
247
|
-
session: Annotated[Session, Depends(get_session)],
|
|
250
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
248
251
|
) -> JSONResponse:
|
|
249
252
|
"""Assign a specific solver for a problem.
|
|
250
253
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
+
Args:
|
|
255
|
+
request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver
|
|
256
|
+
context: Annotated[SessionContext, Depends(get_session)]: The session context.
|
|
254
257
|
|
|
255
258
|
Raises:
|
|
256
259
|
HTTPException: Unknown solver, unauthorized user
|
|
@@ -258,50 +261,60 @@ def select_solver(
|
|
|
258
261
|
Returns:
|
|
259
262
|
JSONResponse: A simple confirmation.
|
|
260
263
|
"""
|
|
264
|
+
db_session = context.db_session
|
|
265
|
+
user = context.user
|
|
266
|
+
|
|
267
|
+
# Validate solver type
|
|
261
268
|
if request.solver_string_representation not in [x for x, _ in available_solvers.items()]:
|
|
262
269
|
raise HTTPException(
|
|
263
270
|
detail=f"Solver of unknown type: {request.solver_string_representation}",
|
|
264
271
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
265
272
|
)
|
|
266
273
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
274
|
+
# Fetch problem
|
|
275
|
+
problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
|
|
276
|
+
|
|
270
277
|
if problem_db is None:
|
|
271
278
|
raise HTTPException(
|
|
272
279
|
detail=f"No problem with ID {request.problem_id}!",
|
|
273
280
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
274
281
|
)
|
|
282
|
+
|
|
275
283
|
# Auth the user
|
|
276
284
|
if user.id != problem_db.user_id:
|
|
277
|
-
raise HTTPException(
|
|
285
|
+
raise HTTPException(
|
|
286
|
+
detail="Unauthorized user!",
|
|
287
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
288
|
+
)
|
|
278
289
|
|
|
279
290
|
# All good, get on with it.
|
|
280
291
|
problem_metadata = problem_db.problem_metadata
|
|
281
292
|
if problem_metadata is None:
|
|
282
293
|
# There's no metadata for this problem! Create some.
|
|
283
294
|
problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db)
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
295
|
+
db_session.add(problem_metadata)
|
|
296
|
+
db_session.commit()
|
|
297
|
+
db_session.refresh(problem_metadata)
|
|
287
298
|
|
|
299
|
+
# Remove existing solver selection metadata
|
|
288
300
|
if problem_metadata.solver_selection_metadata:
|
|
289
|
-
|
|
290
|
-
|
|
301
|
+
db_session.delete(problem_metadata.solver_selection_metadata[-1])
|
|
302
|
+
db_session.commit()
|
|
291
303
|
|
|
304
|
+
# Add new solver selection metadata
|
|
292
305
|
solver_selection_metadata = SolverSelectionMetadata(
|
|
293
306
|
metadata_id=problem_metadata.id,
|
|
294
307
|
solver_string_representation=request.solver_string_representation,
|
|
295
308
|
metadata_instance=problem_metadata,
|
|
296
309
|
)
|
|
297
310
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
311
|
+
db_session.add(solver_selection_metadata)
|
|
312
|
+
db_session.commit()
|
|
313
|
+
db_session.refresh(solver_selection_metadata)
|
|
301
314
|
|
|
302
315
|
problem_metadata.solver_selection_metadata.append(solver_selection_metadata)
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
316
|
+
db_session.add(problem_metadata)
|
|
317
|
+
db_session.commit()
|
|
318
|
+
db_session.refresh(problem_metadata)
|
|
306
319
|
|
|
307
320
|
return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK)
|
|
@@ -2,26 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Annotated
|
|
4
4
|
|
|
5
|
-
from fastapi import APIRouter, Depends
|
|
6
|
-
from sqlmodel import Session
|
|
5
|
+
from fastapi import APIRouter, Depends, HTTPException, status
|
|
7
6
|
|
|
8
|
-
from desdeo.api.db import get_session
|
|
9
7
|
from desdeo.api.models import (
|
|
10
|
-
InteractiveSessionDB,
|
|
11
8
|
PreferenceDB,
|
|
12
|
-
ProblemDB,
|
|
13
9
|
RPMSolveRequest,
|
|
14
10
|
RPMState,
|
|
15
11
|
StateDB,
|
|
16
|
-
User,
|
|
17
12
|
)
|
|
18
13
|
from desdeo.api.routers.problem import check_solver
|
|
19
|
-
from desdeo.api.routers.user_authentication import get_current_user
|
|
20
14
|
from desdeo.mcdm import rpm_solve_solutions
|
|
21
15
|
from desdeo.problem import Problem
|
|
22
16
|
from desdeo.tools import SolverResults
|
|
23
17
|
|
|
24
|
-
from .utils import
|
|
18
|
+
from .utils import SessionContext, get_session_context
|
|
25
19
|
|
|
26
20
|
router = APIRouter(prefix="/method/rpm")
|
|
27
21
|
|
|
@@ -29,31 +23,35 @@ router = APIRouter(prefix="/method/rpm")
|
|
|
29
23
|
@router.post("/solve")
|
|
30
24
|
def solve_solutions(
|
|
31
25
|
request: RPMSolveRequest,
|
|
32
|
-
|
|
33
|
-
session: Annotated[Session, Depends(get_session)],
|
|
26
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
34
27
|
) -> RPMState:
|
|
35
28
|
"""Runs an iteration of the reference point method.
|
|
36
29
|
|
|
37
30
|
Args:
|
|
38
31
|
request (RPMSolveRequest): a request with the needed information to run the method.
|
|
39
32
|
user (Annotated[User, Depends): the current user.
|
|
40
|
-
|
|
33
|
+
context (Annotated[SessionContext, Depends): the current session context.
|
|
41
34
|
|
|
42
35
|
Returns:
|
|
43
36
|
RPMState: a state with information on the results of iterating the reference point method
|
|
44
37
|
once.
|
|
45
38
|
"""
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
39
|
+
user = context.user
|
|
40
|
+
db_session = context.db_session
|
|
41
|
+
problem_db = context.problem_db
|
|
42
|
+
interactive_session = context.interactive_session
|
|
43
|
+
parent_state = context.parent_state
|
|
44
|
+
|
|
45
|
+
# ensure problem exists
|
|
46
|
+
if problem_db is None:
|
|
47
|
+
raise HTTPException(
|
|
48
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
49
|
+
detail="Problem context missing.",
|
|
50
|
+
)
|
|
51
51
|
|
|
52
52
|
solver = check_solver(problem_db=problem_db)
|
|
53
|
-
|
|
54
53
|
problem = Problem.from_problemdb(problem_db)
|
|
55
54
|
|
|
56
|
-
# optimize for solutions
|
|
57
55
|
solver_results: list[SolverResults] = rpm_solve_solutions(
|
|
58
56
|
problem,
|
|
59
57
|
request.preference.aspiration_levels,
|
|
@@ -63,11 +61,15 @@ def solve_solutions(
|
|
|
63
61
|
)
|
|
64
62
|
|
|
65
63
|
# create DB preference
|
|
66
|
-
preference_db = PreferenceDB(
|
|
64
|
+
preference_db = PreferenceDB(
|
|
65
|
+
user_id=user.id,
|
|
66
|
+
problem_id=problem_db.id,
|
|
67
|
+
preference=request.preference,
|
|
68
|
+
)
|
|
67
69
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
70
|
+
db_session.add(preference_db)
|
|
71
|
+
db_session.commit()
|
|
72
|
+
db_session.refresh(preference_db)
|
|
71
73
|
|
|
72
74
|
# create state and add to DB
|
|
73
75
|
rpm_state = RPMState(
|
|
@@ -81,13 +83,13 @@ def solve_solutions(
|
|
|
81
83
|
state = StateDB(
|
|
82
84
|
problem_id=problem_db.id,
|
|
83
85
|
preference_id=preference_db.id,
|
|
84
|
-
session_id=interactive_session.id if interactive_session
|
|
85
|
-
parent_id=parent_state.id if parent_state
|
|
86
|
+
session_id=interactive_session.id if interactive_session else None,
|
|
87
|
+
parent_id=parent_state.id if parent_state else None,
|
|
86
88
|
state=rpm_state,
|
|
87
89
|
)
|
|
88
90
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
91
|
+
db_session.add(state)
|
|
92
|
+
db_session.commit()
|
|
93
|
+
db_session.refresh(state)
|
|
92
94
|
|
|
93
95
|
return rpm_state
|
desdeo/api/routers/session.py
CHANGED
|
@@ -14,7 +14,7 @@ from desdeo.api.models import (
|
|
|
14
14
|
User,
|
|
15
15
|
)
|
|
16
16
|
from desdeo.api.routers.user_authentication import get_current_user
|
|
17
|
-
from desdeo.api.routers.utils import fetch_interactive_session
|
|
17
|
+
from desdeo.api.routers.utils import SessionContext, fetch_interactive_session, get_session_context_without_request
|
|
18
18
|
|
|
19
19
|
router = APIRouter(prefix="/session")
|
|
20
20
|
|
|
@@ -22,21 +22,25 @@ router = APIRouter(prefix="/session")
|
|
|
22
22
|
@router.post("/new")
|
|
23
23
|
def create_new_session(
|
|
24
24
|
request: CreateSessionRequest,
|
|
25
|
-
|
|
26
|
-
session: Annotated[Session, Depends(get_db_session)],
|
|
25
|
+
context: Annotated[SessionContext, Depends(get_session_context_without_request)],
|
|
27
26
|
) -> InteractiveSessionInfo:
|
|
28
|
-
"""."""
|
|
29
|
-
|
|
27
|
+
"""Creates a new interactive session."""
|
|
28
|
+
user = context.user
|
|
29
|
+
db_session = context.db_session
|
|
30
|
+
|
|
31
|
+
interactive_session = InteractiveSessionDB(
|
|
32
|
+
user_id=user.id,
|
|
33
|
+
info=request.info,
|
|
34
|
+
)
|
|
30
35
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
36
|
+
db_session.add(interactive_session)
|
|
37
|
+
db_session.commit()
|
|
38
|
+
db_session.refresh(interactive_session)
|
|
34
39
|
|
|
35
40
|
user.active_session_id = interactive_session.id
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
session.refresh(interactive_session)
|
|
42
|
+
db_session.add(user)
|
|
43
|
+
db_session.commit()
|
|
40
44
|
|
|
41
45
|
return interactive_session
|
|
42
46
|
|
|
@@ -9,8 +9,6 @@ from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, Securit
|
|
|
9
9
|
from fastapi.responses import JSONResponse
|
|
10
10
|
from fastapi.security import (
|
|
11
11
|
APIKeyCookie,
|
|
12
|
-
HTTPAuthorizationCredentials,
|
|
13
|
-
HTTPBearer,
|
|
14
12
|
OAuth2PasswordBearer,
|
|
15
13
|
OAuth2PasswordRequestForm,
|
|
16
14
|
)
|
|
@@ -119,7 +117,8 @@ def get_current_user(
|
|
|
119
117
|
This function is a dependency for other functions that need to get the current user.
|
|
120
118
|
|
|
121
119
|
Args:
|
|
122
|
-
|
|
120
|
+
header_token (Annotated[str, Depends(oauth2_scheme)]): The authentication token as part of the request header.
|
|
121
|
+
cookie_token (Annotated[str, Depends(cookie_scheme)]): The authentication token as part of request cookie.
|
|
123
122
|
session (Annotated[Session, Depends(get_db)]): A database session.
|
|
124
123
|
|
|
125
124
|
Returns:
|
|
@@ -447,8 +446,31 @@ def refresh_access_token(
|
|
|
447
446
|
|
|
448
447
|
# Generate a new access token for the user
|
|
449
448
|
access_token = create_access_token({"id": user.id, "sub": user.username})
|
|
449
|
+
response = JSONResponse(content={"access_token": access_token})
|
|
450
450
|
|
|
451
|
-
|
|
451
|
+
if AuthConfig.cookie_domain == "":
|
|
452
|
+
response.set_cookie(
|
|
453
|
+
key="access_token",
|
|
454
|
+
value=access_token,
|
|
455
|
+
httponly=True,
|
|
456
|
+
secure=False,
|
|
457
|
+
samesite="lax",
|
|
458
|
+
max_age=AuthConfig.authjwt_access_token_expires * 60,
|
|
459
|
+
path="/",
|
|
460
|
+
)
|
|
461
|
+
else:
|
|
462
|
+
response.set_cookie(
|
|
463
|
+
key="access_token",
|
|
464
|
+
value=access_token,
|
|
465
|
+
httponly=True,
|
|
466
|
+
secure=True,
|
|
467
|
+
samesite="none",
|
|
468
|
+
max_age=AuthConfig.authjwt_access_token_expires * 60,
|
|
469
|
+
path="/",
|
|
470
|
+
domain=AuthConfig.cookie_domain,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
return response
|
|
452
474
|
|
|
453
475
|
|
|
454
476
|
@router.post("/add_new_dm")
|
|
@@ -502,7 +524,7 @@ def add_new_analyst(
|
|
|
502
524
|
|
|
503
525
|
"""
|
|
504
526
|
# Check if the user who tries to create the user is either an analyst or an admin.
|
|
505
|
-
if
|
|
527
|
+
if user.role not in (UserRole.analyst, UserRole.admin):
|
|
506
528
|
raise HTTPException(
|
|
507
529
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
508
530
|
detail="Logged in user has insufficient rights.",
|
desdeo/api/routers/utils.py
CHANGED
|
@@ -67,7 +67,7 @@ def fetch_interactive_session(user: User, request: RequestType, session: Session
|
|
|
67
67
|
return interactive_session
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB:
|
|
70
|
+
def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB | None:
|
|
71
71
|
"""Fetches a user's `ProblemDB` based on the id in the given request.
|
|
72
72
|
|
|
73
73
|
Args:
|
|
@@ -81,19 +81,21 @@ def fetch_user_problem(user: User, request: RequestType, session: Session) -> Pr
|
|
|
81
81
|
Returns:
|
|
82
82
|
Problem: the instance of `ProblemDB` with the given id.
|
|
83
83
|
"""
|
|
84
|
-
|
|
85
|
-
|
|
84
|
+
if request.problem_id is None:
|
|
85
|
+
return None
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
return problem_db
|
|
87
|
+
statement = select(ProblemDB).where(
|
|
88
|
+
ProblemDB.user_id == user.id,
|
|
89
|
+
ProblemDB.id == request.problem_id,
|
|
90
|
+
)
|
|
91
|
+
return session.exec(statement).first()
|
|
93
92
|
|
|
94
93
|
|
|
95
94
|
def fetch_parent_state(
|
|
96
|
-
user: User,
|
|
95
|
+
user: User,
|
|
96
|
+
request: RequestType,
|
|
97
|
+
session: Session,
|
|
98
|
+
interactive_session: InteractiveSessionDB | None = None,
|
|
97
99
|
) -> StateDB | None:
|
|
98
100
|
"""Fetches the parent state, if an id is given, or if defined in the given interactive session.
|
|
99
101
|
|
|
@@ -105,18 +107,17 @@ def fetch_parent_state(
|
|
|
105
107
|
given `interactive_session`, if available. If neither source provides a
|
|
106
108
|
parent state, `None` is returned.
|
|
107
109
|
|
|
108
|
-
|
|
109
110
|
Args:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
111
|
+
user (User): the user for which the parent state is fetched.
|
|
112
|
+
request (RequestType): request containing details about the parent state and optionally the
|
|
113
|
+
interactive session.
|
|
114
|
+
session (Session): the database session from which to fetch the parent state.
|
|
115
|
+
interactive_session (InteractiveSessionDB | None, optional): the interactive session containing
|
|
116
|
+
information about the parent state. Defaults to None.
|
|
116
117
|
|
|
117
118
|
Raises:
|
|
118
|
-
|
|
119
|
-
|
|
119
|
+
HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot
|
|
120
|
+
be found in the given database session.
|
|
120
121
|
|
|
121
122
|
Returns:
|
|
122
123
|
StateDB | None: if `request.parent_state_id` is given, returns the corresponding `StateDB`.
|
|
@@ -126,23 +127,19 @@ def fetch_parent_state(
|
|
|
126
127
|
if request.parent_state_id is None:
|
|
127
128
|
# parent state is assumed to be the last sate added to the session.
|
|
128
129
|
# if `interactive_session` is None, then parent state is set to None.
|
|
129
|
-
|
|
130
|
-
interactive_session.states[-1]
|
|
131
|
-
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
132
|
-
else None
|
|
133
|
-
)
|
|
130
|
+
return interactive_session.states[-1] if interactive_session and interactive_session.states else None
|
|
134
131
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
parent_state = session.exec(statement).first()
|
|
132
|
+
# request.parent_state_id is not None
|
|
133
|
+
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
134
|
+
parent_state = session.exec(statement).first()
|
|
139
135
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
136
|
+
# this error is raised because if a parent_state_id is given, it is assumed that the
|
|
137
|
+
# user wished to use that state explicitly as the parent.
|
|
138
|
+
if parent_state is None:
|
|
139
|
+
raise HTTPException(
|
|
140
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
141
|
+
detail=f"Could not find state with id={request.parent_state_id}",
|
|
142
|
+
)
|
|
146
143
|
|
|
147
144
|
return parent_state
|
|
148
145
|
|
|
@@ -153,9 +150,9 @@ class SessionContext:
|
|
|
153
150
|
|
|
154
151
|
user: User
|
|
155
152
|
db_session: Session
|
|
156
|
-
problem_db: ProblemDB
|
|
157
|
-
interactive_session: InteractiveSessionDB | None
|
|
158
|
-
parent_state: StateDB | None
|
|
153
|
+
problem_db: ProblemDB | None = None
|
|
154
|
+
interactive_session: InteractiveSessionDB | None = None
|
|
155
|
+
parent_state: StateDB | None = None
|
|
159
156
|
|
|
160
157
|
|
|
161
158
|
def get_session_context(
|
|
@@ -185,3 +182,11 @@ def get_session_context(
|
|
|
185
182
|
interactive_session=interactive_session,
|
|
186
183
|
parent_state=parent_state,
|
|
187
184
|
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def get_session_context_without_request(
|
|
188
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
189
|
+
db_session: Annotated[Session, Depends(get_session)],
|
|
190
|
+
) -> SessionContext:
|
|
191
|
+
"""Gets the current session context. Should be used as a dep."""
|
|
192
|
+
return SessionContext(user=user, db_session=db_session)
|