desdeo 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,9 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
6
6
  from numpy import allclose
7
7
  from sqlmodel import Session, select
8
8
 
9
- from desdeo.api.db import get_session
10
9
  from desdeo.api.models import (
11
- InteractiveSessionDB,
12
10
  IntermediateSolutionRequest,
13
11
  NIMBUSClassificationRequest,
14
12
  NIMBUSClassificationResponse,
@@ -25,7 +23,6 @@ from desdeo.api.models import (
25
23
  NIMBUSSaveRequest,
26
24
  NIMBUSSaveResponse,
27
25
  NIMBUSSaveState,
28
- ProblemDB,
29
26
  ReferencePoint,
30
27
  SavedSolutionReference,
31
28
  SolutionReference,
@@ -38,11 +35,12 @@ from desdeo.api.models.generic import SolutionInfo
38
35
  from desdeo.api.models.state import IntermediateSolutionState
39
36
  from desdeo.api.routers.generic import solve_intermediate
40
37
  from desdeo.api.routers.problem import check_solver
41
- from desdeo.api.routers.user_authentication import get_current_user
42
38
  from desdeo.mcdm.nimbus import generate_starting_point, solve_sub_problems
43
39
  from desdeo.problem import Problem
44
40
  from desdeo.tools import SolverResults
45
41
 
42
+ from .utils import SessionContext, get_session_context
43
+
46
44
  router = APIRouter(prefix="/method/nimbus")
47
45
 
48
46
 
@@ -50,7 +48,7 @@ router = APIRouter(prefix="/method/nimbus")
50
48
  def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]:
51
49
  """Filters out the duplicate values of objectives."""
52
50
  # No solutions or only one solution. There can not be any duplicates.
53
- if len(solutions) < 2:
51
+ if len(solutions) < 2: # noqa: PLR2004
54
52
  return solutions
55
53
 
56
54
  # Get the objective values
@@ -110,58 +108,24 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list
110
108
  @router.post("/solve")
111
109
  def solve_solutions(
112
110
  request: NIMBUSClassificationRequest,
113
- user: Annotated[User, Depends(get_current_user)],
114
- session: Annotated[Session, Depends(get_session)],
111
+ context: Annotated[SessionContext, Depends(get_session_context)],
115
112
  ) -> NIMBUSClassificationResponse:
116
113
  """Solve the problem using the NIMBUS method."""
117
- if request.session_id is not None:
118
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
119
- interactive_session = session.exec(statement)
120
-
121
- if interactive_session is None:
122
- raise HTTPException(
123
- status_code=status.HTTP_404_NOT_FOUND,
124
- detail=f"Could not find interactive session with id={request.session_id}.",
125
- )
126
- else:
127
- # request.session_id is None:
128
- # use active session instead
129
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
130
-
131
- interactive_session = session.exec(statement).first()
132
-
133
- # fetch the problem from the DB
134
- statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
135
- problem_db = session.exec(statement).first()
114
+ db_session = context.db_session
115
+ user = context.user
116
+ problem_db = context.problem_db
117
+ interactive_session = context.interactive_session
118
+ parent_state = context.parent_state
136
119
 
120
+ # Ensure problem exists
137
121
  if problem_db is None:
138
122
  raise HTTPException(
139
123
  status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
140
124
  )
141
125
 
142
126
  solver = check_solver(problem_db=problem_db)
143
-
144
127
  problem = Problem.from_problemdb(problem_db)
145
128
 
146
- # fetch parent state
147
- if request.parent_state_id is None:
148
- # parent state is assumed to be the last state added to the session.
149
- parent_state = (
150
- interactive_session.states[-1]
151
- if (interactive_session is not None and len(interactive_session.states) > 0)
152
- else None
153
- )
154
-
155
- else:
156
- # request.parent_state_id is not None
157
- statement = select(StateDB).where(StateDB.id == request.parent_state_id)
158
- parent_state = session.exec(statement).first()
159
-
160
- if parent_state is None:
161
- raise HTTPException(
162
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
163
- )
164
-
165
129
  solver_results: list[SolverResults] = solve_sub_problems(
166
130
  problem=problem,
167
131
  current_objectives=request.current_objectives,
@@ -185,24 +149,24 @@ def solve_solutions(
185
149
 
186
150
  # create DB state and add it to the DB
187
151
  state = StateDB.create(
188
- database_session=session,
152
+ database_session=db_session,
189
153
  problem_id=problem_db.id,
190
154
  session_id=interactive_session.id if interactive_session is not None else None,
191
155
  parent_id=parent_state.id if parent_state is not None else None,
192
156
  state=nimbus_state,
193
157
  )
194
158
 
195
- session.add(state)
196
- session.commit()
197
- session.refresh(state)
159
+ db_session.add(state)
160
+ db_session.commit()
161
+ db_session.refresh(state)
198
162
 
199
163
  # Collect all current solutions
200
164
  current_solutions: list[SolutionReference] = []
201
165
  for i, _ in enumerate(solver_results):
202
166
  current_solutions.append(SolutionReference(state=state, solution_index=i))
203
167
 
204
- saved_solutions = collect_saved_solutions(user, request.problem_id, session)
205
- all_solutions = collect_all_solutions(user, request.problem_id, session)
168
+ saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
169
+ all_solutions = collect_all_solutions(user, request.problem_id, db_session)
206
170
 
207
171
  return NIMBUSClassificationResponse(
208
172
  state_id=state.id,
@@ -217,31 +181,14 @@ def solve_solutions(
217
181
  @router.post("/initialize")
218
182
  def initialize(
219
183
  request: NIMBUSInitializationRequest,
220
- user: Annotated[User, Depends(get_current_user)],
221
- session: Annotated[Session, Depends(get_session)],
184
+ context: Annotated[SessionContext, Depends(get_session_context)],
222
185
  ) -> NIMBUSInitializationResponse:
223
186
  """Initialize the problem for the NIMBUS method."""
224
- if request.session_id is not None:
225
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
226
- interactive_session = session.exec(statement)
227
-
228
- if interactive_session is None:
229
- raise HTTPException(
230
- status_code=status.HTTP_404_NOT_FOUND,
231
- detail=f"Could not find interactive session with id={request.session_id}.",
232
- )
233
- else:
234
- # request.session_id is None:
235
- # use active session instead
236
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
237
-
238
- interactive_session = session.exec(statement).first()
239
-
240
- print(interactive_session)
241
-
242
- # fetch the problem from the DB
243
- statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
244
- problem_db = session.exec(statement).first()
187
+ db_session = context.db_session
188
+ user = context.user
189
+ problem_db = context.problem_db
190
+ interactive_session = context.interactive_session
191
+ parent_state = context.parent_state
245
192
 
246
193
  if problem_db is None:
247
194
  raise HTTPException(
@@ -249,18 +196,15 @@ def initialize(
249
196
  )
250
197
 
251
198
  solver = check_solver(problem_db=problem_db)
252
-
253
199
  problem = Problem.from_problemdb(problem_db)
254
200
 
255
201
  if isinstance(ref_point := request.starting_point, ReferencePoint):
256
- # ReferencePoint
257
202
  starting_point = ref_point.aspiration_levels
258
203
 
259
204
  elif isinstance(info := request.starting_point, SolutionInfo):
260
- # SolutionInfo
261
205
  # fetch the solution
262
206
  statement = select(StateDB).where(StateDB.id == info.state_id)
263
- state = session.exec(statement).first()
207
+ state = db_session.exec(statement).first()
264
208
 
265
209
  if state is None:
266
210
  raise HTTPException(
@@ -270,7 +214,6 @@ def initialize(
270
214
  starting_point = state.state.result_objective_values[info.solution_index]
271
215
 
272
216
  else:
273
- # if not starting point is provided, generate it
274
217
  starting_point = None
275
218
 
276
219
  start_result = generate_starting_point(
@@ -281,18 +224,6 @@ def initialize(
281
224
  solver_options=request.solver_options,
282
225
  )
283
226
 
284
- # fetch parent state if it is given
285
- if request.parent_state_id is None:
286
- parent_state = None
287
- else:
288
- statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
289
- parent_state = session.exec(statement).first()
290
-
291
- if parent_state is None:
292
- raise HTTPException(
293
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
294
- )
295
-
296
227
  initialization_state = NIMBUSInitializationState(
297
228
  reference_point=starting_point,
298
229
  scalarization_options=request.scalarization_options,
@@ -302,20 +233,20 @@ def initialize(
302
233
 
303
234
  # create DB state and add it to the DB
304
235
  state = StateDB.create(
305
- database_session=session,
236
+ database_session=db_session,
306
237
  problem_id=problem_db.id,
307
- session_id=interactive_session.id if interactive_session is not None else None,
308
- parent_id=parent_state.id if parent_state is not None else None,
238
+ session_id=interactive_session.id if interactive_session else None,
239
+ parent_id=parent_state.id if parent_state else None,
309
240
  state=initialization_state,
310
241
  )
311
242
 
312
- session.add(state)
313
- session.commit()
314
- session.refresh(state)
243
+ db_session.add(state)
244
+ db_session.commit()
245
+ db_session.refresh(state)
315
246
 
316
247
  current_solutions = [SolutionReference(state=state, solution_index=0)]
317
- saved_solutions = collect_saved_solutions(user, request.problem_id, session)
318
- all_solutions = collect_all_solutions(user, request.problem_id, session)
248
+ saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
249
+ all_solutions = collect_all_solutions(user, request.problem_id, db_session)
319
250
 
320
251
  return NIMBUSInitializationResponse(
321
252
  state_id=state.id,
@@ -327,40 +258,22 @@ def initialize(
327
258
 
328
259
  @router.post("/save")
329
260
  def save(
330
- request: NIMBUSSaveRequest,
331
- user: Annotated[User, Depends(get_current_user)],
332
- session: Annotated[Session, Depends(get_session)],
261
+ request: NIMBUSSaveRequest, context: Annotated[SessionContext, Depends(get_session_context)]
333
262
  ) -> NIMBUSSaveResponse:
334
263
  """Save solutions."""
335
- if request.session_id is not None:
336
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
337
- interactive_session = session.exec(statement)
338
-
339
- if interactive_session is None:
340
- raise HTTPException(
341
- status_code=status.HTTP_404_NOT_FOUND,
342
- detail=f"Could not find interactive session with id={request.session_id}.",
343
- )
344
- else:
345
- # request.session_id is None:
346
- # use active session instead
347
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
264
+ db_session = context.db_session
265
+ user = context.user
266
+ interactive_session = context.interactive_session
267
+ parent_state = context.parent_state
348
268
 
349
- interactive_session = session.exec(statement).first()
350
-
351
- # fetch parent state
352
269
  if request.parent_state_id is None:
353
- # parent state is assumed to be the last state added to the session.
354
270
  parent_state = (
355
271
  interactive_session.states[-1]
356
272
  if (interactive_session is not None and len(interactive_session.states) > 0)
357
273
  else None
358
274
  )
359
-
360
275
  else:
361
- # request.parent_state_id is not None
362
- statement = select(StateDB).where(StateDB.id == request.parent_state_id)
363
- parent_state = session.exec(statement).first()
276
+ parent_state = db_session.exec(select(StateDB).where(StateDB.id == request.parent_state_id)).first()
364
277
 
365
278
  if parent_state is None:
366
279
  raise HTTPException(
@@ -372,7 +285,7 @@ def save(
372
285
  new_solutions: list[UserSavedSolutionDB] = []
373
286
 
374
287
  for info in request.solution_info:
375
- existing_solution = session.exec(
288
+ existing_solution = db_session.exec(
376
289
  select(UserSavedSolutionDB).where(
377
290
  UserSavedSolutionDB.origin_state_id == info.state_id,
378
291
  UserSavedSolutionDB.solution_index == info.solution_index,
@@ -380,42 +293,38 @@ def save(
380
293
  ).first()
381
294
 
382
295
  if existing_solution is not None:
383
- # Update the name of the existing solution
384
296
  existing_solution.name = info.name
385
-
386
- session.add(existing_solution)
387
-
297
+ db_session.add(existing_solution)
388
298
  updated_solutions.append(existing_solution)
299
+
389
300
  else:
390
- # This is a new solution
391
301
  new_solution = UserSavedSolutionDB.from_state_info(
392
- session, user.id, request.problem_id, info.state_id, info.solution_index, info.name
302
+ db_session, user.id, request.problem_id, info.state_id, info.solution_index, info.name
393
303
  )
394
304
 
395
- session.add(new_solution)
396
-
305
+ db_session.add(new_solution)
397
306
  new_solutions.append(new_solution)
398
307
 
399
308
  # Commit existing and new solutions
400
- if updated_solutions or new_solution:
401
- session.commit()
402
- [session.refresh(row) for row in updated_solutions + new_solutions]
309
+ if updated_solutions or new_solutions:
310
+ db_session.commit()
311
+ [db_session.refresh(row) for row in updated_solutions + new_solutions]
403
312
 
404
- # save solver results for state in SolverResults format just for consistency (dont save name field to state)
313
+ # save solver results for state in SolverResults format just for consistency
405
314
  save_state = NIMBUSSaveState(solutions=updated_solutions + new_solutions)
406
315
 
407
316
  # create DB state
408
317
  state = StateDB.create(
409
- database_session=session,
318
+ database_session=db_session,
410
319
  problem_id=request.problem_id,
411
320
  session_id=interactive_session.id if interactive_session is not None else None,
412
321
  parent_id=parent_state.id if parent_state is not None else None,
413
322
  state=save_state,
414
323
  )
415
324
 
416
- session.add(state)
417
- session.commit()
418
- session.refresh(state)
325
+ db_session.add(state)
326
+ db_session.commit()
327
+ db_session.refresh(state)
419
328
 
420
329
  return NIMBUSSaveResponse(state_id=state.id)
421
330
 
@@ -423,20 +332,22 @@ def save(
423
332
  @router.post("/intermediate")
424
333
  def solve_nimbus_intermediate(
425
334
  request: IntermediateSolutionRequest,
426
- user: Annotated[User, Depends(get_current_user)],
427
- session: Annotated[Session, Depends(get_session)],
335
+ context: Annotated[SessionContext, Depends(get_session_context)],
428
336
  ) -> NIMBUSIntermediateSolutionResponse:
429
337
  """Solve intermediate solutions by forwarding the request to generic intermediate endpoint with context nimbus."""
338
+ db_session = context.db_session
339
+ user = context.user
340
+
430
341
  # Add NIMBUS context to request
431
342
  request.context = "nimbus"
343
+
432
344
  # Forward to generic endpoint
433
- intermediate_response = solve_intermediate(request, user, session)
345
+ intermediate_response = solve_intermediate(request, context)
434
346
 
435
347
  # Get saved solutions for this user and problem
436
- saved_solutions = collect_saved_solutions(user, request.problem_id, session)
437
-
348
+ saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
438
349
  # Get all solutions including the newly generated intermediate ones
439
- all_solutions = collect_all_solutions(user, request.problem_id, session)
350
+ all_solutions = collect_all_solutions(user, request.problem_id, db_session)
440
351
 
441
352
  return NIMBUSIntermediateSolutionResponse(
442
353
  state_id=intermediate_response.state_id,
@@ -451,24 +362,17 @@ def solve_nimbus_intermediate(
451
362
  @router.post("/get-or-initialize")
452
363
  def get_or_initialize(
453
364
  request: NIMBUSInitializationRequest,
454
- user: Annotated[User, Depends(get_current_user)],
455
- session: Annotated[Session, Depends(get_session)],
456
- ) -> NIMBUSInitializationResponse | NIMBUSClassificationResponse | \
457
- NIMBUSIntermediateSolutionResponse | NIMBUSFinalizeResponse:
365
+ context: Annotated[SessionContext, Depends(get_session_context)],
366
+ ) -> (
367
+ NIMBUSInitializationResponse
368
+ | NIMBUSClassificationResponse
369
+ | NIMBUSIntermediateSolutionResponse
370
+ | NIMBUSFinalizeResponse
371
+ ):
458
372
  """Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't."""
459
- if request.session_id is not None:
460
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
461
- interactive_session = session.exec(statement)
462
-
463
- if interactive_session is None:
464
- raise HTTPException(
465
- status_code=status.HTTP_404_NOT_FOUND,
466
- detail=f"Could not find interactive session with id={request.session_id}.",
467
- )
468
- else:
469
- # use active session instead
470
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
471
- interactive_session = session.exec(statement).first()
373
+ db_session = context.db_session
374
+ user = context.user
375
+ interactive_session = context.interactive_session
472
376
 
473
377
  # Look for latest relevant state in the session
474
378
  statement = (
@@ -479,7 +383,7 @@ def get_or_initialize(
479
383
  )
480
384
  .order_by(StateDB.id.desc())
481
385
  )
482
- states = session.exec(statement).all()
386
+ states = db_session.exec(statement).all()
483
387
 
484
388
  # Find the latest relevant state (NIMBUS classification, initialization, or intermediate with NIMBUS context)
485
389
  latest_state = None
@@ -491,17 +395,15 @@ def get_or_initialize(
491
395
  break
492
396
 
493
397
  if latest_state is not None:
494
- saved_solutions = collect_saved_solutions(user, request.problem_id, session)
495
- all_solutions = collect_all_solutions(user, request.problem_id, session)
496
- # Handle both single result and list of results cases
398
+ saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
399
+ all_solutions = collect_all_solutions(user, request.problem_id, db_session)
400
+
497
401
  solver_results = latest_state.state.solver_results
498
- if isinstance(solver_results, list):
499
- current_solutions = [
500
- SolutionReference(state=latest_state, solution_index=i) for i in range(len(solver_results))
501
- ]
502
- else:
503
- # Single result case (NIMBUSInitializationState)
504
- current_solutions = [SolutionReference(state=latest_state, solution_index=0)]
402
+ current_solutions = (
403
+ [SolutionReference(state=latest_state, solution_index=i) for i in range(len(solver_results))]
404
+ if isinstance(solver_results, list)
405
+ else [SolutionReference(state=latest_state, solution_index=0)]
406
+ )
505
407
 
506
408
  if isinstance(latest_state.state, NIMBUSClassificationState):
507
409
  return NIMBUSClassificationResponse(
@@ -524,7 +426,6 @@ def get_or_initialize(
524
426
  )
525
427
 
526
428
  if isinstance(latest_state.state, NIMBUSFinalState):
527
-
528
429
  solution_index = latest_state.state.solution_result_index
529
430
  origin_state_id = latest_state.state.solution_origin_state_id
530
431
 
@@ -532,7 +433,7 @@ def get_or_initialize(
532
433
  solution_index=solution_index,
533
434
  state_id=origin_state_id,
534
435
  objective_values=latest_state.state.solver_results.optimal_objectives,
535
- variable_values=latest_state.state.solver_results.optimal_variables
436
+ variable_values=latest_state.state.solver_results.optimal_variables,
536
437
  )
537
438
 
538
439
  return NIMBUSFinalizeResponse(
@@ -541,7 +442,6 @@ def get_or_initialize(
541
442
  saved_solutions=saved_solutions,
542
443
  all_solutions=all_solutions,
543
444
  )
544
-
545
445
  # NIMBUSInitializationState
546
446
  return NIMBUSInitializationResponse(
547
447
  state_id=latest_state.id,
@@ -551,21 +451,18 @@ def get_or_initialize(
551
451
  )
552
452
 
553
453
  # No relevant state found, initialize a new one
554
- return initialize(request, user, session)
454
+ return initialize(request, context)
555
455
 
556
456
 
557
457
  @router.post("/finalize")
558
458
  def finalize_nimbus(
559
- request: NIMBUSFinalizeRequest,
560
- user: Annotated[User, Depends(get_current_user)],
561
- session: Annotated[Session, Depends(get_session)]
459
+ request: NIMBUSFinalizeRequest, context: Annotated[SessionContext, Depends(get_session_context)]
562
460
  ) -> NIMBUSFinalizeResponse:
563
461
  """An endpoint for finishing up the nimbus process.
564
462
 
565
463
  Args:
566
464
  request (NIMBUSFinalizeRequest): The request containing the final solution, etc.
567
- user (Annotated[User, Depends): The current user.
568
- session (Annotated[Session, Depends): The database session.
465
+ context (Annotated[User, get_session_context): The current context.
569
466
 
570
467
  Raises:
571
468
  HTTPException
@@ -573,47 +470,17 @@ def finalize_nimbus(
573
470
  Returns:
574
471
  NIMBUSFinalizeResponse: Response containing info on the final solution.
575
472
  """
576
- if request.session_id is not None:
577
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
578
- interactive_session = session.exec(statement)
579
-
580
- if interactive_session is None:
581
- raise HTTPException(
582
- status_code=status.HTTP_404_NOT_FOUND,
583
- detail=f"Could not find interactive session with id={request.session_id}.",
584
- )
585
- else:
586
- # request.session_id is None:
587
- # use active session instead
588
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
589
-
590
- interactive_session = session.exec(statement).first()
591
-
592
- if request.parent_state_id is None:
593
- parent_state = None
594
- else:
595
- statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
596
- parent_state = session.exec(statement).first()
597
-
598
- if parent_state is None:
599
- raise HTTPException(
600
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
601
- )
602
-
603
- # fetch the problem from the DB
604
- statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
605
- problem_db = session.exec(statement).first()
606
-
607
- if problem_db is None:
608
- raise HTTPException(
609
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
610
- )
473
+ db_session = context.db_session
474
+ user = context.user
475
+ interactive_session = context.interactive_session
476
+ parent_state = context.parent_state
477
+ problem_db = context.problem_db
611
478
 
612
479
  solution_state_id = request.solution_info.state_id
613
480
  solution_index = request.solution_info.solution_index
614
481
 
615
- statement = select(StateDB).where(StateDB.id == solution_state_id)
616
- actual_state = session.exec(statement).first().state
482
+ state = db_session.exec(select(StateDB).where(StateDB.id == solution_state_id)).first()
483
+ actual_state = state.state if state else None
617
484
  if actual_state is None:
618
485
  raise HTTPException(
619
486
  detail="No concrete substate!",
@@ -623,22 +490,22 @@ def finalize_nimbus(
623
490
  final_state = NIMBUSFinalState(
624
491
  solution_origin_state_id=solution_state_id,
625
492
  solution_result_index=solution_index,
626
- solver_results=actual_state.solver_results[solution_index]
493
+ solver_results=actual_state.solver_results[solution_index],
627
494
  )
628
495
 
629
496
  state = StateDB.create(
630
- database_session=session,
497
+ database_session=db_session,
631
498
  problem_id=problem_db.id,
632
499
  session_id=interactive_session.id if interactive_session is not None else None,
633
500
  parent_id=parent_state.id if parent_state is not None else None,
634
501
  state=final_state,
635
502
  )
636
503
 
637
- session.add(state)
638
- session.commit()
639
- session.refresh(state)
504
+ db_session.add(state)
505
+ db_session.commit()
506
+ db_session.refresh(state)
640
507
 
641
- solution_reference_response=SolutionReferenceResponse(
508
+ solution_reference_response = SolutionReferenceResponse(
642
509
  solution_index=solution_index,
643
510
  state_id=solution_state_id,
644
511
  objective_values=final_state.solver_results.optimal_objectives,
@@ -648,22 +515,21 @@ def finalize_nimbus(
648
515
  return NIMBUSFinalizeResponse(
649
516
  state_id=state.id,
650
517
  final_solution=solution_reference_response,
651
- saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=session),
652
- all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=session),
518
+ saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=db_session),
519
+ all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=db_session),
653
520
  )
654
521
 
522
+
655
523
  @router.post("/delete_save")
656
524
  def delete_save(
657
525
  request: NIMBUSDeleteSaveRequest,
658
- user: Annotated[User, Depends(get_current_user)],
659
- session: Annotated[Session, Depends(get_session)]
526
+ context: Annotated[SessionContext, Depends(get_session_context)],
660
527
  ) -> NIMBUSDeleteSaveResponse:
661
528
  """Endpoint for deleting saved solutions.
662
529
 
663
530
  Args:
664
531
  request (NIMBUSDeleteSaveRequest): request containing necessary information for deleting a save
665
- user (Annotated[User, Depends): the current (logged in) user
666
- session (Annotated[Session, Depends): database session
532
+ context (Annotated[SessionContext, Depends): session context
667
533
 
668
534
  Raises:
669
535
  HTTPException
@@ -671,7 +537,9 @@ def delete_save(
671
537
  Returns:
672
538
  NIMBUSDeleteSaveResponse: Response acknowledging the deletion of save and other useful info.
673
539
  """
674
- to_be_deleted = session.exec(
540
+ db_session = context.db_session
541
+
542
+ to_be_deleted = db_session.exec(
675
543
  select(UserSavedSolutionDB).where(
676
544
  UserSavedSolutionDB.origin_state_id == request.state_id,
677
545
  UserSavedSolutionDB.solution_index == request.solution_index,
@@ -679,15 +547,12 @@ def delete_save(
679
547
  ).first()
680
548
 
681
549
  if to_be_deleted is None:
682
- raise HTTPException(
683
- detail="Unable to find a saved solution!",
684
- status_code=status.HTTP_404_NOT_FOUND
685
- )
550
+ raise HTTPException(detail="Unable to find a saved solution!", status_code=status.HTTP_404_NOT_FOUND)
686
551
 
687
- session.delete(to_be_deleted)
688
- session.commit()
552
+ db_session.delete(to_be_deleted)
553
+ db_session.commit()
689
554
 
690
- to_be_deleted = session.exec(
555
+ to_be_deleted = db_session.exec(
691
556
  select(UserSavedSolutionDB).where(
692
557
  UserSavedSolutionDB.origin_state_id == request.state_id,
693
558
  UserSavedSolutionDB.solution_index == request.solution_index,
@@ -696,10 +561,6 @@ def delete_save(
696
561
 
697
562
  if to_be_deleted is not None:
698
563
  raise HTTPException(
699
- detail="Could not delete the saved solution!",
700
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
564
+ detail="Could not delete the saved solution!", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
701
565
  )
702
-
703
- return NIMBUSDeleteSaveResponse(
704
- message="Save deleted."
705
- )
566
+ return NIMBUSDeleteSaveResponse(message="Save deleted.")