@fugood/llama.node 0.3.15 → 0.3.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +5 -0
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +8 -0
- package/src/LlamaCompletionWorker.h +1 -0
- package/src/LlamaContext.cpp +3 -2
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
- package/src/llama.cpp/.github/workflows/build.yml +70 -27
- package/src/llama.cpp/.github/workflows/docker.yml +6 -6
- package/src/llama.cpp/.github/workflows/server.yml +7 -11
- package/src/llama.cpp/CMakeLists.txt +23 -1
- package/src/llama.cpp/common/CMakeLists.txt +6 -3
- package/src/llama.cpp/common/arg.cpp +809 -105
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +1 -1
- package/src/llama.cpp/common/common.cpp +31 -521
- package/src/llama.cpp/common/common.h +17 -36
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +30 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
- package/src/llama.cpp/common/minja/minja.hpp +119 -93
- package/src/llama.cpp/common/sampling.cpp +3 -0
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +0 -9
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
- package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
- package/src/llama.cpp/examples/llava/clip.h +39 -22
- package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/examples/llava/llava.cpp +64 -52
- package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
- package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
- package/src/llama.cpp/examples/llava/mtmd.h +168 -0
- package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
- package/src/llama.cpp/examples/main/main.cpp +16 -5
- package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
- package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
- package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
- package/src/llama.cpp/examples/run/run.cpp +14 -28
- package/src/llama.cpp/examples/server/httplib.h +313 -247
- package/src/llama.cpp/examples/server/server.cpp +243 -139
- package/src/llama.cpp/examples/server/utils.hpp +51 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +14 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +66 -99
- package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
- package/src/llama.cpp/ggml/src/ggml.c +141 -245
- package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
- package/src/llama.cpp/include/llama.h +30 -11
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +2 -0
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/src/CMakeLists.txt +3 -2
- package/src/llama.cpp/src/llama-adapter.cpp +37 -1
- package/src/llama.cpp/src/llama-arch.cpp +161 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-chat.cpp +82 -17
- package/src/llama.cpp/src/llama-chat.h +6 -2
- package/src/llama.cpp/src/llama-context.cpp +108 -92
- package/src/llama.cpp/src/llama-context.h +1 -2
- package/src/llama.cpp/src/llama-graph.cpp +189 -119
- package/src/llama.cpp/src/llama-graph.h +26 -6
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
- package/src/llama.cpp/src/llama-kv-cache.h +41 -115
- package/src/llama.cpp/src/llama-memory.h +1 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model.cpp +1544 -291
- package/src/llama.cpp/src/llama-model.h +13 -1
- package/src/llama.cpp/src/llama-quant.cpp +29 -8
- package/src/llama.cpp/src/llama-sampling.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.cpp +44 -6
- package/src/llama.cpp/src/llama.cpp +1 -1
- package/src/llama.cpp/tests/CMakeLists.txt +43 -30
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
- package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
- package/src/llama.cpp/tests/test-chat.cpp +12 -2
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
|
@@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0;
|
|
|
170
170
|
int get_current_device_id();
|
|
171
171
|
|
|
172
172
|
inline dpct::err0 ggml_sycl_set_device(const int device) try {
|
|
173
|
-
|
|
174
173
|
int current_device_id;
|
|
175
174
|
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
|
|
176
175
|
|
|
@@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc {
|
|
|
242
241
|
}
|
|
243
242
|
}
|
|
244
243
|
|
|
244
|
+
T * realloc(size_t size) {
|
|
245
|
+
GGML_ASSERT(pool != nullptr);
|
|
246
|
+
if (ptr)
|
|
247
|
+
pool->free(ptr, actual_size);
|
|
248
|
+
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
|
|
249
|
+
return ptr;
|
|
250
|
+
}
|
|
251
|
+
|
|
245
252
|
// size is in number of elements
|
|
246
253
|
T * alloc(size_t size) {
|
|
247
254
|
GGML_ASSERT(pool != nullptr);
|
|
@@ -306,7 +313,6 @@ struct ggml_backend_sycl_context {
|
|
|
306
313
|
int device;
|
|
307
314
|
std::string name;
|
|
308
315
|
optimize_feature opt_feature;
|
|
309
|
-
bool optimized_graph=false;
|
|
310
316
|
|
|
311
317
|
queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
|
|
312
318
|
|
|
@@ -371,10 +377,29 @@ struct ggml_backend_sycl_context {
|
|
|
371
377
|
dnnl::stream stream_dnnl() {
|
|
372
378
|
return stream_dnnl(device, 0);
|
|
373
379
|
}
|
|
380
|
+
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
|
|
381
|
+
const dnnl::engine & eng, const queue_ptr q) {
|
|
382
|
+
ggml_sycl_pool_alloc<uint8_t> * pool;
|
|
383
|
+
auto it = scratchpad_map.find(q);
|
|
384
|
+
if (it == scratchpad_map.end()) {
|
|
385
|
+
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
|
|
386
|
+
pool = scratchpad_map[q].get();
|
|
387
|
+
} else {
|
|
388
|
+
pool = it->second.get();
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
size_t scratchpad_size = scratchpad_md.get_size();
|
|
392
|
+
if (scratchpad_size > pool->actual_size) {
|
|
393
|
+
pool->realloc(scratchpad_size);
|
|
394
|
+
}
|
|
395
|
+
void * mem_ptr = pool->get();
|
|
396
|
+
return dnnl::memory(scratchpad_md, eng, mem_ptr);
|
|
397
|
+
}
|
|
374
398
|
#endif
|
|
375
399
|
|
|
376
400
|
// pool
|
|
377
401
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
|
402
|
+
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
|
|
378
403
|
|
|
379
404
|
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
|
380
405
|
|
|
@@ -468,298 +493,9 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
|
|
468
493
|
|
|
469
494
|
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
|
470
495
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
ggml_tensor *dst, const float *src0_dd,
|
|
474
|
-
const float *src1_dd, float *dst_dd,
|
|
475
|
-
const queue_ptr &main_stream);
|
|
476
|
-
|
|
477
|
-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
478
|
-
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
|
479
|
-
int ne0, int ne1, int ne2, int ne3,
|
|
480
|
-
int ne10, int ne11, int ne12, int ne13,
|
|
481
|
-
/*int s0, */ int s1, int s2, int s3,
|
|
482
|
-
/*int s00,*/ int s01, int s02, int s03,
|
|
483
|
-
/*int s10,*/ int s11, int s12, int s13,
|
|
484
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
485
|
-
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
486
|
-
item_ct1.get_local_id(2);
|
|
487
|
-
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
488
|
-
item_ct1.get_local_id(1));
|
|
489
|
-
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
|
490
|
-
item_ct1.get_local_id(0)) /
|
|
491
|
-
ne3;
|
|
492
|
-
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
|
493
|
-
item_ct1.get_local_id(0)) %
|
|
494
|
-
ne3;
|
|
495
|
-
|
|
496
|
-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
|
497
|
-
return;
|
|
498
|
-
}
|
|
499
|
-
|
|
500
|
-
const int i11 = i1 % ne11;
|
|
501
|
-
const int i12 = i2 % ne12;
|
|
502
|
-
const int i13 = i3 % ne13;
|
|
503
|
-
|
|
504
|
-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
|
505
|
-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
|
506
|
-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
|
507
|
-
|
|
508
|
-
const src0_t * src0_row = src0 + i_src0;
|
|
509
|
-
const src1_t * src1_row = src1 + i_src1;
|
|
510
|
-
dst_t * dst_row = dst + i_dst;
|
|
511
|
-
|
|
512
|
-
for (int i0 = i0s; i0 < ne0;
|
|
513
|
-
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
|
514
|
-
const int i10 = i0 % ne10;
|
|
515
|
-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
|
516
|
-
}
|
|
517
|
-
}
|
|
518
|
-
|
|
519
|
-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
520
|
-
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
|
521
|
-
int ne0, int ne1, int ne2, int ne3,
|
|
522
|
-
int ne10, int ne11, int ne12, int ne13,
|
|
523
|
-
/*int s0, */ int s1, int s2, int s3,
|
|
524
|
-
/*int s00,*/ int s01, int s02, int s03,
|
|
525
|
-
/*int s10,*/ int s11, int s12, int s13,
|
|
526
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
527
|
-
|
|
528
|
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
529
|
-
item_ct1.get_local_id(2);
|
|
530
|
-
|
|
531
|
-
const int i3 = i/(ne2*ne1*ne0);
|
|
532
|
-
const int i2 = (i/(ne1*ne0)) % ne2;
|
|
533
|
-
const int i1 = (i/ne0) % ne1;
|
|
534
|
-
const int i0 = i % ne0;
|
|
535
|
-
|
|
536
|
-
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
|
537
|
-
return;
|
|
538
|
-
}
|
|
539
|
-
|
|
540
|
-
const int i11 = i1 % ne11;
|
|
541
|
-
const int i12 = i2 % ne12;
|
|
542
|
-
const int i13 = i3 % ne13;
|
|
543
|
-
|
|
544
|
-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
|
545
|
-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
|
546
|
-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
|
547
|
-
|
|
548
|
-
const src0_t * src0_row = src0 + i_src0;
|
|
549
|
-
const src1_t * src1_row = src1 + i_src1;
|
|
550
|
-
dst_t * dst_row = dst + i_dst;
|
|
551
|
-
|
|
552
|
-
const int i10 = i0 % ne10;
|
|
553
|
-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
|
554
|
-
}
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
template<float (*bin_op)(const float, const float)>
|
|
558
|
-
struct bin_bcast_sycl {
|
|
559
|
-
template <typename src0_t, typename src1_t, typename dst_t>
|
|
560
|
-
void operator()(ggml_backend_sycl_context & ctx,
|
|
561
|
-
const struct ggml_tensor *src0,
|
|
562
|
-
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
|
563
|
-
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
|
564
|
-
queue_ptr stream) {
|
|
565
|
-
|
|
566
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
|
567
|
-
|
|
568
|
-
int nr0 = ne10/ne0;
|
|
569
|
-
int nr1 = ne11/ne1;
|
|
570
|
-
int nr2 = ne12/ne2;
|
|
571
|
-
int nr3 = ne13/ne3;
|
|
572
|
-
|
|
573
|
-
int nr[4] = { nr0, nr1, nr2, nr3 };
|
|
574
|
-
|
|
575
|
-
// collapse dimensions until first broadcast dimension
|
|
576
|
-
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
|
577
|
-
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
|
578
|
-
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
|
579
|
-
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
|
580
|
-
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
|
581
|
-
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
|
582
|
-
auto collapse = [](int64_t cne[]) {
|
|
583
|
-
cne[0] *= cne[1];
|
|
584
|
-
cne[1] = cne[2];
|
|
585
|
-
cne[2] = cne[3];
|
|
586
|
-
cne[3] = 1;
|
|
587
|
-
};
|
|
588
|
-
|
|
589
|
-
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
|
590
|
-
cnb[1] *= cne[1];
|
|
591
|
-
cnb[2] *= cne[2];
|
|
592
|
-
cnb[3] *= cne[3];
|
|
593
|
-
};
|
|
594
|
-
|
|
595
|
-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
|
596
|
-
for (int i = 0; i < 4; i++) {
|
|
597
|
-
if (nr[i] != 1) {
|
|
598
|
-
break;
|
|
599
|
-
}
|
|
600
|
-
if (i > 0) {
|
|
601
|
-
collapse_nb(cnb, cne);
|
|
602
|
-
collapse_nb(cnb0, cne0);
|
|
603
|
-
collapse_nb(cnb1, cne1);
|
|
604
|
-
collapse(cne);
|
|
605
|
-
collapse(cne0);
|
|
606
|
-
collapse(cne1);
|
|
607
|
-
}
|
|
608
|
-
}
|
|
609
|
-
}
|
|
610
|
-
{
|
|
611
|
-
int64_t ne0 = cne[0];
|
|
612
|
-
int64_t ne1 = cne[1];
|
|
613
|
-
int64_t ne2 = cne[2];
|
|
614
|
-
int64_t ne3 = cne[3];
|
|
615
|
-
|
|
616
|
-
int64_t ne10 = cne1[0];
|
|
617
|
-
int64_t ne11 = cne1[1];
|
|
618
|
-
int64_t ne12 = cne1[2];
|
|
619
|
-
int64_t ne13 = cne1[3];
|
|
620
|
-
|
|
621
|
-
size_t nb0 = cnb[0];
|
|
622
|
-
size_t nb1 = cnb[1];
|
|
623
|
-
size_t nb2 = cnb[2];
|
|
624
|
-
size_t nb3 = cnb[3];
|
|
625
|
-
|
|
626
|
-
size_t nb00 = cnb0[0];
|
|
627
|
-
size_t nb01 = cnb0[1];
|
|
628
|
-
size_t nb02 = cnb0[2];
|
|
629
|
-
size_t nb03 = cnb0[3];
|
|
630
|
-
|
|
631
|
-
size_t nb10 = cnb1[0];
|
|
632
|
-
size_t nb11 = cnb1[1];
|
|
633
|
-
size_t nb12 = cnb1[2];
|
|
634
|
-
size_t nb13 = cnb1[3];
|
|
635
|
-
|
|
636
|
-
size_t s0 = nb0 / sizeof(dst_t);
|
|
637
|
-
size_t s1 = nb1 / sizeof(dst_t);
|
|
638
|
-
size_t s2 = nb2 / sizeof(dst_t);
|
|
639
|
-
size_t s3 = nb3 / sizeof(dst_t);
|
|
640
|
-
|
|
641
|
-
size_t s10 = nb10 / sizeof(src1_t);
|
|
642
|
-
size_t s11 = nb11 / sizeof(src1_t);
|
|
643
|
-
size_t s12 = nb12 / sizeof(src1_t);
|
|
644
|
-
size_t s13 = nb13 / sizeof(src1_t);
|
|
645
|
-
|
|
646
|
-
size_t s00 = nb00 / sizeof(src0_t);
|
|
647
|
-
size_t s01 = nb01 / sizeof(src0_t);
|
|
648
|
-
size_t s02 = nb02 / sizeof(src0_t);
|
|
649
|
-
size_t s03 = nb03 / sizeof(src0_t);
|
|
650
|
-
|
|
651
|
-
GGML_UNUSED(s00);
|
|
652
|
-
|
|
653
|
-
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
|
654
|
-
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
|
655
|
-
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
|
656
|
-
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
|
657
|
-
|
|
658
|
-
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
|
659
|
-
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
|
660
|
-
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
|
661
|
-
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
|
662
|
-
|
|
663
|
-
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
|
664
|
-
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
|
665
|
-
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
|
666
|
-
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
|
667
|
-
|
|
668
|
-
GGML_ASSERT(s0 == 1);
|
|
669
|
-
GGML_ASSERT(s10 == 1);
|
|
670
|
-
|
|
671
|
-
const int block_size = 128;
|
|
672
|
-
|
|
673
|
-
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
|
674
|
-
|
|
675
|
-
sycl::range<3> block_dims(1, 1, 1);
|
|
676
|
-
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
|
677
|
-
block_dims[1] = std::min<unsigned int>(
|
|
678
|
-
ne1, block_size / (unsigned int)block_dims[2]);
|
|
679
|
-
block_dims[0] = std::min(
|
|
680
|
-
std::min<unsigned int>(
|
|
681
|
-
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
|
682
|
-
(unsigned int)block_dims[1]),
|
|
683
|
-
64U);
|
|
684
|
-
|
|
685
|
-
sycl::range<3> block_nums(
|
|
686
|
-
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
|
687
|
-
(ne1 + block_dims[1] - 1) / block_dims[1],
|
|
688
|
-
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
|
689
|
-
|
|
690
|
-
if (block_nums[0] > 65535) {
|
|
691
|
-
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
|
692
|
-
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
|
693
|
-
{
|
|
694
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
|
695
|
-
{sycl::aspect::fp16});
|
|
696
|
-
|
|
697
|
-
stream->parallel_for(
|
|
698
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
|
699
|
-
sycl::range<3>(1, 1, block_size),
|
|
700
|
-
sycl::range<3>(1, 1, block_size)),
|
|
701
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
702
|
-
k_bin_bcast_unravel<bin_op>(
|
|
703
|
-
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
|
704
|
-
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
|
|
705
|
-
s03, s11, s12, s13, item_ct1);
|
|
706
|
-
});
|
|
707
|
-
}
|
|
708
|
-
} else {
|
|
709
|
-
/*
|
|
710
|
-
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
|
711
|
-
exceed the limit. To get the device limit, query
|
|
712
|
-
info::device::max_work_group_size. Adjust the work-group size if
|
|
713
|
-
needed.
|
|
714
|
-
*/
|
|
715
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
|
716
|
-
{sycl::aspect::fp16});
|
|
717
|
-
|
|
718
|
-
stream->parallel_for(
|
|
719
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
720
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
721
|
-
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
|
722
|
-
ne2, ne3, ne10, ne11, ne12, ne13,
|
|
723
|
-
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
|
724
|
-
item_ct1);
|
|
725
|
-
});
|
|
726
|
-
}
|
|
727
|
-
}
|
|
728
|
-
GGML_UNUSED(ctx);
|
|
729
|
-
}
|
|
730
|
-
};
|
|
731
|
-
|
|
732
|
-
template <class op>
|
|
733
|
-
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
734
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
|
735
|
-
const float *src0_dd, const float *src1_dd,
|
|
736
|
-
float *dst_dd,
|
|
737
|
-
const queue_ptr &main_stream) {
|
|
738
|
-
|
|
739
|
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
740
|
-
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
|
741
|
-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
742
|
-
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
|
743
|
-
(sycl::half *)dst_dd, main_stream);
|
|
744
|
-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
745
|
-
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
|
746
|
-
main_stream);
|
|
747
|
-
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
|
748
|
-
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
|
749
|
-
main_stream);
|
|
750
|
-
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
|
751
|
-
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
|
752
|
-
main_stream);
|
|
753
|
-
} else {
|
|
754
|
-
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
|
755
|
-
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
756
|
-
GGML_ABORT("fatal error");
|
|
757
|
-
}
|
|
496
|
+
constexpr size_t ceil_div(const size_t m, const size_t n) {
|
|
497
|
+
return (m + n - 1) / n;
|
|
758
498
|
}
|
|
759
499
|
|
|
760
500
|
bool gpu_has_xmx(sycl::device &dev);
|
|
761
|
-
|
|
762
|
-
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
763
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
|
764
|
-
const ggml_sycl_op_flatten_t op);
|
|
765
501
|
#endif // GGML_SYCL_COMMON_HPP
|
|
@@ -16,9 +16,18 @@
|
|
|
16
16
|
#include <sycl/sycl.hpp>
|
|
17
17
|
#include <sycl/half_type.hpp>
|
|
18
18
|
#include <syclcompat/math.hpp>
|
|
19
|
-
#include <oneapi/mkl.hpp>
|
|
20
19
|
#include <map>
|
|
21
20
|
|
|
21
|
+
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
|
22
|
+
#include <oneapi/mkl.hpp>
|
|
23
|
+
// Allow to use the same namespace for Intel oneMKL and oneMath
|
|
24
|
+
namespace oneapi {
|
|
25
|
+
namespace math = mkl;
|
|
26
|
+
}
|
|
27
|
+
#else
|
|
28
|
+
#include <oneapi/math.hpp>
|
|
29
|
+
#endif
|
|
30
|
+
|
|
22
31
|
#include "ggml.h"
|
|
23
32
|
|
|
24
33
|
#if defined(__linux__)
|
|
@@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
|
83
92
|
}
|
|
84
93
|
|
|
85
94
|
template <typename Ts> struct matrix_info_t {
|
|
86
|
-
oneapi::
|
|
95
|
+
oneapi::math::transpose transpose_info[2];
|
|
87
96
|
Ts value_info[2];
|
|
88
97
|
std::int64_t size_info[3];
|
|
89
98
|
std::int64_t ld_info[3];
|
|
90
99
|
std::int64_t groupsize_info;
|
|
91
100
|
};
|
|
92
101
|
|
|
102
|
+
inline auto get_onemath_backend(sycl::queue& queue)
|
|
103
|
+
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
104
|
+
-> sycl::queue&
|
|
105
|
+
#endif
|
|
106
|
+
{
|
|
107
|
+
// If the backend is known at compile-time, use oneMath backend_selector to use
|
|
108
|
+
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
|
109
|
+
// fallback to runtime dispatching.
|
|
110
|
+
#if defined(GGML_SYCL_NVIDIA)
|
|
111
|
+
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
|
112
|
+
#elif defined(GGML_SYCL_AMD)
|
|
113
|
+
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
|
114
|
+
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
115
|
+
return queue;
|
|
116
|
+
#else
|
|
117
|
+
static_assert(false, "Unsupported backend");
|
|
118
|
+
#endif
|
|
119
|
+
}
|
|
120
|
+
|
|
93
121
|
namespace dpct
|
|
94
122
|
{
|
|
95
123
|
typedef sycl::queue *queue_ptr;
|
|
@@ -1686,26 +1714,18 @@ namespace dpct
|
|
|
1686
1714
|
|
|
1687
1715
|
namespace detail
|
|
1688
1716
|
{
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
|
1702
|
-
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
1703
|
-
beta_value, data_c, ldc);
|
|
1704
|
-
#else
|
|
1705
|
-
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
1706
|
-
beta_value, data_c, ldc);
|
|
1707
|
-
#endif
|
|
1708
|
-
}
|
|
1717
|
+
template <class Ta, class Tb, class Tc, class Ts>
|
|
1718
|
+
inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
1719
|
+
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
|
1720
|
+
const void * beta, void * c, int ldc) {
|
|
1721
|
+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
1722
|
+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
1723
|
+
auto data_a = get_memory<const Ta>(a);
|
|
1724
|
+
auto data_b = get_memory<const Tb>(b);
|
|
1725
|
+
auto data_c = get_memory<Tc>(c);
|
|
1726
|
+
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
|
|
1727
|
+
lda, data_b, ldb, beta_value, data_c, ldc);
|
|
1728
|
+
}
|
|
1709
1729
|
|
|
1710
1730
|
template <typename VecT, class BinaryOperation, class = void>
|
|
1711
1731
|
class vectorized_binary
|
|
@@ -1735,7 +1755,7 @@ namespace dpct
|
|
|
1735
1755
|
};
|
|
1736
1756
|
|
|
1737
1757
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1738
|
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
|
1758
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
|
1739
1759
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
|
1740
1760
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
|
1741
1761
|
matrix_info_t<float> * matrix_info) {
|
|
@@ -1754,48 +1774,28 @@ namespace dpct
|
|
|
1754
1774
|
matrix_info->ld_info[2] = ldc;
|
|
1755
1775
|
matrix_info->groupsize_info = batch_size;
|
|
1756
1776
|
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
matrix_info->
|
|
1761
|
-
|
|
1762
|
-
reinterpret_cast<
|
|
1763
|
-
matrix_info->ld_info + 1,
|
|
1764
|
-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1765
|
-
#else
|
|
1766
|
-
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
1767
|
-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
1768
|
-
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
|
1769
|
-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
1770
|
-
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
|
1771
|
-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1772
|
-
#endif
|
|
1777
|
+
sycl::event e = oneapi::math::blas::column_major::gemm_batch(
|
|
1778
|
+
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
1779
|
+
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
|
1780
|
+
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
1781
|
+
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
1782
|
+
reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
|
1783
|
+
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1773
1784
|
}
|
|
1774
1785
|
|
|
1775
1786
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1776
|
-
inline void
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
long long int stride_a, const void *b, int ldb,
|
|
1781
|
-
long long int stride_b, const void *beta, void *c,
|
|
1782
|
-
int ldc, long long int stride_c, int batch_size)
|
|
1783
|
-
{
|
|
1787
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
|
1788
|
+
int m, int n, int k, const void * alpha, const void * a, int lda,
|
|
1789
|
+
long long int stride_a, const void * b, int ldb, long long int stride_b,
|
|
1790
|
+
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
|
1784
1791
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
1785
1792
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
1786
1793
|
auto data_a = get_memory<const Ta>(a);
|
|
1787
1794
|
auto data_b = get_memory<const Tb>(b);
|
|
1788
1795
|
auto data_c = get_memory<Tc>(c);
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
|
1793
|
-
batch_size);
|
|
1794
|
-
#else
|
|
1795
|
-
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
1796
|
-
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
|
1797
|
-
stride_c, batch_size);
|
|
1798
|
-
#endif
|
|
1796
|
+
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
|
|
1797
|
+
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
|
1798
|
+
data_c, ldc, stride_c, batch_size);
|
|
1799
1799
|
}
|
|
1800
1800
|
|
|
1801
1801
|
} // namespace detail
|
|
@@ -2259,13 +2259,10 @@ namespace dpct
|
|
|
2259
2259
|
sycl::range<3>(x, y, 1), direction);
|
|
2260
2260
|
}
|
|
2261
2261
|
|
|
2262
|
-
inline void gemm(sycl::queue &q, oneapi::
|
|
2263
|
-
|
|
2264
|
-
const void *
|
|
2265
|
-
|
|
2266
|
-
const void *beta, void *c, library_data_t c_type, int ldc,
|
|
2267
|
-
library_data_t scaling_type)
|
|
2268
|
-
{
|
|
2262
|
+
inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
|
|
2263
|
+
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
|
2264
|
+
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
2265
|
+
library_data_t scaling_type) {
|
|
2269
2266
|
if (scaling_type == library_data_t::real_float &&
|
|
2270
2267
|
c_type == library_data_t::complex_float)
|
|
2271
2268
|
{
|
|
@@ -2329,9 +2326,8 @@ namespace dpct
|
|
|
2329
2326
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2330
2327
|
library_data_t::real_float, library_data_t::real_float):
|
|
2331
2328
|
{
|
|
2332
|
-
detail::gemm_impl<oneapi::
|
|
2333
|
-
|
|
2334
|
-
ldb, beta, c, ldc);
|
|
2329
|
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
2330
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2335
2331
|
break;
|
|
2336
2332
|
}
|
|
2337
2333
|
case detail::get_type_combination_id(
|
|
@@ -2369,8 +2365,7 @@ namespace dpct
|
|
|
2369
2365
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2370
2366
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2371
2367
|
{
|
|
2372
|
-
detail::gemm_impl<oneapi::
|
|
2373
|
-
oneapi::mkl::bfloat16, float>(
|
|
2368
|
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
2374
2369
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2375
2370
|
break;
|
|
2376
2371
|
}
|
|
@@ -2390,7 +2385,7 @@ namespace dpct
|
|
|
2390
2385
|
default:
|
|
2391
2386
|
throw std::runtime_error("the combination of data type is unsupported");
|
|
2392
2387
|
}
|
|
2393
|
-
}
|
|
2388
|
+
} // gemm()
|
|
2394
2389
|
|
|
2395
2390
|
/// Computes a batch of matrix-matrix product with general matrices.
|
|
2396
2391
|
/// \param [in] q The queue where the routine should be executed.
|
|
@@ -2412,7 +2407,7 @@ namespace dpct
|
|
|
2412
2407
|
/// \param [in] ldc Leading dimension of C.
|
|
2413
2408
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2414
2409
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2415
|
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
|
2410
|
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
2416
2411
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
|
2417
2412
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
|
2418
2413
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
@@ -2450,7 +2445,7 @@ namespace dpct
|
|
|
2450
2445
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2451
2446
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2452
2447
|
{
|
|
2453
|
-
detail::gemm_batch_impl<oneapi::
|
|
2448
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
2454
2449
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2455
2450
|
break;
|
|
2456
2451
|
}
|
|
@@ -2458,7 +2453,7 @@ namespace dpct
|
|
|
2458
2453
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2459
2454
|
library_data_t::real_float, library_data_t::real_float):
|
|
2460
2455
|
{
|
|
2461
|
-
detail::gemm_batch_impl<oneapi::
|
|
2456
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
2462
2457
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2463
2458
|
break;
|
|
2464
2459
|
}
|
|
@@ -2534,15 +2529,11 @@ namespace dpct
|
|
|
2534
2529
|
/// \param [in] stride_c Stride between the different C matrices.
|
|
2535
2530
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2536
2531
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2537
|
-
inline void gemm_batch(sycl::queue &q, oneapi::
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
const void *beta, void *c, library_data_t c_type,
|
|
2543
|
-
int ldc, long long int stride_c, int batch_size,
|
|
2544
|
-
library_data_t scaling_type)
|
|
2545
|
-
{
|
|
2532
|
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
2533
|
+
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
|
2534
|
+
long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
|
2535
|
+
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
2536
|
+
long long int stride_c, int batch_size, library_data_t scaling_type) {
|
|
2546
2537
|
if (scaling_type == library_data_t::real_float &&
|
|
2547
2538
|
c_type == library_data_t::complex_float)
|
|
2548
2539
|
{
|
|
@@ -2611,20 +2602,18 @@ namespace dpct
|
|
|
2611
2602
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2612
2603
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2613
2604
|
{
|
|
2614
|
-
detail::gemm_batch_impl<oneapi::
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
beta, c, ldc, stride_c, batch_size);
|
|
2605
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
2606
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2607
|
+
batch_size);
|
|
2618
2608
|
break;
|
|
2619
2609
|
}
|
|
2620
2610
|
case detail::get_type_combination_id(
|
|
2621
2611
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2622
2612
|
library_data_t::real_float, library_data_t::real_float):
|
|
2623
2613
|
{
|
|
2624
|
-
detail::gemm_batch_impl<oneapi::
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
stride_c, batch_size);
|
|
2614
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
2615
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2616
|
+
batch_size);
|
|
2628
2617
|
break;
|
|
2629
2618
|
}
|
|
2630
2619
|
#endif
|