desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/adm/ADMAfsar.py +551 -0
  3. desdeo/adm/ADMChen.py +414 -0
  4. desdeo/adm/BaseADM.py +119 -0
  5. desdeo/adm/__init__.py +11 -0
  6. desdeo/api/README.md +73 -0
  7. desdeo/api/__init__.py +15 -0
  8. desdeo/api/app.py +50 -0
  9. desdeo/api/config.py +90 -0
  10. desdeo/api/config.toml +64 -0
  11. desdeo/api/db.py +27 -0
  12. desdeo/api/db_init.py +85 -0
  13. desdeo/api/db_models.py +164 -0
  14. desdeo/api/malaga_db_init.py +27 -0
  15. desdeo/api/models/__init__.py +266 -0
  16. desdeo/api/models/archive.py +23 -0
  17. desdeo/api/models/emo.py +128 -0
  18. desdeo/api/models/enautilus.py +69 -0
  19. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  20. desdeo/api/models/gdm/gdm_base.py +69 -0
  21. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  22. desdeo/api/models/gdm/gnimbus.py +138 -0
  23. desdeo/api/models/generic.py +104 -0
  24. desdeo/api/models/generic_states.py +401 -0
  25. desdeo/api/models/nimbus.py +158 -0
  26. desdeo/api/models/preference.py +128 -0
  27. desdeo/api/models/problem.py +717 -0
  28. desdeo/api/models/reference_point_method.py +18 -0
  29. desdeo/api/models/session.py +49 -0
  30. desdeo/api/models/state.py +463 -0
  31. desdeo/api/models/user.py +52 -0
  32. desdeo/api/models/utopia.py +25 -0
  33. desdeo/api/routers/_EMO.backup +309 -0
  34. desdeo/api/routers/_NAUTILUS.py +245 -0
  35. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  36. desdeo/api/routers/_NIMBUS.py +765 -0
  37. desdeo/api/routers/__init__.py +5 -0
  38. desdeo/api/routers/emo.py +497 -0
  39. desdeo/api/routers/enautilus.py +237 -0
  40. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  41. desdeo/api/routers/gdm/gdm_base.py +420 -0
  42. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  43. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  44. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  45. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  46. desdeo/api/routers/generic.py +233 -0
  47. desdeo/api/routers/nimbus.py +705 -0
  48. desdeo/api/routers/problem.py +307 -0
  49. desdeo/api/routers/reference_point_method.py +93 -0
  50. desdeo/api/routers/session.py +100 -0
  51. desdeo/api/routers/test.py +16 -0
  52. desdeo/api/routers/user_authentication.py +520 -0
  53. desdeo/api/routers/utils.py +187 -0
  54. desdeo/api/routers/utopia.py +230 -0
  55. desdeo/api/schema.py +100 -0
  56. desdeo/api/tests/__init__.py +0 -0
  57. desdeo/api/tests/conftest.py +151 -0
  58. desdeo/api/tests/test_enautilus.py +330 -0
  59. desdeo/api/tests/test_models.py +1179 -0
  60. desdeo/api/tests/test_routes.py +1075 -0
  61. desdeo/api/utils/_database.py +263 -0
  62. desdeo/api/utils/_logger.py +29 -0
  63. desdeo/api/utils/database.py +36 -0
  64. desdeo/api/utils/emo_database.py +40 -0
  65. desdeo/core.py +34 -0
  66. desdeo/emo/__init__.py +159 -0
  67. desdeo/emo/hooks/archivers.py +188 -0
  68. desdeo/emo/methods/EAs.py +541 -0
  69. desdeo/emo/methods/__init__.py +0 -0
  70. desdeo/emo/methods/bases.py +12 -0
  71. desdeo/emo/methods/templates.py +111 -0
  72. desdeo/emo/operators/__init__.py +1 -0
  73. desdeo/emo/operators/crossover.py +1282 -0
  74. desdeo/emo/operators/evaluator.py +114 -0
  75. desdeo/emo/operators/generator.py +459 -0
  76. desdeo/emo/operators/mutation.py +1224 -0
  77. desdeo/emo/operators/scalar_selection.py +202 -0
  78. desdeo/emo/operators/selection.py +1778 -0
  79. desdeo/emo/operators/termination.py +286 -0
  80. desdeo/emo/options/__init__.py +108 -0
  81. desdeo/emo/options/algorithms.py +435 -0
  82. desdeo/emo/options/crossover.py +164 -0
  83. desdeo/emo/options/generator.py +131 -0
  84. desdeo/emo/options/mutation.py +260 -0
  85. desdeo/emo/options/repair.py +61 -0
  86. desdeo/emo/options/scalar_selection.py +66 -0
  87. desdeo/emo/options/selection.py +127 -0
  88. desdeo/emo/options/templates.py +383 -0
  89. desdeo/emo/options/termination.py +143 -0
  90. desdeo/explanations/__init__.py +6 -0
  91. desdeo/explanations/explainer.py +100 -0
  92. desdeo/explanations/utils.py +90 -0
  93. desdeo/gdm/__init__.py +22 -0
  94. desdeo/gdm/gdmtools.py +45 -0
  95. desdeo/gdm/score_bands.py +114 -0
  96. desdeo/gdm/voting_rules.py +50 -0
  97. desdeo/mcdm/__init__.py +41 -0
  98. desdeo/mcdm/enautilus.py +338 -0
  99. desdeo/mcdm/gnimbus.py +484 -0
  100. desdeo/mcdm/nautili.py +345 -0
  101. desdeo/mcdm/nautilus.py +477 -0
  102. desdeo/mcdm/nautilus_navigator.py +656 -0
  103. desdeo/mcdm/nimbus.py +417 -0
  104. desdeo/mcdm/pareto_navigator.py +269 -0
  105. desdeo/mcdm/reference_point_method.py +186 -0
  106. desdeo/problem/__init__.py +83 -0
  107. desdeo/problem/evaluator.py +561 -0
  108. desdeo/problem/external/__init__.py +18 -0
  109. desdeo/problem/external/core.py +356 -0
  110. desdeo/problem/external/pymoo_provider.py +266 -0
  111. desdeo/problem/external/runtime.py +44 -0
  112. desdeo/problem/gurobipy_evaluator.py +562 -0
  113. desdeo/problem/infix_parser.py +341 -0
  114. desdeo/problem/json_parser.py +944 -0
  115. desdeo/problem/pyomo_evaluator.py +487 -0
  116. desdeo/problem/schema.py +1829 -0
  117. desdeo/problem/simulator_evaluator.py +348 -0
  118. desdeo/problem/sympy_evaluator.py +244 -0
  119. desdeo/problem/testproblems/__init__.py +88 -0
  120. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  121. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  122. desdeo/problem/testproblems/cake_problem.py +185 -0
  123. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  124. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  125. desdeo/problem/testproblems/forest_problem.py +283 -0
  126. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  127. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  128. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  129. desdeo/problem/testproblems/momip_problem.py +172 -0
  130. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  131. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  132. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  133. desdeo/problem/testproblems/re_problem.py +492 -0
  134. desdeo/problem/testproblems/river_pollution_problems.py +440 -0
  135. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  136. desdeo/problem/testproblems/simple_problem.py +351 -0
  137. desdeo/problem/testproblems/simulator_problem.py +92 -0
  138. desdeo/problem/testproblems/single_objective.py +289 -0
  139. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  140. desdeo/problem/testproblems/zdt_problem.py +274 -0
  141. desdeo/problem/utils.py +245 -0
  142. desdeo/tools/GenerateReferencePoints.py +181 -0
  143. desdeo/tools/__init__.py +120 -0
  144. desdeo/tools/desc_gen.py +22 -0
  145. desdeo/tools/generics.py +165 -0
  146. desdeo/tools/group_scalarization.py +3090 -0
  147. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  148. desdeo/tools/indicators_binary.py +117 -0
  149. desdeo/tools/indicators_unary.py +362 -0
  150. desdeo/tools/interaction_schema.py +38 -0
  151. desdeo/tools/intersection.py +54 -0
  152. desdeo/tools/iterative_pareto_representer.py +99 -0
  153. desdeo/tools/message.py +265 -0
  154. desdeo/tools/ng_solver_interfaces.py +199 -0
  155. desdeo/tools/non_dominated_sorting.py +134 -0
  156. desdeo/tools/patterns.py +283 -0
  157. desdeo/tools/proximal_solver.py +99 -0
  158. desdeo/tools/pyomo_solver_interfaces.py +477 -0
  159. desdeo/tools/reference_vectors.py +229 -0
  160. desdeo/tools/scalarization.py +2065 -0
  161. desdeo/tools/scipy_solver_interfaces.py +454 -0
  162. desdeo/tools/score_bands.py +627 -0
  163. desdeo/tools/utils.py +388 -0
  164. desdeo/tools/visualizations.py +67 -0
  165. desdeo/utopia_stuff/__init__.py +0 -0
  166. desdeo/utopia_stuff/data/1.json +15 -0
  167. desdeo/utopia_stuff/data/2.json +13 -0
  168. desdeo/utopia_stuff/data/3.json +15 -0
  169. desdeo/utopia_stuff/data/4.json +17 -0
  170. desdeo/utopia_stuff/data/5.json +15 -0
  171. desdeo/utopia_stuff/from_json.py +40 -0
  172. desdeo/utopia_stuff/reinit_user.py +38 -0
  173. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  174. desdeo/utopia_stuff/utopia_problem.py +403 -0
  175. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  176. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  177. desdeo-2.1.0.dist-info/METADATA +186 -0
  178. desdeo-2.1.0.dist-info/RECORD +180 -0
  179. {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  180. desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
  181. desdeo-1.2.dist-info/METADATA +0 -16
  182. desdeo-1.2.dist-info/RECORD +0 -4
@@ -0,0 +1,1075 @@
1
+ """Tests related to routes and routers."""
2
+
3
+ import json
4
+ import time
5
+
6
+ from fastapi import status
7
+ from fastapi.testclient import TestClient
8
+
9
+ from desdeo.api.models import (
10
+ CreateSessionRequest,
11
+ EMOFetchRequest,
12
+ EMOIterateRequest,
13
+ EMOIterateResponse,
14
+ ForestProblemMetaData,
15
+ GDMSCOREBandsHistoryResponse,
16
+ GDMScoreBandsInitializationRequest,
17
+ GDMScoreBandsVoteRequest,
18
+ GenericIntermediateSolutionResponse,
19
+ GroupCreateRequest,
20
+ GroupInfoRequest,
21
+ GroupModifyRequest,
22
+ GroupPublic,
23
+ InteractiveSessionDB,
24
+ IntermediateSolutionRequest,
25
+ NIMBUSClassificationRequest,
26
+ NIMBUSClassificationResponse,
27
+ NIMBUSDeleteSaveRequest,
28
+ NIMBUSDeleteSaveResponse,
29
+ NIMBUSFinalizeRequest,
30
+ NIMBUSFinalizeResponse,
31
+ NIMBUSInitializationRequest,
32
+ NIMBUSIntermediateSolutionResponse,
33
+ NIMBUSSaveRequest,
34
+ NIMBUSSaveResponse,
35
+ ProblemDB,
36
+ ProblemGetRequest,
37
+ ProblemInfo,
38
+ ProblemSelectSolverRequest,
39
+ ReferencePoint,
40
+ RPMSolveRequest,
41
+ SolutionInfo,
42
+ SolverSelectionMetadata,
43
+ User,
44
+ UserPublic,
45
+ )
46
+ from desdeo.api.models.nimbus import NIMBUSInitializationResponse
47
+ from desdeo.api.routers.user_authentication import create_access_token
48
+ from desdeo.emo.options.algorithms import rvea_options
49
+ from desdeo.emo.options.templates import ReferencePointOptions
50
+ from desdeo.gdm.score_bands import SCOREBandsGDMConfig
51
+ from desdeo.problem import Problem
52
+ from desdeo.problem.testproblems import dtlz2, simple_knapsack_vectors
53
+ from desdeo.tools.score_bands import KMeansOptions, SCOREBandsConfig
54
+ from desdeo.tools.utils import available_solvers
55
+
56
+ from .conftest import get_json, login, post_file_multipart, post_json
57
+ from .test_models import compare_models
58
+
59
+
60
+ def test_user_login(client: TestClient):
61
+ """Test that login works."""
62
+ response = client.post(
63
+ "/login",
64
+ data={"username": "analyst", "password": "analyst", "grant_type": "password"},
65
+ headers={"content-type": "application/x-www-form-urlencoded"},
66
+ )
67
+
68
+ assert response.status_code == 200
69
+ assert "access_token" in response.json()
70
+
71
+ # wrong login
72
+ response = client.post(
73
+ "/login",
74
+ data={"username": "analyst", "password": "anallyst", "grant_type": "password"},
75
+ headers={"content-type": "application/x-www-form-urlencoded"},
76
+ )
77
+
78
+ assert response.status_code == 401
79
+
80
+
81
+ def test_tokens():
82
+ """Test token generation."""
83
+ token_1 = create_access_token({"id": 1, "sub": "analyst"})
84
+ token_2 = create_access_token({"id": 1, "sub": "analyst"})
85
+
86
+ assert token_1 != token_2
87
+
88
+
89
+ def test_refresh(client: TestClient):
90
+ """Test that refreshing the access token works."""
91
+ # check that no previous cookies exist
92
+ assert len(client.cookies) == 0
93
+
94
+ # no cookie
95
+ response_bad = client.post("/refresh")
96
+
97
+ response_good = client.post(
98
+ "/login",
99
+ data={"username": "analyst", "password": "analyst", "grant_type": "password"},
100
+ headers={"content-type": "application/x-www-form-urlencoded"},
101
+ )
102
+
103
+ assert response_bad.status_code == 401
104
+ assert response_good.status_code == 200
105
+
106
+ assert "access_token" in response_good.json()
107
+ assert len(client.cookies) == 1
108
+ assert "refresh_token" in client.cookies
109
+
110
+ response_refresh = client.post("/refresh")
111
+
112
+ assert "access_token" in response_refresh.json()
113
+
114
+ assert response_good.json()["access_token"] != response_refresh.json()["access_token"]
115
+
116
+
117
+ def test_get_problem(client: TestClient):
118
+ """Test fetching specific problems based on their id."""
119
+ access_token = login(client)
120
+
121
+ response = post_json(client, "/problem/get", ProblemGetRequest(problem_id=1).model_dump(), access_token)
122
+
123
+ assert response.status_code == 200
124
+
125
+ info = ProblemInfo.model_validate(response.json())
126
+
127
+ assert info.id == 1
128
+ assert info.name == "dtlz2"
129
+ assert info.problem_metadata is None
130
+
131
+ response = post_json(client, "problem/get", ProblemGetRequest(problem_id=2).model_dump(), access_token)
132
+
133
+ assert response.status_code == 200
134
+
135
+ info = ProblemInfo.model_validate(response.json())
136
+
137
+ assert info.id == 2
138
+ assert info.name == "The river pollution problem"
139
+ assert isinstance(info.problem_metadata.forest_metadata[0], ForestProblemMetaData)
140
+
141
+
142
+ def test_add_problem(client: TestClient):
143
+ """Test that adding a problem to the database works."""
144
+ access_token = login(client)
145
+
146
+ problem = simple_knapsack_vectors()
147
+
148
+ response = post_json(client, "/problem/add", problem.model_dump(), access_token)
149
+
150
+ assert response.status_code == status.HTTP_200_OK
151
+
152
+ problem_info: ProblemInfo = ProblemInfo.model_validate(response.json())
153
+
154
+ assert problem_info.name == "Simple two-objective Knapsack problem"
155
+
156
+ response = get_json(client, "problem/all_info", access_token)
157
+
158
+ assert response.status_code == status.HTTP_200_OK
159
+
160
+ problems = response.json()
161
+
162
+ assert len(problems) == 4
163
+
164
+
165
+ def test_add_problem_json(client: TestClient, session_and_user: dict):
166
+ """Test that adding a problem to the database works with JSON files."""
167
+ session = session_and_user["session"]
168
+ access_token = login(client)
169
+ problem = dtlz2(5, 3)
170
+
171
+ payload = problem.model_dump()
172
+ raw = json.dumps(payload).encode("utf-8")
173
+
174
+ response = post_file_multipart(client, "/problem/add_json", raw, access_token)
175
+
176
+ assert response.status_code == 200
177
+
178
+ problem_from_db = session.get(ProblemDB, 4)
179
+
180
+ assert compare_models(problem, Problem.from_problemdb(problem_from_db))
181
+
182
+
183
+ def test_new_session(client: TestClient, session_and_user: dict):
184
+ """Test that creating a new session works as expected."""
185
+ user: User = session_and_user["user"]
186
+ session = session_and_user["session"]
187
+
188
+ assert user.active_session_id is None
189
+
190
+ access_token = login(client)
191
+
192
+ request = CreateSessionRequest(info="My session")
193
+
194
+ response = post_json(client, "/session/new", request.model_dump(), access_token)
195
+
196
+ assert response.status_code == status.HTTP_200_OK
197
+
198
+ assert user.active_session_id == 1
199
+ isession = session.get(InteractiveSessionDB, 1)
200
+
201
+ assert isession.info == "My session"
202
+
203
+
204
+ def test_get_session(client: TestClient, session_and_user: dict):
205
+ """Test that getting a session via GET works as intended."""
206
+ user: User = session_and_user["user"]
207
+
208
+ access_token = login(client)
209
+
210
+ # no sessions
211
+ response = client.get(
212
+ "/session/get/1",
213
+ headers={"Authorization": f"Bearer {access_token}"},
214
+ )
215
+ assert response.status_code == status.HTTP_404_NOT_FOUND
216
+
217
+ # add session 1
218
+ request = CreateSessionRequest(info="Session 1")
219
+ response = post_json(client, "/session/new", request.model_dump(), access_token)
220
+ assert response.status_code == status.HTTP_200_OK
221
+ assert user.active_session_id == 1
222
+
223
+ # add session 2
224
+ request = CreateSessionRequest(info="Session 2")
225
+ response = post_json(client, "/session/new", request.model_dump(), access_token)
226
+ assert response.status_code == status.HTTP_200_OK
227
+ assert user.active_session_id == 2
228
+
229
+ # fetch session 1
230
+ response = client.get(
231
+ "/session/get/1",
232
+ headers={"Authorization": f"Bearer {access_token}"},
233
+ )
234
+ assert response.status_code == status.HTTP_200_OK
235
+ assert response.json()["id"] == 1
236
+
237
+ # fetch session 2
238
+ response = client.get(
239
+ "/session/get/2",
240
+ headers={"Authorization": f"Bearer {access_token}"},
241
+ )
242
+ assert response.status_code == status.HTTP_200_OK
243
+ assert response.json()["id"] == 2
244
+
245
+
246
+ def test_get_all_sessions_success(client: TestClient, session_and_user: dict):
247
+ """Test getting all sessions when sessions exist."""
248
+ access_token = login(client)
249
+
250
+ # create 2 test sessions
251
+ post_json(client, "/session/new", {"info": "S1"}, access_token)
252
+ post_json(client, "/session/new", {"info": "S2"}, access_token)
253
+
254
+ response = client.get(
255
+ "/session/get_all",
256
+ headers={"Authorization": f"Bearer {access_token}"},
257
+ )
258
+
259
+ assert response.status_code == status.HTTP_200_OK
260
+ data = response.json()
261
+ assert isinstance(data, list)
262
+ assert len(data) == 2
263
+
264
+
265
+ def test_get_all_sessions_not_found(client: TestClient, session_and_user: dict):
266
+ """Test get_all returns 404 if user has no sessions."""
267
+ access_token = login(client)
268
+
269
+ response = client.get(
270
+ "/session/get_all",
271
+ headers={"Authorization": f"Bearer {access_token}"},
272
+ )
273
+
274
+ assert response.status_code == status.HTTP_404_NOT_FOUND
275
+
276
+
277
+ def test_delete_session_success(client: TestClient, session_and_user: dict):
278
+ """Test deleting an existing session."""
279
+ access_token = login(client)
280
+
281
+ # create session
282
+ post_json(client, "/session/new", {"info": "To delete"}, access_token)
283
+
284
+ response = client.delete(
285
+ "/session/1",
286
+ headers={"Authorization": f"Bearer {access_token}"},
287
+ )
288
+
289
+ assert response.status_code == status.HTTP_204_NO_CONTENT
290
+
291
+ # verify it's gone
292
+ response = client.get(
293
+ "/session/get/1",
294
+ headers={"Authorization": f"Bearer {access_token}"},
295
+ )
296
+ assert response.status_code == status.HTTP_404_NOT_FOUND
297
+
298
+
299
+ def test_delete_session_not_found(client: TestClient, session_and_user: dict):
300
+ """Test deleting a non-existent session returns 404."""
301
+ access_token = login(client)
302
+
303
+ response = client.delete(
304
+ "/session/999",
305
+ headers={"Authorization": f"Bearer {access_token}"},
306
+ )
307
+
308
+ assert response.status_code == status.HTTP_404_NOT_FOUND
309
+
310
+
311
+ def test_rpm_solve(client: TestClient):
312
+ """Test that using the reference point method works as expected."""
313
+ access_token = login(client)
314
+
315
+ request = RPMSolveRequest(
316
+ problem_id=1, preference=ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.3, "f_3": 0.4})
317
+ )
318
+
319
+ response = post_json(client, "/method/rpm/solve", request.model_dump(), access_token)
320
+
321
+ assert response.status_code == status.HTTP_200_OK
322
+
323
+
324
+ def test_nimbus_solve(client: TestClient):
325
+ """Test that using the NIMBUS method works as expected."""
326
+ access_token = login(client)
327
+
328
+ preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
329
+
330
+ request = NIMBUSClassificationRequest(
331
+ problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=3
332
+ )
333
+
334
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
335
+ assert response.status_code == status.HTTP_200_OK
336
+ result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
337
+ json.loads(response.content.decode("utf-8"))
338
+ )
339
+ assert result.previous_preference == preference
340
+ assert len(result.all_solutions) == 3
341
+
342
+ request = NIMBUSSaveRequest(
343
+ problem_id=1,
344
+ parent_state_id=result.state_id,
345
+ solution_info=[
346
+ SolutionInfo(state_id=1, solution_index=0, name="Test solution 1"),
347
+ SolutionInfo(state_id=1, solution_index=2, name="Test solution 3"),
348
+ SolutionInfo(state_id=1, solution_index=2, name="Test solution 34"), # saved twice!
349
+ ],
350
+ )
351
+
352
+ response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
353
+ assert response.status_code == status.HTTP_200_OK
354
+ result2: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content.decode("utf-8")))
355
+ assert result2.state_id is not None
356
+
357
+ preference = ReferencePoint(aspiration_levels={"f_1": 0.1, "f_2": 0.1, "f_3": 0.9})
358
+
359
+ request = NIMBUSClassificationRequest(
360
+ problem_id=1,
361
+ preference=preference,
362
+ current_objectives=result.current_solutions[0].objective_values,
363
+ num_desired=3,
364
+ parent_state_id=result2.state_id,
365
+ )
366
+
367
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
368
+
369
+ assert response.status_code == status.HTTP_200_OK
370
+ result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
371
+ json.loads(response.content.decode("utf-8"))
372
+ )
373
+ assert result.previous_preference == preference
374
+ # We saved the same solution twice, so the filtering should remove one of those.
375
+ assert len(result.saved_solutions) == 2
376
+ assert len(result.all_solutions) == 6 # should not count saved solutions twice
377
+
378
+ # Save some more solutions!
379
+ request = NIMBUSSaveRequest(
380
+ problem_id=1,
381
+ parent_state_id=result.state_id,
382
+ solution_info=[SolutionInfo(state_id=result.state_id, solution_index=1, name="Test solution 2")],
383
+ )
384
+
385
+ response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
386
+ assert response.status_code == status.HTTP_200_OK
387
+ result2: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content.decode("utf-8")))
388
+ assert result2.state_id is not None
389
+
390
+ # Same as the first one. Therefore, (I believe) STOM and ASF give same solutions,
391
+ # which should be reflected on the amount of all solutions
392
+ preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
393
+
394
+ request = NIMBUSClassificationRequest(
395
+ problem_id=1,
396
+ preference=preference,
397
+ current_objectives=result.current_solutions[0].objective_values,
398
+ num_desired=3,
399
+ parent_state_id=result2.state_id,
400
+ )
401
+
402
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
403
+ assert response.status_code == status.HTTP_200_OK
404
+ result3: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
405
+ json.loads(response.content.decode("utf-8"))
406
+ )
407
+ assert result3.previous_preference == preference
408
+ assert len(result3.saved_solutions) == 3
409
+ assert len(result3.all_solutions) == 7
410
+
411
+
412
+ def test_intermediate_solve(client: TestClient):
413
+ """Test that solving intermediate solutions works as expected."""
414
+ access_token = login(client)
415
+
416
+ preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
417
+
418
+ request = NIMBUSClassificationRequest(
419
+ problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=2
420
+ )
421
+
422
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
423
+ assert response.status_code == status.HTTP_200_OK
424
+ result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
425
+ json.loads(response.content.decode("utf-8"))
426
+ )
427
+ assert len(result.all_solutions) == 2
428
+
429
+ # Save some solutions!
430
+ solution_1 = SolutionInfo(state_id=result.state_id, solution_index=0)
431
+ solution_2 = SolutionInfo(state_id=result.state_id, solution_index=1, name="named solution")
432
+
433
+ # Create request for intermediate solutions using solutions created with nimbus solve
434
+ request = IntermediateSolutionRequest(
435
+ problem_id=1,
436
+ context="test",
437
+ reference_solution_1=solution_1,
438
+ reference_solution_2=solution_2,
439
+ num_desired=3,
440
+ )
441
+
442
+ # Test the generic intermediate endpoint
443
+ response = post_json(client, "/method/generic/intermediate", request.model_dump(), access_token)
444
+ assert response.status_code == status.HTTP_200_OK
445
+ result: GenericIntermediateSolutionResponse = GenericIntermediateSolutionResponse.model_validate(
446
+ json.loads(response.content.decode("utf-8"))
447
+ )
448
+
449
+ # Test the NIMBUS-specific intermediate endpoint
450
+ nimbus_request = IntermediateSolutionRequest(
451
+ problem_id=1,
452
+ context="nimbus",
453
+ reference_solution_1=solution_1,
454
+ reference_solution_2=solution_2,
455
+ num_desired=2,
456
+ )
457
+
458
+ nimbus_response = post_json(client, "/method/nimbus/intermediate", nimbus_request.model_dump(), access_token)
459
+ assert nimbus_response.status_code == status.HTTP_200_OK
460
+ nimbus_result = NIMBUSIntermediateSolutionResponse.model_validate(
461
+ json.loads(nimbus_response.content.decode("utf-8"))
462
+ )
463
+
464
+ # Verify the NIMBUS response contains expected fields
465
+ assert nimbus_result.state_id is not None
466
+ assert len(nimbus_result.current_solutions) == 2
467
+ assert len(nimbus_result.all_solutions) == 7
468
+
469
+
470
+ def test_nimbus_initialize(client: TestClient):
471
+ """Test that initializing NIMBUS works without specifying a solver."""
472
+ access_token = login(client)
473
+
474
+ # test with no starting point
475
+ request = NIMBUSInitializationRequest(problem_id=1, solver=None)
476
+
477
+ response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
478
+
479
+ assert response.status_code == status.HTTP_200_OK
480
+ init_result = NIMBUSInitializationResponse.model_validate(json.loads(response.content))
481
+
482
+ assert init_result.state_id == 1
483
+ assert len(init_result.current_solutions) == 1
484
+ assert len(init_result.saved_solutions) == 0
485
+ assert len(init_result.all_solutions) == 1
486
+
487
+ # test with starting point given as solution info
488
+ request_w_info = NIMBUSInitializationRequest(
489
+ problem_id=1, starting_point=SolutionInfo(state_id=1, solution_index=0)
490
+ )
491
+
492
+ response_w_info = post_json(client, "/method/nimbus/initialize", request_w_info.model_dump(), access_token)
493
+
494
+ assert response_w_info.status_code == status.HTTP_200_OK
495
+ result_w_info = NIMBUSInitializationResponse.model_validate(json.loads(response_w_info.content))
496
+
497
+ assert result_w_info.state_id == 2
498
+ assert len(result_w_info.current_solutions) == 1
499
+ assert len(result_w_info.saved_solutions) == 0
500
+ assert len(result_w_info.all_solutions) == 1 # this is still one because the new solution will be a duplicate.
501
+
502
+ # test with starting point given as a reference point
503
+ request_w_ref = NIMBUSInitializationRequest(
504
+ problem_id=1, starting_point=ReferencePoint(aspiration_levels={"f_1": 0.2, "f_2": 0.8, "f_3": 0.4})
505
+ )
506
+
507
+ response_w_ref = post_json(client, "/method/nimbus/initialize", request_w_ref.model_dump(), access_token)
508
+
509
+ assert response_w_ref.status_code == status.HTTP_200_OK
510
+ result_w_ref = NIMBUSInitializationResponse.model_validate(json.loads(response_w_ref.content))
511
+
512
+ assert result_w_ref.state_id == 3
513
+ assert len(result_w_ref.current_solutions) == 1
514
+ assert len(result_w_ref.saved_solutions) == 0
515
+ assert len(result_w_ref.all_solutions) == 2 # we should have a new one
516
+
517
+
518
+ def test_nimbus_finalize(client: TestClient):
519
+ """Test for seeing if NIMBUS finalization works."""
520
+ access_token = login(client)
521
+
522
+ # create some previous iterations
523
+ request = NIMBUSInitializationRequest(problem_id=1)
524
+ response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
525
+ assert response.status_code == status.HTTP_200_OK
526
+ init_response = NIMBUSInitializationResponse.model_validate(json.loads(response.content))
527
+ assert init_response.state_id == 1
528
+ assert len(init_response.current_solutions) == 1
529
+ assert len(init_response.saved_solutions) == 0
530
+ assert len(init_response.all_solutions) == 1
531
+
532
+ preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
533
+
534
+ request = NIMBUSClassificationRequest(
535
+ problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=3
536
+ )
537
+
538
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
539
+ assert response.status_code == status.HTTP_200_OK
540
+ result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
541
+ json.loads(response.content.decode("utf-8"))
542
+ )
543
+ assert result.previous_preference == preference
544
+ assert len(result.current_solutions) == 3
545
+
546
+ solution_index = 2
547
+
548
+ optim_obj = result.current_solutions[solution_index].objective_values
549
+ optim_var = result.current_solutions[solution_index].variable_values
550
+
551
+ state_id = result.state_id
552
+
553
+ request = NIMBUSInitializationRequest(problem_id=1)
554
+ response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
555
+ assert response.status_code == status.HTTP_200_OK
556
+ classify_result = NIMBUSClassificationResponse.model_validate(json.loads(response.content))
557
+ assert classify_result.state_id == 2
558
+ assert len(classify_result.current_solutions) == 3
559
+ assert len(classify_result.saved_solutions) == 0
560
+ assert len(classify_result.all_solutions) == 4
561
+
562
+ request = NIMBUSFinalizeRequest(
563
+ problem_id=1,
564
+ solution_info=SolutionInfo(state_id=state_id, solution_index=solution_index),
565
+ )
566
+
567
+ # Finalize the process
568
+ response = post_json(client, "/method/nimbus/finalize", request.model_dump(), access_token)
569
+ assert response.status_code == status.HTTP_200_OK
570
+ result: NIMBUSFinalizeResponse = NIMBUSFinalizeResponse.model_validate(json.loads(response.content.decode("utf-8")))
571
+ assert result.response_type == "nimbus.finalize"
572
+ assert result.final_solution.objective_values == optim_obj
573
+ assert result.final_solution.variable_values == optim_var
574
+ assert result.final_solution.state_id != result.state_id
575
+ assert result.all_solutions is not None
576
+
577
+ request = NIMBUSInitializationRequest(problem_id=1)
578
+
579
+ # The last item in the pipe is a finalize state, so we should be getting a finalize response.
580
+ response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
581
+ assert response.status_code == status.HTTP_200_OK
582
+ result: NIMBUSFinalizeResponse = NIMBUSFinalizeResponse.model_validate(json.loads(response.content.decode("utf-8")))
583
+ assert result.response_type == "nimbus.finalize"
584
+ assert result.final_solution.objective_values == optim_obj
585
+ assert result.final_solution.variable_values == optim_var
586
+ assert result.final_solution.state_id != result.state_id
587
+ assert result.all_solutions is not None
588
+
589
+
590
+ def test_nimbus_save_and_delete_save(client: TestClient):
591
+ """Test that NIMBUS saving and save deletion works."""
592
+ access_token = login(client)
593
+
594
+ # 1. Initialize
595
+ request: NIMBUSInitializationRequest = NIMBUSInitializationRequest(problem_id=1)
596
+ response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
597
+ init_result: NIMBUSInitializationResponse = NIMBUSInitializationResponse.model_validate(
598
+ json.loads(response.content)
599
+ )
600
+ assert init_result.state_id == 1
601
+
602
+ # 2. Iterate
603
+ request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
604
+ problem_id=1,
605
+ preference=ReferencePoint(
606
+ aspiration_levels={
607
+ "f_1": 0.1,
608
+ "f_2": 0.8,
609
+ "f_3": 0.5,
610
+ }
611
+ ),
612
+ current_objectives=init_result.current_solutions[0].objective_values,
613
+ parent_state_id=1,
614
+ num_desired=3,
615
+ )
616
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
617
+ solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
618
+ json.loads(response.content)
619
+ )
620
+ assert solve_result.state_id == 2
621
+
622
+ # 3. Save
623
+ request: NIMBUSSaveRequest = NIMBUSSaveRequest(
624
+ problem_id=1, parent_state_id=2, solution_info=[SolutionInfo(state_id=2, solution_index=1)]
625
+ )
626
+ response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
627
+ save_result: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content))
628
+ assert save_result.state_id == 3
629
+
630
+ # Assert that stuff is saved
631
+ request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
632
+ problem_id=1,
633
+ preference=ReferencePoint(
634
+ aspiration_levels={
635
+ "f_1": 0.9,
636
+ "f_2": 0.1,
637
+ "f_3": 0.5,
638
+ }
639
+ ),
640
+ current_objectives=solve_result.current_solutions[0].objective_values,
641
+ num_desired=1,
642
+ parent_state_id=3,
643
+ )
644
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
645
+ solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
646
+ json.loads(response.content)
647
+ )
648
+ assert solve_result.state_id == 4
649
+ assert len(solve_result.saved_solutions) > 0
650
+
651
+ # 4. Delete save
652
+ request: NIMBUSDeleteSaveRequest = NIMBUSDeleteSaveRequest(state_id=2, solution_index=1)
653
+ response = post_json(client, "/method/nimbus/delete_save", request.model_dump(), access_token)
654
+ delete_save_result: NIMBUSDeleteSaveResponse = NIMBUSDeleteSaveResponse.model_validate(json.loads(response.content))
655
+
656
+ assert delete_save_result
657
+
658
+ # Assert that saved stuff has been deleted
659
+
660
+ # Assert that stuff is saved
661
+ request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
662
+ problem_id=1,
663
+ preference=ReferencePoint(
664
+ aspiration_levels={
665
+ "f_1": 0.1,
666
+ "f_2": 0.9,
667
+ "f_3": 0.4,
668
+ }
669
+ ),
670
+ current_objectives=solve_result.current_solutions[0].objective_values,
671
+ num_desired=1,
672
+ parent_state_id=4,
673
+ )
674
+ response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
675
+ solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
676
+ json.loads(response.content)
677
+ )
678
+ assert solve_result.state_id == 5
679
+ assert len(solve_result.saved_solutions) == 0
680
+
681
+
682
+ def test_add_new_dm(client: TestClient):
683
+ """Test that adding a decision maker works."""
684
+ # Create a new user to the database
685
+ good_response = client.post(
686
+ "/add_new_dm",
687
+ data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
688
+ headers={"content-type": "application/x-www-form-urlencoded"},
689
+ )
690
+ assert good_response.status_code == status.HTTP_201_CREATED
691
+
692
+ # There already should be a user named new_dm, so we shouldn't create another one.
693
+ bad_response = client.post(
694
+ "/add_new_dm",
695
+ data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
696
+ headers={"content-type": "application/x-www-form-urlencoded"},
697
+ )
698
+ assert bad_response.status_code == status.HTTP_409_CONFLICT
699
+
700
+
701
+ def test_add_new_analyst(client: TestClient):
702
+ """Test that adding a new analyst works."""
703
+ # Try to create an analyst without logging in
704
+ nologin_response = client.post(
705
+ "/add_new_analyst",
706
+ data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
707
+ headers={"content-type": "application/x-www-form-urlencoded"},
708
+ )
709
+
710
+ # No user
711
+ assert nologin_response.status_code == status.HTTP_401_UNAUTHORIZED
712
+
713
+ # Try to create an analyst using a dm account.
714
+ response = client.post(
715
+ "/add_new_dm",
716
+ data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
717
+ headers={"content-type": "application/x-www-form-urlencoded"},
718
+ )
719
+ assert response.status_code == status.HTTP_201_CREATED
720
+
721
+ dm_access_token = login(client, username="new_dm", password="new_dm") # noqa: S106
722
+
723
+ dm_response = client.post(
724
+ "/add_new_analyst",
725
+ data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
726
+ headers={"Authorization": f"Bearer {dm_access_token}", "content-type": "application/x-www-form-urlencoded"},
727
+ )
728
+
729
+ # Creating an analyst using unauthorized user should return 401 status
730
+ assert dm_response.status_code == status.HTTP_401_UNAUTHORIZED
731
+
732
+ # Login with proper user
733
+ analyst_access_token = login(client)
734
+
735
+ good_response = client.post(
736
+ "/add_new_analyst",
737
+ data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
738
+ headers={
739
+ "Authorization": f"Bearer {analyst_access_token}",
740
+ "content-type": "application/x-www-form-urlencoded",
741
+ },
742
+ )
743
+
744
+ # Creating a new analyst with an analyst user should return 201
745
+ assert good_response.status_code == status.HTTP_201_CREATED
746
+
747
+ bad_response = client.post(
748
+ "/add_new_analyst",
749
+ data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
750
+ headers={
751
+ "Authorization": f"Bearer {analyst_access_token}",
752
+ "content-type": "application/x-www-form-urlencoded",
753
+ },
754
+ )
755
+
756
+ # Trying to create an analyst with username that is already in use should return 409
757
+ assert bad_response.status_code == status.HTTP_409_CONFLICT
758
+
759
+
760
+ def test_login_logout(client: TestClient):
761
+ """Test that logging out works."""
762
+ # Login (sets refresh token cookie)
763
+ login(client=client, username="analyst", password="analyst") # noqa: S106
764
+
765
+ # Refresh access token
766
+ response = client.post("/refresh")
767
+ # Access token refreshed
768
+ assert response.status_code == status.HTTP_200_OK
769
+
770
+ # Logout (remove the refresh token cookie)
771
+ response = client.post(
772
+ "/logout",
773
+ )
774
+ assert response.status_code == status.HTTP_200_OK
775
+
776
+ # Refresh access token
777
+ response = client.post(
778
+ "/refresh",
779
+ headers={"content-type": "application/x-www-form-urlencoded"},
780
+ )
781
+ # Access token NOT refreshed
782
+ assert response.status_code == status.HTTP_401_UNAUTHORIZED
783
+
784
+
785
+ def test_group_operations(client: TestClient):
786
+ """Tests for websocket endpoints and websockets."""
787
+ create_endpoint = "/gdm/create_group"
788
+ delete_endpoint = "/gdm/delete_group"
789
+ add_user_endpoint = "/gdm/add_to_group"
790
+ remove_user_endpoint = "/gdm/remove_from_group"
791
+ group_info_endpoint = "/gdm/get_group_info"
792
+
793
+ # login to analyst
794
+ access_token = login(client=client, username="analyst", password="analyst") # noqa: S106
795
+
796
+ def get_info(gid: int):
797
+ return post_json(
798
+ client=client,
799
+ endpoint=group_info_endpoint,
800
+ json=GroupInfoRequest(
801
+ group_id=gid,
802
+ ).model_dump(),
803
+ access_token=access_token,
804
+ )
805
+
806
+ def get_user_info(token: str):
807
+ return client.get(
808
+ "/user_info",
809
+ headers={
810
+ "Authorization": f"Bearer {token}",
811
+ "content-type": "application/x-www-form-urlencoded",
812
+ },
813
+ )
814
+
815
+ # try to create group with no problem
816
+ response = post_json(
817
+ client=client,
818
+ endpoint=create_endpoint,
819
+ json=GroupCreateRequest(group_name="testGroup", problem_id=10).model_dump(),
820
+ access_token=access_token,
821
+ )
822
+ assert response.status_code == 404
823
+
824
+ # Create group properly
825
+ response = post_json(
826
+ client=client,
827
+ endpoint=create_endpoint,
828
+ json=GroupCreateRequest(group_name="testGroup", problem_id=2).model_dump(),
829
+ access_token=access_token,
830
+ )
831
+ assert response.status_code == 201
832
+
833
+ # Add a user to database
834
+ response = client.post(
835
+ "/add_new_dm",
836
+ data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
837
+ headers={"content-type": "application/x-www-form-urlencoded"},
838
+ )
839
+ assert response.status_code == status.HTTP_201_CREATED
840
+
841
+ # Add user to a group
842
+ response = post_json(
843
+ client=client,
844
+ endpoint=add_user_endpoint,
845
+ json=GroupModifyRequest(group_id=1, user_id=2).model_dump(),
846
+ access_token=access_token,
847
+ )
848
+ assert response.status_code == status.HTTP_200_OK
849
+ response = get_info(1)
850
+ assert response.status_code == status.HTTP_200_OK
851
+ group: GroupPublic = GroupPublic.model_validate(json.loads(response.content.decode("utf-8")))
852
+ assert 2 in group.user_ids
853
+ assert 1 not in group.user_ids
854
+
855
+ user_info = get_user_info(access_token)
856
+ user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
857
+ assert 1 in user.group_ids
858
+
859
+ dm_access_token = login(client, "new_dm", "new_dm")
860
+
861
+ user_info = get_user_info(dm_access_token)
862
+ dm_user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
863
+ assert 1 in dm_user.group_ids
864
+
865
+ # TODO: websocket testing and result fetching?
866
+
867
+ # Remove user from a group
868
+ response = post_json(
869
+ client=client,
870
+ endpoint=remove_user_endpoint,
871
+ json=GroupModifyRequest(group_id=1, user_id=2).model_dump(),
872
+ access_token=access_token,
873
+ )
874
+ assert response.status_code == status.HTTP_200_OK
875
+ response = get_info(1)
876
+ assert response.status_code == status.HTTP_200_OK
877
+ group: GroupPublic = GroupPublic.model_validate(json.loads(response.content.decode("utf-8")))
878
+ assert 2 not in group.user_ids
879
+
880
+ user_info = get_user_info(dm_access_token)
881
+ user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
882
+ assert 1 not in user.group_ids
883
+
884
+ user_info = get_user_info(access_token)
885
+ user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
886
+ assert 1 in user.group_ids
887
+
888
+ # Delete the group
889
+ response = post_json(
890
+ client=client,
891
+ endpoint=delete_endpoint,
892
+ json=GroupInfoRequest(
893
+ group_id=1,
894
+ ).model_dump(),
895
+ access_token=access_token,
896
+ )
897
+ assert response.status_code == status.HTTP_200_OK
898
+ response = get_info(1)
899
+ assert response.status_code == status.HTTP_404_NOT_FOUND
900
+
901
+ user_info = get_user_info(access_token)
902
+ user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
903
+ assert 1 not in user.group_ids
904
+
905
+
906
+ def test_preferred_solver(client: TestClient):
907
+ """Test that setting a preferred solver for the problem is ok."""
908
+ access_token = login(client)
909
+
910
+ request = ProblemSelectSolverRequest(problem_id=1, solver_string_representation="THIS SOLVER DOESN'T EXIST")
911
+ response = post_json(client, "/problem/assign_solver", request.model_dump(), access_token)
912
+ assert response.status_code == 404
913
+
914
+ request = ProblemSelectSolverRequest(problem_id=1, solver_string_representation="pyomo_cbc")
915
+ response = post_json(client, "/problem/assign_solver", request.model_dump(), access_token)
916
+ assert response.status_code == 200
917
+
918
+ request = {"problem_id": 1, "metadata_type": "solver_selection_metadata"}
919
+ response = post_json(client, "/problem/get_metadata", request, access_token)
920
+ assert response.status_code == 200
921
+
922
+ model = SolverSelectionMetadata.model_validate(response.json()[0])
923
+
924
+ assert model.metadata_type == "solver_selection_metadata"
925
+ assert model.solver_string_representation == "pyomo_cbc"
926
+
927
+ # Test that the solver is in use
928
+ try:
929
+ request = NIMBUSInitializationRequest(problem_id=1)
930
+ response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
931
+ model = NIMBUSInitializationResponse.model_validate(response.json())
932
+ except Exception as e:
933
+ print(e)
934
+ print("^ This outcome is expected since pyomo_cbc doesn't support nonlinear problems.")
935
+ print(" As that solver is what we set it to be in the start, we can verify that they actually get used.")
936
+
937
+
938
+ def test_get_available_solvers(client: TestClient):
939
+ """Test that available solvers can be fetched."""
940
+ response = client.get("/problem/assign/solver")
941
+
942
+ assert response.status_code == 200
943
+
944
+ data = response.json()
945
+ assert isinstance(data, list)
946
+
947
+ # Check that the returned solver names match the available solvers
948
+ assert set(data) == set(available_solvers.keys())
949
+
950
+
951
+ def test_emo_solve_with_reference_point(client: TestClient):
952
+ """Test that using EMO with reference point works as expected."""
953
+ return
954
+ # TODO: This test fails because of websocket issues. Fix those and re-enable the test.
955
+ access_token = login(client)
956
+ request = EMOIterateRequest(
957
+ problem_id=1,
958
+ template_options=[rvea_options.template],
959
+ preference_options=ReferencePointOptions(preference={"f_1": 0.5, "f_2": 0.3, "f_3": 0.4}, method="Hakanen"),
960
+ )
961
+
962
+ response = post_json(client, "/method/emo/iterate", request.model_dump(), access_token)
963
+
964
+ assert response.status_code == status.HTTP_200_OK
965
+
966
+ # Validate the response structure
967
+ emo_response = EMOIterateResponse.model_validate(response.json())
968
+ assert emo_response.client_id is not None
969
+ state_id = emo_response.state_id
970
+
971
+ initial_time = time.time()
972
+ with client.websocket_connect(f"/method/emo/ws/{emo_response.client_id}") as websocket:
973
+ while time.time() - initial_time < 10:
974
+ message = websocket.receive_json()
975
+ if message.get("message") == f"Finished {emo_response.method_ids[0]}":
976
+ break
977
+ # Fetch the state to verify it worked
978
+ fetch_request = EMOFetchRequest(problem_id=1, parent_state_id=state_id)
979
+ response = post_json(client, "/method/emo/fetch", fetch_request.model_dump(), access_token)
980
+
981
+
982
+ def test_get_problem_metadata(client: TestClient):
983
+ """Test that fetching problem metadata works."""
984
+ access_token = login(client=client)
985
+
986
+ # Problem with no metadata
987
+ req = {"problem_id": 1, "metadata_type": "forest_problem_metadata"}
988
+ response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
989
+ assert response.status_code == 200
990
+ assert response.json() == []
991
+
992
+ # Problem with forest metadata
993
+ req = {"problem_id": 2, "metadata_type": "forest_problem_metadata"}
994
+ response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
995
+ assert response.status_code == 200
996
+ assert response.json()[0]["metadata_type"] == "forest_problem_metadata"
997
+ assert response.json()[0]["schedule_dict"] == {"type": "dict"}
998
+
999
+ # No problem
1000
+ req = {"problem_id": 4, "metadata_type": "forest_problem_metadata"}
1001
+ response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
1002
+ assert response.status_code == 404
1003
+
1004
+
1005
+ def test_gdm_score_bands(client: TestClient):
1006
+ """Test score bands endpoints."""
1007
+ access_token = login(client=client)
1008
+
1009
+ # create group
1010
+ req = GroupCreateRequest(
1011
+ group_name="group",
1012
+ problem_id=3, # The discrete representation problem
1013
+ ).model_dump()
1014
+ response = post_json(client=client, endpoint="/gdm/create_group", json=req, access_token=access_token)
1015
+ assert response.status_code == 201
1016
+
1017
+ # Add a dm to the group
1018
+ # Create a new user to the database
1019
+ response = client.post(
1020
+ "/add_new_dm",
1021
+ data={"username": "dm", "password": "dm", "grant_type": "password"},
1022
+ headers={"content-type": "application/x-www-form-urlencoded"},
1023
+ )
1024
+ assert response.status_code == 201
1025
+
1026
+ req = GroupModifyRequest(group_id=1, user_id=2).model_dump()
1027
+ response = post_json(client=client, endpoint="/gdm/add_to_group", json=req, access_token=access_token)
1028
+ assert response.status_code == 200
1029
+
1030
+ access_token = login(client=client, username="dm", password="dm")
1031
+
1032
+ # Now we have a group, so let's get on with making stuff with gdm score bands.
1033
+ req = GDMScoreBandsInitializationRequest(
1034
+ group_id=1,
1035
+ score_bands_config=SCOREBandsGDMConfig(
1036
+ score_bands_config=SCOREBandsConfig(clustering_algorithm=KMeansOptions(n_clusters=5)), from_iteration=None
1037
+ ),
1038
+ ).model_dump()
1039
+ response = post_json(
1040
+ client=client, endpoint="/gdm-score-bands/get-or-initialize", json=req, access_token=access_token
1041
+ )
1042
+ assert response.status_code == 200
1043
+ response_innards = GDMSCOREBandsHistoryResponse.model_validate(response.json())
1044
+ cluster_size_1 = len(response_innards.history[-1].result.clusters)
1045
+
1046
+ # VOTE AND CONFIRM
1047
+ req = GDMScoreBandsVoteRequest(
1048
+ group_id=1,
1049
+ vote=4,
1050
+ ).model_dump()
1051
+ response = post_json(client=client, endpoint="/gdm-score-bands/vote", json=req, access_token=access_token)
1052
+ assert response.status_code == 200
1053
+ req = GroupInfoRequest(group_id=1).model_dump()
1054
+ response = post_json(client=client, endpoint="/gdm-score-bands/confirm", json=req, access_token=access_token)
1055
+ assert response.status_code == 200
1056
+
1057
+ req = GDMScoreBandsInitializationRequest(
1058
+ group_id=1,
1059
+ score_bands_config=SCOREBandsGDMConfig(
1060
+ score_bands_config=SCOREBandsConfig(clustering_algorithm=KMeansOptions(n_clusters=5)),
1061
+ from_iteration=response_innards.history[-1].latest_iteration,
1062
+ ),
1063
+ ).model_dump()
1064
+ response = post_json(
1065
+ client=client, endpoint="/gdm-score-bands/get-or-initialize", json=req, access_token=access_token
1066
+ )
1067
+ assert response.status_code == 200
1068
+ response_innards = GDMSCOREBandsHistoryResponse.model_validate(response.json())
1069
+ cluster_size_2 = len(response_innards.history[-1].result.clusters)
1070
+
1071
+ # Since we've made one iteration, the length of the clustering and therefore the active
1072
+ # indices should be smaller the second time around than the first one.
1073
+ assert cluster_size_1 > cluster_size_2
1074
+
1075
+ # TODO: Test reverting, re-clustering