compiled-knowledge 4.1.0a2__cp313-cp313-macosx_11_0_arm64.whl → 4.1.0a3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (33) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cpython-313-darwin.so +0 -0
  3. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
  5. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  6. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cpython-313-darwin.so +0 -0
  7. ck/dataset/cross_table.py +143 -79
  8. ck/dataset/dataset.py +95 -7
  9. ck/dataset/dataset_builder.py +11 -4
  10. ck/dataset/dataset_from_crosstable.py +21 -2
  11. ck/learning/coalesce_cross_tables.py +395 -0
  12. ck/learning/model_from_cross_tables.py +242 -0
  13. ck/learning/parameters.py +117 -0
  14. ck/learning/train_generative_bn.py +198 -0
  15. ck/pgm.py +10 -8
  16. ck/pgm_circuit/marginals_program.py +5 -0
  17. ck/pgm_circuit/wmc_program.py +5 -0
  18. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  19. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cpython-313-darwin.so +0 -0
  20. ck/probability/divergence.py +226 -0
  21. ck/probability/probability_space.py +43 -19
  22. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  23. ck_demos/learning/__init__.py +0 -0
  24. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +71 -0
  25. ck_demos/learning/demo_simple_learning.py +55 -0
  26. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  27. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/METADATA +2 -1
  28. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/RECORD +32 -24
  29. ck/learning/train_generative.py +0 -149
  30. /ck/{dataset/cross_table_probabilities.py → probability/cross_table_probability_space.py} +0 -0
  31. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/WHEEL +0 -0
  32. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/licenses/LICENSE.txt +0 -0
  33. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.1.0a3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,395 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, List, Sequence, Dict
5
+
6
+ import numpy as np
7
+ from scipy.sparse import dok_matrix
8
+ from scipy.sparse.linalg import lsqr
9
+
10
+ from ck.dataset.cross_table import CrossTable
11
+ from ck.pgm import RandomVariable, Instance
12
+ from ck.utils.iter_extras import combos
13
+ from ck.utils.np_extras import NDArrayFloat64
14
+
15
+
16
+ def coalesce_cross_tables(crosstabs: Sequence[CrossTable], rvs: Sequence[RandomVariable]) -> CrossTable:
17
+ """
18
+ Rationalise multiple cross-tables into a single cross-table.
19
+
20
+ This method implements a solution finding the best vector `a`
21
+ that solves `b = m a`, subject to `a[i] >= 0`.
22
+
23
+ `a` is a column vector with one entry for each instance of rvs seen in the cross-tables.
24
+
25
+ `b` is a column vector containing all probabilities from all cross-tables (normalised weights).
26
+
27
+ `m` is a sparse matrix with m[i, j] = 1 where b[j] is in the sum for source probability a[i].
28
+
29
+ Args:
30
+ crosstabs: a collection of cross-tables to coalesce.
31
+ rvs: the random variables that cross-tables will be projected on to.
32
+
33
+ Returns:
34
+ a cross-table defined for the given `rvs`, with values inferred from `crosstabs`
35
+ """
36
+ if len(crosstabs) == 0:
37
+ return CrossTable(rvs)
38
+
39
+ m: dok_matrix
40
+ b: np.ndarray
41
+ a_keys: Sequence[Instance]
42
+ m, b, a_keys = _make_matrix(crosstabs, rvs)
43
+
44
+ a = _solve(m, b)
45
+
46
+ return CrossTable(
47
+ rvs,
48
+ update=(
49
+ (instance, weight)
50
+ for instance, weight in zip(a_keys, a)
51
+ ),
52
+ )
53
+
54
+
55
+ def _solve(m: dok_matrix, b: np.ndarray) -> np.ndarray:
56
+ """
57
+ Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
58
+ """
59
+ assert len(b.shape) == 1, 'b should be a vector'
60
+ assert b.shape[0] == m.shape[0], 'b and m must be compatible'
61
+
62
+ return _solve_sam(m, b)
63
+ # return _solve_lsqr(m, b)
64
+ # return _solve_pulp_l1(m, b)
65
+
66
+
67
+ def _solve_sam(m: dok_matrix, b: np.ndarray) -> np.ndarray:
68
+ """
69
+ Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
70
+
71
+ Uses a custom 'split and mean' (SAM) method.
72
+ """
73
+ sam = _SAM(m, b)
74
+ a, error = sam.solve(
75
+ max_iterations=100,
76
+ tolerance=1e-6,
77
+ change_tolerance=1e-12
78
+ )
79
+ return a
80
+
81
+
82
+ def _solve_lsqr(m: dok_matrix, b: np.ndarray) -> np.ndarray:
83
+ """
84
+ Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
85
+
86
+ Uses scipy `lsqr` method, with a heuristic to fix negative values in `a`.
87
+ """
88
+ a: np.ndarray
89
+ # Pycharm type checker incorrectly infers the type signature of `lsqr`
90
+ # noinspection PyTypeChecker
91
+ a, istop, itn, r1norm, r2norm, _, _, _, _, _ = lsqr(m, b)
92
+
93
+ # Negative values or values > 1 are not a valid solution.
94
+
95
+ # Heuristic fix up...
96
+ if len(a) > 0:
97
+ min_val = np.min(a)
98
+ if min_val < 0:
99
+ # We could just let the negative values get truncated to zero, but
100
+ # empirically the results seem better when we shift all parameters up.
101
+ a[:] -= min_val
102
+
103
+ return a
104
+
105
+
106
+ # This approach is unsatisfactory as we should minimise the L2 norm
107
+ # rather than the L1 norm.
108
+ #
109
+ # def _solve_pulp_l1(m: dok_matrix, b: np.ndarray) -> np.ndarray:
110
+ # """
111
+ # Find the best 'a' for `b = m a` subject to `a[i] >= 0`.
112
+ #
113
+ # Uses pulp LpProblem to minimise the L1 norm of `a`.
114
+ #
115
+ # This method will only work if there is an exact solution.
116
+ # If not, then we call _solve_sam as a fallback.
117
+ # """
118
+ # import pulp
119
+ #
120
+ # a_size = m.shape[1]
121
+ # b_size = b.shape[0]
122
+ #
123
+ # prob = pulp.LpProblem('solver', pulp.LpMinimize)
124
+ # x = [pulp.LpVariable(f'x{i}', lowBound=0) for i in range(a_size)]
125
+ #
126
+ # # The objective: minimise the L1 norm of x.
127
+ # # The sum(x) is the L1 norm because each element of x is constrained >= 0.
128
+ # prob.setObjective(pulp.lpSum(x))
129
+ #
130
+ # # The constraints
131
+ # constraints = [pulp.LpAffineExpression() for _ in range(b_size)]
132
+ # for row, col in m.keys():
133
+ # constraints[row].addterm(x[col], 1)
134
+ # for c, b_i in zip(constraints, b):
135
+ # prob.addConstraint(c == b_i)
136
+ #
137
+ # _PULP_TIMEOUT = 60 # seconds
138
+ # status = prob.solve(pulp.PULP_CBC_CMD(msg=False, timeLimit=_PULP_TIMEOUT))
139
+ #
140
+ # if status == pulp.LpStatusOptimal:
141
+ # return np.array([pulp.value(x_var) for x_var in x])
142
+ # else:
143
+ # return _solve_sam(m, b)
144
+
145
+
146
+ def _sum_out_unneeded_rvs(crosstab: CrossTable, rvs: Sequence[RandomVariable]) -> CrossTable:
147
+ """
148
+ Project the given cross-table as needed to ensure all random
149
+ variables in the result are in `rvs`.
150
+ """
151
+ available_rvs = set(crosstab.rvs)
152
+ project_rvs = available_rvs.intersection(rvs)
153
+ if len(project_rvs) == len(available_rvs):
154
+ # No projection is required
155
+ return crosstab
156
+ else:
157
+ return crosstab.project(list(project_rvs))
158
+
159
+
160
+ def _make_matrix(
161
+ crosstabs: Sequence[CrossTable],
162
+ rvs: Sequence[RandomVariable],
163
+ ) -> Tuple[dok_matrix, np.ndarray, Sequence[Instance]]:
164
+ """
165
+ Create the `m` matrix and `b` vector for solving `b = m a`.
166
+
167
+ Args:
168
+ crosstabs: a collection of cross-tables to coalesce.
169
+ rvs: the random variables that cross-tables will be projected on to.
170
+
171
+ Returns:
172
+ the tuple (m, b, a_keys) where
173
+ 'm' is a sparse matrix,
174
+ 'b' is a numpy array of crosstab probabilities,
175
+ 'a_keys' are the keys for the solution probabilities.
176
+ """
177
+
178
+ # Sum out any unneeded random variables
179
+ crosstabs: Sequence[CrossTable] = [
180
+ _sum_out_unneeded_rvs(crosstab, rvs)
181
+ for crosstab in crosstabs
182
+ ]
183
+
184
+ m_cols: Dict[Instance, _MCol] = {}
185
+ b_list: List[float] = []
186
+ a_keys: List[Instance] = []
187
+
188
+ rv_index: Dict[RandomVariable, int] = {rv: i for i, rv in enumerate(rvs)}
189
+
190
+ # instance_template[i] is a list of the possible states of rvs[i]
191
+ instance_template: List[List[int]] = [list(range(len(rv))) for rv in enumerate(rvs)]
192
+
193
+ for crosstab in crosstabs:
194
+
195
+ # get `to_rv` such that crosstab.rvs[i] = rvs[to_rv[i]]
196
+ to_rv = [rv_index.get(rv) for rv in crosstab.rvs]
197
+
198
+ # Make instance_options which is a clone of instance_template but with
199
+ # a singleton list replacing the rvs that this crosstab has.
200
+ # For now the state in each singleton is set to -1, however, later
201
+ # they will be set to the actual states of instances in the current crosstab.
202
+ instance_options = list(instance_template)
203
+ for i in to_rv:
204
+ instance_options[i] = [-1]
205
+
206
+ total = crosstab.total_weight()
207
+ for crosstab_instance, weight in crosstab.items():
208
+
209
+ # Work out what instances get summed to create the crosstab_instance weight.
210
+ # This just overrides the singleton states of `instance_options` with the
211
+ # actual state of the crosstab_instance.
212
+ for state, i in zip(crosstab_instance, to_rv):
213
+ instance_options[i][0] = state
214
+
215
+ # Grow the b list with our instance probability
216
+ b_i = len(b_list)
217
+ b_list.append(weight / total)
218
+
219
+ # Iterate over all states of `rvs` that matches the current crosstab_instance
220
+ # recording `b_i` in the column for those matching instances.
221
+ for instance in combos(instance_options):
222
+ m_col = m_cols.get(instance)
223
+ if m_col is None:
224
+ m_cols[instance] = _MCol(instance, len(a_keys), [b_i])
225
+ a_keys.append(instance)
226
+ else:
227
+ m_col.col.append(b_i)
228
+
229
+ # Construct the m matrix from m_cols
230
+ m = dok_matrix((len(b_list), len(a_keys)), dtype=np.double)
231
+ for m_col in m_cols.values():
232
+ j = m_col.column_index
233
+ for i in m_col.col:
234
+ m[i, j] = 1
235
+
236
+ # Construct the b vector
237
+ b = np.array(b_list, dtype=np.double)
238
+
239
+ return m, b, a_keys
240
+
241
+
242
+ @dataclass
243
+ class _MCol:
244
+ key: Instance
245
+ column_index: int
246
+ col: List[int]
247
+
248
+
249
+ @dataclass
250
+ class _SM:
251
+ split: float
252
+ a: float
253
+
254
+ def diff(self) -> float:
255
+ return self.split - self.a
256
+
257
+
258
+ class _SAM:
259
+ """
260
+ Split and Mean method for finding 'a'
261
+ in b = m a
262
+ subject to a[i] >= 0.
263
+
264
+ Assumes all elements of `m` are either zero or one.
265
+ """
266
+
267
+ def __init__(self, m: dok_matrix, b: NDArrayFloat64, use_lsqr: bool = True):
268
+ """
269
+ Allocate the memory required for a SAM solver.
270
+
271
+ Args:
272
+ m: the summation matrix
273
+ b: the vector of resulting probabilities
274
+ use_lsqr: whether to use LSQR or not to initialise the solution.
275
+ """
276
+ # Replicate the sparse m matrix, as a list of lists of _SM objects,
277
+ # where we have both row major and column major representations.
278
+ a_size: int = m.shape[1]
279
+ b_size: int = m.shape[0]
280
+ a_idx: List[List[_SM]] = [[] for _ in range(a_size)]
281
+ b_idx: List[List[_SM]] = [[] for _ in range(b_size)]
282
+ for (i, j), m_val in m.items():
283
+ if m_val != 0:
284
+ sm = _SM(0, 0)
285
+ a_idx[j].append(sm)
286
+ b_idx[i].append(sm)
287
+
288
+ self._a: NDArrayFloat64 = np.zeros(a_size, dtype=np.double)
289
+ self._b: NDArrayFloat64 = b
290
+ self._m: dok_matrix = m
291
+ self._a_idx: List[List[_SM]] = a_idx
292
+ self._b_idx: List[List[_SM]] = b_idx
293
+ self._use_lsqr = use_lsqr
294
+
295
+ def solve(self, max_iterations: int, tolerance: float, change_tolerance: float) -> Tuple[np.ndarray, float]:
296
+ """
297
+ Initialize split values then iterate (mean step, split step).
298
+
299
+ Args:
300
+ max_iterations: maximum number of iterations.
301
+ tolerance: terminate iterations if error <= tolerance.
302
+ change_tolerance: terminate iterations if change in error <= change_tolerance.
303
+
304
+ Returns:
305
+ tuple ('a', 'error') where 'a' is the current solution after this step
306
+ and 'error' is the sum of absolute errors.
307
+ """
308
+ self._initialize_split_values()
309
+ iteration = 0
310
+ prev_error = 0
311
+ while True:
312
+ iteration += 1
313
+ a, error = self._mean_step()
314
+ if error <= tolerance or abs(error - prev_error) <= change_tolerance or iteration >= max_iterations:
315
+ return a, error
316
+ prev_error = error
317
+ self._split_step()
318
+
319
+ def _initialize_split_values(self):
320
+ """
321
+ Take each 'b' value and split it across its SM cells.
322
+
323
+ If 'self._use_lsqr' is True then the split is based on a solution
324
+ using scipy lsqr, otherwise the split is even for each 'b' value.
325
+ """
326
+ for b_val, b_list in zip(self._b, self._b_idx):
327
+ len_b_list = len(b_list)
328
+ if len_b_list > 0:
329
+ split_val = b_val / len_b_list
330
+ for sm in b_list:
331
+ sm.split = split_val
332
+
333
+ if self._use_lsqr:
334
+ a = _solve_lsqr(self._m, self._b)
335
+ assert len(a) == len(self._a)
336
+ for a_val, a_list in zip(a, self._a_idx):
337
+ for sm in a_list:
338
+ sm.a = a_val
339
+ self._split_step()
340
+
341
+ def _mean_step(self) -> Tuple[np.ndarray, float]:
342
+ """
343
+ Take the current split values to determine the 'a' values
344
+ as the mean across relevant SM cells.
345
+
346
+ Assumes the previous step was either 'initialize_split_values'
347
+ or a 'split step'.
348
+
349
+ Returns:
350
+ tuple ('a', 'error') where 'a' is the current solution after this step
351
+ and 'error' is the sum of absolute errors.
352
+ """
353
+ error = 0.0
354
+ a = self._a
355
+ for i, a_list in enumerate(self._a_idx):
356
+ sum_val = sum(sm.split for sm in a_list)
357
+ a_val = sum_val / len(a_list)
358
+ a[i] = a_val
359
+ for sm in a_list:
360
+ sm.a = a_val
361
+ error += abs(sm.diff())
362
+ return a, error
363
+
364
+ def _split_step(self):
365
+ """
366
+ Take the difference between the split 'b' values and current 'a' values to
367
+ redistribute the split 'b' values.
368
+
369
+ Assumes the previous step was a 'mean step'.
370
+ """
371
+ for b_val, b_list in zip(self._b, self._b_idx):
372
+ if len(b_list) <= 1:
373
+ # Too small to split
374
+ continue
375
+ pos = 0
376
+ neg = 0
377
+ for sm in b_list:
378
+ diff = sm.diff()
379
+ if diff >= 0:
380
+ pos += diff
381
+ else:
382
+ neg -= diff
383
+ mass = min(pos, neg)
384
+ if mass == 0:
385
+ # No mass to redistribute
386
+ continue
387
+ pos = mass / pos
388
+ neg = mass / neg
389
+ for sm in b_list:
390
+ diff = sm.diff()
391
+ if diff >= 0:
392
+ mass = diff * pos
393
+ else:
394
+ mass = diff * neg
395
+ sm.split -= mass
@@ -0,0 +1,242 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from itertools import chain
5
+ from typing import Iterable, List, Tuple, Dict, Sequence, Set, Optional
6
+
7
+ from ck.dataset.cross_table import CrossTable
8
+ from ck.learning.coalesce_cross_tables import coalesce_cross_tables
9
+ from ck.learning.parameters import make_factors, ParameterValues
10
+ from ck.learning.train_generative_bn import cpt_and_parent_sums_from_crosstab
11
+ from ck.pgm import PGM, RandomVariable
12
+ from ck.utils.map_list import MapList
13
+
14
+
15
+ def model_from_cross_tables(pgm: PGM, cross_tables: Iterable[CrossTable]) -> None:
16
+ """
17
+ Make best efforts to construct a Bayesian network model given only the
18
+ evidence from the supplied cross-tables.
19
+
20
+ This function calls `get_cpts` which provides parameters to define
21
+ a Bayesian network model. These are then applied to the PGM using
22
+ `make_factors`.
23
+
24
+ Args:
25
+ pgm: the PGM to add factors and potential function to.
26
+ cross_tables: available cross-tables to build a model from.
27
+
28
+ Raises:
29
+ ValueError: If `pgm` has any existing factors.
30
+ """
31
+ if len(pgm.factors) > 0:
32
+ raise ValueError('the given PGM should have no factors')
33
+ cpts: List[CrossTable] = get_cpts(
34
+ rvs=pgm.rvs,
35
+ cross_tables=cross_tables,
36
+ )
37
+ make_factors(pgm, cpts)
38
+
39
+
40
+ def get_cpts(rvs: Sequence[RandomVariable], cross_tables: Iterable[CrossTable]) -> ParameterValues:
41
+ """
42
+ Make best efforts to define a Bayesian network model given only the
43
+ evidence from the supplied cross-tables.
44
+
45
+ This function infers CPTs for `rvs` using the given `cross_tables`.
46
+
47
+ For any two cross-tables `x` and `y` in `cross_tables`, with common random
48
+ variables `rvs` then this function assumes `x.project(rvs) == y.project(rvs)`.
49
+ If this condition does not hold, then best efforts will still be made to
50
+ define a Bayesian network model, however, the resulting parameter values
51
+ may be suboptimal.
52
+
53
+ Args:
54
+ rvs: the random variables to define a network structure over.
55
+ cross_tables: available cross-tables to build a model from.
56
+
57
+ Returns:
58
+ ParameterValues object as a list of CPTs, each CPT can be used to create
59
+ a new factor in the given PGM to make a Bayesian network.
60
+ """
61
+ # Stabilise the given crosstables
62
+ cross_tables: Tuple[CrossTable, ...] = tuple(cross_tables)
63
+
64
+ # Heuristically determine an ordering over the random variables
65
+ # which will be used to form the BN structure.
66
+ rv_order: Dict[RandomVariable, int] = _get_rv_order(cross_tables)
67
+
68
+ # Make an empty model factor for each random variable.
69
+ model_factors: List[_ModelFactor] = [_ModelFactor(rv) for rv in rvs]
70
+
71
+ # Define a Bayesian network structure.
72
+ # Allocate each crosstab to exactly one random variable, the
73
+ # one with the highest rank (i.e. the child).
74
+ for crosstab in cross_tables:
75
+ if len(crosstab.rvs) == 0:
76
+ continue
77
+ sorted_rvs: List[RandomVariable] = sorted(crosstab.rvs, key=(lambda _rv: rv_order[_rv]), reverse=True)
78
+ child: RandomVariable = sorted_rvs[0]
79
+ parents: List[RandomVariable] = sorted_rvs[1:]
80
+ model_factor: _ModelFactor = model_factors[child.idx]
81
+ model_factor.parent_rvs.update(parents)
82
+ model_factor.cross_tables.append(crosstab)
83
+
84
+ # Link child factors.
85
+ # When defining a factor, we need to define the child factors first.
86
+ for model_factor in model_factors:
87
+ for parent_rv in model_factor.parent_rvs:
88
+ model_factors[parent_rv.idx].child_factors.append(model_factor)
89
+
90
+ # Make the factors, depth first.
91
+ done: Set[int] = set()
92
+ for model_factor in model_factors:
93
+ _infer_cpt(model_factor, done)
94
+
95
+ # Return the CPTs that define the structure
96
+ return [model_factor.cpt for model_factor in model_factors]
97
+
98
+
99
+ def _infer_cpt(model_factor: _ModelFactor, done: Set[int]) -> None:
100
+ """
101
+ Depth-first recursively infer the model factors as CPTs.
102
+ This sets `model_factor.cpt` and `model_factor.parent_rvs` for
103
+ the given `model_factor` and its children.
104
+ """
105
+ # Only visit a model factor once.
106
+ child_rv: RandomVariable = model_factor.child_rv
107
+ if child_rv.idx in done:
108
+ return
109
+ done.add(child_rv.idx)
110
+
111
+ # Recursively visit the child factors
112
+ for child_model_factor in model_factor.child_factors:
113
+ _infer_cpt(child_model_factor, done)
114
+
115
+ # Get all relevant cross-tables to set the parameters
116
+ crosstabs: Sequence[CrossTable] = model_factor.cross_tables
117
+ child_crosstabs: Sequence[CrossTable] = [
118
+ child_model_factor.parent_sums
119
+ for child_model_factor in model_factor.child_factors
120
+ ]
121
+
122
+ # Get the parameters
123
+ rvs = [child_rv] + list(model_factor.parent_rvs)
124
+ crosstab: CrossTable = _infer_parameters(rvs, crosstabs, child_crosstabs)
125
+ cpt, parent_sums = cpt_and_parent_sums_from_crosstab(crosstab)
126
+ model_factor.cpt = cpt
127
+ model_factor.parent_sums = parent_sums
128
+
129
+
130
+ def _infer_parameters(
131
+ rvs: Sequence[RandomVariable],
132
+ crosstabs: Sequence[CrossTable],
133
+ child_crosstabs: Sequence[CrossTable],
134
+ ) -> CrossTable:
135
+ """
136
+ Make best efforts to infer a probability distribution over the give random variables,
137
+ with evidence from the given cross-tables.
138
+
139
+ Returns:
140
+ a CrossTable representing the inferred probability distribution
141
+ (not normalised to sum to one).
142
+
143
+ Assumes:
144
+ `rvs` has no duplicates.
145
+ """
146
+
147
+ # Project crosstables onto rvs, splitting them into complete and partial coverage of `rvs`.
148
+ # Completely covering cross-tables will be accumulated into `complete_crosstab` while others
149
+ # will be appended to partial_crosstabs.
150
+ complete_crosstab: Optional[CrossTable] = None
151
+ partial_crosstabs: List[CrossTable] = []
152
+ for available_crosstab in chain(crosstabs, child_crosstabs):
153
+ available_rvs: Set[RandomVariable] = set(available_crosstab.rvs)
154
+ if available_rvs.issuperset(rvs):
155
+ to_add: CrossTable = available_crosstab.project(rvs)
156
+ if complete_crosstab is None:
157
+ complete_crosstab = to_add
158
+ else:
159
+ complete_crosstab.add_all(to_add.project(rvs).items())
160
+ else:
161
+ partial_crosstabs.append(available_crosstab)
162
+
163
+ if complete_crosstab is not None:
164
+ # A direct solution was found.
165
+ # Ignore any partially covering cross-tables.
166
+ return complete_crosstab
167
+
168
+ # If there are no cross-tables available, the result is empty
169
+ if len(partial_crosstabs) == 0:
170
+ return CrossTable(rvs)
171
+
172
+ # There were no cross-tables that completely cover the given random variables.
173
+ # The following algorithm makes best attempts to coalesce a collection of
174
+ # partially covering cross-tables.
175
+
176
+ return coalesce_cross_tables(partial_crosstabs, rvs)
177
+
178
+
179
+ def _crostab_str(crosstab: CrossTable) -> str:
180
+ return '(' + ', '.join(repr(rv.name) for rv in crosstab.rvs) + ')'
181
+
182
+
183
+ def _get_rv_order(cross_tables: Sequence[CrossTable]) -> Dict[RandomVariable, int]:
184
+ """
185
+ Determine an order over the given random variables.
186
+ Returns a map from rv to its rank in the order.
187
+ """
188
+ child_parent_map: MapList[RandomVariable, RandomVariable] = MapList()
189
+ for crosstab in cross_tables:
190
+ rvs = crosstab.rvs
191
+ for i in range(len(rvs)):
192
+ child = rvs[i]
193
+ parents = rvs[i + 1:]
194
+ child_parent_map.extend(child, parents)
195
+ order: Dict[RandomVariable, int] = {}
196
+ seen: Set[RandomVariable] = set()
197
+ for child in child_parent_map.keys():
198
+ _get_rv_order_r(child, child_parent_map, seen, order)
199
+ return order
200
+
201
+
202
+ def _get_rv_order_r(
203
+ child: RandomVariable,
204
+ child_parent_map: MapList[RandomVariable, RandomVariable],
205
+ seen: Set[RandomVariable],
206
+ order: Dict[RandomVariable, int],
207
+ ):
208
+ if child not in seen:
209
+ seen.add(child)
210
+ parents = child_parent_map[child]
211
+ for parent in parents:
212
+ _get_rv_order_r(parent, child_parent_map, seen, order)
213
+ order[child] = len(order)
214
+
215
+
216
+ @dataclass
217
+ class _ModelFactor:
218
+ """
219
+ A collection of model factors defines a PGM structure,
220
+ each model factor representing a needed PGM factor.
221
+
222
+ Associated with a model factor is a list of crosstabs
223
+ that whose rvs overlap with the rvs of the model factor.
224
+ """
225
+
226
+ child_rv: RandomVariable
227
+ parent_rvs: set[RandomVariable] = field(default_factory=set)
228
+ cross_tables: List[CrossTable] = field(default_factory=list)
229
+ child_factors: List[_ModelFactor] = field(default_factory=list)
230
+
231
+ # These are set once the parameter values are inferred...
232
+ cpt: Optional[CrossTable] = None
233
+ parent_sums: Optional[CrossTable] = None
234
+
235
+ def dump(self, prefix=''):
236
+ """
237
+ For debugging.
238
+ """
239
+ print(f'{prefix}child:', self.child_rv)
240
+ print(f'{prefix}parents:', '{}' if not self.parent_rvs else self.parent_rvs)
241
+ for cross_table in self.cross_tables:
242
+ print(f'{prefix}cross-table:', *cross_table.rvs)