executorch 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +47 -0
- data/LICENSE.txt +176 -0
- data/README.md +198 -0
- data/ext/executorch/executorch.cpp +582 -0
- data/ext/executorch/extconf.rb +208 -0
- data/ext/executorch/utils.h +140 -0
- data/lib/executorch/version.rb +3 -0
- data/lib/executorch.rb +212 -0
- metadata +66 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
require 'mkmf-rice'
|
|
2
|
+
|
|
3
|
+
$CXXFLAGS += ' -std=c++17'
|
|
4
|
+
$CXXFLAGS += ' -DC10_USING_CUSTOM_GENERATED_MACROS'
|
|
5
|
+
$CXXFLAGS += ' -Wno-deprecated-declarations'
|
|
6
|
+
|
|
7
|
+
# ==============================================================================
|
|
8
|
+
# ExecuTorch Path Detection
|
|
9
|
+
# ==============================================================================
|
|
10
|
+
#
|
|
11
|
+
# 1. --with-executorch-dir flag (bundle config or gem install)
|
|
12
|
+
# 2. EXECUTORCH_DIR environment variable
|
|
13
|
+
#
|
|
14
|
+
|
|
15
|
+
include_dirs = []
|
|
16
|
+
lib_dirs = []
|
|
17
|
+
|
|
18
|
+
# Helper to add prefix paths
|
|
19
|
+
def add_prefix_paths(prefix, include_dirs, lib_dirs)
|
|
20
|
+
return false unless prefix && File.directory?(prefix)
|
|
21
|
+
|
|
22
|
+
inc = File.join(prefix, 'include')
|
|
23
|
+
lib = File.join(prefix, 'lib')
|
|
24
|
+
|
|
25
|
+
if File.directory?(inc) && File.directory?(lib)
|
|
26
|
+
include_dirs << inc
|
|
27
|
+
lib_dirs << lib
|
|
28
|
+
true
|
|
29
|
+
else
|
|
30
|
+
false
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
# Priority 1: --with-executorch-dir flag (standard Ruby gem pattern)
|
|
35
|
+
# Usage: bundle config set --local build.executorch --with-executorch-dir=/path/to/executorch
|
|
36
|
+
# or: gem install executorch -- --with-executorch-dir=/path/to/executorch
|
|
37
|
+
executorch_dir = arg_config('--with-executorch-dir')
|
|
38
|
+
if executorch_dir
|
|
39
|
+
if add_prefix_paths(executorch_dir, include_dirs, lib_dirs)
|
|
40
|
+
puts "Using --with-executorch-dir: #{executorch_dir}"
|
|
41
|
+
else
|
|
42
|
+
abort "Error: --with-executorch-dir path is invalid: #{executorch_dir}"
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Priority 2: Environment variable (for CI/scripting)
|
|
47
|
+
if include_dirs.empty?
|
|
48
|
+
executorch_prefix = ENV['EXECUTORCH_DIR']
|
|
49
|
+
if add_prefix_paths(executorch_prefix, include_dirs, lib_dirs)
|
|
50
|
+
puts "Using EXECUTORCH_DIR: #{executorch_prefix}"
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
include_dirs.compact!
|
|
55
|
+
include_dirs.uniq!
|
|
56
|
+
lib_dirs.compact!
|
|
57
|
+
lib_dirs.uniq!
|
|
58
|
+
|
|
59
|
+
if include_dirs.empty? || lib_dirs.empty?
|
|
60
|
+
abort <<~ERROR
|
|
61
|
+
ExecuTorch installation not found!
|
|
62
|
+
|
|
63
|
+
Configure the path using one of these methods:
|
|
64
|
+
|
|
65
|
+
1. Bundle config (recommended - set once per project):
|
|
66
|
+
bundle config set --local build.executorch --with-executorch-dir=/path/to/executorch
|
|
67
|
+
|
|
68
|
+
2. Environment variable (useful for CI):
|
|
69
|
+
EXECUTORCH_DIR=/path/to/executorch bundle install
|
|
70
|
+
|
|
71
|
+
Need to build ExecuTorch first? See: https://pytorch.org/executorch/
|
|
72
|
+
ERROR
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Validate headers exist
|
|
76
|
+
include_dir = include_dirs.first
|
|
77
|
+
lib_dir = lib_dirs.first
|
|
78
|
+
|
|
79
|
+
unless File.exist?(File.join(include_dir, 'executorch', 'extension', 'module', 'module.h'))
|
|
80
|
+
abort <<~ERROR
|
|
81
|
+
ExecuTorch module.h header not found at: #{include_dir}
|
|
82
|
+
Make sure EXECUTORCH_BUILD_EXTENSION_MODULE=ON was set during build.
|
|
83
|
+
ERROR
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
puts "Include directory: #{include_dir}"
|
|
87
|
+
puts "Library directory: #{lib_dir}"
|
|
88
|
+
|
|
89
|
+
# Configure include paths
|
|
90
|
+
# Also add the portable c10 headers path for standalone builds
|
|
91
|
+
portable_c10_dir = File.join(include_dir, 'executorch', 'runtime', 'core', 'portable_type', 'c10')
|
|
92
|
+
$INCFLAGS = "-I#{include_dir} -I#{portable_c10_dir} #{$INCFLAGS}"
|
|
93
|
+
|
|
94
|
+
# Configure library paths
|
|
95
|
+
$LDFLAGS += " -L#{lib_dir}"
|
|
96
|
+
|
|
97
|
+
# Add rpath for runtime library loading
|
|
98
|
+
$LDFLAGS += if RUBY_PLATFORM =~ /darwin/
|
|
99
|
+
" -Wl,-rpath,#{lib_dir}"
|
|
100
|
+
else
|
|
101
|
+
" -Wl,-rpath,#{lib_dir}"
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
# ==============================================================================
|
|
105
|
+
# Library Linking Configuration
|
|
106
|
+
# ==============================================================================
|
|
107
|
+
|
|
108
|
+
# Default libraries required for basic operation
|
|
109
|
+
DEFAULT_LIBS = %w[
|
|
110
|
+
extension_module_static
|
|
111
|
+
extension_data_loader
|
|
112
|
+
extension_tensor
|
|
113
|
+
extension_named_data_map
|
|
114
|
+
extension_flat_tensor
|
|
115
|
+
extension_threadpool
|
|
116
|
+
executorch
|
|
117
|
+
executorch_core
|
|
118
|
+
].freeze
|
|
119
|
+
|
|
120
|
+
# Determine which libraries to link
|
|
121
|
+
libs = if ENV['EXECUTORCH_LIBS']
|
|
122
|
+
# User-specified library list (overrides defaults)
|
|
123
|
+
user_libs = ENV['EXECUTORCH_LIBS'].split(',').map(&:strip).reject(&:empty?)
|
|
124
|
+
puts "Using custom library list: #{user_libs.join(', ')}"
|
|
125
|
+
user_libs
|
|
126
|
+
else
|
|
127
|
+
DEFAULT_LIBS.dup
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Link extension libraries first (order matters for static linking)
|
|
131
|
+
libs.each do |lib|
|
|
132
|
+
lib_file = File.join(lib_dir, "lib#{lib}.a")
|
|
133
|
+
if File.exist?(lib_file)
|
|
134
|
+
$LDFLAGS += " -l#{lib}"
|
|
135
|
+
puts " Linking: #{lib}"
|
|
136
|
+
else
|
|
137
|
+
# Try without _static suffix
|
|
138
|
+
alt_lib = lib.sub(/_static$/, '')
|
|
139
|
+
alt_file = File.join(lib_dir, "lib#{alt_lib}.a")
|
|
140
|
+
if File.exist?(alt_file)
|
|
141
|
+
$LDFLAGS += " -l#{alt_lib}"
|
|
142
|
+
puts " Linking: #{alt_lib} (alternative)"
|
|
143
|
+
elsif lib.include?('extension_')
|
|
144
|
+
# Extension libraries are optional
|
|
145
|
+
puts " Skipping optional: #{lib} (not found)"
|
|
146
|
+
else
|
|
147
|
+
# Core libraries should exist
|
|
148
|
+
abort "Required library not found: #{lib} (looked for #{lib_file})"
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
# Check for portable kernels (if available)
|
|
154
|
+
# Use -force_load on macOS to include global constructors that register kernels
|
|
155
|
+
portable_ops_lib = File.join(lib_dir, 'libportable_ops_lib.a')
|
|
156
|
+
if File.exist?(portable_ops_lib)
|
|
157
|
+
$LDFLAGS += if RUBY_PLATFORM =~ /darwin/
|
|
158
|
+
# macOS: -force_load pulls in all symbols including global constructors
|
|
159
|
+
" -Wl,-force_load,#{portable_ops_lib}"
|
|
160
|
+
else
|
|
161
|
+
# Linux: --whole-archive achieves the same effect
|
|
162
|
+
' -Wl,--whole-archive -lportable_ops_lib -Wl,--no-whole-archive'
|
|
163
|
+
end
|
|
164
|
+
puts ' Linking: portable_ops_lib (force_load)'
|
|
165
|
+
|
|
166
|
+
# Portable kernels
|
|
167
|
+
if File.exist?(File.join(lib_dir, 'libportable_kernels.a'))
|
|
168
|
+
$LDFLAGS += ' -lportable_kernels'
|
|
169
|
+
puts ' Linking: portable_kernels'
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
# CPU info and pthreadpool (often required)
|
|
174
|
+
%w[cpuinfo pthreadpool].each do |lib|
|
|
175
|
+
if File.exist?(File.join(lib_dir, "lib#{lib}.a"))
|
|
176
|
+
$LDFLAGS += " -l#{lib}"
|
|
177
|
+
puts " Linking: #{lib}"
|
|
178
|
+
end
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
# Extra libraries from environment (consolidated from version B)
|
|
182
|
+
if ENV['EXECUTORCH_EXTRA_LIBS']
|
|
183
|
+
extra_libs = ENV['EXECUTORCH_EXTRA_LIBS'].split(',').map(&:strip).reject(&:empty?)
|
|
184
|
+
extra_libs.each do |lib|
|
|
185
|
+
if File.exist?(File.join(lib_dir, "lib#{lib}.a"))
|
|
186
|
+
$LDFLAGS += " -l#{lib}"
|
|
187
|
+
puts " Linking extra: #{lib}"
|
|
188
|
+
else
|
|
189
|
+
puts " Warning: extra library not found: #{lib}"
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
|
|
194
|
+
# ==============================================================================
|
|
195
|
+
# Platform-Specific Configuration
|
|
196
|
+
# ==============================================================================
|
|
197
|
+
|
|
198
|
+
if (RUBY_PLATFORM =~ /darwin/) && (RUBY_PLATFORM =~ /arm64/)
|
|
199
|
+
# Apple Silicon specific
|
|
200
|
+
puts 'Building for Apple Silicon (arm64)'
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
$LDFLAGS += ' -lpthread' if RUBY_PLATFORM =~ /linux/
|
|
204
|
+
|
|
205
|
+
# Create Makefile
|
|
206
|
+
create_makefile('executorch/executorch')
|
|
207
|
+
|
|
208
|
+
puts "\nConfiguration complete. Run 'make' to build."
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
#ifndef EXECUTORCH_RUBY_UTILS_H
|
|
2
|
+
#define EXECUTORCH_RUBY_UTILS_H
|
|
3
|
+
|
|
4
|
+
#include <rice/rice.hpp>
|
|
5
|
+
#include <rice/stl.hpp>
|
|
6
|
+
#include <executorch/runtime/core/error.h>
|
|
7
|
+
#include <executorch/runtime/core/result.h>
|
|
8
|
+
|
|
9
|
+
// Error handling macros - translate C++ exceptions to Ruby exceptions
|
|
10
|
+
#define HANDLE_ET_ERRORS try {
|
|
11
|
+
|
|
12
|
+
#define END_HANDLE_ET_ERRORS \
|
|
13
|
+
} \
|
|
14
|
+
catch (const Rice::Exception &ex) { \
|
|
15
|
+
throw; \
|
|
16
|
+
} \
|
|
17
|
+
catch (const std::exception &ex) { \
|
|
18
|
+
rb_raise(rb_eRuntimeError, "ExecuTorch error: %s", ex.what()); \
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
namespace executorch_ruby {
|
|
22
|
+
|
|
23
|
+
inline void check_error(executorch::runtime::Error error) {
|
|
24
|
+
if (error != executorch::runtime::Error::Ok) {
|
|
25
|
+
const char* error_name = "Unknown error";
|
|
26
|
+
switch (error) {
|
|
27
|
+
case executorch::runtime::Error::Ok:
|
|
28
|
+
return;
|
|
29
|
+
case executorch::runtime::Error::Internal:
|
|
30
|
+
error_name = "Internal error";
|
|
31
|
+
break;
|
|
32
|
+
case executorch::runtime::Error::InvalidState:
|
|
33
|
+
error_name = "Invalid state";
|
|
34
|
+
break;
|
|
35
|
+
case executorch::runtime::Error::InvalidArgument:
|
|
36
|
+
error_name = "Invalid argument";
|
|
37
|
+
break;
|
|
38
|
+
case executorch::runtime::Error::InvalidType:
|
|
39
|
+
error_name = "Invalid type";
|
|
40
|
+
break;
|
|
41
|
+
case executorch::runtime::Error::NotFound:
|
|
42
|
+
error_name = "Not found";
|
|
43
|
+
break;
|
|
44
|
+
case executorch::runtime::Error::MemoryAllocationFailed:
|
|
45
|
+
error_name = "Memory allocation failed";
|
|
46
|
+
break;
|
|
47
|
+
case executorch::runtime::Error::AccessFailed:
|
|
48
|
+
error_name = "Access failed";
|
|
49
|
+
break;
|
|
50
|
+
case executorch::runtime::Error::NotSupported:
|
|
51
|
+
error_name = "Not supported";
|
|
52
|
+
break;
|
|
53
|
+
case executorch::runtime::Error::DelegateInvalidCompatibility:
|
|
54
|
+
error_name = "Delegate invalid compatibility";
|
|
55
|
+
break;
|
|
56
|
+
case executorch::runtime::Error::DelegateMemoryAllocationFailed:
|
|
57
|
+
error_name = "Delegate memory allocation failed";
|
|
58
|
+
break;
|
|
59
|
+
case executorch::runtime::Error::DelegateInvalidHandle:
|
|
60
|
+
error_name = "Delegate invalid handle";
|
|
61
|
+
break;
|
|
62
|
+
default:
|
|
63
|
+
error_name = "Unknown error";
|
|
64
|
+
break;
|
|
65
|
+
}
|
|
66
|
+
rb_raise(rb_eRuntimeError, "ExecuTorch error: %s", error_name);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// Helper to unwrap Result<T> and raise Ruby exception on error
|
|
71
|
+
template <typename T>
|
|
72
|
+
T unwrap_result(executorch::runtime::Result<T>&& result) {
|
|
73
|
+
if (!result.ok()) {
|
|
74
|
+
check_error(result.error());
|
|
75
|
+
// Should never reach here, but just in case
|
|
76
|
+
rb_raise(rb_eRuntimeError, "ExecuTorch: unexpected error");
|
|
77
|
+
}
|
|
78
|
+
return std::move(result.get());
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Convert ExecuTorch ScalarType to Ruby symbol
|
|
82
|
+
// Uses short names (:int, :long, :float) to match input symbols
|
|
83
|
+
inline VALUE scalar_type_to_symbol(executorch::aten::ScalarType dtype) {
|
|
84
|
+
switch (dtype) {
|
|
85
|
+
case executorch::aten::ScalarType::Byte:
|
|
86
|
+
return ID2SYM(rb_intern("byte"));
|
|
87
|
+
case executorch::aten::ScalarType::Char:
|
|
88
|
+
return ID2SYM(rb_intern("char"));
|
|
89
|
+
case executorch::aten::ScalarType::Short:
|
|
90
|
+
return ID2SYM(rb_intern("short"));
|
|
91
|
+
case executorch::aten::ScalarType::Int:
|
|
92
|
+
return ID2SYM(rb_intern("int"));
|
|
93
|
+
case executorch::aten::ScalarType::Long:
|
|
94
|
+
return ID2SYM(rb_intern("long"));
|
|
95
|
+
case executorch::aten::ScalarType::Half:
|
|
96
|
+
return ID2SYM(rb_intern("half"));
|
|
97
|
+
case executorch::aten::ScalarType::Float:
|
|
98
|
+
return ID2SYM(rb_intern("float"));
|
|
99
|
+
case executorch::aten::ScalarType::Double:
|
|
100
|
+
return ID2SYM(rb_intern("double"));
|
|
101
|
+
case executorch::aten::ScalarType::Bool:
|
|
102
|
+
return ID2SYM(rb_intern("bool"));
|
|
103
|
+
default:
|
|
104
|
+
return ID2SYM(rb_intern("unknown"));
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
inline executorch::aten::ScalarType symbol_to_scalar_type(VALUE sym) {
|
|
109
|
+
if (!RB_TYPE_P(sym, T_SYMBOL)) {
|
|
110
|
+
rb_raise(rb_eTypeError, "Expected Symbol for dtype");
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
ID id = SYM2ID(sym);
|
|
114
|
+
|
|
115
|
+
if (id == rb_intern("uint8") || id == rb_intern("byte")) {
|
|
116
|
+
return executorch::aten::ScalarType::Byte;
|
|
117
|
+
} else if (id == rb_intern("int8") || id == rb_intern("char")) {
|
|
118
|
+
return executorch::aten::ScalarType::Char;
|
|
119
|
+
} else if (id == rb_intern("int16") || id == rb_intern("short")) {
|
|
120
|
+
return executorch::aten::ScalarType::Short;
|
|
121
|
+
} else if (id == rb_intern("int32") || id == rb_intern("int")) {
|
|
122
|
+
return executorch::aten::ScalarType::Int;
|
|
123
|
+
} else if (id == rb_intern("int64") || id == rb_intern("long")) {
|
|
124
|
+
return executorch::aten::ScalarType::Long;
|
|
125
|
+
} else if (id == rb_intern("float16") || id == rb_intern("half")) {
|
|
126
|
+
return executorch::aten::ScalarType::Half;
|
|
127
|
+
} else if (id == rb_intern("float32") || id == rb_intern("float")) {
|
|
128
|
+
return executorch::aten::ScalarType::Float;
|
|
129
|
+
} else if (id == rb_intern("float64") || id == rb_intern("double")) {
|
|
130
|
+
return executorch::aten::ScalarType::Double;
|
|
131
|
+
} else if (id == rb_intern("bool")) {
|
|
132
|
+
return executorch::aten::ScalarType::Bool;
|
|
133
|
+
} else {
|
|
134
|
+
rb_raise(rb_eArgError, "Unknown dtype: %s", rb_id2name(id));
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
} // namespace executorch_ruby
|
|
139
|
+
|
|
140
|
+
#endif // EXECUTORCH_RUBY_UTILS_H
|
data/lib/executorch.rb
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "executorch/version"
|
|
4
|
+
|
|
5
|
+
# Load the native extension
|
|
6
|
+
begin
|
|
7
|
+
require_relative "executorch/executorch"
|
|
8
|
+
rescue LoadError => e
|
|
9
|
+
warn "Failed to load ExecuTorch native extension: #{e.message}"
|
|
10
|
+
warn "Make sure to run: bundle exec rake compile"
|
|
11
|
+
raise
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
module Executorch
|
|
15
|
+
# Extend Tensor with Ruby-friendly methods
|
|
16
|
+
class Tensor
|
|
17
|
+
class << self
|
|
18
|
+
# Create a new tensor from data
|
|
19
|
+
#
|
|
20
|
+
# @param data [Array] Array of numeric values (can be nested or flat)
|
|
21
|
+
# @param shape [Array<Integer>, nil] Shape of the tensor. If nil, inferred from data structure.
|
|
22
|
+
# @param dtype [Symbol] Data type (:float, :double, :int, :long)
|
|
23
|
+
# @return [Tensor]
|
|
24
|
+
#
|
|
25
|
+
# @example Create from nested array (shape inferred)
|
|
26
|
+
# Tensor.new([[1.0, 2.0], [3.0, 4.0]]) # shape: [2, 2]
|
|
27
|
+
#
|
|
28
|
+
# @example Create from flat array with explicit shape
|
|
29
|
+
# Tensor.new([1.0, 2.0, 3.0, 4.0], shape: [2, 2])
|
|
30
|
+
#
|
|
31
|
+
def new(data, shape: nil, dtype: :float)
|
|
32
|
+
if shape.nil?
|
|
33
|
+
# Infer shape from nested array structure
|
|
34
|
+
shape = infer_shape(data)
|
|
35
|
+
flat_data = flatten_nested(data, shape)
|
|
36
|
+
else
|
|
37
|
+
# Use flat data directly (backward compatibility)
|
|
38
|
+
flat_data = data
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
create(flat_data, shape, dtype)
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
private
|
|
45
|
+
|
|
46
|
+
# Infer the shape from a nested array structure
|
|
47
|
+
# @param data [Array] Potentially nested array
|
|
48
|
+
# @return [Array<Integer>] Inferred shape
|
|
49
|
+
def infer_shape(data)
|
|
50
|
+
return [0] if data.empty?
|
|
51
|
+
|
|
52
|
+
shape = []
|
|
53
|
+
current = data
|
|
54
|
+
|
|
55
|
+
while current.is_a?(Array)
|
|
56
|
+
shape << current.size
|
|
57
|
+
break if current.empty?
|
|
58
|
+
current = current.first
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Validate that all elements at each level have consistent sizes
|
|
62
|
+
validate_shape(data, shape, 0)
|
|
63
|
+
|
|
64
|
+
shape
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
# Validate that the array has consistent shape at all levels
|
|
68
|
+
# @param data [Array] The data to validate
|
|
69
|
+
# @param expected_shape [Array<Integer>] The expected shape
|
|
70
|
+
# @param depth [Integer] Current depth in the array
|
|
71
|
+
# @raise [ArgumentError] If the array is jagged or inconsistent
|
|
72
|
+
def validate_shape(data, expected_shape, depth)
|
|
73
|
+
return if depth >= expected_shape.size
|
|
74
|
+
return if expected_shape[depth] == 0
|
|
75
|
+
|
|
76
|
+
unless data.is_a?(Array)
|
|
77
|
+
raise ArgumentError, "Inconsistent nesting depth at level #{depth}: expected Array, got #{data.class}"
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
unless data.size == expected_shape[depth]
|
|
81
|
+
raise ArgumentError, "Jagged array at depth #{depth}: expected size #{expected_shape[depth]}, got #{data.size}"
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
data.each_with_index do |element, i|
|
|
85
|
+
if depth + 1 < expected_shape.size
|
|
86
|
+
# Expect more nesting
|
|
87
|
+
unless element.is_a?(Array)
|
|
88
|
+
raise ArgumentError, "Inconsistent nesting at depth #{depth}, index #{i}: expected Array, got #{element.class}"
|
|
89
|
+
end
|
|
90
|
+
validate_shape(element, expected_shape, depth + 1)
|
|
91
|
+
else
|
|
92
|
+
# At leaf level, should be numeric
|
|
93
|
+
if element.is_a?(Array)
|
|
94
|
+
raise ArgumentError, "Inconsistent nesting at depth #{depth}, index #{i}: unexpected Array at leaf level"
|
|
95
|
+
end
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
# Flatten a nested array into a 1D array
|
|
101
|
+
# @param data [Array] Potentially nested array
|
|
102
|
+
# @param shape [Array<Integer>] The shape (used to handle empty arrays)
|
|
103
|
+
# @return [Array] Flat array
|
|
104
|
+
def flatten_nested(data, shape)
|
|
105
|
+
return [] if shape.include?(0)
|
|
106
|
+
deep_flatten(data)
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Recursively flatten an array
|
|
110
|
+
# @param data [Array, Numeric] The data to flatten
|
|
111
|
+
# @return [Array] Flat array
|
|
112
|
+
def deep_flatten(data)
|
|
113
|
+
return [data] unless data.is_a?(Array)
|
|
114
|
+
data.flat_map { |element| deep_flatten(element) }
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
# Convert tensor to a nested Ruby array matching the tensor's shape
|
|
119
|
+
# @return [Array] Nested array with structure matching shape
|
|
120
|
+
#
|
|
121
|
+
# @example
|
|
122
|
+
# tensor = Tensor.new([1, 2, 3, 4], shape: [2, 2])
|
|
123
|
+
# tensor.to_a # => [[1.0, 2.0], [3.0, 4.0]]
|
|
124
|
+
#
|
|
125
|
+
def to_a
|
|
126
|
+
flat = flat_to_a
|
|
127
|
+
reshape_flat_to_nested(flat, shape)
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Convert tensor to a flat Ruby array (original behavior)
|
|
131
|
+
# @return [Array] Flat array of all values
|
|
132
|
+
alias_method :flat_to_a, :_original_to_a
|
|
133
|
+
|
|
134
|
+
private
|
|
135
|
+
|
|
136
|
+
# Reshape a flat array into nested arrays according to shape
|
|
137
|
+
# @param flat [Array] Flat array of values
|
|
138
|
+
# @param shape [Array<Integer>] Target shape
|
|
139
|
+
# @return [Array] Nested array
|
|
140
|
+
def reshape_flat_to_nested(flat, shape)
|
|
141
|
+
return flat if shape.size <= 1
|
|
142
|
+
|
|
143
|
+
# Handle empty dimensions
|
|
144
|
+
if shape.include?(0)
|
|
145
|
+
return build_empty_nested(shape)
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
# Calculate strides for each dimension
|
|
149
|
+
build_nested(flat, shape, 0, 0).first
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# Build nested array structure
|
|
153
|
+
# @param flat [Array] Flat data
|
|
154
|
+
# @param shape [Array<Integer>] Shape
|
|
155
|
+
# @param dim [Integer] Current dimension
|
|
156
|
+
# @param offset [Integer] Current offset in flat array
|
|
157
|
+
# @return [Array] [nested_result, new_offset]
|
|
158
|
+
def build_nested(flat, shape, dim, offset)
|
|
159
|
+
if dim == shape.size - 1
|
|
160
|
+
# Last dimension: slice the flat array
|
|
161
|
+
result = flat[offset, shape[dim]]
|
|
162
|
+
[result, offset + shape[dim]]
|
|
163
|
+
else
|
|
164
|
+
# Build sub-arrays
|
|
165
|
+
result = []
|
|
166
|
+
current_offset = offset
|
|
167
|
+
shape[dim].times do
|
|
168
|
+
sub_result, current_offset = build_nested(flat, shape, dim + 1, current_offset)
|
|
169
|
+
result << sub_result
|
|
170
|
+
end
|
|
171
|
+
[result, current_offset]
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
# Build empty nested structure for shapes with zero dimensions
|
|
176
|
+
# @param shape [Array<Integer>] Shape with at least one zero
|
|
177
|
+
# @return [Array] Empty nested structure
|
|
178
|
+
def build_empty_nested(shape)
|
|
179
|
+
return [] if shape.size == 1
|
|
180
|
+
|
|
181
|
+
first_zero = shape.index(0)
|
|
182
|
+
|
|
183
|
+
if first_zero == 0
|
|
184
|
+
[]
|
|
185
|
+
else
|
|
186
|
+
# Build arrays up to the zero dimension
|
|
187
|
+
build_empty_recursive(shape, 0)
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
# Recursively build empty nested arrays
|
|
192
|
+
# @param shape [Array<Integer>] Shape
|
|
193
|
+
# @param dim [Integer] Current dimension
|
|
194
|
+
# @return [Array] Empty nested structure
|
|
195
|
+
def build_empty_recursive(shape, dim)
|
|
196
|
+
return [] if dim >= shape.size || shape[dim] == 0
|
|
197
|
+
|
|
198
|
+
Array.new(shape[dim]) { build_empty_recursive(shape, dim + 1) }
|
|
199
|
+
end
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# Extend Model with Ruby-friendly methods
|
|
203
|
+
class Model
|
|
204
|
+
# Alias predict to forward for more intuitive API
|
|
205
|
+
alias_method :predict, :forward
|
|
206
|
+
|
|
207
|
+
# Make model callable
|
|
208
|
+
def call(inputs)
|
|
209
|
+
forward(inputs)
|
|
210
|
+
end
|
|
211
|
+
end
|
|
212
|
+
end
|
metadata
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
|
2
|
+
name: executorch
|
|
3
|
+
version: !ruby/object:Gem::Version
|
|
4
|
+
version: 0.1.0
|
|
5
|
+
platform: ruby
|
|
6
|
+
authors:
|
|
7
|
+
- Benjamin Garcia
|
|
8
|
+
bindir: bin
|
|
9
|
+
cert_chain: []
|
|
10
|
+
date: 1980-01-02 00:00:00.000000000 Z
|
|
11
|
+
dependencies:
|
|
12
|
+
- !ruby/object:Gem::Dependency
|
|
13
|
+
name: rice
|
|
14
|
+
requirement: !ruby/object:Gem::Requirement
|
|
15
|
+
requirements:
|
|
16
|
+
- - "~>"
|
|
17
|
+
- !ruby/object:Gem::Version
|
|
18
|
+
version: '4.3'
|
|
19
|
+
type: :runtime
|
|
20
|
+
prerelease: false
|
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
22
|
+
requirements:
|
|
23
|
+
- - "~>"
|
|
24
|
+
- !ruby/object:Gem::Version
|
|
25
|
+
version: '4.3'
|
|
26
|
+
description: Run PyTorch models exported with ExecuTorch in Ruby
|
|
27
|
+
email:
|
|
28
|
+
- hey@bengarcia.dev
|
|
29
|
+
executables: []
|
|
30
|
+
extensions:
|
|
31
|
+
- ext/executorch/extconf.rb
|
|
32
|
+
extra_rdoc_files: []
|
|
33
|
+
files:
|
|
34
|
+
- CHANGELOG.md
|
|
35
|
+
- LICENSE.txt
|
|
36
|
+
- README.md
|
|
37
|
+
- ext/executorch/executorch.cpp
|
|
38
|
+
- ext/executorch/extconf.rb
|
|
39
|
+
- ext/executorch/utils.h
|
|
40
|
+
- lib/executorch.rb
|
|
41
|
+
- lib/executorch/version.rb
|
|
42
|
+
homepage: https://github.com/benngarcia/executorch-ruby
|
|
43
|
+
licenses:
|
|
44
|
+
- Apache-2.0
|
|
45
|
+
metadata:
|
|
46
|
+
homepage_uri: https://github.com/benngarcia/executorch-ruby
|
|
47
|
+
source_code_uri: https://github.com/benngarcia/executorch-ruby
|
|
48
|
+
changelog_uri: https://github.com/benngarcia/executorch-ruby/blob/main/CHANGELOG.md
|
|
49
|
+
rdoc_options: []
|
|
50
|
+
require_paths:
|
|
51
|
+
- lib
|
|
52
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
|
53
|
+
requirements:
|
|
54
|
+
- - ">="
|
|
55
|
+
- !ruby/object:Gem::Version
|
|
56
|
+
version: 3.0.0
|
|
57
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
58
|
+
requirements:
|
|
59
|
+
- - ">="
|
|
60
|
+
- !ruby/object:Gem::Version
|
|
61
|
+
version: '0'
|
|
62
|
+
requirements: []
|
|
63
|
+
rubygems_version: 3.6.9
|
|
64
|
+
specification_version: 4
|
|
65
|
+
summary: Ruby bindings for ExecuTorch
|
|
66
|
+
test_files: []
|