@xdarkicex/openclaw-memory-libravdb 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -0
- package/docs/README.md +14 -0
- package/docs/architecture-decisions/README.md +6 -0
- package/docs/architecture-decisions/adr-001-onnx-over-ollama.md +21 -0
- package/docs/architecture-decisions/adr-002-libravdb-over-lancedb.md +19 -0
- package/docs/architecture-decisions/adr-003-convex-gating-over-threshold.md +27 -0
- package/docs/architecture-decisions/adr-004-sidecar-over-native-ts.md +21 -0
- package/docs/architecture.md +188 -0
- package/docs/contributing.md +76 -0
- package/docs/dependencies.md +38 -0
- package/docs/embedding-profiles.md +42 -0
- package/docs/gating.md +329 -0
- package/docs/implementation.md +381 -0
- package/docs/installation.md +272 -0
- package/docs/mathematics.md +695 -0
- package/docs/models.md +63 -0
- package/docs/problem.md +64 -0
- package/docs/security.md +86 -0
- package/openclaw.plugin.json +84 -0
- package/package.json +41 -0
- package/scripts/build-sidecar.sh +30 -0
- package/scripts/postinstall.js +169 -0
- package/scripts/setup.sh +20 -0
- package/scripts/setup.ts +505 -0
- package/scripts/sidecar-release.d.ts +4 -0
- package/scripts/sidecar-release.js +17 -0
- package/sidecar/cmd/inspect_onnx/main.go +105 -0
- package/sidecar/compact/gate.go +273 -0
- package/sidecar/compact/gate_test.go +85 -0
- package/sidecar/compact/summarize.go +345 -0
- package/sidecar/compact/summarize_test.go +319 -0
- package/sidecar/compact/tokens.go +11 -0
- package/sidecar/config/config.go +119 -0
- package/sidecar/config/config_test.go +75 -0
- package/sidecar/embed/engine.go +696 -0
- package/sidecar/embed/engine_test.go +349 -0
- package/sidecar/embed/matryoshka.go +93 -0
- package/sidecar/embed/matryoshka_test.go +150 -0
- package/sidecar/embed/onnx_local.go +319 -0
- package/sidecar/embed/onnx_local_test.go +159 -0
- package/sidecar/embed/profile_contract_test.go +71 -0
- package/sidecar/embed/profile_eval_test.go +923 -0
- package/sidecar/embed/profiles.go +39 -0
- package/sidecar/go.mod +21 -0
- package/sidecar/go.sum +30 -0
- package/sidecar/health/check.go +33 -0
- package/sidecar/health/check_test.go +55 -0
- package/sidecar/main.go +151 -0
- package/sidecar/model/encoder.go +222 -0
- package/sidecar/model/registry.go +262 -0
- package/sidecar/model/registry_test.go +102 -0
- package/sidecar/model/seq2seq.go +133 -0
- package/sidecar/server/rpc.go +343 -0
- package/sidecar/server/rpc_test.go +350 -0
- package/sidecar/server/transport.go +160 -0
- package/sidecar/store/libravdb.go +676 -0
- package/sidecar/store/libravdb_test.go +472 -0
- package/sidecar/summarize/engine.go +360 -0
- package/sidecar/summarize/engine_test.go +148 -0
- package/sidecar/summarize/onnx_local.go +494 -0
- package/sidecar/summarize/onnx_local_test.go +48 -0
- package/sidecar/summarize/profiles.go +52 -0
- package/sidecar/summarize/tokenizer.go +13 -0
- package/sidecar/summarize/tokenizer_hf.go +76 -0
- package/sidecar/summarize/util.go +13 -0
- package/src/cli.ts +205 -0
- package/src/context-engine.ts +195 -0
- package/src/index.ts +27 -0
- package/src/memory-provider.ts +24 -0
- package/src/openclaw-plugin-sdk.d.ts +53 -0
- package/src/plugin-runtime.ts +67 -0
- package/src/recall-cache.ts +34 -0
- package/src/recall-utils.ts +22 -0
- package/src/rpc.ts +84 -0
- package/src/scoring.ts +58 -0
- package/src/sidecar.ts +506 -0
- package/src/tokens.ts +36 -0
- package/src/types.ts +146 -0
- package/tsconfig.json +20 -0
- package/tsconfig.tests.json +12 -0
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
package summarize
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"crypto/sha256"
|
|
6
|
+
"encoding/hex"
|
|
7
|
+
"encoding/json"
|
|
8
|
+
"fmt"
|
|
9
|
+
"math"
|
|
10
|
+
"os"
|
|
11
|
+
"path/filepath"
|
|
12
|
+
"strings"
|
|
13
|
+
"sync"
|
|
14
|
+
|
|
15
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/model"
|
|
16
|
+
ort "github.com/yalue/onnxruntime_go"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
const defaultManifestName = "summarizer.json"
|
|
20
|
+
|
|
21
|
+
const (
|
|
22
|
+
t5SmallHiddenDim = 512
|
|
23
|
+
t5SmallVocabSize = 32128
|
|
24
|
+
|
|
25
|
+
t5EncoderInputIDs = "input_ids"
|
|
26
|
+
t5EncoderAttnMask = "attention_mask"
|
|
27
|
+
t5EncoderHiddenState = "last_hidden_state"
|
|
28
|
+
|
|
29
|
+
t5DecoderInputIDs = "input_ids"
|
|
30
|
+
t5DecoderEncoderHidden = "encoder_hidden_states"
|
|
31
|
+
t5DecoderEncoderMask = "encoder_attention_mask"
|
|
32
|
+
t5DecoderLogits = "logits"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
// --- ONNX Graph Record ---
|
|
36
|
+
// Model: encoder_model.onnx
|
|
37
|
+
// Inspected: 2026-03-28 22:23:28 PDT
|
|
38
|
+
//
|
|
39
|
+
// INPUTS:
|
|
40
|
+
// input_ids ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 [-1 -1]
|
|
41
|
+
// attention_mask ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 [-1 -1]
|
|
42
|
+
//
|
|
43
|
+
// OUTPUTS:
|
|
44
|
+
// last_hidden_state ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 512]
|
|
45
|
+
// -------------------------
|
|
46
|
+
//
|
|
47
|
+
// --- ONNX Graph Record ---
|
|
48
|
+
// Model: decoder_model.onnx
|
|
49
|
+
// Inspected: 2026-03-28 22:23:29 PDT
|
|
50
|
+
//
|
|
51
|
+
// INPUTS:
|
|
52
|
+
// encoder_attention_mask ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 [-1 -1]
|
|
53
|
+
// input_ids ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 [-1 -1]
|
|
54
|
+
// encoder_hidden_states ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 512]
|
|
55
|
+
//
|
|
56
|
+
// OUTPUTS:
|
|
57
|
+
// logits ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 32128]
|
|
58
|
+
// present.0.decoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
59
|
+
// present.0.decoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
60
|
+
// present.0.encoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
61
|
+
// present.0.encoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
62
|
+
// present.1.decoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
63
|
+
// present.1.decoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
64
|
+
// present.1.encoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
65
|
+
// present.1.encoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
66
|
+
// present.2.decoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
67
|
+
// present.2.decoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
68
|
+
// present.2.encoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
69
|
+
// present.2.encoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
70
|
+
// present.3.decoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
71
|
+
// present.3.decoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
72
|
+
// present.3.encoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
73
|
+
// present.3.encoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
74
|
+
// present.4.decoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
75
|
+
// present.4.decoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
76
|
+
// present.4.encoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
77
|
+
// present.4.encoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
78
|
+
// present.5.decoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
79
|
+
// present.5.decoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
80
|
+
// present.5.encoder.key ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
81
|
+
// present.5.encoder.value ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 -1 -1]
|
|
82
|
+
// encoder_last_hidden_state ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT [-1 -1 512]
|
|
83
|
+
// -------------------------
|
|
84
|
+
|
|
85
|
+
type summaryManifest struct {
|
|
86
|
+
Backend string `json:"backend,omitempty"`
|
|
87
|
+
Profile string `json:"profile,omitempty"`
|
|
88
|
+
Family string `json:"family,omitempty"`
|
|
89
|
+
Model string `json:"model,omitempty"`
|
|
90
|
+
Encoder string `json:"encoder,omitempty"`
|
|
91
|
+
Decoder string `json:"decoder,omitempty"`
|
|
92
|
+
Tokenizer string `json:"tokenizer,omitempty"`
|
|
93
|
+
MaxContextTokens int `json:"maxContextTokens,omitempty"`
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
type onnxLocalSpec struct {
|
|
97
|
+
RuntimePath string
|
|
98
|
+
ModelPath string
|
|
99
|
+
EncoderPath string
|
|
100
|
+
DecoderPath string
|
|
101
|
+
TokenizerPath string
|
|
102
|
+
MaxContextTokens int
|
|
103
|
+
Profile Profile
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
type onnxLocalBackend struct {
|
|
107
|
+
mu sync.Mutex
|
|
108
|
+
registry *model.Registry
|
|
109
|
+
spec onnxLocalSpec
|
|
110
|
+
tok Tokenizer
|
|
111
|
+
loaded *model.Seq2SeqModel
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
func resolveONNXLocalSpec(cfg Config) (onnxLocalSpec, error) {
|
|
115
|
+
manifestPath, err := resolveManifestPath(cfg.ModelPath)
|
|
116
|
+
if err != nil {
|
|
117
|
+
return onnxLocalSpec{}, err
|
|
118
|
+
}
|
|
119
|
+
manifest, err := readManifest(manifestPath)
|
|
120
|
+
if err != nil {
|
|
121
|
+
return onnxLocalSpec{}, err
|
|
122
|
+
}
|
|
123
|
+
selectedProfile, hasProfile := lookupProfile(firstNonEmpty(cfg.Profile, manifest.Profile))
|
|
124
|
+
|
|
125
|
+
baseDir := filepath.Dir(manifestPath)
|
|
126
|
+
modelPath := resolveManifestAsset(baseDir, manifest.Model)
|
|
127
|
+
encoderPath := resolveManifestAsset(baseDir, manifest.Encoder)
|
|
128
|
+
decoderPath := resolveManifestAsset(baseDir, manifest.Decoder)
|
|
129
|
+
tokenizerPath := resolveManifestAsset(baseDir, cfg.TokenizerPath, manifest.Tokenizer)
|
|
130
|
+
if tokenizerPath == "" {
|
|
131
|
+
return onnxLocalSpec{}, fmt.Errorf("onnx-local summarizer manifest missing tokenizer path")
|
|
132
|
+
}
|
|
133
|
+
if modelPath == "" && encoderPath == "" && decoderPath == "" {
|
|
134
|
+
return onnxLocalSpec{}, fmt.Errorf("onnx-local summarizer manifest missing model path")
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
maxCtx := manifest.MaxContextTokens
|
|
138
|
+
if maxCtx <= 0 && hasProfile {
|
|
139
|
+
maxCtx = selectedProfile.MaxContextTokens
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
profile := buildProfile(Profile{
|
|
143
|
+
Backend: "onnx-local",
|
|
144
|
+
Family: firstNonEmpty(strings.TrimSpace(manifest.Family), selectedProfile.Family, "onnx-local"),
|
|
145
|
+
Model: firstNonEmpty(filepath.Base(modelPath), filepath.Base(encoderPath), selectedProfile.Name),
|
|
146
|
+
ModelPath: firstNonEmpty(modelPath, encoderPath, filepath.Dir(manifestPath)),
|
|
147
|
+
})
|
|
148
|
+
|
|
149
|
+
return onnxLocalSpec{
|
|
150
|
+
RuntimePath: strings.TrimSpace(cfg.RuntimePath),
|
|
151
|
+
ModelPath: modelPath,
|
|
152
|
+
EncoderPath: encoderPath,
|
|
153
|
+
DecoderPath: decoderPath,
|
|
154
|
+
TokenizerPath: tokenizerPath,
|
|
155
|
+
MaxContextTokens: maxCtx,
|
|
156
|
+
Profile: profile,
|
|
157
|
+
}, nil
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
func newONNXLocalBackend(cfg Config, deps Dependencies) (summarizerBackend, error) {
|
|
161
|
+
if strings.TrimSpace(cfg.RuntimePath) == "" {
|
|
162
|
+
return nil, fmt.Errorf("onnx-local summarizer requires ONNX runtime path")
|
|
163
|
+
}
|
|
164
|
+
spec, err := resolveONNXLocalSpec(cfg)
|
|
165
|
+
if err != nil {
|
|
166
|
+
return nil, err
|
|
167
|
+
}
|
|
168
|
+
registry := deps.Registry
|
|
169
|
+
if registry == nil {
|
|
170
|
+
registry = model.DefaultRegistry()
|
|
171
|
+
}
|
|
172
|
+
tokenizerLoader := deps.TokenizerLoader
|
|
173
|
+
if tokenizerLoader == nil {
|
|
174
|
+
tokenizerLoader = newTokenizer
|
|
175
|
+
}
|
|
176
|
+
tok, err := tokenizerLoader(spec.TokenizerPath)
|
|
177
|
+
if err != nil {
|
|
178
|
+
return nil, err
|
|
179
|
+
}
|
|
180
|
+
return &onnxLocalBackend{
|
|
181
|
+
registry: registry,
|
|
182
|
+
spec: spec,
|
|
183
|
+
tok: tok,
|
|
184
|
+
}, nil
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
func (b *onnxLocalBackend) Summarize(ctx context.Context, turns []Turn, opts SummaryOpts) (Summary, error) {
|
|
188
|
+
opts = normalizeSummaryOpts(opts)
|
|
189
|
+
if len(turns) == 0 {
|
|
190
|
+
return Summary{}, fmt.Errorf("no turns to summarize")
|
|
191
|
+
}
|
|
192
|
+
if len(turns) < opts.MinInputTurns {
|
|
193
|
+
return Summary{}, fmt.Errorf("need at least %d turns for summarization, got %d", opts.MinInputTurns, len(turns))
|
|
194
|
+
}
|
|
195
|
+
if err := b.Warmup(ctx); err != nil {
|
|
196
|
+
return Summary{}, err
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
text := summarizeInput(turns)
|
|
200
|
+
inputIDs, err := b.tok.Encode("summarize: " + text)
|
|
201
|
+
if err != nil {
|
|
202
|
+
return Summary{}, fmt.Errorf("tokenize input: %w", err)
|
|
203
|
+
}
|
|
204
|
+
if len(inputIDs) == 0 {
|
|
205
|
+
return Summary{}, fmt.Errorf("tokenizer returned no input ids")
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
encMask := make([]int64, len(inputIDs))
|
|
209
|
+
for i := range encMask {
|
|
210
|
+
encMask[i] = 1
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
encHidden, err := b.runEncoder(inputIDs, encMask, len(inputIDs))
|
|
214
|
+
if err != nil {
|
|
215
|
+
return Summary{}, err
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
decodedIDs, confidence, err := b.decode(ctx, encHidden, encMask, len(inputIDs), opts.MaxOutputTokens)
|
|
219
|
+
if err != nil {
|
|
220
|
+
return Summary{}, err
|
|
221
|
+
}
|
|
222
|
+
summaryText, err := b.tok.Decode(decodedIDs)
|
|
223
|
+
if err != nil {
|
|
224
|
+
return Summary{}, fmt.Errorf("decode tokens: %w", err)
|
|
225
|
+
}
|
|
226
|
+
summaryText = strings.TrimSpace(summaryText)
|
|
227
|
+
if summaryText == "" {
|
|
228
|
+
return Summary{}, fmt.Errorf("summarizer produced empty output")
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
sourceIDs := make([]string, 0, len(turns))
|
|
232
|
+
for _, turn := range turns {
|
|
233
|
+
sourceIDs = append(sourceIDs, turn.ID)
|
|
234
|
+
}
|
|
235
|
+
return Summary{
|
|
236
|
+
Text: summaryText,
|
|
237
|
+
SourceIDs: sourceIDs,
|
|
238
|
+
Method: "onnx-t5",
|
|
239
|
+
TokenCount: len(decodedIDs),
|
|
240
|
+
Confidence: confidence,
|
|
241
|
+
}, nil
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
func (b *onnxLocalBackend) Warmup(_ context.Context) error {
|
|
245
|
+
b.mu.Lock()
|
|
246
|
+
defer b.mu.Unlock()
|
|
247
|
+
if b.loaded != nil {
|
|
248
|
+
return nil
|
|
249
|
+
}
|
|
250
|
+
loaded, err := b.registry.LoadSeq2Seq(model.Seq2SeqSpec{
|
|
251
|
+
Key: b.spec.Profile.Fingerprint,
|
|
252
|
+
Profile: model.Profile{
|
|
253
|
+
Name: b.spec.Profile.Fingerprint,
|
|
254
|
+
Family: b.spec.Profile.Family,
|
|
255
|
+
Task: model.TaskSummarization,
|
|
256
|
+
MaxCtxTokens: b.spec.MaxContextTokens,
|
|
257
|
+
ModelPath: b.spec.ModelPath,
|
|
258
|
+
TokenizerPath: b.spec.TokenizerPath,
|
|
259
|
+
OrtLibPath: b.spec.RuntimePath,
|
|
260
|
+
},
|
|
261
|
+
RuntimePath: b.spec.RuntimePath,
|
|
262
|
+
ModelPath: b.spec.ModelPath,
|
|
263
|
+
EncoderPath: b.spec.EncoderPath,
|
|
264
|
+
DecoderPath: b.spec.DecoderPath,
|
|
265
|
+
TokenizerPath: b.spec.TokenizerPath,
|
|
266
|
+
EncoderInputs: []string{t5EncoderInputIDs, t5EncoderAttnMask},
|
|
267
|
+
EncoderOutputs: []string{t5EncoderHiddenState},
|
|
268
|
+
DecoderInputs: []string{t5DecoderEncoderMask, t5DecoderInputIDs, t5DecoderEncoderHidden},
|
|
269
|
+
DecoderOutputs: []string{t5DecoderLogits},
|
|
270
|
+
})
|
|
271
|
+
if err != nil {
|
|
272
|
+
return err
|
|
273
|
+
}
|
|
274
|
+
b.loaded = loaded
|
|
275
|
+
return nil
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
func (b *onnxLocalBackend) Unload() {
|
|
279
|
+
b.mu.Lock()
|
|
280
|
+
defer b.mu.Unlock()
|
|
281
|
+
if b.loaded == nil {
|
|
282
|
+
return
|
|
283
|
+
}
|
|
284
|
+
_ = b.registry.Unload(b.spec.Profile.Fingerprint)
|
|
285
|
+
b.loaded = nil
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
func (b *onnxLocalBackend) Close() error {
|
|
289
|
+
b.Unload()
|
|
290
|
+
return nil
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
func (b *onnxLocalBackend) Profile() Profile { return b.spec.Profile }
|
|
294
|
+
func (b *onnxLocalBackend) Ready() bool { return true }
|
|
295
|
+
func (b *onnxLocalBackend) Reason() string { return "" }
|
|
296
|
+
func (b *onnxLocalBackend) Mode() string { return "onnx-local" }
|
|
297
|
+
|
|
298
|
+
func resolveManifestPath(modelPath string) (string, error) {
|
|
299
|
+
raw := strings.TrimSpace(modelPath)
|
|
300
|
+
if raw == "" {
|
|
301
|
+
return "", fmt.Errorf("onnx-local summarizer requires summarizerModelPath pointing to a model directory or summarizer.json")
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
info, err := os.Stat(raw)
|
|
305
|
+
if err == nil && info.IsDir() {
|
|
306
|
+
return filepath.Join(raw, defaultManifestName), nil
|
|
307
|
+
}
|
|
308
|
+
if err == nil && strings.EqualFold(filepath.Base(raw), defaultManifestName) {
|
|
309
|
+
return raw, nil
|
|
310
|
+
}
|
|
311
|
+
if err == nil {
|
|
312
|
+
return filepath.Join(filepath.Dir(raw), defaultManifestName), nil
|
|
313
|
+
}
|
|
314
|
+
return "", err
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
func readManifest(path string) (summaryManifest, error) {
|
|
318
|
+
data, err := os.ReadFile(path)
|
|
319
|
+
if err != nil {
|
|
320
|
+
return summaryManifest{}, fmt.Errorf("failed to read %s: %w", path, err)
|
|
321
|
+
}
|
|
322
|
+
var manifest summaryManifest
|
|
323
|
+
if err := json.Unmarshal(data, &manifest); err != nil {
|
|
324
|
+
return summaryManifest{}, fmt.Errorf("failed to parse %s: %w", path, err)
|
|
325
|
+
}
|
|
326
|
+
return manifest, nil
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
func resolveManifestAsset(baseDir string, values ...string) string {
|
|
330
|
+
asset := firstNonEmpty(values...)
|
|
331
|
+
if asset == "" {
|
|
332
|
+
return ""
|
|
333
|
+
}
|
|
334
|
+
if filepath.IsAbs(asset) {
|
|
335
|
+
return asset
|
|
336
|
+
}
|
|
337
|
+
return filepath.Join(baseDir, asset)
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
func buildProfile(profile Profile) Profile {
|
|
341
|
+
hash := sha256.Sum256([]byte(strings.Join([]string{
|
|
342
|
+
profile.Backend,
|
|
343
|
+
profile.Family,
|
|
344
|
+
profile.Model,
|
|
345
|
+
profile.ModelPath,
|
|
346
|
+
}, "|")))
|
|
347
|
+
profile.Fingerprint = hex.EncodeToString(hash[:8])
|
|
348
|
+
return profile
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
func (b *onnxLocalBackend) runEncoder(inputIDs, attnMask []int64, seqLen int) ([]float32, error) {
|
|
352
|
+
if b.loaded == nil {
|
|
353
|
+
return nil, fmt.Errorf("onnx summarizer model not loaded")
|
|
354
|
+
}
|
|
355
|
+
shape2D := ort.NewShape(1, int64(seqLen))
|
|
356
|
+
shape3D := ort.NewShape(1, int64(seqLen), t5SmallHiddenDim)
|
|
357
|
+
|
|
358
|
+
inIDs, err := ort.NewTensor(shape2D, inputIDs)
|
|
359
|
+
if err != nil {
|
|
360
|
+
return nil, fmt.Errorf("create encoder input_ids tensor: %w", err)
|
|
361
|
+
}
|
|
362
|
+
defer inIDs.Destroy()
|
|
363
|
+
inMask, err := ort.NewTensor(shape2D, attnMask)
|
|
364
|
+
if err != nil {
|
|
365
|
+
return nil, fmt.Errorf("create encoder attention_mask tensor: %w", err)
|
|
366
|
+
}
|
|
367
|
+
defer inMask.Destroy()
|
|
368
|
+
outHidden, err := ort.NewEmptyTensor[float32](shape3D)
|
|
369
|
+
if err != nil {
|
|
370
|
+
return nil, fmt.Errorf("create encoder output tensor: %w", err)
|
|
371
|
+
}
|
|
372
|
+
defer outHidden.Destroy()
|
|
373
|
+
|
|
374
|
+
if err := b.loaded.RunEncoder([]ort.Value{inIDs, inMask}, []ort.Value{outHidden}); err != nil {
|
|
375
|
+
return nil, fmt.Errorf("encoder run: %w", err)
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
data := outHidden.GetData()
|
|
379
|
+
out := make([]float32, len(data))
|
|
380
|
+
copy(out, data)
|
|
381
|
+
return out, nil
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
func (b *onnxLocalBackend) runDecoderStep(decIDs []int64, encHidden []float32, encMask []int64, seqLen int) ([]float32, error) {
|
|
385
|
+
if b.loaded == nil {
|
|
386
|
+
return nil, fmt.Errorf("onnx summarizer model not loaded")
|
|
387
|
+
}
|
|
388
|
+
decLen := len(decIDs)
|
|
389
|
+
shape2DDec := ort.NewShape(1, int64(decLen))
|
|
390
|
+
shape2DEnc := ort.NewShape(1, int64(seqLen))
|
|
391
|
+
shape3D := ort.NewShape(1, int64(seqLen), t5SmallHiddenDim)
|
|
392
|
+
shapeLogits := ort.NewShape(1, int64(decLen), t5SmallVocabSize)
|
|
393
|
+
|
|
394
|
+
inIDs, err := ort.NewTensor(shape2DDec, decIDs)
|
|
395
|
+
if err != nil {
|
|
396
|
+
return nil, fmt.Errorf("create decoder input_ids tensor: %w", err)
|
|
397
|
+
}
|
|
398
|
+
defer inIDs.Destroy()
|
|
399
|
+
inMask, err := ort.NewTensor(shape2DEnc, encMask)
|
|
400
|
+
if err != nil {
|
|
401
|
+
return nil, fmt.Errorf("create decoder encoder_attention_mask tensor: %w", err)
|
|
402
|
+
}
|
|
403
|
+
defer inMask.Destroy()
|
|
404
|
+
inHidden, err := ort.NewTensor(shape3D, encHidden)
|
|
405
|
+
if err != nil {
|
|
406
|
+
return nil, fmt.Errorf("create decoder encoder_hidden_states tensor: %w", err)
|
|
407
|
+
}
|
|
408
|
+
defer inHidden.Destroy()
|
|
409
|
+
outLogits, err := ort.NewEmptyTensor[float32](shapeLogits)
|
|
410
|
+
if err != nil {
|
|
411
|
+
return nil, fmt.Errorf("create decoder logits tensor: %w", err)
|
|
412
|
+
}
|
|
413
|
+
defer outLogits.Destroy()
|
|
414
|
+
|
|
415
|
+
if err := b.loaded.RunDecoder([]ort.Value{inMask, inIDs, inHidden}, []ort.Value{outLogits}); err != nil {
|
|
416
|
+
return nil, fmt.Errorf("decoder step: %w", err)
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
data := outLogits.GetData()
|
|
420
|
+
out := make([]float32, len(data))
|
|
421
|
+
copy(out, data)
|
|
422
|
+
return out, nil
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
func (b *onnxLocalBackend) decode(ctx context.Context, encHidden []float32, encMask []int64, seqLen int, maxTokens int) (ids []int64, confidence float64, err error) {
|
|
426
|
+
decInput := []int64{b.tok.BOS()}
|
|
427
|
+
var logProbSum float64
|
|
428
|
+
var tokenCount int
|
|
429
|
+
|
|
430
|
+
for step := 0; step < maxTokens; step++ {
|
|
431
|
+
select {
|
|
432
|
+
case <-ctx.Done():
|
|
433
|
+
return nil, 0, ctx.Err()
|
|
434
|
+
default:
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
logits, err := b.runDecoderStep(decInput, encHidden, encMask, seqLen)
|
|
438
|
+
if err != nil {
|
|
439
|
+
return nil, 0, err
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
offset := (len(decInput) - 1) * t5SmallVocabSize
|
|
443
|
+
lastLogits := logits[offset : offset+t5SmallVocabSize]
|
|
444
|
+
|
|
445
|
+
nextToken, logProb := greedySelect(lastLogits)
|
|
446
|
+
if nextToken == b.tok.EOS() {
|
|
447
|
+
break
|
|
448
|
+
}
|
|
449
|
+
decInput = append(decInput, nextToken)
|
|
450
|
+
logProbSum += logProb
|
|
451
|
+
tokenCount++
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
if tokenCount > 0 {
|
|
455
|
+
confidence = math.Exp(logProbSum / float64(tokenCount))
|
|
456
|
+
}
|
|
457
|
+
return decInput[1:], confidence, nil
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
func greedySelect(logits []float32) (token int64, logProb float64) {
|
|
461
|
+
maxV := logits[0]
|
|
462
|
+
for _, v := range logits {
|
|
463
|
+
if v > maxV {
|
|
464
|
+
maxV = v
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
var sumExp float64
|
|
469
|
+
for _, v := range logits {
|
|
470
|
+
sumExp += math.Exp(float64(v) - float64(maxV))
|
|
471
|
+
}
|
|
472
|
+
logSumExp := float64(maxV) + math.Log(sumExp)
|
|
473
|
+
|
|
474
|
+
best := int64(0)
|
|
475
|
+
bestLogit := logits[0]
|
|
476
|
+
for i, v := range logits {
|
|
477
|
+
if v > bestLogit {
|
|
478
|
+
bestLogit = v
|
|
479
|
+
best = int64(i)
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
return best, float64(bestLogit) - logSumExp
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
func summarizeInput(turns []Turn) string {
|
|
486
|
+
parts := make([]string, 0, len(turns))
|
|
487
|
+
for _, turn := range turns {
|
|
488
|
+
text := strings.TrimSpace(turn.Text)
|
|
489
|
+
if text != "" {
|
|
490
|
+
parts = append(parts, text)
|
|
491
|
+
}
|
|
492
|
+
}
|
|
493
|
+
return strings.Join(parts, "\n")
|
|
494
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
package summarize
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"os"
|
|
6
|
+
"path/filepath"
|
|
7
|
+
"testing"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
func TestDecodeIsDeterministic(t *testing.T) {
|
|
11
|
+
modelDir := filepath.Clean(filepath.Join("..", "..", ".models", "t5-small"))
|
|
12
|
+
runtimePath := filepath.Clean(filepath.Join("..", "..", ".models", "onnxruntime", "onnxruntime-osx-arm64-1.23.0", "lib", "libonnxruntime.dylib"))
|
|
13
|
+
|
|
14
|
+
if _, err := os.Stat(filepath.Join(modelDir, "summarizer.json")); os.IsNotExist(err) {
|
|
15
|
+
t.Skip("t5 summarizer model not present")
|
|
16
|
+
}
|
|
17
|
+
if _, err := os.Stat(runtimePath); os.IsNotExist(err) {
|
|
18
|
+
t.Skip("onnx runtime not present")
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
engine := NewWithDeps(Config{
|
|
22
|
+
Backend: "onnx-local",
|
|
23
|
+
Profile: "t5-small",
|
|
24
|
+
RuntimePath: runtimePath,
|
|
25
|
+
ModelPath: modelDir,
|
|
26
|
+
}, Dependencies{})
|
|
27
|
+
|
|
28
|
+
input := []Turn{
|
|
29
|
+
{ID: "turn-1", Text: "The tower is 324 metres tall and located in Paris."},
|
|
30
|
+
}
|
|
31
|
+
opts := SummaryOpts{MinInputTurns: 1, MaxOutputTokens: 32}
|
|
32
|
+
|
|
33
|
+
r1, err := engine.Summarize(context.Background(), input, opts)
|
|
34
|
+
if err != nil {
|
|
35
|
+
t.Fatal(err)
|
|
36
|
+
}
|
|
37
|
+
r2, err := engine.Summarize(context.Background(), input, opts)
|
|
38
|
+
if err != nil {
|
|
39
|
+
t.Fatal(err)
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if r1.Text != r2.Text {
|
|
43
|
+
t.Fatalf("non-deterministic output:\n run1: %q\n run2: %q", r1.Text, r2.Text)
|
|
44
|
+
}
|
|
45
|
+
if r1.Confidence != r2.Confidence {
|
|
46
|
+
t.Fatalf("non-deterministic confidence: %f vs %f", r1.Confidence, r2.Confidence)
|
|
47
|
+
}
|
|
48
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
package summarize
|
|
2
|
+
|
|
3
|
+
import "strings"
|
|
4
|
+
|
|
5
|
+
type modelProfile struct {
|
|
6
|
+
Name string
|
|
7
|
+
Family string
|
|
8
|
+
MaxContextTokens int
|
|
9
|
+
Source modelSource
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
type modelSource struct {
|
|
13
|
+
BaseURL string
|
|
14
|
+
Files []string
|
|
15
|
+
SHA256 map[string]string
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
var shippedProfiles = map[string]modelProfile{
|
|
19
|
+
"t5-small": {
|
|
20
|
+
Name: "t5-small",
|
|
21
|
+
Family: "t5-small",
|
|
22
|
+
MaxContextTokens: 512,
|
|
23
|
+
Source: modelSource{
|
|
24
|
+
BaseURL: "https://huggingface.co/optimum/t5-small/resolve/main",
|
|
25
|
+
Files: []string{
|
|
26
|
+
"encoder_model.onnx",
|
|
27
|
+
"decoder_model.onnx",
|
|
28
|
+
"tokenizer.json",
|
|
29
|
+
"tokenizer_config.json",
|
|
30
|
+
"config.json",
|
|
31
|
+
},
|
|
32
|
+
SHA256: map[string]string{
|
|
33
|
+
"encoder_model.onnx": "41d326633f1b85f526508cc0db78a5d40877c292c1b6dccae2eacd7d2a53480d",
|
|
34
|
+
"decoder_model.onnx": "0a1451011d61bcc796a87b7306c503562e910f110f884d0cc08532972c2cc584",
|
|
35
|
+
"tokenizer.json": "5f0ed8ab5b8cfa9812bb73752f1d80c292e52bcf5a87a144dc9ab2d251056cbb",
|
|
36
|
+
"tokenizer_config.json": "4969f8d76ef05a16553bd2b07b3501673ae8d36972aea88a0f78ad31a3ff2de9",
|
|
37
|
+
"config.json": "d112428e703aa7ea0d6b17a77e9739fcc15b87653779d9b7942d5ecbc61c00ed",
|
|
38
|
+
},
|
|
39
|
+
},
|
|
40
|
+
},
|
|
41
|
+
"distilbart-cnn-12-6": {
|
|
42
|
+
Name: "distilbart-cnn-12-6",
|
|
43
|
+
Family: "distilbart-cnn-12-6",
|
|
44
|
+
MaxContextTokens: 1024,
|
|
45
|
+
},
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
func lookupProfile(name string) (modelProfile, bool) {
|
|
49
|
+
name = strings.TrimSpace(strings.ToLower(name))
|
|
50
|
+
profile, ok := shippedProfiles[name]
|
|
51
|
+
return profile, ok
|
|
52
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
package summarize
|
|
2
|
+
|
|
3
|
+
// Tokenizer is the boundary between text and token IDs.
|
|
4
|
+
// All summarizer backends must operate through this interface.
|
|
5
|
+
// No backend may call a tokenizer implementation directly.
|
|
6
|
+
type Tokenizer interface {
|
|
7
|
+
Encode(text string) ([]int64, error)
|
|
8
|
+
Decode(ids []int64) (string, error)
|
|
9
|
+
VocabSize() int
|
|
10
|
+
BOS() int64
|
|
11
|
+
EOS() int64
|
|
12
|
+
PAD() int64
|
|
13
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
package summarize
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"fmt"
|
|
5
|
+
|
|
6
|
+
"github.com/sugarme/tokenizer"
|
|
7
|
+
"github.com/sugarme/tokenizer/pretrained"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type hfTokenizer struct {
|
|
11
|
+
tk tokenizer.Tokenizer
|
|
12
|
+
bos int64
|
|
13
|
+
eos int64
|
|
14
|
+
pad int64
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
func newTokenizer(path string) (Tokenizer, error) {
|
|
18
|
+
tk, err := pretrained.FromFile(path)
|
|
19
|
+
if err != nil {
|
|
20
|
+
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
pad, ok := firstTokenID(*tk, []string{"<pad>", "[PAD]"})
|
|
24
|
+
if !ok {
|
|
25
|
+
pad = 0
|
|
26
|
+
}
|
|
27
|
+
bos, ok := firstTokenID(*tk, []string{"<s>", "[CLS]"})
|
|
28
|
+
if !ok {
|
|
29
|
+
bos = pad
|
|
30
|
+
}
|
|
31
|
+
eos, ok := firstTokenID(*tk, []string{"</s>", "[SEP]", "<eos>"})
|
|
32
|
+
if !ok {
|
|
33
|
+
eos = 1
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
return &hfTokenizer{
|
|
37
|
+
tk: *tk,
|
|
38
|
+
bos: bos,
|
|
39
|
+
eos: eos,
|
|
40
|
+
pad: pad,
|
|
41
|
+
}, nil
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
func (t *hfTokenizer) Encode(text string) ([]int64, error) {
|
|
45
|
+
encoding, err := t.tk.EncodeSingle(text, true)
|
|
46
|
+
if err != nil {
|
|
47
|
+
return nil, err
|
|
48
|
+
}
|
|
49
|
+
out := make([]int64, len(encoding.Ids))
|
|
50
|
+
for i, id := range encoding.Ids {
|
|
51
|
+
out[i] = int64(id)
|
|
52
|
+
}
|
|
53
|
+
return out, nil
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
func (t *hfTokenizer) Decode(ids []int64) (string, error) {
|
|
57
|
+
raw := make([]int, len(ids))
|
|
58
|
+
for i, id := range ids {
|
|
59
|
+
raw[i] = int(id)
|
|
60
|
+
}
|
|
61
|
+
return t.tk.Decode(raw, true), nil
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
func (t *hfTokenizer) VocabSize() int { return t.tk.GetVocabSize(true) }
|
|
65
|
+
func (t *hfTokenizer) BOS() int64 { return t.bos }
|
|
66
|
+
func (t *hfTokenizer) EOS() int64 { return t.eos }
|
|
67
|
+
func (t *hfTokenizer) PAD() int64 { return t.pad }
|
|
68
|
+
|
|
69
|
+
func firstTokenID(tk tokenizer.Tokenizer, names []string) (int64, bool) {
|
|
70
|
+
for _, name := range names {
|
|
71
|
+
if id, ok := tk.TokenToId(name); ok {
|
|
72
|
+
return int64(id), true
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
return 0, false
|
|
76
|
+
}
|