desdeo 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,9 +5,8 @@ from typing import Annotated
5
5
 
6
6
  from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
7
7
  from fastapi.responses import JSONResponse
8
- from sqlmodel import Session, select
8
+ from sqlmodel import select
9
9
 
10
- from desdeo.api.db import get_session
11
10
  from desdeo.api.models import (
12
11
  ForestProblemMetaData,
13
12
  ProblemDB,
@@ -26,6 +25,8 @@ from desdeo.api.routers.user_authentication import get_current_user
26
25
  from desdeo.problem import Problem
27
26
  from desdeo.tools.utils import available_solvers
28
27
 
28
+ from .utils import SessionContext, get_session_context, get_session_context_without_request
29
+
29
30
  router = APIRouter(prefix="/problem")
30
31
 
31
32
 
@@ -87,15 +88,13 @@ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[
87
88
  @router.post("/get")
88
89
  def get_problem(
89
90
  request: ProblemGetRequest,
90
- user: Annotated[User, Depends(get_current_user)],
91
- session: Annotated[Session, Depends(get_session)],
91
+ context: Annotated[SessionContext, Depends(get_session_context)],
92
92
  ) -> ProblemInfo:
93
93
  """Get the model of a specific problem.
94
94
 
95
95
  Args:
96
96
  request (ProblemGetRequest): the request containing the problem's id `problem_id`.
97
- user (Annotated[User, Depends): the current user.
98
- session (Annotated[Session, Depends): the database session.
97
+ context (Annotated[SessionContext, Depends): the session context.
99
98
 
100
99
  Raises:
101
100
  HTTPException: could not find a problem with the given id.
@@ -103,29 +102,28 @@ def get_problem(
103
102
  Returns:
104
103
  ProblemInfo: detailed information on the requested problem.
105
104
  """
106
- problem = session.get(ProblemDB, request.problem_id)
105
+ problem_db = context.problem_db
107
106
 
108
- if problem is None:
107
+ # Ensure problem exists
108
+ if problem_db is None:
109
109
  raise HTTPException(
110
110
  status_code=status.HTTP_404_NOT_FOUND,
111
111
  detail=f"The problem with the requested id={request.problem_id} was not found.",
112
112
  )
113
113
 
114
- return problem
114
+ return problem_db
115
115
 
116
116
 
117
117
  @router.post("/add")
118
118
  def add_problem(
119
119
  request: Annotated[Problem, Depends(parse_problem_json)],
120
- user: Annotated[User, Depends(get_current_user)],
121
- session: Annotated[Session, Depends(get_session)],
120
+ context: Annotated[SessionContext, Depends(get_session_context_without_request)],
122
121
  ) -> ProblemInfo:
123
122
  """Add a newly defined problem to the database.
124
123
 
125
124
  Args:
126
125
  request (Problem): the JSON representation of the problem.
127
- user (Annotated[User, Depends): the current user.
128
- session (Annotated[Session, Depends): the database session.
126
+ context (Annotated[SessionContext, Depends): the session context.
129
127
 
130
128
  Note:
131
129
  Users with the role 'guest' may not add new problems.
@@ -136,10 +134,15 @@ def add_problem(
136
134
  Returns:
137
135
  ProblemInfo: the information about the problem added.
138
136
  """
137
+ user = context.user
138
+ db_session = context.db_session
139
+
139
140
  if user.role == UserRole.guest:
140
141
  raise HTTPException(
141
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Guest users are not allowed to add new problems."
142
+ status_code=status.HTTP_401_UNAUTHORIZED,
143
+ detail="Guest users are not allowed to add new problems.",
142
144
  )
145
+
143
146
  try:
144
147
  problem_db = ProblemDB.from_problem(request, user=user)
145
148
  except Exception as e:
@@ -148,9 +151,9 @@ def add_problem(
148
151
  detail=f"Could not add problem. Possible reason: {e!r}",
149
152
  ) from e
150
153
 
151
- session.add(problem_db)
152
- session.commit()
153
- session.refresh(problem_db)
154
+ db_session.add(problem_db)
155
+ db_session.commit()
156
+ db_session.refresh(problem_db)
154
157
 
155
158
  return problem_db
156
159
 
@@ -158,15 +161,13 @@ def add_problem(
158
161
  @router.post("/add_json")
159
162
  def add_problem_json(
160
163
  json_file: UploadFile,
161
- user: Annotated[User, Depends(get_current_user)],
162
- session: Annotated[Session, Depends(get_session)],
164
+ context: Annotated[SessionContext, Depends(get_session_context_without_request)],
163
165
  ) -> ProblemInfo:
164
166
  """Adds a problem to the database based on its JSON definition.
165
167
 
166
168
  Args:
167
169
  json_file (UploadFile): a file in JSON format describing the problem.
168
- user (Annotated[User, Depends): the usr for which the problem is added.
169
- session (Annotated[Session, Depends): the database session.
170
+ context (Annotated[SessionContext, Depends): the session context.
170
171
 
171
172
  Raises:
172
173
  HTTPException: if the provided `json_file` is empty.
@@ -175,23 +176,25 @@ def add_problem_json(
175
176
  Returns:
176
177
  ProblemInfo: a description of the added problem.
177
178
  """
179
+ user = context.user
180
+ db_session = context.db_session
181
+
178
182
  raw = json_file.file.read()
179
183
 
180
184
  if not raw:
181
- raise HTTPException(400, "Empty upload.")
185
+ raise HTTPException(status_code=400, detail="Empty upload.")
182
186
 
183
187
  try:
184
- # for extra validation
185
- json.loads(raw)
188
+ json.loads(raw) # extra validation
186
189
  except json.JSONDecodeError as e:
187
- raise HTTPException(400, "Invalid JSON.") from e
190
+ raise HTTPException(status_code=400, detail="Invalid JSON.") from e
188
191
 
189
192
  problem = Problem.model_validate_json(raw, by_name=True)
190
193
  problem_db = ProblemDB.from_problem(problem, user=user)
191
194
 
192
- session.add(problem_db)
193
- session.commit()
194
- session.refresh(problem_db)
195
+ db_session.add(problem_db)
196
+ db_session.commit()
197
+ db_session.refresh(problem_db)
195
198
 
196
199
  return problem_db
197
200
 
@@ -199,8 +202,7 @@ def add_problem_json(
199
202
  @router.post("/get_metadata")
200
203
  def get_metadata(
201
204
  request: ProblemMetaDataGetRequest,
202
- user: Annotated[User, Depends(get_current_user)],
203
- session: Annotated[Session, Depends(get_session)],
205
+ context: Annotated[SessionContext, Depends(get_session_context)],
204
206
  ) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]:
205
207
  """Fetch specific metadata for a specific problem.
206
208
 
@@ -210,19 +212,21 @@ def get_metadata(
210
212
 
211
213
  Args:
212
214
  request (MetaDataGetRequest): the requested metadata type.
213
- user (Annotated[User, Depends]): the current user.
214
- session (Annotated[Session, Depends]): the database session.
215
+ context (Annotated[SessionContext, Depends]): the session context.
215
216
 
216
217
  Returns:
217
218
  list[ForestProblemMetadata | RepresentativeNonDominatedSolutions]: list containing all the metadata
218
219
  defined for the problem with the requested metadata type. If no match is found,
219
220
  returns an empty list.
220
221
  """
221
- statement = select(ProblemDB).where(ProblemDB.id == request.problem_id)
222
- problem_from_db = session.exec(statement).first()
222
+ db_session = context.db_session
223
+
224
+ problem_from_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
225
+
223
226
  if problem_from_db is None:
224
227
  raise HTTPException(
225
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with ID {request.problem_id} not found!"
228
+ status_code=status.HTTP_404_NOT_FOUND,
229
+ detail=f"Problem with ID {request.problem_id} not found!",
226
230
  )
227
231
 
228
232
  problem_metadata = problem_from_db.problem_metadata
@@ -230,7 +234,6 @@ def get_metadata(
230
234
  if problem_metadata is None:
231
235
  # no metadata define for the problem
232
236
  return []
233
-
234
237
  # metadata is defined, try to find matching types based on request
235
238
  return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type]
236
239
 
@@ -240,17 +243,17 @@ def get_available_solvers() -> list[str]:
240
243
  """Return the list of available solver names."""
241
244
  return list(available_solvers.keys())
242
245
 
246
+
243
247
  @router.post("/assign_solver")
244
248
  def select_solver(
245
249
  request: ProblemSelectSolverRequest,
246
- user: Annotated[User, Depends(get_current_user)],
247
- session: Annotated[Session, Depends(get_session)],
250
+ context: Annotated[SessionContext, Depends(get_session_context)],
248
251
  ) -> JSONResponse:
249
252
  """Assign a specific solver for a problem.
250
253
 
251
- request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver
252
- user: Annotated[User, Depends(get_current_user): The user that is logged in.
253
- session: Annotated[Session, Depends(get_session)]: The database session.
254
+ Args:
255
+ request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver
256
+ context: Annotated[SessionContext, Depends(get_session)]: The session context.
254
257
 
255
258
  Raises:
256
259
  HTTPException: Unknown solver, unauthorized user
@@ -258,50 +261,60 @@ def select_solver(
258
261
  Returns:
259
262
  JSONResponse: A simple confirmation.
260
263
  """
264
+ db_session = context.db_session
265
+ user = context.user
266
+
267
+ # Validate solver type
261
268
  if request.solver_string_representation not in [x for x, _ in available_solvers.items()]:
262
269
  raise HTTPException(
263
270
  detail=f"Solver of unknown type: {request.solver_string_representation}",
264
271
  status_code=status.HTTP_404_NOT_FOUND,
265
272
  )
266
273
 
267
- """Set a specific solver for a specific problem."""
268
- # Get the problem
269
- problem_db = session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
274
+ # Fetch problem
275
+ problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
276
+
270
277
  if problem_db is None:
271
278
  raise HTTPException(
272
279
  detail=f"No problem with ID {request.problem_id}!",
273
280
  status_code=status.HTTP_404_NOT_FOUND,
274
281
  )
282
+
275
283
  # Auth the user
276
284
  if user.id != problem_db.user_id:
277
- raise HTTPException(detail="Unauthorized user!", status_code=status.HTTP_401_UNAUTHORIZED)
285
+ raise HTTPException(
286
+ detail="Unauthorized user!",
287
+ status_code=status.HTTP_401_UNAUTHORIZED,
288
+ )
278
289
 
279
290
  # All good, get on with it.
280
291
  problem_metadata = problem_db.problem_metadata
281
292
  if problem_metadata is None:
282
293
  # There's no metadata for this problem! Create some.
283
294
  problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db)
284
- session.add(problem_metadata)
285
- session.commit()
286
- session.refresh(problem_metadata)
295
+ db_session.add(problem_metadata)
296
+ db_session.commit()
297
+ db_session.refresh(problem_metadata)
287
298
 
299
+ # Remove existing solver selection metadata
288
300
  if problem_metadata.solver_selection_metadata:
289
- session.delete(problem_metadata.solver_selection_metadata[-1])
290
- session.commit()
301
+ db_session.delete(problem_metadata.solver_selection_metadata[-1])
302
+ db_session.commit()
291
303
 
304
+ # Add new solver selection metadata
292
305
  solver_selection_metadata = SolverSelectionMetadata(
293
306
  metadata_id=problem_metadata.id,
294
307
  solver_string_representation=request.solver_string_representation,
295
308
  metadata_instance=problem_metadata,
296
309
  )
297
310
 
298
- session.add(solver_selection_metadata)
299
- session.commit()
300
- session.refresh(solver_selection_metadata)
311
+ db_session.add(solver_selection_metadata)
312
+ db_session.commit()
313
+ db_session.refresh(solver_selection_metadata)
301
314
 
302
315
  problem_metadata.solver_selection_metadata.append(solver_selection_metadata)
303
- session.add(problem_metadata)
304
- session.commit()
305
- session.refresh(problem_metadata)
316
+ db_session.add(problem_metadata)
317
+ db_session.commit()
318
+ db_session.refresh(problem_metadata)
306
319
 
307
320
  return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK)
@@ -2,26 +2,20 @@
2
2
 
3
3
  from typing import Annotated
4
4
 
5
- from fastapi import APIRouter, Depends
6
- from sqlmodel import Session
5
+ from fastapi import APIRouter, Depends, HTTPException, status
7
6
 
8
- from desdeo.api.db import get_session
9
7
  from desdeo.api.models import (
10
- InteractiveSessionDB,
11
8
  PreferenceDB,
12
- ProblemDB,
13
9
  RPMSolveRequest,
14
10
  RPMState,
15
11
  StateDB,
16
- User,
17
12
  )
18
13
  from desdeo.api.routers.problem import check_solver
19
- from desdeo.api.routers.user_authentication import get_current_user
20
14
  from desdeo.mcdm import rpm_solve_solutions
21
15
  from desdeo.problem import Problem
22
16
  from desdeo.tools import SolverResults
23
17
 
24
- from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem
18
+ from .utils import SessionContext, get_session_context
25
19
 
26
20
  router = APIRouter(prefix="/method/rpm")
27
21
 
@@ -29,31 +23,35 @@ router = APIRouter(prefix="/method/rpm")
29
23
  @router.post("/solve")
30
24
  def solve_solutions(
31
25
  request: RPMSolveRequest,
32
- user: Annotated[User, Depends(get_current_user)],
33
- session: Annotated[Session, Depends(get_session)],
26
+ context: Annotated[SessionContext, Depends(get_session_context)],
34
27
  ) -> RPMState:
35
28
  """Runs an iteration of the reference point method.
36
29
 
37
30
  Args:
38
31
  request (RPMSolveRequest): a request with the needed information to run the method.
39
32
  user (Annotated[User, Depends): the current user.
40
- session (Annotated[Session, Depends): the current database session.
33
+ context (Annotated[SessionContext, Depends): the current session context.
41
34
 
42
35
  Returns:
43
36
  RPMState: a state with information on the results of iterating the reference point method
44
37
  once.
45
38
  """
46
- # fetch interactive session, parent state, and ProblemDB
47
- interactive_session: InteractiveSessionDB = fetch_interactive_session(user, request, session)
48
- parent_state = fetch_parent_state(user, request, session, interactive_session)
49
-
50
- problem_db: ProblemDB = fetch_user_problem(user, request, session)
39
+ user = context.user
40
+ db_session = context.db_session
41
+ problem_db = context.problem_db
42
+ interactive_session = context.interactive_session
43
+ parent_state = context.parent_state
44
+
45
+ # ensure problem exists
46
+ if problem_db is None:
47
+ raise HTTPException(
48
+ status_code=status.HTTP_400_BAD_REQUEST,
49
+ detail="Problem context missing.",
50
+ )
51
51
 
52
52
  solver = check_solver(problem_db=problem_db)
53
-
54
53
  problem = Problem.from_problemdb(problem_db)
55
54
 
56
- # optimize for solutions
57
55
  solver_results: list[SolverResults] = rpm_solve_solutions(
58
56
  problem,
59
57
  request.preference.aspiration_levels,
@@ -63,11 +61,15 @@ def solve_solutions(
63
61
  )
64
62
 
65
63
  # create DB preference
66
- preference_db = PreferenceDB(user_id=user.id, problem_id=problem_db.id, preference=request.preference)
64
+ preference_db = PreferenceDB(
65
+ user_id=user.id,
66
+ problem_id=problem_db.id,
67
+ preference=request.preference,
68
+ )
67
69
 
68
- session.add(preference_db)
69
- session.commit()
70
- session.refresh(preference_db)
70
+ db_session.add(preference_db)
71
+ db_session.commit()
72
+ db_session.refresh(preference_db)
71
73
 
72
74
  # create state and add to DB
73
75
  rpm_state = RPMState(
@@ -81,13 +83,13 @@ def solve_solutions(
81
83
  state = StateDB(
82
84
  problem_id=problem_db.id,
83
85
  preference_id=preference_db.id,
84
- session_id=interactive_session.id if interactive_session is not None else None,
85
- parent_id=parent_state.id if parent_state is not None else None,
86
+ session_id=interactive_session.id if interactive_session else None,
87
+ parent_id=parent_state.id if parent_state else None,
86
88
  state=rpm_state,
87
89
  )
88
90
 
89
- session.add(state)
90
- session.commit()
91
- session.refresh(state)
91
+ db_session.add(state)
92
+ db_session.commit()
93
+ db_session.refresh(state)
92
94
 
93
95
  return rpm_state
@@ -14,7 +14,7 @@ from desdeo.api.models import (
14
14
  User,
15
15
  )
16
16
  from desdeo.api.routers.user_authentication import get_current_user
17
- from desdeo.api.routers.utils import fetch_interactive_session
17
+ from desdeo.api.routers.utils import SessionContext, fetch_interactive_session, get_session_context_without_request
18
18
 
19
19
  router = APIRouter(prefix="/session")
20
20
 
@@ -22,21 +22,25 @@ router = APIRouter(prefix="/session")
22
22
  @router.post("/new")
23
23
  def create_new_session(
24
24
  request: CreateSessionRequest,
25
- user: Annotated[User, Depends(get_current_user)],
26
- session: Annotated[Session, Depends(get_db_session)],
25
+ context: Annotated[SessionContext, Depends(get_session_context_without_request)],
27
26
  ) -> InteractiveSessionInfo:
28
- """."""
29
- interactive_session = InteractiveSessionDB(user_id=user.id, info=request.info)
27
+ """Creates a new interactive session."""
28
+ user = context.user
29
+ db_session = context.db_session
30
+
31
+ interactive_session = InteractiveSessionDB(
32
+ user_id=user.id,
33
+ info=request.info,
34
+ )
30
35
 
31
- session.add(interactive_session)
32
- session.commit()
33
- session.refresh(interactive_session)
36
+ db_session.add(interactive_session)
37
+ db_session.commit()
38
+ db_session.refresh(interactive_session)
34
39
 
35
40
  user.active_session_id = interactive_session.id
36
41
 
37
- session.add(user)
38
- session.commit()
39
- session.refresh(interactive_session)
42
+ db_session.add(user)
43
+ db_session.commit()
40
44
 
41
45
  return interactive_session
42
46
 
@@ -9,8 +9,6 @@ from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, Securit
9
9
  from fastapi.responses import JSONResponse
10
10
  from fastapi.security import (
11
11
  APIKeyCookie,
12
- HTTPAuthorizationCredentials,
13
- HTTPBearer,
14
12
  OAuth2PasswordBearer,
15
13
  OAuth2PasswordRequestForm,
16
14
  )
@@ -119,7 +117,8 @@ def get_current_user(
119
117
  This function is a dependency for other functions that need to get the current user.
120
118
 
121
119
  Args:
122
- token (Annotated[str, Depends(oauth2_scheme)]): The authentication token.
120
+ header_token (Annotated[str, Depends(oauth2_scheme)]): The authentication token as part of the request header.
121
+ cookie_token (Annotated[str, Depends(cookie_scheme)]): The authentication token as part of request cookie.
123
122
  session (Annotated[Session, Depends(get_db)]): A database session.
124
123
 
125
124
  Returns:
@@ -447,8 +446,31 @@ def refresh_access_token(
447
446
 
448
447
  # Generate a new access token for the user
449
448
  access_token = create_access_token({"id": user.id, "sub": user.username})
449
+ response = JSONResponse(content={"access_token": access_token})
450
450
 
451
- return {"access_token": access_token}
451
+ if AuthConfig.cookie_domain == "":
452
+ response.set_cookie(
453
+ key="access_token",
454
+ value=access_token,
455
+ httponly=True,
456
+ secure=False,
457
+ samesite="lax",
458
+ max_age=AuthConfig.authjwt_access_token_expires * 60,
459
+ path="/",
460
+ )
461
+ else:
462
+ response.set_cookie(
463
+ key="access_token",
464
+ value=access_token,
465
+ httponly=True,
466
+ secure=True,
467
+ samesite="none",
468
+ max_age=AuthConfig.authjwt_access_token_expires * 60,
469
+ path="/",
470
+ domain=AuthConfig.cookie_domain,
471
+ )
472
+
473
+ return response
452
474
 
453
475
 
454
476
  @router.post("/add_new_dm")
@@ -502,7 +524,7 @@ def add_new_analyst(
502
524
 
503
525
  """
504
526
  # Check if the user who tries to create the user is either an analyst or an admin.
505
- if not (user.role == UserRole.analyst or user.role == UserRole.admin):
527
+ if user.role not in (UserRole.analyst, UserRole.admin):
506
528
  raise HTTPException(
507
529
  status_code=status.HTTP_401_UNAUTHORIZED,
508
530
  detail="Logged in user has insufficient rights.",
@@ -67,7 +67,7 @@ def fetch_interactive_session(user: User, request: RequestType, session: Session
67
67
  return interactive_session
68
68
 
69
69
 
70
- def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB:
70
+ def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB | None:
71
71
  """Fetches a user's `ProblemDB` based on the id in the given request.
72
72
 
73
73
  Args:
@@ -81,19 +81,21 @@ def fetch_user_problem(user: User, request: RequestType, session: Session) -> Pr
81
81
  Returns:
82
82
  Problem: the instance of `ProblemDB` with the given id.
83
83
  """
84
- statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
85
- problem_db = session.exec(statement).first()
84
+ if request.problem_id is None:
85
+ return None
86
86
 
87
- if problem_db is None:
88
- raise HTTPException(
89
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
90
- )
91
-
92
- return problem_db
87
+ statement = select(ProblemDB).where(
88
+ ProblemDB.user_id == user.id,
89
+ ProblemDB.id == request.problem_id,
90
+ )
91
+ return session.exec(statement).first()
93
92
 
94
93
 
95
94
  def fetch_parent_state(
96
- user: User, request: RequestType, session: Session, interactive_session: InteractiveSessionDB | None = None
95
+ user: User,
96
+ request: RequestType,
97
+ session: Session,
98
+ interactive_session: InteractiveSessionDB | None = None,
97
99
  ) -> StateDB | None:
98
100
  """Fetches the parent state, if an id is given, or if defined in the given interactive session.
99
101
 
@@ -105,18 +107,17 @@ def fetch_parent_state(
105
107
  given `interactive_session`, if available. If neither source provides a
106
108
  parent state, `None` is returned.
107
109
 
108
-
109
110
  Args:
110
- user (User): the user for which the parent state is fetched.
111
- request (RequestType): request containing details about the parent state and optionally the
112
- interactive session.
113
- session (Session): the database session from which to fetch the parent state.
114
- interactive_session (InteractiveSessionDB | None, optional): the interactive session containing
115
- information about the parent state. Defaults to None.
111
+ user (User): the user for which the parent state is fetched.
112
+ request (RequestType): request containing details about the parent state and optionally the
113
+ interactive session.
114
+ session (Session): the database session from which to fetch the parent state.
115
+ interactive_session (InteractiveSessionDB | None, optional): the interactive session containing
116
+ information about the parent state. Defaults to None.
116
117
 
117
118
  Raises:
118
- HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot
119
- be found in the given database session.
119
+ HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot
120
+ be found in the given database session.
120
121
 
121
122
  Returns:
122
123
  StateDB | None: if `request.parent_state_id` is given, returns the corresponding `StateDB`.
@@ -126,23 +127,19 @@ def fetch_parent_state(
126
127
  if request.parent_state_id is None:
127
128
  # parent state is assumed to be the last sate added to the session.
128
129
  # if `interactive_session` is None, then parent state is set to None.
129
- parent_state = (
130
- interactive_session.states[-1]
131
- if (interactive_session is not None and len(interactive_session.states) > 0)
132
- else None
133
- )
130
+ return interactive_session.states[-1] if interactive_session and interactive_session.states else None
134
131
 
135
- else:
136
- # request.parent_state_id is not None
137
- statement = select(StateDB).where(StateDB.id == request.parent_state_id)
138
- parent_state = session.exec(statement).first()
132
+ # request.parent_state_id is not None
133
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
134
+ parent_state = session.exec(statement).first()
139
135
 
140
- # this error is raised because if a parent_state_id is given, it is assumed that the
141
- # user wished to use that state explicitly as the parent.
142
- if parent_state is None:
143
- raise HTTPException(
144
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
145
- )
136
+ # this error is raised because if a parent_state_id is given, it is assumed that the
137
+ # user wished to use that state explicitly as the parent.
138
+ if parent_state is None:
139
+ raise HTTPException(
140
+ status_code=status.HTTP_404_NOT_FOUND,
141
+ detail=f"Could not find state with id={request.parent_state_id}",
142
+ )
146
143
 
147
144
  return parent_state
148
145
 
@@ -153,9 +150,9 @@ class SessionContext:
153
150
 
154
151
  user: User
155
152
  db_session: Session
156
- problem_db: ProblemDB
157
- interactive_session: InteractiveSessionDB | None
158
- parent_state: StateDB | None
153
+ problem_db: ProblemDB | None = None
154
+ interactive_session: InteractiveSessionDB | None = None
155
+ parent_state: StateDB | None = None
159
156
 
160
157
 
161
158
  def get_session_context(
@@ -185,3 +182,11 @@ def get_session_context(
185
182
  interactive_session=interactive_session,
186
183
  parent_state=parent_state,
187
184
  )
185
+
186
+
187
+ def get_session_context_without_request(
188
+ user: Annotated[User, Depends(get_current_user)],
189
+ db_session: Annotated[Session, Depends(get_session)],
190
+ ) -> SessionContext:
191
+ """Gets the current session context. Should be used as a dep."""
192
+ return SessionContext(user=user, db_session=db_session)