tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,252 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class Graph
|
4
|
+
attr_reader :control_inputs
|
5
|
+
|
6
|
+
extend Forwardable
|
7
|
+
def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
|
8
|
+
|
9
|
+
def self.default
|
10
|
+
@default ||= Graph.new
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.reset_default
|
14
|
+
@default = Graph.new
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.finalize(pointer)
|
18
|
+
proc do
|
19
|
+
FFI::TF_DeleteGraph(pointer)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def initialize
|
24
|
+
@collections = Hash.new
|
25
|
+
@name_scope = NameScope.new
|
26
|
+
@pointer = FFI.TF_NewGraph()
|
27
|
+
@control_inputs = Array.new
|
28
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
29
|
+
end
|
30
|
+
|
31
|
+
def to_ptr
|
32
|
+
@pointer
|
33
|
+
end
|
34
|
+
|
35
|
+
def collections
|
36
|
+
@collections.keys
|
37
|
+
end
|
38
|
+
|
39
|
+
def add_to_collection(name, value)
|
40
|
+
values = @collections[name] ||= Array.new
|
41
|
+
values << value
|
42
|
+
end
|
43
|
+
|
44
|
+
def add_to_collections(names, value)
|
45
|
+
names.each do |name|
|
46
|
+
self.add_to_collection(name, value)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def get_collection_ref(name, scope=nil)
|
51
|
+
@collections[name]
|
52
|
+
end
|
53
|
+
|
54
|
+
def clear_collection(name)
|
55
|
+
@collections[name] = Array.new
|
56
|
+
end
|
57
|
+
|
58
|
+
def as_default
|
59
|
+
raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
|
60
|
+
ExecutionContext.push(self)
|
61
|
+
begin
|
62
|
+
yield self
|
63
|
+
ensure
|
64
|
+
ExecutionContext.pop
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
def control_dependencies(control_inputs)
|
69
|
+
@control_inputs = Array(control_inputs)
|
70
|
+
begin
|
71
|
+
yield self
|
72
|
+
ensure
|
73
|
+
@control_inputs = []
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
def op_def(op_type)
|
78
|
+
buffer_ptr = FFI.TF_NewBuffer
|
79
|
+
Status.check do |status|
|
80
|
+
FFI.TF_GraphGetOpDef(self, op_type, buffer_ptr, status)
|
81
|
+
end
|
82
|
+
buffer = FFI::Buffer.new(buffer_ptr)
|
83
|
+
string = buffer[:data].read_string(buffer[:length])
|
84
|
+
OpDef.decode(string)
|
85
|
+
ensure
|
86
|
+
FFI.TF_DeleteBuffer(buffer)
|
87
|
+
end
|
88
|
+
|
89
|
+
def forward(operation)
|
90
|
+
def forward_internal(set, operation)
|
91
|
+
operation.consumers.each do |consumer|
|
92
|
+
set << consumer.operation
|
93
|
+
forward_internal(set, consumer.operation)
|
94
|
+
end
|
95
|
+
set
|
96
|
+
end
|
97
|
+
result = Set.new([operation])
|
98
|
+
forward_internal(result, operation)
|
99
|
+
end
|
100
|
+
|
101
|
+
def backward(operation)
|
102
|
+
def backward_internal(set, operation)
|
103
|
+
operation.inputs.each do |input|
|
104
|
+
set << input.operation
|
105
|
+
backward_internal(set, input.operation)
|
106
|
+
end
|
107
|
+
set
|
108
|
+
end
|
109
|
+
result = Set.new([operation])
|
110
|
+
backward_internal(result, operation)
|
111
|
+
end
|
112
|
+
|
113
|
+
def operations
|
114
|
+
return enum_for(:operations) unless block_given?
|
115
|
+
|
116
|
+
# Get a pointer to a size_t set to 0
|
117
|
+
position_ptr = ::FFI::MemoryPointer.new(:size_t, 1, true)
|
118
|
+
while (ptr = FFI.TF_GraphNextOperation(self, position_ptr))
|
119
|
+
break if ptr.null?
|
120
|
+
yield Operation.new(self, ptr)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
def operation(name)
|
125
|
+
ptr = FFI.TF_GraphOperationByName(self, name)
|
126
|
+
ptr.null? ? nil : Operation.new(self, ptr)
|
127
|
+
end
|
128
|
+
|
129
|
+
def create_operation(op_type, inputs=[], attrs={})
|
130
|
+
op_desc = OperationDescription.new(self, op_type, inputs, attrs)
|
131
|
+
op_desc.save
|
132
|
+
end
|
133
|
+
|
134
|
+
def execute(operations, feed_dict={})
|
135
|
+
session = Session.new(self, SessionOptions.new)
|
136
|
+
result = session.run(operations, feed_dict)
|
137
|
+
session.close
|
138
|
+
result
|
139
|
+
end
|
140
|
+
|
141
|
+
def output_shapes(operation)
|
142
|
+
operation.outputs.map do |output|
|
143
|
+
num_dims = Status.check do |status|
|
144
|
+
FFI.TF_GraphGetTensorNumDims(self, output, status)
|
145
|
+
end
|
146
|
+
|
147
|
+
if num_dims == -1
|
148
|
+
[]
|
149
|
+
else
|
150
|
+
dims_ptr = ::FFI::MemoryPointer.new(:int64, num_dims)
|
151
|
+
Status.check do |status|
|
152
|
+
FFI.TF_GraphGetTensorShape(self, output, dims_ptr, num_dims, status)
|
153
|
+
end
|
154
|
+
dims_ptr.read_array_of_int64(num_dims)
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
def tensor_set_shape(operation, shape)
|
160
|
+
ptr = ::FFI::MemoryPointer.new(:int64, shape.length)
|
161
|
+
ptr.write_array_of_int64(shape)
|
162
|
+
output = FFI::Output.new
|
163
|
+
output[:oper] = operation
|
164
|
+
output[:index] = 0
|
165
|
+
Status.check do |status|
|
166
|
+
FFI.TF_GraphSetTensorShape(self, output, ptr, shape.length, status)
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
170
|
+
def add_function(function, gradient=nil)
|
171
|
+
Status.check do |status|
|
172
|
+
FFI.TF_GraphCopyFunction(self, function, gradient, status)
|
173
|
+
end
|
174
|
+
end
|
175
|
+
|
176
|
+
def to_function(name, operators, input_operations, output_operations, output_names=nil)
|
177
|
+
inputs = input_operations ? input_operations.map(&:outputs).flatten : []
|
178
|
+
inputs_ptr = FFI::Output.array_to_ptr(inputs.map(&:output))
|
179
|
+
|
180
|
+
outputs = output_operations ? output_operations.map(&:outputs).flatten : []
|
181
|
+
outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
|
182
|
+
|
183
|
+
# Check output names size
|
184
|
+
if output_names && output_names.length != outputs.length
|
185
|
+
raise(ArgumentError, "output_names length must equal outputs length or be nil")
|
186
|
+
end
|
187
|
+
|
188
|
+
# Convert to pointers - keep reference to pointers so they are not GC'ed until the end of the method
|
189
|
+
output_names_ptr = if output_names
|
190
|
+
output_names_ptrs = output_names.map do |output_name|
|
191
|
+
::FFI::MemoryPointer.from_string(output_name)
|
192
|
+
end
|
193
|
+
output_names_ptr = ::FFI::MemoryPointer.new(:pointer, output_names_ptrs.length, true)
|
194
|
+
output_names_ptr.write_array_of_pointer(output_names_ptrs)
|
195
|
+
output_names_ptr
|
196
|
+
else
|
197
|
+
nil
|
198
|
+
end
|
199
|
+
|
200
|
+
append_hash_to_fn_name = 0
|
201
|
+
options = nil
|
202
|
+
description = nil
|
203
|
+
|
204
|
+
func = Status.check do |status|
|
205
|
+
FFI.TF_GraphToFunction(self, name, append_hash_to_fn_name,
|
206
|
+
operators ? operators.length : -1, operators,
|
207
|
+
inputs ? inputs.length : 0, inputs_ptr,
|
208
|
+
outputs ? outputs.length: 0, outputs_ptr,
|
209
|
+
output_names_ptr,
|
210
|
+
options, description, status)
|
211
|
+
end
|
212
|
+
output_types = output_operations.map(&:output_types).flatten(1)
|
213
|
+
output_shapes = output_operations.map(&:output_shapes).flatten(1)
|
214
|
+
Function.new(func, output_types, output_shapes)
|
215
|
+
end
|
216
|
+
|
217
|
+
def as_graph_def
|
218
|
+
buffer_ptr = FFI.TF_NewBuffer
|
219
|
+
Status.check do |status|
|
220
|
+
FFI.TF_GraphToGraphDef(self, buffer_ptr, status)
|
221
|
+
end
|
222
|
+
|
223
|
+
buffer = FFI::Buffer.new(buffer_ptr)
|
224
|
+
string = buffer[:data].read_string(buffer[:length])
|
225
|
+
GraphDef.decode(string)
|
226
|
+
ensure
|
227
|
+
FFI.TF_DeleteBuffer(buffer)
|
228
|
+
end
|
229
|
+
|
230
|
+
def import(graph_def, options=nil)
|
231
|
+
options ||= GraphDefOptions.new
|
232
|
+
|
233
|
+
data = if graph_def.is_a?(GraphDef)
|
234
|
+
GraphDef.encode(graph_def)
|
235
|
+
else
|
236
|
+
graph_def
|
237
|
+
end
|
238
|
+
|
239
|
+
ptr = ::FFI::MemoryPointer.new(:char, data.bytesize)
|
240
|
+
ptr.put_bytes(0, data)
|
241
|
+
|
242
|
+
buffer = FFI::Buffer.new
|
243
|
+
buffer[:data] = ptr
|
244
|
+
buffer[:length] = data.bytesize
|
245
|
+
|
246
|
+
Status.check do |status|
|
247
|
+
FFI.TF_GraphImportGraphDef(self, buffer, options, status)
|
248
|
+
end
|
249
|
+
end
|
250
|
+
end
|
251
|
+
end
|
252
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class GraphDefOptions
|
4
|
+
def self.finalize(pointer)
|
5
|
+
proc do
|
6
|
+
FFI.TF_DeleteImportGraphDefOptions(pointer)
|
7
|
+
end
|
8
|
+
end
|
9
|
+
|
10
|
+
def initialize
|
11
|
+
@pointer = FFI.TF_NewImportGraphDefOptions()
|
12
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
13
|
+
end
|
14
|
+
|
15
|
+
def to_ptr
|
16
|
+
@pointer
|
17
|
+
end
|
18
|
+
|
19
|
+
def prefix=(value)
|
20
|
+
FFI.TF_ImportGraphDefOptionsSetPrefix(self, value)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class GraphKeys
|
4
|
+
GLOBAL_VARIABLES = "variables"
|
5
|
+
LOCAL_VARIABLES = "local_variables"
|
6
|
+
METRIC_VARIABLES = "metric_variables"
|
7
|
+
MODEL_VARIABLES = "model_variables"
|
8
|
+
TRAINABLE_VARIABLES = "trainable_variables"
|
9
|
+
SUMMARIES = "summaries"
|
10
|
+
QUEUE_RUNNERS = "queue_runners"
|
11
|
+
TABLE_INITIALIZERS = "table_initializer"
|
12
|
+
ASSET_FILEPATHS = "asset_filepaths"
|
13
|
+
MOVING_AVERAGE_VARIABLES = "moving_average_variables"
|
14
|
+
REGULARIZATION_LOSSES = "regularization_losses"
|
15
|
+
CONCATENATED_VARIABLES = "concatenated_variables"
|
16
|
+
SAVERS = "savers"
|
17
|
+
WEIGHTS = "weights"
|
18
|
+
BIASES = "biases"
|
19
|
+
ACTIVATIONS = "activations"
|
20
|
+
UPDATE_OPS = "update_ops"
|
21
|
+
LOSSES = "losses"
|
22
|
+
SAVEABLE_OBJECTS = "saveable_objects"
|
23
|
+
RESOURCES = "resources"
|
24
|
+
LOCAL_RESOURCES = "local_resources"
|
25
|
+
TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
|
26
|
+
INIT_OP = "init_op"
|
27
|
+
LOCAL_INIT_OP = "local_init_op"
|
28
|
+
READY_OP = "ready_op"
|
29
|
+
READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
|
30
|
+
SUMMARY_OP = "summary_op"
|
31
|
+
GLOBAL_STEP = "global_step"
|
32
|
+
EVAL_STEP = "eval_step"
|
33
|
+
COND_CONTEXT = "cond_context"
|
34
|
+
WHILE_CONTEXT = "while_context"
|
35
|
+
SUMMARY_COLLECTION = "_SUMMARY_V2"
|
36
|
+
_VARIABLE_COLLECTIONS = [
|
37
|
+
GLOBAL_VARIABLES,
|
38
|
+
LOCAL_VARIABLES,
|
39
|
+
METRIC_VARIABLES,
|
40
|
+
MODEL_VARIABLES,
|
41
|
+
TRAINABLE_VARIABLES,
|
42
|
+
MOVING_AVERAGE_VARIABLES,
|
43
|
+
CONCATENATED_VARIABLES,
|
44
|
+
TRAINABLE_RESOURCE_VARIABLES,
|
45
|
+
]
|
46
|
+
_STREAMING_MODEL_PORTS = "streaming_model_ports"
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
@@ -0,0 +1,176 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class Operation
|
4
|
+
include Operators
|
5
|
+
attr_reader :graph
|
6
|
+
|
7
|
+
def initialize(graph, pointer)
|
8
|
+
@graph = graph
|
9
|
+
@pointer = pointer
|
10
|
+
end
|
11
|
+
|
12
|
+
def to_ptr
|
13
|
+
@pointer
|
14
|
+
end
|
15
|
+
|
16
|
+
def eql?(other)
|
17
|
+
self.name.eql?(other.name)
|
18
|
+
end
|
19
|
+
|
20
|
+
def ==(other)
|
21
|
+
self.name == other.name
|
22
|
+
end
|
23
|
+
|
24
|
+
def hash
|
25
|
+
self.name.hash
|
26
|
+
end
|
27
|
+
|
28
|
+
def name
|
29
|
+
FFI.TF_OperationName(self)
|
30
|
+
end
|
31
|
+
|
32
|
+
def op_type
|
33
|
+
FFI.TF_OperationOpType(self)
|
34
|
+
end
|
35
|
+
|
36
|
+
def op_def
|
37
|
+
self.graph.op_def(self.op_type)
|
38
|
+
end
|
39
|
+
|
40
|
+
def device
|
41
|
+
FFI.TF_OperationDevice(self)
|
42
|
+
end
|
43
|
+
|
44
|
+
def node_def
|
45
|
+
buffer_ptr = FFI.TF_NewBuffer
|
46
|
+
Status.check do |status|
|
47
|
+
FFI.TF_OperationToNodeDef(self, buffer_ptr, status)
|
48
|
+
end
|
49
|
+
buffer = FFI::Buffer.new(buffer_ptr)
|
50
|
+
string = buffer[:data].read_string(buffer[:length])
|
51
|
+
NodeDef.decode(string)
|
52
|
+
ensure
|
53
|
+
FFI.TF_DeleteBuffer(buffer)
|
54
|
+
end
|
55
|
+
|
56
|
+
def num_inputs
|
57
|
+
FFI.TF_OperationNumInputs(self)
|
58
|
+
end
|
59
|
+
|
60
|
+
def inputs
|
61
|
+
pointer = ::FFI::MemoryPointer.new(FFI::Output, self.num_inputs)
|
62
|
+
FFI.TF_OperationAllInputs(self, pointer, self.num_inputs)
|
63
|
+
self.num_inputs.times.map do |index|
|
64
|
+
OperationOutput.from_graph(self.graph, pointer[index])
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
def num_control_outputs
|
69
|
+
FFI.TF_OperationNumControlOutputs(self)
|
70
|
+
end
|
71
|
+
|
72
|
+
def control_outputs
|
73
|
+
pointer = ::FFI::MemoryPointer.new(:pointer, self.num_control_outputs)
|
74
|
+
FFI.TF_OperationGetControlOutputs(self, pointer, self.num_control_outputs)
|
75
|
+
self.num_control_outputs.times.map do |index|
|
76
|
+
operation_ptr = pointer[index].read_pointer
|
77
|
+
self.class.new(self.graph, operation_ptr)
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
def num_outputs
|
82
|
+
FFI.TF_OperationNumOutputs(self)
|
83
|
+
end
|
84
|
+
|
85
|
+
def outputs
|
86
|
+
self.num_outputs.times.map do |i|
|
87
|
+
OperationOutput.from_index(self, i)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
|
91
|
+
def [](index)
|
92
|
+
self.outputs[index]
|
93
|
+
end
|
94
|
+
|
95
|
+
def output_types
|
96
|
+
self.outputs.map do |output|
|
97
|
+
FFI.TF_OperationOutputType(output)
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
def output_shapes
|
102
|
+
self.graph.output_shapes(self)
|
103
|
+
end
|
104
|
+
|
105
|
+
def shape
|
106
|
+
self.output_shapes.first
|
107
|
+
end
|
108
|
+
|
109
|
+
def dtype
|
110
|
+
self.output_types.first
|
111
|
+
end
|
112
|
+
|
113
|
+
def output_list_length(arg_name)
|
114
|
+
Status.check do |status|
|
115
|
+
FFI.TF_OperationOutputListLength(self, arg_name, status)
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def num_control_inputs
|
120
|
+
FFI.TF_OperationNumControlInputs(self)
|
121
|
+
end
|
122
|
+
|
123
|
+
def control_inputs
|
124
|
+
pointer = ::FFI::MemoryPointer.new(:pointer, self.num_control_inputs)
|
125
|
+
FFI.TF_OperationGetControlInputs(self, pointer, self.num_control_inputs)
|
126
|
+
self.num_control_inputs.times.map do |index|
|
127
|
+
operation_ptr = pointer[index].read_pointer
|
128
|
+
self.class.new(self.graph, operation_ptr)
|
129
|
+
end
|
130
|
+
end
|
131
|
+
|
132
|
+
def attributes
|
133
|
+
self.op_def.attr.map do |attr_def|
|
134
|
+
self.attr(attr_def.name)
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
def attr(attr_name)
|
139
|
+
metadata = Status.check do |status|
|
140
|
+
FFI.TF_OperationGetAttrMetadata(self, attr_name, status)
|
141
|
+
end
|
142
|
+
|
143
|
+
OperationAttr.new(self, attr_name, metadata)
|
144
|
+
end
|
145
|
+
|
146
|
+
def output_consumers(output)
|
147
|
+
# How many consumers does this output have?
|
148
|
+
count = FFI.TF_OperationOutputNumConsumers(output)
|
149
|
+
|
150
|
+
# Get the consumers
|
151
|
+
consumers_ptr = ::FFI::MemoryPointer.new(FFI::Input, count)
|
152
|
+
FFI.TF_OperationOutputConsumers(output, consumers_ptr, count)
|
153
|
+
|
154
|
+
count.times.map do |i|
|
155
|
+
OperationOutput.from_graph(self.graph, consumers_ptr[i])
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
def consumers
|
160
|
+
self.outputs.reduce(Array.new) do |result, output|
|
161
|
+
result.concat(self.output_consumers(output))
|
162
|
+
result
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
def to_s
|
167
|
+
result = [self.op_type]
|
168
|
+
result << "name=#{self.name}"
|
169
|
+
outputs.length.times do |index|
|
170
|
+
result << "#{index}:(shape=#{self.output_shapes[index]}, dtype=#{self.output_types[index]})"
|
171
|
+
end
|
172
|
+
result.join(', ')
|
173
|
+
end
|
174
|
+
end
|
175
|
+
end
|
176
|
+
end
|