desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +50 -0
- desdeo/api/config.py +90 -0
- desdeo/api/config.toml +64 -0
- desdeo/api/db.py +27 -0
- desdeo/api/db_init.py +85 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +266 -0
- desdeo/api/models/archive.py +23 -0
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +128 -0
- desdeo/api/models/problem.py +717 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +49 -0
- desdeo/api/models/state.py +463 -0
- desdeo/api/models/user.py +52 -0
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +765 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +307 -0
- desdeo/api/routers/reference_point_method.py +93 -0
- desdeo/api/routers/session.py +100 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +520 -0
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +100 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +151 -0
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +1179 -0
- desdeo/api/tests/test_routes.py +1075 -0
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/_logger.py +29 -0
- desdeo/api/utils/database.py +36 -0
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +34 -0
- desdeo/emo/__init__.py +159 -0
- desdeo/emo/hooks/archivers.py +188 -0
- desdeo/emo/methods/EAs.py +541 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +12 -0
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +1282 -0
- desdeo/emo/operators/evaluator.py +114 -0
- desdeo/emo/operators/generator.py +459 -0
- desdeo/emo/operators/mutation.py +1224 -0
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +1778 -0
- desdeo/emo/operators/termination.py +286 -0
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +41 -0
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +656 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +186 -0
- desdeo/problem/__init__.py +83 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +487 -0
- desdeo/problem/schema.py +1829 -0
- desdeo/problem/simulator_evaluator.py +348 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +88 -0
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +283 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problems.py +440 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +274 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +120 -0
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +165 -0
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +117 -0
- desdeo/tools/indicators_unary.py +362 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +265 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +134 -0
- desdeo/tools/patterns.py +283 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +477 -0
- desdeo/tools/reference_vectors.py +229 -0
- desdeo/tools/scalarization.py +2065 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +627 -0
- desdeo/tools/utils.py +388 -0
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.1.0.dist-info/METADATA +186 -0
- desdeo-2.1.0.dist-info/RECORD +180 -0
- {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
- desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
- desdeo-1.2.dist-info/METADATA +0 -16
- desdeo-1.2.dist-info/RECORD +0 -4
|
@@ -0,0 +1,497 @@
|
|
|
1
|
+
"""Router for evolutionary multiobjective optimization (EMO) methods."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from asyncio import run
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from multiprocessing import Manager as ProcessManager
|
|
8
|
+
from multiprocessing import Process
|
|
9
|
+
from multiprocessing.synchronize import Event as EventClass # only for typing, can be removed
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Annotated
|
|
12
|
+
from warnings import warn
|
|
13
|
+
|
|
14
|
+
import polars as pl
|
|
15
|
+
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
|
|
16
|
+
from fastapi.encoders import jsonable_encoder
|
|
17
|
+
from fastapi.responses import StreamingResponse
|
|
18
|
+
from sqlmodel import Session, select
|
|
19
|
+
from websockets.asyncio.client import connect
|
|
20
|
+
|
|
21
|
+
from desdeo.api.db import get_session
|
|
22
|
+
from desdeo.api.models import InteractiveSessionDB, StateDB
|
|
23
|
+
from desdeo.api.models.emo import (
|
|
24
|
+
EMOFetchRequest,
|
|
25
|
+
EMOFetchResponse,
|
|
26
|
+
EMOIterateRequest,
|
|
27
|
+
EMOIterateResponse,
|
|
28
|
+
EMOSaveRequest,
|
|
29
|
+
EMOScoreRequest,
|
|
30
|
+
EMOScoreResponse,
|
|
31
|
+
Solution,
|
|
32
|
+
)
|
|
33
|
+
from desdeo.api.models.problem import ProblemDB
|
|
34
|
+
from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState
|
|
35
|
+
from desdeo.api.models.user import User
|
|
36
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
37
|
+
from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor
|
|
38
|
+
from desdeo.problem import Problem
|
|
39
|
+
from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json
|
|
40
|
+
|
|
41
|
+
from .utils import fetch_interactive_session, fetch_user_problem
|
|
42
|
+
|
|
43
|
+
router = APIRouter(prefix="/method/emo", tags=["EMO"])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class WSmanager:
|
|
47
|
+
"""Manages active WebSocket connections for EMO methods."""
|
|
48
|
+
|
|
49
|
+
def __init__(self):
|
|
50
|
+
"""Initializes the WebSocket manager."""
|
|
51
|
+
self.active_connections: dict[str, WebSocket] = {}
|
|
52
|
+
"""
|
|
53
|
+
A dictionary to keep track of active WebSocket connections.
|
|
54
|
+
The keys are the client identifiers. Note: not the same as `websocket.client`,
|
|
55
|
+
which is just a tuple of (host, port). Nor is it the user id. Each new
|
|
56
|
+
EA instance will have its own unique identifier. The webui client should
|
|
57
|
+
get its id from the server.
|
|
58
|
+
"""
|
|
59
|
+
self.unsent_messages: dict[str, list[dict]] = {}
|
|
60
|
+
"""A dictionary to store unsent messages for clients that are not currently connected."""
|
|
61
|
+
|
|
62
|
+
async def connect(self, websocket: WebSocket, client_id: str):
|
|
63
|
+
"""Accepts a new WebSocket connection."""
|
|
64
|
+
await websocket.accept()
|
|
65
|
+
self.active_connections[client_id] = websocket
|
|
66
|
+
if client_id in self.unsent_messages:
|
|
67
|
+
for message in self.unsent_messages[client_id]:
|
|
68
|
+
await websocket.send_json(message)
|
|
69
|
+
self.unsent_messages.pop(client_id, None)
|
|
70
|
+
|
|
71
|
+
def disconnect(self, websocket: WebSocket):
|
|
72
|
+
"""Removes a WebSocket connection."""
|
|
73
|
+
for client_id, ws in self.active_connections.items():
|
|
74
|
+
if ws.client == websocket.client:
|
|
75
|
+
client_id_to_remove = client_id
|
|
76
|
+
self.active_connections.pop(client_id_to_remove, None)
|
|
77
|
+
|
|
78
|
+
async def send_private_message(self, message: dict, client_id: str):
|
|
79
|
+
"""Sends a private message to a specific WebSocket connection."""
|
|
80
|
+
websocket = self.active_connections.get(client_id)
|
|
81
|
+
if websocket:
|
|
82
|
+
await websocket.send_json(message)
|
|
83
|
+
else:
|
|
84
|
+
if client_id not in self.unsent_messages:
|
|
85
|
+
self.unsent_messages[client_id] = []
|
|
86
|
+
self.unsent_messages[client_id].append(message)
|
|
87
|
+
warn(
|
|
88
|
+
f"Client with id={client_id} is not connected. Message saved, will be sent upon connection.",
|
|
89
|
+
stacklevel=2,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
async def broadcast_message(self, message: dict):
|
|
93
|
+
"""Sends a message to all active WebSocket connections.
|
|
94
|
+
|
|
95
|
+
Typically don't use this as this won't send messages
|
|
96
|
+
to disconnected/unconnected clients.
|
|
97
|
+
"""
|
|
98
|
+
for websocket in self.active_connections.values():
|
|
99
|
+
await websocket.send_json(message)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
ws_manager = WSmanager()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@router.websocket("/ws/{client_id}")
|
|
106
|
+
async def websocket_endpoint(
|
|
107
|
+
websocket: WebSocket,
|
|
108
|
+
client_id: str,
|
|
109
|
+
# TODO(@light-weaver): Add authentication
|
|
110
|
+
):
|
|
111
|
+
"""WebSocket endpoint for EMO methods."""
|
|
112
|
+
await ws_manager.connect(websocket, client_id)
|
|
113
|
+
try:
|
|
114
|
+
while True:
|
|
115
|
+
data = await websocket.receive_json()
|
|
116
|
+
print(data)
|
|
117
|
+
if "send_to" in data:
|
|
118
|
+
try:
|
|
119
|
+
await ws_manager.send_private_message(data, data["send_to"])
|
|
120
|
+
except ValueError as e:
|
|
121
|
+
warn(f"ValueError in WebSocket communication: {e}", stacklevel=2)
|
|
122
|
+
except WebSocketDisconnect:
|
|
123
|
+
ws_manager.disconnect(websocket)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def handle_stop_event(stop_event: EventClass, listener_id: str):
|
|
127
|
+
"""Handles the stop event for the WebSocket connections."""
|
|
128
|
+
async with connect(f"ws://localhost:8000/method/emo/ws/{listener_id}") as websocket:
|
|
129
|
+
while True:
|
|
130
|
+
data = await websocket.receive_json()
|
|
131
|
+
if "message" in data and data["message"] == "stop":
|
|
132
|
+
stop_event.set()
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_templates() -> list[TemplateOptions]:
|
|
137
|
+
"""Fetches available EMO templates."""
|
|
138
|
+
current_dir = Path(__file__)
|
|
139
|
+
# Should be a database lookup in the future
|
|
140
|
+
json_load_path = current_dir.parent.parent.parent.parent / "datasets" / "emoTemplates"
|
|
141
|
+
|
|
142
|
+
algos = ["nsga3"]
|
|
143
|
+
|
|
144
|
+
templates = []
|
|
145
|
+
for algo in algos:
|
|
146
|
+
with Path.open(json_load_path / f"{algo}.json", "r") as f:
|
|
147
|
+
data = json.load(f)
|
|
148
|
+
template = EMOOptions.model_validate(data)
|
|
149
|
+
templates.append(template.template)
|
|
150
|
+
return templates
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@router.post("/iterate")
|
|
154
|
+
def iterate(
|
|
155
|
+
request: EMOIterateRequest,
|
|
156
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
157
|
+
session: Annotated[Session, Depends(get_session)],
|
|
158
|
+
) -> EMOIterateResponse:
|
|
159
|
+
"""Starts the EMO method.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
request (EMOSolveRequest): The request object containing parameters for the EMO method.
|
|
163
|
+
user (Annotated[User, Depends]): The current user.
|
|
164
|
+
session (Annotated[Session, Depends]): The database session.
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
HTTPException: If the request is invalid or the EMO method fails.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
IterateResponse: A response object containing a list of IDs to be used for websocket communication.
|
|
171
|
+
Also contains the StateDB id where the results will be stored.
|
|
172
|
+
"""
|
|
173
|
+
interactive_session: InteractiveSessionDB | None = fetch_interactive_session(user, request, session)
|
|
174
|
+
|
|
175
|
+
problem_db = fetch_user_problem(user, request, session)
|
|
176
|
+
problem = Problem.from_problemdb(problem_db)
|
|
177
|
+
|
|
178
|
+
templates = request.template_options
|
|
179
|
+
|
|
180
|
+
if templates is None:
|
|
181
|
+
templates = get_templates()
|
|
182
|
+
|
|
183
|
+
web_socket_ids = []
|
|
184
|
+
for template in templates:
|
|
185
|
+
# Ensure unique names
|
|
186
|
+
web_socket_ids.append(f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}")
|
|
187
|
+
client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
188
|
+
client_id = "client"
|
|
189
|
+
|
|
190
|
+
# Save request (incomplete and EAs have not finished running yet)
|
|
191
|
+
|
|
192
|
+
# Handle parent state
|
|
193
|
+
if request.parent_state_id is None:
|
|
194
|
+
parent_state = None
|
|
195
|
+
else:
|
|
196
|
+
statement = select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
197
|
+
parent_state = session.exec(statement).first()
|
|
198
|
+
|
|
199
|
+
if parent_state is None:
|
|
200
|
+
raise HTTPException(
|
|
201
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
202
|
+
detail=f"Could not find state with id={request.parent_state_id}",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
emo_iterate_state = EMOIterateState(
|
|
206
|
+
template_options=jsonable_encoder(templates),
|
|
207
|
+
preference_options=jsonable_encoder(request.preference_options),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
incomplete_db_state = StateDB.create(
|
|
211
|
+
database_session=session,
|
|
212
|
+
problem_id=problem_db.id,
|
|
213
|
+
session_id=interactive_session.id if interactive_session else None,
|
|
214
|
+
parent_id=parent_state.id if parent_state else None,
|
|
215
|
+
state=emo_iterate_state,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
session.add(incomplete_db_state)
|
|
219
|
+
session.commit()
|
|
220
|
+
session.refresh(incomplete_db_state)
|
|
221
|
+
|
|
222
|
+
state_id = incomplete_db_state.id
|
|
223
|
+
if state_id is None:
|
|
224
|
+
raise HTTPException(
|
|
225
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
226
|
+
detail="Failed to create a new state in the database.",
|
|
227
|
+
)
|
|
228
|
+
# Close db session
|
|
229
|
+
session.close()
|
|
230
|
+
|
|
231
|
+
# Spawn a new process to handle EMO method creation
|
|
232
|
+
Process(
|
|
233
|
+
target=_spawn_emo_process,
|
|
234
|
+
args=(
|
|
235
|
+
problem,
|
|
236
|
+
templates,
|
|
237
|
+
request.preference_options,
|
|
238
|
+
web_socket_ids,
|
|
239
|
+
client_id,
|
|
240
|
+
state_id,
|
|
241
|
+
),
|
|
242
|
+
).start()
|
|
243
|
+
|
|
244
|
+
return EMOIterateResponse(method_ids=web_socket_ids, client_id=client_id, state_id=state_id)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _spawn_emo_process(
|
|
248
|
+
problem: Problem,
|
|
249
|
+
templates: list[TemplateOptions],
|
|
250
|
+
preference_options: PreferenceOptions | None,
|
|
251
|
+
websocket_ids: list[str],
|
|
252
|
+
client_id: str,
|
|
253
|
+
state_id: int,
|
|
254
|
+
):
|
|
255
|
+
"""Spawns a new process to handle the EMO method.
|
|
256
|
+
|
|
257
|
+
In turn, this will start multiple processes for each evolutionary algorithm.
|
|
258
|
+
It will collect results and update the database.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
problem (Problem): The problem object.
|
|
262
|
+
templates (List[TemplateOptions]): The list of templates to use.
|
|
263
|
+
preference_options (PreferenceOptions | None): The preference options to use.
|
|
264
|
+
websocket_ids (list[str]): The list of WebSocket IDs.
|
|
265
|
+
client_id (str): The client ID for WebSocket communication.
|
|
266
|
+
state_id (int): The state ID in the database to update with results.
|
|
267
|
+
"""
|
|
268
|
+
process_manager = ProcessManager()
|
|
269
|
+
stop_event = process_manager.Event()
|
|
270
|
+
results_dict = process_manager.dict()
|
|
271
|
+
# Spawn a bunch of EAs
|
|
272
|
+
processes = []
|
|
273
|
+
|
|
274
|
+
for w_id, template in zip(
|
|
275
|
+
websocket_ids, templates, strict=True
|
|
276
|
+
): # Skip the first id, which is for the webui client
|
|
277
|
+
p = Process(
|
|
278
|
+
target=_ea_sync,
|
|
279
|
+
args=(problem, template, preference_options, stop_event.is_set, w_id, client_id, results_dict),
|
|
280
|
+
)
|
|
281
|
+
processes.append(p)
|
|
282
|
+
p.start()
|
|
283
|
+
|
|
284
|
+
# collect results
|
|
285
|
+
for p in processes:
|
|
286
|
+
p.join()
|
|
287
|
+
|
|
288
|
+
# Combine results
|
|
289
|
+
optimal_variables = pl.concat([results.optimal_variables for results in results_dict.values()])
|
|
290
|
+
optimal_outputs = pl.concat([results.optimal_outputs for results in results_dict.values()])
|
|
291
|
+
# update DB
|
|
292
|
+
session = next(get_session())
|
|
293
|
+
statement = select(StateDB).where(StateDB.id == state_id)
|
|
294
|
+
state = session.exec(statement).first()
|
|
295
|
+
if state is None:
|
|
296
|
+
raise ValueError(f"Could not find state with id={state_id} to update with results.")
|
|
297
|
+
emo_state = state.state
|
|
298
|
+
if not isinstance(emo_state, EMOIterateState):
|
|
299
|
+
raise TypeError(f"State with id={state_id} is not of type EMOIterateState.")
|
|
300
|
+
# TODO(@light-weaver): Just a dirty way to handle this. Use non-dominated merge and also split dec and obj vars
|
|
301
|
+
var_names = [var.symbol for var in problem.get_flattened_variables()]
|
|
302
|
+
obj_names = [obj.symbol for obj in problem.objectives]
|
|
303
|
+
if problem.constraints is not None: # noqa: SIM108
|
|
304
|
+
constr_names = [constr.symbol for constr in problem.constraints]
|
|
305
|
+
else:
|
|
306
|
+
constr_names = []
|
|
307
|
+
if problem.extra_funcs is not None: # noqa: SIM108
|
|
308
|
+
extra_names = [extra.symbol for extra in problem.extra_funcs]
|
|
309
|
+
else:
|
|
310
|
+
extra_names = []
|
|
311
|
+
|
|
312
|
+
emo_state.decision_variables = optimal_variables[var_names].to_dict(as_series=False)
|
|
313
|
+
emo_state.objective_values = optimal_outputs[obj_names].to_dict(as_series=False)
|
|
314
|
+
emo_state.constraint_values = optimal_outputs[constr_names].to_dict(as_series=False) if constr_names else None
|
|
315
|
+
emo_state.extra_func_values = optimal_outputs[extra_names].to_dict(as_series=False) if extra_names else None
|
|
316
|
+
|
|
317
|
+
session.add(emo_state)
|
|
318
|
+
session.commit()
|
|
319
|
+
session.close()
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _ea_sync( # noqa: PLR0913
|
|
323
|
+
problem: Problem,
|
|
324
|
+
template: TemplateOptions,
|
|
325
|
+
preference_options: PreferenceOptions | None,
|
|
326
|
+
stop_event: Callable[[], bool],
|
|
327
|
+
websocket_id: str,
|
|
328
|
+
client_id: str,
|
|
329
|
+
results_dict: dict,
|
|
330
|
+
):
|
|
331
|
+
"""Synchronous wrapper to run the evolutionary algorithm in an async event loop.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
problem (Problem): The problem object.
|
|
335
|
+
template (TemplateOptions): The template options for the EMO method.
|
|
336
|
+
preference_options (PreferenceOptions | None): The preference options for the EMO method.
|
|
337
|
+
stop_event (Callable[[], bool]): A callable that returns True if the algorithm should stop.
|
|
338
|
+
websocket_id (str): The WebSocket ID for the current EMO method for communication.
|
|
339
|
+
client_id (str): The ID of the client to send websocket messages to.
|
|
340
|
+
results_dict (dict): A shared ProcessManager dictionary to store results.
|
|
341
|
+
"""
|
|
342
|
+
run(
|
|
343
|
+
_ea_async(
|
|
344
|
+
problem=problem,
|
|
345
|
+
websocket_id=websocket_id,
|
|
346
|
+
client_id=client_id,
|
|
347
|
+
stop_event=stop_event,
|
|
348
|
+
results_dict=results_dict,
|
|
349
|
+
template=template,
|
|
350
|
+
preference_options=preference_options,
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
async def _ea_async( # noqa: PLR0913
|
|
356
|
+
problem: Problem,
|
|
357
|
+
websocket_id: str,
|
|
358
|
+
client_id: str,
|
|
359
|
+
stop_event: Callable[[], bool],
|
|
360
|
+
results_dict: dict,
|
|
361
|
+
template: TemplateOptions,
|
|
362
|
+
preference_options: PreferenceOptions | None = None,
|
|
363
|
+
):
|
|
364
|
+
"""Executes an evolutionary algorithm.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
problem (Problem): The problem object.
|
|
368
|
+
websocket_id (str): The WebSocket ID for the current EMO method for communication.
|
|
369
|
+
client_id (str): The ID of the client to send websocket messages to.
|
|
370
|
+
stop_event (Event): The stop event to signal when to stop the algorithm.
|
|
371
|
+
results_dict (dict): A shared ProcessManager dictionary to store results.
|
|
372
|
+
template (TemplateOptions): The template options for the EMO method.
|
|
373
|
+
preference_options (PreferenceOptions | None): The preference options for the EMO method.
|
|
374
|
+
"""
|
|
375
|
+
# TODO: the url should not be hardcoded
|
|
376
|
+
async with connect(f"ws://localhost:8000/method/emo/ws/{websocket_id}") as ws:
|
|
377
|
+
text = f'{{"message": "Started {websocket_id}", "send_to": "{client_id}"}}'
|
|
378
|
+
await ws.send(text)
|
|
379
|
+
emo_options = EMOOptions(template=template, preference=preference_options)
|
|
380
|
+
solver, extras = emo_constructor(emo_options, problem=problem, external_check=stop_event)
|
|
381
|
+
results = solver()
|
|
382
|
+
if extras.archive is not None:
|
|
383
|
+
results = extras.archive.results
|
|
384
|
+
await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}')
|
|
385
|
+
results_dict[websocket_id] = results
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
@router.post("/fetch")
|
|
389
|
+
async def fetch_results(
|
|
390
|
+
request: EMOFetchRequest,
|
|
391
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
392
|
+
session: Annotated[Session, Depends(get_session)],
|
|
393
|
+
) -> StreamingResponse:
|
|
394
|
+
"""Fetches results from a completed EMO method.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
request (EMOFetchRequest): The request object containing parameters for fetching results.
|
|
398
|
+
user (Annotated[User, Depends]): The current user.
|
|
399
|
+
session (Annotated[Session, Depends]): The database session.
|
|
400
|
+
|
|
401
|
+
Raises:
|
|
402
|
+
HTTPException: If the request is invalid or the EMO method has not completed.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
StreamingResponse: A streaming response containing the results of the EMO method.
|
|
406
|
+
"""
|
|
407
|
+
parent_state = request.parent_state_id
|
|
408
|
+
statement = select(StateDB).where(StateDB.id == parent_state)
|
|
409
|
+
state = session.exec(statement).first()
|
|
410
|
+
if state is None:
|
|
411
|
+
raise HTTPException(status_code=404, detail="Parent state not found.")
|
|
412
|
+
|
|
413
|
+
if not isinstance(state.state, EMOIterateState):
|
|
414
|
+
raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
|
|
415
|
+
|
|
416
|
+
if not (state.state.objective_values and state.state.decision_variables):
|
|
417
|
+
raise ValueError(f"State does not contain results yet.")
|
|
418
|
+
|
|
419
|
+
# Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
|
|
420
|
+
raw_objs: dict[str, list[float]] = state.state.objective_values
|
|
421
|
+
n_solutions = len(next(iter(raw_objs.values())))
|
|
422
|
+
objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)]
|
|
423
|
+
|
|
424
|
+
raw_decs: dict[str, list[float]] = state.state.decision_variables
|
|
425
|
+
|
|
426
|
+
decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)]
|
|
427
|
+
|
|
428
|
+
response: list[Solution] = []
|
|
429
|
+
|
|
430
|
+
def result_stream():
|
|
431
|
+
for i in range(n_solutions):
|
|
432
|
+
item = {"solution_id": i, "objective_values": objs[i], "decision_variables": decs[i]}
|
|
433
|
+
yield json.dumps(item) + "\n"
|
|
434
|
+
|
|
435
|
+
return StreamingResponse(result_stream())
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@router.post("/fetch_score")
|
|
439
|
+
async def fetch_score_bands(
|
|
440
|
+
request: EMOScoreRequest,
|
|
441
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
442
|
+
session: Annotated[Session, Depends(get_session)],
|
|
443
|
+
) -> EMOScoreResponse:
|
|
444
|
+
"""Fetches results from a completed EMO method.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
request (EMOFetchRequest): The request object containing parameters for fetching results and of the SCORE bands
|
|
448
|
+
visualization.
|
|
449
|
+
user (Annotated[User, Depends]): The current user.
|
|
450
|
+
session (Annotated[Session, Depends]): The database session.
|
|
451
|
+
|
|
452
|
+
Raises:
|
|
453
|
+
HTTPException: If the request is invalid or the EMO method has not completed.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
SCOREBandsResult: The results of the SCORE bands visualization.
|
|
457
|
+
"""
|
|
458
|
+
if request.config is None:
|
|
459
|
+
score_config = SCOREBandsConfig()
|
|
460
|
+
else:
|
|
461
|
+
score_config = request.config
|
|
462
|
+
parent_state = request.parent_state_id
|
|
463
|
+
statement = select(StateDB).where(StateDB.id == parent_state)
|
|
464
|
+
state = session.exec(statement).first()
|
|
465
|
+
if state is None:
|
|
466
|
+
raise HTTPException(status_code=404, detail="Parent state not found.")
|
|
467
|
+
|
|
468
|
+
if not isinstance(state.state, EMOIterateState):
|
|
469
|
+
raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
|
|
470
|
+
|
|
471
|
+
if not (state.state.objective_values and state.state.decision_variables):
|
|
472
|
+
raise ValueError(f"State does not contain results yet.")
|
|
473
|
+
|
|
474
|
+
# Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
|
|
475
|
+
raw_objs: dict[str, list[float]] = state.state.objective_values
|
|
476
|
+
objs = pl.DataFrame(raw_objs)
|
|
477
|
+
|
|
478
|
+
results = score_json(
|
|
479
|
+
data=objs,
|
|
480
|
+
options=score_config,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
score_state = EMOSCOREState(result=results.model_dump())
|
|
484
|
+
|
|
485
|
+
score_db_state = StateDB.create(
|
|
486
|
+
database_session=session,
|
|
487
|
+
problem_id=request.problem_id,
|
|
488
|
+
session_id=request.session_id,
|
|
489
|
+
parent_id=parent_state,
|
|
490
|
+
state=score_state,
|
|
491
|
+
)
|
|
492
|
+
session.add(score_db_state)
|
|
493
|
+
session.commit()
|
|
494
|
+
session.refresh(score_db_state)
|
|
495
|
+
state_id = score_db_state.id
|
|
496
|
+
|
|
497
|
+
return EMOScoreResponse(result=results, state_id=state_id)
|