rtc-tools 2.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rtc_tools-2.7.3.dist-info/METADATA +53 -0
  2. rtc_tools-2.7.3.dist-info/RECORD +50 -0
  3. rtc_tools-2.7.3.dist-info/WHEEL +5 -0
  4. rtc_tools-2.7.3.dist-info/entry_points.txt +3 -0
  5. rtc_tools-2.7.3.dist-info/licenses/COPYING.LESSER +165 -0
  6. rtc_tools-2.7.3.dist-info/top_level.txt +1 -0
  7. rtctools/__init__.py +5 -0
  8. rtctools/_internal/__init__.py +0 -0
  9. rtctools/_internal/alias_tools.py +188 -0
  10. rtctools/_internal/caching.py +25 -0
  11. rtctools/_internal/casadi_helpers.py +99 -0
  12. rtctools/_internal/debug_check_helpers.py +41 -0
  13. rtctools/_version.py +21 -0
  14. rtctools/data/__init__.py +4 -0
  15. rtctools/data/csv.py +150 -0
  16. rtctools/data/interpolation/__init__.py +3 -0
  17. rtctools/data/interpolation/bspline.py +31 -0
  18. rtctools/data/interpolation/bspline1d.py +169 -0
  19. rtctools/data/interpolation/bspline2d.py +54 -0
  20. rtctools/data/netcdf.py +467 -0
  21. rtctools/data/pi.py +1236 -0
  22. rtctools/data/rtc.py +228 -0
  23. rtctools/data/storage.py +343 -0
  24. rtctools/optimization/__init__.py +0 -0
  25. rtctools/optimization/collocated_integrated_optimization_problem.py +3208 -0
  26. rtctools/optimization/control_tree_mixin.py +221 -0
  27. rtctools/optimization/csv_lookup_table_mixin.py +462 -0
  28. rtctools/optimization/csv_mixin.py +300 -0
  29. rtctools/optimization/goal_programming_mixin.py +769 -0
  30. rtctools/optimization/goal_programming_mixin_base.py +1094 -0
  31. rtctools/optimization/homotopy_mixin.py +165 -0
  32. rtctools/optimization/initial_state_estimation_mixin.py +89 -0
  33. rtctools/optimization/io_mixin.py +320 -0
  34. rtctools/optimization/linearization_mixin.py +33 -0
  35. rtctools/optimization/linearized_order_goal_programming_mixin.py +235 -0
  36. rtctools/optimization/min_abs_goal_programming_mixin.py +385 -0
  37. rtctools/optimization/modelica_mixin.py +482 -0
  38. rtctools/optimization/netcdf_mixin.py +177 -0
  39. rtctools/optimization/optimization_problem.py +1302 -0
  40. rtctools/optimization/pi_mixin.py +292 -0
  41. rtctools/optimization/planning_mixin.py +19 -0
  42. rtctools/optimization/single_pass_goal_programming_mixin.py +676 -0
  43. rtctools/optimization/timeseries.py +56 -0
  44. rtctools/rtctoolsapp.py +131 -0
  45. rtctools/simulation/__init__.py +0 -0
  46. rtctools/simulation/csv_mixin.py +171 -0
  47. rtctools/simulation/io_mixin.py +195 -0
  48. rtctools/simulation/pi_mixin.py +255 -0
  49. rtctools/simulation/simulation_problem.py +1293 -0
  50. rtctools/util.py +241 -0
@@ -0,0 +1,221 @@
1
+ import logging
2
+ from typing import Dict, List, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from .optimization_problem import OptimizationProblem
7
+
8
+ logger = logging.getLogger("rtctools")
9
+
10
+
11
+ class ControlTreeMixin(OptimizationProblem):
12
+ """
13
+ Adds a stochastic control tree to your optimization problem.
14
+ """
15
+
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ self.__branches = {}
20
+
21
+ def control_tree_options(self) -> Dict[str, Union[List[str], List[float], int]]:
22
+ """
23
+ Returns a dictionary of options controlling the creation of a k-ary stochastic tree.
24
+
25
+ +------------------------+---------------------+-----------------------+
26
+ | Option | Type | Default value |
27
+ +========================+=====================+=======================+
28
+ | ``forecast_variables`` | ``list`` of strings | All constant inputs |
29
+ +------------------------+---------------------+-----------------------+
30
+ | ``branching_times`` | ``list`` of floats | ``self.times()`` |
31
+ +------------------------+---------------------+-----------------------+
32
+ | ``k`` | ``int`` | ``2`` |
33
+ +------------------------+---------------------+-----------------------+
34
+
35
+ A ``k``-ary tree is generated, branching at every interior branching time.
36
+ Ensemble members are clustered to paths through the tree based on average
37
+ distance over all forecast variables.
38
+
39
+ :returns: A dictionary of control tree generation options.
40
+ """
41
+
42
+ options = {}
43
+
44
+ options["forecast_variables"] = [
45
+ var.name() for var in self.dae_variables["constant_inputs"]
46
+ ]
47
+ options["branching_times"] = self.times()[1:]
48
+ options["k"] = 2
49
+
50
+ return options
51
+
52
+ def discretize_control(self, variable, ensemble_member, times, offset):
53
+ control_indices = np.zeros(len(times), dtype=np.int64)
54
+ for branch, members in self.__branches.items():
55
+ if ensemble_member not in members:
56
+ continue
57
+
58
+ branching_time_0 = self.__branching_times[len(branch) + 0]
59
+ branching_time_1 = self.__branching_times[len(branch) + 1]
60
+ els = np.logical_and(times >= branching_time_0, times < branching_time_1)
61
+ nnz = np.count_nonzero(els)
62
+ try:
63
+ control_indices[els] = self.__discretize_controls_cache[(variable, branch)]
64
+ except KeyError:
65
+ control_indices[els] = list(range(offset, offset + nnz))
66
+ self.__discretize_controls_cache[(variable, branch)] = control_indices[els]
67
+ offset += nnz
68
+ return control_indices
69
+
70
+ def discretize_controls(self, resolved_bounds):
71
+ self.__discretize_controls_cache = {}
72
+
73
+ # Collect options
74
+ options = self.control_tree_options()
75
+
76
+ # Make sure branching times contain initial and final time. The
77
+ # presence of these is assumed below.
78
+ times = self.times()
79
+ t0 = self.initial_time
80
+ self.__branching_times = options["branching_times"]
81
+ n_branching_times = len(self.__branching_times)
82
+ if n_branching_times > len(times) - 1:
83
+ raise Exception("Too many branching points specified")
84
+ self.__branching_times = np.concatenate(([t0], self.__branching_times, [np.inf]))
85
+
86
+ logger.debug("ControlTreeMixin: Branching times:")
87
+ logger.debug(self.__branching_times)
88
+
89
+ # Avoid calling constant_inputs() many times
90
+ constant_inputs = [
91
+ self.constant_inputs(ensemble_member=i) for i in range(self.ensemble_size)
92
+ ]
93
+
94
+ # Branches start at branching times, so that the tree looks like the following:
95
+ #
96
+ # *-----
97
+ # *-----
98
+ # *-----
99
+ #
100
+ # t0 t1
101
+ #
102
+ # with branching time t1.
103
+ branches = {}
104
+
105
+ def branch(current_branch: Tuple[int]):
106
+ if len(current_branch) >= n_branching_times:
107
+ return
108
+
109
+ # Branch stats
110
+ n_branch_members = len(branches[current_branch])
111
+ if n_branch_members == 0:
112
+ # Nothing to do
113
+ return
114
+ distances = np.zeros((n_branch_members, n_branch_members))
115
+
116
+ # Decide branching on a segment of the time horizon
117
+ branching_time_0 = self.__branching_times[len(current_branch) + 1]
118
+ branching_time_1 = self.__branching_times[len(current_branch) + 2]
119
+
120
+ # Compute reverse ensemble member index-to-distance index map.
121
+ reverse = {}
122
+ for i, member_i in enumerate(branches[current_branch]):
123
+ reverse[member_i] = i
124
+
125
+ # Compute distances between ensemble members, summed for all
126
+ # forecast variables
127
+ for forecast_variable in options["forecast_variables"]:
128
+ # We assume the time stamps of the forecasts in all ensemble
129
+ # members to be identical
130
+ timeseries = constant_inputs[0][forecast_variable]
131
+ els = np.logical_and(
132
+ timeseries.times >= branching_time_0, timeseries.times < branching_time_1
133
+ )
134
+
135
+ # Compute distance between ensemble members
136
+ for i, member_i in enumerate(branches[current_branch]):
137
+ timeseries_i = constant_inputs[member_i][forecast_variable]
138
+ for j, member_j in enumerate(branches[current_branch]):
139
+ timeseries_j = constant_inputs[member_j][forecast_variable]
140
+ distances[i, j] += np.linalg.norm(
141
+ timeseries_i.values[els] - timeseries_j.values[els]
142
+ )
143
+
144
+ # Keep track of ensemble members that have not yet been allocated
145
+ # to a new branch
146
+ available = set(branches[current_branch])
147
+
148
+ # We first select the scenario with the max distance to any other branch
149
+ idx = np.argmax(np.amax(distances, axis=0))
150
+
151
+ for i in range(options["k"]):
152
+ if idx >= 0:
153
+ branches[current_branch + (i,)] = [branches[current_branch][idx]]
154
+
155
+ available.remove(branches[current_branch][idx])
156
+
157
+ # We select the scenario with the max min distance to the other branches
158
+ min_distances = np.array(
159
+ [
160
+ min(
161
+ [np.inf]
162
+ + [
163
+ distances[j, k]
164
+ for j, member_j in enumerate(branches[current_branch])
165
+ if member_j not in available and member_k in available
166
+ ]
167
+ )
168
+ for k, member_k in enumerate(branches[current_branch])
169
+ ],
170
+ dtype=np.float64,
171
+ )
172
+ min_distances[np.where(min_distances == np.inf)] = -np.inf
173
+
174
+ idx = np.argmax(min_distances)
175
+ if min_distances[idx] <= 0:
176
+ idx = -1
177
+ else:
178
+ branches[current_branch + (i,)] = []
179
+
180
+ # Cluster remaining ensemble members to branches
181
+ for member_i in available:
182
+ min_i = 0
183
+ min_distance = np.inf
184
+ for i in range(options["k"]):
185
+ branch2 = branches[current_branch + (i,)]
186
+ if len(branch2) > 0:
187
+ distance = distances[reverse[member_i], reverse[branch2[0]]]
188
+ if distance < min_distance:
189
+ min_distance = distance
190
+ min_i = i
191
+ branches[current_branch + (min_i,)].append(member_i)
192
+
193
+ # Recurse
194
+ for i in range(options["k"]):
195
+ branch(current_branch + (i,))
196
+
197
+ current_branch = ()
198
+ branches[current_branch] = list(range(self.ensemble_size))
199
+ branch(current_branch)
200
+
201
+ logger.debug("ControlTreeMixin: Control tree is:")
202
+ logger.debug(branches)
203
+
204
+ self.__branches = branches
205
+
206
+ # By now, the tree branches have been set up. We now rely
207
+ # on the default discretization logic to call discretize_control()
208
+ # for each (control variable, ensemble member) pair.
209
+ return super().discretize_controls(resolved_bounds)
210
+
211
+ @property
212
+ def control_tree_branches(self) -> Dict[Tuple[int], List[int]]:
213
+ """
214
+ Returns a dictionary mapping the branch id (a Tuple of ints) to a list
215
+ of ensemble members in said branch.
216
+
217
+ Note that the root branch is an empty tuple containing all ensemble
218
+ members.
219
+ """
220
+
221
+ return self.__branches.copy()
@@ -0,0 +1,462 @@
1
+ import configparser
2
+ import glob
3
+ import logging
4
+ import os
5
+ from typing import Iterable, List, Tuple, Union
6
+
7
+ import casadi as ca
8
+ import numpy as np
9
+ from scipy.interpolate import bisplev, bisplrep, splev
10
+ from scipy.optimize import brentq
11
+
12
+ import rtctools.data.csv as csv
13
+ from rtctools._internal.caching import cached
14
+ from rtctools.data.interpolation.bspline1d import BSpline1D
15
+ from rtctools.data.interpolation.bspline2d import BSpline2D
16
+ from rtctools.optimization.timeseries import Timeseries
17
+
18
+ from .optimization_problem import LookupTable as LookupTableBase
19
+ from .optimization_problem import OptimizationProblem
20
+
21
+ logger = logging.getLogger("rtctools")
22
+
23
+
24
+ class LookupTable(LookupTableBase):
25
+ """
26
+ Lookup table.
27
+ """
28
+
29
+ def __init__(self, inputs: List[ca.MX], function: ca.Function, tck: Tuple = None):
30
+ """
31
+ Create a new lookup table object.
32
+
33
+ :param inputs: List of lookup table input variables.
34
+ :param function: Lookup table CasADi :class:`Function`.
35
+ """
36
+ self.__inputs = inputs
37
+ self.__function = function
38
+
39
+ self.__t, self.__c, self.__k = [None] * 3
40
+
41
+ if tck is not None:
42
+ if len(tck) == 3:
43
+ self.__t, self.__c, self.__k = tck
44
+ elif len(tck) == 5:
45
+ self.__t = tck[:2]
46
+ self.__c = tck[2]
47
+ self.__k = tck[3:]
48
+
49
+ @property
50
+ @cached
51
+ def domain(self) -> Tuple:
52
+ t = self.__t
53
+ if t is None:
54
+ raise AttributeError(
55
+ "This lookup table was not instantiated with tck metadata. \
56
+ Domain/Range information is unavailable."
57
+ )
58
+ if isinstance(t, tuple) and len(t) == 2:
59
+ raise NotImplementedError(
60
+ "Domain/Range information is not yet implemented for 2D LookupTables"
61
+ )
62
+
63
+ return np.nextafter(t[0], np.inf), np.nextafter(t[-1], -np.inf)
64
+
65
+ @property
66
+ @cached
67
+ def range(self) -> Tuple:
68
+ return self(self.domain[0]), self(self.domain[1])
69
+
70
+ @property
71
+ def inputs(self) -> List[ca.MX]:
72
+ """
73
+ List of lookup table input variables.
74
+ """
75
+ return self.__inputs
76
+
77
+ @property
78
+ def function(self) -> ca.Function:
79
+ """
80
+ Lookup table CasADi :class:`Function`.
81
+ """
82
+ return self.__function
83
+
84
+ @property
85
+ @cached
86
+ def __numeric_function_evaluator(self):
87
+ return np.vectorize(
88
+ lambda *args: np.nan if np.any(np.isnan(args)) else float(self.function(*args))
89
+ )
90
+
91
+ def __call__(
92
+ self, *args: Union[float, Iterable, Timeseries]
93
+ ) -> Union[float, np.ndarray, Timeseries]:
94
+ """
95
+ Evaluate the lookup table.
96
+
97
+ :param args: Input values.
98
+ :type args: Float, iterable of floats, or :class:`.Timeseries`
99
+ :returns: Lookup table evaluated at input values.
100
+
101
+ Example use::
102
+
103
+ y = lookup_table(1.0)
104
+ [y1, y2] = lookup_table([1.0, 2.0])
105
+
106
+ """
107
+ evaluator = self.__numeric_function_evaluator
108
+ if len(args) == 1:
109
+ arg = args[0]
110
+ if isinstance(arg, Timeseries):
111
+ return Timeseries(arg.times, self(arg.values))
112
+ else:
113
+ if hasattr(arg, "__iter__"):
114
+ arg = np.fromiter(arg, dtype=float)
115
+ return evaluator(arg)
116
+ else:
117
+ arg = float(arg)
118
+ return evaluator(arg).item()
119
+ else:
120
+ if any(isinstance(arg, Timeseries) for arg in args):
121
+ raise TypeError(
122
+ "Higher-order LookupTable calls do not yet support Timeseries parameters"
123
+ )
124
+ elif any(hasattr(arg, "__iter__") for arg in args):
125
+ raise TypeError(
126
+ "Higher-order LookupTable calls do not yet support vector parameters"
127
+ )
128
+ else:
129
+ args = np.fromiter(args, dtype=float)
130
+ return evaluator(*args)
131
+
132
+ def reverse_call(
133
+ self,
134
+ y: Union[float, Iterable, Timeseries],
135
+ domain: Tuple[float, float] = (None, None),
136
+ detect_range_error: bool = True,
137
+ ) -> Union[float, np.ndarray, Timeseries]:
138
+ """Do an inverted call on this LookupTable
139
+
140
+ Uses SciPy brentq optimizer to simulate a reversed call.
141
+ Note: Method does not work with higher-order LookupTables
142
+ """
143
+ if isinstance(y, Timeseries):
144
+ # Recurse and return
145
+ return Timeseries(y.times, self.reverse_call(y.values))
146
+
147
+ # Get domain information
148
+ l_d, u_d = domain
149
+ if l_d is None:
150
+ l_d = self.domain[0]
151
+ if u_d is None:
152
+ u_d = self.domain[1]
153
+
154
+ # Cast y to array of float
155
+ if hasattr(y, "__iter__"):
156
+ y_array = np.fromiter(y, dtype=float)
157
+ else:
158
+ y_array = np.array([y], dtype=float)
159
+
160
+ # Find not np.nan
161
+ is_not_nan = ~np.isnan(y_array)
162
+ y_array_not_nan = y_array[is_not_nan]
163
+
164
+ # Detect if there is a range violation
165
+ if detect_range_error:
166
+ l_r, u_r = self.range
167
+ lb_viol = y_array_not_nan < l_r
168
+ ub_viol = y_array_not_nan > u_r
169
+ all_viol = y_array_not_nan[lb_viol | ub_viol]
170
+ if all_viol.size > 0:
171
+ raise ValueError(
172
+ "Values {} are not in lookup table range ({}, {})".format(all_viol, l_r, u_r)
173
+ )
174
+
175
+ # Construct function to do inverse evaluation
176
+ evaluator = self.__numeric_function_evaluator
177
+
178
+ def inv_evaluator(y_target):
179
+ """inverse evaluator function"""
180
+ return brentq(lambda x: evaluator(x) - y_target, l_d, u_d)
181
+
182
+ inv_evaluator = np.vectorize(inv_evaluator)
183
+
184
+ # Calculate x_array
185
+ x_array = np.full_like(y_array, np.nan, dtype=float)
186
+ if y_array_not_nan.size != 0:
187
+ x_array[is_not_nan] = inv_evaluator(y_array_not_nan)
188
+
189
+ # Return x
190
+ if hasattr(y, "__iter__"):
191
+ return x_array
192
+ else:
193
+ return x_array.item()
194
+
195
+
196
+ class CSVLookupTableMixin(OptimizationProblem):
197
+ """
198
+ Adds lookup tables to your optimization problem.
199
+
200
+ During preprocessing, the CSV files located inside the ``lookup_tables`` subfolder are read. In
201
+ every CSV file, the first column contains the output of the lookup table. Subsequent columns
202
+ contain the input variables.
203
+
204
+ Cubic B-Splines are used to turn the data points into continuous lookup tables.
205
+
206
+ Optionally, a file ``curvefit_options.ini`` may be included inside the ``lookup_tables`` folder.
207
+ This file contains, grouped per lookup table, the following options:
208
+
209
+ * monotonicity:
210
+ * is an integer, magnitude is ignored
211
+ * if positive, causes spline to be monotonically increasing
212
+ * if negative, causes spline to be monotonically decreasing
213
+ * if 0, leaves spline monotonicity unconstrained
214
+
215
+ * curvature:
216
+ * is an integer, magnitude is ignored
217
+ * if positive, causes spline curvature to be positive (convex)
218
+ * if negative, causes spline curvature to be negative (concave)
219
+ * if 0, leaves spline curvature unconstrained
220
+
221
+ .. note::
222
+
223
+ Currently only one-dimensional lookup tables are fully supported. Support for two-
224
+ dimensional lookup tables is experimental.
225
+
226
+ :cvar csv_delimiter: Column delimiter used in CSV files. Default is ``,``.
227
+ :cvar csv_lookup_table_debug: Whether to generate plots of the spline fits.
228
+ Default is ``False``.
229
+ :cvar csv_lookup_table_debug_points: Number of evaluation points for plots.
230
+ Default is ``100``.
231
+ """
232
+
233
+ #: Column delimiter used in CSV files
234
+ csv_delimiter = ","
235
+
236
+ #: Debug settings
237
+ csv_lookup_table_debug = False
238
+ csv_lookup_table_debug_points = 100
239
+
240
+ def __init__(self, **kwargs):
241
+ # Check arguments
242
+ if "input_folder" in kwargs:
243
+ assert "lookup_table_folder" not in kwargs
244
+
245
+ self.__lookup_table_folder = os.path.join(kwargs["input_folder"], "lookup_tables")
246
+ else:
247
+ self.__lookup_table_folder = kwargs["lookup_table_folder"]
248
+
249
+ # Call parent
250
+ super().__init__(**kwargs)
251
+
252
+ def pre(self):
253
+ # Call parent class first for default behaviour.
254
+ super().pre()
255
+
256
+ # Get curve fitting options from curvefit_options.ini file
257
+ ini_path = os.path.join(self.__lookup_table_folder, "curvefit_options.ini")
258
+ try:
259
+ ini_config = configparser.RawConfigParser()
260
+ ini_config.read(ini_path)
261
+ no_curvefit_options = False
262
+ except IOError:
263
+ logger.info(
264
+ "CSVLookupTableMixin: No curvefit_options.ini file found. Using default values."
265
+ )
266
+ no_curvefit_options = True
267
+
268
+ def get_curvefit_options(curve_name, no_curvefit_options=no_curvefit_options):
269
+ if no_curvefit_options:
270
+ return 0, 0, 0
271
+
272
+ curvefit_options = []
273
+
274
+ def get_property(prop_name):
275
+ try:
276
+ prop = int(ini_config.get(curve_name, prop_name))
277
+ except configparser.NoSectionError:
278
+ prop = 0
279
+ except configparser.NoOptionError:
280
+ prop = 0
281
+ except ValueError:
282
+ raise Exception(
283
+ "CSVLookupTableMixin: Invalid {0} constraint for {1}. {0} should "
284
+ "be either -1, 0, or 1.".format(prop_name, curve_name)
285
+ )
286
+ return prop
287
+
288
+ for prop_name in ["monotonicity", "monotonicity2", "curvature"]:
289
+ curvefit_options.append(get_property(prop_name))
290
+
291
+ logger.debug(
292
+ "CSVLookupTableMixin: Curve fit option for {}:({},{},{})".format(
293
+ curve_name, *curvefit_options
294
+ )
295
+ )
296
+ return tuple(curvefit_options)
297
+
298
+ def check_lookup_table(lookup_table):
299
+ if lookup_table in self.__lookup_tables:
300
+ raise Exception(
301
+ "Cannot add lookup table {},since there is already one with this name.".format(
302
+ lookup_table
303
+ )
304
+ )
305
+
306
+ # Read CSV files
307
+ logger.info("CSVLookupTableMixin: Generating Splines from lookup table data.")
308
+ self.__lookup_tables = {}
309
+ for filename in glob.glob(os.path.join(self.__lookup_table_folder, "*.csv")):
310
+ logger.debug("CSVLookupTableMixin: Reading lookup table from {}".format(filename))
311
+
312
+ csvinput = csv.load(filename, delimiter=self.csv_delimiter)
313
+ output = csvinput.dtype.names[0]
314
+ inputs = csvinput.dtype.names[1:]
315
+
316
+ # Get monotonicity and curvature from ini file
317
+ mono, mono2, curv = get_curvefit_options(output)
318
+
319
+ logger.debug("CSVLookupTableMixin: Output is {}, inputs are {}.".format(output, inputs))
320
+
321
+ tck = None
322
+ function = None
323
+
324
+ # If tck file is newer than the csv file, first try to load the cached values from
325
+ # the tck file
326
+ tck_filename = filename.replace(".csv", ".npz")
327
+ valid_cache = False
328
+ if os.path.exists(tck_filename):
329
+ if no_curvefit_options:
330
+ valid_cache = os.path.getmtime(filename) < os.path.getmtime(tck_filename)
331
+ else:
332
+ valid_cache = (
333
+ os.path.getmtime(filename) < os.path.getmtime(tck_filename)
334
+ ) and (os.path.getmtime(ini_path) < os.path.getmtime(tck_filename))
335
+ if valid_cache:
336
+ logger.debug(
337
+ "CSVLookupTableMixin: Attempting to use cached tck values for {}".format(
338
+ output
339
+ )
340
+ )
341
+ try:
342
+ with np.load(filename.replace(".csv", ".npz")) as data:
343
+ tck = (data["arr_0"], data["arr_1"], int(data["arr_2"]))
344
+ function = ca.Function.load(filename.replace(".csv", ".ca"))
345
+ except Exception:
346
+ valid_cache = False
347
+
348
+ if not valid_cache:
349
+ logger.info("CSVLookupTableMixin: Recalculating tck values for {}".format(output))
350
+
351
+ if len(csvinput.dtype.names) == 2:
352
+ if not valid_cache:
353
+ k = 3 # default value
354
+ # 1D spline fitting needs k+1 data points
355
+ if len(csvinput[output]) >= k + 1:
356
+ tck = BSpline1D.fit(
357
+ csvinput[inputs[0]],
358
+ csvinput[output],
359
+ k=k,
360
+ monotonicity=mono,
361
+ curvature=curv,
362
+ ipopt_options={"nlp_scaling_method": "none"},
363
+ )
364
+ else:
365
+ raise Exception(
366
+ "CSVLookupTableMixin: Too few data points in {} to do spline fitting. "
367
+ "Need at least {} points.".format(filename, k + 1)
368
+ )
369
+
370
+ if self.csv_lookup_table_debug:
371
+ import pylab
372
+
373
+ i = np.linspace(
374
+ csvinput[inputs[0]][0],
375
+ csvinput[inputs[0]][-1],
376
+ self.csv_lookup_table_debug_points,
377
+ )
378
+ o = splev(i, tck)
379
+ pylab.clf()
380
+ # TODO: Figure out why cross-section B0607 in NZV does not
381
+ # conform to constraints!
382
+ pylab.plot(i, o)
383
+ pylab.plot(
384
+ csvinput[inputs[0]],
385
+ csvinput[output],
386
+ linestyle="",
387
+ marker="x",
388
+ markersize=10,
389
+ )
390
+ figure_filename = filename.replace(".csv", ".png")
391
+ pylab.savefig(figure_filename)
392
+
393
+ symbols = [ca.SX.sym(inputs[0])]
394
+ if not valid_cache:
395
+ function = ca.Function("f", symbols, [BSpline1D(*tck)(symbols[0])])
396
+ check_lookup_table(output)
397
+ self.__lookup_tables[output] = LookupTable(symbols, function, tck)
398
+
399
+ elif len(csvinput.dtype.names) == 3:
400
+ if tck is None:
401
+ kx = 3 # default value
402
+ ky = 3 # default value
403
+
404
+ # 2D spline fitting needs (kx+1)*(ky+1) data points
405
+ if len(csvinput[output]) >= (kx + 1) * (ky + 1):
406
+ # TODO: add curvature parameters from curvefit_options.ini
407
+ # once 2d fit is implemented
408
+ tck = bisplrep(
409
+ csvinput[inputs[0]], csvinput[inputs[1]], csvinput[output], kx=kx, ky=ky
410
+ )
411
+ else:
412
+ raise Exception(
413
+ "CSVLookupTableMixin: Too few data points in {} to do spline fitting. "
414
+ "Need at least {} points.".format(filename, (kx + 1) * (ky + 1))
415
+ )
416
+
417
+ if self.csv_lookup_table_debug:
418
+ import pylab
419
+
420
+ i1 = np.linspace(
421
+ csvinput[inputs[0]][0],
422
+ csvinput[inputs[0]][-1],
423
+ self.csv_lookup_table_debug_points,
424
+ )
425
+ i2 = np.linspace(
426
+ csvinput[inputs[1]][0],
427
+ csvinput[inputs[1]][-1],
428
+ self.csv_lookup_table_debug_points,
429
+ )
430
+ i1, i2 = np.meshgrid(i1, i2)
431
+ i1 = i1.flatten()
432
+ i2 = i2.flatten()
433
+ o = bisplev(i1, i2, tck)
434
+ pylab.clf()
435
+ pylab.plot_surface(i1, i2, o)
436
+ figure_filename = filename.replace(".csv", ".png")
437
+ pylab.savefig(figure_filename)
438
+ symbols = [ca.SX.sym(inputs[0]), ca.SX.sym(inputs[1])]
439
+ if not valid_cache:
440
+ function = ca.Function("f", symbols, [BSpline2D(*tck)(symbols[0], symbols[1])])
441
+ check_lookup_table(output)
442
+ self.__lookup_tables[output] = LookupTable(symbols, function, tck)
443
+
444
+ else:
445
+ raise Exception(
446
+ "CSVLookupTableMixin: {}-dimensional lookup tables not implemented yet.".format(
447
+ len(csvinput.dtype.names)
448
+ )
449
+ )
450
+
451
+ if not valid_cache:
452
+ np.savez(filename.replace(".csv", ".npz"), *tck)
453
+ function.save(filename.replace(".csv", ".ca"))
454
+
455
+ def lookup_tables(self, ensemble_member):
456
+ # Call parent class first for default values.
457
+ lookup_tables = super().lookup_tables(ensemble_member)
458
+
459
+ # Update lookup_tables with imported csv lookup tables
460
+ lookup_tables.update(self.__lookup_tables)
461
+
462
+ return lookup_tables