red-candle 1.0.0.pre.5 → 1.0.0.pre.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +164 -4
- data/Rakefile +1 -3
- data/ext/candle/src/llm/gemma.rs +277 -0
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +16 -79
- data/ext/candle/src/llm/mistral.rs +16 -89
- data/ext/candle/src/llm/mod.rs +59 -0
- data/ext/candle/src/llm/quantized_gguf.rs +496 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ruby/embedding_model.rs +0 -1
- data/ext/candle/src/ruby/llm.rs +84 -29
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6bea0c9f27d5bbbc43e0c2e8fa5456b2b79511f164f7083512025538ca172077
|
4
|
+
data.tar.gz: b591f559c63c1bdb9fa96beaf1079186284b219e2ceba50d8e2891fa4f3096bd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 380cde32abf995f9766056638d4d0612acef746a84b31451c0aa2fb8cd205fc76b76b8909729a9af267ee6c549f9149253b2bfb70717385b5afe2cd2ad0b3152
|
7
|
+
data.tar.gz: 7acdfc4de263501be51e3f74933c528b9d723b4da93e78e3abae2454dc2746fb7696c09722fe6666084a48dff098501b2d1af503963b5d739e879bd0fcb96e43
|
data/README.md
CHANGED
@@ -47,9 +47,46 @@ Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
|
|
47
47
|
|
48
48
|
### Supported Models
|
49
49
|
|
50
|
+
- **Gemma**: Google's Gemma models (e.g., `google/gemma-2b`, `google/gemma-7b`, `google/gemma-2b-it`)
|
50
51
|
- **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
|
51
52
|
- **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
|
52
53
|
|
54
|
+
### Quantized Model Support (GGUF)
|
55
|
+
|
56
|
+
Red-Candle supports quantized models in GGUF format, offering 4-8x memory reduction:
|
57
|
+
|
58
|
+
> **Note on GGUF Support**: Red-Candle now uses a unified GGUF loader that automatically detects the model architecture from the GGUF file. This means all GGUF models (including Mistral models from TheBloke) should now work correctly! The loader automatically selects the appropriate tokenizer based on the model type to ensure proper text generation.
|
59
|
+
|
60
|
+
```ruby
|
61
|
+
# Load quantized models - always specify the GGUF filename
|
62
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
63
|
+
device: device,
|
64
|
+
gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
|
65
|
+
|
66
|
+
# Register custom tokenizer mappings for your models
|
67
|
+
Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
|
68
|
+
|
69
|
+
# Popular quantized model sources:
|
70
|
+
# - TheBloke: Extensive collection of GGUF models
|
71
|
+
# - Search HuggingFace for "GGUF" models
|
72
|
+
```
|
73
|
+
|
74
|
+
**Memory usage comparison (7B models):**
|
75
|
+
- Full precision: ~28 GB
|
76
|
+
- Q8_0 (8-bit): ~7 GB - Best quality, larger size
|
77
|
+
- Q5_K_M (5-bit): ~4.5 GB - Very good quality
|
78
|
+
- Q4_K_M (4-bit): ~4 GB - Recommended default, best balance
|
79
|
+
- Q3_K_M (3-bit): ~3 GB - Good for memory-constrained systems
|
80
|
+
|
81
|
+
**Quantization levels explained:**
|
82
|
+
- **Q8_0**: Almost identical to full model, use when quality is paramount
|
83
|
+
- **Q5_K_M**: Excellent quality with good compression
|
84
|
+
- **Q4_K_M**: Best balance of quality/size/speed (recommended default)
|
85
|
+
- **Q3_K_M**: Noticeable quality reduction but very compact
|
86
|
+
- **Q2_K**: ⚠️ **Not recommended** - Can cause inference errors due to extreme quantization
|
87
|
+
|
88
|
+
> **Warning**: Q2_K quantization can lead to "weight is negative, too large or not a valid number" errors during inference. Use Q3_K_M or higher for stable operation.
|
89
|
+
|
53
90
|
> ### ⚠️ Huggingface login warning
|
54
91
|
>
|
55
92
|
> Many models, including the one below, require you to agree to the terms. You'll need to:
|
@@ -67,10 +104,10 @@ device = Candle::Device.cpu # CPU (default)
|
|
67
104
|
device = Candle::Device.metal # Apple GPU (Metal)
|
68
105
|
device = Candle::Device.cuda # NVIDIA GPU (CUDA)
|
69
106
|
|
70
|
-
# Load a
|
71
|
-
llm = Candle::LLM.from_pretrained("
|
72
|
-
#
|
73
|
-
llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
|
107
|
+
# Load a model
|
108
|
+
llm = Candle::LLM.from_pretrained("google/gemma-2b-it", device: device) # Gemma
|
109
|
+
# llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device) # Llama
|
110
|
+
# llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device) # Mistral
|
74
111
|
|
75
112
|
# Generate text
|
76
113
|
response = llm.generate("What is Ruby?", config: Candle::GenerationConfig.balanced)
|
@@ -102,6 +139,33 @@ device = Candle::Device.metal
|
|
102
139
|
device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
|
103
140
|
```
|
104
141
|
|
142
|
+
### Debugging Token Generation
|
143
|
+
|
144
|
+
For debugging purposes, you can enable raw token output to see both token IDs and their raw representations:
|
145
|
+
|
146
|
+
```ruby
|
147
|
+
# Enable debug mode to see raw tokens during generation
|
148
|
+
config = Candle::GenerationConfig.balanced(debug_tokens: true)
|
149
|
+
|
150
|
+
# Non-streaming generation with debug tokens
|
151
|
+
result = llm.generate("Hello, world!", config: config)
|
152
|
+
puts result
|
153
|
+
# Output: [15043:Hello][11:,][1917:world][0:!]
|
154
|
+
|
155
|
+
# Streaming generation with debug tokens
|
156
|
+
llm.generate_stream("Hello, world!", config: config) do |text|
|
157
|
+
print text # Will show each token as it's generated: [15043:Hello][11:,][1917:world][0:!]
|
158
|
+
end
|
159
|
+
|
160
|
+
# Works with all models (Llama, Mistral, Gemma, and quantized GGUF models)
|
161
|
+
```
|
162
|
+
|
163
|
+
This is particularly useful for:
|
164
|
+
- Debugging tokenization issues
|
165
|
+
- Understanding how the model processes text
|
166
|
+
- Troubleshooting generation problems
|
167
|
+
- Analyzing model behavior
|
168
|
+
|
105
169
|
## ⚠️ Model Format Requirement: Safetensors Only
|
106
170
|
|
107
171
|
Red-Candle **only supports embedding models that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
|
@@ -287,6 +351,102 @@ The reranker uses a BERT-based architecture that:
|
|
287
351
|
|
288
352
|
This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
|
289
353
|
|
354
|
+
## Common Runtime Errors
|
355
|
+
|
356
|
+
### 1. Weight is negative, too large or not a valid number
|
357
|
+
|
358
|
+
**Error:**
|
359
|
+
```
|
360
|
+
/Users/cpetersen/src/scientist/red-candle/lib/candle/llm.rb:25:in `_generate_stream': Generation failed: A weight is negative, too large or not a valid number (RuntimeError)
|
361
|
+
from /Users/cpetersen/src/scientist/red-candle/lib/candle/llm.rb:25:in `generate_stream'
|
362
|
+
...
|
363
|
+
```
|
364
|
+
|
365
|
+
**Cause:** This error occurs when using overly aggressive quantization levels (particularly Q2_K) that result in numerical instability during inference. The 2-bit quantization can cause weights to become corrupted or produce NaN/Inf values.
|
366
|
+
|
367
|
+
**Solution:** Use a higher quantization level. Recommended options:
|
368
|
+
- Q4_K_M (4-bit) - Best balance of quality and size
|
369
|
+
- Q5_K_M (5-bit) - Higher quality with slightly larger size
|
370
|
+
- Q3_K_M (3-bit) - Minimum recommended quantization
|
371
|
+
|
372
|
+
```ruby
|
373
|
+
# Instead of Q2_K:
|
374
|
+
llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
375
|
+
device: device,
|
376
|
+
gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
|
377
|
+
```
|
378
|
+
|
379
|
+
### 2. Cannot find tensor model.embed_tokens.weight
|
380
|
+
|
381
|
+
**Error:**
|
382
|
+
```
|
383
|
+
Failed to load quantized model: cannot find tensor model.embed_tokens.weight (RuntimeError)
|
384
|
+
```
|
385
|
+
|
386
|
+
**Cause:** This error was common in earlier versions when loading GGUF files with incompatible tensor naming conventions. The unified GGUF loader in version 1.0.0+ should handle most GGUF files correctly.
|
387
|
+
|
388
|
+
**If you still encounter this error:**
|
389
|
+
1. Ensure you're using the latest version of red-candle (1.0.0 or higher)
|
390
|
+
2. Make sure to specify the exact GGUF filename:
|
391
|
+
```ruby
|
392
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
393
|
+
device: device,
|
394
|
+
gguf_file: "mistral-7b-instruct-v0.2.Q4_K_M.gguf")
|
395
|
+
```
|
396
|
+
3. If the error persists, the GGUF file may use an unsupported architecture or format
|
397
|
+
|
398
|
+
### 3. No GGUF file found in repository
|
399
|
+
|
400
|
+
**Error:**
|
401
|
+
```
|
402
|
+
Failed to load quantized model: No GGUF file found in repository TheBloke/model-name-GGUF. Try specifying a quantization level like Q4_K_M, Q5_K_M, or Q8_0. (RuntimeError)
|
403
|
+
```
|
404
|
+
|
405
|
+
**Cause:** The automatic GGUF file detection couldn't find a matching file, often due to naming variations.
|
406
|
+
|
407
|
+
**Solution:** Specify the exact GGUF filename:
|
408
|
+
```ruby
|
409
|
+
# Visit the HuggingFace repository to find the exact filename
|
410
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
411
|
+
device: device,
|
412
|
+
gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
|
413
|
+
```
|
414
|
+
|
415
|
+
### 4. Failed to download tokenizer
|
416
|
+
|
417
|
+
**Error:**
|
418
|
+
```
|
419
|
+
Failed to load quantized model: Failed to download tokenizer: request error: HTTP status client error (404 Not Found)
|
420
|
+
```
|
421
|
+
|
422
|
+
**Cause:** GGUF repositories often don't include separate tokenizer files since they're embedded in the GGUF format.
|
423
|
+
|
424
|
+
**Solution:** The code now includes fallback tokenizer loading. If you still encounter this error, ensure you're using the latest version of red-candle.
|
425
|
+
|
426
|
+
### 5. Missing metadata in GGUF file
|
427
|
+
|
428
|
+
**Error:**
|
429
|
+
```
|
430
|
+
Failed to load GGUF model: cannot find gemma3.attention.head_count in metadata (RuntimeError)
|
431
|
+
```
|
432
|
+
or
|
433
|
+
```
|
434
|
+
Failed to load GGUF model: cannot find llama.attention.head_count in metadata (RuntimeError)
|
435
|
+
```
|
436
|
+
|
437
|
+
**Cause:** Some GGUF files may have been created with older conversion tools that don't include all required metadata fields.
|
438
|
+
|
439
|
+
**Solution:**
|
440
|
+
- Try a different GGUF file from the same model
|
441
|
+
- Look for GGUF files from TheBloke or other reputable sources
|
442
|
+
- Check if a newer version of the GGUF file is available
|
443
|
+
- Some Gemma GGUF files may not be compatible with the current loader
|
444
|
+
|
445
|
+
**Known compatibility issues:**
|
446
|
+
- `lmstudio-ai/gemma-2b-it-GGUF` - Missing required metadata fields
|
447
|
+
- Gemma 3 GGUF files may require specific tokenizers that are not publicly available
|
448
|
+
- For best compatibility, use Llama or Mistral GGUF files from TheBloke
|
449
|
+
|
290
450
|
## Development
|
291
451
|
|
292
452
|
FORK IT!
|
data/Rakefile
CHANGED
@@ -8,7 +8,7 @@ task default: :test
|
|
8
8
|
Rake::TestTask.new do |t|
|
9
9
|
t.deps << :compile
|
10
10
|
t.libs << "test"
|
11
|
-
t.test_files = FileList["test/**/*_test.rb"]
|
11
|
+
t.test_files = FileList["test/**/*_test.rb"].exclude("test/benchmarks/**/*_test.rb")
|
12
12
|
end
|
13
13
|
|
14
14
|
spec = Bundler.load_gemspec("candle.gemspec")
|
@@ -36,7 +36,6 @@ end
|
|
36
36
|
|
37
37
|
desc "Run benchmark tests"
|
38
38
|
Rake::TestTask.new("test:benchmark") do |t|
|
39
|
-
ENV['CANDLE_RUN_BENCHMARKS'] = 'true'
|
40
39
|
t.deps << :compile
|
41
40
|
t.libs << "test"
|
42
41
|
t.test_files = FileList["test/benchmarks/**/*_test.rb"]
|
@@ -59,7 +58,6 @@ end
|
|
59
58
|
|
60
59
|
desc "Run benchmarks with device tests"
|
61
60
|
task "test:device:benchmark" => :compile do
|
62
|
-
ENV['CANDLE_RUN_BENCHMARKS'] = 'true'
|
63
61
|
ENV['CANDLE_TEST_VERBOSE'] = 'true'
|
64
62
|
Rake::Task["test:device"].invoke
|
65
63
|
Rake::Task["test:benchmark"].invoke
|
@@ -0,0 +1,277 @@
|
|
1
|
+
use candle_core::{DType, Device, Result as CandleResult, Tensor};
|
2
|
+
use candle_nn::VarBuilder;
|
3
|
+
use candle_transformers::models::gemma::{Config, Model as GemmaModel};
|
4
|
+
use hf_hub::{api::tokio::Api, Repo};
|
5
|
+
use tokenizers::Tokenizer;
|
6
|
+
|
7
|
+
use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
|
8
|
+
|
9
|
+
#[derive(Debug)]
|
10
|
+
pub struct Gemma {
|
11
|
+
model: GemmaModel,
|
12
|
+
tokenizer: TokenizerWrapper,
|
13
|
+
device: Device,
|
14
|
+
model_id: String,
|
15
|
+
eos_token_id: u32,
|
16
|
+
}
|
17
|
+
|
18
|
+
impl Gemma {
|
19
|
+
/// Clear the KV cache between generations
|
20
|
+
pub fn clear_kv_cache(&mut self) {
|
21
|
+
self.model.clear_kv_cache();
|
22
|
+
}
|
23
|
+
|
24
|
+
/// Load a Gemma model from HuggingFace Hub
|
25
|
+
pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
|
26
|
+
let api = Api::new()
|
27
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
|
28
|
+
|
29
|
+
let repo = api.repo(Repo::model(model_id.to_string()));
|
30
|
+
|
31
|
+
// Download model files
|
32
|
+
let config_filename = repo
|
33
|
+
.get("config.json")
|
34
|
+
.await
|
35
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
|
36
|
+
|
37
|
+
let tokenizer_filename = repo
|
38
|
+
.get("tokenizer.json")
|
39
|
+
.await
|
40
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
|
41
|
+
|
42
|
+
// Try different file patterns for model weights
|
43
|
+
let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
|
44
|
+
vec![single_file]
|
45
|
+
} else {
|
46
|
+
// Try to find sharded model files
|
47
|
+
let mut sharded_files = Vec::new();
|
48
|
+
let mut index = 1;
|
49
|
+
loop {
|
50
|
+
// Try common shard counts for Gemma models
|
51
|
+
let mut found = false;
|
52
|
+
for total in [2, 3, 4, 5, 6, 7, 8, 10, 12] {
|
53
|
+
let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
|
54
|
+
if let Ok(file) = repo.get(&filename).await {
|
55
|
+
sharded_files.push(file);
|
56
|
+
found = true;
|
57
|
+
break;
|
58
|
+
}
|
59
|
+
}
|
60
|
+
if !found {
|
61
|
+
break;
|
62
|
+
}
|
63
|
+
index += 1;
|
64
|
+
}
|
65
|
+
|
66
|
+
if sharded_files.is_empty() {
|
67
|
+
return Err(candle_core::Error::Msg(
|
68
|
+
"Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
|
69
|
+
));
|
70
|
+
}
|
71
|
+
sharded_files
|
72
|
+
};
|
73
|
+
|
74
|
+
// Load config
|
75
|
+
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
|
76
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
|
77
|
+
|
78
|
+
// Load tokenizer
|
79
|
+
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
80
|
+
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
|
81
|
+
|
82
|
+
// Gemma uses specific tokens
|
83
|
+
let eos_token_id = {
|
84
|
+
let vocab = tokenizer.get_vocab(true);
|
85
|
+
vocab.get("<eos>")
|
86
|
+
.or_else(|| vocab.get("<end_of_turn>"))
|
87
|
+
.copied()
|
88
|
+
.unwrap_or(1) // Default Gemma EOS
|
89
|
+
};
|
90
|
+
|
91
|
+
// Load model weights
|
92
|
+
let vb = unsafe {
|
93
|
+
VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
|
94
|
+
};
|
95
|
+
|
96
|
+
let model = GemmaModel::new(false, &config, vb)?; // Don't use flash attention for now
|
97
|
+
|
98
|
+
Ok(Self {
|
99
|
+
model,
|
100
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
101
|
+
device,
|
102
|
+
model_id: model_id.to_string(),
|
103
|
+
eos_token_id,
|
104
|
+
})
|
105
|
+
}
|
106
|
+
|
107
|
+
/// Create from existing components (useful for testing)
|
108
|
+
pub fn new(
|
109
|
+
model: GemmaModel,
|
110
|
+
tokenizer: Tokenizer,
|
111
|
+
device: Device,
|
112
|
+
model_id: String,
|
113
|
+
) -> Self {
|
114
|
+
let eos_token_id = {
|
115
|
+
let vocab = tokenizer.get_vocab(true);
|
116
|
+
vocab.get("<eos>")
|
117
|
+
.or_else(|| vocab.get("<end_of_turn>"))
|
118
|
+
.copied()
|
119
|
+
.unwrap_or(1)
|
120
|
+
};
|
121
|
+
|
122
|
+
Self {
|
123
|
+
model,
|
124
|
+
tokenizer: TokenizerWrapper::new(tokenizer),
|
125
|
+
device,
|
126
|
+
model_id,
|
127
|
+
eos_token_id,
|
128
|
+
}
|
129
|
+
}
|
130
|
+
|
131
|
+
fn generate_tokens(
|
132
|
+
&mut self,
|
133
|
+
prompt_tokens: Vec<u32>,
|
134
|
+
config: &GenerationConfig,
|
135
|
+
mut callback: Option<impl FnMut(&str)>,
|
136
|
+
) -> CandleResult<Vec<u32>> {
|
137
|
+
let mut text_gen = TextGeneration::from_config(config);
|
138
|
+
text_gen.set_eos_token_id(self.eos_token_id);
|
139
|
+
text_gen.set_tokens(prompt_tokens.clone());
|
140
|
+
|
141
|
+
let mut all_tokens = prompt_tokens.clone();
|
142
|
+
let start_gen = all_tokens.len();
|
143
|
+
|
144
|
+
for index in 0..config.max_length {
|
145
|
+
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
146
|
+
let start_pos = all_tokens.len().saturating_sub(context_size);
|
147
|
+
let ctxt = &all_tokens[start_pos..];
|
148
|
+
|
149
|
+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
150
|
+
let input = input.contiguous()?;
|
151
|
+
let logits = self.model.forward(&input, start_pos)?;
|
152
|
+
|
153
|
+
let logits = logits.squeeze(0)?;
|
154
|
+
let logits = if logits.dims().len() == 2 {
|
155
|
+
let seq_len = logits.dim(0)?;
|
156
|
+
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
157
|
+
} else {
|
158
|
+
logits
|
159
|
+
};
|
160
|
+
|
161
|
+
let logits = logits.to_dtype(DType::F32)?;
|
162
|
+
|
163
|
+
let next_token = text_gen.sample_next_token(
|
164
|
+
&logits,
|
165
|
+
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
166
|
+
)?;
|
167
|
+
|
168
|
+
all_tokens.push(next_token);
|
169
|
+
|
170
|
+
// Stream callback
|
171
|
+
if let Some(ref mut cb) = callback {
|
172
|
+
if config.debug_tokens {
|
173
|
+
// In debug mode, only show debug tokens
|
174
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
175
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
176
|
+
} else {
|
177
|
+
// Normal mode: use incremental decoding for proper text
|
178
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
179
|
+
cb(&decoded_text);
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
183
|
+
// Check stop conditions
|
184
|
+
if text_gen.should_stop(next_token, config.max_length) {
|
185
|
+
break;
|
186
|
+
}
|
187
|
+
|
188
|
+
// Check stop sequences
|
189
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
190
|
+
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
191
|
+
break;
|
192
|
+
}
|
193
|
+
}
|
194
|
+
|
195
|
+
Ok(if config.include_prompt {
|
196
|
+
all_tokens
|
197
|
+
} else {
|
198
|
+
all_tokens[start_gen..].to_vec()
|
199
|
+
})
|
200
|
+
}
|
201
|
+
|
202
|
+
/// Apply Gemma chat template
|
203
|
+
pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
204
|
+
let mut prompt = String::new();
|
205
|
+
|
206
|
+
// Gemma uses a specific format:
|
207
|
+
// <start_of_turn>user\n{user_message}<end_of_turn>
|
208
|
+
// <start_of_turn>model\n{model_message}<end_of_turn>
|
209
|
+
|
210
|
+
for message in messages {
|
211
|
+
let role = message["role"].as_str().unwrap_or("");
|
212
|
+
let content = message["content"].as_str().unwrap_or("");
|
213
|
+
|
214
|
+
match role {
|
215
|
+
"system" => {
|
216
|
+
// Gemma doesn't have explicit system messages, prepend to first user message
|
217
|
+
prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
|
218
|
+
}
|
219
|
+
"user" => {
|
220
|
+
if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
|
221
|
+
prompt.push_str("<start_of_turn>user\n");
|
222
|
+
}
|
223
|
+
prompt.push_str(&format!("{}<end_of_turn>\n", content));
|
224
|
+
}
|
225
|
+
"assistant" | "model" => {
|
226
|
+
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
|
227
|
+
}
|
228
|
+
_ => {}
|
229
|
+
}
|
230
|
+
}
|
231
|
+
|
232
|
+
// Add the model prompt
|
233
|
+
prompt.push_str("<start_of_turn>model\n");
|
234
|
+
|
235
|
+
Ok(prompt)
|
236
|
+
}
|
237
|
+
}
|
238
|
+
|
239
|
+
impl TextGenerator for Gemma {
|
240
|
+
fn generate(
|
241
|
+
&mut self,
|
242
|
+
prompt: &str,
|
243
|
+
config: &GenerationConfig,
|
244
|
+
) -> CandleResult<String> {
|
245
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
246
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
247
|
+
|
248
|
+
if config.debug_tokens {
|
249
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
250
|
+
} else {
|
251
|
+
self.tokenizer.decode(&output_tokens, true)
|
252
|
+
}
|
253
|
+
}
|
254
|
+
|
255
|
+
fn generate_stream(
|
256
|
+
&mut self,
|
257
|
+
prompt: &str,
|
258
|
+
config: &GenerationConfig,
|
259
|
+
mut callback: impl FnMut(&str),
|
260
|
+
) -> CandleResult<String> {
|
261
|
+
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
262
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
263
|
+
self.tokenizer.decode(&output_tokens, true)
|
264
|
+
}
|
265
|
+
|
266
|
+
fn model_name(&self) -> &str {
|
267
|
+
&self.model_id
|
268
|
+
}
|
269
|
+
|
270
|
+
fn device(&self) -> &Device {
|
271
|
+
&self.device
|
272
|
+
}
|
273
|
+
|
274
|
+
fn clear_cache(&mut self) {
|
275
|
+
self.clear_kv_cache();
|
276
|
+
}
|
277
|
+
}
|
@@ -21,6 +21,8 @@ pub struct GenerationConfig {
|
|
21
21
|
pub stop_sequences: Vec<String>,
|
22
22
|
/// Whether to return the prompt in the output
|
23
23
|
pub include_prompt: bool,
|
24
|
+
/// Whether to show raw tokens during generation (for debugging)
|
25
|
+
pub debug_tokens: bool,
|
24
26
|
}
|
25
27
|
|
26
28
|
/// Generate a random seed based on current time
|
@@ -43,6 +45,7 @@ impl Default for GenerationConfig {
|
|
43
45
|
seed: random_seed(),
|
44
46
|
stop_sequences: vec![],
|
45
47
|
include_prompt: false,
|
48
|
+
debug_tokens: false,
|
46
49
|
}
|
47
50
|
}
|
48
51
|
}
|
data/ext/candle/src/llm/llama.rs
CHANGED
@@ -205,77 +205,14 @@ impl Llama {
|
|
205
205
|
|
206
206
|
// Stream callback
|
207
207
|
if let Some(ref mut cb) = callback {
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
// Check stop sequences
|
218
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
219
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
220
|
-
break;
|
221
|
-
}
|
222
|
-
}
|
223
|
-
|
224
|
-
Ok(if config.include_prompt {
|
225
|
-
all_tokens
|
226
|
-
} else {
|
227
|
-
all_tokens[start_gen..].to_vec()
|
228
|
-
})
|
229
|
-
}
|
230
|
-
|
231
|
-
fn generate_tokens_decoded(
|
232
|
-
&mut self,
|
233
|
-
prompt_tokens: Vec<u32>,
|
234
|
-
config: &GenerationConfig,
|
235
|
-
mut callback: Option<impl FnMut(&str)>,
|
236
|
-
) -> CandleResult<Vec<u32>> {
|
237
|
-
let mut text_gen = TextGeneration::from_config(config);
|
238
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
239
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
240
|
-
|
241
|
-
let mut all_tokens = prompt_tokens.clone();
|
242
|
-
let start_gen = all_tokens.len();
|
243
|
-
let mut previously_decoded = String::new();
|
244
|
-
|
245
|
-
for index in 0..config.max_length {
|
246
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
247
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
248
|
-
let ctxt = &all_tokens[start_pos..];
|
249
|
-
|
250
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
251
|
-
let input = input.contiguous()?;
|
252
|
-
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
253
|
-
|
254
|
-
let logits = logits.squeeze(0)?;
|
255
|
-
let logits = if logits.dims().len() == 2 {
|
256
|
-
let seq_len = logits.dim(0)?;
|
257
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
258
|
-
} else {
|
259
|
-
logits
|
260
|
-
};
|
261
|
-
|
262
|
-
let logits = logits.to_dtype(DType::F32)?;
|
263
|
-
|
264
|
-
let next_token = text_gen.sample_next_token(
|
265
|
-
&logits,
|
266
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
267
|
-
)?;
|
268
|
-
|
269
|
-
all_tokens.push(next_token);
|
270
|
-
|
271
|
-
// Stream callback with incremental decoding
|
272
|
-
if let Some(ref mut cb) = callback {
|
273
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
274
|
-
|
275
|
-
if current_decoded.len() > previously_decoded.len() {
|
276
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
277
|
-
cb(new_text);
|
278
|
-
previously_decoded = current_decoded;
|
208
|
+
if config.debug_tokens {
|
209
|
+
// In debug mode, only show debug tokens
|
210
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
211
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
212
|
+
} else {
|
213
|
+
// Normal mode: use incremental decoding for proper text
|
214
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
215
|
+
cb(&decoded_text);
|
279
216
|
}
|
280
217
|
}
|
281
218
|
|
@@ -285,12 +222,7 @@ impl Llama {
|
|
285
222
|
}
|
286
223
|
|
287
224
|
// Check stop sequences
|
288
|
-
let generated_text =
|
289
|
-
previously_decoded.clone()
|
290
|
-
} else {
|
291
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
292
|
-
};
|
293
|
-
|
225
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
294
226
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
295
227
|
break;
|
296
228
|
}
|
@@ -374,7 +306,12 @@ impl TextGenerator for Llama {
|
|
374
306
|
) -> CandleResult<String> {
|
375
307
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
376
308
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
377
|
-
|
309
|
+
|
310
|
+
if config.debug_tokens {
|
311
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
312
|
+
} else {
|
313
|
+
self.tokenizer.decode(&output_tokens, true)
|
314
|
+
}
|
378
315
|
}
|
379
316
|
|
380
317
|
fn generate_stream(
|
@@ -384,7 +321,7 @@ impl TextGenerator for Llama {
|
|
384
321
|
mut callback: impl FnMut(&str),
|
385
322
|
) -> CandleResult<String> {
|
386
323
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
387
|
-
let output_tokens = self.
|
324
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
388
325
|
self.tokenizer.decode(&output_tokens, true)
|
389
326
|
}
|
390
327
|
|