onnx-ruby 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,436 @@
1
+ #include <rice/rice.hpp>
2
+ #include <rice/stl.hpp>
3
+ #include <onnxruntime_cxx_api.h>
4
+ #ifdef __APPLE__
5
+ #include <coreml_provider_factory.h>
6
+ #endif
7
+ #include <vector>
8
+ #include <string>
9
+ #include <cstring>
10
+ #include <stdexcept>
11
+ #include <memory>
12
+ #include <unordered_map>
13
+
14
+ using namespace Rice;
15
+
16
+ // Global ORT environment (initialized once)
17
+ static Ort::Env& get_env(int log_level = ORT_LOGGING_LEVEL_WARNING) {
18
+ static Ort::Env env(static_cast<OrtLoggingLevel>(log_level), "onnx_ruby");
19
+ return env;
20
+ }
21
+
22
+ // Map ORT element type to Ruby symbol name
23
+ static std::string ort_type_to_string(ONNXTensorElementDataType type) {
24
+ switch (type) {
25
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return "float32";
26
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return "float64";
27
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return "int32";
28
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return "int64";
29
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: return "string";
30
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: return "bool";
31
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return "uint8";
32
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return "int8";
33
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return "uint16";
34
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return "int16";
35
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return "float16";
36
+ default: return "unknown";
37
+ }
38
+ }
39
+
40
+ // Flatten a nested Ruby array into a flat vector and compute shape
41
+ static void flatten_ruby_array(VALUE arr, std::vector<VALUE>& flat, std::vector<int64_t>& shape, int depth) {
42
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
43
+ flat.push_back(arr);
44
+ return;
45
+ }
46
+
47
+ long len = RARRAY_LEN(arr);
48
+ if (depth == (int)shape.size()) {
49
+ shape.push_back(len);
50
+ }
51
+
52
+ for (long i = 0; i < len; i++) {
53
+ flatten_ruby_array(rb_ary_entry(arr, i), flat, shape, depth + 1);
54
+ }
55
+ }
56
+
57
+ // Convert an ORT output tensor to a nested Ruby array
58
+ static Rice::Object tensor_to_ruby(const Ort::Value& tensor) {
59
+ auto type_info = tensor.GetTensorTypeAndShapeInfo();
60
+ auto shape = type_info.GetShape();
61
+ auto elem_type = type_info.GetElementType();
62
+ size_t total = type_info.GetElementCount();
63
+
64
+ // Build flat Ruby array first
65
+ Rice::Array flat;
66
+
67
+ switch (elem_type) {
68
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
69
+ const float* data = tensor.GetTensorData<float>();
70
+ for (size_t i = 0; i < total; i++) {
71
+ flat.push(Rice::Object(rb_float_new(static_cast<double>(data[i]))));
72
+ }
73
+ break;
74
+ }
75
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
76
+ const double* data = tensor.GetTensorData<double>();
77
+ for (size_t i = 0; i < total; i++) {
78
+ flat.push(Rice::Object(rb_float_new(data[i])));
79
+ }
80
+ break;
81
+ }
82
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
83
+ const int32_t* data = tensor.GetTensorData<int32_t>();
84
+ for (size_t i = 0; i < total; i++) {
85
+ flat.push(Rice::Object(INT2NUM(data[i])));
86
+ }
87
+ break;
88
+ }
89
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
90
+ const int64_t* data = tensor.GetTensorData<int64_t>();
91
+ for (size_t i = 0; i < total; i++) {
92
+ flat.push(Rice::Object(LONG2NUM(data[i])));
93
+ }
94
+ break;
95
+ }
96
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
97
+ const bool* data = tensor.GetTensorData<bool>();
98
+ for (size_t i = 0; i < total; i++) {
99
+ flat.push(Rice::Object(data[i] ? Qtrue : Qfalse));
100
+ }
101
+ break;
102
+ }
103
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
104
+ size_t count = total;
105
+ std::vector<std::string> strings(count);
106
+ // GetStringTensorContent approach
107
+ size_t total_len = tensor.GetStringTensorDataLength();
108
+ std::vector<char> buffer(total_len);
109
+ std::vector<size_t> offsets(count);
110
+ tensor.GetStringTensorContent(buffer.data(), total_len, offsets.data(), count);
111
+ for (size_t i = 0; i < count; i++) {
112
+ size_t start = offsets[i];
113
+ size_t end = (i + 1 < count) ? offsets[i + 1] : total_len;
114
+ flat.push(Rice::Object(rb_str_new(buffer.data() + start, end - start)));
115
+ }
116
+ break;
117
+ }
118
+ default:
119
+ throw std::runtime_error("Unsupported output tensor element type: " + ort_type_to_string(elem_type));
120
+ }
121
+
122
+ // Reshape flat array into nested arrays according to shape
123
+ if (shape.empty() || shape.size() == 1) {
124
+ return flat;
125
+ }
126
+
127
+ // Reshape from innermost dimension outward
128
+ Rice::Array current = flat;
129
+ for (int d = (int)shape.size() - 1; d >= 1; d--) {
130
+ int64_t dim_size = shape[d];
131
+ Rice::Array reshaped;
132
+ long total_items = RARRAY_LEN(current.value());
133
+ for (long i = 0; i < total_items; i += dim_size) {
134
+ Rice::Array slice;
135
+ for (int64_t j = 0; j < dim_size; j++) {
136
+ slice.push(Rice::Object(rb_ary_entry(current.value(), i + j)));
137
+ }
138
+ reshaped.push(slice);
139
+ }
140
+ current = reshaped;
141
+ }
142
+
143
+ return current;
144
+ }
145
+
146
+ // Map Ruby optimization level symbol to ORT enum
147
+ static GraphOptimizationLevel parse_opt_level(const std::string& level) {
148
+ if (level == "none" || level == "disabled") return GraphOptimizationLevel::ORT_DISABLE_ALL;
149
+ if (level == "basic") return GraphOptimizationLevel::ORT_ENABLE_BASIC;
150
+ if (level == "extended") return GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
151
+ return GraphOptimizationLevel::ORT_ENABLE_ALL; // "all" or default
152
+ }
153
+
154
+ // Optimize a model and save to disk
155
+ static Rice::Object optimize_model(const std::string& input_path, const std::string& output_path,
156
+ const std::string& opt_level) {
157
+ Ort::SessionOptions opts;
158
+ opts.SetGraphOptimizationLevel(parse_opt_level(opt_level));
159
+ opts.SetOptimizedModelFilePath(output_path.c_str());
160
+
161
+ // Creating the session triggers optimization and saves the optimized model
162
+ Ort::Session session(get_env(), input_path.c_str(), opts);
163
+ return Rice::Object(Qtrue);
164
+ }
165
+
166
+ // Get available execution providers
167
+ static Rice::Array available_providers() {
168
+ Rice::Array result;
169
+ auto providers = Ort::GetAvailableProviders();
170
+ for (const auto& p : providers) {
171
+ result.push(Rice::String(p));
172
+ }
173
+ return result;
174
+ }
175
+
176
+ class SessionWrapper {
177
+ public:
178
+ SessionWrapper(const std::string& model_path, int log_level, int intra_threads, int inter_threads,
179
+ const std::string& opt_level, bool memory_pattern, bool cpu_mem_arena,
180
+ const std::string& execution_mode, Rice::Array providers) {
181
+ Ort::SessionOptions opts;
182
+
183
+ if (intra_threads > 0) opts.SetIntraOpNumThreads(intra_threads);
184
+ if (inter_threads > 0) opts.SetInterOpNumThreads(inter_threads);
185
+ opts.SetGraphOptimizationLevel(parse_opt_level(opt_level));
186
+
187
+ if (!memory_pattern) opts.DisableMemPattern();
188
+ if (!cpu_mem_arena) opts.DisableCpuMemArena();
189
+
190
+ if (execution_mode == "parallel") {
191
+ opts.SetExecutionMode(ExecutionMode::ORT_PARALLEL);
192
+ } else {
193
+ opts.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
194
+ }
195
+
196
+ // Append execution providers
197
+ for (long i = 0; i < RARRAY_LEN(providers.value()); i++) {
198
+ std::string provider = Rice::detail::From_Ruby<std::string>().convert(
199
+ rb_ary_entry(providers.value(), i));
200
+
201
+ if (provider == "coreml") {
202
+ #ifdef __APPLE__
203
+ uint32_t coreml_flags = COREML_FLAG_USE_NONE;
204
+ auto status = OrtSessionOptionsAppendExecutionProvider_CoreML(opts, coreml_flags);
205
+ if (status) {
206
+ std::string msg = Ort::GetApi().GetErrorMessage(status);
207
+ Ort::GetApi().ReleaseStatus(status);
208
+ throw std::runtime_error("CoreML provider error: " + msg);
209
+ }
210
+ #else
211
+ throw std::runtime_error("CoreML provider is only available on macOS/iOS");
212
+ #endif
213
+ } else if (provider == "cuda") {
214
+ OrtCUDAProviderOptions cuda_opts;
215
+ memset(&cuda_opts, 0, sizeof(cuda_opts));
216
+ opts.AppendExecutionProvider_CUDA(cuda_opts);
217
+ } else if (provider == "tensorrt") {
218
+ OrtTensorRTProviderOptions trt_opts;
219
+ memset(&trt_opts, 0, sizeof(trt_opts));
220
+ opts.AppendExecutionProvider_TensorRT(trt_opts);
221
+ } else if (provider == "cpu") {
222
+ // CPU is always available as fallback, no-op
223
+ } else {
224
+ throw std::runtime_error("Unknown execution provider: " + provider);
225
+ }
226
+ }
227
+
228
+ session_ = std::make_unique<Ort::Session>(get_env(log_level), model_path.c_str(), opts);
229
+ allocator_ = Ort::AllocatorWithDefaultOptions();
230
+ }
231
+
232
+ // Get input metadata
233
+ Rice::Array input_info() {
234
+ Rice::Array result;
235
+ size_t count = session_->GetInputCount();
236
+
237
+ for (size_t i = 0; i < count; i++) {
238
+ auto name = session_->GetInputNameAllocated(i, allocator_);
239
+ auto type_info = session_->GetInputTypeInfo(i);
240
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
241
+ auto shape = tensor_info.GetShape();
242
+ auto elem_type = tensor_info.GetElementType();
243
+
244
+ Rice::Hash info;
245
+ info[Rice::Symbol("name")] = Rice::String(name.get());
246
+ info[Rice::Symbol("type")] = Rice::Symbol(ort_type_to_string(elem_type));
247
+
248
+ Rice::Array rb_shape;
249
+ for (auto dim : shape) {
250
+ rb_shape.push(Rice::Object(LONG2NUM(dim)));
251
+ }
252
+ info[Rice::Symbol("shape")] = rb_shape;
253
+
254
+ result.push(Rice::Object(info.value()));
255
+ }
256
+ return result;
257
+ }
258
+
259
+ // Get output metadata
260
+ Rice::Array output_info() {
261
+ Rice::Array result;
262
+ size_t count = session_->GetOutputCount();
263
+
264
+ for (size_t i = 0; i < count; i++) {
265
+ auto name = session_->GetOutputNameAllocated(i, allocator_);
266
+ auto type_info = session_->GetOutputTypeInfo(i);
267
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
268
+ auto shape = tensor_info.GetShape();
269
+ auto elem_type = tensor_info.GetElementType();
270
+
271
+ Rice::Hash info;
272
+ info[Rice::Symbol("name")] = Rice::String(name.get());
273
+ info[Rice::Symbol("type")] = Rice::Symbol(ort_type_to_string(elem_type));
274
+
275
+ Rice::Array rb_shape;
276
+ for (auto dim : shape) {
277
+ rb_shape.push(Rice::Object(LONG2NUM(dim)));
278
+ }
279
+ info[Rice::Symbol("shape")] = rb_shape;
280
+
281
+ result.push(Rice::Object(info.value()));
282
+ }
283
+ return result;
284
+ }
285
+
286
+ // Run inference
287
+ Rice::Object run(Rice::Array input_specs, Rice::Array output_names_filter) {
288
+ std::vector<const char*> input_names;
289
+ std::vector<Ort::Value> input_tensors;
290
+ std::vector<std::string> input_name_strs;
291
+
292
+ // Storage for tensor data (must outlive the Run call)
293
+ std::vector<std::vector<float>> float_buffers;
294
+ std::vector<std::vector<double>> double_buffers;
295
+ std::vector<std::vector<int32_t>> int32_buffers;
296
+ std::vector<std::vector<int64_t>> int64_buffers;
297
+ // Use uint8_t instead of bool because std::vector<bool> is bit-packed and has no .data()
298
+ std::vector<std::vector<uint8_t>> bool_buffers;
299
+
300
+ auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
301
+
302
+ for (long idx = 0; idx < RARRAY_LEN(input_specs.value()); idx++) {
303
+ Rice::Hash spec(rb_ary_entry(input_specs.value(), idx));
304
+
305
+ std::string name = Rice::detail::From_Ruby<std::string>().convert(
306
+ spec[Rice::Symbol("name")].value());
307
+ input_name_strs.push_back(name);
308
+
309
+ Rice::Array data(spec[Rice::Symbol("data")].value());
310
+ Rice::Array rb_shape(spec[Rice::Symbol("shape")].value());
311
+ std::string dtype = Rice::detail::From_Ruby<std::string>().convert(
312
+ spec[Rice::Symbol("dtype")].value());
313
+
314
+ std::vector<int64_t> shape;
315
+ for (long s = 0; s < RARRAY_LEN(rb_shape.value()); s++) {
316
+ shape.push_back(NUM2LONG(rb_ary_entry(rb_shape.value(), s)));
317
+ }
318
+
319
+ size_t total_elements = 1;
320
+ for (auto dim : shape) total_elements *= dim;
321
+
322
+ if (dtype == "float") {
323
+ float_buffers.emplace_back(total_elements);
324
+ auto& buf = float_buffers.back();
325
+ for (size_t i = 0; i < total_elements; i++) {
326
+ buf[i] = static_cast<float>(NUM2DBL(rb_ary_entry(data.value(), i)));
327
+ }
328
+ input_tensors.push_back(
329
+ Ort::Value::CreateTensor<float>(memory_info, buf.data(), total_elements,
330
+ shape.data(), shape.size()));
331
+ } else if (dtype == "double") {
332
+ double_buffers.emplace_back(total_elements);
333
+ auto& buf = double_buffers.back();
334
+ for (size_t i = 0; i < total_elements; i++) {
335
+ buf[i] = NUM2DBL(rb_ary_entry(data.value(), i));
336
+ }
337
+ input_tensors.push_back(
338
+ Ort::Value::CreateTensor<double>(memory_info, buf.data(), total_elements,
339
+ shape.data(), shape.size()));
340
+ } else if (dtype == "int32") {
341
+ int32_buffers.emplace_back(total_elements);
342
+ auto& buf = int32_buffers.back();
343
+ for (size_t i = 0; i < total_elements; i++) {
344
+ buf[i] = static_cast<int32_t>(NUM2INT(rb_ary_entry(data.value(), i)));
345
+ }
346
+ input_tensors.push_back(
347
+ Ort::Value::CreateTensor<int32_t>(memory_info, buf.data(), total_elements,
348
+ shape.data(), shape.size()));
349
+ } else if (dtype == "int64") {
350
+ int64_buffers.emplace_back(total_elements);
351
+ auto& buf = int64_buffers.back();
352
+ for (size_t i = 0; i < total_elements; i++) {
353
+ buf[i] = NUM2LONG(rb_ary_entry(data.value(), i));
354
+ }
355
+ input_tensors.push_back(
356
+ Ort::Value::CreateTensor<int64_t>(memory_info, buf.data(), total_elements,
357
+ shape.data(), shape.size()));
358
+ } else if (dtype == "bool") {
359
+ bool_buffers.emplace_back(total_elements);
360
+ auto& buf = bool_buffers.back();
361
+ for (size_t i = 0; i < total_elements; i++) {
362
+ buf[i] = RTEST(rb_ary_entry(data.value(), i)) ? 1 : 0;
363
+ }
364
+ input_tensors.push_back(
365
+ Ort::Value::CreateTensor(memory_info, reinterpret_cast<bool*>(buf.data()), total_elements,
366
+ shape.data(), shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL));
367
+ } else {
368
+ throw std::runtime_error("Unsupported input dtype: " + dtype);
369
+ }
370
+ }
371
+
372
+ for (auto& s : input_name_strs) {
373
+ input_names.push_back(s.c_str());
374
+ }
375
+
376
+ // Determine output names
377
+ std::vector<std::string> output_name_strs;
378
+ std::vector<const char*> output_names;
379
+
380
+ if (RARRAY_LEN(output_names_filter.value()) > 0) {
381
+ for (long i = 0; i < RARRAY_LEN(output_names_filter.value()); i++) {
382
+ output_name_strs.push_back(
383
+ Rice::detail::From_Ruby<std::string>().convert(
384
+ rb_ary_entry(output_names_filter.value(), i)));
385
+ }
386
+ } else {
387
+ size_t output_count = session_->GetOutputCount();
388
+ for (size_t i = 0; i < output_count; i++) {
389
+ auto name = session_->GetOutputNameAllocated(i, allocator_);
390
+ output_name_strs.push_back(name.get());
391
+ }
392
+ }
393
+
394
+ for (auto& s : output_name_strs) {
395
+ output_names.push_back(s.c_str());
396
+ }
397
+
398
+ // Run inference
399
+ auto results = session_->Run(
400
+ Ort::RunOptions{nullptr},
401
+ input_names.data(), input_tensors.data(), input_names.size(),
402
+ output_names.data(), output_names.size());
403
+
404
+ // Convert results to Ruby Hash
405
+ Rice::Hash output;
406
+ for (size_t i = 0; i < results.size(); i++) {
407
+ Rice::Object rb_tensor = tensor_to_ruby(results[i]);
408
+ output[Rice::String(output_name_strs[i])] = rb_tensor;
409
+ }
410
+
411
+ return output;
412
+ }
413
+
414
+ private:
415
+ std::unique_ptr<Ort::Session> session_;
416
+ Ort::AllocatorWithDefaultOptions allocator_;
417
+ };
418
+
419
+ extern "C" void Init_onnx_ruby_ext() {
420
+ Module rb_mOnnxRuby = define_module("OnnxRuby");
421
+ Module rb_mExt = define_module_under(rb_mOnnxRuby, "Ext");
422
+
423
+ define_class_under<SessionWrapper>(rb_mExt, "SessionWrapper")
424
+ .define_constructor(Constructor<SessionWrapper, const std::string&, int, int, int,
425
+ const std::string&, bool, bool, const std::string&, Rice::Array>(),
426
+ Arg("model_path"), Arg("log_level"), Arg("intra_threads"), Arg("inter_threads"),
427
+ Arg("opt_level"), Arg("memory_pattern"), Arg("cpu_mem_arena"),
428
+ Arg("execution_mode"), Arg("providers"))
429
+ .define_method("input_info", &SessionWrapper::input_info)
430
+ .define_method("output_info", &SessionWrapper::output_info)
431
+ .define_method("run", &SessionWrapper::run);
432
+
433
+ rb_mExt.define_module_function("optimize_model", &optimize_model,
434
+ Arg("input_path"), Arg("output_path"), Arg("opt_level"));
435
+ rb_mExt.define_module_function("available_providers", &available_providers);
436
+ }
@@ -0,0 +1,107 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class Classifier
5
+ attr_reader :session, :labels
6
+
7
+ def initialize(model_path, tokenizer: nil, labels: nil, **session_opts)
8
+ @session = Session.new(model_path, **session_opts)
9
+ @labels = labels
10
+ @tokenizer = resolve_tokenizer(tokenizer)
11
+ end
12
+
13
+ # Classify a single input
14
+ # @param input [String, Array<Float>] text (requires tokenizer) or feature vector
15
+ # @return [Hash] { label:, score:, scores: }
16
+ def predict(input)
17
+ predict_batch([input]).first
18
+ end
19
+
20
+ # Classify a batch of inputs
21
+ # @param inputs [Array<String>, Array<Array<Float>>] batch of texts or feature vectors
22
+ # @return [Array<Hash>] array of { label:, score:, scores: }
23
+ def predict_batch(inputs)
24
+ feed = prepare_inputs(inputs)
25
+ result = @session.run(feed)
26
+
27
+ logits = find_output(result, %w[logits output probabilities scores])
28
+
29
+ logits.map { |row| format_prediction(row) }
30
+ end
31
+
32
+ private
33
+
34
+ def resolve_tokenizer(tokenizer)
35
+ return nil if tokenizer.nil?
36
+
37
+ if tokenizer.respond_to?(:encode)
38
+ tokenizer
39
+ else
40
+ begin
41
+ require "tokenizers"
42
+ Tokenizers::Tokenizer.from_pretrained(tokenizer.to_s)
43
+ rescue LoadError
44
+ raise Error, "tokenizer-ruby gem is required for text tokenization. " \
45
+ "Install with: gem install tokenizers"
46
+ end
47
+ end
48
+ end
49
+
50
+ def prepare_inputs(inputs)
51
+ if inputs.first.is_a?(String)
52
+ raise Error, "tokenizer is required for text inputs" unless @tokenizer
53
+
54
+ tokenize_batch(inputs)
55
+ elsif inputs.first.is_a?(Array)
56
+ # Raw feature vectors
57
+ input_name = @session.inputs.first[:name]
58
+ { input_name => inputs }
59
+ else
60
+ raise Error, "inputs must be Strings or Arrays"
61
+ end
62
+ end
63
+
64
+ def tokenize_batch(texts)
65
+ if @tokenizer.respond_to?(:encode_batch)
66
+ encodings = @tokenizer.encode_batch(texts)
67
+ ids = encodings.map(&:ids)
68
+ masks = encodings.map(&:attention_mask)
69
+ else
70
+ encodings = texts.map { |t| @tokenizer.encode(t) }
71
+ ids = encodings.map(&:ids)
72
+ masks = encodings.map(&:attention_mask)
73
+ end
74
+
75
+ max_len = ids.map(&:length).max
76
+ ids = ids.map { |row| row + Array.new(max_len - row.length, 0) }
77
+ masks = masks.map { |row| row + Array.new(max_len - row.length, 0) }
78
+
79
+ input_names = @session.inputs.map { |i| i[:name] }
80
+ feed = {}
81
+ feed[input_names.find { |n| n.include?("input_id") } || input_names[0]] = ids
82
+ mask_name = input_names.find { |n| n.include?("mask") || n.include?("attention") }
83
+ feed[mask_name] = masks if mask_name
84
+ feed
85
+ end
86
+
87
+ def find_output(result, candidate_names)
88
+ candidate_names.each { |name| return result[name] if result.key?(name) }
89
+ result.values.first
90
+ end
91
+
92
+ def format_prediction(logits_row)
93
+ probs = softmax(logits_row)
94
+ max_idx = probs.each_with_index.max_by(&:first).last
95
+ label = @labels ? @labels[max_idx] : max_idx
96
+
97
+ { label: label, score: probs[max_idx], scores: probs }
98
+ end
99
+
100
+ def softmax(logits)
101
+ max_val = logits.max
102
+ exps = logits.map { |v| Math.exp(v - max_val) }
103
+ sum = exps.sum
104
+ exps.map { |v| v / sum }
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class Configuration
5
+ attr_accessor :models_path, :default_providers, :default_log_level,
6
+ :pool_size, :pool_timeout
7
+
8
+ def initialize
9
+ @models_path = "app/models/onnx"
10
+ @default_providers = [:cpu]
11
+ @default_log_level = :warning
12
+ @pool_size = 5
13
+ @pool_timeout = 5
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,147 @@
1
+ # frozen_string_literal: true
2
+
3
+ module OnnxRuby
4
+ class Embedder
5
+ attr_reader :session
6
+
7
+ def initialize(model_path, tokenizer: nil, normalize: true, **session_opts)
8
+ @session = Session.new(model_path, **session_opts)
9
+ @normalize = normalize
10
+ @tokenizer = resolve_tokenizer(tokenizer)
11
+ end
12
+
13
+ # Embed a single text or pre-tokenized input
14
+ # @param input [String, Hash] text string (requires tokenizer) or hash of input tensors
15
+ # @return [Array<Float>] embedding vector
16
+ def embed(input)
17
+ embed_batch([input]).first
18
+ end
19
+
20
+ # Embed a batch of texts or pre-tokenized inputs
21
+ # @param inputs [Array<String>, Array<Hash>] batch of texts or tensor hashes
22
+ # @return [Array<Array<Float>>] array of embedding vectors
23
+ def embed_batch(inputs)
24
+ @_masks = nil
25
+ feed = prepare_inputs(inputs)
26
+ result = @session.run(feed)
27
+
28
+ raw = find_output(result, %w[embeddings sentence_embedding output last_hidden_state])
29
+
30
+ # If output is 3D (batch, seq_len, dim) — do mean pooling
31
+ embeddings = if raw.first.is_a?(Array) && raw.first.first.is_a?(Array)
32
+ mean_pool(raw, @_masks)
33
+ else
34
+ raw
35
+ end
36
+
37
+ embeddings.map { |vec| @normalize ? l2_normalize(vec) : vec }
38
+ end
39
+
40
+ private
41
+
42
+ def resolve_tokenizer(tokenizer)
43
+ return nil if tokenizer.nil?
44
+
45
+ if tokenizer.respond_to?(:encode) || tokenizer.respond_to?(:encode_batch)
46
+ tokenizer
47
+ else
48
+ begin
49
+ require "tokenizers"
50
+ Tokenizers::Tokenizer.from_pretrained(tokenizer.to_s)
51
+ rescue LoadError
52
+ raise Error, "tokenizer-ruby gem is required for text tokenization. " \
53
+ "Install with: gem install tokenizers"
54
+ end
55
+ end
56
+ end
57
+
58
+ def prepare_inputs(inputs)
59
+ if inputs.first.is_a?(String)
60
+ raise Error, "tokenizer is required for text inputs" unless @tokenizer
61
+
62
+ tokenize_batch(inputs)
63
+ elsif inputs.first.is_a?(Hash)
64
+ # Merge hash inputs into batched arrays
65
+ merge_input_hashes(inputs)
66
+ else
67
+ raise Error, "inputs must be Strings or Hashes"
68
+ end
69
+ end
70
+
71
+ def tokenize_batch(texts)
72
+ if @tokenizer.respond_to?(:encode_batch)
73
+ encodings = @tokenizer.encode_batch(texts)
74
+ ids = encodings.map(&:ids)
75
+ masks = encodings.map(&:attention_mask)
76
+ else
77
+ encodings = texts.map { |t| @tokenizer.encode(t) }
78
+ ids = encodings.map(&:ids)
79
+ masks = encodings.map(&:attention_mask)
80
+ end
81
+
82
+ # Pad to max length
83
+ max_len = ids.map(&:length).max
84
+ ids = ids.map { |row| row + Array.new(max_len - row.length, 0) }
85
+ masks = masks.map { |row| row + Array.new(max_len - row.length, 0) }
86
+
87
+ build_feed(ids, masks)
88
+ end
89
+
90
+ def merge_input_hashes(hashes)
91
+ result = {}
92
+ hashes.first.each_key do |key|
93
+ result[key] = hashes.map { |h| h[key] }
94
+ end
95
+ # Stash masks for mean pooling
96
+ mask_key = result.keys.find { |k| k.to_s.include?("mask") || k.to_s.include?("attention") }
97
+ @_masks = result[mask_key] if mask_key
98
+ result
99
+ end
100
+
101
+ def build_feed(ids, masks)
102
+ input_names = @session.inputs.map { |i| i[:name] }
103
+ feed = {}
104
+ feed[input_names.find { |n| n.include?("input_id") } || input_names[0]] = ids
105
+ mask_name = input_names.find { |n| n.include?("mask") || n.include?("attention") }
106
+ feed[mask_name] = masks if mask_name
107
+ # Supply token_type_ids (zeros) if the model expects it
108
+ tti_name = input_names.find { |n| n.include?("token_type") }
109
+ feed[tti_name] = ids.map { |row| Array.new(row.length, 0) } if tti_name
110
+ @_masks = masks # stash for mean pooling
111
+ feed
112
+ end
113
+
114
+ def find_output(result, candidate_names)
115
+ candidate_names.each { |name| return result[name] if result.key?(name) }
116
+ result.values.first
117
+ end
118
+
119
+ # Mean pooling over token embeddings, masked by attention_mask
120
+ def mean_pool(hidden_states, masks)
121
+ hidden_states.each_with_index.map do |tokens, batch_idx|
122
+ mask = masks && masks[batch_idx]
123
+ dim = tokens.first.length
124
+ sum = Array.new(dim, 0.0)
125
+ count = 0.0
126
+
127
+ tokens.each_with_index do |token_vec, tok_idx|
128
+ w = (mask && mask[tok_idx]) ? mask[tok_idx].to_f : 1.0
129
+ next if w.zero?
130
+
131
+ count += w
132
+ token_vec.each_with_index { |v, d| sum[d] += v * w }
133
+ end
134
+
135
+ count = 1.0 if count.zero?
136
+ sum.map { |v| v / count }
137
+ end
138
+ end
139
+
140
+ def l2_normalize(vec)
141
+ norm = Math.sqrt(vec.sum { |v| v * v })
142
+ return vec if norm.zero?
143
+
144
+ vec.map { |v| v / norm }
145
+ end
146
+ end
147
+ end