scipplan 0.2.1a0__py2.py3-none-any.whl → 0.2.2a0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build/lib/scipplan/__init__.py +5 -0
- build/lib/scipplan/config.py +163 -0
- build/lib/scipplan/helpers.py +31 -0
- build/lib/scipplan/parse_model.py +284 -0
- build/lib/scipplan/plan_model.py +339 -0
- build/lib/scipplan/scipplan.py +218 -0
- build/lib/scipplan/variables.py +91 -0
- build/lib/scipplan/zero_crossing.py +29 -0
- scipplan/__init__.py +3 -3
- scipplan/parse_model.py +3 -2
- scipplan/scipplan.py +18 -16
- scipplan/zero_crossing.py +2 -1
- {scipplan-0.2.1a0.dist-info → scipplan-0.2.2a0.dist-info}/METADATA +1 -1
- scipplan-0.2.2a0.dist-info/RECORD +28 -0
- scipplan-0.2.2a0.dist-info/top_level.txt +5 -0
- scipplan-0.2.1a0.dist-info/RECORD +0 -20
- scipplan-0.2.1a0.dist-info/top_level.txt +0 -1
- {scipplan-0.2.1a0.dist-info → scipplan-0.2.2a0.dist-info}/LICENSE +0 -0
- {scipplan-0.2.1a0.dist-info → scipplan-0.2.2a0.dist-info}/WHEEL +0 -0
- {scipplan-0.2.1a0.dist-info → scipplan-0.2.2a0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,339 @@
|
|
1
|
+
from .config import Config
|
2
|
+
from .variables import Variable, VarType
|
3
|
+
from .parse_model import ParseModel as PM, EvalParams
|
4
|
+
from .helpers import list_accessible_files
|
5
|
+
|
6
|
+
import math
|
7
|
+
import os
|
8
|
+
import re
|
9
|
+
|
10
|
+
from pyscipopt.scip import Model
|
11
|
+
from pkg_resources import parse_version
|
12
|
+
from importlib.metadata import version
|
13
|
+
from sympy import Eq, Function, Derivative as dd, Symbol, parse_expr
|
14
|
+
from sympy.solvers.ode.systems import dsolve_system
|
15
|
+
|
16
|
+
|
17
|
+
if parse_version(version("pyscipopt")) >= parse_version("4.3.0"):
|
18
|
+
from pyscipopt import quicksum, exp, log, sqrt, sin, cos
|
19
|
+
allow_trig_funcs = True
|
20
|
+
else:
|
21
|
+
from pyscipopt import quicksum, exp, log, sqrt
|
22
|
+
allow_trig_funcs = False
|
23
|
+
|
24
|
+
class PlanModel:
|
25
|
+
def __init__(self, config: Config) -> None:
|
26
|
+
self.config: Config = config
|
27
|
+
self.var_names: set[str] = set()
|
28
|
+
|
29
|
+
self.model = Model(f"{config.domain}_{config.instance}_{config.horizon}")
|
30
|
+
|
31
|
+
# Translation -> line_num -> horizon -> aux
|
32
|
+
self.aux_vars: dict[str, list[list[list]]] = {}
|
33
|
+
|
34
|
+
self.file_translations = self.read_translations()
|
35
|
+
|
36
|
+
self.constants = self.encode_constants()
|
37
|
+
self.variables = self.encode_pvariables()
|
38
|
+
self.translations = self.encode_constraints()
|
39
|
+
|
40
|
+
self.rewards = self.encode_reward()
|
41
|
+
|
42
|
+
# Encode bounds on Dt
|
43
|
+
for h in range(self.config.horizon):
|
44
|
+
dt_var = self.variables[(self.config.dt_var, h)].model_var
|
45
|
+
self.model.addCons(dt_var >= 0.0, f"dt_{h}_lower_bound")
|
46
|
+
self.model.addCons(dt_var <= self.config.bigM, f"dt_{h}_upper_bound")
|
47
|
+
|
48
|
+
def read_translations(self) -> dict[str, list[str]]:
|
49
|
+
with open(self.get_file_path("solutions" if self.config.provide_sols else "odes")) as f:
|
50
|
+
translations = {}
|
51
|
+
new_sec = True
|
52
|
+
for line in f:
|
53
|
+
line = line.strip()
|
54
|
+
if line == "":
|
55
|
+
pass
|
56
|
+
elif line == "---":
|
57
|
+
new_sec = True
|
58
|
+
elif new_sec is True:
|
59
|
+
translation = line.removesuffix(":")
|
60
|
+
translations[translation] = []
|
61
|
+
new_sec = False
|
62
|
+
else:
|
63
|
+
translations[translation].append(line)
|
64
|
+
|
65
|
+
return translations
|
66
|
+
|
67
|
+
|
68
|
+
|
69
|
+
def encode_constants(self) -> dict[str, float]:
|
70
|
+
constants = {}
|
71
|
+
translation = "constants"
|
72
|
+
config_vals = {
|
73
|
+
# Since horizon can be incremented without this value being updated, it will be removed for the time being
|
74
|
+
# "config_horizon": self.config.horizon,
|
75
|
+
"config_epsilon": self.config.epsilon,
|
76
|
+
"config_gap": self.config.gap,
|
77
|
+
"config_bigM": self.config.bigM
|
78
|
+
}
|
79
|
+
|
80
|
+
|
81
|
+
for line in self.file_translations[translation]:
|
82
|
+
|
83
|
+
var, val = line.split("=")
|
84
|
+
var, val = var.strip(), val.strip()
|
85
|
+
|
86
|
+
val = val if val not in config_vals else config_vals[val]
|
87
|
+
|
88
|
+
try:
|
89
|
+
val = float(val)
|
90
|
+
except ValueError:
|
91
|
+
raise ValueError("Constants can only be floats, please reconfigure: ")
|
92
|
+
|
93
|
+
constants[var] = val
|
94
|
+
self.var_names.add(var)
|
95
|
+
|
96
|
+
constants["bigM"] = self.config.bigM
|
97
|
+
self.var_names.add("bigM")
|
98
|
+
return constants
|
99
|
+
|
100
|
+
|
101
|
+
def encode_pvariables(self) -> dict[tuple[str, int], Variable]:
|
102
|
+
variables: dict[tuple[str, int], Variable] = {}
|
103
|
+
for t in range(self.config.horizon):
|
104
|
+
for constant, val in self.constants.items():
|
105
|
+
variables[(constant, t)] = Variable.create_var(self.model, constant, "constant", t, self.constants)
|
106
|
+
var_type = variables[(constant, t)].var_type
|
107
|
+
|
108
|
+
translation = "pvariables"
|
109
|
+
for line in self.file_translations[translation]:
|
110
|
+
|
111
|
+
var = line.rstrip("\n").strip()
|
112
|
+
if var == "":
|
113
|
+
continue
|
114
|
+
vtype, name = var.split(": ")
|
115
|
+
vtype, name = vtype.strip(), name.strip()
|
116
|
+
|
117
|
+
self.var_names.add(name)
|
118
|
+
|
119
|
+
if vtype.startswith("global"):
|
120
|
+
var = Variable.create_var(self.model, name, vtype, "global", self.constants)
|
121
|
+
for t in range(self.config.horizon + 1):
|
122
|
+
variables[(name, t)] = var
|
123
|
+
else:
|
124
|
+
for t in range(self.config.horizon):
|
125
|
+
variables[(name, t)] = Variable.create_var(self.model, name, vtype, t, self.constants)
|
126
|
+
var_type = variables[(name, t)].var_type
|
127
|
+
if var_type is VarType.STATE:
|
128
|
+
variables[(name, self.config.horizon)] = Variable.create_var(self.model, name, vtype, self.config.horizon, self.constants)
|
129
|
+
|
130
|
+
return variables
|
131
|
+
|
132
|
+
|
133
|
+
|
134
|
+
def encode_constraints(self) -> dict[str, list[str]]:
|
135
|
+
translation_names = [
|
136
|
+
"initials",
|
137
|
+
"instantaneous_constraints",
|
138
|
+
"temporal_constraints",
|
139
|
+
"goals",
|
140
|
+
"odes" if self.config.provide_sols is False else "transitions"
|
141
|
+
]
|
142
|
+
translations: dict[str, list[str]] = {}
|
143
|
+
for translation in translation_names:
|
144
|
+
translations[translation] = []
|
145
|
+
|
146
|
+
for line in self.file_translations[translation]:
|
147
|
+
expr = line.rstrip("\n").strip()
|
148
|
+
# If line is empty don't append
|
149
|
+
if expr == "":
|
150
|
+
continue
|
151
|
+
|
152
|
+
translations[translation].append(expr)
|
153
|
+
|
154
|
+
if self.config.provide_sols is False:
|
155
|
+
self.ode_functions = self.solve_odes(translations["odes"])
|
156
|
+
|
157
|
+
translations["transitions"] = []
|
158
|
+
for func_name, func in self.ode_functions.items():
|
159
|
+
translations["transitions"].append((func_name + "_dash" + " == " + func))
|
160
|
+
|
161
|
+
del translations["odes"]
|
162
|
+
|
163
|
+
|
164
|
+
# Encode constraints into model
|
165
|
+
for cons_idx, (translation, constraints) in enumerate(translations.items()):
|
166
|
+
for idx, constraint in enumerate(constraints):
|
167
|
+
if (self.config.provide_sols is False) and (translation == "temporal_constraints"):
|
168
|
+
pattern = r"|".join(f"({func_name})" for func_name, func in self.ode_functions.items())
|
169
|
+
constraint = re.sub(pattern, lambda x: self.ode_functions[x.group(0)], constraint)
|
170
|
+
constraints[idx] = constraint
|
171
|
+
|
172
|
+
if translation == "initials":
|
173
|
+
exprs = PM(self.get_parser_params(horizon=0, add_aux_vars=True)).evaluate(constraint, horizon=0, expr_name=f"{translation}_{idx}_0")
|
174
|
+
|
175
|
+
if self.aux_vars.get(translation) is None: self.aux_vars[translation] = [None] * len(constraints)
|
176
|
+
if self.aux_vars[translation][idx] is None: self.aux_vars[translation][idx] = [None] * self.config.horizon
|
177
|
+
self.aux_vars[translation][idx][0] = exprs.aux_vars
|
178
|
+
|
179
|
+
for eqtn_idx, eqtn in enumerate(exprs):
|
180
|
+
self.model.addCons(eqtn, f"{translation}_{idx}_{eqtn_idx}")
|
181
|
+
|
182
|
+
elif translation == "goals":
|
183
|
+
# horizon - 1 is because the final action time is horizon - 1
|
184
|
+
exprs = PM(self.get_parser_params(horizon=self.config.horizon - 1, is_goal=True, add_aux_vars=True)).evaluate(constraint, horizon=self.config.horizon - 1, expr_name=f"{translation}_{idx}_{self.config.horizon - 1}")
|
185
|
+
|
186
|
+
if self.aux_vars.get(translation) is None: self.aux_vars[translation] = [None] * len(constraints)
|
187
|
+
if self.aux_vars[translation][idx] is None: self.aux_vars[translation][idx] = [None] * self.config.horizon
|
188
|
+
self.aux_vars[translation][idx][self.config.horizon - 1] = exprs.aux_vars
|
189
|
+
|
190
|
+
for eqtn_idx, eqtn in enumerate(exprs):
|
191
|
+
self.model.addCons(eqtn, f"{translation}_{idx}_{eqtn_idx}")
|
192
|
+
else:
|
193
|
+
for t in range(self.config.horizon):
|
194
|
+
exprs = PM(self.get_parser_params(horizon=t, add_aux_vars=True)).evaluate(constraint, horizon=t, expr_name=f"{translation}_{idx}_{t}")
|
195
|
+
|
196
|
+
if self.aux_vars.get(translation) is None: self.aux_vars[translation] = [None] * len(constraints)
|
197
|
+
if self.aux_vars[translation][idx] is None: self.aux_vars[translation][idx] = [None] * self.config.horizon
|
198
|
+
self.aux_vars[translation][idx][t] = exprs.aux_vars
|
199
|
+
|
200
|
+
for eqtn_idx, eqtn in enumerate(exprs):
|
201
|
+
self.model.addCons(eqtn, f"{translation}_{idx}_{eqtn_idx}")
|
202
|
+
|
203
|
+
return translations
|
204
|
+
|
205
|
+
def encode_reward(self):
|
206
|
+
objectives = [None] * self.config.horizon
|
207
|
+
translation = "reward"
|
208
|
+
reward = self.file_translations[translation][0]
|
209
|
+
for t in range(self.config.horizon):
|
210
|
+
objectives[t] = self.model.addVar(f"Obj_{t}", vtype="C", lb=None, ub=None)
|
211
|
+
# For the sake of similarity the reward is similar to constraint parsing, however, only one reward function is allowed
|
212
|
+
exprs = PM(self.get_parser_params(t)).evaluate(reward)
|
213
|
+
for expr_idx, expr in enumerate(exprs):
|
214
|
+
self.model.addCons(objectives[t] == expr, f"Obj_{t}_{expr_idx}")
|
215
|
+
|
216
|
+
self.model.setObjective(quicksum(objectives), "maximize")
|
217
|
+
|
218
|
+
return objectives
|
219
|
+
|
220
|
+
|
221
|
+
|
222
|
+
def get_parser_params(self, horizon: int, is_goal: bool = False, add_aux_vars: bool = False) -> EvalParams:
|
223
|
+
functions = {
|
224
|
+
"exp": exp,
|
225
|
+
"log": log,
|
226
|
+
"sqrt": sqrt,
|
227
|
+
}
|
228
|
+
if allow_trig_funcs:
|
229
|
+
functions["sin"] = sin
|
230
|
+
functions["cos"] = cos
|
231
|
+
|
232
|
+
variables = {}
|
233
|
+
operators = {}
|
234
|
+
if is_goal:
|
235
|
+
for name in self.var_names:
|
236
|
+
var = self.variables[(name, horizon)]
|
237
|
+
if var.var_type is VarType.STATE:
|
238
|
+
var = self.variables[(name, horizon + 1)]
|
239
|
+
variables[var.name] = var.model_var
|
240
|
+
else:
|
241
|
+
for name in self.var_names:
|
242
|
+
var = self.variables[(name, horizon)]
|
243
|
+
variables[var.name] = var.model_var
|
244
|
+
if var.var_type is VarType.STATE:
|
245
|
+
var = self.variables[(name, horizon + 1)]
|
246
|
+
variables[f"{var.name}_dash"] = var.model_var
|
247
|
+
|
248
|
+
return EvalParams.as_parser(variables, functions, operators, self.model, add_aux_vars)
|
249
|
+
|
250
|
+
|
251
|
+
def get_calc_params(self, horizon, dt) -> EvalParams:
|
252
|
+
functions = {
|
253
|
+
"exp": math.exp,
|
254
|
+
"log": math.log,
|
255
|
+
"sqrt": math.sqrt,
|
256
|
+
}
|
257
|
+
if allow_trig_funcs:
|
258
|
+
functions["sin"] = math.sin
|
259
|
+
functions["cos"] = math.cos
|
260
|
+
|
261
|
+
variables = {}
|
262
|
+
operators = {}
|
263
|
+
for name in self.var_names:
|
264
|
+
var = self.variables[(name, horizon)]
|
265
|
+
if var.var_type is VarType.CONSTANT:
|
266
|
+
variables[var.name] = var.model_var
|
267
|
+
else:
|
268
|
+
variables[var.name] = self.model.getVal(var.model_var)
|
269
|
+
|
270
|
+
if var.var_type is VarType.STATE:
|
271
|
+
var = self.variables[(name, horizon + 1)]
|
272
|
+
variables[f"{var.name}_dash"] = self.model.getVal(var.model_var)
|
273
|
+
|
274
|
+
variables[self.config.dt_var] = dt
|
275
|
+
|
276
|
+
|
277
|
+
return EvalParams.as_calculator(variables, functions, operators, self.model)
|
278
|
+
|
279
|
+
|
280
|
+
def get_file_path(self, translation: str) -> str:
|
281
|
+
path = f"{translation}_{self.config.domain}_{self.config.instance}.txt"
|
282
|
+
|
283
|
+
usr_files_path = os.path.join("./", "translation")
|
284
|
+
usr_files = list_accessible_files(usr_files_path)
|
285
|
+
|
286
|
+
pkg_files_path = os.path.join(os.path.dirname(__file__), "translation")
|
287
|
+
pkg_files = list_accessible_files(pkg_files_path)
|
288
|
+
|
289
|
+
|
290
|
+
if path in usr_files:
|
291
|
+
return os.path.join(usr_files_path, path)
|
292
|
+
elif path in pkg_files:
|
293
|
+
return os.path.join(pkg_files_path, path)
|
294
|
+
else:
|
295
|
+
raise Exception("Unkown file name, please enter a configuration for a valid domain instance in translation: ")
|
296
|
+
|
297
|
+
|
298
|
+
def solve_odes(self, ode_system: list[str]) -> dict[str, str]:
|
299
|
+
dt_var = self.config.dt_var
|
300
|
+
|
301
|
+
dt = Symbol(dt_var)
|
302
|
+
# Used to represent constant variables
|
303
|
+
temp_var = Symbol("ODES_TEMP_VAR")
|
304
|
+
|
305
|
+
variables = {}
|
306
|
+
states = []
|
307
|
+
|
308
|
+
for var_name in self.var_names:
|
309
|
+
var = self.variables[(var_name, 0)]
|
310
|
+
if var.var_type is VarType.STATE:
|
311
|
+
states.append(var.name)
|
312
|
+
variables[var.name] = Function(var.name)(dt)
|
313
|
+
elif var.var_type is VarType.CONSTANT:
|
314
|
+
variables[var.name] = self.constants[var.name]
|
315
|
+
else: # the variable is an action or aux variable which is encoded as a function of some unused variable as workaround to not being able to use symbols for constants
|
316
|
+
variables[var.name] = Function(var.name)(temp_var)
|
317
|
+
|
318
|
+
variables[dt_var] = dt
|
319
|
+
|
320
|
+
system = []
|
321
|
+
for eqtn in ode_system:
|
322
|
+
lhs, rhs = eqtn.split("==")
|
323
|
+
lhs = parse_expr(lhs.strip(), local_dict=variables | {"dd": dd})
|
324
|
+
rhs = parse_expr(rhs.strip(), local_dict=variables | {"dd": dd})
|
325
|
+
system.append(Eq(lhs, rhs))
|
326
|
+
results = dsolve_system(system, ics={variables[state].subs(dt, 0): state for state in states})
|
327
|
+
|
328
|
+
|
329
|
+
|
330
|
+
functions: dict[str, str] = {}
|
331
|
+
for eqtn in results[0]:
|
332
|
+
new_eqtn = eqtn.doit()
|
333
|
+
func_name = new_eqtn.lhs.name.replace(f"({temp_var.name})", "").replace(f"({self.config.dt_var})", "_dash")
|
334
|
+
functions[func_name] = str(new_eqtn.rhs).replace(f"({temp_var.name})", "").replace(f"({self.config.dt_var})", "_dash")
|
335
|
+
|
336
|
+
|
337
|
+
return functions
|
338
|
+
|
339
|
+
|
@@ -0,0 +1,218 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
|
5
|
+
from .config import Config
|
6
|
+
from .plan_model import PlanModel
|
7
|
+
from .variables import VarType
|
8
|
+
from .zero_crossing import ZeroCrossing
|
9
|
+
from .parse_model import ParseModel as PM
|
10
|
+
from .helpers import InfeasibilityError, iterate, write_to_csv
|
11
|
+
|
12
|
+
from importlib.metadata import version
|
13
|
+
from pyscipopt.scip import Model
|
14
|
+
|
15
|
+
class SCIPPlan:
|
16
|
+
"""
|
17
|
+
SCIPPlan is a planner which optimises mixed integer non-linear programming problems over hybrid domains
|
18
|
+
|
19
|
+
In order to use SCIPPlan you must pass in as input a Config object which contains the configuration of the problem.
|
20
|
+
Then you may either use the optimize or solve methods.
|
21
|
+
|
22
|
+
The optimize method attempts to optimise the problem for the provided horizon.
|
23
|
+
If there are no feasible solutions, then the optimize method will raise InfeasibilityError.
|
24
|
+
|
25
|
+
The solve method attempts to solve the problem starting with the provided horizon.
|
26
|
+
If there are no feasible solutions for the current horizon then the configs horizon will be incremented.
|
27
|
+
After which the solve method will attempt to optimize the problem again.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, config: Config) -> None:
|
31
|
+
self.config = config
|
32
|
+
self.plan = PlanModel(self.config)
|
33
|
+
self.scip_model = self.plan.model
|
34
|
+
|
35
|
+
self.scip_model.setRealParam("limits/gap", self.config.gap)
|
36
|
+
|
37
|
+
if config.show_output is False:
|
38
|
+
self.scip_model.hideOutput()
|
39
|
+
|
40
|
+
self.results_table = []
|
41
|
+
self.new_constraints = []
|
42
|
+
|
43
|
+
self.state_actions = []
|
44
|
+
|
45
|
+
|
46
|
+
def optimize(self):
|
47
|
+
iteration = 0
|
48
|
+
|
49
|
+
const_gen_aux_vars = [[None] * self.config.horizon for _ in range(len(self.plan.translations["temporal_constraints"]))]
|
50
|
+
|
51
|
+
while True:
|
52
|
+
self.scip_model.optimize()
|
53
|
+
|
54
|
+
if len(self.scip_model.getSols()) == 0:
|
55
|
+
raise InfeasibilityError
|
56
|
+
|
57
|
+
zero_cross = self.check_violated_constraints(iteration)
|
58
|
+
|
59
|
+
self.save_values(iteration)
|
60
|
+
|
61
|
+
if zero_cross.is_violated is False:
|
62
|
+
return None
|
63
|
+
|
64
|
+
self.new_constraints.append({
|
65
|
+
"interval_start": zero_cross.start,
|
66
|
+
"interval_end": zero_cross.end,
|
67
|
+
"dt_interval": zero_cross.dt_interval,
|
68
|
+
"zero_crossing_coefficient": zero_cross.coef,
|
69
|
+
"new_dt_val": zero_cross.new_dt_val,
|
70
|
+
"horizon": zero_cross.horizon,
|
71
|
+
"iteration": zero_cross.iteration,
|
72
|
+
"constraint_idx": zero_cross.constraint_idx
|
73
|
+
})
|
74
|
+
|
75
|
+
if self.config.show_output is True:
|
76
|
+
print("\n\n")
|
77
|
+
print("New Constraints: \n")
|
78
|
+
for new_constraint in self.new_constraints:
|
79
|
+
print(new_constraint, end="\n\n")
|
80
|
+
print("\n\n")
|
81
|
+
|
82
|
+
self.scip_model.freeTransform()
|
83
|
+
|
84
|
+
t = zero_cross.horizon
|
85
|
+
idx = zero_cross.constraint_idx
|
86
|
+
constraint = self.plan.translations["temporal_constraints"][idx]
|
87
|
+
aux_vars = self.plan.aux_vars["temporal_constraints"][idx][t]
|
88
|
+
# Only add aux vars if there are no aux vars added for the specific constraint
|
89
|
+
params = self.plan.get_parser_params(horizon=t, add_aux_vars=aux_vars is None)
|
90
|
+
params.variables[self.config.dt_var] *= zero_cross.coef
|
91
|
+
exprs = PM(params).evaluate(constraint, aux_vars=aux_vars)
|
92
|
+
if const_gen_aux_vars[idx][t] is None:
|
93
|
+
const_gen_aux_vars[idx][t] = exprs.aux_vars
|
94
|
+
|
95
|
+
for eqtn_idx, eqtn in enumerate(exprs):
|
96
|
+
self.plan.model.addCons(eqtn, f"{constraint}_{idx}_{eqtn_idx}")
|
97
|
+
|
98
|
+
iteration += 1
|
99
|
+
|
100
|
+
|
101
|
+
def check_violated_constraints(self, iteration: int) -> ZeroCrossing:
|
102
|
+
is_violated = False
|
103
|
+
cross_interval = [-1.0 * self.config.epsilon, -1.0 * self.config.epsilon]
|
104
|
+
|
105
|
+
for h in range(self.config.horizon):
|
106
|
+
dt = self.scip_model.getVal(self.plan.variables[(self.config.dt_var, h)].model_var)
|
107
|
+
|
108
|
+
for idx, constraint in enumerate(self.plan.translations["temporal_constraints"]):
|
109
|
+
is_violated = False
|
110
|
+
|
111
|
+
for time in iterate(0, dt, self.config.epsilon):
|
112
|
+
pm = PM(self.plan.get_calc_params(horizon=h, dt=time))
|
113
|
+
exprs = pm.evaluate(constraint)
|
114
|
+
|
115
|
+
for eqtn_idx, constraint_eval in enumerate(exprs):
|
116
|
+
if constraint_eval is False:
|
117
|
+
if not is_violated:
|
118
|
+
# Set interval start when first part of zero crossing is found
|
119
|
+
is_violated = True
|
120
|
+
cross_interval[0] = time
|
121
|
+
|
122
|
+
# Keep updating end point until end of zero crossing or end of dt interval
|
123
|
+
cross_interval[1] = time
|
124
|
+
|
125
|
+
if is_violated and (constraint_eval is True or time + self.config.epsilon > dt):
|
126
|
+
return ZeroCrossing(
|
127
|
+
is_violated=True,
|
128
|
+
horizon=h,
|
129
|
+
iteration=iteration,
|
130
|
+
start=cross_interval[0],
|
131
|
+
end=cross_interval[1],
|
132
|
+
dt_interval=dt,
|
133
|
+
constraint_idx = idx,
|
134
|
+
)
|
135
|
+
|
136
|
+
return ZeroCrossing(is_violated=False)
|
137
|
+
|
138
|
+
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def solve(cls, config: Config) -> tuple[SCIPPlan, float]:
|
142
|
+
# Time total solve time including incrementing horizon
|
143
|
+
start_time = time.time()
|
144
|
+
while True:
|
145
|
+
model = SCIPPlan(config)
|
146
|
+
try:
|
147
|
+
print(f"Encoding the problem over horizon h={config.horizon}.")
|
148
|
+
print("Solving the problem.")
|
149
|
+
model.optimize()
|
150
|
+
|
151
|
+
solve_time = (time.time() - start_time)
|
152
|
+
# print(f"Total Time: {solve_time: .3f} seconds")
|
153
|
+
print("Problem solved. \n")
|
154
|
+
return model, solve_time
|
155
|
+
|
156
|
+
|
157
|
+
except InfeasibilityError:
|
158
|
+
if config.get_defaults().get("horizon") is False:
|
159
|
+
print(f"Horizon of h={model.config.horizon} is infeasible.")
|
160
|
+
|
161
|
+
solve_time = (time.time() - start_time)
|
162
|
+
print(f"Total time: {solve_time:.3f}")
|
163
|
+
|
164
|
+
raise InfeasibilityError
|
165
|
+
|
166
|
+
|
167
|
+
# print("Problem is infeasible for the given horizon.")
|
168
|
+
print(f"Horizon of h={model.config.horizon} is infeasible, incrementing to h={model.config.horizon + 1}.")
|
169
|
+
config.increment_horizon()
|
170
|
+
if config.show_output is True:
|
171
|
+
print(f"Horizon Time: {(time.time() - start_time): .3f} seconds.")
|
172
|
+
|
173
|
+
|
174
|
+
def save_values(self, iteration: int):
|
175
|
+
for (name, h), var in self.plan.variables.items():
|
176
|
+
self.results_table.append(var.to_dict() | {"iteration": iteration})
|
177
|
+
|
178
|
+
|
179
|
+
def main():
|
180
|
+
print(f"SCIP Version: {Model().version()}")
|
181
|
+
print(f"PySCIPOpt Version: {version('pyscipopt')}\n")
|
182
|
+
config = Config.get_config()
|
183
|
+
print(config)
|
184
|
+
|
185
|
+
try:
|
186
|
+
plan, solve_time = SCIPPlan.solve(config)
|
187
|
+
except InfeasibilityError:
|
188
|
+
return None
|
189
|
+
|
190
|
+
if config.save_sols is True:
|
191
|
+
write_to_csv("new_constraints", plan.new_constraints, config)
|
192
|
+
write_to_csv("results", plan.results_table, config)
|
193
|
+
print("Solutions saved: \n")
|
194
|
+
|
195
|
+
|
196
|
+
print("Plan: ")
|
197
|
+
|
198
|
+
# Get action variable names
|
199
|
+
action_names = [
|
200
|
+
var_name for var_name in plan.plan.var_names
|
201
|
+
if plan.plan.variables[(var_name, 0)].var_type is VarType.ACTION
|
202
|
+
]
|
203
|
+
|
204
|
+
action_names = sorted(action_names)
|
205
|
+
|
206
|
+
for step in range(plan.config.horizon):
|
207
|
+
for action_name in action_names:
|
208
|
+
if action_name == config.dt_var:
|
209
|
+
continue
|
210
|
+
print(f"{action_name} at step {step} by value {plan.scip_model.getVal(plan.plan.variables[(action_name, step)].model_var):.3f}.")
|
211
|
+
|
212
|
+
print(f"Dt at step {step} by value {plan.scip_model.getVal(plan.plan.variables[('Dt', step)].model_var):.3f}. \n")
|
213
|
+
|
214
|
+
print(f"Total reward: {(plan.scip_model.getObjVal()):.3f}")
|
215
|
+
print(f"Total time: {solve_time:.3f}")
|
216
|
+
|
217
|
+
if __name__ == "__main__":
|
218
|
+
main()
|
@@ -0,0 +1,91 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from enum import Enum
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
from pyscipopt.scip import Model, Variable as SCIPVariable
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class Variable:
|
11
|
+
name: str
|
12
|
+
var_type: VarType
|
13
|
+
val_type: ValType
|
14
|
+
time: int
|
15
|
+
model: Model
|
16
|
+
model_var: Union[SCIPVariable | float]
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
def create_var(cls, model: Model, name: str, vtype: str, time: int, const_vals: dict[str, float]) -> Variable:
|
20
|
+
|
21
|
+
if "action" in vtype:
|
22
|
+
var_type = VarType.ACTION
|
23
|
+
elif "state" in vtype:
|
24
|
+
var_type = VarType.STATE
|
25
|
+
elif "auxiliary" in vtype:
|
26
|
+
var_type = VarType.AUX
|
27
|
+
elif "constant" in vtype:
|
28
|
+
var_type = VarType.CONSTANT
|
29
|
+
else: # var type isn't recognised
|
30
|
+
raise Exception("Unknown variable type: ")
|
31
|
+
|
32
|
+
if "continuous" in vtype:
|
33
|
+
val_type = ValType.CONTINUOUS
|
34
|
+
elif "integer" in vtype:
|
35
|
+
val_type = ValType.INTEGER
|
36
|
+
elif "boolean" in vtype:
|
37
|
+
val_type = ValType.BOOLEAN
|
38
|
+
elif var_type is VarType.CONSTANT:
|
39
|
+
# Special case for constants as the value type doesn't matter for the model as it is numeric
|
40
|
+
val_type = None
|
41
|
+
else: # val type isn't recognised
|
42
|
+
raise Exception("Unkown value type: ")
|
43
|
+
|
44
|
+
if var_type is VarType.CONSTANT:
|
45
|
+
model_var = const_vals[name]
|
46
|
+
else:
|
47
|
+
model_var = model.addVar(name=f"{name}_{time}", vtype=val_type.value, lb=None, ub=None)
|
48
|
+
|
49
|
+
var = Variable(
|
50
|
+
name=name,
|
51
|
+
var_type=var_type,
|
52
|
+
val_type=val_type,
|
53
|
+
time=time,
|
54
|
+
model=model,
|
55
|
+
model_var=model_var
|
56
|
+
)
|
57
|
+
|
58
|
+
return var
|
59
|
+
|
60
|
+
def to_dict(self):
|
61
|
+
if self.var_type is VarType.CONSTANT:
|
62
|
+
var_val = self.model_var
|
63
|
+
val_type = None
|
64
|
+
else:
|
65
|
+
val_type = self.val_type.name
|
66
|
+
try:
|
67
|
+
var_val = self.model.getVal(self.model_var)
|
68
|
+
except Warning:
|
69
|
+
var_val = None
|
70
|
+
|
71
|
+
|
72
|
+
return {
|
73
|
+
"name": self.name,
|
74
|
+
"variable_type": self.var_type.name,
|
75
|
+
"value_type": val_type,
|
76
|
+
"horizon": self.time,
|
77
|
+
"variable_value": var_val
|
78
|
+
}
|
79
|
+
|
80
|
+
|
81
|
+
|
82
|
+
class VarType(Enum):
|
83
|
+
ACTION = "action"
|
84
|
+
STATE = "state"
|
85
|
+
AUX = "auxiliary"
|
86
|
+
CONSTANT = "constant"
|
87
|
+
|
88
|
+
class ValType(Enum):
|
89
|
+
CONTINUOUS = "C"
|
90
|
+
INTEGER = "I"
|
91
|
+
BOOLEAN = "B"
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from dataclasses import field, dataclass
|
2
|
+
from textwrap import dedent
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class ZeroCrossing:
|
6
|
+
is_violated: bool = field(repr=False)
|
7
|
+
iteration: int = None
|
8
|
+
horizon: int = None
|
9
|
+
start: float = None
|
10
|
+
end: float = None
|
11
|
+
dt_interval: float = None
|
12
|
+
coef: float = field(init=False, default=None)
|
13
|
+
new_dt_val: float = field(init=False, default=None)
|
14
|
+
constraint_idx: int = None
|
15
|
+
|
16
|
+
def __post_init__(self):
|
17
|
+
if self.is_violated is True:
|
18
|
+
if self.start is None or self.end is None or self.dt_interval is None:
|
19
|
+
raise ValueError(
|
20
|
+
dedent(f"""
|
21
|
+
Incorrect input values, start, end and dt_interval have to be specified when zero crossing exists
|
22
|
+
|
23
|
+
{self.start = }
|
24
|
+
{self.end = }
|
25
|
+
{self.dt_interval = }
|
26
|
+
"""))
|
27
|
+
avg_interval = (self.start + self.end) / 2.0
|
28
|
+
self.coef = avg_interval / self.dt_interval
|
29
|
+
self.new_dt_val = self.coef * self.dt_interval
|
scipplan/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
__version__ = "0.2.
|
1
|
+
__version__ = "0.2.2alpha0"
|
2
2
|
print(f"SCIPPlan Version: {__version__}")
|
3
|
-
__release__ = "v0.2.
|
3
|
+
__release__ = "v0.2.2a0"
|
4
4
|
__author__ = "Ari Gestetner, Buser Say"
|
5
|
-
__email__ = "ari.gestetner@monash.edu, buser.say@monash.edu"
|
5
|
+
__email__ = "ari.gestetner@monash.edu, buser.say@monash.edu"
|