compiled-knowledge 4.1.0a2__cp312-cp312-win_amd64.whl → 4.2.0a1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (36) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cp312-win_amd64.pyd +0 -0
  3. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win_amd64.pyd +0 -0
  5. ck/circuit_compiler/llvm_compiler.py +4 -4
  6. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  7. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win_amd64.pyd +0 -0
  8. ck/circuit_compiler/support/input_vars.py +4 -4
  9. ck/dataset/cross_table.py +143 -79
  10. ck/dataset/dataset.py +95 -7
  11. ck/dataset/dataset_builder.py +11 -4
  12. ck/dataset/dataset_from_crosstable.py +21 -2
  13. ck/learning/coalesce_cross_tables.py +403 -0
  14. ck/learning/model_from_cross_tables.py +296 -0
  15. ck/learning/parameters.py +117 -0
  16. ck/learning/train_generative_bn.py +198 -0
  17. ck/pgm.py +10 -8
  18. ck/pgm_circuit/marginals_program.py +5 -0
  19. ck/pgm_circuit/wmc_program.py +5 -0
  20. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  21. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win_amd64.pyd +0 -0
  22. ck/probability/divergence.py +226 -0
  23. ck/probability/probability_space.py +43 -19
  24. ck/utils/map_dict.py +89 -0
  25. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  26. ck_demos/learning/__init__.py +0 -0
  27. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
  28. ck_demos/learning/demo_simple_learning.py +55 -0
  29. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  30. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/METADATA +2 -1
  31. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/RECORD +35 -26
  32. ck/learning/train_generative.py +0 -149
  33. /ck/{dataset/cross_table_probabilities.py → probability/cross_table_probability_space.py} +0 -0
  34. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/WHEEL +0 -0
  35. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/licenses/LICENSE.txt +0 -0
  36. {compiled_knowledge-4.1.0a2.dist-info → compiled_knowledge-4.2.0a1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ """
2
+ General functions for setting the parameter values of a PGM.
3
+ """
4
+ from typing import List, Tuple, TypeAlias
5
+
6
+ import numpy as np
7
+
8
+ from ck.dataset.cross_table import CrossTable
9
+ from ck.pgm import PGM, CPTPotentialFunction, Instance, SparsePotentialFunction, DensePotentialFunction, Factor
10
+ from ck.utils.map_list import MapList
11
+ from ck.utils.np_extras import NDArrayFloat64
12
+
13
+
14
+ ParameterValues: TypeAlias = List[CrossTable]
15
+
16
+
17
+ def make_factors(pgm: PGM, parameter_values: List[CrossTable]) -> None:
18
+ for factor in parameter_values:
19
+ pgm.new_factor(*factor.rvs)
20
+ set_potential_functions(pgm, parameter_values)
21
+
22
+
23
+ def set_potential_functions(pgm: PGM, parameter_values: List[CrossTable]) -> None:
24
+ """
25
+ Set the potential function of each PGM factor to one heuristically chosen,
26
+ using the given parameter values. Then set the parameter values of the potential
27
+ function to those given by `parameter_values`.
28
+
29
+ This function modifies `pgm` in-place, iteratively calling `set_potential_function`.
30
+
31
+ Args:
32
+ pgm (PGM): the PGM to have its potential functions set.
33
+ parameter_values: the parameter values,
34
+ """
35
+ for factor, factor_parameter_values in zip(pgm.factors, parameter_values):
36
+ set_potential_function(factor, factor_parameter_values)
37
+
38
+
39
+ def set_potential_function(factor: Factor, parameter_values: CrossTable) -> None:
40
+ """
41
+ Set the potential function of the given factor to one heuristically chosen,
42
+ using the given parameter values. Then set the parameter values of the potential
43
+ function to those given by `parameter_values`.
44
+
45
+ The potential function will be either a ZeroPotentialFunction, DensePotentialFunction,
46
+ or SparsePotentialFunction.
47
+
48
+ This function modifies `factor` in-place.
49
+
50
+ Args:
51
+ factor: The factor to update.
52
+ parameter_values: the parameter values,
53
+ """
54
+ number_of_parameters: int = len(parameter_values)
55
+ if number_of_parameters == 0:
56
+ factor.set_zero()
57
+ else:
58
+ if number_of_parameters < 100 or number_of_parameters > factor.number_of_states * 0.9:
59
+ pot_function: DensePotentialFunction = factor.set_dense()
60
+ else:
61
+ pot_function: SparsePotentialFunction = factor.set_sparse()
62
+ for instance, weight in parameter_values.items():
63
+ pot_function[instance] = weight
64
+
65
+
66
+ def set_zero(pgm: PGM) -> None:
67
+ """
68
+ Set the potential function of each PGM factor to zero.
69
+ """
70
+ for factor in pgm.factors:
71
+ factor.set_zero()
72
+
73
+
74
+ def set_dense(pgm: PGM, parameter_values: List[CrossTable]) -> None:
75
+ """
76
+ Set the potential function of each PGM factor to a DensePotentialFunction,
77
+ using the given parameter values.
78
+ """
79
+ for factor, cpt in zip(pgm.factors, parameter_values):
80
+ pot_function: DensePotentialFunction = factor.set_dense()
81
+ for instance, weight in cpt.items():
82
+ pot_function[instance] = weight
83
+
84
+
85
+ def set_sparse(pgm: PGM, parameter_values: List[CrossTable]) -> None:
86
+ """
87
+ Set the potential function of each PGM factor to a SparsePotentialFunction,
88
+ using the given parameter values.
89
+ """
90
+ for factor, cpt in zip(pgm.factors, parameter_values):
91
+ pot_function: SparsePotentialFunction = factor.set_sparse()
92
+ for instance, weight in cpt.items():
93
+ pot_function[instance] = weight
94
+
95
+
96
+ def set_cpt(pgm: PGM, parameter_values: List[CrossTable], normalise_cpds: bool = True) -> None:
97
+ """
98
+ Set the potential function of each PGM factor to a CPTPotentialFunction,
99
+ using the given parameter values.
100
+ """
101
+ for factor, cpt in zip(pgm.factors, parameter_values):
102
+ pot_function: CPTPotentialFunction = factor.set_cpt()
103
+
104
+ # Group cpt values by parent instance
105
+ cpds: MapList[Instance, Tuple[int, float]] = MapList()
106
+ for instance, weight in cpt.items():
107
+ cpds.append(instance[1:], (instance[0], weight))
108
+
109
+ # Set the CPDs
110
+ cpd_size = len(cpt.rvs[0]) # size of the child random variable
111
+ for parent_instance, cpd in cpds.items():
112
+ cpd_array: NDArrayFloat64 = np.zeros(cpd_size, dtype=np.float64)
113
+ for child_state_index, weight in cpd:
114
+ cpd_array[child_state_index] = weight
115
+ if normalise_cpds:
116
+ cpd_array /= cpd_array.sum()
117
+ pot_function.set_cpd(parent_instance, cpd_array)
@@ -0,0 +1,198 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Mapping, Tuple
4
+
5
+ from ck.dataset import SoftDataset, HardDataset
6
+ from ck.dataset.cross_table import CrossTable, cross_table_from_dataset
7
+ from ck.learning.parameters import set_potential_functions, ParameterValues
8
+ from ck.pgm import PGM
9
+
10
+
11
+ def train_generative_bn(
12
+ pgm: PGM,
13
+ dataset: HardDataset | SoftDataset,
14
+ *,
15
+ dirichlet_prior: float | Mapping[int, float | CrossTable] = 0,
16
+ check_bayesian_network: bool = True,
17
+ ) -> None:
18
+ """
19
+ Maximum-likelihood, generative training for a Bayesian network.
20
+
21
+ The potential function of the given PGM will be set to new potential functions
22
+ with the learned parameter values.
23
+
24
+ Args:
25
+ pgm: the probabilistic graphical model defining the model structure.
26
+ Potential function values are ignored and need not be set.
27
+ dataset: a dataset of random variable states.
28
+ dirichlet_prior: provides a Dirichlet prior for each factor in `pgm`.
29
+ This can be represented in multiple ways:
30
+ (a) as a uniform prior that is the same for all factors, represented as a float value,
31
+ (b) as a mapping from a factor index to a uniform prior, i.e., a float value,
32
+ (c) as a mapping from a factor index to an arbitrary Dirichlet prior, i.e., a cross-table.
33
+ If there is no entry in the mapping for a factor, then the value 0 will be used for that factor.
34
+ If a cross-table is provided as a prior, then it must have the same random variables as
35
+ the factor it pertains to.
36
+ The default value for `dirichlet_prior` is 0.
37
+ See `CrossTable` for more explanation.
38
+ check_bayesian_network: if true and not `pgm.is_structure_bayesian` an exception will be raised.
39
+
40
+ Raises:
41
+ ValueError: if the given PGM does not have a Bayesian network structure, and check_bayesian_network is True.
42
+ """
43
+ if check_bayesian_network and not pgm.is_structure_bayesian:
44
+ raise ValueError('the given PGM is not a Bayesian network')
45
+ cpts: List[CrossTable] = get_cpts(
46
+ pgm=pgm,
47
+ dataset=dataset,
48
+ dirichlet_prior=dirichlet_prior,
49
+ )
50
+ set_potential_functions(pgm, cpts)
51
+
52
+
53
+ def get_cpts(
54
+ pgm: PGM,
55
+ dataset: HardDataset | SoftDataset,
56
+ *,
57
+ dirichlet_prior: float | Mapping[int, float | CrossTable] = 0,
58
+ ) -> ParameterValues:
59
+ """
60
+ This function applies `cpt_from_crosstab` to each cross-table from `get_factor_cross_tables`.
61
+ The resulting parameter values are CPTs that can be used directly to update the parameters
62
+ of the given PGM, so long as it has a Bayesian network structure.
63
+
64
+ To update the given PGM from the resulting `cpts` use `set_potential_functions(pgm, cpts)`.
65
+
66
+ Args:
67
+ pgm: the probabilistic graphical model defining the model structure.
68
+ Potential function values are ignored and need not be set.
69
+ dataset: a dataset of random variable states.
70
+ dirichlet_prior: provides a Dirichlet prior for each factor in `pgm`.
71
+ This can be represented in multiple ways:
72
+ (a) as a uniform prior that is the same for all factors, represented as a float value,
73
+ (b) as a mapping from a factor index to a uniform prior, i.e., a float value,
74
+ (c) as a mapping from a factor index to an arbitrary Dirichlet prior, i.e., a cross-table.
75
+ If there is no entry in the mapping for a factor, then the value 0 will be used for that factor.
76
+ If a cross-table is provided as a prior, then it must have the same random variables as
77
+ the factor it pertains to.
78
+ The default value for `dirichlet_prior` is 0.
79
+ See `CrossTable` for more explanation.
80
+
81
+ Returns:
82
+ ParameterValues object, a CPT for each factor in the given PGM, as a list of cross-tables, co-indexed
83
+ with the PGM factors.
84
+ """
85
+ cross_tables: List[CrossTable] = get_factor_cross_tables(
86
+ pgm=pgm,
87
+ dataset=dataset,
88
+ dirichlet_prior=dirichlet_prior,
89
+ )
90
+ cpts: List[CrossTable] = list(map(cpt_from_crosstab, cross_tables))
91
+ return cpts
92
+
93
+
94
+ def get_factor_cross_tables(
95
+ pgm: PGM,
96
+ dataset: HardDataset | SoftDataset,
97
+ *,
98
+ dirichlet_prior: float | Mapping[int, float | CrossTable] = 0,
99
+ ) -> ParameterValues:
100
+ """
101
+ Compute a cross-table for each factor of the given PGM, using the data from
102
+ the given dataset.
103
+
104
+ Args:
105
+ pgm: the probabilistic graphical model defining the model structure.
106
+ Potential function values are ignored and need not be set.
107
+ dataset: a dataset of random variable states.
108
+ dirichlet_prior: provides a Dirichlet prior for each factor in `pgm`.
109
+ This can be represented in multiple ways:
110
+ (a) as a uniform prior that is the same for all factors, represented as a float value,
111
+ (b) as a mapping from a factor index to a uniform prior, i.e., a float value,
112
+ (c) as a mapping from a factor index to an arbitrary Dirichlet prior, i.e., a cross-table.
113
+ If there is no entry in the mapping for a factor, then the value 0 will be used for that factor.
114
+ If a cross-table is provided as a prior, then it must have the same random variables as
115
+ the factor it pertains to.
116
+ The default value for `dirichlet_prior` is 0.
117
+ See `CrossTable` for more explanation.
118
+
119
+ Returns:
120
+ ParameterValues object, a crosstable for each factor in the given PGM, as
121
+ per `cross_table_from_dataset`.
122
+
123
+ Assumes:
124
+ every random variable of the PGM is in the dataset.
125
+ """
126
+ factor_dict: Mapping[int, float | CrossTable]
127
+ default_prior: float
128
+ if isinstance(dirichlet_prior, (float, int)):
129
+ factor_dict = {}
130
+ default_prior = dirichlet_prior
131
+ else:
132
+ factor_dict = dirichlet_prior
133
+ default_prior = 0
134
+
135
+ cross_tables: List[CrossTable] = [
136
+ cross_table_from_dataset(
137
+ dataset,
138
+ factor.rvs,
139
+ dirichlet_prior=factor_dict.get(factor.idx, default_prior),
140
+ )
141
+ for factor in pgm.factors
142
+ ]
143
+ return cross_tables
144
+
145
+
146
+ def cpt_from_crosstab(crosstab: CrossTable) -> CrossTable:
147
+ """
148
+ Convert the given cross-table to a conditional probability table (CPT),
149
+ assuming the first random variable of the cross-table is the child
150
+ and remaining random variables are the parents.
151
+
152
+ Args:
153
+ crosstab: a CrossTable representing the weight of unique instances.
154
+
155
+ Returns:
156
+ A cross-table that is a conditional probability table.
157
+
158
+ Assumes:
159
+ the first random variable in `crosstab.rvs` is the child random variable.
160
+ """
161
+ return cpt_and_parent_sums_from_crosstab(crosstab)[0]
162
+
163
+
164
+ def cpt_and_parent_sums_from_crosstab(crosstab: CrossTable) -> Tuple[CrossTable, CrossTable]:
165
+ """
166
+ Convert the given cross-table to a conditional probability table (CPT),
167
+ assuming the first random variable of the cross-table is the child
168
+ and remaining random variables are the parents.
169
+
170
+ Args:
171
+ crosstab: a CrossTable representing the weight of unique instances.
172
+
173
+ Returns:
174
+ A cross-table that is a conditional probability table.
175
+ A cross-table of the parent sums that were divided out of `crosstab`
176
+
177
+ Assumes:
178
+ the first random variable in `crosstab.rvs` is the child random variable.
179
+ """
180
+ # Get the sum of weights for parent states
181
+ parent_sums: CrossTable = CrossTable(
182
+ rvs=crosstab.rvs[1:],
183
+ update=(
184
+ (instance[1:], weight)
185
+ for instance, weight in crosstab.items()
186
+ )
187
+ )
188
+
189
+ # Construct the normalised cross-tables, i.e., the CPTs.
190
+ cpt = CrossTable(
191
+ rvs=crosstab.rvs,
192
+ update=(
193
+ (instance, weight / parent_sums[instance[1:]])
194
+ for instance, weight in crosstab.items()
195
+ )
196
+ )
197
+
198
+ return cpt, parent_sums
ck/pgm.py CHANGED
@@ -596,9 +596,11 @@ class PGM:
596
596
 
597
597
  # Factors form a DAG
598
598
  states: NDArrayUInt8 = np.zeros(self.number_of_factors, dtype=np.uint8)
599
- for factor in self._factors:
600
- if self._has_cycle(factor, child_to_factor, states):
601
- return False
599
+ if any(
600
+ self._has_cycle(factor, child_to_factor, states)
601
+ for factor in self._factors
602
+ ):
603
+ return False
602
604
 
603
605
  # All tests passed
604
606
  return True
@@ -778,7 +780,7 @@ class PGM:
778
780
  next_prefix: str = prefix + indent
779
781
  next_next_prefix: str = next_prefix + indent
780
782
 
781
- print(f'{prefix}PGM id={id(self)} name={self.name!r}')
783
+ print(f'{prefix}PGM id={id(self)}')
782
784
  self.dump_synopsis(prefix=next_prefix, precision=precision, max_state_digits=max_state_digits)
783
785
 
784
786
  print(f'{prefix}random variables ({self.number_of_rvs})')
@@ -792,16 +794,16 @@ class PGM:
792
794
 
793
795
  print(f'{prefix}factors ({self.number_of_factors})')
794
796
  for factor in self.factors:
795
- rv_idxs = [rv.idx for rv in factor.rvs]
797
+ factor_rvs = ', '.join(repr(rv.name) for rv in factor.rvs)
796
798
  if factor.is_zero:
797
- function_ref = '<zero>'
799
+ function_ref = '<ZeroPotentialFunction>'
798
800
  else:
799
801
  function = factor.function
800
802
  function_ref = f'{id(function)}: {function.__class__.__name__}'
801
803
 
802
- print(f'{next_prefix}{factor.idx:>3} rvs={rv_idxs} function={function_ref}')
804
+ print(f'{next_prefix}{factor.idx:>3} rvs=({factor_rvs}) function={function_ref}')
803
805
 
804
- print(f'{prefix}functions ({self.number_of_functions})')
806
+ print(f'{prefix}functions, excluding ZeroPotentialFunction ({sum(1 for _ in self.non_zero_functions)})')
805
807
  for function in sorted(self.non_zero_functions, key=lambda f: id(f)):
806
808
  print(f'{next_prefix}{id(function):>13}: {function.__class__.__name__}')
807
809
  function.dump(prefix=next_next_prefix, show_function_values=show_function_values, show_id_class=False)
@@ -308,6 +308,11 @@ class MarginalsProgram(ProgramWithSlotmap, ProbabilitySpace):
308
308
  The sampler will yield state lists, where the state
309
309
  values are co-indexed with rvs, or self.rvs if rvs is None.
310
310
 
311
+ For more information about this sampler, see the publication:
312
+ Suresh, S., Drake, B. (2025). Sampling of Large Probabilistic Graphical Models
313
+ Using Arithmetic Circuits. AI 2024: Advances in Artificial Intelligence. AI 2024.
314
+ Lecture Notes in Computer Science, vol 15443. https://doi.org/10.1007/978-981-96-0351-0_13.
315
+
311
316
  Args:
312
317
  rvs: the list of random variables to sample; the
313
318
  yielded state vectors are co-indexed with rvs; if None,
@@ -132,6 +132,11 @@ class WMCProgram(ProgramWithSlotmap, ProbabilitySpace):
132
132
  * calls rand.random() once and rand.randrange(...) n times,
133
133
  * calls self.program().compute_result() at least once and <= 1 + m.
134
134
 
135
+ For more information about this sampler, see the publication:
136
+ Suresh, S., Drake, B. (2025). Sampling of Large Probabilistic Graphical Models
137
+ Using Arithmetic Circuits. AI 2024: Advances in Artificial Intelligence. AI 2024.
138
+ Lecture Notes in Computer Science, vol 15443. https://doi.org/10.1007/978-981-96-0351-0_13.
139
+
135
140
  Args:
136
141
  rvs: the list of random variables to sample; the
137
142
  yielded state vectors are co-indexed with rvs; if None,
@@ -13,7 +13,7 @@
13
13
  "/O2"
14
14
  ],
15
15
  "include_dirs": [
16
- "C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-zvpv36cx\\Lib\\site-packages\\numpy\\_core\\include"
16
+ "C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-75r5_lk9\\Lib\\site-packages\\numpy\\_core\\include"
17
17
  ],
18
18
  "name": "ck.pgm_compiler.support.circuit_table._circuit_table_cy",
19
19
  "sources": [
@@ -0,0 +1,226 @@
1
+ """
2
+ This module implements several divergences which measure the difference
3
+ between two distributions.
4
+ """
5
+ import math
6
+ from typing import Sequence
7
+
8
+ import numpy as np
9
+
10
+ from ck.pgm import RandomVariable, rv_instances_as_indicators, PGM
11
+ from ck.probability.probability_space import ProbabilitySpace
12
+
13
+ _NAN: float = np.nan # Not-a-number (i.e., the result of an invalid calculation).
14
+
15
+
16
+ def kl(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
17
+ """
18
+ Compute the Kullback-Leibler divergence between p & q,
19
+ where p is the true distribution.
20
+
21
+ This implementation uses logarithms, base 2.
22
+
23
+ Args:
24
+ p: a probability space to compare to.
25
+ q: the other probability space.
26
+
27
+ Returns:
28
+ the Kullback–Leibler (KL) divergence of p & q, where p is
29
+ the true distribution.
30
+
31
+ Raises:
32
+ ValueError: if `p` and `q` do not have compatible random variables.specifically:
33
+ * `len(self.rvs) == len(other.rvs)`
34
+ * `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
35
+ * `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
36
+
37
+ Warning:
38
+ this method will enumerate the whole probability space.
39
+ """
40
+ if not _compatible_rvs(p.rvs, q.rvs):
41
+ raise ValueError('incompatible random variables')
42
+
43
+ total = 0.0
44
+ for x in rv_instances_as_indicators(*p.rvs):
45
+ p_x = p.probability(*x)
46
+ q_x = q.probability(*x)
47
+ if p_x <= 0 or q_x <= 0:
48
+ return _NAN
49
+ total += p_x * math.log2(p_x / q_x)
50
+ return total
51
+
52
+
53
+ def pseudo_kl(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
54
+ """
55
+ A kind of KL divergence, factored by the structure of `p`.
56
+ This is an experimental measure.
57
+
58
+ This implementation uses logarithms, base 2.
59
+
60
+ Args:
61
+ p: a probability space to compare to.
62
+ q: the other probability space.
63
+
64
+ Returns:
65
+ the factored histogram intersection between the two probability spaces.
66
+
67
+ Raises:
68
+ ValueError: if `p` and `q` do not have compatible random variables.specifically:
69
+ * `len(self.rvs) == len(other.rvs)`
70
+ * `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
71
+ * `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
72
+ ValueError: if not all random variable of `p` are from a single PGM, which must
73
+ have a Bayesian network structure.
74
+ """
75
+ p_rvs: Sequence[RandomVariable] = p.rvs
76
+ q_rvs: Sequence[RandomVariable] = q.rvs
77
+
78
+ if not _compatible_rvs(p_rvs, q_rvs):
79
+ raise ValueError('incompatible random variables')
80
+
81
+ if len(p_rvs) == 0:
82
+ return _NAN
83
+
84
+ pgm: PGM = p_rvs[0].pgm
85
+ if any(rv.pgm is not pgm for rv in p_rvs):
86
+ raise ValueError('p random variables are not from a single PGM.')
87
+ if not pgm.is_structure_bayesian:
88
+ raise ValueError('p does not have Bayesian network structure.')
89
+
90
+ # Across the two spaces, corresponding random variables are equivalent;
91
+ # i.e., same number of states and same `idx` values. Therefore,
92
+ # indicators from either one space can be used in both spaces.
93
+
94
+ total: float = 0
95
+ for factor in pgm.factors:
96
+ for x in rv_instances_as_indicators(*factor.rvs): # every possible state of factor rvs
97
+ p_x = p.probability(*x)
98
+ q_x = q.probability(*x)
99
+ if p_x <= 0 or q_x <= 0:
100
+ return _NAN
101
+ total += p_x * math.log2(p_x / q_x)
102
+ return total
103
+
104
+
105
+ def hi(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
106
+ """
107
+ Compute the histogram intersection between this probability spaces and the given other.
108
+
109
+ The histogram intersection between two probability spaces P and Q,
110
+ with state spaces X, is defined as:
111
+ HI(P, Q) = sum(min(P(x), Q(x)) for x in X)
112
+
113
+ Args:
114
+ p: a probability space to compare to.
115
+ q: the other probability space.
116
+
117
+ Returns:
118
+ the histogram intersection between the two probability spaces.
119
+
120
+ Raises:
121
+ ValueError: if `p` and `q` do not have compatible random variables.specifically:
122
+ * `len(self.rvs) == len(other.rvs)`
123
+ * `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
124
+ * `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
125
+
126
+ Warning:
127
+ this method will enumerate the whole probability space.
128
+
129
+ """
130
+ p_rvs: Sequence[RandomVariable] = p.rvs
131
+ q_rvs: Sequence[RandomVariable] = q.rvs
132
+
133
+ if not _compatible_rvs(p_rvs, q_rvs):
134
+ raise ValueError('incompatible random variables')
135
+
136
+ # Across the two spaces, corresponding random variables are equivalent;
137
+ # i.e., same number of states and same `idx` values. Therefore,
138
+ # indicators from either one space can be used in both spaces.
139
+
140
+ return sum(
141
+ min(p.probability(*x), q.probability(*x))
142
+ for x in rv_instances_as_indicators(*p_rvs)
143
+ )
144
+
145
+
146
+ def fhi(p: ProbabilitySpace, q: ProbabilitySpace) -> float:
147
+ """
148
+ Compute the factored histogram intersection between this probability spaces and the given other.
149
+
150
+ The factored histogram intersection between two probability spaces P and Q,
151
+ with state spaces X and factorisation F, is defined as:
152
+ FHI(P, Q) = 1/n sum(P(Y=y) CHI(P, Q, X | Y=y)
153
+ where:
154
+ CHI(P, Q, X | Y=y) = HI(P(X | Y=y), Q(X | Y=y))
155
+ HI(P, Q) = sum(min(P(X=x), Q(X=x)) for x in f)
156
+
157
+ The value of _n_ is the sum ofP(Y=y) over all CPT rows. However,
158
+ this always equals the number of CPTs, i.e., the number of random
159
+ variables.
160
+
161
+ The factorisation F is taken from the `p`.
162
+
163
+ For more information about factored histogram intersection, see the publication:
164
+ Suresh, S., Drake, B. (2025). Sampling of Large Probabilistic Graphical Models
165
+ Using Arithmetic Circuits. AI 2024: Advances in Artificial Intelligence. AI 2024.
166
+ Lecture Notes in Computer Science, vol 15443. https://doi.org/10.1007/978-981-96-0351-0_13.
167
+
168
+ Args:
169
+ p: a probability space to compare to.
170
+ q: the other probability space.
171
+
172
+ Returns:
173
+ the factored histogram intersection between the two probability spaces.
174
+
175
+ Raises:
176
+ ValueError: if `p` and `q` do not have compatible random variables.specifically:
177
+ * `len(self.rvs) == len(other.rvs)`
178
+ * `len(other.rvs[i]) == len(self.rvs[i])` for all `i`
179
+ * `other.rvs[i].idx == self.rvs[i].idx` for all `i`.
180
+ ValueError: if not all random variable of `p` are from a single PGM, which must
181
+ have a Bayesian network structure.
182
+ """
183
+ p_rvs: Sequence[RandomVariable] = p.rvs
184
+ q_rvs: Sequence[RandomVariable] = q.rvs
185
+
186
+ if not _compatible_rvs(p_rvs, q_rvs):
187
+ raise ValueError('incompatible random variables')
188
+
189
+ if len(p_rvs) == 0:
190
+ return 0
191
+
192
+ pgm: PGM = p_rvs[0].pgm
193
+ if any(rv.pgm is not pgm for rv in p_rvs):
194
+ raise ValueError('p random variables are not from a single PGM.')
195
+ if not pgm.is_structure_bayesian:
196
+ raise ValueError('p does not have Bayesian network structure.')
197
+
198
+ # Across the two spaces, corresponding random variables are equivalent;
199
+ # i.e., same number of states and same `idx` values. Therefore,
200
+ # indicators from either one space can be used in both spaces.
201
+
202
+ # Loop over all CPTs, accumulating the total
203
+ total: float = 0
204
+ for factor in pgm.factors:
205
+ child: RandomVariable = factor.rvs[0]
206
+ parents: Sequence[RandomVariable] = factor.rvs[1:]
207
+ # Loop over all rows of the CPT
208
+ for parent_indicators in rv_instances_as_indicators(*parents):
209
+ p_marginal = p.marginal_distribution(child, condition=parent_indicators)
210
+ q_marginal = q.marginal_distribution(child, condition=parent_indicators)
211
+ row_hi = np.minimum(p_marginal, q_marginal).sum().item()
212
+ pr_row = p.probability(*parent_indicators)
213
+ total += pr_row * row_hi
214
+
215
+ return total / len(p_rvs)
216
+
217
+
218
+ def _compatible_rvs(rvs1: Sequence[RandomVariable], rvs2: Sequence[RandomVariable]) -> bool:
219
+ """
220
+ The rvs are compatible if they have the same number of random variables
221
+ and the corresponding indicators are equal.
222
+ """
223
+ return (
224
+ len(rvs1) == len(rvs2)
225
+ and all(len(rv1) == len(rv2) and rv1.idx == rv2.idx for rv1, rv2 in zip(rvs1, rvs2))
226
+ )