red-candle 1.0.0.pre.7 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +322 -4
  4. data/ext/candle/src/lib.rs +6 -3
  5. data/ext/candle/src/llm/gemma.rs +5 -0
  6. data/ext/candle/src/llm/llama.rs +5 -0
  7. data/ext/candle/src/llm/mistral.rs +5 -0
  8. data/ext/candle/src/llm/mod.rs +1 -89
  9. data/ext/candle/src/llm/quantized_gguf.rs +5 -0
  10. data/ext/candle/src/ner.rs +423 -0
  11. data/ext/candle/src/reranker.rs +24 -21
  12. data/ext/candle/src/ruby/device.rs +6 -6
  13. data/ext/candle/src/ruby/dtype.rs +4 -4
  14. data/ext/candle/src/ruby/embedding_model.rs +36 -33
  15. data/ext/candle/src/ruby/llm.rs +31 -13
  16. data/ext/candle/src/ruby/mod.rs +1 -2
  17. data/ext/candle/src/ruby/tensor.rs +66 -66
  18. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  19. data/ext/candle/src/ruby/utils.rs +6 -24
  20. data/ext/candle/src/tokenizer/loader.rs +108 -0
  21. data/ext/candle/src/tokenizer/mod.rs +103 -0
  22. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  23. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  24. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  25. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  26. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  27. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  28. data/lib/candle/build_info.rb +2 -0
  29. data/lib/candle/device_utils.rb +2 -0
  30. data/lib/candle/ner.rb +345 -0
  31. data/lib/candle/reranker.rb +1 -1
  32. data/lib/candle/tensor.rb +2 -0
  33. data/lib/candle/tokenizer.rb +139 -0
  34. data/lib/candle/version.rb +4 -2
  35. data/lib/candle.rb +2 -0
  36. metadata +126 -3
  37. data/ext/candle/src/ruby/qtensor.rs +0 -69
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6bea0c9f27d5bbbc43e0c2e8fa5456b2b79511f164f7083512025538ca172077
4
- data.tar.gz: b591f559c63c1bdb9fa96beaf1079186284b219e2ceba50d8e2891fa4f3096bd
3
+ metadata.gz: 829a937851c782dfd58b8fb724dc7b08d524d26400047e9f5fc7a5bd0de9cb4b
4
+ data.tar.gz: e8a9420fc310e977968aa396a47e5e5269d107eb8cb7246ca9e2f980a0a28f4d
5
5
  SHA512:
6
- metadata.gz: 380cde32abf995f9766056638d4d0612acef746a84b31451c0aa2fb8cd205fc76b76b8909729a9af267ee6c549f9149253b2bfb70717385b5afe2cd2ad0b3152
7
- data.tar.gz: 7acdfc4de263501be51e3f74933c528b9d723b4da93e78e3abae2454dc2746fb7696c09722fe6666084a48dff098501b2d1af503963b5d739e879bd0fcb96e43
6
+ metadata.gz: 020e23df61d5679612a7892bdbfc7dbcf2d28055df9fc6b9199a8e09933e74d03a0a0fb97d6ddc998f8f8e70e856d63d7eb758638e072cd492734b21681267b7
7
+ data.tar.gz: 5e10f888c2bd74dfdf01c97ca18f6666440edb000d524784ad0ac0676208d9485b8d708fa72d8fe73672f018c038b3526109153540ce5bbc53a81e4e30deccd1
data/Gemfile CHANGED
@@ -1,12 +1,3 @@
1
1
  source "https://rubygems.org"
2
2
 
3
- gemspec
4
-
5
- gem "minitest"
6
- gem "rake"
7
- gem "rake-compiler"
8
-
9
- gem "yard", require: false
10
- gem "yard-rustdoc", require: false
11
-
12
- gem "redcarpet", "~> 3.6"
3
+ gemspec
data/README.md CHANGED
@@ -127,6 +127,8 @@ response = llm.chat(messages)
127
127
 
128
128
  ### GPU Acceleration
129
129
 
130
+ We see an 18x speed up running LLMs under CUDA vs CPU and a >3x speed up running under Metal vs CPU. Details [here](DEVICE_SUPPORT.md#performance-considerations).
131
+
130
132
  ```ruby
131
133
  # CPU works for all models
132
134
  device = Candle::Device.cpu
@@ -166,9 +168,11 @@ This is particularly useful for:
166
168
  - Troubleshooting generation problems
167
169
  - Analyzing model behavior
168
170
 
169
- ## ⚠️ Model Format Requirement: Safetensors Only
171
+ ## ⚠️ Model Format Requirements
172
+
173
+ ### EmbeddingModels and Rerankers: Safetensors Only
170
174
 
171
- Red-Candle **only supports embedding models that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
175
+ Red-Candle **only supports embedding models and rerankers that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
172
176
 
173
177
  **If you encounter an error like:**
174
178
 
@@ -178,13 +182,22 @@ RuntimeError: model.safetensors not found after download. Only safetensors model
178
182
 
179
183
  this means the selected model is not compatible. Please choose a model repo that provides the required file.
180
184
 
185
+ ### LLMs: Safetensors and GGUF Support
186
+
187
+ LLM models support two formats:
188
+ 1. **Safetensors format** - Standard HuggingFace models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`)
189
+ 2. **GGUF quantized format** - Memory-efficient quantized models (e.g., `TheBloke/Llama-2-7B-Chat-GGUF`)
190
+
191
+ See the [Quantized Model Support](#quantized-model-support-gguf) section for details on using GGUF models.
192
+
181
193
  ## Supported Embedding Models
182
194
 
183
195
  Red-Candle supports the following embedding model types from Hugging Face:
184
196
 
185
197
  1. `Candle::EmbeddingModelType::JINA_BERT` - Jina BERT models (e.g., `jinaai/jina-embeddings-v2-base-en`) (**safetensors required**)
186
- 2. `Candle::EmbeddingModelType::STANDARD_BERT` - Standard BERT models (e.g., `sentence-transformers/all-MiniLM-L6-v2`) (**safetensors required**)
198
+ 2. `Candle::EmbeddingModelType::MINILM` - MINILM models (e.g., `sentence-transformers/all-MiniLM-L6-v2`) (**safetensors required**)
187
199
  3. `Candle::EmbeddingModelType::DISTILBERT` - DistilBERT models (e.g., `distilbert-base-uncased-finetuned-sst-2-english`) (**safetensors required**)
200
+ 4. `Candle::EmbeddingModelType::STANDARD_BERT` - Standard BERT models (e.g., `scientistcom/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext`) (**safetensors required**)
188
201
 
189
202
  > **Note:** Most official BERT and DistilBERT models do _not_ provide safetensors. Please check the model repo before use.
190
203
 
@@ -260,7 +273,7 @@ ranked_results = reranker.rerank(query, documents, pooling_method: "pooler", app
260
273
  # Or apply sigmoid activation to get scores between 0 and 1
261
274
  sigmoid_results = reranker.rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
262
275
 
263
- # The pooler method is the default and is recommended for cross-encoders, as is apply_sigmod, so the above is the same as:
276
+ # The pooler method is the default and is recommended for cross-encoders, as is apply_sigmoid, so the above is the same as:
264
277
  ranked_results = reranker.rerank(query, documents)
265
278
 
266
279
  # Results are returned as an array of hashes, sorted by relevance
@@ -351,6 +364,311 @@ The reranker uses a BERT-based architecture that:
351
364
 
352
365
  This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
353
366
 
367
+ ## Tokenizer
368
+
369
+ Red-Candle provides direct access to tokenizers for text preprocessing and analysis. This is useful for understanding how models process text, debugging issues, and building custom NLP pipelines.
370
+
371
+ ### Basic Usage
372
+
373
+ ```ruby
374
+ require 'candle'
375
+
376
+ # Load a tokenizer from HuggingFace
377
+ tokenizer = Candle::Tokenizer.from_pretrained("bert-base-uncased")
378
+
379
+ # Encode text to token IDs
380
+ token_ids = tokenizer.encode("Hello, world!")
381
+ # => [101, 7592, 1010, 2088, 999, 102]
382
+
383
+ # Decode token IDs back to text
384
+ text = tokenizer.decode(token_ids)
385
+ # => "hello, world!"
386
+
387
+ # Get token strings (subwords) - useful for visualization
388
+ tokens = tokenizer.encode_to_tokens("Hello, world!")
389
+ # => ["[CLS]", "hello", ",", "world", "!", "[SEP]"]
390
+
391
+ # Get both IDs and tokens together
392
+ result = tokenizer.encode_with_tokens("preprocessing")
393
+ # => {"ids" => [101, 3653, 22618, 2527, 102],
394
+ # "tokens" => ["[CLS]", "prep", "##ro", "##ces", "##sing", "[SEP]"]}
395
+ ```
396
+
397
+ ### Batch Processing
398
+
399
+ ```ruby
400
+ # Encode multiple texts at once
401
+ texts = ["Hello world", "How are you?", "Tokenizers are cool"]
402
+ batch_ids = tokenizer.encode_batch(texts)
403
+
404
+ # Get token strings for multiple texts
405
+ batch_tokens = tokenizer.encode_batch_to_tokens(texts)
406
+ ```
407
+
408
+ ### Vocabulary Access
409
+
410
+ ```ruby
411
+ # Get vocabulary size
412
+ vocab_size = tokenizer.vocab_size
413
+ # => 30522
414
+
415
+ # Get full vocabulary as a hash
416
+ vocab = tokenizer.get_vocab
417
+ # vocab["hello"] => 7592
418
+
419
+ # Convert a specific token ID to its string
420
+ token_str = tokenizer.id_to_token(7592)
421
+ # => "hello"
422
+
423
+ # Get special tokens
424
+ special = tokenizer.get_special_tokens
425
+ # => {"cls_token" => 101, "sep_token" => 102, "pad_token" => 0, ...}
426
+ ```
427
+
428
+ ### Configuration
429
+
430
+ ```ruby
431
+ # Create a tokenizer with padding enabled
432
+ padded_tokenizer = tokenizer.with_padding(length: 128)
433
+
434
+ # Create a tokenizer with truncation
435
+ truncated_tokenizer = tokenizer.with_truncation(512)
436
+
437
+ # Configure padding with more options
438
+ padded_tokenizer = tokenizer.with_padding(
439
+ length: 128, # Fixed length padding
440
+ direction: "right", # Pad on the right (default)
441
+ pad_token: "[PAD]" # Padding token
442
+ )
443
+ ```
444
+
445
+ ### Model Integration
446
+
447
+ All models expose their tokenizers:
448
+
449
+ ```ruby
450
+ # From LLM
451
+ llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
452
+ llm_tokenizer = llm.tokenizer
453
+
454
+ # From EmbeddingModel
455
+ embedding_model = Candle::EmbeddingModel.new
456
+ emb_tokenizer = embedding_model.tokenizer
457
+
458
+ # From Reranker
459
+ reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2")
460
+ rank_tokenizer = reranker.tokenizer
461
+ ```
462
+
463
+ ### Understanding Subword Tokenization
464
+
465
+ Modern tokenizers split unknown or rare words into subword pieces:
466
+
467
+ ```ruby
468
+ # See how words are split into subwords
469
+ result = tokenizer.encode_with_tokens("unbelievable")
470
+ # => {"ids" => [101, 4895, 6499, 102],
471
+ # "tokens" => ["[CLS]", "un", "##believable", "[SEP]"]}
472
+
473
+ # The ## prefix indicates a continuation of the previous token
474
+ complex = tokenizer.encode_to_tokens("preprocessing tokenization")
475
+ # => ["[CLS]", "prep", "##ro", "##ces", "##sing", "token", "##ization", "[SEP]"]
476
+ ```
477
+
478
+ ### Use Cases
479
+
480
+ - **Token Analysis**: Understand how your text is being processed by models
481
+ - **Debugging**: See why certain inputs might cause unexpected model behavior
482
+ - **Custom Preprocessing**: Build your own text processing pipelines
483
+ - **Educational**: Teach how modern NLP models handle text
484
+ - **NER Preparation**: Get aligned tokens for named entity recognition tasks
485
+
486
+ ## Named Entity Recognition (NER)
487
+
488
+ Red-Candle includes comprehensive Named Entity Recognition capabilities for extracting entities like people, organizations, locations, and custom entity types from text.
489
+
490
+ ### Model-based NER
491
+
492
+ Load pre-trained NER models from HuggingFace:
493
+
494
+ ```ruby
495
+ require 'candle'
496
+
497
+ # Load a pre-trained NER model
498
+ ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
499
+
500
+ # Or load a model with a specific tokenizer (for models without tokenizer.json)
501
+ ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
502
+
503
+ # Extract entities from text
504
+ text = "Apple Inc. was founded by Steve Jobs and Steve Wozniak in Cupertino, California."
505
+ entities = ner.extract_entities(text)
506
+
507
+ entities.each do |entity|
508
+ puts "#{entity['text']} (#{entity['label']}) - confidence: #{entity['confidence'].round(2)}"
509
+ end
510
+ # Output:
511
+ # Apple Inc. (ORG) - confidence: 0.99
512
+ # Steve Jobs (PER) - confidence: 0.99
513
+ # Steve Wozniak (PER) - confidence: 0.98
514
+ # Cupertino (LOC) - confidence: 0.97
515
+ # California (LOC) - confidence: 0.98
516
+
517
+ # Adjust confidence threshold (default: 0.9)
518
+ entities = ner.extract_entities(text, confidence_threshold: 0.95)
519
+
520
+ # Get token-level predictions for detailed analysis
521
+ tokens = ner.predict_tokens(text)
522
+ ```
523
+
524
+ ### Pattern-based Recognition
525
+
526
+ For domain-specific entities, use regex patterns:
527
+
528
+ ```ruby
529
+ # Create pattern-based recognizers
530
+ email_recognizer = Candle::PatternEntityRecognizer.new("EMAIL", [
531
+ /\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b/
532
+ ])
533
+
534
+ phone_recognizer = Candle::PatternEntityRecognizer.new("PHONE", [
535
+ /\b\d{3}[-.]?\d{3}[-.]?\d{4}\b/, # 555-123-4567
536
+ /\b\(\d{3}\)\s*\d{3}[-.]?\d{4}\b/, # (555) 123-4567
537
+ /\b\+1\s*\d{3}[-.]?\d{3}[-.]?\d{4}\b/ # +1 555-123-4567
538
+ ])
539
+
540
+ # Extract entities
541
+ text = "Contact us at info@example.com or call 555-123-4567"
542
+ email_entities = email_recognizer.recognize(text)
543
+ phone_entities = phone_recognizer.recognize(text)
544
+ ```
545
+
546
+ ### Gazetteer-based Recognition
547
+
548
+ Use dictionaries for known entities:
549
+
550
+ ```ruby
551
+ # Create gazetteer recognizers
552
+ companies = ["Apple", "Google", "Microsoft", "Amazon", "Tesla"]
553
+ company_recognizer = Candle::GazetteerEntityRecognizer.new("COMPANY", companies)
554
+
555
+ # Load from file
556
+ drug_recognizer = Candle::GazetteerEntityRecognizer.new("DRUG")
557
+ drug_recognizer.load_from_file("drug_names.txt")
558
+
559
+ # Case-sensitive matching
560
+ product_recognizer = Candle::GazetteerEntityRecognizer.new("PRODUCT",
561
+ ["iPhone", "iPad", "MacBook"],
562
+ case_sensitive: true
563
+ )
564
+ ```
565
+
566
+ ### Hybrid NER
567
+
568
+ Combine ML models with rule-based approaches for best results:
569
+
570
+ ```ruby
571
+ # Create hybrid NER system
572
+ hybrid = Candle::HybridNER.new("Babelscape/wikineural-multilingual-ner")
573
+
574
+ # Add pattern recognizers
575
+ hybrid.add_pattern_recognizer("EMAIL", [/\b[\w._%+-]+@[\w.-]+\.[A-Z|a-z]{2,}\b/])
576
+ hybrid.add_pattern_recognizer("PHONE", [/\b\d{3}[-.]?\d{3}[-.]?\d{4}\b/])
577
+
578
+ # Add gazetteer recognizers
579
+ hybrid.add_gazetteer_recognizer("COMPANY", ["Apple", "Google", "Microsoft"])
580
+ hybrid.add_gazetteer_recognizer("PRODUCT", ["iPhone", "Android", "Windows"])
581
+
582
+ # Extract all entities
583
+ text = "John Smith (john@apple.com) from Apple called about the new iPhone. Reach him at 555-0123."
584
+ entities = hybrid.extract_entities(text)
585
+
586
+ # Results include entities from all recognizers
587
+ # Overlapping entities are automatically resolved (highest confidence wins)
588
+ ```
589
+
590
+ ### Custom Entity Types
591
+
592
+ Perfect for specialized domains:
593
+
594
+ ```ruby
595
+ # Biomedical entities
596
+ gene_patterns = [
597
+ /\b[A-Z][A-Z0-9]{2,}\b/, # TP53, BRCA1, EGFR
598
+ /\bCD\d+\b/, # CD4, CD8, CD34
599
+ /\b[A-Z]+\d[A-Z]\d*\b/ # RAD51C, PALB2
600
+ ]
601
+ gene_recognizer = Candle::PatternEntityRecognizer.new("GENE", gene_patterns)
602
+
603
+ # Financial entities
604
+ ticker_patterns = [
605
+ /\$[A-Z]{1,5}\b/, # $AAPL, $GOOGL
606
+ /\b[A-Z]{1,5}\.NYSE\b/, # AAPL.NYSE
607
+ /\b[A-Z]{1,5}\.NASDAQ\b/ # GOOGL.NASDAQ
608
+ ]
609
+ ticker_recognizer = Candle::PatternEntityRecognizer.new("TICKER", ticker_patterns)
610
+
611
+ # Legal entities
612
+ case_patterns = [
613
+ /\b\d+\s+F\.\d+\s+\d+\b/, # 123 F.3d 456
614
+ /\b\d+\s+U\.S\.\s+\d+\b/, # 123 U.S. 456
615
+ /\bNo\.\s+\d+-\d+\b/ # No. 20-1234
616
+ ]
617
+ case_recognizer = Candle::PatternEntityRecognizer.new("CASE", case_patterns)
618
+ ```
619
+
620
+ ### Available Pre-trained Models
621
+
622
+ Popular NER models on HuggingFace:
623
+
624
+ ```ruby
625
+ # General multilingual NER (4 entity types: PER, ORG, LOC, MISC)
626
+ ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
627
+
628
+ # English NER (requires separate tokenizer)
629
+ ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
630
+
631
+ # Multilingual NER
632
+ ner = Candle::NER.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
633
+
634
+ # OntoNotes 5 (18 entity types including DATE, TIME, MONEY, etc.)
635
+ ner = Candle::NER.from_pretrained("flair/ner-english-ontonotes-large")
636
+
637
+ # Biomedical NER
638
+ ner = Candle::NER.from_pretrained("dmis-lab/biobert-base-cased-v1.2")
639
+ ner = Candle::NER.from_pretrained("allenai/scibert_scivocab_uncased")
640
+ ```
641
+
642
+ ### Performance Tips
643
+
644
+ 1. **Device Selection**: Use GPU for faster inference
645
+ ```ruby
646
+ ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner", device: Candle::Device.metal)
647
+ ```
648
+
649
+ 2. **Batch Processing**: Process multiple texts together when possible
650
+
651
+ 3. **Confidence Threshold**: Balance precision/recall with appropriate thresholds
652
+
653
+ 4. **Entity Resolution**: The hybrid NER automatically handles overlapping entities
654
+
655
+ ### Output Format
656
+
657
+ All NER methods return entities in a consistent format:
658
+
659
+ ```ruby
660
+ {
661
+ "text" => "Apple Inc.", # The entity text
662
+ "label" => "ORG", # Entity type
663
+ "start" => 0, # Character start position
664
+ "end" => 10, # Character end position
665
+ "confidence" => 0.99, # Confidence score (0-1)
666
+ "token_start" => 0, # Token start index (model-based only)
667
+ "token_end" => 2, # Token end index (model-based only)
668
+ "source" => "model" # Source: "model", "pattern", or "gazetteer"
669
+ }
670
+ ```
671
+
354
672
  ## Common Runtime Errors
355
673
 
356
674
  ### 1. Weight is negative, too large or not a valid number
@@ -1,11 +1,13 @@
1
1
  use magnus::{function, prelude::*, Ruby};
2
2
 
3
3
  use crate::ruby::candle_utils;
4
- use crate::ruby::Result as RbResult;
4
+ use crate::ruby::Result;
5
5
 
6
6
  pub mod llm;
7
+ pub mod ner;
7
8
  pub mod reranker;
8
9
  pub mod ruby;
10
+ pub mod tokenizer;
9
11
 
10
12
  // Configuration detection from build.rs
11
13
  #[cfg(all(has_metal, not(force_cpu)))]
@@ -33,7 +35,7 @@ pub fn get_build_info() -> magnus::RHash {
33
35
  }
34
36
 
35
37
  #[magnus::init]
36
- fn init(ruby: &Ruby) -> RbResult<()> {
38
+ fn init(ruby: &Ruby) -> Result<()> {
37
39
  let rb_candle = ruby.define_module("Candle")?;
38
40
 
39
41
  // Export build info
@@ -41,11 +43,12 @@ fn init(ruby: &Ruby) -> RbResult<()> {
41
43
 
42
44
  ruby::init_embedding_model(rb_candle)?;
43
45
  ruby::init_llm(rb_candle)?;
46
+ ner::init(rb_candle)?;
44
47
  reranker::init(rb_candle)?;
45
48
  ruby::dtype::init(rb_candle)?;
46
- ruby::qtensor::init(rb_candle)?;
47
49
  ruby::device::init(rb_candle)?;
48
50
  ruby::tensor::init(rb_candle)?;
51
+ ruby::tokenizer::init(rb_candle)?;
49
52
  candle_utils(rb_candle)?;
50
53
 
51
54
  Ok(())
@@ -21,6 +21,11 @@ impl Gemma {
21
21
  self.model.clear_kv_cache();
22
22
  }
23
23
 
24
+ /// Get the tokenizer
25
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
26
+ &self.tokenizer
27
+ }
28
+
24
29
  /// Load a Gemma model from HuggingFace Hub
25
30
  pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
31
  let api = Api::new()
@@ -28,6 +28,11 @@ impl Llama {
28
28
  }
29
29
  }
30
30
 
31
+ /// Get the tokenizer
32
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
33
+ &self.tokenizer
34
+ }
35
+
31
36
  /// Load a Llama model from HuggingFace Hub
32
37
  pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
33
38
  let api = Api::new()
@@ -21,6 +21,11 @@ impl Mistral {
21
21
  self.model.clear_kv_cache();
22
22
  }
23
23
 
24
+ /// Get the tokenizer
25
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
26
+ &self.tokenizer
27
+ }
28
+
24
29
  /// Load a Mistral model from HuggingFace Hub
25
30
  pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
31
  let api = Api::new()
@@ -1,5 +1,4 @@
1
1
  use candle_core::{Device, Result as CandleResult};
2
- use tokenizers::Tokenizer;
3
2
 
4
3
  pub mod mistral;
5
4
  pub mod llama;
@@ -11,6 +10,7 @@ pub mod quantized_gguf;
11
10
  pub use generation_config::GenerationConfig;
12
11
  pub use text_generation::TextGeneration;
13
12
  pub use quantized_gguf::QuantizedGGUF;
13
+ pub use crate::tokenizer::TokenizerWrapper;
14
14
 
15
15
  /// Trait for text generation models
16
16
  pub trait TextGenerator: Send + Sync {
@@ -37,92 +37,4 @@ pub trait TextGenerator: Send + Sync {
37
37
 
38
38
  /// Clear any cached state (like KV cache)
39
39
  fn clear_cache(&mut self);
40
- }
41
-
42
- /// Common structure for managing tokenizer
43
- #[derive(Debug)]
44
- pub struct TokenizerWrapper {
45
- tokenizer: Tokenizer,
46
- }
47
-
48
- impl TokenizerWrapper {
49
- pub fn new(tokenizer: Tokenizer) -> Self {
50
- Self { tokenizer }
51
- }
52
-
53
- pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
54
- let encoding = self.tokenizer
55
- .encode(text, add_special_tokens)
56
- .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
57
- Ok(encoding.get_ids().to_vec())
58
- }
59
-
60
- pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
61
- self.tokenizer
62
- .decode(tokens, skip_special_tokens)
63
- .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
64
- }
65
-
66
- pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
67
- self.tokenizer
68
- .id_to_token(token)
69
- .map(|s| s.to_string())
70
- .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
71
- }
72
-
73
- /// Decode a single token for streaming output
74
- pub fn decode_token(&self, token: u32) -> CandleResult<String> {
75
- // Decode the single token properly
76
- self.decode(&[token], true)
77
- }
78
-
79
- /// Decode tokens incrementally for streaming
80
- /// This is more efficient than decoding single tokens
81
- pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
82
- if new_tokens_start >= all_tokens.len() {
83
- return Ok(String::new());
84
- }
85
-
86
- // Decode all tokens up to this point
87
- let full_text = self.decode(all_tokens, true)?;
88
-
89
- // If we're at the start, return everything
90
- if new_tokens_start == 0 {
91
- return Ok(full_text);
92
- }
93
-
94
- // Otherwise, decode up to the previous token and return the difference
95
- let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
96
-
97
- // Find the common prefix between the two strings to handle cases where
98
- // the tokenizer might produce slightly different text when decoding
99
- // different token sequences
100
- let common_prefix_len = full_text
101
- .char_indices()
102
- .zip(previous_text.chars())
103
- .take_while(|((_, c1), c2)| c1 == c2)
104
- .count();
105
-
106
- // Find the byte position of the character boundary
107
- let byte_pos = full_text
108
- .char_indices()
109
- .nth(common_prefix_len)
110
- .map(|(pos, _)| pos)
111
- .unwrap_or(full_text.len());
112
-
113
- // Return only the new portion
114
- Ok(full_text[byte_pos..].to_string())
115
- }
116
-
117
- /// Format tokens with debug information
118
- pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
119
- let mut result = String::new();
120
-
121
- for &token in tokens {
122
- let token_piece = self.token_to_piece(token)?;
123
- result.push_str(&format!("[{}:{}]", token, token_piece));
124
- }
125
-
126
- Ok(result)
127
- }
128
40
  }
@@ -28,6 +28,11 @@ enum ModelType {
28
28
  }
29
29
 
30
30
  impl QuantizedGGUF {
31
+ /// Get the tokenizer
32
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
33
+ &self.tokenizer
34
+ }
35
+
31
36
  /// Load a quantized model from a GGUF file
32
37
  pub async fn from_pretrained(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
33
38
  // Check if user specified an exact GGUF filename