infinite-context-gateway 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,108 @@
1
+ Metadata-Version: 2.4
2
+ Name: infinite-context-gateway
3
+ Version: 0.1.0
4
+ Summary: A hybrid AI context gateway combining Headroom cloud compression and In-Place Local Test-Time Training (TTT).
5
+ Author-email: Dev Team <dev@example.com>
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Operating System :: OS Independent
9
+ Requires-Python: >=3.9
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: httpx
12
+ Provides-Extra: cloud
13
+ Requires-Dist: headroom-ai; extra == "cloud"
14
+ Provides-Extra: local
15
+ Requires-Dist: torch; extra == "local"
16
+ Requires-Dist: transformers; extra == "local"
17
+ Requires-Dist: peft; extra == "local"
18
+ Requires-Dist: accelerate; extra == "local"
19
+ Requires-Dist: bitsandbytes; extra == "local"
20
+ Provides-Extra: all
21
+ Requires-Dist: headroom-ai; extra == "all"
22
+ Requires-Dist: torch; extra == "all"
23
+ Requires-Dist: transformers; extra == "all"
24
+ Requires-Dist: peft; extra == "all"
25
+ Requires-Dist: accelerate; extra == "all"
26
+ Requires-Dist: bitsandbytes; extra == "all"
27
+
28
+ # Infinite Context
29
+
30
+ A hybrid AI context gateway combining **Headroom cloud compression** and **In-Place Local Test-Time Training (TTT)**.
31
+
32
+ ## Installation
33
+
34
+ To install the base gateway:
35
+ ```bash
36
+ pip install infinite_context
37
+ ```
38
+
39
+ To install with Cloud Phase 1 dependencies (Headroom API):
40
+ ```bash
41
+ pip install infinite_context[cloud]
42
+ ```
43
+
44
+ To install with Local Phase 2 dependencies (PyTorch, Transformers, PEFT):
45
+ ```bash
46
+ pip install infinite_context[local]
47
+ ```
48
+
49
+ ## Usage
50
+
51
+ ### Phase 1: Cloud Compression
52
+ Uses the Headroom API to semantically compress a massive context and send it to cloud providers (OpenAI, Anthropic) while keeping costs low.
53
+
54
+ ```python
55
+ from infinite_context import ContextGateway
56
+
57
+ gateway = ContextGateway(
58
+ engine="cloud",
59
+ model_id="claude-3-5-sonnet-20240620",
60
+ api_key="your_anthropic_api_key",
61
+ compression_ratio=0.8
62
+ )
63
+
64
+ # You can pass in conversational history (rolling window)
65
+ history = [
66
+ {"role": "user", "content": "What is the codebase about?"},
67
+ {"role": "assistant", "content": "It is a Python application..."}
68
+ ]
69
+
70
+ response = gateway.chat("How does the failover protocol work?", massive_context, history=history)
71
+ print(response)
72
+ ```
73
+
74
+ ### Phase 2: Local Test-Time Training (TTT)
75
+ Bypasses the KV-Cache entirely by injecting a PEFT LoRA adapter and baking the context directly into the model's neural weights on your local GPU. Includes Early Stopping latency optimizations and Generation repetition penalties.
76
+
77
+ ```python
78
+ from infinite_context import ContextGateway
79
+
80
+ gateway = ContextGateway(
81
+ engine="local",
82
+ model_id="Qwen/Qwen2.5-0.5B-Instruct",
83
+ load_in_4bit=True
84
+ )
85
+
86
+ response = gateway.chat("What is the failover protocol?", massive_context)
87
+ print(response)
88
+ ```
89
+
90
+ ### Checkpoint Persistence
91
+ Save trained Fast Weights to disk and resume later without re-reading the document:
92
+
93
+ ```python
94
+ from infinite_context import ContextGateway
95
+
96
+ # Train and keep state
97
+ gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
98
+ gateway.chat("Summarise the document.", massive_context, keep_state=True)
99
+ gateway.save_state("./my_checkpoint")
100
+
101
+ del gateway # Free all GPU memory
102
+
103
+ # Resume later — zero re-training
104
+ gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
105
+ gateway.load_state("./my_checkpoint")
106
+ response = gateway.chat("What was the protocol?", context="")
107
+ print(response)
108
+ ```
@@ -0,0 +1,81 @@
1
+ # Infinite Context
2
+
3
+ A hybrid AI context gateway combining **Headroom cloud compression** and **In-Place Local Test-Time Training (TTT)**.
4
+
5
+ ## Installation
6
+
7
+ To install the base gateway:
8
+ ```bash
9
+ pip install infinite_context
10
+ ```
11
+
12
+ To install with Cloud Phase 1 dependencies (Headroom API):
13
+ ```bash
14
+ pip install infinite_context[cloud]
15
+ ```
16
+
17
+ To install with Local Phase 2 dependencies (PyTorch, Transformers, PEFT):
18
+ ```bash
19
+ pip install infinite_context[local]
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ### Phase 1: Cloud Compression
25
+ Uses the Headroom API to semantically compress a massive context and send it to cloud providers (OpenAI, Anthropic) while keeping costs low.
26
+
27
+ ```python
28
+ from infinite_context import ContextGateway
29
+
30
+ gateway = ContextGateway(
31
+ engine="cloud",
32
+ model_id="claude-3-5-sonnet-20240620",
33
+ api_key="your_anthropic_api_key",
34
+ compression_ratio=0.8
35
+ )
36
+
37
+ # You can pass in conversational history (rolling window)
38
+ history = [
39
+ {"role": "user", "content": "What is the codebase about?"},
40
+ {"role": "assistant", "content": "It is a Python application..."}
41
+ ]
42
+
43
+ response = gateway.chat("How does the failover protocol work?", massive_context, history=history)
44
+ print(response)
45
+ ```
46
+
47
+ ### Phase 2: Local Test-Time Training (TTT)
48
+ Bypasses the KV-Cache entirely by injecting a PEFT LoRA adapter and baking the context directly into the model's neural weights on your local GPU. Includes Early Stopping latency optimizations and Generation repetition penalties.
49
+
50
+ ```python
51
+ from infinite_context import ContextGateway
52
+
53
+ gateway = ContextGateway(
54
+ engine="local",
55
+ model_id="Qwen/Qwen2.5-0.5B-Instruct",
56
+ load_in_4bit=True
57
+ )
58
+
59
+ response = gateway.chat("What is the failover protocol?", massive_context)
60
+ print(response)
61
+ ```
62
+
63
+ ### Checkpoint Persistence
64
+ Save trained Fast Weights to disk and resume later without re-reading the document:
65
+
66
+ ```python
67
+ from infinite_context import ContextGateway
68
+
69
+ # Train and keep state
70
+ gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
71
+ gateway.chat("Summarise the document.", massive_context, keep_state=True)
72
+ gateway.save_state("./my_checkpoint")
73
+
74
+ del gateway # Free all GPU memory
75
+
76
+ # Resume later — zero re-training
77
+ gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
78
+ gateway.load_state("./my_checkpoint")
79
+ response = gateway.chat("What was the protocol?", context="")
80
+ print(response)
81
+ ```
@@ -0,0 +1,10 @@
1
+ """
2
+ Infinite Context — A hybrid AI context gateway combining Headroom cloud
3
+ compression and In-Place Local Test-Time Training (TTT).
4
+ """
5
+
6
+ __version__ = "0.1.0"
7
+
8
+ from .gateway import ContextGateway
9
+
10
+ __all__ = ["ContextGateway"]
@@ -0,0 +1,6 @@
1
+ """Cloud engine components — API client and Headroom compression."""
2
+
3
+ from .client import CloudClient
4
+ from .compressor import HeadroomCompressor
5
+
6
+ __all__ = ["CloudClient", "HeadroomCompressor"]
@@ -0,0 +1,127 @@
1
+ """
2
+ Cloud API client for OpenAI, Anthropic, and OpenAI-compatible endpoints.
3
+
4
+ Handles payload formatting, header construction, and SSE streaming for
5
+ providers like Groq, Together, etc.
6
+ """
7
+
8
+ import json
9
+ import httpx
10
+ from typing import Dict, Generator, List
11
+
12
+ __all__ = ["CloudClient"]
13
+
14
+ # Shared timeout: 30 s connect, 120 s read (large prompts can be slow).
15
+ _DEFAULT_TIMEOUT = httpx.Timeout(connect=30.0, read=120.0, write=30.0, pool=10.0)
16
+
17
+
18
+ class CloudClient:
19
+ """HTTP client for cloud LLM providers."""
20
+
21
+ def __init__(self, api_key: str, model_id: str, base_url: str = None):
22
+ """
23
+ Args:
24
+ api_key: Bearer / x-api-key token.
25
+ model_id: Model identifier (e.g. ``claude-3-5-sonnet-20240620``).
26
+ base_url: Custom endpoint URL. When set, the OpenAI chat-
27
+ completions format is assumed.
28
+ """
29
+ self.api_key = api_key
30
+ self.model_id = model_id
31
+
32
+ # Determine endpoint and provider format.
33
+ if base_url:
34
+ self.base_url = base_url
35
+ self.provider = "openai"
36
+ elif "claude" in self.model_id.lower():
37
+ self.base_url = "https://api.anthropic.com/v1/messages"
38
+ self.provider = "anthropic"
39
+ else:
40
+ self.base_url = "https://api.openai.com/v1/chat/completions"
41
+ self.provider = "openai"
42
+
43
+ # ------------------------------------------------------------------
44
+ # Internals
45
+ # ------------------------------------------------------------------
46
+
47
+ def _prepare_headers(self) -> Dict[str, str]:
48
+ if self.provider == "anthropic":
49
+ return {
50
+ "x-api-key": self.api_key,
51
+ "anthropic-version": "2023-06-01",
52
+ "content-type": "application/json",
53
+ }
54
+ return {
55
+ "Authorization": f"Bearer {self.api_key}",
56
+ "Content-Type": "application/json",
57
+ }
58
+
59
+ def _prepare_payload(
60
+ self,
61
+ messages: List[Dict[str, str]],
62
+ stream: bool = False,
63
+ ) -> dict:
64
+ if self.provider == "anthropic":
65
+ # Anthropic splits the system prompt out of the messages array.
66
+ system_msg = next(
67
+ (m["content"] for m in messages if m["role"] == "system"), ""
68
+ )
69
+ user_msgs = [m for m in messages if m["role"] != "system"]
70
+ return {
71
+ "model": self.model_id,
72
+ "system": system_msg,
73
+ "messages": user_msgs,
74
+ "max_tokens": 4096,
75
+ "stream": stream,
76
+ }
77
+ return {
78
+ "model": self.model_id,
79
+ "messages": messages,
80
+ "stream": stream,
81
+ }
82
+
83
+ # ------------------------------------------------------------------
84
+ # Public API
85
+ # ------------------------------------------------------------------
86
+
87
+ def send_request(self, messages: List[Dict[str, str]]) -> str:
88
+ """Send a non-streaming chat completion request."""
89
+ payload = self._prepare_payload(messages, stream=False)
90
+ headers = self._prepare_headers()
91
+
92
+ with httpx.Client(timeout=_DEFAULT_TIMEOUT) as client:
93
+ response = client.post(self.base_url, json=payload, headers=headers)
94
+ response.raise_for_status()
95
+ data = response.json()
96
+
97
+ if self.provider == "anthropic":
98
+ return data["content"][0]["text"]
99
+ return data["choices"][0]["message"]["content"]
100
+
101
+ def stream_request(
102
+ self, messages: List[Dict[str, str]]
103
+ ) -> Generator[str, None, None]:
104
+ """Yield text deltas from an SSE streaming response."""
105
+ payload = self._prepare_payload(messages, stream=True)
106
+ headers = self._prepare_headers()
107
+
108
+ with httpx.Client(timeout=_DEFAULT_TIMEOUT) as client:
109
+ with client.stream(
110
+ "POST", self.base_url, json=payload, headers=headers
111
+ ) as response:
112
+ response.raise_for_status()
113
+ for line in response.iter_lines():
114
+ if not line or not line.startswith("data:"):
115
+ continue
116
+ data_str = line[len("data:"):].strip()
117
+ if data_str == "[DONE]":
118
+ return
119
+ try:
120
+ chunk = json.loads(data_str)
121
+ delta = chunk["choices"][0].get("delta", {})
122
+ content = delta.get("content")
123
+ if content:
124
+ yield content
125
+ except (json.JSONDecodeError, KeyError, IndexError):
126
+ # Malformed chunk — skip gracefully.
127
+ continue
@@ -0,0 +1,44 @@
1
+ """
2
+ Headroom semantic compression wrapper.
3
+
4
+ Compresses massive context payloads locally before sending them to a cloud
5
+ LLM, reducing token counts and API costs.
6
+ """
7
+
8
+ __all__ = ["HeadroomCompressor"]
9
+
10
+
11
+ class HeadroomCompressor:
12
+ """Wraps the ``headroom-ai`` library with a graceful fallback."""
13
+
14
+ def __init__(self, target_ratio: float = 0.8):
15
+ self.target_ratio = target_ratio
16
+ self._compress_fn = None
17
+
18
+ try:
19
+ from headroom import compress
20
+ self._compress_fn = compress
21
+ except ImportError:
22
+ print(
23
+ "Warning: headroom-ai not installed. "
24
+ "Using fallback semantic compression simulator."
25
+ )
26
+
27
+ def compress(self, context: str) -> str:
28
+ """Compress *context* using Headroom's algorithms.
29
+
30
+ Falls back to a head + tail truncation when the library is absent.
31
+ """
32
+ if self._compress_fn is not None:
33
+ return self._compress_fn(context)
34
+
35
+ # Fallback simulator: keep the first 1000 chars and the last 2000
36
+ # chars (where needles usually are).
37
+ if len(context) < 3000:
38
+ return context
39
+
40
+ return (
41
+ context[:1000]
42
+ + "\n\n...[CONTENT COMPRESSED BY HEADROOM SIMULATOR]...\n\n"
43
+ + context[-2000:]
44
+ )
@@ -0,0 +1,145 @@
1
+ """
2
+ Context Gateway — unified interface for the cloud and local engines.
3
+
4
+ Users interact exclusively through :class:`ContextGateway`, selecting either
5
+ ``engine="cloud"`` (Headroom compression + API) or ``engine="local"``
6
+ (In-Place Test-Time Training on a local GPU).
7
+ """
8
+
9
+ from typing import Any, Generator, List, Optional
10
+
11
+ __all__ = ["ContextGateway"]
12
+
13
+
14
+ class ContextGateway:
15
+ """Hybrid gateway that routes queries through either a cloud API or a
16
+ local TTT engine depending on configuration.
17
+
18
+ Args:
19
+ engine: ``"cloud"`` for Headroom API compression, or ``"local"``
20
+ for In-Place TTT.
21
+ model_id: Target model (e.g. ``"claude-3-5-sonnet-20240620"`` or a
22
+ local HuggingFace repo id).
23
+ api_key: API key for cloud routing (if applicable).
24
+ base_url: Custom base URL for OpenAI-compatible providers like Groq.
25
+ compression_ratio: Target compression ratio for Headroom (cloud only).
26
+ skip_compression: If *True*, skips Headroom compression (useful for
27
+ testing alternative APIs directly).
28
+ **kwargs: Additional parameters forwarded to :class:`LocalEngine`
29
+ (e.g. ``device``, ``load_in_4bit``).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ engine: str = "cloud",
35
+ model_id: str = "claude-3-5-sonnet-20240620",
36
+ api_key: Optional[str] = None,
37
+ base_url: Optional[str] = None,
38
+ compression_ratio: float = 0.8,
39
+ skip_compression: bool = False,
40
+ **kwargs: Any,
41
+ ):
42
+ if engine not in ("cloud", "local"):
43
+ raise ValueError("Engine must be either 'cloud' or 'local'.")
44
+
45
+ self.engine = engine
46
+ self.model_id = model_id
47
+
48
+ if self.engine == "cloud":
49
+ from .cloud.compressor import HeadroomCompressor
50
+ from .cloud.client import CloudClient
51
+
52
+ self.compressor = (
53
+ HeadroomCompressor(target_ratio=compression_ratio)
54
+ if not skip_compression
55
+ else None
56
+ )
57
+ self.client = CloudClient(
58
+ api_key=api_key, model_id=self.model_id, base_url=base_url,
59
+ )
60
+ else: # local
61
+ from .local.engine import LocalEngine
62
+ self.local_engine = LocalEngine(model_id=self.model_id, **kwargs)
63
+
64
+ # ------------------------------------------------------------------
65
+ # Chat
66
+ # ------------------------------------------------------------------
67
+
68
+ def chat(
69
+ self,
70
+ question: str,
71
+ context: str,
72
+ history: Optional[List[dict]] = None,
73
+ max_history: int = 4,
74
+ keep_state: bool = False,
75
+ ) -> str:
76
+ """Process a massive context payload and answer the question.
77
+
78
+ Includes a rolling-window history to keep token costs low.
79
+ """
80
+ if self.engine == "cloud":
81
+ compressed = (
82
+ self.compressor.compress(context)
83
+ if self.compressor is not None
84
+ else context
85
+ )
86
+
87
+ messages: list[dict] = [{"role": "system", "content": compressed}]
88
+ if history:
89
+ messages.extend(history[-max_history:])
90
+ messages.append({"role": "user", "content": question})
91
+
92
+ return self.client.send_request(messages)
93
+
94
+ # local
95
+ return self.local_engine.generate(
96
+ question, context, keep_state=keep_state,
97
+ )
98
+
99
+ def stream_chat(
100
+ self,
101
+ question: str,
102
+ context: str,
103
+ history: Optional[List[dict]] = None,
104
+ max_history: int = 4,
105
+ ) -> Generator[str, None, None]:
106
+ """Yield response tokens as a stream (cloud engine only)."""
107
+ if self.engine == "cloud":
108
+ compressed = (
109
+ self.compressor.compress(context)
110
+ if self.compressor is not None
111
+ else context
112
+ )
113
+
114
+ messages: list[dict] = [{"role": "system", "content": compressed}]
115
+ if history:
116
+ messages.extend(history[-max_history:])
117
+ messages.append({"role": "user", "content": question})
118
+
119
+ yield from self.client.stream_request(messages)
120
+ else:
121
+ raise NotImplementedError(
122
+ "Streaming is not supported by the local TTT engine."
123
+ )
124
+
125
+ # ------------------------------------------------------------------
126
+ # Persistence (local engine only)
127
+ # ------------------------------------------------------------------
128
+
129
+ def save_state(self, path: str) -> None:
130
+ """Save trained Fast Weights to disk for later resumption."""
131
+ if self.engine == "local":
132
+ self.local_engine.save_state(path)
133
+ else:
134
+ raise NotImplementedError(
135
+ "Cloud engine does not support local state saving."
136
+ )
137
+
138
+ def load_state(self, path: str) -> None:
139
+ """Reload previously-saved Fast Weights from disk."""
140
+ if self.engine == "local":
141
+ self.local_engine.load_state(path)
142
+ else:
143
+ raise NotImplementedError(
144
+ "Cloud engine does not support local state loading."
145
+ )
@@ -0,0 +1,5 @@
1
+ """Local engine components — In-Place Test-Time Training (TTT) pipeline."""
2
+
3
+ from .engine import LocalEngine
4
+
5
+ __all__ = ["LocalEngine"]
@@ -0,0 +1,41 @@
1
+ """
2
+ Token chunk manager for the local TTT engine.
3
+
4
+ Splits massive token sequences into fixed-size windows so they fit inside
5
+ the model's KV-Cache limits during the Test-Time Training reading phase.
6
+ """
7
+
8
+ from typing import List
9
+
10
+ __all__ = ["ChunkManager"]
11
+
12
+
13
+ class ChunkManager:
14
+ """Splits token lists into overlapping fixed-size chunks.
15
+
16
+ Args:
17
+ chunk_size: Number of tokens per chunk.
18
+ overlap: Number of overlapping tokens between consecutive chunks
19
+ (helps preserve context continuity).
20
+ """
21
+
22
+ def __init__(self, chunk_size: int = 1024, overlap: int = 0):
23
+ self.chunk_size = chunk_size
24
+ self.overlap = overlap
25
+
26
+ def chunk_tokens(self, tokens: List[int]) -> List[List[int]]:
27
+ """Split *tokens* into smaller manageable chunks."""
28
+ if not tokens:
29
+ return []
30
+
31
+ step = self.chunk_size - self.overlap
32
+
33
+ # Ensure we always move forward.
34
+ if step <= 0:
35
+ raise ValueError("Overlap cannot be greater than or equal to chunk_size.")
36
+
37
+ chunks = []
38
+ for i in range(0, len(tokens), step):
39
+ chunks.append(tokens[i : i + self.chunk_size])
40
+
41
+ return chunks
@@ -0,0 +1,195 @@
1
+ """
2
+ Local Test-Time Training (TTT) engine.
3
+
4
+ Orchestrates the full pipeline: model loading → chunking → LoRA injection →
5
+ reading phase (training) → answering phase (inference) → state management.
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ from typing import Any
11
+
12
+ from .chunk_manager import ChunkManager
13
+ from .ttt_module import InPlaceTTT
14
+
15
+ __all__ = ["LocalEngine"]
16
+
17
+ # Lazy import guard — resolved once at first use, not at import time.
18
+ _transformers_available: bool | None = None
19
+
20
+
21
+ def _ensure_transformers():
22
+ """Import transformers on first use and raise a clear error if missing."""
23
+ global _transformers_available
24
+ if _transformers_available is None:
25
+ try:
26
+ import transformers # noqa: F401
27
+ _transformers_available = True
28
+ except ImportError:
29
+ _transformers_available = False
30
+ if not _transformers_available:
31
+ raise ImportError(
32
+ "Transformers is required for the local engine. "
33
+ "Install via: pip install infinite_context[local]"
34
+ )
35
+
36
+
37
+ class LocalEngine:
38
+ """Orchestrates the Local Test-Time Training (TTT) engine.
39
+
40
+ Loads the model in constrained memory (4-bit) and manages the Reading
41
+ and Answering phases.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ model_id: str,
47
+ device: str = "cuda",
48
+ load_in_4bit: bool = True,
49
+ **kwargs: Any,
50
+ ):
51
+ _ensure_transformers()
52
+
53
+ self.device = device if torch.cuda.is_available() else "cpu"
54
+ self.model_id = model_id
55
+
56
+ print(f"Loading base model {model_id} via Transformers into {self.device}...")
57
+
58
+ from transformers import AutoModelForCausalLM, AutoTokenizer
59
+
60
+ quantization_kwargs: dict = {}
61
+ if load_in_4bit and self.device == "cuda":
62
+ try:
63
+ from transformers import BitsAndBytesConfig
64
+ quantization_kwargs["quantization_config"] = BitsAndBytesConfig(
65
+ load_in_4bit=True,
66
+ bnb_4bit_compute_dtype=torch.float16,
67
+ bnb_4bit_use_double_quant=True,
68
+ )
69
+ except ImportError:
70
+ pass
71
+
72
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
73
+ self.model = AutoModelForCausalLM.from_pretrained(
74
+ self.model_id,
75
+ device_map="auto" if self.device == "cuda" else None,
76
+ **quantization_kwargs,
77
+ )
78
+
79
+ self.chunker = ChunkManager(chunk_size=1024, overlap=128)
80
+
81
+ # ------------------------------------------------------------------
82
+ # Core pipeline
83
+ # ------------------------------------------------------------------
84
+
85
+ def generate(
86
+ self,
87
+ question: str,
88
+ context: str,
89
+ epochs_per_chunk: int = 15,
90
+ target_loss: float = 0.05,
91
+ keep_state: bool = False,
92
+ ) -> str:
93
+ """Run the full In-Place TTT pipeline.
94
+
95
+ 1. Break massive context into chunks.
96
+ 2. Inject Fast Weights (LoRA) into the model.
97
+ 3. Reading Phase — train the Fast Weights on the chunks.
98
+ 4. Answering Phase — generate the response to the question.
99
+ 5. State Management — optionally retain or destroy Fast Weights.
100
+ """
101
+ # 1. Chunking
102
+ tokens = self.tokenizer.encode(context, add_special_tokens=False)
103
+ chunks = self.chunker.chunk_tokens(tokens)
104
+
105
+ print(f"Bypassing KV-Cache: Splitting {len(tokens)} tokens into {len(chunks)} chunks.")
106
+
107
+ # 2. Attach TTT Fast Weights (skip if a checkpoint was loaded)
108
+ is_peft = hasattr(self.model, "peft_config")
109
+ ttt: InPlaceTTT | None = None
110
+ if not is_peft:
111
+ print("Injecting TTT Fast Weights into down_proj layers...")
112
+ ttt = InPlaceTTT(self.model)
113
+ else:
114
+ print("Model already has Fast Weights loaded from disk. Reusing them...")
115
+
116
+ # 3. The 'Reading' Phase
117
+ if chunks:
118
+ print("Starting Reading Phase (Test-Time Training)...")
119
+ if ttt is None:
120
+ print("Warning: Cannot train loaded Fast Weights in-place. Skipping training.")
121
+ else:
122
+ for i, chunk in enumerate(chunks):
123
+ # Hoist tensor creation outside the epoch loop.
124
+ input_ids = torch.tensor([chunk], device=self.device)
125
+ for epoch in range(epochs_per_chunk):
126
+ loss = ttt.train_on_chunk(input_ids)
127
+ if loss < target_loss:
128
+ break
129
+ print(f" Processed Chunk {i + 1}/{len(chunks)} - Epochs: {epoch + 1} - Loss: {loss:.4f}")
130
+
131
+ # KV Cache is implicitly cleared here because we are not passing
132
+ # past_key_values!
133
+
134
+ # 4. The 'Answering' Phase
135
+ print("Starting Answering Phase...")
136
+ self.model.eval()
137
+
138
+ prompt = f"Based on the codebase I just showed you, answer this question:\n{question}\nAnswer:"
139
+ prompt_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
140
+
141
+ device_type = self.device if self.device in ("cuda", "cpu") else "cpu"
142
+ with torch.no_grad():
143
+ output_ids = self.model.generate(
144
+ prompt_ids,
145
+ max_new_tokens=75,
146
+ do_sample=False,
147
+ repetition_penalty=1.1,
148
+ )
149
+
150
+ # Extract just the newly generated tokens.
151
+ response_tokens = output_ids[0][prompt_ids.shape[1]:]
152
+ response = self.tokenizer.decode(response_tokens, skip_special_tokens=True)
153
+
154
+ # 5. State Management
155
+ if not keep_state:
156
+ if ttt is not None:
157
+ print("Resetting model state (deleting Fast Weights)...")
158
+ ttt.remove_fast_weights()
159
+ else:
160
+ print("Cannot reset loaded Fast Weights. Retaining state.")
161
+ else:
162
+ print("Keeping model state (Fast Weights retained)...")
163
+ if ttt is not None:
164
+ # Update the reference so save_pretrained saves the adapter.
165
+ self.model = ttt.model
166
+
167
+ return response
168
+
169
+ # ------------------------------------------------------------------
170
+ # Persistence
171
+ # ------------------------------------------------------------------
172
+
173
+ def save_state(self, path: str) -> None:
174
+ """Save the current Fast Weights (LoRA adapter) to disk."""
175
+ if hasattr(self.model, "peft_config"):
176
+ self.model.save_pretrained(path)
177
+ print(f"State saved to {path}")
178
+ else:
179
+ raise RuntimeError(
180
+ "No active Fast Weights to save. "
181
+ "Did you forget keep_state=True when calling generate()?"
182
+ )
183
+
184
+ def load_state(self, path: str) -> None:
185
+ """Load Fast Weights (LoRA adapter) from disk."""
186
+ from peft import PeftModel
187
+
188
+ adapter_path = os.path.join(path, "ttt_fast_weights")
189
+ if not os.path.exists(adapter_path):
190
+ # Fallback for standard PEFT saves.
191
+ adapter_path = path
192
+ self.model = PeftModel.from_pretrained(
193
+ self.model, adapter_path, is_trainable=True,
194
+ )
195
+ print(f"State loaded from {adapter_path}")
@@ -0,0 +1,133 @@
1
+ """
2
+ In-Place Test-Time Training (TTT) module.
3
+
4
+ Attaches transient LoRA fast-weight adapters to an LLM's MLP layers and
5
+ trains them on incoming context chunks so the model "memorises" the document
6
+ into its weights at inference time.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import List, Optional
12
+
13
+ __all__ = ["InPlaceTTT"]
14
+
15
+ # Lazy guard — resolved once at first use, not at import time.
16
+ _peft_available: Optional[bool] = None
17
+
18
+
19
+ def _ensure_peft():
20
+ """Import peft on first use and raise a clear error if missing."""
21
+ global _peft_available
22
+ if _peft_available is None:
23
+ try:
24
+ import peft # noqa: F401
25
+ _peft_available = True
26
+ except ImportError:
27
+ _peft_available = False
28
+ if not _peft_available:
29
+ raise ImportError(
30
+ "PEFT is required for the local engine. "
31
+ "Install via: pip install infinite_context[local]"
32
+ )
33
+
34
+
35
+ class InPlaceTTT:
36
+ """Manages LoRA adapter injection, test-time training, and teardown."""
37
+
38
+ def __init__(
39
+ self,
40
+ model: nn.Module,
41
+ target_modules: Optional[List[str]] = None,
42
+ lr: float = 5e-3,
43
+ adapter_name: str = "ttt_fast_weights",
44
+ ):
45
+ """
46
+ Args:
47
+ model: The Hugging Face PreTrainedModel.
48
+ target_modules: MLP layer names to attach fast weights to.
49
+ Defaults to ``["down_proj"]``.
50
+ lr: Learning rate for the test-time training update.
51
+ adapter_name: Name of the PEFT adapter used for the fast weights.
52
+ """
53
+ _ensure_peft()
54
+
55
+ self.model = model
56
+ self.target_modules = target_modules or ["down_proj"]
57
+ self.lr = lr
58
+ self.adapter_name = adapter_name
59
+ self.optimizer: Optional[torch.optim.Optimizer] = None
60
+
61
+ self._setup_lora()
62
+
63
+ # ------------------------------------------------------------------
64
+ # Internals
65
+ # ------------------------------------------------------------------
66
+
67
+ def _setup_lora(self) -> None:
68
+ """Inject Fast Weights (LoRA) into the base model.
69
+
70
+ By setting ``r=8`` and targeting ``down_proj``, we turn the static
71
+ MLP into an active, online memory bank.
72
+ """
73
+ from peft import LoraConfig, get_peft_model
74
+ config = LoraConfig(
75
+ r=8,
76
+ lora_alpha=16,
77
+ target_modules=self.target_modules,
78
+ lora_dropout=0.0,
79
+ bias="none",
80
+ task_type="CAUSAL_LM",
81
+ use_dora=True,
82
+ )
83
+ self.model = get_peft_model(self.model, config, adapter_name=self.adapter_name)
84
+
85
+ # Train ONLY the LoRA parameters — not the frozen base weights.
86
+ trainable_params = [p for p in self.model.parameters() if p.requires_grad]
87
+ self.optimizer = torch.optim.AdamW(trainable_params, lr=self.lr)
88
+
89
+ # ------------------------------------------------------------------
90
+ # Public API
91
+ # ------------------------------------------------------------------
92
+
93
+ def train_on_chunk(
94
+ self,
95
+ input_ids: torch.Tensor,
96
+ attention_mask: Optional[torch.Tensor] = None,
97
+ ) -> float:
98
+ """Run one gradient step on *input_ids* (the 'Reading Phase').
99
+
100
+ Uses the standard causal next-token prediction loss which achieves
101
+ the same goal as the paper's Conv1D LM-Aligned objective: baking
102
+ the chunk's knowledge into the ``down_proj`` weights.
103
+ """
104
+ self.model.train()
105
+ self.optimizer.zero_grad(set_to_none=True)
106
+
107
+ if attention_mask is None:
108
+ attention_mask = torch.ones_like(input_ids)
109
+
110
+ # Mixed-precision forward pass (FP16 on CUDA, no-op on CPU).
111
+ device_type = input_ids.device.type
112
+ with torch.amp.autocast(device_type, enabled=(device_type == "cuda")):
113
+ outputs = self.model(
114
+ input_ids=input_ids,
115
+ attention_mask=attention_mask,
116
+ labels=input_ids,
117
+ )
118
+ loss = outputs.loss
119
+
120
+ # Backward + update ONLY the LoRA adapters.
121
+ loss.backward()
122
+ self.optimizer.step()
123
+
124
+ return loss.item()
125
+
126
+ def remove_fast_weights(self) -> None:
127
+ """Delete the LoRA adapter ('State Destruction').
128
+
129
+ The model instantly resets to its base pre-trained state, ready for
130
+ the next request.
131
+ """
132
+ self.model.eval()
133
+ self.model.delete_adapter(self.adapter_name)
@@ -0,0 +1,108 @@
1
+ Metadata-Version: 2.4
2
+ Name: infinite-context-gateway
3
+ Version: 0.1.0
4
+ Summary: A hybrid AI context gateway combining Headroom cloud compression and In-Place Local Test-Time Training (TTT).
5
+ Author-email: Dev Team <dev@example.com>
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Operating System :: OS Independent
9
+ Requires-Python: >=3.9
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: httpx
12
+ Provides-Extra: cloud
13
+ Requires-Dist: headroom-ai; extra == "cloud"
14
+ Provides-Extra: local
15
+ Requires-Dist: torch; extra == "local"
16
+ Requires-Dist: transformers; extra == "local"
17
+ Requires-Dist: peft; extra == "local"
18
+ Requires-Dist: accelerate; extra == "local"
19
+ Requires-Dist: bitsandbytes; extra == "local"
20
+ Provides-Extra: all
21
+ Requires-Dist: headroom-ai; extra == "all"
22
+ Requires-Dist: torch; extra == "all"
23
+ Requires-Dist: transformers; extra == "all"
24
+ Requires-Dist: peft; extra == "all"
25
+ Requires-Dist: accelerate; extra == "all"
26
+ Requires-Dist: bitsandbytes; extra == "all"
27
+
28
+ # Infinite Context
29
+
30
+ A hybrid AI context gateway combining **Headroom cloud compression** and **In-Place Local Test-Time Training (TTT)**.
31
+
32
+ ## Installation
33
+
34
+ To install the base gateway:
35
+ ```bash
36
+ pip install infinite_context
37
+ ```
38
+
39
+ To install with Cloud Phase 1 dependencies (Headroom API):
40
+ ```bash
41
+ pip install infinite_context[cloud]
42
+ ```
43
+
44
+ To install with Local Phase 2 dependencies (PyTorch, Transformers, PEFT):
45
+ ```bash
46
+ pip install infinite_context[local]
47
+ ```
48
+
49
+ ## Usage
50
+
51
+ ### Phase 1: Cloud Compression
52
+ Uses the Headroom API to semantically compress a massive context and send it to cloud providers (OpenAI, Anthropic) while keeping costs low.
53
+
54
+ ```python
55
+ from infinite_context import ContextGateway
56
+
57
+ gateway = ContextGateway(
58
+ engine="cloud",
59
+ model_id="claude-3-5-sonnet-20240620",
60
+ api_key="your_anthropic_api_key",
61
+ compression_ratio=0.8
62
+ )
63
+
64
+ # You can pass in conversational history (rolling window)
65
+ history = [
66
+ {"role": "user", "content": "What is the codebase about?"},
67
+ {"role": "assistant", "content": "It is a Python application..."}
68
+ ]
69
+
70
+ response = gateway.chat("How does the failover protocol work?", massive_context, history=history)
71
+ print(response)
72
+ ```
73
+
74
+ ### Phase 2: Local Test-Time Training (TTT)
75
+ Bypasses the KV-Cache entirely by injecting a PEFT LoRA adapter and baking the context directly into the model's neural weights on your local GPU. Includes Early Stopping latency optimizations and Generation repetition penalties.
76
+
77
+ ```python
78
+ from infinite_context import ContextGateway
79
+
80
+ gateway = ContextGateway(
81
+ engine="local",
82
+ model_id="Qwen/Qwen2.5-0.5B-Instruct",
83
+ load_in_4bit=True
84
+ )
85
+
86
+ response = gateway.chat("What is the failover protocol?", massive_context)
87
+ print(response)
88
+ ```
89
+
90
+ ### Checkpoint Persistence
91
+ Save trained Fast Weights to disk and resume later without re-reading the document:
92
+
93
+ ```python
94
+ from infinite_context import ContextGateway
95
+
96
+ # Train and keep state
97
+ gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
98
+ gateway.chat("Summarise the document.", massive_context, keep_state=True)
99
+ gateway.save_state("./my_checkpoint")
100
+
101
+ del gateway # Free all GPU memory
102
+
103
+ # Resume later — zero re-training
104
+ gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
105
+ gateway.load_state("./my_checkpoint")
106
+ response = gateway.chat("What was the protocol?", context="")
107
+ print(response)
108
+ ```
@@ -0,0 +1,16 @@
1
+ README.md
2
+ pyproject.toml
3
+ infinite_context/__init__.py
4
+ infinite_context/gateway.py
5
+ infinite_context/cloud/__init__.py
6
+ infinite_context/cloud/client.py
7
+ infinite_context/cloud/compressor.py
8
+ infinite_context/local/__init__.py
9
+ infinite_context/local/chunk_manager.py
10
+ infinite_context/local/engine.py
11
+ infinite_context/local/ttt_module.py
12
+ infinite_context_gateway.egg-info/PKG-INFO
13
+ infinite_context_gateway.egg-info/SOURCES.txt
14
+ infinite_context_gateway.egg-info/dependency_links.txt
15
+ infinite_context_gateway.egg-info/requires.txt
16
+ infinite_context_gateway.egg-info/top_level.txt
@@ -0,0 +1,19 @@
1
+ httpx
2
+
3
+ [all]
4
+ headroom-ai
5
+ torch
6
+ transformers
7
+ peft
8
+ accelerate
9
+ bitsandbytes
10
+
11
+ [cloud]
12
+ headroom-ai
13
+
14
+ [local]
15
+ torch
16
+ transformers
17
+ peft
18
+ accelerate
19
+ bitsandbytes
@@ -0,0 +1,44 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "infinite-context-gateway"
7
+ version = "0.1.0"
8
+ authors = [
9
+ { name="Dev Team", email="dev@example.com" },
10
+ ]
11
+ description = "A hybrid AI context gateway combining Headroom cloud compression and In-Place Local Test-Time Training (TTT)."
12
+ readme = "README.md"
13
+ requires-python = ">=3.9"
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ dependencies = [
20
+ "httpx",
21
+ ]
22
+
23
+ [project.optional-dependencies]
24
+ cloud = [
25
+ "headroom-ai",
26
+ ]
27
+ local = [
28
+ "torch",
29
+ "transformers",
30
+ "peft",
31
+ "accelerate",
32
+ "bitsandbytes",
33
+ ]
34
+ all = [
35
+ "headroom-ai",
36
+ "torch",
37
+ "transformers",
38
+ "peft",
39
+ "accelerate",
40
+ "bitsandbytes",
41
+ ]
42
+
43
+ [tool.setuptools.packages.find]
44
+ include = ["infinite_context*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+