tensorflow 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +113 -0
- data/lib/tensorflow/context.rb +26 -0
- data/lib/tensorflow/ffi.rb +60 -0
- data/lib/tensorflow/tensor.rb +193 -0
- data/lib/tensorflow/utils.rb +22 -0
- data/lib/tensorflow/variable.rb +57 -0
- data/lib/tensorflow/version.rb +3 -0
- data/lib/tensorflow.rb +188 -0
- metadata +108 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 32ac7a47334631b2c5393721e5a7f848a26d9535b473557e146e8e6b4b6337be
|
4
|
+
data.tar.gz: 21ca600101471ee43edd9a404f3072d67cbbdfa44f412ba8ec76a2546973087d
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 601c99ff6035138797f0e14a0f5f2c52ae68b4ee4c5223ef5676e45aaf2f61a7994f0a2ce2ad353b03fa503d7ad8fd6e9c49ba7567fbdcee2156bffe6d36a4c4
|
7
|
+
data.tar.gz: 825d6b40f58f3cbe4d1c6f5b9cff5c7a739b730a89e025bf305bcdf3789269ce34f851709a0294d2c830ca1621a08ff3a67b352422bafd92eb8dec30f65e42e3
|
data/CHANGELOG.md
ADDED
data/LICENSE.txt
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
Copyright (c) 2019 Andrew Kane
|
2
|
+
|
3
|
+
MIT License
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining
|
6
|
+
a copy of this software and associated documentation files (the
|
7
|
+
"Software"), to deal in the Software without restriction, including
|
8
|
+
without limitation the rights to use, copy, modify, merge, publish,
|
9
|
+
distribute, sublicense, and/or sell copies of the Software, and to
|
10
|
+
permit persons to whom the Software is furnished to do so, subject to
|
11
|
+
the following conditions:
|
12
|
+
|
13
|
+
The above copyright notice and this permission notice shall be
|
14
|
+
included in all copies or substantial portions of the Software.
|
15
|
+
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17
|
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
18
|
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
19
|
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
20
|
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
21
|
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
22
|
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
data/README.md
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
# TensorFlow
|
2
|
+
|
3
|
+
[TensorFlow](https://github.com/tensorflow/tensorflow) - the end-to-end machine learning platform - for Ruby
|
4
|
+
|
5
|
+
:fire: Uses the C API for blazing performance
|
6
|
+
|
7
|
+
## Installation
|
8
|
+
|
9
|
+
[Install TensorFlow](#tensorflow-installation). For Homebrew, use:
|
10
|
+
|
11
|
+
```sh
|
12
|
+
brew install tensorflow
|
13
|
+
```
|
14
|
+
|
15
|
+
Add this line to your application’s Gemfile:
|
16
|
+
|
17
|
+
```ruby
|
18
|
+
gem 'tensorflow'
|
19
|
+
```
|
20
|
+
|
21
|
+
## Getting Started
|
22
|
+
|
23
|
+
This library follows the TensorFlow 2.0 [Python API](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf). Many methods and options are missing at the moment. PRs welcome!
|
24
|
+
|
25
|
+
## Constants
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
a = Tf.constant(2)
|
29
|
+
b = Tf.constant(3)
|
30
|
+
a + b
|
31
|
+
```
|
32
|
+
|
33
|
+
## Variables
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
v = Tf::Variable.new(0.0)
|
37
|
+
w = v + 1
|
38
|
+
```
|
39
|
+
|
40
|
+
## FizzBuzz
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
def fizzbuzz(max_num)
|
44
|
+
counter = Tf.constant(0)
|
45
|
+
max_num.times do |i|
|
46
|
+
num = Tf.constant(i + 1)
|
47
|
+
if (num % 3).to_i == 0 && (num % 5).to_i == 0
|
48
|
+
puts "FizzBuzz"
|
49
|
+
elsif (num % 3).to_i == 0
|
50
|
+
puts "Fizz"
|
51
|
+
elsif (num % 5).to_i == 0
|
52
|
+
puts "Buzz"
|
53
|
+
else
|
54
|
+
puts num.to_i
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
fizzbuzz(15)
|
60
|
+
```
|
61
|
+
|
62
|
+
## Keras
|
63
|
+
|
64
|
+
Coming soon
|
65
|
+
|
66
|
+
```ruby
|
67
|
+
mnist = Tf::Keras::Datasets::MNIST
|
68
|
+
(x_train, y_train), (x_test, y_test) = mnist.load_data
|
69
|
+
x_train = x_train / 255.0
|
70
|
+
x_test = x_test / 255.0
|
71
|
+
|
72
|
+
model = Tf::Keras::Models::Sequential.new([
|
73
|
+
Tf::Keras::Layers::Flatten.new(input_shape: [28, 28]),
|
74
|
+
Tf::Keras::Layers::Dense.new(128, activation: "relu"),
|
75
|
+
Tf::Keras::Layers::Dropout.new(0.2),
|
76
|
+
Tf::Keras::Layers::Dense.new(10, activation: "softmax")
|
77
|
+
])
|
78
|
+
|
79
|
+
model.compile(optimizer: "adam", loss: "sparse_categorical_crossentropy", metrics: ["accuracy"])
|
80
|
+
model.fit(x_train, y_train, epochs: 5)
|
81
|
+
model.evaluate(x_test, y_test)
|
82
|
+
```
|
83
|
+
|
84
|
+
## TensorFlow Installation
|
85
|
+
|
86
|
+
### Mac
|
87
|
+
|
88
|
+
Run:
|
89
|
+
|
90
|
+
```sh
|
91
|
+
brew install tensorflow
|
92
|
+
```
|
93
|
+
|
94
|
+
### Linux
|
95
|
+
|
96
|
+
Download the [shared library](https://www.tensorflow.org/install/lang_c#download) and move `libtensorflow.so` to `/usr/local/lib`.
|
97
|
+
|
98
|
+
### Windows
|
99
|
+
|
100
|
+
Download the [shared library](https://www.tensorflow.org/install/lang_c#download) and move `tensorflow.dll` to `C:\Windows\System32`.
|
101
|
+
|
102
|
+
## History
|
103
|
+
|
104
|
+
View the [changelog](https://github.com/ankane/tensorflow/blob/master/CHANGELOG.md)
|
105
|
+
|
106
|
+
## Contributing
|
107
|
+
|
108
|
+
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
109
|
+
|
110
|
+
- [Report bugs](https://github.com/ankane/tensorflow/issues)
|
111
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/tensorflow/pulls)
|
112
|
+
- Write, clarify, or fix documentation
|
113
|
+
- Suggest or add new features
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
class Context
|
3
|
+
def initialize
|
4
|
+
options = FFI.TFE_NewContextOptions
|
5
|
+
@status = TensorFlow::FFI.TF_NewStatus
|
6
|
+
@pointer = FFI.TFE_NewContext(options, @status)
|
7
|
+
Utils.check_status @status
|
8
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
9
|
+
FFI.TFE_DeleteContextOptions(options)
|
10
|
+
end
|
11
|
+
|
12
|
+
def self.finalize(pointer)
|
13
|
+
# must use proc instead of stabby lambda
|
14
|
+
proc { FFI.TFE_DeleteContext(pointer) }
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_ptr
|
18
|
+
@pointer
|
19
|
+
end
|
20
|
+
|
21
|
+
def shared_name
|
22
|
+
# hard-coded in Python library
|
23
|
+
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,60 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
module FFI
|
3
|
+
extend ::FFI::Library
|
4
|
+
|
5
|
+
begin
|
6
|
+
ffi_lib TensorFlow.ffi_lib
|
7
|
+
rescue LoadError => e
|
8
|
+
raise e if ENV["TENSORFLOW_DEBUG"]
|
9
|
+
raise LoadError, "Could not find TensorFlow"
|
10
|
+
end
|
11
|
+
|
12
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
|
13
|
+
attach_function :TF_Version, %i[], :string
|
14
|
+
|
15
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_attrtype.h
|
16
|
+
AttrType = enum(:string, :int, :float, :bool, :type, :shape, :tensor, :placeholder, :func)
|
17
|
+
|
18
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_datatype.h
|
19
|
+
DataType = enum(:float, 1, :double, :int32, :uint8, :int16, :int8, :string, :complex64, :int64, :bool, :qint8, :quint8, :qint32, :bfloat16, :qint16, :quint16, :uint16, :complex128, :half, :resource, :variant, :uint32, :uint64)
|
20
|
+
|
21
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h
|
22
|
+
attach_function :TF_NewStatus, %i[], :pointer
|
23
|
+
attach_function :TF_DeleteStatus, %i[pointer], :pointer
|
24
|
+
attach_function :TF_GetCode, %i[pointer], :int
|
25
|
+
attach_function :TF_Message, %i[pointer], :string
|
26
|
+
|
27
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h
|
28
|
+
attach_function :TF_NewTensor, %i[int pointer int pointer size_t pointer pointer], :pointer
|
29
|
+
attach_function :TF_DeleteTensor, %i[pointer], :void
|
30
|
+
attach_function :TF_TensorData, %i[pointer], :pointer
|
31
|
+
|
32
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/eager/c_api.h
|
33
|
+
attach_function :TFE_NewContextOptions, %i[], :pointer
|
34
|
+
attach_function :TFE_ContextOptionsSetAsync, %i[pointer char], :void
|
35
|
+
attach_function :TFE_DeleteContextOptions, %i[pointer], :void
|
36
|
+
attach_function :TFE_NewContext, %i[pointer pointer], :pointer
|
37
|
+
attach_function :TFE_DeleteContext, %i[pointer], :void
|
38
|
+
attach_function :TFE_ContextListDevices, %i[pointer pointer], :pointer
|
39
|
+
attach_function :TFE_ContextGetDevicePlacementPolicy, %i[pointer], :int
|
40
|
+
attach_function :TFE_NewTensorHandle, %i[pointer pointer], :pointer
|
41
|
+
attach_function :TFE_DeleteTensorHandle, %i[pointer], :void
|
42
|
+
attach_function :TFE_TensorHandleDataType, %i[pointer], :int
|
43
|
+
attach_function :TFE_TensorHandleNumDims, %i[pointer pointer], :int
|
44
|
+
attach_function :TFE_TensorHandleNumElements, %i[pointer pointer], :int64
|
45
|
+
attach_function :TFE_TensorHandleDim, %i[pointer int pointer], :int64
|
46
|
+
attach_function :TFE_TensorHandleDeviceName, %i[pointer pointer], :string
|
47
|
+
attach_function :TFE_TensorHandleBackingDeviceName, %i[pointer pointer], :string
|
48
|
+
attach_function :TFE_TensorHandleResolve, %i[pointer pointer], :pointer
|
49
|
+
attach_function :TFE_NewOp, %i[pointer string pointer], :pointer
|
50
|
+
attach_function :TFE_DeleteOp, %i[pointer], :void
|
51
|
+
attach_function :TFE_OpSetDevice, %i[pointer string pointer], :pointer
|
52
|
+
attach_function :TFE_OpGetDevice, %i[pointer pointer], :string
|
53
|
+
attach_function :TFE_OpAddInput, %i[pointer pointer pointer], :void
|
54
|
+
attach_function :TFE_OpGetAttrType, %i[pointer string pointer pointer], :int
|
55
|
+
attach_function :TFE_OpSetAttrString, %i[pointer string pointer size_t], :void
|
56
|
+
attach_function :TFE_OpSetAttrType, %i[pointer string int], :void
|
57
|
+
attach_function :TFE_OpSetAttrShape, %i[pointer string pointer int pointer], :void
|
58
|
+
attach_function :TFE_Execute, %i[pointer pointer pointer pointer], :pointer
|
59
|
+
end
|
60
|
+
end
|
@@ -0,0 +1,193 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
class Tensor
|
3
|
+
def initialize(value = nil, pointer: nil, dtype: nil, shape: nil)
|
4
|
+
@status = FFI.TF_NewStatus
|
5
|
+
|
6
|
+
if pointer
|
7
|
+
@pointer = pointer
|
8
|
+
else
|
9
|
+
data = Array(value)
|
10
|
+
shape ||= calculate_shape(value)
|
11
|
+
|
12
|
+
if shape.size > 0
|
13
|
+
dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
|
14
|
+
dims_ptr.write_array_of_int64(shape)
|
15
|
+
else
|
16
|
+
dims_ptr = nil
|
17
|
+
end
|
18
|
+
|
19
|
+
data = data.flatten
|
20
|
+
|
21
|
+
dtype ||= Utils.infer_type(value)
|
22
|
+
type = FFI::DataType[dtype]
|
23
|
+
case dtype
|
24
|
+
when :string
|
25
|
+
data_ptr = string_ptr(data)
|
26
|
+
when :float
|
27
|
+
data_ptr = ::FFI::MemoryPointer.new(:float, data.size)
|
28
|
+
data_ptr.write_array_of_float(data)
|
29
|
+
when :int32
|
30
|
+
data_ptr = ::FFI::MemoryPointer.new(:int32, data.size)
|
31
|
+
data_ptr.write_array_of_int32(data)
|
32
|
+
else
|
33
|
+
raise "Unknown type: #{dtype}"
|
34
|
+
end
|
35
|
+
|
36
|
+
callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
|
37
|
+
# FFI handles deallocation
|
38
|
+
end
|
39
|
+
|
40
|
+
tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
|
41
|
+
@pointer = FFI.TFE_NewTensorHandle(tensor, @status)
|
42
|
+
check_status @status
|
43
|
+
end
|
44
|
+
|
45
|
+
# TODO fix segfault
|
46
|
+
# ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
47
|
+
end
|
48
|
+
|
49
|
+
def +(other)
|
50
|
+
TensorFlow.add(self, other)
|
51
|
+
end
|
52
|
+
|
53
|
+
def -(other)
|
54
|
+
TensorFlow.subtract(self, other)
|
55
|
+
end
|
56
|
+
|
57
|
+
def *(other)
|
58
|
+
TensorFlow.multiply(self, other)
|
59
|
+
end
|
60
|
+
|
61
|
+
def /(other)
|
62
|
+
TensorFlow.divide(self, other)
|
63
|
+
end
|
64
|
+
|
65
|
+
def %(other)
|
66
|
+
TensorFlow.floormod(self, other)
|
67
|
+
end
|
68
|
+
|
69
|
+
def num_dims
|
70
|
+
ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
|
71
|
+
check_status @status
|
72
|
+
ret
|
73
|
+
end
|
74
|
+
|
75
|
+
def dtype
|
76
|
+
@dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
|
77
|
+
end
|
78
|
+
|
79
|
+
def element_count
|
80
|
+
ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
|
81
|
+
check_status @status
|
82
|
+
ret
|
83
|
+
end
|
84
|
+
|
85
|
+
def shape
|
86
|
+
@shape ||= begin
|
87
|
+
shape = []
|
88
|
+
num_dims.times do |i|
|
89
|
+
shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
|
90
|
+
check_status @status
|
91
|
+
end
|
92
|
+
shape
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def data_pointer
|
97
|
+
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
98
|
+
check_status @status
|
99
|
+
FFI.TF_TensorData(tensor)
|
100
|
+
end
|
101
|
+
|
102
|
+
def to_ptr
|
103
|
+
@pointer
|
104
|
+
end
|
105
|
+
|
106
|
+
def value
|
107
|
+
value =
|
108
|
+
case dtype
|
109
|
+
when :float
|
110
|
+
data_pointer.read_array_of_float(element_count)
|
111
|
+
when :int32
|
112
|
+
data_pointer.read_array_of_int32(element_count)
|
113
|
+
when :string
|
114
|
+
# string tensor format
|
115
|
+
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
116
|
+
start_offset_size = element_count * 8
|
117
|
+
offsets = data_pointer.read_array_of_uint64(element_count)
|
118
|
+
element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string }
|
119
|
+
when :bool
|
120
|
+
data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
|
121
|
+
when :resource
|
122
|
+
return data_pointer
|
123
|
+
else
|
124
|
+
raise "Unknown type: #{dtype}"
|
125
|
+
end
|
126
|
+
|
127
|
+
reshape(value, shape)
|
128
|
+
end
|
129
|
+
|
130
|
+
def to_s
|
131
|
+
inspect
|
132
|
+
end
|
133
|
+
|
134
|
+
def to_i
|
135
|
+
value.to_i
|
136
|
+
end
|
137
|
+
|
138
|
+
def to_a
|
139
|
+
value
|
140
|
+
end
|
141
|
+
|
142
|
+
def inspect
|
143
|
+
inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
144
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
145
|
+
end
|
146
|
+
|
147
|
+
def self.finalize(pointer)
|
148
|
+
# must use proc instead of stabby lambda
|
149
|
+
proc { FFI.TFE_DeleteTensorHandle(pointer) }
|
150
|
+
end
|
151
|
+
|
152
|
+
private
|
153
|
+
|
154
|
+
def reshape(arr, dims)
|
155
|
+
return arr.first if dims.empty?
|
156
|
+
arr = arr.flatten
|
157
|
+
dims[1..-1].reverse.each do |dim|
|
158
|
+
arr = arr.each_slice(dim)
|
159
|
+
end
|
160
|
+
arr.to_a
|
161
|
+
end
|
162
|
+
|
163
|
+
def calculate_shape(value)
|
164
|
+
shape = []
|
165
|
+
d = value
|
166
|
+
while d.is_a?(Array)
|
167
|
+
shape << d.size
|
168
|
+
d = d.first
|
169
|
+
end
|
170
|
+
shape
|
171
|
+
end
|
172
|
+
|
173
|
+
# string tensor format
|
174
|
+
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
175
|
+
def string_ptr(data)
|
176
|
+
start_offset_size = data.size * 8
|
177
|
+
offsets = [0]
|
178
|
+
data.each do |str|
|
179
|
+
offsets << str.bytesize + 1
|
180
|
+
end
|
181
|
+
data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
|
182
|
+
data_ptr.write_array_of_uint64(offsets)
|
183
|
+
data.zip(offsets) do |str, off|
|
184
|
+
(data_ptr + start_offset_size + off).write_string(str)
|
185
|
+
end
|
186
|
+
data_ptr
|
187
|
+
end
|
188
|
+
|
189
|
+
def check_status(status)
|
190
|
+
Utils.check_status(status)
|
191
|
+
end
|
192
|
+
end
|
193
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
module Utils
|
3
|
+
def self.check_status(status)
|
4
|
+
if FFI.TF_GetCode(status) != 0
|
5
|
+
raise Error, FFI.TF_Message(status)
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.infer_type(value)
|
10
|
+
case value
|
11
|
+
when String
|
12
|
+
:string
|
13
|
+
when Float
|
14
|
+
:float
|
15
|
+
when true, false
|
16
|
+
:bool
|
17
|
+
else
|
18
|
+
:int32
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
class Variable
|
3
|
+
def initialize(initial_value, dtype: nil)
|
4
|
+
@dtype = dtype || Utils.infer_type(initial_value)
|
5
|
+
@pointer = TensorFlow.var_handle_op(type_enum, nil, shared_name: TensorFlow.send(:default_context).shared_name)
|
6
|
+
assign(initial_value)
|
7
|
+
end
|
8
|
+
|
9
|
+
def assign(value)
|
10
|
+
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
11
|
+
TensorFlow.assign_variable_op(@pointer, value)
|
12
|
+
self
|
13
|
+
end
|
14
|
+
|
15
|
+
def assign_add(value)
|
16
|
+
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
17
|
+
TensorFlow.assign_add_variable_op(@pointer, value)
|
18
|
+
self
|
19
|
+
end
|
20
|
+
|
21
|
+
def assign_sub(value)
|
22
|
+
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
23
|
+
TensorFlow.assign_sub_variable_op(@pointer, value)
|
24
|
+
self
|
25
|
+
end
|
26
|
+
|
27
|
+
def read_value
|
28
|
+
TensorFlow.read_variable_op(@pointer, type_enum)
|
29
|
+
end
|
30
|
+
|
31
|
+
def +(other)
|
32
|
+
v = Variable.new(read_value.value, dtype: @dtype)
|
33
|
+
v.assign_add(other).read_value
|
34
|
+
end
|
35
|
+
|
36
|
+
def -(other)
|
37
|
+
v = Variable.new(read_value.value, dtype: @dtype)
|
38
|
+
v.assign_sub(other).read_value
|
39
|
+
end
|
40
|
+
|
41
|
+
def to_s
|
42
|
+
inspect
|
43
|
+
end
|
44
|
+
|
45
|
+
def inspect
|
46
|
+
value = read_value
|
47
|
+
inspection = %w(value shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
|
48
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
49
|
+
end
|
50
|
+
|
51
|
+
private
|
52
|
+
|
53
|
+
def type_enum
|
54
|
+
FFI::DataType[@dtype.to_sym] if @dtype
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
data/lib/tensorflow.rb
ADDED
@@ -0,0 +1,188 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "ffi"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "tensorflow/utils"
|
6
|
+
require "tensorflow/context"
|
7
|
+
require "tensorflow/tensor"
|
8
|
+
require "tensorflow/variable"
|
9
|
+
require "tensorflow/version"
|
10
|
+
|
11
|
+
module TensorFlow
|
12
|
+
class Error < StandardError; end
|
13
|
+
|
14
|
+
class << self
|
15
|
+
attr_accessor :ffi_lib
|
16
|
+
end
|
17
|
+
self.ffi_lib = ["tensorflow", "libtensorflow.so"]
|
18
|
+
|
19
|
+
# friendlier error message
|
20
|
+
autoload :FFI, "tensorflow/ffi"
|
21
|
+
|
22
|
+
class << self
|
23
|
+
include Utils
|
24
|
+
|
25
|
+
def library_version
|
26
|
+
FFI.TF_Version
|
27
|
+
end
|
28
|
+
|
29
|
+
def constant(value, dtype: nil, shape: nil)
|
30
|
+
Tensor.new(value, dtype: dtype, shape: shape)
|
31
|
+
end
|
32
|
+
|
33
|
+
def convert_to_tensor(value, dtype: nil)
|
34
|
+
value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor)
|
35
|
+
value
|
36
|
+
end
|
37
|
+
|
38
|
+
def add(x, y)
|
39
|
+
execute("Add", [x, y])
|
40
|
+
end
|
41
|
+
|
42
|
+
def subtract(x, y)
|
43
|
+
execute("Sub", [x, y])
|
44
|
+
end
|
45
|
+
|
46
|
+
def multiply(x, y)
|
47
|
+
execute("Mul", [x, y])
|
48
|
+
end
|
49
|
+
|
50
|
+
def divide(x, y)
|
51
|
+
execute("Div", [x, y])
|
52
|
+
end
|
53
|
+
|
54
|
+
def abs(x)
|
55
|
+
execute("Abs", [x])
|
56
|
+
end
|
57
|
+
|
58
|
+
def sqrt(x)
|
59
|
+
execute("Sqrt", [x])
|
60
|
+
end
|
61
|
+
|
62
|
+
def matmul(x, y)
|
63
|
+
execute("MatMul", [x, y])
|
64
|
+
end
|
65
|
+
|
66
|
+
def floormod(x, y)
|
67
|
+
execute("Mod", [x, y])
|
68
|
+
end
|
69
|
+
|
70
|
+
def range(start, limit, delta)
|
71
|
+
execute("Range", [start, limit, delta])
|
72
|
+
end
|
73
|
+
|
74
|
+
def transpose(x, perm: [1, 0])
|
75
|
+
execute("Transpose", [x, perm])
|
76
|
+
end
|
77
|
+
|
78
|
+
def equal(x, y)
|
79
|
+
execute("Equal", [x, y])
|
80
|
+
end
|
81
|
+
|
82
|
+
def zeros_like(x)
|
83
|
+
execute("ZerosLike", [x])
|
84
|
+
end
|
85
|
+
|
86
|
+
def fill(dims, value)
|
87
|
+
execute("Fill", [dims, value])
|
88
|
+
end
|
89
|
+
|
90
|
+
def zeros(dims)
|
91
|
+
fill(dims, 0)
|
92
|
+
end
|
93
|
+
|
94
|
+
def ones(dims)
|
95
|
+
fill(dims, 1)
|
96
|
+
end
|
97
|
+
|
98
|
+
def assign_add_variable_op(resource, value)
|
99
|
+
execute("AssignAddVariableOp", [resource, value])
|
100
|
+
end
|
101
|
+
|
102
|
+
def assign_sub_variable_op(resource, value)
|
103
|
+
execute("AssignSubVariableOp", [resource, value])
|
104
|
+
end
|
105
|
+
|
106
|
+
def assign_variable_op(resource, value)
|
107
|
+
execute("AssignVariableOp", [resource, value])
|
108
|
+
end
|
109
|
+
|
110
|
+
def read_variable_op(resource, dtype)
|
111
|
+
execute("ReadVariableOp", [resource], dtype: dtype)
|
112
|
+
end
|
113
|
+
|
114
|
+
def var_handle_op(dtype, shape, container: "", shared_name: "")
|
115
|
+
execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape)
|
116
|
+
end
|
117
|
+
|
118
|
+
def var_is_initialized_op(resource)
|
119
|
+
execute("VarIsInitializedOp", [resource])
|
120
|
+
end
|
121
|
+
|
122
|
+
private
|
123
|
+
|
124
|
+
def default_context
|
125
|
+
@default_context ||= Context.new
|
126
|
+
end
|
127
|
+
|
128
|
+
def execute(op_name, inputs = [], **attrs)
|
129
|
+
context = default_context
|
130
|
+
status = FFI.TF_NewStatus
|
131
|
+
op = FFI.TFE_NewOp(context, op_name, status)
|
132
|
+
check_status status
|
133
|
+
# TODO clean up status and op
|
134
|
+
|
135
|
+
attrs.each do |attr_name, attr_value|
|
136
|
+
attr_name = attr_name.to_s
|
137
|
+
|
138
|
+
is_list = ::FFI::MemoryPointer.new(:int)
|
139
|
+
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
140
|
+
check_status status
|
141
|
+
|
142
|
+
case FFI::AttrType[type]
|
143
|
+
when :type
|
144
|
+
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
145
|
+
when :shape
|
146
|
+
# TODO set value properly
|
147
|
+
FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status)
|
148
|
+
check_status status
|
149
|
+
when :string
|
150
|
+
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
151
|
+
else
|
152
|
+
raise "Unknown type: #{FFI::AttrType[type]}"
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
inputs.each do |input|
|
157
|
+
input = convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
158
|
+
FFI.TFE_OpAddInput(op, input, status)
|
159
|
+
check_status status
|
160
|
+
end
|
161
|
+
|
162
|
+
retvals = ::FFI::MemoryPointer.new(:pointer)
|
163
|
+
num_retvals = ::FFI::MemoryPointer.new(:int)
|
164
|
+
num_retvals.write_int(retvals.size)
|
165
|
+
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
166
|
+
check_status status
|
167
|
+
|
168
|
+
if num_retvals.read_int > 0
|
169
|
+
handle = retvals.read_pointer
|
170
|
+
type = FFI.TFE_TensorHandleDataType(handle)
|
171
|
+
|
172
|
+
case FFI::DataType[type]
|
173
|
+
when :resource
|
174
|
+
handle
|
175
|
+
else
|
176
|
+
Tensor.new(pointer: handle)
|
177
|
+
end
|
178
|
+
end
|
179
|
+
end
|
180
|
+
|
181
|
+
def check_status(status)
|
182
|
+
Utils.check_status(status)
|
183
|
+
end
|
184
|
+
end
|
185
|
+
end
|
186
|
+
|
187
|
+
# shortcut
|
188
|
+
Tf = TensorFlow
|
metadata
ADDED
@@ -0,0 +1,108 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: tensorflow
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-09-18 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: ffi
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: minitest
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '5'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '5'
|
69
|
+
description:
|
70
|
+
email: andrew@chartkick.com
|
71
|
+
executables: []
|
72
|
+
extensions: []
|
73
|
+
extra_rdoc_files: []
|
74
|
+
files:
|
75
|
+
- CHANGELOG.md
|
76
|
+
- LICENSE.txt
|
77
|
+
- README.md
|
78
|
+
- lib/tensorflow.rb
|
79
|
+
- lib/tensorflow/context.rb
|
80
|
+
- lib/tensorflow/ffi.rb
|
81
|
+
- lib/tensorflow/tensor.rb
|
82
|
+
- lib/tensorflow/utils.rb
|
83
|
+
- lib/tensorflow/variable.rb
|
84
|
+
- lib/tensorflow/version.rb
|
85
|
+
homepage: https://github.com/ankane/tensorflow
|
86
|
+
licenses:
|
87
|
+
- MIT
|
88
|
+
metadata: {}
|
89
|
+
post_install_message:
|
90
|
+
rdoc_options: []
|
91
|
+
require_paths:
|
92
|
+
- lib
|
93
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
94
|
+
requirements:
|
95
|
+
- - ">="
|
96
|
+
- !ruby/object:Gem::Version
|
97
|
+
version: '2.4'
|
98
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
99
|
+
requirements:
|
100
|
+
- - ">="
|
101
|
+
- !ruby/object:Gem::Version
|
102
|
+
version: '0'
|
103
|
+
requirements: []
|
104
|
+
rubygems_version: 3.0.3
|
105
|
+
signing_key:
|
106
|
+
specification_version: 4
|
107
|
+
summary: TensorFlow - the end-to-end machine learning platform - for Ruby
|
108
|
+
test_files: []
|