mini_embed 0.1.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fd4a9fa127d0882eef7443594736c7ed633bf4728f75cfcf49a2987a515b3e8e
4
- data.tar.gz: 632a8f4cdd9f2a218dc025b47e6f4c19c03ee81db67f5f50c8184cf14809b05e
3
+ metadata.gz: 2df9f5c081f8a7fa2447261817ebba58e5c062921a1dc6ee3ec8048fdc300022
4
+ data.tar.gz: 4ee4f87161506c59e6deda7dd12b819c33e86ae5f5843aa89a9754b57b27f968
5
5
  SHA512:
6
- metadata.gz: 1a3bf50d26e8d53a560e97f1b3125797b1f6de94c773a344d1acb96ab1ef6a8b7a731707f104e4d9cdd856e3df6e1579b9510dafc959d890c6623441d470fa95
7
- data.tar.gz: ac6f937aafff0dd9dc93193ac85eae4293eb0fa51dbb56897e4ad25c60cf784b9c7896032dc48607b4d702bf81cbb0524b8f3dddbc90d190c2eacdee63200dfb
6
+ metadata.gz: 4646b9f96a6ef525751d3046f0524d308940c1deb3a85623f775666ab2e1bbcddbec62c16f110f2cc1eae620f0b52c900a64f03cfa01f8e856d3546eb404ee98
7
+ data.tar.gz: f03b8103bddc296f1601d62bda851655529dc47cea1eee9f6bfcfc03fb5c005ad33f2b53a81e84d882b2813c670254bdbdb465b2b64f76c4ffc66248f5b27f73
data/README.md CHANGED
@@ -52,15 +52,19 @@ require 'mini_embed'
52
52
  # Load a GGUF model (F32, F16, Q8_0, Q4_K, etc. are all supported)
53
53
  model = MiniEmbed.new(model: '/path/to/gte-small.Q8_0.gguf')
54
54
 
55
- # Get the raw binary string (little‑endian 32‑bit floats)
56
- binary = model.embeddings(text: 'hello world')
57
-
58
- # Get an embedding as an array of floats
59
- embedding = binary.unpack('e*')
55
+ # Get embedding as an array of floats (default)
56
+ embedding = model.embeddings(text: 'hello world')
60
57
  puts embedding.size # e.g. 384
61
58
  puts embedding[0..4] # e.g. [0.0123, -0.0456, ...]
59
+
60
+ # Or get the raw binary string (little‑endian 32‑bit floats)
61
+ binary = model.embeddings(text: 'hello world', type: :binary)
62
+ embedding_from_binary = binary.unpack('e*')
62
63
  ```
63
64
 
65
+ Note: The type parameter is optional – it defaults to :vector which returns a Ruby `Array<Float>`. Use `type: :binary` to get the raw binary string (compatible with the original C extension).
66
+
67
+
64
68
  ## Simple tokenization note
65
69
  MiniEmbed uses a naive space‑based tokenizer. This means it splits input on spaces and looks up each token exactly in the model's vocabulary. For models trained with subword tokenization (like BERT), this will not work for out‑of‑vocabulary words.
66
70
  If you need proper subword tokenization, you can:
@@ -41,17 +41,16 @@ enum llama_vocab_type {
41
41
  };
42
42
 
43
43
  /* ------------------------------------------------------------------------- */
44
- // Unicode helper functions (adapted from llama.cpp)
44
+ // Unicode helper functions
45
45
  static int unicode_len_utf8(char c) {
46
46
  if ((c & 0x80) == 0) return 1;
47
47
  if ((c & 0xE0) == 0xC0) return 2;
48
48
  if ((c & 0xF0) == 0xE0) return 3;
49
49
  if ((c & 0xF8) == 0xF0) return 4;
50
- return 1; // fallback
50
+ return 1;
51
51
  }
52
52
 
53
53
  static int unicode_is_letter(uint32_t cp) {
54
- // Basic Unicode letter detection (simplified)
55
54
  return (cp >= 0x41 && cp <= 0x5A) || (cp >= 0x61 && cp <= 0x7A) ||
56
55
  (cp >= 0xC0 && cp <= 0xD6) || (cp >= 0xD8 && cp <= 0xF6) ||
57
56
  (cp >= 0xF8 && cp <= 0x2FF) || (cp >= 0x370 && cp <= 0x37D) ||
@@ -68,74 +67,42 @@ static int unicode_is_number(uint32_t cp) {
68
67
  }
69
68
 
70
69
  static uint32_t unicode_cpt_from_utf8(const char *s, size_t *len) {
71
- uint32_t cp = 0;
72
70
  unsigned char c = (unsigned char)s[0];
73
-
74
- if (c < 0x80) {
75
- *len = 1;
76
- return c;
77
- } else if ((c & 0xE0) == 0xC0) {
78
- *len = 2;
79
- cp = (c & 0x1F) << 6;
80
- cp |= (s[1] & 0x3F);
81
- return cp;
82
- } else if ((c & 0xF0) == 0xE0) {
83
- *len = 3;
84
- cp = (c & 0x0F) << 12;
85
- cp |= (s[1] & 0x3F) << 6;
86
- cp |= (s[2] & 0x3F);
87
- return cp;
88
- } else if ((c & 0xF8) == 0xF0) {
89
- *len = 4;
90
- cp = (c & 0x07) << 18;
91
- cp |= (s[1] & 0x3F) << 12;
92
- cp |= (s[2] & 0x3F) << 6;
93
- cp |= (s[3] & 0x3F);
94
- return cp;
95
- }
96
-
71
+ if (c < 0x80) { *len = 1; return c; }
72
+ if ((c & 0xE0) == 0xC0) { *len = 2; return ((c & 0x1F) << 6) | (s[1] & 0x3F); }
73
+ if ((c & 0xF0) == 0xE0) { *len = 3; return ((c & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F); }
74
+ if ((c & 0xF8) == 0xF0) { *len = 4; return ((c & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F); }
97
75
  *len = 1;
98
76
  return c;
99
77
  }
100
78
 
101
79
  /* ------------------------------------------------------------------------- */
102
- // Simple regex pattern matcher for pre-tokenization
80
+ // Simple regex pattern matcher (simplified)
103
81
  typedef struct {
104
82
  char *pattern;
105
83
  int pattern_len;
106
84
  } RegexPattern;
107
85
 
108
86
  static int match_regex(const char *text, const RegexPattern *patterns, int num_patterns) {
109
- // Simplified implementation for common BPE patterns
110
- // Full regex engine would be complex; this handles the most common cases
111
-
112
87
  for (int i = 0; i < num_patterns; i++) {
113
88
  const char *p = patterns[i].pattern;
114
- int plen = patterns[i].pattern_len;
115
-
116
- // Check for common patterns
117
89
  if (strstr(p, "\\p{L}")) {
118
- // Match Unicode letter
119
90
  size_t len;
120
91
  uint32_t cp = unicode_cpt_from_utf8(text, &len);
121
92
  if (unicode_is_letter(cp)) return 1;
122
93
  } else if (strstr(p, "\\p{N}")) {
123
- // Match Unicode number
124
94
  size_t len;
125
95
  uint32_t cp = unicode_cpt_from_utf8(text, &len);
126
96
  if (unicode_is_number(cp)) return 1;
127
97
  } else if (p[0] == '\\' && p[1] == 's') {
128
- // Match whitespace
129
98
  if (isspace(text[0])) return 1;
130
99
  } else if (p[0] == '\\' && p[1] == 'r') {
131
100
  if (text[0] == '\r') return 1;
132
101
  } else if (p[0] == '\\' && p[1] == 'n') {
133
102
  if (text[0] == '\n') return 1;
134
103
  } else if (p[0] == '.' && p[1] == '*') {
135
- // Match anything
136
104
  return 1;
137
105
  } else if (isalnum(p[0]) || ispunct(p[0])) {
138
- // Match literal character
139
106
  if (text[0] == p[0]) return 1;
140
107
  }
141
108
  }
@@ -144,65 +111,36 @@ static int match_regex(const char *text, const RegexPattern *patterns, int num_p
144
111
 
145
112
  static char** unicode_regex_split(const char *text, const RegexPattern *patterns, int num_patterns, int *num_words) {
146
113
  char **words = NULL;
147
- int word_count = 0;
148
- int word_capacity = 0;
149
-
150
- size_t text_len = strlen(text);
151
- size_t pos = 0;
152
-
114
+ int word_count = 0, word_capacity = 0;
115
+ size_t text_len = strlen(text), pos = 0;
116
+
153
117
  while (pos < text_len) {
154
- // Find the start of a word (character that matches any regex)
155
118
  size_t start = pos;
156
- while (start < text_len) {
157
- if (match_regex(text + start, patterns, num_patterns)) {
158
- break;
159
- }
160
- start++;
161
- }
162
-
119
+ while (start < text_len && !match_regex(text + start, patterns, num_patterns)) start++;
163
120
  if (start >= text_len) break;
164
-
165
- // Find the end of the word (character that doesn't match any regex)
166
121
  size_t end = start;
167
- while (end < text_len) {
168
- if (!match_regex(text + end, patterns, num_patterns)) {
169
- break;
170
- }
171
- end++;
172
- }
173
-
122
+ while (end < text_len && match_regex(text + end, patterns, num_patterns)) end++;
174
123
  if (end > start) {
175
- // Extract the word
176
124
  size_t word_len = end - start;
177
125
  char *word = malloc(word_len + 1);
178
- if (word) {
179
- memcpy(word, text + start, word_len);
180
- word[word_len] = '\0';
181
-
182
- // Add to array
183
- if (word_count >= word_capacity) {
184
- word_capacity = word_capacity == 0 ? 16 : word_capacity * 2;
185
- words = realloc(words, word_capacity * sizeof(char*));
186
- if (!words) {
187
- for (int i = 0; i < word_count; i++) free(words[i]);
188
- free(words);
189
- *num_words = 0;
190
- return NULL;
191
- }
192
- }
193
- words[word_count++] = word;
126
+ if (!word) { while (--word_count >= 0) free(words[word_count]); free(words); *num_words = 0; return NULL; }
127
+ memcpy(word, text + start, word_len);
128
+ word[word_len] = '\0';
129
+ if (word_count >= word_capacity) {
130
+ word_capacity = word_capacity ? word_capacity * 2 : 16;
131
+ words = realloc(words, word_capacity * sizeof(char*));
132
+ if (!words) { free(word); while (--word_count >= 0) free(words[word_count]); *num_words = 0; return NULL; }
194
133
  }
134
+ words[word_count++] = word;
195
135
  }
196
-
197
136
  pos = end;
198
137
  }
199
-
200
138
  *num_words = word_count;
201
139
  return words;
202
140
  }
203
141
 
204
142
  /* ------------------------------------------------------------------------- */
205
- // BPE merge structure
143
+ // BPE merge structures
206
144
  typedef struct {
207
145
  char *left;
208
146
  char *right;
@@ -217,22 +155,19 @@ typedef struct {
217
155
  } BPEMergeTable;
218
156
 
219
157
  static void bpe_merge_table_init(BPEMergeTable *table) {
220
- table->merges = NULL;
221
- table->num_merges = 0;
222
- table->capacity = 0;
158
+ memset(table, 0, sizeof(*table));
223
159
  }
224
160
 
225
161
  static void bpe_merge_table_add(BPEMergeTable *table, const char *left, const char *right, const char *merged, int rank) {
226
162
  if (table->num_merges >= table->capacity) {
227
- table->capacity = table->capacity == 0 ? 100 : table->capacity * 2;
163
+ table->capacity = table->capacity ? table->capacity * 2 : 100;
228
164
  table->merges = realloc(table->merges, table->capacity * sizeof(BPEMerge));
229
165
  }
230
-
231
- BPEMerge *merge = &table->merges[table->num_merges++];
232
- merge->left = strdup(left);
233
- merge->right = strdup(right);
234
- merge->merged = strdup(merged);
235
- merge->rank = rank;
166
+ BPEMerge *m = &table->merges[table->num_merges++];
167
+ m->left = strdup(left);
168
+ m->right = strdup(right);
169
+ m->merged = strdup(merged);
170
+ m->rank = rank;
236
171
  }
237
172
 
238
173
  static void bpe_merge_table_free(BPEMergeTable *table) {
@@ -248,40 +183,25 @@ static void bpe_merge_table_free(BPEMergeTable *table) {
248
183
 
249
184
  static int bpe_merge_rank(const BPEMergeTable *table, const char *left, const char *right) {
250
185
  for (int i = 0; i < table->num_merges; i++) {
251
- if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
186
+ if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0)
252
187
  return table->merges[i].rank;
253
- }
254
188
  }
255
189
  return -1;
256
190
  }
257
191
 
258
- static char* bpe_merge(const BPEMergeTable *table, const char *left, const char *right) {
259
- for (int i = 0; i < table->num_merges; i++) {
260
- if (strcmp(table->merges[i].left, left) == 0 && strcmp(table->merges[i].right, right) == 0) {
261
- return table->merges[i].merged;
262
- }
263
- }
264
- return NULL;
265
- }
266
-
267
192
  /* ------------------------------------------------------------------------- */
268
- // BPE tokenization helper structures
193
+ // BPE tokenization
269
194
  typedef struct {
270
195
  char *text;
271
- int start;
272
- int end;
273
- int prev;
274
- int next;
196
+ int start, end;
197
+ int prev, next;
275
198
  int used;
276
199
  } BPESymbol;
277
200
 
278
201
  static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int (*text_to_id)(void*, const char*), void *vocab_data, int *token_ids, int *num_tokens) {
279
- // Initialize symbols from characters
280
202
  int word_len = strlen(word);
281
203
  int num_symbols = 0;
282
204
  BPESymbol *symbols = malloc(word_len * sizeof(BPESymbol));
283
-
284
- // Split into UTF-8 characters
285
205
  int offset = 0;
286
206
  while (offset < word_len) {
287
207
  int char_len = unicode_len_utf8(word[offset]);
@@ -294,39 +214,25 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
294
214
  offset += char_len;
295
215
  num_symbols++;
296
216
  }
297
-
217
+
298
218
  if (num_symbols <= 1) {
299
- // Single character, just tokenize it
300
219
  int id = text_to_id(vocab_data, word);
301
- if (id != -1) {
302
- token_ids[*num_tokens] = id;
303
- (*num_tokens)++;
304
- }
220
+ if (id != -1) token_ids[(*num_tokens)++] = id;
305
221
  free(symbols);
306
222
  return;
307
223
  }
308
-
309
- // Build priority queue for merges (simplified)
310
- typedef struct {
311
- int left;
312
- int right;
313
- int rank;
314
- } Bigram;
315
-
316
- Bigram *bigrams = malloc(word_len * word_len * sizeof(Bigram));
224
+
225
+ typedef struct { int left, right, rank; } Bigram;
226
+ Bigram *bigrams = malloc(num_symbols * num_symbols * sizeof(Bigram));
317
227
  int num_bigrams = 0;
318
-
319
- // Initialize bigrams
320
228
  for (int i = 0; i < num_symbols - 1; i++) {
321
229
  if (symbols[i].used && symbols[i+1].used) {
322
- // Get the concatenated string for this pair
323
230
  char *left_str = malloc(symbols[i].end - symbols[i].start + 1);
324
231
  char *right_str = malloc(symbols[i+1].end - symbols[i+1].start + 1);
325
232
  memcpy(left_str, symbols[i].text, symbols[i].end - symbols[i].start);
326
233
  memcpy(right_str, symbols[i+1].text, symbols[i+1].end - symbols[i+1].start);
327
234
  left_str[symbols[i].end - symbols[i].start] = '\0';
328
235
  right_str[symbols[i+1].end - symbols[i+1].start] = '\0';
329
-
330
236
  int rank = bpe_merge_rank(merges, left_str, right_str);
331
237
  if (rank != -1) {
332
238
  bigrams[num_bigrams].left = i;
@@ -334,70 +240,42 @@ static void bpe_tokenize_word(const BPEMergeTable *merges, const char *word, int
334
240
  bigrams[num_bigrams].rank = rank;
335
241
  num_bigrams++;
336
242
  }
337
-
338
- free(left_str);
339
- free(right_str);
243
+ free(left_str); free(right_str);
340
244
  }
341
245
  }
342
-
343
- // Sort bigrams by rank (lower rank = higher priority)
344
- for (int i = 0; i < num_bigrams - 1; i++) {
345
- for (int j = i+1; j < num_bigrams; j++) {
246
+ for (int i = 0; i < num_bigrams - 1; i++)
247
+ for (int j = i+1; j < num_bigrams; j++)
346
248
  if (bigrams[i].rank > bigrams[j].rank) {
347
- Bigram temp = bigrams[i];
249
+ Bigram tmp = bigrams[i];
348
250
  bigrams[i] = bigrams[j];
349
- bigrams[j] = temp;
251
+ bigrams[j] = tmp;
350
252
  }
351
- }
352
- }
353
-
354
- // Apply merges
253
+
355
254
  int *merged = calloc(num_symbols, sizeof(int));
356
255
  for (int i = 0; i < num_bigrams; i++) {
357
- int left = bigrams[i].left;
358
- int right = bigrams[i].right;
359
-
256
+ int left = bigrams[i].left, right = bigrams[i].right;
360
257
  if (merged[left] || merged[right]) continue;
361
-
362
- // Merge right into left
363
258
  symbols[left].end = symbols[right].end;
364
259
  symbols[left].next = symbols[right].next;
365
260
  merged[right] = 1;
366
-
367
- // Update next symbol's prev
368
- if (symbols[right].next < num_symbols) {
369
- symbols[symbols[right].next].prev = left;
370
- }
261
+ if (symbols[right].next < num_symbols) symbols[symbols[right].next].prev = left;
371
262
  }
372
-
373
- // Collect final tokens
263
+
374
264
  for (int i = 0; i < num_symbols; i++) {
375
265
  if (!merged[i] && symbols[i].used) {
376
- // Extract the substring
377
266
  char *substr = malloc(symbols[i].end - symbols[i].start + 1);
378
267
  memcpy(substr, word + symbols[i].start, symbols[i].end - symbols[i].start);
379
268
  substr[symbols[i].end - symbols[i].start] = '\0';
380
-
381
269
  int id = text_to_id(vocab_data, substr);
382
- if (id != -1) {
383
- token_ids[*num_tokens] = id;
384
- (*num_tokens)++;
385
- } else {
386
- // Unknown token - use byte-level fallback
387
- // For simplicity, we'll use space as a placeholder
388
- // In a full implementation, you'd encode bytes individually
389
- }
390
-
270
+ if (id != -1) token_ids[(*num_tokens)++] = id;
391
271
  free(substr);
392
272
  }
393
273
  }
394
-
395
- free(bigrams);
396
- free(merged);
397
- free(symbols);
274
+ free(bigrams); free(merged); free(symbols);
398
275
  }
399
276
 
400
277
  /* ------------------------------------------------------------------------- */
278
+ // GGUF parsing
401
279
  static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
402
280
  if (*p + sz > end) return 0;
403
281
  *p += sz;
@@ -405,14 +283,14 @@ static int safe_advance(uint8_t **p, uint8_t *end, size_t sz) {
405
283
  }
406
284
 
407
285
  static uint32_t rd32(uint8_t **p, uint8_t *end) {
408
- uint32_t v = 0;
286
+ uint32_t v;
409
287
  if (!safe_advance(p, end, 4)) return 0;
410
288
  memcpy(&v, *p - 4, 4);
411
289
  return v;
412
290
  }
413
291
 
414
292
  static uint64_t rd64(uint8_t **p, uint8_t *end) {
415
- uint64_t v = 0;
293
+ uint64_t v;
416
294
  if (!safe_advance(p, end, 8)) return 0;
417
295
  memcpy(&v, *p - 8, 8);
418
296
  return v;
@@ -423,9 +301,9 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
423
301
  uint64_t len;
424
302
  memcpy(&len, *p, 8);
425
303
  *p += 8;
426
- if (len == 0 || len > (1 << 20)) return NULL;
304
+ if (len == 0 || len > (1<<20)) return NULL;
427
305
  if (*p + len > end) return NULL;
428
- char *s = malloc(len + 1);
306
+ char *s = malloc(len+1);
429
307
  if (!s) return NULL;
430
308
  memcpy(s, *p, len);
431
309
  s[len] = '\0';
@@ -436,29 +314,27 @@ static char *rdstr(uint8_t **p, uint8_t *end) {
436
314
  static void align_to_32(uint8_t **p, uint8_t *end, uint8_t *base) {
437
315
  size_t off = *p - base;
438
316
  size_t aligned = (off + GGUF_ALIGN - 1) & ~(GGUF_ALIGN - 1);
439
- if (base + aligned <= end)
440
- *p = base + aligned;
317
+ if (base + aligned <= end) *p = base + aligned;
441
318
  }
442
319
 
443
320
  /* ------------------------------------------------------------------------- */
321
+ // Hash table for vocabulary
444
322
  typedef struct HashNode {
445
323
  char *key;
446
- int id;
324
+ int id;
447
325
  struct HashNode *next;
448
326
  } HashNode;
449
327
 
450
328
  typedef struct {
451
- int vocab_size;
452
- int dim;
453
- char **tokens;
454
- float *float_data;
455
- void *tensor_data;
456
- int tensor_type;
457
- void *mapped;
458
- size_t mapped_size;
329
+ int vocab_size;
330
+ int dim;
331
+ char **tokens;
332
+ float *float_data;
333
+ void *tensor_data;
334
+ int tensor_type;
335
+ void *mapped;
336
+ size_t mapped_size;
459
337
  HashNode **table;
460
-
461
- // BPE tokenization data
462
338
  BPEMergeTable merges;
463
339
  RegexPattern *pre_patterns;
464
340
  int num_pre_patterns;
@@ -466,6 +342,7 @@ typedef struct {
466
342
  int bos_token_id;
467
343
  int eos_token_id;
468
344
  int vocab_type;
345
+ char space_marker[8];
469
346
  } EmbedModel;
470
347
 
471
348
  typedef struct {
@@ -498,11 +375,11 @@ static int hget(EmbedModel *m, const char *k) {
498
375
  }
499
376
 
500
377
  static int text_to_id(void *vocab_data, const char *text) {
501
- EmbedModel *m = (EmbedModel*)vocab_data;
502
- return hget(m, text);
378
+ return hget((EmbedModel*)vocab_data, text);
503
379
  }
504
380
 
505
381
  /* ------------------------------------------------------------------------- */
382
+ // File mapping
506
383
  static void *map_file(const char *path, size_t *size) {
507
384
  int fd = open(path, O_RDONLY);
508
385
  if (fd < 0) return NULL;
@@ -511,29 +388,22 @@ static void *map_file(const char *path, size_t *size) {
511
388
  *size = st.st_size;
512
389
  void *data = mmap(NULL, *size, PROT_READ, MAP_PRIVATE, fd, 0);
513
390
  close(fd);
514
- if (data == MAP_FAILED) return NULL;
515
- return data;
391
+ return data == MAP_FAILED ? NULL : data;
516
392
  }
517
393
 
518
394
  /* ------------------------------------------------------------------------- */
395
+ // FP16 conversion
519
396
  static float fp16_to_fp32(uint16_t h) {
520
- const uint16_t sign = (h >> 15) & 1;
521
- const uint16_t exp = (h >> 10) & 0x1F;
522
- const uint16_t mant = h & 0x3FF;
523
- float val;
524
- if (exp == 0) {
525
- val = (mant / 1024.0f) * 6.103515625e-5f;
526
- } else if (exp == 31) {
527
- return 0.0f;
528
- } else {
529
- val = (1.0f + mant / 1024.0f) * (1 << (exp - 15));
530
- }
531
- return sign ? -val : val;
397
+ uint16_t sign = (h >> 15) & 1;
398
+ uint16_t exp = (h >> 10) & 0x1F;
399
+ uint16_t mant = h & 0x3FF;
400
+ if (exp == 0) return (mant / 1024.0f) * 6.103515625e-5f * (sign ? -1.0f : 1.0f);
401
+ if (exp == 31) return 0.0f;
402
+ return (1.0f + mant / 1024.0f) * (1 << (exp - 15)) * (sign ? -1.0f : 1.0f);
532
403
  }
533
404
 
534
405
  /* ------------------------------------------------------------------------- */
535
- /* Block dequantization */
536
-
406
+ // Block dequantization functions
537
407
  static void dequantize_row_q4_0(const void *vx, float *y, int k) {
538
408
  const int nb = k / 32;
539
409
  const uint8_t *x = vx;
@@ -621,7 +491,6 @@ static void dequantize_row_q8_1(const void *vx, float *y, int k) {
621
491
  }
622
492
  }
623
493
 
624
- /* K-quants */
625
494
  static void dequantize_row_q2_K(const void *vx, float *y, int k) {
626
495
  const int nb = k / 256;
627
496
  const uint8_t *x = vx;
@@ -748,7 +617,6 @@ static void dequantize_row_q8_K(const void *vx, float *y, int k) {
748
617
  }
749
618
  }
750
619
 
751
- /* ------------------------------------------------------------------------- */
752
620
  static float* dequantize_tensor(const void *data, int type, int n_rows, int n_cols) {
753
621
  if (type == GGML_TYPE_F32) {
754
622
  float *out = malloc(n_rows * n_cols * sizeof(float));
@@ -793,6 +661,14 @@ static float* dequantize_tensor(const void *data, int type, int n_rows, int n_co
793
661
  for (int r = 0; r < n_rows; r++) {
794
662
  dequant_func(in + r * row_bytes, out + r * n_cols, n_cols);
795
663
  }
664
+
665
+ // Sanitize the tensor: replace NaNs, Infs, and astronomically large values with zero
666
+ int total = n_rows * n_cols;
667
+ for (int i = 0; i < total; i++) {
668
+ if (isnan(out[i]) || isinf(out[i]) || fabs(out[i]) > 1e10f) {
669
+ out[i] = 0.0f;
670
+ }
671
+ }
796
672
  return out;
797
673
  }
798
674
 
@@ -837,45 +713,49 @@ static void free_model_contents(EmbedModel *m) {
837
713
  }
838
714
  if (m->float_data) free(m->float_data);
839
715
  if (m->mapped) munmap(m->mapped, m->mapped_size);
840
-
841
- // Free BPE tokenization data
842
716
  bpe_merge_table_free(&m->merges);
843
717
  if (m->pre_patterns) {
844
- for (int i = 0; i < m->num_pre_patterns; i++) {
845
- free(m->pre_patterns[i].pattern);
846
- }
718
+ for (int i = 0; i < m->num_pre_patterns; i++) free(m->pre_patterns[i].pattern);
847
719
  free(m->pre_patterns);
848
720
  }
849
-
850
721
  free(m);
851
722
  }
852
723
 
853
724
  /* ------------------------------------------------------------------------- */
854
725
  static int is_printable_string(const char *s, size_t len) {
855
- for (size_t i = 0; i < len; i++)
856
- if (!isprint((unsigned char)s[i])) return 0;
726
+ for (size_t i = 0; i < len; i++) if (!isprint((unsigned char)s[i])) return 0;
857
727
  return 1;
858
728
  }
859
729
 
860
- /* Fallback: find the start of tensor info by scanning for a valid string */
861
730
  static uint8_t *find_tensor_info_start(uint8_t *cur, uint8_t *end) {
862
731
  uint8_t *scan = cur;
863
732
  while (scan + 8 < end) {
864
733
  uint64_t len;
865
734
  memcpy(&len, scan, 8);
866
- if (len > 0 && len < 256 && scan + 8 + len <= end) {
867
- if (is_printable_string((char*)scan + 8, len)) {
868
- return scan;
869
- }
870
- }
735
+ if (len > 0 && len < 256 && scan + 8 + len <= end && is_printable_string((char*)scan+8, len))
736
+ return scan;
871
737
  scan++;
872
738
  }
873
739
  return NULL;
874
740
  }
875
741
 
876
742
  /* ------------------------------------------------------------------------- */
743
+ static void detect_space_marker(EmbedModel *m) {
744
+ const char *candidates[] = {"▁", "Ġ", " "};
745
+ for (int i = 0; i < 3; i++) {
746
+ const char *marker = candidates[i];
747
+ int marker_len = strlen(marker);
748
+ for (int j = 0; j < m->vocab_size; j++) {
749
+ if (strncmp(m->tokens[j], marker, marker_len) == 0) {
750
+ strcpy(m->space_marker, marker);
751
+ return;
752
+ }
753
+ }
754
+ }
755
+ m->space_marker[0] = '\0';
756
+ }
757
+
877
758
  static void setup_default_pre_patterns(EmbedModel *m) {
878
- // Default pre-tokenization regex patterns (similar to Llama 3)
879
759
  const char *default_patterns[] = {
880
760
  "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])",
881
761
  "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
@@ -885,29 +765,23 @@ static void setup_default_pre_patterns(EmbedModel *m) {
885
765
  "\\s+(?!\\S)",
886
766
  "\\s+"
887
767
  };
888
-
889
- m->num_pre_patterns = sizeof(default_patterns) / sizeof(default_patterns[0]);
768
+ m->num_pre_patterns = sizeof(default_patterns)/sizeof(default_patterns[0]);
890
769
  m->pre_patterns = malloc(m->num_pre_patterns * sizeof(RegexPattern));
891
-
892
770
  for (int i = 0; i < m->num_pre_patterns; i++) {
893
771
  m->pre_patterns[i].pattern = strdup(default_patterns[i]);
894
772
  m->pre_patterns[i].pattern_len = strlen(default_patterns[i]);
895
773
  }
896
774
  }
897
775
 
898
- /* ------------------------------------------------------------------------- */
899
776
  static void parse_merge(const char *merge_str, char **left, char **right) {
900
- // Parse a merge string like "h ello" -> left="h", right="ello"
901
777
  const char *space = strchr(merge_str, ' ');
902
778
  if (space) {
903
779
  int left_len = space - merge_str;
904
- *left = malloc(left_len + 1);
780
+ *left = malloc(left_len+1);
905
781
  memcpy(*left, merge_str, left_len);
906
782
  (*left)[left_len] = '\0';
907
-
908
- *right = strdup(space + 1);
783
+ *right = strdup(space+1);
909
784
  } else {
910
- // No space - treat as single token
911
785
  *left = strdup(merge_str);
912
786
  *right = strdup("");
913
787
  }
@@ -918,13 +792,10 @@ static EmbedModel *embed_load_gguf(const char *path) {
918
792
  size_t sz;
919
793
  uint8_t *base = map_file(path, &sz);
920
794
  if (!base) return NULL;
921
- uint8_t *cur = base;
922
- uint8_t *end = base + sz;
923
-
795
+ uint8_t *cur = base, *end = base + sz;
924
796
  if (memcmp(cur, "GGUF", 4) != 0) { munmap(base, sz); return NULL; }
925
797
  cur += 4;
926
798
  uint32_t version = rd32(&cur, end);
927
- (void)version;
928
799
  uint64_t n_tensors = rd64(&cur, end);
929
800
  uint64_t n_kv = rd64(&cur, end);
930
801
 
@@ -934,26 +805,20 @@ static EmbedModel *embed_load_gguf(const char *path) {
934
805
  m->mapped_size = sz;
935
806
  m->table = calloc(HASH_SIZE, sizeof(HashNode*));
936
807
  if (!m->table) { free_model_contents(m); return NULL; }
937
-
938
- // Initialize BPE structures
939
808
  bpe_merge_table_init(&m->merges);
940
809
  setup_default_pre_patterns(m);
941
-
942
- // Default values
943
810
  m->unknown_token_id = -1;
944
811
  m->bos_token_id = -1;
945
812
  m->eos_token_id = -1;
946
813
  m->vocab_type = LLAMA_VOCAB_TYPE_NONE;
814
+ m->space_marker[0] = '\0';
947
815
 
948
- /* ---------- Metadata ---------- */
949
816
  int vocab_found = 0;
950
817
  for (uint64_t i = 0; i < n_kv; i++) {
951
818
  char *key = rdstr(&cur, end);
952
819
  if (!key) { free_model_contents(m); return NULL; }
953
820
  uint32_t type = rd32(&cur, end);
954
-
955
- if ((strcmp(key, "tokenizer.ggml.tokens") == 0 ||
956
- strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
821
+ if ((strcmp(key, "tokenizer.ggml.tokens") == 0 || strcmp(key, "tokenizer.ggml.token_list") == 0) && type == 9) {
957
822
  uint32_t subtype = rd32(&cur, end);
958
823
  uint64_t n = rd64(&cur, end);
959
824
  if (subtype != 8) { free(key); free_model_contents(m); return NULL; }
@@ -971,40 +836,29 @@ static EmbedModel *embed_load_gguf(const char *path) {
971
836
  uint32_t subtype = rd32(&cur, end);
972
837
  uint64_t n = rd64(&cur, end);
973
838
  if (subtype == 8) {
974
- // Parse merges
975
839
  for (uint64_t j = 0; j < n && j < MAX_MERGES; j++) {
976
840
  char *merge_str = rdstr(&cur, end);
977
841
  if (merge_str) {
978
842
  char *left, *right;
979
843
  parse_merge(merge_str, &left, &right);
980
- bpe_merge_table_add(&m->merges, left, right, merge_str, j);
981
- free(left);
982
- free(right);
844
+ bpe_merge_table_add(&m->merges, left, right, merge_str, (int)j);
845
+ free(left); free(right);
983
846
  free(merge_str);
984
847
  }
985
848
  }
986
849
  } else {
987
- // Skip if not string array
988
- if (!skip_value(&cur, end, type)) {
989
- free(key); free_model_contents(m); return NULL;
990
- }
850
+ if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
991
851
  }
992
852
  } else if (strcmp(key, "tokenizer.ggml.model") == 0 && type == 8) {
993
853
  char *model_type = rdstr(&cur, end);
994
854
  if (model_type) {
995
- if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0) {
996
- m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
997
- } else if (strcmp(model_type, "bert") == 0) {
998
- m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
999
- }
855
+ if (strcmp(model_type, "gpt2") == 0 || strcmp(model_type, "llama") == 0) m->vocab_type = LLAMA_VOCAB_TYPE_BPE;
856
+ else if (strcmp(model_type, "bert") == 0) m->vocab_type = LLAMA_VOCAB_TYPE_WPM;
1000
857
  free(model_type);
1001
858
  }
1002
859
  } else if (strcmp(key, "tokenizer.ggml.pre") == 0 && type == 8) {
1003
- char *pre_type = rdstr(&cur, end);
1004
- if (pre_type) {
1005
- // Could load custom regex patterns here if needed
1006
- free(pre_type);
1007
- }
860
+ char *pre = rdstr(&cur, end);
861
+ if (pre) free(pre);
1008
862
  } else if (strcmp(key, "tokenizer.ggml.unknown_token_id") == 0 && type == 6) {
1009
863
  m->unknown_token_id = rd32(&cur, end);
1010
864
  } else if (strcmp(key, "tokenizer.ggml.bos_token_id") == 0 && type == 6) {
@@ -1012,20 +866,16 @@ static EmbedModel *embed_load_gguf(const char *path) {
1012
866
  } else if (strcmp(key, "tokenizer.ggml.eos_token_id") == 0 && type == 6) {
1013
867
  m->eos_token_id = rd32(&cur, end);
1014
868
  } else {
1015
- if (!skip_value(&cur, end, type)) {
1016
- free(key); free_model_contents(m); return NULL;
1017
- }
869
+ if (!skip_value(&cur, end, type)) { free(key); free_model_contents(m); return NULL; }
1018
870
  }
1019
871
  free(key);
1020
872
  }
1021
-
1022
873
  if (!vocab_found) { free_model_contents(m); return NULL; }
874
+ detect_space_marker(m);
1023
875
 
1024
876
  uint8_t *after_kv = cur;
1025
877
  align_to_32(&cur, end, base);
1026
878
  uint8_t *tensor_start = cur;
1027
-
1028
- /* ---------- Tensor info ---------- */
1029
879
  int embd_found = 0;
1030
880
  void *raw_tensor_data = NULL;
1031
881
  int tensor_type = -1;
@@ -1039,39 +889,27 @@ static EmbedModel *embed_load_gguf(const char *path) {
1039
889
  if (!name) break;
1040
890
  uint32_t n_dims = rd32(&cur, end);
1041
891
  uint64_t dims[MAX_DIMS] = {0};
1042
- for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++)
1043
- dims[d] = rd64(&cur, end);
892
+ for (uint32_t d = 0; d < n_dims && d < MAX_DIMS; d++) dims[d] = rd64(&cur, end);
1044
893
  uint32_t type = rd32(&cur, end);
1045
894
  uint64_t offset = rd64(&cur, end);
1046
-
1047
895
  int is_token_embd = (strcmp(name, "token_embd.weight") == 0 ||
1048
896
  strcmp(name, "embeddings.word_embeddings.weight") == 0 ||
1049
897
  strcmp(name, "model.embed_tokens.weight") == 0);
1050
-
1051
898
  if (!is_token_embd && n_dims == 2 && m->vocab_size > 0) {
1052
- if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd") != NULL)
1053
- is_token_embd = 1;
1054
- else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd") != NULL)
1055
- is_token_embd = 1;
899
+ if ((uint64_t)m->vocab_size == dims[0] && strstr(name, "embd")) is_token_embd = 1;
900
+ else if ((uint64_t)m->vocab_size == dims[1] && strstr(name, "embd")) is_token_embd = 1;
1056
901
  }
1057
-
1058
902
  if (!embd_found && is_token_embd) {
1059
903
  if (n_dims < 2 || dims[1] == 0) { free(name); free_model_contents(m); return NULL; }
1060
- dim0 = dims[0];
1061
- dim1 = dims[1];
1062
- if (dim0 == (uint64_t)m->vocab_size) {
1063
- m->dim = (int)dim1;
1064
- need_transpose = 0;
1065
- } else if (dim1 == (uint64_t)m->vocab_size) {
1066
- m->dim = (int)dim0;
1067
- need_transpose = 1;
1068
- } else {
1069
- m->dim = (dim0 < dim1) ? (int)dim0 : (int)dim1;
1070
- need_transpose = (dim0 > dim1) ? 1 : 0;
1071
- }
904
+ dim0 = dims[0]; dim1 = dims[1];
905
+ if (dim0 == (uint64_t)m->vocab_size) { m->dim = (int)dim1; need_transpose = 0; }
906
+ else if (dim1 == (uint64_t)m->vocab_size) { m->dim = (int)dim0; need_transpose = 1; }
907
+ else { m->dim = (dim0 < dim1) ? (int)dim0 : (int)dim1; need_transpose = (dim0 > dim1) ? 1 : 0; }
1072
908
  raw_tensor_data = base + offset;
1073
909
  tensor_type = type;
1074
910
  embd_found = 1;
911
+ free(name);
912
+ break;
1075
913
  }
1076
914
  free(name);
1077
915
  }
@@ -1081,13 +919,8 @@ static EmbedModel *embed_load_gguf(const char *path) {
1081
919
  if (!tensor_start) break;
1082
920
  }
1083
921
  }
922
+ if (!embd_found || m->dim == 0) { free_model_contents(m); return NULL; }
1084
923
 
1085
- if (!embd_found || m->dim == 0) {
1086
- free_model_contents(m);
1087
- return NULL;
1088
- }
1089
-
1090
- /* Dequantize */
1091
924
  if (tensor_type == GGML_TYPE_F32 && !need_transpose) {
1092
925
  m->float_data = NULL;
1093
926
  m->tensor_data = raw_tensor_data;
@@ -1095,10 +928,7 @@ static EmbedModel *embed_load_gguf(const char *path) {
1095
928
  int n_rows = need_transpose ? (int)dim1 : (int)dim0;
1096
929
  int n_cols = need_transpose ? (int)dim0 : (int)dim1;
1097
930
  m->float_data = dequantize_tensor(raw_tensor_data, tensor_type, n_rows, n_cols);
1098
- if (!m->float_data) {
1099
- free_model_contents(m);
1100
- return NULL;
1101
- }
931
+ if (!m->float_data) { free_model_contents(m); return NULL; }
1102
932
  m->tensor_data = m->float_data;
1103
933
  }
1104
934
  m->tensor_type = tensor_type;
@@ -1109,79 +939,84 @@ static EmbedModel *embed_load_gguf(const char *path) {
1109
939
  /* ------------------------------------------------------------------------- */
1110
940
  static void embed_text(EmbedModel *m, const char *txt, float *out) {
1111
941
  memset(out, 0, sizeof(float) * m->dim);
1112
-
1113
- // Pre-tokenize using regex
1114
942
  int num_words = 0;
1115
943
  char **words = unicode_regex_split(txt, m->pre_patterns, m->num_pre_patterns, &num_words);
1116
-
1117
944
  if (!words || num_words == 0) {
1118
- // Fallback to space splitting if regex fails
945
+ // Fallback to simple space split
1119
946
  char *copy = strdup(txt);
1120
- if (!copy) return;
1121
-
1122
- char *tok = strtok(copy, " \t\n\r");
1123
- int used = 0;
1124
- const float *embd_matrix = m->tensor_data;
1125
-
1126
- while (tok) {
1127
- int id = hget(m, tok);
1128
- if (id >= 0 && id < m->vocab_size) {
1129
- const float *vec = embd_matrix + id * m->dim;
1130
- for (int i = 0; i < m->dim; i++) out[i] += vec[i];
1131
- used++;
947
+ if (copy) {
948
+ char *tok = strtok(copy, " \t\n\r");
949
+ int used = 0;
950
+ const float *embd = (float*)m->tensor_data;
951
+ while (tok) {
952
+ int id = hget(m, tok);
953
+ if (id >= 0 && id < m->vocab_size) {
954
+ const float *vec = embd + id * m->dim;
955
+ for (int i = 0; i < m->dim; i++) out[i] += vec[i];
956
+ used++;
957
+ }
958
+ tok = strtok(NULL, " \t\n\r");
1132
959
  }
1133
- tok = strtok(NULL, " \t\n\r");
960
+ if (used) { float inv = 1.0f / used; for (int i = 0; i < m->dim; i++) out[i] *= inv; }
961
+ free(copy);
1134
962
  }
1135
-
1136
- if (used > 0) {
1137
- float inv = 1.0f / used;
1138
- for (int i = 0; i < m->dim; i++) out[i] *= inv;
1139
- }
1140
- free(copy);
963
+ if (words) free(words);
1141
964
  return;
1142
965
  }
1143
-
1144
- // Tokenize each word using BPE
1145
- int *token_ids = malloc(m->vocab_size * sizeof(int)); // Max possible tokens
1146
- int num_tokens = 0;
1147
- const float *embd_matrix = m->tensor_data;
966
+
967
+ int *token_ids = malloc(m->vocab_size * sizeof(int));
1148
968
  int used = 0;
1149
-
969
+ const float *embd = (float*)m->tensor_data;
1150
970
  for (int i = 0; i < num_words; i++) {
1151
- num_tokens = 0;
1152
- bpe_tokenize_word(&m->merges, words[i], text_to_id, m, token_ids, &num_tokens);
1153
-
1154
- for (int j = 0; j < num_tokens; j++) {
1155
- int id = token_ids[j];
1156
- if (id >= 0 && id < m->vocab_size) {
1157
- const float *vec = embd_matrix + id * m->dim;
1158
- for (int k = 0; k < m->dim; k++) out[k] += vec[k];
1159
- used++;
1160
- } else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
1161
- // Use unknown token as fallback
1162
- const float *vec = embd_matrix + m->unknown_token_id * m->dim;
1163
- for (int k = 0; k < m->dim; k++) out[k] += vec[k];
1164
- used++;
971
+ char *word = words[i];
972
+ int id = hget(m, word);
973
+ if (id == -1 && m->space_marker[0]) {
974
+ char *with_marker = malloc(strlen(m->space_marker) + strlen(word) + 1);
975
+ strcpy(with_marker, m->space_marker);
976
+ strcat(with_marker, word);
977
+ id = hget(m, with_marker);
978
+ free(with_marker);
979
+ }
980
+ if (id != -1) {
981
+ const float *vec = embd + id * m->dim;
982
+ for (int j = 0; j < m->dim; j++) out[j] += vec[j];
983
+ used++;
984
+ } else {
985
+ int num_tokens = 0;
986
+ bpe_tokenize_word(&m->merges, word, text_to_id, m, token_ids, &num_tokens);
987
+ for (int k = 0; k < num_tokens; k++) {
988
+ int tid = token_ids[k];
989
+ if (tid >= 0 && tid < m->vocab_size) {
990
+ const float *vec = embd + tid * m->dim;
991
+ for (int j = 0; j < m->dim; j++) out[j] += vec[j];
992
+ used++;
993
+ } else if (m->unknown_token_id != -1 && m->unknown_token_id < m->vocab_size) {
994
+ const float *vec = embd + m->unknown_token_id * m->dim;
995
+ for (int j = 0; j < m->dim; j++) out[j] += vec[j];
996
+ used++;
997
+ }
1165
998
  }
1166
999
  }
1167
-
1168
- free(words[i]);
1000
+ free(word);
1169
1001
  }
1170
1002
  free(words);
1171
1003
  free(token_ids);
1172
-
1173
1004
  if (used > 0) {
1174
1005
  float inv = 1.0f / used;
1175
1006
  for (int i = 0; i < m->dim; i++) out[i] *= inv;
1176
1007
  }
1008
+ for (int i = 0; i < m->dim; i++) {
1009
+ if (isnan(out[i]) || isinf(out[i])) {
1010
+ out[i] = 0.0f;
1011
+ }
1012
+ }
1177
1013
  }
1178
1014
 
1179
1015
  /* ------------------------------------------------------------------------- */
1016
+ // Ruby bindings
1180
1017
  static void rb_embedder_free(void *p) {
1181
1018
  ruby_embedder *e = p;
1182
- if (!e) return;
1183
- if (e->model) free_model_contents(e->model);
1184
- free(e);
1019
+ if (e) { if (e->model) free_model_contents(e->model); free(e); }
1185
1020
  }
1186
1021
 
1187
1022
  static size_t rb_embedder_memsize(const void *p) {
@@ -1202,22 +1037,18 @@ static VALUE rb_embedder_alloc(VALUE klass) {
1202
1037
  static VALUE rb_embedder_initialize(VALUE self, VALUE opts) {
1203
1038
  ruby_embedder *e;
1204
1039
  TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
1205
-
1206
1040
  VALUE path = rb_hash_aref(opts, ID2SYM(rb_intern("model")));
1207
1041
  const char *cpath = StringValueCStr(path);
1208
1042
  e->model = embed_load_gguf(cpath);
1209
- if (!e->model)
1210
- rb_raise(rb_eRuntimeError, "failed to load GGUF model");
1043
+ if (!e->model) rb_raise(rb_eRuntimeError, "failed to load GGUF model");
1211
1044
  return self;
1212
1045
  }
1213
1046
 
1214
1047
  static VALUE rb_embed(VALUE self, VALUE opts) {
1215
1048
  ruby_embedder *e;
1216
1049
  TypedData_Get_Struct(self, ruby_embedder, &ruby_embedder_type, e);
1217
-
1218
1050
  VALUE text = rb_hash_aref(opts, ID2SYM(rb_intern("text")));
1219
1051
  const char *ctext = StringValueCStr(text);
1220
-
1221
1052
  VALUE out = rb_str_new(NULL, e->model->dim * sizeof(float));
1222
1053
  embed_text(e->model, ctext, (float*)RSTRING_PTR(out));
1223
1054
  return out;
@@ -1227,5 +1058,5 @@ void Init_mini_embed(void) {
1227
1058
  VALUE c = rb_define_class("MiniEmbed", rb_cObject);
1228
1059
  rb_define_alloc_func(c, rb_embedder_alloc);
1229
1060
  rb_define_method(c, "initialize", rb_embedder_initialize, 1);
1230
- rb_define_method(c, "embeddings", rb_embed, 1);
1061
+ rb_define_method(c, "embed", rb_embed, 1);
1231
1062
  }
data/lib/mini_embed.rb CHANGED
@@ -1,3 +1,17 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  require 'mini_embed/mini_embed'
4
+
5
+ class MiniEmbed
6
+ # @param text [String] - text to extract embeddings from
7
+ # @param type [Symbol, nil] - :binary or :vector - type of data you want to receive
8
+ # @return [String, <Float>] - type == :binary - binary string, type == :vector - array of floats
9
+ def embeddings(text: text, type: :vector)
10
+ binary_data = embed(text: text) # call original C method
11
+
12
+ return binary_data if type == :binary
13
+ return binary_data.unpack('e*') if type == :vector
14
+
15
+ raise ArgumentError, "Unsupported data type: #{type}"
16
+ end
17
+ end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: mini_embed
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Makapoxa