desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +50 -0
- desdeo/api/config.py +90 -0
- desdeo/api/config.toml +64 -0
- desdeo/api/db.py +27 -0
- desdeo/api/db_init.py +85 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +266 -0
- desdeo/api/models/archive.py +23 -0
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +128 -0
- desdeo/api/models/problem.py +717 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +49 -0
- desdeo/api/models/state.py +463 -0
- desdeo/api/models/user.py +52 -0
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +765 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +307 -0
- desdeo/api/routers/reference_point_method.py +93 -0
- desdeo/api/routers/session.py +100 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +520 -0
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +100 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +151 -0
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +1179 -0
- desdeo/api/tests/test_routes.py +1075 -0
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/_logger.py +29 -0
- desdeo/api/utils/database.py +36 -0
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +34 -0
- desdeo/emo/__init__.py +159 -0
- desdeo/emo/hooks/archivers.py +188 -0
- desdeo/emo/methods/EAs.py +541 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +12 -0
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +1282 -0
- desdeo/emo/operators/evaluator.py +114 -0
- desdeo/emo/operators/generator.py +459 -0
- desdeo/emo/operators/mutation.py +1224 -0
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +1778 -0
- desdeo/emo/operators/termination.py +286 -0
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +41 -0
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +656 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +186 -0
- desdeo/problem/__init__.py +83 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +487 -0
- desdeo/problem/schema.py +1829 -0
- desdeo/problem/simulator_evaluator.py +348 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +88 -0
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +283 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problems.py +440 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +274 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +120 -0
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +165 -0
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +117 -0
- desdeo/tools/indicators_unary.py +362 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +265 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +134 -0
- desdeo/tools/patterns.py +283 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +477 -0
- desdeo/tools/reference_vectors.py +229 -0
- desdeo/tools/scalarization.py +2065 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +627 -0
- desdeo/tools/utils.py +388 -0
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.1.0.dist-info/METADATA +186 -0
- desdeo-2.1.0.dist-info/RECORD +180 -0
- {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
- desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
- desdeo-1.2.dist-info/METADATA +0 -16
- desdeo-1.2.dist-info/RECORD +0 -4
|
@@ -0,0 +1,1075 @@
|
|
|
1
|
+
"""Tests related to routes and routers."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
from fastapi import status
|
|
7
|
+
from fastapi.testclient import TestClient
|
|
8
|
+
|
|
9
|
+
from desdeo.api.models import (
|
|
10
|
+
CreateSessionRequest,
|
|
11
|
+
EMOFetchRequest,
|
|
12
|
+
EMOIterateRequest,
|
|
13
|
+
EMOIterateResponse,
|
|
14
|
+
ForestProblemMetaData,
|
|
15
|
+
GDMSCOREBandsHistoryResponse,
|
|
16
|
+
GDMScoreBandsInitializationRequest,
|
|
17
|
+
GDMScoreBandsVoteRequest,
|
|
18
|
+
GenericIntermediateSolutionResponse,
|
|
19
|
+
GroupCreateRequest,
|
|
20
|
+
GroupInfoRequest,
|
|
21
|
+
GroupModifyRequest,
|
|
22
|
+
GroupPublic,
|
|
23
|
+
InteractiveSessionDB,
|
|
24
|
+
IntermediateSolutionRequest,
|
|
25
|
+
NIMBUSClassificationRequest,
|
|
26
|
+
NIMBUSClassificationResponse,
|
|
27
|
+
NIMBUSDeleteSaveRequest,
|
|
28
|
+
NIMBUSDeleteSaveResponse,
|
|
29
|
+
NIMBUSFinalizeRequest,
|
|
30
|
+
NIMBUSFinalizeResponse,
|
|
31
|
+
NIMBUSInitializationRequest,
|
|
32
|
+
NIMBUSIntermediateSolutionResponse,
|
|
33
|
+
NIMBUSSaveRequest,
|
|
34
|
+
NIMBUSSaveResponse,
|
|
35
|
+
ProblemDB,
|
|
36
|
+
ProblemGetRequest,
|
|
37
|
+
ProblemInfo,
|
|
38
|
+
ProblemSelectSolverRequest,
|
|
39
|
+
ReferencePoint,
|
|
40
|
+
RPMSolveRequest,
|
|
41
|
+
SolutionInfo,
|
|
42
|
+
SolverSelectionMetadata,
|
|
43
|
+
User,
|
|
44
|
+
UserPublic,
|
|
45
|
+
)
|
|
46
|
+
from desdeo.api.models.nimbus import NIMBUSInitializationResponse
|
|
47
|
+
from desdeo.api.routers.user_authentication import create_access_token
|
|
48
|
+
from desdeo.emo.options.algorithms import rvea_options
|
|
49
|
+
from desdeo.emo.options.templates import ReferencePointOptions
|
|
50
|
+
from desdeo.gdm.score_bands import SCOREBandsGDMConfig
|
|
51
|
+
from desdeo.problem import Problem
|
|
52
|
+
from desdeo.problem.testproblems import dtlz2, simple_knapsack_vectors
|
|
53
|
+
from desdeo.tools.score_bands import KMeansOptions, SCOREBandsConfig
|
|
54
|
+
from desdeo.tools.utils import available_solvers
|
|
55
|
+
|
|
56
|
+
from .conftest import get_json, login, post_file_multipart, post_json
|
|
57
|
+
from .test_models import compare_models
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_user_login(client: TestClient):
|
|
61
|
+
"""Test that login works."""
|
|
62
|
+
response = client.post(
|
|
63
|
+
"/login",
|
|
64
|
+
data={"username": "analyst", "password": "analyst", "grant_type": "password"},
|
|
65
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
assert response.status_code == 200
|
|
69
|
+
assert "access_token" in response.json()
|
|
70
|
+
|
|
71
|
+
# wrong login
|
|
72
|
+
response = client.post(
|
|
73
|
+
"/login",
|
|
74
|
+
data={"username": "analyst", "password": "anallyst", "grant_type": "password"},
|
|
75
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
assert response.status_code == 401
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_tokens():
|
|
82
|
+
"""Test token generation."""
|
|
83
|
+
token_1 = create_access_token({"id": 1, "sub": "analyst"})
|
|
84
|
+
token_2 = create_access_token({"id": 1, "sub": "analyst"})
|
|
85
|
+
|
|
86
|
+
assert token_1 != token_2
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_refresh(client: TestClient):
|
|
90
|
+
"""Test that refreshing the access token works."""
|
|
91
|
+
# check that no previous cookies exist
|
|
92
|
+
assert len(client.cookies) == 0
|
|
93
|
+
|
|
94
|
+
# no cookie
|
|
95
|
+
response_bad = client.post("/refresh")
|
|
96
|
+
|
|
97
|
+
response_good = client.post(
|
|
98
|
+
"/login",
|
|
99
|
+
data={"username": "analyst", "password": "analyst", "grant_type": "password"},
|
|
100
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
assert response_bad.status_code == 401
|
|
104
|
+
assert response_good.status_code == 200
|
|
105
|
+
|
|
106
|
+
assert "access_token" in response_good.json()
|
|
107
|
+
assert len(client.cookies) == 1
|
|
108
|
+
assert "refresh_token" in client.cookies
|
|
109
|
+
|
|
110
|
+
response_refresh = client.post("/refresh")
|
|
111
|
+
|
|
112
|
+
assert "access_token" in response_refresh.json()
|
|
113
|
+
|
|
114
|
+
assert response_good.json()["access_token"] != response_refresh.json()["access_token"]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_get_problem(client: TestClient):
|
|
118
|
+
"""Test fetching specific problems based on their id."""
|
|
119
|
+
access_token = login(client)
|
|
120
|
+
|
|
121
|
+
response = post_json(client, "/problem/get", ProblemGetRequest(problem_id=1).model_dump(), access_token)
|
|
122
|
+
|
|
123
|
+
assert response.status_code == 200
|
|
124
|
+
|
|
125
|
+
info = ProblemInfo.model_validate(response.json())
|
|
126
|
+
|
|
127
|
+
assert info.id == 1
|
|
128
|
+
assert info.name == "dtlz2"
|
|
129
|
+
assert info.problem_metadata is None
|
|
130
|
+
|
|
131
|
+
response = post_json(client, "problem/get", ProblemGetRequest(problem_id=2).model_dump(), access_token)
|
|
132
|
+
|
|
133
|
+
assert response.status_code == 200
|
|
134
|
+
|
|
135
|
+
info = ProblemInfo.model_validate(response.json())
|
|
136
|
+
|
|
137
|
+
assert info.id == 2
|
|
138
|
+
assert info.name == "The river pollution problem"
|
|
139
|
+
assert isinstance(info.problem_metadata.forest_metadata[0], ForestProblemMetaData)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_add_problem(client: TestClient):
|
|
143
|
+
"""Test that adding a problem to the database works."""
|
|
144
|
+
access_token = login(client)
|
|
145
|
+
|
|
146
|
+
problem = simple_knapsack_vectors()
|
|
147
|
+
|
|
148
|
+
response = post_json(client, "/problem/add", problem.model_dump(), access_token)
|
|
149
|
+
|
|
150
|
+
assert response.status_code == status.HTTP_200_OK
|
|
151
|
+
|
|
152
|
+
problem_info: ProblemInfo = ProblemInfo.model_validate(response.json())
|
|
153
|
+
|
|
154
|
+
assert problem_info.name == "Simple two-objective Knapsack problem"
|
|
155
|
+
|
|
156
|
+
response = get_json(client, "problem/all_info", access_token)
|
|
157
|
+
|
|
158
|
+
assert response.status_code == status.HTTP_200_OK
|
|
159
|
+
|
|
160
|
+
problems = response.json()
|
|
161
|
+
|
|
162
|
+
assert len(problems) == 4
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def test_add_problem_json(client: TestClient, session_and_user: dict):
|
|
166
|
+
"""Test that adding a problem to the database works with JSON files."""
|
|
167
|
+
session = session_and_user["session"]
|
|
168
|
+
access_token = login(client)
|
|
169
|
+
problem = dtlz2(5, 3)
|
|
170
|
+
|
|
171
|
+
payload = problem.model_dump()
|
|
172
|
+
raw = json.dumps(payload).encode("utf-8")
|
|
173
|
+
|
|
174
|
+
response = post_file_multipart(client, "/problem/add_json", raw, access_token)
|
|
175
|
+
|
|
176
|
+
assert response.status_code == 200
|
|
177
|
+
|
|
178
|
+
problem_from_db = session.get(ProblemDB, 4)
|
|
179
|
+
|
|
180
|
+
assert compare_models(problem, Problem.from_problemdb(problem_from_db))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_new_session(client: TestClient, session_and_user: dict):
|
|
184
|
+
"""Test that creating a new session works as expected."""
|
|
185
|
+
user: User = session_and_user["user"]
|
|
186
|
+
session = session_and_user["session"]
|
|
187
|
+
|
|
188
|
+
assert user.active_session_id is None
|
|
189
|
+
|
|
190
|
+
access_token = login(client)
|
|
191
|
+
|
|
192
|
+
request = CreateSessionRequest(info="My session")
|
|
193
|
+
|
|
194
|
+
response = post_json(client, "/session/new", request.model_dump(), access_token)
|
|
195
|
+
|
|
196
|
+
assert response.status_code == status.HTTP_200_OK
|
|
197
|
+
|
|
198
|
+
assert user.active_session_id == 1
|
|
199
|
+
isession = session.get(InteractiveSessionDB, 1)
|
|
200
|
+
|
|
201
|
+
assert isession.info == "My session"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_get_session(client: TestClient, session_and_user: dict):
|
|
205
|
+
"""Test that getting a session via GET works as intended."""
|
|
206
|
+
user: User = session_and_user["user"]
|
|
207
|
+
|
|
208
|
+
access_token = login(client)
|
|
209
|
+
|
|
210
|
+
# no sessions
|
|
211
|
+
response = client.get(
|
|
212
|
+
"/session/get/1",
|
|
213
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
214
|
+
)
|
|
215
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
216
|
+
|
|
217
|
+
# add session 1
|
|
218
|
+
request = CreateSessionRequest(info="Session 1")
|
|
219
|
+
response = post_json(client, "/session/new", request.model_dump(), access_token)
|
|
220
|
+
assert response.status_code == status.HTTP_200_OK
|
|
221
|
+
assert user.active_session_id == 1
|
|
222
|
+
|
|
223
|
+
# add session 2
|
|
224
|
+
request = CreateSessionRequest(info="Session 2")
|
|
225
|
+
response = post_json(client, "/session/new", request.model_dump(), access_token)
|
|
226
|
+
assert response.status_code == status.HTTP_200_OK
|
|
227
|
+
assert user.active_session_id == 2
|
|
228
|
+
|
|
229
|
+
# fetch session 1
|
|
230
|
+
response = client.get(
|
|
231
|
+
"/session/get/1",
|
|
232
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
233
|
+
)
|
|
234
|
+
assert response.status_code == status.HTTP_200_OK
|
|
235
|
+
assert response.json()["id"] == 1
|
|
236
|
+
|
|
237
|
+
# fetch session 2
|
|
238
|
+
response = client.get(
|
|
239
|
+
"/session/get/2",
|
|
240
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
241
|
+
)
|
|
242
|
+
assert response.status_code == status.HTTP_200_OK
|
|
243
|
+
assert response.json()["id"] == 2
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def test_get_all_sessions_success(client: TestClient, session_and_user: dict):
|
|
247
|
+
"""Test getting all sessions when sessions exist."""
|
|
248
|
+
access_token = login(client)
|
|
249
|
+
|
|
250
|
+
# create 2 test sessions
|
|
251
|
+
post_json(client, "/session/new", {"info": "S1"}, access_token)
|
|
252
|
+
post_json(client, "/session/new", {"info": "S2"}, access_token)
|
|
253
|
+
|
|
254
|
+
response = client.get(
|
|
255
|
+
"/session/get_all",
|
|
256
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
assert response.status_code == status.HTTP_200_OK
|
|
260
|
+
data = response.json()
|
|
261
|
+
assert isinstance(data, list)
|
|
262
|
+
assert len(data) == 2
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def test_get_all_sessions_not_found(client: TestClient, session_and_user: dict):
|
|
266
|
+
"""Test get_all returns 404 if user has no sessions."""
|
|
267
|
+
access_token = login(client)
|
|
268
|
+
|
|
269
|
+
response = client.get(
|
|
270
|
+
"/session/get_all",
|
|
271
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def test_delete_session_success(client: TestClient, session_and_user: dict):
|
|
278
|
+
"""Test deleting an existing session."""
|
|
279
|
+
access_token = login(client)
|
|
280
|
+
|
|
281
|
+
# create session
|
|
282
|
+
post_json(client, "/session/new", {"info": "To delete"}, access_token)
|
|
283
|
+
|
|
284
|
+
response = client.delete(
|
|
285
|
+
"/session/1",
|
|
286
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
290
|
+
|
|
291
|
+
# verify it's gone
|
|
292
|
+
response = client.get(
|
|
293
|
+
"/session/get/1",
|
|
294
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
295
|
+
)
|
|
296
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def test_delete_session_not_found(client: TestClient, session_and_user: dict):
|
|
300
|
+
"""Test deleting a non-existent session returns 404."""
|
|
301
|
+
access_token = login(client)
|
|
302
|
+
|
|
303
|
+
response = client.delete(
|
|
304
|
+
"/session/999",
|
|
305
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def test_rpm_solve(client: TestClient):
|
|
312
|
+
"""Test that using the reference point method works as expected."""
|
|
313
|
+
access_token = login(client)
|
|
314
|
+
|
|
315
|
+
request = RPMSolveRequest(
|
|
316
|
+
problem_id=1, preference=ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.3, "f_3": 0.4})
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
response = post_json(client, "/method/rpm/solve", request.model_dump(), access_token)
|
|
320
|
+
|
|
321
|
+
assert response.status_code == status.HTTP_200_OK
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def test_nimbus_solve(client: TestClient):
|
|
325
|
+
"""Test that using the NIMBUS method works as expected."""
|
|
326
|
+
access_token = login(client)
|
|
327
|
+
|
|
328
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
329
|
+
|
|
330
|
+
request = NIMBUSClassificationRequest(
|
|
331
|
+
problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=3
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
335
|
+
assert response.status_code == status.HTTP_200_OK
|
|
336
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
337
|
+
json.loads(response.content.decode("utf-8"))
|
|
338
|
+
)
|
|
339
|
+
assert result.previous_preference == preference
|
|
340
|
+
assert len(result.all_solutions) == 3
|
|
341
|
+
|
|
342
|
+
request = NIMBUSSaveRequest(
|
|
343
|
+
problem_id=1,
|
|
344
|
+
parent_state_id=result.state_id,
|
|
345
|
+
solution_info=[
|
|
346
|
+
SolutionInfo(state_id=1, solution_index=0, name="Test solution 1"),
|
|
347
|
+
SolutionInfo(state_id=1, solution_index=2, name="Test solution 3"),
|
|
348
|
+
SolutionInfo(state_id=1, solution_index=2, name="Test solution 34"), # saved twice!
|
|
349
|
+
],
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
|
|
353
|
+
assert response.status_code == status.HTTP_200_OK
|
|
354
|
+
result2: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
355
|
+
assert result2.state_id is not None
|
|
356
|
+
|
|
357
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.1, "f_2": 0.1, "f_3": 0.9})
|
|
358
|
+
|
|
359
|
+
request = NIMBUSClassificationRequest(
|
|
360
|
+
problem_id=1,
|
|
361
|
+
preference=preference,
|
|
362
|
+
current_objectives=result.current_solutions[0].objective_values,
|
|
363
|
+
num_desired=3,
|
|
364
|
+
parent_state_id=result2.state_id,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
368
|
+
|
|
369
|
+
assert response.status_code == status.HTTP_200_OK
|
|
370
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
371
|
+
json.loads(response.content.decode("utf-8"))
|
|
372
|
+
)
|
|
373
|
+
assert result.previous_preference == preference
|
|
374
|
+
# We saved the same solution twice, so the filtering should remove one of those.
|
|
375
|
+
assert len(result.saved_solutions) == 2
|
|
376
|
+
assert len(result.all_solutions) == 6 # should not count saved solutions twice
|
|
377
|
+
|
|
378
|
+
# Save some more solutions!
|
|
379
|
+
request = NIMBUSSaveRequest(
|
|
380
|
+
problem_id=1,
|
|
381
|
+
parent_state_id=result.state_id,
|
|
382
|
+
solution_info=[SolutionInfo(state_id=result.state_id, solution_index=1, name="Test solution 2")],
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
|
|
386
|
+
assert response.status_code == status.HTTP_200_OK
|
|
387
|
+
result2: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
388
|
+
assert result2.state_id is not None
|
|
389
|
+
|
|
390
|
+
# Same as the first one. Therefore, (I believe) STOM and ASF give same solutions,
|
|
391
|
+
# which should be reflected on the amount of all solutions
|
|
392
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
393
|
+
|
|
394
|
+
request = NIMBUSClassificationRequest(
|
|
395
|
+
problem_id=1,
|
|
396
|
+
preference=preference,
|
|
397
|
+
current_objectives=result.current_solutions[0].objective_values,
|
|
398
|
+
num_desired=3,
|
|
399
|
+
parent_state_id=result2.state_id,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
403
|
+
assert response.status_code == status.HTTP_200_OK
|
|
404
|
+
result3: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
405
|
+
json.loads(response.content.decode("utf-8"))
|
|
406
|
+
)
|
|
407
|
+
assert result3.previous_preference == preference
|
|
408
|
+
assert len(result3.saved_solutions) == 3
|
|
409
|
+
assert len(result3.all_solutions) == 7
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def test_intermediate_solve(client: TestClient):
|
|
413
|
+
"""Test that solving intermediate solutions works as expected."""
|
|
414
|
+
access_token = login(client)
|
|
415
|
+
|
|
416
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
417
|
+
|
|
418
|
+
request = NIMBUSClassificationRequest(
|
|
419
|
+
problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=2
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
423
|
+
assert response.status_code == status.HTTP_200_OK
|
|
424
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
425
|
+
json.loads(response.content.decode("utf-8"))
|
|
426
|
+
)
|
|
427
|
+
assert len(result.all_solutions) == 2
|
|
428
|
+
|
|
429
|
+
# Save some solutions!
|
|
430
|
+
solution_1 = SolutionInfo(state_id=result.state_id, solution_index=0)
|
|
431
|
+
solution_2 = SolutionInfo(state_id=result.state_id, solution_index=1, name="named solution")
|
|
432
|
+
|
|
433
|
+
# Create request for intermediate solutions using solutions created with nimbus solve
|
|
434
|
+
request = IntermediateSolutionRequest(
|
|
435
|
+
problem_id=1,
|
|
436
|
+
context="test",
|
|
437
|
+
reference_solution_1=solution_1,
|
|
438
|
+
reference_solution_2=solution_2,
|
|
439
|
+
num_desired=3,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Test the generic intermediate endpoint
|
|
443
|
+
response = post_json(client, "/method/generic/intermediate", request.model_dump(), access_token)
|
|
444
|
+
assert response.status_code == status.HTTP_200_OK
|
|
445
|
+
result: GenericIntermediateSolutionResponse = GenericIntermediateSolutionResponse.model_validate(
|
|
446
|
+
json.loads(response.content.decode("utf-8"))
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Test the NIMBUS-specific intermediate endpoint
|
|
450
|
+
nimbus_request = IntermediateSolutionRequest(
|
|
451
|
+
problem_id=1,
|
|
452
|
+
context="nimbus",
|
|
453
|
+
reference_solution_1=solution_1,
|
|
454
|
+
reference_solution_2=solution_2,
|
|
455
|
+
num_desired=2,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
nimbus_response = post_json(client, "/method/nimbus/intermediate", nimbus_request.model_dump(), access_token)
|
|
459
|
+
assert nimbus_response.status_code == status.HTTP_200_OK
|
|
460
|
+
nimbus_result = NIMBUSIntermediateSolutionResponse.model_validate(
|
|
461
|
+
json.loads(nimbus_response.content.decode("utf-8"))
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Verify the NIMBUS response contains expected fields
|
|
465
|
+
assert nimbus_result.state_id is not None
|
|
466
|
+
assert len(nimbus_result.current_solutions) == 2
|
|
467
|
+
assert len(nimbus_result.all_solutions) == 7
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def test_nimbus_initialize(client: TestClient):
|
|
471
|
+
"""Test that initializing NIMBUS works without specifying a solver."""
|
|
472
|
+
access_token = login(client)
|
|
473
|
+
|
|
474
|
+
# test with no starting point
|
|
475
|
+
request = NIMBUSInitializationRequest(problem_id=1, solver=None)
|
|
476
|
+
|
|
477
|
+
response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
|
|
478
|
+
|
|
479
|
+
assert response.status_code == status.HTTP_200_OK
|
|
480
|
+
init_result = NIMBUSInitializationResponse.model_validate(json.loads(response.content))
|
|
481
|
+
|
|
482
|
+
assert init_result.state_id == 1
|
|
483
|
+
assert len(init_result.current_solutions) == 1
|
|
484
|
+
assert len(init_result.saved_solutions) == 0
|
|
485
|
+
assert len(init_result.all_solutions) == 1
|
|
486
|
+
|
|
487
|
+
# test with starting point given as solution info
|
|
488
|
+
request_w_info = NIMBUSInitializationRequest(
|
|
489
|
+
problem_id=1, starting_point=SolutionInfo(state_id=1, solution_index=0)
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
response_w_info = post_json(client, "/method/nimbus/initialize", request_w_info.model_dump(), access_token)
|
|
493
|
+
|
|
494
|
+
assert response_w_info.status_code == status.HTTP_200_OK
|
|
495
|
+
result_w_info = NIMBUSInitializationResponse.model_validate(json.loads(response_w_info.content))
|
|
496
|
+
|
|
497
|
+
assert result_w_info.state_id == 2
|
|
498
|
+
assert len(result_w_info.current_solutions) == 1
|
|
499
|
+
assert len(result_w_info.saved_solutions) == 0
|
|
500
|
+
assert len(result_w_info.all_solutions) == 1 # this is still one because the new solution will be a duplicate.
|
|
501
|
+
|
|
502
|
+
# test with starting point given as a reference point
|
|
503
|
+
request_w_ref = NIMBUSInitializationRequest(
|
|
504
|
+
problem_id=1, starting_point=ReferencePoint(aspiration_levels={"f_1": 0.2, "f_2": 0.8, "f_3": 0.4})
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
response_w_ref = post_json(client, "/method/nimbus/initialize", request_w_ref.model_dump(), access_token)
|
|
508
|
+
|
|
509
|
+
assert response_w_ref.status_code == status.HTTP_200_OK
|
|
510
|
+
result_w_ref = NIMBUSInitializationResponse.model_validate(json.loads(response_w_ref.content))
|
|
511
|
+
|
|
512
|
+
assert result_w_ref.state_id == 3
|
|
513
|
+
assert len(result_w_ref.current_solutions) == 1
|
|
514
|
+
assert len(result_w_ref.saved_solutions) == 0
|
|
515
|
+
assert len(result_w_ref.all_solutions) == 2 # we should have a new one
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def test_nimbus_finalize(client: TestClient):
|
|
519
|
+
"""Test for seeing if NIMBUS finalization works."""
|
|
520
|
+
access_token = login(client)
|
|
521
|
+
|
|
522
|
+
# create some previous iterations
|
|
523
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
524
|
+
response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
|
|
525
|
+
assert response.status_code == status.HTTP_200_OK
|
|
526
|
+
init_response = NIMBUSInitializationResponse.model_validate(json.loads(response.content))
|
|
527
|
+
assert init_response.state_id == 1
|
|
528
|
+
assert len(init_response.current_solutions) == 1
|
|
529
|
+
assert len(init_response.saved_solutions) == 0
|
|
530
|
+
assert len(init_response.all_solutions) == 1
|
|
531
|
+
|
|
532
|
+
preference = ReferencePoint(aspiration_levels={"f_1": 0.5, "f_2": 0.6, "f_3": 0.4})
|
|
533
|
+
|
|
534
|
+
request = NIMBUSClassificationRequest(
|
|
535
|
+
problem_id=1, preference=preference, current_objectives={"f_1": 0.6, "f_2": 0.4, "f_3": 0.5}, num_desired=3
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
539
|
+
assert response.status_code == status.HTTP_200_OK
|
|
540
|
+
result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
541
|
+
json.loads(response.content.decode("utf-8"))
|
|
542
|
+
)
|
|
543
|
+
assert result.previous_preference == preference
|
|
544
|
+
assert len(result.current_solutions) == 3
|
|
545
|
+
|
|
546
|
+
solution_index = 2
|
|
547
|
+
|
|
548
|
+
optim_obj = result.current_solutions[solution_index].objective_values
|
|
549
|
+
optim_var = result.current_solutions[solution_index].variable_values
|
|
550
|
+
|
|
551
|
+
state_id = result.state_id
|
|
552
|
+
|
|
553
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
554
|
+
response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
|
|
555
|
+
assert response.status_code == status.HTTP_200_OK
|
|
556
|
+
classify_result = NIMBUSClassificationResponse.model_validate(json.loads(response.content))
|
|
557
|
+
assert classify_result.state_id == 2
|
|
558
|
+
assert len(classify_result.current_solutions) == 3
|
|
559
|
+
assert len(classify_result.saved_solutions) == 0
|
|
560
|
+
assert len(classify_result.all_solutions) == 4
|
|
561
|
+
|
|
562
|
+
request = NIMBUSFinalizeRequest(
|
|
563
|
+
problem_id=1,
|
|
564
|
+
solution_info=SolutionInfo(state_id=state_id, solution_index=solution_index),
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Finalize the process
|
|
568
|
+
response = post_json(client, "/method/nimbus/finalize", request.model_dump(), access_token)
|
|
569
|
+
assert response.status_code == status.HTTP_200_OK
|
|
570
|
+
result: NIMBUSFinalizeResponse = NIMBUSFinalizeResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
571
|
+
assert result.response_type == "nimbus.finalize"
|
|
572
|
+
assert result.final_solution.objective_values == optim_obj
|
|
573
|
+
assert result.final_solution.variable_values == optim_var
|
|
574
|
+
assert result.final_solution.state_id != result.state_id
|
|
575
|
+
assert result.all_solutions is not None
|
|
576
|
+
|
|
577
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
578
|
+
|
|
579
|
+
# The last item in the pipe is a finalize state, so we should be getting a finalize response.
|
|
580
|
+
response = post_json(client, "/method/nimbus/get-or-initialize", request.model_dump(), access_token)
|
|
581
|
+
assert response.status_code == status.HTTP_200_OK
|
|
582
|
+
result: NIMBUSFinalizeResponse = NIMBUSFinalizeResponse.model_validate(json.loads(response.content.decode("utf-8")))
|
|
583
|
+
assert result.response_type == "nimbus.finalize"
|
|
584
|
+
assert result.final_solution.objective_values == optim_obj
|
|
585
|
+
assert result.final_solution.variable_values == optim_var
|
|
586
|
+
assert result.final_solution.state_id != result.state_id
|
|
587
|
+
assert result.all_solutions is not None
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def test_nimbus_save_and_delete_save(client: TestClient):
|
|
591
|
+
"""Test that NIMBUS saving and save deletion works."""
|
|
592
|
+
access_token = login(client)
|
|
593
|
+
|
|
594
|
+
# 1. Initialize
|
|
595
|
+
request: NIMBUSInitializationRequest = NIMBUSInitializationRequest(problem_id=1)
|
|
596
|
+
response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
|
|
597
|
+
init_result: NIMBUSInitializationResponse = NIMBUSInitializationResponse.model_validate(
|
|
598
|
+
json.loads(response.content)
|
|
599
|
+
)
|
|
600
|
+
assert init_result.state_id == 1
|
|
601
|
+
|
|
602
|
+
# 2. Iterate
|
|
603
|
+
request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
|
|
604
|
+
problem_id=1,
|
|
605
|
+
preference=ReferencePoint(
|
|
606
|
+
aspiration_levels={
|
|
607
|
+
"f_1": 0.1,
|
|
608
|
+
"f_2": 0.8,
|
|
609
|
+
"f_3": 0.5,
|
|
610
|
+
}
|
|
611
|
+
),
|
|
612
|
+
current_objectives=init_result.current_solutions[0].objective_values,
|
|
613
|
+
parent_state_id=1,
|
|
614
|
+
num_desired=3,
|
|
615
|
+
)
|
|
616
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
617
|
+
solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
618
|
+
json.loads(response.content)
|
|
619
|
+
)
|
|
620
|
+
assert solve_result.state_id == 2
|
|
621
|
+
|
|
622
|
+
# 3. Save
|
|
623
|
+
request: NIMBUSSaveRequest = NIMBUSSaveRequest(
|
|
624
|
+
problem_id=1, parent_state_id=2, solution_info=[SolutionInfo(state_id=2, solution_index=1)]
|
|
625
|
+
)
|
|
626
|
+
response = post_json(client, "/method/nimbus/save", request.model_dump(), access_token)
|
|
627
|
+
save_result: NIMBUSSaveResponse = NIMBUSSaveResponse.model_validate(json.loads(response.content))
|
|
628
|
+
assert save_result.state_id == 3
|
|
629
|
+
|
|
630
|
+
# Assert that stuff is saved
|
|
631
|
+
request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
|
|
632
|
+
problem_id=1,
|
|
633
|
+
preference=ReferencePoint(
|
|
634
|
+
aspiration_levels={
|
|
635
|
+
"f_1": 0.9,
|
|
636
|
+
"f_2": 0.1,
|
|
637
|
+
"f_3": 0.5,
|
|
638
|
+
}
|
|
639
|
+
),
|
|
640
|
+
current_objectives=solve_result.current_solutions[0].objective_values,
|
|
641
|
+
num_desired=1,
|
|
642
|
+
parent_state_id=3,
|
|
643
|
+
)
|
|
644
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
645
|
+
solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
646
|
+
json.loads(response.content)
|
|
647
|
+
)
|
|
648
|
+
assert solve_result.state_id == 4
|
|
649
|
+
assert len(solve_result.saved_solutions) > 0
|
|
650
|
+
|
|
651
|
+
# 4. Delete save
|
|
652
|
+
request: NIMBUSDeleteSaveRequest = NIMBUSDeleteSaveRequest(state_id=2, solution_index=1)
|
|
653
|
+
response = post_json(client, "/method/nimbus/delete_save", request.model_dump(), access_token)
|
|
654
|
+
delete_save_result: NIMBUSDeleteSaveResponse = NIMBUSDeleteSaveResponse.model_validate(json.loads(response.content))
|
|
655
|
+
|
|
656
|
+
assert delete_save_result
|
|
657
|
+
|
|
658
|
+
# Assert that saved stuff has been deleted
|
|
659
|
+
|
|
660
|
+
# Assert that stuff is saved
|
|
661
|
+
request: NIMBUSClassificationRequest = NIMBUSClassificationRequest(
|
|
662
|
+
problem_id=1,
|
|
663
|
+
preference=ReferencePoint(
|
|
664
|
+
aspiration_levels={
|
|
665
|
+
"f_1": 0.1,
|
|
666
|
+
"f_2": 0.9,
|
|
667
|
+
"f_3": 0.4,
|
|
668
|
+
}
|
|
669
|
+
),
|
|
670
|
+
current_objectives=solve_result.current_solutions[0].objective_values,
|
|
671
|
+
num_desired=1,
|
|
672
|
+
parent_state_id=4,
|
|
673
|
+
)
|
|
674
|
+
response = post_json(client, "/method/nimbus/solve", request.model_dump(), access_token)
|
|
675
|
+
solve_result: NIMBUSClassificationResponse = NIMBUSClassificationResponse.model_validate(
|
|
676
|
+
json.loads(response.content)
|
|
677
|
+
)
|
|
678
|
+
assert solve_result.state_id == 5
|
|
679
|
+
assert len(solve_result.saved_solutions) == 0
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def test_add_new_dm(client: TestClient):
|
|
683
|
+
"""Test that adding a decision maker works."""
|
|
684
|
+
# Create a new user to the database
|
|
685
|
+
good_response = client.post(
|
|
686
|
+
"/add_new_dm",
|
|
687
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
688
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
689
|
+
)
|
|
690
|
+
assert good_response.status_code == status.HTTP_201_CREATED
|
|
691
|
+
|
|
692
|
+
# There already should be a user named new_dm, so we shouldn't create another one.
|
|
693
|
+
bad_response = client.post(
|
|
694
|
+
"/add_new_dm",
|
|
695
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
696
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
697
|
+
)
|
|
698
|
+
assert bad_response.status_code == status.HTTP_409_CONFLICT
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def test_add_new_analyst(client: TestClient):
|
|
702
|
+
"""Test that adding a new analyst works."""
|
|
703
|
+
# Try to create an analyst without logging in
|
|
704
|
+
nologin_response = client.post(
|
|
705
|
+
"/add_new_analyst",
|
|
706
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
707
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# No user
|
|
711
|
+
assert nologin_response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
712
|
+
|
|
713
|
+
# Try to create an analyst using a dm account.
|
|
714
|
+
response = client.post(
|
|
715
|
+
"/add_new_dm",
|
|
716
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
717
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
718
|
+
)
|
|
719
|
+
assert response.status_code == status.HTTP_201_CREATED
|
|
720
|
+
|
|
721
|
+
dm_access_token = login(client, username="new_dm", password="new_dm") # noqa: S106
|
|
722
|
+
|
|
723
|
+
dm_response = client.post(
|
|
724
|
+
"/add_new_analyst",
|
|
725
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
726
|
+
headers={"Authorization": f"Bearer {dm_access_token}", "content-type": "application/x-www-form-urlencoded"},
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Creating an analyst using unauthorized user should return 401 status
|
|
730
|
+
assert dm_response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
731
|
+
|
|
732
|
+
# Login with proper user
|
|
733
|
+
analyst_access_token = login(client)
|
|
734
|
+
|
|
735
|
+
good_response = client.post(
|
|
736
|
+
"/add_new_analyst",
|
|
737
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
738
|
+
headers={
|
|
739
|
+
"Authorization": f"Bearer {analyst_access_token}",
|
|
740
|
+
"content-type": "application/x-www-form-urlencoded",
|
|
741
|
+
},
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Creating a new analyst with an analyst user should return 201
|
|
745
|
+
assert good_response.status_code == status.HTTP_201_CREATED
|
|
746
|
+
|
|
747
|
+
bad_response = client.post(
|
|
748
|
+
"/add_new_analyst",
|
|
749
|
+
data={"username": "new_analyst", "password": "new_analyst", "grant_type": "password"},
|
|
750
|
+
headers={
|
|
751
|
+
"Authorization": f"Bearer {analyst_access_token}",
|
|
752
|
+
"content-type": "application/x-www-form-urlencoded",
|
|
753
|
+
},
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# Trying to create an analyst with username that is already in use should return 409
|
|
757
|
+
assert bad_response.status_code == status.HTTP_409_CONFLICT
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def test_login_logout(client: TestClient):
|
|
761
|
+
"""Test that logging out works."""
|
|
762
|
+
# Login (sets refresh token cookie)
|
|
763
|
+
login(client=client, username="analyst", password="analyst") # noqa: S106
|
|
764
|
+
|
|
765
|
+
# Refresh access token
|
|
766
|
+
response = client.post("/refresh")
|
|
767
|
+
# Access token refreshed
|
|
768
|
+
assert response.status_code == status.HTTP_200_OK
|
|
769
|
+
|
|
770
|
+
# Logout (remove the refresh token cookie)
|
|
771
|
+
response = client.post(
|
|
772
|
+
"/logout",
|
|
773
|
+
)
|
|
774
|
+
assert response.status_code == status.HTTP_200_OK
|
|
775
|
+
|
|
776
|
+
# Refresh access token
|
|
777
|
+
response = client.post(
|
|
778
|
+
"/refresh",
|
|
779
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
780
|
+
)
|
|
781
|
+
# Access token NOT refreshed
|
|
782
|
+
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def test_group_operations(client: TestClient):
|
|
786
|
+
"""Tests for websocket endpoints and websockets."""
|
|
787
|
+
create_endpoint = "/gdm/create_group"
|
|
788
|
+
delete_endpoint = "/gdm/delete_group"
|
|
789
|
+
add_user_endpoint = "/gdm/add_to_group"
|
|
790
|
+
remove_user_endpoint = "/gdm/remove_from_group"
|
|
791
|
+
group_info_endpoint = "/gdm/get_group_info"
|
|
792
|
+
|
|
793
|
+
# login to analyst
|
|
794
|
+
access_token = login(client=client, username="analyst", password="analyst") # noqa: S106
|
|
795
|
+
|
|
796
|
+
def get_info(gid: int):
|
|
797
|
+
return post_json(
|
|
798
|
+
client=client,
|
|
799
|
+
endpoint=group_info_endpoint,
|
|
800
|
+
json=GroupInfoRequest(
|
|
801
|
+
group_id=gid,
|
|
802
|
+
).model_dump(),
|
|
803
|
+
access_token=access_token,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
def get_user_info(token: str):
|
|
807
|
+
return client.get(
|
|
808
|
+
"/user_info",
|
|
809
|
+
headers={
|
|
810
|
+
"Authorization": f"Bearer {token}",
|
|
811
|
+
"content-type": "application/x-www-form-urlencoded",
|
|
812
|
+
},
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
# try to create group with no problem
|
|
816
|
+
response = post_json(
|
|
817
|
+
client=client,
|
|
818
|
+
endpoint=create_endpoint,
|
|
819
|
+
json=GroupCreateRequest(group_name="testGroup", problem_id=10).model_dump(),
|
|
820
|
+
access_token=access_token,
|
|
821
|
+
)
|
|
822
|
+
assert response.status_code == 404
|
|
823
|
+
|
|
824
|
+
# Create group properly
|
|
825
|
+
response = post_json(
|
|
826
|
+
client=client,
|
|
827
|
+
endpoint=create_endpoint,
|
|
828
|
+
json=GroupCreateRequest(group_name="testGroup", problem_id=2).model_dump(),
|
|
829
|
+
access_token=access_token,
|
|
830
|
+
)
|
|
831
|
+
assert response.status_code == 201
|
|
832
|
+
|
|
833
|
+
# Add a user to database
|
|
834
|
+
response = client.post(
|
|
835
|
+
"/add_new_dm",
|
|
836
|
+
data={"username": "new_dm", "password": "new_dm", "grant_type": "password"},
|
|
837
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
838
|
+
)
|
|
839
|
+
assert response.status_code == status.HTTP_201_CREATED
|
|
840
|
+
|
|
841
|
+
# Add user to a group
|
|
842
|
+
response = post_json(
|
|
843
|
+
client=client,
|
|
844
|
+
endpoint=add_user_endpoint,
|
|
845
|
+
json=GroupModifyRequest(group_id=1, user_id=2).model_dump(),
|
|
846
|
+
access_token=access_token,
|
|
847
|
+
)
|
|
848
|
+
assert response.status_code == status.HTTP_200_OK
|
|
849
|
+
response = get_info(1)
|
|
850
|
+
assert response.status_code == status.HTTP_200_OK
|
|
851
|
+
group: GroupPublic = GroupPublic.model_validate(json.loads(response.content.decode("utf-8")))
|
|
852
|
+
assert 2 in group.user_ids
|
|
853
|
+
assert 1 not in group.user_ids
|
|
854
|
+
|
|
855
|
+
user_info = get_user_info(access_token)
|
|
856
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
857
|
+
assert 1 in user.group_ids
|
|
858
|
+
|
|
859
|
+
dm_access_token = login(client, "new_dm", "new_dm")
|
|
860
|
+
|
|
861
|
+
user_info = get_user_info(dm_access_token)
|
|
862
|
+
dm_user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
863
|
+
assert 1 in dm_user.group_ids
|
|
864
|
+
|
|
865
|
+
# TODO: websocket testing and result fetching?
|
|
866
|
+
|
|
867
|
+
# Remove user from a group
|
|
868
|
+
response = post_json(
|
|
869
|
+
client=client,
|
|
870
|
+
endpoint=remove_user_endpoint,
|
|
871
|
+
json=GroupModifyRequest(group_id=1, user_id=2).model_dump(),
|
|
872
|
+
access_token=access_token,
|
|
873
|
+
)
|
|
874
|
+
assert response.status_code == status.HTTP_200_OK
|
|
875
|
+
response = get_info(1)
|
|
876
|
+
assert response.status_code == status.HTTP_200_OK
|
|
877
|
+
group: GroupPublic = GroupPublic.model_validate(json.loads(response.content.decode("utf-8")))
|
|
878
|
+
assert 2 not in group.user_ids
|
|
879
|
+
|
|
880
|
+
user_info = get_user_info(dm_access_token)
|
|
881
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
882
|
+
assert 1 not in user.group_ids
|
|
883
|
+
|
|
884
|
+
user_info = get_user_info(access_token)
|
|
885
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
886
|
+
assert 1 in user.group_ids
|
|
887
|
+
|
|
888
|
+
# Delete the group
|
|
889
|
+
response = post_json(
|
|
890
|
+
client=client,
|
|
891
|
+
endpoint=delete_endpoint,
|
|
892
|
+
json=GroupInfoRequest(
|
|
893
|
+
group_id=1,
|
|
894
|
+
).model_dump(),
|
|
895
|
+
access_token=access_token,
|
|
896
|
+
)
|
|
897
|
+
assert response.status_code == status.HTTP_200_OK
|
|
898
|
+
response = get_info(1)
|
|
899
|
+
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
900
|
+
|
|
901
|
+
user_info = get_user_info(access_token)
|
|
902
|
+
user: UserPublic = UserPublic.model_validate(json.loads(user_info.content.decode("utf-8")))
|
|
903
|
+
assert 1 not in user.group_ids
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def test_preferred_solver(client: TestClient):
|
|
907
|
+
"""Test that setting a preferred solver for the problem is ok."""
|
|
908
|
+
access_token = login(client)
|
|
909
|
+
|
|
910
|
+
request = ProblemSelectSolverRequest(problem_id=1, solver_string_representation="THIS SOLVER DOESN'T EXIST")
|
|
911
|
+
response = post_json(client, "/problem/assign_solver", request.model_dump(), access_token)
|
|
912
|
+
assert response.status_code == 404
|
|
913
|
+
|
|
914
|
+
request = ProblemSelectSolverRequest(problem_id=1, solver_string_representation="pyomo_cbc")
|
|
915
|
+
response = post_json(client, "/problem/assign_solver", request.model_dump(), access_token)
|
|
916
|
+
assert response.status_code == 200
|
|
917
|
+
|
|
918
|
+
request = {"problem_id": 1, "metadata_type": "solver_selection_metadata"}
|
|
919
|
+
response = post_json(client, "/problem/get_metadata", request, access_token)
|
|
920
|
+
assert response.status_code == 200
|
|
921
|
+
|
|
922
|
+
model = SolverSelectionMetadata.model_validate(response.json()[0])
|
|
923
|
+
|
|
924
|
+
assert model.metadata_type == "solver_selection_metadata"
|
|
925
|
+
assert model.solver_string_representation == "pyomo_cbc"
|
|
926
|
+
|
|
927
|
+
# Test that the solver is in use
|
|
928
|
+
try:
|
|
929
|
+
request = NIMBUSInitializationRequest(problem_id=1)
|
|
930
|
+
response = post_json(client, "/method/nimbus/initialize", request.model_dump(), access_token)
|
|
931
|
+
model = NIMBUSInitializationResponse.model_validate(response.json())
|
|
932
|
+
except Exception as e:
|
|
933
|
+
print(e)
|
|
934
|
+
print("^ This outcome is expected since pyomo_cbc doesn't support nonlinear problems.")
|
|
935
|
+
print(" As that solver is what we set it to be in the start, we can verify that they actually get used.")
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def test_get_available_solvers(client: TestClient):
|
|
939
|
+
"""Test that available solvers can be fetched."""
|
|
940
|
+
response = client.get("/problem/assign/solver")
|
|
941
|
+
|
|
942
|
+
assert response.status_code == 200
|
|
943
|
+
|
|
944
|
+
data = response.json()
|
|
945
|
+
assert isinstance(data, list)
|
|
946
|
+
|
|
947
|
+
# Check that the returned solver names match the available solvers
|
|
948
|
+
assert set(data) == set(available_solvers.keys())
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def test_emo_solve_with_reference_point(client: TestClient):
|
|
952
|
+
"""Test that using EMO with reference point works as expected."""
|
|
953
|
+
return
|
|
954
|
+
# TODO: This test fails because of websocket issues. Fix those and re-enable the test.
|
|
955
|
+
access_token = login(client)
|
|
956
|
+
request = EMOIterateRequest(
|
|
957
|
+
problem_id=1,
|
|
958
|
+
template_options=[rvea_options.template],
|
|
959
|
+
preference_options=ReferencePointOptions(preference={"f_1": 0.5, "f_2": 0.3, "f_3": 0.4}, method="Hakanen"),
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
response = post_json(client, "/method/emo/iterate", request.model_dump(), access_token)
|
|
963
|
+
|
|
964
|
+
assert response.status_code == status.HTTP_200_OK
|
|
965
|
+
|
|
966
|
+
# Validate the response structure
|
|
967
|
+
emo_response = EMOIterateResponse.model_validate(response.json())
|
|
968
|
+
assert emo_response.client_id is not None
|
|
969
|
+
state_id = emo_response.state_id
|
|
970
|
+
|
|
971
|
+
initial_time = time.time()
|
|
972
|
+
with client.websocket_connect(f"/method/emo/ws/{emo_response.client_id}") as websocket:
|
|
973
|
+
while time.time() - initial_time < 10:
|
|
974
|
+
message = websocket.receive_json()
|
|
975
|
+
if message.get("message") == f"Finished {emo_response.method_ids[0]}":
|
|
976
|
+
break
|
|
977
|
+
# Fetch the state to verify it worked
|
|
978
|
+
fetch_request = EMOFetchRequest(problem_id=1, parent_state_id=state_id)
|
|
979
|
+
response = post_json(client, "/method/emo/fetch", fetch_request.model_dump(), access_token)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def test_get_problem_metadata(client: TestClient):
|
|
983
|
+
"""Test that fetching problem metadata works."""
|
|
984
|
+
access_token = login(client=client)
|
|
985
|
+
|
|
986
|
+
# Problem with no metadata
|
|
987
|
+
req = {"problem_id": 1, "metadata_type": "forest_problem_metadata"}
|
|
988
|
+
response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
|
|
989
|
+
assert response.status_code == 200
|
|
990
|
+
assert response.json() == []
|
|
991
|
+
|
|
992
|
+
# Problem with forest metadata
|
|
993
|
+
req = {"problem_id": 2, "metadata_type": "forest_problem_metadata"}
|
|
994
|
+
response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
|
|
995
|
+
assert response.status_code == 200
|
|
996
|
+
assert response.json()[0]["metadata_type"] == "forest_problem_metadata"
|
|
997
|
+
assert response.json()[0]["schedule_dict"] == {"type": "dict"}
|
|
998
|
+
|
|
999
|
+
# No problem
|
|
1000
|
+
req = {"problem_id": 4, "metadata_type": "forest_problem_metadata"}
|
|
1001
|
+
response = post_json(client=client, endpoint="/problem/get_metadata", json=req, access_token=access_token)
|
|
1002
|
+
assert response.status_code == 404
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def test_gdm_score_bands(client: TestClient):
|
|
1006
|
+
"""Test score bands endpoints."""
|
|
1007
|
+
access_token = login(client=client)
|
|
1008
|
+
|
|
1009
|
+
# create group
|
|
1010
|
+
req = GroupCreateRequest(
|
|
1011
|
+
group_name="group",
|
|
1012
|
+
problem_id=3, # The discrete representation problem
|
|
1013
|
+
).model_dump()
|
|
1014
|
+
response = post_json(client=client, endpoint="/gdm/create_group", json=req, access_token=access_token)
|
|
1015
|
+
assert response.status_code == 201
|
|
1016
|
+
|
|
1017
|
+
# Add a dm to the group
|
|
1018
|
+
# Create a new user to the database
|
|
1019
|
+
response = client.post(
|
|
1020
|
+
"/add_new_dm",
|
|
1021
|
+
data={"username": "dm", "password": "dm", "grant_type": "password"},
|
|
1022
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
1023
|
+
)
|
|
1024
|
+
assert response.status_code == 201
|
|
1025
|
+
|
|
1026
|
+
req = GroupModifyRequest(group_id=1, user_id=2).model_dump()
|
|
1027
|
+
response = post_json(client=client, endpoint="/gdm/add_to_group", json=req, access_token=access_token)
|
|
1028
|
+
assert response.status_code == 200
|
|
1029
|
+
|
|
1030
|
+
access_token = login(client=client, username="dm", password="dm")
|
|
1031
|
+
|
|
1032
|
+
# Now we have a group, so let's get on with making stuff with gdm score bands.
|
|
1033
|
+
req = GDMScoreBandsInitializationRequest(
|
|
1034
|
+
group_id=1,
|
|
1035
|
+
score_bands_config=SCOREBandsGDMConfig(
|
|
1036
|
+
score_bands_config=SCOREBandsConfig(clustering_algorithm=KMeansOptions(n_clusters=5)), from_iteration=None
|
|
1037
|
+
),
|
|
1038
|
+
).model_dump()
|
|
1039
|
+
response = post_json(
|
|
1040
|
+
client=client, endpoint="/gdm-score-bands/get-or-initialize", json=req, access_token=access_token
|
|
1041
|
+
)
|
|
1042
|
+
assert response.status_code == 200
|
|
1043
|
+
response_innards = GDMSCOREBandsHistoryResponse.model_validate(response.json())
|
|
1044
|
+
cluster_size_1 = len(response_innards.history[-1].result.clusters)
|
|
1045
|
+
|
|
1046
|
+
# VOTE AND CONFIRM
|
|
1047
|
+
req = GDMScoreBandsVoteRequest(
|
|
1048
|
+
group_id=1,
|
|
1049
|
+
vote=4,
|
|
1050
|
+
).model_dump()
|
|
1051
|
+
response = post_json(client=client, endpoint="/gdm-score-bands/vote", json=req, access_token=access_token)
|
|
1052
|
+
assert response.status_code == 200
|
|
1053
|
+
req = GroupInfoRequest(group_id=1).model_dump()
|
|
1054
|
+
response = post_json(client=client, endpoint="/gdm-score-bands/confirm", json=req, access_token=access_token)
|
|
1055
|
+
assert response.status_code == 200
|
|
1056
|
+
|
|
1057
|
+
req = GDMScoreBandsInitializationRequest(
|
|
1058
|
+
group_id=1,
|
|
1059
|
+
score_bands_config=SCOREBandsGDMConfig(
|
|
1060
|
+
score_bands_config=SCOREBandsConfig(clustering_algorithm=KMeansOptions(n_clusters=5)),
|
|
1061
|
+
from_iteration=response_innards.history[-1].latest_iteration,
|
|
1062
|
+
),
|
|
1063
|
+
).model_dump()
|
|
1064
|
+
response = post_json(
|
|
1065
|
+
client=client, endpoint="/gdm-score-bands/get-or-initialize", json=req, access_token=access_token
|
|
1066
|
+
)
|
|
1067
|
+
assert response.status_code == 200
|
|
1068
|
+
response_innards = GDMSCOREBandsHistoryResponse.model_validate(response.json())
|
|
1069
|
+
cluster_size_2 = len(response_innards.history[-1].result.clusters)
|
|
1070
|
+
|
|
1071
|
+
# Since we've made one iteration, the length of the clustering and therefore the active
|
|
1072
|
+
# indices should be smaller the second time around than the first one.
|
|
1073
|
+
assert cluster_size_1 > cluster_size_2
|
|
1074
|
+
|
|
1075
|
+
# TODO: Test reverting, re-clustering
|