compiled-knowledge 4.1.0a2__cp312-cp312-win_amd64.whl → 4.2.0a1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (36) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  5. ck/circuit_compiler/llvm_compiler.py +4 -4
  6. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  7. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win_amd64.pyd +0 -0
  8. ck/circuit_compiler/support/input_vars.py +4 -4
  9. ck/dataset/cross_table.py +143 -79
  10. ck/dataset/dataset.py +95 -7
  11. ck/dataset/dataset_builder.py +11 -4
  12. ck/dataset/dataset_from_crosstable.py +21 -2
  13. ck/learning/coalesce_cross_tables.py +403 -0
  14. ck/learning/model_from_cross_tables.py +296 -0
  15. ck/learning/parameters.py +117 -0
  16. ck/learning/train_generative_bn.py +198 -0
  17. ck/pgm.py +10 -8
  18. ck/pgm_circuit/marginals_program.py +5 -0
  19. ck/pgm_circuit/wmc_program.py +5 -0
  20. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  21. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  22. ck/probability/divergence.py +226 -0
  23. ck/probability/probability_space.py +43 -19
  24. ck/utils/map_dict.py +89 -0
  25. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  26. ck_demos/learning/__init__.py +0 -0
  27. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
  28. ck_demos/learning/demo_simple_learning.py +55 -0
  29. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  30. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/METADATA +2 -1
  31. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/RECORD +35 -26
  32. ck/learning/train_generative.py +0 -149
  33. /ck/{dataset/cross_table_probabilities.py → probability/cross_table_probability_space.py} +0 -0
  34. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/WHEEL +0 -0
  35. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/licenses/LICENSE.txt +0 -0
  36. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,403 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, List, Sequence, Dict
5
+
6
+ import numpy as np
7
+ from scipy.sparse import dok_matrix
8
+ from scipy.sparse.linalg import lsqr
9
+
10
+ from ck.dataset.cross_table import CrossTable
11
+ from ck.pgm import RandomVariable, Instance
12
+ from ck.utils.iter_extras import combos
13
+ from ck.utils.np_extras import NDArrayFloat64
14
+
15
+
16
+ def coalesce_cross_tables(crosstabs: Sequence[CrossTable], rvs: Sequence[RandomVariable]) -> CrossTable:
17
+ """
18
+ Rationalise multiple cross-tables into a single cross-table.
19
+
20
+ This method implements a solution finding the best vector `a`
21
+ that solves `b = m a`, subject to `a[i] >= 0`.
22
+
23
+ `a` is a column vector with one entry for each instance of rvs seen in the cross-tables.
24
+
25
+ `b` is a column vector containing all probabilities from all cross-tables.
26
+
27
+ `m` is a sparse matrix with m[i, j] = 1 where b[j] is in the sum for source probability a[i].
28
+
29
+ "Best" means the vector `a` with:
30
+ * `a[i] >= 0` for all i, then
31
+ * `b - m a` having the smallest L2 norm, then
32
+ * `a` having the smallest L2 norm.
33
+
34
+ The given crosstables will be used to form `b`. The entries each cross-table will be normalised
35
+ to represent a probability distribution over the cross-table's instances (keys).
36
+
37
+ Args:
38
+ crosstabs: a collection of cross-tables to coalesce.
39
+ rvs: the random variables that cross-tables will be projected on to.
40
+
41
+ Returns:
42
+ a cross-table defined for the given `rvs`, with values inferred from `crosstabs`
43
+ """
44
+ if len(crosstabs) == 0:
45
+ return CrossTable(rvs)
46
+
47
+ m: dok_matrix
48
+ b: np.ndarray
49
+ a_keys: Sequence[Instance]
50
+ m, b, a_keys = _make_matrix(crosstabs, rvs)
51
+
52
+ a = _solve(m, b)
53
+
54
+ return CrossTable(
55
+ rvs,
56
+ update=(
57
+ (instance, weight)
58
+ for instance, weight in zip(a_keys, a)
59
+ ),
60
+ )
61
+
62
+
63
+ def _solve(m: dok_matrix, b: np.ndarray) -> np.ndarray:
64
+ """
65
+ Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
66
+ """
67
+ assert len(b.shape) == 1, 'b should be a vector'
68
+ assert b.shape[0] == m.shape[0], 'b and m must be compatible'
69
+
70
+ return _solve_sam(m, b)
71
+ # return _solve_lsqr(m, b)
72
+ # return _solve_pulp_l1(m, b)
73
+
74
+
75
+ def _solve_sam(m: dok_matrix, b: np.ndarray) -> np.ndarray:
76
+ """
77
+ Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
78
+
79
+ Uses a custom 'split and mean' (SAM) method.
80
+ """
81
+ sam = _SAM(m, b)
82
+ a, error = sam.solve(
83
+ max_iterations=100,
84
+ tolerance=1e-6,
85
+ change_tolerance=1e-12
86
+ )
87
+ return a
88
+
89
+
90
+ def _solve_lsqr(m: dok_matrix, b: np.ndarray) -> np.ndarray:
91
+ """
92
+ Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
93
+
94
+ Uses scipy `lsqr` method, with a heuristic to fix negative values in `a`.
95
+ """
96
+ a: np.ndarray
97
+ # Pycharm type checker incorrectly infers the type signature of `lsqr`
98
+ # noinspection PyTypeChecker
99
+ a, istop, itn, r1norm, r2norm, _, _, _, _, _ = lsqr(m, b)
100
+
101
+ # Negative values or values > 1 are not a valid solution.
102
+
103
+ # Heuristic fix up...
104
+ if len(a) > 0:
105
+ min_val = np.min(a)
106
+ if min_val < 0:
107
+ # We could just let the negative values get truncated to zero, but
108
+ # empirically the results seem better when we shift all parameters up.
109
+ a[:] -= min_val
110
+
111
+ return a
112
+
113
+
114
+ # This approach is unsatisfactory as we should minimise the L2 norm
115
+ # rather than the L1 norm.
116
+ #
117
+ # def _solve_pulp_l1(m: dok_matrix, b: np.ndarray) -> np.ndarray:
118
+ # """
119
+ # Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
120
+ #
121
+ # Uses pulp LpProblem to minimise the L1 norm of `a`.
122
+ #
123
+ # This method will only work if there is an exact solution.
124
+ # If not, then we call _solve_sam as a fallback.
125
+ # """
126
+ # import pulp
127
+ #
128
+ # a_size = m.shape[1]
129
+ # b_size = b.shape[0]
130
+ #
131
+ # prob = pulp.LpProblem('solver', pulp.LpMinimize)
132
+ # x = [pulp.LpVariable(f'x{i}', lowBound=0) for i in range(a_size)]
133
+ #
134
+ # # The objective: minimise the L1 norm of x.
135
+ # # The sum(x) is the L1 norm because each element of x is constrained >= 0.
136
+ # prob.setObjective(pulp.lpSum(x))
137
+ #
138
+ # # The constraints
139
+ # constraints = [pulp.LpAffineExpression() for _ in range(b_size)]
140
+ # for row, col in m.keys():
141
+ # constraints[row].addterm(x[col], 1)
142
+ # for c, b_i in zip(constraints, b):
143
+ # prob.addConstraint(c == b_i)
144
+ #
145
+ # _PULP_TIMEOUT = 60 # seconds
146
+ # status = prob.solve(pulp.PULP_CBC_CMD(msg=False, timeLimit=_PULP_TIMEOUT))
147
+ #
148
+ # if status == pulp.LpStatusOptimal:
149
+ # return np.array([pulp.value(x_var) for x_var in x])
150
+ # else:
151
+ # return _solve_sam(m, b)
152
+
153
+
154
+ def _sum_out_unneeded_rvs(crosstab: CrossTable, rvs: Sequence[RandomVariable]) -> CrossTable:
155
+ """
156
+ Project the given cross-table as needed to ensure all random
157
+ variables in the result are in `rvs`.
158
+ """
159
+ available_rvs = set(crosstab.rvs)
160
+ project_rvs = available_rvs.intersection(rvs)
161
+ if len(project_rvs) == len(available_rvs):
162
+ # No projection is required
163
+ return crosstab
164
+ else:
165
+ return crosstab.project(list(project_rvs))
166
+
167
+
168
+ def _make_matrix(
169
+ crosstabs: Sequence[CrossTable],
170
+ rvs: Sequence[RandomVariable],
171
+ ) -> Tuple[dok_matrix, np.ndarray, Sequence[Instance]]:
172
+ """
173
+ Create the `m` matrix and `b` vector for solving `b = m a`.
174
+
175
+ Args:
176
+ crosstabs: a collection of cross-tables to coalesce.
177
+ rvs: the random variables that cross-tables will be projected on to.
178
+
179
+ Returns:
180
+ the tuple (m, b, a_keys) where
181
+ 'm' is a sparse matrix,
182
+ 'b' is a numpy array of crosstab probabilities (normalised as needed),
183
+ 'a_keys' are the keys for the solution probabilities, co-indexed with `a`.
184
+ """
185
+
186
+ # Sum out any unneeded random variables
187
+ crosstabs: Sequence[CrossTable] = [
188
+ _sum_out_unneeded_rvs(crosstab, rvs)
189
+ for crosstab in crosstabs
190
+ ]
191
+
192
+ m_cols: Dict[Instance, _MCol] = {}
193
+ b_list: List[float] = []
194
+ a_keys: List[Instance] = []
195
+
196
+ rv_index: Dict[RandomVariable, int] = {rv: i for i, rv in enumerate(rvs)}
197
+
198
+ # instance_template[i] is a list of the possible states of rvs[i]
199
+ instance_template: List[List[int]] = [list(range(len(rv))) for rv in enumerate(rvs)]
200
+
201
+ for crosstab in crosstabs:
202
+
203
+ # get `to_rv` such that crosstab.rvs[i] = rvs[to_rv[i]]
204
+ to_rv = [rv_index.get(rv) for rv in crosstab.rvs]
205
+
206
+ # Make instance_options which is a clone of instance_template but with
207
+ # a singleton list replacing the rvs that this crosstab has.
208
+ # For now the state in each singleton is set to -1, however, later
209
+ # they will be set to the actual states of instances in the current crosstab.
210
+ instance_options = list(instance_template)
211
+ for i in to_rv:
212
+ instance_options[i] = [-1]
213
+
214
+ total = crosstab.total_weight()
215
+ for crosstab_instance, weight in crosstab.items():
216
+
217
+ # Work out what instances get summed to create the crosstab_instance weight.
218
+ # This just overrides the singleton states of `instance_options` with the
219
+ # actual state of the crosstab_instance.
220
+ for state, i in zip(crosstab_instance, to_rv):
221
+ instance_options[i][0] = state
222
+
223
+ # Grow the b list with our instance probability
224
+ b_i = len(b_list)
225
+ b_list.append(weight / total)
226
+
227
+ # Iterate over all states of `rvs` that matches the current crosstab_instance
228
+ # recording `b_i` in the column for those matching instances.
229
+ for instance in combos(instance_options):
230
+ m_col = m_cols.get(instance)
231
+ if m_col is None:
232
+ m_cols[instance] = _MCol(instance, len(a_keys), [b_i])
233
+ a_keys.append(instance)
234
+ else:
235
+ m_col.col.append(b_i)
236
+
237
+ # Construct the m matrix from m_cols
238
+ m = dok_matrix((len(b_list), len(a_keys)), dtype=np.double)
239
+ for m_col in m_cols.values():
240
+ j = m_col.column_index
241
+ for i in m_col.col:
242
+ m[i, j] = 1
243
+
244
+ # Construct the b vector
245
+ b = np.array(b_list, dtype=np.double)
246
+
247
+ return m, b, a_keys
248
+
249
+
250
+ @dataclass
251
+ class _MCol:
252
+ key: Instance
253
+ column_index: int
254
+ col: List[int]
255
+
256
+
257
+ @dataclass
258
+ class _SM:
259
+ split: float
260
+ a: float
261
+
262
+ def diff(self) -> float:
263
+ return self.split - self.a
264
+
265
+
266
+ class _SAM:
267
+ """
268
+ Split and Mean method for finding 'a'
269
+ in b = m a
270
+ subject to a[i] >= 0.
271
+
272
+ Assumes all elements of `m` are either zero or one.
273
+ """
274
+
275
+ def __init__(self, m: dok_matrix, b: NDArrayFloat64, use_lsqr: bool = True):
276
+ """
277
+ Allocate the memory required for a SAM solver.
278
+
279
+ Args:
280
+ m: the summation matrix
281
+ b: the vector of resulting probabilities
282
+ use_lsqr: whether to use LSQR or not to initialise the solution.
283
+ """
284
+ # Replicate the sparse m matrix, as a list of lists of _SM objects,
285
+ # where we have both row major and column major representations.
286
+ a_size: int = m.shape[1]
287
+ b_size: int = m.shape[0]
288
+ a_idx: List[List[_SM]] = [[] for _ in range(a_size)]
289
+ b_idx: List[List[_SM]] = [[] for _ in range(b_size)]
290
+ for (i, j), m_val in m.items():
291
+ if m_val != 0:
292
+ sm = _SM(0, 0)
293
+ a_idx[j].append(sm)
294
+ b_idx[i].append(sm)
295
+
296
+ self._a: NDArrayFloat64 = np.zeros(a_size, dtype=np.double)
297
+ self._b: NDArrayFloat64 = b
298
+ self._m: dok_matrix = m
299
+ self._a_idx: List[List[_SM]] = a_idx
300
+ self._b_idx: List[List[_SM]] = b_idx
301
+ self._use_lsqr = use_lsqr
302
+
303
+ def solve(self, max_iterations: int, tolerance: float, change_tolerance: float) -> Tuple[np.ndarray, float]:
304
+ """
305
+ Initialize split values then iterate (mean step, split step).
306
+
307
+ Args:
308
+ max_iterations: maximum number of iterations.
309
+ tolerance: terminate iterations if error <= tolerance.
310
+ change_tolerance: terminate iterations if change in error <= change_tolerance.
311
+
312
+ Returns:
313
+ tuple ('a', 'error') where 'a' is the current solution after this step
314
+ and 'error' is the sum of absolute errors.
315
+ """
316
+ self._initialize_split_values()
317
+ iteration = 0
318
+ prev_error = 0
319
+ while True:
320
+ iteration += 1
321
+ a, error = self._mean_step()
322
+ if error <= tolerance or abs(error - prev_error) <= change_tolerance or iteration >= max_iterations:
323
+ return a, error
324
+ prev_error = error
325
+ self._split_step()
326
+
327
+ def _initialize_split_values(self):
328
+ """
329
+ Take each 'b' value and split it across its SM cells.
330
+
331
+ If 'self._use_lsqr' is True then the split is based on a solution
332
+ using scipy lsqr, otherwise the split is even for each 'b' value.
333
+ """
334
+ for b_val, b_list in zip(self._b, self._b_idx):
335
+ len_b_list = len(b_list)
336
+ if len_b_list > 0:
337
+ split_val = b_val / len_b_list
338
+ for sm in b_list:
339
+ sm.split = split_val
340
+
341
+ if self._use_lsqr:
342
+ a = _solve_lsqr(self._m, self._b)
343
+ assert len(a) == len(self._a)
344
+ for a_val, a_list in zip(a, self._a_idx):
345
+ for sm in a_list:
346
+ sm.a = a_val
347
+ self._split_step()
348
+
349
+ def _mean_step(self) -> Tuple[np.ndarray, float]:
350
+ """
351
+ Take the current split values to determine the 'a' values
352
+ as the mean across relevant SM cells.
353
+
354
+ Assumes the previous step was either 'initialize_split_values'
355
+ or a 'split step'.
356
+
357
+ Returns:
358
+ tuple ('a', 'error') where 'a' is the current solution after this step
359
+ and 'error' is the sum of absolute errors.
360
+ """
361
+ error = 0.0
362
+ a = self._a
363
+ for i, a_list in enumerate(self._a_idx):
364
+ sum_val = sum(sm.split for sm in a_list)
365
+ a_val = sum_val / len(a_list)
366
+ a[i] = a_val
367
+ for sm in a_list:
368
+ sm.a = a_val
369
+ error += abs(sm.diff())
370
+ return a, error
371
+
372
+ def _split_step(self):
373
+ """
374
+ Take the difference between the split 'b' values and current 'a' values to
375
+ redistribute the split 'b' values.
376
+
377
+ Assumes the previous step was a 'mean step'.
378
+ """
379
+ for b_val, b_list in zip(self._b, self._b_idx):
380
+ if len(b_list) <= 1:
381
+ # Too small to split
382
+ continue
383
+ pos = 0
384
+ neg = 0
385
+ for sm in b_list:
386
+ diff = sm.diff()
387
+ if diff >= 0:
388
+ pos += diff
389
+ else:
390
+ neg -= diff
391
+ mass = min(pos, neg)
392
+ if mass == 0:
393
+ # No mass to redistribute
394
+ continue
395
+ pos = mass / pos
396
+ neg = mass / neg
397
+ for sm in b_list:
398
+ diff = sm.diff()
399
+ if diff >= 0:
400
+ mass = diff * pos
401
+ else:
402
+ mass = diff * neg
403
+ sm.split -= mass
@@ -0,0 +1,296 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum, auto
5
+ from itertools import chain
6
+ from typing import Iterable, List, Tuple, Dict, Sequence, Set, Optional
7
+
8
+ from ck.dataset.cross_table import CrossTable
9
+ from ck.learning.coalesce_cross_tables import coalesce_cross_tables
10
+ from ck.learning.parameters import make_factors, ParameterValues
11
+ from ck.learning.train_generative_bn import cpt_and_parent_sums_from_crosstab
12
+ from ck.pgm import PGM, RandomVariable
13
+ from ck.utils.map_list import MapList
14
+
15
+
16
+ class ParameterInference(Enum):
17
+ """
18
+ There are variations on the method for inferring a CPT's parameter values.
19
+ This enum defines the possible variations.
20
+ """
21
+ first = auto()
22
+ """
23
+ Use the first cross-table found that covers the CPT's random variables.
24
+ If there is no such cross-table, revert to the `all` method.
25
+ This is the fastest method.
26
+ """
27
+
28
+ sum = auto()
29
+ """
30
+ Use the sum of cross-tables that cover the CPT's random variables.
31
+ If there are no such cross-tables, revert to the `all` method.
32
+ This is the second fastest method.
33
+ """
34
+
35
+ all = auto()
36
+ """
37
+ Project all cross-tables onto the needed random variables, then use
38
+ `coalesce_cross_tables` to solve for the best parameter values.
39
+ """
40
+
41
+
42
+ def model_from_cross_tables(
43
+ pgm: PGM,
44
+ cross_tables: Iterable[CrossTable],
45
+ method: ParameterInference = ParameterInference.sum,
46
+ ) -> None:
47
+ """
48
+ Make best efforts to construct a Bayesian network model given only the
49
+ evidence from the supplied cross-tables.
50
+
51
+ This function calls `get_cpts` which provides parameters to define
52
+ a Bayesian network model. These are then applied to the PGM using
53
+ `make_factors`.
54
+
55
+ Args:
56
+ pgm: the PGM to add factors and potential function to.
57
+ cross_tables: available cross-tables to build a model from.
58
+ method: what parameter inference method to use.
59
+
60
+ Raises:
61
+ ValueError: If `pgm` has any existing factors.
62
+ """
63
+ if len(pgm.factors) > 0:
64
+ raise ValueError('the given PGM should have no factors')
65
+ cpts: List[CrossTable] = get_cpts(
66
+ rvs=pgm.rvs,
67
+ cross_tables=cross_tables,
68
+ method=method,
69
+ )
70
+ make_factors(pgm, cpts)
71
+
72
+
73
+ def get_cpts(
74
+ rvs: Sequence[RandomVariable],
75
+ cross_tables: Iterable[CrossTable],
76
+ method: ParameterInference = ParameterInference.sum,
77
+ ) -> ParameterValues:
78
+ """
79
+ Make best efforts to define a Bayesian network model given only the
80
+ evidence from the supplied cross-tables.
81
+
82
+ This function infers CPTs for `rvs` using the given `cross_tables`.
83
+
84
+ For any two cross-tables `x` and `y` in `cross_tables`, with common random
85
+ variables `rvs` then this function assumes `x.project(rvs) == y.project(rvs)`.
86
+ If this condition does not hold, then best efforts will still be made to
87
+ define a Bayesian network model, however, the resulting parameter values
88
+ may be suboptimal.
89
+
90
+ Args:
91
+ rvs: the random variables to define a network structure over.
92
+ cross_tables: available cross-tables to build a model from.
93
+ method: what parameter inference method to use.
94
+
95
+ Returns:
96
+ ParameterValues object as a list of CPTs, each CPT can be used to create
97
+ a new factor in the given PGM to make a Bayesian network.
98
+ """
99
+ # Stabilise the given crosstables
100
+ cross_tables: Tuple[CrossTable, ...] = tuple(cross_tables)
101
+
102
+ # Heuristically determine an ordering over the random variables
103
+ # which will be used to form the BN structure.
104
+ rv_order: Dict[RandomVariable, int] = _get_rv_order(cross_tables)
105
+
106
+ # Make an empty model factor for each random variable.
107
+ model_factors: List[_ModelFactor] = [_ModelFactor(rv) for rv in rvs]
108
+
109
+ # Define a Bayesian network structure.
110
+ # Allocate each crosstab to exactly one random variable, the
111
+ # one with the highest rank (i.e. the child).
112
+ for crosstab in cross_tables:
113
+ if len(crosstab.rvs) == 0:
114
+ continue
115
+ sorted_rvs: List[RandomVariable] = sorted(crosstab.rvs, key=(lambda _rv: rv_order[_rv]), reverse=True)
116
+ child: RandomVariable = sorted_rvs[0]
117
+ parents: List[RandomVariable] = sorted_rvs[1:]
118
+ model_factor: _ModelFactor = model_factors[child.idx]
119
+ model_factor.parent_rvs.update(parents)
120
+ model_factor.cross_tables.append(crosstab)
121
+
122
+ # Link child factors.
123
+ # When defining a factor, we need to define the child factors first.
124
+ for model_factor in model_factors:
125
+ for parent_rv in model_factor.parent_rvs:
126
+ model_factors[parent_rv.idx].child_factors.append(model_factor)
127
+
128
+ # Make the factors, depth first.
129
+ done: Set[int] = set()
130
+ for model_factor in model_factors:
131
+ _infer_cpt(model_factor, done, method)
132
+
133
+ # Return the CPTs that define the structure
134
+ return [model_factor.cpt for model_factor in model_factors]
135
+
136
+
137
+ def _infer_cpt(
138
+ model_factor: _ModelFactor,
139
+ done: Set[int],
140
+ method: ParameterInference,
141
+ ) -> None:
142
+ """
143
+ Depth-first recursively infer the model factors as CPTs.
144
+ This sets `model_factor.cpt` and `model_factor.parent_rvs` for
145
+ the given `model_factor` and its children.
146
+ """
147
+ # Only visit a model factor once.
148
+ child_rv: RandomVariable = model_factor.child_rv
149
+ if child_rv.idx in done:
150
+ return
151
+ done.add(child_rv.idx)
152
+
153
+ # Recursively visit the child factors
154
+ for child_model_factor in model_factor.child_factors:
155
+ _infer_cpt(child_model_factor, done, method)
156
+
157
+ # Get all relevant cross-tables to set the parameters
158
+ crosstabs: Sequence[CrossTable] = model_factor.cross_tables
159
+ child_crosstabs: Sequence[CrossTable] = [
160
+ child_model_factor.parent_sums
161
+ for child_model_factor in model_factor.child_factors
162
+ ]
163
+
164
+ # Get the parameters
165
+ rvs = [child_rv] + list(model_factor.parent_rvs)
166
+ crosstab: CrossTable = _infer_parameter_values(rvs, crosstabs, child_crosstabs, method)
167
+ cpt, parent_sums = cpt_and_parent_sums_from_crosstab(crosstab)
168
+ model_factor.cpt = cpt
169
+ model_factor.parent_sums = parent_sums
170
+
171
+
172
+ def _infer_parameter_values(
173
+ rvs: Sequence[RandomVariable],
174
+ crosstabs: Sequence[CrossTable],
175
+ child_crosstabs: Sequence[CrossTable],
176
+ method: ParameterInference,
177
+ ) -> CrossTable:
178
+ """
179
+ Make best efforts to infer a probability distribution over the given random variables,
180
+ with evidence from the given cross-tables.
181
+
182
+ Returns:
183
+ a CrossTable representing the inferred probability distribution
184
+ (not normalised to sum to one).
185
+
186
+ Assumes:
187
+ `rvs` has no duplicates.
188
+ """
189
+
190
+ if method == ParameterInference.all:
191
+ # Forced to use all cross-tables with `coalesce_cross_tables`
192
+ projected_crosstabs: List[CrossTable] = [
193
+ crosstab.project(rvs)
194
+ for crosstab in chain(crosstabs, child_crosstabs)
195
+ ]
196
+ return coalesce_cross_tables(projected_crosstabs, rvs)
197
+
198
+ # Project crosstables onto rvs, splitting them into complete and partial coverage of `rvs`.
199
+ # Completely covering cross-tables will be summed into `complete_crosstab` while others
200
+ # will be appended to partial_crosstabs.
201
+ complete_crosstab: Optional[CrossTable] = None
202
+ partial_crosstabs: List[CrossTable] = []
203
+ for available_crosstab in chain(crosstabs, child_crosstabs):
204
+ available_rvs: Set[RandomVariable] = set(available_crosstab.rvs)
205
+ if available_rvs.issuperset(rvs):
206
+ to_add: CrossTable = available_crosstab.project(rvs)
207
+ if method == ParameterInference.first:
208
+ # Take the first available solution.
209
+ return to_add
210
+ if complete_crosstab is None:
211
+ complete_crosstab = to_add
212
+ else:
213
+ complete_crosstab.add_all(to_add.items())
214
+ else:
215
+ partial_crosstabs.append(available_crosstab)
216
+
217
+ if complete_crosstab is not None:
218
+ # A direct solution was found.
219
+ # Ignore any partially covering cross-tables.
220
+ return complete_crosstab
221
+
222
+ # If there are no cross-tables available, the result is empty
223
+ if len(partial_crosstabs) == 0:
224
+ return CrossTable(rvs)
225
+
226
+ # There were no cross-tables that completely cover the given random variables.
227
+ # The following algorithm makes best attempts to coalesce a collection of
228
+ # partially covering cross-tables.
229
+
230
+ return coalesce_cross_tables(partial_crosstabs, rvs)
231
+
232
+
233
+ def _crostab_str(crosstab: CrossTable) -> str:
234
+ return '(' + ', '.join(repr(rv.name) for rv in crosstab.rvs) + ')'
235
+
236
+
237
+ def _get_rv_order(cross_tables: Sequence[CrossTable]) -> Dict[RandomVariable, int]:
238
+ """
239
+ Determine an order over the given random variables.
240
+ Returns a map from rv to its rank in the order.
241
+ """
242
+ child_parent_map: MapList[RandomVariable, RandomVariable] = MapList()
243
+ for crosstab in cross_tables:
244
+ rvs = crosstab.rvs
245
+ for i in range(len(rvs)):
246
+ child = rvs[i]
247
+ parents = rvs[i + 1:]
248
+ child_parent_map.extend(child, parents)
249
+ order: Dict[RandomVariable, int] = {}
250
+ seen: Set[RandomVariable] = set()
251
+ for child in child_parent_map.keys():
252
+ _get_rv_order_r(child, child_parent_map, seen, order)
253
+ return order
254
+
255
+
256
+ def _get_rv_order_r(
257
+ child: RandomVariable,
258
+ child_parent_map: MapList[RandomVariable, RandomVariable],
259
+ seen: Set[RandomVariable],
260
+ order: Dict[RandomVariable, int],
261
+ ):
262
+ if child not in seen:
263
+ seen.add(child)
264
+ parents = child_parent_map[child]
265
+ for parent in parents:
266
+ _get_rv_order_r(parent, child_parent_map, seen, order)
267
+ order[child] = len(order)
268
+
269
+
270
+ @dataclass
271
+ class _ModelFactor:
272
+ """
273
+ A collection of model factors defines a PGM structure,
274
+ each model factor representing a needed PGM factor.
275
+
276
+ Associated with a model factor is a list of crosstabs
277
+ that whose rvs overlap with the rvs of the model factor.
278
+ """
279
+
280
+ child_rv: RandomVariable
281
+ parent_rvs: set[RandomVariable] = field(default_factory=set)
282
+ cross_tables: List[CrossTable] = field(default_factory=list)
283
+ child_factors: List[_ModelFactor] = field(default_factory=list)
284
+
285
+ # These are set once the parameter values are inferred...
286
+ cpt: Optional[CrossTable] = None
287
+ parent_sums: Optional[CrossTable] = None
288
+
289
+ def dump(self, prefix=''):
290
+ """
291
+ For debugging.
292
+ """
293
+ print(f'{prefix}child:', self.child_rv)
294
+ print(f'{prefix}parents:', '{}' if not self.parent_rvs else self.parent_rvs)
295
+ for cross_table in self.cross_tables:
296
+ print(f'{prefix}cross-table:', *cross_table.rvs)