desdeo 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. desdeo/adm/ADMAfsar.py +551 -0
  2. desdeo/adm/ADMChen.py +414 -0
  3. desdeo/adm/BaseADM.py +119 -0
  4. desdeo/adm/__init__.py +11 -0
  5. desdeo/api/__init__.py +6 -6
  6. desdeo/api/app.py +38 -28
  7. desdeo/api/config.py +65 -44
  8. desdeo/api/config.toml +23 -12
  9. desdeo/api/db.py +10 -8
  10. desdeo/api/db_init.py +12 -6
  11. desdeo/api/models/__init__.py +220 -20
  12. desdeo/api/models/archive.py +16 -27
  13. desdeo/api/models/emo.py +128 -0
  14. desdeo/api/models/enautilus.py +69 -0
  15. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  16. desdeo/api/models/gdm/gdm_base.py +69 -0
  17. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  18. desdeo/api/models/gdm/gnimbus.py +138 -0
  19. desdeo/api/models/generic.py +104 -0
  20. desdeo/api/models/generic_states.py +401 -0
  21. desdeo/api/models/nimbus.py +158 -0
  22. desdeo/api/models/preference.py +44 -6
  23. desdeo/api/models/problem.py +274 -64
  24. desdeo/api/models/session.py +4 -1
  25. desdeo/api/models/state.py +419 -52
  26. desdeo/api/models/user.py +7 -6
  27. desdeo/api/models/utopia.py +25 -0
  28. desdeo/api/routers/_EMO.backup +309 -0
  29. desdeo/api/routers/_NIMBUS.py +6 -3
  30. desdeo/api/routers/emo.py +497 -0
  31. desdeo/api/routers/enautilus.py +237 -0
  32. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  33. desdeo/api/routers/gdm/gdm_base.py +420 -0
  34. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  35. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  36. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  37. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  38. desdeo/api/routers/generic.py +233 -0
  39. desdeo/api/routers/nimbus.py +705 -0
  40. desdeo/api/routers/problem.py +201 -4
  41. desdeo/api/routers/reference_point_method.py +20 -44
  42. desdeo/api/routers/session.py +50 -26
  43. desdeo/api/routers/user_authentication.py +180 -26
  44. desdeo/api/routers/utils.py +187 -0
  45. desdeo/api/routers/utopia.py +230 -0
  46. desdeo/api/schema.py +10 -4
  47. desdeo/api/tests/conftest.py +94 -2
  48. desdeo/api/tests/test_enautilus.py +330 -0
  49. desdeo/api/tests/test_models.py +550 -72
  50. desdeo/api/tests/test_routes.py +902 -43
  51. desdeo/api/utils/_database.py +263 -0
  52. desdeo/api/utils/database.py +28 -266
  53. desdeo/api/utils/emo_database.py +40 -0
  54. desdeo/core.py +7 -0
  55. desdeo/emo/__init__.py +154 -24
  56. desdeo/emo/hooks/archivers.py +18 -2
  57. desdeo/emo/methods/EAs.py +128 -5
  58. desdeo/emo/methods/bases.py +9 -56
  59. desdeo/emo/methods/templates.py +111 -0
  60. desdeo/emo/operators/crossover.py +544 -42
  61. desdeo/emo/operators/evaluator.py +10 -14
  62. desdeo/emo/operators/generator.py +127 -24
  63. desdeo/emo/operators/mutation.py +212 -41
  64. desdeo/emo/operators/scalar_selection.py +202 -0
  65. desdeo/emo/operators/selection.py +956 -214
  66. desdeo/emo/operators/termination.py +124 -16
  67. desdeo/emo/options/__init__.py +108 -0
  68. desdeo/emo/options/algorithms.py +435 -0
  69. desdeo/emo/options/crossover.py +164 -0
  70. desdeo/emo/options/generator.py +131 -0
  71. desdeo/emo/options/mutation.py +260 -0
  72. desdeo/emo/options/repair.py +61 -0
  73. desdeo/emo/options/scalar_selection.py +66 -0
  74. desdeo/emo/options/selection.py +127 -0
  75. desdeo/emo/options/templates.py +383 -0
  76. desdeo/emo/options/termination.py +143 -0
  77. desdeo/gdm/__init__.py +22 -0
  78. desdeo/gdm/gdmtools.py +45 -0
  79. desdeo/gdm/score_bands.py +114 -0
  80. desdeo/gdm/voting_rules.py +50 -0
  81. desdeo/mcdm/__init__.py +23 -1
  82. desdeo/mcdm/enautilus.py +338 -0
  83. desdeo/mcdm/gnimbus.py +484 -0
  84. desdeo/mcdm/nautilus_navigator.py +7 -6
  85. desdeo/mcdm/reference_point_method.py +70 -0
  86. desdeo/problem/__init__.py +16 -11
  87. desdeo/problem/evaluator.py +4 -5
  88. desdeo/problem/external/__init__.py +18 -0
  89. desdeo/problem/external/core.py +356 -0
  90. desdeo/problem/external/pymoo_provider.py +266 -0
  91. desdeo/problem/external/runtime.py +44 -0
  92. desdeo/problem/gurobipy_evaluator.py +37 -12
  93. desdeo/problem/infix_parser.py +1 -16
  94. desdeo/problem/json_parser.py +7 -11
  95. desdeo/problem/pyomo_evaluator.py +25 -6
  96. desdeo/problem/schema.py +73 -55
  97. desdeo/problem/simulator_evaluator.py +65 -15
  98. desdeo/problem/testproblems/__init__.py +26 -11
  99. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  100. desdeo/problem/testproblems/cake_problem.py +185 -0
  101. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  102. desdeo/problem/testproblems/forest_problem.py +77 -69
  103. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  104. desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
  105. desdeo/problem/testproblems/single_objective.py +289 -0
  106. desdeo/problem/testproblems/zdt_problem.py +4 -1
  107. desdeo/problem/utils.py +1 -1
  108. desdeo/tools/__init__.py +39 -21
  109. desdeo/tools/desc_gen.py +22 -0
  110. desdeo/tools/generics.py +22 -2
  111. desdeo/tools/group_scalarization.py +3090 -0
  112. desdeo/tools/indicators_binary.py +107 -1
  113. desdeo/tools/indicators_unary.py +3 -16
  114. desdeo/tools/message.py +33 -2
  115. desdeo/tools/non_dominated_sorting.py +4 -3
  116. desdeo/tools/patterns.py +9 -7
  117. desdeo/tools/pyomo_solver_interfaces.py +49 -36
  118. desdeo/tools/reference_vectors.py +118 -351
  119. desdeo/tools/scalarization.py +340 -1413
  120. desdeo/tools/score_bands.py +491 -328
  121. desdeo/tools/utils.py +117 -49
  122. desdeo/tools/visualizations.py +67 -0
  123. desdeo/utopia_stuff/utopia_problem.py +1 -1
  124. desdeo/utopia_stuff/utopia_problem_old.py +1 -1
  125. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/METADATA +47 -30
  126. desdeo-2.1.1.dist-info/RECORD +180 -0
  127. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/WHEEL +1 -1
  128. desdeo-2.0.0.dist-info/RECORD +0 -120
  129. /desdeo/api/utils/{logger.py → _logger.py} +0 -0
  130. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,237 @@
1
+ """Defines end-points to access functionalities related to the E-NAUTILUS method."""
2
+
3
+ from typing import Annotated
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+ from fastapi import APIRouter, Depends, HTTPException, status
8
+ from sqlmodel import Session, select
9
+
10
+ from desdeo.api.db import get_session
11
+ from desdeo.api.models import (
12
+ ENautilusRepresentativeSolutionsResponse,
13
+ ENautilusState,
14
+ ENautilusStateResponse,
15
+ ENautilusStepRequest,
16
+ ENautilusStepResponse,
17
+ ProblemDB,
18
+ RepresentativeNonDominatedSolutions,
19
+ StateDB,
20
+ )
21
+ from desdeo.mcdm import ENautilusResult, enautilus_get_representative_solutions, enautilus_step
22
+ from desdeo.problem import Problem
23
+
24
+ from .utils import (
25
+ SessionContext,
26
+ get_session_context,
27
+ )
28
+
29
+ router = APIRouter(prefix="/method/enautilus")
30
+
31
+
32
+ @router.post("/step")
33
+ def step(
34
+ request: ENautilusStepRequest,
35
+ context: Annotated[SessionContext, Depends(get_session_context)],
36
+ ) -> ENautilusStepResponse:
37
+ """Steps the E-NAUTILUS method."""
38
+ # user = context.user # not used here
39
+ db_session = context.db_session
40
+
41
+ problem_db = context.problem_db
42
+ problem = Problem.from_problemdb(problem_db)
43
+
44
+ interactive_session = context.interactive_session
45
+
46
+ parent_state = context.parent_state
47
+
48
+ representative_solutions = db_session.exec(
49
+ select(RepresentativeNonDominatedSolutions).where(
50
+ RepresentativeNonDominatedSolutions.id == request.representative_solutions_id
51
+ )
52
+ ).first()
53
+
54
+ if representative_solutions is None:
55
+ raise HTTPException(
56
+ status_code=status.HTTP_404_NOT_FOUND,
57
+ detail=(
58
+ "Could not find the requested representative solutions for the problem with "
59
+ f"id={request.representative_solutions_id}."
60
+ ),
61
+ )
62
+
63
+ if request.current_iteration == 0:
64
+ # First iteration, nadir as 'selected_point' and all points are reachable
65
+ # Nadir point is expected in 'True' values, hence the multiplication by -1 for maximized objectives
66
+ selected_point = {
67
+ f"{obj.symbol}": (-1 if obj.maximize else 1)
68
+ * np.max(representative_solutions.solution_data[f"{obj.symbol}_min"])
69
+ for obj in problem.objectives
70
+ }
71
+ reachable_point_indices = list(range(len(representative_solutions.solution_data[problem.objectives[0].symbol])))
72
+ else:
73
+ # Not first iteration
74
+ selected_point = request.selected_point
75
+ reachable_point_indices = request.reachable_point_indices
76
+
77
+ # iterate E-NAUTILUS
78
+ results: ENautilusResult = enautilus_step(
79
+ problem=problem,
80
+ non_dominated_points=representative_solutions.solution_data,
81
+ current_iteration=request.current_iteration,
82
+ iterations_left=request.iterations_left,
83
+ selected_point=selected_point,
84
+ number_of_intermediate_points=request.number_of_intermediate_points,
85
+ reachable_point_indices=reachable_point_indices,
86
+ )
87
+
88
+ enautilus_state = ENautilusState(
89
+ non_dominated_solutions_id=request.representative_solutions_id,
90
+ current_iteration=request.current_iteration,
91
+ iterations_left=request.iterations_left,
92
+ selected_point=selected_point,
93
+ reachable_point_indices=reachable_point_indices,
94
+ number_of_intermediate_points=request.number_of_intermediate_points,
95
+ enautilus_results=results,
96
+ )
97
+
98
+ # create DB state and add it to the DB
99
+ state_db = StateDB.create(
100
+ database_session=db_session,
101
+ problem_id=problem_db.id,
102
+ session_id=interactive_session.id if interactive_session is not None else None,
103
+ parent_id=parent_state.id if parent_state is not None else None,
104
+ state=enautilus_state,
105
+ )
106
+
107
+ db_session.add(state_db)
108
+ db_session.commit()
109
+ db_session.refresh(state_db)
110
+
111
+ return ENautilusStepResponse(
112
+ state_id=state_db.id,
113
+ current_iteration=results.current_iteration,
114
+ iterations_left=results.iterations_left,
115
+ intermediate_points=results.intermediate_points,
116
+ reachable_best_bounds=results.reachable_best_bounds,
117
+ reachable_worst_bounds=results.reachable_worst_bounds,
118
+ closeness_measures=results.closeness_measures,
119
+ reachable_point_indices=results.reachable_point_indices,
120
+ )
121
+
122
+
123
+ @router.get("/get_state/{state_id}")
124
+ def get_state(
125
+ state_id: int,
126
+ db_session: Annotated[Session, Depends(get_session)],
127
+ ) -> ENautilusStateResponse:
128
+ """Fetch a previous state of the the E-NAUTILUS method."""
129
+ state_db: StateDB | None = db_session.exec(select(StateDB).where(StateDB.id == state_id)).first()
130
+
131
+ if state_db is None:
132
+ raise HTTPException(
133
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find 'StateDB' with id={state_id}"
134
+ )
135
+
136
+ if not isinstance(state_db.state, ENautilusState):
137
+ raise HTTPException(
138
+ status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="The requested state does not contain an ENautilusState."
139
+ )
140
+
141
+ enautilus_state: ENautilusState = state_db.state
142
+ results: ENautilusResult = enautilus_state.enautilus_results
143
+
144
+ request = ENautilusStepRequest(
145
+ problem_id=state_db.problem_id,
146
+ session_id=state_db.session_id,
147
+ parent_state_id=state_db.parent_id,
148
+ representative_solutions_id=enautilus_state.non_dominated_solutions_id,
149
+ current_iteration=enautilus_state.current_iteration,
150
+ iterations_left=enautilus_state.iterations_left,
151
+ selected_point=enautilus_state.selected_point,
152
+ reachable_point_indices=enautilus_state.reachable_point_indices,
153
+ number_of_intermediate_points=enautilus_state.number_of_intermediate_points,
154
+ )
155
+
156
+ response = ENautilusStepResponse(
157
+ state_id=state_db.id,
158
+ current_iteration=results.current_iteration,
159
+ iterations_left=results.iterations_left,
160
+ intermediate_points=results.intermediate_points,
161
+ reachable_best_bounds=results.reachable_best_bounds,
162
+ reachable_worst_bounds=results.reachable_worst_bounds,
163
+ closeness_measures=results.closeness_measures,
164
+ reachable_point_indices=results.reachable_point_indices,
165
+ )
166
+
167
+ return ENautilusStateResponse(request=request, response=response)
168
+
169
+
170
+ @router.get("/get_representative/{state_id}")
171
+ def get_representative(
172
+ state_id: int, db_session: Annotated[Session, Depends(get_session)]
173
+ ) -> ENautilusRepresentativeSolutionsResponse:
174
+ """Computes the representative solutions that are closest to the intermediate solutions computed by E-NAUTILUS.
175
+
176
+ This endpoint should be used to get the actual solution from the
177
+ non-dominated representation used in the E-NAUTILUS method's last iteration
178
+ (when number of iterations left is 0).
179
+
180
+ Args:
181
+ state_id (int): id of the `StateDB` with information on the intermediate
182
+ points for which the representative solutions should be computed.
183
+ db_session (Annotated[Session, Depends): the database session.
184
+
185
+ Raises:
186
+ HTTPException: 404 when a `StateDB`, `ProblemDB`, or
187
+ `RepresentativeNonDominatedSolutions` instance cannot be found. 406 when
188
+ the substate of the references `StateDB` is not an instance of
189
+ `ENautilusState`.
190
+
191
+ Returns:
192
+ ENautilusRepresentativeSolutionsResponse: the information on the representative solutions.
193
+ """
194
+ state_db: StateDB | None = db_session.exec(select(StateDB).where(StateDB.id == state_id)).first()
195
+
196
+ if state_db is None:
197
+ raise HTTPException(
198
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find 'StateDB' with id={state_id}"
199
+ )
200
+
201
+ if not isinstance(state_db.state, ENautilusState):
202
+ raise HTTPException(
203
+ status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="The requested state does not contain an ENautilusState."
204
+ )
205
+
206
+ enautilus_state: ENautilusState = state_db.state
207
+ enautilus_result: ENautilusResult = enautilus_state.enautilus_results
208
+
209
+ non_dom_solutions_db = db_session.exec(
210
+ select(RepresentativeNonDominatedSolutions).where(
211
+ RepresentativeNonDominatedSolutions.id == enautilus_state.non_dominated_solutions_id
212
+ )
213
+ ).first()
214
+
215
+ if non_dom_solutions_db is None:
216
+ raise HTTPException(
217
+ status_code=status.HTTP_404_NOT_FOUND,
218
+ detail=(
219
+ "Could not find 'RepresentativeNonDominatedSolutions' with "
220
+ f"id={enautilus_state.non_dominated_solutions_id}"
221
+ ),
222
+ )
223
+
224
+ non_dom_solutions = pl.DataFrame(non_dom_solutions_db.solution_data)
225
+
226
+ problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == state_db.problem_id)).first()
227
+
228
+ if problem_db is None:
229
+ raise HTTPException(
230
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find 'ProblemDB' with id={state_db.problem_id}"
231
+ )
232
+
233
+ problem = Problem.from_problemdb(problem_db)
234
+
235
+ representative_solutions = enautilus_get_representative_solutions(problem, enautilus_result, non_dom_solutions)
236
+
237
+ return ENautilusRepresentativeSolutionsResponse(solutions=representative_solutions)
@@ -0,0 +1,234 @@
1
+ """A structure for group decision making.
2
+
3
+ When preferences are sent through the websockets, they are validated.
4
+ Then, the preferences are saved into a database. When all group members have articulated their
5
+ preferences, system begins optimization. The results are then saved into the database and the system notifies all
6
+ connected users that the solutions are ready to be fetched. If a user is not connected to the server, the server will
7
+ notify the user when they connect next time.
8
+
9
+ """
10
+
11
+ import asyncio
12
+ import logging
13
+ import sys
14
+ from datetime import UTC, datetime
15
+ from typing import Annotated
16
+
17
+ from fastapi import (
18
+ APIRouter,
19
+ Depends,
20
+ Query,
21
+ WebSocket,
22
+ WebSocketDisconnect,
23
+ )
24
+ from jose import ExpiredSignatureError, JWTError, jwt
25
+ from sqlmodel import Session, select
26
+
27
+ from desdeo.api import AuthConfig
28
+ from desdeo.api.db import get_session
29
+ from desdeo.api.models import (
30
+ Group,
31
+ User,
32
+ )
33
+ from desdeo.api.routers.gdm.gdm_base import GroupManager
34
+ from desdeo.api.routers.gdm.gdm_score_bands.gdm_score_bands_manager import GDMScoreBandsManager
35
+ from desdeo.api.routers.gdm.gnimbus.gnimbus_manager import GNIMBUSManager
36
+ from desdeo.api.routers.user_authentication import get_user
37
+
38
+ logging.basicConfig(
39
+ stream=sys.stdout, format="[%(filename)s:%(lineno)d] %(levelname)s: %(message)s", level=logging.INFO
40
+ )
41
+ logger = logging.getLogger(__name__)
42
+
43
+ router = APIRouter(prefix="/gdm")
44
+
45
+
46
+ class ManagerManager:
47
+ """A singleton class to manage group managers. Spawns them and deletes them.
48
+
49
+ TODO: Also check on manager type! If a Group has a NIMBUSManager, but for
50
+ example a RPMManager is requested, create it.
51
+ """
52
+
53
+ def __init__(self):
54
+ """Class constructor."""
55
+ # self.group_managers: dict[int, GroupManager] = {}
56
+ self.group_managers: dict[int, dict[str, GroupManager]] = {}
57
+ self.lock = asyncio.Lock()
58
+
59
+ async def get_group_manager(
60
+ self, group_id: int, method: str, db_session: Session
61
+ ) -> GroupManager | GNIMBUSManager | GDMScoreBandsManager | None:
62
+ """Return the correct group manager for the caller.
63
+
64
+ Args:
65
+ group_id (int): The ID of the group of the mgr
66
+ method (str): The method of the group mgr
67
+ db_session (Session): the database session passed to the manager.
68
+
69
+ Returns:
70
+ GroupManager | GNIMBUSManager | GDMScoreBandsManager | None: The manager (or not if not implemented.)
71
+ """
72
+ async with self.lock:
73
+ if group_id in self.group_managers:
74
+ managers = self.group_managers[group_id]
75
+ if method in managers:
76
+ return managers[method]
77
+ # If there is no manager, create it.
78
+ match method:
79
+ case "gnimbus":
80
+ manager = GNIMBUSManager(group_id=group_id, db_session=db_session)
81
+ self.group_managers[group_id][method] = manager
82
+ return manager
83
+ case "gdm-score-bands":
84
+ manager = GDMScoreBandsManager(group_id=group_id, db_session=db_session)
85
+ self.group_managers[group_id][method] = manager
86
+ return manager
87
+ else:
88
+ self.group_managers[group_id] = {}
89
+ match method:
90
+ case "gnimbus":
91
+ manager = GNIMBUSManager(group_id=group_id, db_session=db_session)
92
+ self.group_managers[group_id][method] = manager
93
+ return manager
94
+ case "gdm-score-bands":
95
+ manager = GDMScoreBandsManager(group_id=group_id, db_session=db_session)
96
+ self.group_managers[group_id][method] = manager
97
+ return manager
98
+
99
+ async def check_disconnect(self, group_id: int, method: str):
100
+ """Checks if a group manager has active connections. If no, delete it.
101
+
102
+ Args:
103
+ group_id (int): ID of the group
104
+ method (str): method of the manager
105
+
106
+ Returns:
107
+ Nothing.
108
+ """
109
+ async with self.lock:
110
+ # check if group has any managers
111
+ if group_id in self.group_managers:
112
+ managers = self.group_managers[group_id]
113
+ # Check if method has a manager
114
+ if method in managers:
115
+ manager = managers[method]
116
+ # check if the manager has any active websockets
117
+ for _, socket in manager.sockets.items():
118
+ if socket is not None:
119
+ return
120
+ # No active sockets, delete the manager.
121
+ async with manager.lock:
122
+ del self.group_managers[group_id][method]
123
+ # If group has no managers, delete the group entry.
124
+ if self.group_managers[group_id] == {}:
125
+ del self.group_managers[group_id]
126
+
127
+
128
+ manager = ManagerManager()
129
+
130
+
131
+ async def auth_user(token: str, session: Session, websocket: WebSocket) -> User:
132
+ """Authenticate the user.
133
+
134
+ token: str: the access token of the user.
135
+ session: Session: the database session from where the user is received
136
+ websocket: WebSocket: the websocket that the user has connected with
137
+
138
+ """
139
+
140
+ async def error_and_close():
141
+ await websocket.send_text("Could not validate credencials. Try logging in again.")
142
+ await websocket.close()
143
+
144
+ try:
145
+ payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
146
+ username = payload.get("sub")
147
+ expire_time: datetime = payload.get("exp")
148
+
149
+ if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
150
+ return await error_and_close()
151
+
152
+ except ExpiredSignatureError:
153
+ return await error_and_close()
154
+
155
+ except JWTError:
156
+ return await error_and_close()
157
+
158
+ user = get_user(session, username=username)
159
+
160
+ if user is None:
161
+ return await error_and_close()
162
+
163
+ return user
164
+
165
+
166
+ @router.websocket("/ws")
167
+ async def websocket_endpoint(
168
+ session: Annotated[Session, Depends(get_session)],
169
+ websocket: WebSocket,
170
+ token: str = Query(),
171
+ group_id: int = Query(),
172
+ method: str = Query(),
173
+ ):
174
+ """The websocket endpoint to which the user connects.
175
+
176
+ Both the access token and the group id is given as a query parameter to the endpoint.
177
+ The call to this endpoint looks like the following:
178
+
179
+ ws://[DOMAIN]:[PORT]/gdm/ws?token=[TOKEN]&group_id=[GROUP_ID]&method=[METHOD]
180
+
181
+ See further details in the documentation. (Explanations -> GDM and websockets)
182
+ """
183
+ # Accept the websocket (to send back stuff if something goes wrong)
184
+ await websocket.accept()
185
+
186
+ user = await auth_user(token, session, websocket)
187
+ if user is None:
188
+ return
189
+
190
+ group = session.exec(select(Group).where(Group.id == group_id)).first()
191
+ if group is None:
192
+ await websocket.send_text(f"There is no group with ID {group_id}.")
193
+ await websocket.close()
194
+ return
195
+
196
+ if not (user.id in group.user_ids or user.id is group.owner_id):
197
+ await websocket.send_text(f"User {user.username} doesn't belong in group {group.name}")
198
+ await websocket.close()
199
+ return
200
+
201
+ # We don't need the session here any more, so we can just close it.
202
+ # I believe this releases connections to the pool
203
+ session.close()
204
+
205
+ # Get the group manager object from the manager of group managers
206
+ group_manager = await manager.get_group_manager(group_id=group_id, method=method)
207
+ if group_manager is None:
208
+ await websocket.send_text(f"Unknown method: {method}")
209
+ await websocket.close()
210
+ return
211
+
212
+ await group_manager.connect(user.id, websocket)
213
+ logger.info(f"Group ID {group_id} manager's active connections {group_manager.sockets}")
214
+ logger.info(f"Existing GroupManagers: {manager.group_managers}")
215
+ while True:
216
+ try:
217
+ # Get data from socket
218
+ data = await websocket.receive_text()
219
+ # send data for preference setting
220
+ if user.id in group.user_ids:
221
+ await group_manager.run_method(user.id, data)
222
+ else:
223
+ logger.warning(
224
+ f"User {user.username} is not part of group {group.name}! They're likely the group owner."
225
+ )
226
+ except WebSocketDisconnect:
227
+ await group_manager.disconnect(user.id, websocket)
228
+ await manager.check_disconnect(group_id=group_id, method=method)
229
+ logger.info(f"Group ID {group_id} manager's active connections {group_manager.sockets}")
230
+ logger.info(f"Existing GroupManagers: {manager.group_managers}")
231
+ break
232
+ except RuntimeError as e:
233
+ logger.warning(f"RuntimeError: {e}")
234
+ break