desdeo 1.1.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +40 -0
- desdeo/api/config.py +69 -0
- desdeo/api/config.toml +53 -0
- desdeo/api/db.py +25 -0
- desdeo/api/db_init.py +79 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +66 -0
- desdeo/api/models/archive.py +34 -0
- desdeo/api/models/preference.py +90 -0
- desdeo/api/models/problem.py +507 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +46 -0
- desdeo/api/models/state.py +96 -0
- desdeo/api/models/user.py +51 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +762 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/problem.py +110 -0
- desdeo/api/routers/reference_point_method.py +117 -0
- desdeo/api/routers/session.py +76 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +366 -0
- desdeo/api/schema.py +94 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +59 -0
- desdeo/api/tests/test_models.py +701 -0
- desdeo/api/tests/test_routes.py +216 -0
- desdeo/api/utils/database.py +274 -0
- desdeo/api/utils/logger.py +29 -0
- desdeo/core.py +27 -0
- desdeo/emo/__init__.py +29 -0
- desdeo/emo/hooks/archivers.py +172 -0
- desdeo/emo/methods/EAs.py +418 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +59 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +780 -0
- desdeo/emo/operators/evaluator.py +118 -0
- desdeo/emo/operators/generator.py +356 -0
- desdeo/emo/operators/mutation.py +1053 -0
- desdeo/emo/operators/selection.py +1036 -0
- desdeo/emo/operators/termination.py +178 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/mcdm/__init__.py +19 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +655 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +116 -0
- desdeo/problem/__init__.py +79 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +468 -0
- desdeo/problem/schema.py +1808 -0
- desdeo/problem/simulator_evaluator.py +298 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +73 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +275 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problem.py +434 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +271 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +102 -0
- desdeo/tools/generics.py +145 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +11 -0
- desdeo/tools/indicators_unary.py +375 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +234 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +133 -0
- desdeo/tools/patterns.py +281 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +464 -0
- desdeo/tools/reference_vectors.py +462 -0
- desdeo/tools/scalarization.py +3138 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +464 -0
- desdeo/tools/utils.py +320 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.0.0.dist-info/LICENSE +21 -0
- desdeo-2.0.0.dist-info/METADATA +168 -0
- desdeo-2.0.0.dist-info/RECORD +120 -0
- {desdeo-1.1.3.dist-info → desdeo-2.0.0.dist-info}/WHEEL +1 -1
- desdeo-1.1.3.dist-info/METADATA +0 -18
- desdeo-1.1.3.dist-info/RECORD +0 -4
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Defines end-points to access and manage problems."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter, Depends, HTTPException, status
|
|
6
|
+
from sqlmodel import Session
|
|
7
|
+
|
|
8
|
+
from desdeo.api.db import get_session
|
|
9
|
+
from desdeo.api.models import ProblemDB, ProblemGetRequest, ProblemInfo, ProblemInfoSmall, User, UserRole
|
|
10
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
11
|
+
from desdeo.problem import Problem
|
|
12
|
+
|
|
13
|
+
router = APIRouter(prefix="/problem")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@router.get("/all")
|
|
17
|
+
def get_problems(user: Annotated[User, Depends(get_current_user)]) -> list[ProblemInfoSmall]:
|
|
18
|
+
"""Get information on all the current user's problems.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
user (Annotated[User, Depends): the current user.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
list[ProblemInfoSmall]: a list of information on all the problems.
|
|
25
|
+
"""
|
|
26
|
+
return user.problems
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@router.get("/all_info")
|
|
30
|
+
def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ProblemInfo]:
|
|
31
|
+
"""Get detailed information on all the current user's problems.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
user (Annotated[User, Depends): the current user.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
list[ProblemInfo]: a list of the detailed information on all the problems.
|
|
38
|
+
"""
|
|
39
|
+
return user.problems
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@router.post("/get")
|
|
43
|
+
def get_problem(
|
|
44
|
+
request: ProblemGetRequest,
|
|
45
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
46
|
+
session: Annotated[Session, Depends(get_session)],
|
|
47
|
+
) -> ProblemInfo:
|
|
48
|
+
"""Get the model of a specific problem.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
request (ProblemGetRequest): the request containing the problem's id `problem_id`.
|
|
52
|
+
user (Annotated[User, Depends): the current user.
|
|
53
|
+
session (Annotated[Session, Depends): the database session.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
HTTPException: could not find a problem with the given id.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
ProblemInfo: detailed information on the requested problem.
|
|
60
|
+
"""
|
|
61
|
+
problem = session.get(ProblemDB, request.problem_id)
|
|
62
|
+
|
|
63
|
+
if problem is None:
|
|
64
|
+
raise HTTPException(
|
|
65
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
66
|
+
detail=f"The problem with the requested id={request.problem_id} was not found.",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return problem
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@router.post("/add")
|
|
73
|
+
def add_problem(
|
|
74
|
+
request: Problem,
|
|
75
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
76
|
+
session: Annotated[Session, Depends(get_session)],
|
|
77
|
+
) -> ProblemInfo:
|
|
78
|
+
"""Add a newly defined problem to the database.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
request (Problem): the JSON representation of the problem.
|
|
82
|
+
user (Annotated[User, Depends): the current user.
|
|
83
|
+
session (Annotated[Session, Depends): the database session.
|
|
84
|
+
|
|
85
|
+
Note:
|
|
86
|
+
Users with the role 'guest' may not add new problems.
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
HTTPException: when any issue with defining the problem arises.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
ProblemInfo: the information about the problem added.
|
|
93
|
+
"""
|
|
94
|
+
if user.role == UserRole.guest:
|
|
95
|
+
raise HTTPException(
|
|
96
|
+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Guest users are not allowed to add new problems."
|
|
97
|
+
)
|
|
98
|
+
try:
|
|
99
|
+
problem_db = ProblemDB.from_problem(request, user=user)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
raise HTTPException(
|
|
102
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
103
|
+
detail=f"Could not add problem. Possible reason: {e!r}",
|
|
104
|
+
) from e
|
|
105
|
+
|
|
106
|
+
session.add(problem_db)
|
|
107
|
+
session.commit()
|
|
108
|
+
session.refresh(problem_db)
|
|
109
|
+
|
|
110
|
+
return problem_db
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Defines end-points to access functionalities related to the reference point method."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter, Depends, HTTPException, status
|
|
6
|
+
from sqlmodel import Session, select
|
|
7
|
+
|
|
8
|
+
from desdeo.api.db import get_session
|
|
9
|
+
from desdeo.api.models import (
|
|
10
|
+
InteractiveSessionDB,
|
|
11
|
+
PreferenceDB,
|
|
12
|
+
ProblemDB,
|
|
13
|
+
RPMSolveRequest,
|
|
14
|
+
RPMState,
|
|
15
|
+
StateDB,
|
|
16
|
+
User,
|
|
17
|
+
)
|
|
18
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
19
|
+
from desdeo.mcdm import rpm_solve_solutions
|
|
20
|
+
from desdeo.problem import Problem
|
|
21
|
+
from desdeo.tools import SolverResults
|
|
22
|
+
|
|
23
|
+
router = APIRouter(prefix="/method/rpm")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@router.post("/solve")
|
|
27
|
+
def solve_solutions(
|
|
28
|
+
request: RPMSolveRequest,
|
|
29
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
30
|
+
session: Annotated[Session, Depends(get_session)],
|
|
31
|
+
) -> RPMState:
|
|
32
|
+
"""."""
|
|
33
|
+
|
|
34
|
+
if request.session_id is not None:
|
|
35
|
+
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
|
|
36
|
+
interactive_session = session.exec(statement)
|
|
37
|
+
|
|
38
|
+
if interactive_session is None:
|
|
39
|
+
raise HTTPException(
|
|
40
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
41
|
+
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
# request.session_id is None:
|
|
45
|
+
# use active session instead
|
|
46
|
+
statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
|
|
47
|
+
|
|
48
|
+
interactive_session = session.exec(statement).first()
|
|
49
|
+
|
|
50
|
+
# fetch the problem from the DB
|
|
51
|
+
statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
|
|
52
|
+
problem_db = session.exec(statement).first()
|
|
53
|
+
|
|
54
|
+
if problem_db is None:
|
|
55
|
+
raise HTTPException(
|
|
56
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
problem = Problem.from_problemdb(problem_db)
|
|
60
|
+
|
|
61
|
+
# optimize for solutions
|
|
62
|
+
solver_results: list[SolverResults] = rpm_solve_solutions(
|
|
63
|
+
problem,
|
|
64
|
+
request.preference.aspiration_levels,
|
|
65
|
+
request.scalarization_options,
|
|
66
|
+
request.solver,
|
|
67
|
+
request.solver_options,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# create DB preference
|
|
71
|
+
preference_db = PreferenceDB(user_id=user.id, problem_id=problem_db.id, preference=request.preference)
|
|
72
|
+
|
|
73
|
+
session.add(preference_db)
|
|
74
|
+
session.commit()
|
|
75
|
+
session.refresh(preference_db)
|
|
76
|
+
|
|
77
|
+
# fetch parent state
|
|
78
|
+
if request.parent_state_id is None:
|
|
79
|
+
# parent state is assumed to be the last sate added to the session.
|
|
80
|
+
parent_state = (
|
|
81
|
+
interactive_session.states[-1]
|
|
82
|
+
if (interactive_session is not None and len(interactive_session.states) > 0)
|
|
83
|
+
else None
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
# request.parent_state_id is not None
|
|
88
|
+
statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
|
|
89
|
+
parent_state = session.exec(statement).first()
|
|
90
|
+
|
|
91
|
+
if parent_state is None:
|
|
92
|
+
raise HTTPException(
|
|
93
|
+
status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# create state and add to DB
|
|
97
|
+
rpm_state = RPMState(
|
|
98
|
+
scalarization_options=request.scalarization_options,
|
|
99
|
+
solver=request.solver,
|
|
100
|
+
solver_options=request.solver_options,
|
|
101
|
+
solver_results=solver_results,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# create DB state and add it to the DB
|
|
105
|
+
state = StateDB(
|
|
106
|
+
problem_id=problem_db.id,
|
|
107
|
+
preference_id=preference_db.id,
|
|
108
|
+
session_id=interactive_session.id if interactive_session is not None else None,
|
|
109
|
+
parent_id=parent_state.id if parent_state is not None else None,
|
|
110
|
+
state=rpm_state,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
session.add(state)
|
|
114
|
+
session.commit()
|
|
115
|
+
session.refresh(state)
|
|
116
|
+
|
|
117
|
+
return rpm_state
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Defines end-points to access and manage interactive sessions."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter, Depends, HTTPException, status
|
|
6
|
+
from sqlmodel import Session, select
|
|
7
|
+
|
|
8
|
+
from desdeo.api.db import get_session
|
|
9
|
+
from desdeo.api.models import (
|
|
10
|
+
CreateSessionRequest,
|
|
11
|
+
GetSessionRequest,
|
|
12
|
+
InteractiveSessionDB,
|
|
13
|
+
InteractiveSessionInfo,
|
|
14
|
+
User,
|
|
15
|
+
)
|
|
16
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
17
|
+
|
|
18
|
+
router = APIRouter(prefix="/session")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@router.post("/new")
|
|
22
|
+
def create_new_session(
|
|
23
|
+
request: CreateSessionRequest,
|
|
24
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
25
|
+
session: Annotated[Session, Depends(get_session)],
|
|
26
|
+
) -> InteractiveSessionInfo:
|
|
27
|
+
"""."""
|
|
28
|
+
interactive_session = InteractiveSessionDB(user_id=user.id, info=request.info)
|
|
29
|
+
|
|
30
|
+
session.add(interactive_session)
|
|
31
|
+
session.commit()
|
|
32
|
+
session.refresh(interactive_session)
|
|
33
|
+
|
|
34
|
+
user.active_session_id = interactive_session.id
|
|
35
|
+
|
|
36
|
+
session.add(user)
|
|
37
|
+
session.commit()
|
|
38
|
+
session.refresh(interactive_session)
|
|
39
|
+
|
|
40
|
+
return interactive_session
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@router.post("/get")
|
|
44
|
+
def get_session(
|
|
45
|
+
request: GetSessionRequest,
|
|
46
|
+
user: Annotated[User, Depends(get_current_user)],
|
|
47
|
+
session: Annotated[Session, Depends(get_session)],
|
|
48
|
+
) -> InteractiveSessionInfo:
|
|
49
|
+
"""Return an interactive session with a given id for the current user.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
request (GetSessionRequest): a request containing the id of the session.
|
|
53
|
+
user (Annotated[User, Depends): the current user.
|
|
54
|
+
session (Annotated[Session, Depends): the database session.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
HTTPException: could not find an interactive session with the given id
|
|
58
|
+
for the current user.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
InteractiveSessionInfo: info on the requested interactive session.
|
|
62
|
+
"""
|
|
63
|
+
statement = select(InteractiveSessionDB).where(
|
|
64
|
+
InteractiveSessionDB.id == request.session_id, InteractiveSessionDB.user_id == user.id
|
|
65
|
+
)
|
|
66
|
+
result = session.exec(statement)
|
|
67
|
+
|
|
68
|
+
interactive_session = result.first()
|
|
69
|
+
|
|
70
|
+
if interactive_session is None:
|
|
71
|
+
raise HTTPException(
|
|
72
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
73
|
+
detail=f"Could not find interactive session with id={request.session_id}.",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return interactive_session
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""A test router for the DESDEO API."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter, Depends
|
|
6
|
+
|
|
7
|
+
from desdeo.api.routers.user_authentication import get_current_user
|
|
8
|
+
from desdeo.api.schema import User
|
|
9
|
+
|
|
10
|
+
router = APIRouter(prefix="/test")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@router.get("/userdetails")
|
|
14
|
+
def get_user(current_user: Annotated[User, Depends(get_current_user)]) -> User:
|
|
15
|
+
"""Get information about the current user."""
|
|
16
|
+
return current_user
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""This module contains the functions for user authentication."""
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import UTC, datetime, timedelta
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
import bcrypt
|
|
8
|
+
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
|
|
9
|
+
from fastapi.responses import JSONResponse
|
|
10
|
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
11
|
+
from jose import JWTError, jwt
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from sqlmodel import Session, select
|
|
14
|
+
|
|
15
|
+
from desdeo.api import SettingsConfig
|
|
16
|
+
from desdeo.api.db import get_session
|
|
17
|
+
from desdeo.api.models import User, UserPublic
|
|
18
|
+
|
|
19
|
+
# AuthConfig
|
|
20
|
+
if SettingsConfig.debug:
|
|
21
|
+
from desdeo.api import AuthDebugConfig
|
|
22
|
+
|
|
23
|
+
AuthConfig = AuthDebugConfig
|
|
24
|
+
else:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
router = APIRouter()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Tokens(BaseModel):
|
|
31
|
+
"""A model for the authentication token."""
|
|
32
|
+
|
|
33
|
+
access_token: str
|
|
34
|
+
refresh_token: str
|
|
35
|
+
token_type: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# OAuth2PasswordBearer is a class that creates a dependency that will be used to get the token from the request.
|
|
39
|
+
# The token will be used to authenticate the user.
|
|
40
|
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
44
|
+
"""Check if a password matches a hash.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
plain_password (str): the plain password.
|
|
48
|
+
hashed_password (str): the hashed password.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
bool: whether the plain password matches the hashed one.
|
|
52
|
+
"""
|
|
53
|
+
password_byte_enc = plain_password.encode("utf-8")
|
|
54
|
+
|
|
55
|
+
return bcrypt.checkpw(password=password_byte_enc, hashed_password=hashed_password.encode("utf-8"))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_password_hash(password: str) -> str:
|
|
59
|
+
"""Hash a password.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
password (str): the password to be hashed.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
str: the hashed password.
|
|
66
|
+
"""
|
|
67
|
+
pwd_bytes = password.encode("utf-8")
|
|
68
|
+
|
|
69
|
+
return bcrypt.hashpw(password=pwd_bytes, salt=bcrypt.gensalt()).decode("utf-8")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_user(session: Session, username: str) -> User | None:
|
|
73
|
+
"""Get the current user.
|
|
74
|
+
|
|
75
|
+
Get the current user based on the username. If no user if found,
|
|
76
|
+
return None.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
session (Session): database session.
|
|
80
|
+
username (str): the username of the user.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
User | None: the User. If no user is found, returns None.
|
|
84
|
+
"""
|
|
85
|
+
statement = select(User).where(User.username == username)
|
|
86
|
+
return session.exec(statement).first()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
|
90
|
+
"""Check if a user exists and the password is correct.
|
|
91
|
+
|
|
92
|
+
Check if a user exists and the password is correct. If the user exists and the password
|
|
93
|
+
is correct, returns the user.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
session (Session): database session.
|
|
97
|
+
username (str): the username of the user.
|
|
98
|
+
password (str): password set for the user.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
User | None: the User. If no user if found, returns None.
|
|
102
|
+
"""
|
|
103
|
+
user = get_user(session, username)
|
|
104
|
+
|
|
105
|
+
if not user or not verify_password(password, user.password_hash):
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
return user
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_current_user(
|
|
112
|
+
token: Annotated[str, Depends(oauth2_scheme)],
|
|
113
|
+
session: Annotated[Session, Depends(get_session)],
|
|
114
|
+
) -> User:
|
|
115
|
+
"""Get the current user based on a JWT token.
|
|
116
|
+
|
|
117
|
+
This function is a dependency for other functions that need to get the current user.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
token (Annotated[str, Depends(oauth2_scheme)]): The authentication token.
|
|
121
|
+
session (Annotated[Session, Depends(get_db)]): A database session.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
User: The information of the current user.
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
HTTPException: If the token is invalid.
|
|
128
|
+
"""
|
|
129
|
+
credentials_exception = HTTPException(
|
|
130
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
131
|
+
detail="Could not validate credentials",
|
|
132
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
133
|
+
)
|
|
134
|
+
try:
|
|
135
|
+
payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
|
|
136
|
+
username = payload.get("sub")
|
|
137
|
+
expire_time: datetime = payload.get("exp")
|
|
138
|
+
|
|
139
|
+
if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
|
|
140
|
+
raise credentials_exception
|
|
141
|
+
|
|
142
|
+
except jwt.exceptions.ExpiredSignatureError:
|
|
143
|
+
raise credentials_exception from None
|
|
144
|
+
|
|
145
|
+
except JWTError:
|
|
146
|
+
raise credentials_exception from JWTError
|
|
147
|
+
|
|
148
|
+
user = get_user(session, username=username)
|
|
149
|
+
|
|
150
|
+
if user is None:
|
|
151
|
+
raise credentials_exception
|
|
152
|
+
|
|
153
|
+
return user
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def create_jwt_token(
|
|
157
|
+
data: dict,
|
|
158
|
+
expires_delta: timedelta,
|
|
159
|
+
algorithm: str = AuthConfig.authjwt_algorithm,
|
|
160
|
+
secret_key: str = AuthConfig.authjwt_secret_key,
|
|
161
|
+
) -> str:
|
|
162
|
+
"""Creates an JWT Token with `data` and `expire_delta`.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
data (dict): The data to encode in the token.
|
|
166
|
+
expires_delta (timedelta): The time after which the token will expire.
|
|
167
|
+
algorithm (str): the algorithms to encode the JWT token.
|
|
168
|
+
Defaults to `AuthConfig.authjwt_algorithm`.
|
|
169
|
+
secret_key (str): the secret key used in encoding the JWT token.
|
|
170
|
+
Defaults to `AuthConfig.authjwt_secret_key`.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
str: the JWT token.
|
|
174
|
+
"""
|
|
175
|
+
data = data.copy()
|
|
176
|
+
expire = datetime.now(UTC) + expires_delta
|
|
177
|
+
data.update({"exp": expire, "jti": str(uuid.uuid4())})
|
|
178
|
+
# jti adds an unique identifier so that life is easier
|
|
179
|
+
|
|
180
|
+
return jwt.encode(data, secret_key, algorithm=algorithm)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def create_access_token(data: dict, expiration_time: int = AuthConfig.authjwt_access_token_expires) -> str:
|
|
184
|
+
"""Creates a JWT access token.
|
|
185
|
+
|
|
186
|
+
Creates a JWT access token with `data`, and an
|
|
187
|
+
expiration time.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
data (dict): the data to encode in the token.
|
|
191
|
+
expiration_time (int): the expiration time of the access token
|
|
192
|
+
in minutes. Defaults to `AuthConfig.authjwt_access_token_expires`.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
str: the JWT access token.
|
|
196
|
+
"""
|
|
197
|
+
return create_jwt_token(data, timedelta(minutes=expiration_time))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def create_refresh_token(data: dict, expiration_time: int = AuthConfig.authjwt_refresh_token_expires) -> str:
|
|
201
|
+
"""Creates a JTW refresh token.
|
|
202
|
+
|
|
203
|
+
Creates a JWT refresh token with `data and an expiration time.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
data (dict): The data to encode in the token.
|
|
207
|
+
expiration_time (int): the expiration time of the refresh token
|
|
208
|
+
in minutes. Defaults to `AuthConfig.authjwt_refresh_token_expires`.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
str: the JWT refresh token.
|
|
212
|
+
"""
|
|
213
|
+
refresh_token: str = create_jwt_token(data, timedelta(minutes=expiration_time))
|
|
214
|
+
|
|
215
|
+
return refresh_token
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def generate_tokens(data: dict) -> Tokens:
|
|
219
|
+
"""Generates a and refresh Tokens with `data`.
|
|
220
|
+
|
|
221
|
+
Note:
|
|
222
|
+
The expiration times of the tokens in defined in
|
|
223
|
+
`AuthConfig`.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
data (dict): The data to encode in the token.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Tokens: the access and refresh tokens.
|
|
230
|
+
"""
|
|
231
|
+
access_token = create_access_token(data)
|
|
232
|
+
refresh_token = create_refresh_token(data)
|
|
233
|
+
return Tokens(access_token=access_token, refresh_token=refresh_token, token_type="bearer") # noqa: S106
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def validate_refresh_token(
|
|
237
|
+
refresh_token: str,
|
|
238
|
+
session: Annotated[Session, Depends(get_session)],
|
|
239
|
+
algorithm: str = AuthConfig.authjwt_algorithm,
|
|
240
|
+
secret_key: str = AuthConfig.authjwt_secret_key,
|
|
241
|
+
) -> User:
|
|
242
|
+
"""Validate a refresh token and return the associated user if valid.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
refresh_token (str): The refresh token to validate.
|
|
246
|
+
session (Annotated[Session, Depends(get_db)]): The database session.
|
|
247
|
+
algorithm (str): the algorithm used to decode the JWT token.
|
|
248
|
+
Defaults to `AuthConfig.authjwt_algorithm`.
|
|
249
|
+
secret_key (str): the secret key used to decode the JWT token.
|
|
250
|
+
Defaults to `AuthConfig.authjwt_secret_key`.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
UserModel: The user associated with the valid refresh token.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
HTTPException: If the refresh token is invalid or expired.
|
|
257
|
+
"""
|
|
258
|
+
credentials_exception = HTTPException(
|
|
259
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
260
|
+
detail="Invalid or expired refresh token",
|
|
261
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
# Decode the refresh token
|
|
266
|
+
payload = jwt.decode(refresh_token, secret_key, algorithms=[algorithm])
|
|
267
|
+
username = payload.get("sub")
|
|
268
|
+
expire_time: datetime = payload.get("exp")
|
|
269
|
+
|
|
270
|
+
except Exception as _:
|
|
271
|
+
raise credentials_exception from None
|
|
272
|
+
|
|
273
|
+
if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
|
|
274
|
+
raise credentials_exception
|
|
275
|
+
|
|
276
|
+
# Validate the user from the database
|
|
277
|
+
user = get_user(session, username=username)
|
|
278
|
+
if user is None:
|
|
279
|
+
raise credentials_exception
|
|
280
|
+
|
|
281
|
+
return user
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@router.get("/user_info")
|
|
285
|
+
def get_current_user_info(user: Annotated[User, Depends(get_current_user)]) -> UserPublic:
|
|
286
|
+
"""Return information about the current user.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
user (Annotated[User, Depends): user dependency, handled by `get_current_user`.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
UserPublic: public information about the current user.
|
|
293
|
+
"""
|
|
294
|
+
return user
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@router.post("/login")
|
|
298
|
+
def login(
|
|
299
|
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
300
|
+
session: Annotated[Session, Depends(get_session)],
|
|
301
|
+
cookie_max_age: int = AuthConfig.authjwt_refresh_token_expires,
|
|
302
|
+
):
|
|
303
|
+
"""Login to get an authentication token.
|
|
304
|
+
|
|
305
|
+
Return an access token in the response and a cookie storing a refresh token.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
form_data (Annotated[OAuth2PasswordRequestForm, Depends()]):
|
|
309
|
+
The form data to authenticate the user.
|
|
310
|
+
session (Annotated[Session, Depends(get_db)]): The database session.
|
|
311
|
+
cookie_max_age (int): the lifetime of the cookie storing the refresh token.
|
|
312
|
+
|
|
313
|
+
"""
|
|
314
|
+
user = authenticate_user(session, form_data.username, form_data.password)
|
|
315
|
+
if user is None:
|
|
316
|
+
raise HTTPException(
|
|
317
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
318
|
+
detail="Incorrect username or password",
|
|
319
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
tokens = generate_tokens({"id": user.id, "sub": user.username})
|
|
323
|
+
|
|
324
|
+
response = JSONResponse(content={"access_token": tokens.access_token})
|
|
325
|
+
response.set_cookie(
|
|
326
|
+
key="refresh_token",
|
|
327
|
+
value=tokens.refresh_token,
|
|
328
|
+
httponly=True, # HTTP only cookie, more secure than storing the refresh token in the frontend code.
|
|
329
|
+
secure=False, # allow http
|
|
330
|
+
samesite="lax", # cross-origin requests
|
|
331
|
+
max_age=cookie_max_age * 60, # convert to minutes
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
return response
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
@router.post("/refresh")
|
|
338
|
+
def refresh_access_token(
|
|
339
|
+
request: Response,
|
|
340
|
+
session: Annotated[Session, Depends(get_session)],
|
|
341
|
+
refresh_token: Annotated[str | None, Cookie()] = None,
|
|
342
|
+
):
|
|
343
|
+
"""Refresh the access token using the refresh token stored in the cookie.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
request (Request): The request containing the cookie.
|
|
347
|
+
session (Annotated[Session, Depends(get_db)]): the database session.
|
|
348
|
+
refresh_token (Annotated[Str | None, Cookie()]): the refresh
|
|
349
|
+
token, which is fetched from a cookie included in the response.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
dict: A dictionary containing the new access token.
|
|
353
|
+
"""
|
|
354
|
+
if not refresh_token:
|
|
355
|
+
raise HTTPException(
|
|
356
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
357
|
+
detail="Missing refresh token.",
|
|
358
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
user = validate_refresh_token(refresh_token, session)
|
|
362
|
+
|
|
363
|
+
# Generate a new access token for the user
|
|
364
|
+
access_token = create_access_token({"id": user.id, "sub": user.username})
|
|
365
|
+
|
|
366
|
+
return {"access_token": access_token}
|