mini_embed 0.1.1 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +9 -5
- data/ext/mini_embed/mini_embed.c +788 -603
- data/lib/mini_embed.rb +14 -0
- metadata +1 -1
data/ext/mini_embed/mini_embed.c
CHANGED
|
@@ -8,13 +8,18 @@
|
|
|
8
8
|
#include <fcntl.h>
|
|
9
9
|
#include <unistd.h>
|
|
10
10
|
#include <ctype.h>
|
|
11
|
+
#include <limits.h>
|
|
11
12
|
#include "ruby.h"
|
|
12
13
|
|
|
13
|
-
#define HASH_SIZE
|
|
14
|
-
#define MAX_DIMS
|
|
15
|
-
#define GGUF_ALIGN
|
|
16
|
-
#define MAX_MERGES
|
|
17
|
-
#define
|
|
14
|
+
#define HASH_SIZE 131071
|
|
15
|
+
#define MAX_DIMS 4
|
|
16
|
+
#define GGUF_ALIGN 32
|
|
17
|
+
#define MAX_MERGES 100000
|
|
18
|
+
#define MERGE_HASH_SIZE 65537
|
|
19
|
+
#define QK8_0 32
|
|
20
|
+
#define QK_K 256
|
|
21
|
+
#define K_SCALE_SIZE 12
|
|
22
|
+
#define MAX_DIM 16384
|
|
18
23
|
|
|
19
24
|
enum ggml_type {
|
|
20
25
|
GGML_TYPE_F32 = 0,
|
|
@@ -40,18 +45,22 @@ enum llama_vocab_type {
|
|
|
40
45
|
LLAMA_VOCAB_TYPE_WPM = 3,
|
|
41
46
|
};
|
|
42
47
|
|
|
48
|
+
enum normalize_type {
|
|
49
|
+
NORM_NONE = 0,
|
|
50
|
+
NORM_L2 = 1,
|
|
51
|
+
};
|
|
52
|
+
|
|
43
53
|
/* ------------------------------------------------------------------------- */
|
|
44
|
-
// Unicode helper functions
|
|
54
|
+
// Unicode helper functions
|
|
45
55
|
static int unicode_len_utf8(char c) {
|
|
46
56
|
if ((c & 0x80) == 0) return 1;
|
|
47
57
|
if ((c & 0xE0) == 0xC0) return 2;
|
|
48
58
|
if ((c & 0xF0) == 0xE0) return 3;
|
|
49
59
|
if ((c & 0xF8) == 0xF0) return 4;
|
|
50
|
-
return 1;
|
|
60
|
+
return 1;
|
|
51
61
|
}
|
|
52
62
|
|
|
53
63
|
static int unicode_is_letter(uint32_t cp) {
|
|
54
|
-
// Basic Unicode letter detection (simplified)
|
|
55
64
|
return (cp >= 0x41 && cp <= 0x5A) || (cp >= 0x61 && cp <= 0x7A) ||
|
|
56
65
|
(cp >= 0xC0 && cp <= 0xD6) || (cp >= 0xD8 && cp <= 0xF6) ||
|
|
57
66
|
(cp >= 0xF8 && cp <= 0x2FF) || (cp >= 0x370 && cp <= 0x37D) ||
|
|
@@ -68,224 +77,243 @@ static int unicode_is_number(uint32_t cp) {
|
|
|
68
77
|
}
|
|
69
78
|
|
|
70
79
|
static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
|
|
71
|
-
uint32_t cp = 0;
|
|
72
80
|
unsigned char c = (unsigned char)s[0];
|
|
73
|
-
|
|
74
|
-
if (c
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
} else if ((c & 0xE0) == 0xC0) {
|
|
78
|
-
*len = 2;
|
|
79
|
-
cp = (c & 0x1F) << 6;
|
|
80
|
-
cp |= (s[1] & 0x3F);
|
|
81
|
-
return cp;
|
|
82
|
-
} else if ((c & 0xF0) == 0xE0) {
|
|
83
|
-
*len = 3;
|
|
84
|
-
cp = (c & 0x0F) << 12;
|
|
85
|
-
cp |= (s[1] & 0x3F) << 6;
|
|
86
|
-
cp |= (s[2] & 0x3F);
|
|
87
|
-
return cp;
|
|
88
|
-
} else if ((c & 0xF8) == 0xF0) {
|
|
89
|
-
*len = 4;
|
|
90
|
-
cp = (c & 0x07) << 18;
|
|
91
|
-
cp |= (s[1] & 0x3F) << 12;
|
|
92
|
-
cp |= (s[2] & 0x3F) << 6;
|
|
93
|
-
cp |= (s[3] & 0x3F);
|
|
94
|
-
return cp;
|
|
95
|
-
}
|
|
96
|
-
|
|
81
|
+
if (c < 0x80) { *len = 1; return c; }
|
|
82
|
+
if ((c & 0xE0) == 0xC0) { *len = 2; return ((c & 0x1F) << 6) | (s[1] & 0x3F); }
|
|
83
|
+
if ((c & 0xF0) == 0xE0) { *len = 3; return ((c & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F); }
|
|
84
|
+
if ((c & 0xF8) == 0xF0) { *len = 4; return ((c & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F); }
|
|
97
85
|
*len = 1;
|
|
98
86
|
return c;
|
|
99
87
|
}
|
|
100
88
|
|
|
101
89
|
/* ------------------------------------------------------------------------- */
|
|
102
|
-
//
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
if (text[0] == '\r') return 1;
|
|
132
|
-
} else if (p[0] == '\\' && p[1] == 'n') {
|
|
133
|
-
if (text[0] == '\n') return 1;
|
|
134
|
-
} else if (p[0] == '.' && p[1] == '*') {
|
|
135
|
-
// Match anything
|
|
136
|
-
return 1;
|
|
137
|
-
} else if (isalnum(p[0]) || ispunct(p[0])) {
|
|
138
|
-
// Match literal character
|
|
139
|
-
if (text[0] == p[0]) return 1;
|
|
140
|
-
}
|
|
90
|
+
// Pre-tokenizer (GPT-2/Llama style, replaces broken regex)
|
|
91
|
+
#define CHAR_CLASS_SPACE 0
|
|
92
|
+
#define CHAR_CLASS_LETTER 1
|
|
93
|
+
#define CHAR_CLASS_NUMBER 2
|
|
94
|
+
#define CHAR_CLASS_NEWLINE 3
|
|
95
|
+
#define CHAR_CLASS_OTHER 4
|
|
96
|
+
|
|
97
|
+
static int get_char_class(uint32_t cp) {
|
|
98
|
+
if (unicode_is_letter(cp)) return CHAR_CLASS_LETTER;
|
|
99
|
+
if (unicode_is_number(cp)) return CHAR_CLASS_NUMBER;
|
|
100
|
+
if (cp == '\n' || cp == '\r') return CHAR_CLASS_NEWLINE;
|
|
101
|
+
if (cp == ' ' || cp == '\t') return CHAR_CLASS_SPACE;
|
|
102
|
+
return CHAR_CLASS_OTHER;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
static int is_contraction(const char *text, size_t pos, size_t text_len) {
|
|
106
|
+
if (pos >= text_len) return 0;
|
|
107
|
+
unsigned char c = (unsigned char)text[pos];
|
|
108
|
+
if (c != '\'' && c != 0xE2) return 0;
|
|
109
|
+
if (c == 0xE2 && pos + 2 < text_len && text[pos+1] == 0x80 && (text[pos+2] == 0x99 || text[pos+2] == 0x98)) {
|
|
110
|
+
if (pos + 3 >= text_len) return 0;
|
|
111
|
+
char next = tolower((unsigned char)text[pos + 3]);
|
|
112
|
+
return next == 's' || next == 't' || next == 'r' || next == 'v' ||
|
|
113
|
+
next == 'm' || next == 'l' || next == 'd';
|
|
114
|
+
}
|
|
115
|
+
if (c == '\'' && pos + 1 < text_len) {
|
|
116
|
+
char next = tolower((unsigned char)text[pos + 1]);
|
|
117
|
+
return next == 's' || next == 't' || next == 'r' || next == 'v' ||
|
|
118
|
+
next == 'm' || next == 'l' || next == 'd';
|
|
141
119
|
}
|
|
142
120
|
return 0;
|
|
143
121
|
}
|
|
144
122
|
|
|
145
|
-
static
|
|
123
|
+
static size_t contraction_len(const char *text, size_t pos) {
|
|
124
|
+
unsigned char c = (unsigned char)text[pos];
|
|
125
|
+
if (c == '\'') return 2;
|
|
126
|
+
return 4;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
static char** pre_tokenize(const char *text, int *num_words) {
|
|
146
130
|
char **words = NULL;
|
|
147
|
-
int word_count = 0;
|
|
148
|
-
int word_capacity = 0;
|
|
149
|
-
|
|
131
|
+
int word_count = 0, word_capacity = 0;
|
|
150
132
|
size_t text_len = strlen(text);
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
133
|
+
|
|
134
|
+
if (text_len == 0) {
|
|
135
|
+
*num_words = 0;
|
|
136
|
+
return NULL;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
#define ADD_WORD(ptr, len) do { \
|
|
140
|
+
char *w = malloc((len) + 1); \
|
|
141
|
+
if (!w) goto error; \
|
|
142
|
+
memcpy(w, ptr, len); \
|
|
143
|
+
w[len] = '\0'; \
|
|
144
|
+
if (word_count >= word_capacity) { \
|
|
145
|
+
word_capacity = word_capacity ? word_capacity * 2 : 16; \
|
|
146
|
+
char **nw = realloc(words, word_capacity * sizeof(char*)); \
|
|
147
|
+
if (!nw) { free(w); goto error; } \
|
|
148
|
+
words = nw; \
|
|
149
|
+
} \
|
|
150
|
+
words[word_count++] = w; \
|
|
151
|
+
} while(0)
|
|
152
|
+
|
|
153
|
+
size_t i = 0;
|
|
154
|
+
while (i < text_len) {
|
|
155
|
+
size_t char_len;
|
|
156
|
+
uint32_t cp = unicode_cpt_from_utf8(text + i, &char_len);
|
|
157
|
+
int cls = get_char_class(cp);
|
|
158
|
+
|
|
159
|
+
if (cls == CHAR_CLASS_NEWLINE) {
|
|
160
|
+
ADD_WORD(text + i, char_len);
|
|
161
|
+
i += char_len;
|
|
162
|
+
continue;
|
|
161
163
|
}
|
|
162
|
-
|
|
163
|
-
if (
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
break;
|
|
164
|
+
|
|
165
|
+
if (cls == CHAR_CLASS_SPACE) {
|
|
166
|
+
size_t space_start = i;
|
|
167
|
+
while (i < text_len) {
|
|
168
|
+
size_t cl;
|
|
169
|
+
uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
|
|
170
|
+
int cc = get_char_class(c);
|
|
171
|
+
if (cc != CHAR_CLASS_SPACE) break;
|
|
172
|
+
i += cl;
|
|
170
173
|
}
|
|
171
|
-
|
|
174
|
+
if (i >= text_len) break;
|
|
175
|
+
size_t space_len = i - space_start;
|
|
176
|
+
ADD_WORD(text + space_start, space_len);
|
|
177
|
+
continue;
|
|
172
178
|
}
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
179
|
+
|
|
180
|
+
size_t start = i;
|
|
181
|
+
i += char_len;
|
|
182
|
+
|
|
183
|
+
while (i < text_len) {
|
|
184
|
+
size_t cl;
|
|
185
|
+
uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
|
|
186
|
+
int ccls = get_char_class(c);
|
|
187
|
+
|
|
188
|
+
if (is_contraction(text, i, text_len)) {
|
|
189
|
+
size_t clen = contraction_len(text, i);
|
|
190
|
+
i += clen;
|
|
191
|
+
continue;
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
if (ccls != cls) break;
|
|
195
|
+
if (cls == CHAR_CLASS_NUMBER) {
|
|
196
|
+
int digits = 0;
|
|
197
|
+
size_t check = start;
|
|
198
|
+
while (check < i) {
|
|
199
|
+
size_t dl;
|
|
200
|
+
uint32_t dc = unicode_cpt_from_utf8(text + check, &dl);
|
|
201
|
+
if (get_char_class(dc) == CHAR_CLASS_NUMBER) digits++;
|
|
202
|
+
check += dl;
|
|
192
203
|
}
|
|
193
|
-
|
|
204
|
+
if (digits >= 3) break;
|
|
194
205
|
}
|
|
206
|
+
i += cl;
|
|
195
207
|
}
|
|
196
|
-
|
|
197
|
-
|
|
208
|
+
|
|
209
|
+
ADD_WORD(text + start, i - start);
|
|
198
210
|
}
|
|
199
|
-
|
|
211
|
+
|
|
212
|
+
#undef ADD_WORD
|
|
200
213
|
*num_words = word_count;
|
|
201
214
|
return words;
|
|
215
|
+
|
|
216
|
+
error:
|
|
217
|
+
for (int j = 0; j < word_count; j++) free(words[j]);
|
|
218
|
+
free(words);
|
|
219
|
+
*num_words = 0;
|
|
220
|
+
return NULL;
|
|
202
221
|
}
|
|
203
222
|
|
|
204
223
|
/* ------------------------------------------------------------------------- */
|
|
205
|
-
// BPE merge
|
|
206
|
-
typedef struct {
|
|
224
|
+
// BPE merge structures with hash table for O(1) lookup
|
|
225
|
+
typedef struct MergeHashNode {
|
|
207
226
|
char *left;
|
|
208
227
|
char *right;
|
|
209
|
-
char *merged;
|
|
210
228
|
int rank;
|
|
211
|
-
|
|
229
|
+
struct MergeHashNode *next;
|
|
230
|
+
} MergeHashNode;
|
|
212
231
|
|
|
213
232
|
typedef struct {
|
|
214
|
-
|
|
233
|
+
MergeHashNode **table;
|
|
234
|
+
int table_size;
|
|
215
235
|
int num_merges;
|
|
216
|
-
int capacity;
|
|
217
236
|
} BPEMergeTable;
|
|
218
237
|
|
|
238
|
+
static uint64_t merge_hash(const char *left, const char *right) {
|
|
239
|
+
uint64_t h = 0xcbf29ce484222325ULL;
|
|
240
|
+
while (*left) { h ^= (uint64_t)(unsigned char)*left++; h *= 0x100000001b3ULL; }
|
|
241
|
+
h ^= (uint64_t)' ';
|
|
242
|
+
h *= 0x100000001b3ULL;
|
|
243
|
+
while (*right) { h ^= (uint64_t)(unsigned char)*right++; h *= 0x100000001b3ULL; }
|
|
244
|
+
return h;
|
|
245
|
+
}
|
|
246
|
+
|
|
219
247
|
static void bpe_merge_table_init(BPEMergeTable *table) {
|
|
220
|
-
table->
|
|
248
|
+
table->table_size = MERGE_HASH_SIZE;
|
|
249
|
+
table->table = calloc(MERGE_HASH_SIZE, sizeof(MergeHashNode*));
|
|
221
250
|
table->num_merges = 0;
|
|
222
|
-
table->capacity = 0;
|
|
223
251
|
}
|
|
224
252
|
|
|
225
|
-
static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right,
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
merge->rank = rank;
|
|
253
|
+
static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, int rank) {
|
|
254
|
+
uint64_t h = merge_hash(left, right) % table->table_size;
|
|
255
|
+
MergeHashNode *n = malloc(sizeof(MergeHashNode));
|
|
256
|
+
if (!n) return;
|
|
257
|
+
n->left = strdup(left);
|
|
258
|
+
n->right = strdup(right);
|
|
259
|
+
n->rank = rank;
|
|
260
|
+
n->next = table->table[h];
|
|
261
|
+
table->table[h] = n;
|
|
262
|
+
table->num_merges++;
|
|
236
263
|
}
|
|
237
264
|
|
|
238
265
|
static void bpe_merge_table_free(BPEMergeTable *table) {
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
266
|
+
if (!table->table) return;
|
|
267
|
+
for (int i = 0; i < table->table_size; i++) {
|
|
268
|
+
MergeHashNode *n = table->table[i];
|
|
269
|
+
while (n) {
|
|
270
|
+
MergeHashNode *next = n->next;
|
|
271
|
+
free(n->left);
|
|
272
|
+
free(n->right);
|
|
273
|
+
free(n);
|
|
274
|
+
n = next;
|
|
275
|
+
}
|
|
243
276
|
}
|
|
244
|
-
free(table->
|
|
245
|
-
table->
|
|
246
|
-
table->num_merges = 0;
|
|
277
|
+
free(table->table);
|
|
278
|
+
table->table = NULL;
|
|
247
279
|
}
|
|
248
280
|
|
|
249
281
|
static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
282
|
+
uint64_t h = merge_hash(left, right) % table->table_size;
|
|
283
|
+
MergeHashNode *n = table->table[h];
|
|
284
|
+
while (n) {
|
|
285
|
+
if (strcmp(n->left, left) == 0 && strcmp(n->right, right) == 0)
|
|
286
|
+
return n->rank;
|
|
287
|
+
n = n->next;
|
|
254
288
|
}
|
|
255
289
|
return -1;
|
|
256
290
|
}
|
|
257
291
|
|
|
258
|
-
static char* bpe_merge(const BPEMergeTable *table, const char *left, const char *right) {
|
|
259
|
-
for (int i = 0; i < table->num_merges; i++) {
|
|
260
|
-
if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
|
|
261
|
-
return table->merges[i].merged;
|
|
262
|
-
}
|
|
263
|
-
}
|
|
264
|
-
return NULL;
|
|
265
|
-
}
|
|
266
|
-
|
|
267
292
|
/* ------------------------------------------------------------------------- */
|
|
268
|
-
// BPE tokenization
|
|
293
|
+
// BPE tokenization (correct iterative algorithm)
|
|
269
294
|
typedef struct {
|
|
270
|
-
char *text;
|
|
271
|
-
int start;
|
|
272
|
-
int
|
|
273
|
-
int prev;
|
|
274
|
-
int next;
|
|
295
|
+
const char *text;
|
|
296
|
+
int start, end;
|
|
297
|
+
int prev, next;
|
|
275
298
|
int used;
|
|
276
299
|
} BPESymbol;
|
|
277
300
|
|
|
278
|
-
static
|
|
279
|
-
|
|
301
|
+
static int text_to_id(void *vocab_data, const char *text);
|
|
302
|
+
|
|
303
|
+
static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word,
|
|
304
|
+
void *vocab_data, int *token_ids, int *num_tokens) {
|
|
280
305
|
int word_len = strlen(word);
|
|
306
|
+
if (word_len == 0) return;
|
|
307
|
+
|
|
281
308
|
int num_symbols = 0;
|
|
282
309
|
BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
|
|
283
|
-
|
|
284
|
-
|
|
310
|
+
if (!symbols) return;
|
|
311
|
+
|
|
285
312
|
int offset = 0;
|
|
286
313
|
while (offset < word_len) {
|
|
287
314
|
int char_len = unicode_len_utf8(word[offset]);
|
|
288
|
-
|
|
315
|
+
if (offset + char_len > word_len) char_len = word_len - offset;
|
|
316
|
+
symbols[num_symbols].text = word;
|
|
289
317
|
symbols[num_symbols].start = offset;
|
|
290
318
|
symbols[num_symbols].end = offset + char_len;
|
|
291
319
|
symbols[num_symbols].prev = num_symbols - 1;
|
|
@@ -294,110 +322,75 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
|
|
|
294
322
|
offset += char_len;
|
|
295
323
|
num_symbols++;
|
|
296
324
|
}
|
|
297
|
-
|
|
325
|
+
|
|
326
|
+
if (num_symbols > 0) symbols[num_symbols - 1].next = -1;
|
|
327
|
+
|
|
298
328
|
if (num_symbols <= 1) {
|
|
299
|
-
// Single character, just tokenize it
|
|
300
329
|
int id = text_to_id(vocab_data, word);
|
|
301
|
-
if (id != -1)
|
|
302
|
-
token_ids[*num_tokens] = id;
|
|
303
|
-
(*num_tokens)++;
|
|
304
|
-
}
|
|
330
|
+
if (id != -1) token_ids[(*num_tokens)++] = id;
|
|
305
331
|
free(symbols);
|
|
306
332
|
return;
|
|
307
333
|
}
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
int
|
|
312
|
-
|
|
313
|
-
int
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
bigrams[num_bigrams].rank = rank;
|
|
335
|
-
num_bigrams++;
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
free(left_str);
|
|
339
|
-
free(right_str);
|
|
340
|
-
}
|
|
341
|
-
}
|
|
342
|
-
|
|
343
|
-
// Sort bigrams by rank (lower rank = higher priority)
|
|
344
|
-
for (int i = 0; i < num_bigrams - 1; i++) {
|
|
345
|
-
for (int j = i+1; j < num_bigrams; j++) {
|
|
346
|
-
if (bigrams[i].rank > bigrams[j].rank) {
|
|
347
|
-
Bigram temp = bigrams[i];
|
|
348
|
-
bigrams[i] = bigrams[j];
|
|
349
|
-
bigrams[j] = temp;
|
|
334
|
+
|
|
335
|
+
while (1) {
|
|
336
|
+
int best_rank = INT_MAX;
|
|
337
|
+
int best_idx = -1;
|
|
338
|
+
|
|
339
|
+
int idx = 0;
|
|
340
|
+
while (idx != -1) {
|
|
341
|
+
int next = symbols[idx].next;
|
|
342
|
+
if (next != -1 && symbols[idx].used && symbols[next].used) {
|
|
343
|
+
int left_len = symbols[idx].end - symbols[idx].start;
|
|
344
|
+
int right_len = symbols[next].end - symbols[next].start;
|
|
345
|
+
char *left_str = malloc(left_len + 1);
|
|
346
|
+
char *right_str = malloc(right_len + 1);
|
|
347
|
+
if (left_str && right_str) {
|
|
348
|
+
memcpy(left_str, word + symbols[idx].start, left_len);
|
|
349
|
+
left_str[left_len] = '\0';
|
|
350
|
+
memcpy(right_str, word + symbols[next].start, right_len);
|
|
351
|
+
right_str[right_len] = '\0';
|
|
352
|
+
int rank = bpe_merge_rank(merges, left_str, right_str);
|
|
353
|
+
if (rank != -1 && rank < best_rank) {
|
|
354
|
+
best_rank = rank;
|
|
355
|
+
best_idx = idx;
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
free(left_str);
|
|
359
|
+
free(right_str);
|
|
350
360
|
}
|
|
361
|
+
idx = symbols[idx].next;
|
|
351
362
|
}
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
if (
|
|
361
|
-
|
|
362
|
-
// Merge right into left
|
|
363
|
-
symbols[left].end = symbols[right].end;
|
|
364
|
-
symbols[left].next = symbols[right].next;
|
|
365
|
-
merged[right] = 1;
|
|
366
|
-
|
|
367
|
-
// Update next symbol's prev
|
|
368
|
-
if (symbols[right].next < num_symbols) {
|
|
369
|
-
symbols[symbols[right].next].prev = left;
|
|
363
|
+
|
|
364
|
+
if (best_idx == -1) break;
|
|
365
|
+
|
|
366
|
+
int right_idx = symbols[best_idx].next;
|
|
367
|
+
symbols[best_idx].end = symbols[right_idx].end;
|
|
368
|
+
symbols[best_idx].next = symbols[right_idx].next;
|
|
369
|
+
symbols[right_idx].used = 0;
|
|
370
|
+
|
|
371
|
+
if (symbols[right_idx].next != -1) {
|
|
372
|
+
symbols[symbols[right_idx].next].prev = best_idx;
|
|
370
373
|
}
|
|
371
374
|
}
|
|
372
|
-
|
|
373
|
-
// Collect final tokens
|
|
375
|
+
|
|
374
376
|
for (int i = 0; i < num_symbols; i++) {
|
|
375
|
-
if (
|
|
376
|
-
|
|
377
|
-
char *substr = malloc(
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
(*num_tokens)++;
|
|
385
|
-
} else {
|
|
386
|
-
// Unknown token - use byte-level fallback
|
|
387
|
-
// For simplicity, we'll use space as a placeholder
|
|
388
|
-
// In a full implementation, you'd encode bytes individually
|
|
377
|
+
if (symbols[i].used) {
|
|
378
|
+
int len = symbols[i].end - symbols[i].start;
|
|
379
|
+
char *substr = malloc(len + 1);
|
|
380
|
+
if (substr) {
|
|
381
|
+
memcpy(substr, word + symbols[i].start, len);
|
|
382
|
+
substr[len] = '\0';
|
|
383
|
+
int id = text_to_id(vocab_data, substr);
|
|
384
|
+
if (id != -1) token_ids[(*num_tokens)++] = id;
|
|
385
|
+
free(substr);
|
|
389
386
|
}
|
|
390
|
-
|
|
391
|
-
free(substr);
|
|
392
387
|
}
|
|
393
388
|
}
|
|
394
|
-
|
|
395
|
-
free(bigrams);
|
|
396
|
-
free(merged);
|
|
397
389
|
free(symbols);
|
|
398
390
|
}
|
|
399
391
|
|
|
400
392
|
/* ------------------------------------------------------------------------- */
|
|
393
|
+
// GGUF parsing
|
|
401
394
|
static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
|
|
402
395
|
if (*p + sz > end) return 0;
|
|
403
396
|
*p += sz;
|
|
@@ -405,14 +398,14 @@ static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
|
|
|
405
398
|
}
|
|
406
399
|
|
|
407
400
|
static uint32_t rd32(uint8_t **p, uint8_t *end) {
|
|
408
|
-
uint32_t v
|
|
401
|
+
uint32_t v;
|
|
409
402
|
if (!safe_advance(p, end, 4)) return 0;
|
|
410
403
|
memcpy(&v, *p - 4, 4);
|
|
411
404
|
return v;
|
|
412
405
|
}
|
|
413
406
|
|
|
414
407
|
static uint64_t rd64(uint8_t **p, uint8_t *end) {
|
|
415
|
-
uint64_t v
|
|
408
|
+
uint64_t v;
|
|
416
409
|
if (!safe_advance(p, end, 8)) return 0;
|
|
417
410
|
memcpy(&v, *p - 8, 8);
|
|
418
411
|
return v;
|
|
@@ -423,9 +416,9 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
|
|
|
423
416
|
uint64_t len;
|
|
424
417
|
memcpy(&len, *p, 8);
|
|
425
418
|
*p += 8;
|
|
426
|
-
if (len == 0 || len > (1
|
|
419
|
+
if (len == 0 || len > (1<<20)) return NULL;
|
|
427
420
|
if (*p + len > end) return NULL;
|
|
428
|
-
char *s = malloc(len
|
|
421
|
+
char *s = malloc(len+1);
|
|
429
422
|
if (!s) return NULL;
|
|
430
423
|
memcpy(s, *p, len);
|
|
431
424
|
s[len] = '\0';
|
|
@@ -436,52 +429,56 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
|
|
|
436
429
|
static void align_to_32(uint8_t **p, uint8_t *end, uint8_t *base) {
|
|
437
430
|
size_t off = *p - base;
|
|
438
431
|
size_t aligned = (off + GGUF_ALIGN - 1) & ~(GGUF_ALIGN - 1);
|
|
439
|
-
if (base + aligned <= end)
|
|
440
|
-
*p = base + aligned;
|
|
432
|
+
if (base + aligned <= end) *p = base + aligned;
|
|
441
433
|
}
|
|
442
434
|
|
|
443
435
|
/* ------------------------------------------------------------------------- */
|
|
436
|
+
// Hash table for vocabulary
|
|
444
437
|
typedef struct HashNode {
|
|
445
438
|
char *key;
|
|
446
|
-
int
|
|
439
|
+
int id;
|
|
447
440
|
struct HashNode *next;
|
|
448
441
|
} HashNode;
|
|
449
442
|
|
|
450
443
|
typedef struct {
|
|
451
|
-
int
|
|
452
|
-
int
|
|
453
|
-
char
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
int tensor_type;
|
|
457
|
-
void *mapped;
|
|
458
|
-
size_t mapped_size;
|
|
444
|
+
int vocab_size;
|
|
445
|
+
int dim;
|
|
446
|
+
char **tokens;
|
|
447
|
+
void *mapped;
|
|
448
|
+
size_t mapped_size;
|
|
459
449
|
HashNode **table;
|
|
460
|
-
|
|
461
|
-
// BPE tokenization data
|
|
462
450
|
BPEMergeTable merges;
|
|
463
|
-
RegexPattern *pre_patterns;
|
|
464
|
-
int num_pre_patterns;
|
|
465
451
|
int unknown_token_id;
|
|
466
452
|
int bos_token_id;
|
|
467
453
|
int eos_token_id;
|
|
468
454
|
int vocab_type;
|
|
455
|
+
char space_marker[8];
|
|
456
|
+
int space_marker_len;
|
|
457
|
+
const void *raw_tensor_data;
|
|
458
|
+
int tensor_type;
|
|
459
|
+
size_t row_bytes;
|
|
460
|
+
int need_transpose;
|
|
461
|
+
uint64_t raw_dim0, raw_dim1;
|
|
462
|
+
int normalize;
|
|
469
463
|
} EmbedModel;
|
|
470
464
|
|
|
471
465
|
typedef struct {
|
|
472
466
|
EmbedModel *model;
|
|
473
467
|
} ruby_embedder;
|
|
474
468
|
|
|
475
|
-
static
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
469
|
+
static uint64_t vocab_hash(const char *s) {
|
|
470
|
+
uint64_t h = 0xcbf29ce484222325ULL;
|
|
471
|
+
while (*s) {
|
|
472
|
+
h ^= (uint64_t)(unsigned char)*s++;
|
|
473
|
+
h *= 0x100000001b3ULL;
|
|
474
|
+
}
|
|
479
475
|
return h % HASH_SIZE;
|
|
480
476
|
}
|
|
481
477
|
|
|
482
478
|
static void hset(EmbedModel *m, char *k, int id) {
|
|
483
|
-
|
|
479
|
+
uint64_t h = vocab_hash(k);
|
|
484
480
|
HashNode *n = malloc(sizeof(*n));
|
|
481
|
+
if (!n) return;
|
|
485
482
|
n->key = k;
|
|
486
483
|
n->id = id;
|
|
487
484
|
n->next = m->table[h];
|
|
@@ -489,7 +486,8 @@ static void hset(EmbedModel *m, char *k, int id) {
|
|
|
489
486
|
}
|
|
490
487
|
|
|
491
488
|
static int hget(EmbedModel *m, const char *k) {
|
|
492
|
-
|
|
489
|
+
if (!k || !m->table) return -1;
|
|
490
|
+
HashNode *n = m->table[vocab_hash(k)];
|
|
493
491
|
while (n) {
|
|
494
492
|
if (strcmp(n->key, k) == 0) return n->id;
|
|
495
493
|
n = n->next;
|
|
@@ -498,48 +496,62 @@ static int hget(EmbedModel *m, const char *k) {
|
|
|
498
496
|
}
|
|
499
497
|
|
|
500
498
|
static int text_to_id(void *vocab_data, const char *text) {
|
|
501
|
-
|
|
502
|
-
return hget(m, text);
|
|
499
|
+
return hget((EmbedModel*)vocab_data, text);
|
|
503
500
|
}
|
|
504
501
|
|
|
505
502
|
/* ------------------------------------------------------------------------- */
|
|
503
|
+
// File mapping
|
|
506
504
|
static void *map_file(const char *path, size_t *size) {
|
|
507
505
|
int fd = open(path, O_RDONLY);
|
|
508
506
|
if (fd < 0) return NULL;
|
|
509
507
|
struct stat st;
|
|
510
508
|
if (fstat(fd, &st) != 0) { close(fd); return NULL; }
|
|
511
509
|
*size = st.st_size;
|
|
510
|
+
if (*size == 0) { close(fd); return NULL; }
|
|
512
511
|
void *data = mmap(NULL, *size, PROT_READ, MAP_PRIVATE, fd, 0);
|
|
513
512
|
close(fd);
|
|
514
|
-
|
|
515
|
-
return data;
|
|
513
|
+
return data == MAP_FAILED ? NULL : data;
|
|
516
514
|
}
|
|
517
515
|
|
|
518
516
|
/* ------------------------------------------------------------------------- */
|
|
517
|
+
// FP16 conversion (corrected)
|
|
519
518
|
static float fp16_to_fp32(uint16_t h) {
|
|
520
|
-
const
|
|
521
|
-
const
|
|
522
|
-
const
|
|
523
|
-
|
|
519
|
+
const uint32_t sign = (h >> 15) & 1;
|
|
520
|
+
const uint32_t exp = (h >> 10) & 0x1F;
|
|
521
|
+
const uint32_t mant = h & 0x3FF;
|
|
522
|
+
|
|
523
|
+
uint32_t f;
|
|
524
524
|
if (exp == 0) {
|
|
525
|
-
|
|
525
|
+
if (mant == 0) {
|
|
526
|
+
f = sign << 31;
|
|
527
|
+
} else {
|
|
528
|
+
uint32_t e = 0;
|
|
529
|
+
uint32_t m = mant;
|
|
530
|
+
while (!(m & 0x400)) { m <<= 1; e++; }
|
|
531
|
+
f = (sign << 31) | ((127 - 15 - e + 1) << 23) | ((m & 0x3FF) << 13);
|
|
532
|
+
}
|
|
526
533
|
} else if (exp == 31) {
|
|
527
|
-
|
|
534
|
+
f = (sign << 31) | (0xFF << 23) | (mant << 13);
|
|
528
535
|
} else {
|
|
529
|
-
|
|
536
|
+
f = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
|
|
530
537
|
}
|
|
531
|
-
|
|
538
|
+
|
|
539
|
+
float result;
|
|
540
|
+
memcpy(&result, &f, sizeof(result));
|
|
541
|
+
return result;
|
|
532
542
|
}
|
|
533
543
|
|
|
534
544
|
/* ------------------------------------------------------------------------- */
|
|
535
|
-
|
|
536
|
-
|
|
545
|
+
// Block dequantization functions (correct sizes)
|
|
537
546
|
static void dequantize_row_q4_0(const void *vx, float *y, int k) {
|
|
538
|
-
const int nb = k /
|
|
547
|
+
const int nb = k / QK8_0;
|
|
539
548
|
const uint8_t *x = vx;
|
|
540
549
|
for (int i = 0; i < nb; i++) {
|
|
541
|
-
const
|
|
542
|
-
|
|
550
|
+
const uint8_t *block = x + i * 18;
|
|
551
|
+
uint16_t d16;
|
|
552
|
+
memcpy(&d16, block, 2);
|
|
553
|
+
const float d = fp16_to_fp32(d16);
|
|
554
|
+
const uint8_t *q = block + 2;
|
|
543
555
|
for (int j = 0; j < 32; j++) {
|
|
544
556
|
const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
|
|
545
557
|
y[i*32 + j] = (v - 8.0f) * d;
|
|
@@ -548,12 +560,16 @@ static void dequantize_row_q4_0(const void *vx, float *y, int k) {
|
|
|
548
560
|
}
|
|
549
561
|
|
|
550
562
|
static void dequantize_row_q4_1(const void *vx, float *y, int k) {
|
|
551
|
-
const int nb = k /
|
|
563
|
+
const int nb = k / QK8_0;
|
|
552
564
|
const uint8_t *x = vx;
|
|
553
565
|
for (int i = 0; i < nb; i++) {
|
|
554
|
-
const
|
|
555
|
-
|
|
556
|
-
|
|
566
|
+
const uint8_t *block = x + i * 20;
|
|
567
|
+
uint16_t d16, m16;
|
|
568
|
+
memcpy(&d16, block, 2);
|
|
569
|
+
memcpy(&m16, block + 2, 2);
|
|
570
|
+
const float d = fp16_to_fp32(d16);
|
|
571
|
+
const float m = fp16_to_fp32(m16);
|
|
572
|
+
const uint8_t *q = block + 4;
|
|
557
573
|
for (int j = 0; j < 32; j++) {
|
|
558
574
|
const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
|
|
559
575
|
y[i*32 + j] = v * d + m;
|
|
@@ -562,14 +578,16 @@ static void dequantize_row_q4_1(const void *vx, float *y, int k) {
|
|
|
562
578
|
}
|
|
563
579
|
|
|
564
580
|
static void dequantize_row_q5_0(const void *vx, float *y, int k) {
|
|
565
|
-
const int nb = k /
|
|
581
|
+
const int nb = k / QK8_0;
|
|
566
582
|
const uint8_t *x = vx;
|
|
567
583
|
for (int i = 0; i < nb; i++) {
|
|
568
|
-
const
|
|
569
|
-
|
|
570
|
-
|
|
584
|
+
const uint8_t *block = x + i * 22;
|
|
585
|
+
uint16_t d16;
|
|
586
|
+
memcpy(&d16, block, 2);
|
|
587
|
+
const float d = fp16_to_fp32(d16);
|
|
571
588
|
uint32_t qh32;
|
|
572
|
-
memcpy(&qh32,
|
|
589
|
+
memcpy(&qh32, block + 2, 4);
|
|
590
|
+
const uint8_t *ql = block + 6;
|
|
573
591
|
for (int j = 0; j < 32; j++) {
|
|
574
592
|
const uint8_t vh = (qh32 >> j) & 1;
|
|
575
593
|
const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
|
|
@@ -579,15 +597,18 @@ static void dequantize_row_q5_0(const void *vx, float *y, int k) {
|
|
|
579
597
|
}
|
|
580
598
|
|
|
581
599
|
static void dequantize_row_q5_1(const void *vx, float *y, int k) {
|
|
582
|
-
const int nb = k /
|
|
600
|
+
const int nb = k / QK8_0;
|
|
583
601
|
const uint8_t *x = vx;
|
|
584
602
|
for (int i = 0; i < nb; i++) {
|
|
585
|
-
const
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
603
|
+
const uint8_t *block = x + i * 24;
|
|
604
|
+
uint16_t d16, m16;
|
|
605
|
+
memcpy(&d16, block, 2);
|
|
606
|
+
memcpy(&m16, block + 2, 2);
|
|
607
|
+
const float d = fp16_to_fp32(d16);
|
|
608
|
+
const float m = fp16_to_fp32(m16);
|
|
589
609
|
uint32_t qh32;
|
|
590
|
-
memcpy(&qh32,
|
|
610
|
+
memcpy(&qh32, block + 4, 4);
|
|
611
|
+
const uint8_t *ql = block + 8;
|
|
591
612
|
for (int j = 0; j < 32; j++) {
|
|
592
613
|
const uint8_t vh = (qh32 >> j) & 1;
|
|
593
614
|
const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
|
|
@@ -597,11 +618,13 @@ static void dequantize_row_q5_1(const void *vx, float *y, int k) {
|
|
|
597
618
|
}
|
|
598
619
|
|
|
599
620
|
static void dequantize_row_q8_0(const void *vx, float *y, int k) {
|
|
600
|
-
const int nb = k /
|
|
621
|
+
const int nb = k / QK8_0;
|
|
601
622
|
const uint8_t *x = vx;
|
|
602
623
|
for (int i = 0; i < nb; i++) {
|
|
603
|
-
const
|
|
604
|
-
|
|
624
|
+
const uint8_t *block = x + i * 34;
|
|
625
|
+
float d;
|
|
626
|
+
memcpy(&d, block, 4);
|
|
627
|
+
const int8_t *q = (const int8_t*)(block + 4);
|
|
605
628
|
for (int j = 0; j < 32; j++) {
|
|
606
629
|
y[i*32 + j] = (float)q[j] * d;
|
|
607
630
|
}
|
|
@@ -609,191 +632,315 @@ static void dequantize_row_q8_0(const void *vx, float *y, int k) {
|
|
|
609
632
|
}
|
|
610
633
|
|
|
611
634
|
static void dequantize_row_q8_1(const void *vx, float *y, int k) {
|
|
612
|
-
const int nb = k /
|
|
635
|
+
const int nb = k / QK8_0;
|
|
613
636
|
const uint8_t *x = vx;
|
|
614
637
|
for (int i = 0; i < nb; i++) {
|
|
615
|
-
const
|
|
616
|
-
|
|
617
|
-
|
|
638
|
+
const uint8_t *block = x + i * 40;
|
|
639
|
+
float d, s;
|
|
640
|
+
memcpy(&d, block, 4);
|
|
641
|
+
memcpy(&s, block + 4, 4);
|
|
642
|
+
const int8_t *q = (const int8_t*)(block + 8);
|
|
618
643
|
for (int j = 0; j < 32; j++) {
|
|
619
644
|
y[i*32 + j] = (float)q[j] * d + s;
|
|
620
645
|
}
|
|
621
646
|
}
|
|
622
647
|
}
|
|
623
648
|
|
|
624
|
-
|
|
649
|
+
// K-quant scale helpers
|
|
650
|
+
static inline void get_scale_min_k4(int j, const uint8_t *q, uint8_t *d, uint8_t *m) {
|
|
651
|
+
if (j < 4) {
|
|
652
|
+
*d = q[j] & 63;
|
|
653
|
+
*m = q[j + 4] & 63;
|
|
654
|
+
} else {
|
|
655
|
+
*d = (q[j+4] & 0xF) | ((q[j-3] >> 6) << 4);
|
|
656
|
+
*m = (q[j+4] >> 4) | ((q[j-1] >> 6) << 4);
|
|
657
|
+
}
|
|
658
|
+
}
|
|
659
|
+
|
|
625
660
|
static void dequantize_row_q2_K(const void *vx, float *y, int k) {
|
|
626
|
-
const int nb = k /
|
|
661
|
+
const int nb = k / QK_K;
|
|
627
662
|
const uint8_t *x = vx;
|
|
628
663
|
for (int i = 0; i < nb; i++) {
|
|
629
|
-
const
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
664
|
+
const uint8_t *block = x + i * 84;
|
|
665
|
+
uint16_t d16, dmin16;
|
|
666
|
+
memcpy(&d16, block, 2);
|
|
667
|
+
memcpy(&dmin16, block + 2, 2);
|
|
668
|
+
const float d = fp16_to_fp32(d16);
|
|
669
|
+
const float min = fp16_to_fp32(dmin16);
|
|
670
|
+
const uint8_t *scales = block + 4;
|
|
671
|
+
const uint8_t *q = block + 20;
|
|
672
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
673
|
+
const float dl = d * (scales[j/64] & 0xF);
|
|
674
|
+
const float ml = min * (scales[j/64] >> 4);
|
|
675
|
+
for (int l = 0; l < 64; l++) {
|
|
637
676
|
const int v = (q[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
|
|
638
|
-
|
|
639
|
-
const float ml = m * (ms - 32);
|
|
640
|
-
y[i*256 + j + l] = v * dl + ml;
|
|
677
|
+
y[i*QK_K + j + l] = v * dl + ml;
|
|
641
678
|
}
|
|
642
679
|
}
|
|
643
680
|
}
|
|
644
681
|
}
|
|
645
682
|
|
|
646
683
|
static void dequantize_row_q3_K(const void *vx, float *y, int k) {
|
|
647
|
-
const int nb = k /
|
|
684
|
+
const int nb = k / QK_K;
|
|
648
685
|
const uint8_t *x = vx;
|
|
649
686
|
for (int i = 0; i < nb; i++) {
|
|
650
|
-
const
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
const
|
|
654
|
-
|
|
687
|
+
const uint8_t *block = x + i * 110;
|
|
688
|
+
uint16_t d16;
|
|
689
|
+
memcpy(&d16, block, 2);
|
|
690
|
+
const float d = fp16_to_fp32(d16);
|
|
691
|
+
const uint8_t *hmask = block + 2;
|
|
692
|
+
const uint8_t *q = block + 34;
|
|
693
|
+
const uint8_t *scales = block + 98;
|
|
694
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
655
695
|
const uint8_t ls1 = scales[j/64] & 0x1F;
|
|
656
|
-
const uint8_t ls2 = (scales[j/64] >>
|
|
657
|
-
const uint8_t
|
|
696
|
+
const uint8_t ls2 = (scales[j/64] >> 5) | ((scales[j/64 + 1] & 0x7) << 3);
|
|
697
|
+
const uint8_t ls3 = ((scales[j/64 + 1] >> 3) & 0x1F);
|
|
698
|
+
const uint8_t ls4 = (scales[j/64 + 1] >> 8);
|
|
658
699
|
for (int l = 0; l < 64; l++) {
|
|
659
700
|
int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
|
|
660
|
-
const int bit = (
|
|
701
|
+
const int bit = (hmask[(j+l)/8] >> ((j+l)%8)) & 1;
|
|
661
702
|
v |= bit << 4;
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
703
|
+
float ls;
|
|
704
|
+
if (l < 16) ls = ls1;
|
|
705
|
+
else if (l < 32) ls = ls2;
|
|
706
|
+
else if (l < 48) ls = ls3;
|
|
707
|
+
else ls = ls4;
|
|
708
|
+
y[i*QK_K + j + l] = (v - 32.0f) * d * ls;
|
|
665
709
|
}
|
|
666
710
|
}
|
|
667
711
|
}
|
|
668
712
|
}
|
|
669
713
|
|
|
670
714
|
static void dequantize_row_q4_K(const void *vx, float *y, int k) {
|
|
671
|
-
const int nb = k /
|
|
715
|
+
const int nb = k / QK_K;
|
|
672
716
|
const uint8_t *x = vx;
|
|
673
717
|
for (int i = 0; i < nb; i++) {
|
|
674
|
-
const
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
718
|
+
const uint8_t *block = x + i * 144;
|
|
719
|
+
uint16_t d16, dmin16;
|
|
720
|
+
memcpy(&d16, block, 2);
|
|
721
|
+
memcpy(&dmin16, block + 2, 2);
|
|
722
|
+
const float d = fp16_to_fp32(d16);
|
|
723
|
+
const float min = fp16_to_fp32(dmin16);
|
|
724
|
+
const uint8_t *scales = block + 4;
|
|
725
|
+
const uint8_t *q = block + 16;
|
|
726
|
+
int is = 0;
|
|
727
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
728
|
+
uint8_t sc, m;
|
|
729
|
+
get_scale_min_k4(is, scales, &sc, &m);
|
|
730
|
+
float d1 = d * sc;
|
|
731
|
+
float m1 = min * m;
|
|
732
|
+
get_scale_min_k4(is + 1, scales, &sc, &m);
|
|
733
|
+
float d2 = d * sc;
|
|
734
|
+
float m2 = min * m;
|
|
681
735
|
for (int l = 0; l < 32; l++) {
|
|
682
|
-
|
|
683
|
-
const float dl = d * (ls - 32);
|
|
684
|
-
const float ml = m * (ms - 2);
|
|
685
|
-
y[i*256 + j + l] = v * dl + ml;
|
|
736
|
+
y[i*QK_K + j + l] = d1 * (q[l] & 0xF) - m1;
|
|
686
737
|
}
|
|
738
|
+
for (int l = 0; l < 32; l++) {
|
|
739
|
+
y[i*QK_K + j + 32 + l] = d2 * (q[l] >> 4) - m2;
|
|
740
|
+
}
|
|
741
|
+
q += 32;
|
|
742
|
+
is += 2;
|
|
687
743
|
}
|
|
688
744
|
}
|
|
689
745
|
}
|
|
690
746
|
|
|
691
747
|
static void dequantize_row_q5_K(const void *vx, float *y, int k) {
|
|
692
|
-
const int nb = k /
|
|
748
|
+
const int nb = k / QK_K;
|
|
693
749
|
const uint8_t *x = vx;
|
|
694
750
|
for (int i = 0; i < nb; i++) {
|
|
695
|
-
const
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
const
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
751
|
+
const uint8_t *block = x + i * 176;
|
|
752
|
+
uint16_t d16, dmin16;
|
|
753
|
+
memcpy(&d16, block, 2);
|
|
754
|
+
memcpy(&dmin16, block + 2, 2);
|
|
755
|
+
const float d = fp16_to_fp32(d16);
|
|
756
|
+
const float min = fp16_to_fp32(dmin16);
|
|
757
|
+
const uint8_t *scales = block + 4;
|
|
758
|
+
const uint8_t *qh = block + 16;
|
|
759
|
+
const uint8_t *ql = block + 48;
|
|
760
|
+
int is = 0;
|
|
761
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
762
|
+
uint8_t sc, m;
|
|
763
|
+
get_scale_min_k4(is, scales, &sc, &m);
|
|
764
|
+
float d1 = d * sc;
|
|
765
|
+
float m1 = min * m;
|
|
766
|
+
get_scale_min_k4(is + 1, scales, &sc, &m);
|
|
767
|
+
float d2 = d * sc;
|
|
768
|
+
float m2 = min * m;
|
|
703
769
|
for (int l = 0; l < 32; l++) {
|
|
704
|
-
int
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
770
|
+
int vh = (qh[j/64 * 4 + l/8] >> (l%8)) & 1;
|
|
771
|
+
int v = (ql[l] & 0xF) | (vh << 4);
|
|
772
|
+
y[i*QK_K + j + l] = d1 * v - m1;
|
|
773
|
+
}
|
|
774
|
+
for (int l = 0; l < 32; l++) {
|
|
775
|
+
int vh = (qh[j/64 * 4 + 4 + l/8] >> (l%8)) & 1;
|
|
776
|
+
int v = (ql[l] >> 4) | (vh << 4);
|
|
777
|
+
y[i*QK_K + j + 32 + l] = d2 * v - m2;
|
|
710
778
|
}
|
|
779
|
+
ql += 32;
|
|
780
|
+
is += 2;
|
|
711
781
|
}
|
|
712
782
|
}
|
|
713
783
|
}
|
|
714
784
|
|
|
715
785
|
static void dequantize_row_q6_K(const void *vx, float *y, int k) {
|
|
716
|
-
const int nb = k /
|
|
786
|
+
const int nb = k / QK_K;
|
|
717
787
|
const uint8_t *x = vx;
|
|
718
788
|
for (int i = 0; i < nb; i++) {
|
|
719
|
-
const
|
|
720
|
-
const uint8_t *
|
|
721
|
-
const uint8_t *qh =
|
|
722
|
-
const
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
v
|
|
729
|
-
y[i*
|
|
789
|
+
const uint8_t *block = x + i * 210;
|
|
790
|
+
const uint8_t *ql = block;
|
|
791
|
+
const uint8_t *qh = block + 128;
|
|
792
|
+
const int8_t *scales = (const int8_t*)(block + 192);
|
|
793
|
+
uint16_t d16;
|
|
794
|
+
memcpy(&d16, block + 208, 2);
|
|
795
|
+
const float d = fp16_to_fp32(d16);
|
|
796
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
797
|
+
for (int l = 0; l < 32; l++) {
|
|
798
|
+
int v = (ql[j/2 + l] & 0xF) | (((qh[j/4 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
799
|
+
y[i*QK_K + j + l] = v * d * scales[j/128 * 8 + l/4];
|
|
800
|
+
}
|
|
801
|
+
for (int l = 0; l < 32; l++) {
|
|
802
|
+
int v = (ql[j/2 + 32 + l] >> 4) | (((qh[j/4 + 16 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
803
|
+
y[i*QK_K + j + 32 + l] = v * d * scales[j/128 * 8 + 8 + l/4];
|
|
804
|
+
}
|
|
805
|
+
for (int l = 0; l < 32; l++) {
|
|
806
|
+
int v = (ql[j/2 + 64 + l] & 0xF) | (((qh[j/4 + 32 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
807
|
+
y[i*QK_K + j + 64 + l] = v * d * scales[j/128 * 8 + 4 + l/4];
|
|
808
|
+
}
|
|
809
|
+
for (int l = 0; l < 32; l++) {
|
|
810
|
+
int v = (ql[j/2 + 96 + l] >> 4) | (((qh[j/4 + 48 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
811
|
+
y[i*QK_K + j + 96 + l] = v * d * scales[j/128 * 8 + 12 + l/4];
|
|
730
812
|
}
|
|
731
813
|
}
|
|
732
814
|
}
|
|
733
815
|
}
|
|
734
816
|
|
|
735
817
|
static void dequantize_row_q8_K(const void *vx, float *y, int k) {
|
|
736
|
-
const int nb = k /
|
|
818
|
+
const int nb = k / QK_K;
|
|
737
819
|
const uint8_t *x = vx;
|
|
738
820
|
for (int i = 0; i < nb; i++) {
|
|
739
|
-
const
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
y[i*256 + j + l] = (float)q[j+l] * d * ls;
|
|
746
|
-
}
|
|
821
|
+
const uint8_t *block = x + i * 292;
|
|
822
|
+
float d;
|
|
823
|
+
memcpy(&d, block, 4);
|
|
824
|
+
const int8_t *q = (const int8_t*)(block + 4);
|
|
825
|
+
for (int j = 0; j < QK_K; j++) {
|
|
826
|
+
y[i*QK_K + j] = (float)q[j] * d;
|
|
747
827
|
}
|
|
748
828
|
}
|
|
749
829
|
}
|
|
750
830
|
|
|
751
|
-
|
|
752
|
-
static
|
|
753
|
-
if (
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
memcpy(out, data, n_rows * n_cols * sizeof(float));
|
|
757
|
-
return out;
|
|
831
|
+
// Lazy single-row dequantization
|
|
832
|
+
static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
|
|
833
|
+
if (!m->raw_tensor_data || row < 0 || row >= m->vocab_size) {
|
|
834
|
+
memset(out, 0, sizeof(float) * m->dim);
|
|
835
|
+
return;
|
|
758
836
|
}
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
837
|
+
|
|
838
|
+
const uint8_t *raw;
|
|
839
|
+
int effective_cols;
|
|
840
|
+
|
|
841
|
+
if (m->need_transpose) {
|
|
842
|
+
int src_row_size;
|
|
843
|
+
switch (m->tensor_type) {
|
|
844
|
+
case GGML_TYPE_F32: src_row_size = m->raw_dim1 * sizeof(float); break;
|
|
845
|
+
case GGML_TYPE_F16: src_row_size = m->raw_dim1 * sizeof(uint16_t); break;
|
|
846
|
+
default: {
|
|
847
|
+
size_t rb = 0;
|
|
848
|
+
int nc = (int)m->raw_dim1;
|
|
849
|
+
switch (m->tensor_type) {
|
|
850
|
+
case GGML_TYPE_Q4_0: rb = (nc / 32) * 18; break;
|
|
851
|
+
case GGML_TYPE_Q4_1: rb = (nc / 32) * 20; break;
|
|
852
|
+
case GGML_TYPE_Q5_0: rb = (nc / 32) * 22; break;
|
|
853
|
+
case GGML_TYPE_Q5_1: rb = (nc / 32) * 24; break;
|
|
854
|
+
case GGML_TYPE_Q8_0: rb = (nc / 32) * 34; break;
|
|
855
|
+
case GGML_TYPE_Q8_1: rb = (nc / 32) * 40; break;
|
|
856
|
+
case GGML_TYPE_Q2_K: rb = (nc / 256) * 84; break;
|
|
857
|
+
case GGML_TYPE_Q3_K: rb = (nc / 256) * 110; break;
|
|
858
|
+
case GGML_TYPE_Q4_K: rb = (nc / 256) * 144; break;
|
|
859
|
+
case GGML_TYPE_Q5_K: rb = (nc / 256) * 176; break;
|
|
860
|
+
case GGML_TYPE_Q6_K: rb = (nc / 256) * 210; break;
|
|
861
|
+
case GGML_TYPE_Q8_K: rb = (nc / 256) * 292; break;
|
|
862
|
+
default: src_row_size = 0; return;
|
|
863
|
+
}
|
|
864
|
+
src_row_size = (int)rb;
|
|
865
|
+
}
|
|
765
866
|
}
|
|
766
|
-
|
|
867
|
+
float *temp_row = malloc(m->raw_dim1 * sizeof(float));
|
|
868
|
+
if (!temp_row) return;
|
|
869
|
+
for (int col = 0; col < m->dim; col++) {
|
|
870
|
+
const uint8_t *src_row = (const uint8_t*)m->raw_tensor_data + col * src_row_size;
|
|
871
|
+
if (m->tensor_type == GGML_TYPE_F32) {
|
|
872
|
+
float val;
|
|
873
|
+
memcpy(&val, src_row + row * sizeof(float), sizeof(float));
|
|
874
|
+
out[col] = val;
|
|
875
|
+
} else if (m->tensor_type == GGML_TYPE_F16) {
|
|
876
|
+
uint16_t val;
|
|
877
|
+
memcpy(&val, src_row + row * sizeof(uint16_t), sizeof(uint16_t));
|
|
878
|
+
out[col] = fp16_to_fp32(val);
|
|
879
|
+
} else {
|
|
880
|
+
memset(out, 0, sizeof(float) * m->dim);
|
|
881
|
+
free(temp_row);
|
|
882
|
+
return;
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
free(temp_row);
|
|
886
|
+
return;
|
|
767
887
|
}
|
|
768
888
|
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
case
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
case
|
|
784
|
-
case
|
|
785
|
-
case
|
|
786
|
-
case
|
|
787
|
-
case
|
|
889
|
+
raw = (const uint8_t*)m->raw_tensor_data + row * m->row_bytes;
|
|
890
|
+
effective_cols = m->dim;
|
|
891
|
+
|
|
892
|
+
switch (m->tensor_type) {
|
|
893
|
+
case GGML_TYPE_F32:
|
|
894
|
+
memcpy(out, raw, effective_cols * sizeof(float));
|
|
895
|
+
break;
|
|
896
|
+
case GGML_TYPE_F16:
|
|
897
|
+
for (int j = 0; j < effective_cols; j++) {
|
|
898
|
+
uint16_t h;
|
|
899
|
+
memcpy(&h, raw + j * sizeof(uint16_t), sizeof(uint16_t));
|
|
900
|
+
out[j] = fp16_to_fp32(h);
|
|
901
|
+
}
|
|
902
|
+
break;
|
|
903
|
+
case GGML_TYPE_Q4_0: dequantize_row_q4_0(raw, out, effective_cols); break;
|
|
904
|
+
case GGML_TYPE_Q4_1: dequantize_row_q4_1(raw, out, effective_cols); break;
|
|
905
|
+
case GGML_TYPE_Q5_0: dequantize_row_q5_0(raw, out, effective_cols); break;
|
|
906
|
+
case GGML_TYPE_Q5_1: dequantize_row_q5_1(raw, out, effective_cols); break;
|
|
907
|
+
case GGML_TYPE_Q8_0: dequantize_row_q8_0(raw, out, effective_cols); break;
|
|
908
|
+
case GGML_TYPE_Q8_1: dequantize_row_q8_1(raw, out, effective_cols); break;
|
|
909
|
+
case GGML_TYPE_Q2_K: dequantize_row_q2_K(raw, out, effective_cols); break;
|
|
910
|
+
case GGML_TYPE_Q3_K: dequantize_row_q3_K(raw, out, effective_cols); break;
|
|
911
|
+
case GGML_TYPE_Q4_K: dequantize_row_q4_K(raw, out, effective_cols); break;
|
|
912
|
+
case GGML_TYPE_Q5_K: dequantize_row_q5_K(raw, out, effective_cols); break;
|
|
913
|
+
case GGML_TYPE_Q6_K: dequantize_row_q6_K(raw, out, effective_cols); break;
|
|
914
|
+
case GGML_TYPE_Q8_K: dequantize_row_q8_K(raw, out, effective_cols); break;
|
|
788
915
|
default:
|
|
789
|
-
|
|
790
|
-
return NULL;
|
|
916
|
+
memset(out, 0, sizeof(float) * effective_cols);
|
|
791
917
|
}
|
|
792
918
|
|
|
793
|
-
for (int
|
|
794
|
-
|
|
919
|
+
for (int j = 0; j < effective_cols; j++) {
|
|
920
|
+
if (isnan(out[j]) || isinf(out[j]) || fabsf(out[j]) > 1e10f) {
|
|
921
|
+
out[j] = 0.0f;
|
|
922
|
+
}
|
|
923
|
+
}
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
static size_t get_row_bytes(int type, int n_cols) {
|
|
927
|
+
switch (type) {
|
|
928
|
+
case GGML_TYPE_F32: return n_cols * sizeof(float);
|
|
929
|
+
case GGML_TYPE_F16: return n_cols * sizeof(uint16_t);
|
|
930
|
+
case GGML_TYPE_Q4_0: return (n_cols / 32) * 18;
|
|
931
|
+
case GGML_TYPE_Q4_1: return (n_cols / 32) * 20;
|
|
932
|
+
case GGML_TYPE_Q5_0: return (n_cols / 32) * 22;
|
|
933
|
+
case GGML_TYPE_Q5_1: return (n_cols / 32) * 24;
|
|
934
|
+
case GGML_TYPE_Q8_0: return (n_cols / 32) * 34;
|
|
935
|
+
case GGML_TYPE_Q8_1: return (n_cols / 32) * 40;
|
|
936
|
+
case GGML_TYPE_Q2_K: return (n_cols / 256) * 84;
|
|
937
|
+
case GGML_TYPE_Q3_K: return (n_cols / 256) * 110;
|
|
938
|
+
case GGML_TYPE_Q4_K: return (n_cols / 256) * 144;
|
|
939
|
+
case GGML_TYPE_Q5_K: return (n_cols / 256) * 176;
|
|
940
|
+
case GGML_TYPE_Q6_K: return (n_cols / 256) * 210;
|
|
941
|
+
case GGML_TYPE_Q8_K: return (n_cols / 256) * 292;
|
|
942
|
+
default: return 0;
|
|
795
943
|
}
|
|
796
|
-
return out;
|
|
797
944
|
}
|
|
798
945
|
|
|
799
946
|
/* ------------------------------------------------------------------------- */
|
|
@@ -809,7 +956,7 @@ static int skip_value(uint8_t **p, uint8_t *end, uint32_t type) {
|
|
|
809
956
|
case 9: {
|
|
810
957
|
uint32_t subtype = rd32(p, end);
|
|
811
958
|
uint64_t n = rd64(p, end);
|
|
812
|
-
for (uint64_t i = 0; i < n; i++)
|
|
959
|
+
for (uint64_t i = 0; i < n && i < 1000000; i++)
|
|
813
960
|
if (!skip_value(p, end, subtype)) return 0;
|
|
814
961
|
return 1;
|
|
815
962
|
}
|
|
@@ -835,79 +982,66 @@ static void free_model_contents(EmbedModel *m) {
|
|
|
835
982
|
}
|
|
836
983
|
free(m->table);
|
|
837
984
|
}
|
|
838
|
-
if (m->float_data) free(m->float_data);
|
|
839
985
|
if (m->mapped) munmap(m->mapped, m->mapped_size);
|
|
840
|
-
|
|
841
|
-
// Free BPE tokenization data
|
|
842
986
|
bpe_merge_table_free(&m->merges);
|
|
843
|
-
if (m->pre_patterns) {
|
|
844
|
-
for (int i = 0; i < m->num_pre_patterns; i++) {
|
|
845
|
-
free(m->pre_patterns[i].pattern);
|
|
846
|
-
}
|
|
847
|
-
free(m->pre_patterns);
|
|
848
|
-
}
|
|
849
|
-
|
|
850
987
|
free(m);
|
|
851
988
|
}
|
|
852
989
|
|
|
853
990
|
/* ------------------------------------------------------------------------- */
|
|
854
991
|
static int is_printable_string(const char *s, size_t len) {
|
|
855
|
-
for (size_t i = 0; i < len; i++)
|
|
856
|
-
if (!isprint((unsigned char)s[i])) return 0;
|
|
992
|
+
for (size_t i = 0; i < len; i++) if (!isprint((unsigned char)s[i])) return 0;
|
|
857
993
|
return 1;
|
|
858
994
|
}
|
|
859
995
|
|
|
860
|
-
/* Fallback: find the start of tensor info by scanning for a valid string */
|
|
861
996
|
static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
|
|
862
997
|
uint8_t *scan = cur;
|
|
863
998
|
while (scan + 8 < end) {
|
|
864
999
|
uint64_t len;
|
|
865
1000
|
memcpy(&len, scan, 8);
|
|
866
|
-
if (len > 0 && len < 256 && scan + 8 + len <= end)
|
|
867
|
-
|
|
868
|
-
return scan;
|
|
869
|
-
}
|
|
870
|
-
}
|
|
1001
|
+
if (len > 0 && len < 256 && scan + 8 + len <= end && is_printable_string((char*)scan+8, len))
|
|
1002
|
+
return scan;
|
|
871
1003
|
scan++;
|
|
872
1004
|
}
|
|
873
1005
|
return NULL;
|
|
874
1006
|
}
|
|
875
1007
|
|
|
876
1008
|
/* ------------------------------------------------------------------------- */
|
|
877
|
-
static void
|
|
878
|
-
|
|
879
|
-
const char *
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
1009
|
+
static void detect_space_marker(EmbedModel *m) {
|
|
1010
|
+
int marker_count[4] = {0};
|
|
1011
|
+
const char *markers[] = {"▁", "Ġ", "ĉ", " "};
|
|
1012
|
+
int marker_lens[] = {3, 2, 2, 1};
|
|
1013
|
+
|
|
1014
|
+
for (int i = 0; i < m->vocab_size && i < 5000; i++) {
|
|
1015
|
+
for (int j = 0; j < 3; j++) {
|
|
1016
|
+
if (strncmp(m->tokens[i], markers[j], marker_lens[j]) == 0) {
|
|
1017
|
+
marker_count[j]++;
|
|
1018
|
+
}
|
|
1019
|
+
}
|
|
1020
|
+
if (m->tokens[i][0] == ' ' && strlen(m->tokens[i]) > 1) {
|
|
1021
|
+
marker_count[3]++;
|
|
1022
|
+
}
|
|
1023
|
+
}
|
|
1024
|
+
|
|
1025
|
+
int best = 0;
|
|
1026
|
+
for (int i = 1; i < 4; i++) {
|
|
1027
|
+
if (marker_count[i] > marker_count[best]) best = i;
|
|
1028
|
+
}
|
|
1029
|
+
|
|
1030
|
+
if (marker_count[best] > 10) {
|
|
1031
|
+
strcpy(m->space_marker, markers[best]);
|
|
1032
|
+
m->space_marker_len = marker_lens[best];
|
|
895
1033
|
}
|
|
896
1034
|
}
|
|
897
1035
|
|
|
898
|
-
/* ------------------------------------------------------------------------- */
|
|
899
1036
|
static void parse_merge(const char *merge_str, char **left, char **right) {
|
|
900
|
-
// Parse a merge string like "h ello" -> left="h", right="ello"
|
|
901
1037
|
const char *space = strchr(merge_str, ' ');
|
|
902
1038
|
if (space) {
|
|
903
1039
|
int left_len = space - merge_str;
|
|
904
1040
|
*left = malloc(left_len + 1);
|
|
905
1041
|
memcpy(*left, merge_str, left_len);
|
|
906
1042
|
(*left)[left_len] = '\0';
|
|
907
|
-
|
|
908
1043
|
*right = strdup(space + 1);
|
|
909
1044
|
} else {
|
|
910
|
-
// No space - treat as single token
|
|
911
1045
|
*left = strdup(merge_str);
|
|
912
1046
|
*right = strdup("");
|
|
913
1047
|
}
|
|
@@ -918,45 +1052,39 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
918
1052
|
size_t sz;
|
|
919
1053
|
uint8_t *base = map_file(path, &sz);
|
|
920
1054
|
if (!base) return NULL;
|
|
921
|
-
uint8_t *cur = base;
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
if (memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
|
|
1055
|
+
uint8_t *cur = base, *end = base + sz;
|
|
1056
|
+
if (sz < 4 || memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
|
|
925
1057
|
cur += 4;
|
|
926
1058
|
uint32_t version = rd32(&cur, end);
|
|
927
1059
|
(void)version;
|
|
928
1060
|
uint64_t n_tensors = rd64(&cur, end);
|
|
929
1061
|
uint64_t n_kv = rd64(&cur, end);
|
|
930
1062
|
|
|
1063
|
+
if (n_kv > 1000000 || n_tensors > 1000000) { munmap(base, sz); return NULL; }
|
|
1064
|
+
|
|
931
1065
|
EmbedModel *m = calloc(1, sizeof(*m));
|
|
932
1066
|
if (!m) { munmap(base, sz); return NULL; }
|
|
933
1067
|
m->mapped = base;
|
|
934
1068
|
m->mapped_size = sz;
|
|
935
1069
|
m->table = calloc(HASH_SIZE, sizeof(HashNode*));
|
|
936
1070
|
if (!m->table) { free_model_contents(m); return NULL; }
|
|
937
|
-
|
|
938
|
-
// Initialize BPE structures
|
|
939
1071
|
bpe_merge_table_init(&m->merges);
|
|
940
|
-
setup_default_pre_patterns(m);
|
|
941
|
-
|
|
942
|
-
// Default values
|
|
943
1072
|
m->unknown_token_id = -1;
|
|
944
1073
|
m->bos_token_id = -1;
|
|
945
1074
|
m->eos_token_id = -1;
|
|
946
1075
|
m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
|
|
1076
|
+
m->normalize = NORM_NONE;
|
|
947
1077
|
|
|
948
|
-
/* ---------- Metadata ---------- */
|
|
949
1078
|
int vocab_found = 0;
|
|
950
1079
|
for (uint64_t i = 0; i < n_kv; i++) {
|
|
951
1080
|
char *key = rdstr(&cur, end);
|
|
952
1081
|
if (!key) { free_model_contents(m); return NULL; }
|
|
953
1082
|
uint32_t type = rd32(&cur, end);
|
|
954
1083
|
|
|
955
|
-
if ((strcmp(key, "tokenizer.ggml.tokens") == 0 ||
|
|
956
|
-
strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
|
|
1084
|
+
if ((strcmp(key, "tokenizer.ggml.tokens") == 0 || strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
|
|
957
1085
|
uint32_t subtype = rd32(&cur, end);
|
|
958
1086
|
uint64_t n = rd64(&cur, end);
|
|
959
|
-
if (subtype != 8) { free(key); free_model_contents(m); return NULL; }
|
|
1087
|
+
if (subtype != 8 || n > 1000000) { free(key); free_model_contents(m); return NULL; }
|
|
960
1088
|
m->tokens = malloc(sizeof(char*) * n);
|
|
961
1089
|
if (!m->tokens) { free(key); free_model_contents(m); return NULL; }
|
|
962
1090
|
m->vocab_size = (int)n;
|
|
@@ -971,66 +1099,64 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
971
1099
|
uint32_t subtype = rd32(&cur, end);
|
|
972
1100
|
uint64_t n = rd64(&cur, end);
|
|
973
1101
|
if (subtype == 8) {
|
|
974
|
-
// Parse merges
|
|
975
1102
|
for (uint64_t j = 0; j < n && j < MAX_MERGES; j++) {
|
|
976
1103
|
char *merge_str = rdstr(&cur, end);
|
|
977
1104
|
if (merge_str) {
|
|
978
1105
|
char *left, *right;
|
|
979
1106
|
parse_merge(merge_str, &left, &right);
|
|
980
|
-
bpe_merge_table_add(&m->merges, left, right,
|
|
1107
|
+
bpe_merge_table_add(&m->merges, left, right, (int)j);
|
|
981
1108
|
free(left);
|
|
982
1109
|
free(right);
|
|
983
1110
|
free(merge_str);
|
|
1111
|
+
} else {
|
|
1112
|
+
break;
|
|
984
1113
|
}
|
|
985
1114
|
}
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
1115
|
+
if (n > MAX_MERGES) {
|
|
1116
|
+
for (uint64_t j = MAX_MERGES; j < n; j++) {
|
|
1117
|
+
char *merge_str = rdstr(&cur, end);
|
|
1118
|
+
free(merge_str);
|
|
1119
|
+
}
|
|
990
1120
|
}
|
|
1121
|
+
} else {
|
|
1122
|
+
if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
|
|
991
1123
|
}
|
|
992
1124
|
} else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
|
|
993
1125
|
char *model_type = rdstr(&cur, end);
|
|
994
1126
|
if (model_type) {
|
|
995
|
-
if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0
|
|
1127
|
+
if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0 ||
|
|
1128
|
+
strcmp(model_type, "phi") == 0 || strcmp(model_type, "qwen") == 0)
|
|
996
1129
|
m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
|
|
997
|
-
|
|
1130
|
+
else if (strcmp(model_type, "bert") == 0)
|
|
998
1131
|
m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
|
|
999
|
-
|
|
1132
|
+
else if (strcmp(model_type, "spm") == 0)
|
|
1133
|
+
m->vocab_type = LLAMA_VOCAB_TYPE_SPM;
|
|
1000
1134
|
free(model_type);
|
|
1001
1135
|
}
|
|
1002
1136
|
} else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
|
|
1003
|
-
char *
|
|
1004
|
-
|
|
1005
|
-
// Could load custom regex patterns here if needed
|
|
1006
|
-
free(pre_type);
|
|
1007
|
-
}
|
|
1137
|
+
char *pre = rdstr(&cur, end);
|
|
1138
|
+
free(pre);
|
|
1008
1139
|
} else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
|
|
1009
|
-
m->unknown_token_id = rd32(&cur, end);
|
|
1140
|
+
m->unknown_token_id = (int)rd32(&cur, end);
|
|
1010
1141
|
} else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
|
|
1011
|
-
m->bos_token_id = rd32(&cur, end);
|
|
1142
|
+
m->bos_token_id = (int)rd32(&cur, end);
|
|
1012
1143
|
} else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
|
|
1013
|
-
m->eos_token_id = rd32(&cur, end);
|
|
1144
|
+
m->eos_token_id = (int)rd32(&cur, end);
|
|
1145
|
+
} else if (strcmp(key, "general.alignment") == 0 && type == 6) {
|
|
1146
|
+
rd32(&cur, end);
|
|
1014
1147
|
} else {
|
|
1015
|
-
if (!skip_value(&cur, end, type)) {
|
|
1016
|
-
free(key); free_model_contents(m); return NULL;
|
|
1017
|
-
}
|
|
1148
|
+
if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
|
|
1018
1149
|
}
|
|
1019
1150
|
free(key);
|
|
1020
1151
|
}
|
|
1021
1152
|
|
|
1022
1153
|
if (!vocab_found) { free_model_contents(m); return NULL; }
|
|
1154
|
+
detect_space_marker(m);
|
|
1023
1155
|
|
|
1024
1156
|
uint8_t *after_kv = cur;
|
|
1025
1157
|
align_to_32(&cur, end, base);
|
|
1026
1158
|
uint8_t *tensor_start = cur;
|
|
1027
|
-
|
|
1028
|
-
/* ---------- Tensor info ---------- */
|
|
1029
1159
|
int embd_found = 0;
|
|
1030
|
-
void *raw_tensor_data = NULL;
|
|
1031
|
-
int tensor_type = -1;
|
|
1032
|
-
uint64_t dim0 = 0, dim1 = 0;
|
|
1033
|
-
int need_transpose = 0;
|
|
1034
1160
|
|
|
1035
1161
|
for (int attempt = 0; attempt < 2; attempt++) {
|
|
1036
1162
|
cur = tensor_start;
|
|
@@ -1039,8 +1165,7 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1039
1165
|
if (!name) break;
|
|
1040
1166
|
uint32_t n_dims = rd32(&cur, end);
|
|
1041
1167
|
uint64_t dims[MAX_DIMS] = {0};
|
|
1042
|
-
for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++)
|
|
1043
|
-
dims[d] = rd64(&cur, end);
|
|
1168
|
+
for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
|
|
1044
1169
|
uint32_t type = rd32(&cur, end);
|
|
1045
1170
|
uint64_t offset = rd64(&cur, end);
|
|
1046
1171
|
|
|
@@ -1049,29 +1174,55 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1049
1174
|
strcmp(name, "model.embed_tokens.weight") == 0);
|
|
1050
1175
|
|
|
1051
1176
|
if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
|
|
1052
|
-
if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")
|
|
1053
|
-
|
|
1054
|
-
else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd") != NULL)
|
|
1055
|
-
is_token_embd = 1;
|
|
1177
|
+
if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
|
|
1178
|
+
else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
|
|
1056
1179
|
}
|
|
1057
1180
|
|
|
1058
1181
|
if (!embd_found && is_token_embd) {
|
|
1059
|
-
if (n_dims < 2 || dims[1] == 0) {
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1182
|
+
if (n_dims < 2 || dims[1] == 0) {
|
|
1183
|
+
free(name); free_model_contents(m); return NULL;
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
uint64_t ne0 = dims[0];
|
|
1187
|
+
uint64_t ne1 = dims[1];
|
|
1188
|
+
|
|
1189
|
+
int need_transpose = 0;
|
|
1190
|
+
int dim;
|
|
1191
|
+
|
|
1192
|
+
if (ne1 == (uint64_t)m->vocab_size) {
|
|
1193
|
+
dim = (int)ne0;
|
|
1064
1194
|
need_transpose = 0;
|
|
1065
|
-
} else if (
|
|
1066
|
-
|
|
1195
|
+
} else if (ne0 == (uint64_t)m->vocab_size) {
|
|
1196
|
+
dim = (int)ne1;
|
|
1067
1197
|
need_transpose = 1;
|
|
1068
1198
|
} else {
|
|
1069
|
-
|
|
1070
|
-
need_transpose = (
|
|
1199
|
+
dim = (ne0 < ne1) ? (int)ne0 : (int)ne1;
|
|
1200
|
+
need_transpose = (ne0 > ne1) ? 1 : 0;
|
|
1071
1201
|
}
|
|
1072
|
-
|
|
1073
|
-
|
|
1202
|
+
|
|
1203
|
+
if (dim <= 0 || dim > MAX_DIM) {
|
|
1204
|
+
free(name); free_model_contents(m); return NULL;
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
size_t row_bytes = get_row_bytes(type, (int)(need_transpose ? ne1 : ne0));
|
|
1208
|
+
size_t total_size = (size_t)(need_transpose ? ne1 : ne0) * row_bytes;
|
|
1209
|
+
|
|
1210
|
+
if (offset >= sz || offset + total_size > sz) {
|
|
1211
|
+
free(name);
|
|
1212
|
+
free_model_contents(m);
|
|
1213
|
+
return NULL;
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
m->dim = dim;
|
|
1217
|
+
m->raw_dim0 = ne0;
|
|
1218
|
+
m->raw_dim1 = ne1;
|
|
1219
|
+
m->need_transpose = need_transpose;
|
|
1220
|
+
m->raw_tensor_data = base + offset;
|
|
1221
|
+
m->tensor_type = type;
|
|
1222
|
+
m->row_bytes = row_bytes;
|
|
1074
1223
|
embd_found = 1;
|
|
1224
|
+
free(name);
|
|
1225
|
+
break;
|
|
1075
1226
|
}
|
|
1076
1227
|
free(name);
|
|
1077
1228
|
}
|
|
@@ -1082,110 +1233,122 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
1082
1233
|
}
|
|
1083
1234
|
}
|
|
1084
1235
|
|
|
1085
|
-
if (!embd_found || m->dim == 0) {
|
|
1086
|
-
free_model_contents(m);
|
|
1087
|
-
return NULL;
|
|
1088
|
-
}
|
|
1089
|
-
|
|
1090
|
-
/* Dequantize */
|
|
1091
|
-
if (tensor_type == GGML_TYPE_F32 && !need_transpose) {
|
|
1092
|
-
m->float_data = NULL;
|
|
1093
|
-
m->tensor_data = raw_tensor_data;
|
|
1094
|
-
} else {
|
|
1095
|
-
int n_rows = need_transpose ? (int)dim1 : (int)dim0;
|
|
1096
|
-
int n_cols = need_transpose ? (int)dim0 : (int)dim1;
|
|
1097
|
-
m->float_data = dequantize_tensor(raw_tensor_data, tensor_type, n_rows, n_cols);
|
|
1098
|
-
if (!m->float_data) {
|
|
1099
|
-
free_model_contents(m);
|
|
1100
|
-
return NULL;
|
|
1101
|
-
}
|
|
1102
|
-
m->tensor_data = m->float_data;
|
|
1236
|
+
if (!embd_found || m->dim == 0) {
|
|
1237
|
+
free_model_contents(m); return NULL;
|
|
1103
1238
|
}
|
|
1104
|
-
m->tensor_type = tensor_type;
|
|
1105
1239
|
|
|
1106
1240
|
return m;
|
|
1107
1241
|
}
|
|
1108
1242
|
|
|
1243
|
+
/* ------------------------------------------------------------------------- */
|
|
1244
|
+
// L2 normalization
|
|
1245
|
+
static void normalize_l2(float *vec, int dim) {
|
|
1246
|
+
float sum = 0;
|
|
1247
|
+
for (int i = 0; i < dim; i++) sum += vec[i] * vec[i];
|
|
1248
|
+
float norm = sqrtf(sum);
|
|
1249
|
+
if (norm > 1e-8f) {
|
|
1250
|
+
float inv = 1.0f / norm;
|
|
1251
|
+
for (int i = 0; i < dim; i++) vec[i] *= inv;
|
|
1252
|
+
}
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1109
1255
|
/* ------------------------------------------------------------------------- */
|
|
1110
1256
|
static void embed_text(EmbedModel *m, const char *txt, float *out) {
|
|
1111
1257
|
memset(out, 0, sizeof(float) * m->dim);
|
|
1112
|
-
|
|
1113
|
-
|
|
1258
|
+
if (!txt || !*txt) return;
|
|
1259
|
+
|
|
1114
1260
|
int num_words = 0;
|
|
1115
|
-
char **words =
|
|
1116
|
-
|
|
1261
|
+
char **words = pre_tokenize(txt, &num_words);
|
|
1262
|
+
|
|
1117
1263
|
if (!words || num_words == 0) {
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
while (tok) {
|
|
1127
|
-
int id = hget(m, tok);
|
|
1128
|
-
if (id >= 0 && id < m->vocab_size) {
|
|
1129
|
-
const float *vec = embd_matrix + id * m->dim;
|
|
1130
|
-
for (int i = 0; i < m->dim; i++) out[i] += vec[i];
|
|
1131
|
-
used++;
|
|
1132
|
-
}
|
|
1133
|
-
tok = strtok(NULL, " \t\n\r");
|
|
1134
|
-
}
|
|
1135
|
-
|
|
1136
|
-
if (used > 0) {
|
|
1137
|
-
float inv = 1.0f / used;
|
|
1138
|
-
for (int i = 0; i < m->dim; i++) out[i] *= inv;
|
|
1139
|
-
}
|
|
1140
|
-
free(copy);
|
|
1264
|
+
if (words) free(words);
|
|
1265
|
+
return;
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
int *token_ids = malloc(m->vocab_size * sizeof(int));
|
|
1269
|
+
if (!token_ids) {
|
|
1270
|
+
for (int i = 0; i < num_words; i++) free(words[i]);
|
|
1271
|
+
free(words);
|
|
1141
1272
|
return;
|
|
1142
1273
|
}
|
|
1143
|
-
|
|
1144
|
-
// Tokenize each word using BPE
|
|
1145
|
-
int *token_ids = malloc(m->vocab_size * sizeof(int)); // Max possible tokens
|
|
1146
|
-
int num_tokens = 0;
|
|
1147
|
-
const float *embd_matrix = m->tensor_data;
|
|
1274
|
+
|
|
1148
1275
|
int used = 0;
|
|
1149
|
-
|
|
1276
|
+
float *temp_vec = malloc(m->dim * sizeof(float));
|
|
1277
|
+
|
|
1150
1278
|
for (int i = 0; i < num_words; i++) {
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1279
|
+
char *word = words[i];
|
|
1280
|
+
int id = hget(m, word);
|
|
1281
|
+
|
|
1282
|
+
if (id == -1 && m->space_marker_len > 0) {
|
|
1283
|
+
size_t with_marker_len = m->space_marker_len + strlen(word);
|
|
1284
|
+
char *with_marker = malloc(with_marker_len + 1);
|
|
1285
|
+
if (with_marker) {
|
|
1286
|
+
memcpy(with_marker, m->space_marker, m->space_marker_len);
|
|
1287
|
+
strcpy(with_marker + m->space_marker_len, word);
|
|
1288
|
+
id = hget(m, with_marker);
|
|
1289
|
+
free(with_marker);
|
|
1290
|
+
}
|
|
1291
|
+
}
|
|
1292
|
+
|
|
1293
|
+
if (id != -1 && id >= 0 && id < m->vocab_size) {
|
|
1294
|
+
dequantize_row_lazy(m, id, temp_vec);
|
|
1295
|
+
for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
|
|
1296
|
+
used++;
|
|
1297
|
+
} else {
|
|
1298
|
+
int num_tokens = 0;
|
|
1299
|
+
bpe_tokenize_word(&m->merges, word, m, token_ids, &num_tokens);
|
|
1300
|
+
for (int k = 0; k < num_tokens; k++) {
|
|
1301
|
+
int tid = token_ids[k];
|
|
1302
|
+
if (tid >= 0 && tid < m->vocab_size) {
|
|
1303
|
+
dequantize_row_lazy(m, tid, temp_vec);
|
|
1304
|
+
for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
|
|
1305
|
+
used++;
|
|
1306
|
+
} else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
|
|
1307
|
+
dequantize_row_lazy(m, m->unknown_token_id, temp_vec);
|
|
1308
|
+
for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
|
|
1309
|
+
used++;
|
|
1310
|
+
}
|
|
1165
1311
|
}
|
|
1166
1312
|
}
|
|
1167
|
-
|
|
1168
|
-
free(words[i]);
|
|
1313
|
+
free(word);
|
|
1169
1314
|
}
|
|
1315
|
+
|
|
1170
1316
|
free(words);
|
|
1171
1317
|
free(token_ids);
|
|
1172
|
-
|
|
1318
|
+
free(temp_vec);
|
|
1319
|
+
|
|
1173
1320
|
if (used > 0) {
|
|
1174
1321
|
float inv = 1.0f / used;
|
|
1175
1322
|
for (int i = 0; i < m->dim; i++) out[i] *= inv;
|
|
1176
1323
|
}
|
|
1324
|
+
|
|
1325
|
+
for (int i = 0; i < m->dim; i++) {
|
|
1326
|
+
if (isnan(out[i]) || isinf(out[i])) {
|
|
1327
|
+
out[i] = 0.0f;
|
|
1328
|
+
}
|
|
1329
|
+
}
|
|
1330
|
+
|
|
1331
|
+
if (m->normalize == NORM_L2) {
|
|
1332
|
+
normalize_l2(out, m->dim);
|
|
1333
|
+
}
|
|
1177
1334
|
}
|
|
1178
1335
|
|
|
1179
1336
|
/* ------------------------------------------------------------------------- */
|
|
1337
|
+
// Ruby bindings
|
|
1180
1338
|
static void rb_embedder_free(void *p) {
|
|
1181
1339
|
ruby_embedder *e = p;
|
|
1182
|
-
if (
|
|
1183
|
-
if (e->model) free_model_contents(e->model);
|
|
1184
|
-
free(e);
|
|
1340
|
+
if (e) { if (e->model) free_model_contents(e->model); free(e); }
|
|
1185
1341
|
}
|
|
1186
1342
|
|
|
1187
1343
|
static size_t rb_embedder_memsize(const void *p) {
|
|
1188
|
-
|
|
1344
|
+
const ruby_embedder *e = p;
|
|
1345
|
+
size_t sz = sizeof(ruby_embedder);
|
|
1346
|
+
if (e && e->model) {
|
|
1347
|
+
sz += e->model->vocab_size * sizeof(char*);
|
|
1348
|
+
sz += e->model->mapped_size;
|
|
1349
|
+
sz += HASH_SIZE * sizeof(HashNode*);
|
|
1350
|
+
}
|
|
1351
|
+
return sz;
|
|
1189
1352
|
}
|
|
1190
1353
|
|
|
1191
1354
|
static const rb_data_type_t ruby_embedder_type = {
|
|
@@ -1203,11 +1366,31 @@ static VALUE rb_embedder_initialize(VALUE self, VALUE opts) {
|
|
|
1203
1366
|
ruby_embedder *e;
|
|
1204
1367
|
TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
|
|
1205
1368
|
|
|
1369
|
+
Check_Type(opts, T_HASH);
|
|
1206
1370
|
VALUE path = rb_hash_aref(opts, ID2SYM(rb_intern("model")));
|
|
1371
|
+
if (NIL_P(path)) rb_raise(rb_eArgError, "missing required key: model");
|
|
1207
1372
|
const char *cpath = StringValueCStr(path);
|
|
1373
|
+
|
|
1374
|
+
VALUE normalize = rb_hash_aref(opts, ID2SYM(rb_intern("normalize")));
|
|
1375
|
+
int norm_type = NORM_NONE;
|
|
1376
|
+
if (!NIL_P(normalize)) {
|
|
1377
|
+
if (SYMBOL_P(normalize)) {
|
|
1378
|
+
ID sym_id = SYM2ID(normalize);
|
|
1379
|
+
if (sym_id == rb_intern("l2") || sym_id == rb_intern("L2")) {
|
|
1380
|
+
norm_type = NORM_L2;
|
|
1381
|
+
}
|
|
1382
|
+
} else if (TYPE(normalize) == T_STRING) {
|
|
1383
|
+
const char *norm_str = StringValueCStr(normalize);
|
|
1384
|
+
if (strcasecmp(norm_str, "l2") == 0) {
|
|
1385
|
+
norm_type = NORM_L2;
|
|
1386
|
+
}
|
|
1387
|
+
}
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1208
1390
|
e->model = embed_load_gguf(cpath);
|
|
1209
|
-
if (!e->model)
|
|
1210
|
-
|
|
1391
|
+
if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model: %s", cpath);
|
|
1392
|
+
|
|
1393
|
+
e->model->normalize = norm_type;
|
|
1211
1394
|
return self;
|
|
1212
1395
|
}
|
|
1213
1396
|
|
|
@@ -1215,7 +1398,9 @@ static VALUE rb_embed(VALUE self, VALUE opts) {
|
|
|
1215
1398
|
ruby_embedder *e;
|
|
1216
1399
|
TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
|
|
1217
1400
|
|
|
1401
|
+
Check_Type(opts, T_HASH);
|
|
1218
1402
|
VALUE text = rb_hash_aref(opts, ID2SYM(rb_intern("text")));
|
|
1403
|
+
if (NIL_P(text)) rb_raise(rb_eArgError, "missing required key: text");
|
|
1219
1404
|
const char *ctext = StringValueCStr(text);
|
|
1220
1405
|
|
|
1221
1406
|
VALUE out = rb_str_new(NULL, e->model->dim * sizeof(float));
|
|
@@ -1227,5 +1412,5 @@ void Init_mini_embed(void) {
|
|
|
1227
1412
|
VALUE c = rb_define_class("MiniEmbed", rb_cObject);
|
|
1228
1413
|
rb_define_alloc_func(c, rb_embedder_alloc);
|
|
1229
1414
|
rb_define_method(c, "initialize", rb_embedder_initialize, 1);
|
|
1230
|
-
rb_define_method(c, "
|
|
1415
|
+
rb_define_method(c, "embed", rb_embed, 1);
|
|
1231
1416
|
}
|