desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/adm/ADMAfsar.py +551 -0
  3. desdeo/adm/ADMChen.py +414 -0
  4. desdeo/adm/BaseADM.py +119 -0
  5. desdeo/adm/__init__.py +11 -0
  6. desdeo/api/README.md +73 -0
  7. desdeo/api/__init__.py +15 -0
  8. desdeo/api/app.py +50 -0
  9. desdeo/api/config.py +90 -0
  10. desdeo/api/config.toml +64 -0
  11. desdeo/api/db.py +27 -0
  12. desdeo/api/db_init.py +85 -0
  13. desdeo/api/db_models.py +164 -0
  14. desdeo/api/malaga_db_init.py +27 -0
  15. desdeo/api/models/__init__.py +266 -0
  16. desdeo/api/models/archive.py +23 -0
  17. desdeo/api/models/emo.py +128 -0
  18. desdeo/api/models/enautilus.py +69 -0
  19. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  20. desdeo/api/models/gdm/gdm_base.py +69 -0
  21. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  22. desdeo/api/models/gdm/gnimbus.py +138 -0
  23. desdeo/api/models/generic.py +104 -0
  24. desdeo/api/models/generic_states.py +401 -0
  25. desdeo/api/models/nimbus.py +158 -0
  26. desdeo/api/models/preference.py +128 -0
  27. desdeo/api/models/problem.py +717 -0
  28. desdeo/api/models/reference_point_method.py +18 -0
  29. desdeo/api/models/session.py +49 -0
  30. desdeo/api/models/state.py +463 -0
  31. desdeo/api/models/user.py +52 -0
  32. desdeo/api/models/utopia.py +25 -0
  33. desdeo/api/routers/_EMO.backup +309 -0
  34. desdeo/api/routers/_NAUTILUS.py +245 -0
  35. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  36. desdeo/api/routers/_NIMBUS.py +765 -0
  37. desdeo/api/routers/__init__.py +5 -0
  38. desdeo/api/routers/emo.py +497 -0
  39. desdeo/api/routers/enautilus.py +237 -0
  40. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  41. desdeo/api/routers/gdm/gdm_base.py +420 -0
  42. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  43. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  44. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  45. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  46. desdeo/api/routers/generic.py +233 -0
  47. desdeo/api/routers/nimbus.py +705 -0
  48. desdeo/api/routers/problem.py +307 -0
  49. desdeo/api/routers/reference_point_method.py +93 -0
  50. desdeo/api/routers/session.py +100 -0
  51. desdeo/api/routers/test.py +16 -0
  52. desdeo/api/routers/user_authentication.py +520 -0
  53. desdeo/api/routers/utils.py +187 -0
  54. desdeo/api/routers/utopia.py +230 -0
  55. desdeo/api/schema.py +100 -0
  56. desdeo/api/tests/__init__.py +0 -0
  57. desdeo/api/tests/conftest.py +151 -0
  58. desdeo/api/tests/test_enautilus.py +330 -0
  59. desdeo/api/tests/test_models.py +1179 -0
  60. desdeo/api/tests/test_routes.py +1075 -0
  61. desdeo/api/utils/_database.py +263 -0
  62. desdeo/api/utils/_logger.py +29 -0
  63. desdeo/api/utils/database.py +36 -0
  64. desdeo/api/utils/emo_database.py +40 -0
  65. desdeo/core.py +34 -0
  66. desdeo/emo/__init__.py +159 -0
  67. desdeo/emo/hooks/archivers.py +188 -0
  68. desdeo/emo/methods/EAs.py +541 -0
  69. desdeo/emo/methods/__init__.py +0 -0
  70. desdeo/emo/methods/bases.py +12 -0
  71. desdeo/emo/methods/templates.py +111 -0
  72. desdeo/emo/operators/__init__.py +1 -0
  73. desdeo/emo/operators/crossover.py +1282 -0
  74. desdeo/emo/operators/evaluator.py +114 -0
  75. desdeo/emo/operators/generator.py +459 -0
  76. desdeo/emo/operators/mutation.py +1224 -0
  77. desdeo/emo/operators/scalar_selection.py +202 -0
  78. desdeo/emo/operators/selection.py +1778 -0
  79. desdeo/emo/operators/termination.py +286 -0
  80. desdeo/emo/options/__init__.py +108 -0
  81. desdeo/emo/options/algorithms.py +435 -0
  82. desdeo/emo/options/crossover.py +164 -0
  83. desdeo/emo/options/generator.py +131 -0
  84. desdeo/emo/options/mutation.py +260 -0
  85. desdeo/emo/options/repair.py +61 -0
  86. desdeo/emo/options/scalar_selection.py +66 -0
  87. desdeo/emo/options/selection.py +127 -0
  88. desdeo/emo/options/templates.py +383 -0
  89. desdeo/emo/options/termination.py +143 -0
  90. desdeo/explanations/__init__.py +6 -0
  91. desdeo/explanations/explainer.py +100 -0
  92. desdeo/explanations/utils.py +90 -0
  93. desdeo/gdm/__init__.py +22 -0
  94. desdeo/gdm/gdmtools.py +45 -0
  95. desdeo/gdm/score_bands.py +114 -0
  96. desdeo/gdm/voting_rules.py +50 -0
  97. desdeo/mcdm/__init__.py +41 -0
  98. desdeo/mcdm/enautilus.py +338 -0
  99. desdeo/mcdm/gnimbus.py +484 -0
  100. desdeo/mcdm/nautili.py +345 -0
  101. desdeo/mcdm/nautilus.py +477 -0
  102. desdeo/mcdm/nautilus_navigator.py +656 -0
  103. desdeo/mcdm/nimbus.py +417 -0
  104. desdeo/mcdm/pareto_navigator.py +269 -0
  105. desdeo/mcdm/reference_point_method.py +186 -0
  106. desdeo/problem/__init__.py +83 -0
  107. desdeo/problem/evaluator.py +561 -0
  108. desdeo/problem/external/__init__.py +18 -0
  109. desdeo/problem/external/core.py +356 -0
  110. desdeo/problem/external/pymoo_provider.py +266 -0
  111. desdeo/problem/external/runtime.py +44 -0
  112. desdeo/problem/gurobipy_evaluator.py +562 -0
  113. desdeo/problem/infix_parser.py +341 -0
  114. desdeo/problem/json_parser.py +944 -0
  115. desdeo/problem/pyomo_evaluator.py +487 -0
  116. desdeo/problem/schema.py +1829 -0
  117. desdeo/problem/simulator_evaluator.py +348 -0
  118. desdeo/problem/sympy_evaluator.py +244 -0
  119. desdeo/problem/testproblems/__init__.py +88 -0
  120. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  121. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  122. desdeo/problem/testproblems/cake_problem.py +185 -0
  123. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  124. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  125. desdeo/problem/testproblems/forest_problem.py +283 -0
  126. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  127. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  128. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  129. desdeo/problem/testproblems/momip_problem.py +172 -0
  130. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  131. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  132. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  133. desdeo/problem/testproblems/re_problem.py +492 -0
  134. desdeo/problem/testproblems/river_pollution_problems.py +440 -0
  135. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  136. desdeo/problem/testproblems/simple_problem.py +351 -0
  137. desdeo/problem/testproblems/simulator_problem.py +92 -0
  138. desdeo/problem/testproblems/single_objective.py +289 -0
  139. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  140. desdeo/problem/testproblems/zdt_problem.py +274 -0
  141. desdeo/problem/utils.py +245 -0
  142. desdeo/tools/GenerateReferencePoints.py +181 -0
  143. desdeo/tools/__init__.py +120 -0
  144. desdeo/tools/desc_gen.py +22 -0
  145. desdeo/tools/generics.py +165 -0
  146. desdeo/tools/group_scalarization.py +3090 -0
  147. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  148. desdeo/tools/indicators_binary.py +117 -0
  149. desdeo/tools/indicators_unary.py +362 -0
  150. desdeo/tools/interaction_schema.py +38 -0
  151. desdeo/tools/intersection.py +54 -0
  152. desdeo/tools/iterative_pareto_representer.py +99 -0
  153. desdeo/tools/message.py +265 -0
  154. desdeo/tools/ng_solver_interfaces.py +199 -0
  155. desdeo/tools/non_dominated_sorting.py +134 -0
  156. desdeo/tools/patterns.py +283 -0
  157. desdeo/tools/proximal_solver.py +99 -0
  158. desdeo/tools/pyomo_solver_interfaces.py +477 -0
  159. desdeo/tools/reference_vectors.py +229 -0
  160. desdeo/tools/scalarization.py +2065 -0
  161. desdeo/tools/scipy_solver_interfaces.py +454 -0
  162. desdeo/tools/score_bands.py +627 -0
  163. desdeo/tools/utils.py +388 -0
  164. desdeo/tools/visualizations.py +67 -0
  165. desdeo/utopia_stuff/__init__.py +0 -0
  166. desdeo/utopia_stuff/data/1.json +15 -0
  167. desdeo/utopia_stuff/data/2.json +13 -0
  168. desdeo/utopia_stuff/data/3.json +15 -0
  169. desdeo/utopia_stuff/data/4.json +17 -0
  170. desdeo/utopia_stuff/data/5.json +15 -0
  171. desdeo/utopia_stuff/from_json.py +40 -0
  172. desdeo/utopia_stuff/reinit_user.py +38 -0
  173. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  174. desdeo/utopia_stuff/utopia_problem.py +403 -0
  175. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  176. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  177. desdeo-2.1.0.dist-info/METADATA +186 -0
  178. desdeo-2.1.0.dist-info/RECORD +180 -0
  179. {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  180. desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
  181. desdeo-1.2.dist-info/METADATA +0 -16
  182. desdeo-1.2.dist-info/RECORD +0 -4
@@ -0,0 +1,5 @@
1
+ """Exports from routers."""
2
+
3
+ __all__ = ["get_current_user"]
4
+
5
+ from .user_authentication import get_current_user
@@ -0,0 +1,497 @@
1
+ """Router for evolutionary multiobjective optimization (EMO) methods."""
2
+
3
+ import json
4
+ from asyncio import run
5
+ from collections.abc import Callable
6
+ from datetime import datetime
7
+ from multiprocessing import Manager as ProcessManager
8
+ from multiprocessing import Process
9
+ from multiprocessing.synchronize import Event as EventClass # only for typing, can be removed
10
+ from pathlib import Path
11
+ from typing import Annotated
12
+ from warnings import warn
13
+
14
+ import polars as pl
15
+ from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
16
+ from fastapi.encoders import jsonable_encoder
17
+ from fastapi.responses import StreamingResponse
18
+ from sqlmodel import Session, select
19
+ from websockets.asyncio.client import connect
20
+
21
+ from desdeo.api.db import get_session
22
+ from desdeo.api.models import InteractiveSessionDB, StateDB
23
+ from desdeo.api.models.emo import (
24
+ EMOFetchRequest,
25
+ EMOFetchResponse,
26
+ EMOIterateRequest,
27
+ EMOIterateResponse,
28
+ EMOSaveRequest,
29
+ EMOScoreRequest,
30
+ EMOScoreResponse,
31
+ Solution,
32
+ )
33
+ from desdeo.api.models.problem import ProblemDB
34
+ from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState
35
+ from desdeo.api.models.user import User
36
+ from desdeo.api.routers.user_authentication import get_current_user
37
+ from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor
38
+ from desdeo.problem import Problem
39
+ from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json
40
+
41
+ from .utils import fetch_interactive_session, fetch_user_problem
42
+
43
+ router = APIRouter(prefix="/method/emo", tags=["EMO"])
44
+
45
+
46
+ class WSmanager:
47
+ """Manages active WebSocket connections for EMO methods."""
48
+
49
+ def __init__(self):
50
+ """Initializes the WebSocket manager."""
51
+ self.active_connections: dict[str, WebSocket] = {}
52
+ """
53
+ A dictionary to keep track of active WebSocket connections.
54
+ The keys are the client identifiers. Note: not the same as `websocket.client`,
55
+ which is just a tuple of (host, port). Nor is it the user id. Each new
56
+ EA instance will have its own unique identifier. The webui client should
57
+ get its id from the server.
58
+ """
59
+ self.unsent_messages: dict[str, list[dict]] = {}
60
+ """A dictionary to store unsent messages for clients that are not currently connected."""
61
+
62
+ async def connect(self, websocket: WebSocket, client_id: str):
63
+ """Accepts a new WebSocket connection."""
64
+ await websocket.accept()
65
+ self.active_connections[client_id] = websocket
66
+ if client_id in self.unsent_messages:
67
+ for message in self.unsent_messages[client_id]:
68
+ await websocket.send_json(message)
69
+ self.unsent_messages.pop(client_id, None)
70
+
71
+ def disconnect(self, websocket: WebSocket):
72
+ """Removes a WebSocket connection."""
73
+ for client_id, ws in self.active_connections.items():
74
+ if ws.client == websocket.client:
75
+ client_id_to_remove = client_id
76
+ self.active_connections.pop(client_id_to_remove, None)
77
+
78
+ async def send_private_message(self, message: dict, client_id: str):
79
+ """Sends a private message to a specific WebSocket connection."""
80
+ websocket = self.active_connections.get(client_id)
81
+ if websocket:
82
+ await websocket.send_json(message)
83
+ else:
84
+ if client_id not in self.unsent_messages:
85
+ self.unsent_messages[client_id] = []
86
+ self.unsent_messages[client_id].append(message)
87
+ warn(
88
+ f"Client with id={client_id} is not connected. Message saved, will be sent upon connection.",
89
+ stacklevel=2,
90
+ )
91
+
92
+ async def broadcast_message(self, message: dict):
93
+ """Sends a message to all active WebSocket connections.
94
+
95
+ Typically don't use this as this won't send messages
96
+ to disconnected/unconnected clients.
97
+ """
98
+ for websocket in self.active_connections.values():
99
+ await websocket.send_json(message)
100
+
101
+
102
+ ws_manager = WSmanager()
103
+
104
+
105
+ @router.websocket("/ws/{client_id}")
106
+ async def websocket_endpoint(
107
+ websocket: WebSocket,
108
+ client_id: str,
109
+ # TODO(@light-weaver): Add authentication
110
+ ):
111
+ """WebSocket endpoint for EMO methods."""
112
+ await ws_manager.connect(websocket, client_id)
113
+ try:
114
+ while True:
115
+ data = await websocket.receive_json()
116
+ print(data)
117
+ if "send_to" in data:
118
+ try:
119
+ await ws_manager.send_private_message(data, data["send_to"])
120
+ except ValueError as e:
121
+ warn(f"ValueError in WebSocket communication: {e}", stacklevel=2)
122
+ except WebSocketDisconnect:
123
+ ws_manager.disconnect(websocket)
124
+
125
+
126
+ async def handle_stop_event(stop_event: EventClass, listener_id: str):
127
+ """Handles the stop event for the WebSocket connections."""
128
+ async with connect(f"ws://localhost:8000/method/emo/ws/{listener_id}") as websocket:
129
+ while True:
130
+ data = await websocket.receive_json()
131
+ if "message" in data and data["message"] == "stop":
132
+ stop_event.set()
133
+ break
134
+
135
+
136
+ def get_templates() -> list[TemplateOptions]:
137
+ """Fetches available EMO templates."""
138
+ current_dir = Path(__file__)
139
+ # Should be a database lookup in the future
140
+ json_load_path = current_dir.parent.parent.parent.parent / "datasets" / "emoTemplates"
141
+
142
+ algos = ["nsga3"]
143
+
144
+ templates = []
145
+ for algo in algos:
146
+ with Path.open(json_load_path / f"{algo}.json", "r") as f:
147
+ data = json.load(f)
148
+ template = EMOOptions.model_validate(data)
149
+ templates.append(template.template)
150
+ return templates
151
+
152
+
153
+ @router.post("/iterate")
154
+ def iterate(
155
+ request: EMOIterateRequest,
156
+ user: Annotated[User, Depends(get_current_user)],
157
+ session: Annotated[Session, Depends(get_session)],
158
+ ) -> EMOIterateResponse:
159
+ """Starts the EMO method.
160
+
161
+ Args:
162
+ request (EMOSolveRequest): The request object containing parameters for the EMO method.
163
+ user (Annotated[User, Depends]): The current user.
164
+ session (Annotated[Session, Depends]): The database session.
165
+
166
+ Raises:
167
+ HTTPException: If the request is invalid or the EMO method fails.
168
+
169
+ Returns:
170
+ IterateResponse: A response object containing a list of IDs to be used for websocket communication.
171
+ Also contains the StateDB id where the results will be stored.
172
+ """
173
+ interactive_session: InteractiveSessionDB | None = fetch_interactive_session(user, request, session)
174
+
175
+ problem_db = fetch_user_problem(user, request, session)
176
+ problem = Problem.from_problemdb(problem_db)
177
+
178
+ templates = request.template_options
179
+
180
+ if templates is None:
181
+ templates = get_templates()
182
+
183
+ web_socket_ids = []
184
+ for template in templates:
185
+ # Ensure unique names
186
+ web_socket_ids.append(f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}")
187
+ client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
188
+ client_id = "client"
189
+
190
+ # Save request (incomplete and EAs have not finished running yet)
191
+
192
+ # Handle parent state
193
+ if request.parent_state_id is None:
194
+ parent_state = None
195
+ else:
196
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
197
+ parent_state = session.exec(statement).first()
198
+
199
+ if parent_state is None:
200
+ raise HTTPException(
201
+ status_code=status.HTTP_404_NOT_FOUND,
202
+ detail=f"Could not find state with id={request.parent_state_id}",
203
+ )
204
+
205
+ emo_iterate_state = EMOIterateState(
206
+ template_options=jsonable_encoder(templates),
207
+ preference_options=jsonable_encoder(request.preference_options),
208
+ )
209
+
210
+ incomplete_db_state = StateDB.create(
211
+ database_session=session,
212
+ problem_id=problem_db.id,
213
+ session_id=interactive_session.id if interactive_session else None,
214
+ parent_id=parent_state.id if parent_state else None,
215
+ state=emo_iterate_state,
216
+ )
217
+
218
+ session.add(incomplete_db_state)
219
+ session.commit()
220
+ session.refresh(incomplete_db_state)
221
+
222
+ state_id = incomplete_db_state.id
223
+ if state_id is None:
224
+ raise HTTPException(
225
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
226
+ detail="Failed to create a new state in the database.",
227
+ )
228
+ # Close db session
229
+ session.close()
230
+
231
+ # Spawn a new process to handle EMO method creation
232
+ Process(
233
+ target=_spawn_emo_process,
234
+ args=(
235
+ problem,
236
+ templates,
237
+ request.preference_options,
238
+ web_socket_ids,
239
+ client_id,
240
+ state_id,
241
+ ),
242
+ ).start()
243
+
244
+ return EMOIterateResponse(method_ids=web_socket_ids, client_id=client_id, state_id=state_id)
245
+
246
+
247
+ def _spawn_emo_process(
248
+ problem: Problem,
249
+ templates: list[TemplateOptions],
250
+ preference_options: PreferenceOptions | None,
251
+ websocket_ids: list[str],
252
+ client_id: str,
253
+ state_id: int,
254
+ ):
255
+ """Spawns a new process to handle the EMO method.
256
+
257
+ In turn, this will start multiple processes for each evolutionary algorithm.
258
+ It will collect results and update the database.
259
+
260
+ Args:
261
+ problem (Problem): The problem object.
262
+ templates (List[TemplateOptions]): The list of templates to use.
263
+ preference_options (PreferenceOptions | None): The preference options to use.
264
+ websocket_ids (list[str]): The list of WebSocket IDs.
265
+ client_id (str): The client ID for WebSocket communication.
266
+ state_id (int): The state ID in the database to update with results.
267
+ """
268
+ process_manager = ProcessManager()
269
+ stop_event = process_manager.Event()
270
+ results_dict = process_manager.dict()
271
+ # Spawn a bunch of EAs
272
+ processes = []
273
+
274
+ for w_id, template in zip(
275
+ websocket_ids, templates, strict=True
276
+ ): # Skip the first id, which is for the webui client
277
+ p = Process(
278
+ target=_ea_sync,
279
+ args=(problem, template, preference_options, stop_event.is_set, w_id, client_id, results_dict),
280
+ )
281
+ processes.append(p)
282
+ p.start()
283
+
284
+ # collect results
285
+ for p in processes:
286
+ p.join()
287
+
288
+ # Combine results
289
+ optimal_variables = pl.concat([results.optimal_variables for results in results_dict.values()])
290
+ optimal_outputs = pl.concat([results.optimal_outputs for results in results_dict.values()])
291
+ # update DB
292
+ session = next(get_session())
293
+ statement = select(StateDB).where(StateDB.id == state_id)
294
+ state = session.exec(statement).first()
295
+ if state is None:
296
+ raise ValueError(f"Could not find state with id={state_id} to update with results.")
297
+ emo_state = state.state
298
+ if not isinstance(emo_state, EMOIterateState):
299
+ raise TypeError(f"State with id={state_id} is not of type EMOIterateState.")
300
+ # TODO(@light-weaver): Just a dirty way to handle this. Use non-dominated merge and also split dec and obj vars
301
+ var_names = [var.symbol for var in problem.get_flattened_variables()]
302
+ obj_names = [obj.symbol for obj in problem.objectives]
303
+ if problem.constraints is not None: # noqa: SIM108
304
+ constr_names = [constr.symbol for constr in problem.constraints]
305
+ else:
306
+ constr_names = []
307
+ if problem.extra_funcs is not None: # noqa: SIM108
308
+ extra_names = [extra.symbol for extra in problem.extra_funcs]
309
+ else:
310
+ extra_names = []
311
+
312
+ emo_state.decision_variables = optimal_variables[var_names].to_dict(as_series=False)
313
+ emo_state.objective_values = optimal_outputs[obj_names].to_dict(as_series=False)
314
+ emo_state.constraint_values = optimal_outputs[constr_names].to_dict(as_series=False) if constr_names else None
315
+ emo_state.extra_func_values = optimal_outputs[extra_names].to_dict(as_series=False) if extra_names else None
316
+
317
+ session.add(emo_state)
318
+ session.commit()
319
+ session.close()
320
+
321
+
322
+ def _ea_sync( # noqa: PLR0913
323
+ problem: Problem,
324
+ template: TemplateOptions,
325
+ preference_options: PreferenceOptions | None,
326
+ stop_event: Callable[[], bool],
327
+ websocket_id: str,
328
+ client_id: str,
329
+ results_dict: dict,
330
+ ):
331
+ """Synchronous wrapper to run the evolutionary algorithm in an async event loop.
332
+
333
+ Args:
334
+ problem (Problem): The problem object.
335
+ template (TemplateOptions): The template options for the EMO method.
336
+ preference_options (PreferenceOptions | None): The preference options for the EMO method.
337
+ stop_event (Callable[[], bool]): A callable that returns True if the algorithm should stop.
338
+ websocket_id (str): The WebSocket ID for the current EMO method for communication.
339
+ client_id (str): The ID of the client to send websocket messages to.
340
+ results_dict (dict): A shared ProcessManager dictionary to store results.
341
+ """
342
+ run(
343
+ _ea_async(
344
+ problem=problem,
345
+ websocket_id=websocket_id,
346
+ client_id=client_id,
347
+ stop_event=stop_event,
348
+ results_dict=results_dict,
349
+ template=template,
350
+ preference_options=preference_options,
351
+ )
352
+ )
353
+
354
+
355
+ async def _ea_async( # noqa: PLR0913
356
+ problem: Problem,
357
+ websocket_id: str,
358
+ client_id: str,
359
+ stop_event: Callable[[], bool],
360
+ results_dict: dict,
361
+ template: TemplateOptions,
362
+ preference_options: PreferenceOptions | None = None,
363
+ ):
364
+ """Executes an evolutionary algorithm.
365
+
366
+ Args:
367
+ problem (Problem): The problem object.
368
+ websocket_id (str): The WebSocket ID for the current EMO method for communication.
369
+ client_id (str): The ID of the client to send websocket messages to.
370
+ stop_event (Event): The stop event to signal when to stop the algorithm.
371
+ results_dict (dict): A shared ProcessManager dictionary to store results.
372
+ template (TemplateOptions): The template options for the EMO method.
373
+ preference_options (PreferenceOptions | None): The preference options for the EMO method.
374
+ """
375
+ # TODO: the url should not be hardcoded
376
+ async with connect(f"ws://localhost:8000/method/emo/ws/{websocket_id}") as ws:
377
+ text = f'{{"message": "Started {websocket_id}", "send_to": "{client_id}"}}'
378
+ await ws.send(text)
379
+ emo_options = EMOOptions(template=template, preference=preference_options)
380
+ solver, extras = emo_constructor(emo_options, problem=problem, external_check=stop_event)
381
+ results = solver()
382
+ if extras.archive is not None:
383
+ results = extras.archive.results
384
+ await ws.send(f'{{"message": "Finished {websocket_id}", "send_to": "{client_id}"}}')
385
+ results_dict[websocket_id] = results
386
+
387
+
388
+ @router.post("/fetch")
389
+ async def fetch_results(
390
+ request: EMOFetchRequest,
391
+ user: Annotated[User, Depends(get_current_user)],
392
+ session: Annotated[Session, Depends(get_session)],
393
+ ) -> StreamingResponse:
394
+ """Fetches results from a completed EMO method.
395
+
396
+ Args:
397
+ request (EMOFetchRequest): The request object containing parameters for fetching results.
398
+ user (Annotated[User, Depends]): The current user.
399
+ session (Annotated[Session, Depends]): The database session.
400
+
401
+ Raises:
402
+ HTTPException: If the request is invalid or the EMO method has not completed.
403
+
404
+ Returns:
405
+ StreamingResponse: A streaming response containing the results of the EMO method.
406
+ """
407
+ parent_state = request.parent_state_id
408
+ statement = select(StateDB).where(StateDB.id == parent_state)
409
+ state = session.exec(statement).first()
410
+ if state is None:
411
+ raise HTTPException(status_code=404, detail="Parent state not found.")
412
+
413
+ if not isinstance(state.state, EMOIterateState):
414
+ raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
415
+
416
+ if not (state.state.objective_values and state.state.decision_variables):
417
+ raise ValueError(f"State does not contain results yet.")
418
+
419
+ # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
420
+ raw_objs: dict[str, list[float]] = state.state.objective_values
421
+ n_solutions = len(next(iter(raw_objs.values())))
422
+ objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)]
423
+
424
+ raw_decs: dict[str, list[float]] = state.state.decision_variables
425
+
426
+ decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)]
427
+
428
+ response: list[Solution] = []
429
+
430
+ def result_stream():
431
+ for i in range(n_solutions):
432
+ item = {"solution_id": i, "objective_values": objs[i], "decision_variables": decs[i]}
433
+ yield json.dumps(item) + "\n"
434
+
435
+ return StreamingResponse(result_stream())
436
+
437
+
438
+ @router.post("/fetch_score")
439
+ async def fetch_score_bands(
440
+ request: EMOScoreRequest,
441
+ user: Annotated[User, Depends(get_current_user)],
442
+ session: Annotated[Session, Depends(get_session)],
443
+ ) -> EMOScoreResponse:
444
+ """Fetches results from a completed EMO method.
445
+
446
+ Args:
447
+ request (EMOFetchRequest): The request object containing parameters for fetching results and of the SCORE bands
448
+ visualization.
449
+ user (Annotated[User, Depends]): The current user.
450
+ session (Annotated[Session, Depends]): The database session.
451
+
452
+ Raises:
453
+ HTTPException: If the request is invalid or the EMO method has not completed.
454
+
455
+ Returns:
456
+ SCOREBandsResult: The results of the SCORE bands visualization.
457
+ """
458
+ if request.config is None:
459
+ score_config = SCOREBandsConfig()
460
+ else:
461
+ score_config = request.config
462
+ parent_state = request.parent_state_id
463
+ statement = select(StateDB).where(StateDB.id == parent_state)
464
+ state = session.exec(statement).first()
465
+ if state is None:
466
+ raise HTTPException(status_code=404, detail="Parent state not found.")
467
+
468
+ if not isinstance(state.state, EMOIterateState):
469
+ raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
470
+
471
+ if not (state.state.objective_values and state.state.decision_variables):
472
+ raise ValueError(f"State does not contain results yet.")
473
+
474
+ # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
475
+ raw_objs: dict[str, list[float]] = state.state.objective_values
476
+ objs = pl.DataFrame(raw_objs)
477
+
478
+ results = score_json(
479
+ data=objs,
480
+ options=score_config,
481
+ )
482
+
483
+ score_state = EMOSCOREState(result=results.model_dump())
484
+
485
+ score_db_state = StateDB.create(
486
+ database_session=session,
487
+ problem_id=request.problem_id,
488
+ session_id=request.session_id,
489
+ parent_id=parent_state,
490
+ state=score_state,
491
+ )
492
+ session.add(score_db_state)
493
+ session.commit()
494
+ session.refresh(score_db_state)
495
+ state_id = score_db_state.id
496
+
497
+ return EMOScoreResponse(result=results, state_id=state_id)