@fugood/llama.node 0.3.9 → 0.3.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.js +2 -2
- package/lib/binding.ts +47 -8
- package/lib/index.js +21 -1
- package/lib/index.ts +31 -1
- package/package.json +12 -3
- package/src/LlamaCompletionWorker.cpp +33 -6
- package/src/LlamaCompletionWorker.h +3 -1
- package/src/LlamaContext.cpp +336 -28
- package/src/LlamaContext.h +2 -0
- package/src/common.hpp +19 -2
- package/src/llama.cpp/.github/workflows/build.yml +289 -107
- package/src/llama.cpp/.github/workflows/close-issue.yml +1 -1
- package/src/llama.cpp/.github/workflows/docker.yml +2 -1
- package/src/llama.cpp/.github/workflows/server.yml +25 -2
- package/src/llama.cpp/CMakeLists.txt +10 -19
- package/src/llama.cpp/cmake/build-info.cmake +1 -1
- package/src/llama.cpp/common/CMakeLists.txt +32 -0
- package/src/llama.cpp/common/arg.cpp +66 -16
- package/src/llama.cpp/common/chat-template.hpp +515 -0
- package/src/llama.cpp/common/chat.cpp +966 -0
- package/src/llama.cpp/common/chat.hpp +52 -0
- package/src/llama.cpp/common/common.cpp +159 -36
- package/src/llama.cpp/common/common.h +56 -14
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +46 -66
- package/src/llama.cpp/common/json-schema-to-grammar.h +15 -1
- package/src/llama.cpp/common/llguidance.cpp +270 -0
- package/src/llama.cpp/common/log.cpp +1 -10
- package/src/llama.cpp/common/log.h +10 -0
- package/src/llama.cpp/common/minja.hpp +2868 -0
- package/src/llama.cpp/common/sampling.cpp +22 -1
- package/src/llama.cpp/common/sampling.h +3 -0
- package/src/llama.cpp/docs/build.md +54 -9
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +12 -2
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +59 -0
- package/src/llama.cpp/examples/llava/clip.cpp +133 -14
- package/src/llama.cpp/examples/llava/clip.h +2 -0
- package/src/llama.cpp/examples/llava/llava.cpp +22 -8
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +9 -1
- package/src/llama.cpp/examples/main/main.cpp +26 -25
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +136 -137
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +18 -4
- package/src/llama.cpp/examples/run/run.cpp +224 -69
- package/src/llama.cpp/examples/server/server.cpp +252 -81
- package/src/llama.cpp/examples/server/utils.hpp +73 -21
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +6 -4
- package/src/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +78 -1
- package/src/llama.cpp/ggml/include/ggml.h +1 -1
- package/src/llama.cpp/ggml/src/CMakeLists.txt +21 -4
- package/src/llama.cpp/ggml/src/ggml-alloc.c +1 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +91 -78
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +7 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +46 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +16 -1
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +28 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +5 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +33 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +1 -5
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +323 -121
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
- package/src/llama.cpp/ggml/src/ggml.c +23 -13
- package/src/llama.cpp/include/llama.h +14 -1
- package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +46 -0
- package/src/llama.cpp/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/src/llama-arch.cpp +7 -2
- package/src/llama.cpp/src/llama-arch.h +3 -1
- package/src/llama.cpp/src/llama-chat.cpp +11 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +86 -6
- package/src/llama.cpp/src/llama-grammar.h +22 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1 -1
- package/src/llama.cpp/src/llama-model.cpp +76 -6
- package/src/llama.cpp/src/llama-sampling.cpp +47 -4
- package/src/llama.cpp/src/llama-vocab.cpp +10 -4
- package/src/llama.cpp/src/llama.cpp +181 -123
- package/src/llama.cpp/tests/CMakeLists.txt +4 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +158 -57
- package/src/llama.cpp/tests/test-chat-template.cpp +154 -31
- package/src/llama.cpp/tests/test-chat.cpp +607 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +2 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +1140 -0
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +0 -32
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
|
+
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
|
|
3
4
|
#include <hip/hip_runtime.h>
|
|
4
5
|
#include <hipblas/hipblas.h>
|
|
5
6
|
#include <hip/hip_fp16.h>
|
|
@@ -8,6 +9,7 @@
|
|
|
8
9
|
// for rocblas_initialize()
|
|
9
10
|
#include "rocblas/rocblas.h"
|
|
10
11
|
#endif // __HIP_PLATFORM_AMD__
|
|
12
|
+
|
|
11
13
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
|
12
14
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
|
13
15
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
@@ -19,6 +21,13 @@
|
|
|
19
21
|
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
|
20
22
|
#define CUDA_R_16F HIPBLAS_R_16F
|
|
21
23
|
#define CUDA_R_32F HIPBLAS_R_32F
|
|
24
|
+
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
|
25
|
+
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
|
26
|
+
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
|
27
|
+
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
|
28
|
+
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
|
29
|
+
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
|
30
|
+
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
|
22
31
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
23
32
|
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
|
24
33
|
#define cublasCreate hipblasCreate
|
|
@@ -74,6 +83,21 @@
|
|
|
74
83
|
#define cudaMemGetInfo hipMemGetInfo
|
|
75
84
|
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
|
76
85
|
#define cudaSetDevice hipSetDevice
|
|
86
|
+
#define cuDeviceGet hipDeviceGet
|
|
87
|
+
#define CUdevice hipDevice_t
|
|
88
|
+
#define CUdeviceptr hipDeviceptr_t
|
|
89
|
+
#define cuMemUnmap hipMemUnmap
|
|
90
|
+
#define CUmemAccessDesc hipMemAccessDesc
|
|
91
|
+
#define cuMemAddressFree hipMemAddressFree
|
|
92
|
+
#define cuMemRelease hipMemRelease
|
|
93
|
+
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
|
94
|
+
#define cuMemCreate hipMemCreate
|
|
95
|
+
#define cuMemAddressReserve hipMemAddressReserve
|
|
96
|
+
#define cuMemMap hipMemMap
|
|
97
|
+
#define cuMemSetAccess hipMemSetAccess
|
|
98
|
+
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
|
99
|
+
#define CUmemAllocationProp hipMemAllocationProp
|
|
100
|
+
#define cuDeviceGetAttribute hipDeviceGetAttribute
|
|
77
101
|
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
|
78
102
|
#define cudaStreamDestroy hipStreamDestroy
|
|
79
103
|
#define cudaStreamFireAndForget hipStreamFireAndForget
|
|
@@ -81,6 +105,28 @@
|
|
|
81
105
|
#define cudaStreamPerThread hipStreamPerThread
|
|
82
106
|
#define cudaStreamSynchronize hipStreamSynchronize
|
|
83
107
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
|
108
|
+
#define cudaGraphExec_t hipGraphExec_t
|
|
109
|
+
#define cudaGraphNode_t hipGraphNode_t
|
|
110
|
+
#define cudaKernelNodeParams hipKernelNodeParams
|
|
111
|
+
#define cudaKernelNodeParams hipKernelNodeParams
|
|
112
|
+
#define cudaGraphExecDestroy hipGraphExecDestroy
|
|
113
|
+
#define cudaGraphLaunch hipGraphLaunch
|
|
114
|
+
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
|
115
|
+
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
|
116
|
+
#define cudaGraphNodeType hipGraphNodeType
|
|
117
|
+
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
|
118
|
+
#define cudaGraphInstantiate hipGraphInstantiate
|
|
119
|
+
#define cudaStreamEndCapture hipStreamEndCapture
|
|
120
|
+
#define cudaGraphDestroy hipGraphDestroy
|
|
121
|
+
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
|
122
|
+
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
|
123
|
+
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
|
124
|
+
#define cudaGraphNodeGetType hipGraphNodeGetType
|
|
125
|
+
#define cudaGraphGetNodes hipGraphGetNodes
|
|
126
|
+
#define cudaGraphExecUpdate hipGraphExecUpdate
|
|
127
|
+
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
|
128
|
+
#define cudaStreamBeginCapture hipStreamBeginCapture
|
|
129
|
+
#define cudaGraph_t hipGraph_t
|
|
84
130
|
#define cudaStream_t hipStream_t
|
|
85
131
|
#define cudaSuccess hipSuccess
|
|
86
132
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
|
@@ -40,13 +40,20 @@ find_package(hip REQUIRED)
|
|
|
40
40
|
find_package(hipblas REQUIRED)
|
|
41
41
|
find_package(rocblas REQUIRED)
|
|
42
42
|
|
|
43
|
+
if (${hip_VERSION} VERSION_LESS 5.5)
|
|
44
|
+
message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
|
|
45
|
+
endif()
|
|
46
|
+
|
|
43
47
|
message(STATUS "HIP and hipBLAS found")
|
|
44
48
|
|
|
49
|
+
# Workaround old compilers
|
|
50
|
+
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
|
|
51
|
+
|
|
45
52
|
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
|
46
53
|
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
|
47
54
|
|
|
48
55
|
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
|
49
|
-
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-
|
|
56
|
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
|
50
57
|
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
|
51
58
|
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
|
52
59
|
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
|
@@ -92,6 +99,14 @@ if (GGML_CUDA_NO_PEER_COPY)
|
|
|
92
99
|
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
|
93
100
|
endif()
|
|
94
101
|
|
|
102
|
+
if (GGML_HIP_GRAPHS)
|
|
103
|
+
add_compile_definitions(GGML_HIP_GRAPHS)
|
|
104
|
+
endif()
|
|
105
|
+
|
|
106
|
+
if (GGML_HIP_NO_VMM)
|
|
107
|
+
add_compile_definitions(GGML_HIP_NO_VMM)
|
|
108
|
+
endif()
|
|
109
|
+
|
|
95
110
|
if (CXX_IS_HIPCC)
|
|
96
111
|
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
|
97
112
|
target_link_libraries(ggml-hip PRIVATE hip::device)
|
|
@@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
|
|
|
29
29
|
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
|
|
30
30
|
|
|
31
31
|
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
|
32
|
-
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-
|
|
32
|
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
|
33
33
|
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
|
34
34
|
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
|
35
35
|
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
|
@@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
|
|
|
181
181
|
|
|
182
182
|
struct ggml_backend_rpc_buffer_context {
|
|
183
183
|
std::shared_ptr<socket_t> sock;
|
|
184
|
-
|
|
184
|
+
void * base_ptr;
|
|
185
185
|
uint64_t remote_ptr;
|
|
186
186
|
};
|
|
187
187
|
|
|
@@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
|
423
423
|
|
|
424
424
|
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
425
425
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
426
|
-
if (ctx->
|
|
427
|
-
return ctx->
|
|
426
|
+
if (ctx->base_ptr != nullptr) {
|
|
427
|
+
return ctx->base_ptr;
|
|
428
428
|
}
|
|
429
429
|
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
|
430
430
|
rpc_msg_buffer_get_base_rsp response;
|
|
431
431
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
|
432
432
|
GGML_ASSERT(status);
|
|
433
|
-
|
|
434
|
-
ctx->
|
|
435
|
-
return base_ptr;
|
|
433
|
+
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
|
434
|
+
return ctx->base_ptr;
|
|
436
435
|
}
|
|
437
436
|
|
|
438
437
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
@@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
|
557
556
|
if (response.remote_ptr != 0) {
|
|
558
557
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
559
558
|
ggml_backend_rpc_buffer_interface,
|
|
560
|
-
new ggml_backend_rpc_buffer_context{sock,
|
|
559
|
+
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
|
|
561
560
|
response.remote_size);
|
|
562
561
|
return buffer;
|
|
563
562
|
} else {
|
|
@@ -1046,7 +1045,28 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1046
1045
|
ggml_free(ctx);
|
|
1047
1046
|
return false;
|
|
1048
1047
|
}
|
|
1049
|
-
|
|
1048
|
+
|
|
1049
|
+
uint64_t src_size = (uint64_t) ggml_nbytes(src);
|
|
1050
|
+
uint64_t dst_data = (uint64_t) dst->data;
|
|
1051
|
+
uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
|
|
1052
|
+
uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
|
|
1053
|
+
|
|
1054
|
+
if (dst_data + src_size > dst_base + dst_buf_sz) {
|
|
1055
|
+
GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
|
|
1056
|
+
" write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
|
|
1057
|
+
" buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
|
|
1058
|
+
__func__,
|
|
1059
|
+
dst_data,
|
|
1060
|
+
dst_data + src_size,
|
|
1061
|
+
dst_base,
|
|
1062
|
+
dst_base + dst_buf_sz);
|
|
1063
|
+
ggml_free(ctx);
|
|
1064
|
+
return false;
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
|
|
1068
|
+
__func__, (void*) src->buffer, (void*) dst->buffer);
|
|
1069
|
+
|
|
1050
1070
|
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
|
1051
1071
|
ggml_free(ctx);
|
|
1052
1072
|
return true;
|
|
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
|
|
|
3878
3878
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
|
|
3879
3879
|
}
|
|
3880
3880
|
|
|
3881
|
-
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
3882
|
-
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
|
|
3883
|
-
}
|
|
3884
|
-
|
|
3885
3881
|
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
3886
3882
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
|
|
3887
3883
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
|
|
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
|
4090
4086
|
ggml_sycl_diag_mask_inf(ctx, dst);
|
|
4091
4087
|
break;
|
|
4092
4088
|
case GGML_OP_SOFT_MAX:
|
|
4093
|
-
|
|
4089
|
+
ggml_sycl_op_soft_max(ctx, dst);
|
|
4094
4090
|
break;
|
|
4095
4091
|
case GGML_OP_ROPE:
|
|
4096
4092
|
ggml_sycl_rope(ctx, dst);
|
|
@@ -4541,14 +4537,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4541
4537
|
case GGML_OP_VIEW:
|
|
4542
4538
|
case GGML_OP_PERMUTE:
|
|
4543
4539
|
case GGML_OP_TRANSPOSE:
|
|
4544
|
-
case GGML_OP_NORM:
|
|
4545
4540
|
case GGML_OP_ADD:
|
|
4546
4541
|
case GGML_OP_ADD1:
|
|
4547
4542
|
case GGML_OP_LOG:
|
|
4548
4543
|
case GGML_OP_SUB:
|
|
4549
4544
|
case GGML_OP_MUL:
|
|
4550
4545
|
case GGML_OP_DIV:
|
|
4546
|
+
return true;
|
|
4547
|
+
case GGML_OP_NORM:
|
|
4551
4548
|
case GGML_OP_RMS_NORM:
|
|
4549
|
+
case GGML_OP_GROUP_NORM:
|
|
4550
|
+
return ggml_is_contiguous(op->src[0]);
|
|
4552
4551
|
case GGML_OP_SCALE:
|
|
4553
4552
|
case GGML_OP_SQR:
|
|
4554
4553
|
case GGML_OP_SQRT:
|
|
@@ -4580,7 +4579,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4580
4579
|
case GGML_OP_SUM_ROWS:
|
|
4581
4580
|
case GGML_OP_ARGSORT:
|
|
4582
4581
|
case GGML_OP_ACC:
|
|
4583
|
-
case GGML_OP_GROUP_NORM:
|
|
4584
4582
|
case GGML_OP_UPSCALE:
|
|
4585
4583
|
case GGML_OP_PAD:
|
|
4586
4584
|
case GGML_OP_LEAKY_RELU:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
#include "
|
|
1
|
+
#include "softmax.hpp"
|
|
2
2
|
|
|
3
|
-
template <bool vals_smem, int ncols_template, int block_size_template>
|
|
4
|
-
static void soft_max_f32(const float * x, const
|
|
3
|
+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
|
4
|
+
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
|
|
5
5
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
|
6
6
|
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
|
7
7
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
29
29
|
slope = sycl::pow(base, float(exp));
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
float *vals = vals_smem ? buf +
|
|
32
|
+
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
|
33
33
|
float max_val = -INFINITY;
|
|
34
34
|
|
|
35
35
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
|
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
42
42
|
const int ix = rowx*ncols + col;
|
|
43
43
|
const int iy = rowy*ncols + col;
|
|
44
44
|
|
|
45
|
-
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
|
45
|
+
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
|
|
46
46
|
|
|
47
47
|
vals[col] = val;
|
|
48
48
|
max_val = sycl::max(max_val, val);
|
|
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
65
65
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
66
66
|
max_val = buf[lane_id];
|
|
67
67
|
for (size_t i = 1; i < nreduce; i += 1) {
|
|
68
|
-
max_val =
|
|
68
|
+
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
|
69
69
|
}
|
|
70
70
|
max_val = warp_reduce_max(max_val, item_ct1);
|
|
71
71
|
}
|
|
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
122
122
|
}
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
-
template <bool vals_smem, int ncols_template, int block_size_template>
|
|
126
|
-
static void soft_max_f32_submitter(const float * x, const
|
|
125
|
+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
|
126
|
+
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
|
|
127
127
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
|
128
128
|
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
|
129
129
|
const size_t n_local_scratch, queue_ptr stream) {
|
|
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
|
|
|
141
141
|
});
|
|
142
142
|
}
|
|
143
143
|
|
|
144
|
-
|
|
144
|
+
template<typename T>
|
|
145
|
+
static void soft_max_f32_sycl(const float * x, const T * mask,
|
|
145
146
|
float * dst, const int ncols_x, const int nrows_x,
|
|
146
147
|
const int nrows_y, const float scale, const float max_bias,
|
|
147
148
|
queue_ptr stream, int device) {
|
|
@@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
|
|
|
223
224
|
}
|
|
224
225
|
}
|
|
225
226
|
|
|
226
|
-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx,
|
|
227
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
|
228
|
-
const float *src0_dd, const float *src1_dd,
|
|
229
|
-
float *dst_dd,
|
|
230
|
-
const queue_ptr &main_stream) {
|
|
227
|
+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
231
228
|
|
|
232
|
-
GGML_ASSERT(
|
|
229
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
233
230
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
234
231
|
|
|
235
|
-
|
|
236
|
-
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
|
237
|
-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
232
|
+
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
238
233
|
|
|
239
|
-
const int64_t ne00 =
|
|
240
|
-
const int64_t nrows_x = ggml_nrows(
|
|
241
|
-
const int64_t nrows_y =
|
|
234
|
+
const int64_t ne00 = dst->src[0]->ne[0];
|
|
235
|
+
const int64_t nrows_x = ggml_nrows(dst->src[0]);
|
|
236
|
+
const int64_t nrows_y = dst->src[0]->ne[1];
|
|
242
237
|
|
|
243
238
|
float scale = 1.0f;
|
|
244
239
|
float max_bias = 0.0f;
|
|
@@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
|
|
246
241
|
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
|
247
242
|
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
|
248
243
|
|
|
249
|
-
|
|
250
|
-
|
|
244
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
245
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
246
|
+
|
|
247
|
+
ggml_sycl_set_device(ctx.device);
|
|
248
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
249
|
+
|
|
250
|
+
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
|
251
|
+
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
|
|
252
|
+
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
|
|
253
|
+
main_stream, ctx.device);
|
|
254
|
+
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
|
|
255
|
+
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
|
256
|
+
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
|
257
|
+
} else {
|
|
258
|
+
/* mask unavailable */
|
|
259
|
+
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
|
260
|
+
}
|
|
251
261
|
}
|
|
@@ -15,10 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
|
-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx,
|
|
19
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
|
20
|
-
const float *src0_dd, const float *src1_dd,
|
|
21
|
-
float *dst_dd,
|
|
22
|
-
const queue_ptr &main_stream);
|
|
18
|
+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
|
|
23
19
|
|
|
24
20
|
#endif // GGML_SYCL_SOFTMAX_HPP
|