mini_embed 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +9 -5
- data/ext/mini_embed/mini_embed.c +185 -354
- data/lib/mini_embed.rb +14 -0
- metadata +1 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 2df9f5c081f8a7fa2447261817ebba58e5c062921a1dc6ee3ec8048fdc300022
|
|
4
|
+
data.tar.gz: 4ee4f87161506c59e6deda7dd12b819c33e86ae5f5843aa89a9754b57b27f968
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 4646b9f96a6ef525751d3046f0524d308940c1deb3a85623f775666ab2e1bbcddbec62c16f110f2cc1eae620f0b52c900a64f03cfa01f8e856d3546eb404ee98
|
|
7
|
+
data.tar.gz: f03b8103bddc296f1601d62bda851655529dc47cea1eee9f6bfcfc03fb5c005ad33f2b53a81e84d882b2813c670254bdbdb465b2b64f76c4ffc66248f5b27f73
|
data/README.md
CHANGED
|
@@ -52,15 +52,19 @@ require 'mini_embed'
|
|
|
52
52
|
# Load a GGUF model (F32, F16, Q8_0, Q4_K, etc. are all supported)
|
|
53
53
|
model = MiniEmbed.new(model: '/path/to/gte-small.Q8_0.gguf')
|
|
54
54
|
|
|
55
|
-
# Get
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
# Get an embedding as an array of floats
|
|
59
|
-
embedding = binary.unpack('e*')
|
|
55
|
+
# Get embedding as an array of floats (default)
|
|
56
|
+
embedding = model.embeddings(text: 'hello world')
|
|
60
57
|
puts embedding.size # e.g. 384
|
|
61
58
|
puts embedding[0..4] # e.g. [0.0123, -0.0456, ...]
|
|
59
|
+
|
|
60
|
+
# Or get the raw binary string (little‑endian 32‑bit floats)
|
|
61
|
+
binary = model.embeddings(text: 'hello world', type: :binary)
|
|
62
|
+
embedding_from_binary = binary.unpack('e*')
|
|
62
63
|
```
|
|
63
64
|
|
|
65
|
+
Note: The type parameter is optional – it defaults to :vector which returns a Ruby `Array<Float>`. Use `type: :binary` to get the raw binary string (compatible with the original C extension).
|
|
66
|
+
|
|
67
|
+
|
|
64
68
|
## Simple tokenization note
|
|
65
69
|
MiniEmbed uses a naive space‑based tokenizer. This means it splits input on spaces and looks up each token exactly in the model's vocabulary. For models trained with subword tokenization (like BERT), this will not work for out‑of‑vocabulary words.
|
|
66
70
|
If you need proper subword tokenization, you can:
|
data/ext/mini_embed/mini_embed.c
CHANGED
|
@@ -41,17 +41,16 @@ enum llama_vocab_type {
|
|
|
41
41
|
};
|
|
42
42
|
|
|
43
43
|
/* ------------------------------------------------------------------------- */
|
|
44
|
-
// Unicode helper functions
|
|
44
|
+
// Unicode helper functions
|
|
45
45
|
static int unicode_len_utf8(char c) {
|
|
46
46
|
if ((c & 0x80) == 0) return 1;
|
|
47
47
|
if ((c & 0xE0) == 0xC0) return 2;
|
|
48
48
|
if ((c & 0xF0) == 0xE0) return 3;
|
|
49
49
|
if ((c & 0xF8) == 0xF0) return 4;
|
|
50
|
-
return 1;
|
|
50
|
+
return 1;
|
|
51
51
|
}
|
|
52
52
|
|
|
53
53
|
static int unicode_is_letter(uint32_t cp) {
|
|
54
|
-
// Basic Unicode letter detection (simplified)
|
|
55
54
|
return (cp >= 0x41 && cp <= 0x5A) || (cp >= 0x61 && cp <= 0x7A) ||
|
|
56
55
|
(cp >= 0xC0 && cp <= 0xD6) || (cp >= 0xD8 && cp <= 0xF6) ||
|
|
57
56
|
(cp >= 0xF8 && cp <= 0x2FF) || (cp >= 0x370 && cp <= 0x37D) ||
|
|
@@ -68,74 +67,42 @@ static int unicode_is_number(uint32_t cp) {
|
|
|
68
67
|
}
|
|
69
68
|
|
|
70
69
|
static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
|
|
71
|
-
uint32_t cp = 0;
|
|
72
70
|
unsigned char c = (unsigned char)s[0];
|
|
73
|
-
|
|
74
|
-
if (c
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
} else if ((c & 0xE0) == 0xC0) {
|
|
78
|
-
*len = 2;
|
|
79
|
-
cp = (c & 0x1F) << 6;
|
|
80
|
-
cp |= (s[1] & 0x3F);
|
|
81
|
-
return cp;
|
|
82
|
-
} else if ((c & 0xF0) == 0xE0) {
|
|
83
|
-
*len = 3;
|
|
84
|
-
cp = (c & 0x0F) << 12;
|
|
85
|
-
cp |= (s[1] & 0x3F) << 6;
|
|
86
|
-
cp |= (s[2] & 0x3F);
|
|
87
|
-
return cp;
|
|
88
|
-
} else if ((c & 0xF8) == 0xF0) {
|
|
89
|
-
*len = 4;
|
|
90
|
-
cp = (c & 0x07) << 18;
|
|
91
|
-
cp |= (s[1] & 0x3F) << 12;
|
|
92
|
-
cp |= (s[2] & 0x3F) << 6;
|
|
93
|
-
cp |= (s[3] & 0x3F);
|
|
94
|
-
return cp;
|
|
95
|
-
}
|
|
96
|
-
|
|
71
|
+
if (c < 0x80) { *len = 1; return c; }
|
|
72
|
+
if ((c & 0xE0) == 0xC0) { *len = 2; return ((c & 0x1F) << 6) | (s[1] & 0x3F); }
|
|
73
|
+
if ((c & 0xF0) == 0xE0) { *len = 3; return ((c & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F); }
|
|
74
|
+
if ((c & 0xF8) == 0xF0) { *len = 4; return ((c & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F); }
|
|
97
75
|
*len = 1;
|
|
98
76
|
return c;
|
|
99
77
|
}
|
|
100
78
|
|
|
101
79
|
/* ------------------------------------------------------------------------- */
|
|
102
|
-
// Simple regex pattern matcher
|
|
80
|
+
// Simple regex pattern matcher (simplified)
|
|
103
81
|
typedef struct {
|
|
104
82
|
char *pattern;
|
|
105
83
|
int pattern_len;
|
|
106
84
|
} RegexPattern;
|
|
107
85
|
|
|
108
86
|
static int match_regex(const char *text, const RegexPattern *patterns, int num_patterns) {
|
|
109
|
-
// Simplified implementation for common BPE patterns
|
|
110
|
-
// Full regex engine would be complex; this handles the most common cases
|
|
111
|
-
|
|
112
87
|
for (int i = 0; i < num_patterns; i++) {
|
|
113
88
|
const char *p = patterns[i].pattern;
|
|
114
|
-
int plen = patterns[i].pattern_len;
|
|
115
|
-
|
|
116
|
-
// Check for common patterns
|
|
117
89
|
if (strstr(p, "\\p{L}")) {
|
|
118
|
-
// Match Unicode letter
|
|
119
90
|
size_t len;
|
|
120
91
|
uint32_t cp = unicode_cpt_from_utf8(text, &len);
|
|
121
92
|
if (unicode_is_letter(cp)) return 1;
|
|
122
93
|
} else if (strstr(p, "\\p{N}")) {
|
|
123
|
-
// Match Unicode number
|
|
124
94
|
size_t len;
|
|
125
95
|
uint32_t cp = unicode_cpt_from_utf8(text, &len);
|
|
126
96
|
if (unicode_is_number(cp)) return 1;
|
|
127
97
|
} else if (p[0] == '\\' && p[1] == 's') {
|
|
128
|
-
// Match whitespace
|
|
129
98
|
if (isspace(text[0])) return 1;
|
|
130
99
|
} else if (p[0] == '\\' && p[1] == 'r') {
|
|
131
100
|
if (text[0] == '\r') return 1;
|
|
132
101
|
} else if (p[0] == '\\' && p[1] == 'n') {
|
|
133
102
|
if (text[0] == '\n') return 1;
|
|
134
103
|
} else if (p[0] == '.' && p[1] == '*') {
|
|
135
|
-
// Match anything
|
|
136
104
|
return 1;
|
|
137
105
|
} else if (isalnum(p[0]) || ispunct(p[0])) {
|
|
138
|
-
// Match literal character
|
|
139
106
|
if (text[0] == p[0]) return 1;
|
|
140
107
|
}
|
|
141
108
|
}
|
|
@@ -144,65 +111,36 @@ static int match_regex(const char *text, const RegexPattern *patterns, int num_p
|
|
|
144
111
|
|
|
145
112
|
static char** unicode_regex_split(const char *text, const RegexPattern *patterns, int num_patterns, int *num_words) {
|
|
146
113
|
char **words = NULL;
|
|
147
|
-
int word_count = 0;
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
size_t text_len = strlen(text);
|
|
151
|
-
size_t pos = 0;
|
|
152
|
-
|
|
114
|
+
int word_count = 0, word_capacity = 0;
|
|
115
|
+
size_t text_len = strlen(text), pos = 0;
|
|
116
|
+
|
|
153
117
|
while (pos < text_len) {
|
|
154
|
-
// Find the start of a word (character that matches any regex)
|
|
155
118
|
size_t start = pos;
|
|
156
|
-
while (start < text_len)
|
|
157
|
-
if (match_regex(text + start, patterns, num_patterns)) {
|
|
158
|
-
break;
|
|
159
|
-
}
|
|
160
|
-
start++;
|
|
161
|
-
}
|
|
162
|
-
|
|
119
|
+
while (start < text_len && !match_regex(text + start, patterns, num_patterns)) start++;
|
|
163
120
|
if (start >= text_len) break;
|
|
164
|
-
|
|
165
|
-
// Find the end of the word (character that doesn't match any regex)
|
|
166
121
|
size_t end = start;
|
|
167
|
-
while (end < text_len)
|
|
168
|
-
if (!match_regex(text + end, patterns, num_patterns)) {
|
|
169
|
-
break;
|
|
170
|
-
}
|
|
171
|
-
end++;
|
|
172
|
-
}
|
|
173
|
-
|
|
122
|
+
while (end < text_len && match_regex(text + end, patterns, num_patterns)) end++;
|
|
174
123
|
if (end > start) {
|
|
175
|
-
// Extract the word
|
|
176
124
|
size_t word_len = end - start;
|
|
177
125
|
char *word = malloc(word_len + 1);
|
|
178
|
-
if (word) {
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
words = realloc(words, word_capacity * sizeof(char*));
|
|
186
|
-
if (!words) {
|
|
187
|
-
for (int i = 0; i < word_count; i++) free(words[i]);
|
|
188
|
-
free(words);
|
|
189
|
-
*num_words = 0;
|
|
190
|
-
return NULL;
|
|
191
|
-
}
|
|
192
|
-
}
|
|
193
|
-
words[word_count++] = word;
|
|
126
|
+
if (!word) { while (--word_count >= 0) free(words[word_count]); free(words); *num_words = 0; return NULL; }
|
|
127
|
+
memcpy(word, text + start, word_len);
|
|
128
|
+
word[word_len] = '\0';
|
|
129
|
+
if (word_count >= word_capacity) {
|
|
130
|
+
word_capacity = word_capacity ? word_capacity * 2 : 16;
|
|
131
|
+
words = realloc(words, word_capacity * sizeof(char*));
|
|
132
|
+
if (!words) { free(word); while (--word_count >= 0) free(words[word_count]); *num_words = 0; return NULL; }
|
|
194
133
|
}
|
|
134
|
+
words[word_count++] = word;
|
|
195
135
|
}
|
|
196
|
-
|
|
197
136
|
pos = end;
|
|
198
137
|
}
|
|
199
|
-
|
|
200
138
|
*num_words = word_count;
|
|
201
139
|
return words;
|
|
202
140
|
}
|
|
203
141
|
|
|
204
142
|
/* ------------------------------------------------------------------------- */
|
|
205
|
-
// BPE merge
|
|
143
|
+
// BPE merge structures
|
|
206
144
|
typedef struct {
|
|
207
145
|
char *left;
|
|
208
146
|
char *right;
|
|
@@ -217,22 +155,19 @@ typedef struct {
|
|
|
217
155
|
} BPEMergeTable;
|
|
218
156
|
|
|
219
157
|
static void bpe_merge_table_init(BPEMergeTable *table) {
|
|
220
|
-
table
|
|
221
|
-
table->num_merges = 0;
|
|
222
|
-
table->capacity = 0;
|
|
158
|
+
memset(table, 0, sizeof(*table));
|
|
223
159
|
}
|
|
224
160
|
|
|
225
161
|
static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, const char *merged, int rank) {
|
|
226
162
|
if (table->num_merges >= table->capacity) {
|
|
227
|
-
table->capacity = table->capacity
|
|
163
|
+
table->capacity = table->capacity ? table->capacity * 2 : 100;
|
|
228
164
|
table->merges = realloc(table->merges, table->capacity * sizeof(BPEMerge));
|
|
229
165
|
}
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
merge->rank = rank;
|
|
166
|
+
BPEMerge *m = &table->merges[table->num_merges++];
|
|
167
|
+
m->left = strdup(left);
|
|
168
|
+
m->right = strdup(right);
|
|
169
|
+
m->merged = strdup(merged);
|
|
170
|
+
m->rank = rank;
|
|
236
171
|
}
|
|
237
172
|
|
|
238
173
|
static void bpe_merge_table_free(BPEMergeTable *table) {
|
|
@@ -248,40 +183,25 @@ static void bpe_merge_table_free(BPEMergeTable *table) {
|
|
|
248
183
|
|
|
249
184
|
static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
|
|
250
185
|
for (int i = 0; i < table->num_merges; i++) {
|
|
251
|
-
if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0)
|
|
186
|
+
if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0)
|
|
252
187
|
return table->merges[i].rank;
|
|
253
|
-
}
|
|
254
188
|
}
|
|
255
189
|
return -1;
|
|
256
190
|
}
|
|
257
191
|
|
|
258
|
-
static char* bpe_merge(const BPEMergeTable *table, const char *left, const char *right) {
|
|
259
|
-
for (int i = 0; i < table->num_merges; i++) {
|
|
260
|
-
if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
|
|
261
|
-
return table->merges[i].merged;
|
|
262
|
-
}
|
|
263
|
-
}
|
|
264
|
-
return NULL;
|
|
265
|
-
}
|
|
266
|
-
|
|
267
192
|
/* ------------------------------------------------------------------------- */
|
|
268
|
-
// BPE tokenization
|
|
193
|
+
// BPE tokenization
|
|
269
194
|
typedef struct {
|
|
270
195
|
char *text;
|
|
271
|
-
int start;
|
|
272
|
-
int
|
|
273
|
-
int prev;
|
|
274
|
-
int next;
|
|
196
|
+
int start, end;
|
|
197
|
+
int prev, next;
|
|
275
198
|
int used;
|
|
276
199
|
} BPESymbol;
|
|
277
200
|
|
|
278
201
|
static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int (*text_to_id)(void*, const char*), void *vocab_data, int *token_ids, int *num_tokens) {
|
|
279
|
-
// Initialize symbols from characters
|
|
280
202
|
int word_len = strlen(word);
|
|
281
203
|
int num_symbols = 0;
|
|
282
204
|
BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
|
|
283
|
-
|
|
284
|
-
// Split into UTF-8 characters
|
|
285
205
|
int offset = 0;
|
|
286
206
|
while (offset < word_len) {
|
|
287
207
|
int char_len = unicode_len_utf8(word[offset]);
|
|
@@ -294,39 +214,25 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
|
|
|
294
214
|
offset += char_len;
|
|
295
215
|
num_symbols++;
|
|
296
216
|
}
|
|
297
|
-
|
|
217
|
+
|
|
298
218
|
if (num_symbols <= 1) {
|
|
299
|
-
// Single character, just tokenize it
|
|
300
219
|
int id = text_to_id(vocab_data, word);
|
|
301
|
-
if (id != -1)
|
|
302
|
-
token_ids[*num_tokens] = id;
|
|
303
|
-
(*num_tokens)++;
|
|
304
|
-
}
|
|
220
|
+
if (id != -1) token_ids[(*num_tokens)++] = id;
|
|
305
221
|
free(symbols);
|
|
306
222
|
return;
|
|
307
223
|
}
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
int left;
|
|
312
|
-
int right;
|
|
313
|
-
int rank;
|
|
314
|
-
} Bigram;
|
|
315
|
-
|
|
316
|
-
Bigram *bigrams = malloc(word_len * word_len * sizeof(Bigram));
|
|
224
|
+
|
|
225
|
+
typedef struct { int left, right, rank; } Bigram;
|
|
226
|
+
Bigram *bigrams = malloc(num_symbols * num_symbols * sizeof(Bigram));
|
|
317
227
|
int num_bigrams = 0;
|
|
318
|
-
|
|
319
|
-
// Initialize bigrams
|
|
320
228
|
for (int i = 0; i < num_symbols - 1; i++) {
|
|
321
229
|
if (symbols[i].used && symbols[i+1].used) {
|
|
322
|
-
// Get the concatenated string for this pair
|
|
323
230
|
char *left_str = malloc(symbols[i].end - symbols[i].start + 1);
|
|
324
231
|
char *right_str = malloc(symbols[i+1].end - symbols[i+1].start + 1);
|
|
325
232
|
memcpy(left_str, symbols[i].text, symbols[i].end - symbols[i].start);
|
|
326
233
|
memcpy(right_str, symbols[i+1].text, symbols[i+1].end - symbols[i+1].start);
|
|
327
234
|
left_str[symbols[i].end - symbols[i].start] = '\0';
|
|
328
235
|
right_str[symbols[i+1].end - symbols[i+1].start] = '\0';
|
|
329
|
-
|
|
330
236
|
int rank = bpe_merge_rank(merges, left_str, right_str);
|
|
331
237
|
if (rank != -1) {
|
|
332
238
|
bigrams[num_bigrams].left = i;
|
|
@@ -334,70 +240,42 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
|
|
|
334
240
|
bigrams[num_bigrams].rank = rank;
|
|
335
241
|
num_bigrams++;
|
|
336
242
|
}
|
|
337
|
-
|
|
338
|
-
free(left_str);
|
|
339
|
-
free(right_str);
|
|
243
|
+
free(left_str); free(right_str);
|
|
340
244
|
}
|
|
341
245
|
}
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
for (int i = 0; i < num_bigrams - 1; i++) {
|
|
345
|
-
for (int j = i+1; j < num_bigrams; j++) {
|
|
246
|
+
for (int i = 0; i < num_bigrams - 1; i++)
|
|
247
|
+
for (int j = i+1; j < num_bigrams; j++)
|
|
346
248
|
if (bigrams[i].rank > bigrams[j].rank) {
|
|
347
|
-
Bigram
|
|
249
|
+
Bigram tmp = bigrams[i];
|
|
348
250
|
bigrams[i] = bigrams[j];
|
|
349
|
-
bigrams[j] =
|
|
251
|
+
bigrams[j] = tmp;
|
|
350
252
|
}
|
|
351
|
-
|
|
352
|
-
}
|
|
353
|
-
|
|
354
|
-
// Apply merges
|
|
253
|
+
|
|
355
254
|
int *merged = calloc(num_symbols, sizeof(int));
|
|
356
255
|
for (int i = 0; i < num_bigrams; i++) {
|
|
357
|
-
int left = bigrams[i].left;
|
|
358
|
-
int right = bigrams[i].right;
|
|
359
|
-
|
|
256
|
+
int left = bigrams[i].left, right = bigrams[i].right;
|
|
360
257
|
if (merged[left] || merged[right]) continue;
|
|
361
|
-
|
|
362
|
-
// Merge right into left
|
|
363
258
|
symbols[left].end = symbols[right].end;
|
|
364
259
|
symbols[left].next = symbols[right].next;
|
|
365
260
|
merged[right] = 1;
|
|
366
|
-
|
|
367
|
-
// Update next symbol's prev
|
|
368
|
-
if (symbols[right].next < num_symbols) {
|
|
369
|
-
symbols[symbols[right].next].prev = left;
|
|
370
|
-
}
|
|
261
|
+
if (symbols[right].next < num_symbols) symbols[symbols[right].next].prev = left;
|
|
371
262
|
}
|
|
372
|
-
|
|
373
|
-
// Collect final tokens
|
|
263
|
+
|
|
374
264
|
for (int i = 0; i < num_symbols; i++) {
|
|
375
265
|
if (!merged[i] && symbols[i].used) {
|
|
376
|
-
// Extract the substring
|
|
377
266
|
char *substr = malloc(symbols[i].end - symbols[i].start + 1);
|
|
378
267
|
memcpy(substr, word + symbols[i].start, symbols[i].end - symbols[i].start);
|
|
379
268
|
substr[symbols[i].end - symbols[i].start] = '\0';
|
|
380
|
-
|
|
381
269
|
int id = text_to_id(vocab_data, substr);
|
|
382
|
-
if (id != -1)
|
|
383
|
-
token_ids[*num_tokens] = id;
|
|
384
|
-
(*num_tokens)++;
|
|
385
|
-
} else {
|
|
386
|
-
// Unknown token - use byte-level fallback
|
|
387
|
-
// For simplicity, we'll use space as a placeholder
|
|
388
|
-
// In a full implementation, you'd encode bytes individually
|
|
389
|
-
}
|
|
390
|
-
|
|
270
|
+
if (id != -1) token_ids[(*num_tokens)++] = id;
|
|
391
271
|
free(substr);
|
|
392
272
|
}
|
|
393
273
|
}
|
|
394
|
-
|
|
395
|
-
free(bigrams);
|
|
396
|
-
free(merged);
|
|
397
|
-
free(symbols);
|
|
274
|
+
free(bigrams); free(merged); free(symbols);
|
|
398
275
|
}
|
|
399
276
|
|
|
400
277
|
/* ------------------------------------------------------------------------- */
|
|
278
|
+
// GGUF parsing
|
|
401
279
|
static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
|
|
402
280
|
if (*p + sz > end) return 0;
|
|
403
281
|
*p += sz;
|
|
@@ -405,14 +283,14 @@ static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
|
|
|
405
283
|
}
|
|
406
284
|
|
|
407
285
|
static uint32_t rd32(uint8_t **p, uint8_t *end) {
|
|
408
|
-
uint32_t v
|
|
286
|
+
uint32_t v;
|
|
409
287
|
if (!safe_advance(p, end, 4)) return 0;
|
|
410
288
|
memcpy(&v, *p - 4, 4);
|
|
411
289
|
return v;
|
|
412
290
|
}
|
|
413
291
|
|
|
414
292
|
static uint64_t rd64(uint8_t **p, uint8_t *end) {
|
|
415
|
-
uint64_t v
|
|
293
|
+
uint64_t v;
|
|
416
294
|
if (!safe_advance(p, end, 8)) return 0;
|
|
417
295
|
memcpy(&v, *p - 8, 8);
|
|
418
296
|
return v;
|
|
@@ -423,9 +301,9 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
|
|
|
423
301
|
uint64_t len;
|
|
424
302
|
memcpy(&len, *p, 8);
|
|
425
303
|
*p += 8;
|
|
426
|
-
if (len == 0 || len > (1
|
|
304
|
+
if (len == 0 || len > (1<<20)) return NULL;
|
|
427
305
|
if (*p + len > end) return NULL;
|
|
428
|
-
char *s = malloc(len
|
|
306
|
+
char *s = malloc(len+1);
|
|
429
307
|
if (!s) return NULL;
|
|
430
308
|
memcpy(s, *p, len);
|
|
431
309
|
s[len] = '\0';
|
|
@@ -436,29 +314,27 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
|
|
|
436
314
|
static void align_to_32(uint8_t **p, uint8_t *end, uint8_t *base) {
|
|
437
315
|
size_t off = *p - base;
|
|
438
316
|
size_t aligned = (off + GGUF_ALIGN - 1) & ~(GGUF_ALIGN - 1);
|
|
439
|
-
if (base + aligned <= end)
|
|
440
|
-
*p = base + aligned;
|
|
317
|
+
if (base + aligned <= end) *p = base + aligned;
|
|
441
318
|
}
|
|
442
319
|
|
|
443
320
|
/* ------------------------------------------------------------------------- */
|
|
321
|
+
// Hash table for vocabulary
|
|
444
322
|
typedef struct HashNode {
|
|
445
323
|
char *key;
|
|
446
|
-
int
|
|
324
|
+
int id;
|
|
447
325
|
struct HashNode *next;
|
|
448
326
|
} HashNode;
|
|
449
327
|
|
|
450
328
|
typedef struct {
|
|
451
|
-
int
|
|
452
|
-
int
|
|
453
|
-
char
|
|
454
|
-
float
|
|
455
|
-
void
|
|
456
|
-
int
|
|
457
|
-
void
|
|
458
|
-
size_t
|
|
329
|
+
int vocab_size;
|
|
330
|
+
int dim;
|
|
331
|
+
char **tokens;
|
|
332
|
+
float *float_data;
|
|
333
|
+
void *tensor_data;
|
|
334
|
+
int tensor_type;
|
|
335
|
+
void *mapped;
|
|
336
|
+
size_t mapped_size;
|
|
459
337
|
HashNode **table;
|
|
460
|
-
|
|
461
|
-
// BPE tokenization data
|
|
462
338
|
BPEMergeTable merges;
|
|
463
339
|
RegexPattern *pre_patterns;
|
|
464
340
|
int num_pre_patterns;
|
|
@@ -466,6 +342,7 @@ typedef struct {
|
|
|
466
342
|
int bos_token_id;
|
|
467
343
|
int eos_token_id;
|
|
468
344
|
int vocab_type;
|
|
345
|
+
char space_marker[8];
|
|
469
346
|
} EmbedModel;
|
|
470
347
|
|
|
471
348
|
typedef struct {
|
|
@@ -498,11 +375,11 @@ static int hget(EmbedModel *m, const char *k) {
|
|
|
498
375
|
}
|
|
499
376
|
|
|
500
377
|
static int text_to_id(void *vocab_data, const char *text) {
|
|
501
|
-
|
|
502
|
-
return hget(m, text);
|
|
378
|
+
return hget((EmbedModel*)vocab_data, text);
|
|
503
379
|
}
|
|
504
380
|
|
|
505
381
|
/* ------------------------------------------------------------------------- */
|
|
382
|
+
// File mapping
|
|
506
383
|
static void *map_file(const char *path, size_t *size) {
|
|
507
384
|
int fd = open(path, O_RDONLY);
|
|
508
385
|
if (fd < 0) return NULL;
|
|
@@ -511,29 +388,22 @@ static void *map_file(const char *path, size_t *size) {
|
|
|
511
388
|
*size = st.st_size;
|
|
512
389
|
void *data = mmap(NULL, *size, PROT_READ, MAP_PRIVATE, fd, 0);
|
|
513
390
|
close(fd);
|
|
514
|
-
|
|
515
|
-
return data;
|
|
391
|
+
return data == MAP_FAILED ? NULL : data;
|
|
516
392
|
}
|
|
517
393
|
|
|
518
394
|
/* ------------------------------------------------------------------------- */
|
|
395
|
+
// FP16 conversion
|
|
519
396
|
static float fp16_to_fp32(uint16_t h) {
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
if (exp ==
|
|
525
|
-
|
|
526
|
-
} else if (exp == 31) {
|
|
527
|
-
return 0.0f;
|
|
528
|
-
} else {
|
|
529
|
-
val = (1.0f + mant / 1024.0f) * (1 << (exp - 15));
|
|
530
|
-
}
|
|
531
|
-
return sign ? -val : val;
|
|
397
|
+
uint16_t sign = (h >> 15) & 1;
|
|
398
|
+
uint16_t exp = (h >> 10) & 0x1F;
|
|
399
|
+
uint16_t mant = h & 0x3FF;
|
|
400
|
+
if (exp == 0) return (mant / 1024.0f) * 6.103515625e-5f * (sign ? -1.0f : 1.0f);
|
|
401
|
+
if (exp == 31) return 0.0f;
|
|
402
|
+
return (1.0f + mant / 1024.0f) * (1 << (exp - 15)) * (sign ? -1.0f : 1.0f);
|
|
532
403
|
}
|
|
533
404
|
|
|
534
405
|
/* ------------------------------------------------------------------------- */
|
|
535
|
-
|
|
536
|
-
|
|
406
|
+
// Block dequantization functions
|
|
537
407
|
static void dequantize_row_q4_0(const void *vx, float *y, int k) {
|
|
538
408
|
const int nb = k / 32;
|
|
539
409
|
const uint8_t *x = vx;
|
|
@@ -621,7 +491,6 @@ static void dequantize_row_q8_1(const void *vx, float *y, int k) {
|
|
|
621
491
|
}
|
|
622
492
|
}
|
|
623
493
|
|
|
624
|
-
/* K-quants */
|
|
625
494
|
static void dequantize_row_q2_K(const void *vx, float *y, int k) {
|
|
626
495
|
const int nb = k / 256;
|
|
627
496
|
const uint8_t *x = vx;
|
|
@@ -748,7 +617,6 @@ static void dequantize_row_q8_K(const void *vx, float *y, int k) {
|
|
|
748
617
|
}
|
|
749
618
|
}
|
|
750
619
|
|
|
751
|
-
/* ------------------------------------------------------------------------- */
|
|
752
620
|
static float* dequantize_tensor(const void *data, int type, int n_rows, int n_cols) {
|
|
753
621
|
if (type == GGML_TYPE_F32) {
|
|
754
622
|
float *out = malloc(n_rows * n_cols * sizeof(float));
|
|
@@ -793,6 +661,14 @@ static float* dequantize_tensor(const void *data, int type, int n_rows, int n_co
|
|
|
793
661
|
for (int r = 0; r < n_rows; r++) {
|
|
794
662
|
dequant_func(in + r * row_bytes, out + r * n_cols, n_cols);
|
|
795
663
|
}
|
|
664
|
+
|
|
665
|
+
// Sanitize the tensor: replace NaNs, Infs, and astronomically large values with zero
|
|
666
|
+
int total = n_rows * n_cols;
|
|
667
|
+
for (int i = 0; i < total; i++) {
|
|
668
|
+
if (isnan(out[i]) || isinf(out[i]) || fabs(out[i]) > 1e10f) {
|
|
669
|
+
out[i] = 0.0f;
|
|
670
|
+
}
|
|
671
|
+
}
|
|
796
672
|
return out;
|
|
797
673
|
}
|
|
798
674
|
|
|
@@ -837,45 +713,49 @@ static void free_model_contents(EmbedModel *m) {
|
|
|
837
713
|
}
|
|
838
714
|
if (m->float_data) free(m->float_data);
|
|
839
715
|
if (m->mapped) munmap(m->mapped, m->mapped_size);
|
|
840
|
-
|
|
841
|
-
// Free BPE tokenization data
|
|
842
716
|
bpe_merge_table_free(&m->merges);
|
|
843
717
|
if (m->pre_patterns) {
|
|
844
|
-
for (int i = 0; i < m->num_pre_patterns; i++)
|
|
845
|
-
free(m->pre_patterns[i].pattern);
|
|
846
|
-
}
|
|
718
|
+
for (int i = 0; i < m->num_pre_patterns; i++) free(m->pre_patterns[i].pattern);
|
|
847
719
|
free(m->pre_patterns);
|
|
848
720
|
}
|
|
849
|
-
|
|
850
721
|
free(m);
|
|
851
722
|
}
|
|
852
723
|
|
|
853
724
|
/* ------------------------------------------------------------------------- */
|
|
854
725
|
static int is_printable_string(const char *s, size_t len) {
|
|
855
|
-
for (size_t i = 0; i < len; i++)
|
|
856
|
-
if (!isprint((unsigned char)s[i])) return 0;
|
|
726
|
+
for (size_t i = 0; i < len; i++) if (!isprint((unsigned char)s[i])) return 0;
|
|
857
727
|
return 1;
|
|
858
728
|
}
|
|
859
729
|
|
|
860
|
-
/* Fallback: find the start of tensor info by scanning for a valid string */
|
|
861
730
|
static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
|
|
862
731
|
uint8_t *scan = cur;
|
|
863
732
|
while (scan + 8 < end) {
|
|
864
733
|
uint64_t len;
|
|
865
734
|
memcpy(&len, scan, 8);
|
|
866
|
-
if (len > 0 && len < 256 && scan + 8 + len <= end)
|
|
867
|
-
|
|
868
|
-
return scan;
|
|
869
|
-
}
|
|
870
|
-
}
|
|
735
|
+
if (len > 0 && len < 256 && scan + 8 + len <= end && is_printable_string((char*)scan+8, len))
|
|
736
|
+
return scan;
|
|
871
737
|
scan++;
|
|
872
738
|
}
|
|
873
739
|
return NULL;
|
|
874
740
|
}
|
|
875
741
|
|
|
876
742
|
/* ------------------------------------------------------------------------- */
|
|
743
|
+
static void detect_space_marker(EmbedModel *m) {
|
|
744
|
+
const char *candidates[] = {"▁", "Ġ", " "};
|
|
745
|
+
for (int i = 0; i < 3; i++) {
|
|
746
|
+
const char *marker = candidates[i];
|
|
747
|
+
int marker_len = strlen(marker);
|
|
748
|
+
for (int j = 0; j < m->vocab_size; j++) {
|
|
749
|
+
if (strncmp(m->tokens[j], marker, marker_len) == 0) {
|
|
750
|
+
strcpy(m->space_marker, marker);
|
|
751
|
+
return;
|
|
752
|
+
}
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
m->space_marker[0] = '\0';
|
|
756
|
+
}
|
|
757
|
+
|
|
877
758
|
static void setup_default_pre_patterns(EmbedModel *m) {
|
|
878
|
-
// Default pre-tokenization regex patterns (similar to Llama 3)
|
|
879
759
|
const char *default_patterns[] = {
|
|
880
760
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])",
|
|
881
761
|
"[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
|
|
@@ -885,29 +765,23 @@ static void setup_default_pre_patterns(EmbedModel *m) {
|
|
|
885
765
|
"\\s+(?!\\S)",
|
|
886
766
|
"\\s+"
|
|
887
767
|
};
|
|
888
|
-
|
|
889
|
-
m->num_pre_patterns = sizeof(default_patterns) / sizeof(default_patterns[0]);
|
|
768
|
+
m->num_pre_patterns = sizeof(default_patterns)/sizeof(default_patterns[0]);
|
|
890
769
|
m->pre_patterns = malloc(m->num_pre_patterns * sizeof(RegexPattern));
|
|
891
|
-
|
|
892
770
|
for (int i = 0; i < m->num_pre_patterns; i++) {
|
|
893
771
|
m->pre_patterns[i].pattern = strdup(default_patterns[i]);
|
|
894
772
|
m->pre_patterns[i].pattern_len = strlen(default_patterns[i]);
|
|
895
773
|
}
|
|
896
774
|
}
|
|
897
775
|
|
|
898
|
-
/* ------------------------------------------------------------------------- */
|
|
899
776
|
static void parse_merge(const char *merge_str, char **left, char **right) {
|
|
900
|
-
// Parse a merge string like "h ello" -> left="h", right="ello"
|
|
901
777
|
const char *space = strchr(merge_str, ' ');
|
|
902
778
|
if (space) {
|
|
903
779
|
int left_len = space - merge_str;
|
|
904
|
-
*left = malloc(left_len
|
|
780
|
+
*left = malloc(left_len+1);
|
|
905
781
|
memcpy(*left, merge_str, left_len);
|
|
906
782
|
(*left)[left_len] = '\0';
|
|
907
|
-
|
|
908
|
-
*right = strdup(space + 1);
|
|
783
|
+
*right = strdup(space+1);
|
|
909
784
|
} else {
|
|
910
|
-
// No space - treat as single token
|
|
911
785
|
*left = strdup(merge_str);
|
|
912
786
|
*right = strdup("");
|
|
913
787
|
}
|
|
@@ -918,13 +792,10 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
918
792
|
size_t sz;
|
|
919
793
|
uint8_t *base = map_file(path, &sz);
|
|
920
794
|
if (!base) return NULL;
|
|
921
|
-
uint8_t *cur = base;
|
|
922
|
-
uint8_t *end = base + sz;
|
|
923
|
-
|
|
795
|
+
uint8_t *cur = base, *end = base + sz;
|
|
924
796
|
if (memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
|
|
925
797
|
cur += 4;
|
|
926
798
|
uint32_t version = rd32(&cur, end);
|
|
927
|
-
(void)version;
|
|
928
799
|
uint64_t n_tensors = rd64(&cur, end);
|
|
929
800
|
uint64_t n_kv = rd64(&cur, end);
|
|
930
801
|
|
|
@@ -934,26 +805,20 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
934
805
|
m->mapped_size = sz;
|
|
935
806
|
m->table = calloc(HASH_SIZE, sizeof(HashNode*));
|
|
936
807
|
if (!m->table) { free_model_contents(m); return NULL; }
|
|
937
|
-
|
|
938
|
-
// Initialize BPE structures
|
|
939
808
|
bpe_merge_table_init(&m->merges);
|
|
940
809
|
setup_default_pre_patterns(m);
|
|
941
|
-
|
|
942
|
-
// Default values
|
|
943
810
|
m->unknown_token_id = -1;
|
|
944
811
|
m->bos_token_id = -1;
|
|
945
812
|
m->eos_token_id = -1;
|
|
946
813
|
m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
|
|
814
|
+
m->space_marker[0] = '\0';
|
|
947
815
|
|
|
948
|
-
/* ---------- Metadata ---------- */
|
|
949
816
|
int vocab_found = 0;
|
|
950
817
|
for (uint64_t i = 0; i < n_kv; i++) {
|
|
951
818
|
char *key = rdstr(&cur, end);
|
|
952
819
|
if (!key) { free_model_contents(m); return NULL; }
|
|
953
820
|
uint32_t type = rd32(&cur, end);
|
|
954
|
-
|
|
955
|
-
if ((strcmp(key, "tokenizer.ggml.tokens") == 0 ||
|
|
956
|
-
strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
|
|
821
|
+
if ((strcmp(key, "tokenizer.ggml.tokens") == 0 || strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
|
|
957
822
|
uint32_t subtype = rd32(&cur, end);
|
|
958
823
|
uint64_t n = rd64(&cur, end);
|
|
959
824
|
if (subtype != 8) { free(key); free_model_contents(m); return NULL; }
|
|
@@ -971,40 +836,29 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
971
836
|
uint32_t subtype = rd32(&cur, end);
|
|
972
837
|
uint64_t n = rd64(&cur, end);
|
|
973
838
|
if (subtype == 8) {
|
|
974
|
-
// Parse merges
|
|
975
839
|
for (uint64_t j = 0; j < n && j < MAX_MERGES; j++) {
|
|
976
840
|
char *merge_str = rdstr(&cur, end);
|
|
977
841
|
if (merge_str) {
|
|
978
842
|
char *left, *right;
|
|
979
843
|
parse_merge(merge_str, &left, &right);
|
|
980
|
-
bpe_merge_table_add(&m->merges, left, right, merge_str, j);
|
|
981
|
-
free(left);
|
|
982
|
-
free(right);
|
|
844
|
+
bpe_merge_table_add(&m->merges, left, right, merge_str, (int)j);
|
|
845
|
+
free(left); free(right);
|
|
983
846
|
free(merge_str);
|
|
984
847
|
}
|
|
985
848
|
}
|
|
986
849
|
} else {
|
|
987
|
-
|
|
988
|
-
if (!skip_value(&cur, end, type)) {
|
|
989
|
-
free(key); free_model_contents(m); return NULL;
|
|
990
|
-
}
|
|
850
|
+
if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
|
|
991
851
|
}
|
|
992
852
|
} else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
|
|
993
853
|
char *model_type = rdstr(&cur, end);
|
|
994
854
|
if (model_type) {
|
|
995
|
-
if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0)
|
|
996
|
-
|
|
997
|
-
} else if (strcmp(model_type, "bert") == 0) {
|
|
998
|
-
m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
|
|
999
|
-
}
|
|
855
|
+
if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0) m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
|
|
856
|
+
else if (strcmp(model_type, "bert") == 0) m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
|
|
1000
857
|
free(model_type);
|
|
1001
858
|
}
|
|
1002
859
|
} else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
|
|
1003
|
-
char *
|
|
1004
|
-
if (
|
|
1005
|
-
// Could load custom regex patterns here if needed
|
|
1006
|
-
free(pre_type);
|
|
1007
|
-
}
|
|
860
|
+
char *pre = rdstr(&cur, end);
|
|
861
|
+
if (pre) free(pre);
|
|
1008
862
|
} else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
|
|
1009
863
|
m->unknown_token_id = rd32(&cur, end);
|
|
1010
864
|
} else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
|
|
@@ -1012,20 +866,16 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1012
866
|
} else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
|
|
1013
867
|
m->eos_token_id = rd32(&cur, end);
|
|
1014
868
|
} else {
|
|
1015
|
-
if (!skip_value(&cur, end, type)) {
|
|
1016
|
-
free(key); free_model_contents(m); return NULL;
|
|
1017
|
-
}
|
|
869
|
+
if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
|
|
1018
870
|
}
|
|
1019
871
|
free(key);
|
|
1020
872
|
}
|
|
1021
|
-
|
|
1022
873
|
if (!vocab_found) { free_model_contents(m); return NULL; }
|
|
874
|
+
detect_space_marker(m);
|
|
1023
875
|
|
|
1024
876
|
uint8_t *after_kv = cur;
|
|
1025
877
|
align_to_32(&cur, end, base);
|
|
1026
878
|
uint8_t *tensor_start = cur;
|
|
1027
|
-
|
|
1028
|
-
/* ---------- Tensor info ---------- */
|
|
1029
879
|
int embd_found = 0;
|
|
1030
880
|
void *raw_tensor_data = NULL;
|
|
1031
881
|
int tensor_type = -1;
|
|
@@ -1039,39 +889,27 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1039
889
|
if (!name) break;
|
|
1040
890
|
uint32_t n_dims = rd32(&cur, end);
|
|
1041
891
|
uint64_t dims[MAX_DIMS] = {0};
|
|
1042
|
-
for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++)
|
|
1043
|
-
dims[d] = rd64(&cur, end);
|
|
892
|
+
for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
|
|
1044
893
|
uint32_t type = rd32(&cur, end);
|
|
1045
894
|
uint64_t offset = rd64(&cur, end);
|
|
1046
|
-
|
|
1047
895
|
int is_token_embd = (strcmp(name, "token_embd.weight") == 0 ||
|
|
1048
896
|
strcmp(name, "embeddings.word_embeddings.weight") == 0 ||
|
|
1049
897
|
strcmp(name, "model.embed_tokens.weight") == 0);
|
|
1050
|
-
|
|
1051
898
|
if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
|
|
1052
|
-
if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")
|
|
1053
|
-
|
|
1054
|
-
else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd") != NULL)
|
|
1055
|
-
is_token_embd = 1;
|
|
899
|
+
if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
|
|
900
|
+
else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
|
|
1056
901
|
}
|
|
1057
|
-
|
|
1058
902
|
if (!embd_found && is_token_embd) {
|
|
1059
903
|
if (n_dims < 2 || dims[1] == 0) { free(name); free_model_contents(m); return NULL; }
|
|
1060
|
-
dim0 = dims[0];
|
|
1061
|
-
dim1 =
|
|
1062
|
-
if (
|
|
1063
|
-
|
|
1064
|
-
need_transpose = 0;
|
|
1065
|
-
} else if (dim1 == (uint64_t)m->vocab_size) {
|
|
1066
|
-
m->dim = (int)dim0;
|
|
1067
|
-
need_transpose = 1;
|
|
1068
|
-
} else {
|
|
1069
|
-
m->dim = (dim0 < dim1) ? (int)dim0 : (int)dim1;
|
|
1070
|
-
need_transpose = (dim0 > dim1) ? 1 : 0;
|
|
1071
|
-
}
|
|
904
|
+
dim0 = dims[0]; dim1 = dims[1];
|
|
905
|
+
if (dim0 == (uint64_t)m->vocab_size) { m->dim = (int)dim1; need_transpose = 0; }
|
|
906
|
+
else if (dim1 == (uint64_t)m->vocab_size) { m->dim = (int)dim0; need_transpose = 1; }
|
|
907
|
+
else { m->dim = (dim0 < dim1) ? (int)dim0 : (int)dim1; need_transpose = (dim0 > dim1) ? 1 : 0; }
|
|
1072
908
|
raw_tensor_data = base + offset;
|
|
1073
909
|
tensor_type = type;
|
|
1074
910
|
embd_found = 1;
|
|
911
|
+
free(name);
|
|
912
|
+
break;
|
|
1075
913
|
}
|
|
1076
914
|
free(name);
|
|
1077
915
|
}
|
|
@@ -1081,13 +919,8 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1081
919
|
if (!tensor_start) break;
|
|
1082
920
|
}
|
|
1083
921
|
}
|
|
922
|
+
if (!embd_found || m->dim == 0) { free_model_contents(m); return NULL; }
|
|
1084
923
|
|
|
1085
|
-
if (!embd_found || m->dim == 0) {
|
|
1086
|
-
free_model_contents(m);
|
|
1087
|
-
return NULL;
|
|
1088
|
-
}
|
|
1089
|
-
|
|
1090
|
-
/* Dequantize */
|
|
1091
924
|
if (tensor_type == GGML_TYPE_F32 && !need_transpose) {
|
|
1092
925
|
m->float_data = NULL;
|
|
1093
926
|
m->tensor_data = raw_tensor_data;
|
|
@@ -1095,10 +928,7 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1095
928
|
int n_rows = need_transpose ? (int)dim1 : (int)dim0;
|
|
1096
929
|
int n_cols = need_transpose ? (int)dim0 : (int)dim1;
|
|
1097
930
|
m->float_data = dequantize_tensor(raw_tensor_data, tensor_type, n_rows, n_cols);
|
|
1098
|
-
if (!m->float_data) {
|
|
1099
|
-
free_model_contents(m);
|
|
1100
|
-
return NULL;
|
|
1101
|
-
}
|
|
931
|
+
if (!m->float_data) { free_model_contents(m); return NULL; }
|
|
1102
932
|
m->tensor_data = m->float_data;
|
|
1103
933
|
}
|
|
1104
934
|
m->tensor_type = tensor_type;
|
|
@@ -1109,79 +939,84 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1109
939
|
/* ------------------------------------------------------------------------- */
|
|
1110
940
|
static void embed_text(EmbedModel *m, const char *txt, float *out) {
|
|
1111
941
|
memset(out, 0, sizeof(float) * m->dim);
|
|
1112
|
-
|
|
1113
|
-
// Pre-tokenize using regex
|
|
1114
942
|
int num_words = 0;
|
|
1115
943
|
char **words = unicode_regex_split(txt, m->pre_patterns, m->num_pre_patterns, &num_words);
|
|
1116
|
-
|
|
1117
944
|
if (!words || num_words == 0) {
|
|
1118
|
-
// Fallback to space
|
|
945
|
+
// Fallback to simple space split
|
|
1119
946
|
char *copy = strdup(txt);
|
|
1120
|
-
if (
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
947
|
+
if (copy) {
|
|
948
|
+
char *tok = strtok(copy, " \t\n\r");
|
|
949
|
+
int used = 0;
|
|
950
|
+
const float *embd = (float*)m->tensor_data;
|
|
951
|
+
while (tok) {
|
|
952
|
+
int id = hget(m, tok);
|
|
953
|
+
if (id >= 0 && id < m->vocab_size) {
|
|
954
|
+
const float *vec = embd + id * m->dim;
|
|
955
|
+
for (int i = 0; i < m->dim; i++) out[i] += vec[i];
|
|
956
|
+
used++;
|
|
957
|
+
}
|
|
958
|
+
tok = strtok(NULL, " \t\n\r");
|
|
1132
959
|
}
|
|
1133
|
-
|
|
960
|
+
if (used) { float inv = 1.0f / used; for (int i = 0; i < m->dim; i++) out[i] *= inv; }
|
|
961
|
+
free(copy);
|
|
1134
962
|
}
|
|
1135
|
-
|
|
1136
|
-
if (used > 0) {
|
|
1137
|
-
float inv = 1.0f / used;
|
|
1138
|
-
for (int i = 0; i < m->dim; i++) out[i] *= inv;
|
|
1139
|
-
}
|
|
1140
|
-
free(copy);
|
|
963
|
+
if (words) free(words);
|
|
1141
964
|
return;
|
|
1142
965
|
}
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
int *token_ids = malloc(m->vocab_size * sizeof(int)); // Max possible tokens
|
|
1146
|
-
int num_tokens = 0;
|
|
1147
|
-
const float *embd_matrix = m->tensor_data;
|
|
966
|
+
|
|
967
|
+
int *token_ids = malloc(m->vocab_size * sizeof(int));
|
|
1148
968
|
int used = 0;
|
|
1149
|
-
|
|
969
|
+
const float *embd = (float*)m->tensor_data;
|
|
1150
970
|
for (int i = 0; i < num_words; i++) {
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
971
|
+
char *word = words[i];
|
|
972
|
+
int id = hget(m, word);
|
|
973
|
+
if (id == -1 && m->space_marker[0]) {
|
|
974
|
+
char *with_marker = malloc(strlen(m->space_marker) + strlen(word) + 1);
|
|
975
|
+
strcpy(with_marker, m->space_marker);
|
|
976
|
+
strcat(with_marker, word);
|
|
977
|
+
id = hget(m, with_marker);
|
|
978
|
+
free(with_marker);
|
|
979
|
+
}
|
|
980
|
+
if (id != -1) {
|
|
981
|
+
const float *vec = embd + id * m->dim;
|
|
982
|
+
for (int j = 0; j < m->dim; j++) out[j] += vec[j];
|
|
983
|
+
used++;
|
|
984
|
+
} else {
|
|
985
|
+
int num_tokens = 0;
|
|
986
|
+
bpe_tokenize_word(&m->merges, word, text_to_id, m, token_ids, &num_tokens);
|
|
987
|
+
for (int k = 0; k < num_tokens; k++) {
|
|
988
|
+
int tid = token_ids[k];
|
|
989
|
+
if (tid >= 0 && tid < m->vocab_size) {
|
|
990
|
+
const float *vec = embd + tid * m->dim;
|
|
991
|
+
for (int j = 0; j < m->dim; j++) out[j] += vec[j];
|
|
992
|
+
used++;
|
|
993
|
+
} else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
|
|
994
|
+
const float *vec = embd + m->unknown_token_id * m->dim;
|
|
995
|
+
for (int j = 0; j < m->dim; j++) out[j] += vec[j];
|
|
996
|
+
used++;
|
|
997
|
+
}
|
|
1165
998
|
}
|
|
1166
999
|
}
|
|
1167
|
-
|
|
1168
|
-
free(words[i]);
|
|
1000
|
+
free(word);
|
|
1169
1001
|
}
|
|
1170
1002
|
free(words);
|
|
1171
1003
|
free(token_ids);
|
|
1172
|
-
|
|
1173
1004
|
if (used > 0) {
|
|
1174
1005
|
float inv = 1.0f / used;
|
|
1175
1006
|
for (int i = 0; i < m->dim; i++) out[i] *= inv;
|
|
1176
1007
|
}
|
|
1008
|
+
for (int i = 0; i < m->dim; i++) {
|
|
1009
|
+
if (isnan(out[i]) || isinf(out[i])) {
|
|
1010
|
+
out[i] = 0.0f;
|
|
1011
|
+
}
|
|
1012
|
+
}
|
|
1177
1013
|
}
|
|
1178
1014
|
|
|
1179
1015
|
/* ------------------------------------------------------------------------- */
|
|
1016
|
+
// Ruby bindings
|
|
1180
1017
|
static void rb_embedder_free(void *p) {
|
|
1181
1018
|
ruby_embedder *e = p;
|
|
1182
|
-
if (
|
|
1183
|
-
if (e->model) free_model_contents(e->model);
|
|
1184
|
-
free(e);
|
|
1019
|
+
if (e) { if (e->model) free_model_contents(e->model); free(e); }
|
|
1185
1020
|
}
|
|
1186
1021
|
|
|
1187
1022
|
static size_t rb_embedder_memsize(const void *p) {
|
|
@@ -1202,22 +1037,18 @@ static VALUE rb_embedder_alloc(VALUE klass) {
|
|
|
1202
1037
|
static VALUE rb_embedder_initialize(VALUE self, VALUE opts) {
|
|
1203
1038
|
ruby_embedder *e;
|
|
1204
1039
|
TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
|
|
1205
|
-
|
|
1206
1040
|
VALUE path = rb_hash_aref(opts, ID2SYM(rb_intern("model")));
|
|
1207
1041
|
const char *cpath = StringValueCStr(path);
|
|
1208
1042
|
e->model = embed_load_gguf(cpath);
|
|
1209
|
-
if (!e->model)
|
|
1210
|
-
rb_raise(rb_eRuntimeError, "failed to load GGUF model");
|
|
1043
|
+
if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model");
|
|
1211
1044
|
return self;
|
|
1212
1045
|
}
|
|
1213
1046
|
|
|
1214
1047
|
static VALUE rb_embed(VALUE self, VALUE opts) {
|
|
1215
1048
|
ruby_embedder *e;
|
|
1216
1049
|
TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
|
|
1217
|
-
|
|
1218
1050
|
VALUE text = rb_hash_aref(opts, ID2SYM(rb_intern("text")));
|
|
1219
1051
|
const char *ctext = StringValueCStr(text);
|
|
1220
|
-
|
|
1221
1052
|
VALUE out = rb_str_new(NULL, e->model->dim * sizeof(float));
|
|
1222
1053
|
embed_text(e->model, ctext, (float*)RSTRING_PTR(out));
|
|
1223
1054
|
return out;
|
|
@@ -1227,5 +1058,5 @@ void Init_mini_embed(void) {
|
|
|
1227
1058
|
VALUE c = rb_define_class("MiniEmbed", rb_cObject);
|
|
1228
1059
|
rb_define_alloc_func(c, rb_embedder_alloc);
|
|
1229
1060
|
rb_define_method(c, "initialize", rb_embedder_initialize, 1);
|
|
1230
|
-
rb_define_method(c, "
|
|
1061
|
+
rb_define_method(c, "embed", rb_embed, 1);
|
|
1231
1062
|
}
|
data/lib/mini_embed.rb
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
3
|
require 'mini_embed/mini_embed'
|
|
4
|
+
|
|
5
|
+
class MiniEmbed
|
|
6
|
+
# @param text [String] - text to extract embeddings from
|
|
7
|
+
# @param type [Symbol, nil] - :binary or :vector - type of data you want to receive
|
|
8
|
+
# @return [String, <Float>] - type == :binary - binary string, type == :vector - array of floats
|
|
9
|
+
def embeddings(text: text, type: :vector)
|
|
10
|
+
binary_data = embed(text: text) # call original C method
|
|
11
|
+
|
|
12
|
+
return binary_data if type == :binary
|
|
13
|
+
return binary_data.unpack('e*') if type == :vector
|
|
14
|
+
|
|
15
|
+
raise ArgumentError, "Unsupported data type: #{type}"
|
|
16
|
+
end
|
|
17
|
+
end
|