tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module TensorMixin
|
3
|
+
def shape
|
4
|
+
@shape ||= begin
|
5
|
+
status = Status.new
|
6
|
+
shape = []
|
7
|
+
if self
|
8
|
+
num_dims.times do |i|
|
9
|
+
shape << dim(i)
|
10
|
+
status.check
|
11
|
+
end
|
12
|
+
end
|
13
|
+
shape
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def numo
|
18
|
+
case dtype
|
19
|
+
when NilClass
|
20
|
+
nil
|
21
|
+
when :variant
|
22
|
+
:variant
|
23
|
+
when :string
|
24
|
+
:string
|
25
|
+
else
|
26
|
+
klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
|
27
|
+
raise "Unknown type: #{dtype}" unless klass
|
28
|
+
klass.cast(value)
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: tensorflow/core/util/event.proto
|
3
|
+
|
4
|
+
require 'google/protobuf'
|
5
|
+
|
6
|
+
require 'tensorflow/core/framework/summary_pb'
|
7
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
8
|
+
add_file("tensorflow/core/util/event.proto", :syntax => :proto3) do
|
9
|
+
add_message "tensorflow.Event" do
|
10
|
+
optional :wall_time, :double, 1
|
11
|
+
optional :step, :int64, 2
|
12
|
+
oneof :what do
|
13
|
+
optional :file_version, :string, 3
|
14
|
+
optional :graph_def, :bytes, 4
|
15
|
+
optional :summary, :message, 5, "tensorflow.Summary"
|
16
|
+
optional :log_message, :message, 6, "tensorflow.LogMessage"
|
17
|
+
optional :session_log, :message, 7, "tensorflow.SessionLog"
|
18
|
+
optional :tagged_run_metadata, :message, 8, "tensorflow.TaggedRunMetadata"
|
19
|
+
optional :meta_graph_def, :bytes, 9
|
20
|
+
end
|
21
|
+
end
|
22
|
+
add_message "tensorflow.LogMessage" do
|
23
|
+
optional :level, :enum, 1, "tensorflow.LogMessage.Level"
|
24
|
+
optional :message, :string, 2
|
25
|
+
end
|
26
|
+
add_enum "tensorflow.LogMessage.Level" do
|
27
|
+
value :UNKNOWN, 0
|
28
|
+
value :DEBUGGING, 10
|
29
|
+
value :INFO, 20
|
30
|
+
value :WARN, 30
|
31
|
+
value :ERROR, 40
|
32
|
+
value :FATAL, 50
|
33
|
+
end
|
34
|
+
add_message "tensorflow.SessionLog" do
|
35
|
+
optional :status, :enum, 1, "tensorflow.SessionLog.SessionStatus"
|
36
|
+
optional :checkpoint_path, :string, 2
|
37
|
+
optional :msg, :string, 3
|
38
|
+
end
|
39
|
+
add_enum "tensorflow.SessionLog.SessionStatus" do
|
40
|
+
value :STATUS_UNSPECIFIED, 0
|
41
|
+
value :START, 1
|
42
|
+
value :STOP, 2
|
43
|
+
value :CHECKPOINT, 3
|
44
|
+
end
|
45
|
+
add_message "tensorflow.TaggedRunMetadata" do
|
46
|
+
optional :tag, :string, 1
|
47
|
+
optional :run_metadata, :bytes, 2
|
48
|
+
end
|
49
|
+
add_message "tensorflow.WatchdogConfig" do
|
50
|
+
optional :timeout_ms, :int64, 1
|
51
|
+
end
|
52
|
+
add_message "tensorflow.RequestedExitCode" do
|
53
|
+
optional :exit_code, :int32, 1
|
54
|
+
end
|
55
|
+
add_message "tensorflow.WorkerHeartbeatRequest" do
|
56
|
+
optional :shutdown_mode, :enum, 1, "tensorflow.WorkerShutdownMode"
|
57
|
+
optional :watchdog_config, :message, 2, "tensorflow.WatchdogConfig"
|
58
|
+
optional :exit_code, :message, 3, "tensorflow.RequestedExitCode"
|
59
|
+
end
|
60
|
+
add_message "tensorflow.WorkerHeartbeatResponse" do
|
61
|
+
optional :health_status, :enum, 1, "tensorflow.WorkerHealth"
|
62
|
+
repeated :worker_log, :message, 2, "tensorflow.Event"
|
63
|
+
optional :hostname, :string, 3
|
64
|
+
end
|
65
|
+
add_enum "tensorflow.WorkerHealth" do
|
66
|
+
value :OK, 0
|
67
|
+
value :RECEIVED_SHUTDOWN_SIGNAL, 1
|
68
|
+
value :INTERNAL_ERROR, 2
|
69
|
+
value :SHUTTING_DOWN, 3
|
70
|
+
end
|
71
|
+
add_enum "tensorflow.WorkerShutdownMode" do
|
72
|
+
value :DEFAULT, 0
|
73
|
+
value :NOT_CONFIGURED, 1
|
74
|
+
value :WAIT_FOR_COORDINATOR, 2
|
75
|
+
value :SHUTDOWN_AFTER_TIMEOUT, 3
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
module Tensorflow
|
81
|
+
Event = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.Event").msgclass
|
82
|
+
LogMessage = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.LogMessage").msgclass
|
83
|
+
LogMessage::Level = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.LogMessage.Level").enummodule
|
84
|
+
SessionLog = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.SessionLog").msgclass
|
85
|
+
SessionLog::SessionStatus = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.SessionLog.SessionStatus").enummodule
|
86
|
+
TaggedRunMetadata = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.TaggedRunMetadata").msgclass
|
87
|
+
WatchdogConfig = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WatchdogConfig").msgclass
|
88
|
+
RequestedExitCode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.RequestedExitCode").msgclass
|
89
|
+
WorkerHeartbeatRequest = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHeartbeatRequest").msgclass
|
90
|
+
WorkerHeartbeatResponse = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHeartbeatResponse").msgclass
|
91
|
+
WorkerHealth = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerHealth").enummodule
|
92
|
+
WorkerShutdownMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("tensorflow.WorkerShutdownMode").enummodule
|
93
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# Based on code from https://github.com/jedld/tensor_stream
|
2
|
+
|
3
|
+
module Tensorflow
|
4
|
+
module Train
|
5
|
+
class GradientDescentOptimizer < Optimizer
|
6
|
+
attr_accessor :learning_rate
|
7
|
+
|
8
|
+
def initialize(learning_rate, use_locking: false, name: "GradientDescent")
|
9
|
+
@learning_rate = learning_rate
|
10
|
+
@learning_rate_tensor = nil
|
11
|
+
super(name: name, use_locking: use_locking)
|
12
|
+
end
|
13
|
+
|
14
|
+
protected
|
15
|
+
|
16
|
+
def prepare
|
17
|
+
learning_rate = call_if_callable(@learning_rate)
|
18
|
+
@learning_rate_tensor = Tensorflow.constant(learning_rate, name: "learning_rate")
|
19
|
+
end
|
20
|
+
|
21
|
+
def apply_dense(grad, var)
|
22
|
+
dtype = grad.output_types.first
|
23
|
+
learning_rate = if @learning_rate_tensor.output_types.first == dtype
|
24
|
+
@learning_rate_tensor
|
25
|
+
else
|
26
|
+
Tensorflow.cast(@learning_rate_tensor, dtype)
|
27
|
+
end
|
28
|
+
|
29
|
+
RawOps.resource_apply_gradient_descent(var, learning_rate, grad)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Based on code from https://github.com/jedld/tensor_stream
|
2
|
+
module Tensorflow
|
3
|
+
module Train
|
4
|
+
class Optimizer
|
5
|
+
attr_reader :name
|
6
|
+
|
7
|
+
def initialize(name: nil, use_locking: false)
|
8
|
+
@name = name
|
9
|
+
@use_locking = use_locking
|
10
|
+
raise(Error::InvalidArgumentError, "Must specify the optimizer name") unless name
|
11
|
+
|
12
|
+
@slots = {}
|
13
|
+
@non_slots = {}
|
14
|
+
end
|
15
|
+
|
16
|
+
def graph
|
17
|
+
ExecutionContext.current
|
18
|
+
end
|
19
|
+
|
20
|
+
def minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil)
|
21
|
+
grads_and_vars = compute_gradients(loss, var_list: var_list, grad_loss: grad_loss)
|
22
|
+
if grads_and_vars.empty?
|
23
|
+
raise(Error::InvalidArgumentError, "No gradients provided for any variable, check your graph for ops that do not support gradients")
|
24
|
+
end
|
25
|
+
apply_gradients(grads_and_vars, global_step: global_step, name: name)
|
26
|
+
end
|
27
|
+
|
28
|
+
def apply_gradients(grads_and_vars, global_step: nil, name: nil)
|
29
|
+
varlist = grads_and_vars.map { |_grad, var| var }
|
30
|
+
#create_slots(varlist)
|
31
|
+
#TensorStream.name_scope(name, default: @name) do
|
32
|
+
prepare
|
33
|
+
apply_ops = grads_and_vars.map do |grad, var|
|
34
|
+
#TensorStream.name_scope("update_" + var.op.name) do
|
35
|
+
apply_dense(grad, var)
|
36
|
+
#end
|
37
|
+
end
|
38
|
+
|
39
|
+
if global_step.nil?
|
40
|
+
finish(apply_ops, name)
|
41
|
+
else
|
42
|
+
global_step.handle.graph.control_dependencies([finish(apply_ops, "update")]) do
|
43
|
+
global_step.assign_add(Tensorflow.constant(1, dtype:global_step.dtype))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
#end
|
47
|
+
end
|
48
|
+
|
49
|
+
def compute_gradients(loss, var_list: nil, grad_loss: nil)
|
50
|
+
trainable_vars = var_list || self.graph.get_collection_ref(Tensorflow::Graph::GraphKeys::TRAINABLE_VARIABLES)
|
51
|
+
|
52
|
+
if trainable_vars.nil? || trainable_vars.empty?
|
53
|
+
raise(Error::InvalidArgumentError, 'There are no variables to train for the loss function')
|
54
|
+
end
|
55
|
+
gradients = Graph::Gradients.new(graph)
|
56
|
+
grads = gradients.gradients(loss, trainable_vars, grad_ys: grad_loss)
|
57
|
+
|
58
|
+
grads.zip(trainable_vars)
|
59
|
+
end
|
60
|
+
|
61
|
+
def get_slot(var, name)
|
62
|
+
named_slots = @slots.fetch(name, nil)
|
63
|
+
return nil if named_slots.nil?
|
64
|
+
|
65
|
+
named_slots.fetch(var_key(var), nil)
|
66
|
+
end
|
67
|
+
|
68
|
+
def get_slot_names
|
69
|
+
@slots.keys.sort
|
70
|
+
end
|
71
|
+
|
72
|
+
protected
|
73
|
+
|
74
|
+
def finish(update_ops, name_scope)
|
75
|
+
Control.group(update_ops, name: name_scope)
|
76
|
+
end
|
77
|
+
|
78
|
+
def create_slots(var_list)
|
79
|
+
# no implementation
|
80
|
+
end
|
81
|
+
|
82
|
+
def prepare
|
83
|
+
# no implementation
|
84
|
+
end
|
85
|
+
|
86
|
+
def apply_dense(_grad, _var)
|
87
|
+
raise(Error::UnimplementedError, "Not implemented")
|
88
|
+
end
|
89
|
+
|
90
|
+
##
|
91
|
+
# Find or create a slot initialized with 0.0.
|
92
|
+
#
|
93
|
+
# Args:
|
94
|
+
# var: Variable - A Variable object
|
95
|
+
# slot_name: string - Name for the slot
|
96
|
+
# op_name: string - Name to use when scoping the Variable that needs to be created
|
97
|
+
def zeros_slot(var, slot_name, op_name)
|
98
|
+
named_slots = slot_dict(slot_name)
|
99
|
+
unless named_slots.key?(var_key(var))
|
100
|
+
named_slots[var_key(var)] = create_zeros_slot(var, op_name)
|
101
|
+
end
|
102
|
+
named_slots[var_key(var)]
|
103
|
+
end
|
104
|
+
|
105
|
+
##
|
106
|
+
# Returns a dict for caching slots created under the given name.
|
107
|
+
#
|
108
|
+
# Args:
|
109
|
+
# slot_name string Name for the slot
|
110
|
+
#
|
111
|
+
# Returns: A dict that maps primary 'Variable' objects to the slot created
|
112
|
+
def slot_dict(slot_name)
|
113
|
+
named_slots = @slots.fetch(slot_name, nil)
|
114
|
+
if named_slots.nil?
|
115
|
+
named_slots = {}
|
116
|
+
@slots[slot_name] = named_slots
|
117
|
+
end
|
118
|
+
named_slots
|
119
|
+
end
|
120
|
+
|
121
|
+
def var_key(var)
|
122
|
+
[var.op.graph, var.op.name]
|
123
|
+
end
|
124
|
+
|
125
|
+
def get_non_slot_variable(name, graph: nil)
|
126
|
+
non_slot = @non_slots.fetch([name, graph], nil)
|
127
|
+
non_slot
|
128
|
+
end
|
129
|
+
|
130
|
+
def call_if_callable(param)
|
131
|
+
param.is_a?(Proc) ? param.call : param
|
132
|
+
end
|
133
|
+
|
134
|
+
def create_non_slot_variable(initial_value, name, colocate_with)
|
135
|
+
graph = colocate_with.graph
|
136
|
+
|
137
|
+
key = [name, graph]
|
138
|
+
v = @non_slots.fetch(key, nil)
|
139
|
+
if v.nil?
|
140
|
+
v = TensorStream.variable(initial_value, name: name, trainable: false)
|
141
|
+
@non_slots[key] = v
|
142
|
+
end
|
143
|
+
v
|
144
|
+
end
|
145
|
+
|
146
|
+
##
|
147
|
+
# Find or create a slot for a variable, using an Initializer.
|
148
|
+
def get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name)
|
149
|
+
named_slots = slot_dict(slot_name)
|
150
|
+
unless named_slots.key?(var_key(var))
|
151
|
+
new_slot_variable = create_slot_with_initializer(var, initializer, shape, dtype, op_name)
|
152
|
+
named_slots[var_key(var)] = new_slot_variable
|
153
|
+
end
|
154
|
+
named_slots[var_key(var)]
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
end
|
@@ -0,0 +1,127 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
class Variable
|
3
|
+
include Operators
|
4
|
+
|
5
|
+
attr_reader :handle, :dtype, :name
|
6
|
+
|
7
|
+
def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
|
8
|
+
initial_value = case initial_value
|
9
|
+
when NilClass
|
10
|
+
@dtype = dtype
|
11
|
+
shape = []
|
12
|
+
initial_value
|
13
|
+
when Graph::Operation
|
14
|
+
@dtype = dtype || initial_value.dtype
|
15
|
+
shape = shape || initial_value.output_shapes.first
|
16
|
+
initial_value
|
17
|
+
when Tensor
|
18
|
+
@dtype = initial_value.dtype
|
19
|
+
shape = shape || initial_value.shape
|
20
|
+
initial_value
|
21
|
+
else
|
22
|
+
tensor = Tensor.from_value(initial_value, dtype: dtype)
|
23
|
+
@dtype = tensor.dtype
|
24
|
+
shape = tensor.shape
|
25
|
+
tensor
|
26
|
+
end
|
27
|
+
|
28
|
+
name = name&.to_s
|
29
|
+
shared_name = shared_name&.to_s
|
30
|
+
unique_name = ExecutionContext.current.unique_name(name || shared_name)
|
31
|
+
shared_name ||= unique_name
|
32
|
+
@name = unique_name
|
33
|
+
|
34
|
+
collections = [Graph::GraphKeys::GLOBAL_VARIABLES]
|
35
|
+
if trainable
|
36
|
+
collections << Graph::GraphKeys::TRAINABLE_VARIABLES
|
37
|
+
end
|
38
|
+
|
39
|
+
ExecutionContext.current.add_to_collections(collections, self)
|
40
|
+
|
41
|
+
@handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name)
|
42
|
+
self.value = initial_value if initial_value
|
43
|
+
end
|
44
|
+
|
45
|
+
def value_handle
|
46
|
+
@value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype)
|
47
|
+
end
|
48
|
+
|
49
|
+
def value
|
50
|
+
case value_handle
|
51
|
+
when Eager::TensorHandle
|
52
|
+
value_handle.value
|
53
|
+
when Graph::Operation
|
54
|
+
value_handle
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
def value=(value)
|
59
|
+
@initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype)
|
60
|
+
end
|
61
|
+
|
62
|
+
def initializer
|
63
|
+
@initializer
|
64
|
+
end
|
65
|
+
|
66
|
+
def initialized?
|
67
|
+
RawOps.var_is_initialized_op(self.handle)
|
68
|
+
end
|
69
|
+
|
70
|
+
# These methods match the operation api to enable gradients and sessions
|
71
|
+
def consumers
|
72
|
+
self.handle.consumers
|
73
|
+
end
|
74
|
+
|
75
|
+
# This enables executing variables to get the values in a session
|
76
|
+
def outputs
|
77
|
+
[Graph::OperationOutput.from_index(self.value_handle, 0)]
|
78
|
+
end
|
79
|
+
|
80
|
+
def to_ptr
|
81
|
+
self.handle.to_ptr
|
82
|
+
end
|
83
|
+
|
84
|
+
def shape
|
85
|
+
self.value_handle.shape
|
86
|
+
end
|
87
|
+
|
88
|
+
def tensor
|
89
|
+
raise(Error::UnavailableError, "Only supported in eager execution mode") if Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
|
90
|
+
self.value_handle.tensor
|
91
|
+
end
|
92
|
+
|
93
|
+
def rank
|
94
|
+
self.shape.size
|
95
|
+
end
|
96
|
+
|
97
|
+
def reshape(shape)
|
98
|
+
RawOps.reshape(self, shape)
|
99
|
+
end
|
100
|
+
|
101
|
+
def assign_add(value, dtype: nil)
|
102
|
+
@value_handle = nil
|
103
|
+
tensor = Tensor.from_value(value, dtype: dtype)
|
104
|
+
tensor = Tensorflow.cast(tensor, self.dtype)
|
105
|
+
RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype)
|
106
|
+
end
|
107
|
+
|
108
|
+
def assign_sub(value)
|
109
|
+
@value_handle = nil
|
110
|
+
tensor = Tensor.from_value(value, dtype: dtype)
|
111
|
+
tensor = Tensorflow.cast(tensor, self.dtype)
|
112
|
+
RawOps.assign_sub_variable_op(self.handle, value, dtype: tensor.dtype)
|
113
|
+
end
|
114
|
+
|
115
|
+
def to_s
|
116
|
+
inspect
|
117
|
+
end
|
118
|
+
|
119
|
+
def inspect
|
120
|
+
inspection = []
|
121
|
+
inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name)
|
122
|
+
inspection << ["shape: #{self.value_handle.shape}"]
|
123
|
+
inspection << ["dtype: #{self.value_handle.dtype}"]
|
124
|
+
"#<#{self.class} #{inspection.join(", ")}>"
|
125
|
+
end
|
126
|
+
end
|
127
|
+
end
|