@fugood/llama.node 1.2.0 → 1.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont(
41
41
  }
42
42
  }
43
43
 
44
- static void ggml_compute_forward_dup_f16(
44
+ template<typename src_t, typename dst_t>
45
+ static void ggml_compute_forward_dup_flt(
45
46
  const ggml_compute_params * params,
46
47
  ggml_tensor * dst) {
47
48
 
48
49
  const ggml_tensor * src0 = dst->src[0];
49
50
 
50
51
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
52
+ GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
51
53
 
52
54
  GGML_TENSOR_UNARY_OP_LOCALS
53
55
 
@@ -62,6 +64,7 @@ static void ggml_compute_forward_dup_f16(
62
64
  const int ir0 = dr * ith;
63
65
  const int ir1 = MIN(ir0 + dr, nr);
64
66
 
67
+ // case: type & row size equal
65
68
  if (src0->type == dst->type &&
66
69
  ne00 == ne0 &&
67
70
  nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
@@ -80,11 +83,11 @@ static void ggml_compute_forward_dup_f16(
80
83
  return;
81
84
  }
82
85
 
83
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
84
-
86
+ // case: dst tensor is contiguous
85
87
  if (ggml_is_contiguous(dst)) {
86
- if (nb00 == sizeof(ggml_fp16_t)) {
87
- if (dst->type == GGML_TYPE_F16) {
88
+ if (nb00 == sizeof(src_t)) {
89
+ if constexpr (std::is_same_v<dst_t, src_t>) {
90
+ // same type
88
91
  size_t id = 0;
89
92
  const size_t rs = ne00 * nb00;
90
93
  char * dst_ptr = (char *) dst->data;
@@ -100,91 +103,46 @@ static void ggml_compute_forward_dup_f16(
100
103
  id += rs * (ne01 - ir1);
101
104
  }
102
105
  }
103
- } else if (dst->type == GGML_TYPE_F32) {
106
+ } else {
107
+ // casting between non-quantized types
104
108
  size_t id = 0;
105
- float * dst_ptr = (float *) dst->data;
109
+ dst_t * dst_ptr = (dst_t *) dst->data;
106
110
 
107
111
  for (int i03 = 0; i03 < ne03; i03++) {
108
112
  for (int i02 = 0; i02 < ne02; i02++) {
109
113
  id += ne00 * ir0;
110
114
  for (int i01 = ir0; i01 < ir1; i01++) {
111
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
115
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
112
116
  for (int i00 = 0; i00 < ne00; i00++) {
113
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
117
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
118
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
114
119
  id++;
115
120
  }
116
121
  }
117
122
  id += ne00 * (ne01 - ir1);
118
123
  }
119
124
  }
120
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
121
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
122
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
123
-
124
- size_t id = 0;
125
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
126
- char * dst_ptr = (char *) dst->data;
127
-
128
- for (int i03 = 0; i03 < ne03; i03++) {
129
- for (int i02 = 0; i02 < ne02; i02++) {
130
- id += rs * ir0;
131
- for (int i01 = ir0; i01 < ir1; i01++) {
132
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
133
-
134
- for (int i00 = 0; i00 < ne00; i00++) {
135
- src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
136
- }
137
-
138
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
139
- id += rs;
140
- }
141
- id += rs * (ne01 - ir1);
142
- }
143
- }
144
- } else {
145
- GGML_ABORT("fatal error"); // TODO: implement
146
125
  }
147
126
  } else {
148
127
  //printf("%s: this is not optimal - fix me\n", __func__);
149
128
 
150
- if (dst->type == GGML_TYPE_F32) {
151
- size_t id = 0;
152
- float * dst_ptr = (float *) dst->data;
153
-
154
- for (int i03 = 0; i03 < ne03; i03++) {
155
- for (int i02 = 0; i02 < ne02; i02++) {
156
- id += ne00 * ir0;
157
- for (int i01 = ir0; i01 < ir1; i01++) {
158
- for (int i00 = 0; i00 < ne00; i00++) {
159
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
160
-
161
- dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
162
- id++;
163
- }
164
- }
165
- id += ne00 * (ne01 - ir1);
166
- }
167
- }
168
- } else if (dst->type == GGML_TYPE_F16) {
169
- size_t id = 0;
170
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
129
+ size_t id = 0;
130
+ dst_t * dst_ptr = (dst_t *) dst->data;
171
131
 
172
- for (int i03 = 0; i03 < ne03; i03++) {
173
- for (int i02 = 0; i02 < ne02; i02++) {
174
- id += ne00 * ir0;
175
- for (int i01 = ir0; i01 < ir1; i01++) {
176
- for (int i00 = 0; i00 < ne00; i00++) {
177
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
132
+ for (int i03 = 0; i03 < ne03; i03++) {
133
+ for (int i02 = 0; i02 < ne02; i02++) {
134
+ id += ne00 * ir0;
135
+ for (int i01 = ir0; i01 < ir1; i01++) {
136
+ for (int i00 = 0; i00 < ne00; i00++) {
137
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
138
 
179
- dst_ptr[id] = *src0_ptr;
180
- id++;
181
- }
139
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
140
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
141
+ id++;
182
142
  }
183
- id += ne00 * (ne01 - ir1);
184
143
  }
144
+ id += ne00 * (ne01 - ir1);
185
145
  }
186
- } else {
187
- GGML_ABORT("fatal error"); // TODO: implement
188
146
  }
189
147
  }
190
148
  return;
@@ -196,7 +154,7 @@ static void ggml_compute_forward_dup_f16(
196
154
  int64_t i12 = 0;
197
155
  int64_t i13 = 0;
198
156
 
199
- if (dst->type == GGML_TYPE_F16) {
157
+ if constexpr (std::is_same_v<dst_t, src_t>) {
200
158
  for (int64_t i03 = 0; i03 < ne03; i03++) {
201
159
  for (int64_t i02 = 0; i02 < ne02; i02++) {
202
160
  i10 += ne00 * ir0;
@@ -217,7 +175,7 @@ static void ggml_compute_forward_dup_f16(
217
175
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
218
176
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
219
177
 
220
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
178
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
221
179
 
222
180
  if (++i10 == ne00) {
223
181
  i10 = 0;
@@ -248,7 +206,8 @@ static void ggml_compute_forward_dup_f16(
248
206
  }
249
207
  }
250
208
  }
251
- } else if (dst->type == GGML_TYPE_F32) {
209
+
210
+ } else {
252
211
  for (int64_t i03 = 0; i03 < ne03; i03++) {
253
212
  for (int64_t i02 = 0; i02 < ne02; i02++) {
254
213
  i10 += ne00 * ir0;
@@ -269,7 +228,8 @@ static void ggml_compute_forward_dup_f16(
269
228
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
270
229
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
271
230
 
272
- *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
231
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
232
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
273
233
 
274
234
  if (++i10 == ne0) {
275
235
  i10 = 0;
@@ -300,18 +260,19 @@ static void ggml_compute_forward_dup_f16(
300
260
  }
301
261
  }
302
262
  }
303
- } else {
304
- GGML_ABORT("fatal error"); // TODO: implement
305
263
  }
306
264
  }
307
265
 
308
- static void ggml_compute_forward_dup_bf16(
266
+
267
+ template<typename src_t>
268
+ static void ggml_compute_forward_dup_to_q(
309
269
  const ggml_compute_params * params,
310
270
  ggml_tensor * dst) {
311
271
 
312
272
  const ggml_tensor * src0 = dst->src[0];
313
273
 
314
274
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
275
+ GGML_ASSERT(!ggml_is_quantized(src0->type));
315
276
 
316
277
  GGML_TENSOR_UNARY_OP_LOCALS
317
278
 
@@ -326,785 +287,36 @@ static void ggml_compute_forward_dup_bf16(
326
287
  const int ir0 = dr * ith;
327
288
  const int ir1 = MIN(ir0 + dr, nr);
328
289
 
329
- if (src0->type == dst->type &&
330
- ne00 == ne0 &&
331
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
332
- // copy by rows
333
- const size_t rs = ne00*nb00;
334
- for (int64_t i03 = 0; i03 < ne03; i03++) {
335
- for (int64_t i02 = 0; i02 < ne02; i02++) {
336
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
337
- memcpy(
338
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
339
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
340
- rs);
341
- }
342
- }
343
- }
344
- return;
345
- }
346
-
347
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
348
-
349
- if (ggml_is_contiguous(dst)) {
350
- if (nb00 == sizeof(ggml_bf16_t)) {
351
- if (dst->type == GGML_TYPE_BF16) {
352
- size_t id = 0;
353
- const size_t rs = ne00 * nb00;
354
- char * dst_ptr = (char *) dst->data;
355
-
356
- for (int i03 = 0; i03 < ne03; i03++) {
357
- for (int i02 = 0; i02 < ne02; i02++) {
358
- id += rs * ir0;
359
- for (int i01 = ir0; i01 < ir1; i01++) {
360
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
361
- memcpy(dst_ptr + id, src0_ptr, rs);
362
- id += rs;
363
- }
364
- id += rs * (ne01 - ir1);
365
- }
366
- }
367
- } else if (dst->type == GGML_TYPE_F16) {
368
- size_t id = 0;
369
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
370
-
371
- for (int i03 = 0; i03 < ne03; i03++) {
372
- for (int i02 = 0; i02 < ne02; i02++) {
373
- id += ne00 * ir0;
374
- for (int i01 = ir0; i01 < ir1; i01++) {
375
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
376
- for (int i00 = 0; i00 < ne00; i00++) {
377
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
378
- id++;
379
- }
380
- }
381
- id += ne00 * (ne01 - ir1);
382
- }
383
- }
384
- } else if (dst->type == GGML_TYPE_F32) {
385
- size_t id = 0;
386
- float * dst_ptr = (float *) dst->data;
387
-
388
- for (int i03 = 0; i03 < ne03; i03++) {
389
- for (int i02 = 0; i02 < ne02; i02++) {
390
- id += ne00 * ir0;
391
- for (int i01 = ir0; i01 < ir1; i01++) {
392
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
393
- for (int i00 = 0; i00 < ne00; i00++) {
394
- dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
395
- id++;
396
- }
397
- }
398
- id += ne00 * (ne01 - ir1);
399
- }
400
- }
401
- } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
402
- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
403
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
404
-
405
- size_t id = 0;
406
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
407
- char * dst_ptr = (char *) dst->data;
408
-
409
- for (int i03 = 0; i03 < ne03; i03++) {
410
- for (int i02 = 0; i02 < ne02; i02++) {
411
- id += rs * ir0;
412
- for (int i01 = ir0; i01 < ir1; i01++) {
413
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
414
-
415
- for (int i00 = 0; i00 < ne00; i00++) {
416
- src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
417
- }
418
-
419
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
420
- id += rs;
421
- }
422
- id += rs * (ne01 - ir1);
423
- }
424
- }
425
- } else {
426
- GGML_ABORT("fatal error"); // TODO: implement
427
- }
428
- } else {
429
- //printf("%s: this is not optimal - fix me\n", __func__);
430
-
431
- if (dst->type == GGML_TYPE_F32) {
432
- size_t id = 0;
433
- float * dst_ptr = (float *) dst->data;
434
-
435
- for (int i03 = 0; i03 < ne03; i03++) {
436
- for (int i02 = 0; i02 < ne02; i02++) {
437
- id += ne00 * ir0;
438
- for (int i01 = ir0; i01 < ir1; i01++) {
439
- for (int i00 = 0; i00 < ne00; i00++) {
440
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
441
-
442
- dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
443
- id++;
444
- }
445
- }
446
- id += ne00 * (ne01 - ir1);
447
- }
448
- }
449
- } else if (dst->type == GGML_TYPE_BF16) {
450
- size_t id = 0;
451
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
452
-
453
- for (int i03 = 0; i03 < ne03; i03++) {
454
- for (int i02 = 0; i02 < ne02; i02++) {
455
- id += ne00 * ir0;
456
- for (int i01 = ir0; i01 < ir1; i01++) {
457
- for (int i00 = 0; i00 < ne00; i00++) {
458
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
459
-
460
- dst_ptr[id] = *src0_ptr;
461
- id++;
462
- }
463
- }
464
- id += ne00 * (ne01 - ir1);
465
- }
466
- }
467
- } else if (dst->type == GGML_TYPE_F16) {
468
- size_t id = 0;
469
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
470
-
471
- for (int i03 = 0; i03 < ne03; i03++) {
472
- for (int i02 = 0; i02 < ne02; i02++) {
473
- id += ne00 * ir0;
474
- for (int i01 = ir0; i01 < ir1; i01++) {
475
- for (int i00 = 0; i00 < ne00; i00++) {
476
- const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
477
-
478
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
479
- id++;
480
- }
481
- }
482
- id += ne00 * (ne01 - ir1);
483
- }
484
- }
485
- } else {
486
- GGML_ABORT("fatal error"); // TODO: implement
487
- }
488
- }
489
- return;
490
- }
491
-
492
- // dst counters
493
- int64_t i10 = 0;
494
- int64_t i11 = 0;
495
- int64_t i12 = 0;
496
- int64_t i13 = 0;
290
+ if (ggml_is_contiguous(dst) &&
291
+ nb00 == sizeof(src_t) &&
292
+ ggml_get_type_traits_cpu(dst->type)->from_float) {
293
+ // casting non-quantized types --> intermediate f32 --> quantized
294
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
295
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
497
296
 
498
- if (dst->type == GGML_TYPE_BF16) {
499
- for (int64_t i03 = 0; i03 < ne03; i03++) {
500
- for (int64_t i02 = 0; i02 < ne02; i02++) {
501
- i10 += ne00 * ir0;
502
- while (i10 >= ne0) {
503
- i10 -= ne0;
504
- if (++i11 == ne1) {
505
- i11 = 0;
506
- if (++i12 == ne2) {
507
- i12 = 0;
508
- if (++i13 == ne3) {
509
- i13 = 0;
510
- }
511
- }
512
- }
513
- }
514
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
515
- for (int64_t i00 = 0; i00 < ne00; i00++) {
516
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
517
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
297
+ size_t id = 0;
298
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
299
+ char * dst_ptr = (char *) dst->data;
518
300
 
519
- memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
301
+ for (int i03 = 0; i03 < ne03; i03++) {
302
+ for (int i02 = 0; i02 < ne02; i02++) {
303
+ id += rs * ir0;
304
+ for (int i01 = ir0; i01 < ir1; i01++) {
305
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
520
306
 
521
- if (++i10 == ne00) {
522
- i10 = 0;
523
- if (++i11 == ne01) {
524
- i11 = 0;
525
- if (++i12 == ne02) {
526
- i12 = 0;
527
- if (++i13 == ne03) {
528
- i13 = 0;
529
- }
530
- }
531
- }
532
- }
533
- }
534
- }
535
- i10 += ne00 * (ne01 - ir1);
536
- while (i10 >= ne0) {
537
- i10 -= ne0;
538
- if (++i11 == ne1) {
539
- i11 = 0;
540
- if (++i12 == ne2) {
541
- i12 = 0;
542
- if (++i13 == ne3) {
543
- i13 = 0;
544
- }
545
- }
546
- }
547
- }
548
- }
549
- }
550
- } else if (dst->type == GGML_TYPE_F16) {
551
- for (int64_t i03 = 0; i03 < ne03; i03++) {
552
- for (int64_t i02 = 0; i02 < ne02; i02++) {
553
- i10 += ne00 * ir0;
554
- while (i10 >= ne0) {
555
- i10 -= ne0;
556
- if (++i11 == ne1) {
557
- i11 = 0;
558
- if (++i12 == ne2) {
559
- i12 = 0;
560
- if (++i13 == ne3) {
561
- i13 = 0;
562
- }
563
- }
307
+ for (int i00 = 0; i00 < ne00; i00++) {
308
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
564
309
  }
565
- }
566
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
567
- for (int64_t i00 = 0; i00 < ne00; i00++) {
568
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
569
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
570
-
571
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
572
310
 
573
- if (++i10 == ne0) {
574
- i10 = 0;
575
- if (++i11 == ne1) {
576
- i11 = 0;
577
- if (++i12 == ne2) {
578
- i12 = 0;
579
- if (++i13 == ne3) {
580
- i13 = 0;
581
- }
582
- }
583
- }
584
- }
585
- }
586
- }
587
- i10 += ne00 * (ne01 - ir1);
588
- while (i10 >= ne0) {
589
- i10 -= ne0;
590
- if (++i11 == ne1) {
591
- i11 = 0;
592
- if (++i12 == ne2) {
593
- i12 = 0;
594
- if (++i13 == ne3) {
595
- i13 = 0;
596
- }
597
- }
598
- }
599
- }
600
- }
601
- }
602
- } else if (dst->type == GGML_TYPE_F32) {
603
- for (int64_t i03 = 0; i03 < ne03; i03++) {
604
- for (int64_t i02 = 0; i02 < ne02; i02++) {
605
- i10 += ne00 * ir0;
606
- while (i10 >= ne0) {
607
- i10 -= ne0;
608
- if (++i11 == ne1) {
609
- i11 = 0;
610
- if (++i12 == ne2) {
611
- i12 = 0;
612
- if (++i13 == ne3) {
613
- i13 = 0;
614
- }
615
- }
616
- }
617
- }
618
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
619
- for (int64_t i00 = 0; i00 < ne00; i00++) {
620
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
621
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
622
-
623
- *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
624
-
625
- if (++i10 == ne0) {
626
- i10 = 0;
627
- if (++i11 == ne1) {
628
- i11 = 0;
629
- if (++i12 == ne2) {
630
- i12 = 0;
631
- if (++i13 == ne3) {
632
- i13 = 0;
633
- }
634
- }
635
- }
636
- }
637
- }
638
- }
639
- i10 += ne00 * (ne01 - ir1);
640
- while (i10 >= ne0) {
641
- i10 -= ne0;
642
- if (++i11 == ne1) {
643
- i11 = 0;
644
- if (++i12 == ne2) {
645
- i12 = 0;
646
- if (++i13 == ne3) {
647
- i13 = 0;
648
- }
649
- }
650
- }
651
- }
652
- }
653
- }
654
- } else {
655
- GGML_ABORT("fatal error"); // TODO: implement
656
- }
657
- }
658
-
659
- static void ggml_compute_forward_dup_f32(
660
- const ggml_compute_params * params,
661
- ggml_tensor * dst) {
662
-
663
- const ggml_tensor * src0 = dst->src[0];
664
-
665
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
666
-
667
- GGML_TENSOR_UNARY_OP_LOCALS
668
-
669
- const int ith = params->ith; // thread index
670
- const int nth = params->nth; // number of threads
671
-
672
- // parallelize by rows
673
- const int nr = ne01;
674
- // number of rows per thread
675
- const int dr = (nr + nth - 1) / nth;
676
- // row range for this thread
677
- const int ir0 = dr * ith;
678
- const int ir1 = MIN(ir0 + dr, nr);
679
-
680
- if (src0->type == dst->type &&
681
- ne00 == ne0 &&
682
- nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
683
- // copy by rows
684
- const size_t rs = ne00*nb00;
685
- for (int64_t i03 = 0; i03 < ne03; i03++) {
686
- for (int64_t i02 = 0; i02 < ne02; i02++) {
687
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
688
- memcpy(
689
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
690
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
691
- rs);
692
- }
693
- }
694
- }
695
- return;
696
- }
697
-
698
- if (ggml_is_contiguous(dst)) {
699
- // TODO: simplify
700
- if (nb00 == sizeof(float)) {
701
- if (ggml_get_type_traits_cpu(dst->type)->from_float) {
702
- ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
703
-
704
- size_t id = 0;
705
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
706
- char * dst_ptr = (char *) dst->data;
707
-
708
- for (int i03 = 0; i03 < ne03; i03++) {
709
- for (int i02 = 0; i02 < ne02; i02++) {
710
- id += rs * ir0;
711
- for (int i01 = ir0; i01 < ir1; i01++) {
712
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
713
- from_float(src0_ptr, dst_ptr + id, ne00);
714
- id += rs;
715
- }
716
- id += rs * (ne01 - ir1);
717
- }
718
- }
719
- } else {
720
- GGML_ABORT("fatal error"); // TODO: implement
721
- }
722
- } else {
723
- //printf("%s: this is not optimal - fix me\n", __func__);
724
-
725
- if (dst->type == GGML_TYPE_F32) {
726
- size_t id = 0;
727
- float * dst_ptr = (float *) dst->data;
728
-
729
- for (int i03 = 0; i03 < ne03; i03++) {
730
- for (int i02 = 0; i02 < ne02; i02++) {
731
- id += ne00 * ir0;
732
- for (int i01 = ir0; i01 < ir1; i01++) {
733
- for (int i00 = 0; i00 < ne00; i00++) {
734
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
735
-
736
- dst_ptr[id] = *src0_ptr;
737
- id++;
738
- }
739
- }
740
- id += ne00 * (ne01 - ir1);
741
- }
742
- }
743
- } else if (dst->type == GGML_TYPE_F16) {
744
- size_t id = 0;
745
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
746
-
747
- for (int i03 = 0; i03 < ne03; i03++) {
748
- for (int i02 = 0; i02 < ne02; i02++) {
749
- id += ne00 * ir0;
750
- for (int i01 = ir0; i01 < ir1; i01++) {
751
- for (int i00 = 0; i00 < ne00; i00++) {
752
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
753
-
754
- dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
755
- id++;
756
- }
757
- }
758
- id += ne00 * (ne01 - ir1);
759
- }
760
- }
761
- } else if (dst->type == GGML_TYPE_BF16) {
762
- size_t id = 0;
763
- ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
764
-
765
- for (int i03 = 0; i03 < ne03; i03++) {
766
- for (int i02 = 0; i02 < ne02; i02++) {
767
- id += ne00 * ir0;
768
- for (int i01 = ir0; i01 < ir1; i01++) {
769
- for (int i00 = 0; i00 < ne00; i00++) {
770
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
771
-
772
- dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
773
- id++;
774
- }
775
- }
776
- id += ne00 * (ne01 - ir1);
777
- }
778
- }
779
- } else if (dst->type == GGML_TYPE_I32) {
780
- size_t id = 0;
781
- int32_t * dst_ptr = (int32_t *) dst->data;
782
-
783
- for (int i03 = 0; i03 < ne03; i03++) {
784
- for (int i02 = 0; i02 < ne02; i02++) {
785
- id += ne00 * ir0;
786
- for (int i01 = ir0; i01 < ir1; i01++) {
787
- for (int i00 = 0; i00 < ne00; i00++) {
788
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
789
-
790
- dst_ptr[id] = *src0_ptr;
791
- id++;
792
- }
793
- }
794
- id += ne00 * (ne01 - ir1);
795
- }
796
- }
797
- } else {
798
- GGML_ABORT("fatal error"); // TODO: implement
799
- }
800
- }
801
-
802
- return;
803
- }
804
-
805
- // dst counters
806
-
807
- int64_t i10 = 0;
808
- int64_t i11 = 0;
809
- int64_t i12 = 0;
810
- int64_t i13 = 0;
811
-
812
- if (dst->type == GGML_TYPE_F32) {
813
- for (int64_t i03 = 0; i03 < ne03; i03++) {
814
- for (int64_t i02 = 0; i02 < ne02; i02++) {
815
- i10 += ne00 * ir0;
816
- while (i10 >= ne0) {
817
- i10 -= ne0;
818
- if (++i11 == ne1) {
819
- i11 = 0;
820
- if (++i12 == ne2) {
821
- i12 = 0;
822
- if (++i13 == ne3) {
823
- i13 = 0;
824
- }
825
- }
826
- }
827
- }
828
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
829
- for (int64_t i00 = 0; i00 < ne00; i00++) {
830
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
831
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
832
-
833
- memcpy(dst_ptr, src0_ptr, sizeof(float));
834
-
835
- if (++i10 == ne0) {
836
- i10 = 0;
837
- if (++i11 == ne1) {
838
- i11 = 0;
839
- if (++i12 == ne2) {
840
- i12 = 0;
841
- if (++i13 == ne3) {
842
- i13 = 0;
843
- }
844
- }
845
- }
846
- }
847
- }
848
- }
849
- i10 += ne00 * (ne01 - ir1);
850
- while (i10 >= ne0) {
851
- i10 -= ne0;
852
- if (++i11 == ne1) {
853
- i11 = 0;
854
- if (++i12 == ne2) {
855
- i12 = 0;
856
- if (++i13 == ne3) {
857
- i13 = 0;
858
- }
859
- }
860
- }
861
- }
862
- }
863
- }
864
- } else if (dst->type == GGML_TYPE_F16) {
865
- for (int64_t i03 = 0; i03 < ne03; i03++) {
866
- for (int64_t i02 = 0; i02 < ne02; i02++) {
867
- i10 += ne00 * ir0;
868
- while (i10 >= ne0) {
869
- i10 -= ne0;
870
- if (++i11 == ne1) {
871
- i11 = 0;
872
- if (++i12 == ne2) {
873
- i12 = 0;
874
- if (++i13 == ne3) {
875
- i13 = 0;
876
- }
877
- }
878
- }
879
- }
880
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
881
- for (int64_t i00 = 0; i00 < ne00; i00++) {
882
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
883
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
884
-
885
- *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
886
-
887
- if (++i10 == ne0) {
888
- i10 = 0;
889
- if (++i11 == ne1) {
890
- i11 = 0;
891
- if (++i12 == ne2) {
892
- i12 = 0;
893
- if (++i13 == ne3) {
894
- i13 = 0;
895
- }
896
- }
897
- }
898
- }
899
- }
900
- }
901
- i10 += ne00 * (ne01 - ir1);
902
- while (i10 >= ne0) {
903
- i10 -= ne0;
904
- if (++i11 == ne1) {
905
- i11 = 0;
906
- if (++i12 == ne2) {
907
- i12 = 0;
908
- if (++i13 == ne3) {
909
- i13 = 0;
910
- }
911
- }
912
- }
913
- }
914
- }
915
- }
916
- } else if (dst->type == GGML_TYPE_BF16) {
917
- for (int64_t i03 = 0; i03 < ne03; i03++) {
918
- for (int64_t i02 = 0; i02 < ne02; i02++) {
919
- i10 += ne00 * ir0;
920
- while (i10 >= ne0) {
921
- i10 -= ne0;
922
- if (++i11 == ne1) {
923
- i11 = 0;
924
- if (++i12 == ne2) {
925
- i12 = 0;
926
- if (++i13 == ne3) {
927
- i13 = 0;
928
- }
929
- }
930
- }
931
- }
932
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
933
- for (int64_t i00 = 0; i00 < ne00; i00++) {
934
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
935
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
936
-
937
- *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
938
-
939
- if (++i10 == ne0) {
940
- i10 = 0;
941
- if (++i11 == ne1) {
942
- i11 = 0;
943
- if (++i12 == ne2) {
944
- i12 = 0;
945
- if (++i13 == ne3) {
946
- i13 = 0;
947
- }
948
- }
949
- }
950
- }
951
- }
952
- }
953
- i10 += ne00 * (ne01 - ir1);
954
- while (i10 >= ne0) {
955
- i10 -= ne0;
956
- if (++i11 == ne1) {
957
- i11 = 0;
958
- if (++i12 == ne2) {
959
- i12 = 0;
960
- if (++i13 == ne3) {
961
- i13 = 0;
962
- }
963
- }
964
- }
965
- }
966
- }
967
- }
968
- } else if (dst->type == GGML_TYPE_I32) {
969
- for (int64_t i03 = 0; i03 < ne03; i03++) {
970
- for (int64_t i02 = 0; i02 < ne02; i02++) {
971
- i10 += ne00 * ir0;
972
- while (i10 >= ne0) {
973
- i10 -= ne0;
974
- if (++i11 == ne1) {
975
- i11 = 0;
976
- if (++i12 == ne2) {
977
- i12 = 0;
978
- if (++i13 == ne3) {
979
- i13 = 0;
980
- }
981
- }
982
- }
983
- }
984
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
985
- for (int64_t i00 = 0; i00 < ne00; i00++) {
986
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
987
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
988
-
989
- *(int32_t *) dst_ptr = *(const float *) src0_ptr;
990
-
991
- if (++i10 == ne0) {
992
- i10 = 0;
993
- if (++i11 == ne1) {
994
- i11 = 0;
995
- if (++i12 == ne2) {
996
- i12 = 0;
997
- if (++i13 == ne3) {
998
- i13 = 0;
999
- }
1000
- }
1001
- }
1002
- }
1003
- }
1004
- }
1005
- i10 += ne00 * (ne01 - ir1);
1006
- while (i10 >= ne0) {
1007
- i10 -= ne0;
1008
- if (++i11 == ne1) {
1009
- i11 = 0;
1010
- if (++i12 == ne2) {
1011
- i12 = 0;
1012
- if (++i13 == ne3) {
1013
- i13 = 0;
1014
- }
1015
- }
1016
- }
1017
- }
1018
- }
1019
- }
1020
- } else {
1021
- GGML_ABORT("fatal error"); // TODO: implement
1022
- }
1023
- }
1024
-
1025
- static void ggml_compute_forward_dup_i32(
1026
- const ggml_compute_params * params,
1027
- ggml_tensor * dst) {
1028
-
1029
- const ggml_tensor * src0 = dst->src[0];
1030
-
1031
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
1032
-
1033
- GGML_TENSOR_UNARY_OP_LOCALS
1034
-
1035
- const int ith = params->ith; // thread index
1036
- const int nth = params->nth; // number of threads
1037
-
1038
- // parallelize by rows
1039
- const int nr = ne01;
1040
- // number of rows per thread
1041
- const int dr = (nr + nth - 1) / nth;
1042
- // row range for this thread
1043
- const int ir0 = dr * ith;
1044
- const int ir1 = MIN(ir0 + dr, nr);
1045
-
1046
- // dst counters
1047
-
1048
- int64_t i10 = 0;
1049
- int64_t i11 = 0;
1050
- int64_t i12 = 0;
1051
- int64_t i13 = 0;
1052
-
1053
- // TODO: not optimal, but works
1054
- if (dst->type == GGML_TYPE_F32) {
1055
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1056
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1057
- i10 += ne00 * ir0;
1058
- while (i10 >= ne0) {
1059
- i10 -= ne0;
1060
- if (++i11 == ne1) {
1061
- i11 = 0;
1062
- if (++i12 == ne2) {
1063
- i12 = 0;
1064
- if (++i13 == ne3) {
1065
- i13 = 0;
1066
- }
1067
- }
1068
- }
1069
- }
1070
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
1071
- for (int64_t i00 = 0; i00 < ne00; i00++) {
1072
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1073
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
1074
-
1075
- *(float *) dst_ptr = *(const int32_t *) src0_ptr;
1076
-
1077
- if (++i10 == ne0) {
1078
- i10 = 0;
1079
- if (++i11 == ne1) {
1080
- i11 = 0;
1081
- if (++i12 == ne2) {
1082
- i12 = 0;
1083
- if (++i13 == ne3) {
1084
- i13 = 0;
1085
- }
1086
- }
1087
- }
1088
- }
1089
- }
1090
- }
1091
- i10 += ne00 * (ne01 - ir1);
1092
- while (i10 >= ne0) {
1093
- i10 -= ne0;
1094
- if (++i11 == ne1) {
1095
- i11 = 0;
1096
- if (++i12 == ne2) {
1097
- i12 = 0;
1098
- if (++i13 == ne3) {
1099
- i13 = 0;
1100
- }
1101
- }
1102
- }
311
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
312
+ id += rs;
1103
313
  }
314
+ id += rs * (ne01 - ir1);
1104
315
  }
1105
316
  }
1106
317
  } else {
1107
- GGML_ABORT("fatal error"); // TODO: implement
318
+ // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
319
+ GGML_ABORT("not implemented");
1108
320
  }
1109
321
  }
1110
322
 
@@ -1258,7 +470,7 @@ static void ggml_compute_forward_dup_bytes(
1258
470
  }
1259
471
  }
1260
472
 
1261
- static void ggml_compute_forward_dup_q(
473
+ static void ggml_compute_forward_dup_from_q(
1262
474
  const ggml_compute_params * params,
1263
475
  ggml_tensor * dst) {
1264
476
 
@@ -1323,24 +535,35 @@ void ggml_compute_forward_dup(
1323
535
  switch (src0->type) {
1324
536
  case GGML_TYPE_F16:
1325
537
  {
1326
- ggml_compute_forward_dup_f16(params, dst);
538
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
539
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
540
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
541
+ else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
1327
542
  } break;
1328
543
  case GGML_TYPE_BF16:
1329
544
  {
1330
- ggml_compute_forward_dup_bf16(params, dst);
545
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
546
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
547
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
548
+ else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
1331
549
  } break;
1332
550
  case GGML_TYPE_F32:
1333
551
  {
1334
- ggml_compute_forward_dup_f32(params, dst);
552
+ /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
553
+ else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
554
+ else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
555
+ else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
556
+ else ggml_compute_forward_dup_to_q<float>(params, dst);
1335
557
  } break;
1336
558
  case GGML_TYPE_I32:
1337
559
  {
1338
- ggml_compute_forward_dup_i32(params, dst);
560
+ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
561
+ else GGML_ABORT("not implemented");
1339
562
  } break;
1340
563
  default:
1341
564
  {
1342
565
  if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
1343
- ggml_compute_forward_dup_q(params, dst);
566
+ ggml_compute_forward_dup_from_q(params, dst);
1344
567
  break;
1345
568
  }
1346
569
  GGML_ABORT("fatal error");
@@ -8599,7 +7822,6 @@ static void ggml_compute_forward_timestep_embedding_f32(
8599
7822
  }
8600
7823
  if (dim % 2 != 0 && ith == 0) {
8601
7824
  embed_data[2 * half] = 0.f;
8602
- embed_data[dim] = 0.f;
8603
7825
  }
8604
7826
  }
8605
7827
  }