tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,38 @@
1
+ # Do NOT edit this file. It is generated by `rake generate_ops`
2
+
3
+ module Tensorflow
4
+ module RawOps
5
+ def self.execute(op_type, inputs=[], attrs={})
6
+ context = ExecutionContext.current(inputs)
7
+ attrs = attrs.compact
8
+ operation = context.create_operation(op_type, inputs, attrs)
9
+ if context.is_a?(Graph::Graph)
10
+ operation
11
+ else
12
+ context.execute(operation)
13
+ end
14
+ end
15
+
16
+ <% Tensorflow.op_defs.values.sort_by(&:name).each do |op_def|
17
+ name = RawOpHelper.underscore(op_def.name).downcase
18
+
19
+ input_names = op_def.input_arg.map do |input|
20
+ RawOpHelper.check_name(input.name)
21
+ end
22
+
23
+ attributes_in = op_def.attr.map do |attr_def|
24
+ RawOpHelper.process_attribute(attr_def)
25
+ end
26
+ attributes_in << "name: \"#{op_def.name}\""
27
+
28
+ attributes_out = op_def.attr.map do |attr_def|
29
+ attribute_name = RawOpHelper.check_attribute_name(attr_def)
30
+ "#{attr_def.name}: #{attribute_name}"
31
+ end
32
+ attributes_out << "name: name" %>
33
+
34
+ def self.<%= name %>(<%= (input_names + attributes_in).join(', ') %>)
35
+ self.execute("<%= op_def.name %>", [<%= input_names.join(', ') %>], <%= attributes_out.join(', ') %>)
36
+ end<% end %>
37
+ end
38
+ end
@@ -0,0 +1,80 @@
1
+ <% graph.operations.each do |operation|
2
+ op_def = self.graph.op_def(operation.op_type) %>
3
+ node {
4
+ name: "<%= operation.name %>"
5
+ op: "<%= operation.op_type %>"
6
+ <% operation.inputs.each do |input| %>
7
+ input: "<%= input.operation.name %><%= input.index > 0 ? ":#{input.index}" : "" %>"
8
+ <% end %>
9
+ <% operation.control_inputs.each do |control_input| %>
10
+ input: "^<%= control_input.name %>"
11
+ <% end %>
12
+ <% op_def.attr.sort_by(&:name).each do |attr_def|
13
+ attr = operation.attr(attr_def.name) %>
14
+ attr {
15
+ key: "<%= attr_def.name %>"
16
+ value {
17
+ <% case attr.metadata[:type]
18
+ when :tensor %>
19
+ tensor {
20
+ dtype: DT_<%= attr.tensor.dtype.upcase %>
21
+ tensor_shape {
22
+ <% attr.tensor.shape.each do |dim| %>
23
+ dim {
24
+ size: <%= dim %>
25
+ }
26
+ <% end %>
27
+ }
28
+ <%= attr.tensor.dtype %>_val: <%= attr.tensor.value.is_a?(Numo::NArray) ? attr.tensor.value.to_a : attr.tensor.value %>
29
+ tensor_content: <%= attr.proto.tensor.tensor_content.dump %>
30
+ }
31
+ <% when :type %>
32
+ <% if attr.list? %>
33
+ list {
34
+ <% attr.value.each do |list_value| %>
35
+ type: DT_<%= list_value.to_s.upcase %>
36
+ <% end %>
37
+ }
38
+ <% else %>
39
+ type: DT_<%= attr.value.to_s.upcase %>
40
+ <% end %>
41
+ <% when :bool %>
42
+ b: <%= attr.value ? 'true' : 'false' %>
43
+ <% when :int %>
44
+ i: <%= attr.value %>
45
+ <% when :shape %>
46
+ <% if attr.list? %>
47
+ list {
48
+ <% attr.value.each do |list_value| %>
49
+ shape {
50
+ <% list_value.each do |sub_list_value| %>
51
+ dim {
52
+ size: <%= sub_list_value %>
53
+ }
54
+ <% end %>
55
+ }
56
+ <% end %>
57
+ }
58
+ <% elsif attr.value.empty? %>
59
+ shape {
60
+ }
61
+ <% else %>
62
+ shape {
63
+ <%= attr.value.join(', ') %>
64
+ }
65
+ <% end %>
66
+ <% when :string %>
67
+ s: "<%= attr.value %>"
68
+ <% when :func %>
69
+ func {
70
+ name: "<%= attr.value %>"
71
+ }
72
+ <% end %>
73
+ }
74
+ }
75
+ <% end %>
76
+ }
77
+ <% end %>
78
+ versions {
79
+ producer: 119
80
+ }
@@ -0,0 +1,26 @@
1
+ require 'erubi'
2
+
3
+ module Tensorflow
4
+ module Printers
5
+ class Graph
6
+ attr_reader :graph
7
+
8
+ def initialize(graph)
9
+ @graph = graph
10
+ end
11
+
12
+ def template
13
+ @template ||= begin
14
+ path = File.join(__dir__, 'graph.erb')
15
+ File.read(path, :mode => 'rb')
16
+ end
17
+ end
18
+
19
+ def print(io_stream=STDOUT)
20
+ #io_stream << ERB.new(self.template, nil, trim_mode: "<>").result_with_hash(:graph => self.graph)
21
+ raw = Erubi::Engine.new(self.template, filename: 'graph.erb')
22
+ io_stream << eval(raw.src)
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,109 @@
1
+ <% graph_def.node.each do |node_def| %>
2
+ node {
3
+ name: "<%= node_def.name %>"
4
+ op: "<%= node_def.op %>"
5
+ device: "<%= node_def.device %>"
6
+ <% node_def.input.each do |input| %>
7
+ input: "<%= input %>"
8
+ <% end %>
9
+ <% node_def.attr.sort_by {|key, value| key}.each do |key, attr_def| %>
10
+ attr {
11
+ key: "<%= key %>"
12
+ value {
13
+ <% case attr_def.value
14
+ when :tensor %>
15
+ tensor {
16
+ dtype: <%= attr_def.tensor.dtype %>
17
+ tensor_shape {
18
+ <% attr_def.tensor.tensor_shape.dim.each do |dim| %>
19
+ dim {
20
+ size: <%= dim %>
21
+ }
22
+ <% end %>
23
+ }
24
+ <% value_method = "#{attr_def.tensor.dtype[3..-1].downcase}_val" %>
25
+ <%= value_method %>: <%= attr_def.tensor.send(value_method.to_sym) %>
26
+ }
27
+ <% when :type %>
28
+ type: <%= attr_def.type %>
29
+ <% when :list %>
30
+ list {
31
+ <% case %>
32
+ <% when !attr_def.list.type.empty? %>
33
+ <% attr_def.list.type.each do |type| %>
34
+ type: <%= type %>
35
+ <% end %>
36
+ <% when !attr_def.list.shape.empty? %>
37
+ <% attr_def.list.shape.each do |shape| %>
38
+ shape {
39
+ <% shape.dim.each do |dim| %>
40
+ dim {
41
+ size: <%= dim.size %>
42
+ }
43
+ <% end %>
44
+ }
45
+ <% end %>
46
+ <% end %>
47
+ }
48
+ <% when :b %>
49
+ b: <%= attr_def.b %>
50
+ <% when :int %>
51
+ i: <%= attr_def.i %>
52
+ <%# when :shape %>
53
+ <%# if attr.list? %>
54
+ <!-- list {-->
55
+ <%# attr.value.each do |list_value| %>
56
+ <!-- shape {-->
57
+ <%# list_value.each do |sub_list_value| %>
58
+ <!-- dim {-->
59
+ <!-- size: <%#= sub_list_value %>-->
60
+ <!-- }-->
61
+ <%# end %>
62
+ <!-- }-->
63
+ <%# end %>
64
+ <!-- }-->
65
+ <%# elsif attr.value.empty? %>
66
+ <!-- shape {-->
67
+ <!-- }-->
68
+ <%# else %>
69
+ <!-- shape {-->
70
+ <%#= attr.value.join(', ') %>
71
+ <!-- }-->
72
+ <%# end %>
73
+ <% when :s %>
74
+ s: "<%= attr_def.s %>"
75
+ <% when :func %>
76
+ func {
77
+ name: "<%= attr_def.func.name %>"
78
+ }
79
+ <% end %>
80
+ }
81
+ }
82
+ <% end %>
83
+ }
84
+ <% end %>
85
+ library {
86
+ <% graph_def.library.function.each do |function_def| %>
87
+ function {
88
+ signature {
89
+ name: "<%= function_def.signature.name %>"
90
+ <% function_def.signature.input_arg.each do |arg_def| %>
91
+ input_arg {
92
+ name: "<%= arg_def.name %>"
93
+ type: <%= arg_def.type %>
94
+ }
95
+ <% end %>
96
+ <% function_def.signature.output_arg.each do |arg_def| %>
97
+ output_arg {
98
+ name: "<%= arg_def.name %>"
99
+ type: <%= arg_def.type %>
100
+ }
101
+ <% end %>
102
+ }
103
+ }
104
+ <% end %>
105
+ }
106
+ versions {
107
+ producer: <%= graph_def.versions.producer %>
108
+ min_consumer: <%= graph_def.versions.min_consumer %>
109
+ }
@@ -0,0 +1,26 @@
1
+ require 'erubi'
2
+
3
+ module Tensorflow
4
+ module Printers
5
+ class GraphDef
6
+ attr_reader :graph_def
7
+
8
+ def initialize(graph_def)
9
+ @graph_def = graph_def
10
+ end
11
+
12
+ def template
13
+ @template ||= begin
14
+ path = File.join(__dir__, 'graph_def.erb')
15
+ File.read(path, :mode => 'rb')
16
+ end
17
+ end
18
+
19
+ def print(io_stream=STDOUT)
20
+ #io_stream << ERB.new(self.template, nil, trim_mode: "<>").result_with_hash(:graph_def => self.graph_def)
21
+ raw = Erubi::Engine.new(self.template, filename: 'graph_def.erb')
22
+ io_stream << eval(raw.src)
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,55 @@
1
+ module Tensorflow
2
+ FFI::DataType.symbols.each do |dtype|
3
+ define_singleton_method(dtype) do
4
+ dtype
5
+ end
6
+ end
7
+
8
+ def self.float32
9
+ :float
10
+ end
11
+
12
+ def self.float64
13
+ :double
14
+ end
15
+
16
+ module PythonCompatability
17
+ def disable_eager_execution
18
+ self.execution_mode = Tensorflow::GRAPH_MODE
19
+ end
20
+
21
+ def enable_eager_execution
22
+ self.execution_mode = Tensorflow::EAGER_MODE
23
+ end
24
+
25
+ def global_variables
26
+ if ExecutionContext.eager?
27
+ []
28
+ else
29
+ ExecutionContext.current.get_collection_ref(Graph::GraphKeys::GLOBAL_VARIABLES)
30
+ end
31
+ end
32
+
33
+ def global_variables_initializer
34
+ if ExecutionContext.eager?
35
+ RawOps.no_op
36
+ else
37
+ global_variables = ExecutionContext.current.get_collection_ref(Graph::GraphKeys::GLOBAL_VARIABLES)
38
+ global_variables = Array(global_variables)
39
+ if global_variables.length > 0
40
+ self.variables_initializer(global_variables)
41
+ end
42
+ end
43
+ end
44
+
45
+ def variables_initializer(variables, name: 'init')
46
+ if ExecutionContext.eager?
47
+ RawOps.no_op
48
+ else
49
+ Control.group(variables.map(&:initializer))
50
+ end
51
+ end
52
+ end
53
+ end
54
+
55
+ Tf = Tensorflow
@@ -0,0 +1,78 @@
1
+ module Tensorflow
2
+ class ResourceSummaryWriter
3
+ attr_accessor :step
4
+ attr_reader :initializer
5
+
6
+ def initialize(shared_name: "", container: "")
7
+ self.step = 1
8
+ @resource = RawOps.summary_writer(shared_name: shared_name, container: container)
9
+ @initializer = yield @resource
10
+ end
11
+
12
+ def create_summary_metadata(display_name, description)
13
+ metadata = SummaryMetadata.new
14
+ metadata.display_name = display_name
15
+ metadata.summary_description = description
16
+ metadata.plugin_data = SummaryMetadata.PluginData.new
17
+ metadata.plugin_data.plugin_name = 'scalars'
18
+ end
19
+
20
+ def step=(value)
21
+ @step = value.is_a?(Variable) ? value : Tensor.new(value, dtype: :int64)
22
+ end
23
+
24
+ def audio(tag, tensor, sample_rate, max_outputs: 3)
25
+ tensor = Tensor.from_value(tensor, dtype: :float)
26
+ result = RawOps.write_audio_summary(@resource, self.step, tag, tensor, sample_rate, max_outputs: max_outputs)
27
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
28
+ result
29
+ end
30
+
31
+ def graph(graph)
32
+ RawOps.write_graph_summary(@resource, self.step, graph.as_graph_def)
33
+ end
34
+
35
+ def histogram(tag, values)
36
+ result = RawOps.write_histogram_summary(@resource, self.step, tag, values)
37
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
38
+ result
39
+ end
40
+
41
+ def image(tag, tensor, bad_color=nil)
42
+ bad_color ||= Tensor.new([255, 0, 0, 255], dtype: :uint8)
43
+ result = RawOps.write_image_summary(@resource, self.step, tag, tensor, bad_color)
44
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
45
+ result
46
+ end
47
+
48
+ def proto(tag, tensor)
49
+ result = RawOps.write_raw_proto_summary(@resource, self.step, tensor)
50
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
51
+ result
52
+ end
53
+
54
+ def scalar(tag, value, dtype: nil)
55
+ result = RawOps.write_scalar_summary(@resource, self.step, tag, value, typeT: dtype)
56
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
57
+ result
58
+ end
59
+
60
+ def write(tag, value, metadata: "".b)
61
+ value = Tensor.new(value)
62
+ dtype ||= value.dtype
63
+
64
+ result = RawOps.write_summary(@resource, step, value, tag, metadata, typeT: dtype)
65
+ ExecutionContext.current.add_to_collection(Graph::GraphKeys::SUMMARY_COLLECTION, result)
66
+ result
67
+ end
68
+ alias :generic :write
69
+
70
+ def flush
71
+ RawOps.flush_summary_writer(@resource)
72
+ end
73
+
74
+ def close
75
+ RawOps.close_summary_writer(@resource)
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,49 @@
1
+ # Status holds error information returned by TensorFlow. We
2
+ # can use status to get error and even display the error messages from tensorflow.
3
+ module Tensorflow
4
+ class Status
5
+ def self.finalize(pointer)
6
+ proc do
7
+ FFI::TF_DeleteStatus(pointer)
8
+ end
9
+ end
10
+
11
+ def self.check
12
+ status = Status.new
13
+ result = yield status
14
+ status.check
15
+ status = nil
16
+ result
17
+ end
18
+
19
+ def initialize
20
+ @pointer = FFI.TF_NewStatus
21
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
22
+ end
23
+
24
+ def to_ptr
25
+ @pointer
26
+ end
27
+
28
+ def code
29
+ FFI.TF_GetCode(self)
30
+ end
31
+
32
+ def message
33
+ FFI.TF_Message(self)
34
+ end
35
+
36
+ def set(code, message)
37
+ FFI.TF_SetStatus(self, code, message)
38
+ end
39
+
40
+ def check
41
+ if self.code != :tf_ok
42
+ camel_case = self.code[3..-1].capitalize
43
+ camel_case.gsub!(/(?:_|(\/))([a-z\d]*)/i) {"#{$1}#{$2.capitalize}"}
44
+ error_klass = Tensorflow::Error.const_get("#{camel_case}Error")
45
+ raise(error_klass, self.message)
46
+ end
47
+ end
48
+ end
49
+ end