tensorflow 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +7 -4
- data/lib/tensorflow/ffi.rb +7 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +17 -0
- data/lib/tensorflow/keras/layers/dense.rb +10 -0
- data/lib/tensorflow/keras/layers/dropout.rb +10 -0
- data/lib/tensorflow/keras/layers/flatten.rb +10 -0
- data/lib/tensorflow/keras/models/sequential.rb +31 -0
- data/lib/tensorflow/math.rb +465 -0
- data/lib/tensorflow/ops.rb +51 -0
- data/lib/tensorflow/raw_ops.rb +4606 -0
- data/lib/tensorflow/tensor.rb +79 -61
- data/lib/tensorflow/utils.rb +133 -14
- data/lib/tensorflow/variable.rb +6 -6
- data/lib/tensorflow/version.rb +1 -1
- data/lib/tensorflow.rb +21 -147
- metadata +52 -2
data/lib/tensorflow/tensor.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
class Tensor
|
3
|
-
def initialize(value = nil,
|
3
|
+
def initialize(value = nil, dtype: nil, shape: nil, pointer: nil)
|
4
4
|
@status = FFI.TF_NewStatus
|
5
5
|
|
6
6
|
if pointer
|
@@ -18,17 +18,27 @@ module TensorFlow
|
|
18
18
|
|
19
19
|
data = data.flatten
|
20
20
|
|
21
|
-
dtype ||= Utils.infer_type(
|
21
|
+
dtype ||= Utils.infer_type(data)
|
22
22
|
type = FFI::DataType[dtype]
|
23
23
|
case dtype
|
24
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
25
|
+
data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
|
26
|
+
data_ptr.send("write_array_of_#{dtype}", data)
|
27
|
+
when :bfloat16
|
28
|
+
# https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
|
29
|
+
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
|
30
|
+
data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
|
31
|
+
when :complex64
|
32
|
+
data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
|
33
|
+
data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
|
34
|
+
when :complex128
|
35
|
+
data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
|
36
|
+
data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
|
24
37
|
when :string
|
25
38
|
data_ptr = string_ptr(data)
|
26
|
-
when :
|
27
|
-
data_ptr = ::FFI::MemoryPointer.new(:
|
28
|
-
data_ptr.
|
29
|
-
when :int32
|
30
|
-
data_ptr = ::FFI::MemoryPointer.new(:int32, data.size)
|
31
|
-
data_ptr.write_array_of_int32(data)
|
39
|
+
when :bool
|
40
|
+
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
|
41
|
+
data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
|
32
42
|
else
|
33
43
|
raise "Unknown type: #{dtype}"
|
34
44
|
end
|
@@ -42,74 +52,41 @@ module TensorFlow
|
|
42
52
|
check_status @status
|
43
53
|
end
|
44
54
|
|
45
|
-
|
46
|
-
# ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
55
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer, @status, tensor))
|
47
56
|
end
|
48
57
|
|
49
58
|
def +(other)
|
50
|
-
|
59
|
+
Math.add(self, other)
|
51
60
|
end
|
52
61
|
|
53
62
|
def -(other)
|
54
|
-
|
63
|
+
Math.subtract(self, other)
|
55
64
|
end
|
56
65
|
|
57
66
|
def *(other)
|
58
|
-
|
67
|
+
Math.multiply(self, other)
|
59
68
|
end
|
60
69
|
|
61
70
|
def /(other)
|
62
|
-
|
71
|
+
Math.divide(self, other)
|
63
72
|
end
|
64
73
|
|
65
74
|
def %(other)
|
66
|
-
|
67
|
-
end
|
68
|
-
|
69
|
-
def num_dims
|
70
|
-
ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
|
71
|
-
check_status @status
|
72
|
-
ret
|
73
|
-
end
|
74
|
-
|
75
|
-
def dtype
|
76
|
-
@dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
|
77
|
-
end
|
78
|
-
|
79
|
-
def element_count
|
80
|
-
ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
|
81
|
-
check_status @status
|
82
|
-
ret
|
83
|
-
end
|
84
|
-
|
85
|
-
def shape
|
86
|
-
@shape ||= begin
|
87
|
-
shape = []
|
88
|
-
num_dims.times do |i|
|
89
|
-
shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
|
90
|
-
check_status @status
|
91
|
-
end
|
92
|
-
shape
|
93
|
-
end
|
94
|
-
end
|
95
|
-
|
96
|
-
def data_pointer
|
97
|
-
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
98
|
-
check_status @status
|
99
|
-
FFI.TF_TensorData(tensor)
|
100
|
-
end
|
101
|
-
|
102
|
-
def to_ptr
|
103
|
-
@pointer
|
75
|
+
Math.floormod(self, other)
|
104
76
|
end
|
105
77
|
|
106
78
|
def value
|
107
79
|
value =
|
108
80
|
case dtype
|
109
|
-
when :float
|
110
|
-
data_pointer.
|
111
|
-
when :
|
112
|
-
data_pointer.
|
81
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
82
|
+
data_pointer.send("read_array_of_#{dtype}", element_count)
|
83
|
+
when :bfloat16
|
84
|
+
byte_str = data_pointer.read_bytes(element_count * 2)
|
85
|
+
element_count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
86
|
+
when :complex64
|
87
|
+
data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) }
|
88
|
+
when :complex128
|
89
|
+
data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) }
|
113
90
|
when :string
|
114
91
|
# string tensor format
|
115
92
|
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
@@ -127,6 +104,21 @@ module TensorFlow
|
|
127
104
|
reshape(value, shape)
|
128
105
|
end
|
129
106
|
|
107
|
+
def dtype
|
108
|
+
@dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
|
109
|
+
end
|
110
|
+
|
111
|
+
def shape
|
112
|
+
@shape ||= begin
|
113
|
+
shape = []
|
114
|
+
num_dims.times do |i|
|
115
|
+
shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
|
116
|
+
check_status @status
|
117
|
+
end
|
118
|
+
shape
|
119
|
+
end
|
120
|
+
end
|
121
|
+
|
130
122
|
def to_s
|
131
123
|
inspect
|
132
124
|
end
|
@@ -139,18 +131,44 @@ module TensorFlow
|
|
139
131
|
value
|
140
132
|
end
|
141
133
|
|
134
|
+
def to_ptr
|
135
|
+
@pointer
|
136
|
+
end
|
137
|
+
|
142
138
|
def inspect
|
143
139
|
inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
144
140
|
"#<#{self.class} #{inspection.join(", ")}>"
|
145
141
|
end
|
146
142
|
|
147
|
-
def self.finalize(pointer)
|
143
|
+
def self.finalize(pointer, status, tensor)
|
148
144
|
# must use proc instead of stabby lambda
|
149
|
-
proc
|
145
|
+
proc do
|
146
|
+
FFI.TFE_DeleteTensorHandle(pointer)
|
147
|
+
FFI.TFE_DeleteStatus(status)
|
148
|
+
FFI.TFE_DeleteTensor(tensor) if tensor
|
149
|
+
end
|
150
150
|
end
|
151
151
|
|
152
152
|
private
|
153
153
|
|
154
|
+
def num_dims
|
155
|
+
ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
|
156
|
+
check_status @status
|
157
|
+
ret
|
158
|
+
end
|
159
|
+
|
160
|
+
def element_count
|
161
|
+
ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
|
162
|
+
check_status @status
|
163
|
+
ret
|
164
|
+
end
|
165
|
+
|
166
|
+
def data_pointer
|
167
|
+
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
168
|
+
check_status @status
|
169
|
+
FFI.TF_TensorData(tensor)
|
170
|
+
end
|
171
|
+
|
154
172
|
def reshape(arr, dims)
|
155
173
|
return arr.first if dims.empty?
|
156
174
|
arr = arr.flatten
|
@@ -176,12 +194,12 @@ module TensorFlow
|
|
176
194
|
start_offset_size = data.size * 8
|
177
195
|
offsets = [0]
|
178
196
|
data.each do |str|
|
179
|
-
offsets << str.bytesize + 1
|
197
|
+
offsets << offsets.last + str.bytesize + 1
|
180
198
|
end
|
181
199
|
data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
|
182
200
|
data_ptr.write_array_of_uint64(offsets)
|
183
|
-
data.zip(offsets) do |str,
|
184
|
-
(data_ptr + start_offset_size +
|
201
|
+
data.zip(offsets) do |str, offset|
|
202
|
+
(data_ptr + start_offset_size + offset).write_string(str)
|
185
203
|
end
|
186
204
|
data_ptr
|
187
205
|
end
|
data/lib/tensorflow/utils.rb
CHANGED
@@ -1,21 +1,140 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
module Utils
|
3
|
-
|
4
|
-
|
5
|
-
|
3
|
+
class << self
|
4
|
+
def check_status(status)
|
5
|
+
if FFI.TF_GetCode(status) != 0
|
6
|
+
raise Error, FFI.TF_Message(status)
|
7
|
+
end
|
6
8
|
end
|
7
|
-
end
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
10
|
+
def default_context
|
11
|
+
@default_context ||= Context.new
|
12
|
+
end
|
13
|
+
|
14
|
+
def execute(op_name, inputs = [], **attrs)
|
15
|
+
context = default_context
|
16
|
+
status = FFI.TF_NewStatus # TODO reuse status between ops?
|
17
|
+
op = FFI.TFE_NewOp(context, op_name, status)
|
18
|
+
check_status status
|
19
|
+
|
20
|
+
attrs.each do |attr_name, attr_value|
|
21
|
+
next if attr_value.nil?
|
22
|
+
|
23
|
+
attr_name = attr_name.to_s
|
24
|
+
|
25
|
+
is_list = ::FFI::MemoryPointer.new(:int)
|
26
|
+
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
27
|
+
check_status status
|
28
|
+
|
29
|
+
case FFI::AttrType[type]
|
30
|
+
when :string
|
31
|
+
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
32
|
+
# when :int
|
33
|
+
# when :float
|
34
|
+
# when :bool
|
35
|
+
when :type
|
36
|
+
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
37
|
+
when :shape
|
38
|
+
# TODO set value properly
|
39
|
+
FFI.TFE_OpSetAttrShape(op, attr_name, nil, 0, status)
|
40
|
+
check_status status
|
41
|
+
# when :tensor
|
42
|
+
# when :placeholder
|
43
|
+
# when :func
|
44
|
+
else
|
45
|
+
raise "Unknown type: #{FFI::AttrType[type]}"
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|
49
|
+
inputs.each do |input|
|
50
|
+
input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
51
|
+
FFI.TFE_OpAddInput(op, input, status)
|
52
|
+
check_status status
|
53
|
+
end
|
54
|
+
|
55
|
+
retvals = ::FFI::MemoryPointer.new(:pointer)
|
56
|
+
num_retvals = ::FFI::MemoryPointer.new(:int)
|
57
|
+
num_retvals.write_int(retvals.size)
|
58
|
+
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
59
|
+
check_status status
|
60
|
+
|
61
|
+
if num_retvals.read_int > 0
|
62
|
+
handle = retvals.read_pointer
|
63
|
+
type = FFI.TFE_TensorHandleDataType(handle)
|
64
|
+
|
65
|
+
case FFI::DataType[type]
|
66
|
+
when :resource
|
67
|
+
handle
|
68
|
+
else
|
69
|
+
Tensor.new(pointer: handle)
|
70
|
+
end
|
71
|
+
end
|
72
|
+
ensure
|
73
|
+
FFI.TF_DeleteStatus(status) if status
|
74
|
+
FFI.TFE_DeleteOp(op) if op
|
75
|
+
end
|
76
|
+
|
77
|
+
def infer_type(value)
|
78
|
+
if value.all? { |v| v.is_a?(String) }
|
79
|
+
:string
|
80
|
+
elsif value.all? { |v| v == true || v == false }
|
81
|
+
:bool
|
82
|
+
elsif value.all? { |v| v.is_a?(Integer) }
|
83
|
+
if value.all? { |v| v >= -2147483648 && v <= 2147483647 }
|
84
|
+
:int32
|
85
|
+
else
|
86
|
+
:int64
|
87
|
+
end
|
88
|
+
elsif value.all? { |v| v.is_a?(Complex) }
|
89
|
+
:complex128
|
90
|
+
elsif value.all? { |v| v.is_a?(Numeric) }
|
91
|
+
:float
|
92
|
+
else
|
93
|
+
raise Error, "Unable to infer data type"
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def load_dataset(path, url)
|
98
|
+
# TODO handle this better
|
99
|
+
raise "No HOME" unless ENV["HOME"]
|
100
|
+
datasets_dir = "#{ENV["HOME"]}/.keras/datasets"
|
101
|
+
FileUtils.mkdir_p(datasets_dir)
|
102
|
+
|
103
|
+
path = "#{datasets_dir}/#{path}"
|
104
|
+
Utils.download_file(url, path) unless File.exist?(path)
|
105
|
+
Npy.load_npz(path)
|
106
|
+
end
|
107
|
+
|
108
|
+
def download_file(url, dest)
|
109
|
+
uri = URI(url)
|
110
|
+
|
111
|
+
temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
|
112
|
+
temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
|
113
|
+
|
114
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
115
|
+
# of response bodies and automatically decompresses gzip
|
116
|
+
# and deflateresponses unless a Range header was sent.
|
117
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
118
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
|
119
|
+
request = Net::HTTP::Get.new(uri)
|
120
|
+
|
121
|
+
print("Downloading dataset")
|
122
|
+
i = 0
|
123
|
+
File.open(temp_path, "wb") do |f|
|
124
|
+
http.request(request) do |response|
|
125
|
+
response.read_body do |chunk|
|
126
|
+
f.write(chunk)
|
127
|
+
|
128
|
+
# print progress
|
129
|
+
putc "." if i % 50 == 0
|
130
|
+
i += 1
|
131
|
+
end
|
132
|
+
end
|
133
|
+
puts # newline
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
FileUtils.mv(temp_path, dest)
|
19
138
|
end
|
20
139
|
end
|
21
140
|
end
|
data/lib/tensorflow/variable.rb
CHANGED
@@ -1,31 +1,31 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
class Variable
|
3
3
|
def initialize(initial_value, dtype: nil)
|
4
|
-
@dtype = dtype || Utils.infer_type(initial_value)
|
5
|
-
@pointer =
|
4
|
+
@dtype = dtype || Utils.infer_type(Array(initial_value).flatten)
|
5
|
+
@pointer = RawOps.var_handle_op(dtype: type_enum, shape: [], shared_name: Utils.default_context.shared_name)
|
6
6
|
assign(initial_value)
|
7
7
|
end
|
8
8
|
|
9
9
|
def assign(value)
|
10
10
|
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
11
|
-
|
11
|
+
RawOps.assign_variable_op(resource: @pointer, value: value)
|
12
12
|
self
|
13
13
|
end
|
14
14
|
|
15
15
|
def assign_add(value)
|
16
16
|
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
17
|
-
|
17
|
+
RawOps.assign_add_variable_op(resource: @pointer, value: value)
|
18
18
|
self
|
19
19
|
end
|
20
20
|
|
21
21
|
def assign_sub(value)
|
22
22
|
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
23
|
-
|
23
|
+
RawOps.assign_sub_variable_op(resource: @pointer, value: value)
|
24
24
|
self
|
25
25
|
end
|
26
26
|
|
27
27
|
def read_value
|
28
|
-
|
28
|
+
RawOps.read_variable_op(resource: @pointer, dtype: type_enum)
|
29
29
|
end
|
30
30
|
|
31
31
|
def +(other)
|
data/lib/tensorflow/version.rb
CHANGED
data/lib/tensorflow.rb
CHANGED
@@ -1,13 +1,30 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "ffi"
|
3
|
+
require "npy"
|
4
|
+
|
5
|
+
# stdlib
|
6
|
+
require "fileutils"
|
7
|
+
require "forwardable"
|
8
|
+
require "net/http"
|
9
|
+
require "tempfile"
|
3
10
|
|
4
11
|
# modules
|
5
12
|
require "tensorflow/utils"
|
6
13
|
require "tensorflow/context"
|
14
|
+
require "tensorflow/math"
|
15
|
+
require "tensorflow/ops"
|
16
|
+
require "tensorflow/raw_ops"
|
7
17
|
require "tensorflow/tensor"
|
8
18
|
require "tensorflow/variable"
|
9
19
|
require "tensorflow/version"
|
10
20
|
|
21
|
+
# keras
|
22
|
+
require "tensorflow/keras/datasets/mnist"
|
23
|
+
require "tensorflow/keras/layers/dense"
|
24
|
+
require "tensorflow/keras/layers/dropout"
|
25
|
+
require "tensorflow/keras/layers/flatten"
|
26
|
+
require "tensorflow/keras/models/sequential"
|
27
|
+
|
11
28
|
module TensorFlow
|
12
29
|
class Error < StandardError; end
|
13
30
|
|
@@ -20,8 +37,12 @@ module TensorFlow
|
|
20
37
|
autoload :FFI, "tensorflow/ffi"
|
21
38
|
|
22
39
|
class << self
|
40
|
+
include Ops
|
23
41
|
include Utils
|
24
42
|
|
43
|
+
extend Forwardable
|
44
|
+
def_delegators Math, :abs, :acos, :acosh, :add, :add_n, :argmax, :argmin, :asin, :asinh, :atan, :atan2, :atanh, :cos, :cosh, :cumsum, :divide, :equal, :exp, :floor, :greater, :greater_equal, :less, :less_equal, :logical_and, :logical_not, :logical_or, :maximum, :minimum, :multiply, :negative, :not_equal, :pow, :reduce_all, :reduce_any, :reduce_logsumexp, :reduce_max, :reduce_mean, :reduce_min, :reduce_prod, :reduce_sum, :round, :scalar_mul, :sigmoid, :sign, :sin, :sinh, :sqrt, :square, :subtract, :tan, :tanh, :truediv
|
45
|
+
|
25
46
|
def library_version
|
26
47
|
FFI.TF_Version
|
27
48
|
end
|
@@ -34,153 +55,6 @@ module TensorFlow
|
|
34
55
|
value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor)
|
35
56
|
value
|
36
57
|
end
|
37
|
-
|
38
|
-
def add(x, y)
|
39
|
-
execute("Add", [x, y])
|
40
|
-
end
|
41
|
-
|
42
|
-
def subtract(x, y)
|
43
|
-
execute("Sub", [x, y])
|
44
|
-
end
|
45
|
-
|
46
|
-
def multiply(x, y)
|
47
|
-
execute("Mul", [x, y])
|
48
|
-
end
|
49
|
-
|
50
|
-
def divide(x, y)
|
51
|
-
execute("Div", [x, y])
|
52
|
-
end
|
53
|
-
|
54
|
-
def abs(x)
|
55
|
-
execute("Abs", [x])
|
56
|
-
end
|
57
|
-
|
58
|
-
def sqrt(x)
|
59
|
-
execute("Sqrt", [x])
|
60
|
-
end
|
61
|
-
|
62
|
-
def matmul(x, y)
|
63
|
-
execute("MatMul", [x, y])
|
64
|
-
end
|
65
|
-
|
66
|
-
def floormod(x, y)
|
67
|
-
execute("Mod", [x, y])
|
68
|
-
end
|
69
|
-
|
70
|
-
def range(start, limit, delta)
|
71
|
-
execute("Range", [start, limit, delta])
|
72
|
-
end
|
73
|
-
|
74
|
-
def transpose(x, perm: [1, 0])
|
75
|
-
execute("Transpose", [x, perm])
|
76
|
-
end
|
77
|
-
|
78
|
-
def equal(x, y)
|
79
|
-
execute("Equal", [x, y])
|
80
|
-
end
|
81
|
-
|
82
|
-
def zeros_like(x)
|
83
|
-
execute("ZerosLike", [x])
|
84
|
-
end
|
85
|
-
|
86
|
-
def fill(dims, value)
|
87
|
-
execute("Fill", [dims, value])
|
88
|
-
end
|
89
|
-
|
90
|
-
def zeros(dims)
|
91
|
-
fill(dims, 0)
|
92
|
-
end
|
93
|
-
|
94
|
-
def ones(dims)
|
95
|
-
fill(dims, 1)
|
96
|
-
end
|
97
|
-
|
98
|
-
def assign_add_variable_op(resource, value)
|
99
|
-
execute("AssignAddVariableOp", [resource, value])
|
100
|
-
end
|
101
|
-
|
102
|
-
def assign_sub_variable_op(resource, value)
|
103
|
-
execute("AssignSubVariableOp", [resource, value])
|
104
|
-
end
|
105
|
-
|
106
|
-
def assign_variable_op(resource, value)
|
107
|
-
execute("AssignVariableOp", [resource, value])
|
108
|
-
end
|
109
|
-
|
110
|
-
def read_variable_op(resource, dtype)
|
111
|
-
execute("ReadVariableOp", [resource], dtype: dtype)
|
112
|
-
end
|
113
|
-
|
114
|
-
def var_handle_op(dtype, shape, container: "", shared_name: "")
|
115
|
-
execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape)
|
116
|
-
end
|
117
|
-
|
118
|
-
def var_is_initialized_op(resource)
|
119
|
-
execute("VarIsInitializedOp", [resource])
|
120
|
-
end
|
121
|
-
|
122
|
-
private
|
123
|
-
|
124
|
-
def default_context
|
125
|
-
@default_context ||= Context.new
|
126
|
-
end
|
127
|
-
|
128
|
-
def execute(op_name, inputs = [], **attrs)
|
129
|
-
context = default_context
|
130
|
-
status = FFI.TF_NewStatus
|
131
|
-
op = FFI.TFE_NewOp(context, op_name, status)
|
132
|
-
check_status status
|
133
|
-
# TODO clean up status and op
|
134
|
-
|
135
|
-
attrs.each do |attr_name, attr_value|
|
136
|
-
attr_name = attr_name.to_s
|
137
|
-
|
138
|
-
is_list = ::FFI::MemoryPointer.new(:int)
|
139
|
-
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
140
|
-
check_status status
|
141
|
-
|
142
|
-
case FFI::AttrType[type]
|
143
|
-
when :type
|
144
|
-
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
145
|
-
when :shape
|
146
|
-
# TODO set value properly
|
147
|
-
FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status)
|
148
|
-
check_status status
|
149
|
-
when :string
|
150
|
-
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
151
|
-
else
|
152
|
-
raise "Unknown type: #{FFI::AttrType[type]}"
|
153
|
-
end
|
154
|
-
end
|
155
|
-
|
156
|
-
inputs.each do |input|
|
157
|
-
input = convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
158
|
-
FFI.TFE_OpAddInput(op, input, status)
|
159
|
-
check_status status
|
160
|
-
end
|
161
|
-
|
162
|
-
retvals = ::FFI::MemoryPointer.new(:pointer)
|
163
|
-
num_retvals = ::FFI::MemoryPointer.new(:int)
|
164
|
-
num_retvals.write_int(retvals.size)
|
165
|
-
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
166
|
-
check_status status
|
167
|
-
|
168
|
-
if num_retvals.read_int > 0
|
169
|
-
handle = retvals.read_pointer
|
170
|
-
type = FFI.TFE_TensorHandleDataType(handle)
|
171
|
-
|
172
|
-
case FFI::DataType[type]
|
173
|
-
when :resource
|
174
|
-
handle
|
175
|
-
else
|
176
|
-
Tensor.new(pointer: handle)
|
177
|
-
end
|
178
|
-
end
|
179
|
-
end
|
180
|
-
|
181
|
-
def check_status(status)
|
182
|
-
Utils.check_status(status)
|
183
|
-
end
|
184
58
|
end
|
185
59
|
end
|
186
60
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: tensorflow
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-09-
|
11
|
+
date: 2019-09-20 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|
@@ -24,6 +24,20 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: npy
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
27
41
|
- !ruby/object:Gem::Dependency
|
28
42
|
name: bundler
|
29
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -66,6 +80,34 @@ dependencies:
|
|
66
80
|
- - ">="
|
67
81
|
- !ruby/object:Gem::Version
|
68
82
|
version: '5'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: google-protobuf
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
- !ruby/object:Gem::Dependency
|
98
|
+
name: nokogiri
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
type: :development
|
105
|
+
prerelease: false
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
69
111
|
description:
|
70
112
|
email: andrew@chartkick.com
|
71
113
|
executables: []
|
@@ -78,6 +120,14 @@ files:
|
|
78
120
|
- lib/tensorflow.rb
|
79
121
|
- lib/tensorflow/context.rb
|
80
122
|
- lib/tensorflow/ffi.rb
|
123
|
+
- lib/tensorflow/keras/datasets/mnist.rb
|
124
|
+
- lib/tensorflow/keras/layers/dense.rb
|
125
|
+
- lib/tensorflow/keras/layers/dropout.rb
|
126
|
+
- lib/tensorflow/keras/layers/flatten.rb
|
127
|
+
- lib/tensorflow/keras/models/sequential.rb
|
128
|
+
- lib/tensorflow/math.rb
|
129
|
+
- lib/tensorflow/ops.rb
|
130
|
+
- lib/tensorflow/raw_ops.rb
|
81
131
|
- lib/tensorflow/tensor.rb
|
82
132
|
- lib/tensorflow/utils.rb
|
83
133
|
- lib/tensorflow/variable.rb
|