grx-tensor 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,534 @@
1
+ /*
2
+ * grx_core.c — Núcleo C de GRX
3
+ * =============================================================
4
+ * Optimizaciones activas:
5
+ * - AVX2 + FMA: 4 doubles/ciclo con multiply-add fusionado
6
+ * - Loop unrolling x2: mayor ILP (Instruction Level Parallelism)
7
+ * - restrict: elimina alias analysis, permite más vectorización auto
8
+ * - Memoria alineada 32 bytes: habilita _mm256_load_pd (más rápido que loadu)
9
+ * - matmul con tiling: respeta líneas de caché L1 (64 bytes = 8 doubles)
10
+ * - Adam con FMA: beta*m + (1-beta)*grad en una pasada
11
+ * =============================================================
12
+ */
13
+
14
+ #define _USE_MATH_DEFINES /* M_PI en Windows/MSVC */
15
+ #define _POSIX_C_SOURCE 200809L /* posix_memalign, M_PI en glibc */
16
+ #include "grx_core.h"
17
+ #include <stdlib.h>
18
+ #include <stdint.h>
19
+ #include <string.h>
20
+ #include <math.h>
21
+ #include <float.h>
22
+ #include <time.h>
23
+
24
+ #ifndef M_PI
25
+ #define M_PI 3.14159265358979323846
26
+ #endif
27
+
28
+ #if defined(__AVX2__) && defined(__FMA__)
29
+ #include <immintrin.h>
30
+ #define GRX_AVX2_FMA 1
31
+ #elif defined(__AVX2__)
32
+ #include <immintrin.h>
33
+ #define GRX_AVX2 1
34
+ #elif defined(__SSE2__)
35
+ #include <emmintrin.h>
36
+ #define GRX_SSE2 1
37
+ #endif
38
+
39
+ #define TILE 8
40
+
41
+ /* ================================================================
42
+ * MEMORIA
43
+ * ================================================================ */
44
+
45
+ GRX_API double* grx_alloc(size_t n) {
46
+ if (__builtin_expect(n == 0, 0)) return NULL;
47
+ void *ptr = NULL;
48
+ #if defined(_WIN32)
49
+ ptr = _aligned_malloc(n * sizeof(double), 32);
50
+ #else
51
+ if (posix_memalign(&ptr, 32, n * sizeof(double)) != 0) return NULL;
52
+ #endif
53
+ return (double*)ptr;
54
+ }
55
+
56
+ GRX_API void grx_free(double *ptr) {
57
+ #if defined(_WIN32)
58
+ _aligned_free(ptr);
59
+ #else
60
+ free(ptr);
61
+ #endif
62
+ }
63
+
64
+ /* ================================================================
65
+ * MACROS SIMD INTERNOS
66
+ * ================================================================ */
67
+
68
+ /* Carga/store: usa aligned si tenemos AVX2+FMA (memoria siempre alineada a 32b) */
69
+ #ifdef GRX_AVX2_FMA
70
+ #define VLD(p) _mm256_load_pd(p)
71
+ #define VST(p, v) _mm256_store_pd(p, v)
72
+ #elif defined(GRX_AVX2)
73
+ #define VLD(p) _mm256_loadu_pd(p)
74
+ #define VST(p, v) _mm256_storeu_pd(p, v)
75
+ #endif
76
+
77
+ /* ================================================================
78
+ * ELEMENT-WISE ARITMÉTICA
79
+ * ================================================================ */
80
+
81
+ #define BINOP_BODY(op_avx, op_scalar) \
82
+ size_t i = 0; \
83
+ for (; i + 8 <= n; i += 8) { \
84
+ VST(out+i, op_avx(VLD(a+i), VLD(b+i))); \
85
+ VST(out+i+4, op_avx(VLD(a+i+4), VLD(b+i+4))); \
86
+ } \
87
+ for (; i + 4 <= n; i += 4) VST(out+i, op_avx(VLD(a+i), VLD(b+i)));\
88
+ for (; i < n; i++) out[i] = op_scalar(a[i], b[i]);
89
+
90
+ #define SCALAR_ADD(x,y) ((x)+(y))
91
+ #define SCALAR_SUB(x,y) ((x)-(y))
92
+ #define SCALAR_MUL(x,y) ((x)*(y))
93
+ #define SCALAR_DIV(x,y) ((x)/(y))
94
+
95
+ GRX_API void grx_add(const double * restrict a, const double * restrict b,
96
+ double * restrict out, size_t n) {
97
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
98
+ BINOP_BODY(_mm256_add_pd, SCALAR_ADD)
99
+ #else
100
+ for (size_t i = 0; i < n; i++) out[i] = a[i] + b[i];
101
+ #endif
102
+ }
103
+
104
+ GRX_API void grx_sub(const double * restrict a, const double * restrict b,
105
+ double * restrict out, size_t n) {
106
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
107
+ BINOP_BODY(_mm256_sub_pd, SCALAR_SUB)
108
+ #else
109
+ for (size_t i = 0; i < n; i++) out[i] = a[i] - b[i];
110
+ #endif
111
+ }
112
+
113
+ GRX_API void grx_mul(const double * restrict a, const double * restrict b,
114
+ double * restrict out, size_t n) {
115
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
116
+ BINOP_BODY(_mm256_mul_pd, SCALAR_MUL)
117
+ #else
118
+ for (size_t i = 0; i < n; i++) out[i] = a[i] * b[i];
119
+ #endif
120
+ }
121
+
122
+ GRX_API void grx_div(const double * restrict a, const double * restrict b,
123
+ double * restrict out, size_t n) {
124
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
125
+ BINOP_BODY(_mm256_div_pd, SCALAR_DIV)
126
+ #else
127
+ for (size_t i = 0; i < n; i++) out[i] = a[i] / b[i];
128
+ #endif
129
+ }
130
+
131
+ GRX_API void grx_scale(const double * restrict a, double s,
132
+ double * restrict out, size_t n) {
133
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
134
+ __m256d vs = _mm256_set1_pd(s);
135
+ size_t i = 0;
136
+ for (; i + 8 <= n; i += 8) {
137
+ VST(out+i, _mm256_mul_pd(VLD(a+i), vs));
138
+ VST(out+i+4, _mm256_mul_pd(VLD(a+i+4), vs));
139
+ }
140
+ for (; i + 4 <= n; i += 4) VST(out+i, _mm256_mul_pd(VLD(a+i), vs));
141
+ for (; i < n; i++) out[i] = a[i] * s;
142
+ #else
143
+ for (size_t i = 0; i < n; i++) out[i] = a[i] * s;
144
+ #endif
145
+ }
146
+
147
+ GRX_API void grx_add_scalar(const double * restrict a, double s,
148
+ double * restrict out, size_t n) {
149
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
150
+ __m256d vs = _mm256_set1_pd(s);
151
+ size_t i = 0;
152
+ for (; i + 8 <= n; i += 8) {
153
+ VST(out+i, _mm256_add_pd(VLD(a+i), vs));
154
+ VST(out+i+4, _mm256_add_pd(VLD(a+i+4), vs));
155
+ }
156
+ for (; i + 4 <= n; i += 4) VST(out+i, _mm256_add_pd(VLD(a+i), vs));
157
+ for (; i < n; i++) out[i] = a[i] + s;
158
+ #else
159
+ for (size_t i = 0; i < n; i++) out[i] = a[i] + s;
160
+ #endif
161
+ }
162
+
163
+ GRX_API void grx_negate(const double * restrict a, double * restrict out, size_t n) {
164
+ grx_scale(a, -1.0, out, n);
165
+ }
166
+
167
+ /* ================================================================
168
+ * MATEMÁTICAS ELEMENT-WISE
169
+ * ================================================================ */
170
+
171
+ GRX_API void grx_abs(const double * restrict a, double * restrict out, size_t n) {
172
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
173
+ /* Máscara para limpiar el bit de signo (AND con 0x7FFFFFFFFFFFFFFF) */
174
+ __m256d mask = _mm256_castsi256_pd(
175
+ _mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL));
176
+ size_t i = 0;
177
+ for (; i + 4 <= n; i += 4)
178
+ VST(out+i, _mm256_and_pd(VLD(a+i), mask));
179
+ for (; i < n; i++) out[i] = fabs(a[i]);
180
+ #else
181
+ for (size_t i = 0; i < n; i++) out[i] = fabs(a[i]);
182
+ #endif
183
+ }
184
+
185
+ GRX_API void grx_sqrt(const double * restrict a, double * restrict out, size_t n) {
186
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
187
+ size_t i = 0;
188
+ for (; i + 4 <= n; i += 4)
189
+ VST(out+i, _mm256_sqrt_pd(VLD(a+i)));
190
+ for (; i < n; i++) out[i] = sqrt(a[i]);
191
+ #else
192
+ for (size_t i = 0; i < n; i++) out[i] = sqrt(a[i]);
193
+ #endif
194
+ }
195
+
196
+ GRX_API void grx_square(const double * restrict a, double * restrict out, size_t n) {
197
+ grx_mul(a, a, out, n);
198
+ }
199
+
200
+ GRX_API void grx_log(const double * restrict a, double * restrict out, size_t n) {
201
+ /* log no tiene intrínseco SIMD estándar; -ffast-math + -march=native
202
+ * permite al compilador auto-vectorizar con SVML si está disponible */
203
+ for (size_t i = 0; i < n; i++) out[i] = log(a[i]);
204
+ }
205
+
206
+ GRX_API void grx_exp(const double * restrict a, double * restrict out, size_t n) {
207
+ for (size_t i = 0; i < n; i++) out[i] = exp(a[i]);
208
+ }
209
+
210
+ GRX_API void grx_pow(const double * restrict a, double e,
211
+ double * restrict out, size_t n) {
212
+ for (size_t i = 0; i < n; i++) out[i] = pow(a[i], e);
213
+ }
214
+
215
+ GRX_API void grx_clip(const double * restrict a, double lo, double hi,
216
+ double * restrict out, size_t n) {
217
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
218
+ __m256d vlo = _mm256_set1_pd(lo);
219
+ __m256d vhi = _mm256_set1_pd(hi);
220
+ size_t i = 0;
221
+ for (; i + 4 <= n; i += 4)
222
+ VST(out+i, _mm256_min_pd(_mm256_max_pd(VLD(a+i), vlo), vhi));
223
+ for (; i < n; i++) out[i] = a[i] < lo ? lo : (a[i] > hi ? hi : a[i]);
224
+ #else
225
+ for (size_t i = 0; i < n; i++)
226
+ out[i] = a[i] < lo ? lo : (a[i] > hi ? hi : a[i]);
227
+ #endif
228
+ }
229
+
230
+ /* ================================================================
231
+ * REDUCCIONES
232
+ * ================================================================ */
233
+
234
+ GRX_API double grx_sum(const double * restrict a, size_t n) {
235
+ double acc = 0.0;
236
+ #ifdef GRX_AVX2_FMA
237
+ __m256d v0 = _mm256_setzero_pd(), v1 = _mm256_setzero_pd();
238
+ size_t i = 0;
239
+ for (; i + 8 <= n; i += 8) {
240
+ v0 = _mm256_add_pd(v0, VLD(a+i));
241
+ v1 = _mm256_add_pd(v1, VLD(a+i+4));
242
+ }
243
+ v0 = _mm256_add_pd(v0, v1);
244
+ for (; i + 4 <= n; i += 4) v0 = _mm256_add_pd(v0, VLD(a+i));
245
+ double tmp[4]; _mm256_store_pd(tmp, v0);
246
+ acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
247
+ for (; i < n; i++) acc += a[i];
248
+ #elif defined(GRX_AVX2)
249
+ __m256d vacc = _mm256_setzero_pd();
250
+ size_t i = 0;
251
+ for (; i + 4 <= n; i += 4) vacc = _mm256_add_pd(vacc, VLD(a+i));
252
+ double tmp[4]; _mm256_storeu_pd(tmp, vacc);
253
+ acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
254
+ for (; i < n; i++) acc += a[i];
255
+ #else
256
+ for (size_t i = 0; i < n; i++) acc += a[i];
257
+ #endif
258
+ return acc;
259
+ }
260
+
261
+ GRX_API double grx_mean(const double * restrict a, size_t n) {
262
+ return n > 0 ? grx_sum(a, n) / (double)n : 0.0;
263
+ }
264
+
265
+ GRX_API double grx_max(const double * restrict a, size_t n) {
266
+ if (n == 0) return -DBL_MAX;
267
+ double m = a[0];
268
+ for (size_t i = 1; i < n; i++) if (a[i] > m) m = a[i];
269
+ return m;
270
+ }
271
+
272
+ GRX_API double grx_min(const double * restrict a, size_t n) {
273
+ if (n == 0) return DBL_MAX;
274
+ double m = a[0];
275
+ for (size_t i = 1; i < n; i++) if (a[i] < m) m = a[i];
276
+ return m;
277
+ }
278
+
279
+ /* ================================================================
280
+ * ÁLGEBRA LINEAL
281
+ * ================================================================ */
282
+
283
+ GRX_API double grx_dot(const double * restrict a, const double * restrict b, size_t n) {
284
+ double acc = 0.0;
285
+ #ifdef GRX_AVX2_FMA
286
+ __m256d v0 = _mm256_setzero_pd(), v1 = _mm256_setzero_pd();
287
+ size_t i = 0;
288
+ for (; i + 8 <= n; i += 8) {
289
+ v0 = _mm256_fmadd_pd(VLD(a+i), VLD(b+i), v0);
290
+ v1 = _mm256_fmadd_pd(VLD(a+i+4), VLD(b+i+4), v1);
291
+ }
292
+ v0 = _mm256_add_pd(v0, v1);
293
+ for (; i + 4 <= n; i += 4) v0 = _mm256_fmadd_pd(VLD(a+i), VLD(b+i), v0);
294
+ double tmp[4]; _mm256_store_pd(tmp, v0);
295
+ acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
296
+ for (; i < n; i++) acc += a[i] * b[i];
297
+ #elif defined(GRX_AVX2)
298
+ __m256d vacc = _mm256_setzero_pd();
299
+ size_t i = 0;
300
+ for (; i + 4 <= n; i += 4)
301
+ vacc = _mm256_add_pd(vacc, _mm256_mul_pd(VLD(a+i), VLD(b+i)));
302
+ double tmp[4]; _mm256_storeu_pd(tmp, vacc);
303
+ acc = tmp[0] + tmp[1] + tmp[2] + tmp[3];
304
+ for (; i < n; i++) acc += a[i] * b[i];
305
+ #else
306
+ for (size_t i = 0; i < n; i++) acc += a[i] * b[i];
307
+ #endif
308
+ return acc;
309
+ }
310
+
311
+ /* matmul con tiling cache-friendly */
312
+ GRX_API void grx_matmul(const double * restrict a, const double * restrict b,
313
+ double * restrict out, size_t M, size_t K, size_t N) {
314
+ memset(out, 0, M * N * sizeof(double));
315
+ for (size_t ii = 0; ii < M; ii += TILE) {
316
+ size_t ie = ii + TILE < M ? ii + TILE : M;
317
+ for (size_t kk = 0; kk < K; kk += TILE) {
318
+ size_t ke = kk + TILE < K ? kk + TILE : K;
319
+ for (size_t jj = 0; jj < N; jj += TILE) {
320
+ size_t je = jj + TILE < N ? jj + TILE : N;
321
+ for (size_t i = ii; i < ie; i++)
322
+ for (size_t k = kk; k < ke; k++) {
323
+ double aik = a[i*K+k];
324
+ for (size_t j = jj; j < je; j++)
325
+ out[i*N+j] += aik * b[k*N+j];
326
+ }
327
+ }
328
+ }
329
+ }
330
+ }
331
+
332
+ /* ================================================================
333
+ * ACTIVACIONES
334
+ * ================================================================ */
335
+
336
+ GRX_API void grx_relu(const double * restrict a, double * restrict out, size_t n) {
337
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
338
+ __m256d vz = _mm256_setzero_pd();
339
+ size_t i = 0;
340
+ for (; i + 8 <= n; i += 8) {
341
+ VST(out+i, _mm256_max_pd(VLD(a+i), vz));
342
+ VST(out+i+4, _mm256_max_pd(VLD(a+i+4), vz));
343
+ }
344
+ for (; i + 4 <= n; i += 4) VST(out+i, _mm256_max_pd(VLD(a+i), vz));
345
+ for (; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : 0.0;
346
+ #else
347
+ for (size_t i = 0; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : 0.0;
348
+ #endif
349
+ }
350
+
351
+ GRX_API void grx_leaky_relu(const double * restrict a, double alpha,
352
+ double * restrict out, size_t n) {
353
+ #if defined(GRX_AVX2_FMA) || defined(GRX_AVX2)
354
+ __m256d va = _mm256_set1_pd(alpha);
355
+ __m256d vz = _mm256_setzero_pd();
356
+ size_t i = 0;
357
+ for (; i + 4 <= n; i += 4) {
358
+ __m256d v = VLD(a+i);
359
+ /* max(v, alpha*v): si v>0 → v, si v<=0 → alpha*v */
360
+ VST(out+i, _mm256_blendv_pd(_mm256_mul_pd(v, va), v,
361
+ _mm256_cmp_pd(v, vz, _CMP_GT_OQ)));
362
+ }
363
+ for (; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : alpha * a[i];
364
+ #else
365
+ for (size_t i = 0; i < n; i++) out[i] = a[i] > 0.0 ? a[i] : alpha * a[i];
366
+ #endif
367
+ }
368
+
369
+ GRX_API void grx_tanh_act(const double * restrict a, double * restrict out, size_t n) {
370
+ for (size_t i = 0; i < n; i++) out[i] = tanh(a[i]);
371
+ }
372
+
373
+ GRX_API void grx_sigmoid(const double * restrict a, double * restrict out, size_t n) {
374
+ for (size_t i = 0; i < n; i++) out[i] = 1.0 / (1.0 + exp(-a[i]));
375
+ }
376
+
377
+ GRX_API void grx_softmax(const double * restrict a, double * restrict out, size_t n) {
378
+ double max_val = grx_max(a, n);
379
+ double sum = 0.0;
380
+ for (size_t i = 0; i < n; i++) { out[i] = exp(a[i] - max_val); sum += out[i]; }
381
+ double inv = 1.0 / sum;
382
+ for (size_t i = 0; i < n; i++) out[i] *= inv;
383
+ }
384
+
385
+ /* ================================================================
386
+ * OPTIMIZADORES (in-place sobre parámetros)
387
+ * ================================================================ */
388
+
389
+ /* SGD: param[i] -= lr * grad[i] */
390
+ GRX_API void grx_sgd_step(double * restrict param, const double * restrict grad,
391
+ double lr, size_t n) {
392
+ #ifdef GRX_AVX2_FMA
393
+ __m256d vlr = _mm256_set1_pd(lr);
394
+ size_t i = 0;
395
+ for (; i + 8 <= n; i += 8) {
396
+ /* param -= lr * grad usando FMA: param = -lr*grad + param */
397
+ VST(param+i, _mm256_fnmadd_pd(vlr, VLD(grad+i), VLD(param+i)));
398
+ VST(param+i+4, _mm256_fnmadd_pd(vlr, VLD(grad+i+4), VLD(param+i+4)));
399
+ }
400
+ for (; i + 4 <= n; i += 4)
401
+ VST(param+i, _mm256_fnmadd_pd(vlr, VLD(grad+i), VLD(param+i)));
402
+ for (; i < n; i++) param[i] -= lr * grad[i];
403
+ #else
404
+ for (size_t i = 0; i < n; i++) param[i] -= lr * grad[i];
405
+ #endif
406
+ }
407
+
408
+ /*
409
+ * Adam: Kingma & Ba 2015
410
+ * m = beta1*m + (1-beta1)*grad
411
+ * v = beta2*v + (1-beta2)*grad^2
412
+ * m_hat = m / (1 - beta1^t)
413
+ * v_hat = v / (1 - beta2^t)
414
+ * param -= lr * m_hat / (sqrt(v_hat) + eps)
415
+ *
416
+ * beta1t = beta1^t (pasado desde Ruby, se actualiza por paso)
417
+ * beta2t = beta2^t
418
+ */
419
+ GRX_API void grx_adam_step(double * restrict param,
420
+ double * restrict m, double * restrict v,
421
+ const double * restrict grad,
422
+ double lr, double beta1, double beta2,
423
+ double epsilon, double beta1t, double beta2t,
424
+ size_t n) {
425
+ double one_m_b1 = 1.0 - beta1;
426
+ double one_m_b2 = 1.0 - beta2;
427
+ double inv_1mb1t = 1.0 / (1.0 - beta1t);
428
+ double inv_1mb2t = 1.0 / (1.0 - beta2t);
429
+
430
+ #ifdef GRX_AVX2_FMA
431
+ __m256d vb1 = _mm256_set1_pd(beta1);
432
+ __m256d vb2 = _mm256_set1_pd(beta2);
433
+ __m256d v1mb1 = _mm256_set1_pd(one_m_b1);
434
+ __m256d v1mb2 = _mm256_set1_pd(one_m_b2);
435
+ __m256d vlr = _mm256_set1_pd(lr);
436
+ __m256d veps = _mm256_set1_pd(epsilon);
437
+ __m256d vi1b1t = _mm256_set1_pd(inv_1mb1t);
438
+ __m256d vi2b2t = _mm256_set1_pd(inv_1mb2t);
439
+
440
+ size_t i = 0;
441
+ for (; i + 4 <= n; i += 4) {
442
+ __m256d g = VLD(grad+i);
443
+ /* m = beta1*m + (1-beta1)*g */
444
+ __m256d mi = _mm256_fmadd_pd(vb1, VLD(m+i), _mm256_mul_pd(v1mb1, g));
445
+ /* v = beta2*v + (1-beta2)*g^2 */
446
+ __m256d vi = _mm256_fmadd_pd(vb2, VLD(v+i),
447
+ _mm256_mul_pd(v1mb2, _mm256_mul_pd(g, g)));
448
+ VST(m+i, mi);
449
+ VST(v+i, vi);
450
+ /* m_hat = m / (1-beta1^t), v_hat = v / (1-beta2^t) */
451
+ __m256d mh = _mm256_mul_pd(mi, vi1b1t);
452
+ __m256d vh = _mm256_mul_pd(vi, vi2b2t);
453
+ /* param -= lr * mh / (sqrt(vh) + eps) */
454
+ __m256d denom = _mm256_add_pd(_mm256_sqrt_pd(vh), veps);
455
+ VST(param+i, _mm256_fnmadd_pd(vlr, _mm256_div_pd(mh, denom), VLD(param+i)));
456
+ }
457
+ for (; i < n; i++) {
458
+ m[i] = beta1 * m[i] + one_m_b1 * grad[i];
459
+ v[i] = beta2 * v[i] + one_m_b2 * grad[i] * grad[i];
460
+ double mh = m[i] * inv_1mb1t;
461
+ double vh = v[i] * inv_1mb2t;
462
+ param[i] -= lr * mh / (sqrt(vh) + epsilon);
463
+ }
464
+ #else
465
+ for (size_t i = 0; i < n; i++) {
466
+ m[i] = beta1 * m[i] + one_m_b1 * grad[i];
467
+ v[i] = beta2 * v[i] + one_m_b2 * grad[i] * grad[i];
468
+ double mh = m[i] * inv_1mb1t;
469
+ double vh = v[i] * inv_1mb2t;
470
+ param[i] -= lr * mh / (sqrt(vh) + epsilon);
471
+ }
472
+ #endif
473
+ }
474
+
475
+ /* ================================================================
476
+ * INICIALIZACIÓN DE PESOS
477
+ * ================================================================ */
478
+
479
+ /* LCG simple (no criptográfico, pero rápido y sin dependencias) */
480
+ static uint64_t grx_rng_state = 0;
481
+
482
+ static void grx_rng_seed(void) {
483
+ grx_rng_state = (uint64_t)time(NULL) ^ (uint64_t)(uintptr_t)&grx_rng_state;
484
+ }
485
+
486
+ /* Genera double uniforme en [0, 1) */
487
+ static double grx_rand01(void) {
488
+ /* xorshift64 */
489
+ grx_rng_state ^= grx_rng_state << 13;
490
+ grx_rng_state ^= grx_rng_state >> 7;
491
+ grx_rng_state ^= grx_rng_state << 17;
492
+ return (double)(grx_rng_state >> 11) / (double)(1ULL << 53);
493
+ }
494
+
495
+ /* Box-Muller: genera par de normales N(0,1) */
496
+ static void grx_box_muller(double *z0, double *z1) {
497
+ double u1, u2;
498
+ do { u1 = grx_rand01(); } while (u1 < 1e-15);
499
+ u2 = grx_rand01();
500
+ double r = sqrt(-2.0 * log(u1));
501
+ *z0 = r * cos(2.0 * M_PI * u2);
502
+ *z1 = r * sin(2.0 * M_PI * u2);
503
+ }
504
+
505
+ GRX_API void grx_init_xavier_uniform(double *out, size_t n,
506
+ size_t fan_in, size_t fan_out) {
507
+ grx_rng_seed();
508
+ double limit = sqrt(6.0 / (double)(fan_in + fan_out));
509
+ for (size_t i = 0; i < n; i++)
510
+ out[i] = (grx_rand01() * 2.0 - 1.0) * limit;
511
+ }
512
+
513
+ GRX_API void grx_init_he_normal(double *out, size_t n, size_t fan_in) {
514
+ grx_rng_seed();
515
+ double std = sqrt(2.0 / (double)fan_in);
516
+ size_t i = 0;
517
+ for (; i + 1 < n; i += 2) {
518
+ double z0, z1;
519
+ grx_box_muller(&z0, &z1);
520
+ out[i] = z0 * std;
521
+ out[i+1] = z1 * std;
522
+ }
523
+ if (i < n) {
524
+ double z0, z1;
525
+ grx_box_muller(&z0, &z1);
526
+ out[i] = z0 * std;
527
+ }
528
+ }
529
+
530
+ /* ============================================================
531
+ * RUBY EXTENSION INIT — requerido por rake-compiler / mkmf
532
+ * No hace nada: la librería se carga vía Fiddle, no como extensión Ruby nativa.
533
+ * ============================================================ */
534
+ void Init_grx_core(void) { /* no-op */ }
@@ -0,0 +1,85 @@
1
+ /*
2
+ * grx_core.h — API pública del núcleo C de GRX
3
+ * =============================================================
4
+ */
5
+
6
+ #ifndef GRX_CORE_H
7
+ #define GRX_CORE_H
8
+
9
+ #include <stddef.h>
10
+
11
+ #ifdef _WIN32
12
+ #define GRX_API __declspec(dllexport)
13
+ #else
14
+ #define GRX_API __attribute__((visibility("default")))
15
+ #endif
16
+
17
+ #ifdef __cplusplus
18
+ extern "C" {
19
+ #endif
20
+
21
+ /* ---- Memoria alineada -------------------------------------------- */
22
+ GRX_API double* grx_alloc(size_t n);
23
+ GRX_API void grx_free(double *ptr);
24
+
25
+ /* ---- Element-wise aritmética ------------------------------------- */
26
+ GRX_API void grx_add (const double *a, const double *b, double *out, size_t n);
27
+ GRX_API void grx_sub (const double *a, const double *b, double *out, size_t n);
28
+ GRX_API void grx_mul (const double *a, const double *b, double *out, size_t n);
29
+ GRX_API void grx_div (const double *a, const double *b, double *out, size_t n);
30
+ GRX_API void grx_scale (const double *a, double s, double *out, size_t n);
31
+ GRX_API void grx_negate(const double *a, double *out, size_t n);
32
+ GRX_API void grx_add_scalar(const double *a, double s, double *out, size_t n);
33
+
34
+ /* ---- Element-wise matemáticas ------------------------------------ */
35
+ GRX_API void grx_abs (const double *a, double *out, size_t n);
36
+ GRX_API void grx_sqrt (const double *a, double *out, size_t n);
37
+ GRX_API void grx_log (const double *a, double *out, size_t n);
38
+ GRX_API void grx_exp (const double *a, double *out, size_t n);
39
+ GRX_API void grx_pow (const double *a, double exp, double *out, size_t n);
40
+ GRX_API void grx_clip (const double *a, double lo, double hi, double *out, size_t n);
41
+ GRX_API void grx_square (const double *a, double *out, size_t n);
42
+
43
+ /* ---- Reducciones ------------------------------------------------- */
44
+ GRX_API double grx_sum (const double *a, size_t n);
45
+ GRX_API double grx_mean(const double *a, size_t n);
46
+ GRX_API double grx_max (const double *a, size_t n);
47
+ GRX_API double grx_min (const double *a, size_t n);
48
+
49
+ /* ---- Álgebra lineal ---------------------------------------------- */
50
+ GRX_API double grx_dot (const double *a, const double *b, size_t n);
51
+ GRX_API void grx_matmul (const double *a, const double *b, double *out,
52
+ size_t M, size_t K, size_t N);
53
+
54
+ /* ---- Activaciones ------------------------------------------------ */
55
+ GRX_API void grx_relu (const double *a, double *out, size_t n);
56
+ GRX_API void grx_leaky_relu (const double *a, double alpha, double *out, size_t n);
57
+ GRX_API void grx_tanh_act (const double *a, double *out, size_t n);
58
+ GRX_API void grx_sigmoid (const double *a, double *out, size_t n);
59
+ GRX_API void grx_softmax (const double *a, double *out, size_t n);
60
+
61
+ /* ---- Optimizadores (in-place) ------------------------------------ */
62
+ /* SGD: param -= lr * grad */
63
+ GRX_API void grx_sgd_step(double *param, const double *grad,
64
+ double lr, size_t n);
65
+
66
+ /* Adam: actualiza param, m, v in-place */
67
+ GRX_API void grx_adam_step(double *param,
68
+ double *m, double *v,
69
+ const double *grad,
70
+ double lr, double beta1, double beta2,
71
+ double epsilon, double beta1t, double beta2t,
72
+ size_t n);
73
+
74
+ /* ---- Inicialización de pesos ------------------------------------- */
75
+ /* Xavier uniform: U(-limit, limit), limit = sqrt(6 / (fan_in + fan_out)) */
76
+ GRX_API void grx_init_xavier_uniform(double *out, size_t n,
77
+ size_t fan_in, size_t fan_out);
78
+ /* He normal: N(0, sqrt(2/fan_in)) */
79
+ GRX_API void grx_init_he_normal(double *out, size_t n, size_t fan_in);
80
+
81
+ #ifdef __cplusplus
82
+ }
83
+ #endif
84
+
85
+ #endif /* GRX_CORE_H */
data/ext/unix/Makefile ADDED
@@ -0,0 +1,66 @@
1
+ # =============================================================
2
+ # Makefile — Linux / macOS
3
+ #
4
+ # Compila grx_core.c DIRECTAMENTE en lib/grx/ (sin archivo intermedio).
5
+ # No hay .so en ext/unix/ — el único binario vive en lib/grx/.
6
+ #
7
+ # Uso:
8
+ # make → compila (detecta OS y SIMD automáticamente)
9
+ # make clean → elimina el binario de lib/grx/
10
+ # make bench → compila y corre el benchmark
11
+ # =============================================================
12
+
13
+ CC = gcc
14
+ CFLAGS = -O3 -march=native -ffast-math -funroll-loops \
15
+ -fPIC -fvisibility=hidden \
16
+ -Wall -Wextra -std=c11
17
+ LDFLAGS = -lm
18
+ SRC_DIR = ../grx
19
+ SRC = $(SRC_DIR)/grx_core.c
20
+ HEADER = $(SRC_DIR)/grx_core.h
21
+
22
+ # Destino final — directamente en lib/grx/
23
+ OUT_DIR = ../../lib/grx
24
+
25
+ UNAME := $(shell uname -s)
26
+ ifeq ($(UNAME), Darwin)
27
+ LIB = libgrx_core.dylib
28
+ SHARED = -dynamiclib
29
+ else
30
+ LIB = libgrx_core.so
31
+ SHARED = -shared
32
+ endif
33
+
34
+ TARGET = $(OUT_DIR)/$(LIB)
35
+
36
+ # Detección de SIMD
37
+ AVX2_TEST := $(shell echo 'int main(){}' | $(CC) -mavx2 -mfma -x c - -o /dev/null 2>&1)
38
+ ifeq ($(AVX2_TEST),)
39
+ CFLAGS += -mavx2 -mfma
40
+ $(info [GRX] AVX2 + FMA habilitados — máxima velocidad SIMD)
41
+ else
42
+ SSE2_TEST := $(shell echo 'int main(){}' | $(CC) -msse2 -x c - -o /dev/null 2>&1)
43
+ ifeq ($(SSE2_TEST),)
44
+ CFLAGS += -msse2
45
+ $(info [GRX] SSE2 habilitado)
46
+ else
47
+ $(info [GRX] Sin SIMD — modo escalar)
48
+ endif
49
+ endif
50
+
51
+ .PHONY: all clean bench
52
+
53
+ all: $(TARGET)
54
+
55
+ $(TARGET): $(SRC) $(HEADER)
56
+ @mkdir -p $(OUT_DIR)
57
+ $(CC) $(CFLAGS) $(SHARED) $(SRC) -o $(TARGET) $(LDFLAGS)
58
+ @echo "[GRX] Compilado → $(TARGET)"
59
+
60
+ bench: all
61
+ @echo "[GRX] Corriendo benchmark..."
62
+ ruby -I../../lib ../../test/benchmark.rb
63
+
64
+ clean:
65
+ rm -f $(TARGET)
66
+ @echo "[GRX] Limpiado: $(TARGET)"