infinite-context-gateway 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- infinite_context_gateway-0.1.0/PKG-INFO +108 -0
- infinite_context_gateway-0.1.0/README.md +81 -0
- infinite_context_gateway-0.1.0/infinite_context/__init__.py +10 -0
- infinite_context_gateway-0.1.0/infinite_context/cloud/__init__.py +6 -0
- infinite_context_gateway-0.1.0/infinite_context/cloud/client.py +127 -0
- infinite_context_gateway-0.1.0/infinite_context/cloud/compressor.py +44 -0
- infinite_context_gateway-0.1.0/infinite_context/gateway.py +145 -0
- infinite_context_gateway-0.1.0/infinite_context/local/__init__.py +5 -0
- infinite_context_gateway-0.1.0/infinite_context/local/chunk_manager.py +41 -0
- infinite_context_gateway-0.1.0/infinite_context/local/engine.py +195 -0
- infinite_context_gateway-0.1.0/infinite_context/local/ttt_module.py +133 -0
- infinite_context_gateway-0.1.0/infinite_context_gateway.egg-info/PKG-INFO +108 -0
- infinite_context_gateway-0.1.0/infinite_context_gateway.egg-info/SOURCES.txt +16 -0
- infinite_context_gateway-0.1.0/infinite_context_gateway.egg-info/dependency_links.txt +1 -0
- infinite_context_gateway-0.1.0/infinite_context_gateway.egg-info/requires.txt +19 -0
- infinite_context_gateway-0.1.0/infinite_context_gateway.egg-info/top_level.txt +1 -0
- infinite_context_gateway-0.1.0/pyproject.toml +44 -0
- infinite_context_gateway-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: infinite-context-gateway
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A hybrid AI context gateway combining Headroom cloud compression and In-Place Local Test-Time Training (TTT).
|
|
5
|
+
Author-email: Dev Team <dev@example.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: httpx
|
|
12
|
+
Provides-Extra: cloud
|
|
13
|
+
Requires-Dist: headroom-ai; extra == "cloud"
|
|
14
|
+
Provides-Extra: local
|
|
15
|
+
Requires-Dist: torch; extra == "local"
|
|
16
|
+
Requires-Dist: transformers; extra == "local"
|
|
17
|
+
Requires-Dist: peft; extra == "local"
|
|
18
|
+
Requires-Dist: accelerate; extra == "local"
|
|
19
|
+
Requires-Dist: bitsandbytes; extra == "local"
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: headroom-ai; extra == "all"
|
|
22
|
+
Requires-Dist: torch; extra == "all"
|
|
23
|
+
Requires-Dist: transformers; extra == "all"
|
|
24
|
+
Requires-Dist: peft; extra == "all"
|
|
25
|
+
Requires-Dist: accelerate; extra == "all"
|
|
26
|
+
Requires-Dist: bitsandbytes; extra == "all"
|
|
27
|
+
|
|
28
|
+
# Infinite Context
|
|
29
|
+
|
|
30
|
+
A hybrid AI context gateway combining **Headroom cloud compression** and **In-Place Local Test-Time Training (TTT)**.
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
To install the base gateway:
|
|
35
|
+
```bash
|
|
36
|
+
pip install infinite_context
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
To install with Cloud Phase 1 dependencies (Headroom API):
|
|
40
|
+
```bash
|
|
41
|
+
pip install infinite_context[cloud]
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
To install with Local Phase 2 dependencies (PyTorch, Transformers, PEFT):
|
|
45
|
+
```bash
|
|
46
|
+
pip install infinite_context[local]
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Usage
|
|
50
|
+
|
|
51
|
+
### Phase 1: Cloud Compression
|
|
52
|
+
Uses the Headroom API to semantically compress a massive context and send it to cloud providers (OpenAI, Anthropic) while keeping costs low.
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
from infinite_context import ContextGateway
|
|
56
|
+
|
|
57
|
+
gateway = ContextGateway(
|
|
58
|
+
engine="cloud",
|
|
59
|
+
model_id="claude-3-5-sonnet-20240620",
|
|
60
|
+
api_key="your_anthropic_api_key",
|
|
61
|
+
compression_ratio=0.8
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# You can pass in conversational history (rolling window)
|
|
65
|
+
history = [
|
|
66
|
+
{"role": "user", "content": "What is the codebase about?"},
|
|
67
|
+
{"role": "assistant", "content": "It is a Python application..."}
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
response = gateway.chat("How does the failover protocol work?", massive_context, history=history)
|
|
71
|
+
print(response)
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
### Phase 2: Local Test-Time Training (TTT)
|
|
75
|
+
Bypasses the KV-Cache entirely by injecting a PEFT LoRA adapter and baking the context directly into the model's neural weights on your local GPU. Includes Early Stopping latency optimizations and Generation repetition penalties.
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
from infinite_context import ContextGateway
|
|
79
|
+
|
|
80
|
+
gateway = ContextGateway(
|
|
81
|
+
engine="local",
|
|
82
|
+
model_id="Qwen/Qwen2.5-0.5B-Instruct",
|
|
83
|
+
load_in_4bit=True
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
response = gateway.chat("What is the failover protocol?", massive_context)
|
|
87
|
+
print(response)
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### Checkpoint Persistence
|
|
91
|
+
Save trained Fast Weights to disk and resume later without re-reading the document:
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
from infinite_context import ContextGateway
|
|
95
|
+
|
|
96
|
+
# Train and keep state
|
|
97
|
+
gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
|
|
98
|
+
gateway.chat("Summarise the document.", massive_context, keep_state=True)
|
|
99
|
+
gateway.save_state("./my_checkpoint")
|
|
100
|
+
|
|
101
|
+
del gateway # Free all GPU memory
|
|
102
|
+
|
|
103
|
+
# Resume later — zero re-training
|
|
104
|
+
gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
|
|
105
|
+
gateway.load_state("./my_checkpoint")
|
|
106
|
+
response = gateway.chat("What was the protocol?", context="")
|
|
107
|
+
print(response)
|
|
108
|
+
```
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Infinite Context
|
|
2
|
+
|
|
3
|
+
A hybrid AI context gateway combining **Headroom cloud compression** and **In-Place Local Test-Time Training (TTT)**.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
To install the base gateway:
|
|
8
|
+
```bash
|
|
9
|
+
pip install infinite_context
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
To install with Cloud Phase 1 dependencies (Headroom API):
|
|
13
|
+
```bash
|
|
14
|
+
pip install infinite_context[cloud]
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
To install with Local Phase 2 dependencies (PyTorch, Transformers, PEFT):
|
|
18
|
+
```bash
|
|
19
|
+
pip install infinite_context[local]
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Usage
|
|
23
|
+
|
|
24
|
+
### Phase 1: Cloud Compression
|
|
25
|
+
Uses the Headroom API to semantically compress a massive context and send it to cloud providers (OpenAI, Anthropic) while keeping costs low.
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
from infinite_context import ContextGateway
|
|
29
|
+
|
|
30
|
+
gateway = ContextGateway(
|
|
31
|
+
engine="cloud",
|
|
32
|
+
model_id="claude-3-5-sonnet-20240620",
|
|
33
|
+
api_key="your_anthropic_api_key",
|
|
34
|
+
compression_ratio=0.8
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# You can pass in conversational history (rolling window)
|
|
38
|
+
history = [
|
|
39
|
+
{"role": "user", "content": "What is the codebase about?"},
|
|
40
|
+
{"role": "assistant", "content": "It is a Python application..."}
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
response = gateway.chat("How does the failover protocol work?", massive_context, history=history)
|
|
44
|
+
print(response)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
### Phase 2: Local Test-Time Training (TTT)
|
|
48
|
+
Bypasses the KV-Cache entirely by injecting a PEFT LoRA adapter and baking the context directly into the model's neural weights on your local GPU. Includes Early Stopping latency optimizations and Generation repetition penalties.
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from infinite_context import ContextGateway
|
|
52
|
+
|
|
53
|
+
gateway = ContextGateway(
|
|
54
|
+
engine="local",
|
|
55
|
+
model_id="Qwen/Qwen2.5-0.5B-Instruct",
|
|
56
|
+
load_in_4bit=True
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
response = gateway.chat("What is the failover protocol?", massive_context)
|
|
60
|
+
print(response)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### Checkpoint Persistence
|
|
64
|
+
Save trained Fast Weights to disk and resume later without re-reading the document:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
from infinite_context import ContextGateway
|
|
68
|
+
|
|
69
|
+
# Train and keep state
|
|
70
|
+
gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
|
|
71
|
+
gateway.chat("Summarise the document.", massive_context, keep_state=True)
|
|
72
|
+
gateway.save_state("./my_checkpoint")
|
|
73
|
+
|
|
74
|
+
del gateway # Free all GPU memory
|
|
75
|
+
|
|
76
|
+
# Resume later — zero re-training
|
|
77
|
+
gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
|
|
78
|
+
gateway.load_state("./my_checkpoint")
|
|
79
|
+
response = gateway.chat("What was the protocol?", context="")
|
|
80
|
+
print(response)
|
|
81
|
+
```
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cloud API client for OpenAI, Anthropic, and OpenAI-compatible endpoints.
|
|
3
|
+
|
|
4
|
+
Handles payload formatting, header construction, and SSE streaming for
|
|
5
|
+
providers like Groq, Together, etc.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import httpx
|
|
10
|
+
from typing import Dict, Generator, List
|
|
11
|
+
|
|
12
|
+
__all__ = ["CloudClient"]
|
|
13
|
+
|
|
14
|
+
# Shared timeout: 30 s connect, 120 s read (large prompts can be slow).
|
|
15
|
+
_DEFAULT_TIMEOUT = httpx.Timeout(connect=30.0, read=120.0, write=30.0, pool=10.0)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CloudClient:
|
|
19
|
+
"""HTTP client for cloud LLM providers."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, api_key: str, model_id: str, base_url: str = None):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
api_key: Bearer / x-api-key token.
|
|
25
|
+
model_id: Model identifier (e.g. ``claude-3-5-sonnet-20240620``).
|
|
26
|
+
base_url: Custom endpoint URL. When set, the OpenAI chat-
|
|
27
|
+
completions format is assumed.
|
|
28
|
+
"""
|
|
29
|
+
self.api_key = api_key
|
|
30
|
+
self.model_id = model_id
|
|
31
|
+
|
|
32
|
+
# Determine endpoint and provider format.
|
|
33
|
+
if base_url:
|
|
34
|
+
self.base_url = base_url
|
|
35
|
+
self.provider = "openai"
|
|
36
|
+
elif "claude" in self.model_id.lower():
|
|
37
|
+
self.base_url = "https://api.anthropic.com/v1/messages"
|
|
38
|
+
self.provider = "anthropic"
|
|
39
|
+
else:
|
|
40
|
+
self.base_url = "https://api.openai.com/v1/chat/completions"
|
|
41
|
+
self.provider = "openai"
|
|
42
|
+
|
|
43
|
+
# ------------------------------------------------------------------
|
|
44
|
+
# Internals
|
|
45
|
+
# ------------------------------------------------------------------
|
|
46
|
+
|
|
47
|
+
def _prepare_headers(self) -> Dict[str, str]:
|
|
48
|
+
if self.provider == "anthropic":
|
|
49
|
+
return {
|
|
50
|
+
"x-api-key": self.api_key,
|
|
51
|
+
"anthropic-version": "2023-06-01",
|
|
52
|
+
"content-type": "application/json",
|
|
53
|
+
}
|
|
54
|
+
return {
|
|
55
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
56
|
+
"Content-Type": "application/json",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def _prepare_payload(
|
|
60
|
+
self,
|
|
61
|
+
messages: List[Dict[str, str]],
|
|
62
|
+
stream: bool = False,
|
|
63
|
+
) -> dict:
|
|
64
|
+
if self.provider == "anthropic":
|
|
65
|
+
# Anthropic splits the system prompt out of the messages array.
|
|
66
|
+
system_msg = next(
|
|
67
|
+
(m["content"] for m in messages if m["role"] == "system"), ""
|
|
68
|
+
)
|
|
69
|
+
user_msgs = [m for m in messages if m["role"] != "system"]
|
|
70
|
+
return {
|
|
71
|
+
"model": self.model_id,
|
|
72
|
+
"system": system_msg,
|
|
73
|
+
"messages": user_msgs,
|
|
74
|
+
"max_tokens": 4096,
|
|
75
|
+
"stream": stream,
|
|
76
|
+
}
|
|
77
|
+
return {
|
|
78
|
+
"model": self.model_id,
|
|
79
|
+
"messages": messages,
|
|
80
|
+
"stream": stream,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# ------------------------------------------------------------------
|
|
84
|
+
# Public API
|
|
85
|
+
# ------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
def send_request(self, messages: List[Dict[str, str]]) -> str:
|
|
88
|
+
"""Send a non-streaming chat completion request."""
|
|
89
|
+
payload = self._prepare_payload(messages, stream=False)
|
|
90
|
+
headers = self._prepare_headers()
|
|
91
|
+
|
|
92
|
+
with httpx.Client(timeout=_DEFAULT_TIMEOUT) as client:
|
|
93
|
+
response = client.post(self.base_url, json=payload, headers=headers)
|
|
94
|
+
response.raise_for_status()
|
|
95
|
+
data = response.json()
|
|
96
|
+
|
|
97
|
+
if self.provider == "anthropic":
|
|
98
|
+
return data["content"][0]["text"]
|
|
99
|
+
return data["choices"][0]["message"]["content"]
|
|
100
|
+
|
|
101
|
+
def stream_request(
|
|
102
|
+
self, messages: List[Dict[str, str]]
|
|
103
|
+
) -> Generator[str, None, None]:
|
|
104
|
+
"""Yield text deltas from an SSE streaming response."""
|
|
105
|
+
payload = self._prepare_payload(messages, stream=True)
|
|
106
|
+
headers = self._prepare_headers()
|
|
107
|
+
|
|
108
|
+
with httpx.Client(timeout=_DEFAULT_TIMEOUT) as client:
|
|
109
|
+
with client.stream(
|
|
110
|
+
"POST", self.base_url, json=payload, headers=headers
|
|
111
|
+
) as response:
|
|
112
|
+
response.raise_for_status()
|
|
113
|
+
for line in response.iter_lines():
|
|
114
|
+
if not line or not line.startswith("data:"):
|
|
115
|
+
continue
|
|
116
|
+
data_str = line[len("data:"):].strip()
|
|
117
|
+
if data_str == "[DONE]":
|
|
118
|
+
return
|
|
119
|
+
try:
|
|
120
|
+
chunk = json.loads(data_str)
|
|
121
|
+
delta = chunk["choices"][0].get("delta", {})
|
|
122
|
+
content = delta.get("content")
|
|
123
|
+
if content:
|
|
124
|
+
yield content
|
|
125
|
+
except (json.JSONDecodeError, KeyError, IndexError):
|
|
126
|
+
# Malformed chunk — skip gracefully.
|
|
127
|
+
continue
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Headroom semantic compression wrapper.
|
|
3
|
+
|
|
4
|
+
Compresses massive context payloads locally before sending them to a cloud
|
|
5
|
+
LLM, reducing token counts and API costs.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__all__ = ["HeadroomCompressor"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HeadroomCompressor:
|
|
12
|
+
"""Wraps the ``headroom-ai`` library with a graceful fallback."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, target_ratio: float = 0.8):
|
|
15
|
+
self.target_ratio = target_ratio
|
|
16
|
+
self._compress_fn = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from headroom import compress
|
|
20
|
+
self._compress_fn = compress
|
|
21
|
+
except ImportError:
|
|
22
|
+
print(
|
|
23
|
+
"Warning: headroom-ai not installed. "
|
|
24
|
+
"Using fallback semantic compression simulator."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def compress(self, context: str) -> str:
|
|
28
|
+
"""Compress *context* using Headroom's algorithms.
|
|
29
|
+
|
|
30
|
+
Falls back to a head + tail truncation when the library is absent.
|
|
31
|
+
"""
|
|
32
|
+
if self._compress_fn is not None:
|
|
33
|
+
return self._compress_fn(context)
|
|
34
|
+
|
|
35
|
+
# Fallback simulator: keep the first 1000 chars and the last 2000
|
|
36
|
+
# chars (where needles usually are).
|
|
37
|
+
if len(context) < 3000:
|
|
38
|
+
return context
|
|
39
|
+
|
|
40
|
+
return (
|
|
41
|
+
context[:1000]
|
|
42
|
+
+ "\n\n...[CONTENT COMPRESSED BY HEADROOM SIMULATOR]...\n\n"
|
|
43
|
+
+ context[-2000:]
|
|
44
|
+
)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context Gateway — unified interface for the cloud and local engines.
|
|
3
|
+
|
|
4
|
+
Users interact exclusively through :class:`ContextGateway`, selecting either
|
|
5
|
+
``engine="cloud"`` (Headroom compression + API) or ``engine="local"``
|
|
6
|
+
(In-Place Test-Time Training on a local GPU).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Generator, List, Optional
|
|
10
|
+
|
|
11
|
+
__all__ = ["ContextGateway"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ContextGateway:
|
|
15
|
+
"""Hybrid gateway that routes queries through either a cloud API or a
|
|
16
|
+
local TTT engine depending on configuration.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
engine: ``"cloud"`` for Headroom API compression, or ``"local"``
|
|
20
|
+
for In-Place TTT.
|
|
21
|
+
model_id: Target model (e.g. ``"claude-3-5-sonnet-20240620"`` or a
|
|
22
|
+
local HuggingFace repo id).
|
|
23
|
+
api_key: API key for cloud routing (if applicable).
|
|
24
|
+
base_url: Custom base URL for OpenAI-compatible providers like Groq.
|
|
25
|
+
compression_ratio: Target compression ratio for Headroom (cloud only).
|
|
26
|
+
skip_compression: If *True*, skips Headroom compression (useful for
|
|
27
|
+
testing alternative APIs directly).
|
|
28
|
+
**kwargs: Additional parameters forwarded to :class:`LocalEngine`
|
|
29
|
+
(e.g. ``device``, ``load_in_4bit``).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
engine: str = "cloud",
|
|
35
|
+
model_id: str = "claude-3-5-sonnet-20240620",
|
|
36
|
+
api_key: Optional[str] = None,
|
|
37
|
+
base_url: Optional[str] = None,
|
|
38
|
+
compression_ratio: float = 0.8,
|
|
39
|
+
skip_compression: bool = False,
|
|
40
|
+
**kwargs: Any,
|
|
41
|
+
):
|
|
42
|
+
if engine not in ("cloud", "local"):
|
|
43
|
+
raise ValueError("Engine must be either 'cloud' or 'local'.")
|
|
44
|
+
|
|
45
|
+
self.engine = engine
|
|
46
|
+
self.model_id = model_id
|
|
47
|
+
|
|
48
|
+
if self.engine == "cloud":
|
|
49
|
+
from .cloud.compressor import HeadroomCompressor
|
|
50
|
+
from .cloud.client import CloudClient
|
|
51
|
+
|
|
52
|
+
self.compressor = (
|
|
53
|
+
HeadroomCompressor(target_ratio=compression_ratio)
|
|
54
|
+
if not skip_compression
|
|
55
|
+
else None
|
|
56
|
+
)
|
|
57
|
+
self.client = CloudClient(
|
|
58
|
+
api_key=api_key, model_id=self.model_id, base_url=base_url,
|
|
59
|
+
)
|
|
60
|
+
else: # local
|
|
61
|
+
from .local.engine import LocalEngine
|
|
62
|
+
self.local_engine = LocalEngine(model_id=self.model_id, **kwargs)
|
|
63
|
+
|
|
64
|
+
# ------------------------------------------------------------------
|
|
65
|
+
# Chat
|
|
66
|
+
# ------------------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
def chat(
|
|
69
|
+
self,
|
|
70
|
+
question: str,
|
|
71
|
+
context: str,
|
|
72
|
+
history: Optional[List[dict]] = None,
|
|
73
|
+
max_history: int = 4,
|
|
74
|
+
keep_state: bool = False,
|
|
75
|
+
) -> str:
|
|
76
|
+
"""Process a massive context payload and answer the question.
|
|
77
|
+
|
|
78
|
+
Includes a rolling-window history to keep token costs low.
|
|
79
|
+
"""
|
|
80
|
+
if self.engine == "cloud":
|
|
81
|
+
compressed = (
|
|
82
|
+
self.compressor.compress(context)
|
|
83
|
+
if self.compressor is not None
|
|
84
|
+
else context
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
messages: list[dict] = [{"role": "system", "content": compressed}]
|
|
88
|
+
if history:
|
|
89
|
+
messages.extend(history[-max_history:])
|
|
90
|
+
messages.append({"role": "user", "content": question})
|
|
91
|
+
|
|
92
|
+
return self.client.send_request(messages)
|
|
93
|
+
|
|
94
|
+
# local
|
|
95
|
+
return self.local_engine.generate(
|
|
96
|
+
question, context, keep_state=keep_state,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def stream_chat(
|
|
100
|
+
self,
|
|
101
|
+
question: str,
|
|
102
|
+
context: str,
|
|
103
|
+
history: Optional[List[dict]] = None,
|
|
104
|
+
max_history: int = 4,
|
|
105
|
+
) -> Generator[str, None, None]:
|
|
106
|
+
"""Yield response tokens as a stream (cloud engine only)."""
|
|
107
|
+
if self.engine == "cloud":
|
|
108
|
+
compressed = (
|
|
109
|
+
self.compressor.compress(context)
|
|
110
|
+
if self.compressor is not None
|
|
111
|
+
else context
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
messages: list[dict] = [{"role": "system", "content": compressed}]
|
|
115
|
+
if history:
|
|
116
|
+
messages.extend(history[-max_history:])
|
|
117
|
+
messages.append({"role": "user", "content": question})
|
|
118
|
+
|
|
119
|
+
yield from self.client.stream_request(messages)
|
|
120
|
+
else:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Streaming is not supported by the local TTT engine."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# ------------------------------------------------------------------
|
|
126
|
+
# Persistence (local engine only)
|
|
127
|
+
# ------------------------------------------------------------------
|
|
128
|
+
|
|
129
|
+
def save_state(self, path: str) -> None:
|
|
130
|
+
"""Save trained Fast Weights to disk for later resumption."""
|
|
131
|
+
if self.engine == "local":
|
|
132
|
+
self.local_engine.save_state(path)
|
|
133
|
+
else:
|
|
134
|
+
raise NotImplementedError(
|
|
135
|
+
"Cloud engine does not support local state saving."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def load_state(self, path: str) -> None:
|
|
139
|
+
"""Reload previously-saved Fast Weights from disk."""
|
|
140
|
+
if self.engine == "local":
|
|
141
|
+
self.local_engine.load_state(path)
|
|
142
|
+
else:
|
|
143
|
+
raise NotImplementedError(
|
|
144
|
+
"Cloud engine does not support local state loading."
|
|
145
|
+
)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Token chunk manager for the local TTT engine.
|
|
3
|
+
|
|
4
|
+
Splits massive token sequences into fixed-size windows so they fit inside
|
|
5
|
+
the model's KV-Cache limits during the Test-Time Training reading phase.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List
|
|
9
|
+
|
|
10
|
+
__all__ = ["ChunkManager"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ChunkManager:
|
|
14
|
+
"""Splits token lists into overlapping fixed-size chunks.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
chunk_size: Number of tokens per chunk.
|
|
18
|
+
overlap: Number of overlapping tokens between consecutive chunks
|
|
19
|
+
(helps preserve context continuity).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, chunk_size: int = 1024, overlap: int = 0):
|
|
23
|
+
self.chunk_size = chunk_size
|
|
24
|
+
self.overlap = overlap
|
|
25
|
+
|
|
26
|
+
def chunk_tokens(self, tokens: List[int]) -> List[List[int]]:
|
|
27
|
+
"""Split *tokens* into smaller manageable chunks."""
|
|
28
|
+
if not tokens:
|
|
29
|
+
return []
|
|
30
|
+
|
|
31
|
+
step = self.chunk_size - self.overlap
|
|
32
|
+
|
|
33
|
+
# Ensure we always move forward.
|
|
34
|
+
if step <= 0:
|
|
35
|
+
raise ValueError("Overlap cannot be greater than or equal to chunk_size.")
|
|
36
|
+
|
|
37
|
+
chunks = []
|
|
38
|
+
for i in range(0, len(tokens), step):
|
|
39
|
+
chunks.append(tokens[i : i + self.chunk_size])
|
|
40
|
+
|
|
41
|
+
return chunks
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local Test-Time Training (TTT) engine.
|
|
3
|
+
|
|
4
|
+
Orchestrates the full pipeline: model loading → chunking → LoRA injection →
|
|
5
|
+
reading phase (training) → answering phase (inference) → state management.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from .chunk_manager import ChunkManager
|
|
13
|
+
from .ttt_module import InPlaceTTT
|
|
14
|
+
|
|
15
|
+
__all__ = ["LocalEngine"]
|
|
16
|
+
|
|
17
|
+
# Lazy import guard — resolved once at first use, not at import time.
|
|
18
|
+
_transformers_available: bool | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _ensure_transformers():
|
|
22
|
+
"""Import transformers on first use and raise a clear error if missing."""
|
|
23
|
+
global _transformers_available
|
|
24
|
+
if _transformers_available is None:
|
|
25
|
+
try:
|
|
26
|
+
import transformers # noqa: F401
|
|
27
|
+
_transformers_available = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
_transformers_available = False
|
|
30
|
+
if not _transformers_available:
|
|
31
|
+
raise ImportError(
|
|
32
|
+
"Transformers is required for the local engine. "
|
|
33
|
+
"Install via: pip install infinite_context[local]"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LocalEngine:
|
|
38
|
+
"""Orchestrates the Local Test-Time Training (TTT) engine.
|
|
39
|
+
|
|
40
|
+
Loads the model in constrained memory (4-bit) and manages the Reading
|
|
41
|
+
and Answering phases.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
model_id: str,
|
|
47
|
+
device: str = "cuda",
|
|
48
|
+
load_in_4bit: bool = True,
|
|
49
|
+
**kwargs: Any,
|
|
50
|
+
):
|
|
51
|
+
_ensure_transformers()
|
|
52
|
+
|
|
53
|
+
self.device = device if torch.cuda.is_available() else "cpu"
|
|
54
|
+
self.model_id = model_id
|
|
55
|
+
|
|
56
|
+
print(f"Loading base model {model_id} via Transformers into {self.device}...")
|
|
57
|
+
|
|
58
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
59
|
+
|
|
60
|
+
quantization_kwargs: dict = {}
|
|
61
|
+
if load_in_4bit and self.device == "cuda":
|
|
62
|
+
try:
|
|
63
|
+
from transformers import BitsAndBytesConfig
|
|
64
|
+
quantization_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
65
|
+
load_in_4bit=True,
|
|
66
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
67
|
+
bnb_4bit_use_double_quant=True,
|
|
68
|
+
)
|
|
69
|
+
except ImportError:
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
73
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
74
|
+
self.model_id,
|
|
75
|
+
device_map="auto" if self.device == "cuda" else None,
|
|
76
|
+
**quantization_kwargs,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self.chunker = ChunkManager(chunk_size=1024, overlap=128)
|
|
80
|
+
|
|
81
|
+
# ------------------------------------------------------------------
|
|
82
|
+
# Core pipeline
|
|
83
|
+
# ------------------------------------------------------------------
|
|
84
|
+
|
|
85
|
+
def generate(
|
|
86
|
+
self,
|
|
87
|
+
question: str,
|
|
88
|
+
context: str,
|
|
89
|
+
epochs_per_chunk: int = 15,
|
|
90
|
+
target_loss: float = 0.05,
|
|
91
|
+
keep_state: bool = False,
|
|
92
|
+
) -> str:
|
|
93
|
+
"""Run the full In-Place TTT pipeline.
|
|
94
|
+
|
|
95
|
+
1. Break massive context into chunks.
|
|
96
|
+
2. Inject Fast Weights (LoRA) into the model.
|
|
97
|
+
3. Reading Phase — train the Fast Weights on the chunks.
|
|
98
|
+
4. Answering Phase — generate the response to the question.
|
|
99
|
+
5. State Management — optionally retain or destroy Fast Weights.
|
|
100
|
+
"""
|
|
101
|
+
# 1. Chunking
|
|
102
|
+
tokens = self.tokenizer.encode(context, add_special_tokens=False)
|
|
103
|
+
chunks = self.chunker.chunk_tokens(tokens)
|
|
104
|
+
|
|
105
|
+
print(f"Bypassing KV-Cache: Splitting {len(tokens)} tokens into {len(chunks)} chunks.")
|
|
106
|
+
|
|
107
|
+
# 2. Attach TTT Fast Weights (skip if a checkpoint was loaded)
|
|
108
|
+
is_peft = hasattr(self.model, "peft_config")
|
|
109
|
+
ttt: InPlaceTTT | None = None
|
|
110
|
+
if not is_peft:
|
|
111
|
+
print("Injecting TTT Fast Weights into down_proj layers...")
|
|
112
|
+
ttt = InPlaceTTT(self.model)
|
|
113
|
+
else:
|
|
114
|
+
print("Model already has Fast Weights loaded from disk. Reusing them...")
|
|
115
|
+
|
|
116
|
+
# 3. The 'Reading' Phase
|
|
117
|
+
if chunks:
|
|
118
|
+
print("Starting Reading Phase (Test-Time Training)...")
|
|
119
|
+
if ttt is None:
|
|
120
|
+
print("Warning: Cannot train loaded Fast Weights in-place. Skipping training.")
|
|
121
|
+
else:
|
|
122
|
+
for i, chunk in enumerate(chunks):
|
|
123
|
+
# Hoist tensor creation outside the epoch loop.
|
|
124
|
+
input_ids = torch.tensor([chunk], device=self.device)
|
|
125
|
+
for epoch in range(epochs_per_chunk):
|
|
126
|
+
loss = ttt.train_on_chunk(input_ids)
|
|
127
|
+
if loss < target_loss:
|
|
128
|
+
break
|
|
129
|
+
print(f" Processed Chunk {i + 1}/{len(chunks)} - Epochs: {epoch + 1} - Loss: {loss:.4f}")
|
|
130
|
+
|
|
131
|
+
# KV Cache is implicitly cleared here because we are not passing
|
|
132
|
+
# past_key_values!
|
|
133
|
+
|
|
134
|
+
# 4. The 'Answering' Phase
|
|
135
|
+
print("Starting Answering Phase...")
|
|
136
|
+
self.model.eval()
|
|
137
|
+
|
|
138
|
+
prompt = f"Based on the codebase I just showed you, answer this question:\n{question}\nAnswer:"
|
|
139
|
+
prompt_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
|
140
|
+
|
|
141
|
+
device_type = self.device if self.device in ("cuda", "cpu") else "cpu"
|
|
142
|
+
with torch.no_grad():
|
|
143
|
+
output_ids = self.model.generate(
|
|
144
|
+
prompt_ids,
|
|
145
|
+
max_new_tokens=75,
|
|
146
|
+
do_sample=False,
|
|
147
|
+
repetition_penalty=1.1,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Extract just the newly generated tokens.
|
|
151
|
+
response_tokens = output_ids[0][prompt_ids.shape[1]:]
|
|
152
|
+
response = self.tokenizer.decode(response_tokens, skip_special_tokens=True)
|
|
153
|
+
|
|
154
|
+
# 5. State Management
|
|
155
|
+
if not keep_state:
|
|
156
|
+
if ttt is not None:
|
|
157
|
+
print("Resetting model state (deleting Fast Weights)...")
|
|
158
|
+
ttt.remove_fast_weights()
|
|
159
|
+
else:
|
|
160
|
+
print("Cannot reset loaded Fast Weights. Retaining state.")
|
|
161
|
+
else:
|
|
162
|
+
print("Keeping model state (Fast Weights retained)...")
|
|
163
|
+
if ttt is not None:
|
|
164
|
+
# Update the reference so save_pretrained saves the adapter.
|
|
165
|
+
self.model = ttt.model
|
|
166
|
+
|
|
167
|
+
return response
|
|
168
|
+
|
|
169
|
+
# ------------------------------------------------------------------
|
|
170
|
+
# Persistence
|
|
171
|
+
# ------------------------------------------------------------------
|
|
172
|
+
|
|
173
|
+
def save_state(self, path: str) -> None:
|
|
174
|
+
"""Save the current Fast Weights (LoRA adapter) to disk."""
|
|
175
|
+
if hasattr(self.model, "peft_config"):
|
|
176
|
+
self.model.save_pretrained(path)
|
|
177
|
+
print(f"State saved to {path}")
|
|
178
|
+
else:
|
|
179
|
+
raise RuntimeError(
|
|
180
|
+
"No active Fast Weights to save. "
|
|
181
|
+
"Did you forget keep_state=True when calling generate()?"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def load_state(self, path: str) -> None:
|
|
185
|
+
"""Load Fast Weights (LoRA adapter) from disk."""
|
|
186
|
+
from peft import PeftModel
|
|
187
|
+
|
|
188
|
+
adapter_path = os.path.join(path, "ttt_fast_weights")
|
|
189
|
+
if not os.path.exists(adapter_path):
|
|
190
|
+
# Fallback for standard PEFT saves.
|
|
191
|
+
adapter_path = path
|
|
192
|
+
self.model = PeftModel.from_pretrained(
|
|
193
|
+
self.model, adapter_path, is_trainable=True,
|
|
194
|
+
)
|
|
195
|
+
print(f"State loaded from {adapter_path}")
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In-Place Test-Time Training (TTT) module.
|
|
3
|
+
|
|
4
|
+
Attaches transient LoRA fast-weight adapters to an LLM's MLP layers and
|
|
5
|
+
trains them on incoming context chunks so the model "memorises" the document
|
|
6
|
+
into its weights at inference time.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from typing import List, Optional
|
|
12
|
+
|
|
13
|
+
__all__ = ["InPlaceTTT"]
|
|
14
|
+
|
|
15
|
+
# Lazy guard — resolved once at first use, not at import time.
|
|
16
|
+
_peft_available: Optional[bool] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _ensure_peft():
|
|
20
|
+
"""Import peft on first use and raise a clear error if missing."""
|
|
21
|
+
global _peft_available
|
|
22
|
+
if _peft_available is None:
|
|
23
|
+
try:
|
|
24
|
+
import peft # noqa: F401
|
|
25
|
+
_peft_available = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
_peft_available = False
|
|
28
|
+
if not _peft_available:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
"PEFT is required for the local engine. "
|
|
31
|
+
"Install via: pip install infinite_context[local]"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class InPlaceTTT:
|
|
36
|
+
"""Manages LoRA adapter injection, test-time training, and teardown."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model: nn.Module,
|
|
41
|
+
target_modules: Optional[List[str]] = None,
|
|
42
|
+
lr: float = 5e-3,
|
|
43
|
+
adapter_name: str = "ttt_fast_weights",
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Args:
|
|
47
|
+
model: The Hugging Face PreTrainedModel.
|
|
48
|
+
target_modules: MLP layer names to attach fast weights to.
|
|
49
|
+
Defaults to ``["down_proj"]``.
|
|
50
|
+
lr: Learning rate for the test-time training update.
|
|
51
|
+
adapter_name: Name of the PEFT adapter used for the fast weights.
|
|
52
|
+
"""
|
|
53
|
+
_ensure_peft()
|
|
54
|
+
|
|
55
|
+
self.model = model
|
|
56
|
+
self.target_modules = target_modules or ["down_proj"]
|
|
57
|
+
self.lr = lr
|
|
58
|
+
self.adapter_name = adapter_name
|
|
59
|
+
self.optimizer: Optional[torch.optim.Optimizer] = None
|
|
60
|
+
|
|
61
|
+
self._setup_lora()
|
|
62
|
+
|
|
63
|
+
# ------------------------------------------------------------------
|
|
64
|
+
# Internals
|
|
65
|
+
# ------------------------------------------------------------------
|
|
66
|
+
|
|
67
|
+
def _setup_lora(self) -> None:
|
|
68
|
+
"""Inject Fast Weights (LoRA) into the base model.
|
|
69
|
+
|
|
70
|
+
By setting ``r=8`` and targeting ``down_proj``, we turn the static
|
|
71
|
+
MLP into an active, online memory bank.
|
|
72
|
+
"""
|
|
73
|
+
from peft import LoraConfig, get_peft_model
|
|
74
|
+
config = LoraConfig(
|
|
75
|
+
r=8,
|
|
76
|
+
lora_alpha=16,
|
|
77
|
+
target_modules=self.target_modules,
|
|
78
|
+
lora_dropout=0.0,
|
|
79
|
+
bias="none",
|
|
80
|
+
task_type="CAUSAL_LM",
|
|
81
|
+
use_dora=True,
|
|
82
|
+
)
|
|
83
|
+
self.model = get_peft_model(self.model, config, adapter_name=self.adapter_name)
|
|
84
|
+
|
|
85
|
+
# Train ONLY the LoRA parameters — not the frozen base weights.
|
|
86
|
+
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
|
87
|
+
self.optimizer = torch.optim.AdamW(trainable_params, lr=self.lr)
|
|
88
|
+
|
|
89
|
+
# ------------------------------------------------------------------
|
|
90
|
+
# Public API
|
|
91
|
+
# ------------------------------------------------------------------
|
|
92
|
+
|
|
93
|
+
def train_on_chunk(
|
|
94
|
+
self,
|
|
95
|
+
input_ids: torch.Tensor,
|
|
96
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
97
|
+
) -> float:
|
|
98
|
+
"""Run one gradient step on *input_ids* (the 'Reading Phase').
|
|
99
|
+
|
|
100
|
+
Uses the standard causal next-token prediction loss which achieves
|
|
101
|
+
the same goal as the paper's Conv1D LM-Aligned objective: baking
|
|
102
|
+
the chunk's knowledge into the ``down_proj`` weights.
|
|
103
|
+
"""
|
|
104
|
+
self.model.train()
|
|
105
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
106
|
+
|
|
107
|
+
if attention_mask is None:
|
|
108
|
+
attention_mask = torch.ones_like(input_ids)
|
|
109
|
+
|
|
110
|
+
# Mixed-precision forward pass (FP16 on CUDA, no-op on CPU).
|
|
111
|
+
device_type = input_ids.device.type
|
|
112
|
+
with torch.amp.autocast(device_type, enabled=(device_type == "cuda")):
|
|
113
|
+
outputs = self.model(
|
|
114
|
+
input_ids=input_ids,
|
|
115
|
+
attention_mask=attention_mask,
|
|
116
|
+
labels=input_ids,
|
|
117
|
+
)
|
|
118
|
+
loss = outputs.loss
|
|
119
|
+
|
|
120
|
+
# Backward + update ONLY the LoRA adapters.
|
|
121
|
+
loss.backward()
|
|
122
|
+
self.optimizer.step()
|
|
123
|
+
|
|
124
|
+
return loss.item()
|
|
125
|
+
|
|
126
|
+
def remove_fast_weights(self) -> None:
|
|
127
|
+
"""Delete the LoRA adapter ('State Destruction').
|
|
128
|
+
|
|
129
|
+
The model instantly resets to its base pre-trained state, ready for
|
|
130
|
+
the next request.
|
|
131
|
+
"""
|
|
132
|
+
self.model.eval()
|
|
133
|
+
self.model.delete_adapter(self.adapter_name)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: infinite-context-gateway
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A hybrid AI context gateway combining Headroom cloud compression and In-Place Local Test-Time Training (TTT).
|
|
5
|
+
Author-email: Dev Team <dev@example.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: httpx
|
|
12
|
+
Provides-Extra: cloud
|
|
13
|
+
Requires-Dist: headroom-ai; extra == "cloud"
|
|
14
|
+
Provides-Extra: local
|
|
15
|
+
Requires-Dist: torch; extra == "local"
|
|
16
|
+
Requires-Dist: transformers; extra == "local"
|
|
17
|
+
Requires-Dist: peft; extra == "local"
|
|
18
|
+
Requires-Dist: accelerate; extra == "local"
|
|
19
|
+
Requires-Dist: bitsandbytes; extra == "local"
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: headroom-ai; extra == "all"
|
|
22
|
+
Requires-Dist: torch; extra == "all"
|
|
23
|
+
Requires-Dist: transformers; extra == "all"
|
|
24
|
+
Requires-Dist: peft; extra == "all"
|
|
25
|
+
Requires-Dist: accelerate; extra == "all"
|
|
26
|
+
Requires-Dist: bitsandbytes; extra == "all"
|
|
27
|
+
|
|
28
|
+
# Infinite Context
|
|
29
|
+
|
|
30
|
+
A hybrid AI context gateway combining **Headroom cloud compression** and **In-Place Local Test-Time Training (TTT)**.
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
To install the base gateway:
|
|
35
|
+
```bash
|
|
36
|
+
pip install infinite_context
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
To install with Cloud Phase 1 dependencies (Headroom API):
|
|
40
|
+
```bash
|
|
41
|
+
pip install infinite_context[cloud]
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
To install with Local Phase 2 dependencies (PyTorch, Transformers, PEFT):
|
|
45
|
+
```bash
|
|
46
|
+
pip install infinite_context[local]
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Usage
|
|
50
|
+
|
|
51
|
+
### Phase 1: Cloud Compression
|
|
52
|
+
Uses the Headroom API to semantically compress a massive context and send it to cloud providers (OpenAI, Anthropic) while keeping costs low.
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
from infinite_context import ContextGateway
|
|
56
|
+
|
|
57
|
+
gateway = ContextGateway(
|
|
58
|
+
engine="cloud",
|
|
59
|
+
model_id="claude-3-5-sonnet-20240620",
|
|
60
|
+
api_key="your_anthropic_api_key",
|
|
61
|
+
compression_ratio=0.8
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# You can pass in conversational history (rolling window)
|
|
65
|
+
history = [
|
|
66
|
+
{"role": "user", "content": "What is the codebase about?"},
|
|
67
|
+
{"role": "assistant", "content": "It is a Python application..."}
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
response = gateway.chat("How does the failover protocol work?", massive_context, history=history)
|
|
71
|
+
print(response)
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
### Phase 2: Local Test-Time Training (TTT)
|
|
75
|
+
Bypasses the KV-Cache entirely by injecting a PEFT LoRA adapter and baking the context directly into the model's neural weights on your local GPU. Includes Early Stopping latency optimizations and Generation repetition penalties.
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
from infinite_context import ContextGateway
|
|
79
|
+
|
|
80
|
+
gateway = ContextGateway(
|
|
81
|
+
engine="local",
|
|
82
|
+
model_id="Qwen/Qwen2.5-0.5B-Instruct",
|
|
83
|
+
load_in_4bit=True
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
response = gateway.chat("What is the failover protocol?", massive_context)
|
|
87
|
+
print(response)
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### Checkpoint Persistence
|
|
91
|
+
Save trained Fast Weights to disk and resume later without re-reading the document:
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
from infinite_context import ContextGateway
|
|
95
|
+
|
|
96
|
+
# Train and keep state
|
|
97
|
+
gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
|
|
98
|
+
gateway.chat("Summarise the document.", massive_context, keep_state=True)
|
|
99
|
+
gateway.save_state("./my_checkpoint")
|
|
100
|
+
|
|
101
|
+
del gateway # Free all GPU memory
|
|
102
|
+
|
|
103
|
+
# Resume later — zero re-training
|
|
104
|
+
gateway = ContextGateway(engine="local", model_id="Qwen/Qwen2.5-0.5B-Instruct")
|
|
105
|
+
gateway.load_state("./my_checkpoint")
|
|
106
|
+
response = gateway.chat("What was the protocol?", context="")
|
|
107
|
+
print(response)
|
|
108
|
+
```
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
infinite_context/__init__.py
|
|
4
|
+
infinite_context/gateway.py
|
|
5
|
+
infinite_context/cloud/__init__.py
|
|
6
|
+
infinite_context/cloud/client.py
|
|
7
|
+
infinite_context/cloud/compressor.py
|
|
8
|
+
infinite_context/local/__init__.py
|
|
9
|
+
infinite_context/local/chunk_manager.py
|
|
10
|
+
infinite_context/local/engine.py
|
|
11
|
+
infinite_context/local/ttt_module.py
|
|
12
|
+
infinite_context_gateway.egg-info/PKG-INFO
|
|
13
|
+
infinite_context_gateway.egg-info/SOURCES.txt
|
|
14
|
+
infinite_context_gateway.egg-info/dependency_links.txt
|
|
15
|
+
infinite_context_gateway.egg-info/requires.txt
|
|
16
|
+
infinite_context_gateway.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
infinite_context
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "infinite-context-gateway"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name="Dev Team", email="dev@example.com" },
|
|
10
|
+
]
|
|
11
|
+
description = "A hybrid AI context gateway combining Headroom cloud compression and In-Place Local Test-Time Training (TTT)."
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.9"
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"License :: OSI Approved :: MIT License",
|
|
17
|
+
"Operating System :: OS Independent",
|
|
18
|
+
]
|
|
19
|
+
dependencies = [
|
|
20
|
+
"httpx",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[project.optional-dependencies]
|
|
24
|
+
cloud = [
|
|
25
|
+
"headroom-ai",
|
|
26
|
+
]
|
|
27
|
+
local = [
|
|
28
|
+
"torch",
|
|
29
|
+
"transformers",
|
|
30
|
+
"peft",
|
|
31
|
+
"accelerate",
|
|
32
|
+
"bitsandbytes",
|
|
33
|
+
]
|
|
34
|
+
all = [
|
|
35
|
+
"headroom-ai",
|
|
36
|
+
"torch",
|
|
37
|
+
"transformers",
|
|
38
|
+
"peft",
|
|
39
|
+
"accelerate",
|
|
40
|
+
"bitsandbytes",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
[tool.setuptools.packages.find]
|
|
44
|
+
include = ["infinite_context*"]
|