mini_embed 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mini_embed/mini_embed.c +743 -389
- data/lib/mini_embed.rb +1 -1
- metadata +1 -1
data/ext/mini_embed/mini_embed.c
CHANGED
|
@@ -8,13 +8,18 @@
|
|
|
8
8
|
#include <fcntl.h>
|
|
9
9
|
#include <unistd.h>
|
|
10
10
|
#include <ctype.h>
|
|
11
|
+
#include <limits.h>
|
|
11
12
|
#include "ruby.h"
|
|
12
13
|
|
|
13
|
-
#define HASH_SIZE
|
|
14
|
-
#define MAX_DIMS
|
|
15
|
-
#define GGUF_ALIGN
|
|
16
|
-
#define MAX_MERGES
|
|
17
|
-
#define
|
|
14
|
+
#define HASH_SIZE 131071
|
|
15
|
+
#define MAX_DIMS 4
|
|
16
|
+
#define GGUF_ALIGN 32
|
|
17
|
+
#define MAX_MERGES 100000
|
|
18
|
+
#define MERGE_HASH_SIZE 65537
|
|
19
|
+
#define QK8_0 32
|
|
20
|
+
#define QK_K 256
|
|
21
|
+
#define K_SCALE_SIZE 12
|
|
22
|
+
#define MAX_DIM 16384
|
|
18
23
|
|
|
19
24
|
enum ggml_type {
|
|
20
25
|
GGML_TYPE_F32 = 0,
|
|
@@ -40,6 +45,11 @@ enum llama_vocab_type {
|
|
|
40
45
|
LLAMA_VOCAB_TYPE_WPM = 3,
|
|
41
46
|
};
|
|
42
47
|
|
|
48
|
+
enum normalize_type {
|
|
49
|
+
NORM_NONE = 0,
|
|
50
|
+
NORM_L2 = 1,
|
|
51
|
+
};
|
|
52
|
+
|
|
43
53
|
/* ------------------------------------------------------------------------- */
|
|
44
54
|
// Unicode helper functions
|
|
45
55
|
static int unicode_len_utf8(char c) {
|
|
@@ -77,135 +87,233 @@ static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
|
|
|
77
87
|
}
|
|
78
88
|
|
|
79
89
|
/* ------------------------------------------------------------------------- */
|
|
80
|
-
//
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
90
|
+
// Pre-tokenizer (GPT-2/Llama style, replaces broken regex)
|
|
91
|
+
#define CHAR_CLASS_SPACE 0
|
|
92
|
+
#define CHAR_CLASS_LETTER 1
|
|
93
|
+
#define CHAR_CLASS_NUMBER 2
|
|
94
|
+
#define CHAR_CLASS_NEWLINE 3
|
|
95
|
+
#define CHAR_CLASS_OTHER 4
|
|
96
|
+
|
|
97
|
+
static int get_char_class(uint32_t cp) {
|
|
98
|
+
if (unicode_is_letter(cp)) return CHAR_CLASS_LETTER;
|
|
99
|
+
if (unicode_is_number(cp)) return CHAR_CLASS_NUMBER;
|
|
100
|
+
if (cp == '\n' || cp == '\r') return CHAR_CLASS_NEWLINE;
|
|
101
|
+
if (cp == ' ' || cp == '\t') return CHAR_CLASS_SPACE;
|
|
102
|
+
return CHAR_CLASS_OTHER;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
static int is_contraction(const char *text, size_t pos, size_t text_len) {
|
|
106
|
+
if (pos >= text_len) return 0;
|
|
107
|
+
unsigned char c = (unsigned char)text[pos];
|
|
108
|
+
if (c != '\'' && c != 0xE2) return 0;
|
|
109
|
+
if (c == 0xE2 && pos + 2 < text_len && text[pos+1] == 0x80 && (text[pos+2] == 0x99 || text[pos+2] == 0x98)) {
|
|
110
|
+
if (pos + 3 >= text_len) return 0;
|
|
111
|
+
char next = tolower((unsigned char)text[pos + 3]);
|
|
112
|
+
return next == 's' || next == 't' || next == 'r' || next == 'v' ||
|
|
113
|
+
next == 'm' || next == 'l' || next == 'd';
|
|
114
|
+
}
|
|
115
|
+
if (c == '\'' && pos + 1 < text_len) {
|
|
116
|
+
char next = tolower((unsigned char)text[pos + 1]);
|
|
117
|
+
return next == 's' || next == 't' || next == 'r' || next == 'v' ||
|
|
118
|
+
next == 'm' || next == 'l' || next == 'd';
|
|
108
119
|
}
|
|
109
120
|
return 0;
|
|
110
121
|
}
|
|
111
122
|
|
|
112
|
-
static
|
|
123
|
+
static size_t contraction_len(const char *text, size_t pos) {
|
|
124
|
+
unsigned char c = (unsigned char)text[pos];
|
|
125
|
+
if (c == '\'') return 2;
|
|
126
|
+
return 4;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
static char** pre_tokenize(const char *text, int *num_words) {
|
|
113
130
|
char **words = NULL;
|
|
114
131
|
int word_count = 0, word_capacity = 0;
|
|
115
|
-
size_t text_len = strlen(text)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
132
|
+
size_t text_len = strlen(text);
|
|
133
|
+
|
|
134
|
+
if (text_len == 0) {
|
|
135
|
+
*num_words = 0;
|
|
136
|
+
return NULL;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
#define ADD_WORD(ptr, len) do { \
|
|
140
|
+
char *w = malloc((len) + 1); \
|
|
141
|
+
if (!w) goto error; \
|
|
142
|
+
memcpy(w, ptr, len); \
|
|
143
|
+
w[len] = '\0'; \
|
|
144
|
+
if (word_count >= word_capacity) { \
|
|
145
|
+
word_capacity = word_capacity ? word_capacity * 2 : 16; \
|
|
146
|
+
char **nw = realloc(words, word_capacity * sizeof(char*)); \
|
|
147
|
+
if (!nw) { free(w); goto error; } \
|
|
148
|
+
words = nw; \
|
|
149
|
+
} \
|
|
150
|
+
words[word_count++] = w; \
|
|
151
|
+
} while(0)
|
|
152
|
+
|
|
153
|
+
size_t i = 0;
|
|
154
|
+
while (i < text_len) {
|
|
155
|
+
size_t char_len;
|
|
156
|
+
uint32_t cp = unicode_cpt_from_utf8(text + i, &char_len);
|
|
157
|
+
int cls = get_char_class(cp);
|
|
158
|
+
|
|
159
|
+
if (cls == CHAR_CLASS_NEWLINE) {
|
|
160
|
+
ADD_WORD(text + i, char_len);
|
|
161
|
+
i += char_len;
|
|
162
|
+
continue;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if (cls == CHAR_CLASS_SPACE) {
|
|
166
|
+
size_t space_start = i;
|
|
167
|
+
while (i < text_len) {
|
|
168
|
+
size_t cl;
|
|
169
|
+
uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
|
|
170
|
+
int cc = get_char_class(c);
|
|
171
|
+
if (cc != CHAR_CLASS_SPACE) break;
|
|
172
|
+
i += cl;
|
|
133
173
|
}
|
|
134
|
-
|
|
174
|
+
if (i >= text_len) break;
|
|
175
|
+
size_t space_len = i - space_start;
|
|
176
|
+
ADD_WORD(text + space_start, space_len);
|
|
177
|
+
continue;
|
|
135
178
|
}
|
|
136
|
-
|
|
179
|
+
|
|
180
|
+
size_t start = i;
|
|
181
|
+
i += char_len;
|
|
182
|
+
|
|
183
|
+
while (i < text_len) {
|
|
184
|
+
size_t cl;
|
|
185
|
+
uint32_t c = unicode_cpt_from_utf8(text + i, &cl);
|
|
186
|
+
int ccls = get_char_class(c);
|
|
187
|
+
|
|
188
|
+
if (is_contraction(text, i, text_len)) {
|
|
189
|
+
size_t clen = contraction_len(text, i);
|
|
190
|
+
i += clen;
|
|
191
|
+
continue;
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
if (ccls != cls) break;
|
|
195
|
+
if (cls == CHAR_CLASS_NUMBER) {
|
|
196
|
+
int digits = 0;
|
|
197
|
+
size_t check = start;
|
|
198
|
+
while (check < i) {
|
|
199
|
+
size_t dl;
|
|
200
|
+
uint32_t dc = unicode_cpt_from_utf8(text + check, &dl);
|
|
201
|
+
if (get_char_class(dc) == CHAR_CLASS_NUMBER) digits++;
|
|
202
|
+
check += dl;
|
|
203
|
+
}
|
|
204
|
+
if (digits >= 3) break;
|
|
205
|
+
}
|
|
206
|
+
i += cl;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
ADD_WORD(text + start, i - start);
|
|
137
210
|
}
|
|
211
|
+
|
|
212
|
+
#undef ADD_WORD
|
|
138
213
|
*num_words = word_count;
|
|
139
214
|
return words;
|
|
215
|
+
|
|
216
|
+
error:
|
|
217
|
+
for (int j = 0; j < word_count; j++) free(words[j]);
|
|
218
|
+
free(words);
|
|
219
|
+
*num_words = 0;
|
|
220
|
+
return NULL;
|
|
140
221
|
}
|
|
141
222
|
|
|
142
223
|
/* ------------------------------------------------------------------------- */
|
|
143
|
-
// BPE merge structures
|
|
144
|
-
typedef struct {
|
|
224
|
+
// BPE merge structures with hash table for O(1) lookup
|
|
225
|
+
typedef struct MergeHashNode {
|
|
145
226
|
char *left;
|
|
146
227
|
char *right;
|
|
147
|
-
char *merged;
|
|
148
228
|
int rank;
|
|
149
|
-
|
|
229
|
+
struct MergeHashNode *next;
|
|
230
|
+
} MergeHashNode;
|
|
150
231
|
|
|
151
232
|
typedef struct {
|
|
152
|
-
|
|
233
|
+
MergeHashNode **table;
|
|
234
|
+
int table_size;
|
|
153
235
|
int num_merges;
|
|
154
|
-
int capacity;
|
|
155
236
|
} BPEMergeTable;
|
|
156
237
|
|
|
238
|
+
static uint64_t merge_hash(const char *left, const char *right) {
|
|
239
|
+
uint64_t h = 0xcbf29ce484222325ULL;
|
|
240
|
+
while (*left) { h ^= (uint64_t)(unsigned char)*left++; h *= 0x100000001b3ULL; }
|
|
241
|
+
h ^= (uint64_t)' ';
|
|
242
|
+
h *= 0x100000001b3ULL;
|
|
243
|
+
while (*right) { h ^= (uint64_t)(unsigned char)*right++; h *= 0x100000001b3ULL; }
|
|
244
|
+
return h;
|
|
245
|
+
}
|
|
246
|
+
|
|
157
247
|
static void bpe_merge_table_init(BPEMergeTable *table) {
|
|
158
|
-
|
|
248
|
+
table->table_size = MERGE_HASH_SIZE;
|
|
249
|
+
table->table = calloc(MERGE_HASH_SIZE, sizeof(MergeHashNode*));
|
|
250
|
+
table->num_merges = 0;
|
|
159
251
|
}
|
|
160
252
|
|
|
161
|
-
static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right,
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
253
|
+
static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, int rank) {
|
|
254
|
+
uint64_t h = merge_hash(left, right) % table->table_size;
|
|
255
|
+
MergeHashNode *n = malloc(sizeof(MergeHashNode));
|
|
256
|
+
if (!n) return;
|
|
257
|
+
n->left = strdup(left);
|
|
258
|
+
n->right = strdup(right);
|
|
259
|
+
n->rank = rank;
|
|
260
|
+
n->next = table->table[h];
|
|
261
|
+
table->table[h] = n;
|
|
262
|
+
table->num_merges++;
|
|
171
263
|
}
|
|
172
264
|
|
|
173
265
|
static void bpe_merge_table_free(BPEMergeTable *table) {
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
266
|
+
if (!table->table) return;
|
|
267
|
+
for (int i = 0; i < table->table_size; i++) {
|
|
268
|
+
MergeHashNode *n = table->table[i];
|
|
269
|
+
while (n) {
|
|
270
|
+
MergeHashNode *next = n->next;
|
|
271
|
+
free(n->left);
|
|
272
|
+
free(n->right);
|
|
273
|
+
free(n);
|
|
274
|
+
n = next;
|
|
275
|
+
}
|
|
178
276
|
}
|
|
179
|
-
free(table->
|
|
180
|
-
table->
|
|
181
|
-
table->num_merges = 0;
|
|
277
|
+
free(table->table);
|
|
278
|
+
table->table = NULL;
|
|
182
279
|
}
|
|
183
280
|
|
|
184
281
|
static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
282
|
+
uint64_t h = merge_hash(left, right) % table->table_size;
|
|
283
|
+
MergeHashNode *n = table->table[h];
|
|
284
|
+
while (n) {
|
|
285
|
+
if (strcmp(n->left, left) == 0 && strcmp(n->right, right) == 0)
|
|
286
|
+
return n->rank;
|
|
287
|
+
n = n->next;
|
|
188
288
|
}
|
|
189
289
|
return -1;
|
|
190
290
|
}
|
|
191
291
|
|
|
192
292
|
/* ------------------------------------------------------------------------- */
|
|
193
|
-
// BPE tokenization
|
|
293
|
+
// BPE tokenization (correct iterative algorithm)
|
|
194
294
|
typedef struct {
|
|
195
|
-
char *text;
|
|
295
|
+
const char *text;
|
|
196
296
|
int start, end;
|
|
197
297
|
int prev, next;
|
|
198
298
|
int used;
|
|
199
299
|
} BPESymbol;
|
|
200
300
|
|
|
201
|
-
static
|
|
301
|
+
static int text_to_id(void *vocab_data, const char *text);
|
|
302
|
+
|
|
303
|
+
static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word,
|
|
304
|
+
void *vocab_data, int *token_ids, int *num_tokens) {
|
|
202
305
|
int word_len = strlen(word);
|
|
306
|
+
if (word_len == 0) return;
|
|
307
|
+
|
|
203
308
|
int num_symbols = 0;
|
|
204
309
|
BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
|
|
310
|
+
if (!symbols) return;
|
|
311
|
+
|
|
205
312
|
int offset = 0;
|
|
206
313
|
while (offset < word_len) {
|
|
207
314
|
int char_len = unicode_len_utf8(word[offset]);
|
|
208
|
-
|
|
315
|
+
if (offset + char_len > word_len) char_len = word_len - offset;
|
|
316
|
+
symbols[num_symbols].text = word;
|
|
209
317
|
symbols[num_symbols].start = offset;
|
|
210
318
|
symbols[num_symbols].end = offset + char_len;
|
|
211
319
|
symbols[num_symbols].prev = num_symbols - 1;
|
|
@@ -215,6 +323,8 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
|
|
|
215
323
|
num_symbols++;
|
|
216
324
|
}
|
|
217
325
|
|
|
326
|
+
if (num_symbols > 0) symbols[num_symbols - 1].next = -1;
|
|
327
|
+
|
|
218
328
|
if (num_symbols <= 1) {
|
|
219
329
|
int id = text_to_id(vocab_data, word);
|
|
220
330
|
if (id != -1) token_ids[(*num_tokens)++] = id;
|
|
@@ -222,56 +332,61 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
|
|
|
222
332
|
return;
|
|
223
333
|
}
|
|
224
334
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
335
|
+
while (1) {
|
|
336
|
+
int best_rank = INT_MAX;
|
|
337
|
+
int best_idx = -1;
|
|
338
|
+
|
|
339
|
+
int idx = 0;
|
|
340
|
+
while (idx != -1) {
|
|
341
|
+
int next = symbols[idx].next;
|
|
342
|
+
if (next != -1 && symbols[idx].used && symbols[next].used) {
|
|
343
|
+
int left_len = symbols[idx].end - symbols[idx].start;
|
|
344
|
+
int right_len = symbols[next].end - symbols[next].start;
|
|
345
|
+
char *left_str = malloc(left_len + 1);
|
|
346
|
+
char *right_str = malloc(right_len + 1);
|
|
347
|
+
if (left_str && right_str) {
|
|
348
|
+
memcpy(left_str, word + symbols[idx].start, left_len);
|
|
349
|
+
left_str[left_len] = '\0';
|
|
350
|
+
memcpy(right_str, word + symbols[next].start, right_len);
|
|
351
|
+
right_str[right_len] = '\0';
|
|
352
|
+
int rank = bpe_merge_rank(merges, left_str, right_str);
|
|
353
|
+
if (rank != -1 && rank < best_rank) {
|
|
354
|
+
best_rank = rank;
|
|
355
|
+
best_idx = idx;
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
free(left_str);
|
|
359
|
+
free(right_str);
|
|
242
360
|
}
|
|
243
|
-
|
|
361
|
+
idx = symbols[idx].next;
|
|
244
362
|
}
|
|
245
|
-
}
|
|
246
|
-
for (int i = 0; i < num_bigrams - 1; i++)
|
|
247
|
-
for (int j = i+1; j < num_bigrams; j++)
|
|
248
|
-
if (bigrams[i].rank > bigrams[j].rank) {
|
|
249
|
-
Bigram tmp = bigrams[i];
|
|
250
|
-
bigrams[i] = bigrams[j];
|
|
251
|
-
bigrams[j] = tmp;
|
|
252
|
-
}
|
|
253
363
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
int
|
|
257
|
-
|
|
258
|
-
symbols[
|
|
259
|
-
symbols[
|
|
260
|
-
|
|
261
|
-
if (symbols[
|
|
364
|
+
if (best_idx == -1) break;
|
|
365
|
+
|
|
366
|
+
int right_idx = symbols[best_idx].next;
|
|
367
|
+
symbols[best_idx].end = symbols[right_idx].end;
|
|
368
|
+
symbols[best_idx].next = symbols[right_idx].next;
|
|
369
|
+
symbols[right_idx].used = 0;
|
|
370
|
+
|
|
371
|
+
if (symbols[right_idx].next != -1) {
|
|
372
|
+
symbols[symbols[right_idx].next].prev = best_idx;
|
|
373
|
+
}
|
|
262
374
|
}
|
|
263
375
|
|
|
264
376
|
for (int i = 0; i < num_symbols; i++) {
|
|
265
|
-
if (
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
substr
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
377
|
+
if (symbols[i].used) {
|
|
378
|
+
int len = symbols[i].end - symbols[i].start;
|
|
379
|
+
char *substr = malloc(len + 1);
|
|
380
|
+
if (substr) {
|
|
381
|
+
memcpy(substr, word + symbols[i].start, len);
|
|
382
|
+
substr[len] = '\0';
|
|
383
|
+
int id = text_to_id(vocab_data, substr);
|
|
384
|
+
if (id != -1) token_ids[(*num_tokens)++] = id;
|
|
385
|
+
free(substr);
|
|
386
|
+
}
|
|
272
387
|
}
|
|
273
388
|
}
|
|
274
|
-
free(
|
|
389
|
+
free(symbols);
|
|
275
390
|
}
|
|
276
391
|
|
|
277
392
|
/* ------------------------------------------------------------------------- */
|
|
@@ -329,36 +444,41 @@ typedef struct {
|
|
|
329
444
|
int vocab_size;
|
|
330
445
|
int dim;
|
|
331
446
|
char **tokens;
|
|
332
|
-
float *float_data;
|
|
333
|
-
void *tensor_data;
|
|
334
|
-
int tensor_type;
|
|
335
447
|
void *mapped;
|
|
336
448
|
size_t mapped_size;
|
|
337
449
|
HashNode **table;
|
|
338
450
|
BPEMergeTable merges;
|
|
339
|
-
RegexPattern *pre_patterns;
|
|
340
|
-
int num_pre_patterns;
|
|
341
451
|
int unknown_token_id;
|
|
342
452
|
int bos_token_id;
|
|
343
453
|
int eos_token_id;
|
|
344
454
|
int vocab_type;
|
|
345
455
|
char space_marker[8];
|
|
456
|
+
int space_marker_len;
|
|
457
|
+
const void *raw_tensor_data;
|
|
458
|
+
int tensor_type;
|
|
459
|
+
size_t row_bytes;
|
|
460
|
+
int need_transpose;
|
|
461
|
+
uint64_t raw_dim0, raw_dim1;
|
|
462
|
+
int normalize;
|
|
346
463
|
} EmbedModel;
|
|
347
464
|
|
|
348
465
|
typedef struct {
|
|
349
466
|
EmbedModel *model;
|
|
350
467
|
} ruby_embedder;
|
|
351
468
|
|
|
352
|
-
static
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
469
|
+
static uint64_t vocab_hash(const char *s) {
|
|
470
|
+
uint64_t h = 0xcbf29ce484222325ULL;
|
|
471
|
+
while (*s) {
|
|
472
|
+
h ^= (uint64_t)(unsigned char)*s++;
|
|
473
|
+
h *= 0x100000001b3ULL;
|
|
474
|
+
}
|
|
356
475
|
return h % HASH_SIZE;
|
|
357
476
|
}
|
|
358
477
|
|
|
359
478
|
static void hset(EmbedModel *m, char *k, int id) {
|
|
360
|
-
|
|
479
|
+
uint64_t h = vocab_hash(k);
|
|
361
480
|
HashNode *n = malloc(sizeof(*n));
|
|
481
|
+
if (!n) return;
|
|
362
482
|
n->key = k;
|
|
363
483
|
n->id = id;
|
|
364
484
|
n->next = m->table[h];
|
|
@@ -366,7 +486,8 @@ static void hset(EmbedModel *m, char *k, int id) {
|
|
|
366
486
|
}
|
|
367
487
|
|
|
368
488
|
static int hget(EmbedModel *m, const char *k) {
|
|
369
|
-
|
|
489
|
+
if (!k || !m->table) return -1;
|
|
490
|
+
HashNode *n = m->table[vocab_hash(k)];
|
|
370
491
|
while (n) {
|
|
371
492
|
if (strcmp(n->key, k) == 0) return n->id;
|
|
372
493
|
n = n->next;
|
|
@@ -386,30 +507,51 @@ static void *map_file(const char *path, size_t *size) {
|
|
|
386
507
|
struct stat st;
|
|
387
508
|
if (fstat(fd, &st) != 0) { close(fd); return NULL; }
|
|
388
509
|
*size = st.st_size;
|
|
510
|
+
if (*size == 0) { close(fd); return NULL; }
|
|
389
511
|
void *data = mmap(NULL, *size, PROT_READ, MAP_PRIVATE, fd, 0);
|
|
390
512
|
close(fd);
|
|
391
513
|
return data == MAP_FAILED ? NULL : data;
|
|
392
514
|
}
|
|
393
515
|
|
|
394
516
|
/* ------------------------------------------------------------------------- */
|
|
395
|
-
// FP16 conversion
|
|
517
|
+
// FP16 conversion (corrected)
|
|
396
518
|
static float fp16_to_fp32(uint16_t h) {
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
519
|
+
const uint32_t sign = (h >> 15) & 1;
|
|
520
|
+
const uint32_t exp = (h >> 10) & 0x1F;
|
|
521
|
+
const uint32_t mant = h & 0x3FF;
|
|
522
|
+
|
|
523
|
+
uint32_t f;
|
|
524
|
+
if (exp == 0) {
|
|
525
|
+
if (mant == 0) {
|
|
526
|
+
f = sign << 31;
|
|
527
|
+
} else {
|
|
528
|
+
uint32_t e = 0;
|
|
529
|
+
uint32_t m = mant;
|
|
530
|
+
while (!(m & 0x400)) { m <<= 1; e++; }
|
|
531
|
+
f = (sign << 31) | ((127 - 15 - e + 1) << 23) | ((m & 0x3FF) << 13);
|
|
532
|
+
}
|
|
533
|
+
} else if (exp == 31) {
|
|
534
|
+
f = (sign << 31) | (0xFF << 23) | (mant << 13);
|
|
535
|
+
} else {
|
|
536
|
+
f = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
float result;
|
|
540
|
+
memcpy(&result, &f, sizeof(result));
|
|
541
|
+
return result;
|
|
403
542
|
}
|
|
404
543
|
|
|
405
544
|
/* ------------------------------------------------------------------------- */
|
|
406
|
-
// Block dequantization functions
|
|
545
|
+
// Block dequantization functions (correct sizes)
|
|
407
546
|
static void dequantize_row_q4_0(const void *vx, float *y, int k) {
|
|
408
|
-
const int nb = k /
|
|
547
|
+
const int nb = k / QK8_0;
|
|
409
548
|
const uint8_t *x = vx;
|
|
410
549
|
for (int i = 0; i < nb; i++) {
|
|
411
|
-
const
|
|
412
|
-
|
|
550
|
+
const uint8_t *block = x + i * 18;
|
|
551
|
+
uint16_t d16;
|
|
552
|
+
memcpy(&d16, block, 2);
|
|
553
|
+
const float d = fp16_to_fp32(d16);
|
|
554
|
+
const uint8_t *q = block + 2;
|
|
413
555
|
for (int j = 0; j < 32; j++) {
|
|
414
556
|
const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
|
|
415
557
|
y[i*32 + j] = (v - 8.0f) * d;
|
|
@@ -418,12 +560,16 @@ static void dequantize_row_q4_0(const void *vx, float *y, int k) {
|
|
|
418
560
|
}
|
|
419
561
|
|
|
420
562
|
static void dequantize_row_q4_1(const void *vx, float *y, int k) {
|
|
421
|
-
const int nb = k /
|
|
563
|
+
const int nb = k / QK8_0;
|
|
422
564
|
const uint8_t *x = vx;
|
|
423
565
|
for (int i = 0; i < nb; i++) {
|
|
424
|
-
const
|
|
425
|
-
|
|
426
|
-
|
|
566
|
+
const uint8_t *block = x + i * 20;
|
|
567
|
+
uint16_t d16, m16;
|
|
568
|
+
memcpy(&d16, block, 2);
|
|
569
|
+
memcpy(&m16, block + 2, 2);
|
|
570
|
+
const float d = fp16_to_fp32(d16);
|
|
571
|
+
const float m = fp16_to_fp32(m16);
|
|
572
|
+
const uint8_t *q = block + 4;
|
|
427
573
|
for (int j = 0; j < 32; j++) {
|
|
428
574
|
const int v = (q[j/2] >> (4*(j%2))) & 0x0F;
|
|
429
575
|
y[i*32 + j] = v * d + m;
|
|
@@ -432,14 +578,16 @@ static void dequantize_row_q4_1(const void *vx, float *y, int k) {
|
|
|
432
578
|
}
|
|
433
579
|
|
|
434
580
|
static void dequantize_row_q5_0(const void *vx, float *y, int k) {
|
|
435
|
-
const int nb = k /
|
|
581
|
+
const int nb = k / QK8_0;
|
|
436
582
|
const uint8_t *x = vx;
|
|
437
583
|
for (int i = 0; i < nb; i++) {
|
|
438
|
-
const
|
|
439
|
-
|
|
440
|
-
|
|
584
|
+
const uint8_t *block = x + i * 22;
|
|
585
|
+
uint16_t d16;
|
|
586
|
+
memcpy(&d16, block, 2);
|
|
587
|
+
const float d = fp16_to_fp32(d16);
|
|
441
588
|
uint32_t qh32;
|
|
442
|
-
memcpy(&qh32,
|
|
589
|
+
memcpy(&qh32, block + 2, 4);
|
|
590
|
+
const uint8_t *ql = block + 6;
|
|
443
591
|
for (int j = 0; j < 32; j++) {
|
|
444
592
|
const uint8_t vh = (qh32 >> j) & 1;
|
|
445
593
|
const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
|
|
@@ -449,15 +597,18 @@ static void dequantize_row_q5_0(const void *vx, float *y, int k) {
|
|
|
449
597
|
}
|
|
450
598
|
|
|
451
599
|
static void dequantize_row_q5_1(const void *vx, float *y, int k) {
|
|
452
|
-
const int nb = k /
|
|
600
|
+
const int nb = k / QK8_0;
|
|
453
601
|
const uint8_t *x = vx;
|
|
454
602
|
for (int i = 0; i < nb; i++) {
|
|
455
|
-
const
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
603
|
+
const uint8_t *block = x + i * 24;
|
|
604
|
+
uint16_t d16, m16;
|
|
605
|
+
memcpy(&d16, block, 2);
|
|
606
|
+
memcpy(&m16, block + 2, 2);
|
|
607
|
+
const float d = fp16_to_fp32(d16);
|
|
608
|
+
const float m = fp16_to_fp32(m16);
|
|
459
609
|
uint32_t qh32;
|
|
460
|
-
memcpy(&qh32,
|
|
610
|
+
memcpy(&qh32, block + 4, 4);
|
|
611
|
+
const uint8_t *ql = block + 8;
|
|
461
612
|
for (int j = 0; j < 32; j++) {
|
|
462
613
|
const uint8_t vh = (qh32 >> j) & 1;
|
|
463
614
|
const int v = ((ql[j/2] >> (4*(j%2))) & 0x0F) | (vh << 4);
|
|
@@ -467,11 +618,13 @@ static void dequantize_row_q5_1(const void *vx, float *y, int k) {
|
|
|
467
618
|
}
|
|
468
619
|
|
|
469
620
|
static void dequantize_row_q8_0(const void *vx, float *y, int k) {
|
|
470
|
-
const int nb = k /
|
|
621
|
+
const int nb = k / QK8_0;
|
|
471
622
|
const uint8_t *x = vx;
|
|
472
623
|
for (int i = 0; i < nb; i++) {
|
|
473
|
-
const
|
|
474
|
-
|
|
624
|
+
const uint8_t *block = x + i * 34;
|
|
625
|
+
float d;
|
|
626
|
+
memcpy(&d, block, 4);
|
|
627
|
+
const int8_t *q = (const int8_t*)(block + 4);
|
|
475
628
|
for (int j = 0; j < 32; j++) {
|
|
476
629
|
y[i*32 + j] = (float)q[j] * d;
|
|
477
630
|
}
|
|
@@ -479,197 +632,315 @@ static void dequantize_row_q8_0(const void *vx, float *y, int k) {
|
|
|
479
632
|
}
|
|
480
633
|
|
|
481
634
|
static void dequantize_row_q8_1(const void *vx, float *y, int k) {
|
|
482
|
-
const int nb = k /
|
|
635
|
+
const int nb = k / QK8_0;
|
|
483
636
|
const uint8_t *x = vx;
|
|
484
637
|
for (int i = 0; i < nb; i++) {
|
|
485
|
-
const
|
|
486
|
-
|
|
487
|
-
|
|
638
|
+
const uint8_t *block = x + i * 40;
|
|
639
|
+
float d, s;
|
|
640
|
+
memcpy(&d, block, 4);
|
|
641
|
+
memcpy(&s, block + 4, 4);
|
|
642
|
+
const int8_t *q = (const int8_t*)(block + 8);
|
|
488
643
|
for (int j = 0; j < 32; j++) {
|
|
489
644
|
y[i*32 + j] = (float)q[j] * d + s;
|
|
490
645
|
}
|
|
491
646
|
}
|
|
492
647
|
}
|
|
493
648
|
|
|
649
|
+
// K-quant scale helpers
|
|
650
|
+
static inline void get_scale_min_k4(int j, const uint8_t *q, uint8_t *d, uint8_t *m) {
|
|
651
|
+
if (j < 4) {
|
|
652
|
+
*d = q[j] & 63;
|
|
653
|
+
*m = q[j + 4] & 63;
|
|
654
|
+
} else {
|
|
655
|
+
*d = (q[j+4] & 0xF) | ((q[j-3] >> 6) << 4);
|
|
656
|
+
*m = (q[j+4] >> 4) | ((q[j-1] >> 6) << 4);
|
|
657
|
+
}
|
|
658
|
+
}
|
|
659
|
+
|
|
494
660
|
static void dequantize_row_q2_K(const void *vx, float *y, int k) {
|
|
495
|
-
const int nb = k /
|
|
661
|
+
const int nb = k / QK_K;
|
|
496
662
|
const uint8_t *x = vx;
|
|
497
663
|
for (int i = 0; i < nb; i++) {
|
|
498
|
-
const
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
664
|
+
const uint8_t *block = x + i * 84;
|
|
665
|
+
uint16_t d16, dmin16;
|
|
666
|
+
memcpy(&d16, block, 2);
|
|
667
|
+
memcpy(&dmin16, block + 2, 2);
|
|
668
|
+
const float d = fp16_to_fp32(d16);
|
|
669
|
+
const float min = fp16_to_fp32(dmin16);
|
|
670
|
+
const uint8_t *scales = block + 4;
|
|
671
|
+
const uint8_t *q = block + 20;
|
|
672
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
673
|
+
const float dl = d * (scales[j/64] & 0xF);
|
|
674
|
+
const float ml = min * (scales[j/64] >> 4);
|
|
675
|
+
for (int l = 0; l < 64; l++) {
|
|
506
676
|
const int v = (q[(j+l)/4] >> (2*((j+l)%4))) & 0x03;
|
|
507
|
-
|
|
508
|
-
const float ml = m * (ms - 32);
|
|
509
|
-
y[i*256 + j + l] = v * dl + ml;
|
|
677
|
+
y[i*QK_K + j + l] = v * dl + ml;
|
|
510
678
|
}
|
|
511
679
|
}
|
|
512
680
|
}
|
|
513
681
|
}
|
|
514
682
|
|
|
515
683
|
static void dequantize_row_q3_K(const void *vx, float *y, int k) {
|
|
516
|
-
const int nb = k /
|
|
684
|
+
const int nb = k / QK_K;
|
|
517
685
|
const uint8_t *x = vx;
|
|
518
686
|
for (int i = 0; i < nb; i++) {
|
|
519
|
-
const
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
const
|
|
523
|
-
|
|
687
|
+
const uint8_t *block = x + i * 110;
|
|
688
|
+
uint16_t d16;
|
|
689
|
+
memcpy(&d16, block, 2);
|
|
690
|
+
const float d = fp16_to_fp32(d16);
|
|
691
|
+
const uint8_t *hmask = block + 2;
|
|
692
|
+
const uint8_t *q = block + 34;
|
|
693
|
+
const uint8_t *scales = block + 98;
|
|
694
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
524
695
|
const uint8_t ls1 = scales[j/64] & 0x1F;
|
|
525
|
-
const uint8_t ls2 = (scales[j/64] >>
|
|
526
|
-
const uint8_t
|
|
696
|
+
const uint8_t ls2 = (scales[j/64] >> 5) | ((scales[j/64 + 1] & 0x7) << 3);
|
|
697
|
+
const uint8_t ls3 = ((scales[j/64 + 1] >> 3) & 0x1F);
|
|
698
|
+
const uint8_t ls4 = (scales[j/64 + 1] >> 8);
|
|
527
699
|
for (int l = 0; l < 64; l++) {
|
|
528
700
|
int v = (q[(j+l)/2] >> (4*((j+l)%2))) & 0x0F;
|
|
529
|
-
const int bit = (
|
|
701
|
+
const int bit = (hmask[(j+l)/8] >> ((j+l)%8)) & 1;
|
|
530
702
|
v |= bit << 4;
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
703
|
+
float ls;
|
|
704
|
+
if (l < 16) ls = ls1;
|
|
705
|
+
else if (l < 32) ls = ls2;
|
|
706
|
+
else if (l < 48) ls = ls3;
|
|
707
|
+
else ls = ls4;
|
|
708
|
+
y[i*QK_K + j + l] = (v - 32.0f) * d * ls;
|
|
534
709
|
}
|
|
535
710
|
}
|
|
536
711
|
}
|
|
537
712
|
}
|
|
538
713
|
|
|
539
714
|
static void dequantize_row_q4_K(const void *vx, float *y, int k) {
|
|
540
|
-
const int nb = k /
|
|
715
|
+
const int nb = k / QK_K;
|
|
541
716
|
const uint8_t *x = vx;
|
|
542
717
|
for (int i = 0; i < nb; i++) {
|
|
543
|
-
const
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
718
|
+
const uint8_t *block = x + i * 144;
|
|
719
|
+
uint16_t d16, dmin16;
|
|
720
|
+
memcpy(&d16, block, 2);
|
|
721
|
+
memcpy(&dmin16, block + 2, 2);
|
|
722
|
+
const float d = fp16_to_fp32(d16);
|
|
723
|
+
const float min = fp16_to_fp32(dmin16);
|
|
724
|
+
const uint8_t *scales = block + 4;
|
|
725
|
+
const uint8_t *q = block + 16;
|
|
726
|
+
int is = 0;
|
|
727
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
728
|
+
uint8_t sc, m;
|
|
729
|
+
get_scale_min_k4(is, scales, &sc, &m);
|
|
730
|
+
float d1 = d * sc;
|
|
731
|
+
float m1 = min * m;
|
|
732
|
+
get_scale_min_k4(is + 1, scales, &sc, &m);
|
|
733
|
+
float d2 = d * sc;
|
|
734
|
+
float m2 = min * m;
|
|
735
|
+
for (int l = 0; l < 32; l++) {
|
|
736
|
+
y[i*QK_K + j + l] = d1 * (q[l] & 0xF) - m1;
|
|
737
|
+
}
|
|
550
738
|
for (int l = 0; l < 32; l++) {
|
|
551
|
-
|
|
552
|
-
const float dl = d * (ls - 32);
|
|
553
|
-
const float ml = m * (ms - 2);
|
|
554
|
-
y[i*256 + j + l] = v * dl + ml;
|
|
739
|
+
y[i*QK_K + j + 32 + l] = d2 * (q[l] >> 4) - m2;
|
|
555
740
|
}
|
|
741
|
+
q += 32;
|
|
742
|
+
is += 2;
|
|
556
743
|
}
|
|
557
744
|
}
|
|
558
745
|
}
|
|
559
746
|
|
|
560
747
|
static void dequantize_row_q5_K(const void *vx, float *y, int k) {
|
|
561
|
-
const int nb = k /
|
|
748
|
+
const int nb = k / QK_K;
|
|
562
749
|
const uint8_t *x = vx;
|
|
563
750
|
for (int i = 0; i < nb; i++) {
|
|
564
|
-
const
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
const
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
751
|
+
const uint8_t *block = x + i * 176;
|
|
752
|
+
uint16_t d16, dmin16;
|
|
753
|
+
memcpy(&d16, block, 2);
|
|
754
|
+
memcpy(&dmin16, block + 2, 2);
|
|
755
|
+
const float d = fp16_to_fp32(d16);
|
|
756
|
+
const float min = fp16_to_fp32(dmin16);
|
|
757
|
+
const uint8_t *scales = block + 4;
|
|
758
|
+
const uint8_t *qh = block + 16;
|
|
759
|
+
const uint8_t *ql = block + 48;
|
|
760
|
+
int is = 0;
|
|
761
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
762
|
+
uint8_t sc, m;
|
|
763
|
+
get_scale_min_k4(is, scales, &sc, &m);
|
|
764
|
+
float d1 = d * sc;
|
|
765
|
+
float m1 = min * m;
|
|
766
|
+
get_scale_min_k4(is + 1, scales, &sc, &m);
|
|
767
|
+
float d2 = d * sc;
|
|
768
|
+
float m2 = min * m;
|
|
572
769
|
for (int l = 0; l < 32; l++) {
|
|
573
|
-
int
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
770
|
+
int vh = (qh[j/64 * 4 + l/8] >> (l%8)) & 1;
|
|
771
|
+
int v = (ql[l] & 0xF) | (vh << 4);
|
|
772
|
+
y[i*QK_K + j + l] = d1 * v - m1;
|
|
773
|
+
}
|
|
774
|
+
for (int l = 0; l < 32; l++) {
|
|
775
|
+
int vh = (qh[j/64 * 4 + 4 + l/8] >> (l%8)) & 1;
|
|
776
|
+
int v = (ql[l] >> 4) | (vh << 4);
|
|
777
|
+
y[i*QK_K + j + 32 + l] = d2 * v - m2;
|
|
579
778
|
}
|
|
779
|
+
ql += 32;
|
|
780
|
+
is += 2;
|
|
580
781
|
}
|
|
581
782
|
}
|
|
582
783
|
}
|
|
583
784
|
|
|
584
785
|
static void dequantize_row_q6_K(const void *vx, float *y, int k) {
|
|
585
|
-
const int nb = k /
|
|
786
|
+
const int nb = k / QK_K;
|
|
586
787
|
const uint8_t *x = vx;
|
|
587
788
|
for (int i = 0; i < nb; i++) {
|
|
588
|
-
const
|
|
589
|
-
const uint8_t *
|
|
590
|
-
const uint8_t *qh =
|
|
591
|
-
const
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
v
|
|
598
|
-
y[i*
|
|
789
|
+
const uint8_t *block = x + i * 210;
|
|
790
|
+
const uint8_t *ql = block;
|
|
791
|
+
const uint8_t *qh = block + 128;
|
|
792
|
+
const int8_t *scales = (const int8_t*)(block + 192);
|
|
793
|
+
uint16_t d16;
|
|
794
|
+
memcpy(&d16, block + 208, 2);
|
|
795
|
+
const float d = fp16_to_fp32(d16);
|
|
796
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
797
|
+
for (int l = 0; l < 32; l++) {
|
|
798
|
+
int v = (ql[j/2 + l] & 0xF) | (((qh[j/4 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
799
|
+
y[i*QK_K + j + l] = v * d * scales[j/128 * 8 + l/4];
|
|
800
|
+
}
|
|
801
|
+
for (int l = 0; l < 32; l++) {
|
|
802
|
+
int v = (ql[j/2 + 32 + l] >> 4) | (((qh[j/4 + 16 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
803
|
+
y[i*QK_K + j + 32 + l] = v * d * scales[j/128 * 8 + 8 + l/4];
|
|
804
|
+
}
|
|
805
|
+
for (int l = 0; l < 32; l++) {
|
|
806
|
+
int v = (ql[j/2 + 64 + l] & 0xF) | (((qh[j/4 + 32 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
807
|
+
y[i*QK_K + j + 64 + l] = v * d * scales[j/128 * 8 + 4 + l/4];
|
|
808
|
+
}
|
|
809
|
+
for (int l = 0; l < 32; l++) {
|
|
810
|
+
int v = (ql[j/2 + 96 + l] >> 4) | (((qh[j/4 + 48 + l/2] >> ((l%2)*4)) & 0xF) << 4);
|
|
811
|
+
y[i*QK_K + j + 96 + l] = v * d * scales[j/128 * 8 + 12 + l/4];
|
|
599
812
|
}
|
|
600
813
|
}
|
|
601
814
|
}
|
|
602
815
|
}
|
|
603
816
|
|
|
604
817
|
static void dequantize_row_q8_K(const void *vx, float *y, int k) {
|
|
605
|
-
const int nb = k /
|
|
818
|
+
const int nb = k / QK_K;
|
|
606
819
|
const uint8_t *x = vx;
|
|
607
820
|
for (int i = 0; i < nb; i++) {
|
|
608
|
-
const
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
y[i*256 + j + l] = (float)q[j+l] * d * ls;
|
|
615
|
-
}
|
|
821
|
+
const uint8_t *block = x + i * 292;
|
|
822
|
+
float d;
|
|
823
|
+
memcpy(&d, block, 4);
|
|
824
|
+
const int8_t *q = (const int8_t*)(block + 4);
|
|
825
|
+
for (int j = 0; j < QK_K; j++) {
|
|
826
|
+
y[i*QK_K + j] = (float)q[j] * d;
|
|
616
827
|
}
|
|
617
828
|
}
|
|
618
829
|
}
|
|
619
830
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
return out;
|
|
831
|
+
// Lazy single-row dequantization
|
|
832
|
+
static void dequantize_row_lazy(const EmbedModel *m, int row, float *out) {
|
|
833
|
+
if (!m->raw_tensor_data || row < 0 || row >= m->vocab_size) {
|
|
834
|
+
memset(out, 0, sizeof(float) * m->dim);
|
|
835
|
+
return;
|
|
626
836
|
}
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
837
|
+
|
|
838
|
+
const uint8_t *raw;
|
|
839
|
+
int effective_cols;
|
|
840
|
+
|
|
841
|
+
if (m->need_transpose) {
|
|
842
|
+
int src_row_size;
|
|
843
|
+
switch (m->tensor_type) {
|
|
844
|
+
case GGML_TYPE_F32: src_row_size = m->raw_dim1 * sizeof(float); break;
|
|
845
|
+
case GGML_TYPE_F16: src_row_size = m->raw_dim1 * sizeof(uint16_t); break;
|
|
846
|
+
default: {
|
|
847
|
+
size_t rb = 0;
|
|
848
|
+
int nc = (int)m->raw_dim1;
|
|
849
|
+
switch (m->tensor_type) {
|
|
850
|
+
case GGML_TYPE_Q4_0: rb = (nc / 32) * 18; break;
|
|
851
|
+
case GGML_TYPE_Q4_1: rb = (nc / 32) * 20; break;
|
|
852
|
+
case GGML_TYPE_Q5_0: rb = (nc / 32) * 22; break;
|
|
853
|
+
case GGML_TYPE_Q5_1: rb = (nc / 32) * 24; break;
|
|
854
|
+
case GGML_TYPE_Q8_0: rb = (nc / 32) * 34; break;
|
|
855
|
+
case GGML_TYPE_Q8_1: rb = (nc / 32) * 40; break;
|
|
856
|
+
case GGML_TYPE_Q2_K: rb = (nc / 256) * 84; break;
|
|
857
|
+
case GGML_TYPE_Q3_K: rb = (nc / 256) * 110; break;
|
|
858
|
+
case GGML_TYPE_Q4_K: rb = (nc / 256) * 144; break;
|
|
859
|
+
case GGML_TYPE_Q5_K: rb = (nc / 256) * 176; break;
|
|
860
|
+
case GGML_TYPE_Q6_K: rb = (nc / 256) * 210; break;
|
|
861
|
+
case GGML_TYPE_Q8_K: rb = (nc / 256) * 292; break;
|
|
862
|
+
default: src_row_size = 0; return;
|
|
863
|
+
}
|
|
864
|
+
src_row_size = (int)rb;
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
float *temp_row = malloc(m->raw_dim1 * sizeof(float));
|
|
868
|
+
if (!temp_row) return;
|
|
869
|
+
for (int col = 0; col < m->dim; col++) {
|
|
870
|
+
const uint8_t *src_row = (const uint8_t*)m->raw_tensor_data + col * src_row_size;
|
|
871
|
+
if (m->tensor_type == GGML_TYPE_F32) {
|
|
872
|
+
float val;
|
|
873
|
+
memcpy(&val, src_row + row * sizeof(float), sizeof(float));
|
|
874
|
+
out[col] = val;
|
|
875
|
+
} else if (m->tensor_type == GGML_TYPE_F16) {
|
|
876
|
+
uint16_t val;
|
|
877
|
+
memcpy(&val, src_row + row * sizeof(uint16_t), sizeof(uint16_t));
|
|
878
|
+
out[col] = fp16_to_fp32(val);
|
|
879
|
+
} else {
|
|
880
|
+
memset(out, 0, sizeof(float) * m->dim);
|
|
881
|
+
free(temp_row);
|
|
882
|
+
return;
|
|
883
|
+
}
|
|
633
884
|
}
|
|
634
|
-
|
|
885
|
+
free(temp_row);
|
|
886
|
+
return;
|
|
635
887
|
}
|
|
636
888
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
case
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
case
|
|
652
|
-
case
|
|
653
|
-
case
|
|
654
|
-
case
|
|
655
|
-
case
|
|
889
|
+
raw = (const uint8_t*)m->raw_tensor_data + row * m->row_bytes;
|
|
890
|
+
effective_cols = m->dim;
|
|
891
|
+
|
|
892
|
+
switch (m->tensor_type) {
|
|
893
|
+
case GGML_TYPE_F32:
|
|
894
|
+
memcpy(out, raw, effective_cols * sizeof(float));
|
|
895
|
+
break;
|
|
896
|
+
case GGML_TYPE_F16:
|
|
897
|
+
for (int j = 0; j < effective_cols; j++) {
|
|
898
|
+
uint16_t h;
|
|
899
|
+
memcpy(&h, raw + j * sizeof(uint16_t), sizeof(uint16_t));
|
|
900
|
+
out[j] = fp16_to_fp32(h);
|
|
901
|
+
}
|
|
902
|
+
break;
|
|
903
|
+
case GGML_TYPE_Q4_0: dequantize_row_q4_0(raw, out, effective_cols); break;
|
|
904
|
+
case GGML_TYPE_Q4_1: dequantize_row_q4_1(raw, out, effective_cols); break;
|
|
905
|
+
case GGML_TYPE_Q5_0: dequantize_row_q5_0(raw, out, effective_cols); break;
|
|
906
|
+
case GGML_TYPE_Q5_1: dequantize_row_q5_1(raw, out, effective_cols); break;
|
|
907
|
+
case GGML_TYPE_Q8_0: dequantize_row_q8_0(raw, out, effective_cols); break;
|
|
908
|
+
case GGML_TYPE_Q8_1: dequantize_row_q8_1(raw, out, effective_cols); break;
|
|
909
|
+
case GGML_TYPE_Q2_K: dequantize_row_q2_K(raw, out, effective_cols); break;
|
|
910
|
+
case GGML_TYPE_Q3_K: dequantize_row_q3_K(raw, out, effective_cols); break;
|
|
911
|
+
case GGML_TYPE_Q4_K: dequantize_row_q4_K(raw, out, effective_cols); break;
|
|
912
|
+
case GGML_TYPE_Q5_K: dequantize_row_q5_K(raw, out, effective_cols); break;
|
|
913
|
+
case GGML_TYPE_Q6_K: dequantize_row_q6_K(raw, out, effective_cols); break;
|
|
914
|
+
case GGML_TYPE_Q8_K: dequantize_row_q8_K(raw, out, effective_cols); break;
|
|
656
915
|
default:
|
|
657
|
-
|
|
658
|
-
return NULL;
|
|
916
|
+
memset(out, 0, sizeof(float) * effective_cols);
|
|
659
917
|
}
|
|
660
918
|
|
|
661
|
-
for (int
|
|
662
|
-
|
|
919
|
+
for (int j = 0; j < effective_cols; j++) {
|
|
920
|
+
if (isnan(out[j]) || isinf(out[j]) || fabsf(out[j]) > 1e10f) {
|
|
921
|
+
out[j] = 0.0f;
|
|
922
|
+
}
|
|
663
923
|
}
|
|
924
|
+
}
|
|
664
925
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
926
|
+
static size_t get_row_bytes(int type, int n_cols) {
|
|
927
|
+
switch (type) {
|
|
928
|
+
case GGML_TYPE_F32: return n_cols * sizeof(float);
|
|
929
|
+
case GGML_TYPE_F16: return n_cols * sizeof(uint16_t);
|
|
930
|
+
case GGML_TYPE_Q4_0: return (n_cols / 32) * 18;
|
|
931
|
+
case GGML_TYPE_Q4_1: return (n_cols / 32) * 20;
|
|
932
|
+
case GGML_TYPE_Q5_0: return (n_cols / 32) * 22;
|
|
933
|
+
case GGML_TYPE_Q5_1: return (n_cols / 32) * 24;
|
|
934
|
+
case GGML_TYPE_Q8_0: return (n_cols / 32) * 34;
|
|
935
|
+
case GGML_TYPE_Q8_1: return (n_cols / 32) * 40;
|
|
936
|
+
case GGML_TYPE_Q2_K: return (n_cols / 256) * 84;
|
|
937
|
+
case GGML_TYPE_Q3_K: return (n_cols / 256) * 110;
|
|
938
|
+
case GGML_TYPE_Q4_K: return (n_cols / 256) * 144;
|
|
939
|
+
case GGML_TYPE_Q5_K: return (n_cols / 256) * 176;
|
|
940
|
+
case GGML_TYPE_Q6_K: return (n_cols / 256) * 210;
|
|
941
|
+
case GGML_TYPE_Q8_K: return (n_cols / 256) * 292;
|
|
942
|
+
default: return 0;
|
|
671
943
|
}
|
|
672
|
-
return out;
|
|
673
944
|
}
|
|
674
945
|
|
|
675
946
|
/* ------------------------------------------------------------------------- */
|
|
@@ -685,7 +956,7 @@ static int skip_value(uint8_t **p, uint8_t *end, uint32_t type) {
|
|
|
685
956
|
case 9: {
|
|
686
957
|
uint32_t subtype = rd32(p, end);
|
|
687
958
|
uint64_t n = rd64(p, end);
|
|
688
|
-
for (uint64_t i = 0; i < n; i++)
|
|
959
|
+
for (uint64_t i = 0; i < n && i < 1000000; i++)
|
|
689
960
|
if (!skip_value(p, end, subtype)) return 0;
|
|
690
961
|
return 1;
|
|
691
962
|
}
|
|
@@ -711,13 +982,8 @@ static void free_model_contents(EmbedModel *m) {
|
|
|
711
982
|
}
|
|
712
983
|
free(m->table);
|
|
713
984
|
}
|
|
714
|
-
if (m->float_data) free(m->float_data);
|
|
715
985
|
if (m->mapped) munmap(m->mapped, m->mapped_size);
|
|
716
986
|
bpe_merge_table_free(&m->merges);
|
|
717
|
-
if (m->pre_patterns) {
|
|
718
|
-
for (int i = 0; i < m->num_pre_patterns; i++) free(m->pre_patterns[i].pattern);
|
|
719
|
-
free(m->pre_patterns);
|
|
720
|
-
}
|
|
721
987
|
free(m);
|
|
722
988
|
}
|
|
723
989
|
|
|
@@ -741,35 +1007,29 @@ static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
|
|
|
741
1007
|
|
|
742
1008
|
/* ------------------------------------------------------------------------- */
|
|
743
1009
|
static void detect_space_marker(EmbedModel *m) {
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
1010
|
+
int marker_count[4] = {0};
|
|
1011
|
+
const char *markers[] = {"▁", "Ġ", "ĉ", " "};
|
|
1012
|
+
int marker_lens[] = {3, 2, 2, 1};
|
|
1013
|
+
|
|
1014
|
+
for (int i = 0; i < m->vocab_size && i < 5000; i++) {
|
|
1015
|
+
for (int j = 0; j < 3; j++) {
|
|
1016
|
+
if (strncmp(m->tokens[i], markers[j], marker_lens[j]) == 0) {
|
|
1017
|
+
marker_count[j]++;
|
|
752
1018
|
}
|
|
753
1019
|
}
|
|
1020
|
+
if (m->tokens[i][0] == ' ' && strlen(m->tokens[i]) > 1) {
|
|
1021
|
+
marker_count[3]++;
|
|
1022
|
+
}
|
|
1023
|
+
}
|
|
1024
|
+
|
|
1025
|
+
int best = 0;
|
|
1026
|
+
for (int i = 1; i < 4; i++) {
|
|
1027
|
+
if (marker_count[i] > marker_count[best]) best = i;
|
|
754
1028
|
}
|
|
755
|
-
m->space_marker[0] = '\0';
|
|
756
|
-
}
|
|
757
1029
|
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
"[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
|
|
762
|
-
"\\p{N}{1,3}",
|
|
763
|
-
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
|
|
764
|
-
"\\s*[\\r\\n]+",
|
|
765
|
-
"\\s+(?!\\S)",
|
|
766
|
-
"\\s+"
|
|
767
|
-
};
|
|
768
|
-
m->num_pre_patterns = sizeof(default_patterns)/sizeof(default_patterns[0]);
|
|
769
|
-
m->pre_patterns = malloc(m->num_pre_patterns * sizeof(RegexPattern));
|
|
770
|
-
for (int i = 0; i < m->num_pre_patterns; i++) {
|
|
771
|
-
m->pre_patterns[i].pattern = strdup(default_patterns[i]);
|
|
772
|
-
m->pre_patterns[i].pattern_len = strlen(default_patterns[i]);
|
|
1030
|
+
if (marker_count[best] > 10) {
|
|
1031
|
+
strcpy(m->space_marker, markers[best]);
|
|
1032
|
+
m->space_marker_len = marker_lens[best];
|
|
773
1033
|
}
|
|
774
1034
|
}
|
|
775
1035
|
|
|
@@ -777,10 +1037,10 @@ static void parse_merge(const char *merge_str, char **left, char **right) {
|
|
|
777
1037
|
const char *space = strchr(merge_str, ' ');
|
|
778
1038
|
if (space) {
|
|
779
1039
|
int left_len = space - merge_str;
|
|
780
|
-
*left = malloc(left_len+1);
|
|
1040
|
+
*left = malloc(left_len + 1);
|
|
781
1041
|
memcpy(*left, merge_str, left_len);
|
|
782
1042
|
(*left)[left_len] = '\0';
|
|
783
|
-
*right = strdup(space+1);
|
|
1043
|
+
*right = strdup(space + 1);
|
|
784
1044
|
} else {
|
|
785
1045
|
*left = strdup(merge_str);
|
|
786
1046
|
*right = strdup("");
|
|
@@ -793,12 +1053,15 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
793
1053
|
uint8_t *base = map_file(path, &sz);
|
|
794
1054
|
if (!base) return NULL;
|
|
795
1055
|
uint8_t *cur = base, *end = base + sz;
|
|
796
|
-
if (memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
|
|
1056
|
+
if (sz < 4 || memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
|
|
797
1057
|
cur += 4;
|
|
798
1058
|
uint32_t version = rd32(&cur, end);
|
|
1059
|
+
(void)version;
|
|
799
1060
|
uint64_t n_tensors = rd64(&cur, end);
|
|
800
1061
|
uint64_t n_kv = rd64(&cur, end);
|
|
801
1062
|
|
|
1063
|
+
if (n_kv > 1000000 || n_tensors > 1000000) { munmap(base, sz); return NULL; }
|
|
1064
|
+
|
|
802
1065
|
EmbedModel *m = calloc(1, sizeof(*m));
|
|
803
1066
|
if (!m) { munmap(base, sz); return NULL; }
|
|
804
1067
|
m->mapped = base;
|
|
@@ -806,22 +1069,22 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
806
1069
|
m->table = calloc(HASH_SIZE, sizeof(HashNode*));
|
|
807
1070
|
if (!m->table) { free_model_contents(m); return NULL; }
|
|
808
1071
|
bpe_merge_table_init(&m->merges);
|
|
809
|
-
setup_default_pre_patterns(m);
|
|
810
1072
|
m->unknown_token_id = -1;
|
|
811
1073
|
m->bos_token_id = -1;
|
|
812
1074
|
m->eos_token_id = -1;
|
|
813
1075
|
m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
|
|
814
|
-
m->
|
|
1076
|
+
m->normalize = NORM_NONE;
|
|
815
1077
|
|
|
816
1078
|
int vocab_found = 0;
|
|
817
1079
|
for (uint64_t i = 0; i < n_kv; i++) {
|
|
818
1080
|
char *key = rdstr(&cur, end);
|
|
819
1081
|
if (!key) { free_model_contents(m); return NULL; }
|
|
820
1082
|
uint32_t type = rd32(&cur, end);
|
|
1083
|
+
|
|
821
1084
|
if ((strcmp(key, "tokenizer.ggml.tokens") == 0 || strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
|
|
822
1085
|
uint32_t subtype = rd32(&cur, end);
|
|
823
1086
|
uint64_t n = rd64(&cur, end);
|
|
824
|
-
if (subtype != 8) { free(key); free_model_contents(m); return NULL; }
|
|
1087
|
+
if (subtype != 8 || n > 1000000) { free(key); free_model_contents(m); return NULL; }
|
|
825
1088
|
m->tokens = malloc(sizeof(char*) * n);
|
|
826
1089
|
if (!m->tokens) { free(key); free_model_contents(m); return NULL; }
|
|
827
1090
|
m->vocab_size = (int)n;
|
|
@@ -841,8 +1104,17 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
841
1104
|
if (merge_str) {
|
|
842
1105
|
char *left, *right;
|
|
843
1106
|
parse_merge(merge_str, &left, &right);
|
|
844
|
-
bpe_merge_table_add(&m->merges, left, right,
|
|
845
|
-
free(left);
|
|
1107
|
+
bpe_merge_table_add(&m->merges, left, right, (int)j);
|
|
1108
|
+
free(left);
|
|
1109
|
+
free(right);
|
|
1110
|
+
free(merge_str);
|
|
1111
|
+
} else {
|
|
1112
|
+
break;
|
|
1113
|
+
}
|
|
1114
|
+
}
|
|
1115
|
+
if (n > MAX_MERGES) {
|
|
1116
|
+
for (uint64_t j = MAX_MERGES; j < n; j++) {
|
|
1117
|
+
char *merge_str = rdstr(&cur, end);
|
|
846
1118
|
free(merge_str);
|
|
847
1119
|
}
|
|
848
1120
|
}
|
|
@@ -852,24 +1124,32 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
852
1124
|
} else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
|
|
853
1125
|
char *model_type = rdstr(&cur, end);
|
|
854
1126
|
if (model_type) {
|
|
855
|
-
if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0
|
|
856
|
-
|
|
1127
|
+
if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0 ||
|
|
1128
|
+
strcmp(model_type, "phi") == 0 || strcmp(model_type, "qwen") == 0)
|
|
1129
|
+
m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
|
|
1130
|
+
else if (strcmp(model_type, "bert") == 0)
|
|
1131
|
+
m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
|
|
1132
|
+
else if (strcmp(model_type, "spm") == 0)
|
|
1133
|
+
m->vocab_type = LLAMA_VOCAB_TYPE_SPM;
|
|
857
1134
|
free(model_type);
|
|
858
1135
|
}
|
|
859
1136
|
} else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
|
|
860
1137
|
char *pre = rdstr(&cur, end);
|
|
861
|
-
|
|
1138
|
+
free(pre);
|
|
862
1139
|
} else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
|
|
863
|
-
m->unknown_token_id = rd32(&cur, end);
|
|
1140
|
+
m->unknown_token_id = (int)rd32(&cur, end);
|
|
864
1141
|
} else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
|
|
865
|
-
m->bos_token_id = rd32(&cur, end);
|
|
1142
|
+
m->bos_token_id = (int)rd32(&cur, end);
|
|
866
1143
|
} else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
|
|
867
|
-
m->eos_token_id = rd32(&cur, end);
|
|
1144
|
+
m->eos_token_id = (int)rd32(&cur, end);
|
|
1145
|
+
} else if (strcmp(key, "general.alignment") == 0 && type == 6) {
|
|
1146
|
+
rd32(&cur, end);
|
|
868
1147
|
} else {
|
|
869
1148
|
if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
|
|
870
1149
|
}
|
|
871
1150
|
free(key);
|
|
872
1151
|
}
|
|
1152
|
+
|
|
873
1153
|
if (!vocab_found) { free_model_contents(m); return NULL; }
|
|
874
1154
|
detect_space_marker(m);
|
|
875
1155
|
|
|
@@ -877,10 +1157,6 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
877
1157
|
align_to_32(&cur, end, base);
|
|
878
1158
|
uint8_t *tensor_start = cur;
|
|
879
1159
|
int embd_found = 0;
|
|
880
|
-
void *raw_tensor_data = NULL;
|
|
881
|
-
int tensor_type = -1;
|
|
882
|
-
uint64_t dim0 = 0, dim1 = 0;
|
|
883
|
-
int need_transpose = 0;
|
|
884
1160
|
|
|
885
1161
|
for (int attempt = 0; attempt < 2; attempt++) {
|
|
886
1162
|
cur = tensor_start;
|
|
@@ -892,21 +1168,58 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
892
1168
|
for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
|
|
893
1169
|
uint32_t type = rd32(&cur, end);
|
|
894
1170
|
uint64_t offset = rd64(&cur, end);
|
|
1171
|
+
|
|
895
1172
|
int is_token_embd = (strcmp(name, "token_embd.weight") == 0 ||
|
|
896
1173
|
strcmp(name, "embeddings.word_embeddings.weight") == 0 ||
|
|
897
1174
|
strcmp(name, "model.embed_tokens.weight") == 0);
|
|
1175
|
+
|
|
898
1176
|
if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
|
|
899
1177
|
if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
|
|
900
1178
|
else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
|
|
901
1179
|
}
|
|
1180
|
+
|
|
902
1181
|
if (!embd_found && is_token_embd) {
|
|
903
|
-
if (n_dims < 2 || dims[1] == 0) {
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
1182
|
+
if (n_dims < 2 || dims[1] == 0) {
|
|
1183
|
+
free(name); free_model_contents(m); return NULL;
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
uint64_t ne0 = dims[0];
|
|
1187
|
+
uint64_t ne1 = dims[1];
|
|
1188
|
+
|
|
1189
|
+
int need_transpose = 0;
|
|
1190
|
+
int dim;
|
|
1191
|
+
|
|
1192
|
+
if (ne1 == (uint64_t)m->vocab_size) {
|
|
1193
|
+
dim = (int)ne0;
|
|
1194
|
+
need_transpose = 0;
|
|
1195
|
+
} else if (ne0 == (uint64_t)m->vocab_size) {
|
|
1196
|
+
dim = (int)ne1;
|
|
1197
|
+
need_transpose = 1;
|
|
1198
|
+
} else {
|
|
1199
|
+
dim = (ne0 < ne1) ? (int)ne0 : (int)ne1;
|
|
1200
|
+
need_transpose = (ne0 > ne1) ? 1 : 0;
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
if (dim <= 0 || dim > MAX_DIM) {
|
|
1204
|
+
free(name); free_model_contents(m); return NULL;
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
size_t row_bytes = get_row_bytes(type, (int)(need_transpose ? ne1 : ne0));
|
|
1208
|
+
size_t total_size = (size_t)(need_transpose ? ne1 : ne0) * row_bytes;
|
|
1209
|
+
|
|
1210
|
+
if (offset >= sz || offset + total_size > sz) {
|
|
1211
|
+
free(name);
|
|
1212
|
+
free_model_contents(m);
|
|
1213
|
+
return NULL;
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
m->dim = dim;
|
|
1217
|
+
m->raw_dim0 = ne0;
|
|
1218
|
+
m->raw_dim1 = ne1;
|
|
1219
|
+
m->need_transpose = need_transpose;
|
|
1220
|
+
m->raw_tensor_data = base + offset;
|
|
1221
|
+
m->tensor_type = type;
|
|
1222
|
+
m->row_bytes = row_bytes;
|
|
910
1223
|
embd_found = 1;
|
|
911
1224
|
free(name);
|
|
912
1225
|
break;
|
|
@@ -919,97 +1232,105 @@ static EmbedModel *embed_load_gguf(const char *path) {
|
|
|
919
1232
|
if (!tensor_start) break;
|
|
920
1233
|
}
|
|
921
1234
|
}
|
|
922
|
-
if (!embd_found || m->dim == 0) { free_model_contents(m); return NULL; }
|
|
923
1235
|
|
|
924
|
-
if (
|
|
925
|
-
m
|
|
926
|
-
m->tensor_data = raw_tensor_data;
|
|
927
|
-
} else {
|
|
928
|
-
int n_rows = need_transpose ? (int)dim1 : (int)dim0;
|
|
929
|
-
int n_cols = need_transpose ? (int)dim0 : (int)dim1;
|
|
930
|
-
m->float_data = dequantize_tensor(raw_tensor_data, tensor_type, n_rows, n_cols);
|
|
931
|
-
if (!m->float_data) { free_model_contents(m); return NULL; }
|
|
932
|
-
m->tensor_data = m->float_data;
|
|
1236
|
+
if (!embd_found || m->dim == 0) {
|
|
1237
|
+
free_model_contents(m); return NULL;
|
|
933
1238
|
}
|
|
934
|
-
m->tensor_type = tensor_type;
|
|
935
1239
|
|
|
936
1240
|
return m;
|
|
937
1241
|
}
|
|
938
1242
|
|
|
1243
|
+
/* ------------------------------------------------------------------------- */
|
|
1244
|
+
// L2 normalization
|
|
1245
|
+
static void normalize_l2(float *vec, int dim) {
|
|
1246
|
+
float sum = 0;
|
|
1247
|
+
for (int i = 0; i < dim; i++) sum += vec[i] * vec[i];
|
|
1248
|
+
float norm = sqrtf(sum);
|
|
1249
|
+
if (norm > 1e-8f) {
|
|
1250
|
+
float inv = 1.0f / norm;
|
|
1251
|
+
for (int i = 0; i < dim; i++) vec[i] *= inv;
|
|
1252
|
+
}
|
|
1253
|
+
}
|
|
1254
|
+
|
|
939
1255
|
/* ------------------------------------------------------------------------- */
|
|
940
1256
|
static void embed_text(EmbedModel *m, const char *txt, float *out) {
|
|
941
1257
|
memset(out, 0, sizeof(float) * m->dim);
|
|
1258
|
+
if (!txt || !*txt) return;
|
|
1259
|
+
|
|
942
1260
|
int num_words = 0;
|
|
943
|
-
char **words =
|
|
1261
|
+
char **words = pre_tokenize(txt, &num_words);
|
|
1262
|
+
|
|
944
1263
|
if (!words || num_words == 0) {
|
|
945
|
-
// Fallback to simple space split
|
|
946
|
-
char *copy = strdup(txt);
|
|
947
|
-
if (copy) {
|
|
948
|
-
char *tok = strtok(copy, " \t\n\r");
|
|
949
|
-
int used = 0;
|
|
950
|
-
const float *embd = (float*)m->tensor_data;
|
|
951
|
-
while (tok) {
|
|
952
|
-
int id = hget(m, tok);
|
|
953
|
-
if (id >= 0 && id < m->vocab_size) {
|
|
954
|
-
const float *vec = embd + id * m->dim;
|
|
955
|
-
for (int i = 0; i < m->dim; i++) out[i] += vec[i];
|
|
956
|
-
used++;
|
|
957
|
-
}
|
|
958
|
-
tok = strtok(NULL, " \t\n\r");
|
|
959
|
-
}
|
|
960
|
-
if (used) { float inv = 1.0f / used; for (int i = 0; i < m->dim; i++) out[i] *= inv; }
|
|
961
|
-
free(copy);
|
|
962
|
-
}
|
|
963
1264
|
if (words) free(words);
|
|
964
1265
|
return;
|
|
965
1266
|
}
|
|
966
1267
|
|
|
967
1268
|
int *token_ids = malloc(m->vocab_size * sizeof(int));
|
|
1269
|
+
if (!token_ids) {
|
|
1270
|
+
for (int i = 0; i < num_words; i++) free(words[i]);
|
|
1271
|
+
free(words);
|
|
1272
|
+
return;
|
|
1273
|
+
}
|
|
1274
|
+
|
|
968
1275
|
int used = 0;
|
|
969
|
-
|
|
1276
|
+
float *temp_vec = malloc(m->dim * sizeof(float));
|
|
1277
|
+
|
|
970
1278
|
for (int i = 0; i < num_words; i++) {
|
|
971
1279
|
char *word = words[i];
|
|
972
1280
|
int id = hget(m, word);
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
1281
|
+
|
|
1282
|
+
if (id == -1 && m->space_marker_len > 0) {
|
|
1283
|
+
size_t with_marker_len = m->space_marker_len + strlen(word);
|
|
1284
|
+
char *with_marker = malloc(with_marker_len + 1);
|
|
1285
|
+
if (with_marker) {
|
|
1286
|
+
memcpy(with_marker, m->space_marker, m->space_marker_len);
|
|
1287
|
+
strcpy(with_marker + m->space_marker_len, word);
|
|
1288
|
+
id = hget(m, with_marker);
|
|
1289
|
+
free(with_marker);
|
|
1290
|
+
}
|
|
979
1291
|
}
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
1292
|
+
|
|
1293
|
+
if (id != -1 && id >= 0 && id < m->vocab_size) {
|
|
1294
|
+
dequantize_row_lazy(m, id, temp_vec);
|
|
1295
|
+
for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
|
|
983
1296
|
used++;
|
|
984
1297
|
} else {
|
|
985
1298
|
int num_tokens = 0;
|
|
986
|
-
bpe_tokenize_word(&m->merges, word,
|
|
1299
|
+
bpe_tokenize_word(&m->merges, word, m, token_ids, &num_tokens);
|
|
987
1300
|
for (int k = 0; k < num_tokens; k++) {
|
|
988
1301
|
int tid = token_ids[k];
|
|
989
1302
|
if (tid >= 0 && tid < m->vocab_size) {
|
|
990
|
-
|
|
991
|
-
for (int j = 0; j < m->dim; j++) out[j] +=
|
|
1303
|
+
dequantize_row_lazy(m, tid, temp_vec);
|
|
1304
|
+
for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
|
|
992
1305
|
used++;
|
|
993
1306
|
} else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
|
|
994
|
-
|
|
995
|
-
for (int j = 0; j < m->dim; j++) out[j] +=
|
|
1307
|
+
dequantize_row_lazy(m, m->unknown_token_id, temp_vec);
|
|
1308
|
+
for (int j = 0; j < m->dim; j++) out[j] += temp_vec[j];
|
|
996
1309
|
used++;
|
|
997
1310
|
}
|
|
998
1311
|
}
|
|
999
1312
|
}
|
|
1000
1313
|
free(word);
|
|
1001
1314
|
}
|
|
1315
|
+
|
|
1002
1316
|
free(words);
|
|
1003
1317
|
free(token_ids);
|
|
1318
|
+
free(temp_vec);
|
|
1319
|
+
|
|
1004
1320
|
if (used > 0) {
|
|
1005
1321
|
float inv = 1.0f / used;
|
|
1006
1322
|
for (int i = 0; i < m->dim; i++) out[i] *= inv;
|
|
1007
1323
|
}
|
|
1324
|
+
|
|
1008
1325
|
for (int i = 0; i < m->dim; i++) {
|
|
1009
1326
|
if (isnan(out[i]) || isinf(out[i])) {
|
|
1010
1327
|
out[i] = 0.0f;
|
|
1011
1328
|
}
|
|
1012
1329
|
}
|
|
1330
|
+
|
|
1331
|
+
if (m->normalize == NORM_L2) {
|
|
1332
|
+
normalize_l2(out, m->dim);
|
|
1333
|
+
}
|
|
1013
1334
|
}
|
|
1014
1335
|
|
|
1015
1336
|
/* ------------------------------------------------------------------------- */
|
|
@@ -1020,7 +1341,14 @@ static void rb_embedder_free(void *p) {
|
|
|
1020
1341
|
}
|
|
1021
1342
|
|
|
1022
1343
|
static size_t rb_embedder_memsize(const void *p) {
|
|
1023
|
-
|
|
1344
|
+
const ruby_embedder *e = p;
|
|
1345
|
+
size_t sz = sizeof(ruby_embedder);
|
|
1346
|
+
if (e && e->model) {
|
|
1347
|
+
sz += e->model->vocab_size * sizeof(char*);
|
|
1348
|
+
sz += e->model->mapped_size;
|
|
1349
|
+
sz += HASH_SIZE * sizeof(HashNode*);
|
|
1350
|
+
}
|
|
1351
|
+
return sz;
|
|
1024
1352
|
}
|
|
1025
1353
|
|
|
1026
1354
|
static const rb_data_type_t ruby_embedder_type = {
|
|
@@ -1037,18 +1365,44 @@ static VALUE rb_embedder_alloc(VALUE klass) {
|
|
|
1037
1365
|
static VALUE rb_embedder_initialize(VALUE self, VALUE opts) {
|
|
1038
1366
|
ruby_embedder *e;
|
|
1039
1367
|
TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
|
|
1368
|
+
|
|
1369
|
+
Check_Type(opts, T_HASH);
|
|
1040
1370
|
VALUE path = rb_hash_aref(opts, ID2SYM(rb_intern("model")));
|
|
1371
|
+
if (NIL_P(path)) rb_raise(rb_eArgError, "missing required key: model");
|
|
1041
1372
|
const char *cpath = StringValueCStr(path);
|
|
1373
|
+
|
|
1374
|
+
VALUE normalize = rb_hash_aref(opts, ID2SYM(rb_intern("normalize")));
|
|
1375
|
+
int norm_type = NORM_NONE;
|
|
1376
|
+
if (!NIL_P(normalize)) {
|
|
1377
|
+
if (SYMBOL_P(normalize)) {
|
|
1378
|
+
ID sym_id = SYM2ID(normalize);
|
|
1379
|
+
if (sym_id == rb_intern("l2") || sym_id == rb_intern("L2")) {
|
|
1380
|
+
norm_type = NORM_L2;
|
|
1381
|
+
}
|
|
1382
|
+
} else if (TYPE(normalize) == T_STRING) {
|
|
1383
|
+
const char *norm_str = StringValueCStr(normalize);
|
|
1384
|
+
if (strcasecmp(norm_str, "l2") == 0) {
|
|
1385
|
+
norm_type = NORM_L2;
|
|
1386
|
+
}
|
|
1387
|
+
}
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1042
1390
|
e->model = embed_load_gguf(cpath);
|
|
1043
|
-
if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model");
|
|
1391
|
+
if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model: %s", cpath);
|
|
1392
|
+
|
|
1393
|
+
e->model->normalize = norm_type;
|
|
1044
1394
|
return self;
|
|
1045
1395
|
}
|
|
1046
1396
|
|
|
1047
1397
|
static VALUE rb_embed(VALUE self, VALUE opts) {
|
|
1048
1398
|
ruby_embedder *e;
|
|
1049
1399
|
TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
|
|
1400
|
+
|
|
1401
|
+
Check_Type(opts, T_HASH);
|
|
1050
1402
|
VALUE text = rb_hash_aref(opts, ID2SYM(rb_intern("text")));
|
|
1403
|
+
if (NIL_P(text)) rb_raise(rb_eArgError, "missing required key: text");
|
|
1051
1404
|
const char *ctext = StringValueCStr(text);
|
|
1405
|
+
|
|
1052
1406
|
VALUE out = rb_str_new(NULL, e->model->dim * sizeof(float));
|
|
1053
1407
|
embed_text(e->model, ctext, (float*)RSTRING_PTR(out));
|
|
1054
1408
|
return out;
|