compiled-knowledge 4.0.0a23__cp312-cp312-win32.whl → 4.0.0a25__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/_circuit_cy.c +1 -1
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/tmp_const.py +5 -4
- ck/circuit_compiler/circuit_compiler.py +3 -2
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/llvm_ir_function.py +4 -4
- ck/example/diamond_square.py +3 -1
- ck/example/triangle_square.py +3 -1
- ck/example/truss.py +3 -1
- ck/in_out/parse_net.py +21 -19
- ck/in_out/parser_utils.py +7 -3
- ck/pgm.py +146 -139
- ck/pgm_circuit/mpe_program.py +3 -4
- ck/pgm_circuit/pgm_circuit.py +27 -18
- ck/pgm_circuit/program_with_slotmap.py +4 -1
- ck/pgm_compiler/pgm_compiler.py +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/pgm_compiler/support/join_tree.py +3 -3
- ck/probability/empirical_probability_space.py +4 -3
- ck/probability/pgm_probability_space.py +7 -3
- ck/probability/probability_space.py +20 -15
- ck/program/raw_program.py +23 -16
- ck/sampling/sampler_support.py +7 -5
- ck/utils/iter_extras.py +3 -2
- ck/utils/local_config.py +16 -8
- {compiled_knowledge-4.0.0a23.dist-info → compiled_knowledge-4.0.0a25.dist-info}/METADATA +1 -1
- {compiled_knowledge-4.0.0a23.dist-info → compiled_knowledge-4.0.0a25.dist-info}/RECORD +34 -34
- {compiled_knowledge-4.0.0a23.dist-info → compiled_knowledge-4.0.0a25.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a23.dist-info → compiled_knowledge-4.0.0a25.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a23.dist-info → compiled_knowledge-4.0.0a25.dist-info}/top_level.txt +0 -0
ck/pgm.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
"""
|
|
2
|
-
For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
|
|
3
|
-
"""
|
|
4
1
|
from __future__ import annotations
|
|
5
2
|
|
|
6
3
|
import math
|
|
@@ -8,7 +5,7 @@ from abc import ABC, abstractmethod
|
|
|
8
5
|
from dataclasses import dataclass
|
|
9
6
|
from itertools import repeat as _repeat
|
|
10
7
|
from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
|
|
11
|
-
Collection, Any, Iterator
|
|
8
|
+
Collection, Any, Iterator, TypeAlias
|
|
12
9
|
|
|
13
10
|
import numpy as np
|
|
14
11
|
|
|
@@ -17,21 +14,33 @@ from ck.utils.iter_extras import (
|
|
|
17
14
|
)
|
|
18
15
|
from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
|
|
19
16
|
|
|
20
|
-
|
|
21
|
-
|
|
17
|
+
State: TypeAlias = Union[int, str, bool, float, None]
|
|
18
|
+
"""
|
|
19
|
+
The type for a possible state of a random variable.
|
|
20
|
+
"""
|
|
22
21
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
Instance: TypeAlias = Sequence[int]
|
|
23
|
+
"""
|
|
24
|
+
An instance (of a sequence of random variables) is a sequence of integers
|
|
25
|
+
that are state indexes, co-indexed with a known sequence of random variables.
|
|
26
|
+
"""
|
|
26
27
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
28
|
+
Key: TypeAlias = Union[Instance, int]
|
|
29
|
+
"""
|
|
30
|
+
A key identifies an instance, either as an instance itself or a
|
|
31
|
+
single integer, representing an instance with one dimension.
|
|
32
|
+
"""
|
|
30
33
|
|
|
31
|
-
|
|
32
|
-
|
|
34
|
+
Shape: TypeAlias = Sequence[int]
|
|
35
|
+
"""
|
|
36
|
+
The type for the "shape" of a sequence of random variables.
|
|
37
|
+
That is, the shape of (rv1, rv2, rv3) is (len(rv1), len(rv2), len(rv3)).
|
|
38
|
+
"""
|
|
33
39
|
|
|
34
|
-
|
|
40
|
+
DEFAULT_CPT_TOLERANCE: float = 0.000001
|
|
41
|
+
"""
|
|
42
|
+
A tolerance when checking CPT distributions sum to one (or zero).
|
|
43
|
+
"""
|
|
35
44
|
|
|
36
45
|
|
|
37
46
|
class PGM:
|
|
@@ -39,11 +48,9 @@ class PGM:
|
|
|
39
48
|
A probabilistic graphical model (PGM) represents a joint probability distribution over
|
|
40
49
|
a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
|
|
41
50
|
|
|
42
|
-
Add a random variable to a PGM, pgm
|
|
43
|
-
|
|
44
|
-
Add a factor to the PGM, pgm, using `factor = pgm.new_factor(...)`.
|
|
51
|
+
Add a random variable to a PGM, `pgm`, using `rv = pgm.new_rv(...)`.
|
|
45
52
|
|
|
46
|
-
|
|
53
|
+
Add a factor to the PGM, `pgm`, using `factor = pgm.new_factor(...)`.
|
|
47
54
|
"""
|
|
48
55
|
|
|
49
56
|
def __init__(self, name: Optional[str] = None):
|
|
@@ -206,14 +213,17 @@ class PGM:
|
|
|
206
213
|
The returned random variable will have an `idx` equal to the value of
|
|
207
214
|
`self.number_of_rvs` just prior to adding the new random variable.
|
|
208
215
|
|
|
216
|
+
The states of the random variable can be specified either as an integer
|
|
217
|
+
representing the number of states, or as a sequence of state values. If a
|
|
218
|
+
single integer, `n`, is provided then the states will be: 0, 1, ..., n-1.
|
|
219
|
+
If a sequence of states are provided then the states must be unique.
|
|
220
|
+
|
|
209
221
|
Assumes:
|
|
210
222
|
Provided states contain no duplicates.
|
|
211
223
|
|
|
212
224
|
Args:
|
|
213
225
|
name: a name for the random variable.
|
|
214
|
-
states: either
|
|
215
|
-
single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
|
|
216
|
-
If a sequence of states are provided then the states must be unique.
|
|
226
|
+
states: either the number of states or a sequence of state values.
|
|
217
227
|
|
|
218
228
|
Returns:
|
|
219
229
|
a RandomVariable object belonging to this PGM.
|
|
@@ -233,10 +243,11 @@ class PGM:
|
|
|
233
243
|
|
|
234
244
|
Assumes:
|
|
235
245
|
The given random variables all belong to this PGM.
|
|
246
|
+
|
|
236
247
|
The random variables contain no duplicates.
|
|
237
248
|
|
|
238
249
|
Args:
|
|
239
|
-
|
|
250
|
+
rvs: the random variables.
|
|
240
251
|
|
|
241
252
|
Returns:
|
|
242
253
|
a Factor object belonging to this PGM.
|
|
@@ -328,17 +339,18 @@ class PGM:
|
|
|
328
339
|
*input_rvs: RandomVariable
|
|
329
340
|
) -> Factor:
|
|
330
341
|
"""
|
|
331
|
-
Add a sparse 0/1 factor to this PGM representing
|
|
332
|
-
|
|
333
|
-
|
|
342
|
+
Add a sparse 0/1 factor to this PGM representing `result_rv == function(*rvs)`.
|
|
343
|
+
That is::
|
|
344
|
+
|
|
334
345
|
factor[result_s, *input_s] = 1, if result_s == function(*input_s);
|
|
335
346
|
= 0, otherwise.
|
|
347
|
+
|
|
336
348
|
Args:
|
|
337
349
|
function: a function from state indexes of the input random variables to a state index
|
|
338
350
|
of the result random variable. The function should take the same number of arguments
|
|
339
351
|
as `input_rvs` and return a state index for `result_rv`.
|
|
340
352
|
result_rv: the random variable defining result values.
|
|
341
|
-
|
|
353
|
+
input_rvs: the random variables defining input values.
|
|
342
354
|
|
|
343
355
|
Returns:
|
|
344
356
|
a Factor object belonging to this PGM, with a configured sparse potential function.
|
|
@@ -370,16 +382,17 @@ class PGM:
|
|
|
370
382
|
"""
|
|
371
383
|
Render indicators as a string.
|
|
372
384
|
|
|
373
|
-
For example
|
|
385
|
+
For example::
|
|
374
386
|
pgm = PGM()
|
|
375
387
|
a = pgm.new_rv('A', ('x', 'y', 'z'))
|
|
376
388
|
b = pgm.new_rv('B', (3, 5))
|
|
377
389
|
print(pgm.indicator_str(a[0], b[1], a[2]))
|
|
378
|
-
|
|
390
|
+
|
|
391
|
+
will print::
|
|
379
392
|
A=x, B=5, A=z
|
|
380
393
|
|
|
381
394
|
Args:
|
|
382
|
-
|
|
395
|
+
indicators: the indicators to render.
|
|
383
396
|
sep: the separator to use between the random variable and its state.
|
|
384
397
|
delim: the delimiter to used when rendering multiple indicators.
|
|
385
398
|
|
|
@@ -398,16 +411,17 @@ class PGM:
|
|
|
398
411
|
"""
|
|
399
412
|
Render indicators as a string, grouping indicators by random variable.
|
|
400
413
|
|
|
401
|
-
For example
|
|
414
|
+
For example::
|
|
402
415
|
pgm = PGM()
|
|
403
416
|
a = pgm.new_rv('A', ('x', 'y', 'z'))
|
|
404
417
|
b = pgm.new_rv('B', (3, 5))
|
|
405
418
|
print(pgm.condition_str(a[0], b[1], a[2]))
|
|
406
|
-
|
|
419
|
+
|
|
420
|
+
will print::
|
|
407
421
|
A in {x, z}, B=5
|
|
408
422
|
|
|
409
423
|
Args:
|
|
410
|
-
|
|
424
|
+
indicators: the indicators to render.
|
|
411
425
|
Return:
|
|
412
426
|
a string representation of the given indicators, as a condition.
|
|
413
427
|
"""
|
|
@@ -587,7 +601,7 @@ class PGM:
|
|
|
587
601
|
# All tests passed
|
|
588
602
|
return True
|
|
589
603
|
|
|
590
|
-
def factors_are_cpts(self, tolerance: float =
|
|
604
|
+
def factors_are_cpts(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
|
|
591
605
|
"""
|
|
592
606
|
Are all factor potential functions set with parameters values
|
|
593
607
|
conforming to Conditional Probability Tables.
|
|
@@ -603,7 +617,7 @@ class PGM:
|
|
|
603
617
|
"""
|
|
604
618
|
return all(function.is_cpt(tolerance) for function in self.functions)
|
|
605
619
|
|
|
606
|
-
def check_is_bayesian_network(self, tolerance: float =
|
|
620
|
+
def check_is_bayesian_network(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
|
|
607
621
|
"""
|
|
608
622
|
Is this PGM a Bayesian network.
|
|
609
623
|
|
|
@@ -648,7 +662,7 @@ class PGM:
|
|
|
648
662
|
This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
|
|
649
663
|
|
|
650
664
|
Warning:
|
|
651
|
-
this is potentially computationally expensive as it
|
|
665
|
+
this is potentially computationally expensive as it marginalises random
|
|
652
666
|
variables not mentioned in the given indicators.
|
|
653
667
|
|
|
654
668
|
Args:
|
|
@@ -658,29 +672,11 @@ class PGM:
|
|
|
658
672
|
the product of factors, conditioned on the given instance. This is the
|
|
659
673
|
computed value of the PGM, conditioned on the given instance.
|
|
660
674
|
"""
|
|
661
|
-
#
|
|
662
|
-
#
|
|
663
|
-
#
|
|
664
|
-
#
|
|
665
|
-
#
|
|
666
|
-
# # Collect rvs not mentioned - to marginalise
|
|
667
|
-
# for rv, rv_filter in zip(self.rvs, inst_filter):
|
|
668
|
-
# if len(rv_filter) == 0:
|
|
669
|
-
# rv_filter.update(rv.state_range())
|
|
670
|
-
#
|
|
671
|
-
# def _sum_inst(_instance: Instance) -> bool:
|
|
672
|
-
# return all(
|
|
673
|
-
# (_state in _rv_filter)
|
|
674
|
-
# for _state, _rv_filter in zip(_instance, inst_filter)
|
|
675
|
-
# )
|
|
676
|
-
#
|
|
677
|
-
# # Accumulate the result
|
|
678
|
-
# sum_value = 0
|
|
679
|
-
# for instance in self.instances():
|
|
680
|
-
# if _sum_inst(instance):
|
|
681
|
-
# sum_value += self.value_product(instance)
|
|
682
|
-
#
|
|
683
|
-
# return sum_value
|
|
675
|
+
# Rather than naively checking all possible states of the PGM random
|
|
676
|
+
# variables, this method works to define the state space that should
|
|
677
|
+
# be summed over, based on the given indicators. Thus, if the given
|
|
678
|
+
# indicators constrain the state space to a small number of possibilities,
|
|
679
|
+
# then the sum is only performed over those possibilities.
|
|
684
680
|
|
|
685
681
|
# Work out the space to sum over
|
|
686
682
|
sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
|
|
@@ -719,11 +715,10 @@ class PGM:
|
|
|
719
715
|
precision: a limit on the render precision of floating point numbers.
|
|
720
716
|
max_state_digits: a limit on the number of digits when showing number of states as an integer.
|
|
721
717
|
"""
|
|
722
|
-
# limit
|
|
718
|
+
# Determine a limit to precision when displaying number of states
|
|
723
719
|
num_states: int = self.number_of_states
|
|
724
720
|
number_of_parameters = sum(function.number_of_parameters for function in self.functions)
|
|
725
721
|
number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
|
|
726
|
-
|
|
727
722
|
if math.log10(num_states) > max_state_digits:
|
|
728
723
|
log_states = math.log10(num_states)
|
|
729
724
|
exp = int(log_states)
|
|
@@ -731,7 +726,6 @@ class PGM:
|
|
|
731
726
|
num_states_str = f'{man:,.{precision}f}e+{exp}'
|
|
732
727
|
else:
|
|
733
728
|
num_states_str = f'{num_states:,}'
|
|
734
|
-
|
|
735
729
|
log_2_num_states = math.log2(num_states)
|
|
736
730
|
if (
|
|
737
731
|
log_2_num_states == 0
|
|
@@ -820,9 +814,9 @@ class PGM:
|
|
|
820
814
|
|
|
821
815
|
For a factor `f` the value of states[f.idx] is the search state.
|
|
822
816
|
Specifically:
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
817
|
+
state 0 => the factor has not been seen yet,
|
|
818
|
+
state 1 => the factor is seen but not fully processed,
|
|
819
|
+
state 2 => the factor is fully processed.
|
|
826
820
|
|
|
827
821
|
Args:
|
|
828
822
|
factor: the current Factor being checked.
|
|
@@ -942,9 +936,9 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
942
936
|
in the random variable's PGM list of random variables.
|
|
943
937
|
|
|
944
938
|
A random variable behaves like a sequence of Indicators, where each indicator represents a random
|
|
945
|
-
variable being in a particular state. Specifically for a random variable rv, len(rv) is the
|
|
939
|
+
variable being in a particular state. Specifically for a random variable rv, `len(rv)` is the
|
|
946
940
|
number of states of the random variable and rv[i] is the Indicators representing that
|
|
947
|
-
rv is in the ith state. When sliced, the result is a tuple, i.e. rv[1:3] = (rv[1], rv[2])
|
|
941
|
+
rv is in the ith state. When sliced, the result is a tuple, i.e. `rv[1:3] = (rv[1], rv[2])`.
|
|
948
942
|
|
|
949
943
|
A RandomVariable has a name. This is for human convenience and has no functional purpose
|
|
950
944
|
within a PGM.
|
|
@@ -954,15 +948,18 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
954
948
|
"""
|
|
955
949
|
Create a new random variable, in the given PGM.
|
|
956
950
|
|
|
951
|
+
The states of the random variable can be specified either as an integer
|
|
952
|
+
representing the number of states, or as a sequence of state values. If a
|
|
953
|
+
single integer, `n`, is provided then the states will be: 0, 1, ..., n-1.
|
|
954
|
+
If a sequence of states are provided then the states must be unique.
|
|
955
|
+
|
|
957
956
|
Assumes:
|
|
958
957
|
Provided states contain no duplicates.
|
|
959
958
|
|
|
960
959
|
Args:
|
|
961
960
|
pgm: the PGM that the random variable will belong to.
|
|
962
961
|
name: a name for the random variable.
|
|
963
|
-
states: either
|
|
964
|
-
single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
|
|
965
|
-
If a sequence of states are provided then the states must be unique.
|
|
962
|
+
states: either the number of states or a sequence of state values.
|
|
966
963
|
"""
|
|
967
964
|
self._pgm: PGM = pgm
|
|
968
965
|
self._name: str = name
|
|
@@ -1040,7 +1037,7 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1040
1037
|
|
|
1041
1038
|
def state_range(self) -> Iterable[int]:
|
|
1042
1039
|
"""
|
|
1043
|
-
Iterate over the state indexes of this random variable, in order.
|
|
1040
|
+
Iterate over the state indexes of this random variable, in ascending order.
|
|
1044
1041
|
|
|
1045
1042
|
Returns:
|
|
1046
1043
|
range(len(self))
|
|
@@ -1122,18 +1119,19 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1122
1119
|
|
|
1123
1120
|
def __eq__(self, other) -> bool:
|
|
1124
1121
|
"""
|
|
1125
|
-
Two random
|
|
1122
|
+
Two random variables are equal if they are the same object.
|
|
1126
1123
|
"""
|
|
1127
1124
|
return self is other
|
|
1128
1125
|
|
|
1129
1126
|
def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
|
|
1130
1127
|
"""
|
|
1131
|
-
Two random variable are equivalent if their indicators are equal.
|
|
1132
|
-
random variable indexes and state indexes are checked.
|
|
1133
|
-
|
|
1128
|
+
Two random variable are equivalent if their indicators are equal.
|
|
1129
|
+
Only random variable indexes and state indexes are checked.
|
|
1134
1130
|
This ignores the names of the random variable and the names of their states.
|
|
1135
|
-
|
|
1136
|
-
|
|
1131
|
+
|
|
1132
|
+
Slot maps operate across `equivalent` random variables.
|
|
1133
|
+
This means indicators of equivalent random variables will work
|
|
1134
|
+
correctly in slot maps, even if from different PGMs.
|
|
1137
1135
|
|
|
1138
1136
|
Args:
|
|
1139
1137
|
other: either a random variable or a sequence of Indicators.
|
|
@@ -1181,7 +1179,8 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1181
1179
|
"""
|
|
1182
1180
|
Returns the first index of `value`.
|
|
1183
1181
|
Raises ValueError if the value is not present.
|
|
1184
|
-
|
|
1182
|
+
|
|
1183
|
+
This method is contracted by `Sequence[Indicator]`.
|
|
1185
1184
|
|
|
1186
1185
|
Warning:
|
|
1187
1186
|
This method is different to `self.idx`.
|
|
@@ -1198,7 +1197,10 @@ class RandomVariable(Sequence[Indicator]):
|
|
|
1198
1197
|
def count(self, value: Any) -> int:
|
|
1199
1198
|
"""
|
|
1200
1199
|
Returns the number of occurrences of `value`.
|
|
1201
|
-
|
|
1200
|
+
That is, if `value` is an indicator of this random variable
|
|
1201
|
+
then 1 is returned, otherwise 0 is returned.
|
|
1202
|
+
|
|
1203
|
+
This method is contracted by `Sequence[Indicator]`.
|
|
1202
1204
|
"""
|
|
1203
1205
|
if isinstance(value, Indicator):
|
|
1204
1206
|
if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
|
|
@@ -1210,25 +1212,25 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1210
1212
|
"""
|
|
1211
1213
|
Wrap a PGM to provide convenient access to PGM random variables.
|
|
1212
1214
|
|
|
1213
|
-
An RVMap of a PGM behaves
|
|
1214
|
-
|
|
1215
|
+
An RVMap of a PGM behaves like the PGM `rvs` property (sequence of
|
|
1216
|
+
RandomVariable objects), with additional access methods for the PGM's
|
|
1217
|
+
random variables.
|
|
1215
1218
|
|
|
1216
1219
|
If the underlying PGM is updated, then the RVMap will automatically update.
|
|
1217
1220
|
|
|
1218
|
-
|
|
1219
|
-
of each random variable.
|
|
1221
|
+
In addition to accessing a random variable by its index, an RVMap enables
|
|
1222
|
+
access to the PGM random variable via the name of each random variable.
|
|
1223
|
+
|
|
1224
|
+
For example, if `pgm.rvs[1]` is a random variable named `xray`, then::
|
|
1220
1225
|
|
|
1221
|
-
|
|
1222
|
-
```
|
|
1223
|
-
rvs = RVMap(pgm)
|
|
1226
|
+
rvs = RVMap(pgm)
|
|
1224
1227
|
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
```
|
|
1228
|
+
# These all retrieve the same random variable object.
|
|
1229
|
+
xray = rvs[1]
|
|
1230
|
+
xray = rvs('xray')
|
|
1231
|
+
xray = rvs.xray
|
|
1230
1232
|
|
|
1231
|
-
To use an RVMap on a PGM, the variable names must be unique across the PGM.
|
|
1233
|
+
To use an RVMap on a PGM, the random variable names must be unique across the PGM.
|
|
1232
1234
|
"""
|
|
1233
1235
|
|
|
1234
1236
|
def __init__(self, pgm: PGM, ignore_case: bool = False):
|
|
@@ -1248,28 +1250,6 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1248
1250
|
# This may raise an exception.
|
|
1249
1251
|
_ = self._rv_map
|
|
1250
1252
|
|
|
1251
|
-
def _clean_name(self, name: str) -> str:
|
|
1252
|
-
"""
|
|
1253
|
-
Adjust the case of the given name as needed.
|
|
1254
|
-
"""
|
|
1255
|
-
return name.lower() if self._ignore_case else name
|
|
1256
|
-
|
|
1257
|
-
@property
|
|
1258
|
-
def _rv_map(self) -> Dict[str, RandomVariable]:
|
|
1259
|
-
"""
|
|
1260
|
-
Get the cached rv map, updating as needed if the PGM changed.
|
|
1261
|
-
Returns:
|
|
1262
|
-
a mapping from random variable name to random variable
|
|
1263
|
-
"""
|
|
1264
|
-
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1265
|
-
# There is a difference between the map and the PGM - create a new map.
|
|
1266
|
-
self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
|
|
1267
|
-
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1268
|
-
raise RuntimeError(f'random variable names are not unique')
|
|
1269
|
-
if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
|
|
1270
|
-
raise RuntimeError(f'random variable names clash with reserved names.')
|
|
1271
|
-
return self.__rv_map
|
|
1272
|
-
|
|
1273
1253
|
def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
|
|
1274
1254
|
"""
|
|
1275
1255
|
As per `PGM.new_rv`.
|
|
@@ -1304,6 +1284,29 @@ class RVMap(Sequence[RandomVariable]):
|
|
|
1304
1284
|
def __getattr__(self, rv_name: str) -> RandomVariable:
|
|
1305
1285
|
return self(rv_name)
|
|
1306
1286
|
|
|
1287
|
+
@property
|
|
1288
|
+
def _rv_map(self) -> Dict[str, RandomVariable]:
|
|
1289
|
+
"""
|
|
1290
|
+
Get the cached random variable map, updating as needed if the PGM changed.
|
|
1291
|
+
|
|
1292
|
+
Returns:
|
|
1293
|
+
a mapping from random variable name to random variable
|
|
1294
|
+
"""
|
|
1295
|
+
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1296
|
+
# There is a difference between the map and the PGM - create a new map.
|
|
1297
|
+
self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
|
|
1298
|
+
if len(self.__rv_map) != len(self._pgm.rvs):
|
|
1299
|
+
raise RuntimeError(f'random variable names are not unique')
|
|
1300
|
+
if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
|
|
1301
|
+
raise RuntimeError(f'random variable names clash with reserved names.')
|
|
1302
|
+
return self.__rv_map
|
|
1303
|
+
|
|
1304
|
+
def _clean_name(self, name: str) -> str:
|
|
1305
|
+
"""
|
|
1306
|
+
Adjust the case of the given name as needed.
|
|
1307
|
+
"""
|
|
1308
|
+
return name.lower() if self._ignore_case else name
|
|
1309
|
+
|
|
1307
1310
|
|
|
1308
1311
|
class Factor:
|
|
1309
1312
|
"""
|
|
@@ -1532,7 +1535,7 @@ class Factor:
|
|
|
1532
1535
|
Set to the potential function to a new `ClausePotentialFunction` object.
|
|
1533
1536
|
|
|
1534
1537
|
Args:
|
|
1535
|
-
|
|
1538
|
+
key: defines the random variable states of the clause. The key is a sequence of
|
|
1536
1539
|
random variable state indexes, co-indexed with `Factor.rvs`.
|
|
1537
1540
|
|
|
1538
1541
|
Returns:
|
|
@@ -1544,7 +1547,7 @@ class Factor:
|
|
|
1544
1547
|
self._potential_function = ClausePotentialFunction(self, key)
|
|
1545
1548
|
return self._potential_function
|
|
1546
1549
|
|
|
1547
|
-
def set_cpt(self, tolerance: float =
|
|
1550
|
+
def set_cpt(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> CPTPotentialFunction:
|
|
1548
1551
|
"""
|
|
1549
1552
|
Set to the potential function to a new `CPTPotentialFunction` object.
|
|
1550
1553
|
|
|
@@ -1561,7 +1564,7 @@ class Factor:
|
|
|
1561
1564
|
return self._potential_function
|
|
1562
1565
|
|
|
1563
1566
|
|
|
1564
|
-
@dataclass(frozen=True, eq=True)
|
|
1567
|
+
@dataclass(frozen=True, eq=True, slots=True)
|
|
1565
1568
|
class ParamId:
|
|
1566
1569
|
"""
|
|
1567
1570
|
A ParamId identifies a parameter of a potential function.
|
|
@@ -1820,7 +1823,7 @@ class PotentialFunction(ABC):
|
|
|
1820
1823
|
"""
|
|
1821
1824
|
...
|
|
1822
1825
|
|
|
1823
|
-
def is_cpt(self, tolerance=
|
|
1826
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
1824
1827
|
"""
|
|
1825
1828
|
Is the potential function set with parameters values conforming to a
|
|
1826
1829
|
Conditional Probability Table.
|
|
@@ -2028,7 +2031,7 @@ class ZeroPotentialFunction(PotentialFunction):
|
|
|
2028
2031
|
def param_idx(self, key: Key) -> int:
|
|
2029
2032
|
return _natural_key_idx(self._shape, key)
|
|
2030
2033
|
|
|
2031
|
-
def is_cpt(self, tolerance=
|
|
2034
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
2032
2035
|
return True
|
|
2033
2036
|
|
|
2034
2037
|
|
|
@@ -2169,7 +2172,7 @@ class DensePotentialFunction(PotentialFunction):
|
|
|
2169
2172
|
"""
|
|
2170
2173
|
Set the values of the potential function using the given iterator.
|
|
2171
2174
|
|
|
2172
|
-
Mapping instances to
|
|
2175
|
+
Mapping instances to values is as follows:
|
|
2173
2176
|
Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
|
|
2174
2177
|
values[0] represents instance (0,0)
|
|
2175
2178
|
values[1] represents instance (0,1)
|
|
@@ -2214,7 +2217,7 @@ class DensePotentialFunction(PotentialFunction):
|
|
|
2214
2217
|
The order of values is the same as set_iter.
|
|
2215
2218
|
|
|
2216
2219
|
Args:
|
|
2217
|
-
|
|
2220
|
+
value: the values to use.
|
|
2218
2221
|
|
|
2219
2222
|
Returns:
|
|
2220
2223
|
self
|
|
@@ -2419,7 +2422,7 @@ class SparsePotentialFunction(PotentialFunction):
|
|
|
2419
2422
|
"""
|
|
2420
2423
|
Set the values of the potential function using the given iterator.
|
|
2421
2424
|
|
|
2422
|
-
Mapping instances to
|
|
2425
|
+
Mapping instances to values is as follows:
|
|
2423
2426
|
Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
|
|
2424
2427
|
values[0] represents instance (0,0)
|
|
2425
2428
|
values[1] represents instance (0,1)
|
|
@@ -2641,7 +2644,7 @@ class CompactPotentialFunction(PotentialFunction):
|
|
|
2641
2644
|
"""
|
|
2642
2645
|
Set the values of the potential function using the given iterator.
|
|
2643
2646
|
|
|
2644
|
-
Mapping instances to
|
|
2647
|
+
Mapping instances to `values` is as follows:
|
|
2645
2648
|
Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
|
|
2646
2649
|
values[0] represents instance (0,0)
|
|
2647
2650
|
values[1] represents instance (0,1)
|
|
@@ -2684,7 +2687,7 @@ class CompactPotentialFunction(PotentialFunction):
|
|
|
2684
2687
|
The order of values is the same as set_iter.
|
|
2685
2688
|
|
|
2686
2689
|
Args:
|
|
2687
|
-
|
|
2690
|
+
value: the values to use.
|
|
2688
2691
|
|
|
2689
2692
|
Returns:
|
|
2690
2693
|
self
|
|
@@ -2836,7 +2839,7 @@ class ClausePotentialFunction(PotentialFunction):
|
|
|
2836
2839
|
else:
|
|
2837
2840
|
return None
|
|
2838
2841
|
|
|
2839
|
-
def is_cpt(self, tolerance=
|
|
2842
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
2840
2843
|
"""
|
|
2841
2844
|
A ClausePotentialFunction can only be a CTP when all entries are zero.
|
|
2842
2845
|
"""
|
|
@@ -2930,7 +2933,7 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
2930
2933
|
def number_of_parameters(self) -> int:
|
|
2931
2934
|
return len(self._values)
|
|
2932
2935
|
|
|
2933
|
-
def is_cpt(self, tolerance=
|
|
2936
|
+
def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
|
|
2934
2937
|
if tolerance >= self._tolerance:
|
|
2935
2938
|
return True
|
|
2936
2939
|
else:
|
|
@@ -3015,12 +3018,11 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3015
3018
|
|
|
3016
3019
|
def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
|
|
3017
3020
|
"""
|
|
3018
|
-
Iterate over (parent_states, cpd) tuples.
|
|
3019
|
-
|
|
3020
|
-
|
|
3021
|
-
|
|
3022
|
-
|
|
3023
|
-
|
|
3021
|
+
Iterate over (parent_states, cpd) tuples. This will exclude zero CPDs.
|
|
3022
|
+
|
|
3023
|
+
Warning:
|
|
3024
|
+
Do not change CPDs to (or from) zero while iterating over them.
|
|
3025
|
+
|
|
3024
3026
|
Returns:
|
|
3025
3027
|
an iterator over pairs (instance, cpd) where,
|
|
3026
3028
|
instance: is indicates the state of the parent random variables.
|
|
@@ -3077,7 +3079,8 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3077
3079
|
Calls self.set_cpd(parent_states, cpd) for each row (parent_states, cpd)
|
|
3078
3080
|
in rows. Any unmentioned parent states will have zero probabilities.
|
|
3079
3081
|
|
|
3080
|
-
Example usage, assuming three Boolean random variables
|
|
3082
|
+
Example usage, assuming three Boolean random variables::
|
|
3083
|
+
|
|
3081
3084
|
pgm.Factor(x, y, z).set_cpt().set(
|
|
3082
3085
|
# y z x[0] x[1]
|
|
3083
3086
|
((0, 0), (0.1, 0.9)),
|
|
@@ -3085,9 +3088,9 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3085
3088
|
((1, 0), (0.1, 0.9)),
|
|
3086
3089
|
((1, 1), (0.1, 0.9))
|
|
3087
3090
|
)
|
|
3088
|
-
|
|
3091
|
+
|
|
3089
3092
|
Args:
|
|
3090
|
-
|
|
3093
|
+
rows: are tuples (key, cpd) used to set the potential function values.
|
|
3091
3094
|
|
|
3092
3095
|
Raises:
|
|
3093
3096
|
ValueError: if a CPD is not valid.
|
|
@@ -3111,7 +3114,7 @@ class CPTPotentialFunction(PotentialFunction):
|
|
|
3111
3114
|
Any list entry may be None, indicating 'guaranteed zero' for the associated parent states.
|
|
3112
3115
|
|
|
3113
3116
|
Args:
|
|
3114
|
-
|
|
3117
|
+
cpds: are the CPDs used to set the potential function values.
|
|
3115
3118
|
|
|
3116
3119
|
Raises:
|
|
3117
3120
|
ValueError: if a CPD is not valid.
|
|
@@ -3288,7 +3291,7 @@ def check_key(shape: Shape, key: Key) -> Instance:
|
|
|
3288
3291
|
A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
|
|
3289
3292
|
|
|
3290
3293
|
Raises:
|
|
3291
|
-
KeyError if the key is not valid.
|
|
3294
|
+
KeyError if the key is not valid for the given shape.
|
|
3292
3295
|
"""
|
|
3293
3296
|
_key: Instance = _key_to_instance(key)
|
|
3294
3297
|
if len(_key) != len(shape):
|
|
@@ -3336,8 +3339,8 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]
|
|
|
3336
3339
|
flip: if true, then first random variable changes most quickly.
|
|
3337
3340
|
|
|
3338
3341
|
Returns:
|
|
3339
|
-
an iteration over
|
|
3340
|
-
co-indexed with the given random variables.
|
|
3342
|
+
an iteration over instances, each instance is a tuple of state
|
|
3343
|
+
indexes, co-indexed with the given random variables.
|
|
3341
3344
|
"""
|
|
3342
3345
|
shape = [len(rv) for rv in rvs]
|
|
3343
3346
|
return _combos_ranges(shape, flip=not flip)
|
|
@@ -3384,6 +3387,10 @@ def _natural_key_idx(shape: Shape, key: Key) -> int:
|
|
|
3384
3387
|
"""
|
|
3385
3388
|
What is the natural index of the given key, assuming the given shape.
|
|
3386
3389
|
|
|
3390
|
+
The natural index of an instance is defined as the index of the
|
|
3391
|
+
instance if all instances for the shape are enumerated as per
|
|
3392
|
+
`rv_instances`.
|
|
3393
|
+
|
|
3387
3394
|
Args:
|
|
3388
3395
|
shape: the shape defining the state space.
|
|
3389
3396
|
key: a key into the state space.
|
ck/pgm_circuit/mpe_program.py
CHANGED
|
@@ -228,10 +228,9 @@ class MPEProgram(ProgramWithSlotmap):
|
|
|
228
228
|
class MPEResult:
|
|
229
229
|
"""
|
|
230
230
|
An MPE result is the result of MPE inference.
|
|
231
|
-
|
|
232
|
-
Fields:
|
|
233
|
-
wmc: the weighted model count value of the MPE solution.
|
|
234
|
-
mpe: The MPE solution instance. If there are ties then this will just be once instance.
|
|
235
231
|
"""
|
|
236
232
|
wmc: float
|
|
233
|
+
"""the weighted model count value of the MPE solution."""
|
|
234
|
+
|
|
237
235
|
mpe: Instance
|
|
236
|
+
"""the MPE solution instance. If there are ties then this will just be once instance."""
|