cbrkit 0.26.1__tar.gz → 0.26.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cbrkit-0.26.1 → cbrkit-0.26.3}/PKG-INFO +1 -1
- {cbrkit-0.26.1 → cbrkit-0.26.3}/pyproject.toml +1 -1
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/retrieval/build.py +52 -25
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/retrieval/rerank.py +4 -1
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/__init__.py +3 -1
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/astar.py +64 -28
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/brute_force.py +10 -2
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/common.py +46 -55
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/dfs.py +4 -6
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/greedy.py +0 -6
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/lap.py +67 -14
- cbrkit-0.26.3/src/cbrkit/sim/graphs/precompute.py +80 -0
- cbrkit-0.26.3/src/cbrkit/sim/graphs/qap.py +145 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/vf2.py +143 -35
- cbrkit-0.26.1/src/cbrkit/sim/graphs/precompute.py +0 -56
- cbrkit-0.26.1/src/cbrkit/sim/graphs/qap.py +0 -118
- {cbrkit-0.26.1 → cbrkit-0.26.3}/README.md +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/__main__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/adapt/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/adapt/attribute_value.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/adapt/generic.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/adapt/numbers.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/adapt/strings.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/api.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/cli.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/constants.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/cycle.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/dumpers.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/eval/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/eval/common.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/eval/retrieval.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/helpers.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/loaders.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/model/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/model/graph.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/model/result.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/py.typed +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/retrieval/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/retrieval/apply.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/reuse/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/reuse/apply.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/reuse/build.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/aggregator.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/attribute_value.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/collections.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/embed.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/generic.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/graphs/alignment.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/numbers.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/strings.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/taxonomy.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/sim/wrappers.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/apply.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/build.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/model.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/prompts.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/__init__.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/cohere.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/google.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/instructor.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/model.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/ollama.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/openai.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
- {cbrkit-0.26.1 → cbrkit-0.26.3}/src/cbrkit/typing.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: cbrkit
|
|
3
|
-
Version: 0.26.
|
|
3
|
+
Version: 0.26.3
|
|
4
4
|
Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
|
|
5
5
|
Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
|
|
6
6
|
Author: Mirko Lenz
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import itertools
|
|
2
|
-
from collections.abc import Sequence
|
|
2
|
+
from collections.abc import Mapping, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from multiprocessing.pool import Pool
|
|
5
5
|
from typing import Literal, override
|
|
@@ -129,14 +129,18 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
|
|
|
129
129
|
retriever_funcs: A list of retriever functions to be combined.
|
|
130
130
|
aggregator: A function to aggregate the results from the retriever functions.
|
|
131
131
|
strategy: The strategy to combine the results. Either "intersection" or "union".
|
|
132
|
+
default_sim: The default similarity value to use for strategy "union" when a case is not found in one of the retriever results.
|
|
132
133
|
|
|
133
134
|
Returns:
|
|
134
135
|
A retriever function that combines the results from multiple retrievers.
|
|
135
136
|
"""
|
|
136
137
|
|
|
137
|
-
retriever_funcs:
|
|
138
|
-
|
|
138
|
+
retriever_funcs: (
|
|
139
|
+
Sequence[RetrieverFunc[K, V, S]] | Mapping[str, RetrieverFunc[K, V, S]]
|
|
140
|
+
)
|
|
141
|
+
aggregator: AggregatorFunc[str, S | float] = default_aggregator
|
|
139
142
|
strategy: Literal["intersection", "union"] = "union"
|
|
143
|
+
default_sim: float = 0.0
|
|
140
144
|
|
|
141
145
|
@override
|
|
142
146
|
def __call__(
|
|
@@ -154,42 +158,65 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
|
|
|
154
158
|
for batch_idx in range(len(batches))
|
|
155
159
|
]
|
|
156
160
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
161
|
+
elif isinstance(self.retriever_funcs, Mapping):
|
|
162
|
+
results = {
|
|
163
|
+
func_key: retriever_func(batches)
|
|
164
|
+
for func_key, retriever_func in self.retriever_funcs.items()
|
|
165
|
+
}
|
|
162
166
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
167
|
+
return [
|
|
168
|
+
self.__call_batch__(
|
|
169
|
+
{
|
|
170
|
+
func_key: func_results[batch_idx]
|
|
171
|
+
for func_key, func_results in results.items()
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
for batch_idx in range(len(batches))
|
|
175
|
+
]
|
|
169
176
|
|
|
170
177
|
raise ValueError(f"Invalid retriever_funcs type: {type(self.retriever_funcs)}")
|
|
171
178
|
|
|
172
|
-
def __call_batch__(
|
|
173
|
-
|
|
179
|
+
def __call_batch__(
|
|
180
|
+
self, results: Sequence[SimMap[K, S]] | Mapping[str, SimMap[K, S]]
|
|
181
|
+
) -> SimMap[K, float]:
|
|
182
|
+
case_keys: set[K]
|
|
183
|
+
|
|
184
|
+
if isinstance(results, Sequence):
|
|
185
|
+
if self.strategy == "intersection":
|
|
186
|
+
case_keys = set().intersection(*(result.keys() for result in results))
|
|
187
|
+
elif self.strategy == "union":
|
|
188
|
+
case_keys = set().union(*(result.keys() for result in results))
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f"Unknown strategy: {self.strategy}")
|
|
191
|
+
|
|
174
192
|
return {
|
|
175
193
|
case_key: self.aggregator(
|
|
176
|
-
[result
|
|
177
|
-
)
|
|
178
|
-
for case_key in set().intersection(
|
|
179
|
-
*[set(result.keys()) for result in results]
|
|
194
|
+
[result.get(case_key, self.default_sim) for result in results]
|
|
180
195
|
)
|
|
196
|
+
for case_key in case_keys
|
|
181
197
|
}
|
|
182
198
|
|
|
183
|
-
elif
|
|
199
|
+
elif isinstance(results, Mapping):
|
|
200
|
+
if self.strategy == "intersection":
|
|
201
|
+
case_keys = set().intersection(
|
|
202
|
+
*(result.keys() for result in results.values())
|
|
203
|
+
)
|
|
204
|
+
elif self.strategy == "union":
|
|
205
|
+
case_keys = set().union(*(result.keys() for result in results.values()))
|
|
206
|
+
else:
|
|
207
|
+
raise ValueError(f"Unknown strategy: {self.strategy}")
|
|
208
|
+
|
|
184
209
|
return {
|
|
185
210
|
case_key: self.aggregator(
|
|
186
|
-
|
|
211
|
+
{
|
|
212
|
+
func_key: result.get(case_key, self.default_sim)
|
|
213
|
+
for func_key, result in results.items()
|
|
214
|
+
}
|
|
187
215
|
)
|
|
188
|
-
for
|
|
189
|
-
for case_key in result.keys()
|
|
216
|
+
for case_key in case_keys
|
|
190
217
|
}
|
|
191
218
|
|
|
192
|
-
raise ValueError(f"
|
|
219
|
+
raise ValueError(f"Invalid results type: {type(results)}")
|
|
193
220
|
|
|
194
221
|
|
|
195
222
|
@dataclass(slots=True, frozen=True)
|
|
@@ -311,12 +311,15 @@ with optional_dependencies():
|
|
|
311
311
|
k=len(casebase),
|
|
312
312
|
)
|
|
313
313
|
max_score = np.max(scores)
|
|
314
|
+
min_score = np.min(scores)
|
|
314
315
|
|
|
315
316
|
key_index = {idx: key for idx, key in enumerate(casebase)}
|
|
316
317
|
|
|
317
318
|
return [
|
|
318
319
|
{
|
|
319
|
-
key_index[case_id]: float(
|
|
320
|
+
key_index[case_id]: float(
|
|
321
|
+
(score - min_score) / (max_score - min_score)
|
|
322
|
+
)
|
|
320
323
|
for case_id, score in zip(
|
|
321
324
|
results[query_id], scores[query_id], strict=True
|
|
322
325
|
)
|
|
@@ -15,7 +15,7 @@ from .common import (
|
|
|
15
15
|
from .greedy import greedy
|
|
16
16
|
from .lap import lap
|
|
17
17
|
from .precompute import precompute
|
|
18
|
-
from .vf2 import vf2
|
|
18
|
+
from .vf2 import vf2, vf2_networkx, vf2_rustworkx
|
|
19
19
|
|
|
20
20
|
with optional_dependencies():
|
|
21
21
|
from .alignment import dtw
|
|
@@ -34,6 +34,8 @@ __all__ = [
|
|
|
34
34
|
"lap",
|
|
35
35
|
"precompute",
|
|
36
36
|
"vf2",
|
|
37
|
+
"vf2_networkx",
|
|
38
|
+
"vf2_rustworkx",
|
|
37
39
|
"dtw",
|
|
38
40
|
"smith_waterman",
|
|
39
41
|
"init_empty",
|
|
@@ -60,7 +60,6 @@ class SelectionFunc[K, N, E, G](Protocol):
|
|
|
60
60
|
s: SearchState[K],
|
|
61
61
|
node_pair_sims: Mapping[tuple[K, K], float],
|
|
62
62
|
edge_pair_sims: Mapping[tuple[K, K], float],
|
|
63
|
-
heuristic_func: HeuristicFunc[K, N, E, G],
|
|
64
63
|
/,
|
|
65
64
|
) -> None | tuple[K, GraphElementType]: ...
|
|
66
65
|
|
|
@@ -127,7 +126,7 @@ class h3[K, N, E, G](HeuristicFunc[K, N, E, G]):
|
|
|
127
126
|
default=0.0,
|
|
128
127
|
)
|
|
129
128
|
|
|
130
|
-
def
|
|
129
|
+
def can_map(x_node: Node[K, N], y_node: Node[K, N]) -> bool:
|
|
131
130
|
return x_node.key == s.node_mapping.get(y_node.key) or (
|
|
132
131
|
y_node.key in s.open_y_nodes and x_node.key in s.open_x_nodes
|
|
133
132
|
)
|
|
@@ -137,8 +136,8 @@ class h3[K, N, E, G](HeuristicFunc[K, N, E, G]):
|
|
|
137
136
|
(
|
|
138
137
|
edge_pair_sims.get((y_key, x_key), 0.0)
|
|
139
138
|
for x_key in s.open_x_edges
|
|
140
|
-
if
|
|
141
|
-
and
|
|
139
|
+
if can_map(x.edges[x_key].source, y.edges[y_key].source)
|
|
140
|
+
and can_map(x.edges[x_key].target, y.edges[y_key].target)
|
|
142
141
|
),
|
|
143
142
|
default=0.0,
|
|
144
143
|
)
|
|
@@ -155,7 +154,6 @@ class select1[K, N, E, G](SelectionFunc[K, N, E, G]):
|
|
|
155
154
|
s: SearchState[K],
|
|
156
155
|
node_pair_sims: Mapping[tuple[K, K], float],
|
|
157
156
|
edge_pair_sims: Mapping[tuple[K, K], float],
|
|
158
|
-
heuristic_func: HeuristicFunc[K, N, E, G],
|
|
159
157
|
) -> None | tuple[K, GraphElementType]:
|
|
160
158
|
"""Select the next node or edge to be mapped"""
|
|
161
159
|
|
|
@@ -181,7 +179,6 @@ class select2[K, N, E, G](SelectionFunc[K, N, E, G]):
|
|
|
181
179
|
s: SearchState[K],
|
|
182
180
|
node_pair_sims: Mapping[tuple[K, K], float],
|
|
183
181
|
edge_pair_sims: Mapping[tuple[K, K], float],
|
|
184
|
-
heuristic_func: HeuristicFunc[K, N, E, G],
|
|
185
182
|
) -> None | tuple[K, GraphElementType]:
|
|
186
183
|
"""Select the next node or edge to be mapped"""
|
|
187
184
|
|
|
@@ -212,36 +209,58 @@ class select3[K, N, E, G](SelectionFunc[K, N, E, G]):
|
|
|
212
209
|
s: SearchState[K],
|
|
213
210
|
node_pair_sims: Mapping[tuple[K, K], float],
|
|
214
211
|
edge_pair_sims: Mapping[tuple[K, K], float],
|
|
215
|
-
heuristic_func: HeuristicFunc[K, N, E, G],
|
|
216
212
|
) -> None | tuple[K, GraphElementType]:
|
|
217
213
|
"""Select the next node or edge to be mapped"""
|
|
218
214
|
|
|
219
|
-
|
|
215
|
+
mapping_options: dict[tuple[K, GraphElementType], int] = {}
|
|
216
|
+
heuristic_scores: dict[tuple[K, GraphElementType], float] = {}
|
|
220
217
|
|
|
221
218
|
for y_key in s.open_y_nodes:
|
|
222
|
-
|
|
223
|
-
(
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
219
|
+
h_vals = [
|
|
220
|
+
node_pair_sims[(y_key, x_key)]
|
|
221
|
+
for x_key in s.open_x_nodes
|
|
222
|
+
if (y_key, x_key) in node_pair_sims
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
if h_vals:
|
|
226
|
+
mapping_options[(y_key, "node")] = len(h_vals)
|
|
227
|
+
heuristic_scores[(y_key, "node")] = max(h_vals)
|
|
228
|
+
|
|
229
|
+
def can_map(x_node: Node[K, N], y_node: Node[K, N]) -> bool:
|
|
230
|
+
return x_node.key == s.node_mapping.get(y_node.key) or (
|
|
231
|
+
y_node.key in s.open_y_nodes and x_node.key in s.open_x_nodes
|
|
228
232
|
)
|
|
229
233
|
|
|
230
234
|
for y_key in s.open_y_edges:
|
|
231
|
-
|
|
232
|
-
(
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
)
|
|
237
|
-
|
|
235
|
+
h_vals = [
|
|
236
|
+
edge_pair_sims[(y_key, x_key)]
|
|
237
|
+
for x_key in s.open_x_edges
|
|
238
|
+
if (y_key, x_key) in edge_pair_sims
|
|
239
|
+
and can_map(x.edges[x_key].source, y.edges[y_key].source)
|
|
240
|
+
and can_map(x.edges[x_key].target, y.edges[y_key].target)
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
if h_vals:
|
|
244
|
+
mapping_options[(y_key, "edge")] = len(h_vals)
|
|
245
|
+
heuristic_scores[(y_key, "edge")] = max(h_vals)
|
|
238
246
|
|
|
239
247
|
if not heuristic_scores:
|
|
248
|
+
# Fallback: select any remaining node or edge for null mapping
|
|
249
|
+
if s.open_y_nodes:
|
|
250
|
+
return next(iter(s.open_y_nodes)), "node"
|
|
251
|
+
elif s.open_y_edges:
|
|
252
|
+
return next(iter(s.open_y_edges)), "edge"
|
|
240
253
|
return None
|
|
241
254
|
|
|
242
|
-
|
|
255
|
+
max_score = max(heuristic_scores.values())
|
|
256
|
+
best_selections = [
|
|
257
|
+
key for key, value in heuristic_scores.items() if value == max_score
|
|
258
|
+
]
|
|
243
259
|
|
|
244
|
-
|
|
260
|
+
# if multiple selections have the same score, select the one with the lowest number of possible mappings
|
|
261
|
+
best_selection = min(best_selections, key=lambda key: mapping_options[key])
|
|
262
|
+
|
|
263
|
+
selection_key, selection_type = best_selection
|
|
245
264
|
|
|
246
265
|
if selection_type == "edge":
|
|
247
266
|
edge = y.edges[selection_key]
|
|
@@ -300,7 +319,6 @@ class build[K, N, E, G](
|
|
|
300
319
|
state,
|
|
301
320
|
node_pair_sims,
|
|
302
321
|
edge_pair_sims,
|
|
303
|
-
self.heuristic_func,
|
|
304
322
|
)
|
|
305
323
|
|
|
306
324
|
if selection is None:
|
|
@@ -363,19 +381,37 @@ class build[K, N, E, G](
|
|
|
363
381
|
x: Graph[K, N, E, G],
|
|
364
382
|
y: Graph[K, N, E, G],
|
|
365
383
|
) -> GraphSim[K]:
|
|
366
|
-
# if len(y.nodes) + len(y.edges) > len(x.nodes) + len(x.edges):
|
|
367
|
-
# self_inv = dataclasses.replace(self, _invert=True)
|
|
368
|
-
# return self.invert_similarity(x, y, self_inv(x=y, y=x))
|
|
369
|
-
|
|
370
384
|
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
371
385
|
|
|
372
386
|
open_set: list[PriorityState[K]] = []
|
|
373
387
|
best_state = self.init_search_state(x, y)
|
|
388
|
+
# best_similarity = self.similarity(
|
|
389
|
+
# x,
|
|
390
|
+
# y,
|
|
391
|
+
# best_state.node_mapping,
|
|
392
|
+
# best_state.edge_mapping,
|
|
393
|
+
# node_pair_sims,
|
|
394
|
+
# edge_pair_sims,
|
|
395
|
+
# )
|
|
374
396
|
heapq.heappush(open_set, PriorityState(0, best_state))
|
|
375
397
|
|
|
376
398
|
while open_set:
|
|
377
399
|
first_elem = heapq.heappop(open_set)
|
|
378
400
|
current_state = first_elem.state
|
|
401
|
+
# current_similarity = self.similarity(
|
|
402
|
+
# x,
|
|
403
|
+
# y,
|
|
404
|
+
# current_state.node_mapping,
|
|
405
|
+
# current_state.edge_mapping,
|
|
406
|
+
# node_pair_sims,
|
|
407
|
+
# edge_pair_sims,
|
|
408
|
+
# )
|
|
409
|
+
|
|
410
|
+
# not needed because we add null mappings and
|
|
411
|
+
# the first item of the queue is always the best one
|
|
412
|
+
# if current_similarity.value > best_similarity.value:
|
|
413
|
+
# best_state = current_state
|
|
414
|
+
# best_similarity = current_similarity
|
|
379
415
|
|
|
380
416
|
if self.finished(current_state):
|
|
381
417
|
best_state = current_state
|
|
@@ -74,8 +74,16 @@ class brute_force[K, N, E, G](
|
|
|
74
74
|
|
|
75
75
|
if next_sim and (
|
|
76
76
|
next_sim.value > best_sim.value
|
|
77
|
-
or
|
|
78
|
-
|
|
77
|
+
or (
|
|
78
|
+
next_sim.value >= best_sim.value
|
|
79
|
+
and (
|
|
80
|
+
len(next_sim.node_mapping) > len(best_sim.node_mapping)
|
|
81
|
+
or (
|
|
82
|
+
len(next_sim.edge_mapping)
|
|
83
|
+
> len(best_sim.edge_mapping)
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
)
|
|
79
87
|
):
|
|
80
88
|
best_sim = next_sim
|
|
81
89
|
|
|
@@ -8,12 +8,11 @@ from frozendict import frozendict
|
|
|
8
8
|
|
|
9
9
|
from ...helpers import (
|
|
10
10
|
batchify_sim,
|
|
11
|
-
reverse_batch_positional,
|
|
12
|
-
reverse_positional,
|
|
13
11
|
total_params,
|
|
14
12
|
unpack_float,
|
|
13
|
+
unpack_floats,
|
|
15
14
|
)
|
|
16
|
-
from ...model.graph import
|
|
15
|
+
from ...model.graph import Graph, Node
|
|
17
16
|
from ...typing import AnySimFunc, BatchSimFunc, Float, SimFunc, StructuredValue
|
|
18
17
|
from ..wrappers import transpose_value
|
|
19
18
|
|
|
@@ -38,26 +37,38 @@ def default_element_matcher(x: Any, y: Any) -> bool:
|
|
|
38
37
|
|
|
39
38
|
@dataclass(slots=True, frozen=True)
|
|
40
39
|
class SemanticEdgeSim[K, N, E]:
|
|
41
|
-
source_weight: float = 0
|
|
42
|
-
target_weight: float = 0
|
|
40
|
+
source_weight: float = 1.0
|
|
41
|
+
target_weight: float = 1.0
|
|
42
|
+
edge_sim_func: AnySimFunc[E, Float] | None = None
|
|
43
43
|
|
|
44
44
|
def __call__(
|
|
45
45
|
self,
|
|
46
|
-
batches: Sequence[tuple[
|
|
46
|
+
batches: Sequence[tuple[E, E, float, float]],
|
|
47
47
|
) -> list[float]:
|
|
48
|
-
source_sims = (
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
48
|
+
source_sims = (source_sim for _, _, source_sim, _ in batches)
|
|
49
|
+
target_sims = (target_sim for _, _, _, target_sim in batches)
|
|
50
|
+
|
|
51
|
+
if self.edge_sim_func is not None:
|
|
52
|
+
edge_sim_func = batchify_sim(self.edge_sim_func)
|
|
53
|
+
edge_sims = unpack_floats(
|
|
54
|
+
edge_sim_func(
|
|
55
|
+
[(x, y) for x, y, _, _ in batches],
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
edge_sims = [1.0] * len(batches)
|
|
60
|
+
|
|
61
|
+
scaling_factor = self.source_weight + self.target_weight
|
|
62
|
+
|
|
63
|
+
if scaling_factor == 0:
|
|
64
|
+
return edge_sims
|
|
56
65
|
|
|
57
66
|
return [
|
|
58
|
-
(
|
|
59
|
-
|
|
60
|
-
for source, target in zip(
|
|
67
|
+
(edge * source * self.source_weight / scaling_factor)
|
|
68
|
+
+ (edge * target * self.target_weight / scaling_factor)
|
|
69
|
+
for source, target, edge in zip(
|
|
70
|
+
source_sims, target_sims, edge_sims, strict=True
|
|
71
|
+
)
|
|
61
72
|
]
|
|
62
73
|
|
|
63
74
|
|
|
@@ -82,37 +93,14 @@ def _induced_edge_mapping[K, N, E, G](
|
|
|
82
93
|
@dataclass(slots=True)
|
|
83
94
|
class BaseGraphSimFunc[K, N, E, G]:
|
|
84
95
|
node_sim_func: AnySimFunc[N, Float]
|
|
85
|
-
edge_sim_func:
|
|
86
|
-
default_edge_sim
|
|
87
|
-
)
|
|
96
|
+
edge_sim_func: SemanticEdgeSim[K, N, E] = default_edge_sim
|
|
88
97
|
node_matcher: ElementMatcher[N] = default_element_matcher
|
|
89
98
|
edge_matcher: ElementMatcher[E] = default_element_matcher
|
|
90
99
|
batch_node_sim_func: BatchSimFunc[Node[K, N], Float] = field(init=False)
|
|
91
|
-
batch_edge_sim_func: (
|
|
92
|
-
BatchSimFunc[Edge[K, N, E], Float] | SemanticEdgeSim[K, N, E]
|
|
93
|
-
) = field(init=False)
|
|
94
|
-
_invert: bool = False
|
|
95
100
|
|
|
96
101
|
def __post_init__(self) -> None:
|
|
97
102
|
self.batch_node_sim_func = batchify_sim(transpose_value(self.node_sim_func))
|
|
98
103
|
|
|
99
|
-
if isinstance(self.edge_sim_func, SemanticEdgeSim):
|
|
100
|
-
self.batch_edge_sim_func = self.edge_sim_func
|
|
101
|
-
else:
|
|
102
|
-
self.batch_edge_sim_func = batchify_sim(self.edge_sim_func)
|
|
103
|
-
|
|
104
|
-
if self._invert:
|
|
105
|
-
self.node_matcher = reverse_positional(self.node_matcher)
|
|
106
|
-
self.edge_matcher = reverse_positional(self.edge_matcher)
|
|
107
|
-
self.batch_node_sim_func = reverse_batch_positional(
|
|
108
|
-
self.batch_node_sim_func
|
|
109
|
-
)
|
|
110
|
-
if not isinstance(self.batch_edge_sim_func, SemanticEdgeSim):
|
|
111
|
-
# semantic edge sim is agnostic to order
|
|
112
|
-
self.batch_edge_sim_func = reverse_batch_positional(
|
|
113
|
-
self.batch_edge_sim_func
|
|
114
|
-
)
|
|
115
|
-
|
|
116
104
|
def induced_edge_mapping(
|
|
117
105
|
self,
|
|
118
106
|
x: Graph[K, N, E, G],
|
|
@@ -163,16 +151,17 @@ class BaseGraphSimFunc[K, N, E, G]:
|
|
|
163
151
|
]
|
|
164
152
|
|
|
165
153
|
edge_pair_values = [(x.edges[x_key], y.edges[y_key]) for y_key, x_key in pairs]
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
154
|
+
edge_pair_sims = self.edge_sim_func(
|
|
155
|
+
[
|
|
156
|
+
(
|
|
157
|
+
x_edge.value,
|
|
158
|
+
y_edge.value,
|
|
159
|
+
node_pair_sims[(y_edge.source.key, x_edge.source.key)],
|
|
160
|
+
node_pair_sims[(y_edge.target.key, x_edge.target.key)],
|
|
161
|
+
)
|
|
162
|
+
for x_edge, y_edge in edge_pair_values
|
|
163
|
+
]
|
|
164
|
+
)
|
|
176
165
|
|
|
177
166
|
return {
|
|
178
167
|
(y_edge.key, x_edge.key): unpack_float(sim)
|
|
@@ -298,16 +287,18 @@ class init_unique_matches[K, N, E, G](SearchStateInit[K, N, E, G]):
|
|
|
298
287
|
edge_matcher: ElementMatcher[E],
|
|
299
288
|
) -> SearchState[K]:
|
|
300
289
|
# pre-populate the mapping with nodes/edges that only have one possible legal mapping
|
|
301
|
-
|
|
290
|
+
y2x_map: defaultdict[K, set[K]] = defaultdict(set)
|
|
291
|
+
x2y_map: defaultdict[K, set[K]] = defaultdict(set)
|
|
302
292
|
|
|
303
293
|
for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
|
|
304
294
|
if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
|
|
305
|
-
|
|
295
|
+
y2x_map[y_key].add(x_key)
|
|
296
|
+
x2y_map[x_key].add(y_key)
|
|
306
297
|
|
|
307
298
|
node_mapping: frozendict[K, K] = frozendict(
|
|
308
299
|
(y_key, next(iter(x_keys)))
|
|
309
|
-
for y_key, x_keys in
|
|
310
|
-
if len(x_keys) == 1
|
|
300
|
+
for y_key, x_keys in y2x_map.items()
|
|
301
|
+
if len(x_keys) == 1 and len(x2y_map[next(iter(x_keys))]) == 1
|
|
311
302
|
)
|
|
312
303
|
|
|
313
304
|
edge_mapping: frozendict[K, K] = _induced_edge_mapping(
|
|
@@ -11,8 +11,6 @@ from .common import BaseGraphSimFunc, GraphSim
|
|
|
11
11
|
|
|
12
12
|
logger = get_logger(__name__)
|
|
13
13
|
|
|
14
|
-
__all__ = ["dfs"]
|
|
15
|
-
|
|
16
14
|
|
|
17
15
|
class RootsFunc[K, N, E, G](Protocol):
|
|
18
16
|
"""Support for matching rooted graphs
|
|
@@ -37,10 +35,10 @@ with optional_dependencies():
|
|
|
37
35
|
class dfs[K, N, E, G](
|
|
38
36
|
BaseGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
|
|
39
37
|
):
|
|
40
|
-
node_del_cost: float =
|
|
41
|
-
node_ins_cost: float =
|
|
42
|
-
edge_del_cost: float =
|
|
43
|
-
edge_ins_cost: float =
|
|
38
|
+
node_del_cost: float = 1.0
|
|
39
|
+
node_ins_cost: float = 1.0
|
|
40
|
+
edge_del_cost: float = 1.0
|
|
41
|
+
edge_ins_cost: float = 1.0
|
|
44
42
|
max_iterations: int = 0
|
|
45
43
|
upper_bound: float | None = None
|
|
46
44
|
strictly_decreasing: bool = True
|
|
@@ -11,8 +11,6 @@ from .common import GraphSim, SearchGraphSimFunc, SearchState
|
|
|
11
11
|
|
|
12
12
|
logger = get_logger(__name__)
|
|
13
13
|
|
|
14
|
-
__all__ = ["greedy"]
|
|
15
|
-
|
|
16
14
|
|
|
17
15
|
@dataclass(slots=True)
|
|
18
16
|
class greedy[K, N, E, G](
|
|
@@ -35,10 +33,6 @@ class greedy[K, N, E, G](
|
|
|
35
33
|
x: Graph[K, N, E, G],
|
|
36
34
|
y: Graph[K, N, E, G],
|
|
37
35
|
) -> GraphSim[K]:
|
|
38
|
-
# if len(y.nodes) + len(y.edges) > len(x.nodes) + len(x.edges):
|
|
39
|
-
# self_inv = dataclasses.replace(self, _invert=True)
|
|
40
|
-
# return self.invert_similarity(x, y, self_inv(x=y, y=x))
|
|
41
|
-
|
|
42
36
|
node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
|
|
43
37
|
|
|
44
38
|
current_state = self.init_search_state(x, y)
|