compiled-knowledge 4.0.0a5__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of compiled-knowledge might be problematic. Click here for more details.
- ck/__init__.py +0 -0
- ck/circuit/__init__.py +13 -0
- ck/circuit/circuit.c +38749 -0
- ck/circuit/circuit.cpython-313-darwin.so +0 -0
- ck/circuit/circuit_py.py +807 -0
- ck/circuit/tmp_const.py +74 -0
- ck/circuit_compiler/__init__.py +2 -0
- ck/circuit_compiler/circuit_compiler.py +26 -0
- ck/circuit_compiler/cython_vm_compiler/__init__.py +1 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.c +17373 -0
- ck/circuit_compiler/cython_vm_compiler/_compiler.cpython-313-darwin.so +0 -0
- ck/circuit_compiler/cython_vm_compiler/cython_vm_compiler.py +96 -0
- ck/circuit_compiler/interpret_compiler.py +223 -0
- ck/circuit_compiler/llvm_compiler.py +388 -0
- ck/circuit_compiler/llvm_vm_compiler.py +546 -0
- ck/circuit_compiler/named_circuit_compilers.py +57 -0
- ck/circuit_compiler/support/__init__.py +0 -0
- ck/circuit_compiler/support/circuit_analyser.py +81 -0
- ck/circuit_compiler/support/input_vars.py +148 -0
- ck/circuit_compiler/support/llvm_ir_function.py +234 -0
- ck/example/__init__.py +53 -0
- ck/example/alarm.py +366 -0
- ck/example/asia.py +28 -0
- ck/example/binary_clique.py +32 -0
- ck/example/bow_tie.py +33 -0
- ck/example/cancer.py +37 -0
- ck/example/chain.py +38 -0
- ck/example/child.py +199 -0
- ck/example/clique.py +33 -0
- ck/example/cnf_pgm.py +39 -0
- ck/example/diamond_square.py +68 -0
- ck/example/earthquake.py +36 -0
- ck/example/empty.py +10 -0
- ck/example/hailfinder.py +539 -0
- ck/example/hepar2.py +628 -0
- ck/example/insurance.py +504 -0
- ck/example/loop.py +40 -0
- ck/example/mildew.py +38161 -0
- ck/example/munin.py +22982 -0
- ck/example/pathfinder.py +53674 -0
- ck/example/rain.py +39 -0
- ck/example/rectangle.py +161 -0
- ck/example/run.py +30 -0
- ck/example/sachs.py +129 -0
- ck/example/sprinkler.py +30 -0
- ck/example/star.py +44 -0
- ck/example/stress.py +64 -0
- ck/example/student.py +43 -0
- ck/example/survey.py +46 -0
- ck/example/triangle_square.py +54 -0
- ck/example/truss.py +49 -0
- ck/in_out/__init__.py +3 -0
- ck/in_out/parse_ace_lmap.py +216 -0
- ck/in_out/parse_ace_nnf.py +288 -0
- ck/in_out/parse_net.py +480 -0
- ck/in_out/parser_utils.py +185 -0
- ck/in_out/pgm_pickle.py +42 -0
- ck/in_out/pgm_python.py +268 -0
- ck/in_out/render_bugs.py +111 -0
- ck/in_out/render_net.py +177 -0
- ck/in_out/render_pomegranate.py +184 -0
- ck/pgm.py +3494 -0
- ck/pgm_circuit/__init__.py +1 -0
- ck/pgm_circuit/marginals_program.py +352 -0
- ck/pgm_circuit/mpe_program.py +237 -0
- ck/pgm_circuit/pgm_circuit.py +75 -0
- ck/pgm_circuit/program_with_slotmap.py +234 -0
- ck/pgm_circuit/slot_map.py +35 -0
- ck/pgm_circuit/support/__init__.py +0 -0
- ck/pgm_circuit/support/compile_circuit.py +83 -0
- ck/pgm_circuit/target_marginals_program.py +103 -0
- ck/pgm_circuit/wmc_program.py +323 -0
- ck/pgm_compiler/__init__.py +2 -0
- ck/pgm_compiler/ace/__init__.py +1 -0
- ck/pgm_compiler/ace/ace.py +252 -0
- ck/pgm_compiler/factor_elimination.py +383 -0
- ck/pgm_compiler/named_pgm_compilers.py +63 -0
- ck/pgm_compiler/pgm_compiler.py +19 -0
- ck/pgm_compiler/recursive_conditioning.py +226 -0
- ck/pgm_compiler/support/__init__.py +0 -0
- ck/pgm_compiler/support/circuit_table/__init__.py +9 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.c +16042 -0
- ck/pgm_compiler/support/circuit_table/circuit_table.cpython-313-darwin.so +0 -0
- ck/pgm_compiler/support/circuit_table/circuit_table_py.py +269 -0
- ck/pgm_compiler/support/clusters.py +556 -0
- ck/pgm_compiler/support/factor_tables.py +398 -0
- ck/pgm_compiler/support/join_tree.py +275 -0
- ck/pgm_compiler/support/named_compiler_maker.py +33 -0
- ck/pgm_compiler/variable_elimination.py +89 -0
- ck/probability/__init__.py +0 -0
- ck/probability/empirical_probability_space.py +47 -0
- ck/probability/probability_space.py +568 -0
- ck/program/__init__.py +3 -0
- ck/program/program.py +129 -0
- ck/program/program_buffer.py +180 -0
- ck/program/raw_program.py +61 -0
- ck/sampling/__init__.py +0 -0
- ck/sampling/forward_sampler.py +211 -0
- ck/sampling/marginals_direct_sampler.py +113 -0
- ck/sampling/sampler.py +62 -0
- ck/sampling/sampler_support.py +232 -0
- ck/sampling/uniform_sampler.py +66 -0
- ck/sampling/wmc_direct_sampler.py +169 -0
- ck/sampling/wmc_gibbs_sampler.py +147 -0
- ck/sampling/wmc_metropolis_sampler.py +159 -0
- ck/sampling/wmc_rejection_sampler.py +113 -0
- ck/utils/__init__.py +0 -0
- ck/utils/iter_extras.py +153 -0
- ck/utils/map_list.py +128 -0
- ck/utils/map_set.py +128 -0
- ck/utils/np_extras.py +51 -0
- ck/utils/random_extras.py +64 -0
- ck/utils/tmp_dir.py +94 -0
- ck_demos/__init__.py +0 -0
- ck_demos/ace/__init__.py +0 -0
- ck_demos/ace/copy_ace_to_ck.py +15 -0
- ck_demos/ace/demo_ace.py +44 -0
- ck_demos/all_demos.py +88 -0
- ck_demos/circuit/__init__.py +0 -0
- ck_demos/circuit/demo_circuit_dump.py +22 -0
- ck_demos/circuit/demo_derivatives.py +43 -0
- ck_demos/circuit_compiler/__init__.py +0 -0
- ck_demos/circuit_compiler/compare_circuit_compilers.py +32 -0
- ck_demos/circuit_compiler/show_llvm_program.py +26 -0
- ck_demos/pgm/__init__.py +0 -0
- ck_demos/pgm/demo_pgm_dump.py +18 -0
- ck_demos/pgm/demo_pgm_dump_stress.py +18 -0
- ck_demos/pgm/demo_pgm_string_rendering.py +15 -0
- ck_demos/pgm/show_examples.py +25 -0
- ck_demos/pgm_compiler/__init__.py +0 -0
- ck_demos/pgm_compiler/compare_pgm_compilers.py +50 -0
- ck_demos/pgm_compiler/demo_compiler_dump.py +50 -0
- ck_demos/pgm_compiler/demo_factor_elimination.py +47 -0
- ck_demos/pgm_compiler/demo_join_tree.py +25 -0
- ck_demos/pgm_compiler/demo_marginals_program.py +53 -0
- ck_demos/pgm_compiler/demo_mpe_program.py +55 -0
- ck_demos/pgm_compiler/demo_pgm_compiler.py +38 -0
- ck_demos/pgm_compiler/demo_recursive_conditioning.py +33 -0
- ck_demos/pgm_compiler/demo_variable_elimination.py +33 -0
- ck_demos/pgm_compiler/demo_wmc_program.py +29 -0
- ck_demos/pgm_inference/__init__.py +0 -0
- ck_demos/pgm_inference/demo_inferencing_basic.py +188 -0
- ck_demos/pgm_inference/demo_inferencing_mpe_cancer.py +45 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_and_mpe_sprinkler.py +154 -0
- ck_demos/pgm_inference/demo_inferencing_wmc_student.py +110 -0
- ck_demos/programs/__init__.py +0 -0
- ck_demos/programs/demo_program_buffer.py +24 -0
- ck_demos/programs/demo_program_multi.py +24 -0
- ck_demos/programs/demo_program_none.py +19 -0
- ck_demos/programs/demo_program_single.py +23 -0
- ck_demos/programs/demo_raw_program_interpreted.py +21 -0
- ck_demos/programs/demo_raw_program_llvm.py +21 -0
- ck_demos/sampling/__init__.py +0 -0
- ck_demos/sampling/check_sampler.py +71 -0
- ck_demos/sampling/demo_marginal_direct_sampler.py +40 -0
- ck_demos/sampling/demo_uniform_sampler.py +38 -0
- ck_demos/sampling/demo_wmc_direct_sampler.py +40 -0
- ck_demos/utils/__init__.py +0 -0
- ck_demos/utils/compare.py +88 -0
- ck_demos/utils/convert_network.py +45 -0
- ck_demos/utils/sample_model.py +216 -0
- ck_demos/utils/stop_watch.py +384 -0
- compiled_knowledge-4.0.0a5.dist-info/METADATA +50 -0
- compiled_knowledge-4.0.0a5.dist-info/RECORD +167 -0
- compiled_knowledge-4.0.0a5.dist-info/WHEEL +5 -0
- compiled_knowledge-4.0.0a5.dist-info/licenses/LICENSE.txt +21 -0
- compiled_knowledge-4.0.0a5.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from typing import Collection, Iterator, List, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayStates, DTypeStates
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCMetropolisSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
skip: int,
|
|
23
|
+
burn_in: int,
|
|
24
|
+
pr_restart: float,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
27
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
28
|
+
self._rand: Random = rand
|
|
29
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
30
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
31
|
+
self._state_dtype: DTypeStates = dtype_for_state_indexes(self.rvs)
|
|
32
|
+
self._slots_0: Collection[int] = sampler_info.slots_0
|
|
33
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
34
|
+
self._skip: int = skip
|
|
35
|
+
self._burn_in: int = burn_in
|
|
36
|
+
self._pr_restart: float = pr_restart
|
|
37
|
+
|
|
38
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
39
|
+
sample_rvs = self._sample_rvs
|
|
40
|
+
rand = self._rand
|
|
41
|
+
yield_f = self._yield_f
|
|
42
|
+
slots_0 = self._slots_0
|
|
43
|
+
slots_1 = self._slots_1
|
|
44
|
+
program_buffer = self._program_buffer
|
|
45
|
+
slots = program_buffer.vars
|
|
46
|
+
skip = self._skip
|
|
47
|
+
burn_in = self._burn_in
|
|
48
|
+
pr_restart = self._pr_restart
|
|
49
|
+
|
|
50
|
+
# Allocate working memory
|
|
51
|
+
state: NDArrayStates = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
52
|
+
|
|
53
|
+
# set up the input slots to respect conditioning
|
|
54
|
+
for slot in slots_0:
|
|
55
|
+
slots[slot] = 0
|
|
56
|
+
for slot in slots_1:
|
|
57
|
+
slots[slot] = 1
|
|
58
|
+
|
|
59
|
+
# Convert sample slots to possibles
|
|
60
|
+
# And map slots to states.
|
|
61
|
+
possibles = []
|
|
62
|
+
for sample_rv in sample_rvs:
|
|
63
|
+
rv_possibles = []
|
|
64
|
+
for slot_state, slot in enumerate(sample_rv.slots):
|
|
65
|
+
if slots[slot] == 1:
|
|
66
|
+
rv_possibles.append((slot_state, slot))
|
|
67
|
+
possibles.append((sample_rv.index, sample_rv.slots, rv_possibles))
|
|
68
|
+
|
|
69
|
+
# Set an initial valid system state
|
|
70
|
+
w: float = self._init_sample_metropolis(state)
|
|
71
|
+
|
|
72
|
+
# Run a burn in
|
|
73
|
+
for i in range(burn_in):
|
|
74
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
75
|
+
|
|
76
|
+
if pr_restart <= 0:
|
|
77
|
+
# There is no possibility of a restart
|
|
78
|
+
if skip == 0:
|
|
79
|
+
while True:
|
|
80
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
81
|
+
yield yield_f(state)
|
|
82
|
+
else:
|
|
83
|
+
while True:
|
|
84
|
+
for _ in range(skip):
|
|
85
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
86
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
87
|
+
yield yield_f(state)
|
|
88
|
+
|
|
89
|
+
else:
|
|
90
|
+
# There is the possibility of a restart
|
|
91
|
+
while True:
|
|
92
|
+
for _ in range(skip):
|
|
93
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
94
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
95
|
+
yield yield_f(state)
|
|
96
|
+
|
|
97
|
+
if rand.random() < pr_restart:
|
|
98
|
+
# Set an initial valid system state
|
|
99
|
+
w = self._init_sample_metropolis(state)
|
|
100
|
+
# Run a burn in
|
|
101
|
+
for i in range(burn_in):
|
|
102
|
+
w = self._next_sample_metropolis(possibles, program_buffer, state, w, rand)
|
|
103
|
+
|
|
104
|
+
def _init_sample_metropolis(self, state: NDArrayStates) -> float:
|
|
105
|
+
"""
|
|
106
|
+
Initialises the states to a valid random system and configures program inputs to match.
|
|
107
|
+
"""
|
|
108
|
+
sample_rvs = self._sample_rvs
|
|
109
|
+
rand = self._rand
|
|
110
|
+
slots_0 = self._slots_0
|
|
111
|
+
slots_1 = self._slots_1
|
|
112
|
+
program_buffer = self._program_buffer
|
|
113
|
+
slots = program_buffer.vars
|
|
114
|
+
|
|
115
|
+
while True:
|
|
116
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, slots, state, rand)
|
|
117
|
+
w: float = program_buffer.compute().item()
|
|
118
|
+
if w >= 0:
|
|
119
|
+
return w
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def _next_sample_metropolis(
|
|
123
|
+
possibles,
|
|
124
|
+
program_buffer: ProgramBuffer,
|
|
125
|
+
state,
|
|
126
|
+
cur_w: float,
|
|
127
|
+
rand: Random,
|
|
128
|
+
) -> float:
|
|
129
|
+
"""
|
|
130
|
+
Updates the states to a random system and reconfigures program inputs to match.
|
|
131
|
+
"""
|
|
132
|
+
prog_in = program_buffer.vars
|
|
133
|
+
|
|
134
|
+
# Generate a proposal.
|
|
135
|
+
# randomly choose a random variable
|
|
136
|
+
i = rand.randrange(0, len(possibles))
|
|
137
|
+
idx, rv_slots, rv_possibles = possibles[i]
|
|
138
|
+
# keep track of the current state slot
|
|
139
|
+
cur_s = state[idx]
|
|
140
|
+
cur_s_slot = rv_slots[cur_s]
|
|
141
|
+
# randomly choose a possible state
|
|
142
|
+
i = rand.randrange(0, len(rv_possibles))
|
|
143
|
+
s, s_slot = rv_possibles[i]
|
|
144
|
+
|
|
145
|
+
# set up state and program to compute weight
|
|
146
|
+
prog_in[cur_s_slot] = 0
|
|
147
|
+
prog_in[s_slot] = 1
|
|
148
|
+
|
|
149
|
+
# calculate the weight and test it
|
|
150
|
+
new_w: float = program_buffer.compute().item()
|
|
151
|
+
if rand.random() * cur_w < new_w:
|
|
152
|
+
# accept
|
|
153
|
+
state[idx] = s
|
|
154
|
+
return new_w
|
|
155
|
+
else:
|
|
156
|
+
# reject: set state and program to what it was before
|
|
157
|
+
prog_in[s_slot] = 0
|
|
158
|
+
prog_in[cur_s_slot] = 1
|
|
159
|
+
return cur_w
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import Collection, Iterator, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ck.pgm import Instance
|
|
6
|
+
from ck.probability.probability_space import dtype_for_state_indexes
|
|
7
|
+
from ck.program.program_buffer import ProgramBuffer
|
|
8
|
+
from ck.program.raw_program import RawProgram
|
|
9
|
+
from ck.sampling.sampler import Sampler
|
|
10
|
+
from ck.sampling.sampler_support import SampleRV, YieldF, uniform_random_sample, SamplerInfo
|
|
11
|
+
from ck.utils.np_extras import NDArrayNumeric
|
|
12
|
+
from ck.utils.random_extras import Random
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WMCRejectionSampler(Sampler):
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sampler_info: SamplerInfo,
|
|
20
|
+
raw_program: RawProgram,
|
|
21
|
+
rand: Random,
|
|
22
|
+
z: float,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(sampler_info.rvs, sampler_info.condition)
|
|
25
|
+
self._yield_f: YieldF = sampler_info.yield_f
|
|
26
|
+
self._rand: Random = rand
|
|
27
|
+
self._program_buffer = ProgramBuffer(raw_program)
|
|
28
|
+
self._sample_rvs: Sequence[SampleRV] = tuple(sampler_info.sample_rvs)
|
|
29
|
+
self._state_dtype = dtype_for_state_indexes(self.rvs)
|
|
30
|
+
self._slots_0: Collection[int] = sampler_info.slots_0
|
|
31
|
+
self._slots_1: Collection[int] = sampler_info.slots_1
|
|
32
|
+
|
|
33
|
+
# Initialise fields for tracking max_w
|
|
34
|
+
self._w_max = None # estimated maximum weight for any one world
|
|
35
|
+
self._w_not_seen = z # z - w_seen
|
|
36
|
+
self._w_high = 0.0 # highest instance wight seen so far
|
|
37
|
+
self._samples = set() # what samples have we seen
|
|
38
|
+
|
|
39
|
+
def __iter__(self) -> Iterator[Instance] | Iterator[int]:
|
|
40
|
+
sample_rvs = self._sample_rvs
|
|
41
|
+
rand = self._rand
|
|
42
|
+
yield_f = self._yield_f
|
|
43
|
+
slots_0 = self._slots_0
|
|
44
|
+
slots_1 = self._slots_1
|
|
45
|
+
program_buffer = self._program_buffer
|
|
46
|
+
slots: NDArrayNumeric = program_buffer.vars
|
|
47
|
+
|
|
48
|
+
# Calling wmc() will give the weighted model count for the state of the current input slots.
|
|
49
|
+
def wmc() -> float:
|
|
50
|
+
return program_buffer.compute().item()
|
|
51
|
+
|
|
52
|
+
# Allocate working memory to store a possible world
|
|
53
|
+
state: NDArrayNumeric = np.zeros(len(sample_rvs), dtype=self._state_dtype)
|
|
54
|
+
|
|
55
|
+
# Initialise w_max to w_max_marginal, if not done yet.
|
|
56
|
+
if self._w_max is None:
|
|
57
|
+
w_max_marginal = self._w_not_seen # initially set to z, so a 'large' weight
|
|
58
|
+
|
|
59
|
+
# Set up the input slots to 0 or 1 to respect conditioning and initial Markov chain states.
|
|
60
|
+
for slot in slots_0:
|
|
61
|
+
slots[slot] = 0
|
|
62
|
+
for slot in slots_1:
|
|
63
|
+
slots[slot] = 1
|
|
64
|
+
|
|
65
|
+
# Loop over the rvs
|
|
66
|
+
for sample_rv in sample_rvs:
|
|
67
|
+
rv_slots = sample_rv.slots
|
|
68
|
+
max_for_rv = 0
|
|
69
|
+
# Set all rv slots to 0
|
|
70
|
+
for slot_state, slot in enumerate(rv_slots):
|
|
71
|
+
slots[slot] = 0
|
|
72
|
+
back_to_one = []
|
|
73
|
+
# Loop over state of the rv.
|
|
74
|
+
for slot_state, slot in enumerate(rv_slots):
|
|
75
|
+
if slot in slots_1:
|
|
76
|
+
slots[slot] = 1
|
|
77
|
+
w: float = wmc()
|
|
78
|
+
max_for_rv = max(max_for_rv, w)
|
|
79
|
+
slots[slot] = 0
|
|
80
|
+
back_to_one.append(slot)
|
|
81
|
+
# Set rv slots back to 1 as needed (ready for next rv).
|
|
82
|
+
for slot in back_to_one:
|
|
83
|
+
slots[slot] = 1
|
|
84
|
+
|
|
85
|
+
w_max_marginal = min(w_max_marginal, max_for_rv)
|
|
86
|
+
|
|
87
|
+
self._w_max = w_max_marginal
|
|
88
|
+
|
|
89
|
+
while True:
|
|
90
|
+
uniform_random_sample(sample_rvs, slots_0, slots_1, slots, state, rand)
|
|
91
|
+
w: float = wmc()
|
|
92
|
+
|
|
93
|
+
if rand.random() * self._w_max < w:
|
|
94
|
+
yield yield_f(state)
|
|
95
|
+
|
|
96
|
+
# Update w_not_seen and w_high to adapt w_max.
|
|
97
|
+
# We don't bother tracking seen samples once w_not_seen and w_high
|
|
98
|
+
# are close enough, or we have tracked too many samples.
|
|
99
|
+
if self._samples is not None:
|
|
100
|
+
s = tuple(state)
|
|
101
|
+
if s not in self._samples:
|
|
102
|
+
self._samples.add(s)
|
|
103
|
+
self._w_not_seen -= w
|
|
104
|
+
self._w_high = max(self._w_high, w)
|
|
105
|
+
w_max_tracked = max(self._w_high, self._w_not_seen)
|
|
106
|
+
self._w_max = min(w_max_tracked, self._w_max)
|
|
107
|
+
|
|
108
|
+
# See if we should stop tracking samples.
|
|
109
|
+
if (
|
|
110
|
+
self._w_not_seen - self._w_high < 0.001 # w_not_seen and w_high are close enough
|
|
111
|
+
or len(self._samples) > 1000000 # tracked too many samples
|
|
112
|
+
):
|
|
113
|
+
self._samples = None
|
ck/utils/__init__.py
ADDED
|
File without changes
|
ck/utils/iter_extras.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A module with extra iteration functions.
|
|
3
|
+
"""
|
|
4
|
+
from functools import reduce as _reduce
|
|
5
|
+
from itertools import combinations, chain
|
|
6
|
+
from operator import mul as _mul
|
|
7
|
+
from typing import Iterable, Tuple, Iterator, Sequence, TypeVar
|
|
8
|
+
|
|
9
|
+
_T = TypeVar('_T')
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def flatten(iterables: Iterable[Iterable[_T]]) -> Iterator[_T]:
|
|
13
|
+
"""
|
|
14
|
+
Iterate over the elements of an iterable of iterables.
|
|
15
|
+
"""
|
|
16
|
+
return (elem for iterable in iterables for elem in iterable)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def deep_flatten(iterables: Iterable) -> Iterator:
|
|
20
|
+
"""
|
|
21
|
+
Iterate over the flattening of nested iterables.
|
|
22
|
+
"""
|
|
23
|
+
for el in iterables:
|
|
24
|
+
if isinstance(el, Iterable) and not isinstance(el, str):
|
|
25
|
+
for sub in deep_flatten(el):
|
|
26
|
+
yield sub
|
|
27
|
+
else:
|
|
28
|
+
yield el
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def combos(list_of_lists: Sequence[Sequence[_T]], flip=False) -> Iterator[Tuple[_T, ...]]:
|
|
32
|
+
"""
|
|
33
|
+
Iterate over all combinations of taking one element from each of the lists.
|
|
34
|
+
|
|
35
|
+
The order of results has the first element changing most rapidly.
|
|
36
|
+
For example, given [[1,2,3],[4,5],[6,7]], combos yields the following:
|
|
37
|
+
(1,4,6), (2,4,6), (3,4,6), (1,5,6), (2,5,6), (3,5,6),
|
|
38
|
+
(1,4,7), (2,4,7), (3,4,7), (1,5,7), (2,5,7), (3,5,7).
|
|
39
|
+
|
|
40
|
+
If flip, then the last changes most rapidly.
|
|
41
|
+
"""
|
|
42
|
+
num = len(list_of_lists)
|
|
43
|
+
if num == 0:
|
|
44
|
+
yield ()
|
|
45
|
+
return
|
|
46
|
+
rng = range(num)
|
|
47
|
+
indexes = [0] * num
|
|
48
|
+
if flip:
|
|
49
|
+
start = num - 1
|
|
50
|
+
inc = -1
|
|
51
|
+
end = -1
|
|
52
|
+
else:
|
|
53
|
+
start = 0
|
|
54
|
+
inc = 1
|
|
55
|
+
end = num
|
|
56
|
+
while True:
|
|
57
|
+
yield tuple(list_of_lists[i][indexes[i]] for i in rng)
|
|
58
|
+
i = start
|
|
59
|
+
while True:
|
|
60
|
+
indexes[i] += 1
|
|
61
|
+
if indexes[i] < len(list_of_lists[i]):
|
|
62
|
+
break
|
|
63
|
+
indexes[i] = 0
|
|
64
|
+
i += inc
|
|
65
|
+
if i == end:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def combos_ranges(list_of_lens: Sequence[int], flip=False) -> Iterator[Tuple[int, ...]]:
|
|
70
|
+
"""
|
|
71
|
+
Equivalent to combos([range(l) for l in list_of_lens], flip).
|
|
72
|
+
|
|
73
|
+
The order of results has the first element changing most rapidly.
|
|
74
|
+
If flip, then the last changes most rapidly.
|
|
75
|
+
"""
|
|
76
|
+
num = len(list_of_lens)
|
|
77
|
+
if num == 0:
|
|
78
|
+
yield ()
|
|
79
|
+
return
|
|
80
|
+
indexes = [0] * num
|
|
81
|
+
if flip:
|
|
82
|
+
start = num - 1
|
|
83
|
+
inc = -1
|
|
84
|
+
end = -1
|
|
85
|
+
else:
|
|
86
|
+
start = 0
|
|
87
|
+
inc = 1
|
|
88
|
+
end = num
|
|
89
|
+
while True:
|
|
90
|
+
yield tuple(indexes)
|
|
91
|
+
i = start
|
|
92
|
+
while True:
|
|
93
|
+
indexes[i] += 1
|
|
94
|
+
if indexes[i] < list_of_lens[i]:
|
|
95
|
+
break
|
|
96
|
+
indexes[i] = 0
|
|
97
|
+
i += inc
|
|
98
|
+
if i == end:
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def pairs(elements: Iterable[_T]) -> Iterable[Tuple[_T, _T]]:
|
|
103
|
+
"""
|
|
104
|
+
Iterate over all possible pairs in the given list of elements.
|
|
105
|
+
"""
|
|
106
|
+
return combinations(elements, 2)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def sequential_pairs(elements: Sequence[_T]) -> Iterator[Tuple[_T, _T]]:
|
|
110
|
+
"""
|
|
111
|
+
Iterate over sequential pairs in the given list of elements.
|
|
112
|
+
"""
|
|
113
|
+
for i in range(len(elements) - 1):
|
|
114
|
+
yield elements[i], elements[i + 1]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def powerset(iterable: Iterable[_T], min_size: int = 0, max_size: int = None) -> Iterable[Tuple[_T, ...]]:
|
|
118
|
+
"""
|
|
119
|
+
powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
|
|
120
|
+
"""
|
|
121
|
+
if not isinstance(iterable, (list, tuple)):
|
|
122
|
+
iterable = list(iterable)
|
|
123
|
+
if min_size is None:
|
|
124
|
+
min_size = 0
|
|
125
|
+
if max_size is None:
|
|
126
|
+
max_size = len(iterable)
|
|
127
|
+
return chain.from_iterable(
|
|
128
|
+
combinations(iterable, size)
|
|
129
|
+
for size in range(min_size, max_size + 1)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def unzip(xs: Iterable[Tuple[_T]]) -> Tuple[Iterable[_T]]:
|
|
134
|
+
"""
|
|
135
|
+
Inverse function of zip.
|
|
136
|
+
|
|
137
|
+
E.g., a, b, c = unzip(zip(a, b, c))
|
|
138
|
+
"""
|
|
139
|
+
return zip(*xs)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def multiply(items: Iterable[_T], initial: _T = 1) -> _T:
|
|
143
|
+
"""
|
|
144
|
+
Return the product of the given items.
|
|
145
|
+
"""
|
|
146
|
+
return _reduce(_mul, items, initial)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def first(items: Iterable[_T]) -> _T:
|
|
150
|
+
"""
|
|
151
|
+
Return the first element of the iterable.
|
|
152
|
+
"""
|
|
153
|
+
return next(iter(items))
|
ck/utils/map_list.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines a class "MapList" for mapping keys to lists.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TypeVar, Generic, List, Dict, MutableMapping, KeysView, ValuesView, ItemsView, Iterable, Iterator
|
|
7
|
+
|
|
8
|
+
_K = TypeVar('_K')
|
|
9
|
+
_V = TypeVar('_V')
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MapList(Generic[_K, _V], MutableMapping[_K, List[_V]]):
|
|
13
|
+
"""
|
|
14
|
+
A MapList keeps a list for each key, unlike a dict which keeps only
|
|
15
|
+
a single element for each key.
|
|
16
|
+
"""
|
|
17
|
+
__slots__ = ('_map', )
|
|
18
|
+
|
|
19
|
+
def __init__(self, *args, **kwargs):
|
|
20
|
+
self._map: Dict[_K, List[_V]] = {}
|
|
21
|
+
self.update(*args, extend=False, **kwargs)
|
|
22
|
+
|
|
23
|
+
def __str__(self) -> str:
|
|
24
|
+
return str(self._map)
|
|
25
|
+
|
|
26
|
+
def __repr__(self) -> str:
|
|
27
|
+
args = ', '.join(f'{key!r}:{val!r}' for key, val in self.items())
|
|
28
|
+
class_name = self.__class__.__name__
|
|
29
|
+
return f'{class_name}({args})'
|
|
30
|
+
|
|
31
|
+
def __len__(self) -> int:
|
|
32
|
+
return len(self._map)
|
|
33
|
+
|
|
34
|
+
def __bool__(self) -> bool:
|
|
35
|
+
return len(self) > 0
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, key) -> List[_V]:
|
|
38
|
+
return self._map[key]
|
|
39
|
+
|
|
40
|
+
def __setitem__(self, key: _K, val: List[_V]):
|
|
41
|
+
if not isinstance(val, list):
|
|
42
|
+
class_name = self.__class__.__name__
|
|
43
|
+
raise RuntimeError(f'every {class_name} value must be a list')
|
|
44
|
+
self._map[key] = val
|
|
45
|
+
|
|
46
|
+
def __delitem__(self, key: _K):
|
|
47
|
+
del self._map[key]
|
|
48
|
+
|
|
49
|
+
def __iter__(self) -> Iterator[_K]:
|
|
50
|
+
return iter(self._map)
|
|
51
|
+
|
|
52
|
+
def __contains__(self, key: _K) -> bool:
|
|
53
|
+
return key in self._map
|
|
54
|
+
|
|
55
|
+
def update(self, *args, extend=False, **kwargs):
|
|
56
|
+
k: _K
|
|
57
|
+
v: List[_V]
|
|
58
|
+
if extend:
|
|
59
|
+
for k, v in dict(*args, **kwargs).items():
|
|
60
|
+
self.extend(k, v)
|
|
61
|
+
else:
|
|
62
|
+
for k, v in dict(*args, **kwargs).items():
|
|
63
|
+
self[k] = v
|
|
64
|
+
|
|
65
|
+
def keys(self) -> KeysView[_K]:
|
|
66
|
+
return self._map.keys()
|
|
67
|
+
|
|
68
|
+
def values(self) -> ValuesView[List[_V]]:
|
|
69
|
+
return self._map.values()
|
|
70
|
+
|
|
71
|
+
def items(self) -> ItemsView[_K, List[_V]]:
|
|
72
|
+
return self._map.items()
|
|
73
|
+
|
|
74
|
+
def get(self, key: _K, default=None):
|
|
75
|
+
"""
|
|
76
|
+
Get the list corresponding to the given key.
|
|
77
|
+
If the key is not yet in the MapList then the
|
|
78
|
+
supplied default will be returned.
|
|
79
|
+
"""
|
|
80
|
+
return self._map.get(key, default)
|
|
81
|
+
|
|
82
|
+
def get_list(self, key: _K) -> List[_V]:
|
|
83
|
+
"""
|
|
84
|
+
Get the list corresponding to the given key.
|
|
85
|
+
|
|
86
|
+
This method will always return a list in the MapList, even if
|
|
87
|
+
it requires a new list being created.
|
|
88
|
+
|
|
89
|
+
Modifying the returned list affects this MapList object.
|
|
90
|
+
"""
|
|
91
|
+
the_list = self._map.get(key)
|
|
92
|
+
if the_list is None:
|
|
93
|
+
the_list = []
|
|
94
|
+
self._map[key] = the_list
|
|
95
|
+
return the_list
|
|
96
|
+
|
|
97
|
+
def append(self, key: _K, item: _V):
|
|
98
|
+
"""
|
|
99
|
+
Append the given item to the list identified by the given key.
|
|
100
|
+
"""
|
|
101
|
+
self.get_list(key).append(item)
|
|
102
|
+
|
|
103
|
+
def extend(self, key: _K, items: Iterable[_V]):
|
|
104
|
+
"""
|
|
105
|
+
Extend the given item to the list identified by the given key.
|
|
106
|
+
"""
|
|
107
|
+
return self.get_list(key).extend(items)
|
|
108
|
+
|
|
109
|
+
def extend_map_list(self, map_list: MapList[_K, _V]):
|
|
110
|
+
"""
|
|
111
|
+
Add all the keyed given items to the list identified by each key.
|
|
112
|
+
"""
|
|
113
|
+
for key, items in map_list.items():
|
|
114
|
+
self.extend(key, items)
|
|
115
|
+
|
|
116
|
+
def clear(self):
|
|
117
|
+
"""
|
|
118
|
+
Remove all items.
|
|
119
|
+
"""
|
|
120
|
+
return self._map.clear()
|
|
121
|
+
|
|
122
|
+
def clear_empty(self):
|
|
123
|
+
"""
|
|
124
|
+
Remove all empty values.
|
|
125
|
+
"""
|
|
126
|
+
keys_to_remove = [key for key, value in self._map.items() if len(value) == 0]
|
|
127
|
+
for key in keys_to_remove:
|
|
128
|
+
del self._map[key]
|
ck/utils/map_set.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines a class "MapSet" for mapping keys to sets.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TypeVar, Generic, Set, Dict, MutableMapping, Iterator, KeysView, ValuesView, ItemsView, Iterable
|
|
7
|
+
|
|
8
|
+
_K = TypeVar('_K')
|
|
9
|
+
_V = TypeVar('_V')
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MapSet(Generic[_K, _V], MutableMapping[_K, Set[_V]]):
|
|
13
|
+
"""
|
|
14
|
+
A MapSet keeps a set for each key, unlike a dict which keeps only
|
|
15
|
+
a single element for each key.
|
|
16
|
+
"""
|
|
17
|
+
__slots__ = ('_map',)
|
|
18
|
+
|
|
19
|
+
def __init__(self, *args, **kwargs):
|
|
20
|
+
self._map: Dict[_K, Set[_V]] = {}
|
|
21
|
+
self.update(*args, add_all=False, **kwargs)
|
|
22
|
+
|
|
23
|
+
def __str__(self) -> str:
|
|
24
|
+
return str(self._map)
|
|
25
|
+
|
|
26
|
+
def __repr__(self) -> str:
|
|
27
|
+
args = ', '.join(f'{key!r}:{val!r}' for key, val in self.items())
|
|
28
|
+
class_name = self.__class__.__name__
|
|
29
|
+
return f'{class_name}({args})'
|
|
30
|
+
|
|
31
|
+
def __len__(self) -> int:
|
|
32
|
+
return len(self._map)
|
|
33
|
+
|
|
34
|
+
def __bool__(self) -> bool:
|
|
35
|
+
return len(self) > 0
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, key: _K) -> Set[_V]:
|
|
38
|
+
return self._map[key]
|
|
39
|
+
|
|
40
|
+
def __setitem__(self, key: _K, val: Set[_V]):
|
|
41
|
+
if not isinstance(val, set):
|
|
42
|
+
class_name = self.__class__.__name__
|
|
43
|
+
raise RuntimeError(f'every {class_name} value must be a set')
|
|
44
|
+
self._map[key] = val
|
|
45
|
+
|
|
46
|
+
def __delitem__(self, key: _K):
|
|
47
|
+
del self._map[key]
|
|
48
|
+
|
|
49
|
+
def __iter__(self) -> Iterator[_K]:
|
|
50
|
+
return iter(self._map)
|
|
51
|
+
|
|
52
|
+
def __contains__(self, key: _K) -> bool:
|
|
53
|
+
return key in self._map
|
|
54
|
+
|
|
55
|
+
def update(self, *args, add_all=True, **kwargs):
|
|
56
|
+
k: _K
|
|
57
|
+
v: Set[_V]
|
|
58
|
+
if add_all:
|
|
59
|
+
for k, v in dict(*args, **kwargs).items():
|
|
60
|
+
self.add_all(k, v)
|
|
61
|
+
else:
|
|
62
|
+
for k, v in dict(*args, **kwargs).items():
|
|
63
|
+
self[k] = v
|
|
64
|
+
|
|
65
|
+
def keys(self) -> KeysView[_K]:
|
|
66
|
+
return self._map.keys()
|
|
67
|
+
|
|
68
|
+
def values(self) -> ValuesView[Set[_V]]:
|
|
69
|
+
return self._map.values()
|
|
70
|
+
|
|
71
|
+
def items(self) -> ItemsView[_K, Set[_V]]:
|
|
72
|
+
return self._map.items()
|
|
73
|
+
|
|
74
|
+
def get(self, key: _K, default=None):
|
|
75
|
+
"""
|
|
76
|
+
Get the set corresponding to the given key.
|
|
77
|
+
If the key is not yet in the MapSet then the
|
|
78
|
+
supplied default will be returned.
|
|
79
|
+
"""
|
|
80
|
+
return self._map.get(key, default)
|
|
81
|
+
|
|
82
|
+
def get_set(self, key: _K) -> Set[_V]:
|
|
83
|
+
"""
|
|
84
|
+
Get the set corresponding to the given key.
|
|
85
|
+
|
|
86
|
+
This method will always return a set in the MapSet, even if
|
|
87
|
+
it requires a new set being created.
|
|
88
|
+
|
|
89
|
+
Modifying the returned set affects this MapSet object.
|
|
90
|
+
"""
|
|
91
|
+
the_set = self._map.get(key)
|
|
92
|
+
if the_set is None:
|
|
93
|
+
the_set = set()
|
|
94
|
+
self._map[key] = the_set
|
|
95
|
+
return the_set
|
|
96
|
+
|
|
97
|
+
def add(self, key: _K, item: _V):
|
|
98
|
+
"""
|
|
99
|
+
Add the given item to the set identified by the given key.
|
|
100
|
+
"""
|
|
101
|
+
self.get_set(key).add(item)
|
|
102
|
+
|
|
103
|
+
def add_all(self, key: _K, items: Iterable[_V]):
|
|
104
|
+
"""
|
|
105
|
+
Add all the given items to the set identified by the given key.
|
|
106
|
+
"""
|
|
107
|
+
return self.get_set(key).update(items)
|
|
108
|
+
|
|
109
|
+
def add_map_set(self, map_set: MapSet[_K, _V]):
|
|
110
|
+
"""
|
|
111
|
+
Add all the keyed given items to the set identified by each key.
|
|
112
|
+
"""
|
|
113
|
+
for key, items in map_set.items():
|
|
114
|
+
self.add_all(key, items)
|
|
115
|
+
|
|
116
|
+
def clear(self):
|
|
117
|
+
"""
|
|
118
|
+
Remove all items.
|
|
119
|
+
"""
|
|
120
|
+
return self._map.clear()
|
|
121
|
+
|
|
122
|
+
def clear_empty(self):
|
|
123
|
+
"""
|
|
124
|
+
Remove all empty values.
|
|
125
|
+
"""
|
|
126
|
+
keys_to_remove = [key for key, value in self._map.items() if len(value) == 0]
|
|
127
|
+
for key in keys_to_remove:
|
|
128
|
+
del self._map[key]
|