desdeo 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. desdeo/adm/ADMAfsar.py +551 -0
  2. desdeo/adm/ADMChen.py +414 -0
  3. desdeo/adm/BaseADM.py +119 -0
  4. desdeo/adm/__init__.py +11 -0
  5. desdeo/api/__init__.py +6 -6
  6. desdeo/api/app.py +38 -28
  7. desdeo/api/config.py +65 -44
  8. desdeo/api/config.toml +23 -12
  9. desdeo/api/db.py +10 -8
  10. desdeo/api/db_init.py +12 -6
  11. desdeo/api/models/__init__.py +220 -20
  12. desdeo/api/models/archive.py +16 -27
  13. desdeo/api/models/emo.py +128 -0
  14. desdeo/api/models/enautilus.py +69 -0
  15. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  16. desdeo/api/models/gdm/gdm_base.py +69 -0
  17. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  18. desdeo/api/models/gdm/gnimbus.py +138 -0
  19. desdeo/api/models/generic.py +104 -0
  20. desdeo/api/models/generic_states.py +401 -0
  21. desdeo/api/models/nimbus.py +158 -0
  22. desdeo/api/models/preference.py +44 -6
  23. desdeo/api/models/problem.py +274 -64
  24. desdeo/api/models/session.py +4 -1
  25. desdeo/api/models/state.py +419 -52
  26. desdeo/api/models/user.py +7 -6
  27. desdeo/api/models/utopia.py +25 -0
  28. desdeo/api/routers/_EMO.backup +309 -0
  29. desdeo/api/routers/_NIMBUS.py +6 -3
  30. desdeo/api/routers/emo.py +497 -0
  31. desdeo/api/routers/enautilus.py +237 -0
  32. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  33. desdeo/api/routers/gdm/gdm_base.py +420 -0
  34. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  35. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  36. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  37. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  38. desdeo/api/routers/generic.py +233 -0
  39. desdeo/api/routers/nimbus.py +705 -0
  40. desdeo/api/routers/problem.py +201 -4
  41. desdeo/api/routers/reference_point_method.py +20 -44
  42. desdeo/api/routers/session.py +50 -26
  43. desdeo/api/routers/user_authentication.py +180 -26
  44. desdeo/api/routers/utils.py +187 -0
  45. desdeo/api/routers/utopia.py +230 -0
  46. desdeo/api/schema.py +10 -4
  47. desdeo/api/tests/conftest.py +94 -2
  48. desdeo/api/tests/test_enautilus.py +330 -0
  49. desdeo/api/tests/test_models.py +550 -72
  50. desdeo/api/tests/test_routes.py +902 -43
  51. desdeo/api/utils/_database.py +263 -0
  52. desdeo/api/utils/database.py +28 -266
  53. desdeo/api/utils/emo_database.py +40 -0
  54. desdeo/core.py +7 -0
  55. desdeo/emo/__init__.py +154 -24
  56. desdeo/emo/hooks/archivers.py +18 -2
  57. desdeo/emo/methods/EAs.py +128 -5
  58. desdeo/emo/methods/bases.py +9 -56
  59. desdeo/emo/methods/templates.py +111 -0
  60. desdeo/emo/operators/crossover.py +544 -42
  61. desdeo/emo/operators/evaluator.py +10 -14
  62. desdeo/emo/operators/generator.py +127 -24
  63. desdeo/emo/operators/mutation.py +212 -41
  64. desdeo/emo/operators/scalar_selection.py +202 -0
  65. desdeo/emo/operators/selection.py +956 -214
  66. desdeo/emo/operators/termination.py +124 -16
  67. desdeo/emo/options/__init__.py +108 -0
  68. desdeo/emo/options/algorithms.py +435 -0
  69. desdeo/emo/options/crossover.py +164 -0
  70. desdeo/emo/options/generator.py +131 -0
  71. desdeo/emo/options/mutation.py +260 -0
  72. desdeo/emo/options/repair.py +61 -0
  73. desdeo/emo/options/scalar_selection.py +66 -0
  74. desdeo/emo/options/selection.py +127 -0
  75. desdeo/emo/options/templates.py +383 -0
  76. desdeo/emo/options/termination.py +143 -0
  77. desdeo/gdm/__init__.py +22 -0
  78. desdeo/gdm/gdmtools.py +45 -0
  79. desdeo/gdm/score_bands.py +114 -0
  80. desdeo/gdm/voting_rules.py +50 -0
  81. desdeo/mcdm/__init__.py +23 -1
  82. desdeo/mcdm/enautilus.py +338 -0
  83. desdeo/mcdm/gnimbus.py +484 -0
  84. desdeo/mcdm/nautilus_navigator.py +7 -6
  85. desdeo/mcdm/reference_point_method.py +70 -0
  86. desdeo/problem/__init__.py +5 -1
  87. desdeo/problem/external/__init__.py +18 -0
  88. desdeo/problem/external/core.py +356 -0
  89. desdeo/problem/external/pymoo_provider.py +266 -0
  90. desdeo/problem/external/runtime.py +44 -0
  91. desdeo/problem/infix_parser.py +2 -2
  92. desdeo/problem/pyomo_evaluator.py +25 -6
  93. desdeo/problem/schema.py +69 -48
  94. desdeo/problem/simulator_evaluator.py +65 -15
  95. desdeo/problem/testproblems/__init__.py +26 -11
  96. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  97. desdeo/problem/testproblems/cake_problem.py +185 -0
  98. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  99. desdeo/problem/testproblems/forest_problem.py +77 -69
  100. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  101. desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
  102. desdeo/problem/testproblems/single_objective.py +289 -0
  103. desdeo/problem/testproblems/zdt_problem.py +4 -1
  104. desdeo/tools/__init__.py +39 -21
  105. desdeo/tools/desc_gen.py +22 -0
  106. desdeo/tools/generics.py +22 -2
  107. desdeo/tools/group_scalarization.py +3090 -0
  108. desdeo/tools/indicators_binary.py +107 -1
  109. desdeo/tools/indicators_unary.py +3 -16
  110. desdeo/tools/message.py +33 -2
  111. desdeo/tools/non_dominated_sorting.py +4 -3
  112. desdeo/tools/patterns.py +9 -7
  113. desdeo/tools/pyomo_solver_interfaces.py +48 -35
  114. desdeo/tools/reference_vectors.py +118 -351
  115. desdeo/tools/scalarization.py +340 -1413
  116. desdeo/tools/score_bands.py +491 -328
  117. desdeo/tools/utils.py +117 -49
  118. desdeo/tools/visualizations.py +67 -0
  119. desdeo/utopia_stuff/utopia_problem.py +1 -1
  120. desdeo/utopia_stuff/utopia_problem_old.py +1 -1
  121. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/METADATA +46 -28
  122. desdeo-2.1.0.dist-info/RECORD +180 -0
  123. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  124. desdeo-2.0.0.dist-info/RECORD +0 -120
  125. /desdeo/api/utils/{logger.py → _logger.py} +0 -0
  126. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,18 +1,63 @@
1
1
  """Defines end-points to access and manage problems."""
2
2
 
3
+ import json
3
4
  from typing import Annotated
4
5
 
5
- from fastapi import APIRouter, Depends, HTTPException, status
6
- from sqlmodel import Session
6
+ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
7
+ from fastapi.responses import JSONResponse
8
+ from sqlmodel import Session, select
7
9
 
8
10
  from desdeo.api.db import get_session
9
- from desdeo.api.models import ProblemDB, ProblemGetRequest, ProblemInfo, ProblemInfoSmall, User, UserRole
11
+ from desdeo.api.models import (
12
+ ForestProblemMetaData,
13
+ ProblemDB,
14
+ ProblemGetRequest,
15
+ ProblemInfo,
16
+ ProblemInfoSmall,
17
+ ProblemMetaDataDB,
18
+ ProblemMetaDataGetRequest,
19
+ ProblemSelectSolverRequest,
20
+ RepresentativeNonDominatedSolutions,
21
+ SolverSelectionMetadata,
22
+ User,
23
+ UserRole,
24
+ )
10
25
  from desdeo.api.routers.user_authentication import get_current_user
11
26
  from desdeo.problem import Problem
27
+ from desdeo.tools.utils import available_solvers
12
28
 
13
29
  router = APIRouter(prefix="/problem")
14
30
 
15
31
 
32
+ def check_solver(problem_db: ProblemDB):
33
+ """Check if a preferred solver is set in the metadata.
34
+
35
+ Check if a preferred solver is set in the metadata.
36
+ If it exist, fetch its constructor and return it. Otherwise return None.
37
+ """
38
+ metadata: ProblemMetaDataDB = problem_db.problem_metadata
39
+ solver_metadata = None
40
+ if metadata is not None:
41
+ solver_metadata_list = [
42
+ metadata for metadata in metadata.all_metadata if metadata.metadata_type == "solver_selection_metadata"
43
+ ]
44
+ if solver_metadata_list != []:
45
+ solver_metadata = solver_metadata_list[-1]
46
+
47
+ if solver_metadata is not None:
48
+ solver = available_solvers[solver_metadata.solver_string_representation]["constructor"]
49
+ else:
50
+ solver = None
51
+ return solver
52
+
53
+
54
+ # This is needed, because otherwise fields ending in an underscore fail to parse.
55
+ async def parse_problem_json(request: Request) -> Problem:
56
+ """Helper function to pass by_name=True to model_validate when coercing the json object to a Problem object."""
57
+ data: dict = await request.json()
58
+ return Problem.model_validate(data, by_name=True)
59
+
60
+
16
61
  @router.get("/all")
17
62
  def get_problems(user: Annotated[User, Depends(get_current_user)]) -> list[ProblemInfoSmall]:
18
63
  """Get information on all the current user's problems.
@@ -71,7 +116,7 @@ def get_problem(
71
116
 
72
117
  @router.post("/add")
73
118
  def add_problem(
74
- request: Problem,
119
+ request: Annotated[Problem, Depends(parse_problem_json)],
75
120
  user: Annotated[User, Depends(get_current_user)],
76
121
  session: Annotated[Session, Depends(get_session)],
77
122
  ) -> ProblemInfo:
@@ -108,3 +153,155 @@ def add_problem(
108
153
  session.refresh(problem_db)
109
154
 
110
155
  return problem_db
156
+
157
+
158
+ @router.post("/add_json")
159
+ def add_problem_json(
160
+ json_file: UploadFile,
161
+ user: Annotated[User, Depends(get_current_user)],
162
+ session: Annotated[Session, Depends(get_session)],
163
+ ) -> ProblemInfo:
164
+ """Adds a problem to the database based on its JSON definition.
165
+
166
+ Args:
167
+ json_file (UploadFile): a file in JSON format describing the problem.
168
+ user (Annotated[User, Depends): the usr for which the problem is added.
169
+ session (Annotated[Session, Depends): the database session.
170
+
171
+ Raises:
172
+ HTTPException: if the provided `json_file` is empty.
173
+ HTTPException: if the content in the provided `json_file` is not in JSON format.__annotations__
174
+
175
+ Returns:
176
+ ProblemInfo: a description of the added problem.
177
+ """
178
+ raw = json_file.file.read()
179
+
180
+ if not raw:
181
+ raise HTTPException(400, "Empty upload.")
182
+
183
+ try:
184
+ # for extra validation
185
+ json.loads(raw)
186
+ except json.JSONDecodeError as e:
187
+ raise HTTPException(400, "Invalid JSON.") from e
188
+
189
+ problem = Problem.model_validate_json(raw, by_name=True)
190
+ problem_db = ProblemDB.from_problem(problem, user=user)
191
+
192
+ session.add(problem_db)
193
+ session.commit()
194
+ session.refresh(problem_db)
195
+
196
+ return problem_db
197
+
198
+
199
+ @router.post("/get_metadata")
200
+ def get_metadata(
201
+ request: ProblemMetaDataGetRequest,
202
+ user: Annotated[User, Depends(get_current_user)],
203
+ session: Annotated[Session, Depends(get_session)],
204
+ ) -> list[ForestProblemMetaData | RepresentativeNonDominatedSolutions | SolverSelectionMetadata]:
205
+ """Fetch specific metadata for a specific problem.
206
+
207
+ Fetch specific metadata for a specific problem. See all the possible
208
+ metadata types from DESDEO/desdeo/api/models/problem.py Problem Metadata
209
+ section.
210
+
211
+ Args:
212
+ request (MetaDataGetRequest): the requested metadata type.
213
+ user (Annotated[User, Depends]): the current user.
214
+ session (Annotated[Session, Depends]): the database session.
215
+
216
+ Returns:
217
+ list[ForestProblemMetadata | RepresentativeNonDominatedSolutions]: list containing all the metadata
218
+ defined for the problem with the requested metadata type. If no match is found,
219
+ returns an empty list.
220
+ """
221
+ statement = select(ProblemDB).where(ProblemDB.id == request.problem_id)
222
+ problem_from_db = session.exec(statement).first()
223
+ if problem_from_db is None:
224
+ raise HTTPException(
225
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with ID {request.problem_id} not found!"
226
+ )
227
+
228
+ problem_metadata = problem_from_db.problem_metadata
229
+
230
+ if problem_metadata is None:
231
+ # no metadata define for the problem
232
+ return []
233
+
234
+ # metadata is defined, try to find matching types based on request
235
+ return [metadata for metadata in problem_metadata.all_metadata if metadata.metadata_type == request.metadata_type]
236
+
237
+
238
+ @router.get("/assign/solver", response_model=list[str])
239
+ def get_available_solvers() -> list[str]:
240
+ """Return the list of available solver names."""
241
+ return list(available_solvers.keys())
242
+
243
+ @router.post("/assign_solver")
244
+ def select_solver(
245
+ request: ProblemSelectSolverRequest,
246
+ user: Annotated[User, Depends(get_current_user)],
247
+ session: Annotated[Session, Depends(get_session)],
248
+ ) -> JSONResponse:
249
+ """Assign a specific solver for a problem.
250
+
251
+ request: ProblemSelectSolverRequest: The request containing problem id and string representation of the solver
252
+ user: Annotated[User, Depends(get_current_user): The user that is logged in.
253
+ session: Annotated[Session, Depends(get_session)]: The database session.
254
+
255
+ Raises:
256
+ HTTPException: Unknown solver, unauthorized user
257
+
258
+ Returns:
259
+ JSONResponse: A simple confirmation.
260
+ """
261
+ if request.solver_string_representation not in [x for x, _ in available_solvers.items()]:
262
+ raise HTTPException(
263
+ detail=f"Solver of unknown type: {request.solver_string_representation}",
264
+ status_code=status.HTTP_404_NOT_FOUND,
265
+ )
266
+
267
+ """Set a specific solver for a specific problem."""
268
+ # Get the problem
269
+ problem_db = session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
270
+ if problem_db is None:
271
+ raise HTTPException(
272
+ detail=f"No problem with ID {request.problem_id}!",
273
+ status_code=status.HTTP_404_NOT_FOUND,
274
+ )
275
+ # Auth the user
276
+ if user.id != problem_db.user_id:
277
+ raise HTTPException(detail="Unauthorized user!", status_code=status.HTTP_401_UNAUTHORIZED)
278
+
279
+ # All good, get on with it.
280
+ problem_metadata = problem_db.problem_metadata
281
+ if problem_metadata is None:
282
+ # There's no metadata for this problem! Create some.
283
+ problem_metadata = ProblemMetaDataDB(problem_id=problem_db.id, problem=problem_db)
284
+ session.add(problem_metadata)
285
+ session.commit()
286
+ session.refresh(problem_metadata)
287
+
288
+ if problem_metadata.solver_selection_metadata:
289
+ session.delete(problem_metadata.solver_selection_metadata[-1])
290
+ session.commit()
291
+
292
+ solver_selection_metadata = SolverSelectionMetadata(
293
+ metadata_id=problem_metadata.id,
294
+ solver_string_representation=request.solver_string_representation,
295
+ metadata_instance=problem_metadata,
296
+ )
297
+
298
+ session.add(solver_selection_metadata)
299
+ session.commit()
300
+ session.refresh(solver_selection_metadata)
301
+
302
+ problem_metadata.solver_selection_metadata.append(solver_selection_metadata)
303
+ session.add(problem_metadata)
304
+ session.commit()
305
+ session.refresh(problem_metadata)
306
+
307
+ return JSONResponse(content={"message": "OK"}, status_code=status.HTTP_200_OK)
@@ -2,8 +2,8 @@
2
2
 
3
3
  from typing import Annotated
4
4
 
5
- from fastapi import APIRouter, Depends, HTTPException, status
6
- from sqlmodel import Session, select
5
+ from fastapi import APIRouter, Depends
6
+ from sqlmodel import Session
7
7
 
8
8
  from desdeo.api.db import get_session
9
9
  from desdeo.api.models import (
@@ -15,11 +15,14 @@ from desdeo.api.models import (
15
15
  StateDB,
16
16
  User,
17
17
  )
18
+ from desdeo.api.routers.problem import check_solver
18
19
  from desdeo.api.routers.user_authentication import get_current_user
19
20
  from desdeo.mcdm import rpm_solve_solutions
20
21
  from desdeo.problem import Problem
21
22
  from desdeo.tools import SolverResults
22
23
 
24
+ from .utils import fetch_interactive_session, fetch_parent_state, fetch_user_problem
25
+
23
26
  router = APIRouter(prefix="/method/rpm")
24
27
 
25
28
 
@@ -29,32 +32,24 @@ def solve_solutions(
29
32
  user: Annotated[User, Depends(get_current_user)],
30
33
  session: Annotated[Session, Depends(get_session)],
31
34
  ) -> RPMState:
32
- """."""
33
-
34
- if request.session_id is not None:
35
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
36
- interactive_session = session.exec(statement)
35
+ """Runs an iteration of the reference point method.
37
36
 
38
- if interactive_session is None:
39
- raise HTTPException(
40
- status_code=status.HTTP_404_NOT_FOUND,
41
- detail=f"Could not find interactive session with id={request.session_id}.",
42
- )
43
- else:
44
- # request.session_id is None:
45
- # use active session instead
46
- statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
37
+ Args:
38
+ request (RPMSolveRequest): a request with the needed information to run the method.
39
+ user (Annotated[User, Depends): the current user.
40
+ session (Annotated[Session, Depends): the current database session.
47
41
 
48
- interactive_session = session.exec(statement).first()
42
+ Returns:
43
+ RPMState: a state with information on the results of iterating the reference point method
44
+ once.
45
+ """
46
+ # fetch interactive session, parent state, and ProblemDB
47
+ interactive_session: InteractiveSessionDB = fetch_interactive_session(user, request, session)
48
+ parent_state = fetch_parent_state(user, request, session, interactive_session)
49
49
 
50
- # fetch the problem from the DB
51
- statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
52
- problem_db = session.exec(statement).first()
50
+ problem_db: ProblemDB = fetch_user_problem(user, request, session)
53
51
 
54
- if problem_db is None:
55
- raise HTTPException(
56
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
57
- )
52
+ solver = check_solver(problem_db=problem_db)
58
53
 
59
54
  problem = Problem.from_problemdb(problem_db)
60
55
 
@@ -63,7 +58,7 @@ def solve_solutions(
63
58
  problem,
64
59
  request.preference.aspiration_levels,
65
60
  request.scalarization_options,
66
- request.solver,
61
+ solver,
67
62
  request.solver_options,
68
63
  )
69
64
 
@@ -74,25 +69,6 @@ def solve_solutions(
74
69
  session.commit()
75
70
  session.refresh(preference_db)
76
71
 
77
- # fetch parent state
78
- if request.parent_state_id is None:
79
- # parent state is assumed to be the last sate added to the session.
80
- parent_state = (
81
- interactive_session.states[-1]
82
- if (interactive_session is not None and len(interactive_session.states) > 0)
83
- else None
84
- )
85
-
86
- else:
87
- # request.parent_state_id is not None
88
- statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
89
- parent_state = session.exec(statement).first()
90
-
91
- if parent_state is None:
92
- raise HTTPException(
93
- status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
94
- )
95
-
96
72
  # create state and add to DB
97
73
  rpm_state = RPMState(
98
74
  scalarization_options=request.scalarization_options,
@@ -5,7 +5,7 @@ from typing import Annotated
5
5
  from fastapi import APIRouter, Depends, HTTPException, status
6
6
  from sqlmodel import Session, select
7
7
 
8
- from desdeo.api.db import get_session
8
+ from desdeo.api.db import get_session as get_db_session
9
9
  from desdeo.api.models import (
10
10
  CreateSessionRequest,
11
11
  GetSessionRequest,
@@ -14,6 +14,7 @@ from desdeo.api.models import (
14
14
  User,
15
15
  )
16
16
  from desdeo.api.routers.user_authentication import get_current_user
17
+ from desdeo.api.routers.utils import fetch_interactive_session
17
18
 
18
19
  router = APIRouter(prefix="/session")
19
20
 
@@ -22,7 +23,7 @@ router = APIRouter(prefix="/session")
22
23
  def create_new_session(
23
24
  request: CreateSessionRequest,
24
25
  user: Annotated[User, Depends(get_current_user)],
25
- session: Annotated[Session, Depends(get_session)],
26
+ session: Annotated[Session, Depends(get_db_session)],
26
27
  ) -> InteractiveSessionInfo:
27
28
  """."""
28
29
  interactive_session = InteractiveSessionDB(user_id=user.id, info=request.info)
@@ -40,37 +41,60 @@ def create_new_session(
40
41
  return interactive_session
41
42
 
42
43
 
43
- @router.post("/get")
44
+ @router.get("/get/{session_id}")
44
45
  def get_session(
45
- request: GetSessionRequest,
46
+ session_id: int,
46
47
  user: Annotated[User, Depends(get_current_user)],
47
- session: Annotated[Session, Depends(get_session)],
48
+ session: Annotated[Session, Depends(get_db_session)],
48
49
  ) -> InteractiveSessionInfo:
49
- """Return an interactive session with a given id for the current user.
50
-
51
- Args:
52
- request (GetSessionRequest): a request containing the id of the session.
53
- user (Annotated[User, Depends): the current user.
54
- session (Annotated[Session, Depends): the database session.
55
-
56
- Raises:
57
- HTTPException: could not find an interactive session with the given id
58
- for the current user.
59
-
60
- Returns:
61
- InteractiveSessionInfo: info on the requested interactive session.
62
- """
63
- statement = select(InteractiveSessionDB).where(
64
- InteractiveSessionDB.id == request.session_id, InteractiveSessionDB.user_id == user.id
50
+ """Return an interactive session with a current user."""
51
+ request = GetSessionRequest(session_id=session_id)
52
+ return fetch_interactive_session(
53
+ user=user,
54
+ request=request,
55
+ session=session,
65
56
  )
66
- result = session.exec(statement)
67
57
 
68
- interactive_session = result.first()
69
58
 
70
- if interactive_session is None:
59
+ @router.get("/get_all", status_code=status.HTTP_200_OK)
60
+ def get_all_sessions(
61
+ user: Annotated[User, Depends(get_current_user)],
62
+ session: Annotated[Session, Depends(get_db_session)],
63
+ ) -> list[InteractiveSessionInfo]:
64
+ """Return all interactive sessions of the current user."""
65
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.user_id == user.id)
66
+ result = session.exec(statement).all()
67
+
68
+ if not result:
71
69
  raise HTTPException(
72
70
  status_code=status.HTTP_404_NOT_FOUND,
73
- detail=f"Could not find interactive session with id={request.session_id}.",
71
+ detail="No interactive sessions found for the user.",
74
72
  )
75
73
 
76
- return interactive_session
74
+ return result
75
+
76
+
77
+ @router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
78
+ def delete_session(
79
+ session_id: int,
80
+ user: Annotated[User, Depends(get_current_user)],
81
+ session: Annotated[Session, Depends(get_db_session)],
82
+ ) -> None:
83
+ """Delete an interactive session and all its related states."""
84
+ request = GetSessionRequest(session_id=session_id)
85
+
86
+ interactive_session = fetch_interactive_session(
87
+ user=user,
88
+ request=request,
89
+ session=session,
90
+ ) # raises 404 if not found
91
+
92
+ try:
93
+ session.delete(interactive_session)
94
+ session.commit()
95
+ except Exception as exc:
96
+ session.rollback()
97
+ raise HTTPException(
98
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
99
+ detail="Failed to delete interactive session.",
100
+ ) from exc