cbrkit 0.23.2__tar.gz → 0.24.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {cbrkit-0.23.2/src/cbrkit.egg-info → cbrkit-0.24.0}/PKG-INFO +1 -1
  2. {cbrkit-0.23.2 → cbrkit-0.24.0}/pyproject.toml +1 -1
  3. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/loaders.py +2 -2
  4. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/retrieval/__init__.py +2 -1
  5. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/retrieval/build.py +35 -20
  6. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/retrieval/rerank.py +1 -1
  7. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/embed.py +1 -1
  8. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/__init__.py +9 -2
  9. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/astar.py +23 -15
  10. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/common.py +4 -0
  11. cbrkit-0.24.0/src/cbrkit/sim/graphs/greedy.py +363 -0
  12. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/isomorphism.py +2 -2
  13. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/anthropic.py +5 -6
  14. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/openai.py +5 -6
  15. {cbrkit-0.23.2 → cbrkit-0.24.0/src/cbrkit.egg-info}/PKG-INFO +1 -1
  16. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit.egg-info/SOURCES.txt +1 -0
  17. {cbrkit-0.23.2 → cbrkit-0.24.0}/tests/test_graph.py +1 -0
  18. {cbrkit-0.23.2 → cbrkit-0.24.0}/LICENSE +0 -0
  19. {cbrkit-0.23.2 → cbrkit-0.24.0}/README.md +0 -0
  20. {cbrkit-0.23.2 → cbrkit-0.24.0}/setup.cfg +0 -0
  21. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/__init__.py +0 -0
  22. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/__main__.py +0 -0
  23. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/adapt/__init__.py +0 -0
  24. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/adapt/attribute_value.py +0 -0
  25. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/adapt/generic.py +0 -0
  26. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/adapt/numbers.py +0 -0
  27. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/adapt/strings.py +0 -0
  28. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/api.py +0 -0
  29. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/cli.py +0 -0
  30. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/constants.py +0 -0
  31. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/cycle.py +0 -0
  32. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/dumpers.py +0 -0
  33. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/eval/__init__.py +0 -0
  34. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/eval/common.py +0 -0
  35. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/eval/retrieval.py +0 -0
  36. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/helpers.py +0 -0
  37. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/model/__init__.py +0 -0
  38. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/model/graph.py +0 -0
  39. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/model/result.py +0 -0
  40. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/py.typed +0 -0
  41. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/retrieval/apply.py +0 -0
  42. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/reuse/__init__.py +0 -0
  43. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/reuse/apply.py +0 -0
  44. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/reuse/build.py +0 -0
  45. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/__init__.py +0 -0
  46. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/aggregator.py +0 -0
  47. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/attribute_value.py +0 -0
  48. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/collections.py +0 -0
  49. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/generic.py +0 -0
  50. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/alignment.py +0 -0
  51. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/brute_force.py +0 -0
  52. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/graphs/precompute.py +0 -0
  53. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/numbers.py +0 -0
  54. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/strings.py +0 -0
  55. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/taxonomy.py +0 -0
  56. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/sim/wrappers.py +0 -0
  57. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/__init__.py +0 -0
  58. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/apply.py +0 -0
  59. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/build.py +0 -0
  60. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/model.py +0 -0
  61. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/prompts.py +0 -0
  62. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/__init__.py +0 -0
  63. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/cohere.py +0 -0
  64. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/google.py +0 -0
  65. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/instructor.py +0 -0
  66. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/model.py +0 -0
  67. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/ollama.py +0 -0
  68. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
  69. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit/typing.py +0 -0
  70. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit.egg-info/dependency_links.txt +0 -0
  71. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit.egg-info/entry_points.txt +0 -0
  72. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit.egg-info/requires.txt +0 -0
  73. {cbrkit-0.23.2 → cbrkit-0.24.0}/src/cbrkit.egg-info/top_level.txt +0 -0
  74. {cbrkit-0.23.2 → cbrkit-0.24.0}/tests/test_cycle.py +0 -0
  75. {cbrkit-0.23.2 → cbrkit-0.24.0}/tests/test_retrieve.py +0 -0
  76. {cbrkit-0.23.2 → cbrkit-0.24.0}/tests/test_reuse.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cbrkit
3
- Version: 0.23.2
3
+ Version: 0.24.0
4
4
  Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
5
5
  Author-email: Mirko Lenz <mirko@mirkolenz.com>
6
6
  Project-URL: Repository, https://github.com/wi2trier/cbrkit
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cbrkit"
3
- version = "0.23.2"
3
+ version = "0.24.0"
4
4
  description = "Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI"
5
5
  authors = [{ name = "Mirko Lenz", email = "mirko@mirkolenz.com" }]
6
6
  readme = "README.md"
@@ -46,7 +46,7 @@ def read(data: ReadableType) -> str:
46
46
  elif isinstance(data, bytes | bytearray):
47
47
  return data.decode("utf-8")
48
48
 
49
- return read(data.read())
49
+ return read(data.read()) # pyright: ignore
50
50
 
51
51
 
52
52
  @dataclass(slots=True, frozen=True)
@@ -102,7 +102,7 @@ class csv(ConversionFunc[Iterable[str] | ReadableType, dict[int, dict[str, str]]
102
102
  if isinstance(source, ReadableType):
103
103
  source = read(source).splitlines()
104
104
 
105
- reader = csvlib.DictReader(source)
105
+ reader = csvlib.DictReader(source) # pyright: ignore
106
106
  data: dict[int, dict[str, str]] = {}
107
107
  row: dict[str, str]
108
108
 
@@ -1,7 +1,7 @@
1
1
  from ..helpers import optional_dependencies
2
2
  from ..model import QueryResultStep, Result, ResultStep
3
3
  from .apply import apply_batches, apply_queries, apply_query
4
- from .build import build, dropout, transpose, transpose_value
4
+ from .build import build, distribute, dropout, transpose, transpose_value
5
5
 
6
6
  with optional_dependencies():
7
7
  from .rerank import cohere
@@ -17,6 +17,7 @@ __all__ = [
17
17
  "transpose",
18
18
  "transpose_value",
19
19
  "dropout",
20
+ "distribute",
20
21
  "apply_batches",
21
22
  "apply_queries",
22
23
  "apply_query",
@@ -1,4 +1,4 @@
1
- import math
1
+ import itertools
2
2
  from collections.abc import Sequence
3
3
  from dataclasses import dataclass
4
4
  from multiprocessing.pool import Pool
@@ -119,6 +119,33 @@ def transpose_value[K, V, S: Float](
119
119
  return transpose(retriever_func, get_value)
120
120
 
121
121
 
122
+ @dataclass(slots=True, frozen=True)
123
+ class distribute[K, V, S: Float](RetrieverFunc[K, V, S]):
124
+ """Distributes the retrieval process by passing each batch separately to the retriever function.
125
+
126
+ Args:
127
+ retriever_func: The retriever function to be used.
128
+ Typically constructed with the `build` function.
129
+ multiprocessing: Either a boolean to enable multiprocessing with all cores
130
+ or an integer to specify the number of processes to use or a multiprocessing.Pool object.
131
+
132
+ Returns:
133
+ A retriever function that distributes the retrieval process.
134
+ """
135
+
136
+ retriever_func: RetrieverFunc[K, V, S]
137
+ multiprocessing: Pool | int | bool
138
+
139
+ def __call_batch__(self, x: Casebase[K, V], y: V) -> SimMap[K, S]:
140
+ return self.retriever_func([(x, y)])[0]
141
+
142
+ @override
143
+ def __call__(
144
+ self, batches: Sequence[tuple[Casebase[K, V], V]]
145
+ ) -> Sequence[SimMap[K, S]]:
146
+ return mp_starmap(self.__call_batch__, batches, self.multiprocessing, logger)
147
+
148
+
122
149
  @dataclass(slots=True, frozen=True)
123
150
  class build[K, V, S: Float](RetrieverFunc[K, V, S]):
124
151
  """Based on the similarity function this function creates a retriever function.
@@ -153,24 +180,12 @@ class build[K, V, S: Float](RetrieverFunc[K, V, S]):
153
180
 
154
181
  similarity_func: MaybeFactory[AnySimFunc[V, S]]
155
182
  multiprocessing: Pool | int | bool = False
156
- chunksize: int | None = None
157
-
158
- def __call_single__(self, x: Casebase[K, V], y: V) -> SimMap[K, S]:
159
- sim_func = batchify_sim(self.similarity_func)
160
- flat_batches = [(x, y) for x in x.values()]
161
- flat_sims: Sequence[S] = sim_func(flat_batches)
162
-
163
- return {key: sim for key, sim in zip(x.keys(), flat_sims, strict=True)}
183
+ chunksize: int = 0
164
184
 
165
185
  @override
166
186
  def __call__(
167
187
  self, batches: Sequence[tuple[Casebase[K, V], V]]
168
188
  ) -> Sequence[SimMap[K, S]]:
169
- if self.chunksize is None:
170
- return mp_starmap(
171
- self.__call_single__, batches, self.multiprocessing, logger
172
- )
173
-
174
189
  sim_func = batchify_sim(self.similarity_func)
175
190
  similarities: list[dict[K, S]] = [{} for _ in range(len(batches))]
176
191
 
@@ -183,15 +198,15 @@ class build[K, V, S: Float](RetrieverFunc[K, V, S]):
183
198
  flat_batches_index.append((idx, key))
184
199
  flat_batches.append((case, query))
185
200
 
186
- if use_mp(self.multiprocessing):
187
- chunksize = self.chunksize or math.ceil(
188
- len(flat_batches) / mp_count(self.multiprocessing)
201
+ if use_mp(self.multiprocessing) or self.chunksize > 0:
202
+ chunksize = (
203
+ self.chunksize
204
+ if self.chunksize > 0
205
+ else len(flat_batches) // mp_count(self.multiprocessing)
189
206
  )
190
207
  batch_chunks = list(chunkify(flat_batches, chunksize))
191
208
  sim_chunks = mp_map(sim_func, batch_chunks, self.multiprocessing, logger)
192
-
193
- for sim_chunk in sim_chunks:
194
- flat_sims.extend(sim_chunk)
209
+ flat_sims = list(itertools.chain.from_iterable(sim_chunks))
195
210
 
196
211
  else:
197
212
  flat_sims = sim_func(flat_batches)
@@ -150,7 +150,7 @@ with optional_dependencies():
150
150
  batches: Sequence[tuple[Casebase[K, str], str]],
151
151
  ) -> Sequence[dict[K, float]]:
152
152
  if isinstance(self.model, str):
153
- model = SentenceTransformer(self.model, device=self.device)
153
+ model = SentenceTransformer(self.model, device=self.device) # pyright: ignore
154
154
  else:
155
155
  model = self.model
156
156
 
@@ -357,7 +357,7 @@ with optional_dependencies():
357
357
  self._metadata = {}
358
358
 
359
359
  if isinstance(model, str):
360
- self.model = SentenceTransformer(model)
360
+ self.model = SentenceTransformer(model) # pyright: ignore
361
361
  self._metadata["model"] = model
362
362
  else:
363
363
  self.model = model
@@ -1,7 +1,12 @@
1
1
  from ...helpers import optional_dependencies
2
- from . import astar
2
+ from . import astar, greedy
3
3
  from .brute_force import brute_force
4
- from .common import ElementMatcher, GraphSim, default_element_matcher
4
+ from .common import (
5
+ ElementMatcher,
6
+ GraphSim,
7
+ default_element_matcher,
8
+ type_element_matcher,
9
+ )
5
10
  from .isomorphism import isomorphism
6
11
  from .precompute import precompute
7
12
 
@@ -13,6 +18,7 @@ with optional_dependencies():
13
18
 
14
19
  __all__ = [
15
20
  "astar",
21
+ "greedy",
16
22
  "brute_force",
17
23
  "isomorphism",
18
24
  "precompute",
@@ -21,4 +27,5 @@ __all__ = [
21
27
  "GraphSim",
22
28
  "ElementMatcher",
23
29
  "default_element_matcher",
30
+ "type_element_matcher",
24
31
  ]
@@ -2,7 +2,7 @@ import heapq
2
2
  import itertools
3
3
  from collections import defaultdict
4
4
  from collections.abc import Mapping, Sequence
5
- from dataclasses import dataclass, field
5
+ from dataclasses import InitVar, dataclass, field
6
6
  from typing import Protocol, override
7
7
 
8
8
  from frozendict import frozendict
@@ -406,14 +406,18 @@ class init1[K, N, E, G](InitFunc[K, N, E, G]):
406
406
  )
407
407
 
408
408
 
409
- @dataclass(slots=True, frozen=True)
409
+ @dataclass(slots=True, init=False)
410
410
  class init2[K, N, E, G](InitFunc[K, N, E, G]):
411
- legal_node_mapping: LegalMappingFunc[K, N, E, G] = field(
412
- default_factory=lambda: legal_node_mapping(default_element_matcher)
413
- )
414
- legal_edge_mapping: LegalMappingFunc[K, N, E, G] = field(
415
- default_factory=lambda: legal_edge_mapping(default_element_matcher)
416
- )
411
+ legal_node_mapping: LegalMappingFunc[K, N, E, G]
412
+ legal_edge_mapping: LegalMappingFunc[K, N, E, G]
413
+
414
+ def __init__(
415
+ self,
416
+ node_matcher: ElementMatcher[N],
417
+ edge_matcher: ElementMatcher[E] = default_element_matcher,
418
+ ):
419
+ self.legal_node_mapping = legal_node_mapping(node_matcher)
420
+ self.legal_edge_mapping = legal_edge_mapping(edge_matcher)
417
421
 
418
422
  def __call__(
419
423
  self,
@@ -525,7 +529,7 @@ class legal_edge_mapping[K, N, E, G](LegalMappingFunc[K, N, E, G]):
525
529
  )
526
530
 
527
531
 
528
- @dataclass(slots=True, frozen=True)
532
+ @dataclass(slots=True)
529
533
  class build[K, N, E, G](SimFunc[Graph[K, N, E, G], GraphSim[K]]):
530
534
  """
531
535
  Performs the A* algorithm proposed by [Bergmann and Gil (2014)](https://doi.org/10.1016/j.is.2012.07.005) to compute the similarity between a query graph and the graphs in the casebase.
@@ -546,18 +550,22 @@ class build[K, N, E, G](SimFunc[Graph[K, N, E, G], GraphSim[K]]):
546
550
 
547
551
  past_cost_func: PastSimFunc[K, N, E, G]
548
552
  future_cost_func: FutureSimFunc[K, N, E, G]
553
+ node_matcher: InitVar[ElementMatcher[N]]
554
+ edge_matcher: InitVar[ElementMatcher[E]] = default_element_matcher
549
555
  selection_func: SelectionFunc[K, N, E, G] = field(default_factory=select2)
550
556
  init_func: InitFunc[K, N, E, G] = field(default_factory=init1)
551
- legal_node_mapping: LegalMappingFunc[K, N, E, G] = field(
552
- default_factory=lambda: legal_node_mapping(default_element_matcher)
553
- )
554
- legal_edge_mapping: LegalMappingFunc[K, N, E, G] = field(
555
- default_factory=lambda: legal_edge_mapping(default_element_matcher)
556
- )
557
557
  queue_limit: int = 10000
558
+ legal_node_mapping: LegalMappingFunc[K, N, E, G] = field(init=False)
559
+ legal_edge_mapping: LegalMappingFunc[K, N, E, G] = field(init=False)
558
560
  # TODO: Currently not implemented as described in the paper, needs further investigation
559
561
  allow_case_oriented_mapping: bool = False
560
562
 
563
+ def __post_init__(
564
+ self, node_matcher: ElementMatcher[N], edge_matcher: ElementMatcher[E]
565
+ ) -> None:
566
+ self.legal_node_mapping = legal_node_mapping(node_matcher)
567
+ self.legal_edge_mapping = legal_edge_mapping(edge_matcher)
568
+
561
569
  def expand_node(
562
570
  self,
563
571
  x: Graph[K, N, E, G],
@@ -17,4 +17,8 @@ class ElementMatcher[T](Protocol):
17
17
 
18
18
 
19
19
  def default_element_matcher(x: Any, y: Any) -> bool:
20
+ return True
21
+
22
+
23
+ def type_element_matcher(x: Any, y: Any) -> bool:
20
24
  return type(x) is type(y)
@@ -0,0 +1,363 @@
1
+ import itertools
2
+ from collections.abc import Callable, Mapping, Sequence
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Protocol, override
5
+
6
+ from frozendict import frozendict
7
+
8
+ from ...helpers import (
9
+ batchify_sim,
10
+ get_logger,
11
+ unpack_float,
12
+ unpack_floats,
13
+ )
14
+ from ...model.graph import (
15
+ Edge,
16
+ Graph,
17
+ Node,
18
+ )
19
+ from ...typing import AnySimFunc, BatchSimFunc, Float, SimFunc, StructuredValue
20
+ from ..wrappers import transpose_value
21
+ from .common import ElementMatcher, GraphSim, default_element_matcher
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ __all__ = [
26
+ "build",
27
+ "init1",
28
+ "default_edge_sim",
29
+ "State",
30
+ "StateSim",
31
+ "InitFunc",
32
+ "LegalMappingFunc",
33
+ "build",
34
+ "legal_node_mapping",
35
+ "legal_edge_mapping",
36
+ ]
37
+
38
+
39
+ @dataclass(slots=True, frozen=True)
40
+ class State[K]:
41
+ # mappings are from y to x
42
+ mapped_nodes: frozendict[K, K]
43
+ mapped_edges: frozendict[K, K]
44
+ open_node_pairs: frozenset[tuple[K, K]]
45
+ open_edge_pairs: frozenset[tuple[K, K]]
46
+
47
+
48
+ @dataclass(slots=True, frozen=True)
49
+ class StateSim[K](StructuredValue[float]):
50
+ node_similarities: Mapping[K, float]
51
+ edge_similarities: Mapping[K, float]
52
+
53
+
54
+ class InitFunc[K, N, E, G](Protocol):
55
+ def __call__(
56
+ self,
57
+ x: Graph[K, N, E, G],
58
+ y: Graph[K, N, E, G],
59
+ /,
60
+ ) -> State[K]: ...
61
+
62
+
63
+ class LegalMappingFunc[K, N, E, G](Protocol):
64
+ def __call__(
65
+ self,
66
+ x: Graph[K, N, E, G],
67
+ y: Graph[K, N, E, G],
68
+ state: State[K],
69
+ x_key: K,
70
+ y_key: K,
71
+ ) -> bool: ...
72
+
73
+
74
+ @dataclass(slots=True, frozen=True)
75
+ class init1[K, N, E, G](InitFunc[K, N, E, G]):
76
+ def __call__(
77
+ self,
78
+ x: Graph[K, N, E, G],
79
+ y: Graph[K, N, E, G],
80
+ ) -> State[K]:
81
+ return State(
82
+ frozendict(),
83
+ frozendict(),
84
+ frozenset(itertools.product(y.nodes.keys(), x.nodes.keys())),
85
+ frozenset(itertools.product(y.edges.keys(), x.edges.keys())),
86
+ )
87
+
88
+
89
+ @dataclass(slots=True, frozen=True)
90
+ class default_edge_sim[K, N, E](BatchSimFunc[Edge[K, N, E], Float]):
91
+ node_sim_func: BatchSimFunc[Node[K, N], Float]
92
+
93
+ @override
94
+ def __call__(
95
+ self, batches: Sequence[tuple[Edge[K, N, E], Edge[K, N, E]]]
96
+ ) -> list[float]:
97
+ source_sims = self.node_sim_func([(x.source, y.source) for x, y in batches])
98
+ target_sims = self.node_sim_func([(x.target, y.target) for x, y in batches])
99
+
100
+ return [
101
+ 0.5 * (unpack_float(source) + unpack_float(target))
102
+ for source, target in zip(source_sims, target_sims, strict=True)
103
+ ]
104
+
105
+
106
+ @dataclass(slots=True, frozen=True)
107
+ class legal_node_mapping[K, N, E, G](LegalMappingFunc[K, N, E, G]):
108
+ matcher: ElementMatcher[N]
109
+
110
+ def __call__(
111
+ self,
112
+ x: Graph[K, N, E, G],
113
+ y: Graph[K, N, E, G],
114
+ state: State[K],
115
+ x_key: K,
116
+ y_key: K,
117
+ ) -> bool:
118
+ return (
119
+ y_key not in state.mapped_nodes.keys()
120
+ and x_key not in state.mapped_nodes.values()
121
+ and self.matcher(x.nodes[x_key].value, y.nodes[y_key].value)
122
+ )
123
+
124
+
125
+ @dataclass(slots=True, frozen=True)
126
+ class legal_edge_mapping[K, N, E, G](LegalMappingFunc[K, N, E, G]):
127
+ matcher: ElementMatcher[E]
128
+
129
+ def __call__(
130
+ self,
131
+ x: Graph[K, N, E, G],
132
+ y: Graph[K, N, E, G],
133
+ state: State[K],
134
+ x_key: K,
135
+ y_key: K,
136
+ ) -> bool:
137
+ x_value = x.edges[x_key]
138
+ y_value = y.edges[y_key]
139
+ mapped_x_source_key = state.mapped_nodes.get(y_value.source.key)
140
+ mapped_x_target_key = state.mapped_nodes.get(y_value.target.key)
141
+
142
+ return (
143
+ y_key not in state.mapped_edges.keys()
144
+ and x_key not in state.mapped_edges.values()
145
+ and self.matcher(x_value.value, y_value.value)
146
+ # if the nodes are already mapped, check if they are mapped legally
147
+ and (
148
+ mapped_x_source_key is None or x_value.source.key == mapped_x_source_key
149
+ )
150
+ and (
151
+ mapped_x_target_key is None or x_value.target.key == mapped_x_target_key
152
+ )
153
+ )
154
+
155
+
156
+ @dataclass(slots=True, init=False)
157
+ class build[K, N, E, G](SimFunc[Graph[K, N, E, G], GraphSim[K]]):
158
+ """
159
+ Performs the A* algorithm proposed by [Bergmann and Gil (2014)](https://doi.org/10.1016/j.is.2012.07.005) to compute the similarity between a query graph and the graphs in the casebase.
160
+
161
+ Args:
162
+ past_cost_func: A heuristic function to compute the costs of all previous steps.
163
+ future_cost_func: A heuristic function to compute the future costs.
164
+ selection_func: A function to select the next node or edge to be mapped.
165
+ init_func: A function to initialize the state.
166
+ node_matcher: A function that returns true if two nodes can be mapped legally.
167
+ edge_matcher: A function that returns true if two edges can be mapped legally.
168
+ queue_limit: Limits the queue size which prunes the search space.
169
+ This leads to a faster search and less memory usage but also introduces a similarity error.
170
+
171
+ Returns:
172
+ The similarity between the query graph and the most similar graph in the casebase.
173
+ """
174
+
175
+ node_sim_func: BatchSimFunc[Node[K, N], Float]
176
+ edge_sim_func: BatchSimFunc[Edge[K, N, E], Float]
177
+ init_func: InitFunc[K, N, E, G]
178
+ legal_node_mapping: LegalMappingFunc[K, N, E, G]
179
+ legal_edge_mapping: LegalMappingFunc[K, N, E, G]
180
+ start_with: Literal["nodes", "edges"]
181
+
182
+ def __init__(
183
+ self,
184
+ node_sim_func: AnySimFunc[N, Float],
185
+ node_matcher: ElementMatcher[N],
186
+ edge_sim_func: AnySimFunc[Edge[K, N, E], Float] | None = None,
187
+ edge_matcher: ElementMatcher[E] = default_element_matcher,
188
+ init_func: InitFunc[K, N, E, G] | None = None,
189
+ start_with: Literal["nodes", "edges"] = "nodes",
190
+ ) -> None:
191
+ self.legal_node_mapping = legal_node_mapping(node_matcher)
192
+ self.legal_edge_mapping = legal_edge_mapping(edge_matcher)
193
+
194
+ self.node_sim_func = batchify_sim(transpose_value(node_sim_func))
195
+ self.edge_sim_func = (
196
+ default_edge_sim(self.node_sim_func)
197
+ if edge_sim_func is None
198
+ else batchify_sim(edge_sim_func)
199
+ )
200
+ self.init_func = init_func if init_func else init1()
201
+
202
+ self.start_with = start_with
203
+
204
+ def compute_similarity(
205
+ self,
206
+ x: Graph[K, N, E, G],
207
+ y: Graph[K, N, E, G],
208
+ s: State[K],
209
+ ) -> StateSim[K]:
210
+ """Function to compute the similarity based on the current state"""
211
+
212
+ node_sims = unpack_floats(
213
+ self.node_sim_func(
214
+ [
215
+ (x.nodes[x_key], y.nodes[y_key])
216
+ for y_key, x_key in s.mapped_nodes.items()
217
+ ]
218
+ )
219
+ )
220
+
221
+ edge_sims = unpack_floats(
222
+ self.edge_sim_func(
223
+ [
224
+ (x.edges[x_key], y.edges[y_key])
225
+ for y_key, x_key in s.mapped_edges.items()
226
+ ]
227
+ )
228
+ )
229
+
230
+ all_sims = itertools.chain(node_sims, edge_sims)
231
+ total_elements = len(y.nodes) + len(y.edges)
232
+
233
+ return StateSim(
234
+ sum(all_sims) / total_elements,
235
+ dict(zip(s.mapped_nodes.keys(), node_sims, strict=True)),
236
+ dict(zip(s.mapped_edges.keys(), edge_sims, strict=True)),
237
+ )
238
+
239
+ def expand_edges(
240
+ self,
241
+ x: Graph[K, N, E, G],
242
+ y: Graph[K, N, E, G],
243
+ state: State[K],
244
+ ) -> list[State[K]]:
245
+ """Expand the current state by adding all possible edge mappings"""
246
+
247
+ new_states = []
248
+
249
+ for y_key, x_key in state.open_edge_pairs:
250
+ if self.legal_edge_mapping(x, y, state, x_key, y_key):
251
+ new_state = State(
252
+ state.mapped_nodes,
253
+ state.mapped_edges.set(y_key, x_key),
254
+ state.open_node_pairs,
255
+ state.open_edge_pairs
256
+ - {
257
+ (y, x)
258
+ for y, x in itertools.product(y.edges.keys(), x.edges.keys())
259
+ if y == y_key or x == x_key
260
+ },
261
+ )
262
+ new_states.append(new_state)
263
+
264
+ return new_states
265
+
266
+ def expand_nodes(
267
+ self,
268
+ x: Graph[K, N, E, G],
269
+ y: Graph[K, N, E, G],
270
+ state: State[K],
271
+ ) -> list[State[K]]:
272
+ """Expand the current state by adding all possible node mappings"""
273
+
274
+ new_states = []
275
+
276
+ for y_key, x_key in state.open_node_pairs:
277
+ if self.legal_node_mapping(x, y, state, x_key, y_key):
278
+ new_state = State(
279
+ state.mapped_nodes.set(y_key, x_key),
280
+ state.mapped_edges,
281
+ state.open_node_pairs
282
+ - {
283
+ (y, x)
284
+ for y, x in itertools.product(y.nodes.keys(), x.nodes.keys())
285
+ if y == y_key or x == x_key
286
+ },
287
+ state.open_edge_pairs,
288
+ )
289
+ new_states.append(new_state)
290
+
291
+ return new_states
292
+
293
+ def expand(
294
+ self,
295
+ x: Graph[K, N, E, G],
296
+ y: Graph[K, N, E, G],
297
+ current_state: State[K],
298
+ current_sim: StateSim[K],
299
+ expand_func: Callable[
300
+ [Graph[K, N, E, G], Graph[K, N, E, G], State[K]], list[State[K]]
301
+ ],
302
+ ) -> tuple[State[K], StateSim[K]]:
303
+ """Expand the current state by adding all possible mappings"""
304
+
305
+ while True:
306
+ # Iterate over all open pairs and find the best pair
307
+ new_states = expand_func(x, y, current_state)
308
+ new_sims = [self.compute_similarity(x, y, state) for state in new_states]
309
+
310
+ best_sim, best_state = max(
311
+ zip(new_sims, new_states, strict=True), key=lambda item: item[0].value
312
+ )
313
+
314
+ # If no better pair is found, break the loop
315
+ if best_state == current_state and best_sim == current_sim:
316
+ break
317
+
318
+ # Update the current state and similarity
319
+ current_state = best_state
320
+ current_sim = best_sim
321
+
322
+ return current_state, current_sim
323
+
324
+ def __call__(
325
+ self,
326
+ x: Graph[K, N, E, G],
327
+ y: Graph[K, N, E, G],
328
+ ) -> GraphSim[K]:
329
+ """Perform greedy graph matching of the query y against the case x"""
330
+
331
+ current_state = self.init_func(x, y)
332
+ current_sim = self.compute_similarity(x, y, current_state)
333
+
334
+ if self.start_with == "edges":
335
+ current_state, current_sim = self.expand(
336
+ x, y, current_state, current_sim, self.expand_edges
337
+ )
338
+ current_state, current_sim = self.expand(
339
+ x, y, current_state, current_sim, self.expand_nodes
340
+ )
341
+ elif self.start_with == "nodes":
342
+ current_state, current_sim = self.expand(
343
+ x, y, current_state, current_sim, self.expand_nodes
344
+ )
345
+ current_state, current_sim = self.expand(
346
+ x, y, current_state, current_sim, self.expand_edges
347
+ )
348
+ else:
349
+ raise ValueError(
350
+ f"Invalid start_with value: {self.start_with}. Expected 'nodes' or 'edges'."
351
+ )
352
+
353
+ return GraphSim(
354
+ current_sim.value,
355
+ dict(current_state.mapped_nodes),
356
+ dict(current_state.mapped_edges),
357
+ dict(current_sim.node_similarities)
358
+ if isinstance(current_sim, StateSim)
359
+ else {},
360
+ dict(current_sim.edge_similarities)
361
+ if isinstance(current_sim, StateSim)
362
+ else {},
363
+ )
@@ -31,9 +31,9 @@ class isomorphism[K, N, E, G](SimFunc[Graph[K, N, E, G], GraphSim[K]]):
31
31
  """
32
32
 
33
33
  node_sim_func: AnySimFunc[N, Float]
34
- aggregator: AggregatorFunc[Any, Float] = default_aggregator
35
- node_matcher: ElementMatcher[N] = default_element_matcher
34
+ node_matcher: ElementMatcher[N]
36
35
  edge_matcher: ElementMatcher[E] = default_element_matcher
36
+ aggregator: AggregatorFunc[Any, Float] = default_aggregator
37
37
  id_order: bool = True
38
38
  subgraph: bool = True
39
39
  induced: bool = True
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Iterable, Sequence
2
2
  from dataclasses import dataclass, field
3
- from typing import Literal, cast, override
3
+ from typing import Any, Literal, cast, override
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
@@ -40,8 +40,7 @@ def pydantic_to_anthropic_schema(model: type[BaseModel], description: str = "")
40
40
 
41
41
 
42
42
  with optional_dependencies():
43
- from anthropic import AsyncAnthropic
44
- from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query
43
+ from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven
45
44
  from anthropic.types import (
46
45
  MessageParam,
47
46
  MetadataParam,
@@ -68,9 +67,9 @@ with optional_dependencies():
68
67
  tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN
69
68
  top_k: int | NotGiven = NOT_GIVEN
70
69
  top_p: float | NotGiven = NOT_GIVEN
71
- extra_headers: Headers | None = None
72
- extra_query: Query | None = None
73
- extra_body: Body | None = None
70
+ extra_headers: Any | None = None
71
+ extra_query: Any | None = None
72
+ extra_body: Any | None = None
74
73
  timeout: float | Timeout | NotGiven | None = NOT_GIVEN
75
74
 
76
75
  @override
@@ -1,7 +1,7 @@
1
1
  from collections.abc import Sequence
2
2
  from dataclasses import dataclass, field
3
3
  from types import UnionType
4
- from typing import Literal, Union, cast, get_args, get_origin, override
4
+ from typing import Any, Literal, Union, cast, get_args, get_origin, override
5
5
 
6
6
  from pydantic import BaseModel, ValidationError
7
7
 
@@ -12,8 +12,7 @@ logger = get_logger(__name__)
12
12
 
13
13
  with optional_dependencies():
14
14
  from httpx import Timeout
15
- from openai import AsyncOpenAI, pydantic_function_tool
16
- from openai._types import NOT_GIVEN, Body, Headers, NotGiven, Query
15
+ from openai import NOT_GIVEN, AsyncOpenAI, NotGiven, pydantic_function_tool
17
16
  from openai.types.chat import (
18
17
  ChatCompletionMessageParam,
19
18
  ChatCompletionNamedToolChoiceParam,
@@ -43,9 +42,9 @@ with optional_dependencies():
43
42
  temperature: float | None = None
44
43
  top_logprobs: int | None = None
45
44
  top_p: float | None = None
46
- extra_headers: Headers | None = None
47
- extra_query: Query | None = None
48
- extra_body: Body | None = None
45
+ extra_headers: Any | None = None
46
+ extra_query: Any | None = None
47
+ extra_body: Any | None = None
49
48
  timeout: float | Timeout | None = None
50
49
 
51
50
  @override
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cbrkit
3
- Version: 0.23.2
3
+ Version: 0.24.0
4
4
  Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
5
5
  Author-email: Mirko Lenz <mirko@mirkolenz.com>
6
6
  Project-URL: Repository, https://github.com/wi2trier/cbrkit
@@ -51,6 +51,7 @@ src/cbrkit/sim/graphs/alignment.py
51
51
  src/cbrkit/sim/graphs/astar.py
52
52
  src/cbrkit/sim/graphs/brute_force.py
53
53
  src/cbrkit/sim/graphs/common.py
54
+ src/cbrkit/sim/graphs/greedy.py
54
55
  src/cbrkit/sim/graphs/isomorphism.py
55
56
  src/cbrkit/sim/graphs/precompute.py
56
57
  src/cbrkit/synthesis/__init__.py
@@ -35,6 +35,7 @@ def test_astar():
35
35
  graph_sim = cbrkit.sim.graphs.astar.build(
36
36
  cbrkit.sim.graphs.astar.g1(node_sim),
37
37
  cbrkit.sim.graphs.astar.h2(node_sim),
38
+ node_matcher=cbrkit.sim.graphs.type_element_matcher,
38
39
  )
39
40
  retriever = cbrkit.retrieval.build(graph_sim)
40
41
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes