llama_cpp 0.12.7 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,33 @@ struct vk_queue {
69
69
  vk::PipelineStageFlags stage_flags;
70
70
  };
71
71
 
72
+ struct vk_pipeline_struct {
73
+ std::string name;
74
+ vk::ShaderModule shader_module;
75
+ vk::DescriptorSetLayout dsl;
76
+ std::vector<vk::DescriptorPool> descriptor_pools;
77
+ std::vector<vk::DescriptorSet> descriptor_sets;
78
+ uint32_t descriptor_set_idx;
79
+ vk::PipelineLayout layout;
80
+ vk::Pipeline pipeline;
81
+ uint32_t push_constant_size;
82
+ uint32_t parameter_count;
83
+ std::array<uint32_t, 3> wg_denoms;
84
+ uint32_t align;
85
+ };
86
+
87
+ typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
88
+ typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
89
+
90
+ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
91
+
92
+ struct vk_matmul_pipeline_struct {
93
+ vk_pipeline l, m, s;
94
+ vk_pipeline a_l, a_m, a_s;
95
+ };
96
+
97
+ typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
98
+
72
99
  struct vk_device {
73
100
  vk::PhysicalDevice physical_device;
74
101
  vk::PhysicalDeviceProperties properties;
@@ -84,10 +111,61 @@ struct vk_device {
84
111
  uint32_t subgroup_size;
85
112
  bool uma;
86
113
 
114
+ bool initialized;
115
+ size_t idx;
116
+
117
+ vk_matmul_pipeline pipeline_matmul_f32;
118
+ vk_matmul_pipeline pipeline_matmul_f16;
119
+ vk_matmul_pipeline pipeline_matmul_f16_f32;
120
+ vk_pipeline pipeline_matmul_split_k_reduce;
121
+
122
+ vk_matmul_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
123
+
124
+ vk_pipeline pipeline_dequant[VK_NUM_TYPES];
125
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
126
+
127
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
128
+ vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
129
+ vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
130
+ vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
131
+ vk_pipeline pipeline_mul_f32;
132
+ vk_pipeline pipeline_add_f32;
133
+ vk_pipeline pipeline_scale_f32;
134
+ vk_pipeline pipeline_sqr_f32;
135
+ vk_pipeline pipeline_clamp_f32;
136
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
137
+ vk_pipeline pipeline_norm_f32;
138
+ vk_pipeline pipeline_rms_norm_f32;
139
+ vk_pipeline pipeline_gelu_f32;
140
+ vk_pipeline pipeline_silu_f32;
141
+ vk_pipeline pipeline_relu_f32;
142
+ vk_pipeline pipeline_diag_mask_inf_f32;
143
+ vk_pipeline pipeline_soft_max_f32;
144
+ vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
145
+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
146
+ vk_pipeline pipeline_argsort_f32;
147
+
148
+ std::vector<vk_pipeline_ref> pipelines;
149
+
87
150
  ~vk_device() {
88
151
  #ifdef GGML_VULKAN_DEBUG
89
152
  std::cerr << "destroy device " << name << std::endl;
90
153
  #endif
154
+ device.destroyCommandPool(compute_queue.pool);
155
+ if (!single_queue) {
156
+ device.destroyCommandPool(transfer_queue.pool);
157
+ }
158
+
159
+ for (auto& pipeline : pipelines) {
160
+ if (pipeline.expired()) {
161
+ continue;
162
+ }
163
+
164
+ vk_pipeline pl = pipeline.lock();
165
+ ggml_vk_destroy_pipeline(device, pl);
166
+ }
167
+ pipelines.clear();
168
+
91
169
  device.destroy();
92
170
  }
93
171
  };
@@ -125,21 +203,6 @@ struct vk_subbuffer {
125
203
  uint64_t size;
126
204
  };
127
205
 
128
- struct vk_pipeline {
129
- std::string name;
130
- vk::ShaderModule shader_module;
131
- vk::DescriptorSetLayout dsl;
132
- std::vector<vk::DescriptorPool> descriptor_pools;
133
- std::vector<vk::DescriptorSet> descriptor_sets;
134
- uint32_t descriptor_set_idx;
135
- vk::PipelineLayout layout;
136
- vk::Pipeline pipeline;
137
- uint32_t push_constant_size;
138
- uint32_t parameter_count;
139
- std::array<uint32_t, 3> wg_denoms;
140
- uint32_t align;
141
- };
142
-
143
206
  struct vk_semaphore {
144
207
  vk::Semaphore s;
145
208
  uint64_t value;
@@ -160,11 +223,21 @@ struct vk_op_push_constants {
160
223
  float param2;
161
224
  };
162
225
 
163
- struct vk_op_cpy_push_constants {
226
+ struct vk_op_unary_push_constants {
227
+ uint32_t ne;
228
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
229
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
230
+ uint32_t d_offset;
231
+ float param1; float param2;
232
+ };
233
+
234
+ struct vk_op_binary_push_constants {
164
235
  uint32_t ne;
165
- uint32_t ne00; uint32_t ne01; uint32_t nb00; uint32_t nb01; uint32_t nb02;
166
- uint32_t ne10; uint32_t ne11; uint32_t nb10; uint32_t nb11; uint32_t nb12;
236
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
237
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
238
+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
167
239
  uint32_t d_offset;
240
+ float param1; float param2;
168
241
  };
169
242
 
170
243
  struct vk_op_diag_mask_push_constants {
@@ -196,6 +269,22 @@ struct vk_op_rope_neox_push_constants {
196
269
  float inv_ndims;
197
270
  };
198
271
 
272
+ struct vk_op_soft_max_push_constants {
273
+ uint32_t KX;
274
+ uint32_t KY;
275
+ uint32_t KZ;
276
+ float scale;
277
+ float max_bias;
278
+ float m0;
279
+ float m1;
280
+ uint32_t n_head_log2;
281
+ };
282
+
283
+ struct vk_op_argsort_push_constants {
284
+ uint32_t ncols;
285
+ bool ascending;
286
+ };
287
+
199
288
  // Allow pre-recording command buffers
200
289
  struct vk_staging_memcpy {
201
290
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -236,7 +325,6 @@ struct ggml_tensor_extra_gpu {
236
325
  };
237
326
 
238
327
  struct ggml_vk_garbage_collector {
239
- std::vector<vk_pipeline *> pipelines;
240
328
  std::vector<vk_semaphore> tl_semaphores;
241
329
  std::vector<vk_semaphore> semaphores;
242
330
  std::vector<vk::Event> events;
@@ -247,35 +335,7 @@ struct ggml_vk_garbage_collector {
247
335
  struct ggml_backend_vk_context {
248
336
  std::string name;
249
337
 
250
- std::weak_ptr<vk_device> device;
251
- vk_pipeline pipeline_matmul_f32_l, pipeline_matmul_f32_m, pipeline_matmul_f32_s;
252
- vk_pipeline pipeline_matmul_f32_aligned_l, pipeline_matmul_f32_aligned_m, pipeline_matmul_f32_aligned_s;
253
- vk_pipeline pipeline_matmul_f16_l, pipeline_matmul_f16_m, pipeline_matmul_f16_s;
254
- vk_pipeline pipeline_matmul_f16_aligned_l, pipeline_matmul_f16_aligned_m, pipeline_matmul_f16_aligned_s;
255
- vk_pipeline pipeline_matmul_f16_f32_l, pipeline_matmul_f16_f32_m, pipeline_matmul_f16_f32_s;
256
- vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
257
- vk_pipeline pipeline_matmul_split_k_reduce;
258
- vk_pipeline pipeline_dequant[VK_NUM_TYPES];
259
- vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
260
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
261
- vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
262
- vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
263
- vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
264
- vk_pipeline pipeline_mul_f32;
265
- vk_pipeline pipeline_add_f32;
266
- vk_pipeline pipeline_scale_f32;
267
- vk_pipeline pipeline_sqr_f32;
268
- vk_pipeline pipeline_clamp_f32;
269
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
270
- vk_pipeline pipeline_norm_f32;
271
- vk_pipeline pipeline_rms_norm_f32;
272
- vk_pipeline pipeline_gelu_f32;
273
- vk_pipeline pipeline_silu_f32;
274
- vk_pipeline pipeline_relu_f32;
275
- vk_pipeline pipeline_diag_mask_inf_f32;
276
- vk_pipeline pipeline_soft_max_f32;
277
- vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
278
- vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
338
+ std::shared_ptr<vk_device> device;
279
339
 
280
340
  size_t semaphore_idx, event_idx;
281
341
  ggml_vk_garbage_collector gc;
@@ -304,13 +364,31 @@ struct vk_instance {
304
364
 
305
365
  std::vector<size_t> device_indices;
306
366
 
307
- std::shared_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
308
367
  ggml_backend_t backends[GGML_VK_MAX_DEVICES];
309
368
  ggml_backend_vk_context contexts[GGML_VK_MAX_DEVICES];
310
369
  ggml_backend_buffer_type buffer_types[GGML_VK_MAX_DEVICES];
311
370
  bool initialized[GGML_VK_MAX_DEVICES];
312
371
  };
313
372
 
373
+ static std::shared_ptr<vk_device> ggml_vk_get_device(size_t idx) {
374
+ #ifdef GGML_VULKAN_DEBUG
375
+ std::cerr << "ggml_vk_get_device(" << idx << ")" << std::endl;
376
+ #endif
377
+ static std::weak_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
378
+
379
+ if (devices[idx].expired()) {
380
+ #ifdef GGML_VULKAN_DEBUG
381
+ std::cerr << "Initializing new vk_device" << std::endl;
382
+ #endif
383
+ std::shared_ptr<vk_device> device = std::make_shared<vk_device>();
384
+ device->initialized = false;
385
+ devices[idx] = device;
386
+ return device;
387
+ }
388
+
389
+ return devices[idx].lock();
390
+ }
391
+
314
392
  #ifdef GGML_VULKAN_CHECK_RESULTS
315
393
  static size_t vk_skip_checks;
316
394
  static size_t vk_output_tensor;
@@ -334,14 +412,15 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
334
412
  GGML_ASSERT(parameter_count > 0);
335
413
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
336
414
 
337
- pipeline.name = name;
338
- pipeline.parameter_count = parameter_count;
339
- pipeline.push_constant_size = push_constant_size;
340
- pipeline.wg_denoms = wg_denoms;
341
- pipeline.align = align;
415
+ pipeline = std::make_shared<vk_pipeline_struct>();
416
+ pipeline->name = name;
417
+ pipeline->parameter_count = parameter_count;
418
+ pipeline->push_constant_size = push_constant_size;
419
+ pipeline->wg_denoms = wg_denoms;
420
+ pipeline->align = align;
342
421
 
343
422
  vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
344
- pipeline.shader_module = ctx->device.lock()->device.createShaderModule(shader_module_create_info);
423
+ pipeline->shader_module = ctx->device->device.createShaderModule(shader_module_create_info);
345
424
 
346
425
  std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
347
426
  std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@@ -355,49 +434,49 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
355
434
  vk::PushConstantRange pcr(
356
435
  vk::ShaderStageFlagBits::eCompute,
357
436
  0,
358
- pipeline.push_constant_size
437
+ pipeline->push_constant_size
359
438
  );
360
439
 
361
440
  vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
362
441
  {},
363
442
  dsl_binding);
364
443
  descriptor_set_layout_create_info.setPNext(&dslbfci);
365
- pipeline.dsl = ctx->device.lock()->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
444
+ pipeline->dsl = ctx->device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
366
445
 
367
446
  // Check if device supports multiple descriptors per pool
368
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
447
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
369
448
  const uint32_t alloc_count = 2;
370
449
 
371
450
  // Try allocating multiple sets from one pool
372
451
  // This fails on AMD for some reason, so add a fall back to allocating one pool per set
373
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
452
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
374
453
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
375
- vk::DescriptorPool pool = ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info);
454
+ vk::DescriptorPool pool = ctx->device->device.createDescriptorPool(descriptor_pool_create_info);
376
455
 
377
456
  std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
378
457
  for (uint32_t i = 0; i < alloc_count; i++) {
379
- layouts[i] = pipeline.dsl;
458
+ layouts[i] = pipeline->dsl;
380
459
  }
381
460
  try {
382
461
  vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
383
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
462
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
384
463
  } catch(vk::OutOfPoolMemoryError const&) {
385
- ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
464
+ ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
386
465
  }
387
466
 
388
- ctx->device.lock()->device.destroyDescriptorPool(pool);
467
+ ctx->device->device.destroyDescriptorPool(pool);
389
468
  }
390
469
 
391
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
392
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
470
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
471
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
393
472
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
394
- pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
473
+ pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
395
474
  }
396
475
 
397
- pipeline.descriptor_set_idx = 0;
476
+ pipeline->descriptor_set_idx = 0;
398
477
 
399
- vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline.dsl, pcr);
400
- pipeline.layout = ctx->device.lock()->device.createPipelineLayout(pipeline_layout_create_info);
478
+ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
479
+ pipeline->layout = ctx->device->device.createPipelineLayout(pipeline_layout_create_info);
401
480
 
402
481
  std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
403
482
 
@@ -417,72 +496,75 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
417
496
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
418
497
  vk::PipelineShaderStageCreateFlags(),
419
498
  vk::ShaderStageFlagBits::eCompute,
420
- pipeline.shader_module,
499
+ pipeline->shader_module,
421
500
  entrypoint.c_str(),
422
501
  &specialization_info);
423
502
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
424
503
  vk::PipelineCreateFlags(),
425
504
  pipeline_shader_create_info,
426
- pipeline.layout);
427
- pipeline.pipeline = ctx->device.lock()->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
505
+ pipeline->layout);
506
+ pipeline->pipeline = ctx->device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
428
507
 
429
- ctx->gc.pipelines.push_back(&pipeline);
508
+ ctx->device->pipelines.push_back(pipeline);
430
509
  }
431
510
 
432
- static void ggml_vk_destroy_pipeline(ggml_backend_vk_context * ctx, vk_pipeline * pipeline) {
511
+ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
512
+ #ifdef GGML_VULKAN_DEBUG
513
+ std::cerr << "ggml_pipeline_destroy_pipeline(" << pipeline->name << ")" << std::endl;
514
+ #endif
433
515
  for (auto& pool : pipeline->descriptor_pools) {
434
- ctx->device.lock()->device.destroyDescriptorPool(pool);
516
+ device.destroyDescriptorPool(pool);
435
517
  }
436
518
  pipeline->descriptor_pools.clear();
437
519
  pipeline->descriptor_sets.clear();
438
520
  pipeline->descriptor_set_idx = 0;
439
521
 
440
- ctx->device.lock()->device.destroyDescriptorSetLayout(pipeline->dsl);
522
+ device.destroyDescriptorSetLayout(pipeline->dsl);
441
523
 
442
- ctx->device.lock()->device.destroyPipelineLayout(pipeline->layout);
524
+ device.destroyPipelineLayout(pipeline->layout);
443
525
 
444
- ctx->device.lock()->device.destroyShaderModule(pipeline->shader_module);
526
+ device.destroyShaderModule(pipeline->shader_module);
445
527
 
446
- ctx->device.lock()->device.destroyPipeline(pipeline->pipeline);
528
+ device.destroyPipeline(pipeline->pipeline);
447
529
  }
448
530
 
449
531
  static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, uint32_t n) {
450
532
  #ifdef GGML_VULKAN_DEBUG
451
- std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl;
533
+ std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")" << std::endl;
452
534
  #endif
453
- if (pipeline.descriptor_sets.size() >= pipeline.descriptor_set_idx + n) {
535
+ if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
454
536
  // Enough descriptors are available
455
537
  return;
456
538
  }
457
539
 
458
- if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
459
- const uint32_t alloc_count = pipeline.descriptor_set_idx + n - pipeline.descriptor_sets.size();
540
+ if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
541
+ const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
460
542
 
461
543
  std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
462
544
  for (uint32_t i = 0; i < alloc_count; i++) {
463
- layouts[i] = pipeline.dsl;
545
+ layouts[i] = pipeline->dsl;
464
546
  }
465
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[0], alloc_count, layouts.data());
466
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
467
- pipeline.descriptor_sets.insert(pipeline.descriptor_sets.end(), sets.begin(), sets.end());
547
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
548
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
549
+ pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
468
550
  } else {
469
- for (uint32_t i = pipeline.descriptor_sets.size(); i < pipeline.descriptor_set_idx + n; i++) {
470
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
551
+ for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
552
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
471
553
  vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
472
- pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
554
+ pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
473
555
 
474
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[i], 1, &pipeline.dsl);
475
- std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
476
- pipeline.descriptor_sets.push_back(sets[0]);
556
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
557
+ std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
558
+ pipeline->descriptor_sets.push_back(sets[0]);
477
559
  }
478
560
  }
479
561
  }
480
562
 
481
563
  static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
482
564
  #ifdef GGML_VULKAN_DEBUG
483
- std::cerr << "ggml_pipeline_cleanup(" << pipeline.name << ")" << std::endl;
565
+ std::cerr << "ggml_pipeline_cleanup(" << pipeline->name << ")" << std::endl;
484
566
  #endif
485
- pipeline.descriptor_set_idx = 0;
567
+ pipeline->descriptor_set_idx = 0;
486
568
  }
487
569
 
488
570
  static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx, vk_queue& q) {
@@ -498,7 +580,7 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx
498
580
  q.pool,
499
581
  vk::CommandBufferLevel::ePrimary,
500
582
  1);
501
- const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device.lock()->device.allocateCommandBuffers(command_buffer_alloc_info);
583
+ const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device->device.allocateCommandBuffers(command_buffer_alloc_info);
502
584
  auto buf = cmd_buffers.front();
503
585
 
504
586
  q.cmd_buffers.push_back(buf);
@@ -643,11 +725,11 @@ static void ggml_vk_create_queue(ggml_backend_vk_context * ctx, vk_queue& q, uin
643
725
  q.queue_family_index = queue_family_index;
644
726
 
645
727
  vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
646
- q.pool = ctx->device.lock()->device.createCommandPool(command_pool_create_info_compute);
728
+ q.pool = ctx->device->device.createCommandPool(command_pool_create_info_compute);
647
729
 
648
730
  q.cmd_buffer_idx = 0;
649
731
 
650
- q.queue = ctx->device.lock()->device.getQueue(queue_family_index, queue_index);
732
+ q.queue = ctx->device->device.getQueue(queue_family_index, queue_index);
651
733
 
652
734
  q.stage_flags = stage_flags;
653
735
  }
@@ -671,7 +753,7 @@ static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context *
671
753
  vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
672
754
  vk::SemaphoreCreateInfo ci{};
673
755
  ci.setPNext(&tci);
674
- vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
756
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
675
757
  ctx->gc.semaphores.push_back({ semaphore, 0 });
676
758
  return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
677
759
  }
@@ -684,7 +766,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
684
766
  vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
685
767
  vk::SemaphoreCreateInfo ci{};
686
768
  ci.setPNext(&tci);
687
- vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
769
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
688
770
  ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
689
771
  }
690
772
  return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
@@ -692,7 +774,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
692
774
 
693
775
  static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
694
776
  if (ctx->event_idx >= ctx->gc.events.size()) {
695
- ctx->gc.events.push_back(ctx->device.lock()->device.createEvent({}));
777
+ ctx->gc.events.push_back(ctx->device->device.createEvent({}));
696
778
  }
697
779
  return ctx->gc.events[ctx->event_idx++];
698
780
  }
@@ -703,7 +785,7 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
703
785
  #endif
704
786
  // Requires command buffers to be done
705
787
 
706
- ctx->device.lock()->device.resetCommandPool(q.pool);
788
+ ctx->device->device.resetCommandPool(q.pool);
707
789
  q.cmd_buffer_idx = 0;
708
790
  }
709
791
 
@@ -740,11 +822,11 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
740
822
  nullptr,
741
823
  };
742
824
 
743
- buf->buffer = ctx->device.lock()->device.createBuffer(buffer_create_info);
825
+ buf->buffer = ctx->device->device.createBuffer(buffer_create_info);
744
826
 
745
- vk::MemoryRequirements mem_req = ctx->device.lock()->device.getBufferMemoryRequirements(buf->buffer);
827
+ vk::MemoryRequirements mem_req = ctx->device->device.getBufferMemoryRequirements(buf->buffer);
746
828
 
747
- vk::PhysicalDeviceMemoryProperties mem_props = ctx->device.lock()->physical_device.getMemoryProperties();
829
+ vk::PhysicalDeviceMemoryProperties mem_props = ctx->device->physical_device.getMemoryProperties();
748
830
 
749
831
  uint32_t memory_type_index = UINT32_MAX;
750
832
 
@@ -757,30 +839,30 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
757
839
  }
758
840
 
759
841
  if (memory_type_index == UINT32_MAX) {
760
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
842
+ ctx->device->device.destroyBuffer(buf->buffer);
761
843
  buf->size = 0;
762
844
  throw vk::OutOfDeviceMemoryError("No suitable memory type found");
763
845
  }
764
846
 
765
847
  try {
766
- buf->device_memory = ctx->device.lock()->device.allocateMemory({ mem_req.size, memory_type_index });
848
+ buf->device_memory = ctx->device->device.allocateMemory({ mem_req.size, memory_type_index });
767
849
  } catch (const vk::SystemError& e) {
768
850
  // Out of Host/Device memory, clean up buffer
769
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
851
+ ctx->device->device.destroyBuffer(buf->buffer);
770
852
  buf->size = 0;
771
853
  throw e;
772
854
  }
773
855
  buf->ptr = nullptr;
774
856
 
775
857
  if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
776
- buf->ptr = ctx->device.lock()->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
858
+ buf->ptr = ctx->device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
777
859
  }
778
860
 
779
- ctx->device.lock()->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
861
+ ctx->device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
780
862
 
781
863
  buf->ctx = ctx;
782
864
 
783
- buf->device = ctx->device.lock();
865
+ buf->device = ctx->device;
784
866
 
785
867
  #ifdef GGML_VULKAN_DEBUG
786
868
  std::cerr << "Created buffer " << buf->buffer << std::endl;
@@ -802,7 +884,7 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
802
884
  static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
803
885
  vk_buffer buf;
804
886
  try {
805
- if (ctx->device.lock()->uma) {
887
+ if (ctx->device->uma) {
806
888
  // Fall back to host memory type
807
889
  buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
808
890
  } else {
@@ -883,10 +965,16 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
883
965
  std::cerr << "ggml_vk_load_shaders(" << ctx->name << ")" << std::endl;
884
966
  #endif
885
967
 
968
+ const std::shared_ptr<vk_device> device = ctx->device;
969
+
886
970
  // mulmat
887
- std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, ctx->device.lock()->subgroup_size * 2, 64, 2, 4, 4, ctx->device.lock()->subgroup_size };
888
- std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, ctx->device.lock()->subgroup_size, 32, 2, 4, 2, ctx->device.lock()->subgroup_size };
889
- std::initializer_list<uint32_t> warptile_s = { ctx->device.lock()->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, ctx->device.lock()->subgroup_size };
971
+ std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
972
+ std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
973
+ std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
974
+
975
+ std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
976
+ std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
977
+ std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
890
978
 
891
979
  std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
892
980
  std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
@@ -896,126 +984,206 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
896
984
  uint32_t m_align = 64;
897
985
  uint32_t s_align = 32;
898
986
 
899
- if (ctx->device.lock()->fp16) {
900
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_len, matmul_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
901
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_len, matmul_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
902
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_len, matmul_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
903
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_len, matmul_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
904
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_len, matmul_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
905
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_len, matmul_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
906
-
907
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_len, matmul_f16_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
908
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_len, matmul_f16_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
909
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_len, matmul_f16_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
910
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_len, matmul_f16_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
911
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_len, matmul_f16_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
912
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_len, matmul_f16_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
913
-
914
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_len, matmul_f16_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
915
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_len, matmul_f16_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
916
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_len, matmul_f16_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
917
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_len, matmul_f16_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
918
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_len, matmul_f16_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
919
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_len, matmul_f16_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
987
+ ctx->device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
988
+ ctx->device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
989
+ ctx->device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
990
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
991
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
992
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
993
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
994
+ ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
995
+
996
+ if (device->fp16) {
997
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
998
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
999
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1000
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1001
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1002
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1003
+
1004
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1005
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1006
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1007
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1008
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1009
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1010
+
1011
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1012
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1013
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1014
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1015
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1016
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1017
+
1018
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1019
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1020
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1021
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1022
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1023
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1024
+
1025
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_0_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1026
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_0_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1027
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_0_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1028
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1029
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1030
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1031
+
1032
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1033
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1034
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1035
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1036
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1037
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1038
+
1039
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1040
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1041
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1042
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1043
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1044
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1045
+
1046
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1047
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1048
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1049
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1050
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1051
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
920
1052
  } else {
921
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
922
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
923
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_fp32_len, matmul_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
924
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_fp32_len, matmul_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
925
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_fp32_len, matmul_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
926
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_fp32_len, matmul_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
927
-
928
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_fp32_len, matmul_f16_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
929
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_fp32_len, matmul_f16_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
930
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_fp32_len, matmul_f16_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
931
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_fp32_len, matmul_f16_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
932
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_fp32_len, matmul_f16_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
933
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_fp32_len, matmul_f16_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
934
-
935
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_fp32_len, matmul_f16_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
936
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_fp32_len, matmul_f16_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
937
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_fp32_len, matmul_f16_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
938
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_fp32_len, matmul_f16_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
939
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_fp32_len, matmul_f16_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
940
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_fp32_len, matmul_f16_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
941
- }
942
-
943
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
944
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
945
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
946
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
947
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
948
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
949
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
950
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
951
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
952
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
953
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
1053
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1054
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1055
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1056
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1057
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1058
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1059
+
1060
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1061
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1062
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1063
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1064
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1065
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1066
+
1067
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
1068
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
1069
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
1070
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
1071
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
1072
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
1073
+
1074
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1075
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1076
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1077
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1078
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1079
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1080
+
1081
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1082
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1083
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1084
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1085
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1086
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1087
+
1088
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1089
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1090
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1091
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1092
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1093
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1094
+
1095
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1096
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1097
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1098
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1099
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1100
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1101
+
1102
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1103
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1104
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1105
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
1106
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
1107
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
1108
+ }
1109
+
1110
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1111
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1112
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1113
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1114
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1115
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1116
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1117
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1118
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1119
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
1120
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
954
1121
 
955
1122
  // dequant shaders
956
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 4 * sizeof(int), { 64, 1, 1}, {}, 1);
957
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
958
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
959
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
960
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
961
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
962
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
963
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
964
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
965
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
966
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
967
- ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
1123
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1124
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1125
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1126
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1127
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1128
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
1129
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1130
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1131
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
1132
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
1133
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
968
1134
 
969
1135
  // get_rows
970
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
971
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
972
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
973
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
974
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
975
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1136
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1137
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1138
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1139
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1140
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1141
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
976
1142
 
977
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
978
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
979
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
980
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
981
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
982
- ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1143
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1144
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1145
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1146
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1147
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1148
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
983
1149
 
984
- ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
1150
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
985
1151
 
986
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
987
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1152
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1153
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
988
1154
 
989
- ggml_vk_create_pipeline(ctx, ctx->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
990
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1155
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1156
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
991
1157
 
992
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
993
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
994
- ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
1158
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1159
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1160
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
995
1161
 
996
- ggml_vk_create_pipeline(ctx, ctx->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1162
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
997
1163
 
998
- ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1164
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
999
1165
 
1000
- ggml_vk_create_pipeline(ctx, ctx->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1166
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1001
1167
 
1002
- ggml_vk_create_pipeline(ctx, ctx->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1168
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1003
1169
 
1004
- ggml_vk_create_pipeline(ctx, ctx->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1170
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1005
1171
 
1006
- ggml_vk_create_pipeline(ctx, ctx->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1007
- ggml_vk_create_pipeline(ctx, ctx->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1008
- ggml_vk_create_pipeline(ctx, ctx->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1172
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1173
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1174
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
1009
1175
 
1010
- ggml_vk_create_pipeline(ctx, ctx->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1176
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1011
1177
 
1012
- ggml_vk_create_pipeline(ctx, ctx->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
1178
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1013
1179
 
1014
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1015
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1180
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1181
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1016
1182
 
1017
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1018
- ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1183
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1184
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
1185
+
1186
+ ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1019
1187
  }
1020
1188
 
1021
1189
  static void ggml_vk_print_gpu_info(size_t idx) {
@@ -1057,8 +1225,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
1057
1225
  }
1058
1226
  }
1059
1227
 
1060
- const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
1061
- bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
1228
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1229
+ bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1062
1230
 
1063
1231
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1064
1232
 
@@ -1106,7 +1274,9 @@ void ggml_vk_instance_init() {
1106
1274
 
1107
1275
  const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
1108
1276
  const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
1277
+ #ifdef __APPLE__
1109
1278
  const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
1279
+ #endif
1110
1280
 
1111
1281
  std::vector<const char*> layers;
1112
1282
 
@@ -1117,13 +1287,17 @@ void ggml_vk_instance_init() {
1117
1287
  if (validation_ext) {
1118
1288
  extensions.push_back("VK_EXT_validation_features");
1119
1289
  }
1290
+ #ifdef __APPLE__
1120
1291
  if (portability_enumeration_ext) {
1121
1292
  extensions.push_back("VK_KHR_portability_enumeration");
1122
1293
  }
1294
+ #endif
1123
1295
  vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
1296
+ #ifdef __APPLE__
1124
1297
  if (portability_enumeration_ext) {
1125
1298
  instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
1126
1299
  }
1300
+ #endif
1127
1301
 
1128
1302
  std::vector<vk::ValidationFeatureEnableEXT> features_enable;
1129
1303
  vk::ValidationFeaturesEXT validation_features;
@@ -1182,140 +1356,152 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
1182
1356
  throw std::runtime_error("Device not found");
1183
1357
  }
1184
1358
 
1185
- vk_instance.devices[idx] = std::make_shared<vk_device>();
1186
- ctx->device = vk_instance.devices[idx];
1187
- ctx->device.lock()->physical_device = devices[dev_num];
1188
- const std::vector<vk::ExtensionProperties> ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties();
1359
+ ctx->device = ggml_vk_get_device(idx);
1360
+ if (!ctx->device->initialized) {
1361
+ ctx->device->physical_device = devices[dev_num];
1362
+ const std::vector<vk::ExtensionProperties> ext_props = ctx->device->physical_device.enumerateDeviceExtensionProperties();
1189
1363
 
1190
- bool maintenance4_support = false;
1364
+ bool maintenance4_support = false;
1191
1365
 
1192
- // Check if maintenance4 is supported
1193
- for (const auto& properties : ext_props) {
1194
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1195
- maintenance4_support = true;
1366
+ // Check if maintenance4 is supported
1367
+ for (const auto& properties : ext_props) {
1368
+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1369
+ maintenance4_support = true;
1370
+ }
1196
1371
  }
1197
- }
1198
1372
 
1199
- vk::PhysicalDeviceProperties2 props2;
1200
- vk::PhysicalDeviceMaintenance3Properties props3;
1201
- vk::PhysicalDeviceMaintenance4Properties props4;
1202
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
1203
- props2.pNext = &props3;
1204
- props3.pNext = &subgroup_props;
1205
- if (maintenance4_support) {
1206
- subgroup_props.pNext = &props4;
1207
- }
1208
- ctx->device.lock()->physical_device.getProperties2(&props2);
1209
- ctx->device.lock()->properties = props2.properties;
1373
+ vk::PhysicalDeviceProperties2 props2;
1374
+ vk::PhysicalDeviceMaintenance3Properties props3;
1375
+ vk::PhysicalDeviceMaintenance4Properties props4;
1376
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
1377
+ props2.pNext = &props3;
1378
+ props3.pNext = &subgroup_props;
1379
+ if (maintenance4_support) {
1380
+ subgroup_props.pNext = &props4;
1381
+ }
1382
+ ctx->device->physical_device.getProperties2(&props2);
1383
+ ctx->device->properties = props2.properties;
1210
1384
 
1211
- if (maintenance4_support) {
1212
- ctx->device.lock()->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1213
- } else {
1214
- ctx->device.lock()->max_memory_allocation_size = props3.maxMemoryAllocationSize;
1215
- }
1385
+ const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
1216
1386
 
1217
- ctx->device.lock()->vendor_id = ctx->device.lock()->properties.vendorID;
1218
- ctx->device.lock()->subgroup_size = subgroup_props.subgroupSize;
1219
- ctx->device.lock()->uma = ctx->device.lock()->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1387
+ if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
1388
+ ctx->device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
1389
+ } else if (maintenance4_support) {
1390
+ ctx->device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1391
+ } else {
1392
+ ctx->device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
1393
+ }
1220
1394
 
1221
- bool fp16_storage = false;
1222
- bool fp16_compute = false;
1395
+ ctx->device->vendor_id = ctx->device->properties.vendorID;
1396
+ ctx->device->subgroup_size = subgroup_props.subgroupSize;
1397
+ ctx->device->uma = ctx->device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1223
1398
 
1224
- for (const auto& properties : ext_props) {
1225
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1226
- fp16_storage = true;
1227
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1228
- fp16_compute = true;
1399
+ bool fp16_storage = false;
1400
+ bool fp16_compute = false;
1401
+
1402
+ for (const auto& properties : ext_props) {
1403
+ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1404
+ fp16_storage = true;
1405
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1406
+ fp16_compute = true;
1407
+ }
1229
1408
  }
1230
- }
1231
1409
 
1232
- const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
1233
- bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
1410
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1411
+ const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1234
1412
 
1235
- ctx->device.lock()->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1413
+ ctx->device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1236
1414
 
1237
- std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device.lock()->physical_device.getQueueFamilyProperties();
1415
+ std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device->physical_device.getQueueFamilyProperties();
1238
1416
 
1239
- // Try to find a non-graphics compute queue and transfer-focused queues
1240
- const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
1241
- const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
1417
+ // Try to find a non-graphics compute queue and transfer-focused queues
1418
+ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
1419
+ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
1242
1420
 
1243
- const float priorities[] = { 1.0f, 1.0f };
1244
- ctx->device.lock()->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
1421
+ const float priorities[] = { 1.0f, 1.0f };
1422
+ ctx->device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
1245
1423
 
1246
- std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
1247
- if (compute_queue_family_index != transfer_queue_family_index) {
1248
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1249
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
1250
- } else if(!ctx->device.lock()->single_queue) {
1251
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
1252
- } else {
1253
- device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1254
- }
1255
- vk::DeviceCreateInfo device_create_info;
1256
- std::vector<const char *> device_extensions;
1257
- vk::PhysicalDeviceFeatures device_features = ctx->device.lock()->physical_device.getFeatures();
1424
+ std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
1425
+ if (compute_queue_family_index != transfer_queue_family_index) {
1426
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1427
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
1428
+ } else if(!ctx->device->single_queue) {
1429
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
1430
+ } else {
1431
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
1432
+ }
1433
+ vk::DeviceCreateInfo device_create_info;
1434
+ std::vector<const char *> device_extensions;
1435
+ vk::PhysicalDeviceFeatures device_features = ctx->device->physical_device.getFeatures();
1258
1436
 
1259
- VkPhysicalDeviceFeatures2 device_features2;
1260
- device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
1261
- device_features2.pNext = nullptr;
1262
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
1437
+ VkPhysicalDeviceFeatures2 device_features2;
1438
+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
1439
+ device_features2.pNext = nullptr;
1440
+ device_features2.features = (VkPhysicalDeviceFeatures)device_features;
1263
1441
 
1264
- VkPhysicalDeviceVulkan11Features vk11_features;
1265
- vk11_features.pNext = nullptr;
1266
- vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
1267
- device_features2.pNext = &vk11_features;
1442
+ VkPhysicalDeviceVulkan11Features vk11_features;
1443
+ vk11_features.pNext = nullptr;
1444
+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
1445
+ device_features2.pNext = &vk11_features;
1268
1446
 
1269
- VkPhysicalDeviceVulkan12Features vk12_features;
1270
- vk12_features.pNext = nullptr;
1271
- vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1272
- vk11_features.pNext = &vk12_features;
1447
+ VkPhysicalDeviceVulkan12Features vk12_features;
1448
+ vk12_features.pNext = nullptr;
1449
+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1450
+ vk11_features.pNext = &vk12_features;
1273
1451
 
1274
- vkGetPhysicalDeviceFeatures2(ctx->device.lock()->physical_device, &device_features2);
1452
+ vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
1275
1453
 
1276
- ctx->device.lock()->fp16 = ctx->device.lock()->fp16 && vk12_features.shaderFloat16;
1454
+ ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
1277
1455
 
1278
- if (!vk11_features.storageBuffer16BitAccess) {
1279
- std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1280
- throw std::runtime_error("Unsupported device");
1281
- }
1456
+ if (!vk11_features.storageBuffer16BitAccess) {
1457
+ std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1458
+ throw std::runtime_error("Unsupported device");
1459
+ }
1282
1460
 
1283
- device_extensions.push_back("VK_KHR_16bit_storage");
1461
+ device_extensions.push_back("VK_KHR_16bit_storage");
1284
1462
 
1285
1463
  #ifdef GGML_VULKAN_VALIDATE
1286
- device_extensions.push_back("VK_KHR_shader_non_semantic_info");
1464
+ device_extensions.push_back("VK_KHR_shader_non_semantic_info");
1287
1465
  #endif
1288
1466
 
1289
- if (ctx->device.lock()->fp16) {
1290
- device_extensions.push_back("VK_KHR_shader_float16_int8");
1291
- }
1292
- ctx->device.lock()->name = ctx->device.lock()->properties.deviceName.data();
1467
+ if (ctx->device->fp16) {
1468
+ device_extensions.push_back("VK_KHR_shader_float16_int8");
1469
+ }
1470
+ ctx->device->name = ctx->device->properties.deviceName.data();
1293
1471
 
1294
- device_create_info = {
1295
- vk::DeviceCreateFlags(),
1296
- device_queue_create_infos,
1297
- {},
1298
- device_extensions
1299
- };
1300
- device_create_info.setPNext(&device_features2);
1301
- ctx->device.lock()->device = ctx->device.lock()->physical_device.createDevice(device_create_info);
1472
+ device_create_info = {
1473
+ vk::DeviceCreateFlags(),
1474
+ device_queue_create_infos,
1475
+ {},
1476
+ device_extensions
1477
+ };
1478
+ device_create_info.setPNext(&device_features2);
1479
+ ctx->device->device = ctx->device->physical_device.createDevice(device_create_info);
1302
1480
 
1303
- ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
1481
+ ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
1304
1482
 
1305
- // Shaders
1306
- ggml_vk_load_shaders(ctx);
1483
+ // Queues
1484
+ ggml_vk_create_queue(ctx, ctx->device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
1307
1485
 
1308
- // Queues
1309
- ggml_vk_create_queue(ctx, ctx->device.lock()->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
1310
- if (!ctx->device.lock()->single_queue) {
1311
- const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
1312
- ggml_vk_create_queue(ctx, ctx->device.lock()->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
1313
- } else {
1314
- // TODO: Use pointer or reference to avoid copy
1315
- ctx->device.lock()->transfer_queue = ctx->device.lock()->compute_queue;
1486
+ // Shaders
1487
+ ggml_vk_load_shaders(ctx);
1488
+
1489
+ if (!ctx->device->single_queue) {
1490
+ const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
1491
+ ggml_vk_create_queue(ctx, ctx->device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
1492
+ } else {
1493
+ // TODO: Use pointer or reference to avoid copy
1494
+ ctx->device->transfer_queue = ctx->device->compute_queue;
1495
+ }
1496
+
1497
+ ctx->device->idx = dev_num;
1498
+ ctx->device->initialized = true;
1499
+ } else if (ctx->device->idx != dev_num) {
1500
+ std::cerr << "ggml_vulkan: Device " << ctx->device->name << " already initialized with index " << ctx->device->idx << ", but trying to reinitialize with index " << dev_num << std::endl;
1501
+ throw std::runtime_error("Device already initialized");
1316
1502
  }
1317
1503
 
1318
- ctx->fence = ctx->device.lock()->device.createFence({});
1504
+ ctx->fence = ctx->device->device.createFence({});
1319
1505
 
1320
1506
  ctx->compute_ctx = nullptr;
1321
1507
  ctx->transfer_ctx = nullptr;
@@ -1333,7 +1519,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
1333
1519
  #endif
1334
1520
  }
1335
1521
 
1336
- static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
1522
+ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
1337
1523
  #ifdef GGML_VULKAN_DEBUG
1338
1524
  std::cerr << "ggml_vk_get_to_fp16()" << std::endl;
1339
1525
  #endif
@@ -1354,10 +1540,36 @@ static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
1354
1540
  return nullptr;
1355
1541
  }
1356
1542
 
1357
- return &ctx->pipeline_dequant[type];
1543
+ return ctx->device->pipeline_dequant[type];
1544
+ }
1545
+
1546
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
1547
+ #ifdef GGML_VULKAN_DEBUG
1548
+ std::cerr << "ggml_vk_get_mul_mat_mat_pipeline()" << std::endl;
1549
+ #endif
1550
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
1551
+ return ctx->device->pipeline_matmul_f32;
1552
+ }
1553
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
1554
+ return ctx->device->pipeline_matmul_f16_f32;
1555
+ }
1556
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
1557
+ return ctx->device->pipeline_matmul_f16;
1558
+ }
1559
+
1560
+ GGML_ASSERT(src1_type == GGML_TYPE_F32);
1561
+
1562
+ switch (src0_type) {
1563
+ case GGML_TYPE_Q4_0:
1564
+ break;
1565
+ default:
1566
+ return nullptr;
1567
+ }
1568
+
1569
+ return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
1358
1570
  }
1359
1571
 
1360
- static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
1572
+ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
1361
1573
  #ifdef GGML_VULKAN_DEBUG
1362
1574
  std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
1363
1575
  #endif
@@ -1378,7 +1590,7 @@ static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
1378
1590
  return nullptr;
1379
1591
  }
1380
1592
 
1381
- return &ctx->pipeline_dequant_mul_mat_vec_f32[type];
1593
+ return ctx->device->pipeline_dequant_mul_mat_vec_f32[type];
1382
1594
  }
1383
1595
 
1384
1596
  static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
@@ -1457,8 +1669,8 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
1457
1669
  if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1458
1670
  fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
1459
1671
  size/1024.0/1024.0);
1460
- ctx->device.lock()->device.freeMemory(buf->device_memory);
1461
- ctx->device.lock()->device.destroyBuffer(buf->buffer);
1672
+ ctx->device->device.freeMemory(buf->device_memory);
1673
+ ctx->device->device.destroyBuffer(buf->buffer);
1462
1674
  return nullptr;
1463
1675
  }
1464
1676
 
@@ -1522,30 +1734,30 @@ static vk_submission ggml_vk_begin_submission(ggml_backend_vk_context * ctx, vk_
1522
1734
  }
1523
1735
 
1524
1736
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
1525
- const uint32_t wg0 = CEIL_DIV(elements[0], pipeline.wg_denoms[0]);
1526
- const uint32_t wg1 = CEIL_DIV(elements[1], pipeline.wg_denoms[1]);
1527
- const uint32_t wg2 = CEIL_DIV(elements[2], pipeline.wg_denoms[2]);
1737
+ const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
1738
+ const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
1739
+ const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
1528
1740
  #ifdef GGML_VULKAN_DEBUG
1529
- std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline.name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
1741
+ std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
1530
1742
  #endif
1531
1743
  std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
1532
1744
  std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
1533
- GGML_ASSERT(pipeline.descriptor_set_idx < pipeline.descriptor_sets.size());
1534
- GGML_ASSERT(buffers.size() == pipeline.parameter_count);
1535
- vk::DescriptorSet& descriptor_set = pipeline.descriptor_sets[pipeline.descriptor_set_idx++];
1536
- for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
1745
+ GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
1746
+ GGML_ASSERT(buffers.size() == pipeline->parameter_count);
1747
+ vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
1748
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
1537
1749
  descriptor_buffer_infos.push_back({buffers[i].buffer->buffer, buffers[i].offset, buffers[i].size});
1538
1750
  }
1539
- for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
1751
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
1540
1752
  write_descriptor_sets.push_back({descriptor_set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &descriptor_buffer_infos[i]});
1541
1753
  }
1542
1754
 
1543
- ctx->device.lock()->device.updateDescriptorSets(write_descriptor_sets, {});
1755
+ ctx->device->device.updateDescriptorSets(write_descriptor_sets, {});
1544
1756
 
1545
- subctx->s->buffer.pushConstants(pipeline.layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
1546
- subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline.pipeline);
1757
+ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
1758
+ subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
1547
1759
  subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
1548
- pipeline.layout,
1760
+ pipeline->layout,
1549
1761
  0,
1550
1762
  { descriptor_set },
1551
1763
  {});
@@ -1804,7 +2016,7 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
1804
2016
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
1805
2017
  }
1806
2018
  } else {
1807
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2019
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1808
2020
  ggml_vk_ctx_begin(ctx, subctx);
1809
2021
  ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, src, spitch, width, height, true);
1810
2022
  ggml_vk_ctx_end(subctx);
@@ -1814,8 +2026,9 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
1814
2026
  }
1815
2027
 
1816
2028
  ggml_vk_submit(subctx, ctx->fence);
1817
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
1818
- ctx->device.lock()->device.resetFences({ ctx->fence });
2029
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
2030
+ ctx->device->device.resetFences({ ctx->fence });
2031
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
1819
2032
  }
1820
2033
  }
1821
2034
 
@@ -1900,18 +2113,19 @@ static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, s
1900
2113
 
1901
2114
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
1902
2115
  } else {
1903
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2116
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1904
2117
  ggml_vk_ctx_begin(ctx, subctx);
1905
2118
  ggml_vk_buffer_read_async(ctx, subctx, src, offset, dst, size, true);
1906
2119
  ggml_vk_ctx_end(subctx);
1907
2120
 
1908
2121
  ggml_vk_submit(subctx, ctx->fence);
1909
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
1910
- ctx->device.lock()->device.resetFences({ ctx->fence });
2122
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
2123
+ ctx->device->device.resetFences({ ctx->fence });
1911
2124
 
1912
2125
  for (auto& cpy : subctx->out_memcpys) {
1913
2126
  memcpy(cpy.dst, cpy.src, cpy.n);
1914
2127
  }
2128
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
1915
2129
  }
1916
2130
  }
1917
2131
 
@@ -1935,15 +2149,13 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
1935
2149
  // Copy within the device
1936
2150
  ggml_backend_vk_context * ctx = src->ctx;
1937
2151
 
1938
- VkBufferCopy bc{ src_offset, dst_offset, size };
1939
-
1940
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2152
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1941
2153
  ggml_vk_ctx_begin(ctx, subctx);
1942
2154
  ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
1943
2155
  ggml_vk_ctx_end(subctx);
1944
2156
  ggml_vk_submit(subctx, ctx->fence);
1945
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
1946
- ctx->device.lock()->device.resetFences({ ctx->fence });
2157
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
2158
+ ctx->device->device.resetFences({ ctx->fence });
1947
2159
  } else {
1948
2160
  #ifdef GGML_VULKAN_DEBUG
1949
2161
  std::cerr << "ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")" << std::endl;
@@ -1971,14 +2183,14 @@ static void ggml_vk_buffer_memset(ggml_backend_vk_context * ctx, vk_buffer& dst,
1971
2183
  // Make sure ctx owns the buffer
1972
2184
  GGML_ASSERT(dst->ctx == ctx);
1973
2185
 
1974
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
2186
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
1975
2187
  ggml_vk_ctx_begin(ctx, subctx);
1976
2188
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
1977
2189
  ggml_vk_ctx_end(subctx);
1978
2190
 
1979
2191
  ggml_vk_submit(subctx, ctx->fence);
1980
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
1981
- ctx->device.lock()->device.resetFences({ ctx->fence });
2192
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
2193
+ ctx->device->device.resetFences({ ctx->fence });
1982
2194
  }
1983
2195
 
1984
2196
  static void ggml_vk_h2d_tensor_2d(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * src, uint64_t i3, uint64_t i2, uint64_t i1) {
@@ -2039,176 +2251,63 @@ static void ggml_vk_d2h_tensor_2d(ggml_backend_vk_context * ctx, vk_context * su
2039
2251
 
2040
2252
  static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
2041
2253
  #ifdef GGML_VULKAN_DEBUG
2042
- std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")";
2254
+ std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")" << std::endl;
2043
2255
  #endif
2044
2256
  if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2045
- #ifdef GGML_VULKAN_DEBUG
2046
- std::cerr << " = 4" << std::endl;
2047
- #endif
2048
2257
  return 4;
2049
2258
  }
2050
2259
 
2051
- #ifdef GGML_VULKAN_DEBUG
2052
- std::cerr << " = 1" << std::endl;
2053
- #endif
2054
2260
  return 1;
2055
2261
  }
2056
2262
 
2057
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, int m, int n) {
2058
- #ifdef GGML_VULKAN_DEBUG
2059
- std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
2060
- #endif
2263
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2061
2264
  if (m <= 32 || n <= 32) {
2062
- return ctx->pipeline_matmul_f32_aligned_s.align;
2265
+ return aligned ? mmp->a_s : mmp->s;
2063
2266
  }
2064
- if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
2065
- return ctx->pipeline_matmul_f32_aligned_m.align;
2066
- }
2067
- return ctx->pipeline_matmul_f32_aligned_l.align;
2267
+ return aligned ? mmp->a_m : mmp->m;
2268
+
2269
+ GGML_UNUSED(ctx);
2068
2270
  }
2069
2271
 
2070
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2071
- if (bit16_x && bit16_y) {
2072
- if (m <= 32 || n <= 32) {
2073
- #ifdef GGML_VULKAN_DEBUG
2074
- std::cerr << " S" << std::endl;
2075
- #endif
2076
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2077
- }
2078
- #ifdef GGML_VULKAN_DEBUG
2079
- std::cerr << " M" << std::endl;
2080
- #endif
2081
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2082
- }
2083
- if (bit16_x && !bit16_y) {
2084
- if (m <= 32 || n <= 32) {
2085
- #ifdef GGML_VULKAN_DEBUG
2086
- std::cerr << " S" << std::endl;
2087
- #endif
2088
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2089
- }
2090
- #ifdef GGML_VULKAN_DEBUG
2091
- std::cerr << " M" << std::endl;
2092
- #endif
2093
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2094
- }
2095
- if (!bit16_x && bit16_y) {
2096
- GGML_ASSERT(false);
2097
- }
2272
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2273
+ return aligned ? mmp->a_m : mmp->m;
2098
2274
 
2099
- if (m <= 32 || n <= 32) {
2100
- #ifdef GGML_VULKAN_DEBUG
2101
- std::cerr << " S" << std::endl;
2102
- #endif
2103
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2104
- }
2105
- #ifdef GGML_VULKAN_DEBUG
2106
- std::cerr << " M" << std::endl;
2107
- #endif
2108
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2275
+ GGML_UNUSED(ctx);
2109
2276
  }
2110
2277
 
2111
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2112
- #ifdef GGML_VULKAN_DEBUG
2113
- std::cerr << " M" << std::endl;
2114
- #endif
2115
- if (bit16_x && bit16_y) {
2116
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2117
- }
2118
- if (bit16_x && !bit16_y) {
2119
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2120
- }
2121
- if (!bit16_x && bit16_y) {
2122
- GGML_ASSERT(false);
2123
- }
2124
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2125
- }
2278
+ static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2279
+ return aligned ? mmp->a_s : mmp->s;
2126
2280
 
2127
- static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2128
- #ifdef GGML_VULKAN_DEBUG
2129
- std::cerr << " S" << std::endl;
2130
- #endif
2131
- if (bit16_x && bit16_y) {
2132
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2133
- }
2134
- if (bit16_x && !bit16_y) {
2135
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2136
- }
2137
- if (!bit16_x && bit16_y) {
2138
- GGML_ASSERT(false);
2139
- }
2140
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2281
+ GGML_UNUSED(ctx);
2141
2282
  }
2142
2283
 
2143
- static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2284
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2144
2285
  #ifdef GGML_VULKAN_DEBUG
2145
- std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
2286
+ std::cerr << "ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")" << std::endl;
2146
2287
  #endif
2147
- switch (ctx->device.lock()->vendor_id) {
2288
+ switch (ctx->device->vendor_id) {
2148
2289
  case VK_VENDOR_ID_AMD:
2149
- return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
2290
+ return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
2150
2291
  case VK_VENDOR_ID_APPLE:
2151
- return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
2292
+ return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
2152
2293
  case VK_VENDOR_ID_INTEL:
2153
- return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
2154
- }
2155
-
2156
- if (bit16_x && bit16_y) {
2157
- if (m <= 32 || n <= 32) {
2158
- #ifdef GGML_VULKAN_DEBUG
2159
- std::cerr << " S" << std::endl;
2160
- #endif
2161
- return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2162
- }
2163
- if (m <= 64 || n <= 64) {
2164
- #ifdef GGML_VULKAN_DEBUG
2165
- std::cerr << " M" << std::endl;
2166
- #endif
2167
- return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2168
- }
2169
- #ifdef GGML_VULKAN_DEBUG
2170
- std::cerr << " L" << std::endl;
2171
- #endif
2172
- return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l;
2173
- }
2174
- if (bit16_x && !bit16_y) {
2175
- if (m <= 32 || n <= 32) {
2176
- #ifdef GGML_VULKAN_DEBUG
2177
- std::cerr << " S" << std::endl;
2178
- #endif
2179
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2180
- }
2181
- if (m <= 64 || n <= 64) {
2182
- #ifdef GGML_VULKAN_DEBUG
2183
- std::cerr << " M" << std::endl;
2184
- #endif
2185
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2186
- }
2187
- #ifdef GGML_VULKAN_DEBUG
2188
- std::cerr << " L" << std::endl;
2189
- #endif
2190
- return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_l : &ctx->pipeline_matmul_f16_f32_l;
2191
- }
2192
- if (!bit16_x && bit16_y) {
2193
- GGML_ASSERT(false);
2294
+ return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
2194
2295
  }
2195
2296
 
2196
2297
  if (m <= 32 || n <= 32) {
2197
- #ifdef GGML_VULKAN_DEBUG
2198
- std::cerr << " S" << std::endl;
2199
- #endif
2200
- return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2298
+ return aligned ? mmp->a_s : mmp->s;
2201
2299
  }
2202
2300
  if (m <= 64 || n <= 64) {
2203
- #ifdef GGML_VULKAN_DEBUG
2204
- std::cerr << " M" << std::endl;
2205
- #endif
2206
- return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2301
+ return aligned ? mmp->a_m : mmp->m;
2207
2302
  }
2303
+ return aligned ? mmp->a_l : mmp->l;
2304
+ }
2305
+
2306
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
2208
2307
  #ifdef GGML_VULKAN_DEBUG
2209
- std::cerr << " L" << std::endl;
2308
+ std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
2210
2309
  #endif
2211
- return aligned ? &ctx->pipeline_matmul_f32_aligned_l : &ctx->pipeline_matmul_f32_l;
2310
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, false)->align;
2212
2311
  }
2213
2312
 
2214
2313
  static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
@@ -2226,10 +2325,10 @@ static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, v
2226
2325
 
2227
2326
  const std::array<uint32_t, 14> pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d };
2228
2327
  // Make sure enough workgroups get assigned for split k to work
2229
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline.wg_denoms[0]) * pipeline.wg_denoms[0]) * split_k, n, batch });
2328
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
2230
2329
  ggml_vk_sync_buffers(subctx);
2231
2330
  const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
2232
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2331
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2233
2332
  }
2234
2333
 
2235
2334
  static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -2239,41 +2338,39 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
2239
2338
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
2240
2339
  }
2241
2340
 
2242
- static vk_pipeline * ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
2341
+ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
2243
2342
  if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
2244
- return &ctx->pipeline_cpy_f32_f32;
2343
+ return ctx->device->pipeline_cpy_f32_f32;
2245
2344
  }
2246
2345
  if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
2247
- return &ctx->pipeline_cpy_f32_f16;
2346
+ return ctx->device->pipeline_cpy_f32_f16;
2248
2347
  }
2249
2348
  if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
2250
- return &ctx->pipeline_cpy_f16_f16;
2349
+ return ctx->device->pipeline_cpy_f16_f16;
2251
2350
  }
2252
2351
 
2253
2352
  std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
2254
2353
  GGML_ASSERT(false);
2255
2354
  }
2256
2355
 
2257
- static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline * pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out, ggml_type buffer_type, bool aligned=true) {
2356
+ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
2258
2357
  #ifdef GGML_VULKAN_DEBUG
2259
2358
  std::cerr << "ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", backend=" << tensor->backend << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
2260
2359
  std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")" << std::endl;
2261
2360
  #endif
2262
2361
  const int tensor_type_size = ggml_type_size(tensor->type);
2263
- const int dst_type_size = ggml_type_size(buffer_type);
2264
2362
 
2265
- const uint32_t ne = tensor->ne[0] * tensor->ne[1] * tensor->ne[2];
2363
+ const uint32_t ne = ggml_nelements(tensor);
2266
2364
 
2267
- const uint32_t nb2 = aligned ? ggml_vk_align_size(dst_type_size * tensor->ne[0] * tensor->ne[1], ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size : tensor->ne[0] * tensor->ne[1];
2268
-
2269
- const vk_op_cpy_push_constants pc = {
2365
+ const vk_op_unary_push_constants pc = {
2270
2366
  (uint32_t)ne,
2271
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size,
2272
- (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], 1 , (uint32_t)tensor->ne[0] , nb2,
2367
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
2368
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
2273
2369
  0,
2370
+ 0.0f, 0.0f,
2274
2371
  };
2275
2372
  ggml_vk_sync_buffers(subctx);
2276
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { in, out }, sizeof(vk_op_cpy_push_constants), &pc, { ne, 1, 1 });
2373
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
2277
2374
  }
2278
2375
 
2279
2376
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2313,23 +2410,30 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2313
2410
  bool src0_uma = false;
2314
2411
  bool src1_uma = false;
2315
2412
 
2316
- if (ctx->device.lock()->uma) {
2413
+ if (ctx->device->uma) {
2317
2414
  ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
2318
2415
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2319
2416
  src0_uma = d_Qx != nullptr;
2320
2417
  src1_uma = d_Qy != nullptr;
2321
2418
  }
2322
2419
 
2323
- const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma;
2324
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2420
+ const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
2421
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2325
2422
 
2326
2423
  const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
2327
2424
  const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
2328
2425
 
2329
- const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2426
+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2330
2427
 
2331
- const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 || x_non_contig;
2332
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
2428
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
2429
+
2430
+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
2431
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
2432
+
2433
+ if (mmp == nullptr) {
2434
+ // Fall back to dequant + f16 mulmat
2435
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
2436
+ }
2333
2437
 
2334
2438
  // Not implemented
2335
2439
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
@@ -2338,17 +2442,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2338
2442
  const int y_ne = ne11 * ne10;
2339
2443
  const int d_ne = ne11 * ne01;
2340
2444
 
2341
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, ne01, ne11));
2445
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
2342
2446
  const bool aligned = ne10 == kpad;
2343
2447
 
2344
2448
  const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
2345
2449
 
2346
- vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(ctx, true, !f16_f32_kernel, ne01, ne11, aligned);
2450
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
2347
2451
 
2348
2452
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
2349
2453
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2350
- const uint64_t x_sz = sizeof(ggml_fp16_t) * x_ne;
2351
- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2454
+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
2455
+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2352
2456
  const uint64_t d_sz = sizeof(float) * d_ne;
2353
2457
 
2354
2458
  vk_buffer d_D = extra->buffer_gpu.lock();
@@ -2379,7 +2483,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2379
2483
  } else {
2380
2484
  d_X = d_Qx;
2381
2485
  x_buf_offset = qx_buf_offset;
2382
- GGML_ASSERT(qx_sz == x_sz); // NOLINT
2486
+ GGML_ASSERT(qx_sz == x_sz);
2383
2487
  }
2384
2488
  if (qy_needs_dequant) {
2385
2489
  d_Y = ctx->prealloc_y;
@@ -2390,8 +2494,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2390
2494
  GGML_ASSERT(qy_sz == y_sz);
2391
2495
  }
2392
2496
 
2393
- vk_pipeline * to_fp16_vk_0 = nullptr;
2394
- vk_pipeline * to_fp16_vk_1 = nullptr;
2497
+ vk_pipeline to_fp16_vk_0 = nullptr;
2498
+ vk_pipeline to_fp16_vk_1 = nullptr;
2395
2499
 
2396
2500
  if (x_non_contig) {
2397
2501
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
@@ -2407,19 +2511,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2407
2511
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
2408
2512
 
2409
2513
  // Allocate descriptor sets
2410
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne12 * ne13);
2514
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
2411
2515
  if (qx_needs_dequant) {
2412
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, x_non_contig ? 1 : ne12 * ne13);
2516
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
2413
2517
  }
2414
2518
  if (qy_needs_dequant) {
2415
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2519
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1);
2416
2520
  }
2417
2521
  if (split_k > 1) {
2418
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, ne12 * ne13);
2522
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
2419
2523
  }
2420
2524
 
2421
2525
  if (x_non_contig) {
2422
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, dst->type, false);
2526
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
2423
2527
  } else if (load_x || qx_needs_dequant) {
2424
2528
  if (load_x) {
2425
2529
  // copy data to device
@@ -2428,13 +2532,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2428
2532
  }
2429
2533
 
2430
2534
  if (qx_needs_dequant) {
2431
- const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
2535
+ const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
2432
2536
  ggml_vk_sync_buffers(subctx);
2433
- ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
2537
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
2434
2538
  }
2435
2539
  }
2436
2540
  if (y_non_contig) {
2437
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, dst->type);
2541
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
2438
2542
  } else if (load_y) {
2439
2543
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
2440
2544
  }
@@ -2451,9 +2555,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
2451
2555
  }
2452
2556
 
2453
2557
  // compute
2454
- ggml_vk_matmul(ctx, subctx, *pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
2558
+ ggml_vk_matmul(ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT
2455
2559
 
2456
- if (dst->backend == GGML_BACKEND_CPU) {
2560
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2457
2561
  // copy dst to host
2458
2562
  float * d = (float *) ((char *) dst->data);
2459
2563
  ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, sizeof(float) * d_ne * ne12 * ne13);
@@ -2499,15 +2603,15 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2499
2603
  bool src0_uma = false;
2500
2604
  bool src1_uma = false;
2501
2605
 
2502
- if (ctx->device.lock()->uma) {
2606
+ if (ctx->device->uma) {
2503
2607
  ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
2504
2608
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2505
2609
  src0_uma = d_Qx != nullptr;
2506
2610
  src1_uma = d_Qy != nullptr;
2507
2611
  }
2508
2612
 
2509
- const bool load_x = src0->backend != GGML_BACKEND_GPU && !src0_uma;
2510
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2613
+ const bool load_x = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
2614
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2511
2615
 
2512
2616
  const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
2513
2617
  const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
@@ -2521,9 +2625,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2521
2625
  const uint64_t y_ne = ne11 * ne10;
2522
2626
  const uint64_t d_ne = ne11 * ne01;
2523
2627
 
2524
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
2628
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
2525
2629
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2526
- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
2630
+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
2527
2631
  const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
2528
2632
  const uint64_t d_sz = sizeof(float) * d_ne;
2529
2633
 
@@ -2563,8 +2667,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2563
2667
  GGML_ASSERT(qy_sz == y_sz);
2564
2668
  }
2565
2669
 
2566
- vk_pipeline * to_fp16_vk_0 = nullptr;
2567
- vk_pipeline* to_fp16_vk_1 = nullptr;
2670
+ vk_pipeline to_fp16_vk_0 = nullptr;
2671
+ vk_pipeline to_fp16_vk_1 = nullptr;
2568
2672
  if (x_non_contig) {
2569
2673
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
2570
2674
  }
@@ -2573,30 +2677,30 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2573
2677
  } else {
2574
2678
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
2575
2679
  }
2576
- vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
2680
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
2577
2681
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
2578
2682
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
2579
2683
  GGML_ASSERT(dmmv != nullptr);
2580
2684
 
2581
2685
  // Allocate descriptor sets
2582
2686
  if (qx_needs_dequant) {
2583
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, 1);
2687
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
2584
2688
  }
2585
2689
  if (qy_needs_dequant) {
2586
- ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2690
+ ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
2587
2691
  }
2588
- ggml_pipeline_allocate_descriptor_sets(ctx, *dmmv, ne12 * ne13);
2692
+ ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
2589
2693
 
2590
2694
  if (x_non_contig) {
2591
- GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment));
2592
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, src0->type);
2695
+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
2696
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
2593
2697
  } else if (load_x) {
2594
2698
  // copy data to device
2595
2699
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qx, 0, src0, 0, 0, ggml_nrows(src0));
2596
2700
  }
2597
2701
  if (y_non_contig) {
2598
2702
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
2599
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, src1->type);
2703
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
2600
2704
  } else if (load_y) {
2601
2705
  ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
2602
2706
  }
@@ -2613,24 +2717,24 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
2613
2717
  const uint64_t y_offset = y_buf_offset + y_sz * it_idx1;
2614
2718
  const uint64_t d_offset = d_buf_offset + d_sz * it_idx1;
2615
2719
 
2616
- const uint64_t y_buffer_offset = (y_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2720
+ const uint64_t y_buffer_offset = (y_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2617
2721
  const uint64_t y_shader_offset = y_offset - y_buffer_offset;
2618
2722
 
2619
- const uint64_t d_buffer_offset = (d_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2723
+ const uint64_t d_buffer_offset = (d_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2620
2724
  const uint64_t d_shader_offset = d_offset - d_buffer_offset;
2621
2725
 
2622
2726
  if (!y_non_contig && qy_needs_dequant) {
2623
- const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
2727
+ const std::vector<uint32_t> pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) };
2624
2728
  ggml_vk_sync_buffers(subctx);
2625
- ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
2729
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1});
2626
2730
  }
2627
2731
 
2628
2732
  // compute
2629
- const std::array<int, 3> pc = { (int)ne00, (int)(y_shader_offset / ggml_type_size(src1->type)), (int)(d_shader_offset / ggml_type_size(dst->type))};
2733
+ const std::array<uint32_t, 3> pc = { (uint32_t)ne00, (uint32_t)(y_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type))};
2630
2734
  ggml_vk_sync_buffers(subctx);
2631
- ggml_vk_dispatch_pipeline(ctx, subctx, *dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
2735
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
2632
2736
 
2633
- if (dst->backend == GGML_BACKEND_CPU) {
2737
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2634
2738
  // copy dst to host
2635
2739
  float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
2636
2740
  ggml_vk_sync_buffers(subctx);
@@ -2647,7 +2751,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2647
2751
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
2648
2752
  #endif
2649
2753
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
2650
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2754
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
2651
2755
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
2652
2756
  GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
2653
2757
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -2674,18 +2778,18 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2674
2778
 
2675
2779
  bool src1_uma = false;
2676
2780
 
2677
- if (ctx->device.lock()->uma) {
2781
+ if (ctx->device->uma) {
2678
2782
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2679
2783
  src1_uma = d_Qy != nullptr;
2680
2784
  }
2681
2785
 
2682
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2786
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2683
2787
 
2684
2788
  const uint64_t x_ne = ne00 * ne01 * ne02;
2685
2789
  const uint64_t y_ne = ne10 * ne11 * ne12;
2686
2790
  const uint64_t d_ne = ne01 * ne11 * ne12;
2687
2791
 
2688
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
2792
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
2689
2793
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2690
2794
  const uint64_t d_sz = sizeof(float) * d_ne;
2691
2795
 
@@ -2704,12 +2808,12 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2704
2808
  }
2705
2809
 
2706
2810
  // Allocate descriptor sets
2707
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, 1);
2811
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
2708
2812
 
2709
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2813
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2710
2814
  const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
2711
2815
 
2712
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2816
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2713
2817
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
2714
2818
 
2715
2819
  if (load_y) {
@@ -2719,9 +2823,9 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
2719
2823
  // compute
2720
2824
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
2721
2825
  ggml_vk_sync_buffers(subctx);
2722
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2826
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2723
2827
 
2724
- if (dst->backend == GGML_BACKEND_CPU) {
2828
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2725
2829
  // copy dst to host
2726
2830
  float * d = (float *) dst->data;
2727
2831
  ggml_vk_sync_buffers(subctx);
@@ -2738,7 +2842,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2738
2842
  GGML_ASSERT(!ggml_is_transposed(src0));
2739
2843
  GGML_ASSERT(!ggml_is_transposed(src1));
2740
2844
  GGML_ASSERT(!ggml_is_permuted(src0));
2741
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2845
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
2742
2846
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2743
2847
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2744
2848
 
@@ -2766,12 +2870,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2766
2870
 
2767
2871
  bool src1_uma = false;
2768
2872
 
2769
- if (ctx->device.lock()->uma) {
2873
+ if (ctx->device->uma) {
2770
2874
  ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
2771
2875
  src1_uma = d_Qy != nullptr;
2772
2876
  }
2773
2877
 
2774
- const bool load_y = src1->backend != GGML_BACKEND_GPU && !src1_uma;
2878
+ const bool load_y = src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
2775
2879
 
2776
2880
  const uint64_t d_ne = ne01 * ne11 * ne12;
2777
2881
 
@@ -2797,12 +2901,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2797
2901
  }
2798
2902
 
2799
2903
  // Allocate descriptor sets
2800
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, 1);
2904
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
2801
2905
 
2802
- const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2906
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2803
2907
  const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
2804
2908
 
2805
- const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
2909
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
2806
2910
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
2807
2911
 
2808
2912
  if (load_y) {
@@ -2812,9 +2916,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
2812
2916
  // compute
2813
2917
  const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
2814
2918
  ggml_vk_sync_buffers(subctx);
2815
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2919
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
2816
2920
 
2817
- if (dst->backend == GGML_BACKEND_CPU) {
2921
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
2818
2922
  // copy dst to host
2819
2923
  float * d = (float *) dst->data;
2820
2924
  ggml_vk_sync_buffers(subctx);
@@ -2832,7 +2936,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr
2832
2936
  return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2833
2937
  (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || ggml_is_quantized(src1->type)) &&
2834
2938
  dst->type == GGML_TYPE_F32 &&
2835
- ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU);
2939
+ ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU);
2836
2940
  }
2837
2941
 
2838
2942
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -2850,6 +2954,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
2850
2954
  }
2851
2955
  }
2852
2956
 
2957
+ // static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
2958
+ //
2959
+ // }
2960
+
2853
2961
  static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2854
2962
  // guaranteed to be an integer due to the check in ggml_can_repeat
2855
2963
  const uint64_t ne0 = dst->ne[0];
@@ -2880,8 +2988,8 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
2880
2988
  // TODO: support for transposed / permuted tensors
2881
2989
  GGML_ASSERT(nb0 == sizeof(float));
2882
2990
  GGML_ASSERT(nb00 == sizeof(float));
2883
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2884
- GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
2991
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
2992
+ GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
2885
2993
 
2886
2994
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
2887
2995
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
@@ -2921,40 +3029,40 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
2921
3029
  }
2922
3030
 
2923
3031
 
2924
- static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) {
3032
+ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
2925
3033
  switch (op) {
2926
3034
  case GGML_OP_ADD:
2927
3035
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2928
- return &ctx->pipeline_add_f32;
3036
+ return ctx->device->pipeline_add_f32;
2929
3037
  }
2930
3038
  return nullptr;
2931
3039
  case GGML_OP_GET_ROWS:
2932
3040
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
2933
3041
  if (dst->type == GGML_TYPE_F16) {
2934
- return &ctx->pipeline_get_rows[src0->type];
3042
+ return ctx->device->pipeline_get_rows[src0->type];
2935
3043
  }
2936
3044
  if (dst->type == GGML_TYPE_F32) {
2937
- return &ctx->pipeline_get_rows_f32[src0->type];
3045
+ return ctx->device->pipeline_get_rows_f32[src0->type];
2938
3046
  }
2939
3047
  return nullptr;
2940
3048
  case GGML_OP_MUL:
2941
3049
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2942
- return &ctx->pipeline_mul_f32;
3050
+ return ctx->device->pipeline_mul_f32;
2943
3051
  }
2944
3052
  return nullptr;
2945
3053
  case GGML_OP_SCALE:
2946
3054
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2947
- return &ctx->pipeline_scale_f32;
3055
+ return ctx->device->pipeline_scale_f32;
2948
3056
  }
2949
3057
  return nullptr;
2950
3058
  case GGML_OP_SQR:
2951
3059
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2952
- return &ctx->pipeline_sqr_f32;
3060
+ return ctx->device->pipeline_sqr_f32;
2953
3061
  }
2954
3062
  return nullptr;
2955
3063
  case GGML_OP_CLAMP:
2956
3064
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2957
- return &ctx->pipeline_clamp_f32;
3065
+ return ctx->device->pipeline_clamp_f32;
2958
3066
  }
2959
3067
  return nullptr;
2960
3068
  case GGML_OP_CPY:
@@ -2963,29 +3071,29 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
2963
3071
  return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
2964
3072
  case GGML_OP_NORM:
2965
3073
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2966
- return &ctx->pipeline_norm_f32;
3074
+ return ctx->device->pipeline_norm_f32;
2967
3075
  }
2968
3076
  return nullptr;
2969
3077
  case GGML_OP_RMS_NORM:
2970
3078
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2971
- return &ctx->pipeline_rms_norm_f32;
3079
+ return ctx->device->pipeline_rms_norm_f32;
2972
3080
  }
2973
3081
  return nullptr;
2974
3082
  case GGML_OP_UNARY:
2975
3083
  switch (ggml_get_unary_op(dst)) {
2976
3084
  case GGML_UNARY_OP_SILU:
2977
3085
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2978
- return &ctx->pipeline_silu_f32;
3086
+ return ctx->device->pipeline_silu_f32;
2979
3087
  }
2980
3088
  break;
2981
3089
  case GGML_UNARY_OP_GELU:
2982
3090
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2983
- return &ctx->pipeline_gelu_f32;
3091
+ return ctx->device->pipeline_gelu_f32;
2984
3092
  }
2985
3093
  break;
2986
3094
  case GGML_UNARY_OP_RELU:
2987
3095
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2988
- return &ctx->pipeline_relu_f32;
3096
+ return ctx->device->pipeline_relu_f32;
2989
3097
  }
2990
3098
  break;
2991
3099
  default:
@@ -2994,12 +3102,12 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
2994
3102
  return nullptr;
2995
3103
  case GGML_OP_DIAG_MASK_INF:
2996
3104
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2997
- return &ctx->pipeline_diag_mask_inf_f32;
3105
+ return ctx->device->pipeline_diag_mask_inf_f32;
2998
3106
  }
2999
3107
  return nullptr;
3000
3108
  case GGML_OP_SOFT_MAX:
3001
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3002
- return &ctx->pipeline_soft_max_f32;
3109
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3110
+ return ctx->device->pipeline_soft_max_f32;
3003
3111
  }
3004
3112
  return nullptr;
3005
3113
  case GGML_OP_ROPE:
@@ -3014,21 +3122,26 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3014
3122
 
3015
3123
  if (is_neox) {
3016
3124
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3017
- return &ctx->pipeline_rope_neox_f32;
3125
+ return ctx->device->pipeline_rope_neox_f32;
3018
3126
  }
3019
3127
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
3020
- return &ctx->pipeline_rope_neox_f16;
3128
+ return ctx->device->pipeline_rope_neox_f16;
3021
3129
  }
3022
3130
  } else {
3023
3131
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3024
- return &ctx->pipeline_rope_f32;
3132
+ return ctx->device->pipeline_rope_f32;
3025
3133
  }
3026
3134
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
3027
- return &ctx->pipeline_rope_f16;
3135
+ return ctx->device->pipeline_rope_f16;
3028
3136
  }
3029
3137
  }
3030
3138
  return nullptr;
3031
3139
  }
3140
+ case GGML_OP_ARGSORT:
3141
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
3142
+ return ctx->device->pipeline_argsort_f32;
3143
+ }
3144
+ return nullptr;
3032
3145
  default:
3033
3146
  return nullptr;
3034
3147
  }
@@ -3044,17 +3157,19 @@ static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) {
3044
3157
  }
3045
3158
 
3046
3159
  template<typename PC>
3047
- static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, const PC&& pc) {
3160
+ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
3048
3161
  #ifdef GGML_VULKAN_DEBUG
3049
3162
  std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
3050
3163
  if (src1 != nullptr) {
3051
3164
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
3052
3165
  }
3166
+ if (src2 != nullptr) {
3167
+ std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", backend=" << src2->backend << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
3168
+ }
3053
3169
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
3054
3170
  #endif
3055
3171
  GGML_ASSERT(!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))); // NOLINT
3056
3172
  GGML_ASSERT(op == GGML_OP_CPY || ggml_vk_dim01_contiguous(src0)); // NOLINT
3057
- GGML_ASSERT(src1 == nullptr || ggml_vk_dim01_contiguous(src1)); // NOLINT
3058
3173
  GGML_ASSERT(dst->extra != nullptr);
3059
3174
  const uint64_t ne00 = src0->ne[0];
3060
3175
  const uint64_t ne01 = src0->ne[1];
@@ -3071,7 +3186,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3071
3186
  const uint64_t nb2 = dst->nb[2];
3072
3187
  const uint64_t nb3 = dst->nb[3];
3073
3188
 
3074
- vk_pipeline * pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, dst, op);
3189
+ const bool use_src2 = src2 != nullptr;
3190
+ const uint64_t ne2 = use_src2 ? src2->ne[0] * src2->ne[1] : 0;
3191
+
3192
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
3075
3193
  ggml_vk_func_t op_func;
3076
3194
 
3077
3195
  if (pipeline == nullptr) {
@@ -3092,40 +3210,50 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3092
3210
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
3093
3211
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
3094
3212
  ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
3213
+ ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
3095
3214
 
3096
3215
  vk_buffer d_X = nullptr;
3097
3216
  size_t x_buf_offset = 0;
3098
3217
  vk_buffer d_Y = nullptr;
3099
3218
  size_t y_buf_offset = 0;
3219
+ vk_buffer d_Z = nullptr;
3220
+ size_t z_buf_offset = 0;
3100
3221
 
3101
3222
  bool src0_uma = false;
3102
3223
  bool src1_uma = false;
3224
+ bool src2_uma = false;
3103
3225
 
3104
- if (ctx->device.lock()->uma) {
3226
+ if (ctx->device->uma) {
3105
3227
  ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset);
3106
3228
  src0_uma = d_X != nullptr;
3107
3229
  if (use_src1) {
3108
3230
  ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset);
3109
3231
  src1_uma = d_Y != nullptr;
3110
3232
  }
3233
+ if (use_src2) {
3234
+ ggml_vk_host_get(ctx, src1->data, d_Z, z_buf_offset);
3235
+ src2_uma = d_Z != nullptr;
3236
+ }
3111
3237
  }
3112
3238
 
3113
- const bool transfer_src0 = src0->backend != GGML_BACKEND_GPU && !src0_uma;
3114
- const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_GPU && !src1_uma;
3239
+ const bool transfer_src0 = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
3240
+ const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
3241
+ const bool transfer_src2 = use_src2 && src2->backend != GGML_BACKEND_TYPE_GPU && !src2_uma;
3115
3242
 
3116
- uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
3117
- uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : 0;
3243
+ uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
3244
+ uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
3245
+ uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
3118
3246
  uint64_t d_sz = ggml_type_size(dst->type) * ne0;
3119
3247
 
3120
3248
  vk_buffer d_D = extra->buffer_gpu.lock();
3121
3249
 
3122
3250
  // Workaround for tiny tensor inputs on ROPE
3123
- if (use_src1 && src1->backend == GGML_BACKEND_GPU && y_sz > d_D->size) {
3251
+ if (use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU && y_sz > d_D->size) {
3124
3252
  y_sz = VK_WHOLE_SIZE;
3125
3253
  }
3126
3254
 
3127
3255
  GGML_ASSERT(d_D != nullptr);
3128
- uint64_t d_buf_offset = (extra->offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
3256
+ uint64_t d_buf_offset = (extra->offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
3129
3257
  GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT
3130
3258
  if (transfer_src0) {
3131
3259
  d_X = ctx->prealloc_qx;
@@ -3142,6 +3270,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3142
3270
  GGML_ASSERT(d_Y != nullptr);
3143
3271
  }
3144
3272
 
3273
+ GGML_ASSERT(!transfer_src2);
3274
+ if (use_src2 && !src2_uma) {
3275
+ d_Z = extra_src2->buffer_gpu.lock();
3276
+ z_buf_offset = extra_src2->offset;
3277
+ GGML_ASSERT(d_Z != nullptr);
3278
+ }
3279
+
3145
3280
  if (op == GGML_OP_CPY) {
3146
3281
  GGML_ASSERT(!transfer_src0);
3147
3282
  GGML_ASSERT(!transfer_src1);
@@ -3169,7 +3304,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3169
3304
 
3170
3305
  // Single call if dimension 2 is contiguous
3171
3306
  if (op == GGML_OP_CPY || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
3172
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, 1);
3307
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
3173
3308
 
3174
3309
  switch (dst->op) {
3175
3310
  case GGML_OP_NORM:
@@ -3198,26 +3333,42 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3198
3333
  }
3199
3334
  }
3200
3335
 
3201
- if (!use_src1 && op == GGML_OP_SOFT_MAX) {
3202
- // Empty src1 is possible on soft_max, but the shader needs a buffer
3336
+ if (op == GGML_OP_SOFT_MAX) {
3337
+ // Empty src1 and src2 are possible on soft_max, but the shader needs buffers
3338
+ vk_subbuffer subbuf_y;
3339
+ if (use_src1) {
3340
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
3341
+ } else {
3342
+ subbuf_y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
3343
+ }
3344
+
3345
+ vk_subbuffer subbuf_z;
3346
+ if (use_src2) {
3347
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
3348
+ } else {
3349
+ subbuf_z = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
3350
+ }
3351
+
3203
3352
  ggml_vk_sync_buffers(subctx);
3204
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3353
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3205
3354
  } else if (use_src1) {
3206
3355
  ggml_vk_sync_buffers(subctx);
3207
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3356
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3208
3357
  } else {
3209
3358
  ggml_vk_sync_buffers(subctx);
3210
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3359
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3211
3360
  }
3212
- if (dst->backend == GGML_BACKEND_CPU && op == GGML_OP_CPY) {
3361
+ if (dst->backend == GGML_BACKEND_TYPE_CPU && op == GGML_OP_CPY) {
3213
3362
  ggml_vk_d2h_tensor_2d(ctx, subctx, d_D, 0, dst);
3214
- } else if(dst->backend == GGML_BACKEND_CPU) {
3363
+ } else if(dst->backend == GGML_BACKEND_TYPE_CPU) {
3215
3364
  // copy dst to host
3216
3365
  float * d = (float *) dst->data;
3217
3366
  ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, d_sz);
3218
3367
  }
3219
3368
  } else {
3220
- ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne02 * ne03);
3369
+ GGML_ASSERT(op != GGML_OP_SOFT_MAX);
3370
+
3371
+ ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03);
3221
3372
 
3222
3373
  switch (dst->op) {
3223
3374
  case GGML_OP_NORM:
@@ -3242,18 +3393,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3242
3393
  const uint32_t y_offset = y_sz * it_idx1;
3243
3394
  const uint32_t d_offset = d_sz * it_idx0;
3244
3395
 
3245
- if (!use_src1 && op == GGML_OP_SOFT_MAX) {
3246
- // Empty src1 is possible on soft_max, but the shader needs a buffer
3247
- ggml_vk_sync_buffers(subctx);
3248
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
3249
- } else if (use_src1) {
3396
+ if (use_src1) {
3250
3397
  ggml_vk_sync_buffers(subctx);
3251
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3398
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3252
3399
  } else {
3253
3400
  ggml_vk_sync_buffers(subctx);
3254
- ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3401
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
3255
3402
  }
3256
- if (dst->backend == GGML_BACKEND_CPU) {
3403
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
3257
3404
  // copy dst to host
3258
3405
  ggml_vk_buffer_read_async(ctx, subctx, d_D, d_buf_offset + d_offset, (char *) dst->data + i02*nb2 + i03*nb3, d_sz);
3259
3406
  }
@@ -3263,69 +3410,141 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
3263
3410
  }
3264
3411
 
3265
3412
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3266
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3413
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3267
3414
  }
3268
3415
 
3269
3416
  static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3270
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3417
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3271
3418
  }
3272
3419
 
3273
3420
  static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3274
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ADD, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3421
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3422
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
3423
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3424
+
3425
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
3426
+ (uint32_t)ggml_nelements(src0),
3427
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3428
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
3429
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3430
+ 0,
3431
+ 0.0f, 0.0f,
3432
+ });
3275
3433
  }
3276
3434
 
3277
3435
  static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3278
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_MUL, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
3436
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3437
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
3438
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3439
+
3440
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
3441
+ (uint32_t)ggml_nelements(src0),
3442
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3443
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
3444
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3445
+ 0,
3446
+ 0.0f, 0.0f,
3447
+ });
3279
3448
  }
3280
3449
 
3281
3450
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3282
3451
  float * op_params = (float *)dst->op_params;
3283
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SCALE, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
3452
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3453
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3454
+
3455
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
3456
+ (uint32_t)ggml_nelements(src0),
3457
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3458
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3459
+ 0,
3460
+ op_params[0], 0.0f
3461
+ });
3284
3462
  }
3285
3463
 
3286
3464
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3287
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SQR, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3465
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3466
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3467
+
3468
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
3469
+ (uint32_t)ggml_nelements(src0),
3470
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3471
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3472
+ 0,
3473
+ 0.0f, 0.0f,
3474
+ });
3288
3475
  }
3289
3476
 
3290
3477
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3291
3478
  float * op_params = (float *)dst->op_params;
3292
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CLAMP, { (uint32_t)ggml_nelements(src0), 0, op_params[0], op_params[1] });
3479
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3480
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3481
+
3482
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
3483
+ (uint32_t)ggml_nelements(src0),
3484
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3485
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3486
+ 0,
3487
+ op_params[0], op_params[1],
3488
+ });
3293
3489
  }
3294
3490
 
3295
3491
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3296
3492
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
3297
- const int src0_type_size = ggml_type_size(src0->type);
3298
- const int dst_type_size = ggml_type_size(dst->type);
3299
- const uint32_t d_offset = (extra->offset % ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
3300
- ggml_vk_op_f32<vk_op_cpy_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CPY, {
3493
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
3494
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
3495
+ const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
3496
+
3497
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
3301
3498
  (uint32_t)ggml_nelements(src0),
3302
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size,
3303
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size,
3499
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
3500
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
3304
3501
  d_offset,
3502
+ 0.0f, 0.0f,
3305
3503
  });
3306
3504
  }
3307
3505
 
3308
3506
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3309
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
3507
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
3310
3508
  }
3311
3509
 
3312
3510
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3313
3511
  float * op_params = (float *)dst->op_params;
3314
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
3512
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
3315
3513
  }
3316
3514
 
3317
3515
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3318
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3516
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
3319
3517
  }
3320
3518
 
3321
3519
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3322
3520
  int32_t * op_params = (int32_t *)dst->op_params;
3323
- ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
3521
+ ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
3324
3522
  }
3325
3523
 
3326
- static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3524
+ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
3327
3525
  float * op_params = (float *)dst->op_params;
3328
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_SOFT_MAX, { (uint32_t)src0->ne[0], (uint32_t)(src1 != nullptr ? ggml_nrows(src1) : 0), op_params[0], 0.0f });
3526
+
3527
+ float scale = op_params[0];
3528
+ float max_bias = op_params[1];
3529
+
3530
+ const uint32_t ncols = (uint32_t)src0->ne[0];
3531
+ const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
3532
+ const uint32_t nrows_y = (uint32_t)src0->ne[1];
3533
+
3534
+ const uint32_t n_head_kv = nrows_x/nrows_y;
3535
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
3536
+
3537
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
3538
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
3539
+
3540
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
3541
+ ncols,
3542
+ nrows_y,
3543
+ src2 != nullptr ? (uint32_t)1 : (uint32_t)0,
3544
+ scale, max_bias,
3545
+ m0, m1,
3546
+ n_head_log2,
3547
+ });
3329
3548
  }
3330
3549
 
3331
3550
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3351,15 +3570,20 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
3351
3570
  if (is_neox) {
3352
3571
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
3353
3572
  const float inv_ndims = -1.0f / n_dims;
3354
- ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
3573
+ ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
3355
3574
  } else {
3356
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
3575
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
3357
3576
  }
3358
3577
  }
3359
3578
 
3579
+ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3580
+ int32_t * op_params = (int32_t *)dst->op_params;
3581
+ ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ORDER_ASC });
3582
+ }
3583
+
3360
3584
  static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
3361
3585
  // If backend is CPU, data from src0 has to be copied off the device
3362
- if (dst->backend == GGML_BACKEND_CPU) {
3586
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
3363
3587
  ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
3364
3588
  vk_buffer d_D = extra_src0->buffer_gpu.lock();
3365
3589
  ggml_vk_sync_buffers(subctx);
@@ -3408,43 +3632,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3408
3632
  const size_t y_ne = k * n * batch;
3409
3633
  const size_t d_ne = m * n * batch;
3410
3634
 
3411
- vk_pipeline * p;
3635
+ vk_pipeline p;
3412
3636
  std::string shname;
3413
3637
  if (shader_size == 0) {
3414
3638
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3415
- p = &ctx->pipeline_matmul_f32_aligned_s;
3639
+ p = ctx->device->pipeline_matmul_f32->a_s;
3416
3640
  shname = "F32_ALIGNED_S";
3417
3641
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3418
- p = &ctx->pipeline_matmul_f16_f32_aligned_s;
3642
+ p = ctx->device->pipeline_matmul_f16_f32->a_s;
3419
3643
  shname = "F16_F32_ALIGNED_S";
3420
3644
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3421
- p = &ctx->pipeline_matmul_f16_aligned_s;
3645
+ p = ctx->device->pipeline_matmul_f16->a_s;
3422
3646
  shname = "F16_ALIGNED_S";
3423
3647
  } else {
3424
3648
  GGML_ASSERT(false);
3425
3649
  }
3426
3650
  } else if (shader_size == 1) {
3427
3651
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3428
- p = &ctx->pipeline_matmul_f32_aligned_m;
3652
+ p = ctx->device->pipeline_matmul_f32->a_m;
3429
3653
  shname = "F32_ALIGNED_M";
3430
3654
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3431
- p = &ctx->pipeline_matmul_f16_f32_aligned_m;
3655
+ p = ctx->device->pipeline_matmul_f16_f32->a_m;
3432
3656
  shname = "F16_F32_ALIGNED_M";
3433
3657
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3434
- p = &ctx->pipeline_matmul_f16_aligned_m;
3658
+ p = ctx->device->pipeline_matmul_f16->a_m;
3435
3659
  shname = "F16_ALIGNED_M";
3436
3660
  } else {
3437
3661
  GGML_ASSERT(false);
3438
3662
  }
3439
3663
  } else if (shader_size == 2) {
3440
3664
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3441
- p = &ctx->pipeline_matmul_f32_aligned_l;
3665
+ p = ctx->device->pipeline_matmul_f32->a_l;
3442
3666
  shname = "F32_ALIGNED_L";
3443
3667
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3444
- p = &ctx->pipeline_matmul_f16_f32_aligned_l;
3668
+ p = ctx->device->pipeline_matmul_f16_f32->a_l;
3445
3669
  shname = "F16_F32_ALIGNED_L";
3446
3670
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3447
- p = &ctx->pipeline_matmul_f16_aligned_l;
3671
+ p = ctx->device->pipeline_matmul_f16->a_l;
3448
3672
  shname = "F16_ALIGNED_L";
3449
3673
  } else {
3450
3674
  GGML_ASSERT(false);
@@ -3458,43 +3682,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3458
3682
  if (k != kpad) {
3459
3683
  if (shader_size == 0) {
3460
3684
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3461
- p = &ctx->pipeline_matmul_f32_s;
3685
+ p = ctx->device->pipeline_matmul_f32->s;
3462
3686
  shname = "F32_S";
3463
3687
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3464
- p = &ctx->pipeline_matmul_f16_f32_s;
3688
+ p = ctx->device->pipeline_matmul_f16_f32->s;
3465
3689
  shname = "F16_F32_S";
3466
3690
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3467
- p = &ctx->pipeline_matmul_f16_s;
3691
+ p = ctx->device->pipeline_matmul_f16->s;
3468
3692
  shname = "F16_S";
3469
3693
  }
3470
3694
  } else if (shader_size == 1) {
3471
3695
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3472
- p = &ctx->pipeline_matmul_f32_m;
3696
+ p = ctx->device->pipeline_matmul_f32->m;
3473
3697
  shname = "F32_M";
3474
3698
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3475
- p = &ctx->pipeline_matmul_f16_f32_m;
3699
+ p = ctx->device->pipeline_matmul_f16_f32->m;
3476
3700
  shname = "F16_F32_M";
3477
3701
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3478
- p = &ctx->pipeline_matmul_f16_m;
3702
+ p = ctx->device->pipeline_matmul_f16->m;
3479
3703
  shname = "F16_M";
3480
3704
  }
3481
3705
  } else if (shader_size == 2) {
3482
3706
  if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3483
- p = &ctx->pipeline_matmul_f32_l;
3707
+ p = ctx->device->pipeline_matmul_f32->l;
3484
3708
  shname = "F32_L";
3485
3709
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
3486
- p = &ctx->pipeline_matmul_f16_f32_l;
3710
+ p = ctx->device->pipeline_matmul_f16_f32->l;
3487
3711
  shname = "F16_F32_L";
3488
3712
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
3489
- p = &ctx->pipeline_matmul_f16_l;
3713
+ p = ctx->device->pipeline_matmul_f16->l;
3490
3714
  shname = "F16_L";
3491
3715
  }
3492
3716
  }
3493
3717
  }
3494
3718
 
3495
- ggml_pipeline_allocate_descriptor_sets(ctx, *p, num_it);
3719
+ ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
3496
3720
  if (split_k > 1) {
3497
- ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, num_it);
3721
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
3498
3722
 
3499
3723
  if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
3500
3724
  // Resize buffer
@@ -3524,9 +3748,11 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3524
3748
  }
3525
3749
  for (size_t i = 0; i < y_ne; i++) {
3526
3750
  if (std::is_same<float, Y_TYPE>()) {
3527
- y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3751
+ // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3752
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
3528
3753
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
3529
- y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
3754
+ // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
3755
+ y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
3530
3756
  } else {
3531
3757
  GGML_ASSERT(false);
3532
3758
  }
@@ -3535,17 +3761,17 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3535
3761
  ggml_vk_buffer_write(ctx, d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
3536
3762
  ggml_vk_buffer_write(ctx, d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
3537
3763
 
3538
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
3764
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3539
3765
  for (size_t i = 0; i < num_it; i++) {
3540
3766
  ggml_vk_ctx_begin(ctx, subctx);
3541
- ggml_vk_matmul(ctx, subctx, *p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
3767
+ ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
3542
3768
  ggml_vk_ctx_end(subctx);
3543
3769
  }
3544
3770
 
3545
3771
  auto begin = std::chrono::high_resolution_clock::now();
3546
3772
  ggml_vk_submit(subctx, ctx->fence);
3547
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
3548
- ctx->device.lock()->device.resetFences({ ctx->fence });
3773
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
3774
+ ctx->device->device.resetFences({ ctx->fence });
3549
3775
 
3550
3776
  auto end = std::chrono::high_resolution_clock::now();
3551
3777
  double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
@@ -3624,6 +3850,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3624
3850
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
3625
3851
  std::cerr << "Actual result: " << std::endl << std::endl;
3626
3852
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
3853
+ std::cerr << std::endl;
3854
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
3627
3855
  std::cerr << "Expected result: " << std::endl << std::endl;
3628
3856
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
3629
3857
 
@@ -3649,15 +3877,15 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
3649
3877
 
3650
3878
  free(d_chk);
3651
3879
 
3652
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
3653
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
3880
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
3881
+ ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
3654
3882
 
3655
3883
  ggml_vk_destroy_buffer(d_X);
3656
3884
  ggml_vk_destroy_buffer(d_Y);
3657
3885
  ggml_vk_destroy_buffer(d_D);
3658
3886
 
3659
- ggml_pipeline_cleanup(*p);
3660
- ggml_pipeline_cleanup(ctx->pipeline_matmul_split_k_reduce);
3887
+ ggml_pipeline_cleanup(p);
3888
+ ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
3661
3889
 
3662
3890
  free(x);
3663
3891
  free(y);
@@ -3730,7 +3958,7 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
3730
3958
  data[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
3731
3959
  }
3732
3960
 
3733
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
3961
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3734
3962
  ggml_vk_ctx_begin(ctx, subctx);
3735
3963
 
3736
3964
  vk_buffer buffer = ggml_vk_create_buffer_check(ctx, ggml_nbytes(tensor), vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -3739,8 +3967,8 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
3739
3967
 
3740
3968
  ggml_vk_ctx_end(subctx);
3741
3969
  ggml_vk_submit(subctx, ctx->fence);
3742
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
3743
- ctx->device.lock()->device.resetFences({ ctx->fence });
3970
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
3971
+ ctx->device->device.resetFences({ ctx->fence });
3744
3972
 
3745
3973
  ggml_vk_buffer_read(ctx, buffer, 0, result_data, ggml_nbytes(tensor));
3746
3974
 
@@ -3812,7 +4040,7 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3812
4040
  x[i] = rand() / (float)RAND_MAX;
3813
4041
  }
3814
4042
 
3815
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4043
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3816
4044
  ggml_vk_ctx_begin(ctx, subctx);
3817
4045
 
3818
4046
  auto begin = std::chrono::high_resolution_clock::now();
@@ -3826,8 +4054,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3826
4054
 
3827
4055
  ggml_vk_ctx_end(subctx);
3828
4056
  ggml_vk_submit(subctx, ctx->fence);
3829
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
3830
- ctx->device.lock()->device.resetFences({ ctx->fence });
4057
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
4058
+ ctx->device->device.resetFences({ ctx->fence });
3831
4059
 
3832
4060
  auto end = std::chrono::high_resolution_clock::now();
3833
4061
 
@@ -3841,8 +4069,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3841
4069
 
3842
4070
  ggml_vk_ctx_end(subctx);
3843
4071
  ggml_vk_submit(subctx, ctx->fence);
3844
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
3845
- ctx->device.lock()->device.resetFences({ ctx->fence });
4072
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
4073
+ ctx->device->device.resetFences({ ctx->fence });
3846
4074
 
3847
4075
  for (auto& cpy : subctx->out_memcpys) {
3848
4076
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -3873,89 +4101,118 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
3873
4101
  }
3874
4102
  }
3875
4103
 
3876
- static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
3877
- #ifdef GGML_VULKAN_DEBUG
3878
- std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
3879
- #endif
3880
- const size_t x_sz = sizeof(float) * ne;
3881
- const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
3882
- const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
3883
- float * x = (float *) malloc(x_sz);
3884
- void * qx = malloc(qx_sz);
3885
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
3886
- vk_buffer x_buf = ggml_vk_create_buffer_check(ctx, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
3887
- ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
3888
-
3889
- for (size_t i = 0; i < ne; i++) {
3890
- x[i] = rand() / (float)RAND_MAX;
3891
- }
3892
-
4104
+ static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
3893
4105
  std::vector<int64_t> hist_cur(1 << 4, 0);
3894
4106
 
3895
- vk_pipeline& p = ctx->pipeline_dequant[quant];
3896
-
3897
4107
  switch(quant) {
4108
+ case GGML_TYPE_F32:
4109
+ memcpy(to, from, sizeof(float) * ne);
4110
+ break;
3898
4111
  case GGML_TYPE_Q4_0:
3899
- ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
4112
+ ggml_quantize_q4_0(from, to, ne, ne, hist_cur.data());
3900
4113
  break;
3901
4114
  case GGML_TYPE_Q4_1:
3902
- ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
4115
+ ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
3903
4116
  break;
3904
4117
  case GGML_TYPE_Q5_0:
3905
- ggml_quantize_q5_0(x, qx, ne, ne, hist_cur.data());
4118
+ ggml_quantize_q5_0(from, to, ne, ne, hist_cur.data());
3906
4119
  break;
3907
4120
  case GGML_TYPE_Q5_1:
3908
- ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
4121
+ ggml_quantize_q5_1(from, to, ne, ne, hist_cur.data());
3909
4122
  break;
3910
4123
  case GGML_TYPE_Q8_0:
3911
- ggml_quantize_q8_0(x, qx, ne, ne, hist_cur.data());
4124
+ ggml_quantize_q8_0(from, to, ne, ne, hist_cur.data());
3912
4125
  break;
3913
4126
  case GGML_TYPE_Q2_K:
3914
- ggml_quantize_q2_K(x, qx, ne, ne, hist_cur.data());
4127
+ ggml_quantize_q2_K(from, to, ne, ne, hist_cur.data());
3915
4128
  break;
3916
4129
  case GGML_TYPE_Q3_K:
3917
- ggml_quantize_q3_K(x, qx, ne, ne, hist_cur.data());
4130
+ ggml_quantize_q3_K(from, to, ne, ne, hist_cur.data());
3918
4131
  break;
3919
4132
  case GGML_TYPE_Q4_K:
3920
- ggml_quantize_q4_K(x, qx, ne, ne, hist_cur.data());
4133
+ ggml_quantize_q4_K(from, to, ne, ne, hist_cur.data());
3921
4134
  break;
3922
4135
  case GGML_TYPE_Q5_K:
3923
- ggml_quantize_q5_K(x, qx, ne, ne, hist_cur.data());
4136
+ ggml_quantize_q5_K(from, to, ne, ne, hist_cur.data());
3924
4137
  break;
3925
4138
  case GGML_TYPE_Q6_K:
3926
- ggml_quantize_q6_K(x, qx, ne, ne, hist_cur.data());
4139
+ ggml_quantize_q6_K(from, to, ne, ne, hist_cur.data());
3927
4140
  break;
3928
4141
  default:
3929
4142
  GGML_ASSERT(false);
3930
4143
  }
4144
+ }
4145
+
4146
+ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
4147
+ #ifdef GGML_VULKAN_DEBUG
4148
+ std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
4149
+ #endif
4150
+ const size_t x_sz = sizeof(float) * ne;
4151
+ const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
4152
+ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
4153
+ float * x = (float *) malloc(x_sz);
4154
+ void * qx = malloc(qx_sz);
4155
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4156
+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
4157
+ ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
4158
+
4159
+ for (size_t i = 0; i < ne; i++) {
4160
+ x[i] = rand() / (float)RAND_MAX;
4161
+ }
4162
+
4163
+ vk_pipeline p = ctx->device->pipeline_dequant[quant];
4164
+
4165
+ ggml_vk_quantize_data(x, qx, ne, quant);
3931
4166
 
3932
4167
  ggml_pipeline_allocate_descriptor_sets(ctx, p, 1);
3933
4168
 
3934
4169
  ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
3935
4170
 
3936
- vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4171
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
3937
4172
  ggml_vk_ctx_begin(ctx, subctx);
3938
- const std::vector<int> pc = { 1, (int)ne, (int)ne, (int)ne };
4173
+ const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
3939
4174
  ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
3940
4175
  ggml_vk_ctx_end(subctx);
3941
4176
 
3942
4177
  auto begin = std::chrono::high_resolution_clock::now();
3943
4178
 
3944
4179
  ggml_vk_submit(subctx, ctx->fence);
3945
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
3946
- ctx->device.lock()->device.resetFences({ ctx->fence });
4180
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
4181
+ ctx->device->device.resetFences({ ctx->fence });
3947
4182
 
3948
4183
  auto end = std::chrono::high_resolution_clock::now();
3949
4184
 
3950
4185
  double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
3951
4186
  ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16);
3952
4187
 
4188
+ int first_err = -1;
4189
+
3953
4190
  double avg_err = 0.0;
3954
4191
  for (size_t i = 0; i < ne; i++) {
3955
- avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
4192
+ double error = std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
4193
+ avg_err += error;
4194
+
4195
+ if (first_err < 0 && error > 0.05) {
4196
+ first_err = i;
4197
+ }
3956
4198
  }
3957
4199
 
3958
- std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err / ne << std::endl;
4200
+ avg_err /= ne;
4201
+
4202
+ std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
4203
+
4204
+ if (avg_err > 0.1) {
4205
+ std::cerr << "first_error = " << first_err << std::endl;
4206
+ std::cerr << "Actual result: " << std::endl << std::endl;
4207
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
4208
+ std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
4209
+ }
4210
+ std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
4211
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
4212
+ std::cerr << x[i] << ", ";
4213
+ }
4214
+ std::cerr << std::endl;
4215
+ }
3959
4216
 
3960
4217
  ggml_vk_destroy_buffer(x_buf);
3961
4218
  ggml_vk_destroy_buffer(qx_buf);
@@ -3964,6 +4221,190 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
3964
4221
  free(qx);
3965
4222
  free(x_chk);
3966
4223
  }
4224
+
4225
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
4226
+ #ifdef GGML_VULKAN_DEBUG
4227
+ std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
4228
+ #endif
4229
+ const size_t x_ne = m * k * batch;
4230
+ const size_t y_ne = k * n * batch;
4231
+ const size_t d_ne = m * n * batch;
4232
+
4233
+ vk_pipeline p;
4234
+ std::string shname;
4235
+ if (shader_size == 0) {
4236
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
4237
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
4238
+ } else if (shader_size == 1) {
4239
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
4240
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
4241
+ } else if (shader_size == 2) {
4242
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
4243
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
4244
+ } else {
4245
+ GGML_ASSERT(0);
4246
+ }
4247
+
4248
+ const size_t kpad = ggml_vk_align_size(k, p->align);
4249
+
4250
+ if (k != kpad) {
4251
+ if (shader_size == 0) {
4252
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
4253
+ shname = std::string(ggml_type_name(quant)) + "_S";
4254
+ } else if (shader_size == 1) {
4255
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
4256
+ shname = std::string(ggml_type_name(quant)) + "_M";
4257
+ } else if (shader_size == 2) {
4258
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
4259
+ shname = std::string(ggml_type_name(quant)) + "_L";
4260
+ } else {
4261
+ GGML_ASSERT(0);
4262
+ }
4263
+ }
4264
+
4265
+ const size_t x_sz = sizeof(float) * x_ne;
4266
+ const size_t y_sz = sizeof(float) * y_ne;
4267
+ const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
4268
+ const size_t d_sz = sizeof(float) * d_ne;
4269
+ float * x = (float *) malloc(x_sz);
4270
+ float * y = (float *) malloc(y_sz);
4271
+ void * qx = malloc(qx_sz);
4272
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4273
+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4274
+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
4275
+ float * d = (float *) malloc(d_sz);
4276
+ float * d_chk = (float *) malloc(d_sz);
4277
+
4278
+ for (size_t i = 0; i < x_ne; i++) {
4279
+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
4280
+ }
4281
+
4282
+ ggml_vk_quantize_data(x, qx, x_ne, quant);
4283
+
4284
+ for (size_t i = 0; i < y_ne; i++) {
4285
+ // y[i] = rand() / (float)RAND_MAX;
4286
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
4287
+ }
4288
+
4289
+ ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
4290
+ if (split_k > 1) {
4291
+ ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
4292
+
4293
+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
4294
+ // Resize buffer
4295
+ if (ctx->prealloc_split_k != nullptr) {
4296
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
4297
+ }
4298
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
4299
+ }
4300
+ }
4301
+
4302
+ ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
4303
+ ggml_vk_buffer_write(ctx, y_buf, 0, y, y_sz);
4304
+
4305
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
4306
+ for (size_t i = 0; i < num_it; i++) {
4307
+ ggml_vk_ctx_begin(ctx, subctx);
4308
+ ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
4309
+ ggml_vk_ctx_end(subctx);
4310
+ }
4311
+
4312
+ auto begin = std::chrono::high_resolution_clock::now();
4313
+
4314
+ ggml_vk_submit(subctx, ctx->fence);
4315
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
4316
+ ctx->device->device.resetFences({ ctx->fence });
4317
+
4318
+ auto end = std::chrono::high_resolution_clock::now();
4319
+
4320
+ double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
4321
+ ggml_vk_buffer_read(ctx, d_buf, 0, d, d_sz);
4322
+
4323
+ ggml_init_params iparams = {
4324
+ /*.mem_size =*/ 1024*1024*1024,
4325
+ /*.mem_buffer =*/ NULL,
4326
+ /*.no_alloc =*/ true,
4327
+ };
4328
+
4329
+ ggml_context * ggml_ctx = ggml_init(iparams);
4330
+
4331
+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
4332
+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
4333
+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
4334
+
4335
+ src0_ggml->data = qx;
4336
+ src1_ggml->data = y;
4337
+ tensor_ggml->data = d_chk;
4338
+
4339
+ ctx->disable = true;
4340
+
4341
+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
4342
+ ggml_build_forward_expand(cgraph, tensor_ggml);
4343
+
4344
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
4345
+
4346
+ ctx->disable = false;
4347
+
4348
+ ggml_free(ggml_ctx);
4349
+
4350
+ double avg_err = 0.0;
4351
+ int first_err_n = -1;
4352
+ int first_err_m = -1;
4353
+ int first_err_b = -1;
4354
+
4355
+ for (size_t i = 0; i < m*n*batch; i++) {
4356
+ double err = std::fabs(d[i] - d_chk[i]);
4357
+ avg_err += err;
4358
+
4359
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
4360
+ first_err_b = i / (m * n);
4361
+ first_err_n = (i % (m * n)) / m;
4362
+ first_err_m = (i % (m * n)) % m;
4363
+ }
4364
+ }
4365
+
4366
+ avg_err /= m * n;
4367
+
4368
+ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
4369
+
4370
+ if (avg_err > 0.1 || std::isnan(avg_err)) {
4371
+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
4372
+ std::cerr << "Actual result: " << std::endl << std::endl;
4373
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4374
+ std::cerr << std::endl;
4375
+ std::cerr << "Expected result: " << std::endl << std::endl;
4376
+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4377
+
4378
+ if (split_k > 1) {
4379
+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
4380
+ ggml_vk_buffer_read(ctx, ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
4381
+
4382
+ std::cerr << "d_buf0: " << std::endl << std::endl;
4383
+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4384
+
4385
+ std::cerr << "d_buf1: " << std::endl << std::endl;
4386
+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4387
+
4388
+ std::cerr << "d_buf2: " << std::endl << std::endl;
4389
+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4390
+
4391
+ std::cerr << "d_buf3: " << std::endl << std::endl;
4392
+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4393
+
4394
+ free(split_k_buf);
4395
+ }
4396
+ }
4397
+
4398
+ ggml_vk_destroy_buffer(qx_buf);
4399
+ ggml_vk_destroy_buffer(y_buf);
4400
+ ggml_vk_destroy_buffer(d_buf);
4401
+
4402
+ free(x);
4403
+ free(qx);
4404
+ free(y);
4405
+ free(d);
4406
+ free(d_chk);
4407
+ }
3967
4408
  #endif
3968
4409
 
3969
4410
  static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
@@ -3976,29 +4417,19 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
3976
4417
  return extra;
3977
4418
  }
3978
4419
 
3979
- static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph * graph) {
3980
- GGML_ASSERT(node != nullptr);
3981
-
3982
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
3983
- for (int j = 0; j < GGML_MAX_SRC; j++) {
3984
- if (graph->nodes[i]->src[j] == node) {
3985
- return graph->nodes[i];
3986
- }
3987
- }
3988
- }
3989
-
3990
- return nullptr;
4420
+ static bool ggml_vk_cpu_assist_op(const ggml_tensor * node) {
4421
+ return node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID;
3991
4422
  }
3992
4423
 
3993
4424
  static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
3994
4425
  #ifdef GGML_VULKAN_DEBUG
3995
4426
  std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl;
3996
4427
  #endif
3997
- const bool any_on_device = node->backend == GGML_BACKEND_GPU
3998
- || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
3999
- || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_GPU));
4428
+ const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU
4429
+ || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4430
+ || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_TYPE_GPU));
4000
4431
 
4001
- if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
4432
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node))) {
4002
4433
  return;
4003
4434
  }
4004
4435
 
@@ -4029,7 +4460,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4029
4460
  const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
4030
4461
 
4031
4462
  int split_k;
4032
- if (node->op == GGML_OP_MUL_MAT) {
4463
+ if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
4033
4464
  split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
4034
4465
  } else {
4035
4466
  split_k = 1;
@@ -4038,11 +4469,11 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4038
4469
  const uint32_t y_ne = ne10 * ne11;
4039
4470
  const uint32_t d_ne = ne20 * ne21;
4040
4471
 
4041
- const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4042
- const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4043
- const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4044
- const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4045
- uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
4472
+ const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4473
+ const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4474
+ const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
4475
+ const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
4476
+ uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
4046
4477
  const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0;
4047
4478
 
4048
4479
  if (extra->buffer_gpu.expired()) {
@@ -4070,6 +4501,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4070
4501
  case GGML_OP_DIAG_MASK_INF:
4071
4502
  case GGML_OP_SOFT_MAX:
4072
4503
  case GGML_OP_ROPE:
4504
+ case GGML_OP_ARGSORT:
4073
4505
  break;
4074
4506
  case GGML_OP_UNARY:
4075
4507
  switch (ggml_get_unary_op(node)) {
@@ -4082,6 +4514,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
4082
4514
  }
4083
4515
  break;
4084
4516
  case GGML_OP_MUL_MAT:
4517
+ case GGML_OP_MUL_MAT_ID:
4085
4518
  if (ctx->prealloc_size_qx < qx_sz) {
4086
4519
  ctx->prealloc_size_qx = qx_sz;
4087
4520
  }
@@ -4115,21 +4548,66 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4115
4548
  #endif
4116
4549
  #if defined(GGML_VULKAN_RUN_TESTS)
4117
4550
  ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul,
4118
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4551
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4119
4552
  vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4120
4553
  ggml_vk_test_transfer(ctx, 8192 * 1000, false);
4121
4554
  ggml_vk_test_transfer(ctx, 8192 * 1000, true);
4122
4555
 
4123
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
4124
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
4125
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
4126
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
4127
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
4128
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
4129
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
4130
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
4131
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
4132
- ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
4556
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
4557
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
4558
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
4559
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
4560
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
4561
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
4562
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
4563
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
4564
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
4565
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
4566
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
4567
+
4568
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
4569
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
4570
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
4571
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
4572
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
4573
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
4574
+
4575
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
4576
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
4577
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
4578
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
4579
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
4580
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
4581
+
4582
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
4583
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
4584
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
4585
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
4586
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
4587
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
4588
+
4589
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
4590
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
4591
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
4592
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
4593
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
4594
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
4595
+
4596
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
4597
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
4598
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
4599
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
4600
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
4601
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
4602
+
4603
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
4604
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
4605
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
4606
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
4607
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
4608
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
4609
+
4610
+ std::cerr << std::endl;
4133
4611
 
4134
4612
  const std::vector<size_t> vals {
4135
4613
  8, 8, 8,
@@ -4215,11 +4693,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4215
4693
  }
4216
4694
 
4217
4695
  static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){
4218
- const bool any_on_device = node->backend == GGML_BACKEND_GPU
4219
- || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
4220
- || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_GPU);
4696
+ const bool any_on_device = node->backend == GGML_BACKEND_TYPE_GPU
4697
+ || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4698
+ || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_TYPE_GPU);
4221
4699
 
4222
- if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
4700
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node)) || (ggml_vk_cpu_assist_op(node) && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
4223
4701
  return;
4224
4702
  }
4225
4703
 
@@ -4231,6 +4709,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4231
4709
 
4232
4710
  const ggml_tensor * src0 = node->src[0];
4233
4711
  const ggml_tensor * src1 = node->src[1];
4712
+ const ggml_tensor * src2 = node->src[2];
4234
4713
 
4235
4714
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
4236
4715
 
@@ -4265,7 +4744,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4265
4744
  case GGML_OP_SOFT_MAX:
4266
4745
  case GGML_OP_ROPE:
4267
4746
  case GGML_OP_MUL_MAT:
4747
+ case GGML_OP_MUL_MAT_ID:
4268
4748
  case GGML_OP_NONE:
4749
+ case GGML_OP_ARGSORT:
4269
4750
  break;
4270
4751
  default:
4271
4752
  if (any_on_device) {
@@ -4276,7 +4757,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4276
4757
  }
4277
4758
 
4278
4759
  if (ctx->compute_ctx == nullptr) {
4279
- ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
4760
+ ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
4280
4761
  ggml_vk_ctx_begin(ctx, ctx->compute_ctx);
4281
4762
  }
4282
4763
 
@@ -4347,16 +4828,25 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4347
4828
 
4348
4829
  break;
4349
4830
  case GGML_OP_SOFT_MAX:
4350
- ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
4831
+ ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, src2, node);
4351
4832
 
4352
4833
  break;
4353
4834
  case GGML_OP_ROPE:
4354
4835
  ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, node);
4355
4836
 
4837
+ break;
4838
+ case GGML_OP_ARGSORT:
4839
+ ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
4356
4840
  break;
4357
4841
  case GGML_OP_MUL_MAT:
4358
4842
  ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
4359
4843
 
4844
+ break;
4845
+ case GGML_OP_MUL_MAT_ID:
4846
+ //ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, node);
4847
+ std::cerr << "ggml_vulkan: GGML_OP_MUL_MAT_ID not implemented yet." << std::endl;
4848
+ GGML_ASSERT(false);
4849
+
4360
4850
  break;
4361
4851
  default:
4362
4852
  return;
@@ -4371,7 +4861,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4371
4861
  last_node = true;
4372
4862
  #endif
4373
4863
 
4374
- if (node->backend == GGML_BACKEND_CPU || last_node) {
4864
+ if (node->backend == GGML_BACKEND_TYPE_CPU || last_node) {
4375
4865
  ggml_vk_ctx_end(ctx->compute_ctx);
4376
4866
  ctx->compute_ctx->exit_tensor = node;
4377
4867
  ctx->compute_ctx = nullptr;
@@ -4379,11 +4869,11 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
4379
4869
  }
4380
4870
 
4381
4871
  static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){
4382
- const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
4383
- || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
4384
- || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
4872
+ const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
4873
+ || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
4874
+ || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
4385
4875
 
4386
- if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
4876
+ if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(tensor))) {
4387
4877
  return false;
4388
4878
  }
4389
4879
 
@@ -4409,6 +4899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4409
4899
  case GGML_OP_PERMUTE:
4410
4900
  case GGML_OP_TRANSPOSE:
4411
4901
  case GGML_OP_NONE:
4902
+ case GGML_OP_ARGSORT:
4412
4903
  extra = (ggml_tensor_extra_gpu *) tensor->extra;
4413
4904
 
4414
4905
  break;
@@ -4424,6 +4915,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4424
4915
  }
4425
4916
  break;
4426
4917
  case GGML_OP_MUL_MAT:
4918
+ case GGML_OP_MUL_MAT_ID:
4427
4919
  if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
4428
4920
  return false;
4429
4921
  }
@@ -4442,7 +4934,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4442
4934
  if (params->ith != 0) {
4443
4935
  return true;
4444
4936
  }
4445
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
4937
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
4446
4938
  return true;
4447
4939
  }
4448
4940
 
@@ -4469,8 +4961,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
4469
4961
  }
4470
4962
 
4471
4963
  if (tensor == subctx.exit_tensor) {
4472
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
4473
- ctx->device.lock()->device.resetFences({ ctx->fence });
4964
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
4965
+ ctx->device->device.resetFences({ ctx->fence });
4474
4966
 
4475
4967
  // Do staging buffer copies
4476
4968
  for (auto& cpy : subctx.out_memcpys) {
@@ -4498,20 +4990,25 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
4498
4990
  }
4499
4991
  ctx->gc.temp_buffers.clear();
4500
4992
 
4501
- for (auto * pipeline : ctx->gc.pipelines) {
4502
- ggml_pipeline_cleanup(*pipeline);
4993
+ for (auto& pipeline : ctx->device->pipelines) {
4994
+ if (pipeline.expired()) {
4995
+ continue;
4996
+ }
4997
+
4998
+ vk_pipeline pl = pipeline.lock();
4999
+ ggml_pipeline_cleanup(pl);
4503
5000
  }
4504
5001
 
4505
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
4506
- ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
5002
+ ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
5003
+ ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
4507
5004
 
4508
5005
  for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
4509
- ctx->device.lock()->device.destroySemaphore({ ctx->gc.semaphores[i].s });
5006
+ ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
4510
5007
  }
4511
5008
  ctx->gc.semaphores.clear();
4512
5009
 
4513
5010
  for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
4514
- ctx->device.lock()->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
5011
+ ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
4515
5012
  }
4516
5013
  ctx->gc.tl_semaphores.clear();
4517
5014
  ctx->semaphore_idx = 0;
@@ -4519,7 +5016,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
4519
5016
  ctx->event_idx = 0;
4520
5017
 
4521
5018
  for (auto& event : ctx->gc.events) {
4522
- ctx->device.lock()->device.resetEvent(event);
5019
+ ctx->device->device.resetEvent(event);
4523
5020
  }
4524
5021
 
4525
5022
  ctx->staging_offset = 0;
@@ -4556,21 +5053,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
4556
5053
  ctx->staging_size = 0;
4557
5054
 
4558
5055
  for (auto& event : ctx->gc.events) {
4559
- ctx->device.lock()->device.destroyEvent(event);
5056
+ ctx->device->device.destroyEvent(event);
4560
5057
  }
4561
5058
  ctx->gc.events.clear();
4562
5059
 
4563
- for (auto* pipeline : ctx->gc.pipelines) {
4564
- ggml_vk_destroy_pipeline(ctx, pipeline);
4565
- }
4566
- ctx->gc.pipelines.clear();
4567
-
4568
- ctx->device.lock()->device.destroyFence(ctx->fence);
4569
-
4570
- ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->compute_queue.pool);
4571
- if (!ctx->device.lock()->single_queue) {
4572
- ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->transfer_queue.pool);
4573
- }
5060
+ ctx->device->device.destroyFence(ctx->fence);
4574
5061
  }
4575
5062
 
4576
5063
  GGML_CALL static int ggml_vk_get_device_count() {
@@ -4745,7 +5232,7 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b
4745
5232
  extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
4746
5233
  }
4747
5234
 
4748
- tensor->backend = GGML_BACKEND_GPU;
5235
+ tensor->backend = GGML_BACKEND_TYPE_GPU;
4749
5236
  tensor->extra = extra;
4750
5237
  }
4751
5238
 
@@ -4753,7 +5240,7 @@ GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t bu
4753
5240
  #ifdef GGML_VULKAN_DEBUG
4754
5241
  std::cerr << "ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl;
4755
5242
  #endif
4756
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5243
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
4757
5244
 
4758
5245
  ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4759
5246
 
@@ -4768,7 +5255,7 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
4768
5255
  #ifdef GGML_VULKAN_DEBUG
4769
5256
  std::cerr << "ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")" << std::endl;
4770
5257
  #endif
4771
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5258
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
4772
5259
 
4773
5260
  ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4774
5261
 
@@ -4781,7 +5268,6 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
4781
5268
 
4782
5269
  GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
4783
5270
  if (ggml_backend_buffer_is_vk(src->buffer)) {
4784
- ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
4785
5271
  ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
4786
5272
  ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
4787
5273
 
@@ -4793,6 +5279,8 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu
4793
5279
  return true;
4794
5280
  }
4795
5281
  return false;
5282
+
5283
+ UNUSED(buffer);
4796
5284
  }
4797
5285
 
4798
5286
  GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4839,12 +5327,12 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
4839
5327
 
4840
5328
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4841
5329
  ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
4842
- return ctx->ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
5330
+ return ctx->ctx->device->properties.limits.minStorageBufferOffsetAlignment;
4843
5331
  }
4844
5332
 
4845
5333
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
4846
5334
  ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
4847
- return ctx->ctx->device.lock()->max_memory_allocation_size;
5335
+ return ctx->ctx->device->max_memory_allocation_size;
4848
5336
  }
4849
5337
 
4850
5338
  GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@@ -4930,7 +5418,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu
4930
5418
  }
4931
5419
 
4932
5420
  GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4933
- return vk_instance.contexts[0].device.lock()->properties.limits.minMemoryMapAlignment;
5421
+ return vk_instance.contexts[0].device->properties.limits.minMemoryMapAlignment;
4934
5422
 
4935
5423
  UNUSED(buft);
4936
5424
  }
@@ -4975,8 +5463,7 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
4975
5463
 
4976
5464
  ggml_vk_cleanup(ctx);
4977
5465
 
4978
- // Release device
4979
- vk_instance.devices[ctx->idx].reset();
5466
+ ctx->device.reset();
4980
5467
  ctx->initialized = false;
4981
5468
 
4982
5469
  vk_instance.initialized[idx] = false;
@@ -4999,13 +5486,13 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
4999
5486
  #endif
5000
5487
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5001
5488
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
5002
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5489
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
5003
5490
 
5004
5491
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5005
5492
 
5006
5493
  if (ctx->transfer_ctx == nullptr) {
5007
5494
  // Initialize new transfer context
5008
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5495
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5009
5496
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5010
5497
  }
5011
5498
 
@@ -5020,13 +5507,13 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
5020
5507
  #endif
5021
5508
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5022
5509
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
5023
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
5510
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
5024
5511
 
5025
5512
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5026
5513
 
5027
5514
  if (ctx->transfer_ctx == nullptr) {
5028
5515
  // Initialize new transfer context
5029
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5516
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5030
5517
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5031
5518
  }
5032
5519
 
@@ -5046,7 +5533,7 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
5046
5533
 
5047
5534
  if (ctx->transfer_ctx == nullptr) {
5048
5535
  // Initialize new transfer context
5049
- ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
5536
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
5050
5537
  ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
5051
5538
  }
5052
5539
 
@@ -5076,8 +5563,8 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
5076
5563
  }
5077
5564
 
5078
5565
  ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
5079
- VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
5080
- ctx->device.lock()->device.resetFences({ ctx->fence });
5566
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
5567
+ ctx->device->device.resetFences({ ctx->fence });
5081
5568
 
5082
5569
  for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
5083
5570
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -5086,7 +5573,7 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
5086
5573
  ctx->transfer_ctx = nullptr;
5087
5574
  }
5088
5575
 
5089
- GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
5576
+ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
5090
5577
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
5091
5578
 
5092
5579
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -5097,7 +5584,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5097
5584
  int last_node = cgraph->n_nodes - 1;
5098
5585
 
5099
5586
  // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
5100
- while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_GPU) {
5587
+ while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_TYPE_GPU) {
5101
5588
  last_node -= 1;
5102
5589
  }
5103
5590
 
@@ -5106,7 +5593,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5106
5593
  }
5107
5594
 
5108
5595
  ggml_compute_params params = {};
5109
- params.type = GGML_TASK_COMPUTE;
5596
+ params.type = GGML_TASK_TYPE_COMPUTE;
5110
5597
  params.ith = 0;
5111
5598
  for (int i = 0; i < cgraph->n_nodes; i++) {
5112
5599
  ggml_tensor * node = cgraph->nodes[i];
@@ -5129,7 +5616,7 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
5129
5616
 
5130
5617
  ggml_vk_graph_cleanup(ctx);
5131
5618
 
5132
- return true;
5619
+ return GGML_STATUS_SUCCESS;
5133
5620
 
5134
5621
  UNUSED(backend);
5135
5622
  }
@@ -5147,6 +5634,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
5147
5634
  }
5148
5635
  break;
5149
5636
  case GGML_OP_MUL_MAT:
5637
+ case GGML_OP_MUL_MAT_ID:
5150
5638
  {
5151
5639
  struct ggml_tensor * a;
5152
5640
  struct ggml_tensor * b;
@@ -5220,6 +5708,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
5220
5708
  case GGML_OP_CONT:
5221
5709
  case GGML_OP_DIAG_MASK_INF:
5222
5710
  case GGML_OP_SOFT_MAX:
5711
+ case GGML_OP_ARGSORT:
5223
5712
  return true;
5224
5713
  default:
5225
5714
  return false;
@@ -5244,6 +5733,11 @@ static ggml_backend_i ggml_backend_vk_interface = {
5244
5733
  /* .supports_op = */ ggml_backend_vk_supports_op,
5245
5734
  };
5246
5735
 
5736
+ static ggml_guid_t ggml_backend_vk_guid() {
5737
+ static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
5738
+ return &guid;
5739
+ }
5740
+
5247
5741
  GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t idx) {
5248
5742
  if (vk_instance.initialized[idx]) {
5249
5743
  return vk_instance.backends[idx];
@@ -5262,6 +5756,7 @@ GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t idx) {
5262
5756
  vk_instance.initialized[idx] = true;
5263
5757
 
5264
5758
  ggml_backend_t vk_backend = new ggml_backend {
5759
+ /* .guid = */ ggml_backend_vk_guid(),
5265
5760
  /* .interface = */ ggml_backend_vk_interface,
5266
5761
  /* .context = */ &vk_instance.contexts[ctx->idx],
5267
5762
  };
@@ -5272,7 +5767,7 @@ GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t idx) {
5272
5767
  }
5273
5768
 
5274
5769
  GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend) {
5275
- return backend && backend->iface.get_name == ggml_backend_vk_name;
5770
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
5276
5771
  }
5277
5772
 
5278
5773
  GGML_CALL int ggml_backend_vk_get_device_count() {
@@ -5410,13 +5905,14 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d
5410
5905
  static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) {
5411
5906
  void * tensor_data = tensor->data;
5412
5907
 
5413
- if (tensor->backend == GGML_BACKEND_GPU) {
5908
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5414
5909
  const size_t tensor_size = ggml_nbytes(tensor);
5415
5910
  tensor_data = malloc(tensor_size);
5416
5911
 
5417
5912
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5418
5913
 
5419
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, extra->offset, tensor_data, tensor_size);
5914
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5915
+ ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size);
5420
5916
  }
5421
5917
 
5422
5918
  std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
@@ -5436,14 +5932,14 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
5436
5932
  std::vector<const ggml_tensor *> done;
5437
5933
  ggml_vk_print_graph_origin(tensor, done);
5438
5934
 
5439
- if (tensor->backend == GGML_BACKEND_GPU) {
5935
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5440
5936
  free(tensor_data);
5441
5937
  }
5442
5938
  }
5443
5939
 
5444
5940
  static void ggml_vk_check_tensor(const std::string& name, const ggml_tensor * tensor) {
5445
5941
  return;
5446
- GGML_ASSERT(tensor->backend == GGML_BACKEND_CPU);
5942
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_CPU);
5447
5943
  if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
5448
5944
  return;
5449
5945
  }
@@ -5481,7 +5977,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5481
5977
  if (params->ith != 0) {
5482
5978
  return;
5483
5979
  }
5484
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
5980
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
5485
5981
  return;
5486
5982
  }
5487
5983
 
@@ -5492,6 +5988,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5492
5988
 
5493
5989
  ggml_tensor * src0 = tensor->src[0];
5494
5990
  ggml_tensor * src1 = tensor->src[1];
5991
+ ggml_tensor * src2 = tensor->src[2];
5495
5992
 
5496
5993
  struct ggml_init_params iparams = {
5497
5994
  /*.mem_size =*/ 1024*1024*1024,
@@ -5503,13 +6000,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5503
6000
 
5504
6001
  struct ggml_tensor * src0_clone = nullptr;
5505
6002
  struct ggml_tensor * src1_clone = nullptr;
6003
+ struct ggml_tensor * src2_clone = nullptr;
5506
6004
  struct ggml_tensor * tensor_clone = nullptr;
5507
6005
 
5508
6006
  size_t src0_size;
5509
6007
  size_t src1_size;
6008
+ size_t src2_size;
5510
6009
 
5511
6010
  void * src0_buffer;
5512
6011
  void * src1_buffer;
6012
+ void * src2_buffer;
5513
6013
 
5514
6014
  if (src0 != nullptr) {
5515
6015
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -5518,17 +6018,18 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5518
6018
 
5519
6019
  src0_buffer = malloc(src0_size);
5520
6020
  src0_clone->data = src0_buffer;
5521
- if (src0->backend == GGML_BACKEND_CPU) {
6021
+ if (src0->backend == GGML_BACKEND_TYPE_CPU) {
5522
6022
  memcpy(src0_clone->data, src0->data, src0_size);
5523
6023
  memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
5524
- } else if (src0->backend == GGML_BACKEND_GPU) {
6024
+ } else if (src0->backend == GGML_BACKEND_TYPE_GPU) {
5525
6025
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
6026
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5526
6027
  uint64_t offset = extra->offset;
5527
6028
  if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
5528
6029
  for (int i3 = 0; i3 < src0->ne[3]; i3++) {
5529
6030
  for (int i2 = 0; i2 < src0->ne[2]; i2++) {
5530
6031
  const int idx = i3*src0->ne[2] + i2;
5531
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
6032
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
5532
6033
  }
5533
6034
  }
5534
6035
 
@@ -5538,10 +6039,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5538
6039
  src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
5539
6040
  }
5540
6041
  } else {
5541
- if (offset + src0_size >= extra->buffer_gpu->size) {
5542
- src0_size = extra->buffer_gpu->size - offset;
6042
+ if (offset + src0_size >= buffer_gpu->size) {
6043
+ src0_size = buffer_gpu->size - offset;
5543
6044
  }
5544
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset, src0_clone->data, src0_size);
6045
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size);
5545
6046
  memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
5546
6047
  }
5547
6048
  } else {
@@ -5561,17 +6062,18 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5561
6062
 
5562
6063
  src1_buffer = malloc(src1_size);
5563
6064
  src1_clone->data = src1_buffer;
5564
- if (src1->backend == GGML_BACKEND_CPU) {
6065
+ if (src1->backend == GGML_BACKEND_TYPE_CPU) {
5565
6066
  memcpy(src1_clone->data, src1->data, src1_size);
5566
6067
  memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
5567
- } else if (src1->backend == GGML_BACKEND_GPU) {
6068
+ } else if (src1->backend == GGML_BACKEND_TYPE_GPU) {
5568
6069
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
6070
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
5569
6071
  uint64_t offset = extra->offset;
5570
6072
  if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
5571
6073
  for (int i3 = 0; i3 < src1->ne[3]; i3++) {
5572
6074
  for (int i2 = 0; i2 < src1->ne[2]; i2++) {
5573
6075
  const int idx = i3*src1->ne[2] + i2;
5574
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
6076
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
5575
6077
  }
5576
6078
  }
5577
6079
 
@@ -5581,10 +6083,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5581
6083
  src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
5582
6084
  }
5583
6085
  } else {
5584
- if (offset + src1_size >= extra->buffer_gpu->size) {
5585
- src1_size = extra->buffer_gpu->size - offset;
6086
+ if (offset + src1_size >= buffer_gpu->size) {
6087
+ src1_size = buffer_gpu->size - offset;
5586
6088
  }
5587
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, offset, src1_clone->data, src1_size);
6089
+ ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size);
5588
6090
  memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
5589
6091
  }
5590
6092
  } else {
@@ -5613,6 +6115,66 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5613
6115
 
5614
6116
  ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone);
5615
6117
  }
6118
+ if (src2 != nullptr) {
6119
+ src2_clone = ggml_dup_tensor(ggml_ctx, src2);
6120
+
6121
+ src2_size = ggml_nbytes(src2);
6122
+
6123
+ src2_buffer = malloc(src2_size);
6124
+ src2_clone->data = src2_buffer;
6125
+ if (src2->backend == GGML_BACKEND_TYPE_CPU) {
6126
+ memcpy(src2_clone->data, src2->data, src2_size);
6127
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
6128
+ } else if (src2->backend == GGML_BACKEND_TYPE_GPU) {
6129
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
6130
+ vk_buffer buf = extra->buffer_gpu.lock();
6131
+ uint64_t offset = extra->offset;
6132
+ if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
6133
+ for (int i3 = 0; i3 < src2->ne[3]; i3++) {
6134
+ for (int i2 = 0; i2 < src2->ne[2]; i2++) {
6135
+ const int idx = i3*src2->ne[2] + i2;
6136
+ ggml_vk_buffer_read(ctx, buf, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
6137
+ }
6138
+ }
6139
+
6140
+ src2_clone->nb[0] = src2->nb[0];
6141
+ src2_clone->nb[1] = src2->nb[1];
6142
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
6143
+ src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
6144
+ }
6145
+ } else {
6146
+ if (offset + src2_size >= buf->size) {
6147
+ src2_size = buf->size - offset;
6148
+ }
6149
+ ggml_vk_buffer_read(ctx, buf, offset, src2_clone->data, src2_size);
6150
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
6151
+ }
6152
+ } else {
6153
+ GGML_ASSERT(false);
6154
+ }
6155
+
6156
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
6157
+ ggml_vk_print_tensor(ctx, src2, "src2");
6158
+ std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
6159
+ std::cerr << "src2_clone=" << tensor << " src2_clone->backend: " << src2_clone->backend << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
6160
+ if (src2->src[0] != nullptr) {
6161
+ std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " backend=" << src2->src[0]->backend << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
6162
+ }
6163
+ if (src2->src[1] != nullptr) {
6164
+ std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " backend=" << src2->src[1]->backend << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
6165
+ }
6166
+ std::cerr << std::endl << "Result:" << std::endl;
6167
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
6168
+ std::cerr << std::endl;
6169
+ std::cerr << std::endl << "Result:" << std::endl;
6170
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
6171
+ std::cerr << std::endl;
6172
+ std::vector<const ggml_tensor *> done;
6173
+ ggml_vk_print_graph_origin(src2_clone, done);
6174
+ }
6175
+
6176
+ ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src2", src2_clone);
6177
+ }
5616
6178
 
5617
6179
  if (tensor->op == GGML_OP_MUL_MAT) {
5618
6180
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
@@ -5632,7 +6194,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5632
6194
  tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
5633
6195
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
5634
6196
  if (src1 != nullptr) {
5635
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, *(float *)tensor->op_params);
6197
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
5636
6198
  } else {
5637
6199
  tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
5638
6200
  }
@@ -5715,6 +6277,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
5715
6277
  if (src1 != nullptr) {
5716
6278
  free(src1_buffer);
5717
6279
  }
6280
+ if (src2 != nullptr) {
6281
+ free(src1_buffer);
6282
+ }
5718
6283
 
5719
6284
  ggml_free(ggml_ctx);
5720
6285
  }
@@ -5723,7 +6288,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5723
6288
  if (params->ith != 0) {
5724
6289
  return;
5725
6290
  }
5726
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
6291
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
5727
6292
  return;
5728
6293
  }
5729
6294
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
@@ -5735,17 +6300,18 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5735
6300
 
5736
6301
  void * tensor_data = tensor->data;
5737
6302
 
5738
- if (tensor->backend == GGML_BACKEND_GPU) {
6303
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5739
6304
  size_t tensor_size = ggml_nbytes(tensor);
5740
6305
  tensor_data = malloc(tensor_size);
5741
6306
 
5742
6307
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
5743
6308
 
5744
- if (extra->offset + tensor_size >= extra->buffer_gpu->size) {
5745
- tensor_size = extra->buffer_gpu->size - (extra->offset);
6309
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
6310
+ if (extra->offset + tensor_size >= buffer_gpu->size) {
6311
+ tensor_size = buffer_gpu->size - (extra->offset);
5746
6312
  }
5747
6313
 
5748
- ggml_vk_buffer_read(ctx, extra->buffer_gpu, extra->offset, tensor_data, tensor_size);
6314
+ ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size);
5749
6315
  }
5750
6316
 
5751
6317
  float first_error_result = -1.0f;
@@ -5868,7 +6434,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
5868
6434
  comp_result = nullptr;
5869
6435
  comp_size = 0;
5870
6436
 
5871
- if (tensor->backend == GGML_BACKEND_GPU) {
6437
+ if (tensor->backend == GGML_BACKEND_TYPE_GPU) {
5872
6438
  free(tensor_data);
5873
6439
  }
5874
6440
  }