red-candle 1.0.0.pre.6 → 1.0.0.pre.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +159 -0
- data/Rakefile +1 -3
- data/ext/candle/src/llm/gemma.rs +16 -79
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +16 -79
- data/ext/candle/src/llm/mistral.rs +16 -89
- data/ext/candle/src/llm/mod.rs +58 -0
- data/ext/candle/src/llm/quantized_gguf.rs +496 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ruby/embedding_model.rs +0 -1
- data/ext/candle/src/ruby/llm.rs +79 -36
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6bea0c9f27d5bbbc43e0c2e8fa5456b2b79511f164f7083512025538ca172077
|
4
|
+
data.tar.gz: b591f559c63c1bdb9fa96beaf1079186284b219e2ceba50d8e2891fa4f3096bd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 380cde32abf995f9766056638d4d0612acef746a84b31451c0aa2fb8cd205fc76b76b8909729a9af267ee6c549f9149253b2bfb70717385b5afe2cd2ad0b3152
|
7
|
+
data.tar.gz: 7acdfc4de263501be51e3f74933c528b9d723b4da93e78e3abae2454dc2746fb7696c09722fe6666084a48dff098501b2d1af503963b5d739e879bd0fcb96e43
|
data/README.md
CHANGED
@@ -51,6 +51,42 @@ Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
|
|
51
51
|
- **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
|
52
52
|
- **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
|
53
53
|
|
54
|
+
### Quantized Model Support (GGUF)
|
55
|
+
|
56
|
+
Red-Candle supports quantized models in GGUF format, offering 4-8x memory reduction:
|
57
|
+
|
58
|
+
> **Note on GGUF Support**: Red-Candle now uses a unified GGUF loader that automatically detects the model architecture from the GGUF file. This means all GGUF models (including Mistral models from TheBloke) should now work correctly! The loader automatically selects the appropriate tokenizer based on the model type to ensure proper text generation.
|
59
|
+
|
60
|
+
```ruby
|
61
|
+
# Load quantized models - always specify the GGUF filename
|
62
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
63
|
+
device: device,
|
64
|
+
gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
|
65
|
+
|
66
|
+
# Register custom tokenizer mappings for your models
|
67
|
+
Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
|
68
|
+
|
69
|
+
# Popular quantized model sources:
|
70
|
+
# - TheBloke: Extensive collection of GGUF models
|
71
|
+
# - Search HuggingFace for "GGUF" models
|
72
|
+
```
|
73
|
+
|
74
|
+
**Memory usage comparison (7B models):**
|
75
|
+
- Full precision: ~28 GB
|
76
|
+
- Q8_0 (8-bit): ~7 GB - Best quality, larger size
|
77
|
+
- Q5_K_M (5-bit): ~4.5 GB - Very good quality
|
78
|
+
- Q4_K_M (4-bit): ~4 GB - Recommended default, best balance
|
79
|
+
- Q3_K_M (3-bit): ~3 GB - Good for memory-constrained systems
|
80
|
+
|
81
|
+
**Quantization levels explained:**
|
82
|
+
- **Q8_0**: Almost identical to full model, use when quality is paramount
|
83
|
+
- **Q5_K_M**: Excellent quality with good compression
|
84
|
+
- **Q4_K_M**: Best balance of quality/size/speed (recommended default)
|
85
|
+
- **Q3_K_M**: Noticeable quality reduction but very compact
|
86
|
+
- **Q2_K**: ⚠️ **Not recommended** - Can cause inference errors due to extreme quantization
|
87
|
+
|
88
|
+
> **Warning**: Q2_K quantization can lead to "weight is negative, too large or not a valid number" errors during inference. Use Q3_K_M or higher for stable operation.
|
89
|
+
|
54
90
|
> ### ⚠️ Huggingface login warning
|
55
91
|
>
|
56
92
|
> Many models, including the one below, require you to agree to the terms. You'll need to:
|
@@ -103,6 +139,33 @@ device = Candle::Device.metal
|
|
103
139
|
device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
|
104
140
|
```
|
105
141
|
|
142
|
+
### Debugging Token Generation
|
143
|
+
|
144
|
+
For debugging purposes, you can enable raw token output to see both token IDs and their raw representations:
|
145
|
+
|
146
|
+
```ruby
|
147
|
+
# Enable debug mode to see raw tokens during generation
|
148
|
+
config = Candle::GenerationConfig.balanced(debug_tokens: true)
|
149
|
+
|
150
|
+
# Non-streaming generation with debug tokens
|
151
|
+
result = llm.generate("Hello, world!", config: config)
|
152
|
+
puts result
|
153
|
+
# Output: [15043:Hello][11:,][1917:world][0:!]
|
154
|
+
|
155
|
+
# Streaming generation with debug tokens
|
156
|
+
llm.generate_stream("Hello, world!", config: config) do |text|
|
157
|
+
print text # Will show each token as it's generated: [15043:Hello][11:,][1917:world][0:!]
|
158
|
+
end
|
159
|
+
|
160
|
+
# Works with all models (Llama, Mistral, Gemma, and quantized GGUF models)
|
161
|
+
```
|
162
|
+
|
163
|
+
This is particularly useful for:
|
164
|
+
- Debugging tokenization issues
|
165
|
+
- Understanding how the model processes text
|
166
|
+
- Troubleshooting generation problems
|
167
|
+
- Analyzing model behavior
|
168
|
+
|
106
169
|
## ⚠️ Model Format Requirement: Safetensors Only
|
107
170
|
|
108
171
|
Red-Candle **only supports embedding models that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
|
@@ -288,6 +351,102 @@ The reranker uses a BERT-based architecture that:
|
|
288
351
|
|
289
352
|
This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
|
290
353
|
|
354
|
+
## Common Runtime Errors
|
355
|
+
|
356
|
+
### 1. Weight is negative, too large or not a valid number
|
357
|
+
|
358
|
+
**Error:**
|
359
|
+
```
|
360
|
+
/Users/cpetersen/src/scientist/red-candle/lib/candle/llm.rb:25:in `_generate_stream': Generation failed: A weight is negative, too large or not a valid number (RuntimeError)
|
361
|
+
from /Users/cpetersen/src/scientist/red-candle/lib/candle/llm.rb:25:in `generate_stream'
|
362
|
+
...
|
363
|
+
```
|
364
|
+
|
365
|
+
**Cause:** This error occurs when using overly aggressive quantization levels (particularly Q2_K) that result in numerical instability during inference. The 2-bit quantization can cause weights to become corrupted or produce NaN/Inf values.
|
366
|
+
|
367
|
+
**Solution:** Use a higher quantization level. Recommended options:
|
368
|
+
- Q4_K_M (4-bit) - Best balance of quality and size
|
369
|
+
- Q5_K_M (5-bit) - Higher quality with slightly larger size
|
370
|
+
- Q3_K_M (3-bit) - Minimum recommended quantization
|
371
|
+
|
372
|
+
```ruby
|
373
|
+
# Instead of Q2_K:
|
374
|
+
llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
375
|
+
device: device,
|
376
|
+
gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
|
377
|
+
```
|
378
|
+
|
379
|
+
### 2. Cannot find tensor model.embed_tokens.weight
|
380
|
+
|
381
|
+
**Error:**
|
382
|
+
```
|
383
|
+
Failed to load quantized model: cannot find tensor model.embed_tokens.weight (RuntimeError)
|
384
|
+
```
|
385
|
+
|
386
|
+
**Cause:** This error was common in earlier versions when loading GGUF files with incompatible tensor naming conventions. The unified GGUF loader in version 1.0.0+ should handle most GGUF files correctly.
|
387
|
+
|
388
|
+
**If you still encounter this error:**
|
389
|
+
1. Ensure you're using the latest version of red-candle (1.0.0 or higher)
|
390
|
+
2. Make sure to specify the exact GGUF filename:
|
391
|
+
```ruby
|
392
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
393
|
+
device: device,
|
394
|
+
gguf_file: "mistral-7b-instruct-v0.2.Q4_K_M.gguf")
|
395
|
+
```
|
396
|
+
3. If the error persists, the GGUF file may use an unsupported architecture or format
|
397
|
+
|
398
|
+
### 3. No GGUF file found in repository
|
399
|
+
|
400
|
+
**Error:**
|
401
|
+
```
|
402
|
+
Failed to load quantized model: No GGUF file found in repository TheBloke/model-name-GGUF. Try specifying a quantization level like Q4_K_M, Q5_K_M, or Q8_0. (RuntimeError)
|
403
|
+
```
|
404
|
+
|
405
|
+
**Cause:** The automatic GGUF file detection couldn't find a matching file, often due to naming variations.
|
406
|
+
|
407
|
+
**Solution:** Specify the exact GGUF filename:
|
408
|
+
```ruby
|
409
|
+
# Visit the HuggingFace repository to find the exact filename
|
410
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
411
|
+
device: device,
|
412
|
+
gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
|
413
|
+
```
|
414
|
+
|
415
|
+
### 4. Failed to download tokenizer
|
416
|
+
|
417
|
+
**Error:**
|
418
|
+
```
|
419
|
+
Failed to load quantized model: Failed to download tokenizer: request error: HTTP status client error (404 Not Found)
|
420
|
+
```
|
421
|
+
|
422
|
+
**Cause:** GGUF repositories often don't include separate tokenizer files since they're embedded in the GGUF format.
|
423
|
+
|
424
|
+
**Solution:** The code now includes fallback tokenizer loading. If you still encounter this error, ensure you're using the latest version of red-candle.
|
425
|
+
|
426
|
+
### 5. Missing metadata in GGUF file
|
427
|
+
|
428
|
+
**Error:**
|
429
|
+
```
|
430
|
+
Failed to load GGUF model: cannot find gemma3.attention.head_count in metadata (RuntimeError)
|
431
|
+
```
|
432
|
+
or
|
433
|
+
```
|
434
|
+
Failed to load GGUF model: cannot find llama.attention.head_count in metadata (RuntimeError)
|
435
|
+
```
|
436
|
+
|
437
|
+
**Cause:** Some GGUF files may have been created with older conversion tools that don't include all required metadata fields.
|
438
|
+
|
439
|
+
**Solution:**
|
440
|
+
- Try a different GGUF file from the same model
|
441
|
+
- Look for GGUF files from TheBloke or other reputable sources
|
442
|
+
- Check if a newer version of the GGUF file is available
|
443
|
+
- Some Gemma GGUF files may not be compatible with the current loader
|
444
|
+
|
445
|
+
**Known compatibility issues:**
|
446
|
+
- `lmstudio-ai/gemma-2b-it-GGUF` - Missing required metadata fields
|
447
|
+
- Gemma 3 GGUF files may require specific tokenizers that are not publicly available
|
448
|
+
- For best compatibility, use Llama or Mistral GGUF files from TheBloke
|
449
|
+
|
291
450
|
## Development
|
292
451
|
|
293
452
|
FORK IT!
|
data/Rakefile
CHANGED
@@ -8,7 +8,7 @@ task default: :test
|
|
8
8
|
Rake::TestTask.new do |t|
|
9
9
|
t.deps << :compile
|
10
10
|
t.libs << "test"
|
11
|
-
t.test_files = FileList["test/**/*_test.rb"]
|
11
|
+
t.test_files = FileList["test/**/*_test.rb"].exclude("test/benchmarks/**/*_test.rb")
|
12
12
|
end
|
13
13
|
|
14
14
|
spec = Bundler.load_gemspec("candle.gemspec")
|
@@ -36,7 +36,6 @@ end
|
|
36
36
|
|
37
37
|
desc "Run benchmark tests"
|
38
38
|
Rake::TestTask.new("test:benchmark") do |t|
|
39
|
-
ENV['CANDLE_RUN_BENCHMARKS'] = 'true'
|
40
39
|
t.deps << :compile
|
41
40
|
t.libs << "test"
|
42
41
|
t.test_files = FileList["test/benchmarks/**/*_test.rb"]
|
@@ -59,7 +58,6 @@ end
|
|
59
58
|
|
60
59
|
desc "Run benchmarks with device tests"
|
61
60
|
task "test:device:benchmark" => :compile do
|
62
|
-
ENV['CANDLE_RUN_BENCHMARKS'] = 'true'
|
63
61
|
ENV['CANDLE_TEST_VERBOSE'] = 'true'
|
64
62
|
Rake::Task["test:device"].invoke
|
65
63
|
Rake::Task["test:benchmark"].invoke
|
data/ext/candle/src/llm/gemma.rs
CHANGED
@@ -169,77 +169,14 @@ impl Gemma {
|
|
169
169
|
|
170
170
|
// Stream callback
|
171
171
|
if let Some(ref mut cb) = callback {
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
// Check stop sequences
|
182
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
183
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
184
|
-
break;
|
185
|
-
}
|
186
|
-
}
|
187
|
-
|
188
|
-
Ok(if config.include_prompt {
|
189
|
-
all_tokens
|
190
|
-
} else {
|
191
|
-
all_tokens[start_gen..].to_vec()
|
192
|
-
})
|
193
|
-
}
|
194
|
-
|
195
|
-
fn generate_tokens_decoded(
|
196
|
-
&mut self,
|
197
|
-
prompt_tokens: Vec<u32>,
|
198
|
-
config: &GenerationConfig,
|
199
|
-
mut callback: Option<impl FnMut(&str)>,
|
200
|
-
) -> CandleResult<Vec<u32>> {
|
201
|
-
let mut text_gen = TextGeneration::from_config(config);
|
202
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
203
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
204
|
-
|
205
|
-
let mut all_tokens = prompt_tokens.clone();
|
206
|
-
let start_gen = all_tokens.len();
|
207
|
-
let mut previously_decoded = String::new();
|
208
|
-
|
209
|
-
for index in 0..config.max_length {
|
210
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
211
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
212
|
-
let ctxt = &all_tokens[start_pos..];
|
213
|
-
|
214
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
215
|
-
let input = input.contiguous()?;
|
216
|
-
let logits = self.model.forward(&input, start_pos)?;
|
217
|
-
|
218
|
-
let logits = logits.squeeze(0)?;
|
219
|
-
let logits = if logits.dims().len() == 2 {
|
220
|
-
let seq_len = logits.dim(0)?;
|
221
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
222
|
-
} else {
|
223
|
-
logits
|
224
|
-
};
|
225
|
-
|
226
|
-
let logits = logits.to_dtype(DType::F32)?;
|
227
|
-
|
228
|
-
let next_token = text_gen.sample_next_token(
|
229
|
-
&logits,
|
230
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
231
|
-
)?;
|
232
|
-
|
233
|
-
all_tokens.push(next_token);
|
234
|
-
|
235
|
-
// Stream callback with incremental decoding
|
236
|
-
if let Some(ref mut cb) = callback {
|
237
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
238
|
-
|
239
|
-
if current_decoded.len() > previously_decoded.len() {
|
240
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
241
|
-
cb(new_text);
|
242
|
-
previously_decoded = current_decoded;
|
172
|
+
if config.debug_tokens {
|
173
|
+
// In debug mode, only show debug tokens
|
174
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
175
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
176
|
+
} else {
|
177
|
+
// Normal mode: use incremental decoding for proper text
|
178
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
179
|
+
cb(&decoded_text);
|
243
180
|
}
|
244
181
|
}
|
245
182
|
|
@@ -249,12 +186,7 @@ impl Gemma {
|
|
249
186
|
}
|
250
187
|
|
251
188
|
// Check stop sequences
|
252
|
-
let generated_text =
|
253
|
-
previously_decoded.clone()
|
254
|
-
} else {
|
255
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
256
|
-
};
|
257
|
-
|
189
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
258
190
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
259
191
|
break;
|
260
192
|
}
|
@@ -312,7 +244,12 @@ impl TextGenerator for Gemma {
|
|
312
244
|
) -> CandleResult<String> {
|
313
245
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
314
246
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
315
|
-
|
247
|
+
|
248
|
+
if config.debug_tokens {
|
249
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
250
|
+
} else {
|
251
|
+
self.tokenizer.decode(&output_tokens, true)
|
252
|
+
}
|
316
253
|
}
|
317
254
|
|
318
255
|
fn generate_stream(
|
@@ -322,7 +259,7 @@ impl TextGenerator for Gemma {
|
|
322
259
|
mut callback: impl FnMut(&str),
|
323
260
|
) -> CandleResult<String> {
|
324
261
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
325
|
-
let output_tokens = self.
|
262
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
326
263
|
self.tokenizer.decode(&output_tokens, true)
|
327
264
|
}
|
328
265
|
|
@@ -21,6 +21,8 @@ pub struct GenerationConfig {
|
|
21
21
|
pub stop_sequences: Vec<String>,
|
22
22
|
/// Whether to return the prompt in the output
|
23
23
|
pub include_prompt: bool,
|
24
|
+
/// Whether to show raw tokens during generation (for debugging)
|
25
|
+
pub debug_tokens: bool,
|
24
26
|
}
|
25
27
|
|
26
28
|
/// Generate a random seed based on current time
|
@@ -43,6 +45,7 @@ impl Default for GenerationConfig {
|
|
43
45
|
seed: random_seed(),
|
44
46
|
stop_sequences: vec![],
|
45
47
|
include_prompt: false,
|
48
|
+
debug_tokens: false,
|
46
49
|
}
|
47
50
|
}
|
48
51
|
}
|
data/ext/candle/src/llm/llama.rs
CHANGED
@@ -205,77 +205,14 @@ impl Llama {
|
|
205
205
|
|
206
206
|
// Stream callback
|
207
207
|
if let Some(ref mut cb) = callback {
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
// Check stop sequences
|
218
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
219
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
220
|
-
break;
|
221
|
-
}
|
222
|
-
}
|
223
|
-
|
224
|
-
Ok(if config.include_prompt {
|
225
|
-
all_tokens
|
226
|
-
} else {
|
227
|
-
all_tokens[start_gen..].to_vec()
|
228
|
-
})
|
229
|
-
}
|
230
|
-
|
231
|
-
fn generate_tokens_decoded(
|
232
|
-
&mut self,
|
233
|
-
prompt_tokens: Vec<u32>,
|
234
|
-
config: &GenerationConfig,
|
235
|
-
mut callback: Option<impl FnMut(&str)>,
|
236
|
-
) -> CandleResult<Vec<u32>> {
|
237
|
-
let mut text_gen = TextGeneration::from_config(config);
|
238
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
239
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
240
|
-
|
241
|
-
let mut all_tokens = prompt_tokens.clone();
|
242
|
-
let start_gen = all_tokens.len();
|
243
|
-
let mut previously_decoded = String::new();
|
244
|
-
|
245
|
-
for index in 0..config.max_length {
|
246
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
247
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
248
|
-
let ctxt = &all_tokens[start_pos..];
|
249
|
-
|
250
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
251
|
-
let input = input.contiguous()?;
|
252
|
-
let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
|
253
|
-
|
254
|
-
let logits = logits.squeeze(0)?;
|
255
|
-
let logits = if logits.dims().len() == 2 {
|
256
|
-
let seq_len = logits.dim(0)?;
|
257
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
258
|
-
} else {
|
259
|
-
logits
|
260
|
-
};
|
261
|
-
|
262
|
-
let logits = logits.to_dtype(DType::F32)?;
|
263
|
-
|
264
|
-
let next_token = text_gen.sample_next_token(
|
265
|
-
&logits,
|
266
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
267
|
-
)?;
|
268
|
-
|
269
|
-
all_tokens.push(next_token);
|
270
|
-
|
271
|
-
// Stream callback with incremental decoding
|
272
|
-
if let Some(ref mut cb) = callback {
|
273
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
274
|
-
|
275
|
-
if current_decoded.len() > previously_decoded.len() {
|
276
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
277
|
-
cb(new_text);
|
278
|
-
previously_decoded = current_decoded;
|
208
|
+
if config.debug_tokens {
|
209
|
+
// In debug mode, only show debug tokens
|
210
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
211
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
212
|
+
} else {
|
213
|
+
// Normal mode: use incremental decoding for proper text
|
214
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
215
|
+
cb(&decoded_text);
|
279
216
|
}
|
280
217
|
}
|
281
218
|
|
@@ -285,12 +222,7 @@ impl Llama {
|
|
285
222
|
}
|
286
223
|
|
287
224
|
// Check stop sequences
|
288
|
-
let generated_text =
|
289
|
-
previously_decoded.clone()
|
290
|
-
} else {
|
291
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
292
|
-
};
|
293
|
-
|
225
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
294
226
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
295
227
|
break;
|
296
228
|
}
|
@@ -374,7 +306,12 @@ impl TextGenerator for Llama {
|
|
374
306
|
) -> CandleResult<String> {
|
375
307
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
376
308
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
377
|
-
|
309
|
+
|
310
|
+
if config.debug_tokens {
|
311
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
312
|
+
} else {
|
313
|
+
self.tokenizer.decode(&output_tokens, true)
|
314
|
+
}
|
378
315
|
}
|
379
316
|
|
380
317
|
fn generate_stream(
|
@@ -384,7 +321,7 @@ impl TextGenerator for Llama {
|
|
384
321
|
mut callback: impl FnMut(&str),
|
385
322
|
) -> CandleResult<String> {
|
386
323
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
387
|
-
let output_tokens = self.
|
324
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
388
325
|
self.tokenizer.decode(&output_tokens, true)
|
389
326
|
}
|
390
327
|
|
@@ -180,87 +180,14 @@ impl Mistral {
|
|
180
180
|
|
181
181
|
// Stream callback
|
182
182
|
if let Some(ref mut cb) = callback {
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
// Check stop sequences
|
193
|
-
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
194
|
-
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
195
|
-
break;
|
196
|
-
}
|
197
|
-
}
|
198
|
-
|
199
|
-
Ok(if config.include_prompt {
|
200
|
-
all_tokens
|
201
|
-
} else {
|
202
|
-
all_tokens[start_gen..].to_vec()
|
203
|
-
})
|
204
|
-
}
|
205
|
-
|
206
|
-
fn generate_tokens_decoded(
|
207
|
-
&mut self,
|
208
|
-
prompt_tokens: Vec<u32>,
|
209
|
-
config: &GenerationConfig,
|
210
|
-
mut callback: Option<impl FnMut(&str)>,
|
211
|
-
) -> CandleResult<Vec<u32>> {
|
212
|
-
let mut text_gen = TextGeneration::from_config(config);
|
213
|
-
text_gen.set_eos_token_id(self.eos_token_id);
|
214
|
-
text_gen.set_tokens(prompt_tokens.clone());
|
215
|
-
|
216
|
-
let mut all_tokens = prompt_tokens.clone();
|
217
|
-
let start_gen = all_tokens.len();
|
218
|
-
|
219
|
-
// For incremental decoding
|
220
|
-
let mut previously_decoded = String::new();
|
221
|
-
|
222
|
-
for index in 0..config.max_length {
|
223
|
-
let context_size = if index > 0 { 1 } else { all_tokens.len() };
|
224
|
-
let start_pos = all_tokens.len().saturating_sub(context_size);
|
225
|
-
let ctxt = &all_tokens[start_pos..];
|
226
|
-
|
227
|
-
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
228
|
-
// Ensure input tensor is contiguous for Metal backend
|
229
|
-
let input = input.contiguous()?;
|
230
|
-
let logits = self.model.forward(&input, start_pos)?;
|
231
|
-
|
232
|
-
// The model returns logits of shape [batch_size, seq_len, vocab_size]
|
233
|
-
// We need to get the logits for the last token only
|
234
|
-
let logits = logits.squeeze(0)?; // Remove batch dimension
|
235
|
-
let logits = if logits.dims().len() == 2 {
|
236
|
-
// If we still have [seq_len, vocab_size], take the last token
|
237
|
-
let seq_len = logits.dim(0)?;
|
238
|
-
logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
|
239
|
-
} else {
|
240
|
-
// Already [vocab_size]
|
241
|
-
logits
|
242
|
-
};
|
243
|
-
|
244
|
-
// Convert to F32 for sampling if needed
|
245
|
-
let logits = logits.to_dtype(DType::F32)?;
|
246
|
-
|
247
|
-
let next_token = text_gen.sample_next_token(
|
248
|
-
&logits,
|
249
|
-
Some((config.repetition_penalty, config.repetition_penalty_last_n)),
|
250
|
-
)?;
|
251
|
-
|
252
|
-
all_tokens.push(next_token);
|
253
|
-
|
254
|
-
// Stream callback with incremental decoding
|
255
|
-
if let Some(ref mut cb) = callback {
|
256
|
-
// Decode all generated tokens so far
|
257
|
-
let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
258
|
-
|
259
|
-
// Only emit the new text since last callback
|
260
|
-
if current_decoded.len() > previously_decoded.len() {
|
261
|
-
let new_text = ¤t_decoded[previously_decoded.len()..];
|
262
|
-
cb(new_text);
|
263
|
-
previously_decoded = current_decoded;
|
183
|
+
if config.debug_tokens {
|
184
|
+
// In debug mode, only show debug tokens
|
185
|
+
let token_piece = self.tokenizer.token_to_piece(next_token)?;
|
186
|
+
cb(&format!("[{}:{}]", next_token, token_piece));
|
187
|
+
} else {
|
188
|
+
// Normal mode: use incremental decoding for proper text
|
189
|
+
let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
|
190
|
+
cb(&decoded_text);
|
264
191
|
}
|
265
192
|
}
|
266
193
|
|
@@ -270,12 +197,7 @@ impl Mistral {
|
|
270
197
|
}
|
271
198
|
|
272
199
|
// Check stop sequences
|
273
|
-
let generated_text =
|
274
|
-
previously_decoded.clone()
|
275
|
-
} else {
|
276
|
-
self.tokenizer.decode(&all_tokens[start_gen..], true)?
|
277
|
-
};
|
278
|
-
|
200
|
+
let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
|
279
201
|
if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
|
280
202
|
break;
|
281
203
|
}
|
@@ -297,7 +219,12 @@ impl TextGenerator for Mistral {
|
|
297
219
|
) -> CandleResult<String> {
|
298
220
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
299
221
|
let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
|
300
|
-
|
222
|
+
|
223
|
+
if config.debug_tokens {
|
224
|
+
self.tokenizer.format_tokens_with_debug(&output_tokens)
|
225
|
+
} else {
|
226
|
+
self.tokenizer.decode(&output_tokens, true)
|
227
|
+
}
|
301
228
|
}
|
302
229
|
|
303
230
|
fn generate_stream(
|
@@ -307,7 +234,7 @@ impl TextGenerator for Mistral {
|
|
307
234
|
mut callback: impl FnMut(&str),
|
308
235
|
) -> CandleResult<String> {
|
309
236
|
let prompt_tokens = self.tokenizer.encode(prompt, true)?;
|
310
|
-
let output_tokens = self.
|
237
|
+
let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
|
311
238
|
self.tokenizer.decode(&output_tokens, true)
|
312
239
|
}
|
313
240
|
|