TensorFrost 0.7.0.dev5__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- TensorFrost/TensorFrost.cp313-win_amd64.pyd +0 -0
- TensorFrost/__init__.py +8 -0
- TensorFrost/clipping.py +5 -0
- TensorFrost/default.py +14 -0
- TensorFrost/glad_gl_core_46.dll +0 -0
- TensorFrost/optimizers.py +219 -0
- TensorFrost/random.py +45 -0
- TensorFrost/regularizers.py +5 -0
- TensorFrost/sort.py +187 -0
- TensorFrost-0.7.0.dev5.dist-info/METADATA +912 -0
- TensorFrost-0.7.0.dev5.dist-info/RECORD +13 -0
- TensorFrost-0.7.0.dev5.dist-info/WHEEL +5 -0
- TensorFrost-0.7.0.dev5.dist-info/top_level.txt +1 -0
|
Binary file
|
TensorFrost/__init__.py
ADDED
TensorFrost/clipping.py
ADDED
TensorFrost/default.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from . import TensorFrost as tf
|
|
2
|
+
|
|
3
|
+
def zeros_like(tensor):
|
|
4
|
+
return tf.zeros(tensor.shape, tensor.type)
|
|
5
|
+
|
|
6
|
+
def eye(n):
|
|
7
|
+
i, j = tf.indices([n, n])
|
|
8
|
+
return tf.select(i == j, 1.0, 0.0)
|
|
9
|
+
|
|
10
|
+
def eye_like(tensor):
|
|
11
|
+
return eye(tensor.shape[0])
|
|
12
|
+
|
|
13
|
+
def ones_like(tensor):
|
|
14
|
+
return tf.ones(tensor.shape, tensor.type)
|
|
Binary file
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from . import TensorFrost as tf
|
|
2
|
+
|
|
3
|
+
class ModuleOptimizer(tf.Module):
|
|
4
|
+
class OptimizerType:
|
|
5
|
+
ADAM = 0
|
|
6
|
+
SGD = 1
|
|
7
|
+
RMSProp = 2
|
|
8
|
+
|
|
9
|
+
class RegularizerType:
|
|
10
|
+
None_ = 0
|
|
11
|
+
L1 = 1
|
|
12
|
+
L2 = 2
|
|
13
|
+
|
|
14
|
+
class ClippingType:
|
|
15
|
+
Clamp = 0
|
|
16
|
+
Norm = 1
|
|
17
|
+
None_ = 2
|
|
18
|
+
|
|
19
|
+
def __init__(self, optimizer_type, regularizer_type, net, params):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.optimizer_type = optimizer_type
|
|
22
|
+
self.regularizer_type = regularizer_type
|
|
23
|
+
self.clipping_type = self.ClippingType.Clamp
|
|
24
|
+
self.epsilon = 1e-8
|
|
25
|
+
|
|
26
|
+
# Set passed parameters as attributes
|
|
27
|
+
self.net = net
|
|
28
|
+
for k, v in params.items():
|
|
29
|
+
setattr(self, k, v)
|
|
30
|
+
|
|
31
|
+
# Initialize t
|
|
32
|
+
t = tf.Parameter([1], tf.float32, False) # mimic Parameter({1}, TFType::Float, false)
|
|
33
|
+
self.t = t
|
|
34
|
+
|
|
35
|
+
self.initializeOptimizer(net)
|
|
36
|
+
|
|
37
|
+
def set_clipping_type(self, ctype):
|
|
38
|
+
self.clipping_type = ctype
|
|
39
|
+
|
|
40
|
+
def initializeOptimizer(self, net):
|
|
41
|
+
net_params = net.parameters()
|
|
42
|
+
requires_grads = net.requires_grads_list()
|
|
43
|
+
|
|
44
|
+
if self.optimizer_type == self.OptimizerType.ADAM:
|
|
45
|
+
self.initializeParameterArray("m", net_params, requires_grads)
|
|
46
|
+
self.initializeParameterArray("v", net_params, requires_grads)
|
|
47
|
+
elif self.optimizer_type == self.OptimizerType.SGD:
|
|
48
|
+
# No additional parameters needed
|
|
49
|
+
pass
|
|
50
|
+
elif self.optimizer_type == self.OptimizerType.RMSProp:
|
|
51
|
+
self.initializeParameterArray("v", net_params, requires_grads)
|
|
52
|
+
|
|
53
|
+
def initializeParameterArray(self, name, net_params, requires_grads):
|
|
54
|
+
arr = tf.ParameterArray()
|
|
55
|
+
|
|
56
|
+
for i, param in enumerate(net_params):
|
|
57
|
+
if not requires_grads[i]:
|
|
58
|
+
continue
|
|
59
|
+
|
|
60
|
+
new_param = tf.Parameter(param.shape, tf.float32, False)
|
|
61
|
+
arr[i] = new_param
|
|
62
|
+
|
|
63
|
+
setattr(self, name, arr)
|
|
64
|
+
|
|
65
|
+
def assert_parameters(self):
|
|
66
|
+
net_params = self.net.parameters()
|
|
67
|
+
requires_grads = self.net.requires_grads_list()
|
|
68
|
+
self.assertParameterArray("m", net_params, requires_grads)
|
|
69
|
+
self.assertParameterArray("v", net_params, requires_grads)
|
|
70
|
+
|
|
71
|
+
def gradient_norm(self, grad):
|
|
72
|
+
# sum of squares
|
|
73
|
+
g = grad * grad
|
|
74
|
+
shape = grad.shape
|
|
75
|
+
num_dims = len(shape)
|
|
76
|
+
for i in range(num_dims):
|
|
77
|
+
g = tf.sum(g)
|
|
78
|
+
return tf.sqrt(g)
|
|
79
|
+
|
|
80
|
+
def assertParameterArray(self, name, net_params, requires_grads):
|
|
81
|
+
if hasattr(self, name):
|
|
82
|
+
arr = getattr(self, name)
|
|
83
|
+
for i, param in enumerate(net_params):
|
|
84
|
+
if not requires_grads[i]:
|
|
85
|
+
continue
|
|
86
|
+
arr_item = arr[i]
|
|
87
|
+
arr_item = tf.assert_tensor(arr_item, param.shape, param.type)
|
|
88
|
+
arr[i] = arr_item
|
|
89
|
+
|
|
90
|
+
def step(self, *args):
|
|
91
|
+
# Overloaded step:
|
|
92
|
+
# step(X, Y) or step(loss)
|
|
93
|
+
if len(args) == 2:
|
|
94
|
+
X, Y = args
|
|
95
|
+
loss = self.net.loss(X, Y)
|
|
96
|
+
self._step(loss)
|
|
97
|
+
return loss
|
|
98
|
+
elif len(args) == 1:
|
|
99
|
+
(loss,) = args
|
|
100
|
+
self._step(loss)
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError("Invalid arguments to step")
|
|
103
|
+
|
|
104
|
+
def _step(self, loss):
|
|
105
|
+
# Increment t by 1
|
|
106
|
+
self.t = self.t + 1.0
|
|
107
|
+
|
|
108
|
+
net = self.net
|
|
109
|
+
net_params = net.parameters()
|
|
110
|
+
requires_grads = net.requires_grads_list()
|
|
111
|
+
|
|
112
|
+
learning_rate = self.learning_rate
|
|
113
|
+
grad_clip = self.grad_clip
|
|
114
|
+
has_clip = isinstance(grad_clip, float) and grad_clip > 0.0
|
|
115
|
+
|
|
116
|
+
for i, param in enumerate(net_params):
|
|
117
|
+
if not requires_grads[i]:
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
grad = tf.grad(loss, param)
|
|
121
|
+
if has_clip:
|
|
122
|
+
if self.clipping_type == self.ClippingType.Clamp:
|
|
123
|
+
grad = tf.clamp(grad, -grad_clip, grad_clip)
|
|
124
|
+
elif self.clipping_type == self.ClippingType.Norm:
|
|
125
|
+
grad_norm = tf.max(1e-6, self.gradient_norm(grad))
|
|
126
|
+
grad = grad * tf.min(1.0, grad_clip / grad_norm)
|
|
127
|
+
|
|
128
|
+
if self.optimizer_type == self.OptimizerType.ADAM:
|
|
129
|
+
update = self.adam_update(i, param, grad, self.t, learning_rate)
|
|
130
|
+
elif self.optimizer_type == self.OptimizerType.SGD:
|
|
131
|
+
update = self.sgd_update(param, grad, learning_rate)
|
|
132
|
+
elif self.optimizer_type == self.OptimizerType.RMSProp:
|
|
133
|
+
update = self.rmsprop_update(i, param, grad, learning_rate)
|
|
134
|
+
else:
|
|
135
|
+
raise RuntimeError("Unknown optimizer type")
|
|
136
|
+
|
|
137
|
+
# Apply regularization if needed
|
|
138
|
+
if self.regularizer_type == self.RegularizerType.L1:
|
|
139
|
+
param = param - learning_rate * self.reg * tf.sign(param)
|
|
140
|
+
elif self.regularizer_type == self.RegularizerType.L2:
|
|
141
|
+
param = param - learning_rate * self.reg * param
|
|
142
|
+
|
|
143
|
+
# Update parameter with computed update
|
|
144
|
+
param = param - update
|
|
145
|
+
net_params[i] = param
|
|
146
|
+
|
|
147
|
+
net.update_parameters(net_params)
|
|
148
|
+
|
|
149
|
+
def adam_update(self, i, param, grad, t, learning_rate):
|
|
150
|
+
beta1 = tf.float(self.beta1)
|
|
151
|
+
beta2 = tf.float(self.beta2)
|
|
152
|
+
|
|
153
|
+
m = self.m[i]
|
|
154
|
+
v = self.v[i]
|
|
155
|
+
|
|
156
|
+
m = tf.lerp(grad, m, beta1)
|
|
157
|
+
v = tf.lerp(grad * grad, v, beta2)
|
|
158
|
+
|
|
159
|
+
# t is a Parameter with shape [1]; get the scalar
|
|
160
|
+
t_val = self.t[0]
|
|
161
|
+
mhat = m / (1.0 - tf.pow(beta1, t_val))
|
|
162
|
+
vhat = v / (1.0 - tf.pow(beta2, t_val))
|
|
163
|
+
|
|
164
|
+
self.m[i] = m
|
|
165
|
+
self.v[i] = v
|
|
166
|
+
|
|
167
|
+
return learning_rate * mhat / (tf.sqrt(vhat) + self.epsilon)
|
|
168
|
+
|
|
169
|
+
def sgd_update(self, param, grad, learning_rate):
|
|
170
|
+
return learning_rate * grad
|
|
171
|
+
|
|
172
|
+
def rmsprop_update(self, i, param, grad, learning_rate):
|
|
173
|
+
decay = tf.float(self.decay)
|
|
174
|
+
|
|
175
|
+
v = self.v[i]
|
|
176
|
+
v = tf.lerp(grad * grad, v, decay)
|
|
177
|
+
self.v[i] = v
|
|
178
|
+
|
|
179
|
+
return (grad * learning_rate) / (tf.sqrt(v) + self.epsilon)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def adam(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, beta1=0.9, beta2=0.999, clip=0.0, reg=0.0):
|
|
183
|
+
return ModuleOptimizer(
|
|
184
|
+
ModuleOptimizer.OptimizerType.ADAM,
|
|
185
|
+
reg_type,
|
|
186
|
+
net,
|
|
187
|
+
{
|
|
188
|
+
"learning_rate": learning_rate,
|
|
189
|
+
"beta1": beta1,
|
|
190
|
+
"beta2": beta2,
|
|
191
|
+
"grad_clip": clip,
|
|
192
|
+
"reg": reg,
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def sgd(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, clip=0.0, reg=0.0):
|
|
197
|
+
return ModuleOptimizer(
|
|
198
|
+
ModuleOptimizer.OptimizerType.SGD,
|
|
199
|
+
reg_type,
|
|
200
|
+
net,
|
|
201
|
+
{
|
|
202
|
+
"learning_rate": learning_rate,
|
|
203
|
+
"grad_clip": clip,
|
|
204
|
+
"reg": reg,
|
|
205
|
+
}
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def rmsprop(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, decay=0.9, clip=0.0, reg=0.0):
|
|
209
|
+
return ModuleOptimizer(
|
|
210
|
+
ModuleOptimizer.OptimizerType.RMSProp,
|
|
211
|
+
reg_type,
|
|
212
|
+
net,
|
|
213
|
+
{
|
|
214
|
+
"learning_rate": learning_rate,
|
|
215
|
+
"decay": decay,
|
|
216
|
+
"grad_clip": clip,
|
|
217
|
+
"reg": reg,
|
|
218
|
+
}
|
|
219
|
+
)
|
TensorFrost/random.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from . import TensorFrost as tf
|
|
2
|
+
|
|
3
|
+
def randn2(shape, seed=0):
|
|
4
|
+
#Box-Muller transform
|
|
5
|
+
r1 = tf.random_value(shape, seed=seed)
|
|
6
|
+
r2 = tf.random_value(shape, seed=tf.hash(seed))
|
|
7
|
+
rho = tf.sqrt(-2.0*tf.log(tf.max(1e-6, r1)))
|
|
8
|
+
theta = 2.0*tf.pi*r2
|
|
9
|
+
return rho*tf.cos(theta), rho*tf.sin(theta)
|
|
10
|
+
|
|
11
|
+
def randn(shape, seed=0):
|
|
12
|
+
return randn2(shape, seed=seed)[0]
|
|
13
|
+
|
|
14
|
+
def rand(shape, seed=0):
|
|
15
|
+
return tf.random_value(shape, seed=seed)
|
|
16
|
+
|
|
17
|
+
def randn_like(tensor, seed=0):
|
|
18
|
+
return randn(tensor.shape, seed=seed)
|
|
19
|
+
|
|
20
|
+
def rand_like(tensor, seed=0):
|
|
21
|
+
return rand(tensor.shape, seed=seed)
|
|
22
|
+
|
|
23
|
+
def rand_int(seed, max_value):
|
|
24
|
+
return tf.int(tf.pcg(tf.uint(seed)) % tf.uint(max_value))
|
|
25
|
+
|
|
26
|
+
def xor_swap(idx, n, seed):
|
|
27
|
+
xor_seed = rand_int(seed, n)
|
|
28
|
+
xor_idx = (idx ^ xor_seed)
|
|
29
|
+
max_idx = tf.max(idx, xor_idx)
|
|
30
|
+
min_idx = tf.min(idx, xor_idx)
|
|
31
|
+
swap = rand_int(min_idx * 451 + seed, 2) == 0
|
|
32
|
+
return tf.select(swap & (max_idx < n), xor_idx, idx)
|
|
33
|
+
|
|
34
|
+
def reverse(idx, n):
|
|
35
|
+
return n - 1 - idx
|
|
36
|
+
|
|
37
|
+
def shuffle(idx, n, seed = 0, iters = 16):
|
|
38
|
+
for i in range(iters):
|
|
39
|
+
idx = xor_swap(idx, n, seed + i)
|
|
40
|
+
idx = reverse(idx, n)
|
|
41
|
+
return idx
|
|
42
|
+
|
|
43
|
+
def permutation(n, seed = 0):
|
|
44
|
+
idx = tf.indices([n])[0]
|
|
45
|
+
return shuffle(idx, n, seed)
|
TensorFrost/sort.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from . import TensorFrost as tf
|
|
2
|
+
|
|
3
|
+
#in-place bitonic sort
|
|
4
|
+
def bitonic(keys, values = None):
|
|
5
|
+
tf.region_begin('Bitonic sort')
|
|
6
|
+
keys = tf.copy(keys)
|
|
7
|
+
if values is not None:
|
|
8
|
+
values = tf.copy(values)
|
|
9
|
+
element_count = keys.shape[0]
|
|
10
|
+
log2_count = tf.int(tf.ceil(tf.log2(tf.float(element_count))))
|
|
11
|
+
count_round = 1 << log2_count
|
|
12
|
+
idx = tf.indices([count_round / 2])[0]
|
|
13
|
+
with tf.loop(log2_count) as k:
|
|
14
|
+
with tf.loop(k+1) as j:
|
|
15
|
+
s = 1 << (k-j)
|
|
16
|
+
m_inner = s - 1
|
|
17
|
+
m_outer = ~m_inner
|
|
18
|
+
m_xor = s + tf.select(j == 0, m_inner, 0)
|
|
19
|
+
|
|
20
|
+
id1 = (2 * (idx & m_outer) + (idx & m_inner))
|
|
21
|
+
id2 = id1 ^ m_xor
|
|
22
|
+
key1, key2 = keys[id1], keys[id2]
|
|
23
|
+
with tf.if_cond((key1 >= key2) & (id1 < element_count) & (id2 < element_count)):
|
|
24
|
+
if values is not None:
|
|
25
|
+
val1, val2 = values[id1], values[id2]
|
|
26
|
+
values[id1] = val2
|
|
27
|
+
values[id2] = val1
|
|
28
|
+
keys[id1] = key2
|
|
29
|
+
keys[id2] = key1
|
|
30
|
+
|
|
31
|
+
tf.region_end('Bitonic sort')
|
|
32
|
+
if values is not None:
|
|
33
|
+
return keys, values
|
|
34
|
+
else:
|
|
35
|
+
return keys
|
|
36
|
+
|
|
37
|
+
#histogram radix sort
|
|
38
|
+
def radix(keys, values = None, bits_per_pass = 6, max_bits = 32):
|
|
39
|
+
def prefix_sum_grouped(A, axis = -1):
|
|
40
|
+
axis = len(A.shape) + axis if axis < 0 else axis
|
|
41
|
+
group_size = 64
|
|
42
|
+
grouped = tf.split_dim(A, group_size, axis)
|
|
43
|
+
group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis)
|
|
44
|
+
ids = grouped.indices
|
|
45
|
+
gid, eid = ids[axis], ids[axis + 1]
|
|
46
|
+
ids = [ids[i] for i in range(len(ids)) if i != axis + 1]
|
|
47
|
+
ids[axis] = gid - 1
|
|
48
|
+
group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), tf.uint(0), group_scan[tuple(ids)]), axis = axis + 1)
|
|
49
|
+
full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1)
|
|
50
|
+
return full_scan
|
|
51
|
+
|
|
52
|
+
sign_bit = ~tf.uint(0x7FFFFFFF)
|
|
53
|
+
|
|
54
|
+
def map_float_to_uint(x):
|
|
55
|
+
# Convert float to uint representation
|
|
56
|
+
ux = tf.asuint(x)
|
|
57
|
+
# Compute mask
|
|
58
|
+
mask = tf.select((ux >> 31) == 1, ~tf.uint(0), sign_bit)
|
|
59
|
+
# Apply XOR
|
|
60
|
+
return ux ^ mask
|
|
61
|
+
|
|
62
|
+
def map_uint_to_float(x):
|
|
63
|
+
# Compute mask
|
|
64
|
+
mask = tf.select((x >> 31) == 0, ~tf.uint(0), sign_bit)
|
|
65
|
+
# Apply XOR and convert back to float
|
|
66
|
+
return tf.asfloat(x ^ mask)
|
|
67
|
+
|
|
68
|
+
def map_int_to_uint(x):
|
|
69
|
+
return tf.asuint(x) ^ sign_bit
|
|
70
|
+
|
|
71
|
+
def map_uint_to_int(x):
|
|
72
|
+
return tf.asint(x ^ sign_bit)
|
|
73
|
+
|
|
74
|
+
tf.region_begin('Radix sort')
|
|
75
|
+
|
|
76
|
+
has_values = values is not None
|
|
77
|
+
|
|
78
|
+
keys = tf.copy(keys)
|
|
79
|
+
if has_values:
|
|
80
|
+
values = tf.copy(values)
|
|
81
|
+
|
|
82
|
+
original_type = keys.type
|
|
83
|
+
if(original_type == tf.float32):
|
|
84
|
+
keys = map_float_to_uint(keys)
|
|
85
|
+
|
|
86
|
+
if(original_type == tf.int32):
|
|
87
|
+
keys = map_int_to_uint(keys)
|
|
88
|
+
|
|
89
|
+
iters = (max_bits + bits_per_pass - 1) // bits_per_pass
|
|
90
|
+
group_size = 128
|
|
91
|
+
histogram_size = 2 ** bits_per_pass
|
|
92
|
+
|
|
93
|
+
def GetBits(A, i):
|
|
94
|
+
return (A >> (i * bits_per_pass)) & tf.uint(histogram_size - 1)
|
|
95
|
+
|
|
96
|
+
keys1 = tf.buffer(keys.shape, keys.type)
|
|
97
|
+
values1 = None
|
|
98
|
+
|
|
99
|
+
if has_values:
|
|
100
|
+
values1 = tf.buffer(values.shape, values.type)
|
|
101
|
+
|
|
102
|
+
with tf.loop(iters // 2) as iter:
|
|
103
|
+
def SortIteration(keys_in, keys_out, values_in, values_out, iter):
|
|
104
|
+
tf.region_begin('Radix sort iteration')
|
|
105
|
+
grouped = tf.split_dim(GetBits(keys_in, iter), group_size)
|
|
106
|
+
|
|
107
|
+
# Do a packed histogram, since we sum 128 elements at a time, we can pack 4 values into a single uint32
|
|
108
|
+
g, e, i = tf.indices([grouped.shape[0], grouped.shape[1], tf.int(histogram_size/4)])
|
|
109
|
+
this_key = grouped[g, e]
|
|
110
|
+
packed_is_bit = (tf.uint(this_key == tf.uint(4*i))) + (tf.uint(this_key == tf.uint(4*i+1)) << 8) + (tf.uint(this_key == tf.uint(4*i+2)) << 16) + (tf.uint(this_key == tf.uint(4*i+3)) << 24)
|
|
111
|
+
packed_is_bit = tf.select((g*group_size + e) < keys_in.shape[0], packed_is_bit, tf.uint(0))
|
|
112
|
+
group_histogram_packed = tf.sum(packed_is_bit, axis = 1)
|
|
113
|
+
|
|
114
|
+
g, i = tf.indices([grouped.shape[0], histogram_size])
|
|
115
|
+
group_histogram = tf.uint((group_histogram_packed[g, i / 4] >> (8*(i % 4))) & tf.uint(0xFF))
|
|
116
|
+
|
|
117
|
+
group_histogram_scan = prefix_sum_grouped(group_histogram, axis = 0)
|
|
118
|
+
i, = tf.indices([histogram_size])
|
|
119
|
+
total_bit_histogram = tf.prefix_sum(group_histogram_scan[group_histogram_scan.shape[0] - 1, i])
|
|
120
|
+
|
|
121
|
+
with tf.kernel(grouped.shape, group_size=[group_size]) as (g, e):
|
|
122
|
+
if(tf.current_backend() == tf.cpu): #dont use group barriers on CPU - doesn't work
|
|
123
|
+
element = g * group_size + e
|
|
124
|
+
with tf.if_cond(element < keys_in.shape[0]):
|
|
125
|
+
old_key = keys_in[element]
|
|
126
|
+
old_val = values_in[element]
|
|
127
|
+
bit = GetBits(old_key, iter)
|
|
128
|
+
total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, bit]) + tf.select(bit == tf.uint(0), tf.uint(0), total_bit_histogram[bit - tf.uint(1)])
|
|
129
|
+
with tf.loop(e) as j:
|
|
130
|
+
total_offset.val += tf.uint(grouped[g, j] == bit)
|
|
131
|
+
keys_out[total_offset] = old_key
|
|
132
|
+
values_out[total_offset] = old_val
|
|
133
|
+
else:
|
|
134
|
+
temp = tf.group_buffer(group_size, tf.uint32)
|
|
135
|
+
half_count = tf.group_buffer(histogram_size, tf.uint32)
|
|
136
|
+
gtid = g.block_thread_index(0)
|
|
137
|
+
|
|
138
|
+
#initialize counters
|
|
139
|
+
for i in range((histogram_size + group_size - 1) // group_size):
|
|
140
|
+
index = gtid + i * group_size
|
|
141
|
+
with tf.if_cond(index < histogram_size):
|
|
142
|
+
half_count[index] = 0
|
|
143
|
+
tf.group_barrier()
|
|
144
|
+
|
|
145
|
+
element = g * group_size + e
|
|
146
|
+
with tf.if_cond(element < keys_in.shape[0]):
|
|
147
|
+
old_key = keys_in[element]
|
|
148
|
+
bit = GetBits(old_key, iter)
|
|
149
|
+
temp[gtid] = bit
|
|
150
|
+
|
|
151
|
+
#count number of bits set in previous sub groups
|
|
152
|
+
quarter_index = e / (group_size // 4)
|
|
153
|
+
with tf.if_cond(quarter_index < 3):
|
|
154
|
+
tf.scatterAdd(half_count[bit], tf.uint(quarter_index < 1) | (tf.uint(quarter_index < 2) << 8) | (tf.uint(quarter_index < 3) << 16))
|
|
155
|
+
|
|
156
|
+
tf.group_barrier()
|
|
157
|
+
|
|
158
|
+
if has_values:
|
|
159
|
+
old_val = values_in[element]
|
|
160
|
+
|
|
161
|
+
total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, tf.int(bit)]) + tf.select(tf.int(bit) == 0, tf.uint(0), total_bit_histogram[tf.int(bit) - 1])
|
|
162
|
+
total_offset += tf.select(quarter_index > 0, (half_count[bit] >> (8*(quarter_index-1))) & tf.uint(0xFF), tf.uint(0))
|
|
163
|
+
begin_index = quarter_index * (group_size // 4)
|
|
164
|
+
with tf.loop(begin_index, e) as j:
|
|
165
|
+
total_offset.val += tf.uint(temp[j] == bit)
|
|
166
|
+
keys_out[total_offset] = old_key
|
|
167
|
+
|
|
168
|
+
if has_values:
|
|
169
|
+
values_out[total_offset] = old_val
|
|
170
|
+
|
|
171
|
+
tf.region_end('Radix sort iteration')
|
|
172
|
+
|
|
173
|
+
SortIteration(keys, keys1, values, values1, 2 * iter)
|
|
174
|
+
SortIteration(keys1, keys, values1, values, 2 * iter + 1)
|
|
175
|
+
|
|
176
|
+
tf.region_end('Radix sort')
|
|
177
|
+
|
|
178
|
+
if(original_type == tf.float32):
|
|
179
|
+
keys = map_uint_to_float(keys)
|
|
180
|
+
|
|
181
|
+
if(original_type == tf.int32):
|
|
182
|
+
keys = map_uint_to_int(keys)
|
|
183
|
+
|
|
184
|
+
if has_values:
|
|
185
|
+
return keys, values
|
|
186
|
+
else:
|
|
187
|
+
return keys
|