tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class Dataset
|
4
|
+
# Copied from Python code
|
5
|
+
DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
|
6
|
+
|
7
|
+
include Enumerable
|
8
|
+
|
9
|
+
# TODO remove
|
10
|
+
attr_reader :output_types, :output_shapes, :variant_tensor
|
11
|
+
|
12
|
+
def self.to_tensor_array(values)
|
13
|
+
case values
|
14
|
+
when Numo::NArray
|
15
|
+
[Tensor.new(values)]
|
16
|
+
when Tensor
|
17
|
+
[values]
|
18
|
+
when Array
|
19
|
+
values.to_a.map do |v|
|
20
|
+
if v.is_a?(Tensor)
|
21
|
+
v
|
22
|
+
else
|
23
|
+
Tensor.new(v)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
when Graph::Operation
|
27
|
+
[values]
|
28
|
+
else
|
29
|
+
raise(Error::UnimplementedError, "Unsupported dataset element: #{values}")
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def self.from_tensors(tensors)
|
34
|
+
TensorDataset.new(tensors)
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.from_tensor_slices(tensors)
|
38
|
+
TensorSliceDataset.new(tensors)
|
39
|
+
end
|
40
|
+
|
41
|
+
def initialize(variant_tensor)
|
42
|
+
@variant_tensor = variant_tensor
|
43
|
+
end
|
44
|
+
|
45
|
+
def to_ptr
|
46
|
+
@variant_tensor.to_ptr
|
47
|
+
end
|
48
|
+
|
49
|
+
def with_options(options)
|
50
|
+
|
51
|
+
end
|
52
|
+
|
53
|
+
def batch(batch_size, drop_remainder: false)
|
54
|
+
BatchDataset.new(self, batch_size, drop_remainder)
|
55
|
+
end
|
56
|
+
|
57
|
+
def shuffle(buffer_size)
|
58
|
+
ShuffleDataset.new(self, buffer_size)
|
59
|
+
end
|
60
|
+
|
61
|
+
def make_one_shot_iterator
|
62
|
+
OneShotIterator.new(self)
|
63
|
+
end
|
64
|
+
|
65
|
+
def make_initializable_iterator(shared_name: '')
|
66
|
+
InitializableIterator.new(self, shared_name: shared_name)
|
67
|
+
end
|
68
|
+
|
69
|
+
def each
|
70
|
+
iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes)
|
71
|
+
RawOps.make_iterator(@variant_tensor, iterator)
|
72
|
+
begin
|
73
|
+
loop do
|
74
|
+
values = RawOps.iterator_get_next_sync(iterator, output_types: @output_types, output_shapes: @output_shapes)
|
75
|
+
yield values
|
76
|
+
end
|
77
|
+
rescue Error::OutOfRangeError
|
78
|
+
end
|
79
|
+
ensure
|
80
|
+
RawOps.delete_iterator(iterator, deleter) if iterator
|
81
|
+
end
|
82
|
+
|
83
|
+
# !!! DEBUG method. You don't want to use this method it because it iterates over
|
84
|
+
# the entire dataset and reads it into a ruby array in memory
|
85
|
+
def data
|
86
|
+
self.map do |slice|
|
87
|
+
if slice.is_a?(Array)
|
88
|
+
slice.map do |tensor|
|
89
|
+
tensor.value
|
90
|
+
end
|
91
|
+
else
|
92
|
+
slice.value
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def map_func(func)
|
98
|
+
MapDataset.new(self, func)
|
99
|
+
end
|
100
|
+
|
101
|
+
def repeat(count)
|
102
|
+
RepeatDataset.new(self, 3)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class FixedLengthRecordDataset < Dataset
|
4
|
+
def initialize(filenames, record_bytes, header_bytes: 0, footer_bytes: 0,
|
5
|
+
buffer_size: DEFAULT_READER_BUFFER_SIZE_BYTES, compression_type: '', num_parallel_reads: 0)
|
6
|
+
@output_types = [:string]
|
7
|
+
@output_shapes = [[]]
|
8
|
+
|
9
|
+
record_bytes_tensor = Tensor.new(record_bytes, dtype: :int64)
|
10
|
+
header_bytes_tensor = Tensor.new(header_bytes, dtype: :int64)
|
11
|
+
footer_bytes_tensor = Tensor.new(footer_bytes, dtype: :int64)
|
12
|
+
buffer_size_tensor = Tensor.new(buffer_size, dtype: :int64)
|
13
|
+
|
14
|
+
variant_tensor = RawOps.fixed_length_record_dataset_v2(filenames,
|
15
|
+
header_bytes_tensor,
|
16
|
+
record_bytes_tensor,
|
17
|
+
footer_bytes_tensor,
|
18
|
+
buffer_size_tensor,
|
19
|
+
compression_type)
|
20
|
+
|
21
|
+
super(variant_tensor)
|
22
|
+
end
|
23
|
+
|
24
|
+
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class Iterator
|
4
|
+
attr_reader :output_types, :output_shapes
|
5
|
+
|
6
|
+
def self.from_structure(output_types, output_shapes=[], shared_name: '')
|
7
|
+
ReinitializableIterator.new(output_types, output_shapes, shared_name: shared_name)
|
8
|
+
end
|
9
|
+
|
10
|
+
def initialize(output_types, output_shapes=[])
|
11
|
+
@output_types = output_types
|
12
|
+
@output_shapes = output_shapes
|
13
|
+
end
|
14
|
+
|
15
|
+
def get_next
|
16
|
+
RawOps.iterator_get_next(@iterator, output_types: self.output_types, output_shapes: self.output_shapes)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
class OneShotIterator < Iterator
|
21
|
+
def initialize(dataset)
|
22
|
+
super(dataset.output_types, dataset.output_shapes)
|
23
|
+
create_one_shot_iterator(dataset)
|
24
|
+
end
|
25
|
+
|
26
|
+
private
|
27
|
+
|
28
|
+
def create_one_shot_iterator(dataset)
|
29
|
+
function = make_dataset_function(dataset)
|
30
|
+
ExecutionContext.current.add_function(function)
|
31
|
+
@iterator = RawOps.one_shot_iterator(dataset_factory: function, output_types: self.output_types, output_shapes: self.output_shapes)
|
32
|
+
end
|
33
|
+
|
34
|
+
def make_dataset_function(dataset)
|
35
|
+
function = Graph::Graph.new.as_default do |func_graph|
|
36
|
+
optimize = RawOps.optimize_dataset(dataset.variant_tensor, ['noop_elimination'],
|
37
|
+
output_types: self.output_types, output_shapes: self.output_shapes)
|
38
|
+
func_graph.to_function('MakeDataset', nil, nil, [optimize])
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
class InitializableIterator < Iterator
|
44
|
+
attr_reader :initializer
|
45
|
+
|
46
|
+
def initialize(dataset, shared_name: '')
|
47
|
+
super(dataset.output_types, dataset.output_shapes)
|
48
|
+
create_initializable_iterator(dataset, shared_name)
|
49
|
+
end
|
50
|
+
|
51
|
+
private
|
52
|
+
|
53
|
+
def create_initializable_iterator(dataset, shared_name)
|
54
|
+
@iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
|
55
|
+
@initializer = RawOps.make_iterator(dataset.variant_tensor, @iterator)
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
class ReinitializableIterator < Iterator
|
60
|
+
def initialize(output_types, output_shapes, shared_name: '')
|
61
|
+
super(output_types, output_shapes)
|
62
|
+
create_iterator_from_structure(shared_name)
|
63
|
+
end
|
64
|
+
|
65
|
+
def make_initializer(dataset)
|
66
|
+
RawOps.make_iterator(dataset.variant_tensor, @iterator)
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def create_iterator_from_structure(shared_name)
|
72
|
+
@iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class MapDataset < Dataset
|
4
|
+
def initialize(input_dataset, function, other_arguments: [])
|
5
|
+
@output_types = function.output_types
|
6
|
+
@output_shapes = function.output_shapes
|
7
|
+
|
8
|
+
variant_tensor = RawOps.map_dataset(input_dataset.variant_tensor, other_arguments,
|
9
|
+
f: function,
|
10
|
+
output_types: @output_types,
|
11
|
+
output_shapes: @output_shapes)
|
12
|
+
|
13
|
+
super(variant_tensor)
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class RepeatDataset < Dataset
|
4
|
+
def initialize(dataset, count)
|
5
|
+
@output_types = dataset.output_types
|
6
|
+
@output_shapes = dataset.output_shapes
|
7
|
+
|
8
|
+
variant_tensor = RawOps.repeat_dataset(dataset.variant_tensor, Tensor.new(count, dtype: :int64),
|
9
|
+
output_types: @output_types,
|
10
|
+
output_shapes: @output_shapes)
|
11
|
+
|
12
|
+
super(variant_tensor)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class ShuffleDataset < Dataset
|
4
|
+
def initialize(input_dataset, buffer_size)
|
5
|
+
@input_dataset = input_dataset
|
6
|
+
@output_types = input_dataset.output_types
|
7
|
+
@output_shapes = input_dataset.output_shapes
|
8
|
+
|
9
|
+
buffer_size = Tensor.new(buffer_size, dtype: :int64)
|
10
|
+
seed = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
|
11
|
+
seed2 = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
|
12
|
+
|
13
|
+
variant_tensor = RawOps.shuffle_dataset(input_dataset.variant_tensor,
|
14
|
+
buffer_size,
|
15
|
+
seed,
|
16
|
+
seed2,
|
17
|
+
output_types: @output_types,
|
18
|
+
output_shapes: @output_shapes)
|
19
|
+
super(variant_tensor)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class TensorDataset < Dataset
|
4
|
+
def initialize(elements)
|
5
|
+
@tensors = self.class.to_tensor_array(elements)
|
6
|
+
@output_types = @tensors.map(&:dtype)
|
7
|
+
@output_shapes = @tensors.map do |tensor|
|
8
|
+
tensor.shape
|
9
|
+
end
|
10
|
+
|
11
|
+
variant_tensor = RawOps.tensor_dataset(@tensors,
|
12
|
+
toutput_types: @output_types,
|
13
|
+
output_shapes: @output_shapes)
|
14
|
+
|
15
|
+
super(variant_tensor)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class TensorSliceDataset < Dataset
|
4
|
+
def initialize(elements)
|
5
|
+
@tensors = self.class.to_tensor_array(elements)
|
6
|
+
@output_types = @tensors.map(&:dtype)
|
7
|
+
@output_shapes = @tensors.map do |tensor|
|
8
|
+
tensor.shape[1..]
|
9
|
+
end
|
10
|
+
variant_tensor = RawOps.tensor_slice_dataset(@tensors, toutput_types: @output_types, output_shapes: @output_shapes)
|
11
|
+
super(variant_tensor)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class TfRecordDataset < Dataset
|
4
|
+
DEFAULT_BUFFER_SIZE = 256 * 1_048_576 # 256 MB
|
5
|
+
|
6
|
+
def initialize(filenames, compression_type='', buffer_size=DEFAULT_BUFFER_SIZE)
|
7
|
+
filenames = Array(filenames)
|
8
|
+
@output_types = [:string]
|
9
|
+
@output_shapes = [[]]
|
10
|
+
|
11
|
+
buffer_size = Tensor.new(buffer_size, dtype: :int64) if buffer_size
|
12
|
+
variant_tensor = RawOps.tf_record_dataset(filenames, compression_type, buffer_size)
|
13
|
+
|
14
|
+
super(variant_tensor)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class ZipDataset < Dataset
|
4
|
+
def initialize(*datasets)
|
5
|
+
datasets = datasets.flatten(1)
|
6
|
+
tensors = Array.new(datasets.size)
|
7
|
+
output_types = Array.new(datasets.size)
|
8
|
+
output_shapes = Array.new(datasets.size)
|
9
|
+
|
10
|
+
datasets.each_with_index do |dataset, i|
|
11
|
+
tensors[i] = dataset.variant_tensor
|
12
|
+
output_types[i] = dataset.output_types
|
13
|
+
output_shapes[i] = dataset.output_shapes
|
14
|
+
end
|
15
|
+
|
16
|
+
@output_types = output_types.flatten
|
17
|
+
@output_shapes = output_shapes.flatten(1)
|
18
|
+
variant_tensor = RawOps.zip_dataset(tensors, n: tensors.count, output_types: @output_types, output_shapes: @output_shapes)
|
19
|
+
|
20
|
+
super(variant_tensor)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Decorator
|
3
|
+
class Function
|
4
|
+
attr_reader :input_signatures
|
5
|
+
|
6
|
+
def initialize(input_signatures = [])
|
7
|
+
@input_signatures = input_signatures
|
8
|
+
end
|
9
|
+
|
10
|
+
def wrap(method)
|
11
|
+
Graph::FunctionDef.new(method, self.input_signatures)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
def self.extended(klass)
|
16
|
+
@waiting_for_method = false
|
17
|
+
this = self
|
18
|
+
klass.instance_eval do
|
19
|
+
@tf = this
|
20
|
+
end
|
21
|
+
|
22
|
+
if klass.is_a?(Object) && klass.to_s == 'main'
|
23
|
+
klass.class.extend(self)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def self.function(input_signature = [])
|
28
|
+
@current_function = Function.new(input_signature)
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.wrap_method(method)
|
32
|
+
# We do this little dance because when the method is wrapped it will trigger method_added. So first we need
|
33
|
+
# to clear out @current_function before continuing
|
34
|
+
if @current_function
|
35
|
+
current_function = @current_function
|
36
|
+
@current_function = nil
|
37
|
+
current_function&.wrap(method)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def singleton_method_added(method_name)
|
42
|
+
super(method_name)
|
43
|
+
method = self.method(method_name)
|
44
|
+
@tf.wrap_method(method)
|
45
|
+
end
|
46
|
+
|
47
|
+
def method_added(method_name)
|
48
|
+
super(method_name)
|
49
|
+
method = self.instance_method(method_name)
|
50
|
+
@tf.wrap_method(method)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,120 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Eager
|
3
|
+
class Context
|
4
|
+
extend Forwardable
|
5
|
+
def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
|
6
|
+
|
7
|
+
def self.default
|
8
|
+
@default ||= Context.new
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.finalize(pointer)
|
12
|
+
proc { FFI.TFE_DeleteContext(pointer) }
|
13
|
+
end
|
14
|
+
|
15
|
+
def initialize
|
16
|
+
@name_scope = NameScope.new
|
17
|
+
options = FFI.TFE_NewContextOptions
|
18
|
+
Status.check do |status|
|
19
|
+
@pointer = FFI.TFE_NewContext(options, status)
|
20
|
+
end
|
21
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
22
|
+
FFI.TFE_DeleteContextOptions(options)
|
23
|
+
end
|
24
|
+
|
25
|
+
def as_default
|
26
|
+
raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
|
27
|
+
ExecutionContext.push(self)
|
28
|
+
begin
|
29
|
+
yield self
|
30
|
+
ensure
|
31
|
+
ExecutionContext.pop
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def create_operation(op_type, inputs=[], attrs={})
|
36
|
+
Operation.new(self, op_type, inputs, attrs)
|
37
|
+
end
|
38
|
+
|
39
|
+
def execute(operation)
|
40
|
+
# TODO decide how many retvals to allocate
|
41
|
+
retvals = ::FFI::MemoryPointer.new(:pointer, 10)
|
42
|
+
num_retvals = ::FFI::MemoryPointer.new(:int)
|
43
|
+
num_retvals.write_int(retvals.size)
|
44
|
+
|
45
|
+
Status.check do |status|
|
46
|
+
FFI.TFE_Execute(operation, retvals, num_retvals, status)
|
47
|
+
end
|
48
|
+
|
49
|
+
n = num_retvals.read_int
|
50
|
+
if n > 0
|
51
|
+
handles = retvals.read_array_of_pointer(n).map do |handle|
|
52
|
+
TensorHandle.new(self, handle)
|
53
|
+
end
|
54
|
+
|
55
|
+
# TODO handle case where n = 1 and still want an array for retvals
|
56
|
+
n == 1 ? handles.first : handles
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def device_policy
|
61
|
+
FFI::ContextDevicePlacementPolicy[FFI.TFE_ContextGetDevicePlacementPolicy(@pointer)]
|
62
|
+
end
|
63
|
+
|
64
|
+
def enable_run_metadata
|
65
|
+
FFI.TFE_ContextEnableRunMetadata(@pointer)
|
66
|
+
end
|
67
|
+
|
68
|
+
def disable_run_metadata
|
69
|
+
FFI.TFE_ContextDisableRunMetadata(@pointer)
|
70
|
+
end
|
71
|
+
|
72
|
+
def start_step
|
73
|
+
FFI.TFE_ContextStartStep(@pointer)
|
74
|
+
end
|
75
|
+
|
76
|
+
def end_step
|
77
|
+
FFI.TFE_ContextEndStep(@pointer)
|
78
|
+
end
|
79
|
+
|
80
|
+
def to_ptr
|
81
|
+
@pointer
|
82
|
+
end
|
83
|
+
|
84
|
+
def shared_name
|
85
|
+
# hard-coded in Python library
|
86
|
+
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"
|
87
|
+
end
|
88
|
+
|
89
|
+
# Mimic graph api
|
90
|
+
def add_to_collection(name, value)
|
91
|
+
end
|
92
|
+
|
93
|
+
# Mimic graph api
|
94
|
+
def add_to_collections(names, value)
|
95
|
+
end
|
96
|
+
|
97
|
+
def get_collection_ref(name)
|
98
|
+
end
|
99
|
+
|
100
|
+
def add_function(function)
|
101
|
+
Status.check do |status|
|
102
|
+
FFI.TFE_ContextAddFunction(self, function, status)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
def remove_function(function)
|
107
|
+
name = function.is_a?(Graph::Function) ? function.name : function
|
108
|
+
Status.check do |status|
|
109
|
+
FFI.TFE_ContextRemoveFunction(self, name, status)
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
def function?(function)
|
114
|
+
name = function.is_a?(Graph::Function) ? function.name : function
|
115
|
+
# result is uchar
|
116
|
+
FFI.TFE_ContextHasFunction(self, name) != 0
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|