recursive-llm-ts 4.8.0 → 5.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,859 @@
1
+ package rlm
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "math"
7
+ "net/http"
8
+ "net/http/httptest"
9
+ "strings"
10
+ "testing"
11
+ )
12
+
13
+ func useHeuristicTokenizerForTest(t *testing.T) {
14
+ t.Helper()
15
+ ResetDefaultTokenizer()
16
+ t.Cleanup(func() {
17
+ ResetDefaultTokenizer()
18
+ })
19
+ }
20
+ // ─── Token Tracking Unit Tests ──────────────────────────────────────────────
21
+
22
+ func TestTokenUsage_ParsedFromAPIResponse(t *testing.T) {
23
+ // Verify that CallChatCompletion correctly parses the usage field from API responses
24
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25
+ resp := map[string]interface{}{
26
+ "choices": []map[string]interface{}{
27
+ {"message": map[string]string{"content": "Hello world"}},
28
+ },
29
+ "usage": map[string]interface{}{
30
+ "prompt_tokens": 150,
31
+ "completion_tokens": 25,
32
+ "total_tokens": 175,
33
+ },
34
+ }
35
+ _ = json.NewEncoder(w).Encode(resp)
36
+ }))
37
+ defer server.Close()
38
+
39
+ result, err := CallChatCompletion(ChatRequest{
40
+ Model: "test-model",
41
+ Messages: []Message{{Role: "user", Content: "test"}},
42
+ APIBase: server.URL,
43
+ })
44
+ if err != nil {
45
+ t.Fatalf("unexpected error: %v", err)
46
+ }
47
+
48
+ if result.Content != "Hello world" {
49
+ t.Errorf("expected content 'Hello world', got %q", result.Content)
50
+ }
51
+ if result.Usage == nil {
52
+ t.Fatal("expected usage to be non-nil")
53
+ }
54
+ if result.Usage.PromptTokens != 150 {
55
+ t.Errorf("expected 150 prompt tokens, got %d", result.Usage.PromptTokens)
56
+ }
57
+ if result.Usage.CompletionTokens != 25 {
58
+ t.Errorf("expected 25 completion tokens, got %d", result.Usage.CompletionTokens)
59
+ }
60
+ if result.Usage.TotalTokens != 175 {
61
+ t.Errorf("expected 175 total tokens, got %d", result.Usage.TotalTokens)
62
+ }
63
+ }
64
+
65
+ func TestTokenUsage_NilWhenAPIDoesNotReturnUsage(t *testing.T) {
66
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
67
+ resp := map[string]interface{}{
68
+ "choices": []map[string]interface{}{
69
+ {"message": map[string]string{"content": "Hello"}},
70
+ },
71
+ }
72
+ _ = json.NewEncoder(w).Encode(resp)
73
+ }))
74
+ defer server.Close()
75
+
76
+ result, err := CallChatCompletion(ChatRequest{
77
+ Model: "test-model",
78
+ Messages: []Message{{Role: "user", Content: "test"}},
79
+ APIBase: server.URL,
80
+ })
81
+ if err != nil {
82
+ t.Fatalf("unexpected error: %v", err)
83
+ }
84
+
85
+ if result.Usage != nil {
86
+ t.Errorf("expected usage to be nil when API doesn't return it, got %+v", result.Usage)
87
+ }
88
+ }
89
+
90
+ func TestRLMStats_TokenAccumulation(t *testing.T) {
91
+ // Test that token usage accumulates across multiple LLM calls
92
+ callCount := 0
93
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94
+ callCount++
95
+ resp := map[string]interface{}{
96
+ "choices": []map[string]interface{}{
97
+ {"message": map[string]string{"content": fmt.Sprintf(`FINAL("answer from call %d")`, callCount)}},
98
+ },
99
+ "usage": map[string]interface{}{
100
+ "prompt_tokens": 100 * callCount,
101
+ "completion_tokens": 20 * callCount,
102
+ "total_tokens": 120 * callCount,
103
+ },
104
+ }
105
+ _ = json.NewEncoder(w).Encode(resp)
106
+ }))
107
+ defer server.Close()
108
+
109
+ engine := New("test-model", Config{
110
+ APIBase: server.URL,
111
+ MaxDepth: 5,
112
+ MaxIterations: 10,
113
+ })
114
+
115
+ _, stats, err := engine.Completion("test query", "test context")
116
+ if err != nil {
117
+ t.Fatalf("unexpected error: %v", err)
118
+ }
119
+
120
+ // First call should have returned FINAL, so 1 LLM call
121
+ if stats.LlmCalls != 1 {
122
+ t.Errorf("expected 1 LLM call, got %d", stats.LlmCalls)
123
+ }
124
+ if stats.TotalTokens != 120 {
125
+ t.Errorf("expected 120 total tokens, got %d", stats.TotalTokens)
126
+ }
127
+ if stats.PromptTokens != 100 {
128
+ t.Errorf("expected 100 prompt tokens, got %d", stats.PromptTokens)
129
+ }
130
+ if stats.CompletionTokens != 20 {
131
+ t.Errorf("expected 20 completion tokens, got %d", stats.CompletionTokens)
132
+ }
133
+ }
134
+
135
+ func TestRLMStats_TokenAccumulation_MultipleIterations(t *testing.T) {
136
+ // Simulates an RLM completion that takes 3 iterations before producing FINAL
137
+ callCount := 0
138
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
139
+ callCount++
140
+ content := "context.indexOf('test')"
141
+ if callCount >= 3 {
142
+ content = `FINAL("done after 3 calls")`
143
+ }
144
+ resp := map[string]interface{}{
145
+ "choices": []map[string]interface{}{
146
+ {"message": map[string]string{"content": content}},
147
+ },
148
+ "usage": map[string]interface{}{
149
+ "prompt_tokens": 200,
150
+ "completion_tokens": 50,
151
+ "total_tokens": 250,
152
+ },
153
+ }
154
+ _ = json.NewEncoder(w).Encode(resp)
155
+ }))
156
+ defer server.Close()
157
+
158
+ engine := New("test-model", Config{
159
+ APIBase: server.URL,
160
+ MaxDepth: 5,
161
+ MaxIterations: 10,
162
+ })
163
+
164
+ _, stats, err := engine.Completion("test query", "test context for searching")
165
+ if err != nil {
166
+ t.Fatalf("unexpected error: %v", err)
167
+ }
168
+
169
+ if stats.LlmCalls != 3 {
170
+ t.Errorf("expected 3 LLM calls, got %d", stats.LlmCalls)
171
+ }
172
+ // 3 calls * 250 tokens each = 750 total
173
+ if stats.TotalTokens != 750 {
174
+ t.Errorf("expected 750 total tokens (3 calls * 250), got %d", stats.TotalTokens)
175
+ }
176
+ if stats.PromptTokens != 600 {
177
+ t.Errorf("expected 600 prompt tokens (3 calls * 200), got %d", stats.PromptTokens)
178
+ }
179
+ if stats.CompletionTokens != 150 {
180
+ t.Errorf("expected 150 completion tokens (3 calls * 50), got %d", stats.CompletionTokens)
181
+ }
182
+ }
183
+
184
+ func TestRLMStats_TokensInJSONOutput(t *testing.T) {
185
+ // Verify token fields are serialized in the JSON output
186
+ stats := RLMStats{
187
+ LlmCalls: 3,
188
+ Iterations: 2,
189
+ Depth: 0,
190
+ TotalTokens: 750,
191
+ PromptTokens: 600,
192
+ CompletionTokens: 150,
193
+ }
194
+
195
+ data, err := json.Marshal(stats)
196
+ if err != nil {
197
+ t.Fatalf("failed to marshal stats: %v", err)
198
+ }
199
+
200
+ var parsed map[string]interface{}
201
+ if err := json.Unmarshal(data, &parsed); err != nil {
202
+ t.Fatalf("failed to unmarshal stats: %v", err)
203
+ }
204
+
205
+ if v, ok := parsed["total_tokens"].(float64); !ok || int(v) != 750 {
206
+ t.Errorf("expected total_tokens=750 in JSON, got %v", parsed["total_tokens"])
207
+ }
208
+ if v, ok := parsed["prompt_tokens"].(float64); !ok || int(v) != 600 {
209
+ t.Errorf("expected prompt_tokens=600 in JSON, got %v", parsed["prompt_tokens"])
210
+ }
211
+ if v, ok := parsed["completion_tokens"].(float64); !ok || int(v) != 150 {
212
+ t.Errorf("expected completion_tokens=150 in JSON, got %v", parsed["completion_tokens"])
213
+ }
214
+ }
215
+
216
+ func TestRLMStats_ZeroTokensOmittedFromJSON(t *testing.T) {
217
+ // When no tokens are tracked, fields should be omitted (omitempty)
218
+ stats := RLMStats{
219
+ LlmCalls: 1,
220
+ Iterations: 1,
221
+ Depth: 0,
222
+ }
223
+
224
+ data, err := json.Marshal(stats)
225
+ if err != nil {
226
+ t.Fatalf("failed to marshal stats: %v", err)
227
+ }
228
+
229
+ jsonStr := string(data)
230
+ if strings.Contains(jsonStr, "total_tokens") {
231
+ t.Errorf("expected total_tokens to be omitted when zero, got: %s", jsonStr)
232
+ }
233
+ if strings.Contains(jsonStr, "prompt_tokens") {
234
+ t.Errorf("expected prompt_tokens to be omitted when zero, got: %s", jsonStr)
235
+ }
236
+ if strings.Contains(jsonStr, "completion_tokens") {
237
+ t.Errorf("expected completion_tokens to be omitted when zero, got: %s", jsonStr)
238
+ }
239
+ }
240
+
241
+ func TestFormatStatsWithObservability_IncludesTokens(t *testing.T) {
242
+ stats := RLMStats{
243
+ LlmCalls: 2,
244
+ Iterations: 1,
245
+ Depth: 0,
246
+ TotalTokens: 500,
247
+ PromptTokens: 400,
248
+ CompletionTokens: 100,
249
+ }
250
+
251
+ obs := NewNoopObserver()
252
+ formatted := FormatStatsWithObservability(stats, obs)
253
+
254
+ if v, ok := formatted["total_tokens"].(int); !ok || v != 500 {
255
+ t.Errorf("expected total_tokens=500, got %v", formatted["total_tokens"])
256
+ }
257
+ if v, ok := formatted["prompt_tokens"].(int); !ok || v != 400 {
258
+ t.Errorf("expected prompt_tokens=400, got %v", formatted["prompt_tokens"])
259
+ }
260
+ if v, ok := formatted["completion_tokens"].(int); !ok || v != 100 {
261
+ t.Errorf("expected completion_tokens=100, got %v", formatted["completion_tokens"])
262
+ }
263
+ }
264
+
265
+ func TestFormatStatsWithObservability_OmitsZeroTokens(t *testing.T) {
266
+ stats := RLMStats{
267
+ LlmCalls: 1,
268
+ Iterations: 1,
269
+ Depth: 0,
270
+ }
271
+
272
+ obs := NewNoopObserver()
273
+ formatted := FormatStatsWithObservability(stats, obs)
274
+
275
+ if _, exists := formatted["total_tokens"]; exists {
276
+ t.Errorf("expected total_tokens to be absent when zero, got %v", formatted["total_tokens"])
277
+ }
278
+ }
279
+
280
+ // ─── Token Efficiency Tests ─────────────────────────────────────────────────
281
+ //
282
+ // These tests prove that RLM context reduction strategies process fewer tokens
283
+ // than passing an entire large document through as raw context.
284
+
285
+ // generateLargeContext creates a realistic document of approximately targetTokens tokens.
286
+ // It generates structured content with numbered paragraphs to make it easy to verify
287
+ // that reduction strategies preserve key information.
288
+ func generateLargeContext(targetTokens int) string {
289
+ // ~3.5 chars per token is our estimation ratio
290
+ targetChars := int(float64(targetTokens) * 3.5)
291
+
292
+ var sb strings.Builder
293
+ sb.WriteString("# Technical Report: System Performance Analysis\n\n")
294
+ sb.WriteString("## Executive Summary\n\n")
295
+ sb.WriteString("This comprehensive report analyzes the performance characteristics of the distributed system ")
296
+ sb.WriteString("deployed across three data centers. Key findings include a 15% improvement in latency, ")
297
+ sb.WriteString("23% reduction in error rates, and significant cost savings through resource optimization.\n\n")
298
+
299
+ paragraphNum := 1
300
+ for sb.Len() < targetChars {
301
+ // Generate diverse paragraph types to simulate realistic documents
302
+ switch paragraphNum % 5 {
303
+ case 0:
304
+ fmt.Fprintf(&sb, "### Section %d: Database Performance Metrics\n\n", paragraphNum)
305
+ fmt.Fprintf(&sb, "In quarter Q%d, the primary database cluster processed an average of %d,000 queries per second "+
306
+ "with a p99 latency of %d.%d milliseconds. The read-to-write ratio was approximately %d:%d. "+
307
+ "Connection pool utilization peaked at %d%% during high-traffic periods, with %d active connections "+
308
+ "out of a configured maximum of %d. Index hit ratios remained above %d%% for all primary tables, "+
309
+ "though the secondary indexes on the analytics tables showed degradation to %d%% during batch "+
310
+ "processing windows. This resulted in an overall throughput improvement of %d.%d%% compared to "+
311
+ "the previous quarter's baseline measurements.\n\n",
312
+ paragraphNum%4+1, paragraphNum*12+50, paragraphNum%10+1, paragraphNum%99,
313
+ paragraphNum%7+3, 1, paragraphNum%30+70, paragraphNum*3+100, paragraphNum*5+200,
314
+ paragraphNum%5+95, paragraphNum%20+75, paragraphNum%15+5, paragraphNum%99)
315
+ case 1:
316
+ fmt.Fprintf(&sb, "### Section %d: API Gateway Statistics\n\n", paragraphNum)
317
+ fmt.Fprintf(&sb, "The API gateway handled %d.%dM requests during the reporting period. Rate limiting "+
318
+ "was triggered %d times for %d unique clients. The top 5 endpoints by traffic volume were: "+
319
+ "/api/v2/users (%d.%d%%), /api/v2/products (%d.%d%%), /api/v2/orders (%d.%d%%), "+
320
+ "/api/v2/analytics (%d.%d%%), and /api/v2/search (%d.%d%%). Authentication failures "+
321
+ "decreased from %d to %d per day after implementing the new token refresh mechanism. "+
322
+ "The overall API availability was %d.%d%% with %d minutes of total downtime.\n\n",
323
+ paragraphNum*5+10, paragraphNum%99, paragraphNum*7+20, paragraphNum*3+5,
324
+ paragraphNum%20+20, paragraphNum%99, paragraphNum%15+15, paragraphNum%99,
325
+ paragraphNum%10+10, paragraphNum%99, paragraphNum%8+5, paragraphNum%99,
326
+ paragraphNum%5+3, paragraphNum%99, paragraphNum*2+50, paragraphNum+10,
327
+ 99, paragraphNum%10+90, paragraphNum%30+5)
328
+ case 2:
329
+ fmt.Fprintf(&sb, "### Section %d: Memory and CPU Utilization\n\n", paragraphNum)
330
+ fmt.Fprintf(&sb, "Across all %d nodes in the cluster, average memory utilization was %d.%d%%. "+
331
+ "Node %d consistently showed the highest memory consumption at %d.%d%%, primarily due to "+
332
+ "in-memory caching of frequently accessed data structures. CPU utilization averaged %d.%d%% "+
333
+ "with peaks reaching %d.%d%% during the daily ETL batch processing window between "+
334
+ "%d:00 and %d:00 UTC. Garbage collection pauses were reduced from an average of %dms to %dms "+
335
+ "after tuning the JVM parameters. Thread pool saturation events decreased from %d per hour "+
336
+ "to %d per hour following the implementation of adaptive thread pool sizing.\n\n",
337
+ paragraphNum*2+20, paragraphNum%40+50, paragraphNum%99, paragraphNum%20+1,
338
+ paragraphNum%15+80, paragraphNum%99, paragraphNum%30+40, paragraphNum%99,
339
+ paragraphNum%20+75, paragraphNum%99, paragraphNum%6+2, paragraphNum%6+4,
340
+ paragraphNum%50+100, paragraphNum%30+20, paragraphNum%10+5, paragraphNum%5+1)
341
+ case 3:
342
+ fmt.Fprintf(&sb, "### Section %d: Error Analysis and Incident Report\n\n", paragraphNum)
343
+ fmt.Fprintf(&sb, "During the period, %d unique error types were observed across the system. "+
344
+ "The most frequent error (ERR-%04d) was a transient connection timeout to the Redis cluster, "+
345
+ "occurring %d times with a mean time to recovery of %d.%d seconds. Error category breakdown: "+
346
+ "network errors (%d%%), application errors (%d%%), database errors (%d%%), "+
347
+ "authentication errors (%d%%), and other (%d%%). The total error budget consumed was %d.%d%% "+
348
+ "of the allocated %d.%d%% for the quarter. Two P2 incidents were recorded on days %d and %d, "+
349
+ "with root causes traced to upstream provider instability and a misconfigured load balancer "+
350
+ "health check interval respectively.\n\n",
351
+ paragraphNum*3+15, paragraphNum+1000, paragraphNum*50+200, paragraphNum%10+1, paragraphNum%99,
352
+ paragraphNum%30+30, paragraphNum%25+20, paragraphNum%20+15, paragraphNum%10+5,
353
+ paragraphNum%10+5, paragraphNum%3, paragraphNum%99, paragraphNum%5, paragraphNum%99,
354
+ paragraphNum%28+1, paragraphNum%28+15)
355
+ case 4:
356
+ fmt.Fprintf(&sb, "### Section %d: Cost Optimization Results\n\n", paragraphNum)
357
+ fmt.Fprintf(&sb, "Infrastructure costs for the period totaled $%d,%03d.%02d, representing a "+
358
+ "%d.%d%% decrease from the previous quarter. Key savings were achieved through: "+
359
+ "reserved instance utilization (saving $%d,%03d), right-sizing %d underutilized instances "+
360
+ "(saving $%d,%03d), implementing spot instances for batch workloads (saving $%d,%03d), "+
361
+ "and optimizing data transfer routes (saving $%d,%03d). The cost per million API requests "+
362
+ "decreased from $%d.%02d to $%d.%02d. Projected annual savings based on current trends: "+
363
+ "$%d,%03d. Storage costs increased by %d.%d%% due to expanded logging retention requirements.\n\n",
364
+ paragraphNum*100+500, paragraphNum%1000, paragraphNum%100, paragraphNum%15+5, paragraphNum%99,
365
+ paragraphNum*20+100, paragraphNum%1000, paragraphNum*3+10, paragraphNum*10+50, paragraphNum%1000,
366
+ paragraphNum*8+30, paragraphNum%1000, paragraphNum*5+20, paragraphNum%1000,
367
+ paragraphNum%50+10, paragraphNum%100, paragraphNum%40+5, paragraphNum%100,
368
+ paragraphNum*300+1000, paragraphNum%1000, paragraphNum%10+2, paragraphNum%99)
369
+ }
370
+ paragraphNum++
371
+ }
372
+
373
+ return sb.String()
374
+ }
375
+
376
+ func TestTokenEfficiency_TFIDFUsesFewerTokens(t *testing.T) {
377
+ useHeuristicTokenizerForTest(t)
378
+
379
+ // Generate a large context (~35,000 tokens, well over 32k)
380
+ largeContext := generateLargeContext(35000)
381
+ originalTokens := EstimateTokens(largeContext)
382
+ if originalTokens < 32000 {
383
+ t.Fatalf("generated context is too small: %d tokens, need at least 32000", originalTokens)
384
+ }
385
+ t.Logf("Original context: %d chars, ~%d estimated tokens", len(largeContext), originalTokens)
386
+
387
+ // Apply TF-IDF compression to fit within a 32k token budget
388
+ modelLimit := 32768
389
+ overhead := 1000 // System prompt + query overhead
390
+ availableTokens := modelLimit - overhead
391
+
392
+ compressed := CompressContextTFIDF(largeContext, availableTokens)
393
+ compressedTokens := EstimateTokens(compressed)
394
+
395
+ t.Logf("TF-IDF compressed: %d chars, ~%d estimated tokens", len(compressed), compressedTokens)
396
+ t.Logf("Token reduction: %d -> %d (%.1f%% reduction)",
397
+ originalTokens, compressedTokens,
398
+ (1.0-float64(compressedTokens)/float64(originalTokens))*100)
399
+
400
+ // Core assertion: TF-IDF MUST produce fewer tokens than the original
401
+ if compressedTokens >= originalTokens {
402
+ t.Errorf("TF-IDF failed to reduce tokens: original=%d, compressed=%d", originalTokens, compressedTokens)
403
+ }
404
+
405
+ // And it must fit within our budget
406
+ if compressedTokens > availableTokens {
407
+ t.Errorf("TF-IDF output exceeds budget: %d tokens > %d available", compressedTokens, availableTokens)
408
+ }
409
+
410
+ // Verify meaningful compression (at least 5% reduction for a context that's over budget)
411
+ reductionPct := (1.0 - float64(compressedTokens)/float64(originalTokens)) * 100
412
+ if reductionPct < 5.0 {
413
+ t.Errorf("TF-IDF compression too weak: only %.1f%% reduction", reductionPct)
414
+ }
415
+ }
416
+
417
+ func TestTokenEfficiency_TextRankUsesFewerTokens(t *testing.T) {
418
+ useHeuristicTokenizerForTest(t)
419
+
420
+ largeContext := generateLargeContext(35000)
421
+ originalTokens := EstimateTokens(largeContext)
422
+ if originalTokens < 32000 {
423
+ t.Fatalf("generated context is too small: %d tokens, need at least 32000", originalTokens)
424
+ }
425
+ t.Logf("Original context: %d chars, ~%d estimated tokens", len(largeContext), originalTokens)
426
+
427
+ modelLimit := 32768
428
+ overhead := 1000
429
+ availableTokens := modelLimit - overhead
430
+
431
+ compressed := CompressContextTextRank(largeContext, availableTokens)
432
+ compressedTokens := EstimateTokens(compressed)
433
+
434
+ t.Logf("TextRank compressed: %d chars, ~%d estimated tokens", len(compressed), compressedTokens)
435
+ t.Logf("Token reduction: %d -> %d (%.1f%% reduction)",
436
+ originalTokens, compressedTokens,
437
+ (1.0-float64(compressedTokens)/float64(originalTokens))*100)
438
+
439
+ if compressedTokens >= originalTokens {
440
+ t.Errorf("TextRank failed to reduce tokens: original=%d, compressed=%d", originalTokens, compressedTokens)
441
+ }
442
+
443
+ if compressedTokens > availableTokens {
444
+ t.Errorf("TextRank output exceeds budget: %d tokens > %d available", compressedTokens, availableTokens)
445
+ }
446
+
447
+ reductionPct := (1.0 - float64(compressedTokens)/float64(originalTokens)) * 100
448
+ if reductionPct < 5.0 {
449
+ t.Errorf("TextRank compression too weak: only %.1f%% reduction", reductionPct)
450
+ }
451
+ }
452
+
453
+ func TestTokenEfficiency_TruncateUsesFewerTokens(t *testing.T) {
454
+ useHeuristicTokenizerForTest(t)
455
+
456
+ largeContext := generateLargeContext(35000)
457
+ originalTokens := EstimateTokens(largeContext)
458
+
459
+ if originalTokens < 32000 {
460
+ t.Fatalf("generated context is too small: %d tokens, need at least 32000", originalTokens)
461
+ }
462
+
463
+ modelLimit := 32768
464
+ overhead := 1000
465
+
466
+ // Create a reducer with truncation strategy
467
+ engine := New("test-model", Config{
468
+ MaxDepth: 5,
469
+ MaxIterations: 10,
470
+ ContextOverflow: &ContextOverflowConfig{
471
+ Enabled: true,
472
+ Strategy: "truncate",
473
+ SafetyMargin: 0.15,
474
+ },
475
+ })
476
+
477
+ reducer := newContextReducer(engine, *engine.contextOverflow, NewNoopObserver())
478
+ truncated, err := reducer.reduceByTruncation(largeContext, modelLimit, overhead)
479
+ if err != nil {
480
+ t.Fatalf("truncation failed: %v", err)
481
+ }
482
+
483
+ truncatedTokens := EstimateTokens(truncated)
484
+
485
+ t.Logf("Truncate: %d -> %d estimated tokens (%.1f%% reduction)",
486
+ originalTokens, truncatedTokens,
487
+ (1.0-float64(truncatedTokens)/float64(originalTokens))*100)
488
+
489
+ if truncatedTokens >= originalTokens {
490
+ t.Errorf("truncation failed to reduce tokens: original=%d, truncated=%d", originalTokens, truncatedTokens)
491
+ }
492
+ }
493
+
494
+ func TestTokenEfficiency_ChunkingProducesSmallChunks(t *testing.T) {
495
+ largeContext := generateLargeContext(35000)
496
+ originalTokens := EstimateTokens(largeContext)
497
+
498
+ if originalTokens < 32000 {
499
+ t.Fatalf("generated context is too small: %d tokens, need at least 32000", originalTokens)
500
+ }
501
+
502
+ // Chunk with a 8k token budget per chunk
503
+ chunkBudget := 8000
504
+ chunks := ChunkContext(largeContext, chunkBudget)
505
+
506
+ t.Logf("Chunked %d tokens into %d chunks (budget: %d tokens/chunk)", originalTokens, len(chunks), chunkBudget)
507
+
508
+ if len(chunks) < 2 {
509
+ t.Errorf("expected multiple chunks for %d token context, got %d", originalTokens, len(chunks))
510
+ }
511
+
512
+ // Each chunk must be smaller than the original
513
+ for i, chunk := range chunks {
514
+ chunkTokens := EstimateTokens(chunk)
515
+ if chunkTokens >= originalTokens {
516
+ t.Errorf("chunk %d is not smaller than original: %d tokens >= %d", i, chunkTokens, originalTokens)
517
+ }
518
+ t.Logf(" Chunk %d: %d estimated tokens", i, chunkTokens)
519
+ }
520
+ }
521
+
522
+ func TestTokenEfficiency_PreemptiveReduction(t *testing.T) {
523
+ // Test that PreemptiveReduceContext actually reduces a large context
524
+ largeContext := generateLargeContext(35000)
525
+ originalTokens := EstimateTokens(largeContext)
526
+
527
+ engine := New("gpt-4o-mini", Config{
528
+ MaxDepth: 5,
529
+ MaxIterations: 10,
530
+ ContextOverflow: &ContextOverflowConfig{
531
+ Enabled: true,
532
+ Strategy: "tfidf",
533
+ SafetyMargin: 0.15,
534
+ },
535
+ })
536
+
537
+ reduced, wasReduced, err := engine.PreemptiveReduceContext("Summarize the key findings", largeContext, 0)
538
+ if err != nil {
539
+ t.Fatalf("preemptive reduction failed: %v", err)
540
+ }
541
+
542
+ // gpt-4o-mini has 128k limit, so 35k should NOT trigger reduction
543
+ if wasReduced {
544
+ t.Logf("context was unexpectedly reduced for 35k input with 128k model limit")
545
+ } else {
546
+ t.Logf("correctly skipped reduction: 35k tokens fits within gpt-4o-mini's 128k limit")
547
+ }
548
+
549
+ // Force a smaller model limit to ensure reduction triggers
550
+ engine2 := New("gpt-4", Config{
551
+ MaxDepth: 5,
552
+ MaxIterations: 10,
553
+ ContextOverflow: &ContextOverflowConfig{
554
+ Enabled: true,
555
+ Strategy: "tfidf",
556
+ SafetyMargin: 0.15,
557
+ MaxModelTokens: 16000, // Force small limit
558
+ },
559
+ })
560
+
561
+ reduced2, wasReduced2, err := engine2.PreemptiveReduceContext("Summarize the key findings", largeContext, 0)
562
+ if err != nil {
563
+ t.Fatalf("preemptive reduction failed: %v", err)
564
+ }
565
+
566
+ if !wasReduced2 {
567
+ t.Error("expected context to be reduced when model limit is 16k and context is 35k tokens")
568
+ }
569
+
570
+ reducedTokens := EstimateTokens(reduced2)
571
+ t.Logf("Preemptive TF-IDF: %d -> %d estimated tokens (%.1f%% reduction)",
572
+ originalTokens, reducedTokens,
573
+ (1.0-float64(reducedTokens)/float64(originalTokens))*100)
574
+
575
+ if reducedTokens >= originalTokens {
576
+ t.Errorf("preemptive reduction failed: original=%d, reduced=%d", originalTokens, reducedTokens)
577
+ }
578
+
579
+ _ = reduced // used above
580
+ }
581
+
582
+ func TestTokenEfficiency_AllStrategiesCompared(t *testing.T) {
583
+ useHeuristicTokenizerForTest(t)
584
+
585
+ // Generate a 40k token context (well over 32k limit)
586
+ largeContext := generateLargeContext(40000)
587
+ originalTokens := EstimateTokens(largeContext)
588
+ if originalTokens < 35000 {
589
+ t.Fatalf("generated context is too small: %d tokens, need at least 35000", originalTokens)
590
+ }
591
+
592
+ modelLimit := 32768
593
+ overhead := 1000
594
+
595
+ t.Logf("Original context: %d chars, ~%d estimated tokens", len(largeContext), originalTokens)
596
+ t.Logf("Model limit: %d tokens, overhead: %d, available: %d", modelLimit, overhead, modelLimit-overhead)
597
+
598
+ // Track results for each strategy
599
+ type strategyResult struct {
600
+ name string
601
+ reducedTokens int
602
+ reductionPct float64
603
+ requiresLLM bool
604
+ }
605
+ var results []strategyResult
606
+
607
+ availableTokens := modelLimit - overhead
608
+
609
+ // TF-IDF (pure algorithmic)
610
+ tfidfResult := CompressContextTFIDF(largeContext, availableTokens)
611
+ tfidfTokens := EstimateTokens(tfidfResult)
612
+ results = append(results, strategyResult{
613
+ name: "tfidf",
614
+ reducedTokens: tfidfTokens,
615
+ reductionPct: (1.0 - float64(tfidfTokens)/float64(originalTokens)) * 100,
616
+ requiresLLM: false,
617
+ })
618
+
619
+ // TextRank (pure algorithmic)
620
+ textRankResult := CompressContextTextRank(largeContext, availableTokens)
621
+ textRankTokens := EstimateTokens(textRankResult)
622
+ results = append(results, strategyResult{
623
+ name: "textrank",
624
+ reducedTokens: textRankTokens,
625
+ reductionPct: (1.0 - float64(textRankTokens)/float64(originalTokens)) * 100,
626
+ requiresLLM: false,
627
+ })
628
+
629
+ // Truncation
630
+ engine := New("test-model", Config{
631
+ MaxDepth: 5,
632
+ MaxIterations: 10,
633
+ ContextOverflow: &ContextOverflowConfig{
634
+ Enabled: true,
635
+ Strategy: "truncate",
636
+ SafetyMargin: 0.15,
637
+ },
638
+ })
639
+ reducer := newContextReducer(engine, *engine.contextOverflow, NewNoopObserver())
640
+ truncResult, _ := reducer.reduceByTruncation(largeContext, modelLimit, overhead)
641
+ truncTokens := EstimateTokens(truncResult)
642
+ results = append(results, strategyResult{
643
+ name: "truncate",
644
+ reducedTokens: truncTokens,
645
+ reductionPct: (1.0 - float64(truncTokens)/float64(originalTokens)) * 100,
646
+ requiresLLM: false,
647
+ })
648
+
649
+ // Print comparison table
650
+ t.Logf("\n--- Token Efficiency Comparison ---")
651
+ t.Logf("%-12s | %12s | %10s | %s", "Strategy", "Tokens Used", "Reduction", "Requires LLM")
652
+ t.Logf("%-12s | %12s | %10s | %s", "------------", "------------", "----------", "------------")
653
+ t.Logf("%-12s | %12d | %9s | %s", "raw (none)", originalTokens, "0.0%", "no")
654
+ for _, r := range results {
655
+ llmStr := "no"
656
+ if r.requiresLLM {
657
+ llmStr = "yes"
658
+ }
659
+ t.Logf("%-12s | %12d | %9.1f%% | %s", r.name, r.reducedTokens, r.reductionPct, llmStr)
660
+ }
661
+
662
+ // Assert ALL strategies use fewer tokens than raw
663
+ for _, r := range results {
664
+ if r.reducedTokens >= originalTokens {
665
+ t.Errorf("strategy %q failed: %d tokens >= original %d tokens", r.name, r.reducedTokens, originalTokens)
666
+ }
667
+ }
668
+
669
+ // Assert all strategies fit within the model limit
670
+ for _, r := range results {
671
+ if r.reducedTokens > availableTokens {
672
+ t.Errorf("strategy %q exceeds budget: %d tokens > %d available", r.name, r.reducedTokens, availableTokens)
673
+ }
674
+ }
675
+ }
676
+
677
+ func TestTokenEfficiency_VeryLargeContext_100kTokens(t *testing.T) {
678
+ useHeuristicTokenizerForTest(t)
679
+
680
+ // Test with a very large context (~100k tokens) to prove scaling
681
+ largeContext := generateLargeContext(100000)
682
+ originalTokens := EstimateTokens(largeContext)
683
+ if originalTokens < 90000 {
684
+ t.Fatalf("generated context is too small: %d tokens, need at least 90000", originalTokens)
685
+ }
686
+
687
+ modelLimit := 32768
688
+ overhead := 1000
689
+ availableTokens := modelLimit - overhead
690
+
691
+ t.Logf("Original: ~%d estimated tokens (3x over 32k limit)", originalTokens)
692
+
693
+ // TF-IDF
694
+ tfidfResult := CompressContextTFIDF(largeContext, availableTokens)
695
+ tfidfTokens := EstimateTokens(tfidfResult)
696
+
697
+ // TextRank
698
+ textRankResult := CompressContextTextRank(largeContext, availableTokens)
699
+ textRankTokens := EstimateTokens(textRankResult)
700
+
701
+ t.Logf("TF-IDF: %d tokens (%.1f%% reduction)", tfidfTokens, (1.0-float64(tfidfTokens)/float64(originalTokens))*100)
702
+ t.Logf("TextRank: %d tokens (%.1f%% reduction)", textRankTokens, (1.0-float64(textRankTokens)/float64(originalTokens))*100)
703
+
704
+ // Both must be significantly smaller
705
+ if tfidfTokens >= originalTokens/2 {
706
+ t.Errorf("TF-IDF should reduce 100k context by at least 50%%: got %d tokens", tfidfTokens)
707
+ }
708
+ if textRankTokens >= originalTokens/2 {
709
+ t.Errorf("TextRank should reduce 100k context by at least 50%%: got %d tokens", textRankTokens)
710
+ }
711
+
712
+ // Both must fit within budget
713
+ if tfidfTokens > availableTokens {
714
+ t.Errorf("TF-IDF exceeds budget: %d > %d", tfidfTokens, availableTokens)
715
+ }
716
+ if textRankTokens > availableTokens {
717
+ t.Errorf("TextRank exceeds budget: %d > %d", textRankTokens, availableTokens)
718
+ }
719
+ }
720
+
721
+ func TestTokenEfficiency_MapReduceTracksTokens(t *testing.T) {
722
+ // Test that mapreduce strategy properly accumulates token usage from multiple chunks
723
+ callCount := 0
724
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
725
+ callCount++
726
+ // Simulate summarization - return a short summary for each chunk
727
+ resp := map[string]interface{}{
728
+ "choices": []map[string]interface{}{
729
+ {"message": map[string]string{"content": fmt.Sprintf("Summary of chunk %d: key finding was performance improvement.", callCount)}},
730
+ },
731
+ "usage": map[string]interface{}{
732
+ "prompt_tokens": 500 + callCount*50,
733
+ "completion_tokens": 30,
734
+ "total_tokens": 530 + callCount*50,
735
+ },
736
+ }
737
+ _ = json.NewEncoder(w).Encode(resp)
738
+ }))
739
+ defer server.Close()
740
+
741
+ engine := New("test-model", Config{
742
+ APIBase: server.URL,
743
+ MaxDepth: 5,
744
+ MaxIterations: 10,
745
+ ContextOverflow: &ContextOverflowConfig{
746
+ Enabled: true,
747
+ Strategy: "mapreduce",
748
+ SafetyMargin: 0.15,
749
+ },
750
+ })
751
+
752
+ // Create a large context that will be split into multiple chunks
753
+ largeContext := generateLargeContext(40000)
754
+ query := "Summarize the key findings"
755
+
756
+ reducer := newContextReducer(engine, *engine.contextOverflow, NewNoopObserver())
757
+ reduced, err := reducer.ReduceForCompletion(query, largeContext, 16000)
758
+ if err != nil {
759
+ t.Fatalf("mapreduce reduction failed: %v", err)
760
+ }
761
+
762
+ // Verify that token usage was accumulated
763
+ if engine.stats.TotalTokens == 0 {
764
+ t.Error("expected total_tokens > 0 after mapreduce reduction, got 0")
765
+ }
766
+ if engine.stats.PromptTokens == 0 {
767
+ t.Error("expected prompt_tokens > 0 after mapreduce reduction, got 0")
768
+ }
769
+ if engine.stats.CompletionTokens == 0 {
770
+ t.Error("expected completion_tokens > 0 after mapreduce reduction, got 0")
771
+ }
772
+
773
+ t.Logf("MapReduce token tracking: %d total tokens (%d prompt, %d completion) across %d LLM calls",
774
+ engine.stats.TotalTokens, engine.stats.PromptTokens, engine.stats.CompletionTokens, engine.stats.LlmCalls)
775
+ t.Logf("Reduced context: %d chars", len(reduced))
776
+
777
+ // The reduced context should be much smaller than the original
778
+ if len(reduced) >= len(largeContext) {
779
+ t.Errorf("mapreduce failed to reduce context: %d chars >= original %d chars", len(reduced), len(largeContext))
780
+ }
781
+ }
782
+
783
+ func TestTokenEfficiency_StructuredCompletion_TracksTokens(t *testing.T) {
784
+ // Verify structured completion accumulates tokens
785
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
786
+ resp := map[string]interface{}{
787
+ "choices": []map[string]interface{}{
788
+ {"message": map[string]string{"content": `{"summary": "test result", "score": 8}`}},
789
+ },
790
+ "usage": map[string]interface{}{
791
+ "prompt_tokens": 300,
792
+ "completion_tokens": 15,
793
+ "total_tokens": 315,
794
+ },
795
+ }
796
+ _ = json.NewEncoder(w).Encode(resp)
797
+ }))
798
+ defer server.Close()
799
+
800
+ engine := New("test-model", Config{
801
+ APIBase: server.URL,
802
+ MaxDepth: 5,
803
+ MaxIterations: 10,
804
+ })
805
+
806
+ schema := &StructuredConfig{
807
+ Schema: &JSONSchema{
808
+ Type: "object",
809
+ Properties: map[string]*JSONSchema{
810
+ "summary": {Type: "string"},
811
+ "score": {Type: "number"},
812
+ },
813
+ Required: []string{"summary", "score"},
814
+ },
815
+ MaxRetries: 3,
816
+ }
817
+
818
+ result, stats, err := engine.StructuredCompletion("Analyze this", "Some test context", schema)
819
+ if err != nil {
820
+ t.Fatalf("structured completion failed: %v", err)
821
+ }
822
+
823
+ if result == nil {
824
+ t.Fatal("expected non-nil result")
825
+ }
826
+
827
+ if stats.TotalTokens == 0 {
828
+ t.Error("expected total_tokens > 0 after structured completion, got 0")
829
+ }
830
+ if stats.PromptTokens == 0 {
831
+ t.Error("expected prompt_tokens > 0 after structured completion")
832
+ }
833
+ if stats.CompletionTokens == 0 {
834
+ t.Error("expected completion_tokens > 0 after structured completion")
835
+ }
836
+
837
+ t.Logf("Structured completion: %d total tokens (%d prompt, %d completion)", stats.TotalTokens, stats.PromptTokens, stats.CompletionTokens)
838
+ }
839
+
840
+ // ─── Token Estimation Accuracy Tests ─────────────────────────────────────────
841
+
842
+ func TestEstimateTokens_AccuracyForLargeContent(t *testing.T) {
843
+ useHeuristicTokenizerForTest(t)
844
+
845
+ // Verify that our estimation stays reasonable for large content
846
+ content := generateLargeContext(32000)
847
+ estimated := EstimateTokens(content)
848
+ // Real tokenizer would give different results, but our estimation should be
849
+ // within a reasonable range. The key property: conservative (over-estimates slightly)
850
+ charToTokenRatio := float64(len(content)) / float64(estimated)
851
+
852
+ // Our estimator uses 3.5 chars/token, so ratio should be ~3.5
853
+ if math.Abs(charToTokenRatio-3.5) > 0.5 {
854
+ t.Errorf("char-to-token ratio %.2f deviates too far from expected ~3.5", charToTokenRatio)
855
+ }
856
+
857
+ t.Logf("Large content: %d chars, %d estimated tokens, ratio: %.2f chars/token",
858
+ len(content), estimated, charToTokenRatio)
859
+ }