tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,153 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class OperationAttr
4
+ attr_reader :metadata, :name, :operation
5
+
6
+ def initialize(operation, name, metadata)
7
+ @operation = operation
8
+ @name = name
9
+ @metadata = metadata
10
+ end
11
+
12
+ def list?
13
+ self.metadata[:is_list] > 0
14
+ end
15
+
16
+ def value
17
+ case self.metadata[:type]
18
+ when :bool
19
+ self.list? ? self.bool_list : self.bool
20
+ when :int
21
+ self.list? ? self.int_list : self.int
22
+ when :float
23
+ self.list? ? self.float_list : self.float
24
+ self.float
25
+ when :func
26
+ self.list? ? self.func_list : self.func
27
+ when :shape
28
+ self.list? ? self.shape_list : self.shape
29
+ when :string
30
+ self.list? ? self.string_list : self.string
31
+ when :tensor
32
+ self.list? ? self.tensor_list : self.tensor
33
+ when :type
34
+ self.list? ? self.dtype_list : self.dtype
35
+ else
36
+ raise(Error::UnimplementedError, "Unsupported attribute. #{self.name} - #{self.metadata[:type]}")
37
+ end
38
+ end
39
+
40
+ def bool
41
+ pointer = ::FFI::MemoryPointer.new(:uchar)
42
+ Status.check do |status|
43
+ FFI.TF_OperationGetAttrBool(self.operation, self.name, pointer, status)
44
+ end
45
+ Boolean(pointer.read_uchar)
46
+ end
47
+
48
+ def dtype
49
+ pointer = ::FFI::MemoryPointer.new(FFI::DataType.native_type)
50
+ Status.check do |status|
51
+ FFI.TF_OperationGetAttrType(self.operation, self.name, pointer, status)
52
+ end
53
+ value = pointer.read(FFI::DataType.native_type)
54
+ FFI::DataType[value]
55
+ end
56
+
57
+ def dtype_list
58
+ pointer = ::FFI::MemoryPointer.new(FFI::DataType.native_type, self.metadata[:list_size])
59
+ Status.check do |status|
60
+ FFI.TF_OperationGetAttrTypeList(self.operation, self.name, pointer, self.metadata[:list_size], status)
61
+ end
62
+ pointer.read_array_of_type(FFI::DataType.native_type, :read_uint32, self.metadata[:list_size]).map do |value|
63
+ FFI::DataType[value]
64
+ end
65
+ end
66
+
67
+ def float
68
+ pointer = ::FFI::MemoryPointer.new(:float)
69
+ Status.check do |status|
70
+ FFI.TF_OperationGetAttrFloat(self.operation, self.name, pointer, status)
71
+ end
72
+ pointer.read_float
73
+ end
74
+
75
+ def func
76
+ self.proto.func.name
77
+ end
78
+
79
+ def int
80
+ pointer = ::FFI::MemoryPointer.new(:int64)
81
+ Status.check do |status|
82
+ FFI.TF_OperationGetAttrInt(self.operation, self.name, pointer, status)
83
+ end
84
+ pointer.read_int
85
+ end
86
+
87
+ def shape
88
+ size = self.metadata[:total_size]
89
+ if size == -1
90
+ []
91
+ else
92
+ pointer = ::FFI::MemoryPointer.new(:int64, size)
93
+ Status.check do |status|
94
+ FFI.TF_OperationGetAttrShape(self.operation, self.name, pointer, size, status)
95
+ end
96
+ pointer.read_array_of_int64(size)
97
+ end
98
+ end
99
+
100
+ def shape_list
101
+ total_size = self.metadata[:total_size]
102
+ storage_ptr = ::FFI::MemoryPointer.new(:int64, total_size)
103
+ dims_pointer = ::FFI::MemoryPointer.new(:pointer, self.metadata[:list_size])
104
+ num_dims_pointer = ::FFI::MemoryPointer.new(:int, self.metadata[:list_size])
105
+ Status.check do |status|
106
+ FFI.TF_OperationGetAttrShapeList(self.operation, self.name,
107
+ dims_pointer, num_dims_pointer,
108
+ self.metadata[:list_size],
109
+ storage_ptr, total_size, status)
110
+ end
111
+
112
+ num_dims = num_dims_pointer.read_array_of_int(self.metadata[:list_size])
113
+ num_dims.map.with_index do |dims, i|
114
+ shape_pointer = dims_pointer[i].read_pointer
115
+ shape_pointer.read_array_of_int64(dims)
116
+ end
117
+ end
118
+
119
+ def string
120
+ size = self.metadata[:total_size]
121
+ pointer = ::FFI::MemoryPointer.new(:string, size)
122
+ Status.check do |status|
123
+ FFI.TF_OperationGetAttrString(self.operation, self.name, pointer, size, status)
124
+ end
125
+ pointer.read_string
126
+ end
127
+
128
+ def tensor
129
+ pointer = ::FFI::MemoryPointer.new(:pointer)
130
+ Status.check do |status|
131
+ FFI.TF_OperationGetAttrTensor(self.operation, self.name, pointer, status)
132
+ end
133
+ Tensor.from_pointer(pointer.read_pointer)
134
+ end
135
+
136
+ def proto
137
+ buffer_ptr = FFI.TF_NewBuffer
138
+ Status.check do |status|
139
+ FFI.TF_OperationGetAttrValueProto(self.operation, self.name, buffer_ptr, status)
140
+ end
141
+ buffer = FFI::Buffer.new(buffer_ptr)
142
+ string = buffer[:data].read_string(buffer[:length])
143
+ AttrValue.decode(string)
144
+ ensure
145
+ FFI.TF_DeleteBuffer(buffer)
146
+ end
147
+
148
+ def to_s
149
+ "#{self.name}: #{self.value}"
150
+ end
151
+ end
152
+ end
153
+ end
@@ -0,0 +1,255 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class OperationDescription
4
+ attr_reader :graph, :name, :op_def
5
+
6
+ def initialize(graph, op_type, inputs, attrs)
7
+ @graph = graph
8
+ @op_def = case op_type
9
+ when Function
10
+ op_type.function_def.signature
11
+ else
12
+ self.graph.op_def(op_type)
13
+ end
14
+ raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def
15
+
16
+ raw_name = attrs.delete(:name)&.to_s || self.op_def.name
17
+ @name = self.graph.scoped_name(raw_name)
18
+ @pointer = FFI.TF_NewOperation(graph, self.op_def.name, @name)
19
+
20
+ inputs = Array(inputs)
21
+ setup_inputs(inputs, attrs)
22
+ setup_control_inputs(graph.control_inputs)
23
+ setup_attrs(**attrs)
24
+ end
25
+
26
+ def figure_dtype(attrs, inputs)
27
+ attr_def = self.op_def.attr.detect do |attr_def|
28
+ attr_def.type == 'type'
29
+ end
30
+
31
+ result = attr_def ? attrs[attr_def.name.to_sym] : nil
32
+ unless result
33
+ inputs.each do |input|
34
+ case input
35
+ when Operation
36
+ return input.output_types.first
37
+ when Variable
38
+ return input.dtype
39
+ end
40
+ end
41
+ end
42
+ result
43
+ end
44
+
45
+ def to_ptr
46
+ @pointer
47
+ end
48
+
49
+ def save
50
+ Status.check do |status|
51
+ ptr = FFI.TF_FinishOperation(self, status)
52
+ Operation.new(self.graph, ptr)
53
+ end
54
+ end
55
+
56
+ def device=(value)
57
+ FFI.TF_SetDevice(self, value)
58
+ end
59
+
60
+ def setup_control_inputs(control_inputs)
61
+ control_inputs.each do |control_input|
62
+ setup_control_input(control_input)
63
+ end
64
+ end
65
+
66
+ def setup_control_input(control_input)
67
+ control_input = case control_input
68
+ when Operation
69
+ control_input
70
+ when Variable
71
+ control_input.handle
72
+ else
73
+ raise(Error::InvalidArgumentError, "Invalid control input")
74
+ end
75
+
76
+ FFI.TF_AddControlInput(self, control_input)
77
+ end
78
+
79
+ def capture_inputs(operation, attrs)
80
+ # First capture the inputs
81
+ inputs = operation.inputs.map do |input|
82
+ self.capture(input.operation)
83
+ end
84
+
85
+ # We now have to group the inputs together. For example, a TensorSlice dataset has 1 input argument
86
+ # which a list. But the number of inputs returned by the operation is actually the number of items in
87
+ # the list, so its usually more than one. We need to group them into one array to be able to call
88
+ # the operation to create a captured copy.
89
+ i = 0
90
+ operation.op_def.input_arg.reduce(Array.new) do |result, input_arg|
91
+ if !input_arg.number_attr.empty?
92
+ input_len = attrs[input_arg.number_attr.to_sym]
93
+ is_sequence = true
94
+ elsif !input_arg.type_list_attr.empty?
95
+ input_len = attrs[input_arg.type_list_attr.to_sym].length
96
+ is_sequence = true
97
+ else
98
+ input_len = 1
99
+ is_sequence = false
100
+ end
101
+
102
+ if is_sequence
103
+ result << inputs[i..i+input_len]
104
+ else
105
+ result << inputs[i]
106
+ end
107
+ i += input_len
108
+ result
109
+ end
110
+ end
111
+
112
+ def capture(operation)
113
+ if self.op_def.is_stateful
114
+ raise(Error::InvalidArgumentError, "Cannot capture a stateful node (name: #{operation.name}, type: #{operation.op_type})")
115
+ elsif operation.op_type == "Placeholder"
116
+ raise(Error::InvalidArgumentError, "Cannot capture a placeholder by value (name: #{operation.name}, type: #{operation.op_type})")
117
+ end
118
+
119
+ attrs = operation.attributes.reduce(Hash.new) do |hash, attr|
120
+ hash[attr.name.to_sym] = attr.value
121
+ hash
122
+ end
123
+ attrs[:name] = operation.name
124
+
125
+ captured_inputs = self.capture_inputs(operation, attrs)
126
+ self.graph.create_operation(operation.op_type, captured_inputs, **attrs)
127
+ end
128
+
129
+ def check_input(arg_def, input, dtype)
130
+ case input
131
+ when Operation
132
+ self.graph.equal?(input.graph) ? input : capture(input)
133
+ when OperationOutput
134
+ input
135
+ when Variable
136
+ arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
137
+ else
138
+ input_name = "#{self.name}/#{arg_def.name}"
139
+ Tensorflow.constant(input, name: input_name, dtype: dtype)
140
+ end
141
+ end
142
+
143
+ def setup_inputs(inputs, attrs)
144
+ inputs.each_with_index do |input, index|
145
+ self.setup_input(index, input, attrs)
146
+ end
147
+ end
148
+
149
+ def setup_input(index, value, attrs)
150
+ arg_def = self.op_def.input_arg[index]
151
+ dtype = attrs[arg_def.type_attr.to_sym]
152
+
153
+ # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
154
+ checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array)
155
+ value.map do |sub_value|
156
+ self.check_input(arg_def, sub_value, dtype)
157
+ end
158
+ else
159
+ self.check_input(arg_def, value, dtype)
160
+ end
161
+
162
+ if !arg_def.type_list_attr.empty?
163
+ # This input is a heterogeneous list
164
+ self.add_input_list(checked_value)
165
+ elsif !arg_def.number_attr.empty?
166
+ # This input is a homogeneous list
167
+ self.add_input_list(checked_value)
168
+ else
169
+ # This input is a single item
170
+ self.add_input(checked_value)
171
+ end
172
+ end
173
+
174
+ def add_input(operation)
175
+ # Check to see if the operation has multiple outputs, and if it does, we need to pack them together
176
+ # to fit into one input
177
+ if operation.is_a?(OperationOutput)
178
+ FFI.TF_AddInput(self, operation)
179
+ elsif operation.num_outputs > 1
180
+ packed = Tensorflow.pack(operation, n: operation.num_outputs)
181
+ FFI.TF_AddInput(self, packed.outputs.first)
182
+ else
183
+ FFI.TF_AddInput(self, operation.outputs.first)
184
+ end
185
+ end
186
+
187
+ def add_input_list(operations)
188
+ # Operation can represent multiple operations *or* one operation with multiple outputs (like SPLIT)
189
+ outputs = Array(operations).map(&:outputs).flatten
190
+ outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
191
+ FFI.TF_AddInputList(self, outputs_ptr, outputs.length)
192
+ end
193
+
194
+ def setup_attrs(**attrs)
195
+ attrs.each do |attr_name, attr_value|
196
+ self.setup_attr(attr_name, attr_value)
197
+ end
198
+ end
199
+
200
+ def setup_attr(name, value)
201
+ attr_def = self.op_def.attr.detect do |attr_def|
202
+ name.to_s == attr_def.name
203
+ end
204
+ unless attr_def
205
+ raise(Error::UnknownError, "Unknown attribute: #{name}")
206
+ end
207
+
208
+ case attr_def.type
209
+ when 'bool'
210
+ FFI.TF_SetAttrBool(self, attr_def.name, value ? 1 : 0)
211
+ when 'int'
212
+ FFI.TF_SetAttrInt(self, attr_def.name, value)
213
+ when 'float'
214
+ FFI.TF_SetAttrFloat(self, attr_def.name, value)
215
+ when 'func'
216
+ function_name = value.is_a?(Function) ? value.name : value
217
+ FFI.TF_SetAttrFuncName(self, attr_def.name, function_name, function_name.length)
218
+ when 'shape'
219
+ pointer = ::FFI::MemoryPointer.new(:int64, value.length)
220
+ pointer.write_array_of_int64(value)
221
+ FFI.TF_SetAttrShape(self, attr_def.name, pointer, value.length)
222
+ when 'list(shape)'
223
+ dims_pointer = ::FFI::MemoryPointer.new(:pointer, value.length)
224
+ num_dims_pointer = ::FFI::MemoryPointer.new(:int32, value.length)
225
+ value.each_with_index do |shape, i|
226
+ dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
227
+ dim_pointer.write_array_of_int64(shape)
228
+ dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
229
+ num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
230
+ end
231
+ FFI.TF_SetAttrShapeList(self, attr_def.name, dims_pointer, num_dims_pointer, value.length)
232
+ when 'string'
233
+ FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
234
+ when 'list(string)'
235
+ a = 1
236
+ #FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
237
+ when 'tensor'
238
+ Status.check do |status|
239
+ FFI.TF_SetAttrTensor(self, attr_def.name, value, status)
240
+ end
241
+ when 'type'
242
+ FFI.TF_SetAttrType(self, attr_def.name, value)
243
+ when 'list(type)'
244
+ value_ptr = ::FFI::MemoryPointer.new(FFI::DataType.native_type.size, value.count)
245
+ value.each_with_index do |a_value, i|
246
+ value_ptr.put_int32(i * FFI::DataType.native_type.size, FFI::DataType[a_value])
247
+ end
248
+ FFI.TF_SetAttrTypeList(self, attr_def.name, value_ptr, value.count)
249
+ else
250
+ raise(Error::UnimplementedError, "Unsupported attribute. #{self.op_def.name} - #{attr_def.name}")
251
+ end
252
+ end
253
+ end
254
+ end
255
+ end
@@ -0,0 +1,49 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class OperationOutput
4
+ attr_reader :operation, :output
5
+
6
+ def self.from_pointer(operation, pointer)
7
+ output = FFI::Output.new(pointer)
8
+ self.new(operation, output)
9
+ end
10
+
11
+ def self.from_index(operation, index)
12
+ output = FFI::Output.new
13
+ output[:index] = index
14
+ output[:oper] = operation
15
+ self.new(operation, output)
16
+ end
17
+
18
+ def self.from_graph(graph, pointer)
19
+ output = FFI::Output.new(pointer)
20
+ operation = Operation.new(graph, output[:oper])
21
+ self.new(operation, output)
22
+ end
23
+
24
+ def initialize(operation, output)
25
+ @operation = operation
26
+ @output = output
27
+ end
28
+
29
+ def to_ptr
30
+ @output.to_ptr
31
+ end
32
+
33
+ def index
34
+ self.output[:index]
35
+ end
36
+
37
+ def to_s
38
+ if self.output
39
+ result = [self.operation.op_type]
40
+ result << "name=#{self.operation.name}"
41
+ result << "#{self.index}:(shape=#{self.operation.output_shapes[self.index]}, dtype=#{self.operation.output_types[self.index]})"
42
+ result.join(', ')
43
+ else
44
+ super
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end