tensorflow-ruby 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +18 -0
- data/LICENSE.txt +22 -0
- data/README.md +104 -0
- data/lib/datasets/download_manager.rb +49 -0
- data/lib/datasets/images/mnist.rb +54 -0
- data/lib/datasets/resource.rb +19 -0
- data/lib/tensorflow-ruby.rb +182 -0
- data/lib/tensorflow.rb +1 -0
- data/lib/tensorflow/batchable_type_spec.rb +4 -0
- data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
- data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
- data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
- data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
- data/lib/tensorflow/core/framework/function_pb.rb +38 -0
- data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
- data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
- data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
- data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
- data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
- data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
- data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
- data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
- data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
- data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
- data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
- data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
- data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
- data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
- data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
- data/lib/tensorflow/core/framework/types_pb.rb +62 -0
- data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
- data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
- data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
- data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
- data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
- data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
- data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
- data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
- data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
- data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
- data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
- data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
- data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
- data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
- data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
- data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
- data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
- data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
- data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
- data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
- data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
- data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
- data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
- data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
- data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
- data/lib/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
- data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
- data/lib/tensorflow/data/batch_dataset.rb +18 -0
- data/lib/tensorflow/data/dataset.rb +106 -0
- data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
- data/lib/tensorflow/data/iterator.rb +76 -0
- data/lib/tensorflow/data/map_dataset.rb +17 -0
- data/lib/tensorflow/data/repeat_dataset.rb +16 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
- data/lib/tensorflow/data/tensor_dataset.rb +19 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
- data/lib/tensorflow/data/zip_dataset.rb +24 -0
- data/lib/tensorflow/decorators.rb +53 -0
- data/lib/tensorflow/eager/context.rb +120 -0
- data/lib/tensorflow/eager/operation.rb +219 -0
- data/lib/tensorflow/eager/tensor_handle.rb +87 -0
- data/lib/tensorflow/error.rb +54 -0
- data/lib/tensorflow/execution_context.rb +62 -0
- data/lib/tensorflow/extensions/arg_def.rb +58 -0
- data/lib/tensorflow/extensions/array.rb +17 -0
- data/lib/tensorflow/extensions/boolean.rb +25 -0
- data/lib/tensorflow/extensions/narray.rb +7 -0
- data/lib/tensorflow/ffi.rb +291 -0
- data/lib/tensorflow/graph/function.rb +33 -0
- data/lib/tensorflow/graph/function_def.rb +62 -0
- data/lib/tensorflow/graph/gradients.rb +120 -0
- data/lib/tensorflow/graph/graph.rb +252 -0
- data/lib/tensorflow/graph/graph_def_options.rb +24 -0
- data/lib/tensorflow/graph/graph_keys.rb +50 -0
- data/lib/tensorflow/graph/operation.rb +176 -0
- data/lib/tensorflow/graph/operation_attr.rb +153 -0
- data/lib/tensorflow/graph/operation_description.rb +255 -0
- data/lib/tensorflow/graph/operation_output.rb +49 -0
- data/lib/tensorflow/graph/session.rb +156 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/layers/conv.rb +14 -0
- data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
- data/lib/tensorflow/keras/layers/dense.rb +68 -0
- data/lib/tensorflow/keras/layers/dropout.rb +27 -0
- data/lib/tensorflow/keras/layers/flatten.rb +25 -0
- data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
- data/lib/tensorflow/keras/metrics/mean.rb +30 -0
- data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
- data/lib/tensorflow/keras/model.rb +6 -0
- data/lib/tensorflow/keras/models/sequential.rb +56 -0
- data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
- data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
- data/lib/tensorflow/keras/utils.rb +83 -0
- data/lib/tensorflow/name_scope.rb +57 -0
- data/lib/tensorflow/op_def_builder.rb +49 -0
- data/lib/tensorflow/ops/audio.rb +13 -0
- data/lib/tensorflow/ops/bitwise.rb +29 -0
- data/lib/tensorflow/ops/control.rb +13 -0
- data/lib/tensorflow/ops/gradients.rb +21 -0
- data/lib/tensorflow/ops/image.rb +218 -0
- data/lib/tensorflow/ops/io.rb +123 -0
- data/lib/tensorflow/ops/linalg.rb +131 -0
- data/lib/tensorflow/ops/math.rb +493 -0
- data/lib/tensorflow/ops/nn.rb +286 -0
- data/lib/tensorflow/ops/operators.rb +31 -0
- data/lib/tensorflow/ops/ops.rb +102 -0
- data/lib/tensorflow/ops/random.rb +18 -0
- data/lib/tensorflow/ops/raw_ops.rb +5179 -0
- data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
- data/lib/tensorflow/printers/graph.erb +80 -0
- data/lib/tensorflow/printers/graph.rb +26 -0
- data/lib/tensorflow/printers/graph_def.erb +109 -0
- data/lib/tensorflow/printers/graph_def.rb +26 -0
- data/lib/tensorflow/python_compatiblity.rb +55 -0
- data/lib/tensorflow/resource_summary_writer.rb +78 -0
- data/lib/tensorflow/status.rb +49 -0
- data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/summary.rb +13 -0
- data/lib/tensorflow/tensor.rb +133 -0
- data/lib/tensorflow/tensor_data.rb +310 -0
- data/lib/tensorflow/tensor_mixin.rb +32 -0
- data/lib/tensorflow/tensor_spec.rb +10 -0
- data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
- data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
- data/lib/tensorflow/train/optimizer.rb +158 -0
- data/lib/tensorflow/type_spec.rb +4 -0
- data/lib/tensorflow/variable.rb +127 -0
- data/lib/tensorflow/version.rb +3 -0
- metadata +308 -0
@@ -0,0 +1,17 @@
|
|
1
|
+
class Array
|
2
|
+
# Helper method to make writing tests easier. Allows this to work:
|
3
|
+
#
|
4
|
+
# assert_equal([1,2], Numo::Array[1,2])
|
5
|
+
#
|
6
|
+
# Versus having to do this:
|
7
|
+
# assert_equal([1,2], Numo::Array[1,2].to_a)
|
8
|
+
|
9
|
+
alias :original_equals :==
|
10
|
+
def ==(other)
|
11
|
+
if other.kind_of?(Numo::NArray)
|
12
|
+
self.eql?(other.to_a)
|
13
|
+
else
|
14
|
+
original_equals(other)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# encoding: UTF-8
|
2
|
+
|
3
|
+
module Kernel
|
4
|
+
def Boolean(value)
|
5
|
+
# Rails converts true/false to 't' and 'f' in this case
|
6
|
+
# because it does not have data dictionary information for
|
7
|
+
# these fields and doesn't seem to be able to figure it
|
8
|
+
# out from the query results.
|
9
|
+
if not value
|
10
|
+
false
|
11
|
+
elsif value.to_s.match(/^(t|true|1|yes|y)$/i)
|
12
|
+
true
|
13
|
+
else
|
14
|
+
false
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
class TrueClass
|
20
|
+
def to_i() 1; end
|
21
|
+
end
|
22
|
+
|
23
|
+
class FalseClass
|
24
|
+
def to_i() 0; end
|
25
|
+
end
|
@@ -0,0 +1,291 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module FFI
|
3
|
+
extend ::FFI::Library
|
4
|
+
|
5
|
+
begin
|
6
|
+
ffi_lib ["tensorflow", "libtensorflow"]
|
7
|
+
rescue LoadError => e
|
8
|
+
raise e if ENV["TENSORFLOW_DEBUG"]
|
9
|
+
raise LoadError, "Could not find Tensorflow"
|
10
|
+
end
|
11
|
+
|
12
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_attrtype.h
|
13
|
+
AttrType = enum(:string, :int, :float, :bool, :type, :shape, :tensor, :placeholder, :func)
|
14
|
+
|
15
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_datatype.h
|
16
|
+
DataType = enum(:float, 1, :double, :int32, :uint8, :int16, :int8, :string, :complex64, :int64, :bool, :qint8, :quint8, :qint32, :bfloat16, :qint16, :quint16, :uint16, :complex128, :half, :resource, :variant, :uint32, :uint64)
|
17
|
+
|
18
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
|
19
|
+
class Buffer < ::FFI::Struct
|
20
|
+
layout :data, :pointer,
|
21
|
+
:length, :size_t,
|
22
|
+
:data_deallocator, :pointer
|
23
|
+
end
|
24
|
+
attach_function :TF_NewBuffer, [], :pointer
|
25
|
+
attach_function :TF_DeleteBuffer, [:pointer], :void
|
26
|
+
attach_function :TF_GetBuffer, [:pointer], :pointer
|
27
|
+
|
28
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
|
29
|
+
attach_function :TF_Version, [], :string
|
30
|
+
attach_function :TF_GetAllOpList, [], Buffer.by_ref
|
31
|
+
|
32
|
+
class Input < ::FFI::Struct
|
33
|
+
layout :oper, :pointer,
|
34
|
+
:index, :int
|
35
|
+
|
36
|
+
def to_s
|
37
|
+
"#{self[:oper]}: #{self[:index]}"
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
class Output < ::FFI::Struct
|
42
|
+
layout :oper, :pointer,
|
43
|
+
:index, :int
|
44
|
+
|
45
|
+
def self.array_to_ptr(outputs)
|
46
|
+
result = ::FFI::MemoryPointer.new(self, outputs.length)
|
47
|
+
outputs.each_with_index do |output, i|
|
48
|
+
copy_output = self.new(result[i])
|
49
|
+
copy_output[:oper] = output[:oper]
|
50
|
+
copy_output[:index] = output[:index]
|
51
|
+
end
|
52
|
+
result
|
53
|
+
end
|
54
|
+
|
55
|
+
def to_s
|
56
|
+
"#{self[:oper]}: #{self[:index]}"
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
attach_function :TF_NewGraph, [], :pointer
|
61
|
+
attach_function :TF_DeleteGraph, [:pointer], :pointer
|
62
|
+
attach_function :TF_GraphGetOpDef, [:pointer, :string, :pointer, :pointer], :void
|
63
|
+
|
64
|
+
attach_function :TF_GraphGetTensorNumDims, [:pointer, Output, :pointer], :int
|
65
|
+
attach_function :TF_GraphGetTensorShape, [:pointer, Output, :pointer, :int, :pointer], :void
|
66
|
+
attach_function :TF_GraphSetTensorShape, [:pointer, Output, :pointer, :int, :pointer], :void
|
67
|
+
|
68
|
+
attach_function :TF_NewOperation, [:pointer, :string, :string], :pointer
|
69
|
+
attach_function :TF_FinishOperation, [:pointer, :pointer], :pointer
|
70
|
+
attach_function :TF_SetDevice, [:pointer, :string,], :void
|
71
|
+
attach_function :TF_SetAttrBool, [:pointer, :string, :uchar], :void
|
72
|
+
attach_function :TF_SetAttrBoolList, [:pointer, :string, :pointer, :int], :void
|
73
|
+
attach_function :TF_SetAttrInt, [:pointer, :string, :int64], :void
|
74
|
+
attach_function :TF_SetAttrIntList, [:pointer, :string, :pointer, :int], :void
|
75
|
+
attach_function :TF_SetAttrFloat, [:pointer, :string, :float], :void
|
76
|
+
attach_function :TF_SetAttrFloatList, [:pointer, :string, :pointer, :int], :void
|
77
|
+
attach_function :TF_SetAttrFuncName, [:pointer, :string, :string, :size_t], :void
|
78
|
+
attach_function :TF_SetAttrPlaceholder, [:pointer, :string, :string], :void
|
79
|
+
attach_function :TF_SetAttrShape, [:pointer, :string, :pointer, :int], :void
|
80
|
+
attach_function :TF_SetAttrShapeList, [:pointer, :string, :pointer, :pointer, :int], :void
|
81
|
+
attach_function :TF_SetAttrString, [:pointer, :string, :pointer, :size_t], :void
|
82
|
+
attach_function :TF_SetAttrStringList, [:pointer, :string, :pointer, :pointer, :int], :void
|
83
|
+
attach_function :TF_SetAttrType, [:pointer, :string, DataType], :void
|
84
|
+
attach_function :TF_SetAttrTypeList, [:pointer, :string, :pointer, :int], :void
|
85
|
+
attach_function :TF_SetAttrTensor, [:pointer, :string, :pointer, :pointer], :void
|
86
|
+
attach_function :TF_SetAttrTensorList, [:pointer, :string, :pointer, :int, :pointer], :void
|
87
|
+
|
88
|
+
attach_function :TF_AddInput, [:pointer, Output], :void
|
89
|
+
attach_function :TF_AddInputList, [:pointer, :pointer, :int], :void
|
90
|
+
|
91
|
+
attach_function :TF_AddControlInput, [:pointer, :pointer], :void
|
92
|
+
attach_function :TF_OperationNumControlInputs, [:pointer], :int
|
93
|
+
attach_function :TF_OperationGetControlInputs, [:pointer, :pointer, :int], :int
|
94
|
+
attach_function :TF_OperationNumControlOutputs, [:pointer], :int
|
95
|
+
attach_function :TF_OperationGetControlOutputs, [:pointer, :pointer, :int], :int
|
96
|
+
|
97
|
+
attach_function :TF_OperationToNodeDef, [:pointer, :pointer, :pointer], :void
|
98
|
+
attach_function :TF_OperationNumInputs, [:pointer], :int
|
99
|
+
attach_function :TF_OperationInputType, [Input], DataType
|
100
|
+
attach_function :TF_OperationInputListLength, [:pointer, :string, :pointer], :int
|
101
|
+
attach_function :TF_OperationAllInputs, [:pointer, :pointer, :int], :void
|
102
|
+
|
103
|
+
attach_function :TF_OperationNumOutputs, [:pointer], :int
|
104
|
+
attach_function :TF_OperationOutputType, [Output], DataType
|
105
|
+
attach_function :TF_OperationOutputListLength, [:pointer, :string, :pointer], :int
|
106
|
+
|
107
|
+
attach_function :TF_OperationOutputNumConsumers, [Output], :int
|
108
|
+
attach_function :TF_OperationOutputConsumers, [Output, :pointer, :int], :int
|
109
|
+
|
110
|
+
class AttrMetadata < ::FFI::Struct
|
111
|
+
layout :is_list, :uchar,
|
112
|
+
:list_size, :int64,
|
113
|
+
:type, AttrType,
|
114
|
+
:total_size, :int64
|
115
|
+
end
|
116
|
+
|
117
|
+
attach_function :TF_OperationGetAttrMetadata, [:pointer, :string, :pointer], AttrMetadata.by_value
|
118
|
+
attach_function :TF_OperationGetAttrBool, [:pointer, :string, :pointer, :pointer], :void
|
119
|
+
attach_function :TF_OperationGetAttrBoolList, [:pointer, :string, :pointer, :int, :pointer], :void
|
120
|
+
attach_function :TF_OperationGetAttrFloat, [:pointer, :string, :pointer, :pointer], :void
|
121
|
+
attach_function :TF_OperationGetAttrFloatList, [:pointer, :string, :pointer, :int, :pointer], :void
|
122
|
+
attach_function :TF_OperationGetAttrInt, [:pointer, :string, :pointer, :pointer], :void
|
123
|
+
attach_function :TF_OperationGetAttrIntList, [:pointer, :string, :pointer, :int, :pointer], :void
|
124
|
+
attach_function :TF_OperationGetAttrShape, [:pointer, :string, :pointer, :int, :pointer], :void
|
125
|
+
attach_function :TF_OperationGetAttrShapeList, [:pointer, :string, :pointer, :pointer, :int, :pointer, :int, :pointer], :void
|
126
|
+
attach_function :TF_OperationGetAttrString, [:pointer, :string, :pointer, :size_t, :pointer], :void
|
127
|
+
attach_function :TF_OperationGetAttrStringList, [:pointer, :string, :pointer, :pointer, :int, :pointer, :size_t], :void
|
128
|
+
attach_function :TF_OperationGetAttrTensor, [:pointer, :string, :pointer, :pointer], :void
|
129
|
+
attach_function :TF_OperationGetAttrType, [:pointer, :string, :pointer, :pointer], :void
|
130
|
+
attach_function :TF_OperationGetAttrTypeList, [:pointer, :string, :pointer, :int, :pointer], :void
|
131
|
+
attach_function :TF_OperationGetAttrValueProto, [:pointer, :string, :pointer, :pointer], :void
|
132
|
+
|
133
|
+
attach_function :TF_GraphOperationByName, [:pointer, :string], :pointer
|
134
|
+
attach_function :TF_GraphNextOperation, [:pointer, :pointer], :pointer
|
135
|
+
|
136
|
+
attach_function :TF_OperationName, [:pointer], :string
|
137
|
+
attach_function :TF_OperationOpType, [:pointer], :string
|
138
|
+
attach_function :TF_OperationDevice, [:pointer], :string
|
139
|
+
|
140
|
+
attach_function :TF_AddGradients, [:pointer, :pointer, :int, :pointer, :int, :pointer, :pointer, :pointer], :void
|
141
|
+
attach_function :TF_AddGradientsWithPrefix, [:pointer, :string, :pointer, :int, :pointer, :int, :pointer, :pointer, :pointer], :void
|
142
|
+
|
143
|
+
attach_function :TF_NewSessionOptions, [], :pointer
|
144
|
+
attach_function :TF_SetTarget, [:pointer, :string], :void
|
145
|
+
attach_function :TF_SetConfig, [:pointer, :pointer, :size_t, :pointer], :void
|
146
|
+
attach_function :TF_DeleteSessionOptions, [:pointer,], :void
|
147
|
+
|
148
|
+
attach_function :TF_NewSession, [:pointer, :pointer, :pointer], :pointer
|
149
|
+
attach_function :TF_CloseSession, [:pointer, :pointer], :void
|
150
|
+
attach_function :TF_DeleteSession, [:pointer, :pointer], :void
|
151
|
+
attach_function :TF_SessionRun, [:pointer, Buffer,
|
152
|
+
:pointer, :pointer, :int,
|
153
|
+
:pointer, :pointer, :int,
|
154
|
+
:pointer, :int,
|
155
|
+
Buffer,
|
156
|
+
:pointer], :void
|
157
|
+
|
158
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h
|
159
|
+
StatusCode = enum(:tf_ok, 0,
|
160
|
+
:tf_cancelled, 1,
|
161
|
+
:tf_unknown, 2,
|
162
|
+
:tf_invalid_argument, 3,
|
163
|
+
:tf_deadline_exceeded, 4,
|
164
|
+
:tf_not_found, 5,
|
165
|
+
:tf_already_exists, 6,
|
166
|
+
:tf_permission_denied, 7,
|
167
|
+
:tf_unauthenticated, 16,
|
168
|
+
:tf_resource_exhausted, 8,
|
169
|
+
:tf_failed_precondition, 9,
|
170
|
+
:tf_aborted, 10,
|
171
|
+
:tf_out_of_range, 11,
|
172
|
+
:tf_unimplemented, 12,
|
173
|
+
:tf_internal, 13,
|
174
|
+
:tf_unavailable, 14,
|
175
|
+
:tf_data_loss, 15)
|
176
|
+
|
177
|
+
attach_function :TF_NewStatus, [], :pointer
|
178
|
+
attach_function :TF_DeleteStatus, [:pointer], :pointer
|
179
|
+
attach_function :TF_GetCode, [:pointer], StatusCode
|
180
|
+
attach_function :TF_Message, [:pointer], :string
|
181
|
+
attach_function :TF_SetStatus, [:pointer, StatusCode, :string], :void
|
182
|
+
|
183
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h
|
184
|
+
callback :tensor_deallocator, [:pointer, :size_t, :pointer], :void
|
185
|
+
attach_function :TF_NewTensor, [DataType, :pointer, :int, :pointer, :size_t, :tensor_deallocator, :pointer], :pointer
|
186
|
+
attach_function :TF_DeleteTensor, [:pointer], :void
|
187
|
+
attach_function :TF_TensorType, [:pointer], DataType
|
188
|
+
attach_function :TF_NumDims, [:pointer], :int
|
189
|
+
attach_function :TF_Dim, [:pointer, :int], :int64
|
190
|
+
attach_function :TF_TensorByteSize, [:pointer], :size_t
|
191
|
+
attach_function :TF_TensorElementCount, [:pointer], :int64
|
192
|
+
attach_function :TF_TensorData, [:pointer], :pointer
|
193
|
+
attach_function :TF_TensorByteSize, [:pointer], :size_t
|
194
|
+
attach_function :TF_StringEncode, [:pointer, :size_t, :pointer, :size_t, :pointer], :size_t
|
195
|
+
attach_function :TF_StringDecode, [:pointer, :size_t, :pointer, :pointer, :pointer], :size_t
|
196
|
+
attach_function :TF_StringEncodedSize, [:size_t], :size_t
|
197
|
+
|
198
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/eager/c_api.h
|
199
|
+
ContextDevicePlacementPolicy = enum(:explicit, :warn, :silent, :silent_for_int32)
|
200
|
+
|
201
|
+
attach_function :TFE_NewContextOptions, [], :pointer
|
202
|
+
attach_function :TFE_ContextOptionsSetAsync, [:pointer, :char], :void
|
203
|
+
attach_function :TFE_DeleteContextOptions, [:pointer], :void
|
204
|
+
attach_function :TFE_NewContext, [:pointer, :pointer], :pointer
|
205
|
+
attach_function :TFE_DeleteContext, [:pointer], :void
|
206
|
+
attach_function :TFE_ContextListDevices, [:pointer, :pointer], :pointer
|
207
|
+
attach_function :TFE_ContextGetDevicePlacementPolicy, [:pointer], :int
|
208
|
+
attach_function :TFE_ContextAddFunction, [:pointer, :pointer, :pointer], :void
|
209
|
+
attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
|
210
|
+
attach_function :TFE_ContextRemoveFunction, [:pointer, :string, :pointer], :void
|
211
|
+
attach_function :TFE_ContextAddFunction, [:pointer, :pointer, :pointer], :void
|
212
|
+
attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
|
213
|
+
attach_function :TFE_ContextRemoveFunction, [:pointer, :string, :pointer], :void
|
214
|
+
|
215
|
+
attach_function :TFE_NewTensorHandle, [:pointer, :pointer], :pointer
|
216
|
+
attach_function :TFE_DeleteTensorHandle, [:pointer], :void
|
217
|
+
attach_function :TFE_TensorHandleDataType, [:pointer], DataType
|
218
|
+
attach_function :TFE_TensorHandleNumDims, [:pointer, :pointer], :int
|
219
|
+
attach_function :TFE_TensorHandleNumElements, [:pointer, :pointer], :int64
|
220
|
+
attach_function :TFE_TensorHandleDim, %i[pointer int pointer], :int64
|
221
|
+
attach_function :TFE_TensorHandleDeviceName, [:pointer, :pointer], :string
|
222
|
+
attach_function :TFE_TensorHandleBackingDeviceName, [:pointer, :pointer], :string
|
223
|
+
attach_function :TFE_TensorHandleResolve, [:pointer, :pointer], :pointer
|
224
|
+
attach_function :TFE_NewOp, [:pointer, :string, :pointer], :pointer
|
225
|
+
attach_function :TFE_DeleteOp, [:pointer], :void
|
226
|
+
attach_function :TFE_OpSetDevice, [:pointer, :string, :pointer], :pointer
|
227
|
+
attach_function :TFE_OpGetDevice, [:pointer, :pointer], :string
|
228
|
+
attach_function :TFE_OpAddInput, [:pointer, :pointer, :pointer], :void
|
229
|
+
attach_function :TFE_OpAddInputList, %i[pointer pointer int pointer], :void
|
230
|
+
attach_function :TFE_OpGetAttrType, %i[pointer string pointer pointer], AttrType
|
231
|
+
attach_function :TFE_OpSetAttrString, %i[pointer string pointer size_t], :void
|
232
|
+
attach_function :TFE_OpSetAttrInt, %i[pointer string int64_t], :void
|
233
|
+
attach_function :TFE_OpSetAttrFloat, %i[pointer string float], :void
|
234
|
+
attach_function :TFE_OpSetAttrFunction, [:pointer, :string, :pointer], :void
|
235
|
+
attach_function :TFE_OpSetAttrFunctionName, [:pointer, :string, :string, :size_t], :void
|
236
|
+
attach_function :TFE_OpSetAttrFunctionList, [:pointer, :string, :pointer, :int], :void
|
237
|
+
attach_function :TFE_OpSetAttrBool, %i[pointer string uint8], :void
|
238
|
+
attach_function :TFE_OpSetAttrTensor, %i[pointer string pointer pointer], :void
|
239
|
+
attach_function :TFE_OpSetAttrType, %i[pointer string int], :void
|
240
|
+
attach_function :TFE_OpSetAttrShape, %i[pointer string pointer int pointer], :void
|
241
|
+
attach_function :TFE_OpSetAttrIntList, %i[pointer string pointer int], :void
|
242
|
+
attach_function :TFE_OpSetAttrFloatList, %i[pointer string pointer int], :void
|
243
|
+
attach_function :TFE_OpSetAttrTypeList, %i[pointer string pointer int], :void
|
244
|
+
attach_function :TFE_OpSetAttrShapeList, %i[pointer string pointer pointer int pointer], :void
|
245
|
+
attach_function :TFE_Execute, %i[pointer pointer pointer pointer], :pointer
|
246
|
+
attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
|
247
|
+
attach_function :TFE_ContextEnableRunMetadata, [:pointer], :void
|
248
|
+
attach_function :TFE_ContextDisableRunMetadata, [:pointer], :void
|
249
|
+
attach_function :TFE_ContextStartStep, [:pointer], :void
|
250
|
+
attach_function :TFE_ContextEndStep, [:pointer], :void
|
251
|
+
|
252
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/ops.h
|
253
|
+
attach_function :TF_NewOpDefinitionBuilder, [:string], :pointer
|
254
|
+
attach_function :TF_RegisterOpDefinition, [:pointer, :pointer], :void
|
255
|
+
attach_function :TF_DeleteOpDefinitionBuilder, [:pointer], :void
|
256
|
+
attach_function :TF_OpDefinitionBuilderAddAttr, [:pointer, :string], :void
|
257
|
+
attach_function :TF_OpDefinitionBuilderAddInput, [:pointer, :string], :void
|
258
|
+
attach_function :TF_OpDefinitionBuilderAddOutput, [:pointer, :string], :void
|
259
|
+
attach_function :TF_OpDefinitionBuilderSetIsCommutative, [:pointer, :bool], :void
|
260
|
+
attach_function :TF_OpDefinitionBuilderSetIsAggregate, [:pointer, :bool], :void
|
261
|
+
attach_function :TF_OpDefinitionBuilderSetIsAggregate, [:pointer, :bool], :void
|
262
|
+
attach_function :TF_OpDefinitionBuilderSetShapeInferenceFunction, [:pointer, :pointer], :void
|
263
|
+
|
264
|
+
attach_function :TF_GraphToFunction, [:pointer, :string, :uchar,
|
265
|
+
:int, :pointer,
|
266
|
+
:int, :pointer,
|
267
|
+
:int, :pointer,
|
268
|
+
:pointer, :pointer, :string, :pointer], :pointer
|
269
|
+
attach_function :TF_FunctionName, [:pointer], :strptr
|
270
|
+
attach_function :TF_FunctionToFunctionDef, [:pointer, :pointer, :pointer], :strptr
|
271
|
+
attach_function :TF_GraphCopyFunction, [:pointer, :pointer, :pointer, :pointer], :void
|
272
|
+
|
273
|
+
attach_function :TF_GraphToGraphDef, [:pointer, :pointer, :pointer], :void
|
274
|
+
|
275
|
+
attach_function :TF_NewImportGraphDefOptions, [], :pointer
|
276
|
+
attach_function :TF_DeleteImportGraphDefOptions, [:pointer], :void
|
277
|
+
attach_function :TF_ImportGraphDefOptionsSetPrefix, [:pointer, :string], :void
|
278
|
+
attach_function :TF_ImportGraphDefOptionsSetDefaultDevice, [:pointer, :string], :void
|
279
|
+
attach_function :TF_ImportGraphDefOptionsSetUniquifyNames, [:pointer, :uchar], :void
|
280
|
+
attach_function :TF_ImportGraphDefOptionsSetUniquifyPrefix, [:pointer, :uchar], :void
|
281
|
+
attach_function :TF_ImportGraphDefOptionsAddInputMapping, [:pointer, :string, :int, Output], :void
|
282
|
+
attach_function :TF_ImportGraphDefOptionsRemapControlDependency, [:pointer, :string, :pointer], :void
|
283
|
+
attach_function :TF_ImportGraphDefOptionsAddControlDependency, [:pointer, :pointer], :void
|
284
|
+
attach_function :TF_ImportGraphDefOptionsAddReturnOutput, [:pointer,:string, :int], :void
|
285
|
+
attach_function :TF_ImportGraphDefOptionsNumReturnOutputs, [:pointer], :int
|
286
|
+
attach_function :TF_ImportGraphDefOptionsAddReturnOperation, [:pointer, :string], :void
|
287
|
+
attach_function :TF_ImportGraphDefOptionsNumReturnOperations, [:pointer], :int
|
288
|
+
|
289
|
+
attach_function :TF_GraphImportGraphDef, [:pointer, Buffer, :pointer, :pointer], :int
|
290
|
+
end
|
291
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class Function
|
4
|
+
attr_reader :output_types, :output_shapes
|
5
|
+
def initialize(pointer, output_types, output_shapes)
|
6
|
+
@pointer = pointer
|
7
|
+
@output_types = output_types
|
8
|
+
@output_shapes = output_shapes
|
9
|
+
end
|
10
|
+
|
11
|
+
def to_ptr
|
12
|
+
@pointer
|
13
|
+
end
|
14
|
+
|
15
|
+
def name
|
16
|
+
name, ptr = FFI.TF_FunctionName(self)
|
17
|
+
name
|
18
|
+
end
|
19
|
+
|
20
|
+
def function_def
|
21
|
+
buffer_ptr = FFI.TF_NewBuffer
|
22
|
+
Status.check do |status|
|
23
|
+
FFI.TF_FunctionToFunctionDef(self, buffer_ptr, status)
|
24
|
+
end
|
25
|
+
buffer = FFI::Buffer.new(buffer_ptr)
|
26
|
+
string = buffer[:data].read_string(buffer[:length])
|
27
|
+
Tensorflow::FunctionDef.decode(string)
|
28
|
+
ensure
|
29
|
+
FFI.TF_DeleteBuffer(buffer)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
module Tensorflow
|
2
|
+
module Graph
|
3
|
+
class FunctionDef
|
4
|
+
attr_reader :ruby_method, :signatures
|
5
|
+
|
6
|
+
Signature = Struct.new(:dtype, :shape)
|
7
|
+
|
8
|
+
def initialize(ruby_method, input_signatures = [])
|
9
|
+
@ruby_method = ruby_method
|
10
|
+
self.process_signatures(ruby_method, input_signatures)
|
11
|
+
self.wrap_ruby_method
|
12
|
+
end
|
13
|
+
|
14
|
+
def process_signatures(ruby_method, input_signatures)
|
15
|
+
if input_signatures.length != ruby_method.parameters.length
|
16
|
+
raise(Error::InvalidArgumentError, "Must specify input signature for each method parameter")
|
17
|
+
end
|
18
|
+
|
19
|
+
@signatures = input_signatures.map do |dtype, shape|
|
20
|
+
Signature.new(dtype, shape)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def aliased_name
|
25
|
+
"#{self.ruby_method.original_name}_original"
|
26
|
+
end
|
27
|
+
|
28
|
+
def wrap_ruby_method
|
29
|
+
new_name = self.aliased_name
|
30
|
+
original_name = self.ruby_method.original_name
|
31
|
+
self.ruby_method.owner.instance_eval do
|
32
|
+
alias_method(new_name, original_name)
|
33
|
+
end
|
34
|
+
|
35
|
+
this = self
|
36
|
+
original_name = ruby_method.original_name
|
37
|
+
self.ruby_method.owner.instance_eval do
|
38
|
+
define_method(original_name) do |*args|
|
39
|
+
function = this.build_function(self)
|
40
|
+
ExecutionContext.current.add_function(function)
|
41
|
+
function
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def build_function(object)
|
47
|
+
Graph::new.as_default do |graph|
|
48
|
+
placeholders = self.ruby_method.parameters.map.with_index do |param, index|
|
49
|
+
signature = self.signatures[index]
|
50
|
+
Tensorflow.placeholder(signature.dtype, name: param.last, shape: signature.shape)
|
51
|
+
end
|
52
|
+
|
53
|
+
# Call the original ruby_method to build the graph
|
54
|
+
bound_method = self.ruby_method.bind(object)
|
55
|
+
result = bound_method.call(*placeholders)
|
56
|
+
|
57
|
+
graph.to_function(self.ruby_method.original_name.to_s, nil, placeholders, Array(result))
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,120 @@
|
|
1
|
+
require 'set'
|
2
|
+
|
3
|
+
module Tensorflow
|
4
|
+
module Graph
|
5
|
+
class Gradients
|
6
|
+
attr_reader :graph
|
7
|
+
|
8
|
+
def self.gradients
|
9
|
+
@gradients ||= begin
|
10
|
+
default = self.instance_method(:add_api_gradients)
|
11
|
+
Hash.new(default)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
def self.register(op_type, &block)
|
16
|
+
self.gradients[op_type] = block
|
17
|
+
end
|
18
|
+
|
19
|
+
def initialize(graph)
|
20
|
+
@graph = graph
|
21
|
+
end
|
22
|
+
|
23
|
+
def path(output, input)
|
24
|
+
forwards = self.graph.forward(input)
|
25
|
+
backwards = self.graph.backward(output)
|
26
|
+
forwards.intersection(backwards)
|
27
|
+
end
|
28
|
+
|
29
|
+
def default_gradient(operation)
|
30
|
+
operation.outputs.map.with_index do |output, i|
|
31
|
+
shape_op = Tensorflow.shape(output, :int32)
|
32
|
+
constant = Tensorflow.constant(1, name: "grad_ys_#{i}", dtype: operation.output_types[i])
|
33
|
+
Tensorflow.fill(shape_op, constant)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def gradients(output, inputs, grad_ys: nil, name: "gradients", stop_operations: Set.new)
|
38
|
+
grad_ys ||= default_gradient(output).first
|
39
|
+
|
40
|
+
self.graph.name_scope(name) do
|
41
|
+
inputs.map.with_index do |input, i|
|
42
|
+
operations_path = self.path(output, input)
|
43
|
+
next if operations_path.empty?
|
44
|
+
|
45
|
+
self.derivative(grad_ys, output, stop_operations, operations_path)
|
46
|
+
end.flatten.compact
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def derivative(gradient, operation, stop_operations, operations_path)
|
51
|
+
# This method follows the C api naming conventions for parameters. Visually it looks
|
52
|
+
# like this:
|
53
|
+
#
|
54
|
+
# x ------> y (forward)
|
55
|
+
# dy <----- dx (backward)
|
56
|
+
|
57
|
+
return gradient if !operations_path.include?(operation) || stop_operations.include?(operation)
|
58
|
+
|
59
|
+
inputs = operation.inputs.select do |input|
|
60
|
+
operations_path.include?(input.operation) && !stop_operations.include?(input.operation)
|
61
|
+
end
|
62
|
+
|
63
|
+
return gradient if inputs.empty?
|
64
|
+
|
65
|
+
outputs = operation.outputs.select do |output|
|
66
|
+
consumers = operation.output_consumers(output)
|
67
|
+
# The last operation we are evaluating will not be hooked up to any consumers, so
|
68
|
+
# we want to analyze all its outputs. For operations earlier in the graph, skip any
|
69
|
+
# unused outputs since they are not connected to anything
|
70
|
+
operation == operations_path.first || consumers.count > 0
|
71
|
+
end
|
72
|
+
|
73
|
+
gradient_func = self.class.gradients[operation.op_type]
|
74
|
+
|
75
|
+
dy = if gradient_func.is_a?(UnboundMethod)
|
76
|
+
gradient_func.bind(self).call(gradient, outputs, inputs)
|
77
|
+
else
|
78
|
+
gradient_func.call(gradient, outputs, inputs)
|
79
|
+
end
|
80
|
+
|
81
|
+
# We are done with this operation, so backpropagate to the input operations
|
82
|
+
inputs.map.with_index do |input, i|
|
83
|
+
dy_output = dy[i]
|
84
|
+
unless dy_output.output[:oper].null?
|
85
|
+
self.derivative(dy_output.operation, input.operation, stop_operations, operations_path)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
|
90
|
+
def add_api_gradients(gradient, outputs, inputs)
|
91
|
+
# These are the outputs from the operation
|
92
|
+
y = FFI::Output.array_to_ptr(outputs.map(&:output))
|
93
|
+
|
94
|
+
# These are the inputs to the output operation
|
95
|
+
x = FFI::Output.array_to_ptr(inputs.map(&:output))
|
96
|
+
|
97
|
+
# This is the gradient we are backpropagating
|
98
|
+
dx = if gradient
|
99
|
+
FFI::Output.array_to_ptr(gradient.outputs.map(&:output))
|
100
|
+
end
|
101
|
+
|
102
|
+
# This is the gradient we want to calculate
|
103
|
+
dy = ::FFI::MemoryPointer.new(FFI::Output, inputs.length, true)
|
104
|
+
|
105
|
+
prefix = self.graph.scoped_name(inputs.first.operation.name)
|
106
|
+
Status.check do |status|
|
107
|
+
FFI.TF_AddGradientsWithPrefix(self.graph,
|
108
|
+
prefix,
|
109
|
+
y, outputs.length,
|
110
|
+
x, inputs.length,
|
111
|
+
dx, status, dy)
|
112
|
+
end
|
113
|
+
|
114
|
+
inputs.length.times.map do |i|
|
115
|
+
OperationOutput.from_graph(graph, dy[i])
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|