desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +50 -0
- desdeo/api/config.py +90 -0
- desdeo/api/config.toml +64 -0
- desdeo/api/db.py +27 -0
- desdeo/api/db_init.py +85 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +266 -0
- desdeo/api/models/archive.py +23 -0
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +128 -0
- desdeo/api/models/problem.py +717 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +49 -0
- desdeo/api/models/state.py +463 -0
- desdeo/api/models/user.py +52 -0
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +765 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +307 -0
- desdeo/api/routers/reference_point_method.py +93 -0
- desdeo/api/routers/session.py +100 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +520 -0
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +100 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +151 -0
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +1179 -0
- desdeo/api/tests/test_routes.py +1075 -0
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/_logger.py +29 -0
- desdeo/api/utils/database.py +36 -0
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +34 -0
- desdeo/emo/__init__.py +159 -0
- desdeo/emo/hooks/archivers.py +188 -0
- desdeo/emo/methods/EAs.py +541 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +12 -0
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +1282 -0
- desdeo/emo/operators/evaluator.py +114 -0
- desdeo/emo/operators/generator.py +459 -0
- desdeo/emo/operators/mutation.py +1224 -0
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +1778 -0
- desdeo/emo/operators/termination.py +286 -0
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +41 -0
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +656 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +186 -0
- desdeo/problem/__init__.py +83 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +487 -0
- desdeo/problem/schema.py +1829 -0
- desdeo/problem/simulator_evaluator.py +348 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +88 -0
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +283 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problems.py +440 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +274 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +120 -0
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +165 -0
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +117 -0
- desdeo/tools/indicators_unary.py +362 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +265 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +134 -0
- desdeo/tools/patterns.py +283 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +477 -0
- desdeo/tools/reference_vectors.py +229 -0
- desdeo/tools/scalarization.py +2065 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +627 -0
- desdeo/tools/utils.py +388 -0
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.1.0.dist-info/METADATA +186 -0
- desdeo-2.1.0.dist-info/RECORD +180 -0
- {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
- desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
- desdeo-1.2.dist-info/METADATA +0 -16
- desdeo-1.2.dist-info/RECORD +0 -4
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
"""A base group manager structure for group decision making."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Annotated
|
|
7
|
+
|
|
8
|
+
from fastapi import (
|
|
9
|
+
APIRouter,
|
|
10
|
+
Depends,
|
|
11
|
+
HTTPException,
|
|
12
|
+
WebSocket,
|
|
13
|
+
WebSocketDisconnect,
|
|
14
|
+
status,
|
|
15
|
+
)
|
|
16
|
+
from fastapi.responses import JSONResponse
|
|
17
|
+
from sqlmodel import Session, select
|
|
18
|
+
|
|
19
|
+
from desdeo.api.db import get_session
|
|
20
|
+
from desdeo.api.models import (
|
|
21
|
+
Group,
|
|
22
|
+
GroupCreateRequest,
|
|
23
|
+
GroupInfoRequest,
|
|
24
|
+
GroupIteration,
|
|
25
|
+
GroupModifyRequest,
|
|
26
|
+
GroupPublic,
|
|
27
|
+
ProblemDB,
|
|
28
|
+
User,
|
|
29
|
+
)
|
|
30
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
31
|
+
|
|
32
|
+
logging.basicConfig(
|
|
33
|
+
stream=sys.stdout, format="[%(filename)s:%(lineno)d] %(levelname)s: %(message)s", level=logging.INFO
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
router = APIRouter(prefix="/gdm", tags=["GDM"])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ManagerError(Exception):
|
|
40
|
+
"""If something goes awry with the manager."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GroupManager:
|
|
44
|
+
"""A group manager. Manages connections, disconnections, optimization and communication to users."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, group_id: int, db_session: Session):
|
|
47
|
+
"""Initializes the instance of the group manager."""
|
|
48
|
+
self.lock = asyncio.Lock()
|
|
49
|
+
self.sockets: dict[int, WebSocket] = {}
|
|
50
|
+
self.group_id: int = group_id
|
|
51
|
+
|
|
52
|
+
# Get session and make sure the group exists
|
|
53
|
+
group = db_session.exec(select(Group).where(Group.id == group_id)).first()
|
|
54
|
+
if group is None:
|
|
55
|
+
db_session.close()
|
|
56
|
+
raise ManagerError(f"No group with ID {group_id} found!")
|
|
57
|
+
|
|
58
|
+
# Initialize the socket dict (at the very least to avoid KeyErrors)
|
|
59
|
+
for user_id in group.user_ids:
|
|
60
|
+
self.sockets[user_id] = None
|
|
61
|
+
|
|
62
|
+
db_session.close()
|
|
63
|
+
|
|
64
|
+
async def send_message(self, message: str, websocket: WebSocket):
|
|
65
|
+
"""Notify the user of the existing results that have to be fetched."""
|
|
66
|
+
try:
|
|
67
|
+
await websocket.send_text(message)
|
|
68
|
+
except WebSocketDisconnect:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
async def connect(self, user_id: int, websocket: WebSocket, db_session: Session):
|
|
72
|
+
"""Connect to websocket.
|
|
73
|
+
|
|
74
|
+
The connection has been accepted beforehand for sending error messages
|
|
75
|
+
back to user, but here we attach it to the manager instance.
|
|
76
|
+
"""
|
|
77
|
+
self.sockets[user_id] = websocket
|
|
78
|
+
|
|
79
|
+
# If there are pending notifications, send notifications
|
|
80
|
+
group = db_session.exec(select(Group).where(Group.id == self.group_id)).first()
|
|
81
|
+
try:
|
|
82
|
+
head_iter = db_session.exec(
|
|
83
|
+
select(GroupIteration).where(GroupIteration.id == group.head_iteration_id)
|
|
84
|
+
).first()
|
|
85
|
+
if head_iter is None:
|
|
86
|
+
db_session.close()
|
|
87
|
+
return
|
|
88
|
+
prev_iter = head_iter.parent
|
|
89
|
+
if prev_iter is None:
|
|
90
|
+
db_session.close()
|
|
91
|
+
return
|
|
92
|
+
if not prev_iter.notified[str(user_id)]:
|
|
93
|
+
await self.send_message("Please fetch results.", websocket)
|
|
94
|
+
notified = prev_iter.notified.copy()
|
|
95
|
+
notified[user_id] = True
|
|
96
|
+
prev_iter.notified = notified
|
|
97
|
+
db_session.add(prev_iter)
|
|
98
|
+
db_session.commit()
|
|
99
|
+
db_session.close()
|
|
100
|
+
except Exception:
|
|
101
|
+
db_session.close()
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
async def disconnect(self, user_id: int, websocket: WebSocket):
|
|
105
|
+
"""Disconnect from websocket.
|
|
106
|
+
|
|
107
|
+
The connection has been closed beforehand, but here we detach the WebSocket
|
|
108
|
+
object from the manager instance.
|
|
109
|
+
"""
|
|
110
|
+
if self.sockets[user_id] == websocket:
|
|
111
|
+
self.sockets[user_id] = None
|
|
112
|
+
|
|
113
|
+
async def broadcast(self, message: str):
|
|
114
|
+
"""Send message to all connected websockets."""
|
|
115
|
+
for _, socket in self.sockets.items():
|
|
116
|
+
if socket is not None:
|
|
117
|
+
try:
|
|
118
|
+
await socket.send_text(message)
|
|
119
|
+
except WebSocketDisconnect:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
async def notify(
|
|
123
|
+
self,
|
|
124
|
+
user_ids: list[int],
|
|
125
|
+
message: str,
|
|
126
|
+
) -> dict[int, bool]:
|
|
127
|
+
"""Notify all users with [message]."""
|
|
128
|
+
notified = {}
|
|
129
|
+
for user_id in user_ids:
|
|
130
|
+
try:
|
|
131
|
+
socket: WebSocket = self.sockets[user_id]
|
|
132
|
+
if socket is not None:
|
|
133
|
+
await self.send_message(message, socket)
|
|
134
|
+
notified[user_id] = True
|
|
135
|
+
else:
|
|
136
|
+
notified[user_id] = False
|
|
137
|
+
except KeyError:
|
|
138
|
+
notified[user_id] = False
|
|
139
|
+
return notified
|
|
140
|
+
|
|
141
|
+
async def run_method(
|
|
142
|
+
self,
|
|
143
|
+
user_id: int,
|
|
144
|
+
data: str,
|
|
145
|
+
):
|
|
146
|
+
"""The function to run the method.
|
|
147
|
+
|
|
148
|
+
One could derive different managers from this GroupManager
|
|
149
|
+
class and implement method and manager-specific "run_method" functions.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@router.post("/create_group")
|
|
154
|
+
def create_group(
|
|
155
|
+
request: GroupCreateRequest,
|
|
156
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
157
|
+
session: Annotated[Session, Depends(get_session)],
|
|
158
|
+
) -> JSONResponse:
|
|
159
|
+
"""Create group.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
request (GroupCreateRequest): a request that holds information to be used in creation of the group.
|
|
163
|
+
user (Annotated[User, Depends(get_current_user)]): the current user.
|
|
164
|
+
session (Annotated[Session, Depends(get_session)]): the database session.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
JSONResponse: Acknowledgement that the group was created
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
HTTPException
|
|
171
|
+
"""
|
|
172
|
+
problem = session.exec(select(ProblemDB).where(ProblemDB.id == request.problem_id)).first()
|
|
173
|
+
if problem is None:
|
|
174
|
+
raise HTTPException(
|
|
175
|
+
detail=f"There's no problem with ID {request.problem_id}!", status_code=status.HTTP_404_NOT_FOUND
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
group = Group(owner_id=user.id, user_ids=[], problem_id=request.problem_id, name=request.group_name)
|
|
179
|
+
|
|
180
|
+
session.add(group)
|
|
181
|
+
session.commit()
|
|
182
|
+
session.refresh(group)
|
|
183
|
+
|
|
184
|
+
group_ids = user.group_ids.copy() if user.group_ids is not None else []
|
|
185
|
+
group_ids.append(group.id)
|
|
186
|
+
user.group_ids = group_ids
|
|
187
|
+
|
|
188
|
+
session.add(user)
|
|
189
|
+
session.commit()
|
|
190
|
+
|
|
191
|
+
return JSONResponse(content={"message": f"Group with ID {group.id} created."}, status_code=status.HTTP_201_CREATED)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@router.post("/delete_group")
|
|
195
|
+
def delete_group(
|
|
196
|
+
request: GroupInfoRequest,
|
|
197
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
198
|
+
session: Annotated[Session, Depends(get_session)],
|
|
199
|
+
) -> JSONResponse:
|
|
200
|
+
"""Delete the group with given ID.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
request (GroupInfoRequest): Contains the ID of the group to be deleted
|
|
204
|
+
user (Annotated[User, Depends(get_current_user)]): The user (in this case must be owner for anything to happen)
|
|
205
|
+
session (Annotated[Session, Depends(get_session)]): The database session
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
JSONResponse: Acknowledgement of the deletion
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
HTTPException: Insufficient authorization etc.
|
|
212
|
+
"""
|
|
213
|
+
group: Group = session.exec(select(Group).where(Group.id == request.group_id)).first()
|
|
214
|
+
if group is None:
|
|
215
|
+
raise HTTPException(detail=f"No group with ID {request.group_id} found.", status_code=status.HTTP_404_NOT_FOUND)
|
|
216
|
+
|
|
217
|
+
if user.id != group.owner_id:
|
|
218
|
+
raise HTTPException(
|
|
219
|
+
detail="Only the owner of a group may delete the group.", status_code=status.HTTP_401_UNAUTHORIZED
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Remove the group from users
|
|
223
|
+
user_ids = group.user_ids
|
|
224
|
+
for uid in user_ids:
|
|
225
|
+
group_user = session.exec(select(User).where(User.id == uid)).first()
|
|
226
|
+
ugids = group_user.group_ids.copy()
|
|
227
|
+
ugids.remove(group.id)
|
|
228
|
+
group_user.group_ids = ugids
|
|
229
|
+
session.add(group_user)
|
|
230
|
+
session.commit()
|
|
231
|
+
|
|
232
|
+
ugids = user.group_ids.copy()
|
|
233
|
+
ugids.remove(group.id)
|
|
234
|
+
user.group_ids = ugids
|
|
235
|
+
session.add(user)
|
|
236
|
+
session.commit()
|
|
237
|
+
session.refresh(user)
|
|
238
|
+
|
|
239
|
+
# Get the root iteration
|
|
240
|
+
# TODO: Adapt this to the new cascade with multiple children
|
|
241
|
+
head: GroupIteration = session.exec(
|
|
242
|
+
select(GroupIteration).where(GroupIteration.id == group.head_iteration_id)
|
|
243
|
+
).first()
|
|
244
|
+
iter_count = 0
|
|
245
|
+
if head is not None:
|
|
246
|
+
while head.parent is not None:
|
|
247
|
+
head = head.parent
|
|
248
|
+
iter_count += 1
|
|
249
|
+
|
|
250
|
+
# First delete the corresponding group iterations
|
|
251
|
+
# This deletes the rest of the iterations due to cascades
|
|
252
|
+
session.delete(head)
|
|
253
|
+
session.commit()
|
|
254
|
+
|
|
255
|
+
# Then delete the group
|
|
256
|
+
session.delete(group)
|
|
257
|
+
session.commit()
|
|
258
|
+
|
|
259
|
+
# Make sure that the group IS deleted!
|
|
260
|
+
group = session.exec(select(Group).where(Group.id == request.group_id)).first()
|
|
261
|
+
if group is not None:
|
|
262
|
+
raise HTTPException(
|
|
263
|
+
detail="Couldn't delete group from the database!", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return JSONResponse(
|
|
267
|
+
content={"message": f"Group with ID {request.group_id} and its {iter_count} iterations have been deleted."},
|
|
268
|
+
status_code=status.HTTP_200_OK,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@router.post("/add_to_group")
|
|
273
|
+
def add_to_group(
|
|
274
|
+
request: GroupModifyRequest,
|
|
275
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
276
|
+
session: Annotated[Session, Depends(get_session)],
|
|
277
|
+
) -> JSONResponse:
|
|
278
|
+
"""Add a user to a group.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
request (GroupModifyRequest): Request object that has group and user IDs.
|
|
282
|
+
user (Annotated[User, Depends(get_current_user)]): the current user.
|
|
283
|
+
session (Annotated[Session, Depends(get_session)]): the database session.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
JSONResponse: Aknowledge that user has been added to the group
|
|
287
|
+
|
|
288
|
+
Raises:
|
|
289
|
+
HTTPException: Authorization issues, group or user not found.
|
|
290
|
+
"""
|
|
291
|
+
group: Group = session.exec(select(Group).where(Group.id == request.group_id)).first()
|
|
292
|
+
# Make sure the group exists
|
|
293
|
+
if group is None:
|
|
294
|
+
raise HTTPException(
|
|
295
|
+
detail=f"There's no group with ID {request.group_id}", status_code=status.HTTP_404_NOT_FOUND
|
|
296
|
+
)
|
|
297
|
+
# Make sure of proper authorization
|
|
298
|
+
if not group.owner_id == user.id:
|
|
299
|
+
raise HTTPException(detail="Unauthorized user", status_code=status.HTTP_401_UNAUTHORIZED)
|
|
300
|
+
|
|
301
|
+
if request.user_id in group.user_ids:
|
|
302
|
+
raise HTTPException(
|
|
303
|
+
detail=f"User with ID {request.user_id} already in this group!", status_code=status.HTTP_400_BAD_REQUEST
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
addee = session.exec(select(User).where(User.id == request.user_id)).first()
|
|
307
|
+
# Make sure the user to be added exists
|
|
308
|
+
if addee is None:
|
|
309
|
+
raise HTTPException(
|
|
310
|
+
detail=f"There is no user with ID {request.user_id}!", status_code=status.HTTP_404_NOT_FOUND
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
users = group.user_ids.copy()
|
|
314
|
+
users.append(request.user_id)
|
|
315
|
+
group.user_ids = users
|
|
316
|
+
session.add(group)
|
|
317
|
+
session.commit()
|
|
318
|
+
session.refresh(group)
|
|
319
|
+
|
|
320
|
+
if addee.group_ids is None:
|
|
321
|
+
addee.group_ids = [group.id]
|
|
322
|
+
else:
|
|
323
|
+
groups = addee.group_ids.copy()
|
|
324
|
+
groups.append(group.id)
|
|
325
|
+
addee.group_ids = groups
|
|
326
|
+
|
|
327
|
+
session.add(addee)
|
|
328
|
+
session.commit()
|
|
329
|
+
session.refresh(addee)
|
|
330
|
+
|
|
331
|
+
return JSONResponse(
|
|
332
|
+
content={"message": f"Added user {group.user_ids[-1]} to group {group.id}."}, status_code=status.HTTP_200_OK
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@router.post("/remove_from_group")
|
|
337
|
+
def remove_from_group(
|
|
338
|
+
request: GroupModifyRequest,
|
|
339
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
340
|
+
session: Annotated[Session, Depends(get_session)],
|
|
341
|
+
) -> JSONResponse:
|
|
342
|
+
"""Remove user from group.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
request (GroupModifyRequest): Request object that has group and user IDs.
|
|
346
|
+
user (Annotated[User, Depends(get_current_user)]): the current user.
|
|
347
|
+
session (Annotated[Session, Depends(get_session)]): the database session.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
JSONResponse: Aknowledge that user has been removed from the group.
|
|
351
|
+
|
|
352
|
+
Raises:
|
|
353
|
+
HTTPException: Authorization issues, group or user not found.
|
|
354
|
+
"""
|
|
355
|
+
group: Group = session.exec(select(Group).where(Group.id == request.group_id)).first()
|
|
356
|
+
# Make sure the group exists
|
|
357
|
+
if group is None:
|
|
358
|
+
raise HTTPException(detail=f"No group with ID {request.group_id} found.", status_code=status.HTTP_404_NOT_FOUND)
|
|
359
|
+
# Make sure of proper authorization
|
|
360
|
+
authorized = user.id in (group.owner_id, request.user_id)
|
|
361
|
+
|
|
362
|
+
if not authorized:
|
|
363
|
+
raise HTTPException(detail="Unauthorized user", status_code=status.HTTP_401_UNAUTHORIZED)
|
|
364
|
+
|
|
365
|
+
if request.user_id not in group.user_ids:
|
|
366
|
+
raise HTTPException(
|
|
367
|
+
detail=f"User with ID {request.user_id} is not in this group!", status_code=status.HTTP_400_BAD_REQUEST
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
user_ids = group.user_ids.copy()
|
|
371
|
+
user_ids.remove(request.user_id)
|
|
372
|
+
group.user_ids = user_ids
|
|
373
|
+
session.add(group)
|
|
374
|
+
session.commit()
|
|
375
|
+
session.refresh(group)
|
|
376
|
+
|
|
377
|
+
removed_user = session.exec(select(User).where(User.id == request.user_id)).first()
|
|
378
|
+
ugids = removed_user.group_ids.copy()
|
|
379
|
+
ugids.remove(group.id)
|
|
380
|
+
removed_user.group_ids = ugids
|
|
381
|
+
session.add(removed_user)
|
|
382
|
+
session.commit()
|
|
383
|
+
|
|
384
|
+
if request.user_id in group.user_ids:
|
|
385
|
+
raise HTTPException(
|
|
386
|
+
detail=f"Could not remove User {request.user_id} from group {request.group_id}.",
|
|
387
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
return JSONResponse(
|
|
391
|
+
content={"message": f"User {request.user_id} removed from group {request.group_id}."},
|
|
392
|
+
status_code=status.HTTP_200_OK,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
@router.post("/get_group_info")
|
|
397
|
+
def get_group_info(
|
|
398
|
+
request: GroupInfoRequest,
|
|
399
|
+
session: Annotated[Session, Depends(get_session)],
|
|
400
|
+
) -> GroupPublic:
|
|
401
|
+
"""Get information about the group.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
request (GroupInfoRequest): the id of the group for which we desire info on
|
|
405
|
+
session (Annotated[Session, Depends(get_session)]): the database session
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
GroupPublic: public info of the group
|
|
409
|
+
|
|
410
|
+
Raises:
|
|
411
|
+
HTTPException: If there's no group with the requests group id
|
|
412
|
+
"""
|
|
413
|
+
group = session.exec(select(Group).where(Group.id == request.group_id)).first()
|
|
414
|
+
if group is None:
|
|
415
|
+
raise HTTPException(
|
|
416
|
+
detail=f"No group with ID {request.group_id} found!",
|
|
417
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
return group
|