tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: stream_executor/dnn.proto
|
3
|
+
|
4
|
+
require 'google/protobuf'
|
5
|
+
|
6
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
7
|
+
add_file("stream_executor/dnn.proto", :syntax => :proto3) do
|
8
|
+
add_message "stream_executor.dnn.TensorDescriptorProto" do
|
9
|
+
repeated :dimensions, :int64, 1
|
10
|
+
optional :data_type, :enum, 2, "stream_executor.dnn.DataType"
|
11
|
+
oneof :layout_oneof do
|
12
|
+
optional :data_layout, :enum, 3, "stream_executor.dnn.DataLayout"
|
13
|
+
optional :filter_layout, :enum, 4, "stream_executor.dnn.FilterLayout"
|
14
|
+
end
|
15
|
+
end
|
16
|
+
add_message "stream_executor.dnn.AlgorithmProto" do
|
17
|
+
optional :algo_id, :int64, 1
|
18
|
+
optional :math_type, :enum, 2, "stream_executor.dnn.AlgorithmProto.MathType"
|
19
|
+
end
|
20
|
+
add_enum "stream_executor.dnn.AlgorithmProto.MathType" do
|
21
|
+
value :DEFAULT_MATH, 0
|
22
|
+
value :TENSOR_OP_MATH, 1
|
23
|
+
end
|
24
|
+
add_message "stream_executor.dnn.ConvolutionDescriptorProto" do
|
25
|
+
repeated :paddings, :int64, 1
|
26
|
+
repeated :strides, :int64, 2
|
27
|
+
repeated :dilations, :int64, 3
|
28
|
+
optional :compute_mode, :enum, 4, "stream_executor.dnn.DataType"
|
29
|
+
optional :group_count, :int32, 5
|
30
|
+
optional :convolution_mode, :enum, 6, "stream_executor.dnn.ConvolutionMode"
|
31
|
+
optional :name, :string, 7
|
32
|
+
end
|
33
|
+
add_enum "stream_executor.dnn.DataType" do
|
34
|
+
value :kFloat, 0
|
35
|
+
value :kDouble, 1
|
36
|
+
value :kHalf, 2
|
37
|
+
value :kInt8, 3
|
38
|
+
value :kInt32, 4
|
39
|
+
end
|
40
|
+
add_enum "stream_executor.dnn.DataLayout" do
|
41
|
+
value :kYXDepthBatch, 0
|
42
|
+
value :kYXBatchDepth, 1
|
43
|
+
value :kBatchYXDepth, 2
|
44
|
+
value :kBatchDepthYX, 3
|
45
|
+
value :kBatchDepthYX4, 4
|
46
|
+
end
|
47
|
+
add_enum "stream_executor.dnn.FilterLayout" do
|
48
|
+
value :kOutputInputYX, 0
|
49
|
+
value :kOutputYXInput, 1
|
50
|
+
value :kOutputInputYX4, 2
|
51
|
+
value :kInputYXOutput, 3
|
52
|
+
value :kYXInputOutput, 4
|
53
|
+
end
|
54
|
+
add_enum "stream_executor.dnn.ActivationMode" do
|
55
|
+
value :kNone, 0
|
56
|
+
value :kSigmoid, 1
|
57
|
+
value :kRelu, 2
|
58
|
+
value :kRelu6, 3
|
59
|
+
value :kReluX, 4
|
60
|
+
value :kTanh, 5
|
61
|
+
value :kBandPass, 6
|
62
|
+
end
|
63
|
+
add_enum "stream_executor.dnn.ConvolutionMode" do
|
64
|
+
value :CROSS_CORRELATION, 0
|
65
|
+
value :CONVOLUTION, 1
|
66
|
+
end
|
67
|
+
add_enum "stream_executor.dnn.ConvolutionKind" do
|
68
|
+
value :INVALID, 0
|
69
|
+
value :FORWARD, 1
|
70
|
+
value :BACKWARD_FILTER, 2
|
71
|
+
value :BACKWARD_DATA, 3
|
72
|
+
value :FORWARD_BIAS_ACTIVATION, 4
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
module StreamExecutor
|
78
|
+
module Dnn
|
79
|
+
TensorDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.TensorDescriptorProto").msgclass
|
80
|
+
AlgorithmProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto").msgclass
|
81
|
+
AlgorithmProto::MathType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.AlgorithmProto.MathType").enummodule
|
82
|
+
ConvolutionDescriptorProto = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionDescriptorProto").msgclass
|
83
|
+
DataType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataType").enummodule
|
84
|
+
DataLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.DataLayout").enummodule
|
85
|
+
FilterLayout = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.FilterLayout").enummodule
|
86
|
+
ActivationMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ActivationMode").enummodule
|
87
|
+
ConvolutionMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionMode").enummodule
|
88
|
+
ConvolutionKind = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("stream_executor.dnn.ConvolutionKind").enummodule
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,100 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Strings
|
3
|
+
class << self
|
4
|
+
def as_string(input, precision: nil, scientific: nil, shortest: nil, width: nil, fill: nil)
|
5
|
+
RawOps.as_string(input, precision: precision, scientific: scientific, shortest: shortest, width: width, fill: fill)
|
6
|
+
end
|
7
|
+
|
8
|
+
# def bytes_split
|
9
|
+
# end
|
10
|
+
|
11
|
+
# def format
|
12
|
+
# end
|
13
|
+
|
14
|
+
def join(inputs, separator: "")
|
15
|
+
RawOps.string_join(inputs, separator: separator, n: inputs.size)
|
16
|
+
end
|
17
|
+
|
18
|
+
def length(input, unit: "BYTE")
|
19
|
+
RawOps.string_length(input, unit: unit)
|
20
|
+
end
|
21
|
+
|
22
|
+
def lower(input)
|
23
|
+
RawOps.string_lower(input.encode('UTF-8'), encoding: 'utf-8')
|
24
|
+
end
|
25
|
+
|
26
|
+
# def ngrams
|
27
|
+
# end
|
28
|
+
|
29
|
+
def reduce_join(inputs, reduction_indices, keep_dims: nil, separator: nil)
|
30
|
+
RawOps.reduce_join(inputs, reduction_indices: reduction_indices, keep_dims: keep_dims, separator: separator)
|
31
|
+
end
|
32
|
+
|
33
|
+
def regex_full_match(input, pattern)
|
34
|
+
RawOps.regex_full_match(input, pattern: pattern)
|
35
|
+
end
|
36
|
+
|
37
|
+
def regex_replace(input, pattern, rewrite, replace_global: nil)
|
38
|
+
RawOps.regex_replace(input, pattern: pattern, rewrite: rewrite, replace_global: replace_global)
|
39
|
+
end
|
40
|
+
|
41
|
+
def split(split_dim, value, num_split: nil)
|
42
|
+
RawOps.split(split_dim: split_dim, value: value, num_split: num_split)
|
43
|
+
end
|
44
|
+
|
45
|
+
def strip(input)
|
46
|
+
RawOps.string_strip(input)
|
47
|
+
end
|
48
|
+
|
49
|
+
def substr(input, pos, len, unit: nil)
|
50
|
+
RawOps.substr(input, pos, len, unit: unit)
|
51
|
+
end
|
52
|
+
|
53
|
+
# def to_hash_bucket
|
54
|
+
# end
|
55
|
+
|
56
|
+
# def to_hash_bucket_fast
|
57
|
+
# end
|
58
|
+
|
59
|
+
# def to_hash_bucket_strong
|
60
|
+
# end
|
61
|
+
|
62
|
+
def to_number(input, out_type: :float)
|
63
|
+
RawOps.string_to_number(input, out_type: out_type)
|
64
|
+
end
|
65
|
+
|
66
|
+
def unicode_decode(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
67
|
+
RawOps.unicode_decode(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
68
|
+
end
|
69
|
+
|
70
|
+
def unicode_decode_with_offsets(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
71
|
+
RawOps.unicode_decode_with_offsets(input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
72
|
+
end
|
73
|
+
|
74
|
+
def unicode_encode(input_values, input_splits, errors: nil, output_encoding: nil, replacement_char: nil)
|
75
|
+
RawOps.unicode_encode(input_values: input_values, input_splits: input_splits, errors: errors, output_encoding: output_encoding, replacement_char: replacement_char)
|
76
|
+
end
|
77
|
+
|
78
|
+
def unicode_script(input)
|
79
|
+
RawOps.unicode_script(input)
|
80
|
+
end
|
81
|
+
|
82
|
+
# def unicode_split
|
83
|
+
# end
|
84
|
+
|
85
|
+
# def unicode_split_with_offsets
|
86
|
+
# end
|
87
|
+
|
88
|
+
def unicode_transcode(input, input_encoding: nil, output_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
89
|
+
RawOps.unicode_transcode(input, input_encoding: input_encoding, output_encoding: output_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
90
|
+
end
|
91
|
+
|
92
|
+
# def unsorted_segment_join
|
93
|
+
# end
|
94
|
+
|
95
|
+
def upper(input, encoding: "")
|
96
|
+
RawOps.string_upper(input, encoding: encoding)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
end
|
100
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class Summary
|
3
|
+
def self.create_file_writer(logdir, max_queue: 10, flush_millis: 120_000, filename_suffix: '.v2', name: nil)
|
4
|
+
ResourceSummaryWriter.new(shared_name: name) do |writer|
|
5
|
+
RawOps.create_summary_file_writer(writer, logdir, max_queue, flush_millis, filename_suffix)
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.all_v2_summary_ops
|
10
|
+
ExecutionContext.current.get_collection_ref(Graph::GraphKeys::SUMMARY_COLLECTION)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,133 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class Tensor
|
3
|
+
include Operators
|
4
|
+
include TensorMixin
|
5
|
+
|
6
|
+
def self.finalize(pointer)
|
7
|
+
proc do
|
8
|
+
FFI.TF_DeleteTensor(pointer)
|
9
|
+
end
|
10
|
+
end
|
11
|
+
|
12
|
+
def self.from_value(value, dtype: nil)
|
13
|
+
case value
|
14
|
+
when Tensor
|
15
|
+
value
|
16
|
+
when Graph::Operation
|
17
|
+
value
|
18
|
+
when Eager::TensorHandle
|
19
|
+
value.tensor
|
20
|
+
when Data::Dataset
|
21
|
+
value.variant_tensor
|
22
|
+
else
|
23
|
+
Tensor.new(value, dtype: dtype)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def self.from_proto(proto)
|
28
|
+
proto = proto.is_a?(TensorProto) ? proto : TensorProto.decode(proto)
|
29
|
+
shape = proto.tensor_shape.dim.map(&:size)
|
30
|
+
dtype = FFI::DataType[DataType.resolve(proto.dtype)]
|
31
|
+
numo_klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
|
32
|
+
value = if shape.empty?
|
33
|
+
array = numo_klass.from_binary(proto.tensor_content)
|
34
|
+
array[0]
|
35
|
+
else
|
36
|
+
numo_klass.from_binary(proto.tensor_content, shape)
|
37
|
+
end
|
38
|
+
self.new(value, dtype:dtype, shape:shape)
|
39
|
+
end
|
40
|
+
|
41
|
+
def self.from_pointer(pointer)
|
42
|
+
result = self.allocate
|
43
|
+
result.instance_variable_set(:@pointer, pointer)
|
44
|
+
ObjectSpace.define_finalizer(result, self.finalize(pointer))
|
45
|
+
result
|
46
|
+
end
|
47
|
+
|
48
|
+
def initialize(value, dtype: nil, shape: [])
|
49
|
+
value = case value
|
50
|
+
when Numo::NArray
|
51
|
+
value
|
52
|
+
when Array
|
53
|
+
# We convert all arrays to narrays. This makes it a lot easier to support multidimensional arrays
|
54
|
+
result = Numo::NArray.cast(value)
|
55
|
+
else
|
56
|
+
TensorData.value_with_shape(value, dtype, shape)
|
57
|
+
end
|
58
|
+
|
59
|
+
tensor_data = TensorData.new(value, dtype: dtype, shape: shape)
|
60
|
+
dtype = tensor_data.dtype
|
61
|
+
shape = tensor_data.shape
|
62
|
+
|
63
|
+
if shape && shape.size > 0
|
64
|
+
dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
|
65
|
+
dims_ptr.write_array_of_int64(shape)
|
66
|
+
else
|
67
|
+
dims_ptr = nil
|
68
|
+
end
|
69
|
+
|
70
|
+
@pointer = FFI.TF_NewTensor(dtype,
|
71
|
+
dims_ptr, shape ? shape.size : 0,
|
72
|
+
tensor_data, tensor_data.byte_size,
|
73
|
+
TensorData::Deallocator, nil)
|
74
|
+
|
75
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
76
|
+
end
|
77
|
+
|
78
|
+
def value
|
79
|
+
self.data.value
|
80
|
+
end
|
81
|
+
|
82
|
+
def dtype
|
83
|
+
FFI.TF_TensorType(self)
|
84
|
+
end
|
85
|
+
|
86
|
+
def to_s
|
87
|
+
inspect
|
88
|
+
end
|
89
|
+
|
90
|
+
def to_ptr
|
91
|
+
@pointer
|
92
|
+
end
|
93
|
+
|
94
|
+
def byte_size
|
95
|
+
FFI.TF_TensorByteSize(self)
|
96
|
+
end
|
97
|
+
|
98
|
+
def inspect
|
99
|
+
inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
100
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
101
|
+
end
|
102
|
+
|
103
|
+
def data
|
104
|
+
TensorData.from_pointer(FFI.TF_TensorData(self), self.byte_size, self.dtype, self.shape)
|
105
|
+
end
|
106
|
+
|
107
|
+
private
|
108
|
+
|
109
|
+
def num_dims
|
110
|
+
FFI.TF_NumDims(self)
|
111
|
+
end
|
112
|
+
|
113
|
+
def dim(index)
|
114
|
+
FFI.TF_Dim(self, index)
|
115
|
+
end
|
116
|
+
|
117
|
+
def element_count
|
118
|
+
FFI.TF_TensorElementCount(self)
|
119
|
+
end
|
120
|
+
|
121
|
+
def calculate_shape(value)
|
122
|
+
return value.shape if value.respond_to?(:shape)
|
123
|
+
|
124
|
+
shape = []
|
125
|
+
d = value
|
126
|
+
while d.is_a?(Array)
|
127
|
+
shape << d.size
|
128
|
+
d = d.first
|
129
|
+
end
|
130
|
+
shape
|
131
|
+
end
|
132
|
+
end
|
133
|
+
end
|
@@ -0,0 +1,310 @@
|
|
1
|
+
require 'rbconfig'
|
2
|
+
module Tensorflow
|
3
|
+
# Tensorflow expects client libraries to allocate memory for the data that a tensor wraps. When a tensor is released,
|
4
|
+
# it notifies the client via a callback that gives the client a chance to release the memory.
|
5
|
+
#
|
6
|
+
# We don't want to use a FFI::MemoryPointer because they are garbage collected. If the underlying data is freed before
|
7
|
+
# the tensor is released you get a GC (this can happen even if a Ruby tensor object keeps a reference to the pointer at
|
8
|
+
# GC time).
|
9
|
+
#
|
10
|
+
# Thus this class creates its own memory and frees the memory only after being called bcak by tensorflow.
|
11
|
+
|
12
|
+
class TensorData
|
13
|
+
extend ::FFI::Library
|
14
|
+
ffi_lib "#{RbConfig::CONFIG['RUBY_SO_NAME']}.#{RbConfig::CONFIG['SOEXT']}"
|
15
|
+
attach_function :ruby_xmalloc, [:size_t], :pointer
|
16
|
+
attach_function :ruby_xfree, [:pointer], :void
|
17
|
+
|
18
|
+
attr_reader :pointer, :byte_size, :dtype, :shape
|
19
|
+
|
20
|
+
extend Forwardable
|
21
|
+
def_delegators :to_ptr, :+, *::FFI::Pointer.instance_methods.grep(/^write_/)
|
22
|
+
|
23
|
+
# Store a callback as a class consstant (so it won't be garbage collected ) that tensorflow will trigger
|
24
|
+
# when memory should be freed.
|
25
|
+
Deallocator = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
|
26
|
+
ruby_xfree(data)
|
27
|
+
end
|
28
|
+
|
29
|
+
DTYPE_TO_NUMO_TYPE_MAP = {bool: Numo::Bit,
|
30
|
+
complex64: Numo::SComplex,
|
31
|
+
complex128: Numo::DComplex,
|
32
|
+
double: Numo::DFloat,
|
33
|
+
float: Numo::SFloat,
|
34
|
+
int8: Numo::Int8,
|
35
|
+
int16: Numo::Int16,
|
36
|
+
int32: Numo::Int32,
|
37
|
+
int64: Numo::Int64,
|
38
|
+
uint8: Numo::UInt8,
|
39
|
+
uint16: Numo::UInt16,
|
40
|
+
uint32: Numo::UInt32,
|
41
|
+
uint64: Numo::UInt64}
|
42
|
+
|
43
|
+
NUMO_TYPE_TO_DTYPE_MAP = DTYPE_TO_NUMO_TYPE_MAP.each_with_object(Hash.new) do |pair, hash|
|
44
|
+
hash[pair.last] = pair.first
|
45
|
+
end
|
46
|
+
|
47
|
+
def self.figure_dtype(value)
|
48
|
+
case value
|
49
|
+
when Numo::RObject
|
50
|
+
# Need to look at the first element to see what it is
|
51
|
+
self.figure_dtype(value[0])
|
52
|
+
when Numo::NArray
|
53
|
+
NUMO_TYPE_TO_DTYPE_MAP[value.class]
|
54
|
+
when Array
|
55
|
+
self.figure_dtype(value.first)
|
56
|
+
when Integer
|
57
|
+
(value >= -2147483648 && value <= 2147483647) ? :int32 : :int64
|
58
|
+
when Complex
|
59
|
+
(value.real > -1.175494351e38 && value.real < 3.402823466e38) ? :complex64 : :complex128
|
60
|
+
when Numeric
|
61
|
+
(value > -1.175494351e38 && value < 3.402823466e38) ? :float : :double
|
62
|
+
when String
|
63
|
+
:string
|
64
|
+
when TrueClass, FalseClass
|
65
|
+
:bool
|
66
|
+
when ::FFI::Pointer
|
67
|
+
:pointer
|
68
|
+
when Tensor
|
69
|
+
value.dtype
|
70
|
+
when Variable
|
71
|
+
value.dtype
|
72
|
+
when Graph::Operation
|
73
|
+
nil
|
74
|
+
when Eager::TensorHandle
|
75
|
+
value.dtype
|
76
|
+
when Google::Protobuf::MessageExts
|
77
|
+
:string
|
78
|
+
else
|
79
|
+
raise(Error::InvalidArgumentError, "Unsupported type: #{value.class}")
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
def self.type_size(dtype)
|
84
|
+
case dtype
|
85
|
+
when :complex64
|
86
|
+
::FFI.type_size(:float) * 2
|
87
|
+
when :complex128
|
88
|
+
::FFI.type_size(:double) * 2
|
89
|
+
else
|
90
|
+
::FFI.type_size(dtype)
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
def self.value_with_shape(value, dtype, shape)
|
95
|
+
if shape && shape.size > 0
|
96
|
+
dtype ||= self.figure_dtype(value)
|
97
|
+
numo_klass = DTYPE_TO_NUMO_TYPE_MAP[dtype]
|
98
|
+
numo_klass.new(shape).fill(value)
|
99
|
+
else
|
100
|
+
value
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
def self.from_pointer(pointer, byte_size, dtype, shape)
|
105
|
+
result = self.allocate
|
106
|
+
result.instance_variable_set(:@pointer, pointer)
|
107
|
+
result.instance_variable_set(:@byte_size, byte_size)
|
108
|
+
result.instance_variable_set(:@dtype, dtype)
|
109
|
+
result.instance_variable_set(:@shape, shape)
|
110
|
+
result
|
111
|
+
end
|
112
|
+
|
113
|
+
def initialize(value, dtype: nil, shape: [])
|
114
|
+
@dtype = dtype || self.class.figure_dtype(value)
|
115
|
+
@shape = shape
|
116
|
+
case value
|
117
|
+
when Numo::NArray
|
118
|
+
self.write_narray(value)
|
119
|
+
when Array
|
120
|
+
raise(Error::InvalidArgumentError, "TensorData does not support Arrays. Please use a Numo::NArray")
|
121
|
+
when Google::Protobuf::MessageExts
|
122
|
+
encoded = value.class.encode(value)
|
123
|
+
self.write_array_of_string([encoded])
|
124
|
+
else
|
125
|
+
self.write_scalar(value)
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
def to_ptr
|
130
|
+
@pointer
|
131
|
+
end
|
132
|
+
|
133
|
+
def read_array_of_complex64(count)
|
134
|
+
values = self.read_array_of_float(2 * count)
|
135
|
+
values.each_slice(2).map do |real, imaginary|
|
136
|
+
Complex(real, imaginary)
|
137
|
+
end
|
138
|
+
end
|
139
|
+
|
140
|
+
def read_array_of_complex128(count)
|
141
|
+
values = self.read_array_of_double(2 * count)
|
142
|
+
values.each_slice(2).map do |real, imaginary|
|
143
|
+
Complex(real, imaginary)
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
def read_array_of_string(count)
|
148
|
+
# The start of the data section comes after the offset table
|
149
|
+
start_offset_size = count * ::FFI.type_size(:int64)
|
150
|
+
|
151
|
+
# Read in the string offsets
|
152
|
+
offsets = self.pointer.read_array_of_uint64(count)
|
153
|
+
|
154
|
+
offsets.map.with_index do |offset, index|
|
155
|
+
src_bytes = (offsets[index + 1] || self.byte_size) - offset
|
156
|
+
dst_ptr = ::FFI::MemoryPointer.new(:pointer)
|
157
|
+
dst_len_ptr = ::FFI::MemoryPointer.new(:size_t)
|
158
|
+
Status.check do |status|
|
159
|
+
FFI.TF_StringDecode(self.pointer + start_offset_size + offset, src_bytes, dst_ptr, dst_len_ptr, status)
|
160
|
+
end
|
161
|
+
string_pointer = dst_ptr.read_pointer
|
162
|
+
string_length = dst_len_ptr.read(:size_t)
|
163
|
+
string_pointer.read_string(string_length)
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
# def value
|
168
|
+
# result = case self.dtype
|
169
|
+
# when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
170
|
+
# self.pointer.send("read_array_of_#{self.dtype}", self.count)
|
171
|
+
# when :bfloat16
|
172
|
+
# byte_str = self.pointer.read_bytes(self.count * 2)
|
173
|
+
# self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
174
|
+
# when :complex64
|
175
|
+
# self.read_array_of_complex64(self.count)
|
176
|
+
# when :complex128
|
177
|
+
# self.read_array_of_complex128(self.count)
|
178
|
+
# when :string
|
179
|
+
# self.read_array_of_string(self.count)
|
180
|
+
# when :bool
|
181
|
+
# self.pointer.read_array_of_int8(self.count)
|
182
|
+
# when :resource, :variant
|
183
|
+
# return self.data
|
184
|
+
# else
|
185
|
+
# raise "Unsupported tensor data type: #{self.dtype}"
|
186
|
+
# end
|
187
|
+
#
|
188
|
+
# if self.count == 1
|
189
|
+
# result.first
|
190
|
+
# else
|
191
|
+
# result
|
192
|
+
# end
|
193
|
+
# end
|
194
|
+
|
195
|
+
def value
|
196
|
+
result = case self.dtype
|
197
|
+
when :bfloat16
|
198
|
+
byte_str = self.pointer.read_bytes(self.count * 2)
|
199
|
+
self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
200
|
+
when :string
|
201
|
+
count = self.shape.reduce(1) {|dim, result| result *= dim}
|
202
|
+
self.read_array_of_string(count)
|
203
|
+
when :bool
|
204
|
+
bytes = self.pointer.read_bytes(self.byte_size)
|
205
|
+
int8 = if self.shape.empty?
|
206
|
+
Numo::Int8.from_binary(bytes)
|
207
|
+
else
|
208
|
+
Numo::Int8.from_binary(bytes, self.shape)
|
209
|
+
end
|
210
|
+
int8.cast_to(Numo::Bit)
|
211
|
+
else
|
212
|
+
bytes = self.pointer.read_bytes(self.byte_size)
|
213
|
+
numo_klass = DTYPE_TO_NUMO_TYPE_MAP[self.dtype]
|
214
|
+
if self.shape.empty?
|
215
|
+
numo_klass.from_binary(bytes)
|
216
|
+
else
|
217
|
+
numo_klass.from_binary(bytes, self.shape)
|
218
|
+
end
|
219
|
+
end
|
220
|
+
|
221
|
+
if self.shape.empty?
|
222
|
+
result[0]
|
223
|
+
else
|
224
|
+
result
|
225
|
+
end
|
226
|
+
end
|
227
|
+
|
228
|
+
def write_array_of_string(strings)
|
229
|
+
# The start of the data section comes after the offset table
|
230
|
+
start_offset_size = strings.size * ::FFI.type_size(:int64)
|
231
|
+
|
232
|
+
# Get the encoded sizes for each string
|
233
|
+
encoded_sizes = strings.map do |string|
|
234
|
+
FFI.TF_StringEncodedSize(string.bytesize)
|
235
|
+
end
|
236
|
+
|
237
|
+
# Now figure the offsets. Offsets are relative to the beginning of data section, not the beginning of the pointer.
|
238
|
+
# Notice we skip the last string [0..-2] since its offset would be the end of the pointer
|
239
|
+
offsets = [0]
|
240
|
+
encoded_sizes[0..-2].each do |encoded_size|
|
241
|
+
offsets << offsets.last + encoded_size
|
242
|
+
end
|
243
|
+
|
244
|
+
# Allocate the needed memory
|
245
|
+
@byte_size = start_offset_size + encoded_sizes.sum
|
246
|
+
@pointer = self.class.ruby_xmalloc(@byte_size)
|
247
|
+
|
248
|
+
# Write the offsets
|
249
|
+
self.pointer.write_array_of_uint64(offsets)
|
250
|
+
|
251
|
+
# Write the strings
|
252
|
+
strings.each_with_index do |string, index|
|
253
|
+
offset = offsets[index]
|
254
|
+
size = encoded_sizes[index]
|
255
|
+
Status.check do |status|
|
256
|
+
FFI.TF_StringEncode(string, string.bytesize, self.pointer + start_offset_size + offset, size, status)
|
257
|
+
end
|
258
|
+
end
|
259
|
+
end
|
260
|
+
|
261
|
+
def write_narray(new_value)
|
262
|
+
@shape = new_value.shape
|
263
|
+
|
264
|
+
case self.dtype
|
265
|
+
when :string
|
266
|
+
self.write_array_of_string(new_value.flatten.to_a)
|
267
|
+
when :bool
|
268
|
+
new_value = new_value.cast_to(Numo::Int8)
|
269
|
+
@byte_size = new_value.byte_size
|
270
|
+
@pointer = self.class.ruby_xmalloc(self.byte_size)
|
271
|
+
self.pointer.write_bytes(new_value.to_binary)
|
272
|
+
else
|
273
|
+
if NUMO_TYPE_TO_DTYPE_MAP[new_value.class] != dtype
|
274
|
+
# Cast the narray if necessary (say the user passed in float but we have a double array)
|
275
|
+
new_value = new_value.cast_to(DTYPE_TO_NUMO_TYPE_MAP[dtype])
|
276
|
+
end
|
277
|
+
@byte_size = new_value.byte_size
|
278
|
+
@pointer = self.class.ruby_xmalloc(@byte_size)
|
279
|
+
self.pointer.write_bytes(new_value.to_binary)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
def write_scalar(new_value)
|
284
|
+
case self.dtype
|
285
|
+
when :string
|
286
|
+
self.write_array_of_string([new_value])
|
287
|
+
when :resource, :variant
|
288
|
+
return self
|
289
|
+
else
|
290
|
+
@byte_size = self.class.type_size(self.dtype)
|
291
|
+
@pointer = self.class.ruby_xmalloc(@byte_size)
|
292
|
+
case self.dtype
|
293
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
294
|
+
self.pointer.send("write_#{self.dtype}", new_value)
|
295
|
+
when :bfloat16
|
296
|
+
byte_str = self.pointer.read_bytes(self.count * 2)
|
297
|
+
self.count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
|
298
|
+
when :complex64
|
299
|
+
self.pointer.write_array_of_float([new_value.real, new_value.imaginary])
|
300
|
+
when :complex128
|
301
|
+
self.pointer.write_array_of_double([new_value.real, new_value.imaginary])
|
302
|
+
when :bool
|
303
|
+
self.pointer.write_int8(new_value ? 1 : 0)
|
304
|
+
else
|
305
|
+
raise "Unsupported tensor data type: #{self.dtype}"
|
306
|
+
end
|
307
|
+
end
|
308
|
+
end
|
309
|
+
end
|
310
|
+
end
|