recursive-llm-ts 4.5.0 → 4.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +74 -4
- package/bin/rlm-go +0 -0
- package/dist/bridge-interface.d.ts +14 -0
- package/dist/errors.d.ts +10 -0
- package/dist/errors.js +25 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.js +2 -1
- package/dist/rlm.d.ts +3 -1
- package/dist/rlm.js +5 -0
- package/go/README.md +9 -1
- package/go/rlm/context_overflow.go +572 -0
- package/go/rlm/context_overflow_test.go +901 -0
- package/go/rlm/errors.go +185 -1
- package/go/rlm/rlm.go +10 -0
- package/go/rlm/structured.go +60 -7
- package/go/rlm/textrank.go +273 -0
- package/go/rlm/textrank_test.go +335 -0
- package/go/rlm/tfidf.go +225 -0
- package/go/rlm/tfidf_test.go +272 -0
- package/go/rlm/types.go +25 -2
- package/package.json +1 -1
package/go/rlm/errors.go
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
package rlm
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import (
|
|
4
|
+
"errors"
|
|
5
|
+
"fmt"
|
|
6
|
+
"strings"
|
|
7
|
+
)
|
|
4
8
|
|
|
5
9
|
// RLMError is the base error type for all RLM errors
|
|
6
10
|
type RLMError struct {
|
|
@@ -81,3 +85,183 @@ func NewAPIError(statusCode int, response string) *APIError {
|
|
|
81
85
|
},
|
|
82
86
|
}
|
|
83
87
|
}
|
|
88
|
+
|
|
89
|
+
// ContextOverflowError is returned when the request exceeds the model's context window
|
|
90
|
+
type ContextOverflowError struct {
|
|
91
|
+
ModelLimit int // Maximum tokens the model supports
|
|
92
|
+
RequestTokens int // Number of tokens in the request
|
|
93
|
+
*APIError
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// NewContextOverflowError creates a ContextOverflowError from parsed API response details
|
|
97
|
+
func NewContextOverflowError(statusCode int, response string, modelLimit, requestTokens int) *ContextOverflowError {
|
|
98
|
+
return &ContextOverflowError{
|
|
99
|
+
ModelLimit: modelLimit,
|
|
100
|
+
RequestTokens: requestTokens,
|
|
101
|
+
APIError: &APIError{
|
|
102
|
+
StatusCode: statusCode,
|
|
103
|
+
Response: response,
|
|
104
|
+
RLMError: &RLMError{
|
|
105
|
+
Message: fmt.Sprintf("context overflow: model limit is %d tokens but request has %d tokens (overflow by %d)",
|
|
106
|
+
modelLimit, requestTokens, requestTokens-modelLimit),
|
|
107
|
+
},
|
|
108
|
+
},
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
// Unwrap returns the embedded APIError so errors.As can find it in the chain.
|
|
113
|
+
func (e *ContextOverflowError) Unwrap() error {
|
|
114
|
+
return e.APIError
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// OverflowRatio returns how much the request exceeds the limit (e.g., 1.23 means 23% over)
|
|
118
|
+
func (e *ContextOverflowError) OverflowRatio() float64 {
|
|
119
|
+
if e.ModelLimit == 0 {
|
|
120
|
+
return 0
|
|
121
|
+
}
|
|
122
|
+
return float64(e.RequestTokens) / float64(e.ModelLimit)
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// IsContextOverflow checks if an error is a context overflow error.
|
|
126
|
+
// It detects both explicit ContextOverflowError types and parses API error messages.
|
|
127
|
+
func IsContextOverflow(err error) (*ContextOverflowError, bool) {
|
|
128
|
+
// Direct type check
|
|
129
|
+
var coe *ContextOverflowError
|
|
130
|
+
if errors.As(err, &coe) {
|
|
131
|
+
return coe, true
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// Parse from APIError message
|
|
135
|
+
var apiErr *APIError
|
|
136
|
+
if errors.As(err, &apiErr) {
|
|
137
|
+
if limit, request, ok := parseContextOverflowMessage(apiErr.Response); ok {
|
|
138
|
+
return NewContextOverflowError(apiErr.StatusCode, apiErr.Response, limit, request), true
|
|
139
|
+
}
|
|
140
|
+
// Also check the error message itself
|
|
141
|
+
if limit, request, ok := parseContextOverflowMessage(apiErr.Error()); ok {
|
|
142
|
+
return NewContextOverflowError(apiErr.StatusCode, apiErr.Response, limit, request), true
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
// Parse from generic error message
|
|
147
|
+
if limit, request, ok := parseContextOverflowMessage(err.Error()); ok {
|
|
148
|
+
return NewContextOverflowError(0, err.Error(), limit, request), true
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
return nil, false
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// parseContextOverflowMessage extracts token limits from common API error message patterns.
|
|
155
|
+
// Supports OpenAI, Azure, vLLM, and other OpenAI-compatible API error formats.
|
|
156
|
+
func parseContextOverflowMessage(msg string) (modelLimit int, requestTokens int, ok bool) {
|
|
157
|
+
// Common patterns:
|
|
158
|
+
// OpenAI: "This model's maximum context length is 32768 tokens. However, your request has 40354 input tokens."
|
|
159
|
+
// Azure: "This model's maximum context length is 32768 tokens, however you requested 40354 tokens"
|
|
160
|
+
// vLLM: "This model's maximum context length is 32768 tokens. However, your request has 40354 input tokens."
|
|
161
|
+
// Anthropic: "max_tokens: ... exceeds the maximum"
|
|
162
|
+
|
|
163
|
+
lowerMsg := strings.ToLower(msg)
|
|
164
|
+
|
|
165
|
+
// Pattern 1: "maximum context length is X tokens"
|
|
166
|
+
if strings.Contains(lowerMsg, "maximum context length") {
|
|
167
|
+
limit := extractNumber(msg, "maximum context length is ", " tokens")
|
|
168
|
+
if limit > 0 {
|
|
169
|
+
// Try various patterns for the request size
|
|
170
|
+
request := extractNumber(msg, "your request has ", " input tokens")
|
|
171
|
+
if request == 0 {
|
|
172
|
+
request = extractNumber(msg, "your request has ", " tokens")
|
|
173
|
+
}
|
|
174
|
+
if request == 0 {
|
|
175
|
+
request = extractNumber(msg, "you requested ", " tokens")
|
|
176
|
+
}
|
|
177
|
+
if request == 0 {
|
|
178
|
+
request = extractNumber(msg, "requested ", " tokens")
|
|
179
|
+
}
|
|
180
|
+
if request > 0 && request > limit {
|
|
181
|
+
return limit, request, true
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Pattern 2: "context_length_exceeded" error code
|
|
187
|
+
if strings.Contains(lowerMsg, "context_length_exceeded") {
|
|
188
|
+
limit := extractNumber(msg, "maximum context length is ", " tokens")
|
|
189
|
+
request := extractNumber(msg, "resulted in ", " tokens")
|
|
190
|
+
if limit > 0 && request > 0 {
|
|
191
|
+
return limit, request, true
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
// Pattern 3: "max_tokens is too large" - response budget exceeds remaining capacity
|
|
196
|
+
// vLLM/OpenAI: "max_tokens' or 'max_completion_tokens' is too large: 10000.
|
|
197
|
+
// This model's maximum context length is 32768 tokens and your request has 30168 input tokens"
|
|
198
|
+
// In this case, input tokens < model limit, but input + max_tokens > model limit.
|
|
199
|
+
// We report the effective total (input + max_tokens) as requestTokens.
|
|
200
|
+
if strings.Contains(lowerMsg, "max_tokens") && strings.Contains(lowerMsg, "too large") {
|
|
201
|
+
limit := extractNumber(msg, "maximum context length is ", " tokens")
|
|
202
|
+
inputTokens := extractNumber(msg, "your request has ", " input tokens")
|
|
203
|
+
if inputTokens == 0 {
|
|
204
|
+
inputTokens = extractNumber(msg, "your request has ", " tokens")
|
|
205
|
+
}
|
|
206
|
+
maxTokens := extractNumber(msg, "too large: ", ".")
|
|
207
|
+
if maxTokens == 0 {
|
|
208
|
+
maxTokens = extractNumber(msg, "too large: ", " ")
|
|
209
|
+
}
|
|
210
|
+
if limit > 0 && inputTokens > 0 && maxTokens > 0 {
|
|
211
|
+
return limit, inputTokens + maxTokens, true
|
|
212
|
+
}
|
|
213
|
+
// Fallback: if we got limit and input tokens, treat input as the overflow
|
|
214
|
+
if limit > 0 && inputTokens > 0 {
|
|
215
|
+
return limit, inputTokens, true
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
// Pattern 4: "input too long" / "too many tokens" generic patterns
|
|
220
|
+
if strings.Contains(lowerMsg, "input too long") || strings.Contains(lowerMsg, "too many tokens") || strings.Contains(lowerMsg, "too many input tokens") {
|
|
221
|
+
limit := extractNumber(msg, "limit is ", " tokens")
|
|
222
|
+
if limit == 0 {
|
|
223
|
+
limit = extractNumber(msg, "maximum of ", " tokens")
|
|
224
|
+
}
|
|
225
|
+
request := extractNumber(msg, "has ", " tokens")
|
|
226
|
+
if request == 0 {
|
|
227
|
+
request = extractNumber(msg, "requested ", " tokens")
|
|
228
|
+
}
|
|
229
|
+
if limit > 0 && request > 0 {
|
|
230
|
+
return limit, request, true
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
return 0, 0, false
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// extractNumber finds a number between a prefix and suffix in a string
|
|
238
|
+
func extractNumber(s string, prefix string, suffix string) int {
|
|
239
|
+
lowerS := strings.ToLower(s)
|
|
240
|
+
lowerPrefix := strings.ToLower(prefix)
|
|
241
|
+
|
|
242
|
+
idx := strings.Index(lowerS, lowerPrefix)
|
|
243
|
+
if idx < 0 {
|
|
244
|
+
return 0
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
start := idx + len(lowerPrefix)
|
|
248
|
+
remaining := s[start:]
|
|
249
|
+
|
|
250
|
+
// Find the suffix
|
|
251
|
+
lowerSuffix := strings.ToLower(suffix)
|
|
252
|
+
endIdx := strings.Index(strings.ToLower(remaining), lowerSuffix)
|
|
253
|
+
if endIdx < 0 {
|
|
254
|
+
return 0
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
numStr := strings.TrimSpace(remaining[:endIdx])
|
|
258
|
+
// Remove commas from numbers like "32,768"
|
|
259
|
+
numStr = strings.ReplaceAll(numStr, ",", "")
|
|
260
|
+
|
|
261
|
+
var n int
|
|
262
|
+
_, err := fmt.Sscanf(numStr, "%d", &n)
|
|
263
|
+
if err != nil {
|
|
264
|
+
return 0
|
|
265
|
+
}
|
|
266
|
+
return n
|
|
267
|
+
}
|
package/go/rlm/rlm.go
CHANGED
|
@@ -20,6 +20,7 @@ type RLM struct {
|
|
|
20
20
|
stats RLMStats
|
|
21
21
|
observer *Observer
|
|
22
22
|
metaAgent *MetaAgent
|
|
23
|
+
contextOverflow *ContextOverflowConfig
|
|
23
24
|
}
|
|
24
25
|
|
|
25
26
|
func New(model string, config Config) *RLM {
|
|
@@ -57,6 +58,15 @@ func New(model string, config Config) *RLM {
|
|
|
57
58
|
r.metaAgent = NewMetaAgent(r, *config.MetaAgent, obs)
|
|
58
59
|
}
|
|
59
60
|
|
|
61
|
+
// Setup context overflow handling
|
|
62
|
+
if config.ContextOverflow != nil {
|
|
63
|
+
r.contextOverflow = config.ContextOverflow
|
|
64
|
+
} else {
|
|
65
|
+
// Enable by default with sensible defaults
|
|
66
|
+
defaultConfig := DefaultContextOverflowConfig()
|
|
67
|
+
r.contextOverflow = &defaultConfig
|
|
68
|
+
}
|
|
69
|
+
|
|
60
70
|
return r
|
|
61
71
|
}
|
|
62
72
|
|
package/go/rlm/structured.go
CHANGED
|
@@ -102,12 +102,65 @@ func (r *RLM) structuredCompletionDirect(query string, context string, config *S
|
|
|
102
102
|
{Role: "user", Content: prompt},
|
|
103
103
|
}
|
|
104
104
|
|
|
105
|
+
// Track whether we've already reduced context for overflow recovery
|
|
106
|
+
contextReduced := false
|
|
107
|
+
|
|
105
108
|
for attempt := 0; attempt < config.MaxRetries; attempt++ {
|
|
106
109
|
result, err := r.callLLM(messages)
|
|
107
110
|
stats.LlmCalls++
|
|
108
111
|
stats.Iterations++
|
|
109
112
|
|
|
110
113
|
if err != nil {
|
|
114
|
+
// Check for context overflow and attempt automatic recovery
|
|
115
|
+
if coe, isOverflow := IsContextOverflow(err); isOverflow && !contextReduced && r.contextOverflow != nil && r.contextOverflow.Enabled {
|
|
116
|
+
r.observer.Debug("structured", "Context overflow detected on attempt %d: model limit %d, request %d tokens",
|
|
117
|
+
attempt+1, coe.ModelLimit, coe.RequestTokens)
|
|
118
|
+
|
|
119
|
+
modelLimit := coe.ModelLimit
|
|
120
|
+
if r.contextOverflow.MaxModelTokens > 0 {
|
|
121
|
+
modelLimit = r.contextOverflow.MaxModelTokens
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
reducer := newContextReducer(r, *r.contextOverflow, r.observer)
|
|
125
|
+
reducedContext, reduceErr := reducer.ReduceForCompletion(query, context, modelLimit)
|
|
126
|
+
if reduceErr != nil {
|
|
127
|
+
r.observer.Error("structured", "Context reduction failed: %v", reduceErr)
|
|
128
|
+
lastErr = err
|
|
129
|
+
continue
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
r.observer.Debug("structured", "Context reduced: %d -> %d chars, rebuilding prompt", len(context), len(reducedContext))
|
|
133
|
+
context = reducedContext
|
|
134
|
+
contextReduced = true
|
|
135
|
+
|
|
136
|
+
// Rebuild the prompt with reduced context
|
|
137
|
+
prompt = fmt.Sprintf(
|
|
138
|
+
"You are a data extraction assistant. Extract information from the context and return it as JSON.\n\n"+
|
|
139
|
+
"Context:\n%s\n\n"+
|
|
140
|
+
"Task: %s\n\n"+
|
|
141
|
+
"Required JSON Schema:\n%s%s\n\n"+
|
|
142
|
+
"%s"+
|
|
143
|
+
"CRITICAL INSTRUCTIONS:\n"+
|
|
144
|
+
"1. Return ONLY valid JSON - no explanations, no markdown, no code blocks\n"+
|
|
145
|
+
"2. The JSON must match the schema EXACTLY\n"+
|
|
146
|
+
"3. Include ALL required fields (see list above)\n"+
|
|
147
|
+
"4. Use correct data types (strings in quotes, numbers without quotes, arrays in [], objects in {})\n"+
|
|
148
|
+
"5. For arrays, return actual JSON arrays [] not objects\n"+
|
|
149
|
+
"6. For enum fields, use ONLY the EXACT values listed - do not paraphrase or substitute\n"+
|
|
150
|
+
"7. For nested objects, ensure ALL required fields within those objects are included\n"+
|
|
151
|
+
"8. Start your response directly with { or [ depending on the schema\n\n"+
|
|
152
|
+
"JSON Response:",
|
|
153
|
+
reducedContext, query, string(schemaJSON), requiredFieldsHint, constraints,
|
|
154
|
+
)
|
|
155
|
+
messages = []Message{
|
|
156
|
+
{Role: "system", Content: "You are a data extraction assistant. Respond only with valid JSON objects."},
|
|
157
|
+
{Role: "user", Content: prompt},
|
|
158
|
+
}
|
|
159
|
+
// Don't count this as a "used" attempt since it was an overflow, not a validation failure
|
|
160
|
+
attempt--
|
|
161
|
+
continue
|
|
162
|
+
}
|
|
163
|
+
|
|
111
164
|
lastErr = err
|
|
112
165
|
continue
|
|
113
166
|
}
|
|
@@ -790,7 +843,7 @@ func buildValidationFeedback(validationErr error, schema *JSONSchema, previousRe
|
|
|
790
843
|
|
|
791
844
|
var feedback strings.Builder
|
|
792
845
|
feedback.WriteString("VALIDATION ERROR - Your previous response was invalid.\n\n")
|
|
793
|
-
|
|
846
|
+
fmt.Fprintf(&feedback, "ERROR: %s\n\n", errMsg)
|
|
794
847
|
|
|
795
848
|
// Extract what field caused the issue
|
|
796
849
|
if strings.Contains(errMsg, "missing required field:") {
|
|
@@ -799,17 +852,17 @@ func buildValidationFeedback(validationErr error, schema *JSONSchema, previousRe
|
|
|
799
852
|
fieldName = strings.TrimSpace(fieldName)
|
|
800
853
|
|
|
801
854
|
feedback.WriteString("SPECIFIC ISSUE:\n")
|
|
802
|
-
|
|
855
|
+
fmt.Fprintf(&feedback, "The field '%s' is REQUIRED but was not provided.\n\n", fieldName)
|
|
803
856
|
|
|
804
857
|
// Find the schema for this field and provide details
|
|
805
858
|
if schema.Type == "object" && schema.Properties != nil {
|
|
806
859
|
if fieldSchema, exists := schema.Properties[fieldName]; exists {
|
|
807
860
|
feedback.WriteString("FIELD REQUIREMENTS:\n")
|
|
808
|
-
|
|
809
|
-
|
|
861
|
+
fmt.Fprintf(&feedback, "- Field name: '%s'\n", fieldName)
|
|
862
|
+
fmt.Fprintf(&feedback, "- Type: %s\n", fieldSchema.Type)
|
|
810
863
|
|
|
811
864
|
if fieldSchema.Type == "object" && len(fieldSchema.Required) > 0 {
|
|
812
|
-
|
|
865
|
+
fmt.Fprintf(&feedback, "- This is an object with required fields: %s\n", strings.Join(fieldSchema.Required, ", "))
|
|
813
866
|
|
|
814
867
|
if fieldSchema.Properties != nil {
|
|
815
868
|
feedback.WriteString("\nNESTED FIELD DETAILS:\n")
|
|
@@ -819,13 +872,13 @@ func buildValidationFeedback(validationErr error, schema *JSONSchema, previousRe
|
|
|
819
872
|
if isRequired {
|
|
820
873
|
requiredMark = " [REQUIRED]"
|
|
821
874
|
}
|
|
822
|
-
|
|
875
|
+
fmt.Fprintf(&feedback, " - %s: %s%s\n", nestedField, nestedSchema.Type, requiredMark)
|
|
823
876
|
}
|
|
824
877
|
}
|
|
825
878
|
}
|
|
826
879
|
|
|
827
880
|
if fieldSchema.Type == "array" && fieldSchema.Items != nil {
|
|
828
|
-
|
|
881
|
+
fmt.Fprintf(&feedback, "- This is an array of: %s\n", fieldSchema.Items.Type)
|
|
829
882
|
}
|
|
830
883
|
}
|
|
831
884
|
}
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
package rlm
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"math"
|
|
5
|
+
"sort"
|
|
6
|
+
"strings"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
// ─── TextRank Graph-Based Sentence Ranking ──────────────────────────────────
|
|
10
|
+
//
|
|
11
|
+
// Pure Go, zero external dependencies, zero API calls.
|
|
12
|
+
// Implements the TextRank algorithm (Mihalcea & Tarau, 2004):
|
|
13
|
+
// 1. Build TF-IDF vectors for each sentence
|
|
14
|
+
// 2. Compute cosine similarity between all sentence pairs
|
|
15
|
+
// 3. Run PageRank iteration on the similarity graph
|
|
16
|
+
// 4. Select top-ranked sentences that fit within token budget
|
|
17
|
+
// 5. Preserve original document order
|
|
18
|
+
|
|
19
|
+
// TextRankConfig controls the TextRank algorithm parameters.
|
|
20
|
+
type TextRankConfig struct {
|
|
21
|
+
// DampingFactor is the PageRank damping factor (default: 0.85)
|
|
22
|
+
DampingFactor float64
|
|
23
|
+
// MaxIterations for PageRank convergence (default: 100)
|
|
24
|
+
MaxIterations int
|
|
25
|
+
// ConvergenceThreshold for PageRank (default: 0.0001)
|
|
26
|
+
ConvergenceThreshold float64
|
|
27
|
+
// MinSimilarity threshold to create an edge (default: 0.1)
|
|
28
|
+
MinSimilarity float64
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
// DefaultTextRankConfig returns sensible defaults for TextRank.
|
|
32
|
+
func DefaultTextRankConfig() TextRankConfig {
|
|
33
|
+
return TextRankConfig{
|
|
34
|
+
DampingFactor: 0.85,
|
|
35
|
+
MaxIterations: 100,
|
|
36
|
+
ConvergenceThreshold: 0.0001,
|
|
37
|
+
MinSimilarity: 0.1,
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
// tfidfVector represents a sparse TF-IDF vector for a sentence.
|
|
42
|
+
type tfidfVector struct {
|
|
43
|
+
terms map[string]float64
|
|
44
|
+
norm float64 // precomputed L2 norm
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// buildTFIDFVectors computes TF-IDF vectors for each sentence.
|
|
48
|
+
func buildTFIDFVectors(sentences []string) []tfidfVector {
|
|
49
|
+
n := len(sentences)
|
|
50
|
+
if n == 0 {
|
|
51
|
+
return nil
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
// Tokenize and filter
|
|
55
|
+
docWords := make([][]string, n)
|
|
56
|
+
for i, s := range sentences {
|
|
57
|
+
docWords[i] = FilterStopWords(TokenizeWords(s))
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Compute document frequency
|
|
61
|
+
df := make(map[string]int)
|
|
62
|
+
for _, words := range docWords {
|
|
63
|
+
seen := make(map[string]bool)
|
|
64
|
+
for _, w := range words {
|
|
65
|
+
if !seen[w] {
|
|
66
|
+
df[w]++
|
|
67
|
+
seen[w] = true
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
nf := float64(n)
|
|
73
|
+
|
|
74
|
+
// Build vectors
|
|
75
|
+
vectors := make([]tfidfVector, n)
|
|
76
|
+
for i, words := range docWords {
|
|
77
|
+
tf := make(map[string]int)
|
|
78
|
+
for _, w := range words {
|
|
79
|
+
tf[w]++
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
terms := make(map[string]float64)
|
|
83
|
+
normSq := 0.0
|
|
84
|
+
for word, freq := range tf {
|
|
85
|
+
idf := math.Log(nf / float64(df[word]))
|
|
86
|
+
val := float64(freq) * idf
|
|
87
|
+
terms[word] = val
|
|
88
|
+
normSq += val * val
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
vectors[i] = tfidfVector{
|
|
92
|
+
terms: terms,
|
|
93
|
+
norm: math.Sqrt(normSq),
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
return vectors
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
// cosineSimilarity computes the cosine similarity between two TF-IDF vectors.
|
|
101
|
+
func cosineSimilarity(a, b tfidfVector) float64 {
|
|
102
|
+
if a.norm == 0 || b.norm == 0 {
|
|
103
|
+
return 0
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// Compute dot product using the smaller vector for efficiency
|
|
107
|
+
dot := 0.0
|
|
108
|
+
small, large := a.terms, b.terms
|
|
109
|
+
if len(a.terms) > len(b.terms) {
|
|
110
|
+
small, large = b.terms, a.terms
|
|
111
|
+
}
|
|
112
|
+
for term, val := range small {
|
|
113
|
+
if otherVal, ok := large[term]; ok {
|
|
114
|
+
dot += val * otherVal
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
return dot / (a.norm * b.norm)
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// BuildSimilarityGraph creates a weighted adjacency matrix of sentence similarities.
|
|
122
|
+
// Only edges above the MinSimilarity threshold are kept.
|
|
123
|
+
func BuildSimilarityGraph(sentences []string, config TextRankConfig) [][]float64 {
|
|
124
|
+
n := len(sentences)
|
|
125
|
+
vectors := buildTFIDFVectors(sentences)
|
|
126
|
+
|
|
127
|
+
graph := make([][]float64, n)
|
|
128
|
+
for i := range graph {
|
|
129
|
+
graph[i] = make([]float64, n)
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
for i := 0; i < n; i++ {
|
|
133
|
+
for j := i + 1; j < n; j++ {
|
|
134
|
+
sim := cosineSimilarity(vectors[i], vectors[j])
|
|
135
|
+
if sim >= config.MinSimilarity {
|
|
136
|
+
graph[i][j] = sim
|
|
137
|
+
graph[j][i] = sim
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
return graph
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// PageRank runs the PageRank algorithm on a weighted graph.
|
|
146
|
+
// Returns a score for each node (sentence).
|
|
147
|
+
func PageRank(graph [][]float64, config TextRankConfig) []float64 {
|
|
148
|
+
n := len(graph)
|
|
149
|
+
if n == 0 {
|
|
150
|
+
return nil
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
d := config.DampingFactor
|
|
154
|
+
scores := make([]float64, n)
|
|
155
|
+
newScores := make([]float64, n)
|
|
156
|
+
|
|
157
|
+
// Initialize with uniform scores
|
|
158
|
+
initial := 1.0 / float64(n)
|
|
159
|
+
for i := range scores {
|
|
160
|
+
scores[i] = initial
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// Precompute outgoing weight sums for each node
|
|
164
|
+
outWeights := make([]float64, n)
|
|
165
|
+
for i := 0; i < n; i++ {
|
|
166
|
+
for j := 0; j < n; j++ {
|
|
167
|
+
outWeights[i] += graph[i][j]
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// Iterate until convergence
|
|
172
|
+
for iter := 0; iter < config.MaxIterations; iter++ {
|
|
173
|
+
maxDelta := 0.0
|
|
174
|
+
|
|
175
|
+
for i := 0; i < n; i++ {
|
|
176
|
+
sum := 0.0
|
|
177
|
+
for j := 0; j < n; j++ {
|
|
178
|
+
if graph[j][i] > 0 && outWeights[j] > 0 {
|
|
179
|
+
sum += graph[j][i] / outWeights[j] * scores[j]
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
newScores[i] = (1-d)/float64(n) + d*sum
|
|
183
|
+
|
|
184
|
+
delta := math.Abs(newScores[i] - scores[i])
|
|
185
|
+
if delta > maxDelta {
|
|
186
|
+
maxDelta = delta
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
// Swap slices
|
|
191
|
+
scores, newScores = newScores, scores
|
|
192
|
+
|
|
193
|
+
// Check convergence
|
|
194
|
+
if maxDelta < config.ConvergenceThreshold {
|
|
195
|
+
break
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
return scores
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
// CompressContextTextRank reduces context to fit within a token budget using
|
|
203
|
+
// TextRank graph-based sentence ranking.
|
|
204
|
+
// Preserves original sentence order in the output.
|
|
205
|
+
func CompressContextTextRank(text string, targetTokens int) string {
|
|
206
|
+
return CompressContextTextRankWithConfig(text, targetTokens, DefaultTextRankConfig())
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
// CompressContextTextRankWithConfig is like CompressContextTextRank but with custom TextRank parameters.
|
|
210
|
+
func CompressContextTextRankWithConfig(text string, targetTokens int, config TextRankConfig) string {
|
|
211
|
+
if EstimateTokens(text) <= targetTokens {
|
|
212
|
+
return text
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
sentences := SplitSentences(text)
|
|
216
|
+
if len(sentences) == 0 {
|
|
217
|
+
return text
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
// Build similarity graph and run PageRank
|
|
221
|
+
graph := BuildSimilarityGraph(sentences, config)
|
|
222
|
+
scores := PageRank(graph, config)
|
|
223
|
+
|
|
224
|
+
// Create scored sentences with PageRank scores
|
|
225
|
+
ranked := make([]ScoredSentence, len(sentences))
|
|
226
|
+
for i, s := range sentences {
|
|
227
|
+
ranked[i] = ScoredSentence{
|
|
228
|
+
Text: s,
|
|
229
|
+
Score: scores[i],
|
|
230
|
+
Index: i,
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// Sort by score descending
|
|
235
|
+
sort.Slice(ranked, func(i, j int) bool {
|
|
236
|
+
return ranked[i].Score > ranked[j].Score
|
|
237
|
+
})
|
|
238
|
+
|
|
239
|
+
// Greedily select top sentences until budget is reached
|
|
240
|
+
var selected []ScoredSentence
|
|
241
|
+
currentTokens := 0
|
|
242
|
+
for _, s := range ranked {
|
|
243
|
+
sentTokens := EstimateTokens(s.Text)
|
|
244
|
+
if currentTokens+sentTokens > targetTokens {
|
|
245
|
+
continue
|
|
246
|
+
}
|
|
247
|
+
selected = append(selected, s)
|
|
248
|
+
currentTokens += sentTokens
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
if len(selected) == 0 {
|
|
252
|
+
// Budget too small - truncate the top sentence
|
|
253
|
+
if len(ranked) > 0 {
|
|
254
|
+
maxChars := targetTokens * 3
|
|
255
|
+
if maxChars > len(ranked[0].Text) {
|
|
256
|
+
maxChars = len(ranked[0].Text)
|
|
257
|
+
}
|
|
258
|
+
return ranked[0].Text[:maxChars]
|
|
259
|
+
}
|
|
260
|
+
return text
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
// Re-sort by original index to preserve document order
|
|
264
|
+
sort.Slice(selected, func(i, j int) bool {
|
|
265
|
+
return selected[i].Index < selected[j].Index
|
|
266
|
+
})
|
|
267
|
+
|
|
268
|
+
parts := make([]string, len(selected))
|
|
269
|
+
for i, s := range selected {
|
|
270
|
+
parts[i] = s.Text
|
|
271
|
+
}
|
|
272
|
+
return strings.Join(parts, " ")
|
|
273
|
+
}
|