tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,153 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class OperationAttr
4
+ attr_reader :metadata, :name, :operation
5
+
6
+ def initialize(operation, name, metadata)
7
+ @operation = operation
8
+ @name = name
9
+ @metadata = metadata
10
+ end
11
+
12
+ def list?
13
+ self.metadata[:is_list] > 0
14
+ end
15
+
16
+ def value
17
+ case self.metadata[:type]
18
+ when :bool
19
+ self.list? ? self.bool_list : self.bool
20
+ when :int
21
+ self.list? ? self.int_list : self.int
22
+ when :float
23
+ self.list? ? self.float_list : self.float
24
+ self.float
25
+ when :func
26
+ self.list? ? self.func_list : self.func
27
+ when :shape
28
+ self.list? ? self.shape_list : self.shape
29
+ when :string
30
+ self.list? ? self.string_list : self.string
31
+ when :tensor
32
+ self.list? ? self.tensor_list : self.tensor
33
+ when :type
34
+ self.list? ? self.dtype_list : self.dtype
35
+ else
36
+ raise(Error::UnimplementedError, "Unsupported attribute. #{self.name} - #{self.metadata[:type]}")
37
+ end
38
+ end
39
+
40
+ def bool
41
+ pointer = ::FFI::MemoryPointer.new(:uchar)
42
+ Status.check do |status|
43
+ FFI.TF_OperationGetAttrBool(self.operation, self.name, pointer, status)
44
+ end
45
+ Boolean(pointer.read_uchar)
46
+ end
47
+
48
+ def dtype
49
+ pointer = ::FFI::MemoryPointer.new(FFI::DataType.native_type)
50
+ Status.check do |status|
51
+ FFI.TF_OperationGetAttrType(self.operation, self.name, pointer, status)
52
+ end
53
+ value = pointer.read(FFI::DataType.native_type)
54
+ FFI::DataType[value]
55
+ end
56
+
57
+ def dtype_list
58
+ pointer = ::FFI::MemoryPointer.new(FFI::DataType.native_type, self.metadata[:list_size])
59
+ Status.check do |status|
60
+ FFI.TF_OperationGetAttrTypeList(self.operation, self.name, pointer, self.metadata[:list_size], status)
61
+ end
62
+ pointer.read_array_of_type(FFI::DataType.native_type, :read_uint32, self.metadata[:list_size]).map do |value|
63
+ FFI::DataType[value]
64
+ end
65
+ end
66
+
67
+ def float
68
+ pointer = ::FFI::MemoryPointer.new(:float)
69
+ Status.check do |status|
70
+ FFI.TF_OperationGetAttrFloat(self.operation, self.name, pointer, status)
71
+ end
72
+ pointer.read_float
73
+ end
74
+
75
+ def func
76
+ self.proto.func.name
77
+ end
78
+
79
+ def int
80
+ pointer = ::FFI::MemoryPointer.new(:int64)
81
+ Status.check do |status|
82
+ FFI.TF_OperationGetAttrInt(self.operation, self.name, pointer, status)
83
+ end
84
+ pointer.read_int
85
+ end
86
+
87
+ def shape
88
+ size = self.metadata[:total_size]
89
+ if size == -1
90
+ []
91
+ else
92
+ pointer = ::FFI::MemoryPointer.new(:int64, size)
93
+ Status.check do |status|
94
+ FFI.TF_OperationGetAttrShape(self.operation, self.name, pointer, size, status)
95
+ end
96
+ pointer.read_array_of_int64(size)
97
+ end
98
+ end
99
+
100
+ def shape_list
101
+ total_size = self.metadata[:total_size]
102
+ storage_ptr = ::FFI::MemoryPointer.new(:int64, total_size)
103
+ dims_pointer = ::FFI::MemoryPointer.new(:pointer, self.metadata[:list_size])
104
+ num_dims_pointer = ::FFI::MemoryPointer.new(:int, self.metadata[:list_size])
105
+ Status.check do |status|
106
+ FFI.TF_OperationGetAttrShapeList(self.operation, self.name,
107
+ dims_pointer, num_dims_pointer,
108
+ self.metadata[:list_size],
109
+ storage_ptr, total_size, status)
110
+ end
111
+
112
+ num_dims = num_dims_pointer.read_array_of_int(self.metadata[:list_size])
113
+ num_dims.map.with_index do |dims, i|
114
+ shape_pointer = dims_pointer[i].read_pointer
115
+ shape_pointer.read_array_of_int64(dims)
116
+ end
117
+ end
118
+
119
+ def string
120
+ size = self.metadata[:total_size]
121
+ pointer = ::FFI::MemoryPointer.new(:string, size)
122
+ Status.check do |status|
123
+ FFI.TF_OperationGetAttrString(self.operation, self.name, pointer, size, status)
124
+ end
125
+ pointer.read_string
126
+ end
127
+
128
+ def tensor
129
+ pointer = ::FFI::MemoryPointer.new(:pointer)
130
+ Status.check do |status|
131
+ FFI.TF_OperationGetAttrTensor(self.operation, self.name, pointer, status)
132
+ end
133
+ Tensor.from_pointer(pointer.read_pointer)
134
+ end
135
+
136
+ def proto
137
+ buffer_ptr = FFI.TF_NewBuffer
138
+ Status.check do |status|
139
+ FFI.TF_OperationGetAttrValueProto(self.operation, self.name, buffer_ptr, status)
140
+ end
141
+ buffer = FFI::Buffer.new(buffer_ptr)
142
+ string = buffer[:data].read_string(buffer[:length])
143
+ AttrValue.decode(string)
144
+ ensure
145
+ FFI.TF_DeleteBuffer(buffer)
146
+ end
147
+
148
+ def to_s
149
+ "#{self.name}: #{self.value}"
150
+ end
151
+ end
152
+ end
153
+ end
@@ -0,0 +1,255 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class OperationDescription
4
+ attr_reader :graph, :name, :op_def
5
+
6
+ def initialize(graph, op_type, inputs, attrs)
7
+ @graph = graph
8
+ @op_def = case op_type
9
+ when Function
10
+ op_type.function_def.signature
11
+ else
12
+ self.graph.op_def(op_type)
13
+ end
14
+ raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def
15
+
16
+ raw_name = attrs.delete(:name)&.to_s || self.op_def.name
17
+ @name = self.graph.scoped_name(raw_name)
18
+ @pointer = FFI.TF_NewOperation(graph, self.op_def.name, @name)
19
+
20
+ inputs = Array(inputs)
21
+ setup_inputs(inputs, attrs)
22
+ setup_control_inputs(graph.control_inputs)
23
+ setup_attrs(**attrs)
24
+ end
25
+
26
+ def figure_dtype(attrs, inputs)
27
+ attr_def = self.op_def.attr.detect do |attr_def|
28
+ attr_def.type == 'type'
29
+ end
30
+
31
+ result = attr_def ? attrs[attr_def.name.to_sym] : nil
32
+ unless result
33
+ inputs.each do |input|
34
+ case input
35
+ when Operation
36
+ return input.output_types.first
37
+ when Variable
38
+ return input.dtype
39
+ end
40
+ end
41
+ end
42
+ result
43
+ end
44
+
45
+ def to_ptr
46
+ @pointer
47
+ end
48
+
49
+ def save
50
+ Status.check do |status|
51
+ ptr = FFI.TF_FinishOperation(self, status)
52
+ Operation.new(self.graph, ptr)
53
+ end
54
+ end
55
+
56
+ def device=(value)
57
+ FFI.TF_SetDevice(self, value)
58
+ end
59
+
60
+ def setup_control_inputs(control_inputs)
61
+ control_inputs.each do |control_input|
62
+ setup_control_input(control_input)
63
+ end
64
+ end
65
+
66
+ def setup_control_input(control_input)
67
+ control_input = case control_input
68
+ when Operation
69
+ control_input
70
+ when Variable
71
+ control_input.handle
72
+ else
73
+ raise(Error::InvalidArgumentError, "Invalid control input")
74
+ end
75
+
76
+ FFI.TF_AddControlInput(self, control_input)
77
+ end
78
+
79
+ def capture_inputs(operation, attrs)
80
+ # First capture the inputs
81
+ inputs = operation.inputs.map do |input|
82
+ self.capture(input.operation)
83
+ end
84
+
85
+ # We now have to group the inputs together. For example, a TensorSlice dataset has 1 input argument
86
+ # which a list. But the number of inputs returned by the operation is actually the number of items in
87
+ # the list, so its usually more than one. We need to group them into one array to be able to call
88
+ # the operation to create a captured copy.
89
+ i = 0
90
+ operation.op_def.input_arg.reduce(Array.new) do |result, input_arg|
91
+ if !input_arg.number_attr.empty?
92
+ input_len = attrs[input_arg.number_attr.to_sym]
93
+ is_sequence = true
94
+ elsif !input_arg.type_list_attr.empty?
95
+ input_len = attrs[input_arg.type_list_attr.to_sym].length
96
+ is_sequence = true
97
+ else
98
+ input_len = 1
99
+ is_sequence = false
100
+ end
101
+
102
+ if is_sequence
103
+ result << inputs[i..i+input_len]
104
+ else
105
+ result << inputs[i]
106
+ end
107
+ i += input_len
108
+ result
109
+ end
110
+ end
111
+
112
+ def capture(operation)
113
+ if self.op_def.is_stateful
114
+ raise(Error::InvalidArgumentError, "Cannot capture a stateful node (name: #{operation.name}, type: #{operation.op_type})")
115
+ elsif operation.op_type == "Placeholder"
116
+ raise(Error::InvalidArgumentError, "Cannot capture a placeholder by value (name: #{operation.name}, type: #{operation.op_type})")
117
+ end
118
+
119
+ attrs = operation.attributes.reduce(Hash.new) do |hash, attr|
120
+ hash[attr.name.to_sym] = attr.value
121
+ hash
122
+ end
123
+ attrs[:name] = operation.name
124
+
125
+ captured_inputs = self.capture_inputs(operation, attrs)
126
+ self.graph.create_operation(operation.op_type, captured_inputs, **attrs)
127
+ end
128
+
129
+ def check_input(arg_def, input, dtype)
130
+ case input
131
+ when Operation
132
+ self.graph.equal?(input.graph) ? input : capture(input)
133
+ when OperationOutput
134
+ input
135
+ when Variable
136
+ arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
137
+ else
138
+ input_name = "#{self.name}/#{arg_def.name}"
139
+ Tensorflow.constant(input, name: input_name, dtype: dtype)
140
+ end
141
+ end
142
+
143
+ def setup_inputs(inputs, attrs)
144
+ inputs.each_with_index do |input, index|
145
+ self.setup_input(index, input, attrs)
146
+ end
147
+ end
148
+
149
+ def setup_input(index, value, attrs)
150
+ arg_def = self.op_def.input_arg[index]
151
+ dtype = attrs[arg_def.type_attr.to_sym]
152
+
153
+ # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
154
+ checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array)
155
+ value.map do |sub_value|
156
+ self.check_input(arg_def, sub_value, dtype)
157
+ end
158
+ else
159
+ self.check_input(arg_def, value, dtype)
160
+ end
161
+
162
+ if !arg_def.type_list_attr.empty?
163
+ # This input is a heterogeneous list
164
+ self.add_input_list(checked_value)
165
+ elsif !arg_def.number_attr.empty?
166
+ # This input is a homogeneous list
167
+ self.add_input_list(checked_value)
168
+ else
169
+ # This input is a single item
170
+ self.add_input(checked_value)
171
+ end
172
+ end
173
+
174
+ def add_input(operation)
175
+ # Check to see if the operation has multiple outputs, and if it does, we need to pack them together
176
+ # to fit into one input
177
+ if operation.is_a?(OperationOutput)
178
+ FFI.TF_AddInput(self, operation)
179
+ elsif operation.num_outputs > 1
180
+ packed = Tensorflow.pack(operation, n: operation.num_outputs)
181
+ FFI.TF_AddInput(self, packed.outputs.first)
182
+ else
183
+ FFI.TF_AddInput(self, operation.outputs.first)
184
+ end
185
+ end
186
+
187
+ def add_input_list(operations)
188
+ # Operation can represent multiple operations *or* one operation with multiple outputs (like SPLIT)
189
+ outputs = Array(operations).map(&:outputs).flatten
190
+ outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
191
+ FFI.TF_AddInputList(self, outputs_ptr, outputs.length)
192
+ end
193
+
194
+ def setup_attrs(**attrs)
195
+ attrs.each do |attr_name, attr_value|
196
+ self.setup_attr(attr_name, attr_value)
197
+ end
198
+ end
199
+
200
+ def setup_attr(name, value)
201
+ attr_def = self.op_def.attr.detect do |attr_def|
202
+ name.to_s == attr_def.name
203
+ end
204
+ unless attr_def
205
+ raise(Error::UnknownError, "Unknown attribute: #{name}")
206
+ end
207
+
208
+ case attr_def.type
209
+ when 'bool'
210
+ FFI.TF_SetAttrBool(self, attr_def.name, value ? 1 : 0)
211
+ when 'int'
212
+ FFI.TF_SetAttrInt(self, attr_def.name, value)
213
+ when 'float'
214
+ FFI.TF_SetAttrFloat(self, attr_def.name, value)
215
+ when 'func'
216
+ function_name = value.is_a?(Function) ? value.name : value
217
+ FFI.TF_SetAttrFuncName(self, attr_def.name, function_name, function_name.length)
218
+ when 'shape'
219
+ pointer = ::FFI::MemoryPointer.new(:int64, value.length)
220
+ pointer.write_array_of_int64(value)
221
+ FFI.TF_SetAttrShape(self, attr_def.name, pointer, value.length)
222
+ when 'list(shape)'
223
+ dims_pointer = ::FFI::MemoryPointer.new(:pointer, value.length)
224
+ num_dims_pointer = ::FFI::MemoryPointer.new(:int32, value.length)
225
+ value.each_with_index do |shape, i|
226
+ dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
227
+ dim_pointer.write_array_of_int64(shape)
228
+ dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
229
+ num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
230
+ end
231
+ FFI.TF_SetAttrShapeList(self, attr_def.name, dims_pointer, num_dims_pointer, value.length)
232
+ when 'string'
233
+ FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
234
+ when 'list(string)'
235
+ a = 1
236
+ #FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
237
+ when 'tensor'
238
+ Status.check do |status|
239
+ FFI.TF_SetAttrTensor(self, attr_def.name, value, status)
240
+ end
241
+ when 'type'
242
+ FFI.TF_SetAttrType(self, attr_def.name, value)
243
+ when 'list(type)'
244
+ value_ptr = ::FFI::MemoryPointer.new(FFI::DataType.native_type.size, value.count)
245
+ value.each_with_index do |a_value, i|
246
+ value_ptr.put_int32(i * FFI::DataType.native_type.size, FFI::DataType[a_value])
247
+ end
248
+ FFI.TF_SetAttrTypeList(self, attr_def.name, value_ptr, value.count)
249
+ else
250
+ raise(Error::UnimplementedError, "Unsupported attribute. #{self.op_def.name} - #{attr_def.name}")
251
+ end
252
+ end
253
+ end
254
+ end
255
+ end
@@ -0,0 +1,49 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class OperationOutput
4
+ attr_reader :operation, :output
5
+
6
+ def self.from_pointer(operation, pointer)
7
+ output = FFI::Output.new(pointer)
8
+ self.new(operation, output)
9
+ end
10
+
11
+ def self.from_index(operation, index)
12
+ output = FFI::Output.new
13
+ output[:index] = index
14
+ output[:oper] = operation
15
+ self.new(operation, output)
16
+ end
17
+
18
+ def self.from_graph(graph, pointer)
19
+ output = FFI::Output.new(pointer)
20
+ operation = Operation.new(graph, output[:oper])
21
+ self.new(operation, output)
22
+ end
23
+
24
+ def initialize(operation, output)
25
+ @operation = operation
26
+ @output = output
27
+ end
28
+
29
+ def to_ptr
30
+ @output.to_ptr
31
+ end
32
+
33
+ def index
34
+ self.output[:index]
35
+ end
36
+
37
+ def to_s
38
+ if self.output
39
+ result = [self.operation.op_type]
40
+ result << "name=#{self.operation.name}"
41
+ result << "#{self.index}:(shape=#{self.operation.output_shapes[self.index]}, dtype=#{self.operation.output_types[self.index]})"
42
+ result.join(', ')
43
+ else
44
+ super
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end