eigen-db 4.0.2 → 4.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +4 -0
- package/README.md +70 -20
- package/dist/eigen-db.js +192 -191
- package/dist/eigen-db.js.map +1 -1
- package/dist/eigen-db.umd.cjs +1 -1
- package/dist/eigen-db.umd.cjs.map +1 -1
- package/package.json +1 -1
- package/src/lib/__tests__/result-set.test.ts +53 -8
- package/src/lib/__tests__/vector-db.test.ts +64 -9
- package/src/lib/result-set.ts +28 -8
- package/src/lib/simd-binary.ts +1 -1
- package/src/lib/simd.wat +270 -128
- package/src/lib/types.ts +3 -1
- package/src/lib/vector-db.ts +8 -4
package/src/lib/simd.wat
CHANGED
|
@@ -4,217 +4,359 @@
|
|
|
4
4
|
|
|
5
5
|
;; normalize(ptr: i32, dimensions: i32)
|
|
6
6
|
;; Normalizes a vector in-place to unit length using SIMD.
|
|
7
|
-
;;
|
|
8
|
-
;; dimensions: number of f32 elements (must be a multiple of 4 for SIMD path)
|
|
7
|
+
;; Optimized with 4x loop unrolling (16 floats/iteration) and multiple accumulators.
|
|
9
8
|
(func (export "normalize") (param $ptr i32) (param $dim i32)
|
|
10
9
|
(local $i i32)
|
|
11
|
-
(local $
|
|
10
|
+
(local $acc0 v128)
|
|
11
|
+
(local $acc1 v128)
|
|
12
|
+
(local $acc2 v128)
|
|
13
|
+
(local $acc3 v128)
|
|
12
14
|
(local $sum f32)
|
|
13
15
|
(local $mag f32)
|
|
14
16
|
(local $inv_mag f32)
|
|
15
17
|
(local $inv_vec v128)
|
|
18
|
+
(local $unroll_end i32)
|
|
16
19
|
(local $simd_end i32)
|
|
17
|
-
(local $remainder i32)
|
|
18
20
|
(local $offset i32)
|
|
19
21
|
|
|
20
|
-
;; Phase 1:
|
|
21
|
-
(local.set $
|
|
22
|
-
(local.set $
|
|
22
|
+
;; Phase 1: Sum of squares with 4x unroll and 4 independent accumulators
|
|
23
|
+
(local.set $acc0 (v128.const f32x4 0 0 0 0))
|
|
24
|
+
(local.set $acc1 (v128.const f32x4 0 0 0 0))
|
|
25
|
+
(local.set $acc2 (v128.const f32x4 0 0 0 0))
|
|
26
|
+
(local.set $acc3 (v128.const f32x4 0 0 0 0))
|
|
27
|
+
(local.set $unroll_end (i32.and (local.get $dim) (i32.const -16)))
|
|
28
|
+
(local.set $simd_end (i32.and (local.get $dim) (i32.const -4)))
|
|
23
29
|
(local.set $i (i32.const 0))
|
|
24
30
|
|
|
25
|
-
(block $
|
|
26
|
-
(loop $
|
|
27
|
-
(br_if $
|
|
31
|
+
(block $break_sum_u
|
|
32
|
+
(loop $loop_sum_u
|
|
33
|
+
(br_if $break_sum_u (i32.ge_u (local.get $i) (local.get $unroll_end)))
|
|
28
34
|
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
35
|
+
|
|
36
|
+
(local.set $acc0 (f32x4.add (local.get $acc0)
|
|
37
|
+
(f32x4.mul (v128.load (local.get $offset)) (v128.load (local.get $offset)))))
|
|
38
|
+
(local.set $acc1 (f32x4.add (local.get $acc1)
|
|
39
|
+
(f32x4.mul (v128.load offset=16 (local.get $offset)) (v128.load offset=16 (local.get $offset)))))
|
|
40
|
+
(local.set $acc2 (f32x4.add (local.get $acc2)
|
|
41
|
+
(f32x4.mul (v128.load offset=32 (local.get $offset)) (v128.load offset=32 (local.get $offset)))))
|
|
42
|
+
(local.set $acc3 (f32x4.add (local.get $acc3)
|
|
43
|
+
(f32x4.mul (v128.load offset=48 (local.get $offset)) (v128.load offset=48 (local.get $offset)))))
|
|
44
|
+
|
|
45
|
+
(local.set $i (i32.add (local.get $i) (i32.const 16)))
|
|
46
|
+
(br $loop_sum_u)
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
;; Merge 4 accumulators
|
|
51
|
+
(local.set $acc0 (f32x4.add (f32x4.add (local.get $acc0) (local.get $acc1))
|
|
52
|
+
(f32x4.add (local.get $acc2) (local.get $acc3))))
|
|
53
|
+
|
|
54
|
+
;; Remaining 4-wide chunks
|
|
55
|
+
(block $break_sum4
|
|
56
|
+
(loop $loop_sum4
|
|
57
|
+
(br_if $break_sum4 (i32.ge_u (local.get $i) (local.get $simd_end)))
|
|
58
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
59
|
+
(local.set $acc0 (f32x4.add (local.get $acc0)
|
|
60
|
+
(f32x4.mul (v128.load (local.get $offset)) (v128.load (local.get $offset)))))
|
|
38
61
|
(local.set $i (i32.add (local.get $i) (i32.const 4)))
|
|
39
|
-
(br $
|
|
62
|
+
(br $loop_sum4)
|
|
40
63
|
)
|
|
41
64
|
)
|
|
42
65
|
|
|
43
|
-
;; Horizontal sum
|
|
66
|
+
;; Horizontal sum
|
|
44
67
|
(local.set $sum
|
|
45
68
|
(f32.add
|
|
46
|
-
(f32.add
|
|
47
|
-
|
|
48
|
-
(f32x4.extract_lane 1 (local.get $sum_vec))
|
|
49
|
-
)
|
|
50
|
-
(f32.add
|
|
51
|
-
(f32x4.extract_lane 2 (local.get $sum_vec))
|
|
52
|
-
(f32x4.extract_lane 3 (local.get $sum_vec))
|
|
53
|
-
)
|
|
54
|
-
)
|
|
55
|
-
)
|
|
69
|
+
(f32.add (f32x4.extract_lane 0 (local.get $acc0)) (f32x4.extract_lane 1 (local.get $acc0)))
|
|
70
|
+
(f32.add (f32x4.extract_lane 2 (local.get $acc0)) (f32x4.extract_lane 3 (local.get $acc0)))))
|
|
56
71
|
|
|
57
|
-
;;
|
|
58
|
-
(local.set $remainder (local.get $simd_end))
|
|
72
|
+
;; Scalar remainder
|
|
59
73
|
(block $break_rem_sum
|
|
60
74
|
(loop $loop_rem_sum
|
|
61
|
-
(br_if $break_rem_sum (i32.ge_u (local.get $
|
|
62
|
-
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $
|
|
63
|
-
(local.set $sum
|
|
64
|
-
(f32.
|
|
65
|
-
|
|
66
|
-
(f32.mul
|
|
67
|
-
(f32.load (local.get $offset))
|
|
68
|
-
(f32.load (local.get $offset))
|
|
69
|
-
)
|
|
70
|
-
)
|
|
71
|
-
)
|
|
72
|
-
(local.set $remainder (i32.add (local.get $remainder) (i32.const 1)))
|
|
75
|
+
(br_if $break_rem_sum (i32.ge_u (local.get $i) (local.get $dim)))
|
|
76
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
77
|
+
(local.set $sum (f32.add (local.get $sum)
|
|
78
|
+
(f32.mul (f32.load (local.get $offset)) (f32.load (local.get $offset)))))
|
|
79
|
+
(local.set $i (i32.add (local.get $i) (i32.const 1)))
|
|
73
80
|
(br $loop_rem_sum)
|
|
74
81
|
)
|
|
75
82
|
)
|
|
76
83
|
|
|
77
|
-
;;
|
|
84
|
+
;; Magnitude check
|
|
78
85
|
(local.set $mag (f32.sqrt (local.get $sum)))
|
|
79
86
|
(if (f32.eq (local.get $mag) (f32.const 0))
|
|
80
|
-
(then (return))
|
|
81
|
-
)
|
|
87
|
+
(then (return)))
|
|
82
88
|
|
|
83
|
-
;; Phase 2:
|
|
89
|
+
;; Phase 2: Scale by inverse magnitude (4x unrolled)
|
|
84
90
|
(local.set $inv_mag (f32.div (f32.const 1) (local.get $mag)))
|
|
85
91
|
(local.set $inv_vec (f32x4.splat (local.get $inv_mag)))
|
|
86
92
|
(local.set $i (i32.const 0))
|
|
87
93
|
|
|
88
|
-
(block $
|
|
89
|
-
(loop $
|
|
90
|
-
(br_if $
|
|
94
|
+
(block $break_norm_u
|
|
95
|
+
(loop $loop_norm_u
|
|
96
|
+
(br_if $break_norm_u (i32.ge_u (local.get $i) (local.get $unroll_end)))
|
|
91
97
|
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
(f32x4.mul
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
98
|
+
|
|
99
|
+
(v128.store (local.get $offset)
|
|
100
|
+
(f32x4.mul (v128.load (local.get $offset)) (local.get $inv_vec)))
|
|
101
|
+
(v128.store offset=16 (local.get $offset)
|
|
102
|
+
(f32x4.mul (v128.load offset=16 (local.get $offset)) (local.get $inv_vec)))
|
|
103
|
+
(v128.store offset=32 (local.get $offset)
|
|
104
|
+
(f32x4.mul (v128.load offset=32 (local.get $offset)) (local.get $inv_vec)))
|
|
105
|
+
(v128.store offset=48 (local.get $offset)
|
|
106
|
+
(f32x4.mul (v128.load offset=48 (local.get $offset)) (local.get $inv_vec)))
|
|
107
|
+
|
|
108
|
+
(local.set $i (i32.add (local.get $i) (i32.const 16)))
|
|
109
|
+
(br $loop_norm_u)
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
;; Remaining 4-wide chunks
|
|
114
|
+
(block $break_norm4
|
|
115
|
+
(loop $loop_norm4
|
|
116
|
+
(br_if $break_norm4 (i32.ge_u (local.get $i) (local.get $simd_end)))
|
|
117
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
118
|
+
(v128.store (local.get $offset)
|
|
119
|
+
(f32x4.mul (v128.load (local.get $offset)) (local.get $inv_vec)))
|
|
99
120
|
(local.set $i (i32.add (local.get $i) (i32.const 4)))
|
|
100
|
-
(br $
|
|
121
|
+
(br $loop_norm4)
|
|
101
122
|
)
|
|
102
123
|
)
|
|
103
124
|
|
|
104
|
-
;;
|
|
105
|
-
(local.set $remainder (local.get $simd_end))
|
|
125
|
+
;; Scalar remainder
|
|
106
126
|
(block $break_rem_norm
|
|
107
127
|
(loop $loop_rem_norm
|
|
108
|
-
(br_if $break_rem_norm (i32.ge_u (local.get $
|
|
109
|
-
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $
|
|
110
|
-
(f32.store
|
|
111
|
-
(local.get $offset)
|
|
112
|
-
|
|
113
|
-
(f32.load (local.get $offset))
|
|
114
|
-
(local.get $inv_mag)
|
|
115
|
-
)
|
|
116
|
-
)
|
|
117
|
-
(local.set $remainder (i32.add (local.get $remainder) (i32.const 1)))
|
|
128
|
+
(br_if $break_rem_norm (i32.ge_u (local.get $i) (local.get $dim)))
|
|
129
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
130
|
+
(f32.store (local.get $offset)
|
|
131
|
+
(f32.mul (f32.load (local.get $offset)) (local.get $inv_mag)))
|
|
132
|
+
(local.set $i (i32.add (local.get $i) (i32.const 1)))
|
|
118
133
|
(br $loop_rem_norm)
|
|
119
134
|
)
|
|
120
135
|
)
|
|
121
136
|
)
|
|
122
137
|
|
|
123
|
-
;; search_all(query_ptr
|
|
138
|
+
;; search_all(query_ptr, db_ptr, scores_ptr, db_size, dimensions)
|
|
124
139
|
;; Computes dot products of query against every vector in the database.
|
|
125
|
-
;;
|
|
140
|
+
;; Optimized with:
|
|
141
|
+
;; - 2-vector outer loop unrolling (halves query memory reads)
|
|
142
|
+
;; - 4x inner loop unrolling (16 floats/iteration, 4 accumulators per vector)
|
|
126
143
|
(func (export "search_all") (param $query_ptr i32) (param $db_ptr i32) (param $scores_ptr i32) (param $db_size i32) (param $dim i32)
|
|
127
144
|
(local $i i32)
|
|
128
145
|
(local $j i32)
|
|
129
|
-
(local $
|
|
130
|
-
(local $
|
|
131
|
-
(local $
|
|
146
|
+
(local $accA0 v128)
|
|
147
|
+
(local $accA1 v128)
|
|
148
|
+
(local $accA2 v128)
|
|
149
|
+
(local $accA3 v128)
|
|
150
|
+
(local $accB0 v128)
|
|
151
|
+
(local $accB1 v128)
|
|
152
|
+
(local $accB2 v128)
|
|
153
|
+
(local $accB3 v128)
|
|
154
|
+
(local $q0 v128)
|
|
155
|
+
(local $q1 v128)
|
|
156
|
+
(local $q2 v128)
|
|
157
|
+
(local $q3 v128)
|
|
158
|
+
(local $dotA f32)
|
|
159
|
+
(local $dotB f32)
|
|
160
|
+
(local $vec_ptrA i32)
|
|
161
|
+
(local $vec_ptrB i32)
|
|
162
|
+
(local $unroll_end i32)
|
|
132
163
|
(local $simd_end i32)
|
|
133
|
-
(local $remainder i32)
|
|
134
164
|
(local $q_offset i32)
|
|
135
|
-
(local $
|
|
165
|
+
(local $vA_offset i32)
|
|
166
|
+
(local $vB_offset i32)
|
|
136
167
|
(local $bytes_per_vec i32)
|
|
168
|
+
(local $pair_end i32)
|
|
137
169
|
|
|
170
|
+
(local.set $unroll_end (i32.and (local.get $dim) (i32.const -16)))
|
|
138
171
|
(local.set $simd_end (i32.and (local.get $dim) (i32.const -4)))
|
|
139
172
|
(local.set $bytes_per_vec (i32.shl (local.get $dim) (i32.const 2)))
|
|
173
|
+
(local.set $pair_end (i32.and (local.get $db_size) (i32.const -2)))
|
|
140
174
|
(local.set $i (i32.const 0))
|
|
141
175
|
|
|
176
|
+
;; Main loop: process 2 database vectors per iteration
|
|
142
177
|
(block $break_outer
|
|
143
178
|
(loop $loop_outer
|
|
144
|
-
(br_if $break_outer (i32.ge_u (local.get $i) (local.get $
|
|
179
|
+
(br_if $break_outer (i32.ge_u (local.get $i) (local.get $pair_end)))
|
|
145
180
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
181
|
+
(local.set $vec_ptrA
|
|
182
|
+
(i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec))))
|
|
183
|
+
(local.set $vec_ptrB
|
|
184
|
+
(i32.add (local.get $vec_ptrA) (local.get $bytes_per_vec)))
|
|
150
185
|
|
|
151
|
-
|
|
152
|
-
(local.set $
|
|
186
|
+
(local.set $accA0 (v128.const f32x4 0 0 0 0))
|
|
187
|
+
(local.set $accA1 (v128.const f32x4 0 0 0 0))
|
|
188
|
+
(local.set $accA2 (v128.const f32x4 0 0 0 0))
|
|
189
|
+
(local.set $accA3 (v128.const f32x4 0 0 0 0))
|
|
190
|
+
(local.set $accB0 (v128.const f32x4 0 0 0 0))
|
|
191
|
+
(local.set $accB1 (v128.const f32x4 0 0 0 0))
|
|
192
|
+
(local.set $accB2 (v128.const f32x4 0 0 0 0))
|
|
193
|
+
(local.set $accB3 (v128.const f32x4 0 0 0 0))
|
|
153
194
|
(local.set $j (i32.const 0))
|
|
154
195
|
|
|
196
|
+
;; Inner loop: load query once, dot product against both vectors
|
|
155
197
|
(block $break_inner
|
|
156
198
|
(loop $loop_inner
|
|
157
|
-
(br_if $break_inner (i32.ge_u (local.get $j) (local.get $
|
|
199
|
+
(br_if $break_inner (i32.ge_u (local.get $j) (local.get $unroll_end)))
|
|
158
200
|
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
159
|
-
(local.set $
|
|
160
|
-
(local.set $
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
(local.set $
|
|
201
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
202
|
+
(local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
|
|
203
|
+
|
|
204
|
+
;; Load query chunk (shared between both vectors)
|
|
205
|
+
(local.set $q0 (v128.load (local.get $q_offset)))
|
|
206
|
+
(local.set $q1 (v128.load offset=16 (local.get $q_offset)))
|
|
207
|
+
(local.set $q2 (v128.load offset=32 (local.get $q_offset)))
|
|
208
|
+
(local.set $q3 (v128.load offset=48 (local.get $q_offset)))
|
|
209
|
+
|
|
210
|
+
;; Vector A
|
|
211
|
+
(local.set $accA0 (f32x4.add (local.get $accA0)
|
|
212
|
+
(f32x4.mul (local.get $q0) (v128.load (local.get $vA_offset)))))
|
|
213
|
+
(local.set $accA1 (f32x4.add (local.get $accA1)
|
|
214
|
+
(f32x4.mul (local.get $q1) (v128.load offset=16 (local.get $vA_offset)))))
|
|
215
|
+
(local.set $accA2 (f32x4.add (local.get $accA2)
|
|
216
|
+
(f32x4.mul (local.get $q2) (v128.load offset=32 (local.get $vA_offset)))))
|
|
217
|
+
(local.set $accA3 (f32x4.add (local.get $accA3)
|
|
218
|
+
(f32x4.mul (local.get $q3) (v128.load offset=48 (local.get $vA_offset)))))
|
|
219
|
+
|
|
220
|
+
;; Vector B (reuses query loads)
|
|
221
|
+
(local.set $accB0 (f32x4.add (local.get $accB0)
|
|
222
|
+
(f32x4.mul (local.get $q0) (v128.load (local.get $vB_offset)))))
|
|
223
|
+
(local.set $accB1 (f32x4.add (local.get $accB1)
|
|
224
|
+
(f32x4.mul (local.get $q1) (v128.load offset=16 (local.get $vB_offset)))))
|
|
225
|
+
(local.set $accB2 (f32x4.add (local.get $accB2)
|
|
226
|
+
(f32x4.mul (local.get $q2) (v128.load offset=32 (local.get $vB_offset)))))
|
|
227
|
+
(local.set $accB3 (f32x4.add (local.get $accB3)
|
|
228
|
+
(f32x4.mul (local.get $q3) (v128.load offset=48 (local.get $vB_offset)))))
|
|
229
|
+
|
|
230
|
+
(local.set $j (i32.add (local.get $j) (i32.const 16)))
|
|
170
231
|
(br $loop_inner)
|
|
171
232
|
)
|
|
172
233
|
)
|
|
173
234
|
|
|
174
|
-
;;
|
|
175
|
-
(local.set $
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
)
|
|
235
|
+
;; Merge accumulators
|
|
236
|
+
(local.set $accA0 (f32x4.add (f32x4.add (local.get $accA0) (local.get $accA1))
|
|
237
|
+
(f32x4.add (local.get $accA2) (local.get $accA3))))
|
|
238
|
+
(local.set $accB0 (f32x4.add (f32x4.add (local.get $accB0) (local.get $accB1))
|
|
239
|
+
(f32x4.add (local.get $accB2) (local.get $accB3))))
|
|
240
|
+
|
|
241
|
+
;; 4-wide cleanup (both vectors)
|
|
242
|
+
(block $break_inner4
|
|
243
|
+
(loop $loop_inner4
|
|
244
|
+
(br_if $break_inner4 (i32.ge_u (local.get $j) (local.get $simd_end)))
|
|
245
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
246
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
247
|
+
(local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
|
|
248
|
+
(local.set $q0 (v128.load (local.get $q_offset)))
|
|
249
|
+
(local.set $accA0 (f32x4.add (local.get $accA0) (f32x4.mul (local.get $q0) (v128.load (local.get $vA_offset)))))
|
|
250
|
+
(local.set $accB0 (f32x4.add (local.get $accB0) (f32x4.mul (local.get $q0) (v128.load (local.get $vB_offset)))))
|
|
251
|
+
(local.set $j (i32.add (local.get $j) (i32.const 4)))
|
|
252
|
+
(br $loop_inner4)
|
|
185
253
|
)
|
|
186
254
|
)
|
|
187
255
|
|
|
188
|
-
;;
|
|
189
|
-
(local.set $
|
|
256
|
+
;; Horizontal sums
|
|
257
|
+
(local.set $dotA
|
|
258
|
+
(f32.add
|
|
259
|
+
(f32.add (f32x4.extract_lane 0 (local.get $accA0)) (f32x4.extract_lane 1 (local.get $accA0)))
|
|
260
|
+
(f32.add (f32x4.extract_lane 2 (local.get $accA0)) (f32x4.extract_lane 3 (local.get $accA0)))))
|
|
261
|
+
(local.set $dotB
|
|
262
|
+
(f32.add
|
|
263
|
+
(f32.add (f32x4.extract_lane 0 (local.get $accB0)) (f32x4.extract_lane 1 (local.get $accB0)))
|
|
264
|
+
(f32.add (f32x4.extract_lane 2 (local.get $accB0)) (f32x4.extract_lane 3 (local.get $accB0)))))
|
|
265
|
+
|
|
266
|
+
;; Scalar remainder (both vectors)
|
|
190
267
|
(block $break_rem
|
|
191
268
|
(loop $loop_rem
|
|
192
|
-
(br_if $break_rem (i32.ge_u (local.get $
|
|
193
|
-
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $
|
|
194
|
-
(local.set $
|
|
195
|
-
(local.set $
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
)
|
|
202
|
-
)
|
|
203
|
-
)
|
|
204
|
-
(local.set $remainder (i32.add (local.get $remainder) (i32.const 1)))
|
|
269
|
+
(br_if $break_rem (i32.ge_u (local.get $j) (local.get $dim)))
|
|
270
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
271
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
272
|
+
(local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
|
|
273
|
+
(local.set $dotA (f32.add (local.get $dotA)
|
|
274
|
+
(f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vA_offset)))))
|
|
275
|
+
(local.set $dotB (f32.add (local.get $dotB)
|
|
276
|
+
(f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vB_offset)))))
|
|
277
|
+
(local.set $j (i32.add (local.get $j) (i32.const 1)))
|
|
205
278
|
(br $loop_rem)
|
|
206
279
|
)
|
|
207
280
|
)
|
|
208
281
|
|
|
209
|
-
;; Store
|
|
282
|
+
;; Store scores for vectors i and i+1
|
|
210
283
|
(f32.store
|
|
211
284
|
(i32.add (local.get $scores_ptr) (i32.shl (local.get $i) (i32.const 2)))
|
|
212
|
-
(local.get $
|
|
213
|
-
|
|
285
|
+
(local.get $dotA))
|
|
286
|
+
(f32.store
|
|
287
|
+
(i32.add (local.get $scores_ptr) (i32.shl (i32.add (local.get $i) (i32.const 1)) (i32.const 2)))
|
|
288
|
+
(local.get $dotB))
|
|
214
289
|
|
|
215
|
-
(local.set $i (i32.add (local.get $i) (i32.const
|
|
290
|
+
(local.set $i (i32.add (local.get $i) (i32.const 2)))
|
|
216
291
|
(br $loop_outer)
|
|
217
292
|
)
|
|
218
293
|
)
|
|
294
|
+
|
|
295
|
+
;; Handle last vector if db_size is odd
|
|
296
|
+
(if (i32.lt_u (local.get $i) (local.get $db_size))
|
|
297
|
+
(then
|
|
298
|
+
(local.set $vec_ptrA
|
|
299
|
+
(i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec))))
|
|
300
|
+
(local.set $accA0 (v128.const f32x4 0 0 0 0))
|
|
301
|
+
(local.set $accA1 (v128.const f32x4 0 0 0 0))
|
|
302
|
+
(local.set $accA2 (v128.const f32x4 0 0 0 0))
|
|
303
|
+
(local.set $accA3 (v128.const f32x4 0 0 0 0))
|
|
304
|
+
(local.set $j (i32.const 0))
|
|
305
|
+
|
|
306
|
+
(block $break_last
|
|
307
|
+
(loop $loop_last
|
|
308
|
+
(br_if $break_last (i32.ge_u (local.get $j) (local.get $unroll_end)))
|
|
309
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
310
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
311
|
+
(local.set $accA0 (f32x4.add (local.get $accA0)
|
|
312
|
+
(f32x4.mul (v128.load (local.get $q_offset)) (v128.load (local.get $vA_offset)))))
|
|
313
|
+
(local.set $accA1 (f32x4.add (local.get $accA1)
|
|
314
|
+
(f32x4.mul (v128.load offset=16 (local.get $q_offset)) (v128.load offset=16 (local.get $vA_offset)))))
|
|
315
|
+
(local.set $accA2 (f32x4.add (local.get $accA2)
|
|
316
|
+
(f32x4.mul (v128.load offset=32 (local.get $q_offset)) (v128.load offset=32 (local.get $vA_offset)))))
|
|
317
|
+
(local.set $accA3 (f32x4.add (local.get $accA3)
|
|
318
|
+
(f32x4.mul (v128.load offset=48 (local.get $q_offset)) (v128.load offset=48 (local.get $vA_offset)))))
|
|
319
|
+
(local.set $j (i32.add (local.get $j) (i32.const 16)))
|
|
320
|
+
(br $loop_last)
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
(local.set $accA0 (f32x4.add (f32x4.add (local.get $accA0) (local.get $accA1))
|
|
325
|
+
(f32x4.add (local.get $accA2) (local.get $accA3))))
|
|
326
|
+
|
|
327
|
+
(block $break_last4
|
|
328
|
+
(loop $loop_last4
|
|
329
|
+
(br_if $break_last4 (i32.ge_u (local.get $j) (local.get $simd_end)))
|
|
330
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
331
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
332
|
+
(local.set $accA0 (f32x4.add (local.get $accA0)
|
|
333
|
+
(f32x4.mul (v128.load (local.get $q_offset)) (v128.load (local.get $vA_offset)))))
|
|
334
|
+
(local.set $j (i32.add (local.get $j) (i32.const 4)))
|
|
335
|
+
(br $loop_last4)
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
(local.set $dotA
|
|
340
|
+
(f32.add
|
|
341
|
+
(f32.add (f32x4.extract_lane 0 (local.get $accA0)) (f32x4.extract_lane 1 (local.get $accA0)))
|
|
342
|
+
(f32.add (f32x4.extract_lane 2 (local.get $accA0)) (f32x4.extract_lane 3 (local.get $accA0)))))
|
|
343
|
+
|
|
344
|
+
(block $break_last_rem
|
|
345
|
+
(loop $loop_last_rem
|
|
346
|
+
(br_if $break_last_rem (i32.ge_u (local.get $j) (local.get $dim)))
|
|
347
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
348
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
349
|
+
(local.set $dotA (f32.add (local.get $dotA)
|
|
350
|
+
(f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vA_offset)))))
|
|
351
|
+
(local.set $j (i32.add (local.get $j) (i32.const 1)))
|
|
352
|
+
(br $loop_last_rem)
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
(f32.store
|
|
357
|
+
(i32.add (local.get $scores_ptr) (i32.shl (local.get $i) (i32.const 2)))
|
|
358
|
+
(local.get $dotA))
|
|
359
|
+
)
|
|
360
|
+
)
|
|
219
361
|
)
|
|
220
362
|
)
|
package/src/lib/types.ts
CHANGED
|
@@ -39,8 +39,10 @@ export interface SetOptions {
|
|
|
39
39
|
* Returns a plain ResultItem[] array by default.
|
|
40
40
|
*/
|
|
41
41
|
export interface QueryOptions {
|
|
42
|
-
/** Maximum number of results to return. Defaults to all. */
|
|
42
|
+
/** Maximum number of results to return. Defaults to Infinity (all results). */
|
|
43
43
|
topK?: number;
|
|
44
|
+
/** Maximum distance threshold (inclusive). Results with distance > maxDistance are excluded. */
|
|
45
|
+
maxDistance?: number;
|
|
44
46
|
/** Override normalization for this call. */
|
|
45
47
|
normalize?: boolean;
|
|
46
48
|
/** When true, returns an Iterable<ResultItem> instead of ResultItem[]. */
|
package/src/lib/vector-db.ts
CHANGED
|
@@ -185,16 +185,20 @@ export class VectorDB {
|
|
|
185
185
|
/**
|
|
186
186
|
* Search for the most similar vectors to the given query vector.
|
|
187
187
|
*
|
|
188
|
-
* Default: returns a plain ResultItem[] sorted by
|
|
188
|
+
* Default: returns a plain ResultItem[] sorted by ascending distance.
|
|
189
189
|
* With `{ iterable: true }`: returns a lazy Iterable<ResultItem> where keys
|
|
190
190
|
* are resolved only as each item is consumed.
|
|
191
|
+
*
|
|
192
|
+
* Distance is defined as `1 - dotProduct`. With normalization (default),
|
|
193
|
+
* this equals cosine distance: 0 = identical, 2 = opposite.
|
|
191
194
|
*/
|
|
192
195
|
query(value: VectorInput, options: QueryOptions & { iterable: true }): Iterable<ResultItem>;
|
|
193
196
|
query(value: VectorInput, options?: QueryOptions): ResultItem[];
|
|
194
197
|
query(value: VectorInput, options?: QueryOptions): ResultItem[] | Iterable<ResultItem> {
|
|
195
198
|
this.assertOpen();
|
|
196
199
|
|
|
197
|
-
const k = options?.topK ??
|
|
200
|
+
const k = options?.topK ?? Infinity;
|
|
201
|
+
const maxDistance = options?.maxDistance;
|
|
198
202
|
const iterable = options && "iterable" in options && options.iterable;
|
|
199
203
|
|
|
200
204
|
if (this.size === 0) {
|
|
@@ -256,9 +260,9 @@ export class VectorDB {
|
|
|
256
260
|
};
|
|
257
261
|
|
|
258
262
|
if (iterable) {
|
|
259
|
-
return iterableResults(scores, resolveKey, k);
|
|
263
|
+
return iterableResults(scores, resolveKey, k, maxDistance);
|
|
260
264
|
}
|
|
261
|
-
return topKResults(scores, resolveKey, k);
|
|
265
|
+
return topKResults(scores, resolveKey, k, maxDistance);
|
|
262
266
|
}
|
|
263
267
|
|
|
264
268
|
/**
|