@pentatonic-ai/ai-agent-sdk 0.9.4 → 0.9.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ import json
18
18
  import logging
19
19
  import os
20
20
  import sqlite3
21
+ import struct
21
22
  import sys
22
23
  import time
23
24
  from contextlib import asynccontextmanager
@@ -34,6 +35,11 @@ from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate
34
35
  from pydantic import BaseModel
35
36
  import uvicorn
36
37
 
38
+ try:
39
+ import sqlite_vec # 0.1.9 — native KNN MATCH over packed-f32 vec0 tables
40
+ except ImportError:
41
+ sqlite_vec = None # Caller logs loudly if helpers can't load the extension
42
+
37
43
  # Shared embed client lives at engine/services/_shared/.
38
44
  sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
39
45
  from _shared.embed_provider import EmbedClient # noqa: E402
@@ -101,6 +107,59 @@ QMD_DB_PATH = _resolve_qmd_db()
101
107
  OLLAMA_URL = os.environ.get("PME_OLLAMA_URL", "http://localhost:11434/api/embeddings")
102
108
  EMBEDDING_MODEL = os.environ.get("PME_EMBED_MODEL", "nomic-embed-text")
103
109
 
110
+ # Embedding dimension for the vec0 virtual table. Production gateway
111
+ # (lambda-gateway.pentatonic.com/v1/embed via pentatonic-gateway provider)
112
+ # returns NV-Embed-v2 4096-dim vectors. The vec0 schema requires the dim
113
+ # at DDL time and writers must match — keep this in lockstep with the
114
+ # gateway / EmbedClient config.
115
+ EMBED_DIM = int(os.environ.get("PME_EMBED_DIM", "4096"))
116
+
117
+
118
+ def _open_qmd_conn() -> sqlite3.Connection:
119
+ """Open qmd.sqlite with sqlite-vec loaded.
120
+
121
+ Falls back to a plain sqlite3 connection if the extension can't load —
122
+ MATCH-form queries will then fail loudly at execute time, which is the
123
+ right signal (loud error > silent degradation back to Python cosine).
124
+ Callers that only need scalar columns (chunks.path, chunks.text) work
125
+ fine without the extension.
126
+
127
+ ``check_same_thread=False`` is intentional: the async backfill yields
128
+ via ``asyncio.to_thread`` to keep /search responsive, which means the
129
+ connection is handed off between event-loop / thread-pool workers.
130
+ sqlite's default thread-safety check would otherwise reject the
131
+ cross-thread reuse even though only one worker touches it at a time.
132
+ """
133
+ conn = sqlite3.connect(QMD_DB_PATH, timeout=10, check_same_thread=False)
134
+ if sqlite_vec is None:
135
+ log.error("sqlite_vec module not importable — qmd vec_index unavailable")
136
+ return conn
137
+ try:
138
+ conn.enable_load_extension(True)
139
+ sqlite_vec.load(conn)
140
+ conn.enable_load_extension(False)
141
+ except Exception as e:
142
+ log.error(f"sqlite-vec load failed: {e} — qmd search will be degraded")
143
+ return conn
144
+
145
+
146
+ def _ensure_vec_index(conn: sqlite3.Connection) -> None:
147
+ """Create the vec0 KNN index if not already present. Idempotent.
148
+
149
+ `distance_metric=cosine` is non-default — sqlite-vec defaults to L2
150
+ (Euclidean). Probe confirmed cosine returns `1 - cos_sim` as the
151
+ distance. The id column is a regular INTEGER PRIMARY KEY so we can
152
+ JOIN back to `chunks` on the row's autoinc id.
153
+ """
154
+ conn.execute(
155
+ f"""
156
+ CREATE VIRTUAL TABLE IF NOT EXISTS vec_index USING vec0(
157
+ id INTEGER PRIMARY KEY,
158
+ embedding float[{EMBED_DIM}] distance_metric=cosine
159
+ )
160
+ """
161
+ )
162
+
104
163
  # NV-Embed-v2 service (primary, 4096-dim). URL/auth/path/body/response are
105
164
  # managed by the shared EmbedClient; PME_EMBED_PROVIDER (default openai)
106
165
  # selects auth scheme (Bearer vs X-API-Key) and request shape.
@@ -177,13 +236,25 @@ def get_http_client() -> httpx.AsyncClient:
177
236
  async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
178
237
  """Open the neo4j driver + HTTP client at process startup, close on
179
238
  shutdown. Without this, the first request pays driver-open latency
180
- and the driver is never properly closed on SIGTERM (leaking conns)."""
239
+ and the driver is never properly closed on SIGTERM (leaking conns).
240
+
241
+ Also schedules the vec_index backfill as a background task so the
242
+ proxy can start serving immediately while older chunks copy across
243
+ into the KNN index — first-time migration of ~450k rows takes
244
+ minutes and would otherwise block /health.
245
+ """
181
246
  global _neo4j_driver, _http_client
182
247
  _neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=NEO4J_AUTH)
183
248
  _http_client = httpx.AsyncClient(timeout=30.0)
249
+ backfill_task = asyncio.create_task(_backfill_vec_index())
184
250
  try:
185
251
  yield
186
252
  finally:
253
+ backfill_task.cancel()
254
+ try:
255
+ await backfill_task
256
+ except (asyncio.CancelledError, Exception):
257
+ pass
187
258
  if _neo4j_driver is not None:
188
259
  await _neo4j_driver.close()
189
260
  _neo4j_driver = None
@@ -192,6 +263,82 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
192
263
  _http_client = None
193
264
 
194
265
 
266
+ async def _backfill_vec_index() -> None:
267
+ """One-time migration: copy existing chunks.embedding (JSON) into
268
+ vec_index (f32 bytes).
269
+
270
+ Idempotent. Runs at proxy startup if vec_index has fewer rows than
271
+ chunks. Async so it doesn't block /health — the proxy serves
272
+ requests in parallel and search degrades to partial-corpus results
273
+ until the backfill finishes (any chunk already mirrored into
274
+ vec_index is findable; the rest are invisible to search but still
275
+ in L0/L3/L5/L6).
276
+
277
+ At 450k rows + ~460 rows/s insert rate this takes ~16 min on a
278
+ cold prod instance. Subsequent restarts no-op cleanly.
279
+ """
280
+ if sqlite_vec is None:
281
+ log.error("sqlite_vec module missing — backfill skipped, search will be degraded")
282
+ return
283
+ if not os.path.exists(QMD_DB_PATH):
284
+ log.info("vec_index backfill skipped — qmd.sqlite does not exist yet")
285
+ return
286
+ try:
287
+ conn = await asyncio.to_thread(_open_qmd_conn)
288
+ await asyncio.to_thread(_ensure_vec_index, conn)
289
+ chunks_n = conn.execute(
290
+ "SELECT count(*) FROM chunks WHERE embedding IS NOT NULL"
291
+ ).fetchone()[0]
292
+ vec_n = conn.execute("SELECT count(*) FROM vec_index").fetchone()[0]
293
+ if vec_n >= chunks_n:
294
+ log.info(f"vec_index backfill skipped — already in sync ({vec_n}/{chunks_n})")
295
+ conn.close()
296
+ return
297
+ missing = chunks_n - vec_n
298
+ log.info(f"vec_index backfill starting — {missing} rows to copy")
299
+ cursor = conn.execute(
300
+ """
301
+ SELECT c.id, c.embedding
302
+ FROM chunks c
303
+ LEFT JOIN vec_index v ON v.id = c.id
304
+ WHERE v.id IS NULL AND c.embedding IS NOT NULL
305
+ """
306
+ )
307
+ BATCH = 500
308
+ copied = 0
309
+ while True:
310
+ batch = await asyncio.to_thread(cursor.fetchmany, BATCH)
311
+ if not batch:
312
+ break
313
+ def _insert_batch() -> int:
314
+ inserted = 0
315
+ with conn:
316
+ for cid, emb_json in batch:
317
+ try:
318
+ vec = json.loads(emb_json)
319
+ except Exception:
320
+ continue
321
+ if len(vec) != EMBED_DIM:
322
+ continue
323
+ conn.execute(
324
+ "INSERT INTO vec_index(id, embedding) VALUES (?, ?)",
325
+ (cid, struct.pack(f"{len(vec)}f", *vec)),
326
+ )
327
+ inserted += 1
328
+ return inserted
329
+ copied += await asyncio.to_thread(_insert_batch)
330
+ log.info(f"vec_index backfill progress: {copied}/{missing}")
331
+ # Yield generously so /search + writers aren't starved.
332
+ await asyncio.sleep(0)
333
+ log.info(f"vec_index backfill done — {copied} rows copied")
334
+ conn.close()
335
+ except asyncio.CancelledError:
336
+ log.info("vec_index backfill cancelled during shutdown")
337
+ raise
338
+ except Exception as e:
339
+ log.error(f"vec_index backfill failed: {e}")
340
+
341
+
195
342
  app = FastAPI(title="Sequential HybridRAG Proxy", version="1.0.0", lifespan=lifespan)
196
343
 
197
344
  # ---------------------------------------------------------------------------
@@ -613,7 +760,15 @@ def cross_encoder_rerank(query: str, results: List[Dict], top_k: int = 16) -> Li
613
760
  return scored[:top_k] + remaining
614
761
 
615
762
  def search_qmd_informed(query: str, graph_context: Dict, limit: int = 12) -> List[Dict]:
616
- """Phase 2: QMD vector search informed by graph results."""
763
+ """Phase 2: QMD vector search via sqlite-vec MATCH.
764
+
765
+ Replaces the legacy Python cosine loop over JSON-serialised embeddings
766
+ (which also had an `ORDER BY id LIMIT 2000` bug — only the OLDEST
767
+ 2000 rows were ever considered, so 99%+ of the corpus was invisible to
768
+ search at production scale). Now: native KNN over the vec0 index,
769
+ full-corpus top-k. Wall time at 450k rows: ~50ms native MATCH vs
770
+ ~15s timeout previously.
771
+ """
617
772
  if not os.path.exists(QMD_DB_PATH):
618
773
  return []
619
774
 
@@ -621,69 +776,64 @@ def search_qmd_informed(query: str, graph_context: Dict, limit: int = 12) -> Lis
621
776
  if not query_embedding:
622
777
  return []
623
778
 
624
- # Enhance query with graph entities for better vector search
625
779
  enhanced_query = query
626
780
  if graph_context["graph_entities"]:
627
781
  enhanced_query += " " + " ".join(graph_context["graph_entities"][:3])
628
-
629
- enhanced_embedding = get_embedding(enhanced_query)
630
- if not enhanced_embedding:
631
- enhanced_embedding = query_embedding
782
+ enhanced_embedding = get_embedding(enhanced_query) or query_embedding
783
+
784
+ if len(enhanced_embedding) != EMBED_DIM:
785
+ # Dim mismatch vs vec0 DDL — the MATCH would error inside sqlite-vec.
786
+ # Bail with a loud log; an embedding-model mismatch in prod is the
787
+ # likely root cause and silent degradation would hide it.
788
+ log.error(
789
+ f"QMD search: query dim {len(enhanced_embedding)} != vec_index dim "
790
+ f"{EMBED_DIM} — embedding model mismatch?"
791
+ )
792
+ return []
793
+ qbytes = struct.pack(f"{len(enhanced_embedding)}f", *enhanced_embedding)
632
794
 
633
795
  try:
634
- conn = sqlite3.connect(QMD_DB_PATH, timeout=5)
635
- conn.row_factory = sqlite3.Row
636
-
637
- # Get vectors and compute similarity
638
- rows = conn.execute("""
639
- SELECT id, path, text, embedding
640
- FROM chunks
641
- WHERE embedding IS NOT NULL
642
- ORDER BY id
643
- LIMIT 2000
644
- """).fetchall()
796
+ conn = _open_qmd_conn()
797
+ # Pull a candidate pool larger than `limit` so entity-boost
798
+ # re-ranking has material to work with — 4× limit, floor 50.
799
+ k_pool = max(limit * 4, 50)
800
+ rows = conn.execute(
801
+ """
802
+ SELECT c.id, c.path, c.text, v.distance
803
+ FROM vec_index v
804
+ JOIN chunks c ON c.id = v.id
805
+ WHERE v.embedding MATCH ? AND k = ?
806
+ ORDER BY v.distance
807
+ """,
808
+ (qbytes, k_pool),
809
+ ).fetchall()
810
+ conn.close()
645
811
 
646
812
  results = []
647
- for row in rows:
648
- try:
649
- # Deserialize embedding
650
- embedding_data = row["embedding"]
651
- if isinstance(embedding_data, str):
652
- embedding = json.loads(embedding_data)
653
- else:
654
- embedding = list(embedding_data)
655
-
656
- # Cosine similarity with enhanced query
657
- dot = sum(a * b for a, b in zip(enhanced_embedding, embedding))
658
- norm_q = sum(x * x for x in enhanced_embedding) ** 0.5
659
- norm_e = sum(x * x for x in embedding) ** 0.5
660
-
661
- if norm_q > 0 and norm_e > 0:
662
- similarity = dot / (norm_q * norm_e)
663
-
664
- # Boost score if path contains graph entities
665
- entity_boost = 0
666
- path_lower = row["path"].lower()
667
- for entity in graph_context["graph_entities"]:
668
- if entity.lower() in path_lower or entity.lower() in row["text"].lower():
669
- entity_boost = GRAPH_PRIORITY_BOOST
670
- break
671
-
672
- final_score = (similarity * VECTOR_BASE_WEIGHT) + entity_boost
673
-
674
- if similarity > 0.2: # Threshold for inclusion
675
- results.append({
676
- "path": row["path"],
677
- "text": row["text"][:600],
678
- "score": final_score,
679
- "source": "vector",
680
- "base_similarity": similarity,
681
- "entity_boost": entity_boost
682
- })
683
- except Exception as e:
684
- logging.debug(f"Suppressed: {e}")
685
-
686
- conn.close()
813
+ for row_id, path, text, distance in rows:
814
+ # vec0 distance_metric=cosine returns `1 - cos_sim` —
815
+ # invert to align with the rest of the codebase's `similarity`
816
+ # convention (1.0 = identical, 0.0 = orthogonal).
817
+ similarity = 1.0 - distance
818
+ if similarity <= 0.2:
819
+ continue
820
+ entity_boost = 0
821
+ path_lower = (path or "").lower()
822
+ text_lower = (text or "").lower()
823
+ for entity in graph_context["graph_entities"]:
824
+ el = entity.lower()
825
+ if el in path_lower or el in text_lower:
826
+ entity_boost = GRAPH_PRIORITY_BOOST
827
+ break
828
+ final_score = (similarity * VECTOR_BASE_WEIGHT) + entity_boost
829
+ results.append({
830
+ "path": path,
831
+ "text": (text or "")[:600],
832
+ "score": final_score,
833
+ "source": "vector",
834
+ "base_similarity": similarity,
835
+ "entity_boost": entity_boost,
836
+ })
687
837
  results.sort(key=lambda x: x["score"], reverse=True)
688
838
  return results[:limit]
689
839
 
@@ -1598,7 +1748,11 @@ async def index_internal_batch(req: IndexInternalBatchRequest) -> dict:
1598
1748
  log.warning(f"L4 embed count mismatch: {len(embeddings)} != {len(norm)}")
1599
1749
  qmd_db = Path(QMD_DB_PATH)
1600
1750
  qmd_db.parent.mkdir(parents=True, exist_ok=True)
1601
- conn = sqlite3.connect(str(qmd_db), timeout=10)
1751
+ # Open with sqlite-vec loaded so we can dual-write to vec_index
1752
+ # below. If extension load fails, vec_index inserts silently no-op
1753
+ # via the try/except — chunks (JSON) still gets the write so the
1754
+ # corpus stays whole; search just degrades to the old path.
1755
+ conn = _open_qmd_conn()
1602
1756
  conn.execute("PRAGMA journal_mode=WAL")
1603
1757
  conn.execute("""
1604
1758
  CREATE TABLE IF NOT EXISTS chunks (
@@ -1612,14 +1766,33 @@ async def index_internal_batch(req: IndexInternalBatchRequest) -> dict:
1612
1766
  created_at TEXT
1613
1767
  )
1614
1768
  """)
1769
+ try:
1770
+ _ensure_vec_index(conn)
1771
+ except Exception as e:
1772
+ log.error(f"vec_index DDL failed: {e} — falling back to chunks-only write")
1615
1773
  for n, vec in zip(norm, embeddings):
1616
1774
  if not vec:
1617
1775
  continue
1618
- conn.execute(
1776
+ cur = conn.execute(
1619
1777
  "INSERT INTO chunks (path, text, embedding, embedding_model, embedding_dim, chunk_index, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
1620
1778
  (f"bench/{arena}/{n['path']}.md", n["content"],
1621
1779
  json.dumps(vec), "nv-embed-v2", len(vec), 0, now_iso),
1622
1780
  )
1781
+ # Mirror into the vec0 KNN index so search_qmd_informed can
1782
+ # MATCH on the f32-packed vector. Dim must match the vec0 DDL
1783
+ # (EMBED_DIM); skip rows where the embedding shape disagrees
1784
+ # so a single bad row doesn't poison the batch insert.
1785
+ if cur.lastrowid is not None and len(vec) == EMBED_DIM:
1786
+ try:
1787
+ conn.execute(
1788
+ "INSERT INTO vec_index(id, embedding) VALUES (?, ?)",
1789
+ (cur.lastrowid, struct.pack(f"{len(vec)}f", *vec)),
1790
+ )
1791
+ except Exception as e:
1792
+ # vec_index dual-write is defensive — the JSON column
1793
+ # in chunks is still the source of truth until the
1794
+ # backfill task confirms vec_index is in sync.
1795
+ log.debug(f"vec_index insert skipped for row {cur.lastrowid}: {e}")
1623
1796
  l4_inserted += 1
1624
1797
  conn.commit()
1625
1798
  conn.close()
@@ -268,6 +268,9 @@ def test_autodetect_all_fail_raises(recorder):
268
268
  # ----------------------------------------------------------------------
269
269
 
270
270
  def test_non_401_http_error_does_not_trigger_autodetect(recorder):
271
+ # max_retries=0 isolates this test to autodetect behaviour. With
272
+ # retries enabled (default), 503 triggers the retry path which is
273
+ # exercised separately in the retry tests below.
271
274
  recorder.respond(
272
275
  "https://gw/v1/embeddings",
273
276
  _FakeResponse(503, "upstream down"),
@@ -277,6 +280,7 @@ def test_non_401_http_error_does_not_trigger_autodetect(recorder):
277
280
  api_key="k",
278
281
  model="m",
279
282
  provider=PROVIDERS["openai"],
283
+ max_retries=0,
280
284
  )
281
285
  with pytest.raises(EmbedHTTPError) as exc:
282
286
  client.embed_batch(["x"])
@@ -490,3 +494,200 @@ def test_from_env_default_max_batch_is_five(monkeypatch):
490
494
  client.embed_batch([f"t{i}" for i in range(10)])
491
495
  # 10 with default chunk=5 → [5, 5] → 2 calls
492
496
  assert len(stub.calls) == 2
497
+
498
+
499
+ # ----------------------------------------------------------------------
500
+ # Retry-with-jitter on transient gateway saturation (502/503/504/429)
501
+ # ----------------------------------------------------------------------
502
+ #
503
+ # These tests exercise the retry path added 2026-05-15. Motivation:
504
+ # the Pentatonic AI Gateway has a K≈10 concurrency cap and 502s under
505
+ # saturation; without retry, a single 502 cascades through the engine's
506
+ # per-layer fallback path and amplifies load instead of damping it.
507
+ # See the prod incident note on EmbedClient.__init__ for context.
508
+
509
+
510
+ class _SequencedRecorder:
511
+ """Returns a different response on each successive call.
512
+
513
+ The default `_Recorder` returns the same response every time, which
514
+ is wrong for retry tests — we need to verify "first call 502, then
515
+ succeed on retry". This recorder pops responses off a queue per
516
+ URL and falls back to the last response if the queue is empty
517
+ (matching the "persistent failure" test case naturally).
518
+ """
519
+
520
+ def __init__(self):
521
+ self.calls: list[dict] = []
522
+ self.queues: dict[str, list[_FakeResponse]] = {}
523
+
524
+ def queue(self, url: str, responses: list[_FakeResponse]) -> None:
525
+ self.queues[url] = list(responses)
526
+
527
+ def __call__(self, url, *, json, headers, timeout):
528
+ self.calls.append({"url": url, "json": json})
529
+ q = self.queues.get(url, [])
530
+ if not q:
531
+ return _FakeResponse(401, "no responses queued")
532
+ # Pop unless this is the last one — keep returning the tail so
533
+ # "all attempts fail" tests don't need to queue N copies.
534
+ return q.pop(0) if len(q) > 1 else q[0]
535
+
536
+
537
+ @pytest.fixture
538
+ def sequenced(monkeypatch):
539
+ rec = _SequencedRecorder()
540
+ monkeypatch.setattr(httpx, "post", rec)
541
+ # Avoid the test taking real wall time on backoff sleeps — patch
542
+ # time.sleep to no-op. The jitter calculation still runs, just
543
+ # without the actual delay.
544
+ import time as _time
545
+ monkeypatch.setattr(_time, "sleep", lambda _s: None)
546
+ return rec
547
+
548
+
549
+ def test_retries_on_502_and_succeeds(sequenced):
550
+ sequenced.queue(
551
+ "https://gw/v1/embeddings",
552
+ [
553
+ _FakeResponse(502, "bad gateway"),
554
+ _FakeResponse(200, {"data": [{"embedding": [0.1, 0.2]}]}),
555
+ ],
556
+ )
557
+ client = EmbedClient(
558
+ url="https://gw/v1/embeddings",
559
+ api_key="k",
560
+ model="m",
561
+ provider=PROVIDERS["openai"],
562
+ max_retries=3,
563
+ )
564
+ out = client.embed_batch(["hello"])
565
+ assert out == [[0.1, 0.2]]
566
+ # First call 502, second call 200 — exactly two attempts.
567
+ assert len(sequenced.calls) == 2
568
+
569
+
570
+ def test_retries_on_503_504_429(sequenced):
571
+ """Each transient code triggers the retry path the same way."""
572
+ for code in (503, 504, 429):
573
+ sequenced.calls.clear()
574
+ sequenced.queue(
575
+ "https://gw/v1/embeddings",
576
+ [
577
+ _FakeResponse(code, "transient"),
578
+ _FakeResponse(200, {"data": [{"embedding": [0.0]}]}),
579
+ ],
580
+ )
581
+ client = EmbedClient(
582
+ url="https://gw/v1/embeddings",
583
+ api_key="k",
584
+ model="m",
585
+ provider=PROVIDERS["openai"],
586
+ max_retries=3,
587
+ )
588
+ out = client.embed_batch(["x"])
589
+ assert out == [[0.0]], f"retry failed for status {code}"
590
+ assert len(sequenced.calls) == 2, f"wrong call count for status {code}"
591
+
592
+
593
+ def test_does_not_retry_on_500(sequenced):
594
+ """500 is server-side bug, not transient saturation — fail fast."""
595
+ sequenced.queue(
596
+ "https://gw/v1/embeddings",
597
+ [_FakeResponse(500, "internal server error")],
598
+ )
599
+ client = EmbedClient(
600
+ url="https://gw/v1/embeddings",
601
+ api_key="k",
602
+ model="m",
603
+ provider=PROVIDERS["openai"],
604
+ max_retries=3,
605
+ )
606
+ with pytest.raises(EmbedHTTPError) as exc:
607
+ client.embed_batch(["x"])
608
+ assert exc.value.status == 500
609
+ # Exactly one attempt — no retry on 500.
610
+ assert len(sequenced.calls) == 1
611
+
612
+
613
+ def test_does_not_retry_on_400(sequenced):
614
+ """4xx (other than 401-autodetect / 429) indicates caller error."""
615
+ sequenced.queue(
616
+ "https://gw/v1/embeddings",
617
+ [_FakeResponse(400, "bad request")],
618
+ )
619
+ client = EmbedClient(
620
+ url="https://gw/v1/embeddings",
621
+ api_key="k",
622
+ model="m",
623
+ provider=PROVIDERS["openai"],
624
+ max_retries=3,
625
+ )
626
+ with pytest.raises(EmbedHTTPError) as exc:
627
+ client.embed_batch(["x"])
628
+ assert exc.value.status == 400
629
+ assert len(sequenced.calls) == 1
630
+
631
+
632
+ def test_max_retries_exhausted_raises(sequenced):
633
+ """Persistent 502 raises after max_retries+1 attempts."""
634
+ sequenced.queue(
635
+ "https://gw/v1/embeddings",
636
+ [_FakeResponse(502, "still down")],
637
+ )
638
+ client = EmbedClient(
639
+ url="https://gw/v1/embeddings",
640
+ api_key="k",
641
+ model="m",
642
+ provider=PROVIDERS["openai"],
643
+ max_retries=3,
644
+ )
645
+ with pytest.raises(EmbedHTTPError) as exc:
646
+ client.embed_batch(["x"])
647
+ assert exc.value.status == 502
648
+ # max_retries=3 → 1 original + 3 retries = 4 calls total.
649
+ assert len(sequenced.calls) == 4
650
+
651
+
652
+ def test_max_retries_zero_disables_retry(sequenced):
653
+ """Explicit opt-out preserves pre-fix behaviour for callers that
654
+ handle their own retry."""
655
+ sequenced.queue(
656
+ "https://gw/v1/embeddings",
657
+ [_FakeResponse(502, "down")],
658
+ )
659
+ client = EmbedClient(
660
+ url="https://gw/v1/embeddings",
661
+ api_key="k",
662
+ model="m",
663
+ provider=PROVIDERS["openai"],
664
+ max_retries=0,
665
+ )
666
+ with pytest.raises(EmbedHTTPError):
667
+ client.embed_batch(["x"])
668
+ assert len(sequenced.calls) == 1
669
+
670
+
671
+ def test_from_env_reads_retry_config(monkeypatch):
672
+ """{prefix}EMBED_MAX_RETRIES + EMBED_RETRY_BASE_DELAY +
673
+ EMBED_RETRY_MAX_DELAY override the defaults."""
674
+ monkeypatch.setenv("L4_NV_EMBED_URL", "https://gw/v1/embeddings")
675
+ monkeypatch.setenv("L4_EMBED_API_KEY", "k")
676
+ monkeypatch.setenv("L4_EMBED_MAX_RETRIES", "5")
677
+ monkeypatch.setenv("L4_EMBED_RETRY_BASE_DELAY", "0.25")
678
+ monkeypatch.setenv("L4_EMBED_RETRY_MAX_DELAY", "2.5")
679
+ client = EmbedClient.from_env(prefix="L4_")
680
+ assert client._max_retries == 5
681
+ assert client._retry_base_delay == 0.25
682
+ assert client._retry_max_delay == 2.5
683
+
684
+
685
+ def test_from_env_default_retry_config(monkeypatch):
686
+ """Defaults: 3 retries, 100ms base, 1s cap — tuned for K≈10
687
+ gateway under burst load."""
688
+ monkeypatch.setenv("L4_NV_EMBED_URL", "https://gw/v1/embeddings")
689
+ monkeypatch.setenv("L4_EMBED_API_KEY", "k")
690
+ client = EmbedClient.from_env(prefix="L4_")
691
+ assert client._max_retries == 3
692
+ assert client._retry_base_delay == 0.1
693
+ assert client._retry_max_delay == 1.0