cbrkit 0.27.2__tar.gz → 0.28.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. {cbrkit-0.27.2 → cbrkit-0.28.0}/PKG-INFO +4 -1
  2. {cbrkit-0.27.2 → cbrkit-0.28.0}/pyproject.toml +4 -1
  3. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/__init__.py +2 -0
  4. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/api.py +1 -1
  5. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/cli.py +25 -1
  6. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/cycle.py +15 -0
  7. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/helpers.py +15 -10
  8. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/retrieval/rerank.py +12 -9
  9. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/embed.py +46 -44
  10. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/__init__.py +8 -0
  11. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/google.py +0 -1
  12. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/model.py +26 -19
  13. cbrkit-0.28.0/src/cbrkit/synthesis/providers/openai_agents.py +65 -0
  14. cbrkit-0.28.0/src/cbrkit/synthesis/providers/pydantic_ai.py +43 -0
  15. cbrkit-0.28.0/src/cbrkit/system.py +157 -0
  16. {cbrkit-0.27.2 → cbrkit-0.28.0}/README.md +0 -0
  17. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/__main__.py +0 -0
  18. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/adapt/__init__.py +0 -0
  19. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/adapt/attribute_value.py +0 -0
  20. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/adapt/generic.py +0 -0
  21. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/adapt/numbers.py +0 -0
  22. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/adapt/strings.py +0 -0
  23. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/constants.py +0 -0
  24. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/dumpers.py +0 -0
  25. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/eval/__init__.py +0 -0
  26. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/eval/common.py +0 -0
  27. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/eval/retrieval.py +0 -0
  28. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/loaders.py +0 -0
  29. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/model/__init__.py +0 -0
  30. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/model/graph.py +0 -0
  31. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/model/result.py +0 -0
  32. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/py.typed +0 -0
  33. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/retrieval/__init__.py +0 -0
  34. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/retrieval/apply.py +0 -0
  35. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/retrieval/build.py +0 -0
  36. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/reuse/__init__.py +0 -0
  37. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/reuse/apply.py +0 -0
  38. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/reuse/build.py +0 -0
  39. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/__init__.py +0 -0
  40. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/aggregator.py +0 -0
  41. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/attribute_value.py +0 -0
  42. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/collections.py +0 -0
  43. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/generic.py +0 -0
  44. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/__init__.py +0 -0
  45. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/alignment.py +0 -0
  46. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/astar.py +0 -0
  47. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/brute_force.py +0 -0
  48. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/common.py +0 -0
  49. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/dfs.py +0 -0
  50. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/greedy.py +0 -0
  51. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/lap.py +0 -0
  52. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/precompute.py +0 -0
  53. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/qap.py +0 -0
  54. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/graphs/vf2.py +0 -0
  55. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/numbers.py +0 -0
  56. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/pooling.py +0 -0
  57. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/strings.py +0 -0
  58. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/taxonomy.py +0 -0
  59. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/sim/wrappers.py +0 -0
  60. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/__init__.py +0 -0
  61. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/apply.py +0 -0
  62. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/build.py +0 -0
  63. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/model.py +0 -0
  64. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/prompts.py +0 -0
  65. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/anthropic.py +0 -0
  66. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/cohere.py +0 -0
  67. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/instructor.py +0 -0
  68. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/ollama.py +0 -0
  69. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/openai.py +0 -0
  70. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/synthesis/providers/wrappers.py +0 -0
  71. {cbrkit-0.27.2 → cbrkit-0.28.0}/src/cbrkit/typing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: cbrkit
3
- Version: 0.27.2
3
+ Version: 0.28.0
4
4
  Summary: Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI
5
5
  Keywords: cbr,case-based reasoning,api,similarity,nlp,retrieval,cli,tool,library
6
6
  Author: Mirko Lenz
@@ -39,6 +39,7 @@ Requires-Dist: fastapi>=0.100,<1 ; extra == 'api'
39
39
  Requires-Dist: pydantic-settings>=2,<3 ; extra == 'api'
40
40
  Requires-Dist: python-multipart>=0.0.15,<1 ; extra == 'api'
41
41
  Requires-Dist: uvicorn[standard]>=0.30,<1 ; extra == 'api'
42
+ Requires-Dist: fastmcp>=2,<3 ; extra == 'api'
42
43
  Requires-Dist: chonkie>=1,<2 ; extra == 'chunking'
43
44
  Requires-Dist: typer>=0.9,<1 ; extra == 'cli'
44
45
  Requires-Dist: ranx>=0.3,<1 ; extra == 'eval'
@@ -53,6 +54,8 @@ Requires-Dist: tiktoken>=0.8,<1 ; extra == 'llm'
53
54
  Requires-Dist: anthropic>=0.40,<1 ; extra == 'llm'
54
55
  Requires-Dist: google-genai>=1,<2 ; extra == 'llm'
55
56
  Requires-Dist: instructor>=1,<2 ; extra == 'llm'
57
+ Requires-Dist: openai-agents>=0.2,<1 ; extra == 'llm'
58
+ Requires-Dist: pydantic-ai-slim>=0.4,<1 ; extra == 'llm'
56
59
  Requires-Dist: bm25s[core,stem]>=0.2,<1 ; extra == 'nlp'
57
60
  Requires-Dist: levenshtein>=0.23,<0.26 ; platform_machine == 'x86_64' and sys_platform == 'darwin' and extra == 'nlp'
58
61
  Requires-Dist: levenshtein>=0.26,<1 ; (platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'nlp') or (sys_platform == 'linux' and extra == 'nlp')
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cbrkit"
3
- version = "0.27.2"
3
+ version = "0.28.0"
4
4
  description = "Customizable Case-Based Reasoning (CBR) toolkit for Python with a built-in API and CLI"
5
5
  authors = [{ name = "Mirko Lenz", email = "mirko@mirkolenz.com" }]
6
6
  readme = "README.md"
@@ -57,6 +57,7 @@ api = [
57
57
  "pydantic-settings>=2,<3",
58
58
  "python-multipart>=0.0.15,<1",
59
59
  "uvicorn[standard]>=0.30,<1",
60
+ "fastmcp>=2,<3",
60
61
  ]
61
62
  chunking = ["chonkie>=1,<2"]
62
63
  cli = ["typer>=0.9,<1"]
@@ -72,6 +73,8 @@ llm = [
72
73
  "anthropic>=0.40,<1",
73
74
  "google-genai>=1,<2",
74
75
  "instructor>=1,<2",
76
+ "openai-agents>=0.2,<1",
77
+ "pydantic-ai-slim>=0.4,<1",
75
78
  ]
76
79
  nlp = [
77
80
  "bm25s[core,stem]>=0.2,<1",
@@ -22,6 +22,7 @@ from . import (
22
22
  reuse,
23
23
  sim,
24
24
  synthesis,
25
+ system,
25
26
  typing,
26
27
  )
27
28
 
@@ -39,6 +40,7 @@ __all__ = [
39
40
  "synthesis",
40
41
  "typing",
41
42
  "constants",
43
+ "system",
42
44
  ]
43
45
 
44
46
  logging.getLogger(__name__).addHandler(logging.NullHandler())
@@ -56,7 +56,7 @@ def parse_dataset(obj: CasebaseSpec) -> Mapping[str, Any]:
56
56
  loader = cbrkit.loaders.structured_loaders[f".{obj.content_type}"]
57
57
  data = loader(obj.file)
58
58
  elif isinstance(obj, Path):
59
- data = cbrkit.loaders.path(obj)
59
+ data = cbrkit.loaders.file(obj)
60
60
 
61
61
  if not all(isinstance(key, str) for key in data.keys()):
62
62
  return {str(key): value for key, value in data.items()}
@@ -154,6 +154,8 @@ def serve(
154
154
  ) -> None:
155
155
  import uvicorn
156
156
 
157
+ from cbrkit.api import app
158
+
157
159
  sys.path.extend(str(x) for x in search_path)
158
160
 
159
161
  os.environ["CBRKIT_RETRIEVER"] = ",".join(retriever)
@@ -161,7 +163,29 @@ def serve(
161
163
  os.environ["CBRKIT_SYNTHESIZER"] = ",".join(synthesizer)
162
164
 
163
165
  uvicorn.run(
164
- "cbrkit.api:app",
166
+ app,
167
+ host=host,
168
+ port=port,
169
+ reload=reload,
170
+ root_path=root_path,
171
+ )
172
+
173
+
174
+ @app.command()
175
+ def uvicorn(
176
+ app: str,
177
+ search_path: Annotated[list[Path], typer.Option(default_factory=list)],
178
+ host: str = "0.0.0.0",
179
+ port: int = 8080,
180
+ reload: bool = False,
181
+ root_path: str = "",
182
+ ) -> None:
183
+ import uvicorn
184
+
185
+ sys.path.extend(str(x) for x in search_path)
186
+
187
+ uvicorn.run(
188
+ app,
165
189
  host=host,
166
190
  port=port,
167
191
  reload=reload,
@@ -10,6 +10,7 @@ from .typing import Float, MaybeFactories, RetrieverFunc, ReuserFunc
10
10
  __all__ = [
11
11
  "apply_queries",
12
12
  "apply_batches",
13
+ "apply_query",
13
14
  "Result",
14
15
  ]
15
16
 
@@ -43,3 +44,17 @@ def apply_queries[Q, C, V, S: Float](
43
44
  return Result(
44
45
  retrieval=retrieval_result, reuse=reuse_result, duration=end_time - start_time
45
46
  )
47
+
48
+
49
+ def apply_query[K, V, S: Float](
50
+ casebase: Mapping[K, V],
51
+ query: V,
52
+ retrievers: MaybeFactories[RetrieverFunc[K, V, S]],
53
+ reusers: MaybeFactories[ReuserFunc[K, V, S]],
54
+ ) -> Result[str, K, V, S]:
55
+ return apply_queries(
56
+ casebase,
57
+ {"default": query},
58
+ retrievers,
59
+ reusers,
60
+ )
@@ -115,19 +115,24 @@ class EventLoop:
115
115
 
116
116
  def close(self) -> None:
117
117
  if self._instance is not None:
118
- tasks = asyncio.all_tasks(self._instance)
119
- for task in tasks:
120
- task.cancel()
121
-
122
118
  try:
123
- self._instance.run_until_complete(
124
- asyncio.gather(*tasks, return_exceptions=True)
125
- )
126
- except asyncio.CancelledError:
119
+ tasks = asyncio.all_tasks(self._instance)
120
+ for task in tasks:
121
+ task.cancel()
122
+
123
+ if tasks:
124
+ self._instance.run_until_complete(
125
+ asyncio.gather(*tasks, return_exceptions=True)
126
+ )
127
+ except (RuntimeError, asyncio.CancelledError):
128
+ # Event loop may already be closed or unavailable
127
129
  pass
128
-
129
130
  finally:
130
- self._instance.close()
131
+ try:
132
+ self._instance.close()
133
+ except RuntimeError:
134
+ # Event loop may already be closed
135
+ pass
131
136
  self._instance = None
132
137
 
133
138
 
@@ -250,6 +250,7 @@ with optional_dependencies():
250
250
 
251
251
  language: str
252
252
  stopwords: list[str] | None = None
253
+ auto_index: bool = False
253
254
  _indexed_retriever: bm25s.BM25 | None = field(
254
255
  default=None, init=False, repr=False
255
256
  )
@@ -265,7 +266,7 @@ with optional_dependencies():
265
266
  def _stemmer(self) -> Callable[..., Any]:
266
267
  return Stemmer.Stemmer(self.language)
267
268
 
268
- def index(self, casebase: Casebase[K, str]) -> None:
269
+ def _build_retriever(self, casebase: Casebase[K, str]) -> bm25s.BM25:
269
270
  cases_tokens = bm25s.tokenize(
270
271
  list(casebase.values()),
271
272
  stemmer=self._stemmer,
@@ -273,6 +274,11 @@ with optional_dependencies():
273
274
  )
274
275
  retriever = bm25s.BM25()
275
276
  retriever.index(cases_tokens)
277
+
278
+ return retriever
279
+
280
+ def index(self, casebase: Casebase[K, str]) -> None:
281
+ retriever = self._build_retriever(casebase)
276
282
  self._indexed_retriever = retriever
277
283
  self._indexed_casebase = dict(casebase)
278
284
 
@@ -306,14 +312,11 @@ with optional_dependencies():
306
312
  if self._indexed_retriever and self._indexed_casebase == casebase:
307
313
  retriever = self._indexed_retriever
308
314
  else:
309
- cases_tokens = bm25s.tokenize(
310
- list(casebase.values()),
311
- stemmer=self._stemmer,
312
- stopwords=self._stopwords,
313
- )
314
- retriever = bm25s.BM25()
315
- retriever.index(cases_tokens)
316
- # TODO: maybe there should be an option to auto-persist on-demand indexing
315
+ retriever = self._build_retriever(casebase)
316
+
317
+ if self.auto_index:
318
+ self._indexed_retriever = retriever
319
+ self._indexed_casebase = dict(casebase)
317
320
 
318
321
  queries_tokens = bm25s.tokenize(
319
322
  cast(list[str], queries),
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import itertools
3
+ import sqlite3
3
4
  from collections.abc import Iterator, MutableMapping, Sequence
4
- from contextlib import AbstractContextManager
5
+ from contextlib import AbstractContextManager, contextmanager
5
6
  from dataclasses import dataclass, field
6
7
  from pathlib import Path
7
8
  from typing import Literal, cast, override
@@ -170,71 +171,72 @@ class build[V, S: Float](BatchSimFunc[V, S]):
170
171
  class cache(BatchConversionFunc[str, NumpyArray]):
171
172
  func: BatchConversionFunc[str, NumpyArray] | None
172
173
  path: Path | None
173
- autodump: bool
174
- autoload: bool
174
+ table: str | None
175
175
  store: MutableMapping[str, NumpyArray] = field(repr=False)
176
- mtime: float = field(repr=False)
177
176
 
178
177
  def __init__(
179
178
  self,
180
179
  func: AnyConversionFunc[str, NumpyArray] | None,
181
180
  path: FilePath | None = None,
182
- autodump: bool = False,
183
- autoload: bool = False,
181
+ table: str | None = None,
184
182
  ):
185
183
  self.func = batchify_conversion(func) if func is not None else None
186
184
  self.path = Path(path) if isinstance(path, str) else path
187
- self.autodump = autodump
188
- self.autoload = autoload
189
- self.mtime = 0
185
+ self.table = table
186
+ self.store = {}
190
187
 
191
- if self.path and self.path.exists():
192
- self.load()
193
- else:
194
- self.store = {}
188
+ if self.path is not None:
189
+ if self.table is None:
190
+ raise ValueError("Table name must be specified for disk cache")
195
191
 
196
- def dump(self) -> None:
197
- if not self.path:
198
- raise ValueError("Path not provided")
192
+ self.path.parent.mkdir(parents=True, exist_ok=True)
199
193
 
200
- if not self.store:
201
- logger.warning("Cache is empty, skipping dump")
202
- return
203
-
204
- if self.path.exists() and self.mtime < self.path.stat().st_mtime:
205
- logger.warning("Cache file has been modified, skipping dump")
206
- return
207
-
208
- np.savez_compressed(self.path, **self.store)
209
- self.mtime = self.path.stat().st_mtime
194
+ with self.connect() as connection:
195
+ connection.execute(f"""
196
+ CREATE TABLE IF NOT EXISTS "{self.table}" (
197
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
198
+ text TEXT NOT NULL UNIQUE,
199
+ vector BLOB NOT NULL
200
+ )
201
+ """)
210
202
 
211
- def load(self) -> None:
212
- if not self.path:
213
- raise ValueError("Path not provided")
203
+ cursor = connection.execute(f'SELECT text, vector FROM "{self.table}"')
214
204
 
215
- if not self.path.exists():
216
- raise FileNotFoundError(f"Cache file '{self.path}' does not exist")
205
+ for text, vector_blob in cursor:
206
+ self.store[text] = np.frombuffer(vector_blob, dtype=np.float64)
217
207
 
218
- mtime = self.path.stat().st_mtime
208
+ @contextmanager
209
+ def connect(self) -> Iterator[sqlite3.Connection]:
210
+ if self.path is None:
211
+ raise ValueError("Path must be set to use the cache")
219
212
 
220
- if self.mtime < mtime:
221
- self.mtime = mtime
213
+ connection = sqlite3.connect(self.path)
222
214
 
223
- with np.load(self.path) as data:
224
- self.store = dict(data)
215
+ try:
216
+ yield connection
217
+ finally:
218
+ connection.close()
225
219
 
220
+ @override
226
221
  def __call__(self, texts: Sequence[str]) -> Sequence[NumpyArray]:
227
- new_texts = [text for text in texts if text not in self.store]
222
+ # remove store entries and duplicates
223
+ new_texts = list({text for text in texts if text not in self.store})
228
224
 
229
225
  if new_texts:
230
- if self.autoload:
231
- self.load()
232
-
233
226
  if self.func:
234
- self.store.update(zip(new_texts, self.func(new_texts), strict=True))
235
-
236
- if self.autodump:
237
- self.dump()
227
+ new_vectors = self.func(new_texts)
228
+
229
+ for text, vector in zip(new_texts, new_vectors, strict=True):
230
+ self.store[text] = vector
231
+
232
+ if self.path is not None:
233
+ with self.connect() as connection:
234
+ for text, vector in zip(new_texts, new_vectors, strict=True):
235
+ vector_blob = vector.astype(np.float64).tobytes()
236
+ connection.execute(
237
+ f'INSERT OR IGNORE INTO "{self.table}" (text, vector) VALUES (?, ?)',
238
+ (text, vector_blob),
239
+ )
238
240
 
239
241
  return [self.store[text] for text in texts]
240
242
 
@@ -1,5 +1,6 @@
1
1
  from ...helpers import optional_dependencies
2
2
  from .model import (
3
+ AsyncProvider,
3
4
  BaseProvider,
4
5
  ChatMessage,
5
6
  ChatPrompt,
@@ -20,6 +21,10 @@ with optional_dependencies():
20
21
  from .anthropic import anthropic
21
22
  with optional_dependencies():
22
23
  from .instructor import instructor
24
+ with optional_dependencies():
25
+ from .pydantic_ai import pydantic_ai
26
+ with optional_dependencies():
27
+ from .openai_agents import openai_agents
23
28
 
24
29
  __all__ = [
25
30
  "openai",
@@ -27,6 +32,7 @@ __all__ = [
27
32
  "cohere",
28
33
  "conversation",
29
34
  "pipe",
35
+ "AsyncProvider",
30
36
  "BaseProvider",
31
37
  "ChatProvider",
32
38
  "ChatMessage",
@@ -36,4 +42,6 @@ __all__ = [
36
42
  "Usage",
37
43
  "anthropic",
38
44
  "instructor",
45
+ "pydantic_ai",
46
+ "openai_agents",
39
47
  ]
@@ -14,7 +14,6 @@ with optional_dependencies():
14
14
 
15
15
  @dataclass(slots=True)
16
16
  class google[R: BaseModel | str](BaseProvider[GooglePrompt, R]):
17
- system_message: str | None = None
18
17
  client: Client = field(default_factory=Client, repr=False)
19
18
  config: GenerateContentConfig = field(init=False)
20
19
  base_config: InitVar[GenerateContentConfig | None] = None
@@ -2,7 +2,7 @@ import asyncio
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping, Sequence
4
4
  from dataclasses import dataclass, field
5
- from typing import Any, Literal
5
+ from typing import Any, Literal, override
6
6
 
7
7
  from pydantic import Field
8
8
 
@@ -44,19 +44,12 @@ class Response[T](StructuredValue[T]):
44
44
 
45
45
 
46
46
  @dataclass(slots=True, kw_only=True)
47
- class BaseProvider[P, R](BatchConversionFunc[P, Response[R]], ABC):
48
- model: str
49
- response_type: type[R]
50
- delay: float = 0
51
- retries: int = 0
52
- default_response: R | None = None
53
- extra_kwargs: Mapping[str, Any] = field(default_factory=dict)
54
-
55
- def __call__(self, batches: Sequence[P]) -> Sequence[Response[R]]:
47
+ class AsyncProvider[P, R](BatchConversionFunc[P, R], ABC):
48
+ def __call__(self, batches: Sequence[P]) -> Sequence[R]:
56
49
  return event_loop.get().run_until_complete(self.__call_batches__(batches))
57
50
 
58
- async def __call_batches__(self, batches: Sequence[P]) -> Sequence[Response[R]]:
59
- logger.info(f"Processing {len(batches)} batches with {self.model}")
51
+ async def __call_batches__(self, batches: Sequence[P]) -> Sequence[R]:
52
+ logger.info(f"Processing {len(batches)} batches")
60
53
 
61
54
  return await asyncio.gather(
62
55
  *(
@@ -65,6 +58,26 @@ class BaseProvider[P, R](BatchConversionFunc[P, Response[R]], ABC):
65
58
  )
66
59
  )
67
60
 
61
+ async def __call_batch_wrapper__(self, prompt: P, idx: int) -> R:
62
+ result = await self.__call_batch__(prompt)
63
+ logger.debug(f"Result of batch {idx + 1}: {result}")
64
+ return result
65
+
66
+ @abstractmethod
67
+ async def __call_batch__(self, prompt: P) -> R: ...
68
+
69
+
70
+ @dataclass(slots=True, kw_only=True)
71
+ class BaseProvider[P, R](AsyncProvider[P, Response[R]], ABC):
72
+ model: str
73
+ response_type: type[R]
74
+ system_message: str | None = None
75
+ delay: float = 0
76
+ retries: int = 0
77
+ default_response: R | None = None
78
+ extra_kwargs: Mapping[str, Any] = field(default_factory=dict)
79
+
80
+ @override
68
81
  async def __call_batch_wrapper__(
69
82
  self, prompt: P, idx: int, retry: int = 0
70
83
  ) -> Response[R]:
@@ -72,9 +85,7 @@ class BaseProvider[P, R](BatchConversionFunc[P, Response[R]], ABC):
72
85
  await asyncio.sleep(idx * self.delay)
73
86
 
74
87
  try:
75
- result = await self.__call_batch__(prompt)
76
- logger.debug(f"Result of batch {idx + 1}: {result}")
77
- return result
88
+ return await super(BaseProvider, self).__call_batch_wrapper__(prompt, idx)
78
89
 
79
90
  except Exception as e:
80
91
  if retry < self.retries:
@@ -87,11 +98,7 @@ class BaseProvider[P, R](BatchConversionFunc[P, Response[R]], ABC):
87
98
 
88
99
  raise e
89
100
 
90
- @abstractmethod
91
- async def __call_batch__(self, prompt: P) -> Response[R]: ...
92
-
93
101
 
94
102
  @dataclass(slots=True, kw_only=True)
95
103
  class ChatProvider[P, R](BaseProvider[P, R], ABC):
96
- system_message: str | None = None
97
104
  messages: Sequence[ChatMessage] = field(default_factory=tuple)
@@ -0,0 +1,65 @@
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+ from typing import cast, override
4
+ from uuid import uuid1
5
+
6
+ from cbrkit import helpers
7
+ from cbrkit.typing import MaybeSequence
8
+
9
+ from ...helpers import get_logger, optional_dependencies
10
+ from .model import AsyncProvider
11
+
12
+ logger = get_logger(__name__)
13
+
14
+ with optional_dependencies():
15
+ from agents import (
16
+ Agent,
17
+ RunConfig,
18
+ RunHooks,
19
+ Runner,
20
+ RunResult,
21
+ SQLiteSession,
22
+ TResponseInputItem,
23
+ )
24
+ from agents.run import DEFAULT_MAX_TURNS
25
+
26
+ type OpenaiAgentsPrompt = str | list[TResponseInputItem]
27
+
28
+ # the output is any in the base class, so we override it here
29
+ class TypedRunResult[R](RunResult):
30
+ final_output: R
31
+
32
+ @dataclass(slots=True)
33
+ class openai_agents[T, R](AsyncProvider[OpenaiAgentsPrompt, TypedRunResult[R]]):
34
+ agents: MaybeSequence[Agent[T]]
35
+ context: T | None = None
36
+ max_turns: int = DEFAULT_MAX_TURNS
37
+ hooks: RunHooks[T] | None = None
38
+ run_config: RunConfig | None = None
39
+
40
+ @override
41
+ async def __call_batch__(self, prompt: OpenaiAgentsPrompt) -> TypedRunResult[R]:
42
+ agents = helpers.produce_sequence(self.agents)
43
+
44
+ if not agents:
45
+ raise ValueError("No agents given.")
46
+
47
+ head_agent, *tail_agents = agents
48
+
49
+ session = SQLiteSession(uuid1().hex) if len(agents) > 1 else None
50
+
51
+ run = partial(
52
+ Runner.run,
53
+ context=self.context,
54
+ max_turns=self.max_turns,
55
+ hooks=self.hooks,
56
+ run_config=self.run_config,
57
+ session=session,
58
+ )
59
+
60
+ response: RunResult = await run(head_agent, prompt)
61
+
62
+ for agent in tail_agents:
63
+ response = await run(agent, [])
64
+
65
+ return cast(TypedRunResult[R], response)
@@ -0,0 +1,43 @@
1
+ from collections.abc import Sequence
2
+ from dataclasses import dataclass
3
+ from typing import override
4
+
5
+ from cbrkit import helpers
6
+ from cbrkit.typing import MaybeSequence
7
+
8
+ from ...helpers import get_logger, optional_dependencies
9
+ from .model import AsyncProvider
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ with optional_dependencies():
14
+ from pydantic_ai.agent import Agent, AgentRunResult
15
+ from pydantic_ai.messages import UserContent
16
+
17
+ type PydanticAiPrompt = str | Sequence[UserContent]
18
+
19
+ @dataclass(slots=True)
20
+ class pydantic_ai[T, R](AsyncProvider[PydanticAiPrompt, AgentRunResult[R]]):
21
+ agents: MaybeSequence[Agent[T, R]]
22
+ deps: T
23
+
24
+ @override
25
+ async def __call_batch__(self, prompt: PydanticAiPrompt) -> AgentRunResult[R]:
26
+ agents = helpers.produce_sequence(self.agents)
27
+
28
+ if not agents:
29
+ raise ValueError("No agents given.")
30
+
31
+ head_agent, *tail_agents = agents
32
+
33
+ response: AgentRunResult[R] = await head_agent.run(prompt, deps=self.deps)
34
+
35
+ for agent in tail_agents:
36
+ response = await agent.run(
37
+ # inject the system prompt because the default is not used if message history is provided
38
+ agent._system_prompts,
39
+ deps=self.deps,
40
+ message_history=response.all_messages() if response else None,
41
+ )
42
+
43
+ return response
@@ -0,0 +1,157 @@
1
+ from collections.abc import Callable, Mapping, Sequence
2
+ from dataclasses import dataclass, field
3
+
4
+ from pydantic import BaseModel
5
+ from typing_extensions import Any
6
+
7
+ import cbrkit
8
+ from cbrkit.helpers import produce_sequence
9
+ from cbrkit.typing import Float, MaybeSequence
10
+
11
+ __all__ = [
12
+ "System",
13
+ "to_fastapi",
14
+ "to_fastmcp",
15
+ "to_pydantic_ai",
16
+ ]
17
+
18
+
19
+ @dataclass(slots=True, frozen=True)
20
+ class System[K: str | int, V: BaseModel, S: Float]:
21
+ casebase: cbrkit.typing.Casebase[K, V]
22
+ model: type[V]
23
+ retriever_pipelines: Mapping[
24
+ str, MaybeSequence[cbrkit.typing.RetrieverFunc[K, V, S]]
25
+ ] = field(default_factory=dict)
26
+ reuser_pipelines: Mapping[str, MaybeSequence[cbrkit.typing.ReuserFunc[K, V, S]]] = (
27
+ field(default_factory=dict)
28
+ )
29
+
30
+ def get_retriever_pipeline(
31
+ self, name: str, limit: int | None
32
+ ) -> Sequence[cbrkit.typing.RetrieverFunc[K, V, S]]:
33
+ retrievers = produce_sequence(self.retriever_pipelines[name])
34
+
35
+ if limit is not None:
36
+ *head_retrievers, tail_retriever = retrievers
37
+ retrievers = head_retrievers + [
38
+ cbrkit.retrieval.dropout(tail_retriever, limit=limit)
39
+ ]
40
+
41
+ return retrievers
42
+
43
+ def retrieve(
44
+ self,
45
+ query: V,
46
+ retriever_pipeline: str,
47
+ limit: int | None = None,
48
+ ) -> cbrkit.retrieval.QueryResultStep[K, V, S]:
49
+ return cbrkit.retrieval.apply_query(
50
+ self.casebase,
51
+ query,
52
+ self.get_retriever_pipeline(retriever_pipeline, limit),
53
+ ).default_query
54
+
55
+ def reuse(
56
+ self,
57
+ query: V,
58
+ reuser_pipeline: str,
59
+ ) -> cbrkit.retrieval.QueryResultStep[K, V, S]:
60
+ return cbrkit.reuse.apply_query(
61
+ self.casebase,
62
+ query,
63
+ self.reuser_pipelines[reuser_pipeline],
64
+ ).default_query
65
+
66
+ def cycle(
67
+ self,
68
+ query: V,
69
+ retriever_pipeline: str,
70
+ reuser_pipeline: str,
71
+ limit: int | None = None,
72
+ ) -> cbrkit.retrieval.QueryResultStep[K, V, S]:
73
+ return cbrkit.cycle.apply_query(
74
+ self.casebase,
75
+ query,
76
+ self.get_retriever_pipeline(retriever_pipeline, limit),
77
+ self.reuser_pipelines[reuser_pipeline],
78
+ ).final_step.default_query
79
+
80
+ @property
81
+ def tools(self) -> list[Callable[..., Any]]:
82
+ res: list[Callable[..., Any]] = []
83
+
84
+ if self.retriever_pipelines:
85
+ res.append(self.retrieve)
86
+
87
+ if self.reuser_pipelines:
88
+ res.append(self.reuse)
89
+
90
+ if self.retriever_pipelines and self.reuser_pipelines:
91
+ res.append(self.cycle)
92
+
93
+ return res
94
+
95
+ def get_case(self, name: K) -> V:
96
+ return self.casebase[name]
97
+
98
+ def get_retriever_names(self) -> list[str]:
99
+ return list(self.retriever_pipelines.keys())
100
+
101
+ def get_reuser_names(self) -> list[str]:
102
+ return list(self.reuser_pipelines.keys())
103
+
104
+ @property
105
+ def resources(self) -> dict[str, Callable[..., Any]]:
106
+ return {
107
+ "casebase/{name}": self.get_case,
108
+ "pipelines/retrieve": self.get_retriever_names,
109
+ "pipelines/reuse": self.get_reuser_names,
110
+ }
111
+
112
+ @property
113
+ def prompts(self) -> list[Callable[..., Any]]:
114
+ return []
115
+
116
+
117
+ with cbrkit.helpers.optional_dependencies():
118
+ from fastapi import FastAPI
119
+
120
+ def to_fastapi(system: System) -> FastAPI:
121
+ app = FastAPI()
122
+
123
+ for value in system.tools:
124
+ app.post(f"/tool/{value.__name__}")(value)
125
+
126
+ for key, value in system.resources.items():
127
+ app.get(f"/resource/{key}")(value)
128
+
129
+ for value in system.prompts:
130
+ app.post(f"/prompt/{value.__name__}")(value)
131
+
132
+ return app
133
+
134
+
135
+ with cbrkit.helpers.optional_dependencies():
136
+ from fastmcp import FastMCP
137
+
138
+ def to_fastmcp(system: System) -> FastMCP[Any]:
139
+ app = FastMCP()
140
+
141
+ for value in system.tools:
142
+ app.tool(value)
143
+
144
+ for key, value in system.resources.items():
145
+ app.resource(f"cbrkit://{key}")(value)
146
+
147
+ for value in system.prompts:
148
+ app.prompt(value)
149
+
150
+ return app
151
+
152
+
153
+ with cbrkit.helpers.optional_dependencies():
154
+ from pydantic_ai.toolsets import FunctionToolset
155
+
156
+ def to_pydantic_ai(system: System) -> FunctionToolset[Any]:
157
+ return FunctionToolset(system.tools)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes