@pentatonic-ai/ai-agent-sdk 0.9.4 → 0.9.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +39 -72
- package/dist/index.js +36 -69
- package/package.json +9 -2
- package/packages/memory/package-lock.json +49 -33
- package/packages/memory/package.json +4 -1
- package/packages/memory/src/__tests__/engine.test.js +40 -5
- package/packages/memory/src/engine.js +38 -3
- package/packages/memory-engine/docker-compose.yml +24 -2
- package/packages/memory-engine/engine/services/_shared/embed_provider.py +125 -31
- package/packages/memory-engine/engine/services/l2/Dockerfile +7 -0
- package/packages/memory-engine/engine/services/l2/l2-hybridrag-proxy.py +233 -60
- package/packages/memory-engine/tests/test_embed_provider.py +201 -0
- package/packages/memory-engine/tests/test_l2_qmd_vec_search.py +280 -0
|
@@ -18,6 +18,7 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import os
|
|
20
20
|
import sqlite3
|
|
21
|
+
import struct
|
|
21
22
|
import sys
|
|
22
23
|
import time
|
|
23
24
|
from contextlib import asynccontextmanager
|
|
@@ -34,6 +35,11 @@ from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate
|
|
|
34
35
|
from pydantic import BaseModel
|
|
35
36
|
import uvicorn
|
|
36
37
|
|
|
38
|
+
try:
|
|
39
|
+
import sqlite_vec # 0.1.9 — native KNN MATCH over packed-f32 vec0 tables
|
|
40
|
+
except ImportError:
|
|
41
|
+
sqlite_vec = None # Caller logs loudly if helpers can't load the extension
|
|
42
|
+
|
|
37
43
|
# Shared embed client lives at engine/services/_shared/.
|
|
38
44
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
39
45
|
from _shared.embed_provider import EmbedClient # noqa: E402
|
|
@@ -101,6 +107,59 @@ QMD_DB_PATH = _resolve_qmd_db()
|
|
|
101
107
|
OLLAMA_URL = os.environ.get("PME_OLLAMA_URL", "http://localhost:11434/api/embeddings")
|
|
102
108
|
EMBEDDING_MODEL = os.environ.get("PME_EMBED_MODEL", "nomic-embed-text")
|
|
103
109
|
|
|
110
|
+
# Embedding dimension for the vec0 virtual table. Production gateway
|
|
111
|
+
# (lambda-gateway.pentatonic.com/v1/embed via pentatonic-gateway provider)
|
|
112
|
+
# returns NV-Embed-v2 4096-dim vectors. The vec0 schema requires the dim
|
|
113
|
+
# at DDL time and writers must match — keep this in lockstep with the
|
|
114
|
+
# gateway / EmbedClient config.
|
|
115
|
+
EMBED_DIM = int(os.environ.get("PME_EMBED_DIM", "4096"))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _open_qmd_conn() -> sqlite3.Connection:
|
|
119
|
+
"""Open qmd.sqlite with sqlite-vec loaded.
|
|
120
|
+
|
|
121
|
+
Falls back to a plain sqlite3 connection if the extension can't load —
|
|
122
|
+
MATCH-form queries will then fail loudly at execute time, which is the
|
|
123
|
+
right signal (loud error > silent degradation back to Python cosine).
|
|
124
|
+
Callers that only need scalar columns (chunks.path, chunks.text) work
|
|
125
|
+
fine without the extension.
|
|
126
|
+
|
|
127
|
+
``check_same_thread=False`` is intentional: the async backfill yields
|
|
128
|
+
via ``asyncio.to_thread`` to keep /search responsive, which means the
|
|
129
|
+
connection is handed off between event-loop / thread-pool workers.
|
|
130
|
+
sqlite's default thread-safety check would otherwise reject the
|
|
131
|
+
cross-thread reuse even though only one worker touches it at a time.
|
|
132
|
+
"""
|
|
133
|
+
conn = sqlite3.connect(QMD_DB_PATH, timeout=10, check_same_thread=False)
|
|
134
|
+
if sqlite_vec is None:
|
|
135
|
+
log.error("sqlite_vec module not importable — qmd vec_index unavailable")
|
|
136
|
+
return conn
|
|
137
|
+
try:
|
|
138
|
+
conn.enable_load_extension(True)
|
|
139
|
+
sqlite_vec.load(conn)
|
|
140
|
+
conn.enable_load_extension(False)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
log.error(f"sqlite-vec load failed: {e} — qmd search will be degraded")
|
|
143
|
+
return conn
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _ensure_vec_index(conn: sqlite3.Connection) -> None:
|
|
147
|
+
"""Create the vec0 KNN index if not already present. Idempotent.
|
|
148
|
+
|
|
149
|
+
`distance_metric=cosine` is non-default — sqlite-vec defaults to L2
|
|
150
|
+
(Euclidean). Probe confirmed cosine returns `1 - cos_sim` as the
|
|
151
|
+
distance. The id column is a regular INTEGER PRIMARY KEY so we can
|
|
152
|
+
JOIN back to `chunks` on the row's autoinc id.
|
|
153
|
+
"""
|
|
154
|
+
conn.execute(
|
|
155
|
+
f"""
|
|
156
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS vec_index USING vec0(
|
|
157
|
+
id INTEGER PRIMARY KEY,
|
|
158
|
+
embedding float[{EMBED_DIM}] distance_metric=cosine
|
|
159
|
+
)
|
|
160
|
+
"""
|
|
161
|
+
)
|
|
162
|
+
|
|
104
163
|
# NV-Embed-v2 service (primary, 4096-dim). URL/auth/path/body/response are
|
|
105
164
|
# managed by the shared EmbedClient; PME_EMBED_PROVIDER (default openai)
|
|
106
165
|
# selects auth scheme (Bearer vs X-API-Key) and request shape.
|
|
@@ -177,13 +236,25 @@ def get_http_client() -> httpx.AsyncClient:
|
|
|
177
236
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
178
237
|
"""Open the neo4j driver + HTTP client at process startup, close on
|
|
179
238
|
shutdown. Without this, the first request pays driver-open latency
|
|
180
|
-
and the driver is never properly closed on SIGTERM (leaking conns).
|
|
239
|
+
and the driver is never properly closed on SIGTERM (leaking conns).
|
|
240
|
+
|
|
241
|
+
Also schedules the vec_index backfill as a background task so the
|
|
242
|
+
proxy can start serving immediately while older chunks copy across
|
|
243
|
+
into the KNN index — first-time migration of ~450k rows takes
|
|
244
|
+
minutes and would otherwise block /health.
|
|
245
|
+
"""
|
|
181
246
|
global _neo4j_driver, _http_client
|
|
182
247
|
_neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=NEO4J_AUTH)
|
|
183
248
|
_http_client = httpx.AsyncClient(timeout=30.0)
|
|
249
|
+
backfill_task = asyncio.create_task(_backfill_vec_index())
|
|
184
250
|
try:
|
|
185
251
|
yield
|
|
186
252
|
finally:
|
|
253
|
+
backfill_task.cancel()
|
|
254
|
+
try:
|
|
255
|
+
await backfill_task
|
|
256
|
+
except (asyncio.CancelledError, Exception):
|
|
257
|
+
pass
|
|
187
258
|
if _neo4j_driver is not None:
|
|
188
259
|
await _neo4j_driver.close()
|
|
189
260
|
_neo4j_driver = None
|
|
@@ -192,6 +263,82 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
|
192
263
|
_http_client = None
|
|
193
264
|
|
|
194
265
|
|
|
266
|
+
async def _backfill_vec_index() -> None:
|
|
267
|
+
"""One-time migration: copy existing chunks.embedding (JSON) into
|
|
268
|
+
vec_index (f32 bytes).
|
|
269
|
+
|
|
270
|
+
Idempotent. Runs at proxy startup if vec_index has fewer rows than
|
|
271
|
+
chunks. Async so it doesn't block /health — the proxy serves
|
|
272
|
+
requests in parallel and search degrades to partial-corpus results
|
|
273
|
+
until the backfill finishes (any chunk already mirrored into
|
|
274
|
+
vec_index is findable; the rest are invisible to search but still
|
|
275
|
+
in L0/L3/L5/L6).
|
|
276
|
+
|
|
277
|
+
At 450k rows + ~460 rows/s insert rate this takes ~16 min on a
|
|
278
|
+
cold prod instance. Subsequent restarts no-op cleanly.
|
|
279
|
+
"""
|
|
280
|
+
if sqlite_vec is None:
|
|
281
|
+
log.error("sqlite_vec module missing — backfill skipped, search will be degraded")
|
|
282
|
+
return
|
|
283
|
+
if not os.path.exists(QMD_DB_PATH):
|
|
284
|
+
log.info("vec_index backfill skipped — qmd.sqlite does not exist yet")
|
|
285
|
+
return
|
|
286
|
+
try:
|
|
287
|
+
conn = await asyncio.to_thread(_open_qmd_conn)
|
|
288
|
+
await asyncio.to_thread(_ensure_vec_index, conn)
|
|
289
|
+
chunks_n = conn.execute(
|
|
290
|
+
"SELECT count(*) FROM chunks WHERE embedding IS NOT NULL"
|
|
291
|
+
).fetchone()[0]
|
|
292
|
+
vec_n = conn.execute("SELECT count(*) FROM vec_index").fetchone()[0]
|
|
293
|
+
if vec_n >= chunks_n:
|
|
294
|
+
log.info(f"vec_index backfill skipped — already in sync ({vec_n}/{chunks_n})")
|
|
295
|
+
conn.close()
|
|
296
|
+
return
|
|
297
|
+
missing = chunks_n - vec_n
|
|
298
|
+
log.info(f"vec_index backfill starting — {missing} rows to copy")
|
|
299
|
+
cursor = conn.execute(
|
|
300
|
+
"""
|
|
301
|
+
SELECT c.id, c.embedding
|
|
302
|
+
FROM chunks c
|
|
303
|
+
LEFT JOIN vec_index v ON v.id = c.id
|
|
304
|
+
WHERE v.id IS NULL AND c.embedding IS NOT NULL
|
|
305
|
+
"""
|
|
306
|
+
)
|
|
307
|
+
BATCH = 500
|
|
308
|
+
copied = 0
|
|
309
|
+
while True:
|
|
310
|
+
batch = await asyncio.to_thread(cursor.fetchmany, BATCH)
|
|
311
|
+
if not batch:
|
|
312
|
+
break
|
|
313
|
+
def _insert_batch() -> int:
|
|
314
|
+
inserted = 0
|
|
315
|
+
with conn:
|
|
316
|
+
for cid, emb_json in batch:
|
|
317
|
+
try:
|
|
318
|
+
vec = json.loads(emb_json)
|
|
319
|
+
except Exception:
|
|
320
|
+
continue
|
|
321
|
+
if len(vec) != EMBED_DIM:
|
|
322
|
+
continue
|
|
323
|
+
conn.execute(
|
|
324
|
+
"INSERT INTO vec_index(id, embedding) VALUES (?, ?)",
|
|
325
|
+
(cid, struct.pack(f"{len(vec)}f", *vec)),
|
|
326
|
+
)
|
|
327
|
+
inserted += 1
|
|
328
|
+
return inserted
|
|
329
|
+
copied += await asyncio.to_thread(_insert_batch)
|
|
330
|
+
log.info(f"vec_index backfill progress: {copied}/{missing}")
|
|
331
|
+
# Yield generously so /search + writers aren't starved.
|
|
332
|
+
await asyncio.sleep(0)
|
|
333
|
+
log.info(f"vec_index backfill done — {copied} rows copied")
|
|
334
|
+
conn.close()
|
|
335
|
+
except asyncio.CancelledError:
|
|
336
|
+
log.info("vec_index backfill cancelled during shutdown")
|
|
337
|
+
raise
|
|
338
|
+
except Exception as e:
|
|
339
|
+
log.error(f"vec_index backfill failed: {e}")
|
|
340
|
+
|
|
341
|
+
|
|
195
342
|
app = FastAPI(title="Sequential HybridRAG Proxy", version="1.0.0", lifespan=lifespan)
|
|
196
343
|
|
|
197
344
|
# ---------------------------------------------------------------------------
|
|
@@ -613,7 +760,15 @@ def cross_encoder_rerank(query: str, results: List[Dict], top_k: int = 16) -> Li
|
|
|
613
760
|
return scored[:top_k] + remaining
|
|
614
761
|
|
|
615
762
|
def search_qmd_informed(query: str, graph_context: Dict, limit: int = 12) -> List[Dict]:
|
|
616
|
-
"""Phase 2: QMD vector search
|
|
763
|
+
"""Phase 2: QMD vector search via sqlite-vec MATCH.
|
|
764
|
+
|
|
765
|
+
Replaces the legacy Python cosine loop over JSON-serialised embeddings
|
|
766
|
+
(which also had an `ORDER BY id LIMIT 2000` bug — only the OLDEST
|
|
767
|
+
2000 rows were ever considered, so 99%+ of the corpus was invisible to
|
|
768
|
+
search at production scale). Now: native KNN over the vec0 index,
|
|
769
|
+
full-corpus top-k. Wall time at 450k rows: ~50ms native MATCH vs
|
|
770
|
+
~15s timeout previously.
|
|
771
|
+
"""
|
|
617
772
|
if not os.path.exists(QMD_DB_PATH):
|
|
618
773
|
return []
|
|
619
774
|
|
|
@@ -621,69 +776,64 @@ def search_qmd_informed(query: str, graph_context: Dict, limit: int = 12) -> Lis
|
|
|
621
776
|
if not query_embedding:
|
|
622
777
|
return []
|
|
623
778
|
|
|
624
|
-
# Enhance query with graph entities for better vector search
|
|
625
779
|
enhanced_query = query
|
|
626
780
|
if graph_context["graph_entities"]:
|
|
627
781
|
enhanced_query += " " + " ".join(graph_context["graph_entities"][:3])
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
if
|
|
631
|
-
|
|
782
|
+
enhanced_embedding = get_embedding(enhanced_query) or query_embedding
|
|
783
|
+
|
|
784
|
+
if len(enhanced_embedding) != EMBED_DIM:
|
|
785
|
+
# Dim mismatch vs vec0 DDL — the MATCH would error inside sqlite-vec.
|
|
786
|
+
# Bail with a loud log; an embedding-model mismatch in prod is the
|
|
787
|
+
# likely root cause and silent degradation would hide it.
|
|
788
|
+
log.error(
|
|
789
|
+
f"QMD search: query dim {len(enhanced_embedding)} != vec_index dim "
|
|
790
|
+
f"{EMBED_DIM} — embedding model mismatch?"
|
|
791
|
+
)
|
|
792
|
+
return []
|
|
793
|
+
qbytes = struct.pack(f"{len(enhanced_embedding)}f", *enhanced_embedding)
|
|
632
794
|
|
|
633
795
|
try:
|
|
634
|
-
conn =
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
rows = conn.execute(
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
796
|
+
conn = _open_qmd_conn()
|
|
797
|
+
# Pull a candidate pool larger than `limit` so entity-boost
|
|
798
|
+
# re-ranking has material to work with — 4× limit, floor 50.
|
|
799
|
+
k_pool = max(limit * 4, 50)
|
|
800
|
+
rows = conn.execute(
|
|
801
|
+
"""
|
|
802
|
+
SELECT c.id, c.path, c.text, v.distance
|
|
803
|
+
FROM vec_index v
|
|
804
|
+
JOIN chunks c ON c.id = v.id
|
|
805
|
+
WHERE v.embedding MATCH ? AND k = ?
|
|
806
|
+
ORDER BY v.distance
|
|
807
|
+
""",
|
|
808
|
+
(qbytes, k_pool),
|
|
809
|
+
).fetchall()
|
|
810
|
+
conn.close()
|
|
645
811
|
|
|
646
812
|
results = []
|
|
647
|
-
for
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
final_score = (similarity * VECTOR_BASE_WEIGHT) + entity_boost
|
|
673
|
-
|
|
674
|
-
if similarity > 0.2: # Threshold for inclusion
|
|
675
|
-
results.append({
|
|
676
|
-
"path": row["path"],
|
|
677
|
-
"text": row["text"][:600],
|
|
678
|
-
"score": final_score,
|
|
679
|
-
"source": "vector",
|
|
680
|
-
"base_similarity": similarity,
|
|
681
|
-
"entity_boost": entity_boost
|
|
682
|
-
})
|
|
683
|
-
except Exception as e:
|
|
684
|
-
logging.debug(f"Suppressed: {e}")
|
|
685
|
-
|
|
686
|
-
conn.close()
|
|
813
|
+
for row_id, path, text, distance in rows:
|
|
814
|
+
# vec0 distance_metric=cosine returns `1 - cos_sim` —
|
|
815
|
+
# invert to align with the rest of the codebase's `similarity`
|
|
816
|
+
# convention (1.0 = identical, 0.0 = orthogonal).
|
|
817
|
+
similarity = 1.0 - distance
|
|
818
|
+
if similarity <= 0.2:
|
|
819
|
+
continue
|
|
820
|
+
entity_boost = 0
|
|
821
|
+
path_lower = (path or "").lower()
|
|
822
|
+
text_lower = (text or "").lower()
|
|
823
|
+
for entity in graph_context["graph_entities"]:
|
|
824
|
+
el = entity.lower()
|
|
825
|
+
if el in path_lower or el in text_lower:
|
|
826
|
+
entity_boost = GRAPH_PRIORITY_BOOST
|
|
827
|
+
break
|
|
828
|
+
final_score = (similarity * VECTOR_BASE_WEIGHT) + entity_boost
|
|
829
|
+
results.append({
|
|
830
|
+
"path": path,
|
|
831
|
+
"text": (text or "")[:600],
|
|
832
|
+
"score": final_score,
|
|
833
|
+
"source": "vector",
|
|
834
|
+
"base_similarity": similarity,
|
|
835
|
+
"entity_boost": entity_boost,
|
|
836
|
+
})
|
|
687
837
|
results.sort(key=lambda x: x["score"], reverse=True)
|
|
688
838
|
return results[:limit]
|
|
689
839
|
|
|
@@ -1598,7 +1748,11 @@ async def index_internal_batch(req: IndexInternalBatchRequest) -> dict:
|
|
|
1598
1748
|
log.warning(f"L4 embed count mismatch: {len(embeddings)} != {len(norm)}")
|
|
1599
1749
|
qmd_db = Path(QMD_DB_PATH)
|
|
1600
1750
|
qmd_db.parent.mkdir(parents=True, exist_ok=True)
|
|
1601
|
-
|
|
1751
|
+
# Open with sqlite-vec loaded so we can dual-write to vec_index
|
|
1752
|
+
# below. If extension load fails, vec_index inserts silently no-op
|
|
1753
|
+
# via the try/except — chunks (JSON) still gets the write so the
|
|
1754
|
+
# corpus stays whole; search just degrades to the old path.
|
|
1755
|
+
conn = _open_qmd_conn()
|
|
1602
1756
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
1603
1757
|
conn.execute("""
|
|
1604
1758
|
CREATE TABLE IF NOT EXISTS chunks (
|
|
@@ -1612,14 +1766,33 @@ async def index_internal_batch(req: IndexInternalBatchRequest) -> dict:
|
|
|
1612
1766
|
created_at TEXT
|
|
1613
1767
|
)
|
|
1614
1768
|
""")
|
|
1769
|
+
try:
|
|
1770
|
+
_ensure_vec_index(conn)
|
|
1771
|
+
except Exception as e:
|
|
1772
|
+
log.error(f"vec_index DDL failed: {e} — falling back to chunks-only write")
|
|
1615
1773
|
for n, vec in zip(norm, embeddings):
|
|
1616
1774
|
if not vec:
|
|
1617
1775
|
continue
|
|
1618
|
-
conn.execute(
|
|
1776
|
+
cur = conn.execute(
|
|
1619
1777
|
"INSERT INTO chunks (path, text, embedding, embedding_model, embedding_dim, chunk_index, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
1620
1778
|
(f"bench/{arena}/{n['path']}.md", n["content"],
|
|
1621
1779
|
json.dumps(vec), "nv-embed-v2", len(vec), 0, now_iso),
|
|
1622
1780
|
)
|
|
1781
|
+
# Mirror into the vec0 KNN index so search_qmd_informed can
|
|
1782
|
+
# MATCH on the f32-packed vector. Dim must match the vec0 DDL
|
|
1783
|
+
# (EMBED_DIM); skip rows where the embedding shape disagrees
|
|
1784
|
+
# so a single bad row doesn't poison the batch insert.
|
|
1785
|
+
if cur.lastrowid is not None and len(vec) == EMBED_DIM:
|
|
1786
|
+
try:
|
|
1787
|
+
conn.execute(
|
|
1788
|
+
"INSERT INTO vec_index(id, embedding) VALUES (?, ?)",
|
|
1789
|
+
(cur.lastrowid, struct.pack(f"{len(vec)}f", *vec)),
|
|
1790
|
+
)
|
|
1791
|
+
except Exception as e:
|
|
1792
|
+
# vec_index dual-write is defensive — the JSON column
|
|
1793
|
+
# in chunks is still the source of truth until the
|
|
1794
|
+
# backfill task confirms vec_index is in sync.
|
|
1795
|
+
log.debug(f"vec_index insert skipped for row {cur.lastrowid}: {e}")
|
|
1623
1796
|
l4_inserted += 1
|
|
1624
1797
|
conn.commit()
|
|
1625
1798
|
conn.close()
|
|
@@ -268,6 +268,9 @@ def test_autodetect_all_fail_raises(recorder):
|
|
|
268
268
|
# ----------------------------------------------------------------------
|
|
269
269
|
|
|
270
270
|
def test_non_401_http_error_does_not_trigger_autodetect(recorder):
|
|
271
|
+
# max_retries=0 isolates this test to autodetect behaviour. With
|
|
272
|
+
# retries enabled (default), 503 triggers the retry path which is
|
|
273
|
+
# exercised separately in the retry tests below.
|
|
271
274
|
recorder.respond(
|
|
272
275
|
"https://gw/v1/embeddings",
|
|
273
276
|
_FakeResponse(503, "upstream down"),
|
|
@@ -277,6 +280,7 @@ def test_non_401_http_error_does_not_trigger_autodetect(recorder):
|
|
|
277
280
|
api_key="k",
|
|
278
281
|
model="m",
|
|
279
282
|
provider=PROVIDERS["openai"],
|
|
283
|
+
max_retries=0,
|
|
280
284
|
)
|
|
281
285
|
with pytest.raises(EmbedHTTPError) as exc:
|
|
282
286
|
client.embed_batch(["x"])
|
|
@@ -490,3 +494,200 @@ def test_from_env_default_max_batch_is_five(monkeypatch):
|
|
|
490
494
|
client.embed_batch([f"t{i}" for i in range(10)])
|
|
491
495
|
# 10 with default chunk=5 → [5, 5] → 2 calls
|
|
492
496
|
assert len(stub.calls) == 2
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
# ----------------------------------------------------------------------
|
|
500
|
+
# Retry-with-jitter on transient gateway saturation (502/503/504/429)
|
|
501
|
+
# ----------------------------------------------------------------------
|
|
502
|
+
#
|
|
503
|
+
# These tests exercise the retry path added 2026-05-15. Motivation:
|
|
504
|
+
# the Pentatonic AI Gateway has a K≈10 concurrency cap and 502s under
|
|
505
|
+
# saturation; without retry, a single 502 cascades through the engine's
|
|
506
|
+
# per-layer fallback path and amplifies load instead of damping it.
|
|
507
|
+
# See the prod incident note on EmbedClient.__init__ for context.
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class _SequencedRecorder:
|
|
511
|
+
"""Returns a different response on each successive call.
|
|
512
|
+
|
|
513
|
+
The default `_Recorder` returns the same response every time, which
|
|
514
|
+
is wrong for retry tests — we need to verify "first call 502, then
|
|
515
|
+
succeed on retry". This recorder pops responses off a queue per
|
|
516
|
+
URL and falls back to the last response if the queue is empty
|
|
517
|
+
(matching the "persistent failure" test case naturally).
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
def __init__(self):
|
|
521
|
+
self.calls: list[dict] = []
|
|
522
|
+
self.queues: dict[str, list[_FakeResponse]] = {}
|
|
523
|
+
|
|
524
|
+
def queue(self, url: str, responses: list[_FakeResponse]) -> None:
|
|
525
|
+
self.queues[url] = list(responses)
|
|
526
|
+
|
|
527
|
+
def __call__(self, url, *, json, headers, timeout):
|
|
528
|
+
self.calls.append({"url": url, "json": json})
|
|
529
|
+
q = self.queues.get(url, [])
|
|
530
|
+
if not q:
|
|
531
|
+
return _FakeResponse(401, "no responses queued")
|
|
532
|
+
# Pop unless this is the last one — keep returning the tail so
|
|
533
|
+
# "all attempts fail" tests don't need to queue N copies.
|
|
534
|
+
return q.pop(0) if len(q) > 1 else q[0]
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@pytest.fixture
|
|
538
|
+
def sequenced(monkeypatch):
|
|
539
|
+
rec = _SequencedRecorder()
|
|
540
|
+
monkeypatch.setattr(httpx, "post", rec)
|
|
541
|
+
# Avoid the test taking real wall time on backoff sleeps — patch
|
|
542
|
+
# time.sleep to no-op. The jitter calculation still runs, just
|
|
543
|
+
# without the actual delay.
|
|
544
|
+
import time as _time
|
|
545
|
+
monkeypatch.setattr(_time, "sleep", lambda _s: None)
|
|
546
|
+
return rec
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def test_retries_on_502_and_succeeds(sequenced):
|
|
550
|
+
sequenced.queue(
|
|
551
|
+
"https://gw/v1/embeddings",
|
|
552
|
+
[
|
|
553
|
+
_FakeResponse(502, "bad gateway"),
|
|
554
|
+
_FakeResponse(200, {"data": [{"embedding": [0.1, 0.2]}]}),
|
|
555
|
+
],
|
|
556
|
+
)
|
|
557
|
+
client = EmbedClient(
|
|
558
|
+
url="https://gw/v1/embeddings",
|
|
559
|
+
api_key="k",
|
|
560
|
+
model="m",
|
|
561
|
+
provider=PROVIDERS["openai"],
|
|
562
|
+
max_retries=3,
|
|
563
|
+
)
|
|
564
|
+
out = client.embed_batch(["hello"])
|
|
565
|
+
assert out == [[0.1, 0.2]]
|
|
566
|
+
# First call 502, second call 200 — exactly two attempts.
|
|
567
|
+
assert len(sequenced.calls) == 2
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def test_retries_on_503_504_429(sequenced):
|
|
571
|
+
"""Each transient code triggers the retry path the same way."""
|
|
572
|
+
for code in (503, 504, 429):
|
|
573
|
+
sequenced.calls.clear()
|
|
574
|
+
sequenced.queue(
|
|
575
|
+
"https://gw/v1/embeddings",
|
|
576
|
+
[
|
|
577
|
+
_FakeResponse(code, "transient"),
|
|
578
|
+
_FakeResponse(200, {"data": [{"embedding": [0.0]}]}),
|
|
579
|
+
],
|
|
580
|
+
)
|
|
581
|
+
client = EmbedClient(
|
|
582
|
+
url="https://gw/v1/embeddings",
|
|
583
|
+
api_key="k",
|
|
584
|
+
model="m",
|
|
585
|
+
provider=PROVIDERS["openai"],
|
|
586
|
+
max_retries=3,
|
|
587
|
+
)
|
|
588
|
+
out = client.embed_batch(["x"])
|
|
589
|
+
assert out == [[0.0]], f"retry failed for status {code}"
|
|
590
|
+
assert len(sequenced.calls) == 2, f"wrong call count for status {code}"
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def test_does_not_retry_on_500(sequenced):
|
|
594
|
+
"""500 is server-side bug, not transient saturation — fail fast."""
|
|
595
|
+
sequenced.queue(
|
|
596
|
+
"https://gw/v1/embeddings",
|
|
597
|
+
[_FakeResponse(500, "internal server error")],
|
|
598
|
+
)
|
|
599
|
+
client = EmbedClient(
|
|
600
|
+
url="https://gw/v1/embeddings",
|
|
601
|
+
api_key="k",
|
|
602
|
+
model="m",
|
|
603
|
+
provider=PROVIDERS["openai"],
|
|
604
|
+
max_retries=3,
|
|
605
|
+
)
|
|
606
|
+
with pytest.raises(EmbedHTTPError) as exc:
|
|
607
|
+
client.embed_batch(["x"])
|
|
608
|
+
assert exc.value.status == 500
|
|
609
|
+
# Exactly one attempt — no retry on 500.
|
|
610
|
+
assert len(sequenced.calls) == 1
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def test_does_not_retry_on_400(sequenced):
|
|
614
|
+
"""4xx (other than 401-autodetect / 429) indicates caller error."""
|
|
615
|
+
sequenced.queue(
|
|
616
|
+
"https://gw/v1/embeddings",
|
|
617
|
+
[_FakeResponse(400, "bad request")],
|
|
618
|
+
)
|
|
619
|
+
client = EmbedClient(
|
|
620
|
+
url="https://gw/v1/embeddings",
|
|
621
|
+
api_key="k",
|
|
622
|
+
model="m",
|
|
623
|
+
provider=PROVIDERS["openai"],
|
|
624
|
+
max_retries=3,
|
|
625
|
+
)
|
|
626
|
+
with pytest.raises(EmbedHTTPError) as exc:
|
|
627
|
+
client.embed_batch(["x"])
|
|
628
|
+
assert exc.value.status == 400
|
|
629
|
+
assert len(sequenced.calls) == 1
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def test_max_retries_exhausted_raises(sequenced):
|
|
633
|
+
"""Persistent 502 raises after max_retries+1 attempts."""
|
|
634
|
+
sequenced.queue(
|
|
635
|
+
"https://gw/v1/embeddings",
|
|
636
|
+
[_FakeResponse(502, "still down")],
|
|
637
|
+
)
|
|
638
|
+
client = EmbedClient(
|
|
639
|
+
url="https://gw/v1/embeddings",
|
|
640
|
+
api_key="k",
|
|
641
|
+
model="m",
|
|
642
|
+
provider=PROVIDERS["openai"],
|
|
643
|
+
max_retries=3,
|
|
644
|
+
)
|
|
645
|
+
with pytest.raises(EmbedHTTPError) as exc:
|
|
646
|
+
client.embed_batch(["x"])
|
|
647
|
+
assert exc.value.status == 502
|
|
648
|
+
# max_retries=3 → 1 original + 3 retries = 4 calls total.
|
|
649
|
+
assert len(sequenced.calls) == 4
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def test_max_retries_zero_disables_retry(sequenced):
|
|
653
|
+
"""Explicit opt-out preserves pre-fix behaviour for callers that
|
|
654
|
+
handle their own retry."""
|
|
655
|
+
sequenced.queue(
|
|
656
|
+
"https://gw/v1/embeddings",
|
|
657
|
+
[_FakeResponse(502, "down")],
|
|
658
|
+
)
|
|
659
|
+
client = EmbedClient(
|
|
660
|
+
url="https://gw/v1/embeddings",
|
|
661
|
+
api_key="k",
|
|
662
|
+
model="m",
|
|
663
|
+
provider=PROVIDERS["openai"],
|
|
664
|
+
max_retries=0,
|
|
665
|
+
)
|
|
666
|
+
with pytest.raises(EmbedHTTPError):
|
|
667
|
+
client.embed_batch(["x"])
|
|
668
|
+
assert len(sequenced.calls) == 1
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def test_from_env_reads_retry_config(monkeypatch):
|
|
672
|
+
"""{prefix}EMBED_MAX_RETRIES + EMBED_RETRY_BASE_DELAY +
|
|
673
|
+
EMBED_RETRY_MAX_DELAY override the defaults."""
|
|
674
|
+
monkeypatch.setenv("L4_NV_EMBED_URL", "https://gw/v1/embeddings")
|
|
675
|
+
monkeypatch.setenv("L4_EMBED_API_KEY", "k")
|
|
676
|
+
monkeypatch.setenv("L4_EMBED_MAX_RETRIES", "5")
|
|
677
|
+
monkeypatch.setenv("L4_EMBED_RETRY_BASE_DELAY", "0.25")
|
|
678
|
+
monkeypatch.setenv("L4_EMBED_RETRY_MAX_DELAY", "2.5")
|
|
679
|
+
client = EmbedClient.from_env(prefix="L4_")
|
|
680
|
+
assert client._max_retries == 5
|
|
681
|
+
assert client._retry_base_delay == 0.25
|
|
682
|
+
assert client._retry_max_delay == 2.5
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def test_from_env_default_retry_config(monkeypatch):
|
|
686
|
+
"""Defaults: 3 retries, 100ms base, 1s cap — tuned for K≈10
|
|
687
|
+
gateway under burst load."""
|
|
688
|
+
monkeypatch.setenv("L4_NV_EMBED_URL", "https://gw/v1/embeddings")
|
|
689
|
+
monkeypatch.setenv("L4_EMBED_API_KEY", "k")
|
|
690
|
+
client = EmbedClient.from_env(prefix="L4_")
|
|
691
|
+
assert client._max_retries == 3
|
|
692
|
+
assert client._retry_base_delay == 0.1
|
|
693
|
+
assert client._retry_max_delay == 1.0
|