compiled-knowledge 4.0.0a24__cp312-cp312-win32.whl → 4.1.0__cp312-cp312-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/_circuit_cy.c +1 -1
- ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
- ck/circuit/tmp_const.py +5 -4
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
- ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
- ck/circuit_compiler/interpret_compiler.py +2 -2
- ck/circuit_compiler/llvm_compiler.py +4 -4
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
- ck/circuit_compiler/support/input_vars.py +4 -4
- ck/circuit_compiler/support/llvm_ir_function.py +4 -4
- ck/dataset/__init__.py +1 -0
- ck/dataset/cross_table.py +334 -0
- ck/dataset/dataset.py +682 -0
- ck/dataset/dataset_builder.py +519 -0
- ck/dataset/dataset_compute.py +140 -0
- ck/dataset/dataset_from_crosstable.py +64 -0
- ck/dataset/dataset_from_csv.py +151 -0
- ck/dataset/sampled_dataset.py +96 -0
- ck/example/diamond_square.py +3 -1
- ck/example/triangle_square.py +3 -1
- ck/example/truss.py +3 -1
- ck/in_out/parse_net.py +21 -19
- ck/in_out/parser_utils.py +7 -3
- ck/learning/__init__.py +0 -0
- ck/learning/coalesce_cross_tables.py +403 -0
- ck/learning/model_from_cross_tables.py +296 -0
- ck/learning/parameters.py +117 -0
- ck/learning/train_generative_bn.py +198 -0
- ck/pgm.py +105 -92
- ck/pgm_circuit/marginals_program.py +5 -0
- ck/pgm_circuit/mpe_program.py +3 -4
- ck/pgm_circuit/pgm_circuit.py +27 -18
- ck/pgm_circuit/program_with_slotmap.py +27 -46
- ck/pgm_circuit/support/compile_circuit.py +2 -4
- ck/pgm_circuit/wmc_program.py +5 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
- ck/probability/cross_table_probability_space.py +53 -0
- ck/probability/divergence.py +226 -0
- ck/probability/empirical_probability_space.py +1 -0
- ck/probability/probability_space.py +53 -30
- ck/program/raw_program.py +23 -16
- ck/sampling/sampler_support.py +5 -6
- ck/utils/iter_extras.py +3 -2
- ck/utils/local_config.py +16 -8
- ck_demos/dataset/__init__.py +0 -0
- ck_demos/dataset/demo_dataset_builder.py +37 -0
- ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
- ck_demos/learning/__init__.py +0 -0
- ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
- ck_demos/learning/demo_simple_learning.py +55 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
- {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/METADATA +2 -1
- {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/RECORD +58 -37
- {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/top_level.txt +0 -0
|
Binary file
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import ctypes as ct
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Sequence, Optional, Dict, List, Tuple, Callable
|
|
5
|
+
from typing import Sequence, Optional, Dict, List, Tuple, Callable, assert_never
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
@@ -174,7 +174,7 @@ def _make_instructions(
|
|
|
174
174
|
elif op_node.symbol == ADD:
|
|
175
175
|
operation = sum
|
|
176
176
|
else:
|
|
177
|
-
|
|
177
|
+
assert_never('not reached')
|
|
178
178
|
|
|
179
179
|
instructions.append(_Instruction(operation, args, dest))
|
|
180
180
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from enum import Enum
|
|
4
|
+
from enum import Enum, auto
|
|
5
5
|
from typing import Sequence, Optional, Tuple, Dict, Protocol, assert_never
|
|
6
6
|
|
|
7
7
|
import llvmlite.binding as llvm
|
|
@@ -15,9 +15,9 @@ from ..program.raw_program import RawProgramFunction
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class Flavour(Enum):
|
|
18
|
-
STACK =
|
|
19
|
-
TMPS =
|
|
20
|
-
FUNCS =
|
|
18
|
+
STACK = auto() # No working temporary memory requested - all on stack.
|
|
19
|
+
TMPS = auto() # Working temporary memory used for op node calculations.
|
|
20
|
+
FUNCS = auto() # Working temporary memory used for op node calculations, one sub-function per op-node.
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
"/O2"
|
|
14
14
|
],
|
|
15
15
|
"include_dirs": [
|
|
16
|
-
"C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-
|
|
16
|
+
"C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-gp1o6j1g\\Lib\\site-packages\\numpy\\_core\\include"
|
|
17
17
|
],
|
|
18
18
|
"name": "ck.circuit_compiler.support.circuit_analyser._circuit_analyser_cy",
|
|
19
19
|
"sources": [
|
|
Binary file
|
|
@@ -3,7 +3,7 @@ This module supports circuit compilers and interpreters by inferring and checkin
|
|
|
3
3
|
that are explicitly or implicitly referred to by a client.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from enum import Enum
|
|
6
|
+
from enum import Enum, auto
|
|
7
7
|
from itertools import chain
|
|
8
8
|
from typing import Sequence, Optional, Set, Iterable, List
|
|
9
9
|
|
|
@@ -15,9 +15,9 @@ class InferVars(Enum):
|
|
|
15
15
|
An enum specifying how to automatically infer a program's input variables.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
ALL =
|
|
19
|
-
REF =
|
|
20
|
-
LOW =
|
|
18
|
+
ALL = auto() # all circuit vars are input vars
|
|
19
|
+
REF = auto() # only referenced vars are input vars
|
|
20
|
+
LOW = auto() # input vars are circuit vars[0 : max_referenced + 1]
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
# Type for specifying input circuit vars
|
|
@@ -213,10 +213,10 @@ def compile_llvm_program(
|
|
|
213
213
|
Compile the given LLVM program.
|
|
214
214
|
|
|
215
215
|
Returns:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
function
|
|
216
|
+
`engine` an LLVM execution engine, which must remain
|
|
217
|
+
in memory for the returned function to be valid,
|
|
218
|
+
|
|
219
|
+
`function` the raw Python callable for the compiled function.
|
|
220
220
|
"""
|
|
221
221
|
_init_llvm()
|
|
222
222
|
|
ck/dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .dataset import Dataset, HardDataset, SoftDataset
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple, Sequence, Iterator, Iterable, Optional, MutableMapping, Dict, assert_never
|
|
4
|
+
|
|
5
|
+
from ck.dataset import SoftDataset, HardDataset
|
|
6
|
+
from ck.pgm import RandomVariable, rv_instances, Instance
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CrossTable(MutableMapping[Instance, float]):
|
|
10
|
+
"""
|
|
11
|
+
A cross-table records the total weight for possible combinations
|
|
12
|
+
of states for some random variables, i.e., the weight of unique instances.
|
|
13
|
+
|
|
14
|
+
A cross-table is a dictionary mapping from state indices of the cross-table
|
|
15
|
+
random variables (an instance, as a tuple) to a weight (as a float).
|
|
16
|
+
|
|
17
|
+
Given a cross-table `ct`, then for each `instance in ct.keys()`:
|
|
18
|
+
`len(instance) == len(ct.rvs)`,
|
|
19
|
+
and `0 <= instance[j] < len(ct.rvs[i])`,
|
|
20
|
+
and `0 < ct[instance]`.
|
|
21
|
+
|
|
22
|
+
Zero weighted instances are not explicitly represented in a cross-table.
|
|
23
|
+
Given a cross-table `ct` then the following is always true.
|
|
24
|
+
`x in ct.keys()` is true if and only if `ct[x] != 0`.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
rvs: Sequence[RandomVariable],
|
|
30
|
+
dirichlet_prior: float | CrossTable = 0,
|
|
31
|
+
update: Iterable[Tuple[Instance, float]] = (),
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Construct a cross-table for the given random variables.
|
|
35
|
+
|
|
36
|
+
The cross-table can be initialised with a Dirichlet prior, x. Practically
|
|
37
|
+
this amounts to adding a weight of x to each possible combination of
|
|
38
|
+
random variable states. That is, a Dirichlet prior of x results in x pseudocounts
|
|
39
|
+
for each possible combination of states.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
rvs: the random variables that this cross-table records weights for. Instances
|
|
43
|
+
in this cross-table are tuples of state indexes, co-indexed with `rvs`.
|
|
44
|
+
dirichlet_prior: provides a prior for `rvs`. This can be represented either:
|
|
45
|
+
(a) as a uniform prior, represented as a float value,
|
|
46
|
+
(b) as an arbitrary prior, represented as a cross-table.
|
|
47
|
+
If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
|
|
48
|
+
The default value for `dirichlet_prior` is 0.
|
|
49
|
+
update: an optional iterable of (instance, weight) tuples to add to
|
|
50
|
+
the cross-table at construction time.
|
|
51
|
+
"""
|
|
52
|
+
self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
|
|
53
|
+
self._dict: Dict[Instance, float]
|
|
54
|
+
|
|
55
|
+
if isinstance(dirichlet_prior, CrossTable):
|
|
56
|
+
# rv_map[i] is where rvs[i] appears in the dirichlet_prior cross-table
|
|
57
|
+
# It will be used to map instances of the prior to instances of self.
|
|
58
|
+
rv_map: List[int] = [
|
|
59
|
+
dirichlet_prior.rvs.index(rv)
|
|
60
|
+
for rv in rvs
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
# Copy items from the prior to self, mapping the instances as needed
|
|
64
|
+
self._dict = {
|
|
65
|
+
tuple(prior_instance[select] for select in rv_map): weight
|
|
66
|
+
for prior_instance, weight in dirichlet_prior.items()
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
elif isinstance(dirichlet_prior, (float, int)):
|
|
70
|
+
if dirichlet_prior != 0:
|
|
71
|
+
# Initialise self with every possible combination of rvs states.
|
|
72
|
+
instance: Instance
|
|
73
|
+
self._dict = {
|
|
74
|
+
instance: dirichlet_prior
|
|
75
|
+
for instance in rv_instances(*self._rvs)
|
|
76
|
+
}
|
|
77
|
+
else:
|
|
78
|
+
self._dict = {}
|
|
79
|
+
else:
|
|
80
|
+
assert_never('not reached')
|
|
81
|
+
|
|
82
|
+
# Apply any provided updates
|
|
83
|
+
self.add_all(update)
|
|
84
|
+
|
|
85
|
+
def __eq__(self, other) -> bool:
|
|
86
|
+
"""
|
|
87
|
+
Two cross-tables are equal if they have the same sequence of random variables
|
|
88
|
+
and their instance weights are equal.
|
|
89
|
+
"""
|
|
90
|
+
return isinstance(other, CrossTable) and self._rvs == other._rvs and self._dict == other._dict
|
|
91
|
+
|
|
92
|
+
def __setitem__(self, key: Instance, value) -> None:
|
|
93
|
+
if value == 0:
|
|
94
|
+
self._dict.pop(key, None)
|
|
95
|
+
else:
|
|
96
|
+
self._dict[key] = value
|
|
97
|
+
|
|
98
|
+
def __delitem__(self, key: Instance) -> None:
|
|
99
|
+
del self._dict[key]
|
|
100
|
+
|
|
101
|
+
def __getitem__(self, key: Instance) -> float:
|
|
102
|
+
"""
|
|
103
|
+
Returns:
|
|
104
|
+
the weight of the given instance.
|
|
105
|
+
This will always return a value, even if the key is not in the underlying dictionary.
|
|
106
|
+
"""
|
|
107
|
+
return self._dict.get(key, 0)
|
|
108
|
+
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
"""
|
|
111
|
+
Returns:
|
|
112
|
+
the number of instances in the cross-table with non-zero weight.
|
|
113
|
+
"""
|
|
114
|
+
return len(self._dict)
|
|
115
|
+
|
|
116
|
+
def __iter__(self) -> Iterator[Instance]:
|
|
117
|
+
"""
|
|
118
|
+
Returns:
|
|
119
|
+
an iterator over the cross-table instances with non-zero weight.
|
|
120
|
+
"""
|
|
121
|
+
return iter(self._dict)
|
|
122
|
+
|
|
123
|
+
def items(self) -> Iterable[Tuple[Instance, float]]:
|
|
124
|
+
"""
|
|
125
|
+
Returns:
|
|
126
|
+
an iterable over (instance, weight) pairs.
|
|
127
|
+
"""
|
|
128
|
+
return self._dict.items()
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def rvs(self) -> Sequence[RandomVariable]:
|
|
132
|
+
"""
|
|
133
|
+
The random variables that this cross-table refers to.
|
|
134
|
+
"""
|
|
135
|
+
return self._rvs
|
|
136
|
+
|
|
137
|
+
def add(self, instance: Instance, weight: float) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Add the given weighted instance to the cross-table.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
instance: a tuple of state indices, co-indexed with `self.rvs`.
|
|
143
|
+
weight: the weight (generalised count) to add to the cross-table. Normally the
|
|
144
|
+
weight will be > 0.
|
|
145
|
+
"""
|
|
146
|
+
self[instance] = self._dict.get(instance, 0) + weight
|
|
147
|
+
|
|
148
|
+
def add_all(self, to_add: Iterable[Tuple[Instance, float]]) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Add the given weighted instances to the cross-table.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
to_add: an iterable of (instance, weight) tuples to add to the cross-table.
|
|
154
|
+
"""
|
|
155
|
+
for instance, weight in to_add:
|
|
156
|
+
self.add(instance, weight)
|
|
157
|
+
|
|
158
|
+
def mul(self, multiplier: float) -> None:
|
|
159
|
+
"""
|
|
160
|
+
Multiply all weights by the given multiplier.
|
|
161
|
+
"""
|
|
162
|
+
if multiplier == 0:
|
|
163
|
+
self._dict.clear()
|
|
164
|
+
elif multiplier == 1:
|
|
165
|
+
pass
|
|
166
|
+
else:
|
|
167
|
+
for instance in self._dict.keys():
|
|
168
|
+
self._dict[instance] *= multiplier
|
|
169
|
+
|
|
170
|
+
def total_weight(self) -> float:
|
|
171
|
+
"""
|
|
172
|
+
Calculate the total weight of this cross-table.
|
|
173
|
+
"""
|
|
174
|
+
return sum(self.values())
|
|
175
|
+
|
|
176
|
+
def project(self, rvs: Sequence[RandomVariable]) -> CrossTable:
|
|
177
|
+
"""
|
|
178
|
+
Project this cross-table onto the given set of random variables.
|
|
179
|
+
|
|
180
|
+
If successful, this method will always return a new CrossTable object.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
a CrossTable with the given sequence of random variables.
|
|
184
|
+
|
|
185
|
+
Assumes:
|
|
186
|
+
`rvs` is a subset of the cross-table's random variables.
|
|
187
|
+
"""
|
|
188
|
+
# Mapping rv_map[i] is the index into `self.rvs` for `rvs[i]`.
|
|
189
|
+
rv_map: List[int] = [self.rvs.index(rv) for rv in rvs]
|
|
190
|
+
|
|
191
|
+
return CrossTable(
|
|
192
|
+
rvs=rvs,
|
|
193
|
+
update=(
|
|
194
|
+
(tuple(instance[i] for i in rv_map), weight)
|
|
195
|
+
for instance, weight in self._dict.items()
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def dump(self, *, show_rvs: bool = True, show_weights: bool = True, as_states: bool = False) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Dump the cross-table in a human-readable format.
|
|
202
|
+
If as_states is true, then instance states are dumped instead of just state indexes.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
show_rvs: If `True`, the random variables are dumped.
|
|
206
|
+
show_weights: If `True`, the instance weights are dumped.
|
|
207
|
+
as_states: If `True`, the states are dumped instead of just state indexes.
|
|
208
|
+
"""
|
|
209
|
+
if show_rvs:
|
|
210
|
+
rvs = ', '.join(str(rv) for rv in self.rvs)
|
|
211
|
+
print(f'rvs: [{rvs}]')
|
|
212
|
+
print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
|
|
213
|
+
for instance, weight in self.items():
|
|
214
|
+
if as_states:
|
|
215
|
+
instance_str = ', '.join(repr(rv.states[idx]) for idx, rv in zip(instance, self.rvs))
|
|
216
|
+
else:
|
|
217
|
+
instance_str = ', '.join(str(idx) for idx in instance)
|
|
218
|
+
if show_weights:
|
|
219
|
+
print(f'({instance_str}) * {weight}')
|
|
220
|
+
else:
|
|
221
|
+
print(f'({instance_str})')
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def cross_table_from_dataset(
|
|
225
|
+
dataset: HardDataset | SoftDataset,
|
|
226
|
+
rvs: Optional[Sequence[RandomVariable]] = None,
|
|
227
|
+
*,
|
|
228
|
+
dirichlet_prior: float | CrossTable = 0,
|
|
229
|
+
) -> CrossTable:
|
|
230
|
+
"""
|
|
231
|
+
Generate a cross-table for the given random variables, using the given dataset, represented
|
|
232
|
+
as a dictionary, mapping instances to weights.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
dataset: The dataset to use to compute the cross-table.
|
|
236
|
+
rvs: The random variables to compute the cross-table for. If omitted
|
|
237
|
+
then `dataset.rvs` will be used.
|
|
238
|
+
dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
|
|
239
|
+
(a) as a uniform prior, represented as a float value,
|
|
240
|
+
(b) as an arbitrary Dirichlet prior, represented as a cross-table.
|
|
241
|
+
If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
|
|
242
|
+
The default value for `dirichlet_prior` is 0.
|
|
243
|
+
See `CrossTable` for more explanation.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
The cross-table for the given random variables, using the given dataset,
|
|
247
|
+
represented as a dictionary mapping instances to weights.
|
|
248
|
+
An instance is a tuple of state indexes, co-indexed with rvs.
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
KeyError: If any random variable in `rvs` does not appear in the dataset.
|
|
252
|
+
"""
|
|
253
|
+
if isinstance(dataset, HardDataset):
|
|
254
|
+
return cross_table_from_hard_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
|
|
255
|
+
if isinstance(dataset, SoftDataset):
|
|
256
|
+
return cross_table_from_soft_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
|
|
257
|
+
raise TypeError('dataset must be either a SoftDataset or HardDataset')
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def cross_table_from_hard_dataset(
|
|
261
|
+
dataset: HardDataset,
|
|
262
|
+
rvs: Optional[Sequence[RandomVariable]] = None,
|
|
263
|
+
*,
|
|
264
|
+
dirichlet_prior: float | CrossTable = 0
|
|
265
|
+
) -> CrossTable:
|
|
266
|
+
"""
|
|
267
|
+
Generate a cross-table for the given random variables, using the given dataset, represented
|
|
268
|
+
as a dictionary, mapping instances to weights.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
dataset: The dataset to use to compute the cross-table.
|
|
272
|
+
rvs: The random variables to compute the cross-table for. If omitted
|
|
273
|
+
then `dataset.rvs` will be used.
|
|
274
|
+
dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
|
|
275
|
+
(a) as a uniform prior, represented as a float value,
|
|
276
|
+
(b) as an arbitrary Dirichlet prior, represented as a cross-table.
|
|
277
|
+
If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
|
|
278
|
+
The default value for `dirichlet_prior` is 0.
|
|
279
|
+
See `CrossTable` for more explanation.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
The cross-table for the given random variables, using the given dataset,
|
|
283
|
+
represented as a dictionary mapping instances to weights.
|
|
284
|
+
An instance is a tuple of state indexes, co-indexed with rvs.
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
KeyError: If any random variable in `rvs` does not appear in the dataset.
|
|
288
|
+
"""
|
|
289
|
+
if rvs is None:
|
|
290
|
+
rvs = dataset.rvs
|
|
291
|
+
return CrossTable(
|
|
292
|
+
rvs=rvs,
|
|
293
|
+
dirichlet_prior=dirichlet_prior,
|
|
294
|
+
update=dataset.instances(rvs)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def cross_table_from_soft_dataset(
|
|
299
|
+
dataset: SoftDataset,
|
|
300
|
+
rvs: Optional[Sequence[RandomVariable]] = None,
|
|
301
|
+
*,
|
|
302
|
+
dirichlet_prior: float | CrossTable = 0
|
|
303
|
+
) -> CrossTable:
|
|
304
|
+
"""
|
|
305
|
+
Generate a cross-table for the given random variables, using the given dataset, represented
|
|
306
|
+
as a dictionary, mapping instances to weights.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
dataset: The dataset to use to compute the cross-table.
|
|
310
|
+
rvs: The random variables to compute the cross-table for. If omitted
|
|
311
|
+
then `dataset.rvs` will be used.
|
|
312
|
+
dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
|
|
313
|
+
(a) as a uniform prior, represented as a float value,
|
|
314
|
+
(b) as an arbitrary Dirichlet prior, represented as a cross-table.
|
|
315
|
+
If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
|
|
316
|
+
The default value for `dirichlet_prior` is 0.
|
|
317
|
+
See `CrossTable` for more explanation.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
The cross-table for the given random variables, using the given dataset,
|
|
321
|
+
represented as a dictionary mapping instances to weights.
|
|
322
|
+
An instance is a tuple of state indexes, co-indexed with rvs.
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
KeyError: If any random variable in `rvs` does not appear in the dataset.
|
|
326
|
+
"""
|
|
327
|
+
if rvs is None:
|
|
328
|
+
rvs = dataset.rvs
|
|
329
|
+
|
|
330
|
+
return CrossTable(
|
|
331
|
+
rvs=rvs,
|
|
332
|
+
dirichlet_prior=dirichlet_prior,
|
|
333
|
+
update=dataset.hard_instances(rvs)
|
|
334
|
+
)
|