cbrkit 0.26.0__tar.gz → 0.26.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cbrkit-0.26.0 → cbrkit-0.26.2}/PKG-INFO +1 -1
- {cbrkit-0.26.0 → cbrkit-0.26.2}/pyproject.toml +1 -1
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/build.py +36 -21
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/__init__.py +2 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/__init__.py +9 -1
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/astar.py +1 -81
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/common.py +150 -53
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/dfs.py +4 -6
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/greedy.py +9 -19
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/lap.py +67 -14
- cbrkit-0.26.2/src/cbrkit/sim/graphs/precompute.py +80 -0
- cbrkit-0.26.2/src/cbrkit/sim/graphs/qap.py +145 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/vf2.py +143 -35
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/wrappers.py +64 -1
- cbrkit-0.26.0/src/cbrkit/sim/graphs/precompute.py +0 -56
- cbrkit-0.26.0/src/cbrkit/sim/graphs/qap.py +0 -118
- {cbrkit-0.26.0 → cbrkit-0.26.2}/README.md +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/__main__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/attribute_value.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/generic.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/numbers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/strings.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/api.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/cli.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/constants.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/cycle.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/dumpers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/eval/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/eval/common.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/eval/retrieval.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/helpers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/loaders.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/model/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/model/graph.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/model/result.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/py.typed +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/apply.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/rerank.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/reuse/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/reuse/apply.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/reuse/build.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/aggregator.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/attribute_value.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/collections.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/embed.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/generic.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/alignment.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/brute_force.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/numbers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/strings.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/taxonomy.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/apply.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/build.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/model.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/prompts.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/__init__.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/cohere.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/google.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/instructor.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/model.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/ollama.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/openai.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
- {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/typing.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: cbrkit
|
|
3
|
-
Version: 0.26.
|
|
3
|
+
Version: 0.26.2
|
|
4
4
|
Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
|
|
5
5
|
Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
|
|
6
6
|
Author: Mirko Lenz
|
|
@@ -2,7 +2,7 @@ import itertools
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from multiprocessing.pool import Pool
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Literal, override
|
|
6
6
|
|
|
7
7
|
from ..helpers import (
|
|
8
8
|
batchify_sim,
|
|
@@ -134,44 +134,59 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
|
|
|
134
134
|
A retriever function that combines the results from multiple retrievers.
|
|
135
135
|
"""
|
|
136
136
|
|
|
137
|
-
retriever_funcs:
|
|
138
|
-
aggregator: AggregatorFunc[
|
|
137
|
+
retriever_funcs: Sequence[RetrieverFunc[K, V, S]]
|
|
138
|
+
aggregator: AggregatorFunc[str, S] = default_aggregator
|
|
139
139
|
strategy: Literal["intersection", "union"] = "union"
|
|
140
140
|
|
|
141
141
|
@override
|
|
142
142
|
def __call__(
|
|
143
143
|
self, batches: Sequence[tuple[Casebase[K, V], V]]
|
|
144
144
|
) -> Sequence[SimMap[K, float]]:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
145
|
+
if isinstance(self.retriever_funcs, Sequence):
|
|
146
|
+
func_results = [
|
|
147
|
+
retriever_func(batches) for retriever_func in self.retriever_funcs
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
return [
|
|
151
|
+
self.__call_batch__(
|
|
152
|
+
[batch_results[batch_idx] for batch_results in func_results]
|
|
153
|
+
)
|
|
154
|
+
for batch_idx in range(len(batches))
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
# elif isinstance(self.retriever_funcs, Mapping):
|
|
158
|
+
# results = {
|
|
159
|
+
# func_key: retriever_func(batches)
|
|
160
|
+
# for func_key, retriever_func in self.retriever_funcs.items()
|
|
161
|
+
# }
|
|
162
|
+
|
|
163
|
+
# return [
|
|
164
|
+
# self.__call_batch__(
|
|
165
|
+
# {func_key: func_results[batch_idx] for func_key, func_results in results.items()}
|
|
166
|
+
# )
|
|
167
|
+
# for batch_idx in range(len(batches))
|
|
168
|
+
# ]
|
|
169
|
+
|
|
170
|
+
raise ValueError(f"Invalid retriever_funcs type: {type(self.retriever_funcs)}")
|
|
156
171
|
|
|
157
|
-
def __call_batch__(self, results:
|
|
172
|
+
def __call_batch__(self, results: Sequence[SimMap[K, S]]) -> SimMap[K, float]:
|
|
158
173
|
if self.strategy == "intersection":
|
|
159
174
|
return {
|
|
160
|
-
|
|
161
|
-
[result[
|
|
175
|
+
case_key: self.aggregator(
|
|
176
|
+
[result[case_key] for result in results if case_key in result]
|
|
162
177
|
)
|
|
163
|
-
for
|
|
178
|
+
for case_key in set().intersection(
|
|
164
179
|
*[set(result.keys()) for result in results]
|
|
165
180
|
)
|
|
166
181
|
}
|
|
167
182
|
|
|
168
183
|
elif self.strategy == "union":
|
|
169
184
|
return {
|
|
170
|
-
|
|
171
|
-
[result[
|
|
185
|
+
case_key: self.aggregator(
|
|
186
|
+
[result[case_key] for result in results if case_key in result]
|
|
172
187
|
)
|
|
173
188
|
for result in results
|
|
174
|
-
for
|
|
189
|
+
for case_key in result.keys()
|
|
175
190
|
}
|
|
176
191
|
|
|
177
192
|
raise ValueError(f"Unknown strategy: {self.strategy}")
|
|
@@ -15,6 +15,7 @@ from .attribute_value import AttributeValueSim, attribute_value
|
|
|
15
15
|
from .wrappers import (
|
|
16
16
|
attribute_table,
|
|
17
17
|
cache,
|
|
18
|
+
combine,
|
|
18
19
|
dynamic_table,
|
|
19
20
|
table,
|
|
20
21
|
transpose,
|
|
@@ -26,6 +27,7 @@ __all__ = [
|
|
|
26
27
|
"transpose",
|
|
27
28
|
"transpose_value",
|
|
28
29
|
"cache",
|
|
30
|
+
"combine",
|
|
29
31
|
"table",
|
|
30
32
|
"dynamic_table",
|
|
31
33
|
"type_table",
|
|
@@ -7,12 +7,15 @@ from .common import (
|
|
|
7
7
|
GraphSim,
|
|
8
8
|
SearchGraphSimFunc,
|
|
9
9
|
SearchState,
|
|
10
|
+
SearchStateInit,
|
|
10
11
|
SemanticEdgeSim,
|
|
12
|
+
init_empty,
|
|
13
|
+
init_unique_matches,
|
|
11
14
|
)
|
|
12
15
|
from .greedy import greedy
|
|
13
16
|
from .lap import lap
|
|
14
17
|
from .precompute import precompute
|
|
15
|
-
from .vf2 import vf2
|
|
18
|
+
from .vf2 import vf2, vf2_networkx, vf2_rustworkx
|
|
16
19
|
|
|
17
20
|
with optional_dependencies():
|
|
18
21
|
from .alignment import dtw
|
|
@@ -31,12 +34,17 @@ __all__ = [
|
|
|
31
34
|
"lap",
|
|
32
35
|
"precompute",
|
|
33
36
|
"vf2",
|
|
37
|
+
"vf2_networkx",
|
|
38
|
+
"vf2_rustworkx",
|
|
34
39
|
"dtw",
|
|
35
40
|
"smith_waterman",
|
|
41
|
+
"init_empty",
|
|
42
|
+
"init_unique_matches",
|
|
36
43
|
"GraphSim",
|
|
37
44
|
"ElementMatcher",
|
|
38
45
|
"SemanticEdgeSim",
|
|
39
46
|
"BaseGraphSimFunc",
|
|
40
47
|
"SearchGraphSimFunc",
|
|
41
48
|
"SearchState",
|
|
49
|
+
"SearchStateInit",
|
|
42
50
|
]
|
|
@@ -1,12 +1,8 @@
|
|
|
1
1
|
import heapq
|
|
2
|
-
import itertools
|
|
3
|
-
from collections import defaultdict
|
|
4
2
|
from collections.abc import Mapping
|
|
5
3
|
from dataclasses import dataclass, field
|
|
6
4
|
from typing import Protocol
|
|
7
5
|
|
|
8
|
-
from frozendict import frozendict
|
|
9
|
-
|
|
10
6
|
from ...helpers import (
|
|
11
7
|
get_logger,
|
|
12
8
|
unpack_float,
|
|
@@ -18,25 +14,20 @@ from ...model.graph import (
|
|
|
18
14
|
)
|
|
19
15
|
from ...typing import SimFunc
|
|
20
16
|
from .common import (
|
|
21
|
-
ElementMatcher,
|
|
22
17
|
GraphSim,
|
|
23
18
|
SearchGraphSimFunc,
|
|
24
19
|
SearchState,
|
|
25
|
-
_induced_edge_mapping,
|
|
26
20
|
)
|
|
27
21
|
|
|
28
22
|
__all__ = [
|
|
29
23
|
"HeuristicFunc",
|
|
30
24
|
"SelectionFunc",
|
|
31
|
-
"InitFunc",
|
|
32
25
|
"h1",
|
|
33
26
|
"h2",
|
|
34
27
|
"h3",
|
|
35
28
|
"select1",
|
|
36
29
|
"select2",
|
|
37
30
|
"select3",
|
|
38
|
-
"init1",
|
|
39
|
-
"init2",
|
|
40
31
|
"build",
|
|
41
32
|
]
|
|
42
33
|
|
|
@@ -74,17 +65,6 @@ class SelectionFunc[K, N, E, G](Protocol):
|
|
|
74
65
|
) -> None | tuple[K, GraphElementType]: ...
|
|
75
66
|
|
|
76
67
|
|
|
77
|
-
class InitFunc[K, N, E, G](Protocol):
|
|
78
|
-
def __call__(
|
|
79
|
-
self,
|
|
80
|
-
x: Graph[K, N, E, G],
|
|
81
|
-
y: Graph[K, N, E, G],
|
|
82
|
-
node_matcher: ElementMatcher[N],
|
|
83
|
-
edge_matcher: ElementMatcher[E],
|
|
84
|
-
/,
|
|
85
|
-
) -> SearchState[K]: ...
|
|
86
|
-
|
|
87
|
-
|
|
88
68
|
@dataclass(slots=True, frozen=True)
|
|
89
69
|
class h1[K, N, E, G](HeuristicFunc[K, N, E, G]):
|
|
90
70
|
def __call__(
|
|
@@ -274,61 +254,6 @@ class select3[K, N, E, G](SelectionFunc[K, N, E, G]):
|
|
|
274
254
|
return selection_key, selection_type
|
|
275
255
|
|
|
276
256
|
|
|
277
|
-
@dataclass(slots=True, frozen=True)
|
|
278
|
-
class init1[K, N, E, G](InitFunc[K, N, E, G]):
|
|
279
|
-
def __call__(
|
|
280
|
-
self,
|
|
281
|
-
x: Graph[K, N, E, G],
|
|
282
|
-
y: Graph[K, N, E, G],
|
|
283
|
-
node_matcher: ElementMatcher[N],
|
|
284
|
-
edge_matcher: ElementMatcher[E],
|
|
285
|
-
) -> SearchState[K]:
|
|
286
|
-
return SearchState(
|
|
287
|
-
frozendict(),
|
|
288
|
-
frozendict(),
|
|
289
|
-
frozenset(y.nodes.keys()),
|
|
290
|
-
frozenset(y.edges.keys()),
|
|
291
|
-
frozenset(x.nodes.keys()),
|
|
292
|
-
frozenset(x.edges.keys()),
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
@dataclass(slots=True, init=False)
|
|
297
|
-
class init2[K, N, E, G](InitFunc[K, N, E, G]):
|
|
298
|
-
def __call__(
|
|
299
|
-
self,
|
|
300
|
-
x: Graph[K, N, E, G],
|
|
301
|
-
y: Graph[K, N, E, G],
|
|
302
|
-
node_matcher: ElementMatcher[N],
|
|
303
|
-
edge_matcher: ElementMatcher[E],
|
|
304
|
-
) -> SearchState[K]:
|
|
305
|
-
# pre-populate the mapping with nodes/edges that only have one possible legal mapping
|
|
306
|
-
possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
|
|
307
|
-
|
|
308
|
-
for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
|
|
309
|
-
if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
|
|
310
|
-
possible_node_mappings[y_key].add(x_key)
|
|
311
|
-
|
|
312
|
-
node_mapping: frozendict[K, K] = frozendict(
|
|
313
|
-
(y_key, next(iter(x_keys)))
|
|
314
|
-
for y_key, x_keys in possible_node_mappings.items()
|
|
315
|
-
if len(x_keys) == 1
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
edge_mapping: frozendict[K, K] = _induced_edge_mapping(
|
|
319
|
-
x, y, node_mapping, edge_matcher
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
return SearchState(
|
|
323
|
-
node_mapping,
|
|
324
|
-
edge_mapping,
|
|
325
|
-
frozenset(y.nodes.keys() - node_mapping.keys()),
|
|
326
|
-
frozenset(y.edges.keys() - edge_mapping.keys()),
|
|
327
|
-
frozenset(x.nodes.keys() - node_mapping.values()),
|
|
328
|
-
frozenset(x.edges.keys() - edge_mapping.values()),
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
|
|
332
257
|
@dataclass(slots=True)
|
|
333
258
|
class build[K, N, E, G](
|
|
334
259
|
SearchGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
|
|
@@ -355,7 +280,6 @@ class build[K, N, E, G](
|
|
|
355
280
|
|
|
356
281
|
heuristic_func: HeuristicFunc[K, N, E, G] = field(default_factory=h3)
|
|
357
282
|
selection_func: SelectionFunc[K, N, E, G] = field(default_factory=select3)
|
|
358
|
-
init_func: InitFunc[K, N, E, G] = field(default_factory=init1)
|
|
359
283
|
beam_width: int = 0
|
|
360
284
|
pathlength_weight: int = 0
|
|
361
285
|
|
|
@@ -439,14 +363,10 @@ class build[K, N, E, G](
|
|
|
439
363
|
x: Graph[K, N, E, G],
|
|
440
364
|
y: Graph[K, N, E, G],
|
|
441
365
|
) -> GraphSim[K]:
|
|
442
|
-
# if len(y.nodes) + len(y.edges) > len(x.nodes) + len(x.edges):
|
|
443
|
-
# self_inv = dataclasses.replace(self, _invert=True)
|
|
444
|
-
# return self.invert_similarity(x, y, self_inv(x=y, y=x))
|
|
445
|
-
|
|
446
366
|
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
447
367
|
|
|
448
368
|
open_set: list[PriorityState[K]] = []
|
|
449
|
-
best_state = self.
|
|
369
|
+
best_state = self.init_search_state(x, y)
|
|
450
370
|
heapq.heappush(open_set, PriorityState(0, best_state))
|
|
451
371
|
|
|
452
372
|
while open_set:
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import itertools
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from collections.abc import Mapping, Sequence
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Any, Protocol
|
|
5
|
+
from typing import Any, Protocol, cast
|
|
5
6
|
|
|
6
7
|
from frozendict import frozendict
|
|
7
8
|
|
|
8
9
|
from ...helpers import (
|
|
9
10
|
batchify_sim,
|
|
10
|
-
|
|
11
|
-
reverse_positional,
|
|
11
|
+
total_params,
|
|
12
12
|
unpack_float,
|
|
13
|
+
unpack_floats,
|
|
13
14
|
)
|
|
14
|
-
from ...model.graph import
|
|
15
|
-
from ...typing import AnySimFunc, BatchSimFunc, Float, StructuredValue
|
|
15
|
+
from ...model.graph import Graph, Node
|
|
16
|
+
from ...typing import AnySimFunc, BatchSimFunc, Float, SimFunc, StructuredValue
|
|
16
17
|
from ..wrappers import transpose_value
|
|
17
18
|
|
|
18
19
|
type PairSim[K] = Mapping[tuple[K, K], float]
|
|
@@ -36,26 +37,38 @@ def default_element_matcher(x: Any, y: Any) -> bool:
|
|
|
36
37
|
|
|
37
38
|
@dataclass(slots=True, frozen=True)
|
|
38
39
|
class SemanticEdgeSim[K, N, E]:
|
|
39
|
-
source_weight: float = 0
|
|
40
|
-
target_weight: float = 0
|
|
40
|
+
source_weight: float = 1.0
|
|
41
|
+
target_weight: float = 1.0
|
|
42
|
+
edge_sim_func: AnySimFunc[E, Float] | None = None
|
|
41
43
|
|
|
42
44
|
def __call__(
|
|
43
45
|
self,
|
|
44
|
-
batches: Sequence[tuple[
|
|
46
|
+
batches: Sequence[tuple[E, E, float, float]],
|
|
45
47
|
) -> list[float]:
|
|
46
|
-
source_sims = (
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
48
|
+
source_sims = (source_sim for _, _, source_sim, _ in batches)
|
|
49
|
+
target_sims = (target_sim for _, _, _, target_sim in batches)
|
|
50
|
+
|
|
51
|
+
if self.edge_sim_func is not None:
|
|
52
|
+
edge_sim_func = batchify_sim(self.edge_sim_func)
|
|
53
|
+
edge_sims = unpack_floats(
|
|
54
|
+
edge_sim_func(
|
|
55
|
+
[(x, y) for x, y, _, _ in batches],
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
edge_sims = [1.0] * len(batches)
|
|
60
|
+
|
|
61
|
+
scaling_factor = self.source_weight + self.target_weight
|
|
62
|
+
|
|
63
|
+
if scaling_factor == 0:
|
|
64
|
+
return edge_sims
|
|
54
65
|
|
|
55
66
|
return [
|
|
56
|
-
(
|
|
57
|
-
|
|
58
|
-
for source, target in zip(
|
|
67
|
+
(edge * source * self.source_weight / scaling_factor)
|
|
68
|
+
+ (edge * target * self.target_weight / scaling_factor)
|
|
69
|
+
for source, target, edge in zip(
|
|
70
|
+
source_sims, target_sims, edge_sims, strict=True
|
|
71
|
+
)
|
|
59
72
|
]
|
|
60
73
|
|
|
61
74
|
|
|
@@ -80,37 +93,14 @@ def _induced_edge_mapping[K, N, E, G](
|
|
|
80
93
|
@dataclass(slots=True)
|
|
81
94
|
class BaseGraphSimFunc[K, N, E, G]:
|
|
82
95
|
node_sim_func: AnySimFunc[N, Float]
|
|
83
|
-
edge_sim_func:
|
|
84
|
-
default_edge_sim
|
|
85
|
-
)
|
|
96
|
+
edge_sim_func: SemanticEdgeSim[K, N, E] = default_edge_sim
|
|
86
97
|
node_matcher: ElementMatcher[N] = default_element_matcher
|
|
87
98
|
edge_matcher: ElementMatcher[E] = default_element_matcher
|
|
88
99
|
batch_node_sim_func: BatchSimFunc[Node[K, N], Float] = field(init=False)
|
|
89
|
-
batch_edge_sim_func: (
|
|
90
|
-
BatchSimFunc[Edge[K, N, E], Float] | SemanticEdgeSim[K, N, E]
|
|
91
|
-
) = field(init=False)
|
|
92
|
-
_invert: bool = False
|
|
93
100
|
|
|
94
101
|
def __post_init__(self) -> None:
|
|
95
102
|
self.batch_node_sim_func = batchify_sim(transpose_value(self.node_sim_func))
|
|
96
103
|
|
|
97
|
-
if isinstance(self.edge_sim_func, SemanticEdgeSim):
|
|
98
|
-
self.batch_edge_sim_func = self.edge_sim_func
|
|
99
|
-
else:
|
|
100
|
-
self.batch_edge_sim_func = batchify_sim(self.edge_sim_func)
|
|
101
|
-
|
|
102
|
-
if self._invert:
|
|
103
|
-
self.node_matcher = reverse_positional(self.node_matcher)
|
|
104
|
-
self.edge_matcher = reverse_positional(self.edge_matcher)
|
|
105
|
-
self.batch_node_sim_func = reverse_batch_positional(
|
|
106
|
-
self.batch_node_sim_func
|
|
107
|
-
)
|
|
108
|
-
if not isinstance(self.batch_edge_sim_func, SemanticEdgeSim):
|
|
109
|
-
# semantic edge sim is agnostic to order
|
|
110
|
-
self.batch_edge_sim_func = reverse_batch_positional(
|
|
111
|
-
self.batch_edge_sim_func
|
|
112
|
-
)
|
|
113
|
-
|
|
114
104
|
def induced_edge_mapping(
|
|
115
105
|
self,
|
|
116
106
|
x: Graph[K, N, E, G],
|
|
@@ -161,16 +151,17 @@ class BaseGraphSimFunc[K, N, E, G]:
|
|
|
161
151
|
]
|
|
162
152
|
|
|
163
153
|
edge_pair_values = [(x.edges[x_key], y.edges[y_key]) for y_key, x_key in pairs]
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
154
|
+
edge_pair_sims = self.edge_sim_func(
|
|
155
|
+
[
|
|
156
|
+
(
|
|
157
|
+
x_edge.value,
|
|
158
|
+
y_edge.value,
|
|
159
|
+
node_pair_sims[(y_edge.source.key, x_edge.source.key)],
|
|
160
|
+
node_pair_sims[(y_edge.target.key, x_edge.target.key)],
|
|
161
|
+
)
|
|
162
|
+
for x_edge, y_edge in edge_pair_values
|
|
163
|
+
]
|
|
164
|
+
)
|
|
174
165
|
|
|
175
166
|
return {
|
|
176
167
|
(y_edge.key, x_edge.key): unpack_float(sim)
|
|
@@ -256,7 +247,113 @@ class SearchState[K]:
|
|
|
256
247
|
open_x_edges: frozenset[K]
|
|
257
248
|
|
|
258
249
|
|
|
250
|
+
class SearchStateInit[K, N, E, G](Protocol):
|
|
251
|
+
def __call__(
|
|
252
|
+
self,
|
|
253
|
+
x: Graph[K, N, E, G],
|
|
254
|
+
y: Graph[K, N, E, G],
|
|
255
|
+
node_matcher: ElementMatcher[N],
|
|
256
|
+
edge_matcher: ElementMatcher[E],
|
|
257
|
+
/,
|
|
258
|
+
) -> SearchState[K]: ...
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dataclass(slots=True, frozen=True)
|
|
262
|
+
class init_empty[K, N, E, G](SearchStateInit[K, N, E, G]):
|
|
263
|
+
def __call__(
|
|
264
|
+
self,
|
|
265
|
+
x: Graph[K, N, E, G],
|
|
266
|
+
y: Graph[K, N, E, G],
|
|
267
|
+
node_matcher: ElementMatcher[N],
|
|
268
|
+
edge_matcher: ElementMatcher[E],
|
|
269
|
+
) -> SearchState[K]:
|
|
270
|
+
return SearchState(
|
|
271
|
+
frozendict(),
|
|
272
|
+
frozendict(),
|
|
273
|
+
frozenset(y.nodes.keys()),
|
|
274
|
+
frozenset(y.edges.keys()),
|
|
275
|
+
frozenset(x.nodes.keys()),
|
|
276
|
+
frozenset(x.edges.keys()),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@dataclass(slots=True, init=False)
|
|
281
|
+
class init_unique_matches[K, N, E, G](SearchStateInit[K, N, E, G]):
|
|
282
|
+
def __call__(
|
|
283
|
+
self,
|
|
284
|
+
x: Graph[K, N, E, G],
|
|
285
|
+
y: Graph[K, N, E, G],
|
|
286
|
+
node_matcher: ElementMatcher[N],
|
|
287
|
+
edge_matcher: ElementMatcher[E],
|
|
288
|
+
) -> SearchState[K]:
|
|
289
|
+
# pre-populate the mapping with nodes/edges that only have one possible legal mapping
|
|
290
|
+
possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
|
|
291
|
+
|
|
292
|
+
for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
|
|
293
|
+
if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
|
|
294
|
+
possible_node_mappings[y_key].add(x_key)
|
|
295
|
+
|
|
296
|
+
node_mapping: frozendict[K, K] = frozendict(
|
|
297
|
+
(y_key, next(iter(x_keys)))
|
|
298
|
+
for y_key, x_keys in possible_node_mappings.items()
|
|
299
|
+
if len(x_keys) == 1
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
edge_mapping: frozendict[K, K] = _induced_edge_mapping(
|
|
303
|
+
x, y, node_mapping, edge_matcher
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return SearchState(
|
|
307
|
+
node_mapping,
|
|
308
|
+
edge_mapping,
|
|
309
|
+
frozenset(y.nodes.keys() - node_mapping.keys()),
|
|
310
|
+
frozenset(y.edges.keys() - edge_mapping.keys()),
|
|
311
|
+
frozenset(x.nodes.keys() - node_mapping.values()),
|
|
312
|
+
frozenset(x.edges.keys() - edge_mapping.values()),
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@dataclass(slots=True)
|
|
259
317
|
class SearchGraphSimFunc[K, N, E, G](BaseGraphSimFunc[K, N, E, G]):
|
|
318
|
+
init_func: (
|
|
319
|
+
SearchStateInit[K, N, E, G] | AnySimFunc[Graph[K, N, E, G], GraphSim[K]]
|
|
320
|
+
) = field(default_factory=init_unique_matches)
|
|
321
|
+
|
|
322
|
+
def init_search_state(
|
|
323
|
+
self, x: Graph[K, N, E, G], y: Graph[K, N, E, G]
|
|
324
|
+
) -> SearchState[K]:
|
|
325
|
+
init_func_params = total_params(self.init_func)
|
|
326
|
+
sim: GraphSim[K]
|
|
327
|
+
|
|
328
|
+
if init_func_params == 4:
|
|
329
|
+
init_func = cast(SearchStateInit[K, N, E, G], self.init_func)
|
|
330
|
+
|
|
331
|
+
return init_func(x, y, self.node_matcher, self.edge_matcher)
|
|
332
|
+
|
|
333
|
+
elif init_func_params == 2:
|
|
334
|
+
init_func = cast(SimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func)
|
|
335
|
+
sim = init_func(x, y)
|
|
336
|
+
|
|
337
|
+
elif init_func_params == 1:
|
|
338
|
+
init_func = cast(
|
|
339
|
+
BatchSimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func
|
|
340
|
+
)
|
|
341
|
+
sim = init_func([(x, y)])[0]
|
|
342
|
+
|
|
343
|
+
else:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Invalid number of parameters for init_func: {init_func_params}"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return SearchState(
|
|
349
|
+
node_mapping=sim.node_mapping,
|
|
350
|
+
edge_mapping=sim.edge_mapping,
|
|
351
|
+
open_y_nodes=frozenset(y.nodes.keys() - sim.node_mapping.keys()),
|
|
352
|
+
open_y_edges=frozenset(y.edges.keys() - sim.edge_mapping.keys()),
|
|
353
|
+
open_x_nodes=frozenset(x.nodes.keys() - sim.node_mapping.values()),
|
|
354
|
+
open_x_edges=frozenset(x.edges.keys() - sim.edge_mapping.values()),
|
|
355
|
+
)
|
|
356
|
+
|
|
260
357
|
def finished(self, state: SearchState[K]) -> bool:
|
|
261
358
|
# the following condition could save a few iterations, but needs to be tested
|
|
262
359
|
# return (not state.open_y_nodes and not state.open_y_edges) or (
|
|
@@ -11,8 +11,6 @@ from .common import BaseGraphSimFunc, GraphSim
|
|
|
11
11
|
|
|
12
12
|
logger = get_logger(__name__)
|
|
13
13
|
|
|
14
|
-
__all__ = ["dfs"]
|
|
15
|
-
|
|
16
14
|
|
|
17
15
|
class RootsFunc[K, N, E, G](Protocol):
|
|
18
16
|
"""Support for matching rooted graphs
|
|
@@ -37,10 +35,10 @@ with optional_dependencies():
|
|
|
37
35
|
class dfs[K, N, E, G](
|
|
38
36
|
BaseGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
|
|
39
37
|
):
|
|
40
|
-
node_del_cost: float =
|
|
41
|
-
node_ins_cost: float =
|
|
42
|
-
edge_del_cost: float =
|
|
43
|
-
edge_ins_cost: float =
|
|
38
|
+
node_del_cost: float = 1.0
|
|
39
|
+
node_ins_cost: float = 1.0
|
|
40
|
+
edge_del_cost: float = 1.0
|
|
41
|
+
edge_ins_cost: float = 1.0
|
|
44
42
|
max_iterations: int = 0
|
|
45
43
|
upper_bound: float | None = None
|
|
46
44
|
strictly_decreasing: bool = True
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
|
-
from frozendict import frozendict
|
|
4
|
-
|
|
5
3
|
from ...helpers import (
|
|
6
4
|
get_logger,
|
|
7
5
|
)
|
|
@@ -13,8 +11,6 @@ from .common import GraphSim, SearchGraphSimFunc, SearchState
|
|
|
13
11
|
|
|
14
12
|
logger = get_logger(__name__)
|
|
15
13
|
|
|
16
|
-
__all__ = ["greedy"]
|
|
17
|
-
|
|
18
14
|
|
|
19
15
|
@dataclass(slots=True)
|
|
20
16
|
class greedy[K, N, E, G](
|
|
@@ -37,23 +33,17 @@ class greedy[K, N, E, G](
|
|
|
37
33
|
x: Graph[K, N, E, G],
|
|
38
34
|
y: Graph[K, N, E, G],
|
|
39
35
|
) -> GraphSim[K]:
|
|
40
|
-
|
|
41
|
-
# self_inv = dataclasses.replace(self, _invert=True)
|
|
42
|
-
# return self.invert_similarity(x, y, self_inv(x=y, y=x))
|
|
36
|
+
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
43
37
|
|
|
44
|
-
current_state =
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
38
|
+
current_state = self.init_search_state(x, y)
|
|
39
|
+
current_sim = self.similarity(
|
|
40
|
+
x,
|
|
41
|
+
y,
|
|
42
|
+
current_state.node_mapping,
|
|
43
|
+
current_state.edge_mapping,
|
|
44
|
+
node_pair_sims,
|
|
45
|
+
edge_pair_sims,
|
|
51
46
|
)
|
|
52
|
-
current_sim = GraphSim(
|
|
53
|
-
0.0, frozendict(), frozendict(), frozendict(), frozendict()
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
57
47
|
|
|
58
48
|
while not self.finished(current_state):
|
|
59
49
|
# Iterate over all open pairs and find the best pair
|