compiled-knowledge 4.0.0a9__cp312-cp312-win_amd64.whl → 4.0.0a11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (30) hide show
  1. ck/circuit/circuit.cp312-win_amd64.pyd +0 -0
  2. ck/circuit/circuit.pyx +20 -8
  3. ck/circuit/circuit_py.py +40 -19
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  5. ck/pgm.py +111 -130
  6. ck/pgm_circuit/pgm_circuit.py +13 -9
  7. ck/pgm_circuit/program_with_slotmap.py +6 -4
  8. ck/pgm_compiler/ace/ace.py +48 -4
  9. ck/pgm_compiler/factor_elimination.py +6 -4
  10. ck/pgm_compiler/recursive_conditioning.py +8 -3
  11. ck/pgm_compiler/support/circuit_table/circuit_table.cp312-win_amd64.pyd +0 -0
  12. ck/pgm_compiler/support/clusters.py +1 -1
  13. ck/pgm_compiler/variable_elimination.py +3 -3
  14. ck/probability/empirical_probability_space.py +3 -0
  15. ck/probability/pgm_probability_space.py +32 -0
  16. ck/probability/probability_space.py +66 -12
  17. ck/program/program.py +9 -1
  18. ck/program/raw_program.py +9 -3
  19. ck/sampling/sampler_support.py +1 -1
  20. ck/sampling/uniform_sampler.py +10 -4
  21. ck/sampling/wmc_direct_sampler.py +4 -2
  22. ck/sampling/wmc_gibbs_sampler.py +6 -0
  23. ck/sampling/wmc_metropolis_sampler.py +7 -1
  24. ck/sampling/wmc_rejection_sampler.py +2 -0
  25. ck/utils/iter_extras.py +9 -6
  26. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/METADATA +16 -12
  27. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/RECORD +30 -29
  28. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/WHEEL +0 -0
  29. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/licenses/LICENSE.txt +0 -0
  30. {compiled_knowledge-4.0.0a9.dist-info → compiled_knowledge-4.0.0a11.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Sequence, List, Dict
3
3
 
4
- from ck.circuit import CircuitNode
4
+ from ck.circuit import CircuitNode, Circuit
5
5
  from ck.pgm import RandomVariable, Indicator
6
6
  from ck.pgm_circuit.slot_map import SlotMap, SlotKey
7
7
  from ck.utils.np_extras import NDArray
@@ -46,19 +46,21 @@ class PGMCircuit:
46
46
 
47
47
  def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
48
48
  """
49
- Print a dump of the Join Tree.
49
+ Print a dump of the circuit.
50
50
  This is intended for debugging and demonstration purposes.
51
51
 
52
- Each cluster is printed as: {separator rvs} | {non-separator rvs}.
53
-
54
52
  Args:
55
53
  prefix: optional prefix for indenting all lines.
56
54
  indent: additional prefix to use for extra indentation.
57
55
  """
58
- # Name the circuit variables
59
- circuit = self.circuit_top.circuit
56
+
57
+ # We infer names for the circuit variables, either as an indicator or as a parameter.
58
+ # The `var_names` will be passed to `circuit.dump`.
59
+
60
+ circuit: Circuit = self.circuit_top.circuit
60
61
  var_names: List[str] = [''] * circuit.number_of_vars
61
62
 
63
+ # Name the circuit variables that are indicators
62
64
  rvs_by_idx: Dict[int, RandomVariable] = {rv.idx: rv for rv in self.rvs}
63
65
  slot_key: SlotKey
64
66
  slot: int
@@ -68,8 +70,10 @@ class PGMCircuit:
68
70
  state_idx = slot_key.state_idx
69
71
  var_names[slot] = f'{rv.name!r}[{state_idx}] {rv.states[state_idx]!r}'
70
72
 
71
- for slot, param_value in enumerate(self.parameter_values, start=self.number_of_indicators):
72
- var_names[slot] = f'param {param_value}'
73
+ # Name the circuit variables that are parameters
74
+ for i, param_value in enumerate(self.parameter_values):
75
+ slot = i + self.number_of_indicators
76
+ var_names[slot] = f'param[{i}] = {param_value}'
73
77
 
74
- # Show all the slots
78
+ # Dump the circuit
75
79
  circuit.dump(prefix=prefix, indent=indent, var_names=var_names)
@@ -1,4 +1,4 @@
1
- from typing import Tuple, Iterator, Sequence, Dict, Iterable
1
+ from typing import Tuple, Sequence, Dict, Iterable
2
2
 
3
3
  from ck.pgm import RandomVariable, rv_instances, Instance, rv_instances_as_indicators, Indicator, ParamId
4
4
  from ck.pgm_circuit.slot_map import SlotMap, SlotKey
@@ -30,11 +30,13 @@ class ProgramWithSlotmap:
30
30
  has a length and rv[i] is a unique 'indicator' across all rvs.
31
31
  precondition: conditions on rvs that are compiled into the program.
32
32
 
33
+ Raises:
34
+ ValueError: if rvs contains duplicates.
33
35
  """
34
36
  self._program_buffer: ProgramBuffer = program_buffer
35
37
  self._slot_map: SlotMap = slot_map
36
38
  self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
37
- self._precondition: Sequence[Indicator] = precondition
39
+ self._precondition: Tuple[Indicator, ...] = tuple(precondition)
38
40
 
39
41
  if len(rvs) != len(set(rv.idx for rv in rvs)):
40
42
  raise ValueError('duplicate random variables provided')
@@ -67,7 +69,7 @@ class ProgramWithSlotmap:
67
69
  def slot_map(self) -> SlotMap:
68
70
  return self._slot_map
69
71
 
70
- def instances(self, flip: bool = False) -> Iterator[Instance]:
72
+ def instances(self, flip: bool = False) -> Iterable[Instance]:
71
73
  """
72
74
  Enumerate instances of the random variables.
73
75
 
@@ -84,7 +86,7 @@ class ProgramWithSlotmap:
84
86
  """
85
87
  return rv_instances(*self._rvs, flip=flip)
86
88
 
87
- def instances_as_indicators(self, flip: bool = False) -> Iterator[Sequence[Indicator]]:
89
+ def instances_as_indicators(self, flip: bool = False) -> Iterable[Sequence[Indicator]]:
88
90
  """
89
91
  Enumerate instances of the random variables.
90
92
 
@@ -5,7 +5,9 @@ from dataclasses import dataclass
5
5
  from pathlib import Path
6
6
  from typing import Optional, List, Tuple
7
7
 
8
- from ck.circuit import CircuitNode
8
+ import numpy as np
9
+
10
+ from ck.circuit import CircuitNode, Circuit
9
11
  from ck.in_out.parse_ace_lmap import read_lmap, LiteralMap
10
12
  from ck.in_out.parse_ace_nnf import read_nnf_with_literal_map
11
13
  from ck.in_out.render_net import render_bayesian_network
@@ -58,6 +60,22 @@ def compile_pgm(
58
60
  if check_is_bayesian_network and not pgm.check_is_bayesian_network():
59
61
  raise ValueError('the given PGM is not a Bayesian network')
60
62
 
63
+ # ACE cannot deal with the empty PGM even though it is a valid Bayesian network
64
+ if pgm.number_of_factors == 0:
65
+ circuit = Circuit()
66
+ circuit.new_vars(pgm.number_of_indicators)
67
+ parameter_values = np.array([], dtype=np.float64)
68
+ slot_map = {indicator: i for i, indicator in enumerate(pgm.indicators)}
69
+ return PGMCircuit(
70
+ rvs=pgm.rvs,
71
+ conditions=(),
72
+ circuit_top=circuit.const(1),
73
+ number_of_indicators=pgm.number_of_indicators,
74
+ number_of_parameters=0,
75
+ slot_map=slot_map,
76
+ parameter_values=parameter_values,
77
+ )
78
+
61
79
  java: str
62
80
  classpath_separator: str
63
81
  java, classpath_separator = _find_java()
@@ -69,7 +87,7 @@ def compile_pgm(
69
87
  )
70
88
  ace_cmd: List[str] = [
71
89
  java,
72
- f'-cp',
90
+ '-cp',
73
91
  class_path,
74
92
  f'-DACEC2D={files.c2d}',
75
93
  f'-Xmx{int(m_bytes)}m',
@@ -83,7 +101,7 @@ def compile_pgm(
83
101
  node_names: List[str] = render_bayesian_network(pgm, file, check_structure_bayesian=False)
84
102
 
85
103
  # Run Ace
86
- ace_result = subprocess.run(ace_cmd, capture_output=True, text=True)
104
+ ace_result: subprocess.CompletedProcess = subprocess.run(ace_cmd, capture_output=True, text=True)
87
105
  if print_output:
88
106
  print(ace_result.stdout)
89
107
  print(ace_result.stderr)
@@ -126,6 +144,29 @@ def compile_pgm(
126
144
  )
127
145
 
128
146
 
147
+ def ace_available(
148
+ ace_dir: Optional[Path | str] = None,
149
+ jar_dir: Optional[Path | str] = None,
150
+ ) -> bool:
151
+ """
152
+ Returns:
153
+ True if it looks like ACE is available, False otherwise.
154
+ ACE is available if ACE files are in the default location and Java is available.
155
+ """
156
+ try:
157
+ java: str
158
+ java, _ = _find_java()
159
+ _: _AceFiles = _find_ace_files(ace_dir, jar_dir)
160
+
161
+ java_cmd: List[str] = [java, '--version',]
162
+ java_result: subprocess.CompletedProcess = subprocess.run(java_cmd, capture_output=True, text=True)
163
+
164
+ return java_result.returncode == 0
165
+
166
+ except RuntimeError:
167
+ return False
168
+
169
+
129
170
  def copy_ace_to_default_location(
130
171
  ace_dir: Path | str,
131
172
  jar_dir: Optional[Path | str] = None,
@@ -179,6 +220,9 @@ def _find_java() -> Tuple[str, str]:
179
220
 
180
221
  Returns:
181
222
  (java, classpath_separator)
223
+
224
+ Raises:
225
+ RuntimeError: if not found, including a helpful message.
182
226
  """
183
227
  if sys.platform == 'win32':
184
228
  return 'java.exe', ';'
@@ -198,7 +242,7 @@ def _find_ace_files(
198
242
  Look for the needed Ace files.
199
243
 
200
244
  Raises:
201
- RuntimeError: if not found, including a helpful message .
245
+ RuntimeError: if not found, including a helpful message.
202
246
  """
203
247
  ace_dir: Path = default_ace_location() if ace_dir is None else Path(ace_dir)
204
248
  jar_dir: Path = ace_dir if jar_dir is None else Path(jar_dir)
@@ -12,19 +12,21 @@ from ck.pgm_compiler.support.join_tree import *
12
12
 
13
13
  _NEG_INF = float('-inf')
14
14
 
15
+ DEFAULT_PRODUCT_SEARCH_LIMIT: int = 1000
16
+
15
17
 
16
18
  def compile_pgm(
17
19
  pgm: PGM,
18
20
  const_parameters: bool = True,
19
21
  *,
20
22
  algorithm: JoinTreeAlgorithm = MIN_FILL_THEN_DEGREE,
21
- limit_product_tree_search: int = 1000,
23
+ limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
22
24
  pre_prune_factor_tables: bool = True,
23
25
  ) -> PGMCircuit:
24
26
  """
25
27
  Compile the PGM to an arithmetic circuit, using factor elimination.
26
28
 
27
- When forming the product of factors withing a join tree nodes,
29
+ When forming the product of factors within join tree nodes,
28
30
  this method searches all practical binary trees for forming products,
29
31
  up to the given limit, `limit_product_tree_search`. The minimum is 1.
30
32
 
@@ -57,7 +59,7 @@ def compile_pgm_best_jointree(
57
59
  pgm: PGM,
58
60
  const_parameters: bool = True,
59
61
  *,
60
- limit_product_tree_search: int = 1000,
62
+ limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
61
63
  pre_prune_factor_tables: bool = True,
62
64
  ) -> PGMCircuit:
63
65
  """
@@ -111,7 +113,7 @@ def compile_pgm_best_jointree(
111
113
  def join_tree_to_circuit(
112
114
  join_tree: JoinTree,
113
115
  const_parameters: bool = True,
114
- limit_product_tree_search: int = 1000,
116
+ limit_product_tree_search: int = DEFAULT_PRODUCT_SEARCH_LIMIT,
115
117
  pre_prune_factor_tables: bool = True,
116
118
  ) -> PGMCircuit:
117
119
  """
@@ -52,10 +52,15 @@ def compile_pgm(
52
52
  multiply_indicators=True,
53
53
  pre_prune_factor_tables=pre_prune_factor_tables,
54
54
  )
55
- dtree: _DTree = _make_dtree(elimination_order, factor_tables)
56
55
 
57
- states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
58
- top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
56
+ if pgm.number_of_factors == 0:
57
+ # Deal with special case: no factors
58
+ top: CircuitNode = factor_tables.circuit.const(1)
59
+ else:
60
+ dtree: _DTree = _make_dtree(elimination_order, factor_tables)
61
+ states: List[Sequence[int]] = [tuple(range(len(rv))) for rv in pgm.rvs]
62
+ top: CircuitNode = dtree.make_circuit(states, factor_tables.circuit)
63
+
59
64
  top.circuit.remove_unreachable_op_nodes(top)
60
65
 
61
66
  return PGMCircuit(
@@ -9,7 +9,7 @@ from ck.pgm import PGM
9
9
 
10
10
  # A VEObjective is a variable elimination objective function.
11
11
  # An objective function is a function from a random variable index (int)
12
- # to an objected value (float or int). This is used to select
12
+ # to an objective value (float or int). This is used to select
13
13
  # a random variable to eliminate in `ve_greedy_min`.
14
14
  VEObjective = Callable[[int], int | float]
15
15
 
@@ -69,13 +69,13 @@ def compile_pgm(
69
69
  tables_with_rv.append(product(x, y))
70
70
  next_tables.append(sum_out(tables_with_rv[0], (rv_idx,)))
71
71
  cur_tables = next_tables
72
- # All rvs are now eliminated
72
+
73
+ # All rvs are now eliminated - all tables should have a single top.
73
74
  tops: List[CircuitNode] = [
74
75
  table.top()
75
76
  for table in cur_tables
76
- if len(table) > 0
77
77
  ]
78
- top = factor_tables.circuit.optimised_add(tops)
78
+ top: CircuitNode = factor_tables.circuit.optimised_mul(tops)
79
79
  top.circuit.remove_unreachable_op_nodes(top)
80
80
 
81
81
  return PGMCircuit(
@@ -7,6 +7,9 @@ from ck.probability.probability_space import ProbabilitySpace, Condition, check_
7
7
  class EmpiricalProbabilitySpace(ProbabilitySpace):
8
8
  def __init__(self, rvs: Sequence[RandomVariable], samples: Iterable[Instance]):
9
9
  """
10
+ Enable probabilistic queries over a sample from a sample space.
11
+ Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
12
+
10
13
  Assumes:
11
14
  len(sample) == len(rvs), for each sample in samples.
12
15
  0 <= sample[i] < len(rvs[i]), for each sample in samples, for i in range(len(rvs)).
@@ -0,0 +1,32 @@
1
+ from typing import Sequence, Iterable, Tuple, Dict, List
2
+
3
+ from ck.pgm import RandomVariable, Indicator, Instance, PGM
4
+ from ck.probability.probability_space import ProbabilitySpace, Condition, check_condition
5
+
6
+
7
+ class PGMProbabilitySpace(ProbabilitySpace):
8
+ def __init__(self, pgm: PGM):
9
+ """
10
+ Enable probabilistic queries directly on a PGM.
11
+ Note that this is not necessarily an efficient approach to calculating probabilities and statistics.
12
+
13
+ Args:
14
+ pgm: The PGM to query.
15
+ """
16
+ self._pgm = pgm
17
+ self._z = None
18
+
19
+ @property
20
+ def rvs(self) -> Sequence[RandomVariable]:
21
+ return self._pgm.rvs
22
+
23
+ def wmc(self, *condition: Condition) -> float:
24
+ condition: Tuple[Indicator, ...] = check_condition(condition)
25
+ return self._pgm.value_product_indicators(*condition)
26
+
27
+ @property
28
+ def z(self) -> float:
29
+ if self._z is None:
30
+ self._z = self._pgm.value_product_indicators()
31
+ return self._z
32
+
@@ -3,6 +3,7 @@ An abstract class for object providing probabilities.
3
3
  """
4
4
  import math
5
5
  from abc import ABC, abstractmethod
6
+ from itertools import chain
6
7
  from typing import Sequence, Tuple, Iterable, Callable
7
8
 
8
9
  import numpy as np
@@ -75,15 +76,38 @@ class ProbabilitySpace(ABC):
75
76
  the probability of the given indicators, conditioned on the given conditions.
76
77
  """
77
78
  condition: Tuple[Indicator, ...] = check_condition(condition)
79
+
78
80
  if len(condition) == 0:
79
81
  z = self.z
82
+ if z <= 0:
83
+ return np.nan
80
84
  else:
81
85
  z = self.wmc(*condition)
82
- indicators += condition # append the condition
83
- if z > 0:
84
- return self.wmc(*indicators) / z
85
- else:
86
- return np.nan
86
+ if z <= 0:
87
+ return np.nan
88
+
89
+ # Combine the indicators with the condition
90
+ # If a variable is mentioned in both the indicators and condition, then
91
+ # we need to take the intersection, and check for contradictions.
92
+ # If a variable is mentioned in the condition but not indicators, then
93
+ # the rv condition needs to be added to the indicators.
94
+ indicator_groups: MapSet[int, Indicator] = _group_indicators(indicators)
95
+ condition_groups: MapSet[int, Indicator] = _group_indicators(condition)
96
+
97
+ for rv_idx, indicators in condition_groups.items():
98
+ indicator_group = indicator_groups.get(rv_idx)
99
+ if indicator_group is None:
100
+ indicator_groups.add_all(rv_idx, indicators)
101
+ else:
102
+ indicator_group.intersection_update(indicators)
103
+ if len(indicator_group) == 0:
104
+ # A contradiction between the indicators and conditions
105
+ return 0.0
106
+
107
+ # Collect all the indicators from the updated indicator_groups
108
+ indicators = chain(*indicator_groups.values())
109
+
110
+ return self.wmc(*indicators) / z
87
111
 
88
112
  def marginal_distribution(self, *rvs: RandomVariable, condition: Condition = ()) -> NDArrayNumeric:
89
113
  """
@@ -160,9 +184,7 @@ class ProbabilitySpace(ABC):
160
184
  assert len(rv_indexes) == len(rvs), 'duplicated random variables not allowed'
161
185
 
162
186
  # Group conditioning indicators by random variable.
163
- conditions_by_rvs = MapSet()
164
- for ind in condition:
165
- conditions_by_rvs.get_set(ind.rv_idx).add(ind.state_idx)
187
+ conditions_by_rvs = _group_states(condition)
166
188
 
167
189
  # See if any MAP random variable is also conditioned.
168
190
  # Reduce the state space of any conditioned MAP rv.
@@ -195,12 +217,12 @@ class ProbabilitySpace(ABC):
195
217
  # Loop over the state space of the 'loop' rvs
196
218
  best_probability = float('-inf')
197
219
  best_states = None
198
- inds: Tuple[Indicator, ...]
199
- for inds in _combos(loop_rvs):
200
- probability = self.wmc(*(inds + new_conditions))
220
+ indicators: Tuple[Indicator, ...]
221
+ for indicators in _combos(loop_rvs):
222
+ probability = self.wmc(*(indicators + new_conditions))
201
223
  if probability > best_probability:
202
224
  best_probability = probability
203
- best_states = tuple(ind.state_idx for ind in inds)
225
+ best_states = tuple(ind.state_idx for ind in indicators)
204
226
  condition_probability = self.wmc(*condition)
205
227
  return best_probability / condition_probability, best_states
206
228
 
@@ -551,6 +573,38 @@ def dtype_for_state_indexes(rvs: Iterable[RandomVariable]) -> DTypeStates:
551
573
  return dtype_for_number_of_states(max((len(rv) for rv in rvs), default=0))
552
574
 
553
575
 
576
+ def _group_indicators(indicators: Iterable[Indicator]) -> MapSet[int, Indicator]:
577
+ """
578
+ Group the given indicators by rv_idx.
579
+
580
+ Args:
581
+ indicators: the indicators to group.
582
+
583
+ Returns:
584
+ A mapping from rv_idx to set of indicators.
585
+ """
586
+ groups: MapSet[int, Indicator] = MapSet()
587
+ for indicator in indicators:
588
+ groups.add(indicator.rv_idx, indicator)
589
+ return groups
590
+
591
+
592
+ def _group_states(indicators: Iterable[Indicator]) -> MapSet[int, int]:
593
+ """
594
+ Group the given indicator states by rv_idx.
595
+
596
+ Args:
597
+ indicators: the indicators to group.
598
+
599
+ Returns:
600
+ A mapping from rv_idx to set of state indexes.
601
+ """
602
+ groups: MapSet[int, int] = MapSet()
603
+ for indicator in indicators:
604
+ groups.add(indicator.rv_idx, indicator.state_idx)
605
+ return groups
606
+
607
+
554
608
  def _normalise_marginal(distribution: NDArrayFloat64) -> None:
555
609
  """
556
610
  Update the values in the given distribution to
ck/program/program.py CHANGED
@@ -1,3 +1,6 @@
1
+ """
2
+ For more documentation on this module, refer to the Jupyter notebook docs/6_circuits_and_programs.ipynb.
3
+ """
1
4
  from typing import Callable, Sequence
2
5
 
3
6
  import numpy as np
@@ -8,7 +11,12 @@ from ck.utils.np_extras import DTypeNumeric, NDArrayNumeric
8
11
 
9
12
  class Program:
10
13
  """
11
- A Program wraps a RawProgram to make a convenient callable.
14
+ A program represents an arithmetic a function from input values to output values.
15
+
16
+ Internally a `Program` wraps a `RawProgram` which is the object returned by a circuit compiler.
17
+
18
+ Every `Program` has a numpy `dtype` which defines the numeric data type for input and output values.
19
+ Typically, the `dtype` of a program is a C style double.
12
20
  """
13
21
 
14
22
  def __init__(self, raw_program: RawProgram):
ck/program/raw_program.py CHANGED
@@ -9,8 +9,8 @@ from ck.utils.np_extras import NDArrayNumeric, DTypeNumeric
9
9
 
10
10
  # RawProgramFunction is a function of three ctypes arrays, returning nothing.
11
11
  # Args:
12
- # [0]: input parameter values,
13
- # [1]: working memory,
12
+ # [0]: input values,
13
+ # [1]: temporary working memory,
14
14
  # [2]: output values.
15
15
  RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
16
16
 
@@ -18,10 +18,16 @@ RawProgramFunction = Callable[[ct.POINTER, ct.POINTER, ct.POINTER], None]
18
18
  @dataclass
19
19
  class RawProgram:
20
20
  """
21
+ A raw program is returned by a circuit compiler to provide execution of
22
+ the function defined by a compiled circuit.
23
+
24
+ A `RawProgram` is a `Callable` with the signature:
25
+
26
+
21
27
  Fields:
22
28
  function: is a function of three ctypes arrays, returning nothing.
23
29
  dtype: the numpy data type of the array values.
24
- number_of_vars: the number of input parameter values (first function argument).
30
+ number_of_vars: the number of input values (first function argument).
25
31
  number_of_tmps: the number of working memory values (second function argument).
26
32
  number_of_results: the number of result values (third function argument).
27
33
  var_indices: maps the index of inputs (from 0 to self.number_of_vars - 1) to the index
@@ -13,7 +13,7 @@ from ck.utils.random_extras import Random
13
13
  # Type of a yield function. Support for a sampler.
14
14
  # A yield function may be used to implement a sampler's iterator, thus
15
15
  # it provides an Instance or single state index.
16
- YieldF = Callable[[NDArrayStates], Instance | int]
16
+ YieldF = Callable[[NDArrayStates], int] | Callable[[NDArrayStates], Instance]
17
17
 
18
18
 
19
19
  @dataclass
@@ -1,15 +1,15 @@
1
- from typing import Set, List, Iterator, Optional, Sequence
2
1
  import random
2
+ from typing import Set, List, Iterator, Optional, Sequence
3
3
 
4
4
  import numpy as np
5
5
 
6
6
  from ck.pgm import Instance, RandomVariable, Indicator
7
7
  from ck.probability.probability_space import dtype_for_state_indexes, Condition, check_condition
8
- from .sampler import Sampler
9
- from .sampler_support import YieldF
10
8
  from ck.utils.map_set import MapSet
11
9
  from ck.utils.np_extras import DType
12
10
  from ck.utils.random_extras import Random
11
+ from .sampler import Sampler
12
+ from .sampler_support import YieldF
13
13
 
14
14
 
15
15
  class UniformSampler(Sampler):
@@ -39,11 +39,15 @@ class UniformSampler(Sampler):
39
39
  conditioned_rvs.add(ind.rv_idx, ind.state_idx)
40
40
 
41
41
  def get_possible_states(_rv: RandomVariable) -> List[int]:
42
+ """
43
+ Get the allowable states for a given random variable, given
44
+ conditions in `conditioned_rvs`.
45
+ """
42
46
  condition_states: Optional[Set[int]] = conditioned_rvs.get(_rv.idx)
43
47
  if condition_states is None:
44
48
  return list(range(len(_rv)))
45
49
  else:
46
- return [state_idx for state_idx in range(len(_rv)) if state_idx not in condition_states]
50
+ return list(condition_states)
47
51
 
48
52
  possible_states: List[List[int]] = [
49
53
  get_possible_states(rv)
@@ -63,4 +67,6 @@ class UniformSampler(Sampler):
63
67
  for i, l in enumerate(possible_states):
64
68
  state_idx = rand.randrange(0, len(l))
65
69
  state[i] = l[state_idx]
70
+ # We know the yield function will always provide either ints or Instances
71
+ # noinspection PyTypeChecker
66
72
  yield yield_f(state)
@@ -8,7 +8,7 @@ from ck.program.program_buffer import ProgramBuffer
8
8
  from ck.program.raw_program import RawProgram
9
9
  from ck.sampling.sampler import Sampler
10
10
  from ck.sampling.sampler_support import SampleRV, YieldF, SamplerInfo
11
- from ck.utils.np_extras import NDArrayNumeric
11
+ from ck.utils.np_extras import NDArrayNumeric, NDArrayStates
12
12
  from ck.utils.random_extras import Random
13
13
 
14
14
 
@@ -52,7 +52,7 @@ class WMCDirectSampler(Sampler):
52
52
  return program_buffer.compute().item()
53
53
 
54
54
  # Set up working memory buffers
55
- states = np.zeros(len(sample_rvs), dtype=self._state_dtype)
55
+ states: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
56
56
  buff_slots = np.zeros(self._max_number_of_states, dtype=np.uintp)
57
57
  buff_states = np.zeros(self._max_number_of_states, dtype=self._state_dtype)
58
58
 
@@ -153,6 +153,8 @@ class WMCDirectSampler(Sampler):
153
153
  slots[slot] = 1
154
154
  states[sample_rv.index] = state
155
155
 
156
+ # We know the yield function will always provide either ints or Instances
157
+ # noinspection PyTypeChecker
156
158
  yield yield_f(states)
157
159
 
158
160
  # Reset the one slots for the next iteration.
@@ -65,12 +65,16 @@ class WMCGibbsSampler(Sampler):
65
65
  if skip == 0:
66
66
  while True:
67
67
  self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
68
+ # We know the yield function will always provide either ints or Instances
69
+ # noinspection PyTypeChecker
68
70
  yield yield_f(state)
69
71
  else:
70
72
  while True:
71
73
  for _ in range(skip):
72
74
  self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
73
75
  self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
76
+ # We know the yield function will always provide either ints or Instances
77
+ # noinspection PyTypeChecker
74
78
  yield yield_f(state)
75
79
 
76
80
  else:
@@ -79,6 +83,8 @@ class WMCGibbsSampler(Sampler):
79
83
  for _ in range(skip):
80
84
  self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
81
85
  self._next_sample_gibbs(sample_rvs, slots_1, program_buffer, prs, state, rand)
86
+ # We know the yield function will always provide either ints or Instances
87
+ # noinspection PyTypeChecker
82
88
  yield yield_f(state)
83
89
  if rand.random() < pr_restart:
84
90
  # Set an initial system state
@@ -1,4 +1,4 @@
1
- from typing import Collection, Iterator, List, Sequence
1
+ from typing import Collection, Iterator, Sequence
2
2
 
3
3
  import numpy as np
4
4
 
@@ -78,12 +78,16 @@ class WMCMetropolisSampler(Sampler):
78
78
  if skip == 0:
79
79
  while True:
80
80
  w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
81
+ # We know the yield function will always provide either ints or Instances
82
+ # noinspection PyTypeChecker
81
83
  yield yield_f(state)
82
84
  else:
83
85
  while True:
84
86
  for _ in range(skip):
85
87
  w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
86
88
  w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
89
+ # We know the yield function will always provide either ints or Instances
90
+ # noinspection PyTypeChecker
87
91
  yield yield_f(state)
88
92
 
89
93
  else:
@@ -92,6 +96,8 @@ class WMCMetropolisSampler(Sampler):
92
96
  for _ in range(skip):
93
97
  w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
94
98
  w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
99
+ # We know the yield function will always provide either ints or Instances
100
+ # noinspection PyTypeChecker
95
101
  yield yield_f(state)
96
102
 
97
103
  if rand.random() < pr_restart:
@@ -91,6 +91,8 @@ class WMCRejectionSampler(Sampler):
91
91
  w: float = wmc()
92
92
 
93
93
  if rand.random() * self._w_max < w:
94
+ # We know the yield function will always provide either ints or Instances
95
+ # noinspection PyTypeChecker
94
96
  yield yield_f(state)
95
97
 
96
98
  # Update w_not_seen and w_high to adapt w_max.