tensorflow 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +113 -0
- data/lib/tensorflow/context.rb +26 -0
- data/lib/tensorflow/ffi.rb +60 -0
- data/lib/tensorflow/tensor.rb +193 -0
- data/lib/tensorflow/utils.rb +22 -0
- data/lib/tensorflow/variable.rb +57 -0
- data/lib/tensorflow/version.rb +3 -0
- data/lib/tensorflow.rb +188 -0
- metadata +108 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 32ac7a47334631b2c5393721e5a7f848a26d9535b473557e146e8e6b4b6337be
|
4
|
+
data.tar.gz: 21ca600101471ee43edd9a404f3072d67cbbdfa44f412ba8ec76a2546973087d
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 601c99ff6035138797f0e14a0f5f2c52ae68b4ee4c5223ef5676e45aaf2f61a7994f0a2ce2ad353b03fa503d7ad8fd6e9c49ba7567fbdcee2156bffe6d36a4c4
|
7
|
+
data.tar.gz: 825d6b40f58f3cbe4d1c6f5b9cff5c7a739b730a89e025bf305bcdf3789269ce34f851709a0294d2c830ca1621a08ff3a67b352422bafd92eb8dec30f65e42e3
|
data/CHANGELOG.md
ADDED
data/LICENSE.txt
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
Copyright (c) 2019 Andrew Kane
|
2
|
+
|
3
|
+
MIT License
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining
|
6
|
+
a copy of this software and associated documentation files (the
|
7
|
+
"Software"), to deal in the Software without restriction, including
|
8
|
+
without limitation the rights to use, copy, modify, merge, publish,
|
9
|
+
distribute, sublicense, and/or sell copies of the Software, and to
|
10
|
+
permit persons to whom the Software is furnished to do so, subject to
|
11
|
+
the following conditions:
|
12
|
+
|
13
|
+
The above copyright notice and this permission notice shall be
|
14
|
+
included in all copies or substantial portions of the Software.
|
15
|
+
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17
|
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
18
|
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
19
|
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
20
|
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
21
|
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
22
|
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
data/README.md
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
# TensorFlow
|
2
|
+
|
3
|
+
[TensorFlow](https://github.com/tensorflow/tensorflow) - the end-to-end machine learning platform - for Ruby
|
4
|
+
|
5
|
+
:fire: Uses the C API for blazing performance
|
6
|
+
|
7
|
+
## Installation
|
8
|
+
|
9
|
+
[Install TensorFlow](#tensorflow-installation). For Homebrew, use:
|
10
|
+
|
11
|
+
```sh
|
12
|
+
brew install tensorflow
|
13
|
+
```
|
14
|
+
|
15
|
+
Add this line to your application’s Gemfile:
|
16
|
+
|
17
|
+
```ruby
|
18
|
+
gem 'tensorflow'
|
19
|
+
```
|
20
|
+
|
21
|
+
## Getting Started
|
22
|
+
|
23
|
+
This library follows the TensorFlow 2.0 [Python API](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf). Many methods and options are missing at the moment. PRs welcome!
|
24
|
+
|
25
|
+
## Constants
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
a = Tf.constant(2)
|
29
|
+
b = Tf.constant(3)
|
30
|
+
a + b
|
31
|
+
```
|
32
|
+
|
33
|
+
## Variables
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
v = Tf::Variable.new(0.0)
|
37
|
+
w = v + 1
|
38
|
+
```
|
39
|
+
|
40
|
+
## FizzBuzz
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
def fizzbuzz(max_num)
|
44
|
+
counter = Tf.constant(0)
|
45
|
+
max_num.times do |i|
|
46
|
+
num = Tf.constant(i + 1)
|
47
|
+
if (num % 3).to_i == 0 && (num % 5).to_i == 0
|
48
|
+
puts "FizzBuzz"
|
49
|
+
elsif (num % 3).to_i == 0
|
50
|
+
puts "Fizz"
|
51
|
+
elsif (num % 5).to_i == 0
|
52
|
+
puts "Buzz"
|
53
|
+
else
|
54
|
+
puts num.to_i
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
fizzbuzz(15)
|
60
|
+
```
|
61
|
+
|
62
|
+
## Keras
|
63
|
+
|
64
|
+
Coming soon
|
65
|
+
|
66
|
+
```ruby
|
67
|
+
mnist = Tf::Keras::Datasets::MNIST
|
68
|
+
(x_train, y_train), (x_test, y_test) = mnist.load_data
|
69
|
+
x_train = x_train / 255.0
|
70
|
+
x_test = x_test / 255.0
|
71
|
+
|
72
|
+
model = Tf::Keras::Models::Sequential.new([
|
73
|
+
Tf::Keras::Layers::Flatten.new(input_shape: [28, 28]),
|
74
|
+
Tf::Keras::Layers::Dense.new(128, activation: "relu"),
|
75
|
+
Tf::Keras::Layers::Dropout.new(0.2),
|
76
|
+
Tf::Keras::Layers::Dense.new(10, activation: "softmax")
|
77
|
+
])
|
78
|
+
|
79
|
+
model.compile(optimizer: "adam", loss: "sparse_categorical_crossentropy", metrics: ["accuracy"])
|
80
|
+
model.fit(x_train, y_train, epochs: 5)
|
81
|
+
model.evaluate(x_test, y_test)
|
82
|
+
```
|
83
|
+
|
84
|
+
## TensorFlow Installation
|
85
|
+
|
86
|
+
### Mac
|
87
|
+
|
88
|
+
Run:
|
89
|
+
|
90
|
+
```sh
|
91
|
+
brew install tensorflow
|
92
|
+
```
|
93
|
+
|
94
|
+
### Linux
|
95
|
+
|
96
|
+
Download the [shared library](https://www.tensorflow.org/install/lang_c#download) and move `libtensorflow.so` to `/usr/local/lib`.
|
97
|
+
|
98
|
+
### Windows
|
99
|
+
|
100
|
+
Download the [shared library](https://www.tensorflow.org/install/lang_c#download) and move `tensorflow.dll` to `C:\Windows\System32`.
|
101
|
+
|
102
|
+
## History
|
103
|
+
|
104
|
+
View the [changelog](https://github.com/ankane/tensorflow/blob/master/CHANGELOG.md)
|
105
|
+
|
106
|
+
## Contributing
|
107
|
+
|
108
|
+
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
109
|
+
|
110
|
+
- [Report bugs](https://github.com/ankane/tensorflow/issues)
|
111
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/tensorflow/pulls)
|
112
|
+
- Write, clarify, or fix documentation
|
113
|
+
- Suggest or add new features
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
class Context
|
3
|
+
def initialize
|
4
|
+
options = FFI.TFE_NewContextOptions
|
5
|
+
@status = TensorFlow::FFI.TF_NewStatus
|
6
|
+
@pointer = FFI.TFE_NewContext(options, @status)
|
7
|
+
Utils.check_status @status
|
8
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
9
|
+
FFI.TFE_DeleteContextOptions(options)
|
10
|
+
end
|
11
|
+
|
12
|
+
def self.finalize(pointer)
|
13
|
+
# must use proc instead of stabby lambda
|
14
|
+
proc { FFI.TFE_DeleteContext(pointer) }
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_ptr
|
18
|
+
@pointer
|
19
|
+
end
|
20
|
+
|
21
|
+
def shared_name
|
22
|
+
# hard-coded in Python library
|
23
|
+
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,60 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
module FFI
|
3
|
+
extend ::FFI::Library
|
4
|
+
|
5
|
+
begin
|
6
|
+
ffi_lib TensorFlow.ffi_lib
|
7
|
+
rescue LoadError => e
|
8
|
+
raise e if ENV["TENSORFLOW_DEBUG"]
|
9
|
+
raise LoadError, "Could not find TensorFlow"
|
10
|
+
end
|
11
|
+
|
12
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
|
13
|
+
attach_function :TF_Version, %i[], :string
|
14
|
+
|
15
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_attrtype.h
|
16
|
+
AttrType = enum(:string, :int, :float, :bool, :type, :shape, :tensor, :placeholder, :func)
|
17
|
+
|
18
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_datatype.h
|
19
|
+
DataType = enum(:float, 1, :double, :int32, :uint8, :int16, :int8, :string, :complex64, :int64, :bool, :qint8, :quint8, :qint32, :bfloat16, :qint16, :quint16, :uint16, :complex128, :half, :resource, :variant, :uint32, :uint64)
|
20
|
+
|
21
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h
|
22
|
+
attach_function :TF_NewStatus, %i[], :pointer
|
23
|
+
attach_function :TF_DeleteStatus, %i[pointer], :pointer
|
24
|
+
attach_function :TF_GetCode, %i[pointer], :int
|
25
|
+
attach_function :TF_Message, %i[pointer], :string
|
26
|
+
|
27
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h
|
28
|
+
attach_function :TF_NewTensor, %i[int pointer int pointer size_t pointer pointer], :pointer
|
29
|
+
attach_function :TF_DeleteTensor, %i[pointer], :void
|
30
|
+
attach_function :TF_TensorData, %i[pointer], :pointer
|
31
|
+
|
32
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/eager/c_api.h
|
33
|
+
attach_function :TFE_NewContextOptions, %i[], :pointer
|
34
|
+
attach_function :TFE_ContextOptionsSetAsync, %i[pointer char], :void
|
35
|
+
attach_function :TFE_DeleteContextOptions, %i[pointer], :void
|
36
|
+
attach_function :TFE_NewContext, %i[pointer pointer], :pointer
|
37
|
+
attach_function :TFE_DeleteContext, %i[pointer], :void
|
38
|
+
attach_function :TFE_ContextListDevices, %i[pointer pointer], :pointer
|
39
|
+
attach_function :TFE_ContextGetDevicePlacementPolicy, %i[pointer], :int
|
40
|
+
attach_function :TFE_NewTensorHandle, %i[pointer pointer], :pointer
|
41
|
+
attach_function :TFE_DeleteTensorHandle, %i[pointer], :void
|
42
|
+
attach_function :TFE_TensorHandleDataType, %i[pointer], :int
|
43
|
+
attach_function :TFE_TensorHandleNumDims, %i[pointer pointer], :int
|
44
|
+
attach_function :TFE_TensorHandleNumElements, %i[pointer pointer], :int64
|
45
|
+
attach_function :TFE_TensorHandleDim, %i[pointer int pointer], :int64
|
46
|
+
attach_function :TFE_TensorHandleDeviceName, %i[pointer pointer], :string
|
47
|
+
attach_function :TFE_TensorHandleBackingDeviceName, %i[pointer pointer], :string
|
48
|
+
attach_function :TFE_TensorHandleResolve, %i[pointer pointer], :pointer
|
49
|
+
attach_function :TFE_NewOp, %i[pointer string pointer], :pointer
|
50
|
+
attach_function :TFE_DeleteOp, %i[pointer], :void
|
51
|
+
attach_function :TFE_OpSetDevice, %i[pointer string pointer], :pointer
|
52
|
+
attach_function :TFE_OpGetDevice, %i[pointer pointer], :string
|
53
|
+
attach_function :TFE_OpAddInput, %i[pointer pointer pointer], :void
|
54
|
+
attach_function :TFE_OpGetAttrType, %i[pointer string pointer pointer], :int
|
55
|
+
attach_function :TFE_OpSetAttrString, %i[pointer string pointer size_t], :void
|
56
|
+
attach_function :TFE_OpSetAttrType, %i[pointer string int], :void
|
57
|
+
attach_function :TFE_OpSetAttrShape, %i[pointer string pointer int pointer], :void
|
58
|
+
attach_function :TFE_Execute, %i[pointer pointer pointer pointer], :pointer
|
59
|
+
end
|
60
|
+
end
|
@@ -0,0 +1,193 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
class Tensor
|
3
|
+
def initialize(value = nil, pointer: nil, dtype: nil, shape: nil)
|
4
|
+
@status = FFI.TF_NewStatus
|
5
|
+
|
6
|
+
if pointer
|
7
|
+
@pointer = pointer
|
8
|
+
else
|
9
|
+
data = Array(value)
|
10
|
+
shape ||= calculate_shape(value)
|
11
|
+
|
12
|
+
if shape.size > 0
|
13
|
+
dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
|
14
|
+
dims_ptr.write_array_of_int64(shape)
|
15
|
+
else
|
16
|
+
dims_ptr = nil
|
17
|
+
end
|
18
|
+
|
19
|
+
data = data.flatten
|
20
|
+
|
21
|
+
dtype ||= Utils.infer_type(value)
|
22
|
+
type = FFI::DataType[dtype]
|
23
|
+
case dtype
|
24
|
+
when :string
|
25
|
+
data_ptr = string_ptr(data)
|
26
|
+
when :float
|
27
|
+
data_ptr = ::FFI::MemoryPointer.new(:float, data.size)
|
28
|
+
data_ptr.write_array_of_float(data)
|
29
|
+
when :int32
|
30
|
+
data_ptr = ::FFI::MemoryPointer.new(:int32, data.size)
|
31
|
+
data_ptr.write_array_of_int32(data)
|
32
|
+
else
|
33
|
+
raise "Unknown type: #{dtype}"
|
34
|
+
end
|
35
|
+
|
36
|
+
callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
|
37
|
+
# FFI handles deallocation
|
38
|
+
end
|
39
|
+
|
40
|
+
tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
|
41
|
+
@pointer = FFI.TFE_NewTensorHandle(tensor, @status)
|
42
|
+
check_status @status
|
43
|
+
end
|
44
|
+
|
45
|
+
# TODO fix segfault
|
46
|
+
# ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
47
|
+
end
|
48
|
+
|
49
|
+
def +(other)
|
50
|
+
TensorFlow.add(self, other)
|
51
|
+
end
|
52
|
+
|
53
|
+
def -(other)
|
54
|
+
TensorFlow.subtract(self, other)
|
55
|
+
end
|
56
|
+
|
57
|
+
def *(other)
|
58
|
+
TensorFlow.multiply(self, other)
|
59
|
+
end
|
60
|
+
|
61
|
+
def /(other)
|
62
|
+
TensorFlow.divide(self, other)
|
63
|
+
end
|
64
|
+
|
65
|
+
def %(other)
|
66
|
+
TensorFlow.floormod(self, other)
|
67
|
+
end
|
68
|
+
|
69
|
+
def num_dims
|
70
|
+
ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
|
71
|
+
check_status @status
|
72
|
+
ret
|
73
|
+
end
|
74
|
+
|
75
|
+
def dtype
|
76
|
+
@dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
|
77
|
+
end
|
78
|
+
|
79
|
+
def element_count
|
80
|
+
ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
|
81
|
+
check_status @status
|
82
|
+
ret
|
83
|
+
end
|
84
|
+
|
85
|
+
def shape
|
86
|
+
@shape ||= begin
|
87
|
+
shape = []
|
88
|
+
num_dims.times do |i|
|
89
|
+
shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
|
90
|
+
check_status @status
|
91
|
+
end
|
92
|
+
shape
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def data_pointer
|
97
|
+
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
98
|
+
check_status @status
|
99
|
+
FFI.TF_TensorData(tensor)
|
100
|
+
end
|
101
|
+
|
102
|
+
def to_ptr
|
103
|
+
@pointer
|
104
|
+
end
|
105
|
+
|
106
|
+
def value
|
107
|
+
value =
|
108
|
+
case dtype
|
109
|
+
when :float
|
110
|
+
data_pointer.read_array_of_float(element_count)
|
111
|
+
when :int32
|
112
|
+
data_pointer.read_array_of_int32(element_count)
|
113
|
+
when :string
|
114
|
+
# string tensor format
|
115
|
+
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
116
|
+
start_offset_size = element_count * 8
|
117
|
+
offsets = data_pointer.read_array_of_uint64(element_count)
|
118
|
+
element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string }
|
119
|
+
when :bool
|
120
|
+
data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
|
121
|
+
when :resource
|
122
|
+
return data_pointer
|
123
|
+
else
|
124
|
+
raise "Unknown type: #{dtype}"
|
125
|
+
end
|
126
|
+
|
127
|
+
reshape(value, shape)
|
128
|
+
end
|
129
|
+
|
130
|
+
def to_s
|
131
|
+
inspect
|
132
|
+
end
|
133
|
+
|
134
|
+
def to_i
|
135
|
+
value.to_i
|
136
|
+
end
|
137
|
+
|
138
|
+
def to_a
|
139
|
+
value
|
140
|
+
end
|
141
|
+
|
142
|
+
def inspect
|
143
|
+
inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
144
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
145
|
+
end
|
146
|
+
|
147
|
+
def self.finalize(pointer)
|
148
|
+
# must use proc instead of stabby lambda
|
149
|
+
proc { FFI.TFE_DeleteTensorHandle(pointer) }
|
150
|
+
end
|
151
|
+
|
152
|
+
private
|
153
|
+
|
154
|
+
def reshape(arr, dims)
|
155
|
+
return arr.first if dims.empty?
|
156
|
+
arr = arr.flatten
|
157
|
+
dims[1..-1].reverse.each do |dim|
|
158
|
+
arr = arr.each_slice(dim)
|
159
|
+
end
|
160
|
+
arr.to_a
|
161
|
+
end
|
162
|
+
|
163
|
+
def calculate_shape(value)
|
164
|
+
shape = []
|
165
|
+
d = value
|
166
|
+
while d.is_a?(Array)
|
167
|
+
shape << d.size
|
168
|
+
d = d.first
|
169
|
+
end
|
170
|
+
shape
|
171
|
+
end
|
172
|
+
|
173
|
+
# string tensor format
|
174
|
+
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
175
|
+
def string_ptr(data)
|
176
|
+
start_offset_size = data.size * 8
|
177
|
+
offsets = [0]
|
178
|
+
data.each do |str|
|
179
|
+
offsets << str.bytesize + 1
|
180
|
+
end
|
181
|
+
data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
|
182
|
+
data_ptr.write_array_of_uint64(offsets)
|
183
|
+
data.zip(offsets) do |str, off|
|
184
|
+
(data_ptr + start_offset_size + off).write_string(str)
|
185
|
+
end
|
186
|
+
data_ptr
|
187
|
+
end
|
188
|
+
|
189
|
+
def check_status(status)
|
190
|
+
Utils.check_status(status)
|
191
|
+
end
|
192
|
+
end
|
193
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
module Utils
|
3
|
+
def self.check_status(status)
|
4
|
+
if FFI.TF_GetCode(status) != 0
|
5
|
+
raise Error, FFI.TF_Message(status)
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.infer_type(value)
|
10
|
+
case value
|
11
|
+
when String
|
12
|
+
:string
|
13
|
+
when Float
|
14
|
+
:float
|
15
|
+
when true, false
|
16
|
+
:bool
|
17
|
+
else
|
18
|
+
:int32
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
class Variable
|
3
|
+
def initialize(initial_value, dtype: nil)
|
4
|
+
@dtype = dtype || Utils.infer_type(initial_value)
|
5
|
+
@pointer = TensorFlow.var_handle_op(type_enum, nil, shared_name: TensorFlow.send(:default_context).shared_name)
|
6
|
+
assign(initial_value)
|
7
|
+
end
|
8
|
+
|
9
|
+
def assign(value)
|
10
|
+
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
11
|
+
TensorFlow.assign_variable_op(@pointer, value)
|
12
|
+
self
|
13
|
+
end
|
14
|
+
|
15
|
+
def assign_add(value)
|
16
|
+
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
17
|
+
TensorFlow.assign_add_variable_op(@pointer, value)
|
18
|
+
self
|
19
|
+
end
|
20
|
+
|
21
|
+
def assign_sub(value)
|
22
|
+
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
23
|
+
TensorFlow.assign_sub_variable_op(@pointer, value)
|
24
|
+
self
|
25
|
+
end
|
26
|
+
|
27
|
+
def read_value
|
28
|
+
TensorFlow.read_variable_op(@pointer, type_enum)
|
29
|
+
end
|
30
|
+
|
31
|
+
def +(other)
|
32
|
+
v = Variable.new(read_value.value, dtype: @dtype)
|
33
|
+
v.assign_add(other).read_value
|
34
|
+
end
|
35
|
+
|
36
|
+
def -(other)
|
37
|
+
v = Variable.new(read_value.value, dtype: @dtype)
|
38
|
+
v.assign_sub(other).read_value
|
39
|
+
end
|
40
|
+
|
41
|
+
def to_s
|
42
|
+
inspect
|
43
|
+
end
|
44
|
+
|
45
|
+
def inspect
|
46
|
+
value = read_value
|
47
|
+
inspection = %w(value shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
|
48
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
49
|
+
end
|
50
|
+
|
51
|
+
private
|
52
|
+
|
53
|
+
def type_enum
|
54
|
+
FFI::DataType[@dtype.to_sym] if @dtype
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
data/lib/tensorflow.rb
ADDED
@@ -0,0 +1,188 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "ffi"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "tensorflow/utils"
|
6
|
+
require "tensorflow/context"
|
7
|
+
require "tensorflow/tensor"
|
8
|
+
require "tensorflow/variable"
|
9
|
+
require "tensorflow/version"
|
10
|
+
|
11
|
+
module TensorFlow
|
12
|
+
class Error < StandardError; end
|
13
|
+
|
14
|
+
class << self
|
15
|
+
attr_accessor :ffi_lib
|
16
|
+
end
|
17
|
+
self.ffi_lib = ["tensorflow", "libtensorflow.so"]
|
18
|
+
|
19
|
+
# friendlier error message
|
20
|
+
autoload :FFI, "tensorflow/ffi"
|
21
|
+
|
22
|
+
class << self
|
23
|
+
include Utils
|
24
|
+
|
25
|
+
def library_version
|
26
|
+
FFI.TF_Version
|
27
|
+
end
|
28
|
+
|
29
|
+
def constant(value, dtype: nil, shape: nil)
|
30
|
+
Tensor.new(value, dtype: dtype, shape: shape)
|
31
|
+
end
|
32
|
+
|
33
|
+
def convert_to_tensor(value, dtype: nil)
|
34
|
+
value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor)
|
35
|
+
value
|
36
|
+
end
|
37
|
+
|
38
|
+
def add(x, y)
|
39
|
+
execute("Add", [x, y])
|
40
|
+
end
|
41
|
+
|
42
|
+
def subtract(x, y)
|
43
|
+
execute("Sub", [x, y])
|
44
|
+
end
|
45
|
+
|
46
|
+
def multiply(x, y)
|
47
|
+
execute("Mul", [x, y])
|
48
|
+
end
|
49
|
+
|
50
|
+
def divide(x, y)
|
51
|
+
execute("Div", [x, y])
|
52
|
+
end
|
53
|
+
|
54
|
+
def abs(x)
|
55
|
+
execute("Abs", [x])
|
56
|
+
end
|
57
|
+
|
58
|
+
def sqrt(x)
|
59
|
+
execute("Sqrt", [x])
|
60
|
+
end
|
61
|
+
|
62
|
+
def matmul(x, y)
|
63
|
+
execute("MatMul", [x, y])
|
64
|
+
end
|
65
|
+
|
66
|
+
def floormod(x, y)
|
67
|
+
execute("Mod", [x, y])
|
68
|
+
end
|
69
|
+
|
70
|
+
def range(start, limit, delta)
|
71
|
+
execute("Range", [start, limit, delta])
|
72
|
+
end
|
73
|
+
|
74
|
+
def transpose(x, perm: [1, 0])
|
75
|
+
execute("Transpose", [x, perm])
|
76
|
+
end
|
77
|
+
|
78
|
+
def equal(x, y)
|
79
|
+
execute("Equal", [x, y])
|
80
|
+
end
|
81
|
+
|
82
|
+
def zeros_like(x)
|
83
|
+
execute("ZerosLike", [x])
|
84
|
+
end
|
85
|
+
|
86
|
+
def fill(dims, value)
|
87
|
+
execute("Fill", [dims, value])
|
88
|
+
end
|
89
|
+
|
90
|
+
def zeros(dims)
|
91
|
+
fill(dims, 0)
|
92
|
+
end
|
93
|
+
|
94
|
+
def ones(dims)
|
95
|
+
fill(dims, 1)
|
96
|
+
end
|
97
|
+
|
98
|
+
def assign_add_variable_op(resource, value)
|
99
|
+
execute("AssignAddVariableOp", [resource, value])
|
100
|
+
end
|
101
|
+
|
102
|
+
def assign_sub_variable_op(resource, value)
|
103
|
+
execute("AssignSubVariableOp", [resource, value])
|
104
|
+
end
|
105
|
+
|
106
|
+
def assign_variable_op(resource, value)
|
107
|
+
execute("AssignVariableOp", [resource, value])
|
108
|
+
end
|
109
|
+
|
110
|
+
def read_variable_op(resource, dtype)
|
111
|
+
execute("ReadVariableOp", [resource], dtype: dtype)
|
112
|
+
end
|
113
|
+
|
114
|
+
def var_handle_op(dtype, shape, container: "", shared_name: "")
|
115
|
+
execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape)
|
116
|
+
end
|
117
|
+
|
118
|
+
def var_is_initialized_op(resource)
|
119
|
+
execute("VarIsInitializedOp", [resource])
|
120
|
+
end
|
121
|
+
|
122
|
+
private
|
123
|
+
|
124
|
+
def default_context
|
125
|
+
@default_context ||= Context.new
|
126
|
+
end
|
127
|
+
|
128
|
+
def execute(op_name, inputs = [], **attrs)
|
129
|
+
context = default_context
|
130
|
+
status = FFI.TF_NewStatus
|
131
|
+
op = FFI.TFE_NewOp(context, op_name, status)
|
132
|
+
check_status status
|
133
|
+
# TODO clean up status and op
|
134
|
+
|
135
|
+
attrs.each do |attr_name, attr_value|
|
136
|
+
attr_name = attr_name.to_s
|
137
|
+
|
138
|
+
is_list = ::FFI::MemoryPointer.new(:int)
|
139
|
+
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
140
|
+
check_status status
|
141
|
+
|
142
|
+
case FFI::AttrType[type]
|
143
|
+
when :type
|
144
|
+
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
145
|
+
when :shape
|
146
|
+
# TODO set value properly
|
147
|
+
FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status)
|
148
|
+
check_status status
|
149
|
+
when :string
|
150
|
+
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
151
|
+
else
|
152
|
+
raise "Unknown type: #{FFI::AttrType[type]}"
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
inputs.each do |input|
|
157
|
+
input = convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
158
|
+
FFI.TFE_OpAddInput(op, input, status)
|
159
|
+
check_status status
|
160
|
+
end
|
161
|
+
|
162
|
+
retvals = ::FFI::MemoryPointer.new(:pointer)
|
163
|
+
num_retvals = ::FFI::MemoryPointer.new(:int)
|
164
|
+
num_retvals.write_int(retvals.size)
|
165
|
+
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
166
|
+
check_status status
|
167
|
+
|
168
|
+
if num_retvals.read_int > 0
|
169
|
+
handle = retvals.read_pointer
|
170
|
+
type = FFI.TFE_TensorHandleDataType(handle)
|
171
|
+
|
172
|
+
case FFI::DataType[type]
|
173
|
+
when :resource
|
174
|
+
handle
|
175
|
+
else
|
176
|
+
Tensor.new(pointer: handle)
|
177
|
+
end
|
178
|
+
end
|
179
|
+
end
|
180
|
+
|
181
|
+
def check_status(status)
|
182
|
+
Utils.check_status(status)
|
183
|
+
end
|
184
|
+
end
|
185
|
+
end
|
186
|
+
|
187
|
+
# shortcut
|
188
|
+
Tf = TensorFlow
|
metadata
ADDED
@@ -0,0 +1,108 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: tensorflow
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-09-18 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: ffi
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: minitest
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '5'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '5'
|
69
|
+
description:
|
70
|
+
email: andrew@chartkick.com
|
71
|
+
executables: []
|
72
|
+
extensions: []
|
73
|
+
extra_rdoc_files: []
|
74
|
+
files:
|
75
|
+
- CHANGELOG.md
|
76
|
+
- LICENSE.txt
|
77
|
+
- README.md
|
78
|
+
- lib/tensorflow.rb
|
79
|
+
- lib/tensorflow/context.rb
|
80
|
+
- lib/tensorflow/ffi.rb
|
81
|
+
- lib/tensorflow/tensor.rb
|
82
|
+
- lib/tensorflow/utils.rb
|
83
|
+
- lib/tensorflow/variable.rb
|
84
|
+
- lib/tensorflow/version.rb
|
85
|
+
homepage: https://github.com/ankane/tensorflow
|
86
|
+
licenses:
|
87
|
+
- MIT
|
88
|
+
metadata: {}
|
89
|
+
post_install_message:
|
90
|
+
rdoc_options: []
|
91
|
+
require_paths:
|
92
|
+
- lib
|
93
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
94
|
+
requirements:
|
95
|
+
- - ">="
|
96
|
+
- !ruby/object:Gem::Version
|
97
|
+
version: '2.4'
|
98
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
99
|
+
requirements:
|
100
|
+
- - ">="
|
101
|
+
- !ruby/object:Gem::Version
|
102
|
+
version: '0'
|
103
|
+
requirements: []
|
104
|
+
rubygems_version: 3.0.3
|
105
|
+
signing_key:
|
106
|
+
specification_version: 4
|
107
|
+
summary: TensorFlow - the end-to-end machine learning platform - for Ruby
|
108
|
+
test_files: []
|