desdeo 1.1.3__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. desdeo/__init__.py +8 -8
  2. desdeo/api/README.md +73 -0
  3. desdeo/api/__init__.py +15 -0
  4. desdeo/api/app.py +40 -0
  5. desdeo/api/config.py +69 -0
  6. desdeo/api/config.toml +53 -0
  7. desdeo/api/db.py +25 -0
  8. desdeo/api/db_init.py +79 -0
  9. desdeo/api/db_models.py +164 -0
  10. desdeo/api/malaga_db_init.py +27 -0
  11. desdeo/api/models/__init__.py +66 -0
  12. desdeo/api/models/archive.py +34 -0
  13. desdeo/api/models/preference.py +90 -0
  14. desdeo/api/models/problem.py +507 -0
  15. desdeo/api/models/reference_point_method.py +18 -0
  16. desdeo/api/models/session.py +46 -0
  17. desdeo/api/models/state.py +96 -0
  18. desdeo/api/models/user.py +51 -0
  19. desdeo/api/routers/_NAUTILUS.py +245 -0
  20. desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
  21. desdeo/api/routers/_NIMBUS.py +762 -0
  22. desdeo/api/routers/__init__.py +5 -0
  23. desdeo/api/routers/problem.py +110 -0
  24. desdeo/api/routers/reference_point_method.py +117 -0
  25. desdeo/api/routers/session.py +76 -0
  26. desdeo/api/routers/test.py +16 -0
  27. desdeo/api/routers/user_authentication.py +366 -0
  28. desdeo/api/schema.py +94 -0
  29. desdeo/api/tests/__init__.py +0 -0
  30. desdeo/api/tests/conftest.py +59 -0
  31. desdeo/api/tests/test_models.py +701 -0
  32. desdeo/api/tests/test_routes.py +216 -0
  33. desdeo/api/utils/database.py +274 -0
  34. desdeo/api/utils/logger.py +29 -0
  35. desdeo/core.py +27 -0
  36. desdeo/emo/__init__.py +29 -0
  37. desdeo/emo/hooks/archivers.py +172 -0
  38. desdeo/emo/methods/EAs.py +418 -0
  39. desdeo/emo/methods/__init__.py +0 -0
  40. desdeo/emo/methods/bases.py +59 -0
  41. desdeo/emo/operators/__init__.py +1 -0
  42. desdeo/emo/operators/crossover.py +780 -0
  43. desdeo/emo/operators/evaluator.py +118 -0
  44. desdeo/emo/operators/generator.py +356 -0
  45. desdeo/emo/operators/mutation.py +1053 -0
  46. desdeo/emo/operators/selection.py +1036 -0
  47. desdeo/emo/operators/termination.py +178 -0
  48. desdeo/explanations/__init__.py +6 -0
  49. desdeo/explanations/explainer.py +100 -0
  50. desdeo/explanations/utils.py +90 -0
  51. desdeo/mcdm/__init__.py +19 -0
  52. desdeo/mcdm/nautili.py +345 -0
  53. desdeo/mcdm/nautilus.py +477 -0
  54. desdeo/mcdm/nautilus_navigator.py +655 -0
  55. desdeo/mcdm/nimbus.py +417 -0
  56. desdeo/mcdm/pareto_navigator.py +269 -0
  57. desdeo/mcdm/reference_point_method.py +116 -0
  58. desdeo/problem/__init__.py +79 -0
  59. desdeo/problem/evaluator.py +561 -0
  60. desdeo/problem/gurobipy_evaluator.py +562 -0
  61. desdeo/problem/infix_parser.py +341 -0
  62. desdeo/problem/json_parser.py +944 -0
  63. desdeo/problem/pyomo_evaluator.py +468 -0
  64. desdeo/problem/schema.py +1808 -0
  65. desdeo/problem/simulator_evaluator.py +298 -0
  66. desdeo/problem/sympy_evaluator.py +244 -0
  67. desdeo/problem/testproblems/__init__.py +73 -0
  68. desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
  69. desdeo/problem/testproblems/dtlz2_problem.py +102 -0
  70. desdeo/problem/testproblems/forest_problem.py +275 -0
  71. desdeo/problem/testproblems/knapsack_problem.py +163 -0
  72. desdeo/problem/testproblems/mcwb_problem.py +831 -0
  73. desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
  74. desdeo/problem/testproblems/momip_problem.py +172 -0
  75. desdeo/problem/testproblems/nimbus_problem.py +143 -0
  76. desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
  77. desdeo/problem/testproblems/re_problem.py +492 -0
  78. desdeo/problem/testproblems/river_pollution_problem.py +434 -0
  79. desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
  80. desdeo/problem/testproblems/simple_problem.py +351 -0
  81. desdeo/problem/testproblems/simulator_problem.py +92 -0
  82. desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
  83. desdeo/problem/testproblems/zdt_problem.py +271 -0
  84. desdeo/problem/utils.py +245 -0
  85. desdeo/tools/GenerateReferencePoints.py +181 -0
  86. desdeo/tools/__init__.py +102 -0
  87. desdeo/tools/generics.py +145 -0
  88. desdeo/tools/gurobipy_solver_interfaces.py +258 -0
  89. desdeo/tools/indicators_binary.py +11 -0
  90. desdeo/tools/indicators_unary.py +375 -0
  91. desdeo/tools/interaction_schema.py +38 -0
  92. desdeo/tools/intersection.py +54 -0
  93. desdeo/tools/iterative_pareto_representer.py +99 -0
  94. desdeo/tools/message.py +234 -0
  95. desdeo/tools/ng_solver_interfaces.py +199 -0
  96. desdeo/tools/non_dominated_sorting.py +133 -0
  97. desdeo/tools/patterns.py +281 -0
  98. desdeo/tools/proximal_solver.py +99 -0
  99. desdeo/tools/pyomo_solver_interfaces.py +464 -0
  100. desdeo/tools/reference_vectors.py +462 -0
  101. desdeo/tools/scalarization.py +3138 -0
  102. desdeo/tools/scipy_solver_interfaces.py +454 -0
  103. desdeo/tools/score_bands.py +464 -0
  104. desdeo/tools/utils.py +320 -0
  105. desdeo/utopia_stuff/__init__.py +0 -0
  106. desdeo/utopia_stuff/data/1.json +15 -0
  107. desdeo/utopia_stuff/data/2.json +13 -0
  108. desdeo/utopia_stuff/data/3.json +15 -0
  109. desdeo/utopia_stuff/data/4.json +17 -0
  110. desdeo/utopia_stuff/data/5.json +15 -0
  111. desdeo/utopia_stuff/from_json.py +40 -0
  112. desdeo/utopia_stuff/reinit_user.py +38 -0
  113. desdeo/utopia_stuff/utopia_db_init.py +212 -0
  114. desdeo/utopia_stuff/utopia_problem.py +403 -0
  115. desdeo/utopia_stuff/utopia_problem_old.py +415 -0
  116. desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
  117. desdeo-2.0.0.dist-info/LICENSE +21 -0
  118. desdeo-2.0.0.dist-info/METADATA +168 -0
  119. desdeo-2.0.0.dist-info/RECORD +120 -0
  120. {desdeo-1.1.3.dist-info → desdeo-2.0.0.dist-info}/WHEEL +1 -1
  121. desdeo-1.1.3.dist-info/METADATA +0 -18
  122. desdeo-1.1.3.dist-info/RECORD +0 -4
@@ -0,0 +1,5 @@
1
+ """Exports from routers."""
2
+
3
+ __all__ = ["get_current_user"]
4
+
5
+ from .user_authentication import get_current_user
@@ -0,0 +1,110 @@
1
+ """Defines end-points to access and manage problems."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
+ from sqlmodel import Session
7
+
8
+ from desdeo.api.db import get_session
9
+ from desdeo.api.models import ProblemDB, ProblemGetRequest, ProblemInfo, ProblemInfoSmall, User, UserRole
10
+ from desdeo.api.routers.user_authentication import get_current_user
11
+ from desdeo.problem import Problem
12
+
13
+ router = APIRouter(prefix="/problem")
14
+
15
+
16
+ @router.get("/all")
17
+ def get_problems(user: Annotated[User, Depends(get_current_user)]) -> list[ProblemInfoSmall]:
18
+ """Get information on all the current user's problems.
19
+
20
+ Args:
21
+ user (Annotated[User, Depends): the current user.
22
+
23
+ Returns:
24
+ list[ProblemInfoSmall]: a list of information on all the problems.
25
+ """
26
+ return user.problems
27
+
28
+
29
+ @router.get("/all_info")
30
+ def get_problems_info(user: Annotated[User, Depends(get_current_user)]) -> list[ProblemInfo]:
31
+ """Get detailed information on all the current user's problems.
32
+
33
+ Args:
34
+ user (Annotated[User, Depends): the current user.
35
+
36
+ Returns:
37
+ list[ProblemInfo]: a list of the detailed information on all the problems.
38
+ """
39
+ return user.problems
40
+
41
+
42
+ @router.post("/get")
43
+ def get_problem(
44
+ request: ProblemGetRequest,
45
+ user: Annotated[User, Depends(get_current_user)],
46
+ session: Annotated[Session, Depends(get_session)],
47
+ ) -> ProblemInfo:
48
+ """Get the model of a specific problem.
49
+
50
+ Args:
51
+ request (ProblemGetRequest): the request containing the problem's id `problem_id`.
52
+ user (Annotated[User, Depends): the current user.
53
+ session (Annotated[Session, Depends): the database session.
54
+
55
+ Raises:
56
+ HTTPException: could not find a problem with the given id.
57
+
58
+ Returns:
59
+ ProblemInfo: detailed information on the requested problem.
60
+ """
61
+ problem = session.get(ProblemDB, request.problem_id)
62
+
63
+ if problem is None:
64
+ raise HTTPException(
65
+ status_code=status.HTTP_404_NOT_FOUND,
66
+ detail=f"The problem with the requested id={request.problem_id} was not found.",
67
+ )
68
+
69
+ return problem
70
+
71
+
72
+ @router.post("/add")
73
+ def add_problem(
74
+ request: Problem,
75
+ user: Annotated[User, Depends(get_current_user)],
76
+ session: Annotated[Session, Depends(get_session)],
77
+ ) -> ProblemInfo:
78
+ """Add a newly defined problem to the database.
79
+
80
+ Args:
81
+ request (Problem): the JSON representation of the problem.
82
+ user (Annotated[User, Depends): the current user.
83
+ session (Annotated[Session, Depends): the database session.
84
+
85
+ Note:
86
+ Users with the role 'guest' may not add new problems.
87
+
88
+ Raises:
89
+ HTTPException: when any issue with defining the problem arises.
90
+
91
+ Returns:
92
+ ProblemInfo: the information about the problem added.
93
+ """
94
+ if user.role == UserRole.guest:
95
+ raise HTTPException(
96
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Guest users are not allowed to add new problems."
97
+ )
98
+ try:
99
+ problem_db = ProblemDB.from_problem(request, user=user)
100
+ except Exception as e:
101
+ raise HTTPException(
102
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
103
+ detail=f"Could not add problem. Possible reason: {e!r}",
104
+ ) from e
105
+
106
+ session.add(problem_db)
107
+ session.commit()
108
+ session.refresh(problem_db)
109
+
110
+ return problem_db
@@ -0,0 +1,117 @@
1
+ """Defines end-points to access functionalities related to the reference point method."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
+ from sqlmodel import Session, select
7
+
8
+ from desdeo.api.db import get_session
9
+ from desdeo.api.models import (
10
+ InteractiveSessionDB,
11
+ PreferenceDB,
12
+ ProblemDB,
13
+ RPMSolveRequest,
14
+ RPMState,
15
+ StateDB,
16
+ User,
17
+ )
18
+ from desdeo.api.routers.user_authentication import get_current_user
19
+ from desdeo.mcdm import rpm_solve_solutions
20
+ from desdeo.problem import Problem
21
+ from desdeo.tools import SolverResults
22
+
23
+ router = APIRouter(prefix="/method/rpm")
24
+
25
+
26
+ @router.post("/solve")
27
+ def solve_solutions(
28
+ request: RPMSolveRequest,
29
+ user: Annotated[User, Depends(get_current_user)],
30
+ session: Annotated[Session, Depends(get_session)],
31
+ ) -> RPMState:
32
+ """."""
33
+
34
+ if request.session_id is not None:
35
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
36
+ interactive_session = session.exec(statement)
37
+
38
+ if interactive_session is None:
39
+ raise HTTPException(
40
+ status_code=status.HTTP_404_NOT_FOUND,
41
+ detail=f"Could not find interactive session with id={request.session_id}.",
42
+ )
43
+ else:
44
+ # request.session_id is None:
45
+ # use active session instead
46
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
47
+
48
+ interactive_session = session.exec(statement).first()
49
+
50
+ # fetch the problem from the DB
51
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
52
+ problem_db = session.exec(statement).first()
53
+
54
+ if problem_db is None:
55
+ raise HTTPException(
56
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
57
+ )
58
+
59
+ problem = Problem.from_problemdb(problem_db)
60
+
61
+ # optimize for solutions
62
+ solver_results: list[SolverResults] = rpm_solve_solutions(
63
+ problem,
64
+ request.preference.aspiration_levels,
65
+ request.scalarization_options,
66
+ request.solver,
67
+ request.solver_options,
68
+ )
69
+
70
+ # create DB preference
71
+ preference_db = PreferenceDB(user_id=user.id, problem_id=problem_db.id, preference=request.preference)
72
+
73
+ session.add(preference_db)
74
+ session.commit()
75
+ session.refresh(preference_db)
76
+
77
+ # fetch parent state
78
+ if request.parent_state_id is None:
79
+ # parent state is assumed to be the last sate added to the session.
80
+ parent_state = (
81
+ interactive_session.states[-1]
82
+ if (interactive_session is not None and len(interactive_session.states) > 0)
83
+ else None
84
+ )
85
+
86
+ else:
87
+ # request.parent_state_id is not None
88
+ statement = session.select(StateDB).where(StateDB.id == request.parent_state_id)
89
+ parent_state = session.exec(statement).first()
90
+
91
+ if parent_state is None:
92
+ raise HTTPException(
93
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
94
+ )
95
+
96
+ # create state and add to DB
97
+ rpm_state = RPMState(
98
+ scalarization_options=request.scalarization_options,
99
+ solver=request.solver,
100
+ solver_options=request.solver_options,
101
+ solver_results=solver_results,
102
+ )
103
+
104
+ # create DB state and add it to the DB
105
+ state = StateDB(
106
+ problem_id=problem_db.id,
107
+ preference_id=preference_db.id,
108
+ session_id=interactive_session.id if interactive_session is not None else None,
109
+ parent_id=parent_state.id if parent_state is not None else None,
110
+ state=rpm_state,
111
+ )
112
+
113
+ session.add(state)
114
+ session.commit()
115
+ session.refresh(state)
116
+
117
+ return rpm_state
@@ -0,0 +1,76 @@
1
+ """Defines end-points to access and manage interactive sessions."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
+ from sqlmodel import Session, select
7
+
8
+ from desdeo.api.db import get_session
9
+ from desdeo.api.models import (
10
+ CreateSessionRequest,
11
+ GetSessionRequest,
12
+ InteractiveSessionDB,
13
+ InteractiveSessionInfo,
14
+ User,
15
+ )
16
+ from desdeo.api.routers.user_authentication import get_current_user
17
+
18
+ router = APIRouter(prefix="/session")
19
+
20
+
21
+ @router.post("/new")
22
+ def create_new_session(
23
+ request: CreateSessionRequest,
24
+ user: Annotated[User, Depends(get_current_user)],
25
+ session: Annotated[Session, Depends(get_session)],
26
+ ) -> InteractiveSessionInfo:
27
+ """."""
28
+ interactive_session = InteractiveSessionDB(user_id=user.id, info=request.info)
29
+
30
+ session.add(interactive_session)
31
+ session.commit()
32
+ session.refresh(interactive_session)
33
+
34
+ user.active_session_id = interactive_session.id
35
+
36
+ session.add(user)
37
+ session.commit()
38
+ session.refresh(interactive_session)
39
+
40
+ return interactive_session
41
+
42
+
43
+ @router.post("/get")
44
+ def get_session(
45
+ request: GetSessionRequest,
46
+ user: Annotated[User, Depends(get_current_user)],
47
+ session: Annotated[Session, Depends(get_session)],
48
+ ) -> InteractiveSessionInfo:
49
+ """Return an interactive session with a given id for the current user.
50
+
51
+ Args:
52
+ request (GetSessionRequest): a request containing the id of the session.
53
+ user (Annotated[User, Depends): the current user.
54
+ session (Annotated[Session, Depends): the database session.
55
+
56
+ Raises:
57
+ HTTPException: could not find an interactive session with the given id
58
+ for the current user.
59
+
60
+ Returns:
61
+ InteractiveSessionInfo: info on the requested interactive session.
62
+ """
63
+ statement = select(InteractiveSessionDB).where(
64
+ InteractiveSessionDB.id == request.session_id, InteractiveSessionDB.user_id == user.id
65
+ )
66
+ result = session.exec(statement)
67
+
68
+ interactive_session = result.first()
69
+
70
+ if interactive_session is None:
71
+ raise HTTPException(
72
+ status_code=status.HTTP_404_NOT_FOUND,
73
+ detail=f"Could not find interactive session with id={request.session_id}.",
74
+ )
75
+
76
+ return interactive_session
@@ -0,0 +1,16 @@
1
+ """A test router for the DESDEO API."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import APIRouter, Depends
6
+
7
+ from desdeo.api.routers.user_authentication import get_current_user
8
+ from desdeo.api.schema import User
9
+
10
+ router = APIRouter(prefix="/test")
11
+
12
+
13
+ @router.get("/userdetails")
14
+ def get_user(current_user: Annotated[User, Depends(get_current_user)]) -> User:
15
+ """Get information about the current user."""
16
+ return current_user
@@ -0,0 +1,366 @@
1
+ """This module contains the functions for user authentication."""
2
+
3
+ import uuid
4
+ from datetime import UTC, datetime, timedelta
5
+ from typing import Annotated
6
+
7
+ import bcrypt
8
+ from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
9
+ from fastapi.responses import JSONResponse
10
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
11
+ from jose import JWTError, jwt
12
+ from pydantic import BaseModel
13
+ from sqlmodel import Session, select
14
+
15
+ from desdeo.api import SettingsConfig
16
+ from desdeo.api.db import get_session
17
+ from desdeo.api.models import User, UserPublic
18
+
19
+ # AuthConfig
20
+ if SettingsConfig.debug:
21
+ from desdeo.api import AuthDebugConfig
22
+
23
+ AuthConfig = AuthDebugConfig
24
+ else:
25
+ pass
26
+
27
+ router = APIRouter()
28
+
29
+
30
+ class Tokens(BaseModel):
31
+ """A model for the authentication token."""
32
+
33
+ access_token: str
34
+ refresh_token: str
35
+ token_type: str
36
+
37
+
38
+ # OAuth2PasswordBearer is a class that creates a dependency that will be used to get the token from the request.
39
+ # The token will be used to authenticate the user.
40
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
41
+
42
+
43
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
44
+ """Check if a password matches a hash.
45
+
46
+ Args:
47
+ plain_password (str): the plain password.
48
+ hashed_password (str): the hashed password.
49
+
50
+ Returns:
51
+ bool: whether the plain password matches the hashed one.
52
+ """
53
+ password_byte_enc = plain_password.encode("utf-8")
54
+
55
+ return bcrypt.checkpw(password=password_byte_enc, hashed_password=hashed_password.encode("utf-8"))
56
+
57
+
58
+ def get_password_hash(password: str) -> str:
59
+ """Hash a password.
60
+
61
+ Args:
62
+ password (str): the password to be hashed.
63
+
64
+ Returns:
65
+ str: the hashed password.
66
+ """
67
+ pwd_bytes = password.encode("utf-8")
68
+
69
+ return bcrypt.hashpw(password=pwd_bytes, salt=bcrypt.gensalt()).decode("utf-8")
70
+
71
+
72
+ def get_user(session: Session, username: str) -> User | None:
73
+ """Get the current user.
74
+
75
+ Get the current user based on the username. If no user if found,
76
+ return None.
77
+
78
+ Args:
79
+ session (Session): database session.
80
+ username (str): the username of the user.
81
+
82
+ Returns:
83
+ User | None: the User. If no user is found, returns None.
84
+ """
85
+ statement = select(User).where(User.username == username)
86
+ return session.exec(statement).first()
87
+
88
+
89
+ def authenticate_user(session: Session, username: str, password: str) -> User | None:
90
+ """Check if a user exists and the password is correct.
91
+
92
+ Check if a user exists and the password is correct. If the user exists and the password
93
+ is correct, returns the user.
94
+
95
+ Args:
96
+ session (Session): database session.
97
+ username (str): the username of the user.
98
+ password (str): password set for the user.
99
+
100
+ Returns:
101
+ User | None: the User. If no user if found, returns None.
102
+ """
103
+ user = get_user(session, username)
104
+
105
+ if not user or not verify_password(password, user.password_hash):
106
+ return None
107
+
108
+ return user
109
+
110
+
111
+ def get_current_user(
112
+ token: Annotated[str, Depends(oauth2_scheme)],
113
+ session: Annotated[Session, Depends(get_session)],
114
+ ) -> User:
115
+ """Get the current user based on a JWT token.
116
+
117
+ This function is a dependency for other functions that need to get the current user.
118
+
119
+ Args:
120
+ token (Annotated[str, Depends(oauth2_scheme)]): The authentication token.
121
+ session (Annotated[Session, Depends(get_db)]): A database session.
122
+
123
+ Returns:
124
+ User: The information of the current user.
125
+
126
+ Raises:
127
+ HTTPException: If the token is invalid.
128
+ """
129
+ credentials_exception = HTTPException(
130
+ status_code=status.HTTP_401_UNAUTHORIZED,
131
+ detail="Could not validate credentials",
132
+ headers={"WWW-Authenticate": "Bearer"},
133
+ )
134
+ try:
135
+ payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
136
+ username = payload.get("sub")
137
+ expire_time: datetime = payload.get("exp")
138
+
139
+ if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
140
+ raise credentials_exception
141
+
142
+ except jwt.exceptions.ExpiredSignatureError:
143
+ raise credentials_exception from None
144
+
145
+ except JWTError:
146
+ raise credentials_exception from JWTError
147
+
148
+ user = get_user(session, username=username)
149
+
150
+ if user is None:
151
+ raise credentials_exception
152
+
153
+ return user
154
+
155
+
156
+ def create_jwt_token(
157
+ data: dict,
158
+ expires_delta: timedelta,
159
+ algorithm: str = AuthConfig.authjwt_algorithm,
160
+ secret_key: str = AuthConfig.authjwt_secret_key,
161
+ ) -> str:
162
+ """Creates an JWT Token with `data` and `expire_delta`.
163
+
164
+ Args:
165
+ data (dict): The data to encode in the token.
166
+ expires_delta (timedelta): The time after which the token will expire.
167
+ algorithm (str): the algorithms to encode the JWT token.
168
+ Defaults to `AuthConfig.authjwt_algorithm`.
169
+ secret_key (str): the secret key used in encoding the JWT token.
170
+ Defaults to `AuthConfig.authjwt_secret_key`.
171
+
172
+ Returns:
173
+ str: the JWT token.
174
+ """
175
+ data = data.copy()
176
+ expire = datetime.now(UTC) + expires_delta
177
+ data.update({"exp": expire, "jti": str(uuid.uuid4())})
178
+ # jti adds an unique identifier so that life is easier
179
+
180
+ return jwt.encode(data, secret_key, algorithm=algorithm)
181
+
182
+
183
+ def create_access_token(data: dict, expiration_time: int = AuthConfig.authjwt_access_token_expires) -> str:
184
+ """Creates a JWT access token.
185
+
186
+ Creates a JWT access token with `data`, and an
187
+ expiration time.
188
+
189
+ Args:
190
+ data (dict): the data to encode in the token.
191
+ expiration_time (int): the expiration time of the access token
192
+ in minutes. Defaults to `AuthConfig.authjwt_access_token_expires`.
193
+
194
+ Returns:
195
+ str: the JWT access token.
196
+ """
197
+ return create_jwt_token(data, timedelta(minutes=expiration_time))
198
+
199
+
200
+ def create_refresh_token(data: dict, expiration_time: int = AuthConfig.authjwt_refresh_token_expires) -> str:
201
+ """Creates a JTW refresh token.
202
+
203
+ Creates a JWT refresh token with `data and an expiration time.
204
+
205
+ Args:
206
+ data (dict): The data to encode in the token.
207
+ expiration_time (int): the expiration time of the refresh token
208
+ in minutes. Defaults to `AuthConfig.authjwt_refresh_token_expires`.
209
+
210
+ Returns:
211
+ str: the JWT refresh token.
212
+ """
213
+ refresh_token: str = create_jwt_token(data, timedelta(minutes=expiration_time))
214
+
215
+ return refresh_token
216
+
217
+
218
+ def generate_tokens(data: dict) -> Tokens:
219
+ """Generates a and refresh Tokens with `data`.
220
+
221
+ Note:
222
+ The expiration times of the tokens in defined in
223
+ `AuthConfig`.
224
+
225
+ Args:
226
+ data (dict): The data to encode in the token.
227
+
228
+ Returns:
229
+ Tokens: the access and refresh tokens.
230
+ """
231
+ access_token = create_access_token(data)
232
+ refresh_token = create_refresh_token(data)
233
+ return Tokens(access_token=access_token, refresh_token=refresh_token, token_type="bearer") # noqa: S106
234
+
235
+
236
+ def validate_refresh_token(
237
+ refresh_token: str,
238
+ session: Annotated[Session, Depends(get_session)],
239
+ algorithm: str = AuthConfig.authjwt_algorithm,
240
+ secret_key: str = AuthConfig.authjwt_secret_key,
241
+ ) -> User:
242
+ """Validate a refresh token and return the associated user if valid.
243
+
244
+ Args:
245
+ refresh_token (str): The refresh token to validate.
246
+ session (Annotated[Session, Depends(get_db)]): The database session.
247
+ algorithm (str): the algorithm used to decode the JWT token.
248
+ Defaults to `AuthConfig.authjwt_algorithm`.
249
+ secret_key (str): the secret key used to decode the JWT token.
250
+ Defaults to `AuthConfig.authjwt_secret_key`.
251
+
252
+ Returns:
253
+ UserModel: The user associated with the valid refresh token.
254
+
255
+ Raises:
256
+ HTTPException: If the refresh token is invalid or expired.
257
+ """
258
+ credentials_exception = HTTPException(
259
+ status_code=status.HTTP_401_UNAUTHORIZED,
260
+ detail="Invalid or expired refresh token",
261
+ headers={"WWW-Authenticate": "Bearer"},
262
+ )
263
+
264
+ try:
265
+ # Decode the refresh token
266
+ payload = jwt.decode(refresh_token, secret_key, algorithms=[algorithm])
267
+ username = payload.get("sub")
268
+ expire_time: datetime = payload.get("exp")
269
+
270
+ except Exception as _:
271
+ raise credentials_exception from None
272
+
273
+ if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
274
+ raise credentials_exception
275
+
276
+ # Validate the user from the database
277
+ user = get_user(session, username=username)
278
+ if user is None:
279
+ raise credentials_exception
280
+
281
+ return user
282
+
283
+
284
+ @router.get("/user_info")
285
+ def get_current_user_info(user: Annotated[User, Depends(get_current_user)]) -> UserPublic:
286
+ """Return information about the current user.
287
+
288
+ Args:
289
+ user (Annotated[User, Depends): user dependency, handled by `get_current_user`.
290
+
291
+ Returns:
292
+ UserPublic: public information about the current user.
293
+ """
294
+ return user
295
+
296
+
297
+ @router.post("/login")
298
+ def login(
299
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
300
+ session: Annotated[Session, Depends(get_session)],
301
+ cookie_max_age: int = AuthConfig.authjwt_refresh_token_expires,
302
+ ):
303
+ """Login to get an authentication token.
304
+
305
+ Return an access token in the response and a cookie storing a refresh token.
306
+
307
+ Args:
308
+ form_data (Annotated[OAuth2PasswordRequestForm, Depends()]):
309
+ The form data to authenticate the user.
310
+ session (Annotated[Session, Depends(get_db)]): The database session.
311
+ cookie_max_age (int): the lifetime of the cookie storing the refresh token.
312
+
313
+ """
314
+ user = authenticate_user(session, form_data.username, form_data.password)
315
+ if user is None:
316
+ raise HTTPException(
317
+ status_code=status.HTTP_401_UNAUTHORIZED,
318
+ detail="Incorrect username or password",
319
+ headers={"WWW-Authenticate": "Bearer"},
320
+ )
321
+
322
+ tokens = generate_tokens({"id": user.id, "sub": user.username})
323
+
324
+ response = JSONResponse(content={"access_token": tokens.access_token})
325
+ response.set_cookie(
326
+ key="refresh_token",
327
+ value=tokens.refresh_token,
328
+ httponly=True, # HTTP only cookie, more secure than storing the refresh token in the frontend code.
329
+ secure=False, # allow http
330
+ samesite="lax", # cross-origin requests
331
+ max_age=cookie_max_age * 60, # convert to minutes
332
+ )
333
+
334
+ return response
335
+
336
+
337
+ @router.post("/refresh")
338
+ def refresh_access_token(
339
+ request: Response,
340
+ session: Annotated[Session, Depends(get_session)],
341
+ refresh_token: Annotated[str | None, Cookie()] = None,
342
+ ):
343
+ """Refresh the access token using the refresh token stored in the cookie.
344
+
345
+ Args:
346
+ request (Request): The request containing the cookie.
347
+ session (Annotated[Session, Depends(get_db)]): the database session.
348
+ refresh_token (Annotated[Str | None, Cookie()]): the refresh
349
+ token, which is fetched from a cookie included in the response.
350
+
351
+ Returns:
352
+ dict: A dictionary containing the new access token.
353
+ """
354
+ if not refresh_token:
355
+ raise HTTPException(
356
+ status_code=status.HTTP_401_UNAUTHORIZED,
357
+ detail="Missing refresh token.",
358
+ headers={"WWW-Authenticate": "Bearer"},
359
+ )
360
+
361
+ user = validate_refresh_token(refresh_token, session)
362
+
363
+ # Generate a new access token for the user
364
+ access_token = create_access_token({"id": user.id, "sub": user.username})
365
+
366
+ return {"access_token": access_token}