@fugood/llama.node 1.2.1 → 1.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/common/arg.cpp +359 -310
  3. package/src/llama.cpp/common/chat.cpp +27 -15
  4. package/src/llama.cpp/common/common.cpp +1 -0
  5. package/src/llama.cpp/common/sampling.cpp +1 -0
  6. package/src/llama.cpp/ggml/CMakeLists.txt +37 -21
  7. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -1
  8. package/src/llama.cpp/ggml/include/ggml-zdnn.h +3 -0
  9. package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
  10. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  11. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +2 -2
  12. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +14 -0
  13. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +17 -3
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +1 -1
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +93 -862
  16. package/src/llama.cpp/include/llama.h +15 -11
  17. package/src/llama.cpp/src/llama-context.cpp +151 -0
  18. package/src/llama.cpp/src/llama-context.h +10 -0
  19. package/src/llama.cpp/src/llama-cparams.h +1 -1
  20. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +8 -0
  21. package/src/llama.cpp/src/llama-kv-cache-iswa.h +2 -0
  22. package/src/llama.cpp/src/llama-kv-cache.cpp +8 -0
  23. package/src/llama.cpp/src/llama-kv-cache.h +2 -0
  24. package/src/llama.cpp/src/llama-memory-hybrid.cpp +8 -0
  25. package/src/llama.cpp/src/llama-memory-hybrid.h +2 -0
  26. package/src/llama.cpp/src/llama-memory-recurrent.cpp +8 -0
  27. package/src/llama.cpp/src/llama-memory-recurrent.h +3 -0
  28. package/src/llama.cpp/src/llama-memory.h +3 -0
  29. package/src/llama.cpp/src/llama-model.cpp +14 -4
  30. package/src/llama.cpp/src/llama-model.h +5 -1
@@ -41,628 +41,15 @@ static void ggml_compute_forward_dup_same_cont(
41
41
  }
42
42
  }
43
43
 
44
- static void ggml_compute_forward_dup_f16(
45
- const ggml_compute_params * params,
46
- ggml_tensor * dst) {
47
-
48
- const ggml_tensor * src0 = dst->src[0];
49
-
50
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
51
-
52
- GGML_TENSOR_UNARY_OP_LOCALS
53
-
54
- const int ith = params->ith; // thread index
55
- const int nth = params->nth; // number of threads
56
-
57
- // parallelize by rows
58
- const int nr = ne01;
59
- // number of rows per thread
60
- const int dr = (nr + nth - 1) / nth;
61
- // row range for this thread
62
- const int ir0 = dr * ith;
63
- const int ir1 = MIN(ir0 + dr, nr);
64
-
65
- if (src0->type == dst->type &&
66
- ne00 == ne0 &&
67
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
68
- // copy by rows
69
- const size_t rs = ne00*nb00;
70
- for (int64_t i03 = 0; i03 < ne03; i03++) {
71
- for (int64_t i02 = 0; i02 < ne02; i02++) {
72
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
73
- memcpy(
74
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
75
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
76
- rs);
77
- }
78
- }
79
- }
80
- return;
81
- }
82
-
83
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
84
-
85
- if (ggml_is_contiguous(dst)) {
86
- if (nb00 == sizeof(ggml_fp16_t)) {
87
- if (dst->type == GGML_TYPE_F16) {
88
- size_t id = 0;
89
- const size_t rs = ne00 * nb00;
90
- char * dst_ptr = (char *) dst->data;
91
-
92
- for (int i03 = 0; i03 < ne03; i03++) {
93
- for (int i02 = 0; i02 < ne02; i02++) {
94
- id += rs * ir0;
95
- for (int i01 = ir0; i01 < ir1; i01++) {
96
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
97
- memcpy(dst_ptr + id, src0_ptr, rs);
98
- id += rs;
99
- }
100
- id += rs * (ne01 - ir1);
101
- }
102
- }
103
- } else if (dst->type == GGML_TYPE_F32) {
104
- size_t id = 0;
105
- float * dst_ptr = (float *) dst->data;
106
-
107
- for (int i03 = 0; i03 < ne03; i03++) {
108
- for (int i02 = 0; i02 < ne02; i02++) {
109
- id += ne00 * ir0;
110
- for (int i01 = ir0; i01 < ir1; i01++) {
111
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
112
- for (int i00 = 0; i00 < ne00; i00++) {
113
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
114
- id++;
115
- }
116
- }
117
- id += ne00 * (ne01 - ir1);
118
- }
119
- }
120
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
121
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
122
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
123
-
124
- size_t id = 0;
125
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
126
- char * dst_ptr = (char *) dst->data;
127
-
128
- for (int i03 = 0; i03 < ne03; i03++) {
129
- for (int i02 = 0; i02 < ne02; i02++) {
130
- id += rs * ir0;
131
- for (int i01 = ir0; i01 < ir1; i01++) {
132
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
133
-
134
- for (int i00 = 0; i00 < ne00; i00++) {
135
- src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
136
- }
137
-
138
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
139
- id += rs;
140
- }
141
- id += rs * (ne01 - ir1);
142
- }
143
- }
144
- } else {
145
- GGML_ABORT("fatal error"); // TODO: implement
146
- }
147
- } else {
148
- //printf("%s: this is not optimal - fix me\n", __func__);
149
-
150
- if (dst->type == GGML_TYPE_F32) {
151
- size_t id = 0;
152
- float * dst_ptr = (float *) dst->data;
153
-
154
- for (int i03 = 0; i03 < ne03; i03++) {
155
- for (int i02 = 0; i02 < ne02; i02++) {
156
- id += ne00 * ir0;
157
- for (int i01 = ir0; i01 < ir1; i01++) {
158
- for (int i00 = 0; i00 < ne00; i00++) {
159
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
160
-
161
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
162
- id++;
163
- }
164
- }
165
- id += ne00 * (ne01 - ir1);
166
- }
167
- }
168
- } else if (dst->type == GGML_TYPE_F16) {
169
- size_t id = 0;
170
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
171
-
172
- for (int i03 = 0; i03 < ne03; i03++) {
173
- for (int i02 = 0; i02 < ne02; i02++) {
174
- id += ne00 * ir0;
175
- for (int i01 = ir0; i01 < ir1; i01++) {
176
- for (int i00 = 0; i00 < ne00; i00++) {
177
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
-
179
- dst_ptr[id] = *src0_ptr;
180
- id++;
181
- }
182
- }
183
- id += ne00 * (ne01 - ir1);
184
- }
185
- }
186
- } else {
187
- GGML_ABORT("fatal error"); // TODO: implement
188
- }
189
- }
190
- return;
191
- }
192
-
193
- // dst counters
194
- int64_t i10 = 0;
195
- int64_t i11 = 0;
196
- int64_t i12 = 0;
197
- int64_t i13 = 0;
198
-
199
- if (dst->type == GGML_TYPE_F16) {
200
- for (int64_t i03 = 0; i03 < ne03; i03++) {
201
- for (int64_t i02 = 0; i02 < ne02; i02++) {
202
- i10 += ne00 * ir0;
203
- while (i10 >= ne0) {
204
- i10 -= ne0;
205
- if (++i11 == ne1) {
206
- i11 = 0;
207
- if (++i12 == ne2) {
208
- i12 = 0;
209
- if (++i13 == ne3) {
210
- i13 = 0;
211
- }
212
- }
213
- }
214
- }
215
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
216
- for (int64_t i00 = 0; i00 < ne00; i00++) {
217
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
218
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
219
-
220
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
221
-
222
- if (++i10 == ne00) {
223
- i10 = 0;
224
- if (++i11 == ne01) {
225
- i11 = 0;
226
- if (++i12 == ne02) {
227
- i12 = 0;
228
- if (++i13 == ne03) {
229
- i13 = 0;
230
- }
231
- }
232
- }
233
- }
234
- }
235
- }
236
- i10 += ne00 * (ne01 - ir1);
237
- while (i10 >= ne0) {
238
- i10 -= ne0;
239
- if (++i11 == ne1) {
240
- i11 = 0;
241
- if (++i12 == ne2) {
242
- i12 = 0;
243
- if (++i13 == ne3) {
244
- i13 = 0;
245
- }
246
- }
247
- }
248
- }
249
- }
250
- }
251
- } else if (dst->type == GGML_TYPE_F32) {
252
- for (int64_t i03 = 0; i03 < ne03; i03++) {
253
- for (int64_t i02 = 0; i02 < ne02; i02++) {
254
- i10 += ne00 * ir0;
255
- while (i10 >= ne0) {
256
- i10 -= ne0;
257
- if (++i11 == ne1) {
258
- i11 = 0;
259
- if (++i12 == ne2) {
260
- i12 = 0;
261
- if (++i13 == ne3) {
262
- i13 = 0;
263
- }
264
- }
265
- }
266
- }
267
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
268
- for (int64_t i00 = 0; i00 < ne00; i00++) {
269
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
270
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
271
-
272
- *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
273
-
274
- if (++i10 == ne0) {
275
- i10 = 0;
276
- if (++i11 == ne1) {
277
- i11 = 0;
278
- if (++i12 == ne2) {
279
- i12 = 0;
280
- if (++i13 == ne3) {
281
- i13 = 0;
282
- }
283
- }
284
- }
285
- }
286
- }
287
- }
288
- i10 += ne00 * (ne01 - ir1);
289
- while (i10 >= ne0) {
290
- i10 -= ne0;
291
- if (++i11 == ne1) {
292
- i11 = 0;
293
- if (++i12 == ne2) {
294
- i12 = 0;
295
- if (++i13 == ne3) {
296
- i13 = 0;
297
- }
298
- }
299
- }
300
- }
301
- }
302
- }
303
- } else {
304
- GGML_ABORT("fatal error"); // TODO: implement
305
- }
306
- }
307
-
308
- static void ggml_compute_forward_dup_bf16(
309
- const ggml_compute_params * params,
310
- ggml_tensor * dst) {
311
-
312
- const ggml_tensor * src0 = dst->src[0];
313
-
314
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
315
-
316
- GGML_TENSOR_UNARY_OP_LOCALS
317
-
318
- const int ith = params->ith; // thread index
319
- const int nth = params->nth; // number of threads
320
-
321
- // parallelize by rows
322
- const int nr = ne01;
323
- // number of rows per thread
324
- const int dr = (nr + nth - 1) / nth;
325
- // row range for this thread
326
- const int ir0 = dr * ith;
327
- const int ir1 = MIN(ir0 + dr, nr);
328
-
329
- if (src0->type == dst->type &&
330
- ne00 == ne0 &&
331
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
332
- // copy by rows
333
- const size_t rs = ne00*nb00;
334
- for (int64_t i03 = 0; i03 < ne03; i03++) {
335
- for (int64_t i02 = 0; i02 < ne02; i02++) {
336
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
337
- memcpy(
338
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
339
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
340
- rs);
341
- }
342
- }
343
- }
344
- return;
345
- }
346
-
347
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
348
-
349
- if (ggml_is_contiguous(dst)) {
350
- if (nb00 == sizeof(ggml_bf16_t)) {
351
- if (dst->type == GGML_TYPE_BF16) {
352
- size_t id = 0;
353
- const size_t rs = ne00 * nb00;
354
- char * dst_ptr = (char *) dst->data;
355
-
356
- for (int i03 = 0; i03 < ne03; i03++) {
357
- for (int i02 = 0; i02 < ne02; i02++) {
358
- id += rs * ir0;
359
- for (int i01 = ir0; i01 < ir1; i01++) {
360
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
361
- memcpy(dst_ptr + id, src0_ptr, rs);
362
- id += rs;
363
- }
364
- id += rs * (ne01 - ir1);
365
- }
366
- }
367
- } else if (dst->type == GGML_TYPE_F16) {
368
- size_t id = 0;
369
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
370
-
371
- for (int i03 = 0; i03 < ne03; i03++) {
372
- for (int i02 = 0; i02 < ne02; i02++) {
373
- id += ne00 * ir0;
374
- for (int i01 = ir0; i01 < ir1; i01++) {
375
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
376
- for (int i00 = 0; i00 < ne00; i00++) {
377
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
378
- id++;
379
- }
380
- }
381
- id += ne00 * (ne01 - ir1);
382
- }
383
- }
384
- } else if (dst->type == GGML_TYPE_F32) {
385
- size_t id = 0;
386
- float * dst_ptr = (float *) dst->data;
387
-
388
- for (int i03 = 0; i03 < ne03; i03++) {
389
- for (int i02 = 0; i02 < ne02; i02++) {
390
- id += ne00 * ir0;
391
- for (int i01 = ir0; i01 < ir1; i01++) {
392
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
393
- for (int i00 = 0; i00 < ne00; i00++) {
394
- dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
395
- id++;
396
- }
397
- }
398
- id += ne00 * (ne01 - ir1);
399
- }
400
- }
401
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
402
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
403
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
404
-
405
- size_t id = 0;
406
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
407
- char * dst_ptr = (char *) dst->data;
408
-
409
- for (int i03 = 0; i03 < ne03; i03++) {
410
- for (int i02 = 0; i02 < ne02; i02++) {
411
- id += rs * ir0;
412
- for (int i01 = ir0; i01 < ir1; i01++) {
413
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
414
-
415
- for (int i00 = 0; i00 < ne00; i00++) {
416
- src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
417
- }
418
-
419
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
420
- id += rs;
421
- }
422
- id += rs * (ne01 - ir1);
423
- }
424
- }
425
- } else {
426
- GGML_ABORT("fatal error"); // TODO: implement
427
- }
428
- } else {
429
- //printf("%s: this is not optimal - fix me\n", __func__);
430
-
431
- if (dst->type == GGML_TYPE_F32) {
432
- size_t id = 0;
433
- float * dst_ptr = (float *) dst->data;
434
-
435
- for (int i03 = 0; i03 < ne03; i03++) {
436
- for (int i02 = 0; i02 < ne02; i02++) {
437
- id += ne00 * ir0;
438
- for (int i01 = ir0; i01 < ir1; i01++) {
439
- for (int i00 = 0; i00 < ne00; i00++) {
440
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
441
-
442
- dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
443
- id++;
444
- }
445
- }
446
- id += ne00 * (ne01 - ir1);
447
- }
448
- }
449
- } else if (dst->type == GGML_TYPE_BF16) {
450
- size_t id = 0;
451
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
452
-
453
- for (int i03 = 0; i03 < ne03; i03++) {
454
- for (int i02 = 0; i02 < ne02; i02++) {
455
- id += ne00 * ir0;
456
- for (int i01 = ir0; i01 < ir1; i01++) {
457
- for (int i00 = 0; i00 < ne00; i00++) {
458
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
459
-
460
- dst_ptr[id] = *src0_ptr;
461
- id++;
462
- }
463
- }
464
- id += ne00 * (ne01 - ir1);
465
- }
466
- }
467
- } else if (dst->type == GGML_TYPE_F16) {
468
- size_t id = 0;
469
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
470
-
471
- for (int i03 = 0; i03 < ne03; i03++) {
472
- for (int i02 = 0; i02 < ne02; i02++) {
473
- id += ne00 * ir0;
474
- for (int i01 = ir0; i01 < ir1; i01++) {
475
- for (int i00 = 0; i00 < ne00; i00++) {
476
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
477
-
478
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
479
- id++;
480
- }
481
- }
482
- id += ne00 * (ne01 - ir1);
483
- }
484
- }
485
- } else {
486
- GGML_ABORT("fatal error"); // TODO: implement
487
- }
488
- }
489
- return;
490
- }
491
-
492
- // dst counters
493
- int64_t i10 = 0;
494
- int64_t i11 = 0;
495
- int64_t i12 = 0;
496
- int64_t i13 = 0;
497
-
498
- if (dst->type == GGML_TYPE_BF16) {
499
- for (int64_t i03 = 0; i03 < ne03; i03++) {
500
- for (int64_t i02 = 0; i02 < ne02; i02++) {
501
- i10 += ne00 * ir0;
502
- while (i10 >= ne0) {
503
- i10 -= ne0;
504
- if (++i11 == ne1) {
505
- i11 = 0;
506
- if (++i12 == ne2) {
507
- i12 = 0;
508
- if (++i13 == ne3) {
509
- i13 = 0;
510
- }
511
- }
512
- }
513
- }
514
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
515
- for (int64_t i00 = 0; i00 < ne00; i00++) {
516
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
517
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
518
-
519
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
520
-
521
- if (++i10 == ne00) {
522
- i10 = 0;
523
- if (++i11 == ne01) {
524
- i11 = 0;
525
- if (++i12 == ne02) {
526
- i12 = 0;
527
- if (++i13 == ne03) {
528
- i13 = 0;
529
- }
530
- }
531
- }
532
- }
533
- }
534
- }
535
- i10 += ne00 * (ne01 - ir1);
536
- while (i10 >= ne0) {
537
- i10 -= ne0;
538
- if (++i11 == ne1) {
539
- i11 = 0;
540
- if (++i12 == ne2) {
541
- i12 = 0;
542
- if (++i13 == ne3) {
543
- i13 = 0;
544
- }
545
- }
546
- }
547
- }
548
- }
549
- }
550
- } else if (dst->type == GGML_TYPE_F16) {
551
- for (int64_t i03 = 0; i03 < ne03; i03++) {
552
- for (int64_t i02 = 0; i02 < ne02; i02++) {
553
- i10 += ne00 * ir0;
554
- while (i10 >= ne0) {
555
- i10 -= ne0;
556
- if (++i11 == ne1) {
557
- i11 = 0;
558
- if (++i12 == ne2) {
559
- i12 = 0;
560
- if (++i13 == ne3) {
561
- i13 = 0;
562
- }
563
- }
564
- }
565
- }
566
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
567
- for (int64_t i00 = 0; i00 < ne00; i00++) {
568
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
569
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
570
-
571
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
572
-
573
- if (++i10 == ne0) {
574
- i10 = 0;
575
- if (++i11 == ne1) {
576
- i11 = 0;
577
- if (++i12 == ne2) {
578
- i12 = 0;
579
- if (++i13 == ne3) {
580
- i13 = 0;
581
- }
582
- }
583
- }
584
- }
585
- }
586
- }
587
- i10 += ne00 * (ne01 - ir1);
588
- while (i10 >= ne0) {
589
- i10 -= ne0;
590
- if (++i11 == ne1) {
591
- i11 = 0;
592
- if (++i12 == ne2) {
593
- i12 = 0;
594
- if (++i13 == ne3) {
595
- i13 = 0;
596
- }
597
- }
598
- }
599
- }
600
- }
601
- }
602
- } else if (dst->type == GGML_TYPE_F32) {
603
- for (int64_t i03 = 0; i03 < ne03; i03++) {
604
- for (int64_t i02 = 0; i02 < ne02; i02++) {
605
- i10 += ne00 * ir0;
606
- while (i10 >= ne0) {
607
- i10 -= ne0;
608
- if (++i11 == ne1) {
609
- i11 = 0;
610
- if (++i12 == ne2) {
611
- i12 = 0;
612
- if (++i13 == ne3) {
613
- i13 = 0;
614
- }
615
- }
616
- }
617
- }
618
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
619
- for (int64_t i00 = 0; i00 < ne00; i00++) {
620
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
621
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
622
-
623
- *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
624
-
625
- if (++i10 == ne0) {
626
- i10 = 0;
627
- if (++i11 == ne1) {
628
- i11 = 0;
629
- if (++i12 == ne2) {
630
- i12 = 0;
631
- if (++i13 == ne3) {
632
- i13 = 0;
633
- }
634
- }
635
- }
636
- }
637
- }
638
- }
639
- i10 += ne00 * (ne01 - ir1);
640
- while (i10 >= ne0) {
641
- i10 -= ne0;
642
- if (++i11 == ne1) {
643
- i11 = 0;
644
- if (++i12 == ne2) {
645
- i12 = 0;
646
- if (++i13 == ne3) {
647
- i13 = 0;
648
- }
649
- }
650
- }
651
- }
652
- }
653
- }
654
- } else {
655
- GGML_ABORT("fatal error"); // TODO: implement
656
- }
657
- }
658
-
659
- static void ggml_compute_forward_dup_f32(
44
+ template<typename src_t, typename dst_t>
45
+ static void ggml_compute_forward_dup_flt(
660
46
  const ggml_compute_params * params,
661
47
  ggml_tensor * dst) {
662
48
 
663
49
  const ggml_tensor * src0 = dst->src[0];
664
50
 
665
51
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
52
+ GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
666
53
 
667
54
  GGML_TENSOR_UNARY_OP_LOCALS
668
55
 
@@ -677,6 +64,7 @@ static void ggml_compute_forward_dup_f32(
677
64
  const int ir0 = dr * ith;
678
65
  const int ir1 = MIN(ir0 + dr, nr);
679
66
 
67
+ // case: type & row size equal
680
68
  if (src0->type == dst->type &&
681
69
  ne00 == ne0 &&
682
70
  nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
@@ -695,173 +83,78 @@ static void ggml_compute_forward_dup_f32(
695
83
  return;
696
84
  }
697
85
 
86
+ // case: dst tensor is contiguous
698
87
  if (ggml_is_contiguous(dst)) {
699
- // TODO: simplify
700
- if (nb00 == sizeof(float)) {
701
- if (ggml_get_type_traits_cpu(dst->type)->from_float) {
702
- ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
703
-
88
+ if (nb00 == sizeof(src_t)) {
89
+ if constexpr (std::is_same_v<dst_t, src_t>) {
90
+ // same type
704
91
  size_t id = 0;
705
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
92
+ const size_t rs = ne00 * nb00;
706
93
  char * dst_ptr = (char *) dst->data;
707
94
 
708
95
  for (int i03 = 0; i03 < ne03; i03++) {
709
96
  for (int i02 = 0; i02 < ne02; i02++) {
710
97
  id += rs * ir0;
711
98
  for (int i01 = ir0; i01 < ir1; i01++) {
712
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
713
- from_float(src0_ptr, dst_ptr + id, ne00);
99
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
100
+ memcpy(dst_ptr + id, src0_ptr, rs);
714
101
  id += rs;
715
102
  }
716
103
  id += rs * (ne01 - ir1);
717
104
  }
718
105
  }
719
106
  } else {
720
- GGML_ABORT("fatal error"); // TODO: implement
721
- }
722
- } else {
723
- //printf("%s: this is not optimal - fix me\n", __func__);
724
-
725
- if (dst->type == GGML_TYPE_F32) {
726
- size_t id = 0;
727
- float * dst_ptr = (float *) dst->data;
728
-
729
- for (int i03 = 0; i03 < ne03; i03++) {
730
- for (int i02 = 0; i02 < ne02; i02++) {
731
- id += ne00 * ir0;
732
- for (int i01 = ir0; i01 < ir1; i01++) {
733
- for (int i00 = 0; i00 < ne00; i00++) {
734
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
735
-
736
- dst_ptr[id] = *src0_ptr;
737
- id++;
738
- }
739
- }
740
- id += ne00 * (ne01 - ir1);
741
- }
742
- }
743
- } else if (dst->type == GGML_TYPE_F16) {
107
+ // casting between non-quantized types
744
108
  size_t id = 0;
745
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
109
+ dst_t * dst_ptr = (dst_t *) dst->data;
746
110
 
747
111
  for (int i03 = 0; i03 < ne03; i03++) {
748
112
  for (int i02 = 0; i02 < ne02; i02++) {
749
113
  id += ne00 * ir0;
750
114
  for (int i01 = ir0; i01 < ir1; i01++) {
115
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
751
116
  for (int i00 = 0; i00 < ne00; i00++) {
752
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
753
-
754
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
117
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
118
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
755
119
  id++;
756
120
  }
757
121
  }
758
122
  id += ne00 * (ne01 - ir1);
759
123
  }
760
124
  }
761
- } else if (dst->type == GGML_TYPE_BF16) {
762
- size_t id = 0;
763
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
764
-
765
- for (int i03 = 0; i03 < ne03; i03++) {
766
- for (int i02 = 0; i02 < ne02; i02++) {
767
- id += ne00 * ir0;
768
- for (int i01 = ir0; i01 < ir1; i01++) {
769
- for (int i00 = 0; i00 < ne00; i00++) {
770
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
125
+ }
126
+ } else {
127
+ //printf("%s: this is not optimal - fix me\n", __func__);
771
128
 
772
- dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
773
- id++;
774
- }
775
- }
776
- id += ne00 * (ne01 - ir1);
777
- }
778
- }
779
- } else if (dst->type == GGML_TYPE_I32) {
780
- size_t id = 0;
781
- int32_t * dst_ptr = (int32_t *) dst->data;
129
+ size_t id = 0;
130
+ dst_t * dst_ptr = (dst_t *) dst->data;
782
131
 
783
- for (int i03 = 0; i03 < ne03; i03++) {
784
- for (int i02 = 0; i02 < ne02; i02++) {
785
- id += ne00 * ir0;
786
- for (int i01 = ir0; i01 < ir1; i01++) {
787
- for (int i00 = 0; i00 < ne00; i00++) {
788
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
132
+ for (int i03 = 0; i03 < ne03; i03++) {
133
+ for (int i02 = 0; i02 < ne02; i02++) {
134
+ id += ne00 * ir0;
135
+ for (int i01 = ir0; i01 < ir1; i01++) {
136
+ for (int i00 = 0; i00 < ne00; i00++) {
137
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
789
138
 
790
- dst_ptr[id] = *src0_ptr;
791
- id++;
792
- }
139
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
140
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
141
+ id++;
793
142
  }
794
- id += ne00 * (ne01 - ir1);
795
143
  }
144
+ id += ne00 * (ne01 - ir1);
796
145
  }
797
- } else {
798
- GGML_ABORT("fatal error"); // TODO: implement
799
146
  }
800
147
  }
801
-
802
148
  return;
803
149
  }
804
150
 
805
151
  // dst counters
806
-
807
152
  int64_t i10 = 0;
808
153
  int64_t i11 = 0;
809
154
  int64_t i12 = 0;
810
155
  int64_t i13 = 0;
811
156
 
812
- if (dst->type == GGML_TYPE_F32) {
813
- for (int64_t i03 = 0; i03 < ne03; i03++) {
814
- for (int64_t i02 = 0; i02 < ne02; i02++) {
815
- i10 += ne00 * ir0;
816
- while (i10 >= ne0) {
817
- i10 -= ne0;
818
- if (++i11 == ne1) {
819
- i11 = 0;
820
- if (++i12 == ne2) {
821
- i12 = 0;
822
- if (++i13 == ne3) {
823
- i13 = 0;
824
- }
825
- }
826
- }
827
- }
828
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
829
- for (int64_t i00 = 0; i00 < ne00; i00++) {
830
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
831
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
832
-
833
- memcpy(dst_ptr, src0_ptr, sizeof(float));
834
-
835
- if (++i10 == ne0) {
836
- i10 = 0;
837
- if (++i11 == ne1) {
838
- i11 = 0;
839
- if (++i12 == ne2) {
840
- i12 = 0;
841
- if (++i13 == ne3) {
842
- i13 = 0;
843
- }
844
- }
845
- }
846
- }
847
- }
848
- }
849
- i10 += ne00 * (ne01 - ir1);
850
- while (i10 >= ne0) {
851
- i10 -= ne0;
852
- if (++i11 == ne1) {
853
- i11 = 0;
854
- if (++i12 == ne2) {
855
- i12 = 0;
856
- if (++i13 == ne3) {
857
- i13 = 0;
858
- }
859
- }
860
- }
861
- }
862
- }
863
- }
864
- } else if (dst->type == GGML_TYPE_F16) {
157
+ if constexpr (std::is_same_v<dst_t, src_t>) {
865
158
  for (int64_t i03 = 0; i03 < ne03; i03++) {
866
159
  for (int64_t i02 = 0; i02 < ne02; i02++) {
867
160
  i10 += ne00 * ir0;
@@ -882,15 +175,15 @@ static void ggml_compute_forward_dup_f32(
882
175
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
883
176
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
884
177
 
885
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
178
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
886
179
 
887
- if (++i10 == ne0) {
180
+ if (++i10 == ne00) {
888
181
  i10 = 0;
889
- if (++i11 == ne1) {
182
+ if (++i11 == ne01) {
890
183
  i11 = 0;
891
- if (++i12 == ne2) {
184
+ if (++i12 == ne02) {
892
185
  i12 = 0;
893
- if (++i13 == ne3) {
186
+ if (++i13 == ne03) {
894
187
  i13 = 0;
895
188
  }
896
189
  }
@@ -913,59 +206,8 @@ static void ggml_compute_forward_dup_f32(
913
206
  }
914
207
  }
915
208
  }
916
- } else if (dst->type == GGML_TYPE_BF16) {
917
- for (int64_t i03 = 0; i03 < ne03; i03++) {
918
- for (int64_t i02 = 0; i02 < ne02; i02++) {
919
- i10 += ne00 * ir0;
920
- while (i10 >= ne0) {
921
- i10 -= ne0;
922
- if (++i11 == ne1) {
923
- i11 = 0;
924
- if (++i12 == ne2) {
925
- i12 = 0;
926
- if (++i13 == ne3) {
927
- i13 = 0;
928
- }
929
- }
930
- }
931
- }
932
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
933
- for (int64_t i00 = 0; i00 < ne00; i00++) {
934
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
935
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
936
209
 
937
- *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
938
-
939
- if (++i10 == ne0) {
940
- i10 = 0;
941
- if (++i11 == ne1) {
942
- i11 = 0;
943
- if (++i12 == ne2) {
944
- i12 = 0;
945
- if (++i13 == ne3) {
946
- i13 = 0;
947
- }
948
- }
949
- }
950
- }
951
- }
952
- }
953
- i10 += ne00 * (ne01 - ir1);
954
- while (i10 >= ne0) {
955
- i10 -= ne0;
956
- if (++i11 == ne1) {
957
- i11 = 0;
958
- if (++i12 == ne2) {
959
- i12 = 0;
960
- if (++i13 == ne3) {
961
- i13 = 0;
962
- }
963
- }
964
- }
965
- }
966
- }
967
- }
968
- } else if (dst->type == GGML_TYPE_I32) {
210
+ } else {
969
211
  for (int64_t i03 = 0; i03 < ne03; i03++) {
970
212
  for (int64_t i02 = 0; i02 < ne02; i02++) {
971
213
  i10 += ne00 * ir0;
@@ -986,7 +228,8 @@ static void ggml_compute_forward_dup_f32(
986
228
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
987
229
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
988
230
 
989
- *(int32_t *) dst_ptr = *(const float *) src0_ptr;
231
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
232
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
990
233
 
991
234
  if (++i10 == ne0) {
992
235
  i10 = 0;
@@ -1017,18 +260,19 @@ static void ggml_compute_forward_dup_f32(
1017
260
  }
1018
261
  }
1019
262
  }
1020
- } else {
1021
- GGML_ABORT("fatal error"); // TODO: implement
1022
263
  }
1023
264
  }
1024
265
 
1025
- static void ggml_compute_forward_dup_i32(
266
+
267
+ template<typename src_t>
268
+ static void ggml_compute_forward_dup_to_q(
1026
269
  const ggml_compute_params * params,
1027
270
  ggml_tensor * dst) {
1028
271
 
1029
272
  const ggml_tensor * src0 = dst->src[0];
1030
273
 
1031
274
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
275
+ GGML_ASSERT(!ggml_is_quantized(src0->type));
1032
276
 
1033
277
  GGML_TENSOR_UNARY_OP_LOCALS
1034
278
 
@@ -1043,68 +287,36 @@ static void ggml_compute_forward_dup_i32(
1043
287
  const int ir0 = dr * ith;
1044
288
  const int ir1 = MIN(ir0 + dr, nr);
1045
289
 
1046
- // dst counters
1047
-
1048
- int64_t i10 = 0;
1049
- int64_t i11 = 0;
1050
- int64_t i12 = 0;
1051
- int64_t i13 = 0;
290
+ if (ggml_is_contiguous(dst) &&
291
+ nb00 == sizeof(src_t) &&
292
+ ggml_get_type_traits_cpu(dst->type)->from_float) {
293
+ // casting non-quantized types --> intermediate f32 --> quantized
294
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
295
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
1052
296
 
1053
- // TODO: not optimal, but works
1054
- if (dst->type == GGML_TYPE_F32) {
1055
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1056
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1057
- i10 += ne00 * ir0;
1058
- while (i10 >= ne0) {
1059
- i10 -= ne0;
1060
- if (++i11 == ne1) {
1061
- i11 = 0;
1062
- if (++i12 == ne2) {
1063
- i12 = 0;
1064
- if (++i13 == ne3) {
1065
- i13 = 0;
1066
- }
1067
- }
1068
- }
1069
- }
1070
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
1071
- for (int64_t i00 = 0; i00 < ne00; i00++) {
1072
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1073
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
297
+ size_t id = 0;
298
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
299
+ char * dst_ptr = (char *) dst->data;
1074
300
 
1075
- *(float *) dst_ptr = *(const int32_t *) src0_ptr;
301
+ for (int i03 = 0; i03 < ne03; i03++) {
302
+ for (int i02 = 0; i02 < ne02; i02++) {
303
+ id += rs * ir0;
304
+ for (int i01 = ir0; i01 < ir1; i01++) {
305
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1076
306
 
1077
- if (++i10 == ne0) {
1078
- i10 = 0;
1079
- if (++i11 == ne1) {
1080
- i11 = 0;
1081
- if (++i12 == ne2) {
1082
- i12 = 0;
1083
- if (++i13 == ne3) {
1084
- i13 = 0;
1085
- }
1086
- }
1087
- }
1088
- }
1089
- }
1090
- }
1091
- i10 += ne00 * (ne01 - ir1);
1092
- while (i10 >= ne0) {
1093
- i10 -= ne0;
1094
- if (++i11 == ne1) {
1095
- i11 = 0;
1096
- if (++i12 == ne2) {
1097
- i12 = 0;
1098
- if (++i13 == ne3) {
1099
- i13 = 0;
1100
- }
1101
- }
307
+ for (int i00 = 0; i00 < ne00; i00++) {
308
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
1102
309
  }
310
+
311
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
312
+ id += rs;
1103
313
  }
314
+ id += rs * (ne01 - ir1);
1104
315
  }
1105
316
  }
1106
317
  } else {
1107
- GGML_ABORT("fatal error"); // TODO: implement
318
+ // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
319
+ GGML_ABORT("not implemented");
1108
320
  }
1109
321
  }
1110
322
 
@@ -1258,7 +470,7 @@ static void ggml_compute_forward_dup_bytes(
1258
470
  }
1259
471
  }
1260
472
 
1261
- static void ggml_compute_forward_dup_q(
473
+ static void ggml_compute_forward_dup_from_q(
1262
474
  const ggml_compute_params * params,
1263
475
  ggml_tensor * dst) {
1264
476
 
@@ -1323,24 +535,35 @@ void ggml_compute_forward_dup(
1323
535
  switch (src0->type) {
1324
536
  case GGML_TYPE_F16:
1325
537
  {
1326
- ggml_compute_forward_dup_f16(params, dst);
538
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
539
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
540
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
541
+ else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
1327
542
  } break;
1328
543
  case GGML_TYPE_BF16:
1329
544
  {
1330
- ggml_compute_forward_dup_bf16(params, dst);
545
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
546
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
547
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
548
+ else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
1331
549
  } break;
1332
550
  case GGML_TYPE_F32:
1333
551
  {
1334
- ggml_compute_forward_dup_f32(params, dst);
552
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
553
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
554
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
555
+ else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
556
+ else ggml_compute_forward_dup_to_q<float>(params, dst);
1335
557
  } break;
1336
558
  case GGML_TYPE_I32:
1337
559
  {
1338
- ggml_compute_forward_dup_i32(params, dst);
560
+ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
561
+ else GGML_ABORT("not implemented");
1339
562
  } break;
1340
563
  default:
1341
564
  {
1342
565
  if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
1343
- ggml_compute_forward_dup_q(params, dst);
566
+ ggml_compute_forward_dup_from_q(params, dst);
1344
567
  break;
1345
568
  }
1346
569
  GGML_ABORT("fatal error");
@@ -5516,6 +4739,7 @@ void ggml_compute_forward_get_rows(
5516
4739
  //}
5517
4740
  }
5518
4741
 
4742
+ template<typename idx_t>
5519
4743
  static void ggml_compute_forward_set_rows_f32(
5520
4744
  const ggml_compute_params * params,
5521
4745
  ggml_tensor * dst) {
@@ -5554,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
5554
4778
  const int64_t i11 = i02%ne11;
5555
4779
  const int64_t i10 = i;
5556
4780
 
5557
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4781
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
5558
4782
 
5559
4783
  GGML_ASSERT(i1 >= 0 && i1 < ne1);
5560
4784
 
@@ -5571,11 +4795,18 @@ void ggml_compute_forward_set_rows(
5571
4795
  ggml_tensor * dst) {
5572
4796
 
5573
4797
  const ggml_tensor * src0 = dst->src[0];
4798
+ const ggml_tensor * src1 = dst->src[1];
5574
4799
 
5575
4800
  switch (src0->type) {
5576
4801
  case GGML_TYPE_F32:
5577
4802
  {
5578
- ggml_compute_forward_set_rows_f32(params, dst);
4803
+ if (src1->type == GGML_TYPE_I64) {
4804
+ ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4805
+ } else if (src1->type == GGML_TYPE_I32) {
4806
+ ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4807
+ } else {
4808
+ GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4809
+ }
5579
4810
  } break;
5580
4811
  default:
5581
4812
  {