red-candle 1.1.2 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +39 -45
- data/Rakefile +79 -88
- data/ext/candle/src/lib.rs +2 -4
- data/ext/candle/src/llm/quantized_gguf.rs +1 -1
- data/ext/candle/src/ruby/device.rs +30 -0
- data/ext/candle/src/ruby/embedding_model.rs +74 -28
- data/ext/candle/src/ruby/llm.rs +96 -1
- data/ext/candle/src/ruby/mod.rs +2 -0
- data/ext/candle/src/{ner.rs → ruby/ner.rs} +47 -15
- data/ext/candle/src/{reranker.rs → ruby/reranker.rs} +24 -2
- data/ext/candle/src/ruby/tensor.rs +101 -26
- data/ext/candle/src/ruby/tokenizer.rs +60 -3
- data/lib/candle/device_utils.rb +3 -15
- data/lib/candle/embedding_model.rb +44 -1
- data/lib/candle/llm.rb +63 -1
- data/lib/candle/ner.rb +34 -22
- data/lib/candle/reranker.rb +20 -1
- data/lib/candle/tensor.rb +15 -0
- data/lib/candle/version.rb +1 -1
- metadata +18 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3f2d005688d7b0253060d087a9800ea8c1d2c7bbb6ff6c92cca3ebc238d99be3
|
4
|
+
data.tar.gz: 6296db628c2d13a39ef035fe45c41be39de2404e6ab72d9735109ad18879f65c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4511b6f96d1356101e10547f740702479c894fb6d1e6f8cb04213b49b624ac8dc73d83b8b91bbab396e095fd31bbb4dd019ca967ce7f790447a4b77dd25d3356
|
7
|
+
data.tar.gz: 5a1ef095e2bbd9967317e0c416fb018da2673e18d76e00c98177cdba5dd2f9c5723fc5404b0d4221800227aab8b74e5560d420aff79f94e216365e6d3cee6f1e
|
data/README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
<img src="/docs/assets/logo-title.png" alt="red-candle" height="80px">
|
2
2
|
|
3
3
|
[](https://github.com/assaydepot/red-candle/actions/workflows/build.yml)
|
4
4
|
[](https://badge.fury.io/rb/red-candle)
|
@@ -18,7 +18,7 @@ gem install red-candle
|
|
18
18
|
require 'candle'
|
19
19
|
|
20
20
|
# Download a model (one-time, ~650MB) - Mistral, Llama3, Gemma all work!
|
21
|
-
llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
21
|
+
llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
22
22
|
gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
|
23
23
|
|
24
24
|
# Chat with it - no API calls, running locally in your Ruby process!
|
@@ -27,8 +27,8 @@ messages = [
|
|
27
27
|
]
|
28
28
|
|
29
29
|
puts llm.chat(messages)
|
30
|
-
# => "Ruby is a dynamic, object-oriented programming language known for its
|
31
|
-
# simplicity, elegance, and productivity, often used for web development
|
30
|
+
# => "Ruby is a dynamic, object-oriented programming language known for its
|
31
|
+
# simplicity, elegance, and productivity, often used for web development
|
32
32
|
# with frameworks like Rails."
|
33
33
|
```
|
34
34
|
|
@@ -99,22 +99,16 @@ x = x.reshape([3, 2])
|
|
99
99
|
require 'candle'
|
100
100
|
|
101
101
|
# Default model (JinaBERT) on CPU
|
102
|
-
model = Candle::EmbeddingModel.
|
102
|
+
model = Candle::EmbeddingModel.from_pretrained
|
103
103
|
embedding = model.embedding("Hi there!")
|
104
104
|
|
105
105
|
# Specify device (CPU, Metal, or CUDA)
|
106
106
|
device = Candle::Device.cpu # or Candle::Device.metal, Candle::Device.cuda
|
107
|
-
model = Candle::EmbeddingModel.
|
108
|
-
model_path: "jinaai/jina-embeddings-v2-base-en",
|
109
|
-
device: device
|
110
|
-
)
|
107
|
+
model = Candle::EmbeddingModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", device: device)
|
111
108
|
embedding = model.embedding("Hi there!")
|
112
109
|
|
113
110
|
# Reranker also supports device selection
|
114
|
-
reranker = Candle::Reranker.
|
115
|
-
model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2",
|
116
|
-
device: device
|
117
|
-
)
|
111
|
+
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2", device: device)
|
118
112
|
results = reranker.rerank("query", ["doc1", "doc2", "doc3"])
|
119
113
|
```
|
120
114
|
|
@@ -140,8 +134,8 @@ Red-Candle supports quantized models in GGUF format, offering 4-8x memory reduct
|
|
140
134
|
|
141
135
|
```ruby
|
142
136
|
# Load quantized models - always specify the GGUF filename
|
143
|
-
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
144
|
-
device: device,
|
137
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
138
|
+
device: device,
|
145
139
|
gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
|
146
140
|
|
147
141
|
# Register custom tokenizer mappings for your models
|
@@ -155,7 +149,7 @@ Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
|
|
155
149
|
**Memory usage comparison (7B models):**
|
156
150
|
- Full precision: ~28 GB
|
157
151
|
- Q8_0 (8-bit): ~7 GB - Best quality, larger size
|
158
|
-
- Q5_K_M (5-bit): ~4.5 GB - Very good quality
|
152
|
+
- Q5_K_M (5-bit): ~4.5 GB - Very good quality
|
159
153
|
- Q4_K_M (4-bit): ~4 GB - Recommended default, best balance
|
160
154
|
- Q3_K_M (3-bit): ~3 GB - Good for memory-constrained systems
|
161
155
|
|
@@ -169,13 +163,13 @@ Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
|
|
169
163
|
> **Warning**: Q2_K quantization can lead to "weight is negative, too large or not a valid number" errors during inference. Use Q3_K_M or higher for stable operation.
|
170
164
|
|
171
165
|
> ### ⚠️ Huggingface login warning
|
172
|
-
>
|
166
|
+
>
|
173
167
|
> Many models, including the one below, require you to agree to the terms. You'll need to:
|
174
168
|
> 1. Login to [Huggingface](https://huggingface.co)
|
175
169
|
> 2. Agree to the terms. For example: [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
176
170
|
> 3. Authenticate your session. Simplest way is with `huggingface-cli login`. Detail here: [Huggingface CLI](https://huggingface.co/docs/huggingface_hub/en/guides/cli)
|
177
171
|
>
|
178
|
-
> More details here: [Huggingface Authentication](HUGGINGFACE.md)
|
172
|
+
> More details here: [Huggingface Authentication](docs/HUGGINGFACE.md)
|
179
173
|
|
180
174
|
```ruby
|
181
175
|
require 'candle'
|
@@ -208,7 +202,7 @@ response = llm.chat(messages)
|
|
208
202
|
|
209
203
|
### GPU Acceleration
|
210
204
|
|
211
|
-
We see an 18x speed up running LLMs under CUDA vs CPU and a >3x speed up running under Metal vs CPU. Details [here](DEVICE_SUPPORT.md#performance-considerations).
|
205
|
+
We see an 18x speed up running LLMs under CUDA vs CPU and a >3x speed up running under Metal vs CPU. Details [here](docs/DEVICE_SUPPORT.md#performance-considerations).
|
212
206
|
|
213
207
|
```ruby
|
214
208
|
# CPU works for all models
|
@@ -216,7 +210,7 @@ device = Candle::Device.cpu
|
|
216
210
|
llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
|
217
211
|
|
218
212
|
# Metal
|
219
|
-
device = Candle::Device.metal
|
213
|
+
device = Candle::Device.metal
|
220
214
|
|
221
215
|
# CUDA support (for NVIDIA GPUs)
|
222
216
|
device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
|
@@ -325,9 +319,9 @@ The default model (`jinaai/jina-embeddings-v2-base-en` with the `sentence-transf
|
|
325
319
|
```ruby
|
326
320
|
> require 'candle'
|
327
321
|
# Ruby memory = 25.9 MB
|
328
|
-
> model = Candle::EmbeddingModel.
|
322
|
+
> model = Candle::EmbeddingModel.from_pretrained
|
329
323
|
# Ruby memory = 3.50 GB
|
330
|
-
> model2 = Candle::EmbeddingModel.
|
324
|
+
> model2 = Candle::EmbeddingModel.from_pretrained
|
331
325
|
# Ruby memory = 7.04 GB
|
332
326
|
> model2 = nil
|
333
327
|
> GC.start
|
@@ -353,7 +347,7 @@ And the following ruby:
|
|
353
347
|
|
354
348
|
```ruby
|
355
349
|
require 'candle'
|
356
|
-
model = Candle::EmbeddingModel.
|
350
|
+
model = Candle::EmbeddingModel.from_pretrained
|
357
351
|
embedding = model.embedding("Hi there!")
|
358
352
|
```
|
359
353
|
|
@@ -367,13 +361,13 @@ Red-Candle includes support for cross-encoder reranking models, which can be use
|
|
367
361
|
require 'candle'
|
368
362
|
|
369
363
|
# Initialize the reranker with a cross-encoder model
|
370
|
-
reranker = Candle::Reranker.
|
364
|
+
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
371
365
|
|
372
366
|
# Define your query and candidate documents
|
373
367
|
query = "How many people live in London?"
|
374
368
|
documents = [
|
375
369
|
"London is known for its financial district",
|
376
|
-
"Around 9 Million people live in London",
|
370
|
+
"Around 9 Million people live in London",
|
377
371
|
"The weather in London is often rainy",
|
378
372
|
"London is the capital of England"
|
379
373
|
]
|
@@ -457,7 +451,7 @@ For faster inference on NVIDIA GPUs:
|
|
457
451
|
|
458
452
|
```ruby
|
459
453
|
# Initialize with CUDA if available (falls back to CPU if not)
|
460
|
-
reranker = Candle::Reranker.
|
454
|
+
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2", cuda: true)
|
461
455
|
```
|
462
456
|
|
463
457
|
### How It Works
|
@@ -501,7 +495,7 @@ tokens = tokenizer.encode_to_tokens("Hello, world!")
|
|
501
495
|
|
502
496
|
# Get both IDs and tokens together
|
503
497
|
result = tokenizer.encode_with_tokens("preprocessing")
|
504
|
-
# => {"ids" => [101, 3653, 22618, 2527, 102],
|
498
|
+
# => {"ids" => [101, 3653, 22618, 2527, 102],
|
505
499
|
# "tokens" => ["[CLS]", "prep", "##ro", "##ces", "##sing", "[SEP]"]}
|
506
500
|
```
|
507
501
|
|
@@ -563,11 +557,11 @@ llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
|
563
557
|
llm_tokenizer = llm.tokenizer
|
564
558
|
|
565
559
|
# From EmbeddingModel
|
566
|
-
embedding_model = Candle::EmbeddingModel.
|
560
|
+
embedding_model = Candle::EmbeddingModel.from_pretrained
|
567
561
|
emb_tokenizer = embedding_model.tokenizer
|
568
562
|
|
569
563
|
# From Reranker
|
570
|
-
reranker = Candle::Reranker.
|
564
|
+
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
571
565
|
rank_tokenizer = reranker.tokenizer
|
572
566
|
```
|
573
567
|
|
@@ -578,7 +572,7 @@ Modern tokenizers split unknown or rare words into subword pieces:
|
|
578
572
|
```ruby
|
579
573
|
# See how words are split into subwords
|
580
574
|
result = tokenizer.encode_with_tokens("unbelievable")
|
581
|
-
# => {"ids" => [101, 4895, 6499, 102],
|
575
|
+
# => {"ids" => [101, 4895, 6499, 102],
|
582
576
|
# "tokens" => ["[CLS]", "un", "##believable", "[SEP]"]}
|
583
577
|
|
584
578
|
# The ## prefix indicates a continuation of the previous token
|
@@ -589,7 +583,7 @@ complex = tokenizer.encode_to_tokens("preprocessing tokenization")
|
|
589
583
|
### Use Cases
|
590
584
|
|
591
585
|
- **Token Analysis**: Understand how your text is being processed by models
|
592
|
-
- **Debugging**: See why certain inputs might cause unexpected model behavior
|
586
|
+
- **Debugging**: See why certain inputs might cause unexpected model behavior
|
593
587
|
- **Custom Preprocessing**: Build your own text processing pipelines
|
594
588
|
- **Educational**: Teach how modern NLP models handle text
|
595
589
|
- **NER Preparation**: Get aligned tokens for named entity recognition tasks
|
@@ -616,7 +610,7 @@ text = "Apple Inc. was founded by Steve Jobs and Steve Wozniak in Cupertino, Cal
|
|
616
610
|
entities = ner.extract_entities(text)
|
617
611
|
|
618
612
|
entities.each do |entity|
|
619
|
-
puts "#{entity[
|
613
|
+
puts "#{entity[:text]} (#{entity[:label]}) - confidence: #{entity[:confidence].round(2)}"
|
620
614
|
end
|
621
615
|
# Output:
|
622
616
|
# Apple Inc. (ORG) - confidence: 0.99
|
@@ -668,8 +662,8 @@ drug_recognizer = Candle::GazetteerEntityRecognizer.new("DRUG")
|
|
668
662
|
drug_recognizer.load_from_file("drug_names.txt")
|
669
663
|
|
670
664
|
# Case-sensitive matching
|
671
|
-
product_recognizer = Candle::GazetteerEntityRecognizer.new("PRODUCT",
|
672
|
-
["iPhone", "iPad", "MacBook"],
|
665
|
+
product_recognizer = Candle::GazetteerEntityRecognizer.new("PRODUCT",
|
666
|
+
["iPhone", "iPad", "MacBook"],
|
673
667
|
case_sensitive: true
|
674
668
|
)
|
675
669
|
```
|
@@ -686,7 +680,7 @@ hybrid = Candle::HybridNER.new("Babelscape/wikineural-multilingual-ner")
|
|
686
680
|
hybrid.add_pattern_recognizer("EMAIL", [/\b[\w._%+-]+@[\w.-]+\.[A-Z|a-z]{2,}\b/])
|
687
681
|
hybrid.add_pattern_recognizer("PHONE", [/\b\d{3}[-.]?\d{3}[-.]?\d{4}\b/])
|
688
682
|
|
689
|
-
# Add gazetteer recognizers
|
683
|
+
# Add gazetteer recognizers
|
690
684
|
hybrid.add_gazetteer_recognizer("COMPANY", ["Apple", "Google", "Microsoft"])
|
691
685
|
hybrid.add_gazetteer_recognizer("PRODUCT", ["iPhone", "Android", "Windows"])
|
692
686
|
|
@@ -739,7 +733,7 @@ ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
|
|
739
733
|
# English NER (requires separate tokenizer)
|
740
734
|
ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
|
741
735
|
|
742
|
-
# Multilingual NER
|
736
|
+
# Multilingual NER
|
743
737
|
ner = Candle::NER.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
|
744
738
|
|
745
739
|
# OntoNotes 5 (18 entity types including DATE, TIME, MONEY, etc.)
|
@@ -758,9 +752,9 @@ ner = Candle::NER.from_pretrained("allenai/scibert_scivocab_uncased")
|
|
758
752
|
```
|
759
753
|
|
760
754
|
2. **Batch Processing**: Process multiple texts together when possible
|
761
|
-
|
755
|
+
|
762
756
|
3. **Confidence Threshold**: Balance precision/recall with appropriate thresholds
|
763
|
-
|
757
|
+
|
764
758
|
4. **Entity Resolution**: The hybrid NER automatically handles overlapping entities
|
765
759
|
|
766
760
|
### Output Format
|
@@ -772,7 +766,7 @@ All NER methods return entities in a consistent format:
|
|
772
766
|
"text" => "Apple Inc.", # The entity text
|
773
767
|
"label" => "ORG", # Entity type
|
774
768
|
"start" => 0, # Character start position
|
775
|
-
"end" => 10, # Character end position
|
769
|
+
"end" => 10, # Character end position
|
776
770
|
"confidence" => 0.99, # Confidence score (0-1)
|
777
771
|
"token_start" => 0, # Token start index (model-based only)
|
778
772
|
"token_end" => 2, # Token end index (model-based only)
|
@@ -799,8 +793,8 @@ All NER methods return entities in a consistent format:
|
|
799
793
|
- Q3_K_M (3-bit) - Minimum recommended quantization
|
800
794
|
|
801
795
|
```ruby
|
802
|
-
llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
803
|
-
device: device,
|
796
|
+
llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
797
|
+
device: device,
|
804
798
|
gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
|
805
799
|
```
|
806
800
|
|
@@ -817,7 +811,7 @@ Failed to load quantized model: cannot find tensor model.embed_tokens.weight (Ru
|
|
817
811
|
1. Ensure you're using the latest version of red-candle (1.0.0 or higher)
|
818
812
|
2. Make sure to specify the exact GGUF filename:
|
819
813
|
```ruby
|
820
|
-
llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
814
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
821
815
|
device: device,
|
822
816
|
gguf_file: "mistral-7b-instruct-v0.2.Q4_K_M.gguf")
|
823
817
|
```
|
@@ -835,8 +829,8 @@ Failed to load quantized model: No GGUF file found in repository TheBloke/model-
|
|
835
829
|
**Solution:** Specify the exact GGUF filename:
|
836
830
|
```ruby
|
837
831
|
# Visit the HuggingFace repository to find the exact filename
|
838
|
-
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
839
|
-
device: device,
|
832
|
+
llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
|
833
|
+
device: device,
|
840
834
|
gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
|
841
835
|
```
|
842
836
|
|
@@ -864,7 +858,7 @@ Failed to load GGUF model: cannot find llama.attention.head_count in metadata (R
|
|
864
858
|
|
865
859
|
**Cause:** Some GGUF files may have been created with older conversion tools that don't include all required metadata fields.
|
866
860
|
|
867
|
-
**Solution:**
|
861
|
+
**Solution:**
|
868
862
|
- Try a different GGUF file from the same model
|
869
863
|
- Look for GGUF files from TheBloke or other reputable sources
|
870
864
|
- Check if a newer version of the GGUF file is available
|
data/Rakefile
CHANGED
@@ -1,22 +1,10 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
require "bundler/gem_tasks"
|
4
|
-
require "rake/testtask"
|
5
4
|
require "rake/extensiontask"
|
5
|
+
require "rspec/core/rake_task"
|
6
6
|
|
7
|
-
task default: :
|
8
|
-
Rake::TestTask.new do |t|
|
9
|
-
t.deps << :compile
|
10
|
-
t.libs << "test"
|
11
|
-
t.test_files = FileList["test/**/*_test.rb"]
|
12
|
-
.exclude("test/benchmarks/**/*_test.rb")
|
13
|
-
.exclude("test/llm/llm_test.rb")
|
14
|
-
.exclude("test/llm/gemma_test.rb")
|
15
|
-
.exclude("test/llm/mistral_test.rb")
|
16
|
-
.exclude("test/llm/llama_test.rb")
|
17
|
-
.exclude("test/llm/phi_test.rb")
|
18
|
-
.exclude("test/llm/qwen_test.rb")
|
19
|
-
end
|
7
|
+
task default: :spec
|
20
8
|
|
21
9
|
spec = Bundler.load_gemspec("candle.gemspec")
|
22
10
|
Rake::ExtensionTask.new("candle", spec) do |c|
|
@@ -33,80 +21,6 @@ Rake::ExtensionTask.new("candle", spec) do |c|
|
|
33
21
|
]
|
34
22
|
end
|
35
23
|
|
36
|
-
desc "Run device compatibility tests"
|
37
|
-
Rake::TestTask.new("test:device") do |t|
|
38
|
-
t.deps << :compile
|
39
|
-
t.libs << "test"
|
40
|
-
t.test_files = FileList["test/device_compatibility_test.rb"]
|
41
|
-
t.verbose = true
|
42
|
-
end
|
43
|
-
|
44
|
-
desc "Run benchmark tests"
|
45
|
-
Rake::TestTask.new("test:benchmark") do |t|
|
46
|
-
t.deps << :compile
|
47
|
-
t.libs << "test"
|
48
|
-
t.test_files = FileList["test/benchmarks/**/*_test.rb"]
|
49
|
-
t.verbose = true
|
50
|
-
end
|
51
|
-
|
52
|
-
desc "Run all tests including benchmarks"
|
53
|
-
task "test:all" => [:test, "test:benchmark"]
|
54
|
-
|
55
|
-
desc "Run tests on specific devices"
|
56
|
-
namespace :test do
|
57
|
-
%w[cpu metal cuda].each do |device|
|
58
|
-
desc "Run tests on #{device.upcase} only"
|
59
|
-
task "device:#{device}" => :compile do
|
60
|
-
ENV['CANDLE_TEST_DEVICES'] = device
|
61
|
-
Rake::Task["test:device"].invoke
|
62
|
-
end
|
63
|
-
end
|
64
|
-
end
|
65
|
-
|
66
|
-
desc "Run benchmarks with device tests"
|
67
|
-
task "test:device:benchmark" => :compile do
|
68
|
-
ENV['CANDLE_TEST_VERBOSE'] = 'true'
|
69
|
-
Rake::Task["test:device"].invoke
|
70
|
-
Rake::Task["test:benchmark"].invoke
|
71
|
-
end
|
72
|
-
|
73
|
-
desc "Run LLM tests for specific models"
|
74
|
-
namespace :test do
|
75
|
-
namespace :llm do
|
76
|
-
desc "Run tests for Gemma models"
|
77
|
-
task :gemma => :compile do
|
78
|
-
ruby "-Itest", "test/llm/gemma_test.rb"
|
79
|
-
end
|
80
|
-
|
81
|
-
desc "Run tests for Phi models"
|
82
|
-
task :phi => :compile do
|
83
|
-
ruby "-Itest", "test/llm/phi_test.rb"
|
84
|
-
end
|
85
|
-
|
86
|
-
desc "Run tests for Qwen models"
|
87
|
-
task :qwen => :compile do
|
88
|
-
ruby "-Itest", "test/llm/qwen_test.rb"
|
89
|
-
end
|
90
|
-
|
91
|
-
desc "Run tests for Mistral models"
|
92
|
-
task :mistral => :compile do
|
93
|
-
ruby "-Itest", "test/llm/mistral_test.rb"
|
94
|
-
end
|
95
|
-
|
96
|
-
desc "Run tests for Llama models"
|
97
|
-
task :llama => :compile do
|
98
|
-
ruby "-Itest", "test/llm/llama_test.rb"
|
99
|
-
end
|
100
|
-
|
101
|
-
desc "Run tests for TinyLlama models"
|
102
|
-
task :tinyllama => :compile do
|
103
|
-
ruby "-Itest", "test/llm/tinyllama_test.rb"
|
104
|
-
end
|
105
|
-
|
106
|
-
desc "Run all LLM tests (WARNING: downloads large models)"
|
107
|
-
task :all => [:gemma, :phi, :qwen, :mistral, :llama]
|
108
|
-
end
|
109
|
-
end
|
110
24
|
|
111
25
|
namespace :doc do
|
112
26
|
task default: %i[rustdoc yard]
|
@@ -174,3 +88,80 @@ end
|
|
174
88
|
|
175
89
|
desc "Run Rust tests with coverage (alias)"
|
176
90
|
task "coverage:rust" => "rust:coverage:html"
|
91
|
+
|
92
|
+
# RSpec tasks
|
93
|
+
desc "Run RSpec tests"
|
94
|
+
RSpec::Core::RakeTask.new(:spec) do |t|
|
95
|
+
t.rspec_opts = "--format progress"
|
96
|
+
end
|
97
|
+
|
98
|
+
# Add compile as a dependency for spec task
|
99
|
+
task spec: :compile
|
100
|
+
|
101
|
+
namespace :spec do
|
102
|
+
desc "Run RSpec tests with all devices"
|
103
|
+
RSpec::Core::RakeTask.new(:device) do |t|
|
104
|
+
t.rspec_opts = "--format documentation --tag device"
|
105
|
+
end
|
106
|
+
|
107
|
+
desc "Run RSpec tests with coverage"
|
108
|
+
task :coverage do
|
109
|
+
ENV['COVERAGE'] = 'true'
|
110
|
+
Rake::Task["spec"].invoke
|
111
|
+
end
|
112
|
+
|
113
|
+
desc "Run RSpec tests in parallel (requires parallel_tests gem)"
|
114
|
+
task :parallel do
|
115
|
+
begin
|
116
|
+
require 'parallel_tests'
|
117
|
+
sh "parallel_rspec spec/"
|
118
|
+
rescue LoadError
|
119
|
+
puts "parallel_tests gem not installed. Run: gem install parallel_tests"
|
120
|
+
end
|
121
|
+
end
|
122
|
+
|
123
|
+
desc "Run specific device tests"
|
124
|
+
%w[cpu metal cuda].each do |device|
|
125
|
+
desc "Run tests on #{device.upcase} only"
|
126
|
+
task "device:#{device}" => :compile do
|
127
|
+
ENV['CANDLE_TEST_DEVICES'] = device
|
128
|
+
sh "rspec spec/device_compatibility_spec.rb --format documentation"
|
129
|
+
end
|
130
|
+
end
|
131
|
+
|
132
|
+
desc "Run LLM tests for specific models"
|
133
|
+
namespace :llm do
|
134
|
+
desc "Run tests for Gemma models"
|
135
|
+
task :gemma => :compile do
|
136
|
+
sh "rspec spec/llm/gemma_spec.rb --format documentation"
|
137
|
+
end
|
138
|
+
|
139
|
+
desc "Run tests for Phi models"
|
140
|
+
task :phi => :compile do
|
141
|
+
sh "rspec spec/llm/phi_spec.rb --format documentation"
|
142
|
+
end
|
143
|
+
|
144
|
+
desc "Run tests for Qwen models"
|
145
|
+
task :qwen => :compile do
|
146
|
+
sh "rspec spec/llm/qwen_spec.rb --format documentation"
|
147
|
+
end
|
148
|
+
|
149
|
+
desc "Run tests for Mistral models"
|
150
|
+
task :mistral => :compile do
|
151
|
+
sh "rspec spec/llm/mistral_spec.rb --format documentation"
|
152
|
+
end
|
153
|
+
|
154
|
+
desc "Run tests for Llama models"
|
155
|
+
task :llama => :compile do
|
156
|
+
sh "rspec spec/llm/llama_spec.rb --format documentation"
|
157
|
+
end
|
158
|
+
|
159
|
+
desc "Run tests for TinyLlama models"
|
160
|
+
task :tinyllama => :compile do
|
161
|
+
sh "rspec spec/llm/tinyllama_spec.rb --format documentation"
|
162
|
+
end
|
163
|
+
|
164
|
+
desc "Run all LLM tests (WARNING: requires large models already downloaded)"
|
165
|
+
task :all => [:gemma, :phi, :qwen, :mistral, :llama, :tinyllama]
|
166
|
+
end
|
167
|
+
end
|
data/ext/candle/src/lib.rs
CHANGED
@@ -4,8 +4,6 @@ use crate::ruby::candle_utils;
|
|
4
4
|
use crate::ruby::Result;
|
5
5
|
|
6
6
|
pub mod llm;
|
7
|
-
pub mod ner;
|
8
|
-
pub mod reranker;
|
9
7
|
pub mod ruby;
|
10
8
|
pub mod structured;
|
11
9
|
pub mod tokenizer;
|
@@ -44,8 +42,8 @@ fn init(ruby: &Ruby) -> Result<()> {
|
|
44
42
|
|
45
43
|
ruby::init_embedding_model(rb_candle)?;
|
46
44
|
ruby::init_llm(rb_candle)?;
|
47
|
-
ner::init(rb_candle)?;
|
48
|
-
reranker::init(rb_candle)?;
|
45
|
+
ruby::ner::init(rb_candle)?;
|
46
|
+
ruby::reranker::init(rb_candle)?;
|
49
47
|
ruby::dtype::init(rb_candle)?;
|
50
48
|
ruby::device::init(rb_candle)?;
|
51
49
|
ruby::tensor::init(rb_candle)?;
|
@@ -53,6 +53,30 @@ pub fn default_device() -> Device {
|
|
53
53
|
}
|
54
54
|
}
|
55
55
|
|
56
|
+
/// Get the best available device by checking runtime availability
|
57
|
+
pub fn best_device() -> Device {
|
58
|
+
// Try devices in order of preference
|
59
|
+
|
60
|
+
#[cfg(feature = "metal")]
|
61
|
+
{
|
62
|
+
// Check if Metal is actually available at runtime
|
63
|
+
if CoreDevice::new_metal(0).is_ok() {
|
64
|
+
return Device::Metal;
|
65
|
+
}
|
66
|
+
}
|
67
|
+
|
68
|
+
#[cfg(feature = "cuda")]
|
69
|
+
{
|
70
|
+
// Check if CUDA is actually available at runtime
|
71
|
+
if CoreDevice::new_cuda(0).is_ok() {
|
72
|
+
return Device::Cuda;
|
73
|
+
}
|
74
|
+
}
|
75
|
+
|
76
|
+
// Always fall back to CPU
|
77
|
+
Device::Cpu
|
78
|
+
}
|
79
|
+
|
56
80
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
57
81
|
#[magnus::wrap(class = "Candle::Device")]
|
58
82
|
pub enum Device {
|
@@ -66,6 +90,11 @@ impl Device {
|
|
66
90
|
pub fn cpu() -> Self {
|
67
91
|
Self::Cpu
|
68
92
|
}
|
93
|
+
|
94
|
+
/// Get the best available device
|
95
|
+
pub fn best() -> Self {
|
96
|
+
best_device()
|
97
|
+
}
|
69
98
|
|
70
99
|
/// Create a CUDA device (GPU)
|
71
100
|
pub fn cuda() -> Result<Self> {
|
@@ -195,6 +224,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
195
224
|
rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
|
196
225
|
rb_device.define_singleton_method("available_devices", function!(available_devices, 0))?;
|
197
226
|
rb_device.define_singleton_method("default", function!(default_device, 0))?;
|
227
|
+
rb_device.define_singleton_method("best", function!(best_device, 0))?;
|
198
228
|
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
199
229
|
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
200
230
|
rb_device.define_method("==", method!(Device::__eq__, 1))?;
|