@xdarkicex/openclaw-memory-libravdb 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -0
- package/docs/README.md +14 -0
- package/docs/architecture-decisions/README.md +6 -0
- package/docs/architecture-decisions/adr-001-onnx-over-ollama.md +21 -0
- package/docs/architecture-decisions/adr-002-libravdb-over-lancedb.md +19 -0
- package/docs/architecture-decisions/adr-003-convex-gating-over-threshold.md +27 -0
- package/docs/architecture-decisions/adr-004-sidecar-over-native-ts.md +21 -0
- package/docs/architecture.md +188 -0
- package/docs/contributing.md +76 -0
- package/docs/dependencies.md +38 -0
- package/docs/embedding-profiles.md +42 -0
- package/docs/gating.md +329 -0
- package/docs/implementation.md +381 -0
- package/docs/installation.md +272 -0
- package/docs/mathematics.md +695 -0
- package/docs/models.md +63 -0
- package/docs/problem.md +64 -0
- package/docs/security.md +86 -0
- package/openclaw.plugin.json +84 -0
- package/package.json +41 -0
- package/scripts/build-sidecar.sh +30 -0
- package/scripts/postinstall.js +169 -0
- package/scripts/setup.sh +20 -0
- package/scripts/setup.ts +505 -0
- package/scripts/sidecar-release.d.ts +4 -0
- package/scripts/sidecar-release.js +17 -0
- package/sidecar/cmd/inspect_onnx/main.go +105 -0
- package/sidecar/compact/gate.go +273 -0
- package/sidecar/compact/gate_test.go +85 -0
- package/sidecar/compact/summarize.go +345 -0
- package/sidecar/compact/summarize_test.go +319 -0
- package/sidecar/compact/tokens.go +11 -0
- package/sidecar/config/config.go +119 -0
- package/sidecar/config/config_test.go +75 -0
- package/sidecar/embed/engine.go +696 -0
- package/sidecar/embed/engine_test.go +349 -0
- package/sidecar/embed/matryoshka.go +93 -0
- package/sidecar/embed/matryoshka_test.go +150 -0
- package/sidecar/embed/onnx_local.go +319 -0
- package/sidecar/embed/onnx_local_test.go +159 -0
- package/sidecar/embed/profile_contract_test.go +71 -0
- package/sidecar/embed/profile_eval_test.go +923 -0
- package/sidecar/embed/profiles.go +39 -0
- package/sidecar/go.mod +21 -0
- package/sidecar/go.sum +30 -0
- package/sidecar/health/check.go +33 -0
- package/sidecar/health/check_test.go +55 -0
- package/sidecar/main.go +151 -0
- package/sidecar/model/encoder.go +222 -0
- package/sidecar/model/registry.go +262 -0
- package/sidecar/model/registry_test.go +102 -0
- package/sidecar/model/seq2seq.go +133 -0
- package/sidecar/server/rpc.go +343 -0
- package/sidecar/server/rpc_test.go +350 -0
- package/sidecar/server/transport.go +160 -0
- package/sidecar/store/libravdb.go +676 -0
- package/sidecar/store/libravdb_test.go +472 -0
- package/sidecar/summarize/engine.go +360 -0
- package/sidecar/summarize/engine_test.go +148 -0
- package/sidecar/summarize/onnx_local.go +494 -0
- package/sidecar/summarize/onnx_local_test.go +48 -0
- package/sidecar/summarize/profiles.go +52 -0
- package/sidecar/summarize/tokenizer.go +13 -0
- package/sidecar/summarize/tokenizer_hf.go +76 -0
- package/sidecar/summarize/util.go +13 -0
- package/src/cli.ts +205 -0
- package/src/context-engine.ts +195 -0
- package/src/index.ts +27 -0
- package/src/memory-provider.ts +24 -0
- package/src/openclaw-plugin-sdk.d.ts +53 -0
- package/src/plugin-runtime.ts +67 -0
- package/src/recall-cache.ts +34 -0
- package/src/recall-utils.ts +22 -0
- package/src/rpc.ts +84 -0
- package/src/scoring.ts +58 -0
- package/src/sidecar.ts +506 -0
- package/src/tokens.ts +36 -0
- package/src/types.ts +146 -0
- package/tsconfig.json +20 -0
- package/tsconfig.tests.json +12 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
package embed
|
|
2
|
+
|
|
3
|
+
import "strings"
|
|
4
|
+
|
|
5
|
+
const (
|
|
6
|
+
DefaultEmbeddingProfile = "nomic-embed-text-v1.5"
|
|
7
|
+
FallbackEmbeddingProfile = "all-minilm-l6-v2"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type modelProfile struct {
|
|
11
|
+
Name string
|
|
12
|
+
Family string
|
|
13
|
+
Dimensions int
|
|
14
|
+
Normalize bool
|
|
15
|
+
MaxContextTokens int
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
var shippedProfiles = map[string]modelProfile{
|
|
19
|
+
"all-minilm-l6-v2": {
|
|
20
|
+
Name: "all-minilm-l6-v2",
|
|
21
|
+
Family: "all-minilm-l6-v2",
|
|
22
|
+
Dimensions: 384,
|
|
23
|
+
Normalize: true,
|
|
24
|
+
MaxContextTokens: 128,
|
|
25
|
+
},
|
|
26
|
+
"nomic-embed-text-v1.5": {
|
|
27
|
+
Name: "nomic-embed-text-v1.5",
|
|
28
|
+
Family: "nomic-embed-text-v1.5",
|
|
29
|
+
Dimensions: 768,
|
|
30
|
+
Normalize: true,
|
|
31
|
+
MaxContextTokens: 8192,
|
|
32
|
+
},
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
func lookupProfile(name string) (modelProfile, bool) {
|
|
36
|
+
name = strings.TrimSpace(strings.ToLower(name))
|
|
37
|
+
profile, ok := shippedProfiles[name]
|
|
38
|
+
return profile, ok
|
|
39
|
+
}
|
package/sidecar/go.mod
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
module github.com/xDarkicex/openclaw-memory-libravdb/sidecar
|
|
2
|
+
|
|
3
|
+
go 1.25.1
|
|
4
|
+
|
|
5
|
+
require (
|
|
6
|
+
github.com/sugarme/tokenizer v0.3.0
|
|
7
|
+
github.com/yalue/onnxruntime_go v1.21.0
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
require (
|
|
11
|
+
github.com/emirpasic/gods v1.18.1 // indirect
|
|
12
|
+
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
|
13
|
+
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
|
14
|
+
github.com/rivo/uniseg v0.4.7 // indirect
|
|
15
|
+
github.com/schollz/progressbar/v2 v2.15.0 // indirect
|
|
16
|
+
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect
|
|
17
|
+
golang.org/x/text v0.25.0 // indirect
|
|
18
|
+
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
replace github.com/sugarme/tokenizer => github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43
|
package/sidecar/go.sum
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43 h1:j8YQypEqa5OjqbGciCNb9hOcYbo1oTVuEjd/iu9U2SY=
|
|
2
|
+
github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43/go.mod h1:VJ+DLK5ZEZwzvODOWwY0cw+B1dabTd3nCB5HuFCItCc=
|
|
3
|
+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
4
|
+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
|
5
|
+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
6
|
+
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
|
7
|
+
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
|
8
|
+
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
|
9
|
+
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
|
10
|
+
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
|
11
|
+
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
|
12
|
+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
|
13
|
+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
14
|
+
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
|
15
|
+
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
|
16
|
+
github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
|
|
17
|
+
github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
|
|
18
|
+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
19
|
+
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
|
20
|
+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
|
21
|
+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
|
22
|
+
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4=
|
|
23
|
+
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
|
|
24
|
+
github.com/yalue/onnxruntime_go v1.21.0 h1:DdtvfY7OP5gR8mwPDqAOAQckf+KcI30hPNJL8hQaYWI=
|
|
25
|
+
github.com/yalue/onnxruntime_go v1.21.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
|
26
|
+
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
|
27
|
+
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
|
28
|
+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
29
|
+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
30
|
+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
package health
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/embed"
|
|
5
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/store"
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
type Status struct {
|
|
9
|
+
OK bool `json:"ok"`
|
|
10
|
+
Message string `json:"message"`
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
func Check(embedder embed.Embedder, st *store.Store) Status {
|
|
14
|
+
if embedder == nil {
|
|
15
|
+
return Status{OK: false, Message: "embedder unavailable"}
|
|
16
|
+
}
|
|
17
|
+
if !embedder.Ready() {
|
|
18
|
+
if embedder.Reason() != "" {
|
|
19
|
+
return Status{OK: false, Message: embedder.Reason()}
|
|
20
|
+
}
|
|
21
|
+
return Status{OK: false, Message: "embedder not ready"}
|
|
22
|
+
}
|
|
23
|
+
if embedder.Mode() == "fallback" {
|
|
24
|
+
if embedder.Reason() != "" {
|
|
25
|
+
return Status{OK: false, Message: embedder.Reason()}
|
|
26
|
+
}
|
|
27
|
+
return Status{OK: false, Message: "embedder running in deterministic fallback mode"}
|
|
28
|
+
}
|
|
29
|
+
if st == nil {
|
|
30
|
+
return Status{OK: false, Message: "store unavailable"}
|
|
31
|
+
}
|
|
32
|
+
return Status{OK: true, Message: "ok"}
|
|
33
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
package health
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"testing"
|
|
6
|
+
|
|
7
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/embed"
|
|
8
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/store"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
type fakeEmbedder struct {
|
|
12
|
+
ready bool
|
|
13
|
+
reason string
|
|
14
|
+
mode string
|
|
15
|
+
profile embed.Profile
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
func (f fakeEmbedder) EmbedDocument(context.Context, string) ([]float32, error) {
|
|
19
|
+
return make([]float32, 1), nil
|
|
20
|
+
}
|
|
21
|
+
func (f fakeEmbedder) EmbedQuery(context.Context, string) ([]float32, error) {
|
|
22
|
+
return make([]float32, 1), nil
|
|
23
|
+
}
|
|
24
|
+
func (f fakeEmbedder) Dimensions() int { return 1 }
|
|
25
|
+
func (f fakeEmbedder) Profile() embed.Profile { return f.profile }
|
|
26
|
+
func (f fakeEmbedder) Ready() bool { return f.ready }
|
|
27
|
+
func (f fakeEmbedder) Reason() string { return f.reason }
|
|
28
|
+
func (f fakeEmbedder) Mode() string { return f.mode }
|
|
29
|
+
|
|
30
|
+
func TestCheckRejectsFallbackEmbedder(t *testing.T) {
|
|
31
|
+
status := Check(fakeEmbedder{ready: true, mode: "fallback", reason: "bundled embedder unavailable"}, &store.Store{})
|
|
32
|
+
if status.OK {
|
|
33
|
+
t.Fatalf("expected fallback embedder to fail health")
|
|
34
|
+
}
|
|
35
|
+
if status.Message != "bundled embedder unavailable" {
|
|
36
|
+
t.Fatalf("unexpected message %q", status.Message)
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
func TestCheckRejectsNotReadyEmbedderWithReason(t *testing.T) {
|
|
41
|
+
status := Check(fakeEmbedder{ready: false, reason: "missing onnx runtime"}, &store.Store{})
|
|
42
|
+
if status.OK {
|
|
43
|
+
t.Fatalf("expected not-ready embedder to fail health")
|
|
44
|
+
}
|
|
45
|
+
if status.Message != "missing onnx runtime" {
|
|
46
|
+
t.Fatalf("unexpected message %q", status.Message)
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
func TestCheckAcceptsReadyPrimaryEmbedder(t *testing.T) {
|
|
51
|
+
status := Check(fakeEmbedder{ready: true, mode: "primary"}, &store.Store{})
|
|
52
|
+
if !status.OK {
|
|
53
|
+
t.Fatalf("expected primary embedder to pass health, got %+v", status)
|
|
54
|
+
}
|
|
55
|
+
}
|
package/sidecar/main.go
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
package main
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"errors"
|
|
6
|
+
"fmt"
|
|
7
|
+
"os"
|
|
8
|
+
"os/signal"
|
|
9
|
+
"syscall"
|
|
10
|
+
|
|
11
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/compact"
|
|
12
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/config"
|
|
13
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/embed"
|
|
14
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/health"
|
|
15
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/model"
|
|
16
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/server"
|
|
17
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/store"
|
|
18
|
+
"github.com/xDarkicex/openclaw-memory-libravdb/sidecar/summarize"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
func main() {
|
|
22
|
+
cfg := config.FromEnv()
|
|
23
|
+
if err := preflightONNXRuntime(cfg); err != nil {
|
|
24
|
+
fmt.Fprintln(os.Stderr, err.Error())
|
|
25
|
+
os.Exit(1)
|
|
26
|
+
}
|
|
27
|
+
embedder := embed.NewWithConfig(embed.Config{
|
|
28
|
+
Backend: cfg.EmbeddingBackend,
|
|
29
|
+
Profile: cfg.EmbeddingProfile,
|
|
30
|
+
FallbackProfile: cfg.FallbackProfile,
|
|
31
|
+
RuntimePath: cfg.ONNXRuntimePath,
|
|
32
|
+
ModelPath: cfg.EmbeddingModelPath,
|
|
33
|
+
TokenizerPath: cfg.EmbeddingTokenizerPath,
|
|
34
|
+
Dimensions: cfg.EmbeddingDimensions,
|
|
35
|
+
Normalize: cfg.EmbeddingNormalize,
|
|
36
|
+
})
|
|
37
|
+
summarizerRuntimePath := cfg.SummarizerRuntimePath
|
|
38
|
+
if summarizerRuntimePath == "" {
|
|
39
|
+
summarizerRuntimePath = cfg.ONNXRuntimePath
|
|
40
|
+
}
|
|
41
|
+
extractive := summarize.NewExtractive(embedder, "extractive")
|
|
42
|
+
configuredSummarizer := summarize.NewWithDeps(summarize.Config{
|
|
43
|
+
Backend: cfg.SummarizerBackend,
|
|
44
|
+
Profile: cfg.SummarizerProfile,
|
|
45
|
+
RuntimePath: summarizerRuntimePath,
|
|
46
|
+
ModelPath: cfg.SummarizerModelPath,
|
|
47
|
+
TokenizerPath: cfg.SummarizerTokenizerPath,
|
|
48
|
+
Model: cfg.SummarizerModel,
|
|
49
|
+
Endpoint: cfg.SummarizerEndpoint,
|
|
50
|
+
}, summarize.Dependencies{
|
|
51
|
+
Embedder: embedder,
|
|
52
|
+
Registry: model.DefaultRegistry(),
|
|
53
|
+
})
|
|
54
|
+
var abstractive summarize.Summarizer
|
|
55
|
+
if configuredSummarizer != nil && configuredSummarizer.Ready() && configuredSummarizer.Mode() != "extractive" {
|
|
56
|
+
abstractive = configuredSummarizer
|
|
57
|
+
}
|
|
58
|
+
st, err := store.Open(cfg.DBPath, embedder)
|
|
59
|
+
if err != nil {
|
|
60
|
+
fmt.Fprintln(os.Stderr, err.Error())
|
|
61
|
+
os.Exit(1)
|
|
62
|
+
}
|
|
63
|
+
if err := st.BackfillDirtyTiers(context.Background()); err != nil {
|
|
64
|
+
fmt.Fprintln(os.Stderr, err.Error())
|
|
65
|
+
os.Exit(1)
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
status := health.Check(embedder, st)
|
|
69
|
+
if !status.OK {
|
|
70
|
+
fmt.Fprintln(os.Stderr, status.Message)
|
|
71
|
+
os.Exit(1)
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
75
|
+
defer stop()
|
|
76
|
+
|
|
77
|
+
srv := server.New(embedder, extractive, abstractive, st, compact.GatingConfig{
|
|
78
|
+
W1c: cfg.GatingW1c,
|
|
79
|
+
W2c: cfg.GatingW2c,
|
|
80
|
+
W3c: cfg.GatingW3c,
|
|
81
|
+
W1t: cfg.GatingW1t,
|
|
82
|
+
W2t: cfg.GatingW2t,
|
|
83
|
+
W3t: cfg.GatingW3t,
|
|
84
|
+
TechNorm: cfg.GatingTechNorm,
|
|
85
|
+
Threshold: cfg.GatingThreshold,
|
|
86
|
+
})
|
|
87
|
+
listener, endpoint, cleanup, err := server.Listen()
|
|
88
|
+
if err != nil {
|
|
89
|
+
fmt.Fprintln(os.Stderr, err.Error())
|
|
90
|
+
os.Exit(1)
|
|
91
|
+
}
|
|
92
|
+
defer cleanup()
|
|
93
|
+
|
|
94
|
+
fmt.Println(endpoint)
|
|
95
|
+
if err := server.Serve(ctx, listener, srv); err != nil && !errors.Is(err, server.ErrServerClosed) {
|
|
96
|
+
fmt.Fprintln(os.Stderr, err.Error())
|
|
97
|
+
os.Exit(1)
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
func preflightONNXRuntime(cfg config.Config) error {
|
|
102
|
+
paths := make([]string, 0, 2)
|
|
103
|
+
|
|
104
|
+
embeddingPath, err := resolvedRuntimePath(cfg.EmbeddingBackend, cfg.ONNXRuntimePath)
|
|
105
|
+
if err != nil {
|
|
106
|
+
return err
|
|
107
|
+
}
|
|
108
|
+
if embeddingPath != "" {
|
|
109
|
+
paths = append(paths, embeddingPath)
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
summarizerRuntimePath := cfg.SummarizerRuntimePath
|
|
113
|
+
if summarizerRuntimePath == "" {
|
|
114
|
+
summarizerRuntimePath = cfg.ONNXRuntimePath
|
|
115
|
+
}
|
|
116
|
+
summarizerPath, err := resolvedRuntimePath(cfg.SummarizerBackend, summarizerRuntimePath)
|
|
117
|
+
if err != nil {
|
|
118
|
+
return err
|
|
119
|
+
}
|
|
120
|
+
if summarizerPath != "" {
|
|
121
|
+
paths = append(paths, summarizerPath)
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
seen := map[string]struct{}{}
|
|
125
|
+
for _, runtimePath := range paths {
|
|
126
|
+
if _, ok := seen[runtimePath]; ok {
|
|
127
|
+
continue
|
|
128
|
+
}
|
|
129
|
+
seen[runtimePath] = struct{}{}
|
|
130
|
+
if _, err := os.Stat(runtimePath); err != nil {
|
|
131
|
+
if os.IsNotExist(err) {
|
|
132
|
+
return fmt.Errorf("ONNX Runtime library not found at %s\nRun scripts/setup.sh to unpack it.", runtimePath)
|
|
133
|
+
}
|
|
134
|
+
return fmt.Errorf("failed to stat ONNX Runtime library %s: %w", runtimePath, err)
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
return nil
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
func resolvedRuntimePath(backend, explicit string) (string, error) {
|
|
142
|
+
switch backend {
|
|
143
|
+
case "", "bundled", "onnx-local":
|
|
144
|
+
return embed.ResolveRuntimePath(embed.Config{
|
|
145
|
+
Backend: backend,
|
|
146
|
+
RuntimePath: explicit,
|
|
147
|
+
})
|
|
148
|
+
default:
|
|
149
|
+
return "", nil
|
|
150
|
+
}
|
|
151
|
+
}
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
package model
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"fmt"
|
|
5
|
+
"os"
|
|
6
|
+
"strings"
|
|
7
|
+
|
|
8
|
+
"github.com/sugarme/tokenizer"
|
|
9
|
+
"github.com/sugarme/tokenizer/pretrained"
|
|
10
|
+
ort "github.com/yalue/onnxruntime_go"
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
type EncoderSpec struct {
|
|
14
|
+
Key string
|
|
15
|
+
Profile Profile
|
|
16
|
+
RuntimePath string
|
|
17
|
+
ModelPath string
|
|
18
|
+
TokenizerPath string
|
|
19
|
+
InputNames []string
|
|
20
|
+
OutputName string
|
|
21
|
+
Dimensions int
|
|
22
|
+
AddSpecialTokens bool
|
|
23
|
+
Pooling string
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
type EncoderModel struct {
|
|
27
|
+
key string
|
|
28
|
+
registry *Registry
|
|
29
|
+
tokenizer tokenizer.Tokenizer
|
|
30
|
+
session *ort.DynamicAdvancedSession
|
|
31
|
+
inputNames []string
|
|
32
|
+
dimensions int
|
|
33
|
+
addSpecialTokens bool
|
|
34
|
+
pooling string
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
func (r *Registry) LoadEncoder(spec EncoderSpec) (*EncoderModel, error) {
|
|
38
|
+
r.mu.Lock()
|
|
39
|
+
defer r.mu.Unlock()
|
|
40
|
+
|
|
41
|
+
if err := r.ensureRuntimeLocked(strings.TrimSpace(spec.RuntimePath)); err != nil {
|
|
42
|
+
return nil, fmt.Errorf("failed to initialize onnx runtime: %w", err)
|
|
43
|
+
}
|
|
44
|
+
if spec.Key == "" {
|
|
45
|
+
spec.Key = spec.Profile.Name
|
|
46
|
+
}
|
|
47
|
+
if loaded, ok := r.loaded[spec.Key]; ok && loaded.encoder != nil {
|
|
48
|
+
loaded.lastAccess = timeNow()
|
|
49
|
+
return loaded.encoder, nil
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
tk, err := pretrained.FromFile(spec.TokenizerPath)
|
|
53
|
+
if err != nil {
|
|
54
|
+
return nil, fmt.Errorf("failed to load tokenizer: %w", err)
|
|
55
|
+
}
|
|
56
|
+
session, err := ort.NewDynamicAdvancedSession(spec.ModelPath, spec.InputNames, []string{spec.OutputName}, nil)
|
|
57
|
+
if err != nil {
|
|
58
|
+
return nil, fmt.Errorf("failed to create onnx session: %w", err)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
encoder := &EncoderModel{
|
|
62
|
+
key: spec.Key,
|
|
63
|
+
registry: r,
|
|
64
|
+
tokenizer: *tk,
|
|
65
|
+
session: session,
|
|
66
|
+
inputNames: append([]string(nil), spec.InputNames...),
|
|
67
|
+
dimensions: spec.Dimensions,
|
|
68
|
+
addSpecialTokens: spec.AddSpecialTokens,
|
|
69
|
+
pooling: spec.Pooling,
|
|
70
|
+
}
|
|
71
|
+
r.loaded[spec.Key] = &loadedModel{
|
|
72
|
+
key: spec.Key,
|
|
73
|
+
profile: spec.Profile,
|
|
74
|
+
lastAccess: timeNow(),
|
|
75
|
+
useCount: 0,
|
|
76
|
+
reservedBytes: fileSize(spec.ModelPath) + fileSize(spec.TokenizerPath),
|
|
77
|
+
closeFn: session.Destroy,
|
|
78
|
+
encoder: encoder,
|
|
79
|
+
}
|
|
80
|
+
if err := r.maybeEvictLocked(timeNow()); err != nil {
|
|
81
|
+
return nil, err
|
|
82
|
+
}
|
|
83
|
+
return encoder, nil
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
func (m *EncoderModel) EmbedText(text string) ([]float32, error) {
|
|
87
|
+
encoding, err := m.EncodeText(text, m.addSpecialTokens)
|
|
88
|
+
if err != nil {
|
|
89
|
+
return nil, err
|
|
90
|
+
}
|
|
91
|
+
return m.EmbedEncoding(*encoding)
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
func (m *EncoderModel) TokenCount(text string) (int, error) {
|
|
95
|
+
encoding, err := m.EncodeText(text, m.addSpecialTokens)
|
|
96
|
+
if err != nil {
|
|
97
|
+
return 0, err
|
|
98
|
+
}
|
|
99
|
+
return len(encoding.Ids), nil
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
func (m *EncoderModel) EncodeText(text string, addSpecialTokens bool) (*tokenizer.Encoding, error) {
|
|
103
|
+
m.registry.mu.Lock()
|
|
104
|
+
m.registry.touchLocked(m.key)
|
|
105
|
+
m.registry.mu.Unlock()
|
|
106
|
+
|
|
107
|
+
input := tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(text))
|
|
108
|
+
encoding, err := m.tokenizer.Encode(input, addSpecialTokens)
|
|
109
|
+
if err != nil {
|
|
110
|
+
return nil, fmt.Errorf("failed to tokenize sentence: %w", err)
|
|
111
|
+
}
|
|
112
|
+
if encoding == nil || len(encoding.Ids) == 0 {
|
|
113
|
+
return nil, fmt.Errorf("tokenizer returned no encodings")
|
|
114
|
+
}
|
|
115
|
+
return encoding, nil
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
func (m *EncoderModel) EmbedEncoding(encoding tokenizer.Encoding) ([]float32, error) {
|
|
119
|
+
batchSize := 1
|
|
120
|
+
seqLength := len(encoding.Ids)
|
|
121
|
+
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
|
122
|
+
inputs := make([]ort.Value, 0, len(m.inputNames))
|
|
123
|
+
encodings := []tokenizer.Encoding{encoding}
|
|
124
|
+
|
|
125
|
+
for _, name := range m.inputNames {
|
|
126
|
+
data, err := inputTensorData(name, encodings, seqLength)
|
|
127
|
+
if err != nil {
|
|
128
|
+
return nil, err
|
|
129
|
+
}
|
|
130
|
+
tensor, err := ort.NewTensor(inputShape, data)
|
|
131
|
+
if err != nil {
|
|
132
|
+
return nil, fmt.Errorf("failed creating %s tensor: %w", name, err)
|
|
133
|
+
}
|
|
134
|
+
defer tensor.Destroy()
|
|
135
|
+
inputs = append(inputs, tensor)
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
outputShape := ort.NewShape(int64(batchSize), int64(m.dimensions))
|
|
139
|
+
useMeanPooling := m.pooling == "mean"
|
|
140
|
+
if useMeanPooling {
|
|
141
|
+
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(m.dimensions))
|
|
142
|
+
}
|
|
143
|
+
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
|
144
|
+
if err != nil {
|
|
145
|
+
return nil, fmt.Errorf("failed creating output tensor: %w", err)
|
|
146
|
+
}
|
|
147
|
+
defer outputTensor.Destroy()
|
|
148
|
+
|
|
149
|
+
if err := m.session.Run(inputs, []ort.Value{outputTensor}); err != nil {
|
|
150
|
+
return nil, fmt.Errorf("failed to run onnx session: %w", err)
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
flat := outputTensor.GetData()
|
|
154
|
+
expected := batchSize * m.dimensions
|
|
155
|
+
if useMeanPooling {
|
|
156
|
+
expected = batchSize * seqLength * m.dimensions
|
|
157
|
+
}
|
|
158
|
+
if len(flat) != expected {
|
|
159
|
+
return nil, fmt.Errorf("unexpected output tensor size: got %d elements, expected %d", len(flat), expected)
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
if useMeanPooling {
|
|
163
|
+
return meanPoolLastHiddenState(flat, encoding.AttentionMask, seqLength, m.dimensions), nil
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
vec := make([]float32, m.dimensions)
|
|
167
|
+
copy(vec, flat[:m.dimensions])
|
|
168
|
+
return vec, nil
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
func inputTensorData(name string, encodings []tokenizer.Encoding, seqLength int) ([]int64, error) {
|
|
172
|
+
data := make([]int64, len(encodings)*seqLength)
|
|
173
|
+
for batch := range encodings {
|
|
174
|
+
switch name {
|
|
175
|
+
case "input_ids":
|
|
176
|
+
for i, id := range encodings[batch].Ids {
|
|
177
|
+
data[batch*seqLength+i] = int64(id)
|
|
178
|
+
}
|
|
179
|
+
case "attention_mask":
|
|
180
|
+
for i, mask := range encodings[batch].AttentionMask {
|
|
181
|
+
data[batch*seqLength+i] = int64(mask)
|
|
182
|
+
}
|
|
183
|
+
case "token_type_ids":
|
|
184
|
+
for i, typeID := range encodings[batch].TypeIds {
|
|
185
|
+
data[batch*seqLength+i] = int64(typeID)
|
|
186
|
+
}
|
|
187
|
+
default:
|
|
188
|
+
return nil, fmt.Errorf("unsupported input tensor name %q", name)
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
return data, nil
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
func meanPoolLastHiddenState(flat []float32, attentionMask []int, seqLength int, dimensions int) []float32 {
|
|
195
|
+
vec := make([]float32, dimensions)
|
|
196
|
+
var denom float32
|
|
197
|
+
for tokenIdx := 0; tokenIdx < seqLength; tokenIdx++ {
|
|
198
|
+
if tokenIdx >= len(attentionMask) || attentionMask[tokenIdx] == 0 {
|
|
199
|
+
continue
|
|
200
|
+
}
|
|
201
|
+
denom++
|
|
202
|
+
base := tokenIdx * dimensions
|
|
203
|
+
for dim := 0; dim < dimensions; dim++ {
|
|
204
|
+
vec[dim] += flat[base+dim]
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
if denom == 0 {
|
|
208
|
+
return vec
|
|
209
|
+
}
|
|
210
|
+
for dim := range vec {
|
|
211
|
+
vec[dim] /= denom
|
|
212
|
+
}
|
|
213
|
+
return vec
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
func fileSize(path string) int64 {
|
|
217
|
+
info, err := os.Stat(path)
|
|
218
|
+
if err != nil {
|
|
219
|
+
return 0
|
|
220
|
+
}
|
|
221
|
+
return info.Size()
|
|
222
|
+
}
|