pyMOTO 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymoto/core_objects.py CHANGED
@@ -1,9 +1,11 @@
1
- from typing import Union, List, Any
1
+ import sys
2
2
  import warnings
3
3
  import inspect
4
4
  import time
5
- from .utils import _parse_to_list, _concatenate_to_array, _split_from_array
5
+ import copy
6
+ from typing import Union, List, Any
6
7
  from abc import ABC, abstractmethod
8
+ from .utils import _parse_to_list, _concatenate_to_array, _split_from_array
7
9
 
8
10
 
9
11
  # Local helper functions
@@ -11,7 +13,7 @@ def err_fmt(*args):
11
13
  """ Format error strings for locating Modules and Signals"""
12
14
  err_str = ""
13
15
  for a in args:
14
- err_str += f"\n\t[ {a} ]"
16
+ err_str += f"\n\t| {a}"
15
17
  return err_str
16
18
 
17
19
 
@@ -76,16 +78,20 @@ class Signal:
76
78
  >> Signal(tag='x2')
77
79
 
78
80
  """
79
- def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None):
81
+ def __init__(self, tag: str = "", state: Any = None, sensitivity: Any = None, min: Any = None, max: Any = None):
80
82
  """
81
-
82
- :param tag: The name of the signal (string)
83
- :param state: The initialized state (optional)
84
- :param sensitivity: The initialized sensitivity (optional)
83
+ Keyword Args:
84
+ tag: The name of the signal
85
+ state: The initialized state
86
+ sensitivity: The initialized sensitivity
87
+ min: Minimum allowed value
88
+ max: Maximum allowed value
85
89
  """
86
90
  self.tag = tag
87
91
  self.state = state
88
92
  self.sensitivity = sensitivity
93
+ self.min = min
94
+ self.max = max
89
95
  self.keep_alloc = sensitivity is not None
90
96
 
91
97
  # Save error string to location where it is initialized
@@ -95,11 +101,15 @@ class Signal:
95
101
  return err_fmt(f"Signal \'{self.tag}\', initialized in {self._init_loc}")
96
102
 
97
103
  def add_sensitivity(self, ds: Any):
104
+ """ Add a new term to internal sensitivity """
98
105
  try:
99
106
  if ds is None:
100
107
  return
101
108
  if self.sensitivity is None:
102
- self.sensitivity = ds
109
+ self.sensitivity = copy.deepcopy(ds)
110
+ elif hasattr(self.sensitivity, "add_sensitivity"):
111
+ # Allow user to implement a custom add_sensitivity function instead of __iadd__
112
+ self.sensitivity.add_sensitivity(ds)
103
113
  else:
104
114
  self.sensitivity += ds
105
115
  return self
@@ -116,8 +126,12 @@ class Signal:
116
126
  def reset(self, keep_alloc: bool = None):
117
127
  """ Reset the sensitivities to zero or None
118
128
  This must be called to clear internal memory of subsequent sensitivity calculations.
119
- :param keep_alloc: Keep the sensitivity allocation intact?
120
- :return: self
129
+
130
+ Args:
131
+ keep_alloc: Keep the sensitivity allocation intact?
132
+
133
+ Returns:
134
+ self
121
135
  """
122
136
  if self.sensitivity is None:
123
137
  return self
@@ -138,11 +152,34 @@ class Signal:
138
152
 
139
153
  def __getitem__(self, item):
140
154
  """ Obtain a sliced signal, for using its partial contents.
141
- :param item: Slice indices
142
- :return: Sliced signal (SignalSlice)
155
+
156
+ Args:
157
+ item: Slice indices
158
+
159
+ Returns:
160
+ Sliced signal (SignalSlice)
143
161
  """
144
162
  return SignalSlice(self, item)
145
163
 
164
+ def __str__(self):
165
+ state_msg = f"state {self.state}" if self.state is not None else "empty state"
166
+ state_msg = state_msg.split('\n')
167
+ if len(state_msg) > 1:
168
+ state_msg = state_msg[0] + ' ... ' + state_msg[-1]
169
+ else:
170
+ state_msg = state_msg[0]
171
+ return f"Signal \"{self.tag}\" with {state_msg}"
172
+
173
+ def __repr__(self):
174
+ state_msg = f"state {self.state}" if self.state is not None else "empty state"
175
+ state_msg = state_msg.split('\n')
176
+ if len(state_msg) > 1:
177
+ state_msg = state_msg[0] + ' ... ' + state_msg[-1]
178
+ else:
179
+ state_msg = state_msg[0]
180
+ sens_msg = 'empty sensitivity' if self.sensitivity is None else 'non-empty sensitivity'
181
+ return f"Signal \"{self.tag}\" with {state_msg} and {sens_msg} at {hex(id(self))}"
182
+
146
183
 
147
184
  class SignalSlice(Signal):
148
185
  """ Slice operator for a Signal
@@ -169,7 +206,8 @@ class SignalSlice(Signal):
169
206
  return None if self.orig_signal.state is None else self.orig_signal.state[self.slice]
170
207
  except Exception as e:
171
208
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
172
- raise type(e)("SignalSlice.state (getter)" + self._err_str()) from e
209
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (getter). Signal details:" +
210
+ self._err_str()).with_traceback(sys.exc_info()[2])
173
211
 
174
212
  @state.setter
175
213
  def state(self, new_state):
@@ -177,7 +215,8 @@ class SignalSlice(Signal):
177
215
  self.orig_signal.state[self.slice] = new_state
178
216
  except Exception as e:
179
217
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
180
- raise type(e)("SignalSlice.state (setter)" + self._err_str()) from e
218
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
219
+ self._err_str()).with_traceback(sys.exc_info()[2])
181
220
 
182
221
  @property
183
222
  def sensitivity(self):
@@ -185,7 +224,8 @@ class SignalSlice(Signal):
185
224
  return None if self.orig_signal.sensitivity is None else self.orig_signal.sensitivity[self.slice]
186
225
  except Exception as e:
187
226
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
188
- raise type(e)("SignalSlice.sensitivity (getter)" + self._err_str()) from e
227
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.sensitivity (getter). Signal details:" +
228
+ self._err_str()).with_traceback(sys.exc_info()[2])
189
229
 
190
230
  @sensitivity.setter
191
231
  def sensitivity(self, new_sens):
@@ -207,7 +247,8 @@ class SignalSlice(Signal):
207
247
  self.orig_signal.sensitivity[self.slice] = new_sens
208
248
  except Exception as e:
209
249
  # Possibilities: Unslicable object (TypeError) or Wrong dimensions or out of range (IndexError)
210
- raise type(e)("SignalSlice.sensitivity (setter)" + self._err_str()) from e
250
+ raise type(e)(str(e) + "\n\t| Above error was raised in SignalSlice.state (setter). Signal details:" +
251
+ self._err_str()).with_traceback(sys.exc_info()[2])
211
252
 
212
253
  def reset(self, keep_alloc: bool = None):
213
254
  """ Reset the sensitivities to zero or None
@@ -231,7 +272,7 @@ def make_signals(*args):
231
272
  return ret
232
273
 
233
274
 
234
- def _check_valid_signal(sig: Any):
275
+ def _is_valid_signal(sig: Any):
235
276
  """ Checks if the argument is a valid Signal object
236
277
  :param sig: The object to check
237
278
  :return: True if it is a valid Signal
@@ -240,10 +281,10 @@ def _check_valid_signal(sig: Any):
240
281
  return True
241
282
  if all([hasattr(sig, f) for f in ["state", "sensitivity", "add_sensitivity", "reset"]]):
242
283
  return True
243
- raise TypeError(f"Given argument with type \'{type(sig).__name__}\' is not a valid Signal")
284
+ return False
244
285
 
245
286
 
246
- def _check_valid_module(mod: Any):
287
+ def _is_valid_module(mod: Any):
247
288
  """ Checks if the argument is a valid Module object
248
289
  :param mod: The object to check
249
290
  :return: True if it is a valid Module
@@ -252,7 +293,7 @@ def _check_valid_module(mod: Any):
252
293
  return True
253
294
  if hasattr(mod, "response") and hasattr(mod, "sensitivity") and hasattr(mod, "reset"):
254
295
  return True
255
- raise TypeError(f"Given argument with type \'{type(mod).__name__}\' is not a valid Module")
296
+ return False
256
297
 
257
298
 
258
299
  def _check_function_signature(fn, signals):
@@ -370,60 +411,48 @@ class Module(ABC, RegisteredClass):
370
411
  >> Module(sig_in=[inputs], sig_out=[outputs]
371
412
  """
372
413
 
373
- def _err_str(self, init: bool = True, add_signal: bool = True, fn=None):
414
+ def _err_str(self, module_signature: bool = True, init: bool = True, fn=None):
374
415
  str_list = []
375
- if init:
376
- str_list.append(f"Module \'{type(self).__name__}\', initialized in {self._init_loc}")
377
- if add_signal:
416
+
417
+ if module_signature:
378
418
  inp_str = "Inputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_in]) if len(self.sig_in) > 0 else "No inputs"
379
419
  out_str = "Outputs: " + ", ".join([s.tag if hasattr(s, 'tag') else 'N/A' for s in self.sig_out]) if len(self.sig_out) > 0 else "No outputs"
380
- str_list.append(inp_str + " --> " + out_str)
420
+ str_list.append(f"Module \'{type(self).__name__}\'( " + inp_str + " ) --> " + out_str)
421
+ if init:
422
+ str_list.append(f"Used in {self._init_loc}")
381
423
  if fn is not None:
382
424
  name = f"{fn.__self__.__class__.__name__}.{fn.__name__}{inspect.signature(fn)}"
383
425
  lineno = inspect.getsourcelines(fn)[1]
384
426
  filename = inspect.getfile(fn)
385
- str_list.append(f"Implemented in File \"{filename}\", line {lineno}, in {name}")
427
+ str_list.append(f"Implementation in File \"{filename}\", line {lineno}, in {name}")
386
428
  return err_fmt(*str_list)
387
429
 
388
430
  # flake8: noqa: C901
389
431
  def __init__(self, sig_in: Union[Signal, List[Signal]] = None, sig_out: Union[Signal, List[Signal]] = None,
390
432
  *args, **kwargs):
391
- # TODO: Reduce complexity of this init
392
433
  self._init_loc = get_init_str()
393
434
 
394
435
  self.sig_in = _parse_to_list(sig_in)
395
436
  self.sig_out = _parse_to_list(sig_out)
396
437
  for i, s in enumerate(self.sig_in):
397
- try:
398
- _check_valid_signal(s)
399
- except Exception as e:
400
- earg0 = e.args[0] if len(e.args) > 0 else ''
401
- earg1 = e.args[1:] if len(e.args) > 1 else ()
402
- raise type(e)(f"Invalid input signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
438
+ if not _is_valid_signal(s):
439
+ tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
440
+ raise TypeError(f"Input {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
403
441
 
404
442
  for i, s in enumerate(self.sig_out):
405
- try:
406
- _check_valid_signal(s)
407
- except Exception as e:
408
- earg0 = e.args[0] if len(e.args) > 0 else ''
409
- earg1 = e.args[1:] if len(e.args) > 1 else ()
410
- raise type(e)(f"Invalid output signal #{i+1} - " + str(earg0) + self._err_str(), *earg1) from None
443
+ if not _is_valid_signal(s):
444
+ tag = f" (\'{s.tag}\')" if hasattr(s, 'tag') else ''
445
+ raise TypeError(f"Output {i}{tag} is not a valid signal, type=\'{type(s).__name__}\'.")
411
446
 
412
- try:
413
- # Call preparation of submodule with remaining arguments
414
- self._prepare(*args, **kwargs)
415
- except Exception as e:
416
- earg0 = e.args[0] if len(e.args) > 0 else ''
417
- earg1 = e.args[1:] if len(e.args) > 1 else ()
418
- raise type(e)("_prepare() - " + str(earg0) + self._err_str(fn=self._prepare), *earg1) from e
447
+ # Call preparation of submodule with remaining arguments
448
+ self._prepare(*args, **kwargs)
419
449
 
420
450
  try:
421
451
  # Check if the signals match _response() signature
422
452
  _check_function_signature(self._response, self.sig_in)
423
453
  except Exception as e:
424
- earg0 = e.args[0] if len(e.args) > 0 else ''
425
- earg1 = e.args[1:] if len(e.args) > 1 else ()
426
- raise type(e)(str(earg0) + self._err_str(fn=self._response), *earg1) from None
454
+ raise type(e)(str(e) + "\n\t| Module details:" +
455
+ self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
427
456
 
428
457
  try:
429
458
  # If no output signals are given, but are required, try to initialize them here
@@ -441,9 +470,8 @@ class Module(ABC, RegisteredClass):
441
470
  # Check if signals match _sensitivity() signature
442
471
  _check_function_signature(self._sensitivity, self.sig_out)
443
472
  except Exception as e:
444
- earg0 = e.args[0] if len(e.args) > 0 else ''
445
- earg1 = e.args[1:] if len(e.args) > 1 else ()
446
- raise type(e)(str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from None
473
+ raise type(e)(str(e) + "\n\t| Module details:" +
474
+ self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
447
475
 
448
476
  def response(self):
449
477
  """ Calculate the response from sig_in and output this to sig_out """
@@ -461,9 +489,9 @@ class Module(ABC, RegisteredClass):
461
489
  self.sig_out[i].state = val
462
490
  return self
463
491
  except Exception as e:
464
- earg0 = e.args[0] if len(e.args) > 0 else ''
465
- earg1 = e.args[1:] if len(e.args) > 1 else ()
466
- raise type(e)("response() - " + str(earg0) + self._err_str(fn=self._response), *earg1) from e
492
+ # https://stackoverflow.com/questions/6062576/adding-information-to-an-exception
493
+ raise type(e)(str(e) + "\n\t| Above error was raised when calling response(). Module details:" +
494
+ self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
467
495
 
468
496
  def __call__(self):
469
497
  return self.response()
@@ -494,9 +522,8 @@ class Module(ABC, RegisteredClass):
494
522
 
495
523
  return self
496
524
  except Exception as e:
497
- earg0 = e.args[0] if len(e.args) > 0 else ''
498
- earg1 = e.args[1:] if len(e.args) > 1 else ()
499
- raise type(e)("sensitivity() - " + str(earg0) + self._err_str(fn=self._sensitivity), *earg1) from e
525
+ raise type(e)(str(e) + "\n\t| Above error was raised when calling sensitivity(). Module details:" +
526
+ self._err_str(fn=self._sensitivity)).with_traceback(sys.exc_info()[2])
500
527
 
501
528
  def reset(self):
502
529
  """ Reset the state of the sensitivities (they are set to zero or to None) """
@@ -506,9 +533,8 @@ class Module(ABC, RegisteredClass):
506
533
  self._reset()
507
534
  return self
508
535
  except Exception as e:
509
- earg0 = e.args[0] if len(e.args) > 0 else ''
510
- earg1 = e.args[1:] if len(e.args) > 1 else ()
511
- raise type(e)("reset() - " + str(earg0) + self._err_str(fn=self._reset), *earg1) from e
536
+ raise type(e)(str(e) + "\n\t| Above error was raised when calling reset(). Module details:" +
537
+ self._err_str(fn=self._response)).with_traceback(sys.exc_info()[2])
512
538
 
513
539
  # METHODS TO BE DEFINED BY USER
514
540
  def _prepare(self, *args, **kwargs):
@@ -542,49 +568,52 @@ class Network(Module):
542
568
  """
543
569
  def __init__(self, *args, print_timing=False):
544
570
  self._init_loc = get_init_str()
545
- try:
546
- # Obtain the internal blocks
547
- self.mods = _parse_to_list(*args)
548
-
549
- # Check if the blocks are initialized, else create them
550
- for i, b in enumerate(self.mods):
551
- if isinstance(b, dict):
552
- exclude_keys = ['type']
553
- b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
554
- self.mods[i] = Module.create(b['type'], **b_ex)
555
-
556
- # Check validity of modules
557
- [_check_valid_module(m) for m in self.mods]
558
-
559
- # Gather all the input and output signals of the internal blocks
560
- all_in = set()
561
- all_out = set()
562
- [all_in.update(b.sig_in) for b in self.mods]
563
- [all_out.update(b.sig_out) for b in self.mods]
564
- in_unique = all_in - all_out
565
-
566
- # Initialize the parent module, with correct inputs and outputs
567
- super().__init__(list(in_unique), list(all_out))
568
-
569
- self.print_timing = print_timing
570
- except Exception as e:
571
- earg0 = e.args[0] if len(e.args) > 0 else ''
572
- earg1 = e.args[1:] if len(e.args) > 1 else ()
573
- raise type(e)(str(earg0) + self._err_str(add_signal=False), *earg1) from None
574
571
 
575
- def timefn(self, fn):
572
+ # Obtain the internal blocks
573
+ self.mods = _parse_to_list(*args)
574
+
575
+ # Check if the blocks are initialized, else create them
576
+ for i, b in enumerate(self.mods):
577
+ if isinstance(b, dict):
578
+ exclude_keys = ['type']
579
+ b_ex = {k: b[k] for k in set(list(b.keys())) - set(exclude_keys)}
580
+ self.mods[i] = Module.create(b['type'], **b_ex)
581
+
582
+ # Check validity of modules
583
+ for m in self.mods:
584
+ if not _is_valid_module(m):
585
+ raise TypeError(f"Argument is not a valid Module, type=\'{type(mod).__name__}\'.")
586
+
587
+ # Gather all the input and output signals of the internal blocks
588
+ all_in = set()
589
+ all_out = set()
590
+ [all_in.update(b.sig_in) for b in self.mods]
591
+ [all_out.update(b.sig_out) for b in self.mods]
592
+ in_unique = all_in - all_out
593
+
594
+ # Initialize the parent module, with correct inputs and outputs
595
+ super().__init__(list(in_unique), list(all_out))
596
+
597
+ self.print_timing = print_timing
598
+
599
+ def timefn(self, fn, prefix='Evaluation'):
576
600
  start_t = time.time()
577
601
  fn()
578
- print(f"Evaluating {fn} took {time.time() - start_t} s")
602
+ duration = time.time() - start_t
603
+ if duration > .5:
604
+ print(f"{prefix} {fn} took {time.time() - start_t} s")
579
605
 
580
606
  def response(self):
581
607
  if self.print_timing:
582
- [self.timefn(b.response) for b in self.mods]
608
+ [self.timefn(b.response, prefix='Response') for b in self.mods]
583
609
  else:
584
610
  [b.response() for b in self.mods]
585
611
 
586
612
  def sensitivity(self):
587
- [b.sensitivity() for b in reversed(self.mods)]
613
+ if self.print_timing:
614
+ [self.timefn(b.sensitivity, 'Sensitivity') for b in reversed(self.mods)]
615
+ else:
616
+ [b.sensitivity() for b in reversed(self.mods)]
588
617
 
589
618
  def reset(self):
590
619
  [b.reset() for b in reversed(self.mods)]
@@ -611,12 +640,9 @@ class Network(Module):
611
640
  modlist = _parse_to_list(*newmods)
612
641
 
613
642
  # Check if the blocks are initialized, else create them
614
- for i, b in enumerate(modlist):
615
- try: # Check validity of modules
616
- _check_valid_module(b)
617
- except Exception as e:
618
- raise type(e)("append() - Trying to append invalid module " + str(e.args[0])
619
- + self._err_str(add_signal=False), *e.args[1:]) from None
643
+ for i, m in enumerate(modlist):
644
+ if not _is_valid_module(m):
645
+ raise TypeError(f"Argument #{i} is not a valid module, type=\'{type(mod).__name__}\'.")
620
646
 
621
647
  # Obtain the internal blocks
622
648
  self.mods.extend(modlist)
@@ -624,25 +650,11 @@ class Network(Module):
624
650
  # Gather all the input and output signals of the internal blocks
625
651
  all_in = set()
626
652
  all_out = set()
627
- [all_in.update(b.sig_in) for b in self.mods]
628
- [all_out.update(b.sig_out) for b in self.mods]
653
+ [all_in.update(m.sig_in) for m in self.mods]
654
+ [all_out.update(m.sig_out) for m in self.mods]
629
655
  in_unique = all_in - all_out
630
656
 
631
657
  self.sig_in = _parse_to_list(in_unique)
632
- try:
633
- [_check_valid_signal(s) for s in self.sig_in]
634
- except Exception as e:
635
- earg0 = e.args[0] if len(e.args) > 0 else ''
636
- earg1 = e.args[1:] if len(e.args) > 1 else ()
637
- raise type(e)("append() - Invalid input signals " + str(earg0)
638
- + self._err_str(add_signal=False), *earg1) from None
639
658
  self.sig_out = _parse_to_list(all_out)
640
- try:
641
- [_check_valid_signal(s) for s in self.sig_out]
642
- except Exception as e:
643
- earg0 = e.args[0] if len(e.args) > 0 else ''
644
- earg1 = e.args[1:] if len(e.args) > 1 else ()
645
- raise type(e)("append() - Invalid output signals " + str(earg0)
646
- + self._err_str(add_signal=False), *earg1) from None
647
659
 
648
660
  return modlist[-1].sig_out[0] if len(modlist[-1].sig_out) == 1 else modlist[-1].sig_out # Returns the output signal
@@ -0,0 +1,209 @@
1
+ import warnings
2
+ import abc
3
+ import numpy as np
4
+ import scipy.special as spsp
5
+ from pymoto import Module
6
+
7
+
8
+ class AggActiveSet:
9
+ """ Determine active set by discarding lower or upper fraction of a set of values
10
+
11
+ Args:
12
+ lower_rel: Fraction of values closest to minimum to discard (based on value)
13
+ upper_rel: Fraction of values closest to maximum to discard (based on value)
14
+ lower_amt: Fraction of lowest values to discard (based on sorting)
15
+ upper_amt: Fraction of highest values to discard (based on sorting)
16
+ """
17
+ def __init__(self, lower_rel=0.0, upper_rel=1.0, lower_amt=0.0, upper_amt=1.0):
18
+ assert upper_rel > lower_rel, "Upper must be larger than lower to keep values in the set"
19
+ assert upper_amt > lower_amt, "Upper must be larger than lower to keep values in the set"
20
+ self.lower_rel, self.upper_rel = lower_rel, upper_rel
21
+ self.lower_amt, self.upper_amt = lower_amt, upper_amt
22
+
23
+ def __call__(self, x):
24
+ """ Generate an active set for given array """
25
+ xmin, xmax = np.min(x), np.max(x)
26
+ if (xmax - xmin) == 0: # All values are the same, so no active set can be taken
27
+ return Ellipsis
28
+
29
+ sel = np.ones_like(x, dtype=bool)
30
+
31
+ # Select based on value
32
+ xrel = (x - xmin) / (xmax - xmin) # Normalize between 0 and 1
33
+ if self.lower_rel > 0:
34
+ sel = np.logical_and(sel, xrel >= self.lower_rel)
35
+ if self.upper_rel < 1:
36
+ sel = np.logical_and(sel, xrel <= self.upper_rel)
37
+
38
+ # Remove lowest and highest N values
39
+ i_sort = np.argsort(x)
40
+ if self.lower_amt > 0:
41
+ n_lower_amt = int(x.size * self.lower_amt)
42
+ sel[i_sort[:n_lower_amt]] = False
43
+
44
+ if self.upper_amt < 1:
45
+ n_upper_amt = int(x.size * (1 - self.upper_amt))
46
+ sel[i_sort[-n_upper_amt:]] = False
47
+
48
+ return sel
49
+
50
+
51
+ class AggScaling:
52
+ """ Scaling strategy to absolute minimum or maximum
53
+
54
+ Args:
55
+ which: Scale to `min` or `max`
56
+ damping(optional): Damping factor between [0, 1), for a value of 0.0 the aggregation approximation is corrected
57
+ to the exact maximum or minimum of the input set
58
+ """
59
+ def __init__(self, which: str, damping=0.0):
60
+ self.damping = damping
61
+ if which.lower() == 'min':
62
+ self.f = np.min
63
+ elif which.lower() == 'max':
64
+ self.f = np.max
65
+ else:
66
+ raise ValueError("Argument `which` can only be 'min' or 'max'")
67
+ self.sf = None
68
+
69
+ def __call__(self, x, fx_approx):
70
+ """ Determine scaling factor
71
+
72
+ Args:
73
+ x: Set of values
74
+ fx_approx: Approximated minimum / maximum
75
+
76
+ Returns:
77
+ Scaling factor
78
+ """
79
+ trueval = self.f(x)
80
+ scale = trueval / fx_approx
81
+ if self.sf is None:
82
+ self.sf = scale
83
+ else:
84
+ self.sf = self.damping * self.sf + (1 - self.damping) * scale
85
+ return self.sf
86
+
87
+
88
+ class Aggregation(Module):
89
+ """ Generic Aggregation module (cannot be used directly, but can only be used as superclass)
90
+
91
+ Keyword Args:
92
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
93
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
94
+ """
95
+ def _prepare(self, scaling: AggScaling = None, active_set: AggActiveSet = None):
96
+ # This prepare function MUST be called in the _prepare function of sub-classes
97
+ self.scaling = scaling
98
+ self.active_set = active_set
99
+ self.sf = 1.0
100
+
101
+ @abc.abstractmethod
102
+ def aggregation_function(self, x):
103
+ """ Calculates f(x) """
104
+ raise NotImplementedError()
105
+
106
+ @abc.abstractmethod
107
+ def aggregation_derivative(self, x):
108
+ """" Calculates df(x) / dx """
109
+ raise NotImplementedError()
110
+
111
+ def _response(self, x):
112
+ # Determine active set
113
+ if self.active_set is not None:
114
+ self.select = self.active_set(x)
115
+ else:
116
+ self.select = Ellipsis
117
+
118
+ # Get aggregated value
119
+ xagg = self.aggregation_function(x[self.select])
120
+
121
+ # Scale
122
+ if self.scaling is not None:
123
+ self.sf = self.scaling(x[self.select], xagg)
124
+ return self.sf * xagg
125
+
126
+ def _sensitivity(self, dfdy):
127
+ x = self.sig_in[0].state
128
+ dydx = self.aggregation_derivative(x[self.select])
129
+ dx = np.zeros_like(x)
130
+ dx[self.select] += self.sf * dfdy * dydx
131
+ return dx
132
+
133
+
134
+ class PNorm(Aggregation):
135
+ r""" P-norm aggregration
136
+
137
+ :math:`S_p(x_1, x_2, \dotsc, x_n) = \left( \sum_i (|x_i|^p) \right)^{1/p}
138
+
139
+ Only valid for positive :math:`x_i` when approximating the minimum or maximum
140
+
141
+ Args:
142
+ p: Power of the p-norm. Approximate maximum for `p>0` and minimum for `p<0`
143
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
144
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
145
+ """
146
+ def _prepare(self, p=2, scaling: AggScaling = None, active_set: AggActiveSet = None):
147
+ self.p = p
148
+ self.y = None
149
+ super()._prepare(scaling, active_set)
150
+
151
+ def aggregation_function(self, x):
152
+ if np.min(x) < 0:
153
+ warnings.warn("PNorm is only valid for positive x")
154
+
155
+ # Get p-norm
156
+ return np.sum(np.abs(x) ** self.p) ** (1/self.p)
157
+
158
+ def aggregation_derivative(self, x):
159
+ pval = np.sum(np.abs(x) ** self.p) ** (1 / self.p - 1)
160
+ return pval * np.sign(x) * np.abs(x)**(self.p - 1)
161
+
162
+
163
+ class SoftMinMax(Aggregation):
164
+ r""" Soft maximum/minimum function
165
+
166
+ :math:`S_a(x_1, x_2, \dotsc, x_n) = \frac{\sum_i (x_i \exp(a x_i))}{\sum_i (\exp(a x_i))}`
167
+
168
+ When using as maximum, it underestimates the maximum
169
+ It is exact however when :math:`x_1=x_2=\dotsc=x_n`
170
+
171
+ Args:
172
+ alpha: Scaling factor of the soft function. Approximate maximum for `alpha>0` and minimum for `alpha<0`
173
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
174
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
175
+ """
176
+ def _prepare(self, alpha=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
177
+ self.alpha = alpha
178
+ self.y = None
179
+ super()._prepare(scaling, active_set)
180
+
181
+ def aggregation_function(self, x):
182
+ self.y = np.sum(x * spsp.softmax(self.alpha * x))
183
+ return self.y
184
+
185
+ def aggregation_derivative(self, x):
186
+ return spsp.softmax(self.alpha * x) * (1 + self.alpha * (x - self.y))
187
+
188
+
189
+ class KSFunction(Aggregation):
190
+ r""" Kreisselmeier and Steinhauser function from 1979
191
+
192
+ :math:`S_\rho(x_1, x_2, \dotsc, x_n) = \frac{1}{\rho} \ln \left( \sum_i \exp(\rho x_i) \right)`
193
+
194
+ Args:
195
+ rho: Scaling factor of the KS function. Approximate maximum for `rho>0` and minimum for `rho<0`
196
+ scaling(optional): Scaling strategy to improve approximation :py:class:`pymoto.AggScaling`
197
+ active_set(optional): Active set strategy to improve approximation :py:class:`pymoto.AggActiveSet`
198
+ """
199
+ def _prepare(self, rho=1.0, scaling: AggScaling = None, active_set: AggActiveSet = None):
200
+ self.rho = rho
201
+ self.y = None
202
+ super()._prepare(scaling, active_set)
203
+
204
+ def aggregation_function(self, x):
205
+ return 1/self.rho * np.log(np.sum(np.exp(self.rho * x)))
206
+
207
+ def aggregation_derivative(self, x):
208
+ erx = np.exp(self.rho * x)
209
+ return erx / np.sum(erx)