red-candle 1.0.2 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7405c9911d6088106dd7a19e96312f12b86e9a80087c3d7745cd3911e263890a
4
- data.tar.gz: a88b75152708e72e019aba9acfeabd899f1ae1d1c567562ded6d2c6aa8eae8d0
3
+ metadata.gz: 5a2b4fac15c2c261d8fea90d34973605209043e2b2a222f82414bf94bc5c47e8
4
+ data.tar.gz: daff4398f34170e20744ab2ee8b1abb5b977ffec15a129d678efced7c9649495
5
5
  SHA512:
6
- metadata.gz: ce1cc52dc1223968f3398ab0972283a6309a80d306c14193f23336cd36ed55c8fa5eaaaf05d756f76c88e442abe19b0d82d2742a49199930ef6effcffd6d4482
7
- data.tar.gz: 8c30f3c0c096f8186b219a9a5d0fe92928621126f545eaabae26ddedc843515b7da8c45890a1f24a7d519c0c68fcd55b0553ea73f977644258ff30c5e5ccd2f1
6
+ metadata.gz: 4391a7fb4072d9ac174bcecdb366c975c5e487f43fcc4b0db75533aa44c94822d38103a633eba0b098ac26cf31313e9dd7da77255ed83692584e6936cca86271
7
+ data.tar.gz: da7a15a86fea349069079537b8c6f6696079842d205168aa908644a59528d3ecf5922548d8b3d153494f0beb3fe6221b30110368f1f1289f0ac1ff4856f6b243
data/Cargo.lock CHANGED
@@ -1,6 +1,6 @@
1
1
  # This file is automatically @generated by Cargo.
2
2
  # It is not intended for manual editing.
3
- version = 3
3
+ version = 4
4
4
 
5
5
  [[package]]
6
6
  name = "accelerate-src"
@@ -121,6 +121,26 @@ version = "0.22.1"
121
121
  source = "registry+https://github.com/rust-lang/crates.io-index"
122
122
  checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
123
123
 
124
+ [[package]]
125
+ name = "bincode"
126
+ version = "2.0.1"
127
+ source = "registry+https://github.com/rust-lang/crates.io-index"
128
+ checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740"
129
+ dependencies = [
130
+ "bincode_derive",
131
+ "serde",
132
+ "unty",
133
+ ]
134
+
135
+ [[package]]
136
+ name = "bincode_derive"
137
+ version = "2.0.1"
138
+ source = "registry+https://github.com/rust-lang/crates.io-index"
139
+ checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09"
140
+ dependencies = [
141
+ "virtue",
142
+ ]
143
+
124
144
  [[package]]
125
145
  name = "bindgen"
126
146
  version = "0.69.5"
@@ -136,7 +156,7 @@ dependencies = [
136
156
  "proc-macro2",
137
157
  "quote",
138
158
  "regex",
139
- "rustc-hash",
159
+ "rustc-hash 1.1.0",
140
160
  "shlex",
141
161
  "syn",
142
162
  ]
@@ -255,13 +275,14 @@ dependencies = [
255
275
  "candle-nn",
256
276
  "candle-transformers",
257
277
  "half",
258
- "hf-hub",
278
+ "hf-hub 0.4.3",
259
279
  "magnus",
280
+ "outlines-core",
260
281
  "rand 0.8.5",
261
282
  "safetensors 0.3.3",
262
283
  "serde",
263
284
  "serde_json",
264
- "tokenizers",
285
+ "tokenizers 0.21.2",
265
286
  "tokio",
266
287
  ]
267
288
 
@@ -641,6 +662,15 @@ dependencies = [
641
662
  "dirs-sys 0.4.1",
642
663
  ]
643
664
 
665
+ [[package]]
666
+ name = "dirs"
667
+ version = "5.0.1"
668
+ source = "registry+https://github.com/rust-lang/crates.io-index"
669
+ checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
670
+ dependencies = [
671
+ "dirs-sys 0.4.1",
672
+ ]
673
+
644
674
  [[package]]
645
675
  name = "dirs"
646
676
  version = "6.0.0"
@@ -1303,13 +1333,30 @@ version = "0.5.2"
1303
1333
  source = "registry+https://github.com/rust-lang/crates.io-index"
1304
1334
  checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
1305
1335
 
1336
+ [[package]]
1337
+ name = "hf-hub"
1338
+ version = "0.3.2"
1339
+ source = "registry+https://github.com/rust-lang/crates.io-index"
1340
+ checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
1341
+ dependencies = [
1342
+ "dirs 5.0.1",
1343
+ "indicatif",
1344
+ "log",
1345
+ "native-tls",
1346
+ "rand 0.8.5",
1347
+ "serde",
1348
+ "serde_json",
1349
+ "thiserror 1.0.69",
1350
+ "ureq",
1351
+ ]
1352
+
1306
1353
  [[package]]
1307
1354
  name = "hf-hub"
1308
1355
  version = "0.4.3"
1309
1356
  source = "registry+https://github.com/rust-lang/crates.io-index"
1310
1357
  checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
1311
1358
  dependencies = [
1312
- "dirs",
1359
+ "dirs 6.0.0",
1313
1360
  "futures",
1314
1361
  "http",
1315
1362
  "indicatif",
@@ -1605,6 +1652,12 @@ dependencies = [
1605
1652
  "web-time",
1606
1653
  ]
1607
1654
 
1655
+ [[package]]
1656
+ name = "indoc"
1657
+ version = "2.0.6"
1658
+ source = "registry+https://github.com/rust-lang/crates.io-index"
1659
+ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
1660
+
1608
1661
  [[package]]
1609
1662
  name = "intel-mkl-src"
1610
1663
  version = "0.8.1"
@@ -1654,6 +1707,15 @@ dependencies = [
1654
1707
  "serde",
1655
1708
  ]
1656
1709
 
1710
+ [[package]]
1711
+ name = "itertools"
1712
+ version = "0.11.0"
1713
+ source = "registry+https://github.com/rust-lang/crates.io-index"
1714
+ checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
1715
+ dependencies = [
1716
+ "either",
1717
+ ]
1718
+
1657
1719
  [[package]]
1658
1720
  name = "itertools"
1659
1721
  version = "0.12.1"
@@ -1815,6 +1877,15 @@ dependencies = [
1815
1877
  "stable_deref_trait",
1816
1878
  ]
1817
1879
 
1880
+ [[package]]
1881
+ name = "memoffset"
1882
+ version = "0.9.1"
1883
+ source = "registry+https://github.com/rust-lang/crates.io-index"
1884
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
1885
+ dependencies = [
1886
+ "autocfg",
1887
+ ]
1888
+
1818
1889
  [[package]]
1819
1890
  name = "metal"
1820
1891
  version = "0.27.0"
@@ -2186,6 +2257,25 @@ version = "0.2.0"
2186
2257
  source = "registry+https://github.com/rust-lang/crates.io-index"
2187
2258
  checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
2188
2259
 
2260
+ [[package]]
2261
+ name = "outlines-core"
2262
+ version = "0.2.3"
2263
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2264
+ checksum = "4f0964d94d3e2322d2c0bbf80549affe085e2c6df08cf6c06e8c558988bcb11b"
2265
+ dependencies = [
2266
+ "bincode",
2267
+ "hf-hub 0.3.2",
2268
+ "once_cell",
2269
+ "regex",
2270
+ "regex-automata",
2271
+ "rustc-hash 2.1.1",
2272
+ "serde",
2273
+ "serde-pyobject",
2274
+ "serde_json",
2275
+ "thiserror 2.0.12",
2276
+ "tokenizers 0.20.3",
2277
+ ]
2278
+
2189
2279
  [[package]]
2190
2280
  name = "paste"
2191
2281
  version = "1.0.15"
@@ -2306,6 +2396,69 @@ dependencies = [
2306
2396
  "version_check",
2307
2397
  ]
2308
2398
 
2399
+ [[package]]
2400
+ name = "pyo3"
2401
+ version = "0.22.6"
2402
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2403
+ checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
2404
+ dependencies = [
2405
+ "cfg-if",
2406
+ "indoc",
2407
+ "libc",
2408
+ "memoffset",
2409
+ "once_cell",
2410
+ "portable-atomic",
2411
+ "pyo3-build-config",
2412
+ "pyo3-ffi",
2413
+ "pyo3-macros",
2414
+ "unindent",
2415
+ ]
2416
+
2417
+ [[package]]
2418
+ name = "pyo3-build-config"
2419
+ version = "0.22.6"
2420
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2421
+ checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
2422
+ dependencies = [
2423
+ "once_cell",
2424
+ "target-lexicon",
2425
+ ]
2426
+
2427
+ [[package]]
2428
+ name = "pyo3-ffi"
2429
+ version = "0.22.6"
2430
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2431
+ checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
2432
+ dependencies = [
2433
+ "libc",
2434
+ "pyo3-build-config",
2435
+ ]
2436
+
2437
+ [[package]]
2438
+ name = "pyo3-macros"
2439
+ version = "0.22.6"
2440
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2441
+ checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
2442
+ dependencies = [
2443
+ "proc-macro2",
2444
+ "pyo3-macros-backend",
2445
+ "quote",
2446
+ "syn",
2447
+ ]
2448
+
2449
+ [[package]]
2450
+ name = "pyo3-macros-backend"
2451
+ version = "0.22.6"
2452
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2453
+ checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
2454
+ dependencies = [
2455
+ "heck",
2456
+ "proc-macro2",
2457
+ "pyo3-build-config",
2458
+ "quote",
2459
+ "syn",
2460
+ ]
2461
+
2309
2462
  [[package]]
2310
2463
  name = "quote"
2311
2464
  version = "1.0.40"
@@ -2418,6 +2571,17 @@ dependencies = [
2418
2571
  "rayon-core",
2419
2572
  ]
2420
2573
 
2574
+ [[package]]
2575
+ name = "rayon-cond"
2576
+ version = "0.3.0"
2577
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2578
+ checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
2579
+ dependencies = [
2580
+ "either",
2581
+ "itertools 0.11.0",
2582
+ "rayon",
2583
+ ]
2584
+
2421
2585
  [[package]]
2422
2586
  name = "rayon-cond"
2423
2587
  version = "0.4.0"
@@ -2604,6 +2768,12 @@ version = "1.1.0"
2604
2768
  source = "registry+https://github.com/rust-lang/crates.io-index"
2605
2769
  checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
2606
2770
 
2771
+ [[package]]
2772
+ name = "rustc-hash"
2773
+ version = "2.1.1"
2774
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2775
+ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
2776
+
2607
2777
  [[package]]
2608
2778
  name = "rustix"
2609
2779
  version = "1.0.7"
@@ -2740,6 +2910,16 @@ dependencies = [
2740
2910
  "serde_derive",
2741
2911
  ]
2742
2912
 
2913
+ [[package]]
2914
+ name = "serde-pyobject"
2915
+ version = "0.4.0"
2916
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2917
+ checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16"
2918
+ dependencies = [
2919
+ "pyo3",
2920
+ "serde",
2921
+ ]
2922
+
2743
2923
  [[package]]
2744
2924
  name = "serde_derive"
2745
2925
  version = "1.0.219"
@@ -2757,6 +2937,7 @@ version = "1.0.140"
2757
2937
  source = "registry+https://github.com/rust-lang/crates.io-index"
2758
2938
  checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
2759
2939
  dependencies = [
2940
+ "indexmap",
2760
2941
  "itoa",
2761
2942
  "memchr",
2762
2943
  "ryu",
@@ -2995,6 +3176,12 @@ dependencies = [
2995
3176
  "xattr",
2996
3177
  ]
2997
3178
 
3179
+ [[package]]
3180
+ name = "target-lexicon"
3181
+ version = "0.12.16"
3182
+ source = "registry+https://github.com/rust-lang/crates.io-index"
3183
+ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
3184
+
2998
3185
  [[package]]
2999
3186
  name = "tempfile"
3000
3187
  version = "3.20.0"
@@ -3058,6 +3245,39 @@ dependencies = [
3058
3245
  "zerovec",
3059
3246
  ]
3060
3247
 
3248
+ [[package]]
3249
+ name = "tokenizers"
3250
+ version = "0.20.3"
3251
+ source = "registry+https://github.com/rust-lang/crates.io-index"
3252
+ checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7"
3253
+ dependencies = [
3254
+ "aho-corasick",
3255
+ "derive_builder",
3256
+ "esaxx-rs",
3257
+ "getrandom 0.2.16",
3258
+ "hf-hub 0.3.2",
3259
+ "indicatif",
3260
+ "itertools 0.12.1",
3261
+ "lazy_static",
3262
+ "log",
3263
+ "macro_rules_attribute",
3264
+ "monostate",
3265
+ "onig",
3266
+ "paste",
3267
+ "rand 0.8.5",
3268
+ "rayon",
3269
+ "rayon-cond 0.3.0",
3270
+ "regex",
3271
+ "regex-syntax",
3272
+ "serde",
3273
+ "serde_json",
3274
+ "spm_precompiled",
3275
+ "thiserror 1.0.69",
3276
+ "unicode-normalization-alignments",
3277
+ "unicode-segmentation",
3278
+ "unicode_categories",
3279
+ ]
3280
+
3061
3281
  [[package]]
3062
3282
  name = "tokenizers"
3063
3283
  version = "0.21.2"
@@ -3081,7 +3301,7 @@ dependencies = [
3081
3301
  "paste",
3082
3302
  "rand 0.9.1",
3083
3303
  "rayon",
3084
- "rayon-cond",
3304
+ "rayon-cond 0.4.0",
3085
3305
  "regex",
3086
3306
  "regex-syntax",
3087
3307
  "serde",
@@ -3365,12 +3585,24 @@ version = "0.1.1"
3365
3585
  source = "registry+https://github.com/rust-lang/crates.io-index"
3366
3586
  checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
3367
3587
 
3588
+ [[package]]
3589
+ name = "unindent"
3590
+ version = "0.2.4"
3591
+ source = "registry+https://github.com/rust-lang/crates.io-index"
3592
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
3593
+
3368
3594
  [[package]]
3369
3595
  name = "untrusted"
3370
3596
  version = "0.9.0"
3371
3597
  source = "registry+https://github.com/rust-lang/crates.io-index"
3372
3598
  checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
3373
3599
 
3600
+ [[package]]
3601
+ name = "unty"
3602
+ version = "0.0.4"
3603
+ source = "registry+https://github.com/rust-lang/crates.io-index"
3604
+ checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
3605
+
3374
3606
  [[package]]
3375
3607
  name = "ureq"
3376
3608
  version = "2.12.1"
@@ -3431,6 +3663,12 @@ version = "0.9.5"
3431
3663
  source = "registry+https://github.com/rust-lang/crates.io-index"
3432
3664
  checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
3433
3665
 
3666
+ [[package]]
3667
+ name = "virtue"
3668
+ version = "0.0.18"
3669
+ source = "registry+https://github.com/rust-lang/crates.io-index"
3670
+ checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1"
3671
+
3434
3672
  [[package]]
3435
3673
  name = "walkdir"
3436
3674
  version = "2.5.0"
data/README.md CHANGED
@@ -3,7 +3,7 @@
3
3
  [![build](https://github.com/assaydepot/red-candle/actions/workflows/build.yml/badge.svg)](https://github.com/assaydepot/red-candle/actions/workflows/build.yml)
4
4
  [![Gem Version](https://badge.fury.io/rb/red-candle.svg)](https://badge.fury.io/rb/red-candle)
5
5
 
6
- Run state-of-the-art **language models directly from Ruby**. No Python, no APIs, no external services - just Ruby with blazing-fast Rust under the hood. Hardware accelerated with **Metal (Mac)** and **CUDA (NVIDIA).**
6
+ Run state-of-the-art **language models directly from Ruby**. No Python, no APIs, no external services - just Ruby with blazing-fast Rust under the hood. Hardware accelerated with **Metal (Mac)** and **CUDA (NVIDIA).** Red candle leverages the Rust ecosystem, notably [Candle](https://github.com/huggingface/candle) and [Magnus](https://github.com/matsadler/magnus), to provide a fast and efficient way to run LLMs in Ruby. See [Dependencies](#dependencies) for more.
7
7
 
8
8
  ## Install & Chat in 30 Seconds
9
9
 
@@ -126,6 +126,10 @@ Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
126
126
  - **Gemma**: Google's Gemma models (e.g., `google/gemma-2b`, `google/gemma-7b`, `google/gemma-2b-it`)
127
127
  - **Llama**: Llama 2 and Llama 3 models (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `meta-llama/Llama-2-7b-hf`, `NousResearch/Llama-2-7b-hf`)
128
128
  - **Mistral**: All Mistral models (e.g., `mistralai/Mistral-7B-Instruct-v0.1`)
129
+ - **Qwen**: Qwen 2 and 2.5 models (e.g., `Qwen/Qwen2-1.5B`, `Qwen/Qwen2.5-7B-Instruct`)
130
+ - **Phi**: Microsoft's Phi-2, Phi-3, Phi-3.5, and Phi-4 models (e.g., `microsoft/phi-2`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/phi-4`)
131
+ - ⚠️ ⚠️ ⚠️ Note: Phi-3 and Phi-4 GGUF models have a known issue with KV cache persistence between generations. The `reset_cache` parameter doesn't work for GGUF models. Recreate the model instance for each generation.
132
+ - `candle` pull request about phi-3 gguf models: https://github.com/huggingface/candle/pull/2937
129
133
 
130
134
  ### Quantized Model Support (GGUF)
131
135
 
@@ -244,6 +248,36 @@ This is particularly useful for:
244
248
  - Troubleshooting generation problems
245
249
  - Analyzing model behavior
246
250
 
251
+ ## Structured Generation
252
+
253
+ Red Candle supports structured generation to constrain LLM outputs to follow specific patterns like JSON schemas or regular expressions:
254
+
255
+ ```ruby
256
+ # Define a JSON schema
257
+ schema = {
258
+ type: "object",
259
+ properties: {
260
+ answer: { type: "string", enum: ["yes", "no"] },
261
+ confidence: { type: "number", minimum: 0, maximum: 1 }
262
+ },
263
+ required: ["answer"]
264
+ }
265
+
266
+ # Generate and parse in one step
267
+ result = llm.generate_structured("Is Ruby easy to learn?", schema: schema)
268
+ puts result["answer"] # "yes"
269
+ puts result["confidence"] # 0.9
270
+
271
+ # Or use regex patterns for non-JSON outputs
272
+ phone_constraint = llm.constraint_from_regex('\d{3}-\d{3}-\d{4}')
273
+ config = Candle::GenerationConfig.balanced(constraint: phone_constraint)
274
+ phone = llm.generate("Generate a phone number:", config: config)
275
+ ```
276
+
277
+ See [STRUCTURED_GENERATION.md](docs/STRUCTURED_GENERATION.md) for detailed documentation.
278
+
279
+ **Note on Reliability**: Structured generation constrains the model's output tokens, but success rates vary by model size and schema complexity. Smaller models (< 7B parameters) may occasionally produce incomplete or invalid JSON, especially with complex schemas. Consider implementing retry logic or fallback strategies in production applications. Larger models generally perform much better with structured generation.
280
+
247
281
  ## ⚠️ Model Format Requirements
248
282
 
249
283
  ### EmbeddingModels and Rerankers: Safetensors Only
@@ -861,7 +895,7 @@ Pull requests are welcome.
861
895
  4. `git push --follow-tags`
862
896
  5. `gem push pkg/red-candle-VERSION_NUMBER.gem`
863
897
 
864
- ## See Also
898
+ ## Dependencies
865
899
 
866
900
  - [Candle](https://github.com/huggingface/candle)
867
901
  - [Magnus](https://github.com/matsadler/magnus)
data/Rakefile CHANGED
@@ -8,7 +8,14 @@ task default: :test
8
8
  Rake::TestTask.new do |t|
9
9
  t.deps << :compile
10
10
  t.libs << "test"
11
- t.test_files = FileList["test/**/*_test.rb"].exclude("test/benchmarks/**/*_test.rb")
11
+ t.test_files = FileList["test/**/*_test.rb"]
12
+ .exclude("test/benchmarks/**/*_test.rb")
13
+ .exclude("test/llm/llm_test.rb")
14
+ .exclude("test/llm/gemma_test.rb")
15
+ .exclude("test/llm/mistral_test.rb")
16
+ .exclude("test/llm/llama_test.rb")
17
+ .exclude("test/llm/phi_test.rb")
18
+ .exclude("test/llm/qwen_test.rb")
12
19
  end
13
20
 
14
21
  spec = Bundler.load_gemspec("candle.gemspec")
@@ -63,6 +70,44 @@ task "test:device:benchmark" => :compile do
63
70
  Rake::Task["test:benchmark"].invoke
64
71
  end
65
72
 
73
+ desc "Run LLM tests for specific models"
74
+ namespace :test do
75
+ namespace :llm do
76
+ desc "Run tests for Gemma models"
77
+ task :gemma => :compile do
78
+ ruby "-Itest", "test/llm/gemma_test.rb"
79
+ end
80
+
81
+ desc "Run tests for Phi models"
82
+ task :phi => :compile do
83
+ ruby "-Itest", "test/llm/phi_test.rb"
84
+ end
85
+
86
+ desc "Run tests for Qwen models"
87
+ task :qwen => :compile do
88
+ ruby "-Itest", "test/llm/qwen_test.rb"
89
+ end
90
+
91
+ desc "Run tests for Mistral models"
92
+ task :mistral => :compile do
93
+ ruby "-Itest", "test/llm/mistral_test.rb"
94
+ end
95
+
96
+ desc "Run tests for Llama models"
97
+ task :llama => :compile do
98
+ ruby "-Itest", "test/llm/llama_test.rb"
99
+ end
100
+
101
+ desc "Run tests for TinyLlama models"
102
+ task :tinyllama => :compile do
103
+ ruby "-Itest", "test/llm/tinyllama_test.rb"
104
+ end
105
+
106
+ desc "Run all LLM tests (WARNING: downloads large models)"
107
+ task :all => [:gemma, :phi, :qwen, :mistral, :llama]
108
+ end
109
+ end
110
+
66
111
  namespace :doc do
67
112
  task default: %i[rustdoc yard]
68
113
 
@@ -3,6 +3,7 @@ name = "candle"
3
3
  version = "0.1.0"
4
4
  edition = "2021"
5
5
  build = "build.rs"
6
+ rust-version = "1.85"
6
7
 
7
8
  [lib]
8
9
  crate-type = ["cdylib"]
@@ -20,6 +21,7 @@ serde_json = "1.0"
20
21
  serde = { version = "1.0", features = ["derive"] }
21
22
  tokio = { version = "1.45", features = ["rt", "macros"] }
22
23
  rand = "0.8"
24
+ outlines-core = "0.2"
23
25
 
24
26
  [features]
25
27
  default = []
@@ -7,6 +7,7 @@ pub mod llm;
7
7
  pub mod ner;
8
8
  pub mod reranker;
9
9
  pub mod ruby;
10
+ pub mod structured;
10
11
  pub mod tokenizer;
11
12
 
12
13
  // Configuration detection from build.rs
@@ -49,6 +50,7 @@ fn init(ruby: &Ruby) -> Result<()> {
49
50
  ruby::device::init(rb_candle)?;
50
51
  ruby::tensor::init(rb_candle)?;
51
52
  ruby::tokenizer::init(rb_candle)?;
53
+ ruby::structured::init_structured(rb_candle)?;
52
54
  candle_utils(rb_candle)?;
53
55
 
54
56
  Ok(())
@@ -0,0 +1,123 @@
1
+ #[cfg(test)]
2
+ mod constrained_generation_tests {
3
+ use super::super::*;
4
+ use crate::structured::{VocabularyAdapter, SchemaProcessor};
5
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
6
+
7
+ #[tokio::test]
8
+ async fn test_constrained_vs_unconstrained_generation() {
9
+ // This test demonstrates the difference between constrained and unconstrained generation
10
+
11
+ // Load a tokenizer for testing
12
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
13
+ let wrapper = TokenizerWrapper::new(tokenizer);
14
+
15
+ // Create vocabulary adapter
16
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
17
+ .expect("Should create vocabulary");
18
+
19
+ // Create schema processor
20
+ let processor = SchemaProcessor::new();
21
+
22
+ // Define a simple JSON schema for a yes/no response
23
+ let schema = r#"{
24
+ "type": "object",
25
+ "properties": {
26
+ "answer": {
27
+ "type": "string",
28
+ "enum": ["yes", "no"]
29
+ }
30
+ },
31
+ "required": ["answer"]
32
+ }"#;
33
+
34
+ // Process schema into Index
35
+ let index = processor.process_schema(schema, &vocabulary)
36
+ .expect("Should process schema");
37
+
38
+ // Test configuration with constraint
39
+ let mut config_with_constraint = GenerationConfig::default();
40
+ config_with_constraint.constraint = Some(index.clone());
41
+ config_with_constraint.max_length = 50;
42
+
43
+ // Test configuration without constraint
44
+ let config_without_constraint = GenerationConfig::default();
45
+
46
+ // Create text generation instances
47
+ let mut gen_constrained = TextGeneration::from_config(&config_with_constraint);
48
+ let mut gen_unconstrained = TextGeneration::from_config(&config_without_constraint);
49
+
50
+ // Set EOS token
51
+ gen_constrained.set_eos_token_id(102); // BERT's [SEP] token
52
+ gen_unconstrained.set_eos_token_id(102);
53
+
54
+ // Constraints are set internally - we can't directly verify them
55
+ // but we can test their effects in actual generation
56
+ }
57
+ }
58
+
59
+ #[test]
60
+ fn test_constraint_configuration() {
61
+ // Test that we can create a TextGeneration with constraints
62
+ let config = GenerationConfig::default();
63
+ let _text_gen = TextGeneration::from_config(&config);
64
+
65
+ // Test that we can create a TextGeneration from config
66
+ // Constraints are private implementation details
67
+ }
68
+
69
+ #[test]
70
+ fn test_repetition_penalty() {
71
+ use candle_core::{Tensor, Device};
72
+
73
+ let device = Device::Cpu;
74
+ let vocab_size = 10;
75
+
76
+ // Create logits with some positive and negative values
77
+ let logits_vec: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0, 0.0, 3.0, -3.0, 1.5, -1.5, 0.5];
78
+ let mut logits = Tensor::from_vec(logits_vec.clone(), vocab_size, &device).unwrap();
79
+
80
+ // Create text generation with some tokens
81
+ let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
82
+ text_gen.push_token(0); // Token that had logit 1.0
83
+ text_gen.push_token(2); // Token that had logit 2.0
84
+ text_gen.push_token(5); // Token that had logit 3.0
85
+
86
+ // Apply repetition penalty
87
+ text_gen.apply_repetition_penalty(&mut logits, 1.5, 10).unwrap();
88
+
89
+ let penalized = logits.to_vec1::<f32>().unwrap();
90
+
91
+ // Check that tokens in context were penalized
92
+ assert!(penalized[0] < logits_vec[0], "Positive logit should be reduced");
93
+ assert!(penalized[2] < logits_vec[2], "Positive logit should be reduced");
94
+ assert!(penalized[5] < logits_vec[5], "Positive logit should be reduced");
95
+
96
+ // Check that other tokens remain unchanged
97
+ assert_eq!(penalized[1], logits_vec[1], "Unsampled token should be unchanged");
98
+ assert_eq!(penalized[3], logits_vec[3], "Unsampled token should be unchanged");
99
+ }
100
+
101
+ #[test]
102
+ fn test_stop_conditions() {
103
+ let mut text_gen = TextGeneration::new(42, Some(1.0), None, None, 1.0, 64);
104
+ text_gen.set_eos_token_id(50256); // Common EOS token
105
+
106
+ // Test max length stop
107
+ for i in 0..10 {
108
+ text_gen.push_token(i);
109
+ }
110
+ assert!(text_gen.should_stop(100, 10), "Should stop at max length");
111
+ assert!(!text_gen.should_stop(100, 20), "Should not stop before max length");
112
+
113
+ // Test EOS token stop
114
+ assert!(text_gen.should_stop(50256, 100), "Should stop at EOS token");
115
+ assert!(!text_gen.should_stop(123, 100), "Should not stop at non-EOS token");
116
+
117
+ // Test stop sequences
118
+ let stop_seqs = vec!["STOP".to_string(), "END".to_string()];
119
+ assert!(text_gen.check_stop_sequences("This is the STOP", &stop_seqs), "Should detect stop sequence");
120
+ assert!(text_gen.check_stop_sequences("The END", &stop_seqs), "Should detect stop sequence");
121
+ assert!(!text_gen.check_stop_sequences("Continue", &stop_seqs), "Should not detect stop sequence");
122
+ }
123
+ }