tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,252 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class Graph
4
+ attr_reader :control_inputs
5
+
6
+ extend Forwardable
7
+ def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
8
+
9
+ def self.default
10
+ @default ||= Graph.new
11
+ end
12
+
13
+ def self.reset_default
14
+ @default = Graph.new
15
+ end
16
+
17
+ def self.finalize(pointer)
18
+ proc do
19
+ FFI::TF_DeleteGraph(pointer)
20
+ end
21
+ end
22
+
23
+ def initialize
24
+ @collections = Hash.new
25
+ @name_scope = NameScope.new
26
+ @pointer = FFI.TF_NewGraph()
27
+ @control_inputs = Array.new
28
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
29
+ end
30
+
31
+ def to_ptr
32
+ @pointer
33
+ end
34
+
35
+ def collections
36
+ @collections.keys
37
+ end
38
+
39
+ def add_to_collection(name, value)
40
+ values = @collections[name] ||= Array.new
41
+ values << value
42
+ end
43
+
44
+ def add_to_collections(names, value)
45
+ names.each do |name|
46
+ self.add_to_collection(name, value)
47
+ end
48
+ end
49
+
50
+ def get_collection_ref(name, scope=nil)
51
+ @collections[name]
52
+ end
53
+
54
+ def clear_collection(name)
55
+ @collections[name] = Array.new
56
+ end
57
+
58
+ def as_default
59
+ raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
60
+ ExecutionContext.push(self)
61
+ begin
62
+ yield self
63
+ ensure
64
+ ExecutionContext.pop
65
+ end
66
+ end
67
+
68
+ def control_dependencies(control_inputs)
69
+ @control_inputs = Array(control_inputs)
70
+ begin
71
+ yield self
72
+ ensure
73
+ @control_inputs = []
74
+ end
75
+ end
76
+
77
+ def op_def(op_type)
78
+ buffer_ptr = FFI.TF_NewBuffer
79
+ Status.check do |status|
80
+ FFI.TF_GraphGetOpDef(self, op_type, buffer_ptr, status)
81
+ end
82
+ buffer = FFI::Buffer.new(buffer_ptr)
83
+ string = buffer[:data].read_string(buffer[:length])
84
+ OpDef.decode(string)
85
+ ensure
86
+ FFI.TF_DeleteBuffer(buffer)
87
+ end
88
+
89
+ def forward(operation)
90
+ def forward_internal(set, operation)
91
+ operation.consumers.each do |consumer|
92
+ set << consumer.operation
93
+ forward_internal(set, consumer.operation)
94
+ end
95
+ set
96
+ end
97
+ result = Set.new([operation])
98
+ forward_internal(result, operation)
99
+ end
100
+
101
+ def backward(operation)
102
+ def backward_internal(set, operation)
103
+ operation.inputs.each do |input|
104
+ set << input.operation
105
+ backward_internal(set, input.operation)
106
+ end
107
+ set
108
+ end
109
+ result = Set.new([operation])
110
+ backward_internal(result, operation)
111
+ end
112
+
113
+ def operations
114
+ return enum_for(:operations) unless block_given?
115
+
116
+ # Get a pointer to a size_t set to 0
117
+ position_ptr = ::FFI::MemoryPointer.new(:size_t, 1, true)
118
+ while (ptr = FFI.TF_GraphNextOperation(self, position_ptr))
119
+ break if ptr.null?
120
+ yield Operation.new(self, ptr)
121
+ end
122
+ end
123
+
124
+ def operation(name)
125
+ ptr = FFI.TF_GraphOperationByName(self, name)
126
+ ptr.null? ? nil : Operation.new(self, ptr)
127
+ end
128
+
129
+ def create_operation(op_type, inputs=[], attrs={})
130
+ op_desc = OperationDescription.new(self, op_type, inputs, attrs)
131
+ op_desc.save
132
+ end
133
+
134
+ def execute(operations, feed_dict={})
135
+ session = Session.new(self, SessionOptions.new)
136
+ result = session.run(operations, feed_dict)
137
+ session.close
138
+ result
139
+ end
140
+
141
+ def output_shapes(operation)
142
+ operation.outputs.map do |output|
143
+ num_dims = Status.check do |status|
144
+ FFI.TF_GraphGetTensorNumDims(self, output, status)
145
+ end
146
+
147
+ if num_dims == -1
148
+ []
149
+ else
150
+ dims_ptr = ::FFI::MemoryPointer.new(:int64, num_dims)
151
+ Status.check do |status|
152
+ FFI.TF_GraphGetTensorShape(self, output, dims_ptr, num_dims, status)
153
+ end
154
+ dims_ptr.read_array_of_int64(num_dims)
155
+ end
156
+ end
157
+ end
158
+
159
+ def tensor_set_shape(operation, shape)
160
+ ptr = ::FFI::MemoryPointer.new(:int64, shape.length)
161
+ ptr.write_array_of_int64(shape)
162
+ output = FFI::Output.new
163
+ output[:oper] = operation
164
+ output[:index] = 0
165
+ Status.check do |status|
166
+ FFI.TF_GraphSetTensorShape(self, output, ptr, shape.length, status)
167
+ end
168
+ end
169
+
170
+ def add_function(function, gradient=nil)
171
+ Status.check do |status|
172
+ FFI.TF_GraphCopyFunction(self, function, gradient, status)
173
+ end
174
+ end
175
+
176
+ def to_function(name, operators, input_operations, output_operations, output_names=nil)
177
+ inputs = input_operations ? input_operations.map(&:outputs).flatten : []
178
+ inputs_ptr = FFI::Output.array_to_ptr(inputs.map(&:output))
179
+
180
+ outputs = output_operations ? output_operations.map(&:outputs).flatten : []
181
+ outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
182
+
183
+ # Check output names size
184
+ if output_names && output_names.length != outputs.length
185
+ raise(ArgumentError, "output_names length must equal outputs length or be nil")
186
+ end
187
+
188
+ # Convert to pointers - keep reference to pointers so they are not GC'ed until the end of the method
189
+ output_names_ptr = if output_names
190
+ output_names_ptrs = output_names.map do |output_name|
191
+ ::FFI::MemoryPointer.from_string(output_name)
192
+ end
193
+ output_names_ptr = ::FFI::MemoryPointer.new(:pointer, output_names_ptrs.length, true)
194
+ output_names_ptr.write_array_of_pointer(output_names_ptrs)
195
+ output_names_ptr
196
+ else
197
+ nil
198
+ end
199
+
200
+ append_hash_to_fn_name = 0
201
+ options = nil
202
+ description = nil
203
+
204
+ func = Status.check do |status|
205
+ FFI.TF_GraphToFunction(self, name, append_hash_to_fn_name,
206
+ operators ? operators.length : -1, operators,
207
+ inputs ? inputs.length : 0, inputs_ptr,
208
+ outputs ? outputs.length: 0, outputs_ptr,
209
+ output_names_ptr,
210
+ options, description, status)
211
+ end
212
+ output_types = output_operations.map(&:output_types).flatten(1)
213
+ output_shapes = output_operations.map(&:output_shapes).flatten(1)
214
+ Function.new(func, output_types, output_shapes)
215
+ end
216
+
217
+ def as_graph_def
218
+ buffer_ptr = FFI.TF_NewBuffer
219
+ Status.check do |status|
220
+ FFI.TF_GraphToGraphDef(self, buffer_ptr, status)
221
+ end
222
+
223
+ buffer = FFI::Buffer.new(buffer_ptr)
224
+ string = buffer[:data].read_string(buffer[:length])
225
+ GraphDef.decode(string)
226
+ ensure
227
+ FFI.TF_DeleteBuffer(buffer)
228
+ end
229
+
230
+ def import(graph_def, options=nil)
231
+ options ||= GraphDefOptions.new
232
+
233
+ data = if graph_def.is_a?(GraphDef)
234
+ GraphDef.encode(graph_def)
235
+ else
236
+ graph_def
237
+ end
238
+
239
+ ptr = ::FFI::MemoryPointer.new(:char, data.bytesize)
240
+ ptr.put_bytes(0, data)
241
+
242
+ buffer = FFI::Buffer.new
243
+ buffer[:data] = ptr
244
+ buffer[:length] = data.bytesize
245
+
246
+ Status.check do |status|
247
+ FFI.TF_GraphImportGraphDef(self, buffer, options, status)
248
+ end
249
+ end
250
+ end
251
+ end
252
+ end
@@ -0,0 +1,24 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class GraphDefOptions
4
+ def self.finalize(pointer)
5
+ proc do
6
+ FFI.TF_DeleteImportGraphDefOptions(pointer)
7
+ end
8
+ end
9
+
10
+ def initialize
11
+ @pointer = FFI.TF_NewImportGraphDefOptions()
12
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
13
+ end
14
+
15
+ def to_ptr
16
+ @pointer
17
+ end
18
+
19
+ def prefix=(value)
20
+ FFI.TF_ImportGraphDefOptionsSetPrefix(self, value)
21
+ end
22
+ end
23
+ end
24
+ end
@@ -0,0 +1,50 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class GraphKeys
4
+ GLOBAL_VARIABLES = "variables"
5
+ LOCAL_VARIABLES = "local_variables"
6
+ METRIC_VARIABLES = "metric_variables"
7
+ MODEL_VARIABLES = "model_variables"
8
+ TRAINABLE_VARIABLES = "trainable_variables"
9
+ SUMMARIES = "summaries"
10
+ QUEUE_RUNNERS = "queue_runners"
11
+ TABLE_INITIALIZERS = "table_initializer"
12
+ ASSET_FILEPATHS = "asset_filepaths"
13
+ MOVING_AVERAGE_VARIABLES = "moving_average_variables"
14
+ REGULARIZATION_LOSSES = "regularization_losses"
15
+ CONCATENATED_VARIABLES = "concatenated_variables"
16
+ SAVERS = "savers"
17
+ WEIGHTS = "weights"
18
+ BIASES = "biases"
19
+ ACTIVATIONS = "activations"
20
+ UPDATE_OPS = "update_ops"
21
+ LOSSES = "losses"
22
+ SAVEABLE_OBJECTS = "saveable_objects"
23
+ RESOURCES = "resources"
24
+ LOCAL_RESOURCES = "local_resources"
25
+ TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
26
+ INIT_OP = "init_op"
27
+ LOCAL_INIT_OP = "local_init_op"
28
+ READY_OP = "ready_op"
29
+ READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
30
+ SUMMARY_OP = "summary_op"
31
+ GLOBAL_STEP = "global_step"
32
+ EVAL_STEP = "eval_step"
33
+ COND_CONTEXT = "cond_context"
34
+ WHILE_CONTEXT = "while_context"
35
+ SUMMARY_COLLECTION = "_SUMMARY_V2"
36
+ _VARIABLE_COLLECTIONS = [
37
+ GLOBAL_VARIABLES,
38
+ LOCAL_VARIABLES,
39
+ METRIC_VARIABLES,
40
+ MODEL_VARIABLES,
41
+ TRAINABLE_VARIABLES,
42
+ MOVING_AVERAGE_VARIABLES,
43
+ CONCATENATED_VARIABLES,
44
+ TRAINABLE_RESOURCE_VARIABLES,
45
+ ]
46
+ _STREAMING_MODEL_PORTS = "streaming_model_ports"
47
+ end
48
+ end
49
+ end
50
+
@@ -0,0 +1,176 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class Operation
4
+ include Operators
5
+ attr_reader :graph
6
+
7
+ def initialize(graph, pointer)
8
+ @graph = graph
9
+ @pointer = pointer
10
+ end
11
+
12
+ def to_ptr
13
+ @pointer
14
+ end
15
+
16
+ def eql?(other)
17
+ self.name.eql?(other.name)
18
+ end
19
+
20
+ def ==(other)
21
+ self.name == other.name
22
+ end
23
+
24
+ def hash
25
+ self.name.hash
26
+ end
27
+
28
+ def name
29
+ FFI.TF_OperationName(self)
30
+ end
31
+
32
+ def op_type
33
+ FFI.TF_OperationOpType(self)
34
+ end
35
+
36
+ def op_def
37
+ self.graph.op_def(self.op_type)
38
+ end
39
+
40
+ def device
41
+ FFI.TF_OperationDevice(self)
42
+ end
43
+
44
+ def node_def
45
+ buffer_ptr = FFI.TF_NewBuffer
46
+ Status.check do |status|
47
+ FFI.TF_OperationToNodeDef(self, buffer_ptr, status)
48
+ end
49
+ buffer = FFI::Buffer.new(buffer_ptr)
50
+ string = buffer[:data].read_string(buffer[:length])
51
+ NodeDef.decode(string)
52
+ ensure
53
+ FFI.TF_DeleteBuffer(buffer)
54
+ end
55
+
56
+ def num_inputs
57
+ FFI.TF_OperationNumInputs(self)
58
+ end
59
+
60
+ def inputs
61
+ pointer = ::FFI::MemoryPointer.new(FFI::Output, self.num_inputs)
62
+ FFI.TF_OperationAllInputs(self, pointer, self.num_inputs)
63
+ self.num_inputs.times.map do |index|
64
+ OperationOutput.from_graph(self.graph, pointer[index])
65
+ end
66
+ end
67
+
68
+ def num_control_outputs
69
+ FFI.TF_OperationNumControlOutputs(self)
70
+ end
71
+
72
+ def control_outputs
73
+ pointer = ::FFI::MemoryPointer.new(:pointer, self.num_control_outputs)
74
+ FFI.TF_OperationGetControlOutputs(self, pointer, self.num_control_outputs)
75
+ self.num_control_outputs.times.map do |index|
76
+ operation_ptr = pointer[index].read_pointer
77
+ self.class.new(self.graph, operation_ptr)
78
+ end
79
+ end
80
+
81
+ def num_outputs
82
+ FFI.TF_OperationNumOutputs(self)
83
+ end
84
+
85
+ def outputs
86
+ self.num_outputs.times.map do |i|
87
+ OperationOutput.from_index(self, i)
88
+ end
89
+ end
90
+
91
+ def [](index)
92
+ self.outputs[index]
93
+ end
94
+
95
+ def output_types
96
+ self.outputs.map do |output|
97
+ FFI.TF_OperationOutputType(output)
98
+ end
99
+ end
100
+
101
+ def output_shapes
102
+ self.graph.output_shapes(self)
103
+ end
104
+
105
+ def shape
106
+ self.output_shapes.first
107
+ end
108
+
109
+ def dtype
110
+ self.output_types.first
111
+ end
112
+
113
+ def output_list_length(arg_name)
114
+ Status.check do |status|
115
+ FFI.TF_OperationOutputListLength(self, arg_name, status)
116
+ end
117
+ end
118
+
119
+ def num_control_inputs
120
+ FFI.TF_OperationNumControlInputs(self)
121
+ end
122
+
123
+ def control_inputs
124
+ pointer = ::FFI::MemoryPointer.new(:pointer, self.num_control_inputs)
125
+ FFI.TF_OperationGetControlInputs(self, pointer, self.num_control_inputs)
126
+ self.num_control_inputs.times.map do |index|
127
+ operation_ptr = pointer[index].read_pointer
128
+ self.class.new(self.graph, operation_ptr)
129
+ end
130
+ end
131
+
132
+ def attributes
133
+ self.op_def.attr.map do |attr_def|
134
+ self.attr(attr_def.name)
135
+ end
136
+ end
137
+
138
+ def attr(attr_name)
139
+ metadata = Status.check do |status|
140
+ FFI.TF_OperationGetAttrMetadata(self, attr_name, status)
141
+ end
142
+
143
+ OperationAttr.new(self, attr_name, metadata)
144
+ end
145
+
146
+ def output_consumers(output)
147
+ # How many consumers does this output have?
148
+ count = FFI.TF_OperationOutputNumConsumers(output)
149
+
150
+ # Get the consumers
151
+ consumers_ptr = ::FFI::MemoryPointer.new(FFI::Input, count)
152
+ FFI.TF_OperationOutputConsumers(output, consumers_ptr, count)
153
+
154
+ count.times.map do |i|
155
+ OperationOutput.from_graph(self.graph, consumers_ptr[i])
156
+ end
157
+ end
158
+
159
+ def consumers
160
+ self.outputs.reduce(Array.new) do |result, output|
161
+ result.concat(self.output_consumers(output))
162
+ result
163
+ end
164
+ end
165
+
166
+ def to_s
167
+ result = [self.op_type]
168
+ result << "name=#{self.name}"
169
+ outputs.length.times do |index|
170
+ result << "#{index}:(shape=#{self.output_shapes[index]}, dtype=#{self.output_types[index]})"
171
+ end
172
+ result.join(', ')
173
+ end
174
+ end
175
+ end
176
+ end