devmem-agents 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- devmem/__init__.py +5 -0
- devmem/api.py +257 -0
- devmem/config.py +34 -0
- devmem/embeddings.py +119 -0
- devmem/ingest.py +184 -0
- devmem/live_backend.py +344 -0
- devmem/main.py +11 -0
- devmem/models.py +157 -0
- devmem/retrieval_eval.py +145 -0
- devmem/service.py +280 -0
- devmem/storage/__init__.py +4 -0
- devmem/storage/milvus_store.py +321 -0
- devmem/storage/neptune_store.py +194 -0
- devmem/storage/record_store.py +974 -0
- devmem_agents-0.1.0.dist-info/METADATA +100 -0
- devmem_agents-0.1.0.dist-info/RECORD +19 -0
- devmem_agents-0.1.0.dist-info/WHEEL +5 -0
- devmem_agents-0.1.0.dist-info/licenses/LICENSE +21 -0
- devmem_agents-0.1.0.dist-info/top_level.txt +1 -0
devmem/live_backend.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""Live backend connectivity config and helpers.
|
|
2
|
+
|
|
3
|
+
This module is used by integration tests that probe live Aurora, Neptune,
|
|
4
|
+
and Milvus systems. Values come from DEVMEM_* environment variables and can
|
|
5
|
+
optionally fall back to a TOML reference file (``devmem.toml`` by default).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import socket
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Iterator
|
|
16
|
+
from urllib.parse import quote
|
|
17
|
+
from urllib.parse import urlparse
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _env_bool(name: str, default: bool = False) -> bool:
|
|
21
|
+
raw = os.getenv(name)
|
|
22
|
+
if raw is None:
|
|
23
|
+
return default
|
|
24
|
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _coerce_int(value: Any, default: int) -> int:
|
|
28
|
+
try:
|
|
29
|
+
return int(value)
|
|
30
|
+
except (TypeError, ValueError):
|
|
31
|
+
return default
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _coerce_str(value: Any) -> str | None:
|
|
35
|
+
if value is None:
|
|
36
|
+
return None
|
|
37
|
+
text = str(value).strip()
|
|
38
|
+
return text or None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _parse_bool(value: Any, default: bool) -> bool:
|
|
42
|
+
if value is None:
|
|
43
|
+
return default
|
|
44
|
+
if isinstance(value, bool):
|
|
45
|
+
return value
|
|
46
|
+
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _load_toml(path: Path) -> dict[str, Any]:
|
|
50
|
+
if not path.exists():
|
|
51
|
+
return {}
|
|
52
|
+
import tomllib
|
|
53
|
+
|
|
54
|
+
with path.open("rb") as handle:
|
|
55
|
+
parsed = tomllib.load(handle)
|
|
56
|
+
return parsed if isinstance(parsed, dict) else {}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _nested_get(payload: dict[str, Any], *keys: str) -> Any:
|
|
60
|
+
current: Any = payload
|
|
61
|
+
for key in keys:
|
|
62
|
+
if not isinstance(current, dict):
|
|
63
|
+
return None
|
|
64
|
+
current = current.get(key)
|
|
65
|
+
return current
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class LiveBackendConfig:
|
|
70
|
+
"""Resolved live-backend test configuration."""
|
|
71
|
+
|
|
72
|
+
namespace: str
|
|
73
|
+
enable_live_connectivity: bool
|
|
74
|
+
allow_writes: bool
|
|
75
|
+
|
|
76
|
+
aurora_dsn: str | None
|
|
77
|
+
|
|
78
|
+
neptune_endpoint: str | None
|
|
79
|
+
neptune_region: str | None
|
|
80
|
+
neptune_port: int
|
|
81
|
+
neptune_use_https: bool
|
|
82
|
+
neptune_iam_auth: bool
|
|
83
|
+
neptune_query_language: str
|
|
84
|
+
neptune_timeout: float
|
|
85
|
+
neptune_use_ssh_tunnel: bool
|
|
86
|
+
neptune_ssh_host: str | None
|
|
87
|
+
neptune_ssh_user: str | None
|
|
88
|
+
neptune_ssh_key_path: str | None
|
|
89
|
+
neptune_ssh_key_passphrase: str | None
|
|
90
|
+
|
|
91
|
+
milvus_uri: str | None
|
|
92
|
+
milvus_host: str | None
|
|
93
|
+
milvus_port: int
|
|
94
|
+
milvus_secure: bool
|
|
95
|
+
milvus_user: str | None
|
|
96
|
+
milvus_password: str | None
|
|
97
|
+
milvus_token: str | None
|
|
98
|
+
milvus_timeout: float
|
|
99
|
+
|
|
100
|
+
tcp_timeout_seconds: float
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def has_any_target(self) -> bool:
|
|
104
|
+
return bool(self.aurora_dsn or self.neptune_endpoint or self.milvus_uri or self.milvus_host)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _normalise_aurora_dsn(raw: str | None) -> str | None:
|
|
108
|
+
if not raw:
|
|
109
|
+
return None
|
|
110
|
+
# Allow SQLAlchemy-style DSNs from reference files.
|
|
111
|
+
uri = raw.replace("postgresql+psycopg2://", "postgresql://")
|
|
112
|
+
|
|
113
|
+
# Some reference files contain unescaped passwords in URI form. This helper
|
|
114
|
+
# rewrites only the userinfo section to URL-safe encoding.
|
|
115
|
+
if "://" not in uri or "@" not in uri:
|
|
116
|
+
return uri
|
|
117
|
+
try:
|
|
118
|
+
parsed = urlparse(uri)
|
|
119
|
+
if parsed.hostname and not (parsed.fragment and "@" in parsed.fragment):
|
|
120
|
+
return uri
|
|
121
|
+
except ValueError:
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
scheme, rest = uri.split("://", 1)
|
|
125
|
+
userinfo, hostpart = rest.rsplit("@", 1)
|
|
126
|
+
if ":" in userinfo:
|
|
127
|
+
user, password = userinfo.split(":", 1)
|
|
128
|
+
encoded_userinfo = f"{quote(user, safe='')}:{quote(password, safe='')}"
|
|
129
|
+
else:
|
|
130
|
+
encoded_userinfo = quote(userinfo, safe="")
|
|
131
|
+
return f"{scheme}://{encoded_userinfo}@{hostpart}"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _resolve_reference_toml() -> Path | None:
|
|
135
|
+
raw = _coerce_str(os.getenv("DEVMEM_LIVE_TOML_PATH"))
|
|
136
|
+
if raw:
|
|
137
|
+
return Path(raw).expanduser()
|
|
138
|
+
for candidate in (Path.cwd() / "devmem.toml", Path.home() / ".config" / "devmem" / "devmem.toml"):
|
|
139
|
+
if candidate.exists():
|
|
140
|
+
return candidate
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def load_live_backend_config() -> LiveBackendConfig:
|
|
145
|
+
"""Load connectivity config from env with optional TOML fallback."""
|
|
146
|
+
reference: dict[str, Any] = {}
|
|
147
|
+
toml_path = _resolve_reference_toml()
|
|
148
|
+
if toml_path:
|
|
149
|
+
reference = _load_toml(toml_path)
|
|
150
|
+
|
|
151
|
+
namespace = _coerce_str(os.getenv("DEVMEM_NAMESPACE")) or _coerce_str(reference.get("DEVMEM_NAMESPACE")) or "devlib_v1"
|
|
152
|
+
|
|
153
|
+
aurora_dsn = _normalise_aurora_dsn(
|
|
154
|
+
_coerce_str(os.getenv("DEVMEM_AURORA_DSN"))
|
|
155
|
+
or _coerce_str(reference.get("AURORA_DSN"))
|
|
156
|
+
or _coerce_str(reference.get("DATABASE_URL"))
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
neptune_endpoint = (
|
|
160
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_ENDPOINT"))
|
|
161
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "ENDPOINT"))
|
|
162
|
+
)
|
|
163
|
+
neptune_region = (
|
|
164
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_REGION"))
|
|
165
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "REGION"))
|
|
166
|
+
or _coerce_str(os.getenv("AWS_REGION"))
|
|
167
|
+
or _coerce_str(os.getenv("AWS_DEFAULT_REGION"))
|
|
168
|
+
)
|
|
169
|
+
neptune_port = _coerce_int(
|
|
170
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_PORT")) or _nested_get(reference, "NEPTUNE", "PORT"),
|
|
171
|
+
default=8182,
|
|
172
|
+
)
|
|
173
|
+
neptune_use_https = _parse_bool(
|
|
174
|
+
os.getenv("DEVMEM_NEPTUNE_USE_HTTPS") or _nested_get(reference, "NEPTUNE", "USE_HTTPS"),
|
|
175
|
+
default=True,
|
|
176
|
+
)
|
|
177
|
+
neptune_iam_auth = _parse_bool(
|
|
178
|
+
os.getenv("DEVMEM_NEPTUNE_IAM_AUTH") or _nested_get(reference, "NEPTUNE", "IAM_AUTH"),
|
|
179
|
+
default=True,
|
|
180
|
+
)
|
|
181
|
+
neptune_query_language = (
|
|
182
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_QUERY_LANGUAGE"))
|
|
183
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "QUERY_LANGUAGE"))
|
|
184
|
+
or "opencypher"
|
|
185
|
+
)
|
|
186
|
+
neptune_timeout = float(
|
|
187
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_TIMEOUT"))
|
|
188
|
+
or _nested_get(reference, "NEPTUNE", "TIMEOUT")
|
|
189
|
+
or 10.0
|
|
190
|
+
)
|
|
191
|
+
neptune_use_ssh_tunnel = _parse_bool(
|
|
192
|
+
os.getenv("DEVMEM_NEPTUNE_USE_SSH_TUNNEL") or _nested_get(reference, "NEPTUNE", "USE_SSH_TUNNEL"),
|
|
193
|
+
default=False,
|
|
194
|
+
)
|
|
195
|
+
neptune_ssh_host = (
|
|
196
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_SSH_HOST"))
|
|
197
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "SSH_HOST"))
|
|
198
|
+
)
|
|
199
|
+
neptune_ssh_user = (
|
|
200
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_SSH_USER"))
|
|
201
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "SSH_USER"))
|
|
202
|
+
)
|
|
203
|
+
neptune_ssh_key_path = (
|
|
204
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_SSH_KEY_PATH"))
|
|
205
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "SSH_KEY_PATH"))
|
|
206
|
+
)
|
|
207
|
+
neptune_ssh_key_passphrase = (
|
|
208
|
+
_coerce_str(os.getenv("DEVMEM_NEPTUNE_SSH_KEY_PASSPHRASE"))
|
|
209
|
+
or _coerce_str(_nested_get(reference, "NEPTUNE", "SSH_KEY_PASSPHRASE"))
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
milvus_uri = (
|
|
213
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_URI"))
|
|
214
|
+
or _coerce_str(_nested_get(reference, "MILVUS", "URI"))
|
|
215
|
+
)
|
|
216
|
+
milvus_host = (
|
|
217
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_HOST"))
|
|
218
|
+
or _coerce_str(_nested_get(reference, "MILVUS", "HOST"))
|
|
219
|
+
)
|
|
220
|
+
milvus_port = _coerce_int(
|
|
221
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_PORT")) or _nested_get(reference, "MILVUS", "PORT"),
|
|
222
|
+
default=19530,
|
|
223
|
+
)
|
|
224
|
+
milvus_secure = _parse_bool(
|
|
225
|
+
os.getenv("DEVMEM_MILVUS_SECURE") or _nested_get(reference, "MILVUS", "SECURE"),
|
|
226
|
+
default=False,
|
|
227
|
+
)
|
|
228
|
+
milvus_user = (
|
|
229
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_USER"))
|
|
230
|
+
or _coerce_str(_nested_get(reference, "MILVUS", "USER"))
|
|
231
|
+
)
|
|
232
|
+
milvus_password = (
|
|
233
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_PASSWORD"))
|
|
234
|
+
or _coerce_str(_nested_get(reference, "MILVUS", "PASSWORD"))
|
|
235
|
+
)
|
|
236
|
+
milvus_token = (
|
|
237
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_TOKEN"))
|
|
238
|
+
or _coerce_str(_nested_get(reference, "MILVUS", "TOKEN"))
|
|
239
|
+
)
|
|
240
|
+
milvus_timeout = float(
|
|
241
|
+
_coerce_str(os.getenv("DEVMEM_MILVUS_TIMEOUT"))
|
|
242
|
+
or _nested_get(reference, "MILVUS", "TIMEOUT")
|
|
243
|
+
or 10.0
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
tcp_timeout = float(_coerce_str(os.getenv("DEVMEM_LIVE_TCP_TIMEOUT")) or 5.0)
|
|
247
|
+
|
|
248
|
+
return LiveBackendConfig(
|
|
249
|
+
namespace=namespace,
|
|
250
|
+
enable_live_connectivity=_env_bool("DEVMEM_LIVE_CONNECTIVITY", default=False),
|
|
251
|
+
allow_writes=_env_bool("DEVMEM_LIVE_ALLOW_WRITES", default=False),
|
|
252
|
+
aurora_dsn=aurora_dsn,
|
|
253
|
+
neptune_endpoint=neptune_endpoint,
|
|
254
|
+
neptune_region=neptune_region,
|
|
255
|
+
neptune_port=neptune_port,
|
|
256
|
+
neptune_use_https=neptune_use_https,
|
|
257
|
+
neptune_iam_auth=neptune_iam_auth,
|
|
258
|
+
neptune_query_language=neptune_query_language,
|
|
259
|
+
neptune_timeout=neptune_timeout,
|
|
260
|
+
neptune_use_ssh_tunnel=neptune_use_ssh_tunnel,
|
|
261
|
+
neptune_ssh_host=neptune_ssh_host,
|
|
262
|
+
neptune_ssh_user=neptune_ssh_user,
|
|
263
|
+
neptune_ssh_key_path=neptune_ssh_key_path,
|
|
264
|
+
neptune_ssh_key_passphrase=neptune_ssh_key_passphrase,
|
|
265
|
+
milvus_uri=milvus_uri,
|
|
266
|
+
milvus_host=milvus_host,
|
|
267
|
+
milvus_port=milvus_port,
|
|
268
|
+
milvus_secure=milvus_secure,
|
|
269
|
+
milvus_user=milvus_user,
|
|
270
|
+
milvus_password=milvus_password,
|
|
271
|
+
milvus_token=milvus_token,
|
|
272
|
+
milvus_timeout=milvus_timeout,
|
|
273
|
+
tcp_timeout_seconds=tcp_timeout,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def parse_host_port_from_url(url: str, default_port: int) -> tuple[str, int]:
|
|
278
|
+
parsed = urlparse(url)
|
|
279
|
+
host = parsed.hostname or ""
|
|
280
|
+
port = parsed.port or default_port
|
|
281
|
+
return host, port
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def parse_postgres_host_port(dsn: str) -> tuple[str, int]:
|
|
285
|
+
normalized = _normalise_aurora_dsn(dsn) or dsn
|
|
286
|
+
parsed = urlparse(normalized)
|
|
287
|
+
host = parsed.hostname or ""
|
|
288
|
+
try:
|
|
289
|
+
port = parsed.port or 5432
|
|
290
|
+
except ValueError:
|
|
291
|
+
port = 5432
|
|
292
|
+
host = ""
|
|
293
|
+
|
|
294
|
+
if host:
|
|
295
|
+
return host, port
|
|
296
|
+
|
|
297
|
+
# Fallback parser for malformed but common DSN patterns.
|
|
298
|
+
rest = normalized.split("://", 1)[-1]
|
|
299
|
+
host_part = rest.rsplit("@", 1)[-1]
|
|
300
|
+
host_port = host_part.split("/", 1)[0]
|
|
301
|
+
if ":" in host_port:
|
|
302
|
+
host_value, port_raw = host_port.rsplit(":", 1)
|
|
303
|
+
return host_value, _coerce_int(port_raw, default=5432)
|
|
304
|
+
return host_port, 5432
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def tcp_connect(host: str, port: int, timeout: float) -> None:
|
|
308
|
+
with socket.create_connection((host, port), timeout=timeout):
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@contextmanager
|
|
313
|
+
def neptune_tunnel(cfg: LiveBackendConfig) -> Iterator[tuple[str, int]]:
|
|
314
|
+
"""Yield (host, port) for Neptune direct or SSH tunnel access."""
|
|
315
|
+
if not cfg.neptune_use_ssh_tunnel:
|
|
316
|
+
if not cfg.neptune_endpoint:
|
|
317
|
+
raise ValueError("Neptune endpoint is not configured")
|
|
318
|
+
yield cfg.neptune_endpoint, cfg.neptune_port
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
if not cfg.neptune_ssh_host or not cfg.neptune_ssh_user or not cfg.neptune_ssh_key_path:
|
|
322
|
+
raise ValueError("SSH tunnel enabled but host/user/key path are missing")
|
|
323
|
+
if not cfg.neptune_endpoint:
|
|
324
|
+
raise ValueError("Neptune endpoint is required when SSH tunnel is enabled")
|
|
325
|
+
|
|
326
|
+
sshtunnel = __import__("sshtunnel")
|
|
327
|
+
paramiko = __import__("paramiko")
|
|
328
|
+
if not hasattr(paramiko, "DSSKey"):
|
|
329
|
+
paramiko.DSSKey = paramiko.RSAKey
|
|
330
|
+
forwarder = sshtunnel.SSHTunnelForwarder(
|
|
331
|
+
(cfg.neptune_ssh_host, 22),
|
|
332
|
+
ssh_username=cfg.neptune_ssh_user,
|
|
333
|
+
ssh_pkey=str(Path(cfg.neptune_ssh_key_path).expanduser()),
|
|
334
|
+
ssh_private_key_password=cfg.neptune_ssh_key_passphrase,
|
|
335
|
+
remote_bind_address=(cfg.neptune_endpoint, cfg.neptune_port),
|
|
336
|
+
local_bind_address=("127.0.0.1", 0),
|
|
337
|
+
allow_agent=False,
|
|
338
|
+
host_pkey_directories=[],
|
|
339
|
+
)
|
|
340
|
+
forwarder.start()
|
|
341
|
+
try:
|
|
342
|
+
yield "127.0.0.1", int(forwarder.local_bind_port)
|
|
343
|
+
finally:
|
|
344
|
+
forwarder.stop()
|
devmem/main.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Application entrypoint for devmem."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from fastapi import FastAPI
|
|
6
|
+
|
|
7
|
+
from devmem.api import router
|
|
8
|
+
from devmem.config import settings
|
|
9
|
+
|
|
10
|
+
app = FastAPI(title=settings.service_name, version=settings.service_version)
|
|
11
|
+
app.include_router(router)
|
devmem/models.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""API models for devmem."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Any
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ApiResponse(BaseModel):
|
|
13
|
+
ok: bool = True
|
|
14
|
+
data: dict[str, Any] = Field(default_factory=dict)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NamespaceScoped(BaseModel):
|
|
18
|
+
namespace: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SessionStartRequest(NamespaceScoped):
|
|
22
|
+
project: str
|
|
23
|
+
repo: str
|
|
24
|
+
branch: str
|
|
25
|
+
agent: str
|
|
26
|
+
task: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SessionStartResponse(BaseModel):
|
|
30
|
+
session_id: str = Field(default_factory=lambda: str(uuid4()))
|
|
31
|
+
started_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ContextPullRequest(NamespaceScoped):
|
|
35
|
+
session_id: str
|
|
36
|
+
project: str
|
|
37
|
+
repo: str
|
|
38
|
+
task: str
|
|
39
|
+
top_k: int = 8
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HybridSearchRequest(NamespaceScoped):
|
|
43
|
+
q: str
|
|
44
|
+
project: str
|
|
45
|
+
repo: str
|
|
46
|
+
top_k: int = 8
|
|
47
|
+
filters: dict[str, Any] = Field(default_factory=dict)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ArtifactUpsertRequest(NamespaceScoped):
|
|
51
|
+
session_id: str
|
|
52
|
+
artifact_type: str
|
|
53
|
+
title: str
|
|
54
|
+
content: str
|
|
55
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class FactUpsertRequest(NamespaceScoped):
|
|
59
|
+
session_id: str
|
|
60
|
+
subject: str
|
|
61
|
+
predicate: str
|
|
62
|
+
object: str
|
|
63
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class DecisionUpsertRequest(NamespaceScoped):
|
|
67
|
+
session_id: str
|
|
68
|
+
title: str
|
|
69
|
+
decision: str
|
|
70
|
+
rationale: str
|
|
71
|
+
alternatives: list[str] = Field(default_factory=list)
|
|
72
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class HandoffCreateRequest(NamespaceScoped):
|
|
76
|
+
session_id: str
|
|
77
|
+
from_agent: str
|
|
78
|
+
to_agent: str
|
|
79
|
+
summary: str
|
|
80
|
+
next_steps: list[str] = Field(default_factory=list)
|
|
81
|
+
blockers: list[str] = Field(default_factory=list)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class TaskUpdateRequest(NamespaceScoped):
|
|
85
|
+
session_id: str
|
|
86
|
+
status: str
|
|
87
|
+
note: str = ""
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class FeedbackRecordRequest(NamespaceScoped):
|
|
91
|
+
session_id: str
|
|
92
|
+
outcome: str
|
|
93
|
+
score: float | None = None
|
|
94
|
+
notes: str = ""
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TaskSimilarRequest(NamespaceScoped):
|
|
98
|
+
q: str
|
|
99
|
+
project: str | None = None
|
|
100
|
+
repo: str | None = None
|
|
101
|
+
top_k: int = 5
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# --------------------------------------------------------------------------
|
|
105
|
+
# Session commit (Pass 2 — atomic multi-write finalization)
|
|
106
|
+
# --------------------------------------------------------------------------
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class CommitArtifact(BaseModel):
|
|
110
|
+
artifact_type: str
|
|
111
|
+
title: str
|
|
112
|
+
content: str
|
|
113
|
+
project: str | None = None
|
|
114
|
+
repo: str | None = None
|
|
115
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class CommitDecision(BaseModel):
|
|
119
|
+
title: str
|
|
120
|
+
decision: str
|
|
121
|
+
rationale: str
|
|
122
|
+
alternatives: list[str] = Field(default_factory=list)
|
|
123
|
+
project: str | None = None
|
|
124
|
+
repo: str | None = None
|
|
125
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class CommitHandoff(BaseModel):
|
|
129
|
+
from_agent: str
|
|
130
|
+
to_agent: str
|
|
131
|
+
summary: str
|
|
132
|
+
next_steps: list[str] = Field(default_factory=list)
|
|
133
|
+
blockers: list[str] = Field(default_factory=list)
|
|
134
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class CommitTaskUpdate(BaseModel):
|
|
138
|
+
status: str
|
|
139
|
+
note: str = ""
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class SessionCommitRequest(NamespaceScoped):
|
|
143
|
+
"""Finalize a session with one atomic multi-write.
|
|
144
|
+
|
|
145
|
+
Replaces the 3-call ritual of `/v1/artifacts/upsert` +
|
|
146
|
+
`/v1/decisions/upsert` + `/v1/tasks/update`.
|
|
147
|
+
|
|
148
|
+
`client_commit_id`, when provided, enables safe retries — a second commit
|
|
149
|
+
with the same id returns the original result without re-inserting rows.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
session_id: str
|
|
153
|
+
artifacts: list[CommitArtifact] = Field(default_factory=list)
|
|
154
|
+
decisions: list[CommitDecision] = Field(default_factory=list)
|
|
155
|
+
handoff: CommitHandoff | None = None
|
|
156
|
+
task_update: CommitTaskUpdate | None = None
|
|
157
|
+
client_commit_id: str | None = None
|
devmem/retrieval_eval.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Evaluate retrieval quality against live ingested project memory."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable
|
|
8
|
+
|
|
9
|
+
from devmem.live_backend import LiveBackendConfig
|
|
10
|
+
from devmem.storage.milvus_store import MilvusStore
|
|
11
|
+
from devmem.storage.neptune_store import NeptuneStore
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class RetrievalCase:
|
|
16
|
+
"""Expected retrieval behavior for a natural-language query."""
|
|
17
|
+
|
|
18
|
+
case_id: str
|
|
19
|
+
query: str
|
|
20
|
+
expected_path_contains: tuple[str, ...]
|
|
21
|
+
expected_terms: tuple[str, ...] = ()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_cases(path: Path) -> list[RetrievalCase]:
|
|
25
|
+
"""Load retrieval cases from JSON."""
|
|
26
|
+
import json
|
|
27
|
+
|
|
28
|
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
|
29
|
+
if not isinstance(payload, list):
|
|
30
|
+
raise ValueError("cases file must contain a JSON array")
|
|
31
|
+
|
|
32
|
+
cases: list[RetrievalCase] = []
|
|
33
|
+
for item in payload:
|
|
34
|
+
if not isinstance(item, dict):
|
|
35
|
+
continue
|
|
36
|
+
case_id = str(item.get("id") or "").strip()
|
|
37
|
+
query = str(item.get("query") or "").strip()
|
|
38
|
+
path_contains = item.get("expected_path_contains") or []
|
|
39
|
+
expected_terms = item.get("expected_terms") or []
|
|
40
|
+
if not case_id or not query or not isinstance(path_contains, list) or not path_contains:
|
|
41
|
+
raise ValueError(f"invalid retrieval case: {item}")
|
|
42
|
+
cases.append(
|
|
43
|
+
RetrievalCase(
|
|
44
|
+
case_id=case_id,
|
|
45
|
+
query=query,
|
|
46
|
+
expected_path_contains=tuple(str(x) for x in path_contains),
|
|
47
|
+
expected_terms=tuple(str(x) for x in expected_terms if str(x).strip()),
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
return cases
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _path_matches(path: str, expected_fragments: Iterable[str]) -> bool:
|
|
54
|
+
path_lc = path.lower()
|
|
55
|
+
return any(fragment.lower() in path_lc for fragment in expected_fragments)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _row_has_terms(row: dict[str, Any], expected_terms: tuple[str, ...]) -> bool:
|
|
59
|
+
if not expected_terms:
|
|
60
|
+
return True
|
|
61
|
+
content = str(row.get("content") or "").lower()
|
|
62
|
+
return all(term.lower() in content for term in expected_terms)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _evaluate_single_case(
|
|
66
|
+
*,
|
|
67
|
+
case: RetrievalCase,
|
|
68
|
+
rows: list[dict[str, Any]],
|
|
69
|
+
kg_paths: set[str],
|
|
70
|
+
) -> dict[str, Any]:
|
|
71
|
+
matches: list[dict[str, Any]] = []
|
|
72
|
+
for row in rows:
|
|
73
|
+
path = str(row.get("path") or "")
|
|
74
|
+
if not path:
|
|
75
|
+
continue
|
|
76
|
+
if not _path_matches(path, case.expected_path_contains):
|
|
77
|
+
continue
|
|
78
|
+
if not _row_has_terms(row, case.expected_terms):
|
|
79
|
+
continue
|
|
80
|
+
matches.append(row)
|
|
81
|
+
|
|
82
|
+
kg_match = any(_path_matches(path, case.expected_path_contains) for path in kg_paths)
|
|
83
|
+
return {
|
|
84
|
+
"case_id": case.case_id,
|
|
85
|
+
"query": case.query,
|
|
86
|
+
"passed": bool(matches),
|
|
87
|
+
"kg_match": kg_match,
|
|
88
|
+
"matches": [
|
|
89
|
+
{
|
|
90
|
+
"path": str(match.get("path") or ""),
|
|
91
|
+
"chunk_index": match.get("chunk_index"),
|
|
92
|
+
"score": match.get("score"),
|
|
93
|
+
}
|
|
94
|
+
for match in matches[:5]
|
|
95
|
+
],
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def evaluate_retrieval(
|
|
100
|
+
*,
|
|
101
|
+
cfg: LiveBackendConfig,
|
|
102
|
+
namespace: str,
|
|
103
|
+
project_id: str,
|
|
104
|
+
repo_id: str,
|
|
105
|
+
cases: list[RetrievalCase],
|
|
106
|
+
top_k: int = 8,
|
|
107
|
+
) -> dict[str, Any]:
|
|
108
|
+
"""Run retrieval evaluation cases against Milvus and Neptune."""
|
|
109
|
+
milvus = MilvusStore(cfg, namespace=namespace, project_id=project_id, repo_id=repo_id)
|
|
110
|
+
neptune = NeptuneStore(cfg, namespace=namespace)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
milvus.connect()
|
|
114
|
+
neptune.connect()
|
|
115
|
+
neptune.health_check()
|
|
116
|
+
|
|
117
|
+
kg_paths = set(neptune.list_project_files(project_id=project_id, repo_id=repo_id, limit=100000))
|
|
118
|
+
|
|
119
|
+
case_results: list[dict[str, Any]] = []
|
|
120
|
+
for case in cases:
|
|
121
|
+
rows = milvus.lexical_search(query_text=case.query, top_k=top_k)
|
|
122
|
+
case_results.append(_evaluate_single_case(case=case, rows=rows, kg_paths=kg_paths))
|
|
123
|
+
|
|
124
|
+
passed = sum(1 for item in case_results if item.get("passed"))
|
|
125
|
+
kg_covered = sum(1 for item in case_results if item.get("kg_match"))
|
|
126
|
+
total = len(case_results)
|
|
127
|
+
|
|
128
|
+
return {
|
|
129
|
+
"summary": {
|
|
130
|
+
"total_cases": total,
|
|
131
|
+
"passed_cases": passed,
|
|
132
|
+
"pass_rate": (float(passed) / float(total)) if total else 0.0,
|
|
133
|
+
"kg_covered_cases": kg_covered,
|
|
134
|
+
"kg_coverage_rate": (float(kg_covered) / float(total)) if total else 0.0,
|
|
135
|
+
"namespace": namespace,
|
|
136
|
+
"project_id": project_id,
|
|
137
|
+
"repo_id": repo_id,
|
|
138
|
+
"milvus_collection": milvus.collection_name,
|
|
139
|
+
"neptune_endpoint": neptune.endpoint_summary(),
|
|
140
|
+
},
|
|
141
|
+
"results": case_results,
|
|
142
|
+
}
|
|
143
|
+
finally:
|
|
144
|
+
neptune.close()
|
|
145
|
+
milvus.close()
|