whisper.rn 0.5.1 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +38 -14
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
- package/cpp/ggml-cpu/ggml-cpu.c +17 -3
- package/cpp/ggml-cpu/ops.cpp +33 -17
- package/cpp/ggml-cpu/unary-ops.cpp +135 -0
- package/cpp/ggml-cpu/unary-ops.h +5 -0
- package/cpp/ggml-cpu/vec.cpp +66 -0
- package/cpp/ggml-cpu/vec.h +10 -8
- package/cpp/ggml-impl.h +51 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-device.cpp +199 -10
- package/cpp/ggml-metal/ggml-metal-device.h +18 -0
- package/cpp/ggml-metal/ggml-metal-device.m +27 -14
- package/cpp/ggml-metal/ggml-metal-impl.h +87 -7
- package/cpp/ggml-metal/ggml-metal-ops.cpp +513 -88
- package/cpp/ggml-metal/ggml-metal-ops.h +6 -0
- package/cpp/ggml-metal/ggml-metal.cpp +3 -3
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +166 -2
- package/cpp/ggml.h +66 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +4 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
package/cpp/ggml-impl.h
CHANGED
|
@@ -102,6 +102,9 @@ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
|
|
|
102
102
|
}
|
|
103
103
|
}
|
|
104
104
|
|
|
105
|
+
static inline float wsp_ggml_softplus(float input) {
|
|
106
|
+
return (input > 20.0f) ? input : logf(1 + expf(input));
|
|
107
|
+
}
|
|
105
108
|
//
|
|
106
109
|
// logging
|
|
107
110
|
//
|
|
@@ -562,14 +565,23 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
|
|
|
562
565
|
#define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
|
|
563
566
|
#define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
|
|
564
567
|
|
|
568
|
+
static inline int32_t wsp_ggml_node_get_use_count(const struct wsp_ggml_cgraph * cgraph, int node_idx) {
|
|
569
|
+
const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
|
|
570
|
+
|
|
571
|
+
size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
|
|
572
|
+
if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
|
|
573
|
+
return 0;
|
|
574
|
+
}
|
|
575
|
+
return cgraph->use_counts[hash_pos];
|
|
576
|
+
}
|
|
577
|
+
|
|
565
578
|
// return true if the node's results are only used by N other nodes
|
|
566
579
|
// and can be fused into their calculations.
|
|
567
580
|
static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
|
568
581
|
const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
|
|
569
582
|
|
|
570
583
|
// check the use count against how many we're replacing
|
|
571
|
-
|
|
572
|
-
if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
|
|
584
|
+
if (wsp_ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
|
|
573
585
|
return false;
|
|
574
586
|
}
|
|
575
587
|
|
|
@@ -635,6 +647,36 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
|
|
|
635
647
|
return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
|
636
648
|
}
|
|
637
649
|
|
|
650
|
+
WSP_GGML_API bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
|
|
651
|
+
const int * node_idxs,
|
|
652
|
+
int count,
|
|
653
|
+
const enum wsp_ggml_op * ops,
|
|
654
|
+
const int * outputs,
|
|
655
|
+
int num_outputs);
|
|
656
|
+
|
|
657
|
+
// Returns true if the subgraph formed by {node_idxs} can be fused
|
|
658
|
+
// checks whethers all nodes which are not part of outputs can be elided
|
|
659
|
+
// by checking if their num_uses are confined to the subgraph
|
|
660
|
+
static inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
|
|
661
|
+
int node_idx,
|
|
662
|
+
int count,
|
|
663
|
+
const enum wsp_ggml_op * ops,
|
|
664
|
+
const int * outputs,
|
|
665
|
+
int num_outputs) {
|
|
666
|
+
WSP_GGML_ASSERT(count < 32);
|
|
667
|
+
if (node_idx + count > cgraph->n_nodes) {
|
|
668
|
+
return false;
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
int idxs[32];
|
|
672
|
+
|
|
673
|
+
for (int i = 0; i < count; ++i) {
|
|
674
|
+
idxs[i] = node_idx + i;
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
return wsp_ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
|
|
678
|
+
}
|
|
679
|
+
|
|
638
680
|
#ifdef __cplusplus
|
|
639
681
|
}
|
|
640
682
|
#endif
|
|
@@ -648,6 +690,13 @@ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_id
|
|
|
648
690
|
return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
|
649
691
|
}
|
|
650
692
|
|
|
693
|
+
inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
|
|
694
|
+
int start_idx,
|
|
695
|
+
std::initializer_list<enum wsp_ggml_op> ops,
|
|
696
|
+
std::initializer_list<int> outputs = {}) {
|
|
697
|
+
return wsp_ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
|
698
|
+
}
|
|
699
|
+
|
|
651
700
|
// expose GGUF internals for test code
|
|
652
701
|
WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
|
|
653
702
|
WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
|
|
@@ -112,7 +112,7 @@ static bool wsp_ggml_mem_ranges_add_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggm
|
|
|
112
112
|
}
|
|
113
113
|
|
|
114
114
|
bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
115
|
-
for (int i = 0; i <
|
|
115
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
116
116
|
if (tensor->src[i]) {
|
|
117
117
|
wsp_ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
|
118
118
|
}
|
|
@@ -173,7 +173,7 @@ static bool wsp_ggml_mem_ranges_check_dst(wsp_ggml_mem_ranges_t mrs, const wsp_g
|
|
|
173
173
|
}
|
|
174
174
|
|
|
175
175
|
bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
|
|
176
|
-
for (int i = 0; i <
|
|
176
|
+
for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
|
|
177
177
|
if (tensor->src[i]) {
|
|
178
178
|
if (!wsp_ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
|
179
179
|
return false;
|
|
@@ -268,6 +268,25 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal
|
|
|
268
268
|
return res;
|
|
269
269
|
}
|
|
270
270
|
|
|
271
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
272
|
+
assert(op->op == WSP_GGML_OP_SUM);
|
|
273
|
+
|
|
274
|
+
char base[256];
|
|
275
|
+
char name[256];
|
|
276
|
+
|
|
277
|
+
snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
278
|
+
snprintf(name, 256, "%s", base);
|
|
279
|
+
|
|
280
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
281
|
+
if (res) {
|
|
282
|
+
return res;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
286
|
+
|
|
287
|
+
return res;
|
|
288
|
+
}
|
|
289
|
+
|
|
271
290
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
272
291
|
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
273
292
|
|
|
@@ -338,7 +357,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
|
|
|
338
357
|
char base[256];
|
|
339
358
|
char name[256];
|
|
340
359
|
|
|
341
|
-
|
|
360
|
+
const char * suffix = "";
|
|
361
|
+
|
|
362
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
363
|
+
suffix = "_4";
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
|
|
342
367
|
snprintf(name, 256, "%s", base);
|
|
343
368
|
|
|
344
369
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
@@ -352,15 +377,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
|
|
|
352
377
|
}
|
|
353
378
|
|
|
354
379
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
380
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
381
|
+
|
|
355
382
|
char base[256];
|
|
356
383
|
char name[256];
|
|
357
384
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
}
|
|
363
|
-
snprintf(name, 256, "%s", base);
|
|
385
|
+
const int nsg = (ne00 + 31)/32;
|
|
386
|
+
|
|
387
|
+
snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
388
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
364
389
|
|
|
365
390
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
366
391
|
if (res) {
|
|
@@ -369,7 +394,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_
|
|
|
369
394
|
|
|
370
395
|
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
371
396
|
|
|
372
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
397
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
|
|
373
398
|
|
|
374
399
|
return res;
|
|
375
400
|
}
|
|
@@ -918,6 +943,96 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
|
|
|
918
943
|
return res;
|
|
919
944
|
}
|
|
920
945
|
|
|
946
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
947
|
+
wsp_ggml_metal_library_t lib,
|
|
948
|
+
const struct wsp_ggml_tensor * op,
|
|
949
|
+
bool has_mask,
|
|
950
|
+
int32_t ncpsg) {
|
|
951
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
952
|
+
WSP_GGML_UNUSED(op);
|
|
953
|
+
|
|
954
|
+
char base[256];
|
|
955
|
+
char name[256];
|
|
956
|
+
|
|
957
|
+
snprintf(base, 256, "kernel_%s",
|
|
958
|
+
"flash_attn_ext_pad");
|
|
959
|
+
|
|
960
|
+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
|
961
|
+
base,
|
|
962
|
+
has_mask,
|
|
963
|
+
ncpsg);
|
|
964
|
+
|
|
965
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
966
|
+
if (res) {
|
|
967
|
+
return res;
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
971
|
+
|
|
972
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
|
973
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
|
974
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
|
975
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
|
976
|
+
|
|
977
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
|
978
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
|
979
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
|
980
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
|
981
|
+
//wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
|
982
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
|
983
|
+
|
|
984
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
985
|
+
|
|
986
|
+
wsp_ggml_metal_cv_free(cv);
|
|
987
|
+
|
|
988
|
+
return res;
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
992
|
+
wsp_ggml_metal_library_t lib,
|
|
993
|
+
const struct wsp_ggml_tensor * op,
|
|
994
|
+
int32_t nqptg,
|
|
995
|
+
int32_t ncpsg) {
|
|
996
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
997
|
+
WSP_GGML_UNUSED(op);
|
|
998
|
+
|
|
999
|
+
char base[256];
|
|
1000
|
+
char name[256];
|
|
1001
|
+
|
|
1002
|
+
snprintf(base, 256, "kernel_%s",
|
|
1003
|
+
"flash_attn_ext_blk");
|
|
1004
|
+
|
|
1005
|
+
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
|
1006
|
+
base,
|
|
1007
|
+
nqptg,
|
|
1008
|
+
ncpsg);
|
|
1009
|
+
|
|
1010
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1011
|
+
if (res) {
|
|
1012
|
+
return res;
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1016
|
+
|
|
1017
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
|
1018
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
|
1019
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
|
1020
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
|
1021
|
+
|
|
1022
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
|
1023
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
|
1024
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
|
1025
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
|
1026
|
+
wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
|
1027
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
|
1028
|
+
|
|
1029
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1030
|
+
|
|
1031
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1032
|
+
|
|
1033
|
+
return res;
|
|
1034
|
+
}
|
|
1035
|
+
|
|
921
1036
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
922
1037
|
wsp_ggml_metal_library_t lib,
|
|
923
1038
|
const wsp_ggml_tensor * op,
|
|
@@ -925,6 +1040,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
925
1040
|
bool has_sinks,
|
|
926
1041
|
bool has_bias,
|
|
927
1042
|
bool has_scap,
|
|
1043
|
+
bool has_kvpad,
|
|
928
1044
|
int32_t nsg) {
|
|
929
1045
|
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
930
1046
|
|
|
@@ -937,18 +1053,23 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
937
1053
|
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
|
938
1054
|
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
|
939
1055
|
|
|
1056
|
+
// do bounds checks for the mask?
|
|
1057
|
+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
|
1058
|
+
|
|
940
1059
|
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
|
941
1060
|
"flash_attn_ext",
|
|
942
1061
|
wsp_ggml_type_name(op->src[1]->type),
|
|
943
1062
|
dk,
|
|
944
1063
|
dv);
|
|
945
1064
|
|
|
946
|
-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
1065
|
+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
947
1066
|
base,
|
|
948
1067
|
has_mask,
|
|
949
1068
|
has_sinks,
|
|
950
1069
|
has_bias,
|
|
951
1070
|
has_scap,
|
|
1071
|
+
has_kvpad,
|
|
1072
|
+
bc_mask,
|
|
952
1073
|
ns10,
|
|
953
1074
|
ns20,
|
|
954
1075
|
nsg);
|
|
@@ -964,6 +1085,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
964
1085
|
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
965
1086
|
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
966
1087
|
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1088
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
|
1089
|
+
|
|
1090
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
|
967
1091
|
|
|
968
1092
|
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
|
969
1093
|
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
|
@@ -983,6 +1107,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
983
1107
|
bool has_sinks,
|
|
984
1108
|
bool has_bias,
|
|
985
1109
|
bool has_scap,
|
|
1110
|
+
bool has_kvpad,
|
|
986
1111
|
int32_t nsg,
|
|
987
1112
|
int32_t nwg) {
|
|
988
1113
|
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
@@ -1002,12 +1127,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
1002
1127
|
dk,
|
|
1003
1128
|
dv);
|
|
1004
1129
|
|
|
1005
|
-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%
|
|
1130
|
+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
|
1006
1131
|
base,
|
|
1007
1132
|
has_mask,
|
|
1008
1133
|
has_sinks,
|
|
1009
1134
|
has_bias,
|
|
1010
1135
|
has_scap,
|
|
1136
|
+
has_kvpad,
|
|
1011
1137
|
ns10,
|
|
1012
1138
|
ns20,
|
|
1013
1139
|
nsg, nwg);
|
|
@@ -1023,6 +1149,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
1023
1149
|
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
|
1024
1150
|
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
|
1025
1151
|
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
|
1152
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
|
1026
1153
|
|
|
1027
1154
|
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
|
1028
1155
|
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
|
@@ -1279,6 +1406,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(
|
|
|
1279
1406
|
return res;
|
|
1280
1407
|
}
|
|
1281
1408
|
|
|
1409
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1410
|
+
assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
|
|
1411
|
+
|
|
1412
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1413
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
1414
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1415
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1416
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
1417
|
+
|
|
1418
|
+
char base[256];
|
|
1419
|
+
char name[256];
|
|
1420
|
+
|
|
1421
|
+
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1422
|
+
snprintf(name, 256, "%s", base);
|
|
1423
|
+
|
|
1424
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1425
|
+
if (res) {
|
|
1426
|
+
return res;
|
|
1427
|
+
}
|
|
1428
|
+
|
|
1429
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1430
|
+
|
|
1431
|
+
return res;
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1282
1434
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1283
1435
|
assert(op->op == WSP_GGML_OP_UPSCALE);
|
|
1284
1436
|
|
|
@@ -1374,3 +1526,40 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding
|
|
|
1374
1526
|
return res;
|
|
1375
1527
|
}
|
|
1376
1528
|
|
|
1529
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1530
|
+
assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
|
|
1531
|
+
|
|
1532
|
+
char base[256];
|
|
1533
|
+
char name[256];
|
|
1534
|
+
|
|
1535
|
+
snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1536
|
+
snprintf(name, 256, "%s", base);
|
|
1537
|
+
|
|
1538
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1539
|
+
if (res) {
|
|
1540
|
+
return res;
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1544
|
+
|
|
1545
|
+
return res;
|
|
1546
|
+
}
|
|
1547
|
+
|
|
1548
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1549
|
+
assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
|
|
1550
|
+
|
|
1551
|
+
char base[256];
|
|
1552
|
+
char name[256];
|
|
1553
|
+
|
|
1554
|
+
snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1555
|
+
snprintf(name, 256, "%s", base);
|
|
1556
|
+
|
|
1557
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1558
|
+
if (res) {
|
|
1559
|
+
return res;
|
|
1560
|
+
}
|
|
1561
|
+
|
|
1562
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1563
|
+
|
|
1564
|
+
return res;
|
|
1565
|
+
}
|
|
@@ -109,6 +109,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_set_rows
|
|
|
109
109
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_repeat (wsp_ggml_metal_library_t lib, enum wsp_ggml_type tsrc);
|
|
110
110
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
111
111
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
112
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
112
113
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
113
114
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
114
115
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
@@ -129,11 +130,26 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm
|
|
|
129
130
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
130
131
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
131
132
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
133
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
132
134
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
133
135
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
134
136
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
135
137
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
136
138
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
139
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
140
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
141
|
+
|
|
142
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
143
|
+
wsp_ggml_metal_library_t lib,
|
|
144
|
+
const struct wsp_ggml_tensor * op,
|
|
145
|
+
bool has_mask,
|
|
146
|
+
int32_t ncpsg);
|
|
147
|
+
|
|
148
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
149
|
+
wsp_ggml_metal_library_t lib,
|
|
150
|
+
const struct wsp_ggml_tensor * op,
|
|
151
|
+
int32_t nqptg,
|
|
152
|
+
int32_t ncpsg);
|
|
137
153
|
|
|
138
154
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
139
155
|
wsp_ggml_metal_library_t lib,
|
|
@@ -142,6 +158,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
142
158
|
bool has_sinks,
|
|
143
159
|
bool has_bias,
|
|
144
160
|
bool has_scap,
|
|
161
|
+
bool has_kvpad,
|
|
145
162
|
int32_t nsg);
|
|
146
163
|
|
|
147
164
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
@@ -151,6 +168,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
151
168
|
bool has_sinks,
|
|
152
169
|
bool has_bias,
|
|
153
170
|
bool has_scap,
|
|
171
|
+
bool has_kvpad,
|
|
154
172
|
int32_t nsg,
|
|
155
173
|
int32_t nwg);
|
|
156
174
|
|
|
@@ -7,6 +7,8 @@
|
|
|
7
7
|
|
|
8
8
|
#include <Metal/Metal.h>
|
|
9
9
|
|
|
10
|
+
#include <stdatomic.h>
|
|
11
|
+
|
|
10
12
|
#ifndef TARGET_OS_VISION
|
|
11
13
|
#define TARGET_OS_VISION 0
|
|
12
14
|
#endif
|
|
@@ -22,6 +24,9 @@
|
|
|
22
24
|
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
|
23
25
|
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
|
24
26
|
|
|
27
|
+
// virtual address for GPU memory allocations
|
|
28
|
+
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
|
29
|
+
|
|
25
30
|
#if !WSP_GGML_METAL_EMBED_LIBRARY
|
|
26
31
|
// Here to assist with NSBundle Path Hack
|
|
27
32
|
@interface WSPGGMLMetalClass : NSObject
|
|
@@ -652,6 +657,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
652
657
|
case WSP_GGML_OP_SCALE:
|
|
653
658
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
654
659
|
return true;
|
|
660
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
661
|
+
return wsp_ggml_is_contiguous(op->src[0]) && wsp_ggml_is_contiguous(op->src[1]) &&
|
|
662
|
+
(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32) &&
|
|
663
|
+
op->src[1]->type == WSP_GGML_TYPE_F32 &&
|
|
664
|
+
op->type == WSP_GGML_TYPE_F32;
|
|
655
665
|
case WSP_GGML_OP_CLAMP:
|
|
656
666
|
return op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
657
667
|
case WSP_GGML_OP_SQR:
|
|
@@ -660,6 +670,8 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
660
670
|
case WSP_GGML_OP_COS:
|
|
661
671
|
case WSP_GGML_OP_LOG:
|
|
662
672
|
return wsp_ggml_is_contiguous(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
673
|
+
case WSP_GGML_OP_SUM:
|
|
674
|
+
return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
|
|
663
675
|
case WSP_GGML_OP_SUM_ROWS:
|
|
664
676
|
case WSP_GGML_OP_MEAN:
|
|
665
677
|
case WSP_GGML_OP_SOFT_MAX:
|
|
@@ -696,7 +708,8 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
696
708
|
return true;
|
|
697
709
|
case WSP_GGML_OP_FLASH_ATTN_EXT:
|
|
698
710
|
// for new head sizes, add checks here
|
|
699
|
-
if (op->src[0]->ne[0] !=
|
|
711
|
+
if (op->src[0]->ne[0] != 32 &&
|
|
712
|
+
op->src[0]->ne[0] != 40 &&
|
|
700
713
|
op->src[0]->ne[0] != 64 &&
|
|
701
714
|
op->src[0]->ne[0] != 80 &&
|
|
702
715
|
op->src[0]->ne[0] != 96 &&
|
|
@@ -780,9 +793,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
780
793
|
};
|
|
781
794
|
}
|
|
782
795
|
case WSP_GGML_OP_GET_ROWS:
|
|
783
|
-
|
|
784
|
-
return op->ne[3] == 1;
|
|
785
|
-
}
|
|
796
|
+
return true;
|
|
786
797
|
case WSP_GGML_OP_SET_ROWS:
|
|
787
798
|
{
|
|
788
799
|
if (op->src[0]->type != WSP_GGML_TYPE_F32) {
|
|
@@ -804,6 +815,9 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
|
|
|
804
815
|
return false;
|
|
805
816
|
};
|
|
806
817
|
}
|
|
818
|
+
case WSP_GGML_OP_OPT_STEP_ADAMW:
|
|
819
|
+
case WSP_GGML_OP_OPT_STEP_SGD:
|
|
820
|
+
return has_simdgroup_reduction;
|
|
807
821
|
default:
|
|
808
822
|
return false;
|
|
809
823
|
}
|
|
@@ -828,7 +842,7 @@ struct wsp_ggml_metal_buffer_wrapper {
|
|
|
828
842
|
};
|
|
829
843
|
|
|
830
844
|
struct wsp_ggml_metal_buffer {
|
|
831
|
-
void * all_data;
|
|
845
|
+
void * all_data;
|
|
832
846
|
size_t all_size;
|
|
833
847
|
|
|
834
848
|
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
|
@@ -966,14 +980,15 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
|
|
|
966
980
|
if (shared) {
|
|
967
981
|
res->all_data = wsp_ggml_metal_host_malloc(size_aligned);
|
|
968
982
|
res->is_shared = true;
|
|
969
|
-
res->owned = true;
|
|
970
983
|
} else {
|
|
971
|
-
//
|
|
972
|
-
res->all_data = (void *)
|
|
984
|
+
// use virtual address from g_addr_device counter
|
|
985
|
+
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
|
|
973
986
|
res->is_shared = false;
|
|
974
987
|
}
|
|
975
988
|
res->all_size = size_aligned;
|
|
976
989
|
|
|
990
|
+
res->owned = true;
|
|
991
|
+
|
|
977
992
|
res->device = wsp_ggml_metal_device_get_obj(dev);
|
|
978
993
|
res->queue = wsp_ggml_metal_device_get_queue(dev);
|
|
979
994
|
|
|
@@ -984,15 +999,13 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
|
|
|
984
999
|
res->buffers[0].metal = nil;
|
|
985
1000
|
|
|
986
1001
|
if (size_aligned > 0) {
|
|
987
|
-
if (props_dev->use_shared_buffers &&shared) {
|
|
1002
|
+
if (props_dev->use_shared_buffers && shared) {
|
|
988
1003
|
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
|
989
1004
|
length:size_aligned
|
|
990
1005
|
options:MTLResourceStorageModeShared
|
|
991
1006
|
deallocator:nil];
|
|
992
1007
|
} else {
|
|
993
1008
|
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
|
994
|
-
|
|
995
|
-
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
|
|
996
1009
|
}
|
|
997
1010
|
}
|
|
998
1011
|
|
|
@@ -1140,7 +1153,7 @@ bool wsp_ggml_metal_buffer_is_shared(wsp_ggml_metal_buffer_t buf) {
|
|
|
1140
1153
|
|
|
1141
1154
|
void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
|
1142
1155
|
if (buf->is_shared) {
|
|
1143
|
-
memset((char *)tensor->data + offset, value, size);
|
|
1156
|
+
memset((char *) tensor->data + offset, value, size);
|
|
1144
1157
|
return;
|
|
1145
1158
|
}
|
|
1146
1159
|
|
|
@@ -1169,7 +1182,7 @@ void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp
|
|
|
1169
1182
|
|
|
1170
1183
|
void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1171
1184
|
if (buf->is_shared) {
|
|
1172
|
-
memcpy((char *)tensor->data + offset, data, size);
|
|
1185
|
+
memcpy((char *) tensor->data + offset, data, size);
|
|
1173
1186
|
return;
|
|
1174
1187
|
}
|
|
1175
1188
|
|
|
@@ -1224,7 +1237,7 @@ void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_gg
|
|
|
1224
1237
|
|
|
1225
1238
|
void wsp_ggml_metal_buffer_get_tensor(wsp_ggml_metal_buffer_t buf, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1226
1239
|
if (buf->is_shared) {
|
|
1227
|
-
memcpy(data, (const char *)tensor->data + offset, size);
|
|
1240
|
+
memcpy(data, (const char *) tensor->data + offset, size);
|
|
1228
1241
|
return;
|
|
1229
1242
|
}
|
|
1230
1243
|
|