compiled-knowledge 4.0.0a24__cp313-cp313-win32.whl → 4.1.0a1__cp313-cp313-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (42) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cp313-win32.pyd +0 -0
  3. ck/circuit/tmp_const.py +5 -4
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  5. ck/circuit_compiler/cython_vm_compiler/_compiler.cp313-win32.pyd +0 -0
  6. ck/circuit_compiler/interpret_compiler.py +2 -2
  7. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  8. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp313-win32.pyd +0 -0
  9. ck/circuit_compiler/support/llvm_ir_function.py +4 -4
  10. ck/dataset/__init__.py +1 -0
  11. ck/dataset/cross_table.py +270 -0
  12. ck/dataset/cross_table_probabilities.py +53 -0
  13. ck/dataset/dataset.py +577 -0
  14. ck/dataset/dataset_compute.py +140 -0
  15. ck/dataset/dataset_from_crosstable.py +45 -0
  16. ck/dataset/dataset_from_csv.py +147 -0
  17. ck/dataset/sampled_dataset.py +96 -0
  18. ck/example/diamond_square.py +3 -1
  19. ck/example/triangle_square.py +3 -1
  20. ck/example/truss.py +3 -1
  21. ck/in_out/parse_net.py +21 -19
  22. ck/in_out/parser_utils.py +7 -3
  23. ck/learning/__init__.py +0 -0
  24. ck/learning/train_generative.py +149 -0
  25. ck/pgm.py +95 -84
  26. ck/pgm_circuit/mpe_program.py +3 -4
  27. ck/pgm_circuit/pgm_circuit.py +27 -18
  28. ck/pgm_circuit/program_with_slotmap.py +27 -46
  29. ck/pgm_circuit/support/compile_circuit.py +2 -4
  30. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  31. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp313-win32.pyd +0 -0
  32. ck/probability/empirical_probability_space.py +1 -0
  33. ck/probability/probability_space.py +10 -11
  34. ck/program/raw_program.py +23 -16
  35. ck/sampling/sampler_support.py +5 -6
  36. ck/utils/iter_extras.py +3 -2
  37. ck/utils/local_config.py +16 -8
  38. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0a1.dist-info}/METADATA +1 -1
  39. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0a1.dist-info}/RECORD +42 -32
  40. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0a1.dist-info}/WHEEL +0 -0
  41. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0a1.dist-info}/licenses/LICENSE.txt +0 -0
  42. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0a1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,149 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Tuple, List
3
+
4
+ import numpy as np
5
+
6
+ from ck.dataset import SoftDataset, HardDataset
7
+ from ck.dataset.cross_table import CrossTable, cross_table_from_dataset
8
+ from ck.pgm import PGM, Instance, DensePotentialFunction, Shape, natural_key_idx, SparsePotentialFunction
9
+ from ck.utils.iter_extras import multiply
10
+ from ck.utils.np_extras import NDArrayFloat64
11
+
12
+
13
+ @dataclass
14
+ class ParameterValues:
15
+ """
16
+ A ParameterValues object represents learned parameter values of a PGM.
17
+ """
18
+ pgm: PGM
19
+ """
20
+ The PGM that the parameter values pertains to.
21
+ """
22
+
23
+ cpts: List[Dict[Instance, NDArrayFloat64]]
24
+ """
25
+ A list of CPTs co-indexed with `pgm.factors`. Each CPT is a dict
26
+ mapping from instances of the parent random variables (of the factors)
27
+ to the child conditional probability distribution (CPD).
28
+ """
29
+
30
+ def set_zero(self) -> None:
31
+ """
32
+ Set the potential function of each PGM factor to zero.
33
+ """
34
+ for factor in self.pgm.factors:
35
+ factor.set_zero()
36
+
37
+ def set_cpt(self) -> None:
38
+ """
39
+ Set the potential function of each PGM factor to a CPTPotentialFunction,
40
+ using our parameter values.
41
+ """
42
+ for factor, cpt in zip(self.pgm.factors, self.cpts):
43
+ factor.set_cpt().set(*cpt.items())
44
+
45
+ def set_dense(self) -> None:
46
+ """
47
+ Set the potential function of each PGM factor to a DensePotentialFunction,
48
+ using our parameter values.
49
+ """
50
+ for factor, cpt in zip(self.pgm.factors, self.cpts):
51
+ pot_function: DensePotentialFunction = factor.set_dense()
52
+ parent_shape: Shape = factor.shape[1:]
53
+ child_state: int
54
+ value: float
55
+ if len(parent_shape) == 0:
56
+ cpd: NDArrayFloat64 = cpt[()]
57
+ for child_state, value in enumerate(cpd):
58
+ pot_function[child_state] = value
59
+ else:
60
+ parent_space: int = multiply(parent_shape)
61
+ parent_states: Instance
62
+ cpd: NDArrayFloat64
63
+ for parent_states, cpd in cpt.items():
64
+ idx: int = natural_key_idx(parent_shape, parent_states)
65
+ for value in cpd:
66
+ pot_function[idx] = value
67
+ idx += parent_space
68
+
69
+ def set_sparse(self) -> None:
70
+ """
71
+ Set the potential function of each PGM factor to a SparsePotentialFunction,
72
+ using our parameter values.
73
+ """
74
+ for factor, cpt in zip(self.pgm.factors, self.cpts):
75
+ pot_function: SparsePotentialFunction = factor.set_sparse()
76
+ parent_states: Instance
77
+ child_state: int
78
+ cpd: NDArrayFloat64
79
+ value: float
80
+ for parent_states, cpd in cpt.items():
81
+ for child_state, value in enumerate(cpd):
82
+ key = (child_state,) + parent_states
83
+ pot_function[key] = value
84
+
85
+
86
+ def train_generative_bn(
87
+ pgm: PGM,
88
+ dataset: HardDataset | SoftDataset,
89
+ *,
90
+ dirichlet_prior: float = 0,
91
+ check_bayesian_network: bool = True,
92
+ ) -> ParameterValues:
93
+ """
94
+ Maximum-likelihood, generative training for a Bayesian network.
95
+
96
+ Args:
97
+ pgm: the probabilistic graphical model defining the model structure.
98
+ Potential function values are ignored and need not be set.
99
+ dataset: a dataset of random variable states.
100
+ dirichlet_prior: a real number >= 0. See `CrossTable` for an explanation.
101
+ check_bayesian_network: if true and not pgm.is_structure_bayesian an exception will be raised.
102
+
103
+ Returns:
104
+ a ParameterValues object that can be used to update the parameters of the given PGM.
105
+
106
+ Raises:
107
+ ValueError: if the given PGM does not have a Bayesian network structure, and check_bayesian_network is True.
108
+ """
109
+ if check_bayesian_network and not pgm.is_structure_bayesian:
110
+ raise ValueError('the given PGM is not a Bayesian network')
111
+ cpts: List[Dict[Instance, NDArrayFloat64]] = [
112
+ cpt_from_crosstab(cross_table_from_dataset(dataset, factor.rvs, dirichlet_prior=dirichlet_prior))
113
+ for factor in pgm.factors
114
+ ]
115
+ return ParameterValues(pgm, cpts)
116
+
117
+
118
+ def cpt_from_crosstab(crosstab: CrossTable) -> Dict[Instance, NDArrayFloat64]:
119
+ """
120
+ Make a conditional probability table (CPT) from a cross-table.
121
+
122
+ Args:
123
+ crosstab: a CrossTable representing the weight of unique instances.
124
+
125
+ Returns:
126
+ a mapping from instances of the parent random variables to the child
127
+ conditional probability distribution (CPD).
128
+
129
+ Assumes:
130
+ the first random variable in `crosstab.rvs` is the child random variable.
131
+ """
132
+ # Number of states for the child random variable.
133
+ child_size: int = len(crosstab.rvs[0])
134
+
135
+ # Get distribution over child states for seen parent states
136
+ parents_weights: Dict[Instance, NDArrayFloat64] = {}
137
+ for state, weight in crosstab.items():
138
+ parent_state: Tuple[int, ...] = state[1:]
139
+ child_state: int = state[0]
140
+ parent_weights = parents_weights.get(parent_state)
141
+ if parent_weights is None:
142
+ parents_weights[parent_state] = parent_weights = np.zeros(child_size, dtype=np.float64)
143
+ parent_weights[child_state] += weight
144
+
145
+ # Normalise
146
+ for parent_state, parent_weights in parents_weights.items():
147
+ parent_weights /= parent_weights.sum()
148
+
149
+ return parents_weights
ck/pgm.py CHANGED
@@ -15,33 +15,34 @@ from ck.utils.iter_extras import (
15
15
  from ck.utils.np_extras import NDArrayFloat64, NDArrayUInt8
16
16
 
17
17
  State: TypeAlias = Union[int, str, bool, float, None]
18
- State.__doc__ = \
19
- """
20
- The type for a possible state of a random variable.
21
- """
22
-
23
- Instance: TypeAlias = Sequence[int]
24
- Instance.__doc__ = \
25
- """
26
- An instance (of a sequence of random variables) is a sequence of integers
27
- that are state indexes, co-indexed with a known sequence of random variables.
28
- """
29
-
30
- Key: TypeAlias = Union[Instance, int]
31
- Key.__doc__ = \
32
- """
33
- A key identifies an instance, either as an instance itself or a
34
- single integer, representing an instance with one dimension.
35
- """
18
+ """
19
+ The type for a possible state of a random variable.
20
+ """
21
+
22
+ Instance: TypeAlias = Tuple[int, ...]
23
+ """
24
+ An instance (of a sequence of random variables) is a tuple of integers
25
+ that are state indexes, co-indexed with a known sequence of random variables.
26
+ """
27
+
28
+ Key: TypeAlias = Union[Sequence[int], int]
29
+ """
30
+ A key identifies an instance, either as a sequence of integers or a
31
+ single integer. The integers are state indexes, co-indexed with a known
32
+ sequence of random variables. A single integer represents an instance with
33
+ one dimension.
34
+ """
36
35
 
37
36
  Shape: TypeAlias = Sequence[int]
38
- Key.__doc__ = \
39
- """
40
- The type for the "shape" of a sequence of random variables.
41
- That is, the shape of (rv1, rv2, rv3) is (len(rv1), len(rv2), len(rv3)).
42
- """
37
+ """
38
+ The type for the "shape" of a sequence of random variables.
39
+ That is, the shape of (rv1, rv2, rv3) is (len(rv1), len(rv2), len(rv3)).
40
+ """
43
41
 
44
- DEFAULT_CPT_TOLERANCE: float = 0.000001 # A tolerance when checking CPT distributions sum to one (or zero).
42
+ DEFAULT_CPT_TOLERANCE: float = 0.000001
43
+ """
44
+ A tolerance when checking CPT distributions sum to one (or zero).
45
+ """
45
46
 
46
47
 
47
48
  class PGM:
@@ -214,14 +215,17 @@ class PGM:
214
215
  The returned random variable will have an `idx` equal to the value of
215
216
  `self.number_of_rvs` just prior to adding the new random variable.
216
217
 
218
+ The states of the random variable can be specified either as an integer
219
+ representing the number of states, or as a sequence of state values. If a
220
+ single integer, `n`, is provided then the states will be: 0, 1, ..., n-1.
221
+ If a sequence of states are provided then the states must be unique.
222
+
217
223
  Assumes:
218
224
  Provided states contain no duplicates.
219
225
 
220
226
  Args:
221
227
  name: a name for the random variable.
222
- states: either an integer number of states or a sequence of state values. If a
223
- single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
224
- If a sequence of states are provided then the states must be unique.
228
+ states: either the number of states or a sequence of state values.
225
229
 
226
230
  Returns:
227
231
  a RandomVariable object belonging to this PGM.
@@ -241,10 +245,11 @@ class PGM:
241
245
 
242
246
  Assumes:
243
247
  The given random variables all belong to this PGM.
248
+
244
249
  The random variables contain no duplicates.
245
250
 
246
251
  Args:
247
- *rvs: the random variables.
252
+ rvs: the random variables.
248
253
 
249
254
  Returns:
250
255
  a Factor object belonging to this PGM.
@@ -336,17 +341,18 @@ class PGM:
336
341
  *input_rvs: RandomVariable
337
342
  ) -> Factor:
338
343
  """
339
- Add a sparse 0/1 factor to this PGM representing:
340
- result_rv == function(*rvs).
341
- That is:
344
+ Add a sparse 0/1 factor to this PGM representing `result_rv == function(*rvs)`.
345
+ That is::
346
+
342
347
  factor[result_s, *input_s] = 1, if result_s == function(*input_s);
343
348
  = 0, otherwise.
349
+
344
350
  Args:
345
351
  function: a function from state indexes of the input random variables to a state index
346
352
  of the result random variable. The function should take the same number of arguments
347
353
  as `input_rvs` and return a state index for `result_rv`.
348
354
  result_rv: the random variable defining result values.
349
- *input_rvs: the random variables defining input values.
355
+ input_rvs: the random variables defining input values.
350
356
 
351
357
  Returns:
352
358
  a Factor object belonging to this PGM, with a configured sparse potential function.
@@ -378,16 +384,17 @@ class PGM:
378
384
  """
379
385
  Render indicators as a string.
380
386
 
381
- For example:
387
+ For example::
382
388
  pgm = PGM()
383
389
  a = pgm.new_rv('A', ('x', 'y', 'z'))
384
390
  b = pgm.new_rv('B', (3, 5))
385
391
  print(pgm.indicator_str(a[0], b[1], a[2]))
386
- will print:
392
+
393
+ will print::
387
394
  A=x, B=5, A=z
388
395
 
389
396
  Args:
390
- *indicators: the indicators to render.
397
+ indicators: the indicators to render.
391
398
  sep: the separator to use between the random variable and its state.
392
399
  delim: the delimiter to used when rendering multiple indicators.
393
400
 
@@ -406,16 +413,17 @@ class PGM:
406
413
  """
407
414
  Render indicators as a string, grouping indicators by random variable.
408
415
 
409
- For example:
416
+ For example::
410
417
  pgm = PGM()
411
418
  a = pgm.new_rv('A', ('x', 'y', 'z'))
412
419
  b = pgm.new_rv('B', (3, 5))
413
420
  print(pgm.condition_str(a[0], b[1], a[2]))
414
- will print:
421
+
422
+ will print::
415
423
  A in {x, z}, B=5
416
424
 
417
425
  Args:
418
- *indicators: the indicators to render.
426
+ indicators: the indicators to render.
419
427
  Return:
420
428
  a string representation of the given indicators, as a condition.
421
429
  """
@@ -930,9 +938,9 @@ class RandomVariable(Sequence[Indicator]):
930
938
  in the random variable's PGM list of random variables.
931
939
 
932
940
  A random variable behaves like a sequence of Indicators, where each indicator represents a random
933
- variable being in a particular state. Specifically for a random variable rv, len(rv) is the
941
+ variable being in a particular state. Specifically for a random variable rv, `len(rv)` is the
934
942
  number of states of the random variable and rv[i] is the Indicators representing that
935
- rv is in the ith state. When sliced, the result is a tuple, i.e. rv[1:3] = (rv[1], rv[2]).
943
+ rv is in the ith state. When sliced, the result is a tuple, i.e. `rv[1:3] = (rv[1], rv[2])`.
936
944
 
937
945
  A RandomVariable has a name. This is for human convenience and has no functional purpose
938
946
  within a PGM.
@@ -942,15 +950,18 @@ class RandomVariable(Sequence[Indicator]):
942
950
  """
943
951
  Create a new random variable, in the given PGM.
944
952
 
953
+ The states of the random variable can be specified either as an integer
954
+ representing the number of states, or as a sequence of state values. If a
955
+ single integer, `n`, is provided then the states will be: 0, 1, ..., n-1.
956
+ If a sequence of states are provided then the states must be unique.
957
+
945
958
  Assumes:
946
959
  Provided states contain no duplicates.
947
960
 
948
961
  Args:
949
962
  pgm: the PGM that the random variable will belong to.
950
963
  name: a name for the random variable.
951
- states: either an integer number of states or a sequence of state values. If a
952
- single integer, `n`, is provided then the states will be 0, 1, ..., n-1.
953
- If a sequence of states are provided then the states must be unique.
964
+ states: either the number of states or a sequence of state values.
954
965
  """
955
966
  self._pgm: PGM = pgm
956
967
  self._name: str = name
@@ -1212,15 +1223,14 @@ class RVMap(Sequence[RandomVariable]):
1212
1223
  In addition to accessing a random variable by its index, an RVMap enables
1213
1224
  access to the PGM random variable via the name of each random variable.
1214
1225
 
1215
- For example, if `pgm.rvs[1]` is a random variable named `xray`, then:
1216
- ```
1217
- rvs = RVMap(pgm)
1226
+ For example, if `pgm.rvs[1]` is a random variable named `xray`, then::
1227
+
1228
+ rvs = RVMap(pgm)
1218
1229
 
1219
- # These all retrieve the same random variable object.
1220
- xray = rvs[1]
1221
- xray = rvs('xray')
1222
- xray = rvs.xray
1223
- ```
1230
+ # These all retrieve the same random variable object.
1231
+ xray = rvs[1]
1232
+ xray = rvs('xray')
1233
+ xray = rvs.xray
1224
1234
 
1225
1235
  To use an RVMap on a PGM, the random variable names must be unique across the PGM.
1226
1236
  """
@@ -1527,7 +1537,7 @@ class Factor:
1527
1537
  Set to the potential function to a new `ClausePotentialFunction` object.
1528
1538
 
1529
1539
  Args:
1530
- *key: defines the random variable states of the clause. The key is a sequence of
1540
+ key: defines the random variable states of the clause. The key is a sequence of
1531
1541
  random variable state indexes, co-indexed with `Factor.rvs`.
1532
1542
 
1533
1543
  Returns:
@@ -1556,7 +1566,7 @@ class Factor:
1556
1566
  return self._potential_function
1557
1567
 
1558
1568
 
1559
- @dataclass(frozen=True, eq=True)
1569
+ @dataclass(frozen=True, eq=True, slots=True)
1560
1570
  class ParamId:
1561
1571
  """
1562
1572
  A ParamId identifies a parameter of a potential function.
@@ -1863,7 +1873,7 @@ class PotentialFunction(ABC):
1863
1873
  a hypothetical parameter index assuming that every valid key has a unique parameter
1864
1874
  as per DensePotentialFunction.
1865
1875
  """
1866
- return _natural_key_idx(self._shape, key)
1876
+ return natural_key_idx(self._shape, key)
1867
1877
 
1868
1878
  def param_id(self, param_idx: int) -> ParamId:
1869
1879
  """
@@ -2021,7 +2031,7 @@ class ZeroPotentialFunction(PotentialFunction):
2021
2031
  return 0
2022
2032
 
2023
2033
  def param_idx(self, key: Key) -> int:
2024
- return _natural_key_idx(self._shape, key)
2034
+ return natural_key_idx(self._shape, key)
2025
2035
 
2026
2036
  def is_cpt(self, tolerance=DEFAULT_CPT_TOLERANCE) -> bool:
2027
2037
  return True
@@ -2164,7 +2174,7 @@ class DensePotentialFunction(PotentialFunction):
2164
2174
  """
2165
2175
  Set the values of the potential function using the given iterator.
2166
2176
 
2167
- Mapping instances to *values is as follows:
2177
+ Mapping instances to values is as follows:
2168
2178
  Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2169
2179
  values[0] represents instance (0,0)
2170
2180
  values[1] represents instance (0,1)
@@ -2209,7 +2219,7 @@ class DensePotentialFunction(PotentialFunction):
2209
2219
  The order of values is the same as set_iter.
2210
2220
 
2211
2221
  Args:
2212
- *value: the values to use.
2222
+ value: the values to use.
2213
2223
 
2214
2224
  Returns:
2215
2225
  self
@@ -2414,7 +2424,7 @@ class SparsePotentialFunction(PotentialFunction):
2414
2424
  """
2415
2425
  Set the values of the potential function using the given iterator.
2416
2426
 
2417
- Mapping instances to *values is as follows:
2427
+ Mapping instances to values is as follows:
2418
2428
  Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2419
2429
  values[0] represents instance (0,0)
2420
2430
  values[1] represents instance (0,1)
@@ -2636,7 +2646,7 @@ class CompactPotentialFunction(PotentialFunction):
2636
2646
  """
2637
2647
  Set the values of the potential function using the given iterator.
2638
2648
 
2639
- Mapping instances to *values is as follows:
2649
+ Mapping instances to `values` is as follows:
2640
2650
  Given Factor(rv1, rv2) where rv1 has 2 states, and rv2 has 3 states:
2641
2651
  values[0] represents instance (0,0)
2642
2652
  values[1] represents instance (0,1)
@@ -2679,7 +2689,7 @@ class CompactPotentialFunction(PotentialFunction):
2679
2689
  The order of values is the same as set_iter.
2680
2690
 
2681
2691
  Args:
2682
- *value: the values to use.
2692
+ value: the values to use.
2683
2693
 
2684
2694
  Returns:
2685
2695
  self
@@ -3071,7 +3081,8 @@ class CPTPotentialFunction(PotentialFunction):
3071
3081
  Calls self.set_cpd(parent_states, cpd) for each row (parent_states, cpd)
3072
3082
  in rows. Any unmentioned parent states will have zero probabilities.
3073
3083
 
3074
- Example usage, assuming three Boolean random variables:
3084
+ Example usage, assuming three Boolean random variables::
3085
+
3075
3086
  pgm.Factor(x, y, z).set_cpt().set(
3076
3087
  # y z x[0] x[1]
3077
3088
  ((0, 0), (0.1, 0.9)),
@@ -3079,9 +3090,9 @@ class CPTPotentialFunction(PotentialFunction):
3079
3090
  ((1, 0), (0.1, 0.9)),
3080
3091
  ((1, 1), (0.1, 0.9))
3081
3092
  )
3082
-
3093
+
3083
3094
  Args:
3084
- *rows: are tuples (key, cpd) used to set the potential function values.
3095
+ rows: are tuples (key, cpd) used to set the potential function values.
3085
3096
 
3086
3097
  Raises:
3087
3098
  ValueError: if a CPD is not valid.
@@ -3105,7 +3116,7 @@ class CPTPotentialFunction(PotentialFunction):
3105
3116
  Any list entry may be None, indicating 'guaranteed zero' for the associated parent states.
3106
3117
 
3107
3118
  Args:
3108
- *cpds: are the CPDs used to set the potential function values.
3119
+ cpds: are the CPDs used to set the potential function values.
3109
3120
 
3110
3121
  Raises:
3111
3122
  ValueError: if a CPD is not valid.
@@ -3355,26 +3366,7 @@ def rv_instances_as_indicators(*rvs: RandomVariable, flip: bool = False) -> Iter
3355
3366
  return _combos(rvs, flip=not flip)
3356
3367
 
3357
3368
 
3358
- def _key_to_instance(key: Key) -> Instance:
3359
- """
3360
- Convert a key to an instance.
3361
-
3362
- Args:
3363
- key: a key into a state space.
3364
-
3365
- Returns:
3366
- A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3367
-
3368
- Assumes:
3369
- The key is valid for the implied state space.
3370
- """
3371
- if isinstance(key, int):
3372
- return (key,)
3373
- else:
3374
- return tuple(key)
3375
-
3376
-
3377
- def _natural_key_idx(shape: Shape, key: Key) -> int:
3369
+ def natural_key_idx(shape: Shape, key: Key) -> int:
3378
3370
  """
3379
3371
  What is the natural index of the given key, assuming the given shape.
3380
3372
 
@@ -3400,6 +3392,25 @@ def _natural_key_idx(shape: Shape, key: Key) -> int:
3400
3392
  return result
3401
3393
 
3402
3394
 
3395
+ def _key_to_instance(key: Key) -> Instance:
3396
+ """
3397
+ Convert a key to an instance.
3398
+
3399
+ Args:
3400
+ key: a key into a state space.
3401
+
3402
+ Returns:
3403
+ A instance from the state space, as a tuple of state indexes, co-indexed with the given shape.
3404
+
3405
+ Assumes:
3406
+ The key is valid for the implied state space.
3407
+ """
3408
+ if isinstance(key, int):
3409
+ return (key,)
3410
+ else:
3411
+ return tuple(key)
3412
+
3413
+
3403
3414
  def _zero_space(shape: Shape) -> int:
3404
3415
  """
3405
3416
  Return the size of the zero space of the given shape. This is the number
@@ -228,10 +228,9 @@ class MPEProgram(ProgramWithSlotmap):
228
228
  class MPEResult:
229
229
  """
230
230
  An MPE result is the result of MPE inference.
231
-
232
- Fields:
233
- wmc: the weighted model count value of the MPE solution.
234
- mpe: The MPE solution instance. If there are ties then this will just be once instance.
235
231
  """
236
232
  wmc: float
233
+ """the weighted model count value of the MPE solution."""
234
+
237
235
  mpe: Instance
236
+ """the MPE solution instance. If there are ties then this will just be once instance."""
@@ -16,33 +16,42 @@ class PGMCircuit:
16
16
  holds the values of the parameters. Specifically, given parameter id `param_id`, then
17
17
  `parameter_values[slot_map[param_id] - number_of_indicators]` is the value of the
18
18
  identified parameter as it was in the PGM.
19
-
20
- Fields:
21
- rvs: holds the random variables from the PGM as it was compiled, in order.
22
-
23
- conditions: any conditions on `rvs` that were compiled into the circuit.
24
-
25
- number_of_indicators: is the number of indicators in `rvs` which is
26
- `sum(len(rv) for rv in rvs`. Specifically, `circuit.vars[i]` is the circuit variable
27
- corresponding to the ith indicator, where `circuit` is `circuit_top.circuit` and
28
- indicators are ordered as per `rvs`.
29
-
30
- number_of_parameters: is the number of parameters from the PGM that are
31
- represented as circuit variables. This may be zero if parameters from the PGM
32
- were compiled as constants.
33
-
34
- slot_map[x]: gives the index of the circuit variable corresponding to x,
35
- where x is either a random variable indicator (Indicator) or a parameter id (ParamId).
36
-
37
19
  """
38
20
 
39
21
  rvs: Sequence[RandomVariable]
22
+ """holds the random variables from the PGM as it was compiled, in order."""
23
+
40
24
  conditions: Sequence[Indicator]
25
+ """any conditions on `rvs` that were compiled into the circuit."""
26
+
41
27
  circuit_top: CircuitNode
28
+ """the top circuit node defining the network function."""
29
+
42
30
  number_of_indicators: int
31
+ """
32
+ the number of indicators in `rvs` which is
33
+ `sum(len(rv) for rv in rvs`. Specifically, `circuit.vars[i]` is the circuit variable
34
+ corresponding to the ith indicator, where `circuit` is `circuit_top.circuit` and
35
+ indicators are ordered as per `rvs`.
36
+ """
37
+
43
38
  number_of_parameters: int
39
+ """
40
+ the number of parameters from the PGM that are
41
+ represented as circuit variables. This may be zero if parameters from the PGM
42
+ were compiled as constants.
43
+ """
44
+
44
45
  slot_map: SlotMap
46
+ """
47
+ gives the index of the circuit variable corresponding to x,
48
+ where x is either a random variable indicator (Indicator) or a parameter id (ParamId).
49
+ """
50
+
45
51
  parameter_values: NDArray
52
+ """
53
+ parameter values, co-indexed with the circuit variables, counting beyond `number_of_indicators`.
54
+ """
46
55
 
47
56
  def dump(self, *, prefix: str = '', indent: str = ' ') -> None:
48
57
  """