pyMOTO 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/METADATA +7 -8
- pyMOTO-1.5.0.dist-info/RECORD +29 -0
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/WHEEL +1 -1
- pymoto/__init__.py +17 -11
- pymoto/common/domain.py +61 -5
- pymoto/common/dyadcarrier.py +87 -29
- pymoto/common/mma.py +142 -129
- pymoto/core_objects.py +129 -117
- pymoto/modules/aggregation.py +209 -0
- pymoto/modules/assembly.py +250 -10
- pymoto/modules/complex.py +3 -3
- pymoto/modules/filter.py +171 -24
- pymoto/modules/generic.py +12 -1
- pymoto/modules/io.py +85 -12
- pymoto/modules/linalg.py +92 -120
- pymoto/modules/scaling.py +5 -4
- pymoto/routines.py +34 -9
- pymoto/solvers/__init__.py +14 -0
- pymoto/solvers/auto_determine.py +108 -0
- pymoto/{common/solvers_dense.py → solvers/dense.py} +90 -70
- pymoto/solvers/iterative.py +361 -0
- pymoto/solvers/matrix_checks.py +60 -0
- pymoto/solvers/solvers.py +253 -0
- pymoto/{common/solvers_sparse.py → solvers/sparse.py} +42 -29
- pyMOTO-1.3.0.dist-info/RECORD +0 -24
- pymoto/common/solvers.py +0 -236
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/LICENSE +0 -0
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/top_level.txt +0 -0
- {pyMOTO-1.3.0.dist-info → pyMOTO-1.5.0.dist-info}/zip-safe +0 -0
pymoto/core_objects.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
-
|
1
|
+
import sys
|
2
2
|
import warnings
|
3
3
|
import inspect
|
4
4
|
import time
|
5
|
-
|
5
|
+
import copy
|
6
|
+
from typing import Union, List, Any
|
6
7
|
from abc import ABC, abstractmethod
|
8
|
+
from .utils import _parse_to_list, _concatenate_to_array, _split_from_array
|
7
9
|
|
8
10
|
|
9
11
|
# Local helper functions
|
@@ -11,7 +13,7 @@ def err_fmt(*args):
|
|
11
13
|
""" Format error strings for locating Modules and Signals"""
|
12
14
|
err_str = ""
|
13
15
|
for a in args:
|
14
|
-
err_str += f"\n\t
|
16
|
+
err_str += f"\n\t| {a}"
|
15
17
|
return err_str
|
16
18
|
|
17
19
|
|
@@ -76,16 +78,20 @@ class Signal:
|
|
76
78
|
>> Signal(tag='x2')
|
77
79
|
|
78
80
|
"""
|
79
|
-
def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None):
|
81
|
+
def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None, min: Any = None, max: Any = None):
|
80
82
|
"""
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
83
|
+
Keyword Args:
|
84
|
+
tag: The name of the signal
|
85
|
+
state: The initialized state
|
86
|
+
sensitivity: The initialized sensitivity
|
87
|
+
min: Minimum allowed value
|
88
|
+
max: Maximum allowed value
|
85
89
|
"""
|
86
90
|
self.tag = tag
|
87
91
|
self.state = state
|
88
92
|
self.sensitivity = sensitivity
|
93
|
+
self.min = min
|
94
|
+
self.max = max
|
89
95
|
self.keep_alloc = sensitivity is not None
|
90
96
|
|
91
97
|
# Save error string to location where it is initialized
|
@@ -95,11 +101,15 @@ class Signal:
|
|
95
101
|
return err_fmt(f"Signal \'{self.tag}\', initialized in {self._init_loc}")
|
96
102
|
|
97
103
|
def add_sensitivity(self, ds: Any):
|
104
|
+
""" Add a new term to internal sensitivity """
|
98
105
|
try:
|
99
106
|
if ds is None:
|
100
107
|
return
|
101
108
|
if self.sensitivity is None:
|
102
|
-
self.sensitivity = ds
|
109
|
+
self.sensitivity = copy.deepcopy(ds)
|
110
|
+
elif hasattr(self.sensitivity, "add_sensitivity"):
|
111
|
+
# Allow user to implement a custom add_sensitivity function instead of __iadd__
|
112
|
+
self.sensitivity.add_sensitivity(ds)
|
103
113
|
else:
|
104
114
|
self.sensitivity += ds
|
105
115
|
return self
|
@@ -116,8 +126,12 @@ class Signal:
|
|
116
126
|
def reset(self, keep_alloc: bool = None):
|
117
127
|
""" Reset the sensitivities to zero or None
|
118
128
|
This must be called to clear internal memory of subsequent sensitivity calculations.
|
119
|
-
|
120
|
-
:
|
129
|
+
|
130
|
+
Args:
|
131
|
+
keep_alloc: Keep the sensitivity allocation intact?
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
self
|
121
135
|
"""
|
122
136
|
if self.sensitivity is None:
|
123
137
|
return self
|
@@ -138,11 +152,34 @@ class Signal:
|
|
138
152
|
|
139
153
|
def __getitem__(self, item):
|
140
154
|
""" Obtain a sliced signal, for using its partial contents.
|
141
|
-
|
142
|
-
:
|
155
|
+
|
156
|
+
Args:
|
157
|
+
item: Slice indices
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Sliced signal (SignalSlice)
|
143
161
|
"""
|
144
162
|
return SignalSlice(self, item)
|
145
163
|
|
164
|
+
def __str__(self):
|
165
|
+
state_msg = f"state {self.state}" if self.state is not None else "empty state"
|
166
|
+
state_msg = state_msg.split('\n')
|
167
|
+
if len(state_msg) > 1:
|
168
|
+
state_msg = state_msg[0] + ' ... ' + state_msg[-1]
|
169
|
+
else:
|
170
|
+
state_msg = state_msg[0]
|
171
|
+
return f"Signal \"{self.tag}\" with {state_msg}"
|
172
|
+
|
173
|
+
def __repr__(self):
|
174
|
+
state_msg = f"state {self.state}" if self.state is not None else "empty state"
|
175
|
+
state_msg = state_msg.split('\n')
|
176
|
+
if len(state_msg) > 1:
|
177
|
+
state_msg = state_msg[0] + ' ... ' + state_msg[-1]
|
178
|
+
else:
|
179
|
+
state_msg = state_msg[0]
|
180
|
+
sens_msg = 'empty sensitivity' if self.sensitivity is None else 'non-empty sensitivity'
|
181
|
+
return f"Signal \"{self.tag}\" with {state_msg} and {sens_msg} at {hex(id(self))}"
|
182
|
+
|
146
183
|
|
147
184
|
class SignalSlice(Signal):
|
148
185
|
""" Slice operator for a Signal
|
@@ -169,7 +206,8 @@ class SignalSlice(Signal):
|
|
169
206
|
return None if self.orig_signal.state is None else self.orig_signal.state[self.slice]
|
170
207
|
except Exception as e:
|
171
208
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
172
|
-
raise type(e)("SignalSlice.state (getter)" +
|
209
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (getter). Signal details:" +
|
210
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
173
211
|
|
174
212
|
@state.setter
|
175
213
|
def state(self, new_state):
|
@@ -177,7 +215,8 @@ class SignalSlice(Signal):
|
|
177
215
|
self.orig_signal.state[self.slice] = new_state
|
178
216
|
except Exception as e:
|
179
217
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
180
|
-
raise type(e)("SignalSlice.state (setter)" +
|
218
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
|
219
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
181
220
|
|
182
221
|
@property
|
183
222
|
def sensitivity(self):
|
@@ -185,7 +224,8 @@ class SignalSlice(Signal):
|
|
185
224
|
return None if self.orig_signal.sensitivity is None else self.orig_signal.sensitivity[self.slice]
|
186
225
|
except Exception as e:
|
187
226
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
188
|
-
raise type(e)("SignalSlice.sensitivity (getter)" +
|
227
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.sensitivity (getter). Signal details:" +
|
228
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
189
229
|
|
190
230
|
@sensitivity.setter
|
191
231
|
def sensitivity(self, new_sens):
|
@@ -207,7 +247,8 @@ class SignalSlice(Signal):
|
|
207
247
|
self.orig_signal.sensitivity[self.slice] = new_sens
|
208
248
|
except Exception as e:
|
209
249
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
210
|
-
raise type(e)("SignalSlice.
|
250
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
|
251
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
211
252
|
|
212
253
|
def reset(self, keep_alloc: bool = None):
|
213
254
|
""" Reset the sensitivities to zero or None
|
@@ -231,7 +272,7 @@ def make_signals(*args):
|
|
231
272
|
return ret
|
232
273
|
|
233
274
|
|
234
|
-
def
|
275
|
+
def _is_valid_signal(sig: Any):
|
235
276
|
""" Checks if the argument is a valid Signal object
|
236
277
|
:param sig: The object to check
|
237
278
|
:return: True if it is a valid Signal
|
@@ -240,10 +281,10 @@ def _check_valid_signal(sig: Any):
|
|
240
281
|
return True
|
241
282
|
if all([hasattr(sig, f) for f in ["state", "sensitivity", "add_sensitivity", "reset"]]):
|
242
283
|
return True
|
243
|
-
|
284
|
+
return False
|
244
285
|
|
245
286
|
|
246
|
-
def
|
287
|
+
def _is_valid_module(mod: Any):
|
247
288
|
""" Checks if the argument is a valid Module object
|
248
289
|
:param mod: The object to check
|
249
290
|
:return: True if it is a valid Module
|
@@ -252,7 +293,7 @@ def _check_valid_module(mod: Any):
|
|
252
293
|
return True
|
253
294
|
if hasattr(mod, "response") and hasattr(mod, "sensitivity") and hasattr(mod, "reset"):
|
254
295
|
return True
|
255
|
-
|
296
|
+
return False
|
256
297
|
|
257
298
|
|
258
299
|
def _check_function_signature(fn, signals):
|
@@ -370,60 +411,48 @@ class Module(ABC, RegisteredClass):
|
|
370
411
|
>> Module(sig_in=[inputs], sig_out=[outputs]
|
371
412
|
"""
|
372
413
|
|
373
|
-
def _err_str(self,
|
414
|
+
def _err_str(self, module_signature: bool = True, init: bool = True, fn=None):
|
374
415
|
str_list = []
|
375
|
-
|
376
|
-
|
377
|
-
if add_signal:
|
416
|
+
|
417
|
+
if module_signature:
|
378
418
|
inp_str = "Inputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_in]) if len(self.sig_in) > 0 else "No inputs"
|
379
419
|
out_str = "Outputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_out]) if len(self.sig_out) > 0 else "No outputs"
|
380
|
-
str_list.append(inp_str + " --> " + out_str)
|
420
|
+
str_list.append(f"Module \'{type(self).__name__}\'( " + inp_str + " ) --> " + out_str)
|
421
|
+
if init:
|
422
|
+
str_list.append(f"Used in {self._init_loc}")
|
381
423
|
if fn is not None:
|
382
424
|
name = f"{fn.__self__.__class__.__name__}.{fn.__name__}{inspect.signature(fn)}"
|
383
425
|
lineno = inspect.getsourcelines(fn)[1]
|
384
426
|
filename = inspect.getfile(fn)
|
385
|
-
str_list.append(f"
|
427
|
+
str_list.append(f"Implementation in File \"{filename}\", line {lineno}, in {name}")
|
386
428
|
return err_fmt(*str_list)
|
387
429
|
|
388
430
|
# flake8: noqa: C901
|
389
431
|
def __init__(self, sig_in: Union[Signal, List[Signal]] = None, sig_out: Union[Signal, List[Signal]] = None,
|
390
432
|
*args, **kwargs):
|
391
|
-
# TODO: Reduce complexity of this init
|
392
433
|
self._init_loc = get_init_str()
|
393
434
|
|
394
435
|
self.sig_in = _parse_to_list(sig_in)
|
395
436
|
self.sig_out = _parse_to_list(sig_out)
|
396
437
|
for i, s in enumerate(self.sig_in):
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
401
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
402
|
-
raise type(e)(f"Invalid input signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
|
438
|
+
if not _is_valid_signal(s):
|
439
|
+
tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
|
440
|
+
raise TypeError(f"Input {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
|
403
441
|
|
404
442
|
for i, s in enumerate(self.sig_out):
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
409
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
410
|
-
raise type(e)(f"Invalid output signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
|
443
|
+
if not _is_valid_signal(s):
|
444
|
+
tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
|
445
|
+
raise TypeError(f"Output {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
|
411
446
|
|
412
|
-
|
413
|
-
|
414
|
-
self._prepare(*args, **kwargs)
|
415
|
-
except Exception as e:
|
416
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
417
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
418
|
-
raise type(e)("_prepare() - " + str(earg0) + self._err_str(fn=self._prepare), *earg1) from e
|
447
|
+
# Call preparation of submodule with remaining arguments
|
448
|
+
self._prepare(*args, **kwargs)
|
419
449
|
|
420
450
|
try:
|
421
451
|
# Check if the signals match _response() signature
|
422
452
|
_check_function_signature(self._response, self.sig_in)
|
423
453
|
except Exception as e:
|
424
|
-
|
425
|
-
|
426
|
-
raise type(e)(str(earg0) + self._err_str(fn=self._response), *earg1) from None
|
454
|
+
raise type(e)(str(e) + "\n\t| Module details:" +
|
455
|
+
self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
|
427
456
|
|
428
457
|
try:
|
429
458
|
# If no output signals are given, but are required, try to initialize them here
|
@@ -441,9 +470,8 @@ class Module(ABC, RegisteredClass):
|
|
441
470
|
# Check if signals match _sensitivity() signature
|
442
471
|
_check_function_signature(self._sensitivity, self.sig_out)
|
443
472
|
except Exception as e:
|
444
|
-
|
445
|
-
|
446
|
-
raise type(e)(str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from None
|
473
|
+
raise type(e)(str(e) + "\n\t| Module details:" +
|
474
|
+
self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
|
447
475
|
|
448
476
|
def response(self):
|
449
477
|
""" Calculate the response from sig_in and output this to sig_out """
|
@@ -461,9 +489,9 @@ class Module(ABC, RegisteredClass):
|
|
461
489
|
self.sig_out[i].state = val
|
462
490
|
return self
|
463
491
|
except Exception as e:
|
464
|
-
|
465
|
-
|
466
|
-
|
492
|
+
# https://stackoverflow.com/questions/6062576/adding-information-to-an-exception
|
493
|
+
raise type(e)(str(e) + "\n\t| Above error was raised when calling response(). Module details:" +
|
494
|
+
self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
|
467
495
|
|
468
496
|
def __call__(self):
|
469
497
|
return self.response()
|
@@ -494,9 +522,8 @@ class Module(ABC, RegisteredClass):
|
|
494
522
|
|
495
523
|
return self
|
496
524
|
except Exception as e:
|
497
|
-
|
498
|
-
|
499
|
-
raise type(e)("sensitivity() - " + str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from e
|
525
|
+
raise type(e)(str(e) + "\n\t| Above error was raised when calling sensitivity(). Module details:" +
|
526
|
+
self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
|
500
527
|
|
501
528
|
def reset(self):
|
502
529
|
""" Reset the state of the sensitivities (they are set to zero or to None) """
|
@@ -506,9 +533,8 @@ class Module(ABC, RegisteredClass):
|
|
506
533
|
self._reset()
|
507
534
|
return self
|
508
535
|
except Exception as e:
|
509
|
-
|
510
|
-
|
511
|
-
raise type(e)("reset() - " + str(earg0) + self._err_str(fn=self._reset), *earg1) from e
|
536
|
+
raise type(e)(str(e) + "\n\t| Above error was raised when calling reset(). Module details:" +
|
537
|
+
self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
|
512
538
|
|
513
539
|
# METHODS TO BE DEFINED BY USER
|
514
540
|
def _prepare(self, *args, **kwargs):
|
@@ -542,49 +568,52 @@ class Network(Module):
|
|
542
568
|
"""
|
543
569
|
def __init__(self, *args, print_timing=False):
|
544
570
|
self._init_loc = get_init_str()
|
545
|
-
try:
|
546
|
-
# Obtain the internal blocks
|
547
|
-
self.mods = _parse_to_list(*args)
|
548
|
-
|
549
|
-
# Check if the blocks are initialized, else create them
|
550
|
-
for i, b in enumerate(self.mods):
|
551
|
-
if isinstance(b, dict):
|
552
|
-
exclude_keys = ['type']
|
553
|
-
b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
|
554
|
-
self.mods[i] = Module.create(b['type'], **b_ex)
|
555
|
-
|
556
|
-
# Check validity of modules
|
557
|
-
[_check_valid_module(m) for m in self.mods]
|
558
|
-
|
559
|
-
# Gather all the input and output signals of the internal blocks
|
560
|
-
all_in = set()
|
561
|
-
all_out = set()
|
562
|
-
[all_in.update(b.sig_in) for b in self.mods]
|
563
|
-
[all_out.update(b.sig_out) for b in self.mods]
|
564
|
-
in_unique = all_in - all_out
|
565
|
-
|
566
|
-
# Initialize the parent module, with correct inputs and outputs
|
567
|
-
super().__init__(list(in_unique), list(all_out))
|
568
|
-
|
569
|
-
self.print_timing = print_timing
|
570
|
-
except Exception as e:
|
571
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
572
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
573
|
-
raise type(e)(str(earg0) + self._err_str(add_signal=False), *earg1) from None
|
574
571
|
|
575
|
-
|
572
|
+
# Obtain the internal blocks
|
573
|
+
self.mods = _parse_to_list(*args)
|
574
|
+
|
575
|
+
# Check if the blocks are initialized, else create them
|
576
|
+
for i, b in enumerate(self.mods):
|
577
|
+
if isinstance(b, dict):
|
578
|
+
exclude_keys = ['type']
|
579
|
+
b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
|
580
|
+
self.mods[i] = Module.create(b['type'], **b_ex)
|
581
|
+
|
582
|
+
# Check validity of modules
|
583
|
+
for m in self.mods:
|
584
|
+
if not _is_valid_module(m):
|
585
|
+
raise TypeError(f"Argument is not a valid Module, type=\'{type(mod).__name__}\'.")
|
586
|
+
|
587
|
+
# Gather all the input and output signals of the internal blocks
|
588
|
+
all_in = set()
|
589
|
+
all_out = set()
|
590
|
+
[all_in.update(b.sig_in) for b in self.mods]
|
591
|
+
[all_out.update(b.sig_out) for b in self.mods]
|
592
|
+
in_unique = all_in - all_out
|
593
|
+
|
594
|
+
# Initialize the parent module, with correct inputs and outputs
|
595
|
+
super().__init__(list(in_unique), list(all_out))
|
596
|
+
|
597
|
+
self.print_timing = print_timing
|
598
|
+
|
599
|
+
def timefn(self, fn, prefix='Evaluation'):
|
576
600
|
start_t = time.time()
|
577
601
|
fn()
|
578
|
-
|
602
|
+
duration = time.time() - start_t
|
603
|
+
if duration > .5:
|
604
|
+
print(f"{prefix} {fn} took {time.time() - start_t} s")
|
579
605
|
|
580
606
|
def response(self):
|
581
607
|
if self.print_timing:
|
582
|
-
[self.timefn(b.response) for b in self.mods]
|
608
|
+
[self.timefn(b.response, prefix='Response') for b in self.mods]
|
583
609
|
else:
|
584
610
|
[b.response() for b in self.mods]
|
585
611
|
|
586
612
|
def sensitivity(self):
|
587
|
-
|
613
|
+
if self.print_timing:
|
614
|
+
[self.timefn(b.sensitivity, 'Sensitivity') for b in reversed(self.mods)]
|
615
|
+
else:
|
616
|
+
[b.sensitivity() for b in reversed(self.mods)]
|
588
617
|
|
589
618
|
def reset(self):
|
590
619
|
[b.reset() for b in reversed(self.mods)]
|
@@ -611,12 +640,9 @@ class Network(Module):
|
|
611
640
|
modlist = _parse_to_list(*newmods)
|
612
641
|
|
613
642
|
# Check if the blocks are initialized, else create them
|
614
|
-
for i,
|
615
|
-
|
616
|
-
|
617
|
-
except Exception as e:
|
618
|
-
raise type(e)("append() - Trying to append invalid module " + str(e.args[0])
|
619
|
-
+ self._err_str(add_signal=False), *e.args[1:]) from None
|
643
|
+
for i, m in enumerate(modlist):
|
644
|
+
if not _is_valid_module(m):
|
645
|
+
raise TypeError(f"Argument #{i} is not a valid module, type=\'{type(mod).__name__}\'.")
|
620
646
|
|
621
647
|
# Obtain the internal blocks
|
622
648
|
self.mods.extend(modlist)
|
@@ -624,25 +650,11 @@ class Network(Module):
|
|
624
650
|
# Gather all the input and output signals of the internal blocks
|
625
651
|
all_in = set()
|
626
652
|
all_out = set()
|
627
|
-
[all_in.update(
|
628
|
-
[all_out.update(
|
653
|
+
[all_in.update(m.sig_in) for m in self.mods]
|
654
|
+
[all_out.update(m.sig_out) for m in self.mods]
|
629
655
|
in_unique = all_in - all_out
|
630
656
|
|
631
657
|
self.sig_in = _parse_to_list(in_unique)
|
632
|
-
try:
|
633
|
-
[_check_valid_signal(s) for s in self.sig_in]
|
634
|
-
except Exception as e:
|
635
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
636
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
637
|
-
raise type(e)("append() - Invalid input signals " + str(earg0)
|
638
|
-
+ self._err_str(add_signal=False), *earg1) from None
|
639
658
|
self.sig_out = _parse_to_list(all_out)
|
640
|
-
try:
|
641
|
-
[_check_valid_signal(s) for s in self.sig_out]
|
642
|
-
except Exception as e:
|
643
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
644
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
645
|
-
raise type(e)("append() - Invalid output signals " + str(earg0)
|
646
|
-
+ self._err_str(add_signal=False), *earg1) from None
|
647
659
|
|
648
660
|
return modlist[-1].sig_out[0] if len(modlist[-1].sig_out) == 1 else modlist[-1].sig_out # Returns the output signal
|
@@ -0,0 +1,209 @@
|
|
1
|
+
import warnings
|
2
|
+
import abc
|
3
|
+
import numpy as np
|
4
|
+
import scipy.special as spsp
|
5
|
+
from pymoto import Module
|
6
|
+
|
7
|
+
|
8
|
+
class AggActiveSet:
|
9
|
+
""" Determine active set by discarding lower or upper fraction of a set of values
|
10
|
+
|
11
|
+
Args:
|
12
|
+
lower_rel: Fraction of values closest to minimum to discard (based on value)
|
13
|
+
upper_rel: Fraction of values closest to maximum to discard (based on value)
|
14
|
+
lower_amt: Fraction of lowest values to discard (based on sorting)
|
15
|
+
upper_amt: Fraction of highest values to discard (based on sorting)
|
16
|
+
"""
|
17
|
+
def __init__(self, lower_rel=0.0, upper_rel=1.0, lower_amt=0.0, upper_amt=1.0):
|
18
|
+
assert upper_rel > lower_rel, "Upper must be larger than lower to keep values in the set"
|
19
|
+
assert upper_amt > lower_amt, "Upper must be larger than lower to keep values in the set"
|
20
|
+
self.lower_rel, self.upper_rel = lower_rel, upper_rel
|
21
|
+
self.lower_amt, self.upper_amt = lower_amt, upper_amt
|
22
|
+
|
23
|
+
def __call__(self, x):
|
24
|
+
""" Generate an active set for given array """
|
25
|
+
xmin, xmax = np.min(x), np.max(x)
|
26
|
+
if (xmax - xmin) == 0: # All values are the same, so no active set can be taken
|
27
|
+
return Ellipsis
|
28
|
+
|
29
|
+
sel = np.ones_like(x, dtype=bool)
|
30
|
+
|
31
|
+
# Select based on value
|
32
|
+
xrel = (x - xmin) / (xmax - xmin) # Normalize between 0 and 1
|
33
|
+
if self.lower_rel > 0:
|
34
|
+
sel = np.logical_and(sel, xrel >= self.lower_rel)
|
35
|
+
if self.upper_rel < 1:
|
36
|
+
sel = np.logical_and(sel, xrel <= self.upper_rel)
|
37
|
+
|
38
|
+
# Remove lowest and highest N values
|
39
|
+
i_sort = np.argsort(x)
|
40
|
+
if self.lower_amt > 0:
|
41
|
+
n_lower_amt = int(x.size * self.lower_amt)
|
42
|
+
sel[i_sort[:n_lower_amt]] = False
|
43
|
+
|
44
|
+
if self.upper_amt < 1:
|
45
|
+
n_upper_amt = int(x.size * (1 - self.upper_amt))
|
46
|
+
sel[i_sort[-n_upper_amt:]] = False
|
47
|
+
|
48
|
+
return sel
|
49
|
+
|
50
|
+
|
51
|
+
class AggScaling:
|
52
|
+
""" Scaling strategy to absolute minimum or maximum
|
53
|
+
|
54
|
+
Args:
|
55
|
+
which: Scale to `min` or `max`
|
56
|
+
damping(optional): Damping factor between [0, 1), for a value of 0.0 the aggregation approximation is corrected
|
57
|
+
to the exact maximum or minimum of the input set
|
58
|
+
"""
|
59
|
+
def __init__(self, which: str, damping=0.0):
|
60
|
+
self.damping = damping
|
61
|
+
if which.lower() == 'min':
|
62
|
+
self.f = np.min
|
63
|
+
elif which.lower() == 'max':
|
64
|
+
self.f = np.max
|
65
|
+
else:
|
66
|
+
raise ValueError("Argument `which` can only be 'min' or 'max'")
|
67
|
+
self.sf = None
|
68
|
+
|
69
|
+
def __call__(self, x, fx_approx):
|
70
|
+
""" Determine scaling factor
|
71
|
+
|
72
|
+
Args:
|
73
|
+
x: Set of values
|
74
|
+
fx_approx: Approximated minimum / maximum
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
Scaling factor
|
78
|
+
"""
|
79
|
+
trueval = self.f(x)
|
80
|
+
scale = trueval / fx_approx
|
81
|
+
if self.sf is None:
|
82
|
+
self.sf = scale
|
83
|
+
else:
|
84
|
+
self.sf = self.damping * self.sf + (1 - self.damping) * scale
|
85
|
+
return self.sf
|
86
|
+
|
87
|
+
|
88
|
+
class Aggregation(Module):
|
89
|
+
""" Generic Aggregation module (cannot be used directly, but can only be used as superclass)
|
90
|
+
|
91
|
+
Keyword Args:
|
92
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
93
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
94
|
+
"""
|
95
|
+
def _prepare(self, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
96
|
+
# This prepare function MUST be called in the _prepare function of sub-classes
|
97
|
+
self.scaling = scaling
|
98
|
+
self.active_set = active_set
|
99
|
+
self.sf = 1.0
|
100
|
+
|
101
|
+
@abc.abstractmethod
|
102
|
+
def aggregation_function(self, x):
|
103
|
+
""" Calculates f(x) """
|
104
|
+
raise NotImplementedError()
|
105
|
+
|
106
|
+
@abc.abstractmethod
|
107
|
+
def aggregation_derivative(self, x):
|
108
|
+
"""" Calculates df(x) / dx """
|
109
|
+
raise NotImplementedError()
|
110
|
+
|
111
|
+
def _response(self, x):
|
112
|
+
# Determine active set
|
113
|
+
if self.active_set is not None:
|
114
|
+
self.select = self.active_set(x)
|
115
|
+
else:
|
116
|
+
self.select = Ellipsis
|
117
|
+
|
118
|
+
# Get aggregated value
|
119
|
+
xagg = self.aggregation_function(x[self.select])
|
120
|
+
|
121
|
+
# Scale
|
122
|
+
if self.scaling is not None:
|
123
|
+
self.sf = self.scaling(x[self.select], xagg)
|
124
|
+
return self.sf * xagg
|
125
|
+
|
126
|
+
def _sensitivity(self, dfdy):
|
127
|
+
x = self.sig_in[0].state
|
128
|
+
dydx = self.aggregation_derivative(x[self.select])
|
129
|
+
dx = np.zeros_like(x)
|
130
|
+
dx[self.select] += self.sf * dfdy * dydx
|
131
|
+
return dx
|
132
|
+
|
133
|
+
|
134
|
+
class PNorm(Aggregation):
|
135
|
+
r""" P-norm aggregration
|
136
|
+
|
137
|
+
:math:`S_p(x_1, x_2, \dotsc, x_n) = \left( \sum_i (|x_i|^p) \right)^{1/p}
|
138
|
+
|
139
|
+
Only valid for positive :math:`x_i` when approximating the minimum or maximum
|
140
|
+
|
141
|
+
Args:
|
142
|
+
p: Power of the p-norm. Approximate maximum for `p>0` and minimum for `p<0`
|
143
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
144
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
145
|
+
"""
|
146
|
+
def _prepare(self, p=2, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
147
|
+
self.p = p
|
148
|
+
self.y = None
|
149
|
+
super()._prepare(scaling, active_set)
|
150
|
+
|
151
|
+
def aggregation_function(self, x):
|
152
|
+
if np.min(x) < 0:
|
153
|
+
warnings.warn("PNorm is only valid for positive x")
|
154
|
+
|
155
|
+
# Get p-norm
|
156
|
+
return np.sum(np.abs(x) ** self.p) ** (1/self.p)
|
157
|
+
|
158
|
+
def aggregation_derivative(self, x):
|
159
|
+
pval = np.sum(np.abs(x) ** self.p) ** (1 / self.p - 1)
|
160
|
+
return pval * np.sign(x) * np.abs(x)**(self.p - 1)
|
161
|
+
|
162
|
+
|
163
|
+
class SoftMinMax(Aggregation):
|
164
|
+
r""" Soft maximum/minimum function
|
165
|
+
|
166
|
+
:math:`S_a(x_1, x_2, \dotsc, x_n) = \frac{\sum_i (x_i \exp(a x_i))}{\sum_i (\exp(a x_i))}`
|
167
|
+
|
168
|
+
When using as maximum, it underestimates the maximum
|
169
|
+
It is exact however when :math:`x_1=x_2=\dotsc=x_n`
|
170
|
+
|
171
|
+
Args:
|
172
|
+
alpha: Scaling factor of the soft function. Approximate maximum for `alpha>0` and minimum for `alpha<0`
|
173
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
174
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
175
|
+
"""
|
176
|
+
def _prepare(self, alpha=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
177
|
+
self.alpha = alpha
|
178
|
+
self.y = None
|
179
|
+
super()._prepare(scaling, active_set)
|
180
|
+
|
181
|
+
def aggregation_function(self, x):
|
182
|
+
self.y = np.sum(x * spsp.softmax(self.alpha * x))
|
183
|
+
return self.y
|
184
|
+
|
185
|
+
def aggregation_derivative(self, x):
|
186
|
+
return spsp.softmax(self.alpha * x) * (1 + self.alpha * (x - self.y))
|
187
|
+
|
188
|
+
|
189
|
+
class KSFunction(Aggregation):
|
190
|
+
r""" Kreisselmeier and Steinhauser function from 1979
|
191
|
+
|
192
|
+
:math:`S_\rho(x_1, x_2, \dotsc, x_n) = \frac{1}{\rho} \ln \left( \sum_i \exp(\rho x_i) \right)`
|
193
|
+
|
194
|
+
Args:
|
195
|
+
rho: Scaling factor of the KS function. Approximate maximum for `rho>0` and minimum for `rho<0`
|
196
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
197
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
198
|
+
"""
|
199
|
+
def _prepare(self, rho=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
200
|
+
self.rho = rho
|
201
|
+
self.y = None
|
202
|
+
super()._prepare(scaling, active_set)
|
203
|
+
|
204
|
+
def aggregation_function(self, x):
|
205
|
+
return 1/self.rho * np.log(np.sum(np.exp(self.rho * x)))
|
206
|
+
|
207
|
+
def aggregation_derivative(self, x):
|
208
|
+
erx = np.exp(self.rho * x)
|
209
|
+
return erx / np.sum(erx)
|