tensorflow-ruby 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: stream_executor/dnn.proto
|
3
|
+
|
4
|
+
require 'google/protobuf'
|
5
|
+
|
6
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
7
|
+
add_file("stream_executor/dnn.proto", :syntax => :proto3) do
|
8
|
+
add_message "stream_executor.dnn.TensorDescriptorProto" do
|
9
|
+
repeated :dimensions, :int64, 1
|
10
|
+
optional :data_type, :enum, 2, "stream_executor.dnn.DataType"
|
11
|
+
oneof :layout_oneof do
|
12
|
+
optional :data_layout, :enum, 3, "stream_executor.dnn.DataLayout"
|
13
|
+
optional :filter_layout, :enum, 4, "stream_executor.dnn.FilterLayout"
|
14
|
+
end
|
15
|
+
end
|
16
|
+
add_message "stream_executor.dnn.AlgorithmProto" do
|
17
|
+
optional :algo_id, :int64, 1
|
18
|
+
optional :math_type, :enum, 2, "stream_executor.dnn.AlgorithmProto.MathType"
|
19
|
+
end
|
20
|
+
add_enum "stream_executor.dnn.AlgorithmProto.MathType" do
|
21
|
+
value :DEFAULT_MATH, 0
|
22
|
+
value :TENSOR_OP_MATH, 1
|
23
|
+
end
|
24
|
+
add_message "stream_executor.dnn.ConvolutionDescriptorProto" do
|
25
|
+
repeated :paddings, :int64, 1
|
26
|
+
repeated :strides, :int64, 2
|
27
|
+
repeated :dilations, :int64, 3
|
28
|
+
optional :compute_mode, :enum, 4, "stream_executor.dnn.DataType"
|
29
|
+
optional :group_count, :int32, 5
|
30
|
+
optional :convolution_mode, :enum, 6, "stream_executor.dnn.ConvolutionMode"
|
31
|
+
optional :name, :string, 7
|
32
|
+
end
|
33
|
+
add_enum "stream_executor.dnn.DataType" do
|
34
|
+
value :kFloat, 0
|
35
|
+
value :kDouble, 1
|
36
|
+
value :kHalf, 2
|
37
|
+
value :kInt8, 3
|
38
|
+
value :kInt32, 4
|
39
|
+
end
|
40
|
+
add_enum "stream_executor.dnn.DataLayout" do
|
41
|
+
value :kYXDepthBatch, 0
|
42
|
+
value :kYXBatchDepth, 1
|
43
|
+
value :kBatchYXDepth, 2
|
44
|
+
value :kBatchDepthYX, 3
|
45
|
+
value :kBatchDepthYX4, 4
|
46
|
+
end
|
47
|
+
add_enum "stream_executor.dnn.FilterLayout" do
|
48
|
+
value :kOutputInputYX, 0
|
49
|
+
value :kOutputYXInput, 1
|
50
|
+
value :kOutputInputYX4, 2
|
51
|
+
value :kInputYXOutput, 3
|
52
|
+
value :kYXInputOutput, 4
|
53
|
+
end
|
54
|
+
add_enum "stream_executor.dnn.ActivationMode" do
|
55
|
+
value :kNone, 0
|
56
|
+
value :kSigmoid, 1
|
57
|
+
value :kRelu, 2
|
58
|
+
value :kRelu6, 3
|
59
|
+
value :kReluX, 4
|
60
|
+
value :kTanh, 5
|
61
|
+
value :kBandPass, 6
|
62
|
+
end
|
63
|
+
add_enum "stream_executor.dnn.ConvolutionMode" do
|
64
|
+
value :CROSS_CORRELATION, 0
|
65
|
+
value :CONVOLUTION, 1
|
66
|
+
end
|
67
|
+
add_enum "stream_executor.dnn.ConvolutionKind" do
|
68
|
+
value :INVALID, 0
|
69
|
+
value :FORWARD, 1
|
70
|
+
value :BACKWARD_FILTER, 2
|
71
|
+
value :BACKWARD_DATA, 3
|
72
|
+
value :FORWARD_BIAS_ACTIVATION, 4
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
module StreamExecutor
|
78
|
+
module Dnn
|
79
|
+
TensorDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.TensorDescriptorProto").msgclass
|
80
|
+
AlgorithmProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto").msgclass
|
81
|
+
AlgorithmProto::MathType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto.MathType").enummodule
|
82
|
+
ConvolutionDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionDescriptorProto").msgclass
|
83
|
+
DataType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataType").enummodule
|
84
|
+
DataLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataLayout").enummodule
|
85
|
+
FilterLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.FilterLayout").enummodule
|
86
|
+
ActivationMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ActivationMode").enummodule
|
87
|
+
ConvolutionMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionMode").enummodule
|
88
|
+
ConvolutionKind = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionKind").enummodule
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,100 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Strings
|
3
|
+
class << self
|
4
|
+
def as_string(input, precision: nil, scientific: nil, shortest: nil, width: nil, fill: nil)
|
5
|
+
RawOps.as_string(input, precision: precision, scientific: scientific, shortest: shortest, width: width, fill: fill)
|
6
|
+
end
|
7
|
+
|
8
|
+
# def bytes_split
|
9
|
+
# end
|
10
|
+
|
11
|
+
# def format
|
12
|
+
# end
|
13
|
+
|
14
|
+
def join(inputs, separator: "")
|
15
|
+
RawOps.string_join(inputs, separator: separator, n: inputs.size)
|
16
|
+
end
|
17
|
+
|
18
|
+
def length(input, unit: "BYTE")
|
19
|
+
RawOps.string_length(input, unit: unit)
|
20
|
+
end
|
21
|
+
|
22
|
+
def lower(input)
|
23
|
+
RawOps.string_lower(input.encode('UTF-8'), encoding: 'utf-8')
|
24
|
+
end
|
25
|
+
|
26
|
+
# def ngrams
|
27
|
+
# end
|
28
|
+
|
29
|
+
def reduce_join(inputs, reduction_indices, keep_dims: nil, separator: nil)
|
30
|
+
RawOps.reduce_join(inputs, reduction_indices: reduction_indices, keep_dims: keep_dims, separator: separator)
|
31
|
+
end
|
32
|
+
|
33
|
+
def regex_full_match(input, pattern)
|
34
|
+
RawOps.regex_full_match(input, pattern: pattern)
|
35
|
+
end
|
36
|
+
|
37
|
+
def regex_replace(input, pattern, rewrite, replace_global: nil)
|
38
|
+
RawOps.regex_replace(input, pattern: pattern, rewrite: rewrite, replace_global: replace_global)
|
39
|
+
end
|
40
|
+
|
41
|
+
def split(split_dim, value, num_split: nil)
|
42
|
+
RawOps.split(split_dim: split_dim, value: value, num_split: num_split)
|
43
|
+
end
|
44
|
+
|
45
|
+
def strip(input)
|
46
|
+
RawOps.string_strip(input)
|
47
|
+
end
|
48
|
+
|
49
|
+
def substr(input, pos, len, unit: nil)
|
50
|
+
RawOps.substr(input, pos, len, unit: unit)
|
51
|
+
end
|
52
|
+
|
53
|
+
# def to_hash_bucket
|
54
|
+
# end
|
55
|
+
|
56
|
+
# def to_hash_bucket_fast
|
57
|
+
# end
|
58
|
+
|
59
|
+
# def to_hash_bucket_strong
|
60
|
+
# end
|
61
|
+
|
62
|
+
def to_number(input, out_type: :float)
|
63
|
+
RawOps.string_to_number(input, out_type: out_type)
|
64
|
+
end
|
65
|
+
|
66
|
+
def unicode_decode(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
67
|
+
RawOps.unicode_decode(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
68
|
+
end
|
69
|
+
|
70
|
+
def unicode_decode_with_offsets(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
71
|
+
RawOps.unicode_decode_with_offsets(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
72
|
+
end
|
73
|
+
|
74
|
+
def unicode_encode(input_values, input_splits, errors: nil, output_encoding: nil, replacement_char: nil)
|
75
|
+
RawOps.unicode_encode(input_values: input_values, input_splits: input_splits, errors: errors, output_encoding: output_encoding, replacement_char: replacement_char)
|
76
|
+
end
|
77
|
+
|
78
|
+
def unicode_script(input)
|
79
|
+
RawOps.unicode_script(input)
|
80
|
+
end
|
81
|
+
|
82
|
+
# def unicode_split
|
83
|
+
# end
|
84
|
+
|
85
|
+
# def unicode_split_with_offsets
|
86
|
+
# end
|
87
|
+
|
88
|
+
def unicode_transcode(input, input_encoding: nil, output_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
89
|
+
RawOps.unicode_transcode(input, input_encoding: input_encoding, output_encoding: output_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
90
|
+
end
|
91
|
+
|
92
|
+
# def unsorted_segment_join
|
93
|
+
# end
|
94
|
+
|
95
|
+
def upper(input, encoding: "")
|
96
|
+
RawOps.string_upper(input, encoding: encoding)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
end
|
100
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class Summary
|
3
|
+
def self.create_file_writer(logdir, max_queue: 10, flush_millis: 120_000, filename_suffix: '.v2', name: nil)
|
4
|
+
ResourceSummaryWriter.new(shared_name: name) do |writer|
|
5
|
+
RawOps.create_summary_file_writer(writer, logdir, max_queue, flush_millis, filename_suffix)
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.all_v2_summary_ops
|
10
|
+
ExecutionContext.current.get_collection_ref(Graph::GraphKeys::SUMMARY_COLLECTION)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,133 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class Tensor
|
3
|
+
include Operators
|
4
|
+
include TensorMixin
|
5
|
+
|
6
|
+
def self.finalize(pointer)
|
7
|
+
proc do
|
8
|
+
FFI.TF_DeleteTensor(pointer)
|
9
|
+
end
|
10
|
+
end
|
11
|
+
|
12
|
+
def self.from_value(value, dtype: nil)
|
13
|
+
case value
|
14
|
+
when Tensor
|
15
|
+
value
|
16
|
+
when Graph::Operation
|
17
|
+
value
|
18
|
+
when Eager::TensorHandle
|
19
|
+
value.tensor
|
20
|
+
when Data::Dataset
|
21
|
+
value.variant_tensor
|
22
|
+
else
|
23
|
+
Tensor.new(value, dtype: dtype)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def self.from_proto(proto)
|
28
|
+
proto = proto.is_a?(TensorProto) ? proto : TensorProto.decode(proto)
|
29
|
+
shape = proto.tensor_shape.dim.map(&:size)
|
30
|
+
dtype = FFI::DataType[DataType.resolve(proto.dtype)]
|
31
|
+
numo_klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
|
32
|
+
value = if shape.empty?
|
33
|
+
array = numo_klass.from_binary(proto.tensor_content)
|
34
|
+
array[0]
|
35
|
+
else
|
36
|
+
numo_klass.from_binary(proto.tensor_content, shape)
|
37
|
+
end
|
38
|
+
self.new(value, dtype:dtype, shape:shape)
|
39
|
+
end
|
40
|
+
|
41
|
+
def self.from_pointer(pointer)
|
42
|
+
result = self.allocate
|
43
|
+
result.instance_variable_set(:@pointer, pointer)
|
44
|
+
ObjectSpace.define_finalizer(result, self.finalize(pointer))
|
45
|
+
result
|
46
|
+
end
|
47
|
+
|
48
|
+
def initialize(value, dtype: nil, shape: [])
|
49
|
+
value = case value
|
50
|
+
when Numo::NArray
|
51
|
+
value
|
52
|
+
when Array
|
53
|
+
# We convert all arrays to narrays. This makes it a lot easier to support multidimensional arrays
|
54
|
+
result = Numo::NArray.cast(value)
|
55
|
+
else
|
56
|
+
TensorData.value_with_shape(value, dtype, shape)
|
57
|
+
end
|
58
|
+
|
59
|
+
tensor_data = TensorData.new(value, dtype: dtype, shape: shape)
|
60
|
+
dtype = tensor_data.dtype
|
61
|
+
shape = tensor_data.shape
|
62
|
+
|
63
|
+
if shape && shape.size > 0
|
64
|
+
dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
|
65
|
+
dims_ptr.write_array_of_int64(shape)
|
66
|
+
else
|
67
|
+
dims_ptr = nil
|
68
|
+
end
|
69
|
+
|
70
|
+
@pointer = FFI.TF_NewTensor(dtype,
|
71
|
+
dims_ptr, shape ? shape.size : 0,
|
72
|
+
tensor_data, tensor_data.byte_size,
|
73
|
+
TensorData::Deallocator, nil)
|
74
|
+
|
75
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
76
|
+
end
|
77
|
+
|
78
|
+
def value
|
79
|
+
self.data.value
|
80
|
+
end
|
81
|
+
|
82
|
+
def dtype
|
83
|
+
FFI.TF_TensorType(self)
|
84
|
+
end
|
85
|
+
|
86
|
+
def to_s
|
87
|
+
inspect
|
88
|
+
end
|
89
|
+
|
90
|
+
def to_ptr
|
91
|
+
@pointer
|
92
|
+
end
|
93
|
+
|
94
|
+
def byte_size
|
95
|
+
FFI.TF_TensorByteSize(self)
|
96
|
+
end
|
97
|
+
|
98
|
+
def inspect
|
99
|
+
inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
100
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
101
|
+
end
|
102
|
+
|
103
|
+
def data
|
104
|
+
TensorData.from_pointer(FFI.TF_TensorData(self), self.byte_size, self.dtype, self.shape)
|
105
|
+
end
|
106
|
+
|
107
|
+
private
|
108
|
+
|
109
|
+
def num_dims
|
110
|
+
FFI.TF_NumDims(self)
|
111
|
+
end
|
112
|
+
|
113
|
+
def dim(index)
|
114
|
+
FFI.TF_Dim(self, index)
|
115
|
+
end
|
116
|
+
|
117
|
+
def element_count
|
118
|
+
FFI.TF_TensorElementCount(self)
|
119
|
+
end
|
120
|
+
|
121
|
+
def calculate_shape(value)
|
122
|
+
return value.shape if value.respond_to?(:shape)
|
123
|
+
|
124
|
+
shape = []
|
125
|
+
d = value
|
126
|
+
while d.is_a?(Array)
|
127
|
+
shape << d.size
|
128
|
+
d = d.first
|
129
|
+
end
|
130
|
+
shape
|
131
|
+
end
|
132
|
+
end
|
133
|
+
end
|
@@ -0,0 +1,310 @@
|
|
1
|
+
require 'rbconfig'
|
2
|
+
module Tensorflow
|
3
|
+
# Tensorflow expects client libraries to allocate memory for the data that a tensor wraps. When a tensor is released,
|
4
|
+
# it notifies the client via a callback that gives the client a chance to release the memory.
|
5
|
+
#
|
6
|
+
# We don't want to use a FFI::MemoryPointer because they are garbage collected. If the underlying data is freed before
|
7
|
+
# the tensor is released you get a GC (this can happen even if a Ruby tensor object keeps a reference to the pointer at
|
8
|
+
# GC time).
|
9
|
+
#
|
10
|
+
# Thus this class creates its own memory and frees the memory only after being called bcak by tensorflow.
|
11
|
+
|
12
|
+
class TensorData
|
13
|
+
extend ::FFI::Library
|
14
|
+
ffi_lib "#{RbConfig::CONFIG['RUBY_SO_NAME']}.#{RbConfig::CONFIG['SOEXT']}"
|
15
|
+
attach_function :ruby_xmalloc, [:size_t], :pointer
|
16
|
+
attach_function :ruby_xfree, [:pointer], :void
|
17
|
+
|
18
|
+
attr_reader :pointer, :byte_size, :dtype, :shape
|
19
|
+
|
20
|
+
extend Forwardable
|
21
|
+
def_delegators :to_ptr, :+, *::FFI::Pointer.instance_methods.grep(/^write_/)
|
22
|
+
|
23
|
+
# Store a callback as a class consstant (so it won't be garbage collected ) that tensorflow will trigger
|
24
|
+
# when memory should be freed.
|
25
|
+
Deallocator = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
|
26
|
+
ruby_xfree(data)
|
27
|
+
end
|
28
|
+
|
29
|
+
DTYPE_TO_NUMO_TYPE_MAP = {bool: Numo::Bit,
|
30
|
+
complex64: Numo::SComplex,
|
31
|
+
complex128: Numo::DComplex,
|
32
|
+
double: Numo::DFloat,
|
33
|
+
float: Numo::SFloat,
|
34
|
+
int8: Numo::Int8,
|
35
|
+
int16: Numo::Int16,
|
36
|
+
int32: Numo::Int32,
|
37
|
+
int64: Numo::Int64,
|
38
|
+
uint8: Numo::UInt8,
|
39
|
+
uint16: Numo::UInt16,
|
40
|
+
uint32: Numo::UInt32,
|
41
|
+
uint64: Numo::UInt64}
|
42
|
+
|
43
|
+
NUMO_TYPE_TO_DTYPE_MAP = DTYPE_TO_NUMO_TYPE_MAP.each_with_object(Hash.new) do |pair, hash|
|
44
|
+
hash[pair.last] = pair.first
|
45
|
+
end
|
46
|
+
|
47
|
+
def self.figure_dtype(value)
|
48
|
+
case value
|
49
|
+
when Numo::RObject
|
50
|
+
# Need to look at the first element to see what it is
|
51
|
+
self.figure_dtype(value[0])
|
52
|
+
when Numo::NArray
|
53
|
+
NUMO_TYPE_TO_DTYPE_MAP[value.class]
|
54
|
+
when Array
|
55
|
+
self.figure_dtype(value.first)
|
56
|
+
when Integer
|
57
|
+
(value >= -2147483648 && value <= 2147483647) ? :int32 : :int64
|
58
|
+
when Complex
|
59
|
+
(value.real > -1.175494351e38 && value.real < 3.402823466e38) ? :complex64 : :complex128
|
60
|
+
when Numeric
|
61
|
+
(value > -1.175494351e38 && value < 3.402823466e38) ? :float : :double
|
62
|
+
when String
|
63
|
+
:string
|
64
|
+
when TrueClass, FalseClass
|
65
|
+
:bool
|
66
|
+
when ::FFI::Pointer
|
67
|
+
:pointer
|
68
|
+
when Tensor
|
69
|
+
value.dtype
|
70
|
+
when Variable
|
71
|
+
value.dtype
|
72
|
+
when Graph::Operation
|
73
|
+
nil
|
74
|
+
when Eager::TensorHandle
|
75
|
+
value.dtype
|
76
|
+
when Google::Protobuf::MessageExts
|
77
|
+
:string
|
78
|
+
else
|
79
|
+
raise(Error::InvalidArgumentError, "Unsupported type: #{value.class}")
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
def self.type_size(dtype)
|
84
|
+
case dtype
|
85
|
+
when :complex64
|
86
|
+
::FFI.type_size(:float) * 2
|
87
|
+
when :complex128
|
88
|
+
::FFI.type_size(:double) * 2
|
89
|
+
else
|
90
|
+
::FFI.type_size(dtype)
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
def self.value_with_shape(value, dtype, shape)
|
95
|
+
if shape && shape.size > 0
|
96
|
+
dtype ||= self.figure_dtype(value)
|
97
|
+
numo_klass = DTYPE_TO_NUMO_TYPE_MAP[dtype]
|
98
|
+
numo_klass.new(shape).fill(value)
|
99
|
+
else
|
100
|
+
value
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
def self.from_pointer(pointer, byte_size, dtype, shape)
|
105
|
+
result = self.allocate
|
106
|
+
result.instance_variable_set(:@pointer, pointer)
|
107
|
+
result.instance_variable_set(:@byte_size, byte_size)
|
108
|
+
result.instance_variable_set(:@dtype, dtype)
|
109
|
+
result.instance_variable_set(:@shape, shape)
|
110
|
+
result
|
111
|
+
end
|
112
|
+
|
113
|
+
def initialize(value, dtype: nil, shape: [])
|
114
|
+
@dtype = dtype || self.class.figure_dtype(value)
|
115
|
+
@shape = shape
|
116
|
+
case value
|
117
|
+
when Numo::NArray
|
118
|
+
self.write_narray(value)
|
119
|
+
when Array
|
120
|
+
raise(Error::InvalidArgumentError, "TensorData does not support Arrays. Please use a Numo::NArray")
|
121
|
+
when Google::Protobuf::MessageExts
|
122
|
+
encoded = value.class.encode(value)
|
123
|
+
self.write_array_of_string([encoded])
|
124
|
+
else
|
125
|
+
self.write_scalar(value)
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
def to_ptr
|
130
|
+
@pointer
|
131
|
+
end
|
132
|
+
|
133
|
+
def read_array_of_complex64(count)
|
134
|
+
values = self.read_array_of_float(2 * count)
|
135
|
+
values.each_slice(2).map do |real, imaginary|
|
136
|
+
Complex(real, imaginary)
|
137
|
+
end
|
138
|
+
end
|
139
|
+
|
140
|
+
def read_array_of_complex128(count)
|
141
|
+
values = self.read_array_of_double(2 * count)
|
142
|
+
values.each_slice(2).map do |real, imaginary|
|
143
|
+
Complex(real, imaginary)
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
def read_array_of_string(count)
|
148
|
+
# The start of the data section comes after the offset table
|
149
|
+
start_offset_size = count * ::FFI.type_size(:int64)
|
150
|
+
|
151
|
+
# Read in the string offsets
|
152
|
+
offsets = self.pointer.read_array_of_uint64(count)
|
153
|
+
|
154
|
+
offsets.map.with_index do |offset, index|
|
155
|
+
src_bytes = (offsets[index + 1] || self.byte_size) - offset
|
156
|
+
dst_ptr = ::FFI::MemoryPointer.new(:pointer)
|
157
|
+
dst_len_ptr = ::FFI::MemoryPointer.new(:size_t)
|
158
|
+
Status.check do |status|
|
159
|
+
FFI.TF_StringDecode(self.pointer + start_offset_size + offset, src_bytes, dst_ptr, dst_len_ptr, status)
|
160
|
+
end
|
161
|
+
string_pointer = dst_ptr.read_pointer
|
162
|
+
string_length = dst_len_ptr.read(:size_t)
|
163
|
+
string_pointer.read_string(string_length)
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
# def value
|
168
|
+
# result = case self.dtype
|
169
|
+
# when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
170
|
+
# self.pointer.send("read_array_of_#{self.dtype}", self.count)
|
171
|
+
# when :bfloat16
|
172
|
+
# byte_str = self.pointer.read_bytes(self.count * 2)
|
173
|
+
# self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
174
|
+
# when :complex64
|
175
|
+
# self.read_array_of_complex64(self.count)
|
176
|
+
# when :complex128
|
177
|
+
# self.read_array_of_complex128(self.count)
|
178
|
+
# when :string
|
179
|
+
# self.read_array_of_string(self.count)
|
180
|
+
# when :bool
|
181
|
+
# self.pointer.read_array_of_int8(self.count)
|
182
|
+
# when :resource, :variant
|
183
|
+
# return self.data
|
184
|
+
# else
|
185
|
+
# raise "Unsupported tensor data type: #{self.dtype}"
|
186
|
+
# end
|
187
|
+
#
|
188
|
+
# if self.count == 1
|
189
|
+
# result.first
|
190
|
+
# else
|
191
|
+
# result
|
192
|
+
# end
|
193
|
+
# end
|
194
|
+
|
195
|
+
def value
|
196
|
+
result = case self.dtype
|
197
|
+
when :bfloat16
|
198
|
+
byte_str = self.pointer.read_bytes(self.count * 2)
|
199
|
+
self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
200
|
+
when :string
|
201
|
+
count = self.shape.reduce(1) {|dim, result| result *= dim}
|
202
|
+
self.read_array_of_string(count)
|
203
|
+
when :bool
|
204
|
+
bytes = self.pointer.read_bytes(self.byte_size)
|
205
|
+
int8 = if self.shape.empty?
|
206
|
+
Numo::Int8.from_binary(bytes)
|
207
|
+
else
|
208
|
+
Numo::Int8.from_binary(bytes, self.shape)
|
209
|
+
end
|
210
|
+
int8.cast_to(Numo::Bit)
|
211
|
+
else
|
212
|
+
bytes = self.pointer.read_bytes(self.byte_size)
|
213
|
+
numo_klass = DTYPE_TO_NUMO_TYPE_MAP[self.dtype]
|
214
|
+
if self.shape.empty?
|
215
|
+
numo_klass.from_binary(bytes)
|
216
|
+
else
|
217
|
+
numo_klass.from_binary(bytes, self.shape)
|
218
|
+
end
|
219
|
+
end
|
220
|
+
|
221
|
+
if self.shape.empty?
|
222
|
+
result[0]
|
223
|
+
else
|
224
|
+
result
|
225
|
+
end
|
226
|
+
end
|
227
|
+
|
228
|
+
def write_array_of_string(strings)
|
229
|
+
# The start of the data section comes after the offset table
|
230
|
+
start_offset_size = strings.size * ::FFI.type_size(:int64)
|
231
|
+
|
232
|
+
# Get the encoded sizes for each string
|
233
|
+
encoded_sizes = strings.map do |string|
|
234
|
+
FFI.TF_StringEncodedSize(string.bytesize)
|
235
|
+
end
|
236
|
+
|
237
|
+
# Now figure the offsets. Offsets are relative to the beginning of data section, not the beginning of the pointer.
|
238
|
+
# Notice we skip the last string [0..-2] since its offset would be the end of the pointer
|
239
|
+
offsets = [0]
|
240
|
+
encoded_sizes[0..-2].each do |encoded_size|
|
241
|
+
offsets << offsets.last + encoded_size
|
242
|
+
end
|
243
|
+
|
244
|
+
# Allocate the needed memory
|
245
|
+
@byte_size = start_offset_size + encoded_sizes.sum
|
246
|
+
@pointer = self.class.ruby_xmalloc(@byte_size)
|
247
|
+
|
248
|
+
# Write the offsets
|
249
|
+
self.pointer.write_array_of_uint64(offsets)
|
250
|
+
|
251
|
+
# Write the strings
|
252
|
+
strings.each_with_index do |string, index|
|
253
|
+
offset = offsets[index]
|
254
|
+
size = encoded_sizes[index]
|
255
|
+
Status.check do |status|
|
256
|
+
FFI.TF_StringEncode(string, string.bytesize, self.pointer + start_offset_size + offset, size, status)
|
257
|
+
end
|
258
|
+
end
|
259
|
+
end
|
260
|
+
|
261
|
+
def write_narray(new_value)
|
262
|
+
@shape = new_value.shape
|
263
|
+
|
264
|
+
case self.dtype
|
265
|
+
when :string
|
266
|
+
self.write_array_of_string(new_value.flatten.to_a)
|
267
|
+
when :bool
|
268
|
+
new_value = new_value.cast_to(Numo::Int8)
|
269
|
+
@byte_size = new_value.byte_size
|
270
|
+
@pointer = self.class.ruby_xmalloc(self.byte_size)
|
271
|
+
self.pointer.write_bytes(new_value.to_binary)
|
272
|
+
else
|
273
|
+
if NUMO_TYPE_TO_DTYPE_MAP[new_value.class] != dtype
|
274
|
+
# Cast the narray if necessary (say the user passed in float but we have a double array)
|
275
|
+
new_value = new_value.cast_to(DTYPE_TO_NUMO_TYPE_MAP[dtype])
|
276
|
+
end
|
277
|
+
@byte_size = new_value.byte_size
|
278
|
+
@pointer = self.class.ruby_xmalloc(@byte_size)
|
279
|
+
self.pointer.write_bytes(new_value.to_binary)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
def write_scalar(new_value)
|
284
|
+
case self.dtype
|
285
|
+
when :string
|
286
|
+
self.write_array_of_string([new_value])
|
287
|
+
when :resource, :variant
|
288
|
+
return self
|
289
|
+
else
|
290
|
+
@byte_size = self.class.type_size(self.dtype)
|
291
|
+
@pointer = self.class.ruby_xmalloc(@byte_size)
|
292
|
+
case self.dtype
|
293
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
294
|
+
self.pointer.send("write_#{self.dtype}", new_value)
|
295
|
+
when :bfloat16
|
296
|
+
byte_str = self.pointer.read_bytes(self.count * 2)
|
297
|
+
self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
298
|
+
when :complex64
|
299
|
+
self.pointer.write_array_of_float([new_value.real, new_value.imaginary])
|
300
|
+
when :complex128
|
301
|
+
self.pointer.write_array_of_double([new_value.real, new_value.imaginary])
|
302
|
+
when :bool
|
303
|
+
self.pointer.write_int8(new_value ? 1 : 0)
|
304
|
+
else
|
305
|
+
raise "Unsupported tensor data type: #{self.dtype}"
|
306
|
+
end
|
307
|
+
end
|
308
|
+
end
|
309
|
+
end
|
310
|
+
end
|