desdeo 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/__init__.py +6 -6
- desdeo/api/app.py +38 -28
- desdeo/api/config.py +65 -44
- desdeo/api/config.toml +23 -12
- desdeo/api/db.py +10 -8
- desdeo/api/db_init.py +12 -6
- desdeo/api/models/__init__.py +220 -20
- desdeo/api/models/archive.py +16 -27
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +44 -6
- desdeo/api/models/problem.py +274 -64
- desdeo/api/models/session.py +4 -1
- desdeo/api/models/state.py +419 -52
- desdeo/api/models/user.py +7 -6
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NIMBUS.py +6 -3
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +201 -4
- desdeo/api/routers/reference_point_method.py +20 -44
- desdeo/api/routers/session.py +50 -26
- desdeo/api/routers/user_authentication.py +180 -26
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +10 -4
- desdeo/api/tests/conftest.py +94 -2
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +550 -72
- desdeo/api/tests/test_routes.py +902 -43
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/database.py +28 -266
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +7 -0
- desdeo/emo/__init__.py +154 -24
- desdeo/emo/hooks/archivers.py +18 -2
- desdeo/emo/methods/EAs.py +128 -5
- desdeo/emo/methods/bases.py +9 -56
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/crossover.py +544 -42
- desdeo/emo/operators/evaluator.py +10 -14
- desdeo/emo/operators/generator.py +127 -24
- desdeo/emo/operators/mutation.py +212 -41
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +956 -214
- desdeo/emo/operators/termination.py +124 -16
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +23 -1
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautilus_navigator.py +7 -6
- desdeo/mcdm/reference_point_method.py +70 -0
- desdeo/problem/__init__.py +16 -11
- desdeo/problem/evaluator.py +4 -5
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +37 -12
- desdeo/problem/infix_parser.py +1 -16
- desdeo/problem/json_parser.py +7 -11
- desdeo/problem/pyomo_evaluator.py +25 -6
- desdeo/problem/schema.py +73 -55
- desdeo/problem/simulator_evaluator.py +65 -15
- desdeo/problem/testproblems/__init__.py +26 -11
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/forest_problem.py +77 -69
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/zdt_problem.py +4 -1
- desdeo/problem/utils.py +1 -1
- desdeo/tools/__init__.py +39 -21
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +22 -2
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/indicators_binary.py +107 -1
- desdeo/tools/indicators_unary.py +3 -16
- desdeo/tools/message.py +33 -2
- desdeo/tools/non_dominated_sorting.py +4 -3
- desdeo/tools/patterns.py +9 -7
- desdeo/tools/pyomo_solver_interfaces.py +49 -36
- desdeo/tools/reference_vectors.py +118 -351
- desdeo/tools/scalarization.py +340 -1413
- desdeo/tools/score_bands.py +491 -328
- desdeo/tools/utils.py +117 -49
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/utopia_problem.py +1 -1
- desdeo/utopia_stuff/utopia_problem_old.py +1 -1
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/METADATA +47 -30
- desdeo-2.1.1.dist-info/RECORD +180 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/WHEEL +1 -1
- desdeo-2.0.0.dist-info/RECORD +0 -120
- /desdeo/api/utils/{logger.py → _logger.py} +0 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info/licenses}/LICENSE +0 -0
desdeo/api/tests/test_routes.py
CHANGED
|
@@ -1,48 +1,60 @@
|
|
|
1
1
|
"""Tests related to routes and routers."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
|
|
3
6
|
from fastapi import status
|
|
4
7
|
from fastapi.testclient import TestClient
|
|
5
8
|
|
|
6
9
|
from desdeo.api.models import (
|
|
7
10
|
CreateSessionRequest,
|
|
8
|
-
|
|
11
|
+
EMOFetchRequest,
|
|
12
|
+
EMOIterateRequest,
|
|
13
|
+
EMOIterateResponse,
|
|
14
|
+
ForestProblemMetaData,
|
|
15
|
+
GDMSCOREBandsHistoryResponse,
|
|
16
|
+
GDMScoreBandsInitializationRequest,
|
|
17
|
+
GDMScoreBandsVoteRequest,
|
|
18
|
+
GenericIntermediateSolutionResponse,
|
|
19
|
+
GroupCreateRequest,
|
|
20
|
+
GroupInfoRequest,
|
|
21
|
+
GroupModifyRequest,
|
|
22
|
+
GroupPublic,
|
|
9
23
|
InteractiveSessionDB,
|
|
24
|
+
IntermediateSolutionRequest,
|
|
25
|
+
NIMBUSClassificationRequest,
|
|
26
|
+
NIMBUSClassificationResponse,
|
|
27
|
+
NIMBUSDeleteSaveRequest,
|
|
28
|
+
NIMBUSDeleteSaveResponse,
|
|
29
|
+
NIMBUSFinalizeRequest,
|
|
30
|
+
NIMBUSFinalizeResponse,
|
|
31
|
+
NIMBUSInitializationRequest,
|
|
32
|
+
NIMBUSIntermediateSolutionResponse,
|
|
33
|
+
NIMBUSSaveRequest,
|
|
34
|
+
NIMBUSSaveResponse,
|
|
35
|
+
ProblemDB,
|
|
10
36
|
ProblemGetRequest,
|
|
11
37
|
ProblemInfo,
|
|
38
|
+
ProblemSelectSolverRequest,
|
|
12
39
|
ReferencePoint,
|
|
13
40
|
RPMSolveRequest,
|
|
41
|
+
SolutionInfo,
|
|
42
|
+
SolverSelectionMetadata,
|
|
14
43
|
User,
|
|
44
|
+
UserPublic,
|
|
15
45
|
)
|
|
46
|
+
from desdeo.api.models.nimbus import NIMBUSInitializationResponse
|
|
16
47
|
from desdeo.api.routers.user_authentication import create_access_token
|
|
17
|
-
from desdeo.
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
data={"username": username, "password": password, "grant_type": "password"},
|
|
25
|
-
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
26
|
-
).json()
|
|
48
|
+
from desdeo.emo.options.algorithms import rvea_options
|
|
49
|
+
from desdeo.emo.options.templates import ReferencePointOptions
|
|
50
|
+
from desdeo.gdm.score_bands import SCOREBandsGDMConfig
|
|
51
|
+
from desdeo.problem import Problem
|
|
52
|
+
from desdeo.problem.testproblems import dtlz2, simple_knapsack_vectors
|
|
53
|
+
from desdeo.tools.score_bands import KMeansOptions, SCOREBandsConfig
|
|
54
|
+
from desdeo.tools.utils import available_solvers
|
|
27
55
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def post_json(client: TestClient, endpoint: str, json: dict, access_token: str):
|
|
32
|
-
"""Makes a post request and returns the response."""
|
|
33
|
-
return client.post(
|
|
34
|
-
endpoint,
|
|
35
|
-
json=json,
|
|
36
|
-
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def get_json(client: TestClient, endpoint: str, access_token: str):
|
|
41
|
-
"""Makes a get request and returns the response."""
|
|
42
|
-
return client.get(
|
|
43
|
-
endpoint,
|
|
44
|
-
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
|
45
|
-
)
|
|
56
|
+
from .conftest import get_json, login, post_file_multipart, post_json
|
|
57
|
+
from .test_models import compare_models
|
|
46
58
|
|
|
47
59
|
|
|
48
60
|
def test_user_login(client: TestClient):
|
|
@@ -114,6 +126,7 @@ def test_get_problem(client: TestClient):
|
|
|
114
126
|
|
|
115
127
|
assert info.id == 1
|
|
116
128
|
assert info.name == "dtlz2"
|
|
129
|
+
assert info.problem_metadata is None
|
|
117
130
|
|
|
118
131
|
response = post_json(client, "problem/get", ProblemGetRequest(problem_id=2).model_dump(), access_token)
|
|
119
132
|
|
|
@@ -123,6 +136,7 @@ def test_get_problem(client: TestClient):
|
|
|
123
136
|
|
|
124
137
|
assert info.id == 2
|
|
125
138
|
assert info.name == "The river pollution problem"
|
|
139
|
+
assert isinstance(info.problem_metadata.forest_metadata[0], ForestProblemMetaData)
|
|
126
140
|
|
|
127
141
|
|
|
128
142
|
def test_add_problem(client: TestClient):
|
|
@@ -145,7 +159,25 @@ def test_add_problem(client: TestClient):
|
|
|
145
159
|
|
|
146
160
|
problems = response.json()
|
|
147
161
|
|
|
148
|
-
assert len(problems) ==
|
|
162
|
+
assert len(problems) == 4
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def test_add_problem_json(client: TestClient, session_and_user: dict):
|
|
166
|
+
"""Test that adding a problem to the database works with JSON files."""
|
|
167
|
+
session = session_and_user["session"]
|
|
168
|
+
access_token = login(client)
|
|
169
|
+
problem = dtlz2(5, 3)
|
|
170
|
+
|
|
171
|
+
payload = problem.model_dump()
|
|
172
|
+
raw = json.dumps(payload).encode("utf-8")
|
|
173
|
+
|
|
174
|
+
response = post_file_multipart(client, "/problem/add_json", raw, access_token)
|
|
175
|
+
|
|
176
|
+
assert response.status_code == 200
|
|
177
|
+
|
|
178
|
+
problem_from_db = session.get(ProblemDB, 4)
|
|
179
|
+
|
|
180
|
+
assert compare_models(problem, Problem.from_problemdb(problem_from_db))
|
|
149
181
|
|
|
150
182
|
|
|
151
183
|
def test_new_session(client: TestClient, session_and_user: dict):
|
|
@@ -170,37 +202,110 @@ def test_new_session(client: TestClient, session_and_user: dict):
|
|
|
170
202
|
|
|
171
203
|
|
|
172
204
|
def test_get_session(client: TestClient, session_and_user: dict):
|
|
173
|
-
"""Test that getting a session works as intended."""
|
|
205
|
+
"""Test that getting a session via GET works as intended."""
|
|
174
206
|
user: User = session_and_user["user"]
|
|
175
207
|
|
|
176
208
|
access_token = login(client)
|
|
177
209
|
|
|
178
210
|
# no sessions
|
|
179
|
-
|
|
180
|
-
|
|
211
|
+
response = client.get(
|
|
212
|
+
"/session/get/1",
|
|
213
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
214
|
+
)
|
|
181
215
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
182
216
|
|
|
183
|
-
# add
|
|
184
|
-
request = CreateSessionRequest(info="
|
|
217
|
+
# add session 1
|
|
218
|
+
request = CreateSessionRequest(info="Session 1")
|
|
185
219
|
response = post_json(client, "/session/new", request.model_dump(), access_token)
|
|
186
220
|
assert response.status_code == status.HTTP_200_OK
|
|
187
|
-
|
|
188
221
|
assert user.active_session_id == 1
|
|
189
222
|
|
|
190
|
-
|
|
223
|
+
# add session 2
|
|
224
|
+
request = CreateSessionRequest(info="Session 2")
|
|
191
225
|
response = post_json(client, "/session/new", request.model_dump(), access_token)
|
|
192
226
|
assert response.status_code == status.HTTP_200_OK
|
|
193
|
-
|
|
194
227
|
assert user.active_session_id == 2
|
|
195
228
|
|
|
196
|
-
#
|
|
197
|
-
|
|
198
|
-
|
|
229
|
+
# fetch session 1
|
|
230
|
+
response = client.get(
|
|
231
|
+
"/session/get/1",
|
|
232
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
233
|
+
)
|
|
199
234
|
assert response.status_code == status.HTTP_200_OK
|
|
235
|
+
assert response.json()["id"] == 1
|
|
200
236
|
|
|
201
|
-
|
|
202
|
-
response =
|
|
237
|
+
# fetch session 2
|
|
238
|
+
response = client.get(
|
|
239
|
+
"/session/get/2",
|
|
240
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
241
|
+
)
|
|
203
242
|
assert response.status_code == status.HTTP_200_OK
|
|
243
|
+
assert response.json()["id"] == 2
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def test_get_all_sessions_success(client: TestClient, session_and_user: dict):
|
|
247
|
+
"""Test getting all sessions when sessions exist."""
|
|
248
|
+
access_token = login(client)
|
|
249
|
+
|
|
250
|
+
# create 2 test sessions
|
|
251
|
+
post_json(client, "/session/new", {"info": "S1"}, access_token)
|
|
252
|
+
post_json(client, "/session/new", {"info": "S2"}, access_token)
|
|
253
|
+
|
|
254
|
+
response = client.get(
|
|
255
|
+
"/session/get_all",
|
|
256
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
assert response.status_code == status.HTTP_200_OK
|
|
260
|
+
data = response.json()
|
|
261
|
+
assert isinstance(data, list)
|
|
262
|
+
assert len(data) == 2
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def test_get_all_sessions_not_found(client: TestClient, session_and_user: dict):
|
|
266
|
+
"""Test get_all returns 404 if user has no sessions."""
|
|
267
|
+
access_token = login(client)
|
|
268
|
+
|
|
269
|
+
response = client.get(
|
|
270
|
+
"/session/get_all",
|
|
271
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def test_delete_session_success(client: TestClient, session_and_user: dict):
|
|
278
|
+
"""Test deleting an existing session."""
|
|
279
|
+
access_token = login(client)
|
|
280
|
+
|
|
281
|
+
# create session
|
|
282
|
+
post_json(client, "/session/new", {"info": "To delete"}, access_token)
|
|
283
|
+
|
|
284
|
+
response = client.delete(
|
|
285
|
+
"/session/1",
|
|
286
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
290
|
+
|
|
291
|
+
# verify it's gone
|
|
292
|
+
response = client.get(
|
|
293
|
+
"/session/get/1",
|
|
294
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
295
|
+
)
|
|
296
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def test_delete_session_not_found(client: TestClient, session_and_user: dict):
|
|
300
|
+
"""Test deleting a non-existent session returns 404."""
|
|
301
|
+
access_token = login(client)
|
|
302
|
+
|
|
303
|
+
response = client.delete(
|
|
304
|
+
"/session/999",
|
|
305
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
204
309
|
|
|
205
310
|
|
|
206
311
|
def test_rpm_solve(client: TestClient):
|
|
@@ -214,3 +319,757 @@ def test_rpm_solve(client: TestClient):
|
|
|
214
319
|
response = post_json(client, "/method/rpm/solve", request.model_dump(), access_token)
|
|
215
320
|
|
|
216
321
|
assert response.status_code == status.HTTP_200_OK
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def test_nimbus_solve(client: TestClient):
|
|
325
|
+
"""Test that using the NIMBUS method works as expected."""
|
|
326
|
+
access_token = login(client)
|
|
327
|
+
|
|
328
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
329
|
+
|
|
330
|
+
request = NIMBUSClassificationRequest(
|
|
331
|
+
problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=3
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
335
|
+
assert response.status_code == status.HTTP_200_OK
|
|
336
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
337
|
+
json.loads(response.content.decode("utf-8"))
|
|
338
|
+
)
|
|
339
|
+
assert result.previous_preference == preference
|
|
340
|
+
assert len(result.all_solutions) == 3
|
|
341
|
+
|
|
342
|
+
request = NIMBUSSaveRequest(
|
|
343
|
+
problem_id=1,
|
|
344
|
+
parent_state_id=result.state_id,
|
|
345
|
+
solution_info=[
|
|
346
|
+
SolutionInfo(state_id=1, solution_index=0, name="Test solution 1"),
|
|
347
|
+
SolutionInfo(state_id=1, solution_index=2, name="Test solution 3"),
|
|
348
|
+
SolutionInfo(state_id=1, solution_index=2, name="Test solution 34"), # saved twice!
|
|
349
|
+
],
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
|
|
353
|
+
assert response.status_code == status.HTTP_200_OK
|
|
354
|
+
result2: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
355
|
+
assert result2.state_id is not None
|
|
356
|
+
|
|
357
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.1, "f_2": 0.1, "f_3": 0.9})
|
|
358
|
+
|
|
359
|
+
request = NIMBUSClassificationRequest(
|
|
360
|
+
problem_id=1,
|
|
361
|
+
preference=preference,
|
|
362
|
+
current_objectives=result.current_solutions[0].objective_values,
|
|
363
|
+
num_desired=3,
|
|
364
|
+
parent_state_id=result2.state_id,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
368
|
+
|
|
369
|
+
assert response.status_code == status.HTTP_200_OK
|
|
370
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
371
|
+
json.loads(response.content.decode("utf-8"))
|
|
372
|
+
)
|
|
373
|
+
assert result.previous_preference == preference
|
|
374
|
+
# We saved the same solution twice, so the filtering should remove one of those.
|
|
375
|
+
assert len(result.saved_solutions) == 2
|
|
376
|
+
assert len(result.all_solutions) == 6 # should not count saved solutions twice
|
|
377
|
+
|
|
378
|
+
# Save some more solutions!
|
|
379
|
+
request = NIMBUSSaveRequest(
|
|
380
|
+
problem_id=1,
|
|
381
|
+
parent_state_id=result.state_id,
|
|
382
|
+
solution_info=[SolutionInfo(state_id=result.state_id, solution_index=1, name="Test solution 2")],
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
|
|
386
|
+
assert response.status_code == status.HTTP_200_OK
|
|
387
|
+
result2: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
388
|
+
assert result2.state_id is not None
|
|
389
|
+
|
|
390
|
+
# Same as the first one. Therefore, (I believe) STOM and ASF give same solutions,
|
|
391
|
+
# which should be reflected on the amount of all solutions
|
|
392
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
393
|
+
|
|
394
|
+
request = NIMBUSClassificationRequest(
|
|
395
|
+
problem_id=1,
|
|
396
|
+
preference=preference,
|
|
397
|
+
current_objectives=result.current_solutions[0].objective_values,
|
|
398
|
+
num_desired=3,
|
|
399
|
+
parent_state_id=result2.state_id,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
403
|
+
assert response.status_code == status.HTTP_200_OK
|
|
404
|
+
result3: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
405
|
+
json.loads(response.content.decode("utf-8"))
|
|
406
|
+
)
|
|
407
|
+
assert result3.previous_preference == preference
|
|
408
|
+
assert len(result3.saved_solutions) == 3
|
|
409
|
+
assert len(result3.all_solutions) == 7
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def test_intermediate_solve(client: TestClient):
|
|
413
|
+
"""Test that solving intermediate solutions works as expected."""
|
|
414
|
+
access_token = login(client)
|
|
415
|
+
|
|
416
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
417
|
+
|
|
418
|
+
request = NIMBUSClassificationRequest(
|
|
419
|
+
problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=2
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
423
|
+
assert response.status_code == status.HTTP_200_OK
|
|
424
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
425
|
+
json.loads(response.content.decode("utf-8"))
|
|
426
|
+
)
|
|
427
|
+
assert len(result.all_solutions) == 2
|
|
428
|
+
|
|
429
|
+
# Save some solutions!
|
|
430
|
+
solution_1 = SolutionInfo(state_id=result.state_id, solution_index=0)
|
|
431
|
+
solution_2 = SolutionInfo(state_id=result.state_id, solution_index=1, name="named solution")
|
|
432
|
+
|
|
433
|
+
# Create request for intermediate solutions using solutions created with nimbus solve
|
|
434
|
+
request = IntermediateSolutionRequest(
|
|
435
|
+
problem_id=1,
|
|
436
|
+
context="test",
|
|
437
|
+
reference_solution_1=solution_1,
|
|
438
|
+
reference_solution_2=solution_2,
|
|
439
|
+
num_desired=3,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Test the generic intermediate endpoint
|
|
443
|
+
response = post_json(client, "/method/generic/intermediate", request.model_dump(), access_token)
|
|
444
|
+
assert response.status_code == status.HTTP_200_OK
|
|
445
|
+
result: GenericIntermediateSolutionResponse = GenericIntermediateSolutionResponse.model_validate(
|
|
446
|
+
json.loads(response.content.decode("utf-8"))
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Test the NIMBUS-specific intermediate endpoint
|
|
450
|
+
nimbus_request = IntermediateSolutionRequest(
|
|
451
|
+
problem_id=1,
|
|
452
|
+
context="nimbus",
|
|
453
|
+
reference_solution_1=solution_1,
|
|
454
|
+
reference_solution_2=solution_2,
|
|
455
|
+
num_desired=2,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
nimbus_response = post_json(client, "/method/nimbus/intermediate", nimbus_request.model_dump(), access_token)
|
|
459
|
+
assert nimbus_response.status_code == status.HTTP_200_OK
|
|
460
|
+
nimbus_result = NIMBUSIntermediateSolutionResponse.model_validate(
|
|
461
|
+
json.loads(nimbus_response.content.decode("utf-8"))
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Verify the NIMBUS response contains expected fields
|
|
465
|
+
assert nimbus_result.state_id is not None
|
|
466
|
+
assert len(nimbus_result.current_solutions) == 2
|
|
467
|
+
assert len(nimbus_result.all_solutions) == 7
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def test_nimbus_initialize(client: TestClient):
|
|
471
|
+
"""Test that initializing NIMBUS works without specifying a solver."""
|
|
472
|
+
access_token = login(client)
|
|
473
|
+
|
|
474
|
+
# test with no starting point
|
|
475
|
+
request = NIMBUSInitializationRequest(problem_id=1, solver=None)
|
|
476
|
+
|
|
477
|
+
response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
|
|
478
|
+
|
|
479
|
+
assert response.status_code == status.HTTP_200_OK
|
|
480
|
+
init_result = NIMBUSInitializationResponse.model_validate(json.loads(response.content))
|
|
481
|
+
|
|
482
|
+
assert init_result.state_id == 1
|
|
483
|
+
assert len(init_result.current_solutions) == 1
|
|
484
|
+
assert len(init_result.saved_solutions) == 0
|
|
485
|
+
assert len(init_result.all_solutions) == 1
|
|
486
|
+
|
|
487
|
+
# test with starting point given as solution info
|
|
488
|
+
request_w_info = NIMBUSInitializationRequest(
|
|
489
|
+
problem_id=1, starting_point=SolutionInfo(state_id=1, solution_index=0)
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
response_w_info = post_json(client, "/method/nimbus/initialize", request_w_info.model_dump(), access_token)
|
|
493
|
+
|
|
494
|
+
assert response_w_info.status_code == status.HTTP_200_OK
|
|
495
|
+
result_w_info = NIMBUSInitializationResponse.model_validate(json.loads(response_w_info.content))
|
|
496
|
+
|
|
497
|
+
assert result_w_info.state_id == 2
|
|
498
|
+
assert len(result_w_info.current_solutions) == 1
|
|
499
|
+
assert len(result_w_info.saved_solutions) == 0
|
|
500
|
+
assert len(result_w_info.all_solutions) == 1 # this is still one because the new solution will be a duplicate.
|
|
501
|
+
|
|
502
|
+
# test with starting point given as a reference point
|
|
503
|
+
request_w_ref = NIMBUSInitializationRequest(
|
|
504
|
+
problem_id=1, starting_point=ReferencePoint(aspiration_levels={"f_1": 0.2, "f_2": 0.8, "f_3": 0.4})
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
response_w_ref = post_json(client, "/method/nimbus/initialize", request_w_ref.model_dump(), access_token)
|
|
508
|
+
|
|
509
|
+
assert response_w_ref.status_code == status.HTTP_200_OK
|
|
510
|
+
result_w_ref = NIMBUSInitializationResponse.model_validate(json.loads(response_w_ref.content))
|
|
511
|
+
|
|
512
|
+
assert result_w_ref.state_id == 3
|
|
513
|
+
assert len(result_w_ref.current_solutions) == 1
|
|
514
|
+
assert len(result_w_ref.saved_solutions) == 0
|
|
515
|
+
assert len(result_w_ref.all_solutions) == 2 # we should have a new one
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def test_nimbus_finalize(client: TestClient):
|
|
519
|
+
"""Test for seeing if NIMBUS finalization works."""
|
|
520
|
+
access_token = login(client)
|
|
521
|
+
|
|
522
|
+
# create some previous iterations
|
|
523
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
524
|
+
response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
|
|
525
|
+
assert response.status_code == status.HTTP_200_OK
|
|
526
|
+
init_response = NIMBUSInitializationResponse.model_validate(json.loads(response.content))
|
|
527
|
+
assert init_response.state_id == 1
|
|
528
|
+
assert len(init_response.current_solutions) == 1
|
|
529
|
+
assert len(init_response.saved_solutions) == 0
|
|
530
|
+
assert len(init_response.all_solutions) == 1
|
|
531
|
+
|
|
532
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
533
|
+
|
|
534
|
+
request = NIMBUSClassificationRequest(
|
|
535
|
+
problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=3
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
539
|
+
assert response.status_code == status.HTTP_200_OK
|
|
540
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
541
|
+
json.loads(response.content.decode("utf-8"))
|
|
542
|
+
)
|
|
543
|
+
assert result.previous_preference == preference
|
|
544
|
+
assert len(result.current_solutions) == 3
|
|
545
|
+
|
|
546
|
+
solution_index = 2
|
|
547
|
+
|
|
548
|
+
optim_obj = result.current_solutions[solution_index].objective_values
|
|
549
|
+
optim_var = result.current_solutions[solution_index].variable_values
|
|
550
|
+
|
|
551
|
+
state_id = result.state_id
|
|
552
|
+
|
|
553
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
554
|
+
response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
|
|
555
|
+
assert response.status_code == status.HTTP_200_OK
|
|
556
|
+
classify_result = NIMBUSClassificationResponse.model_validate(json.loads(response.content))
|
|
557
|
+
assert classify_result.state_id == 2
|
|
558
|
+
assert len(classify_result.current_solutions) == 3
|
|
559
|
+
assert len(classify_result.saved_solutions) == 0
|
|
560
|
+
assert len(classify_result.all_solutions) == 4
|
|
561
|
+
|
|
562
|
+
request = NIMBUSFinalizeRequest(
|
|
563
|
+
problem_id=1,
|
|
564
|
+
solution_info=SolutionInfo(state_id=state_id, solution_index=solution_index),
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Finalize the process
|
|
568
|
+
response = post_json(client, "/method/nimbus/finalize", request.model_dump(), access_token)
|
|
569
|
+
assert response.status_code == status.HTTP_200_OK
|
|
570
|
+
result: NIMBUSFinalizeResponse = NIMBUSFinalizeResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
571
|
+
assert result.response_type == "nimbus.finalize"
|
|
572
|
+
assert result.final_solution.objective_values == optim_obj
|
|
573
|
+
assert result.final_solution.variable_values == optim_var
|
|
574
|
+
assert result.final_solution.state_id != result.state_id
|
|
575
|
+
assert result.all_solutions is not None
|
|
576
|
+
|
|
577
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
578
|
+
|
|
579
|
+
# The last item in the pipe is a finalize state, so we should be getting a finalize response.
|
|
580
|
+
response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
|
|
581
|
+
assert response.status_code == status.HTTP_200_OK
|
|
582
|
+
result: NIMBUSFinalizeResponse = NIMBUSFinalizeResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
583
|
+
assert result.response_type == "nimbus.finalize"
|
|
584
|
+
assert result.final_solution.objective_values == optim_obj
|
|
585
|
+
assert result.final_solution.variable_values == optim_var
|
|
586
|
+
assert result.final_solution.state_id != result.state_id
|
|
587
|
+
assert result.all_solutions is not None
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def test_nimbus_save_and_delete_save(client: TestClient):
|
|
591
|
+
"""Test that NIMBUS saving and save deletion works."""
|
|
592
|
+
access_token = login(client)
|
|
593
|
+
|
|
594
|
+
# 1. Initialize
|
|
595
|
+
request: NIMBUSInitializationRequest = NIMBUSInitializationRequest(problem_id=1)
|
|
596
|
+
response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
|
|
597
|
+
init_result: NIMBUSInitializationResponse = NIMBUSInitializationResponse.model_validate(
|
|
598
|
+
json.loads(response.content)
|
|
599
|
+
)
|
|
600
|
+
assert init_result.state_id == 1
|
|
601
|
+
|
|
602
|
+
# 2. Iterate
|
|
603
|
+
request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
|
|
604
|
+
problem_id=1,
|
|
605
|
+
preference=ReferencePoint(
|
|
606
|
+
aspiration_levels={
|
|
607
|
+
"f_1": 0.1,
|
|
608
|
+
"f_2": 0.8,
|
|
609
|
+
"f_3": 0.5,
|
|
610
|
+
}
|
|
611
|
+
),
|
|
612
|
+
current_objectives=init_result.current_solutions[0].objective_values,
|
|
613
|
+
parent_state_id=1,
|
|
614
|
+
num_desired=3,
|
|
615
|
+
)
|
|
616
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
617
|
+
solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
618
|
+
json.loads(response.content)
|
|
619
|
+
)
|
|
620
|
+
assert solve_result.state_id == 2
|
|
621
|
+
|
|
622
|
+
# 3. Save
|
|
623
|
+
request: NIMBUSSaveRequest = NIMBUSSaveRequest(
|
|
624
|
+
problem_id=1, parent_state_id=2, solution_info=[SolutionInfo(state_id=2, solution_index=1)]
|
|
625
|
+
)
|
|
626
|
+
response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
|
|
627
|
+
save_result: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content))
|
|
628
|
+
assert save_result.state_id == 3
|
|
629
|
+
|
|
630
|
+
# Assert that stuff is saved
|
|
631
|
+
request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
|
|
632
|
+
problem_id=1,
|
|
633
|
+
preference=ReferencePoint(
|
|
634
|
+
aspiration_levels={
|
|
635
|
+
"f_1": 0.9,
|
|
636
|
+
"f_2": 0.1,
|
|
637
|
+
"f_3": 0.5,
|
|
638
|
+
}
|
|
639
|
+
),
|
|
640
|
+
current_objectives=solve_result.current_solutions[0].objective_values,
|
|
641
|
+
num_desired=1,
|
|
642
|
+
parent_state_id=3,
|
|
643
|
+
)
|
|
644
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
645
|
+
solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
646
|
+
json.loads(response.content)
|
|
647
|
+
)
|
|
648
|
+
assert solve_result.state_id == 4
|
|
649
|
+
assert len(solve_result.saved_solutions) > 0
|
|
650
|
+
|
|
651
|
+
# 4. Delete save
|
|
652
|
+
request: NIMBUSDeleteSaveRequest = NIMBUSDeleteSaveRequest(state_id=2, solution_index=1)
|
|
653
|
+
response = post_json(client, "/method/nimbus/delete_save", request.model_dump(), access_token)
|
|
654
|
+
delete_save_result: NIMBUSDeleteSaveResponse = NIMBUSDeleteSaveResponse.model_validate(json.loads(response.content))
|
|
655
|
+
|
|
656
|
+
assert delete_save_result
|
|
657
|
+
|
|
658
|
+
# Assert that saved stuff has been deleted
|
|
659
|
+
|
|
660
|
+
# Assert that stuff is saved
|
|
661
|
+
request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
|
|
662
|
+
problem_id=1,
|
|
663
|
+
preference=ReferencePoint(
|
|
664
|
+
aspiration_levels={
|
|
665
|
+
"f_1": 0.1,
|
|
666
|
+
"f_2": 0.9,
|
|
667
|
+
"f_3": 0.4,
|
|
668
|
+
}
|
|
669
|
+
),
|
|
670
|
+
current_objectives=solve_result.current_solutions[0].objective_values,
|
|
671
|
+
num_desired=1,
|
|
672
|
+
parent_state_id=4,
|
|
673
|
+
)
|
|
674
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
675
|
+
solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
676
|
+
json.loads(response.content)
|
|
677
|
+
)
|
|
678
|
+
assert solve_result.state_id == 5
|
|
679
|
+
assert len(solve_result.saved_solutions) == 0
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def test_add_new_dm(client: TestClient):
|
|
683
|
+
"""Test that adding a decision maker works."""
|
|
684
|
+
# Create a new user to the database
|
|
685
|
+
good_response = client.post(
|
|
686
|
+
"/add_new_dm",
|
|
687
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
688
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
689
|
+
)
|
|
690
|
+
assert good_response.status_code == status.HTTP_201_CREATED
|
|
691
|
+
|
|
692
|
+
# There already should be a user named new_dm, so we shouldn't create another one.
|
|
693
|
+
bad_response = client.post(
|
|
694
|
+
"/add_new_dm",
|
|
695
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
696
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
697
|
+
)
|
|
698
|
+
assert bad_response.status_code == status.HTTP_409_CONFLICT
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def test_add_new_analyst(client: TestClient):
|
|
702
|
+
"""Test that adding a new analyst works."""
|
|
703
|
+
# Try to create an analyst without logging in
|
|
704
|
+
nologin_response = client.post(
|
|
705
|
+
"/add_new_analyst",
|
|
706
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
707
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# No user
|
|
711
|
+
assert nologin_response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
712
|
+
|
|
713
|
+
# Try to create an analyst using a dm account.
|
|
714
|
+
response = client.post(
|
|
715
|
+
"/add_new_dm",
|
|
716
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
717
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
718
|
+
)
|
|
719
|
+
assert response.status_code == status.HTTP_201_CREATED
|
|
720
|
+
|
|
721
|
+
dm_access_token = login(client, username="new_dm", password="new_dm") # noqa: S106
|
|
722
|
+
|
|
723
|
+
dm_response = client.post(
|
|
724
|
+
"/add_new_analyst",
|
|
725
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
726
|
+
headers={"Authorization": f"Bearer {dm_access_token}", "content-type": "application/x-www-form-urlencoded"},
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Creating an analyst using unauthorized user should return 401 status
|
|
730
|
+
assert dm_response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
731
|
+
|
|
732
|
+
# Login with proper user
|
|
733
|
+
analyst_access_token = login(client)
|
|
734
|
+
|
|
735
|
+
good_response = client.post(
|
|
736
|
+
"/add_new_analyst",
|
|
737
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
738
|
+
headers={
|
|
739
|
+
"Authorization": f"Bearer {analyst_access_token}",
|
|
740
|
+
"content-type": "application/x-www-form-urlencoded",
|
|
741
|
+
},
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Creating a new analyst with an analyst user should return 201
|
|
745
|
+
assert good_response.status_code == status.HTTP_201_CREATED
|
|
746
|
+
|
|
747
|
+
bad_response = client.post(
|
|
748
|
+
"/add_new_analyst",
|
|
749
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
750
|
+
headers={
|
|
751
|
+
"Authorization": f"Bearer {analyst_access_token}",
|
|
752
|
+
"content-type": "application/x-www-form-urlencoded",
|
|
753
|
+
},
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# Trying to create an analyst with username that is already in use should return 409
|
|
757
|
+
assert bad_response.status_code == status.HTTP_409_CONFLICT
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def test_login_logout(client: TestClient):
|
|
761
|
+
"""Test that logging out works."""
|
|
762
|
+
# Login (sets refresh token cookie)
|
|
763
|
+
login(client=client, username="analyst", password="analyst") # noqa: S106
|
|
764
|
+
|
|
765
|
+
# Refresh access token
|
|
766
|
+
response = client.post("/refresh")
|
|
767
|
+
# Access token refreshed
|
|
768
|
+
assert response.status_code == status.HTTP_200_OK
|
|
769
|
+
|
|
770
|
+
# Logout (remove the refresh token cookie)
|
|
771
|
+
response = client.post(
|
|
772
|
+
"/logout",
|
|
773
|
+
)
|
|
774
|
+
assert response.status_code == status.HTTP_200_OK
|
|
775
|
+
|
|
776
|
+
# Refresh access token
|
|
777
|
+
response = client.post(
|
|
778
|
+
"/refresh",
|
|
779
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
780
|
+
)
|
|
781
|
+
# Access token NOT refreshed
|
|
782
|
+
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def test_group_operations(client: TestClient):
|
|
786
|
+
"""Tests for websocket endpoints and websockets."""
|
|
787
|
+
create_endpoint = "/gdm/create_group"
|
|
788
|
+
delete_endpoint = "/gdm/delete_group"
|
|
789
|
+
add_user_endpoint = "/gdm/add_to_group"
|
|
790
|
+
remove_user_endpoint = "/gdm/remove_from_group"
|
|
791
|
+
group_info_endpoint = "/gdm/get_group_info"
|
|
792
|
+
|
|
793
|
+
# login to analyst
|
|
794
|
+
access_token = login(client=client, username="analyst", password="analyst") # noqa: S106
|
|
795
|
+
|
|
796
|
+
def get_info(gid: int):
|
|
797
|
+
return post_json(
|
|
798
|
+
client=client,
|
|
799
|
+
endpoint=group_info_endpoint,
|
|
800
|
+
json=GroupInfoRequest(
|
|
801
|
+
group_id=gid,
|
|
802
|
+
).model_dump(),
|
|
803
|
+
access_token=access_token,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
def get_user_info(token: str):
|
|
807
|
+
return client.get(
|
|
808
|
+
"/user_info",
|
|
809
|
+
headers={
|
|
810
|
+
"Authorization": f"Bearer {token}",
|
|
811
|
+
"content-type": "application/x-www-form-urlencoded",
|
|
812
|
+
},
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
# try to create group with no problem
|
|
816
|
+
response = post_json(
|
|
817
|
+
client=client,
|
|
818
|
+
endpoint=create_endpoint,
|
|
819
|
+
json=GroupCreateRequest(group_name="testGroup", problem_id=10).model_dump(),
|
|
820
|
+
access_token=access_token,
|
|
821
|
+
)
|
|
822
|
+
assert response.status_code == 404
|
|
823
|
+
|
|
824
|
+
# Create group properly
|
|
825
|
+
response = post_json(
|
|
826
|
+
client=client,
|
|
827
|
+
endpoint=create_endpoint,
|
|
828
|
+
json=GroupCreateRequest(group_name="testGroup", problem_id=2).model_dump(),
|
|
829
|
+
access_token=access_token,
|
|
830
|
+
)
|
|
831
|
+
assert response.status_code == 201
|
|
832
|
+
|
|
833
|
+
# Add a user to database
|
|
834
|
+
response = client.post(
|
|
835
|
+
"/add_new_dm",
|
|
836
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
837
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
838
|
+
)
|
|
839
|
+
assert response.status_code == status.HTTP_201_CREATED
|
|
840
|
+
|
|
841
|
+
# Add user to a group
|
|
842
|
+
response = post_json(
|
|
843
|
+
client=client,
|
|
844
|
+
endpoint=add_user_endpoint,
|
|
845
|
+
json=GroupModifyRequest(group_id=1, user_id=2).model_dump(),
|
|
846
|
+
access_token=access_token,
|
|
847
|
+
)
|
|
848
|
+
assert response.status_code == status.HTTP_200_OK
|
|
849
|
+
response = get_info(1)
|
|
850
|
+
assert response.status_code == status.HTTP_200_OK
|
|
851
|
+
group: GroupPublic = GroupPublic.model_validate(json.loads(response.content.decode("utf-8")))
|
|
852
|
+
assert 2 in group.user_ids
|
|
853
|
+
assert 1 not in group.user_ids
|
|
854
|
+
|
|
855
|
+
user_info = get_user_info(access_token)
|
|
856
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
857
|
+
assert 1 in user.group_ids
|
|
858
|
+
|
|
859
|
+
dm_access_token = login(client, "new_dm", "new_dm")
|
|
860
|
+
|
|
861
|
+
user_info = get_user_info(dm_access_token)
|
|
862
|
+
dm_user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
863
|
+
assert 1 in dm_user.group_ids
|
|
864
|
+
|
|
865
|
+
# TODO: websocket testing and result fetching?
|
|
866
|
+
|
|
867
|
+
# Remove user from a group
|
|
868
|
+
response = post_json(
|
|
869
|
+
client=client,
|
|
870
|
+
endpoint=remove_user_endpoint,
|
|
871
|
+
json=GroupModifyRequest(group_id=1, user_id=2).model_dump(),
|
|
872
|
+
access_token=access_token,
|
|
873
|
+
)
|
|
874
|
+
assert response.status_code == status.HTTP_200_OK
|
|
875
|
+
response = get_info(1)
|
|
876
|
+
assert response.status_code == status.HTTP_200_OK
|
|
877
|
+
group: GroupPublic = GroupPublic.model_validate(json.loads(response.content.decode("utf-8")))
|
|
878
|
+
assert 2 not in group.user_ids
|
|
879
|
+
|
|
880
|
+
user_info = get_user_info(dm_access_token)
|
|
881
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
882
|
+
assert 1 not in user.group_ids
|
|
883
|
+
|
|
884
|
+
user_info = get_user_info(access_token)
|
|
885
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
886
|
+
assert 1 in user.group_ids
|
|
887
|
+
|
|
888
|
+
# Delete the group
|
|
889
|
+
response = post_json(
|
|
890
|
+
client=client,
|
|
891
|
+
endpoint=delete_endpoint,
|
|
892
|
+
json=GroupInfoRequest(
|
|
893
|
+
group_id=1,
|
|
894
|
+
).model_dump(),
|
|
895
|
+
access_token=access_token,
|
|
896
|
+
)
|
|
897
|
+
assert response.status_code == status.HTTP_200_OK
|
|
898
|
+
response = get_info(1)
|
|
899
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
900
|
+
|
|
901
|
+
user_info = get_user_info(access_token)
|
|
902
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
903
|
+
assert 1 not in user.group_ids
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def test_preferred_solver(client: TestClient):
|
|
907
|
+
"""Test that setting a preferred solver for the problem is ok."""
|
|
908
|
+
access_token = login(client)
|
|
909
|
+
|
|
910
|
+
request = ProblemSelectSolverRequest(problem_id=1, solver_string_representation="THIS SOLVER DOESN'T EXIST")
|
|
911
|
+
response = post_json(client, "/problem/assign_solver", request.model_dump(), access_token)
|
|
912
|
+
assert response.status_code == 404
|
|
913
|
+
|
|
914
|
+
request = ProblemSelectSolverRequest(problem_id=1, solver_string_representation="pyomo_cbc")
|
|
915
|
+
response = post_json(client, "/problem/assign_solver", request.model_dump(), access_token)
|
|
916
|
+
assert response.status_code == 200
|
|
917
|
+
|
|
918
|
+
request = {"problem_id": 1, "metadata_type": "solver_selection_metadata"}
|
|
919
|
+
response = post_json(client, "/problem/get_metadata", request, access_token)
|
|
920
|
+
assert response.status_code == 200
|
|
921
|
+
|
|
922
|
+
model = SolverSelectionMetadata.model_validate(response.json()[0])
|
|
923
|
+
|
|
924
|
+
assert model.metadata_type == "solver_selection_metadata"
|
|
925
|
+
assert model.solver_string_representation == "pyomo_cbc"
|
|
926
|
+
|
|
927
|
+
# Test that the solver is in use
|
|
928
|
+
try:
|
|
929
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
930
|
+
response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
|
|
931
|
+
model = NIMBUSInitializationResponse.model_validate(response.json())
|
|
932
|
+
except Exception as e:
|
|
933
|
+
print(e)
|
|
934
|
+
print("^ This outcome is expected since pyomo_cbc doesn't support nonlinear problems.")
|
|
935
|
+
print(" As that solver is what we set it to be in the start, we can verify that they actually get used.")
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def test_get_available_solvers(client: TestClient):
|
|
939
|
+
"""Test that available solvers can be fetched."""
|
|
940
|
+
response = client.get("/problem/assign/solver")
|
|
941
|
+
|
|
942
|
+
assert response.status_code == 200
|
|
943
|
+
|
|
944
|
+
data = response.json()
|
|
945
|
+
assert isinstance(data, list)
|
|
946
|
+
|
|
947
|
+
# Check that the returned solver names match the available solvers
|
|
948
|
+
assert set(data) == set(available_solvers.keys())
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def test_emo_solve_with_reference_point(client: TestClient):
|
|
952
|
+
"""Test that using EMO with reference point works as expected."""
|
|
953
|
+
return
|
|
954
|
+
# TODO: This test fails because of websocket issues. Fix those and re-enable the test.
|
|
955
|
+
access_token = login(client)
|
|
956
|
+
request = EMOIterateRequest(
|
|
957
|
+
problem_id=1,
|
|
958
|
+
template_options=[rvea_options.template],
|
|
959
|
+
preference_options=ReferencePointOptions(preference={"f_1": 0.5, "f_2": 0.3, "f_3": 0.4}, method="Hakanen"),
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
response = post_json(client, "/method/emo/iterate", request.model_dump(), access_token)
|
|
963
|
+
|
|
964
|
+
assert response.status_code == status.HTTP_200_OK
|
|
965
|
+
|
|
966
|
+
# Validate the response structure
|
|
967
|
+
emo_response = EMOIterateResponse.model_validate(response.json())
|
|
968
|
+
assert emo_response.client_id is not None
|
|
969
|
+
state_id = emo_response.state_id
|
|
970
|
+
|
|
971
|
+
initial_time = time.time()
|
|
972
|
+
with client.websocket_connect(f"/method/emo/ws/{emo_response.client_id}") as websocket:
|
|
973
|
+
while time.time() - initial_time < 10:
|
|
974
|
+
message = websocket.receive_json()
|
|
975
|
+
if message.get("message") == f"Finished {emo_response.method_ids[0]}":
|
|
976
|
+
break
|
|
977
|
+
# Fetch the state to verify it worked
|
|
978
|
+
fetch_request = EMOFetchRequest(problem_id=1, parent_state_id=state_id)
|
|
979
|
+
response = post_json(client, "/method/emo/fetch", fetch_request.model_dump(), access_token)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def test_get_problem_metadata(client: TestClient):
|
|
983
|
+
"""Test that fetching problem metadata works."""
|
|
984
|
+
access_token = login(client=client)
|
|
985
|
+
|
|
986
|
+
# Problem with no metadata
|
|
987
|
+
req = {"problem_id": 1, "metadata_type": "forest_problem_metadata"}
|
|
988
|
+
response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
|
|
989
|
+
assert response.status_code == 200
|
|
990
|
+
assert response.json() == []
|
|
991
|
+
|
|
992
|
+
# Problem with forest metadata
|
|
993
|
+
req = {"problem_id": 2, "metadata_type": "forest_problem_metadata"}
|
|
994
|
+
response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
|
|
995
|
+
assert response.status_code == 200
|
|
996
|
+
assert response.json()[0]["metadata_type"] == "forest_problem_metadata"
|
|
997
|
+
assert response.json()[0]["schedule_dict"] == {"type": "dict"}
|
|
998
|
+
|
|
999
|
+
# No problem
|
|
1000
|
+
req = {"problem_id": 4, "metadata_type": "forest_problem_metadata"}
|
|
1001
|
+
response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
|
|
1002
|
+
assert response.status_code == 404
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def test_gdm_score_bands(client: TestClient):
|
|
1006
|
+
"""Test score bands endpoints."""
|
|
1007
|
+
access_token = login(client=client)
|
|
1008
|
+
|
|
1009
|
+
# create group
|
|
1010
|
+
req = GroupCreateRequest(
|
|
1011
|
+
group_name="group",
|
|
1012
|
+
problem_id=3, # The discrete representation problem
|
|
1013
|
+
).model_dump()
|
|
1014
|
+
response = post_json(client=client, endpoint="/gdm/create_group", json=req, access_token=access_token)
|
|
1015
|
+
assert response.status_code == 201
|
|
1016
|
+
|
|
1017
|
+
# Add a dm to the group
|
|
1018
|
+
# Create a new user to the database
|
|
1019
|
+
response = client.post(
|
|
1020
|
+
"/add_new_dm",
|
|
1021
|
+
data={"username": "dm", "password": "dm", "grant_type": "password"},
|
|
1022
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
1023
|
+
)
|
|
1024
|
+
assert response.status_code == 201
|
|
1025
|
+
|
|
1026
|
+
req = GroupModifyRequest(group_id=1, user_id=2).model_dump()
|
|
1027
|
+
response = post_json(client=client, endpoint="/gdm/add_to_group", json=req, access_token=access_token)
|
|
1028
|
+
assert response.status_code == 200
|
|
1029
|
+
|
|
1030
|
+
access_token = login(client=client, username="dm", password="dm")
|
|
1031
|
+
|
|
1032
|
+
# Now we have a group, so let's get on with making stuff with gdm score bands.
|
|
1033
|
+
req = GDMScoreBandsInitializationRequest(
|
|
1034
|
+
group_id=1,
|
|
1035
|
+
score_bands_config=SCOREBandsGDMConfig(
|
|
1036
|
+
score_bands_config=SCOREBandsConfig(clustering_algorithm=KMeansOptions(n_clusters=5)), from_iteration=None
|
|
1037
|
+
),
|
|
1038
|
+
).model_dump()
|
|
1039
|
+
response = post_json(
|
|
1040
|
+
client=client, endpoint="/gdm-score-bands/get-or-initialize", json=req, access_token=access_token
|
|
1041
|
+
)
|
|
1042
|
+
assert response.status_code == 200
|
|
1043
|
+
response_innards = GDMSCOREBandsHistoryResponse.model_validate(response.json())
|
|
1044
|
+
cluster_size_1 = len(response_innards.history[-1].result.clusters)
|
|
1045
|
+
|
|
1046
|
+
# VOTE AND CONFIRM
|
|
1047
|
+
req = GDMScoreBandsVoteRequest(
|
|
1048
|
+
group_id=1,
|
|
1049
|
+
vote=4,
|
|
1050
|
+
).model_dump()
|
|
1051
|
+
response = post_json(client=client, endpoint="/gdm-score-bands/vote", json=req, access_token=access_token)
|
|
1052
|
+
assert response.status_code == 200
|
|
1053
|
+
req = GroupInfoRequest(group_id=1).model_dump()
|
|
1054
|
+
response = post_json(client=client, endpoint="/gdm-score-bands/confirm", json=req, access_token=access_token)
|
|
1055
|
+
assert response.status_code == 200
|
|
1056
|
+
|
|
1057
|
+
req = GDMScoreBandsInitializationRequest(
|
|
1058
|
+
group_id=1,
|
|
1059
|
+
score_bands_config=SCOREBandsGDMConfig(
|
|
1060
|
+
score_bands_config=SCOREBandsConfig(clustering_algorithm=KMeansOptions(n_clusters=5)),
|
|
1061
|
+
from_iteration=response_innards.history[-1].latest_iteration,
|
|
1062
|
+
),
|
|
1063
|
+
).model_dump()
|
|
1064
|
+
response = post_json(
|
|
1065
|
+
client=client, endpoint="/gdm-score-bands/get-or-initialize", json=req, access_token=access_token
|
|
1066
|
+
)
|
|
1067
|
+
assert response.status_code == 200
|
|
1068
|
+
response_innards = GDMSCOREBandsHistoryResponse.model_validate(response.json())
|
|
1069
|
+
cluster_size_2 = len(response_innards.history[-1].result.clusters)
|
|
1070
|
+
|
|
1071
|
+
# Since we've made one iteration, the length of the clustering and therefore the active
|
|
1072
|
+
# indices should be smaller the second time around than the first one.
|
|
1073
|
+
assert cluster_size_1 > cluster_size_2
|
|
1074
|
+
|
|
1075
|
+
# TODO: Test reverting, re-clustering
|