mini_embed 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,13 +8,18 @@
8
8
  #include <fcntl.h>
9
9
  #include <unistd.h>
10
10
  #include <ctype.h>
11
+ #include <limits.h>
11
12
  #include "ruby.h"
12
13
 
13
- #define HASH_SIZE 131071
14
- #define MAX_DIMS 4
15
- #define GGUF_ALIGN 32
16
- #define MAX_MERGES 10000
17
- #define MAX_REGEX 256
14
+ #define HASH_SIZE 131071
15
+ #define MAX_DIMS 4
16
+ #define GGUF_ALIGN 32
17
+ #define MAX_MERGES 100000
18
+ #define MERGE_HASH_SIZE 65537
19
+ #define QK8_0 32
20
+ #define QK_K 256
21
+ #define K_SCALE_SIZE 12
22
+ #define MAX_DIM 16384
18
23
 
19
24
  enum ggml_type {
20
25
  GGML_TYPE_F32 = 0,
@@ -40,6 +45,11 @@ enum llama_vocab_type {
40
45
  LLAMA_VOCAB_TYPE_WPM = 3,
41
46
  };
42
47
 
48
+ enum normalize_type {
49
+ NORM_NONE = 0,
50
+ NORM_L2 = 1,
51
+ };
52
+
43
53
  /* ------------------------------------------------------------------------- */
44
54
  // Unicode helper functions
45
55
  static int unicode_len_utf8(char c) {
@@ -77,135 +87,233 @@ static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
77
87
  }
78
88
 
79
89
  /* ------------------------------------------------------------------------- */
80
- // Simple regex pattern matcher (simplified)
81
- typedef struct {
82
- char *pattern;
83
- int pattern_len;
84
- } RegexPattern;
85
-
86
- static int match_regex(const char *text, const RegexPattern *patterns, int num_patterns) {
87
- for (int i = 0; i < num_patterns; i++) {
88
- const char *p = patterns[i].pattern;
89
- if (strstr(p, "\\p{L}")) {
90
- size_t len;
91
- uint32_t cp = unicode_cpt_from_utf8(text, &len);
92
- if (unicode_is_letter(cp)) return 1;
93
- } else if (strstr(p, "\\p{N}")) {
94
- size_t len;
95
- uint32_t cp = unicode_cpt_from_utf8(text, &len);
96
- if (unicode_is_number(cp)) return 1;
97
- } else if (p[0] == '\\' && p[1] == 's') {
98
- if (isspace(text[0])) return 1;
99
- } else if (p[0] == '\\' && p[1] == 'r') {
100
- if (text[0] == '\r') return 1;
101
- } else if (p[0] == '\\' && p[1] == 'n') {
102
- if (text[0] == '\n') return 1;
103
- } else if (p[0] == '.' && p[1] == '*') {
104
- return 1;
105
- } else if (isalnum(p[0]) || ispunct(p[0])) {
106
- if (text[0] == p[0]) return 1;
107
- }
90
+ // Pre-tokenizer (GPT-2/Llama style, replaces broken regex)
91
+ #define CHAR_CLASS_SPACE 0
92
+ #define CHAR_CLASS_LETTER 1
93
+ #define CHAR_CLASS_NUMBER 2
94
+ #define CHAR_CLASS_NEWLINE 3
95
+ #define CHAR_CLASS_OTHER 4
96
+
97
+ static int get_char_class(uint32_t cp) {
98
+ if (unicode_is_letter(cp)) return CHAR_CLASS_LETTER;
99
+ if (unicode_is_number(cp)) return CHAR_CLASS_NUMBER;
100
+ if (cp == '\n' || cp == '\r') return CHAR_CLASS_NEWLINE;
101
+ if (cp == ' ' || cp == '\t') return CHAR_CLASS_SPACE;
102
+ return CHAR_CLASS_OTHER;
103
+ }
104
+
105
+ static int is_contraction(const char *text, size_t pos, size_t text_len) {
106
+ if (pos >= text_len) return 0;
107
+ unsigned char c = (unsigned char)text[pos];
108
+ if (c != '\'' && c != 0xE2) return 0;
109
+ if (c == 0xE2 && pos + 2 < text_len && text[pos+1] == 0x80 && (text[pos+2] == 0x99 || text[pos+2] == 0x98)) {
110
+ if (pos + 3 >= text_len) return 0;
111
+ char next = tolower((unsigned char)text[pos + 3]);
112
+ return next == 's' || next == 't' || next == 'r' || next == 'v' ||
113
+ next == 'm' || next == 'l' || next == 'd';
114
+ }
115
+ if (c == '\'' && pos + 1 < text_len) {
116
+ char next = tolower((unsigned char)text[pos + 1]);
117
+ return next == 's' || next == 't' || next == 'r' || next == 'v' ||
118
+ next == 'm' || next == 'l' || next == 'd';
108
119
  }
109
120
  return 0;
110
121
  }
111
122
 
112
- static char** unicode_regex_split(const char *text, const RegexPattern *patterns, int num_patterns, int *num_words) {
123
+ static size_t contraction_len(const char *text, size_t pos) {
124
+ unsigned char c = (unsigned char)text[pos];
125
+ if (c == '\'') return 2;
126
+ return 4;
127
+ }
128
+
129
+ static char** pre_tokenize(const char *text, int *num_words) {
113
130
  char **words = NULL;
114
131
  int word_count = 0, word_capacity = 0;
115
- size_t text_len = strlen(text), pos = 0;
116
-
117
- while (pos < text_len) {
118
- size_t start = pos;
119
- while (start < text_len && !match_regex(text + start, patterns, num_patterns)) start++;
120
- if (start >= text_len) break;
121
- size_t end = start;
122
- while (end < text_len && match_regex(text + end, patterns, num_patterns)) end++;
123
- if (end > start) {
124
- size_t word_len = end - start;
125
- char *word = malloc(word_len + 1);
126
- if (!word) { while (--word_count >= 0) free(words[word_count]); free(words); *num_words = 0; return NULL; }
127
- memcpy(word, text + start, word_len);
128
- word[word_len] = '\0';
129
- if (word_count >= word_capacity) {
130
- word_capacity = word_capacity ? word_capacity * 2 : 16;
131
- words = realloc(words, word_capacity * sizeof(char*));
132
- if (!words) { free(word); while (--word_count >= 0) free(words[word_count]); *num_words = 0; return NULL; }
132
+ size_t text_len = strlen(text);
133
+
134
+ if (text_len == 0) {
135
+ *num_words = 0;
136
+ return NULL;
137
+ }
138
+
139
+ #define ADD_WORD(ptr, len) do { \
140
+ char *w = malloc((len) + 1); \
141
+ if (!w) goto error; \
142
+ memcpy(w, ptr, len); \
143
+ w[len] = '\0'; \
144
+ if (word_count >= word_capacity) { \
145
+ word_capacity = word_capacity ? word_capacity * 2 : 16; \
146
+ char **nw = realloc(words, word_capacity * sizeof(char*)); \
147
+ if (!nw) { free(w); goto error; } \
148
+ words = nw; \
149
+ } \
150
+ words[word_count++] = w; \
151
+ } while(0)
152
+
153
+ size_t i = 0;
154
+ while (i < text_len) {
155
+ size_t char_len;
156
+ uint32_t cp = unicode_cpt_from_utf8(text + i, &char_len);
157
+ int cls = get_char_class(cp);
158
+
159
+ if (cls == CHAR_CLASS_NEWLINE) {
160
+ ADD_WORD(text + i, char_len);
161
+ i += char_len;
162
+ continue;
163
+ }
164
+
165
+ if (cls == CHAR_CLASS_SPACE) {
166
+ size_t space_start = i;
167
+ while (i < text_len) {
168
+ size_t cl;
169
+ uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
170
+ int cc = get_char_class(c);
171
+ if (cc != CHAR_CLASS_SPACE) break;
172
+ i += cl;
133
173
  }
134
- words[word_count++] = word;
174
+ if (i >= text_len) break;
175
+ size_t space_len = i - space_start;
176
+ ADD_WORD(text + space_start, space_len);
177
+ continue;
135
178
  }
136
- pos = end;
179
+
180
+ size_t start = i;
181
+ i += char_len;
182
+
183
+ while (i < text_len) {
184
+ size_t cl;
185
+ uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
186
+ int ccls = get_char_class(c);
187
+
188
+ if (is_contraction(text, i, text_len)) {
189
+ size_t clen = contraction_len(text, i);
190
+ i += clen;
191
+ continue;
192
+ }
193
+
194
+ if (ccls != cls) break;
195
+ if (cls == CHAR_CLASS_NUMBER) {
196
+ int digits = 0;
197
+ size_t check = start;
198
+ while (check < i) {
199
+ size_t dl;
200
+ uint32_t dc = unicode_cpt_from_utf8(text + check, &dl);
201
+ if (get_char_class(dc) == CHAR_CLASS_NUMBER) digits++;
202
+ check += dl;
203
+ }
204
+ if (digits >= 3) break;
205
+ }
206
+ i += cl;
207
+ }
208
+
209
+ ADD_WORD(text + start, i - start);
137
210
  }
211
+
212
+ #undef ADD_WORD
138
213
  *num_words = word_count;
139
214
  return words;
215
+
216
+ error:
217
+ for (int j = 0; j < word_count; j++) free(words[j]);
218
+ free(words);
219
+ *num_words = 0;
220
+ return NULL;
140
221
  }
141
222
 
142
223
  /* ------------------------------------------------------------------------- */
143
- // BPE merge structures
144
- typedef struct {
224
+ // BPE merge structures with hash table for O(1) lookup
225
+ typedef struct MergeHashNode {
145
226
  char *left;
146
227
  char *right;
147
- char *merged;
148
228
  int rank;
149
- } BPEMerge;
229
+ struct MergeHashNode *next;
230
+ } MergeHashNode;
150
231
 
151
232
  typedef struct {
152
- BPEMerge *merges;
233
+ MergeHashNode **table;
234
+ int table_size;
153
235
  int num_merges;
154
- int capacity;
155
236
  } BPEMergeTable;
156
237
 
238
+ static uint64_t merge_hash(const char *left, const char *right) {
239
+ uint64_t h = 0xcbf29ce484222325ULL;
240
+ while (*left) { h ^= (uint64_t)(unsigned char)*left++; h *= 0x100000001b3ULL; }
241
+ h ^= (uint64_t)' ';
242
+ h *= 0x100000001b3ULL;
243
+ while (*right) { h ^= (uint64_t)(unsigned char)*right++; h *= 0x100000001b3ULL; }
244
+ return h;
245
+ }
246
+
157
247
  static void bpe_merge_table_init(BPEMergeTable *table) {
158
- memset(table, 0, sizeof(*table));
248
+ table->table_size = MERGE_HASH_SIZE;
249
+ table->table = calloc(MERGE_HASH_SIZE, sizeof(MergeHashNode*));
250
+ table->num_merges = 0;
159
251
  }
160
252
 
161
- static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, const char *merged, int rank) {
162
- if (table->num_merges >= table->capacity) {
163
- table->capacity = table->capacity ? table->capacity * 2 : 100;
164
- table->merges = realloc(table->merges, table->capacity * sizeof(BPEMerge));
165
- }
166
- BPEMerge *m = &table->merges[table->num_merges++];
167
- m->left = strdup(left);
168
- m->right = strdup(right);
169
- m->merged = strdup(merged);
170
- m->rank = rank;
253
+ static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, int rank) {
254
+ uint64_t h = merge_hash(left, right) % table->table_size;
255
+ MergeHashNode *n = malloc(sizeof(MergeHashNode));
256
+ if (!n) return;
257
+ n->left = strdup(left);
258
+ n->right = strdup(right);
259
+ n->rank = rank;
260
+ n->next = table->table[h];
261
+ table->table[h] = n;
262
+ table->num_merges++;
171
263
  }
172
264
 
173
265
  static void bpe_merge_table_free(BPEMergeTable *table) {
174
- for (int i = 0; i < table->num_merges; i++) {
175
- free(table->merges[i].left);
176
- free(table->merges[i].right);
177
- free(table->merges[i].merged);
266
+ if (!table->table) return;
267
+ for (int i = 0; i < table->table_size; i++) {
268
+ MergeHashNode *n = table->table[i];
269
+ while (n) {
270
+ MergeHashNode *next = n->next;
271
+ free(n->left);
272
+ free(n->right);
273
+ free(n);
274
+ n = next;
275
+ }
178
276
  }
179
- free(table->merges);
180
- table->merges = NULL;
181
- table->num_merges = 0;
277
+ free(table->table);
278
+ table->table = NULL;
182
279
  }
183
280
 
184
281
  static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
185
- for (int i = 0; i < table->num_merges; i++) {
186
- if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0)
187
- return table->merges[i].rank;
282
+ uint64_t h = merge_hash(left, right) % table->table_size;
283
+ MergeHashNode *n = table->table[h];
284
+ while (n) {
285
+ if (strcmp(n->left, left) == 0 && strcmp(n->right, right) == 0)
286
+ return n->rank;
287
+ n = n->next;
188
288
  }
189
289
  return -1;
190
290
  }
191
291
 
192
292
  /* ------------------------------------------------------------------------- */
193
- // BPE tokenization
293
+ // BPE tokenization (correct iterative algorithm)
194
294
  typedef struct {
195
- char *text;
295
+ const char *text;
196
296
  int start, end;
197
297
  int prev, next;
198
298
  int used;
199
299
  } BPESymbol;
200
300
 
201
- static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int (*text_to_id)(void*, const char*), void *vocab_data, int *token_ids, int *num_tokens) {
301
+ static int text_to_id(void *vocab_data, const char *text);
302
+
303
+ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word,
304
+ void *vocab_data, int *token_ids, int *num_tokens) {
202
305
  int word_len = strlen(word);
306
+ if (word_len == 0) return;
307
+
203
308
  int num_symbols = 0;
204
309
  BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
310
+ if (!symbols) return;
311
+
205
312
  int offset = 0;
206
313
  while (offset < word_len) {
207
314
  int char_len = unicode_len_utf8(word[offset]);
208
- symbols[num_symbols].text = (char*)word + offset;
315
+ if (offset + char_len > word_len) char_len = word_len - offset;
316
+ symbols[num_symbols].text = word;
209
317
  symbols[num_symbols].start = offset;
210
318
  symbols[num_symbols].end = offset + char_len;
211
319
  symbols[num_symbols].prev = num_symbols - 1;
@@ -215,6 +323,8 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
215
323
  num_symbols++;
216
324
  }
217
325
 
326
+ if (num_symbols > 0) symbols[num_symbols - 1].next = -1;
327
+
218
328
  if (num_symbols <= 1) {
219
329
  int id = text_to_id(vocab_data, word);
220
330
  if (id != -1) token_ids[(*num_tokens)++] = id;
@@ -222,56 +332,61 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
222
332
  return;
223
333
  }
224
334
 
225
- typedef struct { int left, right, rank; } Bigram;
226
- Bigram *bigrams = malloc(num_symbols * num_symbols * sizeof(Bigram));
227
- int num_bigrams = 0;
228
- for (int i = 0; i < num_symbols - 1; i++) {
229
- if (symbols[i].used && symbols[i+1].used) {
230
- char *left_str = malloc(symbols[i].end - symbols[i].start + 1);
231
- char *right_str = malloc(symbols[i+1].end - symbols[i+1].start + 1);
232
- memcpy(left_str, symbols[i].text, symbols[i].end - symbols[i].start);
233
- memcpy(right_str, symbols[i+1].text, symbols[i+1].end - symbols[i+1].start);
234
- left_str[symbols[i].end - symbols[i].start] = '\0';
235
- right_str[symbols[i+1].end - symbols[i+1].start] = '\0';
236
- int rank = bpe_merge_rank(merges, left_str, right_str);
237
- if (rank != -1) {
238
- bigrams[num_bigrams].left = i;
239
- bigrams[num_bigrams].right = i+1;
240
- bigrams[num_bigrams].rank = rank;
241
- num_bigrams++;
335
+ while (1) {
336
+ int best_rank = INT_MAX;
337
+ int best_idx = -1;
338
+
339
+ int idx = 0;
340
+ while (idx != -1) {
341
+ int next = symbols[idx].next;
342
+ if (next != -1 && symbols[idx].used && symbols[next].used) {
343
+ int left_len = symbols[idx].end - symbols[idx].start;
344
+ int right_len = symbols[next].end - symbols[next].start;
345
+ char *left_str = malloc(left_len + 1);
346
+ char *right_str = malloc(right_len + 1);
347
+ if (left_str && right_str) {
348
+ memcpy(left_str, word + symbols[idx].start, left_len);
349
+ left_str[left_len] = '\0';
350
+ memcpy(right_str, word + symbols[next].start, right_len);
351
+ right_str[right_len] = '\0';
352
+ int rank = bpe_merge_rank(merges, left_str, right_str);
353
+ if (rank != -1 && rank < best_rank) {
354
+ best_rank = rank;
355
+ best_idx = idx;
356
+ }
357
+ }
358
+ free(left_str);
359
+ free(right_str);
242
360
  }
243
- free(left_str); free(right_str);
361
+ idx = symbols[idx].next;
244
362
  }
245
- }
246
- for (int i = 0; i < num_bigrams - 1; i++)
247
- for (int j = i+1; j < num_bigrams; j++)
248
- if (bigrams[i].rank > bigrams[j].rank) {
249
- Bigram tmp = bigrams[i];
250
- bigrams[i] = bigrams[j];
251
- bigrams[j] = tmp;
252
- }
253
363
 
254
- int *merged = calloc(num_symbols, sizeof(int));
255
- for (int i = 0; i < num_bigrams; i++) {
256
- int left = bigrams[i].left, right = bigrams[i].right;
257
- if (merged[left] || merged[right]) continue;
258
- symbols[left].end = symbols[right].end;
259
- symbols[left].next = symbols[right].next;
260
- merged[right] = 1;
261
- if (symbols[right].next < num_symbols) symbols[symbols[right].next].prev = left;
364
+ if (best_idx == -1) break;
365
+
366
+ int right_idx = symbols[best_idx].next;
367
+ symbols[best_idx].end = symbols[right_idx].end;
368
+ symbols[best_idx].next = symbols[right_idx].next;
369
+ symbols[right_idx].used = 0;
370
+
371
+ if (symbols[right_idx].next != -1) {
372
+ symbols[symbols[right_idx].next].prev = best_idx;
373
+ }
262
374
  }
263
375
 
264
376
  for (int i = 0; i < num_symbols; i++) {
265
- if (!merged[i] && symbols[i].used) {
266
- char *substr = malloc(symbols[i].end - symbols[i].start + 1);
267
- memcpy(substr, word + symbols[i].start, symbols[i].end - symbols[i].start);
268
- substr[symbols[i].end - symbols[i].start] = '\0';
269
- int id = text_to_id(vocab_data, substr);
270
- if (id != -1) token_ids[(*num_tokens)++] = id;
271
- free(substr);
377
+ if (symbols[i].used) {
378
+ int len = symbols[i].end - symbols[i].start;
379
+ char *substr = malloc(len + 1);
380
+ if (substr) {
381
+ memcpy(substr, word + symbols[i].start, len);
382
+ substr[len] = '\0';
383
+ int id = text_to_id(vocab_data, substr);
384
+ if (id != -1) token_ids[(*num_tokens)++] = id;
385
+ free(substr);
386
+ }
272
387
  }
273
388
  }
274
- free(bigrams); free(merged); free(symbols);
389
+ free(symbols);
275
390
  }
276
391
 
277
392
  /* ------------------------------------------------------------------------- */
@@ -329,36 +444,41 @@ typedef struct {
329
444
  int vocab_size;
330
445
  int dim;
331
446
  char **tokens;
332
- float *float_data;
333
- void *tensor_data;
334
- int tensor_type;
335
447
  void *mapped;
336
448
  size_t mapped_size;
337
449
  HashNode **table;
338
450
  BPEMergeTable merges;
339
- RegexPattern *pre_patterns;
340
- int num_pre_patterns;
341
451
  int unknown_token_id;
342
452
  int bos_token_id;
343
453
  int eos_token_id;
344
454
  int vocab_type;
345
455
  char space_marker[8];
456
+ int space_marker_len;
457
+ const void *raw_tensor_data;
458
+ int tensor_type;
459
+ size_t row_bytes;
460
+ int need_transpose;
461
+ uint64_t raw_dim0, raw_dim1;
462
+ int normalize;
346
463
  } EmbedModel;
347
464
 
348
465
  typedef struct {
349
466
  EmbedModel *model;
350
467
  } ruby_embedder;
351
468
 
352
- static unsigned long hash(const char *s) {
353
- unsigned long h = 5381;
354
- int c;
355
- while ((c = *s++)) h = ((h << 5) + h) + c;
469
+ static uint64_t vocab_hash(const char *s) {
470
+ uint64_t h = 0xcbf29ce484222325ULL;
471
+ while (*s) {
472
+ h ^= (uint64_t)(unsigned char)*s++;
473
+ h *= 0x100000001b3ULL;
474
+ }
356
475
  return h % HASH_SIZE;
357
476
  }
358
477
 
359
478
  static void hset(EmbedModel *m, char *k, int id) {
360
- unsigned long h = hash(k);
479
+ uint64_t h = vocab_hash(k);
361
480
  HashNode *n = malloc(sizeof(*n));
481
+ if (!n) return;
362
482
  n->key = k;
363
483
  n->id = id;
364
484
  n->next = m->table[h];
@@ -366,7 +486,8 @@ static void hset(EmbedModel *m, char *k, int id) {
366
486
  }
367
487
 
368
488
  static int hget(EmbedModel *m, const char *k) {
369
- HashNode *n = m->table[hash(k)];
489
+ if (!k || !m->table) return -1;
490
+ HashNode *n = m->table[vocab_hash(k)];
370
491
  while (n) {
371
492
  if (strcmp(n->key, k) == 0) return n->id;
372
493
  n = n->next;
@@ -386,30 +507,51 @@ static void *map_file(const char *path, size_t *size) {
386
507
  struct stat st;
387
508
  if (fstat(fd, &st) != 0) { close(fd); return NULL; }
388
509
  *size = st.st_size;
510
+ if (*size == 0) { close(fd); return NULL; }
389
511
  void *data = mmap(NULL, *size, PROT_READ, MAP_PRIVATE, fd, 0);
390
512
  close(fd);
391
513
  return data == MAP_FAILED ? NULL : data;
392
514
  }
393
515
 
394
516
  /* ------------------------------------------------------------------------- */
395
- // FP16 conversion
517
+ // FP16 conversion (corrected)
396
518
  static float fp16_to_fp32(uint16_t h) {
397
- uint16_t sign = (h >> 15) & 1;
398
- uint16_t exp = (h >> 10) & 0x1F;
399
- uint16_t mant = h & 0x3FF;
400
- if (exp == 0) return (mant / 1024.0f) * 6.103515625e-5f * (sign ? -1.0f : 1.0f);
401
- if (exp == 31) return 0.0f;
402
- return (1.0f + mant / 1024.0f) * (1 << (exp - 15)) * (sign ? -1.0f : 1.0f);
519
+ const uint32_t sign = (h >> 15) & 1;
520
+ const uint32_t exp = (h >> 10) & 0x1F;
521
+ const uint32_t mant = h & 0x3FF;
522
+
523
+ uint32_t f;
524
+ if (exp == 0) {
525
+ if (mant == 0) {
526
+ f = sign << 31;
527
+ } else {
528
+ uint32_t e = 0;
529
+ uint32_t m = mant;
530
+ while (!(m & 0x400)) { m <<= 1; e++; }
531
+ f = (sign << 31) | ((127 - 15 - e + 1) << 23) | ((m & 0x3FF) << 13);
532
+ }
533
+ } else if (exp == 31) {
534
+ f = (sign << 31) | (0xFF << 23) | (mant << 13);
535
+ } else {
536
+ f = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
537
+ }
538
+
539
+ float result;
540
+ memcpy(&result, &f, sizeof(result));
541
+ return result;
403
542
  }
404
543
 
405
544
  /* ------------------------------------------------------------------------- */
406
- // Block dequantization functions
545
+ // Block dequantization functions (correct sizes)
407
546
  static void dequantize_row_q4_0(const void *vx, float *y, int k) {
408
- const int nb = k / 32;
547
+ const int nb = k / QK8_0;
409
548
  const uint8_t *x = vx;
410
549
  for (int i = 0; i < nb; i++) {
411
- const float d = ((const float*)(x + i*34))[0];
412
- const uint8_t *q = x + i*34 + 4;
550
+ const uint8_t *block = x + i * 18;
551
+ uint16_t d16;
552
+ memcpy(&d16, block, 2);
553
+ const float d = fp16_to_fp32(d16);
554
+ const uint8_t *q = block + 2;
413
555
  for (int j = 0; j < 32; j++) {
414
556
  const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
415
557
  y[i*32 + j] = (v - 8.0f) * d;
@@ -418,12 +560,16 @@ static void dequantize_row_q4_0(const void *vx, float *y, int k) {
418
560
  }
419
561
 
420
562
  static void dequantize_row_q4_1(const void *vx, float *y, int k) {
421
- const int nb = k / 32;
563
+ const int nb = k / QK8_0;
422
564
  const uint8_t *x = vx;
423
565
  for (int i = 0; i < nb; i++) {
424
- const float d = ((const float*)(x + i*36))[0];
425
- const float m = ((const float*)(x + i*36))[1];
426
- const uint8_t *q = x + i*36 + 8;
566
+ const uint8_t *block = x + i * 20;
567
+ uint16_t d16, m16;
568
+ memcpy(&d16, block, 2);
569
+ memcpy(&m16, block + 2, 2);
570
+ const float d = fp16_to_fp32(d16);
571
+ const float m = fp16_to_fp32(m16);
572
+ const uint8_t *q = block + 4;
427
573
  for (int j = 0; j < 32; j++) {
428
574
  const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
429
575
  y[i*32 + j] = v * d + m;
@@ -432,14 +578,16 @@ static void dequantize_row_q4_1(const void *vx, float *y, int k) {
432
578
  }
433
579
 
434
580
  static void dequantize_row_q5_0(const void *vx, float *y, int k) {
435
- const int nb = k / 32;
581
+ const int nb = k / QK8_0;
436
582
  const uint8_t *x = vx;
437
583
  for (int i = 0; i < nb; i++) {
438
- const float d = ((const float*)(x + i*40))[0];
439
- const uint8_t *qh = x + i*40 + 4;
440
- const uint8_t *ql = x + i*40 + 8;
584
+ const uint8_t *block = x + i * 22;
585
+ uint16_t d16;
586
+ memcpy(&d16, block, 2);
587
+ const float d = fp16_to_fp32(d16);
441
588
  uint32_t qh32;
442
- memcpy(&qh32, qh, 4);
589
+ memcpy(&qh32, block + 2, 4);
590
+ const uint8_t *ql = block + 6;
443
591
  for (int j = 0; j < 32; j++) {
444
592
  const uint8_t vh = (qh32 >> j) & 1;
445
593
  const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
@@ -449,15 +597,18 @@ static void dequantize_row_q5_0(const void *vx, float *y, int k) {
449
597
  }
450
598
 
451
599
  static void dequantize_row_q5_1(const void *vx, float *y, int k) {
452
- const int nb = k / 32;
600
+ const int nb = k / QK8_0;
453
601
  const uint8_t *x = vx;
454
602
  for (int i = 0; i < nb; i++) {
455
- const float d = ((const float*)(x + i*44))[0];
456
- const float m = ((const float*)(x + i*44))[1];
457
- const uint8_t *qh = x + i*44 + 8;
458
- const uint8_t *ql = x + i*44 + 12;
603
+ const uint8_t *block = x + i * 24;
604
+ uint16_t d16, m16;
605
+ memcpy(&d16, block, 2);
606
+ memcpy(&m16, block + 2, 2);
607
+ const float d = fp16_to_fp32(d16);
608
+ const float m = fp16_to_fp32(m16);
459
609
  uint32_t qh32;
460
- memcpy(&qh32, qh, 4);
610
+ memcpy(&qh32, block + 4, 4);
611
+ const uint8_t *ql = block + 8;
461
612
  for (int j = 0; j < 32; j++) {
462
613
  const uint8_t vh = (qh32 >> j) & 1;
463
614
  const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
@@ -467,11 +618,13 @@ static void dequantize_row_q5_1(const void *vx, float *y, int k) {
467
618
  }
468
619
 
469
620
  static void dequantize_row_q8_0(const void *vx, float *y, int k) {
470
- const int nb = k / 32;
621
+ const int nb = k / QK8_0;
471
622
  const uint8_t *x = vx;
472
623
  for (int i = 0; i < nb; i++) {
473
- const float d = ((const float*)(x + i*36))[0];
474
- const int8_t *q = (const int8_t*)(x + i*36 + 4);
624
+ const uint8_t *block = x + i * 34;
625
+ float d;
626
+ memcpy(&d, block, 4);
627
+ const int8_t *q = (const int8_t*)(block + 4);
475
628
  for (int j = 0; j < 32; j++) {
476
629
  y[i*32 + j] = (float)q[j] * d;
477
630
  }
@@ -479,197 +632,315 @@ static void dequantize_row_q8_0(const void *vx, float *y, int k) {
479
632
  }
480
633
 
481
634
  static void dequantize_row_q8_1(const void *vx, float *y, int k) {
482
- const int nb = k / 32;
635
+ const int nb = k / QK8_0;
483
636
  const uint8_t *x = vx;
484
637
  for (int i = 0; i < nb; i++) {
485
- const float d = ((const float*)(x + i*40))[0];
486
- const float s = ((const float*)(x + i*40))[1];
487
- const int8_t *q = (const int8_t*)(x + i*40 + 8);
638
+ const uint8_t *block = x + i * 40;
639
+ float d, s;
640
+ memcpy(&d, block, 4);
641
+ memcpy(&s, block + 4, 4);
642
+ const int8_t *q = (const int8_t*)(block + 8);
488
643
  for (int j = 0; j < 32; j++) {
489
644
  y[i*32 + j] = (float)q[j] * d + s;
490
645
  }
491
646
  }
492
647
  }
493
648
 
649
+ // K-quant scale helpers
650
+ static inline void get_scale_min_k4(int j, const uint8_t *q, uint8_t *d, uint8_t *m) {
651
+ if (j < 4) {
652
+ *d = q[j] & 63;
653
+ *m = q[j + 4] & 63;
654
+ } else {
655
+ *d = (q[j+4] & 0xF) | ((q[j-3] >> 6) << 4);
656
+ *m = (q[j+4] >> 4) | ((q[j-1] >> 6) << 4);
657
+ }
658
+ }
659
+
494
660
  static void dequantize_row_q2_K(const void *vx, float *y, int k) {
495
- const int nb = k / 256;
661
+ const int nb = k / QK_K;
496
662
  const uint8_t *x = vx;
497
663
  for (int i = 0; i < nb; i++) {
498
- const float d = ((const float*)(x + i*336))[0];
499
- const float m = ((const float*)(x + i*336))[1];
500
- const uint8_t *q = x + i*336 + 8;
501
- const uint8_t *scales = q + 64;
502
- for (int j = 0; j < 256; j += 32) {
503
- const uint8_t ls = scales[j/32] & 0xF;
504
- const uint8_t ms = scales[j/32] >> 4;
505
- for (int l = 0; l < 32; l++) {
664
+ const uint8_t *block = x + i * 84;
665
+ uint16_t d16, dmin16;
666
+ memcpy(&d16, block, 2);
667
+ memcpy(&dmin16, block + 2, 2);
668
+ const float d = fp16_to_fp32(d16);
669
+ const float min = fp16_to_fp32(dmin16);
670
+ const uint8_t *scales = block + 4;
671
+ const uint8_t *q = block + 20;
672
+ for (int j = 0; j < QK_K; j += 64) {
673
+ const float dl = d * (scales[j/64] & 0xF);
674
+ const float ml = min * (scales[j/64] >> 4);
675
+ for (int l = 0; l < 64; l++) {
506
676
  const int v = (q[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
507
- const float dl = d * (ls - 32);
508
- const float ml = m * (ms - 32);
509
- y[i*256 + j + l] = v * dl + ml;
677
+ y[i*QK_K + j + l] = v * dl + ml;
510
678
  }
511
679
  }
512
680
  }
513
681
  }
514
682
 
515
683
  static void dequantize_row_q3_K(const void *vx, float *y, int k) {
516
- const int nb = k / 256;
684
+ const int nb = k / QK_K;
517
685
  const uint8_t *x = vx;
518
686
  for (int i = 0; i < nb; i++) {
519
- const float d = ((const float*)(x + i*352))[0];
520
- const uint8_t *q = x + i*352 + 4;
521
- const uint8_t *scales = q + 256;
522
- const uint8_t *h = scales + 32;
523
- for (int j = 0; j < 256; j += 64) {
687
+ const uint8_t *block = x + i * 110;
688
+ uint16_t d16;
689
+ memcpy(&d16, block, 2);
690
+ const float d = fp16_to_fp32(d16);
691
+ const uint8_t *hmask = block + 2;
692
+ const uint8_t *q = block + 34;
693
+ const uint8_t *scales = block + 98;
694
+ for (int j = 0; j < QK_K; j += 64) {
524
695
  const uint8_t ls1 = scales[j/64] & 0x1F;
525
- const uint8_t ls2 = (scales[j/64] >> 4) | ((scales[j/64+1] & 0x0F) << 4);
526
- const uint8_t ms = scales[j/64+1] >> 4;
696
+ const uint8_t ls2 = (scales[j/64] >> 5) | ((scales[j/64 + 1] & 0x7) << 3);
697
+ const uint8_t ls3 = ((scales[j/64 + 1] >> 3) & 0x1F);
698
+ const uint8_t ls4 = (scales[j/64 + 1] >> 8);
527
699
  for (int l = 0; l < 64; l++) {
528
700
  int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
529
- const int bit = (h[(j+l)/8] >> ((j+l)%8)) & 1;
701
+ const int bit = (hmask[(j+l)/8] >> ((j+l)%8)) & 1;
530
702
  v |= bit << 4;
531
- const float dl = d * (ls1 - 32);
532
- const float ml = (l < 32) ? (ls2 - 32) * d : (ms - 32) * d;
533
- y[i*256 + j + l] = v * dl + ml;
703
+ float ls;
704
+ if (l < 16) ls = ls1;
705
+ else if (l < 32) ls = ls2;
706
+ else if (l < 48) ls = ls3;
707
+ else ls = ls4;
708
+ y[i*QK_K + j + l] = (v - 32.0f) * d * ls;
534
709
  }
535
710
  }
536
711
  }
537
712
  }
538
713
 
539
714
  static void dequantize_row_q4_K(const void *vx, float *y, int k) {
540
- const int nb = k / 256;
715
+ const int nb = k / QK_K;
541
716
  const uint8_t *x = vx;
542
717
  for (int i = 0; i < nb; i++) {
543
- const float d = ((const float*)(x + i*416))[0];
544
- const float m = ((const float*)(x + i*416))[1];
545
- const uint8_t *q = x + i*416 + 8;
546
- const uint8_t *scales = q + 128;
547
- for (int j = 0; j < 256; j += 32) {
548
- const uint8_t ls = scales[j/32] & 0x3F;
549
- const uint8_t ms = scales[j/32] >> 6;
718
+ const uint8_t *block = x + i * 144;
719
+ uint16_t d16, dmin16;
720
+ memcpy(&d16, block, 2);
721
+ memcpy(&dmin16, block + 2, 2);
722
+ const float d = fp16_to_fp32(d16);
723
+ const float min = fp16_to_fp32(dmin16);
724
+ const uint8_t *scales = block + 4;
725
+ const uint8_t *q = block + 16;
726
+ int is = 0;
727
+ for (int j = 0; j < QK_K; j += 64) {
728
+ uint8_t sc, m;
729
+ get_scale_min_k4(is, scales, &sc, &m);
730
+ float d1 = d * sc;
731
+ float m1 = min * m;
732
+ get_scale_min_k4(is + 1, scales, &sc, &m);
733
+ float d2 = d * sc;
734
+ float m2 = min * m;
735
+ for (int l = 0; l < 32; l++) {
736
+ y[i*QK_K + j + l] = d1 * (q[l] & 0xF) - m1;
737
+ }
550
738
  for (int l = 0; l < 32; l++) {
551
- const int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
552
- const float dl = d * (ls - 32);
553
- const float ml = m * (ms - 2);
554
- y[i*256 + j + l] = v * dl + ml;
739
+ y[i*QK_K + j + 32 + l] = d2 * (q[l] >> 4) - m2;
555
740
  }
741
+ q += 32;
742
+ is += 2;
556
743
  }
557
744
  }
558
745
  }
559
746
 
560
747
  static void dequantize_row_q5_K(const void *vx, float *y, int k) {
561
- const int nb = k / 256;
748
+ const int nb = k / QK_K;
562
749
  const uint8_t *x = vx;
563
750
  for (int i = 0; i < nb; i++) {
564
- const float d = ((const float*)(x + i*448))[0];
565
- const float m = ((const float*)(x + i*448))[1];
566
- const uint8_t *q = x + i*448 + 8;
567
- const uint8_t *qh = q + 128;
568
- const uint8_t *scales = qh + 32;
569
- for (int j = 0; j < 256; j += 32) {
570
- const uint8_t ls = scales[j/32] & 0x3F;
571
- const uint8_t ms = scales[j/32] >> 6;
751
+ const uint8_t *block = x + i * 176;
752
+ uint16_t d16, dmin16;
753
+ memcpy(&d16, block, 2);
754
+ memcpy(&dmin16, block + 2, 2);
755
+ const float d = fp16_to_fp32(d16);
756
+ const float min = fp16_to_fp32(dmin16);
757
+ const uint8_t *scales = block + 4;
758
+ const uint8_t *qh = block + 16;
759
+ const uint8_t *ql = block + 48;
760
+ int is = 0;
761
+ for (int j = 0; j < QK_K; j += 64) {
762
+ uint8_t sc, m;
763
+ get_scale_min_k4(is, scales, &sc, &m);
764
+ float d1 = d * sc;
765
+ float m1 = min * m;
766
+ get_scale_min_k4(is + 1, scales, &sc, &m);
767
+ float d2 = d * sc;
768
+ float m2 = min * m;
572
769
  for (int l = 0; l < 32; l++) {
573
- int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
574
- const int bit = (qh[(j+l)/8] >> ((j+l)%8)) & 1;
575
- v |= bit << 4;
576
- const float dl = d * (ls - 32);
577
- const float ml = m * (ms - 2);
578
- y[i*256 + j + l] = v * dl + ml;
770
+ int vh = (qh[j/64 * 4 + l/8] >> (l%8)) & 1;
771
+ int v = (ql[l] & 0xF) | (vh << 4);
772
+ y[i*QK_K + j + l] = d1 * v - m1;
773
+ }
774
+ for (int l = 0; l < 32; l++) {
775
+ int vh = (qh[j/64 * 4 + 4 + l/8] >> (l%8)) & 1;
776
+ int v = (ql[l] >> 4) | (vh << 4);
777
+ y[i*QK_K + j + 32 + l] = d2 * v - m2;
579
778
  }
779
+ ql += 32;
780
+ is += 2;
580
781
  }
581
782
  }
582
783
  }
583
784
 
584
785
  static void dequantize_row_q6_K(const void *vx, float *y, int k) {
585
- const int nb = k / 256;
786
+ const int nb = k / QK_K;
586
787
  const uint8_t *x = vx;
587
788
  for (int i = 0; i < nb; i++) {
588
- const float d = ((const float*)(x + i*480))[0];
589
- const uint8_t *q = x + i*480 + 4;
590
- const uint8_t *qh = q + 256;
591
- const uint8_t *scales = qh + 64;
592
- for (int j = 0; j < 256; j += 64) {
593
- const uint8_t ls = scales[j/64];
594
- for (int l = 0; l < 64; l++) {
595
- int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
596
- const int bit = (qh[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
597
- v |= bit << 4;
598
- y[i*256 + j + l] = v * d * (ls - 32);
789
+ const uint8_t *block = x + i * 210;
790
+ const uint8_t *ql = block;
791
+ const uint8_t *qh = block + 128;
792
+ const int8_t *scales = (const int8_t*)(block + 192);
793
+ uint16_t d16;
794
+ memcpy(&d16, block + 208, 2);
795
+ const float d = fp16_to_fp32(d16);
796
+ for (int j = 0; j < QK_K; j += 128) {
797
+ for (int l = 0; l < 32; l++) {
798
+ int v = (ql[j/2 + l] & 0xF) | (((qh[j/4 + l/2] >> ((l%2)*4)) & 0xF) << 4);
799
+ y[i*QK_K + j + l] = v * d * scales[j/128 * 8 + l/4];
800
+ }
801
+ for (int l = 0; l < 32; l++) {
802
+ int v = (ql[j/2 + 32 + l] >> 4) | (((qh[j/4 + 16 + l/2] >> ((l%2)*4)) & 0xF) << 4);
803
+ y[i*QK_K + j + 32 + l] = v * d * scales[j/128 * 8 + 8 + l/4];
804
+ }
805
+ for (int l = 0; l < 32; l++) {
806
+ int v = (ql[j/2 + 64 + l] & 0xF) | (((qh[j/4 + 32 + l/2] >> ((l%2)*4)) & 0xF) << 4);
807
+ y[i*QK_K + j + 64 + l] = v * d * scales[j/128 * 8 + 4 + l/4];
808
+ }
809
+ for (int l = 0; l < 32; l++) {
810
+ int v = (ql[j/2 + 96 + l] >> 4) | (((qh[j/4 + 48 + l/2] >> ((l%2)*4)) & 0xF) << 4);
811
+ y[i*QK_K + j + 96 + l] = v * d * scales[j/128 * 8 + 12 + l/4];
599
812
  }
600
813
  }
601
814
  }
602
815
  }
603
816
 
604
817
  static void dequantize_row_q8_K(const void *vx, float *y, int k) {
605
- const int nb = k / 256;
818
+ const int nb = k / QK_K;
606
819
  const uint8_t *x = vx;
607
820
  for (int i = 0; i < nb; i++) {
608
- const float d = ((const float*)(x + i*544))[0];
609
- const int8_t *q = (const int8_t*)(x + i*544 + 4);
610
- const uint8_t *scales = (const uint8_t*)(q + 256);
611
- for (int j = 0; j < 256; j += 32) {
612
- const uint8_t ls = scales[j/32];
613
- for (int l = 0; l < 32; l++) {
614
- y[i*256 + j + l] = (float)q[j+l] * d * ls;
615
- }
821
+ const uint8_t *block = x + i * 292;
822
+ float d;
823
+ memcpy(&d, block, 4);
824
+ const int8_t *q = (const int8_t*)(block + 4);
825
+ for (int j = 0; j < QK_K; j++) {
826
+ y[i*QK_K + j] = (float)q[j] * d;
616
827
  }
617
828
  }
618
829
  }
619
830
 
620
- static float* dequantize_tensor(const void *data, int type, int n_rows, int n_cols) {
621
- if (type == GGML_TYPE_F32) {
622
- float *out = malloc(n_rows * n_cols * sizeof(float));
623
- if (!out) return NULL;
624
- memcpy(out, data, n_rows * n_cols * sizeof(float));
625
- return out;
831
+ // Lazy single-row dequantization
832
+ static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
833
+ if (!m->raw_tensor_data || row < 0 || row >= m->vocab_size) {
834
+ memset(out, 0, sizeof(float) * m->dim);
835
+ return;
626
836
  }
627
- if (type == GGML_TYPE_F16) {
628
- float *out = malloc(n_rows * n_cols * sizeof(float));
629
- if (!out) return NULL;
630
- const uint16_t *in = data;
631
- for (int i = 0; i < n_rows * n_cols; i++) {
632
- out[i] = fp16_to_fp32(in[i]);
837
+
838
+ const uint8_t *raw;
839
+ int effective_cols;
840
+
841
+ if (m->need_transpose) {
842
+ int src_row_size;
843
+ switch (m->tensor_type) {
844
+ case GGML_TYPE_F32: src_row_size = m->raw_dim1 * sizeof(float); break;
845
+ case GGML_TYPE_F16: src_row_size = m->raw_dim1 * sizeof(uint16_t); break;
846
+ default: {
847
+ size_t rb = 0;
848
+ int nc = (int)m->raw_dim1;
849
+ switch (m->tensor_type) {
850
+ case GGML_TYPE_Q4_0: rb = (nc / 32) * 18; break;
851
+ case GGML_TYPE_Q4_1: rb = (nc / 32) * 20; break;
852
+ case GGML_TYPE_Q5_0: rb = (nc / 32) * 22; break;
853
+ case GGML_TYPE_Q5_1: rb = (nc / 32) * 24; break;
854
+ case GGML_TYPE_Q8_0: rb = (nc / 32) * 34; break;
855
+ case GGML_TYPE_Q8_1: rb = (nc / 32) * 40; break;
856
+ case GGML_TYPE_Q2_K: rb = (nc / 256) * 84; break;
857
+ case GGML_TYPE_Q3_K: rb = (nc / 256) * 110; break;
858
+ case GGML_TYPE_Q4_K: rb = (nc / 256) * 144; break;
859
+ case GGML_TYPE_Q5_K: rb = (nc / 256) * 176; break;
860
+ case GGML_TYPE_Q6_K: rb = (nc / 256) * 210; break;
861
+ case GGML_TYPE_Q8_K: rb = (nc / 256) * 292; break;
862
+ default: src_row_size = 0; return;
863
+ }
864
+ src_row_size = (int)rb;
865
+ }
866
+ }
867
+ float *temp_row = malloc(m->raw_dim1 * sizeof(float));
868
+ if (!temp_row) return;
869
+ for (int col = 0; col < m->dim; col++) {
870
+ const uint8_t *src_row = (const uint8_t*)m->raw_tensor_data + col * src_row_size;
871
+ if (m->tensor_type == GGML_TYPE_F32) {
872
+ float val;
873
+ memcpy(&val, src_row + row * sizeof(float), sizeof(float));
874
+ out[col] = val;
875
+ } else if (m->tensor_type == GGML_TYPE_F16) {
876
+ uint16_t val;
877
+ memcpy(&val, src_row + row * sizeof(uint16_t), sizeof(uint16_t));
878
+ out[col] = fp16_to_fp32(val);
879
+ } else {
880
+ memset(out, 0, sizeof(float) * m->dim);
881
+ free(temp_row);
882
+ return;
883
+ }
633
884
  }
634
- return out;
885
+ free(temp_row);
886
+ return;
635
887
  }
636
888
 
637
- float *out = malloc(n_rows * n_cols * sizeof(float));
638
- if (!out) return NULL;
639
- const uint8_t *in = data;
640
- size_t row_bytes = 0;
641
- void (*dequant_func)(const void*, float*, int) = NULL;
642
-
643
- switch (type) {
644
- case GGML_TYPE_Q4_0: dequant_func = dequantize_row_q4_0; row_bytes = (n_cols / 32) * 34; break;
645
- case GGML_TYPE_Q4_1: dequant_func = dequantize_row_q4_1; row_bytes = (n_cols / 32) * 36; break;
646
- case GGML_TYPE_Q5_0: dequant_func = dequantize_row_q5_0; row_bytes = (n_cols / 32) * 40; break;
647
- case GGML_TYPE_Q5_1: dequant_func = dequantize_row_q5_1; row_bytes = (n_cols / 32) * 44; break;
648
- case GGML_TYPE_Q8_0: dequant_func = dequantize_row_q8_0; row_bytes = (n_cols / 32) * 36; break;
649
- case GGML_TYPE_Q8_1: dequant_func = dequantize_row_q8_1; row_bytes = (n_cols / 32) * 40; break;
650
- case GGML_TYPE_Q2_K: dequant_func = dequantize_row_q2_K; row_bytes = (n_cols / 256) * 336; break;
651
- case GGML_TYPE_Q3_K: dequant_func = dequantize_row_q3_K; row_bytes = (n_cols / 256) * 352; break;
652
- case GGML_TYPE_Q4_K: dequant_func = dequantize_row_q4_K; row_bytes = (n_cols / 256) * 416; break;
653
- case GGML_TYPE_Q5_K: dequant_func = dequantize_row_q5_K; row_bytes = (n_cols / 256) * 448; break;
654
- case GGML_TYPE_Q6_K: dequant_func = dequantize_row_q6_K; row_bytes = (n_cols / 256) * 480; break;
655
- case GGML_TYPE_Q8_K: dequant_func = dequantize_row_q8_K; row_bytes = (n_cols / 256) * 544; break;
889
+ raw = (const uint8_t*)m->raw_tensor_data + row * m->row_bytes;
890
+ effective_cols = m->dim;
891
+
892
+ switch (m->tensor_type) {
893
+ case GGML_TYPE_F32:
894
+ memcpy(out, raw, effective_cols * sizeof(float));
895
+ break;
896
+ case GGML_TYPE_F16:
897
+ for (int j = 0; j < effective_cols; j++) {
898
+ uint16_t h;
899
+ memcpy(&h, raw + j * sizeof(uint16_t), sizeof(uint16_t));
900
+ out[j] = fp16_to_fp32(h);
901
+ }
902
+ break;
903
+ case GGML_TYPE_Q4_0: dequantize_row_q4_0(raw, out, effective_cols); break;
904
+ case GGML_TYPE_Q4_1: dequantize_row_q4_1(raw, out, effective_cols); break;
905
+ case GGML_TYPE_Q5_0: dequantize_row_q5_0(raw, out, effective_cols); break;
906
+ case GGML_TYPE_Q5_1: dequantize_row_q5_1(raw, out, effective_cols); break;
907
+ case GGML_TYPE_Q8_0: dequantize_row_q8_0(raw, out, effective_cols); break;
908
+ case GGML_TYPE_Q8_1: dequantize_row_q8_1(raw, out, effective_cols); break;
909
+ case GGML_TYPE_Q2_K: dequantize_row_q2_K(raw, out, effective_cols); break;
910
+ case GGML_TYPE_Q3_K: dequantize_row_q3_K(raw, out, effective_cols); break;
911
+ case GGML_TYPE_Q4_K: dequantize_row_q4_K(raw, out, effective_cols); break;
912
+ case GGML_TYPE_Q5_K: dequantize_row_q5_K(raw, out, effective_cols); break;
913
+ case GGML_TYPE_Q6_K: dequantize_row_q6_K(raw, out, effective_cols); break;
914
+ case GGML_TYPE_Q8_K: dequantize_row_q8_K(raw, out, effective_cols); break;
656
915
  default:
657
- free(out);
658
- return NULL;
916
+ memset(out, 0, sizeof(float) * effective_cols);
659
917
  }
660
918
 
661
- for (int r = 0; r < n_rows; r++) {
662
- dequant_func(in + r * row_bytes, out + r * n_cols, n_cols);
919
+ for (int j = 0; j < effective_cols; j++) {
920
+ if (isnan(out[j]) || isinf(out[j]) || fabsf(out[j]) > 1e10f) {
921
+ out[j] = 0.0f;
922
+ }
663
923
  }
924
+ }
664
925
 
665
- // Sanitize the tensor: replace NaNs, Infs, and astronomically large values with zero
666
- int total = n_rows * n_cols;
667
- for (int i = 0; i < total; i++) {
668
- if (isnan(out[i]) || isinf(out[i]) || fabs(out[i]) > 1e10f) {
669
- out[i] = 0.0f;
670
- }
926
+ static size_t get_row_bytes(int type, int n_cols) {
927
+ switch (type) {
928
+ case GGML_TYPE_F32: return n_cols * sizeof(float);
929
+ case GGML_TYPE_F16: return n_cols * sizeof(uint16_t);
930
+ case GGML_TYPE_Q4_0: return (n_cols / 32) * 18;
931
+ case GGML_TYPE_Q4_1: return (n_cols / 32) * 20;
932
+ case GGML_TYPE_Q5_0: return (n_cols / 32) * 22;
933
+ case GGML_TYPE_Q5_1: return (n_cols / 32) * 24;
934
+ case GGML_TYPE_Q8_0: return (n_cols / 32) * 34;
935
+ case GGML_TYPE_Q8_1: return (n_cols / 32) * 40;
936
+ case GGML_TYPE_Q2_K: return (n_cols / 256) * 84;
937
+ case GGML_TYPE_Q3_K: return (n_cols / 256) * 110;
938
+ case GGML_TYPE_Q4_K: return (n_cols / 256) * 144;
939
+ case GGML_TYPE_Q5_K: return (n_cols / 256) * 176;
940
+ case GGML_TYPE_Q6_K: return (n_cols / 256) * 210;
941
+ case GGML_TYPE_Q8_K: return (n_cols / 256) * 292;
942
+ default: return 0;
671
943
  }
672
- return out;
673
944
  }
674
945
 
675
946
  /* ------------------------------------------------------------------------- */
@@ -685,7 +956,7 @@ static int skip_value(uint8_t **p, uint8_t *end, uint32_t type) {
685
956
  case 9: {
686
957
  uint32_t subtype = rd32(p, end);
687
958
  uint64_t n = rd64(p, end);
688
- for (uint64_t i = 0; i < n; i++)
959
+ for (uint64_t i = 0; i < n && i < 1000000; i++)
689
960
  if (!skip_value(p, end, subtype)) return 0;
690
961
  return 1;
691
962
  }
@@ -711,13 +982,8 @@ static void free_model_contents(EmbedModel *m) {
711
982
  }
712
983
  free(m->table);
713
984
  }
714
- if (m->float_data) free(m->float_data);
715
985
  if (m->mapped) munmap(m->mapped, m->mapped_size);
716
986
  bpe_merge_table_free(&m->merges);
717
- if (m->pre_patterns) {
718
- for (int i = 0; i < m->num_pre_patterns; i++) free(m->pre_patterns[i].pattern);
719
- free(m->pre_patterns);
720
- }
721
987
  free(m);
722
988
  }
723
989
 
@@ -741,35 +1007,29 @@ static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
741
1007
 
742
1008
  /* ------------------------------------------------------------------------- */
743
1009
  static void detect_space_marker(EmbedModel *m) {
744
- const char *candidates[] = {"▁", "Ġ", " "};
745
- for (int i = 0; i < 3; i++) {
746
- const char *marker = candidates[i];
747
- int marker_len = strlen(marker);
748
- for (int j = 0; j < m->vocab_size; j++) {
749
- if (strncmp(m->tokens[j], marker, marker_len) == 0) {
750
- strcpy(m->space_marker, marker);
751
- return;
1010
+ int marker_count[4] = {0};
1011
+ const char *markers[] = {"▁", "Ġ", "ĉ", " "};
1012
+ int marker_lens[] = {3, 2, 2, 1};
1013
+
1014
+ for (int i = 0; i < m->vocab_size && i < 5000; i++) {
1015
+ for (int j = 0; j < 3; j++) {
1016
+ if (strncmp(m->tokens[i], markers[j], marker_lens[j]) == 0) {
1017
+ marker_count[j]++;
752
1018
  }
753
1019
  }
1020
+ if (m->tokens[i][0] == ' ' && strlen(m->tokens[i]) > 1) {
1021
+ marker_count[3]++;
1022
+ }
1023
+ }
1024
+
1025
+ int best = 0;
1026
+ for (int i = 1; i < 4; i++) {
1027
+ if (marker_count[i] > marker_count[best]) best = i;
754
1028
  }
755
- m->space_marker[0] = '\0';
756
- }
757
1029
 
758
- static void setup_default_pre_patterns(EmbedModel *m) {
759
- const char *default_patterns[] = {
760
- "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])",
761
- "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
762
- "\\p{N}{1,3}",
763
- " ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
764
- "\\s*[\\r\\n]+",
765
- "\\s+(?!\\S)",
766
- "\\s+"
767
- };
768
- m->num_pre_patterns = sizeof(default_patterns)/sizeof(default_patterns[0]);
769
- m->pre_patterns = malloc(m->num_pre_patterns * sizeof(RegexPattern));
770
- for (int i = 0; i < m->num_pre_patterns; i++) {
771
- m->pre_patterns[i].pattern = strdup(default_patterns[i]);
772
- m->pre_patterns[i].pattern_len = strlen(default_patterns[i]);
1030
+ if (marker_count[best] > 10) {
1031
+ strcpy(m->space_marker, markers[best]);
1032
+ m->space_marker_len = marker_lens[best];
773
1033
  }
774
1034
  }
775
1035
 
@@ -777,10 +1037,10 @@ static void parse_merge(const char *merge_str, char **left, char **right) {
777
1037
  const char *space = strchr(merge_str, ' ');
778
1038
  if (space) {
779
1039
  int left_len = space - merge_str;
780
- *left = malloc(left_len+1);
1040
+ *left = malloc(left_len + 1);
781
1041
  memcpy(*left, merge_str, left_len);
782
1042
  (*left)[left_len] = '\0';
783
- *right = strdup(space+1);
1043
+ *right = strdup(space + 1);
784
1044
  } else {
785
1045
  *left = strdup(merge_str);
786
1046
  *right = strdup("");
@@ -793,12 +1053,15 @@ static EmbedModel *embed_load_gguf(const char *path) {
793
1053
  uint8_t *base = map_file(path, &sz);
794
1054
  if (!base) return NULL;
795
1055
  uint8_t *cur = base, *end = base + sz;
796
- if (memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
1056
+ if (sz < 4 || memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
797
1057
  cur += 4;
798
1058
  uint32_t version = rd32(&cur, end);
1059
+ (void)version;
799
1060
  uint64_t n_tensors = rd64(&cur, end);
800
1061
  uint64_t n_kv = rd64(&cur, end);
801
1062
 
1063
+ if (n_kv > 1000000 || n_tensors > 1000000) { munmap(base, sz); return NULL; }
1064
+
802
1065
  EmbedModel *m = calloc(1, sizeof(*m));
803
1066
  if (!m) { munmap(base, sz); return NULL; }
804
1067
  m->mapped = base;
@@ -806,22 +1069,22 @@ static EmbedModel *embed_load_gguf(const char *path) {
806
1069
  m->table = calloc(HASH_SIZE, sizeof(HashNode*));
807
1070
  if (!m->table) { free_model_contents(m); return NULL; }
808
1071
  bpe_merge_table_init(&m->merges);
809
- setup_default_pre_patterns(m);
810
1072
  m->unknown_token_id = -1;
811
1073
  m->bos_token_id = -1;
812
1074
  m->eos_token_id = -1;
813
1075
  m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
814
- m->space_marker[0] = '\0';
1076
+ m->normalize = NORM_NONE;
815
1077
 
816
1078
  int vocab_found = 0;
817
1079
  for (uint64_t i = 0; i < n_kv; i++) {
818
1080
  char *key = rdstr(&cur, end);
819
1081
  if (!key) { free_model_contents(m); return NULL; }
820
1082
  uint32_t type = rd32(&cur, end);
1083
+
821
1084
  if ((strcmp(key, "tokenizer.ggml.tokens") == 0 || strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
822
1085
  uint32_t subtype = rd32(&cur, end);
823
1086
  uint64_t n = rd64(&cur, end);
824
- if (subtype != 8) { free(key); free_model_contents(m); return NULL; }
1087
+ if (subtype != 8 || n > 1000000) { free(key); free_model_contents(m); return NULL; }
825
1088
  m->tokens = malloc(sizeof(char*) * n);
826
1089
  if (!m->tokens) { free(key); free_model_contents(m); return NULL; }
827
1090
  m->vocab_size = (int)n;
@@ -841,8 +1104,17 @@ static EmbedModel *embed_load_gguf(const char *path) {
841
1104
  if (merge_str) {
842
1105
  char *left, *right;
843
1106
  parse_merge(merge_str, &left, &right);
844
- bpe_merge_table_add(&m->merges, left, right, merge_str, (int)j);
845
- free(left); free(right);
1107
+ bpe_merge_table_add(&m->merges, left, right, (int)j);
1108
+ free(left);
1109
+ free(right);
1110
+ free(merge_str);
1111
+ } else {
1112
+ break;
1113
+ }
1114
+ }
1115
+ if (n > MAX_MERGES) {
1116
+ for (uint64_t j = MAX_MERGES; j < n; j++) {
1117
+ char *merge_str = rdstr(&cur, end);
846
1118
  free(merge_str);
847
1119
  }
848
1120
  }
@@ -852,24 +1124,32 @@ static EmbedModel *embed_load_gguf(const char *path) {
852
1124
  } else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
853
1125
  char *model_type = rdstr(&cur, end);
854
1126
  if (model_type) {
855
- if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0) m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
856
- else if (strcmp(model_type, "bert") == 0) m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
1127
+ if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0 ||
1128
+ strcmp(model_type, "phi") == 0 || strcmp(model_type, "qwen") == 0)
1129
+ m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
1130
+ else if (strcmp(model_type, "bert") == 0)
1131
+ m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
1132
+ else if (strcmp(model_type, "spm") == 0)
1133
+ m->vocab_type = LLAMA_VOCAB_TYPE_SPM;
857
1134
  free(model_type);
858
1135
  }
859
1136
  } else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
860
1137
  char *pre = rdstr(&cur, end);
861
- if (pre) free(pre);
1138
+ free(pre);
862
1139
  } else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
863
- m->unknown_token_id = rd32(&cur, end);
1140
+ m->unknown_token_id = (int)rd32(&cur, end);
864
1141
  } else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
865
- m->bos_token_id = rd32(&cur, end);
1142
+ m->bos_token_id = (int)rd32(&cur, end);
866
1143
  } else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
867
- m->eos_token_id = rd32(&cur, end);
1144
+ m->eos_token_id = (int)rd32(&cur, end);
1145
+ } else if (strcmp(key, "general.alignment") == 0 && type == 6) {
1146
+ rd32(&cur, end);
868
1147
  } else {
869
1148
  if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
870
1149
  }
871
1150
  free(key);
872
1151
  }
1152
+
873
1153
  if (!vocab_found) { free_model_contents(m); return NULL; }
874
1154
  detect_space_marker(m);
875
1155
 
@@ -877,10 +1157,6 @@ static EmbedModel *embed_load_gguf(const char *path) {
877
1157
  align_to_32(&cur, end, base);
878
1158
  uint8_t *tensor_start = cur;
879
1159
  int embd_found = 0;
880
- void *raw_tensor_data = NULL;
881
- int tensor_type = -1;
882
- uint64_t dim0 = 0, dim1 = 0;
883
- int need_transpose = 0;
884
1160
 
885
1161
  for (int attempt = 0; attempt < 2; attempt++) {
886
1162
  cur = tensor_start;
@@ -892,21 +1168,58 @@ static EmbedModel *embed_load_gguf(const char *path) {
892
1168
  for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
893
1169
  uint32_t type = rd32(&cur, end);
894
1170
  uint64_t offset = rd64(&cur, end);
1171
+
895
1172
  int is_token_embd = (strcmp(name, "token_embd.weight") == 0 ||
896
1173
  strcmp(name, "embeddings.word_embeddings.weight") == 0 ||
897
1174
  strcmp(name, "model.embed_tokens.weight") == 0);
1175
+
898
1176
  if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
899
1177
  if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
900
1178
  else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
901
1179
  }
1180
+
902
1181
  if (!embd_found && is_token_embd) {
903
- if (n_dims < 2 || dims[1] == 0) { free(name); free_model_contents(m); return NULL; }
904
- dim0 = dims[0]; dim1 = dims[1];
905
- if (dim0 == (uint64_t)m->vocab_size) { m->dim = (int)dim1; need_transpose = 0; }
906
- else if (dim1 == (uint64_t)m->vocab_size) { m->dim = (int)dim0; need_transpose = 1; }
907
- else { m->dim = (dim0 < dim1) ? (int)dim0 : (int)dim1; need_transpose = (dim0 > dim1) ? 1 : 0; }
908
- raw_tensor_data = base + offset;
909
- tensor_type = type;
1182
+ if (n_dims < 2 || dims[1] == 0) {
1183
+ free(name); free_model_contents(m); return NULL;
1184
+ }
1185
+
1186
+ uint64_t ne0 = dims[0];
1187
+ uint64_t ne1 = dims[1];
1188
+
1189
+ int need_transpose = 0;
1190
+ int dim;
1191
+
1192
+ if (ne1 == (uint64_t)m->vocab_size) {
1193
+ dim = (int)ne0;
1194
+ need_transpose = 0;
1195
+ } else if (ne0 == (uint64_t)m->vocab_size) {
1196
+ dim = (int)ne1;
1197
+ need_transpose = 1;
1198
+ } else {
1199
+ dim = (ne0 < ne1) ? (int)ne0 : (int)ne1;
1200
+ need_transpose = (ne0 > ne1) ? 1 : 0;
1201
+ }
1202
+
1203
+ if (dim <= 0 || dim > MAX_DIM) {
1204
+ free(name); free_model_contents(m); return NULL;
1205
+ }
1206
+
1207
+ size_t row_bytes = get_row_bytes(type, (int)(need_transpose ? ne1 : ne0));
1208
+ size_t total_size = (size_t)(need_transpose ? ne1 : ne0) * row_bytes;
1209
+
1210
+ if (offset >= sz || offset + total_size > sz) {
1211
+ free(name);
1212
+ free_model_contents(m);
1213
+ return NULL;
1214
+ }
1215
+
1216
+ m->dim = dim;
1217
+ m->raw_dim0 = ne0;
1218
+ m->raw_dim1 = ne1;
1219
+ m->need_transpose = need_transpose;
1220
+ m->raw_tensor_data = base + offset;
1221
+ m->tensor_type = type;
1222
+ m->row_bytes = row_bytes;
910
1223
  embd_found = 1;
911
1224
  free(name);
912
1225
  break;
@@ -919,97 +1232,105 @@ static EmbedModel *embed_load_gguf(const char *path) {
919
1232
  if (!tensor_start) break;
920
1233
  }
921
1234
  }
922
- if (!embd_found || m->dim == 0) { free_model_contents(m); return NULL; }
923
1235
 
924
- if (tensor_type == GGML_TYPE_F32 && !need_transpose) {
925
- m->float_data = NULL;
926
- m->tensor_data = raw_tensor_data;
927
- } else {
928
- int n_rows = need_transpose ? (int)dim1 : (int)dim0;
929
- int n_cols = need_transpose ? (int)dim0 : (int)dim1;
930
- m->float_data = dequantize_tensor(raw_tensor_data, tensor_type, n_rows, n_cols);
931
- if (!m->float_data) { free_model_contents(m); return NULL; }
932
- m->tensor_data = m->float_data;
1236
+ if (!embd_found || m->dim == 0) {
1237
+ free_model_contents(m); return NULL;
933
1238
  }
934
- m->tensor_type = tensor_type;
935
1239
 
936
1240
  return m;
937
1241
  }
938
1242
 
1243
+ /* ------------------------------------------------------------------------- */
1244
+ // L2 normalization
1245
+ static void normalize_l2(float *vec, int dim) {
1246
+ float sum = 0;
1247
+ for (int i = 0; i < dim; i++) sum += vec[i] * vec[i];
1248
+ float norm = sqrtf(sum);
1249
+ if (norm > 1e-8f) {
1250
+ float inv = 1.0f / norm;
1251
+ for (int i = 0; i < dim; i++) vec[i] *= inv;
1252
+ }
1253
+ }
1254
+
939
1255
  /* ------------------------------------------------------------------------- */
940
1256
  static void embed_text(EmbedModel *m, const char *txt, float *out) {
941
1257
  memset(out, 0, sizeof(float) * m->dim);
1258
+ if (!txt || !*txt) return;
1259
+
942
1260
  int num_words = 0;
943
- char **words = unicode_regex_split(txt, m->pre_patterns, m->num_pre_patterns, &num_words);
1261
+ char **words = pre_tokenize(txt, &num_words);
1262
+
944
1263
  if (!words || num_words == 0) {
945
- // Fallback to simple space split
946
- char *copy = strdup(txt);
947
- if (copy) {
948
- char *tok = strtok(copy, " \t\n\r");
949
- int used = 0;
950
- const float *embd = (float*)m->tensor_data;
951
- while (tok) {
952
- int id = hget(m, tok);
953
- if (id >= 0 && id < m->vocab_size) {
954
- const float *vec = embd + id * m->dim;
955
- for (int i = 0; i < m->dim; i++) out[i] += vec[i];
956
- used++;
957
- }
958
- tok = strtok(NULL, " \t\n\r");
959
- }
960
- if (used) { float inv = 1.0f / used; for (int i = 0; i < m->dim; i++) out[i] *= inv; }
961
- free(copy);
962
- }
963
1264
  if (words) free(words);
964
1265
  return;
965
1266
  }
966
1267
 
967
1268
  int *token_ids = malloc(m->vocab_size * sizeof(int));
1269
+ if (!token_ids) {
1270
+ for (int i = 0; i < num_words; i++) free(words[i]);
1271
+ free(words);
1272
+ return;
1273
+ }
1274
+
968
1275
  int used = 0;
969
- const float *embd = (float*)m->tensor_data;
1276
+ float *temp_vec = malloc(m->dim * sizeof(float));
1277
+
970
1278
  for (int i = 0; i < num_words; i++) {
971
1279
  char *word = words[i];
972
1280
  int id = hget(m, word);
973
- if (id == -1 && m->space_marker[0]) {
974
- char *with_marker = malloc(strlen(m->space_marker) + strlen(word) + 1);
975
- strcpy(with_marker, m->space_marker);
976
- strcat(with_marker, word);
977
- id = hget(m, with_marker);
978
- free(with_marker);
1281
+
1282
+ if (id == -1 && m->space_marker_len > 0) {
1283
+ size_t with_marker_len = m->space_marker_len + strlen(word);
1284
+ char *with_marker = malloc(with_marker_len + 1);
1285
+ if (with_marker) {
1286
+ memcpy(with_marker, m->space_marker, m->space_marker_len);
1287
+ strcpy(with_marker + m->space_marker_len, word);
1288
+ id = hget(m, with_marker);
1289
+ free(with_marker);
1290
+ }
979
1291
  }
980
- if (id != -1) {
981
- const float *vec = embd + id * m->dim;
982
- for (int j = 0; j < m->dim; j++) out[j] += vec[j];
1292
+
1293
+ if (id != -1 && id >= 0 && id < m->vocab_size) {
1294
+ dequantize_row_lazy(m, id, temp_vec);
1295
+ for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
983
1296
  used++;
984
1297
  } else {
985
1298
  int num_tokens = 0;
986
- bpe_tokenize_word(&m->merges, word, text_to_id, m, token_ids, &num_tokens);
1299
+ bpe_tokenize_word(&m->merges, word, m, token_ids, &num_tokens);
987
1300
  for (int k = 0; k < num_tokens; k++) {
988
1301
  int tid = token_ids[k];
989
1302
  if (tid >= 0 && tid < m->vocab_size) {
990
- const float *vec = embd + tid * m->dim;
991
- for (int j = 0; j < m->dim; j++) out[j] += vec[j];
1303
+ dequantize_row_lazy(m, tid, temp_vec);
1304
+ for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
992
1305
  used++;
993
1306
  } else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
994
- const float *vec = embd + m->unknown_token_id * m->dim;
995
- for (int j = 0; j < m->dim; j++) out[j] += vec[j];
1307
+ dequantize_row_lazy(m, m->unknown_token_id, temp_vec);
1308
+ for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
996
1309
  used++;
997
1310
  }
998
1311
  }
999
1312
  }
1000
1313
  free(word);
1001
1314
  }
1315
+
1002
1316
  free(words);
1003
1317
  free(token_ids);
1318
+ free(temp_vec);
1319
+
1004
1320
  if (used > 0) {
1005
1321
  float inv = 1.0f / used;
1006
1322
  for (int i = 0; i < m->dim; i++) out[i] *= inv;
1007
1323
  }
1324
+
1008
1325
  for (int i = 0; i < m->dim; i++) {
1009
1326
  if (isnan(out[i]) || isinf(out[i])) {
1010
1327
  out[i] = 0.0f;
1011
1328
  }
1012
1329
  }
1330
+
1331
+ if (m->normalize == NORM_L2) {
1332
+ normalize_l2(out, m->dim);
1333
+ }
1013
1334
  }
1014
1335
 
1015
1336
  /* ------------------------------------------------------------------------- */
@@ -1020,7 +1341,14 @@ static void rb_embedder_free(void *p) {
1020
1341
  }
1021
1342
 
1022
1343
  static size_t rb_embedder_memsize(const void *p) {
1023
- return sizeof(ruby_embedder);
1344
+ const ruby_embedder *e = p;
1345
+ size_t sz = sizeof(ruby_embedder);
1346
+ if (e && e->model) {
1347
+ sz += e->model->vocab_size * sizeof(char*);
1348
+ sz += e->model->mapped_size;
1349
+ sz += HASH_SIZE * sizeof(HashNode*);
1350
+ }
1351
+ return sz;
1024
1352
  }
1025
1353
 
1026
1354
  static const rb_data_type_t ruby_embedder_type = {
@@ -1037,18 +1365,44 @@ static VALUE rb_embedder_alloc(VALUE klass) {
1037
1365
  static VALUE rb_embedder_initialize(VALUE self, VALUE opts) {
1038
1366
  ruby_embedder *e;
1039
1367
  TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
1368
+
1369
+ Check_Type(opts, T_HASH);
1040
1370
  VALUE path = rb_hash_aref(opts, ID2SYM(rb_intern("model")));
1371
+ if (NIL_P(path)) rb_raise(rb_eArgError, "missing required key: model");
1041
1372
  const char *cpath = StringValueCStr(path);
1373
+
1374
+ VALUE normalize = rb_hash_aref(opts, ID2SYM(rb_intern("normalize")));
1375
+ int norm_type = NORM_NONE;
1376
+ if (!NIL_P(normalize)) {
1377
+ if (SYMBOL_P(normalize)) {
1378
+ ID sym_id = SYM2ID(normalize);
1379
+ if (sym_id == rb_intern("l2") || sym_id == rb_intern("L2")) {
1380
+ norm_type = NORM_L2;
1381
+ }
1382
+ } else if (TYPE(normalize) == T_STRING) {
1383
+ const char *norm_str = StringValueCStr(normalize);
1384
+ if (strcasecmp(norm_str, "l2") == 0) {
1385
+ norm_type = NORM_L2;
1386
+ }
1387
+ }
1388
+ }
1389
+
1042
1390
  e->model = embed_load_gguf(cpath);
1043
- if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model");
1391
+ if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model: %s", cpath);
1392
+
1393
+ e->model->normalize = norm_type;
1044
1394
  return self;
1045
1395
  }
1046
1396
 
1047
1397
  static VALUE rb_embed(VALUE self, VALUE opts) {
1048
1398
  ruby_embedder *e;
1049
1399
  TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
1400
+
1401
+ Check_Type(opts, T_HASH);
1050
1402
  VALUE text = rb_hash_aref(opts, ID2SYM(rb_intern("text")));
1403
+ if (NIL_P(text)) rb_raise(rb_eArgError, "missing required key: text");
1051
1404
  const char *ctext = StringValueCStr(text);
1405
+
1052
1406
  VALUE out = rb_str_new(NULL, e->model->dim * sizeof(float));
1053
1407
  embed_text(e->model, ctext, (float*)RSTRING_PTR(out));
1054
1408
  return out;