compiled-knowledge 4.0.0a9__cp312-cp312-win_amd64.whl → 4.0.0a10__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (30) hide show
  1. ck/circuit/circuit.c +38860 -0
  2. ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
  3. ck/circuit/circuit.pyx +9 -3
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
  5. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  6. ck/pgm.py +78 -37
  7. ck/pgm_circuit/pgm_circuit.py +13 -9
  8. ck/pgm_circuit/program_with_slotmap.py +6 -4
  9. ck/pgm_compiler/ace/ace.py +48 -4
  10. ck/pgm_compiler/factor_elimination.py +6 -4
  11. ck/pgm_compiler/recursive_conditioning.py +8 -3
  12. ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
  13. ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
  14. ck/pgm_compiler/support/clusters.py +1 -1
  15. ck/pgm_compiler/variable_elimination.py +3 -3
  16. ck/probability/empirical_probability_space.py +3 -0
  17. ck/probability/pgm_probability_space.py +32 -0
  18. ck/probability/probability_space.py +66 -12
  19. ck/sampling/sampler_support.py +1 -1
  20. ck/sampling/uniform_sampler.py +10 -4
  21. ck/sampling/wmc_direct_sampler.py +4 -2
  22. ck/sampling/wmc_gibbs_sampler.py +6 -0
  23. ck/sampling/wmc_metropolis_sampler.py +7 -1
  24. ck/sampling/wmc_rejection_sampler.py +2 -0
  25. ck/utils/iter_extras.py +9 -6
  26. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/METADATA +1 -1
  27. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/RECORD +30 -26
  28. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/WHEEL +0 -0
  29. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/licenses/LICENSE.txt +0 -0
  30. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/top_level.txt +0 -0
ck/pgm.py CHANGED
@@ -82,8 +82,8 @@ import math
82
82
  from abc import ABC, abstractmethod
83
83
  from dataclasses import dataclass
84
84
  from itertools import repeat as _repeat
85
- from typing import Sequence, Tuple, Dict, Optional, overload, Iterator, Set, Iterable, List, Union, Callable, \
86
- Collection, Any
85
+ from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
86
+ Collection, Any, Iterator
87
87
 
88
88
  import numpy as np
89
89
 
@@ -462,7 +462,7 @@ class PGM:
462
462
  a string representation of the given indicators.
463
463
  """
464
464
  return delim.join(
465
- f'{rv}{sep}{state}'
465
+ f'{_clean_str(rv)}{sep}{_clean_str(state)}'
466
466
  for rv, state in (
467
467
  self.indicator_pair(indicator)
468
468
  for indicator in indicators
@@ -559,7 +559,7 @@ class PGM:
559
559
  assert len(instance) == len(rvs)
560
560
  return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
561
561
 
562
- def instances(self, flip: bool = False) -> Iterator[Instance]:
562
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
563
563
  """
564
564
  Iterate over all possible instances of this PGM, in natural index
565
565
  order (i.e., last random variable changing most quickly).
@@ -573,7 +573,7 @@ class PGM:
573
573
  """
574
574
  return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
575
575
 
576
- def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
576
+ def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
577
577
  """
578
578
  Iterate over all possible instances of this PGM, in natural index
579
579
  order (i.e., last random variable changing most quickly).
@@ -605,7 +605,7 @@ class PGM:
605
605
  """
606
606
  return tuple(rv[state] for rv, state in zip(self._rvs, instance))
607
607
 
608
- def factor_values(self, key: Key) -> Iterator[float]:
608
+ def factor_values(self, key: Key) -> Iterable[float]:
609
609
  """
610
610
  For a given instance key, each factor defines a single value. This method
611
611
  returns those values.
@@ -717,6 +717,11 @@ class PGM:
717
717
  If no indicators are provided, then the value of the partition function (z)
718
718
  is returned.
719
719
 
720
+ If multiple indicators are provided for the same random variable, then all matching
721
+ instances are summed.
722
+
723
+ This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
724
+
720
725
  Warning:
721
726
  this is potentially computationally expensive as it marginalised random
722
727
  variables not mentioned in the given indicators.
@@ -727,29 +732,51 @@ class PGM:
727
732
  Returns:
728
733
  the product of factors, conditioned on the given instance. This is the
729
734
  computed value of the PGM, conditioned on the given instance.
730
-
731
- Raises:
732
- RuntimeError: if a random variable is referenced multiple times in the given indicators.
733
735
  """
734
- # Create an instance from the indicators
735
- inst: List[int] = [-1] * self.number_of_rvs
736
+ # # Create a filter from the indicators
737
+ # inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
738
+ # for indicator in indicators:
739
+ # rv_idx: int = indicator.rv_idx
740
+ # inst_filter[rv_idx].add(indicator.state_idx)
741
+ # # Collect rvs not mentioned - to marginalise
742
+ # for rv, rv_filter in zip(self.rvs, inst_filter):
743
+ # if len(rv_filter) == 0:
744
+ # rv_filter.update(rv.state_range())
745
+ #
746
+ # def _sum_inst(_instance: Instance) -> bool:
747
+ # return all(
748
+ # (_state in _rv_filter)
749
+ # for _state, _rv_filter in zip(_instance, inst_filter)
750
+ # )
751
+ #
752
+ # # Accumulate the result
753
+ # sum_value = 0
754
+ # for instance in self.instances():
755
+ # if _sum_inst(instance):
756
+ # sum_value += self.value_product(instance)
757
+ #
758
+ # return sum_value
759
+
760
+ # Work out the space to sum over
761
+ sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
736
762
  for indicator in indicators:
737
763
  rv_idx: int = indicator.rv_idx
738
- if inst[rv_idx] >= 0:
739
- raise RuntimeError(f'random variable mentioned multiple times: {self.rvs[rv_idx]}')
740
- inst[rv_idx] = indicator.state_idx
764
+ cur_set = sum_space_set[rv_idx]
765
+ if cur_set is None:
766
+ sum_space_set[rv_idx] = cur_set = set()
767
+ cur_set.add(indicator.state_idx)
741
768
 
742
- # Collect rvs not mentioned - to marginalise
743
- rvs = [rv for rv in self.rvs if inst[rv.idx] < 0]
769
+ # Convert to a list of states that we need to sum over.
770
+ sum_space_list: List[List[int]] = [
771
+ list(cur_set if cur_set is not None else rv.state_range())
772
+ for cur_set, rv in zip(sum_space_set, self.rvs)
773
+ ]
744
774
 
745
775
  # Accumulate the result
746
- sum_value = 0
747
- for instance in rv_instances_as_indicators(*rvs):
748
- for indicator in instance:
749
- inst[indicator.rv_idx] = indicator.state_idx
750
- sum_value += self.value_product(inst)
751
-
752
- return sum_value
776
+ return sum(
777
+ self.value_product(instance)
778
+ for instance in _combos(sum_space_list)
779
+ )
753
780
 
754
781
  def dump_synopsis(
755
782
  self,
@@ -937,8 +964,8 @@ class PGM:
937
964
  else:
938
965
  _cur_rv = sorted(cur_rv)
939
966
  rv = self._rvs[_cur_rv[0].rv_idx]
940
- states_str = sep.join(str(rv.states[ind.state_idx]) for ind in _cur_rv)
941
- cur_str += f'{rv}{elem}{{{states_str}}}'
967
+ states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
968
+ cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
942
969
  return cur_str
943
970
 
944
971
 
@@ -1095,7 +1122,7 @@ class RandomVariable(Sequence[Indicator]):
1095
1122
  """
1096
1123
  return range(len(self._states))
1097
1124
 
1098
- def factors(self) -> Iterator[Factor]:
1125
+ def factors(self) -> Iterable[Factor]:
1099
1126
  """
1100
1127
  Iterate over factors that this random variable participates in.
1101
1128
  This method performs a search through all `self.pgm.factors`.
@@ -1194,8 +1221,8 @@ class RandomVariable(Sequence[Indicator]):
1194
1221
  return self.idx == other.idx and len(self) == len(other)
1195
1222
  else:
1196
1223
  return (
1197
- len(indicators) == len(other) and
1198
- all(indicators[i] == other[i] for i in range(len(indicators)))
1224
+ len(indicators) == len(other) and
1225
+ all(indicators[i] == other[i] for i in range(len(indicators)))
1199
1226
  )
1200
1227
 
1201
1228
  def __len__(self) -> int:
@@ -1467,7 +1494,7 @@ class Factor:
1467
1494
  def __getitem__(self, index):
1468
1495
  return self._rvs[index]
1469
1496
 
1470
- def instances(self, flip: bool = False) -> Iterator[Instance]:
1497
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1471
1498
  """
1472
1499
  Iterate over all possible instances, in natural index order (i.e.,
1473
1500
  last random variable changing most quickly).
@@ -1481,7 +1508,7 @@ class Factor:
1481
1508
  """
1482
1509
  return self.function.instances(flip)
1483
1510
 
1484
- def parent_instances(self, flip: bool = False) -> Iterator[Instance]:
1511
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1485
1512
  """
1486
1513
  Iterate over all possible instances of parent random variable, in
1487
1514
  natural index order (i.e., last random variable changing most quickly).
@@ -1935,7 +1962,7 @@ class PotentialFunction(ABC):
1935
1962
  raise ValueError(f'invalid parameter index: {param_idx}')
1936
1963
  return ParamId(id(self), param_idx)
1937
1964
 
1938
- def items(self) -> Iterator[Tuple[Instance, float]]:
1965
+ def items(self) -> Iterable[Tuple[Instance, float]]:
1939
1966
  """
1940
1967
  Iterate over all keys and values of this potential function.
1941
1968
 
@@ -1946,7 +1973,7 @@ class PotentialFunction(ABC):
1946
1973
  for key in _combos_ranges(self._shape, flip=True):
1947
1974
  yield key, self[key]
1948
1975
 
1949
- def instances(self, flip: bool = False) -> Iterator[Instance]:
1976
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
1950
1977
  """
1951
1978
  Iterate over all possible instances, in natural index order (i.e.,
1952
1979
  last random variable changing most quickly).
@@ -1960,7 +1987,7 @@ class PotentialFunction(ABC):
1960
1987
  """
1961
1988
  return _combos_ranges(self._shape, flip=not flip)
1962
1989
 
1963
- def parent_instances(self, flip: bool = False) -> Iterator[Instance]:
1990
+ def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
1964
1991
  """
1965
1992
  Iterate over all possible instances of parent random variable, in
1966
1993
  natural index order (i.e., last random variable changing most quickly).
@@ -2055,7 +2082,7 @@ class ZeroPotentialFunction(PotentialFunction):
2055
2082
  return self.number_of_states
2056
2083
 
2057
2084
  @property
2058
- def params(self) -> Iterator[Tuple[int, float]]:
2085
+ def params(self) -> Iterable[Tuple[int, float]]:
2059
2086
  for param_idx in range(self.number_of_parameters):
2060
2087
  yield param_idx, 0
2061
2088
 
@@ -3047,7 +3074,7 @@ class CPTPotentialFunction(PotentialFunction):
3047
3074
  else:
3048
3075
  return self._values[offset:offset + child_size]
3049
3076
 
3050
- def cpds(self) -> Iterator[Tuple[Instance, Sequence[float]]]:
3077
+ def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
3051
3078
  """
3052
3079
  Iterate over (parent_states, cpd) tuples.
3053
3080
  This will exclude zero CPDs.
@@ -3358,7 +3385,7 @@ def number_of_states(*rvs: RandomVariable) -> int:
3358
3385
  return _multiply(len(rv) for rv in rvs)
3359
3386
 
3360
3387
 
3361
- def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]:
3388
+ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
3362
3389
  """
3363
3390
  Enumerate instances of the given random variables.
3364
3391
 
@@ -3377,7 +3404,7 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]
3377
3404
  return _combos_ranges(shape, flip=not flip)
3378
3405
 
3379
3406
 
3380
- def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterator[Sequence[Indicator]]:
3407
+ def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
3381
3408
  """
3382
3409
  Enumerate instances of the given random variables.
3383
3410
 
@@ -3492,3 +3519,17 @@ def _normalise_potential_function(
3492
3519
  total = group_sum[group]
3493
3520
  if total > 0:
3494
3521
  function.set_param_value(param_idx, param_value / total)
3522
+
3523
+
3524
+ _CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
3525
+
3526
+
3527
+ def _clean_str(s) -> str:
3528
+ """
3529
+ Quote a string if empty or not all characters are in _CLEAN_CHARS.
3530
+ """
3531
+ s = str(s)
3532
+ if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
3533
+ return repr(s)
3534
+ else:
3535
+ return s
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Sequence, List, Dict
3
3
 
4
- from ck.circuit import CircuitNode
4
+ from ck.circuit import CircuitNode, Circuit
5
5
  from ck.pgm import RandomVariable, Indicator
6
6
  from ck.pgm_circuit.slot_map import SlotMap, SlotKey
7
7
  from ck.utils.np_extras import NDArray
@@ -46,19 +46,21 @@ class PGMCircuit:
46
46
 
47
47
  def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
48
48
  """
49
- Print a dump of the Join Tree.
49
+ Print a dump of the circuit.
50
50
  This is intended for debugging and demonstration purposes.
51
51
 
52
- Each cluster is printed as: {separator rvs} | {non-separator rvs}.
53
-
54
52
  Args:
55
53
  prefix: optional prefix for indenting all lines.
56
54
  indent: additional prefix to use for extra indentation.
57
55
  """
58
- # Name the circuit variables
59
- circuit = self.circuit_top.circuit
56
+
57
+ # We infer names for the circuit variables, either as an indicator or as a parameter.
58
+ # The `var_names` will be passed to `circuit.dump`.
59
+
60
+ circuit: Circuit = self.circuit_top.circuit
60
61
  var_names: List[str] = [''] * circuit.number_of_vars
61
62
 
63
+ # Name the circuit variables that are indicators
62
64
  rvs_by_idx: Dict[int, RandomVariable] = {rv.idx: rv for rv in self.rvs}
63
65
  slot_key: SlotKey
64
66
  slot: int
@@ -68,8 +70,10 @@ class PGMCircuit:
68
70
  state_idx = slot_key.state_idx
69
71
  var_names[slot] = f'{rv.name!r}[{state_idx}] {rv.states[state_idx]!r}'
70
72
 
71
- for slot, param_value in enumerate(self.parameter_values, start=self.number_of_indicators):
72
- var_names[slot] = f'param {param_value}'
73
+ # Name the circuit variables that are parameters
74
+ for i, param_value in enumerate(self.parameter_values):
75
+ slot = i + self.number_of_indicators
76
+ var_names[slot] = f'param[{i}] = {param_value}'
73
77
 
74
- # Show all the slots
78
+ # Dump the circuit
75
79
  circuit.dump(prefix=prefix, indent=indent, var_names=var_names)
@@ -1,4 +1,4 @@
1
- from typing import Tuple, Iterator, Sequence, Dict, Iterable
1
+ from typing import Tuple, Sequence, Dict, Iterable
2
2
 
3
3
  from ck.pgm import RandomVariable, rv_instances, Instance, rv_instances_as_indicators, Indicator, ParamId
4
4
  from ck.pgm_circuit.slot_map import SlotMap, SlotKey
@@ -30,11 +30,13 @@ class ProgramWithSlotmap:
30
30
  has a length and rv[i] is a unique 'indicator' across all rvs.
31
31
  precondition: conditions on rvs that are compiled into the program.
32
32
 
33
+ Raises:
34
+ ValueError: if rvs contains duplicates.
33
35
  """
34
36
  self._program_buffer: ProgramBuffer = program_buffer
35
37
  self._slot_map: SlotMap = slot_map
36
38
  self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
37
- self._precondition: Sequence[Indicator] = precondition
39
+ self._precondition: Tuple[Indicator, ...] = tuple(precondition)
38
40
 
39
41
  if len(rvs) != len(set(rv.idx for rv in rvs)):
40
42
  raise ValueError('duplicate random variables provided')
@@ -67,7 +69,7 @@ class ProgramWithSlotmap:
67
69
  def slot_map(self) -> SlotMap:
68
70
  return self._slot_map
69
71
 
70
- def instances(self, flip: bool = False) -> Iterator[Instance]:
72
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
71
73
  """
72
74
  Enumerate instances of the random variables.
73
75
 
@@ -84,7 +86,7 @@ class ProgramWithSlotmap:
84
86
  """
85
87
  return rv_instances(*self._rvs, flip=flip)
86
88
 
87
- def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
89
+ def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
88
90
  """
89
91
  Enumerate instances of the random variables.
90
92
 
@@ -5,7 +5,9 @@ from dataclasses import dataclass
5
5
  from pathlib import Path
6
6
  from typing import Optional, List, Tuple
7
7
 
8
- from ck.circuit import CircuitNode
8
+ import numpy as np
9
+
10
+ from ck.circuit import CircuitNode, Circuit
9
11
  from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
10
12
  from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
11
13
  from ck.in_out.render_net import render_bayesian_network
@@ -58,6 +60,22 @@ def compile_pgm(
58
60
  if check_is_bayesian_network and not pgm.check_is_bayesian_network():
59
61
  raise ValueError('the given PGM is not a Bayesian network')
60
62
 
63
+ # ACE cannot deal with the empty PGM even though it is a valid Bayesian network
64
+ if pgm.number_of_factors == 0:
65
+ circuit = Circuit()
66
+ circuit.new_vars(pgm.number_of_indicators)
67
+ parameter_values = np.array([], dtype=np.float64)
68
+ slot_map = {indicator: i for i, indicator in enumerate(pgm.indicators)}
69
+ return PGMCircuit(
70
+ rvs=pgm.rvs,
71
+ conditions=(),
72
+ circuit_top=circuit.const(1),
73
+ number_of_indicators=pgm.number_of_indicators,
74
+ number_of_parameters=0,
75
+ slot_map=slot_map,
76
+ parameter_values=parameter_values,
77
+ )
78
+
61
79
  java: str
62
80
  classpath_separator: str
63
81
  java, classpath_separator = _find_java()
@@ -69,7 +87,7 @@ def compile_pgm(
69
87
  )
70
88
  ace_cmd: List[str] = [
71
89
  java,
72
- f'-cp',
90
+ '-cp',
73
91
  class_path,
74
92
  f'-DACEC2D={files.c2d}',
75
93
  f'-Xmx{int(m_bytes)}m',
@@ -83,7 +101,7 @@ def compile_pgm(
83
101
  node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
84
102
 
85
103
  # Run Ace
86
- ace_result = subprocess.run(ace_cmd, capture_output=True, text=True)
104
+ ace_result: subprocess.CompletedProcess = subprocess.run(ace_cmd, capture_output=True, text=True)
87
105
  if print_output:
88
106
  print(ace_result.stdout)
89
107
  print(ace_result.stderr)
@@ -126,6 +144,29 @@ def compile_pgm(
126
144
  )
127
145
 
128
146
 
147
+ def ace_available(
148
+ ace_dir: Optional[Path | str] = None,
149
+ jar_dir: Optional[Path | str] = None,
150
+ ) -> bool:
151
+ """
152
+ Returns:
153
+ True if it looks like ACE is available, False otherwise.
154
+ ACE is available if ACE files are in the default location and Java is available.
155
+ """
156
+ try:
157
+ java: str
158
+ java, _ = _find_java()
159
+ _: _AceFiles = _find_ace_files(ace_dir, jar_dir)
160
+
161
+ java_cmd: List[str] = [java, '--version',]
162
+ java_result: subprocess.CompletedProcess = subprocess.run(java_cmd, capture_output=True, text=True)
163
+
164
+ return java_result.returncode == 0
165
+
166
+ except RuntimeError:
167
+ return False
168
+
169
+
129
170
  def copy_ace_to_default_location(
130
171
  ace_dir: Path | str,
131
172
  jar_dir: Optional[Path | str] = None,
@@ -179,6 +220,9 @@ def _find_java() -> Tuple[str, str]:
179
220
 
180
221
  Returns:
181
222
  (java, classpath_separator)
223
+
224
+ Raises:
225
+ RuntimeError: if not found, including a helpful message.
182
226
  """
183
227
  if sys.platform == 'win32':
184
228
  return 'java.exe', ';'
@@ -198,7 +242,7 @@ def _find_ace_files(
198
242
  Look for the needed Ace files.
199
243
 
200
244
  Raises:
201
- RuntimeError: if not found, including a helpful message .
245
+ RuntimeError: if not found, including a helpful message.
202
246
  """
203
247
  ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
204
248
  jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
@@ -12,19 +12,21 @@ from ck.pgm_compiler.support.join_tree import *
12
12
 
13
13
  _NEG_INF = float('-inf')
14
14
 
15
+ DEFAULT_PRODUCT_SEARCH_LIMIT: int = 1000
16
+
15
17
 
16
18
  def compile_pgm(
17
19
  pgm: PGM,
18
20
  const_parameters: bool = True,
19
21
  *,
20
22
  algorithm: JoinTreeAlgorithm = MIN_FILL_THEN_DEGREE,
21
- limit_product_tree_search: int = 1000,
23
+ limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
22
24
  pre_prune_factor_tables: bool = True,
23
25
  ) -> PGMCircuit:
24
26
  """
25
27
  Compile the PGM to an arithmetic circuit, using factor elimination.
26
28
 
27
- When forming the product of factors withing a join tree nodes,
29
+ When forming the product of factors within join tree nodes,
28
30
  this method searches all practical binary trees for forming products,
29
31
  up to the given limit, `limit_product_tree_search`. The minimum is 1.
30
32
 
@@ -57,7 +59,7 @@ def compile_pgm_best_jointree(
57
59
  pgm: PGM,
58
60
  const_parameters: bool = True,
59
61
  *,
60
- limit_product_tree_search: int = 1000,
62
+ limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
61
63
  pre_prune_factor_tables: bool = True,
62
64
  ) -> PGMCircuit:
63
65
  """
@@ -111,7 +113,7 @@ def compile_pgm_best_jointree(
111
113
  def join_tree_to_circuit(
112
114
  join_tree: JoinTree,
113
115
  const_parameters: bool = True,
114
- limit_product_tree_search: int = 1000,
116
+ limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
115
117
  pre_prune_factor_tables: bool = True,
116
118
  ) -> PGMCircuit:
117
119
  """
@@ -52,10 +52,15 @@ def compile_pgm(
52
52
  multiply_indicators=True,
53
53
  pre_prune_factor_tables=pre_prune_factor_tables,
54
54
  )
55
- dtree: _DTree = _make_dtree(elimination_order, factor_tables)
56
55
 
57
- states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
58
- top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
56
+ if pgm.number_of_factors == 0:
57
+ # Deal with special case: no factors
58
+ top: CircuitNode = factor_tables.circuit.const(1)
59
+ else:
60
+ dtree: _DTree = _make_dtree(elimination_order, factor_tables)
61
+ states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
62
+ top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
63
+
59
64
  top.circuit.remove_unreachable_op_nodes(top)
60
65
 
61
66
  return PGMCircuit(