memahead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memahead/__init__.py +72 -0
- memahead/_embeddings.py +117 -0
- memahead/compressor.py +310 -0
- memahead/context.py +187 -0
- memahead/plan.py +279 -0
- memahead/scorer.py +171 -0
- memahead/tool_filter.py +218 -0
- memahead-0.1.0.dist-info/METADATA +147 -0
- memahead-0.1.0.dist-info/RECORD +12 -0
- memahead-0.1.0.dist-info/WHEEL +4 -0
- memahead-0.1.0.dist-info/licenses/LICENSE +201 -0
- memahead-0.1.0.dist-info/licenses/NOTICE +49 -0
memahead/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""memahead — agent memory optimized for what's ahead.
|
|
2
|
+
|
|
3
|
+
memahead compresses an LLM agent's context at each step of a multi-step
|
|
4
|
+
workflow using *forward-looking* plan awareness. Instead of compressing
|
|
5
|
+
greedily based on what already happened, memahead scores each chunk of context
|
|
6
|
+
against the *remaining* steps of the plan and drops what future steps won't
|
|
7
|
+
need — far fewer tokens per call without losing what matters downstream.
|
|
8
|
+
|
|
9
|
+
It builds on Headroom (``pip install headroom-ai``) for the underlying
|
|
10
|
+
compression mechanics and adds the plan-aware retention scoring layer on top.
|
|
11
|
+
|
|
12
|
+
Academic foundations:
|
|
13
|
+
- PAACE: Yuksel et al., arXiv:2512.16970 (Dec 2025)
|
|
14
|
+
- ACON: Kang et al., Microsoft, arXiv:2510.00615 (2025)
|
|
15
|
+
|
|
16
|
+
Quick start::
|
|
17
|
+
|
|
18
|
+
from memahead import Plan, Step, PlanAwareCompressor
|
|
19
|
+
|
|
20
|
+
plan = Plan([
|
|
21
|
+
Step("research", "Search and gather raw facts about the topic"),
|
|
22
|
+
Step("synthesize", "Identify key themes across the research"),
|
|
23
|
+
Step("draft", "Write a structured first draft"),
|
|
24
|
+
Step("revise", "Produce the final polished output"),
|
|
25
|
+
])
|
|
26
|
+
|
|
27
|
+
compressor = PlanAwareCompressor(quality=0.85)
|
|
28
|
+
compressed = compressor.compress(
|
|
29
|
+
history=prior_messages,
|
|
30
|
+
tools=all_tool_schemas,
|
|
31
|
+
plan=plan,
|
|
32
|
+
current_step="synthesize",
|
|
33
|
+
)
|
|
34
|
+
print(compressed.report)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
from .compressor import PlanAwareCompressor
|
|
40
|
+
from .context import (
|
|
41
|
+
CompressedContext,
|
|
42
|
+
DroppedChunk,
|
|
43
|
+
TokenReport,
|
|
44
|
+
count_tokens,
|
|
45
|
+
)
|
|
46
|
+
from .plan import Plan, PlanGraph, Step
|
|
47
|
+
from .scorer import ChunkScore, RetentionScorer
|
|
48
|
+
from .tool_filter import ToolFilter, ToolMatch, filter_tools
|
|
49
|
+
|
|
50
|
+
__version__ = "0.1.0"
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"__version__",
|
|
54
|
+
# plan
|
|
55
|
+
"Step",
|
|
56
|
+
"Plan",
|
|
57
|
+
"PlanGraph",
|
|
58
|
+
# scoring (core novelty)
|
|
59
|
+
"RetentionScorer",
|
|
60
|
+
"ChunkScore",
|
|
61
|
+
# tool filtering
|
|
62
|
+
"ToolFilter",
|
|
63
|
+
"ToolMatch",
|
|
64
|
+
"filter_tools",
|
|
65
|
+
# compression pipeline
|
|
66
|
+
"PlanAwareCompressor",
|
|
67
|
+
# result containers
|
|
68
|
+
"CompressedContext",
|
|
69
|
+
"TokenReport",
|
|
70
|
+
"DroppedChunk",
|
|
71
|
+
"count_tokens",
|
|
72
|
+
]
|
memahead/_embeddings.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Internal embedding utilities shared by the scorer and the tool filter.
|
|
2
|
+
|
|
3
|
+
This module isolates the (optional, heavyweight) ``sentence-transformers``
|
|
4
|
+
dependency behind a tiny, swappable interface. Anything that produces a 2-D
|
|
5
|
+
array of row vectors from a list of strings can be used as an *embedder*,
|
|
6
|
+
which keeps the rest of the library testable without downloading a model.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Callable, List, Sequence, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"Embedder",
|
|
17
|
+
"SentenceTransformerEmbedder",
|
|
18
|
+
"default_embedder",
|
|
19
|
+
"resolve_embedder",
|
|
20
|
+
"cosine_similarity_matrix",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# An embedder is any callable that maps a list of texts to a (n, dim) matrix.
|
|
24
|
+
Embedder = Callable[[Sequence[str]], np.ndarray]
|
|
25
|
+
|
|
26
|
+
DEFAULT_MODEL = "all-MiniLM-L6-v2"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SentenceTransformerEmbedder:
|
|
30
|
+
"""Lazy wrapper around a ``sentence-transformers`` model.
|
|
31
|
+
|
|
32
|
+
The model is only imported and loaded on first use, so importing
|
|
33
|
+
:mod:`memahead` stays cheap and offline-friendly. The default model is
|
|
34
|
+
``all-MiniLM-L6-v2`` as described in the PAACE/ACON-inspired design.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
|
|
38
|
+
self.model_name = model_name
|
|
39
|
+
self._model = None
|
|
40
|
+
|
|
41
|
+
def _ensure_model(self):
|
|
42
|
+
if self._model is None:
|
|
43
|
+
try:
|
|
44
|
+
from sentence_transformers import SentenceTransformer
|
|
45
|
+
except ImportError as exc: # pragma: no cover - environment dependent
|
|
46
|
+
raise ImportError(
|
|
47
|
+
"sentence-transformers is required for the default embedder. "
|
|
48
|
+
"Install it with `pip install sentence-transformers`, or pass a "
|
|
49
|
+
"custom `embedder` callable to RetentionScorer / the tool filter."
|
|
50
|
+
) from exc
|
|
51
|
+
self._model = SentenceTransformer(self.model_name)
|
|
52
|
+
return self._model
|
|
53
|
+
|
|
54
|
+
def __call__(self, texts: Sequence[str]) -> np.ndarray:
|
|
55
|
+
model = self._ensure_model()
|
|
56
|
+
vectors = model.encode(
|
|
57
|
+
list(texts),
|
|
58
|
+
convert_to_numpy=True,
|
|
59
|
+
normalize_embeddings=False,
|
|
60
|
+
)
|
|
61
|
+
return np.asarray(vectors, dtype=np.float32)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Process-wide cache so repeated scorers reuse the loaded weights.
|
|
65
|
+
_DEFAULT_EMBEDDER: SentenceTransformerEmbedder | None = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def default_embedder(model_name: str = DEFAULT_MODEL) -> SentenceTransformerEmbedder:
|
|
69
|
+
"""Return a cached default embedder backed by ``sentence-transformers``."""
|
|
70
|
+
|
|
71
|
+
global _DEFAULT_EMBEDDER
|
|
72
|
+
if _DEFAULT_EMBEDDER is None or _DEFAULT_EMBEDDER.model_name != model_name:
|
|
73
|
+
_DEFAULT_EMBEDDER = SentenceTransformerEmbedder(model_name)
|
|
74
|
+
return _DEFAULT_EMBEDDER
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def resolve_embedder(
|
|
78
|
+
embedder: Union[Embedder, "SentenceTransformerEmbedder", None],
|
|
79
|
+
) -> Embedder:
|
|
80
|
+
"""Normalize the many ways a caller can supply an embedder.
|
|
81
|
+
|
|
82
|
+
Accepts ``None`` (use the default model), a plain callable, or any object
|
|
83
|
+
exposing an ``encode`` method (e.g. a raw ``SentenceTransformer``).
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
if embedder is None:
|
|
87
|
+
return default_embedder()
|
|
88
|
+
if callable(embedder):
|
|
89
|
+
return embedder
|
|
90
|
+
encode = getattr(embedder, "encode", None)
|
|
91
|
+
if callable(encode):
|
|
92
|
+
return lambda texts: np.asarray(encode(list(texts)), dtype=np.float32)
|
|
93
|
+
raise TypeError(
|
|
94
|
+
"embedder must be None, a callable, or expose an `encode` method; "
|
|
95
|
+
f"got {type(embedder)!r}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _l2_normalize(matrix: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
|
100
|
+
matrix = np.asarray(matrix, dtype=np.float32)
|
|
101
|
+
if matrix.ndim == 1:
|
|
102
|
+
matrix = matrix.reshape(1, -1)
|
|
103
|
+
norms = np.linalg.norm(matrix, axis=1, keepdims=True)
|
|
104
|
+
return matrix / np.maximum(norms, eps)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def cosine_similarity_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
108
|
+
"""Return the (len(a), len(b)) matrix of cosine similarities.
|
|
109
|
+
|
|
110
|
+
Rows correspond to vectors in ``a`` (e.g. context chunks); columns to
|
|
111
|
+
vectors in ``b`` (e.g. remaining plan steps). Inputs need not be
|
|
112
|
+
pre-normalized.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
a_norm = _l2_normalize(a)
|
|
116
|
+
b_norm = _l2_normalize(b)
|
|
117
|
+
return a_norm @ b_norm.T
|
memahead/compressor.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""The plan-aware compression pipeline.
|
|
2
|
+
|
|
3
|
+
:class:`PlanAwareCompressor` ties the pieces together:
|
|
4
|
+
|
|
5
|
+
history + tools + plan + current_step
|
|
6
|
+
-> split history into chunks
|
|
7
|
+
-> score chunks against the *remaining* plan steps (RetentionScorer)
|
|
8
|
+
-> drop chunks future steps won't need (plan-aware retention)
|
|
9
|
+
-> filter tool schemas to the current step (tool_filter, no LLM)
|
|
10
|
+
-> hand survivors to Headroom for the actual compression mechanics
|
|
11
|
+
-> return a CompressedContext (+ TokenReport)
|
|
12
|
+
|
|
13
|
+
memahead owns the *retention policy*; Headroom owns the *compression
|
|
14
|
+
mechanics*. If Headroom is not installed the pipeline still works — it simply
|
|
15
|
+
skips the mechanical compression step and relies on retention alone.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
21
|
+
|
|
22
|
+
from .context import (
|
|
23
|
+
CompressedContext,
|
|
24
|
+
DroppedChunk,
|
|
25
|
+
TokenReport,
|
|
26
|
+
_message_text,
|
|
27
|
+
count_message_tokens,
|
|
28
|
+
count_tool_tokens,
|
|
29
|
+
count_tokens,
|
|
30
|
+
)
|
|
31
|
+
from .plan import Plan, PlanGraph, Step
|
|
32
|
+
from .scorer import RetentionScorer
|
|
33
|
+
from .tool_filter import ToolFilter
|
|
34
|
+
|
|
35
|
+
__all__ = ["PlanAwareCompressor"]
|
|
36
|
+
|
|
37
|
+
PlanLike = Union[Plan, PlanGraph]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PlanAwareCompressor:
|
|
41
|
+
"""Compress agent context using forward-looking, plan-aware retention.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
quality: Information-retention dial in ``[0.0, 1.0]``. Higher keeps
|
|
45
|
+
more context (gentler compression); lower is more aggressive.
|
|
46
|
+
Defaults to ``0.85``.
|
|
47
|
+
retention_threshold: Optional absolute score cutoff in ``[0.0, 1.0]``.
|
|
48
|
+
When set, chunks scoring below it are dropped, overriding the
|
|
49
|
+
``quality``-derived relative policy. Useful for reproducible runs.
|
|
50
|
+
tool_threshold: Match cutoff for keeping a tool schema.
|
|
51
|
+
scorer: A custom :class:`RetentionScorer` (e.g. with an injected
|
|
52
|
+
embedder). If ``None``, one is created lazily.
|
|
53
|
+
tool_filter: A custom :class:`ToolFilter`. If ``None``, one is created.
|
|
54
|
+
embedder: Convenience way to inject one embedder into both the scorer
|
|
55
|
+
and the tool filter (ignored if explicit ``scorer``/``tool_filter``
|
|
56
|
+
are given).
|
|
57
|
+
use_headroom: Whether to run survivors through Headroom for mechanical
|
|
58
|
+
compression. Defaults to ``True``; silently no-ops if Headroom is
|
|
59
|
+
unavailable.
|
|
60
|
+
model: Optional model name forwarded to Headroom and the tokenizer.
|
|
61
|
+
keep_system: Always retain ``system`` role messages. Defaults to True.
|
|
62
|
+
keep_last: Always retain the final message (the current turn's input).
|
|
63
|
+
Defaults to True.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
quality: float = 0.85,
|
|
69
|
+
*,
|
|
70
|
+
retention_threshold: Optional[float] = None,
|
|
71
|
+
tool_threshold: float = 0.3,
|
|
72
|
+
scorer: Optional[RetentionScorer] = None,
|
|
73
|
+
tool_filter: Optional[ToolFilter] = None,
|
|
74
|
+
embedder: Optional[Any] = None,
|
|
75
|
+
use_headroom: bool = True,
|
|
76
|
+
model: Optional[str] = None,
|
|
77
|
+
keep_system: bool = True,
|
|
78
|
+
keep_last: bool = True,
|
|
79
|
+
) -> None:
|
|
80
|
+
if not 0.0 <= quality <= 1.0:
|
|
81
|
+
raise ValueError("quality must be in [0.0, 1.0]")
|
|
82
|
+
if retention_threshold is not None and not 0.0 <= retention_threshold <= 1.0:
|
|
83
|
+
raise ValueError("retention_threshold must be in [0.0, 1.0]")
|
|
84
|
+
|
|
85
|
+
self.quality = quality
|
|
86
|
+
self.retention_threshold = retention_threshold
|
|
87
|
+
self.use_headroom = use_headroom
|
|
88
|
+
self.model = model
|
|
89
|
+
self.keep_system = keep_system
|
|
90
|
+
self.keep_last = keep_last
|
|
91
|
+
|
|
92
|
+
self.scorer = scorer or RetentionScorer(embedder=embedder)
|
|
93
|
+
self.tool_filter = tool_filter or ToolFilter(
|
|
94
|
+
embedder=embedder, threshold=tool_threshold
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# -- helpers ------------------------------------------------------------
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def _resolve_step(plan: PlanLike, current_step: Union[Step, str]) -> Step:
|
|
101
|
+
if isinstance(current_step, Step):
|
|
102
|
+
return current_step
|
|
103
|
+
return plan.get(current_step)
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _is_system(message: Any) -> bool:
|
|
107
|
+
return isinstance(message, dict) and message.get("role") == "system"
|
|
108
|
+
|
|
109
|
+
def _always_keep_mask(self, history: List[Any]) -> List[bool]:
|
|
110
|
+
n = len(history)
|
|
111
|
+
mask = [False] * n
|
|
112
|
+
for i, msg in enumerate(history):
|
|
113
|
+
if self.keep_system and self._is_system(msg):
|
|
114
|
+
mask[i] = True
|
|
115
|
+
if self.keep_last and n > 0:
|
|
116
|
+
mask[n - 1] = True
|
|
117
|
+
return mask
|
|
118
|
+
|
|
119
|
+
def _decide_retention(
|
|
120
|
+
self,
|
|
121
|
+
scores: List[float],
|
|
122
|
+
always_keep: List[bool],
|
|
123
|
+
has_future: bool,
|
|
124
|
+
) -> List[bool]:
|
|
125
|
+
"""Return a keep/drop flag per chunk from scores + policy."""
|
|
126
|
+
|
|
127
|
+
n = len(scores)
|
|
128
|
+
if n == 0:
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
# No future steps -> nothing to prune against; keep everything.
|
|
132
|
+
if not has_future:
|
|
133
|
+
return [True] * n
|
|
134
|
+
|
|
135
|
+
if self.retention_threshold is not None:
|
|
136
|
+
keep = [s >= self.retention_threshold for s in scores]
|
|
137
|
+
else:
|
|
138
|
+
# Relative policy: min-max normalize, then keep the top band as
|
|
139
|
+
# governed by `quality`. quality=0.85 -> keep normalized >= 0.15.
|
|
140
|
+
lo = min(scores)
|
|
141
|
+
hi = max(scores)
|
|
142
|
+
cutoff = 1.0 - self.quality
|
|
143
|
+
if hi - lo < 1e-9:
|
|
144
|
+
# All equal: a flat horizon. Keep them all rather than guess.
|
|
145
|
+
keep = [True] * n
|
|
146
|
+
else:
|
|
147
|
+
keep = [((s - lo) / (hi - lo)) >= cutoff for s in scores]
|
|
148
|
+
|
|
149
|
+
for i in range(n):
|
|
150
|
+
if always_keep[i]:
|
|
151
|
+
keep[i] = True
|
|
152
|
+
return keep
|
|
153
|
+
|
|
154
|
+
def _apply_headroom(self, messages: List[Any]) -> List[Any]:
|
|
155
|
+
"""Run messages through Headroom's ``compress`` if available.
|
|
156
|
+
|
|
157
|
+
Defensive by design: any import error, signature mismatch, or
|
|
158
|
+
unexpected return shape falls back to the input unchanged so that
|
|
159
|
+
retention-only compression still works.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
if not self.use_headroom or not messages:
|
|
163
|
+
return messages
|
|
164
|
+
try:
|
|
165
|
+
from headroom import compress # type: ignore
|
|
166
|
+
except Exception:
|
|
167
|
+
return messages
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
result = compress(messages, model=self.model) if self.model else compress(messages)
|
|
171
|
+
except TypeError:
|
|
172
|
+
try:
|
|
173
|
+
result = compress(messages)
|
|
174
|
+
except Exception:
|
|
175
|
+
return messages
|
|
176
|
+
except Exception:
|
|
177
|
+
return messages
|
|
178
|
+
|
|
179
|
+
return self._normalize_headroom_result(result, fallback=messages)
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def _normalize_headroom_result(result: Any, fallback: List[Any]) -> List[Any]:
|
|
183
|
+
if result is None:
|
|
184
|
+
return fallback
|
|
185
|
+
if isinstance(result, list):
|
|
186
|
+
return result
|
|
187
|
+
# Common attribute names across compression libraries.
|
|
188
|
+
for attr in ("messages", "compressed", "output", "result"):
|
|
189
|
+
value = getattr(result, attr, None)
|
|
190
|
+
if isinstance(value, list):
|
|
191
|
+
return value
|
|
192
|
+
if isinstance(result, dict):
|
|
193
|
+
for key in ("messages", "compressed", "output", "result"):
|
|
194
|
+
value = result.get(key)
|
|
195
|
+
if isinstance(value, list):
|
|
196
|
+
return value
|
|
197
|
+
return fallback
|
|
198
|
+
|
|
199
|
+
# -- public API ---------------------------------------------------------
|
|
200
|
+
|
|
201
|
+
def compress(
|
|
202
|
+
self,
|
|
203
|
+
history: Sequence[Any],
|
|
204
|
+
tools: Sequence[Any],
|
|
205
|
+
plan: PlanLike,
|
|
206
|
+
current_step: Union[Step, str],
|
|
207
|
+
) -> CompressedContext:
|
|
208
|
+
"""Compress ``history`` and ``tools`` for the given step of ``plan``.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
history: Prior chat messages (dicts with ``role``/``content``, or
|
|
212
|
+
plain strings). Each message is treated as one context chunk.
|
|
213
|
+
tools: The full catalog of tool schemas available to the agent.
|
|
214
|
+
plan: The :class:`Plan` (or :class:`PlanGraph`) being executed.
|
|
215
|
+
current_step: The step about to run — the pivot for "what's ahead".
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
A :class:`CompressedContext` with lean ``messages``, filtered
|
|
219
|
+
``tools``, and a :class:`TokenReport`.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
history = list(history)
|
|
223
|
+
tools = list(tools or [])
|
|
224
|
+
|
|
225
|
+
step = self._resolve_step(plan, current_step)
|
|
226
|
+
step_key = step.name if isinstance(current_step, Step) else str(current_step)
|
|
227
|
+
remaining_steps = plan.remaining_from(step_key)
|
|
228
|
+
has_future = len(remaining_steps) > 0
|
|
229
|
+
|
|
230
|
+
before_tokens = count_message_tokens(history, self.model) + count_tool_tokens(
|
|
231
|
+
tools, self.model
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# 1) chunk + score against the forward horizon.
|
|
235
|
+
chunk_texts = [_message_text(m) for m in history]
|
|
236
|
+
always_keep = self._always_keep_mask(history)
|
|
237
|
+
|
|
238
|
+
if history:
|
|
239
|
+
chunk_scores = self.scorer.score(chunk_texts, remaining_steps)
|
|
240
|
+
scores = [cs.score for cs in chunk_scores]
|
|
241
|
+
else:
|
|
242
|
+
scores = []
|
|
243
|
+
|
|
244
|
+
# 2) decide retention.
|
|
245
|
+
keep_flags = self._decide_retention(scores, always_keep, has_future)
|
|
246
|
+
|
|
247
|
+
retained_messages: List[Any] = []
|
|
248
|
+
retained_scores: Dict[str, float] = {}
|
|
249
|
+
dropped: List[DroppedChunk] = []
|
|
250
|
+
for i, msg in enumerate(history):
|
|
251
|
+
source = f"message[{i}]"
|
|
252
|
+
tok = count_tokens(chunk_texts[i], self.model)
|
|
253
|
+
score = scores[i] if i < len(scores) else None
|
|
254
|
+
if keep_flags[i]:
|
|
255
|
+
retained_messages.append(msg)
|
|
256
|
+
if score is not None:
|
|
257
|
+
retained_scores[source] = round(score, 4)
|
|
258
|
+
else:
|
|
259
|
+
dropped.append(
|
|
260
|
+
DroppedChunk(
|
|
261
|
+
source=source,
|
|
262
|
+
kind="message",
|
|
263
|
+
tokens_before=tok,
|
|
264
|
+
tokens_after=0,
|
|
265
|
+
score=round(score, 4) if score is not None else None,
|
|
266
|
+
reason="below retention threshold for remaining plan steps",
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# 3) filter tools to the current step (deterministic, no LLM call).
|
|
271
|
+
tool_matches = self.tool_filter.match(tools, step)
|
|
272
|
+
kept_tools = [m.tool for m in tool_matches if m.kept]
|
|
273
|
+
for m in tool_matches:
|
|
274
|
+
if not m.kept:
|
|
275
|
+
dropped.append(
|
|
276
|
+
DroppedChunk(
|
|
277
|
+
source=f"tool:{m.name or '?'}",
|
|
278
|
+
kind="tool",
|
|
279
|
+
tokens_before=count_tokens(_tool_schema_text(m.tool), self.model),
|
|
280
|
+
tokens_after=0,
|
|
281
|
+
score=round(m.score, 4),
|
|
282
|
+
reason="tool not relevant to current step",
|
|
283
|
+
)
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# 4) hand survivors to Headroom for mechanical compression.
|
|
287
|
+
compressed_messages = self._apply_headroom(retained_messages)
|
|
288
|
+
|
|
289
|
+
after_tokens = count_message_tokens(
|
|
290
|
+
compressed_messages, self.model
|
|
291
|
+
) + count_tool_tokens(kept_tools, self.model)
|
|
292
|
+
|
|
293
|
+
report = TokenReport(
|
|
294
|
+
before=before_tokens,
|
|
295
|
+
after=after_tokens,
|
|
296
|
+
dropped=dropped,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return CompressedContext(
|
|
300
|
+
messages=compressed_messages,
|
|
301
|
+
tools=kept_tools,
|
|
302
|
+
report=report,
|
|
303
|
+
retained_scores=retained_scores,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _tool_schema_text(tool: Any) -> str:
|
|
308
|
+
from .context import _tool_text
|
|
309
|
+
|
|
310
|
+
return _tool_text(tool)
|
memahead/context.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Result containers for compression: :class:`CompressedContext` and
|
|
2
|
+
:class:`TokenReport`, plus a small token-estimation helper.
|
|
3
|
+
|
|
4
|
+
These are intentionally dependency-light dataclasses so they can be passed
|
|
5
|
+
around, serialized, and inspected without importing heavy ML packages.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
__all__ = ["count_tokens", "DroppedChunk", "TokenReport", "CompressedContext"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def count_tokens(text: str, model: Optional[str] = None) -> int:
|
|
17
|
+
"""Estimate the number of tokens in ``text``.
|
|
18
|
+
|
|
19
|
+
Uses :mod:`tiktoken` when available for accuracy; otherwise falls back to
|
|
20
|
+
a fast heuristic (~4 characters per token). The heuristic keeps the
|
|
21
|
+
library usable with zero extra dependencies while still giving a stable,
|
|
22
|
+
monotonic measure for reporting savings.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
text: The text to measure.
|
|
26
|
+
model: Optional model name used to pick a tiktoken encoding.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
if not text:
|
|
30
|
+
return 0
|
|
31
|
+
try:
|
|
32
|
+
import tiktoken
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
encoding = (
|
|
36
|
+
tiktoken.encoding_for_model(model)
|
|
37
|
+
if model
|
|
38
|
+
else tiktoken.get_encoding("cl100k_base")
|
|
39
|
+
)
|
|
40
|
+
except KeyError:
|
|
41
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
42
|
+
return len(encoding.encode(text))
|
|
43
|
+
except Exception:
|
|
44
|
+
# Heuristic fallback: ~4 chars/token, with a floor of one token per word.
|
|
45
|
+
char_estimate = (len(text) + 3) // 4
|
|
46
|
+
word_estimate = len(text.split())
|
|
47
|
+
return max(char_estimate, word_estimate)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _message_text(message: Any) -> str:
|
|
51
|
+
"""Extract the textual content from a chat message (dict or str)."""
|
|
52
|
+
|
|
53
|
+
if isinstance(message, str):
|
|
54
|
+
return message
|
|
55
|
+
if isinstance(message, dict):
|
|
56
|
+
content = message.get("content", "")
|
|
57
|
+
if isinstance(content, str):
|
|
58
|
+
return content
|
|
59
|
+
# Content can be a list of parts (OpenAI-style multimodal blocks).
|
|
60
|
+
if isinstance(content, list):
|
|
61
|
+
parts: List[str] = []
|
|
62
|
+
for part in content:
|
|
63
|
+
if isinstance(part, str):
|
|
64
|
+
parts.append(part)
|
|
65
|
+
elif isinstance(part, dict):
|
|
66
|
+
parts.append(str(part.get("text", "")))
|
|
67
|
+
return "\n".join(parts)
|
|
68
|
+
return str(content)
|
|
69
|
+
return str(message)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _tool_text(tool: Any) -> str:
|
|
73
|
+
"""Extract a stable textual representation of a tool schema for counting."""
|
|
74
|
+
|
|
75
|
+
if isinstance(tool, str):
|
|
76
|
+
return tool
|
|
77
|
+
if isinstance(tool, dict):
|
|
78
|
+
# Support both the bare schema and the OpenAI {"type","function":{...}}
|
|
79
|
+
# envelope so token counts reflect what is actually sent.
|
|
80
|
+
import json
|
|
81
|
+
|
|
82
|
+
return json.dumps(tool, sort_keys=True, default=str)
|
|
83
|
+
return str(tool)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def count_message_tokens(messages: List[Any], model: Optional[str] = None) -> int:
|
|
87
|
+
"""Total estimated tokens across a list of chat messages."""
|
|
88
|
+
|
|
89
|
+
return sum(count_tokens(_message_text(m), model) for m in messages)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def count_tool_tokens(tools: List[Any], model: Optional[str] = None) -> int:
|
|
93
|
+
"""Total estimated tokens across a list of tool schemas."""
|
|
94
|
+
|
|
95
|
+
return sum(count_tokens(_tool_text(t), model) for t in (tools or []))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class DroppedChunk:
|
|
100
|
+
"""Record of a single context chunk that was dropped or shrunk.
|
|
101
|
+
|
|
102
|
+
Attributes:
|
|
103
|
+
source: A human-readable origin (e.g. ``"message[3]"`` or a tool name).
|
|
104
|
+
kind: ``"message"`` or ``"tool"``.
|
|
105
|
+
score: The forward-looking retention score (0.0–1.0), if applicable.
|
|
106
|
+
tokens_before: Tokens the chunk occupied before compression.
|
|
107
|
+
tokens_after: Tokens remaining after compression (0 if fully dropped).
|
|
108
|
+
reason: Why it was dropped (e.g. ``"below retention threshold"``).
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
source: str
|
|
112
|
+
kind: str
|
|
113
|
+
tokens_before: int
|
|
114
|
+
tokens_after: int = 0
|
|
115
|
+
score: Optional[float] = None
|
|
116
|
+
reason: str = ""
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def tokens_saved(self) -> int:
|
|
120
|
+
return max(self.tokens_before - self.tokens_after, 0)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass
|
|
124
|
+
class TokenReport:
|
|
125
|
+
"""Summary of how many tokens compression saved.
|
|
126
|
+
|
|
127
|
+
Attributes:
|
|
128
|
+
before: Total tokens before compression (messages + tools).
|
|
129
|
+
after: Total tokens after compression (messages + tools).
|
|
130
|
+
dropped: Per-chunk records of what was removed or shrunk.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
before: int
|
|
134
|
+
after: int
|
|
135
|
+
dropped: List[DroppedChunk] = field(default_factory=list)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def saved(self) -> int:
|
|
139
|
+
"""Absolute number of tokens saved."""
|
|
140
|
+
|
|
141
|
+
return max(self.before - self.after, 0)
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def compression_ratio(self) -> float:
|
|
145
|
+
"""Fraction of tokens removed, in ``[0.0, 1.0]``.
|
|
146
|
+
|
|
147
|
+
``0.0`` means nothing was saved; ``0.75`` means 75% fewer tokens.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
if self.before <= 0:
|
|
151
|
+
return 0.0
|
|
152
|
+
return round(self.saved / self.before, 4)
|
|
153
|
+
|
|
154
|
+
def dropped_sources(self) -> List[str]:
|
|
155
|
+
"""Return the sources of fully-dropped chunks."""
|
|
156
|
+
|
|
157
|
+
return [d.source for d in self.dropped if d.tokens_after == 0]
|
|
158
|
+
|
|
159
|
+
def __repr__(self) -> str:
|
|
160
|
+
return (
|
|
161
|
+
f"TokenReport(before={self.before}, after={self.after}, "
|
|
162
|
+
f"saved={self.saved}, compression_ratio={self.compression_ratio})"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass
|
|
167
|
+
class CompressedContext:
|
|
168
|
+
"""The lean, ready-to-send context produced by the compressor.
|
|
169
|
+
|
|
170
|
+
Attributes:
|
|
171
|
+
messages: Compressed chat messages, ready to pass to an LLM call.
|
|
172
|
+
tools: Filtered tool schemas relevant to the current step.
|
|
173
|
+
report: A :class:`TokenReport` describing what was saved.
|
|
174
|
+
retained_scores: Mapping of retained message source -> retention score,
|
|
175
|
+
useful for debugging and evaluation.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
messages: List[Any]
|
|
179
|
+
tools: List[Any]
|
|
180
|
+
report: TokenReport
|
|
181
|
+
retained_scores: Dict[str, float] = field(default_factory=dict)
|
|
182
|
+
|
|
183
|
+
def __repr__(self) -> str:
|
|
184
|
+
return (
|
|
185
|
+
f"CompressedContext(messages={len(self.messages)}, "
|
|
186
|
+
f"tools={len(self.tools)}, report={self.report!r})"
|
|
187
|
+
)
|