pyRDDLGym-jax 2.7__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -33
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.7.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.7.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version:
|
|
3
|
+
Version: 3.0
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -76,6 +76,10 @@ Some demos of solved problems by JaxPlan:
|
|
|
76
76
|
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/reservoir.gif" width="120" height="120" margin=0/>
|
|
77
77
|
</p>
|
|
78
78
|
|
|
79
|
+
> [!WARNING]
|
|
80
|
+
> Starting in version 3.0 (major release), the structure of the config files and internal API have changed.
|
|
81
|
+
> Please make sure your config files follow the new format. See [examples here](https://github.com/pyrddlgym-project/pyRDDLGym-jax/tree/main/pyRDDLGym_jax/examples/configs).
|
|
82
|
+
|
|
79
83
|
> [!WARNING]
|
|
80
84
|
> Starting in version 1.0 (major release), the ``weight`` parameter in the config file was removed,
|
|
81
85
|
and was moved to the individual logic components which have their own unique weight parameter assigned.
|
|
@@ -168,35 +172,32 @@ The simplest way to configure the planner is to write and pass a configuration f
|
|
|
168
172
|
The basic structure of a configuration file is provided below for a straight-line planner:
|
|
169
173
|
|
|
170
174
|
```ini
|
|
171
|
-
[
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
rounding_kwargs={'weight': 20}
|
|
175
|
-
control_kwargs={'weight': 20}
|
|
175
|
+
[Compiler]
|
|
176
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
177
|
+
sigmoid_weight=20
|
|
176
178
|
|
|
177
|
-
[
|
|
179
|
+
[Planner]
|
|
178
180
|
method='JaxStraightLinePlan'
|
|
179
181
|
method_kwargs={}
|
|
180
182
|
optimizer='rmsprop'
|
|
181
183
|
optimizer_kwargs={'learning_rate': 0.001}
|
|
182
184
|
|
|
183
|
-
[
|
|
185
|
+
[Optimize]
|
|
184
186
|
key=42
|
|
185
187
|
epochs=5000
|
|
186
188
|
train_seconds=30
|
|
187
189
|
```
|
|
188
190
|
|
|
189
191
|
The configuration file contains three sections:
|
|
190
|
-
- ``[
|
|
191
|
-
|
|
192
|
-
- ``[
|
|
193
|
-
- ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
|
|
192
|
+
- ``[Compiler]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation
|
|
193
|
+
- ``[Planner]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
|
|
194
|
+
- ``[Optimize]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
|
|
194
195
|
|
|
195
|
-
For a policy network approach, simply change the ``[
|
|
196
|
+
For a policy network approach, simply change the ``[Planner]`` settings like so:
|
|
196
197
|
|
|
197
198
|
```ini
|
|
198
199
|
...
|
|
199
|
-
[
|
|
200
|
+
[Planner]
|
|
200
201
|
method='JaxDeepReactivePolicy'
|
|
201
202
|
method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
|
|
202
203
|
...
|
|
@@ -224,7 +225,7 @@ and visualization of the policy or model, and other useful debugging features. T
|
|
|
224
225
|
|
|
225
226
|
```ini
|
|
226
227
|
...
|
|
227
|
-
[
|
|
228
|
+
[Planner]
|
|
228
229
|
dashboard=True
|
|
229
230
|
...
|
|
230
231
|
```
|
|
@@ -251,19 +252,17 @@ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
|
251
252
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
252
253
|
|
|
253
254
|
```ini
|
|
254
|
-
[
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
rounding_kwargs={'weight': TUNABLE_WEIGHT}
|
|
258
|
-
control_kwargs={'weight': TUNABLE_WEIGHT}
|
|
255
|
+
[Compiler]
|
|
256
|
+
method='DefaultJaxRDDLCompilerWithGrad'
|
|
257
|
+
sigmoid_weight=TUNABLE_WEIGHT
|
|
259
258
|
|
|
260
|
-
[
|
|
259
|
+
[Planner]
|
|
261
260
|
method='JaxStraightLinePlan'
|
|
262
261
|
method_kwargs={}
|
|
263
262
|
optimizer='rmsprop'
|
|
264
263
|
optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
|
|
265
264
|
|
|
266
|
-
[
|
|
265
|
+
[Optimize]
|
|
267
266
|
train_seconds=30
|
|
268
267
|
print_summary=False
|
|
269
268
|
print_progress=False
|
|
@@ -272,7 +271,7 @@ train_on_reset=True
|
|
|
272
271
|
|
|
273
272
|
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
274
273
|
|
|
275
|
-
Next,
|
|
274
|
+
Next, link the hyperparameters in the config with concrete ranges that the optimizer will use, and run the optimizer:
|
|
276
275
|
|
|
277
276
|
```python
|
|
278
277
|
import pyRDDLGym
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=_dAhU0a8aQT1JEIj5Ct20spMA83IHwekJD1WC_JjeuA,19
|
|
2
|
+
pyRDDLGym_jax/entry_point.py,sha256=cZzGpt_wTGwamVnN91uAJBW1ReSlB4BpyyMIZfoHnVE,3370
|
|
3
|
+
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=Pb4D1mNwNv6zM13VcbU-w-zwX0GfRrbpwANSFKygfnA,95865
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=syL10om8YD8mF-Jst-6byhXoIxByHfDV5b1lCdjTSeQ,74736
|
|
6
|
+
pyRDDLGym_jax/core/model.py,sha256=vBGkYRNr8ieUf_NUBXUO2G5DC8hqA0Mya0AA_d9nhjg,26926
|
|
7
|
+
pyRDDLGym_jax/core/planner.py,sha256=79ZBd8dB7uOOwpT5L0T96KBB2kcfDOkt45VQopZYxzY,145055
|
|
8
|
+
pyRDDLGym_jax/core/simulator.py,sha256=b4cRBl9XcOJZpOUi5A0OEQ3KbOqfb64qegxJI8V6X_8,11070
|
|
9
|
+
pyRDDLGym_jax/core/tuning.py,sha256=EyEs9TmjNhDnvfgZyOU9y7BsF7WBkDXO5Fbq9ss_Mao,25284
|
|
10
|
+
pyRDDLGym_jax/core/visualization.py,sha256=iBTad49pQ8Eu6zyyGZsjhIumBCmd713FgkkQQsWmlsU,72332
|
|
11
|
+
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
13
|
+
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
15
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=6F6Pbd748_J60sQ7Qkr2wyRXOK7J6I_GSZz4Rszw8ug,2810
|
|
16
|
+
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
17
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=Uxapboqn-L7iqzV61oShb8Dh9blrUfv8fF70y07_W8E,4116
|
|
18
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=X-8cweQOPrcQl8V3EMP_DFqh2aDM33STMepPnEkqdeA,291
|
|
19
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=kzOcAixaZxdPEQ5dC4HEtDhyQ7-0poLdvv4bGdjiY0M,314
|
|
20
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=_ycv7u78kbHQ8-qfMAn8Fcc1uK0K5P_-ffSVjkA4Uzw,397
|
|
21
|
+
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=yZ5HcOrJQoY5Dv3Az3rs8zumZ8Z6zKkp5UGgzxT-EMo,336
|
|
22
|
+
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=lbQtjpvPIbND8j8x55nlfpO6q0LpbqDqpIfmVH74urQ,313
|
|
23
|
+
pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=bDDcjeXl1UxSqkV24zWrpM0tE6ugKBH9aub9Mfg0DD4,300
|
|
24
|
+
pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=oMYDcWS4erqoat_CHggy43le1zqdTka0XpNiQFXrzV4,300
|
|
25
|
+
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=jkGAxvIeC9rhNXJAPABIo5r202x4DQYCD8AQGs4rCeI,325
|
|
26
|
+
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=EWbwZym_WvaftKTkL5v3pzOS4l4Vr_V-29SIR_z4EoI,337
|
|
27
|
+
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=lMLGIa-s4ZBM9_YQxPkBwUOi5fCDyxvYGx5-Mr4I4M8,300
|
|
28
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg,sha256=Hpb6_BAnJpZkHbvjX5O-RByrWRdvllO4dV-tHdVAQz0,347
|
|
29
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg,sha256=Hpb6_BAnJpZkHbvjX5O-RByrWRdvllO4dV-tHdVAQz0,347
|
|
30
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg,sha256=TEZjBW_h8_wLwzq5jbHO-EDSPteve1GlCpBapYaE_OI,300
|
|
31
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=TEZjBW_h8_wLwzq5jbHO-EDSPteve1GlCpBapYaE_OI,300
|
|
32
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=A0TC8BdebQjtwMs2ds4PZFpd_37aB56nDq4bcAgYd0k,304
|
|
33
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=5iZNO2LEdM97da9DDi4tiDMRrlZdD5UYCbpU0bPh2AU,317
|
|
34
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=RhJCVN3z4182nV2nQCUhNKeWJvlILFf3Gs3Bu1r1F3I,279
|
|
35
|
+
pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=tdB6GB2HTIa5sJ0zzRDLYiSqJA2QW1gRbbIFvzY29D8,300
|
|
36
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=SRKLX43BJpDZZJ3rc-ALrXZcCD6ZmLtBnUIhQl7SmB4,354
|
|
37
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=hMkY0yboNA5g3o7SGLevCmiRmWJXWVKtE46MXw05IhA,423
|
|
38
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=_7ZQL4NhjFKNUeBqf_Ms45sDqxjKBR5R866dnLoscC0,385
|
|
39
|
+
pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
|
+
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=xpQZwBPhlUPCtwsDIoGOr-jYo05Da8ZBbc4vOOp8FcU,298
|
|
41
|
+
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=43eEofpuFIs6ZPvFjwYv_qxlIzvVfuSCMB8h79RyY1s,332
|
|
42
|
+
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=CZPEx9eOIN4QcQFp9kP14JFfRpFDn1yBt75iSdE7g1I,294
|
|
43
|
+
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=fPGNIES81yD_V0_YA6SYbLm7MEqYLmSfk5hgn7ZsYCM,440
|
|
44
|
+
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=ZWFeohYYEY8a4-VqWd8o7eLvPwGZpDcmYmG2qIpzZtg,437
|
|
45
|
+
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=9GxnahqQmUkaIIZ6yvM_qFRjANOcsBLv9BFj95XFIKw,400
|
|
46
|
+
pyrddlgym_jax-3.0.dist-info/licenses/LICENSE,sha256=2a-BZEY7aEZW-DkmmOQsuUDU0pc6ovQy3QnYFZ4baq4,1095
|
|
47
|
+
pyrddlgym_jax-3.0.dist-info/METADATA,sha256=tABhmjuI3NLSOt2tHiJ0cMZ_mkgAV15BYHsoSa-pECs,16711
|
|
48
|
+
pyrddlgym_jax-3.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
49
|
+
pyrddlgym_jax-3.0.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
50
|
+
pyrddlgym_jax-3.0.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
51
|
+
pyrddlgym_jax-3.0.dist-info/RECORD,,
|
|
@@ -1,102 +0,0 @@
|
|
|
1
|
-
'''In this simple example, gradient of the return for a simple MDP is computed.
|
|
2
|
-
|
|
3
|
-
Setting:
|
|
4
|
-
The policy is linear in the state:
|
|
5
|
-
action = p * state
|
|
6
|
-
|
|
7
|
-
The state evolves as state' = state + action + 1, so for 3-step problem:
|
|
8
|
-
s0 = 0
|
|
9
|
-
s1 = s0 + p * s0 + 1 = 1
|
|
10
|
-
s2 = s1 + p * s1 + 1 = 2 + p
|
|
11
|
-
s3 = s2 + p * s2 + 1 = (1 + p) * (2 + p) + 1 = 3 + 3 * p + p ^ 2
|
|
12
|
-
|
|
13
|
-
The total return is:
|
|
14
|
-
return = 1 + 2 + p + 3 + 3 * p + p ^ 2
|
|
15
|
-
= 6 + 4 * p + p ^ 2
|
|
16
|
-
|
|
17
|
-
The gradient of the return is:
|
|
18
|
-
gradient = 4 + 2 * p
|
|
19
|
-
|
|
20
|
-
The example given uses p = 2, so it should be:
|
|
21
|
-
return = 18, gradient = 8
|
|
22
|
-
|
|
23
|
-
For p = 3, it should be:
|
|
24
|
-
return = 27, gradient = 10
|
|
25
|
-
'''
|
|
26
|
-
|
|
27
|
-
import os
|
|
28
|
-
import sys
|
|
29
|
-
import jax
|
|
30
|
-
|
|
31
|
-
import pyRDDLGym
|
|
32
|
-
|
|
33
|
-
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
34
|
-
|
|
35
|
-
# a simple domain with state' = state + action
|
|
36
|
-
DOMAIN = """
|
|
37
|
-
domain test {
|
|
38
|
-
pvariables {
|
|
39
|
-
nf : { non-fluent, real, default = 1.0 };
|
|
40
|
-
state : { state-fluent, real, default = 0.0 };
|
|
41
|
-
action : { action-fluent, real, default = 0.0 };
|
|
42
|
-
};
|
|
43
|
-
cpfs {
|
|
44
|
-
state' = state + action + nf;
|
|
45
|
-
};
|
|
46
|
-
reward = state';
|
|
47
|
-
}
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
INSTANCE = """
|
|
51
|
-
non-fluents test_nf {
|
|
52
|
-
domain = test;
|
|
53
|
-
}
|
|
54
|
-
instance inst_test {
|
|
55
|
-
domain = test;
|
|
56
|
-
non-fluents = test_nf;
|
|
57
|
-
max-nondef-actions = pos-inf;
|
|
58
|
-
horizon = 5;
|
|
59
|
-
discount = 1.0;
|
|
60
|
-
}
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def main():
|
|
65
|
-
|
|
66
|
-
# create the environment
|
|
67
|
-
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
68
|
-
with open(os.path.join(abs_path, 'domain.rddl'), 'w') as dom_file:
|
|
69
|
-
dom_file.write(DOMAIN)
|
|
70
|
-
with open(os.path.join(abs_path, 'instance.rddl'), 'w') as inst_file:
|
|
71
|
-
inst_file.write(INSTANCE)
|
|
72
|
-
|
|
73
|
-
env = pyRDDLGym.make(os.path.join(abs_path, 'domain.rddl'),
|
|
74
|
-
os.path.join(abs_path, 'instance.rddl'))
|
|
75
|
-
|
|
76
|
-
# policy is slope * state
|
|
77
|
-
def policy(key, policy_params, hyperparams, step, states):
|
|
78
|
-
return {'action': policy_params['slope'] * states['state']}
|
|
79
|
-
|
|
80
|
-
# compile the return objective
|
|
81
|
-
compiler = JaxRDDLCompilerWithGrad(env.model)
|
|
82
|
-
compiler.compile()
|
|
83
|
-
step_fn = compiler.compile_rollouts(policy, 3, 1)
|
|
84
|
-
|
|
85
|
-
def sum_of_rewards(*args):
|
|
86
|
-
return jax.numpy.sum(step_fn(*args)['reward'])
|
|
87
|
-
|
|
88
|
-
# prepare the arguments (note that batching requires new axis at index 0)
|
|
89
|
-
subs = {k: v[None, ...] for (k, v) in compiler.init_values.items()}
|
|
90
|
-
params = {'slope': 2.0}
|
|
91
|
-
my_args = [jax.random.PRNGKey(42), params, None, subs, compiler.model_params]
|
|
92
|
-
|
|
93
|
-
# print the fluents over the trajectory, return and gradient
|
|
94
|
-
print(step_fn(*my_args)['fluents'])
|
|
95
|
-
print(sum_of_rewards(*my_args))
|
|
96
|
-
print(jax.grad(sum_of_rewards, argnums=1)(*my_args))
|
|
97
|
-
|
|
98
|
-
env.close()
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
if __name__ == "__main__":
|
|
102
|
-
main()
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=nHQztRWlKCpxZgvKkxsGQax5-clS2XguHhAvmBZt0sA,19
|
|
2
|
-
pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
|
|
3
|
-
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=DS4G5f5U83cOUQsUe6RsyyJnLPDuHaqjxM7bHSWMCtM,88040
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
|
|
6
|
-
pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
|
|
7
|
-
pyRDDLGym_jax/core/planner.py,sha256=cvl3JS1tLQqj8KJ5ATkHUfIzCzcYJWOCoWJYwLxMDSg,146835
|
|
8
|
-
pyRDDLGym_jax/core/simulator.py,sha256=D-yLxDFw67DvFHdb_kJjZHujSBSmiFA1J3osel-KOvY,10799
|
|
9
|
-
pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
|
|
10
|
-
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
11
|
-
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
13
|
-
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
15
|
-
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
16
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=uScTTUSdwohhaqvmSf9zvOjQn4xZ97qU1xYezZTIIHg,3745
|
|
17
|
-
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
18
|
-
pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
|
|
19
|
-
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
20
|
-
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
|
|
21
|
-
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
|
|
22
|
-
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=9-QMZPZuecAEaerD79ZAbGX-tgfL8Y2W-tfkAyD15Cw,362
|
|
23
|
-
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=BiY6wwSYkR9-T46AA4n3okJ1Qvj8Iu-y1V5BrfCbqrM,340
|
|
24
|
-
pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=VBlTiHFQG72D1wpebMsuzSokwqlPVD99WjPp4YoWs84,356
|
|
25
|
-
pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=bH_5O13-Y6ztvN_qranlomsmjdj_8CsaA0Bg0hA-FaQ,356
|
|
26
|
-
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=Pq6E9RYksue7X2cWjdWyUsV0LqQTjTvq6p0aLBVKWfY,370
|
|
27
|
-
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=SGVQAOqrOjEsZEtxL_Z6aGbLR19h5gKCcy0oz2vtQp8,382
|
|
28
|
-
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=6obQik2FBldoJ3VwoVfGhQqKpKdnYox770cF-SGRi3Q,345
|
|
29
|
-
pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg,sha256=rs-CzOAyZV_NvwSh2f6Fm9XNw5Z8WIYgpAOzgTm_Gv8,403
|
|
30
|
-
pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=EtSCTjd8gWm7akQdfHFxdpGnQvHzjo2IHbAuVxTAX4U,356
|
|
31
|
-
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=7nPOJCo3eaZuq1pCyIJJJkDM0jjJThDuDECJDZzX-uc,379
|
|
32
|
-
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=V3jzPGuNq2IAxYy_EeZWin4Y_uf0HvGhzg06ODNSY-I,381
|
|
33
|
-
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=SYAJmoUIUhhvAej3XOzC5boGxKVHnSiVi5-ZGj2S29M,354
|
|
34
|
-
pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=osoIPfrldPw7oJF2AaAw0-ke6YHQNdrslFBCTytsqmo,354
|
|
35
|
-
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=oNX8uW8Bw2uG9zHX1zeLF3mHWDHRIlJXYvbFcY0pfCI,382
|
|
36
|
-
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=exCfGI3WU7IFO7n5rRe5cO1ZHAdFwttRYzjIdD4Pz2Y,451
|
|
37
|
-
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=e6Ikgv2uBbKuXHfVKt4KQ01LDUBGbc31D28bCcztJ58,413
|
|
38
|
-
pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
|
-
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
|
|
40
|
-
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
|
|
41
|
-
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qGIJDIw73XCe6pyIPtg,369
|
|
42
|
-
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
|
|
43
|
-
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
|
|
44
|
-
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
|
|
45
|
-
pyrddlgym_jax-2.7.dist-info/licenses/LICENSE,sha256=2a-BZEY7aEZW-DkmmOQsuUDU0pc6ovQy3QnYFZ4baq4,1095
|
|
46
|
-
pyrddlgym_jax-2.7.dist-info/METADATA,sha256=xN_SB6x-qiC9cj8O0VvF9HIEDpK79i7FQgn8D3og2xQ,16770
|
|
47
|
-
pyrddlgym_jax-2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
48
|
-
pyrddlgym_jax-2.7.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
49
|
-
pyrddlgym_jax-2.7.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
50
|
-
pyrddlgym_jax-2.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|