tensorflow 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 32ac7a47334631b2c5393721e5a7f848a26d9535b473557e146e8e6b4b6337be
4
+ data.tar.gz: 21ca600101471ee43edd9a404f3072d67cbbdfa44f412ba8ec76a2546973087d
5
+ SHA512:
6
+ metadata.gz: 601c99ff6035138797f0e14a0f5f2c52ae68b4ee4c5223ef5676e45aaf2f61a7994f0a2ce2ad353b03fa503d7ad8fd6e9c49ba7567fbdcee2156bffe6d36a4c4
7
+ data.tar.gz: 825d6b40f58f3cbe4d1c6f5b9cff5c7a739b730a89e025bf305bcdf3789269ce34f851709a0294d2c830ca1621a08ff3a67b352422bafd92eb8dec30f65e42e3
data/CHANGELOG.md ADDED
@@ -0,0 +1,3 @@
1
+ ## 0.1.0
2
+
3
+ - First release
data/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
1
+ Copyright (c) 2019 Andrew Kane
2
+
3
+ MIT License
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining
6
+ a copy of this software and associated documentation files (the
7
+ "Software"), to deal in the Software without restriction, including
8
+ without limitation the rights to use, copy, modify, merge, publish,
9
+ distribute, sublicense, and/or sell copies of the Software, and to
10
+ permit persons to whom the Software is furnished to do so, subject to
11
+ the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be
14
+ included in all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,113 @@
1
+ # TensorFlow
2
+
3
+ [TensorFlow](https://github.com/tensorflow/tensorflow) - the end-to-end machine learning platform - for Ruby
4
+
5
+ :fire: Uses the C API for blazing performance
6
+
7
+ ## Installation
8
+
9
+ [Install TensorFlow](#tensorflow-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install tensorflow
13
+ ```
14
+
15
+ Add this line to your application’s Gemfile:
16
+
17
+ ```ruby
18
+ gem 'tensorflow'
19
+ ```
20
+
21
+ ## Getting Started
22
+
23
+ This library follows the TensorFlow 2.0 [Python API](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Constants
26
+
27
+ ```ruby
28
+ a = Tf.constant(2)
29
+ b = Tf.constant(3)
30
+ a + b
31
+ ```
32
+
33
+ ## Variables
34
+
35
+ ```ruby
36
+ v = Tf::Variable.new(0.0)
37
+ w = v + 1
38
+ ```
39
+
40
+ ## FizzBuzz
41
+
42
+ ```ruby
43
+ def fizzbuzz(max_num)
44
+ counter = Tf.constant(0)
45
+ max_num.times do |i|
46
+ num = Tf.constant(i + 1)
47
+ if (num % 3).to_i == 0 && (num % 5).to_i == 0
48
+ puts "FizzBuzz"
49
+ elsif (num % 3).to_i == 0
50
+ puts "Fizz"
51
+ elsif (num % 5).to_i == 0
52
+ puts "Buzz"
53
+ else
54
+ puts num.to_i
55
+ end
56
+ end
57
+ end
58
+
59
+ fizzbuzz(15)
60
+ ```
61
+
62
+ ## Keras
63
+
64
+ Coming soon
65
+
66
+ ```ruby
67
+ mnist = Tf::Keras::Datasets::MNIST
68
+ (x_train, y_train), (x_test, y_test) = mnist.load_data
69
+ x_train = x_train / 255.0
70
+ x_test = x_test / 255.0
71
+
72
+ model = Tf::Keras::Models::Sequential.new([
73
+ Tf::Keras::Layers::Flatten.new(input_shape: [28, 28]),
74
+ Tf::Keras::Layers::Dense.new(128, activation: "relu"),
75
+ Tf::Keras::Layers::Dropout.new(0.2),
76
+ Tf::Keras::Layers::Dense.new(10, activation: "softmax")
77
+ ])
78
+
79
+ model.compile(optimizer: "adam", loss: "sparse_categorical_crossentropy", metrics: ["accuracy"])
80
+ model.fit(x_train, y_train, epochs: 5)
81
+ model.evaluate(x_test, y_test)
82
+ ```
83
+
84
+ ## TensorFlow Installation
85
+
86
+ ### Mac
87
+
88
+ Run:
89
+
90
+ ```sh
91
+ brew install tensorflow
92
+ ```
93
+
94
+ ### Linux
95
+
96
+ Download the [shared library](https://www.tensorflow.org/install/lang_c#download) and move `libtensorflow.so` to `/usr/local/lib`.
97
+
98
+ ### Windows
99
+
100
+ Download the [shared library](https://www.tensorflow.org/install/lang_c#download) and move `tensorflow.dll` to `C:\Windows\System32`.
101
+
102
+ ## History
103
+
104
+ View the [changelog](https://github.com/ankane/tensorflow/blob/master/CHANGELOG.md)
105
+
106
+ ## Contributing
107
+
108
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
109
+
110
+ - [Report bugs](https://github.com/ankane/tensorflow/issues)
111
+ - Fix bugs and [submit pull requests](https://github.com/ankane/tensorflow/pulls)
112
+ - Write, clarify, or fix documentation
113
+ - Suggest or add new features
@@ -0,0 +1,26 @@
1
+ module TensorFlow
2
+ class Context
3
+ def initialize
4
+ options = FFI.TFE_NewContextOptions
5
+ @status = TensorFlow::FFI.TF_NewStatus
6
+ @pointer = FFI.TFE_NewContext(options, @status)
7
+ Utils.check_status @status
8
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
9
+ FFI.TFE_DeleteContextOptions(options)
10
+ end
11
+
12
+ def self.finalize(pointer)
13
+ # must use proc instead of stabby lambda
14
+ proc { FFI.TFE_DeleteContext(pointer) }
15
+ end
16
+
17
+ def to_ptr
18
+ @pointer
19
+ end
20
+
21
+ def shared_name
22
+ # hard-coded in Python library
23
+ "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,60 @@
1
+ module TensorFlow
2
+ module FFI
3
+ extend ::FFI::Library
4
+
5
+ begin
6
+ ffi_lib TensorFlow.ffi_lib
7
+ rescue LoadError => e
8
+ raise e if ENV["TENSORFLOW_DEBUG"]
9
+ raise LoadError, "Could not find TensorFlow"
10
+ end
11
+
12
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
13
+ attach_function :TF_Version, %i[], :string
14
+
15
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_attrtype.h
16
+ AttrType = enum(:string, :int, :float, :bool, :type, :shape, :tensor, :placeholder, :func)
17
+
18
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_datatype.h
19
+ DataType = enum(:float, 1, :double, :int32, :uint8, :int16, :int8, :string, :complex64, :int64, :bool, :qint8, :quint8, :qint32, :bfloat16, :qint16, :quint16, :uint16, :complex128, :half, :resource, :variant, :uint32, :uint64)
20
+
21
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h
22
+ attach_function :TF_NewStatus, %i[], :pointer
23
+ attach_function :TF_DeleteStatus, %i[pointer], :pointer
24
+ attach_function :TF_GetCode, %i[pointer], :int
25
+ attach_function :TF_Message, %i[pointer], :string
26
+
27
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h
28
+ attach_function :TF_NewTensor, %i[int pointer int pointer size_t pointer pointer], :pointer
29
+ attach_function :TF_DeleteTensor, %i[pointer], :void
30
+ attach_function :TF_TensorData, %i[pointer], :pointer
31
+
32
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/eager/c_api.h
33
+ attach_function :TFE_NewContextOptions, %i[], :pointer
34
+ attach_function :TFE_ContextOptionsSetAsync, %i[pointer char], :void
35
+ attach_function :TFE_DeleteContextOptions, %i[pointer], :void
36
+ attach_function :TFE_NewContext, %i[pointer pointer], :pointer
37
+ attach_function :TFE_DeleteContext, %i[pointer], :void
38
+ attach_function :TFE_ContextListDevices, %i[pointer pointer], :pointer
39
+ attach_function :TFE_ContextGetDevicePlacementPolicy, %i[pointer], :int
40
+ attach_function :TFE_NewTensorHandle, %i[pointer pointer], :pointer
41
+ attach_function :TFE_DeleteTensorHandle, %i[pointer], :void
42
+ attach_function :TFE_TensorHandleDataType, %i[pointer], :int
43
+ attach_function :TFE_TensorHandleNumDims, %i[pointer pointer], :int
44
+ attach_function :TFE_TensorHandleNumElements, %i[pointer pointer], :int64
45
+ attach_function :TFE_TensorHandleDim, %i[pointer int pointer], :int64
46
+ attach_function :TFE_TensorHandleDeviceName, %i[pointer pointer], :string
47
+ attach_function :TFE_TensorHandleBackingDeviceName, %i[pointer pointer], :string
48
+ attach_function :TFE_TensorHandleResolve, %i[pointer pointer], :pointer
49
+ attach_function :TFE_NewOp, %i[pointer string pointer], :pointer
50
+ attach_function :TFE_DeleteOp, %i[pointer], :void
51
+ attach_function :TFE_OpSetDevice, %i[pointer string pointer], :pointer
52
+ attach_function :TFE_OpGetDevice, %i[pointer pointer], :string
53
+ attach_function :TFE_OpAddInput, %i[pointer pointer pointer], :void
54
+ attach_function :TFE_OpGetAttrType, %i[pointer string pointer pointer], :int
55
+ attach_function :TFE_OpSetAttrString, %i[pointer string pointer size_t], :void
56
+ attach_function :TFE_OpSetAttrType, %i[pointer string int], :void
57
+ attach_function :TFE_OpSetAttrShape, %i[pointer string pointer int pointer], :void
58
+ attach_function :TFE_Execute, %i[pointer pointer pointer pointer], :pointer
59
+ end
60
+ end
@@ -0,0 +1,193 @@
1
+ module TensorFlow
2
+ class Tensor
3
+ def initialize(value = nil, pointer: nil, dtype: nil, shape: nil)
4
+ @status = FFI.TF_NewStatus
5
+
6
+ if pointer
7
+ @pointer = pointer
8
+ else
9
+ data = Array(value)
10
+ shape ||= calculate_shape(value)
11
+
12
+ if shape.size > 0
13
+ dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
14
+ dims_ptr.write_array_of_int64(shape)
15
+ else
16
+ dims_ptr = nil
17
+ end
18
+
19
+ data = data.flatten
20
+
21
+ dtype ||= Utils.infer_type(value)
22
+ type = FFI::DataType[dtype]
23
+ case dtype
24
+ when :string
25
+ data_ptr = string_ptr(data)
26
+ when :float
27
+ data_ptr = ::FFI::MemoryPointer.new(:float, data.size)
28
+ data_ptr.write_array_of_float(data)
29
+ when :int32
30
+ data_ptr = ::FFI::MemoryPointer.new(:int32, data.size)
31
+ data_ptr.write_array_of_int32(data)
32
+ else
33
+ raise "Unknown type: #{dtype}"
34
+ end
35
+
36
+ callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
37
+ # FFI handles deallocation
38
+ end
39
+
40
+ tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
41
+ @pointer = FFI.TFE_NewTensorHandle(tensor, @status)
42
+ check_status @status
43
+ end
44
+
45
+ # TODO fix segfault
46
+ # ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
47
+ end
48
+
49
+ def +(other)
50
+ TensorFlow.add(self, other)
51
+ end
52
+
53
+ def -(other)
54
+ TensorFlow.subtract(self, other)
55
+ end
56
+
57
+ def *(other)
58
+ TensorFlow.multiply(self, other)
59
+ end
60
+
61
+ def /(other)
62
+ TensorFlow.divide(self, other)
63
+ end
64
+
65
+ def %(other)
66
+ TensorFlow.floormod(self, other)
67
+ end
68
+
69
+ def num_dims
70
+ ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
71
+ check_status @status
72
+ ret
73
+ end
74
+
75
+ def dtype
76
+ @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
77
+ end
78
+
79
+ def element_count
80
+ ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
81
+ check_status @status
82
+ ret
83
+ end
84
+
85
+ def shape
86
+ @shape ||= begin
87
+ shape = []
88
+ num_dims.times do |i|
89
+ shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
90
+ check_status @status
91
+ end
92
+ shape
93
+ end
94
+ end
95
+
96
+ def data_pointer
97
+ tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
98
+ check_status @status
99
+ FFI.TF_TensorData(tensor)
100
+ end
101
+
102
+ def to_ptr
103
+ @pointer
104
+ end
105
+
106
+ def value
107
+ value =
108
+ case dtype
109
+ when :float
110
+ data_pointer.read_array_of_float(element_count)
111
+ when :int32
112
+ data_pointer.read_array_of_int32(element_count)
113
+ when :string
114
+ # string tensor format
115
+ # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
116
+ start_offset_size = element_count * 8
117
+ offsets = data_pointer.read_array_of_uint64(element_count)
118
+ element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string }
119
+ when :bool
120
+ data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
121
+ when :resource
122
+ return data_pointer
123
+ else
124
+ raise "Unknown type: #{dtype}"
125
+ end
126
+
127
+ reshape(value, shape)
128
+ end
129
+
130
+ def to_s
131
+ inspect
132
+ end
133
+
134
+ def to_i
135
+ value.to_i
136
+ end
137
+
138
+ def to_a
139
+ value
140
+ end
141
+
142
+ def inspect
143
+ inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
144
+ "#<#{self.class} #{inspection.join(", ")}>"
145
+ end
146
+
147
+ def self.finalize(pointer)
148
+ # must use proc instead of stabby lambda
149
+ proc { FFI.TFE_DeleteTensorHandle(pointer) }
150
+ end
151
+
152
+ private
153
+
154
+ def reshape(arr, dims)
155
+ return arr.first if dims.empty?
156
+ arr = arr.flatten
157
+ dims[1..-1].reverse.each do |dim|
158
+ arr = arr.each_slice(dim)
159
+ end
160
+ arr.to_a
161
+ end
162
+
163
+ def calculate_shape(value)
164
+ shape = []
165
+ d = value
166
+ while d.is_a?(Array)
167
+ shape << d.size
168
+ d = d.first
169
+ end
170
+ shape
171
+ end
172
+
173
+ # string tensor format
174
+ # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
175
+ def string_ptr(data)
176
+ start_offset_size = data.size * 8
177
+ offsets = [0]
178
+ data.each do |str|
179
+ offsets << str.bytesize + 1
180
+ end
181
+ data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
182
+ data_ptr.write_array_of_uint64(offsets)
183
+ data.zip(offsets) do |str, off|
184
+ (data_ptr + start_offset_size + off).write_string(str)
185
+ end
186
+ data_ptr
187
+ end
188
+
189
+ def check_status(status)
190
+ Utils.check_status(status)
191
+ end
192
+ end
193
+ end
@@ -0,0 +1,22 @@
1
+ module TensorFlow
2
+ module Utils
3
+ def self.check_status(status)
4
+ if FFI.TF_GetCode(status) != 0
5
+ raise Error, FFI.TF_Message(status)
6
+ end
7
+ end
8
+
9
+ def self.infer_type(value)
10
+ case value
11
+ when String
12
+ :string
13
+ when Float
14
+ :float
15
+ when true, false
16
+ :bool
17
+ else
18
+ :int32
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,57 @@
1
+ module TensorFlow
2
+ class Variable
3
+ def initialize(initial_value, dtype: nil)
4
+ @dtype = dtype || Utils.infer_type(initial_value)
5
+ @pointer = TensorFlow.var_handle_op(type_enum, nil, shared_name: TensorFlow.send(:default_context).shared_name)
6
+ assign(initial_value)
7
+ end
8
+
9
+ def assign(value)
10
+ value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
11
+ TensorFlow.assign_variable_op(@pointer, value)
12
+ self
13
+ end
14
+
15
+ def assign_add(value)
16
+ value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
17
+ TensorFlow.assign_add_variable_op(@pointer, value)
18
+ self
19
+ end
20
+
21
+ def assign_sub(value)
22
+ value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
23
+ TensorFlow.assign_sub_variable_op(@pointer, value)
24
+ self
25
+ end
26
+
27
+ def read_value
28
+ TensorFlow.read_variable_op(@pointer, type_enum)
29
+ end
30
+
31
+ def +(other)
32
+ v = Variable.new(read_value.value, dtype: @dtype)
33
+ v.assign_add(other).read_value
34
+ end
35
+
36
+ def -(other)
37
+ v = Variable.new(read_value.value, dtype: @dtype)
38
+ v.assign_sub(other).read_value
39
+ end
40
+
41
+ def to_s
42
+ inspect
43
+ end
44
+
45
+ def inspect
46
+ value = read_value
47
+ inspection = %w(value shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
48
+ "#<#{self.class} #{inspection.join(", ")}>"
49
+ end
50
+
51
+ private
52
+
53
+ def type_enum
54
+ FFI::DataType[@dtype.to_sym] if @dtype
55
+ end
56
+ end
57
+ end
@@ -0,0 +1,3 @@
1
+ module TensorFlow
2
+ VERSION = "0.1.0"
3
+ end
data/lib/tensorflow.rb ADDED
@@ -0,0 +1,188 @@
1
+ # dependencies
2
+ require "ffi"
3
+
4
+ # modules
5
+ require "tensorflow/utils"
6
+ require "tensorflow/context"
7
+ require "tensorflow/tensor"
8
+ require "tensorflow/variable"
9
+ require "tensorflow/version"
10
+
11
+ module TensorFlow
12
+ class Error < StandardError; end
13
+
14
+ class << self
15
+ attr_accessor :ffi_lib
16
+ end
17
+ self.ffi_lib = ["tensorflow", "libtensorflow.so"]
18
+
19
+ # friendlier error message
20
+ autoload :FFI, "tensorflow/ffi"
21
+
22
+ class << self
23
+ include Utils
24
+
25
+ def library_version
26
+ FFI.TF_Version
27
+ end
28
+
29
+ def constant(value, dtype: nil, shape: nil)
30
+ Tensor.new(value, dtype: dtype, shape: shape)
31
+ end
32
+
33
+ def convert_to_tensor(value, dtype: nil)
34
+ value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor)
35
+ value
36
+ end
37
+
38
+ def add(x, y)
39
+ execute("Add", [x, y])
40
+ end
41
+
42
+ def subtract(x, y)
43
+ execute("Sub", [x, y])
44
+ end
45
+
46
+ def multiply(x, y)
47
+ execute("Mul", [x, y])
48
+ end
49
+
50
+ def divide(x, y)
51
+ execute("Div", [x, y])
52
+ end
53
+
54
+ def abs(x)
55
+ execute("Abs", [x])
56
+ end
57
+
58
+ def sqrt(x)
59
+ execute("Sqrt", [x])
60
+ end
61
+
62
+ def matmul(x, y)
63
+ execute("MatMul", [x, y])
64
+ end
65
+
66
+ def floormod(x, y)
67
+ execute("Mod", [x, y])
68
+ end
69
+
70
+ def range(start, limit, delta)
71
+ execute("Range", [start, limit, delta])
72
+ end
73
+
74
+ def transpose(x, perm: [1, 0])
75
+ execute("Transpose", [x, perm])
76
+ end
77
+
78
+ def equal(x, y)
79
+ execute("Equal", [x, y])
80
+ end
81
+
82
+ def zeros_like(x)
83
+ execute("ZerosLike", [x])
84
+ end
85
+
86
+ def fill(dims, value)
87
+ execute("Fill", [dims, value])
88
+ end
89
+
90
+ def zeros(dims)
91
+ fill(dims, 0)
92
+ end
93
+
94
+ def ones(dims)
95
+ fill(dims, 1)
96
+ end
97
+
98
+ def assign_add_variable_op(resource, value)
99
+ execute("AssignAddVariableOp", [resource, value])
100
+ end
101
+
102
+ def assign_sub_variable_op(resource, value)
103
+ execute("AssignSubVariableOp", [resource, value])
104
+ end
105
+
106
+ def assign_variable_op(resource, value)
107
+ execute("AssignVariableOp", [resource, value])
108
+ end
109
+
110
+ def read_variable_op(resource, dtype)
111
+ execute("ReadVariableOp", [resource], dtype: dtype)
112
+ end
113
+
114
+ def var_handle_op(dtype, shape, container: "", shared_name: "")
115
+ execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape)
116
+ end
117
+
118
+ def var_is_initialized_op(resource)
119
+ execute("VarIsInitializedOp", [resource])
120
+ end
121
+
122
+ private
123
+
124
+ def default_context
125
+ @default_context ||= Context.new
126
+ end
127
+
128
+ def execute(op_name, inputs = [], **attrs)
129
+ context = default_context
130
+ status = FFI.TF_NewStatus
131
+ op = FFI.TFE_NewOp(context, op_name, status)
132
+ check_status status
133
+ # TODO clean up status and op
134
+
135
+ attrs.each do |attr_name, attr_value|
136
+ attr_name = attr_name.to_s
137
+
138
+ is_list = ::FFI::MemoryPointer.new(:int)
139
+ type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
140
+ check_status status
141
+
142
+ case FFI::AttrType[type]
143
+ when :type
144
+ FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
145
+ when :shape
146
+ # TODO set value properly
147
+ FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status)
148
+ check_status status
149
+ when :string
150
+ FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
151
+ else
152
+ raise "Unknown type: #{FFI::AttrType[type]}"
153
+ end
154
+ end
155
+
156
+ inputs.each do |input|
157
+ input = convert_to_tensor(input) unless input.respond_to?(:to_ptr)
158
+ FFI.TFE_OpAddInput(op, input, status)
159
+ check_status status
160
+ end
161
+
162
+ retvals = ::FFI::MemoryPointer.new(:pointer)
163
+ num_retvals = ::FFI::MemoryPointer.new(:int)
164
+ num_retvals.write_int(retvals.size)
165
+ FFI.TFE_Execute(op, retvals, num_retvals, status)
166
+ check_status status
167
+
168
+ if num_retvals.read_int > 0
169
+ handle = retvals.read_pointer
170
+ type = FFI.TFE_TensorHandleDataType(handle)
171
+
172
+ case FFI::DataType[type]
173
+ when :resource
174
+ handle
175
+ else
176
+ Tensor.new(pointer: handle)
177
+ end
178
+ end
179
+ end
180
+
181
+ def check_status(status)
182
+ Utils.check_status(status)
183
+ end
184
+ end
185
+ end
186
+
187
+ # shortcut
188
+ Tf = TensorFlow
metadata ADDED
@@ -0,0 +1,108 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: tensorflow
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Andrew Kane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2019-09-18 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: ffi
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: minitest
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '5'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '5'
69
+ description:
70
+ email: andrew@chartkick.com
71
+ executables: []
72
+ extensions: []
73
+ extra_rdoc_files: []
74
+ files:
75
+ - CHANGELOG.md
76
+ - LICENSE.txt
77
+ - README.md
78
+ - lib/tensorflow.rb
79
+ - lib/tensorflow/context.rb
80
+ - lib/tensorflow/ffi.rb
81
+ - lib/tensorflow/tensor.rb
82
+ - lib/tensorflow/utils.rb
83
+ - lib/tensorflow/variable.rb
84
+ - lib/tensorflow/version.rb
85
+ homepage: https://github.com/ankane/tensorflow
86
+ licenses:
87
+ - MIT
88
+ metadata: {}
89
+ post_install_message:
90
+ rdoc_options: []
91
+ require_paths:
92
+ - lib
93
+ required_ruby_version: !ruby/object:Gem::Requirement
94
+ requirements:
95
+ - - ">="
96
+ - !ruby/object:Gem::Version
97
+ version: '2.4'
98
+ required_rubygems_version: !ruby/object:Gem::Requirement
99
+ requirements:
100
+ - - ">="
101
+ - !ruby/object:Gem::Version
102
+ version: '0'
103
+ requirements: []
104
+ rubygems_version: 3.0.3
105
+ signing_key:
106
+ specification_version: 4
107
+ summary: TensorFlow - the end-to-end machine learning platform - for Ruby
108
+ test_files: []