tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,90 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: stream_executor/dnn.proto
3
+
4
+ require 'google/protobuf'
5
+
6
+ Google::Protobuf::DescriptorPool.generated_pool.build do
7
+ add_file("stream_executor/dnn.proto", :syntax => :proto3) do
8
+ add_message "stream_executor.dnn.TensorDescriptorProto" do
9
+ repeated :dimensions, :int64, 1
10
+ optional :data_type, :enum, 2, "stream_executor.dnn.DataType"
11
+ oneof :layout_oneof do
12
+ optional :data_layout, :enum, 3, "stream_executor.dnn.DataLayout"
13
+ optional :filter_layout, :enum, 4, "stream_executor.dnn.FilterLayout"
14
+ end
15
+ end
16
+ add_message "stream_executor.dnn.AlgorithmProto" do
17
+ optional :algo_id, :int64, 1
18
+ optional :math_type, :enum, 2, "stream_executor.dnn.AlgorithmProto.MathType"
19
+ end
20
+ add_enum "stream_executor.dnn.AlgorithmProto.MathType" do
21
+ value :DEFAULT_MATH, 0
22
+ value :TENSOR_OP_MATH, 1
23
+ end
24
+ add_message "stream_executor.dnn.ConvolutionDescriptorProto" do
25
+ repeated :paddings, :int64, 1
26
+ repeated :strides, :int64, 2
27
+ repeated :dilations, :int64, 3
28
+ optional :compute_mode, :enum, 4, "stream_executor.dnn.DataType"
29
+ optional :group_count, :int32, 5
30
+ optional :convolution_mode, :enum, 6, "stream_executor.dnn.ConvolutionMode"
31
+ optional :name, :string, 7
32
+ end
33
+ add_enum "stream_executor.dnn.DataType" do
34
+ value :kFloat, 0
35
+ value :kDouble, 1
36
+ value :kHalf, 2
37
+ value :kInt8, 3
38
+ value :kInt32, 4
39
+ end
40
+ add_enum "stream_executor.dnn.DataLayout" do
41
+ value :kYXDepthBatch, 0
42
+ value :kYXBatchDepth, 1
43
+ value :kBatchYXDepth, 2
44
+ value :kBatchDepthYX, 3
45
+ value :kBatchDepthYX4, 4
46
+ end
47
+ add_enum "stream_executor.dnn.FilterLayout" do
48
+ value :kOutputInputYX, 0
49
+ value :kOutputYXInput, 1
50
+ value :kOutputInputYX4, 2
51
+ value :kInputYXOutput, 3
52
+ value :kYXInputOutput, 4
53
+ end
54
+ add_enum "stream_executor.dnn.ActivationMode" do
55
+ value :kNone, 0
56
+ value :kSigmoid, 1
57
+ value :kRelu, 2
58
+ value :kRelu6, 3
59
+ value :kReluX, 4
60
+ value :kTanh, 5
61
+ value :kBandPass, 6
62
+ end
63
+ add_enum "stream_executor.dnn.ConvolutionMode" do
64
+ value :CROSS_CORRELATION, 0
65
+ value :CONVOLUTION, 1
66
+ end
67
+ add_enum "stream_executor.dnn.ConvolutionKind" do
68
+ value :INVALID, 0
69
+ value :FORWARD, 1
70
+ value :BACKWARD_FILTER, 2
71
+ value :BACKWARD_DATA, 3
72
+ value :FORWARD_BIAS_ACTIVATION, 4
73
+ end
74
+ end
75
+ end
76
+
77
+ module StreamExecutor
78
+ module Dnn
79
+ TensorDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.TensorDescriptorProto").msgclass
80
+ AlgorithmProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto").msgclass
81
+ AlgorithmProto::MathType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto.MathType").enummodule
82
+ ConvolutionDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionDescriptorProto").msgclass
83
+ DataType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataType").enummodule
84
+ DataLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataLayout").enummodule
85
+ FilterLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.FilterLayout").enummodule
86
+ ActivationMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ActivationMode").enummodule
87
+ ConvolutionMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionMode").enummodule
88
+ ConvolutionKind = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionKind").enummodule
89
+ end
90
+ end
@@ -0,0 +1,100 @@
1
+ module Tensorflow
2
+ module Strings
3
+ class << self
4
+ def as_string(input, precision: nil, scientific: nil, shortest: nil, width: nil, fill: nil)
5
+ RawOps.as_string(input, precision: precision, scientific: scientific, shortest: shortest, width: width, fill: fill)
6
+ end
7
+
8
+ # def bytes_split
9
+ # end
10
+
11
+ # def format
12
+ # end
13
+
14
+ def join(inputs, separator: "")
15
+ RawOps.string_join(inputs, separator: separator, n: inputs.size)
16
+ end
17
+
18
+ def length(input, unit: "BYTE")
19
+ RawOps.string_length(input, unit: unit)
20
+ end
21
+
22
+ def lower(input)
23
+ RawOps.string_lower(input.encode('UTF-8'), encoding: 'utf-8')
24
+ end
25
+
26
+ # def ngrams
27
+ # end
28
+
29
+ def reduce_join(inputs, reduction_indices, keep_dims: nil, separator: nil)
30
+ RawOps.reduce_join(inputs, reduction_indices: reduction_indices, keep_dims: keep_dims, separator: separator)
31
+ end
32
+
33
+ def regex_full_match(input, pattern)
34
+ RawOps.regex_full_match(input, pattern: pattern)
35
+ end
36
+
37
+ def regex_replace(input, pattern, rewrite, replace_global: nil)
38
+ RawOps.regex_replace(input, pattern: pattern, rewrite: rewrite, replace_global: replace_global)
39
+ end
40
+
41
+ def split(split_dim, value, num_split: nil)
42
+ RawOps.split(split_dim: split_dim, value: value, num_split: num_split)
43
+ end
44
+
45
+ def strip(input)
46
+ RawOps.string_strip(input)
47
+ end
48
+
49
+ def substr(input, pos, len, unit: nil)
50
+ RawOps.substr(input, pos, len, unit: unit)
51
+ end
52
+
53
+ # def to_hash_bucket
54
+ # end
55
+
56
+ # def to_hash_bucket_fast
57
+ # end
58
+
59
+ # def to_hash_bucket_strong
60
+ # end
61
+
62
+ def to_number(input, out_type: :float)
63
+ RawOps.string_to_number(input, out_type: out_type)
64
+ end
65
+
66
+ def unicode_decode(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
67
+ RawOps.unicode_decode(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
68
+ end
69
+
70
+ def unicode_decode_with_offsets(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
71
+ RawOps.unicode_decode_with_offsets(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
72
+ end
73
+
74
+ def unicode_encode(input_values, input_splits, errors: nil, output_encoding: nil, replacement_char: nil)
75
+ RawOps.unicode_encode(input_values: input_values, input_splits: input_splits, errors: errors, output_encoding: output_encoding, replacement_char: replacement_char)
76
+ end
77
+
78
+ def unicode_script(input)
79
+ RawOps.unicode_script(input)
80
+ end
81
+
82
+ # def unicode_split
83
+ # end
84
+
85
+ # def unicode_split_with_offsets
86
+ # end
87
+
88
+ def unicode_transcode(input, input_encoding: nil, output_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
89
+ RawOps.unicode_transcode(input, input_encoding: input_encoding, output_encoding: output_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
90
+ end
91
+
92
+ # def unsorted_segment_join
93
+ # end
94
+
95
+ def upper(input, encoding: "")
96
+ RawOps.string_upper(input, encoding: encoding)
97
+ end
98
+ end
99
+ end
100
+ end
@@ -0,0 +1,13 @@
1
+ module Tensorflow
2
+ class Summary
3
+ def self.create_file_writer(logdir, max_queue: 10, flush_millis: 120_000, filename_suffix: '.v2', name: nil)
4
+ ResourceSummaryWriter.new(shared_name: name) do |writer|
5
+ RawOps.create_summary_file_writer(writer, logdir, max_queue, flush_millis, filename_suffix)
6
+ end
7
+ end
8
+
9
+ def self.all_v2_summary_ops
10
+ ExecutionContext.current.get_collection_ref(Graph::GraphKeys::SUMMARY_COLLECTION)
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,133 @@
1
+ module Tensorflow
2
+ class Tensor
3
+ include Operators
4
+ include TensorMixin
5
+
6
+ def self.finalize(pointer)
7
+ proc do
8
+ FFI.TF_DeleteTensor(pointer)
9
+ end
10
+ end
11
+
12
+ def self.from_value(value, dtype: nil)
13
+ case value
14
+ when Tensor
15
+ value
16
+ when Graph::Operation
17
+ value
18
+ when Eager::TensorHandle
19
+ value.tensor
20
+ when Data::Dataset
21
+ value.variant_tensor
22
+ else
23
+ Tensor.new(value, dtype: dtype)
24
+ end
25
+ end
26
+
27
+ def self.from_proto(proto)
28
+ proto = proto.is_a?(TensorProto) ? proto : TensorProto.decode(proto)
29
+ shape = proto.tensor_shape.dim.map(&:size)
30
+ dtype = FFI::DataType[DataType.resolve(proto.dtype)]
31
+ numo_klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
32
+ value = if shape.empty?
33
+ array = numo_klass.from_binary(proto.tensor_content)
34
+ array[0]
35
+ else
36
+ numo_klass.from_binary(proto.tensor_content, shape)
37
+ end
38
+ self.new(value, dtype:dtype, shape:shape)
39
+ end
40
+
41
+ def self.from_pointer(pointer)
42
+ result = self.allocate
43
+ result.instance_variable_set(:@pointer, pointer)
44
+ ObjectSpace.define_finalizer(result, self.finalize(pointer))
45
+ result
46
+ end
47
+
48
+ def initialize(value, dtype: nil, shape: [])
49
+ value = case value
50
+ when Numo::NArray
51
+ value
52
+ when Array
53
+ # We convert all arrays to narrays. This makes it a lot easier to support multidimensional arrays
54
+ result = Numo::NArray.cast(value)
55
+ else
56
+ TensorData.value_with_shape(value, dtype, shape)
57
+ end
58
+
59
+ tensor_data = TensorData.new(value, dtype: dtype, shape: shape)
60
+ dtype = tensor_data.dtype
61
+ shape = tensor_data.shape
62
+
63
+ if shape && shape.size > 0
64
+ dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
65
+ dims_ptr.write_array_of_int64(shape)
66
+ else
67
+ dims_ptr = nil
68
+ end
69
+
70
+ @pointer = FFI.TF_NewTensor(dtype,
71
+ dims_ptr, shape ? shape.size : 0,
72
+ tensor_data, tensor_data.byte_size,
73
+ TensorData::Deallocator, nil)
74
+
75
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
76
+ end
77
+
78
+ def value
79
+ self.data.value
80
+ end
81
+
82
+ def dtype
83
+ FFI.TF_TensorType(self)
84
+ end
85
+
86
+ def to_s
87
+ inspect
88
+ end
89
+
90
+ def to_ptr
91
+ @pointer
92
+ end
93
+
94
+ def byte_size
95
+ FFI.TF_TensorByteSize(self)
96
+ end
97
+
98
+ def inspect
99
+ inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
100
+ "#<#{self.class} #{inspection.join(", ")}>"
101
+ end
102
+
103
+ def data
104
+ TensorData.from_pointer(FFI.TF_TensorData(self), self.byte_size, self.dtype, self.shape)
105
+ end
106
+
107
+ private
108
+
109
+ def num_dims
110
+ FFI.TF_NumDims(self)
111
+ end
112
+
113
+ def dim(index)
114
+ FFI.TF_Dim(self, index)
115
+ end
116
+
117
+ def element_count
118
+ FFI.TF_TensorElementCount(self)
119
+ end
120
+
121
+ def calculate_shape(value)
122
+ return value.shape if value.respond_to?(:shape)
123
+
124
+ shape = []
125
+ d = value
126
+ while d.is_a?(Array)
127
+ shape << d.size
128
+ d = d.first
129
+ end
130
+ shape
131
+ end
132
+ end
133
+ end
@@ -0,0 +1,310 @@
1
+ require 'rbconfig'
2
+ module Tensorflow
3
+ # Tensorflow expects client libraries to allocate memory for the data that a tensor wraps. When a tensor is released,
4
+ # it notifies the client via a callback that gives the client a chance to release the memory.
5
+ #
6
+ # We don't want to use a FFI::MemoryPointer because they are garbage collected. If the underlying data is freed before
7
+ # the tensor is released you get a GC (this can happen even if a Ruby tensor object keeps a reference to the pointer at
8
+ # GC time).
9
+ #
10
+ # Thus this class creates its own memory and frees the memory only after being called bcak by tensorflow.
11
+
12
+ class TensorData
13
+ extend ::FFI::Library
14
+ ffi_lib "#{RbConfig::CONFIG['RUBY_SO_NAME']}.#{RbConfig::CONFIG['SOEXT']}"
15
+ attach_function :ruby_xmalloc, [:size_t], :pointer
16
+ attach_function :ruby_xfree, [:pointer], :void
17
+
18
+ attr_reader :pointer, :byte_size, :dtype, :shape
19
+
20
+ extend Forwardable
21
+ def_delegators :to_ptr, :+, *::FFI::Pointer.instance_methods.grep(/^write_/)
22
+
23
+ # Store a callback as a class consstant (so it won't be garbage collected ) that tensorflow will trigger
24
+ # when memory should be freed.
25
+ Deallocator = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
26
+ ruby_xfree(data)
27
+ end
28
+
29
+ DTYPE_TO_NUMO_TYPE_MAP = {bool: Numo::Bit,
30
+ complex64: Numo::SComplex,
31
+ complex128: Numo::DComplex,
32
+ double: Numo::DFloat,
33
+ float: Numo::SFloat,
34
+ int8: Numo::Int8,
35
+ int16: Numo::Int16,
36
+ int32: Numo::Int32,
37
+ int64: Numo::Int64,
38
+ uint8: Numo::UInt8,
39
+ uint16: Numo::UInt16,
40
+ uint32: Numo::UInt32,
41
+ uint64: Numo::UInt64}
42
+
43
+ NUMO_TYPE_TO_DTYPE_MAP = DTYPE_TO_NUMO_TYPE_MAP.each_with_object(Hash.new) do |pair, hash|
44
+ hash[pair.last] = pair.first
45
+ end
46
+
47
+ def self.figure_dtype(value)
48
+ case value
49
+ when Numo::RObject
50
+ # Need to look at the first element to see what it is
51
+ self.figure_dtype(value[0])
52
+ when Numo::NArray
53
+ NUMO_TYPE_TO_DTYPE_MAP[value.class]
54
+ when Array
55
+ self.figure_dtype(value.first)
56
+ when Integer
57
+ (value >= -2147483648 && value <= 2147483647) ? :int32 : :int64
58
+ when Complex
59
+ (value.real > -1.175494351e38 && value.real < 3.402823466e38) ? :complex64 : :complex128
60
+ when Numeric
61
+ (value > -1.175494351e38 && value < 3.402823466e38) ? :float : :double
62
+ when String
63
+ :string
64
+ when TrueClass, FalseClass
65
+ :bool
66
+ when ::FFI::Pointer
67
+ :pointer
68
+ when Tensor
69
+ value.dtype
70
+ when Variable
71
+ value.dtype
72
+ when Graph::Operation
73
+ nil
74
+ when Eager::TensorHandle
75
+ value.dtype
76
+ when Google::Protobuf::MessageExts
77
+ :string
78
+ else
79
+ raise(Error::InvalidArgumentError, "Unsupported type: #{value.class}")
80
+ end
81
+ end
82
+
83
+ def self.type_size(dtype)
84
+ case dtype
85
+ when :complex64
86
+ ::FFI.type_size(:float) * 2
87
+ when :complex128
88
+ ::FFI.type_size(:double) * 2
89
+ else
90
+ ::FFI.type_size(dtype)
91
+ end
92
+ end
93
+
94
+ def self.value_with_shape(value, dtype, shape)
95
+ if shape && shape.size > 0
96
+ dtype ||= self.figure_dtype(value)
97
+ numo_klass = DTYPE_TO_NUMO_TYPE_MAP[dtype]
98
+ numo_klass.new(shape).fill(value)
99
+ else
100
+ value
101
+ end
102
+ end
103
+
104
+ def self.from_pointer(pointer, byte_size, dtype, shape)
105
+ result = self.allocate
106
+ result.instance_variable_set(:@pointer, pointer)
107
+ result.instance_variable_set(:@byte_size, byte_size)
108
+ result.instance_variable_set(:@dtype, dtype)
109
+ result.instance_variable_set(:@shape, shape)
110
+ result
111
+ end
112
+
113
+ def initialize(value, dtype: nil, shape: [])
114
+ @dtype = dtype || self.class.figure_dtype(value)
115
+ @shape = shape
116
+ case value
117
+ when Numo::NArray
118
+ self.write_narray(value)
119
+ when Array
120
+ raise(Error::InvalidArgumentError, "TensorData does not support Arrays. Please use a Numo::NArray")
121
+ when Google::Protobuf::MessageExts
122
+ encoded = value.class.encode(value)
123
+ self.write_array_of_string([encoded])
124
+ else
125
+ self.write_scalar(value)
126
+ end
127
+ end
128
+
129
+ def to_ptr
130
+ @pointer
131
+ end
132
+
133
+ def read_array_of_complex64(count)
134
+ values = self.read_array_of_float(2 * count)
135
+ values.each_slice(2).map do |real, imaginary|
136
+ Complex(real, imaginary)
137
+ end
138
+ end
139
+
140
+ def read_array_of_complex128(count)
141
+ values = self.read_array_of_double(2 * count)
142
+ values.each_slice(2).map do |real, imaginary|
143
+ Complex(real, imaginary)
144
+ end
145
+ end
146
+
147
+ def read_array_of_string(count)
148
+ # The start of the data section comes after the offset table
149
+ start_offset_size = count * ::FFI.type_size(:int64)
150
+
151
+ # Read in the string offsets
152
+ offsets = self.pointer.read_array_of_uint64(count)
153
+
154
+ offsets.map.with_index do |offset, index|
155
+ src_bytes = (offsets[index + 1] || self.byte_size) - offset
156
+ dst_ptr = ::FFI::MemoryPointer.new(:pointer)
157
+ dst_len_ptr = ::FFI::MemoryPointer.new(:size_t)
158
+ Status.check do |status|
159
+ FFI.TF_StringDecode(self.pointer + start_offset_size + offset, src_bytes, dst_ptr, dst_len_ptr, status)
160
+ end
161
+ string_pointer = dst_ptr.read_pointer
162
+ string_length = dst_len_ptr.read(:size_t)
163
+ string_pointer.read_string(string_length)
164
+ end
165
+ end
166
+
167
+ # def value
168
+ # result = case self.dtype
169
+ # when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
170
+ # self.pointer.send("read_array_of_#{self.dtype}", self.count)
171
+ # when :bfloat16
172
+ # byte_str = self.pointer.read_bytes(self.count * 2)
173
+ # self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
174
+ # when :complex64
175
+ # self.read_array_of_complex64(self.count)
176
+ # when :complex128
177
+ # self.read_array_of_complex128(self.count)
178
+ # when :string
179
+ # self.read_array_of_string(self.count)
180
+ # when :bool
181
+ # self.pointer.read_array_of_int8(self.count)
182
+ # when :resource, :variant
183
+ # return self.data
184
+ # else
185
+ # raise "Unsupported tensor data type: #{self.dtype}"
186
+ # end
187
+ #
188
+ # if self.count == 1
189
+ # result.first
190
+ # else
191
+ # result
192
+ # end
193
+ # end
194
+
195
+ def value
196
+ result = case self.dtype
197
+ when :bfloat16
198
+ byte_str = self.pointer.read_bytes(self.count * 2)
199
+ self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
200
+ when :string
201
+ count = self.shape.reduce(1) {|dim, result| result *= dim}
202
+ self.read_array_of_string(count)
203
+ when :bool
204
+ bytes = self.pointer.read_bytes(self.byte_size)
205
+ int8 = if self.shape.empty?
206
+ Numo::Int8.from_binary(bytes)
207
+ else
208
+ Numo::Int8.from_binary(bytes, self.shape)
209
+ end
210
+ int8.cast_to(Numo::Bit)
211
+ else
212
+ bytes = self.pointer.read_bytes(self.byte_size)
213
+ numo_klass = DTYPE_TO_NUMO_TYPE_MAP[self.dtype]
214
+ if self.shape.empty?
215
+ numo_klass.from_binary(bytes)
216
+ else
217
+ numo_klass.from_binary(bytes, self.shape)
218
+ end
219
+ end
220
+
221
+ if self.shape.empty?
222
+ result[0]
223
+ else
224
+ result
225
+ end
226
+ end
227
+
228
+ def write_array_of_string(strings)
229
+ # The start of the data section comes after the offset table
230
+ start_offset_size = strings.size * ::FFI.type_size(:int64)
231
+
232
+ # Get the encoded sizes for each string
233
+ encoded_sizes = strings.map do |string|
234
+ FFI.TF_StringEncodedSize(string.bytesize)
235
+ end
236
+
237
+ # Now figure the offsets. Offsets are relative to the beginning of data section, not the beginning of the pointer.
238
+ # Notice we skip the last string [0..-2] since its offset would be the end of the pointer
239
+ offsets = [0]
240
+ encoded_sizes[0..-2].each do |encoded_size|
241
+ offsets << offsets.last + encoded_size
242
+ end
243
+
244
+ # Allocate the needed memory
245
+ @byte_size = start_offset_size + encoded_sizes.sum
246
+ @pointer = self.class.ruby_xmalloc(@byte_size)
247
+
248
+ # Write the offsets
249
+ self.pointer.write_array_of_uint64(offsets)
250
+
251
+ # Write the strings
252
+ strings.each_with_index do |string, index|
253
+ offset = offsets[index]
254
+ size = encoded_sizes[index]
255
+ Status.check do |status|
256
+ FFI.TF_StringEncode(string, string.bytesize, self.pointer + start_offset_size + offset, size, status)
257
+ end
258
+ end
259
+ end
260
+
261
+ def write_narray(new_value)
262
+ @shape = new_value.shape
263
+
264
+ case self.dtype
265
+ when :string
266
+ self.write_array_of_string(new_value.flatten.to_a)
267
+ when :bool
268
+ new_value = new_value.cast_to(Numo::Int8)
269
+ @byte_size = new_value.byte_size
270
+ @pointer = self.class.ruby_xmalloc(self.byte_size)
271
+ self.pointer.write_bytes(new_value.to_binary)
272
+ else
273
+ if NUMO_TYPE_TO_DTYPE_MAP[new_value.class] != dtype
274
+ # Cast the narray if necessary (say the user passed in float but we have a double array)
275
+ new_value = new_value.cast_to(DTYPE_TO_NUMO_TYPE_MAP[dtype])
276
+ end
277
+ @byte_size = new_value.byte_size
278
+ @pointer = self.class.ruby_xmalloc(@byte_size)
279
+ self.pointer.write_bytes(new_value.to_binary)
280
+ end
281
+ end
282
+
283
+ def write_scalar(new_value)
284
+ case self.dtype
285
+ when :string
286
+ self.write_array_of_string([new_value])
287
+ when :resource, :variant
288
+ return self
289
+ else
290
+ @byte_size = self.class.type_size(self.dtype)
291
+ @pointer = self.class.ruby_xmalloc(@byte_size)
292
+ case self.dtype
293
+ when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
294
+ self.pointer.send("write_#{self.dtype}", new_value)
295
+ when :bfloat16
296
+ byte_str = self.pointer.read_bytes(self.count * 2)
297
+ self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
298
+ when :complex64
299
+ self.pointer.write_array_of_float([new_value.real, new_value.imaginary])
300
+ when :complex128
301
+ self.pointer.write_array_of_double([new_value.real, new_value.imaginary])
302
+ when :bool
303
+ self.pointer.write_int8(new_value ? 1 : 0)
304
+ else
305
+ raise "Unsupported tensor data type: #{self.dtype}"
306
+ end
307
+ end
308
+ end
309
+ end
310
+ end