tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Layers
|
4
|
+
class Flatten
|
5
|
+
def initialize(input_shape: nil)
|
6
|
+
@input_shape = input_shape
|
7
|
+
end
|
8
|
+
|
9
|
+
def output_shape
|
10
|
+
flattened_dim = @input_shape.inject(&:*)
|
11
|
+
[-1, flattened_dim]
|
12
|
+
end
|
13
|
+
|
14
|
+
def count_params
|
15
|
+
0
|
16
|
+
end
|
17
|
+
|
18
|
+
def call(inputs)
|
19
|
+
flattened_dim = inputs.shape[1..-1].inject(&:*)
|
20
|
+
Tensorflow.reshape(inputs, [-1, flattened_dim])
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Losses
|
4
|
+
class SparseCategoricalCrossentropy
|
5
|
+
def call(target, output)
|
6
|
+
output = Math.log(output)
|
7
|
+
target = Tensorflow.cast(target, :int64)
|
8
|
+
cost, _ = RawOps.sparse_softmax_cross_entropy_with_logits(features: output, labels: target)
|
9
|
+
Math.reduce_sum(cost) / Tensorflow.cast(RawOps.size(input: cost), :float)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Metrics
|
4
|
+
class Mean
|
5
|
+
def initialize(name: nil, dtype: :float)
|
6
|
+
@dtype = dtype
|
7
|
+
@total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
|
8
|
+
@count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
|
9
|
+
end
|
10
|
+
|
11
|
+
def call(*args)
|
12
|
+
update_state(*args)
|
13
|
+
end
|
14
|
+
|
15
|
+
def update_state(values)
|
16
|
+
input = Tensorflow.cast(input, destination_dtype: @dtype)
|
17
|
+
@total.assign_add(Math.reduce_sum(input))
|
18
|
+
@count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype))
|
19
|
+
end
|
20
|
+
|
21
|
+
def result
|
22
|
+
RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float))
|
23
|
+
end
|
24
|
+
|
25
|
+
def reset_states
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Metrics
|
4
|
+
class SparseCategoricalAccuracy < Mean
|
5
|
+
def update_state(y_true, y_pred)
|
6
|
+
y_pred = RawOps.arg_max(y_pred, -1)
|
7
|
+
|
8
|
+
# if y_pred.dtype != y_true.dtype
|
9
|
+
# y_pred = Tensorflow.cast(y_pred, y_true.dtype)
|
10
|
+
# end
|
11
|
+
|
12
|
+
super(Math.equal(y_true, y_pred))
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,56 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Models
|
4
|
+
class Sequential
|
5
|
+
def initialize(layers = [])
|
6
|
+
@layers = []
|
7
|
+
|
8
|
+
layers.each do |layer|
|
9
|
+
add(layer)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def add(layer)
|
14
|
+
@layers << layer
|
15
|
+
end
|
16
|
+
|
17
|
+
def compile(optimizer: nil, loss: nil, metrics: nil)
|
18
|
+
raise "Not implemented"
|
19
|
+
end
|
20
|
+
|
21
|
+
def fit(x, y, epochs: nil)
|
22
|
+
raise "Not implemented"
|
23
|
+
end
|
24
|
+
|
25
|
+
def evaluate(x, y)
|
26
|
+
raise "Not implemented"
|
27
|
+
end
|
28
|
+
|
29
|
+
def summary
|
30
|
+
sep = "_________________________________________________________________\n"
|
31
|
+
|
32
|
+
output_shape = nil
|
33
|
+
@layers.each do |layer|
|
34
|
+
layer.build(output_shape) if layer.respond_to?(:build)
|
35
|
+
output_shape = layer.output_shape
|
36
|
+
end
|
37
|
+
|
38
|
+
total_params = @layers.map(&:count_params).sum
|
39
|
+
|
40
|
+
summary = String.new("")
|
41
|
+
summary << "Model: \"sequential\"\n"
|
42
|
+
summary << sep
|
43
|
+
summary << "Layer (type) Output Shape Param # \n"
|
44
|
+
summary << "=================================================================\n"
|
45
|
+
summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, ([nil] + l.output_shape[1..-1]).inspect, l.count_params] }.join(sep)
|
46
|
+
summary << "=================================================================\n"
|
47
|
+
summary << "Total params: #{total_params}\n"
|
48
|
+
summary << "Trainable params: #{total_params}\n"
|
49
|
+
summary << "Non-trainable params: 0\n"
|
50
|
+
summary << sep
|
51
|
+
puts summary
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Preprocessing
|
4
|
+
module Image
|
5
|
+
class << self
|
6
|
+
def load_img(path, target_size: nil)
|
7
|
+
img = MiniMagick::Image.open(path)
|
8
|
+
if target_size
|
9
|
+
# TODO make resize consistent with Python
|
10
|
+
img.resize "#{target_size.map(&:to_i).join("x")}!", "-filter", "point"
|
11
|
+
end
|
12
|
+
img
|
13
|
+
end
|
14
|
+
|
15
|
+
def img_to_array(img)
|
16
|
+
Numo::SFloat.cast(img.get_pixels)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,83 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Keras
|
3
|
+
module Utils
|
4
|
+
class << self
|
5
|
+
def add_weight(name: nil, shape: [], initializer: nil, dtype: :float)
|
6
|
+
variable = Variable.new(shape: shape, name: name, dtype: dtype)
|
7
|
+
initial_value =
|
8
|
+
case initializer
|
9
|
+
when "zeros"
|
10
|
+
Tensorflow.fill(shape, 0.0)
|
11
|
+
when "glorot_uniform"
|
12
|
+
# TODO compute fans
|
13
|
+
fan_in = shape[0]
|
14
|
+
fan_out = shape[1]
|
15
|
+
scale = 1.0
|
16
|
+
scale /= [1.0, (fan_in + fan_out) / 2.0].max
|
17
|
+
limit = ::Math.sqrt(3.0 * scale)
|
18
|
+
|
19
|
+
minval = -limit
|
20
|
+
maxval = limit
|
21
|
+
|
22
|
+
rnd = RawOps.random_uniform(shape: shape, dtype: :float)
|
23
|
+
Math.add(rnd * (maxval - minval), minval)
|
24
|
+
else
|
25
|
+
raise Error, "Unknown initializer: #{initializer}"
|
26
|
+
end
|
27
|
+
|
28
|
+
variable.value = initial_value
|
29
|
+
variable
|
30
|
+
end
|
31
|
+
|
32
|
+
def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets")
|
33
|
+
# destination
|
34
|
+
# TODO handle this better
|
35
|
+
raise "No HOME" unless ENV["HOME"]
|
36
|
+
dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}"
|
37
|
+
FileUtils.mkdir_p(File.dirname(dest))
|
38
|
+
|
39
|
+
return dest if File.exist?(dest)
|
40
|
+
|
41
|
+
temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
|
42
|
+
temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
|
43
|
+
|
44
|
+
digest = file_hash&.size == 32 ? Digest::MD5.new : Digest::SHA2.new
|
45
|
+
|
46
|
+
uri = URI(origin)
|
47
|
+
|
48
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
49
|
+
# of response bodies and automatically decompresses gzip
|
50
|
+
# and deflateresponses unless a Range header was sent.
|
51
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
52
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
|
53
|
+
request = Net::HTTP::Get.new(uri)
|
54
|
+
|
55
|
+
puts "Downloading data from #{origin}"
|
56
|
+
i = 0
|
57
|
+
File.open(temp_path, "wb") do |f|
|
58
|
+
http.request(request) do |response|
|
59
|
+
response.read_body do |chunk|
|
60
|
+
f.write(chunk)
|
61
|
+
digest.update(chunk)
|
62
|
+
|
63
|
+
# print progress
|
64
|
+
putc "." if i % 50 == 0
|
65
|
+
i += 1
|
66
|
+
end
|
67
|
+
end
|
68
|
+
puts # newline
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
if file_hash && digest.hexdigest != file_hash
|
73
|
+
raise Error, "Bad hash"
|
74
|
+
end
|
75
|
+
|
76
|
+
FileUtils.mv(temp_path, dest)
|
77
|
+
|
78
|
+
dest
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
require 'set'
|
2
|
+
|
3
|
+
module Tensorflow
|
4
|
+
class NameScope
|
5
|
+
attr_reader :stack, :names
|
6
|
+
|
7
|
+
def initialize
|
8
|
+
@stack = Array.new
|
9
|
+
@names = Set.new
|
10
|
+
end
|
11
|
+
|
12
|
+
def name_scope(base_name)
|
13
|
+
name = self.unique_name(base_name)
|
14
|
+
stack.push(name)
|
15
|
+
|
16
|
+
begin
|
17
|
+
yield current_scope if block_given?
|
18
|
+
ensure
|
19
|
+
stack.pop
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def current_scope
|
24
|
+
if self.stack.last.nil?
|
25
|
+
nil
|
26
|
+
else
|
27
|
+
self.stack.join("/")
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def scoped_name(name)
|
32
|
+
base_name = case
|
33
|
+
when self.stack.empty?
|
34
|
+
name
|
35
|
+
when self.stack.last.nil?
|
36
|
+
name
|
37
|
+
else
|
38
|
+
"#{self.current_scope}/#{name}"
|
39
|
+
end
|
40
|
+
|
41
|
+
self.unique_name(base_name)
|
42
|
+
end
|
43
|
+
|
44
|
+
def unique_name(name)
|
45
|
+
return nil unless name
|
46
|
+
|
47
|
+
i = 0
|
48
|
+
check_name = name
|
49
|
+
while self.names.include?(check_name.downcase)
|
50
|
+
i += 1
|
51
|
+
check_name = "#{name}_#{i}"
|
52
|
+
end
|
53
|
+
self.names << check_name.downcase unless check_name.nil?
|
54
|
+
check_name
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class OpDefBuilder
|
3
|
+
def self.unknown_shape_inference_func
|
4
|
+
@unknown_shape_func ||= FFI.ffi_libraries.first.find_function('TF_ShapeInferenceContextSetUnknownShape')
|
5
|
+
end
|
6
|
+
|
7
|
+
def self.finalize(pointer)
|
8
|
+
proc do
|
9
|
+
FFI::TF_DeleteOpDefinitionBuilder(pointer)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def initialize(name)
|
14
|
+
@pointer = FFI.TF_NewOpDefinitionBuilder(name)
|
15
|
+
ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
|
16
|
+
end
|
17
|
+
|
18
|
+
def to_ptr
|
19
|
+
@pointer
|
20
|
+
end
|
21
|
+
|
22
|
+
def attr(spec)
|
23
|
+
FFI.TF_OpDefinitionBuilderAddAttr(self, spec)
|
24
|
+
self
|
25
|
+
end
|
26
|
+
|
27
|
+
def input(spec)
|
28
|
+
FFI.TF_OpDefinitionBuilderAddInput(self, spec)
|
29
|
+
self
|
30
|
+
end
|
31
|
+
|
32
|
+
def output(spec)
|
33
|
+
FFI.TF_OpDefinitionBuilderAddOutput(self, spec)
|
34
|
+
self
|
35
|
+
end
|
36
|
+
|
37
|
+
def shape_inference(func)
|
38
|
+
FFI.TF_OpDefinitionBuilderSetShapeInferenceFunction(self, func)
|
39
|
+
self
|
40
|
+
end
|
41
|
+
|
42
|
+
def register
|
43
|
+
ObjectSpace.undefine_finalizer(self)
|
44
|
+
Status.check do |status|
|
45
|
+
FFI.TF_RegisterOpDefinition(self, status)
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Audio
|
3
|
+
class << self
|
4
|
+
def decode_wav(contents, desired_channels: -1, desired_samples: -1)
|
5
|
+
RawOps.decode_wav(contents: contents, desired_channels: desired_channels, desired_samples: desired_samples)
|
6
|
+
end
|
7
|
+
|
8
|
+
def encode_wav(audio, sample_rate)
|
9
|
+
RawOps.encode_wav(audio: audio, sample_rate: sample_rate)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Bitwise
|
3
|
+
class << self
|
4
|
+
def bitwise_and(x, y)
|
5
|
+
RawOps.bitwise_and(x, y)
|
6
|
+
end
|
7
|
+
|
8
|
+
def bitwise_or(x, y)
|
9
|
+
RawOps.bitwise_or(x, y)
|
10
|
+
end
|
11
|
+
|
12
|
+
def bitwise_xor(x, y)
|
13
|
+
RawOps.bitwise_xor(x, y)
|
14
|
+
end
|
15
|
+
|
16
|
+
def invert(x)
|
17
|
+
RawOps.invert(x: x)
|
18
|
+
end
|
19
|
+
|
20
|
+
def left_shift(x, y)
|
21
|
+
RawOps.left_shift(x, y)
|
22
|
+
end
|
23
|
+
|
24
|
+
def right_shift(x, y)
|
25
|
+
RawOps.right_shift(x, y)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Ops
|
3
|
+
def self.broadcast_mul(vector, matrix)
|
4
|
+
vector = Tensorflow.expand_dims(vector, -1)
|
5
|
+
vector * matrix
|
6
|
+
end
|
7
|
+
|
8
|
+
Graph::Gradients.register('SparseSoftmaxCrossEntropyWithLogits') do |gradient, outputs, inputs|
|
9
|
+
message = <<~EOS
|
10
|
+
Currently there is no way to take the second derivative of sparse_softmax_cross_entropy_with_logits due to the fused
|
11
|
+
implementation's interaction with tf.gradients()
|
12
|
+
EOS
|
13
|
+
|
14
|
+
graph = gradient.graph
|
15
|
+
operation = outputs[0].operation
|
16
|
+
sparse_softmax_grad_without_gradient = Tensorflow.prevent_gradient(operation[1], message: message)
|
17
|
+
op = Ops.broadcast_mul(gradient, sparse_softmax_grad_without_gradient)
|
18
|
+
op.outputs
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|