pyRDDLGym-jax 2.3__tar.gz → 2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/PKG-INFO +13 -18
  2. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/README.md +10 -16
  3. pyrddlgym_jax-2.5/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/compiler.py +10 -7
  5. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/logic.py +117 -66
  6. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/planner.py +585 -248
  7. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/simulator.py +37 -13
  8. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/tuning.py +52 -31
  9. pyrddlgym_jax-2.5/pyRDDLGym_jax/entry_point.py +59 -0
  10. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
  11. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
  12. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
  13. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_plan.py +3 -3
  14. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_scipy.py +2 -2
  15. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_tune.py +8 -2
  16. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/PKG-INFO +13 -18
  17. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/setup.py +1 -1
  18. pyrddlgym_jax-2.3/pyRDDLGym_jax/__init__.py +0 -1
  19. pyrddlgym_jax-2.3/pyRDDLGym_jax/entry_point.py +0 -27
  20. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/LICENSE +0 -0
  21. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/__init__.py +0 -0
  22. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  23. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  24. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/core/visualization.py +0 -0
  25. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/__init__.py +0 -0
  26. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  27. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  28. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  29. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  30. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  32. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  33. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  34. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  35. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  36. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  37. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  38. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  39. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  40. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  41. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  42. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  43. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  44. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  45. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  46. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  47. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  48. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  49. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  50. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  51. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  52. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  53. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  54. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  55. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  56. {pyrddlgym_jax-2.3 → pyrddlgym_jax-2.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.3
3
+ Version: 2.5
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -39,6 +39,7 @@ Dynamic: description
39
39
  Dynamic: description-content-type
40
40
  Dynamic: home-page
41
41
  Dynamic: license
42
+ Dynamic: license-file
42
43
  Dynamic: provides-extra
43
44
  Dynamic: requires-dist
44
45
  Dynamic: requires-python
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
116
117
  A basic run script is provided to train JaxPlan on any RDDL problem:
117
118
 
118
119
  ```shell
119
- jaxplan plan <domain> <instance> <method> <episodes>
120
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
120
121
  ```
121
122
 
122
123
  where:
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
241
242
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
243
 
243
244
  ```shell
244
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
245
246
  ```
246
247
 
247
248
  where:
@@ -251,7 +252,8 @@ where:
251
252
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
253
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
254
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
256
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
255
257
 
256
258
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
259
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
291
293
  with open('path/to/config.cfg', 'r') as file:
292
294
  config_template = file.read()
293
295
 
294
- # map parameters in the config that will be tuned
296
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
295
297
  def power_10(x):
296
- return 10.0 ** x
297
-
298
- hyperparams = [
299
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
300
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
301
- ]
298
+ return 10.0 ** x
299
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
300
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
302
301
 
303
302
  # build the tuner and tune
304
303
  tuning = JaxParameterTuning(env=env,
305
- config_template=config_template,
306
- hyperparams=hyperparams,
307
- online=False,
308
- eval_trials=trials,
309
- num_workers=workers,
310
- gp_iters=iters)
304
+ config_template=config_template, hyperparams=hyperparams,
305
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
311
306
  tuning.tune(key=42, log_file='path/to/log.csv')
312
307
  ```
313
308
 
@@ -70,7 +70,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
70
70
  A basic run script is provided to train JaxPlan on any RDDL problem:
71
71
 
72
72
  ```shell
73
- jaxplan plan <domain> <instance> <method> <episodes>
73
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
74
74
  ```
75
75
 
76
76
  where:
@@ -195,7 +195,7 @@ More documentation about this and other new features will be coming soon.
195
195
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
196
196
 
197
197
  ```shell
198
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
198
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
199
199
  ```
200
200
 
201
201
  where:
@@ -205,7 +205,8 @@ where:
205
205
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
206
206
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
207
207
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
208
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
208
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
209
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
209
210
 
210
211
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
211
212
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -245,23 +246,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
245
246
  with open('path/to/config.cfg', 'r') as file:
246
247
  config_template = file.read()
247
248
 
248
- # map parameters in the config that will be tuned
249
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
249
250
  def power_10(x):
250
- return 10.0 ** x
251
-
252
- hyperparams = [
253
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
254
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
255
- ]
251
+ return 10.0 ** x
252
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
253
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
256
254
 
257
255
  # build the tuner and tune
258
256
  tuning = JaxParameterTuning(env=env,
259
- config_template=config_template,
260
- hyperparams=hyperparams,
261
- online=False,
262
- eval_trials=trials,
263
- num_workers=workers,
264
- gp_iters=iters)
257
+ config_template=config_template, hyperparams=hyperparams,
258
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
265
259
  tuning.tune(key=42, log_file='path/to/log.csv')
266
260
  ```
267
261
 
@@ -0,0 +1 @@
1
+ __version__ = '2.5'
@@ -430,7 +430,7 @@ class JaxRDDLCompiler:
430
430
  _jax_wrapped_single_step_policy,
431
431
  in_axes=(0, None, None, None, 0, None)
432
432
  )(keys, policy_params, hyperparams, step, subs, model_params)
433
- model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
433
+ model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
434
434
  carry = (key, policy_params, hyperparams, subs, model_params)
435
435
  return carry, log
436
436
 
@@ -440,7 +440,7 @@ class JaxRDDLCompiler:
440
440
  start = (key, policy_params, hyperparams, subs, model_params)
441
441
  steps = jnp.arange(n_steps)
442
442
  end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
443
- log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
443
+ log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
444
444
  model_params = end[-1]
445
445
  return log, model_params
446
446
 
@@ -471,8 +471,7 @@ class JaxRDDLCompiler:
471
471
  return printed
472
472
 
473
473
  def model_parameter_info(self) -> Dict[str, Dict[str, Any]]:
474
- '''Returns a dictionary of additional information about model
475
- parameters.'''
474
+ '''Returns a dictionary of additional information about model parameters.'''
476
475
  result = {}
477
476
  for (id, value) in self.model_params.items():
478
477
  expr_id = int(str(id).split('_')[0])
@@ -708,7 +707,10 @@ class JaxRDDLCompiler:
708
707
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
709
708
  new_slices = [None] * len(jax_nested_expr)
710
709
  for (i, jax_expr) in enumerate(jax_nested_expr):
711
- new_slices[i], key, err, params = jax_expr(x, params, key)
710
+ new_slice, key, err, params = jax_expr(x, params, key)
711
+ if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
712
+ new_slice = jnp.asarray(new_slice, dtype=self.INT)
713
+ new_slices[i] = new_slice
712
714
  error |= err
713
715
  new_slices = tuple(new_slices)
714
716
  sample = sample[new_slices]
@@ -799,7 +801,7 @@ class JaxRDDLCompiler:
799
801
  elif n == 2 or (n >= 2 and op in {'*', '+'}):
800
802
  jax_exprs = [self._jax(arg, init_params) for arg in args]
801
803
  result = jax_exprs[0]
802
- for i, jax_rhs in enumerate(jax_exprs[1:]):
804
+ for (i, jax_rhs) in enumerate(jax_exprs[1:]):
803
805
  jax_op = valid_ops[op](f'{expr.id}_{op}{i}', init_params)
804
806
  result = self._jax_binary(result, jax_rhs, jax_op, at_least_int=True)
805
807
  return result
@@ -987,7 +989,8 @@ class JaxRDDLCompiler:
987
989
  sample_cases = [None] * len(jax_cases)
988
990
  for (i, jax_case) in enumerate(jax_cases):
989
991
  sample_cases[i], key, err_case, params = jax_case(x, params, key)
990
- err |= err_case
992
+ err |= err_case
993
+ sample_cases = jnp.asarray(sample_cases)
991
994
  sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
992
995
 
993
996
  # predicate (enum) is an integer - use it to extract from case array
@@ -29,6 +29,8 @@
29
29
  #
30
30
  # ***********************************************************************
31
31
 
32
+
33
+ from abc import ABCMeta, abstractmethod
32
34
  import traceback
33
35
  from typing import Callable, Dict, Tuple, Union
34
36
 
@@ -64,30 +66,35 @@ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32)
64
66
  #
65
67
  # ===========================================================================
66
68
 
67
- class Comparison:
69
+ class Comparison(metaclass=ABCMeta):
68
70
  '''Base class for approximate comparison operations.'''
69
71
 
72
+ @abstractmethod
70
73
  def greater_equal(self, id, init_params):
71
- raise NotImplementedError
74
+ pass
72
75
 
76
+ @abstractmethod
73
77
  def greater(self, id, init_params):
74
- raise NotImplementedError
78
+ pass
75
79
 
80
+ @abstractmethod
76
81
  def equal(self, id, init_params):
77
- raise NotImplementedError
82
+ pass
78
83
 
84
+ @abstractmethod
79
85
  def sgn(self, id, init_params):
80
- raise NotImplementedError
86
+ pass
81
87
 
88
+ @abstractmethod
82
89
  def argmax(self, id, init_params):
83
- raise NotImplementedError
90
+ pass
84
91
 
85
92
 
86
93
  class SigmoidComparison(Comparison):
87
94
  '''Comparison operations approximated using sigmoid functions.'''
88
95
 
89
96
  def __init__(self, weight: float=10.0) -> None:
90
- self.weight = weight
97
+ self.weight = float(weight)
91
98
 
92
99
  # https://arxiv.org/abs/2110.05651
93
100
  def greater_equal(self, id, init_params):
@@ -139,21 +146,23 @@ class SigmoidComparison(Comparison):
139
146
  #
140
147
  # ===========================================================================
141
148
 
142
- class Rounding:
149
+ class Rounding(metaclass=ABCMeta):
143
150
  '''Base class for approximate rounding operations.'''
144
151
 
152
+ @abstractmethod
145
153
  def floor(self, id, init_params):
146
- raise NotImplementedError
154
+ pass
147
155
 
156
+ @abstractmethod
148
157
  def round(self, id, init_params):
149
- raise NotImplementedError
158
+ pass
150
159
 
151
160
 
152
161
  class SoftRounding(Rounding):
153
162
  '''Rounding operations approximated using soft operations.'''
154
163
 
155
164
  def __init__(self, weight: float=10.0) -> None:
156
- self.weight = weight
165
+ self.weight = float(weight)
157
166
 
158
167
  # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
159
168
  def floor(self, id, init_params):
@@ -189,11 +198,12 @@ class SoftRounding(Rounding):
189
198
  #
190
199
  # ===========================================================================
191
200
 
192
- class Complement:
201
+ class Complement(metaclass=ABCMeta):
193
202
  '''Base class for approximate logical complement operations.'''
194
203
 
204
+ @abstractmethod
195
205
  def __call__(self, id, init_params):
196
- raise NotImplementedError
206
+ pass
197
207
 
198
208
 
199
209
  class StandardComplement(Complement):
@@ -222,16 +232,18 @@ class StandardComplement(Complement):
222
232
  # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
223
233
  # ===========================================================================
224
234
 
225
- class TNorm:
235
+ class TNorm(metaclass=ABCMeta):
226
236
  '''Base class for fuzzy differentiable t-norms.'''
227
237
 
238
+ @abstractmethod
228
239
  def norm(self, id, init_params):
229
240
  '''Elementwise t-norm of x and y.'''
230
- raise NotImplementedError
241
+ pass
231
242
 
243
+ @abstractmethod
232
244
  def norms(self, id, init_params):
233
245
  '''T-norm computed for tensor x along axis.'''
234
- raise NotImplementedError
246
+ pass
235
247
 
236
248
 
237
249
  class ProductTNorm(TNorm):
@@ -339,26 +351,32 @@ class YagerTNorm(TNorm):
339
351
  #
340
352
  # ===========================================================================
341
353
 
342
- class RandomSampling:
354
+ class RandomSampling(metaclass=ABCMeta):
343
355
  '''Describes how non-reparameterizable random variables are sampled.'''
344
356
 
357
+ @abstractmethod
345
358
  def discrete(self, id, init_params, logic):
346
- raise NotImplementedError
359
+ pass
347
360
 
361
+ @abstractmethod
348
362
  def poisson(self, id, init_params, logic):
349
- raise NotImplementedError
363
+ pass
350
364
 
365
+ @abstractmethod
351
366
  def binomial(self, id, init_params, logic):
352
- raise NotImplementedError
367
+ pass
353
368
 
369
+ @abstractmethod
354
370
  def negative_binomial(self, id, init_params, logic):
355
- raise NotImplementedError
371
+ pass
356
372
 
373
+ @abstractmethod
357
374
  def geometric(self, id, init_params, logic):
358
- raise NotImplementedError
375
+ pass
359
376
 
377
+ @abstractmethod
360
378
  def bernoulli(self, id, init_params, logic):
361
- raise NotImplementedError
379
+ pass
362
380
 
363
381
  def __str__(self) -> str:
364
382
  return 'RandomSampling'
@@ -603,21 +621,23 @@ class Determinization(RandomSampling):
603
621
  #
604
622
  # ===========================================================================
605
623
 
606
- class ControlFlow:
624
+ class ControlFlow(metaclass=ABCMeta):
607
625
  '''A base class for control flow, including if and switch statements.'''
608
626
 
627
+ @abstractmethod
609
628
  def if_then_else(self, id, init_params):
610
- raise NotImplementedError
629
+ pass
611
630
 
631
+ @abstractmethod
612
632
  def switch(self, id, init_params):
613
- raise NotImplementedError
633
+ pass
614
634
 
615
635
 
616
636
  class SoftControlFlow(ControlFlow):
617
637
  '''Soft control flow using a probabilistic interpretation.'''
618
638
 
619
639
  def __init__(self, weight: float=10.0) -> None:
620
- self.weight = weight
640
+ self.weight = float(weight)
621
641
 
622
642
  @staticmethod
623
643
  def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
@@ -651,15 +671,15 @@ class SoftControlFlow(ControlFlow):
651
671
  # ===========================================================================
652
672
 
653
673
 
654
- class Logic:
674
+ class Logic(metaclass=ABCMeta):
655
675
  '''A base class for representing logic computations in JAX.'''
656
676
 
657
677
  def __init__(self, use64bit: bool=False) -> None:
658
678
  self.set_use64bit(use64bit)
659
679
 
660
- def summarize_hyperparameters(self) -> None:
661
- print(f'model relaxation:\n'
662
- f' use_64_bit ={self.use64bit}')
680
+ def summarize_hyperparameters(self) -> str:
681
+ return (f'model relaxation:\n'
682
+ f' use_64_bit ={self.use64bit}')
663
683
 
664
684
  def set_use64bit(self, use64bit: bool) -> None:
665
685
  '''Toggles whether or not the JAX system will use 64 bit precision.'''
@@ -765,119 +785,150 @@ class Logic:
765
785
  # ===========================================================================
766
786
  # logical operators
767
787
  # ===========================================================================
768
-
788
+
789
+ @abstractmethod
769
790
  def logical_and(self, id, init_params):
770
- raise NotImplementedError
791
+ pass
771
792
 
793
+ @abstractmethod
772
794
  def logical_not(self, id, init_params):
773
- raise NotImplementedError
795
+ pass
774
796
 
797
+ @abstractmethod
775
798
  def logical_or(self, id, init_params):
776
- raise NotImplementedError
799
+ pass
777
800
 
801
+ @abstractmethod
778
802
  def xor(self, id, init_params):
779
- raise NotImplementedError
803
+ pass
780
804
 
805
+ @abstractmethod
781
806
  def implies(self, id, init_params):
782
- raise NotImplementedError
807
+ pass
783
808
 
809
+ @abstractmethod
784
810
  def equiv(self, id, init_params):
785
- raise NotImplementedError
811
+ pass
786
812
 
813
+ @abstractmethod
787
814
  def forall(self, id, init_params):
788
- raise NotImplementedError
815
+ pass
789
816
 
817
+ @abstractmethod
790
818
  def exists(self, id, init_params):
791
- raise NotImplementedError
819
+ pass
792
820
 
793
821
  # ===========================================================================
794
822
  # comparison operators
795
823
  # ===========================================================================
796
824
 
825
+ @abstractmethod
797
826
  def greater_equal(self, id, init_params):
798
- raise NotImplementedError
827
+ pass
799
828
 
829
+ @abstractmethod
800
830
  def greater(self, id, init_params):
801
- raise NotImplementedError
831
+ pass
802
832
 
833
+ @abstractmethod
803
834
  def less_equal(self, id, init_params):
804
- raise NotImplementedError
835
+ pass
805
836
 
837
+ @abstractmethod
806
838
  def less(self, id, init_params):
807
- raise NotImplementedError
839
+ pass
808
840
 
841
+ @abstractmethod
809
842
  def equal(self, id, init_params):
810
- raise NotImplementedError
843
+ pass
811
844
 
845
+ @abstractmethod
812
846
  def not_equal(self, id, init_params):
813
- raise NotImplementedError
847
+ pass
814
848
 
815
849
  # ===========================================================================
816
850
  # special functions
817
851
  # ===========================================================================
818
852
 
853
+ @abstractmethod
819
854
  def sgn(self, id, init_params):
820
- raise NotImplementedError
855
+ pass
821
856
 
857
+ @abstractmethod
822
858
  def floor(self, id, init_params):
823
- raise NotImplementedError
859
+ pass
824
860
 
861
+ @abstractmethod
825
862
  def round(self, id, init_params):
826
- raise NotImplementedError
863
+ pass
827
864
 
865
+ @abstractmethod
828
866
  def ceil(self, id, init_params):
829
- raise NotImplementedError
867
+ pass
830
868
 
869
+ @abstractmethod
831
870
  def div(self, id, init_params):
832
- raise NotImplementedError
871
+ pass
833
872
 
873
+ @abstractmethod
834
874
  def mod(self, id, init_params):
835
- raise NotImplementedError
875
+ pass
836
876
 
877
+ @abstractmethod
837
878
  def sqrt(self, id, init_params):
838
- raise NotImplementedError
879
+ pass
839
880
 
840
881
  # ===========================================================================
841
882
  # indexing
842
883
  # ===========================================================================
843
-
884
+
885
+ @abstractmethod
844
886
  def argmax(self, id, init_params):
845
- raise NotImplementedError
887
+ pass
846
888
 
889
+ @abstractmethod
847
890
  def argmin(self, id, init_params):
848
- raise NotImplementedError
891
+ pass
849
892
 
850
893
  # ===========================================================================
851
894
  # control flow
852
895
  # ===========================================================================
853
896
 
897
+ @abstractmethod
854
898
  def control_if(self, id, init_params):
855
- raise NotImplementedError
899
+ pass
856
900
 
901
+ @abstractmethod
857
902
  def control_switch(self, id, init_params):
858
- raise NotImplementedError
903
+ pass
859
904
 
860
905
  # ===========================================================================
861
906
  # random variables
862
907
  # ===========================================================================
863
908
 
909
+ @abstractmethod
864
910
  def discrete(self, id, init_params):
865
- raise NotImplementedError
911
+ pass
866
912
 
913
+ @abstractmethod
867
914
  def bernoulli(self, id, init_params):
868
- raise NotImplementedError
915
+ pass
869
916
 
917
+ @abstractmethod
870
918
  def poisson(self, id, init_params):
871
- raise NotImplementedError
919
+ pass
872
920
 
921
+ @abstractmethod
873
922
  def geometric(self, id, init_params):
874
- raise NotImplementedError
923
+ pass
875
924
 
925
+ @abstractmethod
876
926
  def binomial(self, id, init_params):
877
- raise NotImplementedError
927
+ pass
878
928
 
929
+ @abstractmethod
879
930
  def negative_binomial(self, id, init_params):
880
- raise NotImplementedError
931
+ pass
881
932
 
882
933
 
883
934
  class ExactLogic(Logic):
@@ -1109,8 +1160,8 @@ class FuzzyLogic(Logic):
1109
1160
  f' underflow_tol={self.eps}\n'
1110
1161
  f' use_64_bit ={self.use64bit}\n')
1111
1162
 
1112
- def summarize_hyperparameters(self) -> None:
1113
- print(self.__str__())
1163
+ def summarize_hyperparameters(self) -> str:
1164
+ return self.__str__()
1114
1165
 
1115
1166
  # ===========================================================================
1116
1167
  # logical operators