relio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relio/__init__.py +24 -0
- relio/agents.py +72 -0
- relio/ai.py +190 -0
- relio/backends/__init__.py +0 -0
- relio/backends/base.py +56 -0
- relio/backends/postgres.py +176 -0
- relio/backends/sqlite.py +184 -0
- relio/cli/__init__.py +0 -0
- relio/cli/check.py +96 -0
- relio/cli/dockerfile.py +32 -0
- relio/cli/main.py +185 -0
- relio/cli/scaffold.py +330 -0
- relio/embedding/__init__.py +0 -0
- relio/embedding/base.py +47 -0
- relio/embedding/cache.py +37 -0
- relio/embedding/local.py +31 -0
- relio/exposure.py +80 -0
- relio/graph.py +63 -0
- relio/interchange.py +55 -0
- relio/mcp_server.py +31 -0
- relio/memory.py +222 -0
- relio/recall.py +47 -0
- relio/record.py +64 -0
- relio/render.py +34 -0
- relio/sdkgen.py +380 -0
- relio/server/__init__.py +7 -0
- relio/server/agent.py +45 -0
- relio/server/app.py +43 -0
- relio/server/auth.py +45 -0
- relio/server/config.py +14 -0
- relio/server/llm/__init__.py +0 -0
- relio/server/llm/base.py +32 -0
- relio/server/llm/claude.py +57 -0
- relio/server/llm/fake.py +32 -0
- relio/server/routes/__init__.py +0 -0
- relio/server/routes/chat.py +44 -0
- relio/server/routes/graph.py +26 -0
- relio/server/routes/history.py +39 -0
- relio/server/routes/memory.py +64 -0
- relio/server/schemas.py +30 -0
- relio/server/scope.py +15 -0
- relio/server/static.py +29 -0
- relio/templates/desktop/README.md +24 -0
- relio/templates/desktop/package.json +24 -0
- relio/templates/desktop/src-tauri/Cargo.toml +12 -0
- relio/templates/desktop/src-tauri/build.rs +3 -0
- relio/templates/desktop/src-tauri/src/main.rs +8 -0
- relio/templates/desktop/src-tauri/tauri.conf.json +28 -0
- relio/templates/mobile/App.tsx +74 -0
- relio/templates/mobile/README.md +22 -0
- relio/templates/mobile/app.json +10 -0
- relio/templates/mobile/babel.config.js +6 -0
- relio/templates/mobile/package.json +22 -0
- relio/templates/mobile/tsconfig.json +6 -0
- relio/templates/web/index.html +12 -0
- relio/templates/web/package.json +24 -0
- relio/templates/web/src/App.tsx +20 -0
- relio/templates/web/src/components/ChatView.tsx +69 -0
- relio/templates/web/src/components/MemoryBrowser.tsx +84 -0
- relio/templates/web/src/main.tsx +10 -0
- relio/templates/web/src/styles.css +147 -0
- relio/templates/web/tsconfig.json +18 -0
- relio/templates/web/vite.config.ts +19 -0
- relio-0.1.0.dist-info/METADATA +302 -0
- relio-0.1.0.dist-info/RECORD +68 -0
- relio-0.1.0.dist-info/WHEEL +4 -0
- relio-0.1.0.dist-info/entry_points.txt +2 -0
- relio-0.1.0.dist-info/licenses/LICENSE +21 -0
relio/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from .ai import RelioAI
|
|
2
|
+
from .memory import Memory
|
|
3
|
+
from .record import MemoryRecord, MemoryType, Relation, Scope
|
|
4
|
+
from .interchange import (
|
|
5
|
+
export_records,
|
|
6
|
+
import_records,
|
|
7
|
+
import_record_objects,
|
|
8
|
+
from_mem0,
|
|
9
|
+
)
|
|
10
|
+
from .mcp_server import build_mcp_server
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"RelioAI",
|
|
14
|
+
"Memory",
|
|
15
|
+
"MemoryRecord",
|
|
16
|
+
"MemoryType",
|
|
17
|
+
"Relation",
|
|
18
|
+
"Scope",
|
|
19
|
+
"export_records",
|
|
20
|
+
"import_records",
|
|
21
|
+
"import_record_objects",
|
|
22
|
+
"from_mem0",
|
|
23
|
+
"build_mcp_server",
|
|
24
|
+
]
|
relio/agents.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# relio/agents.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Iterator, Optional
|
|
5
|
+
|
|
6
|
+
from .record import MemoryRecord, Scope
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Agent:
|
|
10
|
+
"""A bounded agent context: its own memory namespace, tool slice, config,
|
|
11
|
+
and session. Private by default — it sees only its own space and the tools
|
|
12
|
+
it was granted.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
ai: Any,
|
|
18
|
+
name: str,
|
|
19
|
+
*,
|
|
20
|
+
space: Optional[Scope] = None,
|
|
21
|
+
tools: Optional[list[str]] = None,
|
|
22
|
+
system: str = "",
|
|
23
|
+
model: Optional[str] = None,
|
|
24
|
+
recall_limit: int = 5,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.ai = ai
|
|
27
|
+
self.name = name
|
|
28
|
+
self.space = space or Scope(agent=name) # its own memory namespace
|
|
29
|
+
self._allowed: Optional[set[str]] = set(tools) if tools is not None else None
|
|
30
|
+
self.system = system
|
|
31
|
+
self.model = model
|
|
32
|
+
self.recall_limit = recall_limit
|
|
33
|
+
|
|
34
|
+
# --- memory namespace (isolated) ----------------------------------------
|
|
35
|
+
|
|
36
|
+
def remember(self, content: str, **kwargs: Any) -> MemoryRecord:
|
|
37
|
+
return self.ai.remember(content, scope=self.space, **kwargs)
|
|
38
|
+
|
|
39
|
+
def recall(self, query: str, limit: Optional[int] = None) -> list[MemoryRecord]:
|
|
40
|
+
return self.ai.recall(query, scope=self.space, limit=limit or self.recall_limit)
|
|
41
|
+
|
|
42
|
+
def history(self, limit: int = 20) -> list[MemoryRecord]:
|
|
43
|
+
return self.ai.memory.history(self.space, limit=limit)
|
|
44
|
+
|
|
45
|
+
# --- tool slice (granted subset of the exposure map) --------------------
|
|
46
|
+
|
|
47
|
+
def tools(self) -> list[str]:
|
|
48
|
+
names = self.ai.tools.names()
|
|
49
|
+
return names if self._allowed is None else [n for n in names if n in self._allowed]
|
|
50
|
+
|
|
51
|
+
def call_tool(self, name: str, **kwargs: Any) -> Any:
|
|
52
|
+
if self._allowed is not None and name not in self._allowed:
|
|
53
|
+
raise PermissionError(f"agent {self.name!r} may not call tool {name!r}")
|
|
54
|
+
return self.ai.call_tool(name, **kwargs)
|
|
55
|
+
|
|
56
|
+
# --- reasoning (scoped to this agent) -----------------------------------
|
|
57
|
+
|
|
58
|
+
def chat(self, message: str, **kwargs: Any) -> Iterator[str]:
|
|
59
|
+
if self.ai.provider is None:
|
|
60
|
+
raise RuntimeError("agent chat needs an LLM provider")
|
|
61
|
+
from .server.agent import run_chat
|
|
62
|
+
|
|
63
|
+
prefix = (self.system + "\n\n") if self.system else ""
|
|
64
|
+
return run_chat(
|
|
65
|
+
self.ai.memory,
|
|
66
|
+
self.ai.provider,
|
|
67
|
+
message,
|
|
68
|
+
self.space,
|
|
69
|
+
limit=self.recall_limit,
|
|
70
|
+
system_prefix=prefix,
|
|
71
|
+
**kwargs,
|
|
72
|
+
)
|
relio/ai.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# relio/ai.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Iterator, Optional, Union
|
|
6
|
+
|
|
7
|
+
from .exposure import ExposureMap
|
|
8
|
+
from .memory import Memory
|
|
9
|
+
from .record import MemoryRecord, MemoryType, Relation, Scope
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RelioAI:
|
|
13
|
+
"""The called-in AI component: one seam composing the AI-system components
|
|
14
|
+
(memory/RAG, embeddings, graph, structured query, reasoning, MCP interop).
|
|
15
|
+
|
|
16
|
+
The LLM is optional — construct it with no provider for a pure
|
|
17
|
+
memory/retrieval/data component, and add a provider when you need `chat`.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
memory: Optional[Memory] = None,
|
|
23
|
+
provider: Optional[object] = None,
|
|
24
|
+
*,
|
|
25
|
+
path: str = "relio.db",
|
|
26
|
+
embedder: Optional[object] = None,
|
|
27
|
+
database_url: Optional[str] = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
self.memory = memory or Memory(
|
|
30
|
+
path=path, embedder=embedder, database_url=database_url
|
|
31
|
+
)
|
|
32
|
+
self.provider = provider
|
|
33
|
+
self.tools = ExposureMap() # the governed surface the AI may call
|
|
34
|
+
|
|
35
|
+
# --- knowledge & retrieval ----------------------------------------------
|
|
36
|
+
|
|
37
|
+
def remember(
|
|
38
|
+
self,
|
|
39
|
+
content: str,
|
|
40
|
+
type: MemoryType = MemoryType.SEMANTIC,
|
|
41
|
+
scope: Optional[Scope] = None,
|
|
42
|
+
**kwargs: Any,
|
|
43
|
+
) -> MemoryRecord:
|
|
44
|
+
return self.memory.add(content, type=type, scope=scope, **kwargs)
|
|
45
|
+
|
|
46
|
+
def recall(
|
|
47
|
+
self, query: str, scope: Optional[Scope] = None, limit: int = 5
|
|
48
|
+
) -> list[MemoryRecord]:
|
|
49
|
+
return self.memory.recall(query, scope=scope, limit=limit)
|
|
50
|
+
|
|
51
|
+
def recall_text(self, query: str, **kwargs: Any) -> str:
|
|
52
|
+
return self.memory.recall_text(query, **kwargs)
|
|
53
|
+
|
|
54
|
+
def embed(self, texts: Union[str, list[str]]) -> Union[list[float], list[list[float]]]:
|
|
55
|
+
if isinstance(texts, str):
|
|
56
|
+
return self.memory.embedder.embed(texts)
|
|
57
|
+
return self.memory.embedder.embed_batch(list(texts))
|
|
58
|
+
|
|
59
|
+
# --- graph ---------------------------------------------------------------
|
|
60
|
+
|
|
61
|
+
def add_node(self, content: str, **kwargs: Any) -> MemoryRecord:
|
|
62
|
+
return self.memory.add_node(content, **kwargs)
|
|
63
|
+
|
|
64
|
+
def add_edge(self, source_id: str, predicate: str, target_id: str) -> MemoryRecord:
|
|
65
|
+
return self.memory.add_edge(source_id, predicate, target_id)
|
|
66
|
+
|
|
67
|
+
def neighbors(self, node_id: str, predicate: Optional[str] = None) -> list[MemoryRecord]:
|
|
68
|
+
return self.memory.neighbors(node_id, predicate=predicate)
|
|
69
|
+
|
|
70
|
+
def in_neighbors(self, node_id: str, predicate: Optional[str] = None) -> list[MemoryRecord]:
|
|
71
|
+
return self.memory.in_neighbors(node_id, predicate=predicate)
|
|
72
|
+
|
|
73
|
+
def traverse(self, start_id: str, depth: int = 1, predicate: Optional[str] = None) -> list[MemoryRecord]:
|
|
74
|
+
return self.memory.traverse(start_id, depth=depth, predicate=predicate)
|
|
75
|
+
|
|
76
|
+
# --- structured query ----------------------------------------------------
|
|
77
|
+
|
|
78
|
+
def query(
|
|
79
|
+
self,
|
|
80
|
+
type: Optional[MemoryType] = None,
|
|
81
|
+
scope: Optional[Scope] = None,
|
|
82
|
+
where: Optional[dict[str, str]] = None,
|
|
83
|
+
limit: int = 100,
|
|
84
|
+
) -> list[MemoryRecord]:
|
|
85
|
+
return self.memory.query(type=type, scope=scope, where=where, limit=limit)
|
|
86
|
+
|
|
87
|
+
def transaction(self):
|
|
88
|
+
return self.memory.transaction()
|
|
89
|
+
|
|
90
|
+
# --- reasoning -----------------------------------------------------------
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def has_llm(self) -> bool:
|
|
94
|
+
return self.provider is not None
|
|
95
|
+
|
|
96
|
+
def chat(
|
|
97
|
+
self, message: str, scope: Optional[Scope] = None, **kwargs: Any
|
|
98
|
+
) -> Iterator[str]:
|
|
99
|
+
if self.provider is None:
|
|
100
|
+
raise RuntimeError(
|
|
101
|
+
"RelioAI.chat needs an LLM provider; construct RelioAI(provider=...)"
|
|
102
|
+
)
|
|
103
|
+
from .server.agent import run_chat
|
|
104
|
+
|
|
105
|
+
return run_chat(self.memory, self.provider, message, scope or Scope(), **kwargs)
|
|
106
|
+
|
|
107
|
+
# --- structured / multimodal extraction (D6) ----------------------------
|
|
108
|
+
|
|
109
|
+
def _require_provider(self, what: str) -> None:
|
|
110
|
+
if self.provider is None:
|
|
111
|
+
raise RuntimeError(f"RelioAI.{what} needs an LLM provider")
|
|
112
|
+
|
|
113
|
+
def extract(self, text: str, schema: Optional[dict] = None) -> dict:
|
|
114
|
+
"""Extract structured data from text into `schema`."""
|
|
115
|
+
self._require_provider("extract")
|
|
116
|
+
return self.provider.extract(text, schema=schema)
|
|
117
|
+
|
|
118
|
+
def extract_file(
|
|
119
|
+
self,
|
|
120
|
+
file: Union[str, Path, bytes, bytearray],
|
|
121
|
+
schema: Optional[dict] = None,
|
|
122
|
+
media_type: str = "application/pdf",
|
|
123
|
+
) -> dict:
|
|
124
|
+
"""Extract structured data from a file (PDF/image) into `schema` — the
|
|
125
|
+
path for "read this drawing/scan and give me a bill"."""
|
|
126
|
+
self._require_provider("extract_file")
|
|
127
|
+
data = bytes(file) if isinstance(file, (bytes, bytearray)) else Path(file).read_bytes()
|
|
128
|
+
return self.provider.extract("", schema=schema, image_bytes=data, media_type=media_type)
|
|
129
|
+
|
|
130
|
+
# --- tools / exposure map (D3) ------------------------------------------
|
|
131
|
+
|
|
132
|
+
def tool(self, fn=None, *, name: Optional[str] = None, description: Optional[str] = None):
|
|
133
|
+
"""Register an app operation the AI may call (decorator)."""
|
|
134
|
+
return self.tools.tool(fn, name=name, description=description)
|
|
135
|
+
|
|
136
|
+
def expose(self, obj: Any, fields: list[str]) -> dict[str, Any]:
|
|
137
|
+
"""Field allowlist: project `obj` to only `fields` for AI consumption."""
|
|
138
|
+
return ExposureMap.project(obj, fields)
|
|
139
|
+
|
|
140
|
+
def list_tools(self) -> list[dict[str, Any]]:
|
|
141
|
+
return [
|
|
142
|
+
{"name": s.name, "description": s.description, "parameters": s.parameters}
|
|
143
|
+
for s in self.tools.list()
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
def call_tool(self, name: str, **kwargs: Any) -> Any:
|
|
147
|
+
return self.tools.call(name, **kwargs)
|
|
148
|
+
|
|
149
|
+
# --- agents (D4) ---------------------------------------------------------
|
|
150
|
+
|
|
151
|
+
def agent(
|
|
152
|
+
self,
|
|
153
|
+
name: str,
|
|
154
|
+
*,
|
|
155
|
+
space: Optional[Scope] = None,
|
|
156
|
+
tools: Optional[list[str]] = None,
|
|
157
|
+
system: str = "",
|
|
158
|
+
model: Optional[str] = None,
|
|
159
|
+
recall_limit: int = 5,
|
|
160
|
+
):
|
|
161
|
+
"""Construct a bounded agent: its own memory namespace + tool slice +
|
|
162
|
+
config + session. Private by default."""
|
|
163
|
+
from .agents import Agent
|
|
164
|
+
|
|
165
|
+
return Agent(
|
|
166
|
+
self,
|
|
167
|
+
name,
|
|
168
|
+
space=space,
|
|
169
|
+
tools=tools,
|
|
170
|
+
system=system,
|
|
171
|
+
model=model,
|
|
172
|
+
recall_limit=recall_limit,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# --- interop -------------------------------------------------------------
|
|
176
|
+
|
|
177
|
+
def mcp_server(self, include_tools: bool = True):
|
|
178
|
+
"""The FastMCP server exposing this memory (add/recall) and — when
|
|
179
|
+
`include_tools` — the exposure-map tools, to external agents."""
|
|
180
|
+
from .mcp_server import build_mcp_server
|
|
181
|
+
|
|
182
|
+
server, tools = build_mcp_server(self.memory)
|
|
183
|
+
if include_tools:
|
|
184
|
+
for spec in self.tools.list():
|
|
185
|
+
server.tool()(spec.fn)
|
|
186
|
+
tools[spec.name] = spec.fn
|
|
187
|
+
return server, tools
|
|
188
|
+
|
|
189
|
+
def close(self) -> None:
|
|
190
|
+
self.memory.close()
|
|
File without changes
|
relio/backends/base.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# relio/backends/base.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import ContextManager, Optional
|
|
6
|
+
|
|
7
|
+
from ..record import MemoryRecord, MemoryType, Scope
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StorageBackend(ABC):
|
|
11
|
+
"""Persistence contract. Callers (Memory, RecallEngine) depend only on this."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def add(self, record: MemoryRecord, embedding: list[float] | None) -> None:
|
|
15
|
+
"""Insert or replace a record; store its embedding if provided."""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def get(self, record_id: str) -> MemoryRecord | None:
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def delete(self, record_id: str) -> bool:
|
|
23
|
+
"""Return True if a row was removed."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def search(self, embedding: list[float], k: int) -> list[tuple[MemoryRecord, float]]:
|
|
27
|
+
"""Return up to k nearest records as (record, distance), ascending distance."""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def all(self) -> list[MemoryRecord]:
|
|
31
|
+
"""Return every record in insertion order (oldest first).
|
|
32
|
+
|
|
33
|
+
History relies on this ordering; backends must preserve it.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def query(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
type: Optional[MemoryType] = None,
|
|
41
|
+
scope: Optional[Scope] = None,
|
|
42
|
+
metadata: Optional[dict[str, str]] = None,
|
|
43
|
+
limit: int = 100,
|
|
44
|
+
) -> list[MemoryRecord]:
|
|
45
|
+
"""Structured (non-semantic) filter by exact type / scope / metadata
|
|
46
|
+
equality, returned in insertion order. Unlike search(), this includes
|
|
47
|
+
records with no embedding."""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def transaction(self) -> "ContextManager[None]":
|
|
51
|
+
"""Context manager grouping writes into one atomic unit: all commit on
|
|
52
|
+
clean exit, all roll back if the block raises."""
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def close(self) -> None:
|
|
56
|
+
...
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
from typing import Iterator, Optional
|
|
7
|
+
|
|
8
|
+
from ..record import MemoryRecord, MemoryType, Scope
|
|
9
|
+
from .base import StorageBackend
|
|
10
|
+
|
|
11
|
+
_KEY = re.compile(r"^\w+$") # guard interpolated json paths against injection
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _vector_literal(embedding: list[float]) -> str:
|
|
15
|
+
"""pgvector text form: [a,b,c]."""
|
|
16
|
+
return "[" + ",".join(str(float(x)) for x in embedding) + "]"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _to_record(doc) -> MemoryRecord:
|
|
20
|
+
# psycopg returns JSONB as a parsed dict; tolerate text too.
|
|
21
|
+
if isinstance(doc, (dict, list)):
|
|
22
|
+
return MemoryRecord.model_validate(doc)
|
|
23
|
+
return MemoryRecord.model_validate_json(doc)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PostgresBackend(StorageBackend):
|
|
27
|
+
"""Postgres + pgvector backend — behaviour-identical to SQLiteBackend.
|
|
28
|
+
|
|
29
|
+
The scale path for high write-concurrency or many millions of vectors. A
|
|
30
|
+
connection **pool** lets independent requests run concurrently (no global
|
|
31
|
+
lock). Within `transaction()`, one connection is bound to the current
|
|
32
|
+
context so nested writes share it and commit atomically.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, dsn: str, dim: int = 384, pool_size: int = 10) -> None:
|
|
36
|
+
from psycopg_pool import ConnectionPool # lazy: only when Postgres is used
|
|
37
|
+
|
|
38
|
+
self.dim = dim
|
|
39
|
+
# Holds the transaction's connection for the current context (else None).
|
|
40
|
+
self._active: ContextVar[Optional[object]] = ContextVar(
|
|
41
|
+
"relio_pg_active", default=None
|
|
42
|
+
)
|
|
43
|
+
self._pool = ConnectionPool(
|
|
44
|
+
dsn, min_size=1, max_size=pool_size, kwargs={"autocommit": True}, open=True
|
|
45
|
+
)
|
|
46
|
+
self._init_schema()
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def _conn(self) -> Iterator[object]:
|
|
50
|
+
"""Yield the transaction-bound connection if inside one, else a pooled one."""
|
|
51
|
+
active = self._active.get()
|
|
52
|
+
if active is not None:
|
|
53
|
+
yield active
|
|
54
|
+
else:
|
|
55
|
+
with self._pool.connection() as conn:
|
|
56
|
+
yield conn
|
|
57
|
+
|
|
58
|
+
def _init_schema(self) -> None:
|
|
59
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
60
|
+
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
61
|
+
cur.execute(
|
|
62
|
+
f"""
|
|
63
|
+
CREATE TABLE IF NOT EXISTS records (
|
|
64
|
+
rid BIGSERIAL PRIMARY KEY,
|
|
65
|
+
id TEXT UNIQUE NOT NULL,
|
|
66
|
+
doc JSONB NOT NULL,
|
|
67
|
+
expires_at DOUBLE PRECISION,
|
|
68
|
+
embedding vector({self.dim})
|
|
69
|
+
)
|
|
70
|
+
"""
|
|
71
|
+
)
|
|
72
|
+
# GIN index makes structured query() (Feature J) indexed on jsonb.
|
|
73
|
+
cur.execute("CREATE INDEX IF NOT EXISTS idx_doc_gin ON records USING GIN (doc)")
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _expires_at(record: MemoryRecord) -> float | None:
|
|
77
|
+
if record.ttl is None:
|
|
78
|
+
return None
|
|
79
|
+
return record.created_at.timestamp() + record.ttl
|
|
80
|
+
|
|
81
|
+
def add(self, record: MemoryRecord, embedding: list[float] | None) -> None:
|
|
82
|
+
doc = record.model_dump_json()
|
|
83
|
+
vec = _vector_literal(embedding) if embedding is not None else None
|
|
84
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
85
|
+
cur.execute(
|
|
86
|
+
"""
|
|
87
|
+
INSERT INTO records (id, doc, expires_at, embedding)
|
|
88
|
+
VALUES (%s, %s::jsonb, %s, %s::vector)
|
|
89
|
+
ON CONFLICT (id) DO UPDATE
|
|
90
|
+
SET doc = EXCLUDED.doc,
|
|
91
|
+
expires_at = EXCLUDED.expires_at,
|
|
92
|
+
embedding = EXCLUDED.embedding
|
|
93
|
+
""",
|
|
94
|
+
(record.id, doc, self._expires_at(record), vec),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def get(self, record_id: str) -> MemoryRecord | None:
|
|
98
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
99
|
+
cur.execute("SELECT doc FROM records WHERE id = %s", (record_id,))
|
|
100
|
+
row = cur.fetchone()
|
|
101
|
+
if row is None:
|
|
102
|
+
return None
|
|
103
|
+
return _to_record(row[0])
|
|
104
|
+
|
|
105
|
+
def delete(self, record_id: str) -> bool:
|
|
106
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
107
|
+
cur.execute("DELETE FROM records WHERE id = %s", (record_id,))
|
|
108
|
+
return cur.rowcount > 0
|
|
109
|
+
|
|
110
|
+
def all(self) -> list[MemoryRecord]:
|
|
111
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
112
|
+
cur.execute("SELECT doc FROM records ORDER BY rid")
|
|
113
|
+
rows = cur.fetchall()
|
|
114
|
+
return [_to_record(r[0]) for r in rows]
|
|
115
|
+
|
|
116
|
+
def search(self, embedding: list[float], k: int) -> list[tuple[MemoryRecord, float]]:
|
|
117
|
+
vec = _vector_literal(embedding)
|
|
118
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
119
|
+
cur.execute(
|
|
120
|
+
"""
|
|
121
|
+
SELECT doc, embedding <-> %s::vector AS distance
|
|
122
|
+
FROM records
|
|
123
|
+
WHERE embedding IS NOT NULL
|
|
124
|
+
ORDER BY distance
|
|
125
|
+
LIMIT %s
|
|
126
|
+
""",
|
|
127
|
+
(vec, k),
|
|
128
|
+
)
|
|
129
|
+
rows = cur.fetchall()
|
|
130
|
+
return [(_to_record(r[0]), float(r[1])) for r in rows]
|
|
131
|
+
|
|
132
|
+
def query(
|
|
133
|
+
self,
|
|
134
|
+
*,
|
|
135
|
+
type: Optional[MemoryType] = None,
|
|
136
|
+
scope: Optional[Scope] = None,
|
|
137
|
+
metadata: Optional[dict[str, str]] = None,
|
|
138
|
+
limit: int = 100,
|
|
139
|
+
) -> list[MemoryRecord]:
|
|
140
|
+
clauses: list[str] = []
|
|
141
|
+
params: list[object] = []
|
|
142
|
+
if type is not None:
|
|
143
|
+
clauses.append("doc->>'type' = %s")
|
|
144
|
+
params.append(type.value)
|
|
145
|
+
if scope is not None:
|
|
146
|
+
for field in ("tenant", "user", "agent", "session"):
|
|
147
|
+
value = getattr(scope, field)
|
|
148
|
+
if value is not None:
|
|
149
|
+
clauses.append(f"doc#>>'{{scope,{field}}}' = %s")
|
|
150
|
+
params.append(value)
|
|
151
|
+
for key, value in (metadata or {}).items():
|
|
152
|
+
if not _KEY.match(key):
|
|
153
|
+
raise ValueError(f"invalid metadata key: {key!r}")
|
|
154
|
+
clauses.append(f"doc#>>'{{metadata,{key}}}' = %s")
|
|
155
|
+
params.append(value)
|
|
156
|
+
where = (" WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
157
|
+
params.append(limit)
|
|
158
|
+
with self._conn() as conn, conn.cursor() as cur:
|
|
159
|
+
cur.execute(f"SELECT doc FROM records{where} ORDER BY rid LIMIT %s", params)
|
|
160
|
+
rows = cur.fetchall()
|
|
161
|
+
return [_to_record(r[0]) for r in rows]
|
|
162
|
+
|
|
163
|
+
@contextmanager
|
|
164
|
+
def transaction(self) -> Iterator[None]:
|
|
165
|
+
# Borrow one connection, bind it for this context so nested add()/delete()
|
|
166
|
+
# reuse it, and wrap the block in a single BEGIN/COMMIT.
|
|
167
|
+
with self._pool.connection() as conn:
|
|
168
|
+
token = self._active.set(conn)
|
|
169
|
+
try:
|
|
170
|
+
with conn.transaction():
|
|
171
|
+
yield
|
|
172
|
+
finally:
|
|
173
|
+
self._active.reset(token)
|
|
174
|
+
|
|
175
|
+
def close(self) -> None:
|
|
176
|
+
self._pool.close()
|