desdeo 2.1.1__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,11 +36,13 @@ class NIMBUSSaveRequest(SQLModel):
36
36
 
37
37
  solution_info: list[SolutionInfo]
38
38
 
39
+
39
40
  class NIMBUSDeleteSaveRequest(SQLModel):
40
41
  """Request model for deletion of a saved solution."""
41
42
 
42
- state_id : int = Field(description="The ID of the save state.")
43
+ state_id: int = Field(description="The ID of the save state.")
43
44
  solution_index: int = Field(description="The ID of the solution within the above state.")
45
+ problem_id: int = Field(description="The ID of the problem.")
44
46
 
45
47
 
46
48
  class NIMBUSFinalizeRequest(SQLModel):
@@ -50,7 +52,7 @@ class NIMBUSFinalizeRequest(SQLModel):
50
52
  session_id: int | None = Field(default=None)
51
53
  parent_state_id: int | None = Field(default=None)
52
54
 
53
- solution_info: SolutionInfo # the final solution
55
+ solution_info: SolutionInfo # the final solution
54
56
 
55
57
 
56
58
  class NIMBUSClassificationResponse(SQLModel):
@@ -98,12 +100,14 @@ class NIMBUSSaveResponse(SQLModel):
98
100
 
99
101
  state_id: int | None = Field(description="The id of the newest state")
100
102
 
103
+
101
104
  class NIMBUSDeleteSaveResponse(SQLModel):
102
105
  """Response of NIMBUS save deletion."""
103
106
 
104
107
  response_type: str = "nimbus.delete_save"
105
108
 
106
- message: str | None
109
+ message: str | None = None
110
+
107
111
 
108
112
  class NIMBUSFinalizeResponse(SQLModel):
109
113
  """The response from NIMBUS finish endpoint."""
@@ -144,7 +148,7 @@ class NIMBUSIntermediateSolutionResponse(SQLModel):
144
148
  reference_solution_1: dict[str, float] = Field(
145
149
  sa_column=Column(JSON), description="The first solution used when computing intermediate points."
146
150
  )
147
- reference_solution_2: dict[str, float]= Field(
151
+ reference_solution_2: dict[str, float] = Field(
148
152
  sa_column=Column(JSON), description="The second solution used when computing intermediate points."
149
153
  )
150
154
  current_solutions: list[SolutionReferenceResponse] = Field(
desdeo/api/routers/emo.py CHANGED
@@ -15,30 +15,24 @@ import polars as pl
15
15
  from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
16
16
  from fastapi.encoders import jsonable_encoder
17
17
  from fastapi.responses import StreamingResponse
18
- from sqlmodel import Session, select
18
+ from sqlmodel import select
19
19
  from websockets.asyncio.client import connect
20
20
 
21
21
  from desdeo.api.db import get_session
22
- from desdeo.api.models import InteractiveSessionDB, StateDB
22
+ from desdeo.api.models import StateDB
23
23
  from desdeo.api.models.emo import (
24
24
  EMOFetchRequest,
25
- EMOFetchResponse,
26
25
  EMOIterateRequest,
27
26
  EMOIterateResponse,
28
- EMOSaveRequest,
29
27
  EMOScoreRequest,
30
28
  EMOScoreResponse,
31
- Solution,
32
29
  )
33
- from desdeo.api.models.problem import ProblemDB
34
- from desdeo.api.models.state import EMOFetchState, EMOIterateState, EMOSaveState, EMOSCOREState
35
- from desdeo.api.models.user import User
36
- from desdeo.api.routers.user_authentication import get_current_user
30
+ from desdeo.api.models.state import EMOIterateState, EMOSCOREState
37
31
  from desdeo.emo.options.templates import EMOOptions, PreferenceOptions, TemplateOptions, emo_constructor
38
32
  from desdeo.problem import Problem
39
- from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult, score_json
33
+ from desdeo.tools.score_bands import SCOREBandsConfig, score_json
40
34
 
41
- from .utils import fetch_interactive_session, fetch_user_problem
35
+ from .utils import SessionContext, get_session_context
42
36
 
43
37
  router = APIRouter(prefix="/method/emo", tags=["EMO"])
44
38
 
@@ -113,7 +107,6 @@ async def websocket_endpoint(
113
107
  try:
114
108
  while True:
115
109
  data = await websocket.receive_json()
116
- print(data)
117
110
  if "send_to" in data:
118
111
  try:
119
112
  await ws_manager.send_private_message(data, data["send_to"])
@@ -153,71 +146,55 @@ def get_templates() -> list[TemplateOptions]:
153
146
  @router.post("/iterate")
154
147
  def iterate(
155
148
  request: EMOIterateRequest,
156
- user: Annotated[User, Depends(get_current_user)],
157
- session: Annotated[Session, Depends(get_session)],
149
+ context: Annotated[SessionContext, Depends(get_session_context)],
158
150
  ) -> EMOIterateResponse:
159
- """Starts the EMO method.
151
+ """Fetches results from a completed EMO method.
160
152
 
161
- Args:
162
- request (EMOSolveRequest): The request object containing parameters for the EMO method.
163
- user (Annotated[User, Depends]): The current user.
164
- session (Annotated[Session, Depends]): The database session.
153
+ Args: request (EMOIterateRequest): The request object containing parameters for fetching results.
154
+ context (Annotated[SessionContext, Depends]): The session context.
165
155
 
166
- Raises:
167
- HTTPException: If the request is invalid or the EMO method fails.
168
-
169
- Returns:
170
- IterateResponse: A response object containing a list of IDs to be used for websocket communication.
171
- Also contains the StateDB id where the results will be stored.
156
+ Raises: HTTPException: If the request is invalid or the EMO method fails.
157
+ Returns: IterateResponse: A response object containing a list of IDs to be used for websocket communication.
158
+ Also contains the StateDB id where the results will be stored.
172
159
  """
173
- interactive_session: InteractiveSessionDB | None = fetch_interactive_session(user, request, session)
160
+ # Get context objects
161
+ db_session = context.db_session
162
+ interactive_session = context.interactive_session
163
+ parent_state = context.parent_state
174
164
 
175
- problem_db = fetch_user_problem(user, request, session)
165
+ # Ensure problem exists
166
+ if context.problem_db is None:
167
+ raise HTTPException(status_code=404, detail="Problem not found")
168
+
169
+ problem_db = context.problem_db
176
170
  problem = Problem.from_problemdb(problem_db)
177
171
 
178
- templates = request.template_options
172
+ # Templates
173
+ templates = request.template_options or get_templates()
179
174
 
180
- if templates is None:
181
- templates = get_templates()
175
+ web_socket_ids = [
176
+ f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" for template in templates
177
+ ]
182
178
 
183
- web_socket_ids = []
184
- for template in templates:
185
- # Ensure unique names
186
- web_socket_ids.append(f"{template.algorithm_name.lower()}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}")
187
179
  client_id = f"client_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
188
- client_id = "client"
189
-
190
- # Save request (incomplete and EAs have not finished running yet)
191
-
192
- # Handle parent state
193
- if request.parent_state_id is None:
194
- parent_state = None
195
- else:
196
- statement = select(StateDB).where(StateDB.id == request.parent_state_id)
197
- parent_state = session.exec(statement).first()
198
-
199
- if parent_state is None:
200
- raise HTTPException(
201
- status_code=status.HTTP_404_NOT_FOUND,
202
- detail=f"Could not find state with id={request.parent_state_id}",
203
- )
204
180
 
181
+ # 4) Create incomplete state
205
182
  emo_iterate_state = EMOIterateState(
206
183
  template_options=jsonable_encoder(templates),
207
184
  preference_options=jsonable_encoder(request.preference_options),
208
185
  )
209
186
 
210
187
  incomplete_db_state = StateDB.create(
211
- database_session=session,
188
+ database_session=db_session,
212
189
  problem_id=problem_db.id,
213
190
  session_id=interactive_session.id if interactive_session else None,
214
191
  parent_id=parent_state.id if parent_state else None,
215
192
  state=emo_iterate_state,
216
193
  )
217
194
 
218
- session.add(incomplete_db_state)
219
- session.commit()
220
- session.refresh(incomplete_db_state)
195
+ db_session.add(incomplete_db_state)
196
+ db_session.commit()
197
+ db_session.refresh(incomplete_db_state)
221
198
 
222
199
  state_id = incomplete_db_state.id
223
200
  if state_id is None:
@@ -225,10 +202,8 @@ def iterate(
225
202
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
226
203
  detail="Failed to create a new state in the database.",
227
204
  )
228
- # Close db session
229
- session.close()
230
205
 
231
- # Spawn a new process to handle EMO method creation
206
+ # Start process
232
207
  Process(
233
208
  target=_spawn_emo_process,
234
209
  args=(
@@ -319,7 +294,7 @@ def _spawn_emo_process(
319
294
  session.close()
320
295
 
321
296
 
322
- def _ea_sync( # noqa: PLR0913
297
+ def _ea_sync(
323
298
  problem: Problem,
324
299
  template: TemplateOptions,
325
300
  preference_options: PreferenceOptions | None,
@@ -352,7 +327,7 @@ def _ea_sync( # noqa: PLR0913
352
327
  )
353
328
 
354
329
 
355
- async def _ea_async( # noqa: PLR0913
330
+ async def _ea_async(
356
331
  problem: Problem,
357
332
  websocket_id: str,
358
333
  client_id: str,
@@ -388,33 +363,29 @@ async def _ea_async( # noqa: PLR0913
388
363
  @router.post("/fetch")
389
364
  async def fetch_results(
390
365
  request: EMOFetchRequest,
391
- user: Annotated[User, Depends(get_current_user)],
392
- session: Annotated[Session, Depends(get_session)],
366
+ context: Annotated[SessionContext, Depends(get_session_context)],
393
367
  ) -> StreamingResponse:
394
368
  """Fetches results from a completed EMO method.
395
369
 
396
370
  Args:
397
371
  request (EMOFetchRequest): The request object containing parameters for fetching results.
398
- user (Annotated[User, Depends]): The current user.
399
- session (Annotated[Session, Depends]): The database session.
372
+ context (Annotated[SessionContext, Depends]): The session context.
400
373
 
401
- Raises:
402
- HTTPException: If the request is invalid or the EMO method has not completed.
374
+ Raises: HTTPException: If the request is invalid or the EMO method has not completed.
403
375
 
404
- Returns:
405
- StreamingResponse: A streaming response containing the results of the EMO method.
376
+ Returns: StreamingResponse: A streaming response containing the results of the EMO method.
406
377
  """
407
- parent_state = request.parent_state_id
408
- statement = select(StateDB).where(StateDB.id == parent_state)
409
- state = session.exec(statement).first()
378
+ # Use context instead of manual fetch
379
+ state = context.parent_state
380
+
410
381
  if state is None:
411
382
  raise HTTPException(status_code=404, detail="Parent state not found.")
412
383
 
413
384
  if not isinstance(state.state, EMOIterateState):
414
- raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
385
+ raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")
415
386
 
416
387
  if not (state.state.objective_values and state.state.decision_variables):
417
- raise ValueError(f"State does not contain results yet.")
388
+ raise ValueError("State does not contain results yet.")
418
389
 
419
390
  # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
420
391
  raw_objs: dict[str, list[float]] = state.state.objective_values
@@ -422,14 +393,15 @@ async def fetch_results(
422
393
  objs: list[dict[str, float]] = [{k: v[i] for k, v in raw_objs.items()} for i in range(n_solutions)]
423
394
 
424
395
  raw_decs: dict[str, list[float]] = state.state.decision_variables
425
-
426
396
  decs: list[dict[str, float]] = [{k: v[i] for k, v in raw_decs.items()} for i in range(n_solutions)]
427
397
 
428
- response: list[Solution] = []
429
-
430
398
  def result_stream():
431
399
  for i in range(n_solutions):
432
- item = {"solution_id": i, "objective_values": objs[i], "decision_variables": decs[i]}
400
+ item = {
401
+ "solution_id": i,
402
+ "objective_values": objs[i],
403
+ "decision_variables": decs[i],
404
+ }
433
405
  yield json.dumps(item) + "\n"
434
406
 
435
407
  return StreamingResponse(result_stream())
@@ -438,16 +410,13 @@ async def fetch_results(
438
410
  @router.post("/fetch_score")
439
411
  async def fetch_score_bands(
440
412
  request: EMOScoreRequest,
441
- user: Annotated[User, Depends(get_current_user)],
442
- session: Annotated[Session, Depends(get_session)],
413
+ context: Annotated[SessionContext, Depends(get_session_context)],
443
414
  ) -> EMOScoreResponse:
444
415
  """Fetches results from a completed EMO method.
445
416
 
446
- Args:
447
- request (EMOFetchRequest): The request object containing parameters for fetching results and of the SCORE bands
448
- visualization.
449
- user (Annotated[User, Depends]): The current user.
450
- session (Annotated[Session, Depends]): The database session.
417
+ Args: request (EMOFetchRequest): The request object containing parameters for fetching
418
+ results and of the SCORE bands visualization.
419
+ context (Annotated[SessionContext, Depends]): The session context.
451
420
 
452
421
  Raises:
453
422
  HTTPException: If the request is invalid or the EMO method has not completed.
@@ -455,24 +424,23 @@ async def fetch_score_bands(
455
424
  Returns:
456
425
  SCOREBandsResult: The results of the SCORE bands visualization.
457
426
  """
458
- if request.config is None:
459
- score_config = SCOREBandsConfig()
460
- else:
461
- score_config = request.config
462
- parent_state = request.parent_state_id
463
- statement = select(StateDB).where(StateDB.id == parent_state)
464
- state = session.exec(statement).first()
465
- if state is None:
427
+ # Use context instead of manual fetch
428
+ parent_state = context.parent_state
429
+ db_session = context.db_session
430
+ problem_db = context.problem_db
431
+
432
+ if parent_state is None:
466
433
  raise HTTPException(status_code=404, detail="Parent state not found.")
467
434
 
468
- if not isinstance(state.state, EMOIterateState):
469
- raise TypeError(f"State with id={parent_state} is not of type EMOIterateState.")
435
+ if not isinstance(parent_state.state, EMOIterateState):
436
+ raise TypeError(f"State with id={request.parent_state_id} is not of type EMOIterateState.")
470
437
 
471
- if not (state.state.objective_values and state.state.decision_variables):
472
- raise ValueError(f"State does not contain results yet.")
438
+ if not (parent_state.state.objective_values and parent_state.state.decision_variables):
439
+ raise ValueError("State does not contain results yet.")
473
440
 
474
- # Convert objs: dict[str, list[float]] to objs: list[dict[str, float]]
475
- raw_objs: dict[str, list[float]] = state.state.objective_values
441
+ score_config = SCOREBandsConfig() if request.config is None else request.config
442
+
443
+ raw_objs: dict[str, list[float]] = parent_state.state.objective_values
476
444
  objs = pl.DataFrame(raw_objs)
477
445
 
478
446
  results = score_json(
@@ -482,16 +450,19 @@ async def fetch_score_bands(
482
450
 
483
451
  score_state = EMOSCOREState(result=results.model_dump())
484
452
 
453
+ # Use the session + problem from context instead of request directly
485
454
  score_db_state = StateDB.create(
486
- database_session=session,
487
- problem_id=request.problem_id,
488
- session_id=request.session_id,
489
- parent_id=parent_state,
455
+ database_session=db_session,
456
+ problem_id=problem_db.id,
457
+ session_id=parent_state.session_id,
458
+ parent_id=parent_state.id,
490
459
  state=score_state,
491
460
  )
492
- session.add(score_db_state)
493
- session.commit()
494
- session.refresh(score_db_state)
461
+
462
+ db_session.add(score_db_state)
463
+ db_session.commit()
464
+ db_session.refresh(score_db_state)
465
+
495
466
  state_id = score_db_state.id
496
467
 
497
468
  return EMOScoreResponse(result=results, state_id=state_id)
@@ -5,65 +5,55 @@ from typing import Annotated
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from fastapi import APIRouter, Depends, HTTPException, status
8
- from sqlmodel import Session, select
8
+ from sqlmodel import select
9
9
 
10
- from desdeo.api.db import get_session
11
10
  from desdeo.api.models import (
12
- InteractiveSessionDB,
13
11
  IntermediateSolutionRequest,
14
12
  IntermediateSolutionState,
15
- ProblemDB,
16
13
  ScoreBandsRequest,
17
14
  ScoreBandsResponse,
18
15
  SolutionReference,
19
16
  StateDB,
20
- User,
21
17
  )
22
18
  from desdeo.api.models.generic import GenericIntermediateSolutionResponse
23
- from desdeo.api.routers.user_authentication import get_current_user
24
19
  from desdeo.mcdm.nimbus import solve_intermediate_solutions
25
20
  from desdeo.problem import Problem
26
21
  from desdeo.tools import SolverResults
27
22
  from desdeo.tools.score_bands import calculate_axes_positions, cluster, order_dimensions
28
23
 
24
+ from .utils import SessionContext, get_session_context
25
+
29
26
  router = APIRouter(prefix="/method/generic")
30
27
 
31
28
 
32
29
  @router.post("/intermediate")
33
30
  def solve_intermediate(
34
31
  request: IntermediateSolutionRequest,
35
- user: Annotated[User, Depends(get_current_user)],
36
- session: Annotated[Session, Depends(get_session)],
32
+ context: Annotated[SessionContext, Depends(get_session_context)],
37
33
  ) -> GenericIntermediateSolutionResponse:
38
- """Solve intermediate solutions between given two solutions."""
39
- if request.session_id is not None:
40
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
41
- interactive_session = session.exec(statement)
42
-
43
- if interactive_session is None:
44
- raise HTTPException(
45
- status_code=status.HTTP_404_NOT_FOUND,
46
- detail=f"Could not find interactive session with id={request.session_id}.",
47
- )
48
- else:
49
- # request.session_id is None:
50
- # use active session instead
51
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
52
-
53
- interactive_session = session.exec(statement).first()
34
+ """Solve intermediate solutions between given two solutions.
35
+
36
+ Args:
37
+ request (IntermediateSolutionRequest): The request object containing parameters
38
+ for fetching results.
39
+ context (Annotated[SessionContext, Depends]): The session context.
40
+ """
41
+ db_session = context.db_session
42
+ problem_db = context.problem_db
43
+ interactive_session = context.interactive_session
44
+ parent_state = context.parent_state
54
45
 
55
46
  # query both reference solutions' variable values
56
- # stored as lit of tuples, first element of each tuple are variables values, second are objective function values
57
47
  var_and_obj_values_of_references: list[tuple[dict, dict]] = []
58
48
  reference_states = []
49
+
59
50
  for solution_info in [request.reference_solution_1, request.reference_solution_2]:
60
- solution_state = session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first()
51
+ solution_state = db_session.exec(select(StateDB).where(StateDB.id == solution_info.state_id)).first()
61
52
 
62
53
  if solution_state is None:
63
- # no StateDB found with the given id
64
54
  raise HTTPException(
65
55
  status_code=status.HTTP_404_NOT_FOUND,
66
- detail=f"Could not find a state with the given id{solution_state.state_id}.",
56
+ detail=f"Could not find a state with id={solution_info.state_id}.",
67
57
  )
68
58
 
69
59
  reference_states.append(solution_state)
@@ -71,7 +61,6 @@ def solve_intermediate(
71
61
  try:
72
62
  _var_values = solution_state.state.result_variable_values
73
63
  var_values = _var_values[solution_info.solution_index]
74
-
75
64
  except IndexError as exc:
76
65
  raise HTTPException(
77
66
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -83,7 +72,6 @@ def solve_intermediate(
83
72
  try:
84
73
  _obj_values = solution_state.state.result_objective_values
85
74
  obj_values = _obj_values[solution_info.solution_index]
86
-
87
75
  except IndexError as exc:
88
76
  raise HTTPException(
89
77
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -94,10 +82,7 @@ def solve_intermediate(
94
82
 
95
83
  var_and_obj_values_of_references.append((var_values, obj_values))
96
84
 
97
- # fetch the problem from the DB
98
- statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
99
- problem_db = session.exec(statement).first()
100
-
85
+ # Problem is now already loaded via context
101
86
  if problem_db is None:
102
87
  raise HTTPException(
103
88
  status_code=status.HTTP_404_NOT_FOUND,
@@ -116,26 +101,9 @@ def solve_intermediate(
116
101
  solver_options=request.solver_options,
117
102
  )
118
103
 
119
- # fetch parent state
120
- if request.parent_state_id is None:
121
- # parent state is assumed to be the last state added to the session.
122
- parent_state = (
123
- interactive_session.states[-1]
124
- if (interactive_session is not None and len(interactive_session.states) > 0)
125
- else None
126
- )
127
-
128
- else:
129
- # request.parent_state_id is not None
130
- statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
131
- parent_state = session.exec(statement).first()
132
-
133
- if parent_state is None:
134
- raise HTTPException(
135
- status_code=status.HTTP_404_NOT_FOUND,
136
- detail=f"Could not find state with id={request.parent_state_id}",
137
- )
138
-
104
+ # --------------------------------------
105
+ # parent_state is already loaded in context
106
+ # --------------------------------------
139
107
  intermediate_state = IntermediateSolutionState(
140
108
  scalarization_options=request.scalarization_options,
141
109
  context=request.context,
@@ -149,16 +117,16 @@ def solve_intermediate(
149
117
 
150
118
  # create DB state and add it to the DB
151
119
  state = StateDB.create(
152
- database_session=session,
120
+ database_session=db_session,
153
121
  problem_id=problem_db.id,
154
122
  session_id=interactive_session.id if interactive_session is not None else None,
155
123
  parent_id=parent_state.id if parent_state is not None else None,
156
124
  state=intermediate_state,
157
125
  )
158
126
 
159
- session.add(state)
160
- session.commit()
161
- session.refresh(state)
127
+ db_session.add(state)
128
+ db_session.commit()
129
+ db_session.refresh(state)
162
130
 
163
131
  return GenericIntermediateSolutionResponse(
164
132
  state_id=state.id,