desdeo 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/__init__.py +6 -6
- desdeo/api/app.py +38 -28
- desdeo/api/config.py +65 -44
- desdeo/api/config.toml +23 -12
- desdeo/api/db.py +10 -8
- desdeo/api/db_init.py +12 -6
- desdeo/api/models/__init__.py +220 -20
- desdeo/api/models/archive.py +16 -27
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +44 -6
- desdeo/api/models/problem.py +274 -64
- desdeo/api/models/session.py +4 -1
- desdeo/api/models/state.py +419 -52
- desdeo/api/models/user.py +7 -6
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NIMBUS.py +6 -3
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +201 -4
- desdeo/api/routers/reference_point_method.py +20 -44
- desdeo/api/routers/session.py +50 -26
- desdeo/api/routers/user_authentication.py +180 -26
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +10 -4
- desdeo/api/tests/conftest.py +94 -2
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +550 -72
- desdeo/api/tests/test_routes.py +902 -43
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/database.py +28 -266
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +7 -0
- desdeo/emo/__init__.py +154 -24
- desdeo/emo/hooks/archivers.py +18 -2
- desdeo/emo/methods/EAs.py +128 -5
- desdeo/emo/methods/bases.py +9 -56
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/crossover.py +544 -42
- desdeo/emo/operators/evaluator.py +10 -14
- desdeo/emo/operators/generator.py +127 -24
- desdeo/emo/operators/mutation.py +212 -41
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +956 -214
- desdeo/emo/operators/termination.py +124 -16
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +23 -1
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautilus_navigator.py +7 -6
- desdeo/mcdm/reference_point_method.py +70 -0
- desdeo/problem/__init__.py +16 -11
- desdeo/problem/evaluator.py +4 -5
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +37 -12
- desdeo/problem/infix_parser.py +1 -16
- desdeo/problem/json_parser.py +7 -11
- desdeo/problem/pyomo_evaluator.py +25 -6
- desdeo/problem/schema.py +73 -55
- desdeo/problem/simulator_evaluator.py +65 -15
- desdeo/problem/testproblems/__init__.py +26 -11
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/forest_problem.py +77 -69
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/zdt_problem.py +4 -1
- desdeo/problem/utils.py +1 -1
- desdeo/tools/__init__.py +39 -21
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +22 -2
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/indicators_binary.py +107 -1
- desdeo/tools/indicators_unary.py +3 -16
- desdeo/tools/message.py +33 -2
- desdeo/tools/non_dominated_sorting.py +4 -3
- desdeo/tools/patterns.py +9 -7
- desdeo/tools/pyomo_solver_interfaces.py +49 -36
- desdeo/tools/reference_vectors.py +118 -351
- desdeo/tools/scalarization.py +340 -1413
- desdeo/tools/score_bands.py +491 -328
- desdeo/tools/utils.py +117 -49
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/utopia_problem.py +1 -1
- desdeo/utopia_stuff/utopia_problem_old.py +1 -1
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/METADATA +47 -30
- desdeo-2.1.1.dist-info/RECORD +180 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info}/WHEEL +1 -1
- desdeo-2.0.0.dist-info/RECORD +0 -120
- /desdeo/api/utils/{logger.py → _logger.py} +0 -0
- {desdeo-2.0.0.dist-info → desdeo-2.1.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Generic models for the DESDEO API."""
|
|
2
|
+
|
|
3
|
+
from pydantic import ConfigDict
|
|
4
|
+
from sqlmodel import JSON, Column, Field, SQLModel
|
|
5
|
+
|
|
6
|
+
from desdeo.tools.score_bands import SCOREBandsConfig, SCOREBandsResult
|
|
7
|
+
|
|
8
|
+
from .generic_states import SolutionReferenceResponse
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SolutionInfo(SQLModel):
|
|
12
|
+
"""Used when we wish to reference a solution in some `StateDB` stored in the database."""
|
|
13
|
+
|
|
14
|
+
state_id: int = Field(description="State of the desired solution.")
|
|
15
|
+
solution_index: int = Field(description="Index of the desired solution.")
|
|
16
|
+
name: str | None = Field(description="Name to be given to the solution. Optional.", default=None)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class IntermediateSolutionRequest(SQLModel):
|
|
20
|
+
"""Model of the request to solve intermediate solutions between two solutions."""
|
|
21
|
+
|
|
22
|
+
problem_id: int
|
|
23
|
+
session_id: int | None = Field(default=None)
|
|
24
|
+
parent_state_id: int | None = Field(default=None)
|
|
25
|
+
context: str | None = None # Method context (nimbus, rpm, etc.)
|
|
26
|
+
scalarization_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
|
|
27
|
+
solver: str | None = Field(default=None)
|
|
28
|
+
solver_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
|
|
29
|
+
|
|
30
|
+
num_desired: int | None = Field(default=1)
|
|
31
|
+
|
|
32
|
+
reference_solution_1: SolutionInfo
|
|
33
|
+
reference_solution_2: SolutionInfo
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class GenericIntermediateSolutionResponse(SQLModel):
|
|
37
|
+
"""The response from computing intermediate values."""
|
|
38
|
+
|
|
39
|
+
state_id: int | None = Field(description="The newly created state id")
|
|
40
|
+
reference_solution_1: SolutionReferenceResponse = Field(
|
|
41
|
+
sa_column=Column(JSON),
|
|
42
|
+
description="The first solution used when computing intermediate solutions.",
|
|
43
|
+
)
|
|
44
|
+
reference_solution_2: SolutionReferenceResponse = Field(
|
|
45
|
+
sa_column=Column(JSON),
|
|
46
|
+
description="The second solution used when computing intermediate solutions.",
|
|
47
|
+
)
|
|
48
|
+
intermediate_solutions: list[SolutionReferenceResponse] = Field(
|
|
49
|
+
sa_column=Column(JSON), description="The intermediate solutions computed."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ScoreBandsRequest(SQLModel):
|
|
54
|
+
"""Model of the request to calculate SCORE bands parameters."""
|
|
55
|
+
|
|
56
|
+
data: list[list[float]] = Field(description="Matrix of objective values")
|
|
57
|
+
objs: list[str] = Field(description="Array of objective names for each column")
|
|
58
|
+
|
|
59
|
+
# Optional parameters with defaults matching the score_bands.py functions
|
|
60
|
+
dist_parameter: float = Field(default=0.05, description="Distance parameter for axis positioning")
|
|
61
|
+
use_absolute_corr: bool = Field(default=False, description="Use absolute correlation values")
|
|
62
|
+
distance_formula: int = Field(default=1, description="Distance formula (1 or 2)")
|
|
63
|
+
flip_axes: bool = Field(default=True, description="Whether to flip axes based on correlation signs")
|
|
64
|
+
clustering_algorithm: str = Field(default="DBSCAN", description="Clustering algorithm (DBSCAN or GMM)")
|
|
65
|
+
clustering_score: str = Field(default="silhoutte", description="Clustering score metric")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ScoreBandsResponse(SQLModel):
|
|
69
|
+
"""Model of the response containing SCORE bands parameters."""
|
|
70
|
+
|
|
71
|
+
groups: list[int] = Field(description="Cluster group assignments for each data point")
|
|
72
|
+
axis_dist: list[float] = Field(description="Normalized axis positions")
|
|
73
|
+
axis_signs: list[int] | None = Field(description="Axis direction signs (1 or -1)")
|
|
74
|
+
obj_order: list[int] = Field(description="Optimal order of objectives")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class GroupScoreRequest(SQLModel):
|
|
78
|
+
"""A generic model for requesting SCORE Bands for a state."""
|
|
79
|
+
|
|
80
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
81
|
+
|
|
82
|
+
problem_id: int
|
|
83
|
+
group_id: int
|
|
84
|
+
"""Database ID of the problem to solve."""
|
|
85
|
+
session_id: int | None = Field(default=None)
|
|
86
|
+
parent_state_id: int | None = Field(default=None)
|
|
87
|
+
"""State ID of the parent state, if any."""
|
|
88
|
+
|
|
89
|
+
config: SCOREBandsConfig | None = Field(default=None)
|
|
90
|
+
"""Configuration for the SCORE bands visualization."""
|
|
91
|
+
|
|
92
|
+
solution_ids: list[int] = Field()
|
|
93
|
+
"""List of solution IDs to score."""
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class GroupScoreResponse(SQLModel):
|
|
97
|
+
"""Model of the response to an EMO score request."""
|
|
98
|
+
|
|
99
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
100
|
+
|
|
101
|
+
group_iteration_id: int | None = Field(default=None)
|
|
102
|
+
"""The state ID of the newly created group iteration."""
|
|
103
|
+
|
|
104
|
+
result: SCOREBandsResult
|
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
"""Defines models for representing the state of various interactive methods."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from pydantic import ConfigDict, computed_field
|
|
7
|
+
from sqlalchemy.orm import object_session
|
|
8
|
+
from sqlmodel import (
|
|
9
|
+
JSON,
|
|
10
|
+
Column,
|
|
11
|
+
Field,
|
|
12
|
+
Relationship,
|
|
13
|
+
Session,
|
|
14
|
+
SQLModel,
|
|
15
|
+
select,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from desdeo.problem import Tensor, VariableType
|
|
19
|
+
|
|
20
|
+
from .state import (
|
|
21
|
+
EMOFetchState,
|
|
22
|
+
EMOIterateState,
|
|
23
|
+
EMOSaveState,
|
|
24
|
+
EMOSCOREState,
|
|
25
|
+
ENautilusState,
|
|
26
|
+
GNIMBUSEndState,
|
|
27
|
+
GNIMBUSOptimizationState,
|
|
28
|
+
GNIMBUSVotingState,
|
|
29
|
+
IntermediateSolutionState,
|
|
30
|
+
NIMBUSClassificationState,
|
|
31
|
+
NIMBUSFinalState,
|
|
32
|
+
NIMBUSInitializationState,
|
|
33
|
+
NIMBUSSaveState,
|
|
34
|
+
RPMState,
|
|
35
|
+
)
|
|
36
|
+
from .user import User
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from .problem import ProblemDB
|
|
40
|
+
from .session import InteractiveSessionDB
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class StateKind(str, Enum):
|
|
44
|
+
"""Stores the normalized kinds `{method}.{phase}` of supported states.
|
|
45
|
+
|
|
46
|
+
Note:
|
|
47
|
+
Update this when adding new states. Be sure to update `KIND_TO_STATE`
|
|
48
|
+
in this file as well.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
RPM_SOLVE = "reference_point_method.solve_candidates"
|
|
52
|
+
NIMBUS_SOLVE = "nimbus.solve_candidates"
|
|
53
|
+
NIMBUS_SAVE = "nimbus.save_solutions"
|
|
54
|
+
NIMBUS_INIT = "nimbus.initialize"
|
|
55
|
+
NIMBUS_FINAL = "nimbus.final"
|
|
56
|
+
GNIMBUS_OPTIMIZE = "gnimbus.optimize"
|
|
57
|
+
GNIMBUS_VOTE = "gnimbus.vote"
|
|
58
|
+
GNIMBUS_END = "gnimbus.end"
|
|
59
|
+
EMO_RUN = "emo.run"
|
|
60
|
+
EMO_SAVE = "emo.save_solutions"
|
|
61
|
+
EMO_FETCH = "emo.fetch_solutions"
|
|
62
|
+
EMO_SCORE = "emo.score_bands"
|
|
63
|
+
GENERIC_INTERMEDIATE = "generic.solve_intermediate"
|
|
64
|
+
ENAUTILUS_STEP = "e-nautilus.stepping"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class State(SQLModel, table=True):
|
|
68
|
+
"""The 'polymorphic' state to store method information."""
|
|
69
|
+
|
|
70
|
+
__tablename__ = "states"
|
|
71
|
+
|
|
72
|
+
id: int | None = Field(default=None, primary_key=True)
|
|
73
|
+
|
|
74
|
+
# The state is polymorphic on these.
|
|
75
|
+
# TODO (@gialmisi): once SQLModel supports polymorphic table types, refactor this.
|
|
76
|
+
method: str = Field(index=True) # the method name
|
|
77
|
+
phase: str = Field(index=True) # the phase of the method
|
|
78
|
+
kind: StateKind = Field(index=True) # normalized "{method}.{phase}", lowercase
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class StateDB(SQLModel, table=True):
|
|
82
|
+
"""State holder with a single relationship to the base State."""
|
|
83
|
+
|
|
84
|
+
__tablename__ = "statedb"
|
|
85
|
+
|
|
86
|
+
id: int | None = Field(primary_key=True, default=None)
|
|
87
|
+
|
|
88
|
+
# Optional cross-links (keep as strings in other modules to avoid circulars)
|
|
89
|
+
problem_id: int | None = Field(foreign_key="problemdb.id", default=None)
|
|
90
|
+
session_id: int | None = Field(foreign_key="interactivesessiondb.id", default=None)
|
|
91
|
+
|
|
92
|
+
# lineage
|
|
93
|
+
parent_id: int | None = Field(foreign_key="statedb.id", default=None)
|
|
94
|
+
|
|
95
|
+
# one-to-one to base state
|
|
96
|
+
state_id: int | None = Field(foreign_key="states.id", default=None, index=True)
|
|
97
|
+
base_state: State | None = Relationship(
|
|
98
|
+
sa_relationship_kwargs={
|
|
99
|
+
"uselist": False,
|
|
100
|
+
"single_parent": True,
|
|
101
|
+
"cascade": "all, delete-orphan",
|
|
102
|
+
"foreign_keys": lambda: [StateDB.state_id],
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
parent: "StateDB" = Relationship(
|
|
107
|
+
back_populates="children",
|
|
108
|
+
sa_relationship_kwargs={"remote_side": lambda: StateDB.id},
|
|
109
|
+
)
|
|
110
|
+
children: list["StateDB"] = Relationship(
|
|
111
|
+
back_populates="parent",
|
|
112
|
+
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
session: "InteractiveSessionDB" = Relationship(back_populates="states")
|
|
116
|
+
problem: "ProblemDB" = Relationship()
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def create(
|
|
120
|
+
cls,
|
|
121
|
+
database_session: Session,
|
|
122
|
+
*,
|
|
123
|
+
problem_id: int | None = None,
|
|
124
|
+
session_id: int | None = None,
|
|
125
|
+
parent_id: int | None = None,
|
|
126
|
+
state: SQLModel | None = None,
|
|
127
|
+
) -> "StateDB":
|
|
128
|
+
"""Build a StateDB + base State with a concrete substate."""
|
|
129
|
+
sub_cls = type(state)
|
|
130
|
+
kind: StateKind | None = None
|
|
131
|
+
|
|
132
|
+
for cls_in_mro in sub_cls.mro():
|
|
133
|
+
if cls_in_mro in SUBSTATE_TO_KIND:
|
|
134
|
+
kind = SUBSTATE_TO_KIND[cls_in_mro]
|
|
135
|
+
break
|
|
136
|
+
|
|
137
|
+
if kind is None:
|
|
138
|
+
raise ValueError(f"No StateKind mapping for substate type {sub_cls!r}")
|
|
139
|
+
|
|
140
|
+
method, phase = _method_phase_from_kind(kind)
|
|
141
|
+
base = State(method=method, phase=phase, kind=kind)
|
|
142
|
+
|
|
143
|
+
row = cls(
|
|
144
|
+
problem_id=problem_id,
|
|
145
|
+
session_id=session_id,
|
|
146
|
+
parent_id=parent_id,
|
|
147
|
+
base_state=base,
|
|
148
|
+
)
|
|
149
|
+
database_session.add(row)
|
|
150
|
+
|
|
151
|
+
# Persist base and link substate PK=FK
|
|
152
|
+
_attach_substate(database_session, base, state)
|
|
153
|
+
|
|
154
|
+
return row
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def state(self) -> SQLModel | None:
|
|
158
|
+
"""Return the concrete substate instance (e.g., NIMBUSSaveState)...
|
|
159
|
+
|
|
160
|
+
Return the concrete substate instance (e.g., NIMBUSSaveState)
|
|
161
|
+
resolved from the stored `base_state`.
|
|
162
|
+
"""
|
|
163
|
+
if self.base_state is None:
|
|
164
|
+
return None
|
|
165
|
+
table: SQLModel | None = KIND_TO_TABLE.get(self.base_state.kind)
|
|
166
|
+
|
|
167
|
+
if table is None:
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
db_session = object_session(self)
|
|
171
|
+
|
|
172
|
+
if db_session is None:
|
|
173
|
+
# No bound state
|
|
174
|
+
raise RuntimeError("StateDB.state accessed without a bound Session")
|
|
175
|
+
|
|
176
|
+
return db_session.exec(select(table).where(table.id == self.base_state.id)).first()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
KIND_TO_TABLE: dict[StateKind, SQLModel] = {
|
|
180
|
+
StateKind.RPM_SOLVE: RPMState,
|
|
181
|
+
StateKind.NIMBUS_SOLVE: NIMBUSClassificationState,
|
|
182
|
+
StateKind.NIMBUS_SAVE: NIMBUSSaveState,
|
|
183
|
+
StateKind.NIMBUS_INIT: NIMBUSInitializationState,
|
|
184
|
+
StateKind.NIMBUS_FINAL: NIMBUSFinalState,
|
|
185
|
+
StateKind.EMO_RUN: EMOIterateState,
|
|
186
|
+
StateKind.GNIMBUS_OPTIMIZE: GNIMBUSOptimizationState,
|
|
187
|
+
StateKind.GNIMBUS_VOTE: GNIMBUSVotingState,
|
|
188
|
+
StateKind.GNIMBUS_END: GNIMBUSEndState,
|
|
189
|
+
StateKind.EMO_SAVE: EMOSaveState,
|
|
190
|
+
StateKind.EMO_FETCH: EMOFetchState,
|
|
191
|
+
StateKind.EMO_SCORE: EMOSCOREState,
|
|
192
|
+
StateKind.GENERIC_INTERMEDIATE: IntermediateSolutionState,
|
|
193
|
+
StateKind.ENAUTILUS_STEP: ENautilusState,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
SUBSTATE_TO_KIND: dict[SQLModel, StateKind] = {
|
|
197
|
+
RPMState: StateKind.RPM_SOLVE,
|
|
198
|
+
NIMBUSClassificationState: StateKind.NIMBUS_SOLVE,
|
|
199
|
+
NIMBUSSaveState: StateKind.NIMBUS_SAVE,
|
|
200
|
+
NIMBUSInitializationState: StateKind.NIMBUS_INIT,
|
|
201
|
+
NIMBUSFinalState: StateKind.NIMBUS_FINAL,
|
|
202
|
+
EMOIterateState: StateKind.EMO_RUN,
|
|
203
|
+
GNIMBUSOptimizationState: StateKind.GNIMBUS_OPTIMIZE,
|
|
204
|
+
GNIMBUSVotingState: StateKind.GNIMBUS_VOTE,
|
|
205
|
+
GNIMBUSEndState: StateKind.GNIMBUS_END,
|
|
206
|
+
EMOSaveState: StateKind.EMO_SAVE,
|
|
207
|
+
EMOFetchState: StateKind.EMO_FETCH,
|
|
208
|
+
EMOSCOREState: StateKind.EMO_SCORE,
|
|
209
|
+
IntermediateSolutionState: StateKind.GENERIC_INTERMEDIATE,
|
|
210
|
+
ENautilusState: StateKind.ENAUTILUS_STEP,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _method_phase_from_kind(kind: StateKind) -> tuple[str, str]:
|
|
215
|
+
"""Split enum value back to (method, phase)."""
|
|
216
|
+
method, phase = kind.value.split(".", 1)
|
|
217
|
+
return method, phase
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _attach_substate(session, base: State, sub: SQLModel | None) -> None:
|
|
221
|
+
"""Persist base; link sub.id = base.id; persist sub."""
|
|
222
|
+
session.add(base)
|
|
223
|
+
session.flush() # assigns base.id
|
|
224
|
+
|
|
225
|
+
if sub is not None:
|
|
226
|
+
sub.id = base.id
|
|
227
|
+
session.add(sub)
|
|
228
|
+
session.flush()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class UserSavedSolutionDB(SQLModel, table=True):
|
|
232
|
+
"""Database model of an archived solution."""
|
|
233
|
+
|
|
234
|
+
id: int | None = Field(primary_key=True, default=None)
|
|
235
|
+
|
|
236
|
+
name: str | None = Field(default=None, nullable=True)
|
|
237
|
+
objective_values: dict[str, float] = Field(sa_column=Column(JSON))
|
|
238
|
+
variable_values: dict[str, VariableType] = Field(sa_column=Column(JSON))
|
|
239
|
+
solution_index: int | None
|
|
240
|
+
|
|
241
|
+
# Links
|
|
242
|
+
user_id: int | None = Field(foreign_key="user.id", default=None)
|
|
243
|
+
problem_id: int | None = Field(foreign_key="problemdb.id", default=None)
|
|
244
|
+
origin_state_id: int | None = Field(
|
|
245
|
+
foreign_key="states.id", default=None
|
|
246
|
+
) # the StateDB where this solution was generated
|
|
247
|
+
save_state_id: int | None = Field(
|
|
248
|
+
foreign_key="states.id", default=None
|
|
249
|
+
) # the StateDB that explicitly created the save
|
|
250
|
+
|
|
251
|
+
# Back populates
|
|
252
|
+
user: "User" = Relationship(back_populates="archive")
|
|
253
|
+
problem: "ProblemDB" = Relationship(back_populates="solutions")
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def from_state_info(
|
|
257
|
+
cls,
|
|
258
|
+
database_session: Session,
|
|
259
|
+
user_id: int,
|
|
260
|
+
problem_id: int,
|
|
261
|
+
state_id: int,
|
|
262
|
+
solution_index: int,
|
|
263
|
+
name: str | None,
|
|
264
|
+
) -> "UserSavedSolutionDB | None":
|
|
265
|
+
state = database_session.exec(
|
|
266
|
+
select(StateDB).where(
|
|
267
|
+
StateDB.id == state_id,
|
|
268
|
+
StateDB.problem_id == problem_id,
|
|
269
|
+
)
|
|
270
|
+
).first()
|
|
271
|
+
|
|
272
|
+
if state is None:
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
objective_values = state.state.result_objective_values[solution_index]
|
|
276
|
+
variable_values = state.state.result_variable_values[solution_index]
|
|
277
|
+
|
|
278
|
+
return cls(
|
|
279
|
+
name=name,
|
|
280
|
+
objective_values=objective_values,
|
|
281
|
+
variable_values=variable_values,
|
|
282
|
+
solution_index=solution_index,
|
|
283
|
+
user_id=user_id,
|
|
284
|
+
problem_id=problem_id,
|
|
285
|
+
origin_state_id=state_id,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class SolutionReferenceBase(SQLModel):
|
|
290
|
+
"""A model that functions as a reference to solutions existing in the database.
|
|
291
|
+
|
|
292
|
+
Referenced solutions are not necessarily solutions that the user has saved explicitly. For
|
|
293
|
+
referencing those, see `SavedSolutionReference`.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
name: str | None = Field(description="Optional name to help identify the solution if, e.g., saved.", default=None)
|
|
297
|
+
solution_index: int | None = Field(
|
|
298
|
+
description="The index of the referenced solution, if multiple solutions exist in the reference state.",
|
|
299
|
+
default=None,
|
|
300
|
+
)
|
|
301
|
+
state: StateDB = Field(description="The reference state with the solution information.")
|
|
302
|
+
|
|
303
|
+
@computed_field
|
|
304
|
+
@property
|
|
305
|
+
def state_id(self) -> int:
|
|
306
|
+
return self.state.id
|
|
307
|
+
|
|
308
|
+
@computed_field
|
|
309
|
+
@property
|
|
310
|
+
def num_solutions(self) -> int:
|
|
311
|
+
return len(self.state.state.solver_results)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class SolutionReference(SolutionReferenceBase):
|
|
315
|
+
"""A full solution reference with objectives and variables."""
|
|
316
|
+
|
|
317
|
+
@computed_field
|
|
318
|
+
@property
|
|
319
|
+
def objective_values_all(self) -> list[dict[str, float]]:
|
|
320
|
+
return self.state.state.result_objective_values
|
|
321
|
+
|
|
322
|
+
@computed_field
|
|
323
|
+
@property
|
|
324
|
+
def variable_values_all(self) -> list[dict[str, VariableType | Tensor]]:
|
|
325
|
+
return self.state.state.result_variable_values
|
|
326
|
+
|
|
327
|
+
@computed_field
|
|
328
|
+
@property
|
|
329
|
+
def objective_values(self) -> dict[str, float] | None:
|
|
330
|
+
if self.solution_index is not None:
|
|
331
|
+
return self.state.state.result_objective_values[self.solution_index]
|
|
332
|
+
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
@computed_field
|
|
336
|
+
@property
|
|
337
|
+
def variable_values(self) -> dict[str, VariableType | Tensor] | None:
|
|
338
|
+
if self.solution_index is not None:
|
|
339
|
+
return self.state.state.result_variable_values[self.solution_index]
|
|
340
|
+
|
|
341
|
+
return None
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class SolutionReferenceLite(SolutionReferenceBase):
|
|
345
|
+
"""The same as SolutionReference, but without decision variables for more efficient transport over the internet."""
|
|
346
|
+
|
|
347
|
+
@computed_field
|
|
348
|
+
@property
|
|
349
|
+
def objective_values(self) -> dict[str, float] | None:
|
|
350
|
+
if self.solution_index is not None:
|
|
351
|
+
return self.state.state.result_objective_values[self.solution_index]
|
|
352
|
+
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class SolutionReferenceResponse(SQLModel):
|
|
357
|
+
"""The response information provided when `SolutionReference` object are returned from the client."""
|
|
358
|
+
|
|
359
|
+
model_config = ConfigDict(from_attributes=True)
|
|
360
|
+
|
|
361
|
+
name: str | None = Field(default=None)
|
|
362
|
+
solution_index: int | None
|
|
363
|
+
state_id: int
|
|
364
|
+
objective_values: dict[str, float] | None
|
|
365
|
+
variable_values: dict[str, "VariableType | Tensor"] | None
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class SavedSolutionReference(SQLModel):
|
|
369
|
+
"""A model that functions as a reference to solutions that users have chosen to explicitly save in the database."""
|
|
370
|
+
|
|
371
|
+
saved_solution: UserSavedSolutionDB = Field(description="The reference object with the solution information.")
|
|
372
|
+
|
|
373
|
+
@computed_field
|
|
374
|
+
@property
|
|
375
|
+
def name(self) -> str | None:
|
|
376
|
+
return self.saved_solution.name
|
|
377
|
+
|
|
378
|
+
@computed_field
|
|
379
|
+
@property
|
|
380
|
+
def objective_values(self) -> dict[str, float]:
|
|
381
|
+
return self.saved_solution.objective_values
|
|
382
|
+
|
|
383
|
+
@computed_field
|
|
384
|
+
@property
|
|
385
|
+
def variable_values(self) -> dict[str, VariableType | Tensor]:
|
|
386
|
+
return self.saved_solution.variable_values
|
|
387
|
+
|
|
388
|
+
@computed_field
|
|
389
|
+
@property
|
|
390
|
+
def solution_index(self) -> int | None:
|
|
391
|
+
return self.saved_solution.solution_index
|
|
392
|
+
|
|
393
|
+
@computed_field
|
|
394
|
+
@property
|
|
395
|
+
def saved_solution_id(self) -> int:
|
|
396
|
+
return self.saved_solution.id
|
|
397
|
+
|
|
398
|
+
@computed_field
|
|
399
|
+
@property
|
|
400
|
+
def state_id(self) -> int | None:
|
|
401
|
+
return self.saved_solution.origin_state_id
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Models specific to the nimbus method."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from sqlmodel import JSON, Column, Field, SQLModel
|
|
6
|
+
|
|
7
|
+
from .generic import SolutionInfo
|
|
8
|
+
from .generic_states import SolutionReferenceResponse
|
|
9
|
+
from .preference import ReferencePoint
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NIMBUSClassificationRequest(SQLModel):
|
|
13
|
+
"""Model of the request to the nimbus method."""
|
|
14
|
+
|
|
15
|
+
problem_id: int
|
|
16
|
+
session_id: int | None = Field(default=None)
|
|
17
|
+
parent_state_id: int | None = Field(default=None)
|
|
18
|
+
|
|
19
|
+
scalarization_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
|
|
20
|
+
solver: str | None = Field(default=None)
|
|
21
|
+
solver_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
|
|
22
|
+
preference: ReferencePoint = Field(Column(JSON))
|
|
23
|
+
|
|
24
|
+
current_objectives: dict[str, float] = Field(
|
|
25
|
+
sa_column=Column(JSON), description="The objectives used for iteration."
|
|
26
|
+
)
|
|
27
|
+
num_desired: int | None = Field(default=1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class NIMBUSSaveRequest(SQLModel):
|
|
31
|
+
"""Request model for saving solutions from any method's state."""
|
|
32
|
+
|
|
33
|
+
problem_id: int
|
|
34
|
+
session_id: int | None = Field(default=None)
|
|
35
|
+
parent_state_id: int | None = Field(default=None)
|
|
36
|
+
|
|
37
|
+
solution_info: list[SolutionInfo]
|
|
38
|
+
|
|
39
|
+
class NIMBUSDeleteSaveRequest(SQLModel):
|
|
40
|
+
"""Request model for deletion of a saved solution."""
|
|
41
|
+
|
|
42
|
+
state_id : int = Field(description="The ID of the save state.")
|
|
43
|
+
solution_index: int = Field(description="The ID of the solution within the above state.")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class NIMBUSFinalizeRequest(SQLModel):
|
|
47
|
+
"""Request model for finalizing the NIMBUS procedure."""
|
|
48
|
+
|
|
49
|
+
problem_id: int
|
|
50
|
+
session_id: int | None = Field(default=None)
|
|
51
|
+
parent_state_id: int | None = Field(default=None)
|
|
52
|
+
|
|
53
|
+
solution_info: SolutionInfo # the final solution
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class NIMBUSClassificationResponse(SQLModel):
|
|
57
|
+
"""The response from NIMBUS classification endpoint."""
|
|
58
|
+
|
|
59
|
+
response_type: Literal["nimbus.classification"] = "nimbus.classification"
|
|
60
|
+
|
|
61
|
+
state_id: int | None = Field(description="The newly created state id")
|
|
62
|
+
previous_preference: ReferencePoint = Field(description="The previous preference used.")
|
|
63
|
+
previous_objectives: dict[str, float] = Field(
|
|
64
|
+
sa_column=Column(JSON), description="The previous solutions objectives used for iteration."
|
|
65
|
+
)
|
|
66
|
+
current_solutions: list[SolutionReferenceResponse] = Field(
|
|
67
|
+
description="The solutions from the current iteration of nimbus."
|
|
68
|
+
)
|
|
69
|
+
saved_solutions: list[SolutionReferenceResponse] = Field(
|
|
70
|
+
description="The best candidate solutions saved by the decision maker."
|
|
71
|
+
)
|
|
72
|
+
all_solutions: list[SolutionReferenceResponse] = Field(
|
|
73
|
+
description="All solutions generated by NIMBUS in all iterations."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class NIMBUSInitializationResponse(SQLModel):
|
|
78
|
+
"""The response from NIMBUS classification endpoint."""
|
|
79
|
+
|
|
80
|
+
response_type: Literal["nimbus.initialization"] = "nimbus.initialization"
|
|
81
|
+
|
|
82
|
+
state_id: int | None = Field(description="The newly created state id")
|
|
83
|
+
current_solutions: list[SolutionReferenceResponse] = Field(
|
|
84
|
+
description="The solutions from the current interation of nimbus."
|
|
85
|
+
)
|
|
86
|
+
saved_solutions: list[SolutionReferenceResponse] = Field(
|
|
87
|
+
description="The best candidate solutions saved by the decision maker."
|
|
88
|
+
)
|
|
89
|
+
all_solutions: list[SolutionReferenceResponse] = Field(
|
|
90
|
+
description="All solutions generated by NIMBUS in all iterations."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class NIMBUSSaveResponse(SQLModel):
|
|
95
|
+
"""The response from NIMBUS save endpoint."""
|
|
96
|
+
|
|
97
|
+
response_type: Literal["nimbus.save"] = "nimbus.save"
|
|
98
|
+
|
|
99
|
+
state_id: int | None = Field(description="The id of the newest state")
|
|
100
|
+
|
|
101
|
+
class NIMBUSDeleteSaveResponse(SQLModel):
|
|
102
|
+
"""Response of NIMBUS save deletion."""
|
|
103
|
+
|
|
104
|
+
response_type: str = "nimbus.delete_save"
|
|
105
|
+
|
|
106
|
+
message: str | None
|
|
107
|
+
|
|
108
|
+
class NIMBUSFinalizeResponse(SQLModel):
|
|
109
|
+
"""The response from NIMBUS finish endpoint."""
|
|
110
|
+
|
|
111
|
+
response_type: Literal["nimbus.finalize"] = "nimbus.finalize"
|
|
112
|
+
|
|
113
|
+
state_id: int | None = Field(description="The newly created state id")
|
|
114
|
+
final_solution: SolutionReferenceResponse = Field(
|
|
115
|
+
description="The final solution. We do not need the other current solutions."
|
|
116
|
+
)
|
|
117
|
+
saved_solutions: list[SolutionReferenceResponse] = Field(
|
|
118
|
+
description="The best candidate solutions saved by the decision maker."
|
|
119
|
+
)
|
|
120
|
+
all_solutions: list[SolutionReferenceResponse] = Field(
|
|
121
|
+
description="All solutions generated by NIMBUS in all iterations."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class NIMBUSInitializationRequest(SQLModel):
|
|
126
|
+
"""Model of the request to the nimbus method."""
|
|
127
|
+
|
|
128
|
+
problem_id: int
|
|
129
|
+
session_id: int | None = Field(default=None)
|
|
130
|
+
parent_state_id: int | None = Field(default=None)
|
|
131
|
+
|
|
132
|
+
starting_point: ReferencePoint | SolutionInfo | None = Field(sa_column=Column(JSON), default=None)
|
|
133
|
+
scalarization_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
|
|
134
|
+
solver: str | None = Field(default=None)
|
|
135
|
+
solver_options: dict[str, float | str | bool] | None = Field(sa_column=Column(JSON), default=None)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class NIMBUSIntermediateSolutionResponse(SQLModel):
|
|
139
|
+
"""The response from NIMBUS classification endpoint."""
|
|
140
|
+
|
|
141
|
+
response_type: Literal["nimbus.intermediate"] = "nimbus.intermediate"
|
|
142
|
+
|
|
143
|
+
state_id: int | None = Field(description="The newly created state id")
|
|
144
|
+
reference_solution_1: dict[str, float] = Field(
|
|
145
|
+
sa_column=Column(JSON), description="The first solution used when computing intermediate points."
|
|
146
|
+
)
|
|
147
|
+
reference_solution_2: dict[str, float]= Field(
|
|
148
|
+
sa_column=Column(JSON), description="The second solution used when computing intermediate points."
|
|
149
|
+
)
|
|
150
|
+
current_solutions: list[SolutionReferenceResponse] = Field(
|
|
151
|
+
description="The solutions from the current iteration of NIMBUS."
|
|
152
|
+
)
|
|
153
|
+
saved_solutions: list[SolutionReferenceResponse] = Field(
|
|
154
|
+
description="The best candidate solutions saved by the decision maker."
|
|
155
|
+
)
|
|
156
|
+
all_solutions: list[SolutionReferenceResponse] = Field(
|
|
157
|
+
description="All solutions generated by NIMBUS in all iterations."
|
|
158
|
+
)
|