tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,252 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class Graph
4
+ attr_reader :control_inputs
5
+
6
+ extend Forwardable
7
+ def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
8
+
9
+ def self.default
10
+ @default ||= Graph.new
11
+ end
12
+
13
+ def self.reset_default
14
+ @default = Graph.new
15
+ end
16
+
17
+ def self.finalize(pointer)
18
+ proc do
19
+ FFI::TF_DeleteGraph(pointer)
20
+ end
21
+ end
22
+
23
+ def initialize
24
+ @collections = Hash.new
25
+ @name_scope = NameScope.new
26
+ @pointer = FFI.TF_NewGraph()
27
+ @control_inputs = Array.new
28
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
29
+ end
30
+
31
+ def to_ptr
32
+ @pointer
33
+ end
34
+
35
+ def collections
36
+ @collections.keys
37
+ end
38
+
39
+ def add_to_collection(name, value)
40
+ values = @collections[name] ||= Array.new
41
+ values << value
42
+ end
43
+
44
+ def add_to_collections(names, value)
45
+ names.each do |name|
46
+ self.add_to_collection(name, value)
47
+ end
48
+ end
49
+
50
+ def get_collection_ref(name, scope=nil)
51
+ @collections[name]
52
+ end
53
+
54
+ def clear_collection(name)
55
+ @collections[name] = Array.new
56
+ end
57
+
58
+ def as_default
59
+ raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
60
+ ExecutionContext.push(self)
61
+ begin
62
+ yield self
63
+ ensure
64
+ ExecutionContext.pop
65
+ end
66
+ end
67
+
68
+ def control_dependencies(control_inputs)
69
+ @control_inputs = Array(control_inputs)
70
+ begin
71
+ yield self
72
+ ensure
73
+ @control_inputs = []
74
+ end
75
+ end
76
+
77
+ def op_def(op_type)
78
+ buffer_ptr = FFI.TF_NewBuffer
79
+ Status.check do |status|
80
+ FFI.TF_GraphGetOpDef(self, op_type, buffer_ptr, status)
81
+ end
82
+ buffer = FFI::Buffer.new(buffer_ptr)
83
+ string = buffer[:data].read_string(buffer[:length])
84
+ OpDef.decode(string)
85
+ ensure
86
+ FFI.TF_DeleteBuffer(buffer)
87
+ end
88
+
89
+ def forward(operation)
90
+ def forward_internal(set, operation)
91
+ operation.consumers.each do |consumer|
92
+ set << consumer.operation
93
+ forward_internal(set, consumer.operation)
94
+ end
95
+ set
96
+ end
97
+ result = Set.new([operation])
98
+ forward_internal(result, operation)
99
+ end
100
+
101
+ def backward(operation)
102
+ def backward_internal(set, operation)
103
+ operation.inputs.each do |input|
104
+ set << input.operation
105
+ backward_internal(set, input.operation)
106
+ end
107
+ set
108
+ end
109
+ result = Set.new([operation])
110
+ backward_internal(result, operation)
111
+ end
112
+
113
+ def operations
114
+ return enum_for(:operations) unless block_given?
115
+
116
+ # Get a pointer to a size_t set to 0
117
+ position_ptr = ::FFI::MemoryPointer.new(:size_t, 1, true)
118
+ while (ptr = FFI.TF_GraphNextOperation(self, position_ptr))
119
+ break if ptr.null?
120
+ yield Operation.new(self, ptr)
121
+ end
122
+ end
123
+
124
+ def operation(name)
125
+ ptr = FFI.TF_GraphOperationByName(self, name)
126
+ ptr.null? ? nil : Operation.new(self, ptr)
127
+ end
128
+
129
+ def create_operation(op_type, inputs=[], attrs={})
130
+ op_desc = OperationDescription.new(self, op_type, inputs, attrs)
131
+ op_desc.save
132
+ end
133
+
134
+ def execute(operations, feed_dict={})
135
+ session = Session.new(self, SessionOptions.new)
136
+ result = session.run(operations, feed_dict)
137
+ session.close
138
+ result
139
+ end
140
+
141
+ def output_shapes(operation)
142
+ operation.outputs.map do |output|
143
+ num_dims = Status.check do |status|
144
+ FFI.TF_GraphGetTensorNumDims(self, output, status)
145
+ end
146
+
147
+ if num_dims == -1
148
+ []
149
+ else
150
+ dims_ptr = ::FFI::MemoryPointer.new(:int64, num_dims)
151
+ Status.check do |status|
152
+ FFI.TF_GraphGetTensorShape(self, output, dims_ptr, num_dims, status)
153
+ end
154
+ dims_ptr.read_array_of_int64(num_dims)
155
+ end
156
+ end
157
+ end
158
+
159
+ def tensor_set_shape(operation, shape)
160
+ ptr = ::FFI::MemoryPointer.new(:int64, shape.length)
161
+ ptr.write_array_of_int64(shape)
162
+ output = FFI::Output.new
163
+ output[:oper] = operation
164
+ output[:index] = 0
165
+ Status.check do |status|
166
+ FFI.TF_GraphSetTensorShape(self, output, ptr, shape.length, status)
167
+ end
168
+ end
169
+
170
+ def add_function(function, gradient=nil)
171
+ Status.check do |status|
172
+ FFI.TF_GraphCopyFunction(self, function, gradient, status)
173
+ end
174
+ end
175
+
176
+ def to_function(name, operators, input_operations, output_operations, output_names=nil)
177
+ inputs = input_operations ? input_operations.map(&:outputs).flatten : []
178
+ inputs_ptr = FFI::Output.array_to_ptr(inputs.map(&:output))
179
+
180
+ outputs = output_operations ? output_operations.map(&:outputs).flatten : []
181
+ outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
182
+
183
+ # Check output names size
184
+ if output_names && output_names.length != outputs.length
185
+ raise(ArgumentError, "output_names length must equal outputs length or be nil")
186
+ end
187
+
188
+ # Convert to pointers - keep reference to pointers so they are not GC'ed until the end of the method
189
+ output_names_ptr = if output_names
190
+ output_names_ptrs = output_names.map do |output_name|
191
+ ::FFI::MemoryPointer.from_string(output_name)
192
+ end
193
+ output_names_ptr = ::FFI::MemoryPointer.new(:pointer, output_names_ptrs.length, true)
194
+ output_names_ptr.write_array_of_pointer(output_names_ptrs)
195
+ output_names_ptr
196
+ else
197
+ nil
198
+ end
199
+
200
+ append_hash_to_fn_name = 0
201
+ options = nil
202
+ description = nil
203
+
204
+ func = Status.check do |status|
205
+ FFI.TF_GraphToFunction(self, name, append_hash_to_fn_name,
206
+ operators ? operators.length : -1, operators,
207
+ inputs ? inputs.length : 0, inputs_ptr,
208
+ outputs ? outputs.length: 0, outputs_ptr,
209
+ output_names_ptr,
210
+ options, description, status)
211
+ end
212
+ output_types = output_operations.map(&:output_types).flatten(1)
213
+ output_shapes = output_operations.map(&:output_shapes).flatten(1)
214
+ Function.new(func, output_types, output_shapes)
215
+ end
216
+
217
+ def as_graph_def
218
+ buffer_ptr = FFI.TF_NewBuffer
219
+ Status.check do |status|
220
+ FFI.TF_GraphToGraphDef(self, buffer_ptr, status)
221
+ end
222
+
223
+ buffer = FFI::Buffer.new(buffer_ptr)
224
+ string = buffer[:data].read_string(buffer[:length])
225
+ GraphDef.decode(string)
226
+ ensure
227
+ FFI.TF_DeleteBuffer(buffer)
228
+ end
229
+
230
+ def import(graph_def, options=nil)
231
+ options ||= GraphDefOptions.new
232
+
233
+ data = if graph_def.is_a?(GraphDef)
234
+ GraphDef.encode(graph_def)
235
+ else
236
+ graph_def
237
+ end
238
+
239
+ ptr = ::FFI::MemoryPointer.new(:char, data.bytesize)
240
+ ptr.put_bytes(0, data)
241
+
242
+ buffer = FFI::Buffer.new
243
+ buffer[:data] = ptr
244
+ buffer[:length] = data.bytesize
245
+
246
+ Status.check do |status|
247
+ FFI.TF_GraphImportGraphDef(self, buffer, options, status)
248
+ end
249
+ end
250
+ end
251
+ end
252
+ end
@@ -0,0 +1,24 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class GraphDefOptions
4
+ def self.finalize(pointer)
5
+ proc do
6
+ FFI.TF_DeleteImportGraphDefOptions(pointer)
7
+ end
8
+ end
9
+
10
+ def initialize
11
+ @pointer = FFI.TF_NewImportGraphDefOptions()
12
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
13
+ end
14
+
15
+ def to_ptr
16
+ @pointer
17
+ end
18
+
19
+ def prefix=(value)
20
+ FFI.TF_ImportGraphDefOptionsSetPrefix(self, value)
21
+ end
22
+ end
23
+ end
24
+ end
@@ -0,0 +1,50 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class GraphKeys
4
+ GLOBAL_VARIABLES = "variables"
5
+ LOCAL_VARIABLES = "local_variables"
6
+ METRIC_VARIABLES = "metric_variables"
7
+ MODEL_VARIABLES = "model_variables"
8
+ TRAINABLE_VARIABLES = "trainable_variables"
9
+ SUMMARIES = "summaries"
10
+ QUEUE_RUNNERS = "queue_runners"
11
+ TABLE_INITIALIZERS = "table_initializer"
12
+ ASSET_FILEPATHS = "asset_filepaths"
13
+ MOVING_AVERAGE_VARIABLES = "moving_average_variables"
14
+ REGULARIZATION_LOSSES = "regularization_losses"
15
+ CONCATENATED_VARIABLES = "concatenated_variables"
16
+ SAVERS = "savers"
17
+ WEIGHTS = "weights"
18
+ BIASES = "biases"
19
+ ACTIVATIONS = "activations"
20
+ UPDATE_OPS = "update_ops"
21
+ LOSSES = "losses"
22
+ SAVEABLE_OBJECTS = "saveable_objects"
23
+ RESOURCES = "resources"
24
+ LOCAL_RESOURCES = "local_resources"
25
+ TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
26
+ INIT_OP = "init_op"
27
+ LOCAL_INIT_OP = "local_init_op"
28
+ READY_OP = "ready_op"
29
+ READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
30
+ SUMMARY_OP = "summary_op"
31
+ GLOBAL_STEP = "global_step"
32
+ EVAL_STEP = "eval_step"
33
+ COND_CONTEXT = "cond_context"
34
+ WHILE_CONTEXT = "while_context"
35
+ SUMMARY_COLLECTION = "_SUMMARY_V2"
36
+ _VARIABLE_COLLECTIONS = [
37
+ GLOBAL_VARIABLES,
38
+ LOCAL_VARIABLES,
39
+ METRIC_VARIABLES,
40
+ MODEL_VARIABLES,
41
+ TRAINABLE_VARIABLES,
42
+ MOVING_AVERAGE_VARIABLES,
43
+ CONCATENATED_VARIABLES,
44
+ TRAINABLE_RESOURCE_VARIABLES,
45
+ ]
46
+ _STREAMING_MODEL_PORTS = "streaming_model_ports"
47
+ end
48
+ end
49
+ end
50
+
@@ -0,0 +1,176 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class Operation
4
+ include Operators
5
+ attr_reader :graph
6
+
7
+ def initialize(graph, pointer)
8
+ @graph = graph
9
+ @pointer = pointer
10
+ end
11
+
12
+ def to_ptr
13
+ @pointer
14
+ end
15
+
16
+ def eql?(other)
17
+ self.name.eql?(other.name)
18
+ end
19
+
20
+ def ==(other)
21
+ self.name == other.name
22
+ end
23
+
24
+ def hash
25
+ self.name.hash
26
+ end
27
+
28
+ def name
29
+ FFI.TF_OperationName(self)
30
+ end
31
+
32
+ def op_type
33
+ FFI.TF_OperationOpType(self)
34
+ end
35
+
36
+ def op_def
37
+ self.graph.op_def(self.op_type)
38
+ end
39
+
40
+ def device
41
+ FFI.TF_OperationDevice(self)
42
+ end
43
+
44
+ def node_def
45
+ buffer_ptr = FFI.TF_NewBuffer
46
+ Status.check do |status|
47
+ FFI.TF_OperationToNodeDef(self, buffer_ptr, status)
48
+ end
49
+ buffer = FFI::Buffer.new(buffer_ptr)
50
+ string = buffer[:data].read_string(buffer[:length])
51
+ NodeDef.decode(string)
52
+ ensure
53
+ FFI.TF_DeleteBuffer(buffer)
54
+ end
55
+
56
+ def num_inputs
57
+ FFI.TF_OperationNumInputs(self)
58
+ end
59
+
60
+ def inputs
61
+ pointer = ::FFI::MemoryPointer.new(FFI::Output, self.num_inputs)
62
+ FFI.TF_OperationAllInputs(self, pointer, self.num_inputs)
63
+ self.num_inputs.times.map do |index|
64
+ OperationOutput.from_graph(self.graph, pointer[index])
65
+ end
66
+ end
67
+
68
+ def num_control_outputs
69
+ FFI.TF_OperationNumControlOutputs(self)
70
+ end
71
+
72
+ def control_outputs
73
+ pointer = ::FFI::MemoryPointer.new(:pointer, self.num_control_outputs)
74
+ FFI.TF_OperationGetControlOutputs(self, pointer, self.num_control_outputs)
75
+ self.num_control_outputs.times.map do |index|
76
+ operation_ptr = pointer[index].read_pointer
77
+ self.class.new(self.graph, operation_ptr)
78
+ end
79
+ end
80
+
81
+ def num_outputs
82
+ FFI.TF_OperationNumOutputs(self)
83
+ end
84
+
85
+ def outputs
86
+ self.num_outputs.times.map do |i|
87
+ OperationOutput.from_index(self, i)
88
+ end
89
+ end
90
+
91
+ def [](index)
92
+ self.outputs[index]
93
+ end
94
+
95
+ def output_types
96
+ self.outputs.map do |output|
97
+ FFI.TF_OperationOutputType(output)
98
+ end
99
+ end
100
+
101
+ def output_shapes
102
+ self.graph.output_shapes(self)
103
+ end
104
+
105
+ def shape
106
+ self.output_shapes.first
107
+ end
108
+
109
+ def dtype
110
+ self.output_types.first
111
+ end
112
+
113
+ def output_list_length(arg_name)
114
+ Status.check do |status|
115
+ FFI.TF_OperationOutputListLength(self, arg_name, status)
116
+ end
117
+ end
118
+
119
+ def num_control_inputs
120
+ FFI.TF_OperationNumControlInputs(self)
121
+ end
122
+
123
+ def control_inputs
124
+ pointer = ::FFI::MemoryPointer.new(:pointer, self.num_control_inputs)
125
+ FFI.TF_OperationGetControlInputs(self, pointer, self.num_control_inputs)
126
+ self.num_control_inputs.times.map do |index|
127
+ operation_ptr = pointer[index].read_pointer
128
+ self.class.new(self.graph, operation_ptr)
129
+ end
130
+ end
131
+
132
+ def attributes
133
+ self.op_def.attr.map do |attr_def|
134
+ self.attr(attr_def.name)
135
+ end
136
+ end
137
+
138
+ def attr(attr_name)
139
+ metadata = Status.check do |status|
140
+ FFI.TF_OperationGetAttrMetadata(self, attr_name, status)
141
+ end
142
+
143
+ OperationAttr.new(self, attr_name, metadata)
144
+ end
145
+
146
+ def output_consumers(output)
147
+ # How many consumers does this output have?
148
+ count = FFI.TF_OperationOutputNumConsumers(output)
149
+
150
+ # Get the consumers
151
+ consumers_ptr = ::FFI::MemoryPointer.new(FFI::Input, count)
152
+ FFI.TF_OperationOutputConsumers(output, consumers_ptr, count)
153
+
154
+ count.times.map do |i|
155
+ OperationOutput.from_graph(self.graph, consumers_ptr[i])
156
+ end
157
+ end
158
+
159
+ def consumers
160
+ self.outputs.reduce(Array.new) do |result, output|
161
+ result.concat(self.output_consumers(output))
162
+ result
163
+ end
164
+ end
165
+
166
+ def to_s
167
+ result = [self.op_type]
168
+ result << "name=#{self.name}"
169
+ outputs.length.times do |index|
170
+ result << "#{index}:(shape=#{self.output_shapes[index]}, dtype=#{self.output_types[index]})"
171
+ end
172
+ result.join(', ')
173
+ end
174
+ end
175
+ end
176
+ end