eigen-db 4.0.2 → 4.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/lib/simd.wat CHANGED
@@ -4,217 +4,359 @@
4
4
 
5
5
  ;; normalize(ptr: i32, dimensions: i32)
6
6
  ;; Normalizes a vector in-place to unit length using SIMD.
7
- ;; ptr: byte offset of the vector in memory
8
- ;; dimensions: number of f32 elements (must be a multiple of 4 for SIMD path)
7
+ ;; Optimized with 4x loop unrolling (16 floats/iteration) and multiple accumulators.
9
8
  (func (export "normalize") (param $ptr i32) (param $dim i32)
10
9
  (local $i i32)
11
- (local $sum_vec v128)
10
+ (local $acc0 v128)
11
+ (local $acc1 v128)
12
+ (local $acc2 v128)
13
+ (local $acc3 v128)
12
14
  (local $sum f32)
13
15
  (local $mag f32)
14
16
  (local $inv_mag f32)
15
17
  (local $inv_vec v128)
18
+ (local $unroll_end i32)
16
19
  (local $simd_end i32)
17
- (local $remainder i32)
18
20
  (local $offset i32)
19
21
 
20
- ;; Phase 1: Compute sum of squares using SIMD (4 floats at a time)
21
- (local.set $sum_vec (v128.const f32x4 0 0 0 0))
22
- (local.set $simd_end (i32.and (local.get $dim) (i32.const -4))) ;; dim & ~3
22
+ ;; Phase 1: Sum of squares with 4x unroll and 4 independent accumulators
23
+ (local.set $acc0 (v128.const f32x4 0 0 0 0))
24
+ (local.set $acc1 (v128.const f32x4 0 0 0 0))
25
+ (local.set $acc2 (v128.const f32x4 0 0 0 0))
26
+ (local.set $acc3 (v128.const f32x4 0 0 0 0))
27
+ (local.set $unroll_end (i32.and (local.get $dim) (i32.const -16)))
28
+ (local.set $simd_end (i32.and (local.get $dim) (i32.const -4)))
23
29
  (local.set $i (i32.const 0))
24
30
 
25
- (block $break_sum
26
- (loop $loop_sum
27
- (br_if $break_sum (i32.ge_u (local.get $i) (local.get $simd_end)))
31
+ (block $break_sum_u
32
+ (loop $loop_sum_u
33
+ (br_if $break_sum_u (i32.ge_u (local.get $i) (local.get $unroll_end)))
28
34
  (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
29
- (local.set $sum_vec
30
- (f32x4.add
31
- (local.get $sum_vec)
32
- (f32x4.mul
33
- (v128.load (local.get $offset))
34
- (v128.load (local.get $offset))
35
- )
36
- )
37
- )
35
+
36
+ (local.set $acc0 (f32x4.add (local.get $acc0)
37
+ (f32x4.mul (v128.load (local.get $offset)) (v128.load (local.get $offset)))))
38
+ (local.set $acc1 (f32x4.add (local.get $acc1)
39
+ (f32x4.mul (v128.load offset=16 (local.get $offset)) (v128.load offset=16 (local.get $offset)))))
40
+ (local.set $acc2 (f32x4.add (local.get $acc2)
41
+ (f32x4.mul (v128.load offset=32 (local.get $offset)) (v128.load offset=32 (local.get $offset)))))
42
+ (local.set $acc3 (f32x4.add (local.get $acc3)
43
+ (f32x4.mul (v128.load offset=48 (local.get $offset)) (v128.load offset=48 (local.get $offset)))))
44
+
45
+ (local.set $i (i32.add (local.get $i) (i32.const 16)))
46
+ (br $loop_sum_u)
47
+ )
48
+ )
49
+
50
+ ;; Merge 4 accumulators
51
+ (local.set $acc0 (f32x4.add (f32x4.add (local.get $acc0) (local.get $acc1))
52
+ (f32x4.add (local.get $acc2) (local.get $acc3))))
53
+
54
+ ;; Remaining 4-wide chunks
55
+ (block $break_sum4
56
+ (loop $loop_sum4
57
+ (br_if $break_sum4 (i32.ge_u (local.get $i) (local.get $simd_end)))
58
+ (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
59
+ (local.set $acc0 (f32x4.add (local.get $acc0)
60
+ (f32x4.mul (v128.load (local.get $offset)) (v128.load (local.get $offset)))))
38
61
  (local.set $i (i32.add (local.get $i) (i32.const 4)))
39
- (br $loop_sum)
62
+ (br $loop_sum4)
40
63
  )
41
64
  )
42
65
 
43
- ;; Horizontal sum of SIMD lanes
66
+ ;; Horizontal sum
44
67
  (local.set $sum
45
68
  (f32.add
46
- (f32.add
47
- (f32x4.extract_lane 0 (local.get $sum_vec))
48
- (f32x4.extract_lane 1 (local.get $sum_vec))
49
- )
50
- (f32.add
51
- (f32x4.extract_lane 2 (local.get $sum_vec))
52
- (f32x4.extract_lane 3 (local.get $sum_vec))
53
- )
54
- )
55
- )
69
+ (f32.add (f32x4.extract_lane 0 (local.get $acc0)) (f32x4.extract_lane 1 (local.get $acc0)))
70
+ (f32.add (f32x4.extract_lane 2 (local.get $acc0)) (f32x4.extract_lane 3 (local.get $acc0)))))
56
71
 
57
- ;; Handle remainder elements (dim % 4)
58
- (local.set $remainder (local.get $simd_end))
72
+ ;; Scalar remainder
59
73
  (block $break_rem_sum
60
74
  (loop $loop_rem_sum
61
- (br_if $break_rem_sum (i32.ge_u (local.get $remainder) (local.get $dim)))
62
- (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $remainder) (i32.const 2))))
63
- (local.set $sum
64
- (f32.add
65
- (local.get $sum)
66
- (f32.mul
67
- (f32.load (local.get $offset))
68
- (f32.load (local.get $offset))
69
- )
70
- )
71
- )
72
- (local.set $remainder (i32.add (local.get $remainder) (i32.const 1)))
75
+ (br_if $break_rem_sum (i32.ge_u (local.get $i) (local.get $dim)))
76
+ (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
77
+ (local.set $sum (f32.add (local.get $sum)
78
+ (f32.mul (f32.load (local.get $offset)) (f32.load (local.get $offset)))))
79
+ (local.set $i (i32.add (local.get $i) (i32.const 1)))
73
80
  (br $loop_rem_sum)
74
81
  )
75
82
  )
76
83
 
77
- ;; Compute magnitude and check for zero
84
+ ;; Magnitude check
78
85
  (local.set $mag (f32.sqrt (local.get $sum)))
79
86
  (if (f32.eq (local.get $mag) (f32.const 0))
80
- (then (return))
81
- )
87
+ (then (return)))
82
88
 
83
- ;; Phase 2: Divide each element by magnitude using SIMD
89
+ ;; Phase 2: Scale by inverse magnitude (4x unrolled)
84
90
  (local.set $inv_mag (f32.div (f32.const 1) (local.get $mag)))
85
91
  (local.set $inv_vec (f32x4.splat (local.get $inv_mag)))
86
92
  (local.set $i (i32.const 0))
87
93
 
88
- (block $break_norm
89
- (loop $loop_norm
90
- (br_if $break_norm (i32.ge_u (local.get $i) (local.get $simd_end)))
94
+ (block $break_norm_u
95
+ (loop $loop_norm_u
96
+ (br_if $break_norm_u (i32.ge_u (local.get $i) (local.get $unroll_end)))
91
97
  (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
92
- (v128.store
93
- (local.get $offset)
94
- (f32x4.mul
95
- (v128.load (local.get $offset))
96
- (local.get $inv_vec)
97
- )
98
- )
98
+
99
+ (v128.store (local.get $offset)
100
+ (f32x4.mul (v128.load (local.get $offset)) (local.get $inv_vec)))
101
+ (v128.store offset=16 (local.get $offset)
102
+ (f32x4.mul (v128.load offset=16 (local.get $offset)) (local.get $inv_vec)))
103
+ (v128.store offset=32 (local.get $offset)
104
+ (f32x4.mul (v128.load offset=32 (local.get $offset)) (local.get $inv_vec)))
105
+ (v128.store offset=48 (local.get $offset)
106
+ (f32x4.mul (v128.load offset=48 (local.get $offset)) (local.get $inv_vec)))
107
+
108
+ (local.set $i (i32.add (local.get $i) (i32.const 16)))
109
+ (br $loop_norm_u)
110
+ )
111
+ )
112
+
113
+ ;; Remaining 4-wide chunks
114
+ (block $break_norm4
115
+ (loop $loop_norm4
116
+ (br_if $break_norm4 (i32.ge_u (local.get $i) (local.get $simd_end)))
117
+ (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
118
+ (v128.store (local.get $offset)
119
+ (f32x4.mul (v128.load (local.get $offset)) (local.get $inv_vec)))
99
120
  (local.set $i (i32.add (local.get $i) (i32.const 4)))
100
- (br $loop_norm)
121
+ (br $loop_norm4)
101
122
  )
102
123
  )
103
124
 
104
- ;; Handle remainder elements
105
- (local.set $remainder (local.get $simd_end))
125
+ ;; Scalar remainder
106
126
  (block $break_rem_norm
107
127
  (loop $loop_rem_norm
108
- (br_if $break_rem_norm (i32.ge_u (local.get $remainder) (local.get $dim)))
109
- (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $remainder) (i32.const 2))))
110
- (f32.store
111
- (local.get $offset)
112
- (f32.mul
113
- (f32.load (local.get $offset))
114
- (local.get $inv_mag)
115
- )
116
- )
117
- (local.set $remainder (i32.add (local.get $remainder) (i32.const 1)))
128
+ (br_if $break_rem_norm (i32.ge_u (local.get $i) (local.get $dim)))
129
+ (local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
130
+ (f32.store (local.get $offset)
131
+ (f32.mul (f32.load (local.get $offset)) (local.get $inv_mag)))
132
+ (local.set $i (i32.add (local.get $i) (i32.const 1)))
118
133
  (br $loop_rem_norm)
119
134
  )
120
135
  )
121
136
  )
122
137
 
123
- ;; search_all(query_ptr: i32, db_ptr: i32, scores_ptr: i32, db_size: i32, dimensions: i32)
138
+ ;; search_all(query_ptr, db_ptr, scores_ptr, db_size, dimensions)
124
139
  ;; Computes dot products of query against every vector in the database.
125
- ;; Uses 128-bit SIMD for 4-wide f32 multiply-accumulate.
140
+ ;; Optimized with:
141
+ ;; - 2-vector outer loop unrolling (halves query memory reads)
142
+ ;; - 4x inner loop unrolling (16 floats/iteration, 4 accumulators per vector)
126
143
  (func (export "search_all") (param $query_ptr i32) (param $db_ptr i32) (param $scores_ptr i32) (param $db_size i32) (param $dim i32)
127
144
  (local $i i32)
128
145
  (local $j i32)
129
- (local $acc v128)
130
- (local $dot f32)
131
- (local $vec_ptr i32)
146
+ (local $accA0 v128)
147
+ (local $accA1 v128)
148
+ (local $accA2 v128)
149
+ (local $accA3 v128)
150
+ (local $accB0 v128)
151
+ (local $accB1 v128)
152
+ (local $accB2 v128)
153
+ (local $accB3 v128)
154
+ (local $q0 v128)
155
+ (local $q1 v128)
156
+ (local $q2 v128)
157
+ (local $q3 v128)
158
+ (local $dotA f32)
159
+ (local $dotB f32)
160
+ (local $vec_ptrA i32)
161
+ (local $vec_ptrB i32)
162
+ (local $unroll_end i32)
132
163
  (local $simd_end i32)
133
- (local $remainder i32)
134
164
  (local $q_offset i32)
135
- (local $v_offset i32)
165
+ (local $vA_offset i32)
166
+ (local $vB_offset i32)
136
167
  (local $bytes_per_vec i32)
168
+ (local $pair_end i32)
137
169
 
170
+ (local.set $unroll_end (i32.and (local.get $dim) (i32.const -16)))
138
171
  (local.set $simd_end (i32.and (local.get $dim) (i32.const -4)))
139
172
  (local.set $bytes_per_vec (i32.shl (local.get $dim) (i32.const 2)))
173
+ (local.set $pair_end (i32.and (local.get $db_size) (i32.const -2)))
140
174
  (local.set $i (i32.const 0))
141
175
 
176
+ ;; Main loop: process 2 database vectors per iteration
142
177
  (block $break_outer
143
178
  (loop $loop_outer
144
- (br_if $break_outer (i32.ge_u (local.get $i) (local.get $db_size)))
179
+ (br_if $break_outer (i32.ge_u (local.get $i) (local.get $pair_end)))
145
180
 
146
- ;; Pointer to the i-th database vector
147
- (local.set $vec_ptr
148
- (i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec)))
149
- )
181
+ (local.set $vec_ptrA
182
+ (i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec))))
183
+ (local.set $vec_ptrB
184
+ (i32.add (local.get $vec_ptrA) (local.get $bytes_per_vec)))
150
185
 
151
- ;; SIMD dot product accumulator
152
- (local.set $acc (v128.const f32x4 0 0 0 0))
186
+ (local.set $accA0 (v128.const f32x4 0 0 0 0))
187
+ (local.set $accA1 (v128.const f32x4 0 0 0 0))
188
+ (local.set $accA2 (v128.const f32x4 0 0 0 0))
189
+ (local.set $accA3 (v128.const f32x4 0 0 0 0))
190
+ (local.set $accB0 (v128.const f32x4 0 0 0 0))
191
+ (local.set $accB1 (v128.const f32x4 0 0 0 0))
192
+ (local.set $accB2 (v128.const f32x4 0 0 0 0))
193
+ (local.set $accB3 (v128.const f32x4 0 0 0 0))
153
194
  (local.set $j (i32.const 0))
154
195
 
196
+ ;; Inner loop: load query once, dot product against both vectors
155
197
  (block $break_inner
156
198
  (loop $loop_inner
157
- (br_if $break_inner (i32.ge_u (local.get $j) (local.get $simd_end)))
199
+ (br_if $break_inner (i32.ge_u (local.get $j) (local.get $unroll_end)))
158
200
  (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
159
- (local.set $v_offset (i32.add (local.get $vec_ptr) (i32.shl (local.get $j) (i32.const 2))))
160
- (local.set $acc
161
- (f32x4.add
162
- (local.get $acc)
163
- (f32x4.mul
164
- (v128.load (local.get $q_offset))
165
- (v128.load (local.get $v_offset))
166
- )
167
- )
168
- )
169
- (local.set $j (i32.add (local.get $j) (i32.const 4)))
201
+ (local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
202
+ (local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
203
+
204
+ ;; Load query chunk (shared between both vectors)
205
+ (local.set $q0 (v128.load (local.get $q_offset)))
206
+ (local.set $q1 (v128.load offset=16 (local.get $q_offset)))
207
+ (local.set $q2 (v128.load offset=32 (local.get $q_offset)))
208
+ (local.set $q3 (v128.load offset=48 (local.get $q_offset)))
209
+
210
+ ;; Vector A
211
+ (local.set $accA0 (f32x4.add (local.get $accA0)
212
+ (f32x4.mul (local.get $q0) (v128.load (local.get $vA_offset)))))
213
+ (local.set $accA1 (f32x4.add (local.get $accA1)
214
+ (f32x4.mul (local.get $q1) (v128.load offset=16 (local.get $vA_offset)))))
215
+ (local.set $accA2 (f32x4.add (local.get $accA2)
216
+ (f32x4.mul (local.get $q2) (v128.load offset=32 (local.get $vA_offset)))))
217
+ (local.set $accA3 (f32x4.add (local.get $accA3)
218
+ (f32x4.mul (local.get $q3) (v128.load offset=48 (local.get $vA_offset)))))
219
+
220
+ ;; Vector B (reuses query loads)
221
+ (local.set $accB0 (f32x4.add (local.get $accB0)
222
+ (f32x4.mul (local.get $q0) (v128.load (local.get $vB_offset)))))
223
+ (local.set $accB1 (f32x4.add (local.get $accB1)
224
+ (f32x4.mul (local.get $q1) (v128.load offset=16 (local.get $vB_offset)))))
225
+ (local.set $accB2 (f32x4.add (local.get $accB2)
226
+ (f32x4.mul (local.get $q2) (v128.load offset=32 (local.get $vB_offset)))))
227
+ (local.set $accB3 (f32x4.add (local.get $accB3)
228
+ (f32x4.mul (local.get $q3) (v128.load offset=48 (local.get $vB_offset)))))
229
+
230
+ (local.set $j (i32.add (local.get $j) (i32.const 16)))
170
231
  (br $loop_inner)
171
232
  )
172
233
  )
173
234
 
174
- ;; Horizontal sum of SIMD accumulator
175
- (local.set $dot
176
- (f32.add
177
- (f32.add
178
- (f32x4.extract_lane 0 (local.get $acc))
179
- (f32x4.extract_lane 1 (local.get $acc))
180
- )
181
- (f32.add
182
- (f32x4.extract_lane 2 (local.get $acc))
183
- (f32x4.extract_lane 3 (local.get $acc))
184
- )
235
+ ;; Merge accumulators
236
+ (local.set $accA0 (f32x4.add (f32x4.add (local.get $accA0) (local.get $accA1))
237
+ (f32x4.add (local.get $accA2) (local.get $accA3))))
238
+ (local.set $accB0 (f32x4.add (f32x4.add (local.get $accB0) (local.get $accB1))
239
+ (f32x4.add (local.get $accB2) (local.get $accB3))))
240
+
241
+ ;; 4-wide cleanup (both vectors)
242
+ (block $break_inner4
243
+ (loop $loop_inner4
244
+ (br_if $break_inner4 (i32.ge_u (local.get $j) (local.get $simd_end)))
245
+ (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
246
+ (local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
247
+ (local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
248
+ (local.set $q0 (v128.load (local.get $q_offset)))
249
+ (local.set $accA0 (f32x4.add (local.get $accA0) (f32x4.mul (local.get $q0) (v128.load (local.get $vA_offset)))))
250
+ (local.set $accB0 (f32x4.add (local.get $accB0) (f32x4.mul (local.get $q0) (v128.load (local.get $vB_offset)))))
251
+ (local.set $j (i32.add (local.get $j) (i32.const 4)))
252
+ (br $loop_inner4)
185
253
  )
186
254
  )
187
255
 
188
- ;; Handle remainder elements (dim % 4)
189
- (local.set $remainder (local.get $simd_end))
256
+ ;; Horizontal sums
257
+ (local.set $dotA
258
+ (f32.add
259
+ (f32.add (f32x4.extract_lane 0 (local.get $accA0)) (f32x4.extract_lane 1 (local.get $accA0)))
260
+ (f32.add (f32x4.extract_lane 2 (local.get $accA0)) (f32x4.extract_lane 3 (local.get $accA0)))))
261
+ (local.set $dotB
262
+ (f32.add
263
+ (f32.add (f32x4.extract_lane 0 (local.get $accB0)) (f32x4.extract_lane 1 (local.get $accB0)))
264
+ (f32.add (f32x4.extract_lane 2 (local.get $accB0)) (f32x4.extract_lane 3 (local.get $accB0)))))
265
+
266
+ ;; Scalar remainder (both vectors)
190
267
  (block $break_rem
191
268
  (loop $loop_rem
192
- (br_if $break_rem (i32.ge_u (local.get $remainder) (local.get $dim)))
193
- (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $remainder) (i32.const 2))))
194
- (local.set $v_offset (i32.add (local.get $vec_ptr) (i32.shl (local.get $remainder) (i32.const 2))))
195
- (local.set $dot
196
- (f32.add
197
- (local.get $dot)
198
- (f32.mul
199
- (f32.load (local.get $q_offset))
200
- (f32.load (local.get $v_offset))
201
- )
202
- )
203
- )
204
- (local.set $remainder (i32.add (local.get $remainder) (i32.const 1)))
269
+ (br_if $break_rem (i32.ge_u (local.get $j) (local.get $dim)))
270
+ (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
271
+ (local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
272
+ (local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
273
+ (local.set $dotA (f32.add (local.get $dotA)
274
+ (f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vA_offset)))))
275
+ (local.set $dotB (f32.add (local.get $dotB)
276
+ (f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vB_offset)))))
277
+ (local.set $j (i32.add (local.get $j) (i32.const 1)))
205
278
  (br $loop_rem)
206
279
  )
207
280
  )
208
281
 
209
- ;; Store score for vector i
282
+ ;; Store scores for vectors i and i+1
210
283
  (f32.store
211
284
  (i32.add (local.get $scores_ptr) (i32.shl (local.get $i) (i32.const 2)))
212
- (local.get $dot)
213
- )
285
+ (local.get $dotA))
286
+ (f32.store
287
+ (i32.add (local.get $scores_ptr) (i32.shl (i32.add (local.get $i) (i32.const 1)) (i32.const 2)))
288
+ (local.get $dotB))
214
289
 
215
- (local.set $i (i32.add (local.get $i) (i32.const 1)))
290
+ (local.set $i (i32.add (local.get $i) (i32.const 2)))
216
291
  (br $loop_outer)
217
292
  )
218
293
  )
294
+
295
+ ;; Handle last vector if db_size is odd
296
+ (if (i32.lt_u (local.get $i) (local.get $db_size))
297
+ (then
298
+ (local.set $vec_ptrA
299
+ (i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec))))
300
+ (local.set $accA0 (v128.const f32x4 0 0 0 0))
301
+ (local.set $accA1 (v128.const f32x4 0 0 0 0))
302
+ (local.set $accA2 (v128.const f32x4 0 0 0 0))
303
+ (local.set $accA3 (v128.const f32x4 0 0 0 0))
304
+ (local.set $j (i32.const 0))
305
+
306
+ (block $break_last
307
+ (loop $loop_last
308
+ (br_if $break_last (i32.ge_u (local.get $j) (local.get $unroll_end)))
309
+ (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
310
+ (local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
311
+ (local.set $accA0 (f32x4.add (local.get $accA0)
312
+ (f32x4.mul (v128.load (local.get $q_offset)) (v128.load (local.get $vA_offset)))))
313
+ (local.set $accA1 (f32x4.add (local.get $accA1)
314
+ (f32x4.mul (v128.load offset=16 (local.get $q_offset)) (v128.load offset=16 (local.get $vA_offset)))))
315
+ (local.set $accA2 (f32x4.add (local.get $accA2)
316
+ (f32x4.mul (v128.load offset=32 (local.get $q_offset)) (v128.load offset=32 (local.get $vA_offset)))))
317
+ (local.set $accA3 (f32x4.add (local.get $accA3)
318
+ (f32x4.mul (v128.load offset=48 (local.get $q_offset)) (v128.load offset=48 (local.get $vA_offset)))))
319
+ (local.set $j (i32.add (local.get $j) (i32.const 16)))
320
+ (br $loop_last)
321
+ )
322
+ )
323
+
324
+ (local.set $accA0 (f32x4.add (f32x4.add (local.get $accA0) (local.get $accA1))
325
+ (f32x4.add (local.get $accA2) (local.get $accA3))))
326
+
327
+ (block $break_last4
328
+ (loop $loop_last4
329
+ (br_if $break_last4 (i32.ge_u (local.get $j) (local.get $simd_end)))
330
+ (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
331
+ (local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
332
+ (local.set $accA0 (f32x4.add (local.get $accA0)
333
+ (f32x4.mul (v128.load (local.get $q_offset)) (v128.load (local.get $vA_offset)))))
334
+ (local.set $j (i32.add (local.get $j) (i32.const 4)))
335
+ (br $loop_last4)
336
+ )
337
+ )
338
+
339
+ (local.set $dotA
340
+ (f32.add
341
+ (f32.add (f32x4.extract_lane 0 (local.get $accA0)) (f32x4.extract_lane 1 (local.get $accA0)))
342
+ (f32.add (f32x4.extract_lane 2 (local.get $accA0)) (f32x4.extract_lane 3 (local.get $accA0)))))
343
+
344
+ (block $break_last_rem
345
+ (loop $loop_last_rem
346
+ (br_if $break_last_rem (i32.ge_u (local.get $j) (local.get $dim)))
347
+ (local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
348
+ (local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
349
+ (local.set $dotA (f32.add (local.get $dotA)
350
+ (f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vA_offset)))))
351
+ (local.set $j (i32.add (local.get $j) (i32.const 1)))
352
+ (br $loop_last_rem)
353
+ )
354
+ )
355
+
356
+ (f32.store
357
+ (i32.add (local.get $scores_ptr) (i32.shl (local.get $i) (i32.const 2)))
358
+ (local.get $dotA))
359
+ )
360
+ )
219
361
  )
220
362
  )
package/src/lib/types.ts CHANGED
@@ -39,8 +39,10 @@ export interface SetOptions {
39
39
  * Returns a plain ResultItem[] array by default.
40
40
  */
41
41
  export interface QueryOptions {
42
- /** Maximum number of results to return. Defaults to all. */
42
+ /** Maximum number of results to return. Defaults to Infinity (all results). */
43
43
  topK?: number;
44
+ /** Maximum distance threshold (inclusive). Results with distance > maxDistance are excluded. */
45
+ maxDistance?: number;
44
46
  /** Override normalization for this call. */
45
47
  normalize?: boolean;
46
48
  /** When true, returns an Iterable<ResultItem> instead of ResultItem[]. */
@@ -185,16 +185,20 @@ export class VectorDB {
185
185
  /**
186
186
  * Search for the most similar vectors to the given query vector.
187
187
  *
188
- * Default: returns a plain ResultItem[] sorted by descending similarity.
188
+ * Default: returns a plain ResultItem[] sorted by ascending distance.
189
189
  * With `{ iterable: true }`: returns a lazy Iterable<ResultItem> where keys
190
190
  * are resolved only as each item is consumed.
191
+ *
192
+ * Distance is defined as `1 - dotProduct`. With normalization (default),
193
+ * this equals cosine distance: 0 = identical, 2 = opposite.
191
194
  */
192
195
  query(value: VectorInput, options: QueryOptions & { iterable: true }): Iterable<ResultItem>;
193
196
  query(value: VectorInput, options?: QueryOptions): ResultItem[];
194
197
  query(value: VectorInput, options?: QueryOptions): ResultItem[] | Iterable<ResultItem> {
195
198
  this.assertOpen();
196
199
 
197
- const k = options?.topK ?? this.size;
200
+ const k = options?.topK ?? Infinity;
201
+ const maxDistance = options?.maxDistance;
198
202
  const iterable = options && "iterable" in options && options.iterable;
199
203
 
200
204
  if (this.size === 0) {
@@ -256,9 +260,9 @@ export class VectorDB {
256
260
  };
257
261
 
258
262
  if (iterable) {
259
- return iterableResults(scores, resolveKey, k);
263
+ return iterableResults(scores, resolveKey, k, maxDistance);
260
264
  }
261
- return topKResults(scores, resolveKey, k);
265
+ return topKResults(scores, resolveKey, k, maxDistance);
262
266
  }
263
267
 
264
268
  /**