tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,38 @@
1
+ # Do NOT edit this file. It is generated by `rake generate_ops`
2
+
3
+ module Tensorflow
4
+ module RawOps
5
+ def self.execute(op_type, inputs=[], attrs={})
6
+ context = ExecutionContext.current(inputs)
7
+ attrs = attrs.compact
8
+ operation = context.create_operation(op_type, inputs, attrs)
9
+ if context.is_a?(Graph::Graph)
10
+ operation
11
+ else
12
+ context.execute(operation)
13
+ end
14
+ end
15
+
16
+ <% Tensorflow.op_defs.values.sort_by(&:name).each do |op_def|
17
+ name = RawOpHelper.underscore(op_def.name).downcase
18
+
19
+ input_names = op_def.input_arg.map do |input|
20
+ RawOpHelper.check_name(input.name)
21
+ end
22
+
23
+ attributes_in = op_def.attr.map do |attr_def|
24
+ RawOpHelper.process_attribute(attr_def)
25
+ end
26
+ attributes_in << "name: \"#{op_def.name}\""
27
+
28
+ attributes_out = op_def.attr.map do |attr_def|
29
+ attribute_name = RawOpHelper.check_attribute_name(attr_def)
30
+ "#{attr_def.name}: #{attribute_name}"
31
+ end
32
+ attributes_out << "name: name" %>
33
+
34
+ def self.<%= name %>(<%= (input_names + attributes_in).join(', ') %>)
35
+ self.execute("<%= op_def.name %>", [<%= input_names.join(', ') %>], <%= attributes_out.join(', ') %>)
36
+ end<% end %>
37
+ end
38
+ end
@@ -0,0 +1,80 @@
1
+ <% graph.operations.each do |operation|
2
+ op_def = self.graph.op_def(operation.op_type) %>
3
+ node {
4
+ name: "<%= operation.name %>"
5
+ op: "<%= operation.op_type %>"
6
+ <% operation.inputs.each do |input| %>
7
+ input: "<%= input.operation.name %><%= input.index > 0 ? ":#{input.index}" : "" %>"
8
+ <% end %>
9
+ <% operation.control_inputs.each do |control_input| %>
10
+ input: "^<%= control_input.name %>"
11
+ <% end %>
12
+ <% op_def.attr.sort_by(&:name).each do |attr_def|
13
+ attr = operation.attr(attr_def.name) %>
14
+ attr {
15
+ key: "<%= attr_def.name %>"
16
+ value {
17
+ <% case attr.metadata[:type]
18
+ when :tensor %>
19
+ tensor {
20
+ dtype: DT_<%= attr.tensor.dtype.upcase %>
21
+ tensor_shape {
22
+ <% attr.tensor.shape.each do |dim| %>
23
+ dim {
24
+ size: <%= dim %>
25
+ }
26
+ <% end %>
27
+ }
28
+ <%= attr.tensor.dtype %>_val: <%= attr.tensor.value.is_a?(Numo::NArray) ? attr.tensor.value.to_a : attr.tensor.value %>
29
+ tensor_content: <%= attr.proto.tensor.tensor_content.dump %>
30
+ }
31
+ <% when :type %>
32
+ <% if attr.list? %>
33
+ list {
34
+ <% attr.value.each do |list_value| %>
35
+ type: DT_<%= list_value.to_s.upcase %>
36
+ <% end %>
37
+ }
38
+ <% else %>
39
+ type: DT_<%= attr.value.to_s.upcase %>
40
+ <% end %>
41
+ <% when :bool %>
42
+ b: <%= attr.value ? 'true' : 'false' %>
43
+ <% when :int %>
44
+ i: <%= attr.value %>
45
+ <% when :shape %>
46
+ <% if attr.list? %>
47
+ list {
48
+ <% attr.value.each do |list_value| %>
49
+ shape {
50
+ <% list_value.each do |sub_list_value| %>
51
+ dim {
52
+ size: <%= sub_list_value %>
53
+ }
54
+ <% end %>
55
+ }
56
+ <% end %>
57
+ }
58
+ <% elsif attr.value.empty? %>
59
+ shape {
60
+ }
61
+ <% else %>
62
+ shape {
63
+ <%= attr.value.join(', ') %>
64
+ }
65
+ <% end %>
66
+ <% when :string %>
67
+ s: "<%= attr.value %>"
68
+ <% when :func %>
69
+ func {
70
+ name: "<%= attr.value %>"
71
+ }
72
+ <% end %>
73
+ }
74
+ }
75
+ <% end %>
76
+ }
77
+ <% end %>
78
+ versions {
79
+ producer: 119
80
+ }
@@ -0,0 +1,26 @@
1
+ require 'erubi'
2
+
3
+ module Tensorflow
4
+ module Printers
5
+ class Graph
6
+ attr_reader :graph
7
+
8
+ def initialize(graph)
9
+ @graph = graph
10
+ end
11
+
12
+ def template
13
+ @template ||= begin
14
+ path = File.join(__dir__, 'graph.erb')
15
+ File.read(path, :mode => 'rb')
16
+ end
17
+ end
18
+
19
+ def print(io_stream=STDOUT)
20
+ #io_stream << ERB.new(self.template, nil, trim_mode: "<>").result_with_hash(:graph => self.graph)
21
+ raw = Erubi::Engine.new(self.template, filename: 'graph.erb')
22
+ io_stream << eval(raw.src)
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,109 @@
1
+ <% graph_def.node.each do |node_def| %>
2
+ node {
3
+ name: "<%= node_def.name %>"
4
+ op: "<%= node_def.op %>"
5
+ device: "<%= node_def.device %>"
6
+ <% node_def.input.each do |input| %>
7
+ input: "<%= input %>"
8
+ <% end %>
9
+ <% node_def.attr.sort_by {|key, value| key}.each do |key, attr_def| %>
10
+ attr {
11
+ key: "<%= key %>"
12
+ value {
13
+ <% case attr_def.value
14
+ when :tensor %>
15
+ tensor {
16
+ dtype: <%= attr_def.tensor.dtype %>
17
+ tensor_shape {
18
+ <% attr_def.tensor.tensor_shape.dim.each do |dim| %>
19
+ dim {
20
+ size: <%= dim %>
21
+ }
22
+ <% end %>
23
+ }
24
+ <% value_method = "#{attr_def.tensor.dtype[3..-1].downcase}_val" %>
25
+ <%= value_method %>: <%= attr_def.tensor.send(value_method.to_sym) %>
26
+ }
27
+ <% when :type %>
28
+ type: <%= attr_def.type %>
29
+ <% when :list %>
30
+ list {
31
+ <% case %>
32
+ <% when !attr_def.list.type.empty? %>
33
+ <% attr_def.list.type.each do |type| %>
34
+ type: <%= type %>
35
+ <% end %>
36
+ <% when !attr_def.list.shape.empty? %>
37
+ <% attr_def.list.shape.each do |shape| %>
38
+ shape {
39
+ <% shape.dim.each do |dim| %>
40
+ dim {
41
+ size: <%= dim.size %>
42
+ }
43
+ <% end %>
44
+ }
45
+ <% end %>
46
+ <% end %>
47
+ }
48
+ <% when :b %>
49
+ b: <%= attr_def.b %>
50
+ <% when :int %>
51
+ i: <%= attr_def.i %>
52
+ <%# when :shape %>
53
+ <%# if attr.list? %>
54
+ <!-- list {-->
55
+ <%# attr.value.each do |list_value| %>
56
+ <!-- shape {-->
57
+ <%# list_value.each do |sub_list_value| %>
58
+ <!-- dim {-->
59
+ <!-- size: <%#= sub_list_value %>-->
60
+ <!-- }-->
61
+ <%# end %>
62
+ <!-- }-->
63
+ <%# end %>
64
+ <!-- }-->
65
+ <%# elsif attr.value.empty? %>
66
+ <!-- shape {-->
67
+ <!-- }-->
68
+ <%# else %>
69
+ <!-- shape {-->
70
+ <%#= attr.value.join(', ') %>
71
+ <!-- }-->
72
+ <%# end %>
73
+ <% when :s %>
74
+ s: "<%= attr_def.s %>"
75
+ <% when :func %>
76
+ func {
77
+ name: "<%= attr_def.func.name %>"
78
+ }
79
+ <% end %>
80
+ }
81
+ }
82
+ <% end %>
83
+ }
84
+ <% end %>
85
+ library {
86
+ <% graph_def.library.function.each do |function_def| %>
87
+ function {
88
+ signature {
89
+ name: "<%= function_def.signature.name %>"
90
+ <% function_def.signature.input_arg.each do |arg_def| %>
91
+ input_arg {
92
+ name: "<%= arg_def.name %>"
93
+ type: <%= arg_def.type %>
94
+ }
95
+ <% end %>
96
+ <% function_def.signature.output_arg.each do |arg_def| %>
97
+ output_arg {
98
+ name: "<%= arg_def.name %>"
99
+ type: <%= arg_def.type %>
100
+ }
101
+ <% end %>
102
+ }
103
+ }
104
+ <% end %>
105
+ }
106
+ versions {
107
+ producer: <%= graph_def.versions.producer %>
108
+ min_consumer: <%= graph_def.versions.min_consumer %>
109
+ }
@@ -0,0 +1,26 @@
1
+ require 'erubi'
2
+
3
+ module Tensorflow
4
+ module Printers
5
+ class GraphDef
6
+ attr_reader :graph_def
7
+
8
+ def initialize(graph_def)
9
+ @graph_def = graph_def
10
+ end
11
+
12
+ def template
13
+ @template ||= begin
14
+ path = File.join(__dir__, 'graph_def.erb')
15
+ File.read(path, :mode => 'rb')
16
+ end
17
+ end
18
+
19
+ def print(io_stream=STDOUT)
20
+ #io_stream << ERB.new(self.template, nil, trim_mode: "<>").result_with_hash(:graph_def => self.graph_def)
21
+ raw = Erubi::Engine.new(self.template, filename: 'graph_def.erb')
22
+ io_stream << eval(raw.src)
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,55 @@
1
+ module Tensorflow
2
+ FFI::DataType.symbols.each do |dtype|
3
+ define_singleton_method(dtype) do
4
+ dtype
5
+ end
6
+ end
7
+
8
+ def self.float32
9
+ :float
10
+ end
11
+
12
+ def self.float64
13
+ :double
14
+ end
15
+
16
+ module PythonCompatability
17
+ def disable_eager_execution
18
+ self.execution_mode = Tensorflow::GRAPH_MODE
19
+ end
20
+
21
+ def enable_eager_execution
22
+ self.execution_mode = Tensorflow::EAGER_MODE
23
+ end
24
+
25
+ def global_variables
26
+ if ExecutionContext.eager?
27
+ []
28
+ else
29
+ ExecutionContext.current.get_collection_ref(Graph::GraphKeys::GLOBAL_VARIABLES)
30
+ end
31
+ end
32
+
33
+ def global_variables_initializer
34
+ if ExecutionContext.eager?
35
+ RawOps.no_op
36
+ else
37
+ global_variables = ExecutionContext.current.get_collection_ref(Graph::GraphKeys::GLOBAL_VARIABLES)
38
+ global_variables = Array(global_variables)
39
+ if global_variables.length > 0
40
+ self.variables_initializer(global_variables)
41
+ end
42
+ end
43
+ end
44
+
45
+ def variables_initializer(variables, name: 'init')
46
+ if ExecutionContext.eager?
47
+ RawOps.no_op
48
+ else
49
+ Control.group(variables.map(&:initializer))
50
+ end
51
+ end
52
+ end
53
+ end
54
+
55
+ Tf = Tensorflow
@@ -0,0 +1,78 @@
1
+ module Tensorflow
2
+ class ResourceSummaryWriter
3
+ attr_accessor :step
4
+ attr_reader :initializer
5
+
6
+ def initialize(shared_name: "", container: "")
7
+ self.step = 1
8
+ @resource = RawOps.summary_writer(shared_name: shared_name, container: container)
9
+ @initializer = yield @resource
10
+ end
11
+
12
+ def create_summary_metadata(display_name, description)
13
+ metadata = SummaryMetadata.new
14
+ metadata.display_name = display_name
15
+ metadata.summary_description = description
16
+ metadata.plugin_data = SummaryMetadata.PluginData.new
17
+ metadata.plugin_data.plugin_name = 'scalars'
18
+ end
19
+
20
+ def step=(value)
21
+ @step = value.is_a?(Variable) ? value : Tensor.new(value, dtype: :int64)
22
+ end
23
+
24
+ def audio(tag, tensor, sample_rate, max_outputs: 3)
25
+ tensor = Tensor.from_value(tensor, dtype: :float)
26
+ result = RawOps.write_audio_summary(@resource, self.step, tag, tensor, sample_rate, max_outputs: max_outputs)
27
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
28
+ result
29
+ end
30
+
31
+ def graph(graph)
32
+ RawOps.write_graph_summary(@resource, self.step, graph.as_graph_def)
33
+ end
34
+
35
+ def histogram(tag, values)
36
+ result = RawOps.write_histogram_summary(@resource, self.step, tag, values)
37
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
38
+ result
39
+ end
40
+
41
+ def image(tag, tensor, bad_color=nil)
42
+ bad_color ||= Tensor.new([255, 0, 0, 255], dtype: :uint8)
43
+ result = RawOps.write_image_summary(@resource, self.step, tag, tensor, bad_color)
44
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
45
+ result
46
+ end
47
+
48
+ def proto(tag, tensor)
49
+ result = RawOps.write_raw_proto_summary(@resource, self.step, tensor)
50
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
51
+ result
52
+ end
53
+
54
+ def scalar(tag, value, dtype: nil)
55
+ result = RawOps.write_scalar_summary(@resource, self.step, tag, value, typeT: dtype)
56
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
57
+ result
58
+ end
59
+
60
+ def write(tag, value, metadata: "".b)
61
+ value = Tensor.new(value)
62
+ dtype ||= value.dtype
63
+
64
+ result = RawOps.write_summary(@resource, step, value, tag, metadata, typeT: dtype)
65
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
66
+ result
67
+ end
68
+ alias :generic :write
69
+
70
+ def flush
71
+ RawOps.flush_summary_writer(@resource)
72
+ end
73
+
74
+ def close
75
+ RawOps.close_summary_writer(@resource)
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,49 @@
1
+ # Status holds error information returned by TensorFlow. We
2
+ # can use status to get error and even display the error messages from tensorflow.
3
+ module Tensorflow
4
+ class Status
5
+ def self.finalize(pointer)
6
+ proc do
7
+ FFI::TF_DeleteStatus(pointer)
8
+ end
9
+ end
10
+
11
+ def self.check
12
+ status = Status.new
13
+ result = yield status
14
+ status.check
15
+ status = nil
16
+ result
17
+ end
18
+
19
+ def initialize
20
+ @pointer = FFI.TF_NewStatus
21
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
22
+ end
23
+
24
+ def to_ptr
25
+ @pointer
26
+ end
27
+
28
+ def code
29
+ FFI.TF_GetCode(self)
30
+ end
31
+
32
+ def message
33
+ FFI.TF_Message(self)
34
+ end
35
+
36
+ def set(code, message)
37
+ FFI.TF_SetStatus(self, code, message)
38
+ end
39
+
40
+ def check
41
+ if self.code != :tf_ok
42
+ camel_case = self.code[3..-1].capitalize
43
+ camel_case.gsub!(/(?:_|(\/))([a-z\d]*)/i) {"#{$1}#{$2.capitalize}"}
44
+ error_klass = Tensorflow::Error.const_get("#{camel_case}Error")
45
+ raise(error_klass, self.message)
46
+ end
47
+ end
48
+ end
49
+ end