llama_cpp 0.13.0 → 0.14.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,33 @@ struct vk_queue {
69
69
  vk::PipelineStageFlags stage_flags;
70
70
  };
71
71
 
72
+ struct vk_pipeline_struct {
73
+ std::string name;
74
+ vk::ShaderModule shader_module;
75
+ vk::DescriptorSetLayout dsl;
76
+ std::vector<vk::DescriptorPool> descriptor_pools;
77
+ std::vector<vk::DescriptorSet> descriptor_sets;
78
+ uint32_t descriptor_set_idx;
79
+ vk::PipelineLayout layout;
80
+ vk::Pipeline pipeline;
81
+ uint32_t push_constant_size;
82
+ uint32_t parameter_count;
83
+ std::array<uint32_t, 3> wg_denoms;
84
+ uint32_t align;
85
+ };
86
+
87
+ typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
88
+ typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
89
+
90
+ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
91
+
92
+ struct vk_matmul_pipeline_struct {
93
+ vk_pipeline l, m, s;
94
+ vk_pipeline a_l, a_m, a_s;
95
+ };
96
+
97
+ typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
98
+
72
99
  struct vk_device {
73
100
  vk::PhysicalDevice physical_device;
74
101
  vk::PhysicalDeviceProperties properties;
@@ -84,10 +111,61 @@ struct vk_device {
84
111
  uint32_t subgroup_size;
85
112
  bool uma;
86
113
 
114
+ bool initialized;
115
+ size_t idx;
116
+
117
+ vk_matmul_pipeline pipeline_matmul_f32;
118
+ vk_matmul_pipeline pipeline_matmul_f16;
119
+ vk_matmul_pipeline pipeline_matmul_f16_f32;
120
+ vk_pipeline pipeline_matmul_split_k_reduce;
121
+
122
+ vk_matmul_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
123
+
124
+ vk_pipeline pipeline_dequant[VK_NUM_TYPES];
125
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
126
+
127
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
128
+ vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
129
+ vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
130
+ vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
131
+ vk_pipeline pipeline_mul_f32;
132
+ vk_pipeline pipeline_add_f32;
133
+ vk_pipeline pipeline_scale_f32;
134
+ vk_pipeline pipeline_sqr_f32;
135
+ vk_pipeline pipeline_clamp_f32;
136
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
137
+ vk_pipeline pipeline_norm_f32;
138
+ vk_pipeline pipeline_rms_norm_f32;
139
+ vk_pipeline pipeline_gelu_f32;
140
+ vk_pipeline pipeline_silu_f32;
141
+ vk_pipeline pipeline_relu_f32;
142
+ vk_pipeline pipeline_diag_mask_inf_f32;
143
+ vk_pipeline pipeline_soft_max_f32;
144
+ vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
145
+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
146
+ vk_pipeline pipeline_argsort_f32;
147
+
148
+ std::vector<vk_pipeline_ref> pipelines;
149
+
87
150
  ~vk_device() {
88
151
  #ifdef GGML_VULKAN_DEBUG
89
152
  std::cerr << "destroy device " << name << std::endl;
90
153
  #endif
154
+ device.destroyCommandPool(compute_queue.pool);
155
+ if (!single_queue) {
156
+ device.destroyCommandPool(transfer_queue.pool);
157
+ }
158
+
159
+ for (auto& pipeline : pipelines) {
160
+ if (pipeline.expired()) {
161
+ continue;
162
+ }
163
+
164
+ vk_pipeline pl = pipeline.lock();
165
+ ggml_vk_destroy_pipeline(device, pl);
166
+ }
167
+ pipelines.clear();
168
+
91
169
  device.destroy();
92
170
  }
93
171
  };
@@ -125,21 +203,6 @@ struct vk_subbuffer {
125
203
  uint64_t size;
126
204
  };
127
205
 
128
- struct vk_pipeline {
129
- std::string name;
130
- vk::ShaderModule shader_module;
131
- vk::DescriptorSetLayout dsl;
132
- std::vector<vk::DescriptorPool> descriptor_pools;
133
- std::vector<vk::DescriptorSet> descriptor_sets;
134
- uint32_t descriptor_set_idx;
135
- vk::PipelineLayout layout;
136
- vk::Pipeline pipeline;
137
- uint32_t push_constant_size;
138
- uint32_t parameter_count;
139
- std::array<uint32_t, 3> wg_denoms;
140
- uint32_t align;
141
- };
142
-
143
206
  struct vk_semaphore {
144
207
  vk::Semaphore s;
145
208
  uint64_t value;
@@ -160,11 +223,21 @@ struct vk_op_push_constants {
160
223
  float param2;
161
224
  };
162
225
 
163
- struct vk_op_cpy_push_constants {
226
+ struct vk_op_unary_push_constants {
227
+ uint32_t ne;
228
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
229
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
230
+ uint32_t d_offset;
231
+ float param1; float param2;
232
+ };
233
+
234
+ struct vk_op_binary_push_constants {
164
235
  uint32_t ne;
165
- uint32_t ne00; uint32_t ne01; uint32_t nb00; uint32_t nb01; uint32_t nb02;
166
- uint32_t ne10; uint32_t ne11; uint32_t nb10; uint32_t nb11; uint32_t nb12;
236
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
237
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
238
+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
167
239
  uint32_t d_offset;
240
+ float param1; float param2;
168
241
  };
169
242
 
170
243
  struct vk_op_diag_mask_push_constants {
@@ -196,6 +269,22 @@ struct vk_op_rope_neox_push_constants {
196
269
  float inv_ndims;
197
270
  };
198
271
 
272
+ struct vk_op_soft_max_push_constants {
273
+ uint32_t KX;
274
+ uint32_t KY;
275
+ uint32_t KZ;
276
+ float scale;
277
+ float max_bias;
278
+ float m0;
279
+ float m1;
280
+ uint32_t n_head_log2;
281
+ };
282
+
283
+ struct vk_op_argsort_push_constants {
284
+ uint32_t ncols;
285
+ bool ascending;
286
+ };
287
+
199
288
  // Allow pre-recording command buffers
200
289
  struct vk_staging_memcpy {
201
290
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -236,7 +325,6 @@ struct ggml_tensor_extra_gpu {
236
325
  };
237
326
 
238
327
  struct ggml_vk_garbage_collector {
239
- std::vector<vk_pipeline *> pipelines;
240
328
  std::vector<vk_semaphore> tl_semaphores;
241
329
  std::vector<vk_semaphore> semaphores;
242
330
  std::vector<vk::Event> events;
@@ -247,35 +335,7 @@ struct ggml_vk_garbage_collector {
247
335
  struct ggml_backend_vk_context {
248
336
  std::string name;
249
337
 
250
- std::weak_ptr<vk_device> device;
251
- vk_pipeline pipeline_matmul_f32_l, pipeline_matmul_f32_m, pipeline_matmul_f32_s;
252
- vk_pipeline pipeline_matmul_f32_aligned_l, pipeline_matmul_f32_aligned_m, pipeline_matmul_f32_aligned_s;
253
- vk_pipeline pipeline_matmul_f16_l, pipeline_matmul_f16_m, pipeline_matmul_f16_s;
254
- vk_pipeline pipeline_matmul_f16_aligned_l, pipeline_matmul_f16_aligned_m, pipeline_matmul_f16_aligned_s;
255
- vk_pipeline pipeline_matmul_f16_f32_l, pipeline_matmul_f16_f32_m, pipeline_matmul_f16_f32_s;
256
- vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
257
- vk_pipeline pipeline_matmul_split_k_reduce;
258
- vk_pipeline pipeline_dequant[VK_NUM_TYPES];
259
- vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
260
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
261
- vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
262
- vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
263
- vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
264
- vk_pipeline pipeline_mul_f32;
265
- vk_pipeline pipeline_add_f32;
266
- vk_pipeline pipeline_scale_f32;
267
- vk_pipeline pipeline_sqr_f32;
268
- vk_pipeline pipeline_clamp_f32;
269
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
270
- vk_pipeline pipeline_norm_f32;
271
- vk_pipeline pipeline_rms_norm_f32;
272
- vk_pipeline pipeline_gelu_f32;
273
- vk_pipeline pipeline_silu_f32;
274
- vk_pipeline pipeline_relu_f32;
275
- vk_pipeline pipeline_diag_mask_inf_f32;
276
- vk_pipeline pipeline_soft_max_f32;
277
- vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
278
- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
338
+ std::shared_ptr<vk_device> device;
279
339
 
280
340
  size_t semaphore_idx, event_idx;
281
341
  ggml_vk_garbage_collector gc;
@@ -304,13 +364,31 @@ struct vk_instance {
304
364
 
305
365
  std::vector<size_t> device_indices;
306
366
 
307
- std::shared_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
308
367
  ggml_backend_t backends[GGML_VK_MAX_DEVICES];
309
368
  ggml_backend_vk_context contexts[GGML_VK_MAX_DEVICES];
310
369
  ggml_backend_buffer_type buffer_types[GGML_VK_MAX_DEVICES];
311
370
  bool initialized[GGML_VK_MAX_DEVICES];
312
371
  };
313
372
 
373
+ static std::shared_ptr<vk_device> ggml_vk_get_device(size_t idx) {
374
+ #ifdef GGML_VULKAN_DEBUG
375
+ std::cerr << "ggml_vk_get_device(" << idx << ")" << std::endl;
376
+ #endif
377
+ static std::weak_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
378
+
379
+ if (devices[idx].expired()) {
380
+ #ifdef GGML_VULKAN_DEBUG
381
+ std::cerr << "Initializing new vk_device" << std::endl;
382
+ #endif
383
+ std::shared_ptr<vk_device> device = std::make_shared<vk_device>();
384
+ device->initialized = false;
385
+ devices[idx] = device;
386
+ return device;
387
+ }
388
+
389
+ return devices[idx].lock();
390
+ }
391
+
314
392
  #ifdef GGML_VULKAN_CHECK_RESULTS
315
393
  static size_t vk_skip_checks;
316
394
  static size_t vk_output_tensor;
@@ -334,14 +412,15 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
334
412
  GGML_ASSERT(parameter_count > 0);
335
413
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
336
414
 
337
- pipeline.name = name;
338
- pipeline.parameter_count = parameter_count;
339
- pipeline.push_constant_size = push_constant_size;
340
- pipeline.wg_denoms = wg_denoms;
341
- pipeline.align = align;
415
+ pipeline = std::make_shared<vk_pipeline_struct>();
416
+ pipeline->name = name;
417
+ pipeline->parameter_count = parameter_count;
418
+ pipeline->push_constant_size = push_constant_size;
419
+ pipeline->wg_denoms = wg_denoms;
420
+ pipeline->align = align;
342
421
 
343
422
  vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
344
- pipeline.shader_module = ctx->device.lock()->device.createShaderModule(shader_module_create_info);
423
+ pipeline->shader_module = ctx->device->device.createShaderModule(shader_module_create_info);
345
424
 
346
425
  std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
347
426
  std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@@ -355,49 +434,49 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
355
434
  vk::PushConstantRange pcr(
356
435
  vk::ShaderStageFlagBits::eCompute,
357
436
  0,
358
- pipeline.push_constant_size
437
+ pipeline->push_constant_size
359
438
  );
360
439
 
361
440
  vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
362
441
  {},
363
442
  dsl_binding);
364
443
  descriptor_set_layout_create_info.setPNext(&dslbfci);
365
- pipeline.dsl = ctx->device.lock()->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
444
+ pipeline->dsl = ctx->device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
366
445
 
367
446
  // Check if device supports multiple descriptors per pool
368
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
447
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
369
448
  const uint32_t alloc_count = 2;
370
449
 
371
450
  // Try allocating multiple sets from one pool
372
451
  // This fails on AMD for some reason, so add a fall back to allocating one pool per set
373
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
452
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
374
453
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
375
- vk::DescriptorPool pool = ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info);
454
+ vk::DescriptorPool pool = ctx->device->device.createDescriptorPool(descriptor_pool_create_info);
376
455
 
377
456
  std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
378
457
  for (uint32_t i = 0; i < alloc_count; i++) {
379
- layouts[i] = pipeline.dsl;
458
+ layouts[i] = pipeline->dsl;
380
459
  }
381
460
  try {
382
461
  vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
383
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
462
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
384
463
  } catch(vk::OutOfPoolMemoryError const&) {
385
- ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
464
+ ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
386
465
  }
387
466
 
388
- ctx->device.lock()->device.destroyDescriptorPool(pool);
467
+ ctx->device->device.destroyDescriptorPool(pool);
389
468
  }
390
469
 
391
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
392
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
470
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
471
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
393
472
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
394
- pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
473
+ pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
395
474
  }
396
475
 
397
- pipeline.descriptor_set_idx = 0;
476
+ pipeline->descriptor_set_idx = 0;
398
477
 
399
- vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline.dsl, pcr);
400
- pipeline.layout = ctx->device.lock()->device.createPipelineLayout(pipeline_layout_create_info);
478
+ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
479
+ pipeline->layout = ctx->device->device.createPipelineLayout(pipeline_layout_create_info);
401
480
 
402
481
  std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
403
482
 
@@ -417,72 +496,75 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
417
496
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
418
497
  vk::PipelineShaderStageCreateFlags(),
419
498
  vk::ShaderStageFlagBits::eCompute,
420
- pipeline.shader_module,
499
+ pipeline->shader_module,
421
500
  entrypoint.c_str(),
422
501
  &specialization_info);
423
502
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
424
503
  vk::PipelineCreateFlags(),
425
504
  pipeline_shader_create_info,
426
- pipeline.layout);
427
- pipeline.pipeline = ctx->device.lock()->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
505
+ pipeline->layout);
506
+ pipeline->pipeline = ctx->device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
428
507
 
429
- ctx->gc.pipelines.push_back(&pipeline);
508
+ ctx->device->pipelines.push_back(pipeline);
430
509
  }
431
510
 
432
- static void ggml_vk_destroy_pipeline(ggml_backend_vk_context * ctx, vk_pipeline * pipeline) {
511
+ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
512
+ #ifdef GGML_VULKAN_DEBUG
513
+ std::cerr << "ggml_pipeline_destroy_pipeline(" << pipeline->name << ")" << std::endl;
514
+ #endif
433
515
  for (auto& pool : pipeline->descriptor_pools) {
434
- ctx->device.lock()->device.destroyDescriptorPool(pool);
516
+ device.destroyDescriptorPool(pool);
435
517
  }
436
518
  pipeline->descriptor_pools.clear();
437
519
  pipeline->descriptor_sets.clear();
438
520
  pipeline->descriptor_set_idx = 0;
439
521
 
440
- ctx->device.lock()->device.destroyDescriptorSetLayout(pipeline->dsl);
522
+ device.destroyDescriptorSetLayout(pipeline->dsl);
441
523
 
442
- ctx->device.lock()->device.destroyPipelineLayout(pipeline->layout);
524
+ device.destroyPipelineLayout(pipeline->layout);
443
525
 
444
- ctx->device.lock()->device.destroyShaderModule(pipeline->shader_module);
526
+ device.destroyShaderModule(pipeline->shader_module);
445
527
 
446
- ctx->device.lock()->device.destroyPipeline(pipeline->pipeline);
528
+ device.destroyPipeline(pipeline->pipeline);
447
529
  }
448
530
 
449
531
  static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, uint32_t n) {
450
532
  #ifdef GGML_VULKAN_DEBUG
451
- std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl;
533
+ std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")" << std::endl;
452
534
  #endif
453
- if (pipeline.descriptor_sets.size() >= pipeline.descriptor_set_idx + n) {
535
+ if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
454
536
  // Enough descriptors are available
455
537
  return;
456
538
  }
457
539
 
458
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
459
- const uint32_t alloc_count = pipeline.descriptor_set_idx + n - pipeline.descriptor_sets.size();
540
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
541
+ const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
460
542
 
461
543
  std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
462
544
  for (uint32_t i = 0; i < alloc_count; i++) {
463
- layouts[i] = pipeline.dsl;
545
+ layouts[i] = pipeline->dsl;
464
546
  }
465
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[0], alloc_count, layouts.data());
466
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
467
- pipeline.descriptor_sets.insert(pipeline.descriptor_sets.end(), sets.begin(), sets.end());
547
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
548
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
549
+ pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
468
550
  } else {
469
- for (uint32_t i = pipeline.descriptor_sets.size(); i < pipeline.descriptor_set_idx + n; i++) {
470
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
551
+ for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
552
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
471
553
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
472
- pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
554
+ pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
473
555
 
474
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[i], 1, &pipeline.dsl);
475
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
476
- pipeline.descriptor_sets.push_back(sets[0]);
556
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
557
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
558
+ pipeline->descriptor_sets.push_back(sets[0]);
477
559
  }
478
560
  }
479
561
  }
480
562
 
481
563
  static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
482
564
  #ifdef GGML_VULKAN_DEBUG
483
- std::cerr << "ggml_pipeline_cleanup(" << pipeline.name << ")" << std::endl;
565
+ std::cerr << "ggml_pipeline_cleanup(" << pipeline->name << ")" << std::endl;
484
566
  #endif
485
- pipeline.descriptor_set_idx = 0;
567
+ pipeline->descriptor_set_idx = 0;
486
568
  }
487
569
 
488
570
  static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx, vk_queue& q) {
@@ -498,7 +580,7 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx
498
580
  q.pool,
499
581
  vk::CommandBufferLevel::ePrimary,
500
582
  1);
501
- const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device.lock()->device.allocateCommandBuffers(command_buffer_alloc_info);
583
+ const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device->device.allocateCommandBuffers(command_buffer_alloc_info);
502
584
  auto buf = cmd_buffers.front();
503
585
 
504
586
  q.cmd_buffers.push_back(buf);
@@ -643,11 +725,11 @@ static void ggml_vk_create_queue(ggml_backend_vk_context * ctx, vk_queue& q, uin
643
725
  q.queue_family_index = queue_family_index;
644
726
 
645
727
  vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
646
- q.pool = ctx->device.lock()->device.createCommandPool(command_pool_create_info_compute);
728
+ q.pool = ctx->device->device.createCommandPool(command_pool_create_info_compute);
647
729
 
648
730
  q.cmd_buffer_idx = 0;
649
731
 
650
- q.queue = ctx->device.lock()->device.getQueue(queue_family_index, queue_index);
732
+ q.queue = ctx->device->device.getQueue(queue_family_index, queue_index);
651
733
 
652
734
  q.stage_flags = stage_flags;
653
735
  }
@@ -671,7 +753,7 @@ static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context *
671
753
  vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
672
754
  vk::SemaphoreCreateInfo ci{};
673
755
  ci.setPNext(&tci);
674
- vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
756
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
675
757
  ctx->gc.semaphores.push_back({ semaphore, 0 });
676
758
  return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
677
759
  }
@@ -684,7 +766,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
684
766
  vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
685
767
  vk::SemaphoreCreateInfo ci{};
686
768
  ci.setPNext(&tci);
687
- vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
769
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
688
770
  ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
689
771
  }
690
772
  return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
@@ -692,7 +774,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
692
774
 
693
775
  static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
694
776
  if (ctx->event_idx >= ctx->gc.events.size()) {
695
- ctx->gc.events.push_back(ctx->device.lock()->device.createEvent({}));
777
+ ctx->gc.events.push_back(ctx->device->device.createEvent({}));
696
778
  }
697
779
  return ctx->gc.events[ctx->event_idx++];
698
780
  }
@@ -703,7 +785,7 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
703
785
  #endif
704
786
  // Requires command buffers to be done
705
787
 
706
- ctx->device.lock()->device.resetCommandPool(q.pool);
788
+ ctx->device->device.resetCommandPool(q.pool);
707
789
  q.cmd_buffer_idx = 0;
708
790
  }
709
791
 
@@ -740,11 +822,11 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
740
822
  nullptr,
741
823
  };
742
824
 
743
- buf->buffer = ctx->device.lock()->device.createBuffer(buffer_create_info);
825
+ buf->buffer = ctx->device->device.createBuffer(buffer_create_info);
744
826
 
745
- vk::MemoryRequirements mem_req = ctx->device.lock()->device.getBufferMemoryRequirements(buf->buffer);
827
+ vk::MemoryRequirements mem_req = ctx->device->device.getBufferMemoryRequirements(buf->buffer);
746
828
 
747
- vk::PhysicalDeviceMemoryProperties mem_props = ctx->device.lock()->physical_device.getMemoryProperties();
829
+ vk::PhysicalDeviceMemoryProperties mem_props = ctx->device->physical_device.getMemoryProperties();
748
830
 
749
831
  uint32_t memory_type_index = UINT32_MAX;
750
832
 
@@ -757,30 +839,30 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
757
839
  }
758
840
 
759
841
  if (memory_type_index == UINT32_MAX) {
760
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
842
+ ctx->device->device.destroyBuffer(buf->buffer);
761
843
  buf->size = 0;
762
844
  throw vk::OutOfDeviceMemoryError("No suitable memory type found");
763
845
  }
764
846
 
765
847
  try {
766
- buf->device_memory = ctx->device.lock()->device.allocateMemory({ mem_req.size, memory_type_index });
848
+ buf->device_memory = ctx->device->device.allocateMemory({ mem_req.size, memory_type_index });
767
849
  } catch (const vk::SystemError& e) {
768
850
  // Out of Host/Device memory, clean up buffer
769
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
851
+ ctx->device->device.destroyBuffer(buf->buffer);
770
852
  buf->size = 0;
771
853
  throw e;
772
854
  }
773
855
  buf->ptr = nullptr;
774
856
 
775
857
  if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
776
- buf->ptr = ctx->device.lock()->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
858
+ buf->ptr = ctx->device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
777
859
  }
778
860
 
779
- ctx->device.lock()->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
861
+ ctx->device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
780
862
 
781
863
  buf->ctx = ctx;
782
864
 
783
- buf->device = ctx->device.lock();
865
+ buf->device = ctx->device;
784
866
 
785
867
  #ifdef GGML_VULKAN_DEBUG
786
868
  std::cerr << "Created buffer " << buf->buffer << std::endl;
@@ -802,7 +884,7 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
802
884
  static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
803
885
  vk_buffer buf;
804
886
  try {
805
- if (ctx->device.lock()->uma) {
887
+ if (ctx->device->uma) {
806
888
  // Fall back to host memory type
807
889
  buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
808
890
  } else {
@@ -883,10 +965,16 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
883
965
  std::cerr << "ggml_vk_load_shaders(" << ctx->name << ")" << std::endl;
884
966
  #endif
885
967
 
968
+ const std::shared_ptr<vk_device> device = ctx->device;
969
+
886
970
  // mulmat
887
- std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, ctx->device.lock()->subgroup_size * 2, 64, 2, 4, 4, ctx->device.lock()->subgroup_size };
888
- std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, ctx->device.lock()->subgroup_size, 32, 2, 4, 2, ctx->device.lock()->subgroup_size };
889
- std::initializer_list<uint32_t> warptile_s = { ctx->device.lock()->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, ctx->device.lock()->subgroup_size };
971
+ std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
972
+ std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
973
+ std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
974
+
975
+ std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
976
+ std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
977
+ std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
890
978
 
891
979
  std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
892
980
  std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
@@ -896,126 +984,206 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
896
984
  uint32_t m_align = 64;
897
985
  uint32_t s_align = 32;
898
986
 
899
- if (ctx->device.lock()->fp16) {
900
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_len, matmul_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
901
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_len, matmul_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
902
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_len, matmul_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
903
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_len, matmul_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
904
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_len, matmul_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
905
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_len, matmul_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
906
-
907
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_len, matmul_f16_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
908
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_len, matmul_f16_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
909
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_len, matmul_f16_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
910
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_len, matmul_f16_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
911
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_len, matmul_f16_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
912
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_len, matmul_f16_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
913
-
914
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_len, matmul_f16_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
915
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_len, matmul_f16_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
916
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_len, matmul_f16_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
917
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_len, matmul_f16_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
918
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_len, matmul_f16_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
919
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_len, matmul_f16_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
987
+ ctx->device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
988
+ ctx->device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
989
+ ctx->device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
990
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
991
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
992
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
993
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
994
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
995
+
996
+ if (device->fp16) {
997
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
998
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
999
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1000
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1001
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1002
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1003
+
1004
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1005
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1006
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1007
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1008
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1009
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1010
+
1011
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1012
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1013
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1014
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1015
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1016
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1017
+
1018
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1019
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1020
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1021
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1022
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1023
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1024
+
1025
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_0_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1026
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_0_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1027
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_0_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1028
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1029
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1030
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1031
+
1032
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1033
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1034
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1035
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1036
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1037
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1038
+
1039
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1040
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1041
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1042
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1043
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1044
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1045
+
1046
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1047
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1048
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1049
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1050
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1051
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
920
1052
  } else {
921
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
922
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
923
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_fp32_len, matmul_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
924
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_fp32_len, matmul_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
925
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_fp32_len, matmul_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
926
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_fp32_len, matmul_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
927
-
928
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_fp32_len, matmul_f16_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
929
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_fp32_len, matmul_f16_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
930
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_fp32_len, matmul_f16_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
931
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_fp32_len, matmul_f16_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
932
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_fp32_len, matmul_f16_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
933
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_fp32_len, matmul_f16_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
934
-
935
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_fp32_len, matmul_f16_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
936
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_fp32_len, matmul_f16_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
937
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_fp32_len, matmul_f16_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
938
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_fp32_len, matmul_f16_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
939
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_fp32_len, matmul_f16_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
940
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_fp32_len, matmul_f16_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
941
- }
942
-
943
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
944
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
945
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
946
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
947
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
948
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
949
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
950
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
951
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
952
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
953
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
1053
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1054
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1055
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1056
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1057
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1058
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1059
+
1060
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1061
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1062
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1063
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1064
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1065
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1066
+
1067
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1068
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1069
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1070
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1071
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1072
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1073
+
1074
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1075
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1076
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1077
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1078
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1079
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1080
+
1081
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1082
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1083
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1084
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1085
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1086
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1087
+
1088
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1089
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1090
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1091
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1092
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1093
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1094
+
1095
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1096
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1097
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1098
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1099
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1100
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1101
+
1102
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1103
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1104
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1105
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1106
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1107
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1108
+ }
1109
+
1110
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1111
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1112
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1113
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1114
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1115
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1116
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1117
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1118
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1119
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1120
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
954
1121
 
955
1122
  // dequant shaders
956
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 4 * sizeof(int), { 64, 1, 1}, {}, 1);
957
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
958
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
959
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
960
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
961
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
962
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
963
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
964
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
965
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
966
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
967
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
1123
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1124
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1125
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1126
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1127
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1128
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1129
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1130
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1131
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
1132
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1133
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
968
1134
 
969
1135
  // get_rows
970
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
971
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
972
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
973
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
974
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
975
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1136
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1137
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1138
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1139
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1140
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1141
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1142
+
1143
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1144
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1145
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1146
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1147
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1148
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
976
1149
 
977
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
978
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
979
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
980
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
981
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
982
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1150
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
983
1151
 
984
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
1152
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1153
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
985
1154
 
986
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
987
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1155
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1156
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
988
1157
 
989
- ggml_vk_create_pipeline(ctx, ctx->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
990
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1158
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1159
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1160
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
991
1161
 
992
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
993
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
994
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
1162
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
995
1163
 
996
- ggml_vk_create_pipeline(ctx, ctx->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1164
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
997
1165
 
998
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1166
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
999
1167
 
1000
- ggml_vk_create_pipeline(ctx, ctx->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1168
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1001
1169
 
1002
- ggml_vk_create_pipeline(ctx, ctx->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1170
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1003
1171
 
1004
- ggml_vk_create_pipeline(ctx, ctx->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1172
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1173
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1174
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1005
1175
 
1006
- ggml_vk_create_pipeline(ctx, ctx->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1007
- ggml_vk_create_pipeline(ctx, ctx->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1008
- ggml_vk_create_pipeline(ctx, ctx->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1176
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1009
1177
 
1010
- ggml_vk_create_pipeline(ctx, ctx->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1178
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1011
1179
 
1012
- ggml_vk_create_pipeline(ctx, ctx->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1180
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1181
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1013
1182
 
1014
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1015
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1183
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1184
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1016
1185
 
1017
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1018
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1186
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1019
1187
  }
1020
1188
 
1021
1189
  static void ggml_vk_print_gpu_info(size_t idx) {
@@ -1057,8 +1225,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
1057
1225
  }
1058
1226
  }
1059
1227
 
1060
- const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
1061
- bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
1228
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1229
+ bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1062
1230
 
1063
1231
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1064
1232
 
@@ -1188,140 +1356,152 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
1188
1356
  throw std::runtime_error("Device not found");
1189
1357
  }
1190
1358
 
1191
- vk_instance.devices[idx] = std::make_shared<vk_device>();
1192
- ctx->device = vk_instance.devices[idx];
1193
- ctx->device.lock()->physical_device = devices[dev_num];
1194
- const std::vector<vk::ExtensionProperties> ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties();
1359
+ ctx->device = ggml_vk_get_device(idx);
1360
+ if (!ctx->device->initialized) {
1361
+ ctx->device->physical_device = devices[dev_num];
1362
+ const std::vector<vk::ExtensionProperties> ext_props = ctx->device->physical_device.enumerateDeviceExtensionProperties();
1195
1363
 
1196
- bool maintenance4_support = false;
1364
+ bool maintenance4_support = false;
1197
1365
 
1198
- // Check if maintenance4 is supported
1199
- for (const auto& properties : ext_props) {
1200
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1201
- maintenance4_support = true;
1366
+ // Check if maintenance4 is supported
1367
+ for (const auto& properties : ext_props) {
1368
+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1369
+ maintenance4_support = true;
1370
+ }
1202
1371
  }
1203
- }
1204
1372
 
1205
- vk::PhysicalDeviceProperties2 props2;
1206
- vk::PhysicalDeviceMaintenance3Properties props3;
1207
- vk::PhysicalDeviceMaintenance4Properties props4;
1208
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
1209
- props2.pNext = &props3;
1210
- props3.pNext = &subgroup_props;
1211
- if (maintenance4_support) {
1212
- subgroup_props.pNext = &props4;
1213
- }
1214
- ctx->device.lock()->physical_device.getProperties2(&props2);
1215
- ctx->device.lock()->properties = props2.properties;
1373
+ vk::PhysicalDeviceProperties2 props2;
1374
+ vk::PhysicalDeviceMaintenance3Properties props3;
1375
+ vk::PhysicalDeviceMaintenance4Properties props4;
1376
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
1377
+ props2.pNext = &props3;
1378
+ props3.pNext = &subgroup_props;
1379
+ if (maintenance4_support) {
1380
+ subgroup_props.pNext = &props4;
1381
+ }
1382
+ ctx->device->physical_device.getProperties2(&props2);
1383
+ ctx->device->properties = props2.properties;
1216
1384
 
1217
- if (maintenance4_support) {
1218
- ctx->device.lock()->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1219
- } else {
1220
- ctx->device.lock()->max_memory_allocation_size = props3.maxMemoryAllocationSize;
1221
- }
1385
+ const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
1222
1386
 
1223
- ctx->device.lock()->vendor_id = ctx->device.lock()->properties.vendorID;
1224
- ctx->device.lock()->subgroup_size = subgroup_props.subgroupSize;
1225
- ctx->device.lock()->uma = ctx->device.lock()->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1387
+ if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
1388
+ ctx->device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
1389
+ } else if (maintenance4_support) {
1390
+ ctx->device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1391
+ } else {
1392
+ ctx->device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
1393
+ }
1226
1394
 
1227
- bool fp16_storage = false;
1228
- bool fp16_compute = false;
1395
+ ctx->device->vendor_id = ctx->device->properties.vendorID;
1396
+ ctx->device->subgroup_size = subgroup_props.subgroupSize;
1397
+ ctx->device->uma = ctx->device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1229
1398
 
1230
- for (const auto& properties : ext_props) {
1231
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1232
- fp16_storage = true;
1233
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1234
- fp16_compute = true;
1399
+ bool fp16_storage = false;
1400
+ bool fp16_compute = false;
1401
+
1402
+ for (const auto& properties : ext_props) {
1403
+ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1404
+ fp16_storage = true;
1405
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1406
+ fp16_compute = true;
1407
+ }
1235
1408
  }
1236
- }
1237
1409
 
1238
- const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
1239
- bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
1410
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1411
+ const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1240
1412
 
1241
- ctx->device.lock()->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1413
+ ctx->device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1242
1414
 
1243
- std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device.lock()->physical_device.getQueueFamilyProperties();
1415
+ std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device->physical_device.getQueueFamilyProperties();
1244
1416
 
1245
- // Try to find a non-graphics compute queue and transfer-focused queues
1246
- const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
1247
- const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
1417
+ // Try to find a non-graphics compute queue and transfer-focused queues
1418
+ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
1419
+ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
1248
1420
 
1249
- const float priorities[] = { 1.0f, 1.0f };
1250
- ctx->device.lock()->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
1421
+ const float priorities[] = { 1.0f, 1.0f };
1422
+ ctx->device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
1251
1423
 
1252
- std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
1253
- if (compute_queue_family_index != transfer_queue_family_index) {
1254
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1255
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
1256
- } else if(!ctx->device.lock()->single_queue) {
1257
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
1258
- } else {
1259
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1260
- }
1261
- vk::DeviceCreateInfo device_create_info;
1262
- std::vector<const char *> device_extensions;
1263
- vk::PhysicalDeviceFeatures device_features = ctx->device.lock()->physical_device.getFeatures();
1424
+ std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
1425
+ if (compute_queue_family_index != transfer_queue_family_index) {
1426
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1427
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
1428
+ } else if(!ctx->device->single_queue) {
1429
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
1430
+ } else {
1431
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1432
+ }
1433
+ vk::DeviceCreateInfo device_create_info;
1434
+ std::vector<const char *> device_extensions;
1435
+ vk::PhysicalDeviceFeatures device_features = ctx->device->physical_device.getFeatures();
1264
1436
 
1265
- VkPhysicalDeviceFeatures2 device_features2;
1266
- device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
1267
- device_features2.pNext = nullptr;
1268
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
1437
+ VkPhysicalDeviceFeatures2 device_features2;
1438
+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
1439
+ device_features2.pNext = nullptr;
1440
+ device_features2.features = (VkPhysicalDeviceFeatures)device_features;
1269
1441
 
1270
- VkPhysicalDeviceVulkan11Features vk11_features;
1271
- vk11_features.pNext = nullptr;
1272
- vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
1273
- device_features2.pNext = &vk11_features;
1442
+ VkPhysicalDeviceVulkan11Features vk11_features;
1443
+ vk11_features.pNext = nullptr;
1444
+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
1445
+ device_features2.pNext = &vk11_features;
1274
1446
 
1275
- VkPhysicalDeviceVulkan12Features vk12_features;
1276
- vk12_features.pNext = nullptr;
1277
- vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1278
- vk11_features.pNext = &vk12_features;
1447
+ VkPhysicalDeviceVulkan12Features vk12_features;
1448
+ vk12_features.pNext = nullptr;
1449
+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1450
+ vk11_features.pNext = &vk12_features;
1279
1451
 
1280
- vkGetPhysicalDeviceFeatures2(ctx->device.lock()->physical_device, &device_features2);
1452
+ vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
1281
1453
 
1282
- ctx->device.lock()->fp16 = ctx->device.lock()->fp16 && vk12_features.shaderFloat16;
1454
+ ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
1283
1455
 
1284
- if (!vk11_features.storageBuffer16BitAccess) {
1285
- std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1286
- throw std::runtime_error("Unsupported device");
1287
- }
1456
+ if (!vk11_features.storageBuffer16BitAccess) {
1457
+ std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1458
+ throw std::runtime_error("Unsupported device");
1459
+ }
1288
1460
 
1289
- device_extensions.push_back("VK_KHR_16bit_storage");
1461
+ device_extensions.push_back("VK_KHR_16bit_storage");
1290
1462
 
1291
1463
  #ifdef GGML_VULKAN_VALIDATE
1292
- device_extensions.push_back("VK_KHR_shader_non_semantic_info");
1464
+ device_extensions.push_back("VK_KHR_shader_non_semantic_info");
1293
1465
  #endif
1294
1466
 
1295
- if (ctx->device.lock()->fp16) {
1296
- device_extensions.push_back("VK_KHR_shader_float16_int8");
1297
- }
1298
- ctx->device.lock()->name = ctx->device.lock()->properties.deviceName.data();
1467
+ if (ctx->device->fp16) {
1468
+ device_extensions.push_back("VK_KHR_shader_float16_int8");
1469
+ }
1470
+ ctx->device->name = ctx->device->properties.deviceName.data();
1299
1471
 
1300
- device_create_info = {
1301
- vk::DeviceCreateFlags(),
1302
- device_queue_create_infos,
1303
- {},
1304
- device_extensions
1305
- };
1306
- device_create_info.setPNext(&device_features2);
1307
- ctx->device.lock()->device = ctx->device.lock()->physical_device.createDevice(device_create_info);
1472
+ device_create_info = {
1473
+ vk::DeviceCreateFlags(),
1474
+ device_queue_create_infos,
1475
+ {},
1476
+ device_extensions
1477
+ };
1478
+ device_create_info.setPNext(&device_features2);
1479
+ ctx->device->device = ctx->device->physical_device.createDevice(device_create_info);
1308
1480
 
1309
- ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
1481
+ ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
1310
1482
 
1311
- // Shaders
1312
- ggml_vk_load_shaders(ctx);
1483
+ // Queues
1484
+ ggml_vk_create_queue(ctx, ctx->device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
1313
1485
 
1314
- // Queues
1315
- ggml_vk_create_queue(ctx, ctx->device.lock()->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
1316
- if (!ctx->device.lock()->single_queue) {
1317
- const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
1318
- ggml_vk_create_queue(ctx, ctx->device.lock()->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
1319
- } else {
1320
- // TODO: Use pointer or reference to avoid copy
1321
- ctx->device.lock()->transfer_queue = ctx->device.lock()->compute_queue;
1486
+ // Shaders
1487
+ ggml_vk_load_shaders(ctx);
1488
+
1489
+ if (!ctx->device->single_queue) {
1490
+ const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
1491
+ ggml_vk_create_queue(ctx, ctx->device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
1492
+ } else {
1493
+ // TODO: Use pointer or reference to avoid copy
1494
+ ctx->device->transfer_queue = ctx->device->compute_queue;
1495
+ }
1496
+
1497
+ ctx->device->idx = dev_num;
1498
+ ctx->device->initialized = true;
1499
+ } else if (ctx->device->idx != dev_num) {
1500
+ std::cerr << "ggml_vulkan: Device " << ctx->device->name << " already initialized with index " << ctx->device->idx << ", but trying to reinitialize with index " << dev_num << std::endl;
1501
+ throw std::runtime_error("Device already initialized");
1322
1502
  }
1323
1503
 
1324
- ctx->fence = ctx->device.lock()->device.createFence({});
1504
+ ctx->fence = ctx->device->device.createFence({});
1325
1505
 
1326
1506
  ctx->compute_ctx = nullptr;
1327
1507
  ctx->transfer_ctx = nullptr;
@@ -1339,7 +1519,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
1339
1519
  #endif
1340
1520
  }
1341
1521
 
1342
- static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
1522
+ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
1343
1523
  #ifdef GGML_VULKAN_DEBUG
1344
1524
  std::cerr << "ggml_vk_get_to_fp16()" << std::endl;
1345
1525
  #endif
@@ -1360,10 +1540,36 @@ static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
1360
1540
  return nullptr;
1361
1541
  }
1362
1542
 
1363
- return &ctx->pipeline_dequant[type];
1543
+ return ctx->device->pipeline_dequant[type];
1364
1544
  }
1365
1545
 
1366
- static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
1546
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
1547
+ #ifdef GGML_VULKAN_DEBUG
1548
+ std::cerr << "ggml_vk_get_mul_mat_mat_pipeline()" << std::endl;
1549
+ #endif
1550
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
1551
+ return ctx->device->pipeline_matmul_f32;
1552
+ }
1553
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
1554
+ return ctx->device->pipeline_matmul_f16_f32;
1555
+ }
1556
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
1557
+ return ctx->device->pipeline_matmul_f16;
1558
+ }
1559
+
1560
+ GGML_ASSERT(src1_type == GGML_TYPE_F32);
1561
+
1562
+ switch (src0_type) {
1563
+ case GGML_TYPE_Q4_0:
1564
+ break;
1565
+ default:
1566
+ return nullptr;
1567
+ }
1568
+
1569
+ return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
1570
+ }
1571
+
1572
+ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
1367
1573
  #ifdef GGML_VULKAN_DEBUG
1368
1574
  std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
1369
1575
  #endif
@@ -1384,7 +1590,7 @@ static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
1384
1590
  return nullptr;
1385
1591
  }
1386
1592
 
1387
- return &ctx->pipeline_dequant_mul_mat_vec_f32[type];
1593
+ return ctx->device->pipeline_dequant_mul_mat_vec_f32[type];
1388
1594
  }
1389
1595
 
1390
1596
  static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
@@ -1463,8 +1669,8 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
1463
1669
  if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1464
1670
  fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
1465
1671
  size/1024.0/1024.0);
1466
- ctx->device.lock()->device.freeMemory(buf->device_memory);
1467
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
1672
+ ctx->device->device.freeMemory(buf->device_memory);
1673
+ ctx->device->device.destroyBuffer(buf->buffer);
1468
1674
  return nullptr;
1469
1675
  }
1470
1676
 
@@ -1528,30 +1734,30 @@ static vk_submission ggml_vk_begin_submission(ggml_backend_vk_context * ctx, vk_
1528
1734
  }
1529
1735
 
1530
1736
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
1531
- const uint32_t wg0 = CEIL_DIV(elements[0], pipeline.wg_denoms[0]);
1532
- const uint32_t wg1 = CEIL_DIV(elements[1], pipeline.wg_denoms[1]);
1533
- const uint32_t wg2 = CEIL_DIV(elements[2], pipeline.wg_denoms[2]);
1737
+ const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
1738
+ const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
1739
+ const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
1534
1740
  #ifdef GGML_VULKAN_DEBUG
1535
- std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline.name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
1741
+ std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
1536
1742
  #endif
1537
1743
  std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
1538
1744
  std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
1539
- GGML_ASSERT(pipeline.descriptor_set_idx < pipeline.descriptor_sets.size());
1540
- GGML_ASSERT(buffers.size() == pipeline.parameter_count);
1541
- vk::DescriptorSet& descriptor_set = pipeline.descriptor_sets[pipeline.descriptor_set_idx++];
1542
- for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
1745
+ GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
1746
+ GGML_ASSERT(buffers.size() == pipeline->parameter_count);
1747
+ vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
1748
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
1543
1749
  descriptor_buffer_infos.push_back({buffers[i].buffer->buffer, buffers[i].offset, buffers[i].size});
1544
1750
  }
1545
- for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
1751
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
1546
1752
  write_descriptor_sets.push_back({descriptor_set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &descriptor_buffer_infos[i]});
1547
1753
  }
1548
1754
 
1549
- ctx->device.lock()->device.updateDescriptorSets(write_descriptor_sets, {});
1755
+ ctx->device->device.updateDescriptorSets(write_descriptor_sets, {});
1550
1756
 
1551
- subctx->s->buffer.pushConstants(pipeline.layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
1552
- subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline.pipeline);
1757
+ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
1758
+ subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
1553
1759
  subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
1554
- pipeline.layout,
1760
+ pipeline->layout,
1555
1761
  0,
1556
1762
  { descriptor_set },
1557
1763
  {});
@@ -1810,7 +2016,7 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
1810
2016
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
1811
2017
  }
1812
2018
  } else {
1813
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2019
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1814
2020
  ggml_vk_ctx_begin(ctx, subctx);
1815
2021
  ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, src, spitch, width, height, true);
1816
2022
  ggml_vk_ctx_end(subctx);
@@ -1820,8 +2026,9 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
1820
2026
  }
1821
2027
 
1822
2028
  ggml_vk_submit(subctx, ctx->fence);
1823
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
1824
- ctx->device.lock()->device.resetFences({ ctx->fence });
2029
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
2030
+ ctx->device->device.resetFences({ ctx->fence });
2031
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
1825
2032
  }
1826
2033
  }
1827
2034
 
@@ -1906,18 +2113,19 @@ static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, s
1906
2113
 
1907
2114
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
1908
2115
  } else {
1909
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2116
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1910
2117
  ggml_vk_ctx_begin(ctx, subctx);
1911
2118
  ggml_vk_buffer_read_async(ctx, subctx, src, offset, dst, size, true);
1912
2119
  ggml_vk_ctx_end(subctx);
1913
2120
 
1914
2121
  ggml_vk_submit(subctx, ctx->fence);
1915
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
1916
- ctx->device.lock()->device.resetFences({ ctx->fence });
2122
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
2123
+ ctx->device->device.resetFences({ ctx->fence });
1917
2124
 
1918
2125
  for (auto& cpy : subctx->out_memcpys) {
1919
2126
  memcpy(cpy.dst, cpy.src, cpy.n);
1920
2127
  }
2128
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
1921
2129
  }
1922
2130
  }
1923
2131
 
@@ -1941,15 +2149,13 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
1941
2149
  // Copy within the device
1942
2150
  ggml_backend_vk_context * ctx = src->ctx;
1943
2151
 
1944
- VkBufferCopy bc{ src_offset, dst_offset, size };
1945
-
1946
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2152
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1947
2153
  ggml_vk_ctx_begin(ctx, subctx);
1948
2154
  ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
1949
2155
  ggml_vk_ctx_end(subctx);
1950
2156
  ggml_vk_submit(subctx, ctx->fence);
1951
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
1952
- ctx->device.lock()->device.resetFences({ ctx->fence });
2157
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
2158
+ ctx->device->device.resetFences({ ctx->fence });
1953
2159
  } else {
1954
2160
  #ifdef GGML_VULKAN_DEBUG
1955
2161
  std::cerr << "ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")" << std::endl;
@@ -1977,14 +2183,14 @@ static void ggml_vk_buffer_memset(ggml_backend_vk_context * ctx, vk_buffer& dst,
1977
2183
  // Make sure ctx owns the buffer
1978
2184
  GGML_ASSERT(dst->ctx == ctx);
1979
2185
 
1980
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2186
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1981
2187
  ggml_vk_ctx_begin(ctx, subctx);
1982
2188
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
1983
2189
  ggml_vk_ctx_end(subctx);
1984
2190
 
1985
2191
  ggml_vk_submit(subctx, ctx->fence);
1986
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
1987
- ctx->device.lock()->device.resetFences({ ctx->fence });
2192
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
2193
+ ctx->device->device.resetFences({ ctx->fence });
1988
2194
  }
1989
2195
 
1990
2196
  static void ggml_vk_h2d_tensor_2d(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * src, uint64_t i3, uint64_t i2, uint64_t i1) {
@@ -2045,176 +2251,63 @@ static void ggml_vk_d2h_tensor_2d(ggml_backend_vk_context * ctx, vk_context * su
2045
2251
 
2046
2252
  static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
2047
2253
  #ifdef GGML_VULKAN_DEBUG
2048
- std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")";
2254
+ std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")" << std::endl;
2049
2255
  #endif
2050
2256
  if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2051
- #ifdef GGML_VULKAN_DEBUG
2052
- std::cerr << " = 4" << std::endl;
2053
- #endif
2054
2257
  return 4;
2055
2258
  }
2056
2259
 
2057
- #ifdef GGML_VULKAN_DEBUG
2058
- std::cerr << " = 1" << std::endl;
2059
- #endif
2060
2260
  return 1;
2061
2261
  }
2062
2262
 
2063
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, int m, int n) {
2064
- #ifdef GGML_VULKAN_DEBUG
2065
- std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
2066
- #endif
2263
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2067
2264
  if (m <= 32 || n <= 32) {
2068
- return ctx->pipeline_matmul_f32_aligned_s.align;
2069
- }
2070
- if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
2071
- return ctx->pipeline_matmul_f32_aligned_m.align;
2265
+ return aligned ? mmp->a_s : mmp->s;
2072
2266
  }
2073
- return ctx->pipeline_matmul_f32_aligned_l.align;
2267
+ return aligned ? mmp->a_m : mmp->m;
2268
+
2269
+ GGML_UNUSED(ctx);
2074
2270
  }
2075
2271
 
2076
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2077
- if (bit16_x && bit16_y) {
2078
- if (m <= 32 || n <= 32) {
2079
- #ifdef GGML_VULKAN_DEBUG
2080
- std::cerr << " S" << std::endl;
2081
- #endif
2082
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2083
- }
2084
- #ifdef GGML_VULKAN_DEBUG
2085
- std::cerr << " M" << std::endl;
2086
- #endif
2087
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2088
- }
2089
- if (bit16_x && !bit16_y) {
2090
- if (m <= 32 || n <= 32) {
2091
- #ifdef GGML_VULKAN_DEBUG
2092
- std::cerr << " S" << std::endl;
2093
- #endif
2094
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2095
- }
2096
- #ifdef GGML_VULKAN_DEBUG
2097
- std::cerr << " M" << std::endl;
2098
- #endif
2099
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2100
- }
2101
- if (!bit16_x && bit16_y) {
2102
- GGML_ASSERT(false);
2103
- }
2272
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2273
+ return aligned ? mmp->a_m : mmp->m;
2104
2274
 
2105
- if (m <= 32 || n <= 32) {
2106
- #ifdef GGML_VULKAN_DEBUG
2107
- std::cerr << " S" << std::endl;
2108
- #endif
2109
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2110
- }
2111
- #ifdef GGML_VULKAN_DEBUG
2112
- std::cerr << " M" << std::endl;
2113
- #endif
2114
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2275
+ GGML_UNUSED(ctx);
2115
2276
  }
2116
2277
 
2117
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2118
- #ifdef GGML_VULKAN_DEBUG
2119
- std::cerr << " M" << std::endl;
2120
- #endif
2121
- if (bit16_x && bit16_y) {
2122
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2123
- }
2124
- if (bit16_x && !bit16_y) {
2125
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2126
- }
2127
- if (!bit16_x && bit16_y) {
2128
- GGML_ASSERT(false);
2129
- }
2130
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2131
- }
2278
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2279
+ return aligned ? mmp->a_s : mmp->s;
2132
2280
 
2133
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2134
- #ifdef GGML_VULKAN_DEBUG
2135
- std::cerr << " S" << std::endl;
2136
- #endif
2137
- if (bit16_x && bit16_y) {
2138
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2139
- }
2140
- if (bit16_x && !bit16_y) {
2141
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2142
- }
2143
- if (!bit16_x && bit16_y) {
2144
- GGML_ASSERT(false);
2145
- }
2146
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2281
+ GGML_UNUSED(ctx);
2147
2282
  }
2148
2283
 
2149
- static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2284
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2150
2285
  #ifdef GGML_VULKAN_DEBUG
2151
- std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
2286
+ std::cerr << "ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")" << std::endl;
2152
2287
  #endif
2153
- switch (ctx->device.lock()->vendor_id) {
2288
+ switch (ctx->device->vendor_id) {
2154
2289
  case VK_VENDOR_ID_AMD:
2155
- return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
2290
+ return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
2156
2291
  case VK_VENDOR_ID_APPLE:
2157
- return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
2292
+ return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
2158
2293
  case VK_VENDOR_ID_INTEL:
2159
- return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
2160
- }
2161
-
2162
- if (bit16_x && bit16_y) {
2163
- if (m <= 32 || n <= 32) {
2164
- #ifdef GGML_VULKAN_DEBUG
2165
- std::cerr << " S" << std::endl;
2166
- #endif
2167
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2168
- }
2169
- if (m <= 64 || n <= 64) {
2170
- #ifdef GGML_VULKAN_DEBUG
2171
- std::cerr << " M" << std::endl;
2172
- #endif
2173
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2174
- }
2175
- #ifdef GGML_VULKAN_DEBUG
2176
- std::cerr << " L" << std::endl;
2177
- #endif
2178
- return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l;
2179
- }
2180
- if (bit16_x && !bit16_y) {
2181
- if (m <= 32 || n <= 32) {
2182
- #ifdef GGML_VULKAN_DEBUG
2183
- std::cerr << " S" << std::endl;
2184
- #endif
2185
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2186
- }
2187
- if (m <= 64 || n <= 64) {
2188
- #ifdef GGML_VULKAN_DEBUG
2189
- std::cerr << " M" << std::endl;
2190
- #endif
2191
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2192
- }
2193
- #ifdef GGML_VULKAN_DEBUG
2194
- std::cerr << " L" << std::endl;
2195
- #endif
2196
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_l : &ctx->pipeline_matmul_f16_f32_l;
2197
- }
2198
- if (!bit16_x && bit16_y) {
2199
- GGML_ASSERT(false);
2294
+ return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
2200
2295
  }
2201
2296
 
2202
2297
  if (m <= 32 || n <= 32) {
2203
- #ifdef GGML_VULKAN_DEBUG
2204
- std::cerr << " S" << std::endl;
2205
- #endif
2206
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2298
+ return aligned ? mmp->a_s : mmp->s;
2207
2299
  }
2208
2300
  if (m <= 64 || n <= 64) {
2209
- #ifdef GGML_VULKAN_DEBUG
2210
- std::cerr << " M" << std::endl;
2211
- #endif
2212
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2301
+ return aligned ? mmp->a_m : mmp->m;
2213
2302
  }
2303
+ return aligned ? mmp->a_l : mmp->l;
2304
+ }
2305
+
2306
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
2214
2307
  #ifdef GGML_VULKAN_DEBUG
2215
- std::cerr << " L" << std::endl;
2308
+ std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
2216
2309
  #endif
2217
- return aligned ? &ctx->pipeline_matmul_f32_aligned_l : &ctx->pipeline_matmul_f32_l;
2310
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, false)->align;
2218
2311
  }
2219
2312
 
2220
2313
  static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
@@ -2232,10 +2325,10 @@ static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, v
2232
2325
 
2233
2326
  const std::array<uint32_t, 14> pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d };
2234
2327
  // Make sure enough workgroups get assigned for split k to work
2235
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline.wg_denoms[0]) * pipeline.wg_denoms[0]) * split_k, n, batch });
2328
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
2236
2329
  ggml_vk_sync_buffers(subctx);
2237
2330
  const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
2238
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2331
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2239
2332
  }
2240
2333
 
2241
2334
  static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -2245,41 +2338,39 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
2245
2338
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
2246
2339
  }
2247
2340
 
2248
- static vk_pipeline * ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
2341
+ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
2249
2342
  if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
2250
- return &ctx->pipeline_cpy_f32_f32;
2343
+ return ctx->device->pipeline_cpy_f32_f32;
2251
2344
  }
2252
2345
  if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
2253
- return &ctx->pipeline_cpy_f32_f16;
2346
+ return ctx->device->pipeline_cpy_f32_f16;
2254
2347
  }
2255
2348
  if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
2256
- return &ctx->pipeline_cpy_f16_f16;
2349
+ return ctx->device->pipeline_cpy_f16_f16;
2257
2350
  }
2258
2351
 
2259
2352
  std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
2260
2353
  GGML_ASSERT(false);
2261
2354
  }
2262
2355
 
2263
- static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline * pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out, ggml_type buffer_type, bool aligned=true) {
2356
+ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
2264
2357
  #ifdef GGML_VULKAN_DEBUG
2265
2358
  std::cerr << "ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", backend=" << tensor->backend << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
2266
2359
  std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")" << std::endl;
2267
2360
  #endif
2268
2361
  const int tensor_type_size = ggml_type_size(tensor->type);
2269
- const int dst_type_size = ggml_type_size(buffer_type);
2270
-
2271
- const uint32_t ne = tensor->ne[0] * tensor->ne[1] * tensor->ne[2];
2272
2362
 
2273
- const uint32_t nb2 = aligned ? ggml_vk_align_size(dst_type_size * tensor->ne[0] * tensor->ne[1], ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size : tensor->ne[0] * tensor->ne[1];
2363
+ const uint32_t ne = ggml_nelements(tensor);
2274
2364
 
2275
- const vk_op_cpy_push_constants pc = {
2365
+ const vk_op_unary_push_constants pc = {
2276
2366
  (uint32_t)ne,
2277
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size,
2278
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], 1 , (uint32_t)tensor->ne[0] , nb2,
2367
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
2368
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
2279
2369
  0,
2370
+ 0.0f, 0.0f,
2280
2371
  };
2281
2372
  ggml_vk_sync_buffers(subctx);
2282
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { in, out }, sizeof(vk_op_cpy_push_constants), &pc, { ne, 1, 1 });
2373
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
2283
2374
  }
2284
2375
 
2285
2376
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2319,7 +2410,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2319
2410
  bool src0_uma = false;
2320
2411
  bool src1_uma = false;
2321
2412
 
2322
- if (ctx->device.lock()->uma) {
2413
+ if (ctx->device->uma) {
2323
2414
  ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
2324
2415
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2325
2416
  src0_uma = d_Qx != nullptr;
@@ -2332,10 +2423,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2332
2423
  const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
2333
2424
  const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
2334
2425
 
2335
- const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2426
+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2336
2427
 
2337
- const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 || x_non_contig;
2338
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
2428
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
2429
+
2430
+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
2431
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
2432
+
2433
+ if (mmp == nullptr) {
2434
+ // Fall back to dequant + f16 mulmat
2435
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
2436
+ }
2339
2437
 
2340
2438
  // Not implemented
2341
2439
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
@@ -2344,17 +2442,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2344
2442
  const int y_ne = ne11 * ne10;
2345
2443
  const int d_ne = ne11 * ne01;
2346
2444
 
2347
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, ne01, ne11));
2445
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
2348
2446
  const bool aligned = ne10 == kpad;
2349
2447
 
2350
2448
  const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
2351
2449
 
2352
- vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(ctx, true, !f16_f32_kernel, ne01, ne11, aligned);
2450
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
2353
2451
 
2354
2452
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
2355
2453
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2356
- const uint64_t x_sz = sizeof(ggml_fp16_t) * x_ne;
2357
- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2454
+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
2455
+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2358
2456
  const uint64_t d_sz = sizeof(float) * d_ne;
2359
2457
 
2360
2458
  vk_buffer d_D = extra->buffer_gpu.lock();
@@ -2385,7 +2483,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2385
2483
  } else {
2386
2484
  d_X = d_Qx;
2387
2485
  x_buf_offset = qx_buf_offset;
2388
- GGML_ASSERT(qx_sz == x_sz); // NOLINT
2486
+ GGML_ASSERT(qx_sz == x_sz);
2389
2487
  }
2390
2488
  if (qy_needs_dequant) {
2391
2489
  d_Y = ctx->prealloc_y;
@@ -2396,8 +2494,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2396
2494
  GGML_ASSERT(qy_sz == y_sz);
2397
2495
  }
2398
2496
 
2399
- vk_pipeline * to_fp16_vk_0 = nullptr;
2400
- vk_pipeline * to_fp16_vk_1 = nullptr;
2497
+ vk_pipeline to_fp16_vk_0 = nullptr;
2498
+ vk_pipeline to_fp16_vk_1 = nullptr;
2401
2499
 
2402
2500
  if (x_non_contig) {
2403
2501
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
@@ -2413,19 +2511,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2413
2511
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
2414
2512
 
2415
2513
  // Allocate descriptor sets
2416
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne12 * ne13);
2514
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
2417
2515
  if (qx_needs_dequant) {
2418
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, x_non_contig ? 1 : ne12 * ne13);
2516
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
2419
2517
  }
2420
2518
  if (qy_needs_dequant) {
2421
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2519
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1);
2422
2520
  }
2423
2521
  if (split_k > 1) {
2424
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, ne12 * ne13);
2522
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
2425
2523
  }
2426
2524
 
2427
2525
  if (x_non_contig) {
2428
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, dst->type, false);
2526
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
2429
2527
  } else if (load_x || qx_needs_dequant) {
2430
2528
  if (load_x) {
2431
2529
  // copy data to device
@@ -2434,13 +2532,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2434
2532
  }
2435
2533
 
2436
2534
  if (qx_needs_dequant) {
2437
- const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
2535
+ const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
2438
2536
  ggml_vk_sync_buffers(subctx);
2439
- ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
2537
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
2440
2538
  }
2441
2539
  }
2442
2540
  if (y_non_contig) {
2443
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, dst->type);
2541
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
2444
2542
  } else if (load_y) {
2445
2543
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
2446
2544
  }
@@ -2457,7 +2555,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2457
2555
  }
2458
2556
 
2459
2557
  // compute
2460
- ggml_vk_matmul(ctx, subctx, *pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
2558
+ ggml_vk_matmul(ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
2461
2559
 
2462
2560
  if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2463
2561
  // copy dst to host
@@ -2505,7 +2603,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2505
2603
  bool src0_uma = false;
2506
2604
  bool src1_uma = false;
2507
2605
 
2508
- if (ctx->device.lock()->uma) {
2606
+ if (ctx->device->uma) {
2509
2607
  ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
2510
2608
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2511
2609
  src0_uma = d_Qx != nullptr;
@@ -2527,9 +2625,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2527
2625
  const uint64_t y_ne = ne11 * ne10;
2528
2626
  const uint64_t d_ne = ne11 * ne01;
2529
2627
 
2530
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
2628
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
2531
2629
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2532
- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
2630
+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
2533
2631
  const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2534
2632
  const uint64_t d_sz = sizeof(float) * d_ne;
2535
2633
 
@@ -2569,8 +2667,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2569
2667
  GGML_ASSERT(qy_sz == y_sz);
2570
2668
  }
2571
2669
 
2572
- vk_pipeline * to_fp16_vk_0 = nullptr;
2573
- vk_pipeline* to_fp16_vk_1 = nullptr;
2670
+ vk_pipeline to_fp16_vk_0 = nullptr;
2671
+ vk_pipeline to_fp16_vk_1 = nullptr;
2574
2672
  if (x_non_contig) {
2575
2673
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
2576
2674
  }
@@ -2579,30 +2677,30 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2579
2677
  } else {
2580
2678
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
2581
2679
  }
2582
- vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
2680
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
2583
2681
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
2584
2682
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
2585
2683
  GGML_ASSERT(dmmv != nullptr);
2586
2684
 
2587
2685
  // Allocate descriptor sets
2588
2686
  if (qx_needs_dequant) {
2589
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, 1);
2687
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
2590
2688
  }
2591
2689
  if (qy_needs_dequant) {
2592
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2690
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2593
2691
  }
2594
- ggml_pipeline_allocate_descriptor_sets(ctx, *dmmv, ne12 * ne13);
2692
+ ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
2595
2693
 
2596
2694
  if (x_non_contig) {
2597
- GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment));
2598
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, src0->type);
2695
+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
2696
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
2599
2697
  } else if (load_x) {
2600
2698
  // copy data to device
2601
2699
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qx, 0, src0, 0, 0, ggml_nrows(src0));
2602
2700
  }
2603
2701
  if (y_non_contig) {
2604
2702
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
2605
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, src1->type);
2703
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
2606
2704
  } else if (load_y) {
2607
2705
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
2608
2706
  }
@@ -2619,22 +2717,22 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2619
2717
  const uint64_t y_offset = y_buf_offset + y_sz * it_idx1;
2620
2718
  const uint64_t d_offset = d_buf_offset + d_sz * it_idx1;
2621
2719
 
2622
- const uint64_t y_buffer_offset = (y_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2720
+ const uint64_t y_buffer_offset = (y_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2623
2721
  const uint64_t y_shader_offset = y_offset - y_buffer_offset;
2624
2722
 
2625
- const uint64_t d_buffer_offset = (d_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2723
+ const uint64_t d_buffer_offset = (d_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2626
2724
  const uint64_t d_shader_offset = d_offset - d_buffer_offset;
2627
2725
 
2628
2726
  if (!y_non_contig && qy_needs_dequant) {
2629
- const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
2727
+ const std::vector<uint32_t> pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) };
2630
2728
  ggml_vk_sync_buffers(subctx);
2631
- ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
2729
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1});
2632
2730
  }
2633
2731
 
2634
2732
  // compute
2635
- const std::array<int, 3> pc = { (int)ne00, (int)(y_shader_offset / ggml_type_size(src1->type)), (int)(d_shader_offset / ggml_type_size(dst->type))};
2733
+ const std::array<uint32_t, 3> pc = { (uint32_t)ne00, (uint32_t)(y_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type))};
2636
2734
  ggml_vk_sync_buffers(subctx);
2637
- ggml_vk_dispatch_pipeline(ctx, subctx, *dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
2735
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
2638
2736
 
2639
2737
  if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2640
2738
  // copy dst to host
@@ -2680,7 +2778,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2680
2778
 
2681
2779
  bool src1_uma = false;
2682
2780
 
2683
- if (ctx->device.lock()->uma) {
2781
+ if (ctx->device->uma) {
2684
2782
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2685
2783
  src1_uma = d_Qy != nullptr;
2686
2784
  }
@@ -2691,7 +2789,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2691
2789
  const uint64_t y_ne = ne10 * ne11 * ne12;
2692
2790
  const uint64_t d_ne = ne01 * ne11 * ne12;
2693
2791
 
2694
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
2792
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
2695
2793
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2696
2794
  const uint64_t d_sz = sizeof(float) * d_ne;
2697
2795
 
@@ -2710,12 +2808,12 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2710
2808
  }
2711
2809
 
2712
2810
  // Allocate descriptor sets
2713
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, 1);
2811
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
2714
2812
 
2715
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2813
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2716
2814
  const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
2717
2815
 
2718
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2816
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2719
2817
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
2720
2818
 
2721
2819
  if (load_y) {
@@ -2725,7 +2823,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2725
2823
  // compute
2726
2824
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
2727
2825
  ggml_vk_sync_buffers(subctx);
2728
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2826
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2729
2827
 
2730
2828
  if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2731
2829
  // copy dst to host
@@ -2772,7 +2870,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2772
2870
 
2773
2871
  bool src1_uma = false;
2774
2872
 
2775
- if (ctx->device.lock()->uma) {
2873
+ if (ctx->device->uma) {
2776
2874
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2777
2875
  src1_uma = d_Qy != nullptr;
2778
2876
  }
@@ -2803,12 +2901,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2803
2901
  }
2804
2902
 
2805
2903
  // Allocate descriptor sets
2806
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, 1);
2904
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
2807
2905
 
2808
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2906
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2809
2907
  const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
2810
2908
 
2811
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2909
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2812
2910
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
2813
2911
 
2814
2912
  if (load_y) {
@@ -2818,7 +2916,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2818
2916
  // compute
2819
2917
  const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
2820
2918
  ggml_vk_sync_buffers(subctx);
2821
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2919
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2822
2920
 
2823
2921
  if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2824
2922
  // copy dst to host
@@ -2856,6 +2954,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
2856
2954
  }
2857
2955
  }
2858
2956
 
2957
+ // static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
2958
+ //
2959
+ // }
2960
+
2859
2961
  static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2860
2962
  // guaranteed to be an integer due to the check in ggml_can_repeat
2861
2963
  const uint64_t ne0 = dst->ne[0];
@@ -2927,40 +3029,40 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
2927
3029
  }
2928
3030
 
2929
3031
 
2930
- static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) {
3032
+ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
2931
3033
  switch (op) {
2932
3034
  case GGML_OP_ADD:
2933
3035
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2934
- return &ctx->pipeline_add_f32;
3036
+ return ctx->device->pipeline_add_f32;
2935
3037
  }
2936
3038
  return nullptr;
2937
3039
  case GGML_OP_GET_ROWS:
2938
3040
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
2939
3041
  if (dst->type == GGML_TYPE_F16) {
2940
- return &ctx->pipeline_get_rows[src0->type];
3042
+ return ctx->device->pipeline_get_rows[src0->type];
2941
3043
  }
2942
3044
  if (dst->type == GGML_TYPE_F32) {
2943
- return &ctx->pipeline_get_rows_f32[src0->type];
3045
+ return ctx->device->pipeline_get_rows_f32[src0->type];
2944
3046
  }
2945
3047
  return nullptr;
2946
3048
  case GGML_OP_MUL:
2947
3049
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2948
- return &ctx->pipeline_mul_f32;
3050
+ return ctx->device->pipeline_mul_f32;
2949
3051
  }
2950
3052
  return nullptr;
2951
3053
  case GGML_OP_SCALE:
2952
3054
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2953
- return &ctx->pipeline_scale_f32;
3055
+ return ctx->device->pipeline_scale_f32;
2954
3056
  }
2955
3057
  return nullptr;
2956
3058
  case GGML_OP_SQR:
2957
3059
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2958
- return &ctx->pipeline_sqr_f32;
3060
+ return ctx->device->pipeline_sqr_f32;
2959
3061
  }
2960
3062
  return nullptr;
2961
3063
  case GGML_OP_CLAMP:
2962
3064
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2963
- return &ctx->pipeline_clamp_f32;
3065
+ return ctx->device->pipeline_clamp_f32;
2964
3066
  }
2965
3067
  return nullptr;
2966
3068
  case GGML_OP_CPY:
@@ -2969,29 +3071,29 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
2969
3071
  return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
2970
3072
  case GGML_OP_NORM:
2971
3073
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2972
- return &ctx->pipeline_norm_f32;
3074
+ return ctx->device->pipeline_norm_f32;
2973
3075
  }
2974
3076
  return nullptr;
2975
3077
  case GGML_OP_RMS_NORM:
2976
3078
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2977
- return &ctx->pipeline_rms_norm_f32;
3079
+ return ctx->device->pipeline_rms_norm_f32;
2978
3080
  }
2979
3081
  return nullptr;
2980
3082
  case GGML_OP_UNARY:
2981
3083
  switch (ggml_get_unary_op(dst)) {
2982
3084
  case GGML_UNARY_OP_SILU:
2983
3085
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2984
- return &ctx->pipeline_silu_f32;
3086
+ return ctx->device->pipeline_silu_f32;
2985
3087
  }
2986
3088
  break;
2987
3089
  case GGML_UNARY_OP_GELU:
2988
3090
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2989
- return &ctx->pipeline_gelu_f32;
3091
+ return ctx->device->pipeline_gelu_f32;
2990
3092
  }
2991
3093
  break;
2992
3094
  case GGML_UNARY_OP_RELU:
2993
3095
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2994
- return &ctx->pipeline_relu_f32;
3096
+ return ctx->device->pipeline_relu_f32;
2995
3097
  }
2996
3098
  break;
2997
3099
  default:
@@ -3000,12 +3102,12 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3000
3102
  return nullptr;
3001
3103
  case GGML_OP_DIAG_MASK_INF:
3002
3104
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3003
- return &ctx->pipeline_diag_mask_inf_f32;
3105
+ return ctx->device->pipeline_diag_mask_inf_f32;
3004
3106
  }
3005
3107
  return nullptr;
3006
3108
  case GGML_OP_SOFT_MAX:
3007
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3008
- return &ctx->pipeline_soft_max_f32;
3109
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3110
+ return ctx->device->pipeline_soft_max_f32;
3009
3111
  }
3010
3112
  return nullptr;
3011
3113
  case GGML_OP_ROPE:
@@ -3020,21 +3122,26 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3020
3122
 
3021
3123
  if (is_neox) {
3022
3124
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3023
- return &ctx->pipeline_rope_neox_f32;
3125
+ return ctx->device->pipeline_rope_neox_f32;
3024
3126
  }
3025
3127
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
3026
- return &ctx->pipeline_rope_neox_f16;
3128
+ return ctx->device->pipeline_rope_neox_f16;
3027
3129
  }
3028
3130
  } else {
3029
3131
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3030
- return &ctx->pipeline_rope_f32;
3132
+ return ctx->device->pipeline_rope_f32;
3031
3133
  }
3032
3134
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
3033
- return &ctx->pipeline_rope_f16;
3135
+ return ctx->device->pipeline_rope_f16;
3034
3136
  }
3035
3137
  }
3036
3138
  return nullptr;
3037
3139
  }
3140
+ case GGML_OP_ARGSORT:
3141
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
3142
+ return ctx->device->pipeline_argsort_f32;
3143
+ }
3144
+ return nullptr;
3038
3145
  default:
3039
3146
  return nullptr;
3040
3147
  }
@@ -3050,17 +3157,19 @@ static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) {
3050
3157
  }
3051
3158
 
3052
3159
  template<typename PC>
3053
- static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, const PC&& pc) {
3160
+ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
3054
3161
  #ifdef GGML_VULKAN_DEBUG
3055
3162
  std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
3056
3163
  if (src1 != nullptr) {
3057
3164
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
3058
3165
  }
3166
+ if (src2 != nullptr) {
3167
+ std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", backend=" << src2->backend << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
3168
+ }
3059
3169
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
3060
3170
  #endif
3061
3171
  GGML_ASSERT(!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))); // NOLINT
3062
3172
  GGML_ASSERT(op == GGML_OP_CPY || ggml_vk_dim01_contiguous(src0)); // NOLINT
3063
- GGML_ASSERT(src1 == nullptr || ggml_vk_dim01_contiguous(src1)); // NOLINT
3064
3173
  GGML_ASSERT(dst->extra != nullptr);
3065
3174
  const uint64_t ne00 = src0->ne[0];
3066
3175
  const uint64_t ne01 = src0->ne[1];
@@ -3077,7 +3186,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3077
3186
  const uint64_t nb2 = dst->nb[2];
3078
3187
  const uint64_t nb3 = dst->nb[3];
3079
3188
 
3080
- vk_pipeline * pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, dst, op);
3189
+ const bool use_src2 = src2 != nullptr;
3190
+ const uint64_t ne2 = use_src2 ? src2->ne[0] * src2->ne[1] : 0;
3191
+
3192
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
3081
3193
  ggml_vk_func_t op_func;
3082
3194
 
3083
3195
  if (pipeline == nullptr) {
@@ -3098,29 +3210,39 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3098
3210
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
3099
3211
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
3100
3212
  ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
3213
+ ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
3101
3214
 
3102
3215
  vk_buffer d_X = nullptr;
3103
3216
  size_t x_buf_offset = 0;
3104
3217
  vk_buffer d_Y = nullptr;
3105
3218
  size_t y_buf_offset = 0;
3219
+ vk_buffer d_Z = nullptr;
3220
+ size_t z_buf_offset = 0;
3106
3221
 
3107
3222
  bool src0_uma = false;
3108
3223
  bool src1_uma = false;
3224
+ bool src2_uma = false;
3109
3225
 
3110
- if (ctx->device.lock()->uma) {
3226
+ if (ctx->device->uma) {
3111
3227
  ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset);
3112
3228
  src0_uma = d_X != nullptr;
3113
3229
  if (use_src1) {
3114
3230
  ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset);
3115
3231
  src1_uma = d_Y != nullptr;
3116
3232
  }
3233
+ if (use_src2) {
3234
+ ggml_vk_host_get(ctx, src1->data, d_Z, z_buf_offset);
3235
+ src2_uma = d_Z != nullptr;
3236
+ }
3117
3237
  }
3118
3238
 
3119
3239
  const bool transfer_src0 = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
3120
3240
  const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
3241
+ const bool transfer_src2 = use_src2 && src2->backend != GGML_BACKEND_TYPE_GPU && !src2_uma;
3121
3242
 
3122
- uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
3123
- uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : 0;
3243
+ uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
3244
+ uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
3245
+ uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
3124
3246
  uint64_t d_sz = ggml_type_size(dst->type) * ne0;
3125
3247
 
3126
3248
  vk_buffer d_D = extra->buffer_gpu.lock();
@@ -3131,7 +3253,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3131
3253
  }
3132
3254
 
3133
3255
  GGML_ASSERT(d_D != nullptr);
3134
- uint64_t d_buf_offset = (extra->offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
3256
+ uint64_t d_buf_offset = (extra->offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
3135
3257
  GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT
3136
3258
  if (transfer_src0) {
3137
3259
  d_X = ctx->prealloc_qx;
@@ -3148,6 +3270,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3148
3270
  GGML_ASSERT(d_Y != nullptr);
3149
3271
  }
3150
3272
 
3273
+ GGML_ASSERT(!transfer_src2);
3274
+ if (use_src2 && !src2_uma) {
3275
+ d_Z = extra_src2->buffer_gpu.lock();
3276
+ z_buf_offset = extra_src2->offset;
3277
+ GGML_ASSERT(d_Z != nullptr);
3278
+ }
3279
+
3151
3280
  if (op == GGML_OP_CPY) {
3152
3281
  GGML_ASSERT(!transfer_src0);
3153
3282
  GGML_ASSERT(!transfer_src1);
@@ -3175,7 +3304,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3175
3304
 
3176
3305
  // Single call if dimension 2 is contiguous
3177
3306
  if (op == GGML_OP_CPY || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
3178
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, 1);
3307
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
3179
3308
 
3180
3309
  switch (dst->op) {
3181
3310
  case GGML_OP_NORM:
@@ -3204,16 +3333,30 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3204
3333
  }
3205
3334
  }
3206
3335
 
3207
- if (!use_src1 && op == GGML_OP_SOFT_MAX) {
3208
- // Empty src1 is possible on soft_max, but the shader needs a buffer
3336
+ if (op == GGML_OP_SOFT_MAX) {
3337
+ // Empty src1 and src2 are possible on soft_max, but the shader needs buffers
3338
+ vk_subbuffer subbuf_y;
3339
+ if (use_src1) {
3340
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
3341
+ } else {
3342
+ subbuf_y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
3343
+ }
3344
+
3345
+ vk_subbuffer subbuf_z;
3346
+ if (use_src2) {
3347
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
3348
+ } else {
3349
+ subbuf_z = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
3350
+ }
3351
+
3209
3352
  ggml_vk_sync_buffers(subctx);
3210
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3353
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3211
3354
  } else if (use_src1) {
3212
3355
  ggml_vk_sync_buffers(subctx);
3213
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3356
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3214
3357
  } else {
3215
3358
  ggml_vk_sync_buffers(subctx);
3216
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3359
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3217
3360
  }
3218
3361
  if (dst->backend == GGML_BACKEND_TYPE_CPU && op == GGML_OP_CPY) {
3219
3362
  ggml_vk_d2h_tensor_2d(ctx, subctx, d_D, 0, dst);
@@ -3223,7 +3366,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3223
3366
  ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, d_sz);
3224
3367
  }
3225
3368
  } else {
3226
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne02 * ne03);
3369
+ GGML_ASSERT(op != GGML_OP_SOFT_MAX);
3370
+
3371
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03);
3227
3372
 
3228
3373
  switch (dst->op) {
3229
3374
  case GGML_OP_NORM:
@@ -3248,16 +3393,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3248
3393
  const uint32_t y_offset = y_sz * it_idx1;
3249
3394
  const uint32_t d_offset = d_sz * it_idx0;
3250
3395
 
3251
- if (!use_src1 && op == GGML_OP_SOFT_MAX) {
3252
- // Empty src1 is possible on soft_max, but the shader needs a buffer
3396
+ if (use_src1) {
3253
3397
  ggml_vk_sync_buffers(subctx);
3254
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3255
- } else if (use_src1) {
3256
- ggml_vk_sync_buffers(subctx);
3257
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3398
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3258
3399
  } else {
3259
3400
  ggml_vk_sync_buffers(subctx);
3260
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3401
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3261
3402
  }
3262
3403
  if (dst->backend == GGML_BACKEND_TYPE_CPU) {
3263
3404
  // copy dst to host
@@ -3269,69 +3410,141 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3269
3410
  }
3270
3411
 
3271
3412
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3272
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3413
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3273
3414
  }
3274
3415
 
3275
3416
  static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3276
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3417
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3277
3418
  }
3278
3419
 
3279
3420
  static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3280
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ADD, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3421
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3422
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
3423
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3424
+
3425
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
3426
+ (uint32_t)ggml_nelements(src0),
3427
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3428
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
3429
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3430
+ 0,
3431
+ 0.0f, 0.0f,
3432
+ });
3281
3433
  }
3282
3434
 
3283
3435
  static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3284
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_MUL, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3436
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3437
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
3438
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3439
+
3440
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
3441
+ (uint32_t)ggml_nelements(src0),
3442
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3443
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
3444
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3445
+ 0,
3446
+ 0.0f, 0.0f,
3447
+ });
3285
3448
  }
3286
3449
 
3287
3450
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3288
3451
  float * op_params = (float *)dst->op_params;
3289
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SCALE, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
3452
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3453
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3454
+
3455
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
3456
+ (uint32_t)ggml_nelements(src0),
3457
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3458
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3459
+ 0,
3460
+ op_params[0], 0.0f
3461
+ });
3290
3462
  }
3291
3463
 
3292
3464
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3293
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SQR, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3465
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3466
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3467
+
3468
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
3469
+ (uint32_t)ggml_nelements(src0),
3470
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3471
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3472
+ 0,
3473
+ 0.0f, 0.0f,
3474
+ });
3294
3475
  }
3295
3476
 
3296
3477
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3297
3478
  float * op_params = (float *)dst->op_params;
3298
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CLAMP, { (uint32_t)ggml_nelements(src0), 0, op_params[0], op_params[1] });
3479
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3480
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3481
+
3482
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
3483
+ (uint32_t)ggml_nelements(src0),
3484
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3485
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3486
+ 0,
3487
+ op_params[0], op_params[1],
3488
+ });
3299
3489
  }
3300
3490
 
3301
3491
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3302
3492
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
3303
- const int src0_type_size = ggml_type_size(src0->type);
3304
- const int dst_type_size = ggml_type_size(dst->type);
3305
- const uint32_t d_offset = (extra->offset % ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
3306
- ggml_vk_op_f32<vk_op_cpy_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CPY, {
3493
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3494
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3495
+ const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
3496
+
3497
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
3307
3498
  (uint32_t)ggml_nelements(src0),
3308
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size,
3309
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size,
3499
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3500
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3310
3501
  d_offset,
3502
+ 0.0f, 0.0f,
3311
3503
  });
3312
3504
  }
3313
3505
 
3314
3506
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3315
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
3507
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
3316
3508
  }
3317
3509
 
3318
3510
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3319
3511
  float * op_params = (float *)dst->op_params;
3320
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
3512
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
3321
3513
  }
3322
3514
 
3323
3515
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3324
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3516
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3325
3517
  }
3326
3518
 
3327
3519
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3328
3520
  int32_t * op_params = (int32_t *)dst->op_params;
3329
- ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
3521
+ ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
3330
3522
  }
3331
3523
 
3332
- static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3524
+ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
3333
3525
  float * op_params = (float *)dst->op_params;
3334
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_SOFT_MAX, { (uint32_t)src0->ne[0], (uint32_t)(src1 != nullptr ? ggml_nrows(src1) : 0), op_params[0], 0.0f });
3526
+
3527
+ float scale = op_params[0];
3528
+ float max_bias = op_params[1];
3529
+
3530
+ const uint32_t ncols = (uint32_t)src0->ne[0];
3531
+ const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
3532
+ const uint32_t nrows_y = (uint32_t)src0->ne[1];
3533
+
3534
+ const uint32_t n_head_kv = nrows_x/nrows_y;
3535
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
3536
+
3537
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
3538
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
3539
+
3540
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
3541
+ ncols,
3542
+ nrows_y,
3543
+ src2 != nullptr ? (uint32_t)1 : (uint32_t)0,
3544
+ scale, max_bias,
3545
+ m0, m1,
3546
+ n_head_log2,
3547
+ });
3335
3548
  }
3336
3549
 
3337
3550
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3357,12 +3570,17 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
3357
3570
  if (is_neox) {
3358
3571
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
3359
3572
  const float inv_ndims = -1.0f / n_dims;
3360
- ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
3573
+ ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
3361
3574
  } else {
3362
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
3575
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
3363
3576
  }
3364
3577
  }
3365
3578
 
3579
+ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3580
+ int32_t * op_params = (int32_t *)dst->op_params;
3581
+ ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ORDER_ASC });
3582
+ }
3583
+
3366
3584
  static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3367
3585
  // If backend is CPU, data from src0 has to be copied off the device
3368
3586
  if (dst->backend == GGML_BACKEND_TYPE_CPU) {
@@ -3414,43 +3632,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3414
3632
  const size_t y_ne = k * n * batch;
3415
3633
  const size_t d_ne = m * n * batch;
3416
3634
 
3417
- vk_pipeline * p;
3635
+ vk_pipeline p;
3418
3636
  std::string shname;
3419
3637
  if (shader_size == 0) {
3420
3638
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3421
- p = &ctx->pipeline_matmul_f32_aligned_s;
3639
+ p = ctx->device->pipeline_matmul_f32->a_s;
3422
3640
  shname = "F32_ALIGNED_S";
3423
3641
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3424
- p = &ctx->pipeline_matmul_f16_f32_aligned_s;
3642
+ p = ctx->device->pipeline_matmul_f16_f32->a_s;
3425
3643
  shname = "F16_F32_ALIGNED_S";
3426
3644
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3427
- p = &ctx->pipeline_matmul_f16_aligned_s;
3645
+ p = ctx->device->pipeline_matmul_f16->a_s;
3428
3646
  shname = "F16_ALIGNED_S";
3429
3647
  } else {
3430
3648
  GGML_ASSERT(false);
3431
3649
  }
3432
3650
  } else if (shader_size == 1) {
3433
3651
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3434
- p = &ctx->pipeline_matmul_f32_aligned_m;
3652
+ p = ctx->device->pipeline_matmul_f32->a_m;
3435
3653
  shname = "F32_ALIGNED_M";
3436
3654
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3437
- p = &ctx->pipeline_matmul_f16_f32_aligned_m;
3655
+ p = ctx->device->pipeline_matmul_f16_f32->a_m;
3438
3656
  shname = "F16_F32_ALIGNED_M";
3439
3657
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3440
- p = &ctx->pipeline_matmul_f16_aligned_m;
3658
+ p = ctx->device->pipeline_matmul_f16->a_m;
3441
3659
  shname = "F16_ALIGNED_M";
3442
3660
  } else {
3443
3661
  GGML_ASSERT(false);
3444
3662
  }
3445
3663
  } else if (shader_size == 2) {
3446
3664
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3447
- p = &ctx->pipeline_matmul_f32_aligned_l;
3665
+ p = ctx->device->pipeline_matmul_f32->a_l;
3448
3666
  shname = "F32_ALIGNED_L";
3449
3667
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3450
- p = &ctx->pipeline_matmul_f16_f32_aligned_l;
3668
+ p = ctx->device->pipeline_matmul_f16_f32->a_l;
3451
3669
  shname = "F16_F32_ALIGNED_L";
3452
3670
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3453
- p = &ctx->pipeline_matmul_f16_aligned_l;
3671
+ p = ctx->device->pipeline_matmul_f16->a_l;
3454
3672
  shname = "F16_ALIGNED_L";
3455
3673
  } else {
3456
3674
  GGML_ASSERT(false);
@@ -3464,43 +3682,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3464
3682
  if (k != kpad) {
3465
3683
  if (shader_size == 0) {
3466
3684
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3467
- p = &ctx->pipeline_matmul_f32_s;
3685
+ p = ctx->device->pipeline_matmul_f32->s;
3468
3686
  shname = "F32_S";
3469
3687
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3470
- p = &ctx->pipeline_matmul_f16_f32_s;
3688
+ p = ctx->device->pipeline_matmul_f16_f32->s;
3471
3689
  shname = "F16_F32_S";
3472
3690
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3473
- p = &ctx->pipeline_matmul_f16_s;
3691
+ p = ctx->device->pipeline_matmul_f16->s;
3474
3692
  shname = "F16_S";
3475
3693
  }
3476
3694
  } else if (shader_size == 1) {
3477
3695
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3478
- p = &ctx->pipeline_matmul_f32_m;
3696
+ p = ctx->device->pipeline_matmul_f32->m;
3479
3697
  shname = "F32_M";
3480
3698
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3481
- p = &ctx->pipeline_matmul_f16_f32_m;
3699
+ p = ctx->device->pipeline_matmul_f16_f32->m;
3482
3700
  shname = "F16_F32_M";
3483
3701
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3484
- p = &ctx->pipeline_matmul_f16_m;
3702
+ p = ctx->device->pipeline_matmul_f16->m;
3485
3703
  shname = "F16_M";
3486
3704
  }
3487
3705
  } else if (shader_size == 2) {
3488
3706
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3489
- p = &ctx->pipeline_matmul_f32_l;
3707
+ p = ctx->device->pipeline_matmul_f32->l;
3490
3708
  shname = "F32_L";
3491
3709
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3492
- p = &ctx->pipeline_matmul_f16_f32_l;
3710
+ p = ctx->device->pipeline_matmul_f16_f32->l;
3493
3711
  shname = "F16_F32_L";
3494
3712
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3495
- p = &ctx->pipeline_matmul_f16_l;
3713
+ p = ctx->device->pipeline_matmul_f16->l;
3496
3714
  shname = "F16_L";
3497
3715
  }
3498
3716
  }
3499
3717
  }
3500
3718
 
3501
- ggml_pipeline_allocate_descriptor_sets(ctx, *p, num_it);
3719
+ ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
3502
3720
  if (split_k > 1) {
3503
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, num_it);
3721
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
3504
3722
 
3505
3723
  if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
3506
3724
  // Resize buffer
@@ -3530,9 +3748,11 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3530
3748
  }
3531
3749
  for (size_t i = 0; i < y_ne; i++) {
3532
3750
  if (std::is_same<float, Y_TYPE>()) {
3533
- y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3751
+ // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3752
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
3534
3753
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
3535
- y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
3754
+ // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
3755
+ y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
3536
3756
  } else {
3537
3757
  GGML_ASSERT(false);
3538
3758
  }
@@ -3541,17 +3761,17 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3541
3761
  ggml_vk_buffer_write(ctx, d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
3542
3762
  ggml_vk_buffer_write(ctx, d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
3543
3763
 
3544
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
3764
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3545
3765
  for (size_t i = 0; i < num_it; i++) {
3546
3766
  ggml_vk_ctx_begin(ctx, subctx);
3547
- ggml_vk_matmul(ctx, subctx, *p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
3767
+ ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
3548
3768
  ggml_vk_ctx_end(subctx);
3549
3769
  }
3550
3770
 
3551
3771
  auto begin = std::chrono::high_resolution_clock::now();
3552
3772
  ggml_vk_submit(subctx, ctx->fence);
3553
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
3554
- ctx->device.lock()->device.resetFences({ ctx->fence });
3773
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
3774
+ ctx->device->device.resetFences({ ctx->fence });
3555
3775
 
3556
3776
  auto end = std::chrono::high_resolution_clock::now();
3557
3777
  double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
@@ -3630,6 +3850,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3630
3850
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
3631
3851
  std::cerr << "Actual result: " << std::endl << std::endl;
3632
3852
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
3853
+ std::cerr << std::endl;
3854
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
3633
3855
  std::cerr << "Expected result: " << std::endl << std::endl;
3634
3856
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
3635
3857
 
@@ -3655,15 +3877,15 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3655
3877
 
3656
3878
  free(d_chk);
3657
3879
 
3658
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
3659
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
3880
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
3881
+ ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
3660
3882
 
3661
3883
  ggml_vk_destroy_buffer(d_X);
3662
3884
  ggml_vk_destroy_buffer(d_Y);
3663
3885
  ggml_vk_destroy_buffer(d_D);
3664
3886
 
3665
- ggml_pipeline_cleanup(*p);
3666
- ggml_pipeline_cleanup(ctx->pipeline_matmul_split_k_reduce);
3887
+ ggml_pipeline_cleanup(p);
3888
+ ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
3667
3889
 
3668
3890
  free(x);
3669
3891
  free(y);
@@ -3736,7 +3958,7 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
3736
3958
  data[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3737
3959
  }
3738
3960
 
3739
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
3961
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3740
3962
  ggml_vk_ctx_begin(ctx, subctx);
3741
3963
 
3742
3964
  vk_buffer buffer = ggml_vk_create_buffer_check(ctx, ggml_nbytes(tensor), vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -3745,8 +3967,8 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
3745
3967
 
3746
3968
  ggml_vk_ctx_end(subctx);
3747
3969
  ggml_vk_submit(subctx, ctx->fence);
3748
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
3749
- ctx->device.lock()->device.resetFences({ ctx->fence });
3970
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
3971
+ ctx->device->device.resetFences({ ctx->fence });
3750
3972
 
3751
3973
  ggml_vk_buffer_read(ctx, buffer, 0, result_data, ggml_nbytes(tensor));
3752
3974
 
@@ -3818,7 +4040,7 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3818
4040
  x[i] = rand() / (float)RAND_MAX;
3819
4041
  }
3820
4042
 
3821
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4043
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3822
4044
  ggml_vk_ctx_begin(ctx, subctx);
3823
4045
 
3824
4046
  auto begin = std::chrono::high_resolution_clock::now();
@@ -3832,8 +4054,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3832
4054
 
3833
4055
  ggml_vk_ctx_end(subctx);
3834
4056
  ggml_vk_submit(subctx, ctx->fence);
3835
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
3836
- ctx->device.lock()->device.resetFences({ ctx->fence });
4057
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
4058
+ ctx->device->device.resetFences({ ctx->fence });
3837
4059
 
3838
4060
  auto end = std::chrono::high_resolution_clock::now();
3839
4061
 
@@ -3847,8 +4069,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3847
4069
 
3848
4070
  ggml_vk_ctx_end(subctx);
3849
4071
  ggml_vk_submit(subctx, ctx->fence);
3850
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
3851
- ctx->device.lock()->device.resetFences({ ctx->fence });
4072
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
4073
+ ctx->device->device.resetFences({ ctx->fence });
3852
4074
 
3853
4075
  for (auto& cpy : subctx->out_memcpys) {
3854
4076
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -3879,6 +4101,10 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3879
4101
  }
3880
4102
  }
3881
4103
 
4104
+ static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
4105
+ ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
4106
+ }
4107
+
3882
4108
  static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
3883
4109
  #ifdef GGML_VULKAN_DEBUG
3884
4110
  std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
@@ -3896,72 +4122,59 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
3896
4122
  x[i] = rand() / (float)RAND_MAX;
3897
4123
  }
3898
4124
 
3899
- std::vector<int64_t> hist_cur(1 << 4, 0);
4125
+ vk_pipeline p = ctx->device->pipeline_dequant[quant];
3900
4126
 
3901
- vk_pipeline& p = ctx->pipeline_dequant[quant];
3902
-
3903
- switch(quant) {
3904
- case GGML_TYPE_Q4_0:
3905
- ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
3906
- break;
3907
- case GGML_TYPE_Q4_1:
3908
- ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
3909
- break;
3910
- case GGML_TYPE_Q5_0:
3911
- ggml_quantize_q5_0(x, qx, ne, ne, hist_cur.data());
3912
- break;
3913
- case GGML_TYPE_Q5_1:
3914
- ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
3915
- break;
3916
- case GGML_TYPE_Q8_0:
3917
- ggml_quantize_q8_0(x, qx, ne, ne, hist_cur.data());
3918
- break;
3919
- case GGML_TYPE_Q2_K:
3920
- ggml_quantize_q2_K(x, qx, ne, ne, hist_cur.data());
3921
- break;
3922
- case GGML_TYPE_Q3_K:
3923
- ggml_quantize_q3_K(x, qx, ne, ne, hist_cur.data());
3924
- break;
3925
- case GGML_TYPE_Q4_K:
3926
- ggml_quantize_q4_K(x, qx, ne, ne, hist_cur.data());
3927
- break;
3928
- case GGML_TYPE_Q5_K:
3929
- ggml_quantize_q5_K(x, qx, ne, ne, hist_cur.data());
3930
- break;
3931
- case GGML_TYPE_Q6_K:
3932
- ggml_quantize_q6_K(x, qx, ne, ne, hist_cur.data());
3933
- break;
3934
- default:
3935
- GGML_ASSERT(false);
3936
- }
4127
+ ggml_vk_quantize_data(x, qx, ne, quant);
3937
4128
 
3938
4129
  ggml_pipeline_allocate_descriptor_sets(ctx, p, 1);
3939
4130
 
3940
4131
  ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
3941
4132
 
3942
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4133
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3943
4134
  ggml_vk_ctx_begin(ctx, subctx);
3944
- const std::vector<int> pc = { 1, (int)ne, (int)ne, (int)ne };
4135
+ const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
3945
4136
  ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
3946
4137
  ggml_vk_ctx_end(subctx);
3947
4138
 
3948
4139
  auto begin = std::chrono::high_resolution_clock::now();
3949
4140
 
3950
4141
  ggml_vk_submit(subctx, ctx->fence);
3951
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
3952
- ctx->device.lock()->device.resetFences({ ctx->fence });
4142
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
4143
+ ctx->device->device.resetFences({ ctx->fence });
3953
4144
 
3954
4145
  auto end = std::chrono::high_resolution_clock::now();
3955
4146
 
3956
4147
  double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
3957
4148
  ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16);
3958
4149
 
4150
+ int first_err = -1;
4151
+
3959
4152
  double avg_err = 0.0;
3960
4153
  for (size_t i = 0; i < ne; i++) {
3961
- avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
4154
+ double error = std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
4155
+ avg_err += error;
4156
+
4157
+ if (first_err < 0 && error > 0.05) {
4158
+ first_err = i;
4159
+ }
3962
4160
  }
3963
4161
 
3964
- std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err / ne << std::endl;
4162
+ avg_err /= ne;
4163
+
4164
+ std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
4165
+
4166
+ if (avg_err > 0.1) {
4167
+ std::cerr << "first_error = " << first_err << std::endl;
4168
+ std::cerr << "Actual result: " << std::endl << std::endl;
4169
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
4170
+ std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
4171
+ }
4172
+ std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
4173
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
4174
+ std::cerr << x[i] << ", ";
4175
+ }
4176
+ std::cerr << std::endl;
4177
+ }
3965
4178
 
3966
4179
  ggml_vk_destroy_buffer(x_buf);
3967
4180
  ggml_vk_destroy_buffer(qx_buf);
@@ -3970,6 +4183,190 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
3970
4183
  free(qx);
3971
4184
  free(x_chk);
3972
4185
  }
4186
+
4187
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
4188
+ #ifdef GGML_VULKAN_DEBUG
4189
+ std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
4190
+ #endif
4191
+ const size_t x_ne = m * k * batch;
4192
+ const size_t y_ne = k * n * batch;
4193
+ const size_t d_ne = m * n * batch;
4194
+
4195
+ vk_pipeline p;
4196
+ std::string shname;
4197
+ if (shader_size == 0) {
4198
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
4199
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
4200
+ } else if (shader_size == 1) {
4201
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
4202
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
4203
+ } else if (shader_size == 2) {
4204
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
4205
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
4206
+ } else {
4207
+ GGML_ASSERT(0);
4208
+ }
4209
+
4210
+ const size_t kpad = ggml_vk_align_size(k, p->align);
4211
+
4212
+ if (k != kpad) {
4213
+ if (shader_size == 0) {
4214
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
4215
+ shname = std::string(ggml_type_name(quant)) + "_S";
4216
+ } else if (shader_size == 1) {
4217
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
4218
+ shname = std::string(ggml_type_name(quant)) + "_M";
4219
+ } else if (shader_size == 2) {
4220
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
4221
+ shname = std::string(ggml_type_name(quant)) + "_L";
4222
+ } else {
4223
+ GGML_ASSERT(0);
4224
+ }
4225
+ }
4226
+
4227
+ const size_t x_sz = sizeof(float) * x_ne;
4228
+ const size_t y_sz = sizeof(float) * y_ne;
4229
+ const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
4230
+ const size_t d_sz = sizeof(float) * d_ne;
4231
+ float * x = (float *) malloc(x_sz);
4232
+ float * y = (float *) malloc(y_sz);
4233
+ void * qx = malloc(qx_sz);
4234
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4235
+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4236
+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4237
+ float * d = (float *) malloc(d_sz);
4238
+ float * d_chk = (float *) malloc(d_sz);
4239
+
4240
+ for (size_t i = 0; i < x_ne; i++) {
4241
+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
4242
+ }
4243
+
4244
+ ggml_vk_quantize_data(x, qx, x_ne, quant);
4245
+
4246
+ for (size_t i = 0; i < y_ne; i++) {
4247
+ // y[i] = rand() / (float)RAND_MAX;
4248
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
4249
+ }
4250
+
4251
+ ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
4252
+ if (split_k > 1) {
4253
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
4254
+
4255
+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
4256
+ // Resize buffer
4257
+ if (ctx->prealloc_split_k != nullptr) {
4258
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
4259
+ }
4260
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
4261
+ }
4262
+ }
4263
+
4264
+ ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
4265
+ ggml_vk_buffer_write(ctx, y_buf, 0, y, y_sz);
4266
+
4267
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
4268
+ for (size_t i = 0; i < num_it; i++) {
4269
+ ggml_vk_ctx_begin(ctx, subctx);
4270
+ ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
4271
+ ggml_vk_ctx_end(subctx);
4272
+ }
4273
+
4274
+ auto begin = std::chrono::high_resolution_clock::now();
4275
+
4276
+ ggml_vk_submit(subctx, ctx->fence);
4277
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
4278
+ ctx->device->device.resetFences({ ctx->fence });
4279
+
4280
+ auto end = std::chrono::high_resolution_clock::now();
4281
+
4282
+ double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
4283
+ ggml_vk_buffer_read(ctx, d_buf, 0, d, d_sz);
4284
+
4285
+ ggml_init_params iparams = {
4286
+ /*.mem_size =*/ 1024*1024*1024,
4287
+ /*.mem_buffer =*/ NULL,
4288
+ /*.no_alloc =*/ true,
4289
+ };
4290
+
4291
+ ggml_context * ggml_ctx = ggml_init(iparams);
4292
+
4293
+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
4294
+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
4295
+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
4296
+
4297
+ src0_ggml->data = qx;
4298
+ src1_ggml->data = y;
4299
+ tensor_ggml->data = d_chk;
4300
+
4301
+ ctx->disable = true;
4302
+
4303
+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
4304
+ ggml_build_forward_expand(cgraph, tensor_ggml);
4305
+
4306
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
4307
+
4308
+ ctx->disable = false;
4309
+
4310
+ ggml_free(ggml_ctx);
4311
+
4312
+ double avg_err = 0.0;
4313
+ int first_err_n = -1;
4314
+ int first_err_m = -1;
4315
+ int first_err_b = -1;
4316
+
4317
+ for (size_t i = 0; i < m*n*batch; i++) {
4318
+ double err = std::fabs(d[i] - d_chk[i]);
4319
+ avg_err += err;
4320
+
4321
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
4322
+ first_err_b = i / (m * n);
4323
+ first_err_n = (i % (m * n)) / m;
4324
+ first_err_m = (i % (m * n)) % m;
4325
+ }
4326
+ }
4327
+
4328
+ avg_err /= m * n;
4329
+
4330
+ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
4331
+
4332
+ if (avg_err > 0.1 || std::isnan(avg_err)) {
4333
+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
4334
+ std::cerr << "Actual result: " << std::endl << std::endl;
4335
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4336
+ std::cerr << std::endl;
4337
+ std::cerr << "Expected result: " << std::endl << std::endl;
4338
+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4339
+
4340
+ if (split_k > 1) {
4341
+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
4342
+ ggml_vk_buffer_read(ctx, ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
4343
+
4344
+ std::cerr << "d_buf0: " << std::endl << std::endl;
4345
+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4346
+
4347
+ std::cerr << "d_buf1: " << std::endl << std::endl;
4348
+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4349
+
4350
+ std::cerr << "d_buf2: " << std::endl << std::endl;
4351
+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4352
+
4353
+ std::cerr << "d_buf3: " << std::endl << std::endl;
4354
+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4355
+
4356
+ free(split_k_buf);
4357
+ }
4358
+ }
4359
+
4360
+ ggml_vk_destroy_buffer(qx_buf);
4361
+ ggml_vk_destroy_buffer(y_buf);
4362
+ ggml_vk_destroy_buffer(d_buf);
4363
+
4364
+ free(x);
4365
+ free(qx);
4366
+ free(y);
4367
+ free(d);
4368
+ free(d_chk);
4369
+ }
3973
4370
  #endif
3974
4371
 
3975
4372
  static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
@@ -3982,18 +4379,8 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
3982
4379
  return extra;
3983
4380
  }
3984
4381
 
3985
- static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph * graph) {
3986
- GGML_ASSERT(node != nullptr);
3987
-
3988
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
3989
- for (int j = 0; j < GGML_MAX_SRC; j++) {
3990
- if (graph->nodes[i]->src[j] == node) {
3991
- return graph->nodes[i];
3992
- }
3993
- }
3994
- }
3995
-
3996
- return nullptr;
4382
+ static bool ggml_vk_cpu_assist_op(const ggml_tensor * node) {
4383
+ return node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID;
3997
4384
  }
3998
4385
 
3999
4386
  static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
@@ -4004,7 +4391,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4004
4391
  || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4005
4392
  || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_TYPE_GPU));
4006
4393
 
4007
- if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
4394
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node))) {
4008
4395
  return;
4009
4396
  }
4010
4397
 
@@ -4035,7 +4422,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4035
4422
  const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
4036
4423
 
4037
4424
  int split_k;
4038
- if (node->op == GGML_OP_MUL_MAT) {
4425
+ if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
4039
4426
  split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
4040
4427
  } else {
4041
4428
  split_k = 1;
@@ -4044,11 +4431,11 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4044
4431
  const uint32_t y_ne = ne10 * ne11;
4045
4432
  const uint32_t d_ne = ne20 * ne21;
4046
4433
 
4047
- const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4048
- const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4049
- const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4050
- const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4051
- uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
4434
+ const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4435
+ const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4436
+ const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4437
+ const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4438
+ uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
4052
4439
  const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0;
4053
4440
 
4054
4441
  if (extra->buffer_gpu.expired()) {
@@ -4076,6 +4463,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4076
4463
  case GGML_OP_DIAG_MASK_INF:
4077
4464
  case GGML_OP_SOFT_MAX:
4078
4465
  case GGML_OP_ROPE:
4466
+ case GGML_OP_ARGSORT:
4079
4467
  break;
4080
4468
  case GGML_OP_UNARY:
4081
4469
  switch (ggml_get_unary_op(node)) {
@@ -4088,6 +4476,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4088
4476
  }
4089
4477
  break;
4090
4478
  case GGML_OP_MUL_MAT:
4479
+ case GGML_OP_MUL_MAT_ID:
4091
4480
  if (ctx->prealloc_size_qx < qx_sz) {
4092
4481
  ctx->prealloc_size_qx = qx_sz;
4093
4482
  }
@@ -4121,21 +4510,66 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4121
4510
  #endif
4122
4511
  #if defined(GGML_VULKAN_RUN_TESTS)
4123
4512
  ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul,
4124
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4513
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4125
4514
  vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4126
4515
  ggml_vk_test_transfer(ctx, 8192 * 1000, false);
4127
4516
  ggml_vk_test_transfer(ctx, 8192 * 1000, true);
4128
4517
 
4129
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
4130
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
4131
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
4132
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
4133
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
4134
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
4135
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
4136
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
4137
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
4138
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
4518
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
4519
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
4520
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
4521
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
4522
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
4523
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
4524
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
4525
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
4526
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
4527
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
4528
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
4529
+
4530
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
4531
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
4532
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
4533
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
4534
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
4535
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
4536
+
4537
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
4538
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
4539
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
4540
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
4541
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
4542
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
4543
+
4544
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
4545
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
4546
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
4547
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
4548
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
4549
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
4550
+
4551
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
4552
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
4553
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
4554
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
4555
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
4556
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
4557
+
4558
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
4559
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
4560
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
4561
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
4562
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
4563
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
4564
+
4565
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
4566
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
4567
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
4568
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
4569
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
4570
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
4571
+
4572
+ std::cerr << std::endl;
4139
4573
 
4140
4574
  const std::vector<size_t> vals {
4141
4575
  8, 8, 8,
@@ -4225,7 +4659,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4225
4659
  || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4226
4660
  || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_TYPE_GPU);
4227
4661
 
4228
- if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
4662
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node)) || (ggml_vk_cpu_assist_op(node) && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
4229
4663
  return;
4230
4664
  }
4231
4665
 
@@ -4237,6 +4671,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4237
4671
 
4238
4672
  const ggml_tensor * src0 = node->src[0];
4239
4673
  const ggml_tensor * src1 = node->src[1];
4674
+ const ggml_tensor * src2 = node->src[2];
4240
4675
 
4241
4676
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
4242
4677
 
@@ -4271,7 +4706,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4271
4706
  case GGML_OP_SOFT_MAX:
4272
4707
  case GGML_OP_ROPE:
4273
4708
  case GGML_OP_MUL_MAT:
4709
+ case GGML_OP_MUL_MAT_ID:
4274
4710
  case GGML_OP_NONE:
4711
+ case GGML_OP_ARGSORT:
4275
4712
  break;
4276
4713
  default:
4277
4714
  if (any_on_device) {
@@ -4282,7 +4719,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4282
4719
  }
4283
4720
 
4284
4721
  if (ctx->compute_ctx == nullptr) {
4285
- ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4722
+ ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
4286
4723
  ggml_vk_ctx_begin(ctx, ctx->compute_ctx);
4287
4724
  }
4288
4725
 
@@ -4353,16 +4790,25 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4353
4790
 
4354
4791
  break;
4355
4792
  case GGML_OP_SOFT_MAX:
4356
- ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
4793
+ ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, src2, node);
4357
4794
 
4358
4795
  break;
4359
4796
  case GGML_OP_ROPE:
4360
4797
  ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, node);
4361
4798
 
4799
+ break;
4800
+ case GGML_OP_ARGSORT:
4801
+ ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
4362
4802
  break;
4363
4803
  case GGML_OP_MUL_MAT:
4364
4804
  ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
4365
4805
 
4806
+ break;
4807
+ case GGML_OP_MUL_MAT_ID:
4808
+ //ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, node);
4809
+ std::cerr << "ggml_vulkan: GGML_OP_MUL_MAT_ID not implemented yet." << std::endl;
4810
+ GGML_ASSERT(false);
4811
+
4366
4812
  break;
4367
4813
  default:
4368
4814
  return;
@@ -4389,7 +4835,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4389
4835
  || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4390
4836
  || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
4391
4837
 
4392
- if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
4838
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(tensor))) {
4393
4839
  return false;
4394
4840
  }
4395
4841
 
@@ -4415,6 +4861,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4415
4861
  case GGML_OP_PERMUTE:
4416
4862
  case GGML_OP_TRANSPOSE:
4417
4863
  case GGML_OP_NONE:
4864
+ case GGML_OP_ARGSORT:
4418
4865
  extra = (ggml_tensor_extra_gpu *) tensor->extra;
4419
4866
 
4420
4867
  break;
@@ -4430,6 +4877,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4430
4877
  }
4431
4878
  break;
4432
4879
  case GGML_OP_MUL_MAT:
4880
+ case GGML_OP_MUL_MAT_ID:
4433
4881
  if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
4434
4882
  return false;
4435
4883
  }
@@ -4475,8 +4923,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4475
4923
  }
4476
4924
 
4477
4925
  if (tensor == subctx.exit_tensor) {
4478
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
4479
- ctx->device.lock()->device.resetFences({ ctx->fence });
4926
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
4927
+ ctx->device->device.resetFences({ ctx->fence });
4480
4928
 
4481
4929
  // Do staging buffer copies
4482
4930
  for (auto& cpy : subctx.out_memcpys) {
@@ -4504,20 +4952,25 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
4504
4952
  }
4505
4953
  ctx->gc.temp_buffers.clear();
4506
4954
 
4507
- for (auto * pipeline : ctx->gc.pipelines) {
4508
- ggml_pipeline_cleanup(*pipeline);
4955
+ for (auto& pipeline : ctx->device->pipelines) {
4956
+ if (pipeline.expired()) {
4957
+ continue;
4958
+ }
4959
+
4960
+ vk_pipeline pl = pipeline.lock();
4961
+ ggml_pipeline_cleanup(pl);
4509
4962
  }
4510
4963
 
4511
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
4512
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
4964
+ ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
4965
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
4513
4966
 
4514
4967
  for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
4515
- ctx->device.lock()->device.destroySemaphore({ ctx->gc.semaphores[i].s });
4968
+ ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
4516
4969
  }
4517
4970
  ctx->gc.semaphores.clear();
4518
4971
 
4519
4972
  for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
4520
- ctx->device.lock()->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
4973
+ ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
4521
4974
  }
4522
4975
  ctx->gc.tl_semaphores.clear();
4523
4976
  ctx->semaphore_idx = 0;
@@ -4525,7 +4978,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
4525
4978
  ctx->event_idx = 0;
4526
4979
 
4527
4980
  for (auto& event : ctx->gc.events) {
4528
- ctx->device.lock()->device.resetEvent(event);
4981
+ ctx->device->device.resetEvent(event);
4529
4982
  }
4530
4983
 
4531
4984
  ctx->staging_offset = 0;
@@ -4562,21 +5015,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
4562
5015
  ctx->staging_size = 0;
4563
5016
 
4564
5017
  for (auto& event : ctx->gc.events) {
4565
- ctx->device.lock()->device.destroyEvent(event);
5018
+ ctx->device->device.destroyEvent(event);
4566
5019
  }
4567
5020
  ctx->gc.events.clear();
4568
5021
 
4569
- for (auto* pipeline : ctx->gc.pipelines) {
4570
- ggml_vk_destroy_pipeline(ctx, pipeline);
4571
- }
4572
- ctx->gc.pipelines.clear();
4573
-
4574
- ctx->device.lock()->device.destroyFence(ctx->fence);
4575
-
4576
- ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->compute_queue.pool);
4577
- if (!ctx->device.lock()->single_queue) {
4578
- ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->transfer_queue.pool);
4579
- }
5022
+ ctx->device->device.destroyFence(ctx->fence);
4580
5023
  }
4581
5024
 
4582
5025
  GGML_CALL static int ggml_vk_get_device_count() {
@@ -4787,7 +5230,6 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
4787
5230
 
4788
5231
  GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
4789
5232
  if (ggml_backend_buffer_is_vk(src->buffer)) {
4790
- ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4791
5233
  ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
4792
5234
  ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
4793
5235
 
@@ -4799,6 +5241,8 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu
4799
5241
  return true;
4800
5242
  }
4801
5243
  return false;
5244
+
5245
+ UNUSED(buffer);
4802
5246
  }
4803
5247
 
4804
5248
  GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4845,12 +5289,12 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
4845
5289
 
4846
5290
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4847
5291
  ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
4848
- return ctx->ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
5292
+ return ctx->ctx->device->properties.limits.minStorageBufferOffsetAlignment;
4849
5293
  }
4850
5294
 
4851
5295
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
4852
5296
  ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
4853
- return ctx->ctx->device.lock()->max_memory_allocation_size;
5297
+ return ctx->ctx->device->max_memory_allocation_size;
4854
5298
  }
4855
5299
 
4856
5300
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@@ -4936,7 +5380,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu
4936
5380
  }
4937
5381
 
4938
5382
  GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4939
- return vk_instance.contexts[0].device.lock()->properties.limits.minMemoryMapAlignment;
5383
+ return vk_instance.contexts[0].device->properties.limits.minMemoryMapAlignment;
4940
5384
 
4941
5385
  UNUSED(buft);
4942
5386
  }
@@ -4981,8 +5425,7 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
4981
5425
 
4982
5426
  ggml_vk_cleanup(ctx);
4983
5427
 
4984
- // Release device
4985
- vk_instance.devices[ctx->idx].reset();
5428
+ ctx->device.reset();
4986
5429
  ctx->initialized = false;
4987
5430
 
4988
5431
  vk_instance.initialized[idx] = false;
@@ -5011,7 +5454,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
5011
5454
 
5012
5455
  if (ctx->transfer_ctx == nullptr) {
5013
5456
  // Initialize new transfer context
5014
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5457
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5015
5458
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5016
5459
  }
5017
5460
 
@@ -5032,7 +5475,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
5032
5475
 
5033
5476
  if (ctx->transfer_ctx == nullptr) {
5034
5477
  // Initialize new transfer context
5035
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5478
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5036
5479
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5037
5480
  }
5038
5481
 
@@ -5052,7 +5495,7 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
5052
5495
 
5053
5496
  if (ctx->transfer_ctx == nullptr) {
5054
5497
  // Initialize new transfer context
5055
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5498
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5056
5499
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5057
5500
  }
5058
5501
 
@@ -5082,8 +5525,8 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
5082
5525
  }
5083
5526
 
5084
5527
  ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
5085
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
5086
- ctx->device.lock()->device.resetFences({ ctx->fence });
5528
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
5529
+ ctx->device->device.resetFences({ ctx->fence });
5087
5530
 
5088
5531
  for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
5089
5532
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -5092,7 +5535,7 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
5092
5535
  ctx->transfer_ctx = nullptr;
5093
5536
  }
5094
5537
 
5095
- GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
5538
+ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
5096
5539
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5097
5540
 
5098
5541
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -5135,7 +5578,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5135
5578
 
5136
5579
  ggml_vk_graph_cleanup(ctx);
5137
5580
 
5138
- return true;
5581
+ return GGML_STATUS_SUCCESS;
5139
5582
 
5140
5583
  UNUSED(backend);
5141
5584
  }
@@ -5153,6 +5596,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
5153
5596
  }
5154
5597
  break;
5155
5598
  case GGML_OP_MUL_MAT:
5599
+ case GGML_OP_MUL_MAT_ID:
5156
5600
  {
5157
5601
  struct ggml_tensor * a;
5158
5602
  struct ggml_tensor * b;
@@ -5226,6 +5670,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
5226
5670
  case GGML_OP_CONT:
5227
5671
  case GGML_OP_DIAG_MASK_INF:
5228
5672
  case GGML_OP_SOFT_MAX:
5673
+ case GGML_OP_ARGSORT:
5229
5674
  return true;
5230
5675
  default:
5231
5676
  return false;
@@ -5248,6 +5693,11 @@ static ggml_backend_i ggml_backend_vk_interface = {
5248
5693
  /* .graph_plan_compute = */ NULL,
5249
5694
  /* .graph_compute = */ ggml_backend_vk_graph_compute,
5250
5695
  /* .supports_op = */ ggml_backend_vk_supports_op,
5696
+ /* .event_new = */ NULL,
5697
+ /* .event_free = */ NULL,
5698
+ /* .event_record = */ NULL,
5699
+ /* .event_wait = */ NULL,
5700
+ /* .event_synchronize = */ NULL,
5251
5701
  };
5252
5702
 
5253
5703
  static ggml_guid_t ggml_backend_vk_guid() {
@@ -5428,7 +5878,8 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
5428
5878
 
5429
5879
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5430
5880
 
5431
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, extra->offset, tensor_data, tensor_size);
5881
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5882
+ ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size);
5432
5883
  }
5433
5884
 
5434
5885
  std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
@@ -5504,6 +5955,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5504
5955
 
5505
5956
  ggml_tensor * src0 = tensor->src[0];
5506
5957
  ggml_tensor * src1 = tensor->src[1];
5958
+ ggml_tensor * src2 = tensor->src[2];
5507
5959
 
5508
5960
  struct ggml_init_params iparams = {
5509
5961
  /*.mem_size =*/ 1024*1024*1024,
@@ -5515,13 +5967,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5515
5967
 
5516
5968
  struct ggml_tensor * src0_clone = nullptr;
5517
5969
  struct ggml_tensor * src1_clone = nullptr;
5970
+ struct ggml_tensor * src2_clone = nullptr;
5518
5971
  struct ggml_tensor * tensor_clone = nullptr;
5519
5972
 
5520
5973
  size_t src0_size;
5521
5974
  size_t src1_size;
5975
+ size_t src2_size;
5522
5976
 
5523
5977
  void * src0_buffer;
5524
5978
  void * src1_buffer;
5979
+ void * src2_buffer;
5525
5980
 
5526
5981
  if (src0 != nullptr) {
5527
5982
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -5535,12 +5990,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5535
5990
  memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
5536
5991
  } else if (src0->backend == GGML_BACKEND_TYPE_GPU) {
5537
5992
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
5993
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5538
5994
  uint64_t offset = extra->offset;
5539
5995
  if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
5540
5996
  for (int i3 = 0; i3 < src0->ne[3]; i3++) {
5541
5997
  for (int i2 = 0; i2 < src0->ne[2]; i2++) {
5542
5998
  const int idx = i3*src0->ne[2] + i2;
5543
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
5999
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
5544
6000
  }
5545
6001
  }
5546
6002
 
@@ -5550,10 +6006,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5550
6006
  src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
5551
6007
  }
5552
6008
  } else {
5553
- if (offset + src0_size >= extra->buffer_gpu->size) {
5554
- src0_size = extra->buffer_gpu->size - offset;
6009
+ if (offset + src0_size >= buffer_gpu->size) {
6010
+ src0_size = buffer_gpu->size - offset;
5555
6011
  }
5556
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset, src0_clone->data, src0_size);
6012
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size);
5557
6013
  memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
5558
6014
  }
5559
6015
  } else {
@@ -5578,12 +6034,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5578
6034
  memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
5579
6035
  } else if (src1->backend == GGML_BACKEND_TYPE_GPU) {
5580
6036
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
6037
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5581
6038
  uint64_t offset = extra->offset;
5582
6039
  if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
5583
6040
  for (int i3 = 0; i3 < src1->ne[3]; i3++) {
5584
6041
  for (int i2 = 0; i2 < src1->ne[2]; i2++) {
5585
6042
  const int idx = i3*src1->ne[2] + i2;
5586
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
6043
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
5587
6044
  }
5588
6045
  }
5589
6046
 
@@ -5593,10 +6050,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5593
6050
  src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
5594
6051
  }
5595
6052
  } else {
5596
- if (offset + src1_size >= extra->buffer_gpu->size) {
5597
- src1_size = extra->buffer_gpu->size - offset;
6053
+ if (offset + src1_size >= buffer_gpu->size) {
6054
+ src1_size = buffer_gpu->size - offset;
5598
6055
  }
5599
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset, src1_clone->data, src1_size);
6056
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size);
5600
6057
  memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
5601
6058
  }
5602
6059
  } else {
@@ -5625,6 +6082,66 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5625
6082
 
5626
6083
  ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone);
5627
6084
  }
6085
+ if (src2 != nullptr) {
6086
+ src2_clone = ggml_dup_tensor(ggml_ctx, src2);
6087
+
6088
+ src2_size = ggml_nbytes(src2);
6089
+
6090
+ src2_buffer = malloc(src2_size);
6091
+ src2_clone->data = src2_buffer;
6092
+ if (src2->backend == GGML_BACKEND_TYPE_CPU) {
6093
+ memcpy(src2_clone->data, src2->data, src2_size);
6094
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
6095
+ } else if (src2->backend == GGML_BACKEND_TYPE_GPU) {
6096
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
6097
+ vk_buffer buf = extra->buffer_gpu.lock();
6098
+ uint64_t offset = extra->offset;
6099
+ if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
6100
+ for (int i3 = 0; i3 < src2->ne[3]; i3++) {
6101
+ for (int i2 = 0; i2 < src2->ne[2]; i2++) {
6102
+ const int idx = i3*src2->ne[2] + i2;
6103
+ ggml_vk_buffer_read(ctx, buf, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
6104
+ }
6105
+ }
6106
+
6107
+ src2_clone->nb[0] = src2->nb[0];
6108
+ src2_clone->nb[1] = src2->nb[1];
6109
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
6110
+ src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
6111
+ }
6112
+ } else {
6113
+ if (offset + src2_size >= buf->size) {
6114
+ src2_size = buf->size - offset;
6115
+ }
6116
+ ggml_vk_buffer_read(ctx, buf, offset, src2_clone->data, src2_size);
6117
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
6118
+ }
6119
+ } else {
6120
+ GGML_ASSERT(false);
6121
+ }
6122
+
6123
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
6124
+ ggml_vk_print_tensor(ctx, src2, "src2");
6125
+ std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
6126
+ std::cerr << "src2_clone=" << tensor << " src2_clone->backend: " << src2_clone->backend << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
6127
+ if (src2->src[0] != nullptr) {
6128
+ std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " backend=" << src2->src[0]->backend << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
6129
+ }
6130
+ if (src2->src[1] != nullptr) {
6131
+ std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " backend=" << src2->src[1]->backend << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
6132
+ }
6133
+ std::cerr << std::endl << "Result:" << std::endl;
6134
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
6135
+ std::cerr << std::endl;
6136
+ std::cerr << std::endl << "Result:" << std::endl;
6137
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
6138
+ std::cerr << std::endl;
6139
+ std::vector<const ggml_tensor *> done;
6140
+ ggml_vk_print_graph_origin(src2_clone, done);
6141
+ }
6142
+
6143
+ ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src2", src2_clone);
6144
+ }
5628
6145
 
5629
6146
  if (tensor->op == GGML_OP_MUL_MAT) {
5630
6147
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
@@ -5644,7 +6161,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5644
6161
  tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
5645
6162
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
5646
6163
  if (src1 != nullptr) {
5647
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, *(float *)tensor->op_params);
6164
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
5648
6165
  } else {
5649
6166
  tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
5650
6167
  }
@@ -5727,6 +6244,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5727
6244
  if (src1 != nullptr) {
5728
6245
  free(src1_buffer);
5729
6246
  }
6247
+ if (src2 != nullptr) {
6248
+ free(src1_buffer);
6249
+ }
5730
6250
 
5731
6251
  ggml_free(ggml_ctx);
5732
6252
  }
@@ -5753,11 +6273,12 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5753
6273
 
5754
6274
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5755
6275
 
5756
- if (extra->offset + tensor_size >= extra->buffer_gpu->size) {
5757
- tensor_size = extra->buffer_gpu->size - (extra->offset);
6276
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
6277
+ if (extra->offset + tensor_size >= buffer_gpu->size) {
6278
+ tensor_size = buffer_gpu->size - (extra->offset);
5758
6279
  }
5759
6280
 
5760
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, extra->offset, tensor_data, tensor_size);
6281
+ ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size);
5761
6282
  }
5762
6283
 
5763
6284
  float first_error_result = -1.0f;