cbrkit 0.26.0__tar.gz → 0.26.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {cbrkit-0.26.0 → cbrkit-0.26.1}/PKG-INFO +1 -1
  2. {cbrkit-0.26.0 → cbrkit-0.26.1}/pyproject.toml +1 -1
  3. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/build.py +36 -21
  4. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/__init__.py +2 -0
  5. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/__init__.py +6 -0
  6. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/astar.py +1 -77
  7. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/common.py +110 -2
  8. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/greedy.py +10 -14
  9. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/wrappers.py +64 -1
  10. {cbrkit-0.26.0 → cbrkit-0.26.1}/README.md +0 -0
  11. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/__init__.py +0 -0
  12. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/__main__.py +0 -0
  13. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/__init__.py +0 -0
  14. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/attribute_value.py +0 -0
  15. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/generic.py +0 -0
  16. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/numbers.py +0 -0
  17. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/adapt/strings.py +0 -0
  18. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/api.py +0 -0
  19. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/cli.py +0 -0
  20. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/constants.py +0 -0
  21. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/cycle.py +0 -0
  22. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/dumpers.py +0 -0
  23. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/eval/__init__.py +0 -0
  24. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/eval/common.py +0 -0
  25. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/eval/retrieval.py +0 -0
  26. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/helpers.py +0 -0
  27. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/loaders.py +0 -0
  28. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/model/__init__.py +0 -0
  29. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/model/graph.py +0 -0
  30. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/model/result.py +0 -0
  31. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/py.typed +0 -0
  32. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/__init__.py +0 -0
  33. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/apply.py +0 -0
  34. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/retrieval/rerank.py +0 -0
  35. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/reuse/__init__.py +0 -0
  36. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/reuse/apply.py +0 -0
  37. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/reuse/build.py +0 -0
  38. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/aggregator.py +0 -0
  39. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/attribute_value.py +0 -0
  40. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/collections.py +0 -0
  41. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/embed.py +0 -0
  42. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/generic.py +0 -0
  43. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/alignment.py +0 -0
  44. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/brute_force.py +0 -0
  45. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/dfs.py +0 -0
  46. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/lap.py +0 -0
  47. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/precompute.py +0 -0
  48. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/qap.py +0 -0
  49. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/graphs/vf2.py +0 -0
  50. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/numbers.py +0 -0
  51. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/strings.py +0 -0
  52. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/sim/taxonomy.py +0 -0
  53. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/__init__.py +0 -0
  54. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/apply.py +0 -0
  55. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/build.py +0 -0
  56. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/model.py +0 -0
  57. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/prompts.py +0 -0
  58. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/__init__.py +0 -0
  59. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
  60. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/cohere.py +0 -0
  61. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/google.py +0 -0
  62. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/instructor.py +0 -0
  63. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/model.py +0 -0
  64. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/ollama.py +0 -0
  65. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/openai.py +0 -0
  66. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
  67. {cbrkit-0.26.0 → cbrkit-0.26.1}/src/cbrkit/typing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: cbrkit
3
- Version: 0.26.0
3
+ Version: 0.26.1
4
4
  Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
5
5
  Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
6
6
  Author: Mirko Lenz
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cbrkit"
3
- version = "0.26.0"
3
+ version = "0.26.1"
4
4
  description = "Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI"
5
5
  authors = [{ name = "Mirko Lenz", email = "mirko@mirkolenz.com" }]
6
6
  readme = "README.md"
@@ -2,7 +2,7 @@ import itertools
2
2
  from collections.abc import Sequence
3
3
  from dataclasses import dataclass
4
4
  from multiprocessing.pool import Pool
5
- from typing import Any, Literal, override
5
+ from typing import Literal, override
6
6
 
7
7
  from ..helpers import (
8
8
  batchify_sim,
@@ -134,44 +134,59 @@ class combine[K, V, S: Float](RetrieverFunc[K, V, float]):
134
134
  A retriever function that combines the results from multiple retrievers.
135
135
  """
136
136
 
137
- retriever_funcs: list[RetrieverFunc[K, V, S]]
138
- aggregator: AggregatorFunc[Any, S] = default_aggregator
137
+ retriever_funcs: Sequence[RetrieverFunc[K, V, S]]
138
+ aggregator: AggregatorFunc[str, S] = default_aggregator
139
139
  strategy: Literal["intersection", "union"] = "union"
140
140
 
141
141
  @override
142
142
  def __call__(
143
143
  self, batches: Sequence[tuple[Casebase[K, V], V]]
144
144
  ) -> Sequence[SimMap[K, float]]:
145
- results = [retriever_func(batches) for retriever_func in self.retriever_funcs]
146
-
147
- return [
148
- self.__call_batch__(
149
- [
150
- results[retriever_idx][batch_idx]
151
- for retriever_idx in range(len(self.retriever_funcs))
152
- ]
153
- )
154
- for batch_idx in range(len(batches))
155
- ]
145
+ if isinstance(self.retriever_funcs, Sequence):
146
+ func_results = [
147
+ retriever_func(batches) for retriever_func in self.retriever_funcs
148
+ ]
149
+
150
+ return [
151
+ self.__call_batch__(
152
+ [batch_results[batch_idx] for batch_results in func_results]
153
+ )
154
+ for batch_idx in range(len(batches))
155
+ ]
156
+
157
+ # elif isinstance(self.retriever_funcs, Mapping):
158
+ # results = {
159
+ # func_key: retriever_func(batches)
160
+ # for func_key, retriever_func in self.retriever_funcs.items()
161
+ # }
162
+
163
+ # return [
164
+ # self.__call_batch__(
165
+ # {func_key: func_results[batch_idx] for func_key, func_results in results.items()}
166
+ # )
167
+ # for batch_idx in range(len(batches))
168
+ # ]
169
+
170
+ raise ValueError(f"Invalid retriever_funcs type: {type(self.retriever_funcs)}")
156
171
 
157
- def __call_batch__(self, results: list[SimMap[K, S]]) -> SimMap[K, float]:
172
+ def __call_batch__(self, results: Sequence[SimMap[K, S]]) -> SimMap[K, float]:
158
173
  if self.strategy == "intersection":
159
174
  return {
160
- key: self.aggregator(
161
- [result[key] for result in results if key in result]
175
+ case_key: self.aggregator(
176
+ [result[case_key] for result in results if case_key in result]
162
177
  )
163
- for key in set().intersection(
178
+ for case_key in set().intersection(
164
179
  *[set(result.keys()) for result in results]
165
180
  )
166
181
  }
167
182
 
168
183
  elif self.strategy == "union":
169
184
  return {
170
- key: self.aggregator(
171
- [result[key] for result in results if key in result]
185
+ case_key: self.aggregator(
186
+ [result[case_key] for result in results if case_key in result]
172
187
  )
173
188
  for result in results
174
- for key in result.keys()
189
+ for case_key in result.keys()
175
190
  }
176
191
 
177
192
  raise ValueError(f"Unknown strategy: {self.strategy}")
@@ -15,6 +15,7 @@ from .attribute_value import AttributeValueSim, attribute_value
15
15
  from .wrappers import (
16
16
  attribute_table,
17
17
  cache,
18
+ combine,
18
19
  dynamic_table,
19
20
  table,
20
21
  transpose,
@@ -26,6 +27,7 @@ __all__ = [
26
27
  "transpose",
27
28
  "transpose_value",
28
29
  "cache",
30
+ "combine",
29
31
  "table",
30
32
  "dynamic_table",
31
33
  "type_table",
@@ -7,7 +7,10 @@ from .common import (
7
7
  GraphSim,
8
8
  SearchGraphSimFunc,
9
9
  SearchState,
10
+ SearchStateInit,
10
11
  SemanticEdgeSim,
12
+ init_empty,
13
+ init_unique_matches,
11
14
  )
12
15
  from .greedy import greedy
13
16
  from .lap import lap
@@ -33,10 +36,13 @@ __all__ = [
33
36
  "vf2",
34
37
  "dtw",
35
38
  "smith_waterman",
39
+ "init_empty",
40
+ "init_unique_matches",
36
41
  "GraphSim",
37
42
  "ElementMatcher",
38
43
  "SemanticEdgeSim",
39
44
  "BaseGraphSimFunc",
40
45
  "SearchGraphSimFunc",
41
46
  "SearchState",
47
+ "SearchStateInit",
42
48
  ]
@@ -1,12 +1,8 @@
1
1
  import heapq
2
- import itertools
3
- from collections import defaultdict
4
2
  from collections.abc import Mapping
5
3
  from dataclasses import dataclass, field
6
4
  from typing import Protocol
7
5
 
8
- from frozendict import frozendict
9
-
10
6
  from ...helpers import (
11
7
  get_logger,
12
8
  unpack_float,
@@ -18,25 +14,20 @@ from ...model.graph import (
18
14
  )
19
15
  from ...typing import SimFunc
20
16
  from .common import (
21
- ElementMatcher,
22
17
  GraphSim,
23
18
  SearchGraphSimFunc,
24
19
  SearchState,
25
- _induced_edge_mapping,
26
20
  )
27
21
 
28
22
  __all__ = [
29
23
  "HeuristicFunc",
30
24
  "SelectionFunc",
31
- "InitFunc",
32
25
  "h1",
33
26
  "h2",
34
27
  "h3",
35
28
  "select1",
36
29
  "select2",
37
30
  "select3",
38
- "init1",
39
- "init2",
40
31
  "build",
41
32
  ]
42
33
 
@@ -74,17 +65,6 @@ class SelectionFunc[K, N, E, G](Protocol):
74
65
  ) -> None | tuple[K, GraphElementType]: ...
75
66
 
76
67
 
77
- class InitFunc[K, N, E, G](Protocol):
78
- def __call__(
79
- self,
80
- x: Graph[K, N, E, G],
81
- y: Graph[K, N, E, G],
82
- node_matcher: ElementMatcher[N],
83
- edge_matcher: ElementMatcher[E],
84
- /,
85
- ) -> SearchState[K]: ...
86
-
87
-
88
68
  @dataclass(slots=True, frozen=True)
89
69
  class h1[K, N, E, G](HeuristicFunc[K, N, E, G]):
90
70
  def __call__(
@@ -274,61 +254,6 @@ class select3[K, N, E, G](SelectionFunc[K, N, E, G]):
274
254
  return selection_key, selection_type
275
255
 
276
256
 
277
- @dataclass(slots=True, frozen=True)
278
- class init1[K, N, E, G](InitFunc[K, N, E, G]):
279
- def __call__(
280
- self,
281
- x: Graph[K, N, E, G],
282
- y: Graph[K, N, E, G],
283
- node_matcher: ElementMatcher[N],
284
- edge_matcher: ElementMatcher[E],
285
- ) -> SearchState[K]:
286
- return SearchState(
287
- frozendict(),
288
- frozendict(),
289
- frozenset(y.nodes.keys()),
290
- frozenset(y.edges.keys()),
291
- frozenset(x.nodes.keys()),
292
- frozenset(x.edges.keys()),
293
- )
294
-
295
-
296
- @dataclass(slots=True, init=False)
297
- class init2[K, N, E, G](InitFunc[K, N, E, G]):
298
- def __call__(
299
- self,
300
- x: Graph[K, N, E, G],
301
- y: Graph[K, N, E, G],
302
- node_matcher: ElementMatcher[N],
303
- edge_matcher: ElementMatcher[E],
304
- ) -> SearchState[K]:
305
- # pre-populate the mapping with nodes/edges that only have one possible legal mapping
306
- possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
307
-
308
- for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
309
- if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
310
- possible_node_mappings[y_key].add(x_key)
311
-
312
- node_mapping: frozendict[K, K] = frozendict(
313
- (y_key, next(iter(x_keys)))
314
- for y_key, x_keys in possible_node_mappings.items()
315
- if len(x_keys) == 1
316
- )
317
-
318
- edge_mapping: frozendict[K, K] = _induced_edge_mapping(
319
- x, y, node_mapping, edge_matcher
320
- )
321
-
322
- return SearchState(
323
- node_mapping,
324
- edge_mapping,
325
- frozenset(y.nodes.keys() - node_mapping.keys()),
326
- frozenset(y.edges.keys() - edge_mapping.keys()),
327
- frozenset(x.nodes.keys() - node_mapping.values()),
328
- frozenset(x.edges.keys() - edge_mapping.values()),
329
- )
330
-
331
-
332
257
  @dataclass(slots=True)
333
258
  class build[K, N, E, G](
334
259
  SearchGraphSimFunc[K, N, E, G], SimFunc[Graph[K, N, E, G], GraphSim[K]]
@@ -355,7 +280,6 @@ class build[K, N, E, G](
355
280
 
356
281
  heuristic_func: HeuristicFunc[K, N, E, G] = field(default_factory=h3)
357
282
  selection_func: SelectionFunc[K, N, E, G] = field(default_factory=select3)
358
- init_func: InitFunc[K, N, E, G] = field(default_factory=init1)
359
283
  beam_width: int = 0
360
284
  pathlength_weight: int = 0
361
285
 
@@ -446,7 +370,7 @@ class build[K, N, E, G](
446
370
  node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
447
371
 
448
372
  open_set: list[PriorityState[K]] = []
449
- best_state = self.init_func(x, y, self.node_matcher, self.edge_matcher)
373
+ best_state = self.init_search_state(x, y)
450
374
  heapq.heappush(open_set, PriorityState(0, best_state))
451
375
 
452
376
  while open_set:
@@ -1,7 +1,8 @@
1
1
  import itertools
2
+ from collections import defaultdict
2
3
  from collections.abc import Mapping, Sequence
3
4
  from dataclasses import dataclass, field
4
- from typing import Any, Protocol
5
+ from typing import Any, Protocol, cast
5
6
 
6
7
  from frozendict import frozendict
7
8
 
@@ -9,10 +10,11 @@ from ...helpers import (
9
10
  batchify_sim,
10
11
  reverse_batch_positional,
11
12
  reverse_positional,
13
+ total_params,
12
14
  unpack_float,
13
15
  )
14
16
  from ...model.graph import Edge, Graph, Node
15
- from ...typing import AnySimFunc, BatchSimFunc, Float, StructuredValue
17
+ from ...typing import AnySimFunc, BatchSimFunc, Float, SimFunc, StructuredValue
16
18
  from ..wrappers import transpose_value
17
19
 
18
20
  type PairSim[K] = Mapping[tuple[K, K], float]
@@ -256,7 +258,113 @@ class SearchState[K]:
256
258
  open_x_edges: frozenset[K]
257
259
 
258
260
 
261
+ class SearchStateInit[K, N, E, G](Protocol):
262
+ def __call__(
263
+ self,
264
+ x: Graph[K, N, E, G],
265
+ y: Graph[K, N, E, G],
266
+ node_matcher: ElementMatcher[N],
267
+ edge_matcher: ElementMatcher[E],
268
+ /,
269
+ ) -> SearchState[K]: ...
270
+
271
+
272
+ @dataclass(slots=True, frozen=True)
273
+ class init_empty[K, N, E, G](SearchStateInit[K, N, E, G]):
274
+ def __call__(
275
+ self,
276
+ x: Graph[K, N, E, G],
277
+ y: Graph[K, N, E, G],
278
+ node_matcher: ElementMatcher[N],
279
+ edge_matcher: ElementMatcher[E],
280
+ ) -> SearchState[K]:
281
+ return SearchState(
282
+ frozendict(),
283
+ frozendict(),
284
+ frozenset(y.nodes.keys()),
285
+ frozenset(y.edges.keys()),
286
+ frozenset(x.nodes.keys()),
287
+ frozenset(x.edges.keys()),
288
+ )
289
+
290
+
291
+ @dataclass(slots=True, init=False)
292
+ class init_unique_matches[K, N, E, G](SearchStateInit[K, N, E, G]):
293
+ def __call__(
294
+ self,
295
+ x: Graph[K, N, E, G],
296
+ y: Graph[K, N, E, G],
297
+ node_matcher: ElementMatcher[N],
298
+ edge_matcher: ElementMatcher[E],
299
+ ) -> SearchState[K]:
300
+ # pre-populate the mapping with nodes/edges that only have one possible legal mapping
301
+ possible_node_mappings: defaultdict[K, set[K]] = defaultdict(set)
302
+
303
+ for y_key, x_key in itertools.product(y.nodes.keys(), x.nodes.keys()):
304
+ if node_matcher(x.nodes[x_key].value, y.nodes[y_key].value):
305
+ possible_node_mappings[y_key].add(x_key)
306
+
307
+ node_mapping: frozendict[K, K] = frozendict(
308
+ (y_key, next(iter(x_keys)))
309
+ for y_key, x_keys in possible_node_mappings.items()
310
+ if len(x_keys) == 1
311
+ )
312
+
313
+ edge_mapping: frozendict[K, K] = _induced_edge_mapping(
314
+ x, y, node_mapping, edge_matcher
315
+ )
316
+
317
+ return SearchState(
318
+ node_mapping,
319
+ edge_mapping,
320
+ frozenset(y.nodes.keys() - node_mapping.keys()),
321
+ frozenset(y.edges.keys() - edge_mapping.keys()),
322
+ frozenset(x.nodes.keys() - node_mapping.values()),
323
+ frozenset(x.edges.keys() - edge_mapping.values()),
324
+ )
325
+
326
+
327
+ @dataclass(slots=True)
259
328
  class SearchGraphSimFunc[K, N, E, G](BaseGraphSimFunc[K, N, E, G]):
329
+ init_func: (
330
+ SearchStateInit[K, N, E, G] | AnySimFunc[Graph[K, N, E, G], GraphSim[K]]
331
+ ) = field(default_factory=init_unique_matches)
332
+
333
+ def init_search_state(
334
+ self, x: Graph[K, N, E, G], y: Graph[K, N, E, G]
335
+ ) -> SearchState[K]:
336
+ init_func_params = total_params(self.init_func)
337
+ sim: GraphSim[K]
338
+
339
+ if init_func_params == 4:
340
+ init_func = cast(SearchStateInit[K, N, E, G], self.init_func)
341
+
342
+ return init_func(x, y, self.node_matcher, self.edge_matcher)
343
+
344
+ elif init_func_params == 2:
345
+ init_func = cast(SimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func)
346
+ sim = init_func(x, y)
347
+
348
+ elif init_func_params == 1:
349
+ init_func = cast(
350
+ BatchSimFunc[Graph[K, N, E, G], GraphSim[K]], self.init_func
351
+ )
352
+ sim = init_func([(x, y)])[0]
353
+
354
+ else:
355
+ raise ValueError(
356
+ f"Invalid number of parameters for init_func: {init_func_params}"
357
+ )
358
+
359
+ return SearchState(
360
+ node_mapping=sim.node_mapping,
361
+ edge_mapping=sim.edge_mapping,
362
+ open_y_nodes=frozenset(y.nodes.keys() - sim.node_mapping.keys()),
363
+ open_y_edges=frozenset(y.edges.keys() - sim.edge_mapping.keys()),
364
+ open_x_nodes=frozenset(x.nodes.keys() - sim.node_mapping.values()),
365
+ open_x_edges=frozenset(x.edges.keys() - sim.edge_mapping.values()),
366
+ )
367
+
260
368
  def finished(self, state: SearchState[K]) -> bool:
261
369
  # the following condition could save a few iterations, but needs to be tested
262
370
  # return (not state.open_y_nodes and not state.open_y_edges) or (
@@ -1,7 +1,5 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from frozendict import frozendict
4
-
5
3
  from ...helpers import (
6
4
  get_logger,
7
5
  )
@@ -41,20 +39,18 @@ class greedy[K, N, E, G](
41
39
  # self_inv = dataclasses.replace(self, _invert=True)
42
40
  # return self.invert_similarity(x, y, self_inv(x=y, y=x))
43
41
 
44
- current_state = SearchState(
45
- frozendict(),
46
- frozendict(),
47
- frozenset(y.nodes.keys()),
48
- frozenset(y.edges.keys()),
49
- frozenset(x.nodes.keys()),
50
- frozenset(x.edges.keys()),
51
- )
52
- current_sim = GraphSim(
53
- 0.0, frozendict(), frozendict(), frozendict(), frozendict()
54
- )
55
-
56
42
  node_pair_sims, edge_pair_sims = self.pair_similarities(x, y)
57
43
 
44
+ current_state = self.init_search_state(x, y)
45
+ current_sim = self.similarity(
46
+ x,
47
+ y,
48
+ current_state.node_mapping,
49
+ current_state.edge_mapping,
50
+ node_pair_sims,
51
+ edge_pair_sims,
52
+ )
53
+
58
54
  while not self.finished(current_state):
59
55
  # Iterate over all open pairs and find the best pair
60
56
  next_states: list[SearchState[K]] = []
@@ -1,10 +1,11 @@
1
1
  from collections import defaultdict
2
2
  from collections.abc import Callable, Mapping, MutableMapping, Sequence
3
- from dataclasses import dataclass, field
3
+ from dataclasses import InitVar, dataclass, field
4
4
  from typing import Any, cast, override
5
5
 
6
6
  from ..helpers import batchify_sim, get_metadata, get_value, getitem_or_getattr
7
7
  from ..typing import (
8
+ AggregatorFunc,
8
9
  AnySimFunc,
9
10
  BatchSimFunc,
10
11
  ConversionFunc,
@@ -14,6 +15,7 @@ from ..typing import (
14
15
  SimSeq,
15
16
  StructuredValue,
16
17
  )
18
+ from .aggregator import default_aggregator
17
19
  from .generic import static
18
20
 
19
21
 
@@ -59,6 +61,67 @@ def transpose_value[V, S: Float](
59
61
  return transpose(func, get_value)
60
62
 
61
63
 
64
+ @dataclass(slots=True)
65
+ class combine[V, S: Float](BatchSimFunc[V, float]):
66
+ """Combines multiple similarity functions into one.
67
+
68
+ Args:
69
+ sim_funcs: A list of similarity functions to be combined.
70
+ aggregator: A function to aggregate the results from the similarity functions.
71
+
72
+ Returns:
73
+ A similarity function that combines the results from multiple similarity functions.
74
+ """
75
+
76
+ sim_funcs: InitVar[Sequence[AnySimFunc[V, S]] | Mapping[str, AnySimFunc[V, S]]]
77
+ aggregator: AggregatorFunc[str, S] = default_aggregator
78
+ batch_sim_funcs: Sequence[BatchSimFunc[V, S]] | Mapping[str, BatchSimFunc[V, S]] = (
79
+ field(init=False, repr=False)
80
+ )
81
+
82
+ def __post_init__(
83
+ self, sim_funcs: Sequence[AnySimFunc[V, S]] | Mapping[str, AnySimFunc[V, S]]
84
+ ):
85
+ if isinstance(sim_funcs, Sequence):
86
+ self.batch_sim_funcs = [batchify_sim(func) for func in sim_funcs]
87
+ elif isinstance(sim_funcs, Mapping):
88
+ self.batch_sim_funcs = {
89
+ key: batchify_sim(func) for key, func in sim_funcs.items()
90
+ }
91
+ else:
92
+ raise ValueError(f"Invalid sim_funcs type: {type(sim_funcs)}")
93
+
94
+ @override
95
+ def __call__(self, batches: Sequence[tuple[V, V]]) -> Sequence[float]:
96
+ if isinstance(self.batch_sim_funcs, Sequence):
97
+ func_results = [func(batches) for func in self.batch_sim_funcs]
98
+
99
+ return [
100
+ self.aggregator(
101
+ [batch_results[batch_idx] for batch_results in func_results]
102
+ )
103
+ for batch_idx in range(len(batches))
104
+ ]
105
+
106
+ elif isinstance(self.batch_sim_funcs, Mapping):
107
+ func_results = {
108
+ func_key: func(batches)
109
+ for func_key, func in self.batch_sim_funcs.items()
110
+ }
111
+
112
+ return [
113
+ self.aggregator(
114
+ {
115
+ func_key: batch_results[batch_idx]
116
+ for func_key, batch_results in func_results.items()
117
+ }
118
+ )
119
+ for batch_idx in range(len(batches))
120
+ ]
121
+
122
+ raise ValueError(f"Invalid batch_sim_funcs type: {type(self.batch_sim_funcs)}")
123
+
124
+
62
125
  @dataclass(slots=True)
63
126
  class cache[V, U, S: Float](BatchSimFunc[V, S]):
64
127
  similarity_func: BatchSimFunc[V, S]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes