desdeo 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/api/models/nimbus.py +8 -4
- desdeo/api/routers/emo.py +75 -104
- desdeo/api/routers/generic.py +26 -58
- desdeo/api/routers/nimbus.py +108 -247
- desdeo/api/routers/problem.py +69 -56
- desdeo/api/routers/reference_point_method.py +29 -27
- desdeo/api/routers/session.py +15 -11
- desdeo/api/routers/user_authentication.py +27 -5
- desdeo/api/routers/utils.py +42 -37
- desdeo/api/routers/utopia.py +11 -12
- desdeo/api/tests/test_routes.py +6 -5
- desdeo/emo/__init__.py +2 -0
- desdeo/emo/operators/__init__.py +1 -1
- desdeo/emo/operators/generator.py +153 -2
- desdeo/emo/options/__init__.py +4 -0
- desdeo/emo/options/generator.py +24 -0
- desdeo/problem/__init__.py +12 -11
- desdeo/problem/evaluator.py +4 -5
- desdeo/problem/gurobipy_evaluator.py +37 -12
- desdeo/problem/infix_parser.py +1 -16
- desdeo/problem/json_parser.py +7 -11
- desdeo/problem/schema.py +6 -9
- desdeo/problem/utils.py +1 -1
- desdeo/tools/pyomo_solver_interfaces.py +1 -1
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/METADATA +21 -12
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/RECORD +28 -28
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/WHEEL +1 -1
- {desdeo-2.1.0.dist-info → desdeo-2.2.0.dist-info}/licenses/LICENSE +0 -0
desdeo/api/routers/nimbus.py
CHANGED
|
@@ -6,9 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|
|
6
6
|
from numpy import allclose
|
|
7
7
|
from sqlmodel import Session, select
|
|
8
8
|
|
|
9
|
-
from desdeo.api.db import get_session
|
|
10
9
|
from desdeo.api.models import (
|
|
11
|
-
InteractiveSessionDB,
|
|
12
10
|
IntermediateSolutionRequest,
|
|
13
11
|
NIMBUSClassificationRequest,
|
|
14
12
|
NIMBUSClassificationResponse,
|
|
@@ -25,7 +23,6 @@ from desdeo.api.models import (
|
|
|
25
23
|
NIMBUSSaveRequest,
|
|
26
24
|
NIMBUSSaveResponse,
|
|
27
25
|
NIMBUSSaveState,
|
|
28
|
-
ProblemDB,
|
|
29
26
|
ReferencePoint,
|
|
30
27
|
SavedSolutionReference,
|
|
31
28
|
SolutionReference,
|
|
@@ -38,11 +35,12 @@ from desdeo.api.models.generic import SolutionInfo
|
|
|
38
35
|
from desdeo.api.models.state import IntermediateSolutionState
|
|
39
36
|
from desdeo.api.routers.generic import solve_intermediate
|
|
40
37
|
from desdeo.api.routers.problem import check_solver
|
|
41
|
-
from desdeo.api.routers.user_authentication import get_current_user
|
|
42
38
|
from desdeo.mcdm.nimbus import generate_starting_point, solve_sub_problems
|
|
43
39
|
from desdeo.problem import Problem
|
|
44
40
|
from desdeo.tools import SolverResults
|
|
45
41
|
|
|
42
|
+
from .utils import SessionContext, get_session_context
|
|
43
|
+
|
|
46
44
|
router = APIRouter(prefix="/method/nimbus")
|
|
47
45
|
|
|
48
46
|
|
|
@@ -50,7 +48,7 @@ router = APIRouter(prefix="/method/nimbus")
|
|
|
50
48
|
def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]:
|
|
51
49
|
"""Filters out the duplicate values of objectives."""
|
|
52
50
|
# No solutions or only one solution. There can not be any duplicates.
|
|
53
|
-
if len(solutions) < 2:
|
|
51
|
+
if len(solutions) < 2: # noqa: PLR2004
|
|
54
52
|
return solutions
|
|
55
53
|
|
|
56
54
|
# Get the objective values
|
|
@@ -110,58 +108,24 @@ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list
|
|
|
110
108
|
@router.post("/solve")
|
|
111
109
|
def solve_solutions(
|
|
112
110
|
request: NIMBUSClassificationRequest,
|
|
113
|
-
|
|
114
|
-
session: Annotated[Session, Depends(get_session)],
|
|
111
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
115
112
|
) -> NIMBUSClassificationResponse:
|
|
116
113
|
"""Solve the problem using the NIMBUS method."""
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
raise HTTPException(
|
|
123
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
124
|
-
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
125
|
-
)
|
|
126
|
-
else:
|
|
127
|
-
# request.session_id is None:
|
|
128
|
-
# use active session instead
|
|
129
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
130
|
-
|
|
131
|
-
interactive_session = session.exec(statement).first()
|
|
132
|
-
|
|
133
|
-
# fetch the problem from the DB
|
|
134
|
-
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
135
|
-
problem_db = session.exec(statement).first()
|
|
114
|
+
db_session = context.db_session
|
|
115
|
+
user = context.user
|
|
116
|
+
problem_db = context.problem_db
|
|
117
|
+
interactive_session = context.interactive_session
|
|
118
|
+
parent_state = context.parent_state
|
|
136
119
|
|
|
120
|
+
# Ensure problem exists
|
|
137
121
|
if problem_db is None:
|
|
138
122
|
raise HTTPException(
|
|
139
123
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
|
|
140
124
|
)
|
|
141
125
|
|
|
142
126
|
solver = check_solver(problem_db=problem_db)
|
|
143
|
-
|
|
144
127
|
problem = Problem.from_problemdb(problem_db)
|
|
145
128
|
|
|
146
|
-
# fetch parent state
|
|
147
|
-
if request.parent_state_id is None:
|
|
148
|
-
# parent state is assumed to be the last state added to the session.
|
|
149
|
-
parent_state = (
|
|
150
|
-
interactive_session.states[-1]
|
|
151
|
-
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
152
|
-
else None
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
else:
|
|
156
|
-
# request.parent_state_id is not None
|
|
157
|
-
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
158
|
-
parent_state = session.exec(statement).first()
|
|
159
|
-
|
|
160
|
-
if parent_state is None:
|
|
161
|
-
raise HTTPException(
|
|
162
|
-
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
|
|
163
|
-
)
|
|
164
|
-
|
|
165
129
|
solver_results: list[SolverResults] = solve_sub_problems(
|
|
166
130
|
problem=problem,
|
|
167
131
|
current_objectives=request.current_objectives,
|
|
@@ -185,24 +149,24 @@ def solve_solutions(
|
|
|
185
149
|
|
|
186
150
|
# create DB state and add it to the DB
|
|
187
151
|
state = StateDB.create(
|
|
188
|
-
database_session=
|
|
152
|
+
database_session=db_session,
|
|
189
153
|
problem_id=problem_db.id,
|
|
190
154
|
session_id=interactive_session.id if interactive_session is not None else None,
|
|
191
155
|
parent_id=parent_state.id if parent_state is not None else None,
|
|
192
156
|
state=nimbus_state,
|
|
193
157
|
)
|
|
194
158
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
159
|
+
db_session.add(state)
|
|
160
|
+
db_session.commit()
|
|
161
|
+
db_session.refresh(state)
|
|
198
162
|
|
|
199
163
|
# Collect all current solutions
|
|
200
164
|
current_solutions: list[SolutionReference] = []
|
|
201
165
|
for i, _ in enumerate(solver_results):
|
|
202
166
|
current_solutions.append(SolutionReference(state=state, solution_index=i))
|
|
203
167
|
|
|
204
|
-
saved_solutions = collect_saved_solutions(user, request.problem_id,
|
|
205
|
-
all_solutions = collect_all_solutions(user, request.problem_id,
|
|
168
|
+
saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
|
|
169
|
+
all_solutions = collect_all_solutions(user, request.problem_id, db_session)
|
|
206
170
|
|
|
207
171
|
return NIMBUSClassificationResponse(
|
|
208
172
|
state_id=state.id,
|
|
@@ -217,31 +181,14 @@ def solve_solutions(
|
|
|
217
181
|
@router.post("/initialize")
|
|
218
182
|
def initialize(
|
|
219
183
|
request: NIMBUSInitializationRequest,
|
|
220
|
-
|
|
221
|
-
session: Annotated[Session, Depends(get_session)],
|
|
184
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
222
185
|
) -> NIMBUSInitializationResponse:
|
|
223
186
|
"""Initialize the problem for the NIMBUS method."""
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
raise HTTPException(
|
|
230
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
231
|
-
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
232
|
-
)
|
|
233
|
-
else:
|
|
234
|
-
# request.session_id is None:
|
|
235
|
-
# use active session instead
|
|
236
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
237
|
-
|
|
238
|
-
interactive_session = session.exec(statement).first()
|
|
239
|
-
|
|
240
|
-
print(interactive_session)
|
|
241
|
-
|
|
242
|
-
# fetch the problem from the DB
|
|
243
|
-
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
244
|
-
problem_db = session.exec(statement).first()
|
|
187
|
+
db_session = context.db_session
|
|
188
|
+
user = context.user
|
|
189
|
+
problem_db = context.problem_db
|
|
190
|
+
interactive_session = context.interactive_session
|
|
191
|
+
parent_state = context.parent_state
|
|
245
192
|
|
|
246
193
|
if problem_db is None:
|
|
247
194
|
raise HTTPException(
|
|
@@ -249,18 +196,15 @@ def initialize(
|
|
|
249
196
|
)
|
|
250
197
|
|
|
251
198
|
solver = check_solver(problem_db=problem_db)
|
|
252
|
-
|
|
253
199
|
problem = Problem.from_problemdb(problem_db)
|
|
254
200
|
|
|
255
201
|
if isinstance(ref_point := request.starting_point, ReferencePoint):
|
|
256
|
-
# ReferencePoint
|
|
257
202
|
starting_point = ref_point.aspiration_levels
|
|
258
203
|
|
|
259
204
|
elif isinstance(info := request.starting_point, SolutionInfo):
|
|
260
|
-
# SolutionInfo
|
|
261
205
|
# fetch the solution
|
|
262
206
|
statement = select(StateDB).where(StateDB.id == info.state_id)
|
|
263
|
-
state =
|
|
207
|
+
state = db_session.exec(statement).first()
|
|
264
208
|
|
|
265
209
|
if state is None:
|
|
266
210
|
raise HTTPException(
|
|
@@ -270,7 +214,6 @@ def initialize(
|
|
|
270
214
|
starting_point = state.state.result_objective_values[info.solution_index]
|
|
271
215
|
|
|
272
216
|
else:
|
|
273
|
-
# if not starting point is provided, generate it
|
|
274
217
|
starting_point = None
|
|
275
218
|
|
|
276
219
|
start_result = generate_starting_point(
|
|
@@ -281,18 +224,6 @@ def initialize(
|
|
|
281
224
|
solver_options=request.solver_options,
|
|
282
225
|
)
|
|
283
226
|
|
|
284
|
-
# fetch parent state if it is given
|
|
285
|
-
if request.parent_state_id is None:
|
|
286
|
-
parent_state = None
|
|
287
|
-
else:
|
|
288
|
-
statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
289
|
-
parent_state = session.exec(statement).first()
|
|
290
|
-
|
|
291
|
-
if parent_state is None:
|
|
292
|
-
raise HTTPException(
|
|
293
|
-
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
|
|
294
|
-
)
|
|
295
|
-
|
|
296
227
|
initialization_state = NIMBUSInitializationState(
|
|
297
228
|
reference_point=starting_point,
|
|
298
229
|
scalarization_options=request.scalarization_options,
|
|
@@ -302,20 +233,20 @@ def initialize(
|
|
|
302
233
|
|
|
303
234
|
# create DB state and add it to the DB
|
|
304
235
|
state = StateDB.create(
|
|
305
|
-
database_session=
|
|
236
|
+
database_session=db_session,
|
|
306
237
|
problem_id=problem_db.id,
|
|
307
|
-
session_id=interactive_session.id if interactive_session
|
|
308
|
-
parent_id=parent_state.id if parent_state
|
|
238
|
+
session_id=interactive_session.id if interactive_session else None,
|
|
239
|
+
parent_id=parent_state.id if parent_state else None,
|
|
309
240
|
state=initialization_state,
|
|
310
241
|
)
|
|
311
242
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
243
|
+
db_session.add(state)
|
|
244
|
+
db_session.commit()
|
|
245
|
+
db_session.refresh(state)
|
|
315
246
|
|
|
316
247
|
current_solutions = [SolutionReference(state=state, solution_index=0)]
|
|
317
|
-
saved_solutions = collect_saved_solutions(user, request.problem_id,
|
|
318
|
-
all_solutions = collect_all_solutions(user, request.problem_id,
|
|
248
|
+
saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
|
|
249
|
+
all_solutions = collect_all_solutions(user, request.problem_id, db_session)
|
|
319
250
|
|
|
320
251
|
return NIMBUSInitializationResponse(
|
|
321
252
|
state_id=state.id,
|
|
@@ -327,40 +258,22 @@ def initialize(
|
|
|
327
258
|
|
|
328
259
|
@router.post("/save")
|
|
329
260
|
def save(
|
|
330
|
-
request: NIMBUSSaveRequest,
|
|
331
|
-
user: Annotated[User, Depends(get_current_user)],
|
|
332
|
-
session: Annotated[Session, Depends(get_session)],
|
|
261
|
+
request: NIMBUSSaveRequest, context: Annotated[SessionContext, Depends(get_session_context)]
|
|
333
262
|
) -> NIMBUSSaveResponse:
|
|
334
263
|
"""Save solutions."""
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
if interactive_session is None:
|
|
340
|
-
raise HTTPException(
|
|
341
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
342
|
-
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
343
|
-
)
|
|
344
|
-
else:
|
|
345
|
-
# request.session_id is None:
|
|
346
|
-
# use active session instead
|
|
347
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
264
|
+
db_session = context.db_session
|
|
265
|
+
user = context.user
|
|
266
|
+
interactive_session = context.interactive_session
|
|
267
|
+
parent_state = context.parent_state
|
|
348
268
|
|
|
349
|
-
interactive_session = session.exec(statement).first()
|
|
350
|
-
|
|
351
|
-
# fetch parent state
|
|
352
269
|
if request.parent_state_id is None:
|
|
353
|
-
# parent state is assumed to be the last state added to the session.
|
|
354
270
|
parent_state = (
|
|
355
271
|
interactive_session.states[-1]
|
|
356
272
|
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
357
273
|
else None
|
|
358
274
|
)
|
|
359
|
-
|
|
360
275
|
else:
|
|
361
|
-
|
|
362
|
-
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
363
|
-
parent_state = session.exec(statement).first()
|
|
276
|
+
parent_state = db_session.exec(select(StateDB).where(StateDB.id == request.parent_state_id)).first()
|
|
364
277
|
|
|
365
278
|
if parent_state is None:
|
|
366
279
|
raise HTTPException(
|
|
@@ -372,7 +285,7 @@ def save(
|
|
|
372
285
|
new_solutions: list[UserSavedSolutionDB] = []
|
|
373
286
|
|
|
374
287
|
for info in request.solution_info:
|
|
375
|
-
existing_solution =
|
|
288
|
+
existing_solution = db_session.exec(
|
|
376
289
|
select(UserSavedSolutionDB).where(
|
|
377
290
|
UserSavedSolutionDB.origin_state_id == info.state_id,
|
|
378
291
|
UserSavedSolutionDB.solution_index == info.solution_index,
|
|
@@ -380,42 +293,38 @@ def save(
|
|
|
380
293
|
).first()
|
|
381
294
|
|
|
382
295
|
if existing_solution is not None:
|
|
383
|
-
# Update the name of the existing solution
|
|
384
296
|
existing_solution.name = info.name
|
|
385
|
-
|
|
386
|
-
session.add(existing_solution)
|
|
387
|
-
|
|
297
|
+
db_session.add(existing_solution)
|
|
388
298
|
updated_solutions.append(existing_solution)
|
|
299
|
+
|
|
389
300
|
else:
|
|
390
|
-
# This is a new solution
|
|
391
301
|
new_solution = UserSavedSolutionDB.from_state_info(
|
|
392
|
-
|
|
302
|
+
db_session, user.id, request.problem_id, info.state_id, info.solution_index, info.name
|
|
393
303
|
)
|
|
394
304
|
|
|
395
|
-
|
|
396
|
-
|
|
305
|
+
db_session.add(new_solution)
|
|
397
306
|
new_solutions.append(new_solution)
|
|
398
307
|
|
|
399
308
|
# Commit existing and new solutions
|
|
400
|
-
if updated_solutions or
|
|
401
|
-
|
|
402
|
-
[
|
|
309
|
+
if updated_solutions or new_solutions:
|
|
310
|
+
db_session.commit()
|
|
311
|
+
[db_session.refresh(row) for row in updated_solutions + new_solutions]
|
|
403
312
|
|
|
404
|
-
# save solver results for state in SolverResults format just for consistency
|
|
313
|
+
# save solver results for state in SolverResults format just for consistency
|
|
405
314
|
save_state = NIMBUSSaveState(solutions=updated_solutions + new_solutions)
|
|
406
315
|
|
|
407
316
|
# create DB state
|
|
408
317
|
state = StateDB.create(
|
|
409
|
-
database_session=
|
|
318
|
+
database_session=db_session,
|
|
410
319
|
problem_id=request.problem_id,
|
|
411
320
|
session_id=interactive_session.id if interactive_session is not None else None,
|
|
412
321
|
parent_id=parent_state.id if parent_state is not None else None,
|
|
413
322
|
state=save_state,
|
|
414
323
|
)
|
|
415
324
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
325
|
+
db_session.add(state)
|
|
326
|
+
db_session.commit()
|
|
327
|
+
db_session.refresh(state)
|
|
419
328
|
|
|
420
329
|
return NIMBUSSaveResponse(state_id=state.id)
|
|
421
330
|
|
|
@@ -423,20 +332,22 @@ def save(
|
|
|
423
332
|
@router.post("/intermediate")
|
|
424
333
|
def solve_nimbus_intermediate(
|
|
425
334
|
request: IntermediateSolutionRequest,
|
|
426
|
-
|
|
427
|
-
session: Annotated[Session, Depends(get_session)],
|
|
335
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
428
336
|
) -> NIMBUSIntermediateSolutionResponse:
|
|
429
337
|
"""Solve intermediate solutions by forwarding the request to generic intermediate endpoint with context nimbus."""
|
|
338
|
+
db_session = context.db_session
|
|
339
|
+
user = context.user
|
|
340
|
+
|
|
430
341
|
# Add NIMBUS context to request
|
|
431
342
|
request.context = "nimbus"
|
|
343
|
+
|
|
432
344
|
# Forward to generic endpoint
|
|
433
|
-
intermediate_response = solve_intermediate(request,
|
|
345
|
+
intermediate_response = solve_intermediate(request, context)
|
|
434
346
|
|
|
435
347
|
# Get saved solutions for this user and problem
|
|
436
|
-
saved_solutions = collect_saved_solutions(user, request.problem_id,
|
|
437
|
-
|
|
348
|
+
saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
|
|
438
349
|
# Get all solutions including the newly generated intermediate ones
|
|
439
|
-
all_solutions = collect_all_solutions(user, request.problem_id,
|
|
350
|
+
all_solutions = collect_all_solutions(user, request.problem_id, db_session)
|
|
440
351
|
|
|
441
352
|
return NIMBUSIntermediateSolutionResponse(
|
|
442
353
|
state_id=intermediate_response.state_id,
|
|
@@ -451,24 +362,17 @@ def solve_nimbus_intermediate(
|
|
|
451
362
|
@router.post("/get-or-initialize")
|
|
452
363
|
def get_or_initialize(
|
|
453
364
|
request: NIMBUSInitializationRequest,
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
365
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
366
|
+
) -> (
|
|
367
|
+
NIMBUSInitializationResponse
|
|
368
|
+
| NIMBUSClassificationResponse
|
|
369
|
+
| NIMBUSIntermediateSolutionResponse
|
|
370
|
+
| NIMBUSFinalizeResponse
|
|
371
|
+
):
|
|
458
372
|
"""Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't."""
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
if interactive_session is None:
|
|
464
|
-
raise HTTPException(
|
|
465
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
466
|
-
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
467
|
-
)
|
|
468
|
-
else:
|
|
469
|
-
# use active session instead
|
|
470
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
471
|
-
interactive_session = session.exec(statement).first()
|
|
373
|
+
db_session = context.db_session
|
|
374
|
+
user = context.user
|
|
375
|
+
interactive_session = context.interactive_session
|
|
472
376
|
|
|
473
377
|
# Look for latest relevant state in the session
|
|
474
378
|
statement = (
|
|
@@ -479,7 +383,7 @@ def get_or_initialize(
|
|
|
479
383
|
)
|
|
480
384
|
.order_by(StateDB.id.desc())
|
|
481
385
|
)
|
|
482
|
-
states =
|
|
386
|
+
states = db_session.exec(statement).all()
|
|
483
387
|
|
|
484
388
|
# Find the latest relevant state (NIMBUS classification, initialization, or intermediate with NIMBUS context)
|
|
485
389
|
latest_state = None
|
|
@@ -491,17 +395,15 @@ def get_or_initialize(
|
|
|
491
395
|
break
|
|
492
396
|
|
|
493
397
|
if latest_state is not None:
|
|
494
|
-
saved_solutions = collect_saved_solutions(user, request.problem_id,
|
|
495
|
-
all_solutions = collect_all_solutions(user, request.problem_id,
|
|
496
|
-
|
|
398
|
+
saved_solutions = collect_saved_solutions(user, request.problem_id, db_session)
|
|
399
|
+
all_solutions = collect_all_solutions(user, request.problem_id, db_session)
|
|
400
|
+
|
|
497
401
|
solver_results = latest_state.state.solver_results
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
]
|
|
502
|
-
|
|
503
|
-
# Single result case (NIMBUSInitializationState)
|
|
504
|
-
current_solutions = [SolutionReference(state=latest_state, solution_index=0)]
|
|
402
|
+
current_solutions = (
|
|
403
|
+
[SolutionReference(state=latest_state, solution_index=i) for i in range(len(solver_results))]
|
|
404
|
+
if isinstance(solver_results, list)
|
|
405
|
+
else [SolutionReference(state=latest_state, solution_index=0)]
|
|
406
|
+
)
|
|
505
407
|
|
|
506
408
|
if isinstance(latest_state.state, NIMBUSClassificationState):
|
|
507
409
|
return NIMBUSClassificationResponse(
|
|
@@ -524,7 +426,6 @@ def get_or_initialize(
|
|
|
524
426
|
)
|
|
525
427
|
|
|
526
428
|
if isinstance(latest_state.state, NIMBUSFinalState):
|
|
527
|
-
|
|
528
429
|
solution_index = latest_state.state.solution_result_index
|
|
529
430
|
origin_state_id = latest_state.state.solution_origin_state_id
|
|
530
431
|
|
|
@@ -532,7 +433,7 @@ def get_or_initialize(
|
|
|
532
433
|
solution_index=solution_index,
|
|
533
434
|
state_id=origin_state_id,
|
|
534
435
|
objective_values=latest_state.state.solver_results.optimal_objectives,
|
|
535
|
-
variable_values=latest_state.state.solver_results.optimal_variables
|
|
436
|
+
variable_values=latest_state.state.solver_results.optimal_variables,
|
|
536
437
|
)
|
|
537
438
|
|
|
538
439
|
return NIMBUSFinalizeResponse(
|
|
@@ -541,7 +442,6 @@ def get_or_initialize(
|
|
|
541
442
|
saved_solutions=saved_solutions,
|
|
542
443
|
all_solutions=all_solutions,
|
|
543
444
|
)
|
|
544
|
-
|
|
545
445
|
# NIMBUSInitializationState
|
|
546
446
|
return NIMBUSInitializationResponse(
|
|
547
447
|
state_id=latest_state.id,
|
|
@@ -551,21 +451,18 @@ def get_or_initialize(
|
|
|
551
451
|
)
|
|
552
452
|
|
|
553
453
|
# No relevant state found, initialize a new one
|
|
554
|
-
return initialize(request,
|
|
454
|
+
return initialize(request, context)
|
|
555
455
|
|
|
556
456
|
|
|
557
457
|
@router.post("/finalize")
|
|
558
458
|
def finalize_nimbus(
|
|
559
|
-
request: NIMBUSFinalizeRequest,
|
|
560
|
-
user: Annotated[User, Depends(get_current_user)],
|
|
561
|
-
session: Annotated[Session, Depends(get_session)]
|
|
459
|
+
request: NIMBUSFinalizeRequest, context: Annotated[SessionContext, Depends(get_session_context)]
|
|
562
460
|
) -> NIMBUSFinalizeResponse:
|
|
563
461
|
"""An endpoint for finishing up the nimbus process.
|
|
564
462
|
|
|
565
463
|
Args:
|
|
566
464
|
request (NIMBUSFinalizeRequest): The request containing the final solution, etc.
|
|
567
|
-
|
|
568
|
-
session (Annotated[Session, Depends): The database session.
|
|
465
|
+
context (Annotated[User, get_session_context): The current context.
|
|
569
466
|
|
|
570
467
|
Raises:
|
|
571
468
|
HTTPException
|
|
@@ -573,47 +470,17 @@ def finalize_nimbus(
|
|
|
573
470
|
Returns:
|
|
574
471
|
NIMBUSFinalizeResponse: Response containing info on the final solution.
|
|
575
472
|
"""
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
raise HTTPException(
|
|
582
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
583
|
-
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
584
|
-
)
|
|
585
|
-
else:
|
|
586
|
-
# request.session_id is None:
|
|
587
|
-
# use active session instead
|
|
588
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
589
|
-
|
|
590
|
-
interactive_session = session.exec(statement).first()
|
|
591
|
-
|
|
592
|
-
if request.parent_state_id is None:
|
|
593
|
-
parent_state = None
|
|
594
|
-
else:
|
|
595
|
-
statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
596
|
-
parent_state = session.exec(statement).first()
|
|
597
|
-
|
|
598
|
-
if parent_state is None:
|
|
599
|
-
raise HTTPException(
|
|
600
|
-
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
|
|
601
|
-
)
|
|
602
|
-
|
|
603
|
-
# fetch the problem from the DB
|
|
604
|
-
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
605
|
-
problem_db = session.exec(statement).first()
|
|
606
|
-
|
|
607
|
-
if problem_db is None:
|
|
608
|
-
raise HTTPException(
|
|
609
|
-
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
|
|
610
|
-
)
|
|
473
|
+
db_session = context.db_session
|
|
474
|
+
user = context.user
|
|
475
|
+
interactive_session = context.interactive_session
|
|
476
|
+
parent_state = context.parent_state
|
|
477
|
+
problem_db = context.problem_db
|
|
611
478
|
|
|
612
479
|
solution_state_id = request.solution_info.state_id
|
|
613
480
|
solution_index = request.solution_info.solution_index
|
|
614
481
|
|
|
615
|
-
|
|
616
|
-
actual_state =
|
|
482
|
+
state = db_session.exec(select(StateDB).where(StateDB.id == solution_state_id)).first()
|
|
483
|
+
actual_state = state.state if state else None
|
|
617
484
|
if actual_state is None:
|
|
618
485
|
raise HTTPException(
|
|
619
486
|
detail="No concrete substate!",
|
|
@@ -623,22 +490,22 @@ def finalize_nimbus(
|
|
|
623
490
|
final_state = NIMBUSFinalState(
|
|
624
491
|
solution_origin_state_id=solution_state_id,
|
|
625
492
|
solution_result_index=solution_index,
|
|
626
|
-
solver_results=actual_state.solver_results[solution_index]
|
|
493
|
+
solver_results=actual_state.solver_results[solution_index],
|
|
627
494
|
)
|
|
628
495
|
|
|
629
496
|
state = StateDB.create(
|
|
630
|
-
database_session=
|
|
497
|
+
database_session=db_session,
|
|
631
498
|
problem_id=problem_db.id,
|
|
632
499
|
session_id=interactive_session.id if interactive_session is not None else None,
|
|
633
500
|
parent_id=parent_state.id if parent_state is not None else None,
|
|
634
501
|
state=final_state,
|
|
635
502
|
)
|
|
636
503
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
504
|
+
db_session.add(state)
|
|
505
|
+
db_session.commit()
|
|
506
|
+
db_session.refresh(state)
|
|
640
507
|
|
|
641
|
-
solution_reference_response=SolutionReferenceResponse(
|
|
508
|
+
solution_reference_response = SolutionReferenceResponse(
|
|
642
509
|
solution_index=solution_index,
|
|
643
510
|
state_id=solution_state_id,
|
|
644
511
|
objective_values=final_state.solver_results.optimal_objectives,
|
|
@@ -648,22 +515,21 @@ def finalize_nimbus(
|
|
|
648
515
|
return NIMBUSFinalizeResponse(
|
|
649
516
|
state_id=state.id,
|
|
650
517
|
final_solution=solution_reference_response,
|
|
651
|
-
saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=
|
|
652
|
-
all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=
|
|
518
|
+
saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=db_session),
|
|
519
|
+
all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=db_session),
|
|
653
520
|
)
|
|
654
521
|
|
|
522
|
+
|
|
655
523
|
@router.post("/delete_save")
|
|
656
524
|
def delete_save(
|
|
657
525
|
request: NIMBUSDeleteSaveRequest,
|
|
658
|
-
|
|
659
|
-
session: Annotated[Session, Depends(get_session)]
|
|
526
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
660
527
|
) -> NIMBUSDeleteSaveResponse:
|
|
661
528
|
"""Endpoint for deleting saved solutions.
|
|
662
529
|
|
|
663
530
|
Args:
|
|
664
531
|
request (NIMBUSDeleteSaveRequest): request containing necessary information for deleting a save
|
|
665
|
-
|
|
666
|
-
session (Annotated[Session, Depends): database session
|
|
532
|
+
context (Annotated[SessionContext, Depends): session context
|
|
667
533
|
|
|
668
534
|
Raises:
|
|
669
535
|
HTTPException
|
|
@@ -671,7 +537,9 @@ def delete_save(
|
|
|
671
537
|
Returns:
|
|
672
538
|
NIMBUSDeleteSaveResponse: Response acknowledging the deletion of save and other useful info.
|
|
673
539
|
"""
|
|
674
|
-
|
|
540
|
+
db_session = context.db_session
|
|
541
|
+
|
|
542
|
+
to_be_deleted = db_session.exec(
|
|
675
543
|
select(UserSavedSolutionDB).where(
|
|
676
544
|
UserSavedSolutionDB.origin_state_id == request.state_id,
|
|
677
545
|
UserSavedSolutionDB.solution_index == request.solution_index,
|
|
@@ -679,15 +547,12 @@ def delete_save(
|
|
|
679
547
|
).first()
|
|
680
548
|
|
|
681
549
|
if to_be_deleted is None:
|
|
682
|
-
raise HTTPException(
|
|
683
|
-
detail="Unable to find a saved solution!",
|
|
684
|
-
status_code=status.HTTP_404_NOT_FOUND
|
|
685
|
-
)
|
|
550
|
+
raise HTTPException(detail="Unable to find a saved solution!", status_code=status.HTTP_404_NOT_FOUND)
|
|
686
551
|
|
|
687
|
-
|
|
688
|
-
|
|
552
|
+
db_session.delete(to_be_deleted)
|
|
553
|
+
db_session.commit()
|
|
689
554
|
|
|
690
|
-
to_be_deleted =
|
|
555
|
+
to_be_deleted = db_session.exec(
|
|
691
556
|
select(UserSavedSolutionDB).where(
|
|
692
557
|
UserSavedSolutionDB.origin_state_id == request.state_id,
|
|
693
558
|
UserSavedSolutionDB.solution_index == request.solution_index,
|
|
@@ -696,10 +561,6 @@ def delete_save(
|
|
|
696
561
|
|
|
697
562
|
if to_be_deleted is not None:
|
|
698
563
|
raise HTTPException(
|
|
699
|
-
detail="Could not delete the saved solution!",
|
|
700
|
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
564
|
+
detail="Could not delete the saved solution!", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
701
565
|
)
|
|
702
|
-
|
|
703
|
-
return NIMBUSDeleteSaveResponse(
|
|
704
|
-
message="Save deleted."
|
|
705
|
-
)
|
|
566
|
+
return NIMBUSDeleteSaveResponse(message="Save deleted.")
|