cbrkit 0.26.0__tar.gz → 0.26.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cbrkit-0.26.0 → cbrkit-0.26.1}/PKG-INFO +1 -1
- {cbrkit-0.26.0 → cbrkit-0.26.1}/pyproject.toml +1 -1
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/build.py +36 -21
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/__init__.py +2 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/__init__.py +6 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/astar.py +1 -77
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/common.py +110 -2
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/greedy.py +10 -14
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/wrappers.py +64 -1
- {cbrkit-0.26.0 → cbrkit-0.26.1}/README.md +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/__main__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/attribute_value.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/generic.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/numbers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/strings.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/api.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/cli.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/constants.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/cycle.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/dumpers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/eval/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/eval/common.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/eval/retrieval.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/helpers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/loaders.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/model/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/model/graph.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/model/result.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/py.typed +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/apply.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/rerank.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/reuse/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/reuse/apply.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/reuse/build.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/aggregator.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/attribute_value.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/collections.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/embed.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/generic.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/alignment.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/brute_force.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/dfs.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/lap.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/precompute.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/qap.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/vf2.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/numbers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/strings.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/taxonomy.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/apply.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/build.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/model.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/prompts.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/cohere.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/google.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/instructor.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/model.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/ollama.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/openai.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/typing.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: cbrkit
|
|
3
|
-
Version: 0.26.
|
|
3
|
+
Version: 0.26.1
|
|
4
4
|
Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
|
|
5
5
|
Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
|
|
6
6
|
Author: Mirko Lenz
|
|
@@ -2,7 +2,7 @@ import itertools
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from multiprocessing.pool import Pool
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Literal, override
|
|
6
6
|
|
|
7
7
|
from ..helpers import (
|
|
8
8
|
batchify_sim,
|
|
@@ -134,44 +134,59 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
|
|
|
134
134
|
A retriever function that combines the results from multiple retrievers.
|
|
135
135
|
"""
|
|
136
136
|
|
|
137
|
-
retriever_funcs:
|
|
138
|
-
aggregator: AggregatorFunc[
|
|
137
|
+
retriever_funcs: Sequence[RetrieverFunc[K, V, S]]
|
|
138
|
+
aggregator: AggregatorFunc[str, S] = default_aggregator
|
|
139
139
|
strategy: Literal["intersection", "union"] = "union"
|
|
140
140
|
|
|
141
141
|
@override
|
|
142
142
|
def __call__(
|
|
143
143
|
self, batches: Sequence[tuple[Casebase[K, V], V]]
|
|
144
144
|
) -> Sequence[SimMap[K, float]]:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
145
|
+
if isinstance(self.retriever_funcs, Sequence):
|
|
146
|
+
func_results = [
|
|
147
|
+
retriever_func(batches) for retriever_func in self.retriever_funcs
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
return [
|
|
151
|
+
self.__call_batch__(
|
|
152
|
+
[batch_results[batch_idx] for batch_results in func_results]
|
|
153
|
+
)
|
|
154
|
+
for batch_idx in range(len(batches))
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
# elif isinstance(self.retriever_funcs, Mapping):
|
|
158
|
+
# results = {
|
|
159
|
+
# func_key: retriever_func(batches)
|
|
160
|
+
# for func_key, retriever_func in self.retriever_funcs.items()
|
|
161
|
+
# }
|
|
162
|
+
|
|
163
|
+
# return [
|
|
164
|
+
# self.__call_batch__(
|
|
165
|
+
# {func_key: func_results[batch_idx] for func_key, func_results in results.items()}
|
|
166
|
+
# )
|
|
167
|
+
# for batch_idx in range(len(batches))
|
|
168
|
+
# ]
|
|
169
|
+
|
|
170
|
+
raise ValueError(f"Invalid retriever_funcs type: {type(self.retriever_funcs)}")
|
|
156
171
|
|
|
157
|
-
def __call_batch__(self, results:
|
|
172
|
+
def __call_batch__(self, results: Sequence[SimMap[K, S]]) -> SimMap[K, float]:
|
|
158
173
|
if self.strategy == "intersection":
|
|
159
174
|
return {
|
|
160
|
-
|
|
161
|
-
[result[
|
|
175
|
+
case_key: self.aggregator(
|
|
176
|
+
[result[case_key] for result in results if case_key in result]
|
|
162
177
|
)
|
|
163
|
-
for
|
|
178
|
+
for case_key in set().intersection(
|
|
164
179
|
*[set(result.keys()) for result in results]
|
|
165
180
|
)
|
|
166
181
|
}
|
|
167
182
|
|
|
168
183
|
elif self.strategy == "union":
|
|
169
184
|
return {
|
|
170
|
-
|
|
171
|
-
[result[
|
|
185
|
+
case_key: self.aggregator(
|
|
186
|
+
[result[case_key] for result in results if case_key in result]
|
|
172
187
|
)
|
|
173
188
|
for result in results
|
|
174
|
-
for
|
|
189
|
+
for case_key in result.keys()
|
|
175
190
|
}
|
|
176
191
|
|
|
177
192
|
raise ValueError(f"Unknown strategy: {self.strategy}")
|
|
@@ -15,6 +15,7 @@ from .attribute_value import AttributeValueSim, attribute_value
|
|
|
15
15
|
from .wrappers import (
|
|
16
16
|
attribute_table,
|
|
17
17
|
cache,
|
|
18
|
+
combine,
|
|
18
19
|
dynamic_table,
|
|
19
20
|
table,
|
|
20
21
|
transpose,
|
|
@@ -26,6 +27,7 @@ __all__ = [
|
|
|
26
27
|
"transpose",
|
|
27
28
|
"transpose_value",
|
|
28
29
|
"cache",
|
|
30
|
+
"combine",
|
|
29
31
|
"table",
|
|
30
32
|
"dynamic_table",
|
|
31
33
|
"type_table",
|
|
@@ -7,7 +7,10 @@ from .common import (
|
|
|
7
7
|
GraphSim,
|
|
8
8
|
SearchGraphSimFunc,
|
|
9
9
|
SearchState,
|
|
10
|
+
SearchStateInit,
|
|
10
11
|
SemanticEdgeSim,
|
|
12
|
+
init_empty,
|
|
13
|
+
init_unique_matches,
|
|
11
14
|
)
|
|
12
15
|
from .greedy import greedy
|
|
13
16
|
from .lap import lap
|
|
@@ -33,10 +36,13 @@ __all__ = [
|
|
|
33
36
|
"vf2",
|
|
34
37
|
"dtw",
|
|
35
38
|
"smith_waterman",
|
|
39
|
+
"init_empty",
|
|
40
|
+
"init_unique_matches",
|
|
36
41
|
"GraphSim",
|
|
37
42
|
"ElementMatcher",
|
|
38
43
|
"SemanticEdgeSim",
|
|
39
44
|
"BaseGraphSimFunc",
|
|
40
45
|
"SearchGraphSimFunc",
|
|
41
46
|
"SearchState",
|
|
47
|
+
"SearchStateInit",
|
|
42
48
|
]
|
|
@@ -1,12 +1,8 @@
|
|
|
1
1
|
import heapq
|
|
2
|
-
import itertools
|
|
3
|
-
from collections import defaultdict
|
|
4
2
|
from collections.abc import Mapping
|
|
5
3
|
from dataclasses import dataclass, field
|
|
6
4
|
from typing import Protocol
|
|
7
5
|
|
|
8
|
-
from frozendict import frozendict
|
|
9
|
-
|
|
10
6
|
from ...helpers import (
|
|
11
7
|
get_logger,
|
|
12
8
|
unpack_float,
|
|
@@ -18,25 +14,20 @@ from ...model.graph import (
|
|
|
18
14
|
)
|
|
19
15
|
from ...typing import SimFunc
|
|
20
16
|
from .common import (
|
|
21
|
-
ElementMatcher,
|
|
22
17
|
GraphSim,
|
|
23
18
|
SearchGraphSimFunc,
|
|
24
19
|
SearchState,
|
|
25
|
-
_induced_edge_mapping,
|
|
26
20
|
)
|
|
27
21
|
|
|
28
22
|
__all__ = [
|
|
29
23
|
"HeuristicFunc",
|
|
30
24
|
"SelectionFunc",
|
|
31
|
-
"InitFunc",
|
|
32
25
|
"h1",
|
|
33
26
|
"h2",
|
|
34
27
|
"h3",
|
|
35
28
|
"select1",
|
|
36
29
|
"select2",
|
|
37
30
|
"select3",
|
|
38
|
-
"init1",
|
|
39
|
-
"init2",
|
|
40
31
|
"build",
|
|
41
32
|
]
|
|
42
33
|
|
|
@@ -74,17 +65,6 @@ class SelectionFunc[K, N, E, G](Protocol):
|
|
|
74
65
|
) -> None | tuple[K, GraphElementType]: ...
|
|
75
66
|
|
|
76
67
|
|
|
77
|
-
class InitFunc[K, N, E, G](Protocol):
|
|
78
|
-
def __call__(
|
|
79
|
-
self,
|
|
80
|
-
x: Graph[K, N, E, G],
|
|
81
|
-
y: Graph[K, N, E, G],
|
|
82
|
-
node_matcher: ElementMatcher[N],
|
|
83
|
-
edge_matcher: ElementMatcher[E],
|
|
84
|
-
/,
|
|
85
|
-
) -> SearchState[K]: ...
|
|
86
|
-
|
|
87
|
-
|
|
88
68
|
@dataclass(slots=True, frozen=True)
|
|
89
69
|
class h1[K, N, E, G](HeuristicFunc[K, N, E, G]):
|
|
90
70
|
def __call__(
|
|
@@ -274,61 +254,6 @@ class select3[K, N, E, G](SelectionFunc[K, N, E, G]):
|
|
|
274
254
|
return selection_key, selection_type
|
|
275
255
|
|
|
276
256
|
|
|
277
|
-
@dataclass(slots=True, frozen=True)
|
|
278
|
-
class init1[K, N, E, G](InitFunc[K, N, E, G]):
|
|
279
|
-
def __call__(
|
|
280
|
-
self,
|
|
281
|
-
x: Graph[K, N, E, G],
|
|
282
|
-
y: Graph[K, N, E, G],
|
|
283
|
-
node_matcher: ElementMatcher[N],
|
|
284
|
-
edge_matcher: ElementMatcher[E],
|
|
285
|
-
) -> SearchState[K]:
|
|
286
|
-
return SearchState(
|
|
287
|
-
frozendict(),
|
|
288
|
-
frozendict(),
|
|
289
|
-
frozenset(y.nodes.keys()),
|
|
290
|
-
frozenset(y.edges.keys()),
|
|
291
|
-
frozenset(x.nodes.keys()),
|
|
292
|
-
frozenset(x.edges.keys()),
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
@dataclass(slots=True, init=False)
|
|
297
|
-
class init2[K, N, E, G](InitFunc[K, N, E, G]):
|
|
298
|
-
def __call__(
|
|
299
|
-
self,
|
|
300
|
-
x: Graph[K, N, E, G],
|
|
301
|
-
y: Graph[K, N, E, G],
|
|
302
|
-
node_matcher: ElementMatcher[N],
|
|
303
|
-
edge_matcher: ElementMatcher[E],
|
|
304
|
-
) -> SearchState[K]:
|
|
305
|
-
# pre-populate the mapping with nodes/edges that only have one possible legal mapping
|
|
306
|
-
possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
|
|
307
|
-
|
|
308
|
-
for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
|
|
309
|
-
if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
|
|
310
|
-
possible_node_mappings[y_key].add(x_key)
|
|
311
|
-
|
|
312
|
-
node_mapping: frozendict[K, K] = frozendict(
|
|
313
|
-
(y_key, next(iter(x_keys)))
|
|
314
|
-
for y_key, x_keys in possible_node_mappings.items()
|
|
315
|
-
if len(x_keys) == 1
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
edge_mapping: frozendict[K, K] = _induced_edge_mapping(
|
|
319
|
-
x, y, node_mapping, edge_matcher
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
return SearchState(
|
|
323
|
-
node_mapping,
|
|
324
|
-
edge_mapping,
|
|
325
|
-
frozenset(y.nodes.keys() - node_mapping.keys()),
|
|
326
|
-
frozenset(y.edges.keys() - edge_mapping.keys()),
|
|
327
|
-
frozenset(x.nodes.keys() - node_mapping.values()),
|
|
328
|
-
frozenset(x.edges.keys() - edge_mapping.values()),
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
|
|
332
257
|
@dataclass(slots=True)
|
|
333
258
|
class build[K, N, E, G](
|
|
334
259
|
SearchGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
|
|
@@ -355,7 +280,6 @@ class build[K, N, E, G](
|
|
|
355
280
|
|
|
356
281
|
heuristic_func: HeuristicFunc[K, N, E, G] = field(default_factory=h3)
|
|
357
282
|
selection_func: SelectionFunc[K, N, E, G] = field(default_factory=select3)
|
|
358
|
-
init_func: InitFunc[K, N, E, G] = field(default_factory=init1)
|
|
359
283
|
beam_width: int = 0
|
|
360
284
|
pathlength_weight: int = 0
|
|
361
285
|
|
|
@@ -446,7 +370,7 @@ class build[K, N, E, G](
|
|
|
446
370
|
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
447
371
|
|
|
448
372
|
open_set: list[PriorityState[K]] = []
|
|
449
|
-
best_state = self.
|
|
373
|
+
best_state = self.init_search_state(x, y)
|
|
450
374
|
heapq.heappush(open_set, PriorityState(0, best_state))
|
|
451
375
|
|
|
452
376
|
while open_set:
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import itertools
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from collections.abc import Mapping, Sequence
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Any, Protocol
|
|
5
|
+
from typing import Any, Protocol, cast
|
|
5
6
|
|
|
6
7
|
from frozendict import frozendict
|
|
7
8
|
|
|
@@ -9,10 +10,11 @@ from ...helpers import (
|
|
|
9
10
|
batchify_sim,
|
|
10
11
|
reverse_batch_positional,
|
|
11
12
|
reverse_positional,
|
|
13
|
+
total_params,
|
|
12
14
|
unpack_float,
|
|
13
15
|
)
|
|
14
16
|
from ...model.graph import Edge, Graph, Node
|
|
15
|
-
from ...typing import AnySimFunc, BatchSimFunc, Float, StructuredValue
|
|
17
|
+
from ...typing import AnySimFunc, BatchSimFunc, Float, SimFunc, StructuredValue
|
|
16
18
|
from ..wrappers import transpose_value
|
|
17
19
|
|
|
18
20
|
type PairSim[K] = Mapping[tuple[K, K], float]
|
|
@@ -256,7 +258,113 @@ class SearchState[K]:
|
|
|
256
258
|
open_x_edges: frozenset[K]
|
|
257
259
|
|
|
258
260
|
|
|
261
|
+
class SearchStateInit[K, N, E, G](Protocol):
|
|
262
|
+
def __call__(
|
|
263
|
+
self,
|
|
264
|
+
x: Graph[K, N, E, G],
|
|
265
|
+
y: Graph[K, N, E, G],
|
|
266
|
+
node_matcher: ElementMatcher[N],
|
|
267
|
+
edge_matcher: ElementMatcher[E],
|
|
268
|
+
/,
|
|
269
|
+
) -> SearchState[K]: ...
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@dataclass(slots=True, frozen=True)
|
|
273
|
+
class init_empty[K, N, E, G](SearchStateInit[K, N, E, G]):
|
|
274
|
+
def __call__(
|
|
275
|
+
self,
|
|
276
|
+
x: Graph[K, N, E, G],
|
|
277
|
+
y: Graph[K, N, E, G],
|
|
278
|
+
node_matcher: ElementMatcher[N],
|
|
279
|
+
edge_matcher: ElementMatcher[E],
|
|
280
|
+
) -> SearchState[K]:
|
|
281
|
+
return SearchState(
|
|
282
|
+
frozendict(),
|
|
283
|
+
frozendict(),
|
|
284
|
+
frozenset(y.nodes.keys()),
|
|
285
|
+
frozenset(y.edges.keys()),
|
|
286
|
+
frozenset(x.nodes.keys()),
|
|
287
|
+
frozenset(x.edges.keys()),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@dataclass(slots=True, init=False)
|
|
292
|
+
class init_unique_matches[K, N, E, G](SearchStateInit[K, N, E, G]):
|
|
293
|
+
def __call__(
|
|
294
|
+
self,
|
|
295
|
+
x: Graph[K, N, E, G],
|
|
296
|
+
y: Graph[K, N, E, G],
|
|
297
|
+
node_matcher: ElementMatcher[N],
|
|
298
|
+
edge_matcher: ElementMatcher[E],
|
|
299
|
+
) -> SearchState[K]:
|
|
300
|
+
# pre-populate the mapping with nodes/edges that only have one possible legal mapping
|
|
301
|
+
possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
|
|
302
|
+
|
|
303
|
+
for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
|
|
304
|
+
if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
|
|
305
|
+
possible_node_mappings[y_key].add(x_key)
|
|
306
|
+
|
|
307
|
+
node_mapping: frozendict[K, K] = frozendict(
|
|
308
|
+
(y_key, next(iter(x_keys)))
|
|
309
|
+
for y_key, x_keys in possible_node_mappings.items()
|
|
310
|
+
if len(x_keys) == 1
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
edge_mapping: frozendict[K, K] = _induced_edge_mapping(
|
|
314
|
+
x, y, node_mapping, edge_matcher
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return SearchState(
|
|
318
|
+
node_mapping,
|
|
319
|
+
edge_mapping,
|
|
320
|
+
frozenset(y.nodes.keys() - node_mapping.keys()),
|
|
321
|
+
frozenset(y.edges.keys() - edge_mapping.keys()),
|
|
322
|
+
frozenset(x.nodes.keys() - node_mapping.values()),
|
|
323
|
+
frozenset(x.edges.keys() - edge_mapping.values()),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@dataclass(slots=True)
|
|
259
328
|
class SearchGraphSimFunc[K, N, E, G](BaseGraphSimFunc[K, N, E, G]):
|
|
329
|
+
init_func: (
|
|
330
|
+
SearchStateInit[K, N, E, G] | AnySimFunc[Graph[K, N, E, G], GraphSim[K]]
|
|
331
|
+
) = field(default_factory=init_unique_matches)
|
|
332
|
+
|
|
333
|
+
def init_search_state(
|
|
334
|
+
self, x: Graph[K, N, E, G], y: Graph[K, N, E, G]
|
|
335
|
+
) -> SearchState[K]:
|
|
336
|
+
init_func_params = total_params(self.init_func)
|
|
337
|
+
sim: GraphSim[K]
|
|
338
|
+
|
|
339
|
+
if init_func_params == 4:
|
|
340
|
+
init_func = cast(SearchStateInit[K, N, E, G], self.init_func)
|
|
341
|
+
|
|
342
|
+
return init_func(x, y, self.node_matcher, self.edge_matcher)
|
|
343
|
+
|
|
344
|
+
elif init_func_params == 2:
|
|
345
|
+
init_func = cast(SimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func)
|
|
346
|
+
sim = init_func(x, y)
|
|
347
|
+
|
|
348
|
+
elif init_func_params == 1:
|
|
349
|
+
init_func = cast(
|
|
350
|
+
BatchSimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func
|
|
351
|
+
)
|
|
352
|
+
sim = init_func([(x, y)])[0]
|
|
353
|
+
|
|
354
|
+
else:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
f"Invalid number of parameters for init_func: {init_func_params}"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return SearchState(
|
|
360
|
+
node_mapping=sim.node_mapping,
|
|
361
|
+
edge_mapping=sim.edge_mapping,
|
|
362
|
+
open_y_nodes=frozenset(y.nodes.keys() - sim.node_mapping.keys()),
|
|
363
|
+
open_y_edges=frozenset(y.edges.keys() - sim.edge_mapping.keys()),
|
|
364
|
+
open_x_nodes=frozenset(x.nodes.keys() - sim.node_mapping.values()),
|
|
365
|
+
open_x_edges=frozenset(x.edges.keys() - sim.edge_mapping.values()),
|
|
366
|
+
)
|
|
367
|
+
|
|
260
368
|
def finished(self, state: SearchState[K]) -> bool:
|
|
261
369
|
# the following condition could save a few iterations, but needs to be tested
|
|
262
370
|
# return (not state.open_y_nodes and not state.open_y_edges) or (
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
|
-
from frozendict import frozendict
|
|
4
|
-
|
|
5
3
|
from ...helpers import (
|
|
6
4
|
get_logger,
|
|
7
5
|
)
|
|
@@ -41,20 +39,18 @@ class greedy[K, N, E, G](
|
|
|
41
39
|
# self_inv = dataclasses.replace(self, _invert=True)
|
|
42
40
|
# return self.invert_similarity(x, y, self_inv(x=y, y=x))
|
|
43
41
|
|
|
44
|
-
current_state = SearchState(
|
|
45
|
-
frozendict(),
|
|
46
|
-
frozendict(),
|
|
47
|
-
frozenset(y.nodes.keys()),
|
|
48
|
-
frozenset(y.edges.keys()),
|
|
49
|
-
frozenset(x.nodes.keys()),
|
|
50
|
-
frozenset(x.edges.keys()),
|
|
51
|
-
)
|
|
52
|
-
current_sim = GraphSim(
|
|
53
|
-
0.0, frozendict(), frozendict(), frozendict(), frozendict()
|
|
54
|
-
)
|
|
55
|
-
|
|
56
42
|
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
57
43
|
|
|
44
|
+
current_state = self.init_search_state(x, y)
|
|
45
|
+
current_sim = self.similarity(
|
|
46
|
+
x,
|
|
47
|
+
y,
|
|
48
|
+
current_state.node_mapping,
|
|
49
|
+
current_state.edge_mapping,
|
|
50
|
+
node_pair_sims,
|
|
51
|
+
edge_pair_sims,
|
|
52
|
+
)
|
|
53
|
+
|
|
58
54
|
while not self.finished(current_state):
|
|
59
55
|
# Iterate over all open pairs and find the best pair
|
|
60
56
|
next_states: list[SearchState[K]] = []
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import InitVar, dataclass, field
|
|
4
4
|
from typing import Any, cast, override
|
|
5
5
|
|
|
6
6
|
from ..helpers import batchify_sim, get_metadata, get_value, getitem_or_getattr
|
|
7
7
|
from ..typing import (
|
|
8
|
+
AggregatorFunc,
|
|
8
9
|
AnySimFunc,
|
|
9
10
|
BatchSimFunc,
|
|
10
11
|
ConversionFunc,
|
|
@@ -14,6 +15,7 @@ from ..typing import (
|
|
|
14
15
|
SimSeq,
|
|
15
16
|
StructuredValue,
|
|
16
17
|
)
|
|
18
|
+
from .aggregator import default_aggregator
|
|
17
19
|
from .generic import static
|
|
18
20
|
|
|
19
21
|
|
|
@@ -59,6 +61,67 @@ def transpose_value[V, S: Float](
|
|
|
59
61
|
return transpose(func, get_value)
|
|
60
62
|
|
|
61
63
|
|
|
64
|
+
@dataclass(slots=True)
|
|
65
|
+
class combine[V, S: Float](BatchSimFunc[V, float]):
|
|
66
|
+
"""Combines multiple similarity functions into one.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
sim_funcs: A list of similarity functions to be combined.
|
|
70
|
+
aggregator: A function to aggregate the results from the similarity functions.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A similarity function that combines the results from multiple similarity functions.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
sim_funcs: InitVar[Sequence[AnySimFunc[V, S]] | Mapping[str, AnySimFunc[V, S]]]
|
|
77
|
+
aggregator: AggregatorFunc[str, S] = default_aggregator
|
|
78
|
+
batch_sim_funcs: Sequence[BatchSimFunc[V, S]] | Mapping[str, BatchSimFunc[V, S]] = (
|
|
79
|
+
field(init=False, repr=False)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def __post_init__(
|
|
83
|
+
self, sim_funcs: Sequence[AnySimFunc[V, S]] | Mapping[str, AnySimFunc[V, S]]
|
|
84
|
+
):
|
|
85
|
+
if isinstance(sim_funcs, Sequence):
|
|
86
|
+
self.batch_sim_funcs = [batchify_sim(func) for func in sim_funcs]
|
|
87
|
+
elif isinstance(sim_funcs, Mapping):
|
|
88
|
+
self.batch_sim_funcs = {
|
|
89
|
+
key: batchify_sim(func) for key, func in sim_funcs.items()
|
|
90
|
+
}
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(f"Invalid sim_funcs type: {type(sim_funcs)}")
|
|
93
|
+
|
|
94
|
+
@override
|
|
95
|
+
def __call__(self, batches: Sequence[tuple[V, V]]) -> Sequence[float]:
|
|
96
|
+
if isinstance(self.batch_sim_funcs, Sequence):
|
|
97
|
+
func_results = [func(batches) for func in self.batch_sim_funcs]
|
|
98
|
+
|
|
99
|
+
return [
|
|
100
|
+
self.aggregator(
|
|
101
|
+
[batch_results[batch_idx] for batch_results in func_results]
|
|
102
|
+
)
|
|
103
|
+
for batch_idx in range(len(batches))
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
elif isinstance(self.batch_sim_funcs, Mapping):
|
|
107
|
+
func_results = {
|
|
108
|
+
func_key: func(batches)
|
|
109
|
+
for func_key, func in self.batch_sim_funcs.items()
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
return [
|
|
113
|
+
self.aggregator(
|
|
114
|
+
{
|
|
115
|
+
func_key: batch_results[batch_idx]
|
|
116
|
+
for func_key, batch_results in func_results.items()
|
|
117
|
+
}
|
|
118
|
+
)
|
|
119
|
+
for batch_idx in range(len(batches))
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
raise ValueError(f"Invalid batch_sim_funcs type: {type(self.batch_sim_funcs)}")
|
|
123
|
+
|
|
124
|
+
|
|
62
125
|
@dataclass(slots=True)
|
|
63
126
|
class cache[V, U, S: Float](BatchSimFunc[V, S]):
|
|
64
127
|
similarity_func: BatchSimFunc[V, S]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|