spinekit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spine_a2a/__init__.py +17 -0
- spine_a2a/client.py +101 -0
- spine_a2a/py.typed +0 -0
- spine_backends/__init__.py +27 -0
- spine_backends/embeddings.py +68 -0
- spine_backends/memory.py +99 -0
- spine_backends/migrations.py +39 -0
- spine_backends/pgvector.py +140 -0
- spine_backends/postgres.py +85 -0
- spine_backends/py.typed +0 -0
- spine_backends/redis.py +58 -0
- spine_backends/sqlite.py +103 -0
- spine_cli/__init__.py +5 -0
- spine_cli/app.py +363 -0
- spine_cli/builder.py +86 -0
- spine_cli/config.py +92 -0
- spine_cli/plugins.py +54 -0
- spine_cli/py.typed +0 -0
- spine_cli/templates.py +122 -0
- spine_core/__init__.py +117 -0
- spine_core/agent.py +540 -0
- spine_core/checkpoint.py +39 -0
- spine_core/control.py +25 -0
- spine_core/errors.py +23 -0
- spine_core/guards.py +45 -0
- spine_core/interrupt.py +24 -0
- spine_core/memory.py +48 -0
- spine_core/messages.py +123 -0
- spine_core/middleware.py +157 -0
- spine_core/provider.py +103 -0
- spine_core/py.typed +0 -0
- spine_core/registry.py +76 -0
- spine_core/result.py +55 -0
- spine_core/state.py +58 -0
- spine_core/testing.py +87 -0
- spine_core/tools.py +147 -0
- spine_core/trace.py +59 -0
- spine_eval/__init__.py +42 -0
- spine_eval/loader.py +37 -0
- spine_eval/models.py +120 -0
- spine_eval/py.typed +0 -0
- spine_eval/runner.py +71 -0
- spine_eval/scorers.py +132 -0
- spine_mcp/__init__.py +17 -0
- spine_mcp/py.typed +0 -0
- spine_mcp/toolset.py +126 -0
- spine_middleware/__init__.py +72 -0
- spine_middleware/cache.py +80 -0
- spine_middleware/compaction.py +39 -0
- spine_middleware/cost.py +35 -0
- spine_middleware/fallback.py +30 -0
- spine_middleware/guardrails.py +170 -0
- spine_middleware/loopguard.py +43 -0
- spine_middleware/memory.py +66 -0
- spine_middleware/multitenancy.py +52 -0
- spine_middleware/py.typed +0 -0
- spine_middleware/reliability.py +120 -0
- spine_middleware/replay.py +63 -0
- spine_middleware/retry.py +43 -0
- spine_middleware/sandbox.py +99 -0
- spine_middleware/structured.py +79 -0
- spine_middleware/tooling.py +43 -0
- spine_orchestration/__init__.py +7 -0
- spine_orchestration/patterns.py +106 -0
- spine_orchestration/py.typed +0 -0
- spine_otel/__init__.py +15 -0
- spine_otel/middleware.py +150 -0
- spine_otel/py.typed +0 -0
- spine_providers/__init__.py +15 -0
- spine_providers/anthropic.py +258 -0
- spine_providers/openai.py +273 -0
- spine_providers/py.typed +0 -0
- spinekit-0.2.0.dist-info/METADATA +149 -0
- spinekit-0.2.0.dist-info/RECORD +77 -0
- spinekit-0.2.0.dist-info/WHEEL +4 -0
- spinekit-0.2.0.dist-info/entry_points.txt +29 -0
- spinekit-0.2.0.dist-info/licenses/LICENSE +21 -0
spine_a2a/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""A2A (agent-to-agent) adapter for Spine.
|
|
2
|
+
|
|
3
|
+
```python
|
|
4
|
+
from spine_core import Agent
|
|
5
|
+
from spine_a2a import A2AAgent
|
|
6
|
+
|
|
7
|
+
async with A2AAgent("https://remote.example.com/a2a", name="researcher") as remote:
|
|
8
|
+
agent = Agent("anthropic:claude-sonnet-4-6", tools=[remote.as_tool()])
|
|
9
|
+
print((await agent.run("ask the researcher about otters")).answer)
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from spine_a2a.client import A2AAgent
|
|
16
|
+
|
|
17
|
+
__all__ = ["A2AAgent"]
|
spine_a2a/client.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""A2A (agent-to-agent) client — call a remote agent and mount it as a tool.
|
|
2
|
+
|
|
3
|
+
Spine consumes the open A2A protocol rather than a proprietary handoff format.
|
|
4
|
+
A remote agent is reached over JSON-RPC (``message/send``); ``as_tool`` wraps it
|
|
5
|
+
so a local agent can delegate to it like any other tool.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from spine_core import Tool, raw_tool
|
|
15
|
+
|
|
16
|
+
_INPUT_SCHEMA = {
|
|
17
|
+
"type": "object",
|
|
18
|
+
"properties": {"input": {"type": "string", "description": "Message for the remote agent."}},
|
|
19
|
+
"required": ["input"],
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _extract_text(data: dict[str, Any]) -> str:
|
|
24
|
+
"""Pull text out of an A2A response, tolerant of Message/Task shapes."""
|
|
25
|
+
if "error" in data and data["error"]:
|
|
26
|
+
return f"A2A error: {data['error'].get('message', data['error'])}"
|
|
27
|
+
result = data.get("result", data)
|
|
28
|
+
|
|
29
|
+
def parts_text(parts: Any) -> str:
|
|
30
|
+
out = [p.get("text", "") for p in (parts or []) if isinstance(p, dict)]
|
|
31
|
+
return "\n".join(t for t in out if t)
|
|
32
|
+
|
|
33
|
+
# direct message
|
|
34
|
+
if isinstance(result, dict):
|
|
35
|
+
text = parts_text(result.get("parts"))
|
|
36
|
+
if text:
|
|
37
|
+
return text
|
|
38
|
+
# task with artifacts
|
|
39
|
+
for artifact in result.get("artifacts") or []:
|
|
40
|
+
text = parts_text(artifact.get("parts"))
|
|
41
|
+
if text:
|
|
42
|
+
return text
|
|
43
|
+
# task status message
|
|
44
|
+
status = result.get("status") or {}
|
|
45
|
+
text = parts_text((status.get("message") or {}).get("parts"))
|
|
46
|
+
if text:
|
|
47
|
+
return text
|
|
48
|
+
return ""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class A2AAgent:
|
|
52
|
+
"""A handle to a remote A2A agent reached over JSON-RPC."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
url: str,
|
|
57
|
+
*,
|
|
58
|
+
client: httpx.AsyncClient | None = None,
|
|
59
|
+
name: str | None = None,
|
|
60
|
+
description: str | None = None,
|
|
61
|
+
timeout: float = 60.0,
|
|
62
|
+
) -> None:
|
|
63
|
+
self.url = url
|
|
64
|
+
self.name = name or "remote_agent"
|
|
65
|
+
self.description = description or "Delegate a task to a remote A2A agent."
|
|
66
|
+
self._client = client
|
|
67
|
+
self._owns_client = client is None
|
|
68
|
+
self._timeout = timeout
|
|
69
|
+
|
|
70
|
+
def _ensure_client(self) -> httpx.AsyncClient:
|
|
71
|
+
if self._client is None:
|
|
72
|
+
self._client = httpx.AsyncClient(timeout=self._timeout)
|
|
73
|
+
return self._client
|
|
74
|
+
|
|
75
|
+
async def send(self, text: str) -> str:
|
|
76
|
+
payload = {
|
|
77
|
+
"jsonrpc": "2.0",
|
|
78
|
+
"id": 1,
|
|
79
|
+
"method": "message/send",
|
|
80
|
+
"params": {"message": {"role": "user", "parts": [{"kind": "text", "text": text}]}},
|
|
81
|
+
}
|
|
82
|
+
response = await self._ensure_client().post(self.url, json=payload)
|
|
83
|
+
response.raise_for_status()
|
|
84
|
+
return _extract_text(response.json())
|
|
85
|
+
|
|
86
|
+
def as_tool(self, *, name: str | None = None, description: str | None = None) -> Tool:
|
|
87
|
+
async def call(input: str) -> str:
|
|
88
|
+
return await self.send(input)
|
|
89
|
+
|
|
90
|
+
return raw_tool(name or self.name, description or self.description, _INPUT_SCHEMA, call)
|
|
91
|
+
|
|
92
|
+
async def aclose(self) -> None:
|
|
93
|
+
if self._client is not None and self._owns_client:
|
|
94
|
+
await self._client.aclose()
|
|
95
|
+
self._client = None
|
|
96
|
+
|
|
97
|
+
async def __aenter__(self) -> A2AAgent:
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
async def __aexit__(self, *exc: object) -> None:
|
|
101
|
+
await self.aclose()
|
spine_a2a/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Spine storage backends.
|
|
2
|
+
|
|
3
|
+
Importing the package registers the ``sqlite`` checkpoint backend by name, so
|
|
4
|
+
``spine.toml`` ``checkpoint = "sqlite"`` resolves.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from spine_backends.embeddings import HashEmbedder, OpenAIEmbedder
|
|
10
|
+
from spine_backends.memory import BufferMemory, InMemoryVectorMemory
|
|
11
|
+
from spine_backends.migrations import register_migration
|
|
12
|
+
from spine_backends.pgvector import PgVectorMemory
|
|
13
|
+
from spine_backends.postgres import PostgresCheckpoint
|
|
14
|
+
from spine_backends.redis import RedisCheckpoint
|
|
15
|
+
from spine_backends.sqlite import SQLiteCheckpoint
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"BufferMemory",
|
|
19
|
+
"HashEmbedder",
|
|
20
|
+
"InMemoryVectorMemory",
|
|
21
|
+
"OpenAIEmbedder",
|
|
22
|
+
"PgVectorMemory",
|
|
23
|
+
"PostgresCheckpoint",
|
|
24
|
+
"RedisCheckpoint",
|
|
25
|
+
"SQLiteCheckpoint",
|
|
26
|
+
"register_migration",
|
|
27
|
+
]
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Embedders — pluggable text->vector backends implementing core ``Embedder``.
|
|
2
|
+
|
|
3
|
+
``HashEmbedder`` is dependency-free and offline (a good default and test double);
|
|
4
|
+
``OpenAIEmbedder`` calls a real embedding model. Any object with
|
|
5
|
+
``async def embed(text) -> list[float]`` works, so users can bring their own
|
|
6
|
+
(sentence-transformers, Cohere, local models, …).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import math
|
|
13
|
+
import re
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
_WORD = re.compile(r"\w+")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _features(text: str) -> list[str]:
|
|
20
|
+
text = text.lower()
|
|
21
|
+
words = _WORD.findall(text)
|
|
22
|
+
trigrams = [text[i : i + 3] for i in range(max(0, len(text) - 2))]
|
|
23
|
+
return words + trigrams
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HashEmbedder:
|
|
27
|
+
"""Deterministic, offline hashed bag-of-features embedding (L2-normalized).
|
|
28
|
+
|
|
29
|
+
Good for tests and small/offline deployments; not as expressive as a learned
|
|
30
|
+
model. Swap for ``OpenAIEmbedder`` (or your own) in production.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, dim: int = 256) -> None:
|
|
34
|
+
self.dim = dim
|
|
35
|
+
|
|
36
|
+
async def embed(self, text: str) -> list[float]:
|
|
37
|
+
vec = [0.0] * self.dim
|
|
38
|
+
for feature in _features(text):
|
|
39
|
+
digest = hashlib.md5(feature.encode()).hexdigest() # noqa: S324 - non-crypto
|
|
40
|
+
vec[int(digest, 16) % self.dim] += 1.0
|
|
41
|
+
norm = math.sqrt(sum(v * v for v in vec))
|
|
42
|
+
return [v / norm for v in vec] if norm else vec
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OpenAIEmbedder:
|
|
46
|
+
"""Embeds via the OpenAI embeddings API (lazy client; injectable for tests)."""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
model: str = "text-embedding-3-small",
|
|
51
|
+
*,
|
|
52
|
+
client: Any = None,
|
|
53
|
+
api_key: str | None = None,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.model = model
|
|
56
|
+
self._client = client
|
|
57
|
+
self._api_key = api_key
|
|
58
|
+
|
|
59
|
+
def _ensure_client(self) -> Any:
|
|
60
|
+
if self._client is None:
|
|
61
|
+
import openai
|
|
62
|
+
|
|
63
|
+
self._client = openai.AsyncOpenAI(api_key=self._api_key)
|
|
64
|
+
return self._client
|
|
65
|
+
|
|
66
|
+
async def embed(self, text: str) -> list[float]:
|
|
67
|
+
response = await self._ensure_client().embeddings.create(model=self.model, input=text)
|
|
68
|
+
return list(response.data[0].embedding)
|
spine_backends/memory.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""In-process memory backends — semantic vector recall and simple recency buffer.
|
|
2
|
+
|
|
3
|
+
Both implement the core ``Memory`` protocol. ``InMemoryVectorMemory`` takes any
|
|
4
|
+
``Embedder`` (default :class:`HashEmbedder`), so users choose how text is
|
|
5
|
+
embedded. ``BufferMemory`` is non-semantic recency recall for the simple case.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import uuid
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from spine_backends.embeddings import HashEmbedder
|
|
14
|
+
from spine_core.memory import Embedder, MemoryHit, MemoryRecord
|
|
15
|
+
from spine_core.registry import register_memory
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _cosine(a: list[float], b: list[float]) -> float:
|
|
19
|
+
return sum(x * y for x, y in zip(a, b, strict=True)) # inputs are L2-normalized
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class InMemoryVectorMemory:
|
|
23
|
+
"""Process-local vector memory; recall by embedding cosine similarity."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, *, embedder: Embedder | None = None, dim: int = 256) -> None:
|
|
26
|
+
self.embedder: Embedder = embedder or HashEmbedder(dim)
|
|
27
|
+
self.dim = dim
|
|
28
|
+
self._records: list[tuple[MemoryRecord, list[float]]] = []
|
|
29
|
+
|
|
30
|
+
async def save(
|
|
31
|
+
self,
|
|
32
|
+
content: str,
|
|
33
|
+
*,
|
|
34
|
+
session_id: str | None = None,
|
|
35
|
+
metadata: dict[str, Any] | None = None,
|
|
36
|
+
) -> MemoryRecord:
|
|
37
|
+
record = MemoryRecord(
|
|
38
|
+
id=uuid.uuid4().hex, content=content, session_id=session_id, metadata=metadata or {}
|
|
39
|
+
)
|
|
40
|
+
self._records.append((record, await self.embedder.embed(content)))
|
|
41
|
+
return record
|
|
42
|
+
|
|
43
|
+
async def search(
|
|
44
|
+
self, query: str, *, k: int = 5, session_id: str | None = None
|
|
45
|
+
) -> list[MemoryHit]:
|
|
46
|
+
qv = await self.embedder.embed(query)
|
|
47
|
+
hits = [
|
|
48
|
+
MemoryHit(record=record, score=_cosine(qv, vec))
|
|
49
|
+
for record, vec in self._records
|
|
50
|
+
if session_id is None or record.session_id == session_id
|
|
51
|
+
]
|
|
52
|
+
hits.sort(key=lambda h: h.score, reverse=True)
|
|
53
|
+
return hits[:k]
|
|
54
|
+
|
|
55
|
+
async def load(self, session_id: str, *, limit: int = 20) -> list[MemoryRecord]:
|
|
56
|
+
records = [r for r, _ in self._records if r.session_id == session_id]
|
|
57
|
+
return records[-limit:]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class BufferMemory:
|
|
61
|
+
"""Non-semantic recency memory: ``search`` returns the most recent records.
|
|
62
|
+
|
|
63
|
+
Cheap and predictable when similarity is not needed (e.g. a rolling notes
|
|
64
|
+
buffer). ``search`` ignores the query text.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self) -> None:
|
|
68
|
+
self._records: list[MemoryRecord] = []
|
|
69
|
+
|
|
70
|
+
async def save(
|
|
71
|
+
self,
|
|
72
|
+
content: str,
|
|
73
|
+
*,
|
|
74
|
+
session_id: str | None = None,
|
|
75
|
+
metadata: dict[str, Any] | None = None,
|
|
76
|
+
) -> MemoryRecord:
|
|
77
|
+
record = MemoryRecord(
|
|
78
|
+
id=uuid.uuid4().hex, content=content, session_id=session_id, metadata=metadata or {}
|
|
79
|
+
)
|
|
80
|
+
self._records.append(record)
|
|
81
|
+
return record
|
|
82
|
+
|
|
83
|
+
async def search(
|
|
84
|
+
self, query: str, *, k: int = 5, session_id: str | None = None
|
|
85
|
+
) -> list[MemoryHit]:
|
|
86
|
+
pool = [r for r in self._records if session_id is None or r.session_id == session_id]
|
|
87
|
+
return [MemoryHit(record=r, score=1.0) for r in reversed(pool[-k:])]
|
|
88
|
+
|
|
89
|
+
async def load(self, session_id: str, *, limit: int = 20) -> list[MemoryRecord]:
|
|
90
|
+
records = [r for r in self._records if r.session_id == session_id]
|
|
91
|
+
return records[-limit:]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def register() -> None:
|
|
95
|
+
register_memory("vector", lambda **cfg: InMemoryVectorMemory(**cfg))
|
|
96
|
+
register_memory("buffer", lambda **_: BufferMemory())
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
register()
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""State schema migration registry.
|
|
2
|
+
|
|
3
|
+
A checkpoint written by old code (``version=1``) may be resumed by new code
|
|
4
|
+
(``version=2``). Backends call :func:`migrate` on the raw dict before validating,
|
|
5
|
+
walking registered upgrade functions one version at a time.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from spine_core.state import STATE_VERSION
|
|
14
|
+
|
|
15
|
+
Migration = Callable[[dict[str, Any]], dict[str, Any]]
|
|
16
|
+
_MIGRATIONS: dict[int, Migration] = {}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def register_migration(from_version: int, fn: Migration) -> None:
|
|
20
|
+
"""Register an upgrade from ``from_version`` to ``from_version + 1``."""
|
|
21
|
+
_MIGRATIONS[from_version] = fn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def migrate(raw: dict[str, Any]) -> dict[str, Any]:
|
|
25
|
+
"""Upgrade a raw state dict to the current ``STATE_VERSION``."""
|
|
26
|
+
version = int(raw.get("version", 1))
|
|
27
|
+
while version < STATE_VERSION:
|
|
28
|
+
fn = _MIGRATIONS.get(version)
|
|
29
|
+
if fn is None:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"cannot resume: no migration from state version {version} "
|
|
32
|
+
f"(current is {STATE_VERSION})"
|
|
33
|
+
)
|
|
34
|
+
raw = fn(raw)
|
|
35
|
+
new_version = int(raw.get("version", version + 1))
|
|
36
|
+
if new_version <= version: # guard against a migration that doesn't advance
|
|
37
|
+
raise ValueError(f"migration from version {version} did not advance the version")
|
|
38
|
+
version = new_version
|
|
39
|
+
return raw
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""pgvector memory backend — semantic recall backed by Postgres + pgvector.
|
|
2
|
+
|
|
3
|
+
Scales the vector memory beyond one process. ``asyncpg`` is imported lazily;
|
|
4
|
+
the ``pgvector`` extension must be installed in the database. Integration is
|
|
5
|
+
exercised when ``SPINE_TEST_PG_DSN`` is set.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import uuid
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from spine_backends.embeddings import HashEmbedder
|
|
15
|
+
from spine_core.memory import Embedder, MemoryHit, MemoryRecord
|
|
16
|
+
from spine_core.registry import register_memory
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PgVectorMemory:
|
|
20
|
+
"""``Memory`` over Postgres + pgvector with cosine-distance recall."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
dsn: str,
|
|
25
|
+
*,
|
|
26
|
+
embedder: Embedder | None = None,
|
|
27
|
+
dim: int = 256,
|
|
28
|
+
table: str = "spine_memory",
|
|
29
|
+
pool: Any = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self.dsn = dsn
|
|
32
|
+
self.embedder: Embedder = embedder or HashEmbedder(dim)
|
|
33
|
+
self.dim = dim
|
|
34
|
+
self.table = table
|
|
35
|
+
self._pool = pool
|
|
36
|
+
|
|
37
|
+
async def _ensure_pool(self) -> Any:
|
|
38
|
+
if self._pool is None:
|
|
39
|
+
import asyncpg
|
|
40
|
+
|
|
41
|
+
self._pool = await asyncpg.create_pool(self.dsn)
|
|
42
|
+
async with self._pool.acquire() as conn:
|
|
43
|
+
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
44
|
+
await conn.execute(
|
|
45
|
+
f"""
|
|
46
|
+
CREATE TABLE IF NOT EXISTS {self.table} (
|
|
47
|
+
id TEXT PRIMARY KEY,
|
|
48
|
+
session_id TEXT,
|
|
49
|
+
content TEXT NOT NULL,
|
|
50
|
+
metadata JSONB NOT NULL DEFAULT '{{}}',
|
|
51
|
+
embedding vector({self.dim}) NOT NULL,
|
|
52
|
+
ts TIMESTAMPTZ NOT NULL DEFAULT now()
|
|
53
|
+
)
|
|
54
|
+
"""
|
|
55
|
+
)
|
|
56
|
+
return self._pool
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _vec_literal(vec: list[float]) -> str:
|
|
60
|
+
return "[" + ",".join(repr(x) for x in vec) + "]"
|
|
61
|
+
|
|
62
|
+
async def save(
|
|
63
|
+
self,
|
|
64
|
+
content: str,
|
|
65
|
+
*,
|
|
66
|
+
session_id: str | None = None,
|
|
67
|
+
metadata: dict[str, Any] | None = None,
|
|
68
|
+
) -> MemoryRecord:
|
|
69
|
+
pool = await self._ensure_pool()
|
|
70
|
+
record = MemoryRecord(
|
|
71
|
+
id=uuid.uuid4().hex, content=content, session_id=session_id, metadata=metadata or {}
|
|
72
|
+
)
|
|
73
|
+
embedding = self._vec_literal(await self.embedder.embed(content))
|
|
74
|
+
async with pool.acquire() as conn:
|
|
75
|
+
await conn.execute(
|
|
76
|
+
f"INSERT INTO {self.table} (id, session_id, content, metadata, embedding) "
|
|
77
|
+
f"VALUES ($1, $2, $3, $4::jsonb, $5::vector)",
|
|
78
|
+
record.id,
|
|
79
|
+
session_id,
|
|
80
|
+
content,
|
|
81
|
+
json.dumps(record.metadata),
|
|
82
|
+
embedding,
|
|
83
|
+
)
|
|
84
|
+
return record
|
|
85
|
+
|
|
86
|
+
async def search(
|
|
87
|
+
self, query: str, *, k: int = 5, session_id: str | None = None
|
|
88
|
+
) -> list[MemoryHit]:
|
|
89
|
+
pool = await self._ensure_pool()
|
|
90
|
+
embedding = self._vec_literal(await self.embedder.embed(query))
|
|
91
|
+
where = "WHERE session_id = $2" if session_id is not None else ""
|
|
92
|
+
args: list[Any] = [embedding] + ([session_id] if session_id is not None else [])
|
|
93
|
+
async with pool.acquire() as conn:
|
|
94
|
+
rows = await conn.fetch(
|
|
95
|
+
f"SELECT id, session_id, content, metadata, "
|
|
96
|
+
f"1 - (embedding <=> $1::vector) AS score "
|
|
97
|
+
f"FROM {self.table} {where} ORDER BY embedding <=> $1::vector LIMIT {int(k)}",
|
|
98
|
+
*args,
|
|
99
|
+
)
|
|
100
|
+
return [
|
|
101
|
+
MemoryHit(
|
|
102
|
+
record=MemoryRecord(
|
|
103
|
+
id=row["id"],
|
|
104
|
+
session_id=row["session_id"],
|
|
105
|
+
content=row["content"],
|
|
106
|
+
metadata=json.loads(row["metadata"])
|
|
107
|
+
if isinstance(row["metadata"], str)
|
|
108
|
+
else row["metadata"],
|
|
109
|
+
),
|
|
110
|
+
score=float(row["score"]),
|
|
111
|
+
)
|
|
112
|
+
for row in rows
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
async def load(self, session_id: str, *, limit: int = 20) -> list[MemoryRecord]:
|
|
116
|
+
pool = await self._ensure_pool()
|
|
117
|
+
async with pool.acquire() as conn:
|
|
118
|
+
rows = await conn.fetch(
|
|
119
|
+
f"SELECT id, session_id, content, metadata FROM {self.table} "
|
|
120
|
+
f"WHERE session_id = $1 ORDER BY ts DESC LIMIT {int(limit)}",
|
|
121
|
+
session_id,
|
|
122
|
+
)
|
|
123
|
+
return [
|
|
124
|
+
MemoryRecord(
|
|
125
|
+
id=row["id"],
|
|
126
|
+
session_id=row["session_id"],
|
|
127
|
+
content=row["content"],
|
|
128
|
+
metadata=json.loads(row["metadata"])
|
|
129
|
+
if isinstance(row["metadata"], str)
|
|
130
|
+
else row["metadata"],
|
|
131
|
+
)
|
|
132
|
+
for row in rows
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def register() -> None:
|
|
137
|
+
register_memory("pgvector", lambda dsn="", **cfg: PgVectorMemory(dsn, **cfg))
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
register()
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Postgres checkpoint backend — durable state with optimistic locking.
|
|
2
|
+
|
|
3
|
+
Uses ``asyncpg`` (imported lazily; inject a ``pool`` for tests). Integration is
|
|
4
|
+
exercised against a real database when ``SPINE_TEST_PG_DSN`` is set; the module
|
|
5
|
+
itself imports nothing heavy at top level.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from spine_backends.migrations import migrate
|
|
14
|
+
from spine_core.registry import register_checkpoint
|
|
15
|
+
from spine_core.state import State
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PostgresCheckpoint:
|
|
19
|
+
"""Durable :class:`~spine_core.checkpoint.CheckpointStore` over Postgres."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, dsn: str, *, pool: Any = None, table: str = "spine_checkpoints") -> None:
|
|
22
|
+
self.dsn = dsn
|
|
23
|
+
self.table = table
|
|
24
|
+
self._pool = pool
|
|
25
|
+
|
|
26
|
+
async def _ensure_pool(self) -> Any:
|
|
27
|
+
if self._pool is None:
|
|
28
|
+
import asyncpg
|
|
29
|
+
|
|
30
|
+
self._pool = await asyncpg.create_pool(self.dsn)
|
|
31
|
+
async with self._pool.acquire() as conn:
|
|
32
|
+
await conn.execute(
|
|
33
|
+
f"""
|
|
34
|
+
CREATE TABLE IF NOT EXISTS {self.table} (
|
|
35
|
+
session_id TEXT PRIMARY KEY,
|
|
36
|
+
version INTEGER NOT NULL,
|
|
37
|
+
revision BIGINT NOT NULL DEFAULT 1,
|
|
38
|
+
data JSONB NOT NULL,
|
|
39
|
+
updated TIMESTAMPTZ NOT NULL DEFAULT now()
|
|
40
|
+
)
|
|
41
|
+
"""
|
|
42
|
+
)
|
|
43
|
+
return self._pool
|
|
44
|
+
|
|
45
|
+
async def put(self, state: State) -> None:
|
|
46
|
+
pool = await self._ensure_pool()
|
|
47
|
+
async with pool.acquire() as conn:
|
|
48
|
+
await conn.execute(
|
|
49
|
+
f"""
|
|
50
|
+
INSERT INTO {self.table} (session_id, version, revision, data)
|
|
51
|
+
VALUES ($1, $2, 1, $3::jsonb)
|
|
52
|
+
ON CONFLICT (session_id) DO UPDATE SET
|
|
53
|
+
version = EXCLUDED.version,
|
|
54
|
+
revision = {self.table}.revision + 1,
|
|
55
|
+
data = EXCLUDED.data,
|
|
56
|
+
updated = now()
|
|
57
|
+
""",
|
|
58
|
+
state.session_id,
|
|
59
|
+
state.version,
|
|
60
|
+
state.model_dump_json(),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
async def get(self, session_id: str) -> State | None:
|
|
64
|
+
pool = await self._ensure_pool()
|
|
65
|
+
async with pool.acquire() as conn:
|
|
66
|
+
row = await conn.fetchrow(
|
|
67
|
+
f"SELECT data FROM {self.table} WHERE session_id = $1", session_id
|
|
68
|
+
)
|
|
69
|
+
if row is None:
|
|
70
|
+
return None
|
|
71
|
+
data: Any = row["data"]
|
|
72
|
+
raw = json.loads(data) if isinstance(data, str) else data
|
|
73
|
+
return State.model_validate(migrate(raw))
|
|
74
|
+
|
|
75
|
+
async def delete(self, session_id: str) -> None:
|
|
76
|
+
pool = await self._ensure_pool()
|
|
77
|
+
async with pool.acquire() as conn:
|
|
78
|
+
await conn.execute(f"DELETE FROM {self.table} WHERE session_id = $1", session_id)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def register() -> None:
|
|
82
|
+
register_checkpoint("postgres", lambda dsn="", **_: PostgresCheckpoint(dsn))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
register()
|
spine_backends/py.typed
ADDED
|
File without changes
|
spine_backends/redis.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Redis checkpoint backend — durable state for distributed workers.
|
|
2
|
+
|
|
3
|
+
The ``redis`` client is imported lazily and may be injected (tests use a fake),
|
|
4
|
+
so importing this module never requires the dependency or a server.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from spine_backends.migrations import migrate
|
|
13
|
+
from spine_core.registry import register_checkpoint
|
|
14
|
+
from spine_core.state import State
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RedisCheckpoint:
|
|
18
|
+
"""Durable :class:`~spine_core.checkpoint.CheckpointStore` over Redis."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
url: str = "redis://localhost:6379",
|
|
23
|
+
*,
|
|
24
|
+
client: Any = None,
|
|
25
|
+
prefix: str = "spine:checkpoint:",
|
|
26
|
+
) -> None:
|
|
27
|
+
self.url = url
|
|
28
|
+
self.prefix = prefix
|
|
29
|
+
self._client = client
|
|
30
|
+
|
|
31
|
+
def _ensure_client(self) -> Any:
|
|
32
|
+
if self._client is None:
|
|
33
|
+
import redis.asyncio as redis
|
|
34
|
+
|
|
35
|
+
self._client = redis.from_url(self.url, decode_responses=True)
|
|
36
|
+
return self._client
|
|
37
|
+
|
|
38
|
+
def _key(self, session_id: str) -> str:
|
|
39
|
+
return f"{self.prefix}{session_id}"
|
|
40
|
+
|
|
41
|
+
async def put(self, state: State) -> None:
|
|
42
|
+
await self._ensure_client().set(self._key(state.session_id), state.model_dump_json())
|
|
43
|
+
|
|
44
|
+
async def get(self, session_id: str) -> State | None:
|
|
45
|
+
raw = await self._ensure_client().get(self._key(session_id))
|
|
46
|
+
if raw is None:
|
|
47
|
+
return None
|
|
48
|
+
return State.model_validate(migrate(json.loads(raw)))
|
|
49
|
+
|
|
50
|
+
async def delete(self, session_id: str) -> None:
|
|
51
|
+
await self._ensure_client().delete(self._key(session_id))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def register() -> None:
|
|
55
|
+
register_checkpoint("redis", lambda url="redis://localhost:6379", **_: RedisCheckpoint(url))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
register()
|