eigen-db 4.1.0 → 4.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/README.md +79 -27
- package/dist/eigen-db.js +317 -195
- package/dist/eigen-db.js.map +1 -1
- package/dist/eigen-db.umd.cjs +1 -1
- package/dist/eigen-db.umd.cjs.map +1 -1
- package/package.json +1 -1
- package/src/lib/__tests__/result-set.test.ts +19 -19
- package/src/lib/__tests__/vector-db.test.ts +429 -16
- package/src/lib/memory-manager.ts +8 -0
- package/src/lib/result-set.ts +16 -15
- package/src/lib/simd-binary.ts +1 -1
- package/src/lib/simd-optimized.wat +362 -0
- package/src/lib/simd.wat +42 -248
- package/src/lib/types.ts +4 -6
- package/src/lib/vector-db.ts +241 -9
package/src/lib/result-set.ts
CHANGED
|
@@ -7,27 +7,28 @@
|
|
|
7
7
|
* 2. iterableResults — returns a lazy Iterable<ResultItem> where keys are
|
|
8
8
|
* resolved only as each item is consumed (for pagination / streaming)
|
|
9
9
|
*
|
|
10
|
-
*
|
|
11
|
-
* this equals cosine
|
|
10
|
+
* Similarity is the dot product of query and stored vectors. For normalized
|
|
11
|
+
* vectors (the default), this equals cosine similarity, ranging from 1
|
|
12
|
+
* (identical) to -1 (opposite).
|
|
12
13
|
*/
|
|
13
14
|
|
|
14
15
|
export interface ResultItem {
|
|
15
16
|
key: string;
|
|
16
|
-
|
|
17
|
+
similarity: number;
|
|
17
18
|
}
|
|
18
19
|
|
|
19
20
|
export type KeyResolver = (index: number) => string;
|
|
20
21
|
|
|
21
22
|
/**
|
|
22
|
-
* Sort by
|
|
23
|
+
* Sort by descending similarity and return the top K results as a plain array.
|
|
23
24
|
* All keys are resolved eagerly.
|
|
24
|
-
* If
|
|
25
|
+
* If minSimilarity is provided, results with similarity < minSimilarity are excluded.
|
|
25
26
|
*/
|
|
26
27
|
export function topKResults(
|
|
27
28
|
scores: Float32Array,
|
|
28
29
|
resolveKey: KeyResolver,
|
|
29
30
|
topK: number,
|
|
30
|
-
|
|
31
|
+
minSimilarity?: number,
|
|
31
32
|
): ResultItem[] {
|
|
32
33
|
const n = scores.length;
|
|
33
34
|
if (n === 0) return [];
|
|
@@ -40,19 +41,19 @@ export function topKResults(
|
|
|
40
41
|
const results: ResultItem[] = [];
|
|
41
42
|
for (let i = 0; i < k; i++) {
|
|
42
43
|
const idx = indices[i];
|
|
43
|
-
const
|
|
44
|
-
if (
|
|
45
|
-
results.push({ key: resolveKey(idx),
|
|
44
|
+
const similarity = scores[idx];
|
|
45
|
+
if (minSimilarity !== undefined && similarity < minSimilarity) break;
|
|
46
|
+
results.push({ key: resolveKey(idx), similarity });
|
|
46
47
|
}
|
|
47
48
|
return results;
|
|
48
49
|
}
|
|
49
50
|
|
|
50
51
|
/**
|
|
51
|
-
* Sort by
|
|
52
|
+
* Sort by descending similarity and return a lazy iterable over the top K results.
|
|
52
53
|
* Keys are resolved only when each item is consumed, saving allocations
|
|
53
54
|
* when the caller iterates partially (e.g., pagination).
|
|
54
55
|
*
|
|
55
|
-
* If
|
|
56
|
+
* If minSimilarity is provided, iteration stops when similarity < minSimilarity.
|
|
56
57
|
*
|
|
57
58
|
* The returned iterable is re-iterable — each call to [Symbol.iterator]()
|
|
58
59
|
* produces a fresh cursor over the same pre-sorted data.
|
|
@@ -61,7 +62,7 @@ export function iterableResults(
|
|
|
61
62
|
scores: Float32Array,
|
|
62
63
|
resolveKey: KeyResolver,
|
|
63
64
|
topK: number,
|
|
64
|
-
|
|
65
|
+
minSimilarity?: number,
|
|
65
66
|
): Iterable<ResultItem> {
|
|
66
67
|
const n = scores.length;
|
|
67
68
|
if (n === 0) return [];
|
|
@@ -79,11 +80,11 @@ export function iterableResults(
|
|
|
79
80
|
next(): IteratorResult<ResultItem> {
|
|
80
81
|
if (i >= k) return { done: true, value: undefined };
|
|
81
82
|
const idx = indices[i++];
|
|
82
|
-
const
|
|
83
|
-
if (
|
|
83
|
+
const similarity = scores[idx];
|
|
84
|
+
if (minSimilarity !== undefined && similarity < minSimilarity) return { done: true, value: undefined };
|
|
84
85
|
return {
|
|
85
86
|
done: false,
|
|
86
|
-
value: { key: resolveKey(idx),
|
|
87
|
+
value: { key: resolveKey(idx), similarity },
|
|
87
88
|
};
|
|
88
89
|
},
|
|
89
90
|
};
|
package/src/lib/simd-binary.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
// AUTO-GENERATED - Do not edit. Run: npx tsx scripts/compile-wat.ts
|
|
2
|
-
const SIMD_WASM_BASE64 = "AGFzbQEAAAABDgJgAn9/AGAFf39/
|
|
2
|
+
const SIMD_WASM_BASE64 = "AGFzbQEAAAABDgJgAn9/AGAFf39/f38AAg8BA2VudgZtZW1vcnkCAAEDAwIAAQcaAglub3JtYWxpemUAAApzZWFyY2hfYWxsAAEKpgQCrQIFAX8BewN9AXsCf/0MAAAAAAAAAAAAAAAAAAAAACEDIAFBfHEhCEEAIQICQANAIAIgCE8NASAAIAJBAnRqIQkgAyAJ/QAEACAJ/QAEAP3mAf3kASEDIAJBBGohAgwACwsgA/0fACAD/R8BkiAD/R8CIAP9HwOSkiEEAkADQCACIAFPDQEgACACQQJ0aiEJIAQgCSoCACAJKgIAlJIhBCACQQFqIQIMAAsLIASRIQUgBUMAAAAAWwRADwtDAACAPyAFlSEGIAb9EyEHQQAhAgJAA0AgAiAITw0BIAAgAkECdGohCSAJIAn9AAQAIAf95gH9CwQAIAJBBGohAgwACwsCQANAIAIgAU8NASAAIAJBAnRqIQkgCSAJKgIAIAaUOAIAIAJBAWohAgwACwsL9AEEAn8BewF9BX8gBEF8cSEKIARBAnQhDUEAIQUCQANAIAUgA08NASABIAUgDWxqIQn9DAAAAAAAAAAAAAAAAAAAAAAhB0EAIQYCQANAIAYgCk8NASAAIAZBAnRqIQsgCSAGQQJ0aiEMIAcgC/0ABAAgDP0ABAD95gH95AEhByAGQQRqIQYMAAsLIAf9HwAgB/0fAZIgB/0fAiAH/R8DkpIhCAJAA0AgBiAETw0BIAAgBkECdGohCyAJIAZBAnRqIQwgCCALKgIAIAwqAgCUkiEIIAZBAWohBgwACwsgAiAFQQJ0aiAIOAIAIAVBAWohBQwACwsL";
|
|
3
3
|
|
|
4
4
|
export function getSimdWasmBinary(): Uint8Array {
|
|
5
5
|
const binaryString = atob(SIMD_WASM_BASE64);
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
(module
|
|
2
|
+
;; Import shared memory from JavaScript host
|
|
3
|
+
(import "env" "memory" (memory 1))
|
|
4
|
+
|
|
5
|
+
;; normalize(ptr: i32, dimensions: i32)
|
|
6
|
+
;; Normalizes a vector in-place to unit length using SIMD.
|
|
7
|
+
;; Optimized with 4x loop unrolling (16 floats/iteration) and multiple accumulators.
|
|
8
|
+
(func (export "normalize") (param $ptr i32) (param $dim i32)
|
|
9
|
+
(local $i i32)
|
|
10
|
+
(local $acc0 v128)
|
|
11
|
+
(local $acc1 v128)
|
|
12
|
+
(local $acc2 v128)
|
|
13
|
+
(local $acc3 v128)
|
|
14
|
+
(local $sum f32)
|
|
15
|
+
(local $mag f32)
|
|
16
|
+
(local $inv_mag f32)
|
|
17
|
+
(local $inv_vec v128)
|
|
18
|
+
(local $unroll_end i32)
|
|
19
|
+
(local $simd_end i32)
|
|
20
|
+
(local $offset i32)
|
|
21
|
+
|
|
22
|
+
;; Phase 1: Sum of squares with 4x unroll and 4 independent accumulators
|
|
23
|
+
(local.set $acc0 (v128.const f32x4 0 0 0 0))
|
|
24
|
+
(local.set $acc1 (v128.const f32x4 0 0 0 0))
|
|
25
|
+
(local.set $acc2 (v128.const f32x4 0 0 0 0))
|
|
26
|
+
(local.set $acc3 (v128.const f32x4 0 0 0 0))
|
|
27
|
+
(local.set $unroll_end (i32.and (local.get $dim) (i32.const -16)))
|
|
28
|
+
(local.set $simd_end (i32.and (local.get $dim) (i32.const -4)))
|
|
29
|
+
(local.set $i (i32.const 0))
|
|
30
|
+
|
|
31
|
+
(block $break_sum_u
|
|
32
|
+
(loop $loop_sum_u
|
|
33
|
+
(br_if $break_sum_u (i32.ge_u (local.get $i) (local.get $unroll_end)))
|
|
34
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
35
|
+
|
|
36
|
+
(local.set $acc0 (f32x4.add (local.get $acc0)
|
|
37
|
+
(f32x4.mul (v128.load (local.get $offset)) (v128.load (local.get $offset)))))
|
|
38
|
+
(local.set $acc1 (f32x4.add (local.get $acc1)
|
|
39
|
+
(f32x4.mul (v128.load offset=16 (local.get $offset)) (v128.load offset=16 (local.get $offset)))))
|
|
40
|
+
(local.set $acc2 (f32x4.add (local.get $acc2)
|
|
41
|
+
(f32x4.mul (v128.load offset=32 (local.get $offset)) (v128.load offset=32 (local.get $offset)))))
|
|
42
|
+
(local.set $acc3 (f32x4.add (local.get $acc3)
|
|
43
|
+
(f32x4.mul (v128.load offset=48 (local.get $offset)) (v128.load offset=48 (local.get $offset)))))
|
|
44
|
+
|
|
45
|
+
(local.set $i (i32.add (local.get $i) (i32.const 16)))
|
|
46
|
+
(br $loop_sum_u)
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
;; Merge 4 accumulators
|
|
51
|
+
(local.set $acc0 (f32x4.add (f32x4.add (local.get $acc0) (local.get $acc1))
|
|
52
|
+
(f32x4.add (local.get $acc2) (local.get $acc3))))
|
|
53
|
+
|
|
54
|
+
;; Remaining 4-wide chunks
|
|
55
|
+
(block $break_sum4
|
|
56
|
+
(loop $loop_sum4
|
|
57
|
+
(br_if $break_sum4 (i32.ge_u (local.get $i) (local.get $simd_end)))
|
|
58
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
59
|
+
(local.set $acc0 (f32x4.add (local.get $acc0)
|
|
60
|
+
(f32x4.mul (v128.load (local.get $offset)) (v128.load (local.get $offset)))))
|
|
61
|
+
(local.set $i (i32.add (local.get $i) (i32.const 4)))
|
|
62
|
+
(br $loop_sum4)
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
;; Horizontal sum
|
|
67
|
+
(local.set $sum
|
|
68
|
+
(f32.add
|
|
69
|
+
(f32.add (f32x4.extract_lane 0 (local.get $acc0)) (f32x4.extract_lane 1 (local.get $acc0)))
|
|
70
|
+
(f32.add (f32x4.extract_lane 2 (local.get $acc0)) (f32x4.extract_lane 3 (local.get $acc0)))))
|
|
71
|
+
|
|
72
|
+
;; Scalar remainder
|
|
73
|
+
(block $break_rem_sum
|
|
74
|
+
(loop $loop_rem_sum
|
|
75
|
+
(br_if $break_rem_sum (i32.ge_u (local.get $i) (local.get $dim)))
|
|
76
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
77
|
+
(local.set $sum (f32.add (local.get $sum)
|
|
78
|
+
(f32.mul (f32.load (local.get $offset)) (f32.load (local.get $offset)))))
|
|
79
|
+
(local.set $i (i32.add (local.get $i) (i32.const 1)))
|
|
80
|
+
(br $loop_rem_sum)
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
;; Magnitude check
|
|
85
|
+
(local.set $mag (f32.sqrt (local.get $sum)))
|
|
86
|
+
(if (f32.eq (local.get $mag) (f32.const 0))
|
|
87
|
+
(then (return)))
|
|
88
|
+
|
|
89
|
+
;; Phase 2: Scale by inverse magnitude (4x unrolled)
|
|
90
|
+
(local.set $inv_mag (f32.div (f32.const 1) (local.get $mag)))
|
|
91
|
+
(local.set $inv_vec (f32x4.splat (local.get $inv_mag)))
|
|
92
|
+
(local.set $i (i32.const 0))
|
|
93
|
+
|
|
94
|
+
(block $break_norm_u
|
|
95
|
+
(loop $loop_norm_u
|
|
96
|
+
(br_if $break_norm_u (i32.ge_u (local.get $i) (local.get $unroll_end)))
|
|
97
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
98
|
+
|
|
99
|
+
(v128.store (local.get $offset)
|
|
100
|
+
(f32x4.mul (v128.load (local.get $offset)) (local.get $inv_vec)))
|
|
101
|
+
(v128.store offset=16 (local.get $offset)
|
|
102
|
+
(f32x4.mul (v128.load offset=16 (local.get $offset)) (local.get $inv_vec)))
|
|
103
|
+
(v128.store offset=32 (local.get $offset)
|
|
104
|
+
(f32x4.mul (v128.load offset=32 (local.get $offset)) (local.get $inv_vec)))
|
|
105
|
+
(v128.store offset=48 (local.get $offset)
|
|
106
|
+
(f32x4.mul (v128.load offset=48 (local.get $offset)) (local.get $inv_vec)))
|
|
107
|
+
|
|
108
|
+
(local.set $i (i32.add (local.get $i) (i32.const 16)))
|
|
109
|
+
(br $loop_norm_u)
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
;; Remaining 4-wide chunks
|
|
114
|
+
(block $break_norm4
|
|
115
|
+
(loop $loop_norm4
|
|
116
|
+
(br_if $break_norm4 (i32.ge_u (local.get $i) (local.get $simd_end)))
|
|
117
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
118
|
+
(v128.store (local.get $offset)
|
|
119
|
+
(f32x4.mul (v128.load (local.get $offset)) (local.get $inv_vec)))
|
|
120
|
+
(local.set $i (i32.add (local.get $i) (i32.const 4)))
|
|
121
|
+
(br $loop_norm4)
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
;; Scalar remainder
|
|
126
|
+
(block $break_rem_norm
|
|
127
|
+
(loop $loop_rem_norm
|
|
128
|
+
(br_if $break_rem_norm (i32.ge_u (local.get $i) (local.get $dim)))
|
|
129
|
+
(local.set $offset (i32.add (local.get $ptr) (i32.shl (local.get $i) (i32.const 2))))
|
|
130
|
+
(f32.store (local.get $offset)
|
|
131
|
+
(f32.mul (f32.load (local.get $offset)) (local.get $inv_mag)))
|
|
132
|
+
(local.set $i (i32.add (local.get $i) (i32.const 1)))
|
|
133
|
+
(br $loop_rem_norm)
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
;; search_all(query_ptr, db_ptr, scores_ptr, db_size, dimensions)
|
|
139
|
+
;; Computes dot products of query against every vector in the database.
|
|
140
|
+
;; Optimized with:
|
|
141
|
+
;; - 2-vector outer loop unrolling (halves query memory reads)
|
|
142
|
+
;; - 4x inner loop unrolling (16 floats/iteration, 4 accumulators per vector)
|
|
143
|
+
(func (export "search_all") (param $query_ptr i32) (param $db_ptr i32) (param $scores_ptr i32) (param $db_size i32) (param $dim i32)
|
|
144
|
+
(local $i i32)
|
|
145
|
+
(local $j i32)
|
|
146
|
+
(local $accA0 v128)
|
|
147
|
+
(local $accA1 v128)
|
|
148
|
+
(local $accA2 v128)
|
|
149
|
+
(local $accA3 v128)
|
|
150
|
+
(local $accB0 v128)
|
|
151
|
+
(local $accB1 v128)
|
|
152
|
+
(local $accB2 v128)
|
|
153
|
+
(local $accB3 v128)
|
|
154
|
+
(local $q0 v128)
|
|
155
|
+
(local $q1 v128)
|
|
156
|
+
(local $q2 v128)
|
|
157
|
+
(local $q3 v128)
|
|
158
|
+
(local $dotA f32)
|
|
159
|
+
(local $dotB f32)
|
|
160
|
+
(local $vec_ptrA i32)
|
|
161
|
+
(local $vec_ptrB i32)
|
|
162
|
+
(local $unroll_end i32)
|
|
163
|
+
(local $simd_end i32)
|
|
164
|
+
(local $q_offset i32)
|
|
165
|
+
(local $vA_offset i32)
|
|
166
|
+
(local $vB_offset i32)
|
|
167
|
+
(local $bytes_per_vec i32)
|
|
168
|
+
(local $pair_end i32)
|
|
169
|
+
|
|
170
|
+
(local.set $unroll_end (i32.and (local.get $dim) (i32.const -16)))
|
|
171
|
+
(local.set $simd_end (i32.and (local.get $dim) (i32.const -4)))
|
|
172
|
+
(local.set $bytes_per_vec (i32.shl (local.get $dim) (i32.const 2)))
|
|
173
|
+
(local.set $pair_end (i32.and (local.get $db_size) (i32.const -2)))
|
|
174
|
+
(local.set $i (i32.const 0))
|
|
175
|
+
|
|
176
|
+
;; Main loop: process 2 database vectors per iteration
|
|
177
|
+
(block $break_outer
|
|
178
|
+
(loop $loop_outer
|
|
179
|
+
(br_if $break_outer (i32.ge_u (local.get $i) (local.get $pair_end)))
|
|
180
|
+
|
|
181
|
+
(local.set $vec_ptrA
|
|
182
|
+
(i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec))))
|
|
183
|
+
(local.set $vec_ptrB
|
|
184
|
+
(i32.add (local.get $vec_ptrA) (local.get $bytes_per_vec)))
|
|
185
|
+
|
|
186
|
+
(local.set $accA0 (v128.const f32x4 0 0 0 0))
|
|
187
|
+
(local.set $accA1 (v128.const f32x4 0 0 0 0))
|
|
188
|
+
(local.set $accA2 (v128.const f32x4 0 0 0 0))
|
|
189
|
+
(local.set $accA3 (v128.const f32x4 0 0 0 0))
|
|
190
|
+
(local.set $accB0 (v128.const f32x4 0 0 0 0))
|
|
191
|
+
(local.set $accB1 (v128.const f32x4 0 0 0 0))
|
|
192
|
+
(local.set $accB2 (v128.const f32x4 0 0 0 0))
|
|
193
|
+
(local.set $accB3 (v128.const f32x4 0 0 0 0))
|
|
194
|
+
(local.set $j (i32.const 0))
|
|
195
|
+
|
|
196
|
+
;; Inner loop: load query once, dot product against both vectors
|
|
197
|
+
(block $break_inner
|
|
198
|
+
(loop $loop_inner
|
|
199
|
+
(br_if $break_inner (i32.ge_u (local.get $j) (local.get $unroll_end)))
|
|
200
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
201
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
202
|
+
(local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
|
|
203
|
+
|
|
204
|
+
;; Load query chunk (shared between both vectors)
|
|
205
|
+
(local.set $q0 (v128.load (local.get $q_offset)))
|
|
206
|
+
(local.set $q1 (v128.load offset=16 (local.get $q_offset)))
|
|
207
|
+
(local.set $q2 (v128.load offset=32 (local.get $q_offset)))
|
|
208
|
+
(local.set $q3 (v128.load offset=48 (local.get $q_offset)))
|
|
209
|
+
|
|
210
|
+
;; Vector A
|
|
211
|
+
(local.set $accA0 (f32x4.add (local.get $accA0)
|
|
212
|
+
(f32x4.mul (local.get $q0) (v128.load (local.get $vA_offset)))))
|
|
213
|
+
(local.set $accA1 (f32x4.add (local.get $accA1)
|
|
214
|
+
(f32x4.mul (local.get $q1) (v128.load offset=16 (local.get $vA_offset)))))
|
|
215
|
+
(local.set $accA2 (f32x4.add (local.get $accA2)
|
|
216
|
+
(f32x4.mul (local.get $q2) (v128.load offset=32 (local.get $vA_offset)))))
|
|
217
|
+
(local.set $accA3 (f32x4.add (local.get $accA3)
|
|
218
|
+
(f32x4.mul (local.get $q3) (v128.load offset=48 (local.get $vA_offset)))))
|
|
219
|
+
|
|
220
|
+
;; Vector B (reuses query loads)
|
|
221
|
+
(local.set $accB0 (f32x4.add (local.get $accB0)
|
|
222
|
+
(f32x4.mul (local.get $q0) (v128.load (local.get $vB_offset)))))
|
|
223
|
+
(local.set $accB1 (f32x4.add (local.get $accB1)
|
|
224
|
+
(f32x4.mul (local.get $q1) (v128.load offset=16 (local.get $vB_offset)))))
|
|
225
|
+
(local.set $accB2 (f32x4.add (local.get $accB2)
|
|
226
|
+
(f32x4.mul (local.get $q2) (v128.load offset=32 (local.get $vB_offset)))))
|
|
227
|
+
(local.set $accB3 (f32x4.add (local.get $accB3)
|
|
228
|
+
(f32x4.mul (local.get $q3) (v128.load offset=48 (local.get $vB_offset)))))
|
|
229
|
+
|
|
230
|
+
(local.set $j (i32.add (local.get $j) (i32.const 16)))
|
|
231
|
+
(br $loop_inner)
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
;; Merge accumulators
|
|
236
|
+
(local.set $accA0 (f32x4.add (f32x4.add (local.get $accA0) (local.get $accA1))
|
|
237
|
+
(f32x4.add (local.get $accA2) (local.get $accA3))))
|
|
238
|
+
(local.set $accB0 (f32x4.add (f32x4.add (local.get $accB0) (local.get $accB1))
|
|
239
|
+
(f32x4.add (local.get $accB2) (local.get $accB3))))
|
|
240
|
+
|
|
241
|
+
;; 4-wide cleanup (both vectors)
|
|
242
|
+
(block $break_inner4
|
|
243
|
+
(loop $loop_inner4
|
|
244
|
+
(br_if $break_inner4 (i32.ge_u (local.get $j) (local.get $simd_end)))
|
|
245
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
246
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
247
|
+
(local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
|
|
248
|
+
(local.set $q0 (v128.load (local.get $q_offset)))
|
|
249
|
+
(local.set $accA0 (f32x4.add (local.get $accA0) (f32x4.mul (local.get $q0) (v128.load (local.get $vA_offset)))))
|
|
250
|
+
(local.set $accB0 (f32x4.add (local.get $accB0) (f32x4.mul (local.get $q0) (v128.load (local.get $vB_offset)))))
|
|
251
|
+
(local.set $j (i32.add (local.get $j) (i32.const 4)))
|
|
252
|
+
(br $loop_inner4)
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
;; Horizontal sums
|
|
257
|
+
(local.set $dotA
|
|
258
|
+
(f32.add
|
|
259
|
+
(f32.add (f32x4.extract_lane 0 (local.get $accA0)) (f32x4.extract_lane 1 (local.get $accA0)))
|
|
260
|
+
(f32.add (f32x4.extract_lane 2 (local.get $accA0)) (f32x4.extract_lane 3 (local.get $accA0)))))
|
|
261
|
+
(local.set $dotB
|
|
262
|
+
(f32.add
|
|
263
|
+
(f32.add (f32x4.extract_lane 0 (local.get $accB0)) (f32x4.extract_lane 1 (local.get $accB0)))
|
|
264
|
+
(f32.add (f32x4.extract_lane 2 (local.get $accB0)) (f32x4.extract_lane 3 (local.get $accB0)))))
|
|
265
|
+
|
|
266
|
+
;; Scalar remainder (both vectors)
|
|
267
|
+
(block $break_rem
|
|
268
|
+
(loop $loop_rem
|
|
269
|
+
(br_if $break_rem (i32.ge_u (local.get $j) (local.get $dim)))
|
|
270
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
271
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
272
|
+
(local.set $vB_offset (i32.add (local.get $vec_ptrB) (i32.shl (local.get $j) (i32.const 2))))
|
|
273
|
+
(local.set $dotA (f32.add (local.get $dotA)
|
|
274
|
+
(f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vA_offset)))))
|
|
275
|
+
(local.set $dotB (f32.add (local.get $dotB)
|
|
276
|
+
(f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vB_offset)))))
|
|
277
|
+
(local.set $j (i32.add (local.get $j) (i32.const 1)))
|
|
278
|
+
(br $loop_rem)
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
;; Store scores for vectors i and i+1
|
|
283
|
+
(f32.store
|
|
284
|
+
(i32.add (local.get $scores_ptr) (i32.shl (local.get $i) (i32.const 2)))
|
|
285
|
+
(local.get $dotA))
|
|
286
|
+
(f32.store
|
|
287
|
+
(i32.add (local.get $scores_ptr) (i32.shl (i32.add (local.get $i) (i32.const 1)) (i32.const 2)))
|
|
288
|
+
(local.get $dotB))
|
|
289
|
+
|
|
290
|
+
(local.set $i (i32.add (local.get $i) (i32.const 2)))
|
|
291
|
+
(br $loop_outer)
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
;; Handle last vector if db_size is odd
|
|
296
|
+
(if (i32.lt_u (local.get $i) (local.get $db_size))
|
|
297
|
+
(then
|
|
298
|
+
(local.set $vec_ptrA
|
|
299
|
+
(i32.add (local.get $db_ptr) (i32.mul (local.get $i) (local.get $bytes_per_vec))))
|
|
300
|
+
(local.set $accA0 (v128.const f32x4 0 0 0 0))
|
|
301
|
+
(local.set $accA1 (v128.const f32x4 0 0 0 0))
|
|
302
|
+
(local.set $accA2 (v128.const f32x4 0 0 0 0))
|
|
303
|
+
(local.set $accA3 (v128.const f32x4 0 0 0 0))
|
|
304
|
+
(local.set $j (i32.const 0))
|
|
305
|
+
|
|
306
|
+
(block $break_last
|
|
307
|
+
(loop $loop_last
|
|
308
|
+
(br_if $break_last (i32.ge_u (local.get $j) (local.get $unroll_end)))
|
|
309
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
310
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
311
|
+
(local.set $accA0 (f32x4.add (local.get $accA0)
|
|
312
|
+
(f32x4.mul (v128.load (local.get $q_offset)) (v128.load (local.get $vA_offset)))))
|
|
313
|
+
(local.set $accA1 (f32x4.add (local.get $accA1)
|
|
314
|
+
(f32x4.mul (v128.load offset=16 (local.get $q_offset)) (v128.load offset=16 (local.get $vA_offset)))))
|
|
315
|
+
(local.set $accA2 (f32x4.add (local.get $accA2)
|
|
316
|
+
(f32x4.mul (v128.load offset=32 (local.get $q_offset)) (v128.load offset=32 (local.get $vA_offset)))))
|
|
317
|
+
(local.set $accA3 (f32x4.add (local.get $accA3)
|
|
318
|
+
(f32x4.mul (v128.load offset=48 (local.get $q_offset)) (v128.load offset=48 (local.get $vA_offset)))))
|
|
319
|
+
(local.set $j (i32.add (local.get $j) (i32.const 16)))
|
|
320
|
+
(br $loop_last)
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
(local.set $accA0 (f32x4.add (f32x4.add (local.get $accA0) (local.get $accA1))
|
|
325
|
+
(f32x4.add (local.get $accA2) (local.get $accA3))))
|
|
326
|
+
|
|
327
|
+
(block $break_last4
|
|
328
|
+
(loop $loop_last4
|
|
329
|
+
(br_if $break_last4 (i32.ge_u (local.get $j) (local.get $simd_end)))
|
|
330
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
331
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
332
|
+
(local.set $accA0 (f32x4.add (local.get $accA0)
|
|
333
|
+
(f32x4.mul (v128.load (local.get $q_offset)) (v128.load (local.get $vA_offset)))))
|
|
334
|
+
(local.set $j (i32.add (local.get $j) (i32.const 4)))
|
|
335
|
+
(br $loop_last4)
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
(local.set $dotA
|
|
340
|
+
(f32.add
|
|
341
|
+
(f32.add (f32x4.extract_lane 0 (local.get $accA0)) (f32x4.extract_lane 1 (local.get $accA0)))
|
|
342
|
+
(f32.add (f32x4.extract_lane 2 (local.get $accA0)) (f32x4.extract_lane 3 (local.get $accA0)))))
|
|
343
|
+
|
|
344
|
+
(block $break_last_rem
|
|
345
|
+
(loop $loop_last_rem
|
|
346
|
+
(br_if $break_last_rem (i32.ge_u (local.get $j) (local.get $dim)))
|
|
347
|
+
(local.set $q_offset (i32.add (local.get $query_ptr) (i32.shl (local.get $j) (i32.const 2))))
|
|
348
|
+
(local.set $vA_offset (i32.add (local.get $vec_ptrA) (i32.shl (local.get $j) (i32.const 2))))
|
|
349
|
+
(local.set $dotA (f32.add (local.get $dotA)
|
|
350
|
+
(f32.mul (f32.load (local.get $q_offset)) (f32.load (local.get $vA_offset)))))
|
|
351
|
+
(local.set $j (i32.add (local.get $j) (i32.const 1)))
|
|
352
|
+
(br $loop_last_rem)
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
(f32.store
|
|
357
|
+
(i32.add (local.get $scores_ptr) (i32.shl (local.get $i) (i32.const 2)))
|
|
358
|
+
(local.get $dotA))
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
)
|