desdeo 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. desdeo/adm/ADMAfsar.py +551 -0
  2. desdeo/adm/ADMChen.py +414 -0
  3. desdeo/adm/BaseADM.py +119 -0
  4. desdeo/adm/__init__.py +11 -0
  5. desdeo/api/__init__.py +6 -6
  6. desdeo/api/app.py +38 -28
  7. desdeo/api/config.py +65 -44
  8. desdeo/api/config.toml +23 -12
  9. desdeo/api/db.py +10 -8
  10. desdeo/api/db_init.py +12 -6
  11. desdeo/api/models/__init__.py +220 -20
  12. desdeo/api/models/archive.py +16 -27
  13. desdeo/api/models/emo.py +128 -0
  14. desdeo/api/models/enautilus.py +69 -0
  15. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  16. desdeo/api/models/gdm/gdm_base.py +69 -0
  17. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  18. desdeo/api/models/gdm/gnimbus.py +138 -0
  19. desdeo/api/models/generic.py +104 -0
  20. desdeo/api/models/generic_states.py +401 -0
  21. desdeo/api/models/nimbus.py +158 -0
  22. desdeo/api/models/preference.py +44 -6
  23. desdeo/api/models/problem.py +274 -64
  24. desdeo/api/models/session.py +4 -1
  25. desdeo/api/models/state.py +419 -52
  26. desdeo/api/models/user.py +7 -6
  27. desdeo/api/models/utopia.py +25 -0
  28. desdeo/api/routers/_EMO.backup +309 -0
  29. desdeo/api/routers/_NIMBUS.py +6 -3
  30. desdeo/api/routers/emo.py +497 -0
  31. desdeo/api/routers/enautilus.py +237 -0
  32. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  33. desdeo/api/routers/gdm/gdm_base.py +420 -0
  34. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  35. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  36. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  37. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  38. desdeo/api/routers/generic.py +233 -0
  39. desdeo/api/routers/nimbus.py +705 -0
  40. desdeo/api/routers/problem.py +201 -4
  41. desdeo/api/routers/reference_point_method.py +20 -44
  42. desdeo/api/routers/session.py +50 -26
  43. desdeo/api/routers/user_authentication.py +180 -26
  44. desdeo/api/routers/utils.py +187 -0
  45. desdeo/api/routers/utopia.py +230 -0
  46. desdeo/api/schema.py +10 -4
  47. desdeo/api/tests/conftest.py +94 -2
  48. desdeo/api/tests/test_enautilus.py +330 -0
  49. desdeo/api/tests/test_models.py +550 -72
  50. desdeo/api/tests/test_routes.py +902 -43
  51. desdeo/api/utils/_database.py +263 -0
  52. desdeo/api/utils/database.py +28 -266
  53. desdeo/api/utils/emo_database.py +40 -0
  54. desdeo/core.py +7 -0
  55. desdeo/emo/__init__.py +154 -24
  56. desdeo/emo/hooks/archivers.py +18 -2
  57. desdeo/emo/methods/EAs.py +128 -5
  58. desdeo/emo/methods/bases.py +9 -56
  59. desdeo/emo/methods/templates.py +111 -0
  60. desdeo/emo/operators/crossover.py +544 -42
  61. desdeo/emo/operators/evaluator.py +10 -14
  62. desdeo/emo/operators/generator.py +127 -24
  63. desdeo/emo/operators/mutation.py +212 -41
  64. desdeo/emo/operators/scalar_selection.py +202 -0
  65. desdeo/emo/operators/selection.py +956 -214
  66. desdeo/emo/operators/termination.py +124 -16
  67. desdeo/emo/options/__init__.py +108 -0
  68. desdeo/emo/options/algorithms.py +435 -0
  69. desdeo/emo/options/crossover.py +164 -0
  70. desdeo/emo/options/generator.py +131 -0
  71. desdeo/emo/options/mutation.py +260 -0
  72. desdeo/emo/options/repair.py +61 -0
  73. desdeo/emo/options/scalar_selection.py +66 -0
  74. desdeo/emo/options/selection.py +127 -0
  75. desdeo/emo/options/templates.py +383 -0
  76. desdeo/emo/options/termination.py +143 -0
  77. desdeo/gdm/__init__.py +22 -0
  78. desdeo/gdm/gdmtools.py +45 -0
  79. desdeo/gdm/score_bands.py +114 -0
  80. desdeo/gdm/voting_rules.py +50 -0
  81. desdeo/mcdm/__init__.py +23 -1
  82. desdeo/mcdm/enautilus.py +338 -0
  83. desdeo/mcdm/gnimbus.py +484 -0
  84. desdeo/mcdm/nautilus_navigator.py +7 -6
  85. desdeo/mcdm/reference_point_method.py +70 -0
  86. desdeo/problem/__init__.py +16 -11
  87. desdeo/problem/evaluator.py +4 -5
  88. desdeo/problem/external/__init__.py +18 -0
  89. desdeo/problem/external/core.py +356 -0
  90. desdeo/problem/external/pymoo_provider.py +266 -0
  91. desdeo/problem/external/runtime.py +44 -0
  92. desdeo/problem/gurobipy_evaluator.py +37 -12
  93. desdeo/problem/infix_parser.py +1 -16
  94. desdeo/problem/json_parser.py +7 -11
  95. desdeo/problem/pyomo_evaluator.py +25 -6
  96. desdeo/problem/schema.py +73 -55
  97. desdeo/problem/simulator_evaluator.py +65 -15
  98. desdeo/problem/testproblems/__init__.py +26 -11
  99. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  100. desdeo/problem/testproblems/cake_problem.py +185 -0
  101. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  102. desdeo/problem/testproblems/forest_problem.py +77 -69
  103. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  104. desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
  105. desdeo/problem/testproblems/single_objective.py +289 -0
  106. desdeo/problem/testproblems/zdt_problem.py +4 -1
  107. desdeo/problem/utils.py +1 -1
  108. desdeo/tools/__init__.py +39 -21
  109. desdeo/tools/desc_gen.py +22 -0
  110. desdeo/tools/generics.py +22 -2
  111. desdeo/tools/group_scalarization.py +3090 -0
  112. desdeo/tools/indicators_binary.py +107 -1
  113. desdeo/tools/indicators_unary.py +3 -16
  114. desdeo/tools/message.py +33 -2
  115. desdeo/tools/non_dominated_sorting.py +4 -3
  116. desdeo/tools/patterns.py +9 -7
  117. desdeo/tools/pyomo_solver_interfaces.py +49 -36
  118. desdeo/tools/reference_vectors.py +118 -351
  119. desdeo/tools/scalarization.py +340 -1413
  120. desdeo/tools/score_bands.py +491 -328
  121. desdeo/tools/utils.py +117 -49
  122. desdeo/tools/visualizations.py +67 -0
  123. desdeo/utopia_stuff/utopia_problem.py +1 -1
  124. desdeo/utopia_stuff/utopia_problem_old.py +1 -1
  125. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/METADATA +47 -30
  126. desdeo-2.1.1.dist-info/RECORD +180 -0
  127. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/WHEEL +1 -1
  128. desdeo-2.0.0.dist-info/RECORD +0 -120
  129. /desdeo/api/utils/{logger.py → _logger.py} +0 -0
  130. {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info/licenses}/LICENSE +0 -0
@@ -5,24 +5,22 @@ from datetime import UTC, datetime, timedelta
5
5
  from typing import Annotated
6
6
 
7
7
  import bcrypt
8
- from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
8
+ from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, Security, status
9
9
  from fastapi.responses import JSONResponse
10
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
11
- from jose import JWTError, jwt
10
+ from fastapi.security import (
11
+ APIKeyCookie,
12
+ HTTPAuthorizationCredentials,
13
+ HTTPBearer,
14
+ OAuth2PasswordBearer,
15
+ OAuth2PasswordRequestForm,
16
+ )
17
+ from jose import ExpiredSignatureError, JWTError, jwt
12
18
  from pydantic import BaseModel
13
19
  from sqlmodel import Session, select
14
20
 
15
- from desdeo.api import SettingsConfig
21
+ from desdeo.api import AuthConfig
16
22
  from desdeo.api.db import get_session
17
- from desdeo.api.models import User, UserPublic
18
-
19
- # AuthConfig
20
- if SettingsConfig.debug:
21
- from desdeo.api import AuthDebugConfig
22
-
23
- AuthConfig = AuthDebugConfig
24
- else:
25
- pass
23
+ from desdeo.api.models import User, UserPublic, UserRole
26
24
 
27
25
  router = APIRouter()
28
26
 
@@ -37,7 +35,9 @@ class Tokens(BaseModel):
37
35
 
38
36
  # OAuth2PasswordBearer is a class that creates a dependency that will be used to get the token from the request.
39
37
  # The token will be used to authenticate the user.
40
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
38
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login", auto_error=False)
39
+ # Same, but for getting the access_token from the cookies of the request.
40
+ cookie_scheme = APIKeyCookie(name="access_token", auto_error=False)
41
41
 
42
42
 
43
43
  def verify_password(plain_password: str, hashed_password: str) -> bool:
@@ -108,9 +108,11 @@ def authenticate_user(session: Session, username: str, password: str) -> User |
108
108
  return user
109
109
 
110
110
 
111
+ # token: Annotated[str, Depends(oauth2_scheme)],
111
112
  def get_current_user(
112
- token: Annotated[str, Depends(oauth2_scheme)],
113
113
  session: Annotated[Session, Depends(get_session)],
114
+ header_token: Annotated[str | None, Security(oauth2_scheme)] = None,
115
+ cookie_token: Annotated[str | None, Security(cookie_scheme)] = None,
114
116
  ) -> User:
115
117
  """Get the current user based on a JWT token.
116
118
 
@@ -126,11 +128,16 @@ def get_current_user(
126
128
  Raises:
127
129
  HTTPException: If the token is invalid.
128
130
  """
131
+ token = header_token or cookie_token
132
+
129
133
  credentials_exception = HTTPException(
130
134
  status_code=status.HTTP_401_UNAUTHORIZED,
131
135
  detail="Could not validate credentials",
132
136
  headers={"WWW-Authenticate": "Bearer"},
133
137
  )
138
+
139
+ if not token:
140
+ raise credentials_exception
134
141
  try:
135
142
  payload = jwt.decode(token, AuthConfig.authjwt_secret_key, algorithms=[AuthConfig.authjwt_algorithm])
136
143
  username = payload.get("sub")
@@ -139,7 +146,7 @@ def get_current_user(
139
146
  if username is None or expire_time is None or expire_time < datetime.now(UTC).timestamp():
140
147
  raise credentials_exception
141
148
 
142
- except jwt.exceptions.ExpiredSignatureError:
149
+ except ExpiredSignatureError:
143
150
  raise credentials_exception from None
144
151
 
145
152
  except JWTError:
@@ -281,6 +288,54 @@ def validate_refresh_token(
281
288
  return user
282
289
 
283
290
 
291
+ def add_user_to_database(
292
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
293
+ role: UserRole,
294
+ session: Annotated[Session, Depends(get_session)],
295
+ ) -> None:
296
+ """Add a user to database.
297
+
298
+ Args:
299
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()]: form with username and password to be added
300
+ role: UserRole: Role of the user to be added to the database
301
+ session: Annotated[Session, Depends(get_session)]: database session
302
+
303
+ Returns:
304
+ None
305
+
306
+ Raises:
307
+ HTTPException: If username already is in the database or if adding the user to the database failed.
308
+ """
309
+ username = form_data.username
310
+ password = form_data.password
311
+
312
+ # Check if a user with requested username is already in the database
313
+ if get_user(session=session, username=username):
314
+ raise HTTPException(
315
+ status_code=status.HTTP_409_CONFLICT,
316
+ detail="Username already taken.",
317
+ )
318
+
319
+ # Create the user model and put it into database
320
+ new_user = User(
321
+ username=username,
322
+ password_hash=get_password_hash(
323
+ password=password,
324
+ ),
325
+ role=role,
326
+ )
327
+ session.add(new_user)
328
+ session.commit()
329
+ session.refresh(new_user)
330
+
331
+ # Verify that the user actually is in the database
332
+ if not get_user(session=session, username=username):
333
+ raise HTTPException(
334
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
335
+ detail="Failed to add user into database.",
336
+ )
337
+
338
+
284
339
  @router.get("/user_info")
285
340
  def get_current_user_info(user: Annotated[User, Depends(get_current_user)]) -> UserPublic:
286
341
  """Return information about the current user.
@@ -294,7 +349,7 @@ def get_current_user_info(user: Annotated[User, Depends(get_current_user)]) -> U
294
349
  return user
295
350
 
296
351
 
297
- @router.post("/login")
352
+ @router.post("/login", response_model=Tokens)
298
353
  def login(
299
354
  form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
300
355
  session: Annotated[Session, Depends(get_session)],
@@ -321,19 +376,49 @@ def login(
321
376
 
322
377
  tokens = generate_tokens({"id": user.id, "sub": user.username})
323
378
 
324
- response = JSONResponse(content={"access_token": tokens.access_token})
325
- response.set_cookie(
326
- key="refresh_token",
327
- value=tokens.refresh_token,
328
- httponly=True, # HTTP only cookie, more secure than storing the refresh token in the frontend code.
329
- secure=False, # allow http
330
- samesite="lax", # cross-origin requests
331
- max_age=cookie_max_age * 60, # convert to minutes
332
- )
379
+ response = JSONResponse(content={"access_token": tokens.access_token, "refresh_token": tokens.refresh_token})
380
+
381
+ if AuthConfig.cookie_domain == "":
382
+ response.set_cookie(
383
+ key="refresh_token",
384
+ value=tokens.refresh_token,
385
+ httponly=True, # HTTP only cookie, more secure than storing the refresh token in the frontend code.
386
+ secure=False, # allow http
387
+ samesite="lax", # cross-origin requests
388
+ max_age=cookie_max_age * 60, # convert to minutes
389
+ path="/",
390
+ )
391
+ else:
392
+ response.set_cookie(
393
+ key="refresh_token",
394
+ value=tokens.refresh_token,
395
+ httponly=True, # keep this
396
+ secure=True, # MUST be true for HTTPS
397
+ samesite="none", # required for cross-site subdomains
398
+ max_age=cookie_max_age * 60,
399
+ path="/",
400
+ domain=AuthConfig.cookie_domain, # <- allow sharing between API + webui
401
+ )
333
402
 
334
403
  return response
335
404
 
336
405
 
406
+ @router.post("/logout")
407
+ def logout() -> JSONResponse:
408
+ """Log the current user out. Deletes the refresh token that was set by logging in.
409
+
410
+ Args:
411
+ None
412
+
413
+ Returns:
414
+ JSONResponse: A response in which the cookies are deleted
415
+
416
+ """
417
+ response = JSONResponse(content={"message": "logged out"}, status_code=status.HTTP_200_OK)
418
+ response.delete_cookie("refresh_token")
419
+ return response
420
+
421
+
337
422
  @router.post("/refresh")
338
423
  def refresh_access_token(
339
424
  request: Response,
@@ -364,3 +449,72 @@ def refresh_access_token(
364
449
  access_token = create_access_token({"id": user.id, "sub": user.username})
365
450
 
366
451
  return {"access_token": access_token}
452
+
453
+
454
+ @router.post("/add_new_dm")
455
+ def add_new_dm(
456
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
457
+ session: Annotated[Session, Depends(get_session)],
458
+ ) -> JSONResponse:
459
+ """Add a new user of the role Decision Maker to the database. Requires no login.
460
+
461
+ Args:
462
+ form_data (Annotated[OAuth2PasswordRequestForm, Depends()]): The user credentials to add to the database.
463
+ session (Annotated[Session, Depends(get_session)]): the database session.
464
+
465
+ Returns:
466
+ JSONResponse: A JSON response
467
+
468
+ Raises:
469
+ HTTPException: if username is already in use or if saving to the database fails for some reason.
470
+ """
471
+ add_user_to_database(
472
+ form_data=form_data,
473
+ role=UserRole.dm,
474
+ session=session,
475
+ )
476
+
477
+ return JSONResponse(
478
+ content={"message": 'User with role "decision maker" created.'},
479
+ status_code=status.HTTP_201_CREATED,
480
+ )
481
+
482
+
483
+ @router.post("/add_new_analyst")
484
+ def add_new_analyst(
485
+ user: Annotated[User, Depends(get_current_user)],
486
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
487
+ session: Annotated[Session, Depends(get_session)],
488
+ ) -> JSONResponse:
489
+ """Add a new user of the role Analyst to the database. Requires a logged in analyst or an admin.
490
+
491
+ Args:
492
+ user: Annotated[User, Depends(get_current_user)]: Logged in user with the role "analyst" or "admin".
493
+ form_data: (Annotated[OAuth2PasswordRequestForm, Depends()]): The user credentials to add to the database.
494
+ session: (Annotated[Session, Depends(get_session)]): the database session.
495
+
496
+ Returns:
497
+ JSONResponse: A JSON response
498
+
499
+ Raises:
500
+ HTTPException: if the logged in user is not an analyst or an admin or if
501
+ username is already in use or if saving to the database fails for some reason.
502
+
503
+ """
504
+ # Check if the user who tries to create the user is either an analyst or an admin.
505
+ if not (user.role == UserRole.analyst or user.role == UserRole.admin):
506
+ raise HTTPException(
507
+ status_code=status.HTTP_401_UNAUTHORIZED,
508
+ detail="Logged in user has insufficient rights.",
509
+ )
510
+
511
+ add_user_to_database(
512
+ form_data=form_data,
513
+ role=UserRole.analyst,
514
+ session=session,
515
+ )
516
+
517
+ return JSONResponse(
518
+ content={"message": 'User with role "analyst" created.'},
519
+ status_code=status.HTTP_201_CREATED,
520
+ )
@@ -0,0 +1,187 @@
1
+ """A selection of utilities for handling routers and data therein.
2
+
3
+ NOTE: No routers should be defined in this file!
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Annotated
8
+
9
+ from fastapi import Depends, HTTPException, status
10
+ from sqlmodel import Session, select
11
+
12
+ from desdeo.api.db import get_session
13
+ from desdeo.api.models import (
14
+ ENautilusStepRequest,
15
+ InteractiveSessionDB,
16
+ ProblemDB,
17
+ RPMSolveRequest,
18
+ StateDB,
19
+ User,
20
+ )
21
+ from desdeo.api.routers.user_authentication import get_current_user
22
+
23
+ RequestType = RPMSolveRequest | ENautilusStepRequest
24
+
25
+
26
+ def fetch_interactive_session(user: User, request: RequestType, session: Session) -> InteractiveSessionDB | None:
27
+ """Gets the desired instance of `InteractiveSessionDB`.
28
+
29
+ Args:
30
+ user (User): the user whose interactive sessions are to be queried.
31
+ request (RequestType): the request with possibly information on which interactive session to query.
32
+ session (Session): the database session (not to be confused with the interactive session) from
33
+ which the interactive session should be queried.
34
+
35
+ Note:
36
+ If no explicit `session_id` is given in `request`, this function will try to fetch the
37
+ currently active interactive session for the `user`, e.g., with id `user.active_session_id`.
38
+ If this is `None`, then the interactive session returned will be `None` as well.
39
+
40
+ Raises:
41
+ HTTPException: when an explicit interactive session is requested, but it is not found.
42
+
43
+ Returns:
44
+ InteractiveSessionDB | None: an interactive session DB model, or nothing.
45
+ """
46
+ if request.session_id is not None:
47
+ # specific interactive session id is given, try using that
48
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == request.session_id)
49
+ interactive_session = session.exec(statement).first()
50
+
51
+ if interactive_session is None:
52
+ # Raise if explicitly requested interactive session cannot be found
53
+ raise HTTPException(
54
+ status_code=status.HTTP_404_NOT_FOUND,
55
+ detail=f"Could not find interactive session with id={request.session_id}.",
56
+ )
57
+ else:
58
+ # request.session_id is None
59
+ # try to use active session instead
60
+
61
+ statement = select(InteractiveSessionDB).where(InteractiveSessionDB.id == user.active_session_id)
62
+
63
+ interactive_session = session.exec(statement).first()
64
+
65
+ # At this point interactive_session is either an instance of InteractiveSessionDB or None (which is fine)
66
+
67
+ return interactive_session
68
+
69
+
70
+ def fetch_user_problem(user: User, request: RequestType, session: Session) -> ProblemDB:
71
+ """Fetches a user's `ProblemDB` based on the id in the given request.
72
+
73
+ Args:
74
+ user (User): the user for which the problem is fetched.
75
+ request (RequestType): request containing details of the problem to be fetched (`request.problem_id`).
76
+ session (Session): the database session from which to fetch the problem.
77
+
78
+ Raises:
79
+ HTTPException: a problem with the given id (`request.problem_id`) could not be found (404).
80
+
81
+ Returns:
82
+ Problem: the instance of `ProblemDB` with the given id.
83
+ """
84
+ statement = select(ProblemDB).where(ProblemDB.user_id == user.id, ProblemDB.id == request.problem_id)
85
+ problem_db = session.exec(statement).first()
86
+
87
+ if problem_db is None:
88
+ raise HTTPException(
89
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Problem with id={request.problem_id} could not be found."
90
+ )
91
+
92
+ return problem_db
93
+
94
+
95
+ def fetch_parent_state(
96
+ user: User, request: RequestType, session: Session, interactive_session: InteractiveSessionDB | None = None
97
+ ) -> StateDB | None:
98
+ """Fetches the parent state, if an id is given, or if defined in the given interactive session.
99
+
100
+ Determines the appropriate parent `StateDB` instance to associate with a new
101
+ state or operation. It first checks whether the `request` explicitly
102
+ provides a `parent_state_id`. If so, it attempts to retrieve the
103
+ corresponding `StateDB` entry from the database. If no such id is provided,
104
+ the function defaults to returning the most recently added state from the
105
+ given `interactive_session`, if available. If neither source provides a
106
+ parent state, `None` is returned.
107
+
108
+
109
+ Args:
110
+ user (User): the user for which the parent state is fetched.
111
+ request (RequestType): request containing details about the parent state and optionally the
112
+ interactive session.
113
+ session (Session): the database session from which to fetch the parent state.
114
+ interactive_session (InteractiveSessionDB | None, optional): the interactive session containing
115
+ information about the parent state. Defaults to None.
116
+
117
+ Raises:
118
+ HTTPException: when `request.parent_state_id` is not `None` and a `StateDB` with this id cannot
119
+ be found in the given database session.
120
+
121
+ Returns:
122
+ StateDB | None: if `request.parent_state_id` is given, returns the corresponding `StateDB`.
123
+ If it is not given, returns the latest state defined in `interactive_session.states`.
124
+ If both `request.parent_state_id` and `interactive_session` are `None`, then returns `None`.
125
+ """
126
+ if request.parent_state_id is None:
127
+ # parent state is assumed to be the last sate added to the session.
128
+ # if `interactive_session` is None, then parent state is set to None.
129
+ parent_state = (
130
+ interactive_session.states[-1]
131
+ if (interactive_session is not None and len(interactive_session.states) > 0)
132
+ else None
133
+ )
134
+
135
+ else:
136
+ # request.parent_state_id is not None
137
+ statement = select(StateDB).where(StateDB.id == request.parent_state_id)
138
+ parent_state = session.exec(statement).first()
139
+
140
+ # this error is raised because if a parent_state_id is given, it is assumed that the
141
+ # user wished to use that state explicitly as the parent.
142
+ if parent_state is None:
143
+ raise HTTPException(
144
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find state with id={request.parent_state_id}"
145
+ )
146
+
147
+ return parent_state
148
+
149
+
150
+ @dataclass(frozen=True)
151
+ class SessionContext:
152
+ """A generic context to be used in various endpoints."""
153
+
154
+ user: User
155
+ db_session: Session
156
+ problem_db: ProblemDB
157
+ interactive_session: InteractiveSessionDB | None
158
+ parent_state: StateDB | None
159
+
160
+
161
+ def get_session_context(
162
+ request: RequestType,
163
+ user: Annotated[User, Depends(get_current_user)],
164
+ db_session: Annotated[Session, Depends(get_session)],
165
+ ) -> SessionContext:
166
+ """Gets the current session context. Should be used as a dep.
167
+
168
+ Args:
169
+ request (RequestType): request based on which the context is fetched.
170
+ user (Annotated[User, Depends): the current user (dep).
171
+ db_session (Annotated[Session, Depends): the current database session (dep).
172
+
173
+ Returns:
174
+ SessionContext: the current session context with the relevant instances
175
+ of `User`, `Session`, `ProblemDB`, `InteractiveSessionDB`, and `StateDB`.
176
+ """
177
+ problem_db = fetch_user_problem(user, request, db_session)
178
+ interactive_session = fetch_interactive_session(user, request, db_session)
179
+ parent_state = fetch_parent_state(user, request, db_session, interactive_session=interactive_session)
180
+
181
+ return SessionContext(
182
+ user=user,
183
+ db_session=db_session,
184
+ problem_db=problem_db,
185
+ interactive_session=interactive_session,
186
+ parent_state=parent_state,
187
+ )
@@ -0,0 +1,230 @@
1
+ """Utopia router."""
2
+
3
+ import json
4
+ from typing import Annotated
5
+
6
+ from fastapi import APIRouter, Depends
7
+ from sqlmodel import Session, select
8
+
9
+ from desdeo.api.db import get_session
10
+ from desdeo.api.models import (
11
+ ForestProblemMetaData,
12
+ NIMBUSFinalState,
13
+ NIMBUSInitializationState,
14
+ NIMBUSSaveState,
15
+ ProblemMetaDataDB,
16
+ StateDB,
17
+ User,
18
+ UtopiaRequest,
19
+ UtopiaResponse,
20
+ )
21
+ from desdeo.api.routers.user_authentication import get_current_user
22
+
23
+ router = APIRouter(prefix="/utopia")
24
+
25
+
26
+ @router.post("/")
27
+ def get_utopia_data(
28
+ request: UtopiaRequest,
29
+ user: Annotated[User, Depends(get_current_user)],
30
+ session: Annotated[Session, Depends(get_session)],
31
+ ) -> UtopiaResponse:
32
+ """Request and receive the Utopia map corresponding to the decision variables sent.
33
+
34
+ Args:
35
+ request (UtopiaRequest): the set of decision variables and problem for which the utopia forest map is requested
36
+ for.
37
+ user (Annotated[User, Depend(get_current_user)]) the current user
38
+ session (Annotated[Session, Depends(get_session)]) the current database session
39
+ Raises:
40
+ HTTPException:
41
+ Returns:
42
+ UtopiaResponse: the map for the forest, to be rendered in frontend
43
+ """
44
+ empty_response = UtopiaResponse(is_utopia=False, map_name="", map_json={}, options={}, description="", years=[])
45
+
46
+ state = session.exec(select(StateDB).where(StateDB.id == request.solution.state_id)).first()
47
+ if state is None or not hasattr(state, "state"):
48
+ return empty_response
49
+
50
+ actual_state = state.state
51
+
52
+ if type(actual_state) is NIMBUSSaveState:
53
+ decision_variables = actual_state.result_variable_values[0]
54
+
55
+ elif type(actual_state) in [NIMBUSInitializationState, NIMBUSFinalState]:
56
+ decision_variables = actual_state.solver_results.optimal_variables
57
+
58
+ else:
59
+ # Check if solver_results exists and has the needed index
60
+ if (
61
+ not hasattr(actual_state, "solver_results")
62
+ or request.solution.solution_index >= len(actual_state.solver_results)
63
+ or actual_state.solver_results[request.solution.solution_index] is None
64
+ ):
65
+ return empty_response
66
+
67
+ result = actual_state.solver_results[request.solution.solution_index]
68
+ if not hasattr(result, "optimal_variables") or not result.optimal_variables:
69
+ return empty_response
70
+ decision_variables = result.optimal_variables # expects a list of variables, won't work without.
71
+
72
+ from_db_metadata = session.exec(
73
+ select(ProblemMetaDataDB).where(ProblemMetaDataDB.problem_id == request.problem_id)
74
+ ).first()
75
+ if from_db_metadata is None:
76
+ return empty_response
77
+
78
+ # Get the last instance of forest related metadata from the database.
79
+ # If for some reason there's more than one forest metadata, return the latest.
80
+ forest_metadata: ForestProblemMetaData = [
81
+ metadata for metadata in from_db_metadata.all_metadata if metadata.metadata_type == "forest_problem_metadata"
82
+ ][-1]
83
+ if forest_metadata is None:
84
+ return empty_response
85
+
86
+ # Figure out the treatments from the decision variables and utopia data
87
+
88
+ def treatment_index(part: str) -> str:
89
+ if "clearcut" in part:
90
+ return 1
91
+ if "below" in part:
92
+ return 2
93
+ if "above" in part:
94
+ return 3
95
+ if "even" in part:
96
+ return 4
97
+ if "first" in part:
98
+ return 5
99
+ return -1
100
+
101
+ treatments_dict = {}
102
+ for key in decision_variables:
103
+ if not key.startswith("X"):
104
+ continue
105
+ # The dict keys get converted to ints to strings when it's loaded from database
106
+ try:
107
+ treatments = forest_metadata.schedule_dict[key][str(decision_variables[key].index(1))]
108
+ except ValueError as e:
109
+ # if the optimization didn't choose any decision alternative, it's safe to assume
110
+ # that nothing is being done at that forest stand
111
+ treatments = forest_metadata.schedule_dict[key]["0"]
112
+ # print(e)
113
+ treatments_dict[key] = {forest_metadata.years[0]: 0, forest_metadata.years[1]: 0, forest_metadata.years[2]: 0}
114
+ for year in treatments_dict[key]:
115
+ if year in treatments:
116
+ for part in treatments.split():
117
+ if year in part:
118
+ treatments_dict[key][year] = treatment_index(part)
119
+
120
+ # Create the options for the webui
121
+
122
+ treatment_colors = {
123
+ 0: "#4daf4a",
124
+ 1: "#e41a1c",
125
+ 2: "#984ea3",
126
+ 3: "#e3d802",
127
+ 4: "#ff7f00",
128
+ 5: "#377eb8",
129
+ }
130
+
131
+ description_dict = {
132
+ 0: "Do nothing",
133
+ 1: "Clearcut",
134
+ 2: "Thinning from below",
135
+ 3: "Thinning from above",
136
+ 4: "Even thinning",
137
+ 5: "First thinning",
138
+ }
139
+
140
+ map_name = "ForestMap" # This isn't visible anywhere on the ui
141
+
142
+ options = {}
143
+ for year in forest_metadata.years:
144
+ options[year] = {
145
+ "tooltip": {
146
+ "trigger": "item",
147
+ "showDelay": 0,
148
+ "transitionDuration": 0.2,
149
+ },
150
+ "visualMap": { # // vis eg. stock levels
151
+ "left": "right",
152
+ "showLabel": True,
153
+ "type": "piecewise", # // for different plans
154
+ "pieces": [],
155
+ "text": ["Management plans"],
156
+ "calculable": True,
157
+ },
158
+ # // predefined symbols for visumap'circle': 'rect': 'roundRect': 'triangle': 'diamond': 'pin':'arrow':
159
+ # // can give custom svgs also
160
+ "toolbox": {
161
+ "show": True,
162
+ # //orient: 'vertical',
163
+ "left": "left",
164
+ "top": "top",
165
+ "feature": {
166
+ "dataView": {"readOnly": True},
167
+ "restore": {},
168
+ "saveAsImage": {},
169
+ },
170
+ },
171
+ # // can draw graphic components to indicate different things at least
172
+ "series": [
173
+ {
174
+ "name": year,
175
+ "type": "map",
176
+ "roam": True,
177
+ "map": map_name,
178
+ "nameProperty": forest_metadata.stand_id_field,
179
+ "label": {
180
+ "show": False # Hide text labels on the map
181
+ },
182
+ # "colorBy": "data",
183
+ # "itemStyle": {"symbol": "triangle", "color": "red"},
184
+ "data": [],
185
+ "nameMap": {},
186
+ }
187
+ ],
188
+ }
189
+
190
+ for key in decision_variables:
191
+ if not key.startswith("X"):
192
+ continue
193
+ stand = int(forest_metadata.schedule_dict[key]["unit"])
194
+ treatment_id = treatments_dict[key][year]
195
+ piece = {
196
+ "value": treatment_id,
197
+ "symbol": "circle",
198
+ "label": description_dict[treatment_id],
199
+ "color": treatment_colors[treatment_id],
200
+ }
201
+ if piece not in options[year]["visualMap"]["pieces"]:
202
+ options[year]["visualMap"]["pieces"].append(piece)
203
+ if forest_metadata.stand_descriptor:
204
+ name = forest_metadata.stand_descriptor[str(stand)] + description_dict[treatment_id]
205
+ else:
206
+ name = "Stand " + str(stand) + " " + description_dict[treatment_id]
207
+ options[year]["series"][0]["data"].append(
208
+ {
209
+ "name": name,
210
+ "value": treatment_id,
211
+ }
212
+ )
213
+ options[year]["series"][0]["nameMap"][stand] = name
214
+
215
+ # Let's also generate a nice description for the map
216
+ map_description = (
217
+ f"Income from harvesting in the first period {int(decision_variables['P_1'])}€.\n"
218
+ + f"Income from harvesting in the second period {int(decision_variables['P_2'])}€.\n"
219
+ + f"Income from harvesting in the third period {int(decision_variables['P_3'])}€.\n"
220
+ + f"The discounted value of the remaining forest at the end of the plan {int(decision_variables['V_end'])}€."
221
+ )
222
+
223
+ return UtopiaResponse(
224
+ is_utopia=True,
225
+ map_name=map_name,
226
+ options=options,
227
+ map_json=json.loads(forest_metadata.map_json),
228
+ description=map_description,
229
+ years=forest_metadata.years,
230
+ )