@xdarkicex/openclaw-memory-libravdb 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -0
- package/docs/README.md +14 -0
- package/docs/architecture-decisions/README.md +6 -0
- package/docs/architecture-decisions/adr-001-onnx-over-ollama.md +21 -0
- package/docs/architecture-decisions/adr-002-libravdb-over-lancedb.md +19 -0
- package/docs/architecture-decisions/adr-003-convex-gating-over-threshold.md +27 -0
- package/docs/architecture-decisions/adr-004-sidecar-over-native-ts.md +21 -0
- package/docs/architecture.md +188 -0
- package/docs/contributing.md +76 -0
- package/docs/dependencies.md +38 -0
- package/docs/embedding-profiles.md +42 -0
- package/docs/gating.md +329 -0
- package/docs/implementation.md +381 -0
- package/docs/installation.md +272 -0
- package/docs/mathematics.md +695 -0
- package/docs/models.md +63 -0
- package/docs/problem.md +64 -0
- package/docs/security.md +86 -0
- package/openclaw.plugin.json +84 -0
- package/package.json +41 -0
- package/scripts/build-sidecar.sh +30 -0
- package/scripts/postinstall.js +169 -0
- package/scripts/setup.sh +20 -0
- package/scripts/setup.ts +505 -0
- package/scripts/sidecar-release.d.ts +4 -0
- package/scripts/sidecar-release.js +17 -0
- package/sidecar/cmd/inspect_onnx/main.go +105 -0
- package/sidecar/compact/gate.go +273 -0
- package/sidecar/compact/gate_test.go +85 -0
- package/sidecar/compact/summarize.go +345 -0
- package/sidecar/compact/summarize_test.go +319 -0
- package/sidecar/compact/tokens.go +11 -0
- package/sidecar/config/config.go +119 -0
- package/sidecar/config/config_test.go +75 -0
- package/sidecar/embed/engine.go +696 -0
- package/sidecar/embed/engine_test.go +349 -0
- package/sidecar/embed/matryoshka.go +93 -0
- package/sidecar/embed/matryoshka_test.go +150 -0
- package/sidecar/embed/onnx_local.go +319 -0
- package/sidecar/embed/onnx_local_test.go +159 -0
- package/sidecar/embed/profile_contract_test.go +71 -0
- package/sidecar/embed/profile_eval_test.go +923 -0
- package/sidecar/embed/profiles.go +39 -0
- package/sidecar/go.mod +21 -0
- package/sidecar/go.sum +30 -0
- package/sidecar/health/check.go +33 -0
- package/sidecar/health/check_test.go +55 -0
- package/sidecar/main.go +151 -0
- package/sidecar/model/encoder.go +222 -0
- package/sidecar/model/registry.go +262 -0
- package/sidecar/model/registry_test.go +102 -0
- package/sidecar/model/seq2seq.go +133 -0
- package/sidecar/server/rpc.go +343 -0
- package/sidecar/server/rpc_test.go +350 -0
- package/sidecar/server/transport.go +160 -0
- package/sidecar/store/libravdb.go +676 -0
- package/sidecar/store/libravdb_test.go +472 -0
- package/sidecar/summarize/engine.go +360 -0
- package/sidecar/summarize/engine_test.go +148 -0
- package/sidecar/summarize/onnx_local.go +494 -0
- package/sidecar/summarize/onnx_local_test.go +48 -0
- package/sidecar/summarize/profiles.go +52 -0
- package/sidecar/summarize/tokenizer.go +13 -0
- package/sidecar/summarize/tokenizer_hf.go +76 -0
- package/sidecar/summarize/util.go +13 -0
- package/src/cli.ts +205 -0
- package/src/context-engine.ts +195 -0
- package/src/index.ts +27 -0
- package/src/memory-provider.ts +24 -0
- package/src/openclaw-plugin-sdk.d.ts +53 -0
- package/src/plugin-runtime.ts +67 -0
- package/src/recall-cache.ts +34 -0
- package/src/recall-utils.ts +22 -0
- package/src/rpc.ts +84 -0
- package/src/scoring.ts +58 -0
- package/src/sidecar.ts +506 -0
- package/src/tokens.ts +36 -0
- package/src/types.ts +146 -0
- package/tsconfig.json +20 -0
- package/tsconfig.tests.json +12 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
package model
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"fmt"
|
|
5
|
+
"math"
|
|
6
|
+
"sync"
|
|
7
|
+
"time"
|
|
8
|
+
|
|
9
|
+
ort "github.com/yalue/onnxruntime_go"
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
type Task int
|
|
13
|
+
|
|
14
|
+
const (
|
|
15
|
+
TaskEmbedding Task = iota
|
|
16
|
+
TaskSummarization
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
type Profile struct {
|
|
20
|
+
Name string
|
|
21
|
+
Family string
|
|
22
|
+
Task Task
|
|
23
|
+
Dims int
|
|
24
|
+
MaxCtxTokens int
|
|
25
|
+
Quantization string
|
|
26
|
+
Normalize bool
|
|
27
|
+
ModelPath string
|
|
28
|
+
TokenizerPath string
|
|
29
|
+
OrtLibPath string
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
type MemoryPolicy struct {
|
|
33
|
+
SummarizerIdleTTL time.Duration
|
|
34
|
+
EmbedderIdleTTL time.Duration
|
|
35
|
+
MaxTotalModelBytes int64
|
|
36
|
+
// EvictionK calibrates model-eviction sensitivity:
|
|
37
|
+
// k = idleEligibilitySeconds * medianModelSizeBytes.
|
|
38
|
+
// Increase k to keep idle models resident longer; decrease it to evict more aggressively.
|
|
39
|
+
EvictionK float64
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
type Status struct {
|
|
43
|
+
Name string `json:"name"`
|
|
44
|
+
Family string `json:"family"`
|
|
45
|
+
Task string `json:"task"`
|
|
46
|
+
Loaded bool `json:"loaded"`
|
|
47
|
+
UseCount int `json:"useCount"`
|
|
48
|
+
LastAccess time.Time `json:"lastAccess"`
|
|
49
|
+
IdleFor time.Duration `json:"idleFor"`
|
|
50
|
+
ReservedBytes int64 `json:"reservedBytes"`
|
|
51
|
+
EvictionPriority float64 `json:"evictionPriority"`
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
type Registry struct {
|
|
55
|
+
mu sync.RWMutex
|
|
56
|
+
policy MemoryPolicy
|
|
57
|
+
runtimePath string
|
|
58
|
+
loaded map[string]*loadedModel
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
type loadedModel struct {
|
|
62
|
+
key string
|
|
63
|
+
profile Profile
|
|
64
|
+
lastAccess time.Time
|
|
65
|
+
useCount int
|
|
66
|
+
reservedBytes int64
|
|
67
|
+
closeFn func() error
|
|
68
|
+
encoder *EncoderModel
|
|
69
|
+
seq2seq *Seq2SeqModel
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
const defaultEvictionK = 1.2e11
|
|
73
|
+
|
|
74
|
+
var timeNow = time.Now
|
|
75
|
+
|
|
76
|
+
func DefaultMemoryPolicy() MemoryPolicy {
|
|
77
|
+
return MemoryPolicy{
|
|
78
|
+
SummarizerIdleTTL: 5 * time.Minute,
|
|
79
|
+
EmbedderIdleTTL: 30 * time.Minute,
|
|
80
|
+
MaxTotalModelBytes: 2 << 30,
|
|
81
|
+
EvictionK: defaultEvictionK,
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
func NewRegistry(policy MemoryPolicy) *Registry {
|
|
86
|
+
return &Registry{
|
|
87
|
+
policy: policy,
|
|
88
|
+
loaded: make(map[string]*loadedModel),
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
var defaultRegistry = NewRegistry(DefaultMemoryPolicy())
|
|
93
|
+
|
|
94
|
+
func DefaultRegistry() *Registry {
|
|
95
|
+
return defaultRegistry
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
func (r *Registry) Close() error {
|
|
99
|
+
r.mu.Lock()
|
|
100
|
+
defer r.mu.Unlock()
|
|
101
|
+
|
|
102
|
+
for key, loaded := range r.loaded {
|
|
103
|
+
_ = closeLoadedModel(loaded)
|
|
104
|
+
delete(r.loaded, key)
|
|
105
|
+
}
|
|
106
|
+
if ort.IsInitialized() {
|
|
107
|
+
if err := ort.DestroyEnvironment(); err != nil {
|
|
108
|
+
return err
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
r.runtimePath = ""
|
|
112
|
+
return nil
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
func (r *Registry) Status() map[string]Status {
|
|
116
|
+
r.mu.RLock()
|
|
117
|
+
defer r.mu.RUnlock()
|
|
118
|
+
|
|
119
|
+
now := time.Now()
|
|
120
|
+
out := make(map[string]Status, len(r.loaded))
|
|
121
|
+
for key, loaded := range r.loaded {
|
|
122
|
+
out[key] = Status{
|
|
123
|
+
Name: loaded.profile.Name,
|
|
124
|
+
Family: loaded.profile.Family,
|
|
125
|
+
Task: taskName(loaded.profile.Task),
|
|
126
|
+
Loaded: true,
|
|
127
|
+
UseCount: loaded.useCount,
|
|
128
|
+
LastAccess: loaded.lastAccess,
|
|
129
|
+
IdleFor: now.Sub(loaded.lastAccess),
|
|
130
|
+
ReservedBytes: loaded.reservedBytes,
|
|
131
|
+
EvictionPriority: evictionPriority(*loaded, now, r.policy.EvictionK),
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
return out
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
func (r *Registry) Unload(name string) error {
|
|
138
|
+
r.mu.Lock()
|
|
139
|
+
defer r.mu.Unlock()
|
|
140
|
+
|
|
141
|
+
loaded, ok := r.loaded[name]
|
|
142
|
+
if !ok {
|
|
143
|
+
return nil
|
|
144
|
+
}
|
|
145
|
+
if err := closeLoadedModel(loaded); err != nil {
|
|
146
|
+
return err
|
|
147
|
+
}
|
|
148
|
+
delete(r.loaded, name)
|
|
149
|
+
if len(r.loaded) == 0 && ort.IsInitialized() {
|
|
150
|
+
if err := ort.DestroyEnvironment(); err != nil {
|
|
151
|
+
return err
|
|
152
|
+
}
|
|
153
|
+
r.runtimePath = ""
|
|
154
|
+
}
|
|
155
|
+
return nil
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
func (r *Registry) touchLocked(key string) {
|
|
159
|
+
if loaded, ok := r.loaded[key]; ok {
|
|
160
|
+
loaded.lastAccess = time.Now()
|
|
161
|
+
loaded.useCount++
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
func (r *Registry) ensureRuntimeLocked(path string) error {
|
|
166
|
+
if path == "" {
|
|
167
|
+
return fmt.Errorf("onnx runtime path is required")
|
|
168
|
+
}
|
|
169
|
+
if ort.IsInitialized() {
|
|
170
|
+
if r.runtimePath != "" && r.runtimePath != path {
|
|
171
|
+
return fmt.Errorf("onnx runtime already initialized with %q, cannot switch to %q", r.runtimePath, path)
|
|
172
|
+
}
|
|
173
|
+
r.runtimePath = path
|
|
174
|
+
return nil
|
|
175
|
+
}
|
|
176
|
+
ort.SetSharedLibraryPath(path)
|
|
177
|
+
if err := ort.InitializeEnvironment(); err != nil {
|
|
178
|
+
return err
|
|
179
|
+
}
|
|
180
|
+
r.runtimePath = path
|
|
181
|
+
return nil
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
func (r *Registry) maybeEvictLocked(now time.Time) error {
|
|
185
|
+
for key, loaded := range r.loaded {
|
|
186
|
+
idleTTL := r.policy.EmbedderIdleTTL
|
|
187
|
+
if loaded.profile.Task == TaskSummarization {
|
|
188
|
+
idleTTL = r.policy.SummarizerIdleTTL
|
|
189
|
+
}
|
|
190
|
+
if idleTTL > 0 && now.Sub(loaded.lastAccess) > idleTTL {
|
|
191
|
+
if err := closeLoadedModel(loaded); err != nil {
|
|
192
|
+
return err
|
|
193
|
+
}
|
|
194
|
+
delete(r.loaded, key)
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
if r.policy.MaxTotalModelBytes <= 0 {
|
|
199
|
+
return nil
|
|
200
|
+
}
|
|
201
|
+
for totalReservedBytes(r.loaded) > r.policy.MaxTotalModelBytes {
|
|
202
|
+
evictKey := ""
|
|
203
|
+
evictScore := 0.0
|
|
204
|
+
for key, loaded := range r.loaded {
|
|
205
|
+
score := evictionPriority(*loaded, now, r.policy.EvictionK)
|
|
206
|
+
if evictKey == "" || score > evictScore {
|
|
207
|
+
evictKey = key
|
|
208
|
+
evictScore = score
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
if evictKey == "" {
|
|
212
|
+
return nil
|
|
213
|
+
}
|
|
214
|
+
if err := closeLoadedModel(r.loaded[evictKey]); err != nil {
|
|
215
|
+
return err
|
|
216
|
+
}
|
|
217
|
+
delete(r.loaded, evictKey)
|
|
218
|
+
}
|
|
219
|
+
return nil
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
func evictionPriority(m loadedModel, now time.Time, k float64) float64 {
|
|
223
|
+
deltaT := now.Sub(m.lastAccess).Seconds()
|
|
224
|
+
if deltaT < 0 {
|
|
225
|
+
deltaT = 0
|
|
226
|
+
}
|
|
227
|
+
size := float64(m.reservedBytes)
|
|
228
|
+
if size < 0 {
|
|
229
|
+
size = 0
|
|
230
|
+
}
|
|
231
|
+
if k <= 0 {
|
|
232
|
+
k = defaultEvictionK
|
|
233
|
+
}
|
|
234
|
+
freq := float64(m.useCount)
|
|
235
|
+
return (deltaT * size) / (k * (1.0 + math.Log(freq+1.0)))
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
func totalReservedBytes(loaded map[string]*loadedModel) int64 {
|
|
239
|
+
var total int64
|
|
240
|
+
for _, item := range loaded {
|
|
241
|
+
total += item.reservedBytes
|
|
242
|
+
}
|
|
243
|
+
return total
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
func closeLoadedModel(loaded *loadedModel) error {
|
|
247
|
+
if loaded == nil || loaded.closeFn == nil {
|
|
248
|
+
return nil
|
|
249
|
+
}
|
|
250
|
+
return loaded.closeFn()
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
func taskName(task Task) string {
|
|
254
|
+
switch task {
|
|
255
|
+
case TaskEmbedding:
|
|
256
|
+
return "embedding"
|
|
257
|
+
case TaskSummarization:
|
|
258
|
+
return "summarization"
|
|
259
|
+
default:
|
|
260
|
+
return "unknown"
|
|
261
|
+
}
|
|
262
|
+
}
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
package model
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"math"
|
|
5
|
+
"testing"
|
|
6
|
+
"time"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
func TestEvictionPriorityPrefersLargeIdleRarelyUsedModels(t *testing.T) {
|
|
10
|
+
now := time.Unix(10_000, 0)
|
|
11
|
+
k := defaultEvictionK
|
|
12
|
+
|
|
13
|
+
largeIdleRare := loadedModel{
|
|
14
|
+
lastAccess: now.Add(-10 * time.Minute),
|
|
15
|
+
useCount: 1,
|
|
16
|
+
reservedBytes: 200 << 20,
|
|
17
|
+
}
|
|
18
|
+
smallRecentWarm := loadedModel{
|
|
19
|
+
lastAccess: now.Add(-10 * time.Second),
|
|
20
|
+
useCount: 50_000,
|
|
21
|
+
reservedBytes: 60 << 20,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
oldScore := evictionPriority(largeIdleRare, now, k)
|
|
25
|
+
recentScore := evictionPriority(smallRecentWarm, now, k)
|
|
26
|
+
|
|
27
|
+
if !(oldScore > recentScore) {
|
|
28
|
+
t.Fatalf("expected larger, older, colder model to have higher eviction score: old=%f recent=%f", oldScore, recentScore)
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
func TestEvictionPriorityDecreasesWithUseCount(t *testing.T) {
|
|
33
|
+
now := time.Unix(10_000, 0)
|
|
34
|
+
k := defaultEvictionK
|
|
35
|
+
|
|
36
|
+
cold := loadedModel{
|
|
37
|
+
lastAccess: now.Add(-5 * time.Minute),
|
|
38
|
+
useCount: 1,
|
|
39
|
+
reservedBytes: 200 << 20,
|
|
40
|
+
}
|
|
41
|
+
warm := loadedModel{
|
|
42
|
+
lastAccess: now.Add(-5 * time.Minute),
|
|
43
|
+
useCount: 20,
|
|
44
|
+
reservedBytes: 200 << 20,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
coldScore := evictionPriority(cold, now, k)
|
|
48
|
+
warmScore := evictionPriority(warm, now, k)
|
|
49
|
+
|
|
50
|
+
if !(coldScore > warmScore) {
|
|
51
|
+
t.Fatalf("expected use-count damping to reduce eviction score: cold=%f warm=%f", coldScore, warmScore)
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
func TestEvictionPriorityHandlesZeroUseCountAndUsesDefaultK(t *testing.T) {
|
|
56
|
+
now := time.Unix(10_000, 0)
|
|
57
|
+
model := loadedModel{
|
|
58
|
+
lastAccess: now.Add(-10 * time.Minute),
|
|
59
|
+
useCount: 0,
|
|
60
|
+
reservedBytes: 200 << 20,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
score := evictionPriority(model, now, 0)
|
|
64
|
+
if score <= 0 {
|
|
65
|
+
t.Fatalf("expected positive eviction score for idle loaded model, got %f", score)
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
func TestEvictionPriorityHasLogarithmicDampingCurve(t *testing.T) {
|
|
70
|
+
now := time.Unix(10_000, 0)
|
|
71
|
+
k := defaultEvictionK
|
|
72
|
+
idle := 10 * time.Minute
|
|
73
|
+
size := int64(200 << 20)
|
|
74
|
+
|
|
75
|
+
modelForUseCount := func(useCount int) loadedModel {
|
|
76
|
+
return loadedModel{
|
|
77
|
+
lastAccess: now.Add(-idle),
|
|
78
|
+
useCount: useCount,
|
|
79
|
+
reservedBytes: size,
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
p1 := evictionPriority(modelForUseCount(1), now, k)
|
|
84
|
+
p100 := evictionPriority(modelForUseCount(100), now, k)
|
|
85
|
+
p10k := evictionPriority(modelForUseCount(10_000), now, k)
|
|
86
|
+
|
|
87
|
+
ratioLow := p1 / p100
|
|
88
|
+
ratioHigh := p100 / p10k
|
|
89
|
+
|
|
90
|
+
expectedLow := (1 + math.Log(101)) / (1 + math.Log(2))
|
|
91
|
+
expectedHigh := (1 + math.Log(10_001)) / (1 + math.Log(101))
|
|
92
|
+
|
|
93
|
+
if math.Abs(ratioLow-expectedLow) > 1e-9 {
|
|
94
|
+
t.Fatalf("unexpected low-range damping ratio: got %f want %f", ratioLow, expectedLow)
|
|
95
|
+
}
|
|
96
|
+
if math.Abs(ratioHigh-expectedHigh) > 1e-9 {
|
|
97
|
+
t.Fatalf("unexpected high-range damping ratio: got %f want %f", ratioHigh, expectedHigh)
|
|
98
|
+
}
|
|
99
|
+
if !(ratioLow > ratioHigh) {
|
|
100
|
+
t.Fatalf("expected stronger damping at low counts than high counts: low=%f high=%f", ratioLow, ratioHigh)
|
|
101
|
+
}
|
|
102
|
+
}
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
package model
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"fmt"
|
|
5
|
+
"strings"
|
|
6
|
+
|
|
7
|
+
ort "github.com/yalue/onnxruntime_go"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type Seq2SeqSpec struct {
|
|
11
|
+
Key string
|
|
12
|
+
Profile Profile
|
|
13
|
+
RuntimePath string
|
|
14
|
+
ModelPath string
|
|
15
|
+
EncoderPath string
|
|
16
|
+
DecoderPath string
|
|
17
|
+
TokenizerPath string
|
|
18
|
+
EncoderInputs []string
|
|
19
|
+
EncoderOutputs []string
|
|
20
|
+
DecoderInputs []string
|
|
21
|
+
DecoderOutputs []string
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
type Seq2SeqModel struct {
|
|
25
|
+
key string
|
|
26
|
+
registry *Registry
|
|
27
|
+
encoder *ort.DynamicAdvancedSession
|
|
28
|
+
decoder *ort.DynamicAdvancedSession
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
func (r *Registry) LoadSeq2Seq(spec Seq2SeqSpec) (*Seq2SeqModel, error) {
|
|
32
|
+
r.mu.Lock()
|
|
33
|
+
defer r.mu.Unlock()
|
|
34
|
+
|
|
35
|
+
if err := r.ensureRuntimeLocked(strings.TrimSpace(spec.RuntimePath)); err != nil {
|
|
36
|
+
return nil, fmt.Errorf("failed to initialize onnx runtime: %w", err)
|
|
37
|
+
}
|
|
38
|
+
if spec.Key == "" {
|
|
39
|
+
spec.Key = spec.Profile.Name
|
|
40
|
+
}
|
|
41
|
+
if loaded, ok := r.loaded[spec.Key]; ok && loaded.seq2seq != nil {
|
|
42
|
+
loaded.lastAccess = timeNow()
|
|
43
|
+
return loaded.seq2seq, nil
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
encoderPath := strings.TrimSpace(spec.EncoderPath)
|
|
47
|
+
decoderPath := strings.TrimSpace(spec.DecoderPath)
|
|
48
|
+
modelPath := strings.TrimSpace(spec.ModelPath)
|
|
49
|
+
|
|
50
|
+
if encoderPath == "" && decoderPath == "" && modelPath == "" {
|
|
51
|
+
return nil, fmt.Errorf("seq2seq model requires encoder, decoder, or model path")
|
|
52
|
+
}
|
|
53
|
+
if encoderPath == "" && modelPath != "" {
|
|
54
|
+
encoderPath = modelPath
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
var encoderSession *ort.DynamicAdvancedSession
|
|
58
|
+
var err error
|
|
59
|
+
if encoderPath != "" {
|
|
60
|
+
encoderSession, err = ort.NewDynamicAdvancedSession(encoderPath, spec.EncoderInputs, spec.EncoderOutputs, nil)
|
|
61
|
+
if err != nil {
|
|
62
|
+
return nil, fmt.Errorf("failed to create encoder session: %w", err)
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
var decoderSession *ort.DynamicAdvancedSession
|
|
67
|
+
if decoderPath != "" {
|
|
68
|
+
decoderSession, err = ort.NewDynamicAdvancedSession(decoderPath, spec.DecoderInputs, spec.DecoderOutputs, nil)
|
|
69
|
+
if err != nil {
|
|
70
|
+
if encoderSession != nil {
|
|
71
|
+
_ = encoderSession.Destroy()
|
|
72
|
+
}
|
|
73
|
+
return nil, fmt.Errorf("failed to create decoder session: %w", err)
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
seq2seq := &Seq2SeqModel{
|
|
78
|
+
key: spec.Key,
|
|
79
|
+
registry: r,
|
|
80
|
+
encoder: encoderSession,
|
|
81
|
+
decoder: decoderSession,
|
|
82
|
+
}
|
|
83
|
+
r.loaded[spec.Key] = &loadedModel{
|
|
84
|
+
key: spec.Key,
|
|
85
|
+
profile: spec.Profile,
|
|
86
|
+
lastAccess: timeNow(),
|
|
87
|
+
useCount: 0,
|
|
88
|
+
reservedBytes: fileSize(encoderPath) + fileSize(decoderPath) + fileSize(spec.ModelPath) + fileSize(spec.TokenizerPath),
|
|
89
|
+
closeFn: func() error {
|
|
90
|
+
if decoderSession != nil {
|
|
91
|
+
if err := decoderSession.Destroy(); err != nil {
|
|
92
|
+
return err
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
if encoderSession != nil {
|
|
96
|
+
if err := encoderSession.Destroy(); err != nil {
|
|
97
|
+
return err
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
return nil
|
|
101
|
+
},
|
|
102
|
+
seq2seq: seq2seq,
|
|
103
|
+
}
|
|
104
|
+
if err := r.maybeEvictLocked(timeNow()); err != nil {
|
|
105
|
+
return nil, err
|
|
106
|
+
}
|
|
107
|
+
return seq2seq, nil
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
func (m *Seq2SeqModel) Touch() {
|
|
111
|
+
if m == nil || m.registry == nil {
|
|
112
|
+
return
|
|
113
|
+
}
|
|
114
|
+
m.registry.mu.Lock()
|
|
115
|
+
m.registry.touchLocked(m.key)
|
|
116
|
+
m.registry.mu.Unlock()
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
func (m *Seq2SeqModel) RunEncoder(inputs, outputs []ort.Value) error {
|
|
120
|
+
if m == nil || m.encoder == nil {
|
|
121
|
+
return fmt.Errorf("encoder session not loaded")
|
|
122
|
+
}
|
|
123
|
+
m.Touch()
|
|
124
|
+
return m.encoder.Run(inputs, outputs)
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
func (m *Seq2SeqModel) RunDecoder(inputs, outputs []ort.Value) error {
|
|
128
|
+
if m == nil || m.decoder == nil {
|
|
129
|
+
return fmt.Errorf("decoder session not loaded")
|
|
130
|
+
}
|
|
131
|
+
m.Touch()
|
|
132
|
+
return m.decoder.Run(inputs, outputs)
|
|
133
|
+
}
|