compiled-knowledge 4.1.0a2__cp312-cp312-macosx_11_0_arm64.whl → 4.1.0a3__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/circuit/_circuit_cy.c +1 -1
- ck/circuit/_circuit_cy.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-312-darwin.so +0 -0
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
- ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-312-darwin.so +0 -0
- ck/dataset/cross_table.py +143 -79
- ck/dataset/dataset.py +95 -7
- ck/dataset/dataset_builder.py +11 -4
- ck/dataset/dataset_from_crosstable.py +21 -2
- ck/learning/coalesce_cross_tables.py +395 -0
- ck/learning/model_from_cross_tables.py +242 -0
- ck/learning/parameters.py +117 -0
- ck/learning/train_generative_bn.py +198 -0
- ck/pgm.py +10 -8
- ck/pgm_circuit/marginals_program.py +5 -0
- ck/pgm_circuit/wmc_program.py +5 -0
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
- ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-312-darwin.so +0 -0
- ck/probability/divergence.py +226 -0
- ck/probability/probability_space.py +43 -19
- ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
- ck_demos/learning/__init__.py +0 -0
- ck_demos/learning/demo_bayesian_network_from_cross_tables.py +71 -0
- ck_demos/learning/demo_simple_learning.py +55 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/METADATA +2 -1
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/RECORD +32 -24
- ck/learning/train_generative.py +0 -149
- /ck/{dataset/cross_table_probabilities.py → probability/cross_table_probability_space.py} +0 -0
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/WHEEL +0 -0
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/licenses/LICENSE.txt +0 -0
- {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Tuple, List, Sequence, Dict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.sparse import dok_matrix
|
|
8
|
+
from scipy.sparse.linalg import lsqr
|
|
9
|
+
|
|
10
|
+
from ck.dataset.cross_table import CrossTable
|
|
11
|
+
from ck.pgm import RandomVariable, Instance
|
|
12
|
+
from ck.utils.iter_extras import combos
|
|
13
|
+
from ck.utils.np_extras import NDArrayFloat64
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def coalesce_cross_tables(crosstabs: Sequence[CrossTable], rvs: Sequence[RandomVariable]) -> CrossTable:
|
|
17
|
+
"""
|
|
18
|
+
Rationalise multiple cross-tables into a single cross-table.
|
|
19
|
+
|
|
20
|
+
This method implements a solution finding the best vector `a`
|
|
21
|
+
that solves `b = m a`, subject to `a[i] >= 0`.
|
|
22
|
+
|
|
23
|
+
`a` is a column vector with one entry for each instance of rvs seen in the cross-tables.
|
|
24
|
+
|
|
25
|
+
`b` is a column vector containing all probabilities from all cross-tables (normalised weights).
|
|
26
|
+
|
|
27
|
+
`m` is a sparse matrix with m[i, j] = 1 where b[j] is in the sum for source probability a[i].
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
crosstabs: a collection of cross-tables to coalesce.
|
|
31
|
+
rvs: the random variables that cross-tables will be projected on to.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
a cross-table defined for the given `rvs`, with values inferred from `crosstabs`
|
|
35
|
+
"""
|
|
36
|
+
if len(crosstabs) == 0:
|
|
37
|
+
return CrossTable(rvs)
|
|
38
|
+
|
|
39
|
+
m: dok_matrix
|
|
40
|
+
b: np.ndarray
|
|
41
|
+
a_keys: Sequence[Instance]
|
|
42
|
+
m, b, a_keys = _make_matrix(crosstabs, rvs)
|
|
43
|
+
|
|
44
|
+
a = _solve(m, b)
|
|
45
|
+
|
|
46
|
+
return CrossTable(
|
|
47
|
+
rvs,
|
|
48
|
+
update=(
|
|
49
|
+
(instance, weight)
|
|
50
|
+
for instance, weight in zip(a_keys, a)
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _solve(m: dok_matrix, b: np.ndarray) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
|
|
58
|
+
"""
|
|
59
|
+
assert len(b.shape) == 1, 'b should be a vector'
|
|
60
|
+
assert b.shape[0] == m.shape[0], 'b and m must be compatible'
|
|
61
|
+
|
|
62
|
+
return _solve_sam(m, b)
|
|
63
|
+
# return _solve_lsqr(m, b)
|
|
64
|
+
# return _solve_pulp_l1(m, b)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _solve_sam(m: dok_matrix, b: np.ndarray) -> np.ndarray:
|
|
68
|
+
"""
|
|
69
|
+
Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
|
|
70
|
+
|
|
71
|
+
Uses a custom 'split and mean' (SAM) method.
|
|
72
|
+
"""
|
|
73
|
+
sam = _SAM(m, b)
|
|
74
|
+
a, error = sam.solve(
|
|
75
|
+
max_iterations=100,
|
|
76
|
+
tolerance=1e-6,
|
|
77
|
+
change_tolerance=1e-12
|
|
78
|
+
)
|
|
79
|
+
return a
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _solve_lsqr(m: dok_matrix, b: np.ndarray) -> np.ndarray:
|
|
83
|
+
"""
|
|
84
|
+
Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
|
|
85
|
+
|
|
86
|
+
Uses scipy `lsqr` method, with a heuristic to fix negative values in `a`.
|
|
87
|
+
"""
|
|
88
|
+
a: np.ndarray
|
|
89
|
+
# Pycharm type checker incorrectly infers the type signature of `lsqr`
|
|
90
|
+
# noinspection PyTypeChecker
|
|
91
|
+
a, istop, itn, r1norm, r2norm, _, _, _, _, _ = lsqr(m, b)
|
|
92
|
+
|
|
93
|
+
# Negative values or values > 1 are not a valid solution.
|
|
94
|
+
|
|
95
|
+
# Heuristic fix up...
|
|
96
|
+
if len(a) > 0:
|
|
97
|
+
min_val = np.min(a)
|
|
98
|
+
if min_val < 0:
|
|
99
|
+
# We could just let the negative values get truncated to zero, but
|
|
100
|
+
# empirically the results seem better when we shift all parameters up.
|
|
101
|
+
a[:] -= min_val
|
|
102
|
+
|
|
103
|
+
return a
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# This approach is unsatisfactory as we should minimise the L2 norm
|
|
107
|
+
# rather than the L1 norm.
|
|
108
|
+
#
|
|
109
|
+
# def _solve_pulp_l1(m: dok_matrix, b: np.ndarray) -> np.ndarray:
|
|
110
|
+
# """
|
|
111
|
+
# Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
|
|
112
|
+
#
|
|
113
|
+
# Uses pulp LpProblem to minimise the L1 norm of `a`.
|
|
114
|
+
#
|
|
115
|
+
# This method will only work if there is an exact solution.
|
|
116
|
+
# If not, then we call _solve_sam as a fallback.
|
|
117
|
+
# """
|
|
118
|
+
# import pulp
|
|
119
|
+
#
|
|
120
|
+
# a_size = m.shape[1]
|
|
121
|
+
# b_size = b.shape[0]
|
|
122
|
+
#
|
|
123
|
+
# prob = pulp.LpProblem('solver', pulp.LpMinimize)
|
|
124
|
+
# x = [pulp.LpVariable(f'x{i}', lowBound=0) for i in range(a_size)]
|
|
125
|
+
#
|
|
126
|
+
# # The objective: minimise the L1 norm of x.
|
|
127
|
+
# # The sum(x) is the L1 norm because each element of x is constrained >= 0.
|
|
128
|
+
# prob.setObjective(pulp.lpSum(x))
|
|
129
|
+
#
|
|
130
|
+
# # The constraints
|
|
131
|
+
# constraints = [pulp.LpAffineExpression() for _ in range(b_size)]
|
|
132
|
+
# for row, col in m.keys():
|
|
133
|
+
# constraints[row].addterm(x[col], 1)
|
|
134
|
+
# for c, b_i in zip(constraints, b):
|
|
135
|
+
# prob.addConstraint(c == b_i)
|
|
136
|
+
#
|
|
137
|
+
# _PULP_TIMEOUT = 60 # seconds
|
|
138
|
+
# status = prob.solve(pulp.PULP_CBC_CMD(msg=False, timeLimit=_PULP_TIMEOUT))
|
|
139
|
+
#
|
|
140
|
+
# if status == pulp.LpStatusOptimal:
|
|
141
|
+
# return np.array([pulp.value(x_var) for x_var in x])
|
|
142
|
+
# else:
|
|
143
|
+
# return _solve_sam(m, b)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _sum_out_unneeded_rvs(crosstab: CrossTable, rvs: Sequence[RandomVariable]) -> CrossTable:
|
|
147
|
+
"""
|
|
148
|
+
Project the given cross-table as needed to ensure all random
|
|
149
|
+
variables in the result are in `rvs`.
|
|
150
|
+
"""
|
|
151
|
+
available_rvs = set(crosstab.rvs)
|
|
152
|
+
project_rvs = available_rvs.intersection(rvs)
|
|
153
|
+
if len(project_rvs) == len(available_rvs):
|
|
154
|
+
# No projection is required
|
|
155
|
+
return crosstab
|
|
156
|
+
else:
|
|
157
|
+
return crosstab.project(list(project_rvs))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _make_matrix(
|
|
161
|
+
crosstabs: Sequence[CrossTable],
|
|
162
|
+
rvs: Sequence[RandomVariable],
|
|
163
|
+
) -> Tuple[dok_matrix, np.ndarray, Sequence[Instance]]:
|
|
164
|
+
"""
|
|
165
|
+
Create the `m` matrix and `b` vector for solving `b = m a`.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
crosstabs: a collection of cross-tables to coalesce.
|
|
169
|
+
rvs: the random variables that cross-tables will be projected on to.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
the tuple (m, b, a_keys) where
|
|
173
|
+
'm' is a sparse matrix,
|
|
174
|
+
'b' is a numpy array of crosstab probabilities,
|
|
175
|
+
'a_keys' are the keys for the solution probabilities.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
# Sum out any unneeded random variables
|
|
179
|
+
crosstabs: Sequence[CrossTable] = [
|
|
180
|
+
_sum_out_unneeded_rvs(crosstab, rvs)
|
|
181
|
+
for crosstab in crosstabs
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
m_cols: Dict[Instance, _MCol] = {}
|
|
185
|
+
b_list: List[float] = []
|
|
186
|
+
a_keys: List[Instance] = []
|
|
187
|
+
|
|
188
|
+
rv_index: Dict[RandomVariable, int] = {rv: i for i, rv in enumerate(rvs)}
|
|
189
|
+
|
|
190
|
+
# instance_template[i] is a list of the possible states of rvs[i]
|
|
191
|
+
instance_template: List[List[int]] = [list(range(len(rv))) for rv in enumerate(rvs)]
|
|
192
|
+
|
|
193
|
+
for crosstab in crosstabs:
|
|
194
|
+
|
|
195
|
+
# get `to_rv` such that crosstab.rvs[i] = rvs[to_rv[i]]
|
|
196
|
+
to_rv = [rv_index.get(rv) for rv in crosstab.rvs]
|
|
197
|
+
|
|
198
|
+
# Make instance_options which is a clone of instance_template but with
|
|
199
|
+
# a singleton list replacing the rvs that this crosstab has.
|
|
200
|
+
# For now the state in each singleton is set to -1, however, later
|
|
201
|
+
# they will be set to the actual states of instances in the current crosstab.
|
|
202
|
+
instance_options = list(instance_template)
|
|
203
|
+
for i in to_rv:
|
|
204
|
+
instance_options[i] = [-1]
|
|
205
|
+
|
|
206
|
+
total = crosstab.total_weight()
|
|
207
|
+
for crosstab_instance, weight in crosstab.items():
|
|
208
|
+
|
|
209
|
+
# Work out what instances get summed to create the crosstab_instance weight.
|
|
210
|
+
# This just overrides the singleton states of `instance_options` with the
|
|
211
|
+
# actual state of the crosstab_instance.
|
|
212
|
+
for state, i in zip(crosstab_instance, to_rv):
|
|
213
|
+
instance_options[i][0] = state
|
|
214
|
+
|
|
215
|
+
# Grow the b list with our instance probability
|
|
216
|
+
b_i = len(b_list)
|
|
217
|
+
b_list.append(weight / total)
|
|
218
|
+
|
|
219
|
+
# Iterate over all states of `rvs` that matches the current crosstab_instance
|
|
220
|
+
# recording `b_i` in the column for those matching instances.
|
|
221
|
+
for instance in combos(instance_options):
|
|
222
|
+
m_col = m_cols.get(instance)
|
|
223
|
+
if m_col is None:
|
|
224
|
+
m_cols[instance] = _MCol(instance, len(a_keys), [b_i])
|
|
225
|
+
a_keys.append(instance)
|
|
226
|
+
else:
|
|
227
|
+
m_col.col.append(b_i)
|
|
228
|
+
|
|
229
|
+
# Construct the m matrix from m_cols
|
|
230
|
+
m = dok_matrix((len(b_list), len(a_keys)), dtype=np.double)
|
|
231
|
+
for m_col in m_cols.values():
|
|
232
|
+
j = m_col.column_index
|
|
233
|
+
for i in m_col.col:
|
|
234
|
+
m[i, j] = 1
|
|
235
|
+
|
|
236
|
+
# Construct the b vector
|
|
237
|
+
b = np.array(b_list, dtype=np.double)
|
|
238
|
+
|
|
239
|
+
return m, b, a_keys
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@dataclass
|
|
243
|
+
class _MCol:
|
|
244
|
+
key: Instance
|
|
245
|
+
column_index: int
|
|
246
|
+
col: List[int]
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dataclass
|
|
250
|
+
class _SM:
|
|
251
|
+
split: float
|
|
252
|
+
a: float
|
|
253
|
+
|
|
254
|
+
def diff(self) -> float:
|
|
255
|
+
return self.split - self.a
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class _SAM:
|
|
259
|
+
"""
|
|
260
|
+
Split and Mean method for finding 'a'
|
|
261
|
+
in b = m a
|
|
262
|
+
subject to a[i] >= 0.
|
|
263
|
+
|
|
264
|
+
Assumes all elements of `m` are either zero or one.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
def __init__(self, m: dok_matrix, b: NDArrayFloat64, use_lsqr: bool = True):
|
|
268
|
+
"""
|
|
269
|
+
Allocate the memory required for a SAM solver.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
m: the summation matrix
|
|
273
|
+
b: the vector of resulting probabilities
|
|
274
|
+
use_lsqr: whether to use LSQR or not to initialise the solution.
|
|
275
|
+
"""
|
|
276
|
+
# Replicate the sparse m matrix, as a list of lists of _SM objects,
|
|
277
|
+
# where we have both row major and column major representations.
|
|
278
|
+
a_size: int = m.shape[1]
|
|
279
|
+
b_size: int = m.shape[0]
|
|
280
|
+
a_idx: List[List[_SM]] = [[] for _ in range(a_size)]
|
|
281
|
+
b_idx: List[List[_SM]] = [[] for _ in range(b_size)]
|
|
282
|
+
for (i, j), m_val in m.items():
|
|
283
|
+
if m_val != 0:
|
|
284
|
+
sm = _SM(0, 0)
|
|
285
|
+
a_idx[j].append(sm)
|
|
286
|
+
b_idx[i].append(sm)
|
|
287
|
+
|
|
288
|
+
self._a: NDArrayFloat64 = np.zeros(a_size, dtype=np.double)
|
|
289
|
+
self._b: NDArrayFloat64 = b
|
|
290
|
+
self._m: dok_matrix = m
|
|
291
|
+
self._a_idx: List[List[_SM]] = a_idx
|
|
292
|
+
self._b_idx: List[List[_SM]] = b_idx
|
|
293
|
+
self._use_lsqr = use_lsqr
|
|
294
|
+
|
|
295
|
+
def solve(self, max_iterations: int, tolerance: float, change_tolerance: float) -> Tuple[np.ndarray, float]:
|
|
296
|
+
"""
|
|
297
|
+
Initialize split values then iterate (mean step, split step).
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
max_iterations: maximum number of iterations.
|
|
301
|
+
tolerance: terminate iterations if error <= tolerance.
|
|
302
|
+
change_tolerance: terminate iterations if change in error <= change_tolerance.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
tuple ('a', 'error') where 'a' is the current solution after this step
|
|
306
|
+
and 'error' is the sum of absolute errors.
|
|
307
|
+
"""
|
|
308
|
+
self._initialize_split_values()
|
|
309
|
+
iteration = 0
|
|
310
|
+
prev_error = 0
|
|
311
|
+
while True:
|
|
312
|
+
iteration += 1
|
|
313
|
+
a, error = self._mean_step()
|
|
314
|
+
if error <= tolerance or abs(error - prev_error) <= change_tolerance or iteration >= max_iterations:
|
|
315
|
+
return a, error
|
|
316
|
+
prev_error = error
|
|
317
|
+
self._split_step()
|
|
318
|
+
|
|
319
|
+
def _initialize_split_values(self):
|
|
320
|
+
"""
|
|
321
|
+
Take each 'b' value and split it across its SM cells.
|
|
322
|
+
|
|
323
|
+
If 'self._use_lsqr' is True then the split is based on a solution
|
|
324
|
+
using scipy lsqr, otherwise the split is even for each 'b' value.
|
|
325
|
+
"""
|
|
326
|
+
for b_val, b_list in zip(self._b, self._b_idx):
|
|
327
|
+
len_b_list = len(b_list)
|
|
328
|
+
if len_b_list > 0:
|
|
329
|
+
split_val = b_val / len_b_list
|
|
330
|
+
for sm in b_list:
|
|
331
|
+
sm.split = split_val
|
|
332
|
+
|
|
333
|
+
if self._use_lsqr:
|
|
334
|
+
a = _solve_lsqr(self._m, self._b)
|
|
335
|
+
assert len(a) == len(self._a)
|
|
336
|
+
for a_val, a_list in zip(a, self._a_idx):
|
|
337
|
+
for sm in a_list:
|
|
338
|
+
sm.a = a_val
|
|
339
|
+
self._split_step()
|
|
340
|
+
|
|
341
|
+
def _mean_step(self) -> Tuple[np.ndarray, float]:
|
|
342
|
+
"""
|
|
343
|
+
Take the current split values to determine the 'a' values
|
|
344
|
+
as the mean across relevant SM cells.
|
|
345
|
+
|
|
346
|
+
Assumes the previous step was either 'initialize_split_values'
|
|
347
|
+
or a 'split step'.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
tuple ('a', 'error') where 'a' is the current solution after this step
|
|
351
|
+
and 'error' is the sum of absolute errors.
|
|
352
|
+
"""
|
|
353
|
+
error = 0.0
|
|
354
|
+
a = self._a
|
|
355
|
+
for i, a_list in enumerate(self._a_idx):
|
|
356
|
+
sum_val = sum(sm.split for sm in a_list)
|
|
357
|
+
a_val = sum_val / len(a_list)
|
|
358
|
+
a[i] = a_val
|
|
359
|
+
for sm in a_list:
|
|
360
|
+
sm.a = a_val
|
|
361
|
+
error += abs(sm.diff())
|
|
362
|
+
return a, error
|
|
363
|
+
|
|
364
|
+
def _split_step(self):
|
|
365
|
+
"""
|
|
366
|
+
Take the difference between the split 'b' values and current 'a' values to
|
|
367
|
+
redistribute the split 'b' values.
|
|
368
|
+
|
|
369
|
+
Assumes the previous step was a 'mean step'.
|
|
370
|
+
"""
|
|
371
|
+
for b_val, b_list in zip(self._b, self._b_idx):
|
|
372
|
+
if len(b_list) <= 1:
|
|
373
|
+
# Too small to split
|
|
374
|
+
continue
|
|
375
|
+
pos = 0
|
|
376
|
+
neg = 0
|
|
377
|
+
for sm in b_list:
|
|
378
|
+
diff = sm.diff()
|
|
379
|
+
if diff >= 0:
|
|
380
|
+
pos += diff
|
|
381
|
+
else:
|
|
382
|
+
neg -= diff
|
|
383
|
+
mass = min(pos, neg)
|
|
384
|
+
if mass == 0:
|
|
385
|
+
# No mass to redistribute
|
|
386
|
+
continue
|
|
387
|
+
pos = mass / pos
|
|
388
|
+
neg = mass / neg
|
|
389
|
+
for sm in b_list:
|
|
390
|
+
diff = sm.diff()
|
|
391
|
+
if diff >= 0:
|
|
392
|
+
mass = diff * pos
|
|
393
|
+
else:
|
|
394
|
+
mass = diff * neg
|
|
395
|
+
sm.split -= mass
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from itertools import chain
|
|
5
|
+
from typing import Iterable, List, Tuple, Dict, Sequence, Set, Optional
|
|
6
|
+
|
|
7
|
+
from ck.dataset.cross_table import CrossTable
|
|
8
|
+
from ck.learning.coalesce_cross_tables import coalesce_cross_tables
|
|
9
|
+
from ck.learning.parameters import make_factors, ParameterValues
|
|
10
|
+
from ck.learning.train_generative_bn import cpt_and_parent_sums_from_crosstab
|
|
11
|
+
from ck.pgm import PGM, RandomVariable
|
|
12
|
+
from ck.utils.map_list import MapList
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def model_from_cross_tables(pgm: PGM, cross_tables: Iterable[CrossTable]) -> None:
|
|
16
|
+
"""
|
|
17
|
+
Make best efforts to construct a Bayesian network model given only the
|
|
18
|
+
evidence from the supplied cross-tables.
|
|
19
|
+
|
|
20
|
+
This function calls `get_cpts` which provides parameters to define
|
|
21
|
+
a Bayesian network model. These are then applied to the PGM using
|
|
22
|
+
`make_factors`.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pgm: the PGM to add factors and potential function to.
|
|
26
|
+
cross_tables: available cross-tables to build a model from.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If `pgm` has any existing factors.
|
|
30
|
+
"""
|
|
31
|
+
if len(pgm.factors) > 0:
|
|
32
|
+
raise ValueError('the given PGM should have no factors')
|
|
33
|
+
cpts: List[CrossTable] = get_cpts(
|
|
34
|
+
rvs=pgm.rvs,
|
|
35
|
+
cross_tables=cross_tables,
|
|
36
|
+
)
|
|
37
|
+
make_factors(pgm, cpts)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_cpts(rvs: Sequence[RandomVariable], cross_tables: Iterable[CrossTable]) -> ParameterValues:
|
|
41
|
+
"""
|
|
42
|
+
Make best efforts to define a Bayesian network model given only the
|
|
43
|
+
evidence from the supplied cross-tables.
|
|
44
|
+
|
|
45
|
+
This function infers CPTs for `rvs` using the given `cross_tables`.
|
|
46
|
+
|
|
47
|
+
For any two cross-tables `x` and `y` in `cross_tables`, with common random
|
|
48
|
+
variables `rvs` then this function assumes `x.project(rvs) == y.project(rvs)`.
|
|
49
|
+
If this condition does not hold, then best efforts will still be made to
|
|
50
|
+
define a Bayesian network model, however, the resulting parameter values
|
|
51
|
+
may be suboptimal.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
rvs: the random variables to define a network structure over.
|
|
55
|
+
cross_tables: available cross-tables to build a model from.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
ParameterValues object as a list of CPTs, each CPT can be used to create
|
|
59
|
+
a new factor in the given PGM to make a Bayesian network.
|
|
60
|
+
"""
|
|
61
|
+
# Stabilise the given crosstables
|
|
62
|
+
cross_tables: Tuple[CrossTable, ...] = tuple(cross_tables)
|
|
63
|
+
|
|
64
|
+
# Heuristically determine an ordering over the random variables
|
|
65
|
+
# which will be used to form the BN structure.
|
|
66
|
+
rv_order: Dict[RandomVariable, int] = _get_rv_order(cross_tables)
|
|
67
|
+
|
|
68
|
+
# Make an empty model factor for each random variable.
|
|
69
|
+
model_factors: List[_ModelFactor] = [_ModelFactor(rv) for rv in rvs]
|
|
70
|
+
|
|
71
|
+
# Define a Bayesian network structure.
|
|
72
|
+
# Allocate each crosstab to exactly one random variable, the
|
|
73
|
+
# one with the highest rank (i.e. the child).
|
|
74
|
+
for crosstab in cross_tables:
|
|
75
|
+
if len(crosstab.rvs) == 0:
|
|
76
|
+
continue
|
|
77
|
+
sorted_rvs: List[RandomVariable] = sorted(crosstab.rvs, key=(lambda _rv: rv_order[_rv]), reverse=True)
|
|
78
|
+
child: RandomVariable = sorted_rvs[0]
|
|
79
|
+
parents: List[RandomVariable] = sorted_rvs[1:]
|
|
80
|
+
model_factor: _ModelFactor = model_factors[child.idx]
|
|
81
|
+
model_factor.parent_rvs.update(parents)
|
|
82
|
+
model_factor.cross_tables.append(crosstab)
|
|
83
|
+
|
|
84
|
+
# Link child factors.
|
|
85
|
+
# When defining a factor, we need to define the child factors first.
|
|
86
|
+
for model_factor in model_factors:
|
|
87
|
+
for parent_rv in model_factor.parent_rvs:
|
|
88
|
+
model_factors[parent_rv.idx].child_factors.append(model_factor)
|
|
89
|
+
|
|
90
|
+
# Make the factors, depth first.
|
|
91
|
+
done: Set[int] = set()
|
|
92
|
+
for model_factor in model_factors:
|
|
93
|
+
_infer_cpt(model_factor, done)
|
|
94
|
+
|
|
95
|
+
# Return the CPTs that define the structure
|
|
96
|
+
return [model_factor.cpt for model_factor in model_factors]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _infer_cpt(model_factor: _ModelFactor, done: Set[int]) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Depth-first recursively infer the model factors as CPTs.
|
|
102
|
+
This sets `model_factor.cpt` and `model_factor.parent_rvs` for
|
|
103
|
+
the given `model_factor` and its children.
|
|
104
|
+
"""
|
|
105
|
+
# Only visit a model factor once.
|
|
106
|
+
child_rv: RandomVariable = model_factor.child_rv
|
|
107
|
+
if child_rv.idx in done:
|
|
108
|
+
return
|
|
109
|
+
done.add(child_rv.idx)
|
|
110
|
+
|
|
111
|
+
# Recursively visit the child factors
|
|
112
|
+
for child_model_factor in model_factor.child_factors:
|
|
113
|
+
_infer_cpt(child_model_factor, done)
|
|
114
|
+
|
|
115
|
+
# Get all relevant cross-tables to set the parameters
|
|
116
|
+
crosstabs: Sequence[CrossTable] = model_factor.cross_tables
|
|
117
|
+
child_crosstabs: Sequence[CrossTable] = [
|
|
118
|
+
child_model_factor.parent_sums
|
|
119
|
+
for child_model_factor in model_factor.child_factors
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
# Get the parameters
|
|
123
|
+
rvs = [child_rv] + list(model_factor.parent_rvs)
|
|
124
|
+
crosstab: CrossTable = _infer_parameters(rvs, crosstabs, child_crosstabs)
|
|
125
|
+
cpt, parent_sums = cpt_and_parent_sums_from_crosstab(crosstab)
|
|
126
|
+
model_factor.cpt = cpt
|
|
127
|
+
model_factor.parent_sums = parent_sums
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _infer_parameters(
|
|
131
|
+
rvs: Sequence[RandomVariable],
|
|
132
|
+
crosstabs: Sequence[CrossTable],
|
|
133
|
+
child_crosstabs: Sequence[CrossTable],
|
|
134
|
+
) -> CrossTable:
|
|
135
|
+
"""
|
|
136
|
+
Make best efforts to infer a probability distribution over the give random variables,
|
|
137
|
+
with evidence from the given cross-tables.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
a CrossTable representing the inferred probability distribution
|
|
141
|
+
(not normalised to sum to one).
|
|
142
|
+
|
|
143
|
+
Assumes:
|
|
144
|
+
`rvs` has no duplicates.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
# Project crosstables onto rvs, splitting them into complete and partial coverage of `rvs`.
|
|
148
|
+
# Completely covering cross-tables will be accumulated into `complete_crosstab` while others
|
|
149
|
+
# will be appended to partial_crosstabs.
|
|
150
|
+
complete_crosstab: Optional[CrossTable] = None
|
|
151
|
+
partial_crosstabs: List[CrossTable] = []
|
|
152
|
+
for available_crosstab in chain(crosstabs, child_crosstabs):
|
|
153
|
+
available_rvs: Set[RandomVariable] = set(available_crosstab.rvs)
|
|
154
|
+
if available_rvs.issuperset(rvs):
|
|
155
|
+
to_add: CrossTable = available_crosstab.project(rvs)
|
|
156
|
+
if complete_crosstab is None:
|
|
157
|
+
complete_crosstab = to_add
|
|
158
|
+
else:
|
|
159
|
+
complete_crosstab.add_all(to_add.project(rvs).items())
|
|
160
|
+
else:
|
|
161
|
+
partial_crosstabs.append(available_crosstab)
|
|
162
|
+
|
|
163
|
+
if complete_crosstab is not None:
|
|
164
|
+
# A direct solution was found.
|
|
165
|
+
# Ignore any partially covering cross-tables.
|
|
166
|
+
return complete_crosstab
|
|
167
|
+
|
|
168
|
+
# If there are no cross-tables available, the result is empty
|
|
169
|
+
if len(partial_crosstabs) == 0:
|
|
170
|
+
return CrossTable(rvs)
|
|
171
|
+
|
|
172
|
+
# There were no cross-tables that completely cover the given random variables.
|
|
173
|
+
# The following algorithm makes best attempts to coalesce a collection of
|
|
174
|
+
# partially covering cross-tables.
|
|
175
|
+
|
|
176
|
+
return coalesce_cross_tables(partial_crosstabs, rvs)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _crostab_str(crosstab: CrossTable) -> str:
|
|
180
|
+
return '(' + ', '.join(repr(rv.name) for rv in crosstab.rvs) + ')'
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _get_rv_order(cross_tables: Sequence[CrossTable]) -> Dict[RandomVariable, int]:
|
|
184
|
+
"""
|
|
185
|
+
Determine an order over the given random variables.
|
|
186
|
+
Returns a map from rv to its rank in the order.
|
|
187
|
+
"""
|
|
188
|
+
child_parent_map: MapList[RandomVariable, RandomVariable] = MapList()
|
|
189
|
+
for crosstab in cross_tables:
|
|
190
|
+
rvs = crosstab.rvs
|
|
191
|
+
for i in range(len(rvs)):
|
|
192
|
+
child = rvs[i]
|
|
193
|
+
parents = rvs[i + 1:]
|
|
194
|
+
child_parent_map.extend(child, parents)
|
|
195
|
+
order: Dict[RandomVariable, int] = {}
|
|
196
|
+
seen: Set[RandomVariable] = set()
|
|
197
|
+
for child in child_parent_map.keys():
|
|
198
|
+
_get_rv_order_r(child, child_parent_map, seen, order)
|
|
199
|
+
return order
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _get_rv_order_r(
|
|
203
|
+
child: RandomVariable,
|
|
204
|
+
child_parent_map: MapList[RandomVariable, RandomVariable],
|
|
205
|
+
seen: Set[RandomVariable],
|
|
206
|
+
order: Dict[RandomVariable, int],
|
|
207
|
+
):
|
|
208
|
+
if child not in seen:
|
|
209
|
+
seen.add(child)
|
|
210
|
+
parents = child_parent_map[child]
|
|
211
|
+
for parent in parents:
|
|
212
|
+
_get_rv_order_r(parent, child_parent_map, seen, order)
|
|
213
|
+
order[child] = len(order)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@dataclass
|
|
217
|
+
class _ModelFactor:
|
|
218
|
+
"""
|
|
219
|
+
A collection of model factors defines a PGM structure,
|
|
220
|
+
each model factor representing a needed PGM factor.
|
|
221
|
+
|
|
222
|
+
Associated with a model factor is a list of crosstabs
|
|
223
|
+
that whose rvs overlap with the rvs of the model factor.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
child_rv: RandomVariable
|
|
227
|
+
parent_rvs: set[RandomVariable] = field(default_factory=set)
|
|
228
|
+
cross_tables: List[CrossTable] = field(default_factory=list)
|
|
229
|
+
child_factors: List[_ModelFactor] = field(default_factory=list)
|
|
230
|
+
|
|
231
|
+
# These are set once the parameter values are inferred...
|
|
232
|
+
cpt: Optional[CrossTable] = None
|
|
233
|
+
parent_sums: Optional[CrossTable] = None
|
|
234
|
+
|
|
235
|
+
def dump(self, prefix=''):
|
|
236
|
+
"""
|
|
237
|
+
For debugging.
|
|
238
|
+
"""
|
|
239
|
+
print(f'{prefix}child:', self.child_rv)
|
|
240
|
+
print(f'{prefix}parents:', '{}' if not self.parent_rvs else self.parent_rvs)
|
|
241
|
+
for cross_table in self.cross_tables:
|
|
242
|
+
print(f'{prefix}cross-table:', *cross_table.rvs)
|