tensorflow-ruby 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Layers
|
|
4
|
+
class Flatten
|
|
5
|
+
def initialize(input_shape: nil)
|
|
6
|
+
@input_shape = input_shape
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
def output_shape
|
|
10
|
+
flattened_dim = @input_shape.inject(&:*)
|
|
11
|
+
[-1, flattened_dim]
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def count_params
|
|
15
|
+
0
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def call(inputs)
|
|
19
|
+
flattened_dim = inputs.shape[1..-1].inject(&:*)
|
|
20
|
+
Tensorflow.reshape(inputs, [-1, flattened_dim])
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
end
|
|
25
|
+
end
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Losses
|
|
4
|
+
class SparseCategoricalCrossentropy
|
|
5
|
+
def call(target, output)
|
|
6
|
+
output = Math.log(output)
|
|
7
|
+
target = Tensorflow.cast(target, :int64)
|
|
8
|
+
cost, _ = RawOps.sparse_softmax_cross_entropy_with_logits(features: output, labels: target)
|
|
9
|
+
Math.reduce_sum(cost) / Tensorflow.cast(RawOps.size(input: cost), :float)
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
end
|
|
13
|
+
end
|
|
14
|
+
end
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Metrics
|
|
4
|
+
class Mean
|
|
5
|
+
def initialize(name: nil, dtype: :float)
|
|
6
|
+
@dtype = dtype
|
|
7
|
+
@total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
|
|
8
|
+
@count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def call(*args)
|
|
12
|
+
update_state(*args)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def update_state(values)
|
|
16
|
+
input = Tensorflow.cast(input, destination_dtype: @dtype)
|
|
17
|
+
@total.assign_add(Math.reduce_sum(input))
|
|
18
|
+
@count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype))
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def result
|
|
22
|
+
RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float))
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def reset_states
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
end
|
|
30
|
+
end
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Metrics
|
|
4
|
+
class SparseCategoricalAccuracy < Mean
|
|
5
|
+
def update_state(y_true, y_pred)
|
|
6
|
+
y_pred = RawOps.arg_max(y_pred, -1)
|
|
7
|
+
|
|
8
|
+
# if y_pred.dtype != y_true.dtype
|
|
9
|
+
# y_pred = Tensorflow.cast(y_pred, y_true.dtype)
|
|
10
|
+
# end
|
|
11
|
+
|
|
12
|
+
super(Math.equal(y_true, y_pred))
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
end
|
|
16
|
+
end
|
|
17
|
+
end
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Models
|
|
4
|
+
class Sequential
|
|
5
|
+
def initialize(layers = [])
|
|
6
|
+
@layers = []
|
|
7
|
+
|
|
8
|
+
layers.each do |layer|
|
|
9
|
+
add(layer)
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def add(layer)
|
|
14
|
+
@layers << layer
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def compile(optimizer: nil, loss: nil, metrics: nil)
|
|
18
|
+
raise "Not implemented"
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def fit(x, y, epochs: nil)
|
|
22
|
+
raise "Not implemented"
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def evaluate(x, y)
|
|
26
|
+
raise "Not implemented"
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def summary
|
|
30
|
+
sep = "_________________________________________________________________\n"
|
|
31
|
+
|
|
32
|
+
output_shape = nil
|
|
33
|
+
@layers.each do |layer|
|
|
34
|
+
layer.build(output_shape) if layer.respond_to?(:build)
|
|
35
|
+
output_shape = layer.output_shape
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
total_params = @layers.map(&:count_params).sum
|
|
39
|
+
|
|
40
|
+
summary = String.new("")
|
|
41
|
+
summary << "Model: \"sequential\"\n"
|
|
42
|
+
summary << sep
|
|
43
|
+
summary << "Layer (type) Output Shape Param # \n"
|
|
44
|
+
summary << "=================================================================\n"
|
|
45
|
+
summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, ([nil] + l.output_shape[1..-1]).inspect, l.count_params] }.join(sep)
|
|
46
|
+
summary << "=================================================================\n"
|
|
47
|
+
summary << "Total params: #{total_params}\n"
|
|
48
|
+
summary << "Trainable params: #{total_params}\n"
|
|
49
|
+
summary << "Non-trainable params: 0\n"
|
|
50
|
+
summary << sep
|
|
51
|
+
puts summary
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Preprocessing
|
|
4
|
+
module Image
|
|
5
|
+
class << self
|
|
6
|
+
def load_img(path, target_size: nil)
|
|
7
|
+
img = MiniMagick::Image.open(path)
|
|
8
|
+
if target_size
|
|
9
|
+
# TODO make resize consistent with Python
|
|
10
|
+
img.resize "#{target_size.map(&:to_i).join("x")}!", "-filter", "point"
|
|
11
|
+
end
|
|
12
|
+
img
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def img_to_array(img)
|
|
16
|
+
Numo::SFloat.cast(img.get_pixels)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
end
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Keras
|
|
3
|
+
module Utils
|
|
4
|
+
class << self
|
|
5
|
+
def add_weight(name: nil, shape: [], initializer: nil, dtype: :float)
|
|
6
|
+
variable = Variable.new(shape: shape, name: name, dtype: dtype)
|
|
7
|
+
initial_value =
|
|
8
|
+
case initializer
|
|
9
|
+
when "zeros"
|
|
10
|
+
Tensorflow.fill(shape, 0.0)
|
|
11
|
+
when "glorot_uniform"
|
|
12
|
+
# TODO compute fans
|
|
13
|
+
fan_in = shape[0]
|
|
14
|
+
fan_out = shape[1]
|
|
15
|
+
scale = 1.0
|
|
16
|
+
scale /= [1.0, (fan_in + fan_out) / 2.0].max
|
|
17
|
+
limit = ::Math.sqrt(3.0 * scale)
|
|
18
|
+
|
|
19
|
+
minval = -limit
|
|
20
|
+
maxval = limit
|
|
21
|
+
|
|
22
|
+
rnd = RawOps.random_uniform(shape: shape, dtype: :float)
|
|
23
|
+
Math.add(rnd * (maxval - minval), minval)
|
|
24
|
+
else
|
|
25
|
+
raise Error, "Unknown initializer: #{initializer}"
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
variable.value = initial_value
|
|
29
|
+
variable
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets")
|
|
33
|
+
# destination
|
|
34
|
+
# TODO handle this better
|
|
35
|
+
raise "No HOME" unless ENV["HOME"]
|
|
36
|
+
dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}"
|
|
37
|
+
FileUtils.mkdir_p(File.dirname(dest))
|
|
38
|
+
|
|
39
|
+
return dest if File.exist?(dest)
|
|
40
|
+
|
|
41
|
+
temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
|
|
42
|
+
temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
|
|
43
|
+
|
|
44
|
+
digest = file_hash&.size == 32 ? Digest::MD5.new : Digest::SHA2.new
|
|
45
|
+
|
|
46
|
+
uri = URI(origin)
|
|
47
|
+
|
|
48
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
|
49
|
+
# of response bodies and automatically decompresses gzip
|
|
50
|
+
# and deflateresponses unless a Range header was sent.
|
|
51
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
|
52
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
|
|
53
|
+
request = Net::HTTP::Get.new(uri)
|
|
54
|
+
|
|
55
|
+
puts "Downloading data from #{origin}"
|
|
56
|
+
i = 0
|
|
57
|
+
File.open(temp_path, "wb") do |f|
|
|
58
|
+
http.request(request) do |response|
|
|
59
|
+
response.read_body do |chunk|
|
|
60
|
+
f.write(chunk)
|
|
61
|
+
digest.update(chunk)
|
|
62
|
+
|
|
63
|
+
# print progress
|
|
64
|
+
putc "." if i % 50 == 0
|
|
65
|
+
i += 1
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
puts # newline
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
if file_hash && digest.hexdigest != file_hash
|
|
73
|
+
raise Error, "Bad hash"
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
FileUtils.mv(temp_path, dest)
|
|
77
|
+
|
|
78
|
+
dest
|
|
79
|
+
end
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
end
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
require 'set'
|
|
2
|
+
|
|
3
|
+
module Tensorflow
|
|
4
|
+
class NameScope
|
|
5
|
+
attr_reader :stack, :names
|
|
6
|
+
|
|
7
|
+
def initialize
|
|
8
|
+
@stack = Array.new
|
|
9
|
+
@names = Set.new
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def name_scope(base_name)
|
|
13
|
+
name = self.unique_name(base_name)
|
|
14
|
+
stack.push(name)
|
|
15
|
+
|
|
16
|
+
begin
|
|
17
|
+
yield current_scope if block_given?
|
|
18
|
+
ensure
|
|
19
|
+
stack.pop
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def current_scope
|
|
24
|
+
if self.stack.last.nil?
|
|
25
|
+
nil
|
|
26
|
+
else
|
|
27
|
+
self.stack.join("/")
|
|
28
|
+
end
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def scoped_name(name)
|
|
32
|
+
base_name = case
|
|
33
|
+
when self.stack.empty?
|
|
34
|
+
name
|
|
35
|
+
when self.stack.last.nil?
|
|
36
|
+
name
|
|
37
|
+
else
|
|
38
|
+
"#{self.current_scope}/#{name}"
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
self.unique_name(base_name)
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def unique_name(name)
|
|
45
|
+
return nil unless name
|
|
46
|
+
|
|
47
|
+
i = 0
|
|
48
|
+
check_name = name
|
|
49
|
+
while self.names.include?(check_name.downcase)
|
|
50
|
+
i += 1
|
|
51
|
+
check_name = "#{name}_#{i}"
|
|
52
|
+
end
|
|
53
|
+
self.names << check_name.downcase unless check_name.nil?
|
|
54
|
+
check_name
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
end
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
class OpDefBuilder
|
|
3
|
+
def self.unknown_shape_inference_func
|
|
4
|
+
@unknown_shape_func ||= FFI.ffi_libraries.first.find_function('TF_ShapeInferenceContextSetUnknownShape')
|
|
5
|
+
end
|
|
6
|
+
|
|
7
|
+
def self.finalize(pointer)
|
|
8
|
+
proc do
|
|
9
|
+
FFI::TF_DeleteOpDefinitionBuilder(pointer)
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def initialize(name)
|
|
14
|
+
@pointer = FFI.TF_NewOpDefinitionBuilder(name)
|
|
15
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def to_ptr
|
|
19
|
+
@pointer
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def attr(spec)
|
|
23
|
+
FFI.TF_OpDefinitionBuilderAddAttr(self, spec)
|
|
24
|
+
self
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def input(spec)
|
|
28
|
+
FFI.TF_OpDefinitionBuilderAddInput(self, spec)
|
|
29
|
+
self
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def output(spec)
|
|
33
|
+
FFI.TF_OpDefinitionBuilderAddOutput(self, spec)
|
|
34
|
+
self
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def shape_inference(func)
|
|
38
|
+
FFI.TF_OpDefinitionBuilderSetShapeInferenceFunction(self, func)
|
|
39
|
+
self
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
def register
|
|
43
|
+
ObjectSpace.undefine_finalizer(self)
|
|
44
|
+
Status.check do |status|
|
|
45
|
+
FFI.TF_RegisterOpDefinition(self, status)
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
end
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Audio
|
|
3
|
+
class << self
|
|
4
|
+
def decode_wav(contents, desired_channels: -1, desired_samples: -1)
|
|
5
|
+
RawOps.decode_wav(contents: contents, desired_channels: desired_channels, desired_samples: desired_samples)
|
|
6
|
+
end
|
|
7
|
+
|
|
8
|
+
def encode_wav(audio, sample_rate)
|
|
9
|
+
RawOps.encode_wav(audio: audio, sample_rate: sample_rate)
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
end
|
|
13
|
+
end
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Bitwise
|
|
3
|
+
class << self
|
|
4
|
+
def bitwise_and(x, y)
|
|
5
|
+
RawOps.bitwise_and(x, y)
|
|
6
|
+
end
|
|
7
|
+
|
|
8
|
+
def bitwise_or(x, y)
|
|
9
|
+
RawOps.bitwise_or(x, y)
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def bitwise_xor(x, y)
|
|
13
|
+
RawOps.bitwise_xor(x, y)
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def invert(x)
|
|
17
|
+
RawOps.invert(x: x)
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def left_shift(x, y)
|
|
21
|
+
RawOps.left_shift(x, y)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def right_shift(x, y)
|
|
25
|
+
RawOps.right_shift(x, y)
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
end
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
module Tensorflow
|
|
2
|
+
module Ops
|
|
3
|
+
def self.broadcast_mul(vector, matrix)
|
|
4
|
+
vector = Tensorflow.expand_dims(vector, -1)
|
|
5
|
+
vector * matrix
|
|
6
|
+
end
|
|
7
|
+
|
|
8
|
+
Graph::Gradients.register('SparseSoftmaxCrossEntropyWithLogits') do |gradient, outputs, inputs|
|
|
9
|
+
message = <<~EOS
|
|
10
|
+
Currently there is no way to take the second derivative of sparse_softmax_cross_entropy_with_logits due to the fused
|
|
11
|
+
implementation's interaction with tf.gradients()
|
|
12
|
+
EOS
|
|
13
|
+
|
|
14
|
+
graph = gradient.graph
|
|
15
|
+
operation = outputs[0].operation
|
|
16
|
+
sparse_softmax_grad_without_gradient = Tensorflow.prevent_gradient(operation[1], message: message)
|
|
17
|
+
op = Ops.broadcast_mul(gradient, sparse_softmax_grad_without_gradient)
|
|
18
|
+
op.outputs
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
end
|