@animalabs/membrane 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/context/index.d.ts +10 -0
- package/dist/context/index.d.ts.map +1 -0
- package/dist/context/index.js +9 -0
- package/dist/context/index.js.map +1 -0
- package/dist/context/process.d.ts +22 -0
- package/dist/context/process.d.ts.map +1 -0
- package/dist/context/process.js +369 -0
- package/dist/context/process.js.map +1 -0
- package/dist/context/types.d.ts +118 -0
- package/dist/context/types.d.ts.map +1 -0
- package/dist/context/types.js +60 -0
- package/dist/context/types.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/membrane.d.ts +96 -0
- package/dist/membrane.d.ts.map +1 -0
- package/dist/membrane.js +893 -0
- package/dist/membrane.js.map +1 -0
- package/dist/providers/anthropic.d.ts +36 -0
- package/dist/providers/anthropic.d.ts.map +1 -0
- package/dist/providers/anthropic.js +265 -0
- package/dist/providers/anthropic.js.map +1 -0
- package/dist/providers/index.d.ts +8 -0
- package/dist/providers/index.d.ts.map +1 -0
- package/dist/providers/index.js +8 -0
- package/dist/providers/index.js.map +1 -0
- package/dist/providers/openai-compatible.d.ts +74 -0
- package/dist/providers/openai-compatible.d.ts.map +1 -0
- package/dist/providers/openai-compatible.js +412 -0
- package/dist/providers/openai-compatible.js.map +1 -0
- package/dist/providers/openai.d.ts +69 -0
- package/dist/providers/openai.d.ts.map +1 -0
- package/dist/providers/openai.js +455 -0
- package/dist/providers/openai.js.map +1 -0
- package/dist/providers/openrouter.d.ts +76 -0
- package/dist/providers/openrouter.d.ts.map +1 -0
- package/dist/providers/openrouter.js +492 -0
- package/dist/providers/openrouter.js.map +1 -0
- package/dist/transforms/chat.d.ts +52 -0
- package/dist/transforms/chat.d.ts.map +1 -0
- package/dist/transforms/chat.js +136 -0
- package/dist/transforms/chat.js.map +1 -0
- package/dist/transforms/index.d.ts +6 -0
- package/dist/transforms/index.d.ts.map +1 -0
- package/dist/transforms/index.js +6 -0
- package/dist/transforms/index.js.map +1 -0
- package/dist/transforms/prefill.d.ts +89 -0
- package/dist/transforms/prefill.d.ts.map +1 -0
- package/dist/transforms/prefill.js +401 -0
- package/dist/transforms/prefill.js.map +1 -0
- package/dist/types/config.d.ts +103 -0
- package/dist/types/config.d.ts.map +1 -0
- package/dist/types/config.js +21 -0
- package/dist/types/config.js.map +1 -0
- package/dist/types/content.d.ts +81 -0
- package/dist/types/content.d.ts.map +1 -0
- package/dist/types/content.js +40 -0
- package/dist/types/content.js.map +1 -0
- package/dist/types/errors.d.ts +42 -0
- package/dist/types/errors.d.ts.map +1 -0
- package/dist/types/errors.js +208 -0
- package/dist/types/errors.js.map +1 -0
- package/dist/types/index.d.ts +18 -0
- package/dist/types/index.d.ts.map +1 -0
- package/dist/types/index.js +9 -0
- package/dist/types/index.js.map +1 -0
- package/dist/types/message.d.ts +46 -0
- package/dist/types/message.d.ts.map +1 -0
- package/dist/types/message.js +38 -0
- package/dist/types/message.js.map +1 -0
- package/dist/types/provider.d.ts +155 -0
- package/dist/types/provider.d.ts.map +1 -0
- package/dist/types/provider.js +5 -0
- package/dist/types/provider.js.map +1 -0
- package/dist/types/request.d.ts +78 -0
- package/dist/types/request.d.ts.map +1 -0
- package/dist/types/request.js +5 -0
- package/dist/types/request.js.map +1 -0
- package/dist/types/response.d.ts +131 -0
- package/dist/types/response.d.ts.map +1 -0
- package/dist/types/response.js +7 -0
- package/dist/types/response.js.map +1 -0
- package/dist/types/streaming.d.ts +164 -0
- package/dist/types/streaming.d.ts.map +1 -0
- package/dist/types/streaming.js +5 -0
- package/dist/types/streaming.js.map +1 -0
- package/dist/types/tools.d.ts +71 -0
- package/dist/types/tools.d.ts.map +1 -0
- package/dist/types/tools.js +5 -0
- package/dist/types/tools.js.map +1 -0
- package/dist/utils/index.d.ts +5 -0
- package/dist/utils/index.d.ts.map +1 -0
- package/dist/utils/index.js +5 -0
- package/dist/utils/index.js.map +1 -0
- package/dist/utils/stream-parser.d.ts +53 -0
- package/dist/utils/stream-parser.d.ts.map +1 -0
- package/dist/utils/stream-parser.js +359 -0
- package/dist/utils/stream-parser.js.map +1 -0
- package/dist/utils/tool-parser.d.ts +130 -0
- package/dist/utils/tool-parser.d.ts.map +1 -0
- package/dist/utils/tool-parser.js +571 -0
- package/dist/utils/tool-parser.js.map +1 -0
- package/package.json +37 -0
- package/src/context/index.ts +24 -0
- package/src/context/process.ts +520 -0
- package/src/context/types.ts +231 -0
- package/src/index.ts +23 -0
- package/src/membrane.ts +1174 -0
- package/src/providers/anthropic.ts +340 -0
- package/src/providers/index.ts +31 -0
- package/src/providers/openai-compatible.ts +570 -0
- package/src/providers/openai.ts +625 -0
- package/src/providers/openrouter.ts +662 -0
- package/src/transforms/chat.ts +212 -0
- package/src/transforms/index.ts +22 -0
- package/src/transforms/prefill.ts +585 -0
- package/src/types/config.ts +172 -0
- package/src/types/content.ts +181 -0
- package/src/types/errors.ts +277 -0
- package/src/types/index.ts +154 -0
- package/src/types/message.ts +89 -0
- package/src/types/provider.ts +249 -0
- package/src/types/request.ts +131 -0
- package/src/types/response.ts +223 -0
- package/src/types/streaming.ts +231 -0
- package/src/types/tools.ts +92 -0
- package/src/utils/index.ts +15 -0
- package/src/utils/stream-parser.ts +440 -0
- package/src/utils/tool-parser.ts +715 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Context processing - main entry point
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
import type { Membrane } from '../membrane.js';
|
|
6
|
+
import type { NormalizedMessage, NormalizedRequest } from '../types/index.js';
|
|
7
|
+
import type {
|
|
8
|
+
ContextInput,
|
|
9
|
+
ContextState,
|
|
10
|
+
ContextOutput,
|
|
11
|
+
ContextInfo,
|
|
12
|
+
ContextConfig,
|
|
13
|
+
ContextStreamOptions,
|
|
14
|
+
CacheMarker,
|
|
15
|
+
} from './types.js';
|
|
16
|
+
import {
|
|
17
|
+
createInitialState,
|
|
18
|
+
defaultTokenEstimator,
|
|
19
|
+
DEFAULT_CONTEXT_CONFIG,
|
|
20
|
+
} from './types.js';
|
|
21
|
+
|
|
22
|
+
// ============================================================================
|
|
23
|
+
// Main Entry Point
|
|
24
|
+
// ============================================================================
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* Process context and stream LLM response.
|
|
28
|
+
*
|
|
29
|
+
* This function handles:
|
|
30
|
+
* - Rolling/truncation based on thresholds
|
|
31
|
+
* - Cache marker placement for prompt caching
|
|
32
|
+
* - Hard limit enforcement
|
|
33
|
+
* - State management
|
|
34
|
+
*
|
|
35
|
+
* @param membrane - Configured Membrane instance
|
|
36
|
+
* @param input - Messages, config, and context settings
|
|
37
|
+
* @param state - Previous state (null for first call)
|
|
38
|
+
* @param options - Stream options
|
|
39
|
+
* @returns Response, updated state, and context info
|
|
40
|
+
*/
|
|
41
|
+
export async function processContext(
|
|
42
|
+
membrane: Membrane,
|
|
43
|
+
input: ContextInput,
|
|
44
|
+
state: ContextState | null,
|
|
45
|
+
options?: ContextStreamOptions
|
|
46
|
+
): Promise<ContextOutput> {
|
|
47
|
+
// Merge config with defaults
|
|
48
|
+
const contextConfig = mergeConfig(input.context);
|
|
49
|
+
const tokenEstimator = contextConfig.tokenEstimator ?? defaultTokenEstimator;
|
|
50
|
+
|
|
51
|
+
// Initialize or continue state
|
|
52
|
+
let currentState = state ?? createInitialState();
|
|
53
|
+
|
|
54
|
+
// Detect discontinuity (new conversation or branch switch)
|
|
55
|
+
const isDiscontinuous = detectDiscontinuity(input.messages, currentState);
|
|
56
|
+
if (isDiscontinuous) {
|
|
57
|
+
currentState = createInitialState();
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Calculate tokens for all messages
|
|
61
|
+
const messageTokens = input.messages.map(m => ({
|
|
62
|
+
message: m,
|
|
63
|
+
tokens: tokenEstimator(m),
|
|
64
|
+
id: getMessageId(m),
|
|
65
|
+
}));
|
|
66
|
+
|
|
67
|
+
const totalTokens = messageTokens.reduce((sum, m) => sum + m.tokens, 0);
|
|
68
|
+
const totalCharacters = calculateCharacters(input.messages);
|
|
69
|
+
|
|
70
|
+
// Determine if we should roll
|
|
71
|
+
const rollDecision = shouldRoll(
|
|
72
|
+
currentState,
|
|
73
|
+
input.messages.length,
|
|
74
|
+
totalTokens,
|
|
75
|
+
totalCharacters,
|
|
76
|
+
contextConfig
|
|
77
|
+
);
|
|
78
|
+
|
|
79
|
+
// Apply rolling/truncation if needed
|
|
80
|
+
let keptMessages = input.messages;
|
|
81
|
+
let messagesDropped = 0;
|
|
82
|
+
let didRoll = false;
|
|
83
|
+
let hardLimitHit = false;
|
|
84
|
+
|
|
85
|
+
if (rollDecision.shouldRoll) {
|
|
86
|
+
const truncateResult = truncateMessages(
|
|
87
|
+
messageTokens,
|
|
88
|
+
rollDecision.targetTokens,
|
|
89
|
+
rollDecision.targetMessages,
|
|
90
|
+
contextConfig
|
|
91
|
+
);
|
|
92
|
+
|
|
93
|
+
keptMessages = truncateResult.kept.map(m => m.message);
|
|
94
|
+
messagesDropped = truncateResult.dropped;
|
|
95
|
+
didRoll = true;
|
|
96
|
+
hardLimitHit = rollDecision.reason === 'hard_limit';
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Recalculate tokens after truncation
|
|
100
|
+
const keptTokens = keptMessages.map(m => ({
|
|
101
|
+
message: m,
|
|
102
|
+
tokens: tokenEstimator(m),
|
|
103
|
+
id: getMessageId(m),
|
|
104
|
+
}));
|
|
105
|
+
const keptTotalTokens = keptTokens.reduce((sum, m) => sum + m.tokens, 0);
|
|
106
|
+
|
|
107
|
+
// Place cache markers
|
|
108
|
+
const cacheMarkers = placeCacheMarkers(
|
|
109
|
+
keptMessages,
|
|
110
|
+
keptTokens,
|
|
111
|
+
currentState,
|
|
112
|
+
didRoll,
|
|
113
|
+
contextConfig
|
|
114
|
+
);
|
|
115
|
+
|
|
116
|
+
// Apply cache markers to messages
|
|
117
|
+
const messagesWithCache = applyCacheMarkers(keptMessages, cacheMarkers);
|
|
118
|
+
|
|
119
|
+
// Calculate cached/uncached tokens
|
|
120
|
+
const lastMarker = cacheMarkers[cacheMarkers.length - 1];
|
|
121
|
+
const cachedTokens = lastMarker?.tokenEstimate ?? 0;
|
|
122
|
+
const uncachedTokens = keptTotalTokens - cachedTokens;
|
|
123
|
+
|
|
124
|
+
// Build request
|
|
125
|
+
const request: NormalizedRequest = {
|
|
126
|
+
messages: messagesWithCache,
|
|
127
|
+
system: input.system,
|
|
128
|
+
tools: input.tools,
|
|
129
|
+
config: input.config,
|
|
130
|
+
};
|
|
131
|
+
|
|
132
|
+
// Stream response
|
|
133
|
+
const response = await membrane.stream(request, {
|
|
134
|
+
onChunk: options?.onChunk,
|
|
135
|
+
signal: options?.signal,
|
|
136
|
+
});
|
|
137
|
+
|
|
138
|
+
// Update state
|
|
139
|
+
const newState: ContextState = {
|
|
140
|
+
cacheMarkers,
|
|
141
|
+
windowMessageIds: keptMessages.map(m => getMessageId(m)),
|
|
142
|
+
messagesSinceRoll: didRoll ? 1 : currentState.messagesSinceRoll + 1,
|
|
143
|
+
tokensSinceRoll: didRoll ? keptTotalTokens : currentState.tokensSinceRoll + keptTotalTokens,
|
|
144
|
+
inGracePeriod: rollDecision.enteredGrace || (currentState.inGracePeriod && !didRoll),
|
|
145
|
+
lastRollTime: didRoll ? new Date().toISOString() : currentState.lastRollTime,
|
|
146
|
+
};
|
|
147
|
+
|
|
148
|
+
// Build info
|
|
149
|
+
const info: ContextInfo = {
|
|
150
|
+
didRoll,
|
|
151
|
+
messagesDropped,
|
|
152
|
+
messagesKept: keptMessages.length,
|
|
153
|
+
cacheMarkers,
|
|
154
|
+
cachedTokens,
|
|
155
|
+
uncachedTokens,
|
|
156
|
+
totalTokens: keptTotalTokens,
|
|
157
|
+
hardLimitHit,
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
return { response, state: newState, info };
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// ============================================================================
|
|
164
|
+
// Helper Functions
|
|
165
|
+
// ============================================================================
|
|
166
|
+
|
|
167
|
+
function mergeConfig(config: ContextConfig): ContextConfig {
|
|
168
|
+
return {
|
|
169
|
+
rolling: {
|
|
170
|
+
...DEFAULT_CONTEXT_CONFIG.rolling,
|
|
171
|
+
...config.rolling,
|
|
172
|
+
},
|
|
173
|
+
limits: {
|
|
174
|
+
...DEFAULT_CONTEXT_CONFIG.limits,
|
|
175
|
+
...config.limits,
|
|
176
|
+
},
|
|
177
|
+
cache: {
|
|
178
|
+
...DEFAULT_CONTEXT_CONFIG.cache,
|
|
179
|
+
...config.cache,
|
|
180
|
+
},
|
|
181
|
+
tokenEstimator: config.tokenEstimator,
|
|
182
|
+
};
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
function getMessageId(message: NormalizedMessage): string {
|
|
186
|
+
return message.metadata?.sourceId ?? `msg-${Math.random().toString(36).slice(2)}`;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
function detectDiscontinuity(
|
|
190
|
+
messages: NormalizedMessage[],
|
|
191
|
+
state: ContextState
|
|
192
|
+
): boolean {
|
|
193
|
+
if (state.windowMessageIds.length === 0) {
|
|
194
|
+
return false; // First call, not a discontinuity
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
const currentIds = new Set(messages.map(m => getMessageId(m)));
|
|
198
|
+
const overlap = state.windowMessageIds.filter(id => currentIds.has(id));
|
|
199
|
+
|
|
200
|
+
// If less than 50% overlap, consider it a new conversation
|
|
201
|
+
return overlap.length < state.windowMessageIds.length * 0.5;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
function calculateCharacters(messages: NormalizedMessage[]): number {
|
|
205
|
+
let chars = 0;
|
|
206
|
+
for (const msg of messages) {
|
|
207
|
+
for (const block of msg.content) {
|
|
208
|
+
if (block.type === 'text') {
|
|
209
|
+
chars += block.text.length;
|
|
210
|
+
} else if (block.type === 'tool_result') {
|
|
211
|
+
const content = typeof block.content === 'string'
|
|
212
|
+
? block.content
|
|
213
|
+
: JSON.stringify(block.content);
|
|
214
|
+
chars += content.length;
|
|
215
|
+
}
|
|
216
|
+
// Images not counted for character limits
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
return chars;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
interface RollDecision {
|
|
223
|
+
shouldRoll: boolean;
|
|
224
|
+
reason?: 'threshold' | 'grace_exceeded' | 'hard_limit';
|
|
225
|
+
targetTokens?: number;
|
|
226
|
+
targetMessages?: number;
|
|
227
|
+
enteredGrace: boolean;
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
function shouldRoll(
|
|
231
|
+
state: ContextState,
|
|
232
|
+
messageCount: number,
|
|
233
|
+
totalTokens: number,
|
|
234
|
+
totalCharacters: number,
|
|
235
|
+
config: ContextConfig
|
|
236
|
+
): RollDecision {
|
|
237
|
+
const { rolling, limits } = config;
|
|
238
|
+
const unit = rolling.unit ?? 'messages';
|
|
239
|
+
|
|
240
|
+
const threshold = rolling.threshold;
|
|
241
|
+
const grace = rolling.grace ?? 0;
|
|
242
|
+
const maxThreshold = threshold + grace;
|
|
243
|
+
|
|
244
|
+
// Check hard limits first (always enforced)
|
|
245
|
+
if (limits?.maxCharacters && totalCharacters > limits.maxCharacters) {
|
|
246
|
+
return {
|
|
247
|
+
shouldRoll: true,
|
|
248
|
+
reason: 'hard_limit',
|
|
249
|
+
targetTokens: limits.maxTokens,
|
|
250
|
+
targetMessages: limits.maxMessages,
|
|
251
|
+
enteredGrace: false,
|
|
252
|
+
};
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
if (limits?.maxTokens && totalTokens > limits.maxTokens) {
|
|
256
|
+
return {
|
|
257
|
+
shouldRoll: true,
|
|
258
|
+
reason: 'hard_limit',
|
|
259
|
+
targetTokens: limits.maxTokens,
|
|
260
|
+
targetMessages: limits.maxMessages,
|
|
261
|
+
enteredGrace: false,
|
|
262
|
+
};
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (limits?.maxMessages && messageCount > limits.maxMessages) {
|
|
266
|
+
return {
|
|
267
|
+
shouldRoll: true,
|
|
268
|
+
reason: 'hard_limit',
|
|
269
|
+
targetTokens: limits.maxTokens,
|
|
270
|
+
targetMessages: limits.maxMessages,
|
|
271
|
+
enteredGrace: false,
|
|
272
|
+
};
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// Check rolling threshold
|
|
276
|
+
const current = unit === 'messages' ? state.messagesSinceRoll : state.tokensSinceRoll;
|
|
277
|
+
|
|
278
|
+
if (current >= maxThreshold) {
|
|
279
|
+
// Exceeded grace, must roll
|
|
280
|
+
return {
|
|
281
|
+
shouldRoll: true,
|
|
282
|
+
reason: 'grace_exceeded',
|
|
283
|
+
targetTokens: unit === 'tokens' ? threshold : undefined,
|
|
284
|
+
targetMessages: unit === 'messages' ? threshold : undefined,
|
|
285
|
+
enteredGrace: false,
|
|
286
|
+
};
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if (!state.inGracePeriod && current >= threshold) {
|
|
290
|
+
// Just entered grace period
|
|
291
|
+
return {
|
|
292
|
+
shouldRoll: false,
|
|
293
|
+
enteredGrace: true,
|
|
294
|
+
};
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
return {
|
|
298
|
+
shouldRoll: false,
|
|
299
|
+
enteredGrace: false,
|
|
300
|
+
};
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
interface MessageWithTokens {
|
|
304
|
+
message: NormalizedMessage;
|
|
305
|
+
tokens: number;
|
|
306
|
+
id: string;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
function truncateMessages(
|
|
310
|
+
messages: MessageWithTokens[],
|
|
311
|
+
targetTokens?: number,
|
|
312
|
+
targetMessages?: number,
|
|
313
|
+
config?: ContextConfig
|
|
314
|
+
): { kept: MessageWithTokens[]; dropped: number } {
|
|
315
|
+
// Truncate from the beginning, keeping most recent
|
|
316
|
+
|
|
317
|
+
if (targetMessages && messages.length > targetMessages) {
|
|
318
|
+
const startIdx = messages.length - targetMessages;
|
|
319
|
+
return {
|
|
320
|
+
kept: messages.slice(startIdx),
|
|
321
|
+
dropped: startIdx,
|
|
322
|
+
};
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
if (targetTokens) {
|
|
326
|
+
let tokenSum = 0;
|
|
327
|
+
let startIdx = messages.length;
|
|
328
|
+
|
|
329
|
+
// Count from end backwards
|
|
330
|
+
for (let i = messages.length - 1; i >= 0; i--) {
|
|
331
|
+
tokenSum += messages[i]!.tokens;
|
|
332
|
+
if (tokenSum > targetTokens) {
|
|
333
|
+
startIdx = i + 1;
|
|
334
|
+
break;
|
|
335
|
+
}
|
|
336
|
+
startIdx = i;
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
return {
|
|
340
|
+
kept: messages.slice(startIdx),
|
|
341
|
+
dropped: startIdx,
|
|
342
|
+
};
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
// Default: use buffer from config
|
|
346
|
+
const buffer = config?.rolling.buffer ?? 20;
|
|
347
|
+
const unit = config?.rolling.unit ?? 'messages';
|
|
348
|
+
|
|
349
|
+
if (unit === 'messages') {
|
|
350
|
+
const targetCount = Math.max(buffer * 2, messages.length - buffer);
|
|
351
|
+
if (messages.length > targetCount) {
|
|
352
|
+
const startIdx = messages.length - targetCount;
|
|
353
|
+
return {
|
|
354
|
+
kept: messages.slice(startIdx),
|
|
355
|
+
dropped: startIdx,
|
|
356
|
+
};
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
return { kept: messages, dropped: 0 };
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
function placeCacheMarkers(
|
|
364
|
+
messages: NormalizedMessage[],
|
|
365
|
+
messageTokens: MessageWithTokens[],
|
|
366
|
+
state: ContextState,
|
|
367
|
+
didRoll: boolean,
|
|
368
|
+
config: ContextConfig
|
|
369
|
+
): CacheMarker[] {
|
|
370
|
+
const cacheConfig = config.cache ?? {};
|
|
371
|
+
|
|
372
|
+
if (cacheConfig.enabled === false) {
|
|
373
|
+
return [];
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
const numPoints = cacheConfig.points ?? 1;
|
|
377
|
+
const minTokens = cacheConfig.minTokens ?? 1024;
|
|
378
|
+
const preferUser = cacheConfig.preferUserMessages ?? true;
|
|
379
|
+
|
|
380
|
+
const totalTokens = messageTokens.reduce((sum, m) => sum + m.tokens, 0);
|
|
381
|
+
|
|
382
|
+
// Not enough tokens for caching
|
|
383
|
+
if (totalTokens < minTokens) {
|
|
384
|
+
return [];
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
// If we didn't roll, try to keep existing markers stable
|
|
388
|
+
if (!didRoll && state.cacheMarkers.length > 0) {
|
|
389
|
+
const currentIds = new Set(messages.map(m => getMessageId(m)));
|
|
390
|
+
const validMarkers = state.cacheMarkers.filter(m => currentIds.has(m.messageId));
|
|
391
|
+
|
|
392
|
+
if (validMarkers.length > 0) {
|
|
393
|
+
// Recalculate token estimates for valid markers
|
|
394
|
+
return validMarkers.map(marker => {
|
|
395
|
+
const idx = messages.findIndex(m => getMessageId(m) === marker.messageId);
|
|
396
|
+
const tokenEstimate = messageTokens
|
|
397
|
+
.slice(0, idx + 1)
|
|
398
|
+
.reduce((sum, m) => sum + m.tokens, 0);
|
|
399
|
+
|
|
400
|
+
return {
|
|
401
|
+
messageId: marker.messageId,
|
|
402
|
+
messageIndex: idx,
|
|
403
|
+
tokenEstimate,
|
|
404
|
+
};
|
|
405
|
+
});
|
|
406
|
+
}
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
// Place new markers using arithmetic positioning
|
|
410
|
+
const markers: CacheMarker[] = [];
|
|
411
|
+
const buffer = config.rolling.buffer ?? 20;
|
|
412
|
+
|
|
413
|
+
// For single point: place at (length - buffer)
|
|
414
|
+
// For multiple points: distribute evenly in cacheable portion
|
|
415
|
+
const cacheableEnd = Math.max(0, messages.length - buffer);
|
|
416
|
+
|
|
417
|
+
if (cacheableEnd === 0) {
|
|
418
|
+
return []; // Nothing to cache
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
// Calculate step size for multiple cache points
|
|
422
|
+
const step = Math.floor(cacheableEnd / numPoints);
|
|
423
|
+
|
|
424
|
+
if (step === 0) {
|
|
425
|
+
return []; // Not enough messages for requested cache points
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
let runningTokens = 0;
|
|
429
|
+
let currentIdx = 0;
|
|
430
|
+
|
|
431
|
+
for (let point = 1; point <= numPoints; point++) {
|
|
432
|
+
const targetIdx = Math.min(point * step - 1, cacheableEnd - 1);
|
|
433
|
+
|
|
434
|
+
// Accumulate tokens up to target
|
|
435
|
+
while (currentIdx <= targetIdx && currentIdx < messageTokens.length) {
|
|
436
|
+
runningTokens += messageTokens[currentIdx]!.tokens;
|
|
437
|
+
currentIdx++;
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
let markerIdx = targetIdx;
|
|
441
|
+
let markerTokens = runningTokens;
|
|
442
|
+
|
|
443
|
+
// Adjust to user message if preferred
|
|
444
|
+
if (preferUser) {
|
|
445
|
+
const adjusted = findNearestUserMessage(messages, markerIdx, messageTokens);
|
|
446
|
+
if (adjusted) {
|
|
447
|
+
markerIdx = adjusted.index;
|
|
448
|
+
markerTokens = adjusted.tokens;
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
// Skip if below minimum
|
|
453
|
+
if (markerTokens < minTokens) {
|
|
454
|
+
continue;
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
// Skip if duplicate
|
|
458
|
+
if (markers.some(m => m.messageIndex === markerIdx)) {
|
|
459
|
+
continue;
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
markers.push({
|
|
463
|
+
messageId: getMessageId(messages[markerIdx]!),
|
|
464
|
+
messageIndex: markerIdx,
|
|
465
|
+
tokenEstimate: markerTokens,
|
|
466
|
+
});
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
return markers;
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
function findNearestUserMessage(
|
|
473
|
+
messages: NormalizedMessage[],
|
|
474
|
+
startIdx: number,
|
|
475
|
+
messageTokens: MessageWithTokens[]
|
|
476
|
+
): { index: number; tokens: number } | null {
|
|
477
|
+
// Search backwards for a user message (non-assistant participant)
|
|
478
|
+
const maxSearch = 5;
|
|
479
|
+
|
|
480
|
+
let tokens = messageTokens.slice(0, startIdx + 1).reduce((sum, m) => sum + m.tokens, 0);
|
|
481
|
+
|
|
482
|
+
for (let i = startIdx; i >= Math.max(0, startIdx - maxSearch); i--) {
|
|
483
|
+
const msg = messages[i]!;
|
|
484
|
+
// Heuristic: if participant isn't a common assistant name, it's probably a user
|
|
485
|
+
const participant = msg.participant.toLowerCase();
|
|
486
|
+
const isUser = !['claude', 'assistant', 'bot', 'ai'].includes(participant);
|
|
487
|
+
|
|
488
|
+
if (isUser) {
|
|
489
|
+
return { index: i, tokens };
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
tokens -= messageTokens[i]!.tokens;
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
return null;
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
function applyCacheMarkers(
|
|
499
|
+
messages: NormalizedMessage[],
|
|
500
|
+
cacheMarkers: CacheMarker[]
|
|
501
|
+
): NormalizedMessage[] {
|
|
502
|
+
if (cacheMarkers.length === 0) {
|
|
503
|
+
return messages;
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
const markerIndices = new Set(cacheMarkers.map(m => m.messageIndex));
|
|
507
|
+
|
|
508
|
+
return messages.map((msg, idx) => {
|
|
509
|
+
if (markerIndices.has(idx)) {
|
|
510
|
+
return {
|
|
511
|
+
...msg,
|
|
512
|
+
metadata: {
|
|
513
|
+
...msg.metadata,
|
|
514
|
+
cacheControl: { type: 'ephemeral' as const },
|
|
515
|
+
},
|
|
516
|
+
};
|
|
517
|
+
}
|
|
518
|
+
return msg;
|
|
519
|
+
});
|
|
520
|
+
}
|