desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/adm/ADMAfsar.py +551 -0
  3. desdeo/adm/ADMChen.py +414 -0
  4. desdeo/adm/BaseADM.py +119 -0
  5. desdeo/adm/__init__.py +11 -0
  6. desdeo/api/README.md +73 -0
  7. desdeo/api/__init__.py +15 -0
  8. desdeo/api/app.py +50 -0
  9. desdeo/api/config.py +90 -0
  10. desdeo/api/config.toml +64 -0
  11. desdeo/api/db.py +27 -0
  12. desdeo/api/db_init.py +85 -0
  13. desdeo/api/db_models.py +164 -0
  14. desdeo/api/malaga_db_init.py +27 -0
  15. desdeo/api/models/__init__.py +266 -0
  16. desdeo/api/models/archive.py +23 -0
  17. desdeo/api/models/emo.py +128 -0
  18. desdeo/api/models/enautilus.py +69 -0
  19. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  20. desdeo/api/models/gdm/gdm_base.py +69 -0
  21. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  22. desdeo/api/models/gdm/gnimbus.py +138 -0
  23. desdeo/api/models/generic.py +104 -0
  24. desdeo/api/models/generic_states.py +401 -0
  25. desdeo/api/models/nimbus.py +158 -0
  26. desdeo/api/models/preference.py +128 -0
  27. desdeo/api/models/problem.py +717 -0
  28. desdeo/api/models/reference_point_method.py +18 -0
  29. desdeo/api/models/session.py +49 -0
  30. desdeo/api/models/state.py +463 -0
  31. desdeo/api/models/user.py +52 -0
  32. desdeo/api/models/utopia.py +25 -0
  33. desdeo/api/routers/_EMO.backup +309 -0
  34. desdeo/api/routers/_NAUTILUS.py +245 -0
  35. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  36. desdeo/api/routers/_NIMBUS.py +765 -0
  37. desdeo/api/routers/__init__.py +5 -0
  38. desdeo/api/routers/emo.py +497 -0
  39. desdeo/api/routers/enautilus.py +237 -0
  40. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  41. desdeo/api/routers/gdm/gdm_base.py +420 -0
  42. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  43. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  44. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  45. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  46. desdeo/api/routers/generic.py +233 -0
  47. desdeo/api/routers/nimbus.py +705 -0
  48. desdeo/api/routers/problem.py +307 -0
  49. desdeo/api/routers/reference_point_method.py +93 -0
  50. desdeo/api/routers/session.py +100 -0
  51. desdeo/api/routers/test.py +16 -0
  52. desdeo/api/routers/user_authentication.py +520 -0
  53. desdeo/api/routers/utils.py +187 -0
  54. desdeo/api/routers/utopia.py +230 -0
  55. desdeo/api/schema.py +100 -0
  56. desdeo/api/tests/__init__.py +0 -0
  57. desdeo/api/tests/conftest.py +151 -0
  58. desdeo/api/tests/test_enautilus.py +330 -0
  59. desdeo/api/tests/test_models.py +1179 -0
  60. desdeo/api/tests/test_routes.py +1075 -0
  61. desdeo/api/utils/_database.py +263 -0
  62. desdeo/api/utils/_logger.py +29 -0
  63. desdeo/api/utils/database.py +36 -0
  64. desdeo/api/utils/emo_database.py +40 -0
  65. desdeo/core.py +34 -0
  66. desdeo/emo/__init__.py +159 -0
  67. desdeo/emo/hooks/archivers.py +188 -0
  68. desdeo/emo/methods/EAs.py +541 -0
  69. desdeo/emo/methods/__init__.py +0 -0
  70. desdeo/emo/methods/bases.py +12 -0
  71. desdeo/emo/methods/templates.py +111 -0
  72. desdeo/emo/operators/__init__.py +1 -0
  73. desdeo/emo/operators/crossover.py +1282 -0
  74. desdeo/emo/operators/evaluator.py +114 -0
  75. desdeo/emo/operators/generator.py +459 -0
  76. desdeo/emo/operators/mutation.py +1224 -0
  77. desdeo/emo/operators/scalar_selection.py +202 -0
  78. desdeo/emo/operators/selection.py +1778 -0
  79. desdeo/emo/operators/termination.py +286 -0
  80. desdeo/emo/options/__init__.py +108 -0
  81. desdeo/emo/options/algorithms.py +435 -0
  82. desdeo/emo/options/crossover.py +164 -0
  83. desdeo/emo/options/generator.py +131 -0
  84. desdeo/emo/options/mutation.py +260 -0
  85. desdeo/emo/options/repair.py +61 -0
  86. desdeo/emo/options/scalar_selection.py +66 -0
  87. desdeo/emo/options/selection.py +127 -0
  88. desdeo/emo/options/templates.py +383 -0
  89. desdeo/emo/options/termination.py +143 -0
  90. desdeo/explanations/__init__.py +6 -0
  91. desdeo/explanations/explainer.py +100 -0
  92. desdeo/explanations/utils.py +90 -0
  93. desdeo/gdm/__init__.py +22 -0
  94. desdeo/gdm/gdmtools.py +45 -0
  95. desdeo/gdm/score_bands.py +114 -0
  96. desdeo/gdm/voting_rules.py +50 -0
  97. desdeo/mcdm/__init__.py +41 -0
  98. desdeo/mcdm/enautilus.py +338 -0
  99. desdeo/mcdm/gnimbus.py +484 -0
  100. desdeo/mcdm/nautili.py +345 -0
  101. desdeo/mcdm/nautilus.py +477 -0
  102. desdeo/mcdm/nautilus_navigator.py +656 -0
  103. desdeo/mcdm/nimbus.py +417 -0
  104. desdeo/mcdm/pareto_navigator.py +269 -0
  105. desdeo/mcdm/reference_point_method.py +186 -0
  106. desdeo/problem/__init__.py +83 -0
  107. desdeo/problem/evaluator.py +561 -0
  108. desdeo/problem/external/__init__.py +18 -0
  109. desdeo/problem/external/core.py +356 -0
  110. desdeo/problem/external/pymoo_provider.py +266 -0
  111. desdeo/problem/external/runtime.py +44 -0
  112. desdeo/problem/gurobipy_evaluator.py +562 -0
  113. desdeo/problem/infix_parser.py +341 -0
  114. desdeo/problem/json_parser.py +944 -0
  115. desdeo/problem/pyomo_evaluator.py +487 -0
  116. desdeo/problem/schema.py +1829 -0
  117. desdeo/problem/simulator_evaluator.py +348 -0
  118. desdeo/problem/sympy_evaluator.py +244 -0
  119. desdeo/problem/testproblems/__init__.py +88 -0
  120. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  121. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  122. desdeo/problem/testproblems/cake_problem.py +185 -0
  123. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  124. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  125. desdeo/problem/testproblems/forest_problem.py +283 -0
  126. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  127. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  128. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  129. desdeo/problem/testproblems/momip_problem.py +172 -0
  130. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  131. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  132. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  133. desdeo/problem/testproblems/re_problem.py +492 -0
  134. desdeo/problem/testproblems/river_pollution_problems.py +440 -0
  135. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  136. desdeo/problem/testproblems/simple_problem.py +351 -0
  137. desdeo/problem/testproblems/simulator_problem.py +92 -0
  138. desdeo/problem/testproblems/single_objective.py +289 -0
  139. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  140. desdeo/problem/testproblems/zdt_problem.py +274 -0
  141. desdeo/problem/utils.py +245 -0
  142. desdeo/tools/GenerateReferencePoints.py +181 -0
  143. desdeo/tools/__init__.py +120 -0
  144. desdeo/tools/desc_gen.py +22 -0
  145. desdeo/tools/generics.py +165 -0
  146. desdeo/tools/group_scalarization.py +3090 -0
  147. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  148. desdeo/tools/indicators_binary.py +117 -0
  149. desdeo/tools/indicators_unary.py +362 -0
  150. desdeo/tools/interaction_schema.py +38 -0
  151. desdeo/tools/intersection.py +54 -0
  152. desdeo/tools/iterative_pareto_representer.py +99 -0
  153. desdeo/tools/message.py +265 -0
  154. desdeo/tools/ng_solver_interfaces.py +199 -0
  155. desdeo/tools/non_dominated_sorting.py +134 -0
  156. desdeo/tools/patterns.py +283 -0
  157. desdeo/tools/proximal_solver.py +99 -0
  158. desdeo/tools/pyomo_solver_interfaces.py +477 -0
  159. desdeo/tools/reference_vectors.py +229 -0
  160. desdeo/tools/scalarization.py +2065 -0
  161. desdeo/tools/scipy_solver_interfaces.py +454 -0
  162. desdeo/tools/score_bands.py +627 -0
  163. desdeo/tools/utils.py +388 -0
  164. desdeo/tools/visualizations.py +67 -0
  165. desdeo/utopia_stuff/__init__.py +0 -0
  166. desdeo/utopia_stuff/data/1.json +15 -0
  167. desdeo/utopia_stuff/data/2.json +13 -0
  168. desdeo/utopia_stuff/data/3.json +15 -0
  169. desdeo/utopia_stuff/data/4.json +17 -0
  170. desdeo/utopia_stuff/data/5.json +15 -0
  171. desdeo/utopia_stuff/from_json.py +40 -0
  172. desdeo/utopia_stuff/reinit_user.py +38 -0
  173. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  174. desdeo/utopia_stuff/utopia_problem.py +403 -0
  175. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  176. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  177. desdeo-2.1.0.dist-info/METADATA +186 -0
  178. desdeo-2.1.0.dist-info/RECORD +180 -0
  179. {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  180. desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
  181. desdeo-1.2.dist-info/METADATA +0 -16
  182. desdeo-1.2.dist-info/RECORD +0 -4
@@ -0,0 +1,1179 @@
1
+ """Tests related to the SQLModels."""
2
+
3
+ import numpy as np
4
+ from sqlmodel import Session, select
5
+
6
+ from desdeo.api.models import (
7
+ Bounds,
8
+ ConstantDB,
9
+ ConstraintDB,
10
+ DiscreteRepresentationDB,
11
+ ENautilusState,
12
+ ExtraFunctionDB,
13
+ ForestProblemMetaData,
14
+ Group,
15
+ GroupIteration,
16
+ InteractiveSessionDB,
17
+ NIMBUSClassificationState,
18
+ NIMBUSFinalState,
19
+ NIMBUSInitializationState,
20
+ NIMBUSSaveState,
21
+ ObjectiveDB,
22
+ PreferenceDB,
23
+ ProblemDB,
24
+ ProblemMetaDataDB,
25
+ ReferencePoint,
26
+ RepresentativeNonDominatedSolutions,
27
+ RPMState,
28
+ ScalarizationFunctionDB,
29
+ SimulatorDB,
30
+ StateDB,
31
+ TensorConstantDB,
32
+ TensorVariableDB,
33
+ User,
34
+ UserSavedSolutionDB,
35
+ VariableDB,
36
+ )
37
+ from desdeo.api.models.gdm.gnimbus import (
38
+ OptimizationPreference,
39
+ )
40
+ from desdeo.mcdm import enautilus_step, generate_starting_point, rpm_solve_solutions, solve_sub_problems
41
+ from desdeo.problem.schema import (
42
+ Constant,
43
+ Constraint,
44
+ ConstraintTypeEnum,
45
+ DiscreteRepresentation,
46
+ ExtraFunction,
47
+ Objective,
48
+ ObjectiveTypeEnum,
49
+ Problem,
50
+ ScalarizationFunction,
51
+ Simulator,
52
+ TensorConstant,
53
+ TensorVariable,
54
+ Variable,
55
+ VariableTypeEnum,
56
+ )
57
+ from desdeo.problem.testproblems import (
58
+ binh_and_korn,
59
+ dtlz2,
60
+ momip_ti2,
61
+ momip_ti7,
62
+ multi_valued_constraint_problem,
63
+ nimbus_test_problem,
64
+ pareto_navigator_test_problem,
65
+ re21,
66
+ re22,
67
+ re23,
68
+ re24,
69
+ river_pollution_problem,
70
+ river_pollution_problem_discrete,
71
+ simple_data_problem,
72
+ simple_knapsack,
73
+ simple_knapsack_vectors,
74
+ simple_linear_test_problem,
75
+ simple_scenario_test_problem,
76
+ simple_test_problem,
77
+ spanish_sustainability_problem,
78
+ zdt1,
79
+ )
80
+ from desdeo.tools import PyomoBonminSolver, available_solvers
81
+
82
+
83
+ def compare_models(
84
+ model_1,
85
+ model_2,
86
+ unordered_fields=None,
87
+ ) -> bool:
88
+ """Compares two Pydantic models.
89
+
90
+ Args:
91
+ model_1 (Any): Pydantic model 1.
92
+ model_2 (Any): Pydantic model 2.
93
+ unordered_fields (list[str]): field names that are unordered and should be compared for
94
+ having the same contents.
95
+
96
+ Returns:
97
+ bool: Whether the two models have identical contents.
98
+ """
99
+ if unordered_fields is None:
100
+ unordered_fields = [
101
+ "variables",
102
+ "constants",
103
+ "objectives",
104
+ "constraints",
105
+ "extra_funcs",
106
+ "simulators",
107
+ "scenario_keys",
108
+ ]
109
+
110
+ dict_1 = model_1.model_dump()
111
+ dict_2 = model_2.model_dump()
112
+
113
+ for field in unordered_fields:
114
+ if field in dict_1 and field in dict_2 and isinstance(dict_1[field], list) and isinstance(dict_2[field], list):
115
+ if len(dict_1[field]) != len(dict_2[field]):
116
+ return False
117
+
118
+ for key_1, key_2 in zip(dict_1, dict_2, strict=True):
119
+ if key_1 not in dict_2 or key_2 not in dict_1:
120
+ return False
121
+
122
+ if dict_1[key_1] != dict_1[key_2]:
123
+ return False
124
+
125
+ if dict_2[key_1] != dict_2[key_2]:
126
+ return False
127
+
128
+ del dict_1[field], dict_2[field]
129
+
130
+ return dict_1 == dict_2
131
+
132
+
133
+ def test_tensor_constant(session_and_user: dict[str, Session | list[User]]):
134
+ """Test that a tensor constant can be transformed to an SQLModel and back after adding it to the database."""
135
+ session = session_and_user["session"]
136
+
137
+ t_tensor = TensorConstant(name="tensor", symbol="T", shape=[2, 2], values=[[1, 2], [3, 4]])
138
+ t_tensor_dump = t_tensor.model_dump()
139
+ t_tensor_dump["problem_id"] = 1
140
+
141
+ db_tensor = TensorConstantDB.model_validate(t_tensor_dump)
142
+
143
+ session.add(db_tensor)
144
+ session.commit()
145
+
146
+ statement = select(TensorConstantDB).where(TensorConstantDB.problem_id == 1)
147
+ from_db_tensor = session.exec(statement).first()
148
+
149
+ # check that original added TensorConstant and fetched match
150
+ assert db_tensor == from_db_tensor
151
+
152
+ from_db_tensor_dump = from_db_tensor.model_dump(exclude={"id", "problem_id"})
153
+ t_tensor_validated = TensorConstant.model_validate(from_db_tensor_dump)
154
+
155
+ assert t_tensor_validated == t_tensor
156
+
157
+
158
+ def test_constant(session_and_user: dict[str, Session | list[User]]):
159
+ """Test that a scalar constant can be transformed to an SQLModel and back after adding it to the database."""
160
+ session = session_and_user["session"]
161
+
162
+ constant = Constant(name="constant", symbol="c", value=69.420)
163
+ constant_dump = constant.model_dump()
164
+ constant_dump["problem_id"] = 1
165
+
166
+ db_constant = ConstantDB.model_validate(constant_dump)
167
+
168
+ session.add(db_constant)
169
+ session.commit()
170
+
171
+ statement = select(ConstantDB).where(ConstantDB.problem_id == 1)
172
+ from_db_constant = session.exec(statement).first()
173
+
174
+ assert db_constant == from_db_constant
175
+
176
+ from_db_constant_dump = from_db_constant.model_dump(exclude={"id", "problem_id"})
177
+ constant_validated = Constant.model_validate(from_db_constant_dump)
178
+
179
+ assert constant_validated == constant
180
+
181
+
182
+ def test_variable(session_and_user: dict[str, Session | list[User]]):
183
+ """Test that a scalar variable can be transformed to an SQLModel and back after adding it to the database."""
184
+ session = session_and_user["session"]
185
+
186
+ variable = Variable(
187
+ name="test variable",
188
+ symbol="x_1",
189
+ initial_value=69,
190
+ lowerbound=42,
191
+ upperbound=420,
192
+ variable_type=VariableTypeEnum.integer,
193
+ )
194
+
195
+ variable_dump = variable.model_dump()
196
+ variable_dump["problem_id"] = 1
197
+
198
+ db_variable = VariableDB.model_validate(variable_dump)
199
+
200
+ session.add(db_variable)
201
+ session.commit()
202
+ session.refresh(db_variable)
203
+
204
+ from_db_variable = session.get(VariableDB, db_variable.id)
205
+
206
+ assert db_variable == from_db_variable
207
+
208
+ from_db_variable_dump = from_db_variable.model_dump(exclude={"id", "problem_id"})
209
+ variable_validated = Variable.model_validate(from_db_variable_dump)
210
+
211
+ assert variable_validated == variable
212
+
213
+
214
+ def test_tensor_variable(session_and_user: dict[str, Session | list[User]]):
215
+ """Test that a tensor variable can be transformed to an SQLModel and back after adding it to the database."""
216
+ session = session_and_user["session"]
217
+
218
+ t_variable = TensorVariable(
219
+ name="test variable",
220
+ symbol="X",
221
+ shape=[2, 2],
222
+ initial_values=[[1, 2], [3, 4]],
223
+ lowerbounds=[[0, 1], [1, 0]],
224
+ upperbounds=[[99, 89], [88, 77]],
225
+ variable_type=VariableTypeEnum.integer,
226
+ )
227
+
228
+ t_variable_dump = t_variable.model_dump()
229
+ t_variable_dump["problem_id"] = 69
230
+
231
+ db_t_variable = TensorVariableDB.model_validate(t_variable_dump)
232
+
233
+ session.add(db_t_variable)
234
+ session.commit()
235
+ session.refresh(db_t_variable)
236
+
237
+ from_db_t_variable = session.get(TensorVariableDB, db_t_variable.id)
238
+
239
+ assert db_t_variable == from_db_t_variable
240
+
241
+ from_db_t_variable_dump = from_db_t_variable.model_dump(exclude={"id", "problem_id"})
242
+ t_variable_validated = TensorVariable.model_validate(from_db_t_variable_dump)
243
+
244
+ assert t_variable_validated == t_variable
245
+
246
+
247
+ def test_objective(session_and_user: dict[str, Session | list[User]]):
248
+ """Test that an objective can be transformed to an SQLModel and back after adding it to the database."""
249
+ session = session_and_user["session"]
250
+
251
+ objective = Objective(
252
+ name="Test Objective",
253
+ symbol="f_1",
254
+ func="x_1 + x_2 + Sin(y)",
255
+ objective_type=ObjectiveTypeEnum.analytical,
256
+ ideal=10.5,
257
+ nadir=20.0,
258
+ maximize=False,
259
+ scenario_keys=["s_1", "s_2"],
260
+ unit="m",
261
+ is_convex=False,
262
+ is_linear=True,
263
+ is_twice_differentiable=True,
264
+ simulator_path="/dev/null",
265
+ surrogates=["/var/log", "/dev/sda/sda1"],
266
+ )
267
+
268
+ objective_dump = objective.model_dump()
269
+ objective_dump["problem_id"] = 420 # yes
270
+
271
+ db_objective = ObjectiveDB.model_validate(objective_dump)
272
+
273
+ session.add(db_objective)
274
+ session.commit()
275
+ session.refresh(db_objective)
276
+
277
+ from_db_objective = session.get(ObjectiveDB, db_objective.id)
278
+
279
+ assert db_objective == from_db_objective
280
+
281
+ from_db_objective_dump = from_db_objective.model_dump(exclude={"id", "problem_id"})
282
+ objective_validated = Objective.model_validate(from_db_objective_dump)
283
+
284
+ assert objective_validated == objective
285
+
286
+
287
+ def test_constraint(session_and_user: dict[str, Session | list[User]]):
288
+ """Test that an constraint can be transformed to an SQLModel and back after adding it to the database."""
289
+ session = session_and_user["session"]
290
+
291
+ constraint = Constraint(
292
+ name="Test Constraint",
293
+ symbol="g_1",
294
+ func="x_1 + x_1 + x_1 - 10",
295
+ cons_type=ConstraintTypeEnum.LTE,
296
+ is_convex=True,
297
+ is_linear=False,
298
+ is_twice_differentiable=False,
299
+ scenario_keys=["Abloy", "MasterLock", "MasterLockToOpenMasterLock"],
300
+ simulator_path="/dev/null/aaaaaaaaaa",
301
+ surrogates=["/var/log", "/dev/sda/sda1/no"],
302
+ )
303
+
304
+ constraint_dump = constraint.model_dump()
305
+ constraint_dump["problem_id"] = 72
306
+
307
+ db_constraint = ConstraintDB.model_validate(constraint_dump)
308
+
309
+ session.add(db_constraint)
310
+ session.commit()
311
+ session.refresh(db_constraint)
312
+
313
+ from_db_constraint = session.get(ConstraintDB, db_constraint.id)
314
+
315
+ assert db_constraint == from_db_constraint
316
+
317
+ from_db_constraint_dump = from_db_constraint.model_dump(exclude={"id", "problem_id"})
318
+ constraint_validated = Constraint.model_validate(from_db_constraint_dump)
319
+
320
+ assert constraint_validated == constraint
321
+
322
+
323
+ def test_scalarization_function(session_and_user: dict[str, Session | list[User]]):
324
+ """Test that a scalarization function can be transformed to an SQLModel and back after adding it to the database."""
325
+ session = session_and_user["session"]
326
+
327
+ scalarization = ScalarizationFunction(
328
+ name="Test ScalarizationFunction",
329
+ symbol="s_1",
330
+ func="x_1 + x_1 + x_1 - 10 - 99999 + Sin(y_3)",
331
+ is_convex=True,
332
+ is_linear=True,
333
+ is_twice_differentiable=False,
334
+ scenario_keys=["Abloy", "MasterLock", "MasterLockToOpenMasterLock", "MyHandsHurt"],
335
+ )
336
+
337
+ scalarization_dump = scalarization.model_dump()
338
+ scalarization_dump["problem_id"] = 2
339
+
340
+ db_scalarization = ScalarizationFunctionDB.model_validate(scalarization_dump)
341
+
342
+ session.add(db_scalarization)
343
+ session.commit()
344
+ session.refresh(db_scalarization)
345
+
346
+ from_db_scalarization = session.get(ScalarizationFunctionDB, db_scalarization.id)
347
+
348
+ assert db_scalarization == from_db_scalarization
349
+
350
+ from_db_scalarization_dump = from_db_scalarization.model_dump(exclude={"id", "problem_id"})
351
+ scalarization_validated = ScalarizationFunction.model_validate(from_db_scalarization_dump)
352
+
353
+ assert scalarization_validated == scalarization
354
+
355
+
356
+ def test_extra_function(session_and_user: dict[str, Session | list[User]]):
357
+ """Test that an extra function can be transformed to an SQLModel and back after adding it to the database."""
358
+ session = session_and_user["session"]
359
+
360
+ extra = ExtraFunction(
361
+ name="Test ExtraFunction",
362
+ symbol="extra_1",
363
+ func="x_1 + x_2 + x_9000 - 10 - 99999 + Cos(y_3)",
364
+ is_convex=False,
365
+ is_linear=False,
366
+ is_twice_differentiable=True,
367
+ scenario_keys=["Abloy", "MasterLock", "MasterLockToOpenMasterLock", "MyHandsHurt", "RunningOutOfIdeas"],
368
+ )
369
+
370
+ extra_dump = extra.model_dump()
371
+ extra_dump["problem_id"] = 5
372
+
373
+ db_extra = ExtraFunctionDB.model_validate(extra_dump)
374
+
375
+ session.add(db_extra)
376
+ session.commit()
377
+ session.refresh(db_extra)
378
+
379
+ from_db_extra = session.get(ExtraFunctionDB, db_extra.id)
380
+
381
+ assert db_extra == from_db_extra
382
+
383
+ from_db_extra_dump = from_db_extra.model_dump(exclude={"id", "problem_id"})
384
+ extra_validated = ExtraFunction.model_validate(from_db_extra_dump)
385
+
386
+ assert extra_validated == extra
387
+
388
+
389
+ def test_discrete_representation(session_and_user: dict[str, Session | list[User]]):
390
+ """Test that a DiscreteRepresentation can be transformed to an SQLModel and back after adding it to the database."""
391
+ session = session_and_user["session"]
392
+
393
+ discrete = DiscreteRepresentation(
394
+ variable_values={"x_1": [1, 2, 3, 4, 5], "x_2": [6, 7, 8, 9, 10]},
395
+ objective_values={"f_1": [0.5, 1.0, 2.0, 3.5, 9], "f_2": [-1, -2, -3, -4, -5]},
396
+ non_dominated=True,
397
+ )
398
+
399
+ discrete_dump = discrete.model_dump()
400
+ discrete_dump["problem_id"] = 3
401
+
402
+ db_discrete = DiscreteRepresentationDB.model_validate(discrete_dump)
403
+
404
+ session.add(db_discrete)
405
+ session.commit()
406
+ session.refresh(db_discrete)
407
+
408
+ from_db_discrete = session.get(DiscreteRepresentationDB, db_discrete.id)
409
+
410
+ assert db_discrete == from_db_discrete
411
+
412
+ from_db_discrete_dump = from_db_discrete.model_dump(exclude={"id", "problem_id"})
413
+ discrete_validated = DiscreteRepresentation.model_validate(from_db_discrete_dump)
414
+
415
+ assert discrete_validated == discrete
416
+
417
+
418
+ def test_simulator(session_and_user: dict[str, Session | list[User]]):
419
+ """Test that a Simulator can be transformed to an SQLModel and back after adding it to the database."""
420
+ session = session_and_user["session"]
421
+
422
+ simulator = Simulator(
423
+ file="/my/favorite/simulator.exe",
424
+ name="simulator",
425
+ symbol="simu",
426
+ parameter_options={"param1": 69, "nice": True},
427
+ )
428
+
429
+ simulator_dump = simulator.model_dump()
430
+ simulator_dump["problem_id"] = 2
431
+
432
+ db_simulator = SimulatorDB.model_validate(simulator_dump)
433
+
434
+ session.add(db_simulator)
435
+ session.commit()
436
+ session.refresh(db_simulator)
437
+
438
+ from_db_simulator = session.get(SimulatorDB, db_simulator.id)
439
+
440
+ assert db_simulator == from_db_simulator
441
+
442
+ from_db_simulator_dump = from_db_simulator.model_dump(exclude={"id", "problem_id"})
443
+ simulator_validated = Simulator.model_validate(from_db_simulator_dump)
444
+
445
+ assert simulator_validated == simulator
446
+
447
+
448
+ def test_from_pydantic(session_and_user: dict[str, Session | list[User]]):
449
+ """Test that a problem can be added and fetched from the database correctly."""
450
+ session = session_and_user["session"]
451
+ user = session_and_user["user"]
452
+
453
+ problem_binh = binh_and_korn()
454
+
455
+ problemdb = ProblemDB.from_problem(problem_binh, user=user)
456
+ session.add(problemdb)
457
+ session.commit()
458
+ session.refresh(problemdb)
459
+
460
+ from_db_problem = session.get(ProblemDB, problemdb.id)
461
+
462
+ assert compare_models(problemdb, from_db_problem)
463
+
464
+
465
+ def test_from_problem_to_d_and_back(session_and_user: dict[str, Session | list[User]]):
466
+ """Test that Problem converts to ProblemDB and back."""
467
+ session = session_and_user["session"]
468
+ user = session_and_user["user"]
469
+
470
+ problems = [
471
+ binh_and_korn(),
472
+ river_pollution_problem(),
473
+ simple_knapsack(),
474
+ simple_data_problem(),
475
+ simple_scenario_test_problem(),
476
+ re24(),
477
+ simple_knapsack_vectors(),
478
+ spanish_sustainability_problem(),
479
+ zdt1(10),
480
+ dtlz2(5, 3),
481
+ multi_valued_constraint_problem(),
482
+ momip_ti2(),
483
+ momip_ti7(),
484
+ nimbus_test_problem(),
485
+ pareto_navigator_test_problem(),
486
+ river_pollution_problem_discrete(),
487
+ simple_test_problem(),
488
+ simple_linear_test_problem(),
489
+ re21(),
490
+ re22(),
491
+ re23(),
492
+ ]
493
+
494
+ for problem in problems:
495
+ # convert to SQLModel
496
+ problem_db = ProblemDB.from_problem(problem, user=user)
497
+
498
+ session.add(problem_db)
499
+ session.commit()
500
+ session.refresh(problem_db)
501
+
502
+ from_db = session.get(ProblemDB, problem_db.id)
503
+
504
+ # Back to pure pydantic
505
+ problem_db = Problem.from_problemdb(from_db)
506
+
507
+ # check that problems are equal
508
+ assert compare_models(problem, problem_db)
509
+
510
+
511
+ def test_user_save_solutions(session_and_user: dict[str, Session | list[User]]):
512
+ """Test that user_save_solutions correctly saves solutions to the usersavedsolutiondb in the database."""
513
+ session = session_and_user["session"]
514
+ user = session_and_user["user"]
515
+
516
+ # Create test solutions with proper dictionary values
517
+ objective_values = {"f_1": 1.2, "f_2": 0.9}
518
+ variable_values = {"x_1": 5.2, "x_2": 8.0, "x_3": -4.2}
519
+
520
+ user_id = user.id
521
+ problem_id = 1
522
+
523
+ test_solutions = [
524
+ UserSavedSolutionDB(
525
+ name="Solution 1",
526
+ objective_values=objective_values,
527
+ variable_values=variable_values,
528
+ user_id=user_id,
529
+ problem_id=problem_id,
530
+ origin_state_id=1,
531
+ ),
532
+ UserSavedSolutionDB(
533
+ name="Solution 2",
534
+ objective_values=objective_values,
535
+ variable_values=variable_values,
536
+ solution_index=2,
537
+ user_id=user_id,
538
+ problem_id=problem_id,
539
+ origin_state_id=2,
540
+ ),
541
+ ]
542
+
543
+ num_test_solutions = len(test_solutions)
544
+
545
+ # Create NIMBUSSaveState
546
+ save_state = NIMBUSSaveState(solutions=test_solutions)
547
+
548
+ # Create StateDB with NIMBUSSaveState
549
+ state_db = StateDB.create(session, problem_id=problem_id, state=save_state)
550
+
551
+ session.add(state_db)
552
+ session.commit()
553
+ session.refresh(state_db)
554
+
555
+ # Verify the solutions were saved
556
+ saved_solutions = session.exec(select(UserSavedSolutionDB)).all()
557
+ assert len(saved_solutions) == num_test_solutions
558
+
559
+ # Verify the content of the first solution
560
+ first_solution = saved_solutions[0]
561
+ assert first_solution.name == "Solution 1"
562
+ assert first_solution.objective_values == objective_values
563
+ assert first_solution.variable_values == variable_values
564
+ assert first_solution.origin_state_id == 1
565
+ assert first_solution.solution_index is None
566
+ assert first_solution.user_id == user.id
567
+ assert first_solution.problem_id == problem_id
568
+
569
+ # Verify the content of the second solution
570
+ second_solution = saved_solutions[1]
571
+ assert second_solution.name == "Solution 2"
572
+ assert second_solution.objective_values == objective_values
573
+ assert second_solution.variable_values == variable_values
574
+ assert second_solution.origin_state_id == 2
575
+ assert second_solution.solution_index == 2
576
+ assert second_solution.user_id == user.id
577
+ assert second_solution.problem_id == problem_id
578
+
579
+ # Verify state relationship
580
+ saved_state = session.exec(select(StateDB).where(StateDB.id == state_db.id)).first()
581
+ assert isinstance(saved_state.state, NIMBUSSaveState)
582
+ assert len(saved_state.state.solutions) == num_test_solutions
583
+
584
+ # Check that relationships are formed
585
+ session.refresh(user)
586
+
587
+ assert len(user.archive) == len(test_solutions)
588
+ assert len(session.get(ProblemDB, problem_id).solutions) == len(test_solutions)
589
+
590
+
591
+ def test_preference_models(session_and_user: dict[str, Session | list[User]]):
592
+ """Test that the archive works as intended."""
593
+ session = session_and_user["session"]
594
+ user = session_and_user["user"]
595
+
596
+ problem = ProblemDB.from_problem(dtlz2(5, 3), user=user)
597
+
598
+ session.add(problem)
599
+ session.commit()
600
+ session.refresh(problem)
601
+
602
+ aspiration_levels = {"f_1": 0.1, "f_2": 5, "f_3": -3.1}
603
+ lower_bounds = {"f_1": -4.1, "f_2": 0, "f_3": 2.2}
604
+ upper_bounds = {"f_1": 2.1, "f_2": 0.1, "f_3": 12.2}
605
+
606
+ reference_point = ReferencePoint(aspiration_levels=aspiration_levels)
607
+ bounds = Bounds(lower_bounds=lower_bounds, upper_bounds=upper_bounds)
608
+
609
+ reference_point_db = PreferenceDB(user_id=user.id, problem_id=problem.id, preference=reference_point)
610
+ bounds_db = PreferenceDB(user_id=user.id, problem_id=problem.id, preference=bounds)
611
+
612
+ session.add(reference_point_db)
613
+ session.add(bounds_db)
614
+ session.commit()
615
+ session.refresh(reference_point_db)
616
+ session.refresh(bounds_db)
617
+
618
+ from_db_ref_point = session.get(PreferenceDB, reference_point_db.id)
619
+ from_db_bounds = session.get(PreferenceDB, bounds_db.id)
620
+
621
+ assert from_db_ref_point.preference.aspiration_levels == aspiration_levels
622
+ assert from_db_bounds.preference.lower_bounds == lower_bounds
623
+ assert from_db_bounds.preference.upper_bounds == upper_bounds
624
+
625
+ assert from_db_ref_point.problem == problem
626
+ assert from_db_ref_point.problem == problem
627
+ assert from_db_bounds.problem == problem
628
+
629
+ assert from_db_bounds.user == user
630
+ assert from_db_ref_point.user == user
631
+
632
+
633
+ def test_rpm_state(session_and_user: dict[str, Session | list[User]]):
634
+ """Test the RPM state that it works correctly."""
635
+ session = session_and_user["session"]
636
+ user = session_and_user["user"]
637
+ problem_db = user.problems[0]
638
+
639
+ # create interactive session
640
+ isession = InteractiveSessionDB(user_id=user.id)
641
+
642
+ session.add(isession)
643
+ session.commit()
644
+ session.refresh(isession)
645
+
646
+ # use the reference point method
647
+ asp_levels_1 = {"f_1": 0.4, "f_2": 0.8, "f_3": 0.6}
648
+
649
+ problem = Problem.from_problemdb(problem_db)
650
+
651
+ scalarization_options = None
652
+ solver = "pyomo_bonmin"
653
+ solver_options = None
654
+
655
+ results = rpm_solve_solutions(
656
+ problem,
657
+ asp_levels_1,
658
+ scalarization_options=scalarization_options,
659
+ solver=available_solvers[solver]["constructor"],
660
+ solver_options=solver_options,
661
+ )
662
+
663
+ # create preferences
664
+
665
+ rp_1 = ReferencePoint(aspiration_levels=asp_levels_1)
666
+
667
+ # create state
668
+
669
+ rpm_state = RPMState(
670
+ preferences=rp_1,
671
+ scalarization_options=scalarization_options,
672
+ solver=solver,
673
+ solver_options=solver_options,
674
+ solver_results=results,
675
+ )
676
+
677
+ state_1 = StateDB.create(
678
+ database_session=session, problem_id=problem_db.id, session_id=isession.id, state=rpm_state
679
+ )
680
+
681
+ session.add(state_1)
682
+ session.commit()
683
+ session.refresh(state_1)
684
+
685
+ asp_levels_2 = {"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}
686
+
687
+ scalarization_options = None
688
+ solver = "pyomo_bonmin"
689
+ solver_options = None
690
+
691
+ results = rpm_solve_solutions(
692
+ problem,
693
+ asp_levels_2,
694
+ scalarization_options=scalarization_options,
695
+ solver=available_solvers[solver]["constructor"],
696
+ solver_options=solver_options,
697
+ )
698
+
699
+ # create state
700
+
701
+ rpm_state = RPMState(
702
+ scalarization_options=scalarization_options,
703
+ solver=solver,
704
+ solver_options=solver_options,
705
+ solver_results=results,
706
+ )
707
+
708
+ # create preferences
709
+
710
+ rp_2 = ReferencePoint(aspiration_levels=asp_levels_2)
711
+
712
+ # create state
713
+
714
+ rpm_state = RPMState(
715
+ preferences=rp_2,
716
+ scalarization_options=scalarization_options,
717
+ solver=solver,
718
+ solver_options=solver_options,
719
+ solver_results=results,
720
+ )
721
+
722
+ state_2 = StateDB.create(
723
+ database_session=session,
724
+ problem_id=problem_db.id,
725
+ session_id=isession.id,
726
+ parent_id=state_1.id,
727
+ state=rpm_state,
728
+ )
729
+
730
+ session.add(state_2)
731
+ session.commit()
732
+ session.refresh(state_2)
733
+
734
+ assert state_1.parent is None
735
+ assert state_2.parent == state_1
736
+ assert len(state_1.children) == 1
737
+ assert state_1.children[0] == state_2
738
+
739
+ assert state_1.state.preferences == rp_1
740
+ assert state_2.state.preferences == rp_2
741
+
742
+ assert state_2.problem == problem_db
743
+ assert state_2.session.user == user
744
+
745
+ assert state_2.children == []
746
+ assert state_2.parent.problem == problem_db
747
+ assert state_2.parent.session.user == user
748
+
749
+
750
+ def test_problem_metadata(session_and_user: dict[str, Session | list[User]]):
751
+ """Test that the problem metadata can be put into database and brought back."""
752
+ session = session_and_user["session"]
753
+ user = session_and_user["user"]
754
+
755
+ # Just some test problem to attach the metadata to
756
+ problem = ProblemDB.from_problem(dtlz2(5, 3), user=user)
757
+
758
+ session.add(problem)
759
+ session.commit()
760
+ session.refresh(problem)
761
+
762
+ representative_name = "Test solutions"
763
+ representative_description = "These solutions are used for testing"
764
+ representative_variables = {"x_1": [1.1, 2.2, 3.3], "x_2": [-1.1, -2.2, -3.3]}
765
+ representative_objectives = {"f_1": [0.1, 0.5, 0.9], "f_2": [-0.1, 0.2, 199.2], "f_1_min": [], "f_2_min": []}
766
+ solution_data = representative_variables | representative_objectives
767
+ representative_ideal = {"f_1": 0.1, "f_2": -0.1}
768
+ representative_nadir = {"f_1": 0.9, "f_2": 199.2}
769
+
770
+ metadata = ProblemMetaDataDB(
771
+ problem_id=problem.id,
772
+ )
773
+
774
+ session.add(metadata)
775
+ session.commit()
776
+ session.refresh(metadata)
777
+
778
+ forest_metadata = ForestProblemMetaData(
779
+ metadata_id=metadata.id,
780
+ map_json="type: string",
781
+ schedule_dict={"type": "dict"},
782
+ years=["type:", "list", "of", "strings"],
783
+ stand_id_field="type: string",
784
+ )
785
+
786
+ repr_metadata = RepresentativeNonDominatedSolutions(
787
+ metadata_id=metadata.id,
788
+ name=representative_name,
789
+ description=representative_description,
790
+ solution_data=solution_data,
791
+ ideal=representative_ideal,
792
+ nadir=representative_nadir,
793
+ )
794
+
795
+ session.add(forest_metadata)
796
+ session.add(repr_metadata)
797
+ session.commit()
798
+ session.refresh(forest_metadata)
799
+ session.refresh(repr_metadata)
800
+
801
+ statement = select(ProblemMetaDataDB).where(ProblemMetaDataDB.problem_id == problem.id)
802
+ from_db_metadata = session.exec(statement).first()
803
+
804
+ assert from_db_metadata.id is not None
805
+ assert from_db_metadata.problem_id == problem.id
806
+
807
+ metadata_forest = from_db_metadata.forest_metadata[0]
808
+
809
+ assert isinstance(metadata_forest, ForestProblemMetaData)
810
+ assert metadata_forest.map_json == "type: string"
811
+ assert metadata_forest.schedule_dict == {"type": "dict"}
812
+ assert metadata_forest.years == ["type:", "list", "of", "strings"]
813
+ assert metadata_forest.stand_id_field == "type: string"
814
+
815
+ metadata_representative = from_db_metadata.representative_nd_metadata[0]
816
+
817
+ assert isinstance(metadata_representative, RepresentativeNonDominatedSolutions)
818
+ assert metadata_representative.name == representative_name
819
+ assert metadata_representative.solution_data == solution_data
820
+ assert metadata_representative.ideal == representative_ideal
821
+ assert metadata_representative.nadir == representative_nadir
822
+
823
+ assert problem.problem_metadata == from_db_metadata
824
+
825
+
826
+ def test_group(session_and_user: dict[str, Session | list[User]]):
827
+ """なに?!ちょっとまって。。。ドクメンタはどこですか???"""
828
+ session: Session = session_and_user["session"]
829
+ user: User = session_and_user["user"]
830
+
831
+ group = Group(
832
+ name="TestGroup",
833
+ owner_id=user.id,
834
+ user_ids=[user.id],
835
+ problem_id=1,
836
+ )
837
+
838
+ session.add(group)
839
+ session.commit()
840
+ session.refresh(group)
841
+
842
+ assert group.id == 1
843
+ assert group.user_ids[0] == user.id
844
+ assert group.name == "TestGroup"
845
+
846
+
847
+ def test_gnimbus_datas(session_and_user: dict[str, Session | list[User]]):
848
+ session: Session = session_and_user["session"]
849
+ user: User = session_and_user["user"]
850
+
851
+ group = Group(
852
+ name="TestGroup",
853
+ owner_id=user.id,
854
+ user_ids=[user.id],
855
+ problem_id=1,
856
+ )
857
+
858
+ session.add(group)
859
+ session.commit()
860
+ session.refresh(group)
861
+
862
+ giter = GroupIteration(
863
+ problem_id=1,
864
+ group_id=group.id,
865
+ info_container=OptimizationPreference(
866
+ set_preferences={},
867
+ ),
868
+ notified={},
869
+ parent_id=None,
870
+ parent=None,
871
+ child=None,
872
+ )
873
+ session.add(giter)
874
+ session.commit()
875
+ session.refresh(giter)
876
+
877
+ assert type(giter.info_container) is OptimizationPreference
878
+ assert giter.problem_id == 1
879
+ assert giter.group_id == group.id
880
+
881
+
882
+ def test_enautilus_state(session_and_user: dict[str, Session | list[User]]):
883
+ """Test the E-NAUTILUS state that it works correctly."""
884
+ session = session_and_user["session"]
885
+ user = session_and_user["user"]
886
+
887
+ # create interactive session
888
+ isession = InteractiveSessionDB(user_id=user.id)
889
+
890
+ session.add(isession)
891
+ session.commit()
892
+ session.refresh(isession)
893
+
894
+ # use dummy problem
895
+ dummy_problem = Problem(
896
+ name="Synthetic-4D",
897
+ description="Unit-test Problem for E-NAUTILUS",
898
+ variables=[Variable(name="x", symbol="x", variable_type=VariableTypeEnum.real)],
899
+ objectives=[
900
+ Objective(name="f1", symbol="f1", maximize=False),
901
+ Objective(name="f2", symbol="f2", maximize=True),
902
+ Objective(name="f3", symbol="f3", maximize=False),
903
+ Objective(name="f4", symbol="f4", maximize=True),
904
+ ],
905
+ )
906
+
907
+ x = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
908
+ f1 = np.array([0.40, 0.60, 0.50, 0.70, 0.45, 0.55, 0.65, 0.48])
909
+ f2 = np.array([4.00, 3.80, 4.10, 3.70, 4.05, 3.90, 3.60, 4.20])
910
+ f3 = np.array([1.00, 1.30, 1.10, 1.40, 1.05, 1.20, 1.35, 1.15])
911
+ f4 = np.array([2.50, 2.30, 2.60, 2.20, 2.55, 2.40, 2.10, 2.65])
912
+
913
+ nadir = {"f1": np.max(f1), "f2": np.min(f2), "f3": np.max(f3), "f4": np.min(f4)}
914
+ ideal = {"f1": np.min(f1), "f2": np.max(f2), "f3": np.min(f3), "f4": np.max(f4)}
915
+
916
+ non_dom_data = {
917
+ "x": x.tolist(),
918
+ "f1": f1.tolist(),
919
+ "f1_min": f1.tolist(),
920
+ "f2": f2.tolist(),
921
+ "f2_min": (-f2).tolist(),
922
+ "f3": f3.tolist(),
923
+ "f3_min": f3.tolist(),
924
+ "f4": f4.tolist(),
925
+ "f4_min": (-f4).tolist(),
926
+ }
927
+
928
+ # add problem to DB and refresh it
929
+ problemdb = ProblemDB.from_problem(dummy_problem, user)
930
+
931
+ session.add(problemdb)
932
+ session.commit()
933
+ session.refresh(problemdb)
934
+
935
+ metadata = ProblemMetaDataDB(
936
+ problem_id=problemdb.id,
937
+ )
938
+
939
+ # add metadata to DB
940
+ session.add(metadata)
941
+ session.commit()
942
+ session.refresh(metadata)
943
+
944
+ reprdata = RepresentativeNonDominatedSolutions(
945
+ metadata_id=metadata.id,
946
+ name="Dummy data",
947
+ description="Dummy data for a problem",
948
+ solution_data=non_dom_data,
949
+ ideal=ideal,
950
+ nadir=nadir,
951
+ )
952
+
953
+ # add reprdata to DB
954
+ session.add(reprdata)
955
+ session.commit()
956
+ session.refresh(reprdata)
957
+
958
+ # test the nautilus step state
959
+ # first iteration
960
+ selected_point = nadir
961
+ reachable_indices = list(range(len(x))) # entire front reachable
962
+
963
+ total_iters = 2 # DM first thinks 2 iterations are enough
964
+ n_points = 3 # DM wants to see 3 points at first
965
+ current = 0
966
+
967
+ # First iteration
968
+ res = enautilus_step(
969
+ problem=dummy_problem,
970
+ non_dominated_points=non_dom_data,
971
+ current_iteration=current,
972
+ iterations_left=total_iters - current,
973
+ selected_point=selected_point,
974
+ reachable_point_indices=reachable_indices,
975
+ number_of_intermediate_points=n_points,
976
+ )
977
+
978
+ enautilus_state = ENautilusState(
979
+ non_dominated_solutions_id=reprdata.id,
980
+ current_iteration=current,
981
+ iterations_left=total_iters - current,
982
+ selected_point=selected_point,
983
+ reachable_point_indices=reachable_indices,
984
+ number_of_intermediate_points=n_points,
985
+ enautilus_results=res,
986
+ )
987
+
988
+ state_1 = StateDB.create(
989
+ database_session=session,
990
+ problem_id=problemdb.id,
991
+ session_id=isession.id,
992
+ parent_id=None,
993
+ state=enautilus_state,
994
+ )
995
+
996
+ session.add(state_1)
997
+ session.commit()
998
+ session.refresh(state_1)
999
+
1000
+ # Second iteration
1001
+ res_2 = enautilus_step(
1002
+ problem=dummy_problem,
1003
+ non_dominated_points=non_dom_data,
1004
+ current_iteration=res.current_iteration,
1005
+ iterations_left=res.iterations_left,
1006
+ selected_point=res.intermediate_points[0],
1007
+ reachable_point_indices=res.reachable_point_indices[0],
1008
+ number_of_intermediate_points=n_points,
1009
+ )
1010
+
1011
+ enautilus_state_2 = ENautilusState(
1012
+ non_dominated_solutions_id=reprdata.id,
1013
+ current_iteration=res.current_iteration,
1014
+ iterations_left=res.iterations_left,
1015
+ selected_point=res.intermediate_points[0],
1016
+ reachable_point_indices=res.reachable_point_indices[0],
1017
+ number_of_intermediate_points=n_points,
1018
+ enautilus_results=res_2,
1019
+ )
1020
+
1021
+ state_2 = StateDB.create(
1022
+ database_session=session,
1023
+ problem_id=problemdb.id,
1024
+ session_id=isession.id,
1025
+ parent_id=enautilus_state.id,
1026
+ state=enautilus_state_2,
1027
+ )
1028
+
1029
+ session.add(state_2)
1030
+ session.commit()
1031
+ session.refresh(state_2)
1032
+
1033
+ assert state_1.problem_id == problemdb.id
1034
+ assert state_2.problem_id == problemdb.id
1035
+
1036
+ assert state_1.session_id == isession.id
1037
+ assert state_2.session_id == isession.id
1038
+
1039
+ assert state_1.parent is None
1040
+ assert state_2.parent == state_1
1041
+
1042
+ assert len(state_1.children) == 1
1043
+ assert state_1.children[0] == state_2
1044
+ assert state_2.children == []
1045
+
1046
+
1047
+ def test_nimbus_models(session_and_user: dict[str, Session | list[User]]):
1048
+ """Test that the NIMBUS models are in working order."""
1049
+ session: Session = session_and_user["session"]
1050
+ user: User = session_and_user["user"]
1051
+ problem_db = user.problems[0]
1052
+ problem = Problem.from_problemdb(problem_db)
1053
+
1054
+ isession = InteractiveSessionDB(user_id=user.id)
1055
+
1056
+ session.add(isession)
1057
+ session.commit()
1058
+ session.refresh(isession)
1059
+
1060
+ # 1. Initialize the NIMBUS problem (NIMBUSInitializationState)
1061
+ results_1 = generate_starting_point(
1062
+ problem=problem,
1063
+ )
1064
+ nimbus_init_state = NIMBUSInitializationState(solver_results=results_1)
1065
+
1066
+ state_1 = StateDB.create(
1067
+ database_session=session, problem_id=problem_db.id, session_id=isession.id, state=nimbus_init_state
1068
+ )
1069
+
1070
+ session.add(state_1)
1071
+ session.commit()
1072
+ session.refresh(state_1)
1073
+
1074
+ actual_state_1: NIMBUSInitializationState = state_1.state
1075
+
1076
+ assert type(actual_state_1) is NIMBUSInitializationState
1077
+ assert np.allclose(
1078
+ [x for _, x in actual_state_1.solver_results.optimal_objectives.items()],
1079
+ [x for _, x in results_1.optimal_objectives.items()],
1080
+ 0.001,
1081
+ )
1082
+
1083
+ # 2. Solve sub problems (NIMBUSClassificationState)
1084
+ aspirations = {"f_1": 0.1, "f_2": 0.9, "f_3": 0.6}
1085
+
1086
+ results_2 = solve_sub_problems(
1087
+ problem=problem, current_objectives=results_1.optimal_objectives, reference_point=aspirations, num_desired=4
1088
+ )
1089
+ nimbus_classification_state = NIMBUSClassificationState(
1090
+ preferences=ReferencePoint(aspiration_levels=aspirations),
1091
+ current_objectives=results_1.optimal_objectives,
1092
+ previous_preferences=ReferencePoint(aspiration_levels=aspirations),
1093
+ solver_results=results_2,
1094
+ )
1095
+
1096
+ state_2 = StateDB.create(
1097
+ database_session=session, problem_id=problem_db.id, session_id=isession.id, state=nimbus_classification_state
1098
+ )
1099
+
1100
+ session.add(state_2)
1101
+ session.commit()
1102
+ session.refresh(state_2)
1103
+
1104
+ actual_state_2: NIMBUSClassificationState = state_2.state
1105
+
1106
+ assert type(actual_state_2) is NIMBUSClassificationState
1107
+ assert np.allclose(
1108
+ [x for _, x in actual_state_2.preferences.aspiration_levels.items()], [x for _, x in aspirations.items()], 0.001
1109
+ )
1110
+ assert np.allclose(
1111
+ [x for _, x in actual_state_2.solver_results[0].optimal_objectives.items()],
1112
+ [x for _, x in results_2[0].optimal_objectives.items()],
1113
+ 0.001,
1114
+ )
1115
+
1116
+ # 3. (TODO) Save a found solution (NIMBUSSaveState)
1117
+ # 4. Finalize the NIMBUS process (NIMBUSFinalState)
1118
+ nimbus_final_state = NIMBUSFinalState(
1119
+ solution_origin_state_id=state_2.state.id,
1120
+ solution_result_index=0,
1121
+ solver_results=results_2[0]
1122
+ )
1123
+
1124
+ state_3 = StateDB.create(
1125
+ database_session=session, problem_id=problem_db.id, session_id=isession.id, state=nimbus_final_state
1126
+ )
1127
+
1128
+ session.add(state_3)
1129
+ session.commit()
1130
+ session.refresh(state_3)
1131
+
1132
+ actual_state_3: NIMBUSFinalState = state_3.state
1133
+ assert type(actual_state_3) is NIMBUSFinalState
1134
+ assert np.allclose(
1135
+ [x for _, x in actual_state_3.solver_results.optimal_objectives.items()],
1136
+ [x for _, x in results_2[0].optimal_objectives.items()],
1137
+ 0.01,
1138
+ )
1139
+ assert np.allclose(
1140
+ [x for _, x in actual_state_3.solver_results.optimal_variables.items()],
1141
+ [x for _, x in results_2[0].optimal_variables.items()],
1142
+ 0.01,
1143
+ )
1144
+
1145
+
1146
+ def test_nimbus_initialize_w_multidimensional_constraints(session_and_user: dict[str, Session | list[User]]):
1147
+ """Test that the NIMBUS initialization model works with multidimensional constraints."""
1148
+ session: Session = session_and_user["session"]
1149
+ user: User = session_and_user["user"]
1150
+ problem = multi_valued_constraint_problem()
1151
+
1152
+ problem_db = ProblemDB.from_problem(problem, user)
1153
+ isession = InteractiveSessionDB(user_id=user.id)
1154
+
1155
+ session.add(isession)
1156
+ session.add(problem_db)
1157
+ session.commit()
1158
+ session.refresh(isession)
1159
+ session.refresh(problem_db)
1160
+
1161
+ results_1 = generate_starting_point(problem=problem, solver=PyomoBonminSolver)
1162
+ nimbus_init_state = NIMBUSInitializationState(solver_results=results_1)
1163
+
1164
+ state_1 = StateDB.create(
1165
+ database_session=session, problem_id=problem_db.id, session_id=isession.id, state=nimbus_init_state
1166
+ )
1167
+
1168
+ session.add(state_1)
1169
+ session.commit()
1170
+ session.refresh(state_1)
1171
+
1172
+ actual_state_1: NIMBUSInitializationState = state_1.state
1173
+
1174
+ assert type(actual_state_1) is NIMBUSInitializationState
1175
+ assert np.allclose(
1176
+ [x for _, x in actual_state_1.solver_results.optimal_objectives.items()],
1177
+ [x for _, x in results_1.optimal_objectives.items()],
1178
+ 0.001,
1179
+ )