grx-tensor 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +54 -0
- data/LICENSE.txt +21 -0
- data/README.md +471 -0
- data/ext/grx/extconf.rb +31 -0
- data/ext/grx/grx_core.c +534 -0
- data/ext/grx/grx_core.h +85 -0
- data/ext/unix/Makefile +66 -0
- data/ext/windows/Makefile.mingw +50 -0
- data/grx-tensor.gemspec +88 -0
- data/lib/grx/c_api.rb +96 -0
- data/lib/grx/errors.rb +8 -0
- data/lib/grx/loss.rb +81 -0
- data/lib/grx/nn.rb +262 -0
- data/lib/grx/optim.rb +121 -0
- data/lib/grx/storage.rb +85 -0
- data/lib/grx/tensor.rb +623 -0
- data/lib/grx/version.rb +5 -0
- data/lib/grx.rb +49 -0
- metadata +159 -0
data/ext/grx/grx_core.c
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* grx_core.c — Núcleo C de GRX
|
|
3
|
+
* =============================================================
|
|
4
|
+
* Optimizaciones activas:
|
|
5
|
+
* - AVX2 + FMA: 4 doubles/ciclo con multiply-add fusionado
|
|
6
|
+
* - Loop unrolling x2: mayor ILP (Instruction Level Parallelism)
|
|
7
|
+
* - restrict: elimina alias analysis, permite más vectorización auto
|
|
8
|
+
* - Memoria alineada 32 bytes: habilita _mm256_load_pd (más rápido que loadu)
|
|
9
|
+
* - matmul con tiling: respeta líneas de caché L1 (64 bytes = 8 doubles)
|
|
10
|
+
* - Adam con FMA: beta*m + (1-beta)*grad en una pasada
|
|
11
|
+
* =============================================================
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
#define _USE_MATH_DEFINES /* M_PI en Windows/MSVC */
|
|
15
|
+
#define _POSIX_C_SOURCE 200809L /* posix_memalign, M_PI en glibc */
|
|
16
|
+
#include "grx_core.h"
|
|
17
|
+
#include <stdlib.h>
|
|
18
|
+
#include <stdint.h>
|
|
19
|
+
#include <string.h>
|
|
20
|
+
#include <math.h>
|
|
21
|
+
#include <float.h>
|
|
22
|
+
#include <time.h>
|
|
23
|
+
|
|
24
|
+
#ifndef M_PI
|
|
25
|
+
#define M_PI 3.14159265358979323846
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
#if defined(__AVX2__) && defined(__FMA__)
|
|
29
|
+
#include <immintrin.h>
|
|
30
|
+
#define GRX_AVX2_FMA 1
|
|
31
|
+
#elif defined(__AVX2__)
|
|
32
|
+
#include <immintrin.h>
|
|
33
|
+
#define GRX_AVX2 1
|
|
34
|
+
#elif defined(__SSE2__)
|
|
35
|
+
#include <emmintrin.h>
|
|
36
|
+
#define GRX_SSE2 1
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#define TILE 8
|
|
40
|
+
|
|
41
|
+
/* ================================================================
|
|
42
|
+
* MEMORIA
|
|
43
|
+
* ================================================================ */
|
|
44
|
+
|
|
45
|
+
GRX_API double* grx_alloc(size_t n) {
|
|
46
|
+
if (__builtin_expect(n == 0, 0)) return NULL;
|
|
47
|
+
void *ptr = NULL;
|
|
48
|
+
#if defined(_WIN32)
|
|
49
|
+
ptr = _aligned_malloc(n * sizeof(double), 32);
|
|
50
|
+
#else
|
|
51
|
+
if (posix_memalign(&ptr, 32, n * sizeof(double)) != 0) return NULL;
|
|
52
|
+
#endif
|
|
53
|
+
return (double*)ptr;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
GRX_API void grx_free(double *ptr) {
|
|
57
|
+
#if defined(_WIN32)
|
|
58
|
+
_aligned_free(ptr);
|
|
59
|
+
#else
|
|
60
|
+
free(ptr);
|
|
61
|
+
#endif
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/* ================================================================
|
|
65
|
+
* MACROS SIMD INTERNOS
|
|
66
|
+
* ================================================================ */
|
|
67
|
+
|
|
68
|
+
/* Carga/store: usa aligned si tenemos AVX2+FMA (memoria siempre alineada a 32b) */
|
|
69
|
+
#ifdef GRX_AVX2_FMA
|
|
70
|
+
#define VLD(p) _mm256_load_pd(p)
|
|
71
|
+
#define VST(p, v) _mm256_store_pd(p, v)
|
|
72
|
+
#elif defined(GRX_AVX2)
|
|
73
|
+
#define VLD(p) _mm256_loadu_pd(p)
|
|
74
|
+
#define VST(p, v) _mm256_storeu_pd(p, v)
|
|
75
|
+
#endif
|
|
76
|
+
|
|
77
|
+
/* ================================================================
|
|
78
|
+
* ELEMENT-WISE ARITMÉTICA
|
|
79
|
+
* ================================================================ */
|
|
80
|
+
|
|
81
|
+
#define BINOP_BODY(op_avx, op_scalar) \
|
|
82
|
+
size_t i = 0; \
|
|
83
|
+
for (; i + 8 <= n; i += 8) { \
|
|
84
|
+
VST(out+i, op_avx(VLD(a+i), VLD(b+i))); \
|
|
85
|
+
VST(out+i+4, op_avx(VLD(a+i+4), VLD(b+i+4))); \
|
|
86
|
+
} \
|
|
87
|
+
for (; i + 4 <= n; i += 4) VST(out+i, op_avx(VLD(a+i), VLD(b+i)));\
|
|
88
|
+
for (; i < n; i++) out[i] = op_scalar(a[i], b[i]);
|
|
89
|
+
|
|
90
|
+
#define SCALAR_ADD(x,y) ((x)+(y))
|
|
91
|
+
#define SCALAR_SUB(x,y) ((x)-(y))
|
|
92
|
+
#define SCALAR_MUL(x,y) ((x)*(y))
|
|
93
|
+
#define SCALAR_DIV(x,y) ((x)/(y))
|
|
94
|
+
|
|
95
|
+
GRX_API void grx_add(const double * restrict a, const double * restrict b,
|
|
96
|
+
double * restrict out, size_t n) {
|
|
97
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
98
|
+
BINOP_BODY(_mm256_add_pd, SCALAR_ADD)
|
|
99
|
+
#else
|
|
100
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] + b[i];
|
|
101
|
+
#endif
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
GRX_API void grx_sub(const double * restrict a, const double * restrict b,
|
|
105
|
+
double * restrict out, size_t n) {
|
|
106
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
107
|
+
BINOP_BODY(_mm256_sub_pd, SCALAR_SUB)
|
|
108
|
+
#else
|
|
109
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] - b[i];
|
|
110
|
+
#endif
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
GRX_API void grx_mul(const double * restrict a, const double * restrict b,
|
|
114
|
+
double * restrict out, size_t n) {
|
|
115
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
116
|
+
BINOP_BODY(_mm256_mul_pd, SCALAR_MUL)
|
|
117
|
+
#else
|
|
118
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] * b[i];
|
|
119
|
+
#endif
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
GRX_API void grx_div(const double * restrict a, const double * restrict b,
|
|
123
|
+
double * restrict out, size_t n) {
|
|
124
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
125
|
+
BINOP_BODY(_mm256_div_pd, SCALAR_DIV)
|
|
126
|
+
#else
|
|
127
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] / b[i];
|
|
128
|
+
#endif
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
GRX_API void grx_scale(const double * restrict a, double s,
|
|
132
|
+
double * restrict out, size_t n) {
|
|
133
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
134
|
+
__m256d vs = _mm256_set1_pd(s);
|
|
135
|
+
size_t i = 0;
|
|
136
|
+
for (; i + 8 <= n; i += 8) {
|
|
137
|
+
VST(out+i, _mm256_mul_pd(VLD(a+i), vs));
|
|
138
|
+
VST(out+i+4, _mm256_mul_pd(VLD(a+i+4), vs));
|
|
139
|
+
}
|
|
140
|
+
for (; i + 4 <= n; i += 4) VST(out+i, _mm256_mul_pd(VLD(a+i), vs));
|
|
141
|
+
for (; i < n; i++) out[i] = a[i] * s;
|
|
142
|
+
#else
|
|
143
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] * s;
|
|
144
|
+
#endif
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
GRX_API void grx_add_scalar(const double * restrict a, double s,
|
|
148
|
+
double * restrict out, size_t n) {
|
|
149
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
150
|
+
__m256d vs = _mm256_set1_pd(s);
|
|
151
|
+
size_t i = 0;
|
|
152
|
+
for (; i + 8 <= n; i += 8) {
|
|
153
|
+
VST(out+i, _mm256_add_pd(VLD(a+i), vs));
|
|
154
|
+
VST(out+i+4, _mm256_add_pd(VLD(a+i+4), vs));
|
|
155
|
+
}
|
|
156
|
+
for (; i + 4 <= n; i += 4) VST(out+i, _mm256_add_pd(VLD(a+i), vs));
|
|
157
|
+
for (; i < n; i++) out[i] = a[i] + s;
|
|
158
|
+
#else
|
|
159
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] + s;
|
|
160
|
+
#endif
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
GRX_API void grx_negate(const double * restrict a, double * restrict out, size_t n) {
|
|
164
|
+
grx_scale(a, -1.0, out, n);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/* ================================================================
|
|
168
|
+
* MATEMÁTICAS ELEMENT-WISE
|
|
169
|
+
* ================================================================ */
|
|
170
|
+
|
|
171
|
+
GRX_API void grx_abs(const double * restrict a, double * restrict out, size_t n) {
|
|
172
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
173
|
+
/* Máscara para limpiar el bit de signo (AND con 0x7FFFFFFFFFFFFFFF) */
|
|
174
|
+
__m256d mask = _mm256_castsi256_pd(
|
|
175
|
+
_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL));
|
|
176
|
+
size_t i = 0;
|
|
177
|
+
for (; i + 4 <= n; i += 4)
|
|
178
|
+
VST(out+i, _mm256_and_pd(VLD(a+i), mask));
|
|
179
|
+
for (; i < n; i++) out[i] = fabs(a[i]);
|
|
180
|
+
#else
|
|
181
|
+
for (size_t i = 0; i < n; i++) out[i] = fabs(a[i]);
|
|
182
|
+
#endif
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
GRX_API void grx_sqrt(const double * restrict a, double * restrict out, size_t n) {
|
|
186
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
187
|
+
size_t i = 0;
|
|
188
|
+
for (; i + 4 <= n; i += 4)
|
|
189
|
+
VST(out+i, _mm256_sqrt_pd(VLD(a+i)));
|
|
190
|
+
for (; i < n; i++) out[i] = sqrt(a[i]);
|
|
191
|
+
#else
|
|
192
|
+
for (size_t i = 0; i < n; i++) out[i] = sqrt(a[i]);
|
|
193
|
+
#endif
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
GRX_API void grx_square(const double * restrict a, double * restrict out, size_t n) {
|
|
197
|
+
grx_mul(a, a, out, n);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
GRX_API void grx_log(const double * restrict a, double * restrict out, size_t n) {
|
|
201
|
+
/* log no tiene intrínseco SIMD estándar; -ffast-math + -march=native
|
|
202
|
+
* permite al compilador auto-vectorizar con SVML si está disponible */
|
|
203
|
+
for (size_t i = 0; i < n; i++) out[i] = log(a[i]);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
GRX_API void grx_exp(const double * restrict a, double * restrict out, size_t n) {
|
|
207
|
+
for (size_t i = 0; i < n; i++) out[i] = exp(a[i]);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
GRX_API void grx_pow(const double * restrict a, double e,
|
|
211
|
+
double * restrict out, size_t n) {
|
|
212
|
+
for (size_t i = 0; i < n; i++) out[i] = pow(a[i], e);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
GRX_API void grx_clip(const double * restrict a, double lo, double hi,
|
|
216
|
+
double * restrict out, size_t n) {
|
|
217
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
218
|
+
__m256d vlo = _mm256_set1_pd(lo);
|
|
219
|
+
__m256d vhi = _mm256_set1_pd(hi);
|
|
220
|
+
size_t i = 0;
|
|
221
|
+
for (; i + 4 <= n; i += 4)
|
|
222
|
+
VST(out+i, _mm256_min_pd(_mm256_max_pd(VLD(a+i), vlo), vhi));
|
|
223
|
+
for (; i < n; i++) out[i] = a[i] < lo ? lo : (a[i] > hi ? hi : a[i]);
|
|
224
|
+
#else
|
|
225
|
+
for (size_t i = 0; i < n; i++)
|
|
226
|
+
out[i] = a[i] < lo ? lo : (a[i] > hi ? hi : a[i]);
|
|
227
|
+
#endif
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
/* ================================================================
|
|
231
|
+
* REDUCCIONES
|
|
232
|
+
* ================================================================ */
|
|
233
|
+
|
|
234
|
+
GRX_API double grx_sum(const double * restrict a, size_t n) {
|
|
235
|
+
double acc = 0.0;
|
|
236
|
+
#ifdef GRX_AVX2_FMA
|
|
237
|
+
__m256d v0 = _mm256_setzero_pd(), v1 = _mm256_setzero_pd();
|
|
238
|
+
size_t i = 0;
|
|
239
|
+
for (; i + 8 <= n; i += 8) {
|
|
240
|
+
v0 = _mm256_add_pd(v0, VLD(a+i));
|
|
241
|
+
v1 = _mm256_add_pd(v1, VLD(a+i+4));
|
|
242
|
+
}
|
|
243
|
+
v0 = _mm256_add_pd(v0, v1);
|
|
244
|
+
for (; i + 4 <= n; i += 4) v0 = _mm256_add_pd(v0, VLD(a+i));
|
|
245
|
+
double tmp[4]; _mm256_store_pd(tmp, v0);
|
|
246
|
+
acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
|
|
247
|
+
for (; i < n; i++) acc += a[i];
|
|
248
|
+
#elif defined(GRX_AVX2)
|
|
249
|
+
__m256d vacc = _mm256_setzero_pd();
|
|
250
|
+
size_t i = 0;
|
|
251
|
+
for (; i + 4 <= n; i += 4) vacc = _mm256_add_pd(vacc, VLD(a+i));
|
|
252
|
+
double tmp[4]; _mm256_storeu_pd(tmp, vacc);
|
|
253
|
+
acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
|
|
254
|
+
for (; i < n; i++) acc += a[i];
|
|
255
|
+
#else
|
|
256
|
+
for (size_t i = 0; i < n; i++) acc += a[i];
|
|
257
|
+
#endif
|
|
258
|
+
return acc;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
GRX_API double grx_mean(const double * restrict a, size_t n) {
|
|
262
|
+
return n > 0 ? grx_sum(a, n) / (double)n : 0.0;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
GRX_API double grx_max(const double * restrict a, size_t n) {
|
|
266
|
+
if (n == 0) return -DBL_MAX;
|
|
267
|
+
double m = a[0];
|
|
268
|
+
for (size_t i = 1; i < n; i++) if (a[i] > m) m = a[i];
|
|
269
|
+
return m;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
GRX_API double grx_min(const double * restrict a, size_t n) {
|
|
273
|
+
if (n == 0) return DBL_MAX;
|
|
274
|
+
double m = a[0];
|
|
275
|
+
for (size_t i = 1; i < n; i++) if (a[i] < m) m = a[i];
|
|
276
|
+
return m;
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
/* ================================================================
|
|
280
|
+
* ÁLGEBRA LINEAL
|
|
281
|
+
* ================================================================ */
|
|
282
|
+
|
|
283
|
+
GRX_API double grx_dot(const double * restrict a, const double * restrict b, size_t n) {
|
|
284
|
+
double acc = 0.0;
|
|
285
|
+
#ifdef GRX_AVX2_FMA
|
|
286
|
+
__m256d v0 = _mm256_setzero_pd(), v1 = _mm256_setzero_pd();
|
|
287
|
+
size_t i = 0;
|
|
288
|
+
for (; i + 8 <= n; i += 8) {
|
|
289
|
+
v0 = _mm256_fmadd_pd(VLD(a+i), VLD(b+i), v0);
|
|
290
|
+
v1 = _mm256_fmadd_pd(VLD(a+i+4), VLD(b+i+4), v1);
|
|
291
|
+
}
|
|
292
|
+
v0 = _mm256_add_pd(v0, v1);
|
|
293
|
+
for (; i + 4 <= n; i += 4) v0 = _mm256_fmadd_pd(VLD(a+i), VLD(b+i), v0);
|
|
294
|
+
double tmp[4]; _mm256_store_pd(tmp, v0);
|
|
295
|
+
acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
|
|
296
|
+
for (; i < n; i++) acc += a[i] * b[i];
|
|
297
|
+
#elif defined(GRX_AVX2)
|
|
298
|
+
__m256d vacc = _mm256_setzero_pd();
|
|
299
|
+
size_t i = 0;
|
|
300
|
+
for (; i + 4 <= n; i += 4)
|
|
301
|
+
vacc = _mm256_add_pd(vacc, _mm256_mul_pd(VLD(a+i), VLD(b+i)));
|
|
302
|
+
double tmp[4]; _mm256_storeu_pd(tmp, vacc);
|
|
303
|
+
acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
|
|
304
|
+
for (; i < n; i++) acc += a[i] * b[i];
|
|
305
|
+
#else
|
|
306
|
+
for (size_t i = 0; i < n; i++) acc += a[i] * b[i];
|
|
307
|
+
#endif
|
|
308
|
+
return acc;
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
/* matmul con tiling cache-friendly */
|
|
312
|
+
GRX_API void grx_matmul(const double * restrict a, const double * restrict b,
|
|
313
|
+
double * restrict out, size_t M, size_t K, size_t N) {
|
|
314
|
+
memset(out, 0, M * N * sizeof(double));
|
|
315
|
+
for (size_t ii = 0; ii < M; ii += TILE) {
|
|
316
|
+
size_t ie = ii + TILE < M ? ii + TILE : M;
|
|
317
|
+
for (size_t kk = 0; kk < K; kk += TILE) {
|
|
318
|
+
size_t ke = kk + TILE < K ? kk + TILE : K;
|
|
319
|
+
for (size_t jj = 0; jj < N; jj += TILE) {
|
|
320
|
+
size_t je = jj + TILE < N ? jj + TILE : N;
|
|
321
|
+
for (size_t i = ii; i < ie; i++)
|
|
322
|
+
for (size_t k = kk; k < ke; k++) {
|
|
323
|
+
double aik = a[i*K+k];
|
|
324
|
+
for (size_t j = jj; j < je; j++)
|
|
325
|
+
out[i*N+j] += aik * b[k*N+j];
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
/* ================================================================
|
|
333
|
+
* ACTIVACIONES
|
|
334
|
+
* ================================================================ */
|
|
335
|
+
|
|
336
|
+
GRX_API void grx_relu(const double * restrict a, double * restrict out, size_t n) {
|
|
337
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
338
|
+
__m256d vz = _mm256_setzero_pd();
|
|
339
|
+
size_t i = 0;
|
|
340
|
+
for (; i + 8 <= n; i += 8) {
|
|
341
|
+
VST(out+i, _mm256_max_pd(VLD(a+i), vz));
|
|
342
|
+
VST(out+i+4, _mm256_max_pd(VLD(a+i+4), vz));
|
|
343
|
+
}
|
|
344
|
+
for (; i + 4 <= n; i += 4) VST(out+i, _mm256_max_pd(VLD(a+i), vz));
|
|
345
|
+
for (; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : 0.0;
|
|
346
|
+
#else
|
|
347
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : 0.0;
|
|
348
|
+
#endif
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
GRX_API void grx_leaky_relu(const double * restrict a, double alpha,
|
|
352
|
+
double * restrict out, size_t n) {
|
|
353
|
+
#if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
|
|
354
|
+
__m256d va = _mm256_set1_pd(alpha);
|
|
355
|
+
__m256d vz = _mm256_setzero_pd();
|
|
356
|
+
size_t i = 0;
|
|
357
|
+
for (; i + 4 <= n; i += 4) {
|
|
358
|
+
__m256d v = VLD(a+i);
|
|
359
|
+
/* max(v, alpha*v): si v>0 → v, si v<=0 → alpha*v */
|
|
360
|
+
VST(out+i, _mm256_blendv_pd(_mm256_mul_pd(v, va), v,
|
|
361
|
+
_mm256_cmp_pd(v, vz, _CMP_GT_OQ)));
|
|
362
|
+
}
|
|
363
|
+
for (; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : alpha * a[i];
|
|
364
|
+
#else
|
|
365
|
+
for (size_t i = 0; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : alpha * a[i];
|
|
366
|
+
#endif
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
GRX_API void grx_tanh_act(const double * restrict a, double * restrict out, size_t n) {
|
|
370
|
+
for (size_t i = 0; i < n; i++) out[i] = tanh(a[i]);
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
GRX_API void grx_sigmoid(const double * restrict a, double * restrict out, size_t n) {
|
|
374
|
+
for (size_t i = 0; i < n; i++) out[i] = 1.0 / (1.0 + exp(-a[i]));
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
GRX_API void grx_softmax(const double * restrict a, double * restrict out, size_t n) {
|
|
378
|
+
double max_val = grx_max(a, n);
|
|
379
|
+
double sum = 0.0;
|
|
380
|
+
for (size_t i = 0; i < n; i++) { out[i] = exp(a[i] - max_val); sum += out[i]; }
|
|
381
|
+
double inv = 1.0 / sum;
|
|
382
|
+
for (size_t i = 0; i < n; i++) out[i] *= inv;
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
/* ================================================================
|
|
386
|
+
* OPTIMIZADORES (in-place sobre parámetros)
|
|
387
|
+
* ================================================================ */
|
|
388
|
+
|
|
389
|
+
/* SGD: param[i] -= lr * grad[i] */
|
|
390
|
+
GRX_API void grx_sgd_step(double * restrict param, const double * restrict grad,
|
|
391
|
+
double lr, size_t n) {
|
|
392
|
+
#ifdef GRX_AVX2_FMA
|
|
393
|
+
__m256d vlr = _mm256_set1_pd(lr);
|
|
394
|
+
size_t i = 0;
|
|
395
|
+
for (; i + 8 <= n; i += 8) {
|
|
396
|
+
/* param -= lr * grad usando FMA: param = -lr*grad + param */
|
|
397
|
+
VST(param+i, _mm256_fnmadd_pd(vlr, VLD(grad+i), VLD(param+i)));
|
|
398
|
+
VST(param+i+4, _mm256_fnmadd_pd(vlr, VLD(grad+i+4), VLD(param+i+4)));
|
|
399
|
+
}
|
|
400
|
+
for (; i + 4 <= n; i += 4)
|
|
401
|
+
VST(param+i, _mm256_fnmadd_pd(vlr, VLD(grad+i), VLD(param+i)));
|
|
402
|
+
for (; i < n; i++) param[i] -= lr * grad[i];
|
|
403
|
+
#else
|
|
404
|
+
for (size_t i = 0; i < n; i++) param[i] -= lr * grad[i];
|
|
405
|
+
#endif
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
/*
|
|
409
|
+
* Adam: Kingma & Ba 2015
|
|
410
|
+
* m = beta1*m + (1-beta1)*grad
|
|
411
|
+
* v = beta2*v + (1-beta2)*grad^2
|
|
412
|
+
* m_hat = m / (1 - beta1^t)
|
|
413
|
+
* v_hat = v / (1 - beta2^t)
|
|
414
|
+
* param -= lr * m_hat / (sqrt(v_hat) + eps)
|
|
415
|
+
*
|
|
416
|
+
* beta1t = beta1^t (pasado desde Ruby, se actualiza por paso)
|
|
417
|
+
* beta2t = beta2^t
|
|
418
|
+
*/
|
|
419
|
+
GRX_API void grx_adam_step(double * restrict param,
|
|
420
|
+
double * restrict m, double * restrict v,
|
|
421
|
+
const double * restrict grad,
|
|
422
|
+
double lr, double beta1, double beta2,
|
|
423
|
+
double epsilon, double beta1t, double beta2t,
|
|
424
|
+
size_t n) {
|
|
425
|
+
double one_m_b1 = 1.0 - beta1;
|
|
426
|
+
double one_m_b2 = 1.0 - beta2;
|
|
427
|
+
double inv_1mb1t = 1.0 / (1.0 - beta1t);
|
|
428
|
+
double inv_1mb2t = 1.0 / (1.0 - beta2t);
|
|
429
|
+
|
|
430
|
+
#ifdef GRX_AVX2_FMA
|
|
431
|
+
__m256d vb1 = _mm256_set1_pd(beta1);
|
|
432
|
+
__m256d vb2 = _mm256_set1_pd(beta2);
|
|
433
|
+
__m256d v1mb1 = _mm256_set1_pd(one_m_b1);
|
|
434
|
+
__m256d v1mb2 = _mm256_set1_pd(one_m_b2);
|
|
435
|
+
__m256d vlr = _mm256_set1_pd(lr);
|
|
436
|
+
__m256d veps = _mm256_set1_pd(epsilon);
|
|
437
|
+
__m256d vi1b1t = _mm256_set1_pd(inv_1mb1t);
|
|
438
|
+
__m256d vi2b2t = _mm256_set1_pd(inv_1mb2t);
|
|
439
|
+
|
|
440
|
+
size_t i = 0;
|
|
441
|
+
for (; i + 4 <= n; i += 4) {
|
|
442
|
+
__m256d g = VLD(grad+i);
|
|
443
|
+
/* m = beta1*m + (1-beta1)*g */
|
|
444
|
+
__m256d mi = _mm256_fmadd_pd(vb1, VLD(m+i), _mm256_mul_pd(v1mb1, g));
|
|
445
|
+
/* v = beta2*v + (1-beta2)*g^2 */
|
|
446
|
+
__m256d vi = _mm256_fmadd_pd(vb2, VLD(v+i),
|
|
447
|
+
_mm256_mul_pd(v1mb2, _mm256_mul_pd(g, g)));
|
|
448
|
+
VST(m+i, mi);
|
|
449
|
+
VST(v+i, vi);
|
|
450
|
+
/* m_hat = m / (1-beta1^t), v_hat = v / (1-beta2^t) */
|
|
451
|
+
__m256d mh = _mm256_mul_pd(mi, vi1b1t);
|
|
452
|
+
__m256d vh = _mm256_mul_pd(vi, vi2b2t);
|
|
453
|
+
/* param -= lr * mh / (sqrt(vh) + eps) */
|
|
454
|
+
__m256d denom = _mm256_add_pd(_mm256_sqrt_pd(vh), veps);
|
|
455
|
+
VST(param+i, _mm256_fnmadd_pd(vlr, _mm256_div_pd(mh, denom), VLD(param+i)));
|
|
456
|
+
}
|
|
457
|
+
for (; i < n; i++) {
|
|
458
|
+
m[i] = beta1 * m[i] + one_m_b1 * grad[i];
|
|
459
|
+
v[i] = beta2 * v[i] + one_m_b2 * grad[i] * grad[i];
|
|
460
|
+
double mh = m[i] * inv_1mb1t;
|
|
461
|
+
double vh = v[i] * inv_1mb2t;
|
|
462
|
+
param[i] -= lr * mh / (sqrt(vh) + epsilon);
|
|
463
|
+
}
|
|
464
|
+
#else
|
|
465
|
+
for (size_t i = 0; i < n; i++) {
|
|
466
|
+
m[i] = beta1 * m[i] + one_m_b1 * grad[i];
|
|
467
|
+
v[i] = beta2 * v[i] + one_m_b2 * grad[i] * grad[i];
|
|
468
|
+
double mh = m[i] * inv_1mb1t;
|
|
469
|
+
double vh = v[i] * inv_1mb2t;
|
|
470
|
+
param[i] -= lr * mh / (sqrt(vh) + epsilon);
|
|
471
|
+
}
|
|
472
|
+
#endif
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
/* ================================================================
|
|
476
|
+
* INICIALIZACIÓN DE PESOS
|
|
477
|
+
* ================================================================ */
|
|
478
|
+
|
|
479
|
+
/* LCG simple (no criptográfico, pero rápido y sin dependencias) */
|
|
480
|
+
static uint64_t grx_rng_state = 0;
|
|
481
|
+
|
|
482
|
+
static void grx_rng_seed(void) {
|
|
483
|
+
grx_rng_state = (uint64_t)time(NULL) ^ (uint64_t)(uintptr_t)&grx_rng_state;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
/* Genera double uniforme en [0, 1) */
|
|
487
|
+
static double grx_rand01(void) {
|
|
488
|
+
/* xorshift64 */
|
|
489
|
+
grx_rng_state ^= grx_rng_state << 13;
|
|
490
|
+
grx_rng_state ^= grx_rng_state >> 7;
|
|
491
|
+
grx_rng_state ^= grx_rng_state << 17;
|
|
492
|
+
return (double)(grx_rng_state >> 11) / (double)(1ULL << 53);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
/* Box-Muller: genera par de normales N(0,1) */
|
|
496
|
+
static void grx_box_muller(double *z0, double *z1) {
|
|
497
|
+
double u1, u2;
|
|
498
|
+
do { u1 = grx_rand01(); } while (u1 < 1e-15);
|
|
499
|
+
u2 = grx_rand01();
|
|
500
|
+
double r = sqrt(-2.0 * log(u1));
|
|
501
|
+
*z0 = r * cos(2.0 * M_PI * u2);
|
|
502
|
+
*z1 = r * sin(2.0 * M_PI * u2);
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
GRX_API void grx_init_xavier_uniform(double *out, size_t n,
|
|
506
|
+
size_t fan_in, size_t fan_out) {
|
|
507
|
+
grx_rng_seed();
|
|
508
|
+
double limit = sqrt(6.0 / (double)(fan_in + fan_out));
|
|
509
|
+
for (size_t i = 0; i < n; i++)
|
|
510
|
+
out[i] = (grx_rand01() * 2.0 - 1.0) * limit;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
GRX_API void grx_init_he_normal(double *out, size_t n, size_t fan_in) {
|
|
514
|
+
grx_rng_seed();
|
|
515
|
+
double std = sqrt(2.0 / (double)fan_in);
|
|
516
|
+
size_t i = 0;
|
|
517
|
+
for (; i + 1 < n; i += 2) {
|
|
518
|
+
double z0, z1;
|
|
519
|
+
grx_box_muller(&z0, &z1);
|
|
520
|
+
out[i] = z0 * std;
|
|
521
|
+
out[i+1] = z1 * std;
|
|
522
|
+
}
|
|
523
|
+
if (i < n) {
|
|
524
|
+
double z0, z1;
|
|
525
|
+
grx_box_muller(&z0, &z1);
|
|
526
|
+
out[i] = z0 * std;
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
/* ============================================================
|
|
531
|
+
* RUBY EXTENSION INIT — requerido por rake-compiler / mkmf
|
|
532
|
+
* No hace nada: la librería se carga vía Fiddle, no como extensión Ruby nativa.
|
|
533
|
+
* ============================================================ */
|
|
534
|
+
void Init_grx_core(void) { /* no-op */ }
|
data/ext/grx/grx_core.h
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* grx_core.h — API pública del núcleo C de GRX
|
|
3
|
+
* =============================================================
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#ifndef GRX_CORE_H
|
|
7
|
+
#define GRX_CORE_H
|
|
8
|
+
|
|
9
|
+
#include <stddef.h>
|
|
10
|
+
|
|
11
|
+
#ifdef _WIN32
|
|
12
|
+
#define GRX_API __declspec(dllexport)
|
|
13
|
+
#else
|
|
14
|
+
#define GRX_API __attribute__((visibility("default")))
|
|
15
|
+
#endif
|
|
16
|
+
|
|
17
|
+
#ifdef __cplusplus
|
|
18
|
+
extern "C" {
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
/* ---- Memoria alineada -------------------------------------------- */
|
|
22
|
+
GRX_API double* grx_alloc(size_t n);
|
|
23
|
+
GRX_API void grx_free(double *ptr);
|
|
24
|
+
|
|
25
|
+
/* ---- Element-wise aritmética ------------------------------------- */
|
|
26
|
+
GRX_API void grx_add (const double *a, const double *b, double *out, size_t n);
|
|
27
|
+
GRX_API void grx_sub (const double *a, const double *b, double *out, size_t n);
|
|
28
|
+
GRX_API void grx_mul (const double *a, const double *b, double *out, size_t n);
|
|
29
|
+
GRX_API void grx_div (const double *a, const double *b, double *out, size_t n);
|
|
30
|
+
GRX_API void grx_scale (const double *a, double s, double *out, size_t n);
|
|
31
|
+
GRX_API void grx_negate(const double *a, double *out, size_t n);
|
|
32
|
+
GRX_API void grx_add_scalar(const double *a, double s, double *out, size_t n);
|
|
33
|
+
|
|
34
|
+
/* ---- Element-wise matemáticas ------------------------------------ */
|
|
35
|
+
GRX_API void grx_abs (const double *a, double *out, size_t n);
|
|
36
|
+
GRX_API void grx_sqrt (const double *a, double *out, size_t n);
|
|
37
|
+
GRX_API void grx_log (const double *a, double *out, size_t n);
|
|
38
|
+
GRX_API void grx_exp (const double *a, double *out, size_t n);
|
|
39
|
+
GRX_API void grx_pow (const double *a, double exp, double *out, size_t n);
|
|
40
|
+
GRX_API void grx_clip (const double *a, double lo, double hi, double *out, size_t n);
|
|
41
|
+
GRX_API void grx_square (const double *a, double *out, size_t n);
|
|
42
|
+
|
|
43
|
+
/* ---- Reducciones ------------------------------------------------- */
|
|
44
|
+
GRX_API double grx_sum (const double *a, size_t n);
|
|
45
|
+
GRX_API double grx_mean(const double *a, size_t n);
|
|
46
|
+
GRX_API double grx_max (const double *a, size_t n);
|
|
47
|
+
GRX_API double grx_min (const double *a, size_t n);
|
|
48
|
+
|
|
49
|
+
/* ---- Álgebra lineal ---------------------------------------------- */
|
|
50
|
+
GRX_API double grx_dot (const double *a, const double *b, size_t n);
|
|
51
|
+
GRX_API void grx_matmul (const double *a, const double *b, double *out,
|
|
52
|
+
size_t M, size_t K, size_t N);
|
|
53
|
+
|
|
54
|
+
/* ---- Activaciones ------------------------------------------------ */
|
|
55
|
+
GRX_API void grx_relu (const double *a, double *out, size_t n);
|
|
56
|
+
GRX_API void grx_leaky_relu (const double *a, double alpha, double *out, size_t n);
|
|
57
|
+
GRX_API void grx_tanh_act (const double *a, double *out, size_t n);
|
|
58
|
+
GRX_API void grx_sigmoid (const double *a, double *out, size_t n);
|
|
59
|
+
GRX_API void grx_softmax (const double *a, double *out, size_t n);
|
|
60
|
+
|
|
61
|
+
/* ---- Optimizadores (in-place) ------------------------------------ */
|
|
62
|
+
/* SGD: param -= lr * grad */
|
|
63
|
+
GRX_API void grx_sgd_step(double *param, const double *grad,
|
|
64
|
+
double lr, size_t n);
|
|
65
|
+
|
|
66
|
+
/* Adam: actualiza param, m, v in-place */
|
|
67
|
+
GRX_API void grx_adam_step(double *param,
|
|
68
|
+
double *m, double *v,
|
|
69
|
+
const double *grad,
|
|
70
|
+
double lr, double beta1, double beta2,
|
|
71
|
+
double epsilon, double beta1t, double beta2t,
|
|
72
|
+
size_t n);
|
|
73
|
+
|
|
74
|
+
/* ---- Inicialización de pesos ------------------------------------- */
|
|
75
|
+
/* Xavier uniform: U(-limit, limit), limit = sqrt(6 / (fan_in + fan_out)) */
|
|
76
|
+
GRX_API void grx_init_xavier_uniform(double *out, size_t n,
|
|
77
|
+
size_t fan_in, size_t fan_out);
|
|
78
|
+
/* He normal: N(0, sqrt(2/fan_in)) */
|
|
79
|
+
GRX_API void grx_init_he_normal(double *out, size_t n, size_t fan_in);
|
|
80
|
+
|
|
81
|
+
#ifdef __cplusplus
|
|
82
|
+
}
|
|
83
|
+
#endif
|
|
84
|
+
|
|
85
|
+
#endif /* GRX_CORE_H */
|
data/ext/unix/Makefile
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# =============================================================
|
|
2
|
+
# Makefile — Linux / macOS
|
|
3
|
+
#
|
|
4
|
+
# Compila grx_core.c DIRECTAMENTE en lib/grx/ (sin archivo intermedio).
|
|
5
|
+
# No hay .so en ext/unix/ — el único binario vive en lib/grx/.
|
|
6
|
+
#
|
|
7
|
+
# Uso:
|
|
8
|
+
# make → compila (detecta OS y SIMD automáticamente)
|
|
9
|
+
# make clean → elimina el binario de lib/grx/
|
|
10
|
+
# make bench → compila y corre el benchmark
|
|
11
|
+
# =============================================================
|
|
12
|
+
|
|
13
|
+
CC = gcc
|
|
14
|
+
CFLAGS = -O3 -march=native -ffast-math -funroll-loops \
|
|
15
|
+
-fPIC -fvisibility=hidden \
|
|
16
|
+
-Wall -Wextra -std=c11
|
|
17
|
+
LDFLAGS = -lm
|
|
18
|
+
SRC_DIR = ../grx
|
|
19
|
+
SRC = $(SRC_DIR)/grx_core.c
|
|
20
|
+
HEADER = $(SRC_DIR)/grx_core.h
|
|
21
|
+
|
|
22
|
+
# Destino final — directamente en lib/grx/
|
|
23
|
+
OUT_DIR = ../../lib/grx
|
|
24
|
+
|
|
25
|
+
UNAME := $(shell uname -s)
|
|
26
|
+
ifeq ($(UNAME), Darwin)
|
|
27
|
+
LIB = libgrx_core.dylib
|
|
28
|
+
SHARED = -dynamiclib
|
|
29
|
+
else
|
|
30
|
+
LIB = libgrx_core.so
|
|
31
|
+
SHARED = -shared
|
|
32
|
+
endif
|
|
33
|
+
|
|
34
|
+
TARGET = $(OUT_DIR)/$(LIB)
|
|
35
|
+
|
|
36
|
+
# Detección de SIMD
|
|
37
|
+
AVX2_TEST := $(shell echo 'int main(){}' | $(CC) -mavx2 -mfma -x c - -o /dev/null 2>&1)
|
|
38
|
+
ifeq ($(AVX2_TEST),)
|
|
39
|
+
CFLAGS += -mavx2 -mfma
|
|
40
|
+
$(info [GRX] AVX2 + FMA habilitados — máxima velocidad SIMD)
|
|
41
|
+
else
|
|
42
|
+
SSE2_TEST := $(shell echo 'int main(){}' | $(CC) -msse2 -x c - -o /dev/null 2>&1)
|
|
43
|
+
ifeq ($(SSE2_TEST),)
|
|
44
|
+
CFLAGS += -msse2
|
|
45
|
+
$(info [GRX] SSE2 habilitado)
|
|
46
|
+
else
|
|
47
|
+
$(info [GRX] Sin SIMD — modo escalar)
|
|
48
|
+
endif
|
|
49
|
+
endif
|
|
50
|
+
|
|
51
|
+
.PHONY: all clean bench
|
|
52
|
+
|
|
53
|
+
all: $(TARGET)
|
|
54
|
+
|
|
55
|
+
$(TARGET): $(SRC) $(HEADER)
|
|
56
|
+
@mkdir -p $(OUT_DIR)
|
|
57
|
+
$(CC) $(CFLAGS) $(SHARED) $(SRC) -o $(TARGET) $(LDFLAGS)
|
|
58
|
+
@echo "[GRX] Compilado → $(TARGET)"
|
|
59
|
+
|
|
60
|
+
bench: all
|
|
61
|
+
@echo "[GRX] Corriendo benchmark..."
|
|
62
|
+
ruby -I../../lib ../../test/benchmark.rb
|
|
63
|
+
|
|
64
|
+
clean:
|
|
65
|
+
rm -f $(TARGET)
|
|
66
|
+
@echo "[GRX] Limpiado: $(TARGET)"
|