red-candle 1.1.2 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cd566594df4f0d3ec8ecd592b1d71610ef0ba2091cd91ea53549405a8c5c9b18
4
- data.tar.gz: cfaab403935e927371fbd2f605ea60da3f44effd64950bb132651a5e6437e30c
3
+ metadata.gz: 3f2d005688d7b0253060d087a9800ea8c1d2c7bbb6ff6c92cca3ebc238d99be3
4
+ data.tar.gz: 6296db628c2d13a39ef035fe45c41be39de2404e6ab72d9735109ad18879f65c
5
5
  SHA512:
6
- metadata.gz: b18f384766db5a3de2f794764a8a47d3d69261ae4bf5c93992df7228fe9a5817eb1862b516a9ee3e63b5d349936f46534d50f1378b7932c7ac482550270a8bea
7
- data.tar.gz: ce74834cde884bf03e3d163f6d081e87f95db0fba8db24300b8c6aa5b0f1f0162542c836bebb414010534778e394ca36faaec807ff1cc652902c3b318503def5
6
+ metadata.gz: 4511b6f96d1356101e10547f740702479c894fb6d1e6f8cb04213b49b624ac8dc73d83b8b91bbab396e095fd31bbb4dd019ca967ce7f790447a4b77dd25d3356
7
+ data.tar.gz: 5a1ef095e2bbd9967317e0c416fb018da2673e18d76e00c98177cdba5dd2f9c5723fc5404b0d4221800227aab8b74e5560d420aff79f94e216365e6d3cee6f1e
data/README.md CHANGED
@@ -1,4 +1,4 @@
1
- # `red-candle` Native LLMs for Ruby 🚀
1
+ <img src="/docs/assets/logo-title.png" alt="red-candle" height="80px">
2
2
 
3
3
  [![build](https://github.com/assaydepot/red-candle/actions/workflows/build.yml/badge.svg)](https://github.com/assaydepot/red-candle/actions/workflows/build.yml)
4
4
  [![Gem Version](https://badge.fury.io/rb/red-candle.svg)](https://badge.fury.io/rb/red-candle)
@@ -18,7 +18,7 @@ gem install red-candle
18
18
  require 'candle'
19
19
 
20
20
  # Download a model (one-time, ~650MB) - Mistral, Llama3, Gemma all work!
21
- llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
21
+ llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
22
22
  gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
23
23
 
24
24
  # Chat with it - no API calls, running locally in your Ruby process!
@@ -27,8 +27,8 @@ messages = [
27
27
  ]
28
28
 
29
29
  puts llm.chat(messages)
30
- # => "Ruby is a dynamic, object-oriented programming language known for its
31
- # simplicity, elegance, and productivity, often used for web development
30
+ # => "Ruby is a dynamic, object-oriented programming language known for its
31
+ # simplicity, elegance, and productivity, often used for web development
32
32
  # with frameworks like Rails."
33
33
  ```
34
34
 
@@ -99,22 +99,16 @@ x = x.reshape([3, 2])
99
99
  require 'candle'
100
100
 
101
101
  # Default model (JinaBERT) on CPU
102
- model = Candle::EmbeddingModel.new
102
+ model = Candle::EmbeddingModel.from_pretrained
103
103
  embedding = model.embedding("Hi there!")
104
104
 
105
105
  # Specify device (CPU, Metal, or CUDA)
106
106
  device = Candle::Device.cpu # or Candle::Device.metal, Candle::Device.cuda
107
- model = Candle::EmbeddingModel.new(
108
- model_path: "jinaai/jina-embeddings-v2-base-en",
109
- device: device
110
- )
107
+ model = Candle::EmbeddingModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", device: device)
111
108
  embedding = model.embedding("Hi there!")
112
109
 
113
110
  # Reranker also supports device selection
114
- reranker = Candle::Reranker.new(
115
- model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2",
116
- device: device
117
- )
111
+ reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2", device: device)
118
112
  results = reranker.rerank("query", ["doc1", "doc2", "doc3"])
119
113
  ```
120
114
 
@@ -140,8 +134,8 @@ Red-Candle supports quantized models in GGUF format, offering 4-8x memory reduct
140
134
 
141
135
  ```ruby
142
136
  # Load quantized models - always specify the GGUF filename
143
- llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
144
- device: device,
137
+ llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
138
+ device: device,
145
139
  gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
146
140
 
147
141
  # Register custom tokenizer mappings for your models
@@ -155,7 +149,7 @@ Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
155
149
  **Memory usage comparison (7B models):**
156
150
  - Full precision: ~28 GB
157
151
  - Q8_0 (8-bit): ~7 GB - Best quality, larger size
158
- - Q5_K_M (5-bit): ~4.5 GB - Very good quality
152
+ - Q5_K_M (5-bit): ~4.5 GB - Very good quality
159
153
  - Q4_K_M (4-bit): ~4 GB - Recommended default, best balance
160
154
  - Q3_K_M (3-bit): ~3 GB - Good for memory-constrained systems
161
155
 
@@ -169,13 +163,13 @@ Candle::LLM.register_tokenizer("my-org/my-model-GGUF", "my-org/my-tokenizer")
169
163
  > **Warning**: Q2_K quantization can lead to "weight is negative, too large or not a valid number" errors during inference. Use Q3_K_M or higher for stable operation.
170
164
 
171
165
  > ### ⚠️ Huggingface login warning
172
- >
166
+ >
173
167
  > Many models, including the one below, require you to agree to the terms. You'll need to:
174
168
  > 1. Login to [Huggingface](https://huggingface.co)
175
169
  > 2. Agree to the terms. For example: [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
176
170
  > 3. Authenticate your session. Simplest way is with `huggingface-cli login`. Detail here: [Huggingface CLI](https://huggingface.co/docs/huggingface_hub/en/guides/cli)
177
171
  >
178
- > More details here: [Huggingface Authentication](HUGGINGFACE.md)
172
+ > More details here: [Huggingface Authentication](docs/HUGGINGFACE.md)
179
173
 
180
174
  ```ruby
181
175
  require 'candle'
@@ -208,7 +202,7 @@ response = llm.chat(messages)
208
202
 
209
203
  ### GPU Acceleration
210
204
 
211
- We see an 18x speed up running LLMs under CUDA vs CPU and a >3x speed up running under Metal vs CPU. Details [here](DEVICE_SUPPORT.md#performance-considerations).
205
+ We see an 18x speed up running LLMs under CUDA vs CPU and a >3x speed up running under Metal vs CPU. Details [here](docs/DEVICE_SUPPORT.md#performance-considerations).
212
206
 
213
207
  ```ruby
214
208
  # CPU works for all models
@@ -216,7 +210,7 @@ device = Candle::Device.cpu
216
210
  llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device: device)
217
211
 
218
212
  # Metal
219
- device = Candle::Device.metal
213
+ device = Candle::Device.metal
220
214
 
221
215
  # CUDA support (for NVIDIA GPUs)
222
216
  device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
@@ -325,9 +319,9 @@ The default model (`jinaai/jina-embeddings-v2-base-en` with the `sentence-transf
325
319
  ```ruby
326
320
  > require 'candle'
327
321
  # Ruby memory = 25.9 MB
328
- > model = Candle::EmbeddingModel.new
322
+ > model = Candle::EmbeddingModel.from_pretrained
329
323
  # Ruby memory = 3.50 GB
330
- > model2 = Candle::EmbeddingModel.new
324
+ > model2 = Candle::EmbeddingModel.from_pretrained
331
325
  # Ruby memory = 7.04 GB
332
326
  > model2 = nil
333
327
  > GC.start
@@ -353,7 +347,7 @@ And the following ruby:
353
347
 
354
348
  ```ruby
355
349
  require 'candle'
356
- model = Candle::EmbeddingModel.new
350
+ model = Candle::EmbeddingModel.from_pretrained
357
351
  embedding = model.embedding("Hi there!")
358
352
  ```
359
353
 
@@ -367,13 +361,13 @@ Red-Candle includes support for cross-encoder reranking models, which can be use
367
361
  require 'candle'
368
362
 
369
363
  # Initialize the reranker with a cross-encoder model
370
- reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2")
364
+ reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
371
365
 
372
366
  # Define your query and candidate documents
373
367
  query = "How many people live in London?"
374
368
  documents = [
375
369
  "London is known for its financial district",
376
- "Around 9 Million people live in London",
370
+ "Around 9 Million people live in London",
377
371
  "The weather in London is often rainy",
378
372
  "London is the capital of England"
379
373
  ]
@@ -457,7 +451,7 @@ For faster inference on NVIDIA GPUs:
457
451
 
458
452
  ```ruby
459
453
  # Initialize with CUDA if available (falls back to CPU if not)
460
- reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2", cuda: true)
454
+ reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2", cuda: true)
461
455
  ```
462
456
 
463
457
  ### How It Works
@@ -501,7 +495,7 @@ tokens = tokenizer.encode_to_tokens("Hello, world!")
501
495
 
502
496
  # Get both IDs and tokens together
503
497
  result = tokenizer.encode_with_tokens("preprocessing")
504
- # => {"ids" => [101, 3653, 22618, 2527, 102],
498
+ # => {"ids" => [101, 3653, 22618, 2527, 102],
505
499
  # "tokens" => ["[CLS]", "prep", "##ro", "##ces", "##sing", "[SEP]"]}
506
500
  ```
507
501
 
@@ -563,11 +557,11 @@ llm = Candle::LLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
563
557
  llm_tokenizer = llm.tokenizer
564
558
 
565
559
  # From EmbeddingModel
566
- embedding_model = Candle::EmbeddingModel.new
560
+ embedding_model = Candle::EmbeddingModel.from_pretrained
567
561
  emb_tokenizer = embedding_model.tokenizer
568
562
 
569
563
  # From Reranker
570
- reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2")
564
+ reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
571
565
  rank_tokenizer = reranker.tokenizer
572
566
  ```
573
567
 
@@ -578,7 +572,7 @@ Modern tokenizers split unknown or rare words into subword pieces:
578
572
  ```ruby
579
573
  # See how words are split into subwords
580
574
  result = tokenizer.encode_with_tokens("unbelievable")
581
- # => {"ids" => [101, 4895, 6499, 102],
575
+ # => {"ids" => [101, 4895, 6499, 102],
582
576
  # "tokens" => ["[CLS]", "un", "##believable", "[SEP]"]}
583
577
 
584
578
  # The ## prefix indicates a continuation of the previous token
@@ -589,7 +583,7 @@ complex = tokenizer.encode_to_tokens("preprocessing tokenization")
589
583
  ### Use Cases
590
584
 
591
585
  - **Token Analysis**: Understand how your text is being processed by models
592
- - **Debugging**: See why certain inputs might cause unexpected model behavior
586
+ - **Debugging**: See why certain inputs might cause unexpected model behavior
593
587
  - **Custom Preprocessing**: Build your own text processing pipelines
594
588
  - **Educational**: Teach how modern NLP models handle text
595
589
  - **NER Preparation**: Get aligned tokens for named entity recognition tasks
@@ -616,7 +610,7 @@ text = "Apple Inc. was founded by Steve Jobs and Steve Wozniak in Cupertino, Cal
616
610
  entities = ner.extract_entities(text)
617
611
 
618
612
  entities.each do |entity|
619
- puts "#{entity['text']} (#{entity['label']}) - confidence: #{entity['confidence'].round(2)}"
613
+ puts "#{entity[:text]} (#{entity[:label]}) - confidence: #{entity[:confidence].round(2)}"
620
614
  end
621
615
  # Output:
622
616
  # Apple Inc. (ORG) - confidence: 0.99
@@ -668,8 +662,8 @@ drug_recognizer = Candle::GazetteerEntityRecognizer.new("DRUG")
668
662
  drug_recognizer.load_from_file("drug_names.txt")
669
663
 
670
664
  # Case-sensitive matching
671
- product_recognizer = Candle::GazetteerEntityRecognizer.new("PRODUCT",
672
- ["iPhone", "iPad", "MacBook"],
665
+ product_recognizer = Candle::GazetteerEntityRecognizer.new("PRODUCT",
666
+ ["iPhone", "iPad", "MacBook"],
673
667
  case_sensitive: true
674
668
  )
675
669
  ```
@@ -686,7 +680,7 @@ hybrid = Candle::HybridNER.new("Babelscape/wikineural-multilingual-ner")
686
680
  hybrid.add_pattern_recognizer("EMAIL", [/\b[\w._%+-]+@[\w.-]+\.[A-Z|a-z]{2,}\b/])
687
681
  hybrid.add_pattern_recognizer("PHONE", [/\b\d{3}[-.]?\d{3}[-.]?\d{4}\b/])
688
682
 
689
- # Add gazetteer recognizers
683
+ # Add gazetteer recognizers
690
684
  hybrid.add_gazetteer_recognizer("COMPANY", ["Apple", "Google", "Microsoft"])
691
685
  hybrid.add_gazetteer_recognizer("PRODUCT", ["iPhone", "Android", "Windows"])
692
686
 
@@ -739,7 +733,7 @@ ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
739
733
  # English NER (requires separate tokenizer)
740
734
  ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
741
735
 
742
- # Multilingual NER
736
+ # Multilingual NER
743
737
  ner = Candle::NER.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
744
738
 
745
739
  # OntoNotes 5 (18 entity types including DATE, TIME, MONEY, etc.)
@@ -758,9 +752,9 @@ ner = Candle::NER.from_pretrained("allenai/scibert_scivocab_uncased")
758
752
  ```
759
753
 
760
754
  2. **Batch Processing**: Process multiple texts together when possible
761
-
755
+
762
756
  3. **Confidence Threshold**: Balance precision/recall with appropriate thresholds
763
-
757
+
764
758
  4. **Entity Resolution**: The hybrid NER automatically handles overlapping entities
765
759
 
766
760
  ### Output Format
@@ -772,7 +766,7 @@ All NER methods return entities in a consistent format:
772
766
  "text" => "Apple Inc.", # The entity text
773
767
  "label" => "ORG", # Entity type
774
768
  "start" => 0, # Character start position
775
- "end" => 10, # Character end position
769
+ "end" => 10, # Character end position
776
770
  "confidence" => 0.99, # Confidence score (0-1)
777
771
  "token_start" => 0, # Token start index (model-based only)
778
772
  "token_end" => 2, # Token end index (model-based only)
@@ -799,8 +793,8 @@ All NER methods return entities in a consistent format:
799
793
  - Q3_K_M (3-bit) - Minimum recommended quantization
800
794
 
801
795
  ```ruby
802
- llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
803
- device: device,
796
+ llm = Candle::LLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
797
+ device: device,
804
798
  gguf_file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")
805
799
  ```
806
800
 
@@ -817,7 +811,7 @@ Failed to load quantized model: cannot find tensor model.embed_tokens.weight (Ru
817
811
  1. Ensure you're using the latest version of red-candle (1.0.0 or higher)
818
812
  2. Make sure to specify the exact GGUF filename:
819
813
  ```ruby
820
- llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
814
+ llm = Candle::LLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
821
815
  device: device,
822
816
  gguf_file: "mistral-7b-instruct-v0.2.Q4_K_M.gguf")
823
817
  ```
@@ -835,8 +829,8 @@ Failed to load quantized model: No GGUF file found in repository TheBloke/model-
835
829
  **Solution:** Specify the exact GGUF filename:
836
830
  ```ruby
837
831
  # Visit the HuggingFace repository to find the exact filename
838
- llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
839
- device: device,
832
+ llm = Candle::LLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
833
+ device: device,
840
834
  gguf_file: "llama-2-7b-chat.Q4_K_M.gguf")
841
835
  ```
842
836
 
@@ -864,7 +858,7 @@ Failed to load GGUF model: cannot find llama.attention.head_count in metadata (R
864
858
 
865
859
  **Cause:** Some GGUF files may have been created with older conversion tools that don't include all required metadata fields.
866
860
 
867
- **Solution:**
861
+ **Solution:**
868
862
  - Try a different GGUF file from the same model
869
863
  - Look for GGUF files from TheBloke or other reputable sources
870
864
  - Check if a newer version of the GGUF file is available
data/Rakefile CHANGED
@@ -1,22 +1,10 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  require "bundler/gem_tasks"
4
- require "rake/testtask"
5
4
  require "rake/extensiontask"
5
+ require "rspec/core/rake_task"
6
6
 
7
- task default: :test
8
- Rake::TestTask.new do |t|
9
- t.deps << :compile
10
- t.libs << "test"
11
- t.test_files = FileList["test/**/*_test.rb"]
12
- .exclude("test/benchmarks/**/*_test.rb")
13
- .exclude("test/llm/llm_test.rb")
14
- .exclude("test/llm/gemma_test.rb")
15
- .exclude("test/llm/mistral_test.rb")
16
- .exclude("test/llm/llama_test.rb")
17
- .exclude("test/llm/phi_test.rb")
18
- .exclude("test/llm/qwen_test.rb")
19
- end
7
+ task default: :spec
20
8
 
21
9
  spec = Bundler.load_gemspec("candle.gemspec")
22
10
  Rake::ExtensionTask.new("candle", spec) do |c|
@@ -33,80 +21,6 @@ Rake::ExtensionTask.new("candle", spec) do |c|
33
21
  ]
34
22
  end
35
23
 
36
- desc "Run device compatibility tests"
37
- Rake::TestTask.new("test:device") do |t|
38
- t.deps << :compile
39
- t.libs << "test"
40
- t.test_files = FileList["test/device_compatibility_test.rb"]
41
- t.verbose = true
42
- end
43
-
44
- desc "Run benchmark tests"
45
- Rake::TestTask.new("test:benchmark") do |t|
46
- t.deps << :compile
47
- t.libs << "test"
48
- t.test_files = FileList["test/benchmarks/**/*_test.rb"]
49
- t.verbose = true
50
- end
51
-
52
- desc "Run all tests including benchmarks"
53
- task "test:all" => [:test, "test:benchmark"]
54
-
55
- desc "Run tests on specific devices"
56
- namespace :test do
57
- %w[cpu metal cuda].each do |device|
58
- desc "Run tests on #{device.upcase} only"
59
- task "device:#{device}" => :compile do
60
- ENV['CANDLE_TEST_DEVICES'] = device
61
- Rake::Task["test:device"].invoke
62
- end
63
- end
64
- end
65
-
66
- desc "Run benchmarks with device tests"
67
- task "test:device:benchmark" => :compile do
68
- ENV['CANDLE_TEST_VERBOSE'] = 'true'
69
- Rake::Task["test:device"].invoke
70
- Rake::Task["test:benchmark"].invoke
71
- end
72
-
73
- desc "Run LLM tests for specific models"
74
- namespace :test do
75
- namespace :llm do
76
- desc "Run tests for Gemma models"
77
- task :gemma => :compile do
78
- ruby "-Itest", "test/llm/gemma_test.rb"
79
- end
80
-
81
- desc "Run tests for Phi models"
82
- task :phi => :compile do
83
- ruby "-Itest", "test/llm/phi_test.rb"
84
- end
85
-
86
- desc "Run tests for Qwen models"
87
- task :qwen => :compile do
88
- ruby "-Itest", "test/llm/qwen_test.rb"
89
- end
90
-
91
- desc "Run tests for Mistral models"
92
- task :mistral => :compile do
93
- ruby "-Itest", "test/llm/mistral_test.rb"
94
- end
95
-
96
- desc "Run tests for Llama models"
97
- task :llama => :compile do
98
- ruby "-Itest", "test/llm/llama_test.rb"
99
- end
100
-
101
- desc "Run tests for TinyLlama models"
102
- task :tinyllama => :compile do
103
- ruby "-Itest", "test/llm/tinyllama_test.rb"
104
- end
105
-
106
- desc "Run all LLM tests (WARNING: downloads large models)"
107
- task :all => [:gemma, :phi, :qwen, :mistral, :llama]
108
- end
109
- end
110
24
 
111
25
  namespace :doc do
112
26
  task default: %i[rustdoc yard]
@@ -174,3 +88,80 @@ end
174
88
 
175
89
  desc "Run Rust tests with coverage (alias)"
176
90
  task "coverage:rust" => "rust:coverage:html"
91
+
92
+ # RSpec tasks
93
+ desc "Run RSpec tests"
94
+ RSpec::Core::RakeTask.new(:spec) do |t|
95
+ t.rspec_opts = "--format progress"
96
+ end
97
+
98
+ # Add compile as a dependency for spec task
99
+ task spec: :compile
100
+
101
+ namespace :spec do
102
+ desc "Run RSpec tests with all devices"
103
+ RSpec::Core::RakeTask.new(:device) do |t|
104
+ t.rspec_opts = "--format documentation --tag device"
105
+ end
106
+
107
+ desc "Run RSpec tests with coverage"
108
+ task :coverage do
109
+ ENV['COVERAGE'] = 'true'
110
+ Rake::Task["spec"].invoke
111
+ end
112
+
113
+ desc "Run RSpec tests in parallel (requires parallel_tests gem)"
114
+ task :parallel do
115
+ begin
116
+ require 'parallel_tests'
117
+ sh "parallel_rspec spec/"
118
+ rescue LoadError
119
+ puts "parallel_tests gem not installed. Run: gem install parallel_tests"
120
+ end
121
+ end
122
+
123
+ desc "Run specific device tests"
124
+ %w[cpu metal cuda].each do |device|
125
+ desc "Run tests on #{device.upcase} only"
126
+ task "device:#{device}" => :compile do
127
+ ENV['CANDLE_TEST_DEVICES'] = device
128
+ sh "rspec spec/device_compatibility_spec.rb --format documentation"
129
+ end
130
+ end
131
+
132
+ desc "Run LLM tests for specific models"
133
+ namespace :llm do
134
+ desc "Run tests for Gemma models"
135
+ task :gemma => :compile do
136
+ sh "rspec spec/llm/gemma_spec.rb --format documentation"
137
+ end
138
+
139
+ desc "Run tests for Phi models"
140
+ task :phi => :compile do
141
+ sh "rspec spec/llm/phi_spec.rb --format documentation"
142
+ end
143
+
144
+ desc "Run tests for Qwen models"
145
+ task :qwen => :compile do
146
+ sh "rspec spec/llm/qwen_spec.rb --format documentation"
147
+ end
148
+
149
+ desc "Run tests for Mistral models"
150
+ task :mistral => :compile do
151
+ sh "rspec spec/llm/mistral_spec.rb --format documentation"
152
+ end
153
+
154
+ desc "Run tests for Llama models"
155
+ task :llama => :compile do
156
+ sh "rspec spec/llm/llama_spec.rb --format documentation"
157
+ end
158
+
159
+ desc "Run tests for TinyLlama models"
160
+ task :tinyllama => :compile do
161
+ sh "rspec spec/llm/tinyllama_spec.rb --format documentation"
162
+ end
163
+
164
+ desc "Run all LLM tests (WARNING: requires large models already downloaded)"
165
+ task :all => [:gemma, :phi, :qwen, :mistral, :llama, :tinyllama]
166
+ end
167
+ end
@@ -4,8 +4,6 @@ use crate::ruby::candle_utils;
4
4
  use crate::ruby::Result;
5
5
 
6
6
  pub mod llm;
7
- pub mod ner;
8
- pub mod reranker;
9
7
  pub mod ruby;
10
8
  pub mod structured;
11
9
  pub mod tokenizer;
@@ -44,8 +42,8 @@ fn init(ruby: &Ruby) -> Result<()> {
44
42
 
45
43
  ruby::init_embedding_model(rb_candle)?;
46
44
  ruby::init_llm(rb_candle)?;
47
- ner::init(rb_candle)?;
48
- reranker::init(rb_candle)?;
45
+ ruby::ner::init(rb_candle)?;
46
+ ruby::reranker::init(rb_candle)?;
49
47
  ruby::dtype::init(rb_candle)?;
50
48
  ruby::device::init(rb_candle)?;
51
49
  ruby::tensor::init(rb_candle)?;
@@ -18,7 +18,7 @@ pub struct QuantizedGGUF {
18
18
  device: Device,
19
19
  model_id: String,
20
20
  eos_token_id: u32,
21
- architecture: String,
21
+ pub architecture: String,
22
22
  _chat_template: Option<String>,
23
23
  }
24
24
 
@@ -53,6 +53,30 @@ pub fn default_device() -> Device {
53
53
  }
54
54
  }
55
55
 
56
+ /// Get the best available device by checking runtime availability
57
+ pub fn best_device() -> Device {
58
+ // Try devices in order of preference
59
+
60
+ #[cfg(feature = "metal")]
61
+ {
62
+ // Check if Metal is actually available at runtime
63
+ if CoreDevice::new_metal(0).is_ok() {
64
+ return Device::Metal;
65
+ }
66
+ }
67
+
68
+ #[cfg(feature = "cuda")]
69
+ {
70
+ // Check if CUDA is actually available at runtime
71
+ if CoreDevice::new_cuda(0).is_ok() {
72
+ return Device::Cuda;
73
+ }
74
+ }
75
+
76
+ // Always fall back to CPU
77
+ Device::Cpu
78
+ }
79
+
56
80
  #[derive(Clone, Copy, Debug, PartialEq, Eq)]
57
81
  #[magnus::wrap(class = "Candle::Device")]
58
82
  pub enum Device {
@@ -66,6 +90,11 @@ impl Device {
66
90
  pub fn cpu() -> Self {
67
91
  Self::Cpu
68
92
  }
93
+
94
+ /// Get the best available device
95
+ pub fn best() -> Self {
96
+ best_device()
97
+ }
69
98
 
70
99
  /// Create a CUDA device (GPU)
71
100
  pub fn cuda() -> Result<Self> {
@@ -195,6 +224,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
195
224
  rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
196
225
  rb_device.define_singleton_method("available_devices", function!(available_devices, 0))?;
197
226
  rb_device.define_singleton_method("default", function!(default_device, 0))?;
227
+ rb_device.define_singleton_method("best", function!(best_device, 0))?;
198
228
  rb_device.define_method("to_s", method!(Device::__str__, 0))?;
199
229
  rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
200
230
  rb_device.define_method("==", method!(Device::__eq__, 1))?;