mini_embed 0.3.0 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (4) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -0
  3. data/ext/mini_embed/mini_embed.c +643 -159
  4. metadata +1 -1
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d6230ebfba3a401a8d26543f106e46952198bc6e89c5cf7632da40346933cf64
4
- data.tar.gz: d5d37dd58c4bb3671053acb280db02ebb2ef78722d9c115f57f2594ad3a9ab50
3
+ metadata.gz: d41689bea618e06a4f1b94ee27aa4ed2c68a8d13677b5c31473e50cce5812062
4
+ data.tar.gz: fbea8167ae5dc748a00ec2670150ccf6d2a38ca4861a90df1e287f0c6adb9854
5
5
  SHA512:
6
- metadata.gz: a826aad05808580120035f689412afdf976d77637cd9b1cb57df02740a7c86efff0120bf7fba172498e0b5d7ed82617bac99777e731d33e84bcbb823db543e29
7
- data.tar.gz: f5bb3db889b9c51348daed59c3fbab9496237c3e9a64cb908ef386a1093e5e678531a5ad10eb051d0614dbe1fb9217d93a32049e6a5b8392b053d2474d6e9606
6
+ metadata.gz: 4ee669edc9f38921ec3d195cfe8a781c64c107d70f3fee8d1e47ef0e916161294c6243af0f223c2f8efea76c57c329a12c6b72566233b791839f5b0909058efc
7
+ data.tar.gz: 370a99e583b830ac99b4bd70eca14769d7bda162f5acab4456e63cec226cb1d3631831cf92bf87c2523909cdfec7fd231d5746cef9c30becfdda1da7400eea0e
data/README.md CHANGED
@@ -1,5 +1,7 @@
1
1
  # mini_embed
2
2
 
3
+ [![CircleCI](https://dl.circleci.com/status-badge/img/gh/Makapoxa/mini_embed/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/Makapoxa/mini_embed/tree/main) [![Gem Version](https://badge.fury.io/rb/mini_embed.svg)](https://badge.fury.io/rb/mini_embed)
4
+
3
5
  A minimal, dependency‑free C extension for Ruby that loads [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) embedding models and computes text embeddings **locally**.
4
6
 
5
7
  **⚠️ Important:** This gem is intended for **small projects, prototypes, and hobbyist use**. It allows you to experiment with embeddings without relying on external APIs or cloud costs. **Do not use MiniEmbed in production** – it lacks the performance, scalability, and tokenization robustness of dedicated solutions. For real applications, use a proper inference server like [llama.cpp](https://github.com/ggerganov/llama.cpp) with its HTTP API, or managed services such as OpenAI, Cohere, or Hugging Face.
@@ -440,9 +440,23 @@ typedef struct HashNode {
440
440
  struct HashNode *next;
441
441
  } HashNode;
442
442
 
443
+ typedef struct {
444
+ char *name;
445
+ uint32_t n_dims;
446
+ uint64_t dims[MAX_DIMS];
447
+ int type;
448
+ const uint8_t *data;
449
+ size_t row_bytes;
450
+ } Tensor;
451
+
443
452
  typedef struct {
444
453
  int vocab_size;
445
454
  int dim;
455
+ int n_layers;
456
+ int n_heads;
457
+ int n_ctx;
458
+ int n_ff;
459
+ float eps;
446
460
  char **tokens;
447
461
  void *mapped;
448
462
  size_t mapped_size;
@@ -460,6 +474,11 @@ typedef struct {
460
474
  int need_transpose;
461
475
  uint64_t raw_dim0, raw_dim1;
462
476
  int normalize;
477
+ Tensor *tensors;
478
+ int n_tensors;
479
+ int sep_token_id;
480
+ int pad_token_id;
481
+ int cls_token_id;
463
482
  } EmbedModel;
464
483
 
465
484
  typedef struct {
@@ -541,6 +560,36 @@ static float fp16_to_fp32(uint16_t h) {
541
560
  return result;
542
561
  }
543
562
 
563
+ static uint16_t fp32_to_fp16(float f) {
564
+ uint32_t x;
565
+ memcpy(&x, &f, sizeof(x));
566
+
567
+ uint32_t sign = (x >> 16) & 0x8000;
568
+ int exp = ((x >> 23) & 0xFF) - 127 + 15;
569
+ uint32_t mant = x & 0x7FFFFF;
570
+
571
+ if (exp <= 0) {
572
+ if (exp < -10) return (uint16_t)sign;
573
+ mant |= 0x800000;
574
+ uint32_t t = mant >> (1 - exp);
575
+ if (t & 0x00001000) t += 0x00002000;
576
+ return (uint16_t)(sign | (t >> 13));
577
+ } else if (exp >= 31) {
578
+ if (mant == 0) return (uint16_t)(sign | 0x7C00);
579
+ return (uint16_t)(sign | 0x7C00 | (mant >> 13));
580
+ } else {
581
+ if (mant & 0x00001000) {
582
+ mant += 0x00002000;
583
+ if (mant & 0x00800000) {
584
+ mant = 0;
585
+ exp += 1;
586
+ }
587
+ }
588
+ if (exp >= 31) return (uint16_t)(sign | 0x7C00);
589
+ return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13));
590
+ }
591
+ }
592
+
544
593
  /* ------------------------------------------------------------------------- */
545
594
  // Block dequantization functions (correct sizes)
546
595
  static void dequantize_row_q4_0(const void *vx, float *y, int k) {
@@ -552,9 +601,9 @@ static void dequantize_row_q4_0(const void *vx, float *y, int k) {
552
601
  memcpy(&d16, block, 2);
553
602
  const float d = fp16_to_fp32(d16);
554
603
  const uint8_t *q = block + 2;
555
- for (int j = 0; j < 32; j++) {
556
- const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
557
- y[i*32 + j] = (v - 8.0f) * d;
604
+ for (int j = 0; j < 16; j++) {
605
+ y[i*32 + j] = ((q[j] & 0x0F) - 8.0f) * d;
606
+ y[i*32 + j + 16] = ((q[j] >> 4) - 8.0f) * d;
558
607
  }
559
608
  }
560
609
  }
@@ -570,9 +619,9 @@ static void dequantize_row_q4_1(const void *vx, float *y, int k) {
570
619
  const float d = fp16_to_fp32(d16);
571
620
  const float m = fp16_to_fp32(m16);
572
621
  const uint8_t *q = block + 4;
573
- for (int j = 0; j < 32; j++) {
574
- const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
575
- y[i*32 + j] = v * d + m;
622
+ for (int j = 0; j < 16; j++) {
623
+ y[i*32 + j] = (q[j] & 0x0F) * d + m;
624
+ y[i*32 + j + 16] = (q[j] >> 4) * d + m;
576
625
  }
577
626
  }
578
627
  }
@@ -588,10 +637,13 @@ static void dequantize_row_q5_0(const void *vx, float *y, int k) {
588
637
  uint32_t qh32;
589
638
  memcpy(&qh32, block + 2, 4);
590
639
  const uint8_t *ql = block + 6;
591
- for (int j = 0; j < 32; j++) {
592
- const uint8_t vh = (qh32 >> j) & 1;
593
- const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
594
- y[i*32 + j] = (v - 16.0f) * d;
640
+ for (int j = 0; j < 16; j++) {
641
+ const uint8_t xh0 = ((qh32 >> (j + 0)) << 4) & 0x10;
642
+ const uint8_t xh1 = ((qh32 >> (j + 12))) & 0x10;
643
+ const int x0 = ((ql[j] & 0x0F) | xh0) - 16;
644
+ const int x1 = ((ql[j] >> 4) | xh1) - 16;
645
+ y[i*32 + j] = x0 * d;
646
+ y[i*32 + j + 16] = x1 * d;
595
647
  }
596
648
  }
597
649
  }
@@ -609,10 +661,13 @@ static void dequantize_row_q5_1(const void *vx, float *y, int k) {
609
661
  uint32_t qh32;
610
662
  memcpy(&qh32, block + 4, 4);
611
663
  const uint8_t *ql = block + 8;
612
- for (int j = 0; j < 32; j++) {
613
- const uint8_t vh = (qh32 >> j) & 1;
614
- const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
615
- y[i*32 + j] = v * d + m;
664
+ for (int j = 0; j < 16; j++) {
665
+ const uint8_t xh0 = ((qh32 >> (j + 0)) << 4) & 0x10;
666
+ const uint8_t xh1 = ((qh32 >> (j + 12))) & 0x10;
667
+ const int x0 = (ql[j] & 0x0F) | xh0;
668
+ const int x1 = (ql[j] >> 4) | xh1;
669
+ y[i*32 + j] = x0 * d + m;
670
+ y[i*32 + j + 16] = x1 * d + m;
616
671
  }
617
672
  }
618
673
  }
@@ -622,9 +677,10 @@ static void dequantize_row_q8_0(const void *vx, float *y, int k) {
622
677
  const uint8_t *x = vx;
623
678
  for (int i = 0; i < nb; i++) {
624
679
  const uint8_t *block = x + i * 34;
625
- float d;
626
- memcpy(&d, block, 4);
627
- const int8_t *q = (const int8_t*)(block + 4);
680
+ uint16_t d16;
681
+ memcpy(&d16, block, 2);
682
+ const float d = fp16_to_fp32(d16);
683
+ const int8_t *q = (const int8_t*)(block + 2);
628
684
  for (int j = 0; j < 32; j++) {
629
685
  y[i*32 + j] = (float)q[j] * d;
630
686
  }
@@ -635,13 +691,15 @@ static void dequantize_row_q8_1(const void *vx, float *y, int k) {
635
691
  const int nb = k / QK8_0;
636
692
  const uint8_t *x = vx;
637
693
  for (int i = 0; i < nb; i++) {
638
- const uint8_t *block = x + i * 40;
639
- float d, s;
640
- memcpy(&d, block, 4);
641
- memcpy(&s, block + 4, 4);
642
- const int8_t *q = (const int8_t*)(block + 8);
694
+ const uint8_t *block = x + i * 36;
695
+ uint16_t d16, s16;
696
+ memcpy(&d16, block, 2);
697
+ memcpy(&s16, block + 2, 2);
698
+ const float d = fp16_to_fp32(d16);
699
+ (void)s16;
700
+ const int8_t *q = (const int8_t*)(block + 4);
643
701
  for (int j = 0; j < 32; j++) {
644
- y[i*32 + j] = (float)q[j] * d + s;
702
+ y[i*32 + j] = (float)q[j] * d;
645
703
  }
646
704
  }
647
705
  }
@@ -652,8 +710,8 @@ static inline void get_scale_min_k4(int j, const uint8_t *q, uint8_t *d, uint8_t
652
710
  *d = q[j] & 63;
653
711
  *m = q[j + 4] & 63;
654
712
  } else {
655
- *d = (q[j+4] & 0xF) | ((q[j-3] >> 6) << 4);
656
- *m = (q[j+4] >> 4) | ((q[j-1] >> 6) << 4);
713
+ *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
714
+ *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
657
715
  }
658
716
  }
659
717
 
@@ -663,19 +721,30 @@ static void dequantize_row_q2_K(const void *vx, float *y, int k) {
663
721
  for (int i = 0; i < nb; i++) {
664
722
  const uint8_t *block = x + i * 84;
665
723
  uint16_t d16, dmin16;
666
- memcpy(&d16, block, 2);
667
- memcpy(&dmin16, block + 2, 2);
724
+ memcpy(&d16, block + 80, 2);
725
+ memcpy(&dmin16, block + 82, 2);
668
726
  const float d = fp16_to_fp32(d16);
669
727
  const float min = fp16_to_fp32(dmin16);
670
- const uint8_t *scales = block + 4;
671
- const uint8_t *q = block + 20;
672
- for (int j = 0; j < QK_K; j += 64) {
673
- const float dl = d * (scales[j/64] & 0xF);
674
- const float ml = min * (scales[j/64] >> 4);
675
- for (int l = 0; l < 64; l++) {
676
- const int v = (q[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
677
- y[i*QK_K + j + l] = v * dl + ml;
728
+ const uint8_t *scales = block;
729
+ const uint8_t *q = block + 16;
730
+ float *dst = y + (size_t)i * QK_K;
731
+ int is = 0;
732
+ for (int n = 0; n < QK_K; n += 128) {
733
+ int shift = 0;
734
+ for (int j = 0; j < 4; j++) {
735
+ uint8_t sc = scales[is++];
736
+ float dl = d * (sc & 0x0F);
737
+ float ml = min * (sc >> 4);
738
+ for (int l = 0; l < 16; l++) *dst++ = dl * ((q[l] >> shift) & 3) - ml;
739
+
740
+ sc = scales[is++];
741
+ dl = d * (sc & 0x0F);
742
+ ml = min * (sc >> 4);
743
+ for (int l = 0; l < 16; l++) *dst++ = dl * ((q[l + 16] >> shift) & 3) - ml;
744
+
745
+ shift += 2;
678
746
  }
747
+ q += 32;
679
748
  }
680
749
  }
681
750
  }
@@ -683,30 +752,45 @@ static void dequantize_row_q2_K(const void *vx, float *y, int k) {
683
752
  static void dequantize_row_q3_K(const void *vx, float *y, int k) {
684
753
  const int nb = k / QK_K;
685
754
  const uint8_t *x = vx;
755
+ const uint32_t kmask1 = 0x03030303;
756
+ const uint32_t kmask2 = 0x0f0f0f0f;
757
+ uint32_t aux[4];
758
+ const int8_t *scales = (const int8_t*)aux;
686
759
  for (int i = 0; i < nb; i++) {
687
760
  const uint8_t *block = x + i * 110;
688
761
  uint16_t d16;
689
- memcpy(&d16, block, 2);
690
- const float d = fp16_to_fp32(d16);
691
- const uint8_t *hmask = block + 2;
692
- const uint8_t *q = block + 34;
693
- const uint8_t *scales = block + 98;
694
- for (int j = 0; j < QK_K; j += 64) {
695
- const uint8_t ls1 = scales[j/64] & 0x1F;
696
- const uint8_t ls2 = (scales[j/64] >> 5) | ((scales[j/64 + 1] & 0x7) << 3);
697
- const uint8_t ls3 = ((scales[j/64 + 1] >> 3) & 0x1F);
698
- const uint8_t ls4 = (scales[j/64 + 1] >> 8);
699
- for (int l = 0; l < 64; l++) {
700
- int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
701
- const int bit = (hmask[(j+l)/8] >> ((j+l)%8)) & 1;
702
- v |= bit << 4;
703
- float ls;
704
- if (l < 16) ls = ls1;
705
- else if (l < 32) ls = ls2;
706
- else if (l < 48) ls = ls3;
707
- else ls = ls4;
708
- y[i*QK_K + j + l] = (v - 32.0f) * d * ls;
762
+ memcpy(&d16, block + 108, 2);
763
+ const float d_all = fp16_to_fp32(d16);
764
+ const uint8_t *q = block + 32;
765
+ const uint8_t *hm = block;
766
+ uint8_t m = 1;
767
+ float *dst = y + (size_t)i * QK_K;
768
+
769
+ memcpy(aux, block + 96, 12);
770
+ uint32_t tmp = aux[2];
771
+ aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
772
+ aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
773
+ aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
774
+ aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
775
+
776
+ int is = 0;
777
+ for (int n = 0; n < QK_K; n += 128) {
778
+ int shift = 0;
779
+ for (int j = 0; j < 4; j++) {
780
+ float dl = d_all * (scales[is++] - 32);
781
+ for (int l = 0; l < 16; l++) {
782
+ *dst++ = dl * ((int)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
783
+ }
784
+
785
+ dl = d_all * (scales[is++] - 32);
786
+ for (int l = 0; l < 16; l++) {
787
+ *dst++ = dl * ((int)((q[l + 16] >> shift) & 3) - ((hm[l + 16] & m) ? 0 : 4));
788
+ }
789
+
790
+ shift += 2;
791
+ m <<= 1;
709
792
  }
793
+ q += 32;
710
794
  }
711
795
  }
712
796
  }
@@ -758,6 +842,7 @@ static void dequantize_row_q5_K(const void *vx, float *y, int k) {
758
842
  const uint8_t *qh = block + 16;
759
843
  const uint8_t *ql = block + 48;
760
844
  int is = 0;
845
+ uint8_t u1 = 1, u2 = 2;
761
846
  for (int j = 0; j < QK_K; j += 64) {
762
847
  uint8_t sc, m;
763
848
  get_scale_min_k4(is, scales, &sc, &m);
@@ -767,17 +852,17 @@ static void dequantize_row_q5_K(const void *vx, float *y, int k) {
767
852
  float d2 = d * sc;
768
853
  float m2 = min * m;
769
854
  for (int l = 0; l < 32; l++) {
770
- int vh = (qh[j/64 * 4 + l/8] >> (l%8)) & 1;
771
- int v = (ql[l] & 0xF) | (vh << 4);
855
+ int v = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0);
772
856
  y[i*QK_K + j + l] = d1 * v - m1;
773
857
  }
774
858
  for (int l = 0; l < 32; l++) {
775
- int vh = (qh[j/64 * 4 + 4 + l/8] >> (l%8)) & 1;
776
- int v = (ql[l] >> 4) | (vh << 4);
859
+ int v = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0);
777
860
  y[i*QK_K + j + 32 + l] = d2 * v - m2;
778
861
  }
779
862
  ql += 32;
780
863
  is += 2;
864
+ u1 <<= 2;
865
+ u2 <<= 2;
781
866
  }
782
867
  }
783
868
  }
@@ -793,23 +878,23 @@ static void dequantize_row_q6_K(const void *vx, float *y, int k) {
793
878
  uint16_t d16;
794
879
  memcpy(&d16, block + 208, 2);
795
880
  const float d = fp16_to_fp32(d16);
881
+ float *dst = y + (size_t)i * QK_K;
796
882
  for (int j = 0; j < QK_K; j += 128) {
797
883
  for (int l = 0; l < 32; l++) {
798
- int v = (ql[j/2 + l] & 0xF) | (((qh[j/4 + l/2] >> ((l%2)*4)) & 0xF) << 4);
799
- y[i*QK_K + j + l] = v * d * scales[j/128 * 8 + l/4];
800
- }
801
- for (int l = 0; l < 32; l++) {
802
- int v = (ql[j/2 + 32 + l] >> 4) | (((qh[j/4 + 16 + l/2] >> ((l%2)*4)) & 0xF) << 4);
803
- y[i*QK_K + j + 32 + l] = v * d * scales[j/128 * 8 + 8 + l/4];
804
- }
805
- for (int l = 0; l < 32; l++) {
806
- int v = (ql[j/2 + 64 + l] & 0xF) | (((qh[j/4 + 32 + l/2] >> ((l%2)*4)) & 0xF) << 4);
807
- y[i*QK_K + j + 64 + l] = v * d * scales[j/128 * 8 + 4 + l/4];
808
- }
809
- for (int l = 0; l < 32; l++) {
810
- int v = (ql[j/2 + 96 + l] >> 4) | (((qh[j/4 + 48 + l/2] >> ((l%2)*4)) & 0xF) << 4);
811
- y[i*QK_K + j + 96 + l] = v * d * scales[j/128 * 8 + 12 + l/4];
884
+ int is = l / 16;
885
+ int q1 = ((ql[l] & 0x0F) | (((qh[l] >> 0) & 3) << 4)) - 32;
886
+ int q2 = ((ql[l + 32] & 0x0F) | (((qh[l] >> 2) & 3) << 4)) - 32;
887
+ int q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
888
+ int q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
889
+ dst[l] = d * scales[is + 0] * q1;
890
+ dst[l + 32] = d * scales[is + 2] * q2;
891
+ dst[l + 64] = d * scales[is + 4] * q3;
892
+ dst[l + 96] = d * scales[is + 6] * q4;
812
893
  }
894
+ dst += 128;
895
+ ql += 64;
896
+ qh += 32;
897
+ scales += 8;
813
898
  }
814
899
  }
815
900
  }
@@ -852,7 +937,7 @@ static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
852
937
  case GGML_TYPE_Q5_0: rb = (nc / 32) * 22; break;
853
938
  case GGML_TYPE_Q5_1: rb = (nc / 32) * 24; break;
854
939
  case GGML_TYPE_Q8_0: rb = (nc / 32) * 34; break;
855
- case GGML_TYPE_Q8_1: rb = (nc / 32) * 40; break;
940
+ case GGML_TYPE_Q8_1: rb = (nc / 32) * 36; break;
856
941
  case GGML_TYPE_Q2_K: rb = (nc / 256) * 84; break;
857
942
  case GGML_TYPE_Q3_K: rb = (nc / 256) * 110; break;
858
943
  case GGML_TYPE_Q4_K: rb = (nc / 256) * 144; break;
@@ -923,6 +1008,30 @@ static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
923
1008
  }
924
1009
  }
925
1010
 
1011
+ static int tensor_type_block_size(int type) {
1012
+ switch (type) {
1013
+ case GGML_TYPE_F32:
1014
+ case GGML_TYPE_F16:
1015
+ return 1;
1016
+ case GGML_TYPE_Q4_0:
1017
+ case GGML_TYPE_Q4_1:
1018
+ case GGML_TYPE_Q5_0:
1019
+ case GGML_TYPE_Q5_1:
1020
+ case GGML_TYPE_Q8_0:
1021
+ case GGML_TYPE_Q8_1:
1022
+ return QK8_0;
1023
+ case GGML_TYPE_Q2_K:
1024
+ case GGML_TYPE_Q3_K:
1025
+ case GGML_TYPE_Q4_K:
1026
+ case GGML_TYPE_Q5_K:
1027
+ case GGML_TYPE_Q6_K:
1028
+ case GGML_TYPE_Q8_K:
1029
+ return QK_K;
1030
+ default:
1031
+ return 0;
1032
+ }
1033
+ }
1034
+
926
1035
  static size_t get_row_bytes(int type, int n_cols) {
927
1036
  switch (type) {
928
1037
  case GGML_TYPE_F32: return n_cols * sizeof(float);
@@ -932,7 +1041,7 @@ static size_t get_row_bytes(int type, int n_cols) {
932
1041
  case GGML_TYPE_Q5_0: return (n_cols / 32) * 22;
933
1042
  case GGML_TYPE_Q5_1: return (n_cols / 32) * 24;
934
1043
  case GGML_TYPE_Q8_0: return (n_cols / 32) * 34;
935
- case GGML_TYPE_Q8_1: return (n_cols / 32) * 40;
1044
+ case GGML_TYPE_Q8_1: return (n_cols / 32) * 36;
936
1045
  case GGML_TYPE_Q2_K: return (n_cols / 256) * 84;
937
1046
  case GGML_TYPE_Q3_K: return (n_cols / 256) * 110;
938
1047
  case GGML_TYPE_Q4_K: return (n_cols / 256) * 144;
@@ -949,6 +1058,7 @@ static int skip_value(uint8_t **p, uint8_t *end, uint32_t type) {
949
1058
  case 0: case 1: case 7: return safe_advance(p, end, 1);
950
1059
  case 2: case 3: return safe_advance(p, end, 2);
951
1060
  case 4: case 5: case 6: return safe_advance(p, end, 4);
1061
+ case 10: case 11: case 12: return safe_advance(p, end, 8);
952
1062
  case 8: {
953
1063
  uint64_t len = rd64(p, end);
954
1064
  return safe_advance(p, end, len);
@@ -982,6 +1092,10 @@ static void free_model_contents(EmbedModel *m) {
982
1092
  }
983
1093
  free(m->table);
984
1094
  }
1095
+ if (m->tensors) {
1096
+ for (int i = 0; i < m->n_tensors; i++) free(m->tensors[i].name);
1097
+ free(m->tensors);
1098
+ }
985
1099
  if (m->mapped) munmap(m->mapped, m->mapped_size);
986
1100
  bpe_merge_table_free(&m->merges);
987
1101
  free(m);
@@ -1047,6 +1161,21 @@ static void parse_merge(const char *merge_str, char **left, char **right) {
1047
1161
  }
1048
1162
  }
1049
1163
 
1164
+ static Tensor *find_tensor(EmbedModel *m, const char *name) {
1165
+ if (!m || !m->tensors) return NULL;
1166
+ for (int i = 0; i < m->n_tensors; i++) {
1167
+ if (strcmp(m->tensors[i].name, name) == 0) return &m->tensors[i];
1168
+ }
1169
+ return NULL;
1170
+ }
1171
+
1172
+ static float rd_float32(uint8_t **p, uint8_t *end) {
1173
+ uint32_t bits = rd32(p, end);
1174
+ float v;
1175
+ memcpy(&v, &bits, sizeof(v));
1176
+ return v;
1177
+ }
1178
+
1050
1179
  /* ------------------------------------------------------------------------- */
1051
1180
  static EmbedModel *embed_load_gguf(const char *path) {
1052
1181
  size_t sz;
@@ -1072,8 +1201,12 @@ static EmbedModel *embed_load_gguf(const char *path) {
1072
1201
  m->unknown_token_id = -1;
1073
1202
  m->bos_token_id = -1;
1074
1203
  m->eos_token_id = -1;
1204
+ m->sep_token_id = -1;
1205
+ m->pad_token_id = 0;
1206
+ m->cls_token_id = -1;
1075
1207
  m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
1076
1208
  m->normalize = NORM_NONE;
1209
+ m->eps = 1e-12f;
1077
1210
 
1078
1211
  int vocab_found = 0;
1079
1212
  for (uint64_t i = 0; i < n_kv; i++) {
@@ -1136,12 +1269,31 @@ static EmbedModel *embed_load_gguf(const char *path) {
1136
1269
  } else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
1137
1270
  char *pre = rdstr(&cur, end);
1138
1271
  free(pre);
1139
- } else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
1272
+ } else if (strcmp(key, "bert.block_count") == 0 && type == 4) {
1273
+ m->n_layers = (int)rd32(&cur, end);
1274
+ } else if (strcmp(key, "bert.context_length") == 0 && type == 4) {
1275
+ m->n_ctx = (int)rd32(&cur, end);
1276
+ } else if (strcmp(key, "bert.embedding_length") == 0 && type == 4) {
1277
+ m->dim = (int)rd32(&cur, end);
1278
+ } else if (strcmp(key, "bert.feed_forward_length") == 0 && type == 4) {
1279
+ m->n_ff = (int)rd32(&cur, end);
1280
+ } else if (strcmp(key, "bert.attention.head_count") == 0 && type == 4) {
1281
+ m->n_heads = (int)rd32(&cur, end);
1282
+ } else if (strcmp(key, "bert.attention.layer_norm_epsilon") == 0 && type == 6) {
1283
+ m->eps = rd_float32(&cur, end);
1284
+ } else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 4) {
1140
1285
  m->unknown_token_id = (int)rd32(&cur, end);
1141
- } else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
1286
+ } else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 4) {
1142
1287
  m->bos_token_id = (int)rd32(&cur, end);
1143
- } else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
1288
+ } else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 4) {
1144
1289
  m->eos_token_id = (int)rd32(&cur, end);
1290
+ m->sep_token_id = m->eos_token_id;
1291
+ } else if (strcmp(key, "tokenizer.ggml.seperator_token_id") == 0 && type == 4) {
1292
+ m->sep_token_id = (int)rd32(&cur, end);
1293
+ } else if (strcmp(key, "tokenizer.ggml.padding_token_id") == 0 && type == 4) {
1294
+ m->pad_token_id = (int)rd32(&cur, end);
1295
+ } else if (strcmp(key, "tokenizer.ggml.cls_token_id") == 0 && type == 4) {
1296
+ m->cls_token_id = (int)rd32(&cur, end);
1145
1297
  } else if (strcmp(key, "general.alignment") == 0 && type == 6) {
1146
1298
  rd32(&cur, end);
1147
1299
  } else {
@@ -1153,107 +1305,439 @@ static EmbedModel *embed_load_gguf(const char *path) {
1153
1305
  if (!vocab_found) { free_model_contents(m); return NULL; }
1154
1306
  detect_space_marker(m);
1155
1307
 
1156
- uint8_t *after_kv = cur;
1308
+ m->tensors = calloc((size_t)n_tensors, sizeof(Tensor));
1309
+ if (!m->tensors) { free_model_contents(m); return NULL; }
1310
+ m->n_tensors = (int)n_tensors;
1311
+
1312
+ for (uint64_t i = 0; i < n_tensors; i++) {
1313
+ Tensor *t = &m->tensors[i];
1314
+ t->name = rdstr(&cur, end);
1315
+ if (!t->name) { free_model_contents(m); return NULL; }
1316
+ t->n_dims = rd32(&cur, end);
1317
+ if (t->n_dims == 0 || t->n_dims > MAX_DIMS) { free_model_contents(m); return NULL; }
1318
+ for (uint32_t d = 0; d < t->n_dims; d++) t->dims[d] = rd64(&cur, end);
1319
+ t->type = (int)rd32(&cur, end);
1320
+ uint64_t offset = rd64(&cur, end);
1321
+ int block_size = tensor_type_block_size(t->type);
1322
+ if (block_size == 0 || t->dims[0] % (uint64_t)block_size != 0) {
1323
+ free_model_contents(m);
1324
+ return NULL;
1325
+ }
1326
+ t->row_bytes = get_row_bytes(t->type, (int)t->dims[0]);
1327
+ if (t->row_bytes == 0) { free_model_contents(m); return NULL; }
1328
+ t->data = (const uint8_t*)(uintptr_t)offset;
1329
+ }
1330
+
1157
1331
  align_to_32(&cur, end, base);
1158
- uint8_t *tensor_start = cur;
1159
- int embd_found = 0;
1160
-
1161
- for (int attempt = 0; attempt < 2; attempt++) {
1162
- cur = tensor_start;
1163
- for (uint64_t i = 0; i < n_tensors; i++) {
1164
- char *name = rdstr(&cur, end);
1165
- if (!name) break;
1166
- uint32_t n_dims = rd32(&cur, end);
1167
- uint64_t dims[MAX_DIMS] = {0};
1168
- for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
1169
- uint32_t type = rd32(&cur, end);
1170
- uint64_t offset = rd64(&cur, end);
1171
-
1172
- int is_token_embd = (strcmp(name, "token_embd.weight") == 0 ||
1173
- strcmp(name, "embeddings.word_embeddings.weight") == 0 ||
1174
- strcmp(name, "model.embed_tokens.weight") == 0);
1175
-
1176
- if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
1177
- if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
1178
- else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
1332
+ uint8_t *data_start = cur;
1333
+ for (int i = 0; i < m->n_tensors; i++) {
1334
+ Tensor *t = &m->tensors[i];
1335
+ uint64_t offset = (uint64_t)(uintptr_t)t->data;
1336
+ size_t rows = t->n_dims > 1 ? (size_t)t->dims[1] : 1;
1337
+ size_t total_size = rows * t->row_bytes;
1338
+ if (offset > (uint64_t)sz || data_start + offset < data_start ||
1339
+ data_start + offset + total_size > end) {
1340
+ free_model_contents(m);
1341
+ return NULL;
1342
+ }
1343
+ t->data = data_start + offset;
1344
+ }
1345
+
1346
+ Tensor *embd = find_tensor(m, "token_embd.weight");
1347
+ if (!embd) embd = find_tensor(m, "embeddings.word_embeddings.weight");
1348
+ if (!embd || embd->n_dims < 2 || embd->dims[1] != (uint64_t)m->vocab_size) {
1349
+ free_model_contents(m);
1350
+ return NULL;
1351
+ }
1352
+
1353
+ if (m->dim == 0) m->dim = (int)embd->dims[0];
1354
+ if (m->n_ctx == 0) m->n_ctx = 512;
1355
+ if (m->n_ff == 0) m->n_ff = m->dim * 4;
1356
+ if (m->n_heads == 0) m->n_heads = 12;
1357
+ if (m->n_layers == 0) m->n_layers = 12;
1358
+ if (m->cls_token_id < 0) m->cls_token_id = m->bos_token_id;
1359
+ if (m->sep_token_id < 0) m->sep_token_id = m->eos_token_id;
1360
+
1361
+ m->raw_tensor_data = embd->data;
1362
+ m->tensor_type = embd->type;
1363
+ m->row_bytes = embd->row_bytes;
1364
+ m->raw_dim0 = embd->dims[0];
1365
+ m->raw_dim1 = embd->dims[1];
1366
+ m->need_transpose = 0;
1367
+
1368
+ if (m->dim <= 0 || m->dim > MAX_DIM) {
1369
+ free_model_contents(m); return NULL;
1370
+ }
1371
+
1372
+ return m;
1373
+ }
1374
+
1375
+ /* ------------------------------------------------------------------------- */
1376
+ // L2 normalization
1377
+ static void normalize_l2(float *vec, int dim) {
1378
+ double sum = 0.0;
1379
+ for (int i = 0; i < dim; i++) sum += vec[i] * vec[i];
1380
+ double norm = sqrt(sum);
1381
+ if (norm > 0.0) {
1382
+ float inv = (float)(1.0 / norm);
1383
+ for (int i = 0; i < dim; i++) vec[i] *= inv;
1384
+ }
1385
+ }
1386
+
1387
+ static void tensor_get_row(const Tensor *t, int row, float *out) {
1388
+ if (!t || row < 0 || (t->n_dims > 1 && row >= (int)t->dims[1])) {
1389
+ return;
1390
+ }
1391
+
1392
+ const uint8_t *raw = t->data + (size_t)row * t->row_bytes;
1393
+ int cols = (int)t->dims[0];
1394
+ switch (t->type) {
1395
+ case GGML_TYPE_F32:
1396
+ memcpy(out, raw, (size_t)cols * sizeof(float));
1397
+ break;
1398
+ case GGML_TYPE_F16:
1399
+ for (int i = 0; i < cols; i++) {
1400
+ uint16_t h;
1401
+ memcpy(&h, raw + (size_t)i * sizeof(uint16_t), sizeof(uint16_t));
1402
+ out[i] = fp16_to_fp32(h);
1179
1403
  }
1404
+ break;
1405
+ case GGML_TYPE_Q4_0:
1406
+ dequantize_row_q4_0(raw, out, cols);
1407
+ break;
1408
+ case GGML_TYPE_Q4_1:
1409
+ dequantize_row_q4_1(raw, out, cols);
1410
+ break;
1411
+ case GGML_TYPE_Q5_0:
1412
+ dequantize_row_q5_0(raw, out, cols);
1413
+ break;
1414
+ case GGML_TYPE_Q5_1:
1415
+ dequantize_row_q5_1(raw, out, cols);
1416
+ break;
1417
+ case GGML_TYPE_Q8_0:
1418
+ dequantize_row_q8_0(raw, out, cols);
1419
+ break;
1420
+ case GGML_TYPE_Q8_1:
1421
+ dequantize_row_q8_1(raw, out, cols);
1422
+ break;
1423
+ case GGML_TYPE_Q2_K:
1424
+ dequantize_row_q2_K(raw, out, cols);
1425
+ break;
1426
+ case GGML_TYPE_Q3_K:
1427
+ dequantize_row_q3_K(raw, out, cols);
1428
+ break;
1429
+ case GGML_TYPE_Q4_K:
1430
+ dequantize_row_q4_K(raw, out, cols);
1431
+ break;
1432
+ case GGML_TYPE_Q5_K:
1433
+ dequantize_row_q5_K(raw, out, cols);
1434
+ break;
1435
+ case GGML_TYPE_Q6_K:
1436
+ dequantize_row_q6_K(raw, out, cols);
1437
+ break;
1438
+ case GGML_TYPE_Q8_K:
1439
+ dequantize_row_q8_K(raw, out, cols);
1440
+ break;
1441
+ default:
1442
+ memset(out, 0, (size_t)cols * sizeof(float));
1443
+ break;
1444
+ }
1445
+ }
1180
1446
 
1181
- if (!embd_found && is_token_embd) {
1182
- if (n_dims < 2 || dims[1] == 0) {
1183
- free(name); free_model_contents(m); return NULL;
1184
- }
1185
-
1186
- uint64_t ne0 = dims[0];
1187
- uint64_t ne1 = dims[1];
1188
-
1189
- int need_transpose = 0;
1190
- int dim;
1191
-
1192
- if (ne1 == (uint64_t)m->vocab_size) {
1193
- dim = (int)ne0;
1194
- need_transpose = 0;
1195
- } else if (ne0 == (uint64_t)m->vocab_size) {
1196
- dim = (int)ne1;
1197
- need_transpose = 1;
1198
- } else {
1199
- dim = (ne0 < ne1) ? (int)ne0 : (int)ne1;
1200
- need_transpose = (ne0 > ne1) ? 1 : 0;
1201
- }
1447
+ static const float *tensor_f32_data(const Tensor *t) {
1448
+ if (!t || t->type != GGML_TYPE_F32) return NULL;
1449
+ return (const float*)t->data;
1450
+ }
1202
1451
 
1203
- if (dim <= 0 || dim > MAX_DIM) {
1204
- free(name); free_model_contents(m); return NULL;
1205
- }
1452
+ static float dot_q4_0_q8_0_like_ggml(const uint8_t *raw, const float *x, int n) {
1453
+ int nb = n / QK8_0;
1454
+ float sumf = 0.0f;
1206
1455
 
1207
- size_t row_bytes = get_row_bytes(type, (int)(need_transpose ? ne1 : ne0));
1208
- size_t total_size = (size_t)(need_transpose ? ne1 : ne0) * row_bytes;
1209
-
1210
- if (offset >= sz || offset + total_size > sz) {
1211
- free(name);
1212
- free_model_contents(m);
1213
- return NULL;
1214
- }
1456
+ for (int ib = 0; ib < nb; ib++) {
1457
+ const uint8_t *block = raw + (size_t)ib * 18;
1458
+ uint16_t d16;
1459
+ memcpy(&d16, block, 2);
1460
+ const float dx = fp16_to_fp32(d16);
1461
+ const uint8_t *q = block + 2;
1215
1462
 
1216
- m->dim = dim;
1217
- m->raw_dim0 = ne0;
1218
- m->raw_dim1 = ne1;
1219
- m->need_transpose = need_transpose;
1220
- m->raw_tensor_data = base + offset;
1221
- m->tensor_type = type;
1222
- m->row_bytes = row_bytes;
1223
- embd_found = 1;
1224
- free(name);
1463
+ const float *xb = x + (size_t)ib * QK8_0;
1464
+ float amax = 0.0f;
1465
+ for (int j = 0; j < QK8_0; j++) {
1466
+ float av = fabsf(xb[j]);
1467
+ if (av > amax) amax = av;
1468
+ }
1469
+
1470
+ const float d = amax / 127.0f;
1471
+ const float id = d ? 1.0f / d : 0.0f;
1472
+ const float dy = fp16_to_fp32(fp32_to_fp16(d));
1473
+ int8_t qy[QK8_0];
1474
+ for (int j = 0; j < QK8_0; j++) qy[j] = (int8_t)roundf(xb[j] * id);
1475
+
1476
+ int sumi0 = 0;
1477
+ int sumi1 = 0;
1478
+ for (int j = 0; j < QK8_0/2; j++) {
1479
+ const int v0 = (q[j] & 0x0F) - 8;
1480
+ const int v1 = (q[j] >> 4) - 8;
1481
+ sumi0 += v0 * qy[j];
1482
+ sumi1 += v1 * qy[j + QK8_0/2];
1483
+ }
1484
+ sumf += (float)(sumi0 + sumi1) * dx * dy;
1485
+ }
1486
+
1487
+ return sumf;
1488
+ }
1489
+
1490
+ static int ascii_wordpiece_tokenize(EmbedModel *m, const char *txt, int *ids, int max_ids) {
1491
+ int n = 0;
1492
+ if (m->cls_token_id >= 0 && n < max_ids) ids[n++] = m->cls_token_id;
1493
+
1494
+ size_t len = strlen(txt);
1495
+ size_t i = 0;
1496
+ while (i < len && n < max_ids - 1) {
1497
+ while (i < len && isspace((unsigned char)txt[i])) i++;
1498
+ if (i >= len) break;
1499
+
1500
+ char word[256];
1501
+ int wl = 0;
1502
+ if (isalnum((unsigned char)txt[i])) {
1503
+ while (i < len && (isalnum((unsigned char)txt[i]) || txt[i] == '_') && wl < (int)sizeof(word) - 1) {
1504
+ word[wl++] = (char)tolower((unsigned char)txt[i++]);
1505
+ }
1506
+ while (i < len && (isalnum((unsigned char)txt[i]) || txt[i] == '_')) i++;
1507
+ } else {
1508
+ word[wl++] = txt[i++];
1509
+ }
1510
+ word[wl] = '\0';
1511
+ if (wl == 0) continue;
1512
+
1513
+ char word1[260];
1514
+ const char marker[] = "\xE2\x96\x81";
1515
+ memcpy(word1, marker, 3);
1516
+ memcpy(word1 + 3, word, (size_t)wl + 1);
1517
+ int w1l = wl + 3;
1518
+
1519
+ int current_tokens = n;
1520
+ for (int start = 0; start < w1l && n < max_ids - 1; start++) {
1521
+ int matched = 0;
1522
+ for (int end_pos = w1l; end_pos > start; end_pos--) {
1523
+ char piece[260];
1524
+ int plen = end_pos - start;
1525
+ memcpy(piece, word1 + start, plen);
1526
+ piece[plen] = '\0';
1527
+ int piece_id = hget(m, piece);
1528
+ if (piece_id >= 0) {
1529
+ ids[n++] = piece_id;
1530
+ start = end_pos - 1;
1531
+ matched = 1;
1532
+ break;
1533
+ }
1534
+ }
1535
+ if (!matched) {
1536
+ n = current_tokens;
1225
1537
  break;
1226
1538
  }
1227
- free(name);
1228
1539
  }
1229
- if (embd_found) break;
1230
- if (attempt == 0) {
1231
- tensor_start = find_tensor_info_start(after_kv, end);
1232
- if (!tensor_start) break;
1540
+
1541
+ if (n == current_tokens && m->unknown_token_id >= 0 && n < max_ids - 1) ids[n++] = m->unknown_token_id;
1542
+ }
1543
+
1544
+ if (m->sep_token_id >= 0 && n < max_ids) ids[n++] = m->sep_token_id;
1545
+ return n;
1546
+ }
1547
+
1548
+ static void linear_one(const Tensor *w, const Tensor *b, const float *x, float *out, float *row) {
1549
+ int in = (int)w->dims[0];
1550
+ int out_dim = (int)w->dims[1];
1551
+ const float *bias = tensor_f32_data(b);
1552
+ for (int o = 0; o < out_dim; o++) {
1553
+ float sum = bias ? bias[o] : 0.0f;
1554
+ if (w->type == GGML_TYPE_Q4_0) {
1555
+ const uint8_t *raw = w->data + (size_t)o * w->row_bytes;
1556
+ sum += dot_q4_0_q8_0_like_ggml(raw, x, in);
1557
+ } else {
1558
+ tensor_get_row(w, o, row);
1559
+ for (int i = 0; i < in; i++) sum += row[i] * x[i];
1233
1560
  }
1561
+ out[o] = sum;
1234
1562
  }
1563
+ }
1235
1564
 
1236
- if (!embd_found || m->dim == 0) {
1237
- free_model_contents(m); return NULL;
1565
+ static void linear_batch(const Tensor *w, const Tensor *b, const float *x, int seq, float *out, float *row) {
1566
+ int in = (int)w->dims[0];
1567
+ int out_dim = (int)w->dims[1];
1568
+ for (int t = 0; t < seq; t++) {
1569
+ linear_one(w, b, x + (size_t)t * in, out + (size_t)t * out_dim, row);
1238
1570
  }
1571
+ }
1239
1572
 
1240
- return m;
1573
+ static void layer_norm(const float *x, const Tensor *w, const Tensor *b, int seq, int dim, float eps, float *out) {
1574
+ const float *weight = tensor_f32_data(w);
1575
+ const float *bias = tensor_f32_data(b);
1576
+ for (int t = 0; t < seq; t++) {
1577
+ const float *src = x + (size_t)t * dim;
1578
+ float *dst = out + (size_t)t * dim;
1579
+ float mean = 0.0f;
1580
+ for (int i = 0; i < dim; i++) mean += src[i];
1581
+ mean /= (float)dim;
1582
+ float var = 0.0f;
1583
+ for (int i = 0; i < dim; i++) {
1584
+ float d = src[i] - mean;
1585
+ var += d * d;
1586
+ }
1587
+ var /= (float)dim;
1588
+ float scale = 1.0f / sqrtf(var + eps);
1589
+ for (int i = 0; i < dim; i++) {
1590
+ dst[i] = (src[i] - mean) * scale * (weight ? weight[i] : 1.0f) + (bias ? bias[i] : 0.0f);
1591
+ }
1592
+ }
1241
1593
  }
1242
1594
 
1243
- /* ------------------------------------------------------------------------- */
1244
- // L2 normalization
1245
- static void normalize_l2(float *vec, int dim) {
1246
- float sum = 0;
1247
- for (int i = 0; i < dim; i++) sum += vec[i] * vec[i];
1248
- float norm = sqrtf(sum);
1249
- if (norm > 1e-8f) {
1250
- float inv = 1.0f / norm;
1251
- for (int i = 0; i < dim; i++) vec[i] *= inv;
1595
+ static float gelu_approx(float x) {
1596
+ if (x <= -10.0f) return 0.0f;
1597
+ if (x >= 10.0f) return x;
1598
+ const float c = 0.7978845608028654f;
1599
+ float hx = fp16_to_fp32(fp32_to_fp16(x));
1600
+ float y = 0.5f * hx * (1.0f + tanhf(c * hx * (1.0f + 0.044715f * hx * hx)));
1601
+ return fp16_to_fp32(fp32_to_fp16(y));
1602
+ }
1603
+
1604
+ static int bert_embed_text(EmbedModel *m, const char *txt, float *out) {
1605
+ if (m->vocab_type != LLAMA_VOCAB_TYPE_WPM || !find_tensor(m, "blk.0.attn_q.weight")) return 0;
1606
+
1607
+ memset(out, 0, (size_t)m->dim * sizeof(float));
1608
+ if (!txt || !*txt) return 1;
1609
+
1610
+ int max_seq = m->n_ctx > 0 ? m->n_ctx : 512;
1611
+ int *ids = malloc((size_t)max_seq * sizeof(int));
1612
+ if (!ids) return 1;
1613
+ int seq = ascii_wordpiece_tokenize(m, txt, ids, max_seq);
1614
+ if (seq <= 0) { free(ids); return 1; }
1615
+
1616
+ int dim = m->dim;
1617
+ int ff = m->n_ff;
1618
+ int heads = m->n_heads;
1619
+ int head_dim = dim / heads;
1620
+ float *hidden = calloc((size_t)seq * dim, sizeof(float));
1621
+ float *tmp = calloc((size_t)seq * dim, sizeof(float));
1622
+ float *q = calloc((size_t)seq * dim, sizeof(float));
1623
+ float *k = calloc((size_t)seq * dim, sizeof(float));
1624
+ float *v = calloc((size_t)seq * dim, sizeof(float));
1625
+ float *ctx = calloc((size_t)seq * dim, sizeof(float));
1626
+ float *proj = calloc((size_t)seq * dim, sizeof(float));
1627
+ float *ffn = calloc((size_t)seq * ff, sizeof(float));
1628
+ float *row = malloc((size_t)(ff > dim ? ff : dim) * sizeof(float));
1629
+ float *scores = malloc((size_t)seq * sizeof(float));
1630
+ if (!hidden || !tmp || !q || !k || !v || !ctx || !proj || !ffn || !row || !scores) {
1631
+ free(ids); free(hidden); free(tmp); free(q); free(k); free(v); free(ctx); free(proj); free(ffn); free(row); free(scores);
1632
+ return 1;
1252
1633
  }
1634
+
1635
+ Tensor *tok_emb = find_tensor(m, "token_embd.weight");
1636
+ Tensor *pos_emb = find_tensor(m, "position_embd.weight");
1637
+ Tensor *typ_emb = find_tensor(m, "token_types.weight");
1638
+ Tensor *emb_norm_w = find_tensor(m, "token_embd_norm.weight");
1639
+ Tensor *emb_norm_b = find_tensor(m, "token_embd_norm.bias");
1640
+
1641
+ float *tok = row;
1642
+ float *pos = malloc((size_t)dim * sizeof(float));
1643
+ float *typ = malloc((size_t)dim * sizeof(float));
1644
+ if (!tok_emb || !pos_emb || !typ_emb || !pos || !typ) {
1645
+ free(ids); free(hidden); free(tmp); free(q); free(k); free(v); free(ctx); free(proj); free(ffn); free(row); free(scores); free(pos); free(typ);
1646
+ return 1;
1647
+ }
1648
+
1649
+ for (int t = 0; t < seq; t++) {
1650
+ tensor_get_row(tok_emb, ids[t], tok);
1651
+ tensor_get_row(pos_emb, t, pos);
1652
+ tensor_get_row(typ_emb, 0, typ);
1653
+ for (int d = 0; d < dim; d++) hidden[(size_t)t * dim + d] = tok[d] + pos[d] + typ[d];
1654
+ }
1655
+ layer_norm(hidden, emb_norm_w, emb_norm_b, seq, dim, m->eps, tmp);
1656
+ memcpy(hidden, tmp, (size_t)seq * dim * sizeof(float));
1657
+
1658
+ for (int layer = 0; layer < m->n_layers; layer++) {
1659
+ char name[80];
1660
+ #define TENSOR(suffix) (snprintf(name, sizeof(name), "blk.%d.%s", layer, suffix), find_tensor(m, name))
1661
+ Tensor *qw = TENSOR("attn_q.weight");
1662
+ Tensor *qb = TENSOR("attn_q.bias");
1663
+ Tensor *kw = TENSOR("attn_k.weight");
1664
+ Tensor *kb = TENSOR("attn_k.bias");
1665
+ Tensor *vw = TENSOR("attn_v.weight");
1666
+ Tensor *vb = TENSOR("attn_v.bias");
1667
+ Tensor *ow = TENSOR("attn_output.weight");
1668
+ Tensor *ob = TENSOR("attn_output.bias");
1669
+ Tensor *an_w = TENSOR("attn_output_norm.weight");
1670
+ Tensor *an_b = TENSOR("attn_output_norm.bias");
1671
+ Tensor *fu_w = TENSOR("ffn_up.weight");
1672
+ Tensor *fu_b = TENSOR("ffn_up.bias");
1673
+ Tensor *fd_w = TENSOR("ffn_down.weight");
1674
+ Tensor *fd_b = TENSOR("ffn_down.bias");
1675
+ Tensor *ln_w = TENSOR("layer_output_norm.weight");
1676
+ Tensor *ln_b = TENSOR("layer_output_norm.bias");
1677
+ #undef TENSOR
1678
+
1679
+ if (!qw || !qb || !kw || !kb || !vw || !vb || !ow || !ob || !an_w || !an_b ||
1680
+ !fu_w || !fu_b || !fd_w || !fd_b || !ln_w || !ln_b) break;
1681
+
1682
+ linear_batch(qw, qb, hidden, seq, q, row);
1683
+ linear_batch(kw, kb, hidden, seq, k, row);
1684
+ linear_batch(vw, vb, hidden, seq, v, row);
1685
+ memset(ctx, 0, (size_t)seq * dim * sizeof(float));
1686
+
1687
+ float att_scale = 1.0f / sqrtf((float)head_dim);
1688
+ for (int h = 0; h < heads; h++) {
1689
+ int off = h * head_dim;
1690
+ for (int ti = 0; ti < seq; ti++) {
1691
+ float max_score = -INFINITY;
1692
+ for (int tj = 0; tj < seq; tj++) {
1693
+ float dot = 0.0f;
1694
+ const float *qv0 = q + (size_t)ti * dim + off;
1695
+ const float *kv0 = k + (size_t)tj * dim + off;
1696
+ for (int d = 0; d < head_dim; d++) dot += qv0[d] * kv0[d];
1697
+ scores[tj] = dot * att_scale;
1698
+ if (scores[tj] > max_score) max_score = scores[tj];
1699
+ }
1700
+ double sum = 0.0;
1701
+ for (int tj = 0; tj < seq; tj++) {
1702
+ scores[tj] = expf(scores[tj] - max_score);
1703
+ sum += scores[tj];
1704
+ }
1705
+ float inv_sum = (float)(1.0 / sum);
1706
+ float *dst = ctx + (size_t)ti * dim + off;
1707
+ for (int tj = 0; tj < seq; tj++) {
1708
+ float p = scores[tj] * inv_sum;
1709
+ const float *vv0 = v + (size_t)tj * dim + off;
1710
+ for (int d = 0; d < head_dim; d++) dst[d] += p * vv0[d];
1711
+ }
1712
+ }
1713
+ }
1714
+
1715
+ linear_batch(ow, ob, ctx, seq, proj, row);
1716
+ for (int i = 0; i < seq * dim; i++) tmp[i] = hidden[i] + proj[i];
1717
+ layer_norm(tmp, an_w, an_b, seq, dim, m->eps, hidden);
1718
+
1719
+ linear_batch(fu_w, fu_b, hidden, seq, ffn, row);
1720
+ for (int i = 0; i < seq * ff; i++) ffn[i] = gelu_approx(ffn[i]);
1721
+ linear_batch(fd_w, fd_b, ffn, seq, proj, row);
1722
+ for (int i = 0; i < seq * dim; i++) tmp[i] = hidden[i] + proj[i];
1723
+ layer_norm(tmp, ln_w, ln_b, seq, dim, m->eps, hidden);
1724
+ }
1725
+
1726
+ for (int t = 0; t < seq; t++) {
1727
+ for (int d = 0; d < dim; d++) out[d] += hidden[(size_t)t * dim + d];
1728
+ }
1729
+ float inv = 1.0f / (float)seq;
1730
+ for (int d = 0; d < dim; d++) out[d] *= inv;
1731
+ normalize_l2(out, dim);
1732
+
1733
+ free(ids); free(hidden); free(tmp); free(q); free(k); free(v); free(ctx); free(proj); free(ffn); free(row); free(scores); free(pos); free(typ);
1734
+ return 1;
1253
1735
  }
1254
1736
 
1255
1737
  /* ------------------------------------------------------------------------- */
1256
1738
  static void embed_text(EmbedModel *m, const char *txt, float *out) {
1739
+ if (bert_embed_text(m, txt, out)) return;
1740
+
1257
1741
  memset(out, 0, sizeof(float) * m->dim);
1258
1742
  if (!txt || !*txt) return;
1259
1743
 
@@ -1413,4 +1897,4 @@ void Init_mini_embed(void) {
1413
1897
  rb_define_alloc_func(c, rb_embedder_alloc);
1414
1898
  rb_define_method(c, "initialize", rb_embedder_initialize, 1);
1415
1899
  rb_define_method(c, "embed", rb_embed, 1);
1416
- }
1900
+ }
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: mini_embed
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Makapoxa