pgmnemo-mcp 0.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pgmnemo_mcp-0.5.2/PKG-INFO +10 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp/__init__.py +7 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp/__main__.py +56 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp/config.py +16 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp/server.py +109 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp.egg-info/PKG-INFO +10 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp.egg-info/SOURCES.txt +13 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp.egg-info/dependency_links.txt +1 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp.egg-info/entry_points.txt +2 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp.egg-info/requires.txt +3 -0
- pgmnemo_mcp-0.5.2/pgmnemo_mcp.egg-info/top_level.txt +1 -0
- pgmnemo_mcp-0.5.2/pyproject.toml +26 -0
- pgmnemo_mcp-0.5.2/setup.cfg +4 -0
- pgmnemo_mcp-0.5.2/tests/test_import.py +3 -0
- pgmnemo_mcp-0.5.2/tests/test_server.py +234 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pgmnemo-mcp
|
|
3
|
+
Version: 0.5.2
|
|
4
|
+
Summary: MCP server wrapping pgmnemo ingest and recall for AI agent memory
|
|
5
|
+
License: Apache-2.0
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: mcp>=1.0
|
|
9
|
+
Requires-Dist: psycopg2-binary>=2.9
|
|
10
|
+
Requires-Dist: pydantic>=2.0
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""python -m pgmnemo_mcp — CLI entry point with --smoke flag."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main() -> None:
|
|
8
|
+
parser = argparse.ArgumentParser(prog="pgmnemo-mcp")
|
|
9
|
+
parser.add_argument(
|
|
10
|
+
"--smoke",
|
|
11
|
+
action="store_true",
|
|
12
|
+
help="Run a connectivity smoke test: connect to DB and call recall_lessons().",
|
|
13
|
+
)
|
|
14
|
+
args = parser.parse_args()
|
|
15
|
+
|
|
16
|
+
if args.smoke:
|
|
17
|
+
_run_smoke()
|
|
18
|
+
else:
|
|
19
|
+
from .server import run
|
|
20
|
+
run()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _run_smoke() -> None:
|
|
24
|
+
from .config import DATABASE_URL, get_pool
|
|
25
|
+
|
|
26
|
+
print(f"pgmnemo-mcp smoke: connecting to {_redact(DATABASE_URL)} …")
|
|
27
|
+
try:
|
|
28
|
+
pool = get_pool()
|
|
29
|
+
conn = pool.getconn()
|
|
30
|
+
try:
|
|
31
|
+
with conn.cursor() as cur:
|
|
32
|
+
cur.execute(
|
|
33
|
+
"SELECT count(*) FROM pgmnemo.recall_lessons("
|
|
34
|
+
"NULL::vector(1024), 5, NULL, NULL, 'test')"
|
|
35
|
+
)
|
|
36
|
+
row = cur.fetchone()
|
|
37
|
+
pool.putconn(conn)
|
|
38
|
+
except Exception:
|
|
39
|
+
pool.putconn(conn)
|
|
40
|
+
raise
|
|
41
|
+
except Exception as exc:
|
|
42
|
+
print(f"pgmnemo-mcp smoke: FAIL — {exc}", file=sys.stderr)
|
|
43
|
+
sys.exit(1)
|
|
44
|
+
|
|
45
|
+
print(f"pgmnemo-mcp smoke: OK (recall_lessons returned {row[0]} rows)")
|
|
46
|
+
sys.exit(0)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _redact(url: str) -> str:
|
|
50
|
+
"""Hide password in DATABASE_URL for safe printing."""
|
|
51
|
+
import re
|
|
52
|
+
return re.sub(r"://([^:@]+):([^@]+)@", r"://\1:***@", url)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
main()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from psycopg2 import pool
|
|
3
|
+
|
|
4
|
+
__all__ = ["DATABASE_URL", "MCP_PORT", "get_pool"]
|
|
5
|
+
|
|
6
|
+
DATABASE_URL: str = os.environ.get("DATABASE_URL", "postgresql://localhost/pgmnemo")
|
|
7
|
+
MCP_PORT: int = int(os.environ.get("MCP_PORT", "8765"))
|
|
8
|
+
|
|
9
|
+
_pool: pool.SimpleConnectionPool | None = None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_pool() -> pool.SimpleConnectionPool:
|
|
13
|
+
global _pool
|
|
14
|
+
if _pool is None:
|
|
15
|
+
_pool = pool.SimpleConnectionPool(1, 5, dsn=DATABASE_URL)
|
|
16
|
+
return _pool
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""pgmnemo MCP server — exposes ingest and recall as MCP tools.
|
|
2
|
+
|
|
3
|
+
Transport note (BUG-3 resolution):
|
|
4
|
+
FastMCP uses MCP protocol transport (stdio by default, SSE/streamable-http
|
|
5
|
+
optionally). It does NOT expose REST endpoints at /ingest or /recall.
|
|
6
|
+
Clients must use the MCP JSON-RPC protocol (stdio pipe or SSE at /sse).
|
|
7
|
+
The --smoke command in __main__.py exercises the DB layer directly
|
|
8
|
+
without going through the MCP transport.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from mcp.server.fastmcp import FastMCP
|
|
17
|
+
|
|
18
|
+
from .config import get_pool
|
|
19
|
+
|
|
20
|
+
mcp = FastMCP("pgmnemo", port=8765)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@mcp.tool(name="pgmnemo.ingest", description="Ingest a lesson into pgmnemo agent memory.")
|
|
24
|
+
def ingest(
|
|
25
|
+
text: str,
|
|
26
|
+
role: str = "mcp_agent",
|
|
27
|
+
topic: str = "general",
|
|
28
|
+
importance: int = 3,
|
|
29
|
+
project_id: int = 1,
|
|
30
|
+
commit_sha: str | None = None,
|
|
31
|
+
artifact_hash: str | None = None,
|
|
32
|
+
metadata: dict[str, Any] | None = None,
|
|
33
|
+
) -> dict[str, Any]:
|
|
34
|
+
"""Store a lesson via pgmnemo.ingest() SP and return its id.
|
|
35
|
+
|
|
36
|
+
Uses the pgmnemo.ingest() stored procedure instead of raw INSERT so that:
|
|
37
|
+
- Gate enforcement (provenance checks) runs inside the SP.
|
|
38
|
+
- verified_at is stamped automatically when commit_sha/artifact_hash are present.
|
|
39
|
+
- Embedding dimension validation fires before the INSERT.
|
|
40
|
+
"""
|
|
41
|
+
p = get_pool()
|
|
42
|
+
conn = p.getconn()
|
|
43
|
+
try:
|
|
44
|
+
with conn.cursor() as cur:
|
|
45
|
+
cur.execute(
|
|
46
|
+
"""
|
|
47
|
+
SELECT pgmnemo.ingest(
|
|
48
|
+
%s, %s, %s, %s, %s::smallint,
|
|
49
|
+
NULL::vector(1024), %s, %s, %s::jsonb
|
|
50
|
+
)
|
|
51
|
+
""",
|
|
52
|
+
(
|
|
53
|
+
role,
|
|
54
|
+
project_id,
|
|
55
|
+
topic,
|
|
56
|
+
text,
|
|
57
|
+
importance,
|
|
58
|
+
commit_sha,
|
|
59
|
+
artifact_hash,
|
|
60
|
+
json.dumps(metadata) if metadata is not None else "{}",
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
new_id = cur.fetchone()[0]
|
|
64
|
+
conn.commit()
|
|
65
|
+
return {"id": new_id}
|
|
66
|
+
finally:
|
|
67
|
+
p.putconn(conn)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@mcp.tool(name="pgmnemo.recall", description="Recall lessons from pgmnemo agent memory.")
|
|
71
|
+
def recall(query: str, top_k: int = 5) -> list[dict[str, Any]]:
|
|
72
|
+
"""Return up to top_k lessons whose text matches query via pgmnemo.recall_lessons.
|
|
73
|
+
|
|
74
|
+
recall_lessons() RETURNS TABLE (lesson_id bigint, score, role, ...) — the output
|
|
75
|
+
column is 'lesson_id' (an alias), not 'id' (the physical table column).
|
|
76
|
+
"""
|
|
77
|
+
p = get_pool()
|
|
78
|
+
conn = p.getconn()
|
|
79
|
+
try:
|
|
80
|
+
with conn.cursor() as cur:
|
|
81
|
+
# recall_lessons(query_vec, top_k, role_filter, project_id_filter, query_text)
|
|
82
|
+
# Pass NULL vector — rely on query_text for BM25/hybrid keyword match.
|
|
83
|
+
cur.execute(
|
|
84
|
+
"""
|
|
85
|
+
SELECT lesson_id, role, topic, lesson_text, importance, created_at
|
|
86
|
+
FROM pgmnemo.recall_lessons(
|
|
87
|
+
NULL::vector(1024), %s, NULL, NULL, %s
|
|
88
|
+
)
|
|
89
|
+
""",
|
|
90
|
+
(top_k, query),
|
|
91
|
+
)
|
|
92
|
+
cols = [d[0] for d in cur.description]
|
|
93
|
+
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
|
94
|
+
finally:
|
|
95
|
+
p.putconn(conn)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def run() -> None:
|
|
99
|
+
"""Entry point for `python -m pgmnemo_mcp.server`."""
|
|
100
|
+
mcp.run()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def main() -> None:
|
|
104
|
+
"""Console script entry point (pgmnemo-mcp = pgmnemo_mcp.server:main)."""
|
|
105
|
+
run()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
main()
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pgmnemo-mcp
|
|
3
|
+
Version: 0.5.2
|
|
4
|
+
Summary: MCP server wrapping pgmnemo ingest and recall for AI agent memory
|
|
5
|
+
License: Apache-2.0
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: mcp>=1.0
|
|
9
|
+
Requires-Dist: psycopg2-binary>=2.9
|
|
10
|
+
Requires-Dist: pydantic>=2.0
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
pyproject.toml
|
|
2
|
+
pgmnemo_mcp/__init__.py
|
|
3
|
+
pgmnemo_mcp/__main__.py
|
|
4
|
+
pgmnemo_mcp/config.py
|
|
5
|
+
pgmnemo_mcp/server.py
|
|
6
|
+
pgmnemo_mcp.egg-info/PKG-INFO
|
|
7
|
+
pgmnemo_mcp.egg-info/SOURCES.txt
|
|
8
|
+
pgmnemo_mcp.egg-info/dependency_links.txt
|
|
9
|
+
pgmnemo_mcp.egg-info/entry_points.txt
|
|
10
|
+
pgmnemo_mcp.egg-info/requires.txt
|
|
11
|
+
pgmnemo_mcp.egg-info/top_level.txt
|
|
12
|
+
tests/test_import.py
|
|
13
|
+
tests/test_server.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
pgmnemo_mcp
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=42", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "pgmnemo-mcp"
|
|
7
|
+
version = "0.5.2"
|
|
8
|
+
description = "MCP server wrapping pgmnemo ingest and recall for AI agent memory"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { text = "Apache-2.0" }
|
|
11
|
+
requires-python = ">=3.11"
|
|
12
|
+
dependencies = [
|
|
13
|
+
"mcp>=1.0",
|
|
14
|
+
"psycopg2-binary>=2.9",
|
|
15
|
+
"pydantic>=2.0",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.scripts]
|
|
19
|
+
pgmnemo-mcp = "pgmnemo_mcp.server:run"
|
|
20
|
+
|
|
21
|
+
[tool.setuptools.packages.find]
|
|
22
|
+
where = ["."]
|
|
23
|
+
include = ["pgmnemo_mcp*"]
|
|
24
|
+
|
|
25
|
+
[tool.pytest.ini_options]
|
|
26
|
+
pythonpath = ["."]
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Unit tests for pgmnemo_mcp.server — ingest() and recall() tools.
|
|
2
|
+
|
|
3
|
+
Uses unittest.mock to patch get_pool() so no live PostgreSQL is required.
|
|
4
|
+
All tests verify the SQL / SP call generated and the return shape of each function.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import importlib
|
|
9
|
+
import json
|
|
10
|
+
import sys
|
|
11
|
+
import types
|
|
12
|
+
import unittest
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from unittest.mock import MagicMock, patch, call
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
# Stub the `mcp` package so tests run without `mcp` installed
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
def _stub_mcp() -> None:
|
|
22
|
+
if "mcp" in sys.modules:
|
|
23
|
+
return
|
|
24
|
+
mcp_pkg = types.ModuleType("mcp")
|
|
25
|
+
server_pkg = types.ModuleType("mcp.server")
|
|
26
|
+
fastmcp_mod = types.ModuleType("mcp.server.fastmcp")
|
|
27
|
+
|
|
28
|
+
class _FastMCP:
|
|
29
|
+
def __init__(self, name: str, **kw: object) -> None:
|
|
30
|
+
self.name = name
|
|
31
|
+
|
|
32
|
+
def tool(self, **kw: object):
|
|
33
|
+
def decorator(fn):
|
|
34
|
+
return fn
|
|
35
|
+
return decorator
|
|
36
|
+
|
|
37
|
+
def run(self) -> None:
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
fastmcp_mod.FastMCP = _FastMCP
|
|
41
|
+
sys.modules["mcp"] = mcp_pkg
|
|
42
|
+
sys.modules["mcp.server"] = server_pkg
|
|
43
|
+
sys.modules["mcp.server.fastmcp"] = fastmcp_mod
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
_stub_mcp()
|
|
47
|
+
|
|
48
|
+
# Force reimport after stub
|
|
49
|
+
for mod in list(sys.modules):
|
|
50
|
+
if mod.startswith("pgmnemo_mcp"):
|
|
51
|
+
del sys.modules[mod]
|
|
52
|
+
|
|
53
|
+
from pgmnemo_mcp import server # noqa: E402
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
# Helpers
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
|
|
60
|
+
def _make_pool(rows=None, description=None):
|
|
61
|
+
"""Return a mock psycopg2 connection pool and cursor."""
|
|
62
|
+
pool = MagicMock()
|
|
63
|
+
conn = MagicMock()
|
|
64
|
+
cur = MagicMock()
|
|
65
|
+
|
|
66
|
+
pool.getconn.return_value = conn
|
|
67
|
+
conn.cursor.return_value.__enter__ = lambda s: cur
|
|
68
|
+
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
|
69
|
+
|
|
70
|
+
if rows is not None:
|
|
71
|
+
cur.fetchone.return_value = rows[0] if rows else None
|
|
72
|
+
cur.fetchall.return_value = rows
|
|
73
|
+
if description is not None:
|
|
74
|
+
cur.description = [(col,) for col in description]
|
|
75
|
+
|
|
76
|
+
return pool, conn, cur
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ---------------------------------------------------------------------------
|
|
80
|
+
# ingest() tests
|
|
81
|
+
# ---------------------------------------------------------------------------
|
|
82
|
+
|
|
83
|
+
class TestIngest(unittest.TestCase):
|
|
84
|
+
|
|
85
|
+
def test_returns_id(self):
|
|
86
|
+
pool, conn, cur = _make_pool()
|
|
87
|
+
cur.fetchone.return_value = (42,)
|
|
88
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
89
|
+
result = server.ingest(text="hello world")
|
|
90
|
+
self.assertEqual(result["id"], 42)
|
|
91
|
+
|
|
92
|
+
def test_calls_ingest_sp_not_raw_insert(self):
|
|
93
|
+
"""Regression: must call pgmnemo.ingest() SP, not raw INSERT."""
|
|
94
|
+
pool, conn, cur = _make_pool()
|
|
95
|
+
cur.fetchone.return_value = (1,)
|
|
96
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
97
|
+
server.ingest(text="test lesson")
|
|
98
|
+
sql = cur.execute.call_args[0][0]
|
|
99
|
+
self.assertIn("pgmnemo.ingest", sql)
|
|
100
|
+
self.assertNotIn("INSERT INTO", sql)
|
|
101
|
+
|
|
102
|
+
def test_default_params_sent_to_sp(self):
|
|
103
|
+
"""SP receives args in positional order: role, project_id, topic, text, importance, ..."""
|
|
104
|
+
pool, conn, cur = _make_pool()
|
|
105
|
+
cur.fetchone.return_value = (1,)
|
|
106
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
107
|
+
server.ingest(text="test lesson")
|
|
108
|
+
args = cur.execute.call_args[0][1]
|
|
109
|
+
# Positional order matches SP signature: role, project_id, topic, lesson_text, importance
|
|
110
|
+
self.assertEqual(args[0], "mcp_agent") # role
|
|
111
|
+
self.assertEqual(args[1], 1) # project_id
|
|
112
|
+
self.assertEqual(args[2], "general") # topic
|
|
113
|
+
self.assertEqual(args[3], "test lesson") # lesson_text
|
|
114
|
+
self.assertEqual(args[4], 3) # importance
|
|
115
|
+
|
|
116
|
+
def test_project_id_passed_through(self):
|
|
117
|
+
pool, conn, cur = _make_pool()
|
|
118
|
+
cur.fetchone.return_value = (7,)
|
|
119
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
120
|
+
server.ingest(text="x", project_id=99)
|
|
121
|
+
args = cur.execute.call_args[0][1]
|
|
122
|
+
self.assertEqual(args[1], 99)
|
|
123
|
+
|
|
124
|
+
def test_commit_sha_and_artifact_hash_forwarded(self):
|
|
125
|
+
pool, conn, cur = _make_pool()
|
|
126
|
+
cur.fetchone.return_value = (5,)
|
|
127
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
128
|
+
server.ingest(text="x", commit_sha="abc123", artifact_hash="def456")
|
|
129
|
+
args = cur.execute.call_args[0][1]
|
|
130
|
+
self.assertEqual(args[5], "abc123") # commit_sha
|
|
131
|
+
self.assertEqual(args[6], "def456") # artifact_hash
|
|
132
|
+
|
|
133
|
+
def test_metadata_serialized_to_json(self):
|
|
134
|
+
pool, conn, cur = _make_pool()
|
|
135
|
+
cur.fetchone.return_value = (3,)
|
|
136
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
137
|
+
server.ingest(text="x", metadata={"key": "val"})
|
|
138
|
+
args = cur.execute.call_args[0][1]
|
|
139
|
+
self.assertEqual(json.loads(args[7]), {"key": "val"})
|
|
140
|
+
|
|
141
|
+
def test_none_metadata_defaults_to_empty_json_object(self):
|
|
142
|
+
pool, conn, cur = _make_pool()
|
|
143
|
+
cur.fetchone.return_value = (3,)
|
|
144
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
145
|
+
server.ingest(text="x")
|
|
146
|
+
args = cur.execute.call_args[0][1]
|
|
147
|
+
self.assertEqual(args[7], "{}")
|
|
148
|
+
|
|
149
|
+
def test_jsonb_cast_present_in_sql(self):
|
|
150
|
+
"""Ensures metadata arg is cast to ::jsonb so PostgreSQL accepts it."""
|
|
151
|
+
pool, conn, cur = _make_pool()
|
|
152
|
+
cur.fetchone.return_value = (1,)
|
|
153
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
154
|
+
server.ingest(text="x")
|
|
155
|
+
sql = cur.execute.call_args[0][0]
|
|
156
|
+
self.assertIn("::jsonb", sql)
|
|
157
|
+
|
|
158
|
+
def test_conn_returned_to_pool(self):
|
|
159
|
+
pool, conn, cur = _make_pool()
|
|
160
|
+
cur.fetchone.return_value = (1,)
|
|
161
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
162
|
+
server.ingest(text="x")
|
|
163
|
+
pool.putconn.assert_called_once_with(conn)
|
|
164
|
+
|
|
165
|
+
def test_conn_returned_on_exception(self):
|
|
166
|
+
pool, conn, cur = _make_pool()
|
|
167
|
+
cur.execute.side_effect = RuntimeError("db down")
|
|
168
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
169
|
+
with self.assertRaises(RuntimeError):
|
|
170
|
+
server.ingest(text="boom")
|
|
171
|
+
pool.putconn.assert_called_once_with(conn)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# ---------------------------------------------------------------------------
|
|
175
|
+
# recall() tests
|
|
176
|
+
# ---------------------------------------------------------------------------
|
|
177
|
+
|
|
178
|
+
class TestRecall(unittest.TestCase):
|
|
179
|
+
|
|
180
|
+
def test_returns_list_of_dicts(self):
|
|
181
|
+
ts = datetime(2026, 5, 17)
|
|
182
|
+
rows = [(10, "agent", "memory", "lesson A", 3, ts)]
|
|
183
|
+
cols = ["lesson_id", "role", "topic", "lesson_text", "importance", "created_at"]
|
|
184
|
+
pool, conn, cur = _make_pool(rows=rows, description=cols)
|
|
185
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
186
|
+
result = server.recall(query="memory test")
|
|
187
|
+
self.assertEqual(len(result), 1)
|
|
188
|
+
self.assertEqual(result[0]["lesson_id"], 10)
|
|
189
|
+
self.assertEqual(result[0]["lesson_text"], "lesson A")
|
|
190
|
+
|
|
191
|
+
def test_default_top_k_is_5(self):
|
|
192
|
+
pool, conn, cur = _make_pool(rows=[], description=["lesson_id"])
|
|
193
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
194
|
+
server.recall(query="q")
|
|
195
|
+
args = cur.execute.call_args[0][1]
|
|
196
|
+
self.assertEqual(args[0], 5) # top_k
|
|
197
|
+
|
|
198
|
+
def test_top_k_passed_through(self):
|
|
199
|
+
pool, conn, cur = _make_pool(rows=[], description=["lesson_id"])
|
|
200
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
201
|
+
server.recall(query="q", top_k=10)
|
|
202
|
+
args = cur.execute.call_args[0][1]
|
|
203
|
+
self.assertEqual(args[0], 10)
|
|
204
|
+
|
|
205
|
+
def test_query_text_forwarded(self):
|
|
206
|
+
pool, conn, cur = _make_pool(rows=[], description=["lesson_id"])
|
|
207
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
208
|
+
server.recall(query="find lessons about memory")
|
|
209
|
+
args = cur.execute.call_args[0][1]
|
|
210
|
+
self.assertEqual(args[1], "find lessons about memory")
|
|
211
|
+
|
|
212
|
+
def test_empty_result(self):
|
|
213
|
+
pool, conn, cur = _make_pool(rows=[], description=["lesson_id"])
|
|
214
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
215
|
+
result = server.recall(query="nothing")
|
|
216
|
+
self.assertEqual(result, [])
|
|
217
|
+
|
|
218
|
+
def test_conn_returned_to_pool(self):
|
|
219
|
+
pool, conn, cur = _make_pool(rows=[], description=["lesson_id"])
|
|
220
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
221
|
+
server.recall(query="x")
|
|
222
|
+
pool.putconn.assert_called_once_with(conn)
|
|
223
|
+
|
|
224
|
+
def test_conn_returned_on_exception(self):
|
|
225
|
+
pool, conn, cur = _make_pool()
|
|
226
|
+
cur.execute.side_effect = RuntimeError("db error")
|
|
227
|
+
with patch.object(server, "get_pool", return_value=pool):
|
|
228
|
+
with self.assertRaises(RuntimeError):
|
|
229
|
+
server.recall(query="x")
|
|
230
|
+
pool.putconn.assert_called_once_with(conn)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
if __name__ == "__main__":
|
|
234
|
+
unittest.main()
|