pyRDDLGym-jax 0.1__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. pyRDDLGym_jax/core/compiler.py +445 -221
  2. pyRDDLGym_jax/core/logic.py +129 -62
  3. pyRDDLGym_jax/core/planner.py +699 -332
  4. pyRDDLGym_jax/core/simulator.py +5 -7
  5. pyRDDLGym_jax/core/tuning.py +23 -12
  6. pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
  7. pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +2 -2
  8. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
  9. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +18 -0
  10. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
  11. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
  12. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
  13. pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
  14. pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
  15. pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
  16. pyRDDLGym_jax/examples/run_gradient.py +1 -1
  17. pyRDDLGym_jax/examples/run_gym.py +1 -2
  18. pyRDDLGym_jax/examples/run_plan.py +7 -0
  19. pyRDDLGym_jax/examples/run_tune.py +6 -0
  20. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/METADATA +1 -1
  21. pyRDDLGym_jax-0.2.dist-info/RECORD +46 -0
  22. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/WHEEL +1 -1
  23. pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
  24. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
  25. /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
  26. /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
  27. /pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +0 -0
  28. /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
  29. /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
  30. /pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +0 -0
  31. /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
  32. /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
  33. /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
  34. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/LICENSE +0 -0
  35. {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  import jax.random as random
4
- from typing import Set
4
+ from typing import Optional, Set
5
5
 
6
6
  from pyRDDLGym.core.debug.exception import raise_warning
7
7
 
@@ -20,6 +20,32 @@ class StandardComplement(Complement):
20
20
  return 1.0 - x
21
21
 
22
22
 
23
+ class Comparison:
24
+ '''Base class for approximate comparison operations.'''
25
+
26
+ def greater_equal(self, x, y, param):
27
+ raise NotImplementedError
28
+
29
+ def greater(self, x, y, param):
30
+ raise NotImplementedError
31
+
32
+ def equal(self, x, y, param):
33
+ raise NotImplementedError
34
+
35
+
36
+ class SigmoidComparison(Comparison):
37
+ '''Comparison operations approximated using sigmoid functions.'''
38
+
39
+ def greater_equal(self, x, y, param):
40
+ return jax.nn.sigmoid(param * (x - y))
41
+
42
+ def greater(self, x, y, param):
43
+ return jax.nn.sigmoid(param * (x - y))
44
+
45
+ def equal(self, x, y, param):
46
+ return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
47
+
48
+
23
49
  class TNorm:
24
50
  '''Base class for fuzzy differentiable t-norms.'''
25
51
 
@@ -71,40 +97,61 @@ class FuzzyLogic:
71
97
 
72
98
  def __init__(self, tnorm: TNorm=ProductTNorm(),
73
99
  complement: Complement=StandardComplement(),
100
+ comparison: Comparison=SigmoidComparison(),
74
101
  weight: float=10.0,
75
- debias: Set[str]={},
102
+ debias: Optional[Set[str]]=None,
76
103
  eps: float=1e-10,
77
- verbose: bool=False):
104
+ verbose: bool=False,
105
+ use64bit: bool=False) -> None:
78
106
  '''Creates a new fuzzy logic in Jax.
79
107
 
80
108
  :param tnorm: fuzzy operator for logical AND
81
109
  :param complement: fuzzy operator for logical NOT
82
- :param weight: a concentration parameter (larger means better accuracy)
110
+ :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
111
+ :param weight: a sharpness parameter for sigmoid and softmax activations
83
112
  :param error: an error parameter (e.g. floor) (smaller means better accuracy)
84
113
  :param debias: which functions to de-bias approximate on forward pass
85
114
  :param eps: small positive float to mitigate underflow
86
115
  :param verbose: whether to dump replacements and other info to console
116
+ :param use64bit: whether to perform arithmetic in 64 bit
87
117
  '''
88
118
  self.tnorm = tnorm
89
119
  self.complement = complement
120
+ self.comparison = comparison
90
121
  self.weight = float(weight)
122
+ if debias is None:
123
+ debias = set()
91
124
  self.debias = debias
92
125
  self.eps = eps
93
126
  self.verbose = verbose
94
-
95
- def summarize_hyperparameters(self):
127
+ self.set_use64bit(use64bit)
128
+
129
+ def set_use64bit(self, use64bit: bool) -> None:
130
+ self.use64bit = use64bit
131
+ if use64bit:
132
+ self.REAL = jnp.float64
133
+ self.INT = jnp.int64
134
+ jax.config.update('jax_enable_x64', True)
135
+ else:
136
+ self.REAL = jnp.float32
137
+ self.INT = jnp.int32
138
+ jax.config.update('jax_enable_x64', False)
139
+
140
+ def summarize_hyperparameters(self) -> None:
96
141
  print(f'model relaxation:\n'
97
142
  f' tnorm ={type(self.tnorm).__name__}\n'
98
143
  f' complement ={type(self.complement).__name__}\n'
144
+ f' comparison ={type(self.comparison).__name__}\n'
99
145
  f' sigmoid_weight={self.weight}\n'
100
146
  f' cpfs_to_debias={self.debias}\n'
101
- f' underflow_tol ={self.eps}')
147
+ f' underflow_tol ={self.eps}\n'
148
+ f' use64bit ={self.use64bit}')
102
149
 
103
150
  # ===========================================================================
104
151
  # logical operators
105
152
  # ===========================================================================
106
153
 
107
- def And(self):
154
+ def logical_and(self):
108
155
  if self.verbose:
109
156
  raise_warning('Using the replacement rule: a ^ b --> tnorm(a, b).')
110
157
 
@@ -115,9 +162,9 @@ class FuzzyLogic:
115
162
 
116
163
  return _jax_wrapped_calc_and_approx, None
117
164
 
118
- def Not(self):
165
+ def logical_not(self):
119
166
  if self.verbose:
120
- raise_warning('Using the replacement rule: ~a --> 1 - a')
167
+ raise_warning('Using the replacement rule: ~a --> complement(a)')
121
168
 
122
169
  _not = self.complement
123
170
 
@@ -126,9 +173,9 @@ class FuzzyLogic:
126
173
 
127
174
  return _jax_wrapped_calc_not_approx, None
128
175
 
129
- def Or(self):
176
+ def logical_or(self):
130
177
  if self.verbose:
131
- raise_warning('Using the replacement rule: a or b --> tconorm(a, b).')
178
+ raise_warning('Using the replacement rule: a | b --> tconorm(a, b).')
132
179
 
133
180
  _not = self.complement
134
181
  _and = self.tnorm.norm
@@ -141,7 +188,7 @@ class FuzzyLogic:
141
188
  def xor(self):
142
189
  if self.verbose:
143
190
  raise_warning('Using the replacement rule: '
144
- 'a xor b --> (a or b) ^ (a ^ b).')
191
+ 'a ~ b --> (a | b) ^ (a ^ b).')
145
192
 
146
193
  _not = self.complement
147
194
  _and = self.tnorm.norm
@@ -182,7 +229,7 @@ class FuzzyLogic:
182
229
  def forall(self):
183
230
  if self.verbose:
184
231
  raise_warning('Using the replacement rule: '
185
- 'forall(a) --> tnorm(a[1], tnorm(a[2], ...))')
232
+ 'forall(a) --> a[1] ^ a[2] ^ ...')
186
233
 
187
234
  _forall = self.tnorm.norms
188
235
 
@@ -204,31 +251,35 @@ class FuzzyLogic:
204
251
  # comparison operators
205
252
  # ===========================================================================
206
253
 
207
- def greaterEqual(self):
254
+ def greater_equal(self):
208
255
  if self.verbose:
209
- raise_warning('Using the replacement rule: a >= b --> sigmoid(a - b)')
210
-
211
- debias = 'greaterEqual' in self.debias
256
+ raise_warning('Using the replacement rule: '
257
+ 'a >= b --> comparison.greater_equal(a, b)')
258
+
259
+ greater_equal_op = self.comparison.greater_equal
260
+ debias = 'greater_equal' in self.debias
212
261
 
213
262
  def _jax_wrapped_calc_geq_approx(a, b, param):
214
- sample = jax.nn.sigmoid(param * (a - b))
263
+ sample = greater_equal_op(a, b, param)
215
264
  if debias:
216
265
  hard_sample = jnp.greater_equal(a, b)
217
266
  sample += jax.lax.stop_gradient(hard_sample - sample)
218
267
  return sample
219
268
 
220
- tags = ('weight', 'greaterEqual')
269
+ tags = ('weight', 'greater_equal')
221
270
  new_param = (tags, self.weight)
222
271
  return _jax_wrapped_calc_geq_approx, new_param
223
272
 
224
273
  def greater(self):
225
274
  if self.verbose:
226
- raise_warning('Using the replacement rule: a > b --> sigmoid(a - b)')
227
-
275
+ raise_warning('Using the replacement rule: '
276
+ 'a > b --> comparison.greater(a, b)')
277
+
278
+ greater_op = self.comparison.greater
228
279
  debias = 'greater' in self.debias
229
280
 
230
281
  def _jax_wrapped_calc_gre_approx(a, b, param):
231
- sample = jax.nn.sigmoid(param * (a - b))
282
+ sample = greater_op(a, b, param)
232
283
  if debias:
233
284
  hard_sample = jnp.greater(a, b)
234
285
  sample += jax.lax.stop_gradient(hard_sample - sample)
@@ -238,8 +289,8 @@ class FuzzyLogic:
238
289
  new_param = (tags, self.weight)
239
290
  return _jax_wrapped_calc_gre_approx, new_param
240
291
 
241
- def lessEqual(self):
242
- jax_geq, jax_param = self.greaterEqual()
292
+ def less_equal(self):
293
+ jax_geq, jax_param = self.greater_equal()
243
294
 
244
295
  def _jax_wrapped_calc_leq_approx(a, b, param):
245
296
  return jax_geq(-a, -b, param)
@@ -256,12 +307,14 @@ class FuzzyLogic:
256
307
 
257
308
  def equal(self):
258
309
  if self.verbose:
259
- raise_warning('Using the replacement rule: a == b --> sech^2(b - a)')
260
-
310
+ raise_warning('Using the replacement rule: '
311
+ 'a == b --> comparison.equal(a, b)')
312
+
313
+ equal_op = self.comparison.equal
261
314
  debias = 'equal' in self.debias
262
315
 
263
316
  def _jax_wrapped_calc_equal_approx(a, b, param):
264
- sample = 1.0 - jnp.square(jnp.tanh(param * (b - a)))
317
+ sample = equal_op(a, b, param)
265
318
  if debias:
266
319
  hard_sample = jnp.equal(a, b)
267
320
  sample += jax.lax.stop_gradient(hard_sample - sample)
@@ -271,7 +324,7 @@ class FuzzyLogic:
271
324
  new_param = (tags, self.weight)
272
325
  return _jax_wrapped_calc_equal_approx, new_param
273
326
 
274
- def notEqual(self):
327
+ def not_equal(self):
275
328
  _not = self.complement
276
329
  jax_eq, jax_param = self.equal()
277
330
 
@@ -284,31 +337,32 @@ class FuzzyLogic:
284
337
  # special functions
285
338
  # ===========================================================================
286
339
 
287
- def signum(self):
340
+ def sgn(self):
288
341
  if self.verbose:
289
- raise_warning('Using the replacement rule: signum(x) --> tanh(x)')
342
+ raise_warning('Using the replacement rule: sgn(x) --> tanh(x)')
290
343
 
291
- debias = 'signum' in self.debias
344
+ debias = 'sgn' in self.debias
292
345
 
293
- def _jax_wrapped_calc_signum_approx(x, param):
346
+ def _jax_wrapped_calc_sgn_approx(x, param):
294
347
  sample = jnp.tanh(param * x)
295
348
  if debias:
296
349
  hard_sample = jnp.sign(x)
297
350
  sample += jax.lax.stop_gradient(hard_sample - sample)
298
351
  return sample
299
352
 
300
- tags = ('weight', 'signum')
353
+ tags = ('weight', 'sgn')
301
354
  new_param = (tags, self.weight)
302
- return _jax_wrapped_calc_signum_approx, new_param
355
+ return _jax_wrapped_calc_sgn_approx, new_param
303
356
 
304
357
  def floor(self):
305
358
  if self.verbose:
306
359
  raise_warning('Using the replacement rule: '
307
360
  'floor(x) --> x - atan(-1.0 / tan(pi * x)) / pi - 0.5')
308
-
361
+
309
362
  def _jax_wrapped_calc_floor_approx(x, param):
310
363
  sawtooth_part = jnp.arctan(-1.0 / jnp.tan(x * jnp.pi)) / jnp.pi + 0.5
311
- return x - jax.lax.stop_gradient(sawtooth_part)
364
+ sample = x - jax.lax.stop_gradient(sawtooth_part)
365
+ return sample
312
366
 
313
367
  return _jax_wrapped_calc_floor_approx, None
314
368
 
@@ -324,8 +378,14 @@ class FuzzyLogic:
324
378
  if self.verbose:
325
379
  raise_warning('Using the replacement rule: round(x) --> x')
326
380
 
381
+ debias = 'round' in self.debias
382
+
327
383
  def _jax_wrapped_calc_round_approx(x, param):
328
- return x
384
+ sample = x
385
+ if debias:
386
+ hard_sample = jnp.round(x)
387
+ sample += jax.lax.stop_gradient(hard_sample - sample)
388
+ return sample
329
389
 
330
390
  return _jax_wrapped_calc_round_approx, None
331
391
 
@@ -337,13 +397,13 @@ class FuzzyLogic:
337
397
 
338
398
  return _jax_wrapped_calc_mod_approx, jax_param
339
399
 
340
- def floorDiv(self):
400
+ def div(self):
341
401
  jax_floor, jax_param = self.floor()
342
402
 
343
- def _jax_wrapped_calc_mod_approx(x, y, param):
403
+ def _jax_wrapped_calc_div_approx(x, y, param):
344
404
  return jax_floor(x / y, param)
345
405
 
346
- return _jax_wrapped_calc_mod_approx, jax_param
406
+ return _jax_wrapped_calc_div_approx, jax_param
347
407
 
348
408
  def sqrt(self):
349
409
  if self.verbose:
@@ -369,14 +429,14 @@ class FuzzyLogic:
369
429
  def argmax(self):
370
430
  if self.verbose:
371
431
  raise_warning('Using the replacement rule: '
372
- f'argmax(x) --> sum(i * softmax(x[i]))')
432
+ 'argmax(x) --> sum(i * softmax(x[i]))')
373
433
 
374
434
  debias = 'argmax' in self.debias
375
435
 
376
436
  def _jax_wrapped_calc_argmax_approx(x, axis, param):
377
- prob_max = jax.nn.softmax(param * x, axis=axis)
378
- literals = FuzzyLogic._literals(prob_max.shape, axis=axis)
379
- sample = jnp.sum(literals * prob_max, axis=axis)
437
+ literals = FuzzyLogic._literals(x.shape, axis=axis)
438
+ soft_max = jax.nn.softmax(param * x, axis=axis)
439
+ sample = jnp.sum(literals * soft_max, axis=axis)
380
440
  if debias:
381
441
  hard_sample = jnp.argmax(x, axis=axis)
382
442
  sample += jax.lax.stop_gradient(hard_sample - sample)
@@ -398,26 +458,33 @@ class FuzzyLogic:
398
458
  # control flow
399
459
  # ===========================================================================
400
460
 
401
- def If(self):
461
+ def control_if(self):
402
462
  if self.verbose:
403
463
  raise_warning('Using the replacement rule: '
404
464
  'if c then a else b --> c * a + (1 - c) * b')
405
465
 
466
+ debias = 'if' in self.debias
467
+
406
468
  def _jax_wrapped_calc_if_approx(c, a, b, param):
407
- return c * a + (1.0 - c) * b
469
+ sample = c * a + (1.0 - c) * b
470
+ if debias:
471
+ hard_sample = jnp.select([c, ~c], [a, b])
472
+ sample += jax.lax.stop_gradient(hard_sample - sample)
473
+ return sample
408
474
 
409
475
  return _jax_wrapped_calc_if_approx, None
410
476
 
411
- def Switch(self):
477
+ def control_switch(self):
412
478
  if self.verbose:
413
479
  raise_warning('Using the replacement rule: '
414
- 'switch(pred) { cases } --> sum(cases[i] * (pred == i))')
480
+ 'switch(pred) { cases } --> '
481
+ 'sum(cases[i] * softmax(-abs(pred - i)))')
415
482
 
416
- debias = 'Switch' in self.debias
483
+ debias = 'switch' in self.debias
417
484
 
418
485
  def _jax_wrapped_calc_switch_approx(pred, cases, param):
419
- pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
420
486
  literals = FuzzyLogic._literals(cases.shape, axis=0)
487
+ pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
421
488
  proximity = -jnp.abs(pred - literals)
422
489
  soft_case = jax.nn.softmax(param * proximity, axis=0)
423
490
  sample = jnp.sum(cases * soft_case, axis=0)
@@ -436,8 +503,8 @@ class FuzzyLogic:
436
503
  # ===========================================================================
437
504
 
438
505
  def _gumbel_softmax(self, key, prob):
439
- Gumbel01 = random.gumbel(key=key, shape=prob.shape)
440
- sample = Gumbel01 + jnp.log1p(prob + self.eps - 1.0)
506
+ Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=self.REAL)
507
+ sample = Gumbel01 + jnp.log(prob + self.eps)
441
508
  return sample
442
509
 
443
510
  def bernoulli(self):
@@ -448,13 +515,13 @@ class FuzzyLogic:
448
515
  jax_gs = self._gumbel_softmax
449
516
  jax_argmax, jax_param = self.argmax()
450
517
 
451
- def _jax_wrapped_calc_switch_approx(key, prob, param):
518
+ def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
452
519
  prob = jnp.stack([1.0 - prob, prob], axis=-1)
453
520
  sample = jax_gs(key, prob)
454
- sample = jax_argmax(sample, -1, param)
521
+ sample = jax_argmax(sample, axis=-1, param=param)
455
522
  return sample
456
523
 
457
- return _jax_wrapped_calc_switch_approx, jax_param
524
+ return _jax_wrapped_calc_bernoulli_approx, jax_param
458
525
 
459
526
  def discrete(self):
460
527
  if self.verbose:
@@ -466,7 +533,7 @@ class FuzzyLogic:
466
533
 
467
534
  def _jax_wrapped_calc_discrete_approx(key, prob, param):
468
535
  sample = jax_gs(key, prob)
469
- sample = jax_argmax(sample, -1, param)
536
+ sample = jax_argmax(sample, axis=-1, param=param)
470
537
  return sample
471
538
 
472
539
  return _jax_wrapped_calc_discrete_approx, jax_param
@@ -479,11 +546,11 @@ w = 100.0
479
546
 
480
547
  def _test_logical():
481
548
  print('testing logical')
482
- _and, _ = logic.And()
483
- _not, _ = logic.Not()
549
+ _and, _ = logic.logical_and()
550
+ _not, _ = logic.logical_not()
484
551
  _gre, _ = logic.greater()
485
- _or, _ = logic.Or()
486
- _if, _ = logic.If()
552
+ _or, _ = logic.logical_or()
553
+ _if, _ = logic.control_if()
487
554
 
488
555
  # https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
489
556
  def test_logic(x1, x2):
@@ -516,7 +583,7 @@ def _test_indexing():
516
583
 
517
584
  def _test_control():
518
585
  print('testing control')
519
- _switch, _ = logic.Switch()
586
+ _switch, _ = logic.control_switch()
520
587
 
521
588
  pred = jnp.asarray(jnp.linspace(0, 2, 10))
522
589
  case1 = jnp.asarray([-10.] * 10)