TensorFrost 0.7.0.dev5__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ from .TensorFrost import *
2
+
3
+ from . import optimizers
4
+ from . import regularizers
5
+ from . import clipping
6
+ from . import random
7
+ from . import sort
8
+ from .default import *
@@ -0,0 +1,5 @@
1
+ from .optimizers import *
2
+
3
+ clamp = ModuleOptimizer.ClippingType.Clamp
4
+ norm = ModuleOptimizer.ClippingType.Norm
5
+ none = ModuleOptimizer.ClippingType.None_
TensorFrost/default.py ADDED
@@ -0,0 +1,14 @@
1
+ from . import TensorFrost as tf
2
+
3
+ def zeros_like(tensor):
4
+ return tf.zeros(tensor.shape, tensor.type)
5
+
6
+ def eye(n):
7
+ i, j = tf.indices([n, n])
8
+ return tf.select(i == j, 1.0, 0.0)
9
+
10
+ def eye_like(tensor):
11
+ return eye(tensor.shape[0])
12
+
13
+ def ones_like(tensor):
14
+ return tf.ones(tensor.shape, tensor.type)
Binary file
@@ -0,0 +1,219 @@
1
+ from . import TensorFrost as tf
2
+
3
+ class ModuleOptimizer(tf.Module):
4
+ class OptimizerType:
5
+ ADAM = 0
6
+ SGD = 1
7
+ RMSProp = 2
8
+
9
+ class RegularizerType:
10
+ None_ = 0
11
+ L1 = 1
12
+ L2 = 2
13
+
14
+ class ClippingType:
15
+ Clamp = 0
16
+ Norm = 1
17
+ None_ = 2
18
+
19
+ def __init__(self, optimizer_type, regularizer_type, net, params):
20
+ super().__init__()
21
+ self.optimizer_type = optimizer_type
22
+ self.regularizer_type = regularizer_type
23
+ self.clipping_type = self.ClippingType.Clamp
24
+ self.epsilon = 1e-8
25
+
26
+ # Set passed parameters as attributes
27
+ self.net = net
28
+ for k, v in params.items():
29
+ setattr(self, k, v)
30
+
31
+ # Initialize t
32
+ t = tf.Parameter([1], tf.float32, False) # mimic Parameter({1}, TFType::Float, false)
33
+ self.t = t
34
+
35
+ self.initializeOptimizer(net)
36
+
37
+ def set_clipping_type(self, ctype):
38
+ self.clipping_type = ctype
39
+
40
+ def initializeOptimizer(self, net):
41
+ net_params = net.parameters()
42
+ requires_grads = net.requires_grads_list()
43
+
44
+ if self.optimizer_type == self.OptimizerType.ADAM:
45
+ self.initializeParameterArray("m", net_params, requires_grads)
46
+ self.initializeParameterArray("v", net_params, requires_grads)
47
+ elif self.optimizer_type == self.OptimizerType.SGD:
48
+ # No additional parameters needed
49
+ pass
50
+ elif self.optimizer_type == self.OptimizerType.RMSProp:
51
+ self.initializeParameterArray("v", net_params, requires_grads)
52
+
53
+ def initializeParameterArray(self, name, net_params, requires_grads):
54
+ arr = tf.ParameterArray()
55
+
56
+ for i, param in enumerate(net_params):
57
+ if not requires_grads[i]:
58
+ continue
59
+
60
+ new_param = tf.Parameter(param.shape, tf.float32, False)
61
+ arr[i] = new_param
62
+
63
+ setattr(self, name, arr)
64
+
65
+ def assert_parameters(self):
66
+ net_params = self.net.parameters()
67
+ requires_grads = self.net.requires_grads_list()
68
+ self.assertParameterArray("m", net_params, requires_grads)
69
+ self.assertParameterArray("v", net_params, requires_grads)
70
+
71
+ def gradient_norm(self, grad):
72
+ # sum of squares
73
+ g = grad * grad
74
+ shape = grad.shape
75
+ num_dims = len(shape)
76
+ for i in range(num_dims):
77
+ g = tf.sum(g)
78
+ return tf.sqrt(g)
79
+
80
+ def assertParameterArray(self, name, net_params, requires_grads):
81
+ if hasattr(self, name):
82
+ arr = getattr(self, name)
83
+ for i, param in enumerate(net_params):
84
+ if not requires_grads[i]:
85
+ continue
86
+ arr_item = arr[i]
87
+ arr_item = tf.assert_tensor(arr_item, param.shape, param.type)
88
+ arr[i] = arr_item
89
+
90
+ def step(self, *args):
91
+ # Overloaded step:
92
+ # step(X, Y) or step(loss)
93
+ if len(args) == 2:
94
+ X, Y = args
95
+ loss = self.net.loss(X, Y)
96
+ self._step(loss)
97
+ return loss
98
+ elif len(args) == 1:
99
+ (loss,) = args
100
+ self._step(loss)
101
+ else:
102
+ raise ValueError("Invalid arguments to step")
103
+
104
+ def _step(self, loss):
105
+ # Increment t by 1
106
+ self.t = self.t + 1.0
107
+
108
+ net = self.net
109
+ net_params = net.parameters()
110
+ requires_grads = net.requires_grads_list()
111
+
112
+ learning_rate = self.learning_rate
113
+ grad_clip = self.grad_clip
114
+ has_clip = isinstance(grad_clip, float) and grad_clip > 0.0
115
+
116
+ for i, param in enumerate(net_params):
117
+ if not requires_grads[i]:
118
+ continue
119
+
120
+ grad = tf.grad(loss, param)
121
+ if has_clip:
122
+ if self.clipping_type == self.ClippingType.Clamp:
123
+ grad = tf.clamp(grad, -grad_clip, grad_clip)
124
+ elif self.clipping_type == self.ClippingType.Norm:
125
+ grad_norm = tf.max(1e-6, self.gradient_norm(grad))
126
+ grad = grad * tf.min(1.0, grad_clip / grad_norm)
127
+
128
+ if self.optimizer_type == self.OptimizerType.ADAM:
129
+ update = self.adam_update(i, param, grad, self.t, learning_rate)
130
+ elif self.optimizer_type == self.OptimizerType.SGD:
131
+ update = self.sgd_update(param, grad, learning_rate)
132
+ elif self.optimizer_type == self.OptimizerType.RMSProp:
133
+ update = self.rmsprop_update(i, param, grad, learning_rate)
134
+ else:
135
+ raise RuntimeError("Unknown optimizer type")
136
+
137
+ # Apply regularization if needed
138
+ if self.regularizer_type == self.RegularizerType.L1:
139
+ param = param - learning_rate * self.reg * tf.sign(param)
140
+ elif self.regularizer_type == self.RegularizerType.L2:
141
+ param = param - learning_rate * self.reg * param
142
+
143
+ # Update parameter with computed update
144
+ param = param - update
145
+ net_params[i] = param
146
+
147
+ net.update_parameters(net_params)
148
+
149
+ def adam_update(self, i, param, grad, t, learning_rate):
150
+ beta1 = tf.float(self.beta1)
151
+ beta2 = tf.float(self.beta2)
152
+
153
+ m = self.m[i]
154
+ v = self.v[i]
155
+
156
+ m = tf.lerp(grad, m, beta1)
157
+ v = tf.lerp(grad * grad, v, beta2)
158
+
159
+ # t is a Parameter with shape [1]; get the scalar
160
+ t_val = self.t[0]
161
+ mhat = m / (1.0 - tf.pow(beta1, t_val))
162
+ vhat = v / (1.0 - tf.pow(beta2, t_val))
163
+
164
+ self.m[i] = m
165
+ self.v[i] = v
166
+
167
+ return learning_rate * mhat / (tf.sqrt(vhat) + self.epsilon)
168
+
169
+ def sgd_update(self, param, grad, learning_rate):
170
+ return learning_rate * grad
171
+
172
+ def rmsprop_update(self, i, param, grad, learning_rate):
173
+ decay = tf.float(self.decay)
174
+
175
+ v = self.v[i]
176
+ v = tf.lerp(grad * grad, v, decay)
177
+ self.v[i] = v
178
+
179
+ return (grad * learning_rate) / (tf.sqrt(v) + self.epsilon)
180
+
181
+
182
+ def adam(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, beta1=0.9, beta2=0.999, clip=0.0, reg=0.0):
183
+ return ModuleOptimizer(
184
+ ModuleOptimizer.OptimizerType.ADAM,
185
+ reg_type,
186
+ net,
187
+ {
188
+ "learning_rate": learning_rate,
189
+ "beta1": beta1,
190
+ "beta2": beta2,
191
+ "grad_clip": clip,
192
+ "reg": reg,
193
+ }
194
+ )
195
+
196
+ def sgd(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, clip=0.0, reg=0.0):
197
+ return ModuleOptimizer(
198
+ ModuleOptimizer.OptimizerType.SGD,
199
+ reg_type,
200
+ net,
201
+ {
202
+ "learning_rate": learning_rate,
203
+ "grad_clip": clip,
204
+ "reg": reg,
205
+ }
206
+ )
207
+
208
+ def rmsprop(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, decay=0.9, clip=0.0, reg=0.0):
209
+ return ModuleOptimizer(
210
+ ModuleOptimizer.OptimizerType.RMSProp,
211
+ reg_type,
212
+ net,
213
+ {
214
+ "learning_rate": learning_rate,
215
+ "decay": decay,
216
+ "grad_clip": clip,
217
+ "reg": reg,
218
+ }
219
+ )
TensorFrost/random.py ADDED
@@ -0,0 +1,45 @@
1
+ from . import TensorFrost as tf
2
+
3
+ def randn2(shape, seed=0):
4
+ #Box-Muller transform
5
+ r1 = tf.random_value(shape, seed=seed)
6
+ r2 = tf.random_value(shape, seed=tf.hash(seed))
7
+ rho = tf.sqrt(-2.0*tf.log(tf.max(1e-6, r1)))
8
+ theta = 2.0*tf.pi*r2
9
+ return rho*tf.cos(theta), rho*tf.sin(theta)
10
+
11
+ def randn(shape, seed=0):
12
+ return randn2(shape, seed=seed)[0]
13
+
14
+ def rand(shape, seed=0):
15
+ return tf.random_value(shape, seed=seed)
16
+
17
+ def randn_like(tensor, seed=0):
18
+ return randn(tensor.shape, seed=seed)
19
+
20
+ def rand_like(tensor, seed=0):
21
+ return rand(tensor.shape, seed=seed)
22
+
23
+ def rand_int(seed, max_value):
24
+ return tf.int(tf.pcg(tf.uint(seed)) % tf.uint(max_value))
25
+
26
+ def xor_swap(idx, n, seed):
27
+ xor_seed = rand_int(seed, n)
28
+ xor_idx = (idx ^ xor_seed)
29
+ max_idx = tf.max(idx, xor_idx)
30
+ min_idx = tf.min(idx, xor_idx)
31
+ swap = rand_int(min_idx * 451 + seed, 2) == 0
32
+ return tf.select(swap & (max_idx < n), xor_idx, idx)
33
+
34
+ def reverse(idx, n):
35
+ return n - 1 - idx
36
+
37
+ def shuffle(idx, n, seed = 0, iters = 16):
38
+ for i in range(iters):
39
+ idx = xor_swap(idx, n, seed + i)
40
+ idx = reverse(idx, n)
41
+ return idx
42
+
43
+ def permutation(n, seed = 0):
44
+ idx = tf.indices([n])[0]
45
+ return shuffle(idx, n, seed)
@@ -0,0 +1,5 @@
1
+ from .optimizers import *
2
+
3
+ l1 = ModuleOptimizer.RegularizerType.L1
4
+ l2 = ModuleOptimizer.RegularizerType.L2
5
+ none = ModuleOptimizer.RegularizerType.None_
TensorFrost/sort.py ADDED
@@ -0,0 +1,187 @@
1
+ from . import TensorFrost as tf
2
+
3
+ #in-place bitonic sort
4
+ def bitonic(keys, values = None):
5
+ tf.region_begin('Bitonic sort')
6
+ keys = tf.copy(keys)
7
+ if values is not None:
8
+ values = tf.copy(values)
9
+ element_count = keys.shape[0]
10
+ log2_count = tf.int(tf.ceil(tf.log2(tf.float(element_count))))
11
+ count_round = 1 << log2_count
12
+ idx = tf.indices([count_round / 2])[0]
13
+ with tf.loop(log2_count) as k:
14
+ with tf.loop(k+1) as j:
15
+ s = 1 << (k-j)
16
+ m_inner = s - 1
17
+ m_outer = ~m_inner
18
+ m_xor = s + tf.select(j == 0, m_inner, 0)
19
+
20
+ id1 = (2 * (idx & m_outer) + (idx & m_inner))
21
+ id2 = id1 ^ m_xor
22
+ key1, key2 = keys[id1], keys[id2]
23
+ with tf.if_cond((key1 >= key2) & (id1 < element_count) & (id2 < element_count)):
24
+ if values is not None:
25
+ val1, val2 = values[id1], values[id2]
26
+ values[id1] = val2
27
+ values[id2] = val1
28
+ keys[id1] = key2
29
+ keys[id2] = key1
30
+
31
+ tf.region_end('Bitonic sort')
32
+ if values is not None:
33
+ return keys, values
34
+ else:
35
+ return keys
36
+
37
+ #histogram radix sort
38
+ def radix(keys, values = None, bits_per_pass = 6, max_bits = 32):
39
+ def prefix_sum_grouped(A, axis = -1):
40
+ axis = len(A.shape) + axis if axis < 0 else axis
41
+ group_size = 64
42
+ grouped = tf.split_dim(A, group_size, axis)
43
+ group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis)
44
+ ids = grouped.indices
45
+ gid, eid = ids[axis], ids[axis + 1]
46
+ ids = [ids[i] for i in range(len(ids)) if i != axis + 1]
47
+ ids[axis] = gid - 1
48
+ group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), tf.uint(0), group_scan[tuple(ids)]), axis = axis + 1)
49
+ full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1)
50
+ return full_scan
51
+
52
+ sign_bit = ~tf.uint(0x7FFFFFFF)
53
+
54
+ def map_float_to_uint(x):
55
+ # Convert float to uint representation
56
+ ux = tf.asuint(x)
57
+ # Compute mask
58
+ mask = tf.select((ux >> 31) == 1, ~tf.uint(0), sign_bit)
59
+ # Apply XOR
60
+ return ux ^ mask
61
+
62
+ def map_uint_to_float(x):
63
+ # Compute mask
64
+ mask = tf.select((x >> 31) == 0, ~tf.uint(0), sign_bit)
65
+ # Apply XOR and convert back to float
66
+ return tf.asfloat(x ^ mask)
67
+
68
+ def map_int_to_uint(x):
69
+ return tf.asuint(x) ^ sign_bit
70
+
71
+ def map_uint_to_int(x):
72
+ return tf.asint(x ^ sign_bit)
73
+
74
+ tf.region_begin('Radix sort')
75
+
76
+ has_values = values is not None
77
+
78
+ keys = tf.copy(keys)
79
+ if has_values:
80
+ values = tf.copy(values)
81
+
82
+ original_type = keys.type
83
+ if(original_type == tf.float32):
84
+ keys = map_float_to_uint(keys)
85
+
86
+ if(original_type == tf.int32):
87
+ keys = map_int_to_uint(keys)
88
+
89
+ iters = (max_bits + bits_per_pass - 1) // bits_per_pass
90
+ group_size = 128
91
+ histogram_size = 2 ** bits_per_pass
92
+
93
+ def GetBits(A, i):
94
+ return (A >> (i * bits_per_pass)) & tf.uint(histogram_size - 1)
95
+
96
+ keys1 = tf.buffer(keys.shape, keys.type)
97
+ values1 = None
98
+
99
+ if has_values:
100
+ values1 = tf.buffer(values.shape, values.type)
101
+
102
+ with tf.loop(iters // 2) as iter:
103
+ def SortIteration(keys_in, keys_out, values_in, values_out, iter):
104
+ tf.region_begin('Radix sort iteration')
105
+ grouped = tf.split_dim(GetBits(keys_in, iter), group_size)
106
+
107
+ # Do a packed histogram, since we sum 128 elements at a time, we can pack 4 values into a single uint32
108
+ g, e, i = tf.indices([grouped.shape[0], grouped.shape[1], tf.int(histogram_size/4)])
109
+ this_key = grouped[g, e]
110
+ packed_is_bit = (tf.uint(this_key == tf.uint(4*i))) + (tf.uint(this_key == tf.uint(4*i+1)) << 8) + (tf.uint(this_key == tf.uint(4*i+2)) << 16) + (tf.uint(this_key == tf.uint(4*i+3)) << 24)
111
+ packed_is_bit = tf.select((g*group_size + e) < keys_in.shape[0], packed_is_bit, tf.uint(0))
112
+ group_histogram_packed = tf.sum(packed_is_bit, axis = 1)
113
+
114
+ g, i = tf.indices([grouped.shape[0], histogram_size])
115
+ group_histogram = tf.uint((group_histogram_packed[g, i / 4] >> (8*(i % 4))) & tf.uint(0xFF))
116
+
117
+ group_histogram_scan = prefix_sum_grouped(group_histogram, axis = 0)
118
+ i, = tf.indices([histogram_size])
119
+ total_bit_histogram = tf.prefix_sum(group_histogram_scan[group_histogram_scan.shape[0] - 1, i])
120
+
121
+ with tf.kernel(grouped.shape, group_size=[group_size]) as (g, e):
122
+ if(tf.current_backend() == tf.cpu): #dont use group barriers on CPU - doesn't work
123
+ element = g * group_size + e
124
+ with tf.if_cond(element < keys_in.shape[0]):
125
+ old_key = keys_in[element]
126
+ old_val = values_in[element]
127
+ bit = GetBits(old_key, iter)
128
+ total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, bit]) + tf.select(bit == tf.uint(0), tf.uint(0), total_bit_histogram[bit - tf.uint(1)])
129
+ with tf.loop(e) as j:
130
+ total_offset.val += tf.uint(grouped[g, j] == bit)
131
+ keys_out[total_offset] = old_key
132
+ values_out[total_offset] = old_val
133
+ else:
134
+ temp = tf.group_buffer(group_size, tf.uint32)
135
+ half_count = tf.group_buffer(histogram_size, tf.uint32)
136
+ gtid = g.block_thread_index(0)
137
+
138
+ #initialize counters
139
+ for i in range((histogram_size + group_size - 1) // group_size):
140
+ index = gtid + i * group_size
141
+ with tf.if_cond(index < histogram_size):
142
+ half_count[index] = 0
143
+ tf.group_barrier()
144
+
145
+ element = g * group_size + e
146
+ with tf.if_cond(element < keys_in.shape[0]):
147
+ old_key = keys_in[element]
148
+ bit = GetBits(old_key, iter)
149
+ temp[gtid] = bit
150
+
151
+ #count number of bits set in previous sub groups
152
+ quarter_index = e / (group_size // 4)
153
+ with tf.if_cond(quarter_index < 3):
154
+ tf.scatterAdd(half_count[bit], tf.uint(quarter_index < 1) | (tf.uint(quarter_index < 2) << 8) | (tf.uint(quarter_index < 3) << 16))
155
+
156
+ tf.group_barrier()
157
+
158
+ if has_values:
159
+ old_val = values_in[element]
160
+
161
+ total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, tf.int(bit)]) + tf.select(tf.int(bit) == 0, tf.uint(0), total_bit_histogram[tf.int(bit) - 1])
162
+ total_offset += tf.select(quarter_index > 0, (half_count[bit] >> (8*(quarter_index-1))) & tf.uint(0xFF), tf.uint(0))
163
+ begin_index = quarter_index * (group_size // 4)
164
+ with tf.loop(begin_index, e) as j:
165
+ total_offset.val += tf.uint(temp[j] == bit)
166
+ keys_out[total_offset] = old_key
167
+
168
+ if has_values:
169
+ values_out[total_offset] = old_val
170
+
171
+ tf.region_end('Radix sort iteration')
172
+
173
+ SortIteration(keys, keys1, values, values1, 2 * iter)
174
+ SortIteration(keys1, keys, values1, values, 2 * iter + 1)
175
+
176
+ tf.region_end('Radix sort')
177
+
178
+ if(original_type == tf.float32):
179
+ keys = map_uint_to_float(keys)
180
+
181
+ if(original_type == tf.int32):
182
+ keys = map_uint_to_int(keys)
183
+
184
+ if has_values:
185
+ return keys, values
186
+ else:
187
+ return keys