pyMOTO 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyMOTO-1.2.1.dist-info → pyMOTO-1.4.0.dist-info}/METADATA +7 -8
- pyMOTO-1.4.0.dist-info/RECORD +29 -0
- {pyMOTO-1.2.1.dist-info → pyMOTO-1.4.0.dist-info}/WHEEL +1 -1
- pymoto/__init__.py +19 -13
- pymoto/common/domain.py +75 -0
- pymoto/common/dyadcarrier.py +33 -4
- pymoto/common/mma.py +83 -53
- pymoto/core_objects.py +117 -113
- pymoto/modules/aggregation.py +209 -0
- pymoto/modules/assembly.py +202 -41
- pymoto/modules/complex.py +3 -3
- pymoto/modules/filter.py +171 -24
- pymoto/modules/generic.py +12 -1
- pymoto/modules/io.py +22 -11
- pymoto/modules/linalg.py +24 -118
- pymoto/modules/scaling.py +4 -4
- pymoto/routines.py +32 -15
- pymoto/solvers/__init__.py +14 -0
- pymoto/solvers/auto_determine.py +108 -0
- pymoto/{common/solvers_dense.py → solvers/dense.py} +90 -70
- pymoto/solvers/iterative.py +361 -0
- pymoto/solvers/matrix_checks.py +56 -0
- pymoto/solvers/solvers.py +253 -0
- pymoto/{common/solvers_sparse.py → solvers/sparse.py} +41 -29
- pyMOTO-1.2.1.dist-info/RECORD +0 -24
- pymoto/common/solvers.py +0 -236
- {pyMOTO-1.2.1.dist-info → pyMOTO-1.4.0.dist-info}/LICENSE +0 -0
- {pyMOTO-1.2.1.dist-info → pyMOTO-1.4.0.dist-info}/top_level.txt +0 -0
- {pyMOTO-1.2.1.dist-info → pyMOTO-1.4.0.dist-info}/zip-safe +0 -0
pymoto/core_objects.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
-
|
1
|
+
import sys
|
2
2
|
import warnings
|
3
3
|
import inspect
|
4
4
|
import time
|
5
|
-
|
5
|
+
import copy
|
6
|
+
from typing import Union, List, Any
|
6
7
|
from abc import ABC, abstractmethod
|
8
|
+
from .utils import _parse_to_list, _concatenate_to_array, _split_from_array
|
7
9
|
|
8
10
|
|
9
11
|
# Local helper functions
|
@@ -11,7 +13,7 @@ def err_fmt(*args):
|
|
11
13
|
""" Format error strings for locating Modules and Signals"""
|
12
14
|
err_str = ""
|
13
15
|
for a in args:
|
14
|
-
err_str += f"\n\t
|
16
|
+
err_str += f"\n\t| {a}"
|
15
17
|
return err_str
|
16
18
|
|
17
19
|
|
@@ -76,16 +78,20 @@ class Signal:
|
|
76
78
|
>> Signal(tag='x2')
|
77
79
|
|
78
80
|
"""
|
79
|
-
def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None):
|
81
|
+
def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None, min: Any = None, max: Any = None):
|
80
82
|
"""
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
83
|
+
Keyword Args:
|
84
|
+
tag: The name of the signal
|
85
|
+
state: The initialized state
|
86
|
+
sensitivity: The initialized sensitivity
|
87
|
+
min: Minimum allowed value
|
88
|
+
max: Maximum allowed value
|
85
89
|
"""
|
86
90
|
self.tag = tag
|
87
91
|
self.state = state
|
88
92
|
self.sensitivity = sensitivity
|
93
|
+
self.min = min
|
94
|
+
self.max = max
|
89
95
|
self.keep_alloc = sensitivity is not None
|
90
96
|
|
91
97
|
# Save error string to location where it is initialized
|
@@ -95,11 +101,12 @@ class Signal:
|
|
95
101
|
return err_fmt(f"Signal \'{self.tag}\', initialized in {self._init_loc}")
|
96
102
|
|
97
103
|
def add_sensitivity(self, ds: Any):
|
104
|
+
""" Add a new term to internal sensitivity """
|
98
105
|
try:
|
99
106
|
if ds is None:
|
100
107
|
return
|
101
108
|
if self.sensitivity is None:
|
102
|
-
self.sensitivity = ds
|
109
|
+
self.sensitivity = copy.deepcopy(ds)
|
103
110
|
else:
|
104
111
|
self.sensitivity += ds
|
105
112
|
return self
|
@@ -116,8 +123,12 @@ class Signal:
|
|
116
123
|
def reset(self, keep_alloc: bool = None):
|
117
124
|
""" Reset the sensitivities to zero or None
|
118
125
|
This must be called to clear internal memory of subsequent sensitivity calculations.
|
119
|
-
|
120
|
-
:
|
126
|
+
|
127
|
+
Args:
|
128
|
+
keep_alloc: Keep the sensitivity allocation intact?
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
self
|
121
132
|
"""
|
122
133
|
if self.sensitivity is None:
|
123
134
|
return self
|
@@ -138,11 +149,34 @@ class Signal:
|
|
138
149
|
|
139
150
|
def __getitem__(self, item):
|
140
151
|
""" Obtain a sliced signal, for using its partial contents.
|
141
|
-
|
142
|
-
:
|
152
|
+
|
153
|
+
Args:
|
154
|
+
item: Slice indices
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
Sliced signal (SignalSlice)
|
143
158
|
"""
|
144
159
|
return SignalSlice(self, item)
|
145
160
|
|
161
|
+
def __str__(self):
|
162
|
+
state_msg = f"state {self.state}" if self.state is not None else "empty state"
|
163
|
+
state_msg = state_msg.split('\n')
|
164
|
+
if len(state_msg) > 1:
|
165
|
+
state_msg = state_msg[0] + ' ... ' + state_msg[-1]
|
166
|
+
else:
|
167
|
+
state_msg = state_msg[0]
|
168
|
+
return f"Signal \"{self.tag}\" with {state_msg}"
|
169
|
+
|
170
|
+
def __repr__(self):
|
171
|
+
state_msg = f"state {self.state}" if self.state is not None else "empty state"
|
172
|
+
state_msg = state_msg.split('\n')
|
173
|
+
if len(state_msg) > 1:
|
174
|
+
state_msg = state_msg[0] + ' ... ' + state_msg[-1]
|
175
|
+
else:
|
176
|
+
state_msg = state_msg[0]
|
177
|
+
sens_msg = 'empty sensitivity' if self.sensitivity is None else 'non-empty sensitivity'
|
178
|
+
return f"Signal \"{self.tag}\" with {state_msg} and {sens_msg} at {hex(id(self))}"
|
179
|
+
|
146
180
|
|
147
181
|
class SignalSlice(Signal):
|
148
182
|
""" Slice operator for a Signal
|
@@ -169,7 +203,8 @@ class SignalSlice(Signal):
|
|
169
203
|
return None if self.orig_signal.state is None else self.orig_signal.state[self.slice]
|
170
204
|
except Exception as e:
|
171
205
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
172
|
-
raise type(e)("SignalSlice.state (getter)" +
|
206
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (getter). Signal details:" +
|
207
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
173
208
|
|
174
209
|
@state.setter
|
175
210
|
def state(self, new_state):
|
@@ -177,7 +212,8 @@ class SignalSlice(Signal):
|
|
177
212
|
self.orig_signal.state[self.slice] = new_state
|
178
213
|
except Exception as e:
|
179
214
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
180
|
-
raise type(e)("SignalSlice.state (setter)" +
|
215
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
|
216
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
181
217
|
|
182
218
|
@property
|
183
219
|
def sensitivity(self):
|
@@ -185,7 +221,8 @@ class SignalSlice(Signal):
|
|
185
221
|
return None if self.orig_signal.sensitivity is None else self.orig_signal.sensitivity[self.slice]
|
186
222
|
except Exception as e:
|
187
223
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
188
|
-
raise type(e)("SignalSlice.sensitivity (getter)" +
|
224
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.sensitivity (getter). Signal details:" +
|
225
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
189
226
|
|
190
227
|
@sensitivity.setter
|
191
228
|
def sensitivity(self, new_sens):
|
@@ -207,7 +244,8 @@ class SignalSlice(Signal):
|
|
207
244
|
self.orig_signal.sensitivity[self.slice] = new_sens
|
208
245
|
except Exception as e:
|
209
246
|
# Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
|
210
|
-
raise type(e)("SignalSlice.
|
247
|
+
raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
|
248
|
+
self._err_str()).with_traceback(sys.exc_info()[2])
|
211
249
|
|
212
250
|
def reset(self, keep_alloc: bool = None):
|
213
251
|
""" Reset the sensitivities to zero or None
|
@@ -231,7 +269,7 @@ def make_signals(*args):
|
|
231
269
|
return ret
|
232
270
|
|
233
271
|
|
234
|
-
def
|
272
|
+
def _is_valid_signal(sig: Any):
|
235
273
|
""" Checks if the argument is a valid Signal object
|
236
274
|
:param sig: The object to check
|
237
275
|
:return: True if it is a valid Signal
|
@@ -240,10 +278,10 @@ def _check_valid_signal(sig: Any):
|
|
240
278
|
return True
|
241
279
|
if all([hasattr(sig, f) for f in ["state", "sensitivity", "add_sensitivity", "reset"]]):
|
242
280
|
return True
|
243
|
-
|
281
|
+
return False
|
244
282
|
|
245
283
|
|
246
|
-
def
|
284
|
+
def _is_valid_module(mod: Any):
|
247
285
|
""" Checks if the argument is a valid Module object
|
248
286
|
:param mod: The object to check
|
249
287
|
:return: True if it is a valid Module
|
@@ -252,7 +290,7 @@ def _check_valid_module(mod: Any):
|
|
252
290
|
return True
|
253
291
|
if hasattr(mod, "response") and hasattr(mod, "sensitivity") and hasattr(mod, "reset"):
|
254
292
|
return True
|
255
|
-
|
293
|
+
return False
|
256
294
|
|
257
295
|
|
258
296
|
def _check_function_signature(fn, signals):
|
@@ -370,60 +408,48 @@ class Module(ABC, RegisteredClass):
|
|
370
408
|
>> Module(sig_in=[inputs], sig_out=[outputs]
|
371
409
|
"""
|
372
410
|
|
373
|
-
def _err_str(self,
|
411
|
+
def _err_str(self, module_signature: bool = True, init: bool = True, fn=None):
|
374
412
|
str_list = []
|
375
|
-
|
376
|
-
|
377
|
-
if add_signal:
|
413
|
+
|
414
|
+
if module_signature:
|
378
415
|
inp_str = "Inputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_in]) if len(self.sig_in) > 0 else "No inputs"
|
379
416
|
out_str = "Outputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_out]) if len(self.sig_out) > 0 else "No outputs"
|
380
|
-
str_list.append(inp_str + " --> " + out_str)
|
417
|
+
str_list.append(f"Module \'{type(self).__name__}\'( " + inp_str + " ) --> " + out_str)
|
418
|
+
if init:
|
419
|
+
str_list.append(f"Used in {self._init_loc}")
|
381
420
|
if fn is not None:
|
382
421
|
name = f"{fn.__self__.__class__.__name__}.{fn.__name__}{inspect.signature(fn)}"
|
383
422
|
lineno = inspect.getsourcelines(fn)[1]
|
384
423
|
filename = inspect.getfile(fn)
|
385
|
-
str_list.append(f"
|
424
|
+
str_list.append(f"Implementation in File \"{filename}\", line {lineno}, in {name}")
|
386
425
|
return err_fmt(*str_list)
|
387
426
|
|
388
427
|
# flake8: noqa: C901
|
389
428
|
def __init__(self, sig_in: Union[Signal, List[Signal]] = None, sig_out: Union[Signal, List[Signal]] = None,
|
390
429
|
*args, **kwargs):
|
391
|
-
# TODO: Reduce complexity of this init
|
392
430
|
self._init_loc = get_init_str()
|
393
431
|
|
394
432
|
self.sig_in = _parse_to_list(sig_in)
|
395
433
|
self.sig_out = _parse_to_list(sig_out)
|
396
434
|
for i, s in enumerate(self.sig_in):
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
401
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
402
|
-
raise type(e)(f"Invalid input signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
|
435
|
+
if not _is_valid_signal(s):
|
436
|
+
tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
|
437
|
+
raise TypeError(f"Input {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
|
403
438
|
|
404
439
|
for i, s in enumerate(self.sig_out):
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
409
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
410
|
-
raise type(e)(f"Invalid output signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
|
440
|
+
if not _is_valid_signal(s):
|
441
|
+
tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
|
442
|
+
raise TypeError(f"Output {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
|
411
443
|
|
412
|
-
|
413
|
-
|
414
|
-
self._prepare(*args, **kwargs)
|
415
|
-
except Exception as e:
|
416
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
417
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
418
|
-
raise type(e)("_prepare() - " + str(earg0) + self._err_str(fn=self._prepare), *earg1) from e
|
444
|
+
# Call preparation of submodule with remaining arguments
|
445
|
+
self._prepare(*args, **kwargs)
|
419
446
|
|
420
447
|
try:
|
421
448
|
# Check if the signals match _response() signature
|
422
449
|
_check_function_signature(self._response, self.sig_in)
|
423
450
|
except Exception as e:
|
424
|
-
|
425
|
-
|
426
|
-
raise type(e)(str(earg0) + self._err_str(fn=self._response), *earg1) from None
|
451
|
+
raise type(e)(str(e) + "\n\t| Module details:" +
|
452
|
+
self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
|
427
453
|
|
428
454
|
try:
|
429
455
|
# If no output signals are given, but are required, try to initialize them here
|
@@ -441,9 +467,8 @@ class Module(ABC, RegisteredClass):
|
|
441
467
|
# Check if signals match _sensitivity() signature
|
442
468
|
_check_function_signature(self._sensitivity, self.sig_out)
|
443
469
|
except Exception as e:
|
444
|
-
|
445
|
-
|
446
|
-
raise type(e)(str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from None
|
470
|
+
raise type(e)(str(e) + "\n\t| Module details:" +
|
471
|
+
self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
|
447
472
|
|
448
473
|
def response(self):
|
449
474
|
""" Calculate the response from sig_in and output this to sig_out """
|
@@ -461,9 +486,9 @@ class Module(ABC, RegisteredClass):
|
|
461
486
|
self.sig_out[i].state = val
|
462
487
|
return self
|
463
488
|
except Exception as e:
|
464
|
-
|
465
|
-
|
466
|
-
|
489
|
+
# https://stackoverflow.com/questions/6062576/adding-information-to-an-exception
|
490
|
+
raise type(e)(str(e) + "\n\t| Above error was raised when calling response(). Module details:" +
|
491
|
+
self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
|
467
492
|
|
468
493
|
def __call__(self):
|
469
494
|
return self.response()
|
@@ -494,9 +519,8 @@ class Module(ABC, RegisteredClass):
|
|
494
519
|
|
495
520
|
return self
|
496
521
|
except Exception as e:
|
497
|
-
|
498
|
-
|
499
|
-
raise type(e)("sensitivity() - " + str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from e
|
522
|
+
raise type(e)(str(e) + "\n\t| Above error was raised when calling sensitivity(). Module details:" +
|
523
|
+
self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
|
500
524
|
|
501
525
|
def reset(self):
|
502
526
|
""" Reset the state of the sensitivities (they are set to zero or to None) """
|
@@ -506,9 +530,8 @@ class Module(ABC, RegisteredClass):
|
|
506
530
|
self._reset()
|
507
531
|
return self
|
508
532
|
except Exception as e:
|
509
|
-
|
510
|
-
|
511
|
-
raise type(e)("reset() - " + str(earg0) + self._err_str(fn=self._reset), *earg1) from e
|
533
|
+
raise type(e)(str(e) + "\n\t| Above error was raised when calling reset(). Module details:" +
|
534
|
+
self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
|
512
535
|
|
513
536
|
# METHODS TO BE DEFINED BY USER
|
514
537
|
def _prepare(self, *args, **kwargs):
|
@@ -542,35 +565,33 @@ class Network(Module):
|
|
542
565
|
"""
|
543
566
|
def __init__(self, *args, print_timing=False):
|
544
567
|
self._init_loc = get_init_str()
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
573
|
-
raise type(e)(str(earg0) + self._err_str(add_signal=False), *earg1) from None
|
568
|
+
|
569
|
+
# Obtain the internal blocks
|
570
|
+
self.mods = _parse_to_list(*args)
|
571
|
+
|
572
|
+
# Check if the blocks are initialized, else create them
|
573
|
+
for i, b in enumerate(self.mods):
|
574
|
+
if isinstance(b, dict):
|
575
|
+
exclude_keys = ['type']
|
576
|
+
b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
|
577
|
+
self.mods[i] = Module.create(b['type'], **b_ex)
|
578
|
+
|
579
|
+
# Check validity of modules
|
580
|
+
for m in self.mods:
|
581
|
+
if not _is_valid_module(m):
|
582
|
+
raise TypeError(f"Argument is not a valid Module, type=\'{type(mod).__name__}\'.")
|
583
|
+
|
584
|
+
# Gather all the input and output signals of the internal blocks
|
585
|
+
all_in = set()
|
586
|
+
all_out = set()
|
587
|
+
[all_in.update(b.sig_in) for b in self.mods]
|
588
|
+
[all_out.update(b.sig_out) for b in self.mods]
|
589
|
+
in_unique = all_in - all_out
|
590
|
+
|
591
|
+
# Initialize the parent module, with correct inputs and outputs
|
592
|
+
super().__init__(list(in_unique), list(all_out))
|
593
|
+
|
594
|
+
self.print_timing = print_timing
|
574
595
|
|
575
596
|
def timefn(self, fn):
|
576
597
|
start_t = time.time()
|
@@ -611,12 +632,9 @@ class Network(Module):
|
|
611
632
|
modlist = _parse_to_list(*newmods)
|
612
633
|
|
613
634
|
# Check if the blocks are initialized, else create them
|
614
|
-
for i,
|
615
|
-
|
616
|
-
|
617
|
-
except Exception as e:
|
618
|
-
raise type(e)("append() - Trying to append invalid module " + str(e.args[0])
|
619
|
-
+ self._err_str(add_signal=False), *e.args[1:]) from None
|
635
|
+
for i, m in enumerate(modlist):
|
636
|
+
if not _is_valid_module(m):
|
637
|
+
raise TypeError(f"Argument #{i} is not a valid module, type=\'{type(mod).__name__}\'.")
|
620
638
|
|
621
639
|
# Obtain the internal blocks
|
622
640
|
self.mods.extend(modlist)
|
@@ -624,25 +642,11 @@ class Network(Module):
|
|
624
642
|
# Gather all the input and output signals of the internal blocks
|
625
643
|
all_in = set()
|
626
644
|
all_out = set()
|
627
|
-
[all_in.update(
|
628
|
-
[all_out.update(
|
645
|
+
[all_in.update(m.sig_in) for m in self.mods]
|
646
|
+
[all_out.update(m.sig_out) for m in self.mods]
|
629
647
|
in_unique = all_in - all_out
|
630
648
|
|
631
649
|
self.sig_in = _parse_to_list(in_unique)
|
632
|
-
try:
|
633
|
-
[_check_valid_signal(s) for s in self.sig_in]
|
634
|
-
except Exception as e:
|
635
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
636
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
637
|
-
raise type(e)("append() - Invalid input signals " + str(earg0)
|
638
|
-
+ self._err_str(add_signal=False), *earg1) from None
|
639
650
|
self.sig_out = _parse_to_list(all_out)
|
640
|
-
try:
|
641
|
-
[_check_valid_signal(s) for s in self.sig_out]
|
642
|
-
except Exception as e:
|
643
|
-
earg0 = e.args[0] if len(e.args) > 0 else ''
|
644
|
-
earg1 = e.args[1:] if len(e.args) > 1 else ()
|
645
|
-
raise type(e)("append() - Invalid output signals " + str(earg0)
|
646
|
-
+ self._err_str(add_signal=False), *earg1) from None
|
647
651
|
|
648
652
|
return modlist[-1].sig_out[0] if len(modlist[-1].sig_out) == 1 else modlist[-1].sig_out # Returns the output signal
|
@@ -0,0 +1,209 @@
|
|
1
|
+
import warnings
|
2
|
+
import abc
|
3
|
+
import numpy as np
|
4
|
+
import scipy.special as spsp
|
5
|
+
from pymoto import Module
|
6
|
+
|
7
|
+
|
8
|
+
class AggActiveSet:
|
9
|
+
""" Determine active set by discarding lower or upper fraction of a set of values
|
10
|
+
|
11
|
+
Args:
|
12
|
+
lower_rel: Fraction of values closest to minimum to discard (based on value)
|
13
|
+
upper_rel: Fraction of values closest to maximum to discard (based on value)
|
14
|
+
lower_amt: Fraction of lowest values to discard (based on sorting)
|
15
|
+
upper_amt: Fraction of highest values to discard (based on sorting)
|
16
|
+
"""
|
17
|
+
def __init__(self, lower_rel=0.0, upper_rel=1.0, lower_amt=0.0, upper_amt=1.0):
|
18
|
+
assert upper_rel > lower_rel, "Upper must be larger than lower to keep values in the set"
|
19
|
+
assert upper_amt > lower_amt, "Upper must be larger than lower to keep values in the set"
|
20
|
+
self.lower_rel, self.upper_rel = lower_rel, upper_rel
|
21
|
+
self.lower_amt, self.upper_amt = lower_amt, upper_amt
|
22
|
+
|
23
|
+
def __call__(self, x):
|
24
|
+
""" Generate an active set for given array """
|
25
|
+
xmin, xmax = np.min(x), np.max(x)
|
26
|
+
if (xmax - xmin) == 0: # All values are the same, so no active set can be taken
|
27
|
+
return Ellipsis
|
28
|
+
|
29
|
+
sel = np.ones_like(x, dtype=bool)
|
30
|
+
|
31
|
+
# Select based on value
|
32
|
+
xrel = (x - xmin) / (xmax - xmin) # Normalize between 0 and 1
|
33
|
+
if self.lower_rel > 0:
|
34
|
+
sel = np.logical_and(sel, xrel >= self.lower_rel)
|
35
|
+
if self.upper_rel < 1:
|
36
|
+
sel = np.logical_and(sel, xrel <= self.upper_rel)
|
37
|
+
|
38
|
+
# Remove lowest and highest N values
|
39
|
+
i_sort = np.argsort(x)
|
40
|
+
if self.lower_amt > 0:
|
41
|
+
n_lower_amt = int(x.size * self.lower_amt)
|
42
|
+
sel[i_sort[:n_lower_amt]] = False
|
43
|
+
|
44
|
+
if self.upper_amt < 1:
|
45
|
+
n_upper_amt = int(x.size * (1 - self.upper_amt))
|
46
|
+
sel[i_sort[-n_upper_amt:]] = False
|
47
|
+
|
48
|
+
return sel
|
49
|
+
|
50
|
+
|
51
|
+
class AggScaling:
|
52
|
+
""" Scaling strategy to absolute minimum or maximum
|
53
|
+
|
54
|
+
Args:
|
55
|
+
which: Scale to `min` or `max`
|
56
|
+
damping(optional): Damping factor between [0, 1), for a value of 0.0 the aggregation approximation is corrected
|
57
|
+
to the exact maximum or minimum of the input set
|
58
|
+
"""
|
59
|
+
def __init__(self, which: str, damping=0.0):
|
60
|
+
self.damping = damping
|
61
|
+
if which.lower() == 'min':
|
62
|
+
self.f = np.min
|
63
|
+
elif which.lower() == 'max':
|
64
|
+
self.f = np.max
|
65
|
+
else:
|
66
|
+
raise ValueError("Argument `which` can only be 'min' or 'max'")
|
67
|
+
self.sf = None
|
68
|
+
|
69
|
+
def __call__(self, x, fx_approx):
|
70
|
+
""" Determine scaling factor
|
71
|
+
|
72
|
+
Args:
|
73
|
+
x: Set of values
|
74
|
+
fx_approx: Approximated minimum / maximum
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
Scaling factor
|
78
|
+
"""
|
79
|
+
trueval = self.f(x)
|
80
|
+
scale = trueval / fx_approx
|
81
|
+
if self.sf is None:
|
82
|
+
self.sf = scale
|
83
|
+
else:
|
84
|
+
self.sf = self.damping * self.sf + (1 - self.damping) * scale
|
85
|
+
return self.sf
|
86
|
+
|
87
|
+
|
88
|
+
class Aggregation(Module):
|
89
|
+
""" Generic Aggregation module (cannot be used directly, but can only be used as superclass)
|
90
|
+
|
91
|
+
Keyword Args:
|
92
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
93
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
94
|
+
"""
|
95
|
+
def _prepare(self, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
96
|
+
# This prepare function MUST be called in the _prepare function of sub-classes
|
97
|
+
self.scaling = scaling
|
98
|
+
self.active_set = active_set
|
99
|
+
self.sf = 1.0
|
100
|
+
|
101
|
+
@abc.abstractmethod
|
102
|
+
def aggregation_function(self, x):
|
103
|
+
""" Calculates f(x) """
|
104
|
+
raise NotImplementedError()
|
105
|
+
|
106
|
+
@abc.abstractmethod
|
107
|
+
def aggregation_derivative(self, x):
|
108
|
+
"""" Calculates df(x) / dx """
|
109
|
+
raise NotImplementedError()
|
110
|
+
|
111
|
+
def _response(self, x):
|
112
|
+
# Determine active set
|
113
|
+
if self.active_set is not None:
|
114
|
+
self.select = self.active_set(x)
|
115
|
+
else:
|
116
|
+
self.select = Ellipsis
|
117
|
+
|
118
|
+
# Get aggregated value
|
119
|
+
xagg = self.aggregation_function(x[self.select])
|
120
|
+
|
121
|
+
# Scale
|
122
|
+
if self.scaling is not None:
|
123
|
+
self.sf = self.scaling(x[self.select], xagg)
|
124
|
+
return self.sf * xagg
|
125
|
+
|
126
|
+
def _sensitivity(self, dfdy):
|
127
|
+
x = self.sig_in[0].state
|
128
|
+
dydx = self.aggregation_derivative(x[self.select])
|
129
|
+
dx = np.zeros_like(x)
|
130
|
+
dx[self.select] += self.sf * dfdy * dydx
|
131
|
+
return dx
|
132
|
+
|
133
|
+
|
134
|
+
class PNorm(Aggregation):
|
135
|
+
r""" P-norm aggregration
|
136
|
+
|
137
|
+
:math:`S_p(x_1, x_2, \dotsc, x_n) = \left( \sum_i (|x_i|^p) \right)^{1/p}
|
138
|
+
|
139
|
+
Only valid for positive :math:`x_i` when approximating the minimum or maximum
|
140
|
+
|
141
|
+
Args:
|
142
|
+
p: Power of the p-norm. Approximate maximum for `p>0` and minimum for `p<0`
|
143
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
144
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
145
|
+
"""
|
146
|
+
def _prepare(self, p=2, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
147
|
+
self.p = p
|
148
|
+
self.y = None
|
149
|
+
super()._prepare(scaling, active_set)
|
150
|
+
|
151
|
+
def aggregation_function(self, x):
|
152
|
+
if np.min(x) < 0:
|
153
|
+
warnings.warn("PNorm is only valid for positive x")
|
154
|
+
|
155
|
+
# Get p-norm
|
156
|
+
return np.sum(np.abs(x) ** self.p) ** (1/self.p)
|
157
|
+
|
158
|
+
def aggregation_derivative(self, x):
|
159
|
+
pval = np.sum(np.abs(x) ** self.p) ** (1 / self.p - 1)
|
160
|
+
return pval * np.sign(x) * np.abs(x)**(self.p - 1)
|
161
|
+
|
162
|
+
|
163
|
+
class SoftMinMax(Aggregation):
|
164
|
+
r""" Soft maximum/minimum function
|
165
|
+
|
166
|
+
:math:`S_a(x_1, x_2, \dotsc, x_n) = \frac{\sum_i (x_i \exp(a x_i))}{\sum_i (\exp(a x_i))}`
|
167
|
+
|
168
|
+
When using as maximum, it underestimates the maximum
|
169
|
+
It is exact however when :math:`x_1=x_2=\dotsc=x_n`
|
170
|
+
|
171
|
+
Args:
|
172
|
+
alpha: Scaling factor of the soft function. Approximate maximum for `alpha>0` and minimum for `alpha<0`
|
173
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
174
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
175
|
+
"""
|
176
|
+
def _prepare(self, alpha=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
177
|
+
self.alpha = alpha
|
178
|
+
self.y = None
|
179
|
+
super()._prepare(scaling, active_set)
|
180
|
+
|
181
|
+
def aggregation_function(self, x):
|
182
|
+
self.y = np.sum(x * spsp.softmax(self.alpha * x))
|
183
|
+
return self.y
|
184
|
+
|
185
|
+
def aggregation_derivative(self, x):
|
186
|
+
return spsp.softmax(self.alpha * x) * (1 + self.alpha * (x - self.y))
|
187
|
+
|
188
|
+
|
189
|
+
class KSFunction(Aggregation):
|
190
|
+
r""" Kreisselmeier and Steinhauser function from 1979
|
191
|
+
|
192
|
+
:math:`S_\rho(x_1, x_2, \dotsc, x_n) = \frac{1}{\rho} \ln \left( \sum_i \exp(\rho x_i) \right)`
|
193
|
+
|
194
|
+
Args:
|
195
|
+
rho: Scaling factor of the KS function. Approximate maximum for `rho>0` and minimum for `rho<0`
|
196
|
+
scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
|
197
|
+
active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
|
198
|
+
"""
|
199
|
+
def _prepare(self, rho=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
|
200
|
+
self.rho = rho
|
201
|
+
self.y = None
|
202
|
+
super()._prepare(scaling, active_set)
|
203
|
+
|
204
|
+
def aggregation_function(self, x):
|
205
|
+
return 1/self.rho * np.log(np.sum(np.exp(self.rho * x)))
|
206
|
+
|
207
|
+
def aggregation_derivative(self, x):
|
208
|
+
erx = np.exp(self.rho * x)
|
209
|
+
return erx / np.sum(erx)
|