desdeo 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/__init__.py +6 -6
- desdeo/api/app.py +38 -28
- desdeo/api/config.py +65 -44
- desdeo/api/config.toml +23 -12
- desdeo/api/db.py +10 -8
- desdeo/api/db_init.py +12 -6
- desdeo/api/models/__init__.py +220 -20
- desdeo/api/models/archive.py +16 -27
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +44 -6
- desdeo/api/models/problem.py +274 -64
- desdeo/api/models/session.py +4 -1
- desdeo/api/models/state.py +419 -52
- desdeo/api/models/user.py +7 -6
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NIMBUS.py +6 -3
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +201 -4
- desdeo/api/routers/reference_point_method.py +20 -44
- desdeo/api/routers/session.py +50 -26
- desdeo/api/routers/user_authentication.py +180 -26
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +10 -4
- desdeo/api/tests/conftest.py +94 -2
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +550 -72
- desdeo/api/tests/test_routes.py +902 -43
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/database.py +28 -266
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +7 -0
- desdeo/emo/__init__.py +154 -24
- desdeo/emo/hooks/archivers.py +18 -2
- desdeo/emo/methods/EAs.py +128 -5
- desdeo/emo/methods/bases.py +9 -56
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/crossover.py +544 -42
- desdeo/emo/operators/evaluator.py +10 -14
- desdeo/emo/operators/generator.py +127 -24
- desdeo/emo/operators/mutation.py +212 -41
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +956 -214
- desdeo/emo/operators/termination.py +124 -16
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +23 -1
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautilus_navigator.py +7 -6
- desdeo/mcdm/reference_point_method.py +70 -0
- desdeo/problem/__init__.py +5 -1
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/infix_parser.py +2 -2
- desdeo/problem/pyomo_evaluator.py +25 -6
- desdeo/problem/schema.py +69 -48
- desdeo/problem/simulator_evaluator.py +65 -15
- desdeo/problem/testproblems/__init__.py +26 -11
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/forest_problem.py +77 -69
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/zdt_problem.py +4 -1
- desdeo/tools/__init__.py +39 -21
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +22 -2
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/indicators_binary.py +107 -1
- desdeo/tools/indicators_unary.py +3 -16
- desdeo/tools/message.py +33 -2
- desdeo/tools/non_dominated_sorting.py +4 -3
- desdeo/tools/patterns.py +9 -7
- desdeo/tools/pyomo_solver_interfaces.py +48 -35
- desdeo/tools/reference_vectors.py +118 -351
- desdeo/tools/scalarization.py +340 -1413
- desdeo/tools/score_bands.py +491 -328
- desdeo/tools/utils.py +117 -49
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/utopia_problem.py +1 -1
- desdeo/utopia_stuff/utopia_problem_old.py +1 -1
- {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/METADATA +46 -28
- desdeo-2.1.0.dist-info/RECORD +180 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
- desdeo-2.0.0.dist-info/RECORD +0 -120
- /desdeo/api/utils/{logger.py → _logger.py} +0 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Defines end-points to access functionalities related to the E-NAUTILUS method."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import polars as pl
|
|
7
|
+
from fastapi import APIRouter, Depends, HTTPException, status
|
|
8
|
+
from sqlmodel import Session, select
|
|
9
|
+
|
|
10
|
+
from desdeo.api.db import get_session
|
|
11
|
+
from desdeo.api.models import (
|
|
12
|
+
ENautilusRepresentativeSolutionsResponse,
|
|
13
|
+
ENautilusState,
|
|
14
|
+
ENautilusStateResponse,
|
|
15
|
+
ENautilusStepRequest,
|
|
16
|
+
ENautilusStepResponse,
|
|
17
|
+
ProblemDB,
|
|
18
|
+
RepresentativeNonDominatedSolutions,
|
|
19
|
+
StateDB,
|
|
20
|
+
)
|
|
21
|
+
from desdeo.mcdm import ENautilusResult, enautilus_get_representative_solutions, enautilus_step
|
|
22
|
+
from desdeo.problem import Problem
|
|
23
|
+
|
|
24
|
+
from .utils import (
|
|
25
|
+
SessionContext,
|
|
26
|
+
get_session_context,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
router = APIRouter(prefix="/method/enautilus")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@router.post("/step")
|
|
33
|
+
def step(
|
|
34
|
+
request: ENautilusStepRequest,
|
|
35
|
+
context: Annotated[SessionContext, Depends(get_session_context)],
|
|
36
|
+
) -> ENautilusStepResponse:
|
|
37
|
+
"""Steps the E-NAUTILUS method."""
|
|
38
|
+
# user = context.user # not used here
|
|
39
|
+
db_session = context.db_session
|
|
40
|
+
|
|
41
|
+
problem_db = context.problem_db
|
|
42
|
+
problem = Problem.from_problemdb(problem_db)
|
|
43
|
+
|
|
44
|
+
interactive_session = context.interactive_session
|
|
45
|
+
|
|
46
|
+
parent_state = context.parent_state
|
|
47
|
+
|
|
48
|
+
representative_solutions = db_session.exec(
|
|
49
|
+
select(RepresentativeNonDominatedSolutions).where(
|
|
50
|
+
RepresentativeNonDominatedSolutions.id == request.representative_solutions_id
|
|
51
|
+
)
|
|
52
|
+
).first()
|
|
53
|
+
|
|
54
|
+
if representative_solutions is None:
|
|
55
|
+
raise HTTPException(
|
|
56
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
57
|
+
detail=(
|
|
58
|
+
"Could not find the requested representative solutions for the problem with "
|
|
59
|
+
f"id={request.representative_solutions_id}."
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if request.current_iteration == 0:
|
|
64
|
+
# First iteration, nadir as 'selected_point' and all points are reachable
|
|
65
|
+
# Nadir point is expected in 'True' values, hence the multiplication by -1 for maximized objectives
|
|
66
|
+
selected_point = {
|
|
67
|
+
f"{obj.symbol}": (-1 if obj.maximize else 1)
|
|
68
|
+
* np.max(representative_solutions.solution_data[f"{obj.symbol}_min"])
|
|
69
|
+
for obj in problem.objectives
|
|
70
|
+
}
|
|
71
|
+
reachable_point_indices = list(range(len(representative_solutions.solution_data[problem.objectives[0].symbol])))
|
|
72
|
+
else:
|
|
73
|
+
# Not first iteration
|
|
74
|
+
selected_point = request.selected_point
|
|
75
|
+
reachable_point_indices = request.reachable_point_indices
|
|
76
|
+
|
|
77
|
+
# iterate E-NAUTILUS
|
|
78
|
+
results: ENautilusResult = enautilus_step(
|
|
79
|
+
problem=problem,
|
|
80
|
+
non_dominated_points=representative_solutions.solution_data,
|
|
81
|
+
current_iteration=request.current_iteration,
|
|
82
|
+
iterations_left=request.iterations_left,
|
|
83
|
+
selected_point=selected_point,
|
|
84
|
+
number_of_intermediate_points=request.number_of_intermediate_points,
|
|
85
|
+
reachable_point_indices=reachable_point_indices,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
enautilus_state = ENautilusState(
|
|
89
|
+
non_dominated_solutions_id=request.representative_solutions_id,
|
|
90
|
+
current_iteration=request.current_iteration,
|
|
91
|
+
iterations_left=request.iterations_left,
|
|
92
|
+
selected_point=selected_point,
|
|
93
|
+
reachable_point_indices=reachable_point_indices,
|
|
94
|
+
number_of_intermediate_points=request.number_of_intermediate_points,
|
|
95
|
+
enautilus_results=results,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# create DB state and add it to the DB
|
|
99
|
+
state_db = StateDB.create(
|
|
100
|
+
database_session=db_session,
|
|
101
|
+
problem_id=problem_db.id,
|
|
102
|
+
session_id=interactive_session.id if interactive_session is not None else None,
|
|
103
|
+
parent_id=parent_state.id if parent_state is not None else None,
|
|
104
|
+
state=enautilus_state,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
db_session.add(state_db)
|
|
108
|
+
db_session.commit()
|
|
109
|
+
db_session.refresh(state_db)
|
|
110
|
+
|
|
111
|
+
return ENautilusStepResponse(
|
|
112
|
+
state_id=state_db.id,
|
|
113
|
+
current_iteration=results.current_iteration,
|
|
114
|
+
iterations_left=results.iterations_left,
|
|
115
|
+
intermediate_points=results.intermediate_points,
|
|
116
|
+
reachable_best_bounds=results.reachable_best_bounds,
|
|
117
|
+
reachable_worst_bounds=results.reachable_worst_bounds,
|
|
118
|
+
closeness_measures=results.closeness_measures,
|
|
119
|
+
reachable_point_indices=results.reachable_point_indices,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@router.get("/get_state/{state_id}")
|
|
124
|
+
def get_state(
|
|
125
|
+
state_id: int,
|
|
126
|
+
db_session: Annotated[Session, Depends(get_session)],
|
|
127
|
+
) -> ENautilusStateResponse:
|
|
128
|
+
"""Fetch a previous state of the the E-NAUTILUS method."""
|
|
129
|
+
state_db: StateDB | None = db_session.exec(select(StateDB).where(StateDB.id == state_id)).first()
|
|
130
|
+
|
|
131
|
+
if state_db is None:
|
|
132
|
+
raise HTTPException(
|
|
133
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find 'StateDB' with id={state_id}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if not isinstance(state_db.state, ENautilusState):
|
|
137
|
+
raise HTTPException(
|
|
138
|
+
status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="The requested state does not contain an ENautilusState."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
enautilus_state: ENautilusState = state_db.state
|
|
142
|
+
results: ENautilusResult = enautilus_state.enautilus_results
|
|
143
|
+
|
|
144
|
+
request = ENautilusStepRequest(
|
|
145
|
+
problem_id=state_db.problem_id,
|
|
146
|
+
session_id=state_db.session_id,
|
|
147
|
+
parent_state_id=state_db.parent_id,
|
|
148
|
+
representative_solutions_id=enautilus_state.non_dominated_solutions_id,
|
|
149
|
+
current_iteration=enautilus_state.current_iteration,
|
|
150
|
+
iterations_left=enautilus_state.iterations_left,
|
|
151
|
+
selected_point=enautilus_state.selected_point,
|
|
152
|
+
reachable_point_indices=enautilus_state.reachable_point_indices,
|
|
153
|
+
number_of_intermediate_points=enautilus_state.number_of_intermediate_points,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
response = ENautilusStepResponse(
|
|
157
|
+
state_id=state_db.id,
|
|
158
|
+
current_iteration=results.current_iteration,
|
|
159
|
+
iterations_left=results.iterations_left,
|
|
160
|
+
intermediate_points=results.intermediate_points,
|
|
161
|
+
reachable_best_bounds=results.reachable_best_bounds,
|
|
162
|
+
reachable_worst_bounds=results.reachable_worst_bounds,
|
|
163
|
+
closeness_measures=results.closeness_measures,
|
|
164
|
+
reachable_point_indices=results.reachable_point_indices,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return ENautilusStateResponse(request=request, response=response)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@router.get("/get_representative/{state_id}")
|
|
171
|
+
def get_representative(
|
|
172
|
+
state_id: int, db_session: Annotated[Session, Depends(get_session)]
|
|
173
|
+
) -> ENautilusRepresentativeSolutionsResponse:
|
|
174
|
+
"""Computes the representative solutions that are closest to the intermediate solutions computed by E-NAUTILUS.
|
|
175
|
+
|
|
176
|
+
This endpoint should be used to get the actual solution from the
|
|
177
|
+
non-dominated representation used in the E-NAUTILUS method's last iteration
|
|
178
|
+
(when number of iterations left is 0).
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
state_id (int): id of the `StateDB` with information on the intermediate
|
|
182
|
+
points for which the representative solutions should be computed.
|
|
183
|
+
db_session (Annotated[Session, Depends): the database session.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
HTTPException: 404 when a `StateDB`, `ProblemDB`, or
|
|
187
|
+
`RepresentativeNonDominatedSolutions` instance cannot be found. 406 when
|
|
188
|
+
the substate of the references `StateDB` is not an instance of
|
|
189
|
+
`ENautilusState`.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
ENautilusRepresentativeSolutionsResponse: the information on the representative solutions.
|
|
193
|
+
"""
|
|
194
|
+
state_db: StateDB | None = db_session.exec(select(StateDB).where(StateDB.id == state_id)).first()
|
|
195
|
+
|
|
196
|
+
if state_db is None:
|
|
197
|
+
raise HTTPException(
|
|
198
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find 'StateDB' with id={state_id}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if not isinstance(state_db.state, ENautilusState):
|
|
202
|
+
raise HTTPException(
|
|
203
|
+
status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="The requested state does not contain an ENautilusState."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
enautilus_state: ENautilusState = state_db.state
|
|
207
|
+
enautilus_result: ENautilusResult = enautilus_state.enautilus_results
|
|
208
|
+
|
|
209
|
+
non_dom_solutions_db = db_session.exec(
|
|
210
|
+
select(RepresentativeNonDominatedSolutions).where(
|
|
211
|
+
RepresentativeNonDominatedSolutions.id == enautilus_state.non_dominated_solutions_id
|
|
212
|
+
)
|
|
213
|
+
).first()
|
|
214
|
+
|
|
215
|
+
if non_dom_solutions_db is None:
|
|
216
|
+
raise HTTPException(
|
|
217
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
218
|
+
detail=(
|
|
219
|
+
"Could not find 'RepresentativeNonDominatedSolutions' with "
|
|
220
|
+
f"id={enautilus_state.non_dominated_solutions_id}"
|
|
221
|
+
),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
non_dom_solutions = pl.DataFrame(non_dom_solutions_db.solution_data)
|
|
225
|
+
|
|
226
|
+
problem_db = db_session.exec(select(ProblemDB).where(ProblemDB.id == state_db.problem_id)).first()
|
|
227
|
+
|
|
228
|
+
if problem_db is None:
|
|
229
|
+
raise HTTPException(
|
|
230
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find 'ProblemDB' with id={state_db.problem_id}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
problem = Problem.from_problemdb(problem_db)
|
|
234
|
+
|
|
235
|
+
representative_solutions = enautilus_get_representative_solutions(problem, enautilus_result, non_dom_solutions)
|
|
236
|
+
|
|
237
|
+
return ENautilusRepresentativeSolutionsResponse(solutions=representative_solutions)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""A structure for group decision making.
|
|
2
|
+
|
|
3
|
+
When preferences are sent through the websockets, they are validated.
|
|
4
|
+
Then, the preferences are saved into a database. When all group members have articulated their
|
|
5
|
+
preferences, system begins optimization. The results are then saved into the database and the system notifies all
|
|
6
|
+
connected users that the solutions are ready to be fetched. If a user is not connected to the server, the server will
|
|
7
|
+
notify the user when they connect next time.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
import sys
|
|
14
|
+
from datetime import UTC, datetime
|
|
15
|
+
from typing import Annotated
|
|
16
|
+
|
|
17
|
+
from fastapi import (
|
|
18
|
+
APIRouter,
|
|
19
|
+
Depends,
|
|
20
|
+
Query,
|
|
21
|
+
WebSocket,
|
|
22
|
+
WebSocketDisconnect,
|
|
23
|
+
)
|
|
24
|
+
from jose import ExpiredSignatureError, JWTError, jwt
|
|
25
|
+
from sqlmodel import Session, select
|
|
26
|
+
|
|
27
|
+
from desdeo.api import AuthConfig
|
|
28
|
+
from desdeo.api.db import get_session
|
|
29
|
+
from desdeo.api.models import (
|
|
30
|
+
Group,
|
|
31
|
+
User,
|
|
32
|
+
)
|
|
33
|
+
from desdeo.api.routers.gdm.gdm_base import GroupManager
|
|
34
|
+
from desdeo.api.routers.gdm.gdm_score_bands.gdm_score_bands_manager import GDMScoreBandsManager
|
|
35
|
+
from desdeo.api.routers.gdm.gnimbus.gnimbus_manager import GNIMBUSManager
|
|
36
|
+
from desdeo.api.routers.user_authentication import get_user
|
|
37
|
+
|
|
38
|
+
logging.basicConfig(
|
|
39
|
+
stream=sys.stdout, format="[%(filename)s:%(lineno)d] %(levelname)s: %(message)s", level=logging.INFO
|
|
40
|
+
)
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
router = APIRouter(prefix="/gdm")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ManagerManager:
|
|
47
|
+
"""A singleton class to manage group managers. Spawns them and deletes them.
|
|
48
|
+
|
|
49
|
+
TODO: Also check on manager type! If a Group has a NIMBUSManager, but for
|
|
50
|
+
example a RPMManager is requested, create it.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self):
|
|
54
|
+
"""Class constructor."""
|
|
55
|
+
# self.group_managers: dict[int, GroupManager] = {}
|
|
56
|
+
self.group_managers: dict[int, dict[str, GroupManager]] = {}
|
|
57
|
+
self.lock = asyncio.Lock()
|
|
58
|
+
|
|
59
|
+
async def get_group_manager(
|
|
60
|
+
self, group_id: int, method: str, db_session: Session
|
|
61
|
+
) -> GroupManager | GNIMBUSManager | GDMScoreBandsManager | None:
|
|
62
|
+
"""Return the correct group manager for the caller.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
group_id (int): The ID of the group of the mgr
|
|
66
|
+
method (str): The method of the group mgr
|
|
67
|
+
db_session (Session): the database session passed to the manager.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
GroupManager | GNIMBUSManager | GDMScoreBandsManager | None: The manager (or not if not implemented.)
|
|
71
|
+
"""
|
|
72
|
+
async with self.lock:
|
|
73
|
+
if group_id in self.group_managers:
|
|
74
|
+
managers = self.group_managers[group_id]
|
|
75
|
+
if method in managers:
|
|
76
|
+
return managers[method]
|
|
77
|
+
# If there is no manager, create it.
|
|
78
|
+
match method:
|
|
79
|
+
case "gnimbus":
|
|
80
|
+
manager = GNIMBUSManager(group_id=group_id, db_session=db_session)
|
|
81
|
+
self.group_managers[group_id][method] = manager
|
|
82
|
+
return manager
|
|
83
|
+
case "gdm-score-bands":
|
|
84
|
+
manager = GDMScoreBandsManager(group_id=group_id, db_session=db_session)
|
|
85
|
+
self.group_managers[group_id][method] = manager
|
|
86
|
+
return manager
|
|
87
|
+
else:
|
|
88
|
+
self.group_managers[group_id] = {}
|
|
89
|
+
match method:
|
|
90
|
+
case "gnimbus":
|
|
91
|
+
manager = GNIMBUSManager(group_id=group_id, db_session=db_session)
|
|
92
|
+
self.group_managers[group_id][method] = manager
|
|
93
|
+
return manager
|
|
94
|
+
case "gdm-score-bands":
|
|
95
|
+
manager = GDMScoreBandsManager(group_id=group_id, db_session=db_session)
|
|
96
|
+
self.group_managers[group_id][method] = manager
|
|
97
|
+
return manager
|
|
98
|
+
|
|
99
|
+
async def check_disconnect(self, group_id: int, method: str):
|
|
100
|
+
"""Checks if a group manager has active connections. If no, delete it.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
group_id (int): ID of the group
|
|
104
|
+
method (str): method of the manager
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Nothing.
|
|
108
|
+
"""
|
|
109
|
+
async with self.lock:
|
|
110
|
+
# check if group has any managers
|
|
111
|
+
if group_id in self.group_managers:
|
|
112
|
+
managers = self.group_managers[group_id]
|
|
113
|
+
# Check if method has a manager
|
|
114
|
+
if method in managers:
|
|
115
|
+
manager = managers[method]
|
|
116
|
+
# check if the manager has any active websockets
|
|
117
|
+
for _, socket in manager.sockets.items():
|
|
118
|
+
if socket is not None:
|
|
119
|
+
return
|
|
120
|
+
# No active sockets, delete the manager.
|
|
121
|
+
async with manager.lock:
|
|
122
|
+
del self.group_managers[group_id][method]
|
|
123
|
+
# If group has no managers, delete the group entry.
|
|
124
|
+
if self.group_managers[group_id] == {}:
|
|
125
|
+
del self.group_managers[group_id]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
manager = ManagerManager()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
async def auth_user(token: str, session: Session, websocket: WebSocket) -> User:
|
|
132
|
+
"""Authenticate the user.
|
|
133
|
+
|
|
134
|
+
token: str: the access token of the user.
|
|
135
|
+
session: Session: the database session from where the user is received
|
|
136
|
+
websocket: WebSocket: the websocket that the user has connected with
|
|
137
|
+
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
async def error_and_close():
|
|
141
|
+
await websocket.send_text("Could not validate credencials. Try logging in again.")
|
|
142
|
+
await websocket.close()
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
|
|
146
|
+
username = payload.get("sub")
|
|
147
|
+
expire_time: datetime = payload.get("exp")
|
|
148
|
+
|
|
149
|
+
if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
|
|
150
|
+
return await error_and_close()
|
|
151
|
+
|
|
152
|
+
except ExpiredSignatureError:
|
|
153
|
+
return await error_and_close()
|
|
154
|
+
|
|
155
|
+
except JWTError:
|
|
156
|
+
return await error_and_close()
|
|
157
|
+
|
|
158
|
+
user = get_user(session, username=username)
|
|
159
|
+
|
|
160
|
+
if user is None:
|
|
161
|
+
return await error_and_close()
|
|
162
|
+
|
|
163
|
+
return user
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@router.websocket("/ws")
|
|
167
|
+
async def websocket_endpoint(
|
|
168
|
+
session: Annotated[Session, Depends(get_session)],
|
|
169
|
+
websocket: WebSocket,
|
|
170
|
+
token: str = Query(),
|
|
171
|
+
group_id: int = Query(),
|
|
172
|
+
method: str = Query(),
|
|
173
|
+
):
|
|
174
|
+
"""The websocket endpoint to which the user connects.
|
|
175
|
+
|
|
176
|
+
Both the access token and the group id is given as a query parameter to the endpoint.
|
|
177
|
+
The call to this endpoint looks like the following:
|
|
178
|
+
|
|
179
|
+
ws://[DOMAIN]:[PORT]/gdm/ws?token=[TOKEN]&group_id=[GROUP_ID]&method=[METHOD]
|
|
180
|
+
|
|
181
|
+
See further details in the documentation. (Explanations -> GDM and websockets)
|
|
182
|
+
"""
|
|
183
|
+
# Accept the websocket (to send back stuff if something goes wrong)
|
|
184
|
+
await websocket.accept()
|
|
185
|
+
|
|
186
|
+
user = await auth_user(token, session, websocket)
|
|
187
|
+
if user is None:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
group = session.exec(select(Group).where(Group.id == group_id)).first()
|
|
191
|
+
if group is None:
|
|
192
|
+
await websocket.send_text(f"There is no group with ID {group_id}.")
|
|
193
|
+
await websocket.close()
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
if not (user.id in group.user_ids or user.id is group.owner_id):
|
|
197
|
+
await websocket.send_text(f"User {user.username} doesn't belong in group {group.name}")
|
|
198
|
+
await websocket.close()
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
# We don't need the session here any more, so we can just close it.
|
|
202
|
+
# I believe this releases connections to the pool
|
|
203
|
+
session.close()
|
|
204
|
+
|
|
205
|
+
# Get the group manager object from the manager of group managers
|
|
206
|
+
group_manager = await manager.get_group_manager(group_id=group_id, method=method)
|
|
207
|
+
if group_manager is None:
|
|
208
|
+
await websocket.send_text(f"Unknown method: {method}")
|
|
209
|
+
await websocket.close()
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
await group_manager.connect(user.id, websocket)
|
|
213
|
+
logger.info(f"Group ID {group_id} manager's active connections {group_manager.sockets}")
|
|
214
|
+
logger.info(f"Existing GroupManagers: {manager.group_managers}")
|
|
215
|
+
while True:
|
|
216
|
+
try:
|
|
217
|
+
# Get data from socket
|
|
218
|
+
data = await websocket.receive_text()
|
|
219
|
+
# send data for preference setting
|
|
220
|
+
if user.id in group.user_ids:
|
|
221
|
+
await group_manager.run_method(user.id, data)
|
|
222
|
+
else:
|
|
223
|
+
logger.warning(
|
|
224
|
+
f"User {user.username} is not part of group {group.name}! They're likely the group owner."
|
|
225
|
+
)
|
|
226
|
+
except WebSocketDisconnect:
|
|
227
|
+
await group_manager.disconnect(user.id, websocket)
|
|
228
|
+
await manager.check_disconnect(group_id=group_id, method=method)
|
|
229
|
+
logger.info(f"Group ID {group_id} manager's active connections {group_manager.sockets}")
|
|
230
|
+
logger.info(f"Existing GroupManagers: {manager.group_managers}")
|
|
231
|
+
break
|
|
232
|
+
except RuntimeError as e:
|
|
233
|
+
logger.warning(f"RuntimeError: {e}")
|
|
234
|
+
break
|