compiled-knowledge 4.0.0a22__cp312-cp312-macosx_11_0_arm64.whl → 4.0.0a24__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (33) hide show
  1. ck/circuit/_circuit_cy.c +50 -60
  2. ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
  3. ck/circuit/_circuit_cy.pyx +1 -1
  4. ck/circuit/_circuit_py.py +1 -1
  5. ck/circuit_compiler/circuit_compiler.py +3 -2
  6. ck/circuit_compiler/cython_vm_compiler/_compiler.c +179 -170
  7. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
  8. ck/circuit_compiler/cython_vm_compiler/_compiler.pyx +3 -3
  9. ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +14 -7
  10. ck/circuit_compiler/interpret_compiler.py +35 -4
  11. ck/circuit_compiler/llvm_vm_compiler.py +9 -3
  12. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  13. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
  14. ck/circuit_compiler/support/llvm_ir_function.py +18 -1
  15. ck/pgm.py +100 -102
  16. ck/pgm_compiler/pgm_compiler.py +1 -1
  17. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  18. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
  19. ck/pgm_compiler/support/join_tree.py +3 -3
  20. ck/probability/empirical_probability_space.py +4 -3
  21. ck/probability/pgm_probability_space.py +7 -3
  22. ck/probability/probability_space.py +21 -15
  23. ck/program/raw_program.py +40 -7
  24. ck/sampling/sampler_support.py +8 -5
  25. ck_demos/ace/simple_ace_demo.py +18 -0
  26. ck_demos/getting_started/__init__.py +0 -0
  27. ck_demos/getting_started/simple_demo.py +18 -0
  28. ck_demos/programs/demo_raw_program_dump.py +17 -0
  29. {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/METADATA +1 -1
  30. {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/RECORD +33 -29
  31. {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/WHEEL +0 -0
  32. {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/licenses/LICENSE.txt +0 -0
  33. {compiled_knowledge-4.0.0a22.dist-info → compiled_knowledge-4.0.0a24.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ DTYPE_TO_CVM_TYPE: Dict[DTypeNumeric, str] = {
38
38
  }
39
39
 
40
40
 
41
- def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawProgramFunction, int]:
41
+ def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawProgramFunction, int, int]:
42
42
  """
43
43
  Make a RawProgram function that interprets the circuit.
44
44
 
@@ -47,7 +47,7 @@ def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawPr
47
47
  dtype: a numpy data type that must be a key in the dictionary, DTYPE_TO_CVM_TYPE.
48
48
 
49
49
  Returns:
50
- (function, number_of_tmps)
50
+ (function, number_of_tmps, number_of_instructions)
51
51
  """
52
52
 
53
53
  cdef Instructions instructions
@@ -128,7 +128,7 @@ def make_function(analysis: CircuitAnalysis, dtype: DTypeNumeric) -> Tuple[RawPr
128
128
  else:
129
129
  raise ValueError(f'cvm_type_name unexpected: {cvm_type_name!r}')
130
130
 
131
- return function, len(analysis.op_to_tmp)
131
+ return function, len(analysis.op_to_tmp), len(analysis.op_nodes)
132
132
 
133
133
  # VM instructions
134
134
  cdef int ADD = circuit.ADD
@@ -48,15 +48,16 @@ class CythonRawProgram(RawProgram):
48
48
  result: Sequence[CircuitNode],
49
49
  dtype: DTypeNumeric,
50
50
  ):
51
- self.in_vars = in_vars
52
- self.result = result
53
-
54
- function, number_of_tmps = _make_function(
51
+ function, number_of_tmps, number_of_instructions = _make_function(
55
52
  var_nodes=in_vars,
56
53
  result_nodes=result,
57
54
  dtype=dtype,
58
55
  )
59
56
 
57
+ self.in_vars = in_vars
58
+ self.result = result
59
+ self.number_of_instructions = number_of_instructions
60
+
60
61
  super().__init__(
61
62
  function=function,
62
63
  dtype=dtype,
@@ -66,6 +67,10 @@ class CythonRawProgram(RawProgram):
66
67
  var_indices=tuple(var.idx for var in in_vars),
67
68
  )
68
69
 
70
+ def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
71
+ super().dump(prefix=prefix, indent=indent)
72
+ print(f'{prefix}number of instructions = {self.number_of_instructions}')
73
+
69
74
  def __getstate__(self):
70
75
  """
71
76
  Support for pickle.
@@ -94,7 +99,7 @@ class CythonRawProgram(RawProgram):
94
99
  self.in_vars = state['in_vars']
95
100
  self.result = state['result']
96
101
 
97
- self.function, _ = _make_function(
102
+ self.function, _, self.number_of_instructions = _make_function(
98
103
  var_nodes=self.in_vars,
99
104
  result_nodes=self.result,
100
105
  dtype=self.dtype,
@@ -105,7 +110,7 @@ def _make_function(
105
110
  var_nodes: Sequence[VarNode],
106
111
  result_nodes: Sequence[CircuitNode],
107
112
  dtype: DTypeNumeric,
108
- ) -> Tuple[RawProgramFunction, int]:
113
+ ) -> Tuple[RawProgramFunction, int, int]:
109
114
  """
110
115
  Make a RawProgram function that interprets the circuit.
111
116
 
@@ -115,7 +120,9 @@ def _make_function(
115
120
  dtype: a numpy data type that must be a key in the dictionary, DTYPE_TO_CVM_TYPE.
116
121
 
117
122
  Returns:
118
- (function, number_of_tmps)
123
+ (function, number_of_tmps, number_of_instructions)
119
124
  """
120
125
  analysis: CircuitAnalysis = analyze_circuit(var_nodes, result_nodes)
126
+ DEBUG = _compiler.make_function(analysis, dtype)
127
+
121
128
  return _compiler.make_function(analysis, dtype)
@@ -1,17 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ctypes as ct
3
4
  from dataclasses import dataclass
4
5
  from typing import Sequence, Optional, Dict, List, Tuple, Callable
5
6
 
6
7
  import numpy as np
7
- import ctypes as ct
8
8
 
9
+ from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
10
+ from .support.input_vars import InputVars, InferVars, infer_input_vars
9
11
  from ..circuit import Circuit, CircuitNode, VarNode, OpNode, ADD, MUL
10
12
  from ..program.raw_program import RawProgram, RawProgramFunction
11
13
  from ..utils.iter_extras import multiply, first
12
14
  from ..utils.np_extras import NDArrayNumeric, DTypeNumeric
13
- from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
14
- from .support.input_vars import InputVars, InferVars, infer_input_vars
15
15
 
16
16
  # index to a value array
17
17
  _VARS = 0
@@ -85,6 +85,15 @@ class InterpreterRawProgram(RawProgram):
85
85
  var_indices=tuple(var.idx for var in in_vars),
86
86
  )
87
87
 
88
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
89
+ super().dump(prefix=prefix, indent=indent)
90
+ print(f'{prefix}number of instructions = {len(self.instructions)}')
91
+ if show_instructions:
92
+ print(f'{prefix}instructions:')
93
+ next_prefix: str = prefix + indent
94
+ for instruction in self.instructions:
95
+ print(f'{next_prefix}{instruction.to_str(self.var_indices, self.np_consts)}')
96
+
88
97
  def __getstate__(self):
89
98
  """
90
99
  Support for pickle.
@@ -123,7 +132,6 @@ def _make_instructions(
123
132
  analysis: CircuitAnalysis,
124
133
  dtype: DTypeNumeric,
125
134
  ) -> Tuple[Sequence[_Instruction], NDArrayNumeric]:
126
-
127
135
  # Store const values in a numpy array
128
136
  node_to_const_idx: Dict[int, int] = {
129
137
  id(node): i
@@ -216,9 +224,32 @@ class _ElementID:
216
224
  array: int # VARS, TMPS, CONSTS, RESULT
217
225
  index: int # index into the array
218
226
 
227
+ def to_str(self, var_indices: Sequence[int], consts: NDArrayNumeric) -> str:
228
+ if self.array == _VARS:
229
+ return f'var[{var_indices[self.index]}]'
230
+ elif self.array == _TMPS:
231
+ return f'tmp[{self.index}]'
232
+ elif self.array == _CONSTS:
233
+ return str(consts.item(self.index))
234
+ elif self.array == _RESULT:
235
+ return f'result[{self.index}]'
236
+ else:
237
+ return f'?[{self.index}]'
238
+
219
239
 
220
240
  @dataclass
221
241
  class _Instruction:
222
242
  operation: Callable
223
243
  args: Sequence[_ElementID]
224
244
  dest: _ElementID
245
+
246
+ def to_str(self, var_indices: Sequence[int], consts: NDArrayNumeric) -> str:
247
+ symbol: str
248
+ if self.operation is multiply:
249
+ symbol = 'mul'
250
+ elif self.operation == sum:
251
+ symbol = 'sum'
252
+ else:
253
+ symbol = '<?>'
254
+ args: str = ' '.join(elem.to_str(var_indices, consts) for elem in self.args)
255
+ return f'{self.dest.to_str(var_indices, consts)} = {symbol} {args}'
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ctypes as ct
3
4
  from dataclasses import dataclass
4
5
  from typing import Sequence, Optional, Tuple, List, Dict
5
6
 
6
7
  import llvmlite.binding as llvm
7
8
  import llvmlite.ir as ir
8
9
  import numpy as np
9
- import ctypes as ct
10
10
 
11
11
  from .support.circuit_analyser import CircuitAnalysis, analyze_circuit
12
12
  from .support.input_vars import InputVars, InferVars, infer_input_vars
@@ -126,6 +126,12 @@ class LLVMRawProgramWithArrays(LLVMRawProgram):
126
126
  instructions: np.ndarray
127
127
  consts: np.ndarray
128
128
 
129
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
130
+ super().dump(prefix=prefix, indent=indent)
131
+ print(f'{prefix}LLVM byte code size = {len(self.instructions)}')
132
+ if show_instructions:
133
+ self.dump_llvm_program(prefix=prefix, indent=indent)
134
+
129
135
  def __post_init__(self):
130
136
  self._set_globals(self.instructions, _SET_INSTRUCTIONS_FUNCTION_NAME)
131
137
  self._set_globals(self.consts, _SET_CONSTS_FUNCTION_NAME)
@@ -200,7 +206,7 @@ def _make_llvm_program(
200
206
  consts_global.global_constant = True
201
207
  consts_global.initializer = ir.Constant(consts_array_type, const_values)
202
208
  data_idx_0 = ir.Constant(data_idx_type, 0)
203
- consts: ir.Value = builder.gep(consts_global, [data_idx_0, data_idx_0])
209
+ consts: ir.Value = builder.gep(consts_global, [data_idx_0, data_idx_0])
204
210
 
205
211
  # Put bytecode into the LLVM module
206
212
  instructions_array_type = ir.ArrayType(byte_type, len(byte_code))
@@ -218,7 +224,7 @@ def _make_llvm_program(
218
224
 
219
225
  instructions_ptr_type = byte_type.as_pointer()
220
226
  instructions_global = ir.GlobalVariable(module, instructions_ptr_type, name='instructions')
221
- instructions_global.initializer =ir.Constant(instructions_ptr_type, None)
227
+ instructions_global.initializer = ir.Constant(instructions_ptr_type, None)
222
228
  instructions: ir.Value = builder.load(instructions_global)
223
229
 
224
230
  interp = _InterpBuilder(builder, type_info, inst_idx_type, data_idx_bytes, num_args_bytes, consts, instructions)
@@ -15,7 +15,7 @@
15
15
  "-O3"
16
16
  ],
17
17
  "include_dirs": [
18
- "/private/var/folders/y6/nj790rtn62lfktb1sh__79hc0000gn/T/build-env-8d_gezt7/lib/python3.12/site-packages/numpy/_core/include"
18
+ "/private/var/folders/y6/nj790rtn62lfktb1sh__79hc0000gn/T/build-env-gdy6g4em/lib/python3.12/site-packages/numpy/_core/include"
19
19
  ],
20
20
  "name": "ck.circuit_compiler.support.circuit_analyser._circuit_analyser_cy",
21
21
  "sources": [
@@ -1,7 +1,7 @@
1
1
  import ctypes as ct
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
- from typing import Callable, Tuple, Optional
4
+ from typing import Callable, Tuple, Optional, List
5
5
 
6
6
  import llvmlite.binding as llvm
7
7
  import llvmlite.ir as ir
@@ -172,6 +172,23 @@ class LLVMRawProgram(RawProgram):
172
172
  'opt': self.opt,
173
173
  }
174
174
 
175
+ def dump(self, *, prefix: str = '', indent: str = ' ', show_instructions: bool = True) -> None:
176
+ super().dump(prefix=prefix, indent=indent)
177
+ print(f'{prefix}optimisation level = {self.opt}')
178
+ if show_instructions:
179
+ self.dump_llvm_program(prefix=prefix, indent=indent)
180
+
181
+ def dump_llvm_program(self, *, prefix: str = '', indent: str = ' ') -> None:
182
+ if self.llvm_program is None:
183
+ print(f'{prefix}LLVM program: unavailable')
184
+ else:
185
+ llvm_program: List[str] = self.llvm_program.split('\n')
186
+ print(f'{prefix}LLVM program size = {len(llvm_program)}')
187
+ print(f'{prefix}LLVM program:')
188
+ next_prefix: str = prefix + indent
189
+ for line in llvm_program:
190
+ print(f'{next_prefix}{line}')
191
+
175
192
  def __setstate__(self, state):
176
193
  """
177
194
  Support for pickle.
ck/pgm.py CHANGED
@@ -1,6 +1,3 @@
1
- """
2
- For more documentation on this module, refer to the Jupyter notebook docs/4_PGM_advanced.ipynb.
3
- """
4
1
  from __future__ import annotations
5
2
 
6
3
  import math
@@ -8,7 +5,7 @@ from abc import ABC, abstractmethod
8
5
  from dataclasses import dataclass
9
6
  from itertools import repeat as _repeat
10
7
  from typing import Sequence, Tuple, Dict, Optional, overload, Set, Iterable, List, Union, Callable, \
11
- Collection, Any, Iterator
8
+ Collection, Any, Iterator, TypeAlias
12
9
 
13
10
  import numpy as np
14
11
 
@@ -17,21 +14,34 @@ from ck.utils.iter_extras import (
17
14
  )
18
15
  from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
19
16
 
20
- # What types are permitted as random variable states
21
- State = Union[int, str, bool, float, None]
17
+ State: TypeAlias = Union[int, str, bool, float, None]
18
+ State.__doc__ = \
19
+ """
20
+ The type for a possible state of a random variable.
21
+ """
22
22
 
23
- # An instance (of a sequence of random variables) is a tuple of integers
24
- # that are state indexes, co-indexed with a known sequence of random variables.
25
- Instance = Sequence[int]
23
+ Instance: TypeAlias = Sequence[int]
24
+ Instance.__doc__ = \
25
+ """
26
+ An instance (of a sequence of random variables) is a sequence of integers
27
+ that are state indexes, co-indexed with a known sequence of random variables.
28
+ """
26
29
 
27
- # A key identifies an instance.
28
- # A single integer is treated as an instance with one dimension.
29
- Key = Union[Sequence[int], int]
30
+ Key: TypeAlias = Union[Instance, int]
31
+ Key.__doc__ = \
32
+ """
33
+ A key identifies an instance, either as an instance itself or a
34
+ single integer, representing an instance with one dimension.
35
+ """
30
36
 
31
- # The shape of a sequence of random variables (e.g., a PGM, Factor or PotentialFunction).
32
- Shape = Sequence[int]
37
+ Shape: TypeAlias = Sequence[int]
38
+ Key.__doc__ = \
39
+ """
40
+ The type for the "shape" of a sequence of random variables.
41
+ That is, the shape of (rv1, rv2, rv3) is (len(rv1), len(rv2), len(rv3)).
42
+ """
33
43
 
34
- DEFAULT_TOLERANCE: float = 0.000001 # For checking CPT sums.
44
+ DEFAULT_CPT_TOLERANCE: float = 0.000001 # A tolerance when checking CPT distributions sum to one (or zero).
35
45
 
36
46
 
37
47
  class PGM:
@@ -39,11 +49,9 @@ class PGM:
39
49
  A probabilistic graphical model (PGM) represents a joint probability distribution over
40
50
  a set of random variables. Specifically, a PGM is a factor graph with discrete random variables.
41
51
 
42
- Add a random variable to a PGM, pgm, using `rv = pgm.new_rv(...)`.
43
-
44
- Add a factor to the PGM, pgm, using `factor = pgm.new_factor(...)`.
52
+ Add a random variable to a PGM, `pgm`, using `rv = pgm.new_rv(...)`.
45
53
 
46
- A PGM may be given a human-readable name.
54
+ Add a factor to the PGM, `pgm`, using `factor = pgm.new_factor(...)`.
47
55
  """
48
56
 
49
57
  def __init__(self, name: Optional[str] = None):
@@ -587,7 +595,7 @@ class PGM:
587
595
  # All tests passed
588
596
  return True
589
597
 
590
- def factors_are_cpts(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
598
+ def factors_are_cpts(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
591
599
  """
592
600
  Are all factor potential functions set with parameters values
593
601
  conforming to Conditional Probability Tables.
@@ -603,7 +611,7 @@ class PGM:
603
611
  """
604
612
  return all(function.is_cpt(tolerance) for function in self.functions)
605
613
 
606
- def check_is_bayesian_network(self, tolerance: float = DEFAULT_TOLERANCE) -> bool:
614
+ def check_is_bayesian_network(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> bool:
607
615
  """
608
616
  Is this PGM a Bayesian network.
609
617
 
@@ -648,7 +656,7 @@ class PGM:
648
656
  This method has the same semantics as `ProbabilitySpace.wmc` without conditioning.
649
657
 
650
658
  Warning:
651
- this is potentially computationally expensive as it marginalised random
659
+ this is potentially computationally expensive as it marginalises random
652
660
  variables not mentioned in the given indicators.
653
661
 
654
662
  Args:
@@ -658,29 +666,11 @@ class PGM:
658
666
  the product of factors, conditioned on the given instance. This is the
659
667
  computed value of the PGM, conditioned on the given instance.
660
668
  """
661
- # # Create a filter from the indicators
662
- # inst_filter: List[Set[int]] = [set() for _ in range(self.number_of_rvs)]
663
- # for indicator in indicators:
664
- # rv_idx: int = indicator.rv_idx
665
- # inst_filter[rv_idx].add(indicator.state_idx)
666
- # # Collect rvs not mentioned - to marginalise
667
- # for rv, rv_filter in zip(self.rvs, inst_filter):
668
- # if len(rv_filter) == 0:
669
- # rv_filter.update(rv.state_range())
670
- #
671
- # def _sum_inst(_instance: Instance) -> bool:
672
- # return all(
673
- # (_state in _rv_filter)
674
- # for _state, _rv_filter in zip(_instance, inst_filter)
675
- # )
676
- #
677
- # # Accumulate the result
678
- # sum_value = 0
679
- # for instance in self.instances():
680
- # if _sum_inst(instance):
681
- # sum_value += self.value_product(instance)
682
- #
683
- # return sum_value
669
+ # Rather than naively checking all possible states of the PGM random
670
+ # variables, this method works to define the state space that should
671
+ # be summed over, based on the given indicators. Thus, if the given
672
+ # indicators constrain the state space to a small number of possibilities,
673
+ # then the sum is only performed over those possibilities.
684
674
 
685
675
  # Work out the space to sum over
686
676
  sum_space_set: List[Optional[Set[int]]] = [None] * self.number_of_rvs
@@ -719,11 +709,10 @@ class PGM:
719
709
  precision: a limit on the render precision of floating point numbers.
720
710
  max_state_digits: a limit on the number of digits when showing number of states as an integer.
721
711
  """
722
- # limit the precision when displaying number of states
712
+ # Determine a limit to precision when displaying number of states
723
713
  num_states: int = self.number_of_states
724
714
  number_of_parameters = sum(function.number_of_parameters for function in self.functions)
725
715
  number_of_nz_parameters = sum(function.number_of_parameters for function in self.non_zero_functions)
726
-
727
716
  if math.log10(num_states) > max_state_digits:
728
717
  log_states = math.log10(num_states)
729
718
  exp = int(log_states)
@@ -731,7 +720,6 @@ class PGM:
731
720
  num_states_str = f'{man:,.{precision}f}e+{exp}'
732
721
  else:
733
722
  num_states_str = f'{num_states:,}'
734
-
735
723
  log_2_num_states = math.log2(num_states)
736
724
  if (
737
725
  log_2_num_states == 0
@@ -820,9 +808,9 @@ class PGM:
820
808
 
821
809
  For a factor `f` the value of states[f.idx] is the search state.
822
810
  Specifically:
823
- state 0 => the factor has not been seen yet,
824
- state 1 => the factor is seen but not fully processed,
825
- state 2 => the factor is fully processed.
811
+ state 0 => the factor has not been seen yet,
812
+ state 1 => the factor is seen but not fully processed,
813
+ state 2 => the factor is fully processed.
826
814
 
827
815
  Args:
828
816
  factor: the current Factor being checked.
@@ -1040,7 +1028,7 @@ class RandomVariable(Sequence[Indicator]):
1040
1028
 
1041
1029
  def state_range(self) -> Iterable[int]:
1042
1030
  """
1043
- Iterate over the state indexes of this random variable, in order.
1031
+ Iterate over the state indexes of this random variable, in ascending order.
1044
1032
 
1045
1033
  Returns:
1046
1034
  range(len(self))
@@ -1122,18 +1110,19 @@ class RandomVariable(Sequence[Indicator]):
1122
1110
 
1123
1111
  def __eq__(self, other) -> bool:
1124
1112
  """
1125
- Two random variable are equal if they are the same object.
1113
+ Two random variables are equal if they are the same object.
1126
1114
  """
1127
1115
  return self is other
1128
1116
 
1129
1117
  def equivalent(self, other: RandomVariable | Sequence[Indicator]) -> bool:
1130
1118
  """
1131
- Two random variable are equivalent if their indicators are equal. Only
1132
- random variable indexes and state indexes are checked.
1133
-
1119
+ Two random variable are equivalent if their indicators are equal.
1120
+ Only random variable indexes and state indexes are checked.
1134
1121
  This ignores the names of the random variable and the names of their states.
1135
- This means their indicators will work correctly in slot maps, even
1136
- if from different PGMs.
1122
+
1123
+ Slot maps operate across `equivalent` random variables.
1124
+ This means indicators of equivalent random variables will work
1125
+ correctly in slot maps, even if from different PGMs.
1137
1126
 
1138
1127
  Args:
1139
1128
  other: either a random variable or a sequence of Indicators.
@@ -1181,7 +1170,8 @@ class RandomVariable(Sequence[Indicator]):
1181
1170
  """
1182
1171
  Returns the first index of `value`.
1183
1172
  Raises ValueError if the value is not present.
1184
- Contracted by Sequence[Indicator].
1173
+
1174
+ This method is contracted by `Sequence[Indicator]`.
1185
1175
 
1186
1176
  Warning:
1187
1177
  This method is different to `self.idx`.
@@ -1198,7 +1188,10 @@ class RandomVariable(Sequence[Indicator]):
1198
1188
  def count(self, value: Any) -> int:
1199
1189
  """
1200
1190
  Returns the number of occurrences of `value`.
1201
- Contracted by Sequence[Indicator].
1191
+ That is, if `value` is an indicator of this random variable
1192
+ then 1 is returned, otherwise 0 is returned.
1193
+
1194
+ This method is contracted by `Sequence[Indicator]`.
1202
1195
  """
1203
1196
  if isinstance(value, Indicator):
1204
1197
  if value.rv_idx == self._idx and 0 <= value.state_idx < len(self):
@@ -1210,15 +1203,16 @@ class RVMap(Sequence[RandomVariable]):
1210
1203
  """
1211
1204
  Wrap a PGM to provide convenient access to PGM random variables.
1212
1205
 
1213
- An RVMap of a PGM behaves exactly like the PGM `rvs` property. That it, it
1214
- behaves like a sequence of RandomVariable objects.
1206
+ An RVMap of a PGM behaves like the PGM `rvs` property (sequence of
1207
+ RandomVariable objects), with additional access methods for the PGM's
1208
+ random variables.
1215
1209
 
1216
1210
  If the underlying PGM is updated, then the RVMap will automatically update.
1217
1211
 
1218
- Additionally, an RVMap enables access to the PGM random variable via the name
1219
- of each random variable.
1212
+ In addition to accessing a random variable by its index, an RVMap enables
1213
+ access to the PGM random variable via the name of each random variable.
1220
1214
 
1221
- for example, if `pgm.rvs[1]` is a random variable named `xray`, then
1215
+ For example, if `pgm.rvs[1]` is a random variable named `xray`, then:
1222
1216
  ```
1223
1217
  rvs = RVMap(pgm)
1224
1218
 
@@ -1228,7 +1222,7 @@ class RVMap(Sequence[RandomVariable]):
1228
1222
  xray = rvs.xray
1229
1223
  ```
1230
1224
 
1231
- To use an RVMap on a PGM, the variable names must be unique across the PGM.
1225
+ To use an RVMap on a PGM, the random variable names must be unique across the PGM.
1232
1226
  """
1233
1227
 
1234
1228
  def __init__(self, pgm: PGM, ignore_case: bool = False):
@@ -1248,28 +1242,6 @@ class RVMap(Sequence[RandomVariable]):
1248
1242
  # This may raise an exception.
1249
1243
  _ = self._rv_map
1250
1244
 
1251
- def _clean_name(self, name: str) -> str:
1252
- """
1253
- Adjust the case of the given name as needed.
1254
- """
1255
- return name.lower() if self._ignore_case else name
1256
-
1257
- @property
1258
- def _rv_map(self) -> Dict[str, RandomVariable]:
1259
- """
1260
- Get the cached rv map, updating as needed if the PGM changed.
1261
- Returns:
1262
- a mapping from random variable name to random variable
1263
- """
1264
- if len(self.__rv_map) != len(self._pgm.rvs):
1265
- # There is a difference between the map and the PGM - create a new map.
1266
- self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
1267
- if len(self.__rv_map) != len(self._pgm.rvs):
1268
- raise RuntimeError(f'random variable names are not unique')
1269
- if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
1270
- raise RuntimeError(f'random variable names clash with reserved names.')
1271
- return self.__rv_map
1272
-
1273
1245
  def new_rv(self, name: str, states: Union[int, Sequence[State]]) -> RandomVariable:
1274
1246
  """
1275
1247
  As per `PGM.new_rv`.
@@ -1304,6 +1276,29 @@ class RVMap(Sequence[RandomVariable]):
1304
1276
  def __getattr__(self, rv_name: str) -> RandomVariable:
1305
1277
  return self(rv_name)
1306
1278
 
1279
+ @property
1280
+ def _rv_map(self) -> Dict[str, RandomVariable]:
1281
+ """
1282
+ Get the cached random variable map, updating as needed if the PGM changed.
1283
+
1284
+ Returns:
1285
+ a mapping from random variable name to random variable
1286
+ """
1287
+ if len(self.__rv_map) != len(self._pgm.rvs):
1288
+ # There is a difference between the map and the PGM - create a new map.
1289
+ self.__rv_map = {self._clean_name(rv.name): rv for rv in self._pgm.rvs}
1290
+ if len(self.__rv_map) != len(self._pgm.rvs):
1291
+ raise RuntimeError(f'random variable names are not unique')
1292
+ if not self._reserved_names.isdisjoint(self.__rv_map.keys()):
1293
+ raise RuntimeError(f'random variable names clash with reserved names.')
1294
+ return self.__rv_map
1295
+
1296
+ def _clean_name(self, name: str) -> str:
1297
+ """
1298
+ Adjust the case of the given name as needed.
1299
+ """
1300
+ return name.lower() if self._ignore_case else name
1301
+
1307
1302
 
1308
1303
  class Factor:
1309
1304
  """
@@ -1544,7 +1539,7 @@ class Factor:
1544
1539
  self._potential_function = ClausePotentialFunction(self, key)
1545
1540
  return self._potential_function
1546
1541
 
1547
- def set_cpt(self, tolerance: float = DEFAULT_TOLERANCE) -> CPTPotentialFunction:
1542
+ def set_cpt(self, tolerance: float = DEFAULT_CPT_TOLERANCE) -> CPTPotentialFunction:
1548
1543
  """
1549
1544
  Set to the potential function to a new `CPTPotentialFunction` object.
1550
1545
 
@@ -1820,7 +1815,7 @@ class PotentialFunction(ABC):
1820
1815
  """
1821
1816
  ...
1822
1817
 
1823
- def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
1818
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
1824
1819
  """
1825
1820
  Is the potential function set with parameters values conforming to a
1826
1821
  Conditional Probability Table.
@@ -2028,7 +2023,7 @@ class ZeroPotentialFunction(PotentialFunction):
2028
2023
  def param_idx(self, key: Key) -> int:
2029
2024
  return _natural_key_idx(self._shape, key)
2030
2025
 
2031
- def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2026
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2032
2027
  return True
2033
2028
 
2034
2029
 
@@ -2836,7 +2831,7 @@ class ClausePotentialFunction(PotentialFunction):
2836
2831
  else:
2837
2832
  return None
2838
2833
 
2839
- def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2834
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2840
2835
  """
2841
2836
  A ClausePotentialFunction can only be a CTP when all entries are zero.
2842
2837
  """
@@ -2930,7 +2925,7 @@ class CPTPotentialFunction(PotentialFunction):
2930
2925
  def number_of_parameters(self) -> int:
2931
2926
  return len(self._values)
2932
2927
 
2933
- def is_cpt(self, tolerance=DEFAULT_TOLERANCE) -> bool:
2928
+ def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2934
2929
  if tolerance >= self._tolerance:
2935
2930
  return True
2936
2931
  else:
@@ -3015,12 +3010,11 @@ class CPTPotentialFunction(PotentialFunction):
3015
3010
 
3016
3011
  def cpds(self) -> Iterable[Tuple[Instance, Sequence[float]]]:
3017
3012
  """
3018
- Iterate over (parent_states, cpd) tuples.
3019
- This will exclude zero CPDs.
3020
- Do not change CPDs to (or from) zero while iterating over them.
3021
-
3022
- Get the CPD conditioned on parent states indicated by `parent_states`.
3023
-
3013
+ Iterate over (parent_states, cpd) tuples. This will exclude zero CPDs.
3014
+
3015
+ Warning:
3016
+ Do not change CPDs to (or from) zero while iterating over them.
3017
+
3024
3018
  Returns:
3025
3019
  an iterator over pairs (instance, cpd) where,
3026
3020
  instance: is indicates the state of the parent random variables.
@@ -3288,7 +3282,7 @@ def check_key(shape: Shape, key: Key) -> Instance:
3288
3282
  A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3289
3283
 
3290
3284
  Raises:
3291
- KeyError if the key is not valid.
3285
+ KeyError if the key is not valid for the given shape.
3292
3286
  """
3293
3287
  _key: Instance = _key_to_instance(key)
3294
3288
  if len(_key) != len(shape):
@@ -3336,8 +3330,8 @@ def rv_instances(*rvs: RandomVariable, flip: bool = False) -> Iterable[Instance]
3336
3330
  flip: if true, then first random variable changes most quickly.
3337
3331
 
3338
3332
  Returns:
3339
- an iteration over tuples, each tuple holds state indexes
3340
- co-indexed with the given random variables.
3333
+ an iteration over instances, each instance is a tuple of state
3334
+ indexes, co-indexed with the given random variables.
3341
3335
  """
3342
3336
  shape = [len(rv) for rv in rvs]
3343
3337
  return _combos_ranges(shape, flip=not flip)
@@ -3384,6 +3378,10 @@ def _natural_key_idx(shape: Shape, key: Key) -> int:
3384
3378
  """
3385
3379
  What is the natural index of the given key, assuming the given shape.
3386
3380
 
3381
+ The natural index of an instance is defined as the index of the
3382
+ instance if all instances for the shape are enumerated as per
3383
+ `rv_instances`.
3384
+
3387
3385
  Args:
3388
3386
  shape: the shape defining the state space.
3389
3387
  key: a key into the state space.
@@ -7,7 +7,7 @@ from ck.pgm_circuit import PGMCircuit
7
7
  class PGMCompiler(Protocol):
8
8
  def __call__(self, pgm: PGM, *, const_parameters: bool = True) -> PGMCircuit:
9
9
  """
10
- A PGM compiler is a function with this signature.
10
+ A PGM compiler compiles a PGM to an arithmetic circuit.
11
11
 
12
12
  Args:
13
13
  pgm: The PGM to compile.