@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC {
1541
1541
  } else if constexpr(RM == 8 && RN == 4) {
1542
1542
  KERNEL_8x4(ii,jj);
1543
1543
  } else {
1544
- static_assert(false, "RN/RM values not supported");
1544
+ assert(false && "RN/RM values not supported");
1545
1545
  }
1546
1546
  }
1547
1547
 
@@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC {
1573
1573
  const int nth;
1574
1574
  };
1575
1575
 
1576
- template <typename TA, typename TB, typename TC>
1576
+ template <typename TA>
1577
1577
  class tinyBLAS_Q0_PPC {
1578
1578
  public:
1579
1579
  tinyBLAS_Q0_PPC(int64_t k,
1580
1580
  const TA *A, int64_t lda,
1581
- const TB *B, int64_t ldb,
1582
- TC *C, int64_t ldc,
1581
+ const block_q8_0 *B, int64_t ldb,
1582
+ float *C, int64_t ldc,
1583
1583
  int ith, int nth)
1584
1584
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585
1585
  }
@@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC {
1590
1590
 
1591
1591
  private:
1592
1592
 
1593
- template<int RM, int RN>
1594
- inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1593
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1595
1594
  for (int I = 0; I < RM; I++) {
1596
1595
  for (int J = 0; J < RN; J++) {
1597
1596
  *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
@@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC {
1611
1610
  fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1612
1611
  }
1613
1612
  }
1614
-
1615
- template<typename VA, typename VB, int size>
1616
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
1617
- int64_t i, j;
1618
- TA *aoffset = NULL;
1619
- VA *vecOffset = NULL;
1620
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1621
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1622
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1623
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1624
- VB t1, t2, t3, t4, t5, t6, t7, t8;
1613
+ /* This function processes quantized data from block_q4_0 elements.
1614
+ * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615
+ * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616
+ * Also compute the rowsum which is required to compensate the above conversion. */
1617
+ inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
1625
1618
  const vector signed char lowMask = vec_splats((signed char)0xF);
1626
1619
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1627
1620
  const vector signed char v8 = vec_splats((signed char)0x8);
1628
- aoffset = const_cast<TA*>(a);
1629
- vecOffset = vec;
1621
+ vector signed int vsum = {0};
1622
+ vector signed int vsum2 = {0};
1623
+ c[0] = vec_and(c[1], lowMask);
1624
+ c[1] = vec_sr(c[1], v4);
1625
+ c[0] = vec_sub(c[0], v8);
1626
+ c[1] = vec_sub(c[1], v8);
1627
+ vsum = vec_sum4s(c[0], vsum);
1628
+ vsum2 = vec_sum4s(c[1], vsum2);
1629
+ vsum = vec_add(vsum, vsum2);
1630
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1631
+ }
1632
+
1633
+ template <typename V1, typename V2>
1634
+ inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
1630
1635
  vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1631
1636
  vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1632
1637
  vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1633
1638
  vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1634
- vector signed int vsum = {0};
1635
- vector signed int vsum2 = {0};
1639
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
1640
+ vector unsigned char xor_vector;
1641
+ uint8_t flip_vec = 0x80;
1642
+ xor_vector = vec_splats(flip_vec);
1643
+ t1 = vec_perm(s1, s2, swiz1);
1644
+ t2 = vec_perm(s1, s2, swiz2);
1645
+ t3 = vec_perm(s3, s4, swiz1);
1646
+ t4 = vec_perm(s3, s4, swiz2);
1647
+ t5 = vec_perm(t1, t3, swiz3);
1648
+ t6 = vec_perm(t1, t3, swiz4);
1649
+ t7 = vec_perm(t2, t4, swiz3);
1650
+ t8 = vec_perm(t2, t4, swiz4);
1651
+ if (flip == true) {
1652
+ t5 = vec_xor(t5, xor_vector);
1653
+ t6 = vec_xor(t6, xor_vector);
1654
+ t7 = vec_xor(t7, xor_vector);
1655
+ t8 = vec_xor(t8, xor_vector);
1656
+ }
1657
+ vec_xst(t5, 0, vecOffset);
1658
+ vec_xst(t6, 0, vecOffset+16);
1659
+ vec_xst(t7, 0, vecOffset+32);
1660
+ vec_xst(t8, 0, vecOffset+48);
1661
+ }
1636
1662
 
1663
+ template<int size>
1664
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665
+ int64_t i, j;
1666
+ TA *aoffset = NULL;
1667
+ int8_t *vecOffset = NULL;
1668
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1669
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1670
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1671
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1672
+ aoffset = const_cast<TA*>(a);
1673
+ vecOffset = vec;
1637
1674
  j = (rows >> 3);
1638
1675
  if (j > 0) {
1639
1676
  do {
@@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC {
1646
1683
  aoffset7 = aoffset6 + lda;
1647
1684
  aoffset8 = aoffset7 + lda;
1648
1685
  aoffset += 8 * lda;
1649
-
1650
1686
  i = (cols >> 2);
1651
1687
  if (i > 0) {
1652
1688
  do {
1653
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1654
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1655
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1656
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1657
- c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
1658
- c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
1659
- c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
1660
- c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
1661
-
1662
- c1[0] = vec_and(c1[1], lowMask);
1663
- c1[1] = vec_sr(c1[1], v4);
1664
- c1[0] = vec_sub(c1[0], v8);
1665
- c1[1] = vec_sub(c1[1], v8);
1666
- vsum = vec_sum4s(c1[0], vsum);
1667
- vsum2 = vec_sum4s(c1[1], vsum2);
1668
- vsum = vec_add(vsum, vsum2);
1669
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1670
- vsum = vec_splats(0);
1671
- vsum2 = vec_splats(0);
1672
-
1673
- c2[0] = vec_and(c2[1], lowMask);
1674
- c2[1] = vec_sr(c2[1], v4);
1675
- c2[0] = vec_sub(c2[0], v8);
1676
- c2[1] = vec_sub(c2[1], v8);
1677
- vsum = vec_sum4s(c2[0], vsum);
1678
- vsum2 = vec_sum4s(c2[1], vsum2);
1679
- vsum = vec_add(vsum, vsum2);
1680
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1681
- vsum = vec_splats(0);
1682
- vsum2 = vec_splats(0);
1683
-
1684
- c3[0] = vec_and(c3[1], lowMask);
1685
- c3[1] = vec_sr(c3[1], v4);
1686
- c3[0] = vec_sub(c3[0], v8);
1687
- c3[1] = vec_sub(c3[1], v8);
1688
- vsum = vec_sum4s(c3[0], vsum);
1689
- vsum2 = vec_sum4s(c3[1], vsum2);
1690
- vsum = vec_add(vsum, vsum2);
1691
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1692
- vsum = vec_splats(0);
1693
- vsum2 = vec_splats(0);
1694
-
1695
- c4[0] = vec_and(c4[1], lowMask);
1696
- c4[1] = vec_sr(c4[1], v4);
1697
- c4[0] = vec_sub(c4[0], v8);
1698
- c4[1] = vec_sub(c4[1], v8);
1699
- vsum = vec_sum4s(c4[0], vsum);
1700
- vsum2 = vec_sum4s(c4[1], vsum2);
1701
- vsum = vec_add(vsum, vsum2);
1702
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1703
- vsum = vec_splats(0);
1704
- vsum2 = vec_splats(0);
1705
-
1706
- c5[0] = vec_and(c5[1], lowMask);
1707
- c5[1] = vec_sr(c5[1], v4);
1708
- c5[0] = vec_sub(c5[0], v8);
1709
- c5[1] = vec_sub(c5[1], v8);
1710
- vsum = vec_sum4s(c5[0], vsum);
1711
- vsum2 = vec_sum4s(c5[1], vsum2);
1712
- vsum = vec_add(vsum, vsum2);
1713
- comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1714
- vsum = vec_splats(0);
1715
- vsum2 = vec_splats(0);
1716
-
1717
- c6[0] = vec_and(c6[1], lowMask);
1718
- c6[1] = vec_sr(c6[1], v4);
1719
- c6[0] = vec_sub(c6[0], v8);
1720
- c6[1] = vec_sub(c6[1], v8);
1721
- vsum = vec_sum4s(c6[0], vsum);
1722
- vsum2 = vec_sum4s(c6[1], vsum2);
1723
- vsum = vec_add(vsum, vsum2);
1724
- comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1725
- vsum = vec_splats(0);
1726
- vsum2 = vec_splats(0);
1727
-
1728
- c7[0] = vec_and(c7[1], lowMask);
1729
- c7[1] = vec_sr(c7[1], v4);
1730
- c7[0] = vec_sub(c7[0], v8);
1731
- c7[1] = vec_sub(c7[1], v8);
1732
- vsum = vec_sum4s(c7[0], vsum);
1733
- vsum2 = vec_sum4s(c7[1], vsum2);
1734
- vsum = vec_add(vsum, vsum2);
1735
- comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1736
- vsum = vec_splats(0);
1737
- vsum2 = vec_splats(0);
1738
-
1739
- c8[0] = vec_and(c8[1], lowMask);
1740
- c8[1] = vec_sr(c8[1], v4);
1741
- c8[0] = vec_sub(c8[0], v8);
1742
- c8[1] = vec_sub(c8[1], v8);
1743
- vsum = vec_sum4s(c8[0], vsum);
1744
- vsum2 = vec_sum4s(c8[1], vsum2);
1745
- vsum = vec_add(vsum, vsum2);
1746
- comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1747
- vsum = vec_splats(0);
1748
- vsum2 = vec_splats(0);
1749
-
1750
- t1 = vec_perm(c1[0], c2[0], swiz1);
1751
- t2 = vec_perm(c1[0], c2[0], swiz2);
1752
- t3 = vec_perm(c3[0], c4[0], swiz1);
1753
- t4 = vec_perm(c3[0], c4[0], swiz2);
1754
- t5 = vec_perm(t1, t3, swiz3);
1755
- t6 = vec_perm(t1, t3, swiz4);
1756
- t7 = vec_perm(t2, t4, swiz3);
1757
- t8 = vec_perm(t2, t4, swiz4);
1758
- vec_xst(t5, 0, vecOffset);
1759
- vec_xst(t6, 0, vecOffset+16);
1760
- vec_xst(t7, 0, vecOffset+32);
1761
- vec_xst(t8, 0, vecOffset+48);
1762
-
1763
- t1 = vec_perm(c1[1], c2[1], swiz1);
1764
- t2 = vec_perm(c1[1], c2[1], swiz2);
1765
- t3 = vec_perm(c3[1], c4[1], swiz1);
1766
- t4 = vec_perm(c3[1], c4[1], swiz2);
1767
- t5 = vec_perm(t1, t3, swiz3);
1768
- t6 = vec_perm(t1, t3, swiz4);
1769
- t7 = vec_perm(t2, t4, swiz3);
1770
- t8 = vec_perm(t2, t4, swiz4);
1771
- vec_xst(t5, 0, vecOffset+64);
1772
- vec_xst(t6, 0, vecOffset+80);
1773
- vec_xst(t7, 0, vecOffset+96);
1774
- vec_xst(t8, 0, vecOffset+112);
1775
-
1776
- t1 = vec_perm(c5[0], c6[0], swiz1);
1777
- t2 = vec_perm(c5[0], c6[0], swiz2);
1778
- t3 = vec_perm(c7[0], c8[0], swiz1);
1779
- t4 = vec_perm(c7[0], c8[0], swiz2);
1780
- t5 = vec_perm(t1, t3, swiz3);
1781
- t6 = vec_perm(t1, t3, swiz4);
1782
- t7 = vec_perm(t2, t4, swiz3);
1783
- t8 = vec_perm(t2, t4, swiz4);
1784
- vec_xst(t5, 0, vecOffset+128);
1785
- vec_xst(t6, 0, vecOffset+144);
1786
- vec_xst(t7, 0, vecOffset+160);
1787
- vec_xst(t8, 0, vecOffset+176);
1788
-
1789
- t1 = vec_perm(c5[1], c6[1], swiz1);
1790
- t2 = vec_perm(c5[1], c6[1], swiz2);
1791
- t3 = vec_perm(c7[1], c8[1], swiz1);
1792
- t4 = vec_perm(c7[1], c8[1], swiz2);
1793
- t5 = vec_perm(t1, t3, swiz3);
1794
- t6 = vec_perm(t1, t3, swiz4);
1795
- t7 = vec_perm(t2, t4, swiz3);
1796
- t8 = vec_perm(t2, t4, swiz4);
1797
- vec_xst(t5, 0, vecOffset+192);
1798
- vec_xst(t6, 0, vecOffset+208);
1799
- vec_xst(t7, 0, vecOffset+224);
1800
- vec_xst(t8, 0, vecOffset+240);
1801
-
1689
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697
+
1698
+ process_q4_elements(c1, &comparray[0]);
1699
+ process_q4_elements(c2, &comparray[1]);
1700
+ process_q4_elements(c3, &comparray[2]);
1701
+ process_q4_elements(c4, &comparray[3]);
1702
+ process_q4_elements(c5, &comparray[4]);
1703
+ process_q4_elements(c6, &comparray[5]);
1704
+ process_q4_elements(c7, &comparray[6]);
1705
+ process_q4_elements(c8, &comparray[7]);
1706
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1707
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1708
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1709
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1802
1710
  aoffset1 += lda;
1803
1711
  aoffset2 += lda;
1804
1712
  aoffset3 += lda;
@@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC {
1821
1729
  aoffset3 = aoffset2 + lda;
1822
1730
  aoffset4 = aoffset3 + lda;
1823
1731
  aoffset += 4 * lda;
1824
-
1825
1732
  i = (cols >> 2);
1826
1733
  if (i > 0) {
1827
1734
  do {
1828
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1829
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1830
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1831
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
1832
-
1833
- c1[0] = vec_and(c1[1], lowMask);
1834
- c1[1] = vec_sr(c1[1], v4);
1835
- c1[0] = vec_sub(c1[0], v8);
1836
- c1[1] = vec_sub(c1[1], v8);
1837
- vsum = vec_sum4s(c1[0], vsum);
1838
- vsum2 = vec_sum4s(c1[1], vsum2);
1839
- vsum = vec_add(vsum, vsum2);
1840
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1841
- vsum = vec_splats(0);
1842
- vsum2 = vec_splats(0);
1843
-
1844
- c2[0] = vec_and(c2[1], lowMask);
1845
- c2[1] = vec_sr(c2[1], v4);
1846
- c2[0] = vec_sub(c2[0], v8);
1847
- c2[1] = vec_sub(c2[1], v8);
1848
- vsum = vec_sum4s(c2[0], vsum);
1849
- vsum2 = vec_sum4s(c2[1], vsum2);
1850
- vsum = vec_add(vsum, vsum2);
1851
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1852
- vsum = vec_splats(0);
1853
- vsum2 = vec_splats(0);
1854
-
1855
- c3[0] = vec_and(c3[1], lowMask);
1856
- c3[1] = vec_sr(c3[1], v4);
1857
- c3[0] = vec_sub(c3[0], v8);
1858
- c3[1] = vec_sub(c3[1], v8);
1859
- vsum = vec_sum4s(c3[0], vsum);
1860
- vsum2 = vec_sum4s(c3[1], vsum2);
1861
- vsum = vec_add(vsum, vsum2);
1862
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1863
- vsum = vec_splats(0);
1864
- vsum2 = vec_splats(0);
1865
-
1866
- c4[0] = vec_and(c4[1], lowMask);
1867
- c4[1] = vec_sr(c4[1], v4);
1868
- c4[0] = vec_sub(c4[0], v8);
1869
- c4[1] = vec_sub(c4[1], v8);
1870
- vsum = vec_sum4s(c4[0], vsum);
1871
- vsum2 = vec_sum4s(c4[1], vsum2);
1872
- vsum = vec_add(vsum, vsum2);
1873
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1874
- vsum = vec_splats(0);
1875
- vsum2 = vec_splats( 0);
1876
-
1877
- t1 = vec_perm(c1[0], c2[0], swiz1);
1878
- t2 = vec_perm(c1[0], c2[0], swiz2);
1879
- t3 = vec_perm(c3[0], c4[0], swiz1);
1880
- t4 = vec_perm(c3[0], c4[0], swiz2);
1881
- t5 = vec_perm(t1, t3, swiz3);
1882
- t6 = vec_perm(t1, t3, swiz4);
1883
- t7 = vec_perm(t2, t4, swiz3);
1884
- t8 = vec_perm(t2, t4, swiz4);
1885
- vec_xst(t5, 0, vecOffset);
1886
- vec_xst(t6, 0, vecOffset+16);
1887
- vec_xst(t7, 0, vecOffset+32);
1888
- vec_xst(t8, 0, vecOffset+48);
1889
-
1890
- t1 = vec_perm(c1[1], c2[1], swiz1);
1891
- t2 = vec_perm(c1[1], c2[1], swiz2);
1892
- t3 = vec_perm(c3[1], c4[1], swiz1);
1893
- t4 = vec_perm(c3[1], c4[1], swiz2);
1894
- t5 = vec_perm(t1, t3, swiz3);
1895
- t6 = vec_perm(t1, t3, swiz4);
1896
- t7 = vec_perm(t2, t4, swiz3);
1897
- t8 = vec_perm(t2, t4, swiz4);
1898
- vec_xst(t5, 0, vecOffset+64);
1899
- vec_xst(t6, 0, vecOffset+80);
1900
- vec_xst(t7, 0, vecOffset+96);
1901
- vec_xst(t8, 0, vecOffset+112);
1902
-
1735
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1736
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1737
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1738
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1739
+
1740
+ process_q4_elements(c1, &comparray[0]);
1741
+ process_q4_elements(c2, &comparray[1]);
1742
+ process_q4_elements(c3, &comparray[2]);
1743
+ process_q4_elements(c4, &comparray[3]);
1744
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1745
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1903
1746
  aoffset1 += lda;
1904
1747
  aoffset2 += lda;
1905
1748
  aoffset3 += lda;
@@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC {
1918
1761
  if (i > 0) {
1919
1762
  do {
1920
1763
  switch(rows) {
1921
- case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
1922
- case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
1923
- case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
1764
+ case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1765
+ case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1766
+ case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1924
1767
  break;
1925
1768
  }
1926
- c1[0] = vec_and(c1[1], lowMask);
1927
- c1[1] = vec_sr(c1[1], v4);
1928
- c1[0] = vec_sub(c1[0], v8);
1929
- c1[1] = vec_sub(c1[1], v8);
1930
- vsum = vec_sum4s(c1[0], vsum);
1931
- vsum2 = vec_sum4s(c1[1], vsum2);
1932
- vsum = vec_add(vsum, vsum2);
1933
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1934
- vsum = vec_splats(0);
1935
- vsum2 = vec_splats(0);
1936
-
1937
- c2[0] = vec_and(c2[1], lowMask);
1938
- c2[1] = vec_sr(c2[1], v4);
1939
- c2[0] = vec_sub(c2[0], v8);
1940
- c2[1] = vec_sub(c2[1], v8);
1941
- vsum = vec_sum4s(c2[0], vsum);
1942
- vsum2 = vec_sum4s(c2[1], vsum2);
1943
- vsum = vec_add(vsum, vsum2);
1944
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1945
- vsum = vec_splats(0);
1946
- vsum2 = vec_splats(0);
1947
-
1948
- c3[0] = vec_and(c3[1], lowMask);
1949
- c3[1] = vec_sr(c3[1], v4);
1950
- c3[0] = vec_sub(c3[0], v8);
1951
- c3[1] = vec_sub(c3[1], v8);
1952
- vsum = vec_sum4s(c3[0], vsum);
1953
- vsum2 = vec_sum4s(c3[1], vsum2);
1954
- vsum = vec_add(vsum, vsum2);
1955
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1956
- vsum = vec_splats(0);
1957
- vsum2 = vec_splats(0);
1958
-
1959
- c4[0] = vec_and(c4[1], lowMask);
1960
- c4[1] = vec_sr(c4[1], v4);
1961
- c4[0] = vec_sub(c4[0], v8);
1962
- c4[1] = vec_sub(c4[1], v8);
1963
- vsum = vec_sum4s(c4[0], vsum);
1964
- vsum2 = vec_sum4s(c4[1], vsum2);
1965
- vsum = vec_add(vsum, vsum2);
1966
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1967
- vsum = vec_splats(0);
1968
- vsum2 = vec_splats(0);
1969
-
1970
- t1 = vec_perm(c1[0], c2[0], swiz1);
1971
- t2 = vec_perm(c1[0], c2[0], swiz2);
1972
- t3 = vec_perm(c3[0], c4[0], swiz1);
1973
- t4 = vec_perm(c3[0], c4[0], swiz2);
1974
- t5 = vec_perm(t1, t3, swiz3);
1975
- t6 = vec_perm(t1, t3, swiz4);
1976
- t7 = vec_perm(t2, t4, swiz3);
1977
- t8 = vec_perm(t2, t4, swiz4);
1978
- vec_xst(t5, 0, vecOffset);
1979
- vec_xst(t6, 0, vecOffset+16);
1980
- vec_xst(t7, 0, vecOffset+32);
1981
- vec_xst(t8, 0, vecOffset+48);
1982
-
1983
- t1 = vec_perm(c1[1], c2[1], swiz1);
1984
- t2 = vec_perm(c1[1], c2[1], swiz2);
1985
- t3 = vec_perm(c3[1], c4[1], swiz1);
1986
- t4 = vec_perm(c3[1], c4[1], swiz2);
1987
- t5 = vec_perm(t1, t3, swiz3);
1988
- t6 = vec_perm(t1, t3, swiz4);
1989
- t7 = vec_perm(t2, t4, swiz3);
1990
- t8 = vec_perm(t2, t4, swiz4);
1991
- vec_xst(t5, 0, vecOffset+64);
1992
- vec_xst(t6, 0, vecOffset+80);
1993
- vec_xst(t7, 0, vecOffset+96);
1994
- vec_xst(t8, 0, vecOffset+112);
1769
+ process_q4_elements(c1, &comparray[0]);
1770
+ process_q4_elements(c2, &comparray[1]);
1771
+ process_q4_elements(c3, &comparray[2]);
1772
+ process_q4_elements(c4, &comparray[3]);
1773
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1774
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1995
1775
  aoffset1 += lda;
1996
1776
  aoffset2 += lda;
1997
1777
  aoffset3 += lda;
@@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC {
2001
1781
  }
2002
1782
  }
2003
1783
  }
2004
-
2005
1784
  template<typename VA, typename VB>
2006
- void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1785
+ void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
2007
1786
  int64_t i, j;
2008
- TB *aoffset = NULL;
1787
+ block_q8_0 *aoffset = NULL;
2009
1788
  VA *vecOffset = NULL;
2010
- TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2011
- TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2012
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2013
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
2014
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
2015
- VB t1, t2, t3, t4, t5, t6, t7, t8;
2016
- vector unsigned char xor_vector;
2017
- uint8_t flip_vec = 0x80;
2018
- xor_vector = vec_splats(flip_vec);
2019
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2020
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2021
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
2022
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
2023
-
2024
- aoffset = const_cast<TB*>(a);
1789
+ block_q8_0* aoffsets[8];
1790
+ __vector_pair arr[8];
1791
+ VB c[8][2] = {0};
1792
+ VB c1[8] = {0}; VB c2[8] = {0};
1793
+ aoffset = const_cast<block_q8_0*>(a);
2025
1794
  vecOffset = vec;
2026
1795
  j = (rows >> 3);
2027
1796
  if (j > 0) {
2028
1797
  do {
2029
- aoffset1 = aoffset;
2030
- aoffset2 = aoffset1 + lda;
2031
- aoffset3 = aoffset2 + lda;
2032
- aoffset4 = aoffset3 + lda;
2033
- aoffset5 = aoffset4 + lda;
2034
- aoffset6 = aoffset5 + lda;
2035
- aoffset7 = aoffset6 + lda;
2036
- aoffset8 = aoffset7 + lda;
1798
+ aoffsets[0] = aoffset;
1799
+ for (int it = 1; it < 8; it++)
1800
+ aoffsets[it] = aoffsets[it-1] + lda;
2037
1801
  aoffset += 8 * lda;
2038
1802
 
2039
1803
  i = (cols >> 3);
2040
1804
  if (i > 0) {
2041
1805
  do {
2042
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2043
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2044
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2045
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2046
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
2047
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
2048
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
2049
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
2050
-
2051
- __builtin_vsx_disassemble_pair(c1, &C1);
2052
- __builtin_vsx_disassemble_pair(c2, &C2);
2053
- __builtin_vsx_disassemble_pair(c3, &C3);
2054
- __builtin_vsx_disassemble_pair(c4, &C4);
2055
- __builtin_vsx_disassemble_pair(c5, &C5);
2056
- __builtin_vsx_disassemble_pair(c6, &C6);
2057
- __builtin_vsx_disassemble_pair(c7, &C7);
2058
- __builtin_vsx_disassemble_pair(c8, &C8);
2059
-
2060
- t1 = vec_perm(c1[0], c2[0], swiz1);
2061
- t2 = vec_perm(c1[0], c2[0], swiz2);
2062
- t3 = vec_perm(c3[0], c4[0], swiz1);
2063
- t4 = vec_perm(c3[0], c4[0], swiz2);
2064
- t5 = vec_perm(t1, t3, swiz3);
2065
- t6 = vec_perm(t1, t3, swiz4);
2066
- t7 = vec_perm(t2, t4, swiz3);
2067
- t8 = vec_perm(t2, t4, swiz4);
2068
- if (flip == true) {
2069
- t5 = vec_xor(t5, xor_vector);
2070
- t6 = vec_xor(t6, xor_vector);
2071
- t7 = vec_xor(t7, xor_vector);
2072
- t8 = vec_xor(t8, xor_vector);
2073
- }
2074
- vec_xst(t5, 0, vecOffset);
2075
- vec_xst(t6, 0, vecOffset+16);
2076
- vec_xst(t7, 0, vecOffset+32);
2077
- vec_xst(t8, 0, vecOffset+48);
2078
-
2079
- t1 = vec_perm(c1[1], c2[1], swiz1);
2080
- t2 = vec_perm(c1[1], c2[1], swiz2);
2081
- t3 = vec_perm(c3[1], c4[1], swiz1);
2082
- t4 = vec_perm(c3[1], c4[1], swiz2);
2083
- t5 = vec_perm(t1, t3, swiz3);
2084
- t6 = vec_perm(t1, t3, swiz4);
2085
- t7 = vec_perm(t2, t4, swiz3);
2086
- t8 = vec_perm(t2, t4, swiz4);
2087
- if (flip == true) {
2088
- t5 = vec_xor(t5, xor_vector);
2089
- t6 = vec_xor(t6, xor_vector);
2090
- t7 = vec_xor(t7, xor_vector);
2091
- t8 = vec_xor(t8, xor_vector);
2092
- }
2093
- vec_xst(t5, 0, vecOffset+64);
2094
- vec_xst(t6, 0, vecOffset+80);
2095
- vec_xst(t7, 0, vecOffset+96);
2096
- vec_xst(t8, 0, vecOffset+112);
2097
-
2098
- t1 = vec_perm(c5[0], c6[0], swiz1);
2099
- t2 = vec_perm(c5[0], c6[0], swiz2);
2100
- t3 = vec_perm(c7[0], c8[0], swiz1);
2101
- t4 = vec_perm(c7[0], c8[0], swiz2);
2102
- t5 = vec_perm(t1, t3, swiz3);
2103
- t6 = vec_perm(t1, t3, swiz4);
2104
- t7 = vec_perm(t2, t4, swiz3);
2105
- t8 = vec_perm(t2, t4, swiz4);
2106
- if (flip == true) {
2107
- t5 = vec_xor(t5, xor_vector);
2108
- t6 = vec_xor(t6, xor_vector);
2109
- t7 = vec_xor(t7, xor_vector);
2110
- t8 = vec_xor(t8, xor_vector);
2111
- }
2112
- vec_xst(t5, 0, vecOffset+128);
2113
- vec_xst(t6, 0, vecOffset+144);
2114
- vec_xst(t7, 0, vecOffset+160);
2115
- vec_xst(t8, 0, vecOffset+176);
2116
-
2117
- t1 = vec_perm(c5[1], c6[1], swiz1);
2118
- t2 = vec_perm(c5[1], c6[1], swiz2);
2119
- t3 = vec_perm(c7[1], c8[1], swiz1);
2120
- t4 = vec_perm(c7[1], c8[1], swiz2);
2121
- t5 = vec_perm(t1, t3, swiz3);
2122
- t6 = vec_perm(t1, t3, swiz4);
2123
- t7 = vec_perm(t2, t4, swiz3);
2124
- t8 = vec_perm(t2, t4, swiz4);
2125
- if (flip == true) {
2126
- t5 = vec_xor(t5, xor_vector);
2127
- t6 = vec_xor(t6, xor_vector);
2128
- t7 = vec_xor(t7, xor_vector);
2129
- t8 = vec_xor(t8, xor_vector);
1806
+ for (int it = 0; it < 8; it++) {
1807
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1808
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1809
+ c1[it] = c[it][0];
1810
+ c2[it] = c[it][1];
2130
1811
  }
2131
- vec_xst(t5, 0, vecOffset+192);
2132
- vec_xst(t6, 0, vecOffset+208);
2133
- vec_xst(t7, 0, vecOffset+224);
2134
- vec_xst(t8, 0, vecOffset+240);
2135
-
2136
- aoffset1 += lda;
2137
- aoffset2 += lda;
2138
- aoffset3 += lda;
2139
- aoffset4 += lda;
2140
- aoffset5 += lda;
2141
- aoffset6 += lda;
2142
- aoffset7 += lda;
2143
- aoffset8 += lda;
1812
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1813
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1814
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1815
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1816
+ for (int it = 0; it < 8; it++)
1817
+ aoffsets[it] += lda;
2144
1818
  vecOffset += 256;
2145
1819
  i--;
2146
1820
  } while(i > 0);
@@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC {
2150
1824
  }
2151
1825
 
2152
1826
  if (rows & 4) {
2153
- aoffset1 = aoffset;
2154
- aoffset2 = aoffset1 + lda;
2155
- aoffset3 = aoffset2 + lda;
2156
- aoffset4 = aoffset3 + lda;
2157
- aoffset += 4 * lda;
2158
-
1827
+ aoffsets[0] = aoffset;
1828
+ for (int it = 1; it < 4; it++ )
1829
+ aoffsets[it] = aoffsets[it-1] + lda;
1830
+ aoffset += 4 * lda;
2159
1831
  i = (cols >> 3);
2160
1832
  if (i > 0) {
2161
1833
  do {
2162
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2163
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2164
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2165
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
2166
-
2167
- __builtin_vsx_disassemble_pair(c1, &C1);
2168
- __builtin_vsx_disassemble_pair(c2, &C2);
2169
- __builtin_vsx_disassemble_pair(c3, &C3);
2170
- __builtin_vsx_disassemble_pair(c4, &C4);
2171
-
2172
- t1 = vec_perm(c1[0], c2[0], swiz1);
2173
- t2 = vec_perm(c1[0], c2[0], swiz2);
2174
- t3 = vec_perm(c3[0], c4[0], swiz1);
2175
- t4 = vec_perm(c3[0], c4[0], swiz2);
2176
- t5 = vec_perm(t1, t3, swiz3);
2177
- t6 = vec_perm(t1, t3, swiz4);
2178
- t7 = vec_perm(t2, t4, swiz3);
2179
- t8 = vec_perm(t2, t4, swiz4);
2180
- if (flip == true) {
2181
- t5 = vec_xor(t5, xor_vector);
2182
- t6 = vec_xor(t6, xor_vector);
2183
- t7 = vec_xor(t7, xor_vector);
2184
- t8 = vec_xor(t8, xor_vector);
1834
+ for (int it = 0; it < 4; it++) {
1835
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1836
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1837
+ c1[it] = c[it][0];
1838
+ c2[it] = c[it][1];
2185
1839
  }
2186
- vec_xst(t5, 0, vecOffset);
2187
- vec_xst(t6, 0, vecOffset+16);
2188
- vec_xst(t7, 0, vecOffset+32);
2189
- vec_xst(t8, 0, vecOffset+48);
2190
-
2191
- t1 = vec_perm(c1[1], c2[1], swiz1);
2192
- t2 = vec_perm(c1[1], c2[1], swiz2);
2193
- t3 = vec_perm(c3[1], c4[1], swiz1);
2194
- t4 = vec_perm(c3[1], c4[1], swiz2);
2195
- t5 = vec_perm(t1, t3, swiz3);
2196
- t6 = vec_perm(t1, t3, swiz4);
2197
- t7 = vec_perm(t2, t4, swiz3);
2198
- t8 = vec_perm(t2, t4, swiz4);
2199
- if (flip == true) {
2200
- t5 = vec_xor(t5, xor_vector);
2201
- t6 = vec_xor(t6, xor_vector);
2202
- t7 = vec_xor(t7, xor_vector);
2203
- t8 = vec_xor(t8, xor_vector);
1840
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1841
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1842
+ for (int it = 0; it < 4; it++) {
1843
+ aoffsets[it] += lda;
2204
1844
  }
2205
- vec_xst(t5, 0, vecOffset+64);
2206
- vec_xst(t6, 0, vecOffset+80);
2207
- vec_xst(t7, 0, vecOffset+96);
2208
- vec_xst(t8, 0, vecOffset+112);
2209
-
2210
- aoffset1 += lda;
2211
- aoffset2 += lda;
2212
- aoffset3 += lda;
2213
- aoffset4 += lda;
2214
1845
  vecOffset += 128;
2215
1846
  i--;
2216
1847
  } while(i > 0);
2217
1848
  }
2218
1849
  }
1850
+
2219
1851
  if (rows & 3) {
2220
- aoffset1 = aoffset;
2221
- aoffset2 = aoffset1 + lda;
2222
- aoffset3 = aoffset2 + lda;
1852
+ aoffsets[0] = aoffset;
1853
+ for (int it = 1; it < 3; it++ )
1854
+ aoffsets[it] = aoffsets[it-1] + lda;
2223
1855
  i = (cols >> 3);
2224
1856
  if (i > 0) {
2225
1857
  do {
2226
1858
  switch(rows) {
2227
- case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
2228
- __builtin_vsx_disassemble_pair(c3, &C3);
2229
- case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
2230
- __builtin_vsx_disassemble_pair(c2, &C2);
2231
- case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
2232
- __builtin_vsx_disassemble_pair(c1, &C1);
1859
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1860
+ __builtin_vsx_disassemble_pair(c[2], &arr[2]);
1861
+ c1[2] = c[2][0]; c2[2] = c[2][1];
1862
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1863
+ __builtin_vsx_disassemble_pair(c[1], &arr[1]);
1864
+ c1[1] = c[1][0]; c2[1] = c[1][1];
1865
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1866
+ __builtin_vsx_disassemble_pair(c[0], &arr[0]);
1867
+ c1[0] = c[0][0]; c2[0] = c[0][1];
2233
1868
  break;
2234
1869
  }
2235
- t1 = vec_perm(c1[0], c2[0], swiz1);
2236
- t2 = vec_perm(c1[0], c2[0], swiz2);
2237
- t3 = vec_perm(c3[0], c4[0], swiz1);
2238
- t4 = vec_perm(c3[0], c4[0], swiz2);
2239
- t5 = vec_perm(t1, t3, swiz3);
2240
- t6 = vec_perm(t1, t3, swiz4);
2241
- t7 = vec_perm(t2, t4, swiz3);
2242
- t8 = vec_perm(t2, t4, swiz4);
2243
- if (flip == true) {
2244
- t5 = vec_xor(t5, xor_vector);
2245
- t6 = vec_xor(t6, xor_vector);
2246
- t7 = vec_xor(t7, xor_vector);
2247
- t8 = vec_xor(t8, xor_vector);
2248
- }
2249
- vec_xst(t5, 0, vecOffset);
2250
- vec_xst(t6, 0, vecOffset+16);
2251
- vec_xst(t7, 0, vecOffset+32);
2252
- vec_xst(t8, 0, vecOffset+48);
2253
-
2254
- t1 = vec_perm(c1[1], c2[1], swiz1);
2255
- t2 = vec_perm(c1[1], c2[1], swiz2);
2256
- t3 = vec_perm(c3[1], c4[1], swiz1);
2257
- t4 = vec_perm(c3[1], c4[1], swiz2);
2258
- t5 = vec_perm(t1, t3, swiz3);
2259
- t6 = vec_perm(t1, t3, swiz4);
2260
- t7 = vec_perm(t2, t4, swiz3);
2261
- t8 = vec_perm(t2, t4, swiz4);
2262
- if (flip == true) {
2263
- t5 = vec_xor(t5, xor_vector);
2264
- t6 = vec_xor(t6, xor_vector);
2265
- t7 = vec_xor(t7, xor_vector);
2266
- t8 = vec_xor(t8, xor_vector);
2267
- }
2268
- vec_xst(t5, 0, vecOffset+64);
2269
- vec_xst(t6, 0, vecOffset+80);
2270
- vec_xst(t7, 0, vecOffset+96);
2271
- vec_xst(t8, 0, vecOffset+112);
2272
-
2273
- aoffset1 += lda;
2274
- aoffset2 += lda;
2275
- aoffset3 += lda;
1870
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1871
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1872
+ for (int it = 0; it < 3; it++)
1873
+ aoffsets[it] += lda;
2276
1874
  vecOffset += 128;
2277
1875
  i--;
2278
1876
  } while(i > 0);
@@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC {
2281
1879
  }
2282
1880
 
2283
1881
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2284
- int64_t mc, nc, mp, np;
2285
- int m_rem = MIN(m - m0, 8);
2286
- int n_rem = MIN(n - n0, 8);
2287
- // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
2288
- // issues. After resolving them, below code will be enabled.
2289
- /*if (m_rem >= 16 && n_rem >= 8) {
2290
- mc = 16;
2291
- nc = 8;
2292
- gemm<16,8>(m0, m, n0, n);
2293
- } else if(m_rem >= 8 && n_rem >= 16) {
2294
- mc = 8;
2295
- nc = 16;
2296
- gemm<8,16>(m0, m, n0, n);
2297
- }*/
1882
+ int m_rem = MIN(m - m0, 16);
1883
+ int n_rem = MIN(n - n0, 16);
1884
+
1885
+ int mc = 0, nc = 0;
1886
+
2298
1887
  if (m_rem >= 8 && n_rem >= 8) {
2299
- mc = 8;
2300
- nc = 8;
2301
- gemm<8,8>(m0, m, n0, n);
1888
+ mc = 8;
1889
+ nc = 8;
1890
+ gemm<8, 8>(m0, m, n0, n);
2302
1891
  } else if (m_rem >= 4 && n_rem >= 8) {
2303
1892
  mc = 4;
2304
1893
  nc = 8;
2305
- gemm<4,8>(m0, m, n0, n);
1894
+ gemm<4, 8>(m0, m, n0, n);
2306
1895
  } else if (m_rem >= 8 && n_rem >= 4) {
2307
1896
  mc = 8;
2308
1897
  nc = 4;
2309
- gemm<8,4>(m0, m, n0, n);
1898
+ gemm<8, 4>(m0, m, n0, n);
2310
1899
  } else if (m_rem >= 4 && n_rem >= 4) {
2311
1900
  mc = 4;
2312
1901
  nc = 4;
2313
- gemm_small<4, 4>(m0, m, n0, n);
2314
- } else if ((m_rem < 4) && (n_rem > 4)) {
2315
- nc = 4;
2316
- switch(m_rem) {
2317
- case 1:
2318
- mc = 1;
2319
- gemm_small<1, 4>(m0, m, n0, n);
2320
- break;
2321
- case 2:
2322
- mc = 2;
2323
- gemm_small<2, 4>(m0, m, n0, n);
2324
- break;
2325
- case 3:
2326
- mc = 3;
2327
- gemm_small<3, 4>(m0, m, n0, n);
2328
- break;
2329
- default:
2330
- return;
2331
- }
2332
- } else if ((m_rem > 4) && (n_rem < 4)) {
2333
- mc = 4;
2334
- switch(n_rem) {
2335
- case 1:
2336
- nc = 1;
2337
- gemm_small<4, 1>(m0, m, n0, n);
2338
- break;
2339
- case 2:
2340
- nc = 2;
2341
- gemm_small<4, 2>(m0, m, n0, n);
2342
- break;
2343
- case 3:
2344
- nc = 3;
2345
- gemm_small<4, 3>(m0, m, n0, n);
2346
- break;
2347
- default:
2348
- return;
2349
- }
1902
+ gemm_small(m0, m, n0, n, mc, nc);
2350
1903
  } else {
2351
- switch((m_rem << 4) | n_rem) {
2352
- case 0x43:
2353
- mc = 4;
2354
- nc = 3;
2355
- gemm_small<4, 3>(m0, m, n0, n);
2356
- break;
2357
- case 0x42:
2358
- mc = 4;
2359
- nc = 2;
2360
- gemm_small<4, 2>(m0, m, n0, n);
2361
- break;
2362
- case 0x41:
2363
- mc = 4;
2364
- nc = 1;
2365
- gemm_small<4, 1>(m0, m, n0, n);
2366
- break;
2367
- case 0x34:
2368
- mc = 3;
2369
- nc = 4;
2370
- gemm_small<3, 4>(m0, m, n0, n);
2371
- break;
2372
- case 0x33:
2373
- mc = 3;
2374
- nc = 3;
2375
- gemm_small<3, 3>(m0, m, n0, n);
2376
- break;
2377
- case 0x32:
2378
- mc = 3;
2379
- nc = 2;
2380
- gemm_small<3, 2>(m0, m, n0, n);
2381
- break;
2382
- case 0x31:
2383
- mc = 3;
2384
- nc = 1;
2385
- gemm_small<3, 1>(m0, m, n0, n);
2386
- break;
2387
- case 0x24:
2388
- mc = 2;
2389
- nc = 4;
2390
- gemm_small<2, 4>(m0, m, n0, n);
2391
- break;
2392
- case 0x23:
2393
- mc = 2;
2394
- nc = 3;
2395
- gemm_small<2, 3>(m0, m, n0, n);
2396
- break;
2397
- case 0x22:
2398
- mc = 2;
2399
- nc = 2;
2400
- gemm_small<2, 2>(m0, m, n0, n);
2401
- break;
2402
- case 0x21:
2403
- mc = 2;
2404
- nc = 1;
2405
- gemm_small<2, 1>(m0, m, n0, n);
2406
- break;
2407
- case 0x14:
2408
- mc = 1;
2409
- nc = 4;
2410
- gemm_small<1, 4>(m0, m, n0, n);
2411
- break;
2412
- case 0x13:
2413
- mc = 1;
2414
- nc = 3;
2415
- gemm_small<1, 3>(m0, m, n0, n);
2416
- break;
2417
- case 0x12:
2418
- mc = 1;
2419
- nc = 2;
2420
- gemm_small<1, 2>(m0, m, n0, n);
2421
- break;
2422
- case 0x11:
2423
- mc = 1;
2424
- nc = 1;
2425
- gemm_small<1, 1>(m0, m, n0, n);
2426
- break;
2427
- default:
2428
- return;
2429
- }
1904
+ mc = (m_rem >= 4) ? 4 : m_rem;
1905
+ nc = (n_rem >= 4) ? 4 : n_rem;
1906
+ if (mc == 0 || nc == 0)
1907
+ return;
1908
+ gemm_small(m0, m, n0, n, mc, nc);
2430
1909
  }
2431
- mp = m0 + (m - m0) / mc * mc;
2432
- np = n0 + (n - n0) / nc * nc;
1910
+
1911
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
1912
+ int64_t np = n0 + ((n - n0) / nc) * nc;
2433
1913
  mnpack(mp, m, n0, np);
2434
1914
  mnpack(m0, m, np, n);
2435
1915
  }
2436
1916
 
1917
+
2437
1918
  void KERNEL_4x8(int64_t ii, int64_t jj) {
2438
1919
  vec_t vec_A[8], vec_B[16] = {0};
2439
1920
  acc_t acc_0, acc_1;
@@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC {
2445
1926
  __builtin_mma_xxsetaccz(&acc_0);
2446
1927
  __builtin_mma_xxsetaccz(&acc_1);
2447
1928
  if (std::is_same_v<TA, block_q4_0>) {
2448
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1929
+ packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
2449
1930
  } else {
2450
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1931
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
2451
1932
  }
2452
1933
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2453
1934
  for(int x = 0; x < 8; x++) {
@@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC {
2475
1956
  compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
2476
1957
  compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2477
1958
  }
2478
- save_res<4, 4>(ii, jj, 0, fin_res);
2479
- save_res<4, 4>(ii, jj+4, 4, fin_res);
1959
+ save_res(ii, jj, 0, fin_res);
1960
+ save_res(ii, jj+4, 4, fin_res);
2480
1961
  }
2481
1962
 
2482
1963
  void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC {
2490
1971
  __builtin_mma_xxsetaccz(&acc_0);
2491
1972
  __builtin_mma_xxsetaccz(&acc_1);
2492
1973
  if (std::is_same_v<TA, block_q4_0>) {
2493
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1974
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2494
1975
  } else {
2495
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1976
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2496
1977
  }
2497
1978
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
2498
1979
  for(int x = 0; x < 8; x++) {
@@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC {
2519
2000
  compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2520
2001
  compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2521
2002
  }
2522
- save_res<4, 4>(ii, jj, 0, fin_res);
2523
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2003
+ save_res(ii, jj, 0, fin_res);
2004
+ save_res(ii+4, jj, 4, fin_res);
2524
2005
  }
2525
2006
 
2526
2007
  void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC {
2536
2017
  __builtin_mma_xxsetaccz(&acc_2);
2537
2018
  __builtin_mma_xxsetaccz(&acc_3);
2538
2019
  if (std::is_same_v<TA, block_q4_0>) {
2539
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2020
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2540
2021
  } else {
2541
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2022
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2542
2023
  }
2543
2024
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2544
2025
  for(int x = 0; x < 8; x++) {
@@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC {
2570
2051
  compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2571
2052
  compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2572
2053
  }
2573
- save_res<4, 4>(ii, jj, 0, fin_res);
2574
- save_res<4, 4>(ii+4, jj, 4, fin_res);
2575
- save_res<4, 4>(ii, jj+4, 8, fin_res);
2576
- save_res<4, 4>(ii+4, jj+4, 12, fin_res);
2054
+ save_res(ii, jj, 0, fin_res);
2055
+ save_res(ii+4, jj, 4, fin_res);
2056
+ save_res(ii, jj+4, 8, fin_res);
2057
+ save_res(ii+4, jj+4, 12, fin_res);
2577
2058
  }
2578
2059
 
2579
- template<int RM, int RN>
2580
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2060
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2581
2061
  int64_t ytiles = (m - m0) / RM;
2582
2062
  int64_t xtiles = (n - n0) / RN;
2583
2063
  int64_t tiles = xtiles * ytiles;
@@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC {
2606
2086
  __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2607
2087
  __builtin_mma_xxsetaccz(&acc_0);
2608
2088
  if (isAblock_q4) {
2609
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2089
+ packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2610
2090
  } else {
2611
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2091
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2612
2092
  }
2613
2093
  packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2614
2094
  for(int x = 0; x < 8; x+=4) {
@@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC {
2641
2121
  fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2642
2122
  }
2643
2123
  }
2644
- save_res<RM, RN>(ii, jj, 0, fin_res);
2124
+ save_res(ii, jj, 0, fin_res, RM, RN);
2645
2125
  }
2646
2126
  }
2647
2127
 
@@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC {
2654
2134
  } else if constexpr(RM == 8 && RN == 8) {
2655
2135
  KERNEL_8x8(ii,jj);
2656
2136
  } else {
2657
- static_assert(false, "RN/RM values not supported");
2137
+ assert(false && "RN/RM values not supported");
2658
2138
  }
2659
2139
  }
2660
2140
 
@@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC {
2676
2156
  }
2677
2157
 
2678
2158
  const TA *const A;
2679
- const TB *const B;
2680
- TC *C;
2681
- TA *At;
2682
- TB *Bt;
2159
+ const block_q8_0 *const B;
2160
+ float *C;
2683
2161
  const int64_t k;
2684
2162
  const int64_t lda;
2685
2163
  const int64_t ldb;
@@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC {
2688
2166
  const int nth;
2689
2167
  };
2690
2168
 
2691
- template <typename TA, typename TB, typename TC>
2692
2169
  class tinyBLAS_PPC {
2693
2170
  public:
2694
2171
  tinyBLAS_PPC(int64_t k,
2695
- const TA *A, int64_t lda,
2696
- const TB *B, int64_t ldb,
2697
- TC *C, int64_t ldc,
2172
+ const float *A, int64_t lda,
2173
+ const float *B, int64_t ldb,
2174
+ float *C, int64_t ldc,
2698
2175
  int ith, int nth)
2699
2176
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2700
2177
  }
@@ -2707,247 +2184,139 @@ class tinyBLAS_PPC {
2707
2184
 
2708
2185
  void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
2709
2186
 
2710
- template<typename VA>
2711
- void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
2187
+ inline void vector_permute_store_4(vector float *src, float *vecOffset) {
2188
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2189
+ t1 = vec_mergeh(src[0], src[1]);
2190
+ t2 = vec_mergeh(src[2], src[3]);
2191
+ t3 = vec_mergel(src[0], src[1]);
2192
+ t4 = vec_mergel(src[2], src[3]);
2193
+
2194
+ t5 = vec_xxpermdi(t1, t2, 0);
2195
+ t6 = vec_xxpermdi(t1, t2, 3);
2196
+ t7 = vec_xxpermdi(t3, t4, 0);
2197
+ t8 = vec_xxpermdi(t3, t4, 3);
2198
+
2199
+ vec_xst(t5, 0, vecOffset);
2200
+ vec_xst(t6, 0, vecOffset + 4);
2201
+ vec_xst(t7, 0, vecOffset + 8);
2202
+ vec_xst(t8, 0, vecOffset + 12);
2203
+ }
2204
+
2205
+ inline void vector_permute_store_8(vector float *src, float *vecOffset) {
2206
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
2207
+ t1 = vec_mergeh(src[0], src[1]);
2208
+ t2 = vec_mergeh(src[2], src[3]);
2209
+ t3 = vec_mergeh(src[4], src[5]);
2210
+ t4 = vec_mergeh(src[6], src[7]);
2211
+
2212
+ t5 = vec_xxpermdi(t1, t2, 0);
2213
+ t6 = vec_xxpermdi(t3, t4, 0);
2214
+ t7 = vec_xxpermdi(t1, t2, 3);
2215
+ t8 = vec_xxpermdi(t3, t4, 3);
2216
+
2217
+ vec_xst(t5, 0, vecOffset);
2218
+ vec_xst(t6, 0, vecOffset + 4);
2219
+ vec_xst(t7, 0, vecOffset + 8);
2220
+ vec_xst(t8, 0, vecOffset + 12);
2221
+
2222
+ t1 = vec_mergel(src[0], src[1]);
2223
+ t2 = vec_mergel(src[2], src[3]);
2224
+ t3 = vec_mergel(src[4], src[5]);
2225
+ t4 = vec_mergel(src[6], src[7]);
2226
+
2227
+ t5 = vec_xxpermdi(t1, t2, 0);
2228
+ t6 = vec_xxpermdi(t3, t4, 0);
2229
+ t7 = vec_xxpermdi(t1, t2, 3);
2230
+ t8 = vec_xxpermdi(t3, t4, 3);
2231
+
2232
+ vec_xst(t5, 0, vecOffset + 16);
2233
+ vec_xst(t6, 0, vecOffset + 20);
2234
+ vec_xst(t7, 0, vecOffset + 24);
2235
+ vec_xst(t8, 0, vecOffset + 28);
2236
+ }
2237
+
2238
+ void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
2712
2239
  int64_t i, j;
2713
- TA *aoffset = NULL, *boffset = NULL;
2714
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2715
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2716
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
2717
- VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2718
- VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2719
- VA t1, t2, t3, t4, t5, t6, t7, t8;
2720
- aoffset = const_cast<TA*>(a);
2240
+ float * aoffsets[8];
2241
+ float *aoffset = NULL, *boffset = NULL;
2242
+ __vector_pair arr[8];
2243
+ vector float c[8][2] = {0};
2244
+ vector float c1[8] = {0};
2245
+ vector float c2[8] = {0};
2246
+ aoffset = const_cast<float*>(a);
2721
2247
  boffset = vec;
2722
2248
  j = (rows >> 3);
2723
2249
  if (j > 0) {
2724
2250
 
2725
2251
  do {
2726
- aoffset1 = aoffset;
2727
- aoffset2 = aoffset1 + lda;
2728
- aoffset3 = aoffset2 + lda;
2729
- aoffset4 = aoffset3 + lda;
2730
- aoffset5 = aoffset4 + lda;
2731
- aoffset6 = aoffset5 + lda;
2732
- aoffset7 = aoffset6 + lda;
2733
- aoffset8 = aoffset7 + lda;
2252
+ aoffsets[0] = aoffset;
2253
+ for (int it = 1; it< 8; it++)
2254
+ aoffsets[it] = aoffsets[it-1] + lda;
2734
2255
  aoffset += 8 * lda;
2735
2256
  i = (cols >> 3);
2736
2257
  if (i > 0) {
2737
2258
  do {
2738
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2739
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2740
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2741
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2742
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
2743
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
2744
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
2745
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
2746
- __builtin_vsx_disassemble_pair(c1, &C1);
2747
- __builtin_vsx_disassemble_pair(c2, &C2);
2748
- __builtin_vsx_disassemble_pair(c3, &C3);
2749
- __builtin_vsx_disassemble_pair(c4, &C4);
2750
- __builtin_vsx_disassemble_pair(c5, &C5);
2751
- __builtin_vsx_disassemble_pair(c6, &C6);
2752
- __builtin_vsx_disassemble_pair(c7, &C7);
2753
- __builtin_vsx_disassemble_pair(c8, &C8);
2754
-
2755
- t1 = vec_mergeh(c1[0], c2[0]);
2756
- t2 = vec_mergeh(c3[0], c4[0]);
2757
- t3 = vec_mergeh(c5[0], c6[0]);
2758
- t4 = vec_mergeh(c7[0], c8[0]);
2759
- t5 = vec_xxpermdi(t1, t2, 0);
2760
- t6 = vec_xxpermdi(t3, t4, 0);
2761
- t7 = vec_xxpermdi(t1, t2, 3);
2762
- t8 = vec_xxpermdi(t3, t4, 3);
2763
- vec_xst(t5, 0, boffset);
2764
- vec_xst(t6, 0, boffset+4);
2765
- vec_xst(t7, 0, boffset+8);
2766
- vec_xst(t8, 0, boffset+12);
2767
-
2768
- t1 = vec_mergel(c1[0], c2[0]);
2769
- t2 = vec_mergel(c3[0], c4[0]);
2770
- t3 = vec_mergel(c5[0], c6[0]);
2771
- t4 = vec_mergel(c7[0], c8[0]);
2772
- t5 = vec_xxpermdi(t1, t2, 0);
2773
- t6 = vec_xxpermdi(t3, t4, 0);
2774
- t7 = vec_xxpermdi(t1, t2, 3);
2775
- t8 = vec_xxpermdi(t3, t4, 3);
2776
- vec_xst(t5, 0, boffset+16);
2777
- vec_xst(t6, 0, boffset+20);
2778
- vec_xst(t7, 0, boffset+24);
2779
- vec_xst(t8, 0, boffset+28);
2780
-
2781
- t1 = vec_mergeh(c1[1], c2[1]);
2782
- t2 = vec_mergeh(c3[1], c4[1]);
2783
- t3 = vec_mergeh(c5[1], c6[1]);
2784
- t4 = vec_mergeh(c7[1], c8[1]);
2785
- t5 = vec_xxpermdi(t1, t2, 0);
2786
- t6 = vec_xxpermdi(t3, t4, 0);
2787
- t7 = vec_xxpermdi(t1, t2, 3);
2788
- t8 = vec_xxpermdi(t3, t4, 3);
2789
- vec_xst(t5, 0, boffset+32);
2790
- vec_xst(t6, 0, boffset+36);
2791
- vec_xst(t7, 0, boffset+40);
2792
- vec_xst(t8, 0, boffset+44);
2793
-
2794
- t1 = vec_mergel(c1[1], c2[1]);
2795
- t2 = vec_mergel(c3[1], c4[1]);
2796
- t3 = vec_mergel(c5[1], c6[1]);
2797
- t4 = vec_mergel(c7[1], c8[1]);
2798
- t5 = vec_xxpermdi(t1, t2, 0);
2799
- t6 = vec_xxpermdi(t3, t4, 0);
2800
- t7 = vec_xxpermdi(t1, t2, 3);
2801
- t8 = vec_xxpermdi(t3, t4, 3);
2802
- vec_xst(t5, 0, boffset+48);
2803
- vec_xst(t6, 0, boffset+52);
2804
- vec_xst(t7, 0, boffset+56);
2805
- vec_xst(t8, 0, boffset+60);
2806
-
2807
- aoffset1 += 8*lda;
2808
- aoffset2 += 8*lda;
2809
- aoffset3 += 8*lda;
2810
- aoffset4 += 8*lda;
2259
+ for (int it = 0; it< 8; it++) {
2260
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2261
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2262
+ c1[it] = c[it][0];
2263
+ c2[it] = c[it][1];
2264
+ }
2265
+
2266
+ vector_permute_store_8(c1, boffset);
2267
+ vector_permute_store_8(c2, boffset+32);
2268
+ for (int it = 0; it < 4; it++)
2269
+ aoffsets[it] = aoffsets[it] + 8*lda;
2811
2270
  boffset += 64;
2812
2271
  i--;
2813
2272
  } while(i > 0);
2814
2273
  }
2815
2274
  if (cols & 4) {
2816
- c1[0] = vec_xl(0, aoffset1);
2817
- c2[0] = vec_xl(0, aoffset2);
2818
- c3[0] = vec_xl(0, aoffset3);
2819
- c4[0] = vec_xl(0, aoffset4);
2820
- c5[0] = vec_xl(0, aoffset5);
2821
- c6[0] = vec_xl(0, aoffset6);
2822
- c7[0] = vec_xl(0, aoffset7);
2823
- c8[0] = vec_xl(0, aoffset8);
2824
-
2825
- t1 = vec_mergeh(c1[0], c2[0]);
2826
- t2 = vec_mergeh(c3[0], c4[0]);
2827
- t3 = vec_mergeh(c5[0], c6[0]);
2828
- t4 = vec_mergeh(c7[0], c8[0]);
2829
- t5 = vec_xxpermdi(t1, t2, 0);
2830
- t6 = vec_xxpermdi(t3, t4, 0);
2831
- t7 = vec_xxpermdi(t1, t2, 3);
2832
- t8 = vec_xxpermdi(t3, t4, 3);
2833
- vec_xst(t5, 0, boffset);
2834
- vec_xst(t6, 0, boffset+4);
2835
- vec_xst(t7, 0, boffset+8);
2836
- vec_xst(t8, 0, boffset+12);
2837
-
2838
- t1 = vec_mergel(c1[0], c2[0]);
2839
- t2 = vec_mergel(c3[0], c4[0]);
2840
- t3 = vec_mergel(c5[0], c6[0]);
2841
- t4 = vec_mergel(c7[0], c8[0]);
2842
- t5 = vec_xxpermdi(t1, t2, 0);
2843
- t6 = vec_xxpermdi(t3, t4, 0);
2844
- t7 = vec_xxpermdi(t1, t2, 3);
2845
- t8 = vec_xxpermdi(t3, t4, 3);
2846
- vec_xst(t5, 0, boffset+16);
2847
- vec_xst(t6, 0, boffset+20);
2848
- vec_xst(t7, 0, boffset+24);
2849
- vec_xst(t8, 0, boffset+28);
2275
+ for (int it = 0; it < 8 ; it++)
2276
+ c1[it] = vec_xl(0, aoffsets[it]);
2277
+ vector_permute_store_8(c1, boffset);
2850
2278
  }
2851
2279
  j--;
2852
2280
  } while(j > 0);
2853
2281
  }
2854
2282
 
2855
2283
  if (rows & 4) {
2856
- aoffset1 = aoffset;
2857
- aoffset2 = aoffset1 + lda;
2858
- aoffset3 = aoffset2 + lda;
2859
- aoffset4 = aoffset3 + lda;
2284
+ aoffsets[0] = aoffset;
2285
+ for (int it = 1; it < 4; it++)
2286
+ aoffsets[it] = aoffsets[it-1] + lda;
2860
2287
  aoffset += 4 * lda;
2861
2288
  i = (cols >> 3);
2862
2289
  if (i > 0) {
2863
2290
  do {
2864
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
2865
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
2866
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
2867
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
2868
- __builtin_vsx_disassemble_pair(c1, &C1);
2869
- __builtin_vsx_disassemble_pair(c2, &C2);
2870
- __builtin_vsx_disassemble_pair(c3, &C3);
2871
- __builtin_vsx_disassemble_pair(c4, &C4);
2872
-
2873
- t1 = vec_mergeh(c1[0], c2[0]);
2874
- t2 = vec_mergeh(c3[0], c4[0]);
2875
- t3 = vec_mergel(c1[0], c2[0]);
2876
- t4 = vec_mergel(c3[0], c4[0]);
2877
- t5 = vec_xxpermdi(t1, t2, 0);
2878
- t6 = vec_xxpermdi(t1, t2, 3);
2879
- t7 = vec_xxpermdi(t3, t4, 0);
2880
- t8 = vec_xxpermdi(t3, t4, 3);
2881
- vec_xst(t5, 0, boffset);
2882
- vec_xst(t6, 0, boffset+4);
2883
- vec_xst(t7, 0, boffset+8);
2884
- vec_xst(t8, 0, boffset+12);
2885
-
2886
- t1 = vec_mergeh(c1[1], c2[1]);
2887
- t2 = vec_mergeh(c3[1], c4[1]);
2888
- t3 = vec_mergel(c1[1], c2[1]);
2889
- t4 = vec_mergel(c3[1], c4[1]);
2890
- t5 = vec_xxpermdi(t1, t2, 0);
2891
- t6 = vec_xxpermdi(t1, t2, 3);
2892
- t7 = vec_xxpermdi(t3, t4, 0);
2893
- t8 = vec_xxpermdi(t3, t4, 3);
2894
- vec_xst(t5, 0, boffset+16);
2895
- vec_xst(t6, 0, boffset+20);
2896
- vec_xst(t7, 0, boffset+24);
2897
- vec_xst(t8, 0, boffset+28);
2898
-
2899
- aoffset1 += 8*lda;
2900
- aoffset2 += 8*lda;
2901
- aoffset3 += 8*lda;
2902
- aoffset4 += 8*lda;
2291
+ for (int it = 0; it < 4; it++) {
2292
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2293
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2294
+ c1[it] = c[it][0];
2295
+ c2[it] = c[it][1];
2296
+ }
2297
+ vector_permute_store_4(c1, boffset);
2298
+ vector_permute_store_4(c2, boffset+16);
2299
+ for (int it = 0; it < 4; it++)
2300
+ aoffsets[it] += 8*lda;
2903
2301
  boffset += 32;
2904
2302
  i--;
2905
2303
  } while(i > 0);
2906
2304
  }
2907
2305
 
2908
2306
  if (cols & 4) {
2909
- c1[0] = vec_xl(0, aoffset1);
2910
- c2[0] = vec_xl(0, aoffset2);
2911
- c3[0] = vec_xl(0, aoffset3);
2912
- c4[0] = vec_xl(0, aoffset4);
2913
-
2914
- t1 = vec_mergeh(c1[0], c2[0]);
2915
- t2 = vec_mergeh(c3[0], c4[0]);
2916
- t3 = vec_xxpermdi(t1, t2, 0);
2917
- t4 = vec_xxpermdi(t1, t2, 3);
2918
- vec_xst(t3, 0, boffset);
2919
- vec_xst(t4, 0, boffset+4);
2920
-
2921
- t1 = vec_mergel(c1[0], c2[0]);
2922
- t2 = vec_mergel(c3[0], c4[0]);
2923
- t3 = vec_xxpermdi(t1, t2, 0);
2924
- t4 = vec_xxpermdi(t1, t2, 3);
2925
- vec_xst(t3, 0, boffset+8);
2926
- vec_xst(t4, 0, boffset+12);
2307
+ for (int it = 0; it < 4; it++)
2308
+ c1[it] = vec_xl(0, aoffsets[it]);
2309
+ vector_permute_store_4(c1, boffset);
2927
2310
  }
2928
2311
  }
2929
2312
  if (rows & 3) {
2930
- aoffset1 = aoffset;
2931
- aoffset2 = aoffset1 + lda;
2932
- aoffset3 = aoffset2 + lda;
2313
+ aoffsets[0] = aoffset;
2314
+ for (int it = 1; it < 3; it++)
2315
+ aoffsets[it] = aoffsets[it-1] + lda;
2933
2316
  if (cols & 4) {
2934
- c1[0] = vec_xl(0, aoffset1);
2935
- c2[0] = vec_xl(0, aoffset2);
2936
- c3[0] = vec_xl(0, aoffset3);
2937
-
2938
- t1 = vec_mergeh(c1[0], c2[0]);
2939
- t2 = vec_mergeh(c3[0], c4[0]);
2940
- t3 = vec_xxpermdi(t1, t2, 0);
2941
- t4 = vec_xxpermdi(t1, t2, 3);
2942
- vec_xst(t3, 0, boffset);
2943
- vec_xst(t4, 0, boffset+4);
2944
-
2945
- t1 = vec_mergel(c1[0], c2[0]);
2946
- t2 = vec_mergel(c3[0], c4[0]);
2947
- t3 = vec_xxpermdi(t1, t2, 0);
2948
- t4 = vec_xxpermdi(t1, t2, 3);
2949
- vec_xst(t3, 0, boffset+8);
2950
- vec_xst(t4, 0, boffset+12);
2317
+ for (int it = 0; it < 3; it++)
2318
+ c1[it] = vec_xl(0, aoffsets[it]);
2319
+ vector_permute_store_4(c1, boffset);
2951
2320
  }
2952
2321
  }
2953
2322
  }
@@ -2957,8 +2326,8 @@ class tinyBLAS_PPC {
2957
2326
  acc_t acc_0;
2958
2327
  __builtin_mma_xxsetaccz(&acc_0);
2959
2328
  for (int l = 0; l < k; l+=4) {
2960
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2961
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2329
+ packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2330
+ packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2962
2331
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2963
2332
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2964
2333
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -2973,8 +2342,8 @@ class tinyBLAS_PPC {
2973
2342
  __builtin_mma_xxsetaccz(&acc_0);
2974
2343
  __builtin_mma_xxsetaccz(&acc_1);
2975
2344
  for (int64_t l = 0; l < k; l+=4) {
2976
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2977
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
2345
+ packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
2346
+ packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
2978
2347
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2979
2348
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2980
2349
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2994,8 +2363,8 @@ class tinyBLAS_PPC {
2994
2363
  __builtin_mma_xxsetaccz(&acc_0);
2995
2364
  __builtin_mma_xxsetaccz(&acc_1);
2996
2365
  for (int64_t l = 0; l < k; l+=4) {
2997
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2998
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2366
+ packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
2367
+ packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
2999
2368
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3000
2369
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3001
2370
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -3017,8 +2386,8 @@ class tinyBLAS_PPC {
3017
2386
  __builtin_mma_xxsetaccz(&acc_2);
3018
2387
  __builtin_mma_xxsetaccz(&acc_3);
3019
2388
  for (int l = 0; l < k; l+=8) {
3020
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
3021
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
2389
+ packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
2390
+ packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
3022
2391
  for(int x = 0; x < 16; x+=2) {
3023
2392
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3024
2393
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -3033,155 +2402,37 @@ class tinyBLAS_PPC {
3033
2402
  }
3034
2403
 
3035
2404
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3036
- int64_t mc, nc, mp, np;
3037
- int m_rem = MIN(m - m0, 16);
3038
- int n_rem = MIN(n - n0, 16);
3039
- if (m_rem >= 16 && n_rem >= 8) {
3040
- mc = 8;
3041
- nc = 8;
3042
- gemm<8,8>(m0, m, n0, n);
3043
- } else if(m_rem >= 8 && n_rem >= 16) {
3044
- mc = 8;
3045
- nc = 8;
3046
- gemm<8,8>(m0, m, n0, n);
3047
- } else if (m_rem >= 8 && n_rem >= 8) {
3048
- mc = 8;
3049
- nc = 8;
3050
- gemm<8,8>(m0, m, n0, n);
2405
+ int m_rem = MIN(m - m0, 8);
2406
+ int n_rem = MIN(n - n0, 8);
2407
+ int mc = 0, nc = 0;
2408
+ if (m_rem >= 8 && n_rem >= 8) {
2409
+ mc = 8;
2410
+ nc = 8;
2411
+ gemm<8, 8>(m0, m, n0, n);
3051
2412
  } else if (m_rem >= 4 && n_rem >= 8) {
3052
- mc = 4;
3053
- nc = 8;
3054
- gemm<4,8>(m0, m, n0, n);
2413
+ mc = 4;
2414
+ nc = 8;
2415
+ gemm<4, 8>(m0, m, n0, n);
3055
2416
  } else if (m_rem >= 8 && n_rem >= 4) {
3056
- mc = 8;
3057
- nc = 4;
3058
- gemm<8,4>(m0, m, n0, n);
2417
+ mc = 8;
2418
+ nc = 4;
2419
+ gemm<8, 4>(m0, m, n0, n);
3059
2420
  } else if (m_rem >= 4 && n_rem >= 4) {
3060
- mc = 4;
3061
- nc = 4;
3062
- gemm<4,4>(m0, m, n0, n);
3063
- } else if ((m_rem < 4) && (n_rem > 4)) {
3064
- nc = 4;
3065
- switch(m_rem) {
3066
- case 1:
3067
- mc = 1;
3068
- gemm_small(m0, m, n0, n, mc, nc);
3069
- break;
3070
- case 2:
3071
- mc = 2;
3072
- gemm_small(m0, m, n0, n, mc, nc);
3073
- break;
3074
- case 3:
3075
- mc = 3;
3076
- gemm_small(m0, m, n0, n, mc, nc);
3077
- break;
3078
- default:
3079
- return;
3080
- }
3081
- } else if ((m_rem > 4) && (n_rem < 4)) {
3082
- mc = 4;
3083
- switch(n_rem) {
3084
- case 1:
3085
- nc = 1;
3086
- gemm_small(m0, m, n0, n, mc, nc);
3087
- break;
3088
- case 2:
3089
- nc = 2;
3090
- gemm_small(m0, m, n0, n, mc, nc);
3091
- break;
3092
- case 3:
3093
- nc = 3;
3094
- gemm_small(m0, m, n0, n, mc, nc);
3095
- break;
3096
- default:
3097
- return;
3098
- }
2421
+ mc = 4;
2422
+ nc = 4;
2423
+ gemm<4, 4>(m0, m, n0, n);
3099
2424
  } else {
3100
- switch((m_rem << 4) | n_rem) {
3101
- case 0x43:
3102
- mc = 4;
3103
- nc = 3;
3104
- gemm_small(m0, m, n0, n, mc, nc);
3105
- break;
3106
- case 0x42:
3107
- mc = 4;
3108
- nc = 2;
3109
- gemm_small(m0, m, n0, n, mc, nc);
3110
- break;
3111
- case 0x41:
3112
- mc = 4;
3113
- nc = 1;
3114
- gemm_small(m0, m, n0, n, mc, nc);
3115
- break;
3116
- case 0x34:
3117
- mc = 3;
3118
- nc = 4;
3119
- gemm_small(m0, m, n0, n, mc, nc);
3120
- break;
3121
- case 0x33:
3122
- mc = 3;
3123
- nc = 3;
3124
- gemm_small(m0, m, n0, n, mc, nc);
3125
- break;
3126
- case 0x32:
3127
- mc = 3;
3128
- nc = 2;
3129
- gemm_small(m0, m, n0, n, mc, nc);
3130
- break;
3131
- case 0x31:
3132
- mc = 3;
3133
- nc = 1;
3134
- gemm_small(m0, m, n0, n, mc, nc);
3135
- break;
3136
- case 0x24:
3137
- mc = 2;
3138
- nc = 4;
3139
- gemm_small(m0, m, n0, n, mc, nc);
3140
- break;
3141
- case 0x23:
3142
- mc = 2;
3143
- nc = 3;
3144
- gemm_small(m0, m, n0, n, mc, nc);
3145
- break;
3146
- case 0x22:
3147
- mc = 2;
3148
- nc = 2;
3149
- gemm_small(m0, m, n0, n, mc, nc);
3150
- break;
3151
- case 0x21:
3152
- mc = 2;
3153
- nc = 1;
3154
- gemm_small(m0, m, n0, n, mc, nc);
3155
- break;
3156
- case 0x14:
3157
- mc = 1;
3158
- nc = 4;
3159
- gemm_small(m0, m, n0, n, mc, nc);
3160
- break;
3161
- case 0x13:
3162
- mc = 1;
3163
- nc = 3;
3164
- gemm_small(m0, m, n0, n, mc, nc);
3165
- break;
3166
- case 0x12:
3167
- mc = 1;
3168
- nc = 2;
3169
- gemm_small(m0, m, n0, n, mc, nc);
3170
- break;
3171
- case 0x11:
3172
- mc = 1;
3173
- nc = 1;
3174
- gemm_small(m0, m, n0, n, mc, nc);
3175
- break;
3176
- default:
3177
- return;
3178
- }
2425
+ mc = (m_rem >= 4) ? 4 : m_rem;
2426
+ nc = (n_rem >= 4) ? 4 : n_rem;
2427
+ if (mc == 0 || nc == 0)
2428
+ return;
2429
+ gemm_small(m0, m, n0, n, mc, nc);
3179
2430
  }
3180
- mp = m0 + (m - m0) / mc * mc;
3181
- np = n0 + (n - n0) / nc * nc;
2431
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
2432
+ int64_t np = n0 + ((n - n0) / nc) * nc;
3182
2433
  mnpack(mp, m, n0, np);
3183
2434
  mnpack(m0, m, np, n);
3184
- }
2435
+ }
3185
2436
 
3186
2437
  void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3187
2438
  int64_t ytiles = (m - m0) / RM;
@@ -3206,22 +2457,22 @@ class tinyBLAS_PPC {
3206
2457
  * matrix elements.
3207
2458
  */
3208
2459
  if (RM == 1) {
3209
- TA* a = const_cast<TA*>(A+(ii)*lda+l);
3210
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2460
+ float* a = const_cast<float*>(A+(ii)*lda+l);
2461
+ packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
3211
2462
  vec_A[0] = (vec_t)vec_xl(0,a);
3212
- vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3213
- vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3214
- vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2463
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
2464
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
2465
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
3215
2466
  } else if (RN == 1) {
3216
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3217
- TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2467
+ packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2468
+ float* b = const_cast<float*>(B+(jj)*ldb+l);
3218
2469
  vec_B[0] = (vec_t)vec_xl(0,b);
3219
- vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3220
- vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3221
- vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
2470
+ vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
2471
+ vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
2472
+ vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
3222
2473
  } else {
3223
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3224
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2474
+ packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
2475
+ packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
3225
2476
  }
3226
2477
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3227
2478
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -3231,7 +2482,7 @@ class tinyBLAS_PPC {
3231
2482
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
3232
2483
  for (int I = 0; I < RM; I++) {
3233
2484
  for (int J = 0; J < RN; J++) {
3234
- *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2485
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
3235
2486
  }
3236
2487
  }
3237
2488
  }
@@ -3263,11 +2514,9 @@ class tinyBLAS_PPC {
3263
2514
  }
3264
2515
  }
3265
2516
 
3266
- const TA *const A;
3267
- const TB *const B;
3268
- TC *C;
3269
- TA *At;
3270
- TB *Bt;
2517
+ const float *const A;
2518
+ const float *const B;
2519
+ float *C;
3271
2520
  const int64_t k;
3272
2521
  const int64_t lda;
3273
2522
  const int64_t ldb;
@@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3366
2615
  #elif defined(__MMA__)
3367
2616
  if (k % 8)
3368
2617
  return false;
3369
- tinyBLAS_PPC<float, float, float> tb{
2618
+ tinyBLAS_PPC tb{
3370
2619
  k, (const float *)A, lda,
3371
2620
  (const float *)B, ldb,
3372
2621
  (float *)C, ldc,
@@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3493
2742
  return false;
3494
2743
  if (m < 8 && m != 4)
3495
2744
  return false;
3496
- tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
2745
+ tinyBLAS_Q0_PPC<block_q8_0> tb{
3497
2746
  k, (const block_q8_0 *)A, lda,
3498
2747
  (const block_q8_0 *)B, ldb,
3499
2748
  (float *)C, ldc,
@@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
3530
2779
  return false;
3531
2780
  if (m < 8 && m != 4)
3532
2781
  return false;
3533
- tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
2782
+ tinyBLAS_Q0_PPC<block_q4_0> tb{
3534
2783
  k, (const block_q4_0 *)A, lda,
3535
2784
  (const block_q8_0 *)B, ldb,
3536
2785
  (float *)C, ldc,