mini_embed 0.1.1 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,13 +8,18 @@
8
8
  #include <fcntl.h>
9
9
  #include <unistd.h>
10
10
  #include <ctype.h>
11
+ #include <limits.h>
11
12
  #include "ruby.h"
12
13
 
13
- #define HASH_SIZE 131071
14
- #define MAX_DIMS 4
15
- #define GGUF_ALIGN 32
16
- #define MAX_MERGES 10000
17
- #define MAX_REGEX 256
14
+ #define HASH_SIZE 131071
15
+ #define MAX_DIMS 4
16
+ #define GGUF_ALIGN 32
17
+ #define MAX_MERGES 100000
18
+ #define MERGE_HASH_SIZE 65537
19
+ #define QK8_0 32
20
+ #define QK_K 256
21
+ #define K_SCALE_SIZE 12
22
+ #define MAX_DIM 16384
18
23
 
19
24
  enum ggml_type {
20
25
  GGML_TYPE_F32 = 0,
@@ -40,18 +45,22 @@ enum llama_vocab_type {
40
45
  LLAMA_VOCAB_TYPE_WPM = 3,
41
46
  };
42
47
 
48
+ enum normalize_type {
49
+ NORM_NONE = 0,
50
+ NORM_L2 = 1,
51
+ };
52
+
43
53
  /* ------------------------------------------------------------------------- */
44
- // Unicode helper functions (adapted from llama.cpp)
54
+ // Unicode helper functions
45
55
  static int unicode_len_utf8(char c) {
46
56
  if ((c & 0x80) == 0) return 1;
47
57
  if ((c & 0xE0) == 0xC0) return 2;
48
58
  if ((c & 0xF0) == 0xE0) return 3;
49
59
  if ((c & 0xF8) == 0xF0) return 4;
50
- return 1; // fallback
60
+ return 1;
51
61
  }
52
62
 
53
63
  static int unicode_is_letter(uint32_t cp) {
54
- // Basic Unicode letter detection (simplified)
55
64
  return (cp >= 0x41 && cp <= 0x5A) || (cp >= 0x61 && cp <= 0x7A) ||
56
65
  (cp >= 0xC0 && cp <= 0xD6) || (cp >= 0xD8 && cp <= 0xF6) ||
57
66
  (cp >= 0xF8 && cp <= 0x2FF) || (cp >= 0x370 && cp <= 0x37D) ||
@@ -68,224 +77,243 @@ static int unicode_is_number(uint32_t cp) {
68
77
  }
69
78
 
70
79
  static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
71
- uint32_t cp = 0;
72
80
  unsigned char c = (unsigned char)s[0];
73
-
74
- if (c < 0x80) {
75
- *len = 1;
76
- return c;
77
- } else if ((c & 0xE0) == 0xC0) {
78
- *len = 2;
79
- cp = (c & 0x1F) << 6;
80
- cp |= (s[1] & 0x3F);
81
- return cp;
82
- } else if ((c & 0xF0) == 0xE0) {
83
- *len = 3;
84
- cp = (c & 0x0F) << 12;
85
- cp |= (s[1] & 0x3F) << 6;
86
- cp |= (s[2] & 0x3F);
87
- return cp;
88
- } else if ((c & 0xF8) == 0xF0) {
89
- *len = 4;
90
- cp = (c & 0x07) << 18;
91
- cp |= (s[1] & 0x3F) << 12;
92
- cp |= (s[2] & 0x3F) << 6;
93
- cp |= (s[3] & 0x3F);
94
- return cp;
95
- }
96
-
81
+ if (c < 0x80) { *len = 1; return c; }
82
+ if ((c & 0xE0) == 0xC0) { *len = 2; return ((c & 0x1F) << 6) | (s[1] & 0x3F); }
83
+ if ((c & 0xF0) == 0xE0) { *len = 3; return ((c & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F); }
84
+ if ((c & 0xF8) == 0xF0) { *len = 4; return ((c & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F); }
97
85
  *len = 1;
98
86
  return c;
99
87
  }
100
88
 
101
89
  /* ------------------------------------------------------------------------- */
102
- // Simple regex pattern matcher for pre-tokenization
103
- typedef struct {
104
- char *pattern;
105
- int pattern_len;
106
- } RegexPattern;
107
-
108
- static int match_regex(const char *text, const RegexPattern *patterns, int num_patterns) {
109
- // Simplified implementation for common BPE patterns
110
- // Full regex engine would be complex; this handles the most common cases
111
-
112
- for (int i = 0; i < num_patterns; i++) {
113
- const char *p = patterns[i].pattern;
114
- int plen = patterns[i].pattern_len;
115
-
116
- // Check for common patterns
117
- if (strstr(p, "\\p{L}")) {
118
- // Match Unicode letter
119
- size_t len;
120
- uint32_t cp = unicode_cpt_from_utf8(text, &len);
121
- if (unicode_is_letter(cp)) return 1;
122
- } else if (strstr(p, "\\p{N}")) {
123
- // Match Unicode number
124
- size_t len;
125
- uint32_t cp = unicode_cpt_from_utf8(text, &len);
126
- if (unicode_is_number(cp)) return 1;
127
- } else if (p[0] == '\\' && p[1] == 's') {
128
- // Match whitespace
129
- if (isspace(text[0])) return 1;
130
- } else if (p[0] == '\\' && p[1] == 'r') {
131
- if (text[0] == '\r') return 1;
132
- } else if (p[0] == '\\' && p[1] == 'n') {
133
- if (text[0] == '\n') return 1;
134
- } else if (p[0] == '.' && p[1] == '*') {
135
- // Match anything
136
- return 1;
137
- } else if (isalnum(p[0]) || ispunct(p[0])) {
138
- // Match literal character
139
- if (text[0] == p[0]) return 1;
140
- }
90
+ // Pre-tokenizer (GPT-2/Llama style, replaces broken regex)
91
+ #define CHAR_CLASS_SPACE 0
92
+ #define CHAR_CLASS_LETTER 1
93
+ #define CHAR_CLASS_NUMBER 2
94
+ #define CHAR_CLASS_NEWLINE 3
95
+ #define CHAR_CLASS_OTHER 4
96
+
97
+ static int get_char_class(uint32_t cp) {
98
+ if (unicode_is_letter(cp)) return CHAR_CLASS_LETTER;
99
+ if (unicode_is_number(cp)) return CHAR_CLASS_NUMBER;
100
+ if (cp == '\n' || cp == '\r') return CHAR_CLASS_NEWLINE;
101
+ if (cp == ' ' || cp == '\t') return CHAR_CLASS_SPACE;
102
+ return CHAR_CLASS_OTHER;
103
+ }
104
+
105
+ static int is_contraction(const char *text, size_t pos, size_t text_len) {
106
+ if (pos >= text_len) return 0;
107
+ unsigned char c = (unsigned char)text[pos];
108
+ if (c != '\'' && c != 0xE2) return 0;
109
+ if (c == 0xE2 && pos + 2 < text_len && text[pos+1] == 0x80 && (text[pos+2] == 0x99 || text[pos+2] == 0x98)) {
110
+ if (pos + 3 >= text_len) return 0;
111
+ char next = tolower((unsigned char)text[pos + 3]);
112
+ return next == 's' || next == 't' || next == 'r' || next == 'v' ||
113
+ next == 'm' || next == 'l' || next == 'd';
114
+ }
115
+ if (c == '\'' && pos + 1 < text_len) {
116
+ char next = tolower((unsigned char)text[pos + 1]);
117
+ return next == 's' || next == 't' || next == 'r' || next == 'v' ||
118
+ next == 'm' || next == 'l' || next == 'd';
141
119
  }
142
120
  return 0;
143
121
  }
144
122
 
145
- static char** unicode_regex_split(const char *text, const RegexPattern *patterns, int num_patterns, int *num_words) {
123
+ static size_t contraction_len(const char *text, size_t pos) {
124
+ unsigned char c = (unsigned char)text[pos];
125
+ if (c == '\'') return 2;
126
+ return 4;
127
+ }
128
+
129
+ static char** pre_tokenize(const char *text, int *num_words) {
146
130
  char **words = NULL;
147
- int word_count = 0;
148
- int word_capacity = 0;
149
-
131
+ int word_count = 0, word_capacity = 0;
150
132
  size_t text_len = strlen(text);
151
- size_t pos = 0;
152
-
153
- while (pos < text_len) {
154
- // Find the start of a word (character that matches any regex)
155
- size_t start = pos;
156
- while (start < text_len) {
157
- if (match_regex(text + start, patterns, num_patterns)) {
158
- break;
159
- }
160
- start++;
133
+
134
+ if (text_len == 0) {
135
+ *num_words = 0;
136
+ return NULL;
137
+ }
138
+
139
+ #define ADD_WORD(ptr, len) do { \
140
+ char *w = malloc((len) + 1); \
141
+ if (!w) goto error; \
142
+ memcpy(w, ptr, len); \
143
+ w[len] = '\0'; \
144
+ if (word_count >= word_capacity) { \
145
+ word_capacity = word_capacity ? word_capacity * 2 : 16; \
146
+ char **nw = realloc(words, word_capacity * sizeof(char*)); \
147
+ if (!nw) { free(w); goto error; } \
148
+ words = nw; \
149
+ } \
150
+ words[word_count++] = w; \
151
+ } while(0)
152
+
153
+ size_t i = 0;
154
+ while (i < text_len) {
155
+ size_t char_len;
156
+ uint32_t cp = unicode_cpt_from_utf8(text + i, &char_len);
157
+ int cls = get_char_class(cp);
158
+
159
+ if (cls == CHAR_CLASS_NEWLINE) {
160
+ ADD_WORD(text + i, char_len);
161
+ i += char_len;
162
+ continue;
161
163
  }
162
-
163
- if (start >= text_len) break;
164
-
165
- // Find the end of the word (character that doesn't match any regex)
166
- size_t end = start;
167
- while (end < text_len) {
168
- if (!match_regex(text + end, patterns, num_patterns)) {
169
- break;
164
+
165
+ if (cls == CHAR_CLASS_SPACE) {
166
+ size_t space_start = i;
167
+ while (i < text_len) {
168
+ size_t cl;
169
+ uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
170
+ int cc = get_char_class(c);
171
+ if (cc != CHAR_CLASS_SPACE) break;
172
+ i += cl;
170
173
  }
171
- end++;
174
+ if (i >= text_len) break;
175
+ size_t space_len = i - space_start;
176
+ ADD_WORD(text + space_start, space_len);
177
+ continue;
172
178
  }
173
-
174
- if (end > start) {
175
- // Extract the word
176
- size_t word_len = end - start;
177
- char *word = malloc(word_len + 1);
178
- if (word) {
179
- memcpy(word, text + start, word_len);
180
- word[word_len] = '\0';
181
-
182
- // Add to array
183
- if (word_count >= word_capacity) {
184
- word_capacity = word_capacity == 0 ? 16 : word_capacity * 2;
185
- words = realloc(words, word_capacity * sizeof(char*));
186
- if (!words) {
187
- for (int i = 0; i < word_count; i++) free(words[i]);
188
- free(words);
189
- *num_words = 0;
190
- return NULL;
191
- }
179
+
180
+ size_t start = i;
181
+ i += char_len;
182
+
183
+ while (i < text_len) {
184
+ size_t cl;
185
+ uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
186
+ int ccls = get_char_class(c);
187
+
188
+ if (is_contraction(text, i, text_len)) {
189
+ size_t clen = contraction_len(text, i);
190
+ i += clen;
191
+ continue;
192
+ }
193
+
194
+ if (ccls != cls) break;
195
+ if (cls == CHAR_CLASS_NUMBER) {
196
+ int digits = 0;
197
+ size_t check = start;
198
+ while (check < i) {
199
+ size_t dl;
200
+ uint32_t dc = unicode_cpt_from_utf8(text + check, &dl);
201
+ if (get_char_class(dc) == CHAR_CLASS_NUMBER) digits++;
202
+ check += dl;
192
203
  }
193
- words[word_count++] = word;
204
+ if (digits >= 3) break;
194
205
  }
206
+ i += cl;
195
207
  }
196
-
197
- pos = end;
208
+
209
+ ADD_WORD(text + start, i - start);
198
210
  }
199
-
211
+
212
+ #undef ADD_WORD
200
213
  *num_words = word_count;
201
214
  return words;
215
+
216
+ error:
217
+ for (int j = 0; j < word_count; j++) free(words[j]);
218
+ free(words);
219
+ *num_words = 0;
220
+ return NULL;
202
221
  }
203
222
 
204
223
  /* ------------------------------------------------------------------------- */
205
- // BPE merge structure
206
- typedef struct {
224
+ // BPE merge structures with hash table for O(1) lookup
225
+ typedef struct MergeHashNode {
207
226
  char *left;
208
227
  char *right;
209
- char *merged;
210
228
  int rank;
211
- } BPEMerge;
229
+ struct MergeHashNode *next;
230
+ } MergeHashNode;
212
231
 
213
232
  typedef struct {
214
- BPEMerge *merges;
233
+ MergeHashNode **table;
234
+ int table_size;
215
235
  int num_merges;
216
- int capacity;
217
236
  } BPEMergeTable;
218
237
 
238
+ static uint64_t merge_hash(const char *left, const char *right) {
239
+ uint64_t h = 0xcbf29ce484222325ULL;
240
+ while (*left) { h ^= (uint64_t)(unsigned char)*left++; h *= 0x100000001b3ULL; }
241
+ h ^= (uint64_t)' ';
242
+ h *= 0x100000001b3ULL;
243
+ while (*right) { h ^= (uint64_t)(unsigned char)*right++; h *= 0x100000001b3ULL; }
244
+ return h;
245
+ }
246
+
219
247
  static void bpe_merge_table_init(BPEMergeTable *table) {
220
- table->merges = NULL;
248
+ table->table_size = MERGE_HASH_SIZE;
249
+ table->table = calloc(MERGE_HASH_SIZE, sizeof(MergeHashNode*));
221
250
  table->num_merges = 0;
222
- table->capacity = 0;
223
251
  }
224
252
 
225
- static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, const char *merged, int rank) {
226
- if (table->num_merges >= table->capacity) {
227
- table->capacity = table->capacity == 0 ? 100 : table->capacity * 2;
228
- table->merges = realloc(table->merges, table->capacity * sizeof(BPEMerge));
229
- }
230
-
231
- BPEMerge *merge = &table->merges[table->num_merges++];
232
- merge->left = strdup(left);
233
- merge->right = strdup(right);
234
- merge->merged = strdup(merged);
235
- merge->rank = rank;
253
+ static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, int rank) {
254
+ uint64_t h = merge_hash(left, right) % table->table_size;
255
+ MergeHashNode *n = malloc(sizeof(MergeHashNode));
256
+ if (!n) return;
257
+ n->left = strdup(left);
258
+ n->right = strdup(right);
259
+ n->rank = rank;
260
+ n->next = table->table[h];
261
+ table->table[h] = n;
262
+ table->num_merges++;
236
263
  }
237
264
 
238
265
  static void bpe_merge_table_free(BPEMergeTable *table) {
239
- for (int i = 0; i < table->num_merges; i++) {
240
- free(table->merges[i].left);
241
- free(table->merges[i].right);
242
- free(table->merges[i].merged);
266
+ if (!table->table) return;
267
+ for (int i = 0; i < table->table_size; i++) {
268
+ MergeHashNode *n = table->table[i];
269
+ while (n) {
270
+ MergeHashNode *next = n->next;
271
+ free(n->left);
272
+ free(n->right);
273
+ free(n);
274
+ n = next;
275
+ }
243
276
  }
244
- free(table->merges);
245
- table->merges = NULL;
246
- table->num_merges = 0;
277
+ free(table->table);
278
+ table->table = NULL;
247
279
  }
248
280
 
249
281
  static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
250
- for (int i = 0; i < table->num_merges; i++) {
251
- if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
252
- return table->merges[i].rank;
253
- }
282
+ uint64_t h = merge_hash(left, right) % table->table_size;
283
+ MergeHashNode *n = table->table[h];
284
+ while (n) {
285
+ if (strcmp(n->left, left) == 0 && strcmp(n->right, right) == 0)
286
+ return n->rank;
287
+ n = n->next;
254
288
  }
255
289
  return -1;
256
290
  }
257
291
 
258
- static char* bpe_merge(const BPEMergeTable *table, const char *left, const char *right) {
259
- for (int i = 0; i < table->num_merges; i++) {
260
- if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
261
- return table->merges[i].merged;
262
- }
263
- }
264
- return NULL;
265
- }
266
-
267
292
  /* ------------------------------------------------------------------------- */
268
- // BPE tokenization helper structures
293
+ // BPE tokenization (correct iterative algorithm)
269
294
  typedef struct {
270
- char *text;
271
- int start;
272
- int end;
273
- int prev;
274
- int next;
295
+ const char *text;
296
+ int start, end;
297
+ int prev, next;
275
298
  int used;
276
299
  } BPESymbol;
277
300
 
278
- static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int (*text_to_id)(void*, const char*), void *vocab_data, int *token_ids, int *num_tokens) {
279
- // Initialize symbols from characters
301
+ static int text_to_id(void *vocab_data, const char *text);
302
+
303
+ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word,
304
+ void *vocab_data, int *token_ids, int *num_tokens) {
280
305
  int word_len = strlen(word);
306
+ if (word_len == 0) return;
307
+
281
308
  int num_symbols = 0;
282
309
  BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
283
-
284
- // Split into UTF-8 characters
310
+ if (!symbols) return;
311
+
285
312
  int offset = 0;
286
313
  while (offset < word_len) {
287
314
  int char_len = unicode_len_utf8(word[offset]);
288
- symbols[num_symbols].text = (char*)word + offset;
315
+ if (offset + char_len > word_len) char_len = word_len - offset;
316
+ symbols[num_symbols].text = word;
289
317
  symbols[num_symbols].start = offset;
290
318
  symbols[num_symbols].end = offset + char_len;
291
319
  symbols[num_symbols].prev = num_symbols - 1;
@@ -294,110 +322,75 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
294
322
  offset += char_len;
295
323
  num_symbols++;
296
324
  }
297
-
325
+
326
+ if (num_symbols > 0) symbols[num_symbols - 1].next = -1;
327
+
298
328
  if (num_symbols <= 1) {
299
- // Single character, just tokenize it
300
329
  int id = text_to_id(vocab_data, word);
301
- if (id != -1) {
302
- token_ids[*num_tokens] = id;
303
- (*num_tokens)++;
304
- }
330
+ if (id != -1) token_ids[(*num_tokens)++] = id;
305
331
  free(symbols);
306
332
  return;
307
333
  }
308
-
309
- // Build priority queue for merges (simplified)
310
- typedef struct {
311
- int left;
312
- int right;
313
- int rank;
314
- } Bigram;
315
-
316
- Bigram *bigrams = malloc(word_len * word_len * sizeof(Bigram));
317
- int num_bigrams = 0;
318
-
319
- // Initialize bigrams
320
- for (int i = 0; i < num_symbols - 1; i++) {
321
- if (symbols[i].used && symbols[i+1].used) {
322
- // Get the concatenated string for this pair
323
- char *left_str = malloc(symbols[i].end - symbols[i].start + 1);
324
- char *right_str = malloc(symbols[i+1].end - symbols[i+1].start + 1);
325
- memcpy(left_str, symbols[i].text, symbols[i].end - symbols[i].start);
326
- memcpy(right_str, symbols[i+1].text, symbols[i+1].end - symbols[i+1].start);
327
- left_str[symbols[i].end - symbols[i].start] = '\0';
328
- right_str[symbols[i+1].end - symbols[i+1].start] = '\0';
329
-
330
- int rank = bpe_merge_rank(merges, left_str, right_str);
331
- if (rank != -1) {
332
- bigrams[num_bigrams].left = i;
333
- bigrams[num_bigrams].right = i+1;
334
- bigrams[num_bigrams].rank = rank;
335
- num_bigrams++;
336
- }
337
-
338
- free(left_str);
339
- free(right_str);
340
- }
341
- }
342
-
343
- // Sort bigrams by rank (lower rank = higher priority)
344
- for (int i = 0; i < num_bigrams - 1; i++) {
345
- for (int j = i+1; j < num_bigrams; j++) {
346
- if (bigrams[i].rank > bigrams[j].rank) {
347
- Bigram temp = bigrams[i];
348
- bigrams[i] = bigrams[j];
349
- bigrams[j] = temp;
334
+
335
+ while (1) {
336
+ int best_rank = INT_MAX;
337
+ int best_idx = -1;
338
+
339
+ int idx = 0;
340
+ while (idx != -1) {
341
+ int next = symbols[idx].next;
342
+ if (next != -1 && symbols[idx].used && symbols[next].used) {
343
+ int left_len = symbols[idx].end - symbols[idx].start;
344
+ int right_len = symbols[next].end - symbols[next].start;
345
+ char *left_str = malloc(left_len + 1);
346
+ char *right_str = malloc(right_len + 1);
347
+ if (left_str && right_str) {
348
+ memcpy(left_str, word + symbols[idx].start, left_len);
349
+ left_str[left_len] = '\0';
350
+ memcpy(right_str, word + symbols[next].start, right_len);
351
+ right_str[right_len] = '\0';
352
+ int rank = bpe_merge_rank(merges, left_str, right_str);
353
+ if (rank != -1 && rank < best_rank) {
354
+ best_rank = rank;
355
+ best_idx = idx;
356
+ }
357
+ }
358
+ free(left_str);
359
+ free(right_str);
350
360
  }
361
+ idx = symbols[idx].next;
351
362
  }
352
- }
353
-
354
- // Apply merges
355
- int *merged = calloc(num_symbols, sizeof(int));
356
- for (int i = 0; i < num_bigrams; i++) {
357
- int left = bigrams[i].left;
358
- int right = bigrams[i].right;
359
-
360
- if (merged[left] || merged[right]) continue;
361
-
362
- // Merge right into left
363
- symbols[left].end = symbols[right].end;
364
- symbols[left].next = symbols[right].next;
365
- merged[right] = 1;
366
-
367
- // Update next symbol's prev
368
- if (symbols[right].next < num_symbols) {
369
- symbols[symbols[right].next].prev = left;
363
+
364
+ if (best_idx == -1) break;
365
+
366
+ int right_idx = symbols[best_idx].next;
367
+ symbols[best_idx].end = symbols[right_idx].end;
368
+ symbols[best_idx].next = symbols[right_idx].next;
369
+ symbols[right_idx].used = 0;
370
+
371
+ if (symbols[right_idx].next != -1) {
372
+ symbols[symbols[right_idx].next].prev = best_idx;
370
373
  }
371
374
  }
372
-
373
- // Collect final tokens
375
+
374
376
  for (int i = 0; i < num_symbols; i++) {
375
- if (!merged[i] && symbols[i].used) {
376
- // Extract the substring
377
- char *substr = malloc(symbols[i].end - symbols[i].start + 1);
378
- memcpy(substr, word + symbols[i].start, symbols[i].end - symbols[i].start);
379
- substr[symbols[i].end - symbols[i].start] = '\0';
380
-
381
- int id = text_to_id(vocab_data, substr);
382
- if (id != -1) {
383
- token_ids[*num_tokens] = id;
384
- (*num_tokens)++;
385
- } else {
386
- // Unknown token - use byte-level fallback
387
- // For simplicity, we'll use space as a placeholder
388
- // In a full implementation, you'd encode bytes individually
377
+ if (symbols[i].used) {
378
+ int len = symbols[i].end - symbols[i].start;
379
+ char *substr = malloc(len + 1);
380
+ if (substr) {
381
+ memcpy(substr, word + symbols[i].start, len);
382
+ substr[len] = '\0';
383
+ int id = text_to_id(vocab_data, substr);
384
+ if (id != -1) token_ids[(*num_tokens)++] = id;
385
+ free(substr);
389
386
  }
390
-
391
- free(substr);
392
387
  }
393
388
  }
394
-
395
- free(bigrams);
396
- free(merged);
397
389
  free(symbols);
398
390
  }
399
391
 
400
392
  /* ------------------------------------------------------------------------- */
393
+ // GGUF parsing
401
394
  static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
402
395
  if (*p + sz > end) return 0;
403
396
  *p += sz;
@@ -405,14 +398,14 @@ static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
405
398
  }
406
399
 
407
400
  static uint32_t rd32(uint8_t **p, uint8_t *end) {
408
- uint32_t v = 0;
401
+ uint32_t v;
409
402
  if (!safe_advance(p, end, 4)) return 0;
410
403
  memcpy(&v, *p - 4, 4);
411
404
  return v;
412
405
  }
413
406
 
414
407
  static uint64_t rd64(uint8_t **p, uint8_t *end) {
415
- uint64_t v = 0;
408
+ uint64_t v;
416
409
  if (!safe_advance(p, end, 8)) return 0;
417
410
  memcpy(&v, *p - 8, 8);
418
411
  return v;
@@ -423,9 +416,9 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
423
416
  uint64_t len;
424
417
  memcpy(&len, *p, 8);
425
418
  *p += 8;
426
- if (len == 0 || len > (1 << 20)) return NULL;
419
+ if (len == 0 || len > (1<<20)) return NULL;
427
420
  if (*p + len > end) return NULL;
428
- char *s = malloc(len + 1);
421
+ char *s = malloc(len+1);
429
422
  if (!s) return NULL;
430
423
  memcpy(s, *p, len);
431
424
  s[len] = '\0';
@@ -436,52 +429,56 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
436
429
  static void align_to_32(uint8_t **p, uint8_t *end, uint8_t *base) {
437
430
  size_t off = *p - base;
438
431
  size_t aligned = (off + GGUF_ALIGN - 1) & ~(GGUF_ALIGN - 1);
439
- if (base + aligned <= end)
440
- *p = base + aligned;
432
+ if (base + aligned <= end) *p = base + aligned;
441
433
  }
442
434
 
443
435
  /* ------------------------------------------------------------------------- */
436
+ // Hash table for vocabulary
444
437
  typedef struct HashNode {
445
438
  char *key;
446
- int id;
439
+ int id;
447
440
  struct HashNode *next;
448
441
  } HashNode;
449
442
 
450
443
  typedef struct {
451
- int vocab_size;
452
- int dim;
453
- char **tokens;
454
- float *float_data;
455
- void *tensor_data;
456
- int tensor_type;
457
- void *mapped;
458
- size_t mapped_size;
444
+ int vocab_size;
445
+ int dim;
446
+ char **tokens;
447
+ void *mapped;
448
+ size_t mapped_size;
459
449
  HashNode **table;
460
-
461
- // BPE tokenization data
462
450
  BPEMergeTable merges;
463
- RegexPattern *pre_patterns;
464
- int num_pre_patterns;
465
451
  int unknown_token_id;
466
452
  int bos_token_id;
467
453
  int eos_token_id;
468
454
  int vocab_type;
455
+ char space_marker[8];
456
+ int space_marker_len;
457
+ const void *raw_tensor_data;
458
+ int tensor_type;
459
+ size_t row_bytes;
460
+ int need_transpose;
461
+ uint64_t raw_dim0, raw_dim1;
462
+ int normalize;
469
463
  } EmbedModel;
470
464
 
471
465
  typedef struct {
472
466
  EmbedModel *model;
473
467
  } ruby_embedder;
474
468
 
475
- static unsigned long hash(const char *s) {
476
- unsigned long h = 5381;
477
- int c;
478
- while ((c = *s++)) h = ((h << 5) + h) + c;
469
+ static uint64_t vocab_hash(const char *s) {
470
+ uint64_t h = 0xcbf29ce484222325ULL;
471
+ while (*s) {
472
+ h ^= (uint64_t)(unsigned char)*s++;
473
+ h *= 0x100000001b3ULL;
474
+ }
479
475
  return h % HASH_SIZE;
480
476
  }
481
477
 
482
478
  static void hset(EmbedModel *m, char *k, int id) {
483
- unsigned long h = hash(k);
479
+ uint64_t h = vocab_hash(k);
484
480
  HashNode *n = malloc(sizeof(*n));
481
+ if (!n) return;
485
482
  n->key = k;
486
483
  n->id = id;
487
484
  n->next = m->table[h];
@@ -489,7 +486,8 @@ static void hset(EmbedModel *m, char *k, int id) {
489
486
  }
490
487
 
491
488
  static int hget(EmbedModel *m, const char *k) {
492
- HashNode *n = m->table[hash(k)];
489
+ if (!k || !m->table) return -1;
490
+ HashNode *n = m->table[vocab_hash(k)];
493
491
  while (n) {
494
492
  if (strcmp(n->key, k) == 0) return n->id;
495
493
  n = n->next;
@@ -498,48 +496,62 @@ static int hget(EmbedModel *m, const char *k) {
498
496
  }
499
497
 
500
498
  static int text_to_id(void *vocab_data, const char *text) {
501
- EmbedModel *m = (EmbedModel*)vocab_data;
502
- return hget(m, text);
499
+ return hget((EmbedModel*)vocab_data, text);
503
500
  }
504
501
 
505
502
  /* ------------------------------------------------------------------------- */
503
+ // File mapping
506
504
  static void *map_file(const char *path, size_t *size) {
507
505
  int fd = open(path, O_RDONLY);
508
506
  if (fd < 0) return NULL;
509
507
  struct stat st;
510
508
  if (fstat(fd, &st) != 0) { close(fd); return NULL; }
511
509
  *size = st.st_size;
510
+ if (*size == 0) { close(fd); return NULL; }
512
511
  void *data = mmap(NULL, *size, PROT_READ, MAP_PRIVATE, fd, 0);
513
512
  close(fd);
514
- if (data == MAP_FAILED) return NULL;
515
- return data;
513
+ return data == MAP_FAILED ? NULL : data;
516
514
  }
517
515
 
518
516
  /* ------------------------------------------------------------------------- */
517
+ // FP16 conversion (corrected)
519
518
  static float fp16_to_fp32(uint16_t h) {
520
- const uint16_t sign = (h >> 15) & 1;
521
- const uint16_t exp = (h >> 10) & 0x1F;
522
- const uint16_t mant = h & 0x3FF;
523
- float val;
519
+ const uint32_t sign = (h >> 15) & 1;
520
+ const uint32_t exp = (h >> 10) & 0x1F;
521
+ const uint32_t mant = h & 0x3FF;
522
+
523
+ uint32_t f;
524
524
  if (exp == 0) {
525
- val = (mant / 1024.0f) * 6.103515625e-5f;
525
+ if (mant == 0) {
526
+ f = sign << 31;
527
+ } else {
528
+ uint32_t e = 0;
529
+ uint32_t m = mant;
530
+ while (!(m & 0x400)) { m <<= 1; e++; }
531
+ f = (sign << 31) | ((127 - 15 - e + 1) << 23) | ((m & 0x3FF) << 13);
532
+ }
526
533
  } else if (exp == 31) {
527
- return 0.0f;
534
+ f = (sign << 31) | (0xFF << 23) | (mant << 13);
528
535
  } else {
529
- val = (1.0f + mant / 1024.0f) * (1 << (exp - 15));
536
+ f = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
530
537
  }
531
- return sign ? -val : val;
538
+
539
+ float result;
540
+ memcpy(&result, &f, sizeof(result));
541
+ return result;
532
542
  }
533
543
 
534
544
  /* ------------------------------------------------------------------------- */
535
- /* Block dequantization */
536
-
545
+ // Block dequantization functions (correct sizes)
537
546
  static void dequantize_row_q4_0(const void *vx, float *y, int k) {
538
- const int nb = k / 32;
547
+ const int nb = k / QK8_0;
539
548
  const uint8_t *x = vx;
540
549
  for (int i = 0; i < nb; i++) {
541
- const float d = ((const float*)(x + i*34))[0];
542
- const uint8_t *q = x + i*34 + 4;
550
+ const uint8_t *block = x + i * 18;
551
+ uint16_t d16;
552
+ memcpy(&d16, block, 2);
553
+ const float d = fp16_to_fp32(d16);
554
+ const uint8_t *q = block + 2;
543
555
  for (int j = 0; j < 32; j++) {
544
556
  const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
545
557
  y[i*32 + j] = (v - 8.0f) * d;
@@ -548,12 +560,16 @@ static void dequantize_row_q4_0(const void *vx, float *y, int k) {
548
560
  }
549
561
 
550
562
  static void dequantize_row_q4_1(const void *vx, float *y, int k) {
551
- const int nb = k / 32;
563
+ const int nb = k / QK8_0;
552
564
  const uint8_t *x = vx;
553
565
  for (int i = 0; i < nb; i++) {
554
- const float d = ((const float*)(x + i*36))[0];
555
- const float m = ((const float*)(x + i*36))[1];
556
- const uint8_t *q = x + i*36 + 8;
566
+ const uint8_t *block = x + i * 20;
567
+ uint16_t d16, m16;
568
+ memcpy(&d16, block, 2);
569
+ memcpy(&m16, block + 2, 2);
570
+ const float d = fp16_to_fp32(d16);
571
+ const float m = fp16_to_fp32(m16);
572
+ const uint8_t *q = block + 4;
557
573
  for (int j = 0; j < 32; j++) {
558
574
  const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
559
575
  y[i*32 + j] = v * d + m;
@@ -562,14 +578,16 @@ static void dequantize_row_q4_1(const void *vx, float *y, int k) {
562
578
  }
563
579
 
564
580
  static void dequantize_row_q5_0(const void *vx, float *y, int k) {
565
- const int nb = k / 32;
581
+ const int nb = k / QK8_0;
566
582
  const uint8_t *x = vx;
567
583
  for (int i = 0; i < nb; i++) {
568
- const float d = ((const float*)(x + i*40))[0];
569
- const uint8_t *qh = x + i*40 + 4;
570
- const uint8_t *ql = x + i*40 + 8;
584
+ const uint8_t *block = x + i * 22;
585
+ uint16_t d16;
586
+ memcpy(&d16, block, 2);
587
+ const float d = fp16_to_fp32(d16);
571
588
  uint32_t qh32;
572
- memcpy(&qh32, qh, 4);
589
+ memcpy(&qh32, block + 2, 4);
590
+ const uint8_t *ql = block + 6;
573
591
  for (int j = 0; j < 32; j++) {
574
592
  const uint8_t vh = (qh32 >> j) & 1;
575
593
  const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
@@ -579,15 +597,18 @@ static void dequantize_row_q5_0(const void *vx, float *y, int k) {
579
597
  }
580
598
 
581
599
  static void dequantize_row_q5_1(const void *vx, float *y, int k) {
582
- const int nb = k / 32;
600
+ const int nb = k / QK8_0;
583
601
  const uint8_t *x = vx;
584
602
  for (int i = 0; i < nb; i++) {
585
- const float d = ((const float*)(x + i*44))[0];
586
- const float m = ((const float*)(x + i*44))[1];
587
- const uint8_t *qh = x + i*44 + 8;
588
- const uint8_t *ql = x + i*44 + 12;
603
+ const uint8_t *block = x + i * 24;
604
+ uint16_t d16, m16;
605
+ memcpy(&d16, block, 2);
606
+ memcpy(&m16, block + 2, 2);
607
+ const float d = fp16_to_fp32(d16);
608
+ const float m = fp16_to_fp32(m16);
589
609
  uint32_t qh32;
590
- memcpy(&qh32, qh, 4);
610
+ memcpy(&qh32, block + 4, 4);
611
+ const uint8_t *ql = block + 8;
591
612
  for (int j = 0; j < 32; j++) {
592
613
  const uint8_t vh = (qh32 >> j) & 1;
593
614
  const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
@@ -597,11 +618,13 @@ static void dequantize_row_q5_1(const void *vx, float *y, int k) {
597
618
  }
598
619
 
599
620
  static void dequantize_row_q8_0(const void *vx, float *y, int k) {
600
- const int nb = k / 32;
621
+ const int nb = k / QK8_0;
601
622
  const uint8_t *x = vx;
602
623
  for (int i = 0; i < nb; i++) {
603
- const float d = ((const float*)(x + i*36))[0];
604
- const int8_t *q = (const int8_t*)(x + i*36 + 4);
624
+ const uint8_t *block = x + i * 34;
625
+ float d;
626
+ memcpy(&d, block, 4);
627
+ const int8_t *q = (const int8_t*)(block + 4);
605
628
  for (int j = 0; j < 32; j++) {
606
629
  y[i*32 + j] = (float)q[j] * d;
607
630
  }
@@ -609,191 +632,315 @@ static void dequantize_row_q8_0(const void *vx, float *y, int k) {
609
632
  }
610
633
 
611
634
  static void dequantize_row_q8_1(const void *vx, float *y, int k) {
612
- const int nb = k / 32;
635
+ const int nb = k / QK8_0;
613
636
  const uint8_t *x = vx;
614
637
  for (int i = 0; i < nb; i++) {
615
- const float d = ((const float*)(x + i*40))[0];
616
- const float s = ((const float*)(x + i*40))[1];
617
- const int8_t *q = (const int8_t*)(x + i*40 + 8);
638
+ const uint8_t *block = x + i * 40;
639
+ float d, s;
640
+ memcpy(&d, block, 4);
641
+ memcpy(&s, block + 4, 4);
642
+ const int8_t *q = (const int8_t*)(block + 8);
618
643
  for (int j = 0; j < 32; j++) {
619
644
  y[i*32 + j] = (float)q[j] * d + s;
620
645
  }
621
646
  }
622
647
  }
623
648
 
624
- /* K-quants */
649
+ // K-quant scale helpers
650
+ static inline void get_scale_min_k4(int j, const uint8_t *q, uint8_t *d, uint8_t *m) {
651
+ if (j < 4) {
652
+ *d = q[j] & 63;
653
+ *m = q[j + 4] & 63;
654
+ } else {
655
+ *d = (q[j+4] & 0xF) | ((q[j-3] >> 6) << 4);
656
+ *m = (q[j+4] >> 4) | ((q[j-1] >> 6) << 4);
657
+ }
658
+ }
659
+
625
660
  static void dequantize_row_q2_K(const void *vx, float *y, int k) {
626
- const int nb = k / 256;
661
+ const int nb = k / QK_K;
627
662
  const uint8_t *x = vx;
628
663
  for (int i = 0; i < nb; i++) {
629
- const float d = ((const float*)(x + i*336))[0];
630
- const float m = ((const float*)(x + i*336))[1];
631
- const uint8_t *q = x + i*336 + 8;
632
- const uint8_t *scales = q + 64;
633
- for (int j = 0; j < 256; j += 32) {
634
- const uint8_t ls = scales[j/32] & 0xF;
635
- const uint8_t ms = scales[j/32] >> 4;
636
- for (int l = 0; l < 32; l++) {
664
+ const uint8_t *block = x + i * 84;
665
+ uint16_t d16, dmin16;
666
+ memcpy(&d16, block, 2);
667
+ memcpy(&dmin16, block + 2, 2);
668
+ const float d = fp16_to_fp32(d16);
669
+ const float min = fp16_to_fp32(dmin16);
670
+ const uint8_t *scales = block + 4;
671
+ const uint8_t *q = block + 20;
672
+ for (int j = 0; j < QK_K; j += 64) {
673
+ const float dl = d * (scales[j/64] & 0xF);
674
+ const float ml = min * (scales[j/64] >> 4);
675
+ for (int l = 0; l < 64; l++) {
637
676
  const int v = (q[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
638
- const float dl = d * (ls - 32);
639
- const float ml = m * (ms - 32);
640
- y[i*256 + j + l] = v * dl + ml;
677
+ y[i*QK_K + j + l] = v * dl + ml;
641
678
  }
642
679
  }
643
680
  }
644
681
  }
645
682
 
646
683
  static void dequantize_row_q3_K(const void *vx, float *y, int k) {
647
- const int nb = k / 256;
684
+ const int nb = k / QK_K;
648
685
  const uint8_t *x = vx;
649
686
  for (int i = 0; i < nb; i++) {
650
- const float d = ((const float*)(x + i*352))[0];
651
- const uint8_t *q = x + i*352 + 4;
652
- const uint8_t *scales = q + 256;
653
- const uint8_t *h = scales + 32;
654
- for (int j = 0; j < 256; j += 64) {
687
+ const uint8_t *block = x + i * 110;
688
+ uint16_t d16;
689
+ memcpy(&d16, block, 2);
690
+ const float d = fp16_to_fp32(d16);
691
+ const uint8_t *hmask = block + 2;
692
+ const uint8_t *q = block + 34;
693
+ const uint8_t *scales = block + 98;
694
+ for (int j = 0; j < QK_K; j += 64) {
655
695
  const uint8_t ls1 = scales[j/64] & 0x1F;
656
- const uint8_t ls2 = (scales[j/64] >> 4) | ((scales[j/64+1] & 0x0F) << 4);
657
- const uint8_t ms = scales[j/64+1] >> 4;
696
+ const uint8_t ls2 = (scales[j/64] >> 5) | ((scales[j/64 + 1] & 0x7) << 3);
697
+ const uint8_t ls3 = ((scales[j/64 + 1] >> 3) & 0x1F);
698
+ const uint8_t ls4 = (scales[j/64 + 1] >> 8);
658
699
  for (int l = 0; l < 64; l++) {
659
700
  int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
660
- const int bit = (h[(j+l)/8] >> ((j+l)%8)) & 1;
701
+ const int bit = (hmask[(j+l)/8] >> ((j+l)%8)) & 1;
661
702
  v |= bit << 4;
662
- const float dl = d * (ls1 - 32);
663
- const float ml = (l < 32) ? (ls2 - 32) * d : (ms - 32) * d;
664
- y[i*256 + j + l] = v * dl + ml;
703
+ float ls;
704
+ if (l < 16) ls = ls1;
705
+ else if (l < 32) ls = ls2;
706
+ else if (l < 48) ls = ls3;
707
+ else ls = ls4;
708
+ y[i*QK_K + j + l] = (v - 32.0f) * d * ls;
665
709
  }
666
710
  }
667
711
  }
668
712
  }
669
713
 
670
714
  static void dequantize_row_q4_K(const void *vx, float *y, int k) {
671
- const int nb = k / 256;
715
+ const int nb = k / QK_K;
672
716
  const uint8_t *x = vx;
673
717
  for (int i = 0; i < nb; i++) {
674
- const float d = ((const float*)(x + i*416))[0];
675
- const float m = ((const float*)(x + i*416))[1];
676
- const uint8_t *q = x + i*416 + 8;
677
- const uint8_t *scales = q + 128;
678
- for (int j = 0; j < 256; j += 32) {
679
- const uint8_t ls = scales[j/32] & 0x3F;
680
- const uint8_t ms = scales[j/32] >> 6;
718
+ const uint8_t *block = x + i * 144;
719
+ uint16_t d16, dmin16;
720
+ memcpy(&d16, block, 2);
721
+ memcpy(&dmin16, block + 2, 2);
722
+ const float d = fp16_to_fp32(d16);
723
+ const float min = fp16_to_fp32(dmin16);
724
+ const uint8_t *scales = block + 4;
725
+ const uint8_t *q = block + 16;
726
+ int is = 0;
727
+ for (int j = 0; j < QK_K; j += 64) {
728
+ uint8_t sc, m;
729
+ get_scale_min_k4(is, scales, &sc, &m);
730
+ float d1 = d * sc;
731
+ float m1 = min * m;
732
+ get_scale_min_k4(is + 1, scales, &sc, &m);
733
+ float d2 = d * sc;
734
+ float m2 = min * m;
681
735
  for (int l = 0; l < 32; l++) {
682
- const int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
683
- const float dl = d * (ls - 32);
684
- const float ml = m * (ms - 2);
685
- y[i*256 + j + l] = v * dl + ml;
736
+ y[i*QK_K + j + l] = d1 * (q[l] & 0xF) - m1;
686
737
  }
738
+ for (int l = 0; l < 32; l++) {
739
+ y[i*QK_K + j + 32 + l] = d2 * (q[l] >> 4) - m2;
740
+ }
741
+ q += 32;
742
+ is += 2;
687
743
  }
688
744
  }
689
745
  }
690
746
 
691
747
  static void dequantize_row_q5_K(const void *vx, float *y, int k) {
692
- const int nb = k / 256;
748
+ const int nb = k / QK_K;
693
749
  const uint8_t *x = vx;
694
750
  for (int i = 0; i < nb; i++) {
695
- const float d = ((const float*)(x + i*448))[0];
696
- const float m = ((const float*)(x + i*448))[1];
697
- const uint8_t *q = x + i*448 + 8;
698
- const uint8_t *qh = q + 128;
699
- const uint8_t *scales = qh + 32;
700
- for (int j = 0; j < 256; j += 32) {
701
- const uint8_t ls = scales[j/32] & 0x3F;
702
- const uint8_t ms = scales[j/32] >> 6;
751
+ const uint8_t *block = x + i * 176;
752
+ uint16_t d16, dmin16;
753
+ memcpy(&d16, block, 2);
754
+ memcpy(&dmin16, block + 2, 2);
755
+ const float d = fp16_to_fp32(d16);
756
+ const float min = fp16_to_fp32(dmin16);
757
+ const uint8_t *scales = block + 4;
758
+ const uint8_t *qh = block + 16;
759
+ const uint8_t *ql = block + 48;
760
+ int is = 0;
761
+ for (int j = 0; j < QK_K; j += 64) {
762
+ uint8_t sc, m;
763
+ get_scale_min_k4(is, scales, &sc, &m);
764
+ float d1 = d * sc;
765
+ float m1 = min * m;
766
+ get_scale_min_k4(is + 1, scales, &sc, &m);
767
+ float d2 = d * sc;
768
+ float m2 = min * m;
703
769
  for (int l = 0; l < 32; l++) {
704
- int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
705
- const int bit = (qh[(j+l)/8] >> ((j+l)%8)) & 1;
706
- v |= bit << 4;
707
- const float dl = d * (ls - 32);
708
- const float ml = m * (ms - 2);
709
- y[i*256 + j + l] = v * dl + ml;
770
+ int vh = (qh[j/64 * 4 + l/8] >> (l%8)) & 1;
771
+ int v = (ql[l] & 0xF) | (vh << 4);
772
+ y[i*QK_K + j + l] = d1 * v - m1;
773
+ }
774
+ for (int l = 0; l < 32; l++) {
775
+ int vh = (qh[j/64 * 4 + 4 + l/8] >> (l%8)) & 1;
776
+ int v = (ql[l] >> 4) | (vh << 4);
777
+ y[i*QK_K + j + 32 + l] = d2 * v - m2;
710
778
  }
779
+ ql += 32;
780
+ is += 2;
711
781
  }
712
782
  }
713
783
  }
714
784
 
715
785
  static void dequantize_row_q6_K(const void *vx, float *y, int k) {
716
- const int nb = k / 256;
786
+ const int nb = k / QK_K;
717
787
  const uint8_t *x = vx;
718
788
  for (int i = 0; i < nb; i++) {
719
- const float d = ((const float*)(x + i*480))[0];
720
- const uint8_t *q = x + i*480 + 4;
721
- const uint8_t *qh = q + 256;
722
- const uint8_t *scales = qh + 64;
723
- for (int j = 0; j < 256; j += 64) {
724
- const uint8_t ls = scales[j/64];
725
- for (int l = 0; l < 64; l++) {
726
- int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
727
- const int bit = (qh[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
728
- v |= bit << 4;
729
- y[i*256 + j + l] = v * d * (ls - 32);
789
+ const uint8_t *block = x + i * 210;
790
+ const uint8_t *ql = block;
791
+ const uint8_t *qh = block + 128;
792
+ const int8_t *scales = (const int8_t*)(block + 192);
793
+ uint16_t d16;
794
+ memcpy(&d16, block + 208, 2);
795
+ const float d = fp16_to_fp32(d16);
796
+ for (int j = 0; j < QK_K; j += 128) {
797
+ for (int l = 0; l < 32; l++) {
798
+ int v = (ql[j/2 + l] & 0xF) | (((qh[j/4 + l/2] >> ((l%2)*4)) & 0xF) << 4);
799
+ y[i*QK_K + j + l] = v * d * scales[j/128 * 8 + l/4];
800
+ }
801
+ for (int l = 0; l < 32; l++) {
802
+ int v = (ql[j/2 + 32 + l] >> 4) | (((qh[j/4 + 16 + l/2] >> ((l%2)*4)) & 0xF) << 4);
803
+ y[i*QK_K + j + 32 + l] = v * d * scales[j/128 * 8 + 8 + l/4];
804
+ }
805
+ for (int l = 0; l < 32; l++) {
806
+ int v = (ql[j/2 + 64 + l] & 0xF) | (((qh[j/4 + 32 + l/2] >> ((l%2)*4)) & 0xF) << 4);
807
+ y[i*QK_K + j + 64 + l] = v * d * scales[j/128 * 8 + 4 + l/4];
808
+ }
809
+ for (int l = 0; l < 32; l++) {
810
+ int v = (ql[j/2 + 96 + l] >> 4) | (((qh[j/4 + 48 + l/2] >> ((l%2)*4)) & 0xF) << 4);
811
+ y[i*QK_K + j + 96 + l] = v * d * scales[j/128 * 8 + 12 + l/4];
730
812
  }
731
813
  }
732
814
  }
733
815
  }
734
816
 
735
817
  static void dequantize_row_q8_K(const void *vx, float *y, int k) {
736
- const int nb = k / 256;
818
+ const int nb = k / QK_K;
737
819
  const uint8_t *x = vx;
738
820
  for (int i = 0; i < nb; i++) {
739
- const float d = ((const float*)(x + i*544))[0];
740
- const int8_t *q = (const int8_t*)(x + i*544 + 4);
741
- const uint8_t *scales = (const uint8_t*)(q + 256);
742
- for (int j = 0; j < 256; j += 32) {
743
- const uint8_t ls = scales[j/32];
744
- for (int l = 0; l < 32; l++) {
745
- y[i*256 + j + l] = (float)q[j+l] * d * ls;
746
- }
821
+ const uint8_t *block = x + i * 292;
822
+ float d;
823
+ memcpy(&d, block, 4);
824
+ const int8_t *q = (const int8_t*)(block + 4);
825
+ for (int j = 0; j < QK_K; j++) {
826
+ y[i*QK_K + j] = (float)q[j] * d;
747
827
  }
748
828
  }
749
829
  }
750
830
 
751
- /* ------------------------------------------------------------------------- */
752
- static float* dequantize_tensor(const void *data, int type, int n_rows, int n_cols) {
753
- if (type == GGML_TYPE_F32) {
754
- float *out = malloc(n_rows * n_cols * sizeof(float));
755
- if (!out) return NULL;
756
- memcpy(out, data, n_rows * n_cols * sizeof(float));
757
- return out;
831
+ // Lazy single-row dequantization
832
+ static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
833
+ if (!m->raw_tensor_data || row < 0 || row >= m->vocab_size) {
834
+ memset(out, 0, sizeof(float) * m->dim);
835
+ return;
758
836
  }
759
- if (type == GGML_TYPE_F16) {
760
- float *out = malloc(n_rows * n_cols * sizeof(float));
761
- if (!out) return NULL;
762
- const uint16_t *in = data;
763
- for (int i = 0; i < n_rows * n_cols; i++) {
764
- out[i] = fp16_to_fp32(in[i]);
837
+
838
+ const uint8_t *raw;
839
+ int effective_cols;
840
+
841
+ if (m->need_transpose) {
842
+ int src_row_size;
843
+ switch (m->tensor_type) {
844
+ case GGML_TYPE_F32: src_row_size = m->raw_dim1 * sizeof(float); break;
845
+ case GGML_TYPE_F16: src_row_size = m->raw_dim1 * sizeof(uint16_t); break;
846
+ default: {
847
+ size_t rb = 0;
848
+ int nc = (int)m->raw_dim1;
849
+ switch (m->tensor_type) {
850
+ case GGML_TYPE_Q4_0: rb = (nc / 32) * 18; break;
851
+ case GGML_TYPE_Q4_1: rb = (nc / 32) * 20; break;
852
+ case GGML_TYPE_Q5_0: rb = (nc / 32) * 22; break;
853
+ case GGML_TYPE_Q5_1: rb = (nc / 32) * 24; break;
854
+ case GGML_TYPE_Q8_0: rb = (nc / 32) * 34; break;
855
+ case GGML_TYPE_Q8_1: rb = (nc / 32) * 40; break;
856
+ case GGML_TYPE_Q2_K: rb = (nc / 256) * 84; break;
857
+ case GGML_TYPE_Q3_K: rb = (nc / 256) * 110; break;
858
+ case GGML_TYPE_Q4_K: rb = (nc / 256) * 144; break;
859
+ case GGML_TYPE_Q5_K: rb = (nc / 256) * 176; break;
860
+ case GGML_TYPE_Q6_K: rb = (nc / 256) * 210; break;
861
+ case GGML_TYPE_Q8_K: rb = (nc / 256) * 292; break;
862
+ default: src_row_size = 0; return;
863
+ }
864
+ src_row_size = (int)rb;
865
+ }
765
866
  }
766
- return out;
867
+ float *temp_row = malloc(m->raw_dim1 * sizeof(float));
868
+ if (!temp_row) return;
869
+ for (int col = 0; col < m->dim; col++) {
870
+ const uint8_t *src_row = (const uint8_t*)m->raw_tensor_data + col * src_row_size;
871
+ if (m->tensor_type == GGML_TYPE_F32) {
872
+ float val;
873
+ memcpy(&val, src_row + row * sizeof(float), sizeof(float));
874
+ out[col] = val;
875
+ } else if (m->tensor_type == GGML_TYPE_F16) {
876
+ uint16_t val;
877
+ memcpy(&val, src_row + row * sizeof(uint16_t), sizeof(uint16_t));
878
+ out[col] = fp16_to_fp32(val);
879
+ } else {
880
+ memset(out, 0, sizeof(float) * m->dim);
881
+ free(temp_row);
882
+ return;
883
+ }
884
+ }
885
+ free(temp_row);
886
+ return;
767
887
  }
768
888
 
769
- float *out = malloc(n_rows * n_cols * sizeof(float));
770
- if (!out) return NULL;
771
- const uint8_t *in = data;
772
- size_t row_bytes = 0;
773
- void (*dequant_func)(const void*, float*, int) = NULL;
774
-
775
- switch (type) {
776
- case GGML_TYPE_Q4_0: dequant_func = dequantize_row_q4_0; row_bytes = (n_cols / 32) * 34; break;
777
- case GGML_TYPE_Q4_1: dequant_func = dequantize_row_q4_1; row_bytes = (n_cols / 32) * 36; break;
778
- case GGML_TYPE_Q5_0: dequant_func = dequantize_row_q5_0; row_bytes = (n_cols / 32) * 40; break;
779
- case GGML_TYPE_Q5_1: dequant_func = dequantize_row_q5_1; row_bytes = (n_cols / 32) * 44; break;
780
- case GGML_TYPE_Q8_0: dequant_func = dequantize_row_q8_0; row_bytes = (n_cols / 32) * 36; break;
781
- case GGML_TYPE_Q8_1: dequant_func = dequantize_row_q8_1; row_bytes = (n_cols / 32) * 40; break;
782
- case GGML_TYPE_Q2_K: dequant_func = dequantize_row_q2_K; row_bytes = (n_cols / 256) * 336; break;
783
- case GGML_TYPE_Q3_K: dequant_func = dequantize_row_q3_K; row_bytes = (n_cols / 256) * 352; break;
784
- case GGML_TYPE_Q4_K: dequant_func = dequantize_row_q4_K; row_bytes = (n_cols / 256) * 416; break;
785
- case GGML_TYPE_Q5_K: dequant_func = dequantize_row_q5_K; row_bytes = (n_cols / 256) * 448; break;
786
- case GGML_TYPE_Q6_K: dequant_func = dequantize_row_q6_K; row_bytes = (n_cols / 256) * 480; break;
787
- case GGML_TYPE_Q8_K: dequant_func = dequantize_row_q8_K; row_bytes = (n_cols / 256) * 544; break;
889
+ raw = (const uint8_t*)m->raw_tensor_data + row * m->row_bytes;
890
+ effective_cols = m->dim;
891
+
892
+ switch (m->tensor_type) {
893
+ case GGML_TYPE_F32:
894
+ memcpy(out, raw, effective_cols * sizeof(float));
895
+ break;
896
+ case GGML_TYPE_F16:
897
+ for (int j = 0; j < effective_cols; j++) {
898
+ uint16_t h;
899
+ memcpy(&h, raw + j * sizeof(uint16_t), sizeof(uint16_t));
900
+ out[j] = fp16_to_fp32(h);
901
+ }
902
+ break;
903
+ case GGML_TYPE_Q4_0: dequantize_row_q4_0(raw, out, effective_cols); break;
904
+ case GGML_TYPE_Q4_1: dequantize_row_q4_1(raw, out, effective_cols); break;
905
+ case GGML_TYPE_Q5_0: dequantize_row_q5_0(raw, out, effective_cols); break;
906
+ case GGML_TYPE_Q5_1: dequantize_row_q5_1(raw, out, effective_cols); break;
907
+ case GGML_TYPE_Q8_0: dequantize_row_q8_0(raw, out, effective_cols); break;
908
+ case GGML_TYPE_Q8_1: dequantize_row_q8_1(raw, out, effective_cols); break;
909
+ case GGML_TYPE_Q2_K: dequantize_row_q2_K(raw, out, effective_cols); break;
910
+ case GGML_TYPE_Q3_K: dequantize_row_q3_K(raw, out, effective_cols); break;
911
+ case GGML_TYPE_Q4_K: dequantize_row_q4_K(raw, out, effective_cols); break;
912
+ case GGML_TYPE_Q5_K: dequantize_row_q5_K(raw, out, effective_cols); break;
913
+ case GGML_TYPE_Q6_K: dequantize_row_q6_K(raw, out, effective_cols); break;
914
+ case GGML_TYPE_Q8_K: dequantize_row_q8_K(raw, out, effective_cols); break;
788
915
  default:
789
- free(out);
790
- return NULL;
916
+ memset(out, 0, sizeof(float) * effective_cols);
791
917
  }
792
918
 
793
- for (int r = 0; r < n_rows; r++) {
794
- dequant_func(in + r * row_bytes, out + r * n_cols, n_cols);
919
+ for (int j = 0; j < effective_cols; j++) {
920
+ if (isnan(out[j]) || isinf(out[j]) || fabsf(out[j]) > 1e10f) {
921
+ out[j] = 0.0f;
922
+ }
923
+ }
924
+ }
925
+
926
+ static size_t get_row_bytes(int type, int n_cols) {
927
+ switch (type) {
928
+ case GGML_TYPE_F32: return n_cols * sizeof(float);
929
+ case GGML_TYPE_F16: return n_cols * sizeof(uint16_t);
930
+ case GGML_TYPE_Q4_0: return (n_cols / 32) * 18;
931
+ case GGML_TYPE_Q4_1: return (n_cols / 32) * 20;
932
+ case GGML_TYPE_Q5_0: return (n_cols / 32) * 22;
933
+ case GGML_TYPE_Q5_1: return (n_cols / 32) * 24;
934
+ case GGML_TYPE_Q8_0: return (n_cols / 32) * 34;
935
+ case GGML_TYPE_Q8_1: return (n_cols / 32) * 40;
936
+ case GGML_TYPE_Q2_K: return (n_cols / 256) * 84;
937
+ case GGML_TYPE_Q3_K: return (n_cols / 256) * 110;
938
+ case GGML_TYPE_Q4_K: return (n_cols / 256) * 144;
939
+ case GGML_TYPE_Q5_K: return (n_cols / 256) * 176;
940
+ case GGML_TYPE_Q6_K: return (n_cols / 256) * 210;
941
+ case GGML_TYPE_Q8_K: return (n_cols / 256) * 292;
942
+ default: return 0;
795
943
  }
796
- return out;
797
944
  }
798
945
 
799
946
  /* ------------------------------------------------------------------------- */
@@ -809,7 +956,7 @@ static int skip_value(uint8_t **p, uint8_t *end, uint32_t type) {
809
956
  case 9: {
810
957
  uint32_t subtype = rd32(p, end);
811
958
  uint64_t n = rd64(p, end);
812
- for (uint64_t i = 0; i < n; i++)
959
+ for (uint64_t i = 0; i < n && i < 1000000; i++)
813
960
  if (!skip_value(p, end, subtype)) return 0;
814
961
  return 1;
815
962
  }
@@ -835,79 +982,66 @@ static void free_model_contents(EmbedModel *m) {
835
982
  }
836
983
  free(m->table);
837
984
  }
838
- if (m->float_data) free(m->float_data);
839
985
  if (m->mapped) munmap(m->mapped, m->mapped_size);
840
-
841
- // Free BPE tokenization data
842
986
  bpe_merge_table_free(&m->merges);
843
- if (m->pre_patterns) {
844
- for (int i = 0; i < m->num_pre_patterns; i++) {
845
- free(m->pre_patterns[i].pattern);
846
- }
847
- free(m->pre_patterns);
848
- }
849
-
850
987
  free(m);
851
988
  }
852
989
 
853
990
  /* ------------------------------------------------------------------------- */
854
991
  static int is_printable_string(const char *s, size_t len) {
855
- for (size_t i = 0; i < len; i++)
856
- if (!isprint((unsigned char)s[i])) return 0;
992
+ for (size_t i = 0; i < len; i++) if (!isprint((unsigned char)s[i])) return 0;
857
993
  return 1;
858
994
  }
859
995
 
860
- /* Fallback: find the start of tensor info by scanning for a valid string */
861
996
  static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
862
997
  uint8_t *scan = cur;
863
998
  while (scan + 8 < end) {
864
999
  uint64_t len;
865
1000
  memcpy(&len, scan, 8);
866
- if (len > 0 && len < 256 && scan + 8 + len <= end) {
867
- if (is_printable_string((char*)scan + 8, len)) {
868
- return scan;
869
- }
870
- }
1001
+ if (len > 0 && len < 256 && scan + 8 + len <= end && is_printable_string((char*)scan+8, len))
1002
+ return scan;
871
1003
  scan++;
872
1004
  }
873
1005
  return NULL;
874
1006
  }
875
1007
 
876
1008
  /* ------------------------------------------------------------------------- */
877
- static void setup_default_pre_patterns(EmbedModel *m) {
878
- // Default pre-tokenization regex patterns (similar to Llama 3)
879
- const char *default_patterns[] = {
880
- "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])",
881
- "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
882
- "\\p{N}{1,3}",
883
- " ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
884
- "\\s*[\\r\\n]+",
885
- "\\s+(?!\\S)",
886
- "\\s+"
887
- };
888
-
889
- m->num_pre_patterns = sizeof(default_patterns) / sizeof(default_patterns[0]);
890
- m->pre_patterns = malloc(m->num_pre_patterns * sizeof(RegexPattern));
891
-
892
- for (int i = 0; i < m->num_pre_patterns; i++) {
893
- m->pre_patterns[i].pattern = strdup(default_patterns[i]);
894
- m->pre_patterns[i].pattern_len = strlen(default_patterns[i]);
1009
+ static void detect_space_marker(EmbedModel *m) {
1010
+ int marker_count[4] = {0};
1011
+ const char *markers[] = {"▁", "Ġ", "ĉ", " "};
1012
+ int marker_lens[] = {3, 2, 2, 1};
1013
+
1014
+ for (int i = 0; i < m->vocab_size && i < 5000; i++) {
1015
+ for (int j = 0; j < 3; j++) {
1016
+ if (strncmp(m->tokens[i], markers[j], marker_lens[j]) == 0) {
1017
+ marker_count[j]++;
1018
+ }
1019
+ }
1020
+ if (m->tokens[i][0] == ' ' && strlen(m->tokens[i]) > 1) {
1021
+ marker_count[3]++;
1022
+ }
1023
+ }
1024
+
1025
+ int best = 0;
1026
+ for (int i = 1; i < 4; i++) {
1027
+ if (marker_count[i] > marker_count[best]) best = i;
1028
+ }
1029
+
1030
+ if (marker_count[best] > 10) {
1031
+ strcpy(m->space_marker, markers[best]);
1032
+ m->space_marker_len = marker_lens[best];
895
1033
  }
896
1034
  }
897
1035
 
898
- /* ------------------------------------------------------------------------- */
899
1036
  static void parse_merge(const char *merge_str, char **left, char **right) {
900
- // Parse a merge string like "h ello" -> left="h", right="ello"
901
1037
  const char *space = strchr(merge_str, ' ');
902
1038
  if (space) {
903
1039
  int left_len = space - merge_str;
904
1040
  *left = malloc(left_len + 1);
905
1041
  memcpy(*left, merge_str, left_len);
906
1042
  (*left)[left_len] = '\0';
907
-
908
1043
  *right = strdup(space + 1);
909
1044
  } else {
910
- // No space - treat as single token
911
1045
  *left = strdup(merge_str);
912
1046
  *right = strdup("");
913
1047
  }
@@ -918,45 +1052,39 @@ static EmbedModel *embed_load_gguf(const char *path) {
918
1052
  size_t sz;
919
1053
  uint8_t *base = map_file(path, &sz);
920
1054
  if (!base) return NULL;
921
- uint8_t *cur = base;
922
- uint8_t *end = base + sz;
923
-
924
- if (memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
1055
+ uint8_t *cur = base, *end = base + sz;
1056
+ if (sz < 4 || memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
925
1057
  cur += 4;
926
1058
  uint32_t version = rd32(&cur, end);
927
1059
  (void)version;
928
1060
  uint64_t n_tensors = rd64(&cur, end);
929
1061
  uint64_t n_kv = rd64(&cur, end);
930
1062
 
1063
+ if (n_kv > 1000000 || n_tensors > 1000000) { munmap(base, sz); return NULL; }
1064
+
931
1065
  EmbedModel *m = calloc(1, sizeof(*m));
932
1066
  if (!m) { munmap(base, sz); return NULL; }
933
1067
  m->mapped = base;
934
1068
  m->mapped_size = sz;
935
1069
  m->table = calloc(HASH_SIZE, sizeof(HashNode*));
936
1070
  if (!m->table) { free_model_contents(m); return NULL; }
937
-
938
- // Initialize BPE structures
939
1071
  bpe_merge_table_init(&m->merges);
940
- setup_default_pre_patterns(m);
941
-
942
- // Default values
943
1072
  m->unknown_token_id = -1;
944
1073
  m->bos_token_id = -1;
945
1074
  m->eos_token_id = -1;
946
1075
  m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
1076
+ m->normalize = NORM_NONE;
947
1077
 
948
- /* ---------- Metadata ---------- */
949
1078
  int vocab_found = 0;
950
1079
  for (uint64_t i = 0; i < n_kv; i++) {
951
1080
  char *key = rdstr(&cur, end);
952
1081
  if (!key) { free_model_contents(m); return NULL; }
953
1082
  uint32_t type = rd32(&cur, end);
954
1083
 
955
- if ((strcmp(key, "tokenizer.ggml.tokens") == 0 ||
956
- strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
1084
+ if ((strcmp(key, "tokenizer.ggml.tokens") == 0 || strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
957
1085
  uint32_t subtype = rd32(&cur, end);
958
1086
  uint64_t n = rd64(&cur, end);
959
- if (subtype != 8) { free(key); free_model_contents(m); return NULL; }
1087
+ if (subtype != 8 || n > 1000000) { free(key); free_model_contents(m); return NULL; }
960
1088
  m->tokens = malloc(sizeof(char*) * n);
961
1089
  if (!m->tokens) { free(key); free_model_contents(m); return NULL; }
962
1090
  m->vocab_size = (int)n;
@@ -971,66 +1099,64 @@ static EmbedModel *embed_load_gguf(const char *path) {
971
1099
  uint32_t subtype = rd32(&cur, end);
972
1100
  uint64_t n = rd64(&cur, end);
973
1101
  if (subtype == 8) {
974
- // Parse merges
975
1102
  for (uint64_t j = 0; j < n && j < MAX_MERGES; j++) {
976
1103
  char *merge_str = rdstr(&cur, end);
977
1104
  if (merge_str) {
978
1105
  char *left, *right;
979
1106
  parse_merge(merge_str, &left, &right);
980
- bpe_merge_table_add(&m->merges, left, right, merge_str, j);
1107
+ bpe_merge_table_add(&m->merges, left, right, (int)j);
981
1108
  free(left);
982
1109
  free(right);
983
1110
  free(merge_str);
1111
+ } else {
1112
+ break;
984
1113
  }
985
1114
  }
986
- } else {
987
- // Skip if not string array
988
- if (!skip_value(&cur, end, type)) {
989
- free(key); free_model_contents(m); return NULL;
1115
+ if (n > MAX_MERGES) {
1116
+ for (uint64_t j = MAX_MERGES; j < n; j++) {
1117
+ char *merge_str = rdstr(&cur, end);
1118
+ free(merge_str);
1119
+ }
990
1120
  }
1121
+ } else {
1122
+ if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
991
1123
  }
992
1124
  } else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
993
1125
  char *model_type = rdstr(&cur, end);
994
1126
  if (model_type) {
995
- if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0) {
1127
+ if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0 ||
1128
+ strcmp(model_type, "phi") == 0 || strcmp(model_type, "qwen") == 0)
996
1129
  m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
997
- } else if (strcmp(model_type, "bert") == 0) {
1130
+ else if (strcmp(model_type, "bert") == 0)
998
1131
  m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
999
- }
1132
+ else if (strcmp(model_type, "spm") == 0)
1133
+ m->vocab_type = LLAMA_VOCAB_TYPE_SPM;
1000
1134
  free(model_type);
1001
1135
  }
1002
1136
  } else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
1003
- char *pre_type = rdstr(&cur, end);
1004
- if (pre_type) {
1005
- // Could load custom regex patterns here if needed
1006
- free(pre_type);
1007
- }
1137
+ char *pre = rdstr(&cur, end);
1138
+ free(pre);
1008
1139
  } else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
1009
- m->unknown_token_id = rd32(&cur, end);
1140
+ m->unknown_token_id = (int)rd32(&cur, end);
1010
1141
  } else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
1011
- m->bos_token_id = rd32(&cur, end);
1142
+ m->bos_token_id = (int)rd32(&cur, end);
1012
1143
  } else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
1013
- m->eos_token_id = rd32(&cur, end);
1144
+ m->eos_token_id = (int)rd32(&cur, end);
1145
+ } else if (strcmp(key, "general.alignment") == 0 && type == 6) {
1146
+ rd32(&cur, end);
1014
1147
  } else {
1015
- if (!skip_value(&cur, end, type)) {
1016
- free(key); free_model_contents(m); return NULL;
1017
- }
1148
+ if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
1018
1149
  }
1019
1150
  free(key);
1020
1151
  }
1021
1152
 
1022
1153
  if (!vocab_found) { free_model_contents(m); return NULL; }
1154
+ detect_space_marker(m);
1023
1155
 
1024
1156
  uint8_t *after_kv = cur;
1025
1157
  align_to_32(&cur, end, base);
1026
1158
  uint8_t *tensor_start = cur;
1027
-
1028
- /* ---------- Tensor info ---------- */
1029
1159
  int embd_found = 0;
1030
- void *raw_tensor_data = NULL;
1031
- int tensor_type = -1;
1032
- uint64_t dim0 = 0, dim1 = 0;
1033
- int need_transpose = 0;
1034
1160
 
1035
1161
  for (int attempt = 0; attempt < 2; attempt++) {
1036
1162
  cur = tensor_start;
@@ -1039,8 +1165,7 @@ static EmbedModel *embed_load_gguf(const char *path) {
1039
1165
  if (!name) break;
1040
1166
  uint32_t n_dims = rd32(&cur, end);
1041
1167
  uint64_t dims[MAX_DIMS] = {0};
1042
- for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++)
1043
- dims[d] = rd64(&cur, end);
1168
+ for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
1044
1169
  uint32_t type = rd32(&cur, end);
1045
1170
  uint64_t offset = rd64(&cur, end);
1046
1171
 
@@ -1049,29 +1174,55 @@ static EmbedModel *embed_load_gguf(const char *path) {
1049
1174
  strcmp(name, "model.embed_tokens.weight") == 0);
1050
1175
 
1051
1176
  if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
1052
- if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd") != NULL)
1053
- is_token_embd = 1;
1054
- else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd") != NULL)
1055
- is_token_embd = 1;
1177
+ if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
1178
+ else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
1056
1179
  }
1057
1180
 
1058
1181
  if (!embd_found && is_token_embd) {
1059
- if (n_dims < 2 || dims[1] == 0) { free(name); free_model_contents(m); return NULL; }
1060
- dim0 = dims[0];
1061
- dim1 = dims[1];
1062
- if (dim0 == (uint64_t)m->vocab_size) {
1063
- m->dim = (int)dim1;
1182
+ if (n_dims < 2 || dims[1] == 0) {
1183
+ free(name); free_model_contents(m); return NULL;
1184
+ }
1185
+
1186
+ uint64_t ne0 = dims[0];
1187
+ uint64_t ne1 = dims[1];
1188
+
1189
+ int need_transpose = 0;
1190
+ int dim;
1191
+
1192
+ if (ne1 == (uint64_t)m->vocab_size) {
1193
+ dim = (int)ne0;
1064
1194
  need_transpose = 0;
1065
- } else if (dim1 == (uint64_t)m->vocab_size) {
1066
- m->dim = (int)dim0;
1195
+ } else if (ne0 == (uint64_t)m->vocab_size) {
1196
+ dim = (int)ne1;
1067
1197
  need_transpose = 1;
1068
1198
  } else {
1069
- m->dim = (dim0 < dim1) ? (int)dim0 : (int)dim1;
1070
- need_transpose = (dim0 > dim1) ? 1 : 0;
1199
+ dim = (ne0 < ne1) ? (int)ne0 : (int)ne1;
1200
+ need_transpose = (ne0 > ne1) ? 1 : 0;
1071
1201
  }
1072
- raw_tensor_data = base + offset;
1073
- tensor_type = type;
1202
+
1203
+ if (dim <= 0 || dim > MAX_DIM) {
1204
+ free(name); free_model_contents(m); return NULL;
1205
+ }
1206
+
1207
+ size_t row_bytes = get_row_bytes(type, (int)(need_transpose ? ne1 : ne0));
1208
+ size_t total_size = (size_t)(need_transpose ? ne1 : ne0) * row_bytes;
1209
+
1210
+ if (offset >= sz || offset + total_size > sz) {
1211
+ free(name);
1212
+ free_model_contents(m);
1213
+ return NULL;
1214
+ }
1215
+
1216
+ m->dim = dim;
1217
+ m->raw_dim0 = ne0;
1218
+ m->raw_dim1 = ne1;
1219
+ m->need_transpose = need_transpose;
1220
+ m->raw_tensor_data = base + offset;
1221
+ m->tensor_type = type;
1222
+ m->row_bytes = row_bytes;
1074
1223
  embd_found = 1;
1224
+ free(name);
1225
+ break;
1075
1226
  }
1076
1227
  free(name);
1077
1228
  }
@@ -1082,110 +1233,122 @@ static EmbedModel *embed_load_gguf(const char *path) {
1082
1233
  }
1083
1234
  }
1084
1235
 
1085
- if (!embd_found || m->dim == 0) {
1086
- free_model_contents(m);
1087
- return NULL;
1088
- }
1089
-
1090
- /* Dequantize */
1091
- if (tensor_type == GGML_TYPE_F32 && !need_transpose) {
1092
- m->float_data = NULL;
1093
- m->tensor_data = raw_tensor_data;
1094
- } else {
1095
- int n_rows = need_transpose ? (int)dim1 : (int)dim0;
1096
- int n_cols = need_transpose ? (int)dim0 : (int)dim1;
1097
- m->float_data = dequantize_tensor(raw_tensor_data, tensor_type, n_rows, n_cols);
1098
- if (!m->float_data) {
1099
- free_model_contents(m);
1100
- return NULL;
1101
- }
1102
- m->tensor_data = m->float_data;
1236
+ if (!embd_found || m->dim == 0) {
1237
+ free_model_contents(m); return NULL;
1103
1238
  }
1104
- m->tensor_type = tensor_type;
1105
1239
 
1106
1240
  return m;
1107
1241
  }
1108
1242
 
1243
+ /* ------------------------------------------------------------------------- */
1244
+ // L2 normalization
1245
+ static void normalize_l2(float *vec, int dim) {
1246
+ float sum = 0;
1247
+ for (int i = 0; i < dim; i++) sum += vec[i] * vec[i];
1248
+ float norm = sqrtf(sum);
1249
+ if (norm > 1e-8f) {
1250
+ float inv = 1.0f / norm;
1251
+ for (int i = 0; i < dim; i++) vec[i] *= inv;
1252
+ }
1253
+ }
1254
+
1109
1255
  /* ------------------------------------------------------------------------- */
1110
1256
  static void embed_text(EmbedModel *m, const char *txt, float *out) {
1111
1257
  memset(out, 0, sizeof(float) * m->dim);
1112
-
1113
- // Pre-tokenize using regex
1258
+ if (!txt || !*txt) return;
1259
+
1114
1260
  int num_words = 0;
1115
- char **words = unicode_regex_split(txt, m->pre_patterns, m->num_pre_patterns, &num_words);
1116
-
1261
+ char **words = pre_tokenize(txt, &num_words);
1262
+
1117
1263
  if (!words || num_words == 0) {
1118
- // Fallback to space splitting if regex fails
1119
- char *copy = strdup(txt);
1120
- if (!copy) return;
1121
-
1122
- char *tok = strtok(copy, " \t\n\r");
1123
- int used = 0;
1124
- const float *embd_matrix = m->tensor_data;
1125
-
1126
- while (tok) {
1127
- int id = hget(m, tok);
1128
- if (id >= 0 && id < m->vocab_size) {
1129
- const float *vec = embd_matrix + id * m->dim;
1130
- for (int i = 0; i < m->dim; i++) out[i] += vec[i];
1131
- used++;
1132
- }
1133
- tok = strtok(NULL, " \t\n\r");
1134
- }
1135
-
1136
- if (used > 0) {
1137
- float inv = 1.0f / used;
1138
- for (int i = 0; i < m->dim; i++) out[i] *= inv;
1139
- }
1140
- free(copy);
1264
+ if (words) free(words);
1265
+ return;
1266
+ }
1267
+
1268
+ int *token_ids = malloc(m->vocab_size * sizeof(int));
1269
+ if (!token_ids) {
1270
+ for (int i = 0; i < num_words; i++) free(words[i]);
1271
+ free(words);
1141
1272
  return;
1142
1273
  }
1143
-
1144
- // Tokenize each word using BPE
1145
- int *token_ids = malloc(m->vocab_size * sizeof(int)); // Max possible tokens
1146
- int num_tokens = 0;
1147
- const float *embd_matrix = m->tensor_data;
1274
+
1148
1275
  int used = 0;
1149
-
1276
+ float *temp_vec = malloc(m->dim * sizeof(float));
1277
+
1150
1278
  for (int i = 0; i < num_words; i++) {
1151
- num_tokens = 0;
1152
- bpe_tokenize_word(&m->merges, words[i], text_to_id, m, token_ids, &num_tokens);
1153
-
1154
- for (int j = 0; j < num_tokens; j++) {
1155
- int id = token_ids[j];
1156
- if (id >= 0 && id < m->vocab_size) {
1157
- const float *vec = embd_matrix + id * m->dim;
1158
- for (int k = 0; k < m->dim; k++) out[k] += vec[k];
1159
- used++;
1160
- } else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
1161
- // Use unknown token as fallback
1162
- const float *vec = embd_matrix + m->unknown_token_id * m->dim;
1163
- for (int k = 0; k < m->dim; k++) out[k] += vec[k];
1164
- used++;
1279
+ char *word = words[i];
1280
+ int id = hget(m, word);
1281
+
1282
+ if (id == -1 && m->space_marker_len > 0) {
1283
+ size_t with_marker_len = m->space_marker_len + strlen(word);
1284
+ char *with_marker = malloc(with_marker_len + 1);
1285
+ if (with_marker) {
1286
+ memcpy(with_marker, m->space_marker, m->space_marker_len);
1287
+ strcpy(with_marker + m->space_marker_len, word);
1288
+ id = hget(m, with_marker);
1289
+ free(with_marker);
1290
+ }
1291
+ }
1292
+
1293
+ if (id != -1 && id >= 0 && id < m->vocab_size) {
1294
+ dequantize_row_lazy(m, id, temp_vec);
1295
+ for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
1296
+ used++;
1297
+ } else {
1298
+ int num_tokens = 0;
1299
+ bpe_tokenize_word(&m->merges, word, m, token_ids, &num_tokens);
1300
+ for (int k = 0; k < num_tokens; k++) {
1301
+ int tid = token_ids[k];
1302
+ if (tid >= 0 && tid < m->vocab_size) {
1303
+ dequantize_row_lazy(m, tid, temp_vec);
1304
+ for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
1305
+ used++;
1306
+ } else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
1307
+ dequantize_row_lazy(m, m->unknown_token_id, temp_vec);
1308
+ for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
1309
+ used++;
1310
+ }
1165
1311
  }
1166
1312
  }
1167
-
1168
- free(words[i]);
1313
+ free(word);
1169
1314
  }
1315
+
1170
1316
  free(words);
1171
1317
  free(token_ids);
1172
-
1318
+ free(temp_vec);
1319
+
1173
1320
  if (used > 0) {
1174
1321
  float inv = 1.0f / used;
1175
1322
  for (int i = 0; i < m->dim; i++) out[i] *= inv;
1176
1323
  }
1324
+
1325
+ for (int i = 0; i < m->dim; i++) {
1326
+ if (isnan(out[i]) || isinf(out[i])) {
1327
+ out[i] = 0.0f;
1328
+ }
1329
+ }
1330
+
1331
+ if (m->normalize == NORM_L2) {
1332
+ normalize_l2(out, m->dim);
1333
+ }
1177
1334
  }
1178
1335
 
1179
1336
  /* ------------------------------------------------------------------------- */
1337
+ // Ruby bindings
1180
1338
  static void rb_embedder_free(void *p) {
1181
1339
  ruby_embedder *e = p;
1182
- if (!e) return;
1183
- if (e->model) free_model_contents(e->model);
1184
- free(e);
1340
+ if (e) { if (e->model) free_model_contents(e->model); free(e); }
1185
1341
  }
1186
1342
 
1187
1343
  static size_t rb_embedder_memsize(const void *p) {
1188
- return sizeof(ruby_embedder);
1344
+ const ruby_embedder *e = p;
1345
+ size_t sz = sizeof(ruby_embedder);
1346
+ if (e && e->model) {
1347
+ sz += e->model->vocab_size * sizeof(char*);
1348
+ sz += e->model->mapped_size;
1349
+ sz += HASH_SIZE * sizeof(HashNode*);
1350
+ }
1351
+ return sz;
1189
1352
  }
1190
1353
 
1191
1354
  static const rb_data_type_t ruby_embedder_type = {
@@ -1203,11 +1366,31 @@ static VALUE rb_embedder_initialize(VALUE self, VALUE opts) {
1203
1366
  ruby_embedder *e;
1204
1367
  TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
1205
1368
 
1369
+ Check_Type(opts, T_HASH);
1206
1370
  VALUE path = rb_hash_aref(opts, ID2SYM(rb_intern("model")));
1371
+ if (NIL_P(path)) rb_raise(rb_eArgError, "missing required key: model");
1207
1372
  const char *cpath = StringValueCStr(path);
1373
+
1374
+ VALUE normalize = rb_hash_aref(opts, ID2SYM(rb_intern("normalize")));
1375
+ int norm_type = NORM_NONE;
1376
+ if (!NIL_P(normalize)) {
1377
+ if (SYMBOL_P(normalize)) {
1378
+ ID sym_id = SYM2ID(normalize);
1379
+ if (sym_id == rb_intern("l2") || sym_id == rb_intern("L2")) {
1380
+ norm_type = NORM_L2;
1381
+ }
1382
+ } else if (TYPE(normalize) == T_STRING) {
1383
+ const char *norm_str = StringValueCStr(normalize);
1384
+ if (strcasecmp(norm_str, "l2") == 0) {
1385
+ norm_type = NORM_L2;
1386
+ }
1387
+ }
1388
+ }
1389
+
1208
1390
  e->model = embed_load_gguf(cpath);
1209
- if (!e->model)
1210
- rb_raise(rb_eRuntimeError, "failed to load GGUF model");
1391
+ if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model: %s", cpath);
1392
+
1393
+ e->model->normalize = norm_type;
1211
1394
  return self;
1212
1395
  }
1213
1396
 
@@ -1215,7 +1398,9 @@ static VALUE rb_embed(VALUE self, VALUE opts) {
1215
1398
  ruby_embedder *e;
1216
1399
  TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
1217
1400
 
1401
+ Check_Type(opts, T_HASH);
1218
1402
  VALUE text = rb_hash_aref(opts, ID2SYM(rb_intern("text")));
1403
+ if (NIL_P(text)) rb_raise(rb_eArgError, "missing required key: text");
1219
1404
  const char *ctext = StringValueCStr(text);
1220
1405
 
1221
1406
  VALUE out = rb_str_new(NULL, e->model->dim * sizeof(float));
@@ -1227,5 +1412,5 @@ void Init_mini_embed(void) {
1227
1412
  VALUE c = rb_define_class("MiniEmbed", rb_cObject);
1228
1413
  rb_define_alloc_func(c, rb_embedder_alloc);
1229
1414
  rb_define_method(c, "initialize", rb_embedder_initialize, 1);
1230
- rb_define_method(c, "embeddings", rb_embed, 1);
1415
+ rb_define_method(c, "embed", rb_embed, 1);
1231
1416
  }