pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,24 +1,18 @@
1
+ from typing import Optional, Set
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import jax.random as random
4
- from typing import Optional, Set
5
6
 
6
7
  from pyRDDLGym.core.debug.exception import raise_warning
7
8
 
8
9
 
9
- class Complement:
10
- '''Base class for approximate logical complement operations.'''
11
-
12
- def __call__(self, x):
13
- raise NotImplementedError
14
-
15
-
16
- class StandardComplement(Complement):
17
- '''The standard approximate logical complement given by x -> 1 - x.'''
18
-
19
- def __call__(self, x):
20
- return 1.0 - x
21
-
10
+ # ===========================================================================
11
+ # RELATIONAL OPERATIONS
12
+ # - abstract class
13
+ # - sigmoid comparison
14
+ #
15
+ # ===========================================================================
22
16
 
23
17
  class Comparison:
24
18
  '''Base class for approximate comparison operations.'''
@@ -32,10 +26,14 @@ class Comparison:
32
26
  def equal(self, x, y, param):
33
27
  raise NotImplementedError
34
28
 
29
+ def sgn(self, x, param):
30
+ raise NotImplementedError
31
+
35
32
 
36
33
  class SigmoidComparison(Comparison):
37
34
  '''Comparison operations approximated using sigmoid functions.'''
38
35
 
36
+ # https://arxiv.org/abs/2110.05651
39
37
  def greater_equal(self, x, y, param):
40
38
  return jax.nn.sigmoid(param * (x - y))
41
39
 
@@ -44,7 +42,75 @@ class SigmoidComparison(Comparison):
44
42
 
45
43
  def equal(self, x, y, param):
46
44
  return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
47
-
45
+
46
+ def sgn(self, x, param):
47
+ return jnp.tanh(param * x)
48
+
49
+
50
+ # ===========================================================================
51
+ # ROUNDING OPERATIONS
52
+ # - abstract class
53
+ # - soft rounding
54
+ #
55
+ # ===========================================================================
56
+
57
+ class Rounding:
58
+ '''Base class for approximate rounding operations.'''
59
+
60
+ def floor(self, x, param):
61
+ raise NotImplementedError
62
+
63
+ def round(self, x, param):
64
+ raise NotImplementedError
65
+
66
+
67
+ class SoftRounding(Rounding):
68
+ '''Rounding operations approximated using soft operations.'''
69
+
70
+ # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
71
+ def floor(self, x, param):
72
+ denom = jnp.tanh(param / 4.0)
73
+ return (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
74
+ jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
75
+
76
+ # https://arxiv.org/abs/2006.09952
77
+ def round(self, x, param):
78
+ m = jnp.floor(x) + 0.5
79
+ return m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
80
+
81
+
82
+ # ===========================================================================
83
+ # LOGICAL COMPLEMENT
84
+ # - abstract class
85
+ # - standard complement
86
+ #
87
+ # ===========================================================================
88
+
89
+ class Complement:
90
+ '''Base class for approximate logical complement operations.'''
91
+
92
+ def __call__(self, x):
93
+ raise NotImplementedError
94
+
95
+
96
+ class StandardComplement(Complement):
97
+ '''The standard approximate logical complement given by x -> 1 - x.'''
98
+
99
+ # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
100
+ def __call__(self, x):
101
+ return 1.0 - x
102
+
103
+
104
+ # ===========================================================================
105
+ # TNORMS
106
+ # - abstract tnorm
107
+ # - product tnorm
108
+ # - Godel tnorm
109
+ # - Lukasiewicz tnorm
110
+ # - Yager(p) tnorm
111
+ #
112
+ # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
113
+ # ===========================================================================
48
114
 
49
115
  class TNorm:
50
116
  '''Base class for fuzzy differentiable t-norms.'''
@@ -86,8 +152,134 @@ class LukasiewiczTNorm(TNorm):
86
152
 
87
153
  def norms(self, x, axis):
88
154
  return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
155
+
156
+
157
+ class YagerTNorm(TNorm):
158
+ '''Yager t-norm given by the expression
159
+ (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
160
+
161
+ def __init__(self, p=2.0):
162
+ self.p = float(p)
163
+
164
+ def norm(self, x, y):
165
+ base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
166
+ arg = jnp.linalg.norm(base, ord=self.p, axis=0)
167
+ return jax.nn.relu(1.0 - arg)
168
+
169
+ def norms(self, x, axis):
170
+ arg = jax.nn.relu(1.0 - x)
171
+ for ax in sorted(axis, reverse=True):
172
+ arg = jnp.linalg.norm(arg, ord=self.p, axis=ax)
173
+ return jax.nn.relu(1.0 - arg)
174
+
175
+
176
+ # ===========================================================================
177
+ # RANDOM SAMPLING
178
+ # - abstract sampler
179
+ # - Gumbel-softmax sampler
180
+ # - determinization
181
+ #
182
+ # ===========================================================================
183
+
184
+ class RandomSampling:
185
+ '''An abstract class that describes how discrete and non-reparameterizable
186
+ random variables are sampled.'''
89
187
 
188
+ def discrete(self, logic):
189
+ raise NotImplementedError
90
190
 
191
+ def bernoulli(self, logic):
192
+ jax_discrete, jax_param = self.discrete(logic)
193
+
194
+ def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
195
+ prob = jnp.stack([1.0 - prob, prob], axis=-1)
196
+ sample = jax_discrete(key, prob, param)
197
+ return sample
198
+
199
+ return _jax_wrapped_calc_bernoulli_approx, jax_param
200
+
201
+ def poisson(self, logic):
202
+
203
+ def _jax_wrapped_calc_poisson_exact(key, rate, param):
204
+ return random.poisson(key=key, lam=rate, dtype=logic.INT)
205
+
206
+ return _jax_wrapped_calc_poisson_exact, None
207
+
208
+ def geometric(self, logic):
209
+ if logic.verbose:
210
+ raise_warning('Using the replacement rule: '
211
+ 'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
212
+
213
+ jax_floor, jax_param = logic.floor()
214
+
215
+ def _jax_wrapped_calc_geometric_approx(key, prob, param):
216
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
217
+ sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
218
+ return sample
219
+
220
+ return _jax_wrapped_calc_geometric_approx, jax_param
221
+
222
+
223
+ class GumbelSoftmax(RandomSampling):
224
+ '''Random sampling of discrete variables using Gumbel-softmax trick.'''
225
+
226
+ def discrete(self, logic):
227
+ if logic.verbose:
228
+ raise_warning('Using the replacement rule: '
229
+ 'Discrete(p) --> Gumbel-Softmax(p)')
230
+
231
+ jax_argmax, jax_param = logic.argmax()
232
+
233
+ # https://arxiv.org/pdf/1611.01144
234
+ def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
235
+ Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
236
+ sample = Gumbel01 + jnp.log(prob + logic.eps)
237
+ sample = jax_argmax(sample, axis=-1, param=param)
238
+ return sample
239
+
240
+ return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
241
+
242
+
243
+ class Determinization(RandomSampling):
244
+ '''Random sampling of variables using their deterministic mean estimate.'''
245
+
246
+ def discrete(self, logic):
247
+ if logic.verbose:
248
+ raise_warning('Using the replacement rule: '
249
+ 'Discrete(p) --> sum(i * p[i])')
250
+
251
+ def _jax_wrapped_calc_discrete_determinized(key, prob, param):
252
+ literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
253
+ sample = jnp.sum(literals * prob, axis=-1)
254
+ return sample
255
+
256
+ return _jax_wrapped_calc_discrete_determinized, None
257
+
258
+ def poisson(self, logic):
259
+ if logic.verbose:
260
+ raise_warning('Using the replacement rule: Poisson(rate) --> rate')
261
+
262
+ def _jax_wrapped_calc_poisson_determinized(key, rate, param):
263
+ return rate
264
+
265
+ return _jax_wrapped_calc_poisson_determinized, None
266
+
267
+ def geometric(self, logic):
268
+ if logic.verbose:
269
+ raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
270
+
271
+ def _jax_wrapped_calc_geometric_determinized(key, prob, param):
272
+ sample = 1.0 / prob
273
+ return sample
274
+
275
+ return _jax_wrapped_calc_geometric_determinized, None
276
+
277
+
278
+ # ===========================================================================
279
+ # FUZZY LOGIC
280
+ #
281
+ # ===========================================================================
282
+
91
283
  class FuzzyLogic:
92
284
  '''A class representing fuzzy logic in JAX.
93
285
 
@@ -98,9 +290,11 @@ class FuzzyLogic:
98
290
  def __init__(self, tnorm: TNorm=ProductTNorm(),
99
291
  complement: Complement=StandardComplement(),
100
292
  comparison: Comparison=SigmoidComparison(),
293
+ sampling: RandomSampling=GumbelSoftmax(),
294
+ rounding: Rounding=SoftRounding(),
101
295
  weight: float=10.0,
102
296
  debias: Optional[Set[str]]=None,
103
- eps: float=1e-10,
297
+ eps: float=1e-15,
104
298
  verbose: bool=False,
105
299
  use64bit: bool=False) -> None:
106
300
  '''Creates a new fuzzy logic in Jax.
@@ -108,8 +302,9 @@ class FuzzyLogic:
108
302
  :param tnorm: fuzzy operator for logical AND
109
303
  :param complement: fuzzy operator for logical NOT
110
304
  :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
305
+ :param sampling: random sampling of non-reparameterizable distributions
306
+ :param rounding: rounding floating values to integers
111
307
  :param weight: a sharpness parameter for sigmoid and softmax activations
112
- :param error: an error parameter (e.g. floor) (smaller means better accuracy)
113
308
  :param debias: which functions to de-bias approximate on forward pass
114
309
  :param eps: small positive float to mitigate underflow
115
310
  :param verbose: whether to dump replacements and other info to console
@@ -118,6 +313,8 @@ class FuzzyLogic:
118
313
  self.tnorm = tnorm
119
314
  self.complement = complement
120
315
  self.comparison = comparison
316
+ self.sampling = sampling
317
+ self.rounding = rounding
121
318
  self.weight = float(weight)
122
319
  if debias is None:
123
320
  debias = set()
@@ -142,10 +339,12 @@ class FuzzyLogic:
142
339
  f' tnorm ={type(self.tnorm).__name__}\n'
143
340
  f' complement ={type(self.complement).__name__}\n'
144
341
  f' comparison ={type(self.comparison).__name__}\n'
342
+ f' sampling ={type(self.sampling).__name__}\n'
343
+ f' rounding ={type(self.rounding).__name__}\n'
145
344
  f' sigmoid_weight={self.weight}\n'
146
345
  f' cpfs_to_debias={self.debias}\n'
147
346
  f' underflow_tol ={self.eps}\n'
148
- f' use64bit ={self.use64bit}')
347
+ f' use_64_bit ={self.use64bit}')
149
348
 
150
349
  # ===========================================================================
151
350
  # logical operators
@@ -339,12 +538,14 @@ class FuzzyLogic:
339
538
 
340
539
  def sgn(self):
341
540
  if self.verbose:
342
- raise_warning('Using the replacement rule: sgn(x) --> tanh(x)')
343
-
541
+ raise_warning('Using the replacement rule: '
542
+ 'sgn(x) --> comparison.sgn(x)')
543
+
544
+ sgn_op = self.comparison.sgn
344
545
  debias = 'sgn' in self.debias
345
546
 
346
547
  def _jax_wrapped_calc_sgn_approx(x, param):
347
- sample = jnp.tanh(param * x)
548
+ sample = sgn_op(x, param)
348
549
  if debias:
349
550
  hard_sample = jnp.sign(x)
350
551
  sample += jax.lax.stop_gradient(hard_sample - sample)
@@ -357,37 +558,48 @@ class FuzzyLogic:
357
558
  def floor(self):
358
559
  if self.verbose:
359
560
  raise_warning('Using the replacement rule: '
360
- 'floor(x) --> x - atan(-1.0 / tan(pi * x)) / pi - 0.5')
561
+ 'floor(x) --> rounding.floor(x)')
562
+
563
+ floor_op = self.rounding.floor
564
+ debias = 'floor' in self.debias
361
565
 
362
566
  def _jax_wrapped_calc_floor_approx(x, param):
363
- sawtooth_part = jnp.arctan(-1.0 / jnp.tan(x * jnp.pi)) / jnp.pi + 0.5
364
- sample = x - jax.lax.stop_gradient(sawtooth_part)
567
+ sample = floor_op(x, param)
568
+ if debias:
569
+ hard_sample = jnp.floor(x)
570
+ sample += jax.lax.stop_gradient(hard_sample - sample)
365
571
  return sample
366
572
 
367
- return _jax_wrapped_calc_floor_approx, None
368
-
369
- def ceil(self):
370
- jax_floor, jax_param = self.floor()
371
-
372
- def _jax_wrapped_calc_ceil_approx(x, param):
373
- return -jax_floor(-x, param)
573
+ tags = ('weight', 'floor')
574
+ new_param = (tags, self.weight)
575
+ return _jax_wrapped_calc_floor_approx, new_param
374
576
 
375
- return _jax_wrapped_calc_ceil_approx, jax_param
376
-
377
577
  def round(self):
378
578
  if self.verbose:
379
- raise_warning('Using the replacement rule: round(x) --> x')
579
+ raise_warning('Using the replacement rule: '
580
+ 'round(x) --> rounding.round(x)')
380
581
 
582
+ round_op = self.rounding.round
381
583
  debias = 'round' in self.debias
382
584
 
383
585
  def _jax_wrapped_calc_round_approx(x, param):
384
- sample = x
586
+ sample = round_op(x, param)
385
587
  if debias:
386
588
  hard_sample = jnp.round(x)
387
589
  sample += jax.lax.stop_gradient(hard_sample - sample)
388
590
  return sample
389
591
 
390
- return _jax_wrapped_calc_round_approx, None
592
+ tags = ('weight', 'round')
593
+ new_param = (tags, self.weight)
594
+ return _jax_wrapped_calc_round_approx, new_param
595
+
596
+ def ceil(self):
597
+ jax_floor, jax_param = self.floor()
598
+
599
+ def _jax_wrapped_calc_ceil_approx(x, param):
600
+ return -jax_floor(-x, param)
601
+
602
+ return _jax_wrapped_calc_ceil_approx, jax_param
391
603
 
392
604
  def mod(self):
393
605
  jax_floor, jax_param = self.floor()
@@ -419,7 +631,7 @@ class FuzzyLogic:
419
631
  # ===========================================================================
420
632
 
421
633
  @staticmethod
422
- def _literals(shape, axis):
634
+ def enumerate_literals(shape, axis):
423
635
  literals = jnp.arange(shape[axis])
424
636
  literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
425
637
  literals = jnp.moveaxis(literals, source=0, destination=axis)
@@ -433,8 +645,9 @@ class FuzzyLogic:
433
645
 
434
646
  debias = 'argmax' in self.debias
435
647
 
648
+ # https://arxiv.org/abs/2110.05651
436
649
  def _jax_wrapped_calc_argmax_approx(x, axis, param):
437
- literals = FuzzyLogic._literals(x.shape, axis=axis)
650
+ literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
438
651
  soft_max = jax.nn.softmax(param * x, axis=axis)
439
652
  sample = jnp.sum(literals * soft_max, axis=axis)
440
653
  if debias:
@@ -468,7 +681,7 @@ class FuzzyLogic:
468
681
  def _jax_wrapped_calc_if_approx(c, a, b, param):
469
682
  sample = c * a + (1.0 - c) * b
470
683
  if debias:
471
- hard_sample = jnp.select([c, ~c], [a, b])
684
+ hard_sample = jnp.where(c > 0.5, a, b)
472
685
  sample += jax.lax.stop_gradient(hard_sample - sample)
473
686
  return sample
474
687
 
@@ -478,14 +691,14 @@ class FuzzyLogic:
478
691
  if self.verbose:
479
692
  raise_warning('Using the replacement rule: '
480
693
  'switch(pred) { cases } --> '
481
- 'sum(cases[i] * softmax(-abs(pred - i)))')
694
+ 'sum(cases[i] * softmax(-(pred - i)^2))')
482
695
 
483
696
  debias = 'switch' in self.debias
484
697
 
485
698
  def _jax_wrapped_calc_switch_approx(pred, cases, param):
486
- literals = FuzzyLogic._literals(cases.shape, axis=0)
699
+ literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
487
700
  pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
488
- proximity = -jnp.abs(pred - literals)
701
+ proximity = -jnp.square(pred - literals)
489
702
  soft_case = jax.nn.softmax(param * proximity, axis=0)
490
703
  sample = jnp.sum(cases * soft_case, axis=0)
491
704
  if debias:
@@ -502,46 +715,26 @@ class FuzzyLogic:
502
715
  # random variables
503
716
  # ===========================================================================
504
717
 
505
- def _gumbel_softmax(self, key, prob):
506
- Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=self.REAL)
507
- sample = Gumbel01 + jnp.log(prob + self.eps)
508
- return sample
509
-
718
+ def discrete(self):
719
+ return self.sampling.discrete(self)
720
+
510
721
  def bernoulli(self):
511
- if self.verbose:
512
- raise_warning('Using the replacement rule: '
513
- 'Bernoulli(p) --> Gumbel-softmax(p)')
514
-
515
- jax_gs = self._gumbel_softmax
516
- jax_argmax, jax_param = self.argmax()
517
-
518
- def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
519
- prob = jnp.stack([1.0 - prob, prob], axis=-1)
520
- sample = jax_gs(key, prob)
521
- sample = jax_argmax(sample, axis=-1, param=param)
522
- return sample
523
-
524
- return _jax_wrapped_calc_bernoulli_approx, jax_param
722
+ return self.sampling.bernoulli(self)
525
723
 
526
- def discrete(self):
527
- if self.verbose:
528
- raise_warning('Using the replacement rule: '
529
- 'Discrete(p) --> Gumbel-softmax(p)')
530
-
531
- jax_gs = self._gumbel_softmax
532
- jax_argmax, jax_param = self.argmax()
533
-
534
- def _jax_wrapped_calc_discrete_approx(key, prob, param):
535
- sample = jax_gs(key, prob)
536
- sample = jax_argmax(sample, axis=-1, param=param)
537
- return sample
538
-
539
- return _jax_wrapped_calc_discrete_approx, jax_param
540
-
724
+ def poisson(self):
725
+ return self.sampling.poisson(self)
726
+
727
+ def geometric(self):
728
+ return self.sampling.geometric(self)
541
729
 
730
+
731
+ # ===========================================================================
542
732
  # UNIT TESTS
733
+ #
734
+ # ===========================================================================
735
+
543
736
  logic = FuzzyLogic()
544
- w = 100.0
737
+ w = 1000.0
545
738
 
546
739
 
547
740
  def _test_logical():
@@ -568,7 +761,7 @@ def _test_logical():
568
761
  def _test_indexing():
569
762
  print('testing indexing')
570
763
  _argmax, _ = logic.argmax()
571
- _argmin, _ = logic.argmax()
764
+ _argmin, _ = logic.argmin()
572
765
 
573
766
  def argmaxmin(x):
574
767
  amax = _argmax(x, 0, w)
@@ -598,13 +791,14 @@ def _test_random():
598
791
  key = random.PRNGKey(42)
599
792
  _bernoulli, _ = logic.bernoulli()
600
793
  _discrete, _ = logic.discrete()
794
+ _geometric, _ = logic.geometric()
601
795
 
602
796
  def bern(n):
603
797
  prob = jnp.asarray([0.3] * n)
604
798
  sample = _bernoulli(key, prob, w)
605
799
  return sample
606
800
 
607
- samples = bern(5000)
801
+ samples = bern(50000)
608
802
  print(jnp.mean(samples))
609
803
 
610
804
  def disc(n):
@@ -613,20 +807,30 @@ def _test_random():
613
807
  sample = _discrete(key, prob, w)
614
808
  return sample
615
809
 
616
- samples = disc(5000)
810
+ samples = disc(50000)
617
811
  samples = jnp.round(samples)
618
812
  print([jnp.mean(samples == i) for i in range(3)])
619
-
813
+
814
+ def geom(n):
815
+ prob = jnp.asarray([0.3] * n)
816
+ sample = _geometric(key, prob, w)
817
+ return sample
818
+
819
+ samples = geom(50000)
820
+ print(jnp.mean(samples))
821
+
620
822
 
621
823
  def _test_rounding():
622
824
  print('testing rounding')
623
825
  _floor, _ = logic.floor()
624
826
  _ceil, _ = logic.ceil()
827
+ _round, _ = logic.round()
625
828
  _mod, _ = logic.mod()
626
829
 
627
- x = jnp.asarray([2.1, 0.5001, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
830
+ x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
628
831
  print(_floor(x, w))
629
832
  print(_ceil(x, w))
833
+ print(_round(x, w))
630
834
  print(_mod(x, 2.0, w))
631
835
 
632
836