mini_embed 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (4) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +96 -37
  3. data/ext/mini_embed/mini_embed.c +542 -15
  4. metadata +1 -1
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b44c5e93e9fc010a7c97e41f4fe4eaef3a5ee1033d1f8348c1d7a3f8d01d7d39
4
- data.tar.gz: 61ca897bcf84b44822a15bc2ba9e37299567a797b26fd61ba987a4b7645c45b4
3
+ metadata.gz: fd4a9fa127d0882eef7443594736c7ed633bf4728f75cfcf49a2987a515b3e8e
4
+ data.tar.gz: 632a8f4cdd9f2a218dc025b47e6f4c19c03ee81db67f5f50c8184cf14809b05e
5
5
  SHA512:
6
- metadata.gz: ab91e23951d0ef745d37a41467f778de5dcc7bedf21445292d31a1b6286add8cd851ebed6520756689d56226395bac3ce20a4f4bac0066cff8fc8488f26dcff6
7
- data.tar.gz: 75c2885a6a63dfbc12db9bd80848f4d9526b7e37a23697ae227d1099a29ff5f6a3ea264378e5b66ce8a0b56935ddaa27de513594dfb530d5702652a09c391c27
6
+ metadata.gz: 1a3bf50d26e8d53a560e97f1b3125797b1f6de94c773a344d1acb96ab1ef6a8b7a731707f104e4d9cdd856e3df6e1579b9510dafc959d890c6623441d470fa95
7
+ data.tar.gz: ac6f937aafff0dd9dc93193ac85eae4293eb0fa51dbb56897e4ad25c60cf784b9c7896032dc48607b4d702bf81cbb0524b8f3dddbc90d190c2eacdee63200dfb
data/README.md CHANGED
@@ -1,74 +1,133 @@
1
1
  # mini_embed
2
2
 
3
- Fast, minimal GGUF embedding extractor for Ruby.
3
+ A minimal, dependency‑free C extension for Ruby that loads [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) embedding models and computes text embeddings **locally**.
4
+
5
+ **⚠️ Important:** This gem is intended for **small projects, prototypes, and hobbyist use**. It allows you to experiment with embeddings without relying on external APIs or cloud costs. **Do not use MiniEmbed in production** – it lacks the performance, scalability, and tokenization robustness of dedicated solutions. For real applications, use a proper inference server like [llama.cpp](https://github.com/ggerganov/llama.cpp) with its HTTP API, or managed services such as OpenAI, Cohere, or Hugging Face.
6
+
7
+ ---
8
+
9
+ ## Why MiniEmbed?
10
+
11
+ - **Zero external dependencies** – no TensorFlow, PyTorch, or ONNX runtime.
12
+ - **Single‑file C extension** – fast loading and mean‑pooled embeddings.
13
+ - **Supports all common GGUF quantizations** – from `F32` to `Q2_K`.
14
+ - **Works entirely offline** – your data never leaves your machine.
15
+ - Perfect for **weekend projects**, **proof‑of‑concepts**, or **learning** about embeddings.
16
+
17
+ ---
4
18
 
5
19
  ## Installation
6
20
 
7
- Add to your Gemfile:
21
+ Add this line to your application's `Gemfile`:
8
22
 
9
23
  ```ruby
10
24
  gem 'mini_embed'
11
25
  ```
12
- Or install globally:
13
26
 
14
- ```sh
27
+ Then execute:
28
+
29
+ ```bash
30
+ bundle install
31
+ ```
32
+ Or install it globally:
33
+
34
+ ```bash
15
35
  gem install mini_embed
16
36
  ```
17
37
 
18
- Usage
38
+
39
+ ## Requirements:
40
+
41
+ A POSIX system (Linux, macOS, BSD) – Windows via WSL2 works.
42
+
43
+ A C compiler and make (for compiling the native extension).
44
+
45
+ A GGUF embedding model file (see Where to get models).
46
+
47
+ ## Usage
48
+
19
49
  ```ruby
20
50
  require 'mini_embed'
21
51
 
22
- model = MiniEmbed.new(model: 'path/to/model.gguf')
23
- embeddings_bin = model.embeddings(text: "hello world") # => binary ouput
24
- embeddings_array = embeddings_bin.unpack('f*') # => array of float
25
- puts embeddings_array.size # => model dimension
26
- ```
52
+ # Load a GGUF model (F32, F16, Q8_0, Q4_K, etc. are all supported)
53
+ model = MiniEmbed.new(model: '/path/to/gte-small.Q8_0.gguf')
27
54
 
28
- Supported Quantizations
55
+ # Get the raw binary string (little‑endian 32‑bit floats)
56
+ binary = model.embeddings(text: 'hello world')
29
57
 
58
+ # Get an embedding as an array of floats
59
+ embedding = binary.unpack('e*')
60
+ puts embedding.size # e.g. 384
61
+ puts embedding[0..4] # e.g. [0.0123, -0.0456, ...]
30
62
  ```
31
- F32, F16
32
63
 
33
- Q4_0, Q4_1
64
+ ## Simple tokenization note
65
+ MiniEmbed uses a naive space‑based tokenizer. This means it splits input on spaces and looks up each token exactly in the model's vocabulary. For models trained with subword tokenization (like BERT), this will not work for out‑of‑vocabulary words.
66
+ If you need proper subword tokenization, you can:
34
67
 
35
- Q5_0, Q5_1
68
+ - Pre‑tokenize in Ruby using the tokenizers gem and pass token IDs (not yet exposed in the C API, but easy to add).
69
+ - Stick to simple vocabulary words that exist in the model (e.g., "text", "hello", "dog").
36
70
 
37
- Q8_0, Q8_1
71
+ ## Supported Quantization Types
38
72
 
39
- Q2_K, Q3_K, Q4_K, Q5_K, Q6_K, Q8_K
40
- ```
73
+ | Type | Description |
74
+ |------|---------------|
75
+ | 0 | F32 (float32) |
76
+ | 1 | F16 (float16) |
77
+ | 2 | Q4_0 |
78
+ | 3 | Q4_1 |
79
+ | 6 | Q5_0 |
80
+ | 7 | Q5_1 |
81
+ | 8 | Q8_0 |
82
+ | 9 | Q8_1 |
83
+ | 10 | Q2_K |
84
+ | 11 | Q3_K |
85
+ | 12 | Q4_K |
86
+ | 13 | Q5_K |
87
+ | 14 | Q6_K |
88
+ | 15 | Q8_K |
41
89
 
42
- ## Building the Gem
90
+ The extension automatically dequantizes the embedding matrix on load, so inference speed is always that of a plain float32 lookup.
43
91
 
44
- From the `mini_embed/` directory:
92
+ Where to get models
93
+ Hugging Face offers many GGUF models, e.g.:
45
94
 
46
- ```bash
47
- bundle install
48
- bundle exec rake compile
49
- ```
95
+ - `gte-small`
96
+ - `all‑MiniLM‑L6‑v2`
50
97
 
98
+ You can convert any safetensors or PyTorch model using the `convert‑hf‑to‑gguf.py` script from llama.cpp.
51
99
 
52
- To build the gem file:
100
+ For testing, we recommend the `gte-small` model (384 dimensions, ~30k vocabulary).
53
101
 
54
- ```bash
55
- gem build mini_embed.gemspec
56
- ```
102
+ ## Limitations (Why this is not production‑ready)
103
+
104
+ - Single‑threaded, blocking C code – embedding computation runs on the Ruby thread, freezing the interpreter.
105
+ - No batching – only one text at a time.
106
+ - Space‑based tokenization only – works only for words present exactly in the vocabulary.
107
+ - Loads the entire embedding matrix into RAM – for large vocabularies this may consume significant memory.
108
+ - No GPU support – CPU only.
109
+ - Error handling is minimal – invalid models may crash the Ruby process.
110
+
111
+ If you need a robust, scalable solution, consider:
57
112
 
58
- To install locally:
113
+ - Running llama.cpp as a server (./server -m model.gguf --embeddings) and calling its HTTP endpoint.
114
+ - Using a cloud embeddings API (OpenAI, Cohere, VoyageAI, etc.).
115
+ - Deploying a dedicated inference service with BentoML or Ray Serve.
116
+
117
+
118
+ ## Development & Contributing
119
+ Bug reports and pull requests are welcome on GitHub.
120
+ To run the tests:
59
121
 
60
122
  ```bash
61
- gem install ./mini_embed-0.1.0.gem
123
+ bundle exec rspec
62
124
  ```
63
- Using in a Rails project
64
- Add to Gemfile:
65
125
 
66
- ```ruby
67
- gem 'mini_embed', path: '/path/to/mini_embed'
68
- ```
126
+ The gem uses rake-compiler to build the extension. After making changes to the C source, run:
69
127
 
70
- Then `bundle install` and use as above.
128
+ ```bash
129
+ bundle exec rake compile
130
+ ```
71
131
 
72
132
  ## License
73
-
74
- MIT License. See [LICENSE](LICENSE).
133
+ MIT License. See [LICENSE](LICENSE).
@@ -13,6 +13,8 @@
13
13
  #define HASH_SIZE 131071
14
14
  #define MAX_DIMS 4
15
15
  #define GGUF_ALIGN 32
16
+ #define MAX_MERGES 10000
17
+ #define MAX_REGEX 256
16
18
 
17
19
  enum ggml_type {
18
20
  GGML_TYPE_F32 = 0,
@@ -31,6 +33,370 @@ enum ggml_type {
31
33
  GGML_TYPE_Q8_K = 15,
32
34
  };
33
35
 
36
+ enum llama_vocab_type {
37
+ LLAMA_VOCAB_TYPE_NONE = 0,
38
+ LLAMA_VOCAB_TYPE_SPM = 1,
39
+ LLAMA_VOCAB_TYPE_BPE = 2,
40
+ LLAMA_VOCAB_TYPE_WPM = 3,
41
+ };
42
+
43
+ /* ------------------------------------------------------------------------- */
44
+ // Unicode helper functions (adapted from llama.cpp)
45
+ static int unicode_len_utf8(char c) {
46
+ if ((c & 0x80) == 0) return 1;
47
+ if ((c & 0xE0) == 0xC0) return 2;
48
+ if ((c & 0xF0) == 0xE0) return 3;
49
+ if ((c & 0xF8) == 0xF0) return 4;
50
+ return 1; // fallback
51
+ }
52
+
53
+ static int unicode_is_letter(uint32_t cp) {
54
+ // Basic Unicode letter detection (simplified)
55
+ return (cp >= 0x41 && cp <= 0x5A) || (cp >= 0x61 && cp <= 0x7A) ||
56
+ (cp >= 0xC0 && cp <= 0xD6) || (cp >= 0xD8 && cp <= 0xF6) ||
57
+ (cp >= 0xF8 && cp <= 0x2FF) || (cp >= 0x370 && cp <= 0x37D) ||
58
+ (cp >= 0x37F && cp <= 0x1FFF) || (cp >= 0x200C && cp <= 0x200D) ||
59
+ (cp >= 0x2070 && cp <= 0x218F) || (cp >= 0x2C00 && cp <= 0x2FEF) ||
60
+ (cp >= 0x3001 && cp <= 0xD7FF) || (cp >= 0xF900 && cp <= 0xFDCF) ||
61
+ (cp >= 0xFDF0 && cp <= 0xFFFD);
62
+ }
63
+
64
+ static int unicode_is_number(uint32_t cp) {
65
+ return (cp >= 0x30 && cp <= 0x39) || (cp >= 0x660 && cp <= 0x669) ||
66
+ (cp >= 0x6F0 && cp <= 0x6F9) || (cp >= 0x7C0 && cp <= 0x7C9) ||
67
+ (cp >= 0x966 && cp <= 0x96F);
68
+ }
69
+
70
+ static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
71
+ uint32_t cp = 0;
72
+ unsigned char c = (unsigned char)s[0];
73
+
74
+ if (c < 0x80) {
75
+ *len = 1;
76
+ return c;
77
+ } else if ((c & 0xE0) == 0xC0) {
78
+ *len = 2;
79
+ cp = (c & 0x1F) << 6;
80
+ cp |= (s[1] & 0x3F);
81
+ return cp;
82
+ } else if ((c & 0xF0) == 0xE0) {
83
+ *len = 3;
84
+ cp = (c & 0x0F) << 12;
85
+ cp |= (s[1] & 0x3F) << 6;
86
+ cp |= (s[2] & 0x3F);
87
+ return cp;
88
+ } else if ((c & 0xF8) == 0xF0) {
89
+ *len = 4;
90
+ cp = (c & 0x07) << 18;
91
+ cp |= (s[1] & 0x3F) << 12;
92
+ cp |= (s[2] & 0x3F) << 6;
93
+ cp |= (s[3] & 0x3F);
94
+ return cp;
95
+ }
96
+
97
+ *len = 1;
98
+ return c;
99
+ }
100
+
101
+ /* ------------------------------------------------------------------------- */
102
+ // Simple regex pattern matcher for pre-tokenization
103
+ typedef struct {
104
+ char *pattern;
105
+ int pattern_len;
106
+ } RegexPattern;
107
+
108
+ static int match_regex(const char *text, const RegexPattern *patterns, int num_patterns) {
109
+ // Simplified implementation for common BPE patterns
110
+ // Full regex engine would be complex; this handles the most common cases
111
+
112
+ for (int i = 0; i < num_patterns; i++) {
113
+ const char *p = patterns[i].pattern;
114
+ int plen = patterns[i].pattern_len;
115
+
116
+ // Check for common patterns
117
+ if (strstr(p, "\\p{L}")) {
118
+ // Match Unicode letter
119
+ size_t len;
120
+ uint32_t cp = unicode_cpt_from_utf8(text, &len);
121
+ if (unicode_is_letter(cp)) return 1;
122
+ } else if (strstr(p, "\\p{N}")) {
123
+ // Match Unicode number
124
+ size_t len;
125
+ uint32_t cp = unicode_cpt_from_utf8(text, &len);
126
+ if (unicode_is_number(cp)) return 1;
127
+ } else if (p[0] == '\\' && p[1] == 's') {
128
+ // Match whitespace
129
+ if (isspace(text[0])) return 1;
130
+ } else if (p[0] == '\\' && p[1] == 'r') {
131
+ if (text[0] == '\r') return 1;
132
+ } else if (p[0] == '\\' && p[1] == 'n') {
133
+ if (text[0] == '\n') return 1;
134
+ } else if (p[0] == '.' && p[1] == '*') {
135
+ // Match anything
136
+ return 1;
137
+ } else if (isalnum(p[0]) || ispunct(p[0])) {
138
+ // Match literal character
139
+ if (text[0] == p[0]) return 1;
140
+ }
141
+ }
142
+ return 0;
143
+ }
144
+
145
+ static char** unicode_regex_split(const char *text, const RegexPattern *patterns, int num_patterns, int *num_words) {
146
+ char **words = NULL;
147
+ int word_count = 0;
148
+ int word_capacity = 0;
149
+
150
+ size_t text_len = strlen(text);
151
+ size_t pos = 0;
152
+
153
+ while (pos < text_len) {
154
+ // Find the start of a word (character that matches any regex)
155
+ size_t start = pos;
156
+ while (start < text_len) {
157
+ if (match_regex(text + start, patterns, num_patterns)) {
158
+ break;
159
+ }
160
+ start++;
161
+ }
162
+
163
+ if (start >= text_len) break;
164
+
165
+ // Find the end of the word (character that doesn't match any regex)
166
+ size_t end = start;
167
+ while (end < text_len) {
168
+ if (!match_regex(text + end, patterns, num_patterns)) {
169
+ break;
170
+ }
171
+ end++;
172
+ }
173
+
174
+ if (end > start) {
175
+ // Extract the word
176
+ size_t word_len = end - start;
177
+ char *word = malloc(word_len + 1);
178
+ if (word) {
179
+ memcpy(word, text + start, word_len);
180
+ word[word_len] = '\0';
181
+
182
+ // Add to array
183
+ if (word_count >= word_capacity) {
184
+ word_capacity = word_capacity == 0 ? 16 : word_capacity * 2;
185
+ words = realloc(words, word_capacity * sizeof(char*));
186
+ if (!words) {
187
+ for (int i = 0; i < word_count; i++) free(words[i]);
188
+ free(words);
189
+ *num_words = 0;
190
+ return NULL;
191
+ }
192
+ }
193
+ words[word_count++] = word;
194
+ }
195
+ }
196
+
197
+ pos = end;
198
+ }
199
+
200
+ *num_words = word_count;
201
+ return words;
202
+ }
203
+
204
+ /* ------------------------------------------------------------------------- */
205
+ // BPE merge structure
206
+ typedef struct {
207
+ char *left;
208
+ char *right;
209
+ char *merged;
210
+ int rank;
211
+ } BPEMerge;
212
+
213
+ typedef struct {
214
+ BPEMerge *merges;
215
+ int num_merges;
216
+ int capacity;
217
+ } BPEMergeTable;
218
+
219
+ static void bpe_merge_table_init(BPEMergeTable *table) {
220
+ table->merges = NULL;
221
+ table->num_merges = 0;
222
+ table->capacity = 0;
223
+ }
224
+
225
+ static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, const char *merged, int rank) {
226
+ if (table->num_merges >= table->capacity) {
227
+ table->capacity = table->capacity == 0 ? 100 : table->capacity * 2;
228
+ table->merges = realloc(table->merges, table->capacity * sizeof(BPEMerge));
229
+ }
230
+
231
+ BPEMerge *merge = &table->merges[table->num_merges++];
232
+ merge->left = strdup(left);
233
+ merge->right = strdup(right);
234
+ merge->merged = strdup(merged);
235
+ merge->rank = rank;
236
+ }
237
+
238
+ static void bpe_merge_table_free(BPEMergeTable *table) {
239
+ for (int i = 0; i < table->num_merges; i++) {
240
+ free(table->merges[i].left);
241
+ free(table->merges[i].right);
242
+ free(table->merges[i].merged);
243
+ }
244
+ free(table->merges);
245
+ table->merges = NULL;
246
+ table->num_merges = 0;
247
+ }
248
+
249
+ static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
250
+ for (int i = 0; i < table->num_merges; i++) {
251
+ if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
252
+ return table->merges[i].rank;
253
+ }
254
+ }
255
+ return -1;
256
+ }
257
+
258
+ static char* bpe_merge(const BPEMergeTable *table, const char *left, const char *right) {
259
+ for (int i = 0; i < table->num_merges; i++) {
260
+ if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
261
+ return table->merges[i].merged;
262
+ }
263
+ }
264
+ return NULL;
265
+ }
266
+
267
+ /* ------------------------------------------------------------------------- */
268
+ // BPE tokenization helper structures
269
+ typedef struct {
270
+ char *text;
271
+ int start;
272
+ int end;
273
+ int prev;
274
+ int next;
275
+ int used;
276
+ } BPESymbol;
277
+
278
+ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int (*text_to_id)(void*, const char*), void *vocab_data, int *token_ids, int *num_tokens) {
279
+ // Initialize symbols from characters
280
+ int word_len = strlen(word);
281
+ int num_symbols = 0;
282
+ BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
283
+
284
+ // Split into UTF-8 characters
285
+ int offset = 0;
286
+ while (offset < word_len) {
287
+ int char_len = unicode_len_utf8(word[offset]);
288
+ symbols[num_symbols].text = (char*)word + offset;
289
+ symbols[num_symbols].start = offset;
290
+ symbols[num_symbols].end = offset + char_len;
291
+ symbols[num_symbols].prev = num_symbols - 1;
292
+ symbols[num_symbols].next = num_symbols + 1;
293
+ symbols[num_symbols].used = 1;
294
+ offset += char_len;
295
+ num_symbols++;
296
+ }
297
+
298
+ if (num_symbols <= 1) {
299
+ // Single character, just tokenize it
300
+ int id = text_to_id(vocab_data, word);
301
+ if (id != -1) {
302
+ token_ids[*num_tokens] = id;
303
+ (*num_tokens)++;
304
+ }
305
+ free(symbols);
306
+ return;
307
+ }
308
+
309
+ // Build priority queue for merges (simplified)
310
+ typedef struct {
311
+ int left;
312
+ int right;
313
+ int rank;
314
+ } Bigram;
315
+
316
+ Bigram *bigrams = malloc(word_len * word_len * sizeof(Bigram));
317
+ int num_bigrams = 0;
318
+
319
+ // Initialize bigrams
320
+ for (int i = 0; i < num_symbols - 1; i++) {
321
+ if (symbols[i].used && symbols[i+1].used) {
322
+ // Get the concatenated string for this pair
323
+ char *left_str = malloc(symbols[i].end - symbols[i].start + 1);
324
+ char *right_str = malloc(symbols[i+1].end - symbols[i+1].start + 1);
325
+ memcpy(left_str, symbols[i].text, symbols[i].end - symbols[i].start);
326
+ memcpy(right_str, symbols[i+1].text, symbols[i+1].end - symbols[i+1].start);
327
+ left_str[symbols[i].end - symbols[i].start] = '\0';
328
+ right_str[symbols[i+1].end - symbols[i+1].start] = '\0';
329
+
330
+ int rank = bpe_merge_rank(merges, left_str, right_str);
331
+ if (rank != -1) {
332
+ bigrams[num_bigrams].left = i;
333
+ bigrams[num_bigrams].right = i+1;
334
+ bigrams[num_bigrams].rank = rank;
335
+ num_bigrams++;
336
+ }
337
+
338
+ free(left_str);
339
+ free(right_str);
340
+ }
341
+ }
342
+
343
+ // Sort bigrams by rank (lower rank = higher priority)
344
+ for (int i = 0; i < num_bigrams - 1; i++) {
345
+ for (int j = i+1; j < num_bigrams; j++) {
346
+ if (bigrams[i].rank > bigrams[j].rank) {
347
+ Bigram temp = bigrams[i];
348
+ bigrams[i] = bigrams[j];
349
+ bigrams[j] = temp;
350
+ }
351
+ }
352
+ }
353
+
354
+ // Apply merges
355
+ int *merged = calloc(num_symbols, sizeof(int));
356
+ for (int i = 0; i < num_bigrams; i++) {
357
+ int left = bigrams[i].left;
358
+ int right = bigrams[i].right;
359
+
360
+ if (merged[left] || merged[right]) continue;
361
+
362
+ // Merge right into left
363
+ symbols[left].end = symbols[right].end;
364
+ symbols[left].next = symbols[right].next;
365
+ merged[right] = 1;
366
+
367
+ // Update next symbol's prev
368
+ if (symbols[right].next < num_symbols) {
369
+ symbols[symbols[right].next].prev = left;
370
+ }
371
+ }
372
+
373
+ // Collect final tokens
374
+ for (int i = 0; i < num_symbols; i++) {
375
+ if (!merged[i] && symbols[i].used) {
376
+ // Extract the substring
377
+ char *substr = malloc(symbols[i].end - symbols[i].start + 1);
378
+ memcpy(substr, word + symbols[i].start, symbols[i].end - symbols[i].start);
379
+ substr[symbols[i].end - symbols[i].start] = '\0';
380
+
381
+ int id = text_to_id(vocab_data, substr);
382
+ if (id != -1) {
383
+ token_ids[*num_tokens] = id;
384
+ (*num_tokens)++;
385
+ } else {
386
+ // Unknown token - use byte-level fallback
387
+ // For simplicity, we'll use space as a placeholder
388
+ // In a full implementation, you'd encode bytes individually
389
+ }
390
+
391
+ free(substr);
392
+ }
393
+ }
394
+
395
+ free(bigrams);
396
+ free(merged);
397
+ free(symbols);
398
+ }
399
+
34
400
  /* ------------------------------------------------------------------------- */
35
401
  static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
36
402
  if (*p + sz > end) return 0;
@@ -91,6 +457,15 @@ typedef struct {
91
457
  void *mapped;
92
458
  size_t mapped_size;
93
459
  HashNode **table;
460
+
461
+ // BPE tokenization data
462
+ BPEMergeTable merges;
463
+ RegexPattern *pre_patterns;
464
+ int num_pre_patterns;
465
+ int unknown_token_id;
466
+ int bos_token_id;
467
+ int eos_token_id;
468
+ int vocab_type;
94
469
  } EmbedModel;
95
470
 
96
471
  typedef struct {
@@ -122,6 +497,11 @@ static int hget(EmbedModel *m, const char *k) {
122
497
  return -1;
123
498
  }
124
499
 
500
+ static int text_to_id(void *vocab_data, const char *text) {
501
+ EmbedModel *m = (EmbedModel*)vocab_data;
502
+ return hget(m, text);
503
+ }
504
+
125
505
  /* ------------------------------------------------------------------------- */
126
506
  static void *map_file(const char *path, size_t *size) {
127
507
  int fd = open(path, O_RDONLY);
@@ -457,6 +837,16 @@ static void free_model_contents(EmbedModel *m) {
457
837
  }
458
838
  if (m->float_data) free(m->float_data);
459
839
  if (m->mapped) munmap(m->mapped, m->mapped_size);
840
+
841
+ // Free BPE tokenization data
842
+ bpe_merge_table_free(&m->merges);
843
+ if (m->pre_patterns) {
844
+ for (int i = 0; i < m->num_pre_patterns; i++) {
845
+ free(m->pre_patterns[i].pattern);
846
+ }
847
+ free(m->pre_patterns);
848
+ }
849
+
460
850
  free(m);
461
851
  }
462
852
 
@@ -483,6 +873,46 @@ static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
483
873
  return NULL;
484
874
  }
485
875
 
876
+ /* ------------------------------------------------------------------------- */
877
+ static void setup_default_pre_patterns(EmbedModel *m) {
878
+ // Default pre-tokenization regex patterns (similar to Llama 3)
879
+ const char *default_patterns[] = {
880
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])",
881
+ "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
882
+ "\\p{N}{1,3}",
883
+ " ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
884
+ "\\s*[\\r\\n]+",
885
+ "\\s+(?!\\S)",
886
+ "\\s+"
887
+ };
888
+
889
+ m->num_pre_patterns = sizeof(default_patterns) / sizeof(default_patterns[0]);
890
+ m->pre_patterns = malloc(m->num_pre_patterns * sizeof(RegexPattern));
891
+
892
+ for (int i = 0; i < m->num_pre_patterns; i++) {
893
+ m->pre_patterns[i].pattern = strdup(default_patterns[i]);
894
+ m->pre_patterns[i].pattern_len = strlen(default_patterns[i]);
895
+ }
896
+ }
897
+
898
+ /* ------------------------------------------------------------------------- */
899
+ static void parse_merge(const char *merge_str, char **left, char **right) {
900
+ // Parse a merge string like "h ello" -> left="h", right="ello"
901
+ const char *space = strchr(merge_str, ' ');
902
+ if (space) {
903
+ int left_len = space - merge_str;
904
+ *left = malloc(left_len + 1);
905
+ memcpy(*left, merge_str, left_len);
906
+ (*left)[left_len] = '\0';
907
+
908
+ *right = strdup(space + 1);
909
+ } else {
910
+ // No space - treat as single token
911
+ *left = strdup(merge_str);
912
+ *right = strdup("");
913
+ }
914
+ }
915
+
486
916
  /* ------------------------------------------------------------------------- */
487
917
  static EmbedModel *embed_load_gguf(const char *path) {
488
918
  size_t sz;
@@ -504,6 +934,16 @@ static EmbedModel *embed_load_gguf(const char *path) {
504
934
  m->mapped_size = sz;
505
935
  m->table = calloc(HASH_SIZE, sizeof(HashNode*));
506
936
  if (!m->table) { free_model_contents(m); return NULL; }
937
+
938
+ // Initialize BPE structures
939
+ bpe_merge_table_init(&m->merges);
940
+ setup_default_pre_patterns(m);
941
+
942
+ // Default values
943
+ m->unknown_token_id = -1;
944
+ m->bos_token_id = -1;
945
+ m->eos_token_id = -1;
946
+ m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
507
947
 
508
948
  /* ---------- Metadata ---------- */
509
949
  int vocab_found = 0;
@@ -527,6 +967,50 @@ static EmbedModel *embed_load_gguf(const char *path) {
527
967
  hset(m, tok, (int)j);
528
968
  }
529
969
  vocab_found = 1;
970
+ } else if (strcmp(key, "tokenizer.ggml.merges") == 0 && type == 9) {
971
+ uint32_t subtype = rd32(&cur, end);
972
+ uint64_t n = rd64(&cur, end);
973
+ if (subtype == 8) {
974
+ // Parse merges
975
+ for (uint64_t j = 0; j < n && j < MAX_MERGES; j++) {
976
+ char *merge_str = rdstr(&cur, end);
977
+ if (merge_str) {
978
+ char *left, *right;
979
+ parse_merge(merge_str, &left, &right);
980
+ bpe_merge_table_add(&m->merges, left, right, merge_str, j);
981
+ free(left);
982
+ free(right);
983
+ free(merge_str);
984
+ }
985
+ }
986
+ } else {
987
+ // Skip if not string array
988
+ if (!skip_value(&cur, end, type)) {
989
+ free(key); free_model_contents(m); return NULL;
990
+ }
991
+ }
992
+ } else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
993
+ char *model_type = rdstr(&cur, end);
994
+ if (model_type) {
995
+ if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0) {
996
+ m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
997
+ } else if (strcmp(model_type, "bert") == 0) {
998
+ m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
999
+ }
1000
+ free(model_type);
1001
+ }
1002
+ } else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
1003
+ char *pre_type = rdstr(&cur, end);
1004
+ if (pre_type) {
1005
+ // Could load custom regex patterns here if needed
1006
+ free(pre_type);
1007
+ }
1008
+ } else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
1009
+ m->unknown_token_id = rd32(&cur, end);
1010
+ } else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
1011
+ m->bos_token_id = rd32(&cur, end);
1012
+ } else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
1013
+ m->eos_token_id = rd32(&cur, end);
530
1014
  } else {
531
1015
  if (!skip_value(&cur, end, type)) {
532
1016
  free(key); free_model_contents(m); return NULL;
@@ -625,28 +1109,71 @@ static EmbedModel *embed_load_gguf(const char *path) {
625
1109
  /* ------------------------------------------------------------------------- */
626
1110
  static void embed_text(EmbedModel *m, const char *txt, float *out) {
627
1111
  memset(out, 0, sizeof(float) * m->dim);
628
- char *copy = strdup(txt);
629
- if (!copy) return;
630
-
631
- char *tok = strtok(copy, " ");
632
- int used = 0;
1112
+
1113
+ // Pre-tokenize using regex
1114
+ int num_words = 0;
1115
+ char **words = unicode_regex_split(txt, m->pre_patterns, m->num_pre_patterns, &num_words);
1116
+
1117
+ if (!words || num_words == 0) {
1118
+ // Fallback to space splitting if regex fails
1119
+ char *copy = strdup(txt);
1120
+ if (!copy) return;
1121
+
1122
+ char *tok = strtok(copy, " \t\n\r");
1123
+ int used = 0;
1124
+ const float *embd_matrix = m->tensor_data;
1125
+
1126
+ while (tok) {
1127
+ int id = hget(m, tok);
1128
+ if (id >= 0 && id < m->vocab_size) {
1129
+ const float *vec = embd_matrix + id * m->dim;
1130
+ for (int i = 0; i < m->dim; i++) out[i] += vec[i];
1131
+ used++;
1132
+ }
1133
+ tok = strtok(NULL, " \t\n\r");
1134
+ }
1135
+
1136
+ if (used > 0) {
1137
+ float inv = 1.0f / used;
1138
+ for (int i = 0; i < m->dim; i++) out[i] *= inv;
1139
+ }
1140
+ free(copy);
1141
+ return;
1142
+ }
1143
+
1144
+ // Tokenize each word using BPE
1145
+ int *token_ids = malloc(m->vocab_size * sizeof(int)); // Max possible tokens
1146
+ int num_tokens = 0;
633
1147
  const float *embd_matrix = m->tensor_data;
634
-
635
- while (tok) {
636
- int id = hget(m, tok);
637
- if (id >= 0 && id < m->vocab_size) {
638
- const float *vec = embd_matrix + id * m->dim;
639
- for (int i = 0; i < m->dim; i++) out[i] += vec[i];
640
- used++;
1148
+ int used = 0;
1149
+
1150
+ for (int i = 0; i < num_words; i++) {
1151
+ num_tokens = 0;
1152
+ bpe_tokenize_word(&m->merges, words[i], text_to_id, m, token_ids, &num_tokens);
1153
+
1154
+ for (int j = 0; j < num_tokens; j++) {
1155
+ int id = token_ids[j];
1156
+ if (id >= 0 && id < m->vocab_size) {
1157
+ const float *vec = embd_matrix + id * m->dim;
1158
+ for (int k = 0; k < m->dim; k++) out[k] += vec[k];
1159
+ used++;
1160
+ } else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
1161
+ // Use unknown token as fallback
1162
+ const float *vec = embd_matrix + m->unknown_token_id * m->dim;
1163
+ for (int k = 0; k < m->dim; k++) out[k] += vec[k];
1164
+ used++;
1165
+ }
641
1166
  }
642
- tok = strtok(NULL, " ");
1167
+
1168
+ free(words[i]);
643
1169
  }
644
-
1170
+ free(words);
1171
+ free(token_ids);
1172
+
645
1173
  if (used > 0) {
646
1174
  float inv = 1.0f / used;
647
1175
  for (int i = 0; i < m->dim; i++) out[i] *= inv;
648
1176
  }
649
- free(copy);
650
1177
  }
651
1178
 
652
1179
  /* ------------------------------------------------------------------------- */
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: mini_embed
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Makapoxa