tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,38 @@
|
|
1
|
+
# Do NOT edit this file. It is generated by `rake generate_ops`
|
2
|
+
|
3
|
+
module Tensorflow
|
4
|
+
module RawOps
|
5
|
+
def self.execute(op_type, inputs=[], attrs={})
|
6
|
+
context = ExecutionContext.current(inputs)
|
7
|
+
attrs = attrs.compact
|
8
|
+
operation = context.create_operation(op_type, inputs, attrs)
|
9
|
+
if context.is_a?(Graph::Graph)
|
10
|
+
operation
|
11
|
+
else
|
12
|
+
context.execute(operation)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
<% Tensorflow.op_defs.values.sort_by(&:name).each do |op_def|
|
17
|
+
name = RawOpHelper.underscore(op_def.name).downcase
|
18
|
+
|
19
|
+
input_names = op_def.input_arg.map do |input|
|
20
|
+
RawOpHelper.check_name(input.name)
|
21
|
+
end
|
22
|
+
|
23
|
+
attributes_in = op_def.attr.map do |attr_def|
|
24
|
+
RawOpHelper.process_attribute(attr_def)
|
25
|
+
end
|
26
|
+
attributes_in << "name: \"#{op_def.name}\""
|
27
|
+
|
28
|
+
attributes_out = op_def.attr.map do |attr_def|
|
29
|
+
attribute_name = RawOpHelper.check_attribute_name(attr_def)
|
30
|
+
"#{attr_def.name}: #{attribute_name}"
|
31
|
+
end
|
32
|
+
attributes_out << "name: name" %>
|
33
|
+
|
34
|
+
def self.<%= name %>(<%= (input_names + attributes_in).join(', ') %>)
|
35
|
+
self.execute("<%= op_def.name %>", [<%= input_names.join(', ') %>], <%= attributes_out.join(', ') %>)
|
36
|
+
end<% end %>
|
37
|
+
end
|
38
|
+
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
<% graph.operations.each do |operation|
|
2
|
+
op_def = self.graph.op_def(operation.op_type) %>
|
3
|
+
node {
|
4
|
+
name: "<%= operation.name %>"
|
5
|
+
op: "<%= operation.op_type %>"
|
6
|
+
<% operation.inputs.each do |input| %>
|
7
|
+
input: "<%= input.operation.name %><%= input.index > 0 ? ":#{input.index}" : "" %>"
|
8
|
+
<% end %>
|
9
|
+
<% operation.control_inputs.each do |control_input| %>
|
10
|
+
input: "^<%= control_input.name %>"
|
11
|
+
<% end %>
|
12
|
+
<% op_def.attr.sort_by(&:name).each do |attr_def|
|
13
|
+
attr = operation.attr(attr_def.name) %>
|
14
|
+
attr {
|
15
|
+
key: "<%= attr_def.name %>"
|
16
|
+
value {
|
17
|
+
<% case attr.metadata[:type]
|
18
|
+
when :tensor %>
|
19
|
+
tensor {
|
20
|
+
dtype: DT_<%= attr.tensor.dtype.upcase %>
|
21
|
+
tensor_shape {
|
22
|
+
<% attr.tensor.shape.each do |dim| %>
|
23
|
+
dim {
|
24
|
+
size: <%= dim %>
|
25
|
+
}
|
26
|
+
<% end %>
|
27
|
+
}
|
28
|
+
<%= attr.tensor.dtype %>_val: <%= attr.tensor.value.is_a?(Numo::NArray) ? attr.tensor.value.to_a : attr.tensor.value %>
|
29
|
+
tensor_content: <%= attr.proto.tensor.tensor_content.dump %>
|
30
|
+
}
|
31
|
+
<% when :type %>
|
32
|
+
<% if attr.list? %>
|
33
|
+
list {
|
34
|
+
<% attr.value.each do |list_value| %>
|
35
|
+
type: DT_<%= list_value.to_s.upcase %>
|
36
|
+
<% end %>
|
37
|
+
}
|
38
|
+
<% else %>
|
39
|
+
type: DT_<%= attr.value.to_s.upcase %>
|
40
|
+
<% end %>
|
41
|
+
<% when :bool %>
|
42
|
+
b: <%= attr.value ? 'true' : 'false' %>
|
43
|
+
<% when :int %>
|
44
|
+
i: <%= attr.value %>
|
45
|
+
<% when :shape %>
|
46
|
+
<% if attr.list? %>
|
47
|
+
list {
|
48
|
+
<% attr.value.each do |list_value| %>
|
49
|
+
shape {
|
50
|
+
<% list_value.each do |sub_list_value| %>
|
51
|
+
dim {
|
52
|
+
size: <%= sub_list_value %>
|
53
|
+
}
|
54
|
+
<% end %>
|
55
|
+
}
|
56
|
+
<% end %>
|
57
|
+
}
|
58
|
+
<% elsif attr.value.empty? %>
|
59
|
+
shape {
|
60
|
+
}
|
61
|
+
<% else %>
|
62
|
+
shape {
|
63
|
+
<%= attr.value.join(', ') %>
|
64
|
+
}
|
65
|
+
<% end %>
|
66
|
+
<% when :string %>
|
67
|
+
s: "<%= attr.value %>"
|
68
|
+
<% when :func %>
|
69
|
+
func {
|
70
|
+
name: "<%= attr.value %>"
|
71
|
+
}
|
72
|
+
<% end %>
|
73
|
+
}
|
74
|
+
}
|
75
|
+
<% end %>
|
76
|
+
}
|
77
|
+
<% end %>
|
78
|
+
versions {
|
79
|
+
producer: 119
|
80
|
+
}
|
@@ -0,0 +1,26 @@
|
|
1
|
+
require 'erubi'
|
2
|
+
|
3
|
+
module Tensorflow
|
4
|
+
module Printers
|
5
|
+
class Graph
|
6
|
+
attr_reader :graph
|
7
|
+
|
8
|
+
def initialize(graph)
|
9
|
+
@graph = graph
|
10
|
+
end
|
11
|
+
|
12
|
+
def template
|
13
|
+
@template ||= begin
|
14
|
+
path = File.join(__dir__, 'graph.erb')
|
15
|
+
File.read(path, :mode => 'rb')
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
def print(io_stream=STDOUT)
|
20
|
+
#io_stream << ERB.new(self.template, nil, trim_mode: "<>").result_with_hash(:graph => self.graph)
|
21
|
+
raw = Erubi::Engine.new(self.template, filename: 'graph.erb')
|
22
|
+
io_stream << eval(raw.src)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,109 @@
|
|
1
|
+
<% graph_def.node.each do |node_def| %>
|
2
|
+
node {
|
3
|
+
name: "<%= node_def.name %>"
|
4
|
+
op: "<%= node_def.op %>"
|
5
|
+
device: "<%= node_def.device %>"
|
6
|
+
<% node_def.input.each do |input| %>
|
7
|
+
input: "<%= input %>"
|
8
|
+
<% end %>
|
9
|
+
<% node_def.attr.sort_by {|key, value| key}.each do |key, attr_def| %>
|
10
|
+
attr {
|
11
|
+
key: "<%= key %>"
|
12
|
+
value {
|
13
|
+
<% case attr_def.value
|
14
|
+
when :tensor %>
|
15
|
+
tensor {
|
16
|
+
dtype: <%= attr_def.tensor.dtype %>
|
17
|
+
tensor_shape {
|
18
|
+
<% attr_def.tensor.tensor_shape.dim.each do |dim| %>
|
19
|
+
dim {
|
20
|
+
size: <%= dim %>
|
21
|
+
}
|
22
|
+
<% end %>
|
23
|
+
}
|
24
|
+
<% value_method = "#{attr_def.tensor.dtype[3..-1].downcase}_val" %>
|
25
|
+
<%= value_method %>: <%= attr_def.tensor.send(value_method.to_sym) %>
|
26
|
+
}
|
27
|
+
<% when :type %>
|
28
|
+
type: <%= attr_def.type %>
|
29
|
+
<% when :list %>
|
30
|
+
list {
|
31
|
+
<% case %>
|
32
|
+
<% when !attr_def.list.type.empty? %>
|
33
|
+
<% attr_def.list.type.each do |type| %>
|
34
|
+
type: <%= type %>
|
35
|
+
<% end %>
|
36
|
+
<% when !attr_def.list.shape.empty? %>
|
37
|
+
<% attr_def.list.shape.each do |shape| %>
|
38
|
+
shape {
|
39
|
+
<% shape.dim.each do |dim| %>
|
40
|
+
dim {
|
41
|
+
size: <%= dim.size %>
|
42
|
+
}
|
43
|
+
<% end %>
|
44
|
+
}
|
45
|
+
<% end %>
|
46
|
+
<% end %>
|
47
|
+
}
|
48
|
+
<% when :b %>
|
49
|
+
b: <%= attr_def.b %>
|
50
|
+
<% when :int %>
|
51
|
+
i: <%= attr_def.i %>
|
52
|
+
<%# when :shape %>
|
53
|
+
<%# if attr.list? %>
|
54
|
+
<!-- list {-->
|
55
|
+
<%# attr.value.each do |list_value| %>
|
56
|
+
<!-- shape {-->
|
57
|
+
<%# list_value.each do |sub_list_value| %>
|
58
|
+
<!-- dim {-->
|
59
|
+
<!-- size: <%#= sub_list_value %>-->
|
60
|
+
<!-- }-->
|
61
|
+
<%# end %>
|
62
|
+
<!-- }-->
|
63
|
+
<%# end %>
|
64
|
+
<!-- }-->
|
65
|
+
<%# elsif attr.value.empty? %>
|
66
|
+
<!-- shape {-->
|
67
|
+
<!-- }-->
|
68
|
+
<%# else %>
|
69
|
+
<!-- shape {-->
|
70
|
+
<%#= attr.value.join(', ') %>
|
71
|
+
<!-- }-->
|
72
|
+
<%# end %>
|
73
|
+
<% when :s %>
|
74
|
+
s: "<%= attr_def.s %>"
|
75
|
+
<% when :func %>
|
76
|
+
func {
|
77
|
+
name: "<%= attr_def.func.name %>"
|
78
|
+
}
|
79
|
+
<% end %>
|
80
|
+
}
|
81
|
+
}
|
82
|
+
<% end %>
|
83
|
+
}
|
84
|
+
<% end %>
|
85
|
+
library {
|
86
|
+
<% graph_def.library.function.each do |function_def| %>
|
87
|
+
function {
|
88
|
+
signature {
|
89
|
+
name: "<%= function_def.signature.name %>"
|
90
|
+
<% function_def.signature.input_arg.each do |arg_def| %>
|
91
|
+
input_arg {
|
92
|
+
name: "<%= arg_def.name %>"
|
93
|
+
type: <%= arg_def.type %>
|
94
|
+
}
|
95
|
+
<% end %>
|
96
|
+
<% function_def.signature.output_arg.each do |arg_def| %>
|
97
|
+
output_arg {
|
98
|
+
name: "<%= arg_def.name %>"
|
99
|
+
type: <%= arg_def.type %>
|
100
|
+
}
|
101
|
+
<% end %>
|
102
|
+
}
|
103
|
+
}
|
104
|
+
<% end %>
|
105
|
+
}
|
106
|
+
versions {
|
107
|
+
producer: <%= graph_def.versions.producer %>
|
108
|
+
min_consumer: <%= graph_def.versions.min_consumer %>
|
109
|
+
}
|
@@ -0,0 +1,26 @@
|
|
1
|
+
require 'erubi'
|
2
|
+
|
3
|
+
module Tensorflow
|
4
|
+
module Printers
|
5
|
+
class GraphDef
|
6
|
+
attr_reader :graph_def
|
7
|
+
|
8
|
+
def initialize(graph_def)
|
9
|
+
@graph_def = graph_def
|
10
|
+
end
|
11
|
+
|
12
|
+
def template
|
13
|
+
@template ||= begin
|
14
|
+
path = File.join(__dir__, 'graph_def.erb')
|
15
|
+
File.read(path, :mode => 'rb')
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
def print(io_stream=STDOUT)
|
20
|
+
#io_stream << ERB.new(self.template, nil, trim_mode: "<>").result_with_hash(:graph_def => self.graph_def)
|
21
|
+
raw = Erubi::Engine.new(self.template, filename: 'graph_def.erb')
|
22
|
+
io_stream << eval(raw.src)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,55 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
FFI::DataType.symbols.each do |dtype|
|
3
|
+
define_singleton_method(dtype) do
|
4
|
+
dtype
|
5
|
+
end
|
6
|
+
end
|
7
|
+
|
8
|
+
def self.float32
|
9
|
+
:float
|
10
|
+
end
|
11
|
+
|
12
|
+
def self.float64
|
13
|
+
:double
|
14
|
+
end
|
15
|
+
|
16
|
+
module PythonCompatability
|
17
|
+
def disable_eager_execution
|
18
|
+
self.execution_mode = Tensorflow::GRAPH_MODE
|
19
|
+
end
|
20
|
+
|
21
|
+
def enable_eager_execution
|
22
|
+
self.execution_mode = Tensorflow::EAGER_MODE
|
23
|
+
end
|
24
|
+
|
25
|
+
def global_variables
|
26
|
+
if ExecutionContext.eager?
|
27
|
+
[]
|
28
|
+
else
|
29
|
+
ExecutionContext.current.get_collection_ref(Graph::GraphKeys::GLOBAL_VARIABLES)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def global_variables_initializer
|
34
|
+
if ExecutionContext.eager?
|
35
|
+
RawOps.no_op
|
36
|
+
else
|
37
|
+
global_variables = ExecutionContext.current.get_collection_ref(Graph::GraphKeys::GLOBAL_VARIABLES)
|
38
|
+
global_variables = Array(global_variables)
|
39
|
+
if global_variables.length > 0
|
40
|
+
self.variables_initializer(global_variables)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
45
|
+
def variables_initializer(variables, name: 'init')
|
46
|
+
if ExecutionContext.eager?
|
47
|
+
RawOps.no_op
|
48
|
+
else
|
49
|
+
Control.group(variables.map(&:initializer))
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
Tf = Tensorflow
|
@@ -0,0 +1,78 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class ResourceSummaryWriter
|
3
|
+
attr_accessor :step
|
4
|
+
attr_reader :initializer
|
5
|
+
|
6
|
+
def initialize(shared_name: "", container: "")
|
7
|
+
self.step = 1
|
8
|
+
@resource = RawOps.summary_writer(shared_name: shared_name, container: container)
|
9
|
+
@initializer = yield @resource
|
10
|
+
end
|
11
|
+
|
12
|
+
def create_summary_metadata(display_name, description)
|
13
|
+
metadata = SummaryMetadata.new
|
14
|
+
metadata.display_name = display_name
|
15
|
+
metadata.summary_description = description
|
16
|
+
metadata.plugin_data = SummaryMetadata.PluginData.new
|
17
|
+
metadata.plugin_data.plugin_name = 'scalars'
|
18
|
+
end
|
19
|
+
|
20
|
+
def step=(value)
|
21
|
+
@step = value.is_a?(Variable) ? value : Tensor.new(value, dtype: :int64)
|
22
|
+
end
|
23
|
+
|
24
|
+
def audio(tag, tensor, sample_rate, max_outputs: 3)
|
25
|
+
tensor = Tensor.from_value(tensor, dtype: :float)
|
26
|
+
result = RawOps.write_audio_summary(@resource, self.step, tag, tensor, sample_rate, max_outputs: max_outputs)
|
27
|
+
ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
|
28
|
+
result
|
29
|
+
end
|
30
|
+
|
31
|
+
def graph(graph)
|
32
|
+
RawOps.write_graph_summary(@resource, self.step, graph.as_graph_def)
|
33
|
+
end
|
34
|
+
|
35
|
+
def histogram(tag, values)
|
36
|
+
result = RawOps.write_histogram_summary(@resource, self.step, tag, values)
|
37
|
+
ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
|
38
|
+
result
|
39
|
+
end
|
40
|
+
|
41
|
+
def image(tag, tensor, bad_color=nil)
|
42
|
+
bad_color ||= Tensor.new([255, 0, 0, 255], dtype: :uint8)
|
43
|
+
result = RawOps.write_image_summary(@resource, self.step, tag, tensor, bad_color)
|
44
|
+
ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
|
45
|
+
result
|
46
|
+
end
|
47
|
+
|
48
|
+
def proto(tag, tensor)
|
49
|
+
result = RawOps.write_raw_proto_summary(@resource, self.step, tensor)
|
50
|
+
ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
|
51
|
+
result
|
52
|
+
end
|
53
|
+
|
54
|
+
def scalar(tag, value, dtype: nil)
|
55
|
+
result = RawOps.write_scalar_summary(@resource, self.step, tag, value, typeT: dtype)
|
56
|
+
ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
|
57
|
+
result
|
58
|
+
end
|
59
|
+
|
60
|
+
def write(tag, value, metadata: "".b)
|
61
|
+
value = Tensor.new(value)
|
62
|
+
dtype ||= value.dtype
|
63
|
+
|
64
|
+
result = RawOps.write_summary(@resource, step, value, tag, metadata, typeT: dtype)
|
65
|
+
ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
|
66
|
+
result
|
67
|
+
end
|
68
|
+
alias :generic :write
|
69
|
+
|
70
|
+
def flush
|
71
|
+
RawOps.flush_summary_writer(@resource)
|
72
|
+
end
|
73
|
+
|
74
|
+
def close
|
75
|
+
RawOps.close_summary_writer(@resource)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Status holds error information returned by TensorFlow. We
|
2
|
+
# can use status to get error and even display the error messages from tensorflow.
|
3
|
+
module Tensorflow
|
4
|
+
class Status
|
5
|
+
def self.finalize(pointer)
|
6
|
+
proc do
|
7
|
+
FFI::TF_DeleteStatus(pointer)
|
8
|
+
end
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.check
|
12
|
+
status = Status.new
|
13
|
+
result = yield status
|
14
|
+
status.check
|
15
|
+
status = nil
|
16
|
+
result
|
17
|
+
end
|
18
|
+
|
19
|
+
def initialize
|
20
|
+
@pointer = FFI.TF_NewStatus
|
21
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
22
|
+
end
|
23
|
+
|
24
|
+
def to_ptr
|
25
|
+
@pointer
|
26
|
+
end
|
27
|
+
|
28
|
+
def code
|
29
|
+
FFI.TF_GetCode(self)
|
30
|
+
end
|
31
|
+
|
32
|
+
def message
|
33
|
+
FFI.TF_Message(self)
|
34
|
+
end
|
35
|
+
|
36
|
+
def set(code, message)
|
37
|
+
FFI.TF_SetStatus(self, code, message)
|
38
|
+
end
|
39
|
+
|
40
|
+
def check
|
41
|
+
if self.code != :tf_ok
|
42
|
+
camel_case = self.code[3..-1].capitalize
|
43
|
+
camel_case.gsub!(/(?:_|(\/))([a-z\d]*)/i) {"#{$1}#{$2.capitalize}"}
|
44
|
+
error_klass = Tensorflow::Error.const_get("#{camel_case}Error")
|
45
|
+
raise(error_klass, self.message)
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|