pyRDDLGym-jax 0.3__py3-none-any.whl → 0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +90 -67
- pyRDDLGym_jax/core/logic.py +188 -46
- pyRDDLGym_jax/core/planner.py +59 -47
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +7 -7
- pyRDDLGym_jax-0.4.dist-info/METADATA +276 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.4.dist-info}/RECORD +11 -11
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.4.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.3.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.4.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.4'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1,22 +1,11 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
+
import traceback
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
4
|
+
|
|
2
5
|
import jax
|
|
3
6
|
import jax.numpy as jnp
|
|
4
7
|
import jax.random as random
|
|
5
8
|
import jax.scipy as scipy
|
|
6
|
-
import traceback
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
8
|
-
|
|
9
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
10
|
-
|
|
11
|
-
# more robust approach - if user does not have this or broken try to continue
|
|
12
|
-
try:
|
|
13
|
-
from tensorflow_probability.substrates import jax as tfp
|
|
14
|
-
except Exception:
|
|
15
|
-
raise_warning('Failed to import tensorflow-probability: '
|
|
16
|
-
'compilation of some complex distributions '
|
|
17
|
-
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
18
|
-
traceback.print_exc()
|
|
19
|
-
tfp = None
|
|
20
9
|
|
|
21
10
|
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
22
11
|
from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
|
|
@@ -25,12 +14,23 @@ from pyRDDLGym.core.compiler.tracer import RDDLObjectsTracer
|
|
|
25
14
|
from pyRDDLGym.core.constraints import RDDLConstraints
|
|
26
15
|
from pyRDDLGym.core.debug.exception import (
|
|
27
16
|
print_stack_trace,
|
|
17
|
+
raise_warning,
|
|
28
18
|
RDDLInvalidNumberOfArgumentsError,
|
|
29
19
|
RDDLNotImplementedError
|
|
30
20
|
)
|
|
31
21
|
from pyRDDLGym.core.debug.logger import Logger
|
|
32
22
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
33
23
|
|
|
24
|
+
# more robust approach - if user does not have this or broken try to continue
|
|
25
|
+
try:
|
|
26
|
+
from tensorflow_probability.substrates import jax as tfp
|
|
27
|
+
except Exception:
|
|
28
|
+
raise_warning('Failed to import tensorflow-probability: '
|
|
29
|
+
'compilation of some complex distributions '
|
|
30
|
+
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
31
|
+
traceback.print_exc()
|
|
32
|
+
tfp = None
|
|
33
|
+
|
|
34
34
|
|
|
35
35
|
# ===========================================================================
|
|
36
36
|
# EXACT RDDL TO JAX COMPILATION RULES
|
|
@@ -87,7 +87,7 @@ def _function_aggregation_exact_named(op, name):
|
|
|
87
87
|
def _function_if_exact_named():
|
|
88
88
|
|
|
89
89
|
def _jax_wrapped_if_exact(c, a, b, param):
|
|
90
|
-
return jnp.where(c, a, b)
|
|
90
|
+
return jnp.where(c > 0.5, a, b)
|
|
91
91
|
|
|
92
92
|
return _jax_wrapped_if_exact
|
|
93
93
|
|
|
@@ -114,16 +114,27 @@ def _function_bernoulli_exact_named():
|
|
|
114
114
|
def _function_discrete_exact_named():
|
|
115
115
|
|
|
116
116
|
def _jax_wrapped_discrete_exact(key, prob, param):
|
|
117
|
-
|
|
118
|
-
sample = random.categorical(key=key, logits=logits, axis=-1)
|
|
119
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
120
|
-
jnp.all(prob >= 0),
|
|
121
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
122
|
-
return sample, out_of_bounds
|
|
117
|
+
return random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
123
118
|
|
|
124
119
|
return _jax_wrapped_discrete_exact
|
|
125
120
|
|
|
126
121
|
|
|
122
|
+
def _function_poisson_exact_named():
|
|
123
|
+
|
|
124
|
+
def _jax_wrapped_poisson_exact(key, rate, param):
|
|
125
|
+
return random.poisson(key=key, lam=rate, dtype=jnp.int64)
|
|
126
|
+
|
|
127
|
+
return _jax_wrapped_poisson_exact
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _function_geometric_exact_named():
|
|
131
|
+
|
|
132
|
+
def _jax_wrapped_geometric_exact(key, prob, param):
|
|
133
|
+
return random.geometric(key=key, p=prob, dtype=jnp.int64)
|
|
134
|
+
|
|
135
|
+
return _jax_wrapped_geometric_exact
|
|
136
|
+
|
|
137
|
+
|
|
127
138
|
class JaxRDDLCompiler:
|
|
128
139
|
'''Compiles a RDDL AST representation into an equivalent JAX representation.
|
|
129
140
|
All operations are identical to their numpy equivalents.
|
|
@@ -210,12 +221,12 @@ class JaxRDDLCompiler:
|
|
|
210
221
|
}
|
|
211
222
|
|
|
212
223
|
EXACT_RDDL_TO_JAX_IF = _function_if_exact_named()
|
|
213
|
-
|
|
214
224
|
EXACT_RDDL_TO_JAX_SWITCH = _function_switch_exact_named()
|
|
215
225
|
|
|
216
226
|
EXACT_RDDL_TO_JAX_BERNOULLI = _function_bernoulli_exact_named()
|
|
217
|
-
|
|
218
227
|
EXACT_RDDL_TO_JAX_DISCRETE = _function_discrete_exact_named()
|
|
228
|
+
EXACT_RDDL_TO_JAX_POISSON = _function_poisson_exact_named()
|
|
229
|
+
EXACT_RDDL_TO_JAX_GEOMETRIC = _function_geometric_exact_named()
|
|
219
230
|
|
|
220
231
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
221
232
|
allow_synchronous_state: bool=True,
|
|
@@ -289,6 +300,8 @@ class JaxRDDLCompiler:
|
|
|
289
300
|
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
290
301
|
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
291
302
|
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
303
|
+
self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
304
|
+
self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
292
305
|
|
|
293
306
|
# ===========================================================================
|
|
294
307
|
# main compilation subroutines
|
|
@@ -996,13 +1009,14 @@ class JaxRDDLCompiler:
|
|
|
996
1009
|
jax_op, jax_param = self._unwrap(negative_op, expr.id, info)
|
|
997
1010
|
return self._jax_unary(jax_expr, jax_op, jax_param, at_least_int=True)
|
|
998
1011
|
|
|
999
|
-
elif n == 2:
|
|
1000
|
-
|
|
1001
|
-
jax_lhs = self._jax(lhs, info)
|
|
1002
|
-
jax_rhs = self._jax(rhs, info)
|
|
1012
|
+
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
1013
|
+
jax_exprs = [self._jax(arg, info) for arg in args]
|
|
1003
1014
|
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1004
|
-
|
|
1005
|
-
|
|
1015
|
+
result = jax_exprs[0]
|
|
1016
|
+
for jax_rhs in jax_exprs[1:]:
|
|
1017
|
+
result = self._jax_binary(
|
|
1018
|
+
result, jax_rhs, jax_op, jax_param, at_least_int=True)
|
|
1019
|
+
return result
|
|
1006
1020
|
|
|
1007
1021
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1008
1022
|
|
|
@@ -1046,13 +1060,14 @@ class JaxRDDLCompiler:
|
|
|
1046
1060
|
jax_op, jax_param = self._unwrap(logical_not_op, expr.id, info)
|
|
1047
1061
|
return self._jax_unary(jax_expr, jax_op, jax_param, check_dtype=bool)
|
|
1048
1062
|
|
|
1049
|
-
elif n == 2:
|
|
1050
|
-
|
|
1051
|
-
jax_lhs = self._jax(lhs, info)
|
|
1052
|
-
jax_rhs = self._jax(rhs, info)
|
|
1063
|
+
elif n == 2 or (n >= 2 and op in {'^', '&', '|'}):
|
|
1064
|
+
jax_exprs = [self._jax(arg, info) for arg in args]
|
|
1053
1065
|
jax_op, jax_param = self._unwrap(valid_ops[op], expr.id, info)
|
|
1054
|
-
|
|
1055
|
-
|
|
1066
|
+
result = jax_exprs[0]
|
|
1067
|
+
for jax_rhs in jax_exprs[1:]:
|
|
1068
|
+
result = self._jax_binary(
|
|
1069
|
+
result, jax_rhs, jax_op, jax_param, check_dtype=bool)
|
|
1070
|
+
return result
|
|
1056
1071
|
|
|
1057
1072
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1058
1073
|
|
|
@@ -1165,16 +1180,17 @@ class JaxRDDLCompiler:
|
|
|
1165
1180
|
return _jax_wrapped_if_then_else
|
|
1166
1181
|
|
|
1167
1182
|
def _jax_switch(self, expr, info):
|
|
1183
|
+
pred, *_ = expr.args
|
|
1168
1184
|
|
|
1169
|
-
# if
|
|
1170
|
-
|
|
1185
|
+
# if predicate is non-fluent, always use the exact operation
|
|
1186
|
+
# case conditions are currently only literals so they are non-fluent
|
|
1187
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1171
1188
|
switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
1172
1189
|
else:
|
|
1173
1190
|
switch_op = self.SWITCH_HELPER
|
|
1174
1191
|
jax_switch, jax_param = self._unwrap(switch_op, expr.id, info)
|
|
1175
1192
|
|
|
1176
1193
|
# recursively compile predicate
|
|
1177
|
-
pred, *_ = expr.args
|
|
1178
1194
|
jax_pred = self._jax(pred, info)
|
|
1179
1195
|
|
|
1180
1196
|
# recursively compile cases
|
|
@@ -1426,15 +1442,24 @@ class JaxRDDLCompiler:
|
|
|
1426
1442
|
def _jax_poisson(self, expr, info):
|
|
1427
1443
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_POISSON']
|
|
1428
1444
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1429
|
-
|
|
1430
1445
|
arg_rate, = expr.args
|
|
1446
|
+
|
|
1447
|
+
# if rate is non-fluent, always use the exact operation
|
|
1448
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
|
|
1449
|
+
poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
1450
|
+
else:
|
|
1451
|
+
poisson_op = self.POISSON_HELPER
|
|
1452
|
+
jax_poisson, jax_param = self._unwrap(poisson_op, expr.id, info)
|
|
1453
|
+
|
|
1454
|
+
# recursively compile arguments
|
|
1431
1455
|
jax_rate = self._jax(arg_rate, info)
|
|
1432
1456
|
|
|
1433
1457
|
# uses the implicit JAX subroutine
|
|
1434
1458
|
def _jax_wrapped_distribution_poisson(x, params, key):
|
|
1435
1459
|
rate, key, err = jax_rate(x, params, key)
|
|
1436
1460
|
key, subkey = random.split(key)
|
|
1437
|
-
|
|
1461
|
+
param = params.get(jax_param, None)
|
|
1462
|
+
sample = jax_poisson(subkey, rate, param).astype(self.INT)
|
|
1438
1463
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1439
1464
|
err |= (out_of_bounds * ERR)
|
|
1440
1465
|
return sample, key, err
|
|
@@ -1535,33 +1560,25 @@ class JaxRDDLCompiler:
|
|
|
1535
1560
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_GEOMETRIC']
|
|
1536
1561
|
JaxRDDLCompiler._check_num_args(expr, 1)
|
|
1537
1562
|
arg_prob, = expr.args
|
|
1563
|
+
|
|
1564
|
+
# if prob is non-fluent, always use the exact operation
|
|
1565
|
+
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1566
|
+
geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
1567
|
+
else:
|
|
1568
|
+
geom_op = self.GEOMETRIC_HELPER
|
|
1569
|
+
jax_geom, jax_param = self._unwrap(geom_op, expr.id, info)
|
|
1570
|
+
|
|
1571
|
+
# recursively compile arguments
|
|
1538
1572
|
jax_prob = self._jax(arg_prob, info)
|
|
1539
1573
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
err |= (out_of_bounds * ERR)
|
|
1549
|
-
return sample, key, err
|
|
1550
|
-
|
|
1551
|
-
else:
|
|
1552
|
-
floor_op, jax_param = self._unwrap(
|
|
1553
|
-
self.KNOWN_UNARY['floor'], expr.id, info)
|
|
1554
|
-
|
|
1555
|
-
# reparameterization trick Geom(p) = floor(ln(U(0, 1)) / ln(p)) + 1
|
|
1556
|
-
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1557
|
-
prob, key, err = jax_prob(x, params, key)
|
|
1558
|
-
key, subkey = random.split(key)
|
|
1559
|
-
U = random.uniform(key=subkey, shape=jnp.shape(prob), dtype=self.REAL)
|
|
1560
|
-
param = params.get(jax_param, None)
|
|
1561
|
-
sample = floor_op(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
1562
|
-
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1563
|
-
err |= (out_of_bounds * ERR)
|
|
1564
|
-
return sample, key, err
|
|
1574
|
+
def _jax_wrapped_distribution_geometric(x, params, key):
|
|
1575
|
+
prob, key, err = jax_prob(x, params, key)
|
|
1576
|
+
key, subkey = random.split(key)
|
|
1577
|
+
param = params.get(jax_param, None)
|
|
1578
|
+
sample = jax_geom(subkey, prob, param).astype(self.INT)
|
|
1579
|
+
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1580
|
+
err |= (out_of_bounds * ERR)
|
|
1581
|
+
return sample, key, err
|
|
1565
1582
|
|
|
1566
1583
|
return _jax_wrapped_distribution_geometric
|
|
1567
1584
|
|
|
@@ -1770,7 +1787,10 @@ class JaxRDDLCompiler:
|
|
|
1770
1787
|
# dispatch to sampling subroutine
|
|
1771
1788
|
key, subkey = random.split(key)
|
|
1772
1789
|
param = params.get(jax_param, None)
|
|
1773
|
-
sample
|
|
1790
|
+
sample = jax_discrete(subkey, prob, param)
|
|
1791
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1792
|
+
jnp.all(prob >= 0),
|
|
1793
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1774
1794
|
error |= (out_of_bounds * ERR)
|
|
1775
1795
|
return sample, key, error
|
|
1776
1796
|
|
|
@@ -1803,7 +1823,10 @@ class JaxRDDLCompiler:
|
|
|
1803
1823
|
# dispatch to sampling subroutine
|
|
1804
1824
|
key, subkey = random.split(key)
|
|
1805
1825
|
param = params.get(jax_param, None)
|
|
1806
|
-
sample
|
|
1826
|
+
sample = jax_discrete(subkey, prob, param)
|
|
1827
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
1828
|
+
jnp.all(prob >= 0),
|
|
1829
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
1807
1830
|
error |= (out_of_bounds * ERR)
|
|
1808
1831
|
return sample, key, error
|
|
1809
1832
|
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1,11 +1,19 @@
|
|
|
1
|
+
from typing import Optional, Set
|
|
2
|
+
|
|
1
3
|
import jax
|
|
2
4
|
import jax.numpy as jnp
|
|
3
5
|
import jax.random as random
|
|
4
|
-
from typing import Optional, Set
|
|
5
6
|
|
|
6
7
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
7
8
|
|
|
8
9
|
|
|
10
|
+
# ===========================================================================
|
|
11
|
+
# LOGICAL COMPLEMENT
|
|
12
|
+
# - abstract class
|
|
13
|
+
# - standard complement
|
|
14
|
+
#
|
|
15
|
+
# ===========================================================================
|
|
16
|
+
|
|
9
17
|
class Complement:
|
|
10
18
|
'''Base class for approximate logical complement operations.'''
|
|
11
19
|
|
|
@@ -20,6 +28,13 @@ class StandardComplement(Complement):
|
|
|
20
28
|
return 1.0 - x
|
|
21
29
|
|
|
22
30
|
|
|
31
|
+
# ===========================================================================
|
|
32
|
+
# RELATIONAL OPERATIONS
|
|
33
|
+
# - abstract class
|
|
34
|
+
# - sigmoid comparison
|
|
35
|
+
#
|
|
36
|
+
# ===========================================================================
|
|
37
|
+
|
|
23
38
|
class Comparison:
|
|
24
39
|
'''Base class for approximate comparison operations.'''
|
|
25
40
|
|
|
@@ -44,7 +59,17 @@ class SigmoidComparison(Comparison):
|
|
|
44
59
|
|
|
45
60
|
def equal(self, x, y, param):
|
|
46
61
|
return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
|
|
47
|
-
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# ===========================================================================
|
|
65
|
+
# TNORMS
|
|
66
|
+
# - abstract tnorm
|
|
67
|
+
# - product tnorm
|
|
68
|
+
# - Godel tnorm
|
|
69
|
+
# - Lukasiewicz tnorm
|
|
70
|
+
# - Yager(p) tnorm
|
|
71
|
+
#
|
|
72
|
+
# ===========================================================================
|
|
48
73
|
|
|
49
74
|
class TNorm:
|
|
50
75
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
@@ -86,8 +111,133 @@ class LukasiewiczTNorm(TNorm):
|
|
|
86
111
|
|
|
87
112
|
def norms(self, x, axis):
|
|
88
113
|
return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class YagerTNorm(TNorm):
|
|
117
|
+
'''Yager t-norm given by the expression
|
|
118
|
+
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
119
|
+
|
|
120
|
+
def __init__(self, p=2.0):
|
|
121
|
+
self.p = p
|
|
122
|
+
|
|
123
|
+
def norm(self, x, y):
|
|
124
|
+
base_x = jax.nn.relu(1.0 - x)
|
|
125
|
+
base_y = jax.nn.relu(1.0 - y)
|
|
126
|
+
arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
|
|
127
|
+
return jax.nn.relu(1.0 - arg)
|
|
128
|
+
|
|
129
|
+
def norms(self, x, axis):
|
|
130
|
+
base = jax.nn.relu(1.0 - x)
|
|
131
|
+
arg = jnp.power(jnp.sum(base ** self.p, axis=axis), 1.0 / self.p)
|
|
132
|
+
return jax.nn.relu(1.0 - arg)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# ===========================================================================
|
|
136
|
+
# RANDOM SAMPLING
|
|
137
|
+
# - abstract sampler
|
|
138
|
+
# - Gumbel-softmax sampler
|
|
139
|
+
# - determinization
|
|
140
|
+
#
|
|
141
|
+
# ===========================================================================
|
|
142
|
+
|
|
143
|
+
class RandomSampling:
|
|
144
|
+
'''An abstract class that describes how discrete and non-reparameterizable
|
|
145
|
+
random variables are sampled.'''
|
|
146
|
+
|
|
147
|
+
def discrete(self, logic):
|
|
148
|
+
raise NotImplementedError
|
|
149
|
+
|
|
150
|
+
def bernoulli(self, logic):
|
|
151
|
+
jax_discrete, jax_param = self.discrete(logic)
|
|
152
|
+
|
|
153
|
+
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
154
|
+
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
155
|
+
sample = jax_discrete(key, prob, param)
|
|
156
|
+
return sample
|
|
157
|
+
|
|
158
|
+
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
159
|
+
|
|
160
|
+
def poisson(self, logic):
|
|
161
|
+
|
|
162
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, param):
|
|
163
|
+
return random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
164
|
+
|
|
165
|
+
return _jax_wrapped_calc_poisson_exact, None
|
|
166
|
+
|
|
167
|
+
def geometric(self, logic):
|
|
168
|
+
if logic.verbose:
|
|
169
|
+
raise_warning('Using the replacement rule: '
|
|
170
|
+
'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
|
|
171
|
+
|
|
172
|
+
jax_floor, jax_param = logic.floor()
|
|
173
|
+
|
|
174
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, param):
|
|
175
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
176
|
+
sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
177
|
+
return sample
|
|
178
|
+
|
|
179
|
+
return _jax_wrapped_calc_geometric_approx, jax_param
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class GumbelSoftmax(RandomSampling):
|
|
183
|
+
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
89
184
|
|
|
185
|
+
def discrete(self, logic):
|
|
186
|
+
if logic.verbose:
|
|
187
|
+
raise_warning('Using the replacement rule: '
|
|
188
|
+
'Discrete(p) --> Gumbel-softmax(p)')
|
|
189
|
+
|
|
190
|
+
jax_argmax, jax_param = logic.argmax()
|
|
191
|
+
|
|
192
|
+
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
|
|
193
|
+
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
|
|
194
|
+
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
195
|
+
sample = jax_argmax(sample, axis=-1, param=param)
|
|
196
|
+
return sample
|
|
197
|
+
|
|
198
|
+
return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class Determinization(RandomSampling):
|
|
202
|
+
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
90
203
|
|
|
204
|
+
def discrete(self, logic):
|
|
205
|
+
if logic.verbose:
|
|
206
|
+
raise_warning('Using the replacement rule: '
|
|
207
|
+
'Discrete(p) --> sum(i * p[i])')
|
|
208
|
+
|
|
209
|
+
def _jax_wrapped_calc_discrete_determinized(key, prob, param):
|
|
210
|
+
literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
|
|
211
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
212
|
+
return sample
|
|
213
|
+
|
|
214
|
+
return _jax_wrapped_calc_discrete_determinized, None
|
|
215
|
+
|
|
216
|
+
def poisson(self, logic):
|
|
217
|
+
if logic.verbose:
|
|
218
|
+
raise_warning('Using the replacement rule: Poisson(rate) --> rate')
|
|
219
|
+
|
|
220
|
+
def _jax_wrapped_calc_poisson_determinized(key, rate, param):
|
|
221
|
+
return rate
|
|
222
|
+
|
|
223
|
+
return _jax_wrapped_calc_poisson_determinized, None
|
|
224
|
+
|
|
225
|
+
def geometric(self, logic):
|
|
226
|
+
if logic.verbose:
|
|
227
|
+
raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
|
|
228
|
+
|
|
229
|
+
def _jax_wrapped_calc_geometric_determinized(key, prob, param):
|
|
230
|
+
sample = 1.0 / prob
|
|
231
|
+
return sample
|
|
232
|
+
|
|
233
|
+
return _jax_wrapped_calc_geometric_determinized, None
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# ===========================================================================
|
|
237
|
+
# FUZZY LOGIC
|
|
238
|
+
#
|
|
239
|
+
# ===========================================================================
|
|
240
|
+
|
|
91
241
|
class FuzzyLogic:
|
|
92
242
|
'''A class representing fuzzy logic in JAX.
|
|
93
243
|
|
|
@@ -98,9 +248,10 @@ class FuzzyLogic:
|
|
|
98
248
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
99
249
|
complement: Complement=StandardComplement(),
|
|
100
250
|
comparison: Comparison=SigmoidComparison(),
|
|
251
|
+
sampling: RandomSampling=GumbelSoftmax(),
|
|
101
252
|
weight: float=10.0,
|
|
102
253
|
debias: Optional[Set[str]]=None,
|
|
103
|
-
eps: float=1e-
|
|
254
|
+
eps: float=1e-15,
|
|
104
255
|
verbose: bool=False,
|
|
105
256
|
use64bit: bool=False) -> None:
|
|
106
257
|
'''Creates a new fuzzy logic in Jax.
|
|
@@ -108,8 +259,8 @@ class FuzzyLogic:
|
|
|
108
259
|
:param tnorm: fuzzy operator for logical AND
|
|
109
260
|
:param complement: fuzzy operator for logical NOT
|
|
110
261
|
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
262
|
+
:param sampling: random sampling of non-reparameterizable distributions
|
|
111
263
|
:param weight: a sharpness parameter for sigmoid and softmax activations
|
|
112
|
-
:param error: an error parameter (e.g. floor) (smaller means better accuracy)
|
|
113
264
|
:param debias: which functions to de-bias approximate on forward pass
|
|
114
265
|
:param eps: small positive float to mitigate underflow
|
|
115
266
|
:param verbose: whether to dump replacements and other info to console
|
|
@@ -118,6 +269,7 @@ class FuzzyLogic:
|
|
|
118
269
|
self.tnorm = tnorm
|
|
119
270
|
self.complement = complement
|
|
120
271
|
self.comparison = comparison
|
|
272
|
+
self.sampling = sampling
|
|
121
273
|
self.weight = float(weight)
|
|
122
274
|
if debias is None:
|
|
123
275
|
debias = set()
|
|
@@ -142,10 +294,11 @@ class FuzzyLogic:
|
|
|
142
294
|
f' tnorm ={type(self.tnorm).__name__}\n'
|
|
143
295
|
f' complement ={type(self.complement).__name__}\n'
|
|
144
296
|
f' comparison ={type(self.comparison).__name__}\n'
|
|
297
|
+
f' sampling ={type(self.sampling).__name__}\n'
|
|
145
298
|
f' sigmoid_weight={self.weight}\n'
|
|
146
299
|
f' cpfs_to_debias={self.debias}\n'
|
|
147
300
|
f' underflow_tol ={self.eps}\n'
|
|
148
|
-
f'
|
|
301
|
+
f' use_64_bit ={self.use64bit}')
|
|
149
302
|
|
|
150
303
|
# ===========================================================================
|
|
151
304
|
# logical operators
|
|
@@ -419,7 +572,7 @@ class FuzzyLogic:
|
|
|
419
572
|
# ===========================================================================
|
|
420
573
|
|
|
421
574
|
@staticmethod
|
|
422
|
-
def
|
|
575
|
+
def enumerate_literals(shape, axis):
|
|
423
576
|
literals = jnp.arange(shape[axis])
|
|
424
577
|
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
425
578
|
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
@@ -434,7 +587,7 @@ class FuzzyLogic:
|
|
|
434
587
|
debias = 'argmax' in self.debias
|
|
435
588
|
|
|
436
589
|
def _jax_wrapped_calc_argmax_approx(x, axis, param):
|
|
437
|
-
literals = FuzzyLogic.
|
|
590
|
+
literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
|
|
438
591
|
soft_max = jax.nn.softmax(param * x, axis=axis)
|
|
439
592
|
sample = jnp.sum(literals * soft_max, axis=axis)
|
|
440
593
|
if debias:
|
|
@@ -468,7 +621,7 @@ class FuzzyLogic:
|
|
|
468
621
|
def _jax_wrapped_calc_if_approx(c, a, b, param):
|
|
469
622
|
sample = c * a + (1.0 - c) * b
|
|
470
623
|
if debias:
|
|
471
|
-
hard_sample = jnp.
|
|
624
|
+
hard_sample = jnp.where(c > 0.5, a, b)
|
|
472
625
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
473
626
|
return sample
|
|
474
627
|
|
|
@@ -483,7 +636,7 @@ class FuzzyLogic:
|
|
|
483
636
|
debias = 'switch' in self.debias
|
|
484
637
|
|
|
485
638
|
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
486
|
-
literals = FuzzyLogic.
|
|
639
|
+
literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
|
|
487
640
|
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
488
641
|
proximity = -jnp.abs(pred - literals)
|
|
489
642
|
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
@@ -502,44 +655,24 @@ class FuzzyLogic:
|
|
|
502
655
|
# random variables
|
|
503
656
|
# ===========================================================================
|
|
504
657
|
|
|
505
|
-
def
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
return sample
|
|
509
|
-
|
|
658
|
+
def discrete(self):
|
|
659
|
+
return self.sampling.discrete(self)
|
|
660
|
+
|
|
510
661
|
def bernoulli(self):
|
|
511
|
-
|
|
512
|
-
raise_warning('Using the replacement rule: '
|
|
513
|
-
'Bernoulli(p) --> Gumbel-softmax(p)')
|
|
514
|
-
|
|
515
|
-
jax_gs = self._gumbel_softmax
|
|
516
|
-
jax_argmax, jax_param = self.argmax()
|
|
517
|
-
|
|
518
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
519
|
-
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
520
|
-
sample = jax_gs(key, prob)
|
|
521
|
-
sample = jax_argmax(sample, axis=-1, param=param)
|
|
522
|
-
return sample
|
|
523
|
-
|
|
524
|
-
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
662
|
+
return self.sampling.bernoulli(self)
|
|
525
663
|
|
|
526
|
-
def
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
jax_argmax, jax_param = self.argmax()
|
|
533
|
-
|
|
534
|
-
def _jax_wrapped_calc_discrete_approx(key, prob, param):
|
|
535
|
-
sample = jax_gs(key, prob)
|
|
536
|
-
sample = jax_argmax(sample, axis=-1, param=param)
|
|
537
|
-
return sample
|
|
538
|
-
|
|
539
|
-
return _jax_wrapped_calc_discrete_approx, jax_param
|
|
540
|
-
|
|
664
|
+
def poisson(self):
|
|
665
|
+
return self.sampling.poisson(self)
|
|
666
|
+
|
|
667
|
+
def geometric(self):
|
|
668
|
+
return self.sampling.geometric(self)
|
|
669
|
+
|
|
541
670
|
|
|
671
|
+
# ===========================================================================
|
|
542
672
|
# UNIT TESTS
|
|
673
|
+
#
|
|
674
|
+
# ===========================================================================
|
|
675
|
+
|
|
543
676
|
logic = FuzzyLogic()
|
|
544
677
|
w = 100.0
|
|
545
678
|
|
|
@@ -598,13 +731,14 @@ def _test_random():
|
|
|
598
731
|
key = random.PRNGKey(42)
|
|
599
732
|
_bernoulli, _ = logic.bernoulli()
|
|
600
733
|
_discrete, _ = logic.discrete()
|
|
734
|
+
_geometric, _ = logic.geometric()
|
|
601
735
|
|
|
602
736
|
def bern(n):
|
|
603
737
|
prob = jnp.asarray([0.3] * n)
|
|
604
738
|
sample = _bernoulli(key, prob, w)
|
|
605
739
|
return sample
|
|
606
740
|
|
|
607
|
-
samples = bern(
|
|
741
|
+
samples = bern(50000)
|
|
608
742
|
print(jnp.mean(samples))
|
|
609
743
|
|
|
610
744
|
def disc(n):
|
|
@@ -613,10 +747,18 @@ def _test_random():
|
|
|
613
747
|
sample = _discrete(key, prob, w)
|
|
614
748
|
return sample
|
|
615
749
|
|
|
616
|
-
samples = disc(
|
|
750
|
+
samples = disc(50000)
|
|
617
751
|
samples = jnp.round(samples)
|
|
618
752
|
print([jnp.mean(samples == i) for i in range(3)])
|
|
619
|
-
|
|
753
|
+
|
|
754
|
+
def geom(n):
|
|
755
|
+
prob = jnp.asarray([0.3] * n)
|
|
756
|
+
sample = _geometric(key, prob, w)
|
|
757
|
+
return sample
|
|
758
|
+
|
|
759
|
+
samples = geom(50000)
|
|
760
|
+
print(jnp.mean(samples))
|
|
761
|
+
|
|
620
762
|
|
|
621
763
|
def _test_rounding():
|
|
622
764
|
print('testing rounding')
|
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -2,54 +2,51 @@ from ast import literal_eval
|
|
|
2
2
|
from collections import deque
|
|
3
3
|
import configparser
|
|
4
4
|
from enum import Enum
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
10
|
+
|
|
5
11
|
import haiku as hk
|
|
6
12
|
import jax
|
|
13
|
+
import jax.nn.initializers as initializers
|
|
7
14
|
import jax.numpy as jnp
|
|
8
15
|
import jax.random as random
|
|
9
|
-
import jax.nn.initializers as initializers
|
|
10
16
|
import numpy as np
|
|
11
17
|
import optax
|
|
12
|
-
import os
|
|
13
|
-
import sys
|
|
14
18
|
import termcolor
|
|
15
|
-
import time
|
|
16
|
-
import traceback
|
|
17
19
|
from tqdm import tqdm
|
|
18
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
19
20
|
|
|
20
|
-
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
21
|
-
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
22
|
-
Kwargs = Dict[str, Any]
|
|
23
|
-
Pytree = Any
|
|
24
|
-
|
|
25
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
26
|
-
|
|
27
|
-
from pyRDDLGym_jax import __version__
|
|
28
|
-
|
|
29
|
-
# try to import matplotlib, if failed then skip plotting
|
|
30
|
-
try:
|
|
31
|
-
import matplotlib
|
|
32
|
-
import matplotlib.pyplot as plt
|
|
33
|
-
matplotlib.use('TkAgg')
|
|
34
|
-
except Exception:
|
|
35
|
-
raise_warning('failed to import matplotlib: '
|
|
36
|
-
'plotting functionality will be disabled.', 'red')
|
|
37
|
-
traceback.print_exc()
|
|
38
|
-
plt = None
|
|
39
|
-
|
|
40
21
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
41
22
|
from pyRDDLGym.core.debug.logger import Logger
|
|
42
23
|
from pyRDDLGym.core.debug.exception import (
|
|
24
|
+
raise_warning,
|
|
43
25
|
RDDLNotImplementedError,
|
|
44
26
|
RDDLUndefinedVariableError,
|
|
45
27
|
RDDLTypeError
|
|
46
28
|
)
|
|
47
29
|
from pyRDDLGym.core.policy import BaseAgent
|
|
48
30
|
|
|
49
|
-
from pyRDDLGym_jax
|
|
31
|
+
from pyRDDLGym_jax import __version__
|
|
50
32
|
from pyRDDLGym_jax.core import logic
|
|
33
|
+
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
51
34
|
from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
52
35
|
|
|
36
|
+
# try to import matplotlib, if failed then skip plotting
|
|
37
|
+
try:
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
except Exception:
|
|
40
|
+
raise_warning('failed to import matplotlib: '
|
|
41
|
+
'plotting functionality will be disabled.', 'red')
|
|
42
|
+
traceback.print_exc()
|
|
43
|
+
plt = None
|
|
44
|
+
|
|
45
|
+
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
46
|
+
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
47
|
+
Kwargs = Dict[str, Any]
|
|
48
|
+
Pytree = Any
|
|
49
|
+
|
|
53
50
|
|
|
54
51
|
# ***********************************************************************
|
|
55
52
|
# CONFIG FILE MANAGEMENT
|
|
@@ -104,9 +101,12 @@ def _load_config(config, args):
|
|
|
104
101
|
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
105
102
|
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
106
103
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
104
|
+
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
105
|
+
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
107
106
|
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
108
107
|
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
109
108
|
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
109
|
+
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
110
110
|
|
|
111
111
|
# read the policy settings
|
|
112
112
|
plan_method = planner_args.pop('method')
|
|
@@ -184,18 +184,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
184
184
|
#
|
|
185
185
|
# ***********************************************************************
|
|
186
186
|
|
|
187
|
-
def _function_discrete_approx_named(logic):
|
|
188
|
-
jax_discrete, jax_param = logic.discrete()
|
|
189
|
-
|
|
190
|
-
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
191
|
-
sample = jax_discrete(key, prob, params)
|
|
192
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
193
|
-
jnp.all(prob >= 0),
|
|
194
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
195
|
-
return sample, out_of_bounds
|
|
196
|
-
|
|
197
|
-
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
198
|
-
|
|
199
187
|
|
|
200
188
|
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
201
189
|
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
@@ -271,7 +259,9 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
271
259
|
self.IF_HELPER = logic.control_if()
|
|
272
260
|
self.SWITCH_HELPER = logic.control_switch()
|
|
273
261
|
self.BERNOULLI_HELPER = logic.bernoulli()
|
|
274
|
-
self.DISCRETE_HELPER =
|
|
262
|
+
self.DISCRETE_HELPER = logic.discrete()
|
|
263
|
+
self.POISSON_HELPER = logic.poisson()
|
|
264
|
+
self.GEOMETRIC_HELPER = logic.geometric()
|
|
275
265
|
|
|
276
266
|
def _jax_stop_grad(self, jax_expr):
|
|
277
267
|
|
|
@@ -469,7 +459,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
469
459
|
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
470
460
|
f'constraint-sat strategy (complex):\n'
|
|
471
461
|
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
472
|
-
f' use_new_projection ={self._use_new_projection}'
|
|
462
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
463
|
+
f' max_projection_iters ={self._max_constraint_iter}')
|
|
473
464
|
|
|
474
465
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
475
466
|
_bounds: Bounds,
|
|
@@ -1348,8 +1339,18 @@ class JaxBackpropPlanner:
|
|
|
1348
1339
|
map(str, jax._src.xla_bridge.devices())).replace('\n', '')
|
|
1349
1340
|
except Exception as _:
|
|
1350
1341
|
devices_short = 'N/A'
|
|
1342
|
+
LOGO = \
|
|
1343
|
+
"""
|
|
1344
|
+
__ ______ __ __ ______ __ ______ __ __
|
|
1345
|
+
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1346
|
+
_\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
|
|
1347
|
+
/\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
|
|
1348
|
+
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1349
|
+
"""
|
|
1350
|
+
|
|
1351
1351
|
print('\n'
|
|
1352
|
-
f'
|
|
1352
|
+
f'{LOGO}\n'
|
|
1353
|
+
f'Version {__version__}\n'
|
|
1353
1354
|
f'Python {sys.version}\n'
|
|
1354
1355
|
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1355
1356
|
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
@@ -1711,6 +1712,14 @@ class JaxBackpropPlanner:
|
|
|
1711
1712
|
hyperparam_value = float(policy_hyperparams)
|
|
1712
1713
|
policy_hyperparams = {action: hyperparam_value
|
|
1713
1714
|
for action in self.rddl.action_fluents}
|
|
1715
|
+
|
|
1716
|
+
# fill in missing entries
|
|
1717
|
+
elif isinstance(policy_hyperparams, dict):
|
|
1718
|
+
for action in self.rddl.action_fluents:
|
|
1719
|
+
if action not in policy_hyperparams:
|
|
1720
|
+
raise_warning(f'policy_hyperparams[{action}] is not set, '
|
|
1721
|
+
'setting 1.0 which could be suboptimal.')
|
|
1722
|
+
policy_hyperparams[action] = 1.0
|
|
1714
1723
|
|
|
1715
1724
|
# print summary of parameters:
|
|
1716
1725
|
if print_summary:
|
|
@@ -1772,6 +1781,7 @@ class JaxBackpropPlanner:
|
|
|
1772
1781
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1773
1782
|
log = {}
|
|
1774
1783
|
status = JaxPlannerStatus.NORMAL
|
|
1784
|
+
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1775
1785
|
|
|
1776
1786
|
# initialize plot area
|
|
1777
1787
|
if plot_step is None or plot_step <= 0 or plt is None:
|
|
@@ -1786,6 +1796,7 @@ class JaxBackpropPlanner:
|
|
|
1786
1796
|
iters = range(epochs)
|
|
1787
1797
|
if print_progress:
|
|
1788
1798
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1799
|
+
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1789
1800
|
|
|
1790
1801
|
for it in iters:
|
|
1791
1802
|
status = JaxPlannerStatus.NORMAL
|
|
@@ -1799,7 +1810,7 @@ class JaxBackpropPlanner:
|
|
|
1799
1810
|
|
|
1800
1811
|
# no progress
|
|
1801
1812
|
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1802
|
-
jax.tree_map(
|
|
1813
|
+
jax.tree_map(is_all_zero_fn, train_log['grad']))
|
|
1803
1814
|
if np.all(grad_norm_zero):
|
|
1804
1815
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
1805
1816
|
|
|
@@ -1843,8 +1854,9 @@ class JaxBackpropPlanner:
|
|
|
1843
1854
|
if print_progress:
|
|
1844
1855
|
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1845
1856
|
iters.set_description(
|
|
1846
|
-
f'
|
|
1847
|
-
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best'
|
|
1857
|
+
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1858
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1859
|
+
f'{status.value} status')
|
|
1848
1860
|
|
|
1849
1861
|
# reached computation budget
|
|
1850
1862
|
if elapsed >= train_seconds:
|
|
@@ -1904,7 +1916,7 @@ class JaxBackpropPlanner:
|
|
|
1904
1916
|
f' iterations ={it}\n'
|
|
1905
1917
|
f' best_objective={-best_loss}\n'
|
|
1906
1918
|
f' best_grad_norm={grad_norm}\n'
|
|
1907
|
-
f'diagnosis: {diagnosis}\n')
|
|
1919
|
+
f' diagnosis: {diagnosis}\n')
|
|
1908
1920
|
|
|
1909
1921
|
def _perform_diagnosis(self, last_iter_improve,
|
|
1910
1922
|
train_return, test_return, best_return, grad_norm):
|
|
@@ -2116,7 +2128,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
2116
2128
|
@jax.jit
|
|
2117
2129
|
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2118
2130
|
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
2119
|
-
|
|
2131
|
+
-beta * returns, b=1.0 / returns.size)
|
|
2120
2132
|
|
|
2121
2133
|
|
|
2122
2134
|
@jax.jit
|
pyRDDLGym_jax/core/simulator.py
CHANGED
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -1,20 +1,18 @@
|
|
|
1
|
-
from bayes_opt import BayesianOptimization
|
|
2
|
-
from bayes_opt.util import UtilityFunction
|
|
3
1
|
from copy import deepcopy
|
|
4
2
|
import csv
|
|
5
3
|
import datetime
|
|
6
|
-
import jax
|
|
7
4
|
from multiprocessing import get_context
|
|
8
|
-
import numpy as np
|
|
9
5
|
import os
|
|
10
6
|
import time
|
|
11
7
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
12
|
-
|
|
13
|
-
Kwargs = Dict[str, Any]
|
|
14
|
-
|
|
15
8
|
import warnings
|
|
16
9
|
warnings.filterwarnings("ignore")
|
|
17
10
|
|
|
11
|
+
from bayes_opt import BayesianOptimization
|
|
12
|
+
from bayes_opt.util import UtilityFunction
|
|
13
|
+
import jax
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
18
16
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
19
17
|
from pyRDDLGym.core.env import RDDLEnv
|
|
20
18
|
|
|
@@ -26,6 +24,8 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
26
24
|
JaxOnlineController
|
|
27
25
|
)
|
|
28
26
|
|
|
27
|
+
Kwargs = Dict[str, Any]
|
|
28
|
+
|
|
29
29
|
|
|
30
30
|
# ===============================================================================
|
|
31
31
|
#
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: pyRDDLGym-jax
|
|
3
|
+
Version: 0.4
|
|
4
|
+
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
|
+
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
|
+
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
|
+
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
|
+
License: MIT License
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Natural Language :: English
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: pyRDDLGym >=2.0
|
|
20
|
+
Requires-Dist: tqdm >=4.66
|
|
21
|
+
Requires-Dist: bayesian-optimization >=1.4.3
|
|
22
|
+
Requires-Dist: jax >=0.4.12
|
|
23
|
+
Requires-Dist: optax >=0.1.9
|
|
24
|
+
Requires-Dist: dm-haiku >=0.0.10
|
|
25
|
+
Requires-Dist: tensorflow-probability >=0.21.0
|
|
26
|
+
|
|
27
|
+
# pyRDDLGym-jax
|
|
28
|
+
|
|
29
|
+
Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
|
|
30
|
+
|
|
31
|
+
This directory provides:
|
|
32
|
+
1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
|
|
33
|
+
2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
|
|
34
|
+
|
|
35
|
+
> [!NOTE]
|
|
36
|
+
> While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
37
|
+
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
38
|
+
|
|
39
|
+
## Contents
|
|
40
|
+
|
|
41
|
+
- [Installation](#installation)
|
|
42
|
+
- [Running from the Command Line](#running-from-the-command-line)
|
|
43
|
+
- [Running from within Python](#running-from-within-python)
|
|
44
|
+
- [Configuring the Planner](#configuring-the-planner)
|
|
45
|
+
- [Simulation](#simulation)
|
|
46
|
+
- [Manual Gradient Calculation](#manual-gradient-calculation)
|
|
47
|
+
- [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
|
|
52
|
+
- ``pyRDDLGym>=2.0``
|
|
53
|
+
- ``tqdm>=4.66``
|
|
54
|
+
- ``jax>=0.4.12``
|
|
55
|
+
- ``optax>=0.1.9``
|
|
56
|
+
- ``dm-haiku>=0.0.10``
|
|
57
|
+
- ``tensorflow-probability>=0.21.0``
|
|
58
|
+
|
|
59
|
+
Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
|
|
60
|
+
To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
|
|
61
|
+
|
|
62
|
+
You can install this package, together with all of its requirements, via pip:
|
|
63
|
+
|
|
64
|
+
```shell
|
|
65
|
+
pip install rddlrepository pyRDDLGym-jax
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## Running from the Command Line
|
|
69
|
+
|
|
70
|
+
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
|
|
71
|
+
|
|
72
|
+
```shell
|
|
73
|
+
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
where:
|
|
77
|
+
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
|
|
78
|
+
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
|
|
79
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
80
|
+
- ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
|
|
81
|
+
|
|
82
|
+
The ``method`` parameter supports three possible modes:
|
|
83
|
+
- ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
84
|
+
- ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
85
|
+
- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
|
|
86
|
+
|
|
87
|
+
A basic run script is also provided to run the automatic hyper-parameter tuning:
|
|
88
|
+
|
|
89
|
+
```shell
|
|
90
|
+
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
where:
|
|
94
|
+
- ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
|
|
95
|
+
- ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
|
|
96
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
97
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
98
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
99
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
|
|
100
|
+
|
|
101
|
+
For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
|
|
102
|
+
|
|
103
|
+
```shell
|
|
104
|
+
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
After several minutes of optimization, you should get a visualization as follows:
|
|
108
|
+
|
|
109
|
+
<p align="center">
|
|
110
|
+
<img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
|
|
111
|
+
</p>
|
|
112
|
+
|
|
113
|
+
## Running from within Python
|
|
114
|
+
|
|
115
|
+
To run the Jax planner from within a Python application, refer to the following example:
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
import pyRDDLGym
|
|
119
|
+
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
|
|
120
|
+
|
|
121
|
+
# set up the environment (note the vectorized option must be True)
|
|
122
|
+
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
123
|
+
|
|
124
|
+
# create the planning algorithm
|
|
125
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
126
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
127
|
+
|
|
128
|
+
# evaluate the planner
|
|
129
|
+
controller.evaluate(env, episodes=1, verbose=True, render=True)
|
|
130
|
+
env.close()
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
|
|
134
|
+
All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
|
|
135
|
+
The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
|
|
136
|
+
|
|
137
|
+
## Configuring the Planner
|
|
138
|
+
|
|
139
|
+
The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
|
|
140
|
+
The basic structure of a configuration file is provided below for a straight-line planner:
|
|
141
|
+
|
|
142
|
+
```ini
|
|
143
|
+
[Model]
|
|
144
|
+
logic='FuzzyLogic'
|
|
145
|
+
logic_kwargs={'weight': 20}
|
|
146
|
+
tnorm='ProductTNorm'
|
|
147
|
+
tnorm_kwargs={}
|
|
148
|
+
|
|
149
|
+
[Optimizer]
|
|
150
|
+
method='JaxStraightLinePlan'
|
|
151
|
+
method_kwargs={}
|
|
152
|
+
optimizer='rmsprop'
|
|
153
|
+
optimizer_kwargs={'learning_rate': 0.001}
|
|
154
|
+
batch_size_train=1
|
|
155
|
+
batch_size_test=1
|
|
156
|
+
|
|
157
|
+
[Training]
|
|
158
|
+
key=42
|
|
159
|
+
epochs=5000
|
|
160
|
+
train_seconds=30
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
The configuration file contains three sections:
|
|
164
|
+
- ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
|
|
165
|
+
and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
|
|
166
|
+
- ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
|
|
167
|
+
- ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
|
|
168
|
+
|
|
169
|
+
For a policy network approach, simply change the ``[Optimizer]`` settings like so:
|
|
170
|
+
|
|
171
|
+
```ini
|
|
172
|
+
...
|
|
173
|
+
[Optimizer]
|
|
174
|
+
method='JaxDeepReactivePolicy'
|
|
175
|
+
method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
|
|
176
|
+
...
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
The configuration file must then be passed to the planner during initialization.
|
|
180
|
+
For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
|
|
181
|
+
|
|
182
|
+
```python
|
|
183
|
+
from pyRDDLGym_jax.core.planner import load_config
|
|
184
|
+
|
|
185
|
+
# load the config file with planner settings
|
|
186
|
+
planner_args, _, train_args = load_config("/path/to/config.cfg")
|
|
187
|
+
|
|
188
|
+
# create the planning algorithm
|
|
189
|
+
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
|
|
190
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
191
|
+
...
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
## Simulation
|
|
195
|
+
|
|
196
|
+
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
|
|
197
|
+
|
|
198
|
+
```python
|
|
199
|
+
import pyRDDLGym
|
|
200
|
+
from pyRDDLGym.core.policy import RandomAgent
|
|
201
|
+
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
|
|
202
|
+
|
|
203
|
+
# create the environment
|
|
204
|
+
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
|
|
205
|
+
|
|
206
|
+
# evaluate the random policy
|
|
207
|
+
agent = RandomAgent(action_space=env.action_space,
|
|
208
|
+
num_actions=env.max_allowed_actions)
|
|
209
|
+
agent.evaluate(env, verbose=True, render=True)
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
|
|
213
|
+
In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
|
|
214
|
+
|
|
215
|
+
## Manual Gradient Calculation
|
|
216
|
+
|
|
217
|
+
For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
|
|
218
|
+
Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
|
|
219
|
+
|
|
220
|
+
```python
|
|
221
|
+
import pyRDDLGym
|
|
222
|
+
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
223
|
+
|
|
224
|
+
# set up the environment
|
|
225
|
+
env = pyRDDLGym.make("domain", "instance", vectorized=True)
|
|
226
|
+
|
|
227
|
+
# create the step function
|
|
228
|
+
compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
|
|
229
|
+
compiled.compile()
|
|
230
|
+
step_fn = compiled.compile_transition()
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
This will return a JAX compiled (pure) function requiring the following inputs:
|
|
234
|
+
- ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
|
|
235
|
+
- ``actions`` is the dictionary of action fluent tensors
|
|
236
|
+
- ``subs`` is the dictionary of state-fluent and non-fluent tensors
|
|
237
|
+
- ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
|
|
238
|
+
|
|
239
|
+
The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
|
|
240
|
+
It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
|
|
241
|
+
|
|
242
|
+
Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
|
|
243
|
+
An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
|
|
244
|
+
|
|
245
|
+
## Citing pyRDDLGym-jax
|
|
246
|
+
|
|
247
|
+
The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
|
|
248
|
+
|
|
249
|
+
```
|
|
250
|
+
@inproceedings{gimelfarb2024jaxplan,
|
|
251
|
+
title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
|
|
252
|
+
author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
|
|
253
|
+
booktitle={34th International Conference on Automated Planning and Scheduling},
|
|
254
|
+
year={2024},
|
|
255
|
+
url={https://openreview.net/forum?id=7IKtmUpLEH}
|
|
256
|
+
}
|
|
257
|
+
```
|
|
258
|
+
|
|
259
|
+
The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
|
|
260
|
+
|
|
261
|
+
```
|
|
262
|
+
@inproceedings{patton2022distributional,
|
|
263
|
+
title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
|
|
264
|
+
author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
|
|
265
|
+
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
|
266
|
+
volume={36},
|
|
267
|
+
number={9},
|
|
268
|
+
pages={9894--9901},
|
|
269
|
+
year={2022}
|
|
270
|
+
}
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
|
|
274
|
+
- [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
|
|
275
|
+
- [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
|
|
276
|
+
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=rexmxcBiCOcwctw4wGvk7UxS9MfZn_1CYXp53SoLKlU,19
|
|
2
2
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
4
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/simulator.py,sha256=
|
|
7
|
-
pyRDDLGym_jax/core/tuning.py,sha256=
|
|
3
|
+
pyRDDLGym_jax/core/compiler.py,sha256=SnDN3-J84Wv_YVHoDmfM_U4Ob8uaFLGX4vEaeWC-ERY,90037
|
|
4
|
+
pyRDDLGym_jax/core/logic.py,sha256=o1YAjMnXfi8gwb42kAigBmaf9uIYUWal9__FEkWohrk,26733
|
|
5
|
+
pyRDDLGym_jax/core/planner.py,sha256=Hrwfn88bUu1LNZcnFC5psHPzcIUbPeF4Rn1pFO6_qH0,102655
|
|
6
|
+
pyRDDLGym_jax/core/simulator.py,sha256=hWv6pr-4V-SSCzBYgdIPmKdUDMalft-Zh6dzOo5O9-0,8331
|
|
7
|
+
pyRDDLGym_jax/core/tuning.py,sha256=D_kD8wjqMroCdtjE9eksR2UqrqXJqazsAKrMEHwPxYM,29589
|
|
8
8
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
10
10
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
@@ -37,8 +37,8 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
|
|
|
37
37
|
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=S2-5hPZtgAwUAFpiCAgSi-cnGhYHSDzMGMmatwhbM78,344
|
|
38
38
|
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=VWWPhOYBRq4cWwtrChw5pPqRmlX_nHbMvwciHd9hoLc,357
|
|
39
39
|
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=TG3mtHUnCA7J2Gm9SczENpqAymTnzCE9dj1Z_R-FnVk,340
|
|
40
|
-
pyRDDLGym_jax-0.
|
|
41
|
-
pyRDDLGym_jax-0.
|
|
42
|
-
pyRDDLGym_jax-0.
|
|
43
|
-
pyRDDLGym_jax-0.
|
|
44
|
-
pyRDDLGym_jax-0.
|
|
40
|
+
pyRDDLGym_jax-0.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
41
|
+
pyRDDLGym_jax-0.4.dist-info/METADATA,sha256=-Kf8PLxf_7MiiYXzlZAf31kV1pT-Rurc7QY7dT3Fwk0,12857
|
|
42
|
+
pyRDDLGym_jax-0.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
43
|
+
pyRDDLGym_jax-0.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
44
|
+
pyRDDLGym_jax-0.4.dist-info/RECORD,,
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 0.3
|
|
4
|
-
Summary: pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
|
|
5
|
-
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
|
-
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
7
|
-
Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
|
|
8
|
-
License: MIT License
|
|
9
|
-
Classifier: Development Status :: 3 - Alpha
|
|
10
|
-
Classifier: Intended Audience :: Science/Research
|
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
-
Classifier: Natural Language :: English
|
|
13
|
-
Classifier: Operating System :: OS Independent
|
|
14
|
-
Classifier: Programming Language :: Python :: 3
|
|
15
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
-
Requires-Python: >=3.8
|
|
17
|
-
License-File: LICENSE
|
|
18
|
-
Requires-Dist: pyRDDLGym >=2.0
|
|
19
|
-
Requires-Dist: tqdm >=4.66
|
|
20
|
-
Requires-Dist: bayesian-optimization >=1.4.3
|
|
21
|
-
Requires-Dist: jax >=0.4.12
|
|
22
|
-
Requires-Dist: optax >=0.1.9
|
|
23
|
-
Requires-Dist: dm-haiku >=0.0.10
|
|
24
|
-
Requires-Dist: tensorflow >=2.13.0
|
|
25
|
-
Requires-Dist: tensorflow-probability >=0.21.0
|
|
26
|
-
|
|
File without changes
|
|
File without changes
|