compiled-knowledge 4.0.0a24__cp312-cp312-win32.whl → 4.1.0__cp312-cp312-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compiled-knowledge might be problematic. Click here for more details.

Files changed (58) hide show
  1. ck/circuit/_circuit_cy.c +1 -1
  2. ck/circuit/_circuit_cy.cp312-win32.pyd +0 -0
  3. ck/circuit/tmp_const.py +5 -4
  4. ck/circuit_compiler/cython_vm_compiler/_compiler.c +152 -152
  5. ck/circuit_compiler/cython_vm_compiler/_compiler.cp312-win32.pyd +0 -0
  6. ck/circuit_compiler/interpret_compiler.py +2 -2
  7. ck/circuit_compiler/llvm_compiler.py +4 -4
  8. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.c +1 -1
  9. ck/circuit_compiler/support/circuit_analyser/_circuit_analyser_cy.cp312-win32.pyd +0 -0
  10. ck/circuit_compiler/support/input_vars.py +4 -4
  11. ck/circuit_compiler/support/llvm_ir_function.py +4 -4
  12. ck/dataset/__init__.py +1 -0
  13. ck/dataset/cross_table.py +334 -0
  14. ck/dataset/dataset.py +682 -0
  15. ck/dataset/dataset_builder.py +519 -0
  16. ck/dataset/dataset_compute.py +140 -0
  17. ck/dataset/dataset_from_crosstable.py +64 -0
  18. ck/dataset/dataset_from_csv.py +151 -0
  19. ck/dataset/sampled_dataset.py +96 -0
  20. ck/example/diamond_square.py +3 -1
  21. ck/example/triangle_square.py +3 -1
  22. ck/example/truss.py +3 -1
  23. ck/in_out/parse_net.py +21 -19
  24. ck/in_out/parser_utils.py +7 -3
  25. ck/learning/__init__.py +0 -0
  26. ck/learning/coalesce_cross_tables.py +403 -0
  27. ck/learning/model_from_cross_tables.py +296 -0
  28. ck/learning/parameters.py +117 -0
  29. ck/learning/train_generative_bn.py +198 -0
  30. ck/pgm.py +105 -92
  31. ck/pgm_circuit/marginals_program.py +5 -0
  32. ck/pgm_circuit/mpe_program.py +3 -4
  33. ck/pgm_circuit/pgm_circuit.py +27 -18
  34. ck/pgm_circuit/program_with_slotmap.py +27 -46
  35. ck/pgm_circuit/support/compile_circuit.py +2 -4
  36. ck/pgm_circuit/wmc_program.py +5 -0
  37. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.c +1 -1
  38. ck/pgm_compiler/support/circuit_table/_circuit_table_cy.cp312-win32.pyd +0 -0
  39. ck/probability/cross_table_probability_space.py +53 -0
  40. ck/probability/divergence.py +226 -0
  41. ck/probability/empirical_probability_space.py +1 -0
  42. ck/probability/probability_space.py +53 -30
  43. ck/program/raw_program.py +23 -16
  44. ck/sampling/sampler_support.py +5 -6
  45. ck/utils/iter_extras.py +3 -2
  46. ck/utils/local_config.py +16 -8
  47. ck_demos/dataset/__init__.py +0 -0
  48. ck_demos/dataset/demo_dataset_builder.py +37 -0
  49. ck_demos/dataset/demo_dataset_from_sampler.py +18 -0
  50. ck_demos/learning/__init__.py +0 -0
  51. ck_demos/learning/demo_bayesian_network_from_cross_tables.py +70 -0
  52. ck_demos/learning/demo_simple_learning.py +55 -0
  53. ck_demos/sampling/demo_wmc_direct_sampler.py +2 -2
  54. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/METADATA +2 -1
  55. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/RECORD +58 -37
  56. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/WHEEL +0 -0
  57. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/licenses/LICENSE.txt +0 -0
  58. {compiled_knowledge-4.0.0a24.dist-info → compiled_knowledge-4.1.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import ctypes as ct
4
4
  from dataclasses import dataclass
5
- from typing import Sequence, Optional, Dict, List, Tuple, Callable
5
+ from typing import Sequence, Optional, Dict, List, Tuple, Callable, assert_never
6
6
 
7
7
  import numpy as np
8
8
 
@@ -174,7 +174,7 @@ def _make_instructions(
174
174
  elif op_node.symbol == ADD:
175
175
  operation = sum
176
176
  else:
177
- assert False, 'symbol not understood'
177
+ assert_never('not reached')
178
178
 
179
179
  instructions.append(_Instruction(operation, args, dest))
180
180
 
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from enum import Enum
4
+ from enum import Enum, auto
5
5
  from typing import Sequence, Optional, Tuple, Dict, Protocol, assert_never
6
6
 
7
7
  import llvmlite.binding as llvm
@@ -15,9 +15,9 @@ from ..program.raw_program import RawProgramFunction
15
15
 
16
16
 
17
17
  class Flavour(Enum):
18
- STACK = 0 # No working temporary memory requested - all on stack.
19
- TMPS = 1 # Working temporary memory used for op node calculations.
20
- FUNCS = 2 # Working temporary memory used for op node calculations, one sub-function per op-node.
18
+ STACK = auto() # No working temporary memory requested - all on stack.
19
+ TMPS = auto() # Working temporary memory used for op node calculations.
20
+ FUNCS = auto() # Working temporary memory used for op node calculations, one sub-function per op-node.
21
21
 
22
22
 
23
23
  DEFAULT_TYPE_INFO: TypeInfo = DataType.FLOAT_64.value
@@ -13,7 +13,7 @@
13
13
  "/O2"
14
14
  ],
15
15
  "include_dirs": [
16
- "C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-b7kylfcw\\Lib\\site-packages\\numpy\\_core\\include"
16
+ "C:\\Users\\runneradmin\\AppData\\Local\\Temp\\build-env-gp1o6j1g\\Lib\\site-packages\\numpy\\_core\\include"
17
17
  ],
18
18
  "name": "ck.circuit_compiler.support.circuit_analyser._circuit_analyser_cy",
19
19
  "sources": [
@@ -3,7 +3,7 @@ This module supports circuit compilers and interpreters by inferring and checkin
3
3
  that are explicitly or implicitly referred to by a client.
4
4
  """
5
5
 
6
- from enum import Enum
6
+ from enum import Enum, auto
7
7
  from itertools import chain
8
8
  from typing import Sequence, Optional, Set, Iterable, List
9
9
 
@@ -15,9 +15,9 @@ class InferVars(Enum):
15
15
  An enum specifying how to automatically infer a program's input variables.
16
16
  """
17
17
 
18
- ALL = 'all' # all circuit vars are input vars
19
- REF = 'ref' # only referenced vars are input vars
20
- LOW = 'low' # input vars are circuit vars[0 : max_referenced + 1]
18
+ ALL = auto() # all circuit vars are input vars
19
+ REF = auto() # only referenced vars are input vars
20
+ LOW = auto() # input vars are circuit vars[0 : max_referenced + 1]
21
21
 
22
22
 
23
23
  # Type for specifying input circuit vars
@@ -213,10 +213,10 @@ def compile_llvm_program(
213
213
  Compile the given LLVM program.
214
214
 
215
215
  Returns:
216
- (engine, function) where
217
- engine: is an LLVM execution engine, which must remain
218
- in memory for the returned function to be valid.
219
- function: is the raw Python callable for the compiled function.
216
+ `engine` an LLVM execution engine, which must remain
217
+ in memory for the returned function to be valid,
218
+
219
+ `function` the raw Python callable for the compiled function.
220
220
  """
221
221
  _init_llvm()
222
222
 
ck/dataset/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .dataset import Dataset, HardDataset, SoftDataset
@@ -0,0 +1,334 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple, Sequence, Iterator, Iterable, Optional, MutableMapping, Dict, assert_never
4
+
5
+ from ck.dataset import SoftDataset, HardDataset
6
+ from ck.pgm import RandomVariable, rv_instances, Instance
7
+
8
+
9
+ class CrossTable(MutableMapping[Instance, float]):
10
+ """
11
+ A cross-table records the total weight for possible combinations
12
+ of states for some random variables, i.e., the weight of unique instances.
13
+
14
+ A cross-table is a dictionary mapping from state indices of the cross-table
15
+ random variables (an instance, as a tuple) to a weight (as a float).
16
+
17
+ Given a cross-table `ct`, then for each `instance in ct.keys()`:
18
+ `len(instance) == len(ct.rvs)`,
19
+ and `0 <= instance[j] < len(ct.rvs[i])`,
20
+ and `0 < ct[instance]`.
21
+
22
+ Zero weighted instances are not explicitly represented in a cross-table.
23
+ Given a cross-table `ct` then the following is always true.
24
+ `x in ct.keys()` is true if and only if `ct[x] != 0`.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ rvs: Sequence[RandomVariable],
30
+ dirichlet_prior: float | CrossTable = 0,
31
+ update: Iterable[Tuple[Instance, float]] = (),
32
+ ):
33
+ """
34
+ Construct a cross-table for the given random variables.
35
+
36
+ The cross-table can be initialised with a Dirichlet prior, x. Practically
37
+ this amounts to adding a weight of x to each possible combination of
38
+ random variable states. That is, a Dirichlet prior of x results in x pseudocounts
39
+ for each possible combination of states.
40
+
41
+ Args:
42
+ rvs: the random variables that this cross-table records weights for. Instances
43
+ in this cross-table are tuples of state indexes, co-indexed with `rvs`.
44
+ dirichlet_prior: provides a prior for `rvs`. This can be represented either:
45
+ (a) as a uniform prior, represented as a float value,
46
+ (b) as an arbitrary prior, represented as a cross-table.
47
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
48
+ The default value for `dirichlet_prior` is 0.
49
+ update: an optional iterable of (instance, weight) tuples to add to
50
+ the cross-table at construction time.
51
+ """
52
+ self._rvs: Tuple[RandomVariable, ...] = tuple(rvs)
53
+ self._dict: Dict[Instance, float]
54
+
55
+ if isinstance(dirichlet_prior, CrossTable):
56
+ # rv_map[i] is where rvs[i] appears in the dirichlet_prior cross-table
57
+ # It will be used to map instances of the prior to instances of self.
58
+ rv_map: List[int] = [
59
+ dirichlet_prior.rvs.index(rv)
60
+ for rv in rvs
61
+ ]
62
+
63
+ # Copy items from the prior to self, mapping the instances as needed
64
+ self._dict = {
65
+ tuple(prior_instance[select] for select in rv_map): weight
66
+ for prior_instance, weight in dirichlet_prior.items()
67
+ }
68
+
69
+ elif isinstance(dirichlet_prior, (float, int)):
70
+ if dirichlet_prior != 0:
71
+ # Initialise self with every possible combination of rvs states.
72
+ instance: Instance
73
+ self._dict = {
74
+ instance: dirichlet_prior
75
+ for instance in rv_instances(*self._rvs)
76
+ }
77
+ else:
78
+ self._dict = {}
79
+ else:
80
+ assert_never('not reached')
81
+
82
+ # Apply any provided updates
83
+ self.add_all(update)
84
+
85
+ def __eq__(self, other) -> bool:
86
+ """
87
+ Two cross-tables are equal if they have the same sequence of random variables
88
+ and their instance weights are equal.
89
+ """
90
+ return isinstance(other, CrossTable) and self._rvs == other._rvs and self._dict == other._dict
91
+
92
+ def __setitem__(self, key: Instance, value) -> None:
93
+ if value == 0:
94
+ self._dict.pop(key, None)
95
+ else:
96
+ self._dict[key] = value
97
+
98
+ def __delitem__(self, key: Instance) -> None:
99
+ del self._dict[key]
100
+
101
+ def __getitem__(self, key: Instance) -> float:
102
+ """
103
+ Returns:
104
+ the weight of the given instance.
105
+ This will always return a value, even if the key is not in the underlying dictionary.
106
+ """
107
+ return self._dict.get(key, 0)
108
+
109
+ def __len__(self) -> int:
110
+ """
111
+ Returns:
112
+ the number of instances in the cross-table with non-zero weight.
113
+ """
114
+ return len(self._dict)
115
+
116
+ def __iter__(self) -> Iterator[Instance]:
117
+ """
118
+ Returns:
119
+ an iterator over the cross-table instances with non-zero weight.
120
+ """
121
+ return iter(self._dict)
122
+
123
+ def items(self) -> Iterable[Tuple[Instance, float]]:
124
+ """
125
+ Returns:
126
+ an iterable over (instance, weight) pairs.
127
+ """
128
+ return self._dict.items()
129
+
130
+ @property
131
+ def rvs(self) -> Sequence[RandomVariable]:
132
+ """
133
+ The random variables that this cross-table refers to.
134
+ """
135
+ return self._rvs
136
+
137
+ def add(self, instance: Instance, weight: float) -> None:
138
+ """
139
+ Add the given weighted instance to the cross-table.
140
+
141
+ Args:
142
+ instance: a tuple of state indices, co-indexed with `self.rvs`.
143
+ weight: the weight (generalised count) to add to the cross-table. Normally the
144
+ weight will be > 0.
145
+ """
146
+ self[instance] = self._dict.get(instance, 0) + weight
147
+
148
+ def add_all(self, to_add: Iterable[Tuple[Instance, float]]) -> None:
149
+ """
150
+ Add the given weighted instances to the cross-table.
151
+
152
+ Args:
153
+ to_add: an iterable of (instance, weight) tuples to add to the cross-table.
154
+ """
155
+ for instance, weight in to_add:
156
+ self.add(instance, weight)
157
+
158
+ def mul(self, multiplier: float) -> None:
159
+ """
160
+ Multiply all weights by the given multiplier.
161
+ """
162
+ if multiplier == 0:
163
+ self._dict.clear()
164
+ elif multiplier == 1:
165
+ pass
166
+ else:
167
+ for instance in self._dict.keys():
168
+ self._dict[instance] *= multiplier
169
+
170
+ def total_weight(self) -> float:
171
+ """
172
+ Calculate the total weight of this cross-table.
173
+ """
174
+ return sum(self.values())
175
+
176
+ def project(self, rvs: Sequence[RandomVariable]) -> CrossTable:
177
+ """
178
+ Project this cross-table onto the given set of random variables.
179
+
180
+ If successful, this method will always return a new CrossTable object.
181
+
182
+ Returns:
183
+ a CrossTable with the given sequence of random variables.
184
+
185
+ Assumes:
186
+ `rvs` is a subset of the cross-table's random variables.
187
+ """
188
+ # Mapping rv_map[i] is the index into `self.rvs` for `rvs[i]`.
189
+ rv_map: List[int] = [self.rvs.index(rv) for rv in rvs]
190
+
191
+ return CrossTable(
192
+ rvs=rvs,
193
+ update=(
194
+ (tuple(instance[i] for i in rv_map), weight)
195
+ for instance, weight in self._dict.items()
196
+ ),
197
+ )
198
+
199
+ def dump(self, *, show_rvs: bool = True, show_weights: bool = True, as_states: bool = False) -> None:
200
+ """
201
+ Dump the cross-table in a human-readable format.
202
+ If as_states is true, then instance states are dumped instead of just state indexes.
203
+
204
+ Args:
205
+ show_rvs: If `True`, the random variables are dumped.
206
+ show_weights: If `True`, the instance weights are dumped.
207
+ as_states: If `True`, the states are dumped instead of just state indexes.
208
+ """
209
+ if show_rvs:
210
+ rvs = ', '.join(str(rv) for rv in self.rvs)
211
+ print(f'rvs: [{rvs}]')
212
+ print(f'instances ({len(self)}, with total weight {self.total_weight()}):')
213
+ for instance, weight in self.items():
214
+ if as_states:
215
+ instance_str = ', '.join(repr(rv.states[idx]) for idx, rv in zip(instance, self.rvs))
216
+ else:
217
+ instance_str = ', '.join(str(idx) for idx in instance)
218
+ if show_weights:
219
+ print(f'({instance_str}) * {weight}')
220
+ else:
221
+ print(f'({instance_str})')
222
+
223
+
224
+ def cross_table_from_dataset(
225
+ dataset: HardDataset | SoftDataset,
226
+ rvs: Optional[Sequence[RandomVariable]] = None,
227
+ *,
228
+ dirichlet_prior: float | CrossTable = 0,
229
+ ) -> CrossTable:
230
+ """
231
+ Generate a cross-table for the given random variables, using the given dataset, represented
232
+ as a dictionary, mapping instances to weights.
233
+
234
+ Args:
235
+ dataset: The dataset to use to compute the cross-table.
236
+ rvs: The random variables to compute the cross-table for. If omitted
237
+ then `dataset.rvs` will be used.
238
+ dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
239
+ (a) as a uniform prior, represented as a float value,
240
+ (b) as an arbitrary Dirichlet prior, represented as a cross-table.
241
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
242
+ The default value for `dirichlet_prior` is 0.
243
+ See `CrossTable` for more explanation.
244
+
245
+ Returns:
246
+ The cross-table for the given random variables, using the given dataset,
247
+ represented as a dictionary mapping instances to weights.
248
+ An instance is a tuple of state indexes, co-indexed with rvs.
249
+
250
+ Raises:
251
+ KeyError: If any random variable in `rvs` does not appear in the dataset.
252
+ """
253
+ if isinstance(dataset, HardDataset):
254
+ return cross_table_from_hard_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
255
+ if isinstance(dataset, SoftDataset):
256
+ return cross_table_from_soft_dataset(dataset, rvs, dirichlet_prior=dirichlet_prior)
257
+ raise TypeError('dataset must be either a SoftDataset or HardDataset')
258
+
259
+
260
+ def cross_table_from_hard_dataset(
261
+ dataset: HardDataset,
262
+ rvs: Optional[Sequence[RandomVariable]] = None,
263
+ *,
264
+ dirichlet_prior: float | CrossTable = 0
265
+ ) -> CrossTable:
266
+ """
267
+ Generate a cross-table for the given random variables, using the given dataset, represented
268
+ as a dictionary, mapping instances to weights.
269
+
270
+ Args:
271
+ dataset: The dataset to use to compute the cross-table.
272
+ rvs: The random variables to compute the cross-table for. If omitted
273
+ then `dataset.rvs` will be used.
274
+ dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
275
+ (a) as a uniform prior, represented as a float value,
276
+ (b) as an arbitrary Dirichlet prior, represented as a cross-table.
277
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
278
+ The default value for `dirichlet_prior` is 0.
279
+ See `CrossTable` for more explanation.
280
+
281
+ Returns:
282
+ The cross-table for the given random variables, using the given dataset,
283
+ represented as a dictionary mapping instances to weights.
284
+ An instance is a tuple of state indexes, co-indexed with rvs.
285
+
286
+ Raises:
287
+ KeyError: If any random variable in `rvs` does not appear in the dataset.
288
+ """
289
+ if rvs is None:
290
+ rvs = dataset.rvs
291
+ return CrossTable(
292
+ rvs=rvs,
293
+ dirichlet_prior=dirichlet_prior,
294
+ update=dataset.instances(rvs)
295
+ )
296
+
297
+
298
+ def cross_table_from_soft_dataset(
299
+ dataset: SoftDataset,
300
+ rvs: Optional[Sequence[RandomVariable]] = None,
301
+ *,
302
+ dirichlet_prior: float | CrossTable = 0
303
+ ) -> CrossTable:
304
+ """
305
+ Generate a cross-table for the given random variables, using the given dataset, represented
306
+ as a dictionary, mapping instances to weights.
307
+
308
+ Args:
309
+ dataset: The dataset to use to compute the cross-table.
310
+ rvs: The random variables to compute the cross-table for. If omitted
311
+ then `dataset.rvs` will be used.
312
+ dirichlet_prior: provides a Dirichlet prior for `rvs`. This can be represented either:
313
+ (a) as a uniform prior, represented as a float value,
314
+ (b) as an arbitrary Dirichlet prior, represented as a cross-table.
315
+ If a cross-table is provided as a prior, then it must have the same random variables as `rvs`.
316
+ The default value for `dirichlet_prior` is 0.
317
+ See `CrossTable` for more explanation.
318
+
319
+ Returns:
320
+ The cross-table for the given random variables, using the given dataset,
321
+ represented as a dictionary mapping instances to weights.
322
+ An instance is a tuple of state indexes, co-indexed with rvs.
323
+
324
+ Raises:
325
+ KeyError: If any random variable in `rvs` does not appear in the dataset.
326
+ """
327
+ if rvs is None:
328
+ rvs = dataset.rvs
329
+
330
+ return CrossTable(
331
+ rvs=rvs,
332
+ dirichlet_prior=dirichlet_prior,
333
+ update=dataset.hard_instances(rvs)
334
+ )