pyMOTO 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymoto/core_objects.py CHANGED
@@ -1,9 +1,11 @@
1
- from typing import Union, List, Any
1
+ import sys
2
2
  import warnings
3
3
  import inspect
4
4
  import time
5
- from .utils import _parse_to_list, _concatenate_to_array, _split_from_array
5
+ import copy
6
+ from typing import Union, List, Any
6
7
  from abc import ABC, abstractmethod
8
+ from .utils import _parse_to_list, _concatenate_to_array, _split_from_array
7
9
 
8
10
 
9
11
  # Local helper functions
@@ -11,7 +13,7 @@ def err_fmt(*args):
11
13
  """ Format error strings for locating Modules and Signals"""
12
14
  err_str = ""
13
15
  for a in args:
14
- err_str += f"\n\t[ {a} ]"
16
+ err_str += f"\n\t| {a}"
15
17
  return err_str
16
18
 
17
19
 
@@ -76,16 +78,20 @@ class Signal:
76
78
  >> Signal(tag='x2')
77
79
 
78
80
  """
79
- def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None):
81
+ def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None, min: Any = None, max: Any = None):
80
82
  """
81
-
82
- :param tag: The name of the signal (string)
83
- :param state: The initialized state (optional)
84
- :param sensitivity: The initialized sensitivity (optional)
83
+ Keyword Args:
84
+ tag: The name of the signal
85
+ state: The initialized state
86
+ sensitivity: The initialized sensitivity
87
+ min: Minimum allowed value
88
+ max: Maximum allowed value
85
89
  """
86
90
  self.tag = tag
87
91
  self.state = state
88
92
  self.sensitivity = sensitivity
93
+ self.min = min
94
+ self.max = max
89
95
  self.keep_alloc = sensitivity is not None
90
96
 
91
97
  # Save error string to location where it is initialized
@@ -95,11 +101,12 @@ class Signal:
95
101
  return err_fmt(f"Signal \'{self.tag}\', initialized in {self._init_loc}")
96
102
 
97
103
  def add_sensitivity(self, ds: Any):
104
+ """ Add a new term to internal sensitivity """
98
105
  try:
99
106
  if ds is None:
100
107
  return
101
108
  if self.sensitivity is None:
102
- self.sensitivity = ds
109
+ self.sensitivity = copy.deepcopy(ds)
103
110
  else:
104
111
  self.sensitivity += ds
105
112
  return self
@@ -116,8 +123,12 @@ class Signal:
116
123
  def reset(self, keep_alloc: bool = None):
117
124
  """ Reset the sensitivities to zero or None
118
125
  This must be called to clear internal memory of subsequent sensitivity calculations.
119
- :param keep_alloc: Keep the sensitivity allocation intact?
120
- :return: self
126
+
127
+ Args:
128
+ keep_alloc: Keep the sensitivity allocation intact?
129
+
130
+ Returns:
131
+ self
121
132
  """
122
133
  if self.sensitivity is None:
123
134
  return self
@@ -138,11 +149,34 @@ class Signal:
138
149
 
139
150
  def __getitem__(self, item):
140
151
  """ Obtain a sliced signal, for using its partial contents.
141
- :param item: Slice indices
142
- :return: Sliced signal (SignalSlice)
152
+
153
+ Args:
154
+ item: Slice indices
155
+
156
+ Returns:
157
+ Sliced signal (SignalSlice)
143
158
  """
144
159
  return SignalSlice(self, item)
145
160
 
161
+ def __str__(self):
162
+ state_msg = f"state {self.state}" if self.state is not None else "empty state"
163
+ state_msg = state_msg.split('\n')
164
+ if len(state_msg) > 1:
165
+ state_msg = state_msg[0] + ' ... ' + state_msg[-1]
166
+ else:
167
+ state_msg = state_msg[0]
168
+ return f"Signal \"{self.tag}\" with {state_msg}"
169
+
170
+ def __repr__(self):
171
+ state_msg = f"state {self.state}" if self.state is not None else "empty state"
172
+ state_msg = state_msg.split('\n')
173
+ if len(state_msg) > 1:
174
+ state_msg = state_msg[0] + ' ... ' + state_msg[-1]
175
+ else:
176
+ state_msg = state_msg[0]
177
+ sens_msg = 'empty sensitivity' if self.sensitivity is None else 'non-empty sensitivity'
178
+ return f"Signal \"{self.tag}\" with {state_msg} and {sens_msg} at {hex(id(self))}"
179
+
146
180
 
147
181
  class SignalSlice(Signal):
148
182
  """ Slice operator for a Signal
@@ -169,7 +203,8 @@ class SignalSlice(Signal):
169
203
  return None if self.orig_signal.state is None else self.orig_signal.state[self.slice]
170
204
  except Exception as e:
171
205
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
172
- raise type(e)("SignalSlice.state (getter)" + self._err_str()) from e
206
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (getter). Signal details:" +
207
+ self._err_str()).with_traceback(sys.exc_info()[2])
173
208
 
174
209
  @state.setter
175
210
  def state(self, new_state):
@@ -177,7 +212,8 @@ class SignalSlice(Signal):
177
212
  self.orig_signal.state[self.slice] = new_state
178
213
  except Exception as e:
179
214
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
180
- raise type(e)("SignalSlice.state (setter)" + self._err_str()) from e
215
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
216
+ self._err_str()).with_traceback(sys.exc_info()[2])
181
217
 
182
218
  @property
183
219
  def sensitivity(self):
@@ -185,7 +221,8 @@ class SignalSlice(Signal):
185
221
  return None if self.orig_signal.sensitivity is None else self.orig_signal.sensitivity[self.slice]
186
222
  except Exception as e:
187
223
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
188
- raise type(e)("SignalSlice.sensitivity (getter)" + self._err_str()) from e
224
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.sensitivity (getter). Signal details:" +
225
+ self._err_str()).with_traceback(sys.exc_info()[2])
189
226
 
190
227
  @sensitivity.setter
191
228
  def sensitivity(self, new_sens):
@@ -207,7 +244,8 @@ class SignalSlice(Signal):
207
244
  self.orig_signal.sensitivity[self.slice] = new_sens
208
245
  except Exception as e:
209
246
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
210
- raise type(e)("SignalSlice.sensitivity (setter)" + self._err_str()) from e
247
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
248
+ self._err_str()).with_traceback(sys.exc_info()[2])
211
249
 
212
250
  def reset(self, keep_alloc: bool = None):
213
251
  """ Reset the sensitivities to zero or None
@@ -231,7 +269,7 @@ def make_signals(*args):
231
269
  return ret
232
270
 
233
271
 
234
- def _check_valid_signal(sig: Any):
272
+ def _is_valid_signal(sig: Any):
235
273
  """ Checks if the argument is a valid Signal object
236
274
  :param sig: The object to check
237
275
  :return: True if it is a valid Signal
@@ -240,10 +278,10 @@ def _check_valid_signal(sig: Any):
240
278
  return True
241
279
  if all([hasattr(sig, f) for f in ["state", "sensitivity", "add_sensitivity", "reset"]]):
242
280
  return True
243
- raise TypeError(f"Given argument with type \'{type(sig).__name__}\' is not a valid Signal")
281
+ return False
244
282
 
245
283
 
246
- def _check_valid_module(mod: Any):
284
+ def _is_valid_module(mod: Any):
247
285
  """ Checks if the argument is a valid Module object
248
286
  :param mod: The object to check
249
287
  :return: True if it is a valid Module
@@ -252,7 +290,7 @@ def _check_valid_module(mod: Any):
252
290
  return True
253
291
  if hasattr(mod, "response") and hasattr(mod, "sensitivity") and hasattr(mod, "reset"):
254
292
  return True
255
- raise TypeError(f"Given argument with type \'{type(mod).__name__}\' is not a valid Module")
293
+ return False
256
294
 
257
295
 
258
296
  def _check_function_signature(fn, signals):
@@ -370,60 +408,48 @@ class Module(ABC, RegisteredClass):
370
408
  >> Module(sig_in=[inputs], sig_out=[outputs]
371
409
  """
372
410
 
373
- def _err_str(self, init: bool = True, add_signal: bool = True, fn=None):
411
+ def _err_str(self, module_signature: bool = True, init: bool = True, fn=None):
374
412
  str_list = []
375
- if init:
376
- str_list.append(f"Module \'{type(self).__name__}\', initialized in {self._init_loc}")
377
- if add_signal:
413
+
414
+ if module_signature:
378
415
  inp_str = "Inputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_in]) if len(self.sig_in) > 0 else "No inputs"
379
416
  out_str = "Outputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_out]) if len(self.sig_out) > 0 else "No outputs"
380
- str_list.append(inp_str + " --> " + out_str)
417
+ str_list.append(f"Module \'{type(self).__name__}\'( " + inp_str + " ) --> " + out_str)
418
+ if init:
419
+ str_list.append(f"Used in {self._init_loc}")
381
420
  if fn is not None:
382
421
  name = f"{fn.__self__.__class__.__name__}.{fn.__name__}{inspect.signature(fn)}"
383
422
  lineno = inspect.getsourcelines(fn)[1]
384
423
  filename = inspect.getfile(fn)
385
- str_list.append(f"Implemented in File \"{filename}\", line {lineno}, in {name}")
424
+ str_list.append(f"Implementation in File \"{filename}\", line {lineno}, in {name}")
386
425
  return err_fmt(*str_list)
387
426
 
388
427
  # flake8: noqa: C901
389
428
  def __init__(self, sig_in: Union[Signal, List[Signal]] = None, sig_out: Union[Signal, List[Signal]] = None,
390
429
  *args, **kwargs):
391
- # TODO: Reduce complexity of this init
392
430
  self._init_loc = get_init_str()
393
431
 
394
432
  self.sig_in = _parse_to_list(sig_in)
395
433
  self.sig_out = _parse_to_list(sig_out)
396
434
  for i, s in enumerate(self.sig_in):
397
- try:
398
- _check_valid_signal(s)
399
- except Exception as e:
400
- earg0 = e.args[0] if len(e.args) > 0 else ''
401
- earg1 = e.args[1:] if len(e.args) > 1 else ()
402
- raise type(e)(f"Invalid input signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
435
+ if not _is_valid_signal(s):
436
+ tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
437
+ raise TypeError(f"Input {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
403
438
 
404
439
  for i, s in enumerate(self.sig_out):
405
- try:
406
- _check_valid_signal(s)
407
- except Exception as e:
408
- earg0 = e.args[0] if len(e.args) > 0 else ''
409
- earg1 = e.args[1:] if len(e.args) > 1 else ()
410
- raise type(e)(f"Invalid output signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
440
+ if not _is_valid_signal(s):
441
+ tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
442
+ raise TypeError(f"Output {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
411
443
 
412
- try:
413
- # Call preparation of submodule with remaining arguments
414
- self._prepare(*args, **kwargs)
415
- except Exception as e:
416
- earg0 = e.args[0] if len(e.args) > 0 else ''
417
- earg1 = e.args[1:] if len(e.args) > 1 else ()
418
- raise type(e)("_prepare() - " + str(earg0) + self._err_str(fn=self._prepare), *earg1) from e
444
+ # Call preparation of submodule with remaining arguments
445
+ self._prepare(*args, **kwargs)
419
446
 
420
447
  try:
421
448
  # Check if the signals match _response() signature
422
449
  _check_function_signature(self._response, self.sig_in)
423
450
  except Exception as e:
424
- earg0 = e.args[0] if len(e.args) > 0 else ''
425
- earg1 = e.args[1:] if len(e.args) > 1 else ()
426
- raise type(e)(str(earg0) + self._err_str(fn=self._response), *earg1) from None
451
+ raise type(e)(str(e) + "\n\t| Module details:" +
452
+ self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
427
453
 
428
454
  try:
429
455
  # If no output signals are given, but are required, try to initialize them here
@@ -441,9 +467,8 @@ class Module(ABC, RegisteredClass):
441
467
  # Check if signals match _sensitivity() signature
442
468
  _check_function_signature(self._sensitivity, self.sig_out)
443
469
  except Exception as e:
444
- earg0 = e.args[0] if len(e.args) > 0 else ''
445
- earg1 = e.args[1:] if len(e.args) > 1 else ()
446
- raise type(e)(str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from None
470
+ raise type(e)(str(e) + "\n\t| Module details:" +
471
+ self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
447
472
 
448
473
  def response(self):
449
474
  """ Calculate the response from sig_in and output this to sig_out """
@@ -461,9 +486,9 @@ class Module(ABC, RegisteredClass):
461
486
  self.sig_out[i].state = val
462
487
  return self
463
488
  except Exception as e:
464
- earg0 = e.args[0] if len(e.args) > 0 else ''
465
- earg1 = e.args[1:] if len(e.args) > 1 else ()
466
- raise type(e)("response() - " + str(earg0) + self._err_str(fn=self._response), *earg1) from e
489
+ # https://stackoverflow.com/questions/6062576/adding-information-to-an-exception
490
+ raise type(e)(str(e) + "\n\t| Above error was raised when calling response(). Module details:" +
491
+ self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
467
492
 
468
493
  def __call__(self):
469
494
  return self.response()
@@ -494,9 +519,8 @@ class Module(ABC, RegisteredClass):
494
519
 
495
520
  return self
496
521
  except Exception as e:
497
- earg0 = e.args[0] if len(e.args) > 0 else ''
498
- earg1 = e.args[1:] if len(e.args) > 1 else ()
499
- raise type(e)("sensitivity() - " + str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from e
522
+ raise type(e)(str(e) + "\n\t| Above error was raised when calling sensitivity(). Module details:" +
523
+ self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
500
524
 
501
525
  def reset(self):
502
526
  """ Reset the state of the sensitivities (they are set to zero or to None) """
@@ -506,9 +530,8 @@ class Module(ABC, RegisteredClass):
506
530
  self._reset()
507
531
  return self
508
532
  except Exception as e:
509
- earg0 = e.args[0] if len(e.args) > 0 else ''
510
- earg1 = e.args[1:] if len(e.args) > 1 else ()
511
- raise type(e)("reset() - " + str(earg0) + self._err_str(fn=self._reset), *earg1) from e
533
+ raise type(e)(str(e) + "\n\t| Above error was raised when calling reset(). Module details:" +
534
+ self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
512
535
 
513
536
  # METHODS TO BE DEFINED BY USER
514
537
  def _prepare(self, *args, **kwargs):
@@ -542,35 +565,33 @@ class Network(Module):
542
565
  """
543
566
  def __init__(self, *args, print_timing=False):
544
567
  self._init_loc = get_init_str()
545
- try:
546
- # Obtain the internal blocks
547
- self.mods = _parse_to_list(*args)
548
-
549
- # Check if the blocks are initialized, else create them
550
- for i, b in enumerate(self.mods):
551
- if isinstance(b, dict):
552
- exclude_keys = ['type']
553
- b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
554
- self.mods[i] = Module.create(b['type'], **b_ex)
555
-
556
- # Check validity of modules
557
- [_check_valid_module(m) for m in self.mods]
558
-
559
- # Gather all the input and output signals of the internal blocks
560
- all_in = set()
561
- all_out = set()
562
- [all_in.update(b.sig_in) for b in self.mods]
563
- [all_out.update(b.sig_out) for b in self.mods]
564
- in_unique = all_in - all_out
565
-
566
- # Initialize the parent module, with correct inputs and outputs
567
- super().__init__(list(in_unique), list(all_out))
568
-
569
- self.print_timing = print_timing
570
- except Exception as e:
571
- earg0 = e.args[0] if len(e.args) > 0 else ''
572
- earg1 = e.args[1:] if len(e.args) > 1 else ()
573
- raise type(e)(str(earg0) + self._err_str(add_signal=False), *earg1) from None
568
+
569
+ # Obtain the internal blocks
570
+ self.mods = _parse_to_list(*args)
571
+
572
+ # Check if the blocks are initialized, else create them
573
+ for i, b in enumerate(self.mods):
574
+ if isinstance(b, dict):
575
+ exclude_keys = ['type']
576
+ b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
577
+ self.mods[i] = Module.create(b['type'], **b_ex)
578
+
579
+ # Check validity of modules
580
+ for m in self.mods:
581
+ if not _is_valid_module(m):
582
+ raise TypeError(f"Argument is not a valid Module, type=\'{type(mod).__name__}\'.")
583
+
584
+ # Gather all the input and output signals of the internal blocks
585
+ all_in = set()
586
+ all_out = set()
587
+ [all_in.update(b.sig_in) for b in self.mods]
588
+ [all_out.update(b.sig_out) for b in self.mods]
589
+ in_unique = all_in - all_out
590
+
591
+ # Initialize the parent module, with correct inputs and outputs
592
+ super().__init__(list(in_unique), list(all_out))
593
+
594
+ self.print_timing = print_timing
574
595
 
575
596
  def timefn(self, fn):
576
597
  start_t = time.time()
@@ -611,12 +632,9 @@ class Network(Module):
611
632
  modlist = _parse_to_list(*newmods)
612
633
 
613
634
  # Check if the blocks are initialized, else create them
614
- for i, b in enumerate(modlist):
615
- try: # Check validity of modules
616
- _check_valid_module(b)
617
- except Exception as e:
618
- raise type(e)("append() - Trying to append invalid module " + str(e.args[0])
619
- + self._err_str(add_signal=False), *e.args[1:]) from None
635
+ for i, m in enumerate(modlist):
636
+ if not _is_valid_module(m):
637
+ raise TypeError(f"Argument #{i} is not a valid module, type=\'{type(mod).__name__}\'.")
620
638
 
621
639
  # Obtain the internal blocks
622
640
  self.mods.extend(modlist)
@@ -624,25 +642,11 @@ class Network(Module):
624
642
  # Gather all the input and output signals of the internal blocks
625
643
  all_in = set()
626
644
  all_out = set()
627
- [all_in.update(b.sig_in) for b in self.mods]
628
- [all_out.update(b.sig_out) for b in self.mods]
645
+ [all_in.update(m.sig_in) for m in self.mods]
646
+ [all_out.update(m.sig_out) for m in self.mods]
629
647
  in_unique = all_in - all_out
630
648
 
631
649
  self.sig_in = _parse_to_list(in_unique)
632
- try:
633
- [_check_valid_signal(s) for s in self.sig_in]
634
- except Exception as e:
635
- earg0 = e.args[0] if len(e.args) > 0 else ''
636
- earg1 = e.args[1:] if len(e.args) > 1 else ()
637
- raise type(e)("append() - Invalid input signals " + str(earg0)
638
- + self._err_str(add_signal=False), *earg1) from None
639
650
  self.sig_out = _parse_to_list(all_out)
640
- try:
641
- [_check_valid_signal(s) for s in self.sig_out]
642
- except Exception as e:
643
- earg0 = e.args[0] if len(e.args) > 0 else ''
644
- earg1 = e.args[1:] if len(e.args) > 1 else ()
645
- raise type(e)("append() - Invalid output signals " + str(earg0)
646
- + self._err_str(add_signal=False), *earg1) from None
647
651
 
648
652
  return modlist[-1].sig_out[0] if len(modlist[-1].sig_out) == 1 else modlist[-1].sig_out # Returns the output signal
@@ -0,0 +1,209 @@
1
+ import warnings
2
+ import abc
3
+ import numpy as np
4
+ import scipy.special as spsp
5
+ from pymoto import Module
6
+
7
+
8
+ class AggActiveSet:
9
+ """ Determine active set by discarding lower or upper fraction of a set of values
10
+
11
+ Args:
12
+ lower_rel: Fraction of values closest to minimum to discard (based on value)
13
+ upper_rel: Fraction of values closest to maximum to discard (based on value)
14
+ lower_amt: Fraction of lowest values to discard (based on sorting)
15
+ upper_amt: Fraction of highest values to discard (based on sorting)
16
+ """
17
+ def __init__(self, lower_rel=0.0, upper_rel=1.0, lower_amt=0.0, upper_amt=1.0):
18
+ assert upper_rel > lower_rel, "Upper must be larger than lower to keep values in the set"
19
+ assert upper_amt > lower_amt, "Upper must be larger than lower to keep values in the set"
20
+ self.lower_rel, self.upper_rel = lower_rel, upper_rel
21
+ self.lower_amt, self.upper_amt = lower_amt, upper_amt
22
+
23
+ def __call__(self, x):
24
+ """ Generate an active set for given array """
25
+ xmin, xmax = np.min(x), np.max(x)
26
+ if (xmax - xmin) == 0: # All values are the same, so no active set can be taken
27
+ return Ellipsis
28
+
29
+ sel = np.ones_like(x, dtype=bool)
30
+
31
+ # Select based on value
32
+ xrel = (x - xmin) / (xmax - xmin) # Normalize between 0 and 1
33
+ if self.lower_rel > 0:
34
+ sel = np.logical_and(sel, xrel >= self.lower_rel)
35
+ if self.upper_rel < 1:
36
+ sel = np.logical_and(sel, xrel <= self.upper_rel)
37
+
38
+ # Remove lowest and highest N values
39
+ i_sort = np.argsort(x)
40
+ if self.lower_amt > 0:
41
+ n_lower_amt = int(x.size * self.lower_amt)
42
+ sel[i_sort[:n_lower_amt]] = False
43
+
44
+ if self.upper_amt < 1:
45
+ n_upper_amt = int(x.size * (1 - self.upper_amt))
46
+ sel[i_sort[-n_upper_amt:]] = False
47
+
48
+ return sel
49
+
50
+
51
+ class AggScaling:
52
+ """ Scaling strategy to absolute minimum or maximum
53
+
54
+ Args:
55
+ which: Scale to `min` or `max`
56
+ damping(optional): Damping factor between [0, 1), for a value of 0.0 the aggregation approximation is corrected
57
+ to the exact maximum or minimum of the input set
58
+ """
59
+ def __init__(self, which: str, damping=0.0):
60
+ self.damping = damping
61
+ if which.lower() == 'min':
62
+ self.f = np.min
63
+ elif which.lower() == 'max':
64
+ self.f = np.max
65
+ else:
66
+ raise ValueError("Argument `which` can only be 'min' or 'max'")
67
+ self.sf = None
68
+
69
+ def __call__(self, x, fx_approx):
70
+ """ Determine scaling factor
71
+
72
+ Args:
73
+ x: Set of values
74
+ fx_approx: Approximated minimum / maximum
75
+
76
+ Returns:
77
+ Scaling factor
78
+ """
79
+ trueval = self.f(x)
80
+ scale = trueval / fx_approx
81
+ if self.sf is None:
82
+ self.sf = scale
83
+ else:
84
+ self.sf = self.damping * self.sf + (1 - self.damping) * scale
85
+ return self.sf
86
+
87
+
88
+ class Aggregation(Module):
89
+ """ Generic Aggregation module (cannot be used directly, but can only be used as superclass)
90
+
91
+ Keyword Args:
92
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
93
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
94
+ """
95
+ def _prepare(self, scaling: AggScaling = None, active_set: AggActiveSet = None):
96
+ # This prepare function MUST be called in the _prepare function of sub-classes
97
+ self.scaling = scaling
98
+ self.active_set = active_set
99
+ self.sf = 1.0
100
+
101
+ @abc.abstractmethod
102
+ def aggregation_function(self, x):
103
+ """ Calculates f(x) """
104
+ raise NotImplementedError()
105
+
106
+ @abc.abstractmethod
107
+ def aggregation_derivative(self, x):
108
+ """" Calculates df(x) / dx """
109
+ raise NotImplementedError()
110
+
111
+ def _response(self, x):
112
+ # Determine active set
113
+ if self.active_set is not None:
114
+ self.select = self.active_set(x)
115
+ else:
116
+ self.select = Ellipsis
117
+
118
+ # Get aggregated value
119
+ xagg = self.aggregation_function(x[self.select])
120
+
121
+ # Scale
122
+ if self.scaling is not None:
123
+ self.sf = self.scaling(x[self.select], xagg)
124
+ return self.sf * xagg
125
+
126
+ def _sensitivity(self, dfdy):
127
+ x = self.sig_in[0].state
128
+ dydx = self.aggregation_derivative(x[self.select])
129
+ dx = np.zeros_like(x)
130
+ dx[self.select] += self.sf * dfdy * dydx
131
+ return dx
132
+
133
+
134
+ class PNorm(Aggregation):
135
+ r""" P-norm aggregration
136
+
137
+ :math:`S_p(x_1, x_2, \dotsc, x_n) = \left( \sum_i (|x_i|^p) \right)^{1/p}
138
+
139
+ Only valid for positive :math:`x_i` when approximating the minimum or maximum
140
+
141
+ Args:
142
+ p: Power of the p-norm. Approximate maximum for `p>0` and minimum for `p<0`
143
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
144
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
145
+ """
146
+ def _prepare(self, p=2, scaling: AggScaling = None, active_set: AggActiveSet = None):
147
+ self.p = p
148
+ self.y = None
149
+ super()._prepare(scaling, active_set)
150
+
151
+ def aggregation_function(self, x):
152
+ if np.min(x) < 0:
153
+ warnings.warn("PNorm is only valid for positive x")
154
+
155
+ # Get p-norm
156
+ return np.sum(np.abs(x) ** self.p) ** (1/self.p)
157
+
158
+ def aggregation_derivative(self, x):
159
+ pval = np.sum(np.abs(x) ** self.p) ** (1 / self.p - 1)
160
+ return pval * np.sign(x) * np.abs(x)**(self.p - 1)
161
+
162
+
163
+ class SoftMinMax(Aggregation):
164
+ r""" Soft maximum/minimum function
165
+
166
+ :math:`S_a(x_1, x_2, \dotsc, x_n) = \frac{\sum_i (x_i \exp(a x_i))}{\sum_i (\exp(a x_i))}`
167
+
168
+ When using as maximum, it underestimates the maximum
169
+ It is exact however when :math:`x_1=x_2=\dotsc=x_n`
170
+
171
+ Args:
172
+ alpha: Scaling factor of the soft function. Approximate maximum for `alpha>0` and minimum for `alpha<0`
173
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
174
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
175
+ """
176
+ def _prepare(self, alpha=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
177
+ self.alpha = alpha
178
+ self.y = None
179
+ super()._prepare(scaling, active_set)
180
+
181
+ def aggregation_function(self, x):
182
+ self.y = np.sum(x * spsp.softmax(self.alpha * x))
183
+ return self.y
184
+
185
+ def aggregation_derivative(self, x):
186
+ return spsp.softmax(self.alpha * x) * (1 + self.alpha * (x - self.y))
187
+
188
+
189
+ class KSFunction(Aggregation):
190
+ r""" Kreisselmeier and Steinhauser function from 1979
191
+
192
+ :math:`S_\rho(x_1, x_2, \dotsc, x_n) = \frac{1}{\rho} \ln \left( \sum_i \exp(\rho x_i) \right)`
193
+
194
+ Args:
195
+ rho: Scaling factor of the KS function. Approximate maximum for `rho>0` and minimum for `rho<0`
196
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
197
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
198
+ """
199
+ def _prepare(self, rho=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
200
+ self.rho = rho
201
+ self.y = None
202
+ super()._prepare(scaling, active_set)
203
+
204
+ def aggregation_function(self, x):
205
+ return 1/self.rho * np.log(np.sum(np.exp(self.rho * x)))
206
+
207
+ def aggregation_derivative(self, x):
208
+ erx = np.exp(self.rho * x)
209
+ return erx / np.sum(erx)