contextpilot 0.3.3__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {contextpilot-0.3.3 → contextpilot-0.3.4}/PKG-INFO +17 -11
  2. {contextpilot-0.3.3 → contextpilot-0.3.4}/README.md +16 -10
  3. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/__init__.py +1 -1
  4. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/pipeline/rag_pipeline.py +4 -3
  5. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/http_server.py +43 -23
  6. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/live_index.py +10 -18
  7. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot.egg-info/PKG-INFO +17 -11
  8. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot.egg-info/SOURCES.txt +2 -1
  9. {contextpilot-0.3.3 → contextpilot-0.3.4}/pyproject.toml +1 -1
  10. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_live_index.py +23 -0
  11. contextpilot-0.3.4/tests/test_vllm_patch.py +493 -0
  12. {contextpilot-0.3.3 → contextpilot-0.3.4}/LICENSE +0 -0
  13. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_index/__init__.py +0 -0
  14. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_index/compute_distance_cpu.py +0 -0
  15. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_index/compute_distance_gpu.py +0 -0
  16. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_index/index_construction.py +0 -0
  17. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_index/tree_nodes.py +0 -0
  18. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_ordering/__init__.py +0 -0
  19. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_ordering/inter_scheduler.py +0 -0
  20. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/context_ordering/intra_ordering.py +0 -0
  21. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/pipeline/__init__.py +0 -0
  22. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/pipeline/components.py +0 -0
  23. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/pipeline/multi_turn.py +0 -0
  24. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/retriever/__init__.py +0 -0
  25. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/retriever/bm25.py +0 -0
  26. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/retriever/faiss_embedding.py +0 -0
  27. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/retriever/mem0_retriever.py +0 -0
  28. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/retriever/pageindex_retriever.py +0 -0
  29. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/__init__.py +0 -0
  30. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/conversation_tracker.py +0 -0
  31. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/eviction_heap.py +0 -0
  32. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/http_client.py +0 -0
  33. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/server/metadata.py +0 -0
  34. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/utils/__init__.py +0 -0
  35. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/utils/eval_metrics.py +0 -0
  36. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/utils/prompt_generator.py +0 -0
  37. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot/utils/tools.py +0 -0
  38. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot.egg-info/dependency_links.txt +0 -0
  39. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot.egg-info/requires.txt +0 -0
  40. {contextpilot-0.3.3 → contextpilot-0.3.4}/contextpilot.egg-info/top_level.txt +0 -0
  41. {contextpilot-0.3.3 → contextpilot-0.3.4}/requirements.txt +0 -0
  42. {contextpilot-0.3.3 → contextpilot-0.3.4}/setup.cfg +0 -0
  43. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_context_index.py +0 -0
  44. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_context_ordering.py +0 -0
  45. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_cpu_distances.py +0 -0
  46. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_deduplication.py +0 -0
  47. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_gpu_distance_performance.py +0 -0
  48. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_gpu_distances.py +0 -0
  49. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_group_prefix_sharing.py +0 -0
  50. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_incremental_build.py +0 -0
  51. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_mem0_integration.py +0 -0
  52. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_multi_turn.py +0 -0
  53. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_multi_turn_e2e.py +0 -0
  54. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_pageindex_integration.py +0 -0
  55. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_performance.py +0 -0
  56. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_pipeline.py +0 -0
  57. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_server_integration.py +0 -0
  58. {contextpilot-0.3.3 → contextpilot-0.3.4}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: contextpilot
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Efficient Retrieval-Augmented Generation with Accuracy-Preserving Context Reuse
5
5
  Author: Yinsicheng Jiang, Chivier Humber
6
6
  License: Apache-2.0
@@ -42,7 +42,7 @@ Dynamic: license-file
42
42
  <div align="center">
43
43
  <img src="assets/about.png" alt="ContextPilot Logo" width="800"/>
44
44
 
45
- <h1><strong>ContextPilot: Efficient Long Context Inference with Context Reuse</strong></h1>
45
+ <h1><strong>ContextPilot: Fast Long-Context Inference via Context Reuse</strong></h1>
46
46
 
47
47
  [![Python](https://img.shields.io/badge/python-≥3.10-blue)](https://www.python.org/)
48
48
  [![PyPI](https://img.shields.io/pypi/v/contextpilot)](https://pypi.org/project/contextpilot/)
@@ -80,7 +80,7 @@ ContextPilot is a fast optimization system on context engineering layer for agen
80
80
  ### System Performance
81
81
 
82
82
  <div align="center">
83
- <img src="assets/deepseek_r1_results.png" alt="Benchmark Results" width="600"/>
83
+ <img src="assets/ds_r1_result_horizontal.png" alt="Benchmark Results" width="800"/>
84
84
  </div>
85
85
 
86
86
  ContextPilot (Stateless) on DeepSeek-R1 maintains accuracy compared to SGLang, achieving 64.68% vs 64.15% F1 on MultihopRAG and 41.08% vs 40.20% F1 on NarrativeQA.
@@ -146,15 +146,18 @@ queries = ["What are transformers?", "How do RNNs compare?", "Explain attention
146
146
 
147
147
  for turn_idx, (query, mems) in enumerate(zip(queries, turn_memories)):
148
148
  # 1. Reorder for prefix sharing (handles cold start & incremental)
149
- [ctx], order = cp_live.reorder([mems]) # single request per turn
149
+ # .reorder() accepts a single list or list-of-lists
150
+ reordered, indices = cp_live.reorder(mems)
151
+ ctx = reordered[0] # single context per turn
150
152
  # Turn 2: "GPT is based on transformers" ← moved to prefix (shared with turn 1)
151
153
  # Turn 3: "Transformers …", "GPT …" ← both moved to prefix
152
154
 
153
155
  # 2. Generate answer with reordered context
154
156
  docs_section = "\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(ctx))
155
- importance_ranking = ">".join(
156
- str(ctx.index(doc) + 1) for doc in mems if doc in ctx
157
- )
157
+ # Map original importance order (mems) → 1-based positions in reordered ctx
158
+ pos = {doc: i + 1 for i, doc in enumerate(ctx)}
159
+ importance_ranking = ">".join(str(pos[doc]) for doc in mems if doc in pos)
160
+ # System prompt = documents + importance ranking (after </documents>, doesn't affect prefix sharing)
158
161
  response = client.chat.completions.create(
159
162
  model="Qwen/Qwen3-4B",
160
163
  messages=[
@@ -171,7 +174,7 @@ for turn_idx, (query, mems) in enumerate(zip(queries, turn_memories)):
171
174
  print(f"A: {response.choices[0].message.content}\n")
172
175
  ```
173
176
 
174
- > **Note:** Stateful mode works without eviction sync — `ContextPilot` tracks the previous ordering and reorders new contexts to maximize prefix cache hits. For production deployments with limited KV-cache capacity, install the [SGLang eviction patch](docs/guides/online_usage.md#sglang-integration) to keep the index in sync. See the [online usage guide](docs/guides/online_usage.md) for HTTP server setup.
177
+ > **Note:** Stateful mode works without eviction sync — `ContextPilot` tracks the previous ordering and reorders new contexts to maximize prefix cache hits. For production deployments with limited KV-cache capacity, install the eviction patch for your inference engine ([SGLang](docs/guides/online_usage.md#sglang-integration) or [vLLM](docs/guides/online_usage.md#vllm-integration)) to keep the index in sync. See the [online usage guide](docs/guides/online_usage.md) for HTTP server setup.
175
178
 
176
179
  **Offline / Online Stateless** — same API, just pass the full batch at once:
177
180
 
@@ -190,15 +193,18 @@ all_contexts = [
190
193
  ]
191
194
 
192
195
  # One call: builds index, reorders docs for prefix sharing, and schedules execution order
193
- reordered, order = cp_batch.reorder(all_contexts)
196
+ # .reorder() returns (reordered_contexts, original_indices)
197
+ reordered_ctx, order = cp_batch.reorder(all_contexts)
194
198
 
195
199
  # Build all prompts in optimized order
196
200
  messages_batch = []
197
- for ctx, orig_idx in zip(reordered, order):
201
+ for ctx, orig_idx in zip(reordered_ctx, order):
198
202
  docs_section = "\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(ctx))
203
+ pos = {doc: i + 1 for i, doc in enumerate(ctx)}
199
204
  importance_ranking = ">".join(
200
- str(ctx.index(doc) + 1) for doc in all_contexts[orig_idx] if doc in ctx
205
+ str(pos[doc]) for doc in all_contexts[orig_idx] if doc in pos
201
206
  )
207
+ # System prompt = documents + importance ranking (after </documents>, doesn't affect prefix sharing)
202
208
  messages_batch.append({
203
209
  "model": "Qwen/Qwen3-4B",
204
210
  "messages": [
@@ -1,7 +1,7 @@
1
1
  <div align="center">
2
2
  <img src="assets/about.png" alt="ContextPilot Logo" width="800"/>
3
3
 
4
- <h1><strong>ContextPilot: Efficient Long Context Inference with Context Reuse</strong></h1>
4
+ <h1><strong>ContextPilot: Fast Long-Context Inference via Context Reuse</strong></h1>
5
5
 
6
6
  [![Python](https://img.shields.io/badge/python-≥3.10-blue)](https://www.python.org/)
7
7
  [![PyPI](https://img.shields.io/pypi/v/contextpilot)](https://pypi.org/project/contextpilot/)
@@ -39,7 +39,7 @@ ContextPilot is a fast optimization system on context engineering layer for agen
39
39
  ### System Performance
40
40
 
41
41
  <div align="center">
42
- <img src="assets/deepseek_r1_results.png" alt="Benchmark Results" width="600"/>
42
+ <img src="assets/ds_r1_result_horizontal.png" alt="Benchmark Results" width="800"/>
43
43
  </div>
44
44
 
45
45
  ContextPilot (Stateless) on DeepSeek-R1 maintains accuracy compared to SGLang, achieving 64.68% vs 64.15% F1 on MultihopRAG and 41.08% vs 40.20% F1 on NarrativeQA.
@@ -105,15 +105,18 @@ queries = ["What are transformers?", "How do RNNs compare?", "Explain attention
105
105
 
106
106
  for turn_idx, (query, mems) in enumerate(zip(queries, turn_memories)):
107
107
  # 1. Reorder for prefix sharing (handles cold start & incremental)
108
- [ctx], order = cp_live.reorder([mems]) # single request per turn
108
+ # .reorder() accepts a single list or list-of-lists
109
+ reordered, indices = cp_live.reorder(mems)
110
+ ctx = reordered[0] # single context per turn
109
111
  # Turn 2: "GPT is based on transformers" ← moved to prefix (shared with turn 1)
110
112
  # Turn 3: "Transformers …", "GPT …" ← both moved to prefix
111
113
 
112
114
  # 2. Generate answer with reordered context
113
115
  docs_section = "\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(ctx))
114
- importance_ranking = ">".join(
115
- str(ctx.index(doc) + 1) for doc in mems if doc in ctx
116
- )
116
+ # Map original importance order (mems) → 1-based positions in reordered ctx
117
+ pos = {doc: i + 1 for i, doc in enumerate(ctx)}
118
+ importance_ranking = ">".join(str(pos[doc]) for doc in mems if doc in pos)
119
+ # System prompt = documents + importance ranking (after </documents>, doesn't affect prefix sharing)
117
120
  response = client.chat.completions.create(
118
121
  model="Qwen/Qwen3-4B",
119
122
  messages=[
@@ -130,7 +133,7 @@ for turn_idx, (query, mems) in enumerate(zip(queries, turn_memories)):
130
133
  print(f"A: {response.choices[0].message.content}\n")
131
134
  ```
132
135
 
133
- > **Note:** Stateful mode works without eviction sync — `ContextPilot` tracks the previous ordering and reorders new contexts to maximize prefix cache hits. For production deployments with limited KV-cache capacity, install the [SGLang eviction patch](docs/guides/online_usage.md#sglang-integration) to keep the index in sync. See the [online usage guide](docs/guides/online_usage.md) for HTTP server setup.
136
+ > **Note:** Stateful mode works without eviction sync — `ContextPilot` tracks the previous ordering and reorders new contexts to maximize prefix cache hits. For production deployments with limited KV-cache capacity, install the eviction patch for your inference engine ([SGLang](docs/guides/online_usage.md#sglang-integration) or [vLLM](docs/guides/online_usage.md#vllm-integration)) to keep the index in sync. See the [online usage guide](docs/guides/online_usage.md) for HTTP server setup.
134
137
 
135
138
  **Offline / Online Stateless** — same API, just pass the full batch at once:
136
139
 
@@ -149,15 +152,18 @@ all_contexts = [
149
152
  ]
150
153
 
151
154
  # One call: builds index, reorders docs for prefix sharing, and schedules execution order
152
- reordered, order = cp_batch.reorder(all_contexts)
155
+ # .reorder() returns (reordered_contexts, original_indices)
156
+ reordered_ctx, order = cp_batch.reorder(all_contexts)
153
157
 
154
158
  # Build all prompts in optimized order
155
159
  messages_batch = []
156
- for ctx, orig_idx in zip(reordered, order):
160
+ for ctx, orig_idx in zip(reordered_ctx, order):
157
161
  docs_section = "\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(ctx))
162
+ pos = {doc: i + 1 for i, doc in enumerate(ctx)}
158
163
  importance_ranking = ">".join(
159
- str(ctx.index(doc) + 1) for doc in all_contexts[orig_idx] if doc in ctx
164
+ str(pos[doc]) for doc in all_contexts[orig_idx] if doc in pos
160
165
  )
166
+ # System prompt = documents + importance ranking (after </documents>, doesn't affect prefix sharing)
161
167
  messages_batch.append({
162
168
  "model": "Qwen/Qwen3-4B",
163
169
  "messages": [
@@ -47,7 +47,7 @@ from .retriever import (
47
47
  MEM0_AVAILABLE,
48
48
  )
49
49
 
50
- __version__ = "0.3.3"
50
+ __version__ = "0.3.4"
51
51
 
52
52
  __all__ = [
53
53
  # High-level pipeline API
@@ -566,10 +566,11 @@ class RAGPipeline:
566
566
  **extra_request_body,
567
567
  }
568
568
 
569
- # Add rid for request tracking in SGLang's radix cache
570
- # SGLang uses 'rid' field to identify requests
571
569
  if request_id:
572
- payload["rid"] = request_id
570
+ if self.inference_config.backend == "vllm":
571
+ payload["request_id"] = request_id
572
+ else:
573
+ payload["rid"] = request_id # SGLang field name
573
574
 
574
575
  output = {
575
576
  "generated_text": "",
@@ -20,6 +20,7 @@ import logging
20
20
  import time
21
21
  import asyncio
22
22
  import os
23
+ import re
23
24
  import uuid
24
25
  from typing import List, Dict, Any, Optional
25
26
  from contextlib import asynccontextmanager
@@ -70,6 +71,19 @@ _str_to_id: Dict[str, int] = {}
70
71
  _id_to_str: Dict[int, str] = {}
71
72
  _next_str_id: int = 0
72
73
 
74
+ # Request ID normalization (engine -> ContextPilot canonical IDs)
75
+ _ENGINE_REQ_ID_PREFIX = re.compile(r"^(cmpl-|chatcmpl-|batch-)")
76
+ _VLLM_REQ_SUFFIX = re.compile(r"^(req-[^-]+)-\d+-[0-9a-f]+$")
77
+
78
+
79
+ def _normalize_request_id(request_id: str) -> str:
80
+ """Normalize engine-specific request IDs to ContextPilot canonical form."""
81
+ rid = _ENGINE_REQ_ID_PREFIX.sub("", request_id or "")
82
+ m = _VLLM_REQ_SUFFIX.match(rid)
83
+ if m:
84
+ return m.group(1)
85
+ return rid
86
+
73
87
 
74
88
  def _init_config():
75
89
  """Initialize config from environment variables."""
@@ -607,24 +621,16 @@ async def evict(request: EvictRequest):
607
621
 
608
622
  THIS IS THE MAIN ENDPOINT THAT THE INFERENCE ENGINE'S EVICTION CALLBACK SHOULD CALL.
609
623
 
610
- When the inference engine's cache evicts nodes, it collects the request_ids
611
- from the evicted nodes and invokes the registered callback. That callback
612
- should call this endpoint to remove the corresponding entries from ContextPilot.
613
-
614
- Integration example (SGLang):
615
- def eviction_callback(evicted_request_ids: set):
616
- if evicted_request_ids:
617
- try:
618
- requests.post(
619
- "http://localhost:8765/evict",
620
- json={"request_ids": list(evicted_request_ids)},
621
- timeout=1.0
622
- )
623
- except Exception as e:
624
- logger.warning(f"ContextPilot eviction sync failed: {e}")
625
-
626
- # Register callback when initializing radix cache
627
- tree_cache.set_eviction_callback(eviction_callback)
624
+ When the inference engine's cache evicts entries, it collects the request_ids
625
+ from the evicted entries and invokes the registered callback. That callback
626
+ calls this endpoint to remove the corresponding entries from ContextPilot.
627
+
628
+ Supported engines:
629
+ - SGLang: patches/sglang/ patches the radix cache to fire callbacks on eviction
630
+ - vLLM: patches/vllm/ patches the block pool to fire callbacks on eviction
631
+
632
+ Both use the same protocol:
633
+ POST /evict {"request_ids": ["req-1", "req-2", ...]}
628
634
  """
629
635
  # Check if index is initialized
630
636
  if _index is None:
@@ -633,14 +639,25 @@ async def evict(request: EvictRequest):
633
639
  )
634
640
 
635
641
  try:
642
+ normalized_ids = [
643
+ _normalize_request_id(rid)
644
+ for rid in request.request_ids
645
+ ]
646
+ normalized_ids = [
647
+ rid for rid in normalized_ids
648
+ if rid and not rid.startswith("HEALTH_CHECK")
649
+ ]
650
+ # Deduplicate while preserving order for deterministic logs/responses.
651
+ normalized_ids = list(dict.fromkeys(normalized_ids))
652
+
636
653
  # Remove the evicted requests from our index
637
- result = _index.remove_requests(request.request_ids)
654
+ result = _index.remove_requests(normalized_ids)
638
655
 
639
656
  # Also clear conversation history for evicted requests
640
657
  # This ensures ConversationTracker stays in sync with the engine's cache
641
658
  tracker = get_conversation_tracker()
642
659
  conversations_cleared = 0
643
- for req_id in request.request_ids:
660
+ for req_id in normalized_ids:
644
661
  cleared = tracker.clear_conversation(req_id)
645
662
  conversations_cleared += cleared
646
663
 
@@ -648,12 +665,14 @@ async def evict(request: EvictRequest):
648
665
  logger.info(
649
666
  f"Eviction: removed {result['removed_count']} requests from index, "
650
667
  f"cleared {conversations_cleared} conversation entries, "
651
- f"not_found={len(result['not_found'])}"
668
+ f"not_found={len(result['not_found'])}, "
669
+ f"incoming={len(request.request_ids)}, normalized={len(normalized_ids)}"
652
670
  )
653
671
 
654
672
  return {
655
673
  "status": "success",
656
674
  "conversations_cleared": conversations_cleared,
675
+ "normalized_request_ids": normalized_ids,
657
676
  **result,
658
677
  }
659
678
 
@@ -916,8 +935,9 @@ async def proxy_completions(request: Request):
916
935
  # Pass request_id to inference engine so it can use the same ID for request tracking
917
936
  # Engine will notify ContextPilot via /evict callback when this request is evicted
918
937
  if request_id:
919
- body["rid"] = request_id
920
- logger.info(f"Proxy: forwarding request with rid={request_id}")
938
+ body["rid"] = request_id # SGLang
939
+ body["request_id"] = request_id # vLLM
940
+ logger.info(f"Proxy: forwarding request with request_id={request_id}")
921
941
  else:
922
942
  logger.info("Proxy: forwarding request without rid (no ContextPilot tracking)")
923
943
 
@@ -224,8 +224,10 @@ class ContextPilot(ContextIndex):
224
224
  transparently — callers never need to distinguish between them.
225
225
 
226
226
  Args:
227
- contexts: ``List[List[int]]`` or ``List[List[str]]`` — each
228
- inner list is one context (document IDs or text strings).
227
+ contexts: A single context (``List[int]`` / ``List[str]``)
228
+ or a batch of contexts (``List[List[int]]`` /
229
+ ``List[List[str]]``). A single list is automatically
230
+ wrapped into ``[contexts]``.
229
231
  initial_tokens_per_context: Initial token budget per context
230
232
  (used for eviction tracking; 0 to ignore).
231
233
  conversation_id: Conversation key for multi-turn
@@ -243,6 +245,10 @@ class ContextPilot(ContextIndex):
243
245
  ``reordered_contexts[i]`` corresponds to
244
246
  ``contexts[original_indices[i]]``.
245
247
  """
248
+ # Accept a single list and wrap it
249
+ if contexts and not isinstance(contexts[0], list):
250
+ contexts = [contexts]
251
+
246
252
  result = self.build_incremental(contexts, initial_tokens_per_context)
247
253
  reordered = result["reordered_contexts"]
248
254
 
@@ -958,25 +964,11 @@ class ContextPilot(ContextIndex):
958
964
  return request_id_mapping, request_ids_ordered
959
965
 
960
966
  # =========================================================================
961
- # Request Eviction (Called by SGLang's radix cache callback)
967
+ # Request Eviction (Called by inference engine's eviction callback)
962
968
  # =========================================================================
963
969
 
964
970
  def remove_requests(self, request_ids: Set[str]) -> Dict[str, Any]:
965
- """
966
- Remove requests from the context index.
967
-
968
- THIS IS THE METHOD CALLED BY SGLANG'S EVICTION CALLBACK.
969
-
970
- When SGLang's radix cache evicts requests, it calls a callback
971
- with the set of evicted request_ids. That callback should invoke
972
- this method to keep the context index in sync.
973
-
974
- Args:
975
- request_ids: Set of request IDs to remove (from SGLang callback)
976
-
977
- Returns:
978
- Dictionary with eviction results
979
- """
971
+ """Remove requests from the context index (called by engine eviction callback)."""
980
972
  evicted_nodes = []
981
973
  not_found = []
982
974
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: contextpilot
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Efficient Retrieval-Augmented Generation with Accuracy-Preserving Context Reuse
5
5
  Author: Yinsicheng Jiang, Chivier Humber
6
6
  License: Apache-2.0
@@ -42,7 +42,7 @@ Dynamic: license-file
42
42
  <div align="center">
43
43
  <img src="assets/about.png" alt="ContextPilot Logo" width="800"/>
44
44
 
45
- <h1><strong>ContextPilot: Efficient Long Context Inference with Context Reuse</strong></h1>
45
+ <h1><strong>ContextPilot: Fast Long-Context Inference via Context Reuse</strong></h1>
46
46
 
47
47
  [![Python](https://img.shields.io/badge/python-≥3.10-blue)](https://www.python.org/)
48
48
  [![PyPI](https://img.shields.io/pypi/v/contextpilot)](https://pypi.org/project/contextpilot/)
@@ -80,7 +80,7 @@ ContextPilot is a fast optimization system on context engineering layer for agen
80
80
  ### System Performance
81
81
 
82
82
  <div align="center">
83
- <img src="assets/deepseek_r1_results.png" alt="Benchmark Results" width="600"/>
83
+ <img src="assets/ds_r1_result_horizontal.png" alt="Benchmark Results" width="800"/>
84
84
  </div>
85
85
 
86
86
  ContextPilot (Stateless) on DeepSeek-R1 maintains accuracy compared to SGLang, achieving 64.68% vs 64.15% F1 on MultihopRAG and 41.08% vs 40.20% F1 on NarrativeQA.
@@ -146,15 +146,18 @@ queries = ["What are transformers?", "How do RNNs compare?", "Explain attention
146
146
 
147
147
  for turn_idx, (query, mems) in enumerate(zip(queries, turn_memories)):
148
148
  # 1. Reorder for prefix sharing (handles cold start & incremental)
149
- [ctx], order = cp_live.reorder([mems]) # single request per turn
149
+ # .reorder() accepts a single list or list-of-lists
150
+ reordered, indices = cp_live.reorder(mems)
151
+ ctx = reordered[0] # single context per turn
150
152
  # Turn 2: "GPT is based on transformers" ← moved to prefix (shared with turn 1)
151
153
  # Turn 3: "Transformers …", "GPT …" ← both moved to prefix
152
154
 
153
155
  # 2. Generate answer with reordered context
154
156
  docs_section = "\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(ctx))
155
- importance_ranking = ">".join(
156
- str(ctx.index(doc) + 1) for doc in mems if doc in ctx
157
- )
157
+ # Map original importance order (mems) → 1-based positions in reordered ctx
158
+ pos = {doc: i + 1 for i, doc in enumerate(ctx)}
159
+ importance_ranking = ">".join(str(pos[doc]) for doc in mems if doc in pos)
160
+ # System prompt = documents + importance ranking (after </documents>, doesn't affect prefix sharing)
158
161
  response = client.chat.completions.create(
159
162
  model="Qwen/Qwen3-4B",
160
163
  messages=[
@@ -171,7 +174,7 @@ for turn_idx, (query, mems) in enumerate(zip(queries, turn_memories)):
171
174
  print(f"A: {response.choices[0].message.content}\n")
172
175
  ```
173
176
 
174
- > **Note:** Stateful mode works without eviction sync — `ContextPilot` tracks the previous ordering and reorders new contexts to maximize prefix cache hits. For production deployments with limited KV-cache capacity, install the [SGLang eviction patch](docs/guides/online_usage.md#sglang-integration) to keep the index in sync. See the [online usage guide](docs/guides/online_usage.md) for HTTP server setup.
177
+ > **Note:** Stateful mode works without eviction sync — `ContextPilot` tracks the previous ordering and reorders new contexts to maximize prefix cache hits. For production deployments with limited KV-cache capacity, install the eviction patch for your inference engine ([SGLang](docs/guides/online_usage.md#sglang-integration) or [vLLM](docs/guides/online_usage.md#vllm-integration)) to keep the index in sync. See the [online usage guide](docs/guides/online_usage.md) for HTTP server setup.
175
178
 
176
179
  **Offline / Online Stateless** — same API, just pass the full batch at once:
177
180
 
@@ -190,15 +193,18 @@ all_contexts = [
190
193
  ]
191
194
 
192
195
  # One call: builds index, reorders docs for prefix sharing, and schedules execution order
193
- reordered, order = cp_batch.reorder(all_contexts)
196
+ # .reorder() returns (reordered_contexts, original_indices)
197
+ reordered_ctx, order = cp_batch.reorder(all_contexts)
194
198
 
195
199
  # Build all prompts in optimized order
196
200
  messages_batch = []
197
- for ctx, orig_idx in zip(reordered, order):
201
+ for ctx, orig_idx in zip(reordered_ctx, order):
198
202
  docs_section = "\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(ctx))
203
+ pos = {doc: i + 1 for i, doc in enumerate(ctx)}
199
204
  importance_ranking = ">".join(
200
- str(ctx.index(doc) + 1) for doc in all_contexts[orig_idx] if doc in ctx
205
+ str(pos[doc]) for doc in all_contexts[orig_idx] if doc in pos
201
206
  )
207
+ # System prompt = documents + importance ranking (after </documents>, doesn't affect prefix sharing)
202
208
  messages_batch.append({
203
209
  "model": "Qwen/Qwen3-4B",
204
210
  "messages": [
@@ -52,4 +52,5 @@ tests/test_pageindex_integration.py
52
52
  tests/test_performance.py
53
53
  tests/test_pipeline.py
54
54
  tests/test_server_integration.py
55
- tests/test_utils.py
55
+ tests/test_utils.py
56
+ tests/test_vllm_patch.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "contextpilot"
7
- version = "0.3.3"
7
+ version = "0.3.4"
8
8
  description = "Efficient Retrieval-Augmented Generation with Accuracy-Preserving Context Reuse"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -255,6 +255,29 @@ class TestLiveIndexRequestTracking:
255
255
  assert 'request_ids' in result
256
256
  assert len(result['request_ids']) == len(contexts)
257
257
 
258
+ def test_reorder_single_list(self):
259
+ """reorder() should accept a single list and auto-wrap it."""
260
+ from contextpilot import ContextPilot
261
+
262
+ engine = ContextPilot(use_gpu=False)
263
+ # Pass a flat list instead of list-of-lists
264
+ reordered, indices = engine.reorder([1, 2, 3])
265
+
266
+ assert len(reordered) == 1
267
+ assert set(reordered[0]) == {1, 2, 3}
268
+ assert indices == [0]
269
+
270
+ def test_reorder_single_list_strings(self):
271
+ """reorder() should accept a single list of strings."""
272
+ from contextpilot import ContextPilot
273
+
274
+ engine = ContextPilot(use_gpu=False)
275
+ reordered, indices = engine.reorder(["doc_a", "doc_b", "doc_c"])
276
+
277
+ assert len(reordered) == 1
278
+ assert set(reordered[0]) == {"doc_a", "doc_b", "doc_c"}
279
+ assert indices == [0]
280
+
258
281
 
259
282
  class TestDeduplication:
260
283
  """Test ContextPilot.deduplicate() for multi-turn deduplication."""
@@ -0,0 +1,493 @@
1
+ """
2
+ Tests for vLLM block_pool.py eviction sync patch.
3
+
4
+ Tests the ContextPilot tracking dicts and eviction callback logic
5
+ without requiring a vLLM installation — all vLLM internals are mocked.
6
+ """
7
+
8
+ import pytest
9
+ from unittest.mock import MagicMock, patch
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Mock vLLM types so we can import/test block_pool logic without vLLM
16
+ # ---------------------------------------------------------------------------
17
+
18
+ @dataclass
19
+ class MockKVCacheBlock:
20
+ block_id: int
21
+ block_hash: Optional[bytes] = None
22
+ ref_cnt: int = 0
23
+ is_null: bool = False
24
+ prev_free_block: Optional["MockKVCacheBlock"] = field(default=None, repr=False)
25
+ next_free_block: Optional["MockKVCacheBlock"] = field(default=None, repr=False)
26
+
27
+ def reset_hash(self):
28
+ self.block_hash = None
29
+
30
+
31
+ class MockFreeKVCacheBlockQueue:
32
+ """Simplified free block queue for testing."""
33
+
34
+ def __init__(self, blocks):
35
+ self._blocks = list(blocks)
36
+ self.num_free_blocks = len(self._blocks)
37
+
38
+ def popleft(self):
39
+ self.num_free_blocks -= 1
40
+ return self._blocks.pop(0)
41
+
42
+ def popleft_n(self, n):
43
+ result = self._blocks[:n]
44
+ self._blocks = self._blocks[n:]
45
+ self.num_free_blocks -= n
46
+ return result
47
+
48
+ def remove(self, block):
49
+ if block in self._blocks:
50
+ self._blocks.remove(block)
51
+ self.num_free_blocks -= 1
52
+
53
+ def append_n(self, blocks):
54
+ self._blocks.extend(blocks)
55
+ self.num_free_blocks += len(blocks)
56
+
57
+
58
+ class MockRequest:
59
+ def __init__(self, request_id, block_hashes, all_token_ids=None):
60
+ self.request_id = request_id
61
+ self.block_hashes = block_hashes
62
+ self.all_token_ids = all_token_ids or []
63
+ self.lora_request = None
64
+
65
+
66
+ class MockBlockHashToBlockMap:
67
+ """Mirrors the real BlockHashToBlockMap for testing."""
68
+
69
+ def __init__(self):
70
+ self._cache = {}
71
+
72
+ def get_one_block(self, key):
73
+ blocks = self._cache.get(key)
74
+ if blocks is None:
75
+ return None
76
+ if isinstance(blocks, MockKVCacheBlock):
77
+ return blocks
78
+ if isinstance(blocks, dict):
79
+ return next(iter(blocks.values()))
80
+ raise AssertionError(f"Invalid cache block type: {type(blocks)}")
81
+
82
+ def insert(self, key, block):
83
+ blocks = self._cache.get(key)
84
+ if blocks is None:
85
+ self._cache[key] = block
86
+ elif isinstance(blocks, MockKVCacheBlock):
87
+ self._cache[key] = {
88
+ blocks.block_id: blocks,
89
+ block.block_id: block,
90
+ }
91
+ elif isinstance(blocks, dict):
92
+ blocks[block.block_id] = block
93
+ else:
94
+ raise AssertionError(f"Invalid cache block type: {type(blocks)}")
95
+
96
+ def pop(self, key, block_id):
97
+ blocks = self._cache.pop(key, None)
98
+ if blocks is None:
99
+ return None
100
+
101
+ if isinstance(blocks, MockKVCacheBlock):
102
+ if blocks.block_id == block_id:
103
+ return blocks
104
+ self._cache[key] = blocks
105
+ return None
106
+
107
+ if isinstance(blocks, dict):
108
+ block = blocks.pop(block_id, None)
109
+ if blocks:
110
+ self._cache[key] = blocks
111
+ return block
112
+
113
+ self._cache[key] = blocks
114
+ return None
115
+
116
+ def __len__(self):
117
+ return len(self._cache)
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # BlockPool under test — extracted logic (no vLLM imports needed)
122
+ # ---------------------------------------------------------------------------
123
+
124
+ class TestableBlockPool:
125
+ """ContextPilot-patched BlockPool logic, mocked for testing."""
126
+
127
+ def __init__(self, num_blocks=10, eviction_callback=None):
128
+ self.blocks = [MockKVCacheBlock(i) for i in range(num_blocks)]
129
+ self.free_block_queue = MockFreeKVCacheBlockQueue(list(self.blocks))
130
+ self.cached_block_hash_to_block = MockBlockHashToBlockMap()
131
+ self.enable_caching = True
132
+ self.num_gpu_blocks = num_blocks
133
+ self.metrics_collector = None
134
+
135
+ # Null block
136
+ self.null_block = self.free_block_queue.popleft()
137
+ self.null_block.is_null = True
138
+
139
+ # ContextPilot tracking
140
+ self._block_to_requests: dict[bytes, set[str]] = {}
141
+ self._request_to_blocks: dict[str, set[bytes]] = {}
142
+ self.eviction_callback = eviction_callback
143
+
144
+ def cache_full_blocks_simple(self, request_id, block_indices, block_hashes):
145
+ for idx, bh in zip(block_indices, block_hashes):
146
+ blk = self.blocks[idx]
147
+ blk.block_hash = bh
148
+ self.cached_block_hash_to_block.insert(bh, blk)
149
+
150
+ if self.eviction_callback is not None:
151
+ self._block_to_requests.setdefault(bh, set()).add(request_id)
152
+ self._request_to_blocks.setdefault(request_id, set()).add(bh)
153
+
154
+ def _maybe_evict_cached_block(self, block) -> set:
155
+ fully_evicted = set()
156
+ block_hash = block.block_hash
157
+ if block_hash is None:
158
+ return fully_evicted
159
+
160
+ if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None:
161
+ return fully_evicted
162
+
163
+ if self.cached_block_hash_to_block.get_one_block(block_hash) is None:
164
+ request_ids = self._block_to_requests.pop(block_hash, None)
165
+ if request_ids:
166
+ for rid in request_ids:
167
+ blocks_set = self._request_to_blocks.get(rid)
168
+ if blocks_set is not None:
169
+ blocks_set.discard(block_hash)
170
+ if not blocks_set:
171
+ fully_evicted.add(rid)
172
+ del self._request_to_blocks[rid]
173
+
174
+ block.reset_hash()
175
+ return fully_evicted
176
+
177
+ def get_new_blocks(self, num_blocks):
178
+ ret = self.free_block_queue.popleft_n(num_blocks)
179
+ fully_evicted = set()
180
+
181
+ if self.enable_caching:
182
+ for block in ret:
183
+ evicted = self._maybe_evict_cached_block(block)
184
+ fully_evicted.update(evicted)
185
+ block.ref_cnt += 1
186
+ else:
187
+ for block in ret:
188
+ block.ref_cnt += 1
189
+
190
+ if fully_evicted and self.eviction_callback is not None:
191
+ try:
192
+ self.eviction_callback(fully_evicted)
193
+ except Exception:
194
+ pass
195
+
196
+ return ret
197
+
198
+ def free_blocks(self, blocks):
199
+ blocks_list = list(blocks)
200
+ for block in blocks_list:
201
+ block.ref_cnt -= 1
202
+ self.free_block_queue.append_n(
203
+ [b for b in blocks_list if b.ref_cnt == 0 and not b.is_null]
204
+ )
205
+
206
+ def touch(self, blocks):
207
+ if not blocks:
208
+ return
209
+ if isinstance(blocks[0], MockKVCacheBlock):
210
+ block_iter = blocks
211
+ else:
212
+ block_iter = (b for group in blocks for b in group)
213
+
214
+ for block in block_iter:
215
+ if block.ref_cnt == 0 and not block.is_null:
216
+ self.free_block_queue.remove(block)
217
+ block.ref_cnt += 1
218
+
219
+ def evict_blocks(self, block_ids):
220
+ fully_evicted = set()
221
+ for block_id in block_ids:
222
+ block = self.blocks[block_id]
223
+ evicted = self._maybe_evict_cached_block(block)
224
+ fully_evicted.update(evicted)
225
+ if fully_evicted and self.eviction_callback is not None:
226
+ try:
227
+ self.eviction_callback(fully_evicted)
228
+ except Exception:
229
+ pass
230
+
231
+ def reset_prefix_cache(self):
232
+ if self._request_to_blocks and self.eviction_callback is not None:
233
+ all_requests = set(self._request_to_blocks.keys())
234
+ try:
235
+ self.eviction_callback(all_requests)
236
+ except Exception:
237
+ pass
238
+ self._block_to_requests.clear()
239
+ self._request_to_blocks.clear()
240
+ self.cached_block_hash_to_block = MockBlockHashToBlockMap()
241
+ for block in self.blocks:
242
+ block.reset_hash()
243
+
244
+ def get_tracked_request_ids(self):
245
+ return set(self._request_to_blocks.keys())
246
+
247
+ def is_request_in_cache(self, request_id):
248
+ return request_id in self._request_to_blocks
249
+
250
+
251
+ # ---------------------------------------------------------------------------
252
+ # Tests
253
+ # ---------------------------------------------------------------------------
254
+
255
+ class TestTrackingDicts:
256
+
257
+ def test_cache_records_mapping(self):
258
+ callback = MagicMock()
259
+ pool = TestableBlockPool(eviction_callback=callback)
260
+
261
+ pool.cache_full_blocks_simple(
262
+ "req-1", [1, 2, 3], [b"h1", b"h2", b"h3"]
263
+ )
264
+
265
+ assert pool.is_request_in_cache("req-1")
266
+ assert pool._request_to_blocks["req-1"] == {b"h1", b"h2", b"h3"}
267
+ assert "req-1" in pool._block_to_requests[b"h1"]
268
+ assert "req-1" in pool._block_to_requests[b"h2"]
269
+ assert "req-1" in pool._block_to_requests[b"h3"]
270
+
271
+ def test_shared_blocks_track_both_requests(self):
272
+ callback = MagicMock()
273
+ pool = TestableBlockPool(num_blocks=20, eviction_callback=callback)
274
+
275
+ # Two requests share block hash h1 (different block_ids, same hash)
276
+ pool.cache_full_blocks_simple("req-A", [1, 2], [b"h1", b"h2"])
277
+ pool.cache_full_blocks_simple("req-B", [3, 4], [b"h1", b"h3"])
278
+
279
+ assert pool._block_to_requests[b"h1"] == {"req-A", "req-B"}
280
+ assert pool._request_to_blocks["req-A"] == {b"h1", b"h2"}
281
+ assert pool._request_to_blocks["req-B"] == {b"h1", b"h3"}
282
+
283
+ def test_no_tracking_when_callback_is_none(self):
284
+ pool = TestableBlockPool(eviction_callback=None)
285
+
286
+ pool.cache_full_blocks_simple("req-1", [1, 2], [b"h1", b"h2"])
287
+
288
+ assert len(pool._block_to_requests) == 0
289
+ assert len(pool._request_to_blocks) == 0
290
+
291
+
292
+ class TestEvictionCallback:
293
+
294
+ def test_full_eviction_fires_callback(self):
295
+ callback = MagicMock()
296
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
297
+
298
+ # Cache 3 blocks for req-1 using blocks 1,2,3
299
+ pool.cache_full_blocks_simple("req-1", [1, 2, 3], [b"h1", b"h2", b"h3"])
300
+
301
+ # Evict all 3 blocks
302
+ pool.evict_blocks({1, 2, 3})
303
+
304
+ callback.assert_called_once()
305
+ evicted_ids = callback.call_args[0][0]
306
+ assert "req-1" in evicted_ids
307
+
308
+ def test_partial_eviction_does_not_fire_callback(self):
309
+ callback = MagicMock()
310
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
311
+
312
+ pool.cache_full_blocks_simple("req-1", [1, 2, 3], [b"h1", b"h2", b"h3"])
313
+
314
+ # Evict only 2 of 3 blocks — request still has h3
315
+ pool.evict_blocks({1, 2})
316
+
317
+ callback.assert_not_called()
318
+ assert pool.is_request_in_cache("req-1")
319
+ assert pool._request_to_blocks["req-1"] == {b"h3"}
320
+
321
+ def test_evict_last_block_fires_callback(self):
322
+ callback = MagicMock()
323
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
324
+
325
+ pool.cache_full_blocks_simple("req-1", [1, 2, 3], [b"h1", b"h2", b"h3"])
326
+
327
+ # Evict 2, then the last 1
328
+ pool.evict_blocks({1, 2})
329
+ callback.assert_not_called()
330
+
331
+ pool.evict_blocks({3})
332
+ callback.assert_called_once()
333
+ assert "req-1" in callback.call_args[0][0]
334
+ assert not pool.is_request_in_cache("req-1")
335
+
336
+ def test_shared_hash_not_evicted_until_last_copy_removed(self):
337
+ callback = MagicMock()
338
+ pool = TestableBlockPool(num_blocks=12, eviction_callback=callback)
339
+
340
+ # req-A: shared + unique, req-B: shared only
341
+ pool.cache_full_blocks_simple("req-A", [1, 2], [b"h_shared", b"h_a"])
342
+ pool.cache_full_blocks_simple("req-B", [3], [b"h_shared"])
343
+
344
+ # Remove one shared copy + req-A unique block.
345
+ # h_shared is still available via req-B's block.
346
+ pool.evict_blocks({1, 2})
347
+
348
+ callback.assert_not_called()
349
+ assert pool.is_request_in_cache("req-A")
350
+ assert pool.is_request_in_cache("req-B")
351
+
352
+ # Remove final shared copy: now both requests are fully evicted.
353
+ pool.evict_blocks({3})
354
+ callback.assert_called_once()
355
+ assert callback.call_args[0][0] == {"req-A", "req-B"}
356
+
357
+ def test_multiple_requests_evicted_together(self):
358
+ callback = MagicMock()
359
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
360
+
361
+ pool.cache_full_blocks_simple("req-A", [1], [b"hA"])
362
+ pool.cache_full_blocks_simple("req-B", [2], [b"hB"])
363
+
364
+ pool.evict_blocks({1, 2})
365
+
366
+ callback.assert_called_once()
367
+ evicted = callback.call_args[0][0]
368
+ assert evicted == {"req-A", "req-B"}
369
+
370
+ def test_callback_not_called_when_none(self):
371
+ pool = TestableBlockPool(eviction_callback=None)
372
+
373
+ pool.cache_full_blocks_simple("req-1", [1], [b"h1"])
374
+ # Should not raise
375
+ pool.evict_blocks({1})
376
+
377
+ def test_callback_exception_is_swallowed(self):
378
+ callback = MagicMock(side_effect=Exception("network error"))
379
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
380
+
381
+ pool.cache_full_blocks_simple("req-1", [1], [b"h1"])
382
+ # Should not raise even though callback throws
383
+ pool.evict_blocks({1})
384
+ callback.assert_called_once()
385
+
386
+
387
+ class TestGetNewBlocksEviction:
388
+
389
+ def test_allocating_cached_blocks_fires_callback(self):
390
+ """When get_new_blocks pops cached blocks, eviction callback fires."""
391
+ callback = MagicMock()
392
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
393
+
394
+ # Allocate/cache/free blocks first so they become eviction candidates.
395
+ blocks = pool.get_new_blocks(3)
396
+ block_ids = [b.block_id for b in blocks]
397
+ pool.cache_full_blocks_simple("req-X", block_ids, [b"h1", b"h2", b"h3"])
398
+ pool.free_blocks(blocks)
399
+ assert pool.is_request_in_cache("req-X")
400
+
401
+ # Force allocation of all free blocks to guarantee cached blocks are popped.
402
+ pool.get_new_blocks(pool.free_block_queue.num_free_blocks)
403
+
404
+ callback.assert_called_once()
405
+ assert "req-X" in callback.call_args[0][0]
406
+ assert not pool.is_request_in_cache("req-X")
407
+
408
+
409
+ class TestTouchCompatibility:
410
+
411
+ def test_touch_accepts_grouped_blocks(self):
412
+ pool = TestableBlockPool(num_blocks=8, eviction_callback=None)
413
+ blocks = pool.get_new_blocks(2)
414
+ pool.free_blocks(blocks)
415
+
416
+ # Upstream style: tuple[Sequence[KVCacheBlock], ...]
417
+ pool.touch((blocks,))
418
+ assert blocks[0].ref_cnt == 1
419
+ assert blocks[1].ref_cnt == 1
420
+
421
+
422
+ class TestResetPrefixCache:
423
+
424
+ def test_reset_fires_callback_for_all(self):
425
+ callback = MagicMock()
426
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
427
+
428
+ pool.cache_full_blocks_simple("req-A", [1], [b"hA"])
429
+ pool.cache_full_blocks_simple("req-B", [2], [b"hB"])
430
+ pool.cache_full_blocks_simple("req-C", [3], [b"hC"])
431
+
432
+ pool.reset_prefix_cache()
433
+
434
+ callback.assert_called_once()
435
+ evicted = callback.call_args[0][0]
436
+ assert evicted == {"req-A", "req-B", "req-C"}
437
+
438
+ # Tracking should be cleared
439
+ assert len(pool._block_to_requests) == 0
440
+ assert len(pool._request_to_blocks) == 0
441
+
442
+ def test_reset_with_no_tracked_requests(self):
443
+ callback = MagicMock()
444
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
445
+
446
+ pool.reset_prefix_cache()
447
+
448
+ callback.assert_not_called()
449
+
450
+
451
+ class TestCallbackPrefixStripping:
452
+
453
+ def test_strips_cmpl_prefix(self):
454
+ import re
455
+ prefix_re = re.compile(r"^(cmpl-|chatcmpl-|batch-)")
456
+
457
+ ids = {"cmpl-req-123", "chatcmpl-req-456", "batch-req-789", "plain-id"}
458
+ stripped = {prefix_re.sub("", rid) for rid in ids}
459
+
460
+ assert stripped == {"req-123", "req-456", "req-789", "plain-id"}
461
+
462
+ def test_no_prefix_unchanged(self):
463
+ import re
464
+ prefix_re = re.compile(r"^(cmpl-|chatcmpl-|batch-)")
465
+
466
+ ids = {"my-request-1", "another-req"}
467
+ stripped = {prefix_re.sub("", rid) for rid in ids}
468
+
469
+ assert stripped == {"my-request-1", "another-req"}
470
+
471
+
472
+ class TestHelperMethods:
473
+
474
+ def test_get_tracked_request_ids(self):
475
+ callback = MagicMock()
476
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
477
+
478
+ pool.cache_full_blocks_simple("req-A", [1], [b"hA"])
479
+ pool.cache_full_blocks_simple("req-B", [2], [b"hB"])
480
+
481
+ assert pool.get_tracked_request_ids() == {"req-A", "req-B"}
482
+
483
+ def test_is_request_in_cache(self):
484
+ callback = MagicMock()
485
+ pool = TestableBlockPool(num_blocks=10, eviction_callback=callback)
486
+
487
+ pool.cache_full_blocks_simple("req-A", [1], [b"hA"])
488
+
489
+ assert pool.is_request_in_cache("req-A")
490
+ assert not pool.is_request_in_cache("req-B")
491
+
492
+ pool.evict_blocks({1})
493
+ assert not pool.is_request_in_cache("req-A")
File without changes
File without changes