desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/adm/ADMAfsar.py +551 -0
  3. desdeo/adm/ADMChen.py +414 -0
  4. desdeo/adm/BaseADM.py +119 -0
  5. desdeo/adm/__init__.py +11 -0
  6. desdeo/api/README.md +73 -0
  7. desdeo/api/__init__.py +15 -0
  8. desdeo/api/app.py +50 -0
  9. desdeo/api/config.py +90 -0
  10. desdeo/api/config.toml +64 -0
  11. desdeo/api/db.py +27 -0
  12. desdeo/api/db_init.py +85 -0
  13. desdeo/api/db_models.py +164 -0
  14. desdeo/api/malaga_db_init.py +27 -0
  15. desdeo/api/models/__init__.py +266 -0
  16. desdeo/api/models/archive.py +23 -0
  17. desdeo/api/models/emo.py +128 -0
  18. desdeo/api/models/enautilus.py +69 -0
  19. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  20. desdeo/api/models/gdm/gdm_base.py +69 -0
  21. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  22. desdeo/api/models/gdm/gnimbus.py +138 -0
  23. desdeo/api/models/generic.py +104 -0
  24. desdeo/api/models/generic_states.py +401 -0
  25. desdeo/api/models/nimbus.py +158 -0
  26. desdeo/api/models/preference.py +128 -0
  27. desdeo/api/models/problem.py +717 -0
  28. desdeo/api/models/reference_point_method.py +18 -0
  29. desdeo/api/models/session.py +49 -0
  30. desdeo/api/models/state.py +463 -0
  31. desdeo/api/models/user.py +52 -0
  32. desdeo/api/models/utopia.py +25 -0
  33. desdeo/api/routers/_EMO.backup +309 -0
  34. desdeo/api/routers/_NAUTILUS.py +245 -0
  35. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  36. desdeo/api/routers/_NIMBUS.py +765 -0
  37. desdeo/api/routers/__init__.py +5 -0
  38. desdeo/api/routers/emo.py +497 -0
  39. desdeo/api/routers/enautilus.py +237 -0
  40. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  41. desdeo/api/routers/gdm/gdm_base.py +420 -0
  42. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  43. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  44. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  45. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  46. desdeo/api/routers/generic.py +233 -0
  47. desdeo/api/routers/nimbus.py +705 -0
  48. desdeo/api/routers/problem.py +307 -0
  49. desdeo/api/routers/reference_point_method.py +93 -0
  50. desdeo/api/routers/session.py +100 -0
  51. desdeo/api/routers/test.py +16 -0
  52. desdeo/api/routers/user_authentication.py +520 -0
  53. desdeo/api/routers/utils.py +187 -0
  54. desdeo/api/routers/utopia.py +230 -0
  55. desdeo/api/schema.py +100 -0
  56. desdeo/api/tests/__init__.py +0 -0
  57. desdeo/api/tests/conftest.py +151 -0
  58. desdeo/api/tests/test_enautilus.py +330 -0
  59. desdeo/api/tests/test_models.py +1179 -0
  60. desdeo/api/tests/test_routes.py +1075 -0
  61. desdeo/api/utils/_database.py +263 -0
  62. desdeo/api/utils/_logger.py +29 -0
  63. desdeo/api/utils/database.py +36 -0
  64. desdeo/api/utils/emo_database.py +40 -0
  65. desdeo/core.py +34 -0
  66. desdeo/emo/__init__.py +159 -0
  67. desdeo/emo/hooks/archivers.py +188 -0
  68. desdeo/emo/methods/EAs.py +541 -0
  69. desdeo/emo/methods/__init__.py +0 -0
  70. desdeo/emo/methods/bases.py +12 -0
  71. desdeo/emo/methods/templates.py +111 -0
  72. desdeo/emo/operators/__init__.py +1 -0
  73. desdeo/emo/operators/crossover.py +1282 -0
  74. desdeo/emo/operators/evaluator.py +114 -0
  75. desdeo/emo/operators/generator.py +459 -0
  76. desdeo/emo/operators/mutation.py +1224 -0
  77. desdeo/emo/operators/scalar_selection.py +202 -0
  78. desdeo/emo/operators/selection.py +1778 -0
  79. desdeo/emo/operators/termination.py +286 -0
  80. desdeo/emo/options/__init__.py +108 -0
  81. desdeo/emo/options/algorithms.py +435 -0
  82. desdeo/emo/options/crossover.py +164 -0
  83. desdeo/emo/options/generator.py +131 -0
  84. desdeo/emo/options/mutation.py +260 -0
  85. desdeo/emo/options/repair.py +61 -0
  86. desdeo/emo/options/scalar_selection.py +66 -0
  87. desdeo/emo/options/selection.py +127 -0
  88. desdeo/emo/options/templates.py +383 -0
  89. desdeo/emo/options/termination.py +143 -0
  90. desdeo/explanations/__init__.py +6 -0
  91. desdeo/explanations/explainer.py +100 -0
  92. desdeo/explanations/utils.py +90 -0
  93. desdeo/gdm/__init__.py +22 -0
  94. desdeo/gdm/gdmtools.py +45 -0
  95. desdeo/gdm/score_bands.py +114 -0
  96. desdeo/gdm/voting_rules.py +50 -0
  97. desdeo/mcdm/__init__.py +41 -0
  98. desdeo/mcdm/enautilus.py +338 -0
  99. desdeo/mcdm/gnimbus.py +484 -0
  100. desdeo/mcdm/nautili.py +345 -0
  101. desdeo/mcdm/nautilus.py +477 -0
  102. desdeo/mcdm/nautilus_navigator.py +656 -0
  103. desdeo/mcdm/nimbus.py +417 -0
  104. desdeo/mcdm/pareto_navigator.py +269 -0
  105. desdeo/mcdm/reference_point_method.py +186 -0
  106. desdeo/problem/__init__.py +83 -0
  107. desdeo/problem/evaluator.py +561 -0
  108. desdeo/problem/external/__init__.py +18 -0
  109. desdeo/problem/external/core.py +356 -0
  110. desdeo/problem/external/pymoo_provider.py +266 -0
  111. desdeo/problem/external/runtime.py +44 -0
  112. desdeo/problem/gurobipy_evaluator.py +562 -0
  113. desdeo/problem/infix_parser.py +341 -0
  114. desdeo/problem/json_parser.py +944 -0
  115. desdeo/problem/pyomo_evaluator.py +487 -0
  116. desdeo/problem/schema.py +1829 -0
  117. desdeo/problem/simulator_evaluator.py +348 -0
  118. desdeo/problem/sympy_evaluator.py +244 -0
  119. desdeo/problem/testproblems/__init__.py +88 -0
  120. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  121. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  122. desdeo/problem/testproblems/cake_problem.py +185 -0
  123. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  124. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  125. desdeo/problem/testproblems/forest_problem.py +283 -0
  126. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  127. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  128. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  129. desdeo/problem/testproblems/momip_problem.py +172 -0
  130. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  131. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  132. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  133. desdeo/problem/testproblems/re_problem.py +492 -0
  134. desdeo/problem/testproblems/river_pollution_problems.py +440 -0
  135. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  136. desdeo/problem/testproblems/simple_problem.py +351 -0
  137. desdeo/problem/testproblems/simulator_problem.py +92 -0
  138. desdeo/problem/testproblems/single_objective.py +289 -0
  139. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  140. desdeo/problem/testproblems/zdt_problem.py +274 -0
  141. desdeo/problem/utils.py +245 -0
  142. desdeo/tools/GenerateReferencePoints.py +181 -0
  143. desdeo/tools/__init__.py +120 -0
  144. desdeo/tools/desc_gen.py +22 -0
  145. desdeo/tools/generics.py +165 -0
  146. desdeo/tools/group_scalarization.py +3090 -0
  147. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  148. desdeo/tools/indicators_binary.py +117 -0
  149. desdeo/tools/indicators_unary.py +362 -0
  150. desdeo/tools/interaction_schema.py +38 -0
  151. desdeo/tools/intersection.py +54 -0
  152. desdeo/tools/iterative_pareto_representer.py +99 -0
  153. desdeo/tools/message.py +265 -0
  154. desdeo/tools/ng_solver_interfaces.py +199 -0
  155. desdeo/tools/non_dominated_sorting.py +134 -0
  156. desdeo/tools/patterns.py +283 -0
  157. desdeo/tools/proximal_solver.py +99 -0
  158. desdeo/tools/pyomo_solver_interfaces.py +477 -0
  159. desdeo/tools/reference_vectors.py +229 -0
  160. desdeo/tools/scalarization.py +2065 -0
  161. desdeo/tools/scipy_solver_interfaces.py +454 -0
  162. desdeo/tools/score_bands.py +627 -0
  163. desdeo/tools/utils.py +388 -0
  164. desdeo/tools/visualizations.py +67 -0
  165. desdeo/utopia_stuff/__init__.py +0 -0
  166. desdeo/utopia_stuff/data/1.json +15 -0
  167. desdeo/utopia_stuff/data/2.json +13 -0
  168. desdeo/utopia_stuff/data/3.json +15 -0
  169. desdeo/utopia_stuff/data/4.json +17 -0
  170. desdeo/utopia_stuff/data/5.json +15 -0
  171. desdeo/utopia_stuff/from_json.py +40 -0
  172. desdeo/utopia_stuff/reinit_user.py +38 -0
  173. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  174. desdeo/utopia_stuff/utopia_problem.py +403 -0
  175. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  176. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  177. desdeo-2.1.0.dist-info/METADATA +186 -0
  178. desdeo-2.1.0.dist-info/RECORD +180 -0
  179. {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  180. desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
  181. desdeo-1.2.dist-info/METADATA +0 -16
  182. desdeo-1.2.dist-info/RECORD +0 -4
@@ -0,0 +1,765 @@
1
+ """Router for NIMBUS."""
2
+
3
+ import json
4
+ from typing import Annotated, Any
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException
7
+ from numpy import allclose
8
+ from pydantic import BaseModel, Field, ValidationError
9
+ from sqlalchemy.orm import Session
10
+
11
+ from desdeo.api.db import get_db
12
+ from desdeo.api.db_models import Method, Preference, SolutionArchive, Utopia
13
+ from desdeo.api.db_models import Problem as ProblemInDB
14
+ from desdeo.api.routers.user_authentication import get_current_user
15
+ from desdeo.api.schema import User
16
+ from desdeo.mcdm.nimbus import generate_starting_point, solve_intermediate_solutions, solve_sub_problems
17
+ from desdeo.problem.schema import Problem
18
+ from desdeo.tools.utils import available_solvers
19
+
20
+ router = APIRouter(prefix="/nimbus")
21
+
22
+
23
+ class InitRequest(BaseModel):
24
+ """The request to initialize the NIMBUS."""
25
+
26
+ problem_id: int = Field(description="The ID of the problem to navigate.")
27
+ method_id: int = Field(description="The ID of the method being used.")
28
+
29
+
30
+ class NIMBUSResponse(BaseModel):
31
+ """The response from most NIMBUS endpoints."""
32
+
33
+ objective_symbols: list[str] = Field(description="The symbols of the objectives.")
34
+ objective_long_names: list[str] = Field(description="The names of the objectives.")
35
+ units: list[str | None] | None = Field(description="The units of the objectives.")
36
+ is_maximized: list[bool] = Field(description="Whether the objectives are to be maximized or minimized.")
37
+ lower_bounds: list[float] = Field(description="The lower bounds of the objectives.")
38
+ upper_bounds: list[float] = Field(description="The upper bounds of the objectives.")
39
+ previous_preference: list[float] = Field(description="The previous preference used.")
40
+ current_solutions: list[list[float]] = Field(description="The solutions from the current interation of nimbus.")
41
+ saved_solutions: list[list[float]] = Field(description="The best candidate solutions saved by the decision maker.")
42
+ all_solutions: list[list[float]] = Field(description="All solutions generated by NIMBUS in all iterations.")
43
+
44
+
45
+ class FakeNIMBUSResponse(BaseModel):
46
+ """fake response for testing purposes."""
47
+
48
+ message: str = Field(description="A simple message.")
49
+
50
+
51
+ class UtopiaResponse(BaseModel):
52
+ """The response to an UtopiaRequest."""
53
+
54
+ is_utopia: bool = Field(description="True if map exists for this problem.")
55
+ map_name: str = Field(description="Name of the map.")
56
+ map_json: dict[str, Any] = Field(description="MapJSON representation of the geography.")
57
+ options: dict[str, Any] = Field(description="A dict with given years as keys containing options for each year.")
58
+ description: str = Field(description="Description shown above the map.")
59
+ years: list[str] = Field(description="A list of years for which the maps have been generated.")
60
+
61
+
62
+ class UtopiaRequest(BaseModel):
63
+ """The request for an Utopia map."""
64
+
65
+ problem_id: int = Field(description="The ID of the problem to be solved.")
66
+ solution: list[float] = Field(description="The solution for which the map is generated.")
67
+
68
+
69
+ class NIMBUSIterateRequest(BaseModel):
70
+ """The request to iterate the NIMBUS algorithm."""
71
+
72
+ problem_id: int = Field(description="The ID of the problem to be solved.")
73
+ method_id: int = Field(description="The ID of the method being used.")
74
+ preference: list[float] = Field(
75
+ description=(
76
+ "The preference as a reference point. Note, NIMBUS uses classification preference,"
77
+ " we can construct it using this reference point and the reference solution."
78
+ )
79
+ )
80
+ reference_solution: list[float] = Field(
81
+ description="The reference solution to be used in the classification preference."
82
+ )
83
+ num_solutions: int | None = Field(
84
+ description="The number of solutions to be generated in the iteration.", default=1
85
+ )
86
+
87
+
88
+ class NIMBUSIntermediateSolutionRequest(BaseModel):
89
+ """The request to generate an intermediate solution in NIMBUS."""
90
+
91
+ problem_id: int = Field(description="The ID of the problem to be solved.")
92
+ method_id: int = Field(description="The ID of the method being used.")
93
+
94
+ reference_solution_1: list[float] = Field(
95
+ description="The first reference solution to be used in the classification preference."
96
+ )
97
+ reference_solution_2: list[float] = Field(
98
+ description="The reference solution to be used in the classification preference."
99
+ )
100
+ num_solutions: int | None = Field(
101
+ description="The number of solutions to be generated in the iteration.", default=1
102
+ )
103
+
104
+
105
+ class SaveRequest(BaseModel):
106
+ """The request to save the solutions."""
107
+
108
+ problem_id: int = Field(description="The ID of the problem to be solved.")
109
+ method_id: int = Field(description="The ID of the method being used.")
110
+ solutions: list[list[float]] = Field(description="The solutions to be saved.")
111
+
112
+
113
+ class ChooseRequest(BaseModel):
114
+ """The request to choose the final solution."""
115
+
116
+ problem_id: int = Field(description="The ID of the problem to be solved.")
117
+ method_id: int = Field(description="The ID of the method being used.")
118
+ solution: list[float] = Field(description="The chosen solution.")
119
+
120
+
121
+ @router.post("/initialize")
122
+ def init_nimbus(
123
+ init_request: InitRequest,
124
+ user: Annotated[User, Depends(get_current_user)],
125
+ db: Annotated[Session, Depends(get_db)],
126
+ ) -> NIMBUSResponse | FakeNIMBUSResponse:
127
+ """Initialize the NIMBUS algorithm.
128
+
129
+ Args:
130
+ init_request (InitRequest): The request to initialize the NIMBUS.
131
+ user (Annotated[User, Depends(get_current_user)]): The current user.
132
+ db (Annotated[Session, Depends(get_db)]): The database session.
133
+
134
+ Returns:
135
+ The response from the NIMBUS algorithm.
136
+ """
137
+ # Do database stuff here.
138
+ problem_id = init_request.problem_id
139
+ # The request is supposed to contain method id, but I don't want to deal with frontend code
140
+ init_request.method_id = get_nimbus_method_id(db)
141
+ method_id = init_request.method_id
142
+
143
+ problem, solver = read_problem_from_db(db=db, problem_id=problem_id, user_id=user.index)
144
+
145
+ # See if there are previous solutions in the database for this problem
146
+ solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
147
+
148
+ # Calculate bounds here, just to make sure that they have been properly defined in the problem
149
+ lower_bounds, upper_bounds = calculate_bounds(problem)
150
+
151
+ # If there are no solutions, generate a starting point for NIMBUS
152
+ if not solutions:
153
+ start_result = generate_starting_point(
154
+ problem=problem,
155
+ solver=available_solvers[solver]["constructor"] if solver else None
156
+ )
157
+ save_results_to_db(
158
+ db=db, user_id=user.index, request=init_request, results=[start_result], previous_solutions=solutions
159
+ )
160
+ solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
161
+
162
+ # If there is a solution marked as current, use that. Otherwise just use the first solution in the db
163
+ current_solution = next((sol for sol in solutions if sol.current), solutions[0])
164
+
165
+ # return FakeNIMBUSResponse(message="NIMBUS initialized.")
166
+ return NIMBUSResponse(
167
+ objective_symbols=[obj.symbol for obj in problem.objectives],
168
+ objective_long_names=[obj.name for obj in problem.objectives],
169
+ units=[obj.unit for obj in problem.objectives],
170
+ is_maximized=[obj.maximize for obj in problem.objectives],
171
+ lower_bounds=lower_bounds,
172
+ upper_bounds=upper_bounds,
173
+ previous_preference=current_solution.objectives,
174
+ current_solutions=[current_solution.objectives],
175
+ saved_solutions=[sol.objectives for sol in solutions if sol.saved],
176
+ all_solutions=[sol.objectives for sol in solutions],
177
+ )
178
+
179
+
180
+ @router.post("/iterate")
181
+ def iterate(
182
+ request: NIMBUSIterateRequest,
183
+ user: Annotated[User, Depends(get_current_user)],
184
+ db: Annotated[Session, Depends(get_db)],
185
+ ) -> NIMBUSResponse | FakeNIMBUSResponse:
186
+ """Iterate the NIMBUS algorithm.
187
+
188
+ Args:
189
+ request: The request body for a NIMBUS iteration.
190
+ user (Annotated[User, Depends(get_current_user)]): The current user.
191
+ db (Annotated[Session, Depends(get_db)]): The database session.
192
+
193
+ Returns:
194
+ The response from the NIMBUS algorithm.
195
+ """
196
+ # Do database stuff here.
197
+ problem_id = request.problem_id
198
+ # The request is supposed to contain method id, but I don't want to deal with frontend code
199
+ request.method_id = get_nimbus_method_id(db)
200
+ method_id = request.method_id
201
+
202
+ problem, solver = read_problem_from_db(db=db, problem_id=problem_id, user_id=user.index)
203
+
204
+ previous_solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
205
+
206
+ if not previous_solutions:
207
+ raise HTTPException(status_code=404, detail="Problem not found in the database.")
208
+
209
+ # Calculate bounds here, just to make sure that they have been properly defined in the problem
210
+ lower_bounds, upper_bounds = calculate_bounds(problem)
211
+
212
+ # Do NIMBUS stuff here.
213
+ results = solve_sub_problems(
214
+ problem=problem,
215
+ current_objectives=dict(
216
+ zip([obj.symbol for obj in problem.objectives], request.reference_solution, strict=True)
217
+ ),
218
+ reference_point=dict(zip([obj.symbol for obj in problem.objectives], request.preference, strict=True)),
219
+ num_desired=request.num_solutions,
220
+ solver=available_solvers[solver]["constructor"] if solver else None,
221
+ scalarization_options={"rho": 0.001, "delta": 0.001},
222
+ )
223
+
224
+ # Do database stuff again.
225
+ save_results_to_db(
226
+ db=db, user_id=user.index, request=request, results=results, previous_solutions=previous_solutions
227
+ )
228
+
229
+ solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
230
+
231
+ return NIMBUSResponse(
232
+ objective_symbols=[obj.symbol for obj in problem.objectives],
233
+ objective_long_names=[obj.name for obj in problem.objectives],
234
+ units=[obj.unit for obj in problem.objectives],
235
+ is_maximized=[obj.maximize for obj in problem.objectives],
236
+ lower_bounds=lower_bounds,
237
+ upper_bounds=upper_bounds,
238
+ previous_preference=request.preference,
239
+ current_solutions=[sol.objectives for sol in solutions if sol.current],
240
+ saved_solutions=[sol.objectives for sol in solutions if sol.saved],
241
+ all_solutions=[sol.objectives for sol in solutions],
242
+ )
243
+
244
+
245
+ @router.post("/intermediate")
246
+ def intermediate(
247
+ request: NIMBUSIntermediateSolutionRequest,
248
+ user: Annotated[User, Depends(get_current_user)],
249
+ db: Annotated[Session, Depends(get_db)],
250
+ ) -> NIMBUSResponse | FakeNIMBUSResponse:
251
+ """Get solutions between two solutions using NIMBUS.
252
+
253
+ Args:
254
+ request: The request body for a NIMBUS iteration.
255
+ user (Annotated[User, Depends(get_current_user)]): The current user.
256
+ db (Annotated[Session, Depends(get_db)]): The database session.
257
+
258
+ Returns:
259
+ The response from the NIMBUS algorithm.
260
+ """
261
+ # Do database stuff here.
262
+ problem_id = request.problem_id
263
+ # The request is supposed to contain method id, but I don't want to deal with frontend code
264
+ request.method_id = get_nimbus_method_id(db)
265
+ method_id = request.method_id
266
+
267
+ problem, solver = read_problem_from_db(db=db, problem_id=problem_id, user_id=user.index)
268
+
269
+ previous_solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
270
+
271
+ if not previous_solutions:
272
+ raise HTTPException(status_code=404, detail="Problem not found in the database.")
273
+
274
+ # Calculate bounds here, just to make sure that they have been properly defined in the problem
275
+ lower_bounds, upper_bounds = calculate_bounds(problem)
276
+
277
+ # Do NIMBUS stuff here.
278
+ results = solve_intermediate_solutions(
279
+ problem=problem,
280
+ solution_1=dict(zip(problem.objectives, request.reference_solution_1, strict=True)),
281
+ solution_2=dict(zip(problem.objectives, request.reference_solution_2, strict=True)),
282
+ num_desired=request.num_solutions,
283
+ solver=available_solvers[solver]["constructor"] if solver else None,
284
+ )
285
+
286
+ # Do database stuff again.
287
+ save_results_to_db(
288
+ db=db, user_id=user.index, request=request, results=results, previous_solutions=previous_solutions
289
+ )
290
+
291
+ solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
292
+
293
+ return NIMBUSResponse(
294
+ objective_symbols=[obj.symbol for obj in problem.objectives],
295
+ objective_long_names=[obj.name for obj in problem.objectives],
296
+ units=[obj.unit for obj in problem.objectives],
297
+ is_maximized=[obj.maximize for obj in problem.objectives],
298
+ lower_bounds=lower_bounds,
299
+ upper_bounds=upper_bounds,
300
+ previous_preference=request.preference,
301
+ current_solutions=[sol.objectives for sol in solutions if sol.current],
302
+ saved_solutions=[sol.objectives for sol in solutions if sol.saved],
303
+ all_solutions=[sol.objectives for sol in solutions],
304
+ )
305
+
306
+
307
+ @router.post("/save")
308
+ def save(
309
+ request: SaveRequest,
310
+ user: Annotated[User, Depends(get_current_user)],
311
+ db: Annotated[Session, Depends(get_db)],
312
+ ) -> NIMBUSResponse | FakeNIMBUSResponse:
313
+ """Save the solutions to the database.
314
+
315
+ Args:
316
+ request: The request body for saving solutions.
317
+ user (Annotated[User, Depends(get_current_user)]): The current user.
318
+ db (Annotated[Session, Depends(get_db)]): The database session.
319
+
320
+ Returns:
321
+ The response from the NIMBUS algorithm.
322
+ """
323
+ # Get the solutions from database.
324
+ problem_id = request.problem_id
325
+ method_id = get_nimbus_method_id(db)
326
+
327
+ previous_solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
328
+
329
+ if not previous_solutions:
330
+ raise HTTPException(status_code=404, detail="Problem not found in the database.")
331
+
332
+ # Find the requested solutions and mark them as saved.
333
+ for sol in request.solutions:
334
+ for prev in previous_solutions:
335
+ if allclose(sol, prev.objectives):
336
+ prev.saved = True
337
+ db.commit()
338
+
339
+ return NIMBUSResponse(
340
+ objective_symbols=[],
341
+ objective_long_names=[],
342
+ units=[],
343
+ is_maximized=[],
344
+ lower_bounds=[],
345
+ upper_bounds=[],
346
+ previous_preference=[],
347
+ current_solutions=[sol.objectives for sol in previous_solutions if sol.current],
348
+ saved_solutions=[sol.objectives for sol in previous_solutions if sol.saved],
349
+ all_solutions=[sol.objectives for sol in previous_solutions],
350
+ )
351
+
352
+
353
+ @router.post("/choose")
354
+ def choose(
355
+ request: ChooseRequest,
356
+ user: Annotated[User, Depends(get_current_user)],
357
+ db: Annotated[Session, Depends(get_db)],
358
+ ) -> FakeNIMBUSResponse:
359
+ """Choose a solution as the final solution for NIMBUS.
360
+
361
+ Args:
362
+ request: The request body for saving solutions.
363
+ user (Annotated[User, Depends(get_current_user)]): The current user.
364
+ db (Annotated[Session, Depends(get_db)]): The database session.
365
+
366
+ Returns:
367
+ The response from the NIMBUS algorithm.
368
+ """
369
+ # Get the solutions from database.
370
+ problem_id = request.problem_id
371
+ method_id = get_nimbus_method_id(db)
372
+
373
+ previous_solutions = read_solutions_from_db(db, problem_id, user.index, method_id)
374
+
375
+ if not previous_solutions:
376
+ raise HTTPException(status_code=404, detail="Problem not found in the database.")
377
+
378
+ # Find the requested solution and mark it as chosen.
379
+ for prev in previous_solutions:
380
+ if allclose(request.solution, prev.objectives):
381
+ prev.chosen = True
382
+ db.commit()
383
+ break
384
+ else:
385
+ raise HTTPException(status_code=404, detail="The chosen solution was not found in the database.")
386
+
387
+ return FakeNIMBUSResponse(message="Solution chosen.")
388
+
389
+
390
+ @router.post("/utopia")
391
+ def utopia( # noqa: C901, PLR0912
392
+ request: UtopiaRequest,
393
+ user: Annotated[User, Depends(get_current_user)],
394
+ db: Annotated[Session, Depends(get_db)],
395
+ ) -> UtopiaResponse:
396
+ """Request information necessary to draw the map.
397
+
398
+ Args:
399
+ request: The request body for saving solutions.
400
+ user (Annotated[User, Depends(get_current_user)]): The current user.
401
+ db (Annotated[Session, Depends(get_db)]): The database session.
402
+
403
+ Returns:
404
+ The information used to draw the map.
405
+ """
406
+ method_id = get_nimbus_method_id(db)
407
+ archived_solutions = read_solutions_from_db(db, request.problem_id, user.index, method_id)
408
+
409
+ # Find the solution from the archive
410
+ for sol in archived_solutions:
411
+ if allclose(request.solution, sol.objectives):
412
+ solution = sol
413
+ break
414
+ else:
415
+ raise HTTPException(status_code=404, detail="The chosen solution was not found in the database.")
416
+
417
+ decision_variables = json.loads(solution.decision_variables)
418
+
419
+ # Get the user's map from the database
420
+ utopia_data = db.query(Utopia).filter(Utopia.problem == request.problem_id).first()
421
+ if not utopia_data:
422
+ return UtopiaResponse(
423
+ is_utopia=False,
424
+ map_name="",
425
+ options={},
426
+ map_json={},
427
+ description="",
428
+ years=[],
429
+ )
430
+
431
+ # Figure out the treatments from the decision variables and utopia data
432
+ description_dict = {
433
+ 0: "Do nothing",
434
+ 1: "Clearcut",
435
+ 2: "Thinning from below",
436
+ 3: "Thinning from above",
437
+ 4: "Even thinning",
438
+ 5: "First thinning",
439
+ }
440
+
441
+ def treatment_index(part: str) -> str:
442
+ if "clearcut" in part:
443
+ return 1
444
+ if "below" in part:
445
+ return 2
446
+ if "above" in part:
447
+ return 3
448
+ if "even" in part:
449
+ return 4
450
+ if "first" in part:
451
+ return 5
452
+ return -1
453
+
454
+ treatments_dict = {}
455
+ for key in decision_variables:
456
+ if not key.startswith("X"):
457
+ continue
458
+ # The dict keys get converted to ints to strings when it's loaded from database
459
+ try:
460
+ treatments = utopia_data.schedule_dict[key][str(decision_variables[key].index(1))]
461
+ except ValueError as e:
462
+ # if the optimization didn't choose any decision alternative, it's safe to assume
463
+ # that nothing is being done at that forest stand
464
+ treatments = utopia_data.schedule_dict[key]["0"]
465
+ print(e)
466
+ treatments_dict[key] = {utopia_data.years[0]: 0, utopia_data.years[1]: 0, utopia_data.years[2]: 0}
467
+ for year in treatments_dict[key]:
468
+ if year in treatments:
469
+ for part in treatments.split():
470
+ if year in part:
471
+ treatments_dict[key][year] = treatment_index(part)
472
+
473
+ # Create the options for the webui
474
+
475
+ treatment_colors = {
476
+ 0: "#4daf4a",
477
+ 1: "#e41a1c",
478
+ 2: "#984ea3",
479
+ 3: "#e3d802",
480
+ 4: "#ff7f00",
481
+ 5: "#377eb8",
482
+ }
483
+
484
+ map_name = "ForestMap" # This isn't visible anywhere on the ui
485
+
486
+ options = {}
487
+ for year in utopia_data.years:
488
+ options[year] = {
489
+ "tooltip": {
490
+ "trigger": "item",
491
+ "showDelay": 0,
492
+ "transitionDuration": 0.2,
493
+ },
494
+ "visualMap": { # // vis eg. stock levels
495
+ "left": "right",
496
+ "showLabel": True,
497
+ "type": "piecewise", # // for different plans
498
+ "pieces": [],
499
+ "text": ["Management plans"],
500
+ "calculable": True,
501
+ },
502
+ # // predefined symbols for visumap'circle': 'rect': 'roundRect': 'triangle': 'diamond': 'pin':'arrow':
503
+ # // can give custom svgs also
504
+ "toolbox": {
505
+ "show": True,
506
+ # //orient: 'vertical',
507
+ "left": "left",
508
+ "top": "top",
509
+ "feature": {
510
+ "dataView": {"readOnly": True},
511
+ "restore": {},
512
+ "saveAsImage": {},
513
+ },
514
+ },
515
+ # // can draw graphic components to indicate different things at least
516
+ "series": [
517
+ {
518
+ "name": year,
519
+ "type": "map",
520
+ "roam": True,
521
+ "map": map_name,
522
+ "nameProperty": utopia_data.stand_id_field,
523
+ "label": {
524
+ "show": False # Hide text labels on the map
525
+ },
526
+ # "colorBy": "data",
527
+ # "itemStyle": {"symbol": "triangle", "color": "red"},
528
+ "data": [],
529
+ "nameMap": {},
530
+ }
531
+ ],
532
+ }
533
+
534
+ for key in decision_variables:
535
+ if not key.startswith("X"):
536
+ continue
537
+ stand = int(utopia_data.schedule_dict[key]["unit"])
538
+ treatment_id = treatments_dict[key][year]
539
+ piece = {
540
+ "value": treatment_id,
541
+ "symbol": "circle",
542
+ "label": description_dict[treatment_id],
543
+ "color": treatment_colors[treatment_id],
544
+ }
545
+ if piece not in options[year]["visualMap"]["pieces"]:
546
+ options[year]["visualMap"]["pieces"].append(piece)
547
+ if utopia_data.stand_descriptor:
548
+ name = utopia_data.stand_descriptor[str(stand)] + description_dict[treatment_id]
549
+ else:
550
+ name = "Stand " + str(stand) + " " + description_dict[treatment_id]
551
+ options[year]["series"][0]["data"].append(
552
+ {
553
+ "name": name,
554
+ "value": treatment_id,
555
+ }
556
+ )
557
+ options[year]["series"][0]["nameMap"][stand] = name
558
+
559
+ # Let's also generate a nice description for the map
560
+ map_description = (
561
+ f"Income from harvesting in the first period {int(decision_variables["P_1"])}€.\n"
562
+ + f"Income from harvesting in the second period {int(decision_variables["P_2"])}€.\n"
563
+ + f"Income from harvesting in the third period {int(decision_variables["P_3"])}€.\n"
564
+ + f"The discounted value of the remaining forest at the end of the plan {int(decision_variables["V_end"])}€."
565
+ )
566
+
567
+ return UtopiaResponse(
568
+ is_utopia=True,
569
+ map_name=map_name,
570
+ options=options,
571
+ map_json=json.loads(utopia_data.map_json),
572
+ description=map_description,
573
+ years=utopia_data.years,
574
+ )
575
+
576
+
577
+ def flatten(lst) -> list[float]:
578
+ """Takes a nested list and flattens it into a single list.
579
+
580
+ Args:
581
+ lst: The list that needs flattening
582
+
583
+ Returns:
584
+ The flattened list.
585
+ """
586
+ flat_list = []
587
+ for item in lst:
588
+ if isinstance(item, list):
589
+ flat_list.extend(flatten(item))
590
+ else:
591
+ flat_list.append(item)
592
+ return flat_list
593
+
594
+
595
+ def get_nimbus_method_id(db: Session) -> int:
596
+ """Queries the database to find the id for NIMBUS method.
597
+
598
+ Args:
599
+ db: Database session
600
+
601
+ Returns:
602
+ The method id
603
+ """
604
+ nimbus_method = db.query(Method).filter(Method.kind == Methods.NIMBUS).first()
605
+ return nimbus_method.id
606
+
607
+
608
+ def read_problem_from_db(db: Session, problem_id: int, user_id: int) -> tuple[Problem, str]:
609
+ """Reads the problem from database.
610
+
611
+ Args:
612
+ db (Session): Database session to be used
613
+ problem_id (int): Id of the problem
614
+ method_id (int): Id of the method
615
+ user_id (int): Index of the user
616
+
617
+ Raises:
618
+ HTTPException: _description_
619
+ HTTPException: _description_
620
+ HTTPException: _description_
621
+
622
+ Returns:
623
+ tuple[Problem, str]: Returns the problem as a desdeo problem class and the name of the solver
624
+ """
625
+ problem = db.query(ProblemInDB).filter(ProblemInDB.id == problem_id).first()
626
+
627
+ if problem is None:
628
+ raise HTTPException(status_code=404, detail="Problem not found.")
629
+ if problem.owner != user_id and problem.owner is not None:
630
+ raise HTTPException(status_code=403, detail="Unauthorized to access chosen problem.")
631
+ try:
632
+ solver = problem.solver.value if problem.solver else None
633
+ problem = Problem.model_validate(problem.value)
634
+ except ValidationError:
635
+ raise HTTPException(status_code=500, detail="Error in parsing the problem.") from ValidationError
636
+ return problem, solver
637
+
638
+
639
+ def read_solutions_from_db(db: Session, problem_id: int, user_id: int, method_id: int) -> list[SolutionArchive]:
640
+ """Reads the previous solutions from the database.
641
+
642
+ Args:
643
+ db (Session): _description_
644
+ problem_id (int): _description_
645
+ user_id (int): _description_
646
+ method_id (int): _description_
647
+
648
+ Returns:
649
+ list[SolutionArchive]: _description_
650
+ """
651
+ return (
652
+ db.query(SolutionArchive)
653
+ .filter(
654
+ SolutionArchive.problem == problem_id, SolutionArchive.user == user_id, SolutionArchive.method == method_id
655
+ )
656
+ .all()
657
+ )
658
+
659
+
660
+ def save_results_to_db(
661
+ db: Session,
662
+ user_id: int,
663
+ request: InitRequest | NIMBUSIterateRequest | NIMBUSIntermediateSolutionRequest,
664
+ results: list,
665
+ previous_solutions: list[SolutionArchive],
666
+ ):
667
+ """Saves the results to the database.
668
+
669
+ Args:
670
+ db (Session): _description_
671
+ user_id (int): _description_
672
+ request (_type_): _description_
673
+ results (list): _description_
674
+ previous_solutions (list[SolutionArchive]): _description_
675
+ """
676
+ problem_id = request.problem_id
677
+ method_id = request.method_id
678
+
679
+ if type(request) is InitRequest:
680
+ pref = None
681
+ else:
682
+ pref = Preference(
683
+ user=user_id,
684
+ problem=problem_id,
685
+ method=method_id,
686
+ kind="NIMBUS" if type(type(request) is NIMBUSIterateRequest) else "NIMBUS_intermediate",
687
+ value=request.model_dump(mode="json"),
688
+ )
689
+ db.add(pref)
690
+ db.commit()
691
+
692
+ # See if the results include duplicates and remove them
693
+ duplicate_indices = set()
694
+ for i in range(len(results) - 1):
695
+ for j in range(i + 1, len(results)):
696
+ if allclose(list(results[i].optimal_objectives.values()), list(results[j].optimal_objectives.values())):
697
+ duplicate_indices.add(j)
698
+
699
+ for index in sorted(duplicate_indices, reverse=True):
700
+ results.pop(index)
701
+
702
+ old_current_solutions = (
703
+ db.query(SolutionArchive)
704
+ .filter(SolutionArchive.problem == problem_id, SolutionArchive.user == user_id, SolutionArchive.current)
705
+ .all()
706
+ )
707
+
708
+ # Mark all the old solutions as not current
709
+ for old in old_current_solutions:
710
+ old.current = False
711
+
712
+ for res in results:
713
+ # Check if the results already exist in the database
714
+ duplicate = False
715
+ for prev in previous_solutions:
716
+ if allclose(list(res.optimal_objectives.values()), list(prev.objectives)):
717
+ prev.current = True
718
+ duplicate = True
719
+ break
720
+ # If the solution was not found in the database, add it
721
+ if not duplicate:
722
+ db.add(
723
+ SolutionArchive(
724
+ user=user_id,
725
+ problem=problem_id,
726
+ method=method_id,
727
+ preference=pref.id if pref is not None else None,
728
+ decision_variables=json.dumps(res.optimal_variables),
729
+ objectives=list(res.optimal_objectives.values()),
730
+ saved=False,
731
+ current=True,
732
+ chosen=False,
733
+ )
734
+ )
735
+ db.commit()
736
+
737
+
738
+ def calculate_bounds(problem: Problem) -> tuple[list[float, list[float]]]:
739
+ """Calculates upper and lower bounds for the objectives.
740
+
741
+ Args:
742
+ problem (Problem): _description_
743
+
744
+ Raises:
745
+ HTTPException: _description_
746
+
747
+ Returns:
748
+ tuple[list[float, list[float]]]: tuple containing a list of lower bound values and a list of upper bound values
749
+ """
750
+ ideal = problem.get_ideal_point()
751
+ nadir = problem.get_nadir_point()
752
+ if None in ideal or None in nadir:
753
+ raise HTTPException(status_code=500, detail="Problem missing ideal or nadir value.")
754
+
755
+ lower_bounds = [0.0 for x in range(len(problem.objectives))]
756
+ upper_bounds = [0.0 for x in range(len(problem.objectives))]
757
+ for i in range(len(problem.objectives)):
758
+ if problem.objectives[i].maximize:
759
+ lower_bounds[i] = nadir[problem.objectives[i].symbol]
760
+ upper_bounds[i] = ideal[problem.objectives[i].symbol]
761
+ else:
762
+ lower_bounds[i] = ideal[problem.objectives[i].symbol]
763
+ upper_bounds[i] = nadir[problem.objectives[i].symbol]
764
+
765
+ return lower_bounds, upper_bounds