tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,90 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: stream_executor/dnn.proto
3
+
4
+ require 'google/protobuf'
5
+
6
+ Google::Protobuf::DescriptorPool.generated_pool.build do
7
+ add_file("stream_executor/dnn.proto", :syntax => :proto3) do
8
+ add_message "stream_executor.dnn.TensorDescriptorProto" do
9
+ repeated :dimensions, :int64, 1
10
+ optional :data_type, :enum, 2, "stream_executor.dnn.DataType"
11
+ oneof :layout_oneof do
12
+ optional :data_layout, :enum, 3, "stream_executor.dnn.DataLayout"
13
+ optional :filter_layout, :enum, 4, "stream_executor.dnn.FilterLayout"
14
+ end
15
+ end
16
+ add_message "stream_executor.dnn.AlgorithmProto" do
17
+ optional :algo_id, :int64, 1
18
+ optional :math_type, :enum, 2, "stream_executor.dnn.AlgorithmProto.MathType"
19
+ end
20
+ add_enum "stream_executor.dnn.AlgorithmProto.MathType" do
21
+ value :DEFAULT_MATH, 0
22
+ value :TENSOR_OP_MATH, 1
23
+ end
24
+ add_message "stream_executor.dnn.ConvolutionDescriptorProto" do
25
+ repeated :paddings, :int64, 1
26
+ repeated :strides, :int64, 2
27
+ repeated :dilations, :int64, 3
28
+ optional :compute_mode, :enum, 4, "stream_executor.dnn.DataType"
29
+ optional :group_count, :int32, 5
30
+ optional :convolution_mode, :enum, 6, "stream_executor.dnn.ConvolutionMode"
31
+ optional :name, :string, 7
32
+ end
33
+ add_enum "stream_executor.dnn.DataType" do
34
+ value :kFloat, 0
35
+ value :kDouble, 1
36
+ value :kHalf, 2
37
+ value :kInt8, 3
38
+ value :kInt32, 4
39
+ end
40
+ add_enum "stream_executor.dnn.DataLayout" do
41
+ value :kYXDepthBatch, 0
42
+ value :kYXBatchDepth, 1
43
+ value :kBatchYXDepth, 2
44
+ value :kBatchDepthYX, 3
45
+ value :kBatchDepthYX4, 4
46
+ end
47
+ add_enum "stream_executor.dnn.FilterLayout" do
48
+ value :kOutputInputYX, 0
49
+ value :kOutputYXInput, 1
50
+ value :kOutputInputYX4, 2
51
+ value :kInputYXOutput, 3
52
+ value :kYXInputOutput, 4
53
+ end
54
+ add_enum "stream_executor.dnn.ActivationMode" do
55
+ value :kNone, 0
56
+ value :kSigmoid, 1
57
+ value :kRelu, 2
58
+ value :kRelu6, 3
59
+ value :kReluX, 4
60
+ value :kTanh, 5
61
+ value :kBandPass, 6
62
+ end
63
+ add_enum "stream_executor.dnn.ConvolutionMode" do
64
+ value :CROSS_CORRELATION, 0
65
+ value :CONVOLUTION, 1
66
+ end
67
+ add_enum "stream_executor.dnn.ConvolutionKind" do
68
+ value :INVALID, 0
69
+ value :FORWARD, 1
70
+ value :BACKWARD_FILTER, 2
71
+ value :BACKWARD_DATA, 3
72
+ value :FORWARD_BIAS_ACTIVATION, 4
73
+ end
74
+ end
75
+ end
76
+
77
+ module StreamExecutor
78
+ module Dnn
79
+ TensorDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.TensorDescriptorProto").msgclass
80
+ AlgorithmProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto").msgclass
81
+ AlgorithmProto::MathType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto.MathType").enummodule
82
+ ConvolutionDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionDescriptorProto").msgclass
83
+ DataType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataType").enummodule
84
+ DataLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataLayout").enummodule
85
+ FilterLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.FilterLayout").enummodule
86
+ ActivationMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ActivationMode").enummodule
87
+ ConvolutionMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionMode").enummodule
88
+ ConvolutionKind = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionKind").enummodule
89
+ end
90
+ end
@@ -0,0 +1,100 @@
1
+ module Tensorflow
2
+ module Strings
3
+ class << self
4
+ def as_string(input, precision: nil, scientific: nil, shortest: nil, width: nil, fill: nil)
5
+ RawOps.as_string(input, precision: precision, scientific: scientific, shortest: shortest, width: width, fill: fill)
6
+ end
7
+
8
+ # def bytes_split
9
+ # end
10
+
11
+ # def format
12
+ # end
13
+
14
+ def join(inputs, separator: "")
15
+ RawOps.string_join(inputs, separator: separator, n: inputs.size)
16
+ end
17
+
18
+ def length(input, unit: "BYTE")
19
+ RawOps.string_length(input, unit: unit)
20
+ end
21
+
22
+ def lower(input)
23
+ RawOps.string_lower(input.encode('UTF-8'), encoding: 'utf-8')
24
+ end
25
+
26
+ # def ngrams
27
+ # end
28
+
29
+ def reduce_join(inputs, reduction_indices, keep_dims: nil, separator: nil)
30
+ RawOps.reduce_join(inputs, reduction_indices: reduction_indices, keep_dims: keep_dims, separator: separator)
31
+ end
32
+
33
+ def regex_full_match(input, pattern)
34
+ RawOps.regex_full_match(input, pattern: pattern)
35
+ end
36
+
37
+ def regex_replace(input, pattern, rewrite, replace_global: nil)
38
+ RawOps.regex_replace(input, pattern: pattern, rewrite: rewrite, replace_global: replace_global)
39
+ end
40
+
41
+ def split(split_dim, value, num_split: nil)
42
+ RawOps.split(split_dim: split_dim, value: value, num_split: num_split)
43
+ end
44
+
45
+ def strip(input)
46
+ RawOps.string_strip(input)
47
+ end
48
+
49
+ def substr(input, pos, len, unit: nil)
50
+ RawOps.substr(input, pos, len, unit: unit)
51
+ end
52
+
53
+ # def to_hash_bucket
54
+ # end
55
+
56
+ # def to_hash_bucket_fast
57
+ # end
58
+
59
+ # def to_hash_bucket_strong
60
+ # end
61
+
62
+ def to_number(input, out_type: :float)
63
+ RawOps.string_to_number(input, out_type: out_type)
64
+ end
65
+
66
+ def unicode_decode(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
67
+ RawOps.unicode_decode(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
68
+ end
69
+
70
+ def unicode_decode_with_offsets(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
71
+ RawOps.unicode_decode_with_offsets(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
72
+ end
73
+
74
+ def unicode_encode(input_values, input_splits, errors: nil, output_encoding: nil, replacement_char: nil)
75
+ RawOps.unicode_encode(input_values: input_values, input_splits: input_splits, errors: errors, output_encoding: output_encoding, replacement_char: replacement_char)
76
+ end
77
+
78
+ def unicode_script(input)
79
+ RawOps.unicode_script(input)
80
+ end
81
+
82
+ # def unicode_split
83
+ # end
84
+
85
+ # def unicode_split_with_offsets
86
+ # end
87
+
88
+ def unicode_transcode(input, input_encoding: nil, output_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
89
+ RawOps.unicode_transcode(input, input_encoding: input_encoding, output_encoding: output_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
90
+ end
91
+
92
+ # def unsorted_segment_join
93
+ # end
94
+
95
+ def upper(input, encoding: "")
96
+ RawOps.string_upper(input, encoding: encoding)
97
+ end
98
+ end
99
+ end
100
+ end
@@ -0,0 +1,13 @@
1
+ module Tensorflow
2
+ class Summary
3
+ def self.create_file_writer(logdir, max_queue: 10, flush_millis: 120_000, filename_suffix: '.v2', name: nil)
4
+ ResourceSummaryWriter.new(shared_name: name) do |writer|
5
+ RawOps.create_summary_file_writer(writer, logdir, max_queue, flush_millis, filename_suffix)
6
+ end
7
+ end
8
+
9
+ def self.all_v2_summary_ops
10
+ ExecutionContext.current.get_collection_ref(Graph::GraphKeys::SUMMARY_COLLECTION)
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,133 @@
1
+ module Tensorflow
2
+ class Tensor
3
+ include Operators
4
+ include TensorMixin
5
+
6
+ def self.finalize(pointer)
7
+ proc do
8
+ FFI.TF_DeleteTensor(pointer)
9
+ end
10
+ end
11
+
12
+ def self.from_value(value, dtype: nil)
13
+ case value
14
+ when Tensor
15
+ value
16
+ when Graph::Operation
17
+ value
18
+ when Eager::TensorHandle
19
+ value.tensor
20
+ when Data::Dataset
21
+ value.variant_tensor
22
+ else
23
+ Tensor.new(value, dtype: dtype)
24
+ end
25
+ end
26
+
27
+ def self.from_proto(proto)
28
+ proto = proto.is_a?(TensorProto) ? proto : TensorProto.decode(proto)
29
+ shape = proto.tensor_shape.dim.map(&:size)
30
+ dtype = FFI::DataType[DataType.resolve(proto.dtype)]
31
+ numo_klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
32
+ value = if shape.empty?
33
+ array = numo_klass.from_binary(proto.tensor_content)
34
+ array[0]
35
+ else
36
+ numo_klass.from_binary(proto.tensor_content, shape)
37
+ end
38
+ self.new(value, dtype:dtype, shape:shape)
39
+ end
40
+
41
+ def self.from_pointer(pointer)
42
+ result = self.allocate
43
+ result.instance_variable_set(:@pointer, pointer)
44
+ ObjectSpace.define_finalizer(result, self.finalize(pointer))
45
+ result
46
+ end
47
+
48
+ def initialize(value, dtype: nil, shape: [])
49
+ value = case value
50
+ when Numo::NArray
51
+ value
52
+ when Array
53
+ # We convert all arrays to narrays. This makes it a lot easier to support multidimensional arrays
54
+ result = Numo::NArray.cast(value)
55
+ else
56
+ TensorData.value_with_shape(value, dtype, shape)
57
+ end
58
+
59
+ tensor_data = TensorData.new(value, dtype: dtype, shape: shape)
60
+ dtype = tensor_data.dtype
61
+ shape = tensor_data.shape
62
+
63
+ if shape && shape.size > 0
64
+ dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
65
+ dims_ptr.write_array_of_int64(shape)
66
+ else
67
+ dims_ptr = nil
68
+ end
69
+
70
+ @pointer = FFI.TF_NewTensor(dtype,
71
+ dims_ptr, shape ? shape.size : 0,
72
+ tensor_data, tensor_data.byte_size,
73
+ TensorData::Deallocator, nil)
74
+
75
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
76
+ end
77
+
78
+ def value
79
+ self.data.value
80
+ end
81
+
82
+ def dtype
83
+ FFI.TF_TensorType(self)
84
+ end
85
+
86
+ def to_s
87
+ inspect
88
+ end
89
+
90
+ def to_ptr
91
+ @pointer
92
+ end
93
+
94
+ def byte_size
95
+ FFI.TF_TensorByteSize(self)
96
+ end
97
+
98
+ def inspect
99
+ inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
100
+ "#<#{self.class} #{inspection.join(", ")}>"
101
+ end
102
+
103
+ def data
104
+ TensorData.from_pointer(FFI.TF_TensorData(self), self.byte_size, self.dtype, self.shape)
105
+ end
106
+
107
+ private
108
+
109
+ def num_dims
110
+ FFI.TF_NumDims(self)
111
+ end
112
+
113
+ def dim(index)
114
+ FFI.TF_Dim(self, index)
115
+ end
116
+
117
+ def element_count
118
+ FFI.TF_TensorElementCount(self)
119
+ end
120
+
121
+ def calculate_shape(value)
122
+ return value.shape if value.respond_to?(:shape)
123
+
124
+ shape = []
125
+ d = value
126
+ while d.is_a?(Array)
127
+ shape << d.size
128
+ d = d.first
129
+ end
130
+ shape
131
+ end
132
+ end
133
+ end
@@ -0,0 +1,310 @@
1
+ require 'rbconfig'
2
+ module Tensorflow
3
+ # Tensorflow expects client libraries to allocate memory for the data that a tensor wraps. When a tensor is released,
4
+ # it notifies the client via a callback that gives the client a chance to release the memory.
5
+ #
6
+ # We don't want to use a FFI::MemoryPointer because they are garbage collected. If the underlying data is freed before
7
+ # the tensor is released you get a GC (this can happen even if a Ruby tensor object keeps a reference to the pointer at
8
+ # GC time).
9
+ #
10
+ # Thus this class creates its own memory and frees the memory only after being called bcak by tensorflow.
11
+
12
+ class TensorData
13
+ extend ::FFI::Library
14
+ ffi_lib "#{RbConfig::CONFIG['RUBY_SO_NAME']}.#{RbConfig::CONFIG['SOEXT']}"
15
+ attach_function :ruby_xmalloc, [:size_t], :pointer
16
+ attach_function :ruby_xfree, [:pointer], :void
17
+
18
+ attr_reader :pointer, :byte_size, :dtype, :shape
19
+
20
+ extend Forwardable
21
+ def_delegators :to_ptr, :+, *::FFI::Pointer.instance_methods.grep(/^write_/)
22
+
23
+ # Store a callback as a class consstant (so it won't be garbage collected ) that tensorflow will trigger
24
+ # when memory should be freed.
25
+ Deallocator = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
26
+ ruby_xfree(data)
27
+ end
28
+
29
+ DTYPE_TO_NUMO_TYPE_MAP = {bool: Numo::Bit,
30
+ complex64: Numo::SComplex,
31
+ complex128: Numo::DComplex,
32
+ double: Numo::DFloat,
33
+ float: Numo::SFloat,
34
+ int8: Numo::Int8,
35
+ int16: Numo::Int16,
36
+ int32: Numo::Int32,
37
+ int64: Numo::Int64,
38
+ uint8: Numo::UInt8,
39
+ uint16: Numo::UInt16,
40
+ uint32: Numo::UInt32,
41
+ uint64: Numo::UInt64}
42
+
43
+ NUMO_TYPE_TO_DTYPE_MAP = DTYPE_TO_NUMO_TYPE_MAP.each_with_object(Hash.new) do |pair, hash|
44
+ hash[pair.last] = pair.first
45
+ end
46
+
47
+ def self.figure_dtype(value)
48
+ case value
49
+ when Numo::RObject
50
+ # Need to look at the first element to see what it is
51
+ self.figure_dtype(value[0])
52
+ when Numo::NArray
53
+ NUMO_TYPE_TO_DTYPE_MAP[value.class]
54
+ when Array
55
+ self.figure_dtype(value.first)
56
+ when Integer
57
+ (value >= -2147483648 && value <= 2147483647) ? :int32 : :int64
58
+ when Complex
59
+ (value.real > -1.175494351e38 && value.real < 3.402823466e38) ? :complex64 : :complex128
60
+ when Numeric
61
+ (value > -1.175494351e38 && value < 3.402823466e38) ? :float : :double
62
+ when String
63
+ :string
64
+ when TrueClass, FalseClass
65
+ :bool
66
+ when ::FFI::Pointer
67
+ :pointer
68
+ when Tensor
69
+ value.dtype
70
+ when Variable
71
+ value.dtype
72
+ when Graph::Operation
73
+ nil
74
+ when Eager::TensorHandle
75
+ value.dtype
76
+ when Google::Protobuf::MessageExts
77
+ :string
78
+ else
79
+ raise(Error::InvalidArgumentError, "Unsupported type: #{value.class}")
80
+ end
81
+ end
82
+
83
+ def self.type_size(dtype)
84
+ case dtype
85
+ when :complex64
86
+ ::FFI.type_size(:float) * 2
87
+ when :complex128
88
+ ::FFI.type_size(:double) * 2
89
+ else
90
+ ::FFI.type_size(dtype)
91
+ end
92
+ end
93
+
94
+ def self.value_with_shape(value, dtype, shape)
95
+ if shape && shape.size > 0
96
+ dtype ||= self.figure_dtype(value)
97
+ numo_klass = DTYPE_TO_NUMO_TYPE_MAP[dtype]
98
+ numo_klass.new(shape).fill(value)
99
+ else
100
+ value
101
+ end
102
+ end
103
+
104
+ def self.from_pointer(pointer, byte_size, dtype, shape)
105
+ result = self.allocate
106
+ result.instance_variable_set(:@pointer, pointer)
107
+ result.instance_variable_set(:@byte_size, byte_size)
108
+ result.instance_variable_set(:@dtype, dtype)
109
+ result.instance_variable_set(:@shape, shape)
110
+ result
111
+ end
112
+
113
+ def initialize(value, dtype: nil, shape: [])
114
+ @dtype = dtype || self.class.figure_dtype(value)
115
+ @shape = shape
116
+ case value
117
+ when Numo::NArray
118
+ self.write_narray(value)
119
+ when Array
120
+ raise(Error::InvalidArgumentError, "TensorData does not support Arrays. Please use a Numo::NArray")
121
+ when Google::Protobuf::MessageExts
122
+ encoded = value.class.encode(value)
123
+ self.write_array_of_string([encoded])
124
+ else
125
+ self.write_scalar(value)
126
+ end
127
+ end
128
+
129
+ def to_ptr
130
+ @pointer
131
+ end
132
+
133
+ def read_array_of_complex64(count)
134
+ values = self.read_array_of_float(2 * count)
135
+ values.each_slice(2).map do |real, imaginary|
136
+ Complex(real, imaginary)
137
+ end
138
+ end
139
+
140
+ def read_array_of_complex128(count)
141
+ values = self.read_array_of_double(2 * count)
142
+ values.each_slice(2).map do |real, imaginary|
143
+ Complex(real, imaginary)
144
+ end
145
+ end
146
+
147
+ def read_array_of_string(count)
148
+ # The start of the data section comes after the offset table
149
+ start_offset_size = count * ::FFI.type_size(:int64)
150
+
151
+ # Read in the string offsets
152
+ offsets = self.pointer.read_array_of_uint64(count)
153
+
154
+ offsets.map.with_index do |offset, index|
155
+ src_bytes = (offsets[index + 1] || self.byte_size) - offset
156
+ dst_ptr = ::FFI::MemoryPointer.new(:pointer)
157
+ dst_len_ptr = ::FFI::MemoryPointer.new(:size_t)
158
+ Status.check do |status|
159
+ FFI.TF_StringDecode(self.pointer + start_offset_size + offset, src_bytes, dst_ptr, dst_len_ptr, status)
160
+ end
161
+ string_pointer = dst_ptr.read_pointer
162
+ string_length = dst_len_ptr.read(:size_t)
163
+ string_pointer.read_string(string_length)
164
+ end
165
+ end
166
+
167
+ # def value
168
+ # result = case self.dtype
169
+ # when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
170
+ # self.pointer.send("read_array_of_#{self.dtype}", self.count)
171
+ # when :bfloat16
172
+ # byte_str = self.pointer.read_bytes(self.count * 2)
173
+ # self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
174
+ # when :complex64
175
+ # self.read_array_of_complex64(self.count)
176
+ # when :complex128
177
+ # self.read_array_of_complex128(self.count)
178
+ # when :string
179
+ # self.read_array_of_string(self.count)
180
+ # when :bool
181
+ # self.pointer.read_array_of_int8(self.count)
182
+ # when :resource, :variant
183
+ # return self.data
184
+ # else
185
+ # raise "Unsupported tensor data type: #{self.dtype}"
186
+ # end
187
+ #
188
+ # if self.count == 1
189
+ # result.first
190
+ # else
191
+ # result
192
+ # end
193
+ # end
194
+
195
+ def value
196
+ result = case self.dtype
197
+ when :bfloat16
198
+ byte_str = self.pointer.read_bytes(self.count * 2)
199
+ self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
200
+ when :string
201
+ count = self.shape.reduce(1) {|dim, result| result *= dim}
202
+ self.read_array_of_string(count)
203
+ when :bool
204
+ bytes = self.pointer.read_bytes(self.byte_size)
205
+ int8 = if self.shape.empty?
206
+ Numo::Int8.from_binary(bytes)
207
+ else
208
+ Numo::Int8.from_binary(bytes, self.shape)
209
+ end
210
+ int8.cast_to(Numo::Bit)
211
+ else
212
+ bytes = self.pointer.read_bytes(self.byte_size)
213
+ numo_klass = DTYPE_TO_NUMO_TYPE_MAP[self.dtype]
214
+ if self.shape.empty?
215
+ numo_klass.from_binary(bytes)
216
+ else
217
+ numo_klass.from_binary(bytes, self.shape)
218
+ end
219
+ end
220
+
221
+ if self.shape.empty?
222
+ result[0]
223
+ else
224
+ result
225
+ end
226
+ end
227
+
228
+ def write_array_of_string(strings)
229
+ # The start of the data section comes after the offset table
230
+ start_offset_size = strings.size * ::FFI.type_size(:int64)
231
+
232
+ # Get the encoded sizes for each string
233
+ encoded_sizes = strings.map do |string|
234
+ FFI.TF_StringEncodedSize(string.bytesize)
235
+ end
236
+
237
+ # Now figure the offsets. Offsets are relative to the beginning of data section, not the beginning of the pointer.
238
+ # Notice we skip the last string [0..-2] since its offset would be the end of the pointer
239
+ offsets = [0]
240
+ encoded_sizes[0..-2].each do |encoded_size|
241
+ offsets << offsets.last + encoded_size
242
+ end
243
+
244
+ # Allocate the needed memory
245
+ @byte_size = start_offset_size + encoded_sizes.sum
246
+ @pointer = self.class.ruby_xmalloc(@byte_size)
247
+
248
+ # Write the offsets
249
+ self.pointer.write_array_of_uint64(offsets)
250
+
251
+ # Write the strings
252
+ strings.each_with_index do |string, index|
253
+ offset = offsets[index]
254
+ size = encoded_sizes[index]
255
+ Status.check do |status|
256
+ FFI.TF_StringEncode(string, string.bytesize, self.pointer + start_offset_size + offset, size, status)
257
+ end
258
+ end
259
+ end
260
+
261
+ def write_narray(new_value)
262
+ @shape = new_value.shape
263
+
264
+ case self.dtype
265
+ when :string
266
+ self.write_array_of_string(new_value.flatten.to_a)
267
+ when :bool
268
+ new_value = new_value.cast_to(Numo::Int8)
269
+ @byte_size = new_value.byte_size
270
+ @pointer = self.class.ruby_xmalloc(self.byte_size)
271
+ self.pointer.write_bytes(new_value.to_binary)
272
+ else
273
+ if NUMO_TYPE_TO_DTYPE_MAP[new_value.class] != dtype
274
+ # Cast the narray if necessary (say the user passed in float but we have a double array)
275
+ new_value = new_value.cast_to(DTYPE_TO_NUMO_TYPE_MAP[dtype])
276
+ end
277
+ @byte_size = new_value.byte_size
278
+ @pointer = self.class.ruby_xmalloc(@byte_size)
279
+ self.pointer.write_bytes(new_value.to_binary)
280
+ end
281
+ end
282
+
283
+ def write_scalar(new_value)
284
+ case self.dtype
285
+ when :string
286
+ self.write_array_of_string([new_value])
287
+ when :resource, :variant
288
+ return self
289
+ else
290
+ @byte_size = self.class.type_size(self.dtype)
291
+ @pointer = self.class.ruby_xmalloc(@byte_size)
292
+ case self.dtype
293
+ when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
294
+ self.pointer.send("write_#{self.dtype}", new_value)
295
+ when :bfloat16
296
+ byte_str = self.pointer.read_bytes(self.count * 2)
297
+ self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
298
+ when :complex64
299
+ self.pointer.write_array_of_float([new_value.real, new_value.imaginary])
300
+ when :complex128
301
+ self.pointer.write_array_of_double([new_value.real, new_value.imaginary])
302
+ when :bool
303
+ self.pointer.write_int8(new_value ? 1 : 0)
304
+ else
305
+ raise "Unsupported tensor data type: #{self.dtype}"
306
+ end
307
+ end
308
+ end
309
+ end
310
+ end