red-candle 1.0.0.pre.5 → 1.0.0.pre.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 91a4c43a1a12d6d8960f1a1d190c9bfe8ea60db75f687233012d09b8c90b5020
4
- data.tar.gz: 3f6ce143cd38856365231baebe25a188e7d5824d930ecdae79c1660b3ad6c787
3
+ metadata.gz: 6bea0c9f27d5bbbc43e0c2e8fa5456b2b79511f164f7083512025538ca172077
4
+ data.tar.gz: b591f559c63c1bdb9fa96beaf1079186284b219e2ceba50d8e2891fa4f3096bd
5
5
  SHA512:
6
- metadata.gz: 273a01c438b085509a433602097b5ad4bcdb3420fc19ebe84c4fd37bf43ae5e6e1040701ecaad9e2864c0cda31d6128d4b2df6c395290994c17f135620282be6
7
- data.tar.gz: 9193d01d8bfc704b982c839f9aebac23217ae5d56b5afae6228daae73b60e219b049fe8fc8d5ad13657ceb110476489c68bbb9aa6397688bf70a976a8d62d41e
6
+ metadata.gz: 380cde32abf995f9766056638d4d0612acef746a84b31451c0aa2fb8cd205fc76b76b8909729a9af267ee6c549f9149253b2bfb70717385b5afe2cd2ad0b3152
7
+ data.tar.gz: 7acdfc4de263501be51e3f74933c528b9d723b4da93e78e3abae2454dc2746fb7696c09722fe6666084a48dff098501b2d1af503963b5d739e879bd0fcb96e43
data/README.md CHANGED
@@ -47,9 +47,46 @@ Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
47
47
 
48
48
  ### Supported Models
49
49
 
50
+ - **Gemma**: Google's Gemma models (e.g., `google/gemma-2b`, `google/gemma-7b`, `google/gemma-2b-it`)
50
51
  - **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
51
52
  - **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
52
53
 
54
+ ### Quantized Model Support (GGUF)
55
+
56
+ Red-Candle supports quantized models in GGUF format, offering 4-8x memory reduction:
57
+
58
+ > **Note on GGUF Support**: Red-Candle now uses a unified GGUF loader that automatically detects the model architecture from the GGUF file. This means all GGUF models (including Mistral models from TheBloke) should now work correctly! The loader automatically selects the appropriate tokenizer based on the model type to ensure proper text generation.
59
+
60
+ ```ruby
61
+ # Load quantized models - always specify the GGUF filename
62
+ llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
63
+ device: device,
64
+ gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
65
+
66
+ # Register custom tokenizer mappings for your models
67
+ Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
68
+
69
+ # Popular quantized model sources:
70
+ # - TheBloke: Extensive collection of GGUF models
71
+ # - Search HuggingFace for "GGUF" models
72
+ ```
73
+
74
+ **Memory usage comparison (7B models):**
75
+ - Full precision: ~28 GB
76
+ - Q8_0 (8-bit): ~7 GB - Best quality, larger size
77
+ - Q5_K_M (5-bit): ~4.5 GB - Very good quality
78
+ - Q4_K_M (4-bit): ~4 GB - Recommended default, best balance
79
+ - Q3_K_M (3-bit): ~3 GB - Good for memory-constrained systems
80
+
81
+ **Quantization levels explained:**
82
+ - **Q8_0**: Almost identical to full model, use when quality is paramount
83
+ - **Q5_K_M**: Excellent quality with good compression
84
+ - **Q4_K_M**: Best balance of quality/size/speed (recommended default)
85
+ - **Q3_K_M**: Noticeable quality reduction but very compact
86
+ - **Q2_K**: ⚠️ **Not recommended** - Can cause inference errors due to extreme quantization
87
+
88
+ > **Warning**: Q2_K quantization can lead to "weight is negative, too large or not a valid number" errors during inference. Use Q3_K_M or higher for stable operation.
89
+
53
90
  > ### ⚠️ Huggingface login warning
54
91
  >
55
92
  > Many models, including the one below, require you to agree to the terms. You'll need to:
@@ -67,10 +104,10 @@ device = Candle::Device.cpu # CPU (default)
67
104
  device = Candle::Device.metal # Apple GPU (Metal)
68
105
  device = Candle::Device.cuda # NVIDIA GPU (CUDA)
69
106
 
70
- # Load a Llama model
71
- llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
72
- # Or a Mistral model
73
- llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
107
+ # Load a model
108
+ llm = Candle::LLM.from_pretrained("google/gemma-2b-it", device: device) # Gemma
109
+ # llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device) # Llama
110
+ # llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device) # Mistral
74
111
 
75
112
  # Generate text
76
113
  response = llm.generate("What is Ruby?", config: Candle::GenerationConfig.balanced)
@@ -102,6 +139,33 @@ device = Candle::Device.metal
102
139
  device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
103
140
  ```
104
141
 
142
+ ### Debugging Token Generation
143
+
144
+ For debugging purposes, you can enable raw token output to see both token IDs and their raw representations:
145
+
146
+ ```ruby
147
+ # Enable debug mode to see raw tokens during generation
148
+ config = Candle::GenerationConfig.balanced(debug_tokens: true)
149
+
150
+ # Non-streaming generation with debug tokens
151
+ result = llm.generate("Hello, world!", config: config)
152
+ puts result
153
+ # Output: [15043:Hello][11:,][1917:world][0:!]
154
+
155
+ # Streaming generation with debug tokens
156
+ llm.generate_stream("Hello, world!", config: config) do |text|
157
+ print text # Will show each token as it's generated: [15043:Hello][11:,][1917:world][0:!]
158
+ end
159
+
160
+ # Works with all models (Llama, Mistral, Gemma, and quantized GGUF models)
161
+ ```
162
+
163
+ This is particularly useful for:
164
+ - Debugging tokenization issues
165
+ - Understanding how the model processes text
166
+ - Troubleshooting generation problems
167
+ - Analyzing model behavior
168
+
105
169
  ## ⚠️ Model Format Requirement: Safetensors Only
106
170
 
107
171
  Red-Candle **only supports embedding models that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
@@ -287,6 +351,102 @@ The reranker uses a BERT-based architecture that:
287
351
 
288
352
  This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
289
353
 
354
+ ## Common Runtime Errors
355
+
356
+ ### 1. Weight is negative, too large or not a valid number
357
+
358
+ **Error:**
359
+ ```
360
+ /Users/cpetersen/src/scientist/red-candle/lib/candle/llm.rb:25:in `_generate_stream': Generation failed: A weight is negative, too large or not a valid number (RuntimeError)
361
+ from /Users/cpetersen/src/scientist/red-candle/lib/candle/llm.rb:25:in `generate_stream'
362
+ ...
363
+ ```
364
+
365
+ **Cause:** This error occurs when using overly aggressive quantization levels (particularly Q2_K) that result in numerical instability during inference. The 2-bit quantization can cause weights to become corrupted or produce NaN/Inf values.
366
+
367
+ **Solution:** Use a higher quantization level. Recommended options:
368
+ - Q4_K_M (4-bit) - Best balance of quality and size
369
+ - Q5_K_M (5-bit) - Higher quality with slightly larger size
370
+ - Q3_K_M (3-bit) - Minimum recommended quantization
371
+
372
+ ```ruby
373
+ # Instead of Q2_K:
374
+ llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
375
+ device: device,
376
+ gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
377
+ ```
378
+
379
+ ### 2. Cannot find tensor model.embed_tokens.weight
380
+
381
+ **Error:**
382
+ ```
383
+ Failed to load quantized model: cannot find tensor model.embed_tokens.weight (RuntimeError)
384
+ ```
385
+
386
+ **Cause:** This error was common in earlier versions when loading GGUF files with incompatible tensor naming conventions. The unified GGUF loader in version 1.0.0+ should handle most GGUF files correctly.
387
+
388
+ **If you still encounter this error:**
389
+ 1. Ensure you're using the latest version of red-candle (1.0.0 or higher)
390
+ 2. Make sure to specify the exact GGUF filename:
391
+ ```ruby
392
+ llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
393
+ device: device,
394
+ gguf_file: "mistral-7b-instruct-v0.2.Q4_K_M.gguf")
395
+ ```
396
+ 3. If the error persists, the GGUF file may use an unsupported architecture or format
397
+
398
+ ### 3. No GGUF file found in repository
399
+
400
+ **Error:**
401
+ ```
402
+ Failed to load quantized model: No GGUF file found in repository TheBloke/model-name-GGUF. Try specifying a quantization level like Q4_K_M, Q5_K_M, or Q8_0. (RuntimeError)
403
+ ```
404
+
405
+ **Cause:** The automatic GGUF file detection couldn't find a matching file, often due to naming variations.
406
+
407
+ **Solution:** Specify the exact GGUF filename:
408
+ ```ruby
409
+ # Visit the HuggingFace repository to find the exact filename
410
+ llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
411
+ device: device,
412
+ gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
413
+ ```
414
+
415
+ ### 4. Failed to download tokenizer
416
+
417
+ **Error:**
418
+ ```
419
+ Failed to load quantized model: Failed to download tokenizer: request error: HTTP status client error (404 Not Found)
420
+ ```
421
+
422
+ **Cause:** GGUF repositories often don't include separate tokenizer files since they're embedded in the GGUF format.
423
+
424
+ **Solution:** The code now includes fallback tokenizer loading. If you still encounter this error, ensure you're using the latest version of red-candle.
425
+
426
+ ### 5. Missing metadata in GGUF file
427
+
428
+ **Error:**
429
+ ```
430
+ Failed to load GGUF model: cannot find gemma3.attention.head_count in metadata (RuntimeError)
431
+ ```
432
+ or
433
+ ```
434
+ Failed to load GGUF model: cannot find llama.attention.head_count in metadata (RuntimeError)
435
+ ```
436
+
437
+ **Cause:** Some GGUF files may have been created with older conversion tools that don't include all required metadata fields.
438
+
439
+ **Solution:**
440
+ - Try a different GGUF file from the same model
441
+ - Look for GGUF files from TheBloke or other reputable sources
442
+ - Check if a newer version of the GGUF file is available
443
+ - Some Gemma GGUF files may not be compatible with the current loader
444
+
445
+ **Known compatibility issues:**
446
+ - `lmstudio-ai/gemma-2b-it-GGUF` - Missing required metadata fields
447
+ - Gemma 3 GGUF files may require specific tokenizers that are not publicly available
448
+ - For best compatibility, use Llama or Mistral GGUF files from TheBloke
449
+
290
450
  ## Development
291
451
 
292
452
  FORK IT!
data/Rakefile CHANGED
@@ -8,7 +8,7 @@ task default: :test
8
8
  Rake::TestTask.new do |t|
9
9
  t.deps << :compile
10
10
  t.libs << "test"
11
- t.test_files = FileList["test/**/*_test.rb"]
11
+ t.test_files = FileList["test/**/*_test.rb"].exclude("test/benchmarks/**/*_test.rb")
12
12
  end
13
13
 
14
14
  spec = Bundler.load_gemspec("candle.gemspec")
@@ -36,7 +36,6 @@ end
36
36
 
37
37
  desc "Run benchmark tests"
38
38
  Rake::TestTask.new("test:benchmark") do |t|
39
- ENV['CANDLE_RUN_BENCHMARKS'] = 'true'
40
39
  t.deps << :compile
41
40
  t.libs << "test"
42
41
  t.test_files = FileList["test/benchmarks/**/*_test.rb"]
@@ -59,7 +58,6 @@ end
59
58
 
60
59
  desc "Run benchmarks with device tests"
61
60
  task "test:device:benchmark" => :compile do
62
- ENV['CANDLE_RUN_BENCHMARKS'] = 'true'
63
61
  ENV['CANDLE_TEST_VERBOSE'] = 'true'
64
62
  Rake::Task["test:device"].invoke
65
63
  Rake::Task["test:benchmark"].invoke
@@ -0,0 +1,277 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::gemma::{Config, Model as GemmaModel};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Gemma {
11
+ model: GemmaModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Gemma {
19
+ /// Clear the KV cache between generations
20
+ pub fn clear_kv_cache(&mut self) {
21
+ self.model.clear_kv_cache();
22
+ }
23
+
24
+ /// Load a Gemma model from HuggingFace Hub
25
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
+ let api = Api::new()
27
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
28
+
29
+ let repo = api.repo(Repo::model(model_id.to_string()));
30
+
31
+ // Download model files
32
+ let config_filename = repo
33
+ .get("config.json")
34
+ .await
35
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
36
+
37
+ let tokenizer_filename = repo
38
+ .get("tokenizer.json")
39
+ .await
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
41
+
42
+ // Try different file patterns for model weights
43
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
44
+ vec![single_file]
45
+ } else {
46
+ // Try to find sharded model files
47
+ let mut sharded_files = Vec::new();
48
+ let mut index = 1;
49
+ loop {
50
+ // Try common shard counts for Gemma models
51
+ let mut found = false;
52
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 12] {
53
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
54
+ if let Ok(file) = repo.get(&filename).await {
55
+ sharded_files.push(file);
56
+ found = true;
57
+ break;
58
+ }
59
+ }
60
+ if !found {
61
+ break;
62
+ }
63
+ index += 1;
64
+ }
65
+
66
+ if sharded_files.is_empty() {
67
+ return Err(candle_core::Error::Msg(
68
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
69
+ ));
70
+ }
71
+ sharded_files
72
+ };
73
+
74
+ // Load config
75
+ let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
76
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
77
+
78
+ // Load tokenizer
79
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
80
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
81
+
82
+ // Gemma uses specific tokens
83
+ let eos_token_id = {
84
+ let vocab = tokenizer.get_vocab(true);
85
+ vocab.get("<eos>")
86
+ .or_else(|| vocab.get("<end_of_turn>"))
87
+ .copied()
88
+ .unwrap_or(1) // Default Gemma EOS
89
+ };
90
+
91
+ // Load model weights
92
+ let vb = unsafe {
93
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
94
+ };
95
+
96
+ let model = GemmaModel::new(false, &config, vb)?; // Don't use flash attention for now
97
+
98
+ Ok(Self {
99
+ model,
100
+ tokenizer: TokenizerWrapper::new(tokenizer),
101
+ device,
102
+ model_id: model_id.to_string(),
103
+ eos_token_id,
104
+ })
105
+ }
106
+
107
+ /// Create from existing components (useful for testing)
108
+ pub fn new(
109
+ model: GemmaModel,
110
+ tokenizer: Tokenizer,
111
+ device: Device,
112
+ model_id: String,
113
+ ) -> Self {
114
+ let eos_token_id = {
115
+ let vocab = tokenizer.get_vocab(true);
116
+ vocab.get("<eos>")
117
+ .or_else(|| vocab.get("<end_of_turn>"))
118
+ .copied()
119
+ .unwrap_or(1)
120
+ };
121
+
122
+ Self {
123
+ model,
124
+ tokenizer: TokenizerWrapper::new(tokenizer),
125
+ device,
126
+ model_id,
127
+ eos_token_id,
128
+ }
129
+ }
130
+
131
+ fn generate_tokens(
132
+ &mut self,
133
+ prompt_tokens: Vec<u32>,
134
+ config: &GenerationConfig,
135
+ mut callback: Option<impl FnMut(&str)>,
136
+ ) -> CandleResult<Vec<u32>> {
137
+ let mut text_gen = TextGeneration::from_config(config);
138
+ text_gen.set_eos_token_id(self.eos_token_id);
139
+ text_gen.set_tokens(prompt_tokens.clone());
140
+
141
+ let mut all_tokens = prompt_tokens.clone();
142
+ let start_gen = all_tokens.len();
143
+
144
+ for index in 0..config.max_length {
145
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
146
+ let start_pos = all_tokens.len().saturating_sub(context_size);
147
+ let ctxt = &all_tokens[start_pos..];
148
+
149
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
150
+ let input = input.contiguous()?;
151
+ let logits = self.model.forward(&input, start_pos)?;
152
+
153
+ let logits = logits.squeeze(0)?;
154
+ let logits = if logits.dims().len() == 2 {
155
+ let seq_len = logits.dim(0)?;
156
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
157
+ } else {
158
+ logits
159
+ };
160
+
161
+ let logits = logits.to_dtype(DType::F32)?;
162
+
163
+ let next_token = text_gen.sample_next_token(
164
+ &logits,
165
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
166
+ )?;
167
+
168
+ all_tokens.push(next_token);
169
+
170
+ // Stream callback
171
+ if let Some(ref mut cb) = callback {
172
+ if config.debug_tokens {
173
+ // In debug mode, only show debug tokens
174
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
175
+ cb(&format!("[{}:{}]", next_token, token_piece));
176
+ } else {
177
+ // Normal mode: use incremental decoding for proper text
178
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
179
+ cb(&decoded_text);
180
+ }
181
+ }
182
+
183
+ // Check stop conditions
184
+ if text_gen.should_stop(next_token, config.max_length) {
185
+ break;
186
+ }
187
+
188
+ // Check stop sequences
189
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
190
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
191
+ break;
192
+ }
193
+ }
194
+
195
+ Ok(if config.include_prompt {
196
+ all_tokens
197
+ } else {
198
+ all_tokens[start_gen..].to_vec()
199
+ })
200
+ }
201
+
202
+ /// Apply Gemma chat template
203
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
204
+ let mut prompt = String::new();
205
+
206
+ // Gemma uses a specific format:
207
+ // <start_of_turn>user\n{user_message}<end_of_turn>
208
+ // <start_of_turn>model\n{model_message}<end_of_turn>
209
+
210
+ for message in messages {
211
+ let role = message["role"].as_str().unwrap_or("");
212
+ let content = message["content"].as_str().unwrap_or("");
213
+
214
+ match role {
215
+ "system" => {
216
+ // Gemma doesn't have explicit system messages, prepend to first user message
217
+ prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
218
+ }
219
+ "user" => {
220
+ if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
221
+ prompt.push_str("<start_of_turn>user\n");
222
+ }
223
+ prompt.push_str(&format!("{}<end_of_turn>\n", content));
224
+ }
225
+ "assistant" | "model" => {
226
+ prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
227
+ }
228
+ _ => {}
229
+ }
230
+ }
231
+
232
+ // Add the model prompt
233
+ prompt.push_str("<start_of_turn>model\n");
234
+
235
+ Ok(prompt)
236
+ }
237
+ }
238
+
239
+ impl TextGenerator for Gemma {
240
+ fn generate(
241
+ &mut self,
242
+ prompt: &str,
243
+ config: &GenerationConfig,
244
+ ) -> CandleResult<String> {
245
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
246
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
247
+
248
+ if config.debug_tokens {
249
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
250
+ } else {
251
+ self.tokenizer.decode(&output_tokens, true)
252
+ }
253
+ }
254
+
255
+ fn generate_stream(
256
+ &mut self,
257
+ prompt: &str,
258
+ config: &GenerationConfig,
259
+ mut callback: impl FnMut(&str),
260
+ ) -> CandleResult<String> {
261
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
262
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
263
+ self.tokenizer.decode(&output_tokens, true)
264
+ }
265
+
266
+ fn model_name(&self) -> &str {
267
+ &self.model_id
268
+ }
269
+
270
+ fn device(&self) -> &Device {
271
+ &self.device
272
+ }
273
+
274
+ fn clear_cache(&mut self) {
275
+ self.clear_kv_cache();
276
+ }
277
+ }
@@ -21,6 +21,8 @@ pub struct GenerationConfig {
21
21
  pub stop_sequences: Vec<String>,
22
22
  /// Whether to return the prompt in the output
23
23
  pub include_prompt: bool,
24
+ /// Whether to show raw tokens during generation (for debugging)
25
+ pub debug_tokens: bool,
24
26
  }
25
27
 
26
28
  /// Generate a random seed based on current time
@@ -43,6 +45,7 @@ impl Default for GenerationConfig {
43
45
  seed: random_seed(),
44
46
  stop_sequences: vec![],
45
47
  include_prompt: false,
48
+ debug_tokens: false,
46
49
  }
47
50
  }
48
51
  }
@@ -205,77 +205,14 @@ impl Llama {
205
205
 
206
206
  // Stream callback
207
207
  if let Some(ref mut cb) = callback {
208
- let token_text = self.tokenizer.token_to_piece(next_token)?;
209
- cb(&token_text);
210
- }
211
-
212
- // Check stop conditions
213
- if text_gen.should_stop(next_token, config.max_length) {
214
- break;
215
- }
216
-
217
- // Check stop sequences
218
- let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
219
- if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
220
- break;
221
- }
222
- }
223
-
224
- Ok(if config.include_prompt {
225
- all_tokens
226
- } else {
227
- all_tokens[start_gen..].to_vec()
228
- })
229
- }
230
-
231
- fn generate_tokens_decoded(
232
- &mut self,
233
- prompt_tokens: Vec<u32>,
234
- config: &GenerationConfig,
235
- mut callback: Option<impl FnMut(&str)>,
236
- ) -> CandleResult<Vec<u32>> {
237
- let mut text_gen = TextGeneration::from_config(config);
238
- text_gen.set_eos_token_id(self.eos_token_id);
239
- text_gen.set_tokens(prompt_tokens.clone());
240
-
241
- let mut all_tokens = prompt_tokens.clone();
242
- let start_gen = all_tokens.len();
243
- let mut previously_decoded = String::new();
244
-
245
- for index in 0..config.max_length {
246
- let context_size = if index > 0 { 1 } else { all_tokens.len() };
247
- let start_pos = all_tokens.len().saturating_sub(context_size);
248
- let ctxt = &all_tokens[start_pos..];
249
-
250
- let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
251
- let input = input.contiguous()?;
252
- let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
253
-
254
- let logits = logits.squeeze(0)?;
255
- let logits = if logits.dims().len() == 2 {
256
- let seq_len = logits.dim(0)?;
257
- logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
258
- } else {
259
- logits
260
- };
261
-
262
- let logits = logits.to_dtype(DType::F32)?;
263
-
264
- let next_token = text_gen.sample_next_token(
265
- &logits,
266
- Some((config.repetition_penalty, config.repetition_penalty_last_n)),
267
- )?;
268
-
269
- all_tokens.push(next_token);
270
-
271
- // Stream callback with incremental decoding
272
- if let Some(ref mut cb) = callback {
273
- let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
274
-
275
- if current_decoded.len() > previously_decoded.len() {
276
- let new_text = &current_decoded[previously_decoded.len()..];
277
- cb(new_text);
278
- previously_decoded = current_decoded;
208
+ if config.debug_tokens {
209
+ // In debug mode, only show debug tokens
210
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
211
+ cb(&format!("[{}:{}]", next_token, token_piece));
212
+ } else {
213
+ // Normal mode: use incremental decoding for proper text
214
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
215
+ cb(&decoded_text);
279
216
  }
280
217
  }
281
218
 
@@ -285,12 +222,7 @@ impl Llama {
285
222
  }
286
223
 
287
224
  // Check stop sequences
288
- let generated_text = if callback.is_some() {
289
- previously_decoded.clone()
290
- } else {
291
- self.tokenizer.decode(&all_tokens[start_gen..], true)?
292
- };
293
-
225
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
294
226
  if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
295
227
  break;
296
228
  }
@@ -374,7 +306,12 @@ impl TextGenerator for Llama {
374
306
  ) -> CandleResult<String> {
375
307
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
376
308
  let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
377
- self.tokenizer.decode(&output_tokens, true)
309
+
310
+ if config.debug_tokens {
311
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
312
+ } else {
313
+ self.tokenizer.decode(&output_tokens, true)
314
+ }
378
315
  }
379
316
 
380
317
  fn generate_stream(
@@ -384,7 +321,7 @@ impl TextGenerator for Llama {
384
321
  mut callback: impl FnMut(&str),
385
322
  ) -> CandleResult<String> {
386
323
  let prompt_tokens = self.tokenizer.encode(prompt, true)?;
387
- let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
324
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
388
325
  self.tokenizer.decode(&output_tokens, true)
389
326
  }
390
327