whisper.rn 0.5.0 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/gradle.properties +1 -1
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +292 -130
- package/cpp/ggml-backend-impl.h +4 -4
- package/cpp/ggml-backend-reg.cpp +13 -5
- package/cpp/ggml-backend.cpp +207 -17
- package/cpp/ggml-backend.h +19 -1
- package/cpp/ggml-cpu/amx/amx.cpp +5 -2
- package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
- package/cpp/ggml-cpu/arch-fallback.h +0 -4
- package/cpp/ggml-cpu/common.h +14 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
- package/cpp/ggml-cpu/ggml-cpu.c +65 -44
- package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
- package/cpp/ggml-cpu/ops.cpp +542 -775
- package/cpp/ggml-cpu/ops.h +2 -0
- package/cpp/ggml-cpu/simd-mappings.h +88 -59
- package/cpp/ggml-cpu/unary-ops.cpp +135 -0
- package/cpp/ggml-cpu/unary-ops.h +5 -0
- package/cpp/ggml-cpu/vec.cpp +227 -20
- package/cpp/ggml-cpu/vec.h +407 -56
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-impl.h +94 -12
- package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/ggml-metal/ggml-metal-context.m +600 -0
- package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
- package/cpp/ggml-metal/ggml-metal-device.h +244 -0
- package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
- package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
- package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
- package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
- package/cpp/ggml-metal/ggml-metal.cpp +718 -0
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml-metal-impl.h +40 -40
- package/cpp/ggml-metal.h +1 -6
- package/cpp/ggml-quants.c +1 -0
- package/cpp/ggml.c +341 -15
- package/cpp/ggml.h +150 -5
- package/cpp/jsi/RNWhisperJSI.cpp +9 -2
- package/cpp/jsi/ThreadPool.h +3 -3
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +89 -72
- package/cpp/whisper.h +1 -0
- package/ios/CMakeLists.txt +6 -1
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/RNWhisperVadContext.mm +14 -13
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +8 -9
- package/cpp/ggml-metal.m +0 -6779
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
package/cpp/ggml-cpu/ops.cpp
CHANGED
|
@@ -41,13 +41,15 @@ static void wsp_ggml_compute_forward_dup_same_cont(
|
|
|
41
41
|
}
|
|
42
42
|
}
|
|
43
43
|
|
|
44
|
-
|
|
44
|
+
template<typename src_t, typename dst_t>
|
|
45
|
+
static void wsp_ggml_compute_forward_dup_flt(
|
|
45
46
|
const wsp_ggml_compute_params * params,
|
|
46
47
|
wsp_ggml_tensor * dst) {
|
|
47
48
|
|
|
48
49
|
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
49
50
|
|
|
50
51
|
WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
|
|
52
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_quantized(src0->type) && !wsp_ggml_is_quantized(dst->type));
|
|
51
53
|
|
|
52
54
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
53
55
|
|
|
@@ -62,6 +64,7 @@ static void wsp_ggml_compute_forward_dup_f16(
|
|
|
62
64
|
const int ir0 = dr * ith;
|
|
63
65
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
64
66
|
|
|
67
|
+
// case: type & row size equal
|
|
65
68
|
if (src0->type == dst->type &&
|
|
66
69
|
ne00 == ne0 &&
|
|
67
70
|
nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
|
|
@@ -80,11 +83,11 @@ static void wsp_ggml_compute_forward_dup_f16(
|
|
|
80
83
|
return;
|
|
81
84
|
}
|
|
82
85
|
|
|
83
|
-
//
|
|
84
|
-
|
|
86
|
+
// case: dst tensor is contiguous
|
|
85
87
|
if (wsp_ggml_is_contiguous(dst)) {
|
|
86
|
-
if (nb00 == sizeof(
|
|
87
|
-
if (
|
|
88
|
+
if (nb00 == sizeof(src_t)) {
|
|
89
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
|
90
|
+
// same type
|
|
88
91
|
size_t id = 0;
|
|
89
92
|
const size_t rs = ne00 * nb00;
|
|
90
93
|
char * dst_ptr = (char *) dst->data;
|
|
@@ -100,750 +103,58 @@ static void wsp_ggml_compute_forward_dup_f16(
|
|
|
100
103
|
id += rs * (ne01 - ir1);
|
|
101
104
|
}
|
|
102
105
|
}
|
|
103
|
-
} else
|
|
106
|
+
} else {
|
|
107
|
+
// casting between non-quantized types
|
|
104
108
|
size_t id = 0;
|
|
105
|
-
|
|
109
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
|
106
110
|
|
|
107
111
|
for (int i03 = 0; i03 < ne03; i03++) {
|
|
108
112
|
for (int i02 = 0; i02 < ne02; i02++) {
|
|
109
113
|
id += ne00 * ir0;
|
|
110
114
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
111
|
-
const
|
|
115
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
112
116
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
113
|
-
|
|
117
|
+
float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
|
118
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
|
114
119
|
id++;
|
|
115
120
|
}
|
|
116
121
|
}
|
|
117
122
|
id += ne00 * (ne01 - ir1);
|
|
118
123
|
}
|
|
119
124
|
}
|
|
120
|
-
} else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
121
|
-
wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
122
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
123
|
-
|
|
124
|
-
size_t id = 0;
|
|
125
|
-
size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
|
|
126
|
-
char * dst_ptr = (char *) dst->data;
|
|
127
|
-
|
|
128
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
129
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
130
|
-
id += rs * ir0;
|
|
131
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
132
|
-
const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
133
|
-
|
|
134
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
135
|
-
src0_f32[i00] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
139
|
-
id += rs;
|
|
140
|
-
}
|
|
141
|
-
id += rs * (ne01 - ir1);
|
|
142
|
-
}
|
|
143
|
-
}
|
|
144
|
-
} else {
|
|
145
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
146
125
|
}
|
|
147
126
|
} else {
|
|
148
127
|
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
149
128
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
float * dst_ptr = (float *) dst->data;
|
|
153
|
-
|
|
154
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
155
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
156
|
-
id += ne00 * ir0;
|
|
157
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
158
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
159
|
-
const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
160
|
-
|
|
161
|
-
dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
|
162
|
-
id++;
|
|
163
|
-
}
|
|
164
|
-
}
|
|
165
|
-
id += ne00 * (ne01 - ir1);
|
|
166
|
-
}
|
|
167
|
-
}
|
|
168
|
-
} else if (dst->type == WSP_GGML_TYPE_F16) {
|
|
169
|
-
size_t id = 0;
|
|
170
|
-
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
|
|
171
|
-
|
|
172
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
173
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
174
|
-
id += ne00 * ir0;
|
|
175
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
176
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
177
|
-
const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
178
|
-
|
|
179
|
-
dst_ptr[id] = *src0_ptr;
|
|
180
|
-
id++;
|
|
181
|
-
}
|
|
182
|
-
}
|
|
183
|
-
id += ne00 * (ne01 - ir1);
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
} else {
|
|
187
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
188
|
-
}
|
|
189
|
-
}
|
|
190
|
-
return;
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
// dst counters
|
|
194
|
-
int64_t i10 = 0;
|
|
195
|
-
int64_t i11 = 0;
|
|
196
|
-
int64_t i12 = 0;
|
|
197
|
-
int64_t i13 = 0;
|
|
198
|
-
|
|
199
|
-
if (dst->type == WSP_GGML_TYPE_F16) {
|
|
200
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
201
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
202
|
-
i10 += ne00 * ir0;
|
|
203
|
-
while (i10 >= ne0) {
|
|
204
|
-
i10 -= ne0;
|
|
205
|
-
if (++i11 == ne1) {
|
|
206
|
-
i11 = 0;
|
|
207
|
-
if (++i12 == ne2) {
|
|
208
|
-
i12 = 0;
|
|
209
|
-
if (++i13 == ne3) {
|
|
210
|
-
i13 = 0;
|
|
211
|
-
}
|
|
212
|
-
}
|
|
213
|
-
}
|
|
214
|
-
}
|
|
215
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
216
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
217
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
218
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
219
|
-
|
|
220
|
-
memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_fp16_t));
|
|
221
|
-
|
|
222
|
-
if (++i10 == ne00) {
|
|
223
|
-
i10 = 0;
|
|
224
|
-
if (++i11 == ne01) {
|
|
225
|
-
i11 = 0;
|
|
226
|
-
if (++i12 == ne02) {
|
|
227
|
-
i12 = 0;
|
|
228
|
-
if (++i13 == ne03) {
|
|
229
|
-
i13 = 0;
|
|
230
|
-
}
|
|
231
|
-
}
|
|
232
|
-
}
|
|
233
|
-
}
|
|
234
|
-
}
|
|
235
|
-
}
|
|
236
|
-
i10 += ne00 * (ne01 - ir1);
|
|
237
|
-
while (i10 >= ne0) {
|
|
238
|
-
i10 -= ne0;
|
|
239
|
-
if (++i11 == ne1) {
|
|
240
|
-
i11 = 0;
|
|
241
|
-
if (++i12 == ne2) {
|
|
242
|
-
i12 = 0;
|
|
243
|
-
if (++i13 == ne3) {
|
|
244
|
-
i13 = 0;
|
|
245
|
-
}
|
|
246
|
-
}
|
|
247
|
-
}
|
|
248
|
-
}
|
|
249
|
-
}
|
|
250
|
-
}
|
|
251
|
-
} else if (dst->type == WSP_GGML_TYPE_F32) {
|
|
252
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
253
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
254
|
-
i10 += ne00 * ir0;
|
|
255
|
-
while (i10 >= ne0) {
|
|
256
|
-
i10 -= ne0;
|
|
257
|
-
if (++i11 == ne1) {
|
|
258
|
-
i11 = 0;
|
|
259
|
-
if (++i12 == ne2) {
|
|
260
|
-
i12 = 0;
|
|
261
|
-
if (++i13 == ne3) {
|
|
262
|
-
i13 = 0;
|
|
263
|
-
}
|
|
264
|
-
}
|
|
265
|
-
}
|
|
266
|
-
}
|
|
267
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
268
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
269
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
270
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
271
|
-
|
|
272
|
-
*(float *) dst_ptr = WSP_GGML_CPU_FP16_TO_FP32(*(const wsp_ggml_fp16_t *) src0_ptr);
|
|
273
|
-
|
|
274
|
-
if (++i10 == ne0) {
|
|
275
|
-
i10 = 0;
|
|
276
|
-
if (++i11 == ne1) {
|
|
277
|
-
i11 = 0;
|
|
278
|
-
if (++i12 == ne2) {
|
|
279
|
-
i12 = 0;
|
|
280
|
-
if (++i13 == ne3) {
|
|
281
|
-
i13 = 0;
|
|
282
|
-
}
|
|
283
|
-
}
|
|
284
|
-
}
|
|
285
|
-
}
|
|
286
|
-
}
|
|
287
|
-
}
|
|
288
|
-
i10 += ne00 * (ne01 - ir1);
|
|
289
|
-
while (i10 >= ne0) {
|
|
290
|
-
i10 -= ne0;
|
|
291
|
-
if (++i11 == ne1) {
|
|
292
|
-
i11 = 0;
|
|
293
|
-
if (++i12 == ne2) {
|
|
294
|
-
i12 = 0;
|
|
295
|
-
if (++i13 == ne3) {
|
|
296
|
-
i13 = 0;
|
|
297
|
-
}
|
|
298
|
-
}
|
|
299
|
-
}
|
|
300
|
-
}
|
|
301
|
-
}
|
|
302
|
-
}
|
|
303
|
-
} else {
|
|
304
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
305
|
-
}
|
|
306
|
-
}
|
|
307
|
-
|
|
308
|
-
static void wsp_ggml_compute_forward_dup_bf16(
|
|
309
|
-
const wsp_ggml_compute_params * params,
|
|
310
|
-
wsp_ggml_tensor * dst) {
|
|
311
|
-
|
|
312
|
-
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
313
|
-
|
|
314
|
-
WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
|
|
315
|
-
|
|
316
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
317
|
-
|
|
318
|
-
const int ith = params->ith; // thread index
|
|
319
|
-
const int nth = params->nth; // number of threads
|
|
320
|
-
|
|
321
|
-
// parallelize by rows
|
|
322
|
-
const int nr = ne01;
|
|
323
|
-
// number of rows per thread
|
|
324
|
-
const int dr = (nr + nth - 1) / nth;
|
|
325
|
-
// row range for this thread
|
|
326
|
-
const int ir0 = dr * ith;
|
|
327
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
328
|
-
|
|
329
|
-
if (src0->type == dst->type &&
|
|
330
|
-
ne00 == ne0 &&
|
|
331
|
-
nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
|
|
332
|
-
// copy by rows
|
|
333
|
-
const size_t rs = ne00*nb00;
|
|
334
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
335
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
336
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
337
|
-
memcpy(
|
|
338
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
|
339
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
|
340
|
-
rs);
|
|
341
|
-
}
|
|
342
|
-
}
|
|
343
|
-
}
|
|
344
|
-
return;
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
|
348
|
-
|
|
349
|
-
if (wsp_ggml_is_contiguous(dst)) {
|
|
350
|
-
if (nb00 == sizeof(wsp_ggml_bf16_t)) {
|
|
351
|
-
if (dst->type == WSP_GGML_TYPE_BF16) {
|
|
352
|
-
size_t id = 0;
|
|
353
|
-
const size_t rs = ne00 * nb00;
|
|
354
|
-
char * dst_ptr = (char *) dst->data;
|
|
355
|
-
|
|
356
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
357
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
358
|
-
id += rs * ir0;
|
|
359
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
360
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
361
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
|
362
|
-
id += rs;
|
|
363
|
-
}
|
|
364
|
-
id += rs * (ne01 - ir1);
|
|
365
|
-
}
|
|
366
|
-
}
|
|
367
|
-
} else if (dst->type == WSP_GGML_TYPE_F16) {
|
|
368
|
-
size_t id = 0;
|
|
369
|
-
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
|
|
370
|
-
|
|
371
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
372
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
373
|
-
id += ne00 * ir0;
|
|
374
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
375
|
-
const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
376
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
377
|
-
dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(src0_ptr[i00]));
|
|
378
|
-
id++;
|
|
379
|
-
}
|
|
380
|
-
}
|
|
381
|
-
id += ne00 * (ne01 - ir1);
|
|
382
|
-
}
|
|
383
|
-
}
|
|
384
|
-
} else if (dst->type == WSP_GGML_TYPE_F32) {
|
|
385
|
-
size_t id = 0;
|
|
386
|
-
float * dst_ptr = (float *) dst->data;
|
|
387
|
-
|
|
388
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
389
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
390
|
-
id += ne00 * ir0;
|
|
391
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
392
|
-
const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
393
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
394
|
-
dst_ptr[id] = WSP_GGML_BF16_TO_FP32(src0_ptr[i00]);
|
|
395
|
-
id++;
|
|
396
|
-
}
|
|
397
|
-
}
|
|
398
|
-
id += ne00 * (ne01 - ir1);
|
|
399
|
-
}
|
|
400
|
-
}
|
|
401
|
-
} else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
402
|
-
wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
403
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
404
|
-
|
|
405
|
-
size_t id = 0;
|
|
406
|
-
size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
|
|
407
|
-
char * dst_ptr = (char *) dst->data;
|
|
408
|
-
|
|
409
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
410
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
411
|
-
id += rs * ir0;
|
|
412
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
413
|
-
const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
414
|
-
|
|
415
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
416
|
-
src0_f32[i00] = WSP_GGML_BF16_TO_FP32(src0_ptr[i00]);
|
|
417
|
-
}
|
|
129
|
+
size_t id = 0;
|
|
130
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
|
418
131
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
430
|
-
|
|
431
|
-
if (dst->type == WSP_GGML_TYPE_F32) {
|
|
432
|
-
size_t id = 0;
|
|
433
|
-
float * dst_ptr = (float *) dst->data;
|
|
434
|
-
|
|
435
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
436
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
437
|
-
id += ne00 * ir0;
|
|
438
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
439
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
440
|
-
const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
441
|
-
|
|
442
|
-
dst_ptr[id] = WSP_GGML_BF16_TO_FP32(*src0_ptr);
|
|
443
|
-
id++;
|
|
444
|
-
}
|
|
445
|
-
}
|
|
446
|
-
id += ne00 * (ne01 - ir1);
|
|
447
|
-
}
|
|
448
|
-
}
|
|
449
|
-
} else if (dst->type == WSP_GGML_TYPE_BF16) {
|
|
450
|
-
size_t id = 0;
|
|
451
|
-
wsp_ggml_bf16_t * dst_ptr = (wsp_ggml_bf16_t *) dst->data;
|
|
452
|
-
|
|
453
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
454
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
455
|
-
id += ne00 * ir0;
|
|
456
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
457
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
458
|
-
const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
459
|
-
|
|
460
|
-
dst_ptr[id] = *src0_ptr;
|
|
461
|
-
id++;
|
|
462
|
-
}
|
|
463
|
-
}
|
|
464
|
-
id += ne00 * (ne01 - ir1);
|
|
465
|
-
}
|
|
466
|
-
}
|
|
467
|
-
} else if (dst->type == WSP_GGML_TYPE_F16) {
|
|
468
|
-
size_t id = 0;
|
|
469
|
-
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
|
|
470
|
-
|
|
471
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
472
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
473
|
-
id += ne00 * ir0;
|
|
474
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
475
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
476
|
-
const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
477
|
-
|
|
478
|
-
dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*src0_ptr));
|
|
479
|
-
id++;
|
|
480
|
-
}
|
|
481
|
-
}
|
|
482
|
-
id += ne00 * (ne01 - ir1);
|
|
483
|
-
}
|
|
484
|
-
}
|
|
485
|
-
} else {
|
|
486
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
487
|
-
}
|
|
488
|
-
}
|
|
489
|
-
return;
|
|
490
|
-
}
|
|
491
|
-
|
|
492
|
-
// dst counters
|
|
493
|
-
int64_t i10 = 0;
|
|
494
|
-
int64_t i11 = 0;
|
|
495
|
-
int64_t i12 = 0;
|
|
496
|
-
int64_t i13 = 0;
|
|
497
|
-
|
|
498
|
-
if (dst->type == WSP_GGML_TYPE_BF16) {
|
|
499
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
500
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
501
|
-
i10 += ne00 * ir0;
|
|
502
|
-
while (i10 >= ne0) {
|
|
503
|
-
i10 -= ne0;
|
|
504
|
-
if (++i11 == ne1) {
|
|
505
|
-
i11 = 0;
|
|
506
|
-
if (++i12 == ne2) {
|
|
507
|
-
i12 = 0;
|
|
508
|
-
if (++i13 == ne3) {
|
|
509
|
-
i13 = 0;
|
|
510
|
-
}
|
|
511
|
-
}
|
|
512
|
-
}
|
|
513
|
-
}
|
|
514
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
515
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
516
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
517
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
518
|
-
|
|
519
|
-
memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_bf16_t));
|
|
520
|
-
|
|
521
|
-
if (++i10 == ne00) {
|
|
522
|
-
i10 = 0;
|
|
523
|
-
if (++i11 == ne01) {
|
|
524
|
-
i11 = 0;
|
|
525
|
-
if (++i12 == ne02) {
|
|
526
|
-
i12 = 0;
|
|
527
|
-
if (++i13 == ne03) {
|
|
528
|
-
i13 = 0;
|
|
529
|
-
}
|
|
530
|
-
}
|
|
531
|
-
}
|
|
532
|
-
}
|
|
533
|
-
}
|
|
534
|
-
}
|
|
535
|
-
i10 += ne00 * (ne01 - ir1);
|
|
536
|
-
while (i10 >= ne0) {
|
|
537
|
-
i10 -= ne0;
|
|
538
|
-
if (++i11 == ne1) {
|
|
539
|
-
i11 = 0;
|
|
540
|
-
if (++i12 == ne2) {
|
|
541
|
-
i12 = 0;
|
|
542
|
-
if (++i13 == ne3) {
|
|
543
|
-
i13 = 0;
|
|
544
|
-
}
|
|
545
|
-
}
|
|
546
|
-
}
|
|
547
|
-
}
|
|
548
|
-
}
|
|
549
|
-
}
|
|
550
|
-
} else if (dst->type == WSP_GGML_TYPE_F16) {
|
|
551
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
552
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
553
|
-
i10 += ne00 * ir0;
|
|
554
|
-
while (i10 >= ne0) {
|
|
555
|
-
i10 -= ne0;
|
|
556
|
-
if (++i11 == ne1) {
|
|
557
|
-
i11 = 0;
|
|
558
|
-
if (++i12 == ne2) {
|
|
559
|
-
i12 = 0;
|
|
560
|
-
if (++i13 == ne3) {
|
|
561
|
-
i13 = 0;
|
|
562
|
-
}
|
|
563
|
-
}
|
|
564
|
-
}
|
|
565
|
-
}
|
|
566
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
567
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
568
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
569
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
570
|
-
|
|
571
|
-
*(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr));
|
|
572
|
-
|
|
573
|
-
if (++i10 == ne0) {
|
|
574
|
-
i10 = 0;
|
|
575
|
-
if (++i11 == ne1) {
|
|
576
|
-
i11 = 0;
|
|
577
|
-
if (++i12 == ne2) {
|
|
578
|
-
i12 = 0;
|
|
579
|
-
if (++i13 == ne3) {
|
|
580
|
-
i13 = 0;
|
|
581
|
-
}
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
}
|
|
585
|
-
}
|
|
586
|
-
}
|
|
587
|
-
i10 += ne00 * (ne01 - ir1);
|
|
588
|
-
while (i10 >= ne0) {
|
|
589
|
-
i10 -= ne0;
|
|
590
|
-
if (++i11 == ne1) {
|
|
591
|
-
i11 = 0;
|
|
592
|
-
if (++i12 == ne2) {
|
|
593
|
-
i12 = 0;
|
|
594
|
-
if (++i13 == ne3) {
|
|
595
|
-
i13 = 0;
|
|
596
|
-
}
|
|
597
|
-
}
|
|
598
|
-
}
|
|
599
|
-
}
|
|
600
|
-
}
|
|
601
|
-
}
|
|
602
|
-
} else if (dst->type == WSP_GGML_TYPE_F32) {
|
|
603
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
604
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
605
|
-
i10 += ne00 * ir0;
|
|
606
|
-
while (i10 >= ne0) {
|
|
607
|
-
i10 -= ne0;
|
|
608
|
-
if (++i11 == ne1) {
|
|
609
|
-
i11 = 0;
|
|
610
|
-
if (++i12 == ne2) {
|
|
611
|
-
i12 = 0;
|
|
612
|
-
if (++i13 == ne3) {
|
|
613
|
-
i13 = 0;
|
|
614
|
-
}
|
|
615
|
-
}
|
|
616
|
-
}
|
|
617
|
-
}
|
|
618
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
619
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
620
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
621
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
622
|
-
|
|
623
|
-
*(float *) dst_ptr = WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr);
|
|
624
|
-
|
|
625
|
-
if (++i10 == ne0) {
|
|
626
|
-
i10 = 0;
|
|
627
|
-
if (++i11 == ne1) {
|
|
628
|
-
i11 = 0;
|
|
629
|
-
if (++i12 == ne2) {
|
|
630
|
-
i12 = 0;
|
|
631
|
-
if (++i13 == ne3) {
|
|
632
|
-
i13 = 0;
|
|
633
|
-
}
|
|
634
|
-
}
|
|
635
|
-
}
|
|
636
|
-
}
|
|
637
|
-
}
|
|
638
|
-
}
|
|
639
|
-
i10 += ne00 * (ne01 - ir1);
|
|
640
|
-
while (i10 >= ne0) {
|
|
641
|
-
i10 -= ne0;
|
|
642
|
-
if (++i11 == ne1) {
|
|
643
|
-
i11 = 0;
|
|
644
|
-
if (++i12 == ne2) {
|
|
645
|
-
i12 = 0;
|
|
646
|
-
if (++i13 == ne3) {
|
|
647
|
-
i13 = 0;
|
|
648
|
-
}
|
|
649
|
-
}
|
|
650
|
-
}
|
|
651
|
-
}
|
|
652
|
-
}
|
|
653
|
-
}
|
|
654
|
-
} else {
|
|
655
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
656
|
-
}
|
|
657
|
-
}
|
|
658
|
-
|
|
659
|
-
static void wsp_ggml_compute_forward_dup_f32(
|
|
660
|
-
const wsp_ggml_compute_params * params,
|
|
661
|
-
wsp_ggml_tensor * dst) {
|
|
662
|
-
|
|
663
|
-
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
664
|
-
|
|
665
|
-
WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
|
|
666
|
-
|
|
667
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
668
|
-
|
|
669
|
-
const int ith = params->ith; // thread index
|
|
670
|
-
const int nth = params->nth; // number of threads
|
|
671
|
-
|
|
672
|
-
// parallelize by rows
|
|
673
|
-
const int nr = ne01;
|
|
674
|
-
// number of rows per thread
|
|
675
|
-
const int dr = (nr + nth - 1) / nth;
|
|
676
|
-
// row range for this thread
|
|
677
|
-
const int ir0 = dr * ith;
|
|
678
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
679
|
-
|
|
680
|
-
if (src0->type == dst->type &&
|
|
681
|
-
ne00 == ne0 &&
|
|
682
|
-
nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
|
|
683
|
-
// copy by rows
|
|
684
|
-
const size_t rs = ne00*nb00;
|
|
685
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
686
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
687
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
688
|
-
memcpy(
|
|
689
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
|
690
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
|
691
|
-
rs);
|
|
692
|
-
}
|
|
693
|
-
}
|
|
694
|
-
}
|
|
695
|
-
return;
|
|
696
|
-
}
|
|
697
|
-
|
|
698
|
-
if (wsp_ggml_is_contiguous(dst)) {
|
|
699
|
-
// TODO: simplify
|
|
700
|
-
if (nb00 == sizeof(float)) {
|
|
701
|
-
if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
702
|
-
wsp_ggml_from_float_t const from_float = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
703
|
-
|
|
704
|
-
size_t id = 0;
|
|
705
|
-
size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
|
|
706
|
-
char * dst_ptr = (char *) dst->data;
|
|
707
|
-
|
|
708
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
709
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
710
|
-
id += rs * ir0;
|
|
711
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
712
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
713
|
-
from_float(src0_ptr, dst_ptr + id, ne00);
|
|
714
|
-
id += rs;
|
|
715
|
-
}
|
|
716
|
-
id += rs * (ne01 - ir1);
|
|
717
|
-
}
|
|
718
|
-
}
|
|
719
|
-
} else {
|
|
720
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
721
|
-
}
|
|
722
|
-
} else {
|
|
723
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
724
|
-
|
|
725
|
-
if (dst->type == WSP_GGML_TYPE_F32) {
|
|
726
|
-
size_t id = 0;
|
|
727
|
-
float * dst_ptr = (float *) dst->data;
|
|
728
|
-
|
|
729
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
730
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
731
|
-
id += ne00 * ir0;
|
|
732
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
733
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
734
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
735
|
-
|
|
736
|
-
dst_ptr[id] = *src0_ptr;
|
|
737
|
-
id++;
|
|
738
|
-
}
|
|
739
|
-
}
|
|
740
|
-
id += ne00 * (ne01 - ir1);
|
|
741
|
-
}
|
|
742
|
-
}
|
|
743
|
-
} else if (dst->type == WSP_GGML_TYPE_F16) {
|
|
744
|
-
size_t id = 0;
|
|
745
|
-
wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
|
|
746
|
-
|
|
747
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
748
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
749
|
-
id += ne00 * ir0;
|
|
750
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
751
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
752
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
753
|
-
|
|
754
|
-
dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
|
755
|
-
id++;
|
|
756
|
-
}
|
|
757
|
-
}
|
|
758
|
-
id += ne00 * (ne01 - ir1);
|
|
759
|
-
}
|
|
760
|
-
}
|
|
761
|
-
} else if (dst->type == WSP_GGML_TYPE_BF16) {
|
|
762
|
-
size_t id = 0;
|
|
763
|
-
wsp_ggml_bf16_t * dst_ptr = (wsp_ggml_bf16_t *) dst->data;
|
|
764
|
-
|
|
765
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
766
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
767
|
-
id += ne00 * ir0;
|
|
768
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
769
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
770
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
771
|
-
|
|
772
|
-
dst_ptr[id] = WSP_GGML_FP32_TO_BF16(*src0_ptr);
|
|
773
|
-
id++;
|
|
774
|
-
}
|
|
775
|
-
}
|
|
776
|
-
id += ne00 * (ne01 - ir1);
|
|
777
|
-
}
|
|
778
|
-
}
|
|
779
|
-
} else {
|
|
780
|
-
WSP_GGML_ABORT("fatal error"); // TODO: implement
|
|
781
|
-
}
|
|
782
|
-
}
|
|
783
|
-
|
|
784
|
-
return;
|
|
785
|
-
}
|
|
786
|
-
|
|
787
|
-
// dst counters
|
|
788
|
-
|
|
789
|
-
int64_t i10 = 0;
|
|
790
|
-
int64_t i11 = 0;
|
|
791
|
-
int64_t i12 = 0;
|
|
792
|
-
int64_t i13 = 0;
|
|
793
|
-
|
|
794
|
-
if (dst->type == WSP_GGML_TYPE_F32) {
|
|
795
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
796
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
797
|
-
i10 += ne00 * ir0;
|
|
798
|
-
while (i10 >= ne0) {
|
|
799
|
-
i10 -= ne0;
|
|
800
|
-
if (++i11 == ne1) {
|
|
801
|
-
i11 = 0;
|
|
802
|
-
if (++i12 == ne2) {
|
|
803
|
-
i12 = 0;
|
|
804
|
-
if (++i13 == ne3) {
|
|
805
|
-
i13 = 0;
|
|
806
|
-
}
|
|
807
|
-
}
|
|
808
|
-
}
|
|
809
|
-
}
|
|
810
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
811
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
812
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
813
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
814
|
-
|
|
815
|
-
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
|
816
|
-
|
|
817
|
-
if (++i10 == ne0) {
|
|
818
|
-
i10 = 0;
|
|
819
|
-
if (++i11 == ne1) {
|
|
820
|
-
i11 = 0;
|
|
821
|
-
if (++i12 == ne2) {
|
|
822
|
-
i12 = 0;
|
|
823
|
-
if (++i13 == ne3) {
|
|
824
|
-
i13 = 0;
|
|
825
|
-
}
|
|
826
|
-
}
|
|
827
|
-
}
|
|
828
|
-
}
|
|
829
|
-
}
|
|
830
|
-
}
|
|
831
|
-
i10 += ne00 * (ne01 - ir1);
|
|
832
|
-
while (i10 >= ne0) {
|
|
833
|
-
i10 -= ne0;
|
|
834
|
-
if (++i11 == ne1) {
|
|
835
|
-
i11 = 0;
|
|
836
|
-
if (++i12 == ne2) {
|
|
837
|
-
i12 = 0;
|
|
838
|
-
if (++i13 == ne3) {
|
|
839
|
-
i13 = 0;
|
|
840
|
-
}
|
|
132
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
|
133
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
|
134
|
+
id += ne00 * ir0;
|
|
135
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
136
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
|
137
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
138
|
+
|
|
139
|
+
float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
|
|
140
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
|
141
|
+
id++;
|
|
841
142
|
}
|
|
842
143
|
}
|
|
144
|
+
id += ne00 * (ne01 - ir1);
|
|
843
145
|
}
|
|
844
146
|
}
|
|
845
147
|
}
|
|
846
|
-
|
|
148
|
+
return;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// dst counters
|
|
152
|
+
int64_t i10 = 0;
|
|
153
|
+
int64_t i11 = 0;
|
|
154
|
+
int64_t i12 = 0;
|
|
155
|
+
int64_t i13 = 0;
|
|
156
|
+
|
|
157
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
|
847
158
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
848
159
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
849
160
|
i10 += ne00 * ir0;
|
|
@@ -864,15 +175,15 @@ static void wsp_ggml_compute_forward_dup_f32(
|
|
|
864
175
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
865
176
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
866
177
|
|
|
867
|
-
|
|
178
|
+
memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
|
|
868
179
|
|
|
869
|
-
if (++i10 ==
|
|
180
|
+
if (++i10 == ne00) {
|
|
870
181
|
i10 = 0;
|
|
871
|
-
if (++i11 ==
|
|
182
|
+
if (++i11 == ne01) {
|
|
872
183
|
i11 = 0;
|
|
873
|
-
if (++i12 ==
|
|
184
|
+
if (++i12 == ne02) {
|
|
874
185
|
i12 = 0;
|
|
875
|
-
if (++i13 ==
|
|
186
|
+
if (++i13 == ne03) {
|
|
876
187
|
i13 = 0;
|
|
877
188
|
}
|
|
878
189
|
}
|
|
@@ -895,7 +206,8 @@ static void wsp_ggml_compute_forward_dup_f32(
|
|
|
895
206
|
}
|
|
896
207
|
}
|
|
897
208
|
}
|
|
898
|
-
|
|
209
|
+
|
|
210
|
+
} else {
|
|
899
211
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
900
212
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
901
213
|
i10 += ne00 * ir0;
|
|
@@ -916,7 +228,8 @@ static void wsp_ggml_compute_forward_dup_f32(
|
|
|
916
228
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
917
229
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
918
230
|
|
|
919
|
-
|
|
231
|
+
float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
|
|
232
|
+
*(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
|
|
920
233
|
|
|
921
234
|
if (++i10 == ne0) {
|
|
922
235
|
i10 = 0;
|
|
@@ -947,8 +260,63 @@ static void wsp_ggml_compute_forward_dup_f32(
|
|
|
947
260
|
}
|
|
948
261
|
}
|
|
949
262
|
}
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
template<typename src_t>
|
|
268
|
+
static void wsp_ggml_compute_forward_dup_to_q(
|
|
269
|
+
const wsp_ggml_compute_params * params,
|
|
270
|
+
wsp_ggml_tensor * dst) {
|
|
271
|
+
|
|
272
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
273
|
+
|
|
274
|
+
WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
|
|
275
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_quantized(src0->type));
|
|
276
|
+
|
|
277
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
278
|
+
|
|
279
|
+
const int ith = params->ith; // thread index
|
|
280
|
+
const int nth = params->nth; // number of threads
|
|
281
|
+
|
|
282
|
+
// parallelize by rows
|
|
283
|
+
const int nr = ne01;
|
|
284
|
+
// number of rows per thread
|
|
285
|
+
const int dr = (nr + nth - 1) / nth;
|
|
286
|
+
// row range for this thread
|
|
287
|
+
const int ir0 = dr * ith;
|
|
288
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
289
|
+
|
|
290
|
+
if (wsp_ggml_is_contiguous(dst) &&
|
|
291
|
+
nb00 == sizeof(src_t) &&
|
|
292
|
+
wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
293
|
+
// casting non-quantized types --> intermediate f32 --> quantized
|
|
294
|
+
wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
295
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
296
|
+
|
|
297
|
+
size_t id = 0;
|
|
298
|
+
size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
|
|
299
|
+
char * dst_ptr = (char *) dst->data;
|
|
300
|
+
|
|
301
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
|
302
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
|
303
|
+
id += rs * ir0;
|
|
304
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
305
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
306
|
+
|
|
307
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
|
308
|
+
src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
312
|
+
id += rs;
|
|
313
|
+
}
|
|
314
|
+
id += rs * (ne01 - ir1);
|
|
315
|
+
}
|
|
316
|
+
}
|
|
950
317
|
} else {
|
|
951
|
-
|
|
318
|
+
// printf("%s %s\n", wsp_ggml_type_name(src0->type), wsp_ggml_type_name(dst->type));
|
|
319
|
+
WSP_GGML_ABORT("not implemented");
|
|
952
320
|
}
|
|
953
321
|
}
|
|
954
322
|
|
|
@@ -1102,7 +470,7 @@ static void wsp_ggml_compute_forward_dup_bytes(
|
|
|
1102
470
|
}
|
|
1103
471
|
}
|
|
1104
472
|
|
|
1105
|
-
static void
|
|
473
|
+
static void wsp_ggml_compute_forward_dup_from_q(
|
|
1106
474
|
const wsp_ggml_compute_params * params,
|
|
1107
475
|
wsp_ggml_tensor * dst) {
|
|
1108
476
|
|
|
@@ -1167,20 +535,35 @@ void wsp_ggml_compute_forward_dup(
|
|
|
1167
535
|
switch (src0->type) {
|
|
1168
536
|
case WSP_GGML_TYPE_F16:
|
|
1169
537
|
{
|
|
1170
|
-
|
|
538
|
+
/**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
|
|
539
|
+
else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, wsp_ggml_bf16_t>(params, dst);
|
|
540
|
+
else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, float >(params, dst);
|
|
541
|
+
else wsp_ggml_compute_forward_dup_to_q<wsp_ggml_fp16_t>(params, dst);
|
|
1171
542
|
} break;
|
|
1172
543
|
case WSP_GGML_TYPE_BF16:
|
|
1173
544
|
{
|
|
1174
|
-
|
|
545
|
+
/**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, wsp_ggml_fp16_t>(params, dst);
|
|
546
|
+
else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
|
|
547
|
+
else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, float >(params, dst);
|
|
548
|
+
else wsp_ggml_compute_forward_dup_to_q<wsp_ggml_bf16_t>(params, dst);
|
|
1175
549
|
} break;
|
|
1176
550
|
case WSP_GGML_TYPE_F32:
|
|
1177
551
|
{
|
|
1178
|
-
|
|
552
|
+
/**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<float, wsp_ggml_fp16_t>(params, dst);
|
|
553
|
+
else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<float, wsp_ggml_bf16_t>(params, dst);
|
|
554
|
+
else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<float, float >(params, dst);
|
|
555
|
+
else if (dst->type == WSP_GGML_TYPE_I32) wsp_ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
|
|
556
|
+
else wsp_ggml_compute_forward_dup_to_q<float>(params, dst);
|
|
557
|
+
} break;
|
|
558
|
+
case WSP_GGML_TYPE_I32:
|
|
559
|
+
{
|
|
560
|
+
if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
|
|
561
|
+
else WSP_GGML_ABORT("not implemented");
|
|
1179
562
|
} break;
|
|
1180
563
|
default:
|
|
1181
564
|
{
|
|
1182
565
|
if (wsp_ggml_is_quantized(src0->type) && dst->type == WSP_GGML_TYPE_F32) {
|
|
1183
|
-
|
|
566
|
+
wsp_ggml_compute_forward_dup_from_q(params, dst);
|
|
1184
567
|
break;
|
|
1185
568
|
}
|
|
1186
569
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -4084,31 +3467,27 @@ static void wsp_ggml_compute_forward_norm_f32(
|
|
|
4084
3467
|
|
|
4085
3468
|
WSP_GGML_ASSERT(eps >= 0.0f);
|
|
4086
3469
|
|
|
4087
|
-
// TODO: optimize
|
|
4088
3470
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
4089
3471
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
4090
3472
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
4091
3473
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
4092
3474
|
|
|
4093
|
-
|
|
4094
|
-
|
|
4095
|
-
sum += (wsp_ggml_float)x[i00];
|
|
4096
|
-
}
|
|
4097
|
-
|
|
3475
|
+
float sum = 0.0;
|
|
3476
|
+
wsp_ggml_vec_sum_f32(ne00, &sum, x);
|
|
4098
3477
|
float mean = sum/ne00;
|
|
4099
3478
|
|
|
4100
3479
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3480
|
+
float variance = 0;
|
|
4101
3481
|
|
|
4102
|
-
|
|
4103
|
-
|
|
4104
|
-
|
|
4105
|
-
|
|
4106
|
-
|
|
4107
|
-
|
|
3482
|
+
#ifdef WSP_GGML_USE_ACCELERATE
|
|
3483
|
+
mean = -mean;
|
|
3484
|
+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
|
3485
|
+
vDSP_measqv(y, 1, &variance, ne00);
|
|
3486
|
+
#else
|
|
3487
|
+
variance = wsp_ggml_vec_cvar_f32(ne00, y, x, mean);
|
|
3488
|
+
#endif //WSP_GGML_USE_ACCELERATE
|
|
4108
3489
|
|
|
4109
|
-
float variance = sum2/ne00;
|
|
4110
3490
|
const float scale = 1.0f/sqrtf(variance + eps);
|
|
4111
|
-
|
|
4112
3491
|
wsp_ggml_vec_scale_f32(ne00, y, scale);
|
|
4113
3492
|
}
|
|
4114
3493
|
}
|
|
@@ -5356,6 +4735,7 @@ void wsp_ggml_compute_forward_get_rows(
|
|
|
5356
4735
|
//}
|
|
5357
4736
|
}
|
|
5358
4737
|
|
|
4738
|
+
template<typename idx_t>
|
|
5359
4739
|
static void wsp_ggml_compute_forward_set_rows_f32(
|
|
5360
4740
|
const wsp_ggml_compute_params * params,
|
|
5361
4741
|
wsp_ggml_tensor * dst) {
|
|
@@ -5394,7 +4774,7 @@ static void wsp_ggml_compute_forward_set_rows_f32(
|
|
|
5394
4774
|
const int64_t i11 = i02%ne11;
|
|
5395
4775
|
const int64_t i10 = i;
|
|
5396
4776
|
|
|
5397
|
-
const int64_t i1 = *(
|
|
4777
|
+
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
5398
4778
|
|
|
5399
4779
|
WSP_GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
|
5400
4780
|
|
|
@@ -5411,11 +4791,18 @@ void wsp_ggml_compute_forward_set_rows(
|
|
|
5411
4791
|
wsp_ggml_tensor * dst) {
|
|
5412
4792
|
|
|
5413
4793
|
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
4794
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
5414
4795
|
|
|
5415
4796
|
switch (src0->type) {
|
|
5416
4797
|
case WSP_GGML_TYPE_F32:
|
|
5417
4798
|
{
|
|
5418
|
-
|
|
4799
|
+
if (src1->type == WSP_GGML_TYPE_I64) {
|
|
4800
|
+
wsp_ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
|
|
4801
|
+
} else if (src1->type == WSP_GGML_TYPE_I32) {
|
|
4802
|
+
wsp_ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
|
|
4803
|
+
} else {
|
|
4804
|
+
WSP_GGML_ABORT("src1->type = %d (%s) not supported", src1->type, wsp_ggml_type_name(src1->type));
|
|
4805
|
+
}
|
|
5419
4806
|
} break;
|
|
5420
4807
|
default:
|
|
5421
4808
|
{
|
|
@@ -7027,6 +6414,209 @@ void wsp_ggml_compute_forward_im2col_back_f32(
|
|
|
7027
6414
|
}
|
|
7028
6415
|
}
|
|
7029
6416
|
|
|
6417
|
+
|
|
6418
|
+
// wsp_ggml_compute_forward_im2col_3d_f16
|
|
6419
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
|
6420
|
+
// src1: image [N*IC, ID, IH, IW]
|
|
6421
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6422
|
+
static void wsp_ggml_compute_forward_im2col_3d_f16(
|
|
6423
|
+
const wsp_ggml_compute_params * params,
|
|
6424
|
+
wsp_ggml_tensor * dst) {
|
|
6425
|
+
|
|
6426
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
6427
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
6428
|
+
|
|
6429
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
6430
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
6431
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
6432
|
+
|
|
6433
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
6434
|
+
|
|
6435
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6436
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6437
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
|
6438
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
|
6439
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
|
6440
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
|
6441
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
|
6442
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
|
6443
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
|
6444
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
|
6445
|
+
|
|
6446
|
+
|
|
6447
|
+
const int ith = params->ith;
|
|
6448
|
+
const int nth = params->nth;
|
|
6449
|
+
|
|
6450
|
+
const int64_t N = ne13 / IC;
|
|
6451
|
+
const int64_t ID = ne12;
|
|
6452
|
+
const int64_t IH = ne11;
|
|
6453
|
+
const int64_t IW = ne10;
|
|
6454
|
+
|
|
6455
|
+
const int64_t OC = ne03 / IC;
|
|
6456
|
+
WSP_GGML_UNUSED(OC);
|
|
6457
|
+
const int64_t KD = ne02;
|
|
6458
|
+
const int64_t KH = ne01;
|
|
6459
|
+
const int64_t KW = ne00;
|
|
6460
|
+
|
|
6461
|
+
const int64_t OD = ne3 / N;
|
|
6462
|
+
const int64_t OH = ne2;
|
|
6463
|
+
const int64_t OW = ne1;
|
|
6464
|
+
const int64_t OH_OW = OH*OW;
|
|
6465
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
|
6466
|
+
const int64_t KH_KW = KH*KW;
|
|
6467
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
|
6468
|
+
|
|
6469
|
+
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
6470
|
+
|
|
6471
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6472
|
+
{
|
|
6473
|
+
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
|
|
6474
|
+
|
|
6475
|
+
for (int64_t in = 0; in < N; in++) {
|
|
6476
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
|
6477
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
6478
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
6479
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6480
|
+
|
|
6481
|
+
// micro kernel
|
|
6482
|
+
wsp_ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
|
6483
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
|
6484
|
+
|
|
6485
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
|
6486
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6487
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6488
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
6489
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6490
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6491
|
+
|
|
6492
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
|
6493
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6494
|
+
} else {
|
|
6495
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
6496
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = WSP_GGML_CPU_FP32_TO_FP16(*s);
|
|
6497
|
+
}
|
|
6498
|
+
}
|
|
6499
|
+
}
|
|
6500
|
+
}
|
|
6501
|
+
}
|
|
6502
|
+
}
|
|
6503
|
+
}
|
|
6504
|
+
}
|
|
6505
|
+
}
|
|
6506
|
+
}
|
|
6507
|
+
}
|
|
6508
|
+
|
|
6509
|
+
// wsp_ggml_compute_forward_im2col_3d_f32
|
|
6510
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
|
6511
|
+
// src1: image [N*IC, ID, IH, IW]
|
|
6512
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6513
|
+
static void wsp_ggml_compute_forward_im2col_3d_f32(
|
|
6514
|
+
const wsp_ggml_compute_params * params,
|
|
6515
|
+
wsp_ggml_tensor * dst) {
|
|
6516
|
+
|
|
6517
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
6518
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
6519
|
+
|
|
6520
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
6521
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
6522
|
+
|
|
6523
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
6524
|
+
|
|
6525
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6526
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6527
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
|
6528
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
|
6529
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
|
6530
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
|
6531
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
|
6532
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
|
6533
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
|
6534
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
|
6535
|
+
|
|
6536
|
+
|
|
6537
|
+
const int ith = params->ith;
|
|
6538
|
+
const int nth = params->nth;
|
|
6539
|
+
|
|
6540
|
+
const int64_t N = ne13 / IC;
|
|
6541
|
+
const int64_t ID = ne12;
|
|
6542
|
+
const int64_t IH = ne11;
|
|
6543
|
+
const int64_t IW = ne10;
|
|
6544
|
+
|
|
6545
|
+
const int64_t OC = ne03 / IC;
|
|
6546
|
+
WSP_GGML_UNUSED(OC);
|
|
6547
|
+
const int64_t KD = ne02;
|
|
6548
|
+
const int64_t KH = ne01;
|
|
6549
|
+
const int64_t KW = ne00;
|
|
6550
|
+
|
|
6551
|
+
const int64_t OD = ne3 / N;
|
|
6552
|
+
const int64_t OH = ne2;
|
|
6553
|
+
const int64_t OW = ne1;
|
|
6554
|
+
|
|
6555
|
+
const int64_t OH_OW = OH*OW;
|
|
6556
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
|
6557
|
+
const int64_t KH_KW = KH*KW;
|
|
6558
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
|
6559
|
+
|
|
6560
|
+
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
6561
|
+
|
|
6562
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6563
|
+
{
|
|
6564
|
+
float * const wdata = (float *) dst->data;
|
|
6565
|
+
|
|
6566
|
+
for (int64_t in = 0; in < N; in++) {
|
|
6567
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
|
6568
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
6569
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
6570
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6571
|
+
|
|
6572
|
+
// micro kernel
|
|
6573
|
+
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
|
6574
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
|
6575
|
+
|
|
6576
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
|
6577
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6578
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6579
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
6580
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6581
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6582
|
+
|
|
6583
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
|
6584
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6585
|
+
} else {
|
|
6586
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
6587
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
|
|
6588
|
+
}
|
|
6589
|
+
}
|
|
6590
|
+
}
|
|
6591
|
+
}
|
|
6592
|
+
}
|
|
6593
|
+
}
|
|
6594
|
+
}
|
|
6595
|
+
}
|
|
6596
|
+
}
|
|
6597
|
+
}
|
|
6598
|
+
}
|
|
6599
|
+
|
|
6600
|
+
|
|
6601
|
+
void wsp_ggml_compute_forward_im2col_3d(
|
|
6602
|
+
const wsp_ggml_compute_params * params,
|
|
6603
|
+
wsp_ggml_tensor * dst) {
|
|
6604
|
+
switch (dst->type) {
|
|
6605
|
+
case WSP_GGML_TYPE_F16:
|
|
6606
|
+
{
|
|
6607
|
+
wsp_ggml_compute_forward_im2col_3d_f16(params, dst);
|
|
6608
|
+
} break;
|
|
6609
|
+
case WSP_GGML_TYPE_F32:
|
|
6610
|
+
{
|
|
6611
|
+
wsp_ggml_compute_forward_im2col_3d_f32(params, dst);
|
|
6612
|
+
} break;
|
|
6613
|
+
default:
|
|
6614
|
+
{
|
|
6615
|
+
WSP_GGML_ABORT("fatal error");
|
|
6616
|
+
}
|
|
6617
|
+
}
|
|
6618
|
+
}
|
|
6619
|
+
|
|
7030
6620
|
static void wsp_ggml_call_mul_mat(wsp_ggml_type type, const wsp_ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
7031
6621
|
void * a, void * b, float * c) {
|
|
7032
6622
|
const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(type);
|
|
@@ -7207,6 +6797,148 @@ void wsp_ggml_compute_forward_conv_2d(
|
|
|
7207
6797
|
wsp_ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
7208
6798
|
}
|
|
7209
6799
|
|
|
6800
|
+
// wsp_ggml_compute_forward_conv_3d
|
|
6801
|
+
|
|
6802
|
+
static void wsp_ggml_compute_forward_conv_3d_impl(const wsp_ggml_compute_params * params,
|
|
6803
|
+
const wsp_ggml_tensor * kernel,
|
|
6804
|
+
const wsp_ggml_tensor * src,
|
|
6805
|
+
wsp_ggml_tensor * dst,
|
|
6806
|
+
wsp_ggml_type kernel_type) {
|
|
6807
|
+
|
|
6808
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(kernel));
|
|
6809
|
+
WSP_GGML_ASSERT(kernel_type == WSP_GGML_TYPE_F16 || kernel_type == WSP_GGML_TYPE_F32);
|
|
6810
|
+
WSP_GGML_ASSERT(kernel->type == kernel_type);
|
|
6811
|
+
|
|
6812
|
+
const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(kernel_type);
|
|
6813
|
+
|
|
6814
|
+
const int32_t s0 = dst->op_params[0];
|
|
6815
|
+
const int32_t s1 = dst->op_params[1];
|
|
6816
|
+
const int32_t s2 = dst->op_params[2];
|
|
6817
|
+
const int32_t p0 = dst->op_params[3];
|
|
6818
|
+
const int32_t p1 = dst->op_params[4];
|
|
6819
|
+
const int32_t p2 = dst->op_params[5];
|
|
6820
|
+
const int32_t d0 = dst->op_params[6];
|
|
6821
|
+
const int32_t d1 = dst->op_params[7];
|
|
6822
|
+
const int32_t d2 = dst->op_params[8];
|
|
6823
|
+
const int32_t c = dst->op_params[9];
|
|
6824
|
+
const int32_t n = dst->op_params[10];
|
|
6825
|
+
const int32_t oc = dst->op_params[11];
|
|
6826
|
+
|
|
6827
|
+
const int64_t src_w = src->ne[0];
|
|
6828
|
+
const int64_t src_h = src->ne[1];
|
|
6829
|
+
const int64_t src_d = src->ne[2];
|
|
6830
|
+
const int64_t knl_w = kernel->ne[0];
|
|
6831
|
+
const int64_t knl_h = kernel->ne[1];
|
|
6832
|
+
const int64_t knl_d = kernel->ne[2];
|
|
6833
|
+
const int64_t dst_w = dst->ne[0];
|
|
6834
|
+
const int64_t dst_h = dst->ne[1];
|
|
6835
|
+
const int64_t dst_d = dst->ne[2];
|
|
6836
|
+
|
|
6837
|
+
const float * src_data = (float *) src->data;
|
|
6838
|
+
void * knl_data = kernel->data;
|
|
6839
|
+
float * dst_data = (float *) dst->data;
|
|
6840
|
+
|
|
6841
|
+
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
|
6842
|
+
const int64_t knl_n_total = knl_n_per_channel * c;
|
|
6843
|
+
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
|
6844
|
+
|
|
6845
|
+
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
|
6846
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
6847
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
6848
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
6849
|
+
|
|
6850
|
+
WSP_GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
6851
|
+
|
|
6852
|
+
void * tmp = params->wdata;
|
|
6853
|
+
|
|
6854
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
6855
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
6856
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
|
6857
|
+
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
|
6858
|
+
|
|
6859
|
+
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
|
6860
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
6861
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
6862
|
+
|
|
6863
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
6864
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
|
6865
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
|
6866
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
|
6867
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
|
6868
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
|
6869
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
|
6870
|
+
|
|
6871
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
|
6872
|
+
|
|
6873
|
+
for (int64_t ic = 0; ic < c; ++ic) {
|
|
6874
|
+
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
|
6875
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
6876
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
6877
|
+
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
|
6878
|
+
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
|
6879
|
+
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
|
6880
|
+
|
|
6881
|
+
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
|
6882
|
+
|
|
6883
|
+
float src_val;
|
|
6884
|
+
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
6885
|
+
src_val = 0.0f;
|
|
6886
|
+
} else {
|
|
6887
|
+
const int64_t cn_idx = batch_idx * c + ic;
|
|
6888
|
+
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
|
6889
|
+
src_val = *src_ptr;
|
|
6890
|
+
}
|
|
6891
|
+
|
|
6892
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
6893
|
+
if (kernel_type == WSP_GGML_TYPE_F32) {
|
|
6894
|
+
*(float *)element_ptr = src_val;
|
|
6895
|
+
} else if (kernel_type == WSP_GGML_TYPE_F16) {
|
|
6896
|
+
*(wsp_ggml_fp16_t *)element_ptr = WSP_GGML_CPU_FP32_TO_FP16(src_val);
|
|
6897
|
+
}
|
|
6898
|
+
}
|
|
6899
|
+
}
|
|
6900
|
+
}
|
|
6901
|
+
}
|
|
6902
|
+
}
|
|
6903
|
+
|
|
6904
|
+
wsp_ggml_barrier(params->threadpool);
|
|
6905
|
+
|
|
6906
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
|
6907
|
+
wsp_ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
|
6908
|
+
|
|
6909
|
+
wsp_ggml_barrier(params->threadpool);
|
|
6910
|
+
|
|
6911
|
+
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
|
6912
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
6913
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
|
6914
|
+
|
|
6915
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
6916
|
+
const int64_t p = patch_start_batch + i;
|
|
6917
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
|
6918
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
|
6919
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
|
6920
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
|
6921
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
|
6922
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
|
6923
|
+
|
|
6924
|
+
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
|
6925
|
+
const float value = gemm_output[i * oc + ioc];
|
|
6926
|
+
const int64_t ocn_idx = batch_idx * oc + ioc;
|
|
6927
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
|
6928
|
+
*dst_ptr = value;
|
|
6929
|
+
}
|
|
6930
|
+
}
|
|
6931
|
+
}
|
|
6932
|
+
}
|
|
6933
|
+
|
|
6934
|
+
void wsp_ggml_compute_forward_conv_3d(
|
|
6935
|
+
const wsp_ggml_compute_params * params,
|
|
6936
|
+
wsp_ggml_tensor * dst) {
|
|
6937
|
+
const wsp_ggml_tensor * src0 = dst->src[0];
|
|
6938
|
+
const wsp_ggml_tensor * src1 = dst->src[1];
|
|
6939
|
+
wsp_ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
|
6940
|
+
}
|
|
6941
|
+
|
|
7210
6942
|
// wsp_ggml_compute_forward_conv_transpose_2d
|
|
7211
6943
|
|
|
7212
6944
|
void wsp_ggml_compute_forward_conv_transpose_2d(
|
|
@@ -7872,6 +7604,15 @@ static void wsp_ggml_compute_forward_pad_f32(
|
|
|
7872
7604
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
7873
7605
|
|
|
7874
7606
|
float * dst_ptr = (float *) dst->data;
|
|
7607
|
+
const int32_t lp0 = wsp_ggml_get_op_params_i32(dst, 0);
|
|
7608
|
+
const int32_t rp0 = wsp_ggml_get_op_params_i32(dst, 1);
|
|
7609
|
+
const int32_t lp1 = wsp_ggml_get_op_params_i32(dst, 2);
|
|
7610
|
+
const int32_t rp1 = wsp_ggml_get_op_params_i32(dst, 3);
|
|
7611
|
+
const int32_t lp2 = wsp_ggml_get_op_params_i32(dst, 4);
|
|
7612
|
+
const int32_t rp2 = wsp_ggml_get_op_params_i32(dst, 5);
|
|
7613
|
+
const int32_t lp3 = wsp_ggml_get_op_params_i32(dst, 6);
|
|
7614
|
+
const int32_t rp3 = wsp_ggml_get_op_params_i32(dst, 7);
|
|
7615
|
+
|
|
7875
7616
|
|
|
7876
7617
|
// TODO: optimize
|
|
7877
7618
|
|
|
@@ -7880,10 +7621,12 @@ static void wsp_ggml_compute_forward_pad_f32(
|
|
|
7880
7621
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
7881
7622
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
7882
7623
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7883
|
-
|
|
7884
|
-
|
|
7885
|
-
|
|
7886
|
-
|
|
7624
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
|
7625
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
|
7626
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
|
7627
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
7628
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
|
7629
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7887
7630
|
dst_ptr[dst_idx] = *src_ptr;
|
|
7888
7631
|
} else {
|
|
7889
7632
|
dst_ptr[dst_idx] = 0;
|
|
@@ -8082,7 +7825,7 @@ static void wsp_ggml_compute_forward_timestep_embedding_f32(
|
|
|
8082
7825
|
embed_data[j + half] = sinf(arg);
|
|
8083
7826
|
}
|
|
8084
7827
|
if (dim % 2 != 0 && ith == 0) {
|
|
8085
|
-
embed_data[
|
|
7828
|
+
embed_data[2 * half] = 0.f;
|
|
8086
7829
|
}
|
|
8087
7830
|
}
|
|
8088
7831
|
}
|
|
@@ -8388,7 +8131,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8388
8131
|
}
|
|
8389
8132
|
|
|
8390
8133
|
// V /= S
|
|
8391
|
-
const float S_inv = 1.0f/S;
|
|
8134
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8392
8135
|
wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8393
8136
|
|
|
8394
8137
|
// dst indices
|
|
@@ -8861,8 +8604,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8861
8604
|
WSP_GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
8862
8605
|
WSP_GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
8863
8606
|
WSP_GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8864
|
-
|
|
8865
|
-
WSP_GGML_ASSERT((ng & -ng) == ng);
|
|
8607
|
+
WSP_GGML_ASSERT(nh % ng == 0);
|
|
8866
8608
|
|
|
8867
8609
|
// heads per thread
|
|
8868
8610
|
const int dh = (nh + nth - 1)/nth;
|
|
@@ -8891,8 +8633,9 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8891
8633
|
// n_head
|
|
8892
8634
|
for (int h = ih0; h < ih1; ++h) {
|
|
8893
8635
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8894
|
-
const float dt_soft_plus =
|
|
8636
|
+
const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
|
|
8895
8637
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
8638
|
+
const int g = h / (nh / ng); // repeat_interleave
|
|
8896
8639
|
|
|
8897
8640
|
// dim
|
|
8898
8641
|
for (int i1 = 0; i1 < nr; ++i1) {
|
|
@@ -8915,8 +8658,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8915
8658
|
// TODO: maybe unroll more?
|
|
8916
8659
|
for (int j = 0; j < 1; j++) {
|
|
8917
8660
|
WSP_GGML_F32_VEC t0 = WSP_GGML_F32_VEC_LOAD(s0 + i + j*wsp_ggml_f32_epr + ii*nc);
|
|
8918
|
-
WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr +
|
|
8919
|
-
WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr +
|
|
8661
|
+
WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + g*nc);
|
|
8662
|
+
WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + g*nc);
|
|
8920
8663
|
|
|
8921
8664
|
t0 = WSP_GGML_F32_VEC_MUL(t0, adA);
|
|
8922
8665
|
t1 = WSP_GGML_F32_VEC_MUL(t1, axdt);
|
|
@@ -8930,6 +8673,9 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8930
8673
|
}
|
|
8931
8674
|
|
|
8932
8675
|
sumf = WSP_GGML_F32xt_REDUCE_ONE(sum);
|
|
8676
|
+
#elif defined(__riscv_v_intrinsic)
|
|
8677
|
+
// todo: RVV implementation
|
|
8678
|
+
const int np = 0;
|
|
8933
8679
|
#else
|
|
8934
8680
|
const int np = (nc & ~(WSP_GGML_F32_STEP - 1));
|
|
8935
8681
|
|
|
@@ -8945,8 +8691,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8945
8691
|
for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
|
|
8946
8692
|
for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
|
|
8947
8693
|
ax[j] = WSP_GGML_F32_VEC_LOAD(s0 + i + j*WSP_GGML_F32_EPR + ii*nc);
|
|
8948
|
-
ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR +
|
|
8949
|
-
az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR +
|
|
8694
|
+
ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + g*nc);
|
|
8695
|
+
az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + g*nc);
|
|
8950
8696
|
|
|
8951
8697
|
ax[j] = WSP_GGML_F32_VEC_MUL(ax[j], adA);
|
|
8952
8698
|
ay[j] = WSP_GGML_F32_VEC_MUL(ay[j], axdt);
|
|
@@ -8968,7 +8714,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8968
8714
|
// d_state
|
|
8969
8715
|
for (int i0 = np; i0 < nc; ++i0) {
|
|
8970
8716
|
const int i = i0 + ii*nc;
|
|
8971
|
-
const int ig = i0 +
|
|
8717
|
+
const int ig = i0 + g*nc;
|
|
8972
8718
|
// state = prev_state * dA + dB * x
|
|
8973
8719
|
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8974
8720
|
// y = rowwise_dotprod(state, C)
|
|
@@ -8984,7 +8730,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8984
8730
|
// n_head
|
|
8985
8731
|
for (int h = ih0; h < ih1; ++h) {
|
|
8986
8732
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8987
|
-
const float dt_soft_plus =
|
|
8733
|
+
const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
|
|
8734
|
+
const int g = h / (nh / ng); // repeat_interleave
|
|
8988
8735
|
|
|
8989
8736
|
// dim
|
|
8990
8737
|
for (int i1 = 0; i1 < nr; ++i1) {
|
|
@@ -8999,8 +8746,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
8999
8746
|
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
9000
8747
|
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
9001
8748
|
svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
9002
|
-
svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k +
|
|
9003
|
-
svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k +
|
|
8749
|
+
svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + g*nc]);
|
|
8750
|
+
svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + g*nc]);
|
|
9004
8751
|
svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
9005
8752
|
|
|
9006
8753
|
svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
@@ -9020,7 +8767,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
|
|
|
9020
8767
|
// d_state
|
|
9021
8768
|
for (int i0 = 0; i0 < nc; ++i0) {
|
|
9022
8769
|
const int i = i0 + ii*nc;
|
|
9023
|
-
const int ig = i0 +
|
|
8770
|
+
const int ig = i0 + g*nc;
|
|
9024
8771
|
// state = prev_state * dA + dB * x
|
|
9025
8772
|
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
9026
8773
|
// y = rowwise_dotprod(state, C)
|
|
@@ -9246,6 +8993,26 @@ void wsp_ggml_compute_forward_unary(
|
|
|
9246
8993
|
{
|
|
9247
8994
|
wsp_ggml_compute_forward_exp(params, dst);
|
|
9248
8995
|
} break;
|
|
8996
|
+
case WSP_GGML_UNARY_OP_FLOOR:
|
|
8997
|
+
{
|
|
8998
|
+
wsp_ggml_compute_forward_floor(params, dst);
|
|
8999
|
+
} break;
|
|
9000
|
+
case WSP_GGML_UNARY_OP_CEIL:
|
|
9001
|
+
{
|
|
9002
|
+
wsp_ggml_compute_forward_ceil(params, dst);
|
|
9003
|
+
} break;
|
|
9004
|
+
case WSP_GGML_UNARY_OP_ROUND:
|
|
9005
|
+
{
|
|
9006
|
+
wsp_ggml_compute_forward_round(params, dst);
|
|
9007
|
+
} break;
|
|
9008
|
+
case WSP_GGML_UNARY_OP_TRUNC:
|
|
9009
|
+
{
|
|
9010
|
+
wsp_ggml_compute_forward_trunc(params, dst);
|
|
9011
|
+
} break;
|
|
9012
|
+
case WSP_GGML_UNARY_OP_XIELU:
|
|
9013
|
+
{
|
|
9014
|
+
wsp_ggml_compute_forward_xielu(params, dst);
|
|
9015
|
+
} break;
|
|
9249
9016
|
default:
|
|
9250
9017
|
{
|
|
9251
9018
|
WSP_GGML_ABORT("fatal error");
|
|
@@ -9881,8 +9648,8 @@ static void wsp_ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
9881
9648
|
int64_t h_stride_2d = head_size * head_size;
|
|
9882
9649
|
|
|
9883
9650
|
#if defined(WSP_GGML_SIMD)
|
|
9884
|
-
#if defined(__ARM_FEATURE_SVE)
|
|
9885
|
-
// scalar Route to scalar implementation //TODO: Write SVE code
|
|
9651
|
+
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
|
|
9652
|
+
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
|
|
9886
9653
|
for (int64_t t = 0; t < T; t++) {
|
|
9887
9654
|
int64_t t_offset = t * t_stride;
|
|
9888
9655
|
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|