desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/adm/ADMAfsar.py +551 -0
  3. desdeo/adm/ADMChen.py +414 -0
  4. desdeo/adm/BaseADM.py +119 -0
  5. desdeo/adm/__init__.py +11 -0
  6. desdeo/api/README.md +73 -0
  7. desdeo/api/__init__.py +15 -0
  8. desdeo/api/app.py +50 -0
  9. desdeo/api/config.py +90 -0
  10. desdeo/api/config.toml +64 -0
  11. desdeo/api/db.py +27 -0
  12. desdeo/api/db_init.py +85 -0
  13. desdeo/api/db_models.py +164 -0
  14. desdeo/api/malaga_db_init.py +27 -0
  15. desdeo/api/models/__init__.py +266 -0
  16. desdeo/api/models/archive.py +23 -0
  17. desdeo/api/models/emo.py +128 -0
  18. desdeo/api/models/enautilus.py +69 -0
  19. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  20. desdeo/api/models/gdm/gdm_base.py +69 -0
  21. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  22. desdeo/api/models/gdm/gnimbus.py +138 -0
  23. desdeo/api/models/generic.py +104 -0
  24. desdeo/api/models/generic_states.py +401 -0
  25. desdeo/api/models/nimbus.py +158 -0
  26. desdeo/api/models/preference.py +128 -0
  27. desdeo/api/models/problem.py +717 -0
  28. desdeo/api/models/reference_point_method.py +18 -0
  29. desdeo/api/models/session.py +49 -0
  30. desdeo/api/models/state.py +463 -0
  31. desdeo/api/models/user.py +52 -0
  32. desdeo/api/models/utopia.py +25 -0
  33. desdeo/api/routers/_EMO.backup +309 -0
  34. desdeo/api/routers/_NAUTILUS.py +245 -0
  35. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  36. desdeo/api/routers/_NIMBUS.py +765 -0
  37. desdeo/api/routers/__init__.py +5 -0
  38. desdeo/api/routers/emo.py +497 -0
  39. desdeo/api/routers/enautilus.py +237 -0
  40. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  41. desdeo/api/routers/gdm/gdm_base.py +420 -0
  42. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  43. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  44. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  45. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  46. desdeo/api/routers/generic.py +233 -0
  47. desdeo/api/routers/nimbus.py +705 -0
  48. desdeo/api/routers/problem.py +307 -0
  49. desdeo/api/routers/reference_point_method.py +93 -0
  50. desdeo/api/routers/session.py +100 -0
  51. desdeo/api/routers/test.py +16 -0
  52. desdeo/api/routers/user_authentication.py +520 -0
  53. desdeo/api/routers/utils.py +187 -0
  54. desdeo/api/routers/utopia.py +230 -0
  55. desdeo/api/schema.py +100 -0
  56. desdeo/api/tests/__init__.py +0 -0
  57. desdeo/api/tests/conftest.py +151 -0
  58. desdeo/api/tests/test_enautilus.py +330 -0
  59. desdeo/api/tests/test_models.py +1179 -0
  60. desdeo/api/tests/test_routes.py +1075 -0
  61. desdeo/api/utils/_database.py +263 -0
  62. desdeo/api/utils/_logger.py +29 -0
  63. desdeo/api/utils/database.py +36 -0
  64. desdeo/api/utils/emo_database.py +40 -0
  65. desdeo/core.py +34 -0
  66. desdeo/emo/__init__.py +159 -0
  67. desdeo/emo/hooks/archivers.py +188 -0
  68. desdeo/emo/methods/EAs.py +541 -0
  69. desdeo/emo/methods/__init__.py +0 -0
  70. desdeo/emo/methods/bases.py +12 -0
  71. desdeo/emo/methods/templates.py +111 -0
  72. desdeo/emo/operators/__init__.py +1 -0
  73. desdeo/emo/operators/crossover.py +1282 -0
  74. desdeo/emo/operators/evaluator.py +114 -0
  75. desdeo/emo/operators/generator.py +459 -0
  76. desdeo/emo/operators/mutation.py +1224 -0
  77. desdeo/emo/operators/scalar_selection.py +202 -0
  78. desdeo/emo/operators/selection.py +1778 -0
  79. desdeo/emo/operators/termination.py +286 -0
  80. desdeo/emo/options/__init__.py +108 -0
  81. desdeo/emo/options/algorithms.py +435 -0
  82. desdeo/emo/options/crossover.py +164 -0
  83. desdeo/emo/options/generator.py +131 -0
  84. desdeo/emo/options/mutation.py +260 -0
  85. desdeo/emo/options/repair.py +61 -0
  86. desdeo/emo/options/scalar_selection.py +66 -0
  87. desdeo/emo/options/selection.py +127 -0
  88. desdeo/emo/options/templates.py +383 -0
  89. desdeo/emo/options/termination.py +143 -0
  90. desdeo/explanations/__init__.py +6 -0
  91. desdeo/explanations/explainer.py +100 -0
  92. desdeo/explanations/utils.py +90 -0
  93. desdeo/gdm/__init__.py +22 -0
  94. desdeo/gdm/gdmtools.py +45 -0
  95. desdeo/gdm/score_bands.py +114 -0
  96. desdeo/gdm/voting_rules.py +50 -0
  97. desdeo/mcdm/__init__.py +41 -0
  98. desdeo/mcdm/enautilus.py +338 -0
  99. desdeo/mcdm/gnimbus.py +484 -0
  100. desdeo/mcdm/nautili.py +345 -0
  101. desdeo/mcdm/nautilus.py +477 -0
  102. desdeo/mcdm/nautilus_navigator.py +656 -0
  103. desdeo/mcdm/nimbus.py +417 -0
  104. desdeo/mcdm/pareto_navigator.py +269 -0
  105. desdeo/mcdm/reference_point_method.py +186 -0
  106. desdeo/problem/__init__.py +83 -0
  107. desdeo/problem/evaluator.py +561 -0
  108. desdeo/problem/external/__init__.py +18 -0
  109. desdeo/problem/external/core.py +356 -0
  110. desdeo/problem/external/pymoo_provider.py +266 -0
  111. desdeo/problem/external/runtime.py +44 -0
  112. desdeo/problem/gurobipy_evaluator.py +562 -0
  113. desdeo/problem/infix_parser.py +341 -0
  114. desdeo/problem/json_parser.py +944 -0
  115. desdeo/problem/pyomo_evaluator.py +487 -0
  116. desdeo/problem/schema.py +1829 -0
  117. desdeo/problem/simulator_evaluator.py +348 -0
  118. desdeo/problem/sympy_evaluator.py +244 -0
  119. desdeo/problem/testproblems/__init__.py +88 -0
  120. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  121. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  122. desdeo/problem/testproblems/cake_problem.py +185 -0
  123. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  124. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  125. desdeo/problem/testproblems/forest_problem.py +283 -0
  126. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  127. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  128. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  129. desdeo/problem/testproblems/momip_problem.py +172 -0
  130. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  131. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  132. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  133. desdeo/problem/testproblems/re_problem.py +492 -0
  134. desdeo/problem/testproblems/river_pollution_problems.py +440 -0
  135. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  136. desdeo/problem/testproblems/simple_problem.py +351 -0
  137. desdeo/problem/testproblems/simulator_problem.py +92 -0
  138. desdeo/problem/testproblems/single_objective.py +289 -0
  139. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  140. desdeo/problem/testproblems/zdt_problem.py +274 -0
  141. desdeo/problem/utils.py +245 -0
  142. desdeo/tools/GenerateReferencePoints.py +181 -0
  143. desdeo/tools/__init__.py +120 -0
  144. desdeo/tools/desc_gen.py +22 -0
  145. desdeo/tools/generics.py +165 -0
  146. desdeo/tools/group_scalarization.py +3090 -0
  147. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  148. desdeo/tools/indicators_binary.py +117 -0
  149. desdeo/tools/indicators_unary.py +362 -0
  150. desdeo/tools/interaction_schema.py +38 -0
  151. desdeo/tools/intersection.py +54 -0
  152. desdeo/tools/iterative_pareto_representer.py +99 -0
  153. desdeo/tools/message.py +265 -0
  154. desdeo/tools/ng_solver_interfaces.py +199 -0
  155. desdeo/tools/non_dominated_sorting.py +134 -0
  156. desdeo/tools/patterns.py +283 -0
  157. desdeo/tools/proximal_solver.py +99 -0
  158. desdeo/tools/pyomo_solver_interfaces.py +477 -0
  159. desdeo/tools/reference_vectors.py +229 -0
  160. desdeo/tools/scalarization.py +2065 -0
  161. desdeo/tools/scipy_solver_interfaces.py +454 -0
  162. desdeo/tools/score_bands.py +627 -0
  163. desdeo/tools/utils.py +388 -0
  164. desdeo/tools/visualizations.py +67 -0
  165. desdeo/utopia_stuff/__init__.py +0 -0
  166. desdeo/utopia_stuff/data/1.json +15 -0
  167. desdeo/utopia_stuff/data/2.json +13 -0
  168. desdeo/utopia_stuff/data/3.json +15 -0
  169. desdeo/utopia_stuff/data/4.json +17 -0
  170. desdeo/utopia_stuff/data/5.json +15 -0
  171. desdeo/utopia_stuff/from_json.py +40 -0
  172. desdeo/utopia_stuff/reinit_user.py +38 -0
  173. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  174. desdeo/utopia_stuff/utopia_problem.py +403 -0
  175. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  176. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  177. desdeo-2.1.0.dist-info/METADATA +186 -0
  178. desdeo-2.1.0.dist-info/RECORD +180 -0
  179. {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  180. desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
  181. desdeo-1.2.dist-info/METADATA +0 -16
  182. desdeo-1.2.dist-info/RECORD +0 -4
@@ -0,0 +1,309 @@
1
+ """Router for evolutionary multiobjective optimization (EMO) methods."""
2
+
3
+ from datetime import datetime
4
+ from typing import Annotated, Dict, List, Optional
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, status
7
+ from sqlalchemy.orm import Session
8
+ from sqlmodel import select
9
+
10
+ from desdeo.api.db import get_session
11
+ from desdeo.api.models.archive import (
12
+ UserSavedEMOResults,
13
+ )
14
+ from desdeo.api.models.EMO import (
15
+ EMOSaveRequest,
16
+ EMOSolveRequest,
17
+ )
18
+ from desdeo.api.models.preference import (
19
+ NonPreferredSolutions,
20
+ PreferenceBase,
21
+ PreferenceDB,
22
+ PreferredRanges,
23
+ PreferredSolutions,
24
+ ReferencePoint,
25
+ )
26
+ from desdeo.api.models.problem import ProblemDB
27
+ from desdeo.api.models.session import InteractiveSessionDB
28
+ from desdeo.api.models.state import EMOSaveState, EMOState, StateDB
29
+ from desdeo.api.models.user import User
30
+ from desdeo.api.routers.user_authentication import get_current_user
31
+ from desdeo.api.utils.database import user_save_solutions
32
+ from desdeo.api.utils.emo_database import _convert_dataframe_to_dict_list
33
+ from desdeo.emo.hooks.archivers import NonDominatedArchive
34
+ from desdeo.emo.methods.EAs import nsga3, rvea
35
+ from desdeo.problem import Problem
36
+
37
+ router = APIRouter(prefix="/method/emo", tags=["evolutionary"])
38
+
39
+
40
+ @router.post("/solve")
41
+ def start_emo_optimization(
42
+ request: EMOSolveRequest,
43
+ user: Annotated[User, Depends(get_current_user)],
44
+ session: Annotated[Session, Depends(get_session)],
45
+ ) -> EMOState:
46
+ """Start interactive evolutionary multiobjective optimization."""
47
+
48
+ # Handle session logic
49
+ if request.session_id is not None:
50
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
51
+ interactive_session = session.exec(statement).first()
52
+
53
+ if interactive_session is None:
54
+ raise HTTPException(
55
+ status_code=status.HTTP_404_NOT_FOUND,
56
+ detail=f"Could not find interactive session with id={request.session_id}.",
57
+ )
58
+ else:
59
+ # Use active session
60
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
61
+ interactive_session = session.exec(statement).first()
62
+
63
+ # Fetch problem from DB
64
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
65
+ problem_db = session.exec(statement).first()
66
+
67
+ if problem_db is None:
68
+ raise HTTPException(
69
+ status_code=status.HTTP_404_NOT_FOUND,
70
+ detail=f"Problem with id={request.problem_id} could not be found.",
71
+ )
72
+
73
+ # Convert ProblemDB to Problem object
74
+ problem = Problem.from_problemdb(problem_db)
75
+
76
+ # Build reference vector options based on preference type
77
+ reference_vector_options = _build_reference_vector_options(request.preference, request.number_of_vectors)
78
+
79
+ # Create solver and publisher
80
+ if request.method == "RVEA":
81
+ solver, publisher = rvea(problem=problem, reference_vector_options=reference_vector_options)
82
+ elif request.method == "NSGA3":
83
+ solver, publisher = nsga3(problem=problem, reference_vector_options=reference_vector_options)
84
+ else:
85
+ raise HTTPException(
86
+ status_code=status.HTTP_400_BAD_REQUEST,
87
+ detail=f"Unsupported method: {request.method}. Supported methods are 'NSGA3' and 'RVEA'.",
88
+ )
89
+
90
+ # Add archive if requested
91
+ archive = None
92
+ if request.use_archive:
93
+ archive = NonDominatedArchive(problem=problem, publisher=publisher)
94
+ publisher.auto_subscribe(archive)
95
+
96
+ # Run optimization
97
+ emo_results = solver()
98
+
99
+ # Convert DataFrames to dictionaries for solutions
100
+ solutions_dict = _convert_dataframe_to_dict_list(getattr(emo_results, "solutions", None))
101
+
102
+ # Convert DataFrames to dictionaries for outputs
103
+ outputs_dict = _convert_dataframe_to_dict_list(getattr(emo_results, "outputs", None))
104
+
105
+ # Create DB preference
106
+ preference_db = PreferenceDB(user_id=user.id, problem_id=problem_db.id, preference=request.preference)
107
+
108
+ session.add(preference_db)
109
+ session.commit()
110
+ session.refresh(preference_db)
111
+
112
+ # Handle parent state
113
+ if request.parent_state_id is None:
114
+ parent_state = (
115
+ interactive_session.states[-1]
116
+ if (interactive_session is not None and len(interactive_session.states) > 0)
117
+ else None
118
+ )
119
+ else:
120
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
121
+ parent_state = session.exec(statement).first()
122
+
123
+ if parent_state is None:
124
+ raise HTTPException(
125
+ status_code=status.HTTP_404_NOT_FOUND,
126
+ detail=f"Could not find state with id={request.parent_state_id}",
127
+ )
128
+
129
+ # Create EMO state
130
+ emo_state = EMOState(
131
+ method=request.method, # Use the method directly (already uppercase)
132
+ max_evaluations=request.max_evaluations,
133
+ number_of_vectors=request.number_of_vectors,
134
+ use_archive=request.use_archive,
135
+ solutions=solutions_dict,
136
+ outputs=outputs_dict,
137
+ )
138
+
139
+ # Create DB state
140
+ state = StateDB(
141
+ problem_id=problem_db.id,
142
+ preference_id=preference_db.id,
143
+ session_id=interactive_session.id if interactive_session is not None else None,
144
+ parent_id=parent_state.id if parent_state is not None else None,
145
+ state=emo_state, # Convert to dict for JSON serialization
146
+ )
147
+
148
+ session.add(state)
149
+ session.commit()
150
+ session.refresh(state)
151
+
152
+ return emo_state
153
+
154
+
155
+ @router.post("/save")
156
+ def save(
157
+ request: EMOSaveRequest,
158
+ user: Annotated[User, Depends(get_current_user)],
159
+ session: Annotated[Session, Depends(get_session)],
160
+ ) -> EMOSaveState:
161
+ """Save solutions."""
162
+ if request.session_id is not None:
163
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
164
+ interactive_session = session.exec(statement)
165
+
166
+ if interactive_session is None:
167
+ raise HTTPException(
168
+ status_code=status.HTTP_404_NOT_FOUND,
169
+ detail=f"Could not find interactive session with id={request.session_id}.",
170
+ )
171
+ else:
172
+ # request.session_id is None:
173
+ # use active session instead
174
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
175
+
176
+ interactive_session = session.exec(statement).first()
177
+
178
+ # fetch parent state
179
+ if request.parent_state_id is None:
180
+ # parent state is assumed to be the last state added to the session.
181
+ parent_state = (
182
+ interactive_session.states[-1]
183
+ if (interactive_session is not None and len(interactive_session.states) > 0)
184
+ else None
185
+ )
186
+
187
+ else:
188
+ # request.parent_state_id is not None
189
+ statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
190
+ parent_state = session.exec(statement).first()
191
+
192
+ if parent_state is None:
193
+ raise HTTPException(
194
+ status_code=status.HTTP_404_NOT_FOUND,
195
+ detail=f"Could not find state with id={request.parent_state_id}",
196
+ )
197
+
198
+ # save solver results for state in SolverResults format just for consistency (dont save name field to state)
199
+ # Get values from parent state if available, otherwise use defaults
200
+ max_evaluations = 1000
201
+ number_of_vectors = 20
202
+ use_archive = True
203
+
204
+ if parent_state is not None and isinstance(parent_state.state, EMOState):
205
+ max_evaluations = parent_state.state.max_evaluations
206
+ number_of_vectors = parent_state.state.number_of_vectors
207
+ use_archive = parent_state.state.use_archive
208
+
209
+ save_state = EMOSaveState(
210
+ method=(parent_state.state.method if parent_state else "EMO"), # Get from parent or default
211
+ max_evaluations=max_evaluations,
212
+ number_of_vectors=number_of_vectors,
213
+ use_archive=use_archive,
214
+ problem_id=request.problem_id,
215
+ saved_solutions=[solution.to_emo_results() for solution in request.solutions],
216
+ solutions=[solution.model_dump() for solution in request.solutions], # Original solutions from request
217
+ )
218
+
219
+ # create DB state
220
+ state = StateDB(
221
+ problem_id=request.problem_id,
222
+ session_id=interactive_session.id if interactive_session is not None else None,
223
+ parent_id=parent_state.id if parent_state is not None else None,
224
+ state=save_state,
225
+ )
226
+ # save solutions to the user's archive and add state to the DB
227
+ user_save_solutions(state, request.solutions, user.id, session)
228
+
229
+ return save_state
230
+
231
+
232
+ @router.get("/saved-solutions")
233
+ def get_saved_solutions(
234
+ user: Annotated[User, Depends(get_current_user)],
235
+ session: Annotated[Session, Depends(get_session)],
236
+ ):
237
+ """Get all saved solutions for the current user."""
238
+ from desdeo.api.models.archive import UserSavedSolutionDB
239
+
240
+ # Query saved solutions for the current user
241
+ statement = select(UserSavedSolutionDB).where(UserSavedSolutionDB.user_id == user.id)
242
+ saved_solutions = session.exec(statement).all()
243
+
244
+ # Convert to response format
245
+ results = []
246
+ for solution in saved_solutions:
247
+ results.append(
248
+ {
249
+ "id": solution.id,
250
+ "name": solution.name,
251
+ "variable_values": solution.variable_values,
252
+ "objective_values": solution.objective_values,
253
+ "constraint_values": solution.constraint_values,
254
+ "extra_func_values": solution.extra_func_values,
255
+ "problem_id": solution.problem_id,
256
+ }
257
+ )
258
+
259
+ return results
260
+
261
+
262
+ # Helper functions
263
+ def _build_reference_vector_options(preference: PreferenceBase, number_of_vectors: int) -> Dict:
264
+ """Build reference vector options based on preference type."""
265
+
266
+ base_options = {
267
+ "number_of_vectors": number_of_vectors,
268
+ }
269
+
270
+ # Convert the preference dict to the correct object type
271
+ if isinstance(preference, dict):
272
+ preference_type = preference.get("preference_type")
273
+ if preference_type == "reference_point":
274
+ from desdeo.api.models.preference import ReferencePoint
275
+
276
+ preference = ReferencePoint.model_validate(preference)
277
+ elif preference_type == "preferred_solutions":
278
+ from desdeo.api.models.preference import PreferredSolutions
279
+
280
+ preference = PreferredSolutions.model_validate(preference)
281
+ elif preference_type == "non_preferred_solutions":
282
+ from desdeo.api.models.preference import NonPreferredSolutions
283
+
284
+ preference = NonPreferredSolutions.model_validate(preference)
285
+ elif preference_type == "preferred_ranges":
286
+ from desdeo.api.models.preference import PreferredRanges
287
+
288
+ preference = PreferredRanges.model_validate(preference)
289
+
290
+ # Now handle the properly typed preference object
291
+ if hasattr(preference, "aspiration_levels"):
292
+ base_options["interactive_adaptation"] = "reference_point"
293
+ base_options["reference_point"] = preference.aspiration_levels
294
+ elif hasattr(preference, "preferred_solutions"):
295
+ base_options["interactive_adaptation"] = "preferred_solutions"
296
+ base_options["preferred_solutions"] = preference.preferred_solutions
297
+ elif hasattr(preference, "non_preferred_solutions"):
298
+ base_options["interactive_adaptation"] = "non_preferred_solutions"
299
+ base_options["non_preferred_solutions"] = preference.non_preferred_solutions
300
+ elif hasattr(preference, "preferred_ranges"):
301
+ base_options["interactive_adaptation"] = "preferred_ranges"
302
+ base_options["preferred_ranges"] = preference.preferred_ranges
303
+ else:
304
+ raise HTTPException(
305
+ status_code=400,
306
+ detail=f"Unsupported preference type: {type(preference)} with preference_type: {getattr(preference, 'preference_type', 'unknown')}",
307
+ )
308
+
309
+ return base_options
@@ -0,0 +1,245 @@
1
+ """Endpoints for NAUTILUS ."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException
6
+ from pydantic import BaseModel, Field, ValidationError
7
+ from sqlalchemy.orm import Session
8
+
9
+ from desdeo.api.db import get_db
10
+ from desdeo.api.db_models import Problem as ProblemInDB
11
+ from desdeo.api.db_models import Results
12
+ from desdeo.api.routers.user_authentication import get_current_user
13
+ from desdeo.api.schema import User
14
+ from desdeo.mcdm.nautilus import (
15
+ NAUTILUS_Response,
16
+ get_current_path,
17
+ nautilus_init,
18
+ nautilus_step,
19
+ points_to_weights,
20
+ ranks_to_weights,
21
+ step_back_index,
22
+ )
23
+ from desdeo.problem.schema import Problem
24
+
25
+ router = APIRouter(prefix="/nautilus")
26
+
27
+
28
+ class InitRequest(BaseModel):
29
+ """The request to initialize the NAUTILUS."""
30
+
31
+ problem_id: int = Field(description="The ID of the problem to navigate.")
32
+ # TODO: IS total_steps needed for NAUTILUS, what is good default? now its 5.
33
+ total_steps: int | None = Field(
34
+ description=("The total number of steps in the NAUTILUS. The default value is 5."), default=5
35
+ )
36
+
37
+
38
+ class NavigateRequest(BaseModel):
39
+ """The request to navigate the NAUTILUS."""
40
+
41
+ problem_id: int = Field(description="The ID of the problem to navigate.")
42
+ points: dict[str, float] | None = Field(
43
+ description=(
44
+ "Preference in the form of points given to the objectives."
45
+ " Higher is better. Must sum up to 100. Only one of points or ranks can be given."
46
+ )
47
+ )
48
+ ranks: dict[str, int] | None = Field(
49
+ description=(
50
+ "Preference in the form of ranks given to the objectives. Higher is better."
51
+ "Must be integers between 1 and the number of objectives. Ranks need not be unique, consecutive."
52
+ )
53
+ )
54
+ calculate_step: int = Field(description="The step index to calculate. Starts from 1. Max = total_steps.")
55
+ steps_remaining: int = Field(
56
+ description="The number of steps remaining. Should be total_steps - calculate_step + 1."
57
+ )
58
+
59
+
60
+ class InitialResponse(BaseModel):
61
+ """The response from the initial endpoint of NAUTILUS."""
62
+
63
+ objective_symbols: list[str] = Field(description="The symbols/short names of the objectives.")
64
+ objective_long_names: list[str] = Field(description="Long/descriptive names of the objectives.")
65
+ units: list[str] | None = Field(description="The units of the objectives.")
66
+ is_maximized: list[bool] = Field(description="Whether the objectives are to be maximized or minimized.")
67
+ ideal: list[float] = Field(description="The ideal values of the objectives.")
68
+ nadir: list[float] = Field(description="The nadir values of the objectives.")
69
+ total_steps: int = Field(description="The total number of steps in the NAUTILUS Navigator.")
70
+ distance_to_front: float | None = Field(description="The distance to the front of the reachable region.")
71
+
72
+
73
+ class Response(InitialResponse):
74
+ """The response from most NAUTILUS endpoints.
75
+
76
+ Contains information about the full navigation process.
77
+ """
78
+
79
+ lower_bounds: dict[str, list[float]] = Field(description="The lower bounds of the reachable region.")
80
+ upper_bounds: dict[str, list[float]] = Field(description="The upper bounds of the reachable region.")
81
+ preferences: dict[str, list[float]] = Field(description="The preferences used in each step.")
82
+
83
+ # TODO: ALL ABOVE SHOULD BE FINE
84
+
85
+
86
+ @router.post("/initialize")
87
+ def init_nautilus(
88
+ init_request: InitRequest,
89
+ user: Annotated[User, Depends(get_current_user)],
90
+ db: Annotated[Session, Depends(get_db)],
91
+ ) -> InitialResponse:
92
+ """Initialize the NAUTILUS.
93
+
94
+ Args:
95
+ init_request (InitRequest): The request to initialize the NAUTILUS.
96
+ user (Annotated[User, Depends(get_current_user)]): The current user.
97
+ db (Annotated[Session, Depends(get_db)]): The database session.
98
+
99
+ Returns:
100
+ InitialResponse: The initial response from the NAUTILUS.
101
+ """
102
+ problem_id = init_request.problem_id
103
+ problem = db.query(ProblemInDB).filter(ProblemInDB.id == problem_id).first()
104
+
105
+ if problem is None:
106
+ raise HTTPException(status_code=404, detail="Problem not found.")
107
+ if problem.owner != user.index and problem.owner is not None:
108
+ raise HTTPException(status_code=403, detail="Unauthorized to access chosen problem.")
109
+ try:
110
+ problem = Problem.model_validate(problem.value)
111
+ except ValidationError:
112
+ raise HTTPException(status_code=500, detail="Error in parsing the problem.") from ValidationError
113
+
114
+ response = nautilus_init(problem)
115
+
116
+ # Get and delete all Results from previous runs of NAUTILUS
117
+ results = db.query(Results).filter(Results.problem == problem_id).filter(Results.user == user.index).all()
118
+ for result in results:
119
+ db.delete(result)
120
+ db.commit()
121
+
122
+ new_result = Results(
123
+ user=user.index,
124
+ problem=problem_id,
125
+ value=response.model_dump(mode="json"),
126
+ )
127
+ db.add(new_result)
128
+ db.commit()
129
+
130
+ return InitialResponse(
131
+ objective_symbols=[obj.symbol for obj in problem.objectives],
132
+ objective_long_names=[obj.name for obj in problem.objectives],
133
+ units=[obj.unit for obj in problem.objectives],
134
+ is_maximized=[obj.maximize for obj in problem.objectives],
135
+ ideal=[obj.ideal for obj in problem.objectives],
136
+ nadir=[obj.nadir for obj in problem.objectives],
137
+ total_steps=init_request.total_steps,
138
+ distance_to_front=response.distance_to_front,
139
+ )
140
+
141
+
142
+ @router.post("/iterate")
143
+ def iterate(
144
+ request: NavigateRequest,
145
+ user: Annotated[User, Depends(get_current_user)],
146
+ db: Annotated[Session, Depends(get_db)],
147
+ ) -> Response:
148
+ """Navigate the NAUTILUS.
149
+
150
+ Runs the NAUTILUS algorithm one step at a time.
151
+
152
+ Args:
153
+ request (NavigateRequest): The request to navigate the NAUTILUS 1.
154
+
155
+ Raises:
156
+ HTTPException: _description_
157
+ HTTPException: _description_
158
+ HTTPException: _description_
159
+ HTTPException: _description_
160
+
161
+ Returns:
162
+ Response: _description_
163
+ """
164
+ problem_id, ranks, points, calculate_step, steps_remaining = (
165
+ request.problem_id,
166
+ request.ranks,
167
+ request.points,
168
+ request.calculate_step,
169
+ request.steps_remaining,
170
+ )
171
+
172
+ if ranks is not None and points is not None:
173
+ raise HTTPException(status_code=400, detail="Both ranks and points cannot be given.")
174
+ if ranks is None and points is None:
175
+ raise HTTPException(status_code=400, detail="Either ranks or points must be given.")
176
+
177
+ problem = db.query(ProblemInDB).filter(ProblemInDB.id == problem_id).first()
178
+ if problem is None:
179
+ raise HTTPException(status_code=404, detail="Problem not found.")
180
+ if problem.owner != user.index and problem.owner is not None:
181
+ raise HTTPException(status_code=403, detail="Unauthorized to access chosen problem.")
182
+ try:
183
+ problem = Problem.model_validate(problem.value)
184
+ except ValidationError:
185
+ raise HTTPException(status_code=500, detail="Error in parsing the problem.") from ValidationError
186
+
187
+ results = db.query(Results).filter(Results.problem == problem_id).filter(Results.user == user.index).all()
188
+ if not results:
189
+ raise HTTPException(status_code=404, detail="NAUTILUS 1 not initialized.")
190
+
191
+ responses = [NAUTILUS_Response.model_validate(result.value) for result in results]
192
+
193
+ step_to_append_index = step_back_index(responses, calculate_step - 1)
194
+
195
+ if step_to_append_index < len(responses) - 1:
196
+ responses.append(responses[step_back_index(responses, calculate_step - 1)])
197
+
198
+ try:
199
+ new_response = nautilus_step(
200
+ problem,
201
+ step_number=calculate_step,
202
+ steps_remaining=steps_remaining,
203
+ nav_point=responses[-1].navigation_point,
204
+ ranks=ranks,
205
+ points=points,
206
+ )
207
+ except IndexError as e:
208
+ raise HTTPException(status_code=400, detail=str(e)) from e
209
+
210
+ new_result = Results(
211
+ user=user.index,
212
+ problem=problem_id,
213
+ value=new_response.model_dump(mode="json"),
214
+ )
215
+ db.add(new_result)
216
+ db.commit()
217
+
218
+ responses = [*responses, new_response]
219
+ current_path = get_current_path(responses)
220
+ active_responses = [responses[i] for i in current_path]
221
+ lower_bounds = {}
222
+ upper_bounds = {}
223
+ preferences = {}
224
+ for obj in problem.objectives:
225
+ lower_bounds[obj.symbol] = [
226
+ response.reachable_bounds["lower_bounds"][obj.symbol] for response in active_responses
227
+ ]
228
+ upper_bounds[obj.symbol] = [
229
+ response.reachable_bounds["upper_bounds"][obj.symbol] for response in active_responses
230
+ ]
231
+ preferences[obj.symbol] = [response.preference[obj.symbol] for response in active_responses[1:]]
232
+
233
+ return Response(
234
+ objective_symbols=[obj.symbol for obj in problem.objectives],
235
+ objective_long_names=[obj.name for obj in problem.objectives],
236
+ units=[obj.unit for obj in problem.objectives],
237
+ is_maximized=[obj.maximize for obj in problem.objectives],
238
+ ideal=[obj.ideal for obj in problem.objectives],
239
+ nadir=[obj.nadir for obj in problem.objectives],
240
+ lower_bounds=lower_bounds,
241
+ upper_bounds=upper_bounds,
242
+ preferences=preferences,
243
+ total_steps=len(active_responses) - 1,
244
+ distance_to_front=active_responses[-1].distance_to_front,
245
+ )