tensorflow 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -1
- data/README.md +33 -2
- data/lib/tensorflow.rb +28 -1
- data/lib/tensorflow/audio.rb +13 -0
- data/lib/tensorflow/bitwise.rb +29 -0
- data/lib/tensorflow/data/batch_dataset.rb +20 -0
- data/lib/tensorflow/data/dataset.rb +46 -0
- data/lib/tensorflow/data/shuffle_dataset.rb +21 -0
- data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
- data/lib/tensorflow/ffi.rb +12 -0
- data/lib/tensorflow/image.rb +218 -0
- data/lib/tensorflow/io.rb +125 -0
- data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
- data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
- data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
- data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
- data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
- data/lib/tensorflow/keras/datasets/mnist.rb +6 -5
- data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
- data/lib/tensorflow/keras/metrics/mean.rb +17 -0
- data/lib/tensorflow/keras/utils.rb +56 -0
- data/lib/tensorflow/linalg.rb +133 -0
- data/lib/tensorflow/math.rb +59 -24
- data/lib/tensorflow/nn.rb +284 -0
- data/lib/tensorflow/ops.rb +10 -9
- data/lib/tensorflow/strings.rb +100 -0
- data/lib/tensorflow/tensor.rb +70 -30
- data/lib/tensorflow/utils.rb +115 -70
- data/lib/tensorflow/variable.rb +1 -1
- data/lib/tensorflow/version.rb +1 -1
- metadata +35 -2
@@ -0,0 +1,100 @@
|
|
1
|
+
module TensorFlow
|
2
|
+
module Strings
|
3
|
+
class << self
|
4
|
+
def as_string(input, precision: nil, scientific: nil, shortest: nil, width: nil, fill: nil)
|
5
|
+
RawOps.as_string(input: input, precision: precision, scientific: scientific, shortest: shortest, width: width, fill: fill)
|
6
|
+
end
|
7
|
+
|
8
|
+
# def bytes_split
|
9
|
+
# end
|
10
|
+
|
11
|
+
# def format
|
12
|
+
# end
|
13
|
+
|
14
|
+
def join(inputs, separator: "")
|
15
|
+
Utils.execute("StringJoin", inputs, separator: separator, N: inputs.size)
|
16
|
+
end
|
17
|
+
|
18
|
+
def length(input, unit: "BYTE")
|
19
|
+
RawOps.string_length(input: input, unit: unit)
|
20
|
+
end
|
21
|
+
|
22
|
+
def lower(input, encoding: "")
|
23
|
+
RawOps.string_lower(input: input, encoding: encoding)
|
24
|
+
end
|
25
|
+
|
26
|
+
# def ngrams
|
27
|
+
# end
|
28
|
+
|
29
|
+
def reduce_join(inputs, reduction_indices, keep_dims: nil, separator: nil)
|
30
|
+
RawOps.reduce_join(inputs: inputs, reduction_indices: reduction_indices, keep_dims: keep_dims, separator: separator)
|
31
|
+
end
|
32
|
+
|
33
|
+
def regex_full_match(input, pattern)
|
34
|
+
RawOps.regex_full_match(input: input, pattern: pattern)
|
35
|
+
end
|
36
|
+
|
37
|
+
def regex_replace(input, pattern, rewrite, replace_global: nil)
|
38
|
+
RawOps.regex_replace(input: input, pattern: pattern, rewrite: rewrite, replace_global: replace_global)
|
39
|
+
end
|
40
|
+
|
41
|
+
def split(split_dim, value, num_split: nil)
|
42
|
+
RawOps.split(split_dim: split_dim, value: value, num_split: num_split)
|
43
|
+
end
|
44
|
+
|
45
|
+
def strip(input)
|
46
|
+
RawOps.string_strip(input: input)
|
47
|
+
end
|
48
|
+
|
49
|
+
def substr(input, pos, len, unit: nil)
|
50
|
+
RawOps.substr(input: input, pos: pos, len: len, unit: unit)
|
51
|
+
end
|
52
|
+
|
53
|
+
# def to_hash_bucket
|
54
|
+
# end
|
55
|
+
|
56
|
+
# def to_hash_bucket_fast
|
57
|
+
# end
|
58
|
+
|
59
|
+
# def to_hash_bucket_strong
|
60
|
+
# end
|
61
|
+
|
62
|
+
def to_number(input, out_type: :float)
|
63
|
+
RawOps.string_to_number(string_tensor: input, out_type: out_type)
|
64
|
+
end
|
65
|
+
|
66
|
+
def unicode_decode(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
67
|
+
RawOps.unicode_decode(input: input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
68
|
+
end
|
69
|
+
|
70
|
+
def unicode_decode_with_offsets(input, input_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
71
|
+
RawOps.unicode_decode_with_offsets(input: input, input_encoding: input_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
72
|
+
end
|
73
|
+
|
74
|
+
def unicode_encode(input_values, input_splits, errors: nil, output_encoding: nil, replacement_char: nil)
|
75
|
+
RawOps.unicode_encode(input_values: input_values, input_splits: input_splits, errors: errors, output_encoding: output_encoding, replacement_char: replacement_char)
|
76
|
+
end
|
77
|
+
|
78
|
+
def unicode_script(input)
|
79
|
+
RawOps.unicode_script(input: input)
|
80
|
+
end
|
81
|
+
|
82
|
+
# def unicode_split
|
83
|
+
# end
|
84
|
+
|
85
|
+
# def unicode_split_with_offsets
|
86
|
+
# end
|
87
|
+
|
88
|
+
def unicode_transcode(input, input_encoding: nil, output_encoding: nil, errors: nil, replacement_char: nil, replace_control_characters: nil)
|
89
|
+
RawOps.unicode_transcode(input: input, input_encoding: input_encoding, output_encoding: output_encoding, errors: errors, replacement_char: replacement_char, replace_control_characters: replace_control_characters)
|
90
|
+
end
|
91
|
+
|
92
|
+
# def unsorted_segment_join
|
93
|
+
# end
|
94
|
+
|
95
|
+
def upper(input, encoding: "")
|
96
|
+
RawOps.string_upper(input: input, encoding: encoding)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
end
|
100
|
+
end
|
data/lib/tensorflow/tensor.rb
CHANGED
@@ -6,7 +6,8 @@ module TensorFlow
|
|
6
6
|
if pointer
|
7
7
|
@pointer = pointer
|
8
8
|
else
|
9
|
-
data =
|
9
|
+
data = value
|
10
|
+
data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray)
|
10
11
|
shape ||= calculate_shape(value)
|
11
12
|
|
12
13
|
if shape.size > 0
|
@@ -16,37 +17,49 @@ module TensorFlow
|
|
16
17
|
dims_ptr = nil
|
17
18
|
end
|
18
19
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
25
|
-
data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
|
26
|
-
data_ptr.send("write_array_of_#{dtype}", data)
|
27
|
-
when :bfloat16
|
28
|
-
# https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
|
29
|
-
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
|
30
|
-
data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
|
31
|
-
when :complex64
|
32
|
-
data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
|
33
|
-
data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
|
34
|
-
when :complex128
|
35
|
-
data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
|
36
|
-
data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
|
37
|
-
when :string
|
38
|
-
data_ptr = string_ptr(data)
|
39
|
-
when :bool
|
40
|
-
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
|
41
|
-
data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
|
20
|
+
if data.is_a?(Numo::NArray)
|
21
|
+
dtype ||= Utils.infer_type(data)
|
22
|
+
# TODO use Numo read pointer?
|
23
|
+
data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size)
|
24
|
+
data_ptr.write_bytes(data.to_string)
|
42
25
|
else
|
43
|
-
|
26
|
+
data = data.flatten
|
27
|
+
dtype ||= Utils.infer_type(data)
|
28
|
+
case dtype
|
29
|
+
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
|
30
|
+
data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
|
31
|
+
data_ptr.send("write_array_of_#{dtype}", data)
|
32
|
+
when :bfloat16
|
33
|
+
# https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
|
34
|
+
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
|
35
|
+
data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
|
36
|
+
when :complex64
|
37
|
+
data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
|
38
|
+
data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
|
39
|
+
when :complex128
|
40
|
+
data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
|
41
|
+
data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
|
42
|
+
when :string
|
43
|
+
data_ptr = string_ptr(data)
|
44
|
+
when :bool
|
45
|
+
data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
|
46
|
+
data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
|
47
|
+
else
|
48
|
+
raise "Unknown type: #{dtype}"
|
49
|
+
end
|
44
50
|
end
|
45
51
|
|
52
|
+
type = FFI::DataType[dtype]
|
53
|
+
|
46
54
|
callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
|
47
55
|
# FFI handles deallocation
|
48
56
|
end
|
49
57
|
|
58
|
+
# keep data pointer alive for duration of object
|
59
|
+
@data_ptr = data_ptr
|
60
|
+
@dims_ptr = dims_ptr
|
61
|
+
@callback = callback
|
62
|
+
|
50
63
|
tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
|
51
64
|
@pointer = FFI.TFE_NewTensorHandle(tensor, @status)
|
52
65
|
check_status @status
|
@@ -75,6 +88,10 @@ module TensorFlow
|
|
75
88
|
Math.floormod(self, other)
|
76
89
|
end
|
77
90
|
|
91
|
+
def -@
|
92
|
+
Math.negative(self)
|
93
|
+
end
|
94
|
+
|
78
95
|
def value
|
79
96
|
value =
|
80
97
|
case dtype
|
@@ -92,10 +109,19 @@ module TensorFlow
|
|
92
109
|
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
|
93
110
|
start_offset_size = element_count * 8
|
94
111
|
offsets = data_pointer.read_array_of_uint64(element_count)
|
95
|
-
|
112
|
+
byte_size = FFI.TF_TensorByteSize(tensor_pointer)
|
113
|
+
element_count.times.map do |i|
|
114
|
+
str_len = (offsets[i + 1] || (byte_size - start_offset_size)) - offsets[i]
|
115
|
+
str = (data_pointer + start_offset_size + offsets[i]).read_bytes(str_len)
|
116
|
+
dst = ::FFI::MemoryPointer.new(:char, str.bytesize + 100)
|
117
|
+
dst_len = ::FFI::MemoryPointer.new(:size_t)
|
118
|
+
FFI.TF_StringDecode(str, str.bytesize, dst, dst_len, @status)
|
119
|
+
check_status @status
|
120
|
+
dst.read_pointer.read_bytes(dst_len.read_int32)
|
121
|
+
end
|
96
122
|
when :bool
|
97
123
|
data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
|
98
|
-
when :resource
|
124
|
+
when :resource, :variant
|
99
125
|
return data_pointer
|
100
126
|
else
|
101
127
|
raise "Unknown type: #{dtype}"
|
@@ -135,8 +161,14 @@ module TensorFlow
|
|
135
161
|
@pointer
|
136
162
|
end
|
137
163
|
|
164
|
+
def numo
|
165
|
+
klass = Utils::NUMO_TYPE_MAP[dtype]
|
166
|
+
raise "Unknown type: #{dtype}" unless klass
|
167
|
+
klass.cast(value)
|
168
|
+
end
|
169
|
+
|
138
170
|
def inspect
|
139
|
-
inspection = %w(
|
171
|
+
inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
|
140
172
|
"#<#{self.class} #{inspection.join(", ")}>"
|
141
173
|
end
|
142
174
|
|
@@ -164,9 +196,13 @@ module TensorFlow
|
|
164
196
|
end
|
165
197
|
|
166
198
|
def data_pointer
|
199
|
+
FFI.TF_TensorData(tensor_pointer)
|
200
|
+
end
|
201
|
+
|
202
|
+
def tensor_pointer
|
167
203
|
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
|
168
204
|
check_status @status
|
169
|
-
|
205
|
+
tensor
|
170
206
|
end
|
171
207
|
|
172
208
|
def reshape(arr, dims)
|
@@ -179,6 +215,8 @@ module TensorFlow
|
|
179
215
|
end
|
180
216
|
|
181
217
|
def calculate_shape(value)
|
218
|
+
return value.shape if value.respond_to?(:shape)
|
219
|
+
|
182
220
|
shape = []
|
183
221
|
d = value
|
184
222
|
while d.is_a?(Array)
|
@@ -199,7 +237,9 @@ module TensorFlow
|
|
199
237
|
data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
|
200
238
|
data_ptr.write_array_of_uint64(offsets)
|
201
239
|
data.zip(offsets) do |str, offset|
|
202
|
-
|
240
|
+
dst_len = FFI.TF_StringEncodedSize(str.bytesize)
|
241
|
+
FFI.TF_StringEncode(str, str.bytesize, data_ptr + start_offset_size + offset, dst_len, @status)
|
242
|
+
check_status @status
|
203
243
|
end
|
204
244
|
data_ptr
|
205
245
|
end
|
data/lib/tensorflow/utils.rb
CHANGED
@@ -1,5 +1,18 @@
|
|
1
1
|
module TensorFlow
|
2
2
|
module Utils
|
3
|
+
NUMO_TYPE_MAP = {
|
4
|
+
int8: Numo::Int8,
|
5
|
+
int16: Numo::Int16,
|
6
|
+
int32: Numo::Int32,
|
7
|
+
int64: Numo::Int64,
|
8
|
+
uint8: Numo::UInt8,
|
9
|
+
uint16: Numo::UInt16,
|
10
|
+
uint32: Numo::UInt32,
|
11
|
+
uint64: Numo::UInt64,
|
12
|
+
float: Numo::SFloat,
|
13
|
+
double: Numo::DFloat
|
14
|
+
}
|
15
|
+
|
3
16
|
class << self
|
4
17
|
def check_status(status)
|
5
18
|
if FFI.TF_GetCode(status) != 0
|
@@ -26,48 +39,104 @@ module TensorFlow
|
|
26
39
|
type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
|
27
40
|
check_status status
|
28
41
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
42
|
+
if is_list.read_int == 1
|
43
|
+
num_values = attr_value.size
|
44
|
+
|
45
|
+
case FFI::AttrType[type]
|
46
|
+
when :int
|
47
|
+
values = ::FFI::MemoryPointer.new(:int64, num_values)
|
48
|
+
values.write_array_of_int64(attr_value)
|
49
|
+
FFI.TFE_OpSetAttrIntList(op, attr_name, values, num_values)
|
50
|
+
when :float
|
51
|
+
values = ::FFI::MemoryPointer.new(:float, num_values)
|
52
|
+
values.write_array_of_float(attr_value)
|
53
|
+
FFI.TFE_OpSetAttrFloatList(op, attr_name, values, num_values)
|
54
|
+
when :shape
|
55
|
+
dims_ptrs =
|
56
|
+
attr_value.map do |shape|
|
57
|
+
ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
|
58
|
+
ptr.write_array_of_int64(shape)
|
59
|
+
end
|
60
|
+
dims = ::FFI::MemoryPointer.new(:pointer, num_values)
|
61
|
+
dims.write_array_of_pointer(dims_ptrs)
|
62
|
+
|
63
|
+
num_dims = ::FFI::MemoryPointer.new(:int, num_values)
|
64
|
+
num_dims.write_array_of_int(attr_value.map(&:size))
|
65
|
+
|
66
|
+
FFI.TFE_OpSetAttrShapeList(op, attr_name, dims, num_dims, num_values, status)
|
67
|
+
when :type
|
68
|
+
values = ::FFI::MemoryPointer.new(:int, num_values)
|
69
|
+
types =
|
70
|
+
attr_value.map do |v|
|
71
|
+
if v.is_a?(Symbol)
|
72
|
+
FFI::DataType[v]
|
73
|
+
else
|
74
|
+
v
|
75
|
+
end
|
76
|
+
end
|
77
|
+
values.write_array_of_int(types)
|
78
|
+
FFI.TFE_OpSetAttrTypeList(op, attr_name, values, num_values)
|
79
|
+
else
|
80
|
+
raise "Unknown list type: #{FFI::AttrType[type]}"
|
81
|
+
end
|
44
82
|
else
|
45
|
-
|
83
|
+
case FFI::AttrType[type]
|
84
|
+
when :string
|
85
|
+
FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
|
86
|
+
when :int
|
87
|
+
FFI.TFE_OpSetAttrInt(op, attr_name, attr_value)
|
88
|
+
when :float
|
89
|
+
FFI.TFE_OpSetAttrFloat(op, attr_name, attr_value)
|
90
|
+
when :bool
|
91
|
+
FFI.TFE_OpSetAttrBool(op, attr_name, attr_value ? 1 : 0)
|
92
|
+
when :type
|
93
|
+
attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol)
|
94
|
+
FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
|
95
|
+
when :shape
|
96
|
+
ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size)
|
97
|
+
ptr.write_array_of_int64(attr_value)
|
98
|
+
FFI.TFE_OpSetAttrShape(op, attr_name, ptr, attr_value.size, status)
|
99
|
+
check_status status
|
100
|
+
# when :tensor
|
101
|
+
# when :placeholder
|
102
|
+
# when :func
|
103
|
+
else
|
104
|
+
raise "Unknown type: #{FFI::AttrType[type]}"
|
105
|
+
end
|
46
106
|
end
|
47
107
|
end
|
48
108
|
|
49
|
-
inputs.
|
50
|
-
|
51
|
-
|
109
|
+
inputs.each_with_index do |input, i|
|
110
|
+
# TODO handle this better
|
111
|
+
if op_name == "TensorSliceDataset" && i == 0
|
112
|
+
input_ptr = ::FFI::MemoryPointer.new(:pointer, input.size)
|
113
|
+
input_ptr.write_array_of_pointer(input)
|
114
|
+
FFI.TFE_OpAddInputList(op, input_ptr, input.size, status)
|
115
|
+
else
|
116
|
+
raise "Missing argument" if input.nil?
|
117
|
+
|
118
|
+
input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr)
|
119
|
+
FFI.TFE_OpAddInput(op, input, status)
|
120
|
+
end
|
52
121
|
check_status status
|
53
122
|
end
|
54
123
|
|
55
|
-
retvals
|
124
|
+
# TODO decide how many retvals to allocate
|
125
|
+
retvals = ::FFI::MemoryPointer.new(:pointer, 2)
|
56
126
|
num_retvals = ::FFI::MemoryPointer.new(:int)
|
57
127
|
num_retvals.write_int(retvals.size)
|
58
128
|
FFI.TFE_Execute(op, retvals, num_retvals, status)
|
59
129
|
check_status status
|
60
130
|
|
61
|
-
|
62
|
-
|
63
|
-
|
131
|
+
n = num_retvals.read_int
|
132
|
+
if n > 0
|
133
|
+
retvals =
|
134
|
+
retvals.read_array_of_pointer(n).map do |handle|
|
135
|
+
Tensor.new(pointer: handle)
|
136
|
+
end
|
64
137
|
|
65
|
-
case
|
66
|
-
|
67
|
-
handle
|
68
|
-
else
|
69
|
-
Tensor.new(pointer: handle)
|
70
|
-
end
|
138
|
+
# TODO handle case where n = 1 and still want an array for retvals
|
139
|
+
n == 1 ? retvals.first : retvals
|
71
140
|
end
|
72
141
|
ensure
|
73
142
|
FFI.TF_DeleteStatus(status) if status
|
@@ -75,9 +144,18 @@ module TensorFlow
|
|
75
144
|
end
|
76
145
|
|
77
146
|
def infer_type(value)
|
78
|
-
if value.
|
147
|
+
if value.is_a?(Numo::NArray)
|
148
|
+
type = NUMO_TYPE_MAP.find { |k, v| value.is_a?(v) }
|
149
|
+
if type
|
150
|
+
type.first
|
151
|
+
else
|
152
|
+
raise Error, "Unable to infer data type"
|
153
|
+
end
|
154
|
+
elsif value.empty?
|
155
|
+
raise Error, "Unable to infer data type"
|
156
|
+
elsif value.all? { |v| v.is_a?(String) }
|
79
157
|
:string
|
80
|
-
elsif value.all? { |v| v
|
158
|
+
elsif value.all? { |v| v.is_a?(TrueClass) || v.is_a?(FalseClass) }
|
81
159
|
:bool
|
82
160
|
elsif value.all? { |v| v.is_a?(Integer) }
|
83
161
|
if value.all? { |v| v >= -2147483648 && v <= 2147483647 }
|
@@ -94,47 +172,14 @@ module TensorFlow
|
|
94
172
|
end
|
95
173
|
end
|
96
174
|
|
97
|
-
def
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
path = "#{datasets_dir}/#{path}"
|
104
|
-
Utils.download_file(url, path) unless File.exist?(path)
|
105
|
-
Npy.load_npz(path)
|
106
|
-
end
|
107
|
-
|
108
|
-
def download_file(url, dest)
|
109
|
-
uri = URI(url)
|
110
|
-
|
111
|
-
temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
|
112
|
-
temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name
|
113
|
-
|
114
|
-
# Net::HTTP automatically adds Accept-Encoding for compression
|
115
|
-
# of response bodies and automatically decompresses gzip
|
116
|
-
# and deflateresponses unless a Range header was sent.
|
117
|
-
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
118
|
-
Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
|
119
|
-
request = Net::HTTP::Get.new(uri)
|
120
|
-
|
121
|
-
print("Downloading dataset")
|
122
|
-
i = 0
|
123
|
-
File.open(temp_path, "wb") do |f|
|
124
|
-
http.request(request) do |response|
|
125
|
-
response.read_body do |chunk|
|
126
|
-
f.write(chunk)
|
127
|
-
|
128
|
-
# print progress
|
129
|
-
putc "." if i % 50 == 0
|
130
|
-
i += 1
|
131
|
-
end
|
132
|
-
end
|
133
|
-
puts # newline
|
175
|
+
def to_tensor_array(values)
|
176
|
+
values.map do |v|
|
177
|
+
if v.is_a?(Tensor)
|
178
|
+
v
|
179
|
+
else
|
180
|
+
TensorFlow.convert_to_tensor(v)
|
134
181
|
end
|
135
182
|
end
|
136
|
-
|
137
|
-
FileUtils.mv(temp_path, dest)
|
138
183
|
end
|
139
184
|
end
|
140
185
|
end
|