cbrkit 0.26.0__tar.gz → 0.26.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {cbrkit-0.26.0 → cbrkit-0.26.2}/PKG-INFO +1 -1
  2. {cbrkit-0.26.0 → cbrkit-0.26.2}/pyproject.toml +1 -1
  3. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/build.py +36 -21
  4. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/__init__.py +2 -0
  5. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/__init__.py +9 -1
  6. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/astar.py +1 -81
  7. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/common.py +150 -53
  8. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/dfs.py +4 -6
  9. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/greedy.py +9 -19
  10. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/lap.py +67 -14
  11. cbrkit-0.26.2/src/cbrkit/sim/graphs/precompute.py +80 -0
  12. cbrkit-0.26.2/src/cbrkit/sim/graphs/qap.py +145 -0
  13. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/vf2.py +143 -35
  14. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/wrappers.py +64 -1
  15. cbrkit-0.26.0/src/cbrkit/sim/graphs/precompute.py +0 -56
  16. cbrkit-0.26.0/src/cbrkit/sim/graphs/qap.py +0 -118
  17. {cbrkit-0.26.0 → cbrkit-0.26.2}/README.md +0 -0
  18. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/__init__.py +0 -0
  19. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/__main__.py +0 -0
  20. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/__init__.py +0 -0
  21. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/attribute_value.py +0 -0
  22. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/generic.py +0 -0
  23. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/numbers.py +0 -0
  24. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/adapt/strings.py +0 -0
  25. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/api.py +0 -0
  26. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/cli.py +0 -0
  27. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/constants.py +0 -0
  28. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/cycle.py +0 -0
  29. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/dumpers.py +0 -0
  30. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/eval/__init__.py +0 -0
  31. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/eval/common.py +0 -0
  32. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/eval/retrieval.py +0 -0
  33. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/helpers.py +0 -0
  34. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/loaders.py +0 -0
  35. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/model/__init__.py +0 -0
  36. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/model/graph.py +0 -0
  37. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/model/result.py +0 -0
  38. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/py.typed +0 -0
  39. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/__init__.py +0 -0
  40. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/apply.py +0 -0
  41. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/retrieval/rerank.py +0 -0
  42. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/reuse/__init__.py +0 -0
  43. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/reuse/apply.py +0 -0
  44. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/reuse/build.py +0 -0
  45. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/aggregator.py +0 -0
  46. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/attribute_value.py +0 -0
  47. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/collections.py +0 -0
  48. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/embed.py +0 -0
  49. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/generic.py +0 -0
  50. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/alignment.py +0 -0
  51. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/graphs/brute_force.py +0 -0
  52. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/numbers.py +0 -0
  53. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/strings.py +0 -0
  54. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/sim/taxonomy.py +0 -0
  55. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/__init__.py +0 -0
  56. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/apply.py +0 -0
  57. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/build.py +0 -0
  58. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/model.py +0 -0
  59. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/prompts.py +0 -0
  60. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/__init__.py +0 -0
  61. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
  62. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/cohere.py +0 -0
  63. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/google.py +0 -0
  64. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/instructor.py +0 -0
  65. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/model.py +0 -0
  66. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/ollama.py +0 -0
  67. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/openai.py +0 -0
  68. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
  69. {cbrkit-0.26.0 → cbrkit-0.26.2}/src/cbrkit/typing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: cbrkit
3
- Version: 0.26.0
3
+ Version: 0.26.2
4
4
  Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
5
5
  Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
6
6
  Author: Mirko Lenz
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cbrkit"
3
- version = "0.26.0"
3
+ version = "0.26.2"
4
4
  description = "Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI"
5
5
  authors = [{ name = "Mirko Lenz", email = "mirko@mirkolenz.com" }]
6
6
  readme = "README.md"
@@ -2,7 +2,7 @@ import itertools
2
2
  from collections.abc import Sequence
3
3
  from dataclasses import dataclass
4
4
  from multiprocessing.pool import Pool
5
- from typing import Any, Literal, override
5
+ from typing import Literal, override
6
6
 
7
7
  from ..helpers import (
8
8
  batchify_sim,
@@ -134,44 +134,59 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
134
134
  A retriever function that combines the results from multiple retrievers.
135
135
  """
136
136
 
137
- retriever_funcs: list[RetrieverFunc[K, V, S]]
138
- aggregator: AggregatorFunc[Any, S] = default_aggregator
137
+ retriever_funcs: Sequence[RetrieverFunc[K, V, S]]
138
+ aggregator: AggregatorFunc[str, S] = default_aggregator
139
139
  strategy: Literal["intersection", "union"] = "union"
140
140
 
141
141
  @override
142
142
  def __call__(
143
143
  self, batches: Sequence[tuple[Casebase[K, V], V]]
144
144
  ) -> Sequence[SimMap[K, float]]:
145
- results = [retriever_func(batches) for retriever_func in self.retriever_funcs]
146
-
147
- return [
148
- self.__call_batch__(
149
- [
150
- results[retriever_idx][batch_idx]
151
- for retriever_idx in range(len(self.retriever_funcs))
152
- ]
153
- )
154
- for batch_idx in range(len(batches))
155
- ]
145
+ if isinstance(self.retriever_funcs, Sequence):
146
+ func_results = [
147
+ retriever_func(batches) for retriever_func in self.retriever_funcs
148
+ ]
149
+
150
+ return [
151
+ self.__call_batch__(
152
+ [batch_results[batch_idx] for batch_results in func_results]
153
+ )
154
+ for batch_idx in range(len(batches))
155
+ ]
156
+
157
+ # elif isinstance(self.retriever_funcs, Mapping):
158
+ # results = {
159
+ # func_key: retriever_func(batches)
160
+ # for func_key, retriever_func in self.retriever_funcs.items()
161
+ # }
162
+
163
+ # return [
164
+ # self.__call_batch__(
165
+ # {func_key: func_results[batch_idx] for func_key, func_results in results.items()}
166
+ # )
167
+ # for batch_idx in range(len(batches))
168
+ # ]
169
+
170
+ raise ValueError(f"Invalid retriever_funcs type: {type(self.retriever_funcs)}")
156
171
 
157
- def __call_batch__(self, results: list[SimMap[K, S]]) -> SimMap[K, float]:
172
+ def __call_batch__(self, results: Sequence[SimMap[K, S]]) -> SimMap[K, float]:
158
173
  if self.strategy == "intersection":
159
174
  return {
160
- key: self.aggregator(
161
- [result[key] for result in results if key in result]
175
+ case_key: self.aggregator(
176
+ [result[case_key] for result in results if case_key in result]
162
177
  )
163
- for key in set().intersection(
178
+ for case_key in set().intersection(
164
179
  *[set(result.keys()) for result in results]
165
180
  )
166
181
  }
167
182
 
168
183
  elif self.strategy == "union":
169
184
  return {
170
- key: self.aggregator(
171
- [result[key] for result in results if key in result]
185
+ case_key: self.aggregator(
186
+ [result[case_key] for result in results if case_key in result]
172
187
  )
173
188
  for result in results
174
- for key in result.keys()
189
+ for case_key in result.keys()
175
190
  }
176
191
 
177
192
  raise ValueError(f"Unknown strategy: {self.strategy}")
@@ -15,6 +15,7 @@ from .attribute_value import AttributeValueSim, attribute_value
15
15
  from .wrappers import (
16
16
  attribute_table,
17
17
  cache,
18
+ combine,
18
19
  dynamic_table,
19
20
  table,
20
21
  transpose,
@@ -26,6 +27,7 @@ __all__ = [
26
27
  "transpose",
27
28
  "transpose_value",
28
29
  "cache",
30
+ "combine",
29
31
  "table",
30
32
  "dynamic_table",
31
33
  "type_table",
@@ -7,12 +7,15 @@ from .common import (
7
7
  GraphSim,
8
8
  SearchGraphSimFunc,
9
9
  SearchState,
10
+ SearchStateInit,
10
11
  SemanticEdgeSim,
12
+ init_empty,
13
+ init_unique_matches,
11
14
  )
12
15
  from .greedy import greedy
13
16
  from .lap import lap
14
17
  from .precompute import precompute
15
- from .vf2 import vf2
18
+ from .vf2 import vf2, vf2_networkx, vf2_rustworkx
16
19
 
17
20
  with optional_dependencies():
18
21
  from .alignment import dtw
@@ -31,12 +34,17 @@ __all__ = [
31
34
  "lap",
32
35
  "precompute",
33
36
  "vf2",
37
+ "vf2_networkx",
38
+ "vf2_rustworkx",
34
39
  "dtw",
35
40
  "smith_waterman",
41
+ "init_empty",
42
+ "init_unique_matches",
36
43
  "GraphSim",
37
44
  "ElementMatcher",
38
45
  "SemanticEdgeSim",
39
46
  "BaseGraphSimFunc",
40
47
  "SearchGraphSimFunc",
41
48
  "SearchState",
49
+ "SearchStateInit",
42
50
  ]
@@ -1,12 +1,8 @@
1
1
  import heapq
2
- import itertools
3
- from collections import defaultdict
4
2
  from collections.abc import Mapping
5
3
  from dataclasses import dataclass, field
6
4
  from typing import Protocol
7
5
 
8
- from frozendict import frozendict
9
-
10
6
  from ...helpers import (
11
7
  get_logger,
12
8
  unpack_float,
@@ -18,25 +14,20 @@ from ...model.graph import (
18
14
  )
19
15
  from ...typing import SimFunc
20
16
  from .common import (
21
- ElementMatcher,
22
17
  GraphSim,
23
18
  SearchGraphSimFunc,
24
19
  SearchState,
25
- _induced_edge_mapping,
26
20
  )
27
21
 
28
22
  __all__ = [
29
23
  "HeuristicFunc",
30
24
  "SelectionFunc",
31
- "InitFunc",
32
25
  "h1",
33
26
  "h2",
34
27
  "h3",
35
28
  "select1",
36
29
  "select2",
37
30
  "select3",
38
- "init1",
39
- "init2",
40
31
  "build",
41
32
  ]
42
33
 
@@ -74,17 +65,6 @@ class SelectionFunc[K, N, E, G](Protocol):
74
65
  ) -> None | tuple[K, GraphElementType]: ...
75
66
 
76
67
 
77
- class InitFunc[K, N, E, G](Protocol):
78
- def __call__(
79
- self,
80
- x: Graph[K, N, E, G],
81
- y: Graph[K, N, E, G],
82
- node_matcher: ElementMatcher[N],
83
- edge_matcher: ElementMatcher[E],
84
- /,
85
- ) -> SearchState[K]: ...
86
-
87
-
88
68
  @dataclass(slots=True, frozen=True)
89
69
  class h1[K, N, E, G](HeuristicFunc[K, N, E, G]):
90
70
  def __call__(
@@ -274,61 +254,6 @@ class select3[K, N, E, G](SelectionFunc[K, N, E, G]):
274
254
  return selection_key, selection_type
275
255
 
276
256
 
277
- @dataclass(slots=True, frozen=True)
278
- class init1[K, N, E, G](InitFunc[K, N, E, G]):
279
- def __call__(
280
- self,
281
- x: Graph[K, N, E, G],
282
- y: Graph[K, N, E, G],
283
- node_matcher: ElementMatcher[N],
284
- edge_matcher: ElementMatcher[E],
285
- ) -> SearchState[K]:
286
- return SearchState(
287
- frozendict(),
288
- frozendict(),
289
- frozenset(y.nodes.keys()),
290
- frozenset(y.edges.keys()),
291
- frozenset(x.nodes.keys()),
292
- frozenset(x.edges.keys()),
293
- )
294
-
295
-
296
- @dataclass(slots=True, init=False)
297
- class init2[K, N, E, G](InitFunc[K, N, E, G]):
298
- def __call__(
299
- self,
300
- x: Graph[K, N, E, G],
301
- y: Graph[K, N, E, G],
302
- node_matcher: ElementMatcher[N],
303
- edge_matcher: ElementMatcher[E],
304
- ) -> SearchState[K]:
305
- # pre-populate the mapping with nodes/edges that only have one possible legal mapping
306
- possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
307
-
308
- for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
309
- if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
310
- possible_node_mappings[y_key].add(x_key)
311
-
312
- node_mapping: frozendict[K, K] = frozendict(
313
- (y_key, next(iter(x_keys)))
314
- for y_key, x_keys in possible_node_mappings.items()
315
- if len(x_keys) == 1
316
- )
317
-
318
- edge_mapping: frozendict[K, K] = _induced_edge_mapping(
319
- x, y, node_mapping, edge_matcher
320
- )
321
-
322
- return SearchState(
323
- node_mapping,
324
- edge_mapping,
325
- frozenset(y.nodes.keys() - node_mapping.keys()),
326
- frozenset(y.edges.keys() - edge_mapping.keys()),
327
- frozenset(x.nodes.keys() - node_mapping.values()),
328
- frozenset(x.edges.keys() - edge_mapping.values()),
329
- )
330
-
331
-
332
257
  @dataclass(slots=True)
333
258
  class build[K, N, E, G](
334
259
  SearchGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
@@ -355,7 +280,6 @@ class build[K, N, E, G](
355
280
 
356
281
  heuristic_func: HeuristicFunc[K, N, E, G] = field(default_factory=h3)
357
282
  selection_func: SelectionFunc[K, N, E, G] = field(default_factory=select3)
358
- init_func: InitFunc[K, N, E, G] = field(default_factory=init1)
359
283
  beam_width: int = 0
360
284
  pathlength_weight: int = 0
361
285
 
@@ -439,14 +363,10 @@ class build[K, N, E, G](
439
363
  x: Graph[K, N, E, G],
440
364
  y: Graph[K, N, E, G],
441
365
  ) -> GraphSim[K]:
442
- # if len(y.nodes) + len(y.edges) > len(x.nodes) + len(x.edges):
443
- # self_inv = dataclasses.replace(self, _invert=True)
444
- # return self.invert_similarity(x, y, self_inv(x=y, y=x))
445
-
446
366
  node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
447
367
 
448
368
  open_set: list[PriorityState[K]] = []
449
- best_state = self.init_func(x, y, self.node_matcher, self.edge_matcher)
369
+ best_state = self.init_search_state(x, y)
450
370
  heapq.heappush(open_set, PriorityState(0, best_state))
451
371
 
452
372
  while open_set:
@@ -1,18 +1,19 @@
1
1
  import itertools
2
+ from collections import defaultdict
2
3
  from collections.abc import Mapping, Sequence
3
4
  from dataclasses import dataclass, field
4
- from typing import Any, Protocol
5
+ from typing import Any, Protocol, cast
5
6
 
6
7
  from frozendict import frozendict
7
8
 
8
9
  from ...helpers import (
9
10
  batchify_sim,
10
- reverse_batch_positional,
11
- reverse_positional,
11
+ total_params,
12
12
  unpack_float,
13
+ unpack_floats,
13
14
  )
14
- from ...model.graph import Edge, Graph, Node
15
- from ...typing import AnySimFunc, BatchSimFunc, Float, StructuredValue
15
+ from ...model.graph import Graph, Node
16
+ from ...typing import AnySimFunc, BatchSimFunc, Float, SimFunc, StructuredValue
16
17
  from ..wrappers import transpose_value
17
18
 
18
19
  type PairSim[K] = Mapping[tuple[K, K], float]
@@ -36,26 +37,38 @@ def default_element_matcher(x: Any, y: Any) -> bool:
36
37
 
37
38
  @dataclass(slots=True, frozen=True)
38
39
  class SemanticEdgeSim[K, N, E]:
39
- source_weight: float = 0.5
40
- target_weight: float = 0.5
40
+ source_weight: float = 1.0
41
+ target_weight: float = 1.0
42
+ edge_sim_func: AnySimFunc[E, Float] | None = None
41
43
 
42
44
  def __call__(
43
45
  self,
44
- batches: Sequence[tuple[Edge[K, N, E], Edge[K, N, E], PairSim[K]]],
46
+ batches: Sequence[tuple[E, E, float, float]],
45
47
  ) -> list[float]:
46
- source_sims = (
47
- node_pair_sims.get((y.source.key, x.source.key), 0.0)
48
- for x, y, node_pair_sims in batches
49
- )
50
- target_sims = (
51
- node_pair_sims.get((y.target.key, x.target.key), 0.0)
52
- for x, y, node_pair_sims in batches
53
- )
48
+ source_sims = (source_sim for _, _, source_sim, _ in batches)
49
+ target_sims = (target_sim for _, _, _, target_sim in batches)
50
+
51
+ if self.edge_sim_func is not None:
52
+ edge_sim_func = batchify_sim(self.edge_sim_func)
53
+ edge_sims = unpack_floats(
54
+ edge_sim_func(
55
+ [(x, y) for x, y, _, _ in batches],
56
+ )
57
+ )
58
+ else:
59
+ edge_sims = [1.0] * len(batches)
60
+
61
+ scaling_factor = self.source_weight + self.target_weight
62
+
63
+ if scaling_factor == 0:
64
+ return edge_sims
54
65
 
55
66
  return [
56
- (self.source_weight * source + self.target_weight * target)
57
- / (self.source_weight + self.target_weight)
58
- for source, target in zip(source_sims, target_sims, strict=True)
67
+ (edge * source * self.source_weight / scaling_factor)
68
+ + (edge * target * self.target_weight / scaling_factor)
69
+ for source, target, edge in zip(
70
+ source_sims, target_sims, edge_sims, strict=True
71
+ )
59
72
  ]
60
73
 
61
74
 
@@ -80,37 +93,14 @@ def _induced_edge_mapping[K, N, E, G](
80
93
  @dataclass(slots=True)
81
94
  class BaseGraphSimFunc[K, N, E, G]:
82
95
  node_sim_func: AnySimFunc[N, Float]
83
- edge_sim_func: AnySimFunc[Edge[K, N, E], Float] | SemanticEdgeSim[K, N, E] = (
84
- default_edge_sim
85
- )
96
+ edge_sim_func: SemanticEdgeSim[K, N, E] = default_edge_sim
86
97
  node_matcher: ElementMatcher[N] = default_element_matcher
87
98
  edge_matcher: ElementMatcher[E] = default_element_matcher
88
99
  batch_node_sim_func: BatchSimFunc[Node[K, N], Float] = field(init=False)
89
- batch_edge_sim_func: (
90
- BatchSimFunc[Edge[K, N, E], Float] | SemanticEdgeSim[K, N, E]
91
- ) = field(init=False)
92
- _invert: bool = False
93
100
 
94
101
  def __post_init__(self) -> None:
95
102
  self.batch_node_sim_func = batchify_sim(transpose_value(self.node_sim_func))
96
103
 
97
- if isinstance(self.edge_sim_func, SemanticEdgeSim):
98
- self.batch_edge_sim_func = self.edge_sim_func
99
- else:
100
- self.batch_edge_sim_func = batchify_sim(self.edge_sim_func)
101
-
102
- if self._invert:
103
- self.node_matcher = reverse_positional(self.node_matcher)
104
- self.edge_matcher = reverse_positional(self.edge_matcher)
105
- self.batch_node_sim_func = reverse_batch_positional(
106
- self.batch_node_sim_func
107
- )
108
- if not isinstance(self.batch_edge_sim_func, SemanticEdgeSim):
109
- # semantic edge sim is agnostic to order
110
- self.batch_edge_sim_func = reverse_batch_positional(
111
- self.batch_edge_sim_func
112
- )
113
-
114
104
  def induced_edge_mapping(
115
105
  self,
116
106
  x: Graph[K, N, E, G],
@@ -161,16 +151,17 @@ class BaseGraphSimFunc[K, N, E, G]:
161
151
  ]
162
152
 
163
153
  edge_pair_values = [(x.edges[x_key], y.edges[y_key]) for y_key, x_key in pairs]
164
-
165
- if isinstance(self.batch_edge_sim_func, SemanticEdgeSim):
166
- edge_pair_sims = self.batch_edge_sim_func(
167
- [
168
- (x_edge, y_edge, node_pair_sims)
169
- for x_edge, y_edge in edge_pair_values
170
- ]
171
- )
172
- else:
173
- edge_pair_sims = self.batch_edge_sim_func(edge_pair_values)
154
+ edge_pair_sims = self.edge_sim_func(
155
+ [
156
+ (
157
+ x_edge.value,
158
+ y_edge.value,
159
+ node_pair_sims[(y_edge.source.key, x_edge.source.key)],
160
+ node_pair_sims[(y_edge.target.key, x_edge.target.key)],
161
+ )
162
+ for x_edge, y_edge in edge_pair_values
163
+ ]
164
+ )
174
165
 
175
166
  return {
176
167
  (y_edge.key, x_edge.key): unpack_float(sim)
@@ -256,7 +247,113 @@ class SearchState[K]:
256
247
  open_x_edges: frozenset[K]
257
248
 
258
249
 
250
+ class SearchStateInit[K, N, E, G](Protocol):
251
+ def __call__(
252
+ self,
253
+ x: Graph[K, N, E, G],
254
+ y: Graph[K, N, E, G],
255
+ node_matcher: ElementMatcher[N],
256
+ edge_matcher: ElementMatcher[E],
257
+ /,
258
+ ) -> SearchState[K]: ...
259
+
260
+
261
+ @dataclass(slots=True, frozen=True)
262
+ class init_empty[K, N, E, G](SearchStateInit[K, N, E, G]):
263
+ def __call__(
264
+ self,
265
+ x: Graph[K, N, E, G],
266
+ y: Graph[K, N, E, G],
267
+ node_matcher: ElementMatcher[N],
268
+ edge_matcher: ElementMatcher[E],
269
+ ) -> SearchState[K]:
270
+ return SearchState(
271
+ frozendict(),
272
+ frozendict(),
273
+ frozenset(y.nodes.keys()),
274
+ frozenset(y.edges.keys()),
275
+ frozenset(x.nodes.keys()),
276
+ frozenset(x.edges.keys()),
277
+ )
278
+
279
+
280
+ @dataclass(slots=True, init=False)
281
+ class init_unique_matches[K, N, E, G](SearchStateInit[K, N, E, G]):
282
+ def __call__(
283
+ self,
284
+ x: Graph[K, N, E, G],
285
+ y: Graph[K, N, E, G],
286
+ node_matcher: ElementMatcher[N],
287
+ edge_matcher: ElementMatcher[E],
288
+ ) -> SearchState[K]:
289
+ # pre-populate the mapping with nodes/edges that only have one possible legal mapping
290
+ possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
291
+
292
+ for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
293
+ if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
294
+ possible_node_mappings[y_key].add(x_key)
295
+
296
+ node_mapping: frozendict[K, K] = frozendict(
297
+ (y_key, next(iter(x_keys)))
298
+ for y_key, x_keys in possible_node_mappings.items()
299
+ if len(x_keys) == 1
300
+ )
301
+
302
+ edge_mapping: frozendict[K, K] = _induced_edge_mapping(
303
+ x, y, node_mapping, edge_matcher
304
+ )
305
+
306
+ return SearchState(
307
+ node_mapping,
308
+ edge_mapping,
309
+ frozenset(y.nodes.keys() - node_mapping.keys()),
310
+ frozenset(y.edges.keys() - edge_mapping.keys()),
311
+ frozenset(x.nodes.keys() - node_mapping.values()),
312
+ frozenset(x.edges.keys() - edge_mapping.values()),
313
+ )
314
+
315
+
316
+ @dataclass(slots=True)
259
317
  class SearchGraphSimFunc[K, N, E, G](BaseGraphSimFunc[K, N, E, G]):
318
+ init_func: (
319
+ SearchStateInit[K, N, E, G] | AnySimFunc[Graph[K, N, E, G], GraphSim[K]]
320
+ ) = field(default_factory=init_unique_matches)
321
+
322
+ def init_search_state(
323
+ self, x: Graph[K, N, E, G], y: Graph[K, N, E, G]
324
+ ) -> SearchState[K]:
325
+ init_func_params = total_params(self.init_func)
326
+ sim: GraphSim[K]
327
+
328
+ if init_func_params == 4:
329
+ init_func = cast(SearchStateInit[K, N, E, G], self.init_func)
330
+
331
+ return init_func(x, y, self.node_matcher, self.edge_matcher)
332
+
333
+ elif init_func_params == 2:
334
+ init_func = cast(SimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func)
335
+ sim = init_func(x, y)
336
+
337
+ elif init_func_params == 1:
338
+ init_func = cast(
339
+ BatchSimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func
340
+ )
341
+ sim = init_func([(x, y)])[0]
342
+
343
+ else:
344
+ raise ValueError(
345
+ f"Invalid number of parameters for init_func: {init_func_params}"
346
+ )
347
+
348
+ return SearchState(
349
+ node_mapping=sim.node_mapping,
350
+ edge_mapping=sim.edge_mapping,
351
+ open_y_nodes=frozenset(y.nodes.keys() - sim.node_mapping.keys()),
352
+ open_y_edges=frozenset(y.edges.keys() - sim.edge_mapping.keys()),
353
+ open_x_nodes=frozenset(x.nodes.keys() - sim.node_mapping.values()),
354
+ open_x_edges=frozenset(x.edges.keys() - sim.edge_mapping.values()),
355
+ )
356
+
260
357
  def finished(self, state: SearchState[K]) -> bool:
261
358
  # the following condition could save a few iterations, but needs to be tested
262
359
  # return (not state.open_y_nodes and not state.open_y_edges) or (
@@ -11,8 +11,6 @@ from .common import BaseGraphSimFunc, GraphSim
11
11
 
12
12
  logger = get_logger(__name__)
13
13
 
14
- __all__ = ["dfs"]
15
-
16
14
 
17
15
  class RootsFunc[K, N, E, G](Protocol):
18
16
  """Support for matching rooted graphs
@@ -37,10 +35,10 @@ with optional_dependencies():
37
35
  class dfs[K, N, E, G](
38
36
  BaseGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
39
37
  ):
40
- node_del_cost: float = 2.0
41
- node_ins_cost: float = 0.0
42
- edge_del_cost: float = 2.0
43
- edge_ins_cost: float = 0.0
38
+ node_del_cost: float = 1.0
39
+ node_ins_cost: float = 1.0
40
+ edge_del_cost: float = 1.0
41
+ edge_ins_cost: float = 1.0
44
42
  max_iterations: int = 0
45
43
  upper_bound: float | None = None
46
44
  strictly_decreasing: bool = True
@@ -1,7 +1,5 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from frozendict import frozendict
4
-
5
3
  from ...helpers import (
6
4
  get_logger,
7
5
  )
@@ -13,8 +11,6 @@ from .common import GraphSim, SearchGraphSimFunc, SearchState
13
11
 
14
12
  logger = get_logger(__name__)
15
13
 
16
- __all__ = ["greedy"]
17
-
18
14
 
19
15
  @dataclass(slots=True)
20
16
  class greedy[K, N, E, G](
@@ -37,23 +33,17 @@ class greedy[K, N, E, G](
37
33
  x: Graph[K, N, E, G],
38
34
  y: Graph[K, N, E, G],
39
35
  ) -> GraphSim[K]:
40
- # if len(y.nodes) + len(y.edges) > len(x.nodes) + len(x.edges):
41
- # self_inv = dataclasses.replace(self, _invert=True)
42
- # return self.invert_similarity(x, y, self_inv(x=y, y=x))
36
+ node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
43
37
 
44
- current_state = SearchState(
45
- frozendict(),
46
- frozendict(),
47
- frozenset(y.nodes.keys()),
48
- frozenset(y.edges.keys()),
49
- frozenset(x.nodes.keys()),
50
- frozenset(x.edges.keys()),
38
+ current_state = self.init_search_state(x, y)
39
+ current_sim = self.similarity(
40
+ x,
41
+ y,
42
+ current_state.node_mapping,
43
+ current_state.edge_mapping,
44
+ node_pair_sims,
45
+ edge_pair_sims,
51
46
  )
52
- current_sim = GraphSim(
53
- 0.0, frozendict(), frozendict(), frozendict(), frozendict()
54
- )
55
-
56
- node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
57
47
 
58
48
  while not self.finished(current_state):
59
49
  # Iterate over all open pairs and find the best pair