cbrkit 0.26.2__tar.gz → 0.26.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {cbrkit-0.26.2 → cbrkit-0.26.4}/PKG-INFO +1 -1
  2. {cbrkit-0.26.2 → cbrkit-0.26.4}/pyproject.toml +1 -1
  3. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/retrieval/build.py +52 -25
  4. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/retrieval/rerank.py +4 -1
  5. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/astar.py +133 -65
  6. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/brute_force.py +10 -2
  7. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/common.py +6 -4
  8. {cbrkit-0.26.2 → cbrkit-0.26.4}/README.md +0 -0
  9. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/__init__.py +0 -0
  10. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/__main__.py +0 -0
  11. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/adapt/__init__.py +0 -0
  12. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/adapt/attribute_value.py +0 -0
  13. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/adapt/generic.py +0 -0
  14. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/adapt/numbers.py +0 -0
  15. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/adapt/strings.py +0 -0
  16. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/api.py +0 -0
  17. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/cli.py +0 -0
  18. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/constants.py +0 -0
  19. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/cycle.py +0 -0
  20. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/dumpers.py +0 -0
  21. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/eval/__init__.py +0 -0
  22. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/eval/common.py +0 -0
  23. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/eval/retrieval.py +0 -0
  24. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/helpers.py +0 -0
  25. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/loaders.py +0 -0
  26. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/model/__init__.py +0 -0
  27. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/model/graph.py +0 -0
  28. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/model/result.py +0 -0
  29. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/py.typed +0 -0
  30. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/retrieval/__init__.py +0 -0
  31. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/retrieval/apply.py +0 -0
  32. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/reuse/__init__.py +0 -0
  33. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/reuse/apply.py +0 -0
  34. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/reuse/build.py +0 -0
  35. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/__init__.py +0 -0
  36. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/aggregator.py +0 -0
  37. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/attribute_value.py +0 -0
  38. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/collections.py +0 -0
  39. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/embed.py +0 -0
  40. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/generic.py +0 -0
  41. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/__init__.py +0 -0
  42. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/alignment.py +0 -0
  43. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/dfs.py +0 -0
  44. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/greedy.py +0 -0
  45. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/lap.py +0 -0
  46. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/precompute.py +0 -0
  47. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/qap.py +0 -0
  48. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/graphs/vf2.py +0 -0
  49. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/numbers.py +0 -0
  50. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/strings.py +0 -0
  51. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/taxonomy.py +0 -0
  52. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/sim/wrappers.py +0 -0
  53. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/__init__.py +0 -0
  54. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/apply.py +0 -0
  55. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/build.py +0 -0
  56. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/model.py +0 -0
  57. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/prompts.py +0 -0
  58. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/__init__.py +0 -0
  59. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
  60. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/cohere.py +0 -0
  61. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/google.py +0 -0
  62. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/instructor.py +0 -0
  63. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/model.py +0 -0
  64. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/ollama.py +0 -0
  65. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/openai.py +0 -0
  66. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
  67. {cbrkit-0.26.2 → cbrkit-0.26.4}/src/cbrkit/typing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: cbrkit
3
- Version: 0.26.2
3
+ Version: 0.26.4
4
4
  Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
5
5
  Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
6
6
  Author: Mirko Lenz
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cbrkit"
3
- version = "0.26.2"
3
+ version = "0.26.4"
4
4
  description = "Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI"
5
5
  authors = [{ name = "Mirko Lenz", email = "mirko@mirkolenz.com" }]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  import itertools
2
- from collections.abc import Sequence
2
+ from collections.abc import Mapping, Sequence
3
3
  from dataclasses import dataclass
4
4
  from multiprocessing.pool import Pool
5
5
  from typing import Literal, override
@@ -129,14 +129,18 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
129
129
  retriever_funcs: A list of retriever functions to be combined.
130
130
  aggregator: A function to aggregate the results from the retriever functions.
131
131
  strategy: The strategy to combine the results. Either "intersection" or "union".
132
+ default_sim: The default similarity value to use for strategy "union" when a case is not found in one of the retriever results.
132
133
 
133
134
  Returns:
134
135
  A retriever function that combines the results from multiple retrievers.
135
136
  """
136
137
 
137
- retriever_funcs: Sequence[RetrieverFunc[K, V, S]]
138
- aggregator: AggregatorFunc[str, S] = default_aggregator
138
+ retriever_funcs: (
139
+ Sequence[RetrieverFunc[K, V, S]] | Mapping[str, RetrieverFunc[K, V, S]]
140
+ )
141
+ aggregator: AggregatorFunc[str, S | float] = default_aggregator
139
142
  strategy: Literal["intersection", "union"] = "union"
143
+ default_sim: float = 0.0
140
144
 
141
145
  @override
142
146
  def __call__(
@@ -154,42 +158,65 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
154
158
  for batch_idx in range(len(batches))
155
159
  ]
156
160
 
157
- # elif isinstance(self.retriever_funcs, Mapping):
158
- # results = {
159
- # func_key: retriever_func(batches)
160
- # for func_key, retriever_func in self.retriever_funcs.items()
161
- # }
161
+ elif isinstance(self.retriever_funcs, Mapping):
162
+ results = {
163
+ func_key: retriever_func(batches)
164
+ for func_key, retriever_func in self.retriever_funcs.items()
165
+ }
162
166
 
163
- # return [
164
- # self.__call_batch__(
165
- # {func_key: func_results[batch_idx] for func_key, func_results in results.items()}
166
- # )
167
- # for batch_idx in range(len(batches))
168
- # ]
167
+ return [
168
+ self.__call_batch__(
169
+ {
170
+ func_key: func_results[batch_idx]
171
+ for func_key, func_results in results.items()
172
+ }
173
+ )
174
+ for batch_idx in range(len(batches))
175
+ ]
169
176
 
170
177
  raise ValueError(f"Invalid retriever_funcs type: {type(self.retriever_funcs)}")
171
178
 
172
- def __call_batch__(self, results: Sequence[SimMap[K, S]]) -> SimMap[K, float]:
173
- if self.strategy == "intersection":
179
+ def __call_batch__(
180
+ self, results: Sequence[SimMap[K, S]] | Mapping[str, SimMap[K, S]]
181
+ ) -> SimMap[K, float]:
182
+ case_keys: set[K]
183
+
184
+ if isinstance(results, Sequence):
185
+ if self.strategy == "intersection":
186
+ case_keys = set().intersection(*(result.keys() for result in results))
187
+ elif self.strategy == "union":
188
+ case_keys = set().union(*(result.keys() for result in results))
189
+ else:
190
+ raise ValueError(f"Unknown strategy: {self.strategy}")
191
+
174
192
  return {
175
193
  case_key: self.aggregator(
176
- [result[case_key] for result in results if case_key in result]
177
- )
178
- for case_key in set().intersection(
179
- *[set(result.keys()) for result in results]
194
+ [result.get(case_key, self.default_sim) for result in results]
180
195
  )
196
+ for case_key in case_keys
181
197
  }
182
198
 
183
- elif self.strategy == "union":
199
+ elif isinstance(results, Mapping):
200
+ if self.strategy == "intersection":
201
+ case_keys = set().intersection(
202
+ *(result.keys() for result in results.values())
203
+ )
204
+ elif self.strategy == "union":
205
+ case_keys = set().union(*(result.keys() for result in results.values()))
206
+ else:
207
+ raise ValueError(f"Unknown strategy: {self.strategy}")
208
+
184
209
  return {
185
210
  case_key: self.aggregator(
186
- [result[case_key] for result in results if case_key in result]
211
+ {
212
+ func_key: result.get(case_key, self.default_sim)
213
+ for func_key, result in results.items()
214
+ }
187
215
  )
188
- for result in results
189
- for case_key in result.keys()
216
+ for case_key in case_keys
190
217
  }
191
218
 
192
- raise ValueError(f"Unknown strategy: {self.strategy}")
219
+ raise ValueError(f"Invalid results type: {type(results)}")
193
220
 
194
221
 
195
222
  @dataclass(slots=True, frozen=True)
@@ -311,12 +311,15 @@ with optional_dependencies():
311
311
  k=len(casebase),
312
312
  )
313
313
  max_score = np.max(scores)
314
+ min_score = np.min(scores)
314
315
 
315
316
  key_index = {idx: key for idx, key in enumerate(casebase)}
316
317
 
317
318
  return [
318
319
  {
319
- key_index[case_id]: float(score / max_score)
320
+ key_index[case_id]: float(
321
+ (score - min_score) / (max_score - min_score)
322
+ )
320
323
  for case_id, score in zip(
321
324
  results[query_id], scores[query_id], strict=True
322
325
  )
@@ -1,7 +1,7 @@
1
1
  import heapq
2
- from collections.abc import Mapping
2
+ from collections.abc import Collection, Mapping
3
3
  from dataclasses import dataclass, field
4
- from typing import Protocol
4
+ from typing import Any, Protocol, cast
5
5
 
6
6
  from ...helpers import (
7
7
  get_logger,
@@ -34,6 +34,33 @@ __all__ = [
34
34
  logger = get_logger(__name__)
35
35
 
36
36
 
37
+ def next_elem[K](elements: Collection[K]) -> K:
38
+ """Select the next element from a set deterministically.
39
+
40
+ If elements are sortable, returns the smallest one.
41
+ Otherwise, returns the first element from iteration.
42
+
43
+ Args:
44
+ elements: Set of elements to choose from
45
+
46
+ Returns:
47
+ A single element from the set
48
+
49
+ Raises:
50
+ ValueError: If the set is empty
51
+ """
52
+ if not elements:
53
+ raise ValueError("Cannot select from empty set")
54
+
55
+ if len(elements) == 1:
56
+ return next(iter(elements))
57
+
58
+ try:
59
+ return min(cast(Collection[Any], elements))
60
+ except TypeError:
61
+ return next(iter(elements))
62
+
63
+
37
64
  @dataclass(slots=True, frozen=True, order=True)
38
65
  class PriorityState[K]:
39
66
  priority: float
@@ -60,7 +87,6 @@ class SelectionFunc[K, N, E, G](Protocol):
60
87
  s: SearchState[K],
61
88
  node_pair_sims: Mapping[tuple[K, K], float],
62
89
  edge_pair_sims: Mapping[tuple[K, K], float],
63
- heuristic_func: HeuristicFunc[K, N, E, G],
64
90
  /,
65
91
  ) -> None | tuple[K, GraphElementType]: ...
66
92
 
@@ -127,7 +153,7 @@ class h3[K, N, E, G](HeuristicFunc[K, N, E, G]):
127
153
  default=0.0,
128
154
  )
129
155
 
130
- def mapping_possible(x_node: Node[K, N], y_node: Node[K, N]) -> bool:
156
+ def can_map(x_node: Node[K, N], y_node: Node[K, N]) -> bool:
131
157
  return x_node.key == s.node_mapping.get(y_node.key) or (
132
158
  y_node.key in s.open_y_nodes and x_node.key in s.open_x_nodes
133
159
  )
@@ -137,8 +163,8 @@ class h3[K, N, E, G](HeuristicFunc[K, N, E, G]):
137
163
  (
138
164
  edge_pair_sims.get((y_key, x_key), 0.0)
139
165
  for x_key in s.open_x_edges
140
- if mapping_possible(x.edges[x_key].source, y.edges[y_key].source)
141
- and mapping_possible(x.edges[x_key].target, y.edges[y_key].target)
166
+ if can_map(x.edges[x_key].source, y.edges[y_key].source)
167
+ and can_map(x.edges[x_key].target, y.edges[y_key].target)
142
168
  ),
143
169
  default=0.0,
144
170
  )
@@ -155,19 +181,14 @@ class select1[K, N, E, G](SelectionFunc[K, N, E, G]):
155
181
  s: SearchState[K],
156
182
  node_pair_sims: Mapping[tuple[K, K], float],
157
183
  edge_pair_sims: Mapping[tuple[K, K], float],
158
- heuristic_func: HeuristicFunc[K, N, E, G],
159
184
  ) -> None | tuple[K, GraphElementType]:
160
185
  """Select the next node or edge to be mapped"""
161
186
 
162
- try:
163
- return next(iter(s.open_y_nodes)), "node"
164
- except StopIteration:
165
- pass
187
+ if s.open_y_nodes:
188
+ return next_elem(s.open_y_nodes), "node"
166
189
 
167
- try:
168
- return next(iter(s.open_y_edges)), "edge"
169
- except StopIteration:
170
- pass
190
+ if s.open_y_edges:
191
+ return next_elem(s.open_y_edges), "edge"
171
192
 
172
193
  return None
173
194
 
@@ -181,24 +202,21 @@ class select2[K, N, E, G](SelectionFunc[K, N, E, G]):
181
202
  s: SearchState[K],
182
203
  node_pair_sims: Mapping[tuple[K, K], float],
183
204
  edge_pair_sims: Mapping[tuple[K, K], float],
184
- heuristic_func: HeuristicFunc[K, N, E, G],
185
205
  ) -> None | tuple[K, GraphElementType]:
186
206
  """Select the next node or edge to be mapped"""
187
207
 
188
- try:
189
- return next(
190
- key
191
- for key in s.open_y_edges
192
- if y.edges[key].source.key not in s.open_y_nodes
193
- and y.edges[key].target.key not in s.open_y_nodes
194
- ), "edge"
195
- except StopIteration:
196
- pass
208
+ edge_candidates = {
209
+ key
210
+ for key in s.open_y_edges
211
+ if y.edges[key].source.key not in s.open_y_nodes
212
+ and y.edges[key].target.key not in s.open_y_nodes
213
+ }
197
214
 
198
- try:
199
- return next(iter(s.open_y_nodes)), "node"
200
- except StopIteration:
201
- pass
215
+ if edge_candidates:
216
+ return next_elem(edge_candidates), "edge"
217
+
218
+ if s.open_y_nodes:
219
+ return next_elem(s.open_y_nodes), "node"
202
220
 
203
221
  return None
204
222
 
@@ -212,36 +230,75 @@ class select3[K, N, E, G](SelectionFunc[K, N, E, G]):
212
230
  s: SearchState[K],
213
231
  node_pair_sims: Mapping[tuple[K, K], float],
214
232
  edge_pair_sims: Mapping[tuple[K, K], float],
215
- heuristic_func: HeuristicFunc[K, N, E, G],
216
233
  ) -> None | tuple[K, GraphElementType]:
217
234
  """Select the next node or edge to be mapped"""
218
235
 
219
- heuristic_scores: list[tuple[K, GraphElementType, float]] = []
236
+ mapping_options: dict[tuple[K, GraphElementType], int] = {}
237
+ heuristic_scores: dict[tuple[K, GraphElementType], float] = {}
220
238
 
221
239
  for y_key in s.open_y_nodes:
222
- heuristic_scores.append(
223
- (
224
- y_key,
225
- "node",
226
- heuristic_func(x, y, s, node_pair_sims, edge_pair_sims),
227
- )
240
+ h_vals = [
241
+ node_pair_sims[(y_key, x_key)]
242
+ for x_key in s.open_x_nodes
243
+ if (y_key, x_key) in node_pair_sims
244
+ ]
245
+
246
+ if h_vals:
247
+ mapping_options[(y_key, "node")] = len(h_vals)
248
+ heuristic_scores[(y_key, "node")] = max(h_vals)
249
+
250
+ def can_map(x_node: Node[K, N], y_node: Node[K, N]) -> bool:
251
+ return x_node.key == s.node_mapping.get(y_node.key) or (
252
+ y_node.key in s.open_y_nodes and x_node.key in s.open_x_nodes
228
253
  )
229
254
 
230
255
  for y_key in s.open_y_edges:
231
- heuristic_scores.append(
232
- (
233
- y_key,
234
- "edge",
235
- heuristic_func(x, y, s, node_pair_sims, edge_pair_sims),
236
- )
237
- )
256
+ h_vals = [
257
+ edge_pair_sims[(y_key, x_key)]
258
+ for x_key in s.open_x_edges
259
+ if (y_key, x_key) in edge_pair_sims
260
+ and can_map(x.edges[x_key].source, y.edges[y_key].source)
261
+ and can_map(x.edges[x_key].target, y.edges[y_key].target)
262
+ ]
263
+
264
+ if h_vals:
265
+ mapping_options[(y_key, "edge")] = len(h_vals)
266
+ heuristic_scores[(y_key, "edge")] = max(h_vals)
238
267
 
239
268
  if not heuristic_scores:
269
+ # Fallback: select any remaining node or edge for null mapping
270
+ # Use sorted to ensure deterministic selection
271
+ if s.open_y_nodes:
272
+ return next_elem(s.open_y_nodes), "node"
273
+ elif s.open_y_edges:
274
+ return next_elem(s.open_y_edges), "edge"
240
275
  return None
241
276
 
242
- best_selection = max(heuristic_scores, key=lambda x: x[2])
277
+ # Find the maximum heuristic score
278
+ max_score = max(heuristic_scores.values())
279
+ best_selections = {
280
+ key for key, value in heuristic_scores.items() if value == max_score
281
+ }
243
282
 
244
- selection_key, selection_type, _ = best_selection
283
+ # if multiple selections have the same score, select the one with the lowest number of possible mappings
284
+ if len(best_selections) > 1:
285
+ min_mapping_options = min(mapping_options[key] for key in best_selections)
286
+ best_selections = {
287
+ key
288
+ for key in best_selections
289
+ if mapping_options[key] == min_mapping_options
290
+ }
291
+
292
+ # select the one with the lowest key
293
+ try:
294
+ best_selection = min(
295
+ best_selections,
296
+ key=lambda item: cast(Any, item[0]),
297
+ )
298
+ except TypeError:
299
+ best_selection = next(iter(best_selections))
300
+
301
+ selection_key, selection_type = best_selection
245
302
 
246
303
  if selection_type == "edge":
247
304
  edge = y.edges[selection_key]
@@ -271,7 +328,7 @@ class build[K, N, E, G](
271
328
  beam_width: Limits the queue size which prunes the search space.
272
329
  This leads to a faster search and less memory usage but also introduces a similarity error.
273
330
  Disabled by default. Based on [Neuhaus et al. (2006)](https://doi.org/10.1007/11815921_17).
274
- pathlength_weight: Add a penalty for states with few mapped elements that already have a low similarity.
331
+ pathlength_weight: Favor long partial edit paths over shorter ones.
275
332
  Disabled by default. Based on [Neuhaus et al. (2006)](https://doi.org/10.1007/11815921_17).
276
333
 
277
334
  Returns:
@@ -300,7 +357,6 @@ class build[K, N, E, G](
300
357
  state,
301
358
  node_pair_sims,
302
359
  edge_pair_sims,
303
- self.heuristic_func,
304
360
  )
305
361
 
306
362
  if selection is None:
@@ -338,22 +394,11 @@ class build[K, N, E, G](
338
394
  prio = 1 - (past_sim + future_sim)
339
395
 
340
396
  if self.pathlength_weight > 0:
341
- node_null_mapping = (
342
- set(y.nodes.keys())
343
- - set(state.node_mapping.keys())
344
- - set(state.open_y_nodes)
345
- )
346
- edge_null_mapping = (
347
- set(y.edges.keys())
348
- - set(state.edge_mapping.keys())
349
- - set(state.open_y_edges)
350
- )
351
- num_paths = (
352
- len(state.node_mapping)
353
- + len(state.edge_mapping)
354
- + len(node_null_mapping)
355
- + len(edge_null_mapping)
356
- )
397
+ # Calculate the number of mapping decisions made so far (partial edit path length)
398
+ # This includes actual mappings plus null mappings (elements processed but not mapped)
399
+ total_y_elements = len(y.nodes) + len(y.edges)
400
+ open_y_elements = len(state.open_y_nodes) + len(state.open_y_edges)
401
+ num_paths = total_y_elements - open_y_elements
357
402
  return prio / (self.pathlength_weight**num_paths)
358
403
 
359
404
  return prio
@@ -367,11 +412,33 @@ class build[K, N, E, G](
367
412
 
368
413
  open_set: list[PriorityState[K]] = []
369
414
  best_state = self.init_search_state(x, y)
415
+ # best_similarity = self.similarity(
416
+ # x,
417
+ # y,
418
+ # best_state.node_mapping,
419
+ # best_state.edge_mapping,
420
+ # node_pair_sims,
421
+ # edge_pair_sims,
422
+ # )
370
423
  heapq.heappush(open_set, PriorityState(0, best_state))
371
424
 
372
425
  while open_set:
373
426
  first_elem = heapq.heappop(open_set)
374
427
  current_state = first_elem.state
428
+ # current_similarity = self.similarity(
429
+ # x,
430
+ # y,
431
+ # current_state.node_mapping,
432
+ # current_state.edge_mapping,
433
+ # node_pair_sims,
434
+ # edge_pair_sims,
435
+ # )
436
+
437
+ # not needed because we add null mappings and
438
+ # the first item of the queue is always the best one
439
+ # if current_similarity.value > best_similarity.value:
440
+ # best_state = current_state
441
+ # best_similarity = current_similarity
375
442
 
376
443
  if self.finished(current_state):
377
444
  best_state = current_state
@@ -392,7 +459,8 @@ class build[K, N, E, G](
392
459
  heapq.heappush(open_set, PriorityState(next_prio, next_state))
393
460
 
394
461
  if self.beam_width > 0 and len(open_set) > self.beam_width:
395
- open_set = open_set[: self.beam_width]
462
+ open_set = heapq.nsmallest(self.beam_width, open_set)
463
+ heapq.heapify(open_set)
396
464
 
397
465
  return self.similarity(
398
466
  x,
@@ -74,8 +74,16 @@ class brute_force[K, N, E, G](
74
74
 
75
75
  if next_sim and (
76
76
  next_sim.value > best_sim.value
77
- or len(next_sim.node_mapping) > len(best_sim.node_mapping)
78
- or len(next_sim.edge_mapping) > len(best_sim.edge_mapping)
77
+ or (
78
+ next_sim.value >= best_sim.value
79
+ and (
80
+ len(next_sim.node_mapping) > len(best_sim.node_mapping)
81
+ or (
82
+ len(next_sim.edge_mapping)
83
+ > len(best_sim.edge_mapping)
84
+ )
85
+ )
86
+ )
79
87
  ):
80
88
  best_sim = next_sim
81
89
 
@@ -287,16 +287,18 @@ class init_unique_matches[K, N, E, G](SearchStateInit[K, N, E, G]):
287
287
  edge_matcher: ElementMatcher[E],
288
288
  ) -> SearchState[K]:
289
289
  # pre-populate the mapping with nodes/edges that only have one possible legal mapping
290
- possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
290
+ y2x_map: defaultdict[K, set[K]] = defaultdict(set)
291
+ x2y_map: defaultdict[K, set[K]] = defaultdict(set)
291
292
 
292
293
  for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
293
294
  if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
294
- possible_node_mappings[y_key].add(x_key)
295
+ y2x_map[y_key].add(x_key)
296
+ x2y_map[x_key].add(y_key)
295
297
 
296
298
  node_mapping: frozendict[K, K] = frozendict(
297
299
  (y_key, next(iter(x_keys)))
298
- for y_key, x_keys in possible_node_mappings.items()
299
- if len(x_keys) == 1
300
+ for y_key, x_keys in y2x_map.items()
301
+ if len(x_keys) == 1 and len(x2y_map[next(iter(x_keys))]) == 1
300
302
  )
301
303
 
302
304
  edge_mapping: frozendict[K, K] = _induced_edge_mapping(
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes