desdeo 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. desdeo/adm/ADMAfsar.py +551 -0
  2. desdeo/adm/ADMChen.py +414 -0
  3. desdeo/adm/BaseADM.py +119 -0
  4. desdeo/adm/__init__.py +11 -0
  5. desdeo/api/__init__.py +6 -6
  6. desdeo/api/app.py +38 -28
  7. desdeo/api/config.py +65 -44
  8. desdeo/api/config.toml +23 -12
  9. desdeo/api/db.py +10 -8
  10. desdeo/api/db_init.py +12 -6
  11. desdeo/api/models/__init__.py +220 -20
  12. desdeo/api/models/archive.py +16 -27
  13. desdeo/api/models/emo.py +128 -0
  14. desdeo/api/models/enautilus.py +69 -0
  15. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  16. desdeo/api/models/gdm/gdm_base.py +69 -0
  17. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  18. desdeo/api/models/gdm/gnimbus.py +138 -0
  19. desdeo/api/models/generic.py +104 -0
  20. desdeo/api/models/generic_states.py +401 -0
  21. desdeo/api/models/nimbus.py +158 -0
  22. desdeo/api/models/preference.py +44 -6
  23. desdeo/api/models/problem.py +274 -64
  24. desdeo/api/models/session.py +4 -1
  25. desdeo/api/models/state.py +419 -52
  26. desdeo/api/models/user.py +7 -6
  27. desdeo/api/models/utopia.py +25 -0
  28. desdeo/api/routers/_EMO.backup +309 -0
  29. desdeo/api/routers/_NIMBUS.py +6 -3
  30. desdeo/api/routers/emo.py +497 -0
  31. desdeo/api/routers/enautilus.py +237 -0
  32. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  33. desdeo/api/routers/gdm/gdm_base.py +420 -0
  34. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  35. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  36. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  37. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  38. desdeo/api/routers/generic.py +233 -0
  39. desdeo/api/routers/nimbus.py +705 -0
  40. desdeo/api/routers/problem.py +201 -4
  41. desdeo/api/routers/reference_point_method.py +20 -44
  42. desdeo/api/routers/session.py +50 -26
  43. desdeo/api/routers/user_authentication.py +180 -26
  44. desdeo/api/routers/utils.py +187 -0
  45. desdeo/api/routers/utopia.py +230 -0
  46. desdeo/api/schema.py +10 -4
  47. desdeo/api/tests/conftest.py +94 -2
  48. desdeo/api/tests/test_enautilus.py +330 -0
  49. desdeo/api/tests/test_models.py +550 -72
  50. desdeo/api/tests/test_routes.py +902 -43
  51. desdeo/api/utils/_database.py +263 -0
  52. desdeo/api/utils/database.py +28 -266
  53. desdeo/api/utils/emo_database.py +40 -0
  54. desdeo/core.py +7 -0
  55. desdeo/emo/__init__.py +154 -24
  56. desdeo/emo/hooks/archivers.py +18 -2
  57. desdeo/emo/methods/EAs.py +128 -5
  58. desdeo/emo/methods/bases.py +9 -56
  59. desdeo/emo/methods/templates.py +111 -0
  60. desdeo/emo/operators/crossover.py +544 -42
  61. desdeo/emo/operators/evaluator.py +10 -14
  62. desdeo/emo/operators/generator.py +127 -24
  63. desdeo/emo/operators/mutation.py +212 -41
  64. desdeo/emo/operators/scalar_selection.py +202 -0
  65. desdeo/emo/operators/selection.py +956 -214
  66. desdeo/emo/operators/termination.py +124 -16
  67. desdeo/emo/options/__init__.py +108 -0
  68. desdeo/emo/options/algorithms.py +435 -0
  69. desdeo/emo/options/crossover.py +164 -0
  70. desdeo/emo/options/generator.py +131 -0
  71. desdeo/emo/options/mutation.py +260 -0
  72. desdeo/emo/options/repair.py +61 -0
  73. desdeo/emo/options/scalar_selection.py +66 -0
  74. desdeo/emo/options/selection.py +127 -0
  75. desdeo/emo/options/templates.py +383 -0
  76. desdeo/emo/options/termination.py +143 -0
  77. desdeo/gdm/__init__.py +22 -0
  78. desdeo/gdm/gdmtools.py +45 -0
  79. desdeo/gdm/score_bands.py +114 -0
  80. desdeo/gdm/voting_rules.py +50 -0
  81. desdeo/mcdm/__init__.py +23 -1
  82. desdeo/mcdm/enautilus.py +338 -0
  83. desdeo/mcdm/gnimbus.py +484 -0
  84. desdeo/mcdm/nautilus_navigator.py +7 -6
  85. desdeo/mcdm/reference_point_method.py +70 -0
  86. desdeo/problem/__init__.py +16 -11
  87. desdeo/problem/evaluator.py +4 -5
  88. desdeo/problem/external/__init__.py +18 -0
  89. desdeo/problem/external/core.py +356 -0
  90. desdeo/problem/external/pymoo_provider.py +266 -0
  91. desdeo/problem/external/runtime.py +44 -0
  92. desdeo/problem/gurobipy_evaluator.py +37 -12
  93. desdeo/problem/infix_parser.py +1 -16
  94. desdeo/problem/json_parser.py +7 -11
  95. desdeo/problem/pyomo_evaluator.py +25 -6
  96. desdeo/problem/schema.py +73 -55
  97. desdeo/problem/simulator_evaluator.py +65 -15
  98. desdeo/problem/testproblems/__init__.py +26 -11
  99. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  100. desdeo/problem/testproblems/cake_problem.py +185 -0
  101. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  102. desdeo/problem/testproblems/forest_problem.py +77 -69
  103. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  104. desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
  105. desdeo/problem/testproblems/single_objective.py +289 -0
  106. desdeo/problem/testproblems/zdt_problem.py +4 -1
  107. desdeo/problem/utils.py +1 -1
  108. desdeo/tools/__init__.py +39 -21
  109. desdeo/tools/desc_gen.py +22 -0
  110. desdeo/tools/generics.py +22 -2
  111. desdeo/tools/group_scalarization.py +3090 -0
  112. desdeo/tools/indicators_binary.py +107 -1
  113. desdeo/tools/indicators_unary.py +3 -16
  114. desdeo/tools/message.py +33 -2
  115. desdeo/tools/non_dominated_sorting.py +4 -3
  116. desdeo/tools/patterns.py +9 -7
  117. desdeo/tools/pyomo_solver_interfaces.py +49 -36
  118. desdeo/tools/reference_vectors.py +118 -351
  119. desdeo/tools/scalarization.py +340 -1413
  120. desdeo/tools/score_bands.py +491 -328
  121. desdeo/tools/utils.py +117 -49
  122. desdeo/tools/visualizations.py +67 -0
  123. desdeo/utopia_stuff/utopia_problem.py +1 -1
  124. desdeo/utopia_stuff/utopia_problem_old.py +1 -1
  125. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/METADATA +47 -30
  126. desdeo-2.1.1.dist-info/RECORD +180 -0
  127. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/WHEEL +1 -1
  128. desdeo-2.0.0.dist-info/RECORD +0 -120
  129. /desdeo/api/utils/{logger.py → _logger.py} +0 -0
  130. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,705 @@
1
+ """Defines end-points to access functionalities related to the NIMBUS method."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
+ from numpy import allclose
7
+ from sqlmodel import Session, select
8
+
9
+ from desdeo.api.db import get_session
10
+ from desdeo.api.models import (
11
+ InteractiveSessionDB,
12
+ IntermediateSolutionRequest,
13
+ NIMBUSClassificationRequest,
14
+ NIMBUSClassificationResponse,
15
+ NIMBUSClassificationState,
16
+ NIMBUSDeleteSaveRequest,
17
+ NIMBUSDeleteSaveResponse,
18
+ NIMBUSFinalizeRequest,
19
+ NIMBUSFinalizeResponse,
20
+ NIMBUSFinalState,
21
+ NIMBUSInitializationRequest,
22
+ NIMBUSInitializationResponse,
23
+ NIMBUSInitializationState,
24
+ NIMBUSIntermediateSolutionResponse,
25
+ NIMBUSSaveRequest,
26
+ NIMBUSSaveResponse,
27
+ NIMBUSSaveState,
28
+ ProblemDB,
29
+ ReferencePoint,
30
+ SavedSolutionReference,
31
+ SolutionReference,
32
+ SolutionReferenceResponse,
33
+ StateDB,
34
+ User,
35
+ UserSavedSolutionDB,
36
+ )
37
+ from desdeo.api.models.generic import SolutionInfo
38
+ from desdeo.api.models.state import IntermediateSolutionState
39
+ from desdeo.api.routers.generic import solve_intermediate
40
+ from desdeo.api.routers.problem import check_solver
41
+ from desdeo.api.routers.user_authentication import get_current_user
42
+ from desdeo.mcdm.nimbus import generate_starting_point, solve_sub_problems
43
+ from desdeo.problem import Problem
44
+ from desdeo.tools import SolverResults
45
+
46
+ router = APIRouter(prefix="/method/nimbus")
47
+
48
+
49
+ # helper for collecting solutions
50
+ def filter_duplicates(solutions: list[SavedSolutionReference]) -> list[SavedSolutionReference]:
51
+ """Filters out the duplicate values of objectives."""
52
+ # No solutions or only one solution. There can not be any duplicates.
53
+ if len(solutions) < 2:
54
+ return solutions
55
+
56
+ # Get the objective values
57
+ objective_values_list = [sol.objective_values for sol in solutions]
58
+ # Get the function symbols
59
+ objective_keys = list(objective_values_list[0])
60
+ # Get the corresponding values for functions into a list of lists of values
61
+ valuelists = [[dictionary[key] for key in objective_keys] for dictionary in objective_values_list]
62
+ # Check duplicate indices
63
+ duplicate_indices = []
64
+ for i in range(len(solutions) - 1):
65
+ for j in range(i + 1, len(solutions)):
66
+ # If all values of the objective functions are (nearly) identical, that's a duplicate
67
+ if allclose(valuelists[i], valuelists[j]): # TODO: "similarity tolerance" from problem metadata
68
+ duplicate_indices.append(i)
69
+
70
+ # Quite the memory hell. See If there's a smarter way to do this
71
+ new_solutions = []
72
+ for i in range(len(solutions)):
73
+ if i not in duplicate_indices:
74
+ new_solutions.append(solutions[i])
75
+
76
+ return new_solutions
77
+
78
+
79
+ # for collecting solutions for responses in iterate and initialize endpoints
80
+ def collect_saved_solutions(user: User, problem_id: int, session: Session) -> list[SavedSolutionReference]:
81
+ """Collects all saved solutions for the user and problem."""
82
+ user_saved_solutions = session.exec(
83
+ select(UserSavedSolutionDB).where(
84
+ UserSavedSolutionDB.problem_id == problem_id, UserSavedSolutionDB.user_id == user.id
85
+ )
86
+ ).all()
87
+
88
+ saved_solutions = [SavedSolutionReference(saved_solution=saved_solution) for saved_solution in user_saved_solutions]
89
+
90
+ return filter_duplicates(saved_solutions)
91
+
92
+
93
+ # for collecting solutions for responses in iterate and initialize endpoints
94
+ def collect_all_solutions(user: User, problem_id: int, session: Session) -> list[SolutionReference]:
95
+ """Collects all solutions for the user and problem."""
96
+ statement = (
97
+ select(StateDB)
98
+ .where(StateDB.problem_id == problem_id, StateDB.session_id == user.active_session_id)
99
+ .order_by(StateDB.id.desc())
100
+ )
101
+ states = session.exec(statement).all()
102
+ all_solutions = []
103
+ for state in states:
104
+ for i in range(state.state.num_solutions):
105
+ all_solutions.append(SolutionReference(state=state, solution_index=i))
106
+
107
+ return filter_duplicates(all_solutions)
108
+
109
+
110
+ @router.post("/solve")
111
+ def solve_solutions(
112
+ request: NIMBUSClassificationRequest,
113
+ user: Annotated[User, Depends(get_current_user)],
114
+ session: Annotated[Session, Depends(get_session)],
115
+ ) -> NIMBUSClassificationResponse:
116
+ """Solve the problem using the NIMBUS method."""
117
+ if request.session_id is not None:
118
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
119
+ interactive_session = session.exec(statement)
120
+
121
+ if interactive_session is None:
122
+ raise HTTPException(
123
+ status_code=status.HTTP_404_NOT_FOUND,
124
+ detail=f"Could not find interactive session with id={request.session_id}.",
125
+ )
126
+ else:
127
+ # request.session_id is None:
128
+ # use active session instead
129
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
130
+
131
+ interactive_session = session.exec(statement).first()
132
+
133
+ # fetch the problem from the DB
134
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
135
+ problem_db = session.exec(statement).first()
136
+
137
+ if problem_db is None:
138
+ raise HTTPException(
139
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
140
+ )
141
+
142
+ solver = check_solver(problem_db=problem_db)
143
+
144
+ problem = Problem.from_problemdb(problem_db)
145
+
146
+ # fetch parent state
147
+ if request.parent_state_id is None:
148
+ # parent state is assumed to be the last state added to the session.
149
+ parent_state = (
150
+ interactive_session.states[-1]
151
+ if (interactive_session is not None and len(interactive_session.states) > 0)
152
+ else None
153
+ )
154
+
155
+ else:
156
+ # request.parent_state_id is not None
157
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
158
+ parent_state = session.exec(statement).first()
159
+
160
+ if parent_state is None:
161
+ raise HTTPException(
162
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
163
+ )
164
+
165
+ solver_results: list[SolverResults] = solve_sub_problems(
166
+ problem=problem,
167
+ current_objectives=request.current_objectives,
168
+ reference_point=request.preference.aspiration_levels,
169
+ num_desired=request.num_desired,
170
+ scalarization_options=request.scalarization_options,
171
+ solver=solver,
172
+ solver_options=request.solver_options,
173
+ )
174
+
175
+ nimbus_state = NIMBUSClassificationState(
176
+ preferences=request.preference,
177
+ scalarization_options=request.scalarization_options,
178
+ solver=request.solver,
179
+ solver_options=request.solver_options,
180
+ solver_results=solver_results,
181
+ current_objectives=request.current_objectives,
182
+ num_desired=request.num_desired,
183
+ previous_preferences=request.preference, # why?
184
+ )
185
+
186
+ # create DB state and add it to the DB
187
+ state = StateDB.create(
188
+ database_session=session,
189
+ problem_id=problem_db.id,
190
+ session_id=interactive_session.id if interactive_session is not None else None,
191
+ parent_id=parent_state.id if parent_state is not None else None,
192
+ state=nimbus_state,
193
+ )
194
+
195
+ session.add(state)
196
+ session.commit()
197
+ session.refresh(state)
198
+
199
+ # Collect all current solutions
200
+ current_solutions: list[SolutionReference] = []
201
+ for i, _ in enumerate(solver_results):
202
+ current_solutions.append(SolutionReference(state=state, solution_index=i))
203
+
204
+ saved_solutions = collect_saved_solutions(user, request.problem_id, session)
205
+ all_solutions = collect_all_solutions(user, request.problem_id, session)
206
+
207
+ return NIMBUSClassificationResponse(
208
+ state_id=state.id,
209
+ previous_preference=request.preference,
210
+ previous_objectives=request.current_objectives,
211
+ current_solutions=current_solutions,
212
+ saved_solutions=saved_solutions,
213
+ all_solutions=all_solutions,
214
+ )
215
+
216
+
217
+ @router.post("/initialize")
218
+ def initialize(
219
+ request: NIMBUSInitializationRequest,
220
+ user: Annotated[User, Depends(get_current_user)],
221
+ session: Annotated[Session, Depends(get_session)],
222
+ ) -> NIMBUSInitializationResponse:
223
+ """Initialize the problem for the NIMBUS method."""
224
+ if request.session_id is not None:
225
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
226
+ interactive_session = session.exec(statement)
227
+
228
+ if interactive_session is None:
229
+ raise HTTPException(
230
+ status_code=status.HTTP_404_NOT_FOUND,
231
+ detail=f"Could not find interactive session with id={request.session_id}.",
232
+ )
233
+ else:
234
+ # request.session_id is None:
235
+ # use active session instead
236
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
237
+
238
+ interactive_session = session.exec(statement).first()
239
+
240
+ print(interactive_session)
241
+
242
+ # fetch the problem from the DB
243
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
244
+ problem_db = session.exec(statement).first()
245
+
246
+ if problem_db is None:
247
+ raise HTTPException(
248
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
249
+ )
250
+
251
+ solver = check_solver(problem_db=problem_db)
252
+
253
+ problem = Problem.from_problemdb(problem_db)
254
+
255
+ if isinstance(ref_point := request.starting_point, ReferencePoint):
256
+ # ReferencePoint
257
+ starting_point = ref_point.aspiration_levels
258
+
259
+ elif isinstance(info := request.starting_point, SolutionInfo):
260
+ # SolutionInfo
261
+ # fetch the solution
262
+ statement = select(StateDB).where(StateDB.id == info.state_id)
263
+ state = session.exec(statement).first()
264
+
265
+ if state is None:
266
+ raise HTTPException(
267
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"StateDB with index {info.state_id} could not be found."
268
+ )
269
+
270
+ starting_point = state.state.result_objective_values[info.solution_index]
271
+
272
+ else:
273
+ # if not starting point is provided, generate it
274
+ starting_point = None
275
+
276
+ start_result = generate_starting_point(
277
+ problem=problem,
278
+ reference_point=starting_point,
279
+ scalarization_options=request.scalarization_options,
280
+ solver=solver,
281
+ solver_options=request.solver_options,
282
+ )
283
+
284
+ # fetch parent state if it is given
285
+ if request.parent_state_id is None:
286
+ parent_state = None
287
+ else:
288
+ statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
289
+ parent_state = session.exec(statement).first()
290
+
291
+ if parent_state is None:
292
+ raise HTTPException(
293
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
294
+ )
295
+
296
+ initialization_state = NIMBUSInitializationState(
297
+ reference_point=starting_point,
298
+ scalarization_options=request.scalarization_options,
299
+ solver=request.solver,
300
+ solver_results=start_result,
301
+ )
302
+
303
+ # create DB state and add it to the DB
304
+ state = StateDB.create(
305
+ database_session=session,
306
+ problem_id=problem_db.id,
307
+ session_id=interactive_session.id if interactive_session is not None else None,
308
+ parent_id=parent_state.id if parent_state is not None else None,
309
+ state=initialization_state,
310
+ )
311
+
312
+ session.add(state)
313
+ session.commit()
314
+ session.refresh(state)
315
+
316
+ current_solutions = [SolutionReference(state=state, solution_index=0)]
317
+ saved_solutions = collect_saved_solutions(user, request.problem_id, session)
318
+ all_solutions = collect_all_solutions(user, request.problem_id, session)
319
+
320
+ return NIMBUSInitializationResponse(
321
+ state_id=state.id,
322
+ current_solutions=current_solutions,
323
+ saved_solutions=saved_solutions,
324
+ all_solutions=all_solutions,
325
+ )
326
+
327
+
328
+ @router.post("/save")
329
+ def save(
330
+ request: NIMBUSSaveRequest,
331
+ user: Annotated[User, Depends(get_current_user)],
332
+ session: Annotated[Session, Depends(get_session)],
333
+ ) -> NIMBUSSaveResponse:
334
+ """Save solutions."""
335
+ if request.session_id is not None:
336
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
337
+ interactive_session = session.exec(statement)
338
+
339
+ if interactive_session is None:
340
+ raise HTTPException(
341
+ status_code=status.HTTP_404_NOT_FOUND,
342
+ detail=f"Could not find interactive session with id={request.session_id}.",
343
+ )
344
+ else:
345
+ # request.session_id is None:
346
+ # use active session instead
347
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
348
+
349
+ interactive_session = session.exec(statement).first()
350
+
351
+ # fetch parent state
352
+ if request.parent_state_id is None:
353
+ # parent state is assumed to be the last state added to the session.
354
+ parent_state = (
355
+ interactive_session.states[-1]
356
+ if (interactive_session is not None and len(interactive_session.states) > 0)
357
+ else None
358
+ )
359
+
360
+ else:
361
+ # request.parent_state_id is not None
362
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
363
+ parent_state = session.exec(statement).first()
364
+
365
+ if parent_state is None:
366
+ raise HTTPException(
367
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
368
+ )
369
+
370
+ # Check for duplicate solutions and update names instead of saving duplicates
371
+ updated_solutions: list[UserSavedSolutionDB] = []
372
+ new_solutions: list[UserSavedSolutionDB] = []
373
+
374
+ for info in request.solution_info:
375
+ existing_solution = session.exec(
376
+ select(UserSavedSolutionDB).where(
377
+ UserSavedSolutionDB.origin_state_id == info.state_id,
378
+ UserSavedSolutionDB.solution_index == info.solution_index,
379
+ )
380
+ ).first()
381
+
382
+ if existing_solution is not None:
383
+ # Update the name of the existing solution
384
+ existing_solution.name = info.name
385
+
386
+ session.add(existing_solution)
387
+
388
+ updated_solutions.append(existing_solution)
389
+ else:
390
+ # This is a new solution
391
+ new_solution = UserSavedSolutionDB.from_state_info(
392
+ session, user.id, request.problem_id, info.state_id, info.solution_index, info.name
393
+ )
394
+
395
+ session.add(new_solution)
396
+
397
+ new_solutions.append(new_solution)
398
+
399
+ # Commit existing and new solutions
400
+ if updated_solutions or new_solution:
401
+ session.commit()
402
+ [session.refresh(row) for row in updated_solutions + new_solutions]
403
+
404
+ # save solver results for state in SolverResults format just for consistency (dont save name field to state)
405
+ save_state = NIMBUSSaveState(solutions=updated_solutions + new_solutions)
406
+
407
+ # create DB state
408
+ state = StateDB.create(
409
+ database_session=session,
410
+ problem_id=request.problem_id,
411
+ session_id=interactive_session.id if interactive_session is not None else None,
412
+ parent_id=parent_state.id if parent_state is not None else None,
413
+ state=save_state,
414
+ )
415
+
416
+ session.add(state)
417
+ session.commit()
418
+ session.refresh(state)
419
+
420
+ return NIMBUSSaveResponse(state_id=state.id)
421
+
422
+
423
+ @router.post("/intermediate")
424
+ def solve_nimbus_intermediate(
425
+ request: IntermediateSolutionRequest,
426
+ user: Annotated[User, Depends(get_current_user)],
427
+ session: Annotated[Session, Depends(get_session)],
428
+ ) -> NIMBUSIntermediateSolutionResponse:
429
+ """Solve intermediate solutions by forwarding the request to generic intermediate endpoint with context nimbus."""
430
+ # Add NIMBUS context to request
431
+ request.context = "nimbus"
432
+ # Forward to generic endpoint
433
+ intermediate_response = solve_intermediate(request, user, session)
434
+
435
+ # Get saved solutions for this user and problem
436
+ saved_solutions = collect_saved_solutions(user, request.problem_id, session)
437
+
438
+ # Get all solutions including the newly generated intermediate ones
439
+ all_solutions = collect_all_solutions(user, request.problem_id, session)
440
+
441
+ return NIMBUSIntermediateSolutionResponse(
442
+ state_id=intermediate_response.state_id,
443
+ reference_solution_1=intermediate_response.reference_solution_1.objective_values,
444
+ reference_solution_2=intermediate_response.reference_solution_2.objective_values,
445
+ current_solutions=intermediate_response.intermediate_solutions,
446
+ saved_solutions=saved_solutions,
447
+ all_solutions=all_solutions,
448
+ )
449
+
450
+
451
+ @router.post("/get-or-initialize")
452
+ def get_or_initialize(
453
+ request: NIMBUSInitializationRequest,
454
+ user: Annotated[User, Depends(get_current_user)],
455
+ session: Annotated[Session, Depends(get_session)],
456
+ ) -> NIMBUSInitializationResponse | NIMBUSClassificationResponse | \
457
+ NIMBUSIntermediateSolutionResponse | NIMBUSFinalizeResponse:
458
+ """Get the latest NIMBUS state if it exists, or initialize a new one if it doesn't."""
459
+ if request.session_id is not None:
460
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
461
+ interactive_session = session.exec(statement)
462
+
463
+ if interactive_session is None:
464
+ raise HTTPException(
465
+ status_code=status.HTTP_404_NOT_FOUND,
466
+ detail=f"Could not find interactive session with id={request.session_id}.",
467
+ )
468
+ else:
469
+ # use active session instead
470
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
471
+ interactive_session = session.exec(statement).first()
472
+
473
+ # Look for latest relevant state in the session
474
+ statement = (
475
+ select(StateDB)
476
+ .where(
477
+ StateDB.problem_id == request.problem_id,
478
+ StateDB.session_id == (interactive_session.id if interactive_session else user.active_session_id),
479
+ )
480
+ .order_by(StateDB.id.desc())
481
+ )
482
+ states = session.exec(statement).all()
483
+
484
+ # Find the latest relevant state (NIMBUS classification, initialization, or intermediate with NIMBUS context)
485
+ latest_state = None
486
+ for state in states:
487
+ if isinstance(state.state, (NIMBUSClassificationState | NIMBUSInitializationState | NIMBUSFinalState)) or (
488
+ isinstance(state.state, IntermediateSolutionState) and state.state.context == "nimbus"
489
+ ):
490
+ latest_state = state
491
+ break
492
+
493
+ if latest_state is not None:
494
+ saved_solutions = collect_saved_solutions(user, request.problem_id, session)
495
+ all_solutions = collect_all_solutions(user, request.problem_id, session)
496
+ # Handle both single result and list of results cases
497
+ solver_results = latest_state.state.solver_results
498
+ if isinstance(solver_results, list):
499
+ current_solutions = [
500
+ SolutionReference(state=latest_state, solution_index=i) for i in range(len(solver_results))
501
+ ]
502
+ else:
503
+ # Single result case (NIMBUSInitializationState)
504
+ current_solutions = [SolutionReference(state=latest_state, solution_index=0)]
505
+
506
+ if isinstance(latest_state.state, NIMBUSClassificationState):
507
+ return NIMBUSClassificationResponse(
508
+ state_id=latest_state.id,
509
+ previous_preference=latest_state.state.preferences,
510
+ previous_objectives=latest_state.state.current_objectives,
511
+ current_solutions=current_solutions,
512
+ saved_solutions=saved_solutions,
513
+ all_solutions=all_solutions,
514
+ )
515
+
516
+ if isinstance(latest_state.state, IntermediateSolutionState):
517
+ return NIMBUSIntermediateSolutionResponse(
518
+ state_id=latest_state.id,
519
+ reference_solution_1=latest_state.state.reference_solution_1,
520
+ reference_solution_2=latest_state.state.reference_solution_2,
521
+ current_solutions=current_solutions,
522
+ saved_solutions=saved_solutions,
523
+ all_solutions=all_solutions,
524
+ )
525
+
526
+ if isinstance(latest_state.state, NIMBUSFinalState):
527
+
528
+ solution_index = latest_state.state.solution_result_index
529
+ origin_state_id = latest_state.state.solution_origin_state_id
530
+
531
+ final_solution_ref_res = SolutionReferenceResponse(
532
+ solution_index=solution_index,
533
+ state_id=origin_state_id,
534
+ objective_values=latest_state.state.solver_results.optimal_objectives,
535
+ variable_values=latest_state.state.solver_results.optimal_variables
536
+ )
537
+
538
+ return NIMBUSFinalizeResponse(
539
+ state_id=latest_state.id,
540
+ final_solution=final_solution_ref_res,
541
+ saved_solutions=saved_solutions,
542
+ all_solutions=all_solutions,
543
+ )
544
+
545
+ # NIMBUSInitializationState
546
+ return NIMBUSInitializationResponse(
547
+ state_id=latest_state.id,
548
+ current_solutions=current_solutions,
549
+ saved_solutions=saved_solutions,
550
+ all_solutions=all_solutions,
551
+ )
552
+
553
+ # No relevant state found, initialize a new one
554
+ return initialize(request, user, session)
555
+
556
+
557
+ @router.post("/finalize")
558
+ def finalize_nimbus(
559
+ request: NIMBUSFinalizeRequest,
560
+ user: Annotated[User, Depends(get_current_user)],
561
+ session: Annotated[Session, Depends(get_session)]
562
+ ) -> NIMBUSFinalizeResponse:
563
+ """An endpoint for finishing up the nimbus process.
564
+
565
+ Args:
566
+ request (NIMBUSFinalizeRequest): The request containing the final solution, etc.
567
+ user (Annotated[User, Depends): The current user.
568
+ session (Annotated[Session, Depends): The database session.
569
+
570
+ Raises:
571
+ HTTPException
572
+
573
+ Returns:
574
+ NIMBUSFinalizeResponse: Response containing info on the final solution.
575
+ """
576
+ if request.session_id is not None:
577
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
578
+ interactive_session = session.exec(statement)
579
+
580
+ if interactive_session is None:
581
+ raise HTTPException(
582
+ status_code=status.HTTP_404_NOT_FOUND,
583
+ detail=f"Could not find interactive session with id={request.session_id}.",
584
+ )
585
+ else:
586
+ # request.session_id is None:
587
+ # use active session instead
588
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
589
+
590
+ interactive_session = session.exec(statement).first()
591
+
592
+ if request.parent_state_id is None:
593
+ parent_state = None
594
+ else:
595
+ statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
596
+ parent_state = session.exec(statement).first()
597
+
598
+ if parent_state is None:
599
+ raise HTTPException(
600
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
601
+ )
602
+
603
+ # fetch the problem from the DB
604
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
605
+ problem_db = session.exec(statement).first()
606
+
607
+ if problem_db is None:
608
+ raise HTTPException(
609
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
610
+ )
611
+
612
+ solution_state_id = request.solution_info.state_id
613
+ solution_index = request.solution_info.solution_index
614
+
615
+ statement = select(StateDB).where(StateDB.id == solution_state_id)
616
+ actual_state = session.exec(statement).first().state
617
+ if actual_state is None:
618
+ raise HTTPException(
619
+ detail="No concrete substate!",
620
+ status_code=status.HTTP_404_NOT_FOUND,
621
+ )
622
+
623
+ final_state = NIMBUSFinalState(
624
+ solution_origin_state_id=solution_state_id,
625
+ solution_result_index=solution_index,
626
+ solver_results=actual_state.solver_results[solution_index]
627
+ )
628
+
629
+ state = StateDB.create(
630
+ database_session=session,
631
+ problem_id=problem_db.id,
632
+ session_id=interactive_session.id if interactive_session is not None else None,
633
+ parent_id=parent_state.id if parent_state is not None else None,
634
+ state=final_state,
635
+ )
636
+
637
+ session.add(state)
638
+ session.commit()
639
+ session.refresh(state)
640
+
641
+ solution_reference_response=SolutionReferenceResponse(
642
+ solution_index=solution_index,
643
+ state_id=solution_state_id,
644
+ objective_values=final_state.solver_results.optimal_objectives,
645
+ variable_values=final_state.solver_results.optimal_variables,
646
+ )
647
+
648
+ return NIMBUSFinalizeResponse(
649
+ state_id=state.id,
650
+ final_solution=solution_reference_response,
651
+ saved_solutions=collect_saved_solutions(user=user, problem_id=problem_db.id, session=session),
652
+ all_solutions=collect_all_solutions(user=user, problem_id=problem_db.id, session=session),
653
+ )
654
+
655
+ @router.post("/delete_save")
656
+ def delete_save(
657
+ request: NIMBUSDeleteSaveRequest,
658
+ user: Annotated[User, Depends(get_current_user)],
659
+ session: Annotated[Session, Depends(get_session)]
660
+ ) -> NIMBUSDeleteSaveResponse:
661
+ """Endpoint for deleting saved solutions.
662
+
663
+ Args:
664
+ request (NIMBUSDeleteSaveRequest): request containing necessary information for deleting a save
665
+ user (Annotated[User, Depends): the current (logged in) user
666
+ session (Annotated[Session, Depends): database session
667
+
668
+ Raises:
669
+ HTTPException
670
+
671
+ Returns:
672
+ NIMBUSDeleteSaveResponse: Response acknowledging the deletion of save and other useful info.
673
+ """
674
+ to_be_deleted = session.exec(
675
+ select(UserSavedSolutionDB).where(
676
+ UserSavedSolutionDB.origin_state_id == request.state_id,
677
+ UserSavedSolutionDB.solution_index == request.solution_index,
678
+ )
679
+ ).first()
680
+
681
+ if to_be_deleted is None:
682
+ raise HTTPException(
683
+ detail="Unable to find a saved solution!",
684
+ status_code=status.HTTP_404_NOT_FOUND
685
+ )
686
+
687
+ session.delete(to_be_deleted)
688
+ session.commit()
689
+
690
+ to_be_deleted = session.exec(
691
+ select(UserSavedSolutionDB).where(
692
+ UserSavedSolutionDB.origin_state_id == request.state_id,
693
+ UserSavedSolutionDB.solution_index == request.solution_index,
694
+ )
695
+ ).first()
696
+
697
+ if to_be_deleted is not None:
698
+ raise HTTPException(
699
+ detail="Could not delete the saved solution!",
700
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
701
+ )
702
+
703
+ return NIMBUSDeleteSaveResponse(
704
+ message="Save deleted."
705
+ )