scipplan 0.2.1a0__py2.py3-none-any.whl → 0.2.2a0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,339 @@
1
+ from .config import Config
2
+ from .variables import Variable, VarType
3
+ from .parse_model import ParseModel as PM, EvalParams
4
+ from .helpers import list_accessible_files
5
+
6
+ import math
7
+ import os
8
+ import re
9
+
10
+ from pyscipopt.scip import Model
11
+ from pkg_resources import parse_version
12
+ from importlib.metadata import version
13
+ from sympy import Eq, Function, Derivative as dd, Symbol, parse_expr
14
+ from sympy.solvers.ode.systems import dsolve_system
15
+
16
+
17
+ if parse_version(version("pyscipopt")) >= parse_version("4.3.0"):
18
+ from pyscipopt import quicksum, exp, log, sqrt, sin, cos
19
+ allow_trig_funcs = True
20
+ else:
21
+ from pyscipopt import quicksum, exp, log, sqrt
22
+ allow_trig_funcs = False
23
+
24
+ class PlanModel:
25
+ def __init__(self, config: Config) -> None:
26
+ self.config: Config = config
27
+ self.var_names: set[str] = set()
28
+
29
+ self.model = Model(f"{config.domain}_{config.instance}_{config.horizon}")
30
+
31
+ # Translation -> line_num -> horizon -> aux
32
+ self.aux_vars: dict[str, list[list[list]]] = {}
33
+
34
+ self.file_translations = self.read_translations()
35
+
36
+ self.constants = self.encode_constants()
37
+ self.variables = self.encode_pvariables()
38
+ self.translations = self.encode_constraints()
39
+
40
+ self.rewards = self.encode_reward()
41
+
42
+ # Encode bounds on Dt
43
+ for h in range(self.config.horizon):
44
+ dt_var = self.variables[(self.config.dt_var, h)].model_var
45
+ self.model.addCons(dt_var >= 0.0, f"dt_{h}_lower_bound")
46
+ self.model.addCons(dt_var <= self.config.bigM, f"dt_{h}_upper_bound")
47
+
48
+ def read_translations(self) -> dict[str, list[str]]:
49
+ with open(self.get_file_path("solutions" if self.config.provide_sols else "odes")) as f:
50
+ translations = {}
51
+ new_sec = True
52
+ for line in f:
53
+ line = line.strip()
54
+ if line == "":
55
+ pass
56
+ elif line == "---":
57
+ new_sec = True
58
+ elif new_sec is True:
59
+ translation = line.removesuffix(":")
60
+ translations[translation] = []
61
+ new_sec = False
62
+ else:
63
+ translations[translation].append(line)
64
+
65
+ return translations
66
+
67
+
68
+
69
+ def encode_constants(self) -> dict[str, float]:
70
+ constants = {}
71
+ translation = "constants"
72
+ config_vals = {
73
+ # Since horizon can be incremented without this value being updated, it will be removed for the time being
74
+ # "config_horizon": self.config.horizon,
75
+ "config_epsilon": self.config.epsilon,
76
+ "config_gap": self.config.gap,
77
+ "config_bigM": self.config.bigM
78
+ }
79
+
80
+
81
+ for line in self.file_translations[translation]:
82
+
83
+ var, val = line.split("=")
84
+ var, val = var.strip(), val.strip()
85
+
86
+ val = val if val not in config_vals else config_vals[val]
87
+
88
+ try:
89
+ val = float(val)
90
+ except ValueError:
91
+ raise ValueError("Constants can only be floats, please reconfigure: ")
92
+
93
+ constants[var] = val
94
+ self.var_names.add(var)
95
+
96
+ constants["bigM"] = self.config.bigM
97
+ self.var_names.add("bigM")
98
+ return constants
99
+
100
+
101
+ def encode_pvariables(self) -> dict[tuple[str, int], Variable]:
102
+ variables: dict[tuple[str, int], Variable] = {}
103
+ for t in range(self.config.horizon):
104
+ for constant, val in self.constants.items():
105
+ variables[(constant, t)] = Variable.create_var(self.model, constant, "constant", t, self.constants)
106
+ var_type = variables[(constant, t)].var_type
107
+
108
+ translation = "pvariables"
109
+ for line in self.file_translations[translation]:
110
+
111
+ var = line.rstrip("\n").strip()
112
+ if var == "":
113
+ continue
114
+ vtype, name = var.split(": ")
115
+ vtype, name = vtype.strip(), name.strip()
116
+
117
+ self.var_names.add(name)
118
+
119
+ if vtype.startswith("global"):
120
+ var = Variable.create_var(self.model, name, vtype, "global", self.constants)
121
+ for t in range(self.config.horizon + 1):
122
+ variables[(name, t)] = var
123
+ else:
124
+ for t in range(self.config.horizon):
125
+ variables[(name, t)] = Variable.create_var(self.model, name, vtype, t, self.constants)
126
+ var_type = variables[(name, t)].var_type
127
+ if var_type is VarType.STATE:
128
+ variables[(name, self.config.horizon)] = Variable.create_var(self.model, name, vtype, self.config.horizon, self.constants)
129
+
130
+ return variables
131
+
132
+
133
+
134
+ def encode_constraints(self) -> dict[str, list[str]]:
135
+ translation_names = [
136
+ "initials",
137
+ "instantaneous_constraints",
138
+ "temporal_constraints",
139
+ "goals",
140
+ "odes" if self.config.provide_sols is False else "transitions"
141
+ ]
142
+ translations: dict[str, list[str]] = {}
143
+ for translation in translation_names:
144
+ translations[translation] = []
145
+
146
+ for line in self.file_translations[translation]:
147
+ expr = line.rstrip("\n").strip()
148
+ # If line is empty don't append
149
+ if expr == "":
150
+ continue
151
+
152
+ translations[translation].append(expr)
153
+
154
+ if self.config.provide_sols is False:
155
+ self.ode_functions = self.solve_odes(translations["odes"])
156
+
157
+ translations["transitions"] = []
158
+ for func_name, func in self.ode_functions.items():
159
+ translations["transitions"].append((func_name + "_dash" + " == " + func))
160
+
161
+ del translations["odes"]
162
+
163
+
164
+ # Encode constraints into model
165
+ for cons_idx, (translation, constraints) in enumerate(translations.items()):
166
+ for idx, constraint in enumerate(constraints):
167
+ if (self.config.provide_sols is False) and (translation == "temporal_constraints"):
168
+ pattern = r"|".join(f"({func_name})" for func_name, func in self.ode_functions.items())
169
+ constraint = re.sub(pattern, lambda x: self.ode_functions[x.group(0)], constraint)
170
+ constraints[idx] = constraint
171
+
172
+ if translation == "initials":
173
+ exprs = PM(self.get_parser_params(horizon=0, add_aux_vars=True)).evaluate(constraint, horizon=0, expr_name=f"{translation}_{idx}_0")
174
+
175
+ if self.aux_vars.get(translation) is None: self.aux_vars[translation] = [None] * len(constraints)
176
+ if self.aux_vars[translation][idx] is None: self.aux_vars[translation][idx] = [None] * self.config.horizon
177
+ self.aux_vars[translation][idx][0] = exprs.aux_vars
178
+
179
+ for eqtn_idx, eqtn in enumerate(exprs):
180
+ self.model.addCons(eqtn, f"{translation}_{idx}_{eqtn_idx}")
181
+
182
+ elif translation == "goals":
183
+ # horizon - 1 is because the final action time is horizon - 1
184
+ exprs = PM(self.get_parser_params(horizon=self.config.horizon - 1, is_goal=True, add_aux_vars=True)).evaluate(constraint, horizon=self.config.horizon - 1, expr_name=f"{translation}_{idx}_{self.config.horizon - 1}")
185
+
186
+ if self.aux_vars.get(translation) is None: self.aux_vars[translation] = [None] * len(constraints)
187
+ if self.aux_vars[translation][idx] is None: self.aux_vars[translation][idx] = [None] * self.config.horizon
188
+ self.aux_vars[translation][idx][self.config.horizon - 1] = exprs.aux_vars
189
+
190
+ for eqtn_idx, eqtn in enumerate(exprs):
191
+ self.model.addCons(eqtn, f"{translation}_{idx}_{eqtn_idx}")
192
+ else:
193
+ for t in range(self.config.horizon):
194
+ exprs = PM(self.get_parser_params(horizon=t, add_aux_vars=True)).evaluate(constraint, horizon=t, expr_name=f"{translation}_{idx}_{t}")
195
+
196
+ if self.aux_vars.get(translation) is None: self.aux_vars[translation] = [None] * len(constraints)
197
+ if self.aux_vars[translation][idx] is None: self.aux_vars[translation][idx] = [None] * self.config.horizon
198
+ self.aux_vars[translation][idx][t] = exprs.aux_vars
199
+
200
+ for eqtn_idx, eqtn in enumerate(exprs):
201
+ self.model.addCons(eqtn, f"{translation}_{idx}_{eqtn_idx}")
202
+
203
+ return translations
204
+
205
+ def encode_reward(self):
206
+ objectives = [None] * self.config.horizon
207
+ translation = "reward"
208
+ reward = self.file_translations[translation][0]
209
+ for t in range(self.config.horizon):
210
+ objectives[t] = self.model.addVar(f"Obj_{t}", vtype="C", lb=None, ub=None)
211
+ # For the sake of similarity the reward is similar to constraint parsing, however, only one reward function is allowed
212
+ exprs = PM(self.get_parser_params(t)).evaluate(reward)
213
+ for expr_idx, expr in enumerate(exprs):
214
+ self.model.addCons(objectives[t] == expr, f"Obj_{t}_{expr_idx}")
215
+
216
+ self.model.setObjective(quicksum(objectives), "maximize")
217
+
218
+ return objectives
219
+
220
+
221
+
222
+ def get_parser_params(self, horizon: int, is_goal: bool = False, add_aux_vars: bool = False) -> EvalParams:
223
+ functions = {
224
+ "exp": exp,
225
+ "log": log,
226
+ "sqrt": sqrt,
227
+ }
228
+ if allow_trig_funcs:
229
+ functions["sin"] = sin
230
+ functions["cos"] = cos
231
+
232
+ variables = {}
233
+ operators = {}
234
+ if is_goal:
235
+ for name in self.var_names:
236
+ var = self.variables[(name, horizon)]
237
+ if var.var_type is VarType.STATE:
238
+ var = self.variables[(name, horizon + 1)]
239
+ variables[var.name] = var.model_var
240
+ else:
241
+ for name in self.var_names:
242
+ var = self.variables[(name, horizon)]
243
+ variables[var.name] = var.model_var
244
+ if var.var_type is VarType.STATE:
245
+ var = self.variables[(name, horizon + 1)]
246
+ variables[f"{var.name}_dash"] = var.model_var
247
+
248
+ return EvalParams.as_parser(variables, functions, operators, self.model, add_aux_vars)
249
+
250
+
251
+ def get_calc_params(self, horizon, dt) -> EvalParams:
252
+ functions = {
253
+ "exp": math.exp,
254
+ "log": math.log,
255
+ "sqrt": math.sqrt,
256
+ }
257
+ if allow_trig_funcs:
258
+ functions["sin"] = math.sin
259
+ functions["cos"] = math.cos
260
+
261
+ variables = {}
262
+ operators = {}
263
+ for name in self.var_names:
264
+ var = self.variables[(name, horizon)]
265
+ if var.var_type is VarType.CONSTANT:
266
+ variables[var.name] = var.model_var
267
+ else:
268
+ variables[var.name] = self.model.getVal(var.model_var)
269
+
270
+ if var.var_type is VarType.STATE:
271
+ var = self.variables[(name, horizon + 1)]
272
+ variables[f"{var.name}_dash"] = self.model.getVal(var.model_var)
273
+
274
+ variables[self.config.dt_var] = dt
275
+
276
+
277
+ return EvalParams.as_calculator(variables, functions, operators, self.model)
278
+
279
+
280
+ def get_file_path(self, translation: str) -> str:
281
+ path = f"{translation}_{self.config.domain}_{self.config.instance}.txt"
282
+
283
+ usr_files_path = os.path.join("./", "translation")
284
+ usr_files = list_accessible_files(usr_files_path)
285
+
286
+ pkg_files_path = os.path.join(os.path.dirname(__file__), "translation")
287
+ pkg_files = list_accessible_files(pkg_files_path)
288
+
289
+
290
+ if path in usr_files:
291
+ return os.path.join(usr_files_path, path)
292
+ elif path in pkg_files:
293
+ return os.path.join(pkg_files_path, path)
294
+ else:
295
+ raise Exception("Unkown file name, please enter a configuration for a valid domain instance in translation: ")
296
+
297
+
298
+ def solve_odes(self, ode_system: list[str]) -> dict[str, str]:
299
+ dt_var = self.config.dt_var
300
+
301
+ dt = Symbol(dt_var)
302
+ # Used to represent constant variables
303
+ temp_var = Symbol("ODES_TEMP_VAR")
304
+
305
+ variables = {}
306
+ states = []
307
+
308
+ for var_name in self.var_names:
309
+ var = self.variables[(var_name, 0)]
310
+ if var.var_type is VarType.STATE:
311
+ states.append(var.name)
312
+ variables[var.name] = Function(var.name)(dt)
313
+ elif var.var_type is VarType.CONSTANT:
314
+ variables[var.name] = self.constants[var.name]
315
+ else: # the variable is an action or aux variable which is encoded as a function of some unused variable as workaround to not being able to use symbols for constants
316
+ variables[var.name] = Function(var.name)(temp_var)
317
+
318
+ variables[dt_var] = dt
319
+
320
+ system = []
321
+ for eqtn in ode_system:
322
+ lhs, rhs = eqtn.split("==")
323
+ lhs = parse_expr(lhs.strip(), local_dict=variables | {"dd": dd})
324
+ rhs = parse_expr(rhs.strip(), local_dict=variables | {"dd": dd})
325
+ system.append(Eq(lhs, rhs))
326
+ results = dsolve_system(system, ics={variables[state].subs(dt, 0): state for state in states})
327
+
328
+
329
+
330
+ functions: dict[str, str] = {}
331
+ for eqtn in results[0]:
332
+ new_eqtn = eqtn.doit()
333
+ func_name = new_eqtn.lhs.name.replace(f"({temp_var.name})", "").replace(f"({self.config.dt_var})", "_dash")
334
+ functions[func_name] = str(new_eqtn.rhs).replace(f"({temp_var.name})", "").replace(f"({self.config.dt_var})", "_dash")
335
+
336
+
337
+ return functions
338
+
339
+
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+
5
+ from .config import Config
6
+ from .plan_model import PlanModel
7
+ from .variables import VarType
8
+ from .zero_crossing import ZeroCrossing
9
+ from .parse_model import ParseModel as PM
10
+ from .helpers import InfeasibilityError, iterate, write_to_csv
11
+
12
+ from importlib.metadata import version
13
+ from pyscipopt.scip import Model
14
+
15
+ class SCIPPlan:
16
+ """
17
+ SCIPPlan is a planner which optimises mixed integer non-linear programming problems over hybrid domains
18
+
19
+ In order to use SCIPPlan you must pass in as input a Config object which contains the configuration of the problem.
20
+ Then you may either use the optimize or solve methods.
21
+
22
+ The optimize method attempts to optimise the problem for the provided horizon.
23
+ If there are no feasible solutions, then the optimize method will raise InfeasibilityError.
24
+
25
+ The solve method attempts to solve the problem starting with the provided horizon.
26
+ If there are no feasible solutions for the current horizon then the configs horizon will be incremented.
27
+ After which the solve method will attempt to optimize the problem again.
28
+ """
29
+
30
+ def __init__(self, config: Config) -> None:
31
+ self.config = config
32
+ self.plan = PlanModel(self.config)
33
+ self.scip_model = self.plan.model
34
+
35
+ self.scip_model.setRealParam("limits/gap", self.config.gap)
36
+
37
+ if config.show_output is False:
38
+ self.scip_model.hideOutput()
39
+
40
+ self.results_table = []
41
+ self.new_constraints = []
42
+
43
+ self.state_actions = []
44
+
45
+
46
+ def optimize(self):
47
+ iteration = 0
48
+
49
+ const_gen_aux_vars = [[None] * self.config.horizon for _ in range(len(self.plan.translations["temporal_constraints"]))]
50
+
51
+ while True:
52
+ self.scip_model.optimize()
53
+
54
+ if len(self.scip_model.getSols()) == 0:
55
+ raise InfeasibilityError
56
+
57
+ zero_cross = self.check_violated_constraints(iteration)
58
+
59
+ self.save_values(iteration)
60
+
61
+ if zero_cross.is_violated is False:
62
+ return None
63
+
64
+ self.new_constraints.append({
65
+ "interval_start": zero_cross.start,
66
+ "interval_end": zero_cross.end,
67
+ "dt_interval": zero_cross.dt_interval,
68
+ "zero_crossing_coefficient": zero_cross.coef,
69
+ "new_dt_val": zero_cross.new_dt_val,
70
+ "horizon": zero_cross.horizon,
71
+ "iteration": zero_cross.iteration,
72
+ "constraint_idx": zero_cross.constraint_idx
73
+ })
74
+
75
+ if self.config.show_output is True:
76
+ print("\n\n")
77
+ print("New Constraints: \n")
78
+ for new_constraint in self.new_constraints:
79
+ print(new_constraint, end="\n\n")
80
+ print("\n\n")
81
+
82
+ self.scip_model.freeTransform()
83
+
84
+ t = zero_cross.horizon
85
+ idx = zero_cross.constraint_idx
86
+ constraint = self.plan.translations["temporal_constraints"][idx]
87
+ aux_vars = self.plan.aux_vars["temporal_constraints"][idx][t]
88
+ # Only add aux vars if there are no aux vars added for the specific constraint
89
+ params = self.plan.get_parser_params(horizon=t, add_aux_vars=aux_vars is None)
90
+ params.variables[self.config.dt_var] *= zero_cross.coef
91
+ exprs = PM(params).evaluate(constraint, aux_vars=aux_vars)
92
+ if const_gen_aux_vars[idx][t] is None:
93
+ const_gen_aux_vars[idx][t] = exprs.aux_vars
94
+
95
+ for eqtn_idx, eqtn in enumerate(exprs):
96
+ self.plan.model.addCons(eqtn, f"{constraint}_{idx}_{eqtn_idx}")
97
+
98
+ iteration += 1
99
+
100
+
101
+ def check_violated_constraints(self, iteration: int) -> ZeroCrossing:
102
+ is_violated = False
103
+ cross_interval = [-1.0 * self.config.epsilon, -1.0 * self.config.epsilon]
104
+
105
+ for h in range(self.config.horizon):
106
+ dt = self.scip_model.getVal(self.plan.variables[(self.config.dt_var, h)].model_var)
107
+
108
+ for idx, constraint in enumerate(self.plan.translations["temporal_constraints"]):
109
+ is_violated = False
110
+
111
+ for time in iterate(0, dt, self.config.epsilon):
112
+ pm = PM(self.plan.get_calc_params(horizon=h, dt=time))
113
+ exprs = pm.evaluate(constraint)
114
+
115
+ for eqtn_idx, constraint_eval in enumerate(exprs):
116
+ if constraint_eval is False:
117
+ if not is_violated:
118
+ # Set interval start when first part of zero crossing is found
119
+ is_violated = True
120
+ cross_interval[0] = time
121
+
122
+ # Keep updating end point until end of zero crossing or end of dt interval
123
+ cross_interval[1] = time
124
+
125
+ if is_violated and (constraint_eval is True or time + self.config.epsilon > dt):
126
+ return ZeroCrossing(
127
+ is_violated=True,
128
+ horizon=h,
129
+ iteration=iteration,
130
+ start=cross_interval[0],
131
+ end=cross_interval[1],
132
+ dt_interval=dt,
133
+ constraint_idx = idx,
134
+ )
135
+
136
+ return ZeroCrossing(is_violated=False)
137
+
138
+
139
+
140
+ @classmethod
141
+ def solve(cls, config: Config) -> tuple[SCIPPlan, float]:
142
+ # Time total solve time including incrementing horizon
143
+ start_time = time.time()
144
+ while True:
145
+ model = SCIPPlan(config)
146
+ try:
147
+ print(f"Encoding the problem over horizon h={config.horizon}.")
148
+ print("Solving the problem.")
149
+ model.optimize()
150
+
151
+ solve_time = (time.time() - start_time)
152
+ # print(f"Total Time: {solve_time: .3f} seconds")
153
+ print("Problem solved. \n")
154
+ return model, solve_time
155
+
156
+
157
+ except InfeasibilityError:
158
+ if config.get_defaults().get("horizon") is False:
159
+ print(f"Horizon of h={model.config.horizon} is infeasible.")
160
+
161
+ solve_time = (time.time() - start_time)
162
+ print(f"Total time: {solve_time:.3f}")
163
+
164
+ raise InfeasibilityError
165
+
166
+
167
+ # print("Problem is infeasible for the given horizon.")
168
+ print(f"Horizon of h={model.config.horizon} is infeasible, incrementing to h={model.config.horizon + 1}.")
169
+ config.increment_horizon()
170
+ if config.show_output is True:
171
+ print(f"Horizon Time: {(time.time() - start_time): .3f} seconds.")
172
+
173
+
174
+ def save_values(self, iteration: int):
175
+ for (name, h), var in self.plan.variables.items():
176
+ self.results_table.append(var.to_dict() | {"iteration": iteration})
177
+
178
+
179
+ def main():
180
+ print(f"SCIP Version: {Model().version()}")
181
+ print(f"PySCIPOpt Version: {version('pyscipopt')}\n")
182
+ config = Config.get_config()
183
+ print(config)
184
+
185
+ try:
186
+ plan, solve_time = SCIPPlan.solve(config)
187
+ except InfeasibilityError:
188
+ return None
189
+
190
+ if config.save_sols is True:
191
+ write_to_csv("new_constraints", plan.new_constraints, config)
192
+ write_to_csv("results", plan.results_table, config)
193
+ print("Solutions saved: \n")
194
+
195
+
196
+ print("Plan: ")
197
+
198
+ # Get action variable names
199
+ action_names = [
200
+ var_name for var_name in plan.plan.var_names
201
+ if plan.plan.variables[(var_name, 0)].var_type is VarType.ACTION
202
+ ]
203
+
204
+ action_names = sorted(action_names)
205
+
206
+ for step in range(plan.config.horizon):
207
+ for action_name in action_names:
208
+ if action_name == config.dt_var:
209
+ continue
210
+ print(f"{action_name} at step {step} by value {plan.scip_model.getVal(plan.plan.variables[(action_name, step)].model_var):.3f}.")
211
+
212
+ print(f"Dt at step {step} by value {plan.scip_model.getVal(plan.plan.variables[('Dt', step)].model_var):.3f}. \n")
213
+
214
+ print(f"Total reward: {(plan.scip_model.getObjVal()):.3f}")
215
+ print(f"Total time: {solve_time:.3f}")
216
+
217
+ if __name__ == "__main__":
218
+ main()
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Union
6
+
7
+ from pyscipopt.scip import Model, Variable as SCIPVariable
8
+
9
+ @dataclass
10
+ class Variable:
11
+ name: str
12
+ var_type: VarType
13
+ val_type: ValType
14
+ time: int
15
+ model: Model
16
+ model_var: Union[SCIPVariable | float]
17
+
18
+ @classmethod
19
+ def create_var(cls, model: Model, name: str, vtype: str, time: int, const_vals: dict[str, float]) -> Variable:
20
+
21
+ if "action" in vtype:
22
+ var_type = VarType.ACTION
23
+ elif "state" in vtype:
24
+ var_type = VarType.STATE
25
+ elif "auxiliary" in vtype:
26
+ var_type = VarType.AUX
27
+ elif "constant" in vtype:
28
+ var_type = VarType.CONSTANT
29
+ else: # var type isn't recognised
30
+ raise Exception("Unknown variable type: ")
31
+
32
+ if "continuous" in vtype:
33
+ val_type = ValType.CONTINUOUS
34
+ elif "integer" in vtype:
35
+ val_type = ValType.INTEGER
36
+ elif "boolean" in vtype:
37
+ val_type = ValType.BOOLEAN
38
+ elif var_type is VarType.CONSTANT:
39
+ # Special case for constants as the value type doesn't matter for the model as it is numeric
40
+ val_type = None
41
+ else: # val type isn't recognised
42
+ raise Exception("Unkown value type: ")
43
+
44
+ if var_type is VarType.CONSTANT:
45
+ model_var = const_vals[name]
46
+ else:
47
+ model_var = model.addVar(name=f"{name}_{time}", vtype=val_type.value, lb=None, ub=None)
48
+
49
+ var = Variable(
50
+ name=name,
51
+ var_type=var_type,
52
+ val_type=val_type,
53
+ time=time,
54
+ model=model,
55
+ model_var=model_var
56
+ )
57
+
58
+ return var
59
+
60
+ def to_dict(self):
61
+ if self.var_type is VarType.CONSTANT:
62
+ var_val = self.model_var
63
+ val_type = None
64
+ else:
65
+ val_type = self.val_type.name
66
+ try:
67
+ var_val = self.model.getVal(self.model_var)
68
+ except Warning:
69
+ var_val = None
70
+
71
+
72
+ return {
73
+ "name": self.name,
74
+ "variable_type": self.var_type.name,
75
+ "value_type": val_type,
76
+ "horizon": self.time,
77
+ "variable_value": var_val
78
+ }
79
+
80
+
81
+
82
+ class VarType(Enum):
83
+ ACTION = "action"
84
+ STATE = "state"
85
+ AUX = "auxiliary"
86
+ CONSTANT = "constant"
87
+
88
+ class ValType(Enum):
89
+ CONTINUOUS = "C"
90
+ INTEGER = "I"
91
+ BOOLEAN = "B"
@@ -0,0 +1,29 @@
1
+ from dataclasses import field, dataclass
2
+ from textwrap import dedent
3
+
4
+ @dataclass
5
+ class ZeroCrossing:
6
+ is_violated: bool = field(repr=False)
7
+ iteration: int = None
8
+ horizon: int = None
9
+ start: float = None
10
+ end: float = None
11
+ dt_interval: float = None
12
+ coef: float = field(init=False, default=None)
13
+ new_dt_val: float = field(init=False, default=None)
14
+ constraint_idx: int = None
15
+
16
+ def __post_init__(self):
17
+ if self.is_violated is True:
18
+ if self.start is None or self.end is None or self.dt_interval is None:
19
+ raise ValueError(
20
+ dedent(f"""
21
+ Incorrect input values, start, end and dt_interval have to be specified when zero crossing exists
22
+
23
+ {self.start = }
24
+ {self.end = }
25
+ {self.dt_interval = }
26
+ """))
27
+ avg_interval = (self.start + self.end) / 2.0
28
+ self.coef = avg_interval / self.dt_interval
29
+ self.new_dt_val = self.coef * self.dt_interval
scipplan/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- __version__ = "0.2.1alpha0"
1
+ __version__ = "0.2.2alpha0"
2
2
  print(f"SCIPPlan Version: {__version__}")
3
- __release__ = "v0.2.1a0"
3
+ __release__ = "v0.2.2a0"
4
4
  __author__ = "Ari Gestetner, Buser Say"
5
- __email__ = "ari.gestetner@monash.edu, buser.say@monash.edu"
5
+ __email__ = "ari.gestetner@monash.edu, buser.say@monash.edu"