desdeo 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. desdeo/adm/ADMAfsar.py +551 -0
  2. desdeo/adm/ADMChen.py +414 -0
  3. desdeo/adm/BaseADM.py +119 -0
  4. desdeo/adm/__init__.py +11 -0
  5. desdeo/api/__init__.py +6 -6
  6. desdeo/api/app.py +38 -28
  7. desdeo/api/config.py +65 -44
  8. desdeo/api/config.toml +23 -12
  9. desdeo/api/db.py +10 -8
  10. desdeo/api/db_init.py +12 -6
  11. desdeo/api/models/__init__.py +220 -20
  12. desdeo/api/models/archive.py +16 -27
  13. desdeo/api/models/emo.py +128 -0
  14. desdeo/api/models/enautilus.py +69 -0
  15. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  16. desdeo/api/models/gdm/gdm_base.py +69 -0
  17. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  18. desdeo/api/models/gdm/gnimbus.py +138 -0
  19. desdeo/api/models/generic.py +104 -0
  20. desdeo/api/models/generic_states.py +401 -0
  21. desdeo/api/models/nimbus.py +158 -0
  22. desdeo/api/models/preference.py +44 -6
  23. desdeo/api/models/problem.py +274 -64
  24. desdeo/api/models/session.py +4 -1
  25. desdeo/api/models/state.py +419 -52
  26. desdeo/api/models/user.py +7 -6
  27. desdeo/api/models/utopia.py +25 -0
  28. desdeo/api/routers/_EMO.backup +309 -0
  29. desdeo/api/routers/_NIMBUS.py +6 -3
  30. desdeo/api/routers/emo.py +497 -0
  31. desdeo/api/routers/enautilus.py +237 -0
  32. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  33. desdeo/api/routers/gdm/gdm_base.py +420 -0
  34. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  35. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  36. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  37. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  38. desdeo/api/routers/generic.py +233 -0
  39. desdeo/api/routers/nimbus.py +705 -0
  40. desdeo/api/routers/problem.py +201 -4
  41. desdeo/api/routers/reference_point_method.py +20 -44
  42. desdeo/api/routers/session.py +50 -26
  43. desdeo/api/routers/user_authentication.py +180 -26
  44. desdeo/api/routers/utils.py +187 -0
  45. desdeo/api/routers/utopia.py +230 -0
  46. desdeo/api/schema.py +10 -4
  47. desdeo/api/tests/conftest.py +94 -2
  48. desdeo/api/tests/test_enautilus.py +330 -0
  49. desdeo/api/tests/test_models.py +550 -72
  50. desdeo/api/tests/test_routes.py +902 -43
  51. desdeo/api/utils/_database.py +263 -0
  52. desdeo/api/utils/database.py +28 -266
  53. desdeo/api/utils/emo_database.py +40 -0
  54. desdeo/core.py +7 -0
  55. desdeo/emo/__init__.py +154 -24
  56. desdeo/emo/hooks/archivers.py +18 -2
  57. desdeo/emo/methods/EAs.py +128 -5
  58. desdeo/emo/methods/bases.py +9 -56
  59. desdeo/emo/methods/templates.py +111 -0
  60. desdeo/emo/operators/crossover.py +544 -42
  61. desdeo/emo/operators/evaluator.py +10 -14
  62. desdeo/emo/operators/generator.py +127 -24
  63. desdeo/emo/operators/mutation.py +212 -41
  64. desdeo/emo/operators/scalar_selection.py +202 -0
  65. desdeo/emo/operators/selection.py +956 -214
  66. desdeo/emo/operators/termination.py +124 -16
  67. desdeo/emo/options/__init__.py +108 -0
  68. desdeo/emo/options/algorithms.py +435 -0
  69. desdeo/emo/options/crossover.py +164 -0
  70. desdeo/emo/options/generator.py +131 -0
  71. desdeo/emo/options/mutation.py +260 -0
  72. desdeo/emo/options/repair.py +61 -0
  73. desdeo/emo/options/scalar_selection.py +66 -0
  74. desdeo/emo/options/selection.py +127 -0
  75. desdeo/emo/options/templates.py +383 -0
  76. desdeo/emo/options/termination.py +143 -0
  77. desdeo/gdm/__init__.py +22 -0
  78. desdeo/gdm/gdmtools.py +45 -0
  79. desdeo/gdm/score_bands.py +114 -0
  80. desdeo/gdm/voting_rules.py +50 -0
  81. desdeo/mcdm/__init__.py +23 -1
  82. desdeo/mcdm/enautilus.py +338 -0
  83. desdeo/mcdm/gnimbus.py +484 -0
  84. desdeo/mcdm/nautilus_navigator.py +7 -6
  85. desdeo/mcdm/reference_point_method.py +70 -0
  86. desdeo/problem/__init__.py +5 -1
  87. desdeo/problem/external/__init__.py +18 -0
  88. desdeo/problem/external/core.py +356 -0
  89. desdeo/problem/external/pymoo_provider.py +266 -0
  90. desdeo/problem/external/runtime.py +44 -0
  91. desdeo/problem/infix_parser.py +2 -2
  92. desdeo/problem/pyomo_evaluator.py +25 -6
  93. desdeo/problem/schema.py +69 -48
  94. desdeo/problem/simulator_evaluator.py +65 -15
  95. desdeo/problem/testproblems/__init__.py +26 -11
  96. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  97. desdeo/problem/testproblems/cake_problem.py +185 -0
  98. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  99. desdeo/problem/testproblems/forest_problem.py +77 -69
  100. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  101. desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
  102. desdeo/problem/testproblems/single_objective.py +289 -0
  103. desdeo/problem/testproblems/zdt_problem.py +4 -1
  104. desdeo/tools/__init__.py +39 -21
  105. desdeo/tools/desc_gen.py +22 -0
  106. desdeo/tools/generics.py +22 -2
  107. desdeo/tools/group_scalarization.py +3090 -0
  108. desdeo/tools/indicators_binary.py +107 -1
  109. desdeo/tools/indicators_unary.py +3 -16
  110. desdeo/tools/message.py +33 -2
  111. desdeo/tools/non_dominated_sorting.py +4 -3
  112. desdeo/tools/patterns.py +9 -7
  113. desdeo/tools/pyomo_solver_interfaces.py +48 -35
  114. desdeo/tools/reference_vectors.py +118 -351
  115. desdeo/tools/scalarization.py +340 -1413
  116. desdeo/tools/score_bands.py +491 -328
  117. desdeo/tools/utils.py +117 -49
  118. desdeo/tools/visualizations.py +67 -0
  119. desdeo/utopia_stuff/utopia_problem.py +1 -1
  120. desdeo/utopia_stuff/utopia_problem_old.py +1 -1
  121. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/METADATA +46 -28
  122. desdeo-2.1.0.dist-info/RECORD +180 -0
  123. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  124. desdeo-2.0.0.dist-info/RECORD +0 -120
  125. /desdeo/api/utils/{logger.py → _logger.py} +0 -0
  126. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,104 @@
1
+ """Generic models for the DESDEO API."""
2
+
3
+ from pydantic import ConfigDict
4
+ from sqlmodel import JSON, Column, Field, SQLModel
5
+
6
+ from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult
7
+
8
+ from .generic_states import SolutionReferenceResponse
9
+
10
+
11
+ class SolutionInfo(SQLModel):
12
+ """Used when we wish to reference a solution in some `StateDB` stored in the database."""
13
+
14
+ state_id: int = Field(description="State of the desired solution.")
15
+ solution_index: int = Field(description="Index of the desired solution.")
16
+ name: str | None = Field(description="Name to be given to the solution. Optional.", default=None)
17
+
18
+
19
+ class IntermediateSolutionRequest(SQLModel):
20
+ """Model of the request to solve intermediate solutions between two solutions."""
21
+
22
+ problem_id: int
23
+ session_id: int | None = Field(default=None)
24
+ parent_state_id: int | None = Field(default=None)
25
+ context: str | None = None # Method context (nimbus, rpm, etc.)
26
+ scalarization_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
27
+ solver: str | None = Field(default=None)
28
+ solver_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
29
+
30
+ num_desired: int | None = Field(default=1)
31
+
32
+ reference_solution_1: SolutionInfo
33
+ reference_solution_2: SolutionInfo
34
+
35
+
36
+ class GenericIntermediateSolutionResponse(SQLModel):
37
+ """The response from computing intermediate values."""
38
+
39
+ state_id: int | None = Field(description="The newly created state id")
40
+ reference_solution_1: SolutionReferenceResponse = Field(
41
+ sa_column=Column(JSON),
42
+ description="The first solution used when computing intermediate solutions.",
43
+ )
44
+ reference_solution_2: SolutionReferenceResponse = Field(
45
+ sa_column=Column(JSON),
46
+ description="The second solution used when computing intermediate solutions.",
47
+ )
48
+ intermediate_solutions: list[SolutionReferenceResponse] = Field(
49
+ sa_column=Column(JSON), description="The intermediate solutions computed."
50
+ )
51
+
52
+
53
+ class ScoreBandsRequest(SQLModel):
54
+ """Model of the request to calculate SCORE bands parameters."""
55
+
56
+ data: list[list[float]] = Field(description="Matrix of objective values")
57
+ objs: list[str] = Field(description="Array of objective names for each column")
58
+
59
+ # Optional parameters with defaults matching the score_bands.py functions
60
+ dist_parameter: float = Field(default=0.05, description="Distance parameter for axis positioning")
61
+ use_absolute_corr: bool = Field(default=False, description="Use absolute correlation values")
62
+ distance_formula: int = Field(default=1, description="Distance formula (1 or 2)")
63
+ flip_axes: bool = Field(default=True, description="Whether to flip axes based on correlation signs")
64
+ clustering_algorithm: str = Field(default="DBSCAN", description="Clustering algorithm (DBSCAN or GMM)")
65
+ clustering_score: str = Field(default="silhoutte", description="Clustering score metric")
66
+
67
+
68
+ class ScoreBandsResponse(SQLModel):
69
+ """Model of the response containing SCORE bands parameters."""
70
+
71
+ groups: list[int] = Field(description="Cluster group assignments for each data point")
72
+ axis_dist: list[float] = Field(description="Normalized axis positions")
73
+ axis_signs: list[int] | None = Field(description="Axis direction signs (1 or -1)")
74
+ obj_order: list[int] = Field(description="Optimal order of objectives")
75
+
76
+
77
+ class GroupScoreRequest(SQLModel):
78
+ """A generic model for requesting SCORE Bands for a state."""
79
+
80
+ model_config = ConfigDict(use_attribute_docstrings=True)
81
+
82
+ problem_id: int
83
+ group_id: int
84
+ """Database ID of the problem to solve."""
85
+ session_id: int | None = Field(default=None)
86
+ parent_state_id: int | None = Field(default=None)
87
+ """State ID of the parent state, if any."""
88
+
89
+ config: SCOREBandsConfig | None = Field(default=None)
90
+ """Configuration for the SCORE bands visualization."""
91
+
92
+ solution_ids: list[int] = Field()
93
+ """List of solution IDs to score."""
94
+
95
+
96
+ class GroupScoreResponse(SQLModel):
97
+ """Model of the response to an EMO score request."""
98
+
99
+ model_config = ConfigDict(use_attribute_docstrings=True)
100
+
101
+ group_iteration_id: int | None = Field(default=None)
102
+ """The state ID of the newly created group iteration."""
103
+
104
+ result: SCOREBandsResult
@@ -0,0 +1,401 @@
1
+ """Defines models for representing the state of various interactive methods."""
2
+
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING
5
+
6
+ from pydantic import ConfigDict, computed_field
7
+ from sqlalchemy.orm import object_session
8
+ from sqlmodel import (
9
+ JSON,
10
+ Column,
11
+ Field,
12
+ Relationship,
13
+ Session,
14
+ SQLModel,
15
+ select,
16
+ )
17
+
18
+ from desdeo.problem import Tensor, VariableType
19
+
20
+ from .state import (
21
+ EMOFetchState,
22
+ EMOIterateState,
23
+ EMOSaveState,
24
+ EMOSCOREState,
25
+ ENautilusState,
26
+ GNIMBUSEndState,
27
+ GNIMBUSOptimizationState,
28
+ GNIMBUSVotingState,
29
+ IntermediateSolutionState,
30
+ NIMBUSClassificationState,
31
+ NIMBUSFinalState,
32
+ NIMBUSInitializationState,
33
+ NIMBUSSaveState,
34
+ RPMState,
35
+ )
36
+ from .user import User
37
+
38
+ if TYPE_CHECKING:
39
+ from .problem import ProblemDB
40
+ from .session import InteractiveSessionDB
41
+
42
+
43
+ class StateKind(str, Enum):
44
+ """Stores the normalized kinds `{method}.{phase}` of supported states.
45
+
46
+ Note:
47
+ Update this when adding new states. Be sure to update `KIND_TO_STATE`
48
+ in this file as well.
49
+ """
50
+
51
+ RPM_SOLVE = "reference_point_method.solve_candidates"
52
+ NIMBUS_SOLVE = "nimbus.solve_candidates"
53
+ NIMBUS_SAVE = "nimbus.save_solutions"
54
+ NIMBUS_INIT = "nimbus.initialize"
55
+ NIMBUS_FINAL = "nimbus.final"
56
+ GNIMBUS_OPTIMIZE = "gnimbus.optimize"
57
+ GNIMBUS_VOTE = "gnimbus.vote"
58
+ GNIMBUS_END = "gnimbus.end"
59
+ EMO_RUN = "emo.run"
60
+ EMO_SAVE = "emo.save_solutions"
61
+ EMO_FETCH = "emo.fetch_solutions"
62
+ EMO_SCORE = "emo.score_bands"
63
+ GENERIC_INTERMEDIATE = "generic.solve_intermediate"
64
+ ENAUTILUS_STEP = "e-nautilus.stepping"
65
+
66
+
67
+ class State(SQLModel, table=True):
68
+ """The 'polymorphic' state to store method information."""
69
+
70
+ __tablename__ = "states"
71
+
72
+ id: int | None = Field(default=None, primary_key=True)
73
+
74
+ # The state is polymorphic on these.
75
+ # TODO (@gialmisi): once SQLModel supports polymorphic table types, refactor this.
76
+ method: str = Field(index=True) # the method name
77
+ phase: str = Field(index=True) # the phase of the method
78
+ kind: StateKind = Field(index=True) # normalized "{method}.{phase}", lowercase
79
+
80
+
81
+ class StateDB(SQLModel, table=True):
82
+ """State holder with a single relationship to the base State."""
83
+
84
+ __tablename__ = "statedb"
85
+
86
+ id: int | None = Field(primary_key=True, default=None)
87
+
88
+ # Optional cross-links (keep as strings in other modules to avoid circulars)
89
+ problem_id: int | None = Field(foreign_key="problemdb.id", default=None)
90
+ session_id: int | None = Field(foreign_key="interactivesessiondb.id", default=None)
91
+
92
+ # lineage
93
+ parent_id: int | None = Field(foreign_key="statedb.id", default=None)
94
+
95
+ # one-to-one to base state
96
+ state_id: int | None = Field(foreign_key="states.id", default=None, index=True)
97
+ base_state: State | None = Relationship(
98
+ sa_relationship_kwargs={
99
+ "uselist": False,
100
+ "single_parent": True,
101
+ "cascade": "all, delete-orphan",
102
+ "foreign_keys": lambda: [StateDB.state_id],
103
+ }
104
+ )
105
+
106
+ parent: "StateDB" = Relationship(
107
+ back_populates="children",
108
+ sa_relationship_kwargs={"remote_side": lambda: StateDB.id},
109
+ )
110
+ children: list["StateDB"] = Relationship(
111
+ back_populates="parent",
112
+ sa_relationship_kwargs={"cascade": "all, delete-orphan"},
113
+ )
114
+
115
+ session: "InteractiveSessionDB" = Relationship(back_populates="states")
116
+ problem: "ProblemDB" = Relationship()
117
+
118
+ @classmethod
119
+ def create(
120
+ cls,
121
+ database_session: Session,
122
+ *,
123
+ problem_id: int | None = None,
124
+ session_id: int | None = None,
125
+ parent_id: int | None = None,
126
+ state: SQLModel | None = None,
127
+ ) -> "StateDB":
128
+ """Build a StateDB + base State with a concrete substate."""
129
+ sub_cls = type(state)
130
+ kind: StateKind | None = None
131
+
132
+ for cls_in_mro in sub_cls.mro():
133
+ if cls_in_mro in SUBSTATE_TO_KIND:
134
+ kind = SUBSTATE_TO_KIND[cls_in_mro]
135
+ break
136
+
137
+ if kind is None:
138
+ raise ValueError(f"No StateKind mapping for substate type {sub_cls!r}")
139
+
140
+ method, phase = _method_phase_from_kind(kind)
141
+ base = State(method=method, phase=phase, kind=kind)
142
+
143
+ row = cls(
144
+ problem_id=problem_id,
145
+ session_id=session_id,
146
+ parent_id=parent_id,
147
+ base_state=base,
148
+ )
149
+ database_session.add(row)
150
+
151
+ # Persist base and link substate PK=FK
152
+ _attach_substate(database_session, base, state)
153
+
154
+ return row
155
+
156
+ @property
157
+ def state(self) -> SQLModel | None:
158
+ """Return the concrete substate instance (e.g., NIMBUSSaveState)...
159
+
160
+ Return the concrete substate instance (e.g., NIMBUSSaveState)
161
+ resolved from the stored `base_state`.
162
+ """
163
+ if self.base_state is None:
164
+ return None
165
+ table: SQLModel | None = KIND_TO_TABLE.get(self.base_state.kind)
166
+
167
+ if table is None:
168
+ return None
169
+
170
+ db_session = object_session(self)
171
+
172
+ if db_session is None:
173
+ # No bound state
174
+ raise RuntimeError("StateDB.state accessed without a bound Session")
175
+
176
+ return db_session.exec(select(table).where(table.id == self.base_state.id)).first()
177
+
178
+
179
+ KIND_TO_TABLE: dict[StateKind, SQLModel] = {
180
+ StateKind.RPM_SOLVE: RPMState,
181
+ StateKind.NIMBUS_SOLVE: NIMBUSClassificationState,
182
+ StateKind.NIMBUS_SAVE: NIMBUSSaveState,
183
+ StateKind.NIMBUS_INIT: NIMBUSInitializationState,
184
+ StateKind.NIMBUS_FINAL: NIMBUSFinalState,
185
+ StateKind.EMO_RUN: EMOIterateState,
186
+ StateKind.GNIMBUS_OPTIMIZE: GNIMBUSOptimizationState,
187
+ StateKind.GNIMBUS_VOTE: GNIMBUSVotingState,
188
+ StateKind.GNIMBUS_END: GNIMBUSEndState,
189
+ StateKind.EMO_SAVE: EMOSaveState,
190
+ StateKind.EMO_FETCH: EMOFetchState,
191
+ StateKind.EMO_SCORE: EMOSCOREState,
192
+ StateKind.GENERIC_INTERMEDIATE: IntermediateSolutionState,
193
+ StateKind.ENAUTILUS_STEP: ENautilusState,
194
+ }
195
+
196
+ SUBSTATE_TO_KIND: dict[SQLModel, StateKind] = {
197
+ RPMState: StateKind.RPM_SOLVE,
198
+ NIMBUSClassificationState: StateKind.NIMBUS_SOLVE,
199
+ NIMBUSSaveState: StateKind.NIMBUS_SAVE,
200
+ NIMBUSInitializationState: StateKind.NIMBUS_INIT,
201
+ NIMBUSFinalState: StateKind.NIMBUS_FINAL,
202
+ EMOIterateState: StateKind.EMO_RUN,
203
+ GNIMBUSOptimizationState: StateKind.GNIMBUS_OPTIMIZE,
204
+ GNIMBUSVotingState: StateKind.GNIMBUS_VOTE,
205
+ GNIMBUSEndState: StateKind.GNIMBUS_END,
206
+ EMOSaveState: StateKind.EMO_SAVE,
207
+ EMOFetchState: StateKind.EMO_FETCH,
208
+ EMOSCOREState: StateKind.EMO_SCORE,
209
+ IntermediateSolutionState: StateKind.GENERIC_INTERMEDIATE,
210
+ ENautilusState: StateKind.ENAUTILUS_STEP,
211
+ }
212
+
213
+
214
+ def _method_phase_from_kind(kind: StateKind) -> tuple[str, str]:
215
+ """Split enum value back to (method, phase)."""
216
+ method, phase = kind.value.split(".", 1)
217
+ return method, phase
218
+
219
+
220
+ def _attach_substate(session, base: State, sub: SQLModel | None) -> None:
221
+ """Persist base; link sub.id = base.id; persist sub."""
222
+ session.add(base)
223
+ session.flush() # assigns base.id
224
+
225
+ if sub is not None:
226
+ sub.id = base.id
227
+ session.add(sub)
228
+ session.flush()
229
+
230
+
231
+ class UserSavedSolutionDB(SQLModel, table=True):
232
+ """Database model of an archived solution."""
233
+
234
+ id: int | None = Field(primary_key=True, default=None)
235
+
236
+ name: str | None = Field(default=None, nullable=True)
237
+ objective_values: dict[str, float] = Field(sa_column=Column(JSON))
238
+ variable_values: dict[str, VariableType] = Field(sa_column=Column(JSON))
239
+ solution_index: int | None
240
+
241
+ # Links
242
+ user_id: int | None = Field(foreign_key="user.id", default=None)
243
+ problem_id: int | None = Field(foreign_key="problemdb.id", default=None)
244
+ origin_state_id: int | None = Field(
245
+ foreign_key="states.id", default=None
246
+ ) # the StateDB where this solution was generated
247
+ save_state_id: int | None = Field(
248
+ foreign_key="states.id", default=None
249
+ ) # the StateDB that explicitly created the save
250
+
251
+ # Back populates
252
+ user: "User" = Relationship(back_populates="archive")
253
+ problem: "ProblemDB" = Relationship(back_populates="solutions")
254
+
255
+ @classmethod
256
+ def from_state_info(
257
+ cls,
258
+ database_session: Session,
259
+ user_id: int,
260
+ problem_id: int,
261
+ state_id: int,
262
+ solution_index: int,
263
+ name: str | None,
264
+ ) -> "UserSavedSolutionDB | None":
265
+ state = database_session.exec(
266
+ select(StateDB).where(
267
+ StateDB.id == state_id,
268
+ StateDB.problem_id == problem_id,
269
+ )
270
+ ).first()
271
+
272
+ if state is None:
273
+ return None
274
+
275
+ objective_values = state.state.result_objective_values[solution_index]
276
+ variable_values = state.state.result_variable_values[solution_index]
277
+
278
+ return cls(
279
+ name=name,
280
+ objective_values=objective_values,
281
+ variable_values=variable_values,
282
+ solution_index=solution_index,
283
+ user_id=user_id,
284
+ problem_id=problem_id,
285
+ origin_state_id=state_id,
286
+ )
287
+
288
+
289
+ class SolutionReferenceBase(SQLModel):
290
+ """A model that functions as a reference to solutions existing in the database.
291
+
292
+ Referenced solutions are not necessarily solutions that the user has saved explicitly. For
293
+ referencing those, see `SavedSolutionReference`.
294
+ """
295
+
296
+ name: str | None = Field(description="Optional name to help identify the solution if, e.g., saved.", default=None)
297
+ solution_index: int | None = Field(
298
+ description="The index of the referenced solution, if multiple solutions exist in the reference state.",
299
+ default=None,
300
+ )
301
+ state: StateDB = Field(description="The reference state with the solution information.")
302
+
303
+ @computed_field
304
+ @property
305
+ def state_id(self) -> int:
306
+ return self.state.id
307
+
308
+ @computed_field
309
+ @property
310
+ def num_solutions(self) -> int:
311
+ return len(self.state.state.solver_results)
312
+
313
+
314
+ class SolutionReference(SolutionReferenceBase):
315
+ """A full solution reference with objectives and variables."""
316
+
317
+ @computed_field
318
+ @property
319
+ def objective_values_all(self) -> list[dict[str, float]]:
320
+ return self.state.state.result_objective_values
321
+
322
+ @computed_field
323
+ @property
324
+ def variable_values_all(self) -> list[dict[str, VariableType | Tensor]]:
325
+ return self.state.state.result_variable_values
326
+
327
+ @computed_field
328
+ @property
329
+ def objective_values(self) -> dict[str, float] | None:
330
+ if self.solution_index is not None:
331
+ return self.state.state.result_objective_values[self.solution_index]
332
+
333
+ return None
334
+
335
+ @computed_field
336
+ @property
337
+ def variable_values(self) -> dict[str, VariableType | Tensor] | None:
338
+ if self.solution_index is not None:
339
+ return self.state.state.result_variable_values[self.solution_index]
340
+
341
+ return None
342
+
343
+
344
+ class SolutionReferenceLite(SolutionReferenceBase):
345
+ """The same as SolutionReference, but without decision variables for more efficient transport over the internet."""
346
+
347
+ @computed_field
348
+ @property
349
+ def objective_values(self) -> dict[str, float] | None:
350
+ if self.solution_index is not None:
351
+ return self.state.state.result_objective_values[self.solution_index]
352
+
353
+ return None
354
+
355
+
356
+ class SolutionReferenceResponse(SQLModel):
357
+ """The response information provided when `SolutionReference` object are returned from the client."""
358
+
359
+ model_config = ConfigDict(from_attributes=True)
360
+
361
+ name: str | None = Field(default=None)
362
+ solution_index: int | None
363
+ state_id: int
364
+ objective_values: dict[str, float] | None
365
+ variable_values: dict[str, "VariableType | Tensor"] | None
366
+
367
+
368
+ class SavedSolutionReference(SQLModel):
369
+ """A model that functions as a reference to solutions that users have chosen to explicitly save in the database."""
370
+
371
+ saved_solution: UserSavedSolutionDB = Field(description="The reference object with the solution information.")
372
+
373
+ @computed_field
374
+ @property
375
+ def name(self) -> str | None:
376
+ return self.saved_solution.name
377
+
378
+ @computed_field
379
+ @property
380
+ def objective_values(self) -> dict[str, float]:
381
+ return self.saved_solution.objective_values
382
+
383
+ @computed_field
384
+ @property
385
+ def variable_values(self) -> dict[str, VariableType | Tensor]:
386
+ return self.saved_solution.variable_values
387
+
388
+ @computed_field
389
+ @property
390
+ def solution_index(self) -> int | None:
391
+ return self.saved_solution.solution_index
392
+
393
+ @computed_field
394
+ @property
395
+ def saved_solution_id(self) -> int:
396
+ return self.saved_solution.id
397
+
398
+ @computed_field
399
+ @property
400
+ def state_id(self) -> int | None:
401
+ return self.saved_solution.origin_state_id
@@ -0,0 +1,158 @@
1
+ """Models specific to the nimbus method."""
2
+
3
+ from typing import Literal
4
+
5
+ from sqlmodel import JSON, Column, Field, SQLModel
6
+
7
+ from .generic import SolutionInfo
8
+ from .generic_states import SolutionReferenceResponse
9
+ from .preference import ReferencePoint
10
+
11
+
12
+ class NIMBUSClassificationRequest(SQLModel):
13
+ """Model of the request to the nimbus method."""
14
+
15
+ problem_id: int
16
+ session_id: int | None = Field(default=None)
17
+ parent_state_id: int | None = Field(default=None)
18
+
19
+ scalarization_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
20
+ solver: str | None = Field(default=None)
21
+ solver_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
22
+ preference: ReferencePoint = Field(Column(JSON))
23
+
24
+ current_objectives: dict[str, float] = Field(
25
+ sa_column=Column(JSON), description="The objectives used for iteration."
26
+ )
27
+ num_desired: int | None = Field(default=1)
28
+
29
+
30
+ class NIMBUSSaveRequest(SQLModel):
31
+ """Request model for saving solutions from any method's state."""
32
+
33
+ problem_id: int
34
+ session_id: int | None = Field(default=None)
35
+ parent_state_id: int | None = Field(default=None)
36
+
37
+ solution_info: list[SolutionInfo]
38
+
39
+ class NIMBUSDeleteSaveRequest(SQLModel):
40
+ """Request model for deletion of a saved solution."""
41
+
42
+ state_id : int = Field(description="The ID of the save state.")
43
+ solution_index: int = Field(description="The ID of the solution within the above state.")
44
+
45
+
46
+ class NIMBUSFinalizeRequest(SQLModel):
47
+ """Request model for finalizing the NIMBUS procedure."""
48
+
49
+ problem_id: int
50
+ session_id: int | None = Field(default=None)
51
+ parent_state_id: int | None = Field(default=None)
52
+
53
+ solution_info: SolutionInfo # the final solution
54
+
55
+
56
+ class NIMBUSClassificationResponse(SQLModel):
57
+ """The response from NIMBUS classification endpoint."""
58
+
59
+ response_type: Literal["nimbus.classification"] = "nimbus.classification"
60
+
61
+ state_id: int | None = Field(description="The newly created state id")
62
+ previous_preference: ReferencePoint = Field(description="The previous preference used.")
63
+ previous_objectives: dict[str, float] = Field(
64
+ sa_column=Column(JSON), description="The previous solutions objectives used for iteration."
65
+ )
66
+ current_solutions: list[SolutionReferenceResponse] = Field(
67
+ description="The solutions from the current iteration of nimbus."
68
+ )
69
+ saved_solutions: list[SolutionReferenceResponse] = Field(
70
+ description="The best candidate solutions saved by the decision maker."
71
+ )
72
+ all_solutions: list[SolutionReferenceResponse] = Field(
73
+ description="All solutions generated by NIMBUS in all iterations."
74
+ )
75
+
76
+
77
+ class NIMBUSInitializationResponse(SQLModel):
78
+ """The response from NIMBUS classification endpoint."""
79
+
80
+ response_type: Literal["nimbus.initialization"] = "nimbus.initialization"
81
+
82
+ state_id: int | None = Field(description="The newly created state id")
83
+ current_solutions: list[SolutionReferenceResponse] = Field(
84
+ description="The solutions from the current interation of nimbus."
85
+ )
86
+ saved_solutions: list[SolutionReferenceResponse] = Field(
87
+ description="The best candidate solutions saved by the decision maker."
88
+ )
89
+ all_solutions: list[SolutionReferenceResponse] = Field(
90
+ description="All solutions generated by NIMBUS in all iterations."
91
+ )
92
+
93
+
94
+ class NIMBUSSaveResponse(SQLModel):
95
+ """The response from NIMBUS save endpoint."""
96
+
97
+ response_type: Literal["nimbus.save"] = "nimbus.save"
98
+
99
+ state_id: int | None = Field(description="The id of the newest state")
100
+
101
+ class NIMBUSDeleteSaveResponse(SQLModel):
102
+ """Response of NIMBUS save deletion."""
103
+
104
+ response_type: str = "nimbus.delete_save"
105
+
106
+ message: str | None
107
+
108
+ class NIMBUSFinalizeResponse(SQLModel):
109
+ """The response from NIMBUS finish endpoint."""
110
+
111
+ response_type: Literal["nimbus.finalize"] = "nimbus.finalize"
112
+
113
+ state_id: int | None = Field(description="The newly created state id")
114
+ final_solution: SolutionReferenceResponse = Field(
115
+ description="The final solution. We do not need the other current solutions."
116
+ )
117
+ saved_solutions: list[SolutionReferenceResponse] = Field(
118
+ description="The best candidate solutions saved by the decision maker."
119
+ )
120
+ all_solutions: list[SolutionReferenceResponse] = Field(
121
+ description="All solutions generated by NIMBUS in all iterations."
122
+ )
123
+
124
+
125
+ class NIMBUSInitializationRequest(SQLModel):
126
+ """Model of the request to the nimbus method."""
127
+
128
+ problem_id: int
129
+ session_id: int | None = Field(default=None)
130
+ parent_state_id: int | None = Field(default=None)
131
+
132
+ starting_point: ReferencePoint | SolutionInfo | None = Field(sa_column=Column(JSON), default=None)
133
+ scalarization_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
134
+ solver: str | None = Field(default=None)
135
+ solver_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
136
+
137
+
138
+ class NIMBUSIntermediateSolutionResponse(SQLModel):
139
+ """The response from NIMBUS classification endpoint."""
140
+
141
+ response_type: Literal["nimbus.intermediate"] = "nimbus.intermediate"
142
+
143
+ state_id: int | None = Field(description="The newly created state id")
144
+ reference_solution_1: dict[str, float] = Field(
145
+ sa_column=Column(JSON), description="The first solution used when computing intermediate points."
146
+ )
147
+ reference_solution_2: dict[str, float]= Field(
148
+ sa_column=Column(JSON), description="The second solution used when computing intermediate points."
149
+ )
150
+ current_solutions: list[SolutionReferenceResponse] = Field(
151
+ description="The solutions from the current iteration of NIMBUS."
152
+ )
153
+ saved_solutions: list[SolutionReferenceResponse] = Field(
154
+ description="The best candidate solutions saved by the decision maker."
155
+ )
156
+ all_solutions: list[SolutionReferenceResponse] = Field(
157
+ description="All solutions generated by NIMBUS in all iterations."
158
+ )