@fugood/llama.node 0.3.9 → 0.3.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.js +2 -2
  18. package/lib/binding.ts +47 -8
  19. package/lib/index.js +21 -1
  20. package/lib/index.ts +31 -1
  21. package/package.json +12 -3
  22. package/src/LlamaCompletionWorker.cpp +33 -6
  23. package/src/LlamaCompletionWorker.h +3 -1
  24. package/src/LlamaContext.cpp +336 -28
  25. package/src/LlamaContext.h +2 -0
  26. package/src/common.hpp +19 -2
  27. package/src/llama.cpp/.github/workflows/build.yml +289 -107
  28. package/src/llama.cpp/.github/workflows/close-issue.yml +1 -1
  29. package/src/llama.cpp/.github/workflows/docker.yml +2 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +25 -2
  31. package/src/llama.cpp/CMakeLists.txt +10 -19
  32. package/src/llama.cpp/cmake/build-info.cmake +1 -1
  33. package/src/llama.cpp/common/CMakeLists.txt +32 -0
  34. package/src/llama.cpp/common/arg.cpp +66 -16
  35. package/src/llama.cpp/common/chat-template.hpp +515 -0
  36. package/src/llama.cpp/common/chat.cpp +966 -0
  37. package/src/llama.cpp/common/chat.hpp +52 -0
  38. package/src/llama.cpp/common/common.cpp +159 -36
  39. package/src/llama.cpp/common/common.h +56 -14
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +46 -66
  41. package/src/llama.cpp/common/json-schema-to-grammar.h +15 -1
  42. package/src/llama.cpp/common/llguidance.cpp +270 -0
  43. package/src/llama.cpp/common/log.cpp +1 -10
  44. package/src/llama.cpp/common/log.h +10 -0
  45. package/src/llama.cpp/common/minja.hpp +2868 -0
  46. package/src/llama.cpp/common/sampling.cpp +22 -1
  47. package/src/llama.cpp/common/sampling.h +3 -0
  48. package/src/llama.cpp/docs/build.md +54 -9
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +12 -2
  50. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +1 -1
  51. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  52. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +59 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +133 -14
  54. package/src/llama.cpp/examples/llava/clip.h +2 -0
  55. package/src/llama.cpp/examples/llava/llava.cpp +22 -8
  56. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +9 -1
  57. package/src/llama.cpp/examples/main/main.cpp +26 -25
  58. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +136 -137
  59. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +18 -4
  60. package/src/llama.cpp/examples/run/run.cpp +224 -69
  61. package/src/llama.cpp/examples/server/server.cpp +252 -81
  62. package/src/llama.cpp/examples/server/utils.hpp +73 -21
  63. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +6 -4
  64. package/src/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt +11 -0
  65. package/src/llama.cpp/ggml/CMakeLists.txt +78 -1
  66. package/src/llama.cpp/ggml/include/ggml.h +1 -1
  67. package/src/llama.cpp/ggml/src/CMakeLists.txt +21 -4
  68. package/src/llama.cpp/ggml/src/ggml-alloc.c +1 -13
  69. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +91 -78
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +7 -7
  71. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -1
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  73. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +46 -0
  74. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +16 -1
  75. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +28 -8
  77. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +5 -7
  78. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +33 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +1 -5
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +323 -121
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  82. package/src/llama.cpp/ggml/src/ggml.c +23 -13
  83. package/src/llama.cpp/include/llama.h +14 -1
  84. package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +112 -0
  85. package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +46 -0
  86. package/src/llama.cpp/src/CMakeLists.txt +1 -1
  87. package/src/llama.cpp/src/llama-arch.cpp +7 -2
  88. package/src/llama.cpp/src/llama-arch.h +3 -1
  89. package/src/llama.cpp/src/llama-chat.cpp +11 -2
  90. package/src/llama.cpp/src/llama-chat.h +1 -0
  91. package/src/llama.cpp/src/llama-grammar.cpp +86 -6
  92. package/src/llama.cpp/src/llama-grammar.h +22 -1
  93. package/src/llama.cpp/src/llama-mmap.cpp +1 -0
  94. package/src/llama.cpp/src/llama-model-loader.cpp +1 -1
  95. package/src/llama.cpp/src/llama-model.cpp +76 -6
  96. package/src/llama.cpp/src/llama-sampling.cpp +47 -4
  97. package/src/llama.cpp/src/llama-vocab.cpp +10 -4
  98. package/src/llama.cpp/src/llama.cpp +181 -123
  99. package/src/llama.cpp/tests/CMakeLists.txt +4 -0
  100. package/src/llama.cpp/tests/test-backend-ops.cpp +158 -57
  101. package/src/llama.cpp/tests/test-chat-template.cpp +154 -31
  102. package/src/llama.cpp/tests/test-chat.cpp +607 -0
  103. package/src/llama.cpp/tests/test-grammar-integration.cpp +2 -2
  104. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +1140 -0
  105. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -1
  106. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +0 -32
@@ -1,5 +1,6 @@
1
1
  #pragma once
2
2
 
3
+ #define HIP_ENABLE_WARP_SYNC_BUILTINS 1
3
4
  #include <hip/hip_runtime.h>
4
5
  #include <hipblas/hipblas.h>
5
6
  #include <hip/hip_fp16.h>
@@ -8,6 +9,7 @@
8
9
  // for rocblas_initialize()
9
10
  #include "rocblas/rocblas.h"
10
11
  #endif // __HIP_PLATFORM_AMD__
12
+
11
13
  #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
12
14
  #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
13
15
  #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
@@ -19,6 +21,13 @@
19
21
  #define CUBLAS_TF32_TENSOR_OP_MATH 0
20
22
  #define CUDA_R_16F HIPBLAS_R_16F
21
23
  #define CUDA_R_32F HIPBLAS_R_32F
24
+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
25
+ #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
26
+ #define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
27
+ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
28
+ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
29
+ #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
30
+ #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
22
31
  #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
23
32
  #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
24
33
  #define cublasCreate hipblasCreate
@@ -74,6 +83,21 @@
74
83
  #define cudaMemGetInfo hipMemGetInfo
75
84
  #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
76
85
  #define cudaSetDevice hipSetDevice
86
+ #define cuDeviceGet hipDeviceGet
87
+ #define CUdevice hipDevice_t
88
+ #define CUdeviceptr hipDeviceptr_t
89
+ #define cuMemUnmap hipMemUnmap
90
+ #define CUmemAccessDesc hipMemAccessDesc
91
+ #define cuMemAddressFree hipMemAddressFree
92
+ #define cuMemRelease hipMemRelease
93
+ #define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
94
+ #define cuMemCreate hipMemCreate
95
+ #define cuMemAddressReserve hipMemAddressReserve
96
+ #define cuMemMap hipMemMap
97
+ #define cuMemSetAccess hipMemSetAccess
98
+ #define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
99
+ #define CUmemAllocationProp hipMemAllocationProp
100
+ #define cuDeviceGetAttribute hipDeviceGetAttribute
77
101
  #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
78
102
  #define cudaStreamDestroy hipStreamDestroy
79
103
  #define cudaStreamFireAndForget hipStreamFireAndForget
@@ -81,6 +105,28 @@
81
105
  #define cudaStreamPerThread hipStreamPerThread
82
106
  #define cudaStreamSynchronize hipStreamSynchronize
83
107
  #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
108
+ #define cudaGraphExec_t hipGraphExec_t
109
+ #define cudaGraphNode_t hipGraphNode_t
110
+ #define cudaKernelNodeParams hipKernelNodeParams
111
+ #define cudaKernelNodeParams hipKernelNodeParams
112
+ #define cudaGraphExecDestroy hipGraphExecDestroy
113
+ #define cudaGraphLaunch hipGraphLaunch
114
+ #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115
+ #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
116
+ #define cudaGraphNodeType hipGraphNodeType
117
+ #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118
+ #define cudaGraphInstantiate hipGraphInstantiate
119
+ #define cudaStreamEndCapture hipStreamEndCapture
120
+ #define cudaGraphDestroy hipGraphDestroy
121
+ #define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
122
+ #define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
123
+ #define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
124
+ #define cudaGraphNodeGetType hipGraphNodeGetType
125
+ #define cudaGraphGetNodes hipGraphGetNodes
126
+ #define cudaGraphExecUpdate hipGraphExecUpdate
127
+ #define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
128
+ #define cudaStreamBeginCapture hipStreamBeginCapture
129
+ #define cudaGraph_t hipGraph_t
84
130
  #define cudaStream_t hipStream_t
85
131
  #define cudaSuccess hipSuccess
86
132
  #define __trap() do { abort(); __builtin_unreachable(); } while(0)
@@ -40,13 +40,20 @@ find_package(hip REQUIRED)
40
40
  find_package(hipblas REQUIRED)
41
41
  find_package(rocblas REQUIRED)
42
42
 
43
+ if (${hip_VERSION} VERSION_LESS 5.5)
44
+ message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
45
+ endif()
46
+
43
47
  message(STATUS "HIP and hipBLAS found")
44
48
 
49
+ # Workaround old compilers
50
+ set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
51
+
45
52
  file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
46
53
  list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
47
54
 
48
55
  file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
49
- file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
56
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
50
57
  list(APPEND GGML_SOURCES_ROCM ${SRCS})
51
58
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
52
59
  list(APPEND GGML_SOURCES_ROCM ${SRCS})
@@ -92,6 +99,14 @@ if (GGML_CUDA_NO_PEER_COPY)
92
99
  add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
93
100
  endif()
94
101
 
102
+ if (GGML_HIP_GRAPHS)
103
+ add_compile_definitions(GGML_HIP_GRAPHS)
104
+ endif()
105
+
106
+ if (GGML_HIP_NO_VMM)
107
+ add_compile_definitions(GGML_HIP_NO_VMM)
108
+ endif()
109
+
95
110
  if (CXX_IS_HIPCC)
96
111
  set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
97
112
  target_link_libraries(ggml-hip PRIVATE hip::device)
@@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
29
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30
30
 
31
31
  file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
32
- file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
32
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
33
33
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
34
34
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
35
35
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
@@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
181
181
 
182
182
  struct ggml_backend_rpc_buffer_context {
183
183
  std::shared_ptr<socket_t> sock;
184
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
184
+ void * base_ptr;
185
185
  uint64_t remote_ptr;
186
186
  };
187
187
 
@@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
423
423
 
424
424
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
425
425
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
426
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
427
- return ctx->base_cache[buffer];
426
+ if (ctx->base_ptr != nullptr) {
427
+ return ctx->base_ptr;
428
428
  }
429
429
  rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
430
430
  rpc_msg_buffer_get_base_rsp response;
431
431
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
432
432
  GGML_ASSERT(status);
433
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
434
- ctx->base_cache[buffer] = base_ptr;
435
- return base_ptr;
433
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
434
+ return ctx->base_ptr;
436
435
  }
437
436
 
438
437
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
557
556
  if (response.remote_ptr != 0) {
558
557
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
559
558
  ggml_backend_rpc_buffer_interface,
560
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
559
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
561
560
  response.remote_size);
562
561
  return buffer;
563
562
  } else {
@@ -1046,7 +1045,28 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1046
1045
  ggml_free(ctx);
1047
1046
  return false;
1048
1047
  }
1049
- GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
1048
+
1049
+ uint64_t src_size = (uint64_t) ggml_nbytes(src);
1050
+ uint64_t dst_data = (uint64_t) dst->data;
1051
+ uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
1052
+ uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1053
+
1054
+ if (dst_data + src_size > dst_base + dst_buf_sz) {
1055
+ GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1056
+ " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1057
+ " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1058
+ __func__,
1059
+ dst_data,
1060
+ dst_data + src_size,
1061
+ dst_base,
1062
+ dst_base + dst_buf_sz);
1063
+ ggml_free(ctx);
1064
+ return false;
1065
+ }
1066
+
1067
+ GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
1068
+ __func__, (void*) src->buffer, (void*) dst->buffer);
1069
+
1050
1070
  response.result = ggml_backend_buffer_copy_tensor(src, dst);
1051
1071
  ggml_free(ctx);
1052
1072
  return true;
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
3878
3878
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3879
3879
  }
3880
3880
 
3881
- static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3882
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
3883
- }
3884
-
3885
3881
  static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3886
3882
  GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3887
3883
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4090
4086
  ggml_sycl_diag_mask_inf(ctx, dst);
4091
4087
  break;
4092
4088
  case GGML_OP_SOFT_MAX:
4093
- ggml_sycl_soft_max(ctx, dst);
4089
+ ggml_sycl_op_soft_max(ctx, dst);
4094
4090
  break;
4095
4091
  case GGML_OP_ROPE:
4096
4092
  ggml_sycl_rope(ctx, dst);
@@ -4541,14 +4537,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4541
4537
  case GGML_OP_VIEW:
4542
4538
  case GGML_OP_PERMUTE:
4543
4539
  case GGML_OP_TRANSPOSE:
4544
- case GGML_OP_NORM:
4545
4540
  case GGML_OP_ADD:
4546
4541
  case GGML_OP_ADD1:
4547
4542
  case GGML_OP_LOG:
4548
4543
  case GGML_OP_SUB:
4549
4544
  case GGML_OP_MUL:
4550
4545
  case GGML_OP_DIV:
4546
+ return true;
4547
+ case GGML_OP_NORM:
4551
4548
  case GGML_OP_RMS_NORM:
4549
+ case GGML_OP_GROUP_NORM:
4550
+ return ggml_is_contiguous(op->src[0]);
4552
4551
  case GGML_OP_SCALE:
4553
4552
  case GGML_OP_SQR:
4554
4553
  case GGML_OP_SQRT:
@@ -4580,7 +4579,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4580
4579
  case GGML_OP_SUM_ROWS:
4581
4580
  case GGML_OP_ARGSORT:
4582
4581
  case GGML_OP_ACC:
4583
- case GGML_OP_GROUP_NORM:
4584
4582
  case GGML_OP_UPSCALE:
4585
4583
  case GGML_OP_PAD:
4586
4584
  case GGML_OP_LEAKY_RELU:
@@ -1,7 +1,7 @@
1
- #include "norm.hpp"
1
+ #include "softmax.hpp"
2
2
 
3
- template <bool vals_smem, int ncols_template, int block_size_template>
4
- static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
3
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
4
+ static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
5
5
  const int nrows_y, const float scale, const float max_bias, const float m0,
6
6
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7
7
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
29
29
  slope = sycl::pow(base, float(exp));
30
30
  }
31
31
 
32
- float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
32
+ float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
33
33
  float max_val = -INFINITY;
34
34
 
35
35
  for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
42
42
  const int ix = rowx*ncols + col;
43
43
  const int iy = rowy*ncols + col;
44
44
 
45
- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
45
+ const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
46
46
 
47
47
  vals[col] = val;
48
48
  max_val = sycl::max(max_val, val);
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
65
65
  item_ct1.barrier(sycl::access::fence_space::local_space);
66
66
  max_val = buf[lane_id];
67
67
  for (size_t i = 1; i < nreduce; i += 1) {
68
- max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
68
+ max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
69
  }
70
70
  max_val = warp_reduce_max(max_val, item_ct1);
71
71
  }
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122
122
  }
123
123
  }
124
124
 
125
- template <bool vals_smem, int ncols_template, int block_size_template>
126
- static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
125
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
126
+ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
127
127
  const int nrows_y, const float scale, const float max_bias, const float m0,
128
128
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
129
  const size_t n_local_scratch, queue_ptr stream) {
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
141
141
  });
142
142
  }
143
143
 
144
- static void soft_max_f32_sycl(const float * x, const float * mask,
144
+ template<typename T>
145
+ static void soft_max_f32_sycl(const float * x, const T * mask,
145
146
  float * dst, const int ncols_x, const int nrows_x,
146
147
  const int nrows_y, const float scale, const float max_bias,
147
148
  queue_ptr stream, int device) {
@@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
223
224
  }
224
225
  }
225
226
 
226
- void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227
- const ggml_tensor *src1, ggml_tensor *dst,
228
- const float *src0_dd, const float *src1_dd,
229
- float *dst_dd,
230
- const queue_ptr &main_stream) {
227
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231
228
 
232
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
229
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
233
230
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
234
231
 
235
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
237
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
232
+ GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238
233
 
239
- const int64_t ne00 = src0->ne[0];
240
- const int64_t nrows_x = ggml_nrows(src0);
241
- const int64_t nrows_y = src0->ne[1];
234
+ const int64_t ne00 = dst->src[0]->ne[0];
235
+ const int64_t nrows_x = ggml_nrows(dst->src[0]);
236
+ const int64_t nrows_y = dst->src[0]->ne[1];
242
237
 
243
238
  float scale = 1.0f;
244
239
  float max_bias = 0.0f;
@@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
246
241
  memcpy(&scale, dst->op_params + 0, sizeof(float));
247
242
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
248
243
 
249
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
250
- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
244
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
245
+ float * dst_dd = static_cast<float *>(dst->data);
246
+
247
+ ggml_sycl_set_device(ctx.device);
248
+ dpct::queue_ptr main_stream = ctx.stream();
249
+
250
+ if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
251
+ const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
252
+ soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253
+ main_stream, ctx.device);
254
+ } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
255
+ const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
256
+ soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
257
+ } else {
258
+ /* mask unavailable */
259
+ soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
260
+ }
251
261
  }
@@ -15,10 +15,6 @@
15
15
 
16
16
  #include "common.hpp"
17
17
 
18
- void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
19
- const ggml_tensor *src1, ggml_tensor *dst,
20
- const float *src0_dd, const float *src1_dd,
21
- float *dst_dd,
22
- const queue_ptr &main_stream);
18
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
23
19
 
24
20
  #endif // GGML_SYCL_SOFTMAX_HPP