@xdarkicex/openclaw-memory-libravdb 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/README.md +46 -0
  2. package/docs/README.md +14 -0
  3. package/docs/architecture-decisions/README.md +6 -0
  4. package/docs/architecture-decisions/adr-001-onnx-over-ollama.md +21 -0
  5. package/docs/architecture-decisions/adr-002-libravdb-over-lancedb.md +19 -0
  6. package/docs/architecture-decisions/adr-003-convex-gating-over-threshold.md +27 -0
  7. package/docs/architecture-decisions/adr-004-sidecar-over-native-ts.md +21 -0
  8. package/docs/architecture.md +188 -0
  9. package/docs/contributing.md +76 -0
  10. package/docs/dependencies.md +38 -0
  11. package/docs/embedding-profiles.md +42 -0
  12. package/docs/gating.md +329 -0
  13. package/docs/implementation.md +381 -0
  14. package/docs/installation.md +272 -0
  15. package/docs/mathematics.md +695 -0
  16. package/docs/models.md +63 -0
  17. package/docs/problem.md +64 -0
  18. package/docs/security.md +86 -0
  19. package/openclaw.plugin.json +84 -0
  20. package/package.json +41 -0
  21. package/scripts/build-sidecar.sh +30 -0
  22. package/scripts/postinstall.js +169 -0
  23. package/scripts/setup.sh +20 -0
  24. package/scripts/setup.ts +505 -0
  25. package/scripts/sidecar-release.d.ts +4 -0
  26. package/scripts/sidecar-release.js +17 -0
  27. package/sidecar/cmd/inspect_onnx/main.go +105 -0
  28. package/sidecar/compact/gate.go +273 -0
  29. package/sidecar/compact/gate_test.go +85 -0
  30. package/sidecar/compact/summarize.go +345 -0
  31. package/sidecar/compact/summarize_test.go +319 -0
  32. package/sidecar/compact/tokens.go +11 -0
  33. package/sidecar/config/config.go +119 -0
  34. package/sidecar/config/config_test.go +75 -0
  35. package/sidecar/embed/engine.go +696 -0
  36. package/sidecar/embed/engine_test.go +349 -0
  37. package/sidecar/embed/matryoshka.go +93 -0
  38. package/sidecar/embed/matryoshka_test.go +150 -0
  39. package/sidecar/embed/onnx_local.go +319 -0
  40. package/sidecar/embed/onnx_local_test.go +159 -0
  41. package/sidecar/embed/profile_contract_test.go +71 -0
  42. package/sidecar/embed/profile_eval_test.go +923 -0
  43. package/sidecar/embed/profiles.go +39 -0
  44. package/sidecar/go.mod +21 -0
  45. package/sidecar/go.sum +30 -0
  46. package/sidecar/health/check.go +33 -0
  47. package/sidecar/health/check_test.go +55 -0
  48. package/sidecar/main.go +151 -0
  49. package/sidecar/model/encoder.go +222 -0
  50. package/sidecar/model/registry.go +262 -0
  51. package/sidecar/model/registry_test.go +102 -0
  52. package/sidecar/model/seq2seq.go +133 -0
  53. package/sidecar/server/rpc.go +343 -0
  54. package/sidecar/server/rpc_test.go +350 -0
  55. package/sidecar/server/transport.go +160 -0
  56. package/sidecar/store/libravdb.go +676 -0
  57. package/sidecar/store/libravdb_test.go +472 -0
  58. package/sidecar/summarize/engine.go +360 -0
  59. package/sidecar/summarize/engine_test.go +148 -0
  60. package/sidecar/summarize/onnx_local.go +494 -0
  61. package/sidecar/summarize/onnx_local_test.go +48 -0
  62. package/sidecar/summarize/profiles.go +52 -0
  63. package/sidecar/summarize/tokenizer.go +13 -0
  64. package/sidecar/summarize/tokenizer_hf.go +76 -0
  65. package/sidecar/summarize/util.go +13 -0
  66. package/src/cli.ts +205 -0
  67. package/src/context-engine.ts +195 -0
  68. package/src/index.ts +27 -0
  69. package/src/memory-provider.ts +24 -0
  70. package/src/openclaw-plugin-sdk.d.ts +53 -0
  71. package/src/plugin-runtime.ts +67 -0
  72. package/src/recall-cache.ts +34 -0
  73. package/src/recall-utils.ts +22 -0
  74. package/src/rpc.ts +84 -0
  75. package/src/scoring.ts +58 -0
  76. package/src/sidecar.ts +506 -0
  77. package/src/tokens.ts +36 -0
  78. package/src/types.ts +146 -0
  79. package/tsconfig.json +20 -0
  80. package/tsconfig.tests.json +12 -0
@@ -0,0 +1,39 @@
1
+ package embed
2
+
3
+ import "strings"
4
+
5
+ const (
6
+ DefaultEmbeddingProfile = "nomic-embed-text-v1.5"
7
+ FallbackEmbeddingProfile = "all-minilm-l6-v2"
8
+ )
9
+
10
+ type modelProfile struct {
11
+ Name string
12
+ Family string
13
+ Dimensions int
14
+ Normalize bool
15
+ MaxContextTokens int
16
+ }
17
+
18
+ var shippedProfiles = map[string]modelProfile{
19
+ "all-minilm-l6-v2": {
20
+ Name: "all-minilm-l6-v2",
21
+ Family: "all-minilm-l6-v2",
22
+ Dimensions: 384,
23
+ Normalize: true,
24
+ MaxContextTokens: 128,
25
+ },
26
+ "nomic-embed-text-v1.5": {
27
+ Name: "nomic-embed-text-v1.5",
28
+ Family: "nomic-embed-text-v1.5",
29
+ Dimensions: 768,
30
+ Normalize: true,
31
+ MaxContextTokens: 8192,
32
+ },
33
+ }
34
+
35
+ func lookupProfile(name string) (modelProfile, bool) {
36
+ name = strings.TrimSpace(strings.ToLower(name))
37
+ profile, ok := shippedProfiles[name]
38
+ return profile, ok
39
+ }
package/sidecar/go.mod ADDED
@@ -0,0 +1,21 @@
1
+ module github.com/xDarkicex/openclaw-memory-libravdb/sidecar
2
+
3
+ go 1.25.1
4
+
5
+ require (
6
+ github.com/sugarme/tokenizer v0.3.0
7
+ github.com/yalue/onnxruntime_go v1.21.0
8
+ )
9
+
10
+ require (
11
+ github.com/emirpasic/gods v1.18.1 // indirect
12
+ github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
13
+ github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
14
+ github.com/rivo/uniseg v0.4.7 // indirect
15
+ github.com/schollz/progressbar/v2 v2.15.0 // indirect
16
+ github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect
17
+ golang.org/x/text v0.25.0 // indirect
18
+ gopkg.in/yaml.v3 v3.0.1 // indirect
19
+ )
20
+
21
+ replace github.com/sugarme/tokenizer => github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43
package/sidecar/go.sum ADDED
@@ -0,0 +1,30 @@
1
+ github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43 h1:j8YQypEqa5OjqbGciCNb9hOcYbo1oTVuEjd/iu9U2SY=
2
+ github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43/go.mod h1:VJ+DLK5ZEZwzvODOWwY0cw+B1dabTd3nCB5HuFCItCc=
3
+ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4
+ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
5
+ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6
+ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
7
+ github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
8
+ github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
9
+ github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
10
+ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
11
+ github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
12
+ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
13
+ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
14
+ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
15
+ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
16
+ github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
17
+ github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
18
+ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
19
+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
20
+ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
21
+ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
22
+ github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4=
23
+ github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
24
+ github.com/yalue/onnxruntime_go v1.21.0 h1:DdtvfY7OP5gR8mwPDqAOAQckf+KcI30hPNJL8hQaYWI=
25
+ github.com/yalue/onnxruntime_go v1.21.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
26
+ golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
27
+ golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
28
+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
29
+ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
30
+ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
@@ -0,0 +1,33 @@
1
+ package health
2
+
3
+ import (
4
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/embed"
5
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/store"
6
+ )
7
+
8
+ type Status struct {
9
+ OK bool `json:"ok"`
10
+ Message string `json:"message"`
11
+ }
12
+
13
+ func Check(embedder embed.Embedder, st *store.Store) Status {
14
+ if embedder == nil {
15
+ return Status{OK: false, Message: "embedder unavailable"}
16
+ }
17
+ if !embedder.Ready() {
18
+ if embedder.Reason() != "" {
19
+ return Status{OK: false, Message: embedder.Reason()}
20
+ }
21
+ return Status{OK: false, Message: "embedder not ready"}
22
+ }
23
+ if embedder.Mode() == "fallback" {
24
+ if embedder.Reason() != "" {
25
+ return Status{OK: false, Message: embedder.Reason()}
26
+ }
27
+ return Status{OK: false, Message: "embedder running in deterministic fallback mode"}
28
+ }
29
+ if st == nil {
30
+ return Status{OK: false, Message: "store unavailable"}
31
+ }
32
+ return Status{OK: true, Message: "ok"}
33
+ }
@@ -0,0 +1,55 @@
1
+ package health
2
+
3
+ import (
4
+ "context"
5
+ "testing"
6
+
7
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/embed"
8
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/store"
9
+ )
10
+
11
+ type fakeEmbedder struct {
12
+ ready bool
13
+ reason string
14
+ mode string
15
+ profile embed.Profile
16
+ }
17
+
18
+ func (f fakeEmbedder) EmbedDocument(context.Context, string) ([]float32, error) {
19
+ return make([]float32, 1), nil
20
+ }
21
+ func (f fakeEmbedder) EmbedQuery(context.Context, string) ([]float32, error) {
22
+ return make([]float32, 1), nil
23
+ }
24
+ func (f fakeEmbedder) Dimensions() int { return 1 }
25
+ func (f fakeEmbedder) Profile() embed.Profile { return f.profile }
26
+ func (f fakeEmbedder) Ready() bool { return f.ready }
27
+ func (f fakeEmbedder) Reason() string { return f.reason }
28
+ func (f fakeEmbedder) Mode() string { return f.mode }
29
+
30
+ func TestCheckRejectsFallbackEmbedder(t *testing.T) {
31
+ status := Check(fakeEmbedder{ready: true, mode: "fallback", reason: "bundled embedder unavailable"}, &store.Store{})
32
+ if status.OK {
33
+ t.Fatalf("expected fallback embedder to fail health")
34
+ }
35
+ if status.Message != "bundled embedder unavailable" {
36
+ t.Fatalf("unexpected message %q", status.Message)
37
+ }
38
+ }
39
+
40
+ func TestCheckRejectsNotReadyEmbedderWithReason(t *testing.T) {
41
+ status := Check(fakeEmbedder{ready: false, reason: "missing onnx runtime"}, &store.Store{})
42
+ if status.OK {
43
+ t.Fatalf("expected not-ready embedder to fail health")
44
+ }
45
+ if status.Message != "missing onnx runtime" {
46
+ t.Fatalf("unexpected message %q", status.Message)
47
+ }
48
+ }
49
+
50
+ func TestCheckAcceptsReadyPrimaryEmbedder(t *testing.T) {
51
+ status := Check(fakeEmbedder{ready: true, mode: "primary"}, &store.Store{})
52
+ if !status.OK {
53
+ t.Fatalf("expected primary embedder to pass health, got %+v", status)
54
+ }
55
+ }
@@ -0,0 +1,151 @@
1
+ package main
2
+
3
+ import (
4
+ "context"
5
+ "errors"
6
+ "fmt"
7
+ "os"
8
+ "os/signal"
9
+ "syscall"
10
+
11
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/compact"
12
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/config"
13
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/embed"
14
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/health"
15
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/model"
16
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/server"
17
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/store"
18
+ "github.com/xDarkicex/openclaw-memory-libravdb/sidecar/summarize"
19
+ )
20
+
21
+ func main() {
22
+ cfg := config.FromEnv()
23
+ if err := preflightONNXRuntime(cfg); err != nil {
24
+ fmt.Fprintln(os.Stderr, err.Error())
25
+ os.Exit(1)
26
+ }
27
+ embedder := embed.NewWithConfig(embed.Config{
28
+ Backend: cfg.EmbeddingBackend,
29
+ Profile: cfg.EmbeddingProfile,
30
+ FallbackProfile: cfg.FallbackProfile,
31
+ RuntimePath: cfg.ONNXRuntimePath,
32
+ ModelPath: cfg.EmbeddingModelPath,
33
+ TokenizerPath: cfg.EmbeddingTokenizerPath,
34
+ Dimensions: cfg.EmbeddingDimensions,
35
+ Normalize: cfg.EmbeddingNormalize,
36
+ })
37
+ summarizerRuntimePath := cfg.SummarizerRuntimePath
38
+ if summarizerRuntimePath == "" {
39
+ summarizerRuntimePath = cfg.ONNXRuntimePath
40
+ }
41
+ extractive := summarize.NewExtractive(embedder, "extractive")
42
+ configuredSummarizer := summarize.NewWithDeps(summarize.Config{
43
+ Backend: cfg.SummarizerBackend,
44
+ Profile: cfg.SummarizerProfile,
45
+ RuntimePath: summarizerRuntimePath,
46
+ ModelPath: cfg.SummarizerModelPath,
47
+ TokenizerPath: cfg.SummarizerTokenizerPath,
48
+ Model: cfg.SummarizerModel,
49
+ Endpoint: cfg.SummarizerEndpoint,
50
+ }, summarize.Dependencies{
51
+ Embedder: embedder,
52
+ Registry: model.DefaultRegistry(),
53
+ })
54
+ var abstractive summarize.Summarizer
55
+ if configuredSummarizer != nil && configuredSummarizer.Ready() && configuredSummarizer.Mode() != "extractive" {
56
+ abstractive = configuredSummarizer
57
+ }
58
+ st, err := store.Open(cfg.DBPath, embedder)
59
+ if err != nil {
60
+ fmt.Fprintln(os.Stderr, err.Error())
61
+ os.Exit(1)
62
+ }
63
+ if err := st.BackfillDirtyTiers(context.Background()); err != nil {
64
+ fmt.Fprintln(os.Stderr, err.Error())
65
+ os.Exit(1)
66
+ }
67
+
68
+ status := health.Check(embedder, st)
69
+ if !status.OK {
70
+ fmt.Fprintln(os.Stderr, status.Message)
71
+ os.Exit(1)
72
+ }
73
+
74
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
75
+ defer stop()
76
+
77
+ srv := server.New(embedder, extractive, abstractive, st, compact.GatingConfig{
78
+ W1c: cfg.GatingW1c,
79
+ W2c: cfg.GatingW2c,
80
+ W3c: cfg.GatingW3c,
81
+ W1t: cfg.GatingW1t,
82
+ W2t: cfg.GatingW2t,
83
+ W3t: cfg.GatingW3t,
84
+ TechNorm: cfg.GatingTechNorm,
85
+ Threshold: cfg.GatingThreshold,
86
+ })
87
+ listener, endpoint, cleanup, err := server.Listen()
88
+ if err != nil {
89
+ fmt.Fprintln(os.Stderr, err.Error())
90
+ os.Exit(1)
91
+ }
92
+ defer cleanup()
93
+
94
+ fmt.Println(endpoint)
95
+ if err := server.Serve(ctx, listener, srv); err != nil && !errors.Is(err, server.ErrServerClosed) {
96
+ fmt.Fprintln(os.Stderr, err.Error())
97
+ os.Exit(1)
98
+ }
99
+ }
100
+
101
+ func preflightONNXRuntime(cfg config.Config) error {
102
+ paths := make([]string, 0, 2)
103
+
104
+ embeddingPath, err := resolvedRuntimePath(cfg.EmbeddingBackend, cfg.ONNXRuntimePath)
105
+ if err != nil {
106
+ return err
107
+ }
108
+ if embeddingPath != "" {
109
+ paths = append(paths, embeddingPath)
110
+ }
111
+
112
+ summarizerRuntimePath := cfg.SummarizerRuntimePath
113
+ if summarizerRuntimePath == "" {
114
+ summarizerRuntimePath = cfg.ONNXRuntimePath
115
+ }
116
+ summarizerPath, err := resolvedRuntimePath(cfg.SummarizerBackend, summarizerRuntimePath)
117
+ if err != nil {
118
+ return err
119
+ }
120
+ if summarizerPath != "" {
121
+ paths = append(paths, summarizerPath)
122
+ }
123
+
124
+ seen := map[string]struct{}{}
125
+ for _, runtimePath := range paths {
126
+ if _, ok := seen[runtimePath]; ok {
127
+ continue
128
+ }
129
+ seen[runtimePath] = struct{}{}
130
+ if _, err := os.Stat(runtimePath); err != nil {
131
+ if os.IsNotExist(err) {
132
+ return fmt.Errorf("ONNX Runtime library not found at %s\nRun scripts/setup.sh to unpack it.", runtimePath)
133
+ }
134
+ return fmt.Errorf("failed to stat ONNX Runtime library %s: %w", runtimePath, err)
135
+ }
136
+ }
137
+
138
+ return nil
139
+ }
140
+
141
+ func resolvedRuntimePath(backend, explicit string) (string, error) {
142
+ switch backend {
143
+ case "", "bundled", "onnx-local":
144
+ return embed.ResolveRuntimePath(embed.Config{
145
+ Backend: backend,
146
+ RuntimePath: explicit,
147
+ })
148
+ default:
149
+ return "", nil
150
+ }
151
+ }
@@ -0,0 +1,222 @@
1
+ package model
2
+
3
+ import (
4
+ "fmt"
5
+ "os"
6
+ "strings"
7
+
8
+ "github.com/sugarme/tokenizer"
9
+ "github.com/sugarme/tokenizer/pretrained"
10
+ ort "github.com/yalue/onnxruntime_go"
11
+ )
12
+
13
+ type EncoderSpec struct {
14
+ Key string
15
+ Profile Profile
16
+ RuntimePath string
17
+ ModelPath string
18
+ TokenizerPath string
19
+ InputNames []string
20
+ OutputName string
21
+ Dimensions int
22
+ AddSpecialTokens bool
23
+ Pooling string
24
+ }
25
+
26
+ type EncoderModel struct {
27
+ key string
28
+ registry *Registry
29
+ tokenizer tokenizer.Tokenizer
30
+ session *ort.DynamicAdvancedSession
31
+ inputNames []string
32
+ dimensions int
33
+ addSpecialTokens bool
34
+ pooling string
35
+ }
36
+
37
+ func (r *Registry) LoadEncoder(spec EncoderSpec) (*EncoderModel, error) {
38
+ r.mu.Lock()
39
+ defer r.mu.Unlock()
40
+
41
+ if err := r.ensureRuntimeLocked(strings.TrimSpace(spec.RuntimePath)); err != nil {
42
+ return nil, fmt.Errorf("failed to initialize onnx runtime: %w", err)
43
+ }
44
+ if spec.Key == "" {
45
+ spec.Key = spec.Profile.Name
46
+ }
47
+ if loaded, ok := r.loaded[spec.Key]; ok && loaded.encoder != nil {
48
+ loaded.lastAccess = timeNow()
49
+ return loaded.encoder, nil
50
+ }
51
+
52
+ tk, err := pretrained.FromFile(spec.TokenizerPath)
53
+ if err != nil {
54
+ return nil, fmt.Errorf("failed to load tokenizer: %w", err)
55
+ }
56
+ session, err := ort.NewDynamicAdvancedSession(spec.ModelPath, spec.InputNames, []string{spec.OutputName}, nil)
57
+ if err != nil {
58
+ return nil, fmt.Errorf("failed to create onnx session: %w", err)
59
+ }
60
+
61
+ encoder := &EncoderModel{
62
+ key: spec.Key,
63
+ registry: r,
64
+ tokenizer: *tk,
65
+ session: session,
66
+ inputNames: append([]string(nil), spec.InputNames...),
67
+ dimensions: spec.Dimensions,
68
+ addSpecialTokens: spec.AddSpecialTokens,
69
+ pooling: spec.Pooling,
70
+ }
71
+ r.loaded[spec.Key] = &loadedModel{
72
+ key: spec.Key,
73
+ profile: spec.Profile,
74
+ lastAccess: timeNow(),
75
+ useCount: 0,
76
+ reservedBytes: fileSize(spec.ModelPath) + fileSize(spec.TokenizerPath),
77
+ closeFn: session.Destroy,
78
+ encoder: encoder,
79
+ }
80
+ if err := r.maybeEvictLocked(timeNow()); err != nil {
81
+ return nil, err
82
+ }
83
+ return encoder, nil
84
+ }
85
+
86
+ func (m *EncoderModel) EmbedText(text string) ([]float32, error) {
87
+ encoding, err := m.EncodeText(text, m.addSpecialTokens)
88
+ if err != nil {
89
+ return nil, err
90
+ }
91
+ return m.EmbedEncoding(*encoding)
92
+ }
93
+
94
+ func (m *EncoderModel) TokenCount(text string) (int, error) {
95
+ encoding, err := m.EncodeText(text, m.addSpecialTokens)
96
+ if err != nil {
97
+ return 0, err
98
+ }
99
+ return len(encoding.Ids), nil
100
+ }
101
+
102
+ func (m *EncoderModel) EncodeText(text string, addSpecialTokens bool) (*tokenizer.Encoding, error) {
103
+ m.registry.mu.Lock()
104
+ m.registry.touchLocked(m.key)
105
+ m.registry.mu.Unlock()
106
+
107
+ input := tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(text))
108
+ encoding, err := m.tokenizer.Encode(input, addSpecialTokens)
109
+ if err != nil {
110
+ return nil, fmt.Errorf("failed to tokenize sentence: %w", err)
111
+ }
112
+ if encoding == nil || len(encoding.Ids) == 0 {
113
+ return nil, fmt.Errorf("tokenizer returned no encodings")
114
+ }
115
+ return encoding, nil
116
+ }
117
+
118
+ func (m *EncoderModel) EmbedEncoding(encoding tokenizer.Encoding) ([]float32, error) {
119
+ batchSize := 1
120
+ seqLength := len(encoding.Ids)
121
+ inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
122
+ inputs := make([]ort.Value, 0, len(m.inputNames))
123
+ encodings := []tokenizer.Encoding{encoding}
124
+
125
+ for _, name := range m.inputNames {
126
+ data, err := inputTensorData(name, encodings, seqLength)
127
+ if err != nil {
128
+ return nil, err
129
+ }
130
+ tensor, err := ort.NewTensor(inputShape, data)
131
+ if err != nil {
132
+ return nil, fmt.Errorf("failed creating %s tensor: %w", name, err)
133
+ }
134
+ defer tensor.Destroy()
135
+ inputs = append(inputs, tensor)
136
+ }
137
+
138
+ outputShape := ort.NewShape(int64(batchSize), int64(m.dimensions))
139
+ useMeanPooling := m.pooling == "mean"
140
+ if useMeanPooling {
141
+ outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(m.dimensions))
142
+ }
143
+ outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
144
+ if err != nil {
145
+ return nil, fmt.Errorf("failed creating output tensor: %w", err)
146
+ }
147
+ defer outputTensor.Destroy()
148
+
149
+ if err := m.session.Run(inputs, []ort.Value{outputTensor}); err != nil {
150
+ return nil, fmt.Errorf("failed to run onnx session: %w", err)
151
+ }
152
+
153
+ flat := outputTensor.GetData()
154
+ expected := batchSize * m.dimensions
155
+ if useMeanPooling {
156
+ expected = batchSize * seqLength * m.dimensions
157
+ }
158
+ if len(flat) != expected {
159
+ return nil, fmt.Errorf("unexpected output tensor size: got %d elements, expected %d", len(flat), expected)
160
+ }
161
+
162
+ if useMeanPooling {
163
+ return meanPoolLastHiddenState(flat, encoding.AttentionMask, seqLength, m.dimensions), nil
164
+ }
165
+
166
+ vec := make([]float32, m.dimensions)
167
+ copy(vec, flat[:m.dimensions])
168
+ return vec, nil
169
+ }
170
+
171
+ func inputTensorData(name string, encodings []tokenizer.Encoding, seqLength int) ([]int64, error) {
172
+ data := make([]int64, len(encodings)*seqLength)
173
+ for batch := range encodings {
174
+ switch name {
175
+ case "input_ids":
176
+ for i, id := range encodings[batch].Ids {
177
+ data[batch*seqLength+i] = int64(id)
178
+ }
179
+ case "attention_mask":
180
+ for i, mask := range encodings[batch].AttentionMask {
181
+ data[batch*seqLength+i] = int64(mask)
182
+ }
183
+ case "token_type_ids":
184
+ for i, typeID := range encodings[batch].TypeIds {
185
+ data[batch*seqLength+i] = int64(typeID)
186
+ }
187
+ default:
188
+ return nil, fmt.Errorf("unsupported input tensor name %q", name)
189
+ }
190
+ }
191
+ return data, nil
192
+ }
193
+
194
+ func meanPoolLastHiddenState(flat []float32, attentionMask []int, seqLength int, dimensions int) []float32 {
195
+ vec := make([]float32, dimensions)
196
+ var denom float32
197
+ for tokenIdx := 0; tokenIdx < seqLength; tokenIdx++ {
198
+ if tokenIdx >= len(attentionMask) || attentionMask[tokenIdx] == 0 {
199
+ continue
200
+ }
201
+ denom++
202
+ base := tokenIdx * dimensions
203
+ for dim := 0; dim < dimensions; dim++ {
204
+ vec[dim] += flat[base+dim]
205
+ }
206
+ }
207
+ if denom == 0 {
208
+ return vec
209
+ }
210
+ for dim := range vec {
211
+ vec[dim] /= denom
212
+ }
213
+ return vec
214
+ }
215
+
216
+ func fileSize(path string) int64 {
217
+ info, err := os.Stat(path)
218
+ if err != nil {
219
+ return 0
220
+ }
221
+ return info.Size()
222
+ }