desdeo 2.1.1__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/api/models/nimbus.py +8 -4
- desdeo/api/routers/emo.py +75 -104
- desdeo/api/routers/generic.py +26 -58
- desdeo/api/routers/nimbus.py +108 -247
- desdeo/api/routers/problem.py +69 -56
- desdeo/api/routers/reference_point_method.py +29 -27
- desdeo/api/routers/session.py +15 -11
- desdeo/api/routers/user_authentication.py +27 -5
- desdeo/api/routers/utils.py +42 -37
- desdeo/api/routers/utopia.py +11 -12
- desdeo/api/tests/test_routes.py +6 -5
- desdeo/emo/__init__.py +2 -0
- desdeo/emo/operators/__init__.py +1 -1
- desdeo/emo/operators/generator.py +153 -2
- desdeo/emo/options/__init__.py +4 -0
- desdeo/emo/options/generator.py +24 -0
- {desdeo-2.1.1.dist-info → desdeo-2.2.0.dist-info}/METADATA +20 -10
- {desdeo-2.1.1.dist-info → desdeo-2.2.0.dist-info}/RECORD +20 -20
- {desdeo-2.1.1.dist-info → desdeo-2.2.0.dist-info}/WHEEL +1 -1
- {desdeo-2.1.1.dist-info → desdeo-2.2.0.dist-info}/licenses/LICENSE +0 -0
desdeo/api/models/nimbus.py
CHANGED
|
@@ -36,11 +36,13 @@ class NIMBUSSaveRequest(SQLModel):
|
|
|
36
36
|
|
|
37
37
|
solution_info: list[SolutionInfo]
|
|
38
38
|
|
|
39
|
+
|
|
39
40
|
class NIMBUSDeleteSaveRequest(SQLModel):
|
|
40
41
|
"""Request model for deletion of a saved solution."""
|
|
41
42
|
|
|
42
|
-
state_id
|
|
43
|
+
state_id: int = Field(description="The ID of the save state.")
|
|
43
44
|
solution_index: int = Field(description="The ID of the solution within the above state.")
|
|
45
|
+
problem_id: int = Field(description="The ID of the problem.")
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
class NIMBUSFinalizeRequest(SQLModel):
|
|
@@ -50,7 +52,7 @@ class NIMBUSFinalizeRequest(SQLModel):
|
|
|
50
52
|
session_id: int | None = Field(default=None)
|
|
51
53
|
parent_state_id: int | None = Field(default=None)
|
|
52
54
|
|
|
53
|
-
solution_info: SolutionInfo
|
|
55
|
+
solution_info: SolutionInfo # the final solution
|
|
54
56
|
|
|
55
57
|
|
|
56
58
|
class NIMBUSClassificationResponse(SQLModel):
|
|
@@ -98,12 +100,14 @@ class NIMBUSSaveResponse(SQLModel):
|
|
|
98
100
|
|
|
99
101
|
state_id: int | None = Field(description="The id of the newest state")
|
|
100
102
|
|
|
103
|
+
|
|
101
104
|
class NIMBUSDeleteSaveResponse(SQLModel):
|
|
102
105
|
"""Response of NIMBUS save deletion."""
|
|
103
106
|
|
|
104
107
|
response_type: str = "nimbus.delete_save"
|
|
105
108
|
|
|
106
|
-
message: str | None
|
|
109
|
+
message: str | None = None
|
|
110
|
+
|
|
107
111
|
|
|
108
112
|
class NIMBUSFinalizeResponse(SQLModel):
|
|
109
113
|
"""The response from NIMBUS finish endpoint."""
|
|
@@ -144,7 +148,7 @@ class NIMBUSIntermediateSolutionResponse(SQLModel):
|
|
|
144
148
|
reference_solution_1: dict[str, float] = Field(
|
|
145
149
|
sa_column=Column(JSON), description="The first solution used when computing intermediate points."
|
|
146
150
|
)
|
|
147
|
-
reference_solution_2: dict[str, float]= Field(
|
|
151
|
+
reference_solution_2: dict[str, float] = Field(
|
|
148
152
|
sa_column=Column(JSON), description="The second solution used when computing intermediate points."
|
|
149
153
|
)
|
|
150
154
|
current_solutions: list[SolutionReferenceResponse] = Field(
|
desdeo/api/routers/emo.py
CHANGED
|
@@ -15,30 +15,24 @@ import polars as pl
|
|
|
15
15
|
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
|
|
16
16
|
from fastapi.encoders import jsonable_encoder
|
|
17
17
|
from fastapi.responses import StreamingResponse
|
|
18
|
-
from sqlmodel import
|
|
18
|
+
from sqlmodel import select
|
|
19
19
|
from websockets.asyncio.client import connect
|
|
20
20
|
|
|
21
21
|
from desdeo.api.db import get_session
|
|
22
|
-
from desdeo.api.models import
|
|
22
|
+
from desdeo.api.models import StateDB
|
|
23
23
|
from desdeo.api.models.emo import (
|
|
24
24
|
EMOFetchRequest,
|
|
25
|
-
EMOFetchResponse,
|
|
26
25
|
EMOIterateRequest,
|
|
27
26
|
EMOIterateResponse,
|
|
28
|
-
EMOSaveRequest,
|
|
29
27
|
EMOScoreRequest,
|
|
30
28
|
EMOScoreResponse,
|
|
31
|
-
Solution,
|
|
32
29
|
)
|
|
33
|
-
from desdeo.api.models.
|
|
34
|
-
from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState
|
|
35
|
-
from desdeo.api.models.user import User
|
|
36
|
-
from desdeo.api.routers.user_authentication import get_current_user
|
|
30
|
+
from desdeo.api.models.state import EMOIterateState, EMOSCOREState
|
|
37
31
|
from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor
|
|
38
32
|
from desdeo.problem import Problem
|
|
39
|
-
from desdeo.tools.score_bands import SCOREBandsConfig,
|
|
33
|
+
from desdeo.tools.score_bands import SCOREBandsConfig, score_json
|
|
40
34
|
|
|
41
|
-
from .utils import
|
|
35
|
+
from .utils import SessionContext, get_session_context
|
|
42
36
|
|
|
43
37
|
router = APIRouter(prefix="/method/emo", tags=["EMO"])
|
|
44
38
|
|
|
@@ -113,7 +107,6 @@ async def websocket_endpoint(
|
|
|
113
107
|
try:
|
|
114
108
|
while True:
|
|
115
109
|
data = await websocket.receive_json()
|
|
116
|
-
print(data)
|
|
117
110
|
if "send_to" in data:
|
|
118
111
|
try:
|
|
119
112
|
await ws_manager.send_private_message(data, data["send_to"])
|
|
@@ -153,71 +146,55 @@ def get_templates() -> list[TemplateOptions]:
|
|
|
153
146
|
@router.post("/iterate")
|
|
154
147
|
def iterate(
|
|
155
148
|
request: EMOIterateRequest,
|
|
156
|
-
|
|
157
|
-
session: Annotated[Session, Depends(get_session)],
|
|
149
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
158
150
|
) -> EMOIterateResponse:
|
|
159
|
-
"""
|
|
151
|
+
"""Fetches results from a completed EMO method.
|
|
160
152
|
|
|
161
|
-
Args:
|
|
162
|
-
|
|
163
|
-
user (Annotated[User, Depends]): The current user.
|
|
164
|
-
session (Annotated[Session, Depends]): The database session.
|
|
153
|
+
Args: request (EMOIterateRequest): The request object containing parameters for fetching results.
|
|
154
|
+
context (Annotated[SessionContext, Depends]): The session context.
|
|
165
155
|
|
|
166
|
-
Raises:
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
170
|
-
IterateResponse: A response object containing a list of IDs to be used for websocket communication.
|
|
171
|
-
Also contains the StateDB id where the results will be stored.
|
|
156
|
+
Raises: HTTPException: If the request is invalid or the EMO method fails.
|
|
157
|
+
Returns: IterateResponse: A response object containing a list of IDs to be used for websocket communication.
|
|
158
|
+
Also contains the StateDB id where the results will be stored.
|
|
172
159
|
"""
|
|
173
|
-
|
|
160
|
+
# Get context objects
|
|
161
|
+
db_session = context.db_session
|
|
162
|
+
interactive_session = context.interactive_session
|
|
163
|
+
parent_state = context.parent_state
|
|
174
164
|
|
|
175
|
-
|
|
165
|
+
# Ensure problem exists
|
|
166
|
+
if context.problem_db is None:
|
|
167
|
+
raise HTTPException(status_code=404, detail="Problem not found")
|
|
168
|
+
|
|
169
|
+
problem_db = context.problem_db
|
|
176
170
|
problem = Problem.from_problemdb(problem_db)
|
|
177
171
|
|
|
178
|
-
|
|
172
|
+
# Templates
|
|
173
|
+
templates = request.template_options or get_templates()
|
|
179
174
|
|
|
180
|
-
|
|
181
|
-
|
|
175
|
+
web_socket_ids = [
|
|
176
|
+
f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" for template in templates
|
|
177
|
+
]
|
|
182
178
|
|
|
183
|
-
web_socket_ids = []
|
|
184
|
-
for template in templates:
|
|
185
|
-
# Ensure unique names
|
|
186
|
-
web_socket_ids.append(f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}")
|
|
187
179
|
client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
188
|
-
client_id = "client"
|
|
189
|
-
|
|
190
|
-
# Save request (incomplete and EAs have not finished running yet)
|
|
191
|
-
|
|
192
|
-
# Handle parent state
|
|
193
|
-
if request.parent_state_id is None:
|
|
194
|
-
parent_state = None
|
|
195
|
-
else:
|
|
196
|
-
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
197
|
-
parent_state = session.exec(statement).first()
|
|
198
|
-
|
|
199
|
-
if parent_state is None:
|
|
200
|
-
raise HTTPException(
|
|
201
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
202
|
-
detail=f"Could not find state with id={request.parent_state_id}",
|
|
203
|
-
)
|
|
204
180
|
|
|
181
|
+
# 4) Create incomplete state
|
|
205
182
|
emo_iterate_state = EMOIterateState(
|
|
206
183
|
template_options=jsonable_encoder(templates),
|
|
207
184
|
preference_options=jsonable_encoder(request.preference_options),
|
|
208
185
|
)
|
|
209
186
|
|
|
210
187
|
incomplete_db_state = StateDB.create(
|
|
211
|
-
database_session=
|
|
188
|
+
database_session=db_session,
|
|
212
189
|
problem_id=problem_db.id,
|
|
213
190
|
session_id=interactive_session.id if interactive_session else None,
|
|
214
191
|
parent_id=parent_state.id if parent_state else None,
|
|
215
192
|
state=emo_iterate_state,
|
|
216
193
|
)
|
|
217
194
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
195
|
+
db_session.add(incomplete_db_state)
|
|
196
|
+
db_session.commit()
|
|
197
|
+
db_session.refresh(incomplete_db_state)
|
|
221
198
|
|
|
222
199
|
state_id = incomplete_db_state.id
|
|
223
200
|
if state_id is None:
|
|
@@ -225,10 +202,8 @@ def iterate(
|
|
|
225
202
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
226
203
|
detail="Failed to create a new state in the database.",
|
|
227
204
|
)
|
|
228
|
-
# Close db session
|
|
229
|
-
session.close()
|
|
230
205
|
|
|
231
|
-
#
|
|
206
|
+
# Start process
|
|
232
207
|
Process(
|
|
233
208
|
target=_spawn_emo_process,
|
|
234
209
|
args=(
|
|
@@ -319,7 +294,7 @@ def _spawn_emo_process(
|
|
|
319
294
|
session.close()
|
|
320
295
|
|
|
321
296
|
|
|
322
|
-
def _ea_sync(
|
|
297
|
+
def _ea_sync(
|
|
323
298
|
problem: Problem,
|
|
324
299
|
template: TemplateOptions,
|
|
325
300
|
preference_options: PreferenceOptions | None,
|
|
@@ -352,7 +327,7 @@ def _ea_sync( # noqa: PLR0913
|
|
|
352
327
|
)
|
|
353
328
|
|
|
354
329
|
|
|
355
|
-
async def _ea_async(
|
|
330
|
+
async def _ea_async(
|
|
356
331
|
problem: Problem,
|
|
357
332
|
websocket_id: str,
|
|
358
333
|
client_id: str,
|
|
@@ -388,33 +363,29 @@ async def _ea_async( # noqa: PLR0913
|
|
|
388
363
|
@router.post("/fetch")
|
|
389
364
|
async def fetch_results(
|
|
390
365
|
request: EMOFetchRequest,
|
|
391
|
-
|
|
392
|
-
session: Annotated[Session, Depends(get_session)],
|
|
366
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
393
367
|
) -> StreamingResponse:
|
|
394
368
|
"""Fetches results from a completed EMO method.
|
|
395
369
|
|
|
396
370
|
Args:
|
|
397
371
|
request (EMOFetchRequest): The request object containing parameters for fetching results.
|
|
398
|
-
|
|
399
|
-
session (Annotated[Session, Depends]): The database session.
|
|
372
|
+
context (Annotated[SessionContext, Depends]): The session context.
|
|
400
373
|
|
|
401
|
-
Raises:
|
|
402
|
-
HTTPException: If the request is invalid or the EMO method has not completed.
|
|
374
|
+
Raises: HTTPException: If the request is invalid or the EMO method has not completed.
|
|
403
375
|
|
|
404
|
-
Returns:
|
|
405
|
-
StreamingResponse: A streaming response containing the results of the EMO method.
|
|
376
|
+
Returns: StreamingResponse: A streaming response containing the results of the EMO method.
|
|
406
377
|
"""
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
378
|
+
# Use context instead of manual fetch
|
|
379
|
+
state = context.parent_state
|
|
380
|
+
|
|
410
381
|
if state is None:
|
|
411
382
|
raise HTTPException(status_code=404, detail="Parent state not found.")
|
|
412
383
|
|
|
413
384
|
if not isinstance(state.state, EMOIterateState):
|
|
414
|
-
raise TypeError(f"State with id={
|
|
385
|
+
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")
|
|
415
386
|
|
|
416
387
|
if not (state.state.objective_values and state.state.decision_variables):
|
|
417
|
-
raise ValueError(
|
|
388
|
+
raise ValueError("State does not contain results yet.")
|
|
418
389
|
|
|
419
390
|
# Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
|
|
420
391
|
raw_objs: dict[str, list[float]] = state.state.objective_values
|
|
@@ -422,14 +393,15 @@ async def fetch_results(
|
|
|
422
393
|
objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)]
|
|
423
394
|
|
|
424
395
|
raw_decs: dict[str, list[float]] = state.state.decision_variables
|
|
425
|
-
|
|
426
396
|
decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)]
|
|
427
397
|
|
|
428
|
-
response: list[Solution] = []
|
|
429
|
-
|
|
430
398
|
def result_stream():
|
|
431
399
|
for i in range(n_solutions):
|
|
432
|
-
item = {
|
|
400
|
+
item = {
|
|
401
|
+
"solution_id": i,
|
|
402
|
+
"objective_values": objs[i],
|
|
403
|
+
"decision_variables": decs[i],
|
|
404
|
+
}
|
|
433
405
|
yield json.dumps(item) + "\n"
|
|
434
406
|
|
|
435
407
|
return StreamingResponse(result_stream())
|
|
@@ -438,16 +410,13 @@ async def fetch_results(
|
|
|
438
410
|
@router.post("/fetch_score")
|
|
439
411
|
async def fetch_score_bands(
|
|
440
412
|
request: EMOScoreRequest,
|
|
441
|
-
|
|
442
|
-
session: Annotated[Session, Depends(get_session)],
|
|
413
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
443
414
|
) -> EMOScoreResponse:
|
|
444
415
|
"""Fetches results from a completed EMO method.
|
|
445
416
|
|
|
446
|
-
Args:
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
user (Annotated[User, Depends]): The current user.
|
|
450
|
-
session (Annotated[Session, Depends]): The database session.
|
|
417
|
+
Args: request (EMOFetchRequest): The request object containing parameters for fetching
|
|
418
|
+
results and of the SCORE bands visualization.
|
|
419
|
+
context (Annotated[SessionContext, Depends]): The session context.
|
|
451
420
|
|
|
452
421
|
Raises:
|
|
453
422
|
HTTPException: If the request is invalid or the EMO method has not completed.
|
|
@@ -455,24 +424,23 @@ async def fetch_score_bands(
|
|
|
455
424
|
Returns:
|
|
456
425
|
SCOREBandsResult: The results of the SCORE bands visualization.
|
|
457
426
|
"""
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
state = session.exec(statement).first()
|
|
465
|
-
if state is None:
|
|
427
|
+
# Use context instead of manual fetch
|
|
428
|
+
parent_state = context.parent_state
|
|
429
|
+
db_session = context.db_session
|
|
430
|
+
problem_db = context.problem_db
|
|
431
|
+
|
|
432
|
+
if parent_state is None:
|
|
466
433
|
raise HTTPException(status_code=404, detail="Parent state not found.")
|
|
467
434
|
|
|
468
|
-
if not isinstance(
|
|
469
|
-
raise TypeError(f"State with id={
|
|
435
|
+
if not isinstance(parent_state.state, EMOIterateState):
|
|
436
|
+
raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")
|
|
470
437
|
|
|
471
|
-
if not (
|
|
472
|
-
raise ValueError(
|
|
438
|
+
if not (parent_state.state.objective_values and parent_state.state.decision_variables):
|
|
439
|
+
raise ValueError("State does not contain results yet.")
|
|
473
440
|
|
|
474
|
-
|
|
475
|
-
|
|
441
|
+
score_config = SCOREBandsConfig() if request.config is None else request.config
|
|
442
|
+
|
|
443
|
+
raw_objs: dict[str, list[float]] = parent_state.state.objective_values
|
|
476
444
|
objs = pl.DataFrame(raw_objs)
|
|
477
445
|
|
|
478
446
|
results = score_json(
|
|
@@ -482,16 +450,19 @@ async def fetch_score_bands(
|
|
|
482
450
|
|
|
483
451
|
score_state = EMOSCOREState(result=results.model_dump())
|
|
484
452
|
|
|
453
|
+
# Use the session + problem from context instead of request directly
|
|
485
454
|
score_db_state = StateDB.create(
|
|
486
|
-
database_session=
|
|
487
|
-
problem_id=
|
|
488
|
-
session_id=
|
|
489
|
-
parent_id=parent_state,
|
|
455
|
+
database_session=db_session,
|
|
456
|
+
problem_id=problem_db.id,
|
|
457
|
+
session_id=parent_state.session_id,
|
|
458
|
+
parent_id=parent_state.id,
|
|
490
459
|
state=score_state,
|
|
491
460
|
)
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
461
|
+
|
|
462
|
+
db_session.add(score_db_state)
|
|
463
|
+
db_session.commit()
|
|
464
|
+
db_session.refresh(score_db_state)
|
|
465
|
+
|
|
495
466
|
state_id = score_db_state.id
|
|
496
467
|
|
|
497
468
|
return EMOScoreResponse(result=results, state_id=state_id)
|
desdeo/api/routers/generic.py
CHANGED
|
@@ -5,65 +5,55 @@ from typing import Annotated
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
7
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
8
|
-
from sqlmodel import
|
|
8
|
+
from sqlmodel import select
|
|
9
9
|
|
|
10
|
-
from desdeo.api.db import get_session
|
|
11
10
|
from desdeo.api.models import (
|
|
12
|
-
InteractiveSessionDB,
|
|
13
11
|
IntermediateSolutionRequest,
|
|
14
12
|
IntermediateSolutionState,
|
|
15
|
-
ProblemDB,
|
|
16
13
|
ScoreBandsRequest,
|
|
17
14
|
ScoreBandsResponse,
|
|
18
15
|
SolutionReference,
|
|
19
16
|
StateDB,
|
|
20
|
-
User,
|
|
21
17
|
)
|
|
22
18
|
from desdeo.api.models.generic import GenericIntermediateSolutionResponse
|
|
23
|
-
from desdeo.api.routers.user_authentication import get_current_user
|
|
24
19
|
from desdeo.mcdm.nimbus import solve_intermediate_solutions
|
|
25
20
|
from desdeo.problem import Problem
|
|
26
21
|
from desdeo.tools import SolverResults
|
|
27
22
|
from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions
|
|
28
23
|
|
|
24
|
+
from .utils import SessionContext, get_session_context
|
|
25
|
+
|
|
29
26
|
router = APIRouter(prefix="/method/generic")
|
|
30
27
|
|
|
31
28
|
|
|
32
29
|
@router.post("/intermediate")
|
|
33
30
|
def solve_intermediate(
|
|
34
31
|
request: IntermediateSolutionRequest,
|
|
35
|
-
|
|
36
|
-
session: Annotated[Session, Depends(get_session)],
|
|
32
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
37
33
|
) -> GenericIntermediateSolutionResponse:
|
|
38
|
-
"""Solve intermediate solutions between given two solutions.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
# request.session_id is None:
|
|
50
|
-
# use active session instead
|
|
51
|
-
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
52
|
-
|
|
53
|
-
interactive_session = session.exec(statement).first()
|
|
34
|
+
"""Solve intermediate solutions between given two solutions.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
request (IntermediateSolutionRequest): The request object containing parameters
|
|
38
|
+
for fetching results.
|
|
39
|
+
context (Annotated[SessionContext, Depends]): The session context.
|
|
40
|
+
"""
|
|
41
|
+
db_session = context.db_session
|
|
42
|
+
problem_db = context.problem_db
|
|
43
|
+
interactive_session = context.interactive_session
|
|
44
|
+
parent_state = context.parent_state
|
|
54
45
|
|
|
55
46
|
# query both reference solutions' variable values
|
|
56
|
-
# stored as lit of tuples, first element of each tuple are variables values, second are objective function values
|
|
57
47
|
var_and_obj_values_of_references: list[tuple[dict, dict]] = []
|
|
58
48
|
reference_states = []
|
|
49
|
+
|
|
59
50
|
for solution_info in [request.reference_solution_1, request.reference_solution_2]:
|
|
60
|
-
solution_state =
|
|
51
|
+
solution_state = db_session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first()
|
|
61
52
|
|
|
62
53
|
if solution_state is None:
|
|
63
|
-
# no StateDB found with the given id
|
|
64
54
|
raise HTTPException(
|
|
65
55
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
66
|
-
detail=f"Could not find a state with
|
|
56
|
+
detail=f"Could not find a state with id={solution_info.state_id}.",
|
|
67
57
|
)
|
|
68
58
|
|
|
69
59
|
reference_states.append(solution_state)
|
|
@@ -71,7 +61,6 @@ def solve_intermediate(
|
|
|
71
61
|
try:
|
|
72
62
|
_var_values = solution_state.state.result_variable_values
|
|
73
63
|
var_values = _var_values[solution_info.solution_index]
|
|
74
|
-
|
|
75
64
|
except IndexError as exc:
|
|
76
65
|
raise HTTPException(
|
|
77
66
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -83,7 +72,6 @@ def solve_intermediate(
|
|
|
83
72
|
try:
|
|
84
73
|
_obj_values = solution_state.state.result_objective_values
|
|
85
74
|
obj_values = _obj_values[solution_info.solution_index]
|
|
86
|
-
|
|
87
75
|
except IndexError as exc:
|
|
88
76
|
raise HTTPException(
|
|
89
77
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -94,10 +82,7 @@ def solve_intermediate(
|
|
|
94
82
|
|
|
95
83
|
var_and_obj_values_of_references.append((var_values, obj_values))
|
|
96
84
|
|
|
97
|
-
#
|
|
98
|
-
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
99
|
-
problem_db = session.exec(statement).first()
|
|
100
|
-
|
|
85
|
+
# Problem is now already loaded via context
|
|
101
86
|
if problem_db is None:
|
|
102
87
|
raise HTTPException(
|
|
103
88
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
@@ -116,26 +101,9 @@ def solve_intermediate(
|
|
|
116
101
|
solver_options=request.solver_options,
|
|
117
102
|
)
|
|
118
103
|
|
|
119
|
-
#
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
parent_state = (
|
|
123
|
-
interactive_session.states[-1]
|
|
124
|
-
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
125
|
-
else None
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
# request.parent_state_id is not None
|
|
130
|
-
statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
131
|
-
parent_state = session.exec(statement).first()
|
|
132
|
-
|
|
133
|
-
if parent_state is None:
|
|
134
|
-
raise HTTPException(
|
|
135
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
136
|
-
detail=f"Could not find state with id={request.parent_state_id}",
|
|
137
|
-
)
|
|
138
|
-
|
|
104
|
+
# --------------------------------------
|
|
105
|
+
# parent_state is already loaded in context
|
|
106
|
+
# --------------------------------------
|
|
139
107
|
intermediate_state = IntermediateSolutionState(
|
|
140
108
|
scalarization_options=request.scalarization_options,
|
|
141
109
|
context=request.context,
|
|
@@ -149,16 +117,16 @@ def solve_intermediate(
|
|
|
149
117
|
|
|
150
118
|
# create DB state and add it to the DB
|
|
151
119
|
state = StateDB.create(
|
|
152
|
-
database_session=
|
|
120
|
+
database_session=db_session,
|
|
153
121
|
problem_id=problem_db.id,
|
|
154
122
|
session_id=interactive_session.id if interactive_session is not None else None,
|
|
155
123
|
parent_id=parent_state.id if parent_state is not None else None,
|
|
156
124
|
state=intermediate_state,
|
|
157
125
|
)
|
|
158
126
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
127
|
+
db_session.add(state)
|
|
128
|
+
db_session.commit()
|
|
129
|
+
db_session.refresh(state)
|
|
162
130
|
|
|
163
131
|
return GenericIntermediateSolutionResponse(
|
|
164
132
|
state_id=state.id,
|