tensorflow 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +7 -4
- data/lib/tensorflow/ffi.rb +7 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +17 -0
- data/lib/tensorflow/keras/layers/dense.rb +10 -0
- data/lib/tensorflow/keras/layers/dropout.rb +10 -0
- data/lib/tensorflow/keras/layers/flatten.rb +10 -0
- data/lib/tensorflow/keras/models/sequential.rb +31 -0
- data/lib/tensorflow/math.rb +465 -0
- data/lib/tensorflow/ops.rb +51 -0
- data/lib/tensorflow/raw_ops.rb +4606 -0
- data/lib/tensorflow/tensor.rb +79 -61
- data/lib/tensorflow/utils.rb +133 -14
- data/lib/tensorflow/variable.rb +6 -6
- data/lib/tensorflow/version.rb +1 -1
- data/lib/tensorflow.rb +21 -147
- metadata +52 -2
data/lib/tensorflow/tensor.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
class Tensor
|
3
|
-
def initialize(value = nil,
|
3
|
+
def initialize(value = nil, dtype: nil, shape: nil, pointer: nil)
|
4
4
|
@status = FFI.TF_NewStatus
|
5
5
|
|
6
6
|
if pointer
|
@@ -18,17 +18,27 @@ module TensorFlow
|
|
18
18
|
|
19
19
|
data = data.flatten
|
20
20
|
|
21
|
-
dtype ||= Utils.infer_type(
|
21
|
+
dtype ||= Utils.infer_type(data)
|
22
22
|
type = FFI::DataType[dtype]
|
23
23
|
case dtype
|
24
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
25
|
+
data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
|
26
|
+
data_ptr.send("write_array_of_#{dtype}", data)
|
27
|
+
when :bfloat16
|
28
|
+
# https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
|
29
|
+
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
|
30
|
+
data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
|
31
|
+
when :complex64
|
32
|
+
data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
|
33
|
+
data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
|
34
|
+
when :complex128
|
35
|
+
data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
|
36
|
+
data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
|
24
37
|
when :string
|
25
38
|
data_ptr = string_ptr(data)
|
26
|
-
when :
|
27
|
-
data_ptr = ::FFI::MemoryPointer.new(:
|
28
|
-
data_ptr.
|
29
|
-
when :int32
|
30
|
-
data_ptr = ::FFI::MemoryPointer.new(:int32, data.size)
|
31
|
-
data_ptr.write_array_of_int32(data)
|
39
|
+
when :bool
|
40
|
+
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
|
41
|
+
data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
|
32
42
|
else
|
33
43
|
raise "Unknown type: #{dtype}"
|
34
44
|
end
|
@@ -42,74 +52,41 @@ module TensorFlow
|
|
42
52
|
check_status @status
|
43
53
|
end
|
44
54
|
|
45
|
-
|
46
|
-
# ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
55
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer, @status, tensor))
|
47
56
|
end
|
48
57
|
|
49
58
|
def +(other)
|
50
|
-
|
59
|
+
Math.add(self, other)
|
51
60
|
end
|
52
61
|
|
53
62
|
def -(other)
|
54
|
-
|
63
|
+
Math.subtract(self, other)
|
55
64
|
end
|
56
65
|
|
57
66
|
def *(other)
|
58
|
-
|
67
|
+
Math.multiply(self, other)
|
59
68
|
end
|
60
69
|
|
61
70
|
def /(other)
|
62
|
-
|
71
|
+
Math.divide(self, other)
|
63
72
|
end
|
64
73
|
|
65
74
|
def %(other)
|
66
|
-
|
67
|
-
end
|
68
|
-
|
69
|
-
def num_dims
|
70
|
-
ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
|
71
|
-
check_status @status
|
72
|
-
ret
|
73
|
-
end
|
74
|
-
|
75
|
-
def dtype
|
76
|
-
@dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
|
77
|
-
end
|
78
|
-
|
79
|
-
def element_count
|
80
|
-
ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
|
81
|
-
check_status @status
|
82
|
-
ret
|
83
|
-
end
|
84
|
-
|
85
|
-
def shape
|
86
|
-
@shape ||= begin
|
87
|
-
shape = []
|
88
|
-
num_dims.times do |i|
|
89
|
-
shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
|
90
|
-
check_status @status
|
91
|
-
end
|
92
|
-
shape
|
93
|
-
end
|
94
|
-
end
|
95
|
-
|
96
|
-
def data_pointer
|
97
|
-
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
98
|
-
check_status @status
|
99
|
-
FFI.TF_TensorData(tensor)
|
100
|
-
end
|
101
|
-
|
102
|
-
def to_ptr
|
103
|
-
@pointer
|
75
|
+
Math.floormod(self, other)
|
104
76
|
end
|
105
77
|
|
106
78
|
def value
|
107
79
|
value =
|
108
80
|
case dtype
|
109
|
-
when :float
|
110
|
-
data_pointer.
|
111
|
-
when :
|
112
|
-
data_pointer.
|
81
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
82
|
+
data_pointer.send("read_array_of_#{dtype}", element_count)
|
83
|
+
when :bfloat16
|
84
|
+
byte_str = data_pointer.read_bytes(element_count * 2)
|
85
|
+
element_count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
86
|
+
when :complex64
|
87
|
+
data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) }
|
88
|
+
when :complex128
|
89
|
+
data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) }
|
113
90
|
when :string
|
114
91
|
# string tensor format
|
115
92
|
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
@@ -127,6 +104,21 @@ module TensorFlow
|
|
127
104
|
reshape(value, shape)
|
128
105
|
end
|
129
106
|
|
107
|
+
def dtype
|
108
|
+
@dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
|
109
|
+
end
|
110
|
+
|
111
|
+
def shape
|
112
|
+
@shape ||= begin
|
113
|
+
shape = []
|
114
|
+
num_dims.times do |i|
|
115
|
+
shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
|
116
|
+
check_status @status
|
117
|
+
end
|
118
|
+
shape
|
119
|
+
end
|
120
|
+
end
|
121
|
+
|
130
122
|
def to_s
|
131
123
|
inspect
|
132
124
|
end
|
@@ -139,18 +131,44 @@ module TensorFlow
|
|
139
131
|
value
|
140
132
|
end
|
141
133
|
|
134
|
+
def to_ptr
|
135
|
+
@pointer
|
136
|
+
end
|
137
|
+
|
142
138
|
def inspect
|
143
139
|
inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
144
140
|
"#<#{self.class} #{inspection.join(", ")}>"
|
145
141
|
end
|
146
142
|
|
147
|
-
def self.finalize(pointer)
|
143
|
+
def self.finalize(pointer, status, tensor)
|
148
144
|
# must use proc instead of stabby lambda
|
149
|
-
proc
|
145
|
+
proc do
|
146
|
+
FFI.TFE_DeleteTensorHandle(pointer)
|
147
|
+
FFI.TFE_DeleteStatus(status)
|
148
|
+
FFI.TFE_DeleteTensor(tensor) if tensor
|
149
|
+
end
|
150
150
|
end
|
151
151
|
|
152
152
|
private
|
153
153
|
|
154
|
+
def num_dims
|
155
|
+
ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
|
156
|
+
check_status @status
|
157
|
+
ret
|
158
|
+
end
|
159
|
+
|
160
|
+
def element_count
|
161
|
+
ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
|
162
|
+
check_status @status
|
163
|
+
ret
|
164
|
+
end
|
165
|
+
|
166
|
+
def data_pointer
|
167
|
+
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
168
|
+
check_status @status
|
169
|
+
FFI.TF_TensorData(tensor)
|
170
|
+
end
|
171
|
+
|
154
172
|
def reshape(arr, dims)
|
155
173
|
return arr.first if dims.empty?
|
156
174
|
arr = arr.flatten
|
@@ -176,12 +194,12 @@ module TensorFlow
|
|
176
194
|
start_offset_size = data.size * 8
|
177
195
|
offsets = [0]
|
178
196
|
data.each do |str|
|
179
|
-
offsets << str.bytesize + 1
|
197
|
+
offsets << offsets.last + str.bytesize + 1
|
180
198
|
end
|
181
199
|
data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
|
182
200
|
data_ptr.write_array_of_uint64(offsets)
|
183
|
-
data.zip(offsets) do |str,
|
184
|
-
(data_ptr + start_offset_size +
|
201
|
+
data.zip(offsets) do |str, offset|
|
202
|
+
(data_ptr + start_offset_size + offset).write_string(str)
|
185
203
|
end
|
186
204
|
data_ptr
|
187
205
|
end
|
data/lib/tensorflow/utils.rb
CHANGED
@@ -1,21 +1,140 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
module Utils
|
3
|
-
|
4
|
-
|
5
|
-
|
3
|
+
class << self
|
4
|
+
def check_status(status)
|
5
|
+
if FFI.TF_GetCode(status) != 0
|
6
|
+
raise Error, FFI.TF_Message(status)
|
7
|
+
end
|
6
8
|
end
|
7
|
-
end
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
10
|
+
def default_context
|
11
|
+
@default_context ||= Context.new
|
12
|
+
end
|
13
|
+
|
14
|
+
def execute(op_name, inputs = [], **attrs)
|
15
|
+
context = default_context
|
16
|
+
status = FFI.TF_NewStatus # TODO reuse status between ops?
|
17
|
+
op = FFI.TFE_NewOp(context, op_name, status)
|
18
|
+
check_status status
|
19
|
+
|
20
|
+
attrs.each do |attr_name, attr_value|
|
21
|
+
next if attr_value.nil?
|
22
|
+
|
23
|
+
attr_name = attr_name.to_s
|
24
|
+
|
25
|
+
is_list = ::FFI::MemoryPointer.new(:int)
|
26
|
+
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
27
|
+
check_status status
|
28
|
+
|
29
|
+
case FFI::AttrType[type]
|
30
|
+
when :string
|
31
|
+
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
32
|
+
# when :int
|
33
|
+
# when :float
|
34
|
+
# when :bool
|
35
|
+
when :type
|
36
|
+
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
37
|
+
when :shape
|
38
|
+
# TODO set value properly
|
39
|
+
FFI.TFE_OpSetAttrShape(op, attr_name, nil, 0, status)
|
40
|
+
check_status status
|
41
|
+
# when :tensor
|
42
|
+
# when :placeholder
|
43
|
+
# when :func
|
44
|
+
else
|
45
|
+
raise "Unknown type: #{FFI::AttrType[type]}"
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|
49
|
+
inputs.each do |input|
|
50
|
+
input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
51
|
+
FFI.TFE_OpAddInput(op, input, status)
|
52
|
+
check_status status
|
53
|
+
end
|
54
|
+
|
55
|
+
retvals = ::FFI::MemoryPointer.new(:pointer)
|
56
|
+
num_retvals = ::FFI::MemoryPointer.new(:int)
|
57
|
+
num_retvals.write_int(retvals.size)
|
58
|
+
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
59
|
+
check_status status
|
60
|
+
|
61
|
+
if num_retvals.read_int > 0
|
62
|
+
handle = retvals.read_pointer
|
63
|
+
type = FFI.TFE_TensorHandleDataType(handle)
|
64
|
+
|
65
|
+
case FFI::DataType[type]
|
66
|
+
when :resource
|
67
|
+
handle
|
68
|
+
else
|
69
|
+
Tensor.new(pointer: handle)
|
70
|
+
end
|
71
|
+
end
|
72
|
+
ensure
|
73
|
+
FFI.TF_DeleteStatus(status) if status
|
74
|
+
FFI.TFE_DeleteOp(op) if op
|
75
|
+
end
|
76
|
+
|
77
|
+
def infer_type(value)
|
78
|
+
if value.all? { |v| v.is_a?(String) }
|
79
|
+
:string
|
80
|
+
elsif value.all? { |v| v == true || v == false }
|
81
|
+
:bool
|
82
|
+
elsif value.all? { |v| v.is_a?(Integer) }
|
83
|
+
if value.all? { |v| v >= -2147483648 && v <= 2147483647 }
|
84
|
+
:int32
|
85
|
+
else
|
86
|
+
:int64
|
87
|
+
end
|
88
|
+
elsif value.all? { |v| v.is_a?(Complex) }
|
89
|
+
:complex128
|
90
|
+
elsif value.all? { |v| v.is_a?(Numeric) }
|
91
|
+
:float
|
92
|
+
else
|
93
|
+
raise Error, "Unable to infer data type"
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def load_dataset(path, url)
|
98
|
+
# TODO handle this better
|
99
|
+
raise "No HOME" unless ENV["HOME"]
|
100
|
+
datasets_dir = "#{ENV["HOME"]}/.keras/datasets"
|
101
|
+
FileUtils.mkdir_p(datasets_dir)
|
102
|
+
|
103
|
+
path = "#{datasets_dir}/#{path}"
|
104
|
+
Utils.download_file(url, path) unless File.exist?(path)
|
105
|
+
Npy.load_npz(path)
|
106
|
+
end
|
107
|
+
|
108
|
+
def download_file(url, dest)
|
109
|
+
uri = URI(url)
|
110
|
+
|
111
|
+
temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
|
112
|
+
temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
|
113
|
+
|
114
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
115
|
+
# of response bodies and automatically decompresses gzip
|
116
|
+
# and deflateresponses unless a Range header was sent.
|
117
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
118
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
|
119
|
+
request = Net::HTTP::Get.new(uri)
|
120
|
+
|
121
|
+
print("Downloading dataset")
|
122
|
+
i = 0
|
123
|
+
File.open(temp_path, "wb") do |f|
|
124
|
+
http.request(request) do |response|
|
125
|
+
response.read_body do |chunk|
|
126
|
+
f.write(chunk)
|
127
|
+
|
128
|
+
# print progress
|
129
|
+
putc "." if i % 50 == 0
|
130
|
+
i += 1
|
131
|
+
end
|
132
|
+
end
|
133
|
+
puts # newline
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
FileUtils.mv(temp_path, dest)
|
19
138
|
end
|
20
139
|
end
|
21
140
|
end
|
data/lib/tensorflow/variable.rb
CHANGED
@@ -1,31 +1,31 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
class Variable
|
3
3
|
def initialize(initial_value, dtype: nil)
|
4
|
-
@dtype = dtype || Utils.infer_type(initial_value)
|
5
|
-
@pointer =
|
4
|
+
@dtype = dtype || Utils.infer_type(Array(initial_value).flatten)
|
5
|
+
@pointer = RawOps.var_handle_op(dtype: type_enum, shape: [], shared_name: Utils.default_context.shared_name)
|
6
6
|
assign(initial_value)
|
7
7
|
end
|
8
8
|
|
9
9
|
def assign(value)
|
10
10
|
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
11
|
-
|
11
|
+
RawOps.assign_variable_op(resource: @pointer, value: value)
|
12
12
|
self
|
13
13
|
end
|
14
14
|
|
15
15
|
def assign_add(value)
|
16
16
|
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
17
|
-
|
17
|
+
RawOps.assign_add_variable_op(resource: @pointer, value: value)
|
18
18
|
self
|
19
19
|
end
|
20
20
|
|
21
21
|
def assign_sub(value)
|
22
22
|
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
|
23
|
-
|
23
|
+
RawOps.assign_sub_variable_op(resource: @pointer, value: value)
|
24
24
|
self
|
25
25
|
end
|
26
26
|
|
27
27
|
def read_value
|
28
|
-
|
28
|
+
RawOps.read_variable_op(resource: @pointer, dtype: type_enum)
|
29
29
|
end
|
30
30
|
|
31
31
|
def +(other)
|
data/lib/tensorflow/version.rb
CHANGED
data/lib/tensorflow.rb
CHANGED
@@ -1,13 +1,30 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "ffi"
|
3
|
+
require "npy"
|
4
|
+
|
5
|
+
# stdlib
|
6
|
+
require "fileutils"
|
7
|
+
require "forwardable"
|
8
|
+
require "net/http"
|
9
|
+
require "tempfile"
|
3
10
|
|
4
11
|
# modules
|
5
12
|
require "tensorflow/utils"
|
6
13
|
require "tensorflow/context"
|
14
|
+
require "tensorflow/math"
|
15
|
+
require "tensorflow/ops"
|
16
|
+
require "tensorflow/raw_ops"
|
7
17
|
require "tensorflow/tensor"
|
8
18
|
require "tensorflow/variable"
|
9
19
|
require "tensorflow/version"
|
10
20
|
|
21
|
+
# keras
|
22
|
+
require "tensorflow/keras/datasets/mnist"
|
23
|
+
require "tensorflow/keras/layers/dense"
|
24
|
+
require "tensorflow/keras/layers/dropout"
|
25
|
+
require "tensorflow/keras/layers/flatten"
|
26
|
+
require "tensorflow/keras/models/sequential"
|
27
|
+
|
11
28
|
module TensorFlow
|
12
29
|
class Error < StandardError; end
|
13
30
|
|
@@ -20,8 +37,12 @@ module TensorFlow
|
|
20
37
|
autoload :FFI, "tensorflow/ffi"
|
21
38
|
|
22
39
|
class << self
|
40
|
+
include Ops
|
23
41
|
include Utils
|
24
42
|
|
43
|
+
extend Forwardable
|
44
|
+
def_delegators Math, :abs, :acos, :acosh, :add, :add_n, :argmax, :argmin, :asin, :asinh, :atan, :atan2, :atanh, :cos, :cosh, :cumsum, :divide, :equal, :exp, :floor, :greater, :greater_equal, :less, :less_equal, :logical_and, :logical_not, :logical_or, :maximum, :minimum, :multiply, :negative, :not_equal, :pow, :reduce_all, :reduce_any, :reduce_logsumexp, :reduce_max, :reduce_mean, :reduce_min, :reduce_prod, :reduce_sum, :round, :scalar_mul, :sigmoid, :sign, :sin, :sinh, :sqrt, :square, :subtract, :tan, :tanh, :truediv
|
45
|
+
|
25
46
|
def library_version
|
26
47
|
FFI.TF_Version
|
27
48
|
end
|
@@ -34,153 +55,6 @@ module TensorFlow
|
|
34
55
|
value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor)
|
35
56
|
value
|
36
57
|
end
|
37
|
-
|
38
|
-
def add(x, y)
|
39
|
-
execute("Add", [x, y])
|
40
|
-
end
|
41
|
-
|
42
|
-
def subtract(x, y)
|
43
|
-
execute("Sub", [x, y])
|
44
|
-
end
|
45
|
-
|
46
|
-
def multiply(x, y)
|
47
|
-
execute("Mul", [x, y])
|
48
|
-
end
|
49
|
-
|
50
|
-
def divide(x, y)
|
51
|
-
execute("Div", [x, y])
|
52
|
-
end
|
53
|
-
|
54
|
-
def abs(x)
|
55
|
-
execute("Abs", [x])
|
56
|
-
end
|
57
|
-
|
58
|
-
def sqrt(x)
|
59
|
-
execute("Sqrt", [x])
|
60
|
-
end
|
61
|
-
|
62
|
-
def matmul(x, y)
|
63
|
-
execute("MatMul", [x, y])
|
64
|
-
end
|
65
|
-
|
66
|
-
def floormod(x, y)
|
67
|
-
execute("Mod", [x, y])
|
68
|
-
end
|
69
|
-
|
70
|
-
def range(start, limit, delta)
|
71
|
-
execute("Range", [start, limit, delta])
|
72
|
-
end
|
73
|
-
|
74
|
-
def transpose(x, perm: [1, 0])
|
75
|
-
execute("Transpose", [x, perm])
|
76
|
-
end
|
77
|
-
|
78
|
-
def equal(x, y)
|
79
|
-
execute("Equal", [x, y])
|
80
|
-
end
|
81
|
-
|
82
|
-
def zeros_like(x)
|
83
|
-
execute("ZerosLike", [x])
|
84
|
-
end
|
85
|
-
|
86
|
-
def fill(dims, value)
|
87
|
-
execute("Fill", [dims, value])
|
88
|
-
end
|
89
|
-
|
90
|
-
def zeros(dims)
|
91
|
-
fill(dims, 0)
|
92
|
-
end
|
93
|
-
|
94
|
-
def ones(dims)
|
95
|
-
fill(dims, 1)
|
96
|
-
end
|
97
|
-
|
98
|
-
def assign_add_variable_op(resource, value)
|
99
|
-
execute("AssignAddVariableOp", [resource, value])
|
100
|
-
end
|
101
|
-
|
102
|
-
def assign_sub_variable_op(resource, value)
|
103
|
-
execute("AssignSubVariableOp", [resource, value])
|
104
|
-
end
|
105
|
-
|
106
|
-
def assign_variable_op(resource, value)
|
107
|
-
execute("AssignVariableOp", [resource, value])
|
108
|
-
end
|
109
|
-
|
110
|
-
def read_variable_op(resource, dtype)
|
111
|
-
execute("ReadVariableOp", [resource], dtype: dtype)
|
112
|
-
end
|
113
|
-
|
114
|
-
def var_handle_op(dtype, shape, container: "", shared_name: "")
|
115
|
-
execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape)
|
116
|
-
end
|
117
|
-
|
118
|
-
def var_is_initialized_op(resource)
|
119
|
-
execute("VarIsInitializedOp", [resource])
|
120
|
-
end
|
121
|
-
|
122
|
-
private
|
123
|
-
|
124
|
-
def default_context
|
125
|
-
@default_context ||= Context.new
|
126
|
-
end
|
127
|
-
|
128
|
-
def execute(op_name, inputs = [], **attrs)
|
129
|
-
context = default_context
|
130
|
-
status = FFI.TF_NewStatus
|
131
|
-
op = FFI.TFE_NewOp(context, op_name, status)
|
132
|
-
check_status status
|
133
|
-
# TODO clean up status and op
|
134
|
-
|
135
|
-
attrs.each do |attr_name, attr_value|
|
136
|
-
attr_name = attr_name.to_s
|
137
|
-
|
138
|
-
is_list = ::FFI::MemoryPointer.new(:int)
|
139
|
-
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
140
|
-
check_status status
|
141
|
-
|
142
|
-
case FFI::AttrType[type]
|
143
|
-
when :type
|
144
|
-
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
145
|
-
when :shape
|
146
|
-
# TODO set value properly
|
147
|
-
FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status)
|
148
|
-
check_status status
|
149
|
-
when :string
|
150
|
-
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
151
|
-
else
|
152
|
-
raise "Unknown type: #{FFI::AttrType[type]}"
|
153
|
-
end
|
154
|
-
end
|
155
|
-
|
156
|
-
inputs.each do |input|
|
157
|
-
input = convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
158
|
-
FFI.TFE_OpAddInput(op, input, status)
|
159
|
-
check_status status
|
160
|
-
end
|
161
|
-
|
162
|
-
retvals = ::FFI::MemoryPointer.new(:pointer)
|
163
|
-
num_retvals = ::FFI::MemoryPointer.new(:int)
|
164
|
-
num_retvals.write_int(retvals.size)
|
165
|
-
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
166
|
-
check_status status
|
167
|
-
|
168
|
-
if num_retvals.read_int > 0
|
169
|
-
handle = retvals.read_pointer
|
170
|
-
type = FFI.TFE_TensorHandleDataType(handle)
|
171
|
-
|
172
|
-
case FFI::DataType[type]
|
173
|
-
when :resource
|
174
|
-
handle
|
175
|
-
else
|
176
|
-
Tensor.new(pointer: handle)
|
177
|
-
end
|
178
|
-
end
|
179
|
-
end
|
180
|
-
|
181
|
-
def check_status(status)
|
182
|
-
Utils.check_status(status)
|
183
|
-
end
|
184
58
|
end
|
185
59
|
end
|
186
60
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: tensorflow
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-09-
|
11
|
+
date: 2019-09-20 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|
@@ -24,6 +24,20 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: npy
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
27
41
|
- !ruby/object:Gem::Dependency
|
28
42
|
name: bundler
|
29
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -66,6 +80,34 @@ dependencies:
|
|
66
80
|
- - ">="
|
67
81
|
- !ruby/object:Gem::Version
|
68
82
|
version: '5'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: google-protobuf
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
- !ruby/object:Gem::Dependency
|
98
|
+
name: nokogiri
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
type: :development
|
105
|
+
prerelease: false
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
69
111
|
description:
|
70
112
|
email: andrew@chartkick.com
|
71
113
|
executables: []
|
@@ -78,6 +120,14 @@ files:
|
|
78
120
|
- lib/tensorflow.rb
|
79
121
|
- lib/tensorflow/context.rb
|
80
122
|
- lib/tensorflow/ffi.rb
|
123
|
+
- lib/tensorflow/keras/datasets/mnist.rb
|
124
|
+
- lib/tensorflow/keras/layers/dense.rb
|
125
|
+
- lib/tensorflow/keras/layers/dropout.rb
|
126
|
+
- lib/tensorflow/keras/layers/flatten.rb
|
127
|
+
- lib/tensorflow/keras/models/sequential.rb
|
128
|
+
- lib/tensorflow/math.rb
|
129
|
+
- lib/tensorflow/ops.rb
|
130
|
+
- lib/tensorflow/raw_ops.rb
|
81
131
|
- lib/tensorflow/tensor.rb
|
82
132
|
- lib/tensorflow/utils.rb
|
83
133
|
- lib/tensorflow/variable.rb
|