prompt_guard 0.1.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/AGENTS-hf-inference.md +635 -0
- data/AGENTS.md +395 -0
- data/Gemfile +5 -0
- data/INTEGRATION_SPEC.md +404 -0
- data/LICENSE +21 -0
- data/README.md +212 -72
- data/Rakefile +21 -0
- data/lib/prompt_guard/detector.rb +60 -33
- data/lib/prompt_guard/model.rb +113 -140
- data/lib/prompt_guard/utils/hub.rb +162 -0
- data/lib/prompt_guard/version.rb +1 -1
- data/lib/prompt_guard.rb +117 -27
- data/prompt_guard.gemspec +74 -0
- metadata +97 -5
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 3813d3eba942c206b4eb940865a140f4fa98c841bda23e3fdef2923226e5bf6d
|
|
4
|
+
data.tar.gz: bccc41fc593b7df5bf1845911d2e5f1b73216a0e63afeb917a879e75b8c34cf8
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 6ecbb50428c9b940216dc1ca64150e7d16c513c99d69e0144fe1a9ee016e31073d1ea992c0fbcfec4b1c790159e142ebe377a374f0279946c708121483d99298
|
|
7
|
+
data.tar.gz: cf4fbed001d527e313a5be8b1e430f602fc9eb946321bed2cec94947bf3013b31228de9237e6828ba025393e7a5e233af41b46adeeda00078f88ad6f43fd240b
|
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
# AGENTS.md — Guidelines for Hugging Face ONNX Inference Ruby Gems
|
|
2
|
+
|
|
3
|
+
This document describes the architecture and conventions for Ruby gems that provide
|
|
4
|
+
ML inference by downloading ONNX models from Hugging Face Hub and running them via
|
|
5
|
+
ONNX Runtime. Inspired by [ankane/informers](https://github.com/ankane/informers).
|
|
6
|
+
|
|
7
|
+
It is meant to be copied and adapted for each new gem of this kind.
|
|
8
|
+
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
## 1. Core Concept
|
|
12
|
+
|
|
13
|
+
The gem does NOT bundle any model. Instead, it:
|
|
14
|
+
|
|
15
|
+
1. **Lazily downloads** model files (`.onnx`, `config.json`, `tokenizer.json`, etc.)
|
|
16
|
+
from Hugging Face Hub on first use.
|
|
17
|
+
2. **Caches** them locally following the XDG standard (`~/.cache/<gem_name>/`).
|
|
18
|
+
3. **Runs inference** via the `onnxruntime` gem (Ruby bindings for ONNX Runtime).
|
|
19
|
+
4. **Exposes a pipeline API** that abstracts model loading, tokenization, and
|
|
20
|
+
post-processing behind a single `pipeline(task, model_id)` call.
|
|
21
|
+
|
|
22
|
+
Models are referenced by their Hugging Face identifier: `"owner/model-name"`.
|
|
23
|
+
|
|
24
|
+
---
|
|
25
|
+
|
|
26
|
+
## 2. Naming Conventions
|
|
27
|
+
|
|
28
|
+
| Concept | Convention | Example |
|
|
29
|
+
|---------|-----------|---------|
|
|
30
|
+
| Gem name | `snake_case` | `informers` |
|
|
31
|
+
| Module name | `PascalCase` | `Informers` |
|
|
32
|
+
| GitHub repo | `<owner>/<gem_name>` | `ankane/informers` |
|
|
33
|
+
| Model reference | HF model ID | `"sentence-transformers/all-MiniLM-L6-v2"` |
|
|
34
|
+
|
|
35
|
+
---
|
|
36
|
+
|
|
37
|
+
## 3. File Layout
|
|
38
|
+
|
|
39
|
+
```
|
|
40
|
+
lib/
|
|
41
|
+
<gem_name>.rb # Main entry point — requires + module + pipeline()
|
|
42
|
+
<gem_name>/
|
|
43
|
+
version.rb # VERSION constant
|
|
44
|
+
env.rb # Global config (cache_dir, remote_host, etc.)
|
|
45
|
+
configs.rb # Model config loader (config.json)
|
|
46
|
+
models.rb # Model class registry + PreTrainedModel
|
|
47
|
+
tokenizers.rb # Tokenizer loader (tokenizer.json)
|
|
48
|
+
processors.rb # Image/audio processors (if needed)
|
|
49
|
+
pipelines.rb # Pipeline classes (one per task type)
|
|
50
|
+
backends/
|
|
51
|
+
onnx.rb # ONNX Runtime session wrapper
|
|
52
|
+
utils/
|
|
53
|
+
hub.rb # Hugging Face Hub download + cache logic
|
|
54
|
+
core.rb # Shared utilities (softmax, sigmoid, etc.)
|
|
55
|
+
tensor.rb # Tensor/array helpers
|
|
56
|
+
... # Other domain utils (image, audio, etc.)
|
|
57
|
+
test/
|
|
58
|
+
test_helper.rb
|
|
59
|
+
<gem_name>_test.rb
|
|
60
|
+
hub_test.rb
|
|
61
|
+
pipeline_test.rb
|
|
62
|
+
...
|
|
63
|
+
<gem_name>.gemspec
|
|
64
|
+
Rakefile
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
---
|
|
68
|
+
|
|
69
|
+
## 4. Public API Contract
|
|
70
|
+
|
|
71
|
+
### 4.1 Main entry point — `pipeline()`
|
|
72
|
+
|
|
73
|
+
The gem exposes a single factory method that returns a callable pipeline object:
|
|
74
|
+
|
|
75
|
+
```ruby
|
|
76
|
+
# Create a pipeline for a specific task, optionally with a specific model.
|
|
77
|
+
# When model is omitted, a default model for the task is used.
|
|
78
|
+
pipeline = <GemName>.pipeline(task, model_id = nil, **options)
|
|
79
|
+
|
|
80
|
+
# Execute the pipeline (callable object)
|
|
81
|
+
result = pipeline.(input)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
**Options:**
|
|
85
|
+
|
|
86
|
+
| Option | Type | Description |
|
|
87
|
+
|--------|------|-------------|
|
|
88
|
+
| `dtype` | String | Model variant: `"fp32"`, `"fp16"`, `"q8"`, `"q4"`, etc. |
|
|
89
|
+
| `device` | String | Execution device: `"cpu"`, `"cuda"`, `"coreml"` |
|
|
90
|
+
| `cache_dir` | String | Override default cache directory |
|
|
91
|
+
| `revision` | String | Model revision/branch (default: `"main"`) |
|
|
92
|
+
| `model_file_name` | String | Override ONNX file path within the model repo |
|
|
93
|
+
| `session_options` | Hash | ONNX Runtime session options |
|
|
94
|
+
| `progress_callback` | Proc | Called during download with progress info |
|
|
95
|
+
|
|
96
|
+
### 4.2 Global configuration
|
|
97
|
+
|
|
98
|
+
```ruby
|
|
99
|
+
# Cache directory (default: ~/.cache/<gem_name>/ following XDG)
|
|
100
|
+
<GemName>.cache_dir = "/custom/cache/path"
|
|
101
|
+
|
|
102
|
+
# Remote host (default: "https://huggingface.co/")
|
|
103
|
+
<GemName>.remote_host = "https://huggingface.co/"
|
|
104
|
+
|
|
105
|
+
# URL template for model files
|
|
106
|
+
<GemName>.remote_path_template = "{model}/resolve/{revision}/"
|
|
107
|
+
|
|
108
|
+
# Enable/disable remote downloads (default: true, unless $<GEM>_OFFLINE is set)
|
|
109
|
+
<GemName>.allow_remote_models = true
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
---
|
|
113
|
+
|
|
114
|
+
## 5. Error Classes
|
|
115
|
+
|
|
116
|
+
```ruby
|
|
117
|
+
module <GemName>
|
|
118
|
+
class Error < StandardError; end
|
|
119
|
+
# Extend with domain-specific errors as needed
|
|
120
|
+
end
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
---
|
|
124
|
+
|
|
125
|
+
## 6. Hub Module — Download & Cache
|
|
126
|
+
|
|
127
|
+
The hub module (`lib/<gem_name>/utils/hub.rb`) is the core of the distribution
|
|
128
|
+
mechanism. It MUST:
|
|
129
|
+
|
|
130
|
+
### 6.1 Responsibilities
|
|
131
|
+
|
|
132
|
+
1. **Download files** from Hugging Face Hub via HTTP (stdlib `open-uri`).
|
|
133
|
+
2. **Cache files** locally in a structured directory.
|
|
134
|
+
3. **Check cache** before any download — return cached file if available.
|
|
135
|
+
4. **Support authentication** via `$HF_TOKEN` environment variable.
|
|
136
|
+
5. **Handle failures** gracefully — use temp files (`.incomplete`) and clean up.
|
|
137
|
+
6. **Use only Ruby stdlib** for HTTP: `open-uri`, `net/http`, `uri`, `json`, `fileutils`.
|
|
138
|
+
|
|
139
|
+
### 6.2 Cache structure
|
|
140
|
+
|
|
141
|
+
```
|
|
142
|
+
~/.cache/<gem_name>/
|
|
143
|
+
<owner>/<model_name>/
|
|
144
|
+
config.json
|
|
145
|
+
tokenizer.json
|
|
146
|
+
onnx/
|
|
147
|
+
model_quantized.onnx
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
The cache key mirrors the Hugging Face path:
|
|
151
|
+
- Default revision (`main`): `<owner>/<model>/filename`
|
|
152
|
+
- Specific revision: `<owner>/<model>/<revision>/filename`
|
|
153
|
+
|
|
154
|
+
### 6.3 Remote URL construction
|
|
155
|
+
|
|
156
|
+
```
|
|
157
|
+
{remote_host}/{model}/resolve/{revision}/{filename}
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
```
|
|
162
|
+
https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx
|
|
163
|
+
```
|
|
164
|
+
|
|
165
|
+
### 6.4 FileCache class
|
|
166
|
+
|
|
167
|
+
```ruby
|
|
168
|
+
class FileCache
|
|
169
|
+
def initialize(path)
|
|
170
|
+
@path = path
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
# Check if file exists in cache
|
|
174
|
+
def match(request)
|
|
175
|
+
file_path = File.join(@path, request)
|
|
176
|
+
FileResponse.new(file_path) if File.exist?(file_path)
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
# Write response to cache with atomic write pattern
|
|
180
|
+
def put(request, response)
|
|
181
|
+
output_path = File.join(@path, request)
|
|
182
|
+
tmp_path = "#{output_path}.incomplete"
|
|
183
|
+
FileUtils.mkdir_p(File.dirname(output_path))
|
|
184
|
+
File.open(tmp_path, "wb") { |f| f.write(response.read(1024 * 1024)) until response.eof? }
|
|
185
|
+
FileUtils.move(tmp_path, output_path)
|
|
186
|
+
end
|
|
187
|
+
end
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
### 6.5 Download flow
|
|
191
|
+
|
|
192
|
+
```
|
|
193
|
+
get_model_file(model_id, filename, **options)
|
|
194
|
+
|
|
|
195
|
+
v
|
|
196
|
+
[Check FileCache] --> cache hit? --> return cached path
|
|
197
|
+
|
|
|
198
|
+
no
|
|
199
|
+
v
|
|
200
|
+
[Build remote URL: host + path_template + filename]
|
|
201
|
+
|
|
|
202
|
+
v
|
|
203
|
+
[HTTP GET with User-Agent + optional HF_TOKEN auth header]
|
|
204
|
+
|
|
|
205
|
+
v
|
|
206
|
+
[Write to cache via FileCache.put (atomic: .incomplete -> rename)]
|
|
207
|
+
|
|
|
208
|
+
v
|
|
209
|
+
[Return local file path]
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
### 6.6 get_model_file signature
|
|
213
|
+
|
|
214
|
+
```ruby
|
|
215
|
+
def self.get_model_file(path_or_repo_id, filename, fatal = true, **options)
|
|
216
|
+
# options: cache_dir, revision, progress_callback, local_files_only
|
|
217
|
+
# Returns: absolute path to the cached file
|
|
218
|
+
# Raises: Error if file not found and fatal is true
|
|
219
|
+
# Returns: nil if file not found and fatal is false
|
|
220
|
+
end
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
### 6.7 get_model_json helper
|
|
224
|
+
|
|
225
|
+
```ruby
|
|
226
|
+
def self.get_model_json(model_path, file_name, fatal = true, **options)
|
|
227
|
+
# Downloads a JSON file and parses it
|
|
228
|
+
# Returns: parsed Hash
|
|
229
|
+
# Returns: {} if file not found and fatal is false
|
|
230
|
+
end
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
---
|
|
234
|
+
|
|
235
|
+
## 7. Config Loader
|
|
236
|
+
|
|
237
|
+
The config loader (`lib/<gem_name>/configs.rb`) downloads and parses `config.json`
|
|
238
|
+
from the model repository. This file contains model metadata:
|
|
239
|
+
|
|
240
|
+
- `model_type` — used to select the correct model class (e.g. `"bert"`, `"gpt2"`)
|
|
241
|
+
- `id2label` — label mappings for classification models
|
|
242
|
+
- Architecture-specific parameters (hidden size, number of heads, etc.)
|
|
243
|
+
|
|
244
|
+
```ruby
|
|
245
|
+
class PretrainedConfig
|
|
246
|
+
def self.from_pretrained(model_name_or_path, **options)
|
|
247
|
+
data = Hub.get_model_json(model_name_or_path, "config.json", true, **options)
|
|
248
|
+
new(data)
|
|
249
|
+
end
|
|
250
|
+
|
|
251
|
+
def [](key)
|
|
252
|
+
@config_json[key.to_s]
|
|
253
|
+
end
|
|
254
|
+
end
|
|
255
|
+
```
|
|
256
|
+
|
|
257
|
+
---
|
|
258
|
+
|
|
259
|
+
## 8. Model Loading
|
|
260
|
+
|
|
261
|
+
### 8.1 Auto-resolution pattern
|
|
262
|
+
|
|
263
|
+
Models use an `AutoModel` pattern: the `model_type` field in `config.json` is
|
|
264
|
+
used to look up the correct Ruby class from a mapping:
|
|
265
|
+
|
|
266
|
+
```ruby
|
|
267
|
+
class AutoModel < PretrainedMixin
|
|
268
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, ...]
|
|
269
|
+
end
|
|
270
|
+
```
|
|
271
|
+
|
|
272
|
+
Each mapping is a Hash: `{ "bert" => ["BertConfig", BertForSequenceClassification], ... }`.
|
|
273
|
+
|
|
274
|
+
The flow:
|
|
275
|
+
|
|
276
|
+
```
|
|
277
|
+
AutoModel.from_pretrained("owner/model")
|
|
278
|
+
|
|
|
279
|
+
v
|
|
280
|
+
[Download config.json] --> extract model_type (e.g. "bert")
|
|
281
|
+
|
|
|
282
|
+
v
|
|
283
|
+
[Look up model_type in MODEL_CLASS_MAPPINGS]
|
|
284
|
+
|
|
|
285
|
+
v
|
|
286
|
+
[Call SpecificModel.from_pretrained(...)]
|
|
287
|
+
|
|
|
288
|
+
v
|
|
289
|
+
[Download ONNX file] --> create OnnxRuntime::InferenceSession
|
|
290
|
+
|
|
|
291
|
+
v
|
|
292
|
+
[Return model instance]
|
|
293
|
+
```
|
|
294
|
+
|
|
295
|
+
### 8.2 ONNX file naming convention
|
|
296
|
+
|
|
297
|
+
The ONNX file path depends on the `dtype` option:
|
|
298
|
+
|
|
299
|
+
| dtype | ONNX file suffix | Example path |
|
|
300
|
+
|-------|-----------------|--------------|
|
|
301
|
+
| `fp32` | (none) | `onnx/model.onnx` |
|
|
302
|
+
| `fp16` | `_fp16` | `onnx/model_fp16.onnx` |
|
|
303
|
+
| `q8` | `_quantized` | `onnx/model_quantized.onnx` |
|
|
304
|
+
| `int8` | `_quantized` | `onnx/model_quantized.onnx` |
|
|
305
|
+
| `uint8` | `_uint8` | `onnx/model_uint8.onnx` |
|
|
306
|
+
| `q4` | `_q4` | `onnx/model_q4.onnx` |
|
|
307
|
+
| `q4f16` | `_q4f16` | `onnx/model_q4f16.onnx` |
|
|
308
|
+
| `bnb4` | `_bnb4` | `onnx/model_bnb4.onnx` |
|
|
309
|
+
|
|
310
|
+
Default is `q8` (quantized) for smaller download size.
|
|
311
|
+
|
|
312
|
+
### 8.3 construct_session
|
|
313
|
+
|
|
314
|
+
```ruby
|
|
315
|
+
def self.construct_session(model_name, file_name, **options)
|
|
316
|
+
model_file = "onnx/#{file_name}#{dtype_suffix}.onnx"
|
|
317
|
+
path = Hub.get_model_file(model_name, model_file, true, **options)
|
|
318
|
+
OnnxRuntime::InferenceSession.new(path, **session_options)
|
|
319
|
+
end
|
|
320
|
+
```
|
|
321
|
+
|
|
322
|
+
### 8.4 Model types
|
|
323
|
+
|
|
324
|
+
Different architectures need different ONNX sessions:
|
|
325
|
+
|
|
326
|
+
| Model type | Sessions loaded |
|
|
327
|
+
|------------|----------------|
|
|
328
|
+
| EncoderOnly | `model.onnx` |
|
|
329
|
+
| DecoderOnly | `decoder_model_merged.onnx` + `generation_config.json` |
|
|
330
|
+
| Seq2Seq | `encoder_model.onnx` + `decoder_model_merged.onnx` + `generation_config.json` |
|
|
331
|
+
| Vision2Seq | Same as Seq2Seq |
|
|
332
|
+
| EncoderDecoder | `encoder_model.onnx` + `decoder_model_merged.onnx` |
|
|
333
|
+
|
|
334
|
+
---
|
|
335
|
+
|
|
336
|
+
## 9. Tokenizer Loading
|
|
337
|
+
|
|
338
|
+
Tokenizers are loaded from `tokenizer.json` using the `tokenizers` gem
|
|
339
|
+
(Rust-based HuggingFace tokenizers with Ruby bindings):
|
|
340
|
+
|
|
341
|
+
```ruby
|
|
342
|
+
class AutoTokenizer
|
|
343
|
+
def self.from_pretrained(model_name_or_path, **options)
|
|
344
|
+
tokenizer_json_path = Hub.get_model_file(model_name_or_path, "tokenizer.json", true, **options)
|
|
345
|
+
tokenizer_config = Hub.get_model_json(model_name_or_path, "tokenizer_config.json", false, **options)
|
|
346
|
+
# Build tokenizer from JSON file
|
|
347
|
+
Tokenizers::Tokenizer.from_file(tokenizer_json_path)
|
|
348
|
+
end
|
|
349
|
+
end
|
|
350
|
+
```
|
|
351
|
+
|
|
352
|
+
---
|
|
353
|
+
|
|
354
|
+
## 10. Pipeline System
|
|
355
|
+
|
|
356
|
+
### 10.1 Pipeline factory
|
|
357
|
+
|
|
358
|
+
The `pipeline()` method is the main entry point. It:
|
|
359
|
+
|
|
360
|
+
1. Looks up the task in `SUPPORTED_TASKS` to find the default model, model class,
|
|
361
|
+
tokenizer class, and pipeline class.
|
|
362
|
+
2. Loads the model via `AutoModel.from_pretrained(model_id, **options)`.
|
|
363
|
+
3. Loads the tokenizer via `AutoTokenizer.from_pretrained(model_id, **options)`.
|
|
364
|
+
4. Loads the processor (for vision/audio tasks) if needed.
|
|
365
|
+
5. Returns a pipeline instance that is callable.
|
|
366
|
+
|
|
367
|
+
```ruby
|
|
368
|
+
def self.pipeline(task, model = nil, **options)
|
|
369
|
+
task_info = SUPPORTED_TASKS[task]
|
|
370
|
+
model ||= task_info[:default][:model]
|
|
371
|
+
|
|
372
|
+
loaded_model = task_info[:model].from_pretrained(model, **options)
|
|
373
|
+
tokenizer = task_info[:tokenizer]&.from_pretrained(model, **options)
|
|
374
|
+
processor = task_info[:processor]&.from_pretrained(model, **options)
|
|
375
|
+
|
|
376
|
+
task_info[:pipeline].new(task: task, model: loaded_model, tokenizer: tokenizer, processor: processor)
|
|
377
|
+
end
|
|
378
|
+
```
|
|
379
|
+
|
|
380
|
+
### 10.2 SUPPORTED_TASKS registry
|
|
381
|
+
|
|
382
|
+
Each supported task is registered in a hash mapping task name to its components:
|
|
383
|
+
|
|
384
|
+
```ruby
|
|
385
|
+
SUPPORTED_TASKS = {
|
|
386
|
+
"sentiment-analysis" => {
|
|
387
|
+
tokenizer: AutoTokenizer,
|
|
388
|
+
pipeline: TextClassificationPipeline,
|
|
389
|
+
model: AutoModelForSequenceClassification,
|
|
390
|
+
default: { model: "Xenova/distilbert-base-uncased-finetuned-sst-2-english" },
|
|
391
|
+
type: "text"
|
|
392
|
+
},
|
|
393
|
+
"embedding" => {
|
|
394
|
+
tokenizer: AutoTokenizer,
|
|
395
|
+
pipeline: EmbeddingPipeline,
|
|
396
|
+
model: AutoModel,
|
|
397
|
+
default: { model: "Xenova/all-MiniLM-L6-v2" },
|
|
398
|
+
type: "text"
|
|
399
|
+
},
|
|
400
|
+
"reranking" => {
|
|
401
|
+
tokenizer: AutoTokenizer,
|
|
402
|
+
pipeline: RerankingPipeline,
|
|
403
|
+
model: AutoModelForSequenceClassification,
|
|
404
|
+
default: { model: "Xenova/ms-marco-MiniLM-L-6-v2" },
|
|
405
|
+
type: "text"
|
|
406
|
+
},
|
|
407
|
+
# ... more tasks
|
|
408
|
+
}
|
|
409
|
+
```
|
|
410
|
+
|
|
411
|
+
### 10.3 Pipeline base class
|
|
412
|
+
|
|
413
|
+
```ruby
|
|
414
|
+
class Pipeline
|
|
415
|
+
def initialize(task:, model:, tokenizer: nil, processor: nil)
|
|
416
|
+
@task = task
|
|
417
|
+
@model = model
|
|
418
|
+
@tokenizer = tokenizer
|
|
419
|
+
@processor = processor
|
|
420
|
+
end
|
|
421
|
+
|
|
422
|
+
# Subclasses implement call() with task-specific logic:
|
|
423
|
+
# 1. Tokenize/preprocess the input
|
|
424
|
+
# 2. Run model inference
|
|
425
|
+
# 3. Post-process and return structured results
|
|
426
|
+
end
|
|
427
|
+
```
|
|
428
|
+
|
|
429
|
+
### 10.4 Common pipeline types
|
|
430
|
+
|
|
431
|
+
| Pipeline class | Task | Input | Output |
|
|
432
|
+
|---------------|------|-------|--------|
|
|
433
|
+
| `TextClassificationPipeline` | `sentiment-analysis` | String | `{ label:, score: }` |
|
|
434
|
+
| `TokenClassificationPipeline` | `ner` | String | `[{ entity_group:, word:, score:, start:, end: }]` |
|
|
435
|
+
| `QuestionAnsweringPipeline` | `question-answering` | question + context | `{ answer:, score:, start:, end: }` |
|
|
436
|
+
| `EmbeddingPipeline` | `embedding` | String or [String] | Array of floats (vector) |
|
|
437
|
+
| `RerankingPipeline` | `reranking` | query + [docs] | `[{ doc_id:, score: }]` |
|
|
438
|
+
| `Text2TextGenerationPipeline` | `text2text-generation` | String | `{ generated_text: }` |
|
|
439
|
+
| `TextGenerationPipeline` | `text-generation` | String | `{ generated_text: }` |
|
|
440
|
+
| `SummarizationPipeline` | `summarization` | String | `{ summary_text: }` |
|
|
441
|
+
| `TranslationPipeline` | `translation` | String | `{ translation_text: }` |
|
|
442
|
+
| `FillMaskPipeline` | `fill-mask` | String with `[MASK]` | `[{ score:, token_str:, sequence: }]` |
|
|
443
|
+
| `ImageClassificationPipeline` | `image-classification` | image path | `{ label:, score: }` |
|
|
444
|
+
| `ObjectDetectionPipeline` | `object-detection` | image path | `[{ label:, score:, box: }]` |
|
|
445
|
+
|
|
446
|
+
---
|
|
447
|
+
|
|
448
|
+
## 11. Environment Variables
|
|
449
|
+
|
|
450
|
+
| Variable | Purpose | Default |
|
|
451
|
+
|----------|---------|---------|
|
|
452
|
+
| `HF_TOKEN` | Hugging Face auth token for private models | (none) |
|
|
453
|
+
| `XDG_CACHE_HOME` | Base cache directory | `~/.cache` |
|
|
454
|
+
| `<GEM>_OFFLINE` | Disable remote downloads when set | (empty = online) |
|
|
455
|
+
|
|
456
|
+
---
|
|
457
|
+
|
|
458
|
+
## 12. Runtime Dependencies
|
|
459
|
+
|
|
460
|
+
| Gem | Purpose |
|
|
461
|
+
|-----|---------|
|
|
462
|
+
| `onnxruntime` (~> 0.x) | ONNX model inference engine |
|
|
463
|
+
| `tokenizers` (~> 0.x) | HuggingFace tokenizers (Rust bindings) |
|
|
464
|
+
|
|
465
|
+
**No other runtime dependencies.** HTTP download uses Ruby stdlib (`open-uri`, `net/http`, `json`, `fileutils`).
|
|
466
|
+
|
|
467
|
+
Optional dependencies (documented, not required):
|
|
468
|
+
- `ruby-vips` — for image loading (vision tasks)
|
|
469
|
+
- `ffmpeg` — for audio loading (audio tasks)
|
|
470
|
+
|
|
471
|
+
---
|
|
472
|
+
|
|
473
|
+
## 13. Gemspec Conventions
|
|
474
|
+
|
|
475
|
+
```ruby
|
|
476
|
+
Gem::Specification.new do |spec|
|
|
477
|
+
spec.name = "<gem_name>"
|
|
478
|
+
spec.version = <GemName>::VERSION
|
|
479
|
+
spec.summary = "Fast transformer inference for Ruby"
|
|
480
|
+
|
|
481
|
+
spec.required_ruby_version = ">= 3.1.0"
|
|
482
|
+
|
|
483
|
+
spec.add_dependency "onnxruntime", "~> 0.9"
|
|
484
|
+
spec.add_dependency "tokenizers", "~> 0.5"
|
|
485
|
+
|
|
486
|
+
spec.files = Dir.glob("lib/**/*") + %w[README.md LICENSE.txt CHANGELOG.md]
|
|
487
|
+
spec.require_paths = ["lib"]
|
|
488
|
+
end
|
|
489
|
+
```
|
|
490
|
+
|
|
491
|
+
---
|
|
492
|
+
|
|
493
|
+
## 14. Adapting for a New Gem
|
|
494
|
+
|
|
495
|
+
To create a new gem following this pattern:
|
|
496
|
+
|
|
497
|
+
1. **Define your tasks** — What pipelines will the gem support? List them with
|
|
498
|
+
their default models.
|
|
499
|
+
2. **Copy the file layout** — Especially `utils/hub.rb` (reusable as-is),
|
|
500
|
+
`env.rb`, `configs.rb`.
|
|
501
|
+
3. **Register models** — Create the `MODEL_CLASS_MAPPINGS` for each model type
|
|
502
|
+
you want to support.
|
|
503
|
+
4. **Implement pipelines** — One class per task, each implementing `call()` with:
|
|
504
|
+
- Input preprocessing (tokenization, image processing, etc.)
|
|
505
|
+
- Model inference via ONNX Runtime
|
|
506
|
+
- Output post-processing (softmax, argmax, decoding, etc.)
|
|
507
|
+
5. **Wire `SUPPORTED_TASKS`** — Map task names to their pipeline, model, tokenizer,
|
|
508
|
+
and default model.
|
|
509
|
+
6. **Expose `<GemName>.pipeline(task, model)`** as the main entry point.
|
|
510
|
+
|
|
511
|
+
### Key design decisions to make:
|
|
512
|
+
|
|
513
|
+
- **Which ONNX models are compatible?** The model MUST have a `.onnx` file on HF.
|
|
514
|
+
Models from `Xenova/` namespace are pre-converted and reliable.
|
|
515
|
+
- **Default dtype**: `q8` (quantized) is a good default — smaller downloads, fast
|
|
516
|
+
inference, minimal accuracy loss for most tasks.
|
|
517
|
+
- **Which tasks to support?** Start minimal (e.g. embedding + reranking) and add
|
|
518
|
+
tasks as needed.
|
|
519
|
+
|
|
520
|
+
---
|
|
521
|
+
|
|
522
|
+
## 15. Key Differences from WASM Binary Wrapper Pattern
|
|
523
|
+
|
|
524
|
+
| Aspect | WASM Wrapper (AGENTS.md) | HF ONNX Inference (this file) |
|
|
525
|
+
|--------|--------------------------|-------------------------------|
|
|
526
|
+
| Source | GitHub Releases (single binary) | Hugging Face Hub (multiple files) |
|
|
527
|
+
| Format | Single `.wasm` file | `.onnx` + `.json` config files |
|
|
528
|
+
| Runtime | External CLI (`wasmtime`) | In-process (`onnxruntime` gem) |
|
|
529
|
+
| Cache | No cache (explicit `binary_path`) | XDG cache, automatic |
|
|
530
|
+
| Download | Explicit (`download_to_binary_path!`) | Implicit (lazy, on first use) |
|
|
531
|
+
| Models | One binary per gem | Any compatible HF model |
|
|
532
|
+
| Dependencies | Zero (stdlib only) | `onnxruntime` + `tokenizers` |
|
|
533
|
+
| API style | `Module.run(*args)` (CLI-like) | `pipeline.(input)` (callable object) |
|
|
534
|
+
|
|
535
|
+
---
|
|
536
|
+
|
|
537
|
+
## 16. Testing Strategy
|
|
538
|
+
|
|
539
|
+
### 16.1 Framework
|
|
540
|
+
|
|
541
|
+
Use **Minitest** (stdlib). Tests run with `rake test`.
|
|
542
|
+
|
|
543
|
+
### 16.2 Stubbing conventions
|
|
544
|
+
|
|
545
|
+
- **Hub/download tests**: Stub HTTP calls. Never download real models in unit tests.
|
|
546
|
+
Use `webmock` or stub `Hub.get_model_file` to return fixture file paths.
|
|
547
|
+
- **Pipeline tests**: Use small fixture ONNX models or stub the ONNX session.
|
|
548
|
+
- **Integration tests**: Can use a real (small) model for end-to-end validation.
|
|
549
|
+
Guard with `skip "Requires network"` and a CI flag.
|
|
550
|
+
|
|
551
|
+
### 16.3 Test fixtures
|
|
552
|
+
|
|
553
|
+
Provide minimal fixture files in `test/fixtures/`:
|
|
554
|
+
|
|
555
|
+
```
|
|
556
|
+
test/fixtures/
|
|
557
|
+
config.json # Minimal model config
|
|
558
|
+
tokenizer.json # Minimal tokenizer
|
|
559
|
+
model.onnx # Small/dummy ONNX model (optional)
|
|
560
|
+
```
|
|
561
|
+
|
|
562
|
+
### 16.4 Test checklist
|
|
563
|
+
|
|
564
|
+
**Hub module:**
|
|
565
|
+
- [ ] Downloads file from remote when not cached
|
|
566
|
+
- [ ] Returns cached file when already downloaded
|
|
567
|
+
- [ ] Sends `HF_TOKEN` in auth header when set
|
|
568
|
+
- [ ] Raises when `local_files_only: true` and file not in cache
|
|
569
|
+
- [ ] Creates intermediate directories
|
|
570
|
+
- [ ] Uses atomic write (`.incomplete` + rename)
|
|
571
|
+
- [ ] Handles HTTP errors gracefully
|
|
572
|
+
|
|
573
|
+
**Config:**
|
|
574
|
+
- [ ] Parses `config.json` correctly
|
|
575
|
+
- [ ] Exposes `model_type`, `id2label`, and other fields
|
|
576
|
+
|
|
577
|
+
**Pipeline factory:**
|
|
578
|
+
- [ ] Returns correct pipeline class for each supported task
|
|
579
|
+
- [ ] Uses default model when none specified
|
|
580
|
+
- [ ] Passes options through to model/tokenizer loading
|
|
581
|
+
|
|
582
|
+
**Individual pipelines:**
|
|
583
|
+
- [ ] Tokenizes input correctly
|
|
584
|
+
- [ ] Returns expected output structure
|
|
585
|
+
- [ ] Handles single and batched inputs
|
|
586
|
+
- [ ] Handles edge cases (empty input, very long input)
|
|
587
|
+
|
|
588
|
+
**Integration:**
|
|
589
|
+
- [ ] Full workflow: `pipeline(task, model) -> call(input) -> structured result`
|
|
590
|
+
- [ ] Offline mode raises when model not cached
|
|
591
|
+
|
|
592
|
+
---
|
|
593
|
+
|
|
594
|
+
## 17. Complete Flow Diagram
|
|
595
|
+
|
|
596
|
+
```
|
|
597
|
+
User code:
|
|
598
|
+
model = <GemName>.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
|
|
599
|
+
embeddings = model.("Hello world")
|
|
600
|
+
|
|
601
|
+
Internal flow:
|
|
602
|
+
|
|
603
|
+
<GemName>.pipeline(task, model_id, **options)
|
|
604
|
+
│
|
|
605
|
+
├─ SUPPORTED_TASKS[task]
|
|
606
|
+
│ → { pipeline: EmbeddingPipeline, model: AutoModel, tokenizer: AutoTokenizer, default: ... }
|
|
607
|
+
│
|
|
608
|
+
├─ AutoConfig.from_pretrained(model_id)
|
|
609
|
+
│ └─ Hub.get_model_json(model_id, "config.json")
|
|
610
|
+
│ ├─ FileCache.match("sentence-transformers/all-MiniLM-L6-v2/config.json")
|
|
611
|
+
│ │ → cache hit? return path
|
|
612
|
+
│ └─ HTTP GET https://huggingface.co/.../config.json
|
|
613
|
+
│ → write to cache → return path
|
|
614
|
+
│
|
|
615
|
+
├─ AutoModel.from_pretrained(model_id)
|
|
616
|
+
│ ├─ config[:model_type] → "bert" → BertModel
|
|
617
|
+
│ └─ construct_session(model_id, "model")
|
|
618
|
+
│ └─ Hub.get_model_file(model_id, "onnx/model_quantized.onnx")
|
|
619
|
+
│ → download if needed → return path
|
|
620
|
+
│ → OnnxRuntime::InferenceSession.new(path)
|
|
621
|
+
│
|
|
622
|
+
├─ AutoTokenizer.from_pretrained(model_id)
|
|
623
|
+
│ └─ Hub.get_model_file(model_id, "tokenizer.json")
|
|
624
|
+
│ → download if needed → return path
|
|
625
|
+
│ → Tokenizers::Tokenizer.from_file(path)
|
|
626
|
+
│
|
|
627
|
+
└─ EmbeddingPipeline.new(model:, tokenizer:)
|
|
628
|
+
│
|
|
629
|
+
└─ pipeline.("Hello world")
|
|
630
|
+
├─ tokenizer.("Hello world") → { input_ids:, attention_mask: }
|
|
631
|
+
├─ model.(tokenized_input) → raw ONNX output
|
|
632
|
+
├─ mean_pooling(output, attention_mask)
|
|
633
|
+
├─ normalize(pooled)
|
|
634
|
+
└─ return [0.012, -0.034, 0.056, ...] # embedding vector
|
|
635
|
+
```
|