mini_embed 0.4.0 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mini_embed/mini_embed.c +153 -61
  3. metadata +1 -1
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0f3a2f9365c3ba228faf709ec97d986dadcc22a78ab0b706d35ba3da5e1552ce
4
- data.tar.gz: 778847fe77dc4cb8b8774b62fe6f212c435880e470868f15c2517f1adb37211a
3
+ metadata.gz: d41689bea618e06a4f1b94ee27aa4ed2c68a8d13677b5c31473e50cce5812062
4
+ data.tar.gz: fbea8167ae5dc748a00ec2670150ccf6d2a38ca4861a90df1e287f0c6adb9854
5
5
  SHA512:
6
- metadata.gz: '086948ced123967c0aa5f7e0fcb6624dffb2f68f6c95e4abd2dbaa429fb8717d98be76f6b173925b21af1bc14a0fe8af9d6d68d891ba4ce90d5a0b2145df55ef'
7
- data.tar.gz: 2ac2c25baf87dd7b21fc38ccce6c3be0a3c133008ef27fa1790a95d3bc6146d5cb522166c93d8961edb16c3128e13d8f0d9206869abee43af10586f0124e00f2
6
+ metadata.gz: 4ee669edc9f38921ec3d195cfe8a781c64c107d70f3fee8d1e47ef0e916161294c6243af0f223c2f8efea76c57c329a12c6b72566233b791839f5b0909058efc
7
+ data.tar.gz: 370a99e583b830ac99b4bd70eca14769d7bda162f5acab4456e63cec226cb1d3631831cf92bf87c2523909cdfec7fd231d5746cef9c30becfdda1da7400eea0e
@@ -637,10 +637,13 @@ static void dequantize_row_q5_0(const void *vx, float *y, int k) {
637
637
  uint32_t qh32;
638
638
  memcpy(&qh32, block + 2, 4);
639
639
  const uint8_t *ql = block + 6;
640
- for (int j = 0; j < 32; j++) {
641
- const uint8_t vh = (qh32 >> j) & 1;
642
- const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
643
- y[i*32 + j] = (v - 16.0f) * d;
640
+ for (int j = 0; j < 16; j++) {
641
+ const uint8_t xh0 = ((qh32 >> (j + 0)) << 4) & 0x10;
642
+ const uint8_t xh1 = ((qh32 >> (j + 12))) & 0x10;
643
+ const int x0 = ((ql[j] & 0x0F) | xh0) - 16;
644
+ const int x1 = ((ql[j] >> 4) | xh1) - 16;
645
+ y[i*32 + j] = x0 * d;
646
+ y[i*32 + j + 16] = x1 * d;
644
647
  }
645
648
  }
646
649
  }
@@ -658,10 +661,13 @@ static void dequantize_row_q5_1(const void *vx, float *y, int k) {
658
661
  uint32_t qh32;
659
662
  memcpy(&qh32, block + 4, 4);
660
663
  const uint8_t *ql = block + 8;
661
- for (int j = 0; j < 32; j++) {
662
- const uint8_t vh = (qh32 >> j) & 1;
663
- const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
664
- y[i*32 + j] = v * d + m;
664
+ for (int j = 0; j < 16; j++) {
665
+ const uint8_t xh0 = ((qh32 >> (j + 0)) << 4) & 0x10;
666
+ const uint8_t xh1 = ((qh32 >> (j + 12))) & 0x10;
667
+ const int x0 = (ql[j] & 0x0F) | xh0;
668
+ const int x1 = (ql[j] >> 4) | xh1;
669
+ y[i*32 + j] = x0 * d + m;
670
+ y[i*32 + j + 16] = x1 * d + m;
665
671
  }
666
672
  }
667
673
  }
@@ -690,10 +696,10 @@ static void dequantize_row_q8_1(const void *vx, float *y, int k) {
690
696
  memcpy(&d16, block, 2);
691
697
  memcpy(&s16, block + 2, 2);
692
698
  const float d = fp16_to_fp32(d16);
693
- const float s = fp16_to_fp32(s16);
699
+ (void)s16;
694
700
  const int8_t *q = (const int8_t*)(block + 4);
695
701
  for (int j = 0; j < 32; j++) {
696
- y[i*32 + j] = (float)q[j] * d + s;
702
+ y[i*32 + j] = (float)q[j] * d;
697
703
  }
698
704
  }
699
705
  }
@@ -704,8 +710,8 @@ static inline void get_scale_min_k4(int j, const uint8_t *q, uint8_t *d, uint8_t
704
710
  *d = q[j] & 63;
705
711
  *m = q[j + 4] & 63;
706
712
  } else {
707
- *d = (q[j+4] & 0xF) | ((q[j-3] >> 6) << 4);
708
- *m = (q[j+4] >> 4) | ((q[j-1] >> 6) << 4);
713
+ *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
714
+ *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
709
715
  }
710
716
  }
711
717
 
@@ -715,19 +721,30 @@ static void dequantize_row_q2_K(const void *vx, float *y, int k) {
715
721
  for (int i = 0; i < nb; i++) {
716
722
  const uint8_t *block = x + i * 84;
717
723
  uint16_t d16, dmin16;
718
- memcpy(&d16, block, 2);
719
- memcpy(&dmin16, block + 2, 2);
724
+ memcpy(&d16, block + 80, 2);
725
+ memcpy(&dmin16, block + 82, 2);
720
726
  const float d = fp16_to_fp32(d16);
721
727
  const float min = fp16_to_fp32(dmin16);
722
- const uint8_t *scales = block + 4;
723
- const uint8_t *q = block + 20;
724
- for (int j = 0; j < QK_K; j += 64) {
725
- const float dl = d * (scales[j/64] & 0xF);
726
- const float ml = min * (scales[j/64] >> 4);
727
- for (int l = 0; l < 64; l++) {
728
- const int v = (q[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
729
- y[i*QK_K + j + l] = v * dl + ml;
728
+ const uint8_t *scales = block;
729
+ const uint8_t *q = block + 16;
730
+ float *dst = y + (size_t)i * QK_K;
731
+ int is = 0;
732
+ for (int n = 0; n < QK_K; n += 128) {
733
+ int shift = 0;
734
+ for (int j = 0; j < 4; j++) {
735
+ uint8_t sc = scales[is++];
736
+ float dl = d * (sc & 0x0F);
737
+ float ml = min * (sc >> 4);
738
+ for (int l = 0; l < 16; l++) *dst++ = dl * ((q[l] >> shift) & 3) - ml;
739
+
740
+ sc = scales[is++];
741
+ dl = d * (sc & 0x0F);
742
+ ml = min * (sc >> 4);
743
+ for (int l = 0; l < 16; l++) *dst++ = dl * ((q[l + 16] >> shift) & 3) - ml;
744
+
745
+ shift += 2;
730
746
  }
747
+ q += 32;
731
748
  }
732
749
  }
733
750
  }
@@ -735,30 +752,45 @@ static void dequantize_row_q2_K(const void *vx, float *y, int k) {
735
752
  static void dequantize_row_q3_K(const void *vx, float *y, int k) {
736
753
  const int nb = k / QK_K;
737
754
  const uint8_t *x = vx;
755
+ const uint32_t kmask1 = 0x03030303;
756
+ const uint32_t kmask2 = 0x0f0f0f0f;
757
+ uint32_t aux[4];
758
+ const int8_t *scales = (const int8_t*)aux;
738
759
  for (int i = 0; i < nb; i++) {
739
760
  const uint8_t *block = x + i * 110;
740
761
  uint16_t d16;
741
- memcpy(&d16, block, 2);
742
- const float d = fp16_to_fp32(d16);
743
- const uint8_t *hmask = block + 2;
744
- const uint8_t *q = block + 34;
745
- const uint8_t *scales = block + 98;
746
- for (int j = 0; j < QK_K; j += 64) {
747
- const uint8_t ls1 = scales[j/64] & 0x1F;
748
- const uint8_t ls2 = (scales[j/64] >> 5) | ((scales[j/64 + 1] & 0x7) << 3);
749
- const uint8_t ls3 = ((scales[j/64 + 1] >> 3) & 0x1F);
750
- const uint8_t ls4 = (scales[j/64 + 1] >> 8);
751
- for (int l = 0; l < 64; l++) {
752
- int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
753
- const int bit = (hmask[(j+l)/8] >> ((j+l)%8)) & 1;
754
- v |= bit << 4;
755
- float ls;
756
- if (l < 16) ls = ls1;
757
- else if (l < 32) ls = ls2;
758
- else if (l < 48) ls = ls3;
759
- else ls = ls4;
760
- y[i*QK_K + j + l] = (v - 32.0f) * d * ls;
762
+ memcpy(&d16, block + 108, 2);
763
+ const float d_all = fp16_to_fp32(d16);
764
+ const uint8_t *q = block + 32;
765
+ const uint8_t *hm = block;
766
+ uint8_t m = 1;
767
+ float *dst = y + (size_t)i * QK_K;
768
+
769
+ memcpy(aux, block + 96, 12);
770
+ uint32_t tmp = aux[2];
771
+ aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
772
+ aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
773
+ aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
774
+ aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
775
+
776
+ int is = 0;
777
+ for (int n = 0; n < QK_K; n += 128) {
778
+ int shift = 0;
779
+ for (int j = 0; j < 4; j++) {
780
+ float dl = d_all * (scales[is++] - 32);
781
+ for (int l = 0; l < 16; l++) {
782
+ *dst++ = dl * ((int)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
783
+ }
784
+
785
+ dl = d_all * (scales[is++] - 32);
786
+ for (int l = 0; l < 16; l++) {
787
+ *dst++ = dl * ((int)((q[l + 16] >> shift) & 3) - ((hm[l + 16] & m) ? 0 : 4));
788
+ }
789
+
790
+ shift += 2;
791
+ m <<= 1;
761
792
  }
793
+ q += 32;
762
794
  }
763
795
  }
764
796
  }
@@ -810,6 +842,7 @@ static void dequantize_row_q5_K(const void *vx, float *y, int k) {
810
842
  const uint8_t *qh = block + 16;
811
843
  const uint8_t *ql = block + 48;
812
844
  int is = 0;
845
+ uint8_t u1 = 1, u2 = 2;
813
846
  for (int j = 0; j < QK_K; j += 64) {
814
847
  uint8_t sc, m;
815
848
  get_scale_min_k4(is, scales, &sc, &m);
@@ -819,17 +852,17 @@ static void dequantize_row_q5_K(const void *vx, float *y, int k) {
819
852
  float d2 = d * sc;
820
853
  float m2 = min * m;
821
854
  for (int l = 0; l < 32; l++) {
822
- int vh = (qh[j/64 * 4 + l/8] >> (l%8)) & 1;
823
- int v = (ql[l] & 0xF) | (vh << 4);
855
+ int v = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0);
824
856
  y[i*QK_K + j + l] = d1 * v - m1;
825
857
  }
826
858
  for (int l = 0; l < 32; l++) {
827
- int vh = (qh[j/64 * 4 + 4 + l/8] >> (l%8)) & 1;
828
- int v = (ql[l] >> 4) | (vh << 4);
859
+ int v = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0);
829
860
  y[i*QK_K + j + 32 + l] = d2 * v - m2;
830
861
  }
831
862
  ql += 32;
832
863
  is += 2;
864
+ u1 <<= 2;
865
+ u2 <<= 2;
833
866
  }
834
867
  }
835
868
  }
@@ -845,23 +878,23 @@ static void dequantize_row_q6_K(const void *vx, float *y, int k) {
845
878
  uint16_t d16;
846
879
  memcpy(&d16, block + 208, 2);
847
880
  const float d = fp16_to_fp32(d16);
881
+ float *dst = y + (size_t)i * QK_K;
848
882
  for (int j = 0; j < QK_K; j += 128) {
849
883
  for (int l = 0; l < 32; l++) {
850
- int v = (ql[j/2 + l] & 0xF) | (((qh[j/4 + l/2] >> ((l%2)*4)) & 0xF) << 4);
851
- y[i*QK_K + j + l] = v * d * scales[j/128 * 8 + l/4];
852
- }
853
- for (int l = 0; l < 32; l++) {
854
- int v = (ql[j/2 + 32 + l] >> 4) | (((qh[j/4 + 16 + l/2] >> ((l%2)*4)) & 0xF) << 4);
855
- y[i*QK_K + j + 32 + l] = v * d * scales[j/128 * 8 + 8 + l/4];
856
- }
857
- for (int l = 0; l < 32; l++) {
858
- int v = (ql[j/2 + 64 + l] & 0xF) | (((qh[j/4 + 32 + l/2] >> ((l%2)*4)) & 0xF) << 4);
859
- y[i*QK_K + j + 64 + l] = v * d * scales[j/128 * 8 + 4 + l/4];
860
- }
861
- for (int l = 0; l < 32; l++) {
862
- int v = (ql[j/2 + 96 + l] >> 4) | (((qh[j/4 + 48 + l/2] >> ((l%2)*4)) & 0xF) << 4);
863
- y[i*QK_K + j + 96 + l] = v * d * scales[j/128 * 8 + 12 + l/4];
884
+ int is = l / 16;
885
+ int q1 = ((ql[l] & 0x0F) | (((qh[l] >> 0) & 3) << 4)) - 32;
886
+ int q2 = ((ql[l + 32] & 0x0F) | (((qh[l] >> 2) & 3) << 4)) - 32;
887
+ int q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
888
+ int q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
889
+ dst[l] = d * scales[is + 0] * q1;
890
+ dst[l + 32] = d * scales[is + 2] * q2;
891
+ dst[l + 64] = d * scales[is + 4] * q3;
892
+ dst[l + 96] = d * scales[is + 6] * q4;
864
893
  }
894
+ dst += 128;
895
+ ql += 64;
896
+ qh += 32;
897
+ scales += 8;
865
898
  }
866
899
  }
867
900
  }
@@ -904,7 +937,7 @@ static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
904
937
  case GGML_TYPE_Q5_0: rb = (nc / 32) * 22; break;
905
938
  case GGML_TYPE_Q5_1: rb = (nc / 32) * 24; break;
906
939
  case GGML_TYPE_Q8_0: rb = (nc / 32) * 34; break;
907
- case GGML_TYPE_Q8_1: rb = (nc / 32) * 40; break;
940
+ case GGML_TYPE_Q8_1: rb = (nc / 32) * 36; break;
908
941
  case GGML_TYPE_Q2_K: rb = (nc / 256) * 84; break;
909
942
  case GGML_TYPE_Q3_K: rb = (nc / 256) * 110; break;
910
943
  case GGML_TYPE_Q4_K: rb = (nc / 256) * 144; break;
@@ -975,6 +1008,30 @@ static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
975
1008
  }
976
1009
  }
977
1010
 
1011
+ static int tensor_type_block_size(int type) {
1012
+ switch (type) {
1013
+ case GGML_TYPE_F32:
1014
+ case GGML_TYPE_F16:
1015
+ return 1;
1016
+ case GGML_TYPE_Q4_0:
1017
+ case GGML_TYPE_Q4_1:
1018
+ case GGML_TYPE_Q5_0:
1019
+ case GGML_TYPE_Q5_1:
1020
+ case GGML_TYPE_Q8_0:
1021
+ case GGML_TYPE_Q8_1:
1022
+ return QK8_0;
1023
+ case GGML_TYPE_Q2_K:
1024
+ case GGML_TYPE_Q3_K:
1025
+ case GGML_TYPE_Q4_K:
1026
+ case GGML_TYPE_Q5_K:
1027
+ case GGML_TYPE_Q6_K:
1028
+ case GGML_TYPE_Q8_K:
1029
+ return QK_K;
1030
+ default:
1031
+ return 0;
1032
+ }
1033
+ }
1034
+
978
1035
  static size_t get_row_bytes(int type, int n_cols) {
979
1036
  switch (type) {
980
1037
  case GGML_TYPE_F32: return n_cols * sizeof(float);
@@ -1261,6 +1318,11 @@ static EmbedModel *embed_load_gguf(const char *path) {
1261
1318
  for (uint32_t d = 0; d < t->n_dims; d++) t->dims[d] = rd64(&cur, end);
1262
1319
  t->type = (int)rd32(&cur, end);
1263
1320
  uint64_t offset = rd64(&cur, end);
1321
+ int block_size = tensor_type_block_size(t->type);
1322
+ if (block_size == 0 || t->dims[0] % (uint64_t)block_size != 0) {
1323
+ free_model_contents(m);
1324
+ return NULL;
1325
+ }
1264
1326
  t->row_bytes = get_row_bytes(t->type, (int)t->dims[0]);
1265
1327
  if (t->row_bytes == 0) { free_model_contents(m); return NULL; }
1266
1328
  t->data = (const uint8_t*)(uintptr_t)offset;
@@ -1343,9 +1405,39 @@ static void tensor_get_row(const Tensor *t, int row, float *out) {
1343
1405
  case GGML_TYPE_Q4_0:
1344
1406
  dequantize_row_q4_0(raw, out, cols);
1345
1407
  break;
1408
+ case GGML_TYPE_Q4_1:
1409
+ dequantize_row_q4_1(raw, out, cols);
1410
+ break;
1411
+ case GGML_TYPE_Q5_0:
1412
+ dequantize_row_q5_0(raw, out, cols);
1413
+ break;
1414
+ case GGML_TYPE_Q5_1:
1415
+ dequantize_row_q5_1(raw, out, cols);
1416
+ break;
1346
1417
  case GGML_TYPE_Q8_0:
1347
1418
  dequantize_row_q8_0(raw, out, cols);
1348
1419
  break;
1420
+ case GGML_TYPE_Q8_1:
1421
+ dequantize_row_q8_1(raw, out, cols);
1422
+ break;
1423
+ case GGML_TYPE_Q2_K:
1424
+ dequantize_row_q2_K(raw, out, cols);
1425
+ break;
1426
+ case GGML_TYPE_Q3_K:
1427
+ dequantize_row_q3_K(raw, out, cols);
1428
+ break;
1429
+ case GGML_TYPE_Q4_K:
1430
+ dequantize_row_q4_K(raw, out, cols);
1431
+ break;
1432
+ case GGML_TYPE_Q5_K:
1433
+ dequantize_row_q5_K(raw, out, cols);
1434
+ break;
1435
+ case GGML_TYPE_Q6_K:
1436
+ dequantize_row_q6_K(raw, out, cols);
1437
+ break;
1438
+ case GGML_TYPE_Q8_K:
1439
+ dequantize_row_q8_K(raw, out, cols);
1440
+ break;
1349
1441
  default:
1350
1442
  memset(out, 0, (size_t)cols * sizeof(float));
1351
1443
  break;
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: mini_embed
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Makapoxa