whispercpp 1.3.0 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -0,0 +1,95 @@
|
|
1
|
+
//
|
2
|
+
// MIT license
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
4
|
+
// SPDX-License-Identifier: MIT
|
5
|
+
//
|
6
|
+
|
7
|
+
//
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
11
|
+
//
|
12
|
+
|
13
|
+
#include "common.hpp"
|
14
|
+
#include "ggml-impl.h"
|
15
|
+
|
16
|
+
int get_current_device_id() {
|
17
|
+
return dpct::dev_mgr::instance().current_device_id();
|
18
|
+
}
|
19
|
+
|
20
|
+
void* ggml_sycl_host_malloc(size_t size) try {
|
21
|
+
if (getenv("GGML_SYCL_NO_PINNED") != nullptr) {
|
22
|
+
return nullptr;
|
23
|
+
}
|
24
|
+
|
25
|
+
void* ptr = nullptr;
|
26
|
+
// allow to use dpct::get_in_order_queue() for host malloc
|
27
|
+
dpct::err0 err = CHECK_TRY_ERROR(
|
28
|
+
ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue()));
|
29
|
+
|
30
|
+
if (err != 0) {
|
31
|
+
// clear the error
|
32
|
+
GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported");
|
33
|
+
return nullptr;
|
34
|
+
}
|
35
|
+
|
36
|
+
return ptr;
|
37
|
+
} catch (sycl::exception const& exc) {
|
38
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
39
|
+
<< ", line:" << __LINE__ << std::endl;
|
40
|
+
std::exit(1);
|
41
|
+
}
|
42
|
+
|
43
|
+
void ggml_sycl_host_free(void* ptr) try {
|
44
|
+
// allow to use dpct::get_in_order_queue() for host malloc
|
45
|
+
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
|
46
|
+
} catch (sycl::exception const& exc) {
|
47
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
48
|
+
<< ", line:" << __LINE__ << std::endl;
|
49
|
+
std::exit(1);
|
50
|
+
}
|
51
|
+
|
52
|
+
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
|
53
|
+
const int64_t max_range = std::numeric_limits<int>::max();
|
54
|
+
int64_t sycl_down_blk_size = block_size;
|
55
|
+
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
|
56
|
+
while(global_range > max_range) {
|
57
|
+
sycl_down_blk_size /= 2;
|
58
|
+
global_range = accumulate_block_num * sycl_down_blk_size;
|
59
|
+
}
|
60
|
+
return sycl_down_blk_size;
|
61
|
+
}
|
62
|
+
|
63
|
+
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
64
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
65
|
+
const ggml_sycl_op_flatten_t op) try {
|
66
|
+
|
67
|
+
const bool use_src1 = src1 != nullptr;
|
68
|
+
|
69
|
+
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
70
|
+
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
71
|
+
|
72
|
+
// dd = data device
|
73
|
+
float * src0_ddf = (float *) src0->data;
|
74
|
+
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
75
|
+
float * dst_ddf = (float *) dst->data;
|
76
|
+
|
77
|
+
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
78
|
+
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
79
|
+
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
80
|
+
|
81
|
+
ggml_sycl_set_device(ctx.device);
|
82
|
+
queue_ptr main_stream = ctx.stream();
|
83
|
+
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
84
|
+
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
85
|
+
|
86
|
+
// do the computation
|
87
|
+
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
88
|
+
// print_ggml_tensor("tensor", dst);
|
89
|
+
}
|
90
|
+
catch (sycl::exception const &exc) {
|
91
|
+
|
92
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
93
|
+
<< ", line:" << __LINE__ << std::endl;
|
94
|
+
std::exit(1);
|
95
|
+
}
|
@@ -0,0 +1,196 @@
|
|
1
|
+
//
|
2
|
+
// MIT license
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
4
|
+
// SPDX-License-Identifier: MIT
|
5
|
+
//
|
6
|
+
|
7
|
+
//
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
11
|
+
//
|
12
|
+
|
13
|
+
#include "concat.hpp"
|
14
|
+
#include "common.hpp"
|
15
|
+
|
16
|
+
static void concat_f32_dim0(const float *x, const float *y, float *dst,
|
17
|
+
const int ne0, const int ne00,
|
18
|
+
const sycl::nd_item<3> &item_ct1) {
|
19
|
+
int nidx = item_ct1.get_local_id(2) +
|
20
|
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
21
|
+
if (nidx >= ne0) {
|
22
|
+
return;
|
23
|
+
}
|
24
|
+
// operation
|
25
|
+
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
26
|
+
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
27
|
+
if (nidx < ne00) { // src0
|
28
|
+
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
29
|
+
item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);
|
30
|
+
dst[offset_dst] = x[offset_src];
|
31
|
+
} else {
|
32
|
+
int offset_src =
|
33
|
+
nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +
|
34
|
+
item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);
|
35
|
+
dst[offset_dst] = y[offset_src];
|
36
|
+
}
|
37
|
+
}
|
38
|
+
|
39
|
+
static void concat_f32_dim1(const float *x, const float *y, float *dst,
|
40
|
+
const int ne0, const int ne01,
|
41
|
+
const sycl::nd_item<3> &item_ct1) {
|
42
|
+
int nidx = item_ct1.get_local_id(2) +
|
43
|
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
44
|
+
if (nidx >= ne0) {
|
45
|
+
return;
|
46
|
+
}
|
47
|
+
// operation
|
48
|
+
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
49
|
+
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
50
|
+
if (item_ct1.get_group(1) < (size_t) ne01) { // src0
|
51
|
+
int offset_src =
|
52
|
+
nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
|
53
|
+
dst[offset_dst] = x[offset_src];
|
54
|
+
} else {
|
55
|
+
int offset_src =
|
56
|
+
nidx + (item_ct1.get_group(1) - ne01) * ne0 +
|
57
|
+
item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);
|
58
|
+
dst[offset_dst] = y[offset_src];
|
59
|
+
}
|
60
|
+
}
|
61
|
+
|
62
|
+
static void concat_f32_dim2(const float *x, const float *y, float *dst,
|
63
|
+
const int ne0, const int ne02,
|
64
|
+
const sycl::nd_item<3> &item_ct1) {
|
65
|
+
int nidx = item_ct1.get_local_id(2) +
|
66
|
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
67
|
+
if (nidx >= ne0) {
|
68
|
+
return;
|
69
|
+
}
|
70
|
+
// operation
|
71
|
+
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
72
|
+
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
73
|
+
if (item_ct1.get_group(0) < (size_t) ne02) { // src0
|
74
|
+
int offset_src = nidx + item_ct1.get_group(1) * ne0 +
|
75
|
+
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
76
|
+
dst[offset_dst] = x[offset_src];
|
77
|
+
} else {
|
78
|
+
int offset_src =
|
79
|
+
nidx + item_ct1.get_group(1) * ne0 +
|
80
|
+
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
|
81
|
+
dst[offset_dst] = y[offset_src];
|
82
|
+
}
|
83
|
+
}
|
84
|
+
|
85
|
+
static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
86
|
+
int ne00, int ne01, int ne02, int ne0, int ne1,
|
87
|
+
int ne2, int dim, queue_ptr stream) {
|
88
|
+
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
|
89
|
+
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
90
|
+
switch (dim) {
|
91
|
+
case 0:
|
92
|
+
stream->parallel_for(
|
93
|
+
sycl::nd_range<3>(gridDim *
|
94
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
95
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
96
|
+
[=](sycl::nd_item<3> item_ct1) {
|
97
|
+
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
|
98
|
+
});
|
99
|
+
break;
|
100
|
+
case 1:
|
101
|
+
stream->parallel_for(
|
102
|
+
sycl::nd_range<3>(gridDim *
|
103
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
104
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
105
|
+
[=](sycl::nd_item<3> item_ct1) {
|
106
|
+
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
107
|
+
});
|
108
|
+
break;
|
109
|
+
// dim >=2 will be dispatched to the default path
|
110
|
+
default:
|
111
|
+
stream->parallel_for(
|
112
|
+
sycl::nd_range<3>(gridDim *
|
113
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
114
|
+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
115
|
+
[=](sycl::nd_item<3> item_ct1) {
|
116
|
+
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
|
117
|
+
});
|
118
|
+
break;
|
119
|
+
}
|
120
|
+
}
|
121
|
+
|
122
|
+
// non-contiguous kernel (slow)
|
123
|
+
static void concat_f32_sycl_non_cont(
|
124
|
+
queue_ptr stream, const char *src0, const char *src1, char *dst,
|
125
|
+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
|
126
|
+
uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
|
127
|
+
int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,
|
128
|
+
uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,
|
129
|
+
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
|
130
|
+
uint64_t nb3, int32_t dim) {
|
131
|
+
sycl::range<3> gridDim(ne3, ne2, ne1);
|
132
|
+
stream->parallel_for(
|
133
|
+
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
|
134
|
+
[=](sycl::nd_item<3> item_ct1) {
|
135
|
+
int64_t i3 = item_ct1.get_group(0);
|
136
|
+
int64_t i2 = item_ct1.get_group(1);
|
137
|
+
int64_t i1 = item_ct1.get_group(2);
|
138
|
+
|
139
|
+
int64_t o[4] = {0, 0, 0, 0};
|
140
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
141
|
+
|
142
|
+
const float *x;
|
143
|
+
|
144
|
+
for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
|
145
|
+
i0 += item_ct1.get_local_range(2)) {
|
146
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
147
|
+
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
|
148
|
+
(i0)*nb00);
|
149
|
+
} else {
|
150
|
+
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
|
151
|
+
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
|
152
|
+
}
|
153
|
+
|
154
|
+
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
|
155
|
+
|
156
|
+
*y = *x;
|
157
|
+
}
|
158
|
+
});
|
159
|
+
}
|
160
|
+
|
161
|
+
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
162
|
+
const ggml_tensor *src1, ggml_tensor *dst) {
|
163
|
+
queue_ptr stream = ctx.stream();
|
164
|
+
|
165
|
+
const int32_t dim = ((int32_t *)dst->op_params)[0];
|
166
|
+
|
167
|
+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
168
|
+
const float *src0_d = (const float *)src0->data;
|
169
|
+
const float *src1_d = (const float *)src1->data;
|
170
|
+
|
171
|
+
float *dst_d = (float *)dst->data;
|
172
|
+
|
173
|
+
if (dim != 3) {
|
174
|
+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
175
|
+
concat_f32_sycl(
|
176
|
+
src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
|
177
|
+
dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1],
|
178
|
+
src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
179
|
+
}
|
180
|
+
} else {
|
181
|
+
const size_t size0 = ggml_nbytes(src0);
|
182
|
+
const size_t size1 = ggml_nbytes(src1);
|
183
|
+
|
184
|
+
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
|
185
|
+
SYCL_CHECK(CHECK_TRY_ERROR(
|
186
|
+
stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
|
187
|
+
}
|
188
|
+
} else
|
189
|
+
concat_f32_sycl_non_cont(
|
190
|
+
stream, (const char *)src0->data, (const char *)src1->data,
|
191
|
+
(char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
192
|
+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
|
193
|
+
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1],
|
194
|
+
src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
195
|
+
dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
|
196
|
+
}
|
@@ -0,0 +1,99 @@
|
|
1
|
+
//
|
2
|
+
// MIT license
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
4
|
+
// SPDX-License-Identifier: MIT
|
5
|
+
//
|
6
|
+
|
7
|
+
//
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
11
|
+
//
|
12
|
+
|
13
|
+
#include "conv.hpp"
|
14
|
+
|
15
|
+
static void conv_transpose_1d_kernel(
|
16
|
+
const int s0, const int output_size,
|
17
|
+
const int src0_ne0, const int src0_ne1, const int src0_ne2,
|
18
|
+
const int src1_ne0, const int dst_ne0,
|
19
|
+
const float * src0, const float * src1, float * dst,
|
20
|
+
const sycl::nd_item<3> &item_ct1) {
|
21
|
+
int global_index = item_ct1.get_local_id(2) +
|
22
|
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
23
|
+
if (global_index >= output_size) {
|
24
|
+
return;
|
25
|
+
}
|
26
|
+
|
27
|
+
int out_index = global_index / dst_ne0;
|
28
|
+
|
29
|
+
float accumulator = 0;
|
30
|
+
|
31
|
+
for (int c = 0; c < src0_ne2; c++) {
|
32
|
+
int idx = global_index % dst_ne0;
|
33
|
+
|
34
|
+
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
|
35
|
+
int input_offset = src1_ne0 * c;
|
36
|
+
|
37
|
+
for (int i = 0; i < src1_ne0; i++) {
|
38
|
+
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
|
39
|
+
continue;
|
40
|
+
}
|
41
|
+
int weight_idx = idx - i*s0;
|
42
|
+
|
43
|
+
float kernel_weight = src0[kernel_offset + weight_idx];
|
44
|
+
float input_value = src1[input_offset+i];
|
45
|
+
|
46
|
+
accumulator += kernel_weight * input_value;
|
47
|
+
}
|
48
|
+
}
|
49
|
+
dst[global_index] = accumulator;
|
50
|
+
}
|
51
|
+
|
52
|
+
static void conv_transpose_1d_f32_f32_sycl(
|
53
|
+
const int s0, const int output_size,
|
54
|
+
const int src0_ne0, const int src0_ne1, const int src0_ne2,
|
55
|
+
const int src1_ne0, const int dst_ne0,
|
56
|
+
const float *src0, const float *src1, float *dst,
|
57
|
+
const queue_ptr& stream) {
|
58
|
+
|
59
|
+
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
60
|
+
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
61
|
+
const sycl::range<3> block_nums(1, 1, num_blocks);
|
62
|
+
stream->parallel_for(
|
63
|
+
sycl::nd_range<3>(
|
64
|
+
block_nums * block_dims, block_dims),
|
65
|
+
[=](sycl::nd_item<3> item_ct1) {
|
66
|
+
conv_transpose_1d_kernel(
|
67
|
+
s0, output_size,
|
68
|
+
src0_ne0, src0_ne1, src0_ne2,
|
69
|
+
src1_ne0, dst_ne0,
|
70
|
+
src0, src1, dst, item_ct1);
|
71
|
+
});
|
72
|
+
}
|
73
|
+
|
74
|
+
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
75
|
+
const ggml_tensor *src1, ggml_tensor *dst) {
|
76
|
+
const float * src0_d = (const float *)src0->data;
|
77
|
+
const float * src1_d = (const float *)src1->data;
|
78
|
+
|
79
|
+
float * dst_d = (float *)dst->data;
|
80
|
+
dpct::queue_ptr stream = ctx.stream();
|
81
|
+
|
82
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
83
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
84
|
+
|
85
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
86
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
87
|
+
|
88
|
+
const int32_t * opts = (const int32_t *)dst->op_params;
|
89
|
+
|
90
|
+
const int s0 = opts[0];
|
91
|
+
|
92
|
+
const int64_t output_size = ggml_nelements(dst);
|
93
|
+
|
94
|
+
conv_transpose_1d_f32_f32_sycl(s0, output_size,
|
95
|
+
src0->ne[0], src0->ne[1], src0->ne[2],
|
96
|
+
src1->ne[0], dst->ne[0],
|
97
|
+
src0_d, src1_d, dst_d, stream);
|
98
|
+
}
|
99
|
+
|