whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +187 -112
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +55 -19
- package/cpp/ggml-alloc.h +7 -0
- package/cpp/ggml-backend-impl.h +46 -21
- package/cpp/ggml-backend.c +563 -156
- package/cpp/ggml-backend.h +62 -17
- package/cpp/ggml-impl.h +1 -1
- package/cpp/ggml-metal-whisper.metal +1010 -253
- package/cpp/ggml-metal.h +7 -1
- package/cpp/ggml-metal.m +618 -187
- package/cpp/ggml-quants.c +64 -59
- package/cpp/ggml-quants.h +40 -40
- package/cpp/ggml.c +751 -1466
- package/cpp/ggml.h +90 -25
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1635 -928
- package/cpp/whisper.h +55 -10
- package/ios/RNWhisper.mm +7 -7
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +3 -11
- package/ios/RNWhisperContext.mm +62 -134
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +6 -5
- package/src/version.json +1 -1
package/cpp/ggml.c
CHANGED
|
@@ -233,24 +233,6 @@ inline static void * wsp_ggml_aligned_malloc(size_t size) {
|
|
|
233
233
|
#define UNUSED WSP_GGML_UNUSED
|
|
234
234
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
|
235
235
|
|
|
236
|
-
//
|
|
237
|
-
// tensor access macros
|
|
238
|
-
//
|
|
239
|
-
|
|
240
|
-
#define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
|
|
241
|
-
WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
|
242
|
-
WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
243
|
-
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
|
244
|
-
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
245
|
-
|
|
246
|
-
#define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
|
|
247
|
-
WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
|
248
|
-
WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
249
|
-
WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
|
250
|
-
WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
|
251
|
-
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
|
252
|
-
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
253
|
-
|
|
254
236
|
#if defined(WSP_GGML_USE_ACCELERATE)
|
|
255
237
|
#include <Accelerate/Accelerate.h>
|
|
256
238
|
#if defined(WSP_GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
|
|
@@ -455,9 +437,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
455
437
|
.blck_size = QK4_0,
|
|
456
438
|
.type_size = sizeof(block_q4_0),
|
|
457
439
|
.is_quantized = true,
|
|
458
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
459
|
-
.from_float =
|
|
460
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
440
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_0,
|
|
441
|
+
.from_float = wsp_quantize_row_q4_0,
|
|
442
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_0_reference,
|
|
461
443
|
.vec_dot = wsp_ggml_vec_dot_q4_0_q8_0,
|
|
462
444
|
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
463
445
|
},
|
|
@@ -466,9 +448,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
466
448
|
.blck_size = QK4_1,
|
|
467
449
|
.type_size = sizeof(block_q4_1),
|
|
468
450
|
.is_quantized = true,
|
|
469
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
470
|
-
.from_float =
|
|
471
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
451
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_1,
|
|
452
|
+
.from_float = wsp_quantize_row_q4_1,
|
|
453
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_1_reference,
|
|
472
454
|
.vec_dot = wsp_ggml_vec_dot_q4_1_q8_1,
|
|
473
455
|
.vec_dot_type = WSP_GGML_TYPE_Q8_1,
|
|
474
456
|
},
|
|
@@ -499,9 +481,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
499
481
|
.blck_size = QK5_0,
|
|
500
482
|
.type_size = sizeof(block_q5_0),
|
|
501
483
|
.is_quantized = true,
|
|
502
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
503
|
-
.from_float =
|
|
504
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
484
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_0,
|
|
485
|
+
.from_float = wsp_quantize_row_q5_0,
|
|
486
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_0_reference,
|
|
505
487
|
.vec_dot = wsp_ggml_vec_dot_q5_0_q8_0,
|
|
506
488
|
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
507
489
|
},
|
|
@@ -510,9 +492,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
510
492
|
.blck_size = QK5_1,
|
|
511
493
|
.type_size = sizeof(block_q5_1),
|
|
512
494
|
.is_quantized = true,
|
|
513
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
514
|
-
.from_float =
|
|
515
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
495
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_1,
|
|
496
|
+
.from_float = wsp_quantize_row_q5_1,
|
|
497
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_1_reference,
|
|
516
498
|
.vec_dot = wsp_ggml_vec_dot_q5_1_q8_1,
|
|
517
499
|
.vec_dot_type = WSP_GGML_TYPE_Q8_1,
|
|
518
500
|
},
|
|
@@ -521,9 +503,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
521
503
|
.blck_size = QK8_0,
|
|
522
504
|
.type_size = sizeof(block_q8_0),
|
|
523
505
|
.is_quantized = true,
|
|
524
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
525
|
-
.from_float =
|
|
526
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
506
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q8_0,
|
|
507
|
+
.from_float = wsp_quantize_row_q8_0,
|
|
508
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q8_0_reference,
|
|
527
509
|
.vec_dot = wsp_ggml_vec_dot_q8_0_q8_0,
|
|
528
510
|
.vec_dot_type = WSP_GGML_TYPE_Q8_0,
|
|
529
511
|
},
|
|
@@ -532,8 +514,8 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
532
514
|
.blck_size = QK8_1,
|
|
533
515
|
.type_size = sizeof(block_q8_1),
|
|
534
516
|
.is_quantized = true,
|
|
535
|
-
.from_float =
|
|
536
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
517
|
+
.from_float = wsp_quantize_row_q8_1,
|
|
518
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q8_1_reference,
|
|
537
519
|
.vec_dot_type = WSP_GGML_TYPE_Q8_1,
|
|
538
520
|
},
|
|
539
521
|
[WSP_GGML_TYPE_Q2_K] = {
|
|
@@ -541,9 +523,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
541
523
|
.blck_size = QK_K,
|
|
542
524
|
.type_size = sizeof(block_q2_K),
|
|
543
525
|
.is_quantized = true,
|
|
544
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
545
|
-
.from_float =
|
|
546
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
526
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q2_K,
|
|
527
|
+
.from_float = wsp_quantize_row_q2_K,
|
|
528
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q2_K_reference,
|
|
547
529
|
.vec_dot = wsp_ggml_vec_dot_q2_K_q8_K,
|
|
548
530
|
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
549
531
|
},
|
|
@@ -552,9 +534,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
552
534
|
.blck_size = QK_K,
|
|
553
535
|
.type_size = sizeof(block_q3_K),
|
|
554
536
|
.is_quantized = true,
|
|
555
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
556
|
-
.from_float =
|
|
557
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
537
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q3_K,
|
|
538
|
+
.from_float = wsp_quantize_row_q3_K,
|
|
539
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q3_K_reference,
|
|
558
540
|
.vec_dot = wsp_ggml_vec_dot_q3_K_q8_K,
|
|
559
541
|
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
560
542
|
},
|
|
@@ -563,9 +545,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
563
545
|
.blck_size = QK_K,
|
|
564
546
|
.type_size = sizeof(block_q4_K),
|
|
565
547
|
.is_quantized = true,
|
|
566
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
567
|
-
.from_float =
|
|
568
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
548
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_K,
|
|
549
|
+
.from_float = wsp_quantize_row_q4_K,
|
|
550
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_K_reference,
|
|
569
551
|
.vec_dot = wsp_ggml_vec_dot_q4_K_q8_K,
|
|
570
552
|
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
571
553
|
},
|
|
@@ -574,9 +556,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
574
556
|
.blck_size = QK_K,
|
|
575
557
|
.type_size = sizeof(block_q5_K),
|
|
576
558
|
.is_quantized = true,
|
|
577
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
578
|
-
.from_float =
|
|
579
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
559
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_K,
|
|
560
|
+
.from_float = wsp_quantize_row_q5_K,
|
|
561
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_K_reference,
|
|
580
562
|
.vec_dot = wsp_ggml_vec_dot_q5_K_q8_K,
|
|
581
563
|
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
582
564
|
},
|
|
@@ -585,9 +567,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
585
567
|
.blck_size = QK_K,
|
|
586
568
|
.type_size = sizeof(block_q6_K),
|
|
587
569
|
.is_quantized = true,
|
|
588
|
-
.to_float = (wsp_ggml_to_float_t)
|
|
589
|
-
.from_float =
|
|
590
|
-
.from_float_reference = (wsp_ggml_from_float_t)
|
|
570
|
+
.to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q6_K,
|
|
571
|
+
.from_float = wsp_quantize_row_q6_K,
|
|
572
|
+
.from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q6_K_reference,
|
|
591
573
|
.vec_dot = wsp_ggml_vec_dot_q6_K_q8_K,
|
|
592
574
|
.vec_dot_type = WSP_GGML_TYPE_Q8_K,
|
|
593
575
|
},
|
|
@@ -596,7 +578,7 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
|
|
|
596
578
|
.blck_size = QK_K,
|
|
597
579
|
.type_size = sizeof(block_q8_K),
|
|
598
580
|
.is_quantized = true,
|
|
599
|
-
.from_float =
|
|
581
|
+
.from_float = wsp_quantize_row_q8_K,
|
|
600
582
|
}
|
|
601
583
|
};
|
|
602
584
|
|
|
@@ -1613,6 +1595,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
1613
1595
|
"GROUP_NORM",
|
|
1614
1596
|
|
|
1615
1597
|
"MUL_MAT",
|
|
1598
|
+
"MUL_MAT_ID",
|
|
1616
1599
|
"OUT_PROD",
|
|
1617
1600
|
|
|
1618
1601
|
"SCALE",
|
|
@@ -1634,17 +1617,13 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
1634
1617
|
"ROPE_BACK",
|
|
1635
1618
|
"ALIBI",
|
|
1636
1619
|
"CLAMP",
|
|
1637
|
-
"CONV_1D",
|
|
1638
|
-
"CONV_1D_STAGE_0",
|
|
1639
|
-
"CONV_1D_STAGE_1",
|
|
1640
1620
|
"CONV_TRANSPOSE_1D",
|
|
1641
|
-
"
|
|
1642
|
-
"CONV_2D_STAGE_0",
|
|
1643
|
-
"CONV_2D_STAGE_1",
|
|
1621
|
+
"IM2COL",
|
|
1644
1622
|
"CONV_TRANSPOSE_2D",
|
|
1645
1623
|
"POOL_1D",
|
|
1646
1624
|
"POOL_2D",
|
|
1647
1625
|
"UPSCALE",
|
|
1626
|
+
"ARGSORT",
|
|
1648
1627
|
|
|
1649
1628
|
"FLASH_ATTN",
|
|
1650
1629
|
"FLASH_FF",
|
|
@@ -1671,7 +1650,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
|
|
|
1671
1650
|
"CROSS_ENTROPY_LOSS_BACK",
|
|
1672
1651
|
};
|
|
1673
1652
|
|
|
1674
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1653
|
+
static_assert(WSP_GGML_OP_COUNT == 70, "WSP_GGML_OP_COUNT != 70");
|
|
1675
1654
|
|
|
1676
1655
|
static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
1677
1656
|
"none",
|
|
@@ -1700,6 +1679,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1700
1679
|
"group_norm(x)",
|
|
1701
1680
|
|
|
1702
1681
|
"X*Y",
|
|
1682
|
+
"X[i]*Y",
|
|
1703
1683
|
"X*Y",
|
|
1704
1684
|
|
|
1705
1685
|
"x*v",
|
|
@@ -1721,17 +1701,13 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1721
1701
|
"rope_back(x)",
|
|
1722
1702
|
"alibi(x)",
|
|
1723
1703
|
"clamp(x)",
|
|
1724
|
-
"conv_1d(x)",
|
|
1725
|
-
"conv_1d_stage_0(x)",
|
|
1726
|
-
"conv_1d_stage_1(x)",
|
|
1727
1704
|
"conv_transpose_1d(x)",
|
|
1728
|
-
"
|
|
1729
|
-
"conv_2d_stage_0(x)",
|
|
1730
|
-
"conv_2d_stage_1(x)",
|
|
1705
|
+
"im2col(x)",
|
|
1731
1706
|
"conv_transpose_2d(x)",
|
|
1732
1707
|
"pool_1d(x)",
|
|
1733
1708
|
"pool_2d(x)",
|
|
1734
1709
|
"upscale(x)",
|
|
1710
|
+
"argsort(x)",
|
|
1735
1711
|
|
|
1736
1712
|
"flash_attn(x)",
|
|
1737
1713
|
"flash_ff(x)",
|
|
@@ -1758,10 +1734,28 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
|
|
|
1758
1734
|
"cross_entropy_loss_back(x,y)",
|
|
1759
1735
|
};
|
|
1760
1736
|
|
|
1761
|
-
static_assert(WSP_GGML_OP_COUNT ==
|
|
1737
|
+
static_assert(WSP_GGML_OP_COUNT == 70, "WSP_GGML_OP_COUNT != 70");
|
|
1762
1738
|
|
|
1763
1739
|
static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
|
|
1764
1740
|
|
|
1741
|
+
|
|
1742
|
+
static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
|
|
1743
|
+
"ABS",
|
|
1744
|
+
"SGN",
|
|
1745
|
+
"NEG",
|
|
1746
|
+
"STEP",
|
|
1747
|
+
"TANH",
|
|
1748
|
+
"ELU",
|
|
1749
|
+
"RELU",
|
|
1750
|
+
"GELU",
|
|
1751
|
+
"GELU_QUICK",
|
|
1752
|
+
"SILU",
|
|
1753
|
+
"LEAKY",
|
|
1754
|
+
};
|
|
1755
|
+
|
|
1756
|
+
static_assert(WSP_GGML_UNARY_OP_COUNT == 11, "WSP_GGML_UNARY_OP_COUNT != 11");
|
|
1757
|
+
|
|
1758
|
+
|
|
1765
1759
|
static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
1766
1760
|
static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
|
|
1767
1761
|
|
|
@@ -1781,18 +1775,13 @@ static void wsp_ggml_setup_op_has_task_pass(void) {
|
|
|
1781
1775
|
|
|
1782
1776
|
p[WSP_GGML_OP_ACC ] = true;
|
|
1783
1777
|
p[WSP_GGML_OP_MUL_MAT ] = true;
|
|
1778
|
+
p[WSP_GGML_OP_MUL_MAT_ID ] = true;
|
|
1784
1779
|
p[WSP_GGML_OP_OUT_PROD ] = true;
|
|
1785
1780
|
p[WSP_GGML_OP_SET ] = true;
|
|
1786
1781
|
p[WSP_GGML_OP_GET_ROWS_BACK ] = true;
|
|
1787
1782
|
p[WSP_GGML_OP_DIAG_MASK_INF ] = true;
|
|
1788
1783
|
p[WSP_GGML_OP_DIAG_MASK_ZERO ] = true;
|
|
1789
|
-
p[WSP_GGML_OP_CONV_1D ] = true;
|
|
1790
|
-
p[WSP_GGML_OP_CONV_1D_STAGE_0 ] = true;
|
|
1791
|
-
p[WSP_GGML_OP_CONV_1D_STAGE_1 ] = true;
|
|
1792
1784
|
p[WSP_GGML_OP_CONV_TRANSPOSE_1D ] = true;
|
|
1793
|
-
p[WSP_GGML_OP_CONV_2D ] = true;
|
|
1794
|
-
p[WSP_GGML_OP_CONV_2D_STAGE_0 ] = true;
|
|
1795
|
-
p[WSP_GGML_OP_CONV_2D_STAGE_1 ] = true;
|
|
1796
1785
|
p[WSP_GGML_OP_CONV_TRANSPOSE_2D ] = true;
|
|
1797
1786
|
p[WSP_GGML_OP_FLASH_ATTN_BACK ] = true;
|
|
1798
1787
|
p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
|
@@ -2039,6 +2028,20 @@ const char * wsp_ggml_op_symbol(enum wsp_ggml_op op) {
|
|
|
2039
2028
|
return WSP_GGML_OP_SYMBOL[op];
|
|
2040
2029
|
}
|
|
2041
2030
|
|
|
2031
|
+
const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) {
|
|
2032
|
+
return WSP_GGML_UNARY_OP_NAME[op];
|
|
2033
|
+
}
|
|
2034
|
+
|
|
2035
|
+
const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
|
|
2036
|
+
if (t->op == WSP_GGML_OP_UNARY) {
|
|
2037
|
+
enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t);
|
|
2038
|
+
return wsp_ggml_unary_op_name(uop);
|
|
2039
|
+
}
|
|
2040
|
+
else {
|
|
2041
|
+
return wsp_ggml_op_name(t->op);
|
|
2042
|
+
}
|
|
2043
|
+
}
|
|
2044
|
+
|
|
2042
2045
|
size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) {
|
|
2043
2046
|
return wsp_ggml_type_size(tensor->type);
|
|
2044
2047
|
}
|
|
@@ -3170,9 +3173,7 @@ static struct wsp_ggml_tensor * wsp_ggml_add_impl(
|
|
|
3170
3173
|
struct wsp_ggml_tensor * a,
|
|
3171
3174
|
struct wsp_ggml_tensor * b,
|
|
3172
3175
|
bool inplace) {
|
|
3173
|
-
|
|
3174
|
-
// WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
|
|
3175
|
-
WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a));
|
|
3176
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
|
|
3176
3177
|
|
|
3177
3178
|
bool is_node = false;
|
|
3178
3179
|
|
|
@@ -3387,9 +3388,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_impl(
|
|
|
3387
3388
|
struct wsp_ggml_tensor * a,
|
|
3388
3389
|
struct wsp_ggml_tensor * b,
|
|
3389
3390
|
bool inplace) {
|
|
3390
|
-
|
|
3391
|
-
// WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
|
|
3392
|
-
WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a));
|
|
3391
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
|
|
3393
3392
|
|
|
3394
3393
|
bool is_node = false;
|
|
3395
3394
|
|
|
@@ -3434,7 +3433,7 @@ static struct wsp_ggml_tensor * wsp_ggml_div_impl(
|
|
|
3434
3433
|
struct wsp_ggml_tensor * a,
|
|
3435
3434
|
struct wsp_ggml_tensor * b,
|
|
3436
3435
|
bool inplace) {
|
|
3437
|
-
WSP_GGML_ASSERT(
|
|
3436
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
|
|
3438
3437
|
|
|
3439
3438
|
bool is_node = false;
|
|
3440
3439
|
|
|
@@ -4072,6 +4071,49 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
|
|
|
4072
4071
|
return result;
|
|
4073
4072
|
}
|
|
4074
4073
|
|
|
4074
|
+
// wsp_ggml_mul_mat_id
|
|
4075
|
+
|
|
4076
|
+
struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
|
|
4077
|
+
struct wsp_ggml_context * ctx,
|
|
4078
|
+
struct wsp_ggml_tensor * as[],
|
|
4079
|
+
struct wsp_ggml_tensor * ids,
|
|
4080
|
+
int id,
|
|
4081
|
+
struct wsp_ggml_tensor * b) {
|
|
4082
|
+
|
|
4083
|
+
int64_t n_as = ids->ne[0];
|
|
4084
|
+
|
|
4085
|
+
WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
|
|
4086
|
+
WSP_GGML_ASSERT(wsp_ggml_is_vector(ids));
|
|
4087
|
+
WSP_GGML_ASSERT(n_as > 0 && n_as <= WSP_GGML_MAX_SRC - 2);
|
|
4088
|
+
WSP_GGML_ASSERT(id >= 0 && id < n_as);
|
|
4089
|
+
|
|
4090
|
+
bool is_node = false;
|
|
4091
|
+
|
|
4092
|
+
if (as[0]->grad || b->grad) {
|
|
4093
|
+
is_node = true;
|
|
4094
|
+
}
|
|
4095
|
+
|
|
4096
|
+
const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
|
4097
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
|
|
4098
|
+
|
|
4099
|
+
wsp_ggml_set_op_params_i32(result, 0, id);
|
|
4100
|
+
|
|
4101
|
+
result->op = WSP_GGML_OP_MUL_MAT_ID;
|
|
4102
|
+
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
4103
|
+
result->src[0] = ids;
|
|
4104
|
+
result->src[1] = b;
|
|
4105
|
+
|
|
4106
|
+
for (int64_t i = 0; i < n_as; i++) {
|
|
4107
|
+
struct wsp_ggml_tensor * a = as[i];
|
|
4108
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(as[0], a));
|
|
4109
|
+
WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(a, b));
|
|
4110
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(a));
|
|
4111
|
+
result->src[i + 2] = a;
|
|
4112
|
+
}
|
|
4113
|
+
|
|
4114
|
+
return result;
|
|
4115
|
+
}
|
|
4116
|
+
|
|
4075
4117
|
// wsp_ggml_out_prod
|
|
4076
4118
|
|
|
4077
4119
|
struct wsp_ggml_tensor * wsp_ggml_out_prod(
|
|
@@ -4225,7 +4267,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
|
|
|
4225
4267
|
struct wsp_ggml_tensor * b,
|
|
4226
4268
|
size_t nb1,
|
|
4227
4269
|
size_t offset) {
|
|
4228
|
-
return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset,
|
|
4270
|
+
return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
|
|
4229
4271
|
}
|
|
4230
4272
|
|
|
4231
4273
|
// wsp_ggml_cpy
|
|
@@ -4842,7 +4884,17 @@ struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero_inplace(
|
|
|
4842
4884
|
static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
|
|
4843
4885
|
struct wsp_ggml_context * ctx,
|
|
4844
4886
|
struct wsp_ggml_tensor * a,
|
|
4887
|
+
struct wsp_ggml_tensor * mask,
|
|
4888
|
+
float scale,
|
|
4845
4889
|
bool inplace) {
|
|
4890
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
|
|
4891
|
+
if (mask) {
|
|
4892
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
|
|
4893
|
+
WSP_GGML_ASSERT(mask->ne[2] == 1);
|
|
4894
|
+
WSP_GGML_ASSERT(mask->ne[3] == 1);
|
|
4895
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(mask, a));
|
|
4896
|
+
}
|
|
4897
|
+
|
|
4846
4898
|
bool is_node = false;
|
|
4847
4899
|
|
|
4848
4900
|
if (a->grad) {
|
|
@@ -4851,9 +4903,13 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
|
|
|
4851
4903
|
|
|
4852
4904
|
struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
|
|
4853
4905
|
|
|
4906
|
+
float params[] = { scale };
|
|
4907
|
+
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
4908
|
+
|
|
4854
4909
|
result->op = WSP_GGML_OP_SOFT_MAX;
|
|
4855
4910
|
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
4856
4911
|
result->src[0] = a;
|
|
4912
|
+
result->src[1] = mask;
|
|
4857
4913
|
|
|
4858
4914
|
return result;
|
|
4859
4915
|
}
|
|
@@ -4861,13 +4917,21 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
|
|
|
4861
4917
|
struct wsp_ggml_tensor * wsp_ggml_soft_max(
|
|
4862
4918
|
struct wsp_ggml_context * ctx,
|
|
4863
4919
|
struct wsp_ggml_tensor * a) {
|
|
4864
|
-
return wsp_ggml_soft_max_impl(ctx, a, false);
|
|
4920
|
+
return wsp_ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
|
|
4865
4921
|
}
|
|
4866
4922
|
|
|
4867
4923
|
struct wsp_ggml_tensor * wsp_ggml_soft_max_inplace(
|
|
4868
4924
|
struct wsp_ggml_context * ctx,
|
|
4869
4925
|
struct wsp_ggml_tensor * a) {
|
|
4870
|
-
return wsp_ggml_soft_max_impl(ctx, a, true);
|
|
4926
|
+
return wsp_ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
|
|
4927
|
+
}
|
|
4928
|
+
|
|
4929
|
+
struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
|
|
4930
|
+
struct wsp_ggml_context * ctx,
|
|
4931
|
+
struct wsp_ggml_tensor * a,
|
|
4932
|
+
struct wsp_ggml_tensor * mask,
|
|
4933
|
+
float scale) {
|
|
4934
|
+
return wsp_ggml_soft_max_impl(ctx, a, mask, scale, false);
|
|
4871
4935
|
}
|
|
4872
4936
|
|
|
4873
4937
|
// wsp_ggml_soft_max_back
|
|
@@ -5040,8 +5104,13 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
|
|
|
5040
5104
|
int n_dims,
|
|
5041
5105
|
int mode,
|
|
5042
5106
|
int n_ctx,
|
|
5107
|
+
int n_orig_ctx,
|
|
5043
5108
|
float freq_base,
|
|
5044
5109
|
float freq_scale,
|
|
5110
|
+
float ext_factor,
|
|
5111
|
+
float attn_factor,
|
|
5112
|
+
float beta_fast,
|
|
5113
|
+
float beta_slow,
|
|
5045
5114
|
float xpos_base,
|
|
5046
5115
|
bool xpos_down) {
|
|
5047
5116
|
WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
|
|
@@ -5058,11 +5127,15 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
|
|
|
5058
5127
|
|
|
5059
5128
|
struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
|
|
5060
5129
|
|
|
5061
|
-
int32_t params[
|
|
5062
|
-
memcpy(params +
|
|
5063
|
-
memcpy(params +
|
|
5064
|
-
memcpy(params +
|
|
5065
|
-
memcpy(params +
|
|
5130
|
+
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
|
5131
|
+
memcpy(params + 5, &freq_base, sizeof(float));
|
|
5132
|
+
memcpy(params + 6, &freq_scale, sizeof(float));
|
|
5133
|
+
memcpy(params + 7, &ext_factor, sizeof(float));
|
|
5134
|
+
memcpy(params + 8, &attn_factor, sizeof(float));
|
|
5135
|
+
memcpy(params + 9, &beta_fast, sizeof(float));
|
|
5136
|
+
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
5137
|
+
memcpy(params + 11, &xpos_base, sizeof(float));
|
|
5138
|
+
memcpy(params + 12, &xpos_down, sizeof(bool));
|
|
5066
5139
|
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
5067
5140
|
|
|
5068
5141
|
result->op = WSP_GGML_OP_ROPE_BACK;
|
|
@@ -5137,82 +5210,6 @@ static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, in
|
|
|
5137
5210
|
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
|
5138
5211
|
}
|
|
5139
5212
|
|
|
5140
|
-
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
|
5141
|
-
// a: [OC,IC, K]
|
|
5142
|
-
// b: [N, IC, IL]
|
|
5143
|
-
// result: [N, OL, IC*K]
|
|
5144
|
-
static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_0(
|
|
5145
|
-
struct wsp_ggml_context * ctx,
|
|
5146
|
-
struct wsp_ggml_tensor * a,
|
|
5147
|
-
struct wsp_ggml_tensor * b,
|
|
5148
|
-
int s0,
|
|
5149
|
-
int p0,
|
|
5150
|
-
int d0) {
|
|
5151
|
-
WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
|
|
5152
|
-
bool is_node = false;
|
|
5153
|
-
|
|
5154
|
-
if (a->grad || b->grad) {
|
|
5155
|
-
WSP_GGML_ASSERT(false); // TODO: implement backward
|
|
5156
|
-
is_node = true;
|
|
5157
|
-
}
|
|
5158
|
-
|
|
5159
|
-
const int64_t OL = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
|
5160
|
-
|
|
5161
|
-
const int64_t ne[4] = {
|
|
5162
|
-
a->ne[1] * a->ne[0],
|
|
5163
|
-
OL,
|
|
5164
|
-
b->ne[2],
|
|
5165
|
-
1,
|
|
5166
|
-
};
|
|
5167
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
|
|
5168
|
-
|
|
5169
|
-
int32_t params[] = { s0, p0, d0 };
|
|
5170
|
-
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
5171
|
-
|
|
5172
|
-
result->op = WSP_GGML_OP_CONV_1D_STAGE_0;
|
|
5173
|
-
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5174
|
-
result->src[0] = a;
|
|
5175
|
-
result->src[1] = b;
|
|
5176
|
-
|
|
5177
|
-
return result;
|
|
5178
|
-
}
|
|
5179
|
-
|
|
5180
|
-
// wsp_ggml_conv_1d_stage_1
|
|
5181
|
-
|
|
5182
|
-
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
|
5183
|
-
// a: [OC, IC, K]
|
|
5184
|
-
// b: [N, OL, IC * K]
|
|
5185
|
-
// result: [N, OC, OL]
|
|
5186
|
-
static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_1(
|
|
5187
|
-
struct wsp_ggml_context * ctx,
|
|
5188
|
-
struct wsp_ggml_tensor * a,
|
|
5189
|
-
struct wsp_ggml_tensor * b) {
|
|
5190
|
-
|
|
5191
|
-
bool is_node = false;
|
|
5192
|
-
|
|
5193
|
-
if (a->grad || b->grad) {
|
|
5194
|
-
WSP_GGML_ASSERT(false); // TODO: implement backward
|
|
5195
|
-
is_node = true;
|
|
5196
|
-
}
|
|
5197
|
-
|
|
5198
|
-
const int64_t ne[4] = {
|
|
5199
|
-
b->ne[1],
|
|
5200
|
-
a->ne[2],
|
|
5201
|
-
b->ne[2],
|
|
5202
|
-
1,
|
|
5203
|
-
};
|
|
5204
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
|
|
5205
|
-
|
|
5206
|
-
result->op = WSP_GGML_OP_CONV_1D_STAGE_1;
|
|
5207
|
-
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5208
|
-
result->src[0] = a;
|
|
5209
|
-
result->src[1] = b;
|
|
5210
|
-
|
|
5211
|
-
return result;
|
|
5212
|
-
}
|
|
5213
|
-
|
|
5214
|
-
// wsp_ggml_conv_1d
|
|
5215
|
-
|
|
5216
5213
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
|
|
5217
5214
|
struct wsp_ggml_context * ctx,
|
|
5218
5215
|
struct wsp_ggml_tensor * a,
|
|
@@ -5220,43 +5217,17 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
|
|
|
5220
5217
|
int s0,
|
|
5221
5218
|
int p0,
|
|
5222
5219
|
int d0) {
|
|
5223
|
-
struct wsp_ggml_tensor *
|
|
5224
|
-
result = wsp_ggml_conv_1d_stage_1(ctx, a, result);
|
|
5225
|
-
return result;
|
|
5226
|
-
}
|
|
5220
|
+
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
|
|
5227
5221
|
|
|
5228
|
-
|
|
5229
|
-
|
|
5230
|
-
//
|
|
5231
|
-
//
|
|
5232
|
-
// int s0,
|
|
5233
|
-
// int p0,
|
|
5234
|
-
// int d0) {
|
|
5235
|
-
// WSP_GGML_ASSERT(wsp_ggml_is_matrix(b));
|
|
5236
|
-
// WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
|
|
5237
|
-
// bool is_node = false;
|
|
5222
|
+
struct wsp_ggml_tensor * result =
|
|
5223
|
+
wsp_ggml_mul_mat(ctx,
|
|
5224
|
+
wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
|
|
5225
|
+
wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
|
|
5238
5226
|
|
|
5239
|
-
|
|
5240
|
-
// WSP_GGML_ASSERT(false); // TODO: implement backward
|
|
5241
|
-
// is_node = true;
|
|
5242
|
-
// }
|
|
5227
|
+
result = wsp_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
|
|
5243
5228
|
|
|
5244
|
-
|
|
5245
|
-
|
|
5246
|
-
// a->ne[2], 1, 1,
|
|
5247
|
-
// };
|
|
5248
|
-
// struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne);
|
|
5249
|
-
|
|
5250
|
-
// int32_t params[] = { s0, p0, d0 };
|
|
5251
|
-
// wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
5252
|
-
|
|
5253
|
-
// result->op = WSP_GGML_OP_CONV_1D;
|
|
5254
|
-
// result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5255
|
-
// result->src[0] = a;
|
|
5256
|
-
// result->src[1] = b;
|
|
5257
|
-
|
|
5258
|
-
// return result;
|
|
5259
|
-
// }
|
|
5229
|
+
return result;
|
|
5230
|
+
}
|
|
5260
5231
|
|
|
5261
5232
|
// wsp_ggml_conv_1d_ph
|
|
5262
5233
|
|
|
@@ -5319,7 +5290,7 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
|
|
|
5319
5290
|
// a: [OC,IC, KH, KW]
|
|
5320
5291
|
// b: [N, IC, IH, IW]
|
|
5321
5292
|
// result: [N, OH, OW, IC*KH*KW]
|
|
5322
|
-
|
|
5293
|
+
struct wsp_ggml_tensor * wsp_ggml_im2col(
|
|
5323
5294
|
struct wsp_ggml_context * ctx,
|
|
5324
5295
|
struct wsp_ggml_tensor * a,
|
|
5325
5296
|
struct wsp_ggml_tensor * b,
|
|
@@ -5328,9 +5299,14 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
|
|
|
5328
5299
|
int p0,
|
|
5329
5300
|
int p1,
|
|
5330
5301
|
int d0,
|
|
5331
|
-
int d1
|
|
5302
|
+
int d1,
|
|
5303
|
+
bool is_2D) {
|
|
5332
5304
|
|
|
5333
|
-
|
|
5305
|
+
if(is_2D) {
|
|
5306
|
+
WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
|
|
5307
|
+
} else {
|
|
5308
|
+
WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
|
|
5309
|
+
}
|
|
5334
5310
|
bool is_node = false;
|
|
5335
5311
|
|
|
5336
5312
|
if (a->grad || b->grad) {
|
|
@@ -5338,81 +5314,51 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
|
|
|
5338
5314
|
is_node = true;
|
|
5339
5315
|
}
|
|
5340
5316
|
|
|
5341
|
-
const int64_t OH = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
|
5342
|
-
const int64_t OW =
|
|
5317
|
+
const int64_t OH = is_2D ? wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
|
|
5318
|
+
const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
|
5343
5319
|
|
|
5344
5320
|
const int64_t ne[4] = {
|
|
5345
|
-
a->ne[2] * a->ne[1] * a->ne[0],
|
|
5321
|
+
is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
|
|
5346
5322
|
OW,
|
|
5347
|
-
OH,
|
|
5348
|
-
b->ne[3],
|
|
5323
|
+
is_2D ? OH : b->ne[2],
|
|
5324
|
+
is_2D ? b->ne[3] : 1,
|
|
5349
5325
|
};
|
|
5350
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
|
|
5351
5326
|
|
|
5352
|
-
|
|
5327
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
|
|
5328
|
+
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
|
|
5353
5329
|
wsp_ggml_set_op_params(result, params, sizeof(params));
|
|
5354
5330
|
|
|
5355
|
-
result->op =
|
|
5356
|
-
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5357
|
-
result->src[0] = a;
|
|
5358
|
-
result->src[1] = b;
|
|
5359
|
-
|
|
5360
|
-
return result;
|
|
5361
|
-
|
|
5362
|
-
}
|
|
5363
|
-
|
|
5364
|
-
// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
|
5365
|
-
// a: [OC, IC, KH, KW]
|
|
5366
|
-
// b: [N, OH, OW, IC * KH * KW]
|
|
5367
|
-
// result: [N, OC, OH, OW]
|
|
5368
|
-
static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_1(
|
|
5369
|
-
struct wsp_ggml_context * ctx,
|
|
5370
|
-
struct wsp_ggml_tensor * a,
|
|
5371
|
-
struct wsp_ggml_tensor * b) {
|
|
5372
|
-
|
|
5373
|
-
bool is_node = false;
|
|
5374
|
-
|
|
5375
|
-
if (a->grad || b->grad) {
|
|
5376
|
-
WSP_GGML_ASSERT(false); // TODO: implement backward
|
|
5377
|
-
is_node = true;
|
|
5378
|
-
}
|
|
5379
|
-
|
|
5380
|
-
const int64_t ne[4] = {
|
|
5381
|
-
b->ne[1],
|
|
5382
|
-
b->ne[2],
|
|
5383
|
-
a->ne[3],
|
|
5384
|
-
b->ne[3],
|
|
5385
|
-
};
|
|
5386
|
-
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
|
|
5387
|
-
|
|
5388
|
-
result->op = WSP_GGML_OP_CONV_2D_STAGE_1;
|
|
5331
|
+
result->op = WSP_GGML_OP_IM2COL;
|
|
5389
5332
|
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5390
5333
|
result->src[0] = a;
|
|
5391
5334
|
result->src[1] = b;
|
|
5392
5335
|
|
|
5393
5336
|
return result;
|
|
5394
|
-
|
|
5395
5337
|
}
|
|
5396
5338
|
|
|
5397
5339
|
// a: [OC,IC, KH, KW]
|
|
5398
5340
|
// b: [N, IC, IH, IW]
|
|
5399
5341
|
// result: [N, OC, OH, OW]
|
|
5400
5342
|
struct wsp_ggml_tensor * wsp_ggml_conv_2d(
|
|
5401
|
-
|
|
5402
|
-
|
|
5403
|
-
|
|
5404
|
-
|
|
5405
|
-
|
|
5406
|
-
|
|
5407
|
-
|
|
5408
|
-
|
|
5409
|
-
|
|
5343
|
+
struct wsp_ggml_context * ctx,
|
|
5344
|
+
struct wsp_ggml_tensor * a,
|
|
5345
|
+
struct wsp_ggml_tensor * b,
|
|
5346
|
+
int s0,
|
|
5347
|
+
int s1,
|
|
5348
|
+
int p0,
|
|
5349
|
+
int p1,
|
|
5350
|
+
int d0,
|
|
5351
|
+
int d1) {
|
|
5352
|
+
struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
|
5410
5353
|
|
|
5411
|
-
struct wsp_ggml_tensor * result =
|
|
5412
|
-
|
|
5354
|
+
struct wsp_ggml_tensor * result =
|
|
5355
|
+
wsp_ggml_mul_mat(ctx,
|
|
5356
|
+
wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
|
|
5357
|
+
wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
|
|
5413
5358
|
|
|
5414
|
-
|
|
5359
|
+
result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
|
|
5415
5360
|
|
|
5361
|
+
return result;
|
|
5416
5362
|
}
|
|
5417
5363
|
|
|
5418
5364
|
// wsp_ggml_conv_2d_sk_p0
|
|
@@ -5580,6 +5526,43 @@ struct wsp_ggml_tensor * wsp_ggml_upscale(
|
|
|
5580
5526
|
return wsp_ggml_upscale_impl(ctx, a, scale_factor);
|
|
5581
5527
|
}
|
|
5582
5528
|
|
|
5529
|
+
// wsp_ggml_argsort
|
|
5530
|
+
|
|
5531
|
+
struct wsp_ggml_tensor * wsp_ggml_argsort(
|
|
5532
|
+
struct wsp_ggml_context * ctx,
|
|
5533
|
+
struct wsp_ggml_tensor * a,
|
|
5534
|
+
enum wsp_ggml_sort_order order) {
|
|
5535
|
+
bool is_node = false;
|
|
5536
|
+
|
|
5537
|
+
struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, a->ne);
|
|
5538
|
+
|
|
5539
|
+
wsp_ggml_set_op_params_i32(result, 0, (int32_t) order);
|
|
5540
|
+
|
|
5541
|
+
result->op = WSP_GGML_OP_ARGSORT;
|
|
5542
|
+
result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
|
|
5543
|
+
result->src[0] = a;
|
|
5544
|
+
|
|
5545
|
+
return result;
|
|
5546
|
+
}
|
|
5547
|
+
|
|
5548
|
+
// wsp_ggml_top_k
|
|
5549
|
+
|
|
5550
|
+
struct wsp_ggml_tensor * wsp_ggml_top_k(
|
|
5551
|
+
struct wsp_ggml_context * ctx,
|
|
5552
|
+
struct wsp_ggml_tensor * a,
|
|
5553
|
+
int k) {
|
|
5554
|
+
WSP_GGML_ASSERT(a->ne[0] >= k);
|
|
5555
|
+
|
|
5556
|
+
struct wsp_ggml_tensor * result = wsp_ggml_argsort(ctx, a, WSP_GGML_SORT_DESC);
|
|
5557
|
+
|
|
5558
|
+
result = wsp_ggml_view_4d(ctx, result,
|
|
5559
|
+
k, result->ne[1], result->ne[2], result->ne[3],
|
|
5560
|
+
result->nb[1], result->nb[2], result->nb[3],
|
|
5561
|
+
0);
|
|
5562
|
+
|
|
5563
|
+
return result;
|
|
5564
|
+
}
|
|
5565
|
+
|
|
5583
5566
|
// wsp_ggml_flash_attn
|
|
5584
5567
|
|
|
5585
5568
|
struct wsp_ggml_tensor * wsp_ggml_flash_attn(
|
|
@@ -6472,7 +6455,7 @@ static void wsp_ggml_compute_forward_dup_f16(
|
|
|
6472
6455
|
}
|
|
6473
6456
|
}
|
|
6474
6457
|
} else if (type_traits[dst->type].from_float) {
|
|
6475
|
-
wsp_ggml_from_float_t const
|
|
6458
|
+
wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dst->type].from_float;
|
|
6476
6459
|
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
6477
6460
|
|
|
6478
6461
|
size_t id = 0;
|
|
@@ -6489,7 +6472,7 @@ static void wsp_ggml_compute_forward_dup_f16(
|
|
|
6489
6472
|
src0_f32[i00] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]);
|
|
6490
6473
|
}
|
|
6491
6474
|
|
|
6492
|
-
|
|
6475
|
+
wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
6493
6476
|
id += rs;
|
|
6494
6477
|
}
|
|
6495
6478
|
id += rs * (ne01 - ir1);
|
|
@@ -6725,7 +6708,7 @@ static void wsp_ggml_compute_forward_dup_f32(
|
|
|
6725
6708
|
}
|
|
6726
6709
|
}
|
|
6727
6710
|
} else if (type_traits[dst->type].from_float) {
|
|
6728
|
-
wsp_ggml_from_float_t const
|
|
6711
|
+
wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dst->type].from_float;
|
|
6729
6712
|
|
|
6730
6713
|
size_t id = 0;
|
|
6731
6714
|
size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
|
|
@@ -6736,7 +6719,7 @@ static void wsp_ggml_compute_forward_dup_f32(
|
|
|
6736
6719
|
id += rs * ir0;
|
|
6737
6720
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
6738
6721
|
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
6739
|
-
|
|
6722
|
+
wsp_quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
|
6740
6723
|
id += rs;
|
|
6741
6724
|
}
|
|
6742
6725
|
id += rs * (ne01 - ir1);
|
|
@@ -6939,7 +6922,7 @@ static void wsp_ggml_compute_forward_add_f32(
|
|
|
6939
6922
|
const struct wsp_ggml_tensor * src0,
|
|
6940
6923
|
const struct wsp_ggml_tensor * src1,
|
|
6941
6924
|
struct wsp_ggml_tensor * dst) {
|
|
6942
|
-
WSP_GGML_ASSERT(
|
|
6925
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
6943
6926
|
|
|
6944
6927
|
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
6945
6928
|
return;
|
|
@@ -6972,16 +6955,19 @@ static void wsp_ggml_compute_forward_add_f32(
|
|
|
6972
6955
|
const int64_t i13 = i03 % ne13;
|
|
6973
6956
|
const int64_t i12 = i02 % ne12;
|
|
6974
6957
|
const int64_t i11 = i01 % ne11;
|
|
6958
|
+
const int64_t nr0 = ne00 / ne10;
|
|
6975
6959
|
|
|
6976
6960
|
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
6977
6961
|
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
6978
6962
|
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
6979
6963
|
|
|
6964
|
+
for (int64_t r = 0; r < nr0; ++r) {
|
|
6980
6965
|
#ifdef WSP_GGML_USE_ACCELERATE
|
|
6981
|
-
|
|
6966
|
+
vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
|
|
6982
6967
|
#else
|
|
6983
|
-
|
|
6968
|
+
wsp_ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
6984
6969
|
#endif
|
|
6970
|
+
}
|
|
6985
6971
|
}
|
|
6986
6972
|
} else {
|
|
6987
6973
|
// src1 is not contiguous
|
|
@@ -6998,8 +6984,9 @@ static void wsp_ggml_compute_forward_add_f32(
|
|
|
6998
6984
|
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
6999
6985
|
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
7000
6986
|
|
|
7001
|
-
for (
|
|
7002
|
-
|
|
6987
|
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
6988
|
+
const int64_t i10 = i0 % ne10;
|
|
6989
|
+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
|
|
7003
6990
|
|
|
7004
6991
|
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
|
|
7005
6992
|
}
|
|
@@ -7158,8 +7145,8 @@ static void wsp_ggml_compute_forward_add_q_f32(
|
|
|
7158
7145
|
|
|
7159
7146
|
const enum wsp_ggml_type type = src0->type;
|
|
7160
7147
|
const enum wsp_ggml_type dtype = dst->type;
|
|
7161
|
-
wsp_ggml_to_float_t const
|
|
7162
|
-
wsp_ggml_from_float_t const
|
|
7148
|
+
wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
|
|
7149
|
+
wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dtype].from_float;
|
|
7163
7150
|
|
|
7164
7151
|
// we don't support permuted src0 or src1
|
|
7165
7152
|
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
|
|
@@ -7204,12 +7191,12 @@ static void wsp_ggml_compute_forward_add_q_f32(
|
|
|
7204
7191
|
assert(ne00 % 32 == 0);
|
|
7205
7192
|
|
|
7206
7193
|
// unquantize row from src0 to temp buffer
|
|
7207
|
-
|
|
7194
|
+
wsp_dewsp_quantize_row_q(src0_row, wdata, ne00);
|
|
7208
7195
|
// add src1
|
|
7209
7196
|
wsp_ggml_vec_acc_f32(ne00, wdata, src1_row);
|
|
7210
7197
|
// quantize row to dst
|
|
7211
|
-
if (
|
|
7212
|
-
|
|
7198
|
+
if (wsp_quantize_row_q != NULL) {
|
|
7199
|
+
wsp_quantize_row_q(wdata, dst_row, ne00);
|
|
7213
7200
|
} else {
|
|
7214
7201
|
memcpy(dst_row, wdata, ne0*nb0);
|
|
7215
7202
|
}
|
|
@@ -7435,8 +7422,8 @@ static void wsp_ggml_compute_forward_add1_q_f32(
|
|
|
7435
7422
|
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
7436
7423
|
|
|
7437
7424
|
const enum wsp_ggml_type type = src0->type;
|
|
7438
|
-
wsp_ggml_to_float_t const
|
|
7439
|
-
wsp_ggml_from_float_t const
|
|
7425
|
+
wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
|
|
7426
|
+
wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[type].from_float;
|
|
7440
7427
|
|
|
7441
7428
|
// we don't support permuted src0
|
|
7442
7429
|
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
|
|
@@ -7471,11 +7458,11 @@ static void wsp_ggml_compute_forward_add1_q_f32(
|
|
|
7471
7458
|
assert(ne0 % 32 == 0);
|
|
7472
7459
|
|
|
7473
7460
|
// unquantize row from src0 to temp buffer
|
|
7474
|
-
|
|
7461
|
+
wsp_dewsp_quantize_row_q(src0_row, wdata, ne0);
|
|
7475
7462
|
// add src1
|
|
7476
7463
|
wsp_ggml_vec_acc1_f32(ne0, wdata, v);
|
|
7477
7464
|
// quantize row to dst
|
|
7478
|
-
|
|
7465
|
+
wsp_quantize_row_q(wdata, dst_row, ne0);
|
|
7479
7466
|
}
|
|
7480
7467
|
}
|
|
7481
7468
|
|
|
@@ -7719,7 +7706,7 @@ static void wsp_ggml_compute_forward_mul_f32(
|
|
|
7719
7706
|
const struct wsp_ggml_tensor * src0,
|
|
7720
7707
|
const struct wsp_ggml_tensor * src1,
|
|
7721
7708
|
struct wsp_ggml_tensor * dst) {
|
|
7722
|
-
WSP_GGML_ASSERT(
|
|
7709
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
7723
7710
|
|
|
7724
7711
|
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
7725
7712
|
return;
|
|
@@ -7742,7 +7729,6 @@ static void wsp_ggml_compute_forward_mul_f32(
|
|
|
7742
7729
|
|
|
7743
7730
|
WSP_GGML_ASSERT( nb0 == sizeof(float));
|
|
7744
7731
|
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
7745
|
-
WSP_GGML_ASSERT(ne00 == ne10);
|
|
7746
7732
|
|
|
7747
7733
|
if (nb10 == sizeof(float)) {
|
|
7748
7734
|
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
@@ -7754,20 +7740,21 @@ static void wsp_ggml_compute_forward_mul_f32(
|
|
|
7754
7740
|
const int64_t i13 = i03 % ne13;
|
|
7755
7741
|
const int64_t i12 = i02 % ne12;
|
|
7756
7742
|
const int64_t i11 = i01 % ne11;
|
|
7743
|
+
const int64_t nr0 = ne00 / ne10;
|
|
7757
7744
|
|
|
7758
7745
|
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
7759
7746
|
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
7760
7747
|
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
7761
7748
|
|
|
7749
|
+
for (int64_t r = 0 ; r < nr0; ++r) {
|
|
7762
7750
|
#ifdef WSP_GGML_USE_ACCELERATE
|
|
7763
|
-
|
|
7751
|
+
UNUSED(wsp_ggml_vec_mul_f32);
|
|
7764
7752
|
|
|
7765
|
-
|
|
7753
|
+
vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
|
|
7766
7754
|
#else
|
|
7767
|
-
|
|
7755
|
+
wsp_ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
7768
7756
|
#endif
|
|
7769
|
-
|
|
7770
|
-
// }
|
|
7757
|
+
}
|
|
7771
7758
|
}
|
|
7772
7759
|
} else {
|
|
7773
7760
|
// src1 is not contiguous
|
|
@@ -7785,8 +7772,9 @@ static void wsp_ggml_compute_forward_mul_f32(
|
|
|
7785
7772
|
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
7786
7773
|
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
7787
7774
|
|
|
7788
|
-
for (int64_t i0 = 0; i0 < ne00; i0
|
|
7789
|
-
|
|
7775
|
+
for (int64_t i0 = 0; i0 < ne00; ++i0) {
|
|
7776
|
+
const int64_t i10 = i0 % ne10;
|
|
7777
|
+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
|
|
7790
7778
|
|
|
7791
7779
|
dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
|
|
7792
7780
|
}
|
|
@@ -7820,14 +7808,16 @@ static void wsp_ggml_compute_forward_div_f32(
|
|
|
7820
7808
|
const struct wsp_ggml_tensor * src0,
|
|
7821
7809
|
const struct wsp_ggml_tensor * src1,
|
|
7822
7810
|
struct wsp_ggml_tensor * dst) {
|
|
7823
|
-
|
|
7824
|
-
assert(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst));
|
|
7811
|
+
WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
|
|
7825
7812
|
|
|
7826
7813
|
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
7827
7814
|
return;
|
|
7828
7815
|
}
|
|
7829
7816
|
|
|
7830
|
-
const int
|
|
7817
|
+
const int ith = params->ith;
|
|
7818
|
+
const int nth = params->nth;
|
|
7819
|
+
|
|
7820
|
+
const int64_t nr = wsp_ggml_nrows(src0);
|
|
7831
7821
|
|
|
7832
7822
|
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
7833
7823
|
|
|
@@ -7835,41 +7825,50 @@ static void wsp_ggml_compute_forward_div_f32(
|
|
|
7835
7825
|
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
7836
7826
|
|
|
7837
7827
|
if (nb10 == sizeof(float)) {
|
|
7838
|
-
for (
|
|
7839
|
-
// src0
|
|
7840
|
-
const
|
|
7841
|
-
const
|
|
7842
|
-
const
|
|
7828
|
+
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
7829
|
+
// src0 and dst are same shape => same indices
|
|
7830
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
7831
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
7832
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
7833
|
+
|
|
7834
|
+
const int64_t i13 = i03 % ne13;
|
|
7835
|
+
const int64_t i12 = i02 % ne12;
|
|
7836
|
+
const int64_t i11 = i01 % ne11;
|
|
7837
|
+
const int64_t nr0 = ne00 / ne10;
|
|
7838
|
+
|
|
7839
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
7840
|
+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
7841
|
+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
7843
7842
|
|
|
7843
|
+
for (int64_t r = 0; r < nr0; ++r) {
|
|
7844
7844
|
#ifdef WSP_GGML_USE_ACCELERATE
|
|
7845
|
-
|
|
7845
|
+
UNUSED(wsp_ggml_vec_div_f32);
|
|
7846
7846
|
|
|
7847
|
-
|
|
7848
|
-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
|
|
7849
|
-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
|
|
7850
|
-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
|
|
7851
|
-
ne0);
|
|
7847
|
+
vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
|
|
7852
7848
|
#else
|
|
7853
|
-
|
|
7854
|
-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
|
7855
|
-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
|
7856
|
-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
|
7849
|
+
wsp_ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
7857
7850
|
#endif
|
|
7858
|
-
|
|
7859
|
-
// }
|
|
7851
|
+
}
|
|
7860
7852
|
}
|
|
7861
7853
|
} else {
|
|
7862
7854
|
// src1 is not contiguous
|
|
7863
|
-
for (
|
|
7864
|
-
// src0
|
|
7865
|
-
|
|
7866
|
-
const
|
|
7867
|
-
const
|
|
7855
|
+
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
7856
|
+
// src0 and dst are same shape => same indices
|
|
7857
|
+
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
|
7858
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
7859
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
7860
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
7868
7861
|
|
|
7869
|
-
|
|
7870
|
-
|
|
7871
|
-
|
|
7872
|
-
|
|
7862
|
+
const int64_t i13 = i03 % ne13;
|
|
7863
|
+
const int64_t i12 = i02 % ne12;
|
|
7864
|
+
const int64_t i11 = i01 % ne11;
|
|
7865
|
+
|
|
7866
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
|
7867
|
+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
7868
|
+
|
|
7869
|
+
for (int64_t i0 = 0; i0 < ne00; ++i0) {
|
|
7870
|
+
const int64_t i10 = i0 % ne10;
|
|
7871
|
+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
|
|
7873
7872
|
|
|
7874
7873
|
dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
|
|
7875
7874
|
}
|
|
@@ -8315,7 +8314,7 @@ static void wsp_ggml_compute_forward_repeat_f16(
|
|
|
8315
8314
|
return;
|
|
8316
8315
|
}
|
|
8317
8316
|
|
|
8318
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
8317
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
8319
8318
|
|
|
8320
8319
|
// guaranteed to be an integer due to the check in wsp_ggml_can_repeat
|
|
8321
8320
|
const int nr0 = (int)(ne0/ne00);
|
|
@@ -8460,6 +8459,7 @@ static void wsp_ggml_compute_forward_concat_f32(
|
|
|
8460
8459
|
WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8461
8460
|
|
|
8462
8461
|
const int ith = params->ith;
|
|
8462
|
+
const int nth = params->nth;
|
|
8463
8463
|
|
|
8464
8464
|
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
8465
8465
|
|
|
@@ -8469,7 +8469,7 @@ static void wsp_ggml_compute_forward_concat_f32(
|
|
|
8469
8469
|
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
8470
8470
|
|
|
8471
8471
|
for (int i3 = 0; i3 < ne3; i3++) {
|
|
8472
|
-
for (int i2 = ith; i2 < ne2; i2
|
|
8472
|
+
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
|
8473
8473
|
if (i2 < ne02) { // src0
|
|
8474
8474
|
for (int i1 = 0; i1 < ne1; i1++) {
|
|
8475
8475
|
for (int i0 = 0; i0 < ne0; i0++) {
|
|
@@ -9507,6 +9507,8 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
|
|
|
9507
9507
|
// TODO: find the optimal values for these
|
|
9508
9508
|
if (wsp_ggml_is_contiguous(src0) &&
|
|
9509
9509
|
wsp_ggml_is_contiguous(src1) &&
|
|
9510
|
+
//src0->type == WSP_GGML_TYPE_F32 &&
|
|
9511
|
+
src1->type == WSP_GGML_TYPE_F32 &&
|
|
9510
9512
|
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
|
9511
9513
|
|
|
9512
9514
|
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
|
@@ -9545,7 +9547,7 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9545
9547
|
|
|
9546
9548
|
// we don't support permuted src0 or src1
|
|
9547
9549
|
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
|
|
9548
|
-
WSP_GGML_ASSERT(nb10 ==
|
|
9550
|
+
WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
|
|
9549
9551
|
|
|
9550
9552
|
// dst cannot be transposed or permuted
|
|
9551
9553
|
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
@@ -9627,6 +9629,8 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9627
9629
|
char * wdata = params->wdata;
|
|
9628
9630
|
const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
|
|
9629
9631
|
|
|
9632
|
+
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
|
9633
|
+
|
|
9630
9634
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
|
9631
9635
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
9632
9636
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
|
@@ -9728,6 +9732,26 @@ static void wsp_ggml_compute_forward_mul_mat(
|
|
|
9728
9732
|
}
|
|
9729
9733
|
}
|
|
9730
9734
|
|
|
9735
|
+
// wsp_ggml_compute_forward_mul_mat_id
|
|
9736
|
+
|
|
9737
|
+
static void wsp_ggml_compute_forward_mul_mat_id(
|
|
9738
|
+
const struct wsp_ggml_compute_params * params,
|
|
9739
|
+
struct wsp_ggml_tensor * dst) {
|
|
9740
|
+
|
|
9741
|
+
const struct wsp_ggml_tensor * ids = dst->src[0];
|
|
9742
|
+
const struct wsp_ggml_tensor * src1 = dst->src[1];
|
|
9743
|
+
|
|
9744
|
+
const int id = wsp_ggml_get_op_params_i32(dst, 0);
|
|
9745
|
+
|
|
9746
|
+
const int a_id = ((int32_t *)ids->data)[id];
|
|
9747
|
+
|
|
9748
|
+
WSP_GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
|
|
9749
|
+
|
|
9750
|
+
const struct wsp_ggml_tensor * src0 = dst->src[a_id + 2];
|
|
9751
|
+
|
|
9752
|
+
wsp_ggml_compute_forward_mul_mat(params, src0, src1, dst);
|
|
9753
|
+
}
|
|
9754
|
+
|
|
9731
9755
|
// wsp_ggml_compute_forward_out_prod
|
|
9732
9756
|
|
|
9733
9757
|
static void wsp_ggml_compute_forward_out_prod_f32(
|
|
@@ -9743,10 +9767,12 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
9743
9767
|
const int ith = params->ith;
|
|
9744
9768
|
const int nth = params->nth;
|
|
9745
9769
|
|
|
9770
|
+
WSP_GGML_ASSERT(ne0 == ne00);
|
|
9771
|
+
WSP_GGML_ASSERT(ne1 == ne10);
|
|
9772
|
+
WSP_GGML_ASSERT(ne2 == ne02);
|
|
9746
9773
|
WSP_GGML_ASSERT(ne02 == ne12);
|
|
9747
|
-
WSP_GGML_ASSERT(ne03 == ne13);
|
|
9748
|
-
WSP_GGML_ASSERT(ne2 == ne12);
|
|
9749
9774
|
WSP_GGML_ASSERT(ne3 == ne13);
|
|
9775
|
+
WSP_GGML_ASSERT(ne03 == ne13);
|
|
9750
9776
|
|
|
9751
9777
|
// we don't support permuted src0 or src1
|
|
9752
9778
|
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
@@ -9757,18 +9783,25 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
9757
9783
|
// WSP_GGML_ASSERT(nb1 <= nb2);
|
|
9758
9784
|
// WSP_GGML_ASSERT(nb2 <= nb3);
|
|
9759
9785
|
|
|
9760
|
-
WSP_GGML_ASSERT(ne0 == ne00);
|
|
9761
|
-
WSP_GGML_ASSERT(ne1 == ne10);
|
|
9762
|
-
WSP_GGML_ASSERT(ne2 == ne02);
|
|
9763
|
-
WSP_GGML_ASSERT(ne3 == ne03);
|
|
9764
|
-
|
|
9765
9786
|
// nb01 >= nb00 - src0 is not transposed
|
|
9766
9787
|
// compute by src0 rows
|
|
9767
9788
|
|
|
9768
9789
|
// TODO: #if defined(WSP_GGML_USE_CUBLAS) wsp_ggml_cuda_out_prod
|
|
9769
|
-
// TODO: #if defined(
|
|
9790
|
+
// TODO: #if defined(WSP_GGML_USE_CLBLAST)
|
|
9791
|
+
|
|
9792
|
+
#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
|
|
9793
|
+
bool use_blas = wsp_ggml_is_matrix(src0) &&
|
|
9794
|
+
wsp_ggml_is_matrix(src1) &&
|
|
9795
|
+
wsp_ggml_is_contiguous(src0) &&
|
|
9796
|
+
(wsp_ggml_is_contiguous(src1) || wsp_ggml_is_transposed(src1));
|
|
9797
|
+
#endif
|
|
9770
9798
|
|
|
9771
9799
|
if (params->type == WSP_GGML_TASK_INIT) {
|
|
9800
|
+
#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) // gemm beta will zero dst
|
|
9801
|
+
if (use_blas) {
|
|
9802
|
+
return;
|
|
9803
|
+
}
|
|
9804
|
+
#endif
|
|
9772
9805
|
wsp_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
|
9773
9806
|
return;
|
|
9774
9807
|
}
|
|
@@ -9777,6 +9810,50 @@ static void wsp_ggml_compute_forward_out_prod_f32(
|
|
|
9777
9810
|
return;
|
|
9778
9811
|
}
|
|
9779
9812
|
|
|
9813
|
+
#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
|
|
9814
|
+
if (use_blas) {
|
|
9815
|
+
if (params->ith != 0) { // All threads other than the first do no work.
|
|
9816
|
+
return;
|
|
9817
|
+
}
|
|
9818
|
+
// Arguments to wsp_ggml_compute_forward_out_prod (expressed as major,minor)
|
|
9819
|
+
// src0: (k,n)
|
|
9820
|
+
// src1: (k,m)
|
|
9821
|
+
// dst: (m,n)
|
|
9822
|
+
//
|
|
9823
|
+
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
|
|
9824
|
+
// Also expressed as (major,minor)
|
|
9825
|
+
// a: (m,k): so src1 transposed
|
|
9826
|
+
// b: (k,n): so src0
|
|
9827
|
+
// c: (m,n)
|
|
9828
|
+
//
|
|
9829
|
+
// However, if wsp_ggml_is_transposed(src1) is true, then
|
|
9830
|
+
// src1->data already contains a transposed version, so sgemm mustn't
|
|
9831
|
+
// transpose it further.
|
|
9832
|
+
|
|
9833
|
+
int n = src0->ne[0];
|
|
9834
|
+
int k = src0->ne[1];
|
|
9835
|
+
int m = src1->ne[0];
|
|
9836
|
+
|
|
9837
|
+
int transposeA, lda;
|
|
9838
|
+
|
|
9839
|
+
if (!wsp_ggml_is_transposed(src1)) {
|
|
9840
|
+
transposeA = CblasTrans;
|
|
9841
|
+
lda = m;
|
|
9842
|
+
} else {
|
|
9843
|
+
transposeA = CblasNoTrans;
|
|
9844
|
+
lda = k;
|
|
9845
|
+
}
|
|
9846
|
+
|
|
9847
|
+
float * a = (float *) ((char *) src1->data);
|
|
9848
|
+
float * b = (float *) ((char *) src0->data);
|
|
9849
|
+
float * c = (float *) ((char *) dst->data);
|
|
9850
|
+
|
|
9851
|
+
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
|
|
9852
|
+
|
|
9853
|
+
return;
|
|
9854
|
+
}
|
|
9855
|
+
#endif
|
|
9856
|
+
|
|
9780
9857
|
// dst[:,:,:,:] = 0
|
|
9781
9858
|
// for i2,i3:
|
|
9782
9859
|
// for i1:
|
|
@@ -9880,7 +9957,7 @@ static void wsp_ggml_compute_forward_out_prod_q_f32(
|
|
|
9880
9957
|
const int nth = params->nth;
|
|
9881
9958
|
|
|
9882
9959
|
const enum wsp_ggml_type type = src0->type;
|
|
9883
|
-
wsp_ggml_to_float_t const
|
|
9960
|
+
wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
|
|
9884
9961
|
|
|
9885
9962
|
WSP_GGML_ASSERT(ne02 == ne12);
|
|
9886
9963
|
WSP_GGML_ASSERT(ne03 == ne13);
|
|
@@ -9957,7 +10034,7 @@ static void wsp_ggml_compute_forward_out_prod_q_f32(
|
|
|
9957
10034
|
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
|
9958
10035
|
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
|
9959
10036
|
|
|
9960
|
-
|
|
10037
|
+
wsp_dewsp_quantize_row_q(s0, wdata, ne0);
|
|
9961
10038
|
wsp_ggml_vec_mad_f32(ne0, d, wdata, *s1);
|
|
9962
10039
|
}
|
|
9963
10040
|
}
|
|
@@ -10251,7 +10328,7 @@ static void wsp_ggml_compute_forward_get_rows_q(
|
|
|
10251
10328
|
const int nc = src0->ne[0];
|
|
10252
10329
|
const int nr = wsp_ggml_nelements(src1);
|
|
10253
10330
|
const enum wsp_ggml_type type = src0->type;
|
|
10254
|
-
wsp_ggml_to_float_t const
|
|
10331
|
+
wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
|
|
10255
10332
|
|
|
10256
10333
|
assert( dst->ne[0] == nc);
|
|
10257
10334
|
assert( dst->ne[1] == nr);
|
|
@@ -10260,7 +10337,7 @@ static void wsp_ggml_compute_forward_get_rows_q(
|
|
|
10260
10337
|
for (int i = 0; i < nr; ++i) {
|
|
10261
10338
|
const int r = ((int32_t *) src1->data)[i];
|
|
10262
10339
|
|
|
10263
|
-
|
|
10340
|
+
wsp_dewsp_quantize_row_q(
|
|
10264
10341
|
(const void *) ((char *) src0->data + r*src0->nb[1]),
|
|
10265
10342
|
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
|
10266
10343
|
}
|
|
@@ -10630,20 +10707,25 @@ static void wsp_ggml_compute_forward_diag_mask_zero(
|
|
|
10630
10707
|
static void wsp_ggml_compute_forward_soft_max_f32(
|
|
10631
10708
|
const struct wsp_ggml_compute_params * params,
|
|
10632
10709
|
const struct wsp_ggml_tensor * src0,
|
|
10633
|
-
struct wsp_ggml_tensor *
|
|
10634
|
-
|
|
10635
|
-
|
|
10636
|
-
|
|
10710
|
+
const struct wsp_ggml_tensor * src1,
|
|
10711
|
+
struct wsp_ggml_tensor * dst) {
|
|
10712
|
+
assert(wsp_ggml_is_contiguous(dst));
|
|
10713
|
+
assert(wsp_ggml_are_same_shape(src0, dst));
|
|
10637
10714
|
|
|
10638
10715
|
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
10639
10716
|
return;
|
|
10640
10717
|
}
|
|
10641
10718
|
|
|
10719
|
+
float scale = 1.0f;
|
|
10720
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
10721
|
+
|
|
10642
10722
|
// TODO: handle transposed/permuted matrices
|
|
10643
10723
|
|
|
10644
10724
|
const int ith = params->ith;
|
|
10645
10725
|
const int nth = params->nth;
|
|
10646
10726
|
|
|
10727
|
+
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
|
10728
|
+
|
|
10647
10729
|
const int nc = src0->ne[0];
|
|
10648
10730
|
const int nr = wsp_ggml_nrows(src0);
|
|
10649
10731
|
|
|
@@ -10654,29 +10736,40 @@ static void wsp_ggml_compute_forward_soft_max_f32(
|
|
|
10654
10736
|
const int ir0 = dr*ith;
|
|
10655
10737
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
10656
10738
|
|
|
10739
|
+
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
10740
|
+
|
|
10657
10741
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
10658
|
-
float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
|
10659
|
-
float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
10742
|
+
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
|
10743
|
+
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
10744
|
+
|
|
10745
|
+
// broadcast the mask across rows
|
|
10746
|
+
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
|
10747
|
+
|
|
10748
|
+
wsp_ggml_vec_cpy_f32 (nc, wp, sp);
|
|
10749
|
+
wsp_ggml_vec_scale_f32(nc, wp, scale);
|
|
10750
|
+
if (mp) {
|
|
10751
|
+
wsp_ggml_vec_acc_f32(nc, wp, mp);
|
|
10752
|
+
}
|
|
10660
10753
|
|
|
10661
10754
|
#ifndef NDEBUG
|
|
10662
10755
|
for (int i = 0; i < nc; ++i) {
|
|
10663
10756
|
//printf("p[%d] = %f\n", i, p[i]);
|
|
10664
|
-
assert(!isnan(
|
|
10757
|
+
assert(!isnan(wp[i]));
|
|
10665
10758
|
}
|
|
10666
10759
|
#endif
|
|
10667
10760
|
|
|
10668
10761
|
float max = -INFINITY;
|
|
10669
|
-
wsp_ggml_vec_max_f32(nc, &max,
|
|
10762
|
+
wsp_ggml_vec_max_f32(nc, &max, wp);
|
|
10670
10763
|
|
|
10671
10764
|
wsp_ggml_float sum = 0.0;
|
|
10672
10765
|
|
|
10673
10766
|
uint16_t scvt;
|
|
10674
10767
|
for (int i = 0; i < nc; i++) {
|
|
10675
|
-
if (
|
|
10768
|
+
if (wp[i] == -INFINITY) {
|
|
10676
10769
|
dp[i] = 0.0f;
|
|
10677
10770
|
} else {
|
|
10678
|
-
// const float val = (
|
|
10679
|
-
wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(
|
|
10771
|
+
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
|
|
10772
|
+
wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(wp[i] - max);
|
|
10680
10773
|
memcpy(&scvt, &s, sizeof(scvt));
|
|
10681
10774
|
const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt]);
|
|
10682
10775
|
sum += (wsp_ggml_float)val;
|
|
@@ -10701,11 +10794,12 @@ static void wsp_ggml_compute_forward_soft_max_f32(
|
|
|
10701
10794
|
static void wsp_ggml_compute_forward_soft_max(
|
|
10702
10795
|
const struct wsp_ggml_compute_params * params,
|
|
10703
10796
|
const struct wsp_ggml_tensor * src0,
|
|
10704
|
-
struct wsp_ggml_tensor *
|
|
10797
|
+
const struct wsp_ggml_tensor * src1,
|
|
10798
|
+
struct wsp_ggml_tensor * dst) {
|
|
10705
10799
|
switch (src0->type) {
|
|
10706
10800
|
case WSP_GGML_TYPE_F32:
|
|
10707
10801
|
{
|
|
10708
|
-
wsp_ggml_compute_forward_soft_max_f32(params, src0, dst);
|
|
10802
|
+
wsp_ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
|
|
10709
10803
|
} break;
|
|
10710
10804
|
default:
|
|
10711
10805
|
{
|
|
@@ -11086,7 +11180,8 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
11086
11180
|
const struct wsp_ggml_compute_params * params,
|
|
11087
11181
|
const struct wsp_ggml_tensor * src0,
|
|
11088
11182
|
const struct wsp_ggml_tensor * src1,
|
|
11089
|
-
struct wsp_ggml_tensor * dst
|
|
11183
|
+
struct wsp_ggml_tensor * dst,
|
|
11184
|
+
const bool forward) {
|
|
11090
11185
|
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11091
11186
|
return;
|
|
11092
11187
|
}
|
|
@@ -11145,6 +11240,11 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
11145
11240
|
const bool is_neox = mode & 2;
|
|
11146
11241
|
const bool is_glm = mode & 4;
|
|
11147
11242
|
|
|
11243
|
+
// backward process uses inverse rotation by cos and sin.
|
|
11244
|
+
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
11245
|
+
// this essentially just switches the sign of sin.
|
|
11246
|
+
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
11247
|
+
|
|
11148
11248
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
11149
11249
|
|
|
11150
11250
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -11161,9 +11261,9 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
11161
11261
|
float block_theta = MAX(p - (n_ctx - 2), 0);
|
|
11162
11262
|
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
|
11163
11263
|
const float cos_theta = cosf(theta_base);
|
|
11164
|
-
const float sin_theta = sinf(theta_base);
|
|
11264
|
+
const float sin_theta = sinf(theta_base) * sin_sign;
|
|
11165
11265
|
const float cos_block_theta = cosf(block_theta);
|
|
11166
|
-
const float sin_block_theta = sinf(block_theta);
|
|
11266
|
+
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
|
11167
11267
|
|
|
11168
11268
|
theta_base *= theta_scale;
|
|
11169
11269
|
block_theta *= theta_scale;
|
|
@@ -11187,6 +11287,7 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
11187
11287
|
rope_yarn(
|
|
11188
11288
|
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
|
11189
11289
|
);
|
|
11290
|
+
sin_theta *= sin_sign;
|
|
11190
11291
|
|
|
11191
11292
|
// zeta scaling for xPos only:
|
|
11192
11293
|
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
|
@@ -11217,6 +11318,7 @@ static void wsp_ggml_compute_forward_rope_f32(
|
|
|
11217
11318
|
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
|
11218
11319
|
&cos_theta, &sin_theta
|
|
11219
11320
|
);
|
|
11321
|
+
sin_theta *= sin_sign;
|
|
11220
11322
|
|
|
11221
11323
|
theta_base *= theta_scale;
|
|
11222
11324
|
|
|
@@ -11242,7 +11344,8 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
11242
11344
|
const struct wsp_ggml_compute_params * params,
|
|
11243
11345
|
const struct wsp_ggml_tensor * src0,
|
|
11244
11346
|
const struct wsp_ggml_tensor * src1,
|
|
11245
|
-
struct wsp_ggml_tensor * dst
|
|
11347
|
+
struct wsp_ggml_tensor * dst,
|
|
11348
|
+
const bool forward) {
|
|
11246
11349
|
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11247
11350
|
return;
|
|
11248
11351
|
}
|
|
@@ -11294,6 +11397,11 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
11294
11397
|
const bool is_neox = mode & 2;
|
|
11295
11398
|
const bool is_glm = mode & 4;
|
|
11296
11399
|
|
|
11400
|
+
// backward process uses inverse rotation by cos and sin.
|
|
11401
|
+
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
11402
|
+
// this essentially just switches the sign of sin.
|
|
11403
|
+
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
11404
|
+
|
|
11297
11405
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
11298
11406
|
|
|
11299
11407
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -11310,9 +11418,9 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
11310
11418
|
float block_theta = MAX(p - (n_ctx - 2), 0);
|
|
11311
11419
|
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
|
11312
11420
|
const float cos_theta = cosf(theta_base);
|
|
11313
|
-
const float sin_theta = sinf(theta_base);
|
|
11421
|
+
const float sin_theta = sinf(theta_base) * sin_sign;
|
|
11314
11422
|
const float cos_block_theta = cosf(block_theta);
|
|
11315
|
-
const float sin_block_theta = sinf(block_theta);
|
|
11423
|
+
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
|
11316
11424
|
|
|
11317
11425
|
theta_base *= theta_scale;
|
|
11318
11426
|
block_theta *= theta_scale;
|
|
@@ -11336,6 +11444,7 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
11336
11444
|
rope_yarn(
|
|
11337
11445
|
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
|
11338
11446
|
);
|
|
11447
|
+
sin_theta *= sin_sign;
|
|
11339
11448
|
|
|
11340
11449
|
theta_base *= theta_scale;
|
|
11341
11450
|
|
|
@@ -11362,6 +11471,7 @@ static void wsp_ggml_compute_forward_rope_f16(
|
|
|
11362
11471
|
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
|
11363
11472
|
&cos_theta, &sin_theta
|
|
11364
11473
|
);
|
|
11474
|
+
sin_theta *= sin_sign;
|
|
11365
11475
|
|
|
11366
11476
|
theta_base *= theta_scale;
|
|
11367
11477
|
|
|
@@ -11391,11 +11501,11 @@ static void wsp_ggml_compute_forward_rope(
|
|
|
11391
11501
|
switch (src0->type) {
|
|
11392
11502
|
case WSP_GGML_TYPE_F16:
|
|
11393
11503
|
{
|
|
11394
|
-
wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst);
|
|
11504
|
+
wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
|
|
11395
11505
|
} break;
|
|
11396
11506
|
case WSP_GGML_TYPE_F32:
|
|
11397
11507
|
{
|
|
11398
|
-
wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
|
11508
|
+
wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
|
|
11399
11509
|
} break;
|
|
11400
11510
|
default:
|
|
11401
11511
|
{
|
|
@@ -11406,216 +11516,6 @@ static void wsp_ggml_compute_forward_rope(
|
|
|
11406
11516
|
|
|
11407
11517
|
// wsp_ggml_compute_forward_rope_back
|
|
11408
11518
|
|
|
11409
|
-
static void wsp_ggml_compute_forward_rope_back_f32(
|
|
11410
|
-
const struct wsp_ggml_compute_params * params,
|
|
11411
|
-
const struct wsp_ggml_tensor * src0,
|
|
11412
|
-
const struct wsp_ggml_tensor * src1,
|
|
11413
|
-
struct wsp_ggml_tensor * dst) {
|
|
11414
|
-
|
|
11415
|
-
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11416
|
-
return;
|
|
11417
|
-
}
|
|
11418
|
-
|
|
11419
|
-
// y = rope(x, src1)
|
|
11420
|
-
// dx = rope_back(dy, src1)
|
|
11421
|
-
// src0 is dy, src1 contains options
|
|
11422
|
-
|
|
11423
|
-
float freq_base;
|
|
11424
|
-
float freq_scale;
|
|
11425
|
-
|
|
11426
|
-
// these two only relevant for xPos RoPE:
|
|
11427
|
-
float xpos_base;
|
|
11428
|
-
bool xpos_down;
|
|
11429
|
-
|
|
11430
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
11431
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
11432
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
11433
|
-
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
|
|
11434
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
|
11435
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
11436
|
-
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
11437
|
-
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
|
11438
|
-
|
|
11439
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
11440
|
-
|
|
11441
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
11442
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
11443
|
-
|
|
11444
|
-
assert(nb0 == sizeof(float));
|
|
11445
|
-
|
|
11446
|
-
const int ith = params->ith;
|
|
11447
|
-
const int nth = params->nth;
|
|
11448
|
-
|
|
11449
|
-
const int nr = wsp_ggml_nrows(dst);
|
|
11450
|
-
|
|
11451
|
-
// rows per thread
|
|
11452
|
-
const int dr = (nr + nth - 1)/nth;
|
|
11453
|
-
|
|
11454
|
-
// row range for this thread
|
|
11455
|
-
const int ir0 = dr*ith;
|
|
11456
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
11457
|
-
|
|
11458
|
-
// row index used to determine which thread to use
|
|
11459
|
-
int ir = 0;
|
|
11460
|
-
|
|
11461
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
11462
|
-
|
|
11463
|
-
const bool is_neox = mode & 2;
|
|
11464
|
-
|
|
11465
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
11466
|
-
|
|
11467
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
11468
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
11469
|
-
const int64_t p = pos[i2];
|
|
11470
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
11471
|
-
if (ir++ < ir0) continue;
|
|
11472
|
-
if (ir > ir1) break;
|
|
11473
|
-
|
|
11474
|
-
float theta_base = freq_scale * (float)p;
|
|
11475
|
-
|
|
11476
|
-
if (!is_neox) {
|
|
11477
|
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
11478
|
-
const float cos_theta = cosf(theta_base);
|
|
11479
|
-
const float sin_theta = sinf(theta_base);
|
|
11480
|
-
|
|
11481
|
-
// zeta scaling for xPos only:
|
|
11482
|
-
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
|
11483
|
-
if (xpos_down) zeta = 1.0f / zeta;
|
|
11484
|
-
|
|
11485
|
-
theta_base *= theta_scale;
|
|
11486
|
-
|
|
11487
|
-
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
11488
|
-
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
11489
|
-
|
|
11490
|
-
const float dy0 = dy[0];
|
|
11491
|
-
const float dy1 = dy[1];
|
|
11492
|
-
|
|
11493
|
-
dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
|
|
11494
|
-
dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
|
|
11495
|
-
}
|
|
11496
|
-
} else {
|
|
11497
|
-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
11498
|
-
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
|
11499
|
-
const float cos_theta = cosf(theta_base);
|
|
11500
|
-
const float sin_theta = sinf(theta_base);
|
|
11501
|
-
|
|
11502
|
-
theta_base *= theta_scale;
|
|
11503
|
-
|
|
11504
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
|
11505
|
-
|
|
11506
|
-
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
11507
|
-
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
11508
|
-
|
|
11509
|
-
const float dy0 = dy[0];
|
|
11510
|
-
const float dy1 = dy[n_dims/2];
|
|
11511
|
-
|
|
11512
|
-
dx[0] = dy0*cos_theta + dy1*sin_theta;
|
|
11513
|
-
dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
|
|
11514
|
-
}
|
|
11515
|
-
}
|
|
11516
|
-
}
|
|
11517
|
-
}
|
|
11518
|
-
}
|
|
11519
|
-
}
|
|
11520
|
-
}
|
|
11521
|
-
|
|
11522
|
-
static void wsp_ggml_compute_forward_rope_back_f16(
|
|
11523
|
-
const struct wsp_ggml_compute_params * params,
|
|
11524
|
-
const struct wsp_ggml_tensor * src0,
|
|
11525
|
-
const struct wsp_ggml_tensor * src1,
|
|
11526
|
-
struct wsp_ggml_tensor * dst) {
|
|
11527
|
-
|
|
11528
|
-
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11529
|
-
return;
|
|
11530
|
-
}
|
|
11531
|
-
|
|
11532
|
-
// y = rope(x, src1)
|
|
11533
|
-
// dx = rope_back(dy, src1)
|
|
11534
|
-
// src0 is dy, src1 contains options
|
|
11535
|
-
|
|
11536
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
11537
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
11538
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
11539
|
-
|
|
11540
|
-
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
11541
|
-
|
|
11542
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
11543
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
11544
|
-
|
|
11545
|
-
assert(nb0 == sizeof(wsp_ggml_fp16_t));
|
|
11546
|
-
|
|
11547
|
-
const int ith = params->ith;
|
|
11548
|
-
const int nth = params->nth;
|
|
11549
|
-
|
|
11550
|
-
const int nr = wsp_ggml_nrows(dst);
|
|
11551
|
-
|
|
11552
|
-
// rows per thread
|
|
11553
|
-
const int dr = (nr + nth - 1)/nth;
|
|
11554
|
-
|
|
11555
|
-
// row range for this thread
|
|
11556
|
-
const int ir0 = dr*ith;
|
|
11557
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
11558
|
-
|
|
11559
|
-
// row index used to determine which thread to use
|
|
11560
|
-
int ir = 0;
|
|
11561
|
-
|
|
11562
|
-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
|
11563
|
-
|
|
11564
|
-
const bool is_neox = mode & 2;
|
|
11565
|
-
|
|
11566
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
11567
|
-
|
|
11568
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
11569
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
11570
|
-
const int64_t p = pos[i2];
|
|
11571
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
11572
|
-
if (ir++ < ir0) continue;
|
|
11573
|
-
if (ir > ir1) break;
|
|
11574
|
-
|
|
11575
|
-
float theta_base = (float)p;
|
|
11576
|
-
|
|
11577
|
-
if (!is_neox) {
|
|
11578
|
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
11579
|
-
const float cos_theta = cosf(theta_base);
|
|
11580
|
-
const float sin_theta = sinf(theta_base);
|
|
11581
|
-
|
|
11582
|
-
theta_base *= theta_scale;
|
|
11583
|
-
|
|
11584
|
-
const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
11585
|
-
wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
11586
|
-
|
|
11587
|
-
const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]);
|
|
11588
|
-
const float dy1 = WSP_GGML_FP16_TO_FP32(dy[1]);
|
|
11589
|
-
|
|
11590
|
-
dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
|
|
11591
|
-
dx[1] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
|
|
11592
|
-
}
|
|
11593
|
-
} else {
|
|
11594
|
-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
11595
|
-
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
|
11596
|
-
const float cos_theta = cosf(theta_base);
|
|
11597
|
-
const float sin_theta = sinf(theta_base);
|
|
11598
|
-
|
|
11599
|
-
theta_base *= theta_scale;
|
|
11600
|
-
|
|
11601
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
|
11602
|
-
|
|
11603
|
-
const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
11604
|
-
wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
11605
|
-
|
|
11606
|
-
const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]);
|
|
11607
|
-
const float dy1 = WSP_GGML_FP16_TO_FP32(dy[n_dims/2]);
|
|
11608
|
-
|
|
11609
|
-
dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
|
|
11610
|
-
dx[n_dims/2] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
|
|
11611
|
-
}
|
|
11612
|
-
}
|
|
11613
|
-
}
|
|
11614
|
-
}
|
|
11615
|
-
}
|
|
11616
|
-
}
|
|
11617
|
-
}
|
|
11618
|
-
|
|
11619
11519
|
static void wsp_ggml_compute_forward_rope_back(
|
|
11620
11520
|
const struct wsp_ggml_compute_params * params,
|
|
11621
11521
|
const struct wsp_ggml_tensor * src0,
|
|
@@ -11624,11 +11524,11 @@ static void wsp_ggml_compute_forward_rope_back(
|
|
|
11624
11524
|
switch (src0->type) {
|
|
11625
11525
|
case WSP_GGML_TYPE_F16:
|
|
11626
11526
|
{
|
|
11627
|
-
|
|
11527
|
+
wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
|
|
11628
11528
|
} break;
|
|
11629
11529
|
case WSP_GGML_TYPE_F32:
|
|
11630
11530
|
{
|
|
11631
|
-
|
|
11531
|
+
wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
|
|
11632
11532
|
} break;
|
|
11633
11533
|
default:
|
|
11634
11534
|
{
|
|
@@ -11637,9 +11537,9 @@ static void wsp_ggml_compute_forward_rope_back(
|
|
|
11637
11537
|
}
|
|
11638
11538
|
}
|
|
11639
11539
|
|
|
11640
|
-
//
|
|
11540
|
+
// wsp_ggml_compute_forward_conv_transpose_1d
|
|
11641
11541
|
|
|
11642
|
-
static void
|
|
11542
|
+
static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
11643
11543
|
const struct wsp_ggml_compute_params * params,
|
|
11644
11544
|
const struct wsp_ggml_tensor * src0,
|
|
11645
11545
|
const struct wsp_ggml_tensor * src1,
|
|
@@ -11656,14 +11556,7 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32(
|
|
|
11656
11556
|
const int ith = params->ith;
|
|
11657
11557
|
const int nth = params->nth;
|
|
11658
11558
|
|
|
11659
|
-
const int nk = ne00;
|
|
11660
|
-
|
|
11661
|
-
// size of the convolution row - the kernel size unrolled across all input channels
|
|
11662
|
-
const int ew0 = nk*ne01;
|
|
11663
|
-
|
|
11664
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
11665
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
|
11666
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
|
11559
|
+
const int nk = ne00*ne01*ne02;
|
|
11667
11560
|
|
|
11668
11561
|
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
11669
11562
|
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
@@ -11671,23 +11564,37 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32(
|
|
|
11671
11564
|
if (params->type == WSP_GGML_TASK_INIT) {
|
|
11672
11565
|
memset(params->wdata, 0, params->wsize);
|
|
11673
11566
|
|
|
11674
|
-
|
|
11567
|
+
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
|
11568
|
+
{
|
|
11569
|
+
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
|
|
11675
11570
|
|
|
11676
|
-
|
|
11677
|
-
|
|
11678
|
-
|
|
11571
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
11572
|
+
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
11573
|
+
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
|
|
11574
|
+
wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
|
|
11575
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
11576
|
+
dst_data[i00*ne02 + i02] = src[i00];
|
|
11577
|
+
}
|
|
11578
|
+
}
|
|
11579
|
+
}
|
|
11580
|
+
}
|
|
11679
11581
|
|
|
11680
|
-
|
|
11681
|
-
|
|
11682
|
-
|
|
11582
|
+
// permute source data (src1) from (L x Cin) to (Cin x L)
|
|
11583
|
+
{
|
|
11584
|
+
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk;
|
|
11585
|
+
wsp_ggml_fp16_t * dst_data = wdata;
|
|
11683
11586
|
|
|
11684
|
-
|
|
11685
|
-
|
|
11686
|
-
|
|
11587
|
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
|
11588
|
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
|
11589
|
+
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
|
11590
|
+
dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
|
|
11687
11591
|
}
|
|
11688
11592
|
}
|
|
11689
11593
|
}
|
|
11690
11594
|
|
|
11595
|
+
// need to zero dst since we are accumulating into it
|
|
11596
|
+
memset(dst->data, 0, wsp_ggml_nbytes(dst));
|
|
11597
|
+
|
|
11691
11598
|
return;
|
|
11692
11599
|
}
|
|
11693
11600
|
|
|
@@ -11695,424 +11602,7 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32(
|
|
|
11695
11602
|
return;
|
|
11696
11603
|
}
|
|
11697
11604
|
|
|
11698
|
-
|
|
11699
|
-
const int nr = ne2;
|
|
11700
|
-
|
|
11701
|
-
// rows per thread
|
|
11702
|
-
const int dr = (nr + nth - 1)/nth;
|
|
11703
|
-
|
|
11704
|
-
// row range for this thread
|
|
11705
|
-
const int ir0 = dr*ith;
|
|
11706
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
11707
|
-
|
|
11708
|
-
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
|
|
11709
|
-
|
|
11710
|
-
for (int i2 = 0; i2 < ne2; i2++) {
|
|
11711
|
-
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
11712
|
-
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
|
11713
|
-
|
|
11714
|
-
for (int i0 = 0; i0 < ne0; i0++) {
|
|
11715
|
-
wsp_ggml_vec_dot_f16(ew0, dst_data + i0,
|
|
11716
|
-
(wsp_ggml_fp16_t *) ((char *) src0->data + i1*nb02),
|
|
11717
|
-
(wsp_ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
|
|
11718
|
-
}
|
|
11719
|
-
}
|
|
11720
|
-
}
|
|
11721
|
-
}
|
|
11722
|
-
|
|
11723
|
-
static void wsp_ggml_compute_forward_conv_1d_f32(
|
|
11724
|
-
const struct wsp_ggml_compute_params * params,
|
|
11725
|
-
const struct wsp_ggml_tensor * src0,
|
|
11726
|
-
const struct wsp_ggml_tensor * src1,
|
|
11727
|
-
struct wsp_ggml_tensor * dst) {
|
|
11728
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
11729
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
11730
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
11731
|
-
|
|
11732
|
-
int64_t t0 = wsp_ggml_perf_time_us();
|
|
11733
|
-
UNUSED(t0);
|
|
11734
|
-
|
|
11735
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
11736
|
-
|
|
11737
|
-
const int ith = params->ith;
|
|
11738
|
-
const int nth = params->nth;
|
|
11739
|
-
|
|
11740
|
-
const int nk = ne00;
|
|
11741
|
-
|
|
11742
|
-
const int ew0 = nk*ne01;
|
|
11743
|
-
|
|
11744
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
11745
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
|
11746
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
|
11747
|
-
|
|
11748
|
-
WSP_GGML_ASSERT(nb00 == sizeof(float));
|
|
11749
|
-
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
11750
|
-
|
|
11751
|
-
if (params->type == WSP_GGML_TASK_INIT) {
|
|
11752
|
-
memset(params->wdata, 0, params->wsize);
|
|
11753
|
-
|
|
11754
|
-
float * const wdata = (float *) params->wdata + 0;
|
|
11755
|
-
|
|
11756
|
-
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
|
11757
|
-
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
|
11758
|
-
float * dst_data = wdata;
|
|
11759
|
-
|
|
11760
|
-
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
11761
|
-
for (int64_t ik = 0; ik < nk; ik++) {
|
|
11762
|
-
const int idx0 = i0*s0 + ik*d0 - p0;
|
|
11763
|
-
|
|
11764
|
-
if(!(idx0 < 0 || idx0 >= ne10)) {
|
|
11765
|
-
dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
|
|
11766
|
-
}
|
|
11767
|
-
}
|
|
11768
|
-
}
|
|
11769
|
-
}
|
|
11770
|
-
|
|
11771
|
-
return;
|
|
11772
|
-
}
|
|
11773
|
-
|
|
11774
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11775
|
-
return;
|
|
11776
|
-
}
|
|
11777
|
-
|
|
11778
|
-
// total rows in dst
|
|
11779
|
-
const int nr = ne02;
|
|
11780
|
-
|
|
11781
|
-
// rows per thread
|
|
11782
|
-
const int dr = (nr + nth - 1)/nth;
|
|
11783
|
-
|
|
11784
|
-
// row range for this thread
|
|
11785
|
-
const int ir0 = dr*ith;
|
|
11786
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
11787
|
-
|
|
11788
|
-
float * const wdata = (float *) params->wdata + 0;
|
|
11789
|
-
|
|
11790
|
-
for (int i2 = 0; i2 < ne2; i2++) {
|
|
11791
|
-
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
11792
|
-
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
|
11793
|
-
|
|
11794
|
-
for (int i0 = 0; i0 < ne0; i0++) {
|
|
11795
|
-
wsp_ggml_vec_dot_f32(ew0, dst_data + i0,
|
|
11796
|
-
(float *) ((char *) src0->data + i1*nb02),
|
|
11797
|
-
(float *) wdata + i2*nb2 + i0*ew0);
|
|
11798
|
-
}
|
|
11799
|
-
}
|
|
11800
|
-
}
|
|
11801
|
-
}
|
|
11802
|
-
|
|
11803
|
-
// TODO: reuse wsp_ggml_mul_mat or implement wsp_ggml_im2col and remove stage_0 and stage_1
|
|
11804
|
-
static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
|
|
11805
|
-
wsp_ggml_fp16_t * A,
|
|
11806
|
-
wsp_ggml_fp16_t * B,
|
|
11807
|
-
float * C,
|
|
11808
|
-
const int ith, const int nth) {
|
|
11809
|
-
// does not seem to make a difference
|
|
11810
|
-
int64_t m0, m1, n0, n1;
|
|
11811
|
-
// patches per thread
|
|
11812
|
-
if (m > n) {
|
|
11813
|
-
n0 = 0;
|
|
11814
|
-
n1 = n;
|
|
11815
|
-
|
|
11816
|
-
// total patches in dst
|
|
11817
|
-
const int np = m;
|
|
11818
|
-
|
|
11819
|
-
// patches per thread
|
|
11820
|
-
const int dp = (np + nth - 1)/nth;
|
|
11821
|
-
|
|
11822
|
-
// patch range for this thread
|
|
11823
|
-
m0 = dp*ith;
|
|
11824
|
-
m1 = MIN(m0 + dp, np);
|
|
11825
|
-
} else {
|
|
11826
|
-
m0 = 0;
|
|
11827
|
-
m1 = m;
|
|
11828
|
-
|
|
11829
|
-
// total patches in dst
|
|
11830
|
-
const int np = n;
|
|
11831
|
-
|
|
11832
|
-
// patches per thread
|
|
11833
|
-
const int dp = (np + nth - 1)/nth;
|
|
11834
|
-
|
|
11835
|
-
// patch range for this thread
|
|
11836
|
-
n0 = dp*ith;
|
|
11837
|
-
n1 = MIN(n0 + dp, np);
|
|
11838
|
-
}
|
|
11839
|
-
|
|
11840
|
-
// block-tiling attempt
|
|
11841
|
-
int64_t blck_n = 16;
|
|
11842
|
-
int64_t blck_m = 16;
|
|
11843
|
-
|
|
11844
|
-
// int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
|
|
11845
|
-
// int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(wsp_ggml_fp16_t) * K);
|
|
11846
|
-
// if (blck_size > 0) {
|
|
11847
|
-
// blck_0 = 4;
|
|
11848
|
-
// blck_1 = blck_size / blck_0;
|
|
11849
|
-
// if (blck_1 < 0) {
|
|
11850
|
-
// blck_1 = 1;
|
|
11851
|
-
// }
|
|
11852
|
-
// // blck_0 = (int64_t)sqrt(blck_size);
|
|
11853
|
-
// // blck_1 = blck_0;
|
|
11854
|
-
// }
|
|
11855
|
-
// // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
|
|
11856
|
-
|
|
11857
|
-
for (int j = n0; j < n1; j+=blck_n) {
|
|
11858
|
-
for (int i = m0; i < m1; i+=blck_m) {
|
|
11859
|
-
// printf("i j k => %d %d %d\n", i, j, K);
|
|
11860
|
-
for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
|
|
11861
|
-
for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
|
|
11862
|
-
wsp_ggml_vec_dot_f16(k,
|
|
11863
|
-
C + ii*n + jj,
|
|
11864
|
-
A + ii * k,
|
|
11865
|
-
B + jj * k);
|
|
11866
|
-
}
|
|
11867
|
-
}
|
|
11868
|
-
}
|
|
11869
|
-
}
|
|
11870
|
-
}
|
|
11871
|
-
|
|
11872
|
-
// src0: kernel [OC, IC, K]
|
|
11873
|
-
// src1: signal [N, IC, IL]
|
|
11874
|
-
// dst: result [N, OL, IC*K]
|
|
11875
|
-
static void wsp_ggml_compute_forward_conv_1d_stage_0_f32(
|
|
11876
|
-
const struct wsp_ggml_compute_params * params,
|
|
11877
|
-
const struct wsp_ggml_tensor * src0,
|
|
11878
|
-
const struct wsp_ggml_tensor * src1,
|
|
11879
|
-
struct wsp_ggml_tensor * dst) {
|
|
11880
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
11881
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
11882
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
11883
|
-
|
|
11884
|
-
int64_t t0 = wsp_ggml_perf_time_us();
|
|
11885
|
-
UNUSED(t0);
|
|
11886
|
-
|
|
11887
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
11888
|
-
|
|
11889
|
-
const int64_t N = ne12;
|
|
11890
|
-
const int64_t IC = ne11;
|
|
11891
|
-
const int64_t IL = ne10;
|
|
11892
|
-
|
|
11893
|
-
const int64_t K = ne00;
|
|
11894
|
-
|
|
11895
|
-
const int64_t OL = ne1;
|
|
11896
|
-
|
|
11897
|
-
const int ith = params->ith;
|
|
11898
|
-
const int nth = params->nth;
|
|
11899
|
-
|
|
11900
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
11901
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
|
11902
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
|
11903
|
-
|
|
11904
|
-
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
11905
|
-
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
11906
|
-
|
|
11907
|
-
if (params->type == WSP_GGML_TASK_INIT) {
|
|
11908
|
-
memset(dst->data, 0, wsp_ggml_nbytes(dst));
|
|
11909
|
-
return;
|
|
11910
|
-
}
|
|
11911
|
-
|
|
11912
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11913
|
-
return;
|
|
11914
|
-
}
|
|
11915
|
-
|
|
11916
|
-
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
|
11917
|
-
{
|
|
11918
|
-
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
|
|
11919
|
-
|
|
11920
|
-
for (int64_t in = 0; in < N; in++) {
|
|
11921
|
-
for (int64_t iol = 0; iol < OL; iol++) {
|
|
11922
|
-
for (int64_t iic = ith; iic < IC; iic+=nth) {
|
|
11923
|
-
|
|
11924
|
-
// micro kernel
|
|
11925
|
-
wsp_ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
|
|
11926
|
-
const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
|
|
11927
|
-
|
|
11928
|
-
for (int64_t ik = 0; ik < K; ik++) {
|
|
11929
|
-
const int64_t iil = iol*s0 + ik*d0 - p0;
|
|
11930
|
-
|
|
11931
|
-
if (!(iil < 0 || iil >= IL)) {
|
|
11932
|
-
dst_data[iic*K + ik] = WSP_GGML_FP32_TO_FP16(src_data[iil]);
|
|
11933
|
-
}
|
|
11934
|
-
}
|
|
11935
|
-
}
|
|
11936
|
-
}
|
|
11937
|
-
}
|
|
11938
|
-
}
|
|
11939
|
-
}
|
|
11940
|
-
|
|
11941
|
-
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
|
11942
|
-
// src0: [OC, IC, K]
|
|
11943
|
-
// src1: [N, OL, IC * K]
|
|
11944
|
-
// result: [N, OC, OL]
|
|
11945
|
-
static void wsp_ggml_compute_forward_conv_1d_stage_1_f16(
|
|
11946
|
-
const struct wsp_ggml_compute_params * params,
|
|
11947
|
-
const struct wsp_ggml_tensor * src0,
|
|
11948
|
-
const struct wsp_ggml_tensor * src1,
|
|
11949
|
-
struct wsp_ggml_tensor * dst) {
|
|
11950
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
11951
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
|
|
11952
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
11953
|
-
|
|
11954
|
-
int64_t t0 = wsp_ggml_perf_time_us();
|
|
11955
|
-
UNUSED(t0);
|
|
11956
|
-
|
|
11957
|
-
if (params->type == WSP_GGML_TASK_INIT) {
|
|
11958
|
-
return;
|
|
11959
|
-
}
|
|
11960
|
-
|
|
11961
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11962
|
-
return;
|
|
11963
|
-
}
|
|
11964
|
-
|
|
11965
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
11966
|
-
|
|
11967
|
-
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
11968
|
-
WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t));
|
|
11969
|
-
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
11970
|
-
|
|
11971
|
-
const int N = ne12;
|
|
11972
|
-
const int OL = ne11;
|
|
11973
|
-
|
|
11974
|
-
const int OC = ne02;
|
|
11975
|
-
const int IC = ne01;
|
|
11976
|
-
const int K = ne00;
|
|
11977
|
-
|
|
11978
|
-
const int ith = params->ith;
|
|
11979
|
-
const int nth = params->nth;
|
|
11980
|
-
|
|
11981
|
-
int64_t m = OC;
|
|
11982
|
-
int64_t n = OL;
|
|
11983
|
-
int64_t k = IC * K;
|
|
11984
|
-
|
|
11985
|
-
// [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
|
11986
|
-
for (int i = 0; i < N; i++) {
|
|
11987
|
-
wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
|
|
11988
|
-
wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k]
|
|
11989
|
-
float * C = (float *)dst->data + i * m * n; // [m, n]
|
|
11990
|
-
|
|
11991
|
-
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
|
11992
|
-
}
|
|
11993
|
-
}
|
|
11994
|
-
|
|
11995
|
-
static void wsp_ggml_compute_forward_conv_1d(
|
|
11996
|
-
const struct wsp_ggml_compute_params * params,
|
|
11997
|
-
const struct wsp_ggml_tensor * src0,
|
|
11998
|
-
const struct wsp_ggml_tensor * src1,
|
|
11999
|
-
struct wsp_ggml_tensor * dst) {
|
|
12000
|
-
switch(src0->type) {
|
|
12001
|
-
case WSP_GGML_TYPE_F16:
|
|
12002
|
-
{
|
|
12003
|
-
wsp_ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
|
|
12004
|
-
} break;
|
|
12005
|
-
case WSP_GGML_TYPE_F32:
|
|
12006
|
-
{
|
|
12007
|
-
wsp_ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
|
|
12008
|
-
} break;
|
|
12009
|
-
default:
|
|
12010
|
-
{
|
|
12011
|
-
WSP_GGML_ASSERT(false);
|
|
12012
|
-
} break;
|
|
12013
|
-
}
|
|
12014
|
-
}
|
|
12015
|
-
|
|
12016
|
-
static void wsp_ggml_compute_forward_conv_1d_stage_0(
|
|
12017
|
-
const struct wsp_ggml_compute_params * params,
|
|
12018
|
-
const struct wsp_ggml_tensor * src0,
|
|
12019
|
-
const struct wsp_ggml_tensor * src1,
|
|
12020
|
-
struct wsp_ggml_tensor * dst) {
|
|
12021
|
-
switch(src0->type) {
|
|
12022
|
-
case WSP_GGML_TYPE_F16:
|
|
12023
|
-
{
|
|
12024
|
-
wsp_ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
|
|
12025
|
-
} break;
|
|
12026
|
-
default:
|
|
12027
|
-
{
|
|
12028
|
-
WSP_GGML_ASSERT(false);
|
|
12029
|
-
} break;
|
|
12030
|
-
}
|
|
12031
|
-
}
|
|
12032
|
-
|
|
12033
|
-
static void wsp_ggml_compute_forward_conv_1d_stage_1(
|
|
12034
|
-
const struct wsp_ggml_compute_params * params,
|
|
12035
|
-
const struct wsp_ggml_tensor * src0,
|
|
12036
|
-
const struct wsp_ggml_tensor * src1,
|
|
12037
|
-
struct wsp_ggml_tensor * dst) {
|
|
12038
|
-
switch(src0->type) {
|
|
12039
|
-
case WSP_GGML_TYPE_F16:
|
|
12040
|
-
{
|
|
12041
|
-
wsp_ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
|
|
12042
|
-
} break;
|
|
12043
|
-
default:
|
|
12044
|
-
{
|
|
12045
|
-
WSP_GGML_ASSERT(false);
|
|
12046
|
-
} break;
|
|
12047
|
-
}
|
|
12048
|
-
}
|
|
12049
|
-
|
|
12050
|
-
// wsp_ggml_compute_forward_conv_transpose_1d
|
|
12051
|
-
|
|
12052
|
-
static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
12053
|
-
const struct wsp_ggml_compute_params * params,
|
|
12054
|
-
const struct wsp_ggml_tensor * src0,
|
|
12055
|
-
const struct wsp_ggml_tensor * src1,
|
|
12056
|
-
struct wsp_ggml_tensor * dst) {
|
|
12057
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
12058
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
12059
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
12060
|
-
|
|
12061
|
-
int64_t t0 = wsp_ggml_perf_time_us();
|
|
12062
|
-
UNUSED(t0);
|
|
12063
|
-
|
|
12064
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
12065
|
-
|
|
12066
|
-
const int ith = params->ith;
|
|
12067
|
-
const int nth = params->nth;
|
|
12068
|
-
|
|
12069
|
-
const int nk = ne00*ne01*ne02;
|
|
12070
|
-
|
|
12071
|
-
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
12072
|
-
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
12073
|
-
|
|
12074
|
-
if (params->type == WSP_GGML_TASK_INIT) {
|
|
12075
|
-
memset(params->wdata, 0, params->wsize);
|
|
12076
|
-
|
|
12077
|
-
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
|
12078
|
-
{
|
|
12079
|
-
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
|
|
12080
|
-
|
|
12081
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
12082
|
-
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
12083
|
-
const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
|
|
12084
|
-
wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
|
|
12085
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
12086
|
-
dst_data[i00*ne02 + i02] = src[i00];
|
|
12087
|
-
}
|
|
12088
|
-
}
|
|
12089
|
-
}
|
|
12090
|
-
}
|
|
12091
|
-
|
|
12092
|
-
// permute source data (src1) from (L x Cin) to (Cin x L)
|
|
12093
|
-
{
|
|
12094
|
-
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk;
|
|
12095
|
-
wsp_ggml_fp16_t * dst_data = wdata;
|
|
12096
|
-
|
|
12097
|
-
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
|
12098
|
-
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
|
12099
|
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
|
12100
|
-
dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
|
|
12101
|
-
}
|
|
12102
|
-
}
|
|
12103
|
-
}
|
|
12104
|
-
|
|
12105
|
-
// need to zero dst since we are accumulating into it
|
|
12106
|
-
memset(dst->data, 0, wsp_ggml_nbytes(dst));
|
|
12107
|
-
|
|
12108
|
-
return;
|
|
12109
|
-
}
|
|
12110
|
-
|
|
12111
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
12112
|
-
return;
|
|
12113
|
-
}
|
|
12114
|
-
|
|
12115
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
11605
|
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
12116
11606
|
|
|
12117
11607
|
// total rows in dst
|
|
12118
11608
|
const int nr = ne1;
|
|
@@ -12258,229 +11748,81 @@ static void wsp_ggml_compute_forward_conv_transpose_1d(
|
|
|
12258
11748
|
}
|
|
12259
11749
|
}
|
|
12260
11750
|
|
|
12261
|
-
// wsp_ggml_compute_forward_conv_2d
|
|
12262
|
-
|
|
12263
11751
|
// src0: kernel [OC, IC, KH, KW]
|
|
12264
11752
|
// src1: image [N, IC, IH, IW]
|
|
12265
11753
|
// dst: result [N, OH, OW, IC*KH*KW]
|
|
12266
|
-
static void
|
|
12267
|
-
const struct wsp_ggml_compute_params * params,
|
|
12268
|
-
const struct wsp_ggml_tensor * src0,
|
|
12269
|
-
const struct wsp_ggml_tensor * src1,
|
|
12270
|
-
struct wsp_ggml_tensor * dst) {
|
|
12271
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
12272
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
12273
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
12274
|
-
|
|
12275
|
-
int64_t t0 = wsp_ggml_perf_time_us();
|
|
12276
|
-
UNUSED(t0);
|
|
12277
|
-
|
|
12278
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
12279
|
-
|
|
12280
|
-
const int64_t N = ne13;
|
|
12281
|
-
const int64_t IC = ne12;
|
|
12282
|
-
const int64_t IH = ne11;
|
|
12283
|
-
const int64_t IW = ne10;
|
|
12284
|
-
|
|
12285
|
-
// const int64_t OC = ne03;
|
|
12286
|
-
// const int64_t IC = ne02;
|
|
12287
|
-
const int64_t KH = ne01;
|
|
12288
|
-
const int64_t KW = ne00;
|
|
12289
|
-
|
|
12290
|
-
const int64_t OH = ne2;
|
|
12291
|
-
const int64_t OW = ne1;
|
|
12292
|
-
|
|
12293
|
-
const int ith = params->ith;
|
|
12294
|
-
const int nth = params->nth;
|
|
12295
|
-
|
|
12296
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
12297
|
-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
|
12298
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
|
12299
|
-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
|
12300
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
|
12301
|
-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
|
12302
|
-
|
|
12303
|
-
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
12304
|
-
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
12305
|
-
|
|
12306
|
-
if (params->type == WSP_GGML_TASK_INIT) {
|
|
12307
|
-
memset(dst->data, 0, wsp_ggml_nbytes(dst));
|
|
12308
|
-
return;
|
|
12309
|
-
}
|
|
12310
|
-
|
|
12311
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
12312
|
-
return;
|
|
12313
|
-
}
|
|
12314
|
-
|
|
12315
|
-
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
12316
|
-
{
|
|
12317
|
-
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
|
|
12318
|
-
|
|
12319
|
-
for (int64_t in = 0; in < N; in++) {
|
|
12320
|
-
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
12321
|
-
for (int64_t iow = 0; iow < OW; iow++) {
|
|
12322
|
-
for (int64_t iic = ith; iic < IC; iic+=nth) {
|
|
12323
|
-
|
|
12324
|
-
// micro kernel
|
|
12325
|
-
wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
12326
|
-
const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
|
|
12327
|
-
|
|
12328
|
-
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
12329
|
-
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
12330
|
-
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
12331
|
-
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
12332
|
-
|
|
12333
|
-
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
|
12334
|
-
dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
|
12335
|
-
}
|
|
12336
|
-
}
|
|
12337
|
-
}
|
|
12338
|
-
}
|
|
12339
|
-
}
|
|
12340
|
-
}
|
|
12341
|
-
}
|
|
12342
|
-
}
|
|
12343
|
-
}
|
|
12344
|
-
|
|
12345
|
-
// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
|
12346
|
-
// src0: [OC, IC, KH, KW]
|
|
12347
|
-
// src1: [N, OH, OW, IC * KH * KW]
|
|
12348
|
-
// result: [N, OC, OH, OW]
|
|
12349
|
-
static void wsp_ggml_compute_forward_conv_2d_stage_1_f16(
|
|
12350
|
-
const struct wsp_ggml_compute_params * params,
|
|
12351
|
-
const struct wsp_ggml_tensor * src0,
|
|
12352
|
-
const struct wsp_ggml_tensor * src1,
|
|
12353
|
-
struct wsp_ggml_tensor * dst) {
|
|
12354
|
-
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
12355
|
-
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
|
|
12356
|
-
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
12357
|
-
|
|
12358
|
-
int64_t t0 = wsp_ggml_perf_time_us();
|
|
12359
|
-
UNUSED(t0);
|
|
12360
|
-
|
|
12361
|
-
if (params->type == WSP_GGML_TASK_INIT) {
|
|
12362
|
-
return;
|
|
12363
|
-
}
|
|
12364
|
-
|
|
12365
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
12366
|
-
return;
|
|
12367
|
-
}
|
|
12368
|
-
|
|
12369
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
12370
|
-
|
|
12371
|
-
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
12372
|
-
WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t));
|
|
12373
|
-
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
12374
|
-
|
|
12375
|
-
const int N = ne13;
|
|
12376
|
-
const int OH = ne12;
|
|
12377
|
-
const int OW = ne11;
|
|
12378
|
-
|
|
12379
|
-
const int OC = ne03;
|
|
12380
|
-
const int IC = ne02;
|
|
12381
|
-
const int KH = ne01;
|
|
12382
|
-
const int KW = ne00;
|
|
12383
|
-
|
|
12384
|
-
const int ith = params->ith;
|
|
12385
|
-
const int nth = params->nth;
|
|
12386
|
-
|
|
12387
|
-
int64_t m = OC;
|
|
12388
|
-
int64_t n = OH * OW;
|
|
12389
|
-
int64_t k = IC * KH * KW;
|
|
12390
|
-
|
|
12391
|
-
// [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
|
12392
|
-
for (int i = 0; i < N; i++) {
|
|
12393
|
-
wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
|
|
12394
|
-
wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k]
|
|
12395
|
-
float * C = (float *)dst->data + i * m * n; // [m, n]
|
|
12396
|
-
|
|
12397
|
-
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
|
12398
|
-
}
|
|
12399
|
-
}
|
|
12400
|
-
|
|
12401
|
-
static void wsp_ggml_compute_forward_conv_2d_f16_f32(
|
|
11754
|
+
static void wsp_ggml_compute_forward_im2col_f16(
|
|
12402
11755
|
const struct wsp_ggml_compute_params * params,
|
|
12403
11756
|
const struct wsp_ggml_tensor * src0,
|
|
12404
11757
|
const struct wsp_ggml_tensor * src1,
|
|
12405
11758
|
struct wsp_ggml_tensor * dst) {
|
|
12406
11759
|
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
12407
11760
|
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
12408
|
-
WSP_GGML_ASSERT( dst->type ==
|
|
11761
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
12409
11762
|
|
|
12410
11763
|
int64_t t0 = wsp_ggml_perf_time_us();
|
|
12411
11764
|
UNUSED(t0);
|
|
12412
11765
|
|
|
12413
|
-
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
11766
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS;
|
|
12414
11767
|
|
|
12415
|
-
|
|
12416
|
-
|
|
12417
|
-
|
|
12418
|
-
|
|
12419
|
-
|
|
12420
|
-
|
|
12421
|
-
|
|
12422
|
-
// nk1: KH
|
|
12423
|
-
// ne13: N
|
|
12424
|
-
|
|
12425
|
-
const int N = ne13;
|
|
12426
|
-
const int IC = ne12;
|
|
12427
|
-
const int IH = ne11;
|
|
12428
|
-
const int IW = ne10;
|
|
12429
|
-
|
|
12430
|
-
const int OC = ne03;
|
|
12431
|
-
// const int IC = ne02;
|
|
12432
|
-
const int KH = ne01;
|
|
12433
|
-
const int KW = ne00;
|
|
12434
|
-
|
|
12435
|
-
const int OH = ne1;
|
|
12436
|
-
const int OW = ne0;
|
|
11768
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
11769
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
11770
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
11771
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
11772
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
11773
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
11774
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
12437
11775
|
|
|
12438
11776
|
const int ith = params->ith;
|
|
12439
11777
|
const int nth = params->nth;
|
|
12440
11778
|
|
|
12441
|
-
|
|
12442
|
-
|
|
11779
|
+
const int64_t N = is_2D ? ne13 : ne12;
|
|
11780
|
+
const int64_t IC = is_2D ? ne12 : ne11;
|
|
11781
|
+
const int64_t IH = is_2D ? ne11 : 1;
|
|
11782
|
+
const int64_t IW = ne10;
|
|
12443
11783
|
|
|
12444
|
-
|
|
12445
|
-
|
|
12446
|
-
// ew0: IC*KH*KW
|
|
11784
|
+
const int64_t KH = is_2D ? ne01 : 1;
|
|
11785
|
+
const int64_t KW = ne00;
|
|
12447
11786
|
|
|
12448
|
-
const
|
|
12449
|
-
const
|
|
12450
|
-
|
|
12451
|
-
|
|
12452
|
-
|
|
12453
|
-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
|
11787
|
+
const int64_t OH = is_2D ? ne2 : 1;
|
|
11788
|
+
const int64_t OW = ne1;
|
|
11789
|
+
|
|
11790
|
+
int ofs0 = is_2D ? nb13 : nb12;
|
|
11791
|
+
int ofs1 = is_2D ? nb12 : nb11;
|
|
12454
11792
|
|
|
12455
11793
|
WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
|
|
12456
11794
|
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
12457
11795
|
|
|
12458
11796
|
if (params->type == WSP_GGML_TASK_INIT) {
|
|
12459
|
-
|
|
11797
|
+
return;
|
|
11798
|
+
}
|
|
12460
11799
|
|
|
12461
|
-
|
|
12462
|
-
|
|
11800
|
+
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
11801
|
+
return;
|
|
11802
|
+
}
|
|
12463
11803
|
|
|
12464
|
-
|
|
12465
|
-
|
|
11804
|
+
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
11805
|
+
{
|
|
11806
|
+
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
|
|
12466
11807
|
|
|
12467
|
-
|
|
12468
|
-
|
|
12469
|
-
|
|
12470
|
-
|
|
11808
|
+
for (int64_t in = 0; in < N; in++) {
|
|
11809
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
|
|
11810
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
11811
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
12471
11812
|
|
|
12472
|
-
|
|
12473
|
-
|
|
12474
|
-
|
|
11813
|
+
// micro kernel
|
|
11814
|
+
wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
11815
|
+
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
|
|
12475
11816
|
|
|
12476
|
-
|
|
12477
|
-
|
|
12478
|
-
|
|
12479
|
-
|
|
11817
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
|
11818
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
11819
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
11820
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
12480
11821
|
|
|
12481
|
-
|
|
12482
|
-
|
|
12483
|
-
|
|
11822
|
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
11823
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
|
11824
|
+
} else {
|
|
11825
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
|
12484
11826
|
}
|
|
12485
11827
|
}
|
|
12486
11828
|
}
|
|
@@ -12488,77 +11830,10 @@ static void wsp_ggml_compute_forward_conv_2d_f16_f32(
|
|
|
12488
11830
|
}
|
|
12489
11831
|
}
|
|
12490
11832
|
}
|
|
12491
|
-
|
|
12492
|
-
return;
|
|
12493
|
-
}
|
|
12494
|
-
|
|
12495
|
-
if (params->type == WSP_GGML_TASK_FINALIZE) {
|
|
12496
|
-
return;
|
|
12497
|
-
}
|
|
12498
|
-
|
|
12499
|
-
wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
|
|
12500
|
-
// wdata: [N*OH*OW, IC*KH*KW]
|
|
12501
|
-
// dst: result [N, OC, OH, OW]
|
|
12502
|
-
// src0: kernel [OC, IC, KH, KW]
|
|
12503
|
-
|
|
12504
|
-
int64_t m = OC;
|
|
12505
|
-
int64_t n = OH * OW;
|
|
12506
|
-
int64_t k = IC * KH * KW;
|
|
12507
|
-
|
|
12508
|
-
// [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
|
12509
|
-
for (int i = 0; i < N; i++) {
|
|
12510
|
-
wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
|
|
12511
|
-
wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)wdata + i * m * k; // [n, k]
|
|
12512
|
-
float * C = (float *)dst->data + i * m * n; // [m * k]
|
|
12513
|
-
|
|
12514
|
-
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
|
12515
|
-
}
|
|
12516
|
-
}
|
|
12517
|
-
|
|
12518
|
-
static void wsp_ggml_compute_forward_conv_2d(
|
|
12519
|
-
const struct wsp_ggml_compute_params * params,
|
|
12520
|
-
const struct wsp_ggml_tensor * src0,
|
|
12521
|
-
const struct wsp_ggml_tensor * src1,
|
|
12522
|
-
struct wsp_ggml_tensor * dst) {
|
|
12523
|
-
switch (src0->type) {
|
|
12524
|
-
case WSP_GGML_TYPE_F16:
|
|
12525
|
-
{
|
|
12526
|
-
wsp_ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
|
|
12527
|
-
} break;
|
|
12528
|
-
case WSP_GGML_TYPE_F32:
|
|
12529
|
-
{
|
|
12530
|
-
//wsp_ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
|
|
12531
|
-
WSP_GGML_ASSERT(false);
|
|
12532
|
-
} break;
|
|
12533
|
-
default:
|
|
12534
|
-
{
|
|
12535
|
-
WSP_GGML_ASSERT(false);
|
|
12536
|
-
} break;
|
|
12537
|
-
}
|
|
12538
|
-
}
|
|
12539
|
-
|
|
12540
|
-
static void wsp_ggml_compute_forward_conv_2d_stage_0(
|
|
12541
|
-
const struct wsp_ggml_compute_params * params,
|
|
12542
|
-
const struct wsp_ggml_tensor * src0,
|
|
12543
|
-
const struct wsp_ggml_tensor * src1,
|
|
12544
|
-
struct wsp_ggml_tensor * dst) {
|
|
12545
|
-
switch (src0->type) {
|
|
12546
|
-
case WSP_GGML_TYPE_F16:
|
|
12547
|
-
{
|
|
12548
|
-
wsp_ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
|
|
12549
|
-
} break;
|
|
12550
|
-
case WSP_GGML_TYPE_F32:
|
|
12551
|
-
{
|
|
12552
|
-
WSP_GGML_ASSERT(false);
|
|
12553
|
-
} break;
|
|
12554
|
-
default:
|
|
12555
|
-
{
|
|
12556
|
-
WSP_GGML_ASSERT(false);
|
|
12557
|
-
} break;
|
|
12558
11833
|
}
|
|
12559
11834
|
}
|
|
12560
11835
|
|
|
12561
|
-
static void
|
|
11836
|
+
static void wsp_ggml_compute_forward_im2col(
|
|
12562
11837
|
const struct wsp_ggml_compute_params * params,
|
|
12563
11838
|
const struct wsp_ggml_tensor * src0,
|
|
12564
11839
|
const struct wsp_ggml_tensor * src1,
|
|
@@ -12566,7 +11841,7 @@ static void wsp_ggml_compute_forward_conv_2d_stage_1(
|
|
|
12566
11841
|
switch (src0->type) {
|
|
12567
11842
|
case WSP_GGML_TYPE_F16:
|
|
12568
11843
|
{
|
|
12569
|
-
|
|
11844
|
+
wsp_ggml_compute_forward_im2col_f16(params, src0, src1, dst);
|
|
12570
11845
|
} break;
|
|
12571
11846
|
case WSP_GGML_TYPE_F32:
|
|
12572
11847
|
{
|
|
@@ -12880,6 +12155,67 @@ static void wsp_ggml_compute_forward_upscale(
|
|
|
12880
12155
|
}
|
|
12881
12156
|
}
|
|
12882
12157
|
|
|
12158
|
+
// wsp_ggml_compute_forward_argsort
|
|
12159
|
+
|
|
12160
|
+
static void wsp_ggml_compute_forward_argsort_f32(
|
|
12161
|
+
const struct wsp_ggml_compute_params * params,
|
|
12162
|
+
const struct wsp_ggml_tensor * src0,
|
|
12163
|
+
struct wsp_ggml_tensor * dst) {
|
|
12164
|
+
|
|
12165
|
+
if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
|
|
12166
|
+
return;
|
|
12167
|
+
}
|
|
12168
|
+
|
|
12169
|
+
WSP_GGML_TENSOR_UNARY_OP_LOCALS
|
|
12170
|
+
|
|
12171
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
12172
|
+
|
|
12173
|
+
const int ith = params->ith;
|
|
12174
|
+
const int nth = params->nth;
|
|
12175
|
+
|
|
12176
|
+
const int64_t nr = wsp_ggml_nrows(src0);
|
|
12177
|
+
|
|
12178
|
+
enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
|
|
12179
|
+
|
|
12180
|
+
for (int64_t i = ith; i < nr; i += nth) {
|
|
12181
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
12182
|
+
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
12183
|
+
|
|
12184
|
+
for (int64_t j = 0; j < ne0; j++) {
|
|
12185
|
+
dst_data[j] = j;
|
|
12186
|
+
}
|
|
12187
|
+
|
|
12188
|
+
// C doesn't have a functional sort, so we do a bubble sort instead
|
|
12189
|
+
for (int64_t j = 0; j < ne0; j++) {
|
|
12190
|
+
for (int64_t k = j + 1; k < ne0; k++) {
|
|
12191
|
+
if ((order == WSP_GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
|
12192
|
+
(order == WSP_GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
|
12193
|
+
int32_t tmp = dst_data[j];
|
|
12194
|
+
dst_data[j] = dst_data[k];
|
|
12195
|
+
dst_data[k] = tmp;
|
|
12196
|
+
}
|
|
12197
|
+
}
|
|
12198
|
+
}
|
|
12199
|
+
}
|
|
12200
|
+
}
|
|
12201
|
+
|
|
12202
|
+
static void wsp_ggml_compute_forward_argsort(
|
|
12203
|
+
const struct wsp_ggml_compute_params * params,
|
|
12204
|
+
const struct wsp_ggml_tensor * src0,
|
|
12205
|
+
struct wsp_ggml_tensor * dst) {
|
|
12206
|
+
|
|
12207
|
+
switch (src0->type) {
|
|
12208
|
+
case WSP_GGML_TYPE_F32:
|
|
12209
|
+
{
|
|
12210
|
+
wsp_ggml_compute_forward_argsort_f32(params, src0, dst);
|
|
12211
|
+
} break;
|
|
12212
|
+
default:
|
|
12213
|
+
{
|
|
12214
|
+
WSP_GGML_ASSERT(false);
|
|
12215
|
+
} break;
|
|
12216
|
+
}
|
|
12217
|
+
}
|
|
12218
|
+
|
|
12883
12219
|
// wsp_ggml_compute_forward_flash_attn
|
|
12884
12220
|
|
|
12885
12221
|
static void wsp_ggml_compute_forward_flash_attn_f32(
|
|
@@ -14703,6 +14039,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
14703
14039
|
{
|
|
14704
14040
|
wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
|
|
14705
14041
|
} break;
|
|
14042
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
14043
|
+
{
|
|
14044
|
+
wsp_ggml_compute_forward_mul_mat_id(params, tensor);
|
|
14045
|
+
} break;
|
|
14706
14046
|
case WSP_GGML_OP_OUT_PROD:
|
|
14707
14047
|
{
|
|
14708
14048
|
wsp_ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
|
|
@@ -14761,7 +14101,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
14761
14101
|
} break;
|
|
14762
14102
|
case WSP_GGML_OP_SOFT_MAX:
|
|
14763
14103
|
{
|
|
14764
|
-
wsp_ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
|
|
14104
|
+
wsp_ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
|
|
14765
14105
|
} break;
|
|
14766
14106
|
case WSP_GGML_OP_SOFT_MAX_BACK:
|
|
14767
14107
|
{
|
|
@@ -14783,33 +14123,13 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
14783
14123
|
{
|
|
14784
14124
|
wsp_ggml_compute_forward_clamp(params, tensor->src[0], tensor);
|
|
14785
14125
|
} break;
|
|
14786
|
-
case WSP_GGML_OP_CONV_1D:
|
|
14787
|
-
{
|
|
14788
|
-
wsp_ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
|
|
14789
|
-
} break;
|
|
14790
|
-
case WSP_GGML_OP_CONV_1D_STAGE_0:
|
|
14791
|
-
{
|
|
14792
|
-
wsp_ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
|
|
14793
|
-
} break;
|
|
14794
|
-
case WSP_GGML_OP_CONV_1D_STAGE_1:
|
|
14795
|
-
{
|
|
14796
|
-
wsp_ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
|
|
14797
|
-
} break;
|
|
14798
14126
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
14799
14127
|
{
|
|
14800
14128
|
wsp_ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
|
|
14801
14129
|
} break;
|
|
14802
|
-
case
|
|
14803
|
-
{
|
|
14804
|
-
wsp_ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
|
|
14805
|
-
} break;
|
|
14806
|
-
case WSP_GGML_OP_CONV_2D_STAGE_0:
|
|
14807
|
-
{
|
|
14808
|
-
wsp_ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
|
|
14809
|
-
} break;
|
|
14810
|
-
case WSP_GGML_OP_CONV_2D_STAGE_1:
|
|
14130
|
+
case WSP_GGML_OP_IM2COL:
|
|
14811
14131
|
{
|
|
14812
|
-
|
|
14132
|
+
wsp_ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
|
|
14813
14133
|
} break;
|
|
14814
14134
|
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
14815
14135
|
{
|
|
@@ -14827,6 +14147,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
|
|
|
14827
14147
|
{
|
|
14828
14148
|
wsp_ggml_compute_forward_upscale(params, tensor->src[0], tensor);
|
|
14829
14149
|
} break;
|
|
14150
|
+
case WSP_GGML_OP_ARGSORT:
|
|
14151
|
+
{
|
|
14152
|
+
wsp_ggml_compute_forward_argsort(params, tensor->src[0], tensor);
|
|
14153
|
+
} break;
|
|
14830
14154
|
case WSP_GGML_OP_FLASH_ATTN:
|
|
14831
14155
|
{
|
|
14832
14156
|
const int32_t t = wsp_ggml_get_op_params_i32(tensor, 0);
|
|
@@ -15477,6 +14801,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15477
14801
|
zero_table);
|
|
15478
14802
|
}
|
|
15479
14803
|
} break;
|
|
14804
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
14805
|
+
{
|
|
14806
|
+
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
14807
|
+
} break;
|
|
15480
14808
|
case WSP_GGML_OP_OUT_PROD:
|
|
15481
14809
|
{
|
|
15482
14810
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
@@ -15708,17 +15036,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15708
15036
|
// necessary for llama
|
|
15709
15037
|
if (src0->grad) {
|
|
15710
15038
|
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
|
15711
|
-
const int n_dims
|
|
15712
|
-
const int mode
|
|
15713
|
-
const int n_ctx
|
|
15714
|
-
|
|
15715
|
-
float freq_scale;
|
|
15716
|
-
|
|
15717
|
-
|
|
15718
|
-
memcpy(&
|
|
15719
|
-
memcpy(&
|
|
15720
|
-
memcpy(&
|
|
15721
|
-
memcpy(&
|
|
15039
|
+
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
15040
|
+
const int mode = ((int32_t *) tensor->op_params)[2];
|
|
15041
|
+
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
|
15042
|
+
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
|
15043
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
|
15044
|
+
|
|
15045
|
+
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
|
15046
|
+
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
15047
|
+
memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
|
|
15048
|
+
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
|
15049
|
+
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
|
15050
|
+
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
|
15051
|
+
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
|
15052
|
+
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
|
15722
15053
|
|
|
15723
15054
|
src0->grad = wsp_ggml_add_or_set(ctx,
|
|
15724
15055
|
src0->grad,
|
|
@@ -15728,8 +15059,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15728
15059
|
n_dims,
|
|
15729
15060
|
mode,
|
|
15730
15061
|
n_ctx,
|
|
15062
|
+
n_orig_ctx,
|
|
15731
15063
|
freq_base,
|
|
15732
15064
|
freq_scale,
|
|
15065
|
+
ext_factor,
|
|
15066
|
+
attn_factor,
|
|
15067
|
+
beta_fast,
|
|
15068
|
+
beta_slow,
|
|
15733
15069
|
xpos_base,
|
|
15734
15070
|
xpos_down),
|
|
15735
15071
|
zero_table);
|
|
@@ -15739,17 +15075,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15739
15075
|
{
|
|
15740
15076
|
if (src0->grad) {
|
|
15741
15077
|
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
|
15742
|
-
const int n_dims
|
|
15743
|
-
const int mode
|
|
15744
|
-
const int n_ctx
|
|
15745
|
-
|
|
15746
|
-
float freq_scale;
|
|
15747
|
-
|
|
15748
|
-
|
|
15749
|
-
memcpy(&
|
|
15750
|
-
memcpy(&
|
|
15751
|
-
memcpy(&
|
|
15752
|
-
memcpy(&
|
|
15078
|
+
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
15079
|
+
const int mode = ((int32_t *) tensor->op_params)[2];
|
|
15080
|
+
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
|
15081
|
+
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
|
15082
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
|
15083
|
+
|
|
15084
|
+
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
|
15085
|
+
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
15086
|
+
memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
|
|
15087
|
+
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
|
15088
|
+
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
|
15089
|
+
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
|
15090
|
+
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
|
15091
|
+
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
|
15753
15092
|
|
|
15754
15093
|
src0->grad = wsp_ggml_add_or_set(ctx,
|
|
15755
15094
|
src0->grad,
|
|
@@ -15758,14 +15097,14 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15758
15097
|
src1,
|
|
15759
15098
|
n_dims,
|
|
15760
15099
|
mode,
|
|
15761
|
-
0,
|
|
15762
15100
|
n_ctx,
|
|
15101
|
+
n_orig_ctx,
|
|
15763
15102
|
freq_base,
|
|
15764
15103
|
freq_scale,
|
|
15765
|
-
|
|
15766
|
-
|
|
15767
|
-
|
|
15768
|
-
|
|
15104
|
+
ext_factor,
|
|
15105
|
+
attn_factor,
|
|
15106
|
+
beta_fast,
|
|
15107
|
+
beta_slow,
|
|
15769
15108
|
xpos_base,
|
|
15770
15109
|
xpos_down,
|
|
15771
15110
|
false),
|
|
@@ -15780,31 +15119,11 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15780
15119
|
{
|
|
15781
15120
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15782
15121
|
} break;
|
|
15783
|
-
case WSP_GGML_OP_CONV_1D:
|
|
15784
|
-
{
|
|
15785
|
-
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15786
|
-
} break;
|
|
15787
|
-
case WSP_GGML_OP_CONV_1D_STAGE_0:
|
|
15788
|
-
{
|
|
15789
|
-
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15790
|
-
} break;
|
|
15791
|
-
case WSP_GGML_OP_CONV_1D_STAGE_1:
|
|
15792
|
-
{
|
|
15793
|
-
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15794
|
-
} break;
|
|
15795
15122
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
15796
15123
|
{
|
|
15797
15124
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15798
15125
|
} break;
|
|
15799
|
-
case
|
|
15800
|
-
{
|
|
15801
|
-
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15802
|
-
} break;
|
|
15803
|
-
case WSP_GGML_OP_CONV_2D_STAGE_0:
|
|
15804
|
-
{
|
|
15805
|
-
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15806
|
-
} break;
|
|
15807
|
-
case WSP_GGML_OP_CONV_2D_STAGE_1:
|
|
15126
|
+
case WSP_GGML_OP_IM2COL:
|
|
15808
15127
|
{
|
|
15809
15128
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15810
15129
|
} break;
|
|
@@ -15824,6 +15143,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
|
|
|
15824
15143
|
{
|
|
15825
15144
|
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15826
15145
|
} break;
|
|
15146
|
+
case WSP_GGML_OP_ARGSORT:
|
|
15147
|
+
{
|
|
15148
|
+
WSP_GGML_ASSERT(false); // TODO: not implemented
|
|
15149
|
+
} break;
|
|
15827
15150
|
case WSP_GGML_OP_FLASH_ATTN:
|
|
15828
15151
|
{
|
|
15829
15152
|
struct wsp_ggml_tensor * flash_grad = NULL;
|
|
@@ -16184,12 +15507,8 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) {
|
|
|
16184
15507
|
return wsp_ggml_new_graph_custom(ctx, WSP_GGML_DEFAULT_GRAPH_SIZE, false);
|
|
16185
15508
|
}
|
|
16186
15509
|
|
|
16187
|
-
struct wsp_ggml_cgraph
|
|
16188
|
-
|
|
16189
|
-
struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_GRAPH, obj_size);
|
|
16190
|
-
struct wsp_ggml_cgraph * cgraph = (struct wsp_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
|
|
16191
|
-
|
|
16192
|
-
*cgraph = (struct wsp_ggml_cgraph) {
|
|
15510
|
+
struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph0, int i0, int i1) {
|
|
15511
|
+
struct wsp_ggml_cgraph cgraph = {
|
|
16193
15512
|
/*.size =*/ 0,
|
|
16194
15513
|
/*.n_nodes =*/ i1 - i0,
|
|
16195
15514
|
/*.n_leafs =*/ 0,
|
|
@@ -16424,7 +15743,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16424
15743
|
n_tasks = n_threads;
|
|
16425
15744
|
} break;
|
|
16426
15745
|
case WSP_GGML_OP_SUB:
|
|
16427
|
-
case WSP_GGML_OP_DIV:
|
|
16428
15746
|
case WSP_GGML_OP_SQR:
|
|
16429
15747
|
case WSP_GGML_OP_SQRT:
|
|
16430
15748
|
case WSP_GGML_OP_LOG:
|
|
@@ -16457,10 +15775,13 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16457
15775
|
{
|
|
16458
15776
|
n_tasks = n_threads;
|
|
16459
15777
|
} break;
|
|
15778
|
+
default:
|
|
15779
|
+
WSP_GGML_ASSERT(false);
|
|
16460
15780
|
}
|
|
16461
15781
|
break;
|
|
16462
15782
|
case WSP_GGML_OP_SILU_BACK:
|
|
16463
15783
|
case WSP_GGML_OP_MUL:
|
|
15784
|
+
case WSP_GGML_OP_DIV:
|
|
16464
15785
|
case WSP_GGML_OP_NORM:
|
|
16465
15786
|
case WSP_GGML_OP_RMS_NORM:
|
|
16466
15787
|
case WSP_GGML_OP_RMS_NORM_BACK:
|
|
@@ -16498,6 +15819,11 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16498
15819
|
}
|
|
16499
15820
|
#endif
|
|
16500
15821
|
} break;
|
|
15822
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
15823
|
+
{
|
|
15824
|
+
// FIXME: blas
|
|
15825
|
+
n_tasks = n_threads;
|
|
15826
|
+
} break;
|
|
16501
15827
|
case WSP_GGML_OP_OUT_PROD:
|
|
16502
15828
|
{
|
|
16503
15829
|
n_tasks = n_threads;
|
|
@@ -16517,7 +15843,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16517
15843
|
} break;
|
|
16518
15844
|
case WSP_GGML_OP_DIAG_MASK_ZERO:
|
|
16519
15845
|
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
16520
|
-
case WSP_GGML_OP_SOFT_MAX:
|
|
16521
15846
|
case WSP_GGML_OP_SOFT_MAX_BACK:
|
|
16522
15847
|
case WSP_GGML_OP_ROPE:
|
|
16523
15848
|
case WSP_GGML_OP_ROPE_BACK:
|
|
@@ -16533,31 +15858,15 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16533
15858
|
{
|
|
16534
15859
|
n_tasks = 1; //TODO
|
|
16535
15860
|
} break;
|
|
16536
|
-
case
|
|
16537
|
-
{
|
|
16538
|
-
n_tasks = n_threads;
|
|
16539
|
-
} break;
|
|
16540
|
-
case WSP_GGML_OP_CONV_1D_STAGE_0:
|
|
16541
|
-
{
|
|
16542
|
-
n_tasks = n_threads;
|
|
16543
|
-
} break;
|
|
16544
|
-
case WSP_GGML_OP_CONV_1D_STAGE_1:
|
|
15861
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
16545
15862
|
{
|
|
16546
|
-
n_tasks = n_threads;
|
|
15863
|
+
n_tasks = MIN(MIN(4, n_threads), wsp_ggml_nrows(node->src[0]));
|
|
16547
15864
|
} break;
|
|
16548
15865
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
16549
15866
|
{
|
|
16550
15867
|
n_tasks = n_threads;
|
|
16551
15868
|
} break;
|
|
16552
|
-
case
|
|
16553
|
-
{
|
|
16554
|
-
n_tasks = n_threads;
|
|
16555
|
-
} break;
|
|
16556
|
-
case WSP_GGML_OP_CONV_2D_STAGE_0:
|
|
16557
|
-
{
|
|
16558
|
-
n_tasks = n_threads;
|
|
16559
|
-
} break;
|
|
16560
|
-
case WSP_GGML_OP_CONV_2D_STAGE_1:
|
|
15869
|
+
case WSP_GGML_OP_IM2COL:
|
|
16561
15870
|
{
|
|
16562
15871
|
n_tasks = n_threads;
|
|
16563
15872
|
} break;
|
|
@@ -16574,6 +15883,10 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16574
15883
|
{
|
|
16575
15884
|
n_tasks = n_threads;
|
|
16576
15885
|
} break;
|
|
15886
|
+
case WSP_GGML_OP_ARGSORT:
|
|
15887
|
+
{
|
|
15888
|
+
n_tasks = n_threads;
|
|
15889
|
+
} break;
|
|
16577
15890
|
case WSP_GGML_OP_FLASH_ATTN:
|
|
16578
15891
|
{
|
|
16579
15892
|
n_tasks = n_threads;
|
|
@@ -16642,6 +15955,12 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
|
|
|
16642
15955
|
} break;
|
|
16643
15956
|
default:
|
|
16644
15957
|
{
|
|
15958
|
+
fprintf(stderr, "%s: op not implemented: ", __func__);
|
|
15959
|
+
if (node->op < WSP_GGML_OP_COUNT) {
|
|
15960
|
+
fprintf(stderr, "%s\n", wsp_ggml_op_name(node->op));
|
|
15961
|
+
} else {
|
|
15962
|
+
fprintf(stderr, "%d\n", node->op);
|
|
15963
|
+
}
|
|
16645
15964
|
WSP_GGML_ASSERT(false);
|
|
16646
15965
|
} break;
|
|
16647
15966
|
}
|
|
@@ -16782,18 +16101,16 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16782
16101
|
|
|
16783
16102
|
// thread scheduling for the different operations + work buffer size estimation
|
|
16784
16103
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
16785
|
-
int n_tasks = 1;
|
|
16786
|
-
|
|
16787
16104
|
struct wsp_ggml_tensor * node = cgraph->nodes[i];
|
|
16788
16105
|
|
|
16106
|
+
const int n_tasks = wsp_ggml_get_n_tasks(node, n_threads);
|
|
16107
|
+
|
|
16789
16108
|
size_t cur = 0;
|
|
16790
16109
|
|
|
16791
16110
|
switch (node->op) {
|
|
16792
16111
|
case WSP_GGML_OP_CPY:
|
|
16793
16112
|
case WSP_GGML_OP_DUP:
|
|
16794
16113
|
{
|
|
16795
|
-
n_tasks = n_threads;
|
|
16796
|
-
|
|
16797
16114
|
if (wsp_ggml_is_quantized(node->type)) {
|
|
16798
16115
|
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
16799
16116
|
}
|
|
@@ -16801,16 +16118,12 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16801
16118
|
case WSP_GGML_OP_ADD:
|
|
16802
16119
|
case WSP_GGML_OP_ADD1:
|
|
16803
16120
|
{
|
|
16804
|
-
n_tasks = n_threads;
|
|
16805
|
-
|
|
16806
16121
|
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
16807
16122
|
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
16808
16123
|
}
|
|
16809
16124
|
} break;
|
|
16810
16125
|
case WSP_GGML_OP_ACC:
|
|
16811
16126
|
{
|
|
16812
|
-
n_tasks = n_threads;
|
|
16813
|
-
|
|
16814
16127
|
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
16815
16128
|
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
|
|
16816
16129
|
}
|
|
@@ -16836,45 +16149,32 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16836
16149
|
cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(node->src[1])/wsp_ggml_blck_size(vec_dot_type);
|
|
16837
16150
|
}
|
|
16838
16151
|
} break;
|
|
16152
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
16153
|
+
{
|
|
16154
|
+
const struct wsp_ggml_tensor * a = node->src[2];
|
|
16155
|
+
const struct wsp_ggml_tensor * b = node->src[1];
|
|
16156
|
+
const enum wsp_ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
|
|
16157
|
+
#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
|
|
16158
|
+
if (wsp_ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
|
|
16159
|
+
if (a->type != WSP_GGML_TYPE_F32) {
|
|
16160
|
+
// here we need memory just for single 2D matrix from src0
|
|
16161
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
|
|
16162
|
+
}
|
|
16163
|
+
} else
|
|
16164
|
+
#endif
|
|
16165
|
+
if (b->type != vec_dot_type) {
|
|
16166
|
+
cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(b)/wsp_ggml_blck_size(vec_dot_type);
|
|
16167
|
+
}
|
|
16168
|
+
} break;
|
|
16839
16169
|
case WSP_GGML_OP_OUT_PROD:
|
|
16840
16170
|
{
|
|
16841
|
-
n_tasks = n_threads;
|
|
16842
|
-
|
|
16843
16171
|
if (wsp_ggml_is_quantized(node->src[0]->type)) {
|
|
16844
16172
|
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
|
16845
16173
|
}
|
|
16846
16174
|
} break;
|
|
16847
|
-
case
|
|
16175
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
16848
16176
|
{
|
|
16849
|
-
|
|
16850
|
-
WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
|
|
16851
|
-
WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
|
|
16852
|
-
|
|
16853
|
-
const int64_t ne00 = node->src[0]->ne[0];
|
|
16854
|
-
const int64_t ne01 = node->src[0]->ne[1];
|
|
16855
|
-
const int64_t ne02 = node->src[0]->ne[2];
|
|
16856
|
-
|
|
16857
|
-
const int64_t ne10 = node->src[1]->ne[0];
|
|
16858
|
-
const int64_t ne11 = node->src[1]->ne[1];
|
|
16859
|
-
|
|
16860
|
-
const int64_t ne0 = node->ne[0];
|
|
16861
|
-
const int64_t ne1 = node->ne[1];
|
|
16862
|
-
const int64_t nk = ne00;
|
|
16863
|
-
const int64_t ew0 = nk * ne01;
|
|
16864
|
-
|
|
16865
|
-
UNUSED(ne02);
|
|
16866
|
-
UNUSED(ne10);
|
|
16867
|
-
UNUSED(ne11);
|
|
16868
|
-
|
|
16869
|
-
if (node->src[0]->type == WSP_GGML_TYPE_F16 &&
|
|
16870
|
-
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
16871
|
-
cur = sizeof(wsp_ggml_fp16_t)*(ne0*ne1*ew0);
|
|
16872
|
-
} else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
|
|
16873
|
-
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
16874
|
-
cur = sizeof(float)*(ne0*ne1*ew0);
|
|
16875
|
-
} else {
|
|
16876
|
-
WSP_GGML_ASSERT(false);
|
|
16877
|
-
}
|
|
16177
|
+
cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
|
16878
16178
|
} break;
|
|
16879
16179
|
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
16880
16180
|
{
|
|
@@ -16901,38 +16201,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16901
16201
|
WSP_GGML_ASSERT(false);
|
|
16902
16202
|
}
|
|
16903
16203
|
} break;
|
|
16904
|
-
case WSP_GGML_OP_CONV_2D:
|
|
16905
|
-
{
|
|
16906
|
-
const int64_t ne00 = node->src[0]->ne[0]; // W
|
|
16907
|
-
const int64_t ne01 = node->src[0]->ne[1]; // H
|
|
16908
|
-
const int64_t ne02 = node->src[0]->ne[2]; // C
|
|
16909
|
-
const int64_t ne03 = node->src[0]->ne[3]; // N
|
|
16910
|
-
|
|
16911
|
-
const int64_t ne10 = node->src[1]->ne[0]; // W
|
|
16912
|
-
const int64_t ne11 = node->src[1]->ne[1]; // H
|
|
16913
|
-
const int64_t ne12 = node->src[1]->ne[2]; // C
|
|
16914
|
-
|
|
16915
|
-
const int64_t ne0 = node->ne[0];
|
|
16916
|
-
const int64_t ne1 = node->ne[1];
|
|
16917
|
-
const int64_t ne2 = node->ne[2];
|
|
16918
|
-
const int64_t ne3 = node->ne[3];
|
|
16919
|
-
const int64_t nk = ne00*ne01;
|
|
16920
|
-
const int64_t ew0 = nk * ne02;
|
|
16921
|
-
|
|
16922
|
-
UNUSED(ne03);
|
|
16923
|
-
UNUSED(ne2);
|
|
16924
|
-
|
|
16925
|
-
if (node->src[0]->type == WSP_GGML_TYPE_F16 &&
|
|
16926
|
-
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
16927
|
-
// im2col: [N*OH*OW, IC*KH*KW]
|
|
16928
|
-
cur = sizeof(wsp_ggml_fp16_t)*(ne3*ne0*ne1*ew0);
|
|
16929
|
-
} else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
|
|
16930
|
-
node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
16931
|
-
cur = sizeof(float)* (ne10*ne11*ne12);
|
|
16932
|
-
} else {
|
|
16933
|
-
WSP_GGML_ASSERT(false);
|
|
16934
|
-
}
|
|
16935
|
-
} break;
|
|
16936
16204
|
case WSP_GGML_OP_CONV_TRANSPOSE_2D:
|
|
16937
16205
|
{
|
|
16938
16206
|
const int64_t ne00 = node->src[0]->ne[0]; // W
|
|
@@ -16949,8 +16217,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16949
16217
|
} break;
|
|
16950
16218
|
case WSP_GGML_OP_FLASH_ATTN:
|
|
16951
16219
|
{
|
|
16952
|
-
n_tasks = n_threads;
|
|
16953
|
-
|
|
16954
16220
|
const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
|
|
16955
16221
|
|
|
16956
16222
|
if (node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
@@ -16963,8 +16229,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16963
16229
|
} break;
|
|
16964
16230
|
case WSP_GGML_OP_FLASH_FF:
|
|
16965
16231
|
{
|
|
16966
|
-
n_tasks = n_threads;
|
|
16967
|
-
|
|
16968
16232
|
if (node->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
16969
16233
|
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
|
16970
16234
|
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
|
@@ -16975,8 +16239,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16975
16239
|
} break;
|
|
16976
16240
|
case WSP_GGML_OP_FLASH_ATTN_BACK:
|
|
16977
16241
|
{
|
|
16978
|
-
n_tasks = n_threads;
|
|
16979
|
-
|
|
16980
16242
|
const int64_t D = node->src[0]->ne[0];
|
|
16981
16243
|
const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
|
|
16982
16244
|
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
|
|
@@ -16991,8 +16253,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
|
|
|
16991
16253
|
|
|
16992
16254
|
case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
|
|
16993
16255
|
{
|
|
16994
|
-
n_tasks = n_threads;
|
|
16995
|
-
|
|
16996
16256
|
cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
|
16997
16257
|
} break;
|
|
16998
16258
|
case WSP_GGML_OP_COUNT:
|
|
@@ -18719,14 +17979,14 @@ enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
|
|
|
18719
17979
|
|
|
18720
17980
|
////////////////////////////////////////////////////////////////////////////////
|
|
18721
17981
|
|
|
18722
|
-
size_t
|
|
17982
|
+
size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
|
18723
17983
|
assert(k % QK4_0 == 0);
|
|
18724
17984
|
const int nb = k / QK4_0;
|
|
18725
17985
|
|
|
18726
17986
|
for (int b = 0; b < n; b += k) {
|
|
18727
17987
|
block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0;
|
|
18728
17988
|
|
|
18729
|
-
|
|
17989
|
+
wsp_quantize_row_q4_0_reference(src + b, y, k);
|
|
18730
17990
|
|
|
18731
17991
|
for (int i = 0; i < nb; i++) {
|
|
18732
17992
|
for (int j = 0; j < QK4_0; j += 2) {
|
|
@@ -18742,14 +18002,14 @@ size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64
|
|
|
18742
18002
|
return (n/QK4_0*sizeof(block_q4_0));
|
|
18743
18003
|
}
|
|
18744
18004
|
|
|
18745
|
-
size_t
|
|
18005
|
+
size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
|
18746
18006
|
assert(k % QK4_1 == 0);
|
|
18747
18007
|
const int nb = k / QK4_1;
|
|
18748
18008
|
|
|
18749
18009
|
for (int b = 0; b < n; b += k) {
|
|
18750
18010
|
block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1;
|
|
18751
18011
|
|
|
18752
|
-
|
|
18012
|
+
wsp_quantize_row_q4_1_reference(src + b, y, k);
|
|
18753
18013
|
|
|
18754
18014
|
for (int i = 0; i < nb; i++) {
|
|
18755
18015
|
for (int j = 0; j < QK4_1; j += 2) {
|
|
@@ -18765,22 +18025,22 @@ size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64
|
|
|
18765
18025
|
return (n/QK4_1*sizeof(block_q4_1));
|
|
18766
18026
|
}
|
|
18767
18027
|
|
|
18768
|
-
size_t
|
|
18028
|
+
size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
|
18769
18029
|
assert(k % QK5_0 == 0);
|
|
18770
18030
|
const int nb = k / QK5_0;
|
|
18771
18031
|
|
|
18772
18032
|
for (int b = 0; b < n; b += k) {
|
|
18773
18033
|
block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0;
|
|
18774
18034
|
|
|
18775
|
-
|
|
18035
|
+
wsp_quantize_row_q5_0_reference(src + b, y, k);
|
|
18776
18036
|
|
|
18777
18037
|
for (int i = 0; i < nb; i++) {
|
|
18778
18038
|
uint32_t qh;
|
|
18779
18039
|
memcpy(&qh, &y[i].qh, sizeof(qh));
|
|
18780
18040
|
|
|
18781
18041
|
for (int j = 0; j < QK5_0; j += 2) {
|
|
18782
|
-
const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
|
18783
|
-
const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
|
|
18042
|
+
const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
|
|
18043
|
+
const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
|
|
18784
18044
|
|
|
18785
18045
|
// cast to 16 bins
|
|
18786
18046
|
const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
|
|
@@ -18795,22 +18055,22 @@ size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64
|
|
|
18795
18055
|
return (n/QK5_0*sizeof(block_q5_0));
|
|
18796
18056
|
}
|
|
18797
18057
|
|
|
18798
|
-
size_t
|
|
18058
|
+
size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
|
18799
18059
|
assert(k % QK5_1 == 0);
|
|
18800
18060
|
const int nb = k / QK5_1;
|
|
18801
18061
|
|
|
18802
18062
|
for (int b = 0; b < n; b += k) {
|
|
18803
18063
|
block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1;
|
|
18804
18064
|
|
|
18805
|
-
|
|
18065
|
+
wsp_quantize_row_q5_1_reference(src + b, y, k);
|
|
18806
18066
|
|
|
18807
18067
|
for (int i = 0; i < nb; i++) {
|
|
18808
18068
|
uint32_t qh;
|
|
18809
18069
|
memcpy(&qh, &y[i].qh, sizeof(qh));
|
|
18810
18070
|
|
|
18811
18071
|
for (int j = 0; j < QK5_1; j += 2) {
|
|
18812
|
-
const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
|
18813
|
-
const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
|
|
18072
|
+
const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
|
|
18073
|
+
const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
|
|
18814
18074
|
|
|
18815
18075
|
// cast to 16 bins
|
|
18816
18076
|
const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
|
|
@@ -18825,14 +18085,14 @@ size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64
|
|
|
18825
18085
|
return (n/QK5_1*sizeof(block_q5_1));
|
|
18826
18086
|
}
|
|
18827
18087
|
|
|
18828
|
-
size_t
|
|
18088
|
+
size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
|
18829
18089
|
assert(k % QK8_0 == 0);
|
|
18830
18090
|
const int nb = k / QK8_0;
|
|
18831
18091
|
|
|
18832
18092
|
for (int b = 0; b < n; b += k) {
|
|
18833
18093
|
block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0;
|
|
18834
18094
|
|
|
18835
|
-
|
|
18095
|
+
wsp_quantize_row_q8_0_reference(src + b, y, k);
|
|
18836
18096
|
|
|
18837
18097
|
for (int i = 0; i < nb; i++) {
|
|
18838
18098
|
for (int j = 0; j < QK8_0; ++j) {
|
|
@@ -18846,68 +18106,68 @@ size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64
|
|
|
18846
18106
|
return (n/QK8_0*sizeof(block_q8_0));
|
|
18847
18107
|
}
|
|
18848
18108
|
|
|
18849
|
-
size_t
|
|
18109
|
+
size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
|
18850
18110
|
size_t result = 0;
|
|
18851
18111
|
switch (type) {
|
|
18852
18112
|
case WSP_GGML_TYPE_Q4_0:
|
|
18853
18113
|
{
|
|
18854
18114
|
WSP_GGML_ASSERT(start % QK4_0 == 0);
|
|
18855
18115
|
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
|
|
18856
|
-
result =
|
|
18116
|
+
result = wsp_ggml_wsp_quantize_q4_0(src + start, block, n, n, hist);
|
|
18857
18117
|
} break;
|
|
18858
18118
|
case WSP_GGML_TYPE_Q4_1:
|
|
18859
18119
|
{
|
|
18860
18120
|
WSP_GGML_ASSERT(start % QK4_1 == 0);
|
|
18861
18121
|
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
|
|
18862
|
-
result =
|
|
18122
|
+
result = wsp_ggml_wsp_quantize_q4_1(src + start, block, n, n, hist);
|
|
18863
18123
|
} break;
|
|
18864
18124
|
case WSP_GGML_TYPE_Q5_0:
|
|
18865
18125
|
{
|
|
18866
18126
|
WSP_GGML_ASSERT(start % QK5_0 == 0);
|
|
18867
18127
|
block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
|
|
18868
|
-
result =
|
|
18128
|
+
result = wsp_ggml_wsp_quantize_q5_0(src + start, block, n, n, hist);
|
|
18869
18129
|
} break;
|
|
18870
18130
|
case WSP_GGML_TYPE_Q5_1:
|
|
18871
18131
|
{
|
|
18872
18132
|
WSP_GGML_ASSERT(start % QK5_1 == 0);
|
|
18873
18133
|
block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
|
|
18874
|
-
result =
|
|
18134
|
+
result = wsp_ggml_wsp_quantize_q5_1(src + start, block, n, n, hist);
|
|
18875
18135
|
} break;
|
|
18876
18136
|
case WSP_GGML_TYPE_Q8_0:
|
|
18877
18137
|
{
|
|
18878
18138
|
WSP_GGML_ASSERT(start % QK8_0 == 0);
|
|
18879
18139
|
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
|
18880
|
-
result =
|
|
18140
|
+
result = wsp_ggml_wsp_quantize_q8_0(src + start, block, n, n, hist);
|
|
18881
18141
|
} break;
|
|
18882
18142
|
case WSP_GGML_TYPE_Q2_K:
|
|
18883
18143
|
{
|
|
18884
18144
|
WSP_GGML_ASSERT(start % QK_K == 0);
|
|
18885
18145
|
block_q2_K * block = (block_q2_K*)dst + start / QK_K;
|
|
18886
|
-
result =
|
|
18146
|
+
result = wsp_ggml_wsp_quantize_q2_K(src + start, block, n, n, hist);
|
|
18887
18147
|
} break;
|
|
18888
18148
|
case WSP_GGML_TYPE_Q3_K:
|
|
18889
18149
|
{
|
|
18890
18150
|
WSP_GGML_ASSERT(start % QK_K == 0);
|
|
18891
18151
|
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
|
|
18892
|
-
result =
|
|
18152
|
+
result = wsp_ggml_wsp_quantize_q3_K(src + start, block, n, n, hist);
|
|
18893
18153
|
} break;
|
|
18894
18154
|
case WSP_GGML_TYPE_Q4_K:
|
|
18895
18155
|
{
|
|
18896
18156
|
WSP_GGML_ASSERT(start % QK_K == 0);
|
|
18897
18157
|
block_q4_K * block = (block_q4_K*)dst + start / QK_K;
|
|
18898
|
-
result =
|
|
18158
|
+
result = wsp_ggml_wsp_quantize_q4_K(src + start, block, n, n, hist);
|
|
18899
18159
|
} break;
|
|
18900
18160
|
case WSP_GGML_TYPE_Q5_K:
|
|
18901
18161
|
{
|
|
18902
18162
|
WSP_GGML_ASSERT(start % QK_K == 0);
|
|
18903
18163
|
block_q5_K * block = (block_q5_K*)dst + start / QK_K;
|
|
18904
|
-
result =
|
|
18164
|
+
result = wsp_ggml_wsp_quantize_q5_K(src + start, block, n, n, hist);
|
|
18905
18165
|
} break;
|
|
18906
18166
|
case WSP_GGML_TYPE_Q6_K:
|
|
18907
18167
|
{
|
|
18908
18168
|
WSP_GGML_ASSERT(start % QK_K == 0);
|
|
18909
18169
|
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
|
18910
|
-
result =
|
|
18170
|
+
result = wsp_ggml_wsp_quantize_q6_K(src + start, block, n, n, hist);
|
|
18911
18171
|
} break;
|
|
18912
18172
|
case WSP_GGML_TYPE_F16:
|
|
18913
18173
|
{
|
|
@@ -19000,6 +18260,7 @@ struct wsp_gguf_kv {
|
|
|
19000
18260
|
|
|
19001
18261
|
struct wsp_gguf_header {
|
|
19002
18262
|
char magic[4];
|
|
18263
|
+
|
|
19003
18264
|
uint32_t version;
|
|
19004
18265
|
uint64_t n_tensors; // GGUFv2
|
|
19005
18266
|
uint64_t n_kv; // GGUFv2
|
|
@@ -19089,7 +18350,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19089
18350
|
|
|
19090
18351
|
for (uint32_t i = 0; i < sizeof(magic); i++) {
|
|
19091
18352
|
if (magic[i] != WSP_GGUF_MAGIC[i]) {
|
|
19092
|
-
fprintf(stderr, "%s: invalid magic characters %
|
|
18353
|
+
fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
|
19093
18354
|
fclose(file);
|
|
19094
18355
|
return NULL;
|
|
19095
18356
|
}
|
|
@@ -19104,7 +18365,6 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19104
18365
|
{
|
|
19105
18366
|
strncpy(ctx->header.magic, magic, 4);
|
|
19106
18367
|
|
|
19107
|
-
|
|
19108
18368
|
ctx->kv = NULL;
|
|
19109
18369
|
ctx->infos = NULL;
|
|
19110
18370
|
ctx->data = NULL;
|
|
@@ -19132,7 +18392,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19132
18392
|
{
|
|
19133
18393
|
ctx->kv = malloc(ctx->header.n_kv * sizeof(struct wsp_gguf_kv));
|
|
19134
18394
|
|
|
19135
|
-
for (
|
|
18395
|
+
for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
|
|
19136
18396
|
struct wsp_gguf_kv * kv = &ctx->kv[i];
|
|
19137
18397
|
|
|
19138
18398
|
//fprintf(stderr, "%s: reading kv %d\n", __func__, i);
|
|
@@ -19179,7 +18439,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19179
18439
|
case WSP_GGUF_TYPE_STRING:
|
|
19180
18440
|
{
|
|
19181
18441
|
kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct wsp_gguf_str));
|
|
19182
|
-
for (
|
|
18442
|
+
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
|
|
19183
18443
|
ok = ok && wsp_gguf_fread_str(file, &((struct wsp_gguf_str *) kv->value.arr.data)[j], &offset);
|
|
19184
18444
|
}
|
|
19185
18445
|
} break;
|
|
@@ -19207,7 +18467,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19207
18467
|
{
|
|
19208
18468
|
ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct wsp_gguf_tensor_info));
|
|
19209
18469
|
|
|
19210
|
-
for (
|
|
18470
|
+
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
|
19211
18471
|
struct wsp_gguf_tensor_info * info = &ctx->infos[i];
|
|
19212
18472
|
|
|
19213
18473
|
for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) {
|
|
@@ -19254,7 +18514,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19254
18514
|
// compute the total size of the data section, taking into account the alignment
|
|
19255
18515
|
{
|
|
19256
18516
|
ctx->size = 0;
|
|
19257
|
-
for (
|
|
18517
|
+
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
|
19258
18518
|
struct wsp_gguf_tensor_info * info = &ctx->infos[i];
|
|
19259
18519
|
|
|
19260
18520
|
const int64_t ne =
|
|
@@ -19323,7 +18583,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
|
|
|
19323
18583
|
wsp_ggml_set_no_alloc(ctx_data, true);
|
|
19324
18584
|
|
|
19325
18585
|
// create the tensors
|
|
19326
|
-
for (
|
|
18586
|
+
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
|
19327
18587
|
const int64_t ne[WSP_GGML_MAX_DIMS] = {
|
|
19328
18588
|
ctx->infos[i].ne[0],
|
|
19329
18589
|
ctx->infos[i].ne[1],
|
|
@@ -19458,24 +18718,29 @@ int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key) {
|
|
|
19458
18718
|
}
|
|
19459
18719
|
|
|
19460
18720
|
const char * wsp_gguf_get_key(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18721
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19461
18722
|
return ctx->kv[key_id].key.data;
|
|
19462
18723
|
}
|
|
19463
18724
|
|
|
19464
18725
|
enum wsp_gguf_type wsp_gguf_get_kv_type(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18726
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19465
18727
|
return ctx->kv[key_id].type;
|
|
19466
18728
|
}
|
|
19467
18729
|
|
|
19468
18730
|
enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18731
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19469
18732
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
|
|
19470
18733
|
return ctx->kv[key_id].value.arr.type;
|
|
19471
18734
|
}
|
|
19472
18735
|
|
|
19473
18736
|
const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18737
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19474
18738
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
|
|
19475
18739
|
return ctx->kv[key_id].value.arr.data;
|
|
19476
18740
|
}
|
|
19477
18741
|
|
|
19478
18742
|
const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_id, int i) {
|
|
18743
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19479
18744
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
|
|
19480
18745
|
struct wsp_gguf_kv * kv = &ctx->kv[key_id];
|
|
19481
18746
|
struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[i];
|
|
@@ -19483,70 +18748,90 @@ const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_i
|
|
|
19483
18748
|
}
|
|
19484
18749
|
|
|
19485
18750
|
int wsp_gguf_get_arr_n(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18751
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19486
18752
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
|
|
19487
18753
|
return ctx->kv[key_id].value.arr.n;
|
|
19488
18754
|
}
|
|
19489
18755
|
|
|
19490
18756
|
uint8_t wsp_gguf_get_val_u8(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18757
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19491
18758
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT8);
|
|
19492
18759
|
return ctx->kv[key_id].value.uint8;
|
|
19493
18760
|
}
|
|
19494
18761
|
|
|
19495
18762
|
int8_t wsp_gguf_get_val_i8(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18763
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19496
18764
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT8);
|
|
19497
18765
|
return ctx->kv[key_id].value.int8;
|
|
19498
18766
|
}
|
|
19499
18767
|
|
|
19500
18768
|
uint16_t wsp_gguf_get_val_u16(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18769
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19501
18770
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT16);
|
|
19502
18771
|
return ctx->kv[key_id].value.uint16;
|
|
19503
18772
|
}
|
|
19504
18773
|
|
|
19505
18774
|
int16_t wsp_gguf_get_val_i16(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18775
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19506
18776
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT16);
|
|
19507
18777
|
return ctx->kv[key_id].value.int16;
|
|
19508
18778
|
}
|
|
19509
18779
|
|
|
19510
18780
|
uint32_t wsp_gguf_get_val_u32(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18781
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19511
18782
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT32);
|
|
19512
18783
|
return ctx->kv[key_id].value.uint32;
|
|
19513
18784
|
}
|
|
19514
18785
|
|
|
19515
18786
|
int32_t wsp_gguf_get_val_i32(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18787
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19516
18788
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT32);
|
|
19517
18789
|
return ctx->kv[key_id].value.int32;
|
|
19518
18790
|
}
|
|
19519
18791
|
|
|
19520
18792
|
float wsp_gguf_get_val_f32(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18793
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19521
18794
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT32);
|
|
19522
18795
|
return ctx->kv[key_id].value.float32;
|
|
19523
18796
|
}
|
|
19524
18797
|
|
|
19525
18798
|
uint64_t wsp_gguf_get_val_u64(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18799
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19526
18800
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT64);
|
|
19527
18801
|
return ctx->kv[key_id].value.uint64;
|
|
19528
18802
|
}
|
|
19529
18803
|
|
|
19530
18804
|
int64_t wsp_gguf_get_val_i64(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18805
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19531
18806
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT64);
|
|
19532
18807
|
return ctx->kv[key_id].value.int64;
|
|
19533
18808
|
}
|
|
19534
18809
|
|
|
19535
18810
|
double wsp_gguf_get_val_f64(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18811
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19536
18812
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT64);
|
|
19537
18813
|
return ctx->kv[key_id].value.float64;
|
|
19538
18814
|
}
|
|
19539
18815
|
|
|
19540
18816
|
bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18817
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19541
18818
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_BOOL);
|
|
19542
18819
|
return ctx->kv[key_id].value.bool_;
|
|
19543
18820
|
}
|
|
19544
18821
|
|
|
19545
18822
|
const char * wsp_gguf_get_val_str(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18823
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
19546
18824
|
WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_STRING);
|
|
19547
18825
|
return ctx->kv[key_id].value.str.data;
|
|
19548
18826
|
}
|
|
19549
18827
|
|
|
18828
|
+
const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id) {
|
|
18829
|
+
WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
|
|
18830
|
+
WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_ARRAY);
|
|
18831
|
+
WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_STRING);
|
|
18832
|
+
return &ctx->kv[key_id].value;
|
|
18833
|
+
}
|
|
18834
|
+
|
|
19550
18835
|
int wsp_gguf_get_n_tensors(const struct wsp_gguf_context * ctx) {
|
|
19551
18836
|
return ctx->header.n_tensors;
|
|
19552
18837
|
}
|