tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class OperationAttr
|
4
|
+
attr_reader :metadata, :name, :operation
|
5
|
+
|
6
|
+
def initialize(operation, name, metadata)
|
7
|
+
@operation = operation
|
8
|
+
@name = name
|
9
|
+
@metadata = metadata
|
10
|
+
end
|
11
|
+
|
12
|
+
def list?
|
13
|
+
self.metadata[:is_list] > 0
|
14
|
+
end
|
15
|
+
|
16
|
+
def value
|
17
|
+
case self.metadata[:type]
|
18
|
+
when :bool
|
19
|
+
self.list? ? self.bool_list : self.bool
|
20
|
+
when :int
|
21
|
+
self.list? ? self.int_list : self.int
|
22
|
+
when :float
|
23
|
+
self.list? ? self.float_list : self.float
|
24
|
+
self.float
|
25
|
+
when :func
|
26
|
+
self.list? ? self.func_list : self.func
|
27
|
+
when :shape
|
28
|
+
self.list? ? self.shape_list : self.shape
|
29
|
+
when :string
|
30
|
+
self.list? ? self.string_list : self.string
|
31
|
+
when :tensor
|
32
|
+
self.list? ? self.tensor_list : self.tensor
|
33
|
+
when :type
|
34
|
+
self.list? ? self.dtype_list : self.dtype
|
35
|
+
else
|
36
|
+
raise(Error::UnimplementedError, "Unsupported attribute. #{self.name} - #{self.metadata[:type]}")
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
def bool
|
41
|
+
pointer = ::FFI::MemoryPointer.new(:uchar)
|
42
|
+
Status.check do |status|
|
43
|
+
FFI.TF_OperationGetAttrBool(self.operation, self.name, pointer, status)
|
44
|
+
end
|
45
|
+
Boolean(pointer.read_uchar)
|
46
|
+
end
|
47
|
+
|
48
|
+
def dtype
|
49
|
+
pointer = ::FFI::MemoryPointer.new(FFI::DataType.native_type)
|
50
|
+
Status.check do |status|
|
51
|
+
FFI.TF_OperationGetAttrType(self.operation, self.name, pointer, status)
|
52
|
+
end
|
53
|
+
value = pointer.read(FFI::DataType.native_type)
|
54
|
+
FFI::DataType[value]
|
55
|
+
end
|
56
|
+
|
57
|
+
def dtype_list
|
58
|
+
pointer = ::FFI::MemoryPointer.new(FFI::DataType.native_type, self.metadata[:list_size])
|
59
|
+
Status.check do |status|
|
60
|
+
FFI.TF_OperationGetAttrTypeList(self.operation, self.name, pointer, self.metadata[:list_size], status)
|
61
|
+
end
|
62
|
+
pointer.read_array_of_type(FFI::DataType.native_type, :read_uint32, self.metadata[:list_size]).map do |value|
|
63
|
+
FFI::DataType[value]
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
67
|
+
def float
|
68
|
+
pointer = ::FFI::MemoryPointer.new(:float)
|
69
|
+
Status.check do |status|
|
70
|
+
FFI.TF_OperationGetAttrFloat(self.operation, self.name, pointer, status)
|
71
|
+
end
|
72
|
+
pointer.read_float
|
73
|
+
end
|
74
|
+
|
75
|
+
def func
|
76
|
+
self.proto.func.name
|
77
|
+
end
|
78
|
+
|
79
|
+
def int
|
80
|
+
pointer = ::FFI::MemoryPointer.new(:int64)
|
81
|
+
Status.check do |status|
|
82
|
+
FFI.TF_OperationGetAttrInt(self.operation, self.name, pointer, status)
|
83
|
+
end
|
84
|
+
pointer.read_int
|
85
|
+
end
|
86
|
+
|
87
|
+
def shape
|
88
|
+
size = self.metadata[:total_size]
|
89
|
+
if size == -1
|
90
|
+
[]
|
91
|
+
else
|
92
|
+
pointer = ::FFI::MemoryPointer.new(:int64, size)
|
93
|
+
Status.check do |status|
|
94
|
+
FFI.TF_OperationGetAttrShape(self.operation, self.name, pointer, size, status)
|
95
|
+
end
|
96
|
+
pointer.read_array_of_int64(size)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def shape_list
|
101
|
+
total_size = self.metadata[:total_size]
|
102
|
+
storage_ptr = ::FFI::MemoryPointer.new(:int64, total_size)
|
103
|
+
dims_pointer = ::FFI::MemoryPointer.new(:pointer, self.metadata[:list_size])
|
104
|
+
num_dims_pointer = ::FFI::MemoryPointer.new(:int, self.metadata[:list_size])
|
105
|
+
Status.check do |status|
|
106
|
+
FFI.TF_OperationGetAttrShapeList(self.operation, self.name,
|
107
|
+
dims_pointer, num_dims_pointer,
|
108
|
+
self.metadata[:list_size],
|
109
|
+
storage_ptr, total_size, status)
|
110
|
+
end
|
111
|
+
|
112
|
+
num_dims = num_dims_pointer.read_array_of_int(self.metadata[:list_size])
|
113
|
+
num_dims.map.with_index do |dims, i|
|
114
|
+
shape_pointer = dims_pointer[i].read_pointer
|
115
|
+
shape_pointer.read_array_of_int64(dims)
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def string
|
120
|
+
size = self.metadata[:total_size]
|
121
|
+
pointer = ::FFI::MemoryPointer.new(:string, size)
|
122
|
+
Status.check do |status|
|
123
|
+
FFI.TF_OperationGetAttrString(self.operation, self.name, pointer, size, status)
|
124
|
+
end
|
125
|
+
pointer.read_string
|
126
|
+
end
|
127
|
+
|
128
|
+
def tensor
|
129
|
+
pointer = ::FFI::MemoryPointer.new(:pointer)
|
130
|
+
Status.check do |status|
|
131
|
+
FFI.TF_OperationGetAttrTensor(self.operation, self.name, pointer, status)
|
132
|
+
end
|
133
|
+
Tensor.from_pointer(pointer.read_pointer)
|
134
|
+
end
|
135
|
+
|
136
|
+
def proto
|
137
|
+
buffer_ptr = FFI.TF_NewBuffer
|
138
|
+
Status.check do |status|
|
139
|
+
FFI.TF_OperationGetAttrValueProto(self.operation, self.name, buffer_ptr, status)
|
140
|
+
end
|
141
|
+
buffer = FFI::Buffer.new(buffer_ptr)
|
142
|
+
string = buffer[:data].read_string(buffer[:length])
|
143
|
+
AttrValue.decode(string)
|
144
|
+
ensure
|
145
|
+
FFI.TF_DeleteBuffer(buffer)
|
146
|
+
end
|
147
|
+
|
148
|
+
def to_s
|
149
|
+
"#{self.name}: #{self.value}"
|
150
|
+
end
|
151
|
+
end
|
152
|
+
end
|
153
|
+
end
|
@@ -0,0 +1,255 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class OperationDescription
|
4
|
+
attr_reader :graph, :name, :op_def
|
5
|
+
|
6
|
+
def initialize(graph, op_type, inputs, attrs)
|
7
|
+
@graph = graph
|
8
|
+
@op_def = case op_type
|
9
|
+
when Function
|
10
|
+
op_type.function_def.signature
|
11
|
+
else
|
12
|
+
self.graph.op_def(op_type)
|
13
|
+
end
|
14
|
+
raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def
|
15
|
+
|
16
|
+
raw_name = attrs.delete(:name)&.to_s || self.op_def.name
|
17
|
+
@name = self.graph.scoped_name(raw_name)
|
18
|
+
@pointer = FFI.TF_NewOperation(graph, self.op_def.name, @name)
|
19
|
+
|
20
|
+
inputs = Array(inputs)
|
21
|
+
setup_inputs(inputs, attrs)
|
22
|
+
setup_control_inputs(graph.control_inputs)
|
23
|
+
setup_attrs(**attrs)
|
24
|
+
end
|
25
|
+
|
26
|
+
def figure_dtype(attrs, inputs)
|
27
|
+
attr_def = self.op_def.attr.detect do |attr_def|
|
28
|
+
attr_def.type == 'type'
|
29
|
+
end
|
30
|
+
|
31
|
+
result = attr_def ? attrs[attr_def.name.to_sym] : nil
|
32
|
+
unless result
|
33
|
+
inputs.each do |input|
|
34
|
+
case input
|
35
|
+
when Operation
|
36
|
+
return input.output_types.first
|
37
|
+
when Variable
|
38
|
+
return input.dtype
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
result
|
43
|
+
end
|
44
|
+
|
45
|
+
def to_ptr
|
46
|
+
@pointer
|
47
|
+
end
|
48
|
+
|
49
|
+
def save
|
50
|
+
Status.check do |status|
|
51
|
+
ptr = FFI.TF_FinishOperation(self, status)
|
52
|
+
Operation.new(self.graph, ptr)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
def device=(value)
|
57
|
+
FFI.TF_SetDevice(self, value)
|
58
|
+
end
|
59
|
+
|
60
|
+
def setup_control_inputs(control_inputs)
|
61
|
+
control_inputs.each do |control_input|
|
62
|
+
setup_control_input(control_input)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def setup_control_input(control_input)
|
67
|
+
control_input = case control_input
|
68
|
+
when Operation
|
69
|
+
control_input
|
70
|
+
when Variable
|
71
|
+
control_input.handle
|
72
|
+
else
|
73
|
+
raise(Error::InvalidArgumentError, "Invalid control input")
|
74
|
+
end
|
75
|
+
|
76
|
+
FFI.TF_AddControlInput(self, control_input)
|
77
|
+
end
|
78
|
+
|
79
|
+
def capture_inputs(operation, attrs)
|
80
|
+
# First capture the inputs
|
81
|
+
inputs = operation.inputs.map do |input|
|
82
|
+
self.capture(input.operation)
|
83
|
+
end
|
84
|
+
|
85
|
+
# We now have to group the inputs together. For example, a TensorSlice dataset has 1 input argument
|
86
|
+
# which a list. But the number of inputs returned by the operation is actually the number of items in
|
87
|
+
# the list, so its usually more than one. We need to group them into one array to be able to call
|
88
|
+
# the operation to create a captured copy.
|
89
|
+
i = 0
|
90
|
+
operation.op_def.input_arg.reduce(Array.new) do |result, input_arg|
|
91
|
+
if !input_arg.number_attr.empty?
|
92
|
+
input_len = attrs[input_arg.number_attr.to_sym]
|
93
|
+
is_sequence = true
|
94
|
+
elsif !input_arg.type_list_attr.empty?
|
95
|
+
input_len = attrs[input_arg.type_list_attr.to_sym].length
|
96
|
+
is_sequence = true
|
97
|
+
else
|
98
|
+
input_len = 1
|
99
|
+
is_sequence = false
|
100
|
+
end
|
101
|
+
|
102
|
+
if is_sequence
|
103
|
+
result << inputs[i..i+input_len]
|
104
|
+
else
|
105
|
+
result << inputs[i]
|
106
|
+
end
|
107
|
+
i += input_len
|
108
|
+
result
|
109
|
+
end
|
110
|
+
end
|
111
|
+
|
112
|
+
def capture(operation)
|
113
|
+
if self.op_def.is_stateful
|
114
|
+
raise(Error::InvalidArgumentError, "Cannot capture a stateful node (name: #{operation.name}, type: #{operation.op_type})")
|
115
|
+
elsif operation.op_type == "Placeholder"
|
116
|
+
raise(Error::InvalidArgumentError, "Cannot capture a placeholder by value (name: #{operation.name}, type: #{operation.op_type})")
|
117
|
+
end
|
118
|
+
|
119
|
+
attrs = operation.attributes.reduce(Hash.new) do |hash, attr|
|
120
|
+
hash[attr.name.to_sym] = attr.value
|
121
|
+
hash
|
122
|
+
end
|
123
|
+
attrs[:name] = operation.name
|
124
|
+
|
125
|
+
captured_inputs = self.capture_inputs(operation, attrs)
|
126
|
+
self.graph.create_operation(operation.op_type, captured_inputs, **attrs)
|
127
|
+
end
|
128
|
+
|
129
|
+
def check_input(arg_def, input, dtype)
|
130
|
+
case input
|
131
|
+
when Operation
|
132
|
+
self.graph.equal?(input.graph) ? input : capture(input)
|
133
|
+
when OperationOutput
|
134
|
+
input
|
135
|
+
when Variable
|
136
|
+
arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
|
137
|
+
else
|
138
|
+
input_name = "#{self.name}/#{arg_def.name}"
|
139
|
+
Tensorflow.constant(input, name: input_name, dtype: dtype)
|
140
|
+
end
|
141
|
+
end
|
142
|
+
|
143
|
+
def setup_inputs(inputs, attrs)
|
144
|
+
inputs.each_with_index do |input, index|
|
145
|
+
self.setup_input(index, input, attrs)
|
146
|
+
end
|
147
|
+
end
|
148
|
+
|
149
|
+
def setup_input(index, value, attrs)
|
150
|
+
arg_def = self.op_def.input_arg[index]
|
151
|
+
dtype = attrs[arg_def.type_attr.to_sym]
|
152
|
+
|
153
|
+
# Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
|
154
|
+
checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array)
|
155
|
+
value.map do |sub_value|
|
156
|
+
self.check_input(arg_def, sub_value, dtype)
|
157
|
+
end
|
158
|
+
else
|
159
|
+
self.check_input(arg_def, value, dtype)
|
160
|
+
end
|
161
|
+
|
162
|
+
if !arg_def.type_list_attr.empty?
|
163
|
+
# This input is a heterogeneous list
|
164
|
+
self.add_input_list(checked_value)
|
165
|
+
elsif !arg_def.number_attr.empty?
|
166
|
+
# This input is a homogeneous list
|
167
|
+
self.add_input_list(checked_value)
|
168
|
+
else
|
169
|
+
# This input is a single item
|
170
|
+
self.add_input(checked_value)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
|
174
|
+
def add_input(operation)
|
175
|
+
# Check to see if the operation has multiple outputs, and if it does, we need to pack them together
|
176
|
+
# to fit into one input
|
177
|
+
if operation.is_a?(OperationOutput)
|
178
|
+
FFI.TF_AddInput(self, operation)
|
179
|
+
elsif operation.num_outputs > 1
|
180
|
+
packed = Tensorflow.pack(operation, n: operation.num_outputs)
|
181
|
+
FFI.TF_AddInput(self, packed.outputs.first)
|
182
|
+
else
|
183
|
+
FFI.TF_AddInput(self, operation.outputs.first)
|
184
|
+
end
|
185
|
+
end
|
186
|
+
|
187
|
+
def add_input_list(operations)
|
188
|
+
# Operation can represent multiple operations *or* one operation with multiple outputs (like SPLIT)
|
189
|
+
outputs = Array(operations).map(&:outputs).flatten
|
190
|
+
outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
|
191
|
+
FFI.TF_AddInputList(self, outputs_ptr, outputs.length)
|
192
|
+
end
|
193
|
+
|
194
|
+
def setup_attrs(**attrs)
|
195
|
+
attrs.each do |attr_name, attr_value|
|
196
|
+
self.setup_attr(attr_name, attr_value)
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
def setup_attr(name, value)
|
201
|
+
attr_def = self.op_def.attr.detect do |attr_def|
|
202
|
+
name.to_s == attr_def.name
|
203
|
+
end
|
204
|
+
unless attr_def
|
205
|
+
raise(Error::UnknownError, "Unknown attribute: #{name}")
|
206
|
+
end
|
207
|
+
|
208
|
+
case attr_def.type
|
209
|
+
when 'bool'
|
210
|
+
FFI.TF_SetAttrBool(self, attr_def.name, value ? 1 : 0)
|
211
|
+
when 'int'
|
212
|
+
FFI.TF_SetAttrInt(self, attr_def.name, value)
|
213
|
+
when 'float'
|
214
|
+
FFI.TF_SetAttrFloat(self, attr_def.name, value)
|
215
|
+
when 'func'
|
216
|
+
function_name = value.is_a?(Function) ? value.name : value
|
217
|
+
FFI.TF_SetAttrFuncName(self, attr_def.name, function_name, function_name.length)
|
218
|
+
when 'shape'
|
219
|
+
pointer = ::FFI::MemoryPointer.new(:int64, value.length)
|
220
|
+
pointer.write_array_of_int64(value)
|
221
|
+
FFI.TF_SetAttrShape(self, attr_def.name, pointer, value.length)
|
222
|
+
when 'list(shape)'
|
223
|
+
dims_pointer = ::FFI::MemoryPointer.new(:pointer, value.length)
|
224
|
+
num_dims_pointer = ::FFI::MemoryPointer.new(:int32, value.length)
|
225
|
+
value.each_with_index do |shape, i|
|
226
|
+
dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
|
227
|
+
dim_pointer.write_array_of_int64(shape)
|
228
|
+
dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
|
229
|
+
num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
|
230
|
+
end
|
231
|
+
FFI.TF_SetAttrShapeList(self, attr_def.name, dims_pointer, num_dims_pointer, value.length)
|
232
|
+
when 'string'
|
233
|
+
FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
|
234
|
+
when 'list(string)'
|
235
|
+
a = 1
|
236
|
+
#FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
|
237
|
+
when 'tensor'
|
238
|
+
Status.check do |status|
|
239
|
+
FFI.TF_SetAttrTensor(self, attr_def.name, value, status)
|
240
|
+
end
|
241
|
+
when 'type'
|
242
|
+
FFI.TF_SetAttrType(self, attr_def.name, value)
|
243
|
+
when 'list(type)'
|
244
|
+
value_ptr = ::FFI::MemoryPointer.new(FFI::DataType.native_type.size, value.count)
|
245
|
+
value.each_with_index do |a_value, i|
|
246
|
+
value_ptr.put_int32(i * FFI::DataType.native_type.size, FFI::DataType[a_value])
|
247
|
+
end
|
248
|
+
FFI.TF_SetAttrTypeList(self, attr_def.name, value_ptr, value.count)
|
249
|
+
else
|
250
|
+
raise(Error::UnimplementedError, "Unsupported attribute. #{self.op_def.name} - #{attr_def.name}")
|
251
|
+
end
|
252
|
+
end
|
253
|
+
end
|
254
|
+
end
|
255
|
+
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class OperationOutput
|
4
|
+
attr_reader :operation, :output
|
5
|
+
|
6
|
+
def self.from_pointer(operation, pointer)
|
7
|
+
output = FFI::Output.new(pointer)
|
8
|
+
self.new(operation, output)
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.from_index(operation, index)
|
12
|
+
output = FFI::Output.new
|
13
|
+
output[:index] = index
|
14
|
+
output[:oper] = operation
|
15
|
+
self.new(operation, output)
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.from_graph(graph, pointer)
|
19
|
+
output = FFI::Output.new(pointer)
|
20
|
+
operation = Operation.new(graph, output[:oper])
|
21
|
+
self.new(operation, output)
|
22
|
+
end
|
23
|
+
|
24
|
+
def initialize(operation, output)
|
25
|
+
@operation = operation
|
26
|
+
@output = output
|
27
|
+
end
|
28
|
+
|
29
|
+
def to_ptr
|
30
|
+
@output.to_ptr
|
31
|
+
end
|
32
|
+
|
33
|
+
def index
|
34
|
+
self.output[:index]
|
35
|
+
end
|
36
|
+
|
37
|
+
def to_s
|
38
|
+
if self.output
|
39
|
+
result = [self.operation.op_type]
|
40
|
+
result << "name=#{self.operation.name}"
|
41
|
+
result << "#{self.index}:(shape=#{self.operation.output_shapes[self.index]}, dtype=#{self.operation.output_types[self.index]})"
|
42
|
+
result.join(', ')
|
43
|
+
else
|
44
|
+
super
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|