tensorflow-ruby 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Eager
|
|
3
|
+
class Operation
|
|
4
|
+
attr_reader :context, :guessed_dtype, :op_def, :status
|
|
5
|
+
|
|
6
|
+
def initialize(context, op_type, inputs, attrs)
|
|
7
|
+
@context = context
|
|
8
|
+
@op_def = case op_type
|
|
9
|
+
when Graph::Function
|
|
10
|
+
op_type.function_def.signature
|
|
11
|
+
else
|
|
12
|
+
Tensorflow.op_def(op_type)
|
|
13
|
+
end
|
|
14
|
+
raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def
|
|
15
|
+
|
|
16
|
+
@status = Status.new
|
|
17
|
+
@pointer = FFI.TFE_NewOp(context, self.op_def.name, self.status)
|
|
18
|
+
name = attrs.delete(:name) || op_type
|
|
19
|
+
|
|
20
|
+
inputs = Array(inputs)
|
|
21
|
+
@guessed_dtype = figure_dtype(attrs, inputs)
|
|
22
|
+
|
|
23
|
+
setup_inputs(inputs, attrs)
|
|
24
|
+
setup_attrs(attrs)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def to_ptr
|
|
28
|
+
@pointer
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def dtype
|
|
32
|
+
list_ptr = ::FFI::MemoryPointer.new(:int)
|
|
33
|
+
FFI.TFE_OpGetAttrType(self, 'dtype', list_ptr, self.status)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def figure_dtype(attrs, inputs)
|
|
37
|
+
attr_def = self.op_def.attr.detect do |attr_def|
|
|
38
|
+
attr_def.type == 'type'
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
result = attr_def ? attrs[attr_def.name.to_sym] : nil
|
|
42
|
+
unless result
|
|
43
|
+
inputs.each do |input|
|
|
44
|
+
case input
|
|
45
|
+
when Operation
|
|
46
|
+
return input.output_types.first
|
|
47
|
+
when Variable
|
|
48
|
+
return input.dtype
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
result
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def setup_attrs(attrs)
|
|
56
|
+
attrs.each do |attr_name, attr_value|
|
|
57
|
+
next unless attr_value
|
|
58
|
+
|
|
59
|
+
attr_name = attr_name.to_s
|
|
60
|
+
list_ptr = ::FFI::MemoryPointer.new(:int)
|
|
61
|
+
type = FFI.TFE_OpGetAttrType(self, attr_name, list_ptr, self.status)
|
|
62
|
+
self.status.check
|
|
63
|
+
is_list = Boolean(list_ptr.read_int)
|
|
64
|
+
|
|
65
|
+
if is_list
|
|
66
|
+
add_list_attr(type, attr_name, attr_value)
|
|
67
|
+
else
|
|
68
|
+
add_scalar_attr(type, attr_name, attr_value)
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def add_list_attr(type, attr_name, attr_value)
|
|
74
|
+
num_values = attr_value.size
|
|
75
|
+
|
|
76
|
+
case type
|
|
77
|
+
when :int
|
|
78
|
+
values = ::FFI::MemoryPointer.new(:int64, num_values)
|
|
79
|
+
values.write_array_of_int64(attr_value)
|
|
80
|
+
FFI.TFE_OpSetAttrIntList(self, attr_name, values, num_values)
|
|
81
|
+
when :float
|
|
82
|
+
values = ::FFI::MemoryPointer.new(:float, num_values)
|
|
83
|
+
values.write_array_of_float(attr_value)
|
|
84
|
+
FFI.TFE_OpSetAttrFloatList(self, attr_name, values, num_values)
|
|
85
|
+
when :shape
|
|
86
|
+
dims_pointer = ::FFI::MemoryPointer.new(:pointer, num_values)
|
|
87
|
+
num_dims_pointer = ::FFI::MemoryPointer.new(:int32, num_values)
|
|
88
|
+
attr_value.each_with_index do |shape, i|
|
|
89
|
+
dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
|
|
90
|
+
dim_pointer.write_array_of_int64(shape)
|
|
91
|
+
dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
|
|
92
|
+
num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
|
|
93
|
+
end
|
|
94
|
+
FFI.TFE_OpSetAttrShapeList(self, attr_name, dims_pointer, num_dims_pointer, num_values, self.status)
|
|
95
|
+
self.status.check
|
|
96
|
+
when :type
|
|
97
|
+
values = ::FFI::MemoryPointer.new(:int, num_values)
|
|
98
|
+
types =
|
|
99
|
+
attr_value.map do |v|
|
|
100
|
+
if v.is_a?(Symbol)
|
|
101
|
+
FFI::DataType[v]
|
|
102
|
+
else
|
|
103
|
+
v
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
values.write_array_of_int(types)
|
|
107
|
+
FFI.TFE_OpSetAttrTypeList(self, attr_name, values, num_values)
|
|
108
|
+
else
|
|
109
|
+
raise "Unknown list type: #{type}"
|
|
110
|
+
end
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def add_scalar_attr(type, attr_name, attr_value)
|
|
114
|
+
case type
|
|
115
|
+
when :string
|
|
116
|
+
FFI.TFE_OpSetAttrString(self, attr_name, attr_value, attr_value.bytesize)
|
|
117
|
+
when :int
|
|
118
|
+
FFI.TFE_OpSetAttrInt(self, attr_name, attr_value)
|
|
119
|
+
when :float
|
|
120
|
+
FFI.TFE_OpSetAttrFloat(self, attr_name, attr_value)
|
|
121
|
+
when :bool
|
|
122
|
+
FFI.TFE_OpSetAttrBool(self, attr_name, attr_value ? 1 : 0)
|
|
123
|
+
when :type
|
|
124
|
+
attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol)
|
|
125
|
+
FFI.TFE_OpSetAttrType(self, attr_name, attr_value)
|
|
126
|
+
when :shape
|
|
127
|
+
ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size)
|
|
128
|
+
ptr.write_array_of_int64(attr_value)
|
|
129
|
+
FFI.TFE_OpSetAttrShape(self, attr_name, ptr, attr_value.size, self.status)
|
|
130
|
+
when :tensor
|
|
131
|
+
attr_value = TensorHandle.from_value(self.context, attr_value)
|
|
132
|
+
FFI.TFE_OpSetAttrTensor(self, attr_name, attr_value.tensor, self.status)
|
|
133
|
+
# when :placeholder
|
|
134
|
+
when :func
|
|
135
|
+
case attr_value
|
|
136
|
+
when Graph::Function
|
|
137
|
+
FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value.name, attr_value.name.length)
|
|
138
|
+
when String
|
|
139
|
+
FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value, attr_value.length)
|
|
140
|
+
else
|
|
141
|
+
self.status.set(:tf_invalid_argument, "Invalid function attribute for attribute: #{attr_name}")
|
|
142
|
+
end
|
|
143
|
+
else
|
|
144
|
+
self.status.set(:tf_unknown, "Unsupported attribute type: #{type}")
|
|
145
|
+
end
|
|
146
|
+
self.status.check
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def setup_inputs(inputs, attrs)
|
|
150
|
+
inputs.each_with_index do |input, index|
|
|
151
|
+
setup_input(index, input, attrs)
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def check_input(arg_def, input, dtype)
|
|
156
|
+
case input
|
|
157
|
+
when Variable
|
|
158
|
+
arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
|
|
159
|
+
else
|
|
160
|
+
TensorHandle.from_value(self.context, input, dtype: dtype)
|
|
161
|
+
end
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def setup_input(index, value, attrs)
|
|
165
|
+
if value.nil?
|
|
166
|
+
self.status.set(:tf_invalid_argument, "Argument is unset. Index: #{index}")
|
|
167
|
+
self.status.check
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
arg_def = self.op_def.input_arg[index]
|
|
171
|
+
dtype = attrs[arg_def.type_attr.to_sym]
|
|
172
|
+
|
|
173
|
+
# Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
|
|
174
|
+
checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array)
|
|
175
|
+
value.map do |sub_value|
|
|
176
|
+
self.check_input(arg_def, sub_value, dtype)
|
|
177
|
+
end
|
|
178
|
+
else
|
|
179
|
+
self.check_input(arg_def, value, dtype)
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
if !arg_def.type_list_attr.empty?
|
|
183
|
+
# This input is a heterogeneous list
|
|
184
|
+
self.add_input_list(checked_value)
|
|
185
|
+
elsif !arg_def.number_attr.empty? && !arg_def.type_attr.empty?
|
|
186
|
+
# This input is a homogeneous list
|
|
187
|
+
self.add_input_list(checked_value)
|
|
188
|
+
elsif !arg_def.number_attr.empty?
|
|
189
|
+
# This is a list but we have to set it up one input at a time
|
|
190
|
+
checked_value.each do |sub_checked_value|
|
|
191
|
+
self.add_input(sub_checked_value)
|
|
192
|
+
end
|
|
193
|
+
else
|
|
194
|
+
# This input is a single item
|
|
195
|
+
self.add_input(checked_value)
|
|
196
|
+
end
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
def add_input(value)
|
|
200
|
+
# Check to see if the operation has multiple outputs, and if it does, we need to pack them together
|
|
201
|
+
# to fit into one input
|
|
202
|
+
if value.is_a?(Array) && value.length > 1
|
|
203
|
+
packed = Tensorflow.pack(value)
|
|
204
|
+
FFI.TFE_OpAddInput(self, packed, self.status)
|
|
205
|
+
else
|
|
206
|
+
FFI.TFE_OpAddInput(self, value, self.status)
|
|
207
|
+
end
|
|
208
|
+
self.status.check
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
def add_input_list(values)
|
|
212
|
+
input_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
|
|
213
|
+
input_ptr.write_array_of_pointer(values)
|
|
214
|
+
FFI.TFE_OpAddInputList(self, input_ptr, values.length, self.status)
|
|
215
|
+
self.status.check
|
|
216
|
+
end
|
|
217
|
+
end
|
|
218
|
+
end
|
|
219
|
+
end
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Eager
|
|
3
|
+
class TensorHandle
|
|
4
|
+
include TensorMixin
|
|
5
|
+
include Operators
|
|
6
|
+
|
|
7
|
+
attr_reader :context
|
|
8
|
+
|
|
9
|
+
def self.finalize(pointer)
|
|
10
|
+
proc do
|
|
11
|
+
FFI.TFE_DeleteTensorHandle(pointer)
|
|
12
|
+
end
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def self.from_value(context, value, dtype: nil)
|
|
16
|
+
case value
|
|
17
|
+
when TensorHandle
|
|
18
|
+
value
|
|
19
|
+
when Data::Dataset
|
|
20
|
+
value.variant_tensor
|
|
21
|
+
when Tensor
|
|
22
|
+
TensorHandle.new(context, value)
|
|
23
|
+
when Variable
|
|
24
|
+
value.value_handle
|
|
25
|
+
else
|
|
26
|
+
TensorHandle.new(context, Tensor.new(value, dtype: dtype))
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def initialize(context, value)
|
|
31
|
+
@context = context
|
|
32
|
+
case value
|
|
33
|
+
when ::FFI::Pointer
|
|
34
|
+
@pointer = value
|
|
35
|
+
when Tensor
|
|
36
|
+
Status.check do |status|
|
|
37
|
+
@pointer = FFI.TFE_NewTensorHandle(value, status)
|
|
38
|
+
end
|
|
39
|
+
# We need to keep the tensor live so that it is not freed!
|
|
40
|
+
@tensor = value
|
|
41
|
+
else
|
|
42
|
+
raise(Error::InvalidArgumentError, "Invalid value passed to tensor_handle: #{value}")
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def to_ptr
|
|
49
|
+
@pointer
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def tensor
|
|
53
|
+
Status.check do |status|
|
|
54
|
+
Tensor.from_pointer(FFI.TFE_TensorHandleResolve(self, status))
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def dtype
|
|
59
|
+
FFI.TFE_TensorHandleDataType(self)
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def element_count
|
|
63
|
+
Status.check do |status|
|
|
64
|
+
FFI.TFE_TensorHandleNumElements(self, status)
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def value
|
|
69
|
+
self.tensor.value
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
private
|
|
73
|
+
|
|
74
|
+
def num_dims
|
|
75
|
+
Status.check do |status|
|
|
76
|
+
FFI.TFE_TensorHandleNumDims(self, status)
|
|
77
|
+
end
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
def dim(index)
|
|
81
|
+
Status.check do |status|
|
|
82
|
+
FFI.TFE_TensorHandleDim(self, index, status)
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Error
|
|
3
|
+
class AbortedError < StandardError
|
|
4
|
+
end
|
|
5
|
+
|
|
6
|
+
class AlreadyExistsError < StandardError
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
class CancelledError < StandardError
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
class DataLossError < StandardError
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
class DeadlineExceededError < StandardError
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
class FailedPreconditionError < StandardError
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
class InternalError < StandardError
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
class InvalidArgumentError < StandardError
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
class NotFoundError < StandardError
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
class OpError < StandardError
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
class OutOfRangeError < StandardError
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
class PermissionDeniedError < StandardError
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
class ResourceExhaustedError < StandardError
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
class UnauthenticatedError < StandardError
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
class UnavailableError < StandardError
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
class UnimplementedError < StandardError
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
class UnknownError < StandardError
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
class ExecutionContext
|
|
3
|
+
class << self
|
|
4
|
+
extend Forwardable
|
|
5
|
+
def_delegators :context, :push, :pop, :current, :eager?, :graph?
|
|
6
|
+
end
|
|
7
|
+
|
|
8
|
+
def self.context
|
|
9
|
+
Thread.current[:execution_context] ||= self.new
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def initialize
|
|
13
|
+
@stack = Array.new
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def push(value)
|
|
17
|
+
@stack.push(value)
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def pop
|
|
21
|
+
@stack.pop
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def figure_from_inputs(inputs=[])
|
|
25
|
+
inputs.flatten.each do |input|
|
|
26
|
+
case input
|
|
27
|
+
when Graph::Operation
|
|
28
|
+
return input.graph
|
|
29
|
+
when Eager::TensorHandle
|
|
30
|
+
return input.context
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
nil
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def figure_from_context
|
|
37
|
+
@stack.last
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def figure_from_execution_mode
|
|
41
|
+
if ::Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
|
|
42
|
+
Graph::Graph.default
|
|
43
|
+
else
|
|
44
|
+
Eager::Context.default
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def current(inputs=[])
|
|
49
|
+
figure_from_context || figure_from_inputs(inputs) || figure_from_execution_mode
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def eager?(inputs=[])
|
|
53
|
+
context = self.current(inputs)
|
|
54
|
+
context.is_a?(Eager::Context)
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def graph?(inputs=[])
|
|
58
|
+
context = self.current(inputs)
|
|
59
|
+
context.is_a?(Graph::Graph)
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
class OpDef
|
|
3
|
+
class ArgDef
|
|
4
|
+
def dtype
|
|
5
|
+
case self.type
|
|
6
|
+
when :DT_INVALID
|
|
7
|
+
nil
|
|
8
|
+
when :DT_FLOAT
|
|
9
|
+
:float
|
|
10
|
+
when :DT_DOUBLE
|
|
11
|
+
:double
|
|
12
|
+
when :DT_INT32
|
|
13
|
+
:int32
|
|
14
|
+
when :DT_UINT8
|
|
15
|
+
:uint8
|
|
16
|
+
when :DT_INT16
|
|
17
|
+
:int16
|
|
18
|
+
when :DT_INT8
|
|
19
|
+
:int8
|
|
20
|
+
when :DT_STRING
|
|
21
|
+
:string
|
|
22
|
+
when :DT_COMPLEX64
|
|
23
|
+
:complex64
|
|
24
|
+
when :DT_INT64
|
|
25
|
+
:int64
|
|
26
|
+
when :DT_BOOL
|
|
27
|
+
:bool
|
|
28
|
+
when :DT_QINT8
|
|
29
|
+
:qint8
|
|
30
|
+
when :DT_QUINT8
|
|
31
|
+
:quint8
|
|
32
|
+
when :DT_QINT32
|
|
33
|
+
:qint32
|
|
34
|
+
when :DT_BFLOAT16
|
|
35
|
+
:bfloat16
|
|
36
|
+
when :DT_QINT16
|
|
37
|
+
:qint16
|
|
38
|
+
when :DT_QUINT16
|
|
39
|
+
:quint16
|
|
40
|
+
when :DT_UINT16
|
|
41
|
+
:uint16
|
|
42
|
+
when :DT_COMPLEX128
|
|
43
|
+
:complex128
|
|
44
|
+
when :DT_HALF
|
|
45
|
+
:half
|
|
46
|
+
when :DT_RESOURCE
|
|
47
|
+
:resource
|
|
48
|
+
when :DT_VARIANT
|
|
49
|
+
:variant
|
|
50
|
+
when :DT_UINT32
|
|
51
|
+
:uint32
|
|
52
|
+
when :DT_UINT64
|
|
53
|
+
:uint64
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
end
|