llama_cpp 0.12.7 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -69,6 +69,33 @@ struct vk_queue {
69
69
  vk::PipelineStageFlags stage_flags;
70
70
  };
71
71
 
72
+ struct vk_pipeline_struct {
73
+ std::string name;
74
+ vk::ShaderModule shader_module;
75
+ vk::DescriptorSetLayout dsl;
76
+ std::vector<vk::DescriptorPool> descriptor_pools;
77
+ std::vector<vk::DescriptorSet> descriptor_sets;
78
+ uint32_t descriptor_set_idx;
79
+ vk::PipelineLayout layout;
80
+ vk::Pipeline pipeline;
81
+ uint32_t push_constant_size;
82
+ uint32_t parameter_count;
83
+ std::array<uint32_t, 3> wg_denoms;
84
+ uint32_t align;
85
+ };
86
+
87
+ typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
88
+ typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
89
+
90
+ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
91
+
92
+ struct vk_matmul_pipeline_struct {
93
+ vk_pipeline l, m, s;
94
+ vk_pipeline a_l, a_m, a_s;
95
+ };
96
+
97
+ typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
98
+
72
99
  struct vk_device {
73
100
  vk::PhysicalDevice physical_device;
74
101
  vk::PhysicalDeviceProperties properties;
@@ -84,10 +111,61 @@ struct vk_device {
84
111
  uint32_t subgroup_size;
85
112
  bool uma;
86
113
 
114
+ bool initialized;
115
+ size_t idx;
116
+
117
+ vk_matmul_pipeline pipeline_matmul_f32;
118
+ vk_matmul_pipeline pipeline_matmul_f16;
119
+ vk_matmul_pipeline pipeline_matmul_f16_f32;
120
+ vk_pipeline pipeline_matmul_split_k_reduce;
121
+
122
+ vk_matmul_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
123
+
124
+ vk_pipeline pipeline_dequant[VK_NUM_TYPES];
125
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
126
+
127
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
128
+ vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
129
+ vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
130
+ vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
131
+ vk_pipeline pipeline_mul_f32;
132
+ vk_pipeline pipeline_add_f32;
133
+ vk_pipeline pipeline_scale_f32;
134
+ vk_pipeline pipeline_sqr_f32;
135
+ vk_pipeline pipeline_clamp_f32;
136
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
137
+ vk_pipeline pipeline_norm_f32;
138
+ vk_pipeline pipeline_rms_norm_f32;
139
+ vk_pipeline pipeline_gelu_f32;
140
+ vk_pipeline pipeline_silu_f32;
141
+ vk_pipeline pipeline_relu_f32;
142
+ vk_pipeline pipeline_diag_mask_inf_f32;
143
+ vk_pipeline pipeline_soft_max_f32;
144
+ vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
145
+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
146
+ vk_pipeline pipeline_argsort_f32;
147
+
148
+ std::vector<vk_pipeline_ref> pipelines;
149
+
87
150
  ~vk_device() {
88
151
  #ifdef GGML_VULKAN_DEBUG
89
152
  std::cerr << "destroy device " << name << std::endl;
90
153
  #endif
154
+ device.destroyCommandPool(compute_queue.pool);
155
+ if (!single_queue) {
156
+ device.destroyCommandPool(transfer_queue.pool);
157
+ }
158
+
159
+ for (auto& pipeline : pipelines) {
160
+ if (pipeline.expired()) {
161
+ continue;
162
+ }
163
+
164
+ vk_pipeline pl = pipeline.lock();
165
+ ggml_vk_destroy_pipeline(device, pl);
166
+ }
167
+ pipelines.clear();
168
+
91
169
  device.destroy();
92
170
  }
93
171
  };
@@ -125,21 +203,6 @@ struct vk_subbuffer {
125
203
  uint64_t size;
126
204
  };
127
205
 
128
- struct vk_pipeline {
129
- std::string name;
130
- vk::ShaderModule shader_module;
131
- vk::DescriptorSetLayout dsl;
132
- std::vector<vk::DescriptorPool> descriptor_pools;
133
- std::vector<vk::DescriptorSet> descriptor_sets;
134
- uint32_t descriptor_set_idx;
135
- vk::PipelineLayout layout;
136
- vk::Pipeline pipeline;
137
- uint32_t push_constant_size;
138
- uint32_t parameter_count;
139
- std::array<uint32_t, 3> wg_denoms;
140
- uint32_t align;
141
- };
142
-
143
206
  struct vk_semaphore {
144
207
  vk::Semaphore s;
145
208
  uint64_t value;
@@ -160,11 +223,21 @@ struct vk_op_push_constants {
160
223
  float param2;
161
224
  };
162
225
 
163
- struct vk_op_cpy_push_constants {
226
+ struct vk_op_unary_push_constants {
227
+ uint32_t ne;
228
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
229
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
230
+ uint32_t d_offset;
231
+ float param1; float param2;
232
+ };
233
+
234
+ struct vk_op_binary_push_constants {
164
235
  uint32_t ne;
165
- uint32_t ne00; uint32_t ne01; uint32_t nb00; uint32_t nb01; uint32_t nb02;
166
- uint32_t ne10; uint32_t ne11; uint32_t nb10; uint32_t nb11; uint32_t nb12;
236
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
237
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
238
+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
167
239
  uint32_t d_offset;
240
+ float param1; float param2;
168
241
  };
169
242
 
170
243
  struct vk_op_diag_mask_push_constants {
@@ -196,6 +269,22 @@ struct vk_op_rope_neox_push_constants {
196
269
  float inv_ndims;
197
270
  };
198
271
 
272
+ struct vk_op_soft_max_push_constants {
273
+ uint32_t KX;
274
+ uint32_t KY;
275
+ uint32_t KZ;
276
+ float scale;
277
+ float max_bias;
278
+ float m0;
279
+ float m1;
280
+ uint32_t n_head_log2;
281
+ };
282
+
283
+ struct vk_op_argsort_push_constants {
284
+ uint32_t ncols;
285
+ bool ascending;
286
+ };
287
+
199
288
  // Allow pre-recording command buffers
200
289
  struct vk_staging_memcpy {
201
290
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -236,7 +325,6 @@ struct ggml_tensor_extra_gpu {
236
325
  };
237
326
 
238
327
  struct ggml_vk_garbage_collector {
239
- std::vector<vk_pipeline *> pipelines;
240
328
  std::vector<vk_semaphore> tl_semaphores;
241
329
  std::vector<vk_semaphore> semaphores;
242
330
  std::vector<vk::Event> events;
@@ -247,35 +335,7 @@ struct ggml_vk_garbage_collector {
247
335
  struct ggml_backend_vk_context {
248
336
  std::string name;
249
337
 
250
- std::weak_ptr<vk_device> device;
251
- vk_pipeline pipeline_matmul_f32_l, pipeline_matmul_f32_m, pipeline_matmul_f32_s;
252
- vk_pipeline pipeline_matmul_f32_aligned_l, pipeline_matmul_f32_aligned_m, pipeline_matmul_f32_aligned_s;
253
- vk_pipeline pipeline_matmul_f16_l, pipeline_matmul_f16_m, pipeline_matmul_f16_s;
254
- vk_pipeline pipeline_matmul_f16_aligned_l, pipeline_matmul_f16_aligned_m, pipeline_matmul_f16_aligned_s;
255
- vk_pipeline pipeline_matmul_f16_f32_l, pipeline_matmul_f16_f32_m, pipeline_matmul_f16_f32_s;
256
- vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
257
- vk_pipeline pipeline_matmul_split_k_reduce;
258
- vk_pipeline pipeline_dequant[VK_NUM_TYPES];
259
- vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
260
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
261
- vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
262
- vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
263
- vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
264
- vk_pipeline pipeline_mul_f32;
265
- vk_pipeline pipeline_add_f32;
266
- vk_pipeline pipeline_scale_f32;
267
- vk_pipeline pipeline_sqr_f32;
268
- vk_pipeline pipeline_clamp_f32;
269
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
270
- vk_pipeline pipeline_norm_f32;
271
- vk_pipeline pipeline_rms_norm_f32;
272
- vk_pipeline pipeline_gelu_f32;
273
- vk_pipeline pipeline_silu_f32;
274
- vk_pipeline pipeline_relu_f32;
275
- vk_pipeline pipeline_diag_mask_inf_f32;
276
- vk_pipeline pipeline_soft_max_f32;
277
- vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
278
- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
338
+ std::shared_ptr<vk_device> device;
279
339
 
280
340
  size_t semaphore_idx, event_idx;
281
341
  ggml_vk_garbage_collector gc;
@@ -304,13 +364,31 @@ struct vk_instance {
304
364
 
305
365
  std::vector<size_t> device_indices;
306
366
 
307
- std::shared_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
308
367
  ggml_backend_t backends[GGML_VK_MAX_DEVICES];
309
368
  ggml_backend_vk_context contexts[GGML_VK_MAX_DEVICES];
310
369
  ggml_backend_buffer_type buffer_types[GGML_VK_MAX_DEVICES];
311
370
  bool initialized[GGML_VK_MAX_DEVICES];
312
371
  };
313
372
 
373
+ static std::shared_ptr<vk_device> ggml_vk_get_device(size_t idx) {
374
+ #ifdef GGML_VULKAN_DEBUG
375
+ std::cerr << "ggml_vk_get_device(" << idx << ")" << std::endl;
376
+ #endif
377
+ static std::weak_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
378
+
379
+ if (devices[idx].expired()) {
380
+ #ifdef GGML_VULKAN_DEBUG
381
+ std::cerr << "Initializing new vk_device" << std::endl;
382
+ #endif
383
+ std::shared_ptr<vk_device> device = std::make_shared<vk_device>();
384
+ device->initialized = false;
385
+ devices[idx] = device;
386
+ return device;
387
+ }
388
+
389
+ return devices[idx].lock();
390
+ }
391
+
314
392
  #ifdef GGML_VULKAN_CHECK_RESULTS
315
393
  static size_t vk_skip_checks;
316
394
  static size_t vk_output_tensor;
@@ -334,14 +412,15 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
334
412
  GGML_ASSERT(parameter_count > 0);
335
413
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
336
414
 
337
- pipeline.name = name;
338
- pipeline.parameter_count = parameter_count;
339
- pipeline.push_constant_size = push_constant_size;
340
- pipeline.wg_denoms = wg_denoms;
341
- pipeline.align = align;
415
+ pipeline = std::make_shared<vk_pipeline_struct>();
416
+ pipeline->name = name;
417
+ pipeline->parameter_count = parameter_count;
418
+ pipeline->push_constant_size = push_constant_size;
419
+ pipeline->wg_denoms = wg_denoms;
420
+ pipeline->align = align;
342
421
 
343
422
  vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
344
- pipeline.shader_module = ctx->device.lock()->device.createShaderModule(shader_module_create_info);
423
+ pipeline->shader_module = ctx->device->device.createShaderModule(shader_module_create_info);
345
424
 
346
425
  std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
347
426
  std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@@ -355,49 +434,49 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
355
434
  vk::PushConstantRange pcr(
356
435
  vk::ShaderStageFlagBits::eCompute,
357
436
  0,
358
- pipeline.push_constant_size
437
+ pipeline->push_constant_size
359
438
  );
360
439
 
361
440
  vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
362
441
  {},
363
442
  dsl_binding);
364
443
  descriptor_set_layout_create_info.setPNext(&dslbfci);
365
- pipeline.dsl = ctx->device.lock()->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
444
+ pipeline->dsl = ctx->device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
366
445
 
367
446
  // Check if device supports multiple descriptors per pool
368
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
447
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
369
448
  const uint32_t alloc_count = 2;
370
449
 
371
450
  // Try allocating multiple sets from one pool
372
451
  // This fails on AMD for some reason, so add a fall back to allocating one pool per set
373
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
452
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
374
453
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
375
- vk::DescriptorPool pool = ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info);
454
+ vk::DescriptorPool pool = ctx->device->device.createDescriptorPool(descriptor_pool_create_info);
376
455
 
377
456
  std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
378
457
  for (uint32_t i = 0; i < alloc_count; i++) {
379
- layouts[i] = pipeline.dsl;
458
+ layouts[i] = pipeline->dsl;
380
459
  }
381
460
  try {
382
461
  vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
383
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
462
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
384
463
  } catch(vk::OutOfPoolMemoryError const&) {
385
- ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
464
+ ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
386
465
  }
387
466
 
388
- ctx->device.lock()->device.destroyDescriptorPool(pool);
467
+ ctx->device->device.destroyDescriptorPool(pool);
389
468
  }
390
469
 
391
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
392
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
470
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
471
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
393
472
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
394
- pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
473
+ pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
395
474
  }
396
475
 
397
- pipeline.descriptor_set_idx = 0;
476
+ pipeline->descriptor_set_idx = 0;
398
477
 
399
- vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline.dsl, pcr);
400
- pipeline.layout = ctx->device.lock()->device.createPipelineLayout(pipeline_layout_create_info);
478
+ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
479
+ pipeline->layout = ctx->device->device.createPipelineLayout(pipeline_layout_create_info);
401
480
 
402
481
  std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
403
482
 
@@ -417,72 +496,75 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
417
496
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
418
497
  vk::PipelineShaderStageCreateFlags(),
419
498
  vk::ShaderStageFlagBits::eCompute,
420
- pipeline.shader_module,
499
+ pipeline->shader_module,
421
500
  entrypoint.c_str(),
422
501
  &specialization_info);
423
502
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
424
503
  vk::PipelineCreateFlags(),
425
504
  pipeline_shader_create_info,
426
- pipeline.layout);
427
- pipeline.pipeline = ctx->device.lock()->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
505
+ pipeline->layout);
506
+ pipeline->pipeline = ctx->device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
428
507
 
429
- ctx->gc.pipelines.push_back(&pipeline);
508
+ ctx->device->pipelines.push_back(pipeline);
430
509
  }
431
510
 
432
- static void ggml_vk_destroy_pipeline(ggml_backend_vk_context * ctx, vk_pipeline * pipeline) {
511
+ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
512
+ #ifdef GGML_VULKAN_DEBUG
513
+ std::cerr << "ggml_pipeline_destroy_pipeline(" << pipeline->name << ")" << std::endl;
514
+ #endif
433
515
  for (auto& pool : pipeline->descriptor_pools) {
434
- ctx->device.lock()->device.destroyDescriptorPool(pool);
516
+ device.destroyDescriptorPool(pool);
435
517
  }
436
518
  pipeline->descriptor_pools.clear();
437
519
  pipeline->descriptor_sets.clear();
438
520
  pipeline->descriptor_set_idx = 0;
439
521
 
440
- ctx->device.lock()->device.destroyDescriptorSetLayout(pipeline->dsl);
522
+ device.destroyDescriptorSetLayout(pipeline->dsl);
441
523
 
442
- ctx->device.lock()->device.destroyPipelineLayout(pipeline->layout);
524
+ device.destroyPipelineLayout(pipeline->layout);
443
525
 
444
- ctx->device.lock()->device.destroyShaderModule(pipeline->shader_module);
526
+ device.destroyShaderModule(pipeline->shader_module);
445
527
 
446
- ctx->device.lock()->device.destroyPipeline(pipeline->pipeline);
528
+ device.destroyPipeline(pipeline->pipeline);
447
529
  }
448
530
 
449
531
  static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, uint32_t n) {
450
532
  #ifdef GGML_VULKAN_DEBUG
451
- std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl;
533
+ std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")" << std::endl;
452
534
  #endif
453
- if (pipeline.descriptor_sets.size() >= pipeline.descriptor_set_idx + n) {
535
+ if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
454
536
  // Enough descriptors are available
455
537
  return;
456
538
  }
457
539
 
458
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
459
- const uint32_t alloc_count = pipeline.descriptor_set_idx + n - pipeline.descriptor_sets.size();
540
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
541
+ const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
460
542
 
461
543
  std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
462
544
  for (uint32_t i = 0; i < alloc_count; i++) {
463
- layouts[i] = pipeline.dsl;
545
+ layouts[i] = pipeline->dsl;
464
546
  }
465
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[0], alloc_count, layouts.data());
466
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
467
- pipeline.descriptor_sets.insert(pipeline.descriptor_sets.end(), sets.begin(), sets.end());
547
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
548
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
549
+ pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
468
550
  } else {
469
- for (uint32_t i = pipeline.descriptor_sets.size(); i < pipeline.descriptor_set_idx + n; i++) {
470
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
551
+ for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
552
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
471
553
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
472
- pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
554
+ pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
473
555
 
474
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[i], 1, &pipeline.dsl);
475
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
476
- pipeline.descriptor_sets.push_back(sets[0]);
556
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
557
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
558
+ pipeline->descriptor_sets.push_back(sets[0]);
477
559
  }
478
560
  }
479
561
  }
480
562
 
481
563
  static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
482
564
  #ifdef GGML_VULKAN_DEBUG
483
- std::cerr << "ggml_pipeline_cleanup(" << pipeline.name << ")" << std::endl;
565
+ std::cerr << "ggml_pipeline_cleanup(" << pipeline->name << ")" << std::endl;
484
566
  #endif
485
- pipeline.descriptor_set_idx = 0;
567
+ pipeline->descriptor_set_idx = 0;
486
568
  }
487
569
 
488
570
  static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx, vk_queue& q) {
@@ -498,7 +580,7 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx
498
580
  q.pool,
499
581
  vk::CommandBufferLevel::ePrimary,
500
582
  1);
501
- const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device.lock()->device.allocateCommandBuffers(command_buffer_alloc_info);
583
+ const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device->device.allocateCommandBuffers(command_buffer_alloc_info);
502
584
  auto buf = cmd_buffers.front();
503
585
 
504
586
  q.cmd_buffers.push_back(buf);
@@ -643,11 +725,11 @@ static void ggml_vk_create_queue(ggml_backend_vk_context * ctx, vk_queue& q, uin
643
725
  q.queue_family_index = queue_family_index;
644
726
 
645
727
  vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
646
- q.pool = ctx->device.lock()->device.createCommandPool(command_pool_create_info_compute);
728
+ q.pool = ctx->device->device.createCommandPool(command_pool_create_info_compute);
647
729
 
648
730
  q.cmd_buffer_idx = 0;
649
731
 
650
- q.queue = ctx->device.lock()->device.getQueue(queue_family_index, queue_index);
732
+ q.queue = ctx->device->device.getQueue(queue_family_index, queue_index);
651
733
 
652
734
  q.stage_flags = stage_flags;
653
735
  }
@@ -671,7 +753,7 @@ static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context *
671
753
  vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
672
754
  vk::SemaphoreCreateInfo ci{};
673
755
  ci.setPNext(&tci);
674
- vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
756
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
675
757
  ctx->gc.semaphores.push_back({ semaphore, 0 });
676
758
  return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
677
759
  }
@@ -684,7 +766,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
684
766
  vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
685
767
  vk::SemaphoreCreateInfo ci{};
686
768
  ci.setPNext(&tci);
687
- vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
769
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
688
770
  ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
689
771
  }
690
772
  return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
@@ -692,7 +774,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
692
774
 
693
775
  static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
694
776
  if (ctx->event_idx >= ctx->gc.events.size()) {
695
- ctx->gc.events.push_back(ctx->device.lock()->device.createEvent({}));
777
+ ctx->gc.events.push_back(ctx->device->device.createEvent({}));
696
778
  }
697
779
  return ctx->gc.events[ctx->event_idx++];
698
780
  }
@@ -703,7 +785,7 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
703
785
  #endif
704
786
  // Requires command buffers to be done
705
787
 
706
- ctx->device.lock()->device.resetCommandPool(q.pool);
788
+ ctx->device->device.resetCommandPool(q.pool);
707
789
  q.cmd_buffer_idx = 0;
708
790
  }
709
791
 
@@ -740,11 +822,11 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
740
822
  nullptr,
741
823
  };
742
824
 
743
- buf->buffer = ctx->device.lock()->device.createBuffer(buffer_create_info);
825
+ buf->buffer = ctx->device->device.createBuffer(buffer_create_info);
744
826
 
745
- vk::MemoryRequirements mem_req = ctx->device.lock()->device.getBufferMemoryRequirements(buf->buffer);
827
+ vk::MemoryRequirements mem_req = ctx->device->device.getBufferMemoryRequirements(buf->buffer);
746
828
 
747
- vk::PhysicalDeviceMemoryProperties mem_props = ctx->device.lock()->physical_device.getMemoryProperties();
829
+ vk::PhysicalDeviceMemoryProperties mem_props = ctx->device->physical_device.getMemoryProperties();
748
830
 
749
831
  uint32_t memory_type_index = UINT32_MAX;
750
832
 
@@ -757,30 +839,30 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
757
839
  }
758
840
 
759
841
  if (memory_type_index == UINT32_MAX) {
760
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
842
+ ctx->device->device.destroyBuffer(buf->buffer);
761
843
  buf->size = 0;
762
844
  throw vk::OutOfDeviceMemoryError("No suitable memory type found");
763
845
  }
764
846
 
765
847
  try {
766
- buf->device_memory = ctx->device.lock()->device.allocateMemory({ mem_req.size, memory_type_index });
848
+ buf->device_memory = ctx->device->device.allocateMemory({ mem_req.size, memory_type_index });
767
849
  } catch (const vk::SystemError& e) {
768
850
  // Out of Host/Device memory, clean up buffer
769
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
851
+ ctx->device->device.destroyBuffer(buf->buffer);
770
852
  buf->size = 0;
771
853
  throw e;
772
854
  }
773
855
  buf->ptr = nullptr;
774
856
 
775
857
  if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
776
- buf->ptr = ctx->device.lock()->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
858
+ buf->ptr = ctx->device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
777
859
  }
778
860
 
779
- ctx->device.lock()->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
861
+ ctx->device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
780
862
 
781
863
  buf->ctx = ctx;
782
864
 
783
- buf->device = ctx->device.lock();
865
+ buf->device = ctx->device;
784
866
 
785
867
  #ifdef GGML_VULKAN_DEBUG
786
868
  std::cerr << "Created buffer " << buf->buffer << std::endl;
@@ -802,7 +884,7 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
802
884
  static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
803
885
  vk_buffer buf;
804
886
  try {
805
- if (ctx->device.lock()->uma) {
887
+ if (ctx->device->uma) {
806
888
  // Fall back to host memory type
807
889
  buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
808
890
  } else {
@@ -883,10 +965,16 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
883
965
  std::cerr << "ggml_vk_load_shaders(" << ctx->name << ")" << std::endl;
884
966
  #endif
885
967
 
968
+ const std::shared_ptr<vk_device> device = ctx->device;
969
+
886
970
  // mulmat
887
- std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, ctx->device.lock()->subgroup_size * 2, 64, 2, 4, 4, ctx->device.lock()->subgroup_size };
888
- std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, ctx->device.lock()->subgroup_size, 32, 2, 4, 2, ctx->device.lock()->subgroup_size };
889
- std::initializer_list<uint32_t> warptile_s = { ctx->device.lock()->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, ctx->device.lock()->subgroup_size };
971
+ std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
972
+ std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
973
+ std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
974
+
975
+ std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
976
+ std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
977
+ std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
890
978
 
891
979
  std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
892
980
  std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
@@ -896,126 +984,206 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
896
984
  uint32_t m_align = 64;
897
985
  uint32_t s_align = 32;
898
986
 
899
- if (ctx->device.lock()->fp16) {
900
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_len, matmul_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
901
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_len, matmul_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
902
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_len, matmul_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
903
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_len, matmul_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
904
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_len, matmul_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
905
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_len, matmul_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
906
-
907
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_len, matmul_f16_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
908
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_len, matmul_f16_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
909
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_len, matmul_f16_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
910
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_len, matmul_f16_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
911
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_len, matmul_f16_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
912
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_len, matmul_f16_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
913
-
914
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_len, matmul_f16_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
915
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_len, matmul_f16_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
916
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_len, matmul_f16_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
917
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_len, matmul_f16_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
918
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_len, matmul_f16_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
919
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_len, matmul_f16_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
987
+ ctx->device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
988
+ ctx->device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
989
+ ctx->device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
990
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
991
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
992
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
993
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
994
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
995
+
996
+ if (device->fp16) {
997
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
998
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
999
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1000
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1001
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1002
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1003
+
1004
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1005
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1006
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1007
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1008
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1009
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1010
+
1011
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1012
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1013
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1014
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1015
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1016
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1017
+
1018
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1019
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1020
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1021
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1022
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1023
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1024
+
1025
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_0_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1026
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_0_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1027
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_0_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1028
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1029
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1030
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1031
+
1032
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1033
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1034
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1035
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1036
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1037
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1038
+
1039
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1040
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1041
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1042
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1043
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1044
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1045
+
1046
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1047
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1048
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1049
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1050
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1051
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
920
1052
  } else {
921
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
922
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
923
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_fp32_len, matmul_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
924
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_fp32_len, matmul_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
925
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_fp32_len, matmul_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
926
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_fp32_len, matmul_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
927
-
928
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_fp32_len, matmul_f16_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
929
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_fp32_len, matmul_f16_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
930
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_fp32_len, matmul_f16_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
931
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_fp32_len, matmul_f16_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
932
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_fp32_len, matmul_f16_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
933
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_fp32_len, matmul_f16_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
934
-
935
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_fp32_len, matmul_f16_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
936
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_fp32_len, matmul_f16_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
937
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_fp32_len, matmul_f16_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
938
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_fp32_len, matmul_f16_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
939
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_fp32_len, matmul_f16_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
940
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_fp32_len, matmul_f16_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
941
- }
942
-
943
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
944
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
945
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
946
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
947
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
948
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
949
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
950
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
951
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
952
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
953
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
1053
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1054
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1055
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1056
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1057
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1058
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1059
+
1060
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1061
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1062
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1063
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1064
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1065
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1066
+
1067
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1068
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1069
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1070
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1071
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1072
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1073
+
1074
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1075
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1076
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1077
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1078
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1079
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1080
+
1081
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1082
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1083
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1084
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1085
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1086
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1087
+
1088
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1089
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1090
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1091
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1092
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1093
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1094
+
1095
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1096
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1097
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1098
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1099
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1100
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1101
+
1102
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1103
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1104
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1105
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1106
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1107
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1108
+ }
1109
+
1110
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1111
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1112
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1113
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1114
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1115
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1116
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1117
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1118
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1119
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1120
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
954
1121
 
955
1122
  // dequant shaders
956
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 4 * sizeof(int), { 64, 1, 1}, {}, 1);
957
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
958
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
959
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
960
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
961
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
962
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
963
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
964
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
965
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
966
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
967
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
1123
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1124
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1125
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1126
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1127
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1128
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1129
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1130
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1131
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
1132
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1133
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
968
1134
 
969
1135
  // get_rows
970
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
971
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
972
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
973
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
974
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
975
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1136
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1137
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1138
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1139
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1140
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1141
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
976
1142
 
977
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
978
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
979
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
980
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
981
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
982
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1143
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1144
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1145
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1146
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1147
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1148
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
983
1149
 
984
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
1150
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
985
1151
 
986
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
987
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1152
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1153
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
988
1154
 
989
- ggml_vk_create_pipeline(ctx, ctx->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
990
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1155
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1156
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
991
1157
 
992
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
993
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
994
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
1158
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1159
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1160
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
995
1161
 
996
- ggml_vk_create_pipeline(ctx, ctx->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1162
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
997
1163
 
998
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1164
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
999
1165
 
1000
- ggml_vk_create_pipeline(ctx, ctx->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1166
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1001
1167
 
1002
- ggml_vk_create_pipeline(ctx, ctx->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1168
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1003
1169
 
1004
- ggml_vk_create_pipeline(ctx, ctx->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1170
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1005
1171
 
1006
- ggml_vk_create_pipeline(ctx, ctx->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1007
- ggml_vk_create_pipeline(ctx, ctx->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1008
- ggml_vk_create_pipeline(ctx, ctx->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1172
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1173
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1174
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1009
1175
 
1010
- ggml_vk_create_pipeline(ctx, ctx->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1176
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1011
1177
 
1012
- ggml_vk_create_pipeline(ctx, ctx->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1178
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1013
1179
 
1014
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1015
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1180
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1181
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1016
1182
 
1017
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1018
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1183
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1184
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1185
+
1186
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1019
1187
  }
1020
1188
 
1021
1189
  static void ggml_vk_print_gpu_info(size_t idx) {
@@ -1057,8 +1225,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
1057
1225
  }
1058
1226
  }
1059
1227
 
1060
- const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
1061
- bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
1228
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1229
+ bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1062
1230
 
1063
1231
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1064
1232
 
@@ -1106,7 +1274,9 @@ void ggml_vk_instance_init() {
1106
1274
 
1107
1275
  const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
1108
1276
  const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
1277
+ #ifdef __APPLE__
1109
1278
  const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
1279
+ #endif
1110
1280
 
1111
1281
  std::vector<const char*> layers;
1112
1282
 
@@ -1117,13 +1287,17 @@ void ggml_vk_instance_init() {
1117
1287
  if (validation_ext) {
1118
1288
  extensions.push_back("VK_EXT_validation_features");
1119
1289
  }
1290
+ #ifdef __APPLE__
1120
1291
  if (portability_enumeration_ext) {
1121
1292
  extensions.push_back("VK_KHR_portability_enumeration");
1122
1293
  }
1294
+ #endif
1123
1295
  vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
1296
+ #ifdef __APPLE__
1124
1297
  if (portability_enumeration_ext) {
1125
1298
  instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
1126
1299
  }
1300
+ #endif
1127
1301
 
1128
1302
  std::vector<vk::ValidationFeatureEnableEXT> features_enable;
1129
1303
  vk::ValidationFeaturesEXT validation_features;
@@ -1182,140 +1356,152 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
1182
1356
  throw std::runtime_error("Device not found");
1183
1357
  }
1184
1358
 
1185
- vk_instance.devices[idx] = std::make_shared<vk_device>();
1186
- ctx->device = vk_instance.devices[idx];
1187
- ctx->device.lock()->physical_device = devices[dev_num];
1188
- const std::vector<vk::ExtensionProperties> ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties();
1359
+ ctx->device = ggml_vk_get_device(idx);
1360
+ if (!ctx->device->initialized) {
1361
+ ctx->device->physical_device = devices[dev_num];
1362
+ const std::vector<vk::ExtensionProperties> ext_props = ctx->device->physical_device.enumerateDeviceExtensionProperties();
1189
1363
 
1190
- bool maintenance4_support = false;
1364
+ bool maintenance4_support = false;
1191
1365
 
1192
- // Check if maintenance4 is supported
1193
- for (const auto& properties : ext_props) {
1194
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1195
- maintenance4_support = true;
1366
+ // Check if maintenance4 is supported
1367
+ for (const auto& properties : ext_props) {
1368
+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1369
+ maintenance4_support = true;
1370
+ }
1196
1371
  }
1197
- }
1198
1372
 
1199
- vk::PhysicalDeviceProperties2 props2;
1200
- vk::PhysicalDeviceMaintenance3Properties props3;
1201
- vk::PhysicalDeviceMaintenance4Properties props4;
1202
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
1203
- props2.pNext = &props3;
1204
- props3.pNext = &subgroup_props;
1205
- if (maintenance4_support) {
1206
- subgroup_props.pNext = &props4;
1207
- }
1208
- ctx->device.lock()->physical_device.getProperties2(&props2);
1209
- ctx->device.lock()->properties = props2.properties;
1373
+ vk::PhysicalDeviceProperties2 props2;
1374
+ vk::PhysicalDeviceMaintenance3Properties props3;
1375
+ vk::PhysicalDeviceMaintenance4Properties props4;
1376
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
1377
+ props2.pNext = &props3;
1378
+ props3.pNext = &subgroup_props;
1379
+ if (maintenance4_support) {
1380
+ subgroup_props.pNext = &props4;
1381
+ }
1382
+ ctx->device->physical_device.getProperties2(&props2);
1383
+ ctx->device->properties = props2.properties;
1210
1384
 
1211
- if (maintenance4_support) {
1212
- ctx->device.lock()->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1213
- } else {
1214
- ctx->device.lock()->max_memory_allocation_size = props3.maxMemoryAllocationSize;
1215
- }
1385
+ const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
1216
1386
 
1217
- ctx->device.lock()->vendor_id = ctx->device.lock()->properties.vendorID;
1218
- ctx->device.lock()->subgroup_size = subgroup_props.subgroupSize;
1219
- ctx->device.lock()->uma = ctx->device.lock()->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1387
+ if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
1388
+ ctx->device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
1389
+ } else if (maintenance4_support) {
1390
+ ctx->device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1391
+ } else {
1392
+ ctx->device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
1393
+ }
1220
1394
 
1221
- bool fp16_storage = false;
1222
- bool fp16_compute = false;
1395
+ ctx->device->vendor_id = ctx->device->properties.vendorID;
1396
+ ctx->device->subgroup_size = subgroup_props.subgroupSize;
1397
+ ctx->device->uma = ctx->device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1223
1398
 
1224
- for (const auto& properties : ext_props) {
1225
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1226
- fp16_storage = true;
1227
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1228
- fp16_compute = true;
1399
+ bool fp16_storage = false;
1400
+ bool fp16_compute = false;
1401
+
1402
+ for (const auto& properties : ext_props) {
1403
+ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1404
+ fp16_storage = true;
1405
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1406
+ fp16_compute = true;
1407
+ }
1229
1408
  }
1230
- }
1231
1409
 
1232
- const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
1233
- bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
1410
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1411
+ const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1234
1412
 
1235
- ctx->device.lock()->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1413
+ ctx->device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1236
1414
 
1237
- std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device.lock()->physical_device.getQueueFamilyProperties();
1415
+ std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device->physical_device.getQueueFamilyProperties();
1238
1416
 
1239
- // Try to find a non-graphics compute queue and transfer-focused queues
1240
- const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
1241
- const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
1417
+ // Try to find a non-graphics compute queue and transfer-focused queues
1418
+ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
1419
+ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
1242
1420
 
1243
- const float priorities[] = { 1.0f, 1.0f };
1244
- ctx->device.lock()->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
1421
+ const float priorities[] = { 1.0f, 1.0f };
1422
+ ctx->device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
1245
1423
 
1246
- std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
1247
- if (compute_queue_family_index != transfer_queue_family_index) {
1248
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1249
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
1250
- } else if(!ctx->device.lock()->single_queue) {
1251
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
1252
- } else {
1253
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1254
- }
1255
- vk::DeviceCreateInfo device_create_info;
1256
- std::vector<const char *> device_extensions;
1257
- vk::PhysicalDeviceFeatures device_features = ctx->device.lock()->physical_device.getFeatures();
1424
+ std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
1425
+ if (compute_queue_family_index != transfer_queue_family_index) {
1426
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1427
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
1428
+ } else if(!ctx->device->single_queue) {
1429
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
1430
+ } else {
1431
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1432
+ }
1433
+ vk::DeviceCreateInfo device_create_info;
1434
+ std::vector<const char *> device_extensions;
1435
+ vk::PhysicalDeviceFeatures device_features = ctx->device->physical_device.getFeatures();
1258
1436
 
1259
- VkPhysicalDeviceFeatures2 device_features2;
1260
- device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
1261
- device_features2.pNext = nullptr;
1262
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
1437
+ VkPhysicalDeviceFeatures2 device_features2;
1438
+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
1439
+ device_features2.pNext = nullptr;
1440
+ device_features2.features = (VkPhysicalDeviceFeatures)device_features;
1263
1441
 
1264
- VkPhysicalDeviceVulkan11Features vk11_features;
1265
- vk11_features.pNext = nullptr;
1266
- vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
1267
- device_features2.pNext = &vk11_features;
1442
+ VkPhysicalDeviceVulkan11Features vk11_features;
1443
+ vk11_features.pNext = nullptr;
1444
+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
1445
+ device_features2.pNext = &vk11_features;
1268
1446
 
1269
- VkPhysicalDeviceVulkan12Features vk12_features;
1270
- vk12_features.pNext = nullptr;
1271
- vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1272
- vk11_features.pNext = &vk12_features;
1447
+ VkPhysicalDeviceVulkan12Features vk12_features;
1448
+ vk12_features.pNext = nullptr;
1449
+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1450
+ vk11_features.pNext = &vk12_features;
1273
1451
 
1274
- vkGetPhysicalDeviceFeatures2(ctx->device.lock()->physical_device, &device_features2);
1452
+ vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
1275
1453
 
1276
- ctx->device.lock()->fp16 = ctx->device.lock()->fp16 && vk12_features.shaderFloat16;
1454
+ ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
1277
1455
 
1278
- if (!vk11_features.storageBuffer16BitAccess) {
1279
- std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1280
- throw std::runtime_error("Unsupported device");
1281
- }
1456
+ if (!vk11_features.storageBuffer16BitAccess) {
1457
+ std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1458
+ throw std::runtime_error("Unsupported device");
1459
+ }
1282
1460
 
1283
- device_extensions.push_back("VK_KHR_16bit_storage");
1461
+ device_extensions.push_back("VK_KHR_16bit_storage");
1284
1462
 
1285
1463
  #ifdef GGML_VULKAN_VALIDATE
1286
- device_extensions.push_back("VK_KHR_shader_non_semantic_info");
1464
+ device_extensions.push_back("VK_KHR_shader_non_semantic_info");
1287
1465
  #endif
1288
1466
 
1289
- if (ctx->device.lock()->fp16) {
1290
- device_extensions.push_back("VK_KHR_shader_float16_int8");
1291
- }
1292
- ctx->device.lock()->name = ctx->device.lock()->properties.deviceName.data();
1467
+ if (ctx->device->fp16) {
1468
+ device_extensions.push_back("VK_KHR_shader_float16_int8");
1469
+ }
1470
+ ctx->device->name = ctx->device->properties.deviceName.data();
1293
1471
 
1294
- device_create_info = {
1295
- vk::DeviceCreateFlags(),
1296
- device_queue_create_infos,
1297
- {},
1298
- device_extensions
1299
- };
1300
- device_create_info.setPNext(&device_features2);
1301
- ctx->device.lock()->device = ctx->device.lock()->physical_device.createDevice(device_create_info);
1472
+ device_create_info = {
1473
+ vk::DeviceCreateFlags(),
1474
+ device_queue_create_infos,
1475
+ {},
1476
+ device_extensions
1477
+ };
1478
+ device_create_info.setPNext(&device_features2);
1479
+ ctx->device->device = ctx->device->physical_device.createDevice(device_create_info);
1302
1480
 
1303
- ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
1481
+ ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
1304
1482
 
1305
- // Shaders
1306
- ggml_vk_load_shaders(ctx);
1483
+ // Queues
1484
+ ggml_vk_create_queue(ctx, ctx->device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
1307
1485
 
1308
- // Queues
1309
- ggml_vk_create_queue(ctx, ctx->device.lock()->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
1310
- if (!ctx->device.lock()->single_queue) {
1311
- const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
1312
- ggml_vk_create_queue(ctx, ctx->device.lock()->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
1313
- } else {
1314
- // TODO: Use pointer or reference to avoid copy
1315
- ctx->device.lock()->transfer_queue = ctx->device.lock()->compute_queue;
1486
+ // Shaders
1487
+ ggml_vk_load_shaders(ctx);
1488
+
1489
+ if (!ctx->device->single_queue) {
1490
+ const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
1491
+ ggml_vk_create_queue(ctx, ctx->device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
1492
+ } else {
1493
+ // TODO: Use pointer or reference to avoid copy
1494
+ ctx->device->transfer_queue = ctx->device->compute_queue;
1495
+ }
1496
+
1497
+ ctx->device->idx = dev_num;
1498
+ ctx->device->initialized = true;
1499
+ } else if (ctx->device->idx != dev_num) {
1500
+ std::cerr << "ggml_vulkan: Device " << ctx->device->name << " already initialized with index " << ctx->device->idx << ", but trying to reinitialize with index " << dev_num << std::endl;
1501
+ throw std::runtime_error("Device already initialized");
1316
1502
  }
1317
1503
 
1318
- ctx->fence = ctx->device.lock()->device.createFence({});
1504
+ ctx->fence = ctx->device->device.createFence({});
1319
1505
 
1320
1506
  ctx->compute_ctx = nullptr;
1321
1507
  ctx->transfer_ctx = nullptr;
@@ -1333,7 +1519,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
1333
1519
  #endif
1334
1520
  }
1335
1521
 
1336
- static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
1522
+ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
1337
1523
  #ifdef GGML_VULKAN_DEBUG
1338
1524
  std::cerr << "ggml_vk_get_to_fp16()" << std::endl;
1339
1525
  #endif
@@ -1354,10 +1540,36 @@ static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
1354
1540
  return nullptr;
1355
1541
  }
1356
1542
 
1357
- return &ctx->pipeline_dequant[type];
1543
+ return ctx->device->pipeline_dequant[type];
1544
+ }
1545
+
1546
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
1547
+ #ifdef GGML_VULKAN_DEBUG
1548
+ std::cerr << "ggml_vk_get_mul_mat_mat_pipeline()" << std::endl;
1549
+ #endif
1550
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
1551
+ return ctx->device->pipeline_matmul_f32;
1552
+ }
1553
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
1554
+ return ctx->device->pipeline_matmul_f16_f32;
1555
+ }
1556
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
1557
+ return ctx->device->pipeline_matmul_f16;
1558
+ }
1559
+
1560
+ GGML_ASSERT(src1_type == GGML_TYPE_F32);
1561
+
1562
+ switch (src0_type) {
1563
+ case GGML_TYPE_Q4_0:
1564
+ break;
1565
+ default:
1566
+ return nullptr;
1567
+ }
1568
+
1569
+ return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
1358
1570
  }
1359
1571
 
1360
- static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
1572
+ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
1361
1573
  #ifdef GGML_VULKAN_DEBUG
1362
1574
  std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
1363
1575
  #endif
@@ -1378,7 +1590,7 @@ static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
1378
1590
  return nullptr;
1379
1591
  }
1380
1592
 
1381
- return &ctx->pipeline_dequant_mul_mat_vec_f32[type];
1593
+ return ctx->device->pipeline_dequant_mul_mat_vec_f32[type];
1382
1594
  }
1383
1595
 
1384
1596
  static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
@@ -1457,8 +1669,8 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
1457
1669
  if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1458
1670
  fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
1459
1671
  size/1024.0/1024.0);
1460
- ctx->device.lock()->device.freeMemory(buf->device_memory);
1461
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
1672
+ ctx->device->device.freeMemory(buf->device_memory);
1673
+ ctx->device->device.destroyBuffer(buf->buffer);
1462
1674
  return nullptr;
1463
1675
  }
1464
1676
 
@@ -1522,30 +1734,30 @@ static vk_submission ggml_vk_begin_submission(ggml_backend_vk_context * ctx, vk_
1522
1734
  }
1523
1735
 
1524
1736
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
1525
- const uint32_t wg0 = CEIL_DIV(elements[0], pipeline.wg_denoms[0]);
1526
- const uint32_t wg1 = CEIL_DIV(elements[1], pipeline.wg_denoms[1]);
1527
- const uint32_t wg2 = CEIL_DIV(elements[2], pipeline.wg_denoms[2]);
1737
+ const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
1738
+ const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
1739
+ const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
1528
1740
  #ifdef GGML_VULKAN_DEBUG
1529
- std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline.name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
1741
+ std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
1530
1742
  #endif
1531
1743
  std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
1532
1744
  std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
1533
- GGML_ASSERT(pipeline.descriptor_set_idx < pipeline.descriptor_sets.size());
1534
- GGML_ASSERT(buffers.size() == pipeline.parameter_count);
1535
- vk::DescriptorSet& descriptor_set = pipeline.descriptor_sets[pipeline.descriptor_set_idx++];
1536
- for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
1745
+ GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
1746
+ GGML_ASSERT(buffers.size() == pipeline->parameter_count);
1747
+ vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
1748
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
1537
1749
  descriptor_buffer_infos.push_back({buffers[i].buffer->buffer, buffers[i].offset, buffers[i].size});
1538
1750
  }
1539
- for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
1751
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
1540
1752
  write_descriptor_sets.push_back({descriptor_set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &descriptor_buffer_infos[i]});
1541
1753
  }
1542
1754
 
1543
- ctx->device.lock()->device.updateDescriptorSets(write_descriptor_sets, {});
1755
+ ctx->device->device.updateDescriptorSets(write_descriptor_sets, {});
1544
1756
 
1545
- subctx->s->buffer.pushConstants(pipeline.layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
1546
- subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline.pipeline);
1757
+ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
1758
+ subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
1547
1759
  subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
1548
- pipeline.layout,
1760
+ pipeline->layout,
1549
1761
  0,
1550
1762
  { descriptor_set },
1551
1763
  {});
@@ -1804,7 +2016,7 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
1804
2016
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
1805
2017
  }
1806
2018
  } else {
1807
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2019
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1808
2020
  ggml_vk_ctx_begin(ctx, subctx);
1809
2021
  ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, src, spitch, width, height, true);
1810
2022
  ggml_vk_ctx_end(subctx);
@@ -1814,8 +2026,9 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
1814
2026
  }
1815
2027
 
1816
2028
  ggml_vk_submit(subctx, ctx->fence);
1817
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
1818
- ctx->device.lock()->device.resetFences({ ctx->fence });
2029
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
2030
+ ctx->device->device.resetFences({ ctx->fence });
2031
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
1819
2032
  }
1820
2033
  }
1821
2034
 
@@ -1900,18 +2113,19 @@ static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, s
1900
2113
 
1901
2114
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
1902
2115
  } else {
1903
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2116
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1904
2117
  ggml_vk_ctx_begin(ctx, subctx);
1905
2118
  ggml_vk_buffer_read_async(ctx, subctx, src, offset, dst, size, true);
1906
2119
  ggml_vk_ctx_end(subctx);
1907
2120
 
1908
2121
  ggml_vk_submit(subctx, ctx->fence);
1909
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
1910
- ctx->device.lock()->device.resetFences({ ctx->fence });
2122
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
2123
+ ctx->device->device.resetFences({ ctx->fence });
1911
2124
 
1912
2125
  for (auto& cpy : subctx->out_memcpys) {
1913
2126
  memcpy(cpy.dst, cpy.src, cpy.n);
1914
2127
  }
2128
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
1915
2129
  }
1916
2130
  }
1917
2131
 
@@ -1935,15 +2149,13 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
1935
2149
  // Copy within the device
1936
2150
  ggml_backend_vk_context * ctx = src->ctx;
1937
2151
 
1938
- VkBufferCopy bc{ src_offset, dst_offset, size };
1939
-
1940
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2152
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1941
2153
  ggml_vk_ctx_begin(ctx, subctx);
1942
2154
  ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
1943
2155
  ggml_vk_ctx_end(subctx);
1944
2156
  ggml_vk_submit(subctx, ctx->fence);
1945
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
1946
- ctx->device.lock()->device.resetFences({ ctx->fence });
2157
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
2158
+ ctx->device->device.resetFences({ ctx->fence });
1947
2159
  } else {
1948
2160
  #ifdef GGML_VULKAN_DEBUG
1949
2161
  std::cerr << "ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")" << std::endl;
@@ -1971,14 +2183,14 @@ static void ggml_vk_buffer_memset(ggml_backend_vk_context * ctx, vk_buffer& dst,
1971
2183
  // Make sure ctx owns the buffer
1972
2184
  GGML_ASSERT(dst->ctx == ctx);
1973
2185
 
1974
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2186
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1975
2187
  ggml_vk_ctx_begin(ctx, subctx);
1976
2188
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
1977
2189
  ggml_vk_ctx_end(subctx);
1978
2190
 
1979
2191
  ggml_vk_submit(subctx, ctx->fence);
1980
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
1981
- ctx->device.lock()->device.resetFences({ ctx->fence });
2192
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
2193
+ ctx->device->device.resetFences({ ctx->fence });
1982
2194
  }
1983
2195
 
1984
2196
  static void ggml_vk_h2d_tensor_2d(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * src, uint64_t i3, uint64_t i2, uint64_t i1) {
@@ -2039,176 +2251,63 @@ static void ggml_vk_d2h_tensor_2d(ggml_backend_vk_context * ctx, vk_context * su
2039
2251
 
2040
2252
  static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
2041
2253
  #ifdef GGML_VULKAN_DEBUG
2042
- std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")";
2254
+ std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")" << std::endl;
2043
2255
  #endif
2044
2256
  if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2045
- #ifdef GGML_VULKAN_DEBUG
2046
- std::cerr << " = 4" << std::endl;
2047
- #endif
2048
2257
  return 4;
2049
2258
  }
2050
2259
 
2051
- #ifdef GGML_VULKAN_DEBUG
2052
- std::cerr << " = 1" << std::endl;
2053
- #endif
2054
2260
  return 1;
2055
2261
  }
2056
2262
 
2057
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, int m, int n) {
2058
- #ifdef GGML_VULKAN_DEBUG
2059
- std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
2060
- #endif
2263
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2061
2264
  if (m <= 32 || n <= 32) {
2062
- return ctx->pipeline_matmul_f32_aligned_s.align;
2265
+ return aligned ? mmp->a_s : mmp->s;
2063
2266
  }
2064
- if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
2065
- return ctx->pipeline_matmul_f32_aligned_m.align;
2066
- }
2067
- return ctx->pipeline_matmul_f32_aligned_l.align;
2267
+ return aligned ? mmp->a_m : mmp->m;
2268
+
2269
+ GGML_UNUSED(ctx);
2068
2270
  }
2069
2271
 
2070
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2071
- if (bit16_x && bit16_y) {
2072
- if (m <= 32 || n <= 32) {
2073
- #ifdef GGML_VULKAN_DEBUG
2074
- std::cerr << " S" << std::endl;
2075
- #endif
2076
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2077
- }
2078
- #ifdef GGML_VULKAN_DEBUG
2079
- std::cerr << " M" << std::endl;
2080
- #endif
2081
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2082
- }
2083
- if (bit16_x && !bit16_y) {
2084
- if (m <= 32 || n <= 32) {
2085
- #ifdef GGML_VULKAN_DEBUG
2086
- std::cerr << " S" << std::endl;
2087
- #endif
2088
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2089
- }
2090
- #ifdef GGML_VULKAN_DEBUG
2091
- std::cerr << " M" << std::endl;
2092
- #endif
2093
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2094
- }
2095
- if (!bit16_x && bit16_y) {
2096
- GGML_ASSERT(false);
2097
- }
2272
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2273
+ return aligned ? mmp->a_m : mmp->m;
2098
2274
 
2099
- if (m <= 32 || n <= 32) {
2100
- #ifdef GGML_VULKAN_DEBUG
2101
- std::cerr << " S" << std::endl;
2102
- #endif
2103
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2104
- }
2105
- #ifdef GGML_VULKAN_DEBUG
2106
- std::cerr << " M" << std::endl;
2107
- #endif
2108
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2275
+ GGML_UNUSED(ctx);
2109
2276
  }
2110
2277
 
2111
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2112
- #ifdef GGML_VULKAN_DEBUG
2113
- std::cerr << " M" << std::endl;
2114
- #endif
2115
- if (bit16_x && bit16_y) {
2116
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2117
- }
2118
- if (bit16_x && !bit16_y) {
2119
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2120
- }
2121
- if (!bit16_x && bit16_y) {
2122
- GGML_ASSERT(false);
2123
- }
2124
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2125
- }
2278
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2279
+ return aligned ? mmp->a_s : mmp->s;
2126
2280
 
2127
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2128
- #ifdef GGML_VULKAN_DEBUG
2129
- std::cerr << " S" << std::endl;
2130
- #endif
2131
- if (bit16_x && bit16_y) {
2132
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2133
- }
2134
- if (bit16_x && !bit16_y) {
2135
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2136
- }
2137
- if (!bit16_x && bit16_y) {
2138
- GGML_ASSERT(false);
2139
- }
2140
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2281
+ GGML_UNUSED(ctx);
2141
2282
  }
2142
2283
 
2143
- static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2284
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2144
2285
  #ifdef GGML_VULKAN_DEBUG
2145
- std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
2286
+ std::cerr << "ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")" << std::endl;
2146
2287
  #endif
2147
- switch (ctx->device.lock()->vendor_id) {
2288
+ switch (ctx->device->vendor_id) {
2148
2289
  case VK_VENDOR_ID_AMD:
2149
- return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
2290
+ return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
2150
2291
  case VK_VENDOR_ID_APPLE:
2151
- return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
2292
+ return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
2152
2293
  case VK_VENDOR_ID_INTEL:
2153
- return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
2154
- }
2155
-
2156
- if (bit16_x && bit16_y) {
2157
- if (m <= 32 || n <= 32) {
2158
- #ifdef GGML_VULKAN_DEBUG
2159
- std::cerr << " S" << std::endl;
2160
- #endif
2161
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2162
- }
2163
- if (m <= 64 || n <= 64) {
2164
- #ifdef GGML_VULKAN_DEBUG
2165
- std::cerr << " M" << std::endl;
2166
- #endif
2167
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2168
- }
2169
- #ifdef GGML_VULKAN_DEBUG
2170
- std::cerr << " L" << std::endl;
2171
- #endif
2172
- return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l;
2173
- }
2174
- if (bit16_x && !bit16_y) {
2175
- if (m <= 32 || n <= 32) {
2176
- #ifdef GGML_VULKAN_DEBUG
2177
- std::cerr << " S" << std::endl;
2178
- #endif
2179
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2180
- }
2181
- if (m <= 64 || n <= 64) {
2182
- #ifdef GGML_VULKAN_DEBUG
2183
- std::cerr << " M" << std::endl;
2184
- #endif
2185
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2186
- }
2187
- #ifdef GGML_VULKAN_DEBUG
2188
- std::cerr << " L" << std::endl;
2189
- #endif
2190
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_l : &ctx->pipeline_matmul_f16_f32_l;
2191
- }
2192
- if (!bit16_x && bit16_y) {
2193
- GGML_ASSERT(false);
2294
+ return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
2194
2295
  }
2195
2296
 
2196
2297
  if (m <= 32 || n <= 32) {
2197
- #ifdef GGML_VULKAN_DEBUG
2198
- std::cerr << " S" << std::endl;
2199
- #endif
2200
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2298
+ return aligned ? mmp->a_s : mmp->s;
2201
2299
  }
2202
2300
  if (m <= 64 || n <= 64) {
2203
- #ifdef GGML_VULKAN_DEBUG
2204
- std::cerr << " M" << std::endl;
2205
- #endif
2206
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2301
+ return aligned ? mmp->a_m : mmp->m;
2207
2302
  }
2303
+ return aligned ? mmp->a_l : mmp->l;
2304
+ }
2305
+
2306
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
2208
2307
  #ifdef GGML_VULKAN_DEBUG
2209
- std::cerr << " L" << std::endl;
2308
+ std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
2210
2309
  #endif
2211
- return aligned ? &ctx->pipeline_matmul_f32_aligned_l : &ctx->pipeline_matmul_f32_l;
2310
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, false)->align;
2212
2311
  }
2213
2312
 
2214
2313
  static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
@@ -2226,10 +2325,10 @@ static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, v
2226
2325
 
2227
2326
  const std::array<uint32_t, 14> pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d };
2228
2327
  // Make sure enough workgroups get assigned for split k to work
2229
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline.wg_denoms[0]) * pipeline.wg_denoms[0]) * split_k, n, batch });
2328
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
2230
2329
  ggml_vk_sync_buffers(subctx);
2231
2330
  const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
2232
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2331
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2233
2332
  }
2234
2333
 
2235
2334
  static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -2239,41 +2338,39 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
2239
2338
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
2240
2339
  }
2241
2340
 
2242
- static vk_pipeline * ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
2341
+ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
2243
2342
  if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
2244
- return &ctx->pipeline_cpy_f32_f32;
2343
+ return ctx->device->pipeline_cpy_f32_f32;
2245
2344
  }
2246
2345
  if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
2247
- return &ctx->pipeline_cpy_f32_f16;
2346
+ return ctx->device->pipeline_cpy_f32_f16;
2248
2347
  }
2249
2348
  if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
2250
- return &ctx->pipeline_cpy_f16_f16;
2349
+ return ctx->device->pipeline_cpy_f16_f16;
2251
2350
  }
2252
2351
 
2253
2352
  std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
2254
2353
  GGML_ASSERT(false);
2255
2354
  }
2256
2355
 
2257
- static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline * pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out, ggml_type buffer_type, bool aligned=true) {
2356
+ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
2258
2357
  #ifdef GGML_VULKAN_DEBUG
2259
2358
  std::cerr << "ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", backend=" << tensor->backend << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
2260
2359
  std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")" << std::endl;
2261
2360
  #endif
2262
2361
  const int tensor_type_size = ggml_type_size(tensor->type);
2263
- const int dst_type_size = ggml_type_size(buffer_type);
2264
2362
 
2265
- const uint32_t ne = tensor->ne[0] * tensor->ne[1] * tensor->ne[2];
2363
+ const uint32_t ne = ggml_nelements(tensor);
2266
2364
 
2267
- const uint32_t nb2 = aligned ? ggml_vk_align_size(dst_type_size * tensor->ne[0] * tensor->ne[1], ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size : tensor->ne[0] * tensor->ne[1];
2268
-
2269
- const vk_op_cpy_push_constants pc = {
2365
+ const vk_op_unary_push_constants pc = {
2270
2366
  (uint32_t)ne,
2271
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size,
2272
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], 1 , (uint32_t)tensor->ne[0] , nb2,
2367
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
2368
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
2273
2369
  0,
2370
+ 0.0f, 0.0f,
2274
2371
  };
2275
2372
  ggml_vk_sync_buffers(subctx);
2276
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { in, out }, sizeof(vk_op_cpy_push_constants), &pc, { ne, 1, 1 });
2373
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
2277
2374
  }
2278
2375
 
2279
2376
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2313,23 +2410,30 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2313
2410
  bool src0_uma = false;
2314
2411
  bool src1_uma = false;
2315
2412
 
2316
- if (ctx->device.lock()->uma) {
2413
+ if (ctx->device->uma) {
2317
2414
  ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
2318
2415
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2319
2416
  src0_uma = d_Qx != nullptr;
2320
2417
  src1_uma = d_Qy != nullptr;
2321
2418
  }
2322
2419
 
2323
- const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma;
2324
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2420
+ const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
2421
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2325
2422
 
2326
2423
  const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
2327
2424
  const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
2328
2425
 
2329
- const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2426
+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2330
2427
 
2331
- const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 || x_non_contig;
2332
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
2428
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
2429
+
2430
+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
2431
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
2432
+
2433
+ if (mmp == nullptr) {
2434
+ // Fall back to dequant + f16 mulmat
2435
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
2436
+ }
2333
2437
 
2334
2438
  // Not implemented
2335
2439
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
@@ -2338,17 +2442,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2338
2442
  const int y_ne = ne11 * ne10;
2339
2443
  const int d_ne = ne11 * ne01;
2340
2444
 
2341
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, ne01, ne11));
2445
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
2342
2446
  const bool aligned = ne10 == kpad;
2343
2447
 
2344
2448
  const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
2345
2449
 
2346
- vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(ctx, true, !f16_f32_kernel, ne01, ne11, aligned);
2450
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
2347
2451
 
2348
2452
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
2349
2453
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2350
- const uint64_t x_sz = sizeof(ggml_fp16_t) * x_ne;
2351
- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2454
+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
2455
+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2352
2456
  const uint64_t d_sz = sizeof(float) * d_ne;
2353
2457
 
2354
2458
  vk_buffer d_D = extra->buffer_gpu.lock();
@@ -2379,7 +2483,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2379
2483
  } else {
2380
2484
  d_X = d_Qx;
2381
2485
  x_buf_offset = qx_buf_offset;
2382
- GGML_ASSERT(qx_sz == x_sz); // NOLINT
2486
+ GGML_ASSERT(qx_sz == x_sz);
2383
2487
  }
2384
2488
  if (qy_needs_dequant) {
2385
2489
  d_Y = ctx->prealloc_y;
@@ -2390,8 +2494,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2390
2494
  GGML_ASSERT(qy_sz == y_sz);
2391
2495
  }
2392
2496
 
2393
- vk_pipeline * to_fp16_vk_0 = nullptr;
2394
- vk_pipeline * to_fp16_vk_1 = nullptr;
2497
+ vk_pipeline to_fp16_vk_0 = nullptr;
2498
+ vk_pipeline to_fp16_vk_1 = nullptr;
2395
2499
 
2396
2500
  if (x_non_contig) {
2397
2501
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
@@ -2407,19 +2511,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2407
2511
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
2408
2512
 
2409
2513
  // Allocate descriptor sets
2410
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne12 * ne13);
2514
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
2411
2515
  if (qx_needs_dequant) {
2412
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, x_non_contig ? 1 : ne12 * ne13);
2516
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
2413
2517
  }
2414
2518
  if (qy_needs_dequant) {
2415
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2519
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1);
2416
2520
  }
2417
2521
  if (split_k > 1) {
2418
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, ne12 * ne13);
2522
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
2419
2523
  }
2420
2524
 
2421
2525
  if (x_non_contig) {
2422
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, dst->type, false);
2526
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
2423
2527
  } else if (load_x || qx_needs_dequant) {
2424
2528
  if (load_x) {
2425
2529
  // copy data to device
@@ -2428,13 +2532,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2428
2532
  }
2429
2533
 
2430
2534
  if (qx_needs_dequant) {
2431
- const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
2535
+ const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
2432
2536
  ggml_vk_sync_buffers(subctx);
2433
- ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
2537
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
2434
2538
  }
2435
2539
  }
2436
2540
  if (y_non_contig) {
2437
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, dst->type);
2541
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
2438
2542
  } else if (load_y) {
2439
2543
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
2440
2544
  }
@@ -2451,9 +2555,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2451
2555
  }
2452
2556
 
2453
2557
  // compute
2454
- ggml_vk_matmul(ctx, subctx, *pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
2558
+ ggml_vk_matmul(ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
2455
2559
 
2456
- if (dst->backend == GGML_BACKEND_CPU) {
2560
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2457
2561
  // copy dst to host
2458
2562
  float * d = (float *) ((char *) dst->data);
2459
2563
  ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, sizeof(float) * d_ne * ne12 * ne13);
@@ -2499,15 +2603,15 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2499
2603
  bool src0_uma = false;
2500
2604
  bool src1_uma = false;
2501
2605
 
2502
- if (ctx->device.lock()->uma) {
2606
+ if (ctx->device->uma) {
2503
2607
  ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
2504
2608
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2505
2609
  src0_uma = d_Qx != nullptr;
2506
2610
  src1_uma = d_Qy != nullptr;
2507
2611
  }
2508
2612
 
2509
- const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma;
2510
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2613
+ const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
2614
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2511
2615
 
2512
2616
  const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
2513
2617
  const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
@@ -2521,9 +2625,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2521
2625
  const uint64_t y_ne = ne11 * ne10;
2522
2626
  const uint64_t d_ne = ne11 * ne01;
2523
2627
 
2524
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
2628
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
2525
2629
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2526
- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
2630
+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
2527
2631
  const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2528
2632
  const uint64_t d_sz = sizeof(float) * d_ne;
2529
2633
 
@@ -2563,8 +2667,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2563
2667
  GGML_ASSERT(qy_sz == y_sz);
2564
2668
  }
2565
2669
 
2566
- vk_pipeline * to_fp16_vk_0 = nullptr;
2567
- vk_pipeline* to_fp16_vk_1 = nullptr;
2670
+ vk_pipeline to_fp16_vk_0 = nullptr;
2671
+ vk_pipeline to_fp16_vk_1 = nullptr;
2568
2672
  if (x_non_contig) {
2569
2673
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
2570
2674
  }
@@ -2573,30 +2677,30 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2573
2677
  } else {
2574
2678
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
2575
2679
  }
2576
- vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
2680
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
2577
2681
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
2578
2682
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
2579
2683
  GGML_ASSERT(dmmv != nullptr);
2580
2684
 
2581
2685
  // Allocate descriptor sets
2582
2686
  if (qx_needs_dequant) {
2583
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, 1);
2687
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
2584
2688
  }
2585
2689
  if (qy_needs_dequant) {
2586
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2690
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2587
2691
  }
2588
- ggml_pipeline_allocate_descriptor_sets(ctx, *dmmv, ne12 * ne13);
2692
+ ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
2589
2693
 
2590
2694
  if (x_non_contig) {
2591
- GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment));
2592
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, src0->type);
2695
+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
2696
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
2593
2697
  } else if (load_x) {
2594
2698
  // copy data to device
2595
2699
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qx, 0, src0, 0, 0, ggml_nrows(src0));
2596
2700
  }
2597
2701
  if (y_non_contig) {
2598
2702
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
2599
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, src1->type);
2703
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
2600
2704
  } else if (load_y) {
2601
2705
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
2602
2706
  }
@@ -2613,24 +2717,24 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2613
2717
  const uint64_t y_offset = y_buf_offset + y_sz * it_idx1;
2614
2718
  const uint64_t d_offset = d_buf_offset + d_sz * it_idx1;
2615
2719
 
2616
- const uint64_t y_buffer_offset = (y_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2720
+ const uint64_t y_buffer_offset = (y_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2617
2721
  const uint64_t y_shader_offset = y_offset - y_buffer_offset;
2618
2722
 
2619
- const uint64_t d_buffer_offset = (d_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2723
+ const uint64_t d_buffer_offset = (d_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2620
2724
  const uint64_t d_shader_offset = d_offset - d_buffer_offset;
2621
2725
 
2622
2726
  if (!y_non_contig && qy_needs_dequant) {
2623
- const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
2727
+ const std::vector<uint32_t> pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) };
2624
2728
  ggml_vk_sync_buffers(subctx);
2625
- ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
2729
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1});
2626
2730
  }
2627
2731
 
2628
2732
  // compute
2629
- const std::array<int, 3> pc = { (int)ne00, (int)(y_shader_offset / ggml_type_size(src1->type)), (int)(d_shader_offset / ggml_type_size(dst->type))};
2733
+ const std::array<uint32_t, 3> pc = { (uint32_t)ne00, (uint32_t)(y_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type))};
2630
2734
  ggml_vk_sync_buffers(subctx);
2631
- ggml_vk_dispatch_pipeline(ctx, subctx, *dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
2735
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
2632
2736
 
2633
- if (dst->backend == GGML_BACKEND_CPU) {
2737
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2634
2738
  // copy dst to host
2635
2739
  float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
2636
2740
  ggml_vk_sync_buffers(subctx);
@@ -2647,7 +2751,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2647
2751
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
2648
2752
  #endif
2649
2753
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
2650
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2754
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
2651
2755
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
2652
2756
  GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
2653
2757
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -2674,18 +2778,18 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2674
2778
 
2675
2779
  bool src1_uma = false;
2676
2780
 
2677
- if (ctx->device.lock()->uma) {
2781
+ if (ctx->device->uma) {
2678
2782
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2679
2783
  src1_uma = d_Qy != nullptr;
2680
2784
  }
2681
2785
 
2682
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2786
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2683
2787
 
2684
2788
  const uint64_t x_ne = ne00 * ne01 * ne02;
2685
2789
  const uint64_t y_ne = ne10 * ne11 * ne12;
2686
2790
  const uint64_t d_ne = ne01 * ne11 * ne12;
2687
2791
 
2688
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
2792
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
2689
2793
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2690
2794
  const uint64_t d_sz = sizeof(float) * d_ne;
2691
2795
 
@@ -2704,12 +2808,12 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2704
2808
  }
2705
2809
 
2706
2810
  // Allocate descriptor sets
2707
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, 1);
2811
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
2708
2812
 
2709
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2813
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2710
2814
  const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
2711
2815
 
2712
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2816
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2713
2817
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
2714
2818
 
2715
2819
  if (load_y) {
@@ -2719,9 +2823,9 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2719
2823
  // compute
2720
2824
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
2721
2825
  ggml_vk_sync_buffers(subctx);
2722
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2826
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2723
2827
 
2724
- if (dst->backend == GGML_BACKEND_CPU) {
2828
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2725
2829
  // copy dst to host
2726
2830
  float * d = (float *) dst->data;
2727
2831
  ggml_vk_sync_buffers(subctx);
@@ -2738,7 +2842,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2738
2842
  GGML_ASSERT(!ggml_is_transposed(src0));
2739
2843
  GGML_ASSERT(!ggml_is_transposed(src1));
2740
2844
  GGML_ASSERT(!ggml_is_permuted(src0));
2741
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2845
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
2742
2846
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2743
2847
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2744
2848
 
@@ -2766,12 +2870,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2766
2870
 
2767
2871
  bool src1_uma = false;
2768
2872
 
2769
- if (ctx->device.lock()->uma) {
2873
+ if (ctx->device->uma) {
2770
2874
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2771
2875
  src1_uma = d_Qy != nullptr;
2772
2876
  }
2773
2877
 
2774
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2878
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2775
2879
 
2776
2880
  const uint64_t d_ne = ne01 * ne11 * ne12;
2777
2881
 
@@ -2797,12 +2901,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2797
2901
  }
2798
2902
 
2799
2903
  // Allocate descriptor sets
2800
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, 1);
2904
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
2801
2905
 
2802
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2906
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2803
2907
  const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
2804
2908
 
2805
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2909
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2806
2910
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
2807
2911
 
2808
2912
  if (load_y) {
@@ -2812,9 +2916,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2812
2916
  // compute
2813
2917
  const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
2814
2918
  ggml_vk_sync_buffers(subctx);
2815
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2919
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2816
2920
 
2817
- if (dst->backend == GGML_BACKEND_CPU) {
2921
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2818
2922
  // copy dst to host
2819
2923
  float * d = (float *) dst->data;
2820
2924
  ggml_vk_sync_buffers(subctx);
@@ -2832,7 +2936,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr
2832
2936
  return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2833
2937
  (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || ggml_is_quantized(src1->type)) &&
2834
2938
  dst->type == GGML_TYPE_F32 &&
2835
- ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU);
2939
+ ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU);
2836
2940
  }
2837
2941
 
2838
2942
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -2850,6 +2954,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
2850
2954
  }
2851
2955
  }
2852
2956
 
2957
+ // static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
2958
+ //
2959
+ // }
2960
+
2853
2961
  static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2854
2962
  // guaranteed to be an integer due to the check in ggml_can_repeat
2855
2963
  const uint64_t ne0 = dst->ne[0];
@@ -2880,8 +2988,8 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
2880
2988
  // TODO: support for transposed / permuted tensors
2881
2989
  GGML_ASSERT(nb0 == sizeof(float));
2882
2990
  GGML_ASSERT(nb00 == sizeof(float));
2883
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2884
- GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
2991
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
2992
+ GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
2885
2993
 
2886
2994
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
2887
2995
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
@@ -2921,40 +3029,40 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
2921
3029
  }
2922
3030
 
2923
3031
 
2924
- static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) {
3032
+ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
2925
3033
  switch (op) {
2926
3034
  case GGML_OP_ADD:
2927
3035
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2928
- return &ctx->pipeline_add_f32;
3036
+ return ctx->device->pipeline_add_f32;
2929
3037
  }
2930
3038
  return nullptr;
2931
3039
  case GGML_OP_GET_ROWS:
2932
3040
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
2933
3041
  if (dst->type == GGML_TYPE_F16) {
2934
- return &ctx->pipeline_get_rows[src0->type];
3042
+ return ctx->device->pipeline_get_rows[src0->type];
2935
3043
  }
2936
3044
  if (dst->type == GGML_TYPE_F32) {
2937
- return &ctx->pipeline_get_rows_f32[src0->type];
3045
+ return ctx->device->pipeline_get_rows_f32[src0->type];
2938
3046
  }
2939
3047
  return nullptr;
2940
3048
  case GGML_OP_MUL:
2941
3049
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2942
- return &ctx->pipeline_mul_f32;
3050
+ return ctx->device->pipeline_mul_f32;
2943
3051
  }
2944
3052
  return nullptr;
2945
3053
  case GGML_OP_SCALE:
2946
3054
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2947
- return &ctx->pipeline_scale_f32;
3055
+ return ctx->device->pipeline_scale_f32;
2948
3056
  }
2949
3057
  return nullptr;
2950
3058
  case GGML_OP_SQR:
2951
3059
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2952
- return &ctx->pipeline_sqr_f32;
3060
+ return ctx->device->pipeline_sqr_f32;
2953
3061
  }
2954
3062
  return nullptr;
2955
3063
  case GGML_OP_CLAMP:
2956
3064
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2957
- return &ctx->pipeline_clamp_f32;
3065
+ return ctx->device->pipeline_clamp_f32;
2958
3066
  }
2959
3067
  return nullptr;
2960
3068
  case GGML_OP_CPY:
@@ -2963,29 +3071,29 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
2963
3071
  return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
2964
3072
  case GGML_OP_NORM:
2965
3073
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2966
- return &ctx->pipeline_norm_f32;
3074
+ return ctx->device->pipeline_norm_f32;
2967
3075
  }
2968
3076
  return nullptr;
2969
3077
  case GGML_OP_RMS_NORM:
2970
3078
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2971
- return &ctx->pipeline_rms_norm_f32;
3079
+ return ctx->device->pipeline_rms_norm_f32;
2972
3080
  }
2973
3081
  return nullptr;
2974
3082
  case GGML_OP_UNARY:
2975
3083
  switch (ggml_get_unary_op(dst)) {
2976
3084
  case GGML_UNARY_OP_SILU:
2977
3085
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2978
- return &ctx->pipeline_silu_f32;
3086
+ return ctx->device->pipeline_silu_f32;
2979
3087
  }
2980
3088
  break;
2981
3089
  case GGML_UNARY_OP_GELU:
2982
3090
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2983
- return &ctx->pipeline_gelu_f32;
3091
+ return ctx->device->pipeline_gelu_f32;
2984
3092
  }
2985
3093
  break;
2986
3094
  case GGML_UNARY_OP_RELU:
2987
3095
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2988
- return &ctx->pipeline_relu_f32;
3096
+ return ctx->device->pipeline_relu_f32;
2989
3097
  }
2990
3098
  break;
2991
3099
  default:
@@ -2994,12 +3102,12 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
2994
3102
  return nullptr;
2995
3103
  case GGML_OP_DIAG_MASK_INF:
2996
3104
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2997
- return &ctx->pipeline_diag_mask_inf_f32;
3105
+ return ctx->device->pipeline_diag_mask_inf_f32;
2998
3106
  }
2999
3107
  return nullptr;
3000
3108
  case GGML_OP_SOFT_MAX:
3001
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3002
- return &ctx->pipeline_soft_max_f32;
3109
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3110
+ return ctx->device->pipeline_soft_max_f32;
3003
3111
  }
3004
3112
  return nullptr;
3005
3113
  case GGML_OP_ROPE:
@@ -3014,21 +3122,26 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3014
3122
 
3015
3123
  if (is_neox) {
3016
3124
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3017
- return &ctx->pipeline_rope_neox_f32;
3125
+ return ctx->device->pipeline_rope_neox_f32;
3018
3126
  }
3019
3127
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
3020
- return &ctx->pipeline_rope_neox_f16;
3128
+ return ctx->device->pipeline_rope_neox_f16;
3021
3129
  }
3022
3130
  } else {
3023
3131
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3024
- return &ctx->pipeline_rope_f32;
3132
+ return ctx->device->pipeline_rope_f32;
3025
3133
  }
3026
3134
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
3027
- return &ctx->pipeline_rope_f16;
3135
+ return ctx->device->pipeline_rope_f16;
3028
3136
  }
3029
3137
  }
3030
3138
  return nullptr;
3031
3139
  }
3140
+ case GGML_OP_ARGSORT:
3141
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
3142
+ return ctx->device->pipeline_argsort_f32;
3143
+ }
3144
+ return nullptr;
3032
3145
  default:
3033
3146
  return nullptr;
3034
3147
  }
@@ -3044,17 +3157,19 @@ static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) {
3044
3157
  }
3045
3158
 
3046
3159
  template<typename PC>
3047
- static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, const PC&& pc) {
3160
+ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
3048
3161
  #ifdef GGML_VULKAN_DEBUG
3049
3162
  std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
3050
3163
  if (src1 != nullptr) {
3051
3164
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
3052
3165
  }
3166
+ if (src2 != nullptr) {
3167
+ std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", backend=" << src2->backend << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
3168
+ }
3053
3169
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
3054
3170
  #endif
3055
3171
  GGML_ASSERT(!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))); // NOLINT
3056
3172
  GGML_ASSERT(op == GGML_OP_CPY || ggml_vk_dim01_contiguous(src0)); // NOLINT
3057
- GGML_ASSERT(src1 == nullptr || ggml_vk_dim01_contiguous(src1)); // NOLINT
3058
3173
  GGML_ASSERT(dst->extra != nullptr);
3059
3174
  const uint64_t ne00 = src0->ne[0];
3060
3175
  const uint64_t ne01 = src0->ne[1];
@@ -3071,7 +3186,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3071
3186
  const uint64_t nb2 = dst->nb[2];
3072
3187
  const uint64_t nb3 = dst->nb[3];
3073
3188
 
3074
- vk_pipeline * pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, dst, op);
3189
+ const bool use_src2 = src2 != nullptr;
3190
+ const uint64_t ne2 = use_src2 ? src2->ne[0] * src2->ne[1] : 0;
3191
+
3192
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
3075
3193
  ggml_vk_func_t op_func;
3076
3194
 
3077
3195
  if (pipeline == nullptr) {
@@ -3092,40 +3210,50 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3092
3210
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
3093
3211
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
3094
3212
  ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
3213
+ ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
3095
3214
 
3096
3215
  vk_buffer d_X = nullptr;
3097
3216
  size_t x_buf_offset = 0;
3098
3217
  vk_buffer d_Y = nullptr;
3099
3218
  size_t y_buf_offset = 0;
3219
+ vk_buffer d_Z = nullptr;
3220
+ size_t z_buf_offset = 0;
3100
3221
 
3101
3222
  bool src0_uma = false;
3102
3223
  bool src1_uma = false;
3224
+ bool src2_uma = false;
3103
3225
 
3104
- if (ctx->device.lock()->uma) {
3226
+ if (ctx->device->uma) {
3105
3227
  ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset);
3106
3228
  src0_uma = d_X != nullptr;
3107
3229
  if (use_src1) {
3108
3230
  ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset);
3109
3231
  src1_uma = d_Y != nullptr;
3110
3232
  }
3233
+ if (use_src2) {
3234
+ ggml_vk_host_get(ctx, src1->data, d_Z, z_buf_offset);
3235
+ src2_uma = d_Z != nullptr;
3236
+ }
3111
3237
  }
3112
3238
 
3113
- const bool transfer_src0 = src0->backend != GGML_BACKEND_GPU && !src0_uma;
3114
- const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_GPU && !src1_uma;
3239
+ const bool transfer_src0 = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
3240
+ const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
3241
+ const bool transfer_src2 = use_src2 && src2->backend != GGML_BACKEND_TYPE_GPU && !src2_uma;
3115
3242
 
3116
- uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
3117
- uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : 0;
3243
+ uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
3244
+ uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
3245
+ uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
3118
3246
  uint64_t d_sz = ggml_type_size(dst->type) * ne0;
3119
3247
 
3120
3248
  vk_buffer d_D = extra->buffer_gpu.lock();
3121
3249
 
3122
3250
  // Workaround for tiny tensor inputs on ROPE
3123
- if (use_src1 && src1->backend == GGML_BACKEND_GPU && y_sz > d_D->size) {
3251
+ if (use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU && y_sz > d_D->size) {
3124
3252
  y_sz = VK_WHOLE_SIZE;
3125
3253
  }
3126
3254
 
3127
3255
  GGML_ASSERT(d_D != nullptr);
3128
- uint64_t d_buf_offset = (extra->offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
3256
+ uint64_t d_buf_offset = (extra->offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
3129
3257
  GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT
3130
3258
  if (transfer_src0) {
3131
3259
  d_X = ctx->prealloc_qx;
@@ -3142,6 +3270,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3142
3270
  GGML_ASSERT(d_Y != nullptr);
3143
3271
  }
3144
3272
 
3273
+ GGML_ASSERT(!transfer_src2);
3274
+ if (use_src2 && !src2_uma) {
3275
+ d_Z = extra_src2->buffer_gpu.lock();
3276
+ z_buf_offset = extra_src2->offset;
3277
+ GGML_ASSERT(d_Z != nullptr);
3278
+ }
3279
+
3145
3280
  if (op == GGML_OP_CPY) {
3146
3281
  GGML_ASSERT(!transfer_src0);
3147
3282
  GGML_ASSERT(!transfer_src1);
@@ -3169,7 +3304,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3169
3304
 
3170
3305
  // Single call if dimension 2 is contiguous
3171
3306
  if (op == GGML_OP_CPY || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
3172
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, 1);
3307
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
3173
3308
 
3174
3309
  switch (dst->op) {
3175
3310
  case GGML_OP_NORM:
@@ -3198,26 +3333,42 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3198
3333
  }
3199
3334
  }
3200
3335
 
3201
- if (!use_src1 && op == GGML_OP_SOFT_MAX) {
3202
- // Empty src1 is possible on soft_max, but the shader needs a buffer
3336
+ if (op == GGML_OP_SOFT_MAX) {
3337
+ // Empty src1 and src2 are possible on soft_max, but the shader needs buffers
3338
+ vk_subbuffer subbuf_y;
3339
+ if (use_src1) {
3340
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
3341
+ } else {
3342
+ subbuf_y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
3343
+ }
3344
+
3345
+ vk_subbuffer subbuf_z;
3346
+ if (use_src2) {
3347
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
3348
+ } else {
3349
+ subbuf_z = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
3350
+ }
3351
+
3203
3352
  ggml_vk_sync_buffers(subctx);
3204
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3353
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3205
3354
  } else if (use_src1) {
3206
3355
  ggml_vk_sync_buffers(subctx);
3207
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3356
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3208
3357
  } else {
3209
3358
  ggml_vk_sync_buffers(subctx);
3210
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3359
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3211
3360
  }
3212
- if (dst->backend == GGML_BACKEND_CPU && op == GGML_OP_CPY) {
3361
+ if (dst->backend == GGML_BACKEND_TYPE_CPU && op == GGML_OP_CPY) {
3213
3362
  ggml_vk_d2h_tensor_2d(ctx, subctx, d_D, 0, dst);
3214
- } else if(dst->backend == GGML_BACKEND_CPU) {
3363
+ } else if(dst->backend == GGML_BACKEND_TYPE_CPU) {
3215
3364
  // copy dst to host
3216
3365
  float * d = (float *) dst->data;
3217
3366
  ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, d_sz);
3218
3367
  }
3219
3368
  } else {
3220
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne02 * ne03);
3369
+ GGML_ASSERT(op != GGML_OP_SOFT_MAX);
3370
+
3371
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03);
3221
3372
 
3222
3373
  switch (dst->op) {
3223
3374
  case GGML_OP_NORM:
@@ -3242,18 +3393,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3242
3393
  const uint32_t y_offset = y_sz * it_idx1;
3243
3394
  const uint32_t d_offset = d_sz * it_idx0;
3244
3395
 
3245
- if (!use_src1 && op == GGML_OP_SOFT_MAX) {
3246
- // Empty src1 is possible on soft_max, but the shader needs a buffer
3247
- ggml_vk_sync_buffers(subctx);
3248
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3249
- } else if (use_src1) {
3396
+ if (use_src1) {
3250
3397
  ggml_vk_sync_buffers(subctx);
3251
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3398
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3252
3399
  } else {
3253
3400
  ggml_vk_sync_buffers(subctx);
3254
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3401
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3255
3402
  }
3256
- if (dst->backend == GGML_BACKEND_CPU) {
3403
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
3257
3404
  // copy dst to host
3258
3405
  ggml_vk_buffer_read_async(ctx, subctx, d_D, d_buf_offset + d_offset, (char *) dst->data + i02*nb2 + i03*nb3, d_sz);
3259
3406
  }
@@ -3263,69 +3410,141 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3263
3410
  }
3264
3411
 
3265
3412
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3266
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3413
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3267
3414
  }
3268
3415
 
3269
3416
  static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3270
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3417
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3271
3418
  }
3272
3419
 
3273
3420
  static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3274
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ADD, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3421
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3422
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
3423
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3424
+
3425
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
3426
+ (uint32_t)ggml_nelements(src0),
3427
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3428
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
3429
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3430
+ 0,
3431
+ 0.0f, 0.0f,
3432
+ });
3275
3433
  }
3276
3434
 
3277
3435
  static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3278
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_MUL, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3436
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3437
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
3438
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3439
+
3440
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
3441
+ (uint32_t)ggml_nelements(src0),
3442
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3443
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
3444
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3445
+ 0,
3446
+ 0.0f, 0.0f,
3447
+ });
3279
3448
  }
3280
3449
 
3281
3450
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3282
3451
  float * op_params = (float *)dst->op_params;
3283
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SCALE, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
3452
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3453
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3454
+
3455
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
3456
+ (uint32_t)ggml_nelements(src0),
3457
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3458
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3459
+ 0,
3460
+ op_params[0], 0.0f
3461
+ });
3284
3462
  }
3285
3463
 
3286
3464
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3287
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SQR, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3465
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3466
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3467
+
3468
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
3469
+ (uint32_t)ggml_nelements(src0),
3470
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3471
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3472
+ 0,
3473
+ 0.0f, 0.0f,
3474
+ });
3288
3475
  }
3289
3476
 
3290
3477
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3291
3478
  float * op_params = (float *)dst->op_params;
3292
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CLAMP, { (uint32_t)ggml_nelements(src0), 0, op_params[0], op_params[1] });
3479
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3480
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3481
+
3482
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
3483
+ (uint32_t)ggml_nelements(src0),
3484
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3485
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3486
+ 0,
3487
+ op_params[0], op_params[1],
3488
+ });
3293
3489
  }
3294
3490
 
3295
3491
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3296
3492
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
3297
- const int src0_type_size = ggml_type_size(src0->type);
3298
- const int dst_type_size = ggml_type_size(dst->type);
3299
- const uint32_t d_offset = (extra->offset % ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
3300
- ggml_vk_op_f32<vk_op_cpy_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CPY, {
3493
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3494
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3495
+ const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
3496
+
3497
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
3301
3498
  (uint32_t)ggml_nelements(src0),
3302
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size,
3303
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size,
3499
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3500
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3304
3501
  d_offset,
3502
+ 0.0f, 0.0f,
3305
3503
  });
3306
3504
  }
3307
3505
 
3308
3506
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3309
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
3507
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
3310
3508
  }
3311
3509
 
3312
3510
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3313
3511
  float * op_params = (float *)dst->op_params;
3314
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
3512
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
3315
3513
  }
3316
3514
 
3317
3515
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3318
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3516
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3319
3517
  }
3320
3518
 
3321
3519
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3322
3520
  int32_t * op_params = (int32_t *)dst->op_params;
3323
- ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
3521
+ ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
3324
3522
  }
3325
3523
 
3326
- static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3524
+ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
3327
3525
  float * op_params = (float *)dst->op_params;
3328
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_SOFT_MAX, { (uint32_t)src0->ne[0], (uint32_t)(src1 != nullptr ? ggml_nrows(src1) : 0), op_params[0], 0.0f });
3526
+
3527
+ float scale = op_params[0];
3528
+ float max_bias = op_params[1];
3529
+
3530
+ const uint32_t ncols = (uint32_t)src0->ne[0];
3531
+ const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
3532
+ const uint32_t nrows_y = (uint32_t)src0->ne[1];
3533
+
3534
+ const uint32_t n_head_kv = nrows_x/nrows_y;
3535
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
3536
+
3537
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
3538
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
3539
+
3540
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
3541
+ ncols,
3542
+ nrows_y,
3543
+ src2 != nullptr ? (uint32_t)1 : (uint32_t)0,
3544
+ scale, max_bias,
3545
+ m0, m1,
3546
+ n_head_log2,
3547
+ });
3329
3548
  }
3330
3549
 
3331
3550
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3351,15 +3570,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
3351
3570
  if (is_neox) {
3352
3571
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
3353
3572
  const float inv_ndims = -1.0f / n_dims;
3354
- ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
3573
+ ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
3355
3574
  } else {
3356
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
3575
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
3357
3576
  }
3358
3577
  }
3359
3578
 
3579
+ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3580
+ int32_t * op_params = (int32_t *)dst->op_params;
3581
+ ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ORDER_ASC });
3582
+ }
3583
+
3360
3584
  static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3361
3585
  // If backend is CPU, data from src0 has to be copied off the device
3362
- if (dst->backend == GGML_BACKEND_CPU) {
3586
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
3363
3587
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
3364
3588
  vk_buffer d_D = extra_src0->buffer_gpu.lock();
3365
3589
  ggml_vk_sync_buffers(subctx);
@@ -3408,43 +3632,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3408
3632
  const size_t y_ne = k * n * batch;
3409
3633
  const size_t d_ne = m * n * batch;
3410
3634
 
3411
- vk_pipeline * p;
3635
+ vk_pipeline p;
3412
3636
  std::string shname;
3413
3637
  if (shader_size == 0) {
3414
3638
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3415
- p = &ctx->pipeline_matmul_f32_aligned_s;
3639
+ p = ctx->device->pipeline_matmul_f32->a_s;
3416
3640
  shname = "F32_ALIGNED_S";
3417
3641
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3418
- p = &ctx->pipeline_matmul_f16_f32_aligned_s;
3642
+ p = ctx->device->pipeline_matmul_f16_f32->a_s;
3419
3643
  shname = "F16_F32_ALIGNED_S";
3420
3644
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3421
- p = &ctx->pipeline_matmul_f16_aligned_s;
3645
+ p = ctx->device->pipeline_matmul_f16->a_s;
3422
3646
  shname = "F16_ALIGNED_S";
3423
3647
  } else {
3424
3648
  GGML_ASSERT(false);
3425
3649
  }
3426
3650
  } else if (shader_size == 1) {
3427
3651
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3428
- p = &ctx->pipeline_matmul_f32_aligned_m;
3652
+ p = ctx->device->pipeline_matmul_f32->a_m;
3429
3653
  shname = "F32_ALIGNED_M";
3430
3654
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3431
- p = &ctx->pipeline_matmul_f16_f32_aligned_m;
3655
+ p = ctx->device->pipeline_matmul_f16_f32->a_m;
3432
3656
  shname = "F16_F32_ALIGNED_M";
3433
3657
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3434
- p = &ctx->pipeline_matmul_f16_aligned_m;
3658
+ p = ctx->device->pipeline_matmul_f16->a_m;
3435
3659
  shname = "F16_ALIGNED_M";
3436
3660
  } else {
3437
3661
  GGML_ASSERT(false);
3438
3662
  }
3439
3663
  } else if (shader_size == 2) {
3440
3664
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3441
- p = &ctx->pipeline_matmul_f32_aligned_l;
3665
+ p = ctx->device->pipeline_matmul_f32->a_l;
3442
3666
  shname = "F32_ALIGNED_L";
3443
3667
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3444
- p = &ctx->pipeline_matmul_f16_f32_aligned_l;
3668
+ p = ctx->device->pipeline_matmul_f16_f32->a_l;
3445
3669
  shname = "F16_F32_ALIGNED_L";
3446
3670
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3447
- p = &ctx->pipeline_matmul_f16_aligned_l;
3671
+ p = ctx->device->pipeline_matmul_f16->a_l;
3448
3672
  shname = "F16_ALIGNED_L";
3449
3673
  } else {
3450
3674
  GGML_ASSERT(false);
@@ -3458,43 +3682,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3458
3682
  if (k != kpad) {
3459
3683
  if (shader_size == 0) {
3460
3684
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3461
- p = &ctx->pipeline_matmul_f32_s;
3685
+ p = ctx->device->pipeline_matmul_f32->s;
3462
3686
  shname = "F32_S";
3463
3687
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3464
- p = &ctx->pipeline_matmul_f16_f32_s;
3688
+ p = ctx->device->pipeline_matmul_f16_f32->s;
3465
3689
  shname = "F16_F32_S";
3466
3690
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3467
- p = &ctx->pipeline_matmul_f16_s;
3691
+ p = ctx->device->pipeline_matmul_f16->s;
3468
3692
  shname = "F16_S";
3469
3693
  }
3470
3694
  } else if (shader_size == 1) {
3471
3695
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3472
- p = &ctx->pipeline_matmul_f32_m;
3696
+ p = ctx->device->pipeline_matmul_f32->m;
3473
3697
  shname = "F32_M";
3474
3698
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3475
- p = &ctx->pipeline_matmul_f16_f32_m;
3699
+ p = ctx->device->pipeline_matmul_f16_f32->m;
3476
3700
  shname = "F16_F32_M";
3477
3701
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3478
- p = &ctx->pipeline_matmul_f16_m;
3702
+ p = ctx->device->pipeline_matmul_f16->m;
3479
3703
  shname = "F16_M";
3480
3704
  }
3481
3705
  } else if (shader_size == 2) {
3482
3706
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3483
- p = &ctx->pipeline_matmul_f32_l;
3707
+ p = ctx->device->pipeline_matmul_f32->l;
3484
3708
  shname = "F32_L";
3485
3709
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3486
- p = &ctx->pipeline_matmul_f16_f32_l;
3710
+ p = ctx->device->pipeline_matmul_f16_f32->l;
3487
3711
  shname = "F16_F32_L";
3488
3712
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3489
- p = &ctx->pipeline_matmul_f16_l;
3713
+ p = ctx->device->pipeline_matmul_f16->l;
3490
3714
  shname = "F16_L";
3491
3715
  }
3492
3716
  }
3493
3717
  }
3494
3718
 
3495
- ggml_pipeline_allocate_descriptor_sets(ctx, *p, num_it);
3719
+ ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
3496
3720
  if (split_k > 1) {
3497
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, num_it);
3721
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
3498
3722
 
3499
3723
  if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
3500
3724
  // Resize buffer
@@ -3524,9 +3748,11 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3524
3748
  }
3525
3749
  for (size_t i = 0; i < y_ne; i++) {
3526
3750
  if (std::is_same<float, Y_TYPE>()) {
3527
- y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3751
+ // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3752
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
3528
3753
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
3529
- y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
3754
+ // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
3755
+ y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
3530
3756
  } else {
3531
3757
  GGML_ASSERT(false);
3532
3758
  }
@@ -3535,17 +3761,17 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3535
3761
  ggml_vk_buffer_write(ctx, d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
3536
3762
  ggml_vk_buffer_write(ctx, d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
3537
3763
 
3538
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
3764
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3539
3765
  for (size_t i = 0; i < num_it; i++) {
3540
3766
  ggml_vk_ctx_begin(ctx, subctx);
3541
- ggml_vk_matmul(ctx, subctx, *p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
3767
+ ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
3542
3768
  ggml_vk_ctx_end(subctx);
3543
3769
  }
3544
3770
 
3545
3771
  auto begin = std::chrono::high_resolution_clock::now();
3546
3772
  ggml_vk_submit(subctx, ctx->fence);
3547
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
3548
- ctx->device.lock()->device.resetFences({ ctx->fence });
3773
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
3774
+ ctx->device->device.resetFences({ ctx->fence });
3549
3775
 
3550
3776
  auto end = std::chrono::high_resolution_clock::now();
3551
3777
  double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
@@ -3624,6 +3850,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3624
3850
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
3625
3851
  std::cerr << "Actual result: " << std::endl << std::endl;
3626
3852
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
3853
+ std::cerr << std::endl;
3854
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
3627
3855
  std::cerr << "Expected result: " << std::endl << std::endl;
3628
3856
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
3629
3857
 
@@ -3649,15 +3877,15 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3649
3877
 
3650
3878
  free(d_chk);
3651
3879
 
3652
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
3653
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
3880
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
3881
+ ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
3654
3882
 
3655
3883
  ggml_vk_destroy_buffer(d_X);
3656
3884
  ggml_vk_destroy_buffer(d_Y);
3657
3885
  ggml_vk_destroy_buffer(d_D);
3658
3886
 
3659
- ggml_pipeline_cleanup(*p);
3660
- ggml_pipeline_cleanup(ctx->pipeline_matmul_split_k_reduce);
3887
+ ggml_pipeline_cleanup(p);
3888
+ ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
3661
3889
 
3662
3890
  free(x);
3663
3891
  free(y);
@@ -3730,7 +3958,7 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
3730
3958
  data[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3731
3959
  }
3732
3960
 
3733
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
3961
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3734
3962
  ggml_vk_ctx_begin(ctx, subctx);
3735
3963
 
3736
3964
  vk_buffer buffer = ggml_vk_create_buffer_check(ctx, ggml_nbytes(tensor), vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -3739,8 +3967,8 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
3739
3967
 
3740
3968
  ggml_vk_ctx_end(subctx);
3741
3969
  ggml_vk_submit(subctx, ctx->fence);
3742
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
3743
- ctx->device.lock()->device.resetFences({ ctx->fence });
3970
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
3971
+ ctx->device->device.resetFences({ ctx->fence });
3744
3972
 
3745
3973
  ggml_vk_buffer_read(ctx, buffer, 0, result_data, ggml_nbytes(tensor));
3746
3974
 
@@ -3812,7 +4040,7 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3812
4040
  x[i] = rand() / (float)RAND_MAX;
3813
4041
  }
3814
4042
 
3815
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4043
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3816
4044
  ggml_vk_ctx_begin(ctx, subctx);
3817
4045
 
3818
4046
  auto begin = std::chrono::high_resolution_clock::now();
@@ -3826,8 +4054,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3826
4054
 
3827
4055
  ggml_vk_ctx_end(subctx);
3828
4056
  ggml_vk_submit(subctx, ctx->fence);
3829
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
3830
- ctx->device.lock()->device.resetFences({ ctx->fence });
4057
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
4058
+ ctx->device->device.resetFences({ ctx->fence });
3831
4059
 
3832
4060
  auto end = std::chrono::high_resolution_clock::now();
3833
4061
 
@@ -3841,8 +4069,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3841
4069
 
3842
4070
  ggml_vk_ctx_end(subctx);
3843
4071
  ggml_vk_submit(subctx, ctx->fence);
3844
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
3845
- ctx->device.lock()->device.resetFences({ ctx->fence });
4072
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
4073
+ ctx->device->device.resetFences({ ctx->fence });
3846
4074
 
3847
4075
  for (auto& cpy : subctx->out_memcpys) {
3848
4076
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -3873,89 +4101,118 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3873
4101
  }
3874
4102
  }
3875
4103
 
3876
- static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
3877
- #ifdef GGML_VULKAN_DEBUG
3878
- std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
3879
- #endif
3880
- const size_t x_sz = sizeof(float) * ne;
3881
- const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
3882
- const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
3883
- float * x = (float *) malloc(x_sz);
3884
- void * qx = malloc(qx_sz);
3885
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
3886
- vk_buffer x_buf = ggml_vk_create_buffer_check(ctx, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
3887
- ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
3888
-
3889
- for (size_t i = 0; i < ne; i++) {
3890
- x[i] = rand() / (float)RAND_MAX;
3891
- }
3892
-
4104
+ static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
3893
4105
  std::vector<int64_t> hist_cur(1 << 4, 0);
3894
4106
 
3895
- vk_pipeline& p = ctx->pipeline_dequant[quant];
3896
-
3897
4107
  switch(quant) {
4108
+ case GGML_TYPE_F32:
4109
+ memcpy(to, from, sizeof(float) * ne);
4110
+ break;
3898
4111
  case GGML_TYPE_Q4_0:
3899
- ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
4112
+ ggml_quantize_q4_0(from, to, ne, ne, hist_cur.data());
3900
4113
  break;
3901
4114
  case GGML_TYPE_Q4_1:
3902
- ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
4115
+ ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
3903
4116
  break;
3904
4117
  case GGML_TYPE_Q5_0:
3905
- ggml_quantize_q5_0(x, qx, ne, ne, hist_cur.data());
4118
+ ggml_quantize_q5_0(from, to, ne, ne, hist_cur.data());
3906
4119
  break;
3907
4120
  case GGML_TYPE_Q5_1:
3908
- ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
4121
+ ggml_quantize_q5_1(from, to, ne, ne, hist_cur.data());
3909
4122
  break;
3910
4123
  case GGML_TYPE_Q8_0:
3911
- ggml_quantize_q8_0(x, qx, ne, ne, hist_cur.data());
4124
+ ggml_quantize_q8_0(from, to, ne, ne, hist_cur.data());
3912
4125
  break;
3913
4126
  case GGML_TYPE_Q2_K:
3914
- ggml_quantize_q2_K(x, qx, ne, ne, hist_cur.data());
4127
+ ggml_quantize_q2_K(from, to, ne, ne, hist_cur.data());
3915
4128
  break;
3916
4129
  case GGML_TYPE_Q3_K:
3917
- ggml_quantize_q3_K(x, qx, ne, ne, hist_cur.data());
4130
+ ggml_quantize_q3_K(from, to, ne, ne, hist_cur.data());
3918
4131
  break;
3919
4132
  case GGML_TYPE_Q4_K:
3920
- ggml_quantize_q4_K(x, qx, ne, ne, hist_cur.data());
4133
+ ggml_quantize_q4_K(from, to, ne, ne, hist_cur.data());
3921
4134
  break;
3922
4135
  case GGML_TYPE_Q5_K:
3923
- ggml_quantize_q5_K(x, qx, ne, ne, hist_cur.data());
4136
+ ggml_quantize_q5_K(from, to, ne, ne, hist_cur.data());
3924
4137
  break;
3925
4138
  case GGML_TYPE_Q6_K:
3926
- ggml_quantize_q6_K(x, qx, ne, ne, hist_cur.data());
4139
+ ggml_quantize_q6_K(from, to, ne, ne, hist_cur.data());
3927
4140
  break;
3928
4141
  default:
3929
4142
  GGML_ASSERT(false);
3930
4143
  }
4144
+ }
4145
+
4146
+ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
4147
+ #ifdef GGML_VULKAN_DEBUG
4148
+ std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
4149
+ #endif
4150
+ const size_t x_sz = sizeof(float) * ne;
4151
+ const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
4152
+ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
4153
+ float * x = (float *) malloc(x_sz);
4154
+ void * qx = malloc(qx_sz);
4155
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4156
+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
4157
+ ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
4158
+
4159
+ for (size_t i = 0; i < ne; i++) {
4160
+ x[i] = rand() / (float)RAND_MAX;
4161
+ }
4162
+
4163
+ vk_pipeline p = ctx->device->pipeline_dequant[quant];
4164
+
4165
+ ggml_vk_quantize_data(x, qx, ne, quant);
3931
4166
 
3932
4167
  ggml_pipeline_allocate_descriptor_sets(ctx, p, 1);
3933
4168
 
3934
4169
  ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
3935
4170
 
3936
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4171
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3937
4172
  ggml_vk_ctx_begin(ctx, subctx);
3938
- const std::vector<int> pc = { 1, (int)ne, (int)ne, (int)ne };
4173
+ const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
3939
4174
  ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
3940
4175
  ggml_vk_ctx_end(subctx);
3941
4176
 
3942
4177
  auto begin = std::chrono::high_resolution_clock::now();
3943
4178
 
3944
4179
  ggml_vk_submit(subctx, ctx->fence);
3945
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
3946
- ctx->device.lock()->device.resetFences({ ctx->fence });
4180
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
4181
+ ctx->device->device.resetFences({ ctx->fence });
3947
4182
 
3948
4183
  auto end = std::chrono::high_resolution_clock::now();
3949
4184
 
3950
4185
  double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
3951
4186
  ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16);
3952
4187
 
4188
+ int first_err = -1;
4189
+
3953
4190
  double avg_err = 0.0;
3954
4191
  for (size_t i = 0; i < ne; i++) {
3955
- avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
4192
+ double error = std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
4193
+ avg_err += error;
4194
+
4195
+ if (first_err < 0 && error > 0.05) {
4196
+ first_err = i;
4197
+ }
3956
4198
  }
3957
4199
 
3958
- std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err / ne << std::endl;
4200
+ avg_err /= ne;
4201
+
4202
+ std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
4203
+
4204
+ if (avg_err > 0.1) {
4205
+ std::cerr << "first_error = " << first_err << std::endl;
4206
+ std::cerr << "Actual result: " << std::endl << std::endl;
4207
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
4208
+ std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
4209
+ }
4210
+ std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
4211
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
4212
+ std::cerr << x[i] << ", ";
4213
+ }
4214
+ std::cerr << std::endl;
4215
+ }
3959
4216
 
3960
4217
  ggml_vk_destroy_buffer(x_buf);
3961
4218
  ggml_vk_destroy_buffer(qx_buf);
@@ -3964,6 +4221,190 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
3964
4221
  free(qx);
3965
4222
  free(x_chk);
3966
4223
  }
4224
+
4225
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
4226
+ #ifdef GGML_VULKAN_DEBUG
4227
+ std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
4228
+ #endif
4229
+ const size_t x_ne = m * k * batch;
4230
+ const size_t y_ne = k * n * batch;
4231
+ const size_t d_ne = m * n * batch;
4232
+
4233
+ vk_pipeline p;
4234
+ std::string shname;
4235
+ if (shader_size == 0) {
4236
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
4237
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
4238
+ } else if (shader_size == 1) {
4239
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
4240
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
4241
+ } else if (shader_size == 2) {
4242
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
4243
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
4244
+ } else {
4245
+ GGML_ASSERT(0);
4246
+ }
4247
+
4248
+ const size_t kpad = ggml_vk_align_size(k, p->align);
4249
+
4250
+ if (k != kpad) {
4251
+ if (shader_size == 0) {
4252
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
4253
+ shname = std::string(ggml_type_name(quant)) + "_S";
4254
+ } else if (shader_size == 1) {
4255
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
4256
+ shname = std::string(ggml_type_name(quant)) + "_M";
4257
+ } else if (shader_size == 2) {
4258
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
4259
+ shname = std::string(ggml_type_name(quant)) + "_L";
4260
+ } else {
4261
+ GGML_ASSERT(0);
4262
+ }
4263
+ }
4264
+
4265
+ const size_t x_sz = sizeof(float) * x_ne;
4266
+ const size_t y_sz = sizeof(float) * y_ne;
4267
+ const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
4268
+ const size_t d_sz = sizeof(float) * d_ne;
4269
+ float * x = (float *) malloc(x_sz);
4270
+ float * y = (float *) malloc(y_sz);
4271
+ void * qx = malloc(qx_sz);
4272
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4273
+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4274
+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4275
+ float * d = (float *) malloc(d_sz);
4276
+ float * d_chk = (float *) malloc(d_sz);
4277
+
4278
+ for (size_t i = 0; i < x_ne; i++) {
4279
+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
4280
+ }
4281
+
4282
+ ggml_vk_quantize_data(x, qx, x_ne, quant);
4283
+
4284
+ for (size_t i = 0; i < y_ne; i++) {
4285
+ // y[i] = rand() / (float)RAND_MAX;
4286
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
4287
+ }
4288
+
4289
+ ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
4290
+ if (split_k > 1) {
4291
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
4292
+
4293
+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
4294
+ // Resize buffer
4295
+ if (ctx->prealloc_split_k != nullptr) {
4296
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
4297
+ }
4298
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
4299
+ }
4300
+ }
4301
+
4302
+ ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
4303
+ ggml_vk_buffer_write(ctx, y_buf, 0, y, y_sz);
4304
+
4305
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
4306
+ for (size_t i = 0; i < num_it; i++) {
4307
+ ggml_vk_ctx_begin(ctx, subctx);
4308
+ ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
4309
+ ggml_vk_ctx_end(subctx);
4310
+ }
4311
+
4312
+ auto begin = std::chrono::high_resolution_clock::now();
4313
+
4314
+ ggml_vk_submit(subctx, ctx->fence);
4315
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
4316
+ ctx->device->device.resetFences({ ctx->fence });
4317
+
4318
+ auto end = std::chrono::high_resolution_clock::now();
4319
+
4320
+ double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
4321
+ ggml_vk_buffer_read(ctx, d_buf, 0, d, d_sz);
4322
+
4323
+ ggml_init_params iparams = {
4324
+ /*.mem_size =*/ 1024*1024*1024,
4325
+ /*.mem_buffer =*/ NULL,
4326
+ /*.no_alloc =*/ true,
4327
+ };
4328
+
4329
+ ggml_context * ggml_ctx = ggml_init(iparams);
4330
+
4331
+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
4332
+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
4333
+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
4334
+
4335
+ src0_ggml->data = qx;
4336
+ src1_ggml->data = y;
4337
+ tensor_ggml->data = d_chk;
4338
+
4339
+ ctx->disable = true;
4340
+
4341
+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
4342
+ ggml_build_forward_expand(cgraph, tensor_ggml);
4343
+
4344
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
4345
+
4346
+ ctx->disable = false;
4347
+
4348
+ ggml_free(ggml_ctx);
4349
+
4350
+ double avg_err = 0.0;
4351
+ int first_err_n = -1;
4352
+ int first_err_m = -1;
4353
+ int first_err_b = -1;
4354
+
4355
+ for (size_t i = 0; i < m*n*batch; i++) {
4356
+ double err = std::fabs(d[i] - d_chk[i]);
4357
+ avg_err += err;
4358
+
4359
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
4360
+ first_err_b = i / (m * n);
4361
+ first_err_n = (i % (m * n)) / m;
4362
+ first_err_m = (i % (m * n)) % m;
4363
+ }
4364
+ }
4365
+
4366
+ avg_err /= m * n;
4367
+
4368
+ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
4369
+
4370
+ if (avg_err > 0.1 || std::isnan(avg_err)) {
4371
+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
4372
+ std::cerr << "Actual result: " << std::endl << std::endl;
4373
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4374
+ std::cerr << std::endl;
4375
+ std::cerr << "Expected result: " << std::endl << std::endl;
4376
+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4377
+
4378
+ if (split_k > 1) {
4379
+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
4380
+ ggml_vk_buffer_read(ctx, ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
4381
+
4382
+ std::cerr << "d_buf0: " << std::endl << std::endl;
4383
+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4384
+
4385
+ std::cerr << "d_buf1: " << std::endl << std::endl;
4386
+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4387
+
4388
+ std::cerr << "d_buf2: " << std::endl << std::endl;
4389
+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4390
+
4391
+ std::cerr << "d_buf3: " << std::endl << std::endl;
4392
+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4393
+
4394
+ free(split_k_buf);
4395
+ }
4396
+ }
4397
+
4398
+ ggml_vk_destroy_buffer(qx_buf);
4399
+ ggml_vk_destroy_buffer(y_buf);
4400
+ ggml_vk_destroy_buffer(d_buf);
4401
+
4402
+ free(x);
4403
+ free(qx);
4404
+ free(y);
4405
+ free(d);
4406
+ free(d_chk);
4407
+ }
3967
4408
  #endif
3968
4409
 
3969
4410
  static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
@@ -3976,29 +4417,19 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
3976
4417
  return extra;
3977
4418
  }
3978
4419
 
3979
- static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph * graph) {
3980
- GGML_ASSERT(node != nullptr);
3981
-
3982
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
3983
- for (int j = 0; j < GGML_MAX_SRC; j++) {
3984
- if (graph->nodes[i]->src[j] == node) {
3985
- return graph->nodes[i];
3986
- }
3987
- }
3988
- }
3989
-
3990
- return nullptr;
4420
+ static bool ggml_vk_cpu_assist_op(const ggml_tensor * node) {
4421
+ return node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID;
3991
4422
  }
3992
4423
 
3993
4424
  static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
3994
4425
  #ifdef GGML_VULKAN_DEBUG
3995
4426
  std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl;
3996
4427
  #endif
3997
- const bool any_on_device = node->backend == GGML_BACKEND_GPU
3998
- || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
3999
- || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_GPU));
4428
+ const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU
4429
+ || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4430
+ || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_TYPE_GPU));
4000
4431
 
4001
- if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
4432
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node))) {
4002
4433
  return;
4003
4434
  }
4004
4435
 
@@ -4029,7 +4460,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4029
4460
  const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
4030
4461
 
4031
4462
  int split_k;
4032
- if (node->op == GGML_OP_MUL_MAT) {
4463
+ if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
4033
4464
  split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
4034
4465
  } else {
4035
4466
  split_k = 1;
@@ -4038,11 +4469,11 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4038
4469
  const uint32_t y_ne = ne10 * ne11;
4039
4470
  const uint32_t d_ne = ne20 * ne21;
4040
4471
 
4041
- const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4042
- const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4043
- const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4044
- const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4045
- uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
4472
+ const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4473
+ const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4474
+ const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4475
+ const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4476
+ uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
4046
4477
  const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0;
4047
4478
 
4048
4479
  if (extra->buffer_gpu.expired()) {
@@ -4070,6 +4501,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4070
4501
  case GGML_OP_DIAG_MASK_INF:
4071
4502
  case GGML_OP_SOFT_MAX:
4072
4503
  case GGML_OP_ROPE:
4504
+ case GGML_OP_ARGSORT:
4073
4505
  break;
4074
4506
  case GGML_OP_UNARY:
4075
4507
  switch (ggml_get_unary_op(node)) {
@@ -4082,6 +4514,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4082
4514
  }
4083
4515
  break;
4084
4516
  case GGML_OP_MUL_MAT:
4517
+ case GGML_OP_MUL_MAT_ID:
4085
4518
  if (ctx->prealloc_size_qx < qx_sz) {
4086
4519
  ctx->prealloc_size_qx = qx_sz;
4087
4520
  }
@@ -4115,21 +4548,66 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4115
4548
  #endif
4116
4549
  #if defined(GGML_VULKAN_RUN_TESTS)
4117
4550
  ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul,
4118
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4551
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4119
4552
  vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4120
4553
  ggml_vk_test_transfer(ctx, 8192 * 1000, false);
4121
4554
  ggml_vk_test_transfer(ctx, 8192 * 1000, true);
4122
4555
 
4123
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
4124
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
4125
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
4126
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
4127
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
4128
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
4129
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
4130
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
4131
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
4132
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
4556
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
4557
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
4558
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
4559
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
4560
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
4561
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
4562
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
4563
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
4564
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
4565
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
4566
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
4567
+
4568
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
4569
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
4570
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
4571
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
4572
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
4573
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
4574
+
4575
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
4576
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
4577
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
4578
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
4579
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
4580
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
4581
+
4582
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
4583
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
4584
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
4585
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
4586
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
4587
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
4588
+
4589
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
4590
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
4591
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
4592
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
4593
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
4594
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
4595
+
4596
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
4597
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
4598
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
4599
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
4600
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
4601
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
4602
+
4603
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
4604
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
4605
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
4606
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
4607
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
4608
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
4609
+
4610
+ std::cerr << std::endl;
4133
4611
 
4134
4612
  const std::vector<size_t> vals {
4135
4613
  8, 8, 8,
@@ -4215,11 +4693,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4215
4693
  }
4216
4694
 
4217
4695
  static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){
4218
- const bool any_on_device = node->backend == GGML_BACKEND_GPU
4219
- || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
4220
- || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_GPU);
4696
+ const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU
4697
+ || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4698
+ || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_TYPE_GPU);
4221
4699
 
4222
- if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
4700
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node)) || (ggml_vk_cpu_assist_op(node) && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
4223
4701
  return;
4224
4702
  }
4225
4703
 
@@ -4231,6 +4709,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4231
4709
 
4232
4710
  const ggml_tensor * src0 = node->src[0];
4233
4711
  const ggml_tensor * src1 = node->src[1];
4712
+ const ggml_tensor * src2 = node->src[2];
4234
4713
 
4235
4714
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
4236
4715
 
@@ -4265,7 +4744,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4265
4744
  case GGML_OP_SOFT_MAX:
4266
4745
  case GGML_OP_ROPE:
4267
4746
  case GGML_OP_MUL_MAT:
4747
+ case GGML_OP_MUL_MAT_ID:
4268
4748
  case GGML_OP_NONE:
4749
+ case GGML_OP_ARGSORT:
4269
4750
  break;
4270
4751
  default:
4271
4752
  if (any_on_device) {
@@ -4276,7 +4757,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4276
4757
  }
4277
4758
 
4278
4759
  if (ctx->compute_ctx == nullptr) {
4279
- ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4760
+ ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
4280
4761
  ggml_vk_ctx_begin(ctx, ctx->compute_ctx);
4281
4762
  }
4282
4763
 
@@ -4347,16 +4828,25 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4347
4828
 
4348
4829
  break;
4349
4830
  case GGML_OP_SOFT_MAX:
4350
- ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
4831
+ ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, src2, node);
4351
4832
 
4352
4833
  break;
4353
4834
  case GGML_OP_ROPE:
4354
4835
  ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, node);
4355
4836
 
4837
+ break;
4838
+ case GGML_OP_ARGSORT:
4839
+ ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
4356
4840
  break;
4357
4841
  case GGML_OP_MUL_MAT:
4358
4842
  ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
4359
4843
 
4844
+ break;
4845
+ case GGML_OP_MUL_MAT_ID:
4846
+ //ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, node);
4847
+ std::cerr << "ggml_vulkan: GGML_OP_MUL_MAT_ID not implemented yet." << std::endl;
4848
+ GGML_ASSERT(false);
4849
+
4360
4850
  break;
4361
4851
  default:
4362
4852
  return;
@@ -4371,7 +4861,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4371
4861
  last_node = true;
4372
4862
  #endif
4373
4863
 
4374
- if (node->backend == GGML_BACKEND_CPU || last_node) {
4864
+ if (node->backend == GGML_BACKEND_TYPE_CPU || last_node) {
4375
4865
  ggml_vk_ctx_end(ctx->compute_ctx);
4376
4866
  ctx->compute_ctx->exit_tensor = node;
4377
4867
  ctx->compute_ctx = nullptr;
@@ -4379,11 +4869,11 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4379
4869
  }
4380
4870
 
4381
4871
  static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){
4382
- const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
4383
- || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
4384
- || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
4872
+ const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
4873
+ || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4874
+ || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
4385
4875
 
4386
- if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
4876
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(tensor))) {
4387
4877
  return false;
4388
4878
  }
4389
4879
 
@@ -4409,6 +4899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4409
4899
  case GGML_OP_PERMUTE:
4410
4900
  case GGML_OP_TRANSPOSE:
4411
4901
  case GGML_OP_NONE:
4902
+ case GGML_OP_ARGSORT:
4412
4903
  extra = (ggml_tensor_extra_gpu *) tensor->extra;
4413
4904
 
4414
4905
  break;
@@ -4424,6 +4915,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4424
4915
  }
4425
4916
  break;
4426
4917
  case GGML_OP_MUL_MAT:
4918
+ case GGML_OP_MUL_MAT_ID:
4427
4919
  if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
4428
4920
  return false;
4429
4921
  }
@@ -4442,7 +4934,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4442
4934
  if (params->ith != 0) {
4443
4935
  return true;
4444
4936
  }
4445
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4937
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
4446
4938
  return true;
4447
4939
  }
4448
4940
 
@@ -4469,8 +4961,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4469
4961
  }
4470
4962
 
4471
4963
  if (tensor == subctx.exit_tensor) {
4472
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
4473
- ctx->device.lock()->device.resetFences({ ctx->fence });
4964
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
4965
+ ctx->device->device.resetFences({ ctx->fence });
4474
4966
 
4475
4967
  // Do staging buffer copies
4476
4968
  for (auto& cpy : subctx.out_memcpys) {
@@ -4498,20 +4990,25 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
4498
4990
  }
4499
4991
  ctx->gc.temp_buffers.clear();
4500
4992
 
4501
- for (auto * pipeline : ctx->gc.pipelines) {
4502
- ggml_pipeline_cleanup(*pipeline);
4993
+ for (auto& pipeline : ctx->device->pipelines) {
4994
+ if (pipeline.expired()) {
4995
+ continue;
4996
+ }
4997
+
4998
+ vk_pipeline pl = pipeline.lock();
4999
+ ggml_pipeline_cleanup(pl);
4503
5000
  }
4504
5001
 
4505
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
4506
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
5002
+ ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
5003
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
4507
5004
 
4508
5005
  for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
4509
- ctx->device.lock()->device.destroySemaphore({ ctx->gc.semaphores[i].s });
5006
+ ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
4510
5007
  }
4511
5008
  ctx->gc.semaphores.clear();
4512
5009
 
4513
5010
  for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
4514
- ctx->device.lock()->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
5011
+ ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
4515
5012
  }
4516
5013
  ctx->gc.tl_semaphores.clear();
4517
5014
  ctx->semaphore_idx = 0;
@@ -4519,7 +5016,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
4519
5016
  ctx->event_idx = 0;
4520
5017
 
4521
5018
  for (auto& event : ctx->gc.events) {
4522
- ctx->device.lock()->device.resetEvent(event);
5019
+ ctx->device->device.resetEvent(event);
4523
5020
  }
4524
5021
 
4525
5022
  ctx->staging_offset = 0;
@@ -4556,21 +5053,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
4556
5053
  ctx->staging_size = 0;
4557
5054
 
4558
5055
  for (auto& event : ctx->gc.events) {
4559
- ctx->device.lock()->device.destroyEvent(event);
5056
+ ctx->device->device.destroyEvent(event);
4560
5057
  }
4561
5058
  ctx->gc.events.clear();
4562
5059
 
4563
- for (auto* pipeline : ctx->gc.pipelines) {
4564
- ggml_vk_destroy_pipeline(ctx, pipeline);
4565
- }
4566
- ctx->gc.pipelines.clear();
4567
-
4568
- ctx->device.lock()->device.destroyFence(ctx->fence);
4569
-
4570
- ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->compute_queue.pool);
4571
- if (!ctx->device.lock()->single_queue) {
4572
- ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->transfer_queue.pool);
4573
- }
5060
+ ctx->device->device.destroyFence(ctx->fence);
4574
5061
  }
4575
5062
 
4576
5063
  GGML_CALL static int ggml_vk_get_device_count() {
@@ -4745,7 +5232,7 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b
4745
5232
  extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
4746
5233
  }
4747
5234
 
4748
- tensor->backend = GGML_BACKEND_GPU;
5235
+ tensor->backend = GGML_BACKEND_TYPE_GPU;
4749
5236
  tensor->extra = extra;
4750
5237
  }
4751
5238
 
@@ -4753,7 +5240,7 @@ GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t bu
4753
5240
  #ifdef GGML_VULKAN_DEBUG
4754
5241
  std::cerr << "ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl;
4755
5242
  #endif
4756
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5243
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
4757
5244
 
4758
5245
  ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4759
5246
 
@@ -4768,7 +5255,7 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
4768
5255
  #ifdef GGML_VULKAN_DEBUG
4769
5256
  std::cerr << "ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl;
4770
5257
  #endif
4771
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5258
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
4772
5259
 
4773
5260
  ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4774
5261
 
@@ -4781,7 +5268,6 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
4781
5268
 
4782
5269
  GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
4783
5270
  if (ggml_backend_buffer_is_vk(src->buffer)) {
4784
- ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4785
5271
  ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
4786
5272
  ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
4787
5273
 
@@ -4793,6 +5279,8 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu
4793
5279
  return true;
4794
5280
  }
4795
5281
  return false;
5282
+
5283
+ UNUSED(buffer);
4796
5284
  }
4797
5285
 
4798
5286
  GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4839,12 +5327,12 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
4839
5327
 
4840
5328
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4841
5329
  ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
4842
- return ctx->ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
5330
+ return ctx->ctx->device->properties.limits.minStorageBufferOffsetAlignment;
4843
5331
  }
4844
5332
 
4845
5333
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
4846
5334
  ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
4847
- return ctx->ctx->device.lock()->max_memory_allocation_size;
5335
+ return ctx->ctx->device->max_memory_allocation_size;
4848
5336
  }
4849
5337
 
4850
5338
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@@ -4930,7 +5418,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu
4930
5418
  }
4931
5419
 
4932
5420
  GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4933
- return vk_instance.contexts[0].device.lock()->properties.limits.minMemoryMapAlignment;
5421
+ return vk_instance.contexts[0].device->properties.limits.minMemoryMapAlignment;
4934
5422
 
4935
5423
  UNUSED(buft);
4936
5424
  }
@@ -4975,8 +5463,7 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
4975
5463
 
4976
5464
  ggml_vk_cleanup(ctx);
4977
5465
 
4978
- // Release device
4979
- vk_instance.devices[ctx->idx].reset();
5466
+ ctx->device.reset();
4980
5467
  ctx->initialized = false;
4981
5468
 
4982
5469
  vk_instance.initialized[idx] = false;
@@ -4999,13 +5486,13 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
4999
5486
  #endif
5000
5487
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5001
5488
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
5002
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5489
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
5003
5490
 
5004
5491
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5005
5492
 
5006
5493
  if (ctx->transfer_ctx == nullptr) {
5007
5494
  // Initialize new transfer context
5008
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5495
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5009
5496
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5010
5497
  }
5011
5498
 
@@ -5020,13 +5507,13 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
5020
5507
  #endif
5021
5508
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5022
5509
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
5023
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5510
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
5024
5511
 
5025
5512
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5026
5513
 
5027
5514
  if (ctx->transfer_ctx == nullptr) {
5028
5515
  // Initialize new transfer context
5029
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5516
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5030
5517
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5031
5518
  }
5032
5519
 
@@ -5046,7 +5533,7 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
5046
5533
 
5047
5534
  if (ctx->transfer_ctx == nullptr) {
5048
5535
  // Initialize new transfer context
5049
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5536
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5050
5537
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5051
5538
  }
5052
5539
 
@@ -5076,8 +5563,8 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
5076
5563
  }
5077
5564
 
5078
5565
  ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
5079
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
5080
- ctx->device.lock()->device.resetFences({ ctx->fence });
5566
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
5567
+ ctx->device->device.resetFences({ ctx->fence });
5081
5568
 
5082
5569
  for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
5083
5570
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -5086,7 +5573,7 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
5086
5573
  ctx->transfer_ctx = nullptr;
5087
5574
  }
5088
5575
 
5089
- GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
5576
+ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
5090
5577
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5091
5578
 
5092
5579
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -5097,7 +5584,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5097
5584
  int last_node = cgraph->n_nodes - 1;
5098
5585
 
5099
5586
  // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
5100
- while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_GPU) {
5587
+ while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_TYPE_GPU) {
5101
5588
  last_node -= 1;
5102
5589
  }
5103
5590
 
@@ -5106,7 +5593,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5106
5593
  }
5107
5594
 
5108
5595
  ggml_compute_params params = {};
5109
- params.type = GGML_TASK_COMPUTE;
5596
+ params.type = GGML_TASK_TYPE_COMPUTE;
5110
5597
  params.ith = 0;
5111
5598
  for (int i = 0; i < cgraph->n_nodes; i++) {
5112
5599
  ggml_tensor * node = cgraph->nodes[i];
@@ -5129,7 +5616,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5129
5616
 
5130
5617
  ggml_vk_graph_cleanup(ctx);
5131
5618
 
5132
- return true;
5619
+ return GGML_STATUS_SUCCESS;
5133
5620
 
5134
5621
  UNUSED(backend);
5135
5622
  }
@@ -5147,6 +5634,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
5147
5634
  }
5148
5635
  break;
5149
5636
  case GGML_OP_MUL_MAT:
5637
+ case GGML_OP_MUL_MAT_ID:
5150
5638
  {
5151
5639
  struct ggml_tensor * a;
5152
5640
  struct ggml_tensor * b;
@@ -5220,6 +5708,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
5220
5708
  case GGML_OP_CONT:
5221
5709
  case GGML_OP_DIAG_MASK_INF:
5222
5710
  case GGML_OP_SOFT_MAX:
5711
+ case GGML_OP_ARGSORT:
5223
5712
  return true;
5224
5713
  default:
5225
5714
  return false;
@@ -5244,6 +5733,11 @@ static ggml_backend_i ggml_backend_vk_interface = {
5244
5733
  /* .supports_op = */ ggml_backend_vk_supports_op,
5245
5734
  };
5246
5735
 
5736
+ static ggml_guid_t ggml_backend_vk_guid() {
5737
+ static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
5738
+ return &guid;
5739
+ }
5740
+
5247
5741
  GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t idx) {
5248
5742
  if (vk_instance.initialized[idx]) {
5249
5743
  return vk_instance.backends[idx];
@@ -5262,6 +5756,7 @@ GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t idx) {
5262
5756
  vk_instance.initialized[idx] = true;
5263
5757
 
5264
5758
  ggml_backend_t vk_backend = new ggml_backend {
5759
+ /* .guid = */ ggml_backend_vk_guid(),
5265
5760
  /* .interface = */ ggml_backend_vk_interface,
5266
5761
  /* .context = */ &vk_instance.contexts[ctx->idx],
5267
5762
  };
@@ -5272,7 +5767,7 @@ GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t idx) {
5272
5767
  }
5273
5768
 
5274
5769
  GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend) {
5275
- return backend && backend->iface.get_name == ggml_backend_vk_name;
5770
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
5276
5771
  }
5277
5772
 
5278
5773
  GGML_CALL int ggml_backend_vk_get_device_count() {
@@ -5410,13 +5905,14 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d
5410
5905
  static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) {
5411
5906
  void * tensor_data = tensor->data;
5412
5907
 
5413
- if (tensor->backend == GGML_BACKEND_GPU) {
5908
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5414
5909
  const size_t tensor_size = ggml_nbytes(tensor);
5415
5910
  tensor_data = malloc(tensor_size);
5416
5911
 
5417
5912
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5418
5913
 
5419
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, extra->offset, tensor_data, tensor_size);
5914
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5915
+ ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size);
5420
5916
  }
5421
5917
 
5422
5918
  std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
@@ -5436,14 +5932,14 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
5436
5932
  std::vector<const ggml_tensor *> done;
5437
5933
  ggml_vk_print_graph_origin(tensor, done);
5438
5934
 
5439
- if (tensor->backend == GGML_BACKEND_GPU) {
5935
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5440
5936
  free(tensor_data);
5441
5937
  }
5442
5938
  }
5443
5939
 
5444
5940
  static void ggml_vk_check_tensor(const std::string& name, const ggml_tensor * tensor) {
5445
5941
  return;
5446
- GGML_ASSERT(tensor->backend == GGML_BACKEND_CPU);
5942
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_CPU);
5447
5943
  if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
5448
5944
  return;
5449
5945
  }
@@ -5481,7 +5977,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5481
5977
  if (params->ith != 0) {
5482
5978
  return;
5483
5979
  }
5484
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
5980
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
5485
5981
  return;
5486
5982
  }
5487
5983
 
@@ -5492,6 +5988,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5492
5988
 
5493
5989
  ggml_tensor * src0 = tensor->src[0];
5494
5990
  ggml_tensor * src1 = tensor->src[1];
5991
+ ggml_tensor * src2 = tensor->src[2];
5495
5992
 
5496
5993
  struct ggml_init_params iparams = {
5497
5994
  /*.mem_size =*/ 1024*1024*1024,
@@ -5503,13 +6000,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5503
6000
 
5504
6001
  struct ggml_tensor * src0_clone = nullptr;
5505
6002
  struct ggml_tensor * src1_clone = nullptr;
6003
+ struct ggml_tensor * src2_clone = nullptr;
5506
6004
  struct ggml_tensor * tensor_clone = nullptr;
5507
6005
 
5508
6006
  size_t src0_size;
5509
6007
  size_t src1_size;
6008
+ size_t src2_size;
5510
6009
 
5511
6010
  void * src0_buffer;
5512
6011
  void * src1_buffer;
6012
+ void * src2_buffer;
5513
6013
 
5514
6014
  if (src0 != nullptr) {
5515
6015
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -5518,17 +6018,18 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5518
6018
 
5519
6019
  src0_buffer = malloc(src0_size);
5520
6020
  src0_clone->data = src0_buffer;
5521
- if (src0->backend == GGML_BACKEND_CPU) {
6021
+ if (src0->backend == GGML_BACKEND_TYPE_CPU) {
5522
6022
  memcpy(src0_clone->data, src0->data, src0_size);
5523
6023
  memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
5524
- } else if (src0->backend == GGML_BACKEND_GPU) {
6024
+ } else if (src0->backend == GGML_BACKEND_TYPE_GPU) {
5525
6025
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
6026
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5526
6027
  uint64_t offset = extra->offset;
5527
6028
  if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
5528
6029
  for (int i3 = 0; i3 < src0->ne[3]; i3++) {
5529
6030
  for (int i2 = 0; i2 < src0->ne[2]; i2++) {
5530
6031
  const int idx = i3*src0->ne[2] + i2;
5531
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
6032
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
5532
6033
  }
5533
6034
  }
5534
6035
 
@@ -5538,10 +6039,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5538
6039
  src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
5539
6040
  }
5540
6041
  } else {
5541
- if (offset + src0_size >= extra->buffer_gpu->size) {
5542
- src0_size = extra->buffer_gpu->size - offset;
6042
+ if (offset + src0_size >= buffer_gpu->size) {
6043
+ src0_size = buffer_gpu->size - offset;
5543
6044
  }
5544
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset, src0_clone->data, src0_size);
6045
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size);
5545
6046
  memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
5546
6047
  }
5547
6048
  } else {
@@ -5561,17 +6062,18 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5561
6062
 
5562
6063
  src1_buffer = malloc(src1_size);
5563
6064
  src1_clone->data = src1_buffer;
5564
- if (src1->backend == GGML_BACKEND_CPU) {
6065
+ if (src1->backend == GGML_BACKEND_TYPE_CPU) {
5565
6066
  memcpy(src1_clone->data, src1->data, src1_size);
5566
6067
  memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
5567
- } else if (src1->backend == GGML_BACKEND_GPU) {
6068
+ } else if (src1->backend == GGML_BACKEND_TYPE_GPU) {
5568
6069
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
6070
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5569
6071
  uint64_t offset = extra->offset;
5570
6072
  if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
5571
6073
  for (int i3 = 0; i3 < src1->ne[3]; i3++) {
5572
6074
  for (int i2 = 0; i2 < src1->ne[2]; i2++) {
5573
6075
  const int idx = i3*src1->ne[2] + i2;
5574
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
6076
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
5575
6077
  }
5576
6078
  }
5577
6079
 
@@ -5581,10 +6083,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5581
6083
  src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
5582
6084
  }
5583
6085
  } else {
5584
- if (offset + src1_size >= extra->buffer_gpu->size) {
5585
- src1_size = extra->buffer_gpu->size - offset;
6086
+ if (offset + src1_size >= buffer_gpu->size) {
6087
+ src1_size = buffer_gpu->size - offset;
5586
6088
  }
5587
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset, src1_clone->data, src1_size);
6089
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size);
5588
6090
  memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
5589
6091
  }
5590
6092
  } else {
@@ -5613,6 +6115,66 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5613
6115
 
5614
6116
  ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone);
5615
6117
  }
6118
+ if (src2 != nullptr) {
6119
+ src2_clone = ggml_dup_tensor(ggml_ctx, src2);
6120
+
6121
+ src2_size = ggml_nbytes(src2);
6122
+
6123
+ src2_buffer = malloc(src2_size);
6124
+ src2_clone->data = src2_buffer;
6125
+ if (src2->backend == GGML_BACKEND_TYPE_CPU) {
6126
+ memcpy(src2_clone->data, src2->data, src2_size);
6127
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
6128
+ } else if (src2->backend == GGML_BACKEND_TYPE_GPU) {
6129
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
6130
+ vk_buffer buf = extra->buffer_gpu.lock();
6131
+ uint64_t offset = extra->offset;
6132
+ if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
6133
+ for (int i3 = 0; i3 < src2->ne[3]; i3++) {
6134
+ for (int i2 = 0; i2 < src2->ne[2]; i2++) {
6135
+ const int idx = i3*src2->ne[2] + i2;
6136
+ ggml_vk_buffer_read(ctx, buf, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
6137
+ }
6138
+ }
6139
+
6140
+ src2_clone->nb[0] = src2->nb[0];
6141
+ src2_clone->nb[1] = src2->nb[1];
6142
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
6143
+ src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
6144
+ }
6145
+ } else {
6146
+ if (offset + src2_size >= buf->size) {
6147
+ src2_size = buf->size - offset;
6148
+ }
6149
+ ggml_vk_buffer_read(ctx, buf, offset, src2_clone->data, src2_size);
6150
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
6151
+ }
6152
+ } else {
6153
+ GGML_ASSERT(false);
6154
+ }
6155
+
6156
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
6157
+ ggml_vk_print_tensor(ctx, src2, "src2");
6158
+ std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
6159
+ std::cerr << "src2_clone=" << tensor << " src2_clone->backend: " << src2_clone->backend << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
6160
+ if (src2->src[0] != nullptr) {
6161
+ std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " backend=" << src2->src[0]->backend << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
6162
+ }
6163
+ if (src2->src[1] != nullptr) {
6164
+ std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " backend=" << src2->src[1]->backend << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
6165
+ }
6166
+ std::cerr << std::endl << "Result:" << std::endl;
6167
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
6168
+ std::cerr << std::endl;
6169
+ std::cerr << std::endl << "Result:" << std::endl;
6170
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
6171
+ std::cerr << std::endl;
6172
+ std::vector<const ggml_tensor *> done;
6173
+ ggml_vk_print_graph_origin(src2_clone, done);
6174
+ }
6175
+
6176
+ ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src2", src2_clone);
6177
+ }
5616
6178
 
5617
6179
  if (tensor->op == GGML_OP_MUL_MAT) {
5618
6180
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
@@ -5632,7 +6194,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5632
6194
  tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
5633
6195
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
5634
6196
  if (src1 != nullptr) {
5635
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, *(float *)tensor->op_params);
6197
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
5636
6198
  } else {
5637
6199
  tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
5638
6200
  }
@@ -5715,6 +6277,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5715
6277
  if (src1 != nullptr) {
5716
6278
  free(src1_buffer);
5717
6279
  }
6280
+ if (src2 != nullptr) {
6281
+ free(src1_buffer);
6282
+ }
5718
6283
 
5719
6284
  ggml_free(ggml_ctx);
5720
6285
  }
@@ -5723,7 +6288,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5723
6288
  if (params->ith != 0) {
5724
6289
  return;
5725
6290
  }
5726
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
6291
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
5727
6292
  return;
5728
6293
  }
5729
6294
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
@@ -5735,17 +6300,18 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5735
6300
 
5736
6301
  void * tensor_data = tensor->data;
5737
6302
 
5738
- if (tensor->backend == GGML_BACKEND_GPU) {
6303
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5739
6304
  size_t tensor_size = ggml_nbytes(tensor);
5740
6305
  tensor_data = malloc(tensor_size);
5741
6306
 
5742
6307
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5743
6308
 
5744
- if (extra->offset + tensor_size >= extra->buffer_gpu->size) {
5745
- tensor_size = extra->buffer_gpu->size - (extra->offset);
6309
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
6310
+ if (extra->offset + tensor_size >= buffer_gpu->size) {
6311
+ tensor_size = buffer_gpu->size - (extra->offset);
5746
6312
  }
5747
6313
 
5748
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, extra->offset, tensor_data, tensor_size);
6314
+ ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size);
5749
6315
  }
5750
6316
 
5751
6317
  float first_error_result = -1.0f;
@@ -5868,7 +6434,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5868
6434
  comp_result = nullptr;
5869
6435
  comp_size = 0;
5870
6436
 
5871
- if (tensor->backend == GGML_BACKEND_GPU) {
6437
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5872
6438
  free(tensor_data);
5873
6439
  }
5874
6440
  }