desdeo 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/__init__.py +6 -6
- desdeo/api/app.py +38 -28
- desdeo/api/config.py +65 -44
- desdeo/api/config.toml +23 -12
- desdeo/api/db.py +10 -8
- desdeo/api/db_init.py +12 -6
- desdeo/api/models/__init__.py +220 -20
- desdeo/api/models/archive.py +16 -27
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +44 -6
- desdeo/api/models/problem.py +274 -64
- desdeo/api/models/session.py +4 -1
- desdeo/api/models/state.py +419 -52
- desdeo/api/models/user.py +7 -6
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NIMBUS.py +6 -3
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +201 -4
- desdeo/api/routers/reference_point_method.py +20 -44
- desdeo/api/routers/session.py +50 -26
- desdeo/api/routers/user_authentication.py +180 -26
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +10 -4
- desdeo/api/tests/conftest.py +94 -2
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +550 -72
- desdeo/api/tests/test_routes.py +902 -43
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/database.py +28 -266
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +7 -0
- desdeo/emo/__init__.py +154 -24
- desdeo/emo/hooks/archivers.py +18 -2
- desdeo/emo/methods/EAs.py +128 -5
- desdeo/emo/methods/bases.py +9 -56
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/crossover.py +544 -42
- desdeo/emo/operators/evaluator.py +10 -14
- desdeo/emo/operators/generator.py +127 -24
- desdeo/emo/operators/mutation.py +212 -41
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +956 -214
- desdeo/emo/operators/termination.py +124 -16
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +23 -1
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautilus_navigator.py +7 -6
- desdeo/mcdm/reference_point_method.py +70 -0
- desdeo/problem/__init__.py +16 -11
- desdeo/problem/evaluator.py +4 -5
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +37 -12
- desdeo/problem/infix_parser.py +1 -16
- desdeo/problem/json_parser.py +7 -11
- desdeo/problem/pyomo_evaluator.py +25 -6
- desdeo/problem/schema.py +73 -55
- desdeo/problem/simulator_evaluator.py +65 -15
- desdeo/problem/testproblems/__init__.py +26 -11
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/forest_problem.py +77 -69
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/zdt_problem.py +4 -1
- desdeo/problem/utils.py +1 -1
- desdeo/tools/__init__.py +39 -21
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +22 -2
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/indicators_binary.py +107 -1
- desdeo/tools/indicators_unary.py +3 -16
- desdeo/tools/message.py +33 -2
- desdeo/tools/non_dominated_sorting.py +4 -3
- desdeo/tools/patterns.py +9 -7
- desdeo/tools/pyomo_solver_interfaces.py +49 -36
- desdeo/tools/reference_vectors.py +118 -351
- desdeo/tools/scalarization.py +340 -1413
- desdeo/tools/score_bands.py +491 -328
- desdeo/tools/utils.py +117 -49
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/utopia_problem.py +1 -1
- desdeo/utopia_stuff/utopia_problem_old.py +1 -1
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/METADATA +47 -30
- desdeo-2.1.1.dist-info/RECORD +180 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/WHEEL +1 -1
- desdeo-2.0.0.dist-info/RECORD +0 -120
- /desdeo/api/utils/{logger.py → _logger.py} +0 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info/licenses}/LICENSE +0 -0
desdeo/api/routers/problem.py
CHANGED
|
@@ -1,18 +1,63 @@
|
|
|
1
1
|
"""Defines end-points to access and manage problems."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
from typing import Annotated
|
|
4
5
|
|
|
5
|
-
from fastapi import APIRouter, Depends, HTTPException, status
|
|
6
|
-
from
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
|
+
from sqlmodel import Session, select
|
|
7
9
|
|
|
8
10
|
from desdeo.api.db import get_session
|
|
9
|
-
from desdeo.api.models import
|
|
11
|
+
from desdeo.api.models import (
|
|
12
|
+
ForestProblemMetaData,
|
|
13
|
+
ProblemDB,
|
|
14
|
+
ProblemGetRequest,
|
|
15
|
+
ProblemInfo,
|
|
16
|
+
ProblemInfoSmall,
|
|
17
|
+
ProblemMetaDataDB,
|
|
18
|
+
ProblemMetaDataGetRequest,
|
|
19
|
+
ProblemSelectSolverRequest,
|
|
20
|
+
RepresentativeNonDominatedSolutions,
|
|
21
|
+
SolverSelectionMetadata,
|
|
22
|
+
User,
|
|
23
|
+
UserRole,
|
|
24
|
+
)
|
|
10
25
|
from desdeo.api.routers.user_authentication import get_current_user
|
|
11
26
|
from desdeo.problem import Problem
|
|
27
|
+
from desdeo.tools.utils import available_solvers
|
|
12
28
|
|
|
13
29
|
router = APIRouter(prefix="/problem")
|
|
14
30
|
|
|
15
31
|
|
|
32
|
+
def check_solver(problem_db: ProblemDB):
|
|
33
|
+
"""Check if a preferred solver is set in the metadata.
|
|
34
|
+
|
|
35
|
+
Check if a preferred solver is set in the metadata.
|
|
36
|
+
If it exist, fetch its constructor and return it. Otherwise return None.
|
|
37
|
+
"""
|
|
38
|
+
metadata: ProblemMetaDataDB = problem_db.problem_metadata
|
|
39
|
+
solver_metadata = None
|
|
40
|
+
if metadata is not None:
|
|
41
|
+
solver_metadata_list = [
|
|
42
|
+
metadata for metadata in metadata.all_metadata if metadata.metadata_type == "solver_selection_metadata"
|
|
43
|
+
]
|
|
44
|
+
if solver_metadata_list != []:
|
|
45
|
+
solver_metadata = solver_metadata_list[-1]
|
|
46
|
+
|
|
47
|
+
if solver_metadata is not None:
|
|
48
|
+
solver = available_solvers[solver_metadata.solver_string_representation]["constructor"]
|
|
49
|
+
else:
|
|
50
|
+
solver = None
|
|
51
|
+
return solver
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# This is needed, because otherwise fields ending in an underscore fail to parse.
|
|
55
|
+
async def parse_problem_json(request: Request) -> Problem:
|
|
56
|
+
"""Helper function to pass by_name=True to model_validate when coercing the json object to a Problem object."""
|
|
57
|
+
data: dict = await request.json()
|
|
58
|
+
return Problem.model_validate(data, by_name=True)
|
|
59
|
+
|
|
60
|
+
|
|
16
61
|
@router.get("/all")
|
|
17
62
|
def get_problems(user: Annotated[User, Depends(get_current_user)]) -> list[ProblemInfoSmall]:
|
|
18
63
|
"""Get information on all the current user's problems.
|
|
@@ -71,7 +116,7 @@ def get_problem(
|
|
|
71
116
|
|
|
72
117
|
@router.post("/add")
|
|
73
118
|
def add_problem(
|
|
74
|
-
request: Problem,
|
|
119
|
+
request: Annotated[Problem, Depends(parse_problem_json)],
|
|
75
120
|
user: Annotated[User, Depends(get_current_user)],
|
|
76
121
|
session: Annotated[Session, Depends(get_session)],
|
|
77
122
|
) -> ProblemInfo:
|
|
@@ -108,3 +153,155 @@ def add_problem(
|
|
|
108
153
|
session.refresh(problem_db)
|
|
109
154
|
|
|
110
155
|
return problem_db
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@router.post("/add_json")
|
|
159
|
+
def add_problem_json(
|
|
160
|
+
json_file: UploadFile,
|
|
161
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
162
|
+
session: Annotated[Session, Depends(get_session)],
|
|
163
|
+
) -> ProblemInfo:
|
|
164
|
+
"""Adds a problem to the database based on its JSON definition.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
json_file (UploadFile): a file in JSON format describing the problem.
|
|
168
|
+
user (Annotated[User, Depends): the usr for which the problem is added.
|
|
169
|
+
session (Annotated[Session, Depends): the database session.
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
HTTPException: if the provided `json_file` is empty.
|
|
173
|
+
HTTPException: if the content in the provided `json_file` is not in JSON format.__annotations__
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
ProblemInfo: a description of the added problem.
|
|
177
|
+
"""
|
|
178
|
+
raw = json_file.file.read()
|
|
179
|
+
|
|
180
|
+
if not raw:
|
|
181
|
+
raise HTTPException(400, "Empty upload.")
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
# for extra validation
|
|
185
|
+
json.loads(raw)
|
|
186
|
+
except json.JSONDecodeError as e:
|
|
187
|
+
raise HTTPException(400, "Invalid JSON.") from e
|
|
188
|
+
|
|
189
|
+
problem = Problem.model_validate_json(raw, by_name=True)
|
|
190
|
+
problem_db = ProblemDB.from_problem(problem, user=user)
|
|
191
|
+
|
|
192
|
+
session.add(problem_db)
|
|
193
|
+
session.commit()
|
|
194
|
+
session.refresh(problem_db)
|
|
195
|
+
|
|
196
|
+
return problem_db
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@router.post("/get_metadata")
|
|
200
|
+
def get_metadata(
|
|
201
|
+
request: ProblemMetaDataGetRequest,
|
|
202
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
203
|
+
session: Annotated[Session, Depends(get_session)],
|
|
204
|
+
) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]:
|
|
205
|
+
"""Fetch specific metadata for a specific problem.
|
|
206
|
+
|
|
207
|
+
Fetch specific metadata for a specific problem. See all the possible
|
|
208
|
+
metadata types from DESDEO/desdeo/api/models/problem.py Problem Metadata
|
|
209
|
+
section.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
request (MetaDataGetRequest): the requested metadata type.
|
|
213
|
+
user (Annotated[User, Depends]): the current user.
|
|
214
|
+
session (Annotated[Session, Depends]): the database session.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
list[ForestProblemMetadata | RepresentativeNonDominatedSolutions]: list containing all the metadata
|
|
218
|
+
defined for the problem with the requested metadata type. If no match is found,
|
|
219
|
+
returns an empty list.
|
|
220
|
+
"""
|
|
221
|
+
statement = select(ProblemDB).where(ProblemDB.id == request.problem_id)
|
|
222
|
+
problem_from_db = session.exec(statement).first()
|
|
223
|
+
if problem_from_db is None:
|
|
224
|
+
raise HTTPException(
|
|
225
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with ID {request.problem_id} not found!"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
problem_metadata = problem_from_db.problem_metadata
|
|
229
|
+
|
|
230
|
+
if problem_metadata is None:
|
|
231
|
+
# no metadata define for the problem
|
|
232
|
+
return []
|
|
233
|
+
|
|
234
|
+
# metadata is defined, try to find matching types based on request
|
|
235
|
+
return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@router.get("/assign/solver", response_model=list[str])
|
|
239
|
+
def get_available_solvers() -> list[str]:
|
|
240
|
+
"""Return the list of available solver names."""
|
|
241
|
+
return list(available_solvers.keys())
|
|
242
|
+
|
|
243
|
+
@router.post("/assign_solver")
|
|
244
|
+
def select_solver(
|
|
245
|
+
request: ProblemSelectSolverRequest,
|
|
246
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
247
|
+
session: Annotated[Session, Depends(get_session)],
|
|
248
|
+
) -> JSONResponse:
|
|
249
|
+
"""Assign a specific solver for a problem.
|
|
250
|
+
|
|
251
|
+
request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver
|
|
252
|
+
user: Annotated[User, Depends(get_current_user): The user that is logged in.
|
|
253
|
+
session: Annotated[Session, Depends(get_session)]: The database session.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
HTTPException: Unknown solver, unauthorized user
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
JSONResponse: A simple confirmation.
|
|
260
|
+
"""
|
|
261
|
+
if request.solver_string_representation not in [x for x, _ in available_solvers.items()]:
|
|
262
|
+
raise HTTPException(
|
|
263
|
+
detail=f"Solver of unknown type: {request.solver_string_representation}",
|
|
264
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
"""Set a specific solver for a specific problem."""
|
|
268
|
+
# Get the problem
|
|
269
|
+
problem_db = session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
|
|
270
|
+
if problem_db is None:
|
|
271
|
+
raise HTTPException(
|
|
272
|
+
detail=f"No problem with ID {request.problem_id}!",
|
|
273
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
274
|
+
)
|
|
275
|
+
# Auth the user
|
|
276
|
+
if user.id != problem_db.user_id:
|
|
277
|
+
raise HTTPException(detail="Unauthorized user!", status_code=status.HTTP_401_UNAUTHORIZED)
|
|
278
|
+
|
|
279
|
+
# All good, get on with it.
|
|
280
|
+
problem_metadata = problem_db.problem_metadata
|
|
281
|
+
if problem_metadata is None:
|
|
282
|
+
# There's no metadata for this problem! Create some.
|
|
283
|
+
problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db)
|
|
284
|
+
session.add(problem_metadata)
|
|
285
|
+
session.commit()
|
|
286
|
+
session.refresh(problem_metadata)
|
|
287
|
+
|
|
288
|
+
if problem_metadata.solver_selection_metadata:
|
|
289
|
+
session.delete(problem_metadata.solver_selection_metadata[-1])
|
|
290
|
+
session.commit()
|
|
291
|
+
|
|
292
|
+
solver_selection_metadata = SolverSelectionMetadata(
|
|
293
|
+
metadata_id=problem_metadata.id,
|
|
294
|
+
solver_string_representation=request.solver_string_representation,
|
|
295
|
+
metadata_instance=problem_metadata,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
session.add(solver_selection_metadata)
|
|
299
|
+
session.commit()
|
|
300
|
+
session.refresh(solver_selection_metadata)
|
|
301
|
+
|
|
302
|
+
problem_metadata.solver_selection_metadata.append(solver_selection_metadata)
|
|
303
|
+
session.add(problem_metadata)
|
|
304
|
+
session.commit()
|
|
305
|
+
session.refresh(problem_metadata)
|
|
306
|
+
|
|
307
|
+
return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK)
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Annotated
|
|
4
4
|
|
|
5
|
-
from fastapi import APIRouter, Depends
|
|
6
|
-
from sqlmodel import Session
|
|
5
|
+
from fastapi import APIRouter, Depends
|
|
6
|
+
from sqlmodel import Session
|
|
7
7
|
|
|
8
8
|
from desdeo.api.db import get_session
|
|
9
9
|
from desdeo.api.models import (
|
|
@@ -15,11 +15,14 @@ from desdeo.api.models import (
|
|
|
15
15
|
StateDB,
|
|
16
16
|
User,
|
|
17
17
|
)
|
|
18
|
+
from desdeo.api.routers.problem import check_solver
|
|
18
19
|
from desdeo.api.routers.user_authentication import get_current_user
|
|
19
20
|
from desdeo.mcdm import rpm_solve_solutions
|
|
20
21
|
from desdeo.problem import Problem
|
|
21
22
|
from desdeo.tools import SolverResults
|
|
22
23
|
|
|
24
|
+
from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem
|
|
25
|
+
|
|
23
26
|
router = APIRouter(prefix="/method/rpm")
|
|
24
27
|
|
|
25
28
|
|
|
@@ -29,32 +32,24 @@ def solve_solutions(
|
|
|
29
32
|
user: Annotated[User, Depends(get_current_user)],
|
|
30
33
|
session: Annotated[Session, Depends(get_session)],
|
|
31
34
|
) -> RPMState:
|
|
32
|
-
""".
|
|
33
|
-
|
|
34
|
-
if request.session_id is not None:
|
|
35
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
|
|
36
|
-
interactive_session = session.exec(statement)
|
|
35
|
+
"""Runs an iteration of the reference point method.
|
|
37
36
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
)
|
|
43
|
-
else:
|
|
44
|
-
# request.session_id is None:
|
|
45
|
-
# use active session instead
|
|
46
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
37
|
+
Args:
|
|
38
|
+
request (RPMSolveRequest): a request with the needed information to run the method.
|
|
39
|
+
user (Annotated[User, Depends): the current user.
|
|
40
|
+
session (Annotated[Session, Depends): the current database session.
|
|
47
41
|
|
|
48
|
-
|
|
42
|
+
Returns:
|
|
43
|
+
RPMState: a state with information on the results of iterating the reference point method
|
|
44
|
+
once.
|
|
45
|
+
"""
|
|
46
|
+
# fetch interactive session, parent state, and ProblemDB
|
|
47
|
+
interactive_session: InteractiveSessionDB = fetch_interactive_session(user, request, session)
|
|
48
|
+
parent_state = fetch_parent_state(user, request, session, interactive_session)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
52
|
-
problem_db = session.exec(statement).first()
|
|
50
|
+
problem_db: ProblemDB = fetch_user_problem(user, request, session)
|
|
53
51
|
|
|
54
|
-
|
|
55
|
-
raise HTTPException(
|
|
56
|
-
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
|
|
57
|
-
)
|
|
52
|
+
solver = check_solver(problem_db=problem_db)
|
|
58
53
|
|
|
59
54
|
problem = Problem.from_problemdb(problem_db)
|
|
60
55
|
|
|
@@ -63,7 +58,7 @@ def solve_solutions(
|
|
|
63
58
|
problem,
|
|
64
59
|
request.preference.aspiration_levels,
|
|
65
60
|
request.scalarization_options,
|
|
66
|
-
|
|
61
|
+
solver,
|
|
67
62
|
request.solver_options,
|
|
68
63
|
)
|
|
69
64
|
|
|
@@ -74,25 +69,6 @@ def solve_solutions(
|
|
|
74
69
|
session.commit()
|
|
75
70
|
session.refresh(preference_db)
|
|
76
71
|
|
|
77
|
-
# fetch parent state
|
|
78
|
-
if request.parent_state_id is None:
|
|
79
|
-
# parent state is assumed to be the last sate added to the session.
|
|
80
|
-
parent_state = (
|
|
81
|
-
interactive_session.states[-1]
|
|
82
|
-
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
83
|
-
else None
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
else:
|
|
87
|
-
# request.parent_state_id is not None
|
|
88
|
-
statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
89
|
-
parent_state = session.exec(statement).first()
|
|
90
|
-
|
|
91
|
-
if parent_state is None:
|
|
92
|
-
raise HTTPException(
|
|
93
|
-
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
|
|
94
|
-
)
|
|
95
|
-
|
|
96
72
|
# create state and add to DB
|
|
97
73
|
rpm_state = RPMState(
|
|
98
74
|
scalarization_options=request.scalarization_options,
|
desdeo/api/routers/session.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Annotated
|
|
|
5
5
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
6
6
|
from sqlmodel import Session, select
|
|
7
7
|
|
|
8
|
-
from desdeo.api.db import get_session
|
|
8
|
+
from desdeo.api.db import get_session as get_db_session
|
|
9
9
|
from desdeo.api.models import (
|
|
10
10
|
CreateSessionRequest,
|
|
11
11
|
GetSessionRequest,
|
|
@@ -14,6 +14,7 @@ from desdeo.api.models import (
|
|
|
14
14
|
User,
|
|
15
15
|
)
|
|
16
16
|
from desdeo.api.routers.user_authentication import get_current_user
|
|
17
|
+
from desdeo.api.routers.utils import fetch_interactive_session
|
|
17
18
|
|
|
18
19
|
router = APIRouter(prefix="/session")
|
|
19
20
|
|
|
@@ -22,7 +23,7 @@ router = APIRouter(prefix="/session")
|
|
|
22
23
|
def create_new_session(
|
|
23
24
|
request: CreateSessionRequest,
|
|
24
25
|
user: Annotated[User, Depends(get_current_user)],
|
|
25
|
-
session: Annotated[Session, Depends(
|
|
26
|
+
session: Annotated[Session, Depends(get_db_session)],
|
|
26
27
|
) -> InteractiveSessionInfo:
|
|
27
28
|
"""."""
|
|
28
29
|
interactive_session = InteractiveSessionDB(user_id=user.id, info=request.info)
|
|
@@ -40,37 +41,60 @@ def create_new_session(
|
|
|
40
41
|
return interactive_session
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
@router.
|
|
44
|
+
@router.get("/get/{session_id}")
|
|
44
45
|
def get_session(
|
|
45
|
-
|
|
46
|
+
session_id: int,
|
|
46
47
|
user: Annotated[User, Depends(get_current_user)],
|
|
47
|
-
session: Annotated[Session, Depends(
|
|
48
|
+
session: Annotated[Session, Depends(get_db_session)],
|
|
48
49
|
) -> InteractiveSessionInfo:
|
|
49
|
-
"""Return an interactive session with a
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
session
|
|
55
|
-
|
|
56
|
-
Raises:
|
|
57
|
-
HTTPException: could not find an interactive session with the given id
|
|
58
|
-
for the current user.
|
|
59
|
-
|
|
60
|
-
Returns:
|
|
61
|
-
InteractiveSessionInfo: info on the requested interactive session.
|
|
62
|
-
"""
|
|
63
|
-
statement = select(InteractiveSessionDB).where(
|
|
64
|
-
InteractiveSessionDB.id == request.session_id, InteractiveSessionDB.user_id == user.id
|
|
50
|
+
"""Return an interactive session with a current user."""
|
|
51
|
+
request = GetSessionRequest(session_id=session_id)
|
|
52
|
+
return fetch_interactive_session(
|
|
53
|
+
user=user,
|
|
54
|
+
request=request,
|
|
55
|
+
session=session,
|
|
65
56
|
)
|
|
66
|
-
result = session.exec(statement)
|
|
67
57
|
|
|
68
|
-
interactive_session = result.first()
|
|
69
58
|
|
|
70
|
-
|
|
59
|
+
@router.get("/get_all", status_code=status.HTTP_200_OK)
|
|
60
|
+
def get_all_sessions(
|
|
61
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
62
|
+
session: Annotated[Session, Depends(get_db_session)],
|
|
63
|
+
) -> list[InteractiveSessionInfo]:
|
|
64
|
+
"""Return all interactive sessions of the current user."""
|
|
65
|
+
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.user_id == user.id)
|
|
66
|
+
result = session.exec(statement).all()
|
|
67
|
+
|
|
68
|
+
if not result:
|
|
71
69
|
raise HTTPException(
|
|
72
70
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
73
|
-
detail=
|
|
71
|
+
detail="No interactive sessions found for the user.",
|
|
74
72
|
)
|
|
75
73
|
|
|
76
|
-
return
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
78
|
+
def delete_session(
|
|
79
|
+
session_id: int,
|
|
80
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
81
|
+
session: Annotated[Session, Depends(get_db_session)],
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Delete an interactive session and all its related states."""
|
|
84
|
+
request = GetSessionRequest(session_id=session_id)
|
|
85
|
+
|
|
86
|
+
interactive_session = fetch_interactive_session(
|
|
87
|
+
user=user,
|
|
88
|
+
request=request,
|
|
89
|
+
session=session,
|
|
90
|
+
) # raises 404 if not found
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
session.delete(interactive_session)
|
|
94
|
+
session.commit()
|
|
95
|
+
except Exception as exc:
|
|
96
|
+
session.rollback()
|
|
97
|
+
raise HTTPException(
|
|
98
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
99
|
+
detail="Failed to delete interactive session.",
|
|
100
|
+
) from exc
|