whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,95 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "common.hpp"
14
+ #include "ggml-impl.h"
15
+
16
+ int get_current_device_id() {
17
+ return dpct::dev_mgr::instance().current_device_id();
18
+ }
19
+
20
+ void* ggml_sycl_host_malloc(size_t size) try {
21
+ if (getenv("GGML_SYCL_NO_PINNED") != nullptr) {
22
+ return nullptr;
23
+ }
24
+
25
+ void* ptr = nullptr;
26
+ // allow to use dpct::get_in_order_queue() for host malloc
27
+ dpct::err0 err = CHECK_TRY_ERROR(
28
+ ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue()));
29
+
30
+ if (err != 0) {
31
+ // clear the error
32
+ GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported");
33
+ return nullptr;
34
+ }
35
+
36
+ return ptr;
37
+ } catch (sycl::exception const& exc) {
38
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
39
+ << ", line:" << __LINE__ << std::endl;
40
+ std::exit(1);
41
+ }
42
+
43
+ void ggml_sycl_host_free(void* ptr) try {
44
+ // allow to use dpct::get_in_order_queue() for host malloc
45
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
46
+ } catch (sycl::exception const& exc) {
47
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
48
+ << ", line:" << __LINE__ << std::endl;
49
+ std::exit(1);
50
+ }
51
+
52
+ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
53
+ const int64_t max_range = std::numeric_limits<int>::max();
54
+ int64_t sycl_down_blk_size = block_size;
55
+ int64_t global_range = accumulate_block_num * sycl_down_blk_size;
56
+ while(global_range > max_range) {
57
+ sycl_down_blk_size /= 2;
58
+ global_range = accumulate_block_num * sycl_down_blk_size;
59
+ }
60
+ return sycl_down_blk_size;
61
+ }
62
+
63
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
64
+ const ggml_tensor *src1, ggml_tensor *dst,
65
+ const ggml_sycl_op_flatten_t op) try {
66
+
67
+ const bool use_src1 = src1 != nullptr;
68
+
69
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
70
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
71
+
72
+ // dd = data device
73
+ float * src0_ddf = (float *) src0->data;
74
+ float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
75
+ float * dst_ddf = (float *) dst->data;
76
+
77
+ ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
78
+ ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
79
+ ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
80
+
81
+ ggml_sycl_set_device(ctx.device);
82
+ queue_ptr main_stream = ctx.stream();
83
+ // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
84
+ // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
85
+
86
+ // do the computation
87
+ op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
88
+ // print_ggml_tensor("tensor", dst);
89
+ }
90
+ catch (sycl::exception const &exc) {
91
+
92
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
93
+ << ", line:" << __LINE__ << std::endl;
94
+ std::exit(1);
95
+ }
@@ -0,0 +1,196 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "concat.hpp"
14
+ #include "common.hpp"
15
+
16
+ static void concat_f32_dim0(const float *x, const float *y, float *dst,
17
+ const int ne0, const int ne00,
18
+ const sycl::nd_item<3> &item_ct1) {
19
+ int nidx = item_ct1.get_local_id(2) +
20
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
21
+ if (nidx >= ne0) {
22
+ return;
23
+ }
24
+ // operation
25
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
26
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
27
+ if (nidx < ne00) { // src0
28
+ int offset_src = nidx + item_ct1.get_group(1) * ne00 +
29
+ item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);
30
+ dst[offset_dst] = x[offset_src];
31
+ } else {
32
+ int offset_src =
33
+ nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +
34
+ item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);
35
+ dst[offset_dst] = y[offset_src];
36
+ }
37
+ }
38
+
39
+ static void concat_f32_dim1(const float *x, const float *y, float *dst,
40
+ const int ne0, const int ne01,
41
+ const sycl::nd_item<3> &item_ct1) {
42
+ int nidx = item_ct1.get_local_id(2) +
43
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
44
+ if (nidx >= ne0) {
45
+ return;
46
+ }
47
+ // operation
48
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
49
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
50
+ if (item_ct1.get_group(1) < (size_t) ne01) { // src0
51
+ int offset_src =
52
+ nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
53
+ dst[offset_dst] = x[offset_src];
54
+ } else {
55
+ int offset_src =
56
+ nidx + (item_ct1.get_group(1) - ne01) * ne0 +
57
+ item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);
58
+ dst[offset_dst] = y[offset_src];
59
+ }
60
+ }
61
+
62
+ static void concat_f32_dim2(const float *x, const float *y, float *dst,
63
+ const int ne0, const int ne02,
64
+ const sycl::nd_item<3> &item_ct1) {
65
+ int nidx = item_ct1.get_local_id(2) +
66
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
67
+ if (nidx >= ne0) {
68
+ return;
69
+ }
70
+ // operation
71
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
72
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
73
+ if (item_ct1.get_group(0) < (size_t) ne02) { // src0
74
+ int offset_src = nidx + item_ct1.get_group(1) * ne0 +
75
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
76
+ dst[offset_dst] = x[offset_src];
77
+ } else {
78
+ int offset_src =
79
+ nidx + item_ct1.get_group(1) * ne0 +
80
+ (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
81
+ dst[offset_dst] = y[offset_src];
82
+ }
83
+ }
84
+
85
+ static void concat_f32_sycl(const float *x, const float *y, float *dst,
86
+ int ne00, int ne01, int ne02, int ne0, int ne1,
87
+ int ne2, int dim, queue_ptr stream) {
88
+ int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
89
+ sycl::range<3> gridDim(ne2, ne1, num_blocks);
90
+ switch (dim) {
91
+ case 0:
92
+ stream->parallel_for(
93
+ sycl::nd_range<3>(gridDim *
94
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96
+ [=](sycl::nd_item<3> item_ct1) {
97
+ concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98
+ });
99
+ break;
100
+ case 1:
101
+ stream->parallel_for(
102
+ sycl::nd_range<3>(gridDim *
103
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105
+ [=](sycl::nd_item<3> item_ct1) {
106
+ concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
+ });
108
+ break;
109
+ // dim >=2 will be dispatched to the default path
110
+ default:
111
+ stream->parallel_for(
112
+ sycl::nd_range<3>(gridDim *
113
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115
+ [=](sycl::nd_item<3> item_ct1) {
116
+ concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
117
+ });
118
+ break;
119
+ }
120
+ }
121
+
122
+ // non-contiguous kernel (slow)
123
+ static void concat_f32_sycl_non_cont(
124
+ queue_ptr stream, const char *src0, const char *src1, char *dst,
125
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
126
+ uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
127
+ int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,
128
+ uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,
129
+ int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130
+ uint64_t nb3, int32_t dim) {
131
+ sycl::range<3> gridDim(ne3, ne2, ne1);
132
+ stream->parallel_for(
133
+ sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
134
+ [=](sycl::nd_item<3> item_ct1) {
135
+ int64_t i3 = item_ct1.get_group(0);
136
+ int64_t i2 = item_ct1.get_group(1);
137
+ int64_t i1 = item_ct1.get_group(2);
138
+
139
+ int64_t o[4] = {0, 0, 0, 0};
140
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
141
+
142
+ const float *x;
143
+
144
+ for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
145
+ i0 += item_ct1.get_local_range(2)) {
146
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
147
+ x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
148
+ (i0)*nb00);
149
+ } else {
150
+ x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
151
+ (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
152
+ }
153
+
154
+ float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
155
+
156
+ *y = *x;
157
+ }
158
+ });
159
+ }
160
+
161
+ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
162
+ const ggml_tensor *src1, ggml_tensor *dst) {
163
+ queue_ptr stream = ctx.stream();
164
+
165
+ const int32_t dim = ((int32_t *)dst->op_params)[0];
166
+
167
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
168
+ const float *src0_d = (const float *)src0->data;
169
+ const float *src1_d = (const float *)src1->data;
170
+
171
+ float *dst_d = (float *)dst->data;
172
+
173
+ if (dim != 3) {
174
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
175
+ concat_f32_sycl(
176
+ src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
177
+ dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1],
178
+ src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
179
+ }
180
+ } else {
181
+ const size_t size0 = ggml_nbytes(src0);
182
+ const size_t size1 = ggml_nbytes(src1);
183
+
184
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
185
+ SYCL_CHECK(CHECK_TRY_ERROR(
186
+ stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
187
+ }
188
+ } else
189
+ concat_f32_sycl_non_cont(
190
+ stream, (const char *)src0->data, (const char *)src1->data,
191
+ (char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
192
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
193
+ src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1],
194
+ src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
195
+ dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
196
+ }
@@ -0,0 +1,99 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "conv.hpp"
14
+
15
+ static void conv_transpose_1d_kernel(
16
+ const int s0, const int output_size,
17
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
18
+ const int src1_ne0, const int dst_ne0,
19
+ const float * src0, const float * src1, float * dst,
20
+ const sycl::nd_item<3> &item_ct1) {
21
+ int global_index = item_ct1.get_local_id(2) +
22
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
23
+ if (global_index >= output_size) {
24
+ return;
25
+ }
26
+
27
+ int out_index = global_index / dst_ne0;
28
+
29
+ float accumulator = 0;
30
+
31
+ for (int c = 0; c < src0_ne2; c++) {
32
+ int idx = global_index % dst_ne0;
33
+
34
+ int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
35
+ int input_offset = src1_ne0 * c;
36
+
37
+ for (int i = 0; i < src1_ne0; i++) {
38
+ if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
39
+ continue;
40
+ }
41
+ int weight_idx = idx - i*s0;
42
+
43
+ float kernel_weight = src0[kernel_offset + weight_idx];
44
+ float input_value = src1[input_offset+i];
45
+
46
+ accumulator += kernel_weight * input_value;
47
+ }
48
+ }
49
+ dst[global_index] = accumulator;
50
+ }
51
+
52
+ static void conv_transpose_1d_f32_f32_sycl(
53
+ const int s0, const int output_size,
54
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
55
+ const int src1_ne0, const int dst_ne0,
56
+ const float *src0, const float *src1, float *dst,
57
+ const queue_ptr& stream) {
58
+
59
+ const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
60
+ const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
61
+ const sycl::range<3> block_nums(1, 1, num_blocks);
62
+ stream->parallel_for(
63
+ sycl::nd_range<3>(
64
+ block_nums * block_dims, block_dims),
65
+ [=](sycl::nd_item<3> item_ct1) {
66
+ conv_transpose_1d_kernel(
67
+ s0, output_size,
68
+ src0_ne0, src0_ne1, src0_ne2,
69
+ src1_ne0, dst_ne0,
70
+ src0, src1, dst, item_ct1);
71
+ });
72
+ }
73
+
74
+ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
75
+ const ggml_tensor *src1, ggml_tensor *dst) {
76
+ const float * src0_d = (const float *)src0->data;
77
+ const float * src1_d = (const float *)src1->data;
78
+
79
+ float * dst_d = (float *)dst->data;
80
+ dpct::queue_ptr stream = ctx.stream();
81
+
82
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
83
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
84
+
85
+ GGML_ASSERT(ggml_is_contiguous(src0));
86
+ GGML_ASSERT(ggml_is_contiguous(src1));
87
+
88
+ const int32_t * opts = (const int32_t *)dst->op_params;
89
+
90
+ const int s0 = opts[0];
91
+
92
+ const int64_t output_size = ggml_nelements(dst);
93
+
94
+ conv_transpose_1d_f32_f32_sycl(s0, output_size,
95
+ src0->ne[0], src0->ne[1], src0->ne[2],
96
+ src1->ne[0], dst->ne[0],
97
+ src0_d, src1_d, dst_d, stream);
98
+ }
99
+