pyRDDLGym-jax 2.3__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +2 -3
- pyRDDLGym_jax/core/logic.py +117 -66
- pyRDDLGym_jax/core/planner.py +489 -218
- pyRDDLGym_jax/core/tuning.py +28 -22
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/METADATA +1 -1
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/RECORD +13 -13
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/WHEEL +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.4'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -471,8 +471,7 @@ class JaxRDDLCompiler:
|
|
|
471
471
|
return printed
|
|
472
472
|
|
|
473
473
|
def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
|
|
474
|
-
'''Returns a dictionary of additional information about model
|
|
475
|
-
parameters.'''
|
|
474
|
+
'''Returns a dictionary of additional information about model parameters.'''
|
|
476
475
|
result = {}
|
|
477
476
|
for (id, value) in self.model_params.items():
|
|
478
477
|
expr_id = int(str(id).split('_')[0])
|
|
@@ -799,7 +798,7 @@ class JaxRDDLCompiler:
|
|
|
799
798
|
elif n == 2 or (n >= 2 and op in {'*', '+'}):
|
|
800
799
|
jax_exprs = [self._jax(arg, init_params) for arg in args]
|
|
801
800
|
result = jax_exprs[0]
|
|
802
|
-
for i, jax_rhs in enumerate(jax_exprs[1:]):
|
|
801
|
+
for (i, jax_rhs) in enumerate(jax_exprs[1:]):
|
|
803
802
|
jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
|
|
804
803
|
result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
|
|
805
804
|
return result
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -29,6 +29,8 @@
|
|
|
29
29
|
#
|
|
30
30
|
# ***********************************************************************
|
|
31
31
|
|
|
32
|
+
|
|
33
|
+
from abc import ABCMeta, abstractmethod
|
|
32
34
|
import traceback
|
|
33
35
|
from typing import Callable, Dict, Tuple, Union
|
|
34
36
|
|
|
@@ -64,30 +66,35 @@ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32)
|
|
|
64
66
|
#
|
|
65
67
|
# ===========================================================================
|
|
66
68
|
|
|
67
|
-
class Comparison:
|
|
69
|
+
class Comparison(metaclass=ABCMeta):
|
|
68
70
|
'''Base class for approximate comparison operations.'''
|
|
69
71
|
|
|
72
|
+
@abstractmethod
|
|
70
73
|
def greater_equal(self, id, init_params):
|
|
71
|
-
|
|
74
|
+
pass
|
|
72
75
|
|
|
76
|
+
@abstractmethod
|
|
73
77
|
def greater(self, id, init_params):
|
|
74
|
-
|
|
78
|
+
pass
|
|
75
79
|
|
|
80
|
+
@abstractmethod
|
|
76
81
|
def equal(self, id, init_params):
|
|
77
|
-
|
|
82
|
+
pass
|
|
78
83
|
|
|
84
|
+
@abstractmethod
|
|
79
85
|
def sgn(self, id, init_params):
|
|
80
|
-
|
|
86
|
+
pass
|
|
81
87
|
|
|
88
|
+
@abstractmethod
|
|
82
89
|
def argmax(self, id, init_params):
|
|
83
|
-
|
|
90
|
+
pass
|
|
84
91
|
|
|
85
92
|
|
|
86
93
|
class SigmoidComparison(Comparison):
|
|
87
94
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
88
95
|
|
|
89
96
|
def __init__(self, weight: float=10.0) -> None:
|
|
90
|
-
self.weight = weight
|
|
97
|
+
self.weight = float(weight)
|
|
91
98
|
|
|
92
99
|
# https://arxiv.org/abs/2110.05651
|
|
93
100
|
def greater_equal(self, id, init_params):
|
|
@@ -139,21 +146,23 @@ class SigmoidComparison(Comparison):
|
|
|
139
146
|
#
|
|
140
147
|
# ===========================================================================
|
|
141
148
|
|
|
142
|
-
class Rounding:
|
|
149
|
+
class Rounding(metaclass=ABCMeta):
|
|
143
150
|
'''Base class for approximate rounding operations.'''
|
|
144
151
|
|
|
152
|
+
@abstractmethod
|
|
145
153
|
def floor(self, id, init_params):
|
|
146
|
-
|
|
154
|
+
pass
|
|
147
155
|
|
|
156
|
+
@abstractmethod
|
|
148
157
|
def round(self, id, init_params):
|
|
149
|
-
|
|
158
|
+
pass
|
|
150
159
|
|
|
151
160
|
|
|
152
161
|
class SoftRounding(Rounding):
|
|
153
162
|
'''Rounding operations approximated using soft operations.'''
|
|
154
163
|
|
|
155
164
|
def __init__(self, weight: float=10.0) -> None:
|
|
156
|
-
self.weight = weight
|
|
165
|
+
self.weight = float(weight)
|
|
157
166
|
|
|
158
167
|
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
159
168
|
def floor(self, id, init_params):
|
|
@@ -189,11 +198,12 @@ class SoftRounding(Rounding):
|
|
|
189
198
|
#
|
|
190
199
|
# ===========================================================================
|
|
191
200
|
|
|
192
|
-
class Complement:
|
|
201
|
+
class Complement(metaclass=ABCMeta):
|
|
193
202
|
'''Base class for approximate logical complement operations.'''
|
|
194
203
|
|
|
204
|
+
@abstractmethod
|
|
195
205
|
def __call__(self, id, init_params):
|
|
196
|
-
|
|
206
|
+
pass
|
|
197
207
|
|
|
198
208
|
|
|
199
209
|
class StandardComplement(Complement):
|
|
@@ -222,16 +232,18 @@ class StandardComplement(Complement):
|
|
|
222
232
|
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
223
233
|
# ===========================================================================
|
|
224
234
|
|
|
225
|
-
class TNorm:
|
|
235
|
+
class TNorm(metaclass=ABCMeta):
|
|
226
236
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
227
237
|
|
|
238
|
+
@abstractmethod
|
|
228
239
|
def norm(self, id, init_params):
|
|
229
240
|
'''Elementwise t-norm of x and y.'''
|
|
230
|
-
|
|
241
|
+
pass
|
|
231
242
|
|
|
243
|
+
@abstractmethod
|
|
232
244
|
def norms(self, id, init_params):
|
|
233
245
|
'''T-norm computed for tensor x along axis.'''
|
|
234
|
-
|
|
246
|
+
pass
|
|
235
247
|
|
|
236
248
|
|
|
237
249
|
class ProductTNorm(TNorm):
|
|
@@ -339,26 +351,32 @@ class YagerTNorm(TNorm):
|
|
|
339
351
|
#
|
|
340
352
|
# ===========================================================================
|
|
341
353
|
|
|
342
|
-
class RandomSampling:
|
|
354
|
+
class RandomSampling(metaclass=ABCMeta):
|
|
343
355
|
'''Describes how non-reparameterizable random variables are sampled.'''
|
|
344
356
|
|
|
357
|
+
@abstractmethod
|
|
345
358
|
def discrete(self, id, init_params, logic):
|
|
346
|
-
|
|
359
|
+
pass
|
|
347
360
|
|
|
361
|
+
@abstractmethod
|
|
348
362
|
def poisson(self, id, init_params, logic):
|
|
349
|
-
|
|
363
|
+
pass
|
|
350
364
|
|
|
365
|
+
@abstractmethod
|
|
351
366
|
def binomial(self, id, init_params, logic):
|
|
352
|
-
|
|
367
|
+
pass
|
|
353
368
|
|
|
369
|
+
@abstractmethod
|
|
354
370
|
def negative_binomial(self, id, init_params, logic):
|
|
355
|
-
|
|
371
|
+
pass
|
|
356
372
|
|
|
373
|
+
@abstractmethod
|
|
357
374
|
def geometric(self, id, init_params, logic):
|
|
358
|
-
|
|
375
|
+
pass
|
|
359
376
|
|
|
377
|
+
@abstractmethod
|
|
360
378
|
def bernoulli(self, id, init_params, logic):
|
|
361
|
-
|
|
379
|
+
pass
|
|
362
380
|
|
|
363
381
|
def __str__(self) -> str:
|
|
364
382
|
return 'RandomSampling'
|
|
@@ -603,21 +621,23 @@ class Determinization(RandomSampling):
|
|
|
603
621
|
#
|
|
604
622
|
# ===========================================================================
|
|
605
623
|
|
|
606
|
-
class ControlFlow:
|
|
624
|
+
class ControlFlow(metaclass=ABCMeta):
|
|
607
625
|
'''A base class for control flow, including if and switch statements.'''
|
|
608
626
|
|
|
627
|
+
@abstractmethod
|
|
609
628
|
def if_then_else(self, id, init_params):
|
|
610
|
-
|
|
629
|
+
pass
|
|
611
630
|
|
|
631
|
+
@abstractmethod
|
|
612
632
|
def switch(self, id, init_params):
|
|
613
|
-
|
|
633
|
+
pass
|
|
614
634
|
|
|
615
635
|
|
|
616
636
|
class SoftControlFlow(ControlFlow):
|
|
617
637
|
'''Soft control flow using a probabilistic interpretation.'''
|
|
618
638
|
|
|
619
639
|
def __init__(self, weight: float=10.0) -> None:
|
|
620
|
-
self.weight = weight
|
|
640
|
+
self.weight = float(weight)
|
|
621
641
|
|
|
622
642
|
@staticmethod
|
|
623
643
|
def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
|
|
@@ -651,15 +671,15 @@ class SoftControlFlow(ControlFlow):
|
|
|
651
671
|
# ===========================================================================
|
|
652
672
|
|
|
653
673
|
|
|
654
|
-
class Logic:
|
|
674
|
+
class Logic(metaclass=ABCMeta):
|
|
655
675
|
'''A base class for representing logic computations in JAX.'''
|
|
656
676
|
|
|
657
677
|
def __init__(self, use64bit: bool=False) -> None:
|
|
658
678
|
self.set_use64bit(use64bit)
|
|
659
679
|
|
|
660
|
-
def summarize_hyperparameters(self) ->
|
|
661
|
-
|
|
662
|
-
|
|
680
|
+
def summarize_hyperparameters(self) -> str:
|
|
681
|
+
return (f'model relaxation:\n'
|
|
682
|
+
f' use_64_bit ={self.use64bit}')
|
|
663
683
|
|
|
664
684
|
def set_use64bit(self, use64bit: bool) -> None:
|
|
665
685
|
'''Toggles whether or not the JAX system will use 64 bit precision.'''
|
|
@@ -765,119 +785,150 @@ class Logic:
|
|
|
765
785
|
# ===========================================================================
|
|
766
786
|
# logical operators
|
|
767
787
|
# ===========================================================================
|
|
768
|
-
|
|
788
|
+
|
|
789
|
+
@abstractmethod
|
|
769
790
|
def logical_and(self, id, init_params):
|
|
770
|
-
|
|
791
|
+
pass
|
|
771
792
|
|
|
793
|
+
@abstractmethod
|
|
772
794
|
def logical_not(self, id, init_params):
|
|
773
|
-
|
|
795
|
+
pass
|
|
774
796
|
|
|
797
|
+
@abstractmethod
|
|
775
798
|
def logical_or(self, id, init_params):
|
|
776
|
-
|
|
799
|
+
pass
|
|
777
800
|
|
|
801
|
+
@abstractmethod
|
|
778
802
|
def xor(self, id, init_params):
|
|
779
|
-
|
|
803
|
+
pass
|
|
780
804
|
|
|
805
|
+
@abstractmethod
|
|
781
806
|
def implies(self, id, init_params):
|
|
782
|
-
|
|
807
|
+
pass
|
|
783
808
|
|
|
809
|
+
@abstractmethod
|
|
784
810
|
def equiv(self, id, init_params):
|
|
785
|
-
|
|
811
|
+
pass
|
|
786
812
|
|
|
813
|
+
@abstractmethod
|
|
787
814
|
def forall(self, id, init_params):
|
|
788
|
-
|
|
815
|
+
pass
|
|
789
816
|
|
|
817
|
+
@abstractmethod
|
|
790
818
|
def exists(self, id, init_params):
|
|
791
|
-
|
|
819
|
+
pass
|
|
792
820
|
|
|
793
821
|
# ===========================================================================
|
|
794
822
|
# comparison operators
|
|
795
823
|
# ===========================================================================
|
|
796
824
|
|
|
825
|
+
@abstractmethod
|
|
797
826
|
def greater_equal(self, id, init_params):
|
|
798
|
-
|
|
827
|
+
pass
|
|
799
828
|
|
|
829
|
+
@abstractmethod
|
|
800
830
|
def greater(self, id, init_params):
|
|
801
|
-
|
|
831
|
+
pass
|
|
802
832
|
|
|
833
|
+
@abstractmethod
|
|
803
834
|
def less_equal(self, id, init_params):
|
|
804
|
-
|
|
835
|
+
pass
|
|
805
836
|
|
|
837
|
+
@abstractmethod
|
|
806
838
|
def less(self, id, init_params):
|
|
807
|
-
|
|
839
|
+
pass
|
|
808
840
|
|
|
841
|
+
@abstractmethod
|
|
809
842
|
def equal(self, id, init_params):
|
|
810
|
-
|
|
843
|
+
pass
|
|
811
844
|
|
|
845
|
+
@abstractmethod
|
|
812
846
|
def not_equal(self, id, init_params):
|
|
813
|
-
|
|
847
|
+
pass
|
|
814
848
|
|
|
815
849
|
# ===========================================================================
|
|
816
850
|
# special functions
|
|
817
851
|
# ===========================================================================
|
|
818
852
|
|
|
853
|
+
@abstractmethod
|
|
819
854
|
def sgn(self, id, init_params):
|
|
820
|
-
|
|
855
|
+
pass
|
|
821
856
|
|
|
857
|
+
@abstractmethod
|
|
822
858
|
def floor(self, id, init_params):
|
|
823
|
-
|
|
859
|
+
pass
|
|
824
860
|
|
|
861
|
+
@abstractmethod
|
|
825
862
|
def round(self, id, init_params):
|
|
826
|
-
|
|
863
|
+
pass
|
|
827
864
|
|
|
865
|
+
@abstractmethod
|
|
828
866
|
def ceil(self, id, init_params):
|
|
829
|
-
|
|
867
|
+
pass
|
|
830
868
|
|
|
869
|
+
@abstractmethod
|
|
831
870
|
def div(self, id, init_params):
|
|
832
|
-
|
|
871
|
+
pass
|
|
833
872
|
|
|
873
|
+
@abstractmethod
|
|
834
874
|
def mod(self, id, init_params):
|
|
835
|
-
|
|
875
|
+
pass
|
|
836
876
|
|
|
877
|
+
@abstractmethod
|
|
837
878
|
def sqrt(self, id, init_params):
|
|
838
|
-
|
|
879
|
+
pass
|
|
839
880
|
|
|
840
881
|
# ===========================================================================
|
|
841
882
|
# indexing
|
|
842
883
|
# ===========================================================================
|
|
843
|
-
|
|
884
|
+
|
|
885
|
+
@abstractmethod
|
|
844
886
|
def argmax(self, id, init_params):
|
|
845
|
-
|
|
887
|
+
pass
|
|
846
888
|
|
|
889
|
+
@abstractmethod
|
|
847
890
|
def argmin(self, id, init_params):
|
|
848
|
-
|
|
891
|
+
pass
|
|
849
892
|
|
|
850
893
|
# ===========================================================================
|
|
851
894
|
# control flow
|
|
852
895
|
# ===========================================================================
|
|
853
896
|
|
|
897
|
+
@abstractmethod
|
|
854
898
|
def control_if(self, id, init_params):
|
|
855
|
-
|
|
899
|
+
pass
|
|
856
900
|
|
|
901
|
+
@abstractmethod
|
|
857
902
|
def control_switch(self, id, init_params):
|
|
858
|
-
|
|
903
|
+
pass
|
|
859
904
|
|
|
860
905
|
# ===========================================================================
|
|
861
906
|
# random variables
|
|
862
907
|
# ===========================================================================
|
|
863
908
|
|
|
909
|
+
@abstractmethod
|
|
864
910
|
def discrete(self, id, init_params):
|
|
865
|
-
|
|
911
|
+
pass
|
|
866
912
|
|
|
913
|
+
@abstractmethod
|
|
867
914
|
def bernoulli(self, id, init_params):
|
|
868
|
-
|
|
915
|
+
pass
|
|
869
916
|
|
|
917
|
+
@abstractmethod
|
|
870
918
|
def poisson(self, id, init_params):
|
|
871
|
-
|
|
919
|
+
pass
|
|
872
920
|
|
|
921
|
+
@abstractmethod
|
|
873
922
|
def geometric(self, id, init_params):
|
|
874
|
-
|
|
923
|
+
pass
|
|
875
924
|
|
|
925
|
+
@abstractmethod
|
|
876
926
|
def binomial(self, id, init_params):
|
|
877
|
-
|
|
927
|
+
pass
|
|
878
928
|
|
|
929
|
+
@abstractmethod
|
|
879
930
|
def negative_binomial(self, id, init_params):
|
|
880
|
-
|
|
931
|
+
pass
|
|
881
932
|
|
|
882
933
|
|
|
883
934
|
class ExactLogic(Logic):
|
|
@@ -1109,8 +1160,8 @@ class FuzzyLogic(Logic):
|
|
|
1109
1160
|
f' underflow_tol={self.eps}\n'
|
|
1110
1161
|
f' use_64_bit ={self.use64bit}\n')
|
|
1111
1162
|
|
|
1112
|
-
def summarize_hyperparameters(self) ->
|
|
1113
|
-
|
|
1163
|
+
def summarize_hyperparameters(self) -> str:
|
|
1164
|
+
return self.__str__()
|
|
1114
1165
|
|
|
1115
1166
|
# ===========================================================================
|
|
1116
1167
|
# logical operators
|