tensorflow-ruby 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class Dataset
|
4
|
+
# Copied from Python code
|
5
|
+
DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
|
6
|
+
|
7
|
+
include Enumerable
|
8
|
+
|
9
|
+
# TODO remove
|
10
|
+
attr_reader :output_types, :output_shapes, :variant_tensor
|
11
|
+
|
12
|
+
def self.to_tensor_array(values)
|
13
|
+
case values
|
14
|
+
when Numo::NArray
|
15
|
+
[Tensor.new(values)]
|
16
|
+
when Tensor
|
17
|
+
[values]
|
18
|
+
when Array
|
19
|
+
values.to_a.map do |v|
|
20
|
+
if v.is_a?(Tensor)
|
21
|
+
v
|
22
|
+
else
|
23
|
+
Tensor.new(v)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
when Graph::Operation
|
27
|
+
[values]
|
28
|
+
else
|
29
|
+
raise(Error::UnimplementedError, "Unsupported dataset element: #{values}")
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def self.from_tensors(tensors)
|
34
|
+
TensorDataset.new(tensors)
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.from_tensor_slices(tensors)
|
38
|
+
TensorSliceDataset.new(tensors)
|
39
|
+
end
|
40
|
+
|
41
|
+
def initialize(variant_tensor)
|
42
|
+
@variant_tensor = variant_tensor
|
43
|
+
end
|
44
|
+
|
45
|
+
def to_ptr
|
46
|
+
@variant_tensor.to_ptr
|
47
|
+
end
|
48
|
+
|
49
|
+
def with_options(options)
|
50
|
+
|
51
|
+
end
|
52
|
+
|
53
|
+
def batch(batch_size, drop_remainder: false)
|
54
|
+
BatchDataset.new(self, batch_size, drop_remainder)
|
55
|
+
end
|
56
|
+
|
57
|
+
def shuffle(buffer_size)
|
58
|
+
ShuffleDataset.new(self, buffer_size)
|
59
|
+
end
|
60
|
+
|
61
|
+
def make_one_shot_iterator
|
62
|
+
OneShotIterator.new(self)
|
63
|
+
end
|
64
|
+
|
65
|
+
def make_initializable_iterator(shared_name: '')
|
66
|
+
InitializableIterator.new(self, shared_name: shared_name)
|
67
|
+
end
|
68
|
+
|
69
|
+
def each
|
70
|
+
iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes)
|
71
|
+
RawOps.make_iterator(@variant_tensor, iterator)
|
72
|
+
begin
|
73
|
+
loop do
|
74
|
+
values = RawOps.iterator_get_next_sync(iterator, output_types: @output_types, output_shapes: @output_shapes)
|
75
|
+
yield values
|
76
|
+
end
|
77
|
+
rescue Error::OutOfRangeError
|
78
|
+
end
|
79
|
+
ensure
|
80
|
+
RawOps.delete_iterator(iterator, deleter) if iterator
|
81
|
+
end
|
82
|
+
|
83
|
+
# !!! DEBUG method. You don't want to use this method it because it iterates over
|
84
|
+
# the entire dataset and reads it into a ruby array in memory
|
85
|
+
def data
|
86
|
+
self.map do |slice|
|
87
|
+
if slice.is_a?(Array)
|
88
|
+
slice.map do |tensor|
|
89
|
+
tensor.value
|
90
|
+
end
|
91
|
+
else
|
92
|
+
slice.value
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
def map_func(func)
|
98
|
+
MapDataset.new(self, func)
|
99
|
+
end
|
100
|
+
|
101
|
+
def repeat(count)
|
102
|
+
RepeatDataset.new(self, 3)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class FixedLengthRecordDataset < Dataset
|
4
|
+
def initialize(filenames, record_bytes, header_bytes: 0, footer_bytes: 0,
|
5
|
+
buffer_size: DEFAULT_READER_BUFFER_SIZE_BYTES, compression_type: '', num_parallel_reads: 0)
|
6
|
+
@output_types = [:string]
|
7
|
+
@output_shapes = [[]]
|
8
|
+
|
9
|
+
record_bytes_tensor = Tensor.new(record_bytes, dtype: :int64)
|
10
|
+
header_bytes_tensor = Tensor.new(header_bytes, dtype: :int64)
|
11
|
+
footer_bytes_tensor = Tensor.new(footer_bytes, dtype: :int64)
|
12
|
+
buffer_size_tensor = Tensor.new(buffer_size, dtype: :int64)
|
13
|
+
|
14
|
+
variant_tensor = RawOps.fixed_length_record_dataset_v2(filenames,
|
15
|
+
header_bytes_tensor,
|
16
|
+
record_bytes_tensor,
|
17
|
+
footer_bytes_tensor,
|
18
|
+
buffer_size_tensor,
|
19
|
+
compression_type)
|
20
|
+
|
21
|
+
super(variant_tensor)
|
22
|
+
end
|
23
|
+
|
24
|
+
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class Iterator
|
4
|
+
attr_reader :output_types, :output_shapes
|
5
|
+
|
6
|
+
def self.from_structure(output_types, output_shapes=[], shared_name: '')
|
7
|
+
ReinitializableIterator.new(output_types, output_shapes, shared_name: shared_name)
|
8
|
+
end
|
9
|
+
|
10
|
+
def initialize(output_types, output_shapes=[])
|
11
|
+
@output_types = output_types
|
12
|
+
@output_shapes = output_shapes
|
13
|
+
end
|
14
|
+
|
15
|
+
def get_next
|
16
|
+
RawOps.iterator_get_next(@iterator, output_types: self.output_types, output_shapes: self.output_shapes)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
class OneShotIterator < Iterator
|
21
|
+
def initialize(dataset)
|
22
|
+
super(dataset.output_types, dataset.output_shapes)
|
23
|
+
create_one_shot_iterator(dataset)
|
24
|
+
end
|
25
|
+
|
26
|
+
private
|
27
|
+
|
28
|
+
def create_one_shot_iterator(dataset)
|
29
|
+
function = make_dataset_function(dataset)
|
30
|
+
ExecutionContext.current.add_function(function)
|
31
|
+
@iterator = RawOps.one_shot_iterator(dataset_factory: function, output_types: self.output_types, output_shapes: self.output_shapes)
|
32
|
+
end
|
33
|
+
|
34
|
+
def make_dataset_function(dataset)
|
35
|
+
function = Graph::Graph.new.as_default do |func_graph|
|
36
|
+
optimize = RawOps.optimize_dataset(dataset.variant_tensor, ['noop_elimination'],
|
37
|
+
output_types: self.output_types, output_shapes: self.output_shapes)
|
38
|
+
func_graph.to_function('MakeDataset', nil, nil, [optimize])
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
class InitializableIterator < Iterator
|
44
|
+
attr_reader :initializer
|
45
|
+
|
46
|
+
def initialize(dataset, shared_name: '')
|
47
|
+
super(dataset.output_types, dataset.output_shapes)
|
48
|
+
create_initializable_iterator(dataset, shared_name)
|
49
|
+
end
|
50
|
+
|
51
|
+
private
|
52
|
+
|
53
|
+
def create_initializable_iterator(dataset, shared_name)
|
54
|
+
@iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
|
55
|
+
@initializer = RawOps.make_iterator(dataset.variant_tensor, @iterator)
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
class ReinitializableIterator < Iterator
|
60
|
+
def initialize(output_types, output_shapes, shared_name: '')
|
61
|
+
super(output_types, output_shapes)
|
62
|
+
create_iterator_from_structure(shared_name)
|
63
|
+
end
|
64
|
+
|
65
|
+
def make_initializer(dataset)
|
66
|
+
RawOps.make_iterator(dataset.variant_tensor, @iterator)
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def create_iterator_from_structure(shared_name)
|
72
|
+
@iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class MapDataset < Dataset
|
4
|
+
def initialize(input_dataset, function, other_arguments: [])
|
5
|
+
@output_types = function.output_types
|
6
|
+
@output_shapes = function.output_shapes
|
7
|
+
|
8
|
+
variant_tensor = RawOps.map_dataset(input_dataset.variant_tensor, other_arguments,
|
9
|
+
f: function,
|
10
|
+
output_types: @output_types,
|
11
|
+
output_shapes: @output_shapes)
|
12
|
+
|
13
|
+
super(variant_tensor)
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class RepeatDataset < Dataset
|
4
|
+
def initialize(dataset, count)
|
5
|
+
@output_types = dataset.output_types
|
6
|
+
@output_shapes = dataset.output_shapes
|
7
|
+
|
8
|
+
variant_tensor = RawOps.repeat_dataset(dataset.variant_tensor, Tensor.new(count, dtype: :int64),
|
9
|
+
output_types: @output_types,
|
10
|
+
output_shapes: @output_shapes)
|
11
|
+
|
12
|
+
super(variant_tensor)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class ShuffleDataset < Dataset
|
4
|
+
def initialize(input_dataset, buffer_size)
|
5
|
+
@input_dataset = input_dataset
|
6
|
+
@output_types = input_dataset.output_types
|
7
|
+
@output_shapes = input_dataset.output_shapes
|
8
|
+
|
9
|
+
buffer_size = Tensor.new(buffer_size, dtype: :int64)
|
10
|
+
seed = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
|
11
|
+
seed2 = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
|
12
|
+
|
13
|
+
variant_tensor = RawOps.shuffle_dataset(input_dataset.variant_tensor,
|
14
|
+
buffer_size,
|
15
|
+
seed,
|
16
|
+
seed2,
|
17
|
+
output_types: @output_types,
|
18
|
+
output_shapes: @output_shapes)
|
19
|
+
super(variant_tensor)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class TensorDataset < Dataset
|
4
|
+
def initialize(elements)
|
5
|
+
@tensors = self.class.to_tensor_array(elements)
|
6
|
+
@output_types = @tensors.map(&:dtype)
|
7
|
+
@output_shapes = @tensors.map do |tensor|
|
8
|
+
tensor.shape
|
9
|
+
end
|
10
|
+
|
11
|
+
variant_tensor = RawOps.tensor_dataset(@tensors,
|
12
|
+
toutput_types: @output_types,
|
13
|
+
output_shapes: @output_shapes)
|
14
|
+
|
15
|
+
super(variant_tensor)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class TensorSliceDataset < Dataset
|
4
|
+
def initialize(elements)
|
5
|
+
@tensors = self.class.to_tensor_array(elements)
|
6
|
+
@output_types = @tensors.map(&:dtype)
|
7
|
+
@output_shapes = @tensors.map do |tensor|
|
8
|
+
tensor.shape[1..]
|
9
|
+
end
|
10
|
+
variant_tensor = RawOps.tensor_slice_dataset(@tensors, toutput_types: @output_types, output_shapes: @output_shapes)
|
11
|
+
super(variant_tensor)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class TfRecordDataset < Dataset
|
4
|
+
DEFAULT_BUFFER_SIZE = 256 * 1_048_576 # 256 MB
|
5
|
+
|
6
|
+
def initialize(filenames, compression_type='', buffer_size=DEFAULT_BUFFER_SIZE)
|
7
|
+
filenames = Array(filenames)
|
8
|
+
@output_types = [:string]
|
9
|
+
@output_shapes = [[]]
|
10
|
+
|
11
|
+
buffer_size = Tensor.new(buffer_size, dtype: :int64) if buffer_size
|
12
|
+
variant_tensor = RawOps.tf_record_dataset(filenames, compression_type, buffer_size)
|
13
|
+
|
14
|
+
super(variant_tensor)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Data
|
3
|
+
class ZipDataset < Dataset
|
4
|
+
def initialize(*datasets)
|
5
|
+
datasets = datasets.flatten(1)
|
6
|
+
tensors = Array.new(datasets.size)
|
7
|
+
output_types = Array.new(datasets.size)
|
8
|
+
output_shapes = Array.new(datasets.size)
|
9
|
+
|
10
|
+
datasets.each_with_index do |dataset, i|
|
11
|
+
tensors[i] = dataset.variant_tensor
|
12
|
+
output_types[i] = dataset.output_types
|
13
|
+
output_shapes[i] = dataset.output_shapes
|
14
|
+
end
|
15
|
+
|
16
|
+
@output_types = output_types.flatten
|
17
|
+
@output_shapes = output_shapes.flatten(1)
|
18
|
+
variant_tensor = RawOps.zip_dataset(tensors, n: tensors.count, output_types: @output_types, output_shapes: @output_shapes)
|
19
|
+
|
20
|
+
super(variant_tensor)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Decorator
|
3
|
+
class Function
|
4
|
+
attr_reader :input_signatures
|
5
|
+
|
6
|
+
def initialize(input_signatures = [])
|
7
|
+
@input_signatures = input_signatures
|
8
|
+
end
|
9
|
+
|
10
|
+
def wrap(method)
|
11
|
+
Graph::FunctionDef.new(method, self.input_signatures)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
def self.extended(klass)
|
16
|
+
@waiting_for_method = false
|
17
|
+
this = self
|
18
|
+
klass.instance_eval do
|
19
|
+
@tf = this
|
20
|
+
end
|
21
|
+
|
22
|
+
if klass.is_a?(Object) && klass.to_s == 'main'
|
23
|
+
klass.class.extend(self)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def self.function(input_signature = [])
|
28
|
+
@current_function = Function.new(input_signature)
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.wrap_method(method)
|
32
|
+
# We do this little dance because when the method is wrapped it will trigger method_added. So first we need
|
33
|
+
# to clear out @current_function before continuing
|
34
|
+
if @current_function
|
35
|
+
current_function = @current_function
|
36
|
+
@current_function = nil
|
37
|
+
current_function&.wrap(method)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def singleton_method_added(method_name)
|
42
|
+
super(method_name)
|
43
|
+
method = self.method(method_name)
|
44
|
+
@tf.wrap_method(method)
|
45
|
+
end
|
46
|
+
|
47
|
+
def method_added(method_name)
|
48
|
+
super(method_name)
|
49
|
+
method = self.instance_method(method_name)
|
50
|
+
@tf.wrap_method(method)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,120 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Eager
|
3
|
+
class Context
|
4
|
+
extend Forwardable
|
5
|
+
def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
|
6
|
+
|
7
|
+
def self.default
|
8
|
+
@default ||= Context.new
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.finalize(pointer)
|
12
|
+
proc { FFI.TFE_DeleteContext(pointer) }
|
13
|
+
end
|
14
|
+
|
15
|
+
def initialize
|
16
|
+
@name_scope = NameScope.new
|
17
|
+
options = FFI.TFE_NewContextOptions
|
18
|
+
Status.check do |status|
|
19
|
+
@pointer = FFI.TFE_NewContext(options, status)
|
20
|
+
end
|
21
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
22
|
+
FFI.TFE_DeleteContextOptions(options)
|
23
|
+
end
|
24
|
+
|
25
|
+
def as_default
|
26
|
+
raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
|
27
|
+
ExecutionContext.push(self)
|
28
|
+
begin
|
29
|
+
yield self
|
30
|
+
ensure
|
31
|
+
ExecutionContext.pop
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def create_operation(op_type, inputs=[], attrs={})
|
36
|
+
Operation.new(self, op_type, inputs, attrs)
|
37
|
+
end
|
38
|
+
|
39
|
+
def execute(operation)
|
40
|
+
# TODO decide how many retvals to allocate
|
41
|
+
retvals = ::FFI::MemoryPointer.new(:pointer, 10)
|
42
|
+
num_retvals = ::FFI::MemoryPointer.new(:int)
|
43
|
+
num_retvals.write_int(retvals.size)
|
44
|
+
|
45
|
+
Status.check do |status|
|
46
|
+
FFI.TFE_Execute(operation, retvals, num_retvals, status)
|
47
|
+
end
|
48
|
+
|
49
|
+
n = num_retvals.read_int
|
50
|
+
if n > 0
|
51
|
+
handles = retvals.read_array_of_pointer(n).map do |handle|
|
52
|
+
TensorHandle.new(self, handle)
|
53
|
+
end
|
54
|
+
|
55
|
+
# TODO handle case where n = 1 and still want an array for retvals
|
56
|
+
n == 1 ? handles.first : handles
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def device_policy
|
61
|
+
FFI::ContextDevicePlacementPolicy[FFI.TFE_ContextGetDevicePlacementPolicy(@pointer)]
|
62
|
+
end
|
63
|
+
|
64
|
+
def enable_run_metadata
|
65
|
+
FFI.TFE_ContextEnableRunMetadata(@pointer)
|
66
|
+
end
|
67
|
+
|
68
|
+
def disable_run_metadata
|
69
|
+
FFI.TFE_ContextDisableRunMetadata(@pointer)
|
70
|
+
end
|
71
|
+
|
72
|
+
def start_step
|
73
|
+
FFI.TFE_ContextStartStep(@pointer)
|
74
|
+
end
|
75
|
+
|
76
|
+
def end_step
|
77
|
+
FFI.TFE_ContextEndStep(@pointer)
|
78
|
+
end
|
79
|
+
|
80
|
+
def to_ptr
|
81
|
+
@pointer
|
82
|
+
end
|
83
|
+
|
84
|
+
def shared_name
|
85
|
+
# hard-coded in Python library
|
86
|
+
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"
|
87
|
+
end
|
88
|
+
|
89
|
+
# Mimic graph api
|
90
|
+
def add_to_collection(name, value)
|
91
|
+
end
|
92
|
+
|
93
|
+
# Mimic graph api
|
94
|
+
def add_to_collections(names, value)
|
95
|
+
end
|
96
|
+
|
97
|
+
def get_collection_ref(name)
|
98
|
+
end
|
99
|
+
|
100
|
+
def add_function(function)
|
101
|
+
Status.check do |status|
|
102
|
+
FFI.TFE_ContextAddFunction(self, function, status)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
def remove_function(function)
|
107
|
+
name = function.is_a?(Graph::Function) ? function.name : function
|
108
|
+
Status.check do |status|
|
109
|
+
FFI.TFE_ContextRemoveFunction(self, name, status)
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
def function?(function)
|
114
|
+
name = function.is_a?(Graph::Function) ? function.name : function
|
115
|
+
# result is uchar
|
116
|
+
FFI.TFE_ContextHasFunction(self, name) != 0
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|