rtc-tools 2.5.2rc3__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rtc-tools might be problematic. Click here for more details.

Files changed (47) hide show
  1. {rtc_tools-2.5.2rc3.dist-info → rtc_tools-2.6.0.dist-info}/METADATA +7 -7
  2. rtc_tools-2.6.0.dist-info/RECORD +50 -0
  3. {rtc_tools-2.5.2rc3.dist-info → rtc_tools-2.6.0.dist-info}/WHEEL +1 -1
  4. rtctools/__init__.py +2 -1
  5. rtctools/_internal/alias_tools.py +12 -10
  6. rtctools/_internal/caching.py +5 -3
  7. rtctools/_internal/casadi_helpers.py +11 -32
  8. rtctools/_internal/debug_check_helpers.py +1 -1
  9. rtctools/_version.py +3 -3
  10. rtctools/data/__init__.py +2 -2
  11. rtctools/data/csv.py +54 -33
  12. rtctools/data/interpolation/bspline.py +3 -3
  13. rtctools/data/interpolation/bspline1d.py +42 -29
  14. rtctools/data/interpolation/bspline2d.py +10 -4
  15. rtctools/data/netcdf.py +137 -93
  16. rtctools/data/pi.py +304 -210
  17. rtctools/data/rtc.py +64 -53
  18. rtctools/data/storage.py +91 -51
  19. rtctools/optimization/collocated_integrated_optimization_problem.py +1244 -696
  20. rtctools/optimization/control_tree_mixin.py +68 -66
  21. rtctools/optimization/csv_lookup_table_mixin.py +107 -74
  22. rtctools/optimization/csv_mixin.py +83 -52
  23. rtctools/optimization/goal_programming_mixin.py +239 -148
  24. rtctools/optimization/goal_programming_mixin_base.py +204 -111
  25. rtctools/optimization/homotopy_mixin.py +36 -27
  26. rtctools/optimization/initial_state_estimation_mixin.py +8 -8
  27. rtctools/optimization/io_mixin.py +48 -43
  28. rtctools/optimization/linearization_mixin.py +3 -1
  29. rtctools/optimization/linearized_order_goal_programming_mixin.py +57 -28
  30. rtctools/optimization/min_abs_goal_programming_mixin.py +72 -29
  31. rtctools/optimization/modelica_mixin.py +135 -81
  32. rtctools/optimization/netcdf_mixin.py +32 -18
  33. rtctools/optimization/optimization_problem.py +181 -127
  34. rtctools/optimization/pi_mixin.py +68 -36
  35. rtctools/optimization/planning_mixin.py +19 -0
  36. rtctools/optimization/single_pass_goal_programming_mixin.py +159 -112
  37. rtctools/optimization/timeseries.py +4 -6
  38. rtctools/rtctoolsapp.py +18 -18
  39. rtctools/simulation/csv_mixin.py +37 -30
  40. rtctools/simulation/io_mixin.py +9 -5
  41. rtctools/simulation/pi_mixin.py +62 -32
  42. rtctools/simulation/simulation_problem.py +471 -180
  43. rtctools/util.py +84 -56
  44. rtc_tools-2.5.2rc3.dist-info/RECORD +0 -49
  45. {rtc_tools-2.5.2rc3.dist-info → rtc_tools-2.6.0.dist-info}/COPYING.LESSER +0 -0
  46. {rtc_tools-2.5.2rc3.dist-info → rtc_tools-2.6.0.dist-info}/entry_points.txt +0 -0
  47. {rtc_tools-2.5.2rc3.dist-info → rtc_tools-2.6.0.dist-info}/top_level.txt +0 -0
@@ -41,14 +41,35 @@ class ControlTreeMixin(OptimizationProblem):
41
41
 
42
42
  options = {}
43
43
 
44
- options['forecast_variables'] = [var.name()
45
- for var in self.dae_variables['constant_inputs']]
46
- options['branching_times'] = self.times()[1:]
47
- options['k'] = 2
44
+ options["forecast_variables"] = [
45
+ var.name() for var in self.dae_variables["constant_inputs"]
46
+ ]
47
+ options["branching_times"] = self.times()[1:]
48
+ options["k"] = 2
48
49
 
49
50
  return options
50
51
 
52
+ def discretize_control(self, variable, ensemble_member, times, offset):
53
+ control_indices = np.zeros(len(times), dtype=np.int16)
54
+ for branch, members in self.__branches.items():
55
+ if ensemble_member not in members:
56
+ continue
57
+
58
+ branching_time_0 = self.__branching_times[len(branch) + 0]
59
+ branching_time_1 = self.__branching_times[len(branch) + 1]
60
+ els = np.logical_and(times >= branching_time_0, times < branching_time_1)
61
+ nnz = np.count_nonzero(els)
62
+ try:
63
+ control_indices[els] = self.__discretize_controls_cache[(variable, branch)]
64
+ except KeyError:
65
+ control_indices[els] = list(range(offset, offset + nnz))
66
+ self.__discretize_controls_cache[(variable, branch)] = control_indices[els]
67
+ offset += nnz
68
+ return control_indices
69
+
51
70
  def discretize_controls(self, resolved_bounds):
71
+ self.__discretize_controls_cache = {}
72
+
52
73
  # Collect options
53
74
  options = self.control_tree_options()
54
75
 
@@ -56,14 +77,14 @@ class ControlTreeMixin(OptimizationProblem):
56
77
  # presence of these is assumed below.
57
78
  times = self.times()
58
79
  t0 = self.initial_time
59
- branching_times = options['branching_times']
60
- n_branching_times = len(branching_times)
80
+ self.__branching_times = options["branching_times"]
81
+ n_branching_times = len(self.__branching_times)
61
82
  if n_branching_times > len(times) - 1:
62
83
  raise Exception("Too many branching points specified")
63
- branching_times = np.concatenate(([t0], branching_times, [np.inf]))
84
+ self.__branching_times = np.concatenate(([t0], self.__branching_times, [np.inf]))
64
85
 
65
86
  logger.debug("ControlTreeMixin: Branching times:")
66
- logger.debug(branching_times)
87
+ logger.debug(self.__branching_times)
67
88
 
68
89
  # Branches start at branching times, so that the tree looks like the following:
69
90
  #
@@ -88,8 +109,8 @@ class ControlTreeMixin(OptimizationProblem):
88
109
  distances = np.zeros((n_branch_members, n_branch_members))
89
110
 
90
111
  # Decide branching on a segment of the time horizon
91
- branching_time_0 = branching_times[len(current_branch) + 1]
92
- branching_time_1 = branching_times[len(current_branch) + 2]
112
+ branching_time_0 = self.__branching_times[len(current_branch) + 1]
113
+ branching_time_1 = self.__branching_times[len(current_branch) + 2]
93
114
 
94
115
  # Compute reverse ensemble member index-to-distance index map.
95
116
  reverse = {}
@@ -98,23 +119,24 @@ class ControlTreeMixin(OptimizationProblem):
98
119
 
99
120
  # Compute distances between ensemble members, summed for all
100
121
  # forecast variables
101
- for forecast_variable in options['forecast_variables']:
122
+ for forecast_variable in options["forecast_variables"]:
102
123
  # We assume the time stamps of the forecasts in all ensemble
103
124
  # members to be identical
104
- timeseries = self.constant_inputs(ensemble_member=0)[
105
- forecast_variable]
125
+ timeseries = self.constant_inputs(ensemble_member=0)[forecast_variable]
106
126
  els = np.logical_and(
107
- timeseries.times >= branching_time_0, timeseries.times < branching_time_1)
127
+ timeseries.times >= branching_time_0, timeseries.times < branching_time_1
128
+ )
108
129
 
109
130
  # Compute distance between ensemble members
110
131
  for i, member_i in enumerate(branches[current_branch]):
111
- timeseries_i = self.constant_inputs(ensemble_member=member_i)[
112
- forecast_variable]
132
+ timeseries_i = self.constant_inputs(ensemble_member=member_i)[forecast_variable]
113
133
  for j, member_j in enumerate(branches[current_branch]):
114
134
  timeseries_j = self.constant_inputs(ensemble_member=member_j)[
115
- forecast_variable]
116
- distances[
117
- i, j] += np.linalg.norm(timeseries_i.values[els] - timeseries_j.values[els])
135
+ forecast_variable
136
+ ]
137
+ distances[i, j] += np.linalg.norm(
138
+ timeseries_i.values[els] - timeseries_j.values[els]
139
+ )
118
140
 
119
141
  # Keep track of ensemble members that have not yet been allocated
120
142
  # to a new branch
@@ -123,45 +145,51 @@ class ControlTreeMixin(OptimizationProblem):
123
145
  # We first select the scenario with the max distance to any other branch
124
146
  idx = np.argmax(np.amax(distances, axis=0))
125
147
 
126
- for i in range(options['k']):
148
+ for i in range(options["k"]):
127
149
  if idx >= 0:
128
- branches[current_branch +
129
- (i, )] = [branches[current_branch][idx]]
150
+ branches[current_branch + (i,)] = [branches[current_branch][idx]]
130
151
 
131
152
  available.remove(branches[current_branch][idx])
132
153
 
133
154
  # We select the scenario with the max min distance to the other branches
134
- min_distances = np.array([
135
- min([np.inf] + [distances[j, k]
136
- for j, member_j in enumerate(branches[current_branch])
137
- if member_j not in available and member_k in available])
138
- for k, member_k in enumerate(branches[current_branch])
139
- ], dtype=np.float64)
155
+ min_distances = np.array(
156
+ [
157
+ min(
158
+ [np.inf]
159
+ + [
160
+ distances[j, k]
161
+ for j, member_j in enumerate(branches[current_branch])
162
+ if member_j not in available and member_k in available
163
+ ]
164
+ )
165
+ for k, member_k in enumerate(branches[current_branch])
166
+ ],
167
+ dtype=np.float64,
168
+ )
140
169
  min_distances[np.where(min_distances == np.inf)] = -np.inf
141
170
 
142
171
  idx = np.argmax(min_distances)
143
172
  if min_distances[idx] <= 0:
144
173
  idx = -1
145
174
  else:
146
- branches[current_branch + (i, )] = []
175
+ branches[current_branch + (i,)] = []
147
176
 
148
177
  # Cluster remaining ensemble members to branches
149
178
  for member_i in available:
150
179
  min_i = 0
151
180
  min_distance = np.inf
152
- for i in range(options['k']):
153
- branch2 = branches[current_branch + (i, )]
181
+ for i in range(options["k"]):
182
+ branch2 = branches[current_branch + (i,)]
154
183
  if len(branch2) > 0:
155
- distance = distances[
156
- reverse[member_i], reverse[branch2[0]]]
184
+ distance = distances[reverse[member_i], reverse[branch2[0]]]
157
185
  if distance < min_distance:
158
186
  min_distance = distance
159
187
  min_i = i
160
- branches[current_branch + (min_i, )].append(member_i)
188
+ branches[current_branch + (min_i,)].append(member_i)
161
189
 
162
190
  # Recurse
163
- for i in range(options['k']):
164
- branch(current_branch + (i, ))
191
+ for i in range(options["k"]):
192
+ branch(current_branch + (i,))
165
193
 
166
194
  current_branch = ()
167
195
  branches[current_branch] = list(range(self.ensemble_size))
@@ -172,36 +200,10 @@ class ControlTreeMixin(OptimizationProblem):
172
200
 
173
201
  self.__branches = branches
174
202
 
175
- # Map ensemble members to control inputs
176
- # (variable, (ensemble member, step)) -> control_index
177
- self.__control_indices = [{} for ensemble_member in range(self.ensemble_size)]
178
- count = 0
179
- for control_input in self.controls:
180
- times = self.times(control_input)
181
- for member in range(self.ensemble_size):
182
- self.__control_indices[member][control_input] = np.zeros(
183
- len(times), dtype=np.int16)
184
- for branch, members in branches.items():
185
- if not members:
186
- # Avoid making free variables by skipping branches which have no members
187
- continue
188
-
189
- branching_time_0 = branching_times[len(branch) + 0]
190
- branching_time_1 = branching_times[len(branch) + 1]
191
- els = np.logical_and(
192
- times >= branching_time_0, times < branching_time_1)
193
- nnz = np.count_nonzero(els)
194
- for member in members:
195
- self.__control_indices[member][control_input][els] = \
196
- list(range(count, count + nnz))
197
- count += nnz
198
-
199
- discrete = self._collint_get_discrete(count, self.__control_indices)
200
- lbx, ubx = self._collint_get_lbx_ubx(count, self.__control_indices)
201
- x0 = self._collint_get_x0(count, self.__control_indices)
202
-
203
- # Return number of control variables
204
- return count, discrete, lbx, ubx, x0, self.__control_indices
203
+ # By now, the tree branches have been set up. We now rely
204
+ # on the default discretization logic to call discretize_control()
205
+ # for each (control variable, ensemble member) pair.
206
+ return super().discretize_controls(resolved_bounds)
205
207
 
206
208
  @property
207
209
  def control_tree_branches(self) -> Dict[Tuple[int], List[int]]:
@@ -6,8 +6,9 @@ import pickle
6
6
  from typing import Iterable, List, Tuple, Union
7
7
 
8
8
  import casadi as ca
9
-
10
9
  import numpy as np
10
+ from scipy.interpolate import bisplev, bisplrep, splev
11
+ from scipy.optimize import brentq
11
12
 
12
13
  import rtctools.data.csv as csv
13
14
  from rtctools._internal.caching import cached
@@ -15,9 +16,6 @@ from rtctools.data.interpolation.bspline1d import BSpline1D
15
16
  from rtctools.data.interpolation.bspline2d import BSpline2D
16
17
  from rtctools.optimization.timeseries import Timeseries
17
18
 
18
- from scipy.interpolate import bisplev, bisplrep, splev
19
- from scipy.optimize import brentq
20
-
21
19
  from .optimization_problem import LookupTable as LookupTableBase
22
20
  from .optimization_problem import OptimizationProblem
23
21
 
@@ -54,10 +52,14 @@ class LookupTable(LookupTableBase):
54
52
  def domain(self) -> Tuple:
55
53
  t = self.__t
56
54
  if t is None:
57
- raise AttributeError('This lookup table was not instantiated with tck metadata. \
58
- Domain/Range information is unavailable.')
55
+ raise AttributeError(
56
+ "This lookup table was not instantiated with tck metadata. \
57
+ Domain/Range information is unavailable."
58
+ )
59
59
  if type(t) == tuple and len(t) == 2:
60
- raise NotImplementedError('Domain/Range information is not yet implemented for 2D LookupTables')
60
+ raise NotImplementedError(
61
+ "Domain/Range information is not yet implemented for 2D LookupTables"
62
+ )
61
63
 
62
64
  return np.nextafter(t[0], np.inf), np.nextafter(t[-1], -np.inf)
63
65
 
@@ -84,9 +86,7 @@ class LookupTable(LookupTableBase):
84
86
  @cached
85
87
  def __numeric_function_evaluator(self):
86
88
  return np.vectorize(
87
- lambda *args: np.nan
88
- if np.any(np.isnan(args))
89
- else float(self.function(*args))
89
+ lambda *args: np.nan if np.any(np.isnan(args)) else float(self.function(*args))
90
90
  )
91
91
 
92
92
  def __call__(
@@ -170,9 +170,7 @@ class LookupTable(LookupTableBase):
170
170
  all_viol = y_array_not_nan[lb_viol | ub_viol]
171
171
  if all_viol.size > 0:
172
172
  raise ValueError(
173
- "Values {} are not in lookup table range ({}, {})".format(
174
- all_viol, l_r, u_r
175
- )
173
+ "Values {} are not in lookup table range ({}, {})".format(all_viol, l_r, u_r)
176
174
  )
177
175
 
178
176
  # Construct function to do inverse evaluation
@@ -200,9 +198,9 @@ class CSVLookupTableMixin(OptimizationProblem):
200
198
  """
201
199
  Adds lookup tables to your optimization problem.
202
200
 
203
- During preprocessing, the CSV files located inside the ``lookup_tables`` subfolder are read.
204
- In every CSV file, the first column contains the output of the lookup table. Subsequent columns contain
205
- the input variables.
201
+ During preprocessing, the CSV files located inside the ``lookup_tables`` subfolder are read. In
202
+ every CSV file, the first column contains the output of the lookup table. Subsequent columns
203
+ contain the input variables.
206
204
 
207
205
  Cubic B-Splines are used to turn the data points into continuous lookup tables.
208
206
 
@@ -226,13 +224,15 @@ class CSVLookupTableMixin(OptimizationProblem):
226
224
  Currently only one-dimensional lookup tables are fully supported. Support for two-
227
225
  dimensional lookup tables is experimental.
228
226
 
229
- :cvar csv_delimiter: Column delimiter used in CSV files. Default is ``,``.
230
- :cvar csv_lookup_table_debug: Whether to generate plots of the spline fits. Default is ``false``.
231
- :cvar csv_lookup_table_debug_points: Number of evaluation points for plots. Default is ``100``.
227
+ :cvar csv_delimiter: Column delimiter used in CSV files. Default is ``,``.
228
+ :cvar csv_lookup_table_debug: Whether to generate plots of the spline fits.
229
+ Default is ``False``.
230
+ :cvar csv_lookup_table_debug_points: Number of evaluation points for plots.
231
+ Default is ``100``.
232
232
  """
233
233
 
234
234
  #: Column delimiter used in CSV files
235
- csv_delimiter = ','
235
+ csv_delimiter = ","
236
236
 
237
237
  #: Debug settings
238
238
  csv_lookup_table_debug = False
@@ -240,13 +240,12 @@ class CSVLookupTableMixin(OptimizationProblem):
240
240
 
241
241
  def __init__(self, **kwargs):
242
242
  # Check arguments
243
- if 'input_folder' in kwargs:
244
- assert ('lookup_table_folder' not in kwargs)
243
+ if "input_folder" in kwargs:
244
+ assert "lookup_table_folder" not in kwargs
245
245
 
246
- self.__lookup_table_folder = os.path.join(
247
- kwargs['input_folder'], 'lookup_tables')
246
+ self.__lookup_table_folder = os.path.join(kwargs["input_folder"], "lookup_tables")
248
247
  else:
249
- self.__lookup_table_folder = kwargs['lookup_table_folder']
248
+ self.__lookup_table_folder = kwargs["lookup_table_folder"]
250
249
 
251
250
  # Call parent
252
251
  super().__init__(**kwargs)
@@ -256,15 +255,15 @@ class CSVLookupTableMixin(OptimizationProblem):
256
255
  super().pre()
257
256
 
258
257
  # Get curve fitting options from curvefit_options.ini file
259
- ini_path = os.path.join(
260
- self.__lookup_table_folder, 'curvefit_options.ini')
258
+ ini_path = os.path.join(self.__lookup_table_folder, "curvefit_options.ini")
261
259
  try:
262
260
  ini_config = configparser.RawConfigParser()
263
261
  ini_config.read(ini_path)
264
262
  no_curvefit_options = False
265
263
  except IOError:
266
264
  logger.info(
267
- "CSVLookupTableMixin: No curvefit_options.ini file found. Using default values.")
265
+ "CSVLookupTableMixin: No curvefit_options.ini file found. Using default values."
266
+ )
268
267
  no_curvefit_options = True
269
268
 
270
269
  def get_curvefit_options(curve_name, no_curvefit_options=no_curvefit_options):
@@ -282,30 +281,33 @@ class CSVLookupTableMixin(OptimizationProblem):
282
281
  prop = 0
283
282
  except ValueError:
284
283
  raise Exception(
285
- 'CSVLookupTableMixin: Invalid {0} constraint for {1}. {0} should '
286
- 'be either -1, 0, or 1.'.format(prop_name, curve_name))
284
+ "CSVLookupTableMixin: Invalid {0} constraint for {1}. {0} should "
285
+ "be either -1, 0, or 1.".format(prop_name, curve_name)
286
+ )
287
287
  return prop
288
288
 
289
- for prop_name in ['monotonicity', 'monotonicity2', 'curvature']:
289
+ for prop_name in ["monotonicity", "monotonicity2", "curvature"]:
290
290
  curvefit_options.append(get_property(prop_name))
291
291
 
292
- logger.debug("CSVLookupTableMixin: Curve fit option for {}:({},{},{})".format(
293
- curve_name, *curvefit_options))
292
+ logger.debug(
293
+ "CSVLookupTableMixin: Curve fit option for {}:({},{},{})".format(
294
+ curve_name, *curvefit_options
295
+ )
296
+ )
294
297
  return tuple(curvefit_options)
295
298
 
296
299
  def check_lookup_table(lookup_table):
297
300
  if lookup_table in self.__lookup_tables:
298
301
  raise Exception(
299
302
  "Cannot add lookup table {},"
300
- "since there is already one with this name.".format(lookup_table))
303
+ "since there is already one with this name.".format(lookup_table)
304
+ )
305
+
301
306
  # Read CSV files
302
- logger.info(
303
- "CSVLookupTableMixin: Generating Splines from lookup table data.")
307
+ logger.info("CSVLookupTableMixin: Generating Splines from lookup table data.")
304
308
  self.__lookup_tables = {}
305
309
  for filename in glob.glob(os.path.join(self.__lookup_table_folder, "*.csv")):
306
-
307
- logger.debug(
308
- "CSVLookupTableMixin: Reading lookup table from {}".format(filename))
310
+ logger.debug("CSVLookupTableMixin: Reading lookup table from {}".format(filename))
309
311
 
310
312
  csvinput = csv.load(filename, delimiter=self.csv_delimiter)
311
313
  output = csvinput.dtype.names[0]
@@ -314,62 +316,80 @@ class CSVLookupTableMixin(OptimizationProblem):
314
316
  # Get monotonicity and curvature from ini file
315
317
  mono, mono2, curv = get_curvefit_options(output)
316
318
 
317
- logger.debug(
318
- "CSVLookupTableMixin: Output is {}, inputs are {}.".format(output, inputs))
319
+ logger.debug("CSVLookupTableMixin: Output is {}, inputs are {}.".format(output, inputs))
319
320
 
320
321
  tck = None
321
322
  function = None
322
323
 
323
- # If tck file is newer than the csv file, first try to load the cached values from the tck file
324
- tck_filename = filename.replace('.csv', '.tck')
324
+ # If tck file is newer than the csv file, first try to load the cached values from
325
+ # the tck file
326
+ tck_filename = filename.replace(".csv", ".tck")
325
327
  valid_cache = False
326
328
  if os.path.exists(tck_filename):
327
329
  if no_curvefit_options:
328
330
  valid_cache = os.path.getmtime(filename) < os.path.getmtime(tck_filename)
329
331
  else:
330
- valid_cache = (os.path.getmtime(filename) < os.path.getmtime(tck_filename)) and \
331
- (os.path.getmtime(ini_path) < os.path.getmtime(tck_filename))
332
+ valid_cache = (
333
+ os.path.getmtime(filename) < os.path.getmtime(tck_filename)
334
+ ) and (os.path.getmtime(ini_path) < os.path.getmtime(tck_filename))
332
335
  if valid_cache:
333
336
  logger.debug(
334
- 'CSVLookupTableMixin: Attempting to use cached tck values for {}'.format(output))
335
- with open(tck_filename, 'rb') as f:
337
+ "CSVLookupTableMixin: Attempting to use cached tck values for {}".format(
338
+ output
339
+ )
340
+ )
341
+ with open(tck_filename, "rb") as f:
336
342
  try:
337
343
  tck, function = pickle.load(f)
338
344
  except Exception:
339
345
  valid_cache = False
340
346
  if not valid_cache:
341
- logger.info(
342
- 'CSVLookupTableMixin: Recalculating tck values for {}'.format(output))
347
+ logger.info("CSVLookupTableMixin: Recalculating tck values for {}".format(output))
343
348
 
344
349
  if len(csvinput.dtype.names) == 2:
345
350
  if not valid_cache:
346
351
  k = 3 # default value
347
352
  # 1D spline fitting needs k+1 data points
348
353
  if len(csvinput[output]) >= k + 1:
349
- tck = BSpline1D.fit(csvinput[inputs[0]], csvinput[
350
- output], k=k, monotonicity=mono, curvature=curv)
354
+ tck = BSpline1D.fit(
355
+ csvinput[inputs[0]],
356
+ csvinput[output],
357
+ k=k,
358
+ monotonicity=mono,
359
+ curvature=curv,
360
+ )
351
361
  else:
352
362
  raise Exception(
353
- 'CSVLookupTableMixin: Too few data points in {} to do spline fitting. '
354
- 'Need at least {} points.'.format(filename, k + 1))
363
+ "CSVLookupTableMixin: Too few data points in {} to do spline fitting. "
364
+ "Need at least {} points.".format(filename, k + 1)
365
+ )
355
366
 
356
367
  if self.csv_lookup_table_debug:
357
368
  import pylab
358
- i = np.linspace(csvinput[inputs[0]][0], csvinput[
359
- inputs[0]][-1], self.csv_lookup_table_debug_points)
369
+
370
+ i = np.linspace(
371
+ csvinput[inputs[0]][0],
372
+ csvinput[inputs[0]][-1],
373
+ self.csv_lookup_table_debug_points,
374
+ )
360
375
  o = splev(i, tck)
361
376
  pylab.clf()
362
377
  # TODO: Figure out why cross-section B0607 in NZV does not
363
378
  # conform to constraints!
364
379
  pylab.plot(i, o)
365
- pylab.plot(csvinput[inputs[0]], csvinput[
366
- output], linestyle='', marker='x', markersize=10)
367
- figure_filename = filename.replace('.csv', '.png')
380
+ pylab.plot(
381
+ csvinput[inputs[0]],
382
+ csvinput[output],
383
+ linestyle="",
384
+ marker="x",
385
+ markersize=10,
386
+ )
387
+ figure_filename = filename.replace(".csv", ".png")
368
388
  pylab.savefig(figure_filename)
369
389
 
370
390
  symbols = [ca.SX.sym(inputs[0])]
371
391
  if not valid_cache:
372
- function = ca.Function('f', symbols, [BSpline1D(*tck)(symbols[0])])
392
+ function = ca.Function("f", symbols, [BSpline1D(*tck)(symbols[0])])
373
393
  check_lookup_table(output)
374
394
  self.__lookup_tables[output] = LookupTable(symbols, function, tck)
375
395
 
@@ -382,42 +402,55 @@ class CSVLookupTableMixin(OptimizationProblem):
382
402
  if len(csvinput[output]) >= (kx + 1) * (ky + 1):
383
403
  # TODO: add curvature parameters from curvefit_options.ini
384
404
  # once 2d fit is implemented
385
- tck = bisplrep(csvinput[inputs[0]], csvinput[
386
- inputs[1]], csvinput[output], kx=kx, ky=ky)
405
+ tck = bisplrep(
406
+ csvinput[inputs[0]], csvinput[inputs[1]], csvinput[output], kx=kx, ky=ky
407
+ )
387
408
  else:
388
409
  raise Exception(
389
- 'CSVLookupTableMixin: Too few data points in {} to do spline fitting. '
390
- 'Need at least {} points.'.format(filename, (kx + 1) * (ky + 1)))
410
+ "CSVLookupTableMixin: Too few data points in {} to do spline fitting. "
411
+ "Need at least {} points.".format(filename, (kx + 1) * (ky + 1))
412
+ )
391
413
 
392
414
  if self.csv_lookup_table_debug:
393
415
  import pylab
394
- i1 = np.linspace(csvinput[inputs[0]][0], csvinput[
395
- inputs[0]][-1], self.csv_lookup_table_debug_points)
396
- i2 = np.linspace(csvinput[inputs[1]][0], csvinput[
397
- inputs[1]][-1], self.csv_lookup_table_debug_points)
416
+
417
+ i1 = np.linspace(
418
+ csvinput[inputs[0]][0],
419
+ csvinput[inputs[0]][-1],
420
+ self.csv_lookup_table_debug_points,
421
+ )
422
+ i2 = np.linspace(
423
+ csvinput[inputs[1]][0],
424
+ csvinput[inputs[1]][-1],
425
+ self.csv_lookup_table_debug_points,
426
+ )
398
427
  i1, i2 = np.meshgrid(i1, i2)
399
428
  i1 = i1.flatten()
400
429
  i2 = i2.flatten()
401
430
  o = bisplev(i1, i2, tck)
402
431
  pylab.clf()
403
432
  pylab.plot_surface(i1, i2, o)
404
- figure_filename = filename.replace('.csv', '.png')
433
+ figure_filename = filename.replace(".csv", ".png")
405
434
  pylab.savefig(figure_filename)
406
435
  symbols = [ca.SX.sym(inputs[0]), ca.SX.sym(inputs[1])]
407
436
  if not valid_cache:
408
- function = ca.Function('f', symbols, [BSpline2D(*tck)(symbols[0], symbols[1])])
437
+ function = ca.Function("f", symbols, [BSpline2D(*tck)(symbols[0], symbols[1])])
409
438
  check_lookup_table(output)
410
439
  self.__lookup_tables[output] = LookupTable(symbols, function, tck)
411
440
 
412
441
  else:
413
442
  raise Exception(
414
- 'CSVLookupTableMixin: {}-dimensional lookup tables not implemented yet.'.format(
415
- len(csvinput.dtype.names)))
443
+ "CSVLookupTableMixin: {}-dimensional lookup tables not implemented yet.".format(
444
+ len(csvinput.dtype.names)
445
+ )
446
+ )
416
447
 
417
448
  if not valid_cache:
418
- pickle.dump((tck, function),
419
- open(filename.replace('.csv', '.tck'), 'wb'),
420
- protocol=pickle.HIGHEST_PROTOCOL)
449
+ pickle.dump(
450
+ (tck, function),
451
+ open(filename.replace(".csv", ".tck"), "wb"),
452
+ protocol=pickle.HIGHEST_PROTOCOL,
453
+ )
421
454
 
422
455
  def lookup_tables(self, ensemble_member):
423
456
  # Call parent class first for default values.