compiled-knowledge 4.0.0a9__cp312-cp312-win_amd64.whl → 4.0.0a10__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/circuit.c +38860 -0
- ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
- ck/circuit/circuit.pyx +9 -3
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
- ck/pgm.py +78 -37
- ck/pgm_circuit/pgm_circuit.py +13 -9
- ck/pgm_circuit/program_with_slotmap.py +6 -4
- ck/pgm_compiler/ace/ace.py +48 -4
- ck/pgm_compiler/factor_elimination.py +6 -4
- ck/pgm_compiler/recursive_conditioning.py +8 -3
- ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
- ck/pgm_compiler/support/clusters.py +1 -1
- ck/pgm_compiler/variable_elimination.py +3 -3
- ck/probability/empirical_probability_space.py +3 -0
- ck/probability/pgm_probability_space.py +32 -0
- ck/probability/probability_space.py +66 -12
- ck/sampling/sampler_support.py +1 -1
- ck/sampling/uniform_sampler.py +10 -4
- ck/sampling/wmc_direct_sampler.py +4 -2
- ck/sampling/wmc_gibbs_sampler.py +6 -0
- ck/sampling/wmc_metropolis_sampler.py +7 -1
- ck/sampling/wmc_rejection_sampler.py +2 -0
- ck/utils/iter_extras.py +9 -6
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/RECORD +30 -26
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a10.dist-info}/top_level.txt +0 -0
|
Binary file
|
ck/pgm.py
CHANGED
|
@@ -82,8 +82,8 @@ import math
|
|
|
82
82
|
from abc import ABC, abstractmethod
|
|
83
83
|
from dataclasses import dataclass
|
|
84
84
|
from itertools import repeat as _repeat
|
|
85
|
-
from typing import Sequence, Tuple, Dict, Optional, overload,
|
|
86
|
-
Collection, Any
|
|
85
|
+
from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
|
|
86
|
+
Collection, Any, Iterator
|
|
87
87
|
|
|
88
88
|
import numpy as np
|
|
89
89
|
|
|
@@ -462,7 +462,7 @@ class PGM:
|
|
|
462
462
|
a string representation of the given indicators.
|
|
463
463
|
"""
|
|
464
464
|
return delim.join(
|
|
465
|
-
f'{rv}{sep}{state}'
|
|
465
|
+
f'{_clean_str(rv)}{sep}{_clean_str(state)}'
|
|
466
466
|
for rv, state in (
|
|
467
467
|
self.indicator_pair(indicator)
|
|
468
468
|
for indicator in indicators
|
|
@@ -559,7 +559,7 @@ class PGM:
|
|
|
559
559
|
assert len(instance) == len(rvs)
|
|
560
560
|
return delim.join(str(rv.states[i]) for rv, i in zip(rvs, instance))
|
|
561
561
|
|
|
562
|
-
def instances(self, flip: bool = False) ->
|
|
562
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
563
563
|
"""
|
|
564
564
|
Iterate over all possible instances of this PGM, in natural index
|
|
565
565
|
order (i.e., last random variable changing most quickly).
|
|
@@ -573,7 +573,7 @@ class PGM:
|
|
|
573
573
|
"""
|
|
574
574
|
return _combos_ranges(tuple(len(rv) for rv in self._rvs), flip=not flip)
|
|
575
575
|
|
|
576
|
-
def instances_as_indicators(self, flip: bool = False) ->
|
|
576
|
+
def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
577
577
|
"""
|
|
578
578
|
Iterate over all possible instances of this PGM, in natural index
|
|
579
579
|
order (i.e., last random variable changing most quickly).
|
|
@@ -605,7 +605,7 @@ class PGM:
|
|
|
605
605
|
"""
|
|
606
606
|
return tuple(rv[state] for rv, state in zip(self._rvs, instance))
|
|
607
607
|
|
|
608
|
-
def factor_values(self, key: Key) ->
|
|
608
|
+
def factor_values(self, key: Key) -> Iterable[float]:
|
|
609
609
|
"""
|
|
610
610
|
For a given instance key, each factor defines a single value. This method
|
|
611
611
|
returns those values.
|
|
@@ -717,6 +717,11 @@ class PGM:
|
|
|
717
717
|
If no indicators are provided, then the value of the partition function (z)
|
|
718
718
|
is returned.
|
|
719
719
|
|
|
720
|
+
If multiple indicators are provided for the same random variable, then all matching
|
|
721
|
+
instances are summed.
|
|
722
|
+
|
|
723
|
+
This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
|
|
724
|
+
|
|
720
725
|
Warning:
|
|
721
726
|
this is potentially computationally expensive as it marginalised random
|
|
722
727
|
variables not mentioned in the given indicators.
|
|
@@ -727,29 +732,51 @@ class PGM:
|
|
|
727
732
|
Returns:
|
|
728
733
|
the product of factors, conditioned on the given instance. This is the
|
|
729
734
|
computed value of the PGM, conditioned on the given instance.
|
|
730
|
-
|
|
731
|
-
Raises:
|
|
732
|
-
RuntimeError: if a random variable is referenced multiple times in the given indicators.
|
|
733
735
|
"""
|
|
734
|
-
# Create
|
|
735
|
-
|
|
736
|
+
# # Create a filter from the indicators
|
|
737
|
+
# inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
|
|
738
|
+
# for indicator in indicators:
|
|
739
|
+
# rv_idx: int = indicator.rv_idx
|
|
740
|
+
# inst_filter[rv_idx].add(indicator.state_idx)
|
|
741
|
+
# # Collect rvs not mentioned - to marginalise
|
|
742
|
+
# for rv, rv_filter in zip(self.rvs, inst_filter):
|
|
743
|
+
# if len(rv_filter) == 0:
|
|
744
|
+
# rv_filter.update(rv.state_range())
|
|
745
|
+
#
|
|
746
|
+
# def _sum_inst(_instance: Instance) -> bool:
|
|
747
|
+
# return all(
|
|
748
|
+
# (_state in _rv_filter)
|
|
749
|
+
# for _state, _rv_filter in zip(_instance, inst_filter)
|
|
750
|
+
# )
|
|
751
|
+
#
|
|
752
|
+
# # Accumulate the result
|
|
753
|
+
# sum_value = 0
|
|
754
|
+
# for instance in self.instances():
|
|
755
|
+
# if _sum_inst(instance):
|
|
756
|
+
# sum_value += self.value_product(instance)
|
|
757
|
+
#
|
|
758
|
+
# return sum_value
|
|
759
|
+
|
|
760
|
+
# Work out the space to sum over
|
|
761
|
+
sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
|
|
736
762
|
for indicator in indicators:
|
|
737
763
|
rv_idx: int = indicator.rv_idx
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
764
|
+
cur_set = sum_space_set[rv_idx]
|
|
765
|
+
if cur_set is None:
|
|
766
|
+
sum_space_set[rv_idx] = cur_set = set()
|
|
767
|
+
cur_set.add(indicator.state_idx)
|
|
741
768
|
|
|
742
|
-
#
|
|
743
|
-
|
|
769
|
+
# Convert to a list of states that we need to sum over.
|
|
770
|
+
sum_space_list: List[List[int]] = [
|
|
771
|
+
list(cur_set if cur_set is not None else rv.state_range())
|
|
772
|
+
for cur_set, rv in zip(sum_space_set, self.rvs)
|
|
773
|
+
]
|
|
744
774
|
|
|
745
775
|
# Accumulate the result
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
for
|
|
749
|
-
|
|
750
|
-
sum_value += self.value_product(inst)
|
|
751
|
-
|
|
752
|
-
return sum_value
|
|
776
|
+
return sum(
|
|
777
|
+
self.value_product(instance)
|
|
778
|
+
for instance in _combos(sum_space_list)
|
|
779
|
+
)
|
|
753
780
|
|
|
754
781
|
def dump_synopsis(
|
|
755
782
|
self,
|
|
@@ -937,8 +964,8 @@ class PGM:
|
|
|
937
964
|
else:
|
|
938
965
|
_cur_rv = sorted(cur_rv)
|
|
939
966
|
rv = self._rvs[_cur_rv[0].rv_idx]
|
|
940
|
-
states_str = sep.join(
|
|
941
|
-
cur_str += f'{rv}{elem}{{{states_str}}}'
|
|
967
|
+
states_str: str = sep.join(_clean_str(rv.states[ind.state_idx]) for ind in _cur_rv)
|
|
968
|
+
cur_str += f'{_clean_str(rv)}{elem}{{{states_str}}}'
|
|
942
969
|
return cur_str
|
|
943
970
|
|
|
944
971
|
|
|
@@ -1095,7 +1122,7 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1095
1122
|
"""
|
|
1096
1123
|
return range(len(self._states))
|
|
1097
1124
|
|
|
1098
|
-
def factors(self) ->
|
|
1125
|
+
def factors(self) -> Iterable[Factor]:
|
|
1099
1126
|
"""
|
|
1100
1127
|
Iterate over factors that this random variable participates in.
|
|
1101
1128
|
This method performs a search through all `self.pgm.factors`.
|
|
@@ -1194,8 +1221,8 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1194
1221
|
return self.idx == other.idx and len(self) == len(other)
|
|
1195
1222
|
else:
|
|
1196
1223
|
return (
|
|
1197
|
-
|
|
1198
|
-
|
|
1224
|
+
len(indicators) == len(other) and
|
|
1225
|
+
all(indicators[i] == other[i] for i in range(len(indicators)))
|
|
1199
1226
|
)
|
|
1200
1227
|
|
|
1201
1228
|
def __len__(self) -> int:
|
|
@@ -1467,7 +1494,7 @@ class Factor:
|
|
|
1467
1494
|
def __getitem__(self, index):
|
|
1468
1495
|
return self._rvs[index]
|
|
1469
1496
|
|
|
1470
|
-
def instances(self, flip: bool = False) ->
|
|
1497
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1471
1498
|
"""
|
|
1472
1499
|
Iterate over all possible instances, in natural index order (i.e.,
|
|
1473
1500
|
last random variable changing most quickly).
|
|
@@ -1481,7 +1508,7 @@ class Factor:
|
|
|
1481
1508
|
"""
|
|
1482
1509
|
return self.function.instances(flip)
|
|
1483
1510
|
|
|
1484
|
-
def parent_instances(self, flip: bool = False) ->
|
|
1511
|
+
def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1485
1512
|
"""
|
|
1486
1513
|
Iterate over all possible instances of parent random variable, in
|
|
1487
1514
|
natural index order (i.e., last random variable changing most quickly).
|
|
@@ -1935,7 +1962,7 @@ class PotentialFunction(ABC):
|
|
|
1935
1962
|
raise ValueError(f'invalid parameter index: {param_idx}')
|
|
1936
1963
|
return ParamId(id(self), param_idx)
|
|
1937
1964
|
|
|
1938
|
-
def items(self) ->
|
|
1965
|
+
def items(self) -> Iterable[Tuple[Instance, float]]:
|
|
1939
1966
|
"""
|
|
1940
1967
|
Iterate over all keys and values of this potential function.
|
|
1941
1968
|
|
|
@@ -1946,7 +1973,7 @@ class PotentialFunction(ABC):
|
|
|
1946
1973
|
for key in _combos_ranges(self._shape, flip=True):
|
|
1947
1974
|
yield key, self[key]
|
|
1948
1975
|
|
|
1949
|
-
def instances(self, flip: bool = False) ->
|
|
1976
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1950
1977
|
"""
|
|
1951
1978
|
Iterate over all possible instances, in natural index order (i.e.,
|
|
1952
1979
|
last random variable changing most quickly).
|
|
@@ -1960,7 +1987,7 @@ class PotentialFunction(ABC):
|
|
|
1960
1987
|
"""
|
|
1961
1988
|
return _combos_ranges(self._shape, flip=not flip)
|
|
1962
1989
|
|
|
1963
|
-
def parent_instances(self, flip: bool = False) ->
|
|
1990
|
+
def parent_instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
1964
1991
|
"""
|
|
1965
1992
|
Iterate over all possible instances of parent random variable, in
|
|
1966
1993
|
natural index order (i.e., last random variable changing most quickly).
|
|
@@ -2055,7 +2082,7 @@ class ZeroPotentialFunction(PotentialFunction):
|
|
|
2055
2082
|
return self.number_of_states
|
|
2056
2083
|
|
|
2057
2084
|
@property
|
|
2058
|
-
def params(self) ->
|
|
2085
|
+
def params(self) -> Iterable[Tuple[int, float]]:
|
|
2059
2086
|
for param_idx in range(self.number_of_parameters):
|
|
2060
2087
|
yield param_idx, 0
|
|
2061
2088
|
|
|
@@ -3047,7 +3074,7 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3047
3074
|
else:
|
|
3048
3075
|
return self._values[offset:offset + child_size]
|
|
3049
3076
|
|
|
3050
|
-
def cpds(self) ->
|
|
3077
|
+
def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
|
|
3051
3078
|
"""
|
|
3052
3079
|
Iterate over (parent_states, cpd) tuples.
|
|
3053
3080
|
This will exclude zero CPDs.
|
|
@@ -3358,7 +3385,7 @@ def number_of_states(*rvs: RandomVariable) -> int:
|
|
|
3358
3385
|
return _multiply(len(rv) for rv in rvs)
|
|
3359
3386
|
|
|
3360
3387
|
|
|
3361
|
-
def rv_instances(*rvs: RandomVariable, flip: bool = False) ->
|
|
3388
|
+
def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]:
|
|
3362
3389
|
"""
|
|
3363
3390
|
Enumerate instances of the given random variables.
|
|
3364
3391
|
|
|
@@ -3377,7 +3404,7 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterator[Instance]
|
|
|
3377
3404
|
return _combos_ranges(shape, flip=not flip)
|
|
3378
3405
|
|
|
3379
3406
|
|
|
3380
|
-
def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) ->
|
|
3407
|
+
def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
3381
3408
|
"""
|
|
3382
3409
|
Enumerate instances of the given random variables.
|
|
3383
3410
|
|
|
@@ -3492,3 +3519,17 @@ def _normalise_potential_function(
|
|
|
3492
3519
|
total = group_sum[group]
|
|
3493
3520
|
if total > 0:
|
|
3494
3521
|
function.set_param_value(param_idx, param_value / total)
|
|
3522
|
+
|
|
3523
|
+
|
|
3524
|
+
_CLEAN_CHARS: Set[str] = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-+~?.')
|
|
3525
|
+
|
|
3526
|
+
|
|
3527
|
+
def _clean_str(s) -> str:
|
|
3528
|
+
"""
|
|
3529
|
+
Quote a string if empty or not all characters are in _CLEAN_CHARS.
|
|
3530
|
+
"""
|
|
3531
|
+
s = str(s)
|
|
3532
|
+
if len(s) == 0 or not all(c in _CLEAN_CHARS for c in s):
|
|
3533
|
+
return repr(s)
|
|
3534
|
+
else:
|
|
3535
|
+
return s
|
ck/pgm_circuit/pgm_circuit.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import Sequence, List, Dict
|
|
3
3
|
|
|
4
|
-
from ck.circuit import CircuitNode
|
|
4
|
+
from ck.circuit import CircuitNode, Circuit
|
|
5
5
|
from ck.pgm import RandomVariable, Indicator
|
|
6
6
|
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
7
7
|
from ck.utils.np_extras import NDArray
|
|
@@ -46,19 +46,21 @@ class PGMCircuit:
|
|
|
46
46
|
|
|
47
47
|
def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
|
|
48
48
|
"""
|
|
49
|
-
Print a dump of the
|
|
49
|
+
Print a dump of the circuit.
|
|
50
50
|
This is intended for debugging and demonstration purposes.
|
|
51
51
|
|
|
52
|
-
Each cluster is printed as: {separator rvs} | {non-separator rvs}.
|
|
53
|
-
|
|
54
52
|
Args:
|
|
55
53
|
prefix: optional prefix for indenting all lines.
|
|
56
54
|
indent: additional prefix to use for extra indentation.
|
|
57
55
|
"""
|
|
58
|
-
|
|
59
|
-
circuit
|
|
56
|
+
|
|
57
|
+
# We infer names for the circuit variables, either as an indicator or as a parameter.
|
|
58
|
+
# The `var_names` will be passed to `circuit.dump`.
|
|
59
|
+
|
|
60
|
+
circuit: Circuit = self.circuit_top.circuit
|
|
60
61
|
var_names: List[str] = [''] * circuit.number_of_vars
|
|
61
62
|
|
|
63
|
+
# Name the circuit variables that are indicators
|
|
62
64
|
rvs_by_idx: Dict[int, RandomVariable] = {rv.idx: rv for rv in self.rvs}
|
|
63
65
|
slot_key: SlotKey
|
|
64
66
|
slot: int
|
|
@@ -68,8 +70,10 @@ class PGMCircuit:
|
|
|
68
70
|
state_idx = slot_key.state_idx
|
|
69
71
|
var_names[slot] = f'{rv.name!r}[{state_idx}] {rv.states[state_idx]!r}'
|
|
70
72
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
+
# Name the circuit variables that are parameters
|
|
74
|
+
for i, param_value in enumerate(self.parameter_values):
|
|
75
|
+
slot = i + self.number_of_indicators
|
|
76
|
+
var_names[slot] = f'param[{i}] = {param_value}'
|
|
73
77
|
|
|
74
|
-
#
|
|
78
|
+
# Dump the circuit
|
|
75
79
|
circuit.dump(prefix=prefix, indent=indent, var_names=var_names)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Tuple,
|
|
1
|
+
from typing import Tuple, Sequence, Dict, Iterable
|
|
2
2
|
|
|
3
3
|
from ck.pgm import RandomVariable, rv_instances, Instance, rv_instances_as_indicators, Indicator, ParamId
|
|
4
4
|
from ck.pgm_circuit.slot_map import SlotMap, SlotKey
|
|
@@ -30,11 +30,13 @@ class ProgramWithSlotmap:
|
|
|
30
30
|
has a length and rv[i] is a unique 'indicator' across all rvs.
|
|
31
31
|
precondition: conditions on rvs that are compiled into the program.
|
|
32
32
|
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: if rvs contains duplicates.
|
|
33
35
|
"""
|
|
34
36
|
self._program_buffer: ProgramBuffer = program_buffer
|
|
35
37
|
self._slot_map: SlotMap = slot_map
|
|
36
38
|
self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
|
|
37
|
-
self._precondition:
|
|
39
|
+
self._precondition: Tuple[Indicator, ...] = tuple(precondition)
|
|
38
40
|
|
|
39
41
|
if len(rvs) != len(set(rv.idx for rv in rvs)):
|
|
40
42
|
raise ValueError('duplicate random variables provided')
|
|
@@ -67,7 +69,7 @@ class ProgramWithSlotmap:
|
|
|
67
69
|
def slot_map(self) -> SlotMap:
|
|
68
70
|
return self._slot_map
|
|
69
71
|
|
|
70
|
-
def instances(self, flip: bool = False) ->
|
|
72
|
+
def instances(self, flip: bool = False) -> Iterable[Instance]:
|
|
71
73
|
"""
|
|
72
74
|
Enumerate instances of the random variables.
|
|
73
75
|
|
|
@@ -84,7 +86,7 @@ class ProgramWithSlotmap:
|
|
|
84
86
|
"""
|
|
85
87
|
return rv_instances(*self._rvs, flip=flip)
|
|
86
88
|
|
|
87
|
-
def instances_as_indicators(self, flip: bool = False) ->
|
|
89
|
+
def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
|
|
88
90
|
"""
|
|
89
91
|
Enumerate instances of the random variables.
|
|
90
92
|
|
ck/pgm_compiler/ace/ace.py
CHANGED
|
@@ -5,7 +5,9 @@ from dataclasses import dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Optional, List, Tuple
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from ck.circuit import CircuitNode, Circuit
|
|
9
11
|
from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
|
|
10
12
|
from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
|
|
11
13
|
from ck.in_out.render_net import render_bayesian_network
|
|
@@ -58,6 +60,22 @@ def compile_pgm(
|
|
|
58
60
|
if check_is_bayesian_network and not pgm.check_is_bayesian_network():
|
|
59
61
|
raise ValueError('the given PGM is not a Bayesian network')
|
|
60
62
|
|
|
63
|
+
# ACE cannot deal with the empty PGM even though it is a valid Bayesian network
|
|
64
|
+
if pgm.number_of_factors == 0:
|
|
65
|
+
circuit = Circuit()
|
|
66
|
+
circuit.new_vars(pgm.number_of_indicators)
|
|
67
|
+
parameter_values = np.array([], dtype=np.float64)
|
|
68
|
+
slot_map = {indicator: i for i, indicator in enumerate(pgm.indicators)}
|
|
69
|
+
return PGMCircuit(
|
|
70
|
+
rvs=pgm.rvs,
|
|
71
|
+
conditions=(),
|
|
72
|
+
circuit_top=circuit.const(1),
|
|
73
|
+
number_of_indicators=pgm.number_of_indicators,
|
|
74
|
+
number_of_parameters=0,
|
|
75
|
+
slot_map=slot_map,
|
|
76
|
+
parameter_values=parameter_values,
|
|
77
|
+
)
|
|
78
|
+
|
|
61
79
|
java: str
|
|
62
80
|
classpath_separator: str
|
|
63
81
|
java, classpath_separator = _find_java()
|
|
@@ -69,7 +87,7 @@ def compile_pgm(
|
|
|
69
87
|
)
|
|
70
88
|
ace_cmd: List[str] = [
|
|
71
89
|
java,
|
|
72
|
-
|
|
90
|
+
'-cp',
|
|
73
91
|
class_path,
|
|
74
92
|
f'-DACEC2D={files.c2d}',
|
|
75
93
|
f'-Xmx{int(m_bytes)}m',
|
|
@@ -83,7 +101,7 @@ def compile_pgm(
|
|
|
83
101
|
node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
|
|
84
102
|
|
|
85
103
|
# Run Ace
|
|
86
|
-
ace_result = subprocess.run(ace_cmd, capture_output=True, text=True)
|
|
104
|
+
ace_result: subprocess.CompletedProcess = subprocess.run(ace_cmd, capture_output=True, text=True)
|
|
87
105
|
if print_output:
|
|
88
106
|
print(ace_result.stdout)
|
|
89
107
|
print(ace_result.stderr)
|
|
@@ -126,6 +144,29 @@ def compile_pgm(
|
|
|
126
144
|
)
|
|
127
145
|
|
|
128
146
|
|
|
147
|
+
def ace_available(
|
|
148
|
+
ace_dir: Optional[Path | str] = None,
|
|
149
|
+
jar_dir: Optional[Path | str] = None,
|
|
150
|
+
) -> bool:
|
|
151
|
+
"""
|
|
152
|
+
Returns:
|
|
153
|
+
True if it looks like ACE is available, False otherwise.
|
|
154
|
+
ACE is available if ACE files are in the default location and Java is available.
|
|
155
|
+
"""
|
|
156
|
+
try:
|
|
157
|
+
java: str
|
|
158
|
+
java, _ = _find_java()
|
|
159
|
+
_: _AceFiles = _find_ace_files(ace_dir, jar_dir)
|
|
160
|
+
|
|
161
|
+
java_cmd: List[str] = [java, '--version',]
|
|
162
|
+
java_result: subprocess.CompletedProcess = subprocess.run(java_cmd, capture_output=True, text=True)
|
|
163
|
+
|
|
164
|
+
return java_result.returncode == 0
|
|
165
|
+
|
|
166
|
+
except RuntimeError:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
|
|
129
170
|
def copy_ace_to_default_location(
|
|
130
171
|
ace_dir: Path | str,
|
|
131
172
|
jar_dir: Optional[Path | str] = None,
|
|
@@ -179,6 +220,9 @@ def _find_java() -> Tuple[str, str]:
|
|
|
179
220
|
|
|
180
221
|
Returns:
|
|
181
222
|
(java, classpath_separator)
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
RuntimeError: if not found, including a helpful message.
|
|
182
226
|
"""
|
|
183
227
|
if sys.platform == 'win32':
|
|
184
228
|
return 'java.exe', ';'
|
|
@@ -198,7 +242,7 @@ def _find_ace_files(
|
|
|
198
242
|
Look for the needed Ace files.
|
|
199
243
|
|
|
200
244
|
Raises:
|
|
201
|
-
RuntimeError: if not found, including a helpful message
|
|
245
|
+
RuntimeError: if not found, including a helpful message.
|
|
202
246
|
"""
|
|
203
247
|
ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
|
|
204
248
|
jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
|
|
@@ -12,19 +12,21 @@ from ck.pgm_compiler.support.join_tree import *
|
|
|
12
12
|
|
|
13
13
|
_NEG_INF = float('-inf')
|
|
14
14
|
|
|
15
|
+
DEFAULT_PRODUCT_SEARCH_LIMIT: int = 1000
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
def compile_pgm(
|
|
17
19
|
pgm: PGM,
|
|
18
20
|
const_parameters: bool = True,
|
|
19
21
|
*,
|
|
20
22
|
algorithm: JoinTreeAlgorithm = MIN_FILL_THEN_DEGREE,
|
|
21
|
-
limit_product_tree_search: int =
|
|
23
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
22
24
|
pre_prune_factor_tables: bool = True,
|
|
23
25
|
) -> PGMCircuit:
|
|
24
26
|
"""
|
|
25
27
|
Compile the PGM to an arithmetic circuit, using factor elimination.
|
|
26
28
|
|
|
27
|
-
When forming the product of factors
|
|
29
|
+
When forming the product of factors within join tree nodes,
|
|
28
30
|
this method searches all practical binary trees for forming products,
|
|
29
31
|
up to the given limit, `limit_product_tree_search`. The minimum is 1.
|
|
30
32
|
|
|
@@ -57,7 +59,7 @@ def compile_pgm_best_jointree(
|
|
|
57
59
|
pgm: PGM,
|
|
58
60
|
const_parameters: bool = True,
|
|
59
61
|
*,
|
|
60
|
-
limit_product_tree_search: int =
|
|
62
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
61
63
|
pre_prune_factor_tables: bool = True,
|
|
62
64
|
) -> PGMCircuit:
|
|
63
65
|
"""
|
|
@@ -111,7 +113,7 @@ def compile_pgm_best_jointree(
|
|
|
111
113
|
def join_tree_to_circuit(
|
|
112
114
|
join_tree: JoinTree,
|
|
113
115
|
const_parameters: bool = True,
|
|
114
|
-
limit_product_tree_search: int =
|
|
116
|
+
limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
|
|
115
117
|
pre_prune_factor_tables: bool = True,
|
|
116
118
|
) -> PGMCircuit:
|
|
117
119
|
"""
|
|
@@ -52,10 +52,15 @@ def compile_pgm(
|
|
|
52
52
|
multiply_indicators=True,
|
|
53
53
|
pre_prune_factor_tables=pre_prune_factor_tables,
|
|
54
54
|
)
|
|
55
|
-
dtree: _DTree = _make_dtree(elimination_order, factor_tables)
|
|
56
55
|
|
|
57
|
-
|
|
58
|
-
|
|
56
|
+
if pgm.number_of_factors == 0:
|
|
57
|
+
# Deal with special case: no factors
|
|
58
|
+
top: CircuitNode = factor_tables.circuit.const(1)
|
|
59
|
+
else:
|
|
60
|
+
dtree: _DTree = _make_dtree(elimination_order, factor_tables)
|
|
61
|
+
states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
|
|
62
|
+
top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
|
|
63
|
+
|
|
59
64
|
top.circuit.remove_unreachable_op_nodes(top)
|
|
60
65
|
|
|
61
66
|
return PGMCircuit(
|